/*	CAST-256 Block Cypher Implementation
	Copyright (C) 1999, Daniel Roethlisberger

	This program is free software; you can redistribute it and/or modify
	it under the terms of the GNU General Public License as published by
	the Free Software Foundation; either version 2 of the License, or
	(at your option) any later version.

	This program is distributed in the hope that it will be useful,
	but WITHOUT ANY WARRANTY; without even the implied warranty of
	MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
	GNU General Public License for more details.

	You should have received a copy of the GNU General Public License
	along with this program; if not, write to the Free Software
	Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA

	The author of this program may be contacted at admin@roe.ch.
*/

#include "cast256.h"
#include "cast256s.h"

// ========================================================================
//  Support defines
// ------------------------------------------------------------------------
#define DWORD unsigned long

// Rotate Left operation
#define ROTL(x,n)	(((x) << (n)) | ((x) >> (32 - (n))))

// Extract the nth BYTE from a DWORD
#define EXT_A(x)	((int)((x >> 24) & 0xFF))
#define EXT_B(x)	((int)((x >> 16) & 0xFF))
#define EXT_C(x)	((int)((x >>  8) & 0xFF))
#define EXT_D(x)	((int)((x)       & 0xFF))

// Five least significant bits
#define FIVE_LSB(x) ((x) & (DWORD)0x0000001F)

// Fill in the rotation subkey from KAPPA
#define Kr_FILL(i,k) \
	Kr[i].A = FIVE_LSB(k.A); \
	Kr[i].B = FIVE_LSB(k.C); \
	Kr[i].C = FIVE_LSB(k.E); \
	Kr[i].D = FIVE_LSB(k.G)

// Fill in the mask subkey from KAPPA
#define Km_FILL(i,k) \
	Km[i].A = k.H; \
	Km[i].B = k.F; \
	Km[i].C = k.D; \
	Km[i].D = k.B

// ========================================================================
//  The global rotation and mask round subkey generation tables
// ------------------------------------------------------------------------
KAPPA g_Tr[24];
KAPPA g_Tm[24];

// ========================================================================
//  The F-Functions
// ------------------------------------------------------------------------
DWORD __inline f1(DWORD D, DWORD Kri, DWORD Kmi)
{
	DWORD I = ROTL((Kmi + D), Kri);
	return ((S1[EXT_A(I)] ^ S2[EXT_B(I)]) - S3[EXT_C(I)]) + S4[EXT_D(I)];
}
DWORD __inline f2(DWORD D, DWORD Kri, DWORD Kmi)
{
	DWORD I = ROTL((Kmi ^ D), Kri);
	return ((S1[EXT_A(I)] - S2[EXT_B(I)]) + S3[EXT_C(I)]) ^ S4[EXT_D(I)];
}
DWORD __inline f3(DWORD D, DWORD Kri, DWORD Kmi)
{
	DWORD I = ROTL((Kmi - D), Kri);
	return ((S1[EXT_A(I)] + S2[EXT_B(I)]) ^ S3[EXT_C(I)]) - S4[EXT_D(I)];
}

// ========================================================================
//  The Quad-Round and Reverse Quad Round
// ------------------------------------------------------------------------
BETA __inline Q(int round, BETA data, BETA* Kr, BETA* Km)
{

	data.C = data.C ^ f1(data.D, Kr[round].A, Km[round].A);
	data.B = data.B ^ f2(data.C, Kr[round].B, Km[round].B);
	data.A = data.A ^ f3(data.B, Kr[round].C, Km[round].C);
	data.D = data.D ^ f1(data.A, Kr[round].D, Km[round].D);

	return data;
}
BETA __inline QBAR(int round, BETA data, BETA* Kr, BETA* Km)
{
	data.D = data.D ^ f1(data.A, Kr[round].D, Km[round].D);
	data.A = data.A ^ f3(data.B, Kr[round].C, Km[round].C);
	data.B = data.B ^ f2(data.C, Kr[round].B, Km[round].B);
	data.C = data.C ^ f1(data.D, Kr[round].A, Km[round].A);

	return data;
}

// ========================================================================
//  The Forward Octave
// ------------------------------------------------------------------------
KAPPA __inline W(int round, KAPPA data, KAPPA* Tr, KAPPA* Tm)
{
	data.G = data.G ^ f1(data.H, Tr[round].A, Tm[round].A);
	data.F = data.F ^ f2(data.G, Tr[round].B, Tm[round].B);
	data.E = data.E ^ f3(data.F, Tr[round].C, Tm[round].C);
	data.D = data.D ^ f1(data.E, Tr[round].D, Tm[round].D);
	data.C = data.C ^ f2(data.D, Tr[round].E, Tm[round].E);
	data.B = data.B ^ f3(data.C, Tr[round].F, Tm[round].F);
	data.A = data.A ^ f1(data.B, Tr[round].G, Tm[round].G);
	data.H = data.H ^ f2(data.A, Tr[round].H, Tm[round].H);
	
	return data;
}

// ========================================================================
//  Encrypt a 128bit block
// ------------------------------------------------------------------------
void CAST256Encrypt(BETA *Kr, BETA *Km, BETA *pData)
{
	for(int i = 0; i < 6; i++)
		*pData = Q(i, *pData, Kr, Km);
	for(i = 6; i < 12; i++)
		*pData = QBAR(i, *pData, Kr, Km);
}

// ========================================================================
//  Decrypt a 128bit block
// ------------------------------------------------------------------------
void CAST256Decrypt(BETA *Kr, BETA *Km, BETA *pData)
{
	for(int i = 11; i >= 6; i--)
		*pData = Q(i, *pData, Kr, Km);
	for(i = 5; i >= 0; i--)
		*pData = QBAR(i, *pData, Kr, Km);
}

// ========================================================================
//  Initialize the mask and rotation round subkey sets from a given user key
// ------------------------------------------------------------------------
void CAST256KeyInit(BETA *Kr, BETA *Km, KAPPA userKey)
{
	for(int i = 0; i < 12; i++)
	{
		userKey = W(2*i, userKey, g_Tr, g_Tm);
		userKey = W(2*i+1, userKey, g_Tr, g_Tm);
		Kr_FILL(i, userKey);
		Km_FILL(i, userKey);
	}
}

// ========================================================================
//  Initialize the tables used for key initialization / expansion
// ------------------------------------------------------------------------
void CAST256TableInit(void)
{
	DWORD Cm = 0x5A827999;
	DWORD Mm = 0x6ED9EBA1;
	DWORD Cr = 19;
	DWORD Mr = 17;
	for(int i = 0; i < 24; i++)
	{
		g_Tm[i].A = Cm;
		Cm += Mm;
		g_Tm[i].B = Cm;
		Cm += Mm;
		g_Tm[i].C = Cm;
		Cm += Mm;
		g_Tm[i].D = Cm;
		Cm += Mm;
		g_Tm[i].E = Cm;
		Cm += Mm;
		g_Tm[i].F = Cm;
		Cm += Mm;
		g_Tm[i].G = Cm;
		Cm += Mm;
		g_Tm[i].H = Cm;
		Cm += Mm;
		g_Tr[i].A = Cr;
		Cr = FIVE_LSB(Cr + Mr);
		g_Tr[i].B = Cr;
		Cr = FIVE_LSB(Cr + Mr);
		g_Tr[i].C = Cr;
		Cr = FIVE_LSB(Cr + Mr);
		g_Tr[i].D = Cr;
		Cr = FIVE_LSB(Cr + Mr);
		g_Tr[i].E = Cr;
		Cr = FIVE_LSB(Cr + Mr);
		g_Tr[i].F = Cr;
		Cr = FIVE_LSB(Cr + Mr);
		g_Tr[i].G = Cr;
		Cr = FIVE_LSB(Cr + Mr);
		g_Tr[i].H = Cr;
		Cr = FIVE_LSB(Cr + Mr);
	}
}

// ========================================================================
//  CAST-256 Test Code
// ------------------------------------------------------------------------
/*
#include <stdio.h>
#include <string.h>
#include <conio.h>

void main(void)
{
	printf("CAST-256 Block Cypher Implementation\nCopyright (C) 1999, Daniel Roethlisberger <admin@roe.ch>\n\nThis program is free software under the GNU license.\nLet me know if you use it.\n\n");

	BETA Kr[CAST_ROUNDS];
	BETA Km[CAST_ROUNDS];

	// Simple encryption/decryption test

	printf("Block Encryption/Decryption Test: ");
	
	KAPPA uKey = {0x01234567, 0x89ABCDEF, 0x01234567, 0x89ABCDEF, 0x01234567, 0x89ABCDEF, 0x01234567, 0x89ABCDEF};
	BETA Plain = {0xDEAFDEAF, 0xDEAFDEAF, 0xDEAFDEAF, 0xDEAFDEAF};
	BETA Cypher = {0xDEAFDEAF, 0xDEAFDEAF, 0xDEAFDEAF, 0xDEAFDEAF};

	// Initialize tables and keys
	CAST256TableInit();
	CAST256KeyInit(Kr, Km, uKey);

	// Encrypt and decrypt Cypher
	CAST256Encrypt(Kr, Km, &Cypher);
	CAST256Decrypt(Kr, Km, &Cypher);

	// Check result. Cypher must be equal to Plain
	if(!memcmp(&Plain, &Cypher, 16))
		printf("PASSED\n");
	else
		printf("FAILED\n");


	// RFC2612's defined reference vector test

	printf("RFC 2612 Reference Vector Test:   ");

	// Defined user key
	uKey.A = 0x2342bb9e;
	uKey.B = 0xfa38542c;
	uKey.C = 0xbed0ac83;
	uKey.D = 0x940ac298;
	uKey.E = 0x8d7c47ce;
	uKey.F = 0x26490846;
	uKey.G = 0x1cc1b513;
	uKey.H = 0x7ae6b604;

	// Defined plaintext
	Plain.A = 0x00000000;
	Plain.B = 0x00000000;
	Plain.C = 0x00000000;
	Plain.D = 0x00000000;

	// Defined corresponding cyphertext
	Cypher.A = 0x4f6a2038;
	Cypher.B = 0x286897b9;
	Cypher.C = 0xc9870136;
	Cypher.D = 0x553317fa;

	// Initialize tables and keys
	CAST256TableInit();
	CAST256KeyInit(Kr, Km, uKey);

	// Encrypt the test vector (=Plain)
	CAST256Encrypt(Kr, Km, &Plain);

	// Check whether Plain encrypted to the correct cyphertext (=Cypher)
	if(!memcmp(&Plain, &Cypher, 16))
		printf("PASSED\n");
	else
		printf("FAILED\n");

	printf("\nPress any key to return...\n\n");
	getch();
}
/**/