/*
 * EAPTLS - EAPServer TLS module. This can be used to implement EAP-TLS by 
 *          verifying the client certificate, or as the foundation for
 *	    EAP-TTLS and EAP-PEAP by exporting the TLS payload (or
 *          InnerApplication records) as RADIUS A/V pairs or EAP packets,
 *          respectively.
 *
 * Author:
 * Emile van Bergen, emile@evbergen.xs4all.nl
 *
 * Permission to redistribute an original or modified version of this program
 * in source, intermediate or object code form is hereby granted exclusively
 * under the terms of the GNU General Public License, version 2. Please see the
 * file COPYING for details, or refer to http://www.gnu.org/copyleft/gpl.html.
 *
 * History:
 */

char eaptls_id[] = "EAPTLS - Copyright (C) 2005 Emile van Bergen.";


/*
 * INCLUDES & DEFINES
 */


#include <sys/uio.h>
#include <netinet/in.h>
#include <string.h>
#include <unistd.h>
#include <stdlib.h>
#include <stdio.h>

#include <evblib/sysdefs/byteorder.h>
#include <evblib/buffer/buffer.h>
#include <evblib/strio/strio.h>
#include <evblib/misc/misc.h>

#include <gnutls/gnutls.h>
#include <gnutls/extra.h>

#include <metadata.h>		/* For META_ORD */
#include <constants.h>		/* Server constants */


/* Calculates how far ahead a is from b on a ring of size mod */

#define MODAHEAD(a, b, mod)	(((mod) + (a) - (b)) % (mod))


/* Some constants */

#define MODULE_VERSION		"v0.1"

#define MAX_CONVERSATIONS	128

#define TLS_BUFFER_SIZE		32768	    /* Two TLS max. size records */


#define RET_ERROR	    	0x40

#define RET_NOMEM		0x40
#define RET_NOCONV		0x41
#define RET_NODATA		0x42
#define RET_NOBUF		0x43


/* Dictionary items not in common/constants.h */

#define C_DI_FRAMED_MTU		12	/* In C_DS_RAD_ATTR */

#define C_DI_TLS_PAYLOAD	51	/* In C_DS_INTERNAL */
#define C_DI_TLS_RESULT		52	
#define C_DI_TLS_ACTION		53

#define C_DV_TLS_FLAGS_MORE	0x40
#define C_DV_TLS_FLAGS_LENGTH	0x80


/*
 * TYPES
 */


typedef enum conv_state {
    TLS_HANDSHAKE, TLS_IA, TLS_APP, TLS_CLOSE
} conv_state_t;


typedef struct conv {
    int inuse;
    conv_state_t state;
    time_t exp;

    gnutls_session_t session;
    BUF_T fromtls, totls;	    /* transport buffers (below TLS) */
    BUF_T totls_app;		    /* application buffer (above TLS);
				       we don't buffer application dat
				       from TLS, as we send that to OR
				       straight away */ 
    int result;
} CONV_T;


/* 
 * GLOBALS
 */


char buf[TLS_BUFFER_SIZE];	    /* scratch buffer */

META_ORD rx_space, tx_space;
int debug;


/*
 * CALLBACKS (BRR)
 */


/* Pull data from conversation's to_tls buffer, which is a ring that never
 * wraps, because everything is always pulled before that occurs. */

ssize_t tlspull(gnutls_transport_ptr_t td, void *data, size_t len)
{
    CONV_T *conv = (CONV_T *)td;
    if (!conv->inuse) { errno = EBADF; return -1; }

    /* Another copy, and then Gnutls proceeds to make a few more of its own.
     * Sigh. */

    if (len > buf_maxget(&conv->totls)) { errno = EAGAIN; return -1; }
    memcpy(data, conv->totls.r, len);
    buf_get(&conv->totls, len);
    return len;
}


/* Push data into conversation's from_tls buffer, which is a ring that
 * never wraps, because we always empty it as soon as anything appears here. */

ssize_t tlspush(gnutls_transport_ptr_t td, const void *data, size_t len)
{
    CONV_T *conv = (CONV_T *)td;
    if (!conv->inuse) { errno = EBADF; return -1; }
    
    /* Check if there's room for what Gnutls wants to give us. There
     * always should be, because Gnutls never writes more than two full
     * sized records for one top level API call, as far as we know. 
     * Otherwise, fail the call completely, so we'll get notified. */

    if (len > buf_maxput(&conv->fromtls)) { errno = ENOSYS; return -1; }

    /* Another copy, and then Gnutls proceeds to make a few more of its own.
     * Sigh. */

    memcpy(conv->fromtls.w, data, len);
    buf_put(&conv->fromtls, len);
    return len;
}


/* Output GnuTLS debugging */

void tlslog(int level, const char *s)
{
    fprintf(stderr, "[%d] %s", level, s);
}


/*
 * FUNCTIONS
 */


void usage()
{
    fprintf(stderr, 
"Usage: eaptls [-d] [-s attribute space base number]\n"
"       eaptls -v\n"
"       eaptls -h\n");
    _exit(1);
}


void parseoptions(int argc, char **argv)
{
    int c;

    /* Handle options */
    while((c = getopt(argc, argv, "dhvr:s:")) != -1) {
        switch(c) {
          case 's': rx_space = tx_space = atoi(optarg); 
		    if (!tx_space) {
			fprintf(stderr, "radldap: Invalid space number '%s'!\n",
				optarg);
			usage();
		    }
		    rx_space++;
		    break;
          case 'd': debug++; break;
          case 'v': fprintf(stderr, "\nEAPTLS module " MODULE_VERSION ". "
               "Copyright (C) 2005 Emile van Bergen / E-Advies.\n\n"
"Permission to redistribute an original or modified version of this program\n"
"in source, intermediate or object code form is hereby granted exclusively\n"
"under the terms of the GNU General Public License, version 2. Please see the\n"
"file COPYING for details, or refer to http://www.gnu.org/copyleft/gpl.html.\n"
"\n");
          case 'h':
          case '?':
            usage();
        }
    }
}


static inline int addint(uint32_t **p, uint32_t atrnr, uint32_t val)
{
    uint32_t *o = *p;

    *o++ = net32(C_DS_INTERNAL);
    *o++ = net32(META_VND_ANY);
    *o++ = net32(atrnr);
    *o++ = net32(4);
    *o++ = net32(val);
    *p = o;
    return 20;
}


#if 0
#define ADDINT(o,nr,val)		\
   (*(*(o))++ = net32(C_DS_INTERNAL),	\
    *(*(o))++ = net32(META_VND_ANY),	\
    *(*(o))++ = net32(nr),		\
    *(*(o))++ = net32(sizeof(val)),     \
    *(*(o))++ = net32(val),		\
    16 + (sizeof(val) + 3) & ~3)
#endif


/* Proceed with TLS conversation according to its state; may add
 * attributes to module response. */

typedef enum { RETURN, APP_RECV, IA_RECV, PRERROR, ERROR, DUMPSESS } action_t;

uint32_t tls_state_machine(CONV_T *conv, uint32_t **rep, uint32_t *repend,
			   int app_action)
{
    action_t action;
    int ret, gr;
    uint32_t *o;
    ssize_t l;

    ret = gr = 0;
    action = RETURN;

    /* State switch */

    if (debug) 
	fprintf(stderr, "eaptls: tls_state_machine, state %d\n", conv->state);

    switch(conv->state) {
      case TLS_HANDSHAKE:
	gr = gnutls_handshake(conv->session);
	if (debug) fprintf(stderr, "eaptls: gnutls_handshake returned %d: %s\n", gr, gr <= 0 ? gnutls_strerror(gr) : "(n/a)");
	if (gr >= 0) {
	    action = gnutls_ia_handshake_p(conv->session) ? IA_RECV : APP_RECV;
	    break;
	}
	if (gr == GNUTLS_E_INTERRUPTED || GNUTLS_E_AGAIN) break;
	action = ERROR;
	break;

      case TLS_APP:
	action = APP_RECV;
	l = buf_maxget(&conv->totls_app);
	if (!l) break;
	gr = gnutls_record_send(conv->session, conv->totls_app.r, l); 
	if (debug) fprintf(stderr, "eaptls: gnutls_record_send %d returned %d: %s\n", (int)l, gr, gr <= 0 ? gnutls_strerror(gr) : (gr < l ? "short write" : "all data sent"));
	if (gr > 0) { if (gr > l) gr = l; buf_get(&conv->totls_app, gr); }
	if (!gr || gr == GNUTLS_E_INTERRUPTED || GNUTLS_E_AGAIN) break;
	action = ERROR;
	break;

      default:
	fprintf(stderr, "eaptls: Unknown state %d!\n", conv->state);
	action = ERROR;
    }

    /* Action switch */

action:
    if (debug) 
	fprintf(stderr, "eaptls: tls_state_machine, action %d\n", action);

    switch(action) {

      /* Receive TLS data straight into reply packet buffer; keep header
       * only if we received any data at all */

      case APP_RECV:
	conv->state = TLS_APP;
	l = (repend - *rep) * sizeof(**rep) - 16;
	if (l < 0) {
	    fprintf(stderr, "eaptls: BUG: no room in OR reply buffer\n");
	    action = ERROR; goto action;
	}
	o = *rep;
	*o++ = net32(C_DS_INTERNAL);
	*o++ = net32(META_VND_ANY);
	*o++ = net32(C_DI_TLS_PAYLOAD);
	*o++ = 0;					/* length done later */
	gr = gnutls_record_recv(conv->session, o, l);
	if (debug) {
	    fprintf(stderr, "eaptls: gnutls_record_recv %d returned %d: %s\n", l, gr, gr < 0 ? gnutls_strerror(gr) : "data:");
	    hexdumpfd(2, o, gr, 0);
	}
	if (gr > 0) {
	    (*rep)[3] = net32(gr);	/* set length */
	    gr = (gr + 3) >> 2;		/* length in 32-bit words */
	    *rep = o + gr;
	    ret += 16 + (gr << 2);	/* rounded length plus header */
	    break;
	}
	if (gr == GNUTLS_E_INTERRUPTED || GNUTLS_E_AGAIN) break;
	action = ERROR;
	goto action;

      case RETURN: 
	break;

      /* Error handling */

      case PRERROR:
	fprintf(stderr, "eaptls: "); gnutls_perror(gr);
      case ERROR:
      case DUMPSESS:
	fprintf(stderr, "eaptls: FIXME: implement dump session action\n");
	break;

      default:
	fprintf(stderr, "eaptls: Unknown action %d!\n", action);
    }
    return ret;
}


/*
 * MAIN
 */


int main(int argc, char **argv)
{
    static uint32_t rxbuf[(C_MAX_MSGSIZE >> 2) + 1];
    static uint32_t txbuf[(C_MAX_MSGSIZE >> 2) + 1];
    static CONV_T convring[MAX_CONVERSATIONS];
    static gnutls_certificate_credentials_t x509_cred;
    static gnutls_dh_params_t dh_params;
    CONV_T *conv;
    STR_T s, tls_data, rx_payload;
    uint32_t spc, vnd, atr, len, *i, *e, *o; 
    uint32_t convr, convw, convnr;
    uint32_t rx_flags, tx_flags, mtu, rx_result, rx_action, retint;
    gnutls_datum dat;
    struct iovec iov[4];
    int parts, n;
    ssize_t l;

    /* Initialise static data, parse options */

    memset(convring, 0, sizeof(convring));
    parseoptions(argc, argv);

    /* Setup GnuTLS */

    gnutls_global_init();
    gnutls_global_init_extra();
    gnutls_global_set_log_function(tlslog);
    gnutls_global_set_log_level(4);
    gnutls_certificate_allocate_credentials(&x509_cred);
    
    /* Get CAs, CRL, our own keypair */

    gnutls_certificate_set_x509_trust_file(x509_cred, 
	"modules/eaptls/ca.pem", GNUTLS_X509_FMT_PEM);
    gnutls_certificate_set_x509_crl_file(x509_cred, 
	"modules/eaptls/crl.pem", GNUTLS_X509_FMT_PEM);
    gnutls_certificate_set_x509_key_file(x509_cred, 
	"modules/eaptls/cert.pem",
	"modules/eaptls/key.pem", GNUTLS_X509_FMT_PEM);

    /* Get DH parameters - we should read those inside the loop, from a 
     * file that is created upon installation and periodically re-created
     * by cron and reread here each time it's found to be newer than last
     * time it was read. That ensures security even if the module stays
     * running forever, and does not create any delays here. For testing
     * this will do though. */

    s = readstr("modules/eaptls/dh.pem");
    if (s.l < 0) {
	fprintf(stderr, "Can't read DH parameters: %s!\n", strerror(-s.l));
	return 1;
    }
    dat.data = s.p; dat.size = s.l;
    gnutls_dh_params_init(&dh_params);
    gnutls_dh_params_import_pkcs3(dh_params, &dat, GNUTLS_X509_FMT_PEM);
    free(s.p);
    gnutls_certificate_set_dh_params(x509_cred, dh_params);

    /* Request loop */

    convr = convw = 0;
    if (debug) errputcs("eaptls: Ready for requests.\n");
    for(;;) {

        /*
         * Get the request from OpenRADIUS
         */

        /* Read header */
        if (read(0, rxbuf, 8) != 8) { perror("eaptls: read"); break; }
        if (net32(*rxbuf) != 0xbeefdead) {
            fprintf(stderr, "eaptls: Invalid magic 0x%08x!\n", net32(*rxbuf)); 
            break;
        }
        len = net32(rxbuf[1]);
        if (len < 8 || len > sizeof(rxbuf) - 4) {
            fprintf(stderr, "eaptls: Invalid length %d!\n", len); 
            break;
        }

        /* Read rest of message */
        if (read(0, rxbuf + 2, len - 8) != len - 8) {
            perror("eaptls: read"); 
            break;
        }

	/* Initialize defaults */

	retint = convnr = rx_flags = tx_flags = rx_result = rx_action = ~0;
	tls_data.l = 0;
	rx_payload.l = 0;
	mtu = 1230;

        /*
         * Loop through the attributes. The attribute spaces for
         * the Request- and Response-[T]TLS/PEAP-* attributes 
         * depend on the EAP type and are specified on the command
         * line. Note: we parse responses, we send requests. Attribute
         * numbers 0 and 1 correspond to the Flags and TLS Length/Data
         * fields of the EAP TLS packets.
         */

        e = rxbuf + (len >> 2);
        for(i = rxbuf + 2; i < e; i += ((len + 3) >> 2) + 4) {

            /* Get space/vendor/attribute/length tuple */

            spc = net32(i[0]); vnd = net32(i[1]);
            atr = net32(i[2]); len = net32(i[3]);
            if (debug) {
                fprintf(stderr, "eaptls: got space %d, vendor "
                        "%d, attribute %d, len %d\n", spc, vnd, atr, len);
            }

            /* Get TLS-Data, -Flags attributes from EAP type-specific space */

            if (spc == rx_space && (vnd == 0 || vnd == META_ORD_ERR)) {
                switch (atr) {
                    case 0: rx_flags = net32(i[4]); break;
                    case 1: tls_data.p = (char *)&(i[4]); 
			    tls_data.l = len;
                }
                continue;
            }

	    /* Get Transaction-Id (indexes timer ring containing
	     * conversations); TLS-Payload (cleartext data to be transmitted
	     * using as TLS Application Data) */

	    if (spc == C_DS_INTERNAL) switch(atr) {
		case C_DI_TID:		convnr = net32(i[4]); 
					continue;
		case C_DI_TLS_PAYLOAD:	rx_payload.p = (char *)&(i[4]); 
					rx_payload.l = len;
					continue;
		case C_DI_TLS_RESULT:	rx_result = net32(i[4]); 
					continue;
		case C_DI_TLS_ACTION:	rx_action = net32(i[4]); 
					continue;
	    }

	    /* Get Framed-MTU, used as maximum amount of data to get from TLS at
	     * a time */

            if (spc == C_DS_RAD_ATR && atr == C_DI_FRAMED_MTU) {
                mtu = net32(i[4]);
                continue;
            }
        }

        /* Initialise reply buffer */

        o = txbuf; *o++ = net32(0xdeadbeef); o++;
        e = txbuf + sizeof(txbuf) / sizeof(txbuf[0]);
        len = 8; 

	/* See if we actually got any TLS attributes; flags must always be
	 * present and cannot be 0xffffffff (at most 0xff) */

	if (rx_flags == ~0) {
	    errputcs("eaptls: no TLS attributes from core - wrong space?\n");
	    retint = RET_NODATA; goto reply;
	}

        /* Obtain new conversation number if none given or given one invalid */

        if (convnr > MAX_CONVERSATIONS || !convring[convnr].inuse) {

            /* If we don't have room for a new conversation, fail now */

            if (MODAHEAD(convr, convw + 1, MAX_CONVERSATIONS) <= 0) {
		errputcs("eaptls: no room for new conversation!\n");
		retint = RET_NOCONV; goto reply;
            }

            /* Create a new EAP-TLS conversation. A conversation is
             * a series of requests and responses. It does not map
             * 1:1 to a TLS session, because a new conversation may
             * resume an existing TLS session. */

            convnr = convw++;
	    conv = &convring[convnr];
	    conv->inuse = 1;
	    conv->state = TLS_HANDSHAKE;
	    if (!buf_init(&conv->fromtls, TLS_BUFFER_SIZE) ||
	        !buf_init(&conv->totls, TLS_BUFFER_SIZE) ||
	        !buf_init(&conv->totls_app, TLS_BUFFER_SIZE)) {
		errputcs("eaptls: no memory for new conversation!\n");
		retint = RET_NOCONV; goto reply;
	    }

	    /* FIXME: add expiry time to ring entry */

	    /* Create Gnutls session - may be resumed from existing one during
	     * handshake */

            gnutls_init(&conv->session, GNUTLS_SERVER);
            gnutls_set_default_priority(conv->session);
	    gnutls_credentials_set(conv->session, GNUTLS_CRD_CERTIFICATE, 
				   x509_cred);
	    gnutls_transport_set_lowat(conv->session, 0);
	    gnutls_transport_set_push_function(conv->session, tlspush);
	    gnutls_transport_set_pull_function(conv->session, tlspull);
	    gnutls_transport_set_ptr(conv->session, 
				     (gnutls_transport_ptr_t)conv);

	    fprintf(stderr, "eaptls: New conversation %u\n", convnr);
        }

	/* Skip length field in TLS data if length flag set */

	if (rx_flags & C_DV_TLS_FLAGS_LENGTH) {
	    tls_data.p += 4;
	    tls_data.l -= 4;
	}

	/* Show what we got */

	if (debug) {
	    fprintf(stderr, "eaptls: Got %4d bytes of TLS data, flags 0x%02x, MTU %d, conversation %d\n", tls_data.l, rx_flags, mtu, convnr);
	    hexdumpfd(2, tls_data.p, tls_data.l, 0);
	}

	/* Process it */

	if (rx_result != ~0) {
	    conv->result = rx_result;	    /* Echoed when we drop the
					       conversation */
	}
	if (tls_data.l) { 

	    /* Feed TLS data to TLS transport buffer, from where GnuTLS will
	     * pull it using the gnutls_transport_read callback */

	    l = buf_maxput(&conv->totls);
	    if (tls_data.l > l) {
		fprintf(stderr, "eaptls: No room for %d bytes in transport "
				"buffer; has %d left!\n", tls_data.l, l);

		/* FIXME: We should dump this session now */
		retint = RET_NOBUF; goto reply;
	    }
	    memcpy(conv->totls.w, tls_data.p, tls_data.l);
	    buf_put(&conv->totls, tls_data.l);
	}
	if (rx_payload.l) {

	    /* Feed payload data to TLS application buffer, where it's kept
	     * until our state machine determines it's ready to send TLS
	     * Application or InnerApplication records and it's all sent
	     * in one go. */

	    l = buf_maxput(&conv->totls_app);
	    if (rx_payload.l > l) {
		fprintf(stderr, "eaptls: No room for %d bytes in payload "
				"buffer; has %d left!\n", rx_payload.l, l);

		/* FIXME: We should dump this session now */
		retint = RET_NOBUF; goto reply;
	    }
	    memcpy(conv->totls_app.w, rx_payload.p, rx_payload.l);
	    buf_put(&conv->totls_app, rx_payload.l);
	}

	/* Work TLS engine according to our state machine, unless the 'more'
	 * flag was set by the sender. May add attributes to response. */

	if (!(rx_flags & C_DV_TLS_FLAGS_MORE))
	    len += tls_state_machine(conv, &o, e, rx_action);

	/* Add the given or new conversation number */

	len += addint(&o, C_DI_TID, convnr);

	/* We're OK now */

	retint = 0;

        /*
         * Send a reply with size 'len'; an 'int' attribute containing the 
	 * status in retint is added, together with the TLS data if there's
	 * any to send and we're not reporting an error.
         */
reply:
	/* Add reply status */

	len += addint(&o, C_DI_INT, retint);

	/* Check if we have any data to send. If the MTU is smaller than 256
	 * or there's no data in the buffer or we have an error to report, we
	 * won't send any data. */

	l = 0;
	if (!(rx_flags & C_DV_TLS_FLAGS_MORE) && mtu >= 256 &&
	    retint < RET_ERROR && (l = buf_maxget(&conv->fromtls))) {

	    /* If this is the first part of our transmission, set
	     * length flag in tx_flags and add a length field at
	     * the start of the TLS data. */

	    tx_flags = 0;
	    if (conv->fromtls.r == conv->fromtls.bu.p) {
		tx_flags = C_DV_TLS_FLAGS_LENGTH;
	    }

	    /* If we can send less than our whole buffer, set the
	     * more flag in tx_flags. */

	    if (l > mtu - 12) {
		l = mtu - 12;
		tx_flags |= C_DV_TLS_FLAGS_MORE;
	    }

	    /* Add the flags */

	    *o++ = net32(tx_space); *o++ = net32(META_VND_ANY); *o++ = net32(0);
	    *o++ = net32(1);
	    *o++ = net32(tx_flags);
	    len += 20;

	    /* Add a header for TLS transport length (optional) and data */

	    *o++ = net32(tx_space); *o++ = net32(META_VND_ANY); *o++ = net32(1);
	    if (tx_flags & C_DV_TLS_FLAGS_LENGTH) {
		*o++ = net32(l + 4);
		*o++ = net32(buf_maxget(&conv->fromtls));
		len += 4;
	    }
	    else *o++ = net32(l);
	    len += 16;
	}

	/* Start to prepare response */

	parts = 0;
	iov[parts].iov_base = txbuf;			    /* Header */
	iov[parts].iov_len = len;
	parts++;
	if (l) {
	    iov[parts].iov_base = conv->fromtls.r;	    /* TLS trans data */
	    iov[parts].iov_len = l;
	    parts++;
	    len += l;
	    l = (4 - l) & 3;
	    if (l) {
		iov[parts].iov_base = "\x00\x00\x00\x00";   /* Padding */
		iov[parts].iov_len = l;
		parts++;
		len += l;
	    }
	}

	/* Send it */

        txbuf[1] = net32(len);
	if (debug) {
	    fprintf(stderr, "eaptls: replying, %d bytes in %d part(s):\n", 
		    len, parts);
	    for(n = 0; n < parts; n++)
		hexdumpfd(2, iov[n].iov_base, iov[n].iov_len, 0);
	}

        if (writev(1, iov, parts) != len) { perror("eaptls: write"); break; }
	if (parts >= 2) buf_get(&conv->fromtls, iov[1].iov_len);
    }

    gnutls_certificate_free_credentials(x509_cred);
    gnutls_global_deinit();

    return 1;
}


/*
 * vim:softtabstop=4:sw=4
 */

