/**
*  Class ThreeWay
* <P>
*   The 3-Way cypher, inspired, inter alia, by
*	code from Schneier's <cite>Applied Cryptography</cite 2nd edition
*  <P>
*  Coded Mr. Tines &lt;tines@windsong.demon.co.uk&gt; 1998
*  and released into the public domain
*  <P>
* THIS SOFTWARE IS PROVIDED BY THE AUTHORS ''AS IS'' AND ANY EXPRESS
* OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
* ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE
* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
* BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
* WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE
* OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
* EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*  <P>
* @author Mr. Tines
* @version 1.0 23-Dec-1998
*/

package uk.co.demon.windsong.crypt.cea;
import uk.co.demon.windsong.util.text.Hexprint;

public class ThreeWay implements CEA
{
    /**
    * Default constructor
    */
    public ThreeWay()
    {
        destroy();
    }

    /**
    * round constant for encrypt round 1
    */
    private static final int STRT_E=0x0b0b;
    /**
    * ditto decrypt
    */
    private static final int STRT_D=0xb1b1;
    /**
    * round count as recommended
    */
    private static final int ROUNDS=11;

    /**
    * Expanded key-schedule
    */
	private static class TWAY_KEYSCHED
    {
        int[] givenKey = new int[3];
	    int[] inverseKey = new int[3];
	    int[] encryptRoundConstants = new int[ROUNDS+1];
	    int[] decryptRoundConstants = new int[ROUNDS+1];
        public void wipe()
        {
            for (int i1 = 0; i1<3; ++i1) givenKey[i1] = inverseKey[i1] = 0;
            for (int i2 = 0; i2<ROUNDS+1; ++i2)
            {
                encryptRoundConstants[i2] = decryptRoundConstants[i2] = 0;
            }
        }
    }

    /**
    * primitive operations within the algorithm
    * bitwise reverse 96 bits
    * @param a array of 3 ints
    */
    private static final void mu(int[] a)
    {
	    int[] b = new int[3];
	    int i, j;

	    /* reverse within each int */
	    for(i=0; i<3; i++)
	    {
		    b[i]=0;
		    for(j=0; j<32; j++)
		    {
			    b[i] <<= 1;
			    if((a[i]&1) != 0) b[i] |=1;
			    a[i] >>>= 1;
		    }
	    }

	    /* reverse the ints */
	    for(i=0;i<3;i++) {a[i] = b[2-i]; b[2-i] = 0;}
    }

    /**
    * cyclic index increment
    */
    private static final int ADD1(int i) {return ((i+1)%3);}
    private static final int ADD2(int i) {return ((i+2)%3);}

    /**
    * primitive operations within the algorithm
    * non-linear mixing of 96 bits
    * @param a array of 3 ints
    */
    private static final void gamma(int[] a)
    {
	    int[] b = new int[3];
	    int i;
	    for(i=0; i<3; i++)
        {
		    b[i] = a[i] ^ (a[ADD1(i)] |(~a[ADD2(i)]));
        }
	    for(i=0; i<3; i++)
        {
		    a[i] = b[i];
            b[i] = 0;
        }
    }

    /**
    * primitive operations within the algorithm
    * linear mixing of 96 bits
    * @param a array of 3 ints
    */
    private static final void theta(int[] a)
    {
	    int[] b = new int[3];
	    int i;
	    for(i=0; i<3; i++)
	    {
		    b[i] = a[i] ^
			    (a[i]>>>16) ^ (a[ADD1(i)]<<16) ^
			    (a[ADD1(i)]>>>16) ^ (a[ADD2(i)]<<16) ^
			    (a[ADD1(i)]>>>24) ^ (a[ADD2(i)]<<8) ^
		        (a[ADD2(i)]>>> 8) ^ (a[i]<<24) ^
			    (a[ADD2(i)]>>>16) ^ (a[i]<<16) ^
			    (a[ADD2(i)]>>>24) ^ (a[i]<<8);
	    }
	    for(i=0; i<3; i++)
        {
		    a[i] = b[i];
            b[i] = 0;
        }
    }

    /**
    * primitive operations within the algorithm
    * first permutation of 96 bits
    * @param a array of 3 ints
    */
    private static final void pi1(int[] a)
    {
	    a[0] = (a[0]>>>10) ^ (a[0]<<22);
	    a[2] = (a[2]<<1) ^ (a[2]>>>31);
    }

    /**
    * primitive operations within the algorithm
    * second permutation of 96 bits
    * @param a array of 3 ints
    */
    private static final void pi2(int[] a)
    {
	    a[2] = (a[2]>>>10) ^ (a[2]<<22);
	    a[0] = (a[0]<<1) ^ (a[0]>>>31);
    }

    /**
    * the round function
    * @param a array of 3 ints
    */
    private static final void tway_round(int[] a)
    {
	    theta(a); pi1(a); gamma(a); pi2(a);
    }

    /**
    * set up round constants
    * @param init direction specific constant
    * @param table key schecule
    */
    private static final void tway_generate_const(int init, int[] table)
    {
	    int i;
	    for(i=0; i<=ROUNDS; i++)
	    {
		    table[i] = init;
		    init <<= 1;
		    if((init & 0x10000) != 0) init ^= 0x11011;
	    }
    }

    /**
    * one block forward pass
    * @param k the expanded key
    * @param a array of 3 ints
    */
    private static final void Tway_encipher(TWAY_KEYSCHED k, int[] a)
    {
	    int i;
	    for(i=0; i<=ROUNDS; i++)
	    {
		    a[0] ^= k.givenKey[0] ^
			    (k.encryptRoundConstants[i]<<16);
		    a[1] ^= k.givenKey[1];
            a[2] ^= k.givenKey[2] ^
			    k.encryptRoundConstants[i];
		    if(i<ROUNDS) tway_round(a);
		    else theta(a);
	    }
    }

    /**
    * one block inverted pass
    * @param k the expanded key
    * @param a array of 3 ints
    */
    private static final void Tway_decipher(TWAY_KEYSCHED k, int[] a)
    {
	    int i;
	    mu(a);
	    for(i=0; i<=ROUNDS; i++)
	    {
		    a[0] ^= k.inverseKey[0] ^
			    (k.decryptRoundConstants[i]<<16);
		    a[1] ^= k.inverseKey[1];
		    a[2] ^= k.inverseKey[2] ^
			    k.decryptRoundConstants[i];
		    if(i<ROUNDS)tway_round(a);
		    else theta(a);
    	}
	    mu(a);
    }

    /**
    * Exoands user 96 bit key expansion into key schedule
    * @param k the expanded key
    * @param a array of 3 ints
    */
    private static void Tway_Key_Init(TWAY_KEYSCHED k, int[] a)
    {
	    int i;
	    for(i=0; i<3;i++) k.inverseKey[i] =
			  k.givenKey[i] = a[i];
	    theta(k.inverseKey);
	    mu(k.inverseKey);
	    tway_generate_const(STRT_E, k.encryptRoundConstants);
	    tway_generate_const(STRT_D, k.decryptRoundConstants);
    }

    /**
    * Is the jacket doing triple encryption?
    */
    private boolean triple;
    /**
    * The key schedule data
    */
    private TWAY_KEYSCHED[] ks = null;
    /**
    * Initialise the object with one or three key blocks
    * @param key array of key bytes, 1 or 3 key block lengths
    * @param triple true if three keys for triple application
    */
    public void init(byte[] key, int offset, boolean triple)
    {
        this.triple = triple;
        int[] a = new int[getKeysize()/4];
	    int i, k;
	    int keys = triple ? 3 : 1;

        ks = new TWAY_KEYSCHED[keys];

	    for(k=0; k < keys; k++)
	    {
		    /* assume key data is MSB first */
		    for(i=0; i<(a.length); i++)
		    {
			    a[i] = ((key[0+offset]&0xFF)<<24) |
                       ((key[1+offset]&0xFF)<<16) |
                       ((key[2+offset]&0xFF)<<8) |
                        (key[3+offset]&0xFF);
			    offset+=4;
		    }
            ks[k] = new TWAY_KEYSCHED();
		    Tway_Key_Init(ks[k], a);
	    }

	    /* purge sensitive info */
	    for(i=0; i<a.length; i++) a[i] = 0;
    }

    /**
    * Transform one block in ecb mode
    * @param encrypt true if forwards transformation
    * @param in input block
    * @param offin offset into block of input data
    * @param out output block
    * @param offout offset into block of output data
    */
    public void ecb(boolean encrypt, byte[] in, int offin,
        byte[] out, int offout)
    {
	    int keys = (triple) ? 3 : 1;
        int[] a = new int[getKeysize()/4];
    	int i;

	    /* reduce data from MSB-first form */
	    for(i=0; i<a.length; i++)
	    {
		    a[i] = ((in[0+offin]&0xFF)<<24) |
                   ((in[1+offin]&0xFF)<<16) |
				   ((in[2+offin]&0xFF)<<8)  |
                   (in[3+offin]&0xFF);
            offin+=4;
	    }

	    for(i=0; i<keys; i++)
	    {
		    if(encrypt)
            {
			    Tway_encipher(ks[i], a);
            }
		    else
            {
			    Tway_decipher(ks[keys-(i+1)], a);
            }
	    }

	    /* restore endianness - If I've got the theory right! */
	    for(i=0; i<a.length; i++)
	    {
		    out[  offout] = (byte)((a[i]>>24) & 0xFF);
		    out[1+offout] = (byte)((a[i]>>16) & 0xFF);
		    out[2+offout] = (byte)((a[i]>>8) & 0xFF);
            out[3+offout] = (byte) (a[i] & 0xFF);
		    offout+=4;
            a[i] = 0;
	    }
    }

    /**
    * Wipe key schedule information
    */
    public void destroy()
    {
        triple = false;
        if(ks == null) return;
        for(int i=0; i<ks.length; ++i)
        {
            ks[i].wipe();
        }
        ks = null;
    }

    /**
    * Provide infomation of desired key size
    * @return byte length of key
    */
    public final int getKeysize() {return keysize();}

    /**
    * Provide infomation of algorithm block size
    * @return byte length of block
    */
    public final int getBlocksize() {return blocksize();}

    /**
    * Provide infomation of desired key size
    * @return byte length of key
    */
    public final static int keysize() {return 12;}

    /**
    * Provide infomation of algorithm block size
    * @return byte length of block
    */
    public final static int blocksize() {return 12;}

    /*******************************************************************/

    private static void Tway_ECB_encrypt(TWAY_KEYSCHED k, int[] data, int blocks)
    {
	    int[] ptr = new int[3];
	    int i=blocks;
	    while(i != 0)
	    {
            int j;
            for(j=0; j<3;++j) ptr[j] = data[j+(blocks-i)*3];
		    Tway_encipher(k, ptr);
            for(j=0; j<3;++j) data[j+(blocks-i)*3] = ptr[j];
		    i--;
	    }
    }

    private static void Tway_ECB_decrypt(TWAY_KEYSCHED k, int[] data, int blocks)
    {
	    int[] ptr = new int[3];
	    int i=blocks;
	    while(i != 0)
	    {
            int j;
            for(j=0; j<3;++j) ptr[j] = data[j+(blocks-i)*3];
		    Tway_decipher(k, ptr);
            for(j=0; j<3;++j) data[j+(blocks-i)*3] = ptr[j];
		    i--;
	    }
    }


    private static void printvec(String text, int[] vector)
    {
        System.out.println(
            text+" : "+Hexprint.fmt(vector[2])+" "+
            Hexprint.fmt(vector[1])+" "+
            Hexprint.fmt(vector[0])+" ");
    }
    /**
    * Drive standard set of test vectors and post to stdout
    */
    public static void main(String[] args)
    {
        test();
    }

    /**
    * Do standard set of test vectors and post to stdout
    */
    public static void test ()
    {
	    /* ideally, these would be malloc()d and destoryed */
	    TWAY_KEYSCHED ks = new TWAY_KEYSCHED();

	    /* locals */
	    int[] userKey = new int[3];
        int[] data = new int[9];
	    int i;

	    /* test 1 */
	    for(i=0;i<3;i++)
	    {
		    userKey[i] = 0;
		    data[i] = 1;
	    }

	    Tway_Key_Init(ks, userKey);

	    System.out.println("Test 1********");
	    printvec("KEY       = ", userKey);
	    printvec("PLAIN     = ", data);
	    Tway_encipher(ks, data);
	    printvec("CIPHER    = ", data);
	    Tway_decipher(ks, data);
	    printvec("RECOVERED = ", data);

	    /* test 2 */
	    for(i=0;i<3;i++)
	    {
		    userKey[i] = 6-i;
		    data[i] = 3-i;
	    }

	    Tway_Key_Init(ks, userKey);

	    System.out.println("Test 2********");
	    printvec("KEY       = ", userKey);
	    printvec("PLAIN     = ", data);
	    Tway_encipher(ks, data);
	    printvec("CIPHER    = ", data);
	    Tway_decipher(ks, data);
	    printvec("RECOVERED = ", data);

	    /* test 3 */
	    userKey[2] = 0xbcdef012;
	    userKey[1] = 0x456789ab;
	    userKey[0] = 0xdef01234;
	    data[2] = 0x01234567;
	    data[1] = 0x9abcdef0;
	    data[0] = 0x23456789;

	    Tway_Key_Init(ks, userKey);

	    System.out.println("Test 3********");
	    printvec("KEY       = ", userKey);
	    printvec("PLAIN     = ", data);
	    Tway_encipher(ks, data);
	    printvec("CIPHER    = ", data);
	    Tway_decipher(ks, data);
	    printvec("RECOVERED = ", data);

	    /* test 4 */
	    userKey[2] = 0xcab920cd;
	    userKey[1] = 0xd6144138;
	    userKey[0] = 0xd2f05b5e;
	    data[2] = 0xad21ecf7;
	    data[1] = 0x83ae9dc4;
    	data[0] = 0x4059c76e;

	    Tway_Key_Init(ks, userKey);

	    System.out.println("Test 4********");
	    printvec("KEY       = ", userKey);
	    printvec("PLAIN     = ", data);
	    Tway_encipher(ks, data);
	    printvec("CIPHER    = ", data);
	    Tway_decipher(ks, data);
	    printvec("RECOVERED = ", data);

	    /* block test */
	    for(i=0; i<9; i++) data[i] = i;
	    for(i=0; i<9; i+=3)
        {
            int n = i/3;
		    System.out.println("Block "+n+" set to      "+
            Hexprint.fmt(data[i])+" "+
            Hexprint.fmt(data[i+1])+" "+
            Hexprint.fmt(data[i+2]));
        }

    	Tway_ECB_encrypt(ks, data, 3);
	    for(i=0; i<9; i+=3)
        {
            int n = i/3;
		    System.out.println("Block "+n+" encrypts to "+
            Hexprint.fmt(data[i])+" "+
            Hexprint.fmt(data[i+1])+" "+
            Hexprint.fmt(data[i+2]));
        }

    	Tway_ECB_decrypt(ks, data, 3);
    	for(i=0; i<9; i+=3)
        {
            int n = i/3;
		    System.out.println("Block "+n+" decrypts to "+
            Hexprint.fmt(data[i])+" "+
            Hexprint.fmt(data[i+1])+" "+
            Hexprint.fmt(data[i+2]));
        }


	    /* Packaged API tests */
	    System.out.println("API tests");

		byte[] key = new byte[keysize()];
		byte[] in = new byte[blocksize()];
        byte[] out = new byte[blocksize()];
		byte[] savekey = new byte[keysize()];
		byte[] savetext = new byte[blocksize()];
		int length;
		ThreeWay keysched = new ThreeWay();

		for(i=0; i<key.length; i++) key[i] = 0;
		for(i=0; i<in.length; i++) in[i]=0;
		in[3] = in[7] = in[11] = 1;

		keysched.init(key, 0, false);
		keysched.ecb(true, in, 0, out, 0);
        keysched.destroy();

		for(i=0; i<out.length; i++)
		{
			System.out.print(Hexprint.fmt(out[i+8*(1 - (i/4))]));
			if(3 == i%4) System.out.print(" ");
			savetext[i] = out[i];
		}
		System.out.println("");
		System.out.println("ad21ecf7 83ae9dc4 4059c76e is expected result");

		for(i=0; i<key.length; i++) key[i] = 0;
		key[3] = 6; key[7] = 5; key[11] = 4;
		for(i=0; i<in.length; i++) in[i]=0;
		in[3] = 3; in[7] = 2; in[11] = 1;

		keysched.init(key, 0, false);
		keysched.ecb(true, in, 0, out, 0);
        keysched.destroy();

        for(i=0; i<out.length; i++)
		{
			System.out.print(Hexprint.fmt(out[i+8*(1 - (i/4))]));
			if(3 == i%4) System.out.print(" ");
			savekey[i] = out[i];
		}
		System.out.println("");
		System.out.println("cab920cd d6144138 d2f05b5e is expected result");

		key[8] = (byte)0xbc; key[9] = (byte)0xde; key[10] = (byte)0xf0; key[11] = (byte)0x12;
		key[4] = (byte)0x45; key[5] = (byte)0x67; key[6] = (byte)0x89; key[7] = (byte)0xab;
		key[0] = (byte)0xde; key[1] = (byte)0xf0; key[2] = (byte)0x12; key[3] = (byte)0x34;

		in[8] = (byte)0x01; in[9] = (byte)0x23; in[10] = (byte)0x45; in[11] = (byte)0x67;
		in[4] = (byte)0x9a; in[5] = (byte)0xbc; in[6] = (byte)0xde; in[7] = (byte)0xf0;
		in[0] = (byte)0x23; in[1] = (byte)0x45; in[2] = (byte)0x67; in[3] = (byte)0x89;

		keysched.init(key, 0, false);
		keysched.ecb(true, in, 0, out, 0);
        keysched.destroy();

        for(i=0; i<out.length; i++)
		{
			System.out.print(Hexprint.fmt(out[i+8*(1 - (i/4))]));
			if(3 == i%4) System.out.print(" ");
		}
		System.out.println("");
		System.out.println("7cdb76b2 9cdddb6d 0aa55dbb is expected result");

		for(i=0; i<keysize(); i++) key[i] = savekey[i];
		for(i=0; i<blocksize(); i++) in[i] = savetext[i];

		keysched.init(key, 0, false);
		keysched.ecb(true, in, 0, out, 0);
        keysched.destroy();

				for(i=0; i<out.length; i++)
		{
			System.out.print(Hexprint.fmt(out[i+8*(1 - (i/4))]));
			if(3 == i%4) System.out.print(" ");
			savetext[i] = out[i];
		}
		System.out.println("");
		System.out.println("15b155ed 6b13f17c 478ea871 is expected result");
	}

}