/* gwbascii 1.0 - Detokenize/decode gwbasic/basica files */
/* Released to the public domain by its author, Arne de Bruijn, 2002. */

#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#include <stdarg.h>

#ifdef WIN32
#include <io.h>
#include <fcntl.h>
#endif

#define ELMS(x) (sizeof(x)/sizeof((x)[0]))
char *kws0[] = { 
 /*81*/ "END", "FOR", "NEXT", "DATA", "INPUT", "DIM", "READ", "LET",
 /*89*/ "GOTO", "RUN", "IF", "RESTORE", "GOSUB", "RETURN", "REM", "STOP",
 /*91*/ "PRINT", "CLEAR", "LIST", "NEW", "ON", "WAIT", "DEF", "POKE",
 /*99*/ "CONT", NULL, NULL, "OUT", "LPRINT", "LLIST", NULL, "WIDTH",
 /*a1*/ "ELSE", "TRON", "TROFF", "SWAP", "ERASE", "EDIT", "ERROR", "RESUME",
 /*a9*/ "DELETE", "AUTO", "RENUM", "DEFSTR", "DEFINT", "DEFSNG", "DEFDBL", "LINE",
 /*b1*/ "WHILE", "WEND", "CALL", NULL, NULL, NULL, "WRITE", "OPTION",
 /*b9*/ "RANDOMIZE", "OPEN", "CLOSE", "LOAD", "MERGE", "SAVE", "COLOR", "CLS",
 /*c1*/ "MOTOR", "BSAVE", "BLOAD", "SOUND", "BEEP", "PSET", "PRESET", "SCREEN",
 /*c9*/ "KEY", "LOCATE", NULL, "TO", "THEN", "TAB(", "STEP", "USR",
 /*d1*/ "FN", "SPC(", "NOT", "ERL", "ERR", "STRING$", "USING", "INSTR",
 /*d9*/ "'", "VARPTR", "CSRLIN", "POINT", "OFF", "INKEY$", NULL, NULL,
 /*e1*/ NULL, NULL, NULL, NULL, NULL, ">", "=", "<",
 /*e9*/ "+", "-", "*", "/", "^", "AND", "OR", "XOR",
 /*f1*/ "EQV", "IMP", "MOD", "\\"};

char *kws1[] = {
 /*81*/ "CVI", "CVS", "CVD", "MKI$", "MKS$", "MKD$", NULL, NULL,
 /*89*/ NULL, NULL, "EXTERR"};

char *kws2[] = {
 /*81*/ "FILES", "FIELD", "SYSTEM", "NAME", "LSET", "RSET", "KILL", "PUT",
 /*89*/ "GET", "RESET", "COMMON", "CHAIN", "DATE$", "TIME$", "PAINT", "COM",
 /*91*/ "CIRCLE", "DRAW", "PLAY", "TIMER", "ERDEV", "IOCTL", "CHDIR", "MKDIR",
 /*99*/ "RMDIR", "SHELL", "ENVIRON", "VIEW", "WINDOW", "PMAP", "PALETTE", "LCOPY",
 /*a1*/ "CALLS", NULL, NULL, NULL, "PCOPY", NULL, "LOCK", "UNLOCK"};

char *kws3[] = {
 /*81*/ "LEFT$", "RIGHT$", "MID$", "SGN", "INT", "ABS", "SQR", "RND",
 /*89*/ "SIN", "LOG", "EXP", "COS", "TAN", "ATN", "FRE", "INP",
 /*91*/ "POS", "LEN", "STR$", "VAL", "ASC", "CHR$", "PEEK", "SPACE$",
 /*99*/ "OCT$", "HEX$", "LPOS", "CINT", "CSNG", "CDBL", "FIX", "PEN",
 /*a1*/ "STICK", "STRIG", "EOF", "LOC", "LOF"};

char **kws[] = {kws0, kws1, kws2, kws3};
int kwls[] = {ELMS(kws0), ELMS(kws1), ELMS(kws2), ELMS(kws3)};

int key1[] = {0x9A, 0xF7, 0x19, 0x83, 0x24, 0x63, 0x43, 0x83, 0x75, 0xCD, 0x8D, 0x84, 0xA9};
int key2[] = {0x7C, 0x88, 0x59, 0x74, 0xE0, 0x97, 0x26, 0x77, 0xC4, 0x1D, 0x1E};

#define FLTBINSIZE 10
#define FLTBCDSIZE 23

void bin2bcd(unsigned char *bin, unsigned char *bcd) {
	unsigned char buf[FLTBINSIZE];
	int j;
	memcpy(buf, bin, FLTBINSIZE);
	for (j = 0; ; ) {
		int i;
		int rest = 0;
		int donext = 0;
		for (i = FLTBINSIZE - 1; i >= 0; i--) {
			int n = buf[i] + rest * 256;
			donext |= (buf[i] = n / 10);
			rest = n % 10;
		}
		bcd[j++] = rest;
		if (!donext)
			break;
	}
	for (; j < FLTBCDSIZE; j++)
		bcd[j] = 0;
}

void bcdadjexp(unsigned char *bcd, int *exp) {
	int carry;
	int i;

	if (!bcd[FLTBCDSIZE - 1])
		return;
	carry = bcd[0] >= 5 ? 1 : 0;
	for (i = 0; i < FLTBCDSIZE - 1; i++) {
		int n = bcd[i + 1] + carry;
		carry = n >= 10 ? 1 : 0;
		bcd[i] = n % 10;
	}
	bcd[FLTBCDSIZE - 1] = 0;
	(*exp)++;
}

void bcdmul2(unsigned char *bcd) {
	int carry = 0;
	int i;

	for (i = 0; i < FLTBCDSIZE; i++) {
		int n = bcd[i] * 2 + carry;
		carry = n >= 10 ? 1 : 0;
		bcd[i] = n % 10;
	}
}

void bcddiv2adj(unsigned char *bcd, int *exp) {
	int carry10 = 0;
	int i;

	memmove(bcd + 1, bcd, FLTBCDSIZE - 1);
	bcd[0] = 0;
	(*exp)--;
	for (i = FLTBCDSIZE - 1; i >= 0; i--) {
		int n = bcd[i] + carry10;
		carry10 = n & 1 ? 10 : 0;
		bcd[i] = n >> 1;
	}
}

void fmtflt(unsigned char *data, int size, int decimals, char expchar, char noexpchar, char *out) {
	unsigned char bin[FLTBINSIZE];
	unsigned char bcd[FLTBCDSIZE];
	char *pout;
	int sign, iexp, oexp;
	int i;
	int hasdot = 0;
	int pbcd, bcdend;

	memset(bin, 0, FLTBINSIZE);
	memcpy(bin + FLTBINSIZE - size, data, size);
	for (i = 0; i < FLTBINSIZE && !bin[i]; i++)
		;
	if (i == FLTBINSIZE) { /* input is all zeros */
		sign = 0;
		iexp = 0;
		oexp = 0;
	} else {
		bin[FLTBINSIZE - 2] |= 0x80;
		bin[FLTBINSIZE - 1] = 0;

		sign = data[size - 2] & 0x80;
		iexp = data[size - 1] - 128;

		oexp = 21;
		iexp -= 72;
	}

	bin2bcd(bin, bcd);

	while (iexp < 0) {
		iexp++;
		bcddiv2adj(bcd, &oexp);
		bcdadjexp(bcd, &oexp);
	}
	while (iexp > 0) {
		iexp--;
		bcdmul2(bcd);
		bcdadjexp(bcd, &oexp);
	}
	if (bcd[FLTBCDSIZE - 3 - decimals] >= 5) {
		int carry = 1;
		int i;
		for (i = FLTBCDSIZE - 2 - decimals; i < FLTBCDSIZE; i++) {
			int n = bcd[i] + carry;
			carry = n >= 10 ? 1 : 0;
			bcd[i] = n % 10;
		}
		bcdadjexp(bcd, &oexp);
	}
	pout = out;
	if (sign)
		*pout++ = '-';
	pbcd = FLTBCDSIZE - 2;
	bcdend = pbcd - decimals - 1;
	while (!bcd[bcdend + 1])
		bcdend++;
	if (oexp > 0 && oexp <= decimals) {
		for (i = 0; i <= oexp; i++)
			*pout++ = bcd[pbcd--] + '0';
		oexp = 0;
	} else if (oexp < 0 && (pbcd - bcdend) - oexp - 2 <= decimals) {
		*pout++ = '.';
		hasdot = 1;
		for (i = -1; i > oexp; i--)
			*pout++ = '0';
		oexp = 0;
	} else
		*pout++ = bcd[pbcd--] + '0';
	if (pbcd > bcdend) {
		if (!hasdot) {
			*pout++ = '.';
			hasdot = 1;
		}
		while (pbcd > bcdend)
			*pout++ = bcd[pbcd--] + '0';
	}
	if (oexp) {
		*pout++ = expchar;
		sprintf(pout, "%+03d", oexp);
	} else {
		if (noexpchar != '!' || !hasdot)
			*pout++ = noexpchar;
		*pout = 0;
	}
}

void usage(void) {
	puts("gwbascii 1.0 - Detokenize/decode gwbasic/basica files\n\n"
			"Usage: gwbasasc [input] [output]\n"
			"(default stdin/stdout)");
}

int parse_args(int argc, char *argv[], FILE **fin, FILE **fout,
	int *closein, int *closeout) {
	int i;
	int noopt = 0;
	char *fnin = NULL;
	char *fnout = NULL;

	*fin = NULL;
	*fout = NULL;
	*closein = 0;
	*closeout = 0;

	for (i = 1; i < argc; i++)
		if (argv[i][0] == '-' && !noopt) {
			if (!strcmp(argv[i], "--"))
				noopt = 1;
			else {
				usage();
				return -1;
			}
		} else if (!fnin)
			fnin = argv[i];
		else if (!fnout)
			fnout = argv[i];
		else {
			usage();
			return -1;
		}

	if (fnin) {
		if (!(*fin = fopen(fnin, "rb"))) {
			perror("Error opening input file");
			return -1;
		}
		*closein = 1;
	} else {
#ifdef WIN32
		_setmode(_fileno(stdin), _O_BINARY );
#endif
		*fin = stdin;
	}
	
	if (fnout) {
		if (!(*fout = fopen(fnout, "w"))) {
			if (*closein)
				fclose(*fin);
			*fin = NULL;
			*closein = 0;
			perror("Error creating output file");
			return -1;
		}
		*closeout = 1;
	} else
		*fout = stdout;
	return 0;
}


struct decstate {
	int c1, c2;
};

void decode_init(struct decstate *decstate) {
	decstate->c1 = 11;
	decstate->c2 = 13;
}

int decode(struct decstate *decstate, int c) {
	c = (((256 + c - decstate->c1) ^ 
			key1[decstate->c2 - 1] ^ 
			key2[decstate->c1 - 1]) +
			decstate->c2) & 255;
	if (!--decstate->c1)
		decstate->c1 = 11;
	if (!--decstate->c2)
		decstate->c2 = 13;
	return c;
}

int readbyte(FILE *f, struct decstate *decstate) {
	int c;
	if ((c = fgetc(f)) == EOF) {
		if (ferror(f))
			perror("Error reading input file");
		else
			fprintf(stderr, "Unexpected end of input file.\n");
		return -1;
	}
	if (decstate)
		c = decode(decstate, c);
	return c;
}

int readword(FILE *f, struct decstate *decstate, unsigned int *w) {
	int c1, c2;
	
	if ((c1 = readbyte(f, decstate)) == -1)
		return -1;
	if ((c2 = readbyte(f, decstate)) == -1)
		return -1;
	*w = c1 | (c2 << 8);
	return 0;
}

int writefmt(FILE *f, const char *fmt, ...) {
	int rc;
	va_list vp;
	va_start(vp, fmt);
	rc = vfprintf(f, fmt, vp);
	va_end(vp);
	if (rc < 0) {
		if (ferror(f))
			perror("Error writing output file");
		else
			fprintf(stderr, "Error writing output file (disk full?)\n");
	}
	return 0;
}

int process_file(FILE *fin, FILE *fout) {
	int c;
	unsigned int w;
	struct decstate sdecstate, *decstate;
	int j;
	unsigned char binbuf[8];
	char outbuf[32];

	if ((c = readbyte(fin, NULL)) == -1) /* read file type */
		return -1;
	if (c != 255 && c != 254) {
		fprintf(stderr, "Unknown input file format.\n");
		return -1;
	}
	if (c == 254) {
		decode_init(&sdecstate);
		decstate = &sdecstate;
	} else
		decstate = NULL;

	for (;;) {
		int skipspace;

		if (readword(fin, decstate, &w)) /* read internal pointer to next line */
			return -1;
		if (!w) /* zero -> end of file */
			break;
		if (readword(fin, decstate, &w)) /* read line number */
			return -1;
		if (writefmt(fout, "%d ", w))
			return -1;

		skipspace = w == 0; /* skip first space if line number is 0 */
		for (;;) {
			if ((c = readbyte(fin, decstate)) == -1) /* read token */
				return -1;
			if (c == 0x20 && skipspace) { 
				skipspace = 0;
				continue;
			}
			while (c == ':') {
				if ((c = readbyte(fin, decstate)) == -1)
					return -1;
				if (c == 0x8f) {
					if ((c = readbyte(fin, decstate)) == -1)
						return -1;
					if (c == 0xd9) {
						if (writefmt(fout, "\'"))
							return -1;
					} else {
						if (writefmt(fout, ":REM%c", c))
							return -1;
					}
					for (;;) {
						if ((c = readbyte(fin, decstate)) == -1)
							return -1;
						if (!c)
							break;
						if (writefmt(fout, "%c", c))
							return -1;
					}
					break;
				}
				if (c != 0xa1) /* no : before ELSE */
					if (writefmt(fout, ":"))
						return -1;
			}
			if (!c)
				break;
			if (c < 0x20) {
				switch (c) {
					case 11:
						if (readword(fin, decstate, &w))
							return -1;
						if (writefmt(fout, "&O%o", w))
							return -1;
						break;
					case 12:
						if (readword(fin, decstate, &w))
							return -1;
						if (writefmt(fout, "&H%X", w))
							return -1;
						break;
					case 14:
						if (readword(fin, decstate, &w))
							return -1;
						if (writefmt(fout, "%u", w))
							return -1;
						break;
					case 15:
						if ((c = readbyte(fin, decstate)) == -1)
							return -1;
						if (writefmt(fout, "%d", c))
							return -1;
						break;
					case 17:
					case 18:
					case 19:
					case 20:
					case 21:
					case 22:
					case 23:
					case 24:
					case 25:
					case 26:
						if (writefmt(fout, "%d", c - 17))
							return -1;
						break;
					case 28:
						if (readword(fin, decstate, &w))
							return -1;
						if (writefmt(fout, "%d",
							w & 0x8000 ? ((int)w) - 65536 : w))
							return -1;
						break;
					case 29:
						for (j = 0; j < 4; j++) {
							if ((c = readbyte(fin, decstate)) == -1)
								return -1;
							binbuf[j] = c;
						}
						fmtflt(binbuf, 4, 6, 'E', '!', outbuf);
						if (writefmt(fout, "%s", outbuf))
							return -1;
						break;
					case 31:
						for (j = 0; j < 8; j++) {
							if ((c = readbyte(fin, decstate)) == -1)
								return -1;
							binbuf[j] = c;
						}
						fmtflt(binbuf, 8, 15, 'D', '#', outbuf);
						if (writefmt(fout, "%s", outbuf))
							return -1;
						break;
					default:
						fprintf(stderr, "Invalid token %d in input file.\n", c);
						return -1;
				}
			} else if (c < 0x80) {
				if (writefmt(fout, "%c", c))
					return -1;
				if (c == 0x22) {
					for (;;) {
						if ((c = readbyte(fin, decstate)) == -1)
							return -1;
						if (!c || c == 0x22)
							break;
						if (writefmt(fout, "%c", c))
							return -1;
					}
					if (!c)
						break;
					if (writefmt(fout, "%c", c))
						return -1;
				}
			} else {
				int kwt = 0;
				if (c >= 0xfd) {
					kwt = c - 0xfd + 1;
					if ((c = readbyte(fin, decstate)) == -1)
						return -1;
				}
				if (c < 129 || c - 129 >= kwls[kwt] || !kws[kwt][c - 129]) {
					fprintf(stderr, "Invalid token %d %d in input file.\n", kwt ? kwt + 0xfd - 1 : -1, c);
					return -1;
				}
				if (writefmt(fout, "%s", kws[kwt][c - 129]))
					return -1;

				if (!kwt && c == 0xb1) { /* skip unknown byte after WHILE */
					if ((c = readbyte(fin, decstate)) == -1)
						return -1;
				} else if (!kwt && c == 0x8f) { /* comment after REM */
					for (;;) {
						if ((c = readbyte(fin, decstate)) == -1)
							return -1;
						if (!c)
							break;
						if (writefmt(fout, "%c", c))
							return -1;
					}
					break;
				}
			}
		}
		if (writefmt(fout, "\n", c))
			return -1;
	}
	return 0;
}

int main(int argc, char *argv[]) {
	FILE *fin, *fout;
	int closein, closeout;

	if (parse_args(argc, argv, &fin, &fout, &closein, &closeout))
		return EXIT_FAILURE;

	if (process_file(fin, fout)) {
		if (closein)
			fclose(fin);
		if (closeout)
			fclose(fout);
		return EXIT_FAILURE;
	}

	if (closein)
		fclose(fin);
	if (closeout) {
		if (fclose(fout)) {
			perror("Error closing output file");
			return EXIT_FAILURE;
		}
	}

	return EXIT_SUCCESS;
}
