// Copyright (c) 2003 Robin J Carey. All rights reserved.
//
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// 1. Redistributions of source code must retain the above copyright
//    notice, this list of conditions and the following disclaimer.
// 2. Redistributions in binary form must reproduce the above copyright
//    notice, this list of conditions and the following disclaimer in the
//    documentation and/or other materials provided with the distribution.
// 3. The name of the author, Robin J Carey, may not be used to endorse or
//    promote products derived from this software without specific prior
//    written permission.
// 4. This software may not be used for terrorism, paedophilia or crimes
//    against humanity.
//
// THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
// ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
// OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
// HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
// LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
// OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
// SUCH DAMAGE.
//
//

# include  "ASSERT.h"
# include  "ByteType.h"
# include  "Startup.h"
# include  "RSA.h"
# include  "FileUtil.h"
# include  "Key.h"
# include  "Fatal.h"

# include  <limits.h>		// For: PATH_MAX
# include  <stdio.h>		// For: printf(3), fprintf(3), fflush(3),
				//      fwrite(3), snprintf(3), fclose(3),
				//      feof(3)
# include  <stdlib.h>		// For: exit(3), atexit(3)
# include  <string.h>		// For: memcpy(3), memcmp(3), memset(3)
# include  <unistd.h>		// For: getopt(3), unlink(2)

// OpenSSL
# include  <openssl/bn.h>
# include  <openssl/crypto.h>
# include  <openssl/sha.h>

# define FATAL fclose (fOut); unlink (outputFile); FATAL

static const char * const	INPUT_FILE = CIPHERTEXT_FILE;
static const char * const	rejectMsg = "WARNING: Incorrect password or "
"corrupt/fraudulent ciphertext.\n";

static void
Usage (const char * const argv0)
{
  fprintf (stderr, "%s [ -i input-file ] ", argv0);
  fprintf (stderr, "-o output-file -f from-public-key\n");
  exit (EXIT_FAILURE);
}

static BIGNUM * const	from_n	= BN_new ();
static BIGNUM * const	n	= BN_new ();
static BIGNUM * const	d	= BN_new ();
static BIGNUM * const	message	= BN_new ();
static ByteType		dataLen;
static ByteType		saltLen;

static void
ClearMemory (void)
{
  BN_clear (from_n);
  BN_clear (n);
  BN_clear (d);
  BN_clear (message);
  dataLen = saltLen = 0;
}

int
main (int argc, char * argv [])
{
  Startup	startup;
  startup.Useless ();

  char		inputFile	[ PATH_MAX + 1 ];
  char		outputFile	[ PATH_MAX + 1 ];
  char		fromKeyFile	[ PATH_MAX + 1 ];
  bool		gotOutputFile	= false;
  bool		gotFromKeyFile	= false;
  int		ch;

  snprintf (inputFile, sizeof (inputFile), "%s", INPUT_FILE);
  while ((ch = getopt (argc, argv, "i:o:f:")) != -1) {
    switch (ch) {
      case 'i':
	snprintf (inputFile, sizeof (inputFile), "%s", optarg);
	break;
      case 'o':
	snprintf (outputFile, sizeof (outputFile), "%s", optarg);
	gotOutputFile = true;
	break;
      case 'f':
	snprintf (fromKeyFile, sizeof (fromKeyFile), "%s", optarg);
	gotFromKeyFile = true;
	break;
      case '?':
      default:
	Usage (argv [0]);
    }
  }
  if (!gotOutputFile || !gotFromKeyFile) {
    Usage (argv [0]);
  }
  if (atexit (&ClearMemory) < 0) {
    perror ("atexit(3)");
    exit (EXIT_FAILURE);
  }

  FILE * const		fIn = FileUtil::FOpenEmptyCheck (inputFile, "rb");

  // Read "from_n" public-key for signature computation:
  //
  Key::ReadKey (fromKeyFile, from_n);

  // Read own public-key "n" for RSA decryption:
  //
  Key::ReadKey (PUBKEY, n);
  //
  // Decrypt private-key "d" for RSA decryption:
  //
  Key::ReadKey (PRIVKEY, d);

  // Open "outputFile":
  //
  FILE * const		fOut = FileUtil::FOpenErrCheck (outputFile, "wb");

  // Main decryption loop; Operations:
  //
  // 1) RSA decrypt "inputFile"; m = c^d mod n
  // 2) RSA decrypt appendix-signature; generating the binary signature.
  //
  printf ("Decrypting ciphertext:");
  fflush (stdout);
  char			inBuf		[ HEXSTR_BUFSIZE ];
  ByteType		sigBin		[ BN_num_bytes (from_n) ];
  size_t		sigBin_i	= 0;
  bool			in_sig		= false;
  ssize_t		Got;
  BN_CTX * const	ctx		= BN_CTX_new ();
  BIGNUM *		ciphertext	= BN_new ();
  SHA_CTX		SHA1context;
  SHA1_Init (&SHA1context);
  do {
    Got = FileUtil::FReadCharNewline (fIn, inBuf, sizeof (inBuf));
    if (Got > 0) {
      if (inBuf [ 0 ] == SIGNATURE || in_sig) {
	if (!in_sig) {
	  printf (" done.\n");
	  printf ("Decrypting digital signature:");
	  fflush (stdout);
	}
	in_sig = true;
	if (inBuf [ 0 ] != SIGNATURE) {
	  FATAL ("Corrupt digital signature.\n%s", rejectMsg);
	}
	if (BN_hex2bn (&ciphertext, &inBuf [ 1 ]) == 0) {
	  FATAL ("Corrupt digital signature hexadecimal encoding.\n%s",
								rejectMsg);
	}
      } else {
	if (BN_hex2bn (&ciphertext, inBuf) == 0) {
	  FATAL ("Corrupt ciphertext hexadecimal encoding.\n%s", rejectMsg);
	}
      }
      ASSERT (BN_mod_exp (message, ciphertext, d, n, ctx) == 1);
      if (BN_num_bytes (message) < (int) min_msg_size) {
	FATAL ("%s message-size below minimum threshold.\n%s",
		(in_sig) ? "Digital signature" : "Ciphertext", rejectMsg);
      }
      if (BN_num_bytes (message) > (int) max_msg_size) {
	FATAL ("%s message-size exceeded maximum limit.\n%s",
		(in_sig) ? "Digital signature" : "Ciphertext", rejectMsg);
      }
      ByteType		mesgBin [ BN_num_bytes (message) ];
      ASSERT (BN_bn2bin (message, mesgBin) == BN_num_bytes (message));
      saltLen = mesgBin[0];
      if (saltLen == 0) {
	FATAL ("Decoded %s salt-size equals zero.\n%s",
		(in_sig) ? "digital signature" : "ciphertext", rejectMsg);
      }
      if (saltLen > (sizeof (mesgBin) - 1)) {
	FATAL ("Decoded %s salt-size exceeded message-size.\n%s",
		(in_sig) ? "digital signature" : "ciphertext", rejectMsg);
      }
      if ((sizeof (mesgBin) - 1 - saltLen) > max_msg_data_size) {
	FATAL ("%s data-size exceeded maximum limit.\n%s",
		(in_sig) ? "Digital signature" : "Ciphertext", rejectMsg);
      }
      // NOTE: saltLen <= (sizeof (mesgBin) - 1)
      dataLen = sizeof (mesgBin) - 1 - saltLen;
      if (dataLen > 0) {
	const ByteType * const	data = &mesgBin [ 1 + saltLen ];
	if (in_sig) {
	  if ((dataLen + sigBin_i) > sizeof (sigBin)) {
	    FATAL ("Digital signature exceeded signature buffer size.\n%s",
								rejectMsg);
	  }
	  memcpy (&sigBin [ sigBin_i ], data, dataLen);
	  sigBin_i += dataLen;
	} else {
	  if (fwrite (data, sizeof (ByteType), dataLen, fOut) != dataLen) {
	    FileUtil::PError ("fwrite(3)", outputFile);
	    exit (EXIT_FAILURE);
	  }
	  SHA1_Update (&SHA1context, data, dataLen);
	}
      }
      memset (mesgBin, 0, sizeof (mesgBin));
    } else if (Got < 0) {
      FileUtil::PError ("fgetc(3)", inputFile);
      exit (EXIT_FAILURE);
    }
  } while (!feof (fIn));
  printf (" done.\n");
  fflush (stdout);

  // Verify digital signature: M = s^e mod from_n
  //
  printf ("Verifying digital signature:");
  fflush (stdout);
  unsigned char		SHA1digest [ SHA_DIGEST_LENGTH ];
  SHA1_Final (SHA1digest, &SHA1context);
  BIGNUM * const	e	= BN_new ();
  ASSERT (BN_set_word (e, encExp) == 1);
  BIGNUM * const	s	= BN_new ();
  ASSERT (BN_bin2bn (sigBin, sigBin_i, s) != NULL);
  memset (sigBin, 0, sizeof (sigBin));
  BIGNUM * const	M	= BN_new ();
  ASSERT (BN_mod_exp (M, s, e, from_n, ctx) == 1);
  BN_clear (s);
  if (BN_num_bytes (M) < (int) (SHA_DIGEST_LENGTH + 1)) {
    FATAL (rejectMsg);
  }
  ByteType		Mbin [ BN_num_bytes (M) ];
  ASSERT (BN_bn2bin (M, Mbin) == BN_num_bytes (M));
  BN_clear (M);
  if (memcmp (&Mbin [ 1 ], SHA1digest, sizeof (SHA1digest)) != 0) {
    FATAL (rejectMsg);
  } else {
    printf (" Signature verified.\n");
  }
  memset (SHA1digest, 0, sizeof (SHA1digest));
  memset (Mbin, 0, sizeof (Mbin));

  // Let O/S take care of allocated resources: Just exit ...

  exit (EXIT_SUCCESS);
}
