#include "libopgp.h"
#include <rsa.h>
#include <dh.h>
#include <dsa.h>

/*--------------------------------------------------*/
/* do DSA verify function */

static u32_t checkdsa(BIGNUM * r, BIGNUM * u1, DSA * dsakey,
                      u8_t * hash, int hlen)
{
  BN_CTX *ctx = BN_CTX_new();
  BIGNUM *t1 = BN_new(), *t2 = BN_new(), *t3 = BN_new(), *u2 = NULL;
  int ret;

  u2 = BN_mod_inverse(u1, dsakey->q, ctx);
  BN_bin2bn(hash, hlen, t3);
  BN_mod_mul(t3, t3, u2, dsakey->q, ctx);
  BN_mod_mul(u2, r, u2, dsakey->q, ctx);
  BN_mod_exp(t1, dsakey->g, t3, dsakey->p, ctx);
  BN_mod_exp(t2, dsakey->pub_key, u2, dsakey->p, ctx);
  BN_mod_mul(t3, t1, t2, dsakey->p, ctx);
  BN_mod(t3, t3, dsakey->q, ctx);
  DSA_free(dsakey);
  BN_CTX_free(ctx), BN_free(u2), BN_free(t1), BN_free(t2), BN_free(t3);
  ret = BN_ucmp(t3, r);
  return ret;
}

#define MAXSIGM 512
/*--------------------------------------------------*/
/* do ElG verify function */

static u32_t checkelg(BIGNUM * a, BIGNUM * b, DH * dhkey,
                      u8_t * hash, int hlen)
{
  BN_CTX *ctx = BN_CTX_new();
  BIGNUM *t = BN_new(), *u = BN_new();
  u8_t hbuf[MAXSIGM];
  u32_t k;

  BN_sub(t, dhkey->p, BN_value_one());
  BN_gcd(u, a, t, ctx);
  BN_set_word(t, 2);
  if (BN_ucmp(u, t) > 0) {
#if 0
    BN_print_fp(stderr, u);
    fprintf(stderr, " = gcd(p-1,r)\n");
    fprintf(stderr, "WARNING: Weak El Gamal Signature\n");
#endif
    exit(-1);
  }
  memset(hbuf, 0xff, MAXSIGM);
  hbuf[0] = 0, hbuf[1] = 1;
  k = BN_num_bytes(dhkey->p);
  hbuf[k - hlen - 1] = 0;
  memcpy(&hbuf[k - hlen], hash, hlen);
  BN_bin2bn(hbuf, k, u);
  BN_mod_exp(t, dhkey->g, u, dhkey->p, ctx);
  BN_mod_exp(u, a, b, dhkey->p, ctx);
  BN_mod_exp(b, dhkey->pub_key, a, dhkey->p, ctx);
  BN_mod_mul(a, b, u, dhkey->p, ctx);
  k = BN_ucmp(t, a);
  BN_CTX_free(ctx), BN_free(u), BN_free(t), DH_free(dhkey);
  return k;
}

/*--------------------------------------------------*/
/* do RSA verify function */

static u32_t checkrsa(BIGNUM * a, RSA * key, u8_t * hash, int hlen)
{
  u8_t dbuf[MAXSIGM];
  int j, ll;

  j = BN_bn2bin(a, dbuf);
  ll = BN_num_bytes(key->n);
  while (j < ll)
    memmove(&dbuf[1], dbuf, j++), dbuf[0] = 0;
  j = RSA_public_decrypt(ll, dbuf, dbuf, key, RSA_PKCS1_PADDING);
  RSA_free(key);
  return (j != hlen || memcmp(dbuf, hash, j));
}

u8_t *PGP_sigk;
/*--------------------------------------------------*/
int PGP_sigck(u8_t * sgp, void *hctx)
{
  u8_t hbuf[64], *bp, sgalg = 17, halg = 2;
  keyid_t keyid = 0;
  u32_t j;
  int i, k = 0;
  BIGNUM *r, *u1 = NULL;

  PGP_sigk = NULL;
  j = *sgp++;
  if (j == 3) {
    k = *sgp++;                 /* extramd length - fixed on old ver */
    PGP_hblk(hctx, sgp, k);
    sgp += k;
    PGP_sigk = sgp;
    sgp += 8;
    sgalg = *sgp++;
    halg = *sgp++;              /* hash algorithm */
  } else if (j == 4) {
    bp = sgp;
    sgp++;                      /* type */
    sgalg = *sgp++;
    halg = *sgp++;
    for (i = 0; i < 2; i++) {
      j = *sgp++ * 256, j += *sgp++;
      if (!i) {
        j += 6;
        PGP_hblk(hctx, --bp, j);
        bp = hbuf;
        *bp++ = 4, *bp++ = 0xff;
        *bp++ = j >> 24, *bp++ = j >> 16, *bp++ = j >> 8, *bp++ = j;
        PGP_hblk(hctx, hbuf, 6);
        j -= 6;
      }
      while (j > 0) {           /* find keyid */
        /* NOTE: I take only the first, but also check unhashed */
        if ((sgp[1] & 0x7f) == 16 && !keyid)
          PGP_sigk = &sgp[2];
        if ((k = *sgp) >= 192) {
          if (k == 255)
            k = 4 + ((u32_t) sgp[1] << 24) + ((u32_t) sgp[2] << 16)
              + ((u32_t) sgp[3] << 8) + sgp[4];
          else
            k = 1 + (k - 192) * 256 + 192 + sgp[1];
        }
        j -= ++k;
        sgp += k;
      }
      if (j)
        return -2;
    }
  } else
    return -2;

  PGP_hfin(hbuf, hctx);

  if (!(bp = PGP_sigk))
    return -3;
  for (j = 0; j < 8; j++)
    keyid = (keyid << 8) + *bp++;

  if (*sgp++ != hbuf[0] || *sgp++ != hbuf[1])
    return -1;
  if (sgalg != 17 && sgalg != 16 && sgalg != 20 && sgalg != 1 && sgalg != 3)
    return -3;
  r = PGP_mpiBN(&sgp);
  if (sgalg != 1 && sgalg != 3)
    u1 = PGP_mpiBN(&sgp);

  if (sgalg != 17)
    i = PGP_hder(halg, hbuf);
  else
    i = PGP_hlen[halg];

  for (;;) {
    void *sigkey;

    k = -4;
    j = PGP_gtkey(&sigkey, NULL, &keyid);
    if (j != sgalg)
      break;
    if (sgalg == 17)
      k = checkdsa(r, u1, sigkey, hbuf, i);
    else if (sgalg == 1 || sgalg == 3)
      k = checkrsa(r, sigkey, hbuf, i);
    else if (sgalg == 16 || sgalg == 20)
      k = checkelg(r, u1, sigkey, hbuf, i);
    if (j == sgalg && !k)
      break;
    keyid = NEXT_KEY;
  }
  BN_free(r);
  if (sgalg != 1 && sgalg != 3)
    BN_free(u1);
  return k;
}
