/* This file is in the public domain */

#include <iostream>
#include <iomanip>
#include <ctime>
#include <string>
#include <exception>

#include <opencl/filters.h>
#include <opencl/randpool.h>
using namespace OpenCL_types;

OpenCL::Randpool rng;

struct algorithm
   {
   algorithm(const std::string& t, const std::string& n,
             u32bit k = 0, u32bit i = 0) :
             type(t), name(n), filtername(n), keylen(k), ivlen(i) {}
   algorithm(const std::string& t, const std::string& n, const std::string& f,
             u32bit k = 0, u32bit i = 0) :
             type(t), name(n), filtername(f), keylen(k), ivlen(i) {}

   std::string type, name, filtername;
   u32bit keylen, ivlen;
   };

const std::string END = "END";

algorithm algorithms[] = {

   algorithm("Block Cipher", "Blowfish", 16),
   algorithm("Block Cipher", "CAST256", 16),
   algorithm("Block Cipher", "CAST5", 16),
   algorithm("Block Cipher", "CS-Cipher", 16),
   algorithm("Block Cipher", "DES", 8),
   algorithm("Block Cipher", "DESX", 24),
   algorithm("Block Cipher", "Triple-DES", 24),
   algorithm("Block Cipher", "GOST", 32),
   algorithm("Block Cipher", "IDEA", 16),
   algorithm("Block Cipher", "Luby-Rackoff<SHA1>", 16),
   algorithm("Block Cipher", "MISTY1", 16),
   algorithm("Block Cipher", "RC2", 16),
   algorithm("Block Cipher", "RC5 (r = 12)", "RC5-12", 16),
   algorithm("Block Cipher", "RC5 (r = 16)", "RC5-16", 16),
   algorithm("Block Cipher", "RC6", 32),
   algorithm("Block Cipher", "Rijndael [AES] (r = 10)", "Rijndael", 16),
   algorithm("Block Cipher", "Rijndael [AES] (r = 12)", "Rijndael", 24),
   algorithm("Block Cipher", "Rijndael [AES] (r = 14)", "Rijndael", 32),
   algorithm("Block Cipher", "SAFER-SK128", 16),
   algorithm("Block Cipher", "Serpent", 32),
   algorithm("Block Cipher", "SHARK", 16),
   algorithm("Block Cipher", "Skipjack", 10),
   algorithm("Block Cipher", "Square", 16),
   algorithm("Block Cipher", "TEA", 16),
   algorithm("Block Cipher", "ThreeWay", 12),
   algorithm("Block Cipher", "Twofish", 32),
   algorithm("Block Cipher", "XTEA", 16),

   algorithm("Cipher Mode", "CBC<DES>", "CBC_wPadding_Encryption<DES>", 8, 8),
   algorithm("Cipher Mode", "CFB<DES>(8)", "CFB_Encryption<DES>(8)", 8, 8),
   algorithm("Cipher Mode", "CFB<DES>(4)", "CFB_Encryption<DES>(4)", 8, 8),
   algorithm("Cipher Mode", "CFB<DES>(2)", "CFB_Encryption<DES>(2)", 8, 8),
   algorithm("Cipher Mode", "CFB<DES>(1)", "CFB_Encryption<DES>(1)", 8, 8),
   algorithm("Cipher Mode", "OFB<DES>", 8, 8),
   algorithm("Cipher Mode", "Counter<DES>", 8, 8),

   algorithm("Stream Cipher", "ARC4", 16),
   algorithm("Stream Cipher", "ISAAC", 16),
   algorithm("Stream Cipher", "SEAL", 20),

   algorithm("Hash", "Adler32"),
   algorithm("Hash", "CRC24"),
   algorithm("Hash", "CRC32"),
   algorithm("Hash", "HAVAL", "HAVAL-256"),
   algorithm("Hash", "MD2"),
   algorithm("Hash", "MD4"),
   algorithm("Hash", "MD5"),
   algorithm("Hash", "RIPE-MD128"),
   algorithm("Hash", "RIPE-MD160"),
   algorithm("Hash", "SHA-1"),
   algorithm("Hash", "SHA2-256"),
   algorithm("Hash", "SHA2-512"),
   algorithm("Hash", "Tiger"),

   algorithm("MAC", "EMAC<Square>", 16),
   algorithm("MAC", "MD5-MAC", 16),

   algorithm("RNG", "X917<Square>"),
   algorithm("RNG", "Randpool"),

   algorithm(END, END)
};

/* Discard output to reduce overhead */
struct BitBucket : public OpenCL::Filter
   {
   void write(const byte[], u32bit) {}
   };

void bench(const std::string& name, const std::string& filtername,
           bool html, double seconds,
           u32bit keylen = 0, u32bit ivlen = 0)
   {
   OpenCL::Filter* lookup(const std::string&, const std::string&,
                          const std::string&);

   OpenCL::Filter* filter = lookup(filtername, std::string(int(2*keylen), 'A'),
                                               std::string(int(2*ivlen), 'A'));
   if(!filter)
      {
      std::cout << "Filter lookup doesn't know about "
                << filtername << std::endl;
      std::exit(1);
      }
   OpenCL::Pipe pipe(filter, new BitBucket);

   static const u32bit BUFFERSIZE = OpenCL::DEFAULT_BUFFERSIZE;
   byte buf[BUFFERSIZE];
   rng.randomize(buf, BUFFERSIZE);
   u32bit iterations = 1, j = 0;
   u32bit start = std::clock(), clocks_used = 0;
   /* 2/3 is a fudge factor; even with it the tests take longer than one
      might expect */
   while(clocks_used < (2.0 / 3.0) * seconds * CLOCKS_PER_SEC)
      {
      iterations *= 2;
      for(; j != iterations; j++)
          pipe.write(buf, BUFFERSIZE);
      clocks_used = std::clock() - start;
      }

   double bytes_per_sec = ((double)iterations * BUFFERSIZE) /
                          ((double)clocks_used / CLOCKS_PER_SEC);
   u32bit kbytes_per_sec = (u32bit)(bytes_per_sec / 1024.0);

   if(html)
      {
      std::cout << "<TR><TH>" << name << std::string(25 - name.length(), ' ');
      std::cout << "<TH>" << std::setw(6) << kbytes_per_sec << std::endl;
      }
   else
      {
      std::cout << name << ": " << std::string(25 - name.length(), ' ');
      std::cout << std::setw(6) << kbytes_per_sec << " kbytes/sec"
                << std::endl;
      }
   }

void benchmark(const std::string& what, bool html, double seconds)
   {
   try {
      if(html)
         {
         std::cout << "<!DOCTYPE HTML PUBLIC \"-//W3C//DTD "
                   << "HTML 4.0 Transitional//EN\">\n"
                   << "<HTML>\n\n"
                   << "<TITLE>OpenCL Benchmarks</TITLE>\n\n"
                   << "<BODY>\n\n"
                   << "<P><TABLE BORDER CELLSPACING=1>\n"
                   << "<THEAD>\n"
                   << "<TR><TH>Algorithm                <TH>Kbytes / second\n"
                   << "<TBODY>\n";
         }

      for(u32bit j = 0; algorithms[j].type != END; j++)
         if(what == "All" || what == algorithms[j].type)
            bench(algorithms[j].name, algorithms[j].filtername,
                  html, seconds,
                  algorithms[j].keylen, algorithms[j].ivlen);

      if(html)
         std::cout << "</TABLE>\n\n</BODY></HTML>\n";
      }
   catch(OpenCL::Exception& e)
      {
      std::cout << "OpenCL exception caught: " << e.what() << std::endl;
      std::exit(1);
      }
   catch(std::exception& e)
      {
      std::cout << "Standard library exception caught: " << e.what()
                << std::endl;
      std::exit(1);
      }
   catch(...)
      {
      std::cout << "Unknown exception caught." << std::endl;
      std::exit(1);
      }
   }

