/* aes_cipher.c */
/* ESP, performs ESP encryption services to IP packet for a 
   group of AES Candidate cryptographic algorithms  */
/* Includes: MARS-CBC, SERPENT-CBC, TWOFISH-CBC, RC6-CBC, and RIJNDAEL-CBC */
            

#include <linux/types.h>
#include <linux/kernel.h>
#include <linux/skbuff.h>
#include <linux/errno.h>
#include <net/ip.h>
#include <linux/in.h>
#include <linux/in6.h>
#include <asm/uaccess.h>
#include <asm/checksum.h>
#include "../crypto/mars/mars.h"
#include "../crypto/rc6/rc6.h"
#include "../crypto/rijndael/rijndael.h"
#include "../crypto/serpent/serpent.h"
#include "../crypto/twofish/twofish.h"
#include "../sadb.h"
#include "../ipsec.h"
#include "../transform.h"
#include "aes_cipher.h"
#include "esp_mars_cbc.h"
#include "esp_rc6_cbc.h"
#include "esp_rijndael_cbc.h"
#include "esp_serpent_cbc.h"
#include "esp_twofish_cbc.h"

static char addr_s1[18],addr_s2[18];

struct sk_buff *aes_cipher_input(struct sk_buff *skb, struct sadb_dst_node *d, struct ipsec_transform *t)
{
	unsigned short dlen,pad_len;
	unsigned char ivec[AES_IVEC_LENGTH];
	unsigned char *pdata=0;
	unsigned long spi=0;
	struct iphdr *iph = (struct iphdr *)skb->data;
	unsigned short ihl = iph->ihl<<2;

	pdata = skb_pull(skb,ihl);
	spi = d->sa_info.spi;

	memcpy(ivec,pdata+sizeof(spi)+SN_LENGTH+d->sa_info.alg.esp.auth_ivec_length,d->sa_info.alg.esp.crypto_ivec_length);

	pdata = skb_pull(skb,sizeof(spi)+d->sa_info.alg.esp.auth_ivec_length+d->sa_info.alg.esp.crypto_ivec_length+SN_LENGTH);

  	dlen = htons(iph->tot_len)-ihl - sizeof(spi)-d->sa_info.alg.esp.crypto_ivec_length-SN_LENGTH-
		d->sa_info.alg.esp.auth_ivec_length-(d->sa_info.alg.esp.auth_data_length<<2);

    	if ((dlen % (MAX_AES_PAD_LENGTH+1)) && (!(t->id == ESP_NULL))){
          	printk(KERN_INFO "%s: Inbound: Invalid Data Length: src %s dst %s spi %lx len %d\n", t->name,
                 	strcpy(addr_s1,ntoa(iph->saddr)), strcpy(addr_s2,ntoa(iph->daddr)), spi, dlen);
          	goto in_error;
        }

	/* Decrypt */
	switch (t->id)
	{
	case ESP_MARS_CBC:
        	mars_cbc_encrypt(pdata,dlen,(WORD *)d->sa_info.alg.esp.inbound_crypto_key.fast_key,ivec,MARS_DECRYPT);
	break;
	case ESP_RC6_CBC:
        	rc6_cbc_encrypt(pdata,dlen,*(rc6key *)d->sa_info.alg.esp.inbound_crypto_key.fast_key,ivec,RC6_DECRYPT);
	break;
	case ESP_RIJNDAEL_CBC:
        	rijndael_cbc_encrypt(pdata,dlen,(rijndaelkey *)d->sa_info.alg.esp.inbound_crypto_key.fast_key,ivec,RIJNDAEL_DECRYPT);
	break;
	case ESP_SERPENT_CBC:
        	serpent_cbc_encrypt(pdata,dlen,*(serpentkey *)d->sa_info.alg.esp.inbound_crypto_key.fast_key,ivec,SERPENT_DECRYPT);
	break;
	case ESP_TWOFISH_CBC:
        	twofish_cbc_encrypt(pdata,dlen,(twofishkey *)d->sa_info.alg.esp.inbound_crypto_key.fast_key,ivec,TWOFISH_DECRYPT);
	break;
	}

	pad_len = *(char *)(pdata+dlen-2);

  	if (pad_len > MAX_AES_PAD_LENGTH)
	{
    		printk(KERN_INFO "%s: Inbound: Decryption Error: src %s dst %s spi %lx\n", t->name,
         		strcpy(addr_s1,ntoa(iph->saddr)), strcpy(addr_s2,ntoa(iph->daddr)), spi);
    		goto in_error;
	}

  	iph->protocol=*(char *)(pdata+dlen-1);
	iph->tot_len = htons(dlen+ihl-pad_len-2);
	iph->check = 0;
	iph->check = ip_fast_csum((unsigned char *)iph,iph->ihl);
	skb_push(skb,ihl);

	/* IP header copy MUST be in tail to head order (hence r_memcpy) to avoid the (ihl<<2)-8 bytes of overlap */
	r_memcpy(skb->data,(unsigned char *)iph,ihl);
	skb_trim(skb,ihl+dlen-pad_len-2);
    	skb->nh.raw=skb->data;
	skb->h.raw=skb->data+ihl;

	return skb;

in_error:

	if (skb) kfree_skb(skb);
	return NULL;

}

struct sk_buff *aes_cipher_output(struct sk_buff *skb, struct sadb_dst_node *d, struct ipsec_transform *t)
{
	struct rtable *rt=(struct rtable*)skb->dst;
        struct device *dev = rt->u.dst.dev;
	struct iphdr *iph = (struct iphdr *)skb->data;
	unsigned long spi,sn;
	unsigned short dlen,pad_len,ihl;
	unsigned char ivec[AES_IVEC_LENGTH];
	int error;
	struct sk_buff *tmpskb=0;

#define errout(e) { error = e; printk(KERN_INFO "%s: Outbound: Protocol Error %d\n", t->name, e);goto out_error;}
	spi = d->sa_info.peer_spi;	

	if (d->sa_info.flags & IPSEC_TUNNEL_FLAG)
	{
		dlen = ntohs(iph->tot_len);
		ihl = 20;
	}
	else
	{
		dlen = ntohs(iph->tot_len)-(iph->ihl<<2);
		ihl = iph->ihl<<2;
	}

	/* setup sequence number before generating IVEC for RFC1829-compat mode */
	sn = htonl(d->sa_info.sn);
	d->sa_info.sn++;

	/* Generate IVEC */
       	generate_ivec(ivec,d->sa_info.alg.esp.crypto_ivec_length);

	/* allocate skbuff to it's largest size possible */

	if (!(tmpskb = ipsec_alloc_skb(skb,sizeof(spi)+dlen+((dev->hard_header_len+15)&~15)+SN_LENGTH+
		ihl+d->sa_info.alg.esp.crypto_ivec_length+MAX_AES_PAD_LENGTH+2+(d->sa_info.alg.esp.auth_data_length<<2)+
		d->sa_info.alg.esp.auth_ivec_length)))errout(ENOBUFS);

	tmpskb->mac.raw = tmpskb->data;

	skb_reserve(tmpskb,((dev->hard_header_len+15)&~15)+ihl+SN_LENGTH+
		d->sa_info.alg.esp.crypto_ivec_length+
		d->sa_info.alg.esp.auth_ivec_length+sizeof(spi));
	skb_put(tmpskb,dlen);

	/* add IP data */
	if (d->sa_info.flags & IPSEC_TUNNEL_FLAG)
		memcpy(tmpskb->data,skb->data,dlen);
	else 
		memcpy(tmpskb->data,skb->data+(iph->ihl<<2),dlen);

        pad_len = pad_data(tmpskb->data,dlen+2,dlen,MAX_AES_PAD_LENGTH+1);
	skb_put(tmpskb,pad_len+2);

        *(char *)(tmpskb->data+dlen+pad_len)=pad_len;

	/* store transport protocol type */
	if (d->sa_info.flags & IPSEC_TUNNEL_FLAG)
        	*(char *)(tmpskb->data+dlen+pad_len+1)=IPPROTO_IPIP;
  	else
        	*(char *)(tmpskb->data+dlen+pad_len+1)=iph->protocol;

	skb_push(tmpskb,ihl+sizeof(spi)+SN_LENGTH+d->sa_info.alg.esp.crypto_ivec_length+
		d->sa_info.alg.esp.auth_ivec_length);
	memcpy(tmpskb->data,skb->data,ihl);
	spi = htonl(spi);
	memcpy(tmpskb->data+ihl,&spi,sizeof(spi));
	memcpy(tmpskb->data+ihl+sizeof(spi),&sn,SN_LENGTH);
	memcpy(tmpskb->data+ihl+sizeof(spi)+SN_LENGTH+d->sa_info.alg.esp.auth_ivec_length,
		ivec,d->sa_info.alg.esp.crypto_ivec_length);

	skb_put(tmpskb,(d->sa_info.alg.esp.auth_data_length<<2));

	iph = (struct iphdr *)tmpskb->data;
	iph->protocol = IPPROTO_ESP;

    	iph->tot_len = htons(tmpskb->len);

	/* change address if IPIP encapsulation */
	/* fill in new header if tunneling */
	if (d->sa_info.flags & IPSEC_TUNNEL_FLAG){

		ip_rt_put(rt);
		error = ip_route_output(&rt,d->sa_info.peer_addr.addr_union.ip_addr.s_addr,0,iph->tos,0);
		if (error)errout(error);
    		iph->saddr = rt->rt_src;
    		iph->daddr = rt->rt_dst;
    		iph->ihl = ihl>>2;
    		iph->id = htons(ip_id_count++);
    		iph->frag_off &= htons(IP_DF);
    		iph->ttl = ip_statistics.IpDefaultTTL;

		tmpskb->dev = rt->u.dst.dev;
		tmpskb->dst = dst_clone(&rt->u.dst);
	}
	
	/* Encrypt */
	switch (t->id)
	{
	case ESP_MARS_CBC:
        	mars_cbc_encrypt(tmpskb->data+ihl+sizeof(spi)+SN_LENGTH+d->sa_info.alg.esp.auth_ivec_length+
			d->sa_info.alg.esp.crypto_ivec_length,dlen+pad_len+2,
			(WORD *)d->sa_info.alg.esp.outbound_crypto_key.fast_key,ivec,MARS_ENCRYPT);
	break;
	case ESP_RC6_CBC:
        	rc6_cbc_encrypt(tmpskb->data+ihl+sizeof(spi)+SN_LENGTH+d->sa_info.alg.esp.auth_ivec_length+
			d->sa_info.alg.esp.crypto_ivec_length,dlen+pad_len+2,
			*(rc6key *)d->sa_info.alg.esp.outbound_crypto_key.fast_key,ivec,RC6_ENCRYPT);
	break;
	case ESP_RIJNDAEL_CBC:
        	rijndael_cbc_encrypt(tmpskb->data+ihl+sizeof(spi)+SN_LENGTH+d->sa_info.alg.esp.auth_ivec_length+
			d->sa_info.alg.esp.crypto_ivec_length,dlen+pad_len+2,
			(rijndaelkey *)d->sa_info.alg.esp.outbound_crypto_key.fast_key,ivec,RIJNDAEL_ENCRYPT);
	break;
	case ESP_SERPENT_CBC:
        	serpent_cbc_encrypt(tmpskb->data+ihl+sizeof(spi)+SN_LENGTH+d->sa_info.alg.esp.auth_ivec_length+
			d->sa_info.alg.esp.crypto_ivec_length,dlen+pad_len+2,
			*(serpentkey *)d->sa_info.alg.esp.outbound_crypto_key.fast_key,ivec,SERPENT_ENCRYPT);
	break;
	case ESP_TWOFISH_CBC:
        	twofish_cbc_encrypt(tmpskb->data+ihl+sizeof(spi)+SN_LENGTH+d->sa_info.alg.esp.auth_ivec_length+
			d->sa_info.alg.esp.crypto_ivec_length,dlen+pad_len+2,
			(twofishkey *)d->sa_info.alg.esp.outbound_crypto_key.fast_key,ivec,TWOFISH_ENCRYPT);
	break;
	}

	if (!tmpskb->data){
		printk(KERN_INFO "%s: Outbound: Encryption error: src %s dst %s spi %lx\n", t->name,
              		strcpy(addr_s1,ntoa(iph->saddr)), strcpy(addr_s2,ntoa(iph->daddr)), spi);
          	goto out_error;
	}

	if (skb) kfree_skb(skb);
        return tmpskb;

out_error:
	if (tmpskb) kfree_skb(tmpskb);
	if (skb) kfree_skb(skb);
  	return NULL;

}
