/*
 *	SILCConnection.java		2002/11/14
 *	
 *	Copyright (c) 2002 Alistair K Phipps (jsilc@alistairphipps.com).
 *	All rights reserved.
 */

package com.alistairphipps.jsilc.silcprotocol;

import java.nio.ByteBuffer;
import java.util.Arrays;
import org.bouncycastle.crypto.*;
import org.bouncycastle.crypto.params.*;
import org.bouncycastle.crypto.macs.*;
import org.bouncycastle.crypto.modes.*;
import org.bouncycastle.crypto.engines.*;
import org.bouncycastle.crypto.digests.*;

/** Describes a SILC Cryptography engine - decrypting / encrypting packets
 * 
 * @author Alistair K Phipps
 * @version 20021114
 */
public class SILCCryptoEngine
{
	HMac _hmacSend;
	HMac _hmacRecv;
	BlockCipher _ciphSend;
	BlockCipher _ciphRecv;
	int _iSeqNumSend;
	int _iSeqNumRecv;
	
	/** Initialise packet crypto engine with parameters from given KE2 payload
	 * @param pl KE2 payload received at end of KE negotiation
	 */
	public SILCCryptoEngine( SILCKeyExchange2Payload pl )
	{
		// create new MAC - one for sending, one for receiving
		_hmacSend = new HMac( new SHA1Digest() );
		_hmacRecv = new HMac( new SHA1Digest() );

		// make sure HMAC length is correct (160bits)
		assert( _hmacSend.getMacSize() == 20 );
		assert( _hmacRecv.getMacSize() == 20 );

		// create new cipher (128-bit AES) in CBC mode
		// one for sending and one for receiving
		_ciphSend = new CBCBlockCipher( new RijndaelEngine( 128 ) );
		_ciphRecv = new CBCBlockCipher( new RijndaelEngine( 128 ) );
		
		// set up a hasher to calculate various keys and initial vectors
		Digest digest = new SHA1Digest();
		
		// get a byte list with KEY | HASH in
		ByteBuffer yb = ByteBuffer.allocate( pl.getDHKey().length + pl.getHash().length );
		yb.put( pl.getDHKey() );
		yb.put( pl.getHash() );
		byte[] yl = yb.array();
		
		// calculate sending IV = hash( 0 | KEY | HASH )
		byte[] ylSendIV = calculateTruncatedHash( digest, (byte)0, yl, _ciphSend.getBlockSize() );
	
		// calculate receiving IV = hash( 1 | KEY | HASH )
		byte[] ylRecvIV = calculateTruncatedHash( digest, (byte)1, yl, _ciphRecv.getBlockSize() );
	
		// calculate sending encryption key = hash( 2 | KEY | HASH )
		byte[] ylSendKey = calculateTruncatedHash( digest, (byte)2, yl, _ciphSend.getBlockSize() );

		// calculate receiving encryption key = hash( 3 | KEY | HASH )
		byte[] ylRecvKey = calculateTruncatedHash( digest, (byte)3, yl, _ciphRecv.getBlockSize() );

		// calculate sending HMAC key = hash( 4 | KEY | HASH )
		byte[] ylSendHMACKey = calculateTruncatedHash( digest, (byte)4, yl, _hmacSend.getMacSize() );

		// calculate receiving HMAC key = hash( 5 | KEY | HASH )
		byte[] ylRecvHMACKey = calculateTruncatedHash( digest, (byte)5, yl, _hmacRecv.getMacSize() );

		// initialise sending cipher
		_ciphSend.init( true, new ParametersWithIV( new KeyParameter( ylSendKey ), ylSendIV ) );

		// initialise receiving cipher
		_ciphRecv.init( false, new ParametersWithIV( new KeyParameter( ylRecvKey ), ylRecvIV ) );

		// initialise sending hmac
		_hmacSend.init( new KeyParameter( ylSendHMACKey ) );

		// initialise receiving hmac
		_hmacRecv.init( new KeyParameter( ylRecvHMACKey ) );
	}

	/** Helper function to calculate a hash using the given engine of ( y | yl ) and truncate to the specified number of bytes
	 * @param d Digest function to create hash - must be created and reset
	 * @param y Byte prefix
	 * @param yl Byte list suffix
	 * @return Digest( y | yl )
	 */
	private byte[] calculateTruncatedHash( Digest d, byte y, byte[] yl, int iNumBytes )
	{
		d.update( y );
		d.update( yl, 0, yl.length );
		byte[] ylD = new byte[ d.getDigestSize() ];
		d.doFinal( ylD, 0 );
		byte[] ylRet = new byte[ iNumBytes ];
		assert( iNumBytes <= ylD.length );	// digest must be same length or longer than length specified
		ByteBuffer.wrap( ylD ).get( ylRet );
		return ylRet;
	}

	/** Encrypt given packet to byte list with HMAC at the end
	 * @param p Packet to encrypt and add MAC to
	 * @return byte list with encrypted packet and MAC
	 */
	public byte[] encryptPacketToByteList( SILCPacket p ) throws InvalidCipherTextException
	{
		// calculate sha1 HMAC from HMAC( seq num | packet data ) and increment sequence number
		byte[] ylHMAC = new byte[ _hmacSend.getMacSize() ];
		ByteBuffer yb = ByteBuffer.allocate( p.toByteList().length + 4 );
		yb.putInt( _iSeqNumSend++ );
		yb.put( p.toByteList() );
		_hmacSend.update( yb.array(), 0, yb.array().length );	// seq num | packet
		_hmacSend.doFinal( ylHMAC, 0 );

		// do encryption - TODO: only encrypt header and padding if payload is encrypted with private or channel key
		byte[] ylIn = p.toByteList();
		byte[] ylOut = new byte[ ylIn.length ];
		assert( ylOut.length % _ciphSend.getBlockSize() == 0 ); // padding should ensure this - we assume it
		for( int i = 0; i < ylOut.length; i += _ciphSend.getBlockSize() )
			_ciphSend.processBlock( ylIn, i, ylOut, i ); 

		// return encrypted packet | HMAC
		ByteBuffer ybRet = ByteBuffer.allocate( ylHMAC.length + ylOut.length );
		ybRet.put( ylOut );
		ybRet.put( ylHMAC );
		return ybRet.array();
	}

	/** Decrypt byte list containing first 16 bytes of packet (generally the header) - used so can determine number of bytes remaining so can get the full packet afterwards
	 * @param yl Byte list containing 16 encrypted bytes
	 * @return Byte list decrypted
	 */
	public byte[] decryptByteList( byte[] yl ) throws InvalidCipherTextException
	{
		byte[] ylOut = new byte[ yl.length ];
		assert( ylOut.length % _ciphRecv.getBlockSize() == 0 );
		for( int i = 0; i < ylOut.length; i+= _ciphRecv.getBlockSize() )
			_ciphRecv.processBlock( yl, i, ylOut, i );
		return ylOut;
	}

	/** Decrypt byte list containing encrypted packet + MAC, verify MAC and return SILCPacket
	 * @param yl1 Byte list with decrypted first 16 bytes of packet
	 * @param yl Byte list with rest of encrypted packet
	 * @param ylHMAC Byte list with HMAC (not encrypted)
	 * @return Packet that the byte list contained
	 */
	public SILCPacket decryptByteListToPacket( byte[] yl1, byte[] yl, byte[] ylHMAC ) throws InvalidCipherTextException
	{
		// decrypt bytes
		byte[] ylOut = new byte[ yl.length ];
		assert( ylOut.length % _ciphRecv.getBlockSize() == 0 );
		for( int i = 0; i < ylOut.length; i += _ciphRecv.getBlockSize() )
			_ciphRecv.processBlock( yl, i, ylOut, i );

		// create unified packet byte list
		ByteBuffer yb = ByteBuffer.allocate( yl1.length + ylOut.length );
		yb.put( yl1 );
		yb.put( ylOut );
		
		// calculate HMAC at this end and inc recv seq num
		byte[] ylHMAC2 = new byte[ _hmacRecv.getMacSize() ];
		ByteBuffer yb2 = ByteBuffer.allocate( yb.limit() + 4 );
		yb2.putInt( _iSeqNumRecv++ );
		yb2.put( yb.array() );
		_hmacRecv.update( yb2.array(), 0, yb2.array().length );	// seq num | packet
		_hmacRecv.doFinal( ylHMAC2, 0 );
		if( !Arrays.equals( ylHMAC, ylHMAC2 ) )
		{
			System.out.println( "WARNING: HMACS DON'T MATCH" );
			System.out.println( "HMAC received: " + com.alistairphipps.util.Hex.toString( ylHMAC ) );
			System.out.println( "HMAC calculated: " + com.alistairphipps.util.Hex.toString( ylHMAC2 ) );
			System.out.println( "Packet data received: " + com.alistairphipps.util.Hex.toString( yb.array() ) );
		}
		
		return new SILCPacket( yb.array() );
	}
}
