/*************************************************
* CAST256 Source File                            *
* (C) 1999-2001 The OpenCL Project               *
*************************************************/

#include <opencl/cast256.h>

namespace OpenCL {

/*************************************************
* CAST256 Encryption                             *
*************************************************/
void CAST256::encrypt(const byte in[BLOCKSIZE], byte out[BLOCKSIZE]) const
   {
   u32bit A = make_u32bit(in[ 0], in[ 1], in[ 2], in[ 3]),
          B = make_u32bit(in[ 4], in[ 5], in[ 6], in[ 7]),
          C = make_u32bit(in[ 8], in[ 9], in[10], in[11]),
          D = make_u32bit(in[12], in[13], in[14], in[15]);
   round1(C, D, MK[ 0], RK[ 0]);   round2(B, C, MK[ 1], RK[ 1]);
   round3(A, B, MK[ 2], RK[ 2]);   round1(D, A, MK[ 3], RK[ 3]);
   round1(C, D, MK[ 4], RK[ 4]);   round2(B, C, MK[ 5], RK[ 5]);
   round3(A, B, MK[ 6], RK[ 6]);   round1(D, A, MK[ 7], RK[ 7]);
   round1(C, D, MK[ 8], RK[ 8]);   round2(B, C, MK[ 9], RK[ 9]);
   round3(A, B, MK[10], RK[10]);   round1(D, A, MK[11], RK[11]);
   round1(C, D, MK[12], RK[12]);   round2(B, C, MK[13], RK[13]);
   round3(A, B, MK[14], RK[14]);   round1(D, A, MK[15], RK[15]);
   round1(C, D, MK[16], RK[16]);   round2(B, C, MK[17], RK[17]);
   round3(A, B, MK[18], RK[18]);   round1(D, A, MK[19], RK[19]);
   round1(C, D, MK[20], RK[20]);   round2(B, C, MK[21], RK[21]);
   round3(A, B, MK[22], RK[22]);   round1(D, A, MK[23], RK[23]);
   round1(D, A, MK[27], RK[27]);   round3(A, B, MK[26], RK[26]);
   round2(B, C, MK[25], RK[25]);   round1(C, D, MK[24], RK[24]);
   round1(D, A, MK[31], RK[31]);   round3(A, B, MK[30], RK[30]);
   round2(B, C, MK[29], RK[29]);   round1(C, D, MK[28], RK[28]);
   round1(D, A, MK[35], RK[35]);   round3(A, B, MK[34], RK[34]);
   round2(B, C, MK[33], RK[33]);   round1(C, D, MK[32], RK[32]);
   round1(D, A, MK[39], RK[39]);   round3(A, B, MK[38], RK[38]);
   round2(B, C, MK[37], RK[37]);   round1(C, D, MK[36], RK[36]);
   round1(D, A, MK[43], RK[43]);   round3(A, B, MK[42], RK[42]);
   round2(B, C, MK[41], RK[41]);   round1(C, D, MK[40], RK[40]);
   round1(D, A, MK[47], RK[47]);   round3(A, B, MK[46], RK[46]);
   round2(B, C, MK[45], RK[45]);   round1(C, D, MK[44], RK[44]);
   out[ 0] = get_byte(0, A); out[ 1] = get_byte(1, A);
   out[ 2] = get_byte(2, A); out[ 3] = get_byte(3, A);
   out[ 4] = get_byte(0, B); out[ 5] = get_byte(1, B);
   out[ 6] = get_byte(2, B); out[ 7] = get_byte(3, B);
   out[ 8] = get_byte(0, C); out[ 9] = get_byte(1, C);
   out[10] = get_byte(2, C); out[11] = get_byte(3, C);
   out[12] = get_byte(0, D); out[13] = get_byte(1, D);
   out[14] = get_byte(2, D); out[15] = get_byte(3, D);
   }

/*************************************************
* CAST256 Decryption                             *
*************************************************/
void CAST256::decrypt(const byte in[BLOCKSIZE], byte out[BLOCKSIZE]) const
   {
   u32bit A = make_u32bit(in[ 0], in[ 1], in[ 2], in[ 3]),
          B = make_u32bit(in[ 4], in[ 5], in[ 6], in[ 7]),
          C = make_u32bit(in[ 8], in[ 9], in[10], in[11]),
          D = make_u32bit(in[12], in[13], in[14], in[15]);
   round1(C, D, MK[44], RK[44]);   round2(B, C, MK[45], RK[45]);
   round3(A, B, MK[46], RK[46]);   round1(D, A, MK[47], RK[47]);
   round1(C, D, MK[40], RK[40]);   round2(B, C, MK[41], RK[41]);
   round3(A, B, MK[42], RK[42]);   round1(D, A, MK[43], RK[43]);
   round1(C, D, MK[36], RK[36]);   round2(B, C, MK[37], RK[37]);
   round3(A, B, MK[38], RK[38]);   round1(D, A, MK[39], RK[39]);
   round1(C, D, MK[32], RK[32]);   round2(B, C, MK[33], RK[33]);
   round3(A, B, MK[34], RK[34]);   round1(D, A, MK[35], RK[35]);
   round1(C, D, MK[28], RK[28]);   round2(B, C, MK[29], RK[29]);
   round3(A, B, MK[30], RK[30]);   round1(D, A, MK[31], RK[31]);
   round1(C, D, MK[24], RK[24]);   round2(B, C, MK[25], RK[25]);
   round3(A, B, MK[26], RK[26]);   round1(D, A, MK[27], RK[27]);
   round1(D, A, MK[23], RK[23]);   round3(A, B, MK[22], RK[22]);
   round2(B, C, MK[21], RK[21]);   round1(C, D, MK[20], RK[20]);
   round1(D, A, MK[19], RK[19]);   round3(A, B, MK[18], RK[18]);
   round2(B, C, MK[17], RK[17]);   round1(C, D, MK[16], RK[16]);
   round1(D, A, MK[15], RK[15]);   round3(A, B, MK[14], RK[14]);
   round2(B, C, MK[13], RK[13]);   round1(C, D, MK[12], RK[12]);
   round1(D, A, MK[11], RK[11]);   round3(A, B, MK[10], RK[10]);
   round2(B, C, MK[ 9], RK[ 9]);   round1(C, D, MK[ 8], RK[ 8]);
   round1(D, A, MK[ 7], RK[ 7]);   round3(A, B, MK[ 6], RK[ 6]);
   round2(B, C, MK[ 5], RK[ 5]);   round1(C, D, MK[ 4], RK[ 4]);
   round1(D, A, MK[ 3], RK[ 3]);   round3(A, B, MK[ 2], RK[ 2]);
   round2(B, C, MK[ 1], RK[ 1]);   round1(C, D, MK[ 0], RK[ 0]);
   out[ 0] = get_byte(0, A); out[ 1] = get_byte(1, A);
   out[ 2] = get_byte(2, A); out[ 3] = get_byte(3, A);
   out[ 4] = get_byte(0, B); out[ 5] = get_byte(1, B);
   out[ 6] = get_byte(2, B); out[ 7] = get_byte(3, B);
   out[ 8] = get_byte(0, C); out[ 9] = get_byte(1, C);
   out[10] = get_byte(2, C); out[11] = get_byte(3, C);
   out[12] = get_byte(0, D); out[13] = get_byte(1, D);
   out[14] = get_byte(2, D); out[15] = get_byte(3, D);
   }

/*************************************************
* CAST256 Round Type 1                           *
*************************************************/
void CAST256::round1(u32bit& out, u32bit in, u32bit mask, u32bit rot) const
   {
   u32bit temp = rotate_left(mask + in, rot);
   out  ^= (CAST_SBOX1[get_byte(0, temp)] ^ CAST_SBOX2[get_byte(1, temp)]) -
            CAST_SBOX3[get_byte(2, temp)] + CAST_SBOX4[get_byte(3, temp)];
   }

/*************************************************
* CAST256 Round Type 2                           *
*************************************************/
void CAST256::round2(u32bit& out, u32bit in, u32bit mask, u32bit rot) const
   {
   u32bit temp = rotate_left(mask ^ in, rot);
   out  ^= (CAST_SBOX1[get_byte(0, temp)]  - CAST_SBOX2[get_byte(1, temp)] +
            CAST_SBOX3[get_byte(2, temp)]) ^ CAST_SBOX4[get_byte(3, temp)];
   }

/*************************************************
* CAST256 Round Type 3                           *
*************************************************/
void CAST256::round3(u32bit& out, u32bit in, u32bit mask, u32bit rot) const
   {
   u32bit temp = rotate_left(mask - in, rot);
   out  ^= ((CAST_SBOX1[get_byte(0, temp)]  + CAST_SBOX2[get_byte(1, temp)]) ^
             CAST_SBOX3[get_byte(2, temp)]) - CAST_SBOX4[get_byte(3, temp)];
   }

/*************************************************
* CAST256 Key Setup                              *
*************************************************/
void CAST256::set_key(const byte key[], u32bit length) throw(InvalidKeyLength)
   {
   if(!valid_keylength(length))
      throw InvalidKeyLength("CAST-256", length);
   SecureBuffer<u32bit, 16> TMK, TRK;
   u32bit CM = 0x5A827999, CR = 19;
   for(u32bit j = 0; j != length; j++)
      TMK[j/4] = (TMK[j/4] << 8) + key[j];

   u32bit A = TMK[0], B = TMK[1], C = TMK[2], D = TMK[3],
          E = TMK[4], F = TMK[5], G = TMK[6], H = TMK[7];
   for(u32bit j = 0; j != 48; j += 4)
      {
      for(u32bit k = 0; k != 8; k++)
         { TMK[k] = CM; TRK[k] = CR % 32; CM += 0x6ED9EBA1; CR += 17; }
      round1(G, H, TMK[0], TRK[0]);   round2(F, G, TMK[1], TRK[1]);
      round3(E, F, TMK[2], TRK[2]);   round1(D, E, TMK[3], TRK[3]);
      round2(C, D, TMK[4], TRK[4]);   round3(B, C, TMK[5], TRK[5]);
      round1(A, B, TMK[6], TRK[6]);   round2(H, A, TMK[7], TRK[7]);
      for(u32bit k = 0; k != 8; k++)
         { TMK[k] = CM; TRK[k] = CR % 32; CM += 0x6ED9EBA1; CR += 17; }
      round1(G, H, TMK[0], TRK[0]);   round2(F, G, TMK[1], TRK[1]);
      round3(E, F, TMK[2], TRK[2]);   round1(D, E, TMK[3], TRK[3]);
      round2(C, D, TMK[4], TRK[4]);   round3(B, C, TMK[5], TRK[5]);
      round1(A, B, TMK[6], TRK[6]);   round2(H, A, TMK[7], TRK[7]);
      RK[j  ] = (A % 32); RK[j+1] = (C % 32);
      RK[j+2] = (E % 32); RK[j+3] = (G % 32);
      MK[j  ] = H; MK[j+1] = F; MK[j+2] = D; MK[j+3] = B;
      }
   }

}
