/* $Id: ssh_v2_transport.c,v 1.7 2001/02/11 03:35:32 tls Exp $ */

/*
 * Copyright (c) 2001 Eric Haszlakiewicz.
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 * 3. The name of the author may not be used to endorse or promote products
 *    derived from this software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
 * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
 * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED
 * AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
 * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
 * SUCH DAMAGE.
 */

#include <stdio.h>
#include <errno.h>
#include <string.h>

#include "sshd.h"
#include "ssh_buffer.h"
#include "ssh_cipher.h"
#include "ssh_event.h"
#include "ssh_mac.h"
#include "ssh_v2_messages.h"
#include "ssh_transport.h"
#include "ssh_util.h"

static int v2ssh_read_packet(ssh_context_t *, struct ssh_buf *);
static int v2ssh_xmit_data(ssh_context_t *, u_int8_t, const u_int8_t *, size_t);
static int v2ssh_xmit_packet(ssh_context_t *, const struct ssh_buf *);
static int v2ssh_get_sendsize(struct ssh_transport *);
static int v2ssh_send_kexinit(ssh_context_t *context, struct ssh_buf *bufp);
static int v2ssh_dispatch_transport(ssh_context_t *, struct ssh_buf *, size_t, int);

enum kex_algo
{
	SSH_V2_KEX_DIFFIE_HELLMAN
};
const char *kex_algo_map[] =
{
	"diffie-hellman-group1-sha1",   // SSH_V2_KEX_DIFFIE_HELLMAN
	NULL
};

enum shk_algo
{
	SSH_V2_SHK_SSH_DSS,
	SSH_V2_SHK_X509V3,
	SSH_V2_SHK_SPKI,
	SSH_V2_SHK_PGP
};
const char *shk_algo_map[] = {
	"ssh-dss",    // SSH_V2_SHK_SSH_DSS,
	"x509v3",     // SSH_V2_SHK_X509V3,
	"spki",       // SSH_V2_SHK_SPKI,
	"pgp",        // SSH_V2_SHK_PGP
	NULL
};

enum encr_algo {
	SSH_V2_ENCR_3DES_CBC,
	SSH_V2_ENCR_BLOWFISH_CBC,
	SSH_V2_ENCR_TWOFISH_CBC,
	SSH_V2_ENCR_ARCFOUR,
	SSH_V2_ENCR_IDEA_CBC,
	SSH_V2_ENCR_CAST128_CBC,
	SSH_V2_ENCR_NONE
};

const char *encr_algo_map[] = {
	"3des-cbc",
	"blowfish-cbc",
	"twofish-cbc",
	"arcfour",
	"idea-cbc",
	"cast128-cbc",
	"none",
	NULL
};

enum mac_algo {
	SSH_V2_HMAC_SHA1,
	SSH_V2_HMAC_SHA1_96,
	SSH_V2_HMAC_MD5,
	SSH_V2_HMAC_MD5_96,
	SSH_V2_HMAC_NONE
};
const char *mac_algo_map[] = {
	"hmac-sha1",
	"hmac-sha1-96",
	"hmac-md5",
	"hmac-md5-96",
	"none",
	NULL
};
enum comp_algo {
	SSH_V2_COMP_ZLIB
};
const char *comp_algo_map[] = {
	"zlib",
	NULL
};

int v2ssh_init_transport(struct ssh_transport *xport, int isserver)
{
	int s = xport->commsock;
	struct ssh_v2_transport *v2xport;

	memset(xport, 0, sizeof(*xport));
	xport->commsock = s;
	v2xport = &xport->t_v2;

	xport->kexinit = v2ssh_send_kexinit;
	xport->dispatch_transport = v2ssh_dispatch_transport;
	xport->init_compression = NULL; 	// XXX XAX might be version specific
	xport->init_decompress = NULL;		// XXX XAX might be version specific
	xport->read_packet = v2ssh_read_packet;
	xport->xmit_data = v2ssh_xmit_data;
	xport->xmit_packet = v2ssh_xmit_packet;
	xport->set_max_packet_size = NULL;	// XXX XAX might be version specific
	xport->get_sendsize = v2ssh_get_sendsize;

	v2xport->max_payload_size = 32768;	// XXX Can be higher.
	v2xport->max_packet_size = 35000; // XXX Can be higher.

	v2xport->enc_buf = buf_alloc(NULL, 2048);
	v2xport->clr_buf = buf_alloc(NULL, 2048);
	v2xport->sequence_recv = 0;
	v2xport->sequence_send = 0;
	v2xport->mac_send = NULL;
	v2xport->mac_recv = NULL;
	return(0);
}

/*
 * v2ssh_send_kexinit
 *    This should be called by whether we're in client mode or server mode.
 *    We assume that we're not sending an immediate guess packet.
 */
static const char *v2ssh_mklist(ssh_context_t *context, int *tags, const char *tagmap[]);
static const char *v2ssh_mklist(ssh_context_t *context, int *tags, const char *tagmap[])
{
/*	XAX figure out preferred mode.
	insert into string.

	for (ii = 0; ii < XAXnumalgorithms; ii++)
	{
		if (XAXalg[ii].used || !XAXalg[ii].available)
			continue;
		alg[ii].used = 1;
		insert into string.
	}
	return the string.
*/
	return tagmap[0];
}

static int v2ssh_send_kexinit(ssh_context_t *context, struct ssh_buf *bufp)
{
	struct ssh_buf buf;
	struct ssh_event ev;
	struct ssh_ev_send *sendptr;
	const char *a_kex, *a_shk, *a_encr_c2s;
	const char *a_encr_s2c, *a_mac_c2s, *a_mac_s2c;
	const char *a_comp_c2s, *a_comp_s2c;
	const char *lang_c2s, *lang_s2c;

	memset(&buf, 0, sizeof(buf));
	memset(&ev, 0, sizeof(ev));

	a_kex = v2ssh_mklist(context, context->v2_ctx.kex_algo, kex_algo_map);
	a_shk = v2ssh_mklist(context, context->v2_ctx.shk_algo, shk_algo_map);
	a_encr_c2s = v2ssh_mklist(context, context->v2_ctx.encr_algo_c2s, encr_algo_map);
	a_encr_s2c = v2ssh_mklist(context, context->v2_ctx.encr_algo_s2c, encr_algo_map);
	a_mac_c2s = v2ssh_mklist(context, context->v2_ctx.mac_algo_c2s, mac_algo_map);
	a_mac_s2c = v2ssh_mklist(context, context->v2_ctx.mac_algo_s2c, mac_algo_map);

	a_comp_c2s = v2ssh_mklist(context, context->v2_ctx.comp_algo_c2s, comp_algo_map);
	a_comp_s2c = v2ssh_mklist(context, context->v2_ctx.comp_algo_s2c, comp_algo_map);

	lang_c2s = lang_s2c = "";

	if (buf_alloc(&buf, 1024) == NULL)
	{
		SSH_ERROR("buf_alloc for send_kexinit failed");
		return(RET_FATAL);
	}

	buf_put_byte(&buf, SSH_V2_MSG_KEXINIT);
	ssh_rand_bytes(16, context->v2_ctx.cookie);
	buf_put_nbytes(&buf, 16, context->v2_ctx.cookie);
	buf_put_binstr(&buf, a_kex, strlen(a_kex));
	buf_put_binstr(&buf, a_shk, strlen(a_shk));
	buf_put_binstr(&buf, a_encr_c2s, strlen(a_encr_c2s));
	buf_put_binstr(&buf, a_encr_s2c, strlen(a_encr_s2c));
	buf_put_binstr(&buf, a_mac_c2s, strlen(a_mac_c2s));
	buf_put_binstr(&buf, a_mac_s2c, strlen(a_mac_s2c));
	buf_put_binstr(&buf, a_comp_c2s, strlen(a_comp_c2s));
	buf_put_binstr(&buf, a_comp_s2c, strlen(a_comp_s2c));
	buf_put_binstr(&buf, lang_c2s, strlen(lang_c2s));
	buf_put_binstr(&buf, lang_s2c, strlen(lang_s2c));
	buf_put_byte(&buf, 0);	  // first_key_packet_follows == false
	buf_put_int32(&buf, 0);   // reserved

	ev.event_type = SSH_EVENT_SEND;
	sendptr = (struct ssh_ev_send *)&ev.event_data;
	sendptr->dlen = buf_alllen(&buf);
	memcpy(sendptr->data, buf_alldata(&buf), buf_alllen(&buf));
	buf_cleanup(&buf);
	if (ssh_sys_sendevent(context, &ev) < 0)
	{
		SSH_DLOG(1, ("Unable to send KEXINIT package"));
		return(-1);
	}
	return(0);
}

/*
 * Reads from the socket into buf.
 *
 * Returns:
 *     -1 on error
 *     >0 Enough data to form a complete packet.
 *          Returns size of packet.
 *     0  Not enough data yet.
 */
static int v2ssh_read_packet(ssh_context_t *context, struct ssh_buf *buf)
{
	struct ssh_buf *decr_buf;
	struct ssh_buf *enc_buf;

	int read_len;
	int packetlen, pad_len;
	int amt_to_read;
	int amt_to_decrypt;
	int maclen;

	int payload_len;
	int block_size;
	struct ssh_v2_transport *xport;

	/*
	 * E  uint32  packet length (= size of encrypted portion, not incl self)
	 * E  byte    padding length
	 * E  byte[n] payload (n = pkt len - pad len - 1)
	 * E  byte[n] random pad (n = pad len)
	 * byte[m] mac, m = mac len
	 *
	 * E = encrypted portion.
	 */

	xport = &context->transport_ctx.t_v2;
	enc_buf = xport->enc_buf;

	// Set decr_buf.  This is where we decrypt into
	// and where we copy/decompress out of.
	// If we're not compressing, we can decrypt
	// directly into the final buffer.
	if (xport->compressing)
		decr_buf = xport->clr_buf;
	else
		decr_buf = buf;

	// Start at the beginning of the buffer.
	// Calculation of the mac depends on this.
	buf_reset(decr_buf);

	block_size = context->cipher->block_size;

	if (xport->mac_recv)
		maclen = xport->mac_recv->mac_length;
	else
		maclen = 0;

	/*
	 * We don't put more than the payload into the clear buf
	 * so it should be all used up by the time we get back here.
	 */
	/*assert(buf_len(buf) == 0);*/

	/*
	 * First read in enough cipher text to get the
	 * packet length and padding length.
	 */
	amt_to_read = block_size - buf_len(enc_buf);
	amt_to_decrypt = block_size;
	while (amt_to_read > 0)
	{
		read_len = buf_fillbuf(read, enc_buf, context->transport_ctx.commsock, amt_to_read);
		
		if (read_len < 0)
		{
			if (errno != EINTR)
				return(-1);
			continue;
		}
		if (read_len == 0)
		{
			SSH_DLOG(4, ("EOF\n"));
			errno = EPIPE;
			return(-1);
		}
		amt_to_read -= read_len;
	}

	/*
	 * Ok, we've got enough for the first block so
	 * we can go ahead and decrypt it to get the length fields.
	 */

	// Make sure there's enough space
	if (buf_avail(decr_buf) < amt_to_decrypt)
	{
		if (buf_makeavail(decr_buf, amt_to_decrypt) != 0)
		{
			SSH_ERROR("Unable to grow buffer for decrypted data");
			return(-1);
		}
	}

	// Decrypt, if necessary.
	if (context->cipher->decrypt)
		cipher_decrypt(context->cipher, buf_data(enc_buf),
					   buf_endofdata(decr_buf), amt_to_decrypt);
	else
		memcpy(buf_endofdata(decr_buf), buf_data(enc_buf), amt_to_decrypt);


	/* Grab the packet length and padding length */
	buf_get_int32(decr_buf, &packetlen);
	buf_get_int8(decr_buf, (int8_t *)&pad_len);

	payload_len = packetlen - pad_len - 1;

	if (payload_len < 0)
	{
		SSH_ERROR("Pad length longer than packet length!");
		return(RET_FATAL);
	}

	if (packetlen > xport->max_packet_size ||
	    payload_len > xport->max_payload_size)
	{
		SSH_ERROR("Packet length too long: %d, %d\n", packetlen, payload_len);
		return(RET_FATAL);
	}

	if ((packetlen + 4) % block_size != 0)
	{
		SSH_ERROR("Packet length+4 not multiple of cipher block size.\n"
		          "Packet length: %d\tCipher block size: %d\n",
		          packetlen, block_size);
		return(RET_FATAL);
	}

	/*
	 * Figure out how much we still need to read, decrypt
	 * Read enough to get the mac too.
	 */
	amt_to_read = packetlen - buf_len(enc_buf) + maclen;
	amt_to_decrypt = packetlen - block_size;

	/* Skip the blocks we just decrypted */
	buf_get_skip(enc_buf, block_size);

	while(amt_to_read > 0)
	{
		read_len = buf_fillbuf(read, enc_buf, context->transport_ctx.commsock, amt_to_read);
		
		if (read_len < 0)
		{
			if (errno != EINTR)
				return(RET_FATAL);
			continue;
		}
		if (read_len == 0) {
			SSH_DLOG(4, ("EOF\n"));
			errno = EPIPE;
			return(RET_FATAL);
		}

		amt_to_read -= read_len;
	}

	/*
	 * Decrypt payload and padding
	 */

	/* Make sure there's enough space */
	if (buf_avail(decr_buf) < amt_to_decrypt)
	{
		if (buf_makeavail(decr_buf, amt_to_decrypt) != 0)
		{
			SSH_ERROR("Unable to grow buffer for decrypted data");
			return(RET_FATAL);
		}
	}

	if (context->cipher->decrypt)
		cipher_decrypt(context->cipher, buf_data(enc_buf),
					   buf_endofdata(decr_buf), amt_to_decrypt);
	else
		memcpy(buf_endofdata(decr_buf), buf_data(enc_buf), amt_to_decrypt);

	/* Mark the buffer as containing the payload, but not the padding */
	buf_adjlen(decr_buf, amt_to_decrypt - pad_len);

	buf_get_skip(enc_buf, amt_to_decrypt);

	buf_rewind(decr_buf);

	if (xport->mac_recv)
	{
		if (memcmp(buf_data(enc_buf),
		       xport->mac_recv->mac_generate(xport->mac_recv,
		                                xport->sequence_recv,
		                                buf_data(decr_buf),
		                                buf_len(decr_buf) + pad_len),
		       maclen) != 0)
		{
			//XAX send SSH_DISCONNECT_MAC_ERROR
			SSH_ERROR("MAC check failed!");
			return(RET_FATAL);
		}
		buf_get_skip(enc_buf, maclen);
	}

	// Re-skip the packet length and padlen.
	buf_get_skip(decr_buf, 4 + 1);

	if (xport->compressing)
	{
		int ret;
		// decompress from decr_buf to buf
		int decompress_space = buf_len(decr_buf) * 2;

		if (buf_avail(buf) < decompress_space)
		{
			if (buf_makeavail(buf, decompress_space) != 0)
			{
				SSH_ERROR("Unable to grow buffer for decrypted data");
				return(RET_FATAL);
			}
		}

		xport->inz.next_in = buf_data(decr_buf);
		xport->inz.avail_in = buf_len(decr_buf);
		xport->inz.next_out = buf_endofdata(buf);
		xport->inz.avail_out = buf_avail(buf);
		do {
			ret = inflate(&(xport->inz), Z_PARTIAL_FLUSH);
			if (ret == Z_OK)
				continue;
			if (ret != Z_BUF_ERROR)
			{
				SSH_DLOG(1, ("inflate failed: %d\n", ret));
				return (-1);
			}
			if (xport->inz.avail_out == 0 ||
			    xport->inz.avail_out < xport->inz.avail_in)
			{
				void *orig_data = buf_alldata(buf);
				if (buf_makeavail(buf, xport->inz.avail_in * 2) != 0)
				{
					SSH_ERROR("Unable to grow buffer for decrypted data");
					return(RET_FATAL);
				}
				xport->inz.avail_out = xport->inz.avail_in * 2;
				if (orig_data != buf_alldata(buf))
				{
					// If the data pointer got reallocated, recalculate
					// where we're writing the compression output to.
// XAX XXX ???
//					xport->inz.next_out = buf_alldata(buf) +
//					       (xport->inz.next_out - orig_data);
				}
			}
		} while (ret != Z_BUF_ERROR);
		if (xport->inz.avail_in != 0)
		{
			SSH_DLOG(1, ("incomplete inflation: %d\n", xport->inz.avail_in));
			return(RET_FATAL);
		}
	}
	// else it's already in the final buf

	return(buf_len(buf));
}

static int v2ssh_xmit_data(ssh_context_t *context,
                           u_int8_t ptype, const u_int8_t *rawbuf, size_t len)
{
	FUNC_DECL(v2ssh_xmit_data);

	SSH_ERROR("Unimplemented");
	return(RET_FATAL);
}
static int v2ssh_xmit_packet(ssh_context_t *context, const struct ssh_buf *buf)
{
	FUNC_DECL(v2ssh_xmit_packet);
	struct ssh_v2_transport *xport = &context->transport_ctx.t_v2;
	u_int32_t padlen, packetlen, maclen;
	int block_size;
	struct ssh_buf *actual_send_buf;


	block_size = context->cipher->block_size;
	if (block_size < 8)
		block_size = 8;

	padlen = block_size - ((4 + 1 + buf_len(buf)) % block_size);
	packetlen = 1 + buf_len(buf) + padlen;
	if (packetlen < 16)
	{
		packetlen += block_size;
		padlen += block_size;
	}
	if (xport->mac_send)
		maclen = xport->mac_send->mac_length;
	else
		maclen = 0;

	{
		struct ssh_buf *newb;
		int encryptlen, totallen, alloclen;

		encryptlen = 4 + packetlen;
		totallen = encryptlen + maclen;
		alloclen = totallen;
		if (xport->compressing)
			alloclen = (totallen + 1) * 1.001 + 12;
		if ((newb = buf_alloc(xport->clr_buf, alloclen)) == NULL)
		{
			SSH_ERROR("Unable to buf_alloc clr_buf: %s\n", strerror(errno));
			return(RET_FATAL);
		}
		xport->clr_buf = newb;
	}

	buf_put_int32(xport->clr_buf, packetlen);
	buf_put_byte(xport->clr_buf, padlen);

	// Payload:
	if (xport->compressing)
	{
		int ret;
		// Only zlib is supported.
		xport->outz.next_out = buf_endofdata(xport->clr_buf);
		xport->outz.avail_out = buf_avail(xport->clr_buf);
		xport->outz.next_in = buf_alldata(buf);
		xport->outz.avail_in = buf_alllen(buf);
		if ((ret = deflate(&(xport->outz), Z_PARTIAL_FLUSH)) != Z_OK)
		{
			SSH_ERROR("deflate failed: %d\n", ret);
			return(RET_FATAL);
		}
		buf_adjlen(xport->clr_buf,
		           buf_avail(xport->clr_buf) - xport->outz.avail_out);
	}
	else
	{
		buf_append(xport->clr_buf, buf);
	}
	ssh_rand_bytes(padlen, buf_endofdata(xport->clr_buf));
	buf_adjlen(xport->clr_buf, padlen);

	if (context->cipher->encrypt)
	{
		struct ssh_buf *newb;

		// Also does a buf_reset
		if ((newb = buf_alloc(xport->enc_buf,
		                      buf_alllen(xport->clr_buf) + maclen))
		    == NULL)
		{
			SSH_ERROR("Unable to allocate buffer for encrypted data: %s\n",
			          strerror(errno));
			return(RET_FATAL);
		}

		cipher_encrypt(context->cipher,
		               buf_alldata(xport->clr_buf),
		               buf_data(xport->enc_buf),
		               buf_alllen(xport->clr_buf));
		actual_send_buf = xport->enc_buf;
	}
	else
	{
		actual_send_buf = xport->clr_buf;
	}

	if (xport->mac_send)
	{
		memcpy(buf_endofdata(actual_send_buf),
		       xport->mac_send->mac_generate(xport->mac_send,
		                                    xport->sequence_send,
		                                    buf_alldata(xport->clr_buf),
		                                    buf_alllen(xport->clr_buf)),
		       xport->mac_send->mac_length);
		buf_adjlen(actual_send_buf, xport->mac_send->mac_length);
	}

	xport->sequence_send++;

	return(RET_OK);
}

static int v2ssh_dispatch_transport(ssh_context_t *context, struct ssh_buf *msg,
                                size_t size, int msg_type)
{
	FUNC_DECL(v2ssh_dispatch_transport);

// XAX write me.
	switch(msg_type)
	{
	case SSH_V2_MSG_KEXINIT:
		break;
	case SSH_V2_MSG_NEWKEYS:
		break;
	case SSH_V2_MSG_KEXDH_INIT:
		break;
	case SSH_V2_MSG_KEXDH_REPLY:
		break;
	default:
		return(RET_FATAL);
		break;
	}

	return(RET_OK);
}

static int v2ssh_get_sendsize(struct ssh_transport *context)
{
	return(context->t_v2.max_payload_size);
}
