#include "mhash.h"
#include "libdefs.h"
#include "keygen.h"


/* Key generation using OpenPGP Simple S2K algorithm */
int _mhash_gen_key_s2k_simple(hashid algorithm, void *keyword, int key_size,
		  unsigned char *password, int plen)
{
	word8* key;
	word8 *digest=NULL;
	char null='\0';
	int i,j, times;
	MHASH td;
	int block_size = mhash_get_block_size(algorithm);


	times = key_size/block_size;
	if (key_size%block_size != 0) times++;

	key=calloc(1, times*block_size);
	
	for (i=0;i<times;i++) {
		td = mhash_init(algorithm);
		if (td<0) return -1;
		
		for (j=0;j<i;j++)
			mhash(td, &null, 1);
		mhash(td, password, plen);
		digest=mhash_end(td);
		
		memmove( &key[i*block_size], digest, block_size);
		free(digest);
	}
	memmove(keyword, key, key_size);
	memset(key, '\0', key_size);
	free(key);
	return 0;
}


/* Key generation using OpenPGP Salted S2K algorithm */
int _mhash_gen_key_s2k_salted(hashid algorithm, void *keyword, int key_size,
		  unsigned char* salt, int salt_size,
		  unsigned char *password, int plen)
{
	word8* key;
	word8 *digest=NULL;
	char null='\0';
	int i,j, times;
	MHASH td;
	int block_size = mhash_get_block_size(algorithm);

	if (salt==NULL) return -1;
	if (salt_size<8) return -1; /* This algorithm will use EXACTLY
				     * 8 bytes salt.
				     */
	times = key_size/block_size;
	if (key_size%block_size != 0) times++;

	key=calloc(1, times*block_size);
	
	for (i=0;i<times;i++) {
		td = mhash_init(algorithm);
		if (td<0) return -1;
		
		for (j=0;j<i;j++)
			mhash(td, &null, 1);

		mhash(td, salt, 8);
		mhash(td, password, plen);
		digest=mhash_end(td);
		
		memmove( &key[i*block_size], digest, block_size);
		free(digest);
	}
	memmove(keyword, key, key_size);
	memset(key, '\0', key_size);
	free(key);
	return 0;
}

/* Key generation using OpenPGP Iterated and Salted S2K algorithm */
int _mhash_gen_key_s2k_isalted(hashid algorithm, unsigned long count, 
		  void *keyword, int key_size,
		  unsigned char* salt, int salt_size,
		  unsigned char *password, int plen)
{
	word8* key;
	word8 *digest=NULL;
	char null='\0';
	int i,j, z, times;
	MHASH td;
	int block_size = mhash_get_block_size(algorithm);
	char* saltpass=calloc(1, 8+plen);

	memmove( saltpass, salt, 8);
	memmove( &saltpass[8], password, plen);

	if (salt==NULL) return -1;
	if (salt_size<8) return -1; /* This algorithm will use EXACTLY
				     * 8 bytes salt.
				     */
	times = key_size/block_size;
	if (key_size%block_size != 0) times++;
	key=calloc(1, times*block_size);
	
	for (i=0;i<times;i++) {
		td = mhash_init(algorithm);
		if (td<0) return -1;
	
		for (j=0;j<i;j++)
			mhash(td, &null, 1);

		mhash(td, saltpass, 8+plen);
		if (count> (8+plen)) {
			count -= (8+plen);
		} else {
			count=0;
		}
		if (count>0) {
		 	for (z=0;z<count;z++)
				mhash(td, saltpass, 1);
		}
		digest=mhash_end(td);
		
		memmove( &key[i*block_size], digest, block_size);
		free(digest);
	}
	memmove(keyword, key, key_size);

	memset(key, '\0', key_size);
	memset(saltpass, '\0', 8+plen);

	free(key);
	free(saltpass);

	return 0;
}
