#include <stdio.h>
#include <unistd.h>

#include "options.h"
#include "util.h"
#include "session.h"
#include "buffer.h"
#include "ssh.h"
#include "packet.h"
#include "auth.h"
#include "authpasswd.h"
#include "authpubkey.h"
#include "pwd.h"

static void authclear();
static int checkusername(unsigned char *username, unsigned int userlen);
static void send_msg_userauth_banner();

void authinitialise() {

	ses.authstate.failcount = 0;
	authclear();
	
}

static void authclear() {
	
	ses.authstate.authdone = 0;
	ses.authstate.pw = NULL;
	ses.authstate.authtypes = 0;
#ifdef DROPBEAR_PUBKEY_AUTH
	ses.authstate.authtypes |= AUTH_TYPE_PUBKEY;
#endif
#ifdef DROPBEAR_PASSWORD_AUTH
	ses.authstate.authtypes |= AUTH_TYPE_PASSWORD;
#endif

}

static void send_msg_userauth_banner() {

	TRACE(("enter send_msg_userauth_banner"));
	if (ses.opts->banner == NULL) {
		TRACE(("leave send_msg_userauth_banner: banner is NULL"));
		return;
	}

	CHECKCLEARTOWRITE();

	buf_putbyte(ses.writepayload, SSH_MSG_USERAUTH_BANNER);
	buf_putstring(ses.writepayload, buf_getptr(ses.opts->banner,
				ses.opts->banner->len), ses.opts->banner->size);
	buf_putstring(ses.writepayload, "en", 2);

	encrypt_packet();
	buf_free(ses.opts->banner);

	TRACE(("leave send_msg_userauth_banner"));
}

void recv_msg_userauth_request() {

	unsigned char *username, *servicename, *methodname;
	unsigned int userlen, servicelen, methodlen;

	TRACE(("enter recv_msg_userauth_request"));

	/* ignore packets if auth is already done */
	if (ses.authstate.authdone == 1) {
		return;
	}

	/* send the banner if it exists, it will only exist once */
	if (ses.opts->banner) {
		send_msg_userauth_banner();
	}

	
	username = buf_getstring(ses.payload, &userlen);
	servicename = buf_getstring(ses.payload, &servicelen);
	methodname = buf_getstring(ses.payload, &methodlen);

	/* only handle 'ssh-connection' currently */
	if (servicelen != SSH_SERVICE_CONNECTION_LEN
			&& (strncmp(servicename, SSH_SERVICE_CONNECTION,
					SSH_SERVICE_CONNECTION_LEN) != 0)) {
		
		/* TODO - disconnect here */
		m_free(username);
		m_free(servicename);
		m_free(methodname);
		dropbear_exit("unknown service in auth");
	}

	if (methodlen == AUTH_METHOD_NONE_LEN &&
			strncmp(methodname, AUTH_METHOD_NONE,
				AUTH_METHOD_NONE_LEN) == 0) {
		send_msg_userauth_failure(0, 0);
	} else if (checkusername(username, userlen) == CHECK_USER_RETURN) {
		/* username is invalid/no shell/etc - send failure */
		TRACE(("sending checkusername failure"));
		send_msg_userauth_failure(0, 1);
#ifdef DROPBEAR_PASSWORD_AUTH
	} else if (methodlen == AUTH_METHOD_PASSWORD_LEN &&
			strncmp(methodname, AUTH_METHOD_PASSWORD,
				AUTH_METHOD_PASSWORD_LEN) == 0) {
		passwordauth(username, userlen);
#endif /* DROPBEAR_PASSWORD_AUTH */
#ifdef DROPBEAR_PUBKEY_AUTH
	} else if (methodlen == AUTH_METHOD_PUBKEY_LEN &&
			strncmp(methodname, AUTH_METHOD_PUBKEY,
				AUTH_METHOD_PUBKEY_LEN) == 0) {
		pubkeyauth(username, userlen);
#endif /* DROPBEAR_PUBKEY_AUTH */
	} else {
		send_msg_userauth_failure(0, 1);
	}

	m_free(username);
	m_free(servicename);
	m_free(methodname);
}

static int checkusername(unsigned char *username, unsigned int userlen) {

	char* shell;
	
	TRACE(("enter checkusername"));
	if (userlen > MAX_USERNAME_LEN) {
		return CHECK_USER_RETURN;
	}

	/* new user or username has changed */
	if (ses.authstate.pw == NULL ||
		strcmp(username, ses.authstate.pw->pw_name) != 0) {
			/* the username needs resetting */
			authclear();
			ses.authstate.pw = getpwnam((char*)username);
	}

	/* check that user exists */
	if (ses.authstate.pw == NULL) {
		TRACE(("leave checkusername: user doesn't exist"));
		send_msg_userauth_failure(0, 1);
		return CHECK_USER_RETURN;
	}

	/* check for an empty password */
	if (ses.authstate.pw->pw_passwd[0] == '\0') {
		TRACE(("leave checkusername: empty pword"));
		send_msg_userauth_failure(0, 1);
		return CHECK_USER_RETURN;
	}

	TRACE(("shell is %s", ses.authstate.pw->pw_shell));
	/* check that the shell is valid */
	/* XXX - todo check this is correct: empty shell is ok */
	if (ses.authstate.pw->pw_shell[0] == '\0') {
		goto goodshell;
	}
	setusershell();
	while ((shell = getusershell()) != NULL) {
		TRACE(("test shell is '%s'", shell));
		if (strcmp(shell, ses.authstate.pw->pw_shell) == 0) {
			/* have a match */
			goto goodshell;
		}
	}
	/* no matching shell */
	endusershell();
	TRACE(("no matching shell"));
	send_msg_userauth_failure(0, 1);
	return CHECK_USER_RETURN;
	
goodshell:
	endusershell();
	TRACE(("matching shell"));

	TRACE(("uid = %d\n", ses.authstate.pw->pw_uid));
	TRACE(("leave checkusername"));
	return CHECK_USER_CONTINUE;

}

/* partial indicates whether to set the "partial success" flag,
 * incrfail is whether to count this failure in the failure count (which
 * is limited */
void send_msg_userauth_failure(int partial, int incrfail) {

	buffer *typebuf;

	TRACE(("enter send_msg_userauth_failure"));

	CHECKCLEARTOWRITE();
	
	buf_putbyte(ses.writepayload, SSH_MSG_USERAUTH_FAILURE);

	/* put a list of allowed types */
	typebuf = buf_new(30); /* long enough for PUBKEY and PASSWORD */

	if (ses.authstate.authtypes & AUTH_TYPE_PUBKEY) {
		buf_putbytes(typebuf, AUTH_METHOD_PUBKEY, AUTH_METHOD_PUBKEY_LEN);
		if (ses.authstate.authtypes & AUTH_TYPE_PASSWORD) {
			buf_putbyte(typebuf, ',');
		}
	}
	
	if (ses.authstate.authtypes & AUTH_TYPE_PASSWORD) {
		buf_putbytes(typebuf, AUTH_METHOD_PASSWORD, AUTH_METHOD_PASSWORD_LEN);
	}

	buf_setpos(typebuf, 0);
	buf_putstring(ses.writepayload, buf_getptr(typebuf, typebuf->len),
			typebuf->len);
	buf_free(typebuf);

	buf_putbyte(ses.writepayload, partial ? 1 : 0);
	encrypt_packet();

	if (incrfail) {
		ses.authstate.failcount++;
		sleep(FAIL_SLEEP_TIME);
	}

	if (ses.authstate.failcount >= MAX_AUTH_TRIES) {
		/* XXX - send disconnect ? */
		TRACE(("Max auth tries reached, exiting"));
		exit(0);
	}
	
	TRACE(("leave send_msg_userauth_failure"));
}

/* only to be called when writebuf is empty */
void send_msg_userauth_success() {

	TRACE(("enter send_msg_userauth_success"));

	CHECKCLEARTOWRITE();

	buf_putbyte(ses.writepayload, SSH_MSG_USERAUTH_SUCCESS);
	encrypt_packet();

	ses.authstate.authdone = 1;
	close(ses.childpipe);
	TRACE(("leave send_msg_userauth_success"));

}
