/*
 *    This program is free software; you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation; either version 2 of the License, or
 *    (at your option) any later version.
 *
 *    This program is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with this program; if not, write to the Free Software
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 */

/* Some functions that didn't fit elsewhere */

/* $Id: extra.c,v 1.19 2000/03/10 19:19:12 nikos Exp $ */

#ifndef DEFINES_H
#define DEFINES_H
#include <defines.h>
#endif

#include "extra.h"
#include <xmalloc.h>
#include "locks.h"
#include <errors.h>
#ifndef NO_GETPASS
# include <getpass.h>
#endif
#include <keys.h>
#include <bits.h>
#include "mcrypt_int.h"

static char rcsid[] =
    "$Id: extra.c,v 1.19 2000/03/10 19:19:12 nikos Exp $";

extern int double_check;
extern char *outfile;
extern int stream_flag;
extern int bare_flag;
extern int cleanDelete;
extern int noiv;
extern int hash_algorithm;

char *my_getpass(char *prt)
{
	char *atmp;
	char *ztmp = secure_xmalloc(MAX_KEY_LEN);
	char *btmp;
	char string_tmp[200];

	atmp = getpass(prt);

	Bzero(string_tmp, sizeof(string_tmp));
	Bzero(ztmp, MAX_KEY_LEN);
	strncpy(ztmp, atmp, MAX_KEY_LEN - 1);

#ifndef NO_GETPASS
	secure_xfree(atmp, strlen(atmp));
#else
	Bzero(atmp, strlen(atmp));
#endif

	strcpy(string_tmp, _("Re-"));
	strcat(string_tmp, prt);

	btmp = getpass(string_tmp);

	if (strcmp(ztmp, btmp) != 0) {
		fprintf(stderr,
			_
			("Keywords do not match or they are too long.\n"));
#ifndef NO_GETPASS
		secure_xfree(btmp, MAX_KEY_LEN);
#else
		Bzero(btmp, strlen(btmp));
#endif
		secure_xfree(ztmp, MAX_KEY_LEN);
		return NULL;
	}
#ifndef NO_GETPASS
	secure_xfree(btmp, MAX_KEY_LEN);
#else
	Bzero(btmp, strlen(btmp));
#endif

	return ztmp;
}

char *get_password(int numofchar, int mode, unsigned int *len, void *salt)
{

	char *tmp = NULL;
	char msg[200];

	sprintf(msg, _("Enter passphrase: "));

	if (mode == ENCRYPT) {
		fprintf(stderr,
			_
			("Enter the passphrase (maximum of %d characters)\n"),
			MAX_KEY_LEN - 1);
		fprintf(stderr,
			_
			("Please use a combination of upper and lower case letters and numbers.\n"));
		tmp = my_getpass(msg);

	} else {
		if (double_check == FALSE) {
			tmp = getpass(msg);
		} else {
			tmp = my_getpass(msg);
		}
	}
	if (tmp == NULL)
		return NULL;

	*len = strlen(tmp);
	return tmp;

}



#ifdef HAVE_STAT

int check_file(char *filename)
{
	struct stat ostat;

	if (stat(filename, &ostat) == 0) {
		return 1;	/* exists */
	} else {
		return 0;	/* no */
	}


}


#endif



void read_until_null(char *pointer, FILE * fstream)
{
	int i;

	for (i = 0; i < 100; i++) {
		fread(&pointer[i], 1, 1, fstream);
		if (pointer[i] == 0)
			break;
	}

}

int check_file_head(FILE * fstream, char *algorithm, char *mode,
		    char *keymode, int *keysize, void *salt,
		    int *salt_size)
{
	char buf[3];
	char tmp_buf[101];
	short int keylen;
	unsigned char flags;
	unsigned char sflag;

	if (stream_flag == TRUE) {
		fstream = (FILE *) stdin;
	}

	fread(buf, 1, 3, fstream);
	fread(&flags, 1, 1, fstream);

	if (buf[0] == '\0' && buf[1] == 'm' && buf[2] == '\3') {
		/* if headers are ok */

		if (m_getbit(0, flags) != 0) {
			err_crit(_
				 ("Unsupported version of encrypted file\n"));
			return -1;
		}
		if (m_getbit(1, flags) != 0) {
			err_crit(_
				 ("Unsupported version of encrypted file\n"));
			return -1;
		}
		if (m_getbit(2, flags) != 0) {
			err_crit(_
				 ("Unsupported version of encrypted file\n"));
			return -1;
		}
		if (m_getbit(3, flags) != 0) {
			err_crit(_
				 ("Unsupported version of encrypted file\n"));
			return -1;
		}
		if (m_getbit(4, flags) != 0) {
			err_crit(_
				 ("Unsupported version of encrypted file\n"));
			return -1;
		}
		if (m_getbit(5, flags) != 0) {
			err_crit(_
				 ("Unsupported version of encrypted file\n"));
			return -1;
		}

		if (m_getbit(7, flags) != 0) {
			err_warn(_
				 ("No Initialization vector was used.\n"));
			noiv = TRUE;	/* No iv is being used */
		}

		read_until_null(tmp_buf, fstream);
		strcpy(algorithm, tmp_buf);

		fread(&keylen, sizeof(short int), 1, fstream);
#ifdef WORDS_BIGENDIAN
		*keysize = byteswap16(keylen);
#else
		*keysize = keylen;
#endif
#ifdef LIBMCRYPT_22
		if (strstr(algorithm, "twofish") != NULL
		    || strstr(algorithm, "rc6") != NULL
		    || strstr(algorithm, "blowfish") != NULL
		    || strstr(algorithm, "rijndael") != NULL
		    || strstr(algorithm, "rc2") != NULL
		    || strstr(algorithm, "serpent") != NULL) {

			sprintf(tmp_buf, "-%u", *keysize * 8);
			strcat(algorithm, tmp_buf);
		}

		if (strstr(algorithm, "safer-sk64") != NULL) {
			strcpy(algorithm, "safer-64");
		}
		if (strstr(algorithm, "arcfour") != NULL) {
			strcpy(algorithm, "rc4");
		}
		if (strstr(algorithm, "safer-sk128") != NULL) {
			strcpy(algorithm, "safer-128");
		}
#endif

		read_until_null(tmp_buf, fstream);
		strcpy(mode, tmp_buf);

		read_until_null(tmp_buf, fstream);
		strcpy(keymode, tmp_buf);
		fread(&sflag, 1, 1, fstream);
		if (m_getbit(6, flags) == 1) {
			if (m_getbit(0, sflag) != 0) {
				*salt_size = m_setbit(0, sflag, 0);
				if (*salt_size > 0) {
					fread(tmp_buf, 1, *salt_size,
					      fstream);
					memmove(salt, tmp_buf, *salt_size);
				}
			}
		}

		read_until_null(tmp_buf, fstream);	/* hash name ignored
							 * crc32 assumed 
							 */
		hash_algorithm = check_hash_algo(tmp_buf);

		return 0;
	} else {		/* No headers present */
		if (buf[0] == '\0' && buf[1] == 'm' && buf[2] == '\2') {
			err_crit(_
				 ("This is a file encrypted with the 2.2 version of mcrypt. Unfortunately you'll need that version to decrypt it.\n"));
			return 1;
		}
		if (buf[0] == '\0' && buf[1] == 'm' && buf[2] == '\1') {
			err_crit(_
				 ("This is a file encrypted with the 2.1 version of mcrypt. Unfortunately you'll need that version to decrypt it.\n"));
			return 1;
		}
		err_crit(_
			 ("Unable to get algorithm information. Use the --bare flag and specify the algorithm manualy.\n"));

		return 1;
	}

}


void *read_iv(FILE * fstream, int ivsize)
{
	char *IV;

	if (stream_flag == TRUE) {
		fstream = (FILE *) stdin;
	}

	IV = xmalloc(ivsize);
	fread(IV, 1, ivsize, fstream);

	return IV;

}

void _tolow(char *str, int size)
{
	int i;

	for (i = 0; i < size; i++) {
		str[i] = _tolower(str[i]);
	}
}


int write_file_head(FILE * filedes, char *algorithm, char *mode,
		    char *keymode, int *keysize, void *salt, int salt_size)
{
	char *buf;
	short int keylen = *keysize;
	unsigned char null = 0;
	unsigned char sflag = 0;
#ifdef LIBMCRYPT_22
	char* index;
	char algo[70];
#endif

	buf = xmalloc(4);

	buf[0] = '\0';
	buf[1] = 'm';
	buf[2] = '\3';
	buf[3] = '\0';		/* flags not yet fully supported */

	if (salt != NULL)
		buf[3] = m_setbit(6, buf[3], 1);
	if (noiv == TRUE)
		buf[3] = m_setbit(7, buf[3], 1);

	if (fwrite(buf, 1, 4, filedes) != 4) {
		return 1;
	}
	xfree(buf);

#ifdef LIBMCRYPT_22
	if (strstr(algorithm, "cast") == NULL
	    && strstr(algorithm, "safer-") == NULL) {
		index = rindex(algorithm, '-');
		if (index != NULL) {
			*index = '\0';
		}
	}
	if (strstr(algorithm, "rc4") != NULL) {
		strcpy(algo, "arcfour");
		algorithm = algo;
	}
	if (strstr(algorithm, "safer-") != NULL) {
		strcpy(algo, "safer-sk");
		index = rindex(algorithm, '-');
		if (index != NULL) {
			strcat(algo, ++index);
		}
		algorithm = algo;
	}
#endif
	if (fwrite(algorithm, 1, strlen(algorithm), filedes) !=
	    strlen(algorithm)) {
		return 1;
	}
#ifdef LIBMCRYPT_22
	if (strstr(algorithm, "rijndael") != NULL) {
		fwrite("-128", 1, 4, filedes);
	}
#endif
	fwrite(&null, 1, 1, filedes);

#ifdef WORDS_BIGENDIAN
	keylen = byteswap16(keylen);
#endif
	if (fwrite(&keylen, 1, sizeof(short int), filedes) !=
	    sizeof(short int)) {
		return 1;
	}

	if (fwrite(mode, 1, strlen(mode), filedes) != strlen(mode)) {
		return 1;
	}
	fwrite(&null, 1, 1, filedes);

	if (fwrite(keymode, 1, strlen(keymode), filedes) !=
	    strlen(keymode)) {
		return 1;
	}
	fwrite(&null, 1, 1, filedes);

	if (salt != NULL) {
		sflag = salt_size;
		sflag = m_setbit(0, sflag, 1);
		fwrite(&sflag, 1, 1, filedes);
		if (fwrite(salt, 1, salt_size, filedes) != salt_size) {
			return 1;
		}
	}



	buf = mhash_get_hash_name(hash_algorithm);
	_tolow(buf, strlen(buf));
	if (fwrite(buf, 1, strlen(buf), filedes) != strlen(buf)) {
		return 1;
	}
	xfree(buf);
	fwrite(&null, 1, 1, filedes);


	return 0;

}


int write_iv(FILE * filedes, void *IV, int ivsize)
{
	unsigned char *buf = NULL;

	if (ivsize > 0) {
		buf = xmalloc(ivsize);
		if (IV != NULL) {
			Bzero(buf, ivsize);
			memmove(buf, IV, ivsize);
		} else {
			Bzero(buf, ivsize);
		}
		if (fwrite(buf, 1, ivsize, filedes) != ivsize) {
			return 1;
		}
	}

	xfree(buf);
	return 0;

}

#ifdef HAVE_STAT
#ifdef HAVE_UTIME
void copyDate(char *srcName, char *dstName)
{
	int retVal;
	struct stat statBuf;
	struct utimbuf uTimBuf;

	retVal = stat(srcName, &statBuf);
	if (retVal == -1)
		perror("stat");

	uTimBuf.actime = statBuf.st_atime;
	uTimBuf.modtime = statBuf.st_mtime;

	retVal = utime(dstName, &uTimBuf);
	if (retVal == -1)
		perror("utime");
}
#endif


int is_normal_file(char *filename)
{
	struct stat statBuf;

#ifdef HAVE_LSTAT		/* Do not treat symlinks as regular files */
	if (lstat(filename, &statBuf) != 0)
		return FALSE;
#else
	if (stat(filename, &statBuf) != 0)
		return FALSE;
#endif

	if (S_ISREG(statBuf.st_mode) != 0) {
		return TRUE;
	} else {
		return FALSE;
	}

}

#endif

void shandler(int signal)
{

	fprintf(stderr, _("Signal %d caught. Exiting.\n"), signal);
	if (signal != SIGSEGV)
		cleanUp();
	exit(-1);

}

void snhandler(int signal)
{

	fprintf(stderr, _("\nSignal %d caught. Exiting.\n"), signal);
	cleanUp();
	exit(-1);

}

void cleanUp()
{

	fflush(NULL);
	if (stream_flag == FALSE && cleanDelete == TRUE)
		remove(outfile);	/* Delete the file we were writing to */

}


char **read_key_file(char *file, int *num)
{

	FILE *FROMF;
	char keyword[MAX_KEY_LEN], **keys = NULL;
	int x = 0;

	FROMF = fopen(file, "r");
	if (FROMF == NULL) {
		fprintf(stderr,
			_("Keyfile could not be opened. Ignoring it.\n"));
		return NULL;
	}
	if (read_lock(fileno((FILE *) FROMF)) == -1) {
		fprintf(stderr,
			_("Keyfile could not be locked. Ignoring it.\n"));
		return NULL;
	}
	while (fgets(keyword, MAX_KEY_LEN, FROMF) != NULL) {
		x++;
		keys = realloc(keys, x * sizeof(char *));
		keys[x - 1] = xmalloc(strlen(keyword) + 1);
/* Remove newline */
		if (keyword[strlen(keyword) - 1] == '\n')
			keyword[strlen(keyword) - 1] = '\0';
		strcpy(keys[x - 1], keyword);

	}

	*num = x;

	unlock(fileno((FILE *) FROMF));
	return keys;

}

#ifdef HAVE_GETPWUID
char *get_cfile(int uid, char *cfile)
{

	char *home;
	struct passwd *pwd;

	pwd = getpwuid(uid);

	if (pwd != NULL) {
		home = xmalloc(strlen(pwd->pw_dir) + strlen(cfile) + 2);
		strcpy(home, pwd->pw_dir);
		strcat(home, "/");
	} else {
		home = xcalloc(1, strlen(cfile) + 2);
	}
	strcat(home, cfile);

	return home;

}
#endif


int ask_overwrite(char *name, char *file)
{
	char x[2];
	int tty_opened = 0;
	FILE *fp;

#ifdef HAVE_SIGNAL_H
	Signal(SIGINT, snhandler);
	Signal(SIGQUIT, snhandler);
	Signal(SIGSEGV, snhandler);
	Signal(SIGPIPE, snhandler);
	Signal(SIGTERM, snhandler);
	Signal(SIGHUP, snhandler);
#endif


	fprintf(stderr,
		_
		("%s: %s already exists; do you wish to overwrite (y or n)?"),
		name, file);

	if ((fp = fopen("/dev/tty", "r")) == 0) {
		fp = stdin;
		setbuf(fp, NULL);
	} else {
		tty_opened = 1;
	}

	x[0] = fgetc(fp);
	x[1] = '\0';

	if (tty_opened != 0)
		fclose(fp);

#ifdef HAVE_SIGNAL_H
	Signal(SIGINT, shandler);
	Signal(SIGQUIT, shandler);
	Signal(SIGSEGV, shandler);
	Signal(SIGPIPE, shandler);
	Signal(SIGTERM, shandler);
	Signal(SIGHUP, shandler);
#endif

	if (strcoll(x, "y") == 0 || strcoll(x, "Y") == 0) {
		return TRUE;
	} else {
		return FALSE;
	}

}


void Bzero(void *s, size_t n)
{
#ifdef HAVE_MEMSET
	memset((void *) s, '\0', (size_t) n);
#else
# ifdef HAVE_BZERO
	bzero((void *) s, (size_t) n);
# else
	char *stmp = s;

	for (int i = 0; i < n; i++)
		stmp[i] = '\0';

# endif
#endif
}
