/*
 *	SILCConnection.java		2002/11/09
 *	
 *	Copyright (c) 2002 Alistair K Phipps (jsilc@alistairphipps.com).
 *	All rights reserved.
 */

package com.alistairphipps.jsilc.silcprotocol;

import java.net.UnknownHostException;
import com.alistairphipps.jsilc.core.TCPConnection;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.logging.Logger;
import java.util.logging.Level;
import com.alistairphipps.util.Hex;
import java.util.Arrays;
import org.bouncycastle.crypto.InvalidCipherTextException;

import org.bouncycastle.crypto.modes.CBCBlockCipher;
import org.bouncycastle.crypto.*;
import org.bouncycastle.crypto.params.*;
import org.bouncycastle.crypto.macs.*;
import org.bouncycastle.crypto.engines.*;
import org.bouncycastle.crypto.digests.*;

/** Describes a SILC Connection, including handling session key decryption.
 * 
 * @author Alistair K Phipps
 * @version 20021114
 */
public class SILCConnection
{
	/** logger is used for logging */
	private static Logger logger = Logger.getLogger( "com.alistairphipps.jsilc.silcprotocol.silcconnection" );
	
	private boolean _bEncryptedPackets;
	
	private TCPConnection _conn;

	private SILCCryptoEngine _ce;

	private byte _ySourceIdType = SILCIdType.CLIENT;
	private byte[] _ylSourceId = new byte[16];

	private byte _yDestIdType = SILCIdType.NONE;
	private byte[] _ylDestId = new byte[16];

	private String _strServerName;
	
	public byte getLocalIdType()
	{
		return _ySourceIdType;
	}
	
	public byte[] getLocalId()
	{
		return _ylSourceId;
	}

	public byte getRemoteIdType()
	{
		return _yDestIdType;
	}

	public byte[] getRemoteId()
	{
		return _ylDestId;
	}

	public String getServerName()
	{
		return _strServerName;
	}

	public SILCConnection( String strUsername, String strHost, int iPort ) throws UnknownHostException, IOException
	{
		logger.entering( "SILCConnection", "SILCConnection" );
		_bEncryptedPackets = false;
		_strServerName = strHost + ":" + Integer.toString( iPort );
		// open connection
		logger.info( "Opening connection to " + getServerName() );
		_conn = new TCPConnection( strHost, iPort );
		// do key exchange
		logger.info( "Performing key exchange" );
		try
		{
			SILCPacket pktKE2 = doKeyExchange();
			// we are now encrypting
			_ce = new SILCCryptoEngine( (SILCKeyExchange2Payload)pktKE2.getPayload() );
			_bEncryptedPackets = true;
			logger.info( "Performing connection authentication" );
			doConnAuth();
			logger.info( "Setting up client" );
			doSetup( strUsername );
		}
		catch( InvalidCipherTextException e ) // don't want to propagate this up
		{
			e.printStackTrace();
			throw new IOException();
		}
		logger.exiting( "SILCConnection", "SILCConnection" );
		
	}

	private SILCPacket doKeyExchange() throws IOException, InvalidCipherTextException
	{
		logger.entering( "SILCConnection", "doKeyExchange" );

		// start by negotiating protocols for the connection.  Currently protocols are fixed and defined in SILCKeyExchangeStartPayload.
		logger.finest( "Creating KES payload" );
		SILCPayload pl = new SILCKeyExchangeStartPayload();
		logger.finest( "Creating KES Packet" );
		SILCPacket p = new SILCPacket( SILCPacketType.KEY_EXCHANGE, _ySourceIdType, _ylSourceId, _yDestIdType, _ylDestId, pl );
		logger.finer( "Sending KES Packet: \n" + p.toString() );
		putPacket( p );
		logger.finest( "Waiting for KES Response Packet" );
		SILCPacket pin = getPacket();
		switch( pin.getType() )
		{
			case SILCPacketType.KEY_EXCHANGE: // this is what we should get
				logger.finer( "Got KES Response Packet: \n" + pin.toString() );
				// check cookie value
				if( !( Arrays.equals( ( (SILCKeyExchangeStartPayload)pin.getPayload() ).getCookie(), ( (SILCKeyExchangeStartPayload)pl ).getCookie() ) ) )
				{
					logger.severe( "Cookie mismatch in KES Response packet" );
					logger.severe( "Expected: " + Hex.toString( ( (SILCKeyExchangeStartPayload)pin.getPayload() ).getCookie() ) + " Got: " + Hex.toString( ( (SILCKeyExchangeStartPayload)pl ).getCookie() ) );
					throw new IOException();
				}
				// XXX should check values returned in payload to ensure they're ok... but we only specified one type in our sent payload, so we must be getting the same ones back, else we should have received a FAILURE instead
				break;
			case SILCPacketType.FAILURE: // problem with KE
				logger.severe( "Got Failure after KE Start" );
				throw new IOException(); // XXX should probably use our own exception
			default:
				logger.severe( "Got unexpected packet type: " + SILCPacketType.toString( pin.getType() ) + " after KE Start" );
				throw new IOException();
		}

		// if got here, have successfully negotiated protocols for the connection.  Continue with actual key exchange.
		// set destination to source of KES packet
		_yDestIdType = pin.getSourceIdType();
		_ylDestId = pin.getSourceId();
		// create KE_1 payload to send - calculates hash from *our* (initiator) KES payload, but uses mutual auth value from *server's* (responders) KES payload
		logger.finest( "Creating KE_1 payload" );
		SILCPayload pl2 = new SILCKeyExchange1Payload( p.getPayload().toByteList(), ( (SILCKeyExchangeStartPayload)pin.getPayload() ).getMutualAuth() );
		logger.finest( "Creating KE_1 packet" );
		SILCPacket p2 = new SILCPacket( SILCPacketType.KEY_EXCHANGE_1, _ySourceIdType, _ylSourceId, _yDestIdType, _ylDestId, pl2 );
		logger.finer( "Sending KE_1 packet: \n" + p2.toString() );
		putPacket( p2 );

		logger.finest( "Waiting for KE_2 packet" );
		SILCPacket pin2 = getPacket();
		switch( pin2.getType() )
		{
			case SILCPacketType.KEY_EXCHANGE_2: // this is what we should get
				logger.finer( "Got KE_2 Packet: \n" + pin2.toString() );
			/*	if( /// TODO: check sig
				{
					logger.severe( "Cookie mismatch in KES Response packet" );
					logger.severe( "Expected: " + Hex.toString( ( (SILCKeyExchangeStartPayload)pin.getPayload() ).getCookie() ) + " Got: " + Hex.toString( ( (SILCKeyExchangeStartPayload)pl ).getCookie() ) );
					throw new IOException();
				}*/
				break;
			case SILCPacketType.FAILURE: // problem with KE_1
				logger.severe( "Got Failure after KE 1: " + pin2.toString() );
				throw new IOException(); // XXX should probably use our own exception
			default:
				logger.severe( "Got unexpected packet type: " + SILCPacketType.toString( pin.getType() ) + " after KE 1" );
				throw new IOException();
		}

		SILCKeyExchange1Payload plX = (SILCKeyExchange1Payload)pl2;
		if( !((SILCKeyExchange2Payload)pin2.getPayload()).isValid( p.getPayload().toByteList(), plX.getDHPrivateKey(), plX.getPublicKey(), plX.getDHPublicData() ) )
		{
			logger.severe( "KE_2 packet had invalid signature" );
			throw new IOException();
		}

		logger.finest( "Creating success payload" );
		SILCPayload pl3 = new SILCSuccessPayload();
		logger.finest( "Creating success packet" );
		SILCPacket p3 = new SILCPacket( SILCPacketType.SUCCESS, _ySourceIdType, _ylSourceId, _yDestIdType, _ylDestId, pl3 );
		logger.finer( "Sending Success packet: \n" + p3.toString() );
		putPacket( p3 );

		logger.finest( "Waiting for success packet" );
		SILCPacket pin3 = getPacket();
		switch( pin3.getType() )
		{
			case SILCPacketType.SUCCESS:
				logger.finer( "Got Success packet: \n" + pin3.toString() );
				break;
			case SILCPacketType.FAILURE:
				logger.severe( "Got Failure after Success: " + pin3.toString() );
				throw new IOException();
			default:
				logger.severe( "Got unexpected packet type after Success: " + pin3.toString() );
				throw new IOException();
		}
		logger.info( "Key exchange done" );
		logger.exiting( "SILCConnection", "doKeyExchange" );
		return pin2;	// return KE2 packet
	}

	/** Perform connection authentication
	 */
	private void doConnAuth() throws IOException, InvalidCipherTextException
	{
		logger.entering( "SILCConnection", "doConnAuth" );
	
	
		// send CONNECTION_AUTH (NONE)
		logger.finest( "Creating Connection Auth payload" );
		SILCPayload pl = new SILCConnectionAuthPayload();
		logger.finest( "Creating Connection Auth packet" );
		SILCPacket p = new SILCPacket( SILCPacketType.CONNECTION_AUTH, _ySourceIdType, _ylSourceId, _yDestIdType, _ylDestId, pl );
		logger.finest( "Sending Connection Auth packet" );
		putPacket( p );
		logger.finest( "Waiting for success packet" );
		SILCPacket pin = getPacket();
		switch( pin.getType() )
		{
			case SILCPacketType.SUCCESS:
				logger.finer( "Got Success packet: \n" + pin.toString() );
				break;
			case SILCPacketType.FAILURE:
				logger.severe( "Got Failure after Connection Auth: " + pin.toString() );
				throw new IOException();
			default:
				logger.severe( "Got unexpected packet type after Connection Auth: " + pin.toString() );
				throw new IOException();
		}
		logger.exiting( "SILCConnection", "doConnAuth" );
	}

	/** Perform client setup
	 */
	private void doSetup( String strUsername ) throws IOException, InvalidCipherTextException
	{
		logger.entering( "SILCConnection", "doSetup" );
		logger.finest( "Creating new client payload" );
		SILCPayload pl = new SILCNewClientPayload( strUsername, "jSilc user" );
		logger.finest( "Creating new client packet" );
		SILCPacket p = new SILCPacket( SILCPacketType.NEW_CLIENT, _ySourceIdType, _ylSourceId, _yDestIdType, _ylDestId, pl );
		logger.finer( "Sending new client packet: \n" + p.toString() );
		putPacket( p );
	
		SILCPacket pin = getPacket();
		switch( pin.getType() )
		{
			case SILCPacketType.NEW_ID:
				logger.finer( "Got new ID: \n" + pin.toString() );
				break;
			default:
				logger.severe( "Got unexpected packet type after new client: " + pin.toString() );
				throw new IOException();
		}
		_ySourceIdType = ( (SILCIdPayload)pin.getPayload() ).getIdType();
		_ylSourceId = ( (SILCIdPayload)pin.getPayload() ).getId();

		logger.exiting( "SILCConnection", "doSetup" );
	}

	public SILCPacket getPacket() throws IOException, InvalidCipherTextException
	{
		logger.entering( "SILCConnection", "getPacket" );
		SILCPacket p;
		if( !_bEncryptedPackets )
		{
			logger.finer( "Reading first 8 bytes of header (and possibly padding)" );
			byte[] ylFirst = _conn.readData( 8 ); // first 8 bytes of header
			logger.finest( "Read." );
			ByteBuffer ybFirst = ByteBuffer.wrap( ylFirst );
			short sPayloadLength = ybFirst.getShort( 0 ); // read payload length
			byte yPadLength = ybFirst.get( 4 ); // read padding length
			logger.finest( "Payload length: " + Integer.toString( sPayloadLength & 0xFFFF ) + " Padding length: " + Integer.toString( yPadLength & 0xFF ) );
			ByteBuffer yb = ByteBuffer.allocate( ( sPayloadLength & 0xFFFF ) + ( yPadLength & 0xFF ) );
			yb.put( ylFirst );
			logger.finer( "Reading rest of packet" );
			yb.put( _conn.readData( ( sPayloadLength & 0xFFFF ) + ( yPadLength & 0xFF ) - 8 ) );
			p = new SILCPacket( yb.array() );
		}
		else
		{
			logger.finer( "Reading first 16 encrypted bytes of header (and possibly padding)" );
			byte[] ylFirstE = _conn.readData( 16 ); // first 16 bytes of header
			logger.finest( "Read." );
			byte[] ylFirst = _ce.decryptByteList( ylFirstE );
			logger.finest( "Decrypted.");
			ByteBuffer ybFirst = ByteBuffer.wrap( ylFirst );
			short sPayloadLength = ybFirst.getShort( 0 ); // read payload length
			byte yPadLength = ybFirst.get( 4 ); // read padding length
			logger.finest( "Payload length: " + Integer.toString( sPayloadLength & 0xFFFF ) + " Padding length: " + Integer.toString( yPadLength & 0xFF ) );
			logger.finer( "Reading rest of packet" );
			byte[] ylSecondE = _conn.readData( ( sPayloadLength & 0xFFFF ) + ( yPadLength & 0xFF ) - 16 ); // read rest of packet
			byte[] ylHMAC = _conn.readData( 20 ); // FIXME: should dynamically get correct length of HMAC
			logger.finest( "Read." );
			p = _ce.decryptByteListToPacket( ylFirst, ylSecondE, ylHMAC );
			logger.finest( "Decrypted." );
		}
		logger.exiting( "SILCConnection", "getPacket" );
		return p;	
	}

	public void putPacket( SILCPacket packet ) throws IOException, InvalidCipherTextException
	{
		logger.entering( "SILCConnection", "putPacket" );
		if( !_bEncryptedPackets )
			_conn.writeData( packet.toByteList() );
		else
			_conn.writeData( _ce.encryptPacketToByteList( packet ) );
		logger.exiting( "SILCConnection", "putPacket" );
	}
}
