#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <wchar.h>
#include <wctype.h>
#include <windows.h>
#include <ole2.h>
#include <openssl/sha.h>
#include <openssl/rc4.h>
#include "ecc.h"
#include "msdrm.h"

extern void error_exit(char *msg);
extern void printwcs(wchar_t * msg);

typedef struct bboxobj_st {
	void *jtable;
	void *jtbl2;
	void *jtbl3;
	MS_BN ecprivkey;
	MS_ECCpt ecpt1;
	uchar clientid[84];	/* First part is public key */
	uchar hwid[20];
	uchar rc4key[6];
	uchar pad1[2];
	int numkeypairs;
	uchar *keypairs;
} BBOXOBJ;

#define MAXKEYPAIRS 50

struct keypair_st {
	MS_ECCpt public;
	MS_BN private;
} keypair[MAXKEYPAIRS];
int numkeypairs = 0;


static MS_BN msec_mod = {
	{0xf7, 0x24, 0x14, 0x14, 0x26, 0x59, 0x41, 0x31, 0x18, 0x28,
	 0x18, 0x27, 0x67, 0x45, 0x23, 0x01, 0xef, 0xcd, 0xab, 0x89}
};

static MS_BN msec_a = {
	{0x97, 0x14, 0xe4, 0xeb, 0x09, 0xc0, 0x80, 0x47, 0x3d, 0xff,
	 0x32, 0x76, 0xe8, 0xbc, 0x77, 0xd2, 0xcc, 0xab, 0xa5, 0x37}
};

static MS_BN msec_b = {
	{0x9e, 0x23, 0x28, 0x93, 0xdf, 0xde, 0x8f, 0xd7, 0x1a, 0x5f,
	 0xe8, 0x28, 0x32, 0x2f, 0x5e, 0x72, 0xbf, 0xda, 0xd8, 0x0d}
};

static MS_BN msec_gx = {
	{0x20, 0xa1, 0x9f, 0x10, 0xf0, 0xda, 0x38, 0xba, 0x7d, 0xc0,
	 0x10, 0x35, 0xe5, 0xa1, 0xa3, 0xd6, 0x7f, 0x94, 0x23, 0x87}
};

static MS_BN msec_gy = {
	{0x6f, 0x93, 0x79, 0xa3, 0xcd, 0x7a, 0xed, 0xd4, 0x56, 0x58,
	 0x3c, 0x8c, 0x2d, 0x52, 0x75, 0x10, 0x91, 0x44, 0x57, 0x44}
};


static void printMSBN(MS_BN * num)
{
	int i;
	for (i = MS_BN_LEN - 1; i >= 0; i--)
		fprintf(stderr, "%02x", num->d[i]);
}


static BIGNUM *MS_BN_to_BN(MS_BN * msnum, BIGNUM * r)
{
	uchar bigendian[MS_BN_LEN];
	int i;

	for (i = 0; i < MS_BN_LEN; i++)
		bigendian[i] = msnum->d[MS_BN_LEN - 1 - i];

	return BN_bin2bn(bigendian, MS_BN_LEN, r);
}


static void MS_ECCpt_to_ECCpt(MS_ECCpt * mspt, ECCpt * r)
{
	MS_BN_to_BN(&mspt->x, r->x);
	MS_BN_to_BN(&mspt->y, r->y);
}


static ECC *MSECC_new_set()
{
	BIGNUM *tmod, *ta, *tb;
	ECC *ecc;
	ECCpt tg;

	tmod = BN_new();
	ta = BN_new();
	tb = BN_new();
	ECCpt_init(&tg);

	MS_BN_to_BN(&msec_mod, tmod);
	MS_BN_to_BN(&msec_a, ta);
	MS_BN_to_BN(&msec_b, tb);
	MS_BN_to_BN(&msec_gx, tg.x);
	MS_BN_to_BN(&msec_gy, tg.y);

	ecc = ECC_new_set(tmod, ta, tb, tg);

	BN_free(tmod);
	BN_free(ta);
	BN_free(tb);
	ECCpt_free(&tg);

	return ecc;
}


static void MSECC_set_privkey(MS_BN * pk, ECC * ecc)
{
	if (ecc->privkey == NULL)
		ecc->privkey = BN_new();
	MS_BN_to_BN(pk, ecc->privkey);
}


static void BN_to_MS_BN(BIGNUM * in, MS_BN * out)
{
	MS_BN tmp;
	int bytelen, i;

	bytelen = BN_num_bytes(in);
	if (bytelen > MS_BN_LEN)
		error_exit
		    ("Bug in code:  Result is too big in BN_to_MS_BN");

	for (i = 0; i < MS_BN_LEN; i++)
		tmp.d[i] = 0;

	BN_bn2bin(in, (uchar *) & tmp.d[MS_BN_LEN - bytelen]);

	for (i = 0; i < MS_BN_LEN; i++)
		out->d[i] = tmp.d[MS_BN_LEN - 1 - i];
}


static void MSECC_decrypt(MS_ECCpt * r, MS_ECCpt * ctext, ECC * ecc)
{
	ECCpt u, v;

	if (ecc->privkey == NULL)
		error_exit
		    ("Bug in code:  MSECC_decrypt called with no private key!");

	ECCpt_init(&u);
	ECCpt_init(&v);
	MS_ECCpt_to_ECCpt(&ctext[0], &u);
	MS_ECCpt_to_ECCpt(&ctext[1], &v);

	ECCpt_mul(&u, &u, ecc->privkey, ecc);
	BN_sub(u.y, ecc->modulus, u.y);
	ECCpt_add(&v, &v, &u, ecc);

	BN_to_MS_BN(v.x, &r->x);
	BN_to_MS_BN(v.y, &r->y);

	ECCpt_free(&u);
	ECCpt_free(&v);
}


static int MS_Base64Decode(wchar_t * str, char **buff)
{
	wchar_t *cp;
	char *ocp;
	int len, val, count, block, ocount;

	len = wcslen(str);
	if ((*buff = malloc((len * 3) / 4)) == NULL)
		error_exit("Memory allocation failed in MS_Base64Decode.");

	ocp = *buff;
	count = 0;
	block = 0;
	ocount = 0;
	for (cp = str; *cp != L'\0'; cp++) {
		if ((*cp >= L'A') && (*cp <= L'Z'))
			val = *cp - L'A';
		else if ((*cp >= L'a') && (*cp <= L'z'))
			val = *cp - L'a' + 26;
		else if ((*cp >= L'0') && (*cp <= L'9'))
			val = *cp - L'0' + 52;
		else if ((*cp == L'+') || (*cp == L'!'))
			val = 62;
		else if ((*cp == L'/') || (*cp == L'*'))
			val = 63;
		else if (*cp == L'=') {
			if (count == 2) {
				*ocp++ = block >> 4;
				ocount++;
			} else {
				*ocp++ = (block >> 10);
				*ocp++ = (block >> 2) & 0xff;
				ocount += 2;
			}
			break;
		} else
			val = -1;

		if (val >= 0) {
			block = (block << 6) | val;
			if (++count == 4) {
				*ocp++ = block >> 16;
				*ocp++ = (block >> 8) & 0xff;
				*ocp++ = block & 0xff;
				ocount += 3;
				count = 0;
			}
		}
	}

	return ocount;
}


void MSDRM_decr_packet(uchar * data, int len, CONTKEY * ckey)
{
	RC4_KEY rc4state;
	int num64bits = len / 8;
	uchar work2[8];
	int i;
	unsigned int pustate[2];
	unsigned int tmpd[2];
	uchar *keystart = data + (num64bits - 1) * 8;

	if (len < 16) {
		for (i = 0; i < len; i++)
			data[i] ^= ckey->keyhash[i];
		return;
	}

	for (i = 0; i < 8; i++)
		keystart[i] ^= ckey->inmask[i];

	des_ecb_encrypt((const_des_cblock *) keystart,
			(des_cblock *) work2, ckey->keysched, 0);

	for (i = 0; i < 8; i++)
		work2[i] ^= ckey->outmask[i];

	RC4_set_key(&rc4state, 8, work2);
	RC4(&rc4state, len, data, data);

	MultiSwapMAC(&ckey->hashkey, (unsigned int *) data, num64bits - 1,
		     pustate);
	tmpd[0] = ((int *) work2)[1];
	tmpd[1] = ((int *) work2)[0];
	MultiSwapDecode(&ckey->hashkey, pustate, tmpd,
			(unsigned int *) keystart);
}


static void MSDRM_setup(MS_BN * privkey, wchar_t * value, CONTKEY * out)
{
	ECC *msecc;
	MS_ECCpt dec;
	RC4_KEY rc4state;
	uchar rc4buff[64];
	int len;
	char *dynbuff;

	msecc = MSECC_new_set();
	MSECC_set_privkey(privkey, msecc);

	len = MS_Base64Decode(value, &dynbuff);
	MSECC_decrypt(&dec, (MS_ECCpt *) dynbuff, msecc);
	free(dynbuff);

	ECC_free(msecc);
	msecc = NULL;

	if ((uchar) dec.x.d[0] > MS_BN_LEN - 1)
		error_exit("Decrypted content key is too big!");

	out->ckeylen = (uchar) dec.x.d[0];
	memcpy(out->ckey, &dec.x.d[1], out->ckeylen);

	if (globalinfo.verbose) {
		int i;
		fprintf(stderr, "Content key:");
		for (i = 0; i < out->ckeylen; i++)
			fprintf(stderr, " %02x", out->ckey[i]);
		fprintf(stderr, "\n");
	}

	SHA1(out->ckey, out->ckeylen, out->keyhash);

	des_set_key_unchecked((des_cblock *) (&out->keyhash[12]),
			      out->keysched);

	RC4_set_key(&rc4state, 12, out->keyhash);
	memset(rc4buff, 0, sizeof(rc4buff));
	RC4(&rc4state, sizeof(rc4buff), rc4buff, rc4buff);

	memcpy(out->outmask, &rc4buff[48], 8);
	memcpy(out->inmask, &rc4buff[56], 8);

	MultiSwapSetKey(&out->hashkey, (unsigned int *) rc4buff);
}


/* Stupid little fake XML parser. */

static wchar_t *find_close(wchar_t * str)
{
	while ((*str != L'\0') && (*str != L'>')) {
		if (*str == L'"') {
			if ((str = wcschr(str + 1, L'"')) == NULL)
				return NULL;
		}
		str++;
	}

	if (*str == L'\0')
		return NULL;
	else
		return str + 1;
}


wchar_t *get_element(wchar_t * tag, wchar_t * str)
{
	int len = wcslen(tag);
	wchar_t *tmptag;
	wchar_t *start, *end;
	wchar_t *rval = NULL;

	if ((tmptag = malloc((len + 4) * sizeof(wchar_t))) == NULL)
		error_exit("Memory allocation failed in get_element (1)");

	swprintf(tmptag, L"<%s", tag);

	while (1) {
		if ((start = wcsstr(str, tmptag)) == NULL)
			goto exit;
		if (!iswalnum(start[len + 1]))
			break;
		str = start + len + 1;
	}

	swprintf(tmptag, L"</%s>", tag);
	end = wcsstr(str, tmptag);

	if (end == NULL) {
		goto exit;
	} else {
		wchar_t *realstart = find_close(start);
		if ((realstart == NULL) || (realstart > end)) {
			goto exit;
		} else {
			wchar_t *tmp =
			    malloc((end - realstart +
				    1) * sizeof(wchar_t));
			if (tmp == NULL)
				error_exit
				    ("Memory allocation failed in get_element (2)");
			memcpy(tmp, realstart,
			       (end - realstart) * sizeof(wchar_t));
			tmp[end - realstart] = L'\0';
			rval = tmp;
		}
	}

exit:
	free(tmptag);
	return rval;
}



/*
 * getDRMDataPath allocates extra room on the end (20 wchars) for 
 * appending a filename
 */
static wchar_t *getDRMDataPath()
{
	HKEY key_drm;
	long stat;
	DWORD dtype, dlen;
	wchar_t *buff;

	stat =
	    RegOpenKeyEx(HKEY_LOCAL_MACHINE, "Software\\Microsoft\\DRM", 0,
			 KEY_READ, &key_drm);
	if (FAILED(stat))
		return NULL;

	stat =
	    RegQueryValueEx(key_drm, "DataPath", NULL, NULL, NULL, &dlen);
	if (FAILED(stat))
		return NULL;

	if ((buff =
	     (wchar_t *) malloc(dlen + 20 * sizeof(wchar_t))) == NULL)
		error_exit("Memory allocation failed in getDRMDataPath");

	stat =
	    RegQueryValueEx(key_drm, "DataPath", NULL, &dtype,
			    (uchar *) buff, &dlen);
	if (FAILED(stat)) {
		free(buff);
		return NULL;
	}

	RegCloseKey(key_drm);

	return buff;
}



int fileExistsA(char *fname)
{
	return (GetFileAttributes(fname) != -1);
}

int fileExistsW(wchar_t * fname)
{
	char buffer[MAX_PATH];
	int len;

	len = wcslen(fname);
	if (WideCharToMultiByte
	    (CP_ACP, 0, fname, len + 1, buffer, MAX_PATH, NULL, NULL) == 0)
		return 0;

	return fileExistsA(buffer);
}

void getKSFilename(wchar_t * ksname, char *libname)
{
	wchar_t *basepath = getDRMDataPath();
	wchar_t currks[MAX_PATH], lastks[MAX_PATH];
	char abasepath[MAX_PATH];
	char currlib[MAX_PATH], lastlib[MAX_PATH];
	int fnum;

	if (basepath != NULL) {
		WideCharToMultiByte(CP_ACP, 0, basepath,
				    wcslen(basepath) + 1, abasepath,
				    MAX_PATH, NULL, NULL);
		swprintf(lastks, L"%s\\v2ks.bla", basepath);
		swprintf(currks, L"%s\\v2ksndv.bla", basepath);
		sprintf(lastlib, "BlackBox.dll");
		sprintf(currlib, "%s\\IndivBox.key", abasepath);
		fnum = 1;
		while (fileExistsW(currks) && (fileExistsA(currlib))) {
			fnum++;
			wcscpy(lastks, currks);
			swprintf(currks, L"%s\\v2ks%03x.bla", basepath,
				 fnum);
			strcpy(lastlib, currlib);
			sprintf(currlib, "%s\\Indiv%03x.key", basepath,
				fnum);
		}
		wcscpy(ksname, lastks);
		strcpy(libname, lastlib);
		free(basepath);
	}
}


static int getkeypairs()
{
	HMODULE mylib;
	int rval;
	BBOXOBJ *bbobj;
	char errmsg[100];
	wchar_t KSFilename[MAX_PATH];
	char BBoxLib[MAX_PATH];
	int i;

	getKSFilename(KSFilename, BBoxLib);
	if (globalinfo.verbose) {
		fprintf(stderr, "BlackBox library to use: %s\n", BBoxLib);
		fprintf(stderr, "Keystore to use: ");
		printwcs(KSFilename);
		fprintf(stderr, "\n");
	}
	mylib = LoadLibraryA(BBoxLib);
	if (mylib == NULL) {
		DWORD err = GetLastError();
		sprintf(errmsg, "Failed loading library. Err code %08x",
			err);
		error_exit(errmsg);
	} else {
		typedef int (WINAPI * createfn) (BBOXOBJ **,
						 unsigned short *);
		createfn create =
		    (createfn) GetProcAddress(mylib,
					      "IBlackBox_CreateInstance2");
		if (create == NULL)
			error_exit("Failed finding proc address.");
		else {
			rval = (*create) (&bbobj, KSFilename);

			if (bbobj == NULL) {
				sprintf(errmsg,
					"Failed to create a black box object (err code %08x)\n",
					rval);
				error_exit(errmsg);
			}

			if (globalinfo.verbose) {
				fprintf(stderr,
					"Created BlackBox instance - extracting key pairs\n");
			}

			memcpy(&keypair[0].private, &bbobj->ecprivkey, 20);
			memcpy(&keypair[0].public, bbobj->clientid, 40);
			numkeypairs = bbobj->numkeypairs + 1;
			for (i = 0; i < bbobj->numkeypairs; i++) {
				memcpy(&keypair[i + 1].public,
				       bbobj->keypairs + 60 * i, 40);
				memcpy(&keypair[i + 1].private,
				       bbobj->keypairs + 60 * i + 40, 20);
			}

			if (globalinfo.verbose) {
				fprintf(stderr, "\n");
				for (i = 0; i < numkeypairs; i++) {
					fprintf(stderr,
						"Public key %d x: ",
						i + 1);
					printMSBN(&keypair[i].public.x);
					fprintf(stderr,
						"\nPublic key %d y: ",
						i + 1);
					printMSBN(&keypair[i].public.y);
					fprintf(stderr,
						"\nPrivate key %d:  ",
						i + 1);
					printMSBN(&keypair[i].private);
					fprintf(stderr, "\n\n");
				}
			}
		}
		FreeLibrary(mylib);
	}
	return 0;
}


static CONTKEY *checkLicense(wchar_t * license)
{
	wchar_t *ebits = NULL;
	wchar_t *pubkey = NULL;
	wchar_t *value = NULL;
	MS_BN *privkey = NULL;
	CONTKEY *ckey = NULL;
	MS_BN *thispubkey;
	int i;

	if ((ebits = get_element(L"ENABLINGBITS", license)) == NULL)
		error_exit("No ENABLINGBITS element in license!");

	if ((pubkey = get_element(L"PUBKEY", ebits)) == NULL)
		error_exit("No PUBKEY element in license!");

	if ((value = get_element(L"VALUE", ebits)) == NULL)
		error_exit("No VALUE element in license!");

	MS_Base64Decode(pubkey, (char **) &thispubkey);
	if (globalinfo.verbose) {
		fprintf(stderr, "Checking license with PUBKEY ");
		printMSBN(thispubkey);
		fprintf(stderr, "\n");
	}
	for (i = 0; i < numkeypairs; i++) {
		if (memcmp(thispubkey, (char *) &keypair[i].public, 40) ==
		    0) {
			privkey = &keypair[i].private;
			break;
		}
	}

	if (privkey != NULL) {
		if ((ckey = malloc(sizeof(CONTKEY))) == NULL)
			error_exit
			    ("Memory allocation failed in checkLicense");
		if (globalinfo.verbose) {
			fprintf(stderr,
				"Matched public key!  Proceeding...\n");
		}
		MSDRM_setup(privkey, value, ckey);
	}

	free(thispubkey);
	free(value);
	free(pubkey);
	free(ebits);

	return ckey;
}


static CONTKEY *getContKey(wchar_t * licFile, wchar_t * kid)
{
	HRESULT hr;
	IStorage *pStg = NULL, *pStgChild = NULL;
	IStream *pStrmLicense;
	IEnumSTATSTG *penum = NULL;
	STATSTG statstg, statstrm;
	wchar_t *license = NULL;
	unsigned long statLen, reallyRead;
	CONTKEY *ckey;

	hr = StgOpenStorage(licFile, NULL,
			    STGM_READ | STGM_SHARE_EXCLUSIVE, NULL, 0,
			    &pStg);

	if (FAILED(hr))
		error_exit("Couldn't open license file!");

	hr = pStg->lpVtbl->OpenStorage(pStg, kid, NULL,
				       STGM_READ | STGM_SHARE_EXCLUSIVE,
				       NULL, 0, &pStgChild);

	if (FAILED(hr))
		return NULL;

	hr = pStgChild->lpVtbl->EnumElements(pStgChild, 0, NULL, 0,
					     &penum);
	if (FAILED(hr))
		error_exit("Couldn't EnumElements in storage.");

	memset(&statstg, 0, sizeof(statstg));
	hr = penum->lpVtbl->Next(penum, 1, &statstg, 0);
	while (S_OK == hr) {
		hr = pStgChild->lpVtbl->OpenStream(pStgChild,
						   statstg.pwcsName, NULL,
						   STGM_READ |
						   STGM_SHARE_EXCLUSIVE, 0,
						   &pStrmLicense);

		if (FAILED(hr))
			error_exit("Couldn't open license!");

		hr = pStrmLicense->lpVtbl->Stat(pStrmLicense, &statstrm,
						0);

		statLen = (unsigned long) statstrm.cbSize.QuadPart / 2;
		if ((license =
		     (wchar_t *) malloc(2 * (statLen + 1))) == NULL)
			error_exit
			    ("Memory allocation failed in getContKey.");
		hr = pStrmLicense->lpVtbl->Read(pStrmLicense, license,
						2 * statLen, &reallyRead);

		if ((FAILED(hr)) || (reallyRead != 2 * statLen))
			error_exit("License read failed.");

		license[statLen] = 0;

		pStrmLicense->lpVtbl->Release(pStrmLicense);

		if ((ckey = checkLicense(license + 1)) != NULL) {
			free(license);
			pStgChild->lpVtbl->Release(pStgChild);
			pStg->lpVtbl->Release(pStg);
			return ckey;
		}

		free(license);

		hr = penum->lpVtbl->Next(penum, 1, &statstg, 0);
	}

	return NULL;
}


static void convertKID(wchar_t * kid)
{
	while (*kid != L'\0') {
		if (*kid == L'/')
			*kid = L'@';
		else if (*kid == L'!')
			*kid = L'%';
		kid++;
	}
}


CONTKEY *MSDRM_init(wchar_t * kid)
{
	CONTKEY *ckey;
	wchar_t *licfile;

	licfile = getDRMDataPath();
	if (licfile == NULL)
		error_exit("Couldn't get DRM data path from registry.");

	wcscat(licfile, L"\\drmv2.lic");

	if (globalinfo.verbose) {
		fprintf(stderr, "License file full path: ");
		printwcs(licfile);
		fprintf(stderr, "\n");
	}

	getkeypairs();

	convertKID(kid);
	ckey = getContKey(licfile, kid);
	if (ckey == NULL)
		error_exit
		    ("Couldn't find a valid license for this content.");

	free(licfile);

	return ckey;
}

