/* seskey.c - Session key routines
 *        Copyright (C) 2002 Timo Schulz
 *        Copyright (C) 1998-2002 Free Software Foundation, Inc.
 *
 * This file is part of OpenCDK.
 *
 * OpenCDK is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or
 * (at your option) any later version.
 *
 * OpenCDK is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with OpenCDK; if not, write to the Free Software Foundation,
 * Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
 */

#ifdef HAVE_CONFIG_H
# include <config.h>
#endif
#include <assert.h>
#include <stdio.h>

#include "opencdk.h"
#include "main.h"
#include "packet.h"

#define gcry_md_get_asnoid(algo, asnbuf, asnlen) \
gcry_md_algo_info( (algo), GCRYCTL_GET_ASNOID, (asnbuf), (asnlen) )

/* We encode the MD in this way:
 *
 * 0  1 PAD(n bytes)   0  ASN(asnlen bytes)  MD(len bytes)
 *
 * PAD consists of FF bytes.
 */
static int
do_encode_md (byte ** r_frame, size_t * r_flen, const byte * md, int algo,
	      size_t len, unsigned nbits, const byte * asn, size_t asnlen)
{
  int i, nframe = (nbits + 7) / 8;
  byte * frame = NULL;
  size_t n = 0;

  if (!asn || !md)
    return CDK_Inv_Value;

  if (len + asnlen + 4 > nframe)
    return CDK_General_Error;

  frame = cdk_calloc (1, nframe);
  if (!frame)
    return CDK_Out_Of_Core;
  frame[n++] = 0;
  frame[n++] = 1;
  i = nframe - len - asnlen - 3;
  if (i < 0)
    {
      cdk_free (frame);
      return CDK_Inv_Value;
    }
  memset (frame + n, 0xff, i);
  n += i;
  frame[n++] = 0;
  memcpy (frame + n, asn, asnlen);
  n += asnlen;
  memcpy (frame + n, md, len);
  n += len;
  if (n != nframe)
    {
      cdk_free (frame);
      return CDK_Inv_Value;
    }
  if (r_frame)
    *r_frame = frame;
  if (r_flen)
    *r_flen = n;

  return 0;
}


/* RFC2437 format:
 *  
 *  0  2  RND(n bytes)  0  [A  DEK(k bytes)  CSUM(2 bytes)]
 *  
 *  RND - randomized bytes for padding.
 *  A - cipher algorithm.
 *  DEK - random session key.
 *  CKSUM - algebraic checksum of the DEK.
 */
int
_cdk_pkcs1_sesskey (GCRY_MPI * esk, CDK_DEK dek, int nbits)
{
  int rc = 0;
  int i = 0;
  int nframe;
  byte *p, *frame;
  u16 chksum = 0;
  size_t n = 0;
  GCRY_MPI a = NULL;

  if (!esk || !dek)
    return CDK_Inv_Value;

  for (i = 0; i < dek->keylen; i++)
    chksum += dek->key[i];
  nframe = (nbits + 7) / 8;
  frame = cdk_scalloc (nframe + 1);
  if (!frame)
    return CDK_Out_Of_Core;
  n = 0;
  frame[n++] = 0x00;
  frame[n++] = 0x02;
  i = nframe - 6 - dek->keylen;
  p = gcry_random_bytes (i, GCRY_STRONG_RANDOM);
  /* replace zero bytes by new values */
  for (;;)
    {
      int j, k;
      byte *pp;

      /* count the zero bytes */
      for (j = k = 0; j < i; j++)
	{
	  if (!p[j])
	    k++;
	}
      if (!k)
	break; /* okay: no zero bytes */
      k += k / 128; /* better get some more */
      pp = gcry_random_bytes (k, GCRY_STRONG_RANDOM);
      for (j = 0; j < i && k; j++)
	{
	  if (!p[j])
	    p[j] = pp[--k];
	}
      cdk_free (pp);
    }
  memcpy (frame + n, p, i);
  cdk_free (p);
  n += i;
  frame[n++] = 0;
  frame[n++] = dek->algo;
  memcpy (frame + n, dek->key, dek->keylen);
  n += dek->keylen;
  frame[n++] = chksum >> 8;
  frame[n++] = chksum;
  rc = gcry_mpi_scan (&a, GCRYMPI_FMT_USG, frame, &nframe);
  if (rc)
    rc = CDK_Gcry_Error;
  cdk_free (frame);
  if (!rc)
    *esk = a;
  return rc;
}


static CDK_DEK
pkcs1_decode (GCRY_MPI esk)
{
  CDK_DEK dek;
  byte frame[4096];
  size_t nframe, n;
  u16 csum = 0, csum2 = 0;
  int rc;

  nframe = sizeof frame-1;
  rc = gcry_mpi_print (GCRYMPI_FMT_USG, frame, &nframe, esk);
  if (rc)
    return NULL;
  dek = cdk_scalloc (sizeof *dek);
  if (!dek)
    return NULL;

  /* Now get the DEK (data encryption key) from the frame
   *
   * Old versions encode the DEK in in this format (msb is left):
   *
   *     0  1  DEK(16 bytes)  CSUM(2 bytes)  0  RND(n bytes) 2
   *
   * Later versions encode the DEK like this:
   *
   *     0  2  RND(n bytes)  0  A  DEK(k bytes)  CSUM(2 bytes)
   *
   * (mpi_get_buffer already removed the leading zero).
   *
   * RND are non-zero randow bytes.
   * A   is the cipher algorithm
   * DEK is the encryption key (session key) with length k
   * CSUM
   */
  n = 0;
  if (frame[n] != 2)
    {
      cdk_free (dek);
      return NULL;
    }
  for (n++; n < nframe && frame[n]; n++)
    ;
  n++;
  dek->keylen = nframe - (n + 1) - 2;
  dek->algo = frame[n++];
  if (dek->keylen != gcry_cipher_get_algo_keylen (dek->algo))
    {
      cdk_free (dek);
      return NULL;
    }
  csum =  frame[nframe-2] << 8;
  csum |= frame[nframe-1];
  memcpy (dek->key, frame + n, dek->keylen);
  for (n = 0; n < dek->keylen; n++)
    csum2 += dek->key[n];
  if (csum != csum2)
    {
      cdk_free (dek);
      return NULL;
    }
  return dek;
}


/* Do some tests before it calls do_encode_md that depends on the
   public key algorithm that is used. */
int
_cdk_pkcs1_digest (byte ** r_md, size_t * r_mdlen, int pk_algo,
                   const byte * md, int digest_algo, unsigned nbits)
{
  int rc = 0;

  if (!md || !r_md || !r_mdlen)
    return CDK_Inv_Value;

  if (is_DSA (pk_algo))
    {
      size_t n = gcry_md_get_algo_dlen (digest_algo);
      if (!n)
	return CDK_Inv_Algo;
      *r_md = cdk_malloc (n + 1);
      if (!*r_md)
	return CDK_Out_Of_Core;
      *r_mdlen = n;
      memcpy (*r_md, md, n);
      return 0;
    }
  else
    {
      byte *asn = NULL;
      size_t asnlen = 0, digest_size = 0;

      rc = gcry_md_get_asnoid (digest_algo, NULL, &asnlen);
      if (rc || !asnlen)
	return CDK_Gcry_Error;
      asn = cdk_malloc (asnlen + 1);
      if (!asn)
	return CDK_Out_Of_Core;
      rc = gcry_md_get_asnoid (digest_algo, asn, &asnlen);
      if (rc)
	return CDK_Gcry_Error;
      digest_size = gcry_md_get_algo_dlen (digest_algo);
      rc = do_encode_md (r_md, r_mdlen, md, digest_algo, digest_size,
			 nbits, asn, asnlen);
      cdk_free (asn);
      return rc;
    }

  return 0;
}


static char *
passphrase_prompt (cdkPKT_secret_key * sk)
{
  char * p;
  int bits = cdk_pk_get_nbits (sk->pk), pk_algo = sk->pubkey_algo;
  u32 keyid = cdk_pk_get_keyid (sk->pk, NULL);
  const char *algo = "???";
  
  p = cdk_calloc (1, 512);
  if (!p)
    return NULL;

  if (is_RSA (pk_algo))
    algo = "RSA";
  else if (is_ELG (pk_algo))
    algo = "ELG";
  else if (is_DSA (pk_algo))
    algo = "DSA";
  
  snprintf (p, 511, "%d-bit %s key, ID %08lX\nEnter Passphrase: ",
            bits, algo, keyid);
  
  return p;
}


int
_cdk_seckey_unprotect2 (cdkPKT_secret_key * sk)
{
  char * pw = NULL, * p = NULL;
  int rc = 0;
  
  if (sk->is_protected)
    {
      p = passphrase_prompt (sk);
      pw = _cdk_passphrase_get (p);
      if (pw)
        rc = cdk_seckey_unprotect (sk, pw);
      _cdk_passphrase_free (pw);
      cdk_free (p);
    }
  return rc;
}


CDK_DEK
cdk_pkcs1_to_dek (cdkPKT_pubkey_enc * enc, cdkPKT_secret_key * sk)
{
  CDK_DEK dek = NULL;
  GCRY_MPI skey = NULL;
  int rc = 0;

  if (!enc || !sk)
      return NULL;
  
  if (sk->is_protected)
    rc = _cdk_seckey_unprotect2 (sk);
  if (!rc)
    rc = cdk_pk_decrypt (sk, enc, &skey);
  if (!rc)
    dek = pkcs1_decode (skey);
  gcry_mpi_release (skey);
  return dek;
}


int
cdk_dek_new (CDK_DEK * r_dek, int algo)
{
  GCRY_CIPHER_HD hd;
  CDK_DEK dek;
  int i, rc;

  if (r_dek)
    *r_dek = NULL;
  
  dek = cdk_scalloc (sizeof *dek);
  if (!dek)
    return CDK_Out_Of_Core;

  if (!algo)
    algo = GCRY_CIPHER_CAST5;
  dek->algo = algo;
  dek->keylen = gcry_cipher_get_algo_keylen (dek->algo);
  hd = gcry_cipher_open (dek->algo, GCRY_CIPHER_MODE_CFB, 1);
  if (!hd)
    return CDK_Gcry_Error;
  gcry_randomize (dek->key, dek->keylen, GCRY_STRONG_RANDOM);
  
  for (i = 0; i < 8; i++)
    {
      rc = gcry_cipher_setkey (hd, dek->key, dek->keylen);
      if (!rc)
	{
          if (r_dek)
            *r_dek = dek;
	  gcry_cipher_close (hd);
	  return 0;
	}
      gcry_randomize (dek->key, dek->keylen, GCRY_STRONG_RANDOM);
    }
  cdk_free (dek);
  return CDK_Weak_Key;
}


CDK_DEK
cdk_passphrase_to_dek (int cipher_algo, CDK_S2K_HD s2k, int mode, char * pw)
{
  CDK_DEK dek;

  dek = cdk_scalloc (sizeof *dek);
  if (!dek)
    return NULL;

  dek->algo = cipher_algo;
  if (!*pw && mode == 2)
    dek->keylen = 0;
  else
    cdk_hash_passphrase (dek, pw, s2k, mode == 2);
  return dek;    
}


int
cdk_hash_passphrase (CDK_DEK dek, char *pw, CDK_S2K_HD s2k, int create)
{
  GCRY_MD_HD md;
  int pass, i;
  int used = 0;
  int pwlen = 0;

  if (!dek || !pw || !s2k)
    return CDK_Inv_Value;

  if (!s2k->hash_algo)
      s2k->hash_algo = GCRY_MD_SHA1;
  pwlen = strlen (pw);

  dek->keylen = gcry_cipher_get_algo_keylen (dek->algo);
  md = gcry_md_open (s2k->hash_algo, GCRY_MD_FLAG_SECURE);
  if (!md)
    return CDK_Gcry_Error;

  for (pass = 0; used < dek->keylen; pass++)
    {
      if (pass)
	{
	  gcry_md_reset (md);
	  for (i = 0; i < pass; i++) /* preset the hash context */
	    gcry_md_putc (md, 0);
	}
      if (s2k->mode == 1 || s2k->mode == 3)
	{
	  int len2 = pwlen + 8;
	  u32 count = len2;
	  if (create && !pass)
	    {
	      gcry_randomize (s2k->salt, 8, 1);
	      if (s2k->mode == 3)
		s2k->count = 96; /* 65536 iterations */
	    }
	  if (s2k->mode == 3)
	    {
	      count = (16ul + (s2k->count & 15)) << ((s2k->count >> 4) + 6);
	      if (count < len2)
		count = len2;
	    }
	  /* a little bit complicated because we need a ulong for count */
	  while (count > len2) /* maybe iterated+salted */
	    {
	      gcry_md_write (md, s2k->salt, 8);
	      gcry_md_write (md, pw, pwlen);
	      count -= len2;
	    }
	  if (count < 8)
	    gcry_md_write (md, s2k->salt, count);
	  else
	    {
	      gcry_md_write (md, s2k->salt, 8);
	      count -= 8;
	      gcry_md_write (md, pw, count);
	    }
	}
      else
	gcry_md_write (md, pw, pwlen);
      gcry_md_final (md);
      i = gcry_md_get_algo_dlen (s2k->hash_algo);
      if (i > dek->keylen - used)
	i = dek->keylen - used;
      memcpy (dek->key + used, gcry_md_read (md, s2k->hash_algo), i);
      used += i;
    }
  gcry_md_close (md);
  return 0;
}


int
cdk_s2k_new (CDK_S2K_HD * ret_s2k, int mode, int algo, const byte * salt)
{
  CDK_S2K_HD s2k;
  int rc;

  if (mode < 0 || mode > 3)
    return CDK_Inv_Value;
  
  rc = _cdk_md_test_algo (algo);
  if (rc)
    return rc;
  
  s2k = cdk_calloc (1, sizeof *s2k);
  if (!s2k)
    return CDK_Out_Of_Core;
  s2k->mode = mode;
  s2k->hash_algo = algo;
  if (salt)
    memcpy (s2k->salt, salt, 8);

  if (ret_s2k)
    *ret_s2k = s2k;

  return 0;
}
