/* common_auth.c */
/* AH/ESP, Adds Authentication to ESP and builds AH header for a 
   group of authentication algorithms  */
/* Includes: HMAC-MD5-96 and HMAC-SHA1-96 */
             
#include <linux/types.h>
#include <linux/kernel.h>
#include <linux/skbuff.h>
#include <linux/errno.h>
#include <net/ip.h>
#include <linux/in.h>
#include <linux/in6.h>
#include <asm/uaccess.h>
#include <asm/checksum.h>
#include "../crypto/sha/sha.h"
#include "../crypto/md5/md5.h"
#include "../sadb.h"
#include "../ipsec.h"
#include "../transform.h"
#include "hmacsha96.h"
#include "hmacmd596.h"
#include "common_auth.h"

static char addr_s1[18],addr_s2[18];

struct sk_buff *common_auth_input(struct sk_buff *skb, struct sadb_dst_node *d, struct ipsec_transform *t)
{
	unsigned char *pdata=0, *tmpiph=0, adata[MAXHASHLEN],adata2[MAXHASHLEN];
	unsigned long spi=0;
	struct iphdr *iph = (struct iphdr *)skb->data;
	struct ah_header ah;
	unsigned short dlen, ihl = iph->ihl<<2;
	int error=0,sn_length;

	spi = d->sa_info.spi;

        switch (d->sa_info.protocol)
        {
        case IPPROTO_AH:

#define errin(e) { error = e; printk(KERN_INFO "%s: Inbound: Protocol Error %d\n", t->name, e);goto in_error;}

	pdata = skb_pull(skb,ihl);
	ah = *(struct ah_header *)pdata;
	sn_length = d->sa_info.sn?4:0;

	if ((ah.length<<2) != ((d->sa_info.alg.ah.auth_data_length<<2)+(d->sa_info.sn?4:0)))
	{
		printk(KERN_INFO "%s: SADB: Uses a data length of %d (got %d): dst %s spi %lx\n", t->name, (d->sa_info.alg.ah.auth_data_length+(d->sa_info.sn?1:0)),ah.length,
			strcpy(addr_s2,ntoa(d->dst.addr_union.ip_addr.s_addr)), d->sa_info.spi);
	  	return 0;
	}
	memcpy(adata,pdata+sizeof(ah)+sn_length,d->sa_info.alg.ah.auth_data_length<<2);
	memset(pdata+sizeof(ah)+sn_length,0,d->sa_info.alg.ah.auth_data_length<<2);

	if (!(tmpiph = (unsigned char *)kmalloc(ihl,GFP_ATOMIC))) errin(ENOBUFS);

	memcpy(tmpiph,iph,ihl);

	iph_ah_prep(tmpiph);

	switch (t->id){

	case HMACSHA96: 
	{
		SHA_CTX ctx;

		SHAInit(&ctx);
		SHAUpdate(&ctx,d->sa_info.alg.ah.inbound_auth_key.fast_key,SHABLOCKLEN);
		SHAUpdate(&ctx,tmpiph,ihl);
		SHAUpdate(&ctx,pdata, htons(iph->tot_len)-ihl);
		SHAFinal(&ctx);

		memcpy(adata2,ctx.buffer,SHAHASHLEN);
	
		SHAInit(&ctx);
		SHAUpdate(&ctx,d->sa_info.alg.ah.inbound_auth_key.fast_key+SHABLOCKLEN,SHABLOCKLEN);
		SHAUpdate(&ctx,adata2,SHAHASHLEN);
		SHAFinal(&ctx);

		if (memcmp(adata,(unsigned char *)ctx.buffer,d->sa_info.alg.ah.auth_data_length<<2)){
			printk(KERN_INFO "%s: Inbound: Authenticator Failure: src %s dst %s spi %lx alg_id %d\n", t->name,
        			strcpy(addr_s1,ntoa(iph->saddr)), strcpy(addr_s2,ntoa(iph->daddr)), spi,d->sa_info.alg.ah.auth_alg_id);
			bprintk(adata,d->sa_info.alg.ah.auth_data_length<<2,"found");
			bprintk((unsigned char*)ctx.buffer,d->sa_info.alg.ah.auth_data_length<<2,"expected");
			goto in_error;
		}
	}
	break;
	case HMACMD596: 
	{
		struct MD5Context ctx;

                MD5Init(&ctx);
                MD5Update(&ctx,d->sa_info.alg.ah.inbound_auth_key.fast_key,MD5BLOCKLEN);
                MD5Update(&ctx,tmpiph,ihl); 
                MD5Update(&ctx,pdata, htons(iph->tot_len)-ihl);
                MD5Final(adata2,&ctx);

                MD5Init(&ctx);
                MD5Update(&ctx,d->sa_info.alg.ah.inbound_auth_key.fast_key+MD5BLOCKLEN,MD5BLOCKLEN);
                MD5Update(&ctx,adata2,MD5HASHLEN);
                MD5Final(adata2,&ctx);

                if (memcmp(adata,adata2,d->sa_info.alg.ah.auth_data_length<<2)){
                        printk(KERN_INFO "%s: Inbound: Authenticator Failure: src %s dst %s spi %lx alg_id %d\n", 
                                t->name, strcpy(addr_s1,ntoa(iph->saddr)), strcpy(addr_s2,ntoa(iph->daddr)), 
                                spi,d->sa_info.alg.ah.auth_alg_id);
                        goto in_error;
                }

	}
	break;
	}
  	iph->protocol=ah.next_header;
	iph->tot_len = htons(ntohs(iph->tot_len) - sizeof(ah) - (d->sa_info.alg.ah.auth_data_length<<2) - sn_length);
	iph->check = 0;
	iph->check = ip_fast_csum((unsigned char *)iph,iph->ihl);
	skb_pull(skb,sizeof(ah)+sn_length+(d->sa_info.alg.ah.auth_data_length<<2));
	skb_push(skb,(iph->ihl<<2));

	/* IP header copy MUST be in tail to head order (hence r_memcpy) to avoid the 
	   (ihl<<2)-8 bytes of overlap */
	r_memcpy(skb->data,(unsigned char *)iph,ihl);
    	skb->nh.raw=skb->data;
	skb->h.raw=skb->data+ihl;

	break;

	case IPPROTO_ESP:
		dlen = ntohs(iph->tot_len) - ihl-(d->sa_info.alg.esp.auth_data_length<<2);

		switch (t->id)
		{
		case HMACSHA96:
		{
			SHA_CTX ctx;

			SHAInit(&ctx);
			SHAUpdate(&ctx,d->sa_info.alg.esp.inbound_auth_key.fast_key,SHABLOCKLEN);
			SHAUpdate(&ctx,skb->data+ihl,dlen);
			SHAFinal(&ctx);

			memcpy(adata,ctx.buffer,SHAHASHLEN);

			SHAInit(&ctx);
			SHAUpdate(&ctx,d->sa_info.alg.esp.inbound_auth_key.fast_key+SHABLOCKLEN,SHABLOCKLEN);
			SHAUpdate(&ctx,adata,SHAHASHLEN);
			SHAFinal(&ctx);

			if (memcmp(skb->data+dlen+ihl,(unsigned char *)ctx.buffer,(d->sa_info.alg.esp.auth_data_length<<2))){
        			printk(KERN_INFO "%s: Inbound: Authenticator Failure: src %s dst %s spi %lx alg_id %d\n",
                			t->name, strcpy(addr_s1,ntoa(iph->saddr)), strcpy(addr_s2,ntoa(iph->daddr)),
                			spi,d->sa_info.alg.esp.auth_alg_id);
        			goto in_error;
			}
		}
		break;
		case HMACMD596:
		{
			struct MD5Context ctx;


	                MD5Init(&ctx);
                	MD5Update(&ctx,d->sa_info.alg.esp.inbound_auth_key.fast_key,MD5BLOCKLEN);
                	MD5Update(&ctx,skb->data+ihl,dlen);
                	MD5Final(adata,&ctx);

                	MD5Init(&ctx);
                	MD5Update(&ctx,d->sa_info.alg.esp.inbound_auth_key.fast_key+MD5BLOCKLEN,MD5BLOCKLEN);
                	MD5Update(&ctx,adata,MD5HASHLEN);
                	MD5Final(adata,&ctx);

                	if (memcmp(skb->data+dlen+ihl,adata,(d->sa_info.alg.esp.auth_data_length<<2))){
                        	printk(KERN_INFO "%s: Inbound: Authenticator Failure: src %s dst %s spi %lx alg_id %d\n", 
                                	t->name, strcpy(addr_s1,ntoa(iph->saddr)), strcpy(addr_s2,ntoa(iph->daddr)), 
                                	spi,d->sa_info.alg.esp.auth_alg_id);
                        	goto in_error;
                	}
                }
		break;
		}
	break;
        }

	if (tmpiph) kfree(tmpiph);
	return skb;

in_error:

	if (skb) kfree_skb(skb);
	if (tmpiph) kfree(tmpiph);
	return NULL;
}

struct sk_buff *common_auth_output(struct sk_buff *skb, struct sadb_dst_node *d, struct ipsec_transform *t)
{
	struct rtable *rt=(struct rtable*)skb->dst;
        struct device *dev = rt->u.dst.dev;
	struct iphdr *iph = (struct iphdr *)skb->data;
	struct iphdr *tmpiph = iph;
	unsigned long spi;
	unsigned short dlen,ihl=(iph->ihl<<2);
	struct sk_buff *tmpskb=0;
	struct ah_header ah;
	unsigned char adata[MAXHASHLEN];
	int error=0,sn_length;

        switch (d->sa_info.protocol)
        {
        case IPPROTO_AH:

#define errout(e) { error = e; printk(KERN_INFO "%s: Outbound: Protocol Error %d\n", t->name, e);goto out_error;}

	spi = ah.spi = d->sa_info.peer_spi;
	ah.spi = htonl(ah.spi);
	ah.reserved = 0;

	ah.length = d->sa_info.alg.ah.auth_data_length + (d->sa_info.sn?1:0);

	if (d->sa_info.flags & IPSEC_TUNNEL_FLAG)
	{
		ah.next_header = IPPROTO_IPIP;
		dlen = ntohs(iph->tot_len);
		ihl = 20;
	}
	else
	{
		ah.next_header = iph->protocol;
		dlen = ntohs(iph->tot_len)-(iph->ihl<<2);
		ihl = (iph->ihl<<2);
	}

	sn_length = d->sa_info.sn?4:0;

	/* allocate skbuff to it's largest size possible */
	if (!(tmpskb = ipsec_alloc_skb(skb,dlen+d->sa_info.alg.ah.auth_ivec_length+((dev->hard_header_len+15)&~15)+ihl+sn_length+sizeof(ah)+(d->sa_info.alg.ah.auth_data_length<<2))))errout(ENOBUFS);

	skb_reserve(tmpskb,((dev->hard_header_len+15)&~15)+ihl+sizeof(ah)+sn_length+(d->sa_info.alg.ah.auth_data_length<<2));
	skb_put(tmpskb,dlen);

	/* add IP data */
	if (d->sa_info.flags & IPSEC_TUNNEL_FLAG)
		memcpy(tmpskb->data,skb->data,dlen);
	else 
		memcpy(tmpskb->data,skb->data+(iph->ihl<<2),dlen);


	skb_push(tmpskb,(d->sa_info.alg.ah.auth_data_length<<2)+sizeof(ah)+ihl+sn_length);

	memcpy(tmpskb->data,skb->data,ihl);

	memcpy(tmpskb->data+ihl,&ah,sizeof(ah));

	if (d->sa_info.sn)
	{
		unsigned long tmp_sn = ntohl(d->sa_info.sn);
		memcpy(tmpskb->data+ihl+sizeof(ah),&tmp_sn,sn_length);
		d->sa_info.sn++;
	}

	iph = (struct iphdr *)tmpskb->data;
	iph->protocol = IPPROTO_AH;

	/* change address if IPIP encapsulation */
	/* fill in new header if tunneling */
	if (d->sa_info.flags & IPSEC_TUNNEL_FLAG){

		ip_rt_put(rt);
		error = ip_route_output(&rt,d->sa_info.peer_addr.addr_union.ip_addr.s_addr,0,iph->tos,0);
		if (error)errout(error);
    		iph->saddr = rt->rt_src;
    		iph->daddr = rt->rt_dst;
    		iph->ihl = ihl>>2;
    		iph->id = htons(ip_id_count++);
    		iph->frag_off &= htons(IP_DF);
    		iph->ttl = ip_statistics.IpDefaultTTL;

		tmpskb->dev = rt->u.dst.dev;
		tmpskb->dst = dst_clone(&rt->u.dst);
	}

    	iph->tot_len = htons(ihl + dlen + sizeof(ah)+(d->sa_info.alg.ah.auth_data_length<<2)+sn_length);

	memcpy(tmpiph,iph,ihl);
	iph_ah_prep((unsigned char *)tmpiph);

	memset(adata,0,d->sa_info.alg.ah.auth_data_length<<2);

	switch (t->id)
	{
	case HMACSHA96:
	{
		SHA_CTX ctx;

		SHAInit(&ctx);
		SHAUpdate(&ctx,d->sa_info.alg.ah.inbound_auth_key.fast_key,SHABLOCKLEN);
		SHAUpdate(&ctx,(unsigned char *)tmpiph,ihl);
		SHAUpdate(&ctx,(unsigned char *)&ah,sizeof(ah));
		SHAUpdate(&ctx,tmpskb->data+ihl+sizeof(ah),sn_length);
		SHAUpdate(&ctx,adata,(d->sa_info.alg.ah.auth_data_length<<2));
		SHAUpdate(&ctx,tmpskb->data+ihl+sizeof(ah)+sn_length+(d->sa_info.alg.ah.auth_data_length<<2),dlen);
		SHAFinal(&ctx);
	
		memcpy(adata,ctx.buffer,SHAHASHLEN);

		SHAInit(&ctx);
		SHAUpdate(&ctx,d->sa_info.alg.ah.inbound_auth_key.fast_key+SHABLOCKLEN,SHABLOCKLEN);
		SHAUpdate(&ctx,adata,SHAHASHLEN);
		SHAFinal(&ctx);

		memcpy(tmpskb->data+ihl+sizeof(ah)+sn_length,ctx.buffer,d->sa_info.alg.ah.auth_data_length<<2);
	}
	break;
	case HMACMD596:
	{
		struct MD5Context ctx;

                MD5Init(&ctx);
                MD5Update(&ctx,d->sa_info.alg.ah.inbound_auth_key.fast_key,MD5BLOCKLEN);
                MD5Update(&ctx,(unsigned char *)tmpiph,ihl);
                MD5Update(&ctx,(unsigned char *)&ah,sizeof(ah));
                MD5Update(&ctx,tmpskb->data+ihl+sizeof(ah),sn_length);
                MD5Update(&ctx,adata,d->sa_info.alg.ah.auth_data_length<<2);
                MD5Update(&ctx,tmpskb->data+ihl+sizeof(ah)+sn_length+(d->sa_info.alg.ah.auth_data_length<<2),dlen);
                MD5Final(adata,&ctx);

                MD5Init(&ctx);
                MD5Update(&ctx,d->sa_info.alg.ah.inbound_auth_key.fast_key+MD5BLOCKLEN,MD5BLOCKLEN);
                MD5Update(&ctx,adata,MD5HASHLEN);
                MD5Final(adata,&ctx);

                memcpy(tmpskb->data+ihl+sizeof(ah)+sn_length,adata,d->sa_info.alg.ah.auth_data_length<<2);

	}
	break;
	}
	        if (skb) kfree_skb(skb);
        	return tmpskb;

        break;

        case IPPROTO_ESP:
                dlen = ntohs(iph->tot_len)-ihl-(d->sa_info.alg.esp.auth_data_length<<2);
		switch (t->id)
		{
		case HMACSHA96:
		{
			SHA_CTX ctx;

                	SHAInit(&ctx);
                	SHAUpdate(&ctx,d->sa_info.alg.esp.inbound_auth_key.fast_key,SHABLOCKLEN);
                	SHAUpdate(&ctx,skb->data+ihl,dlen);
                	SHAFinal(&ctx);

			memcpy(adata,ctx.buffer,SHAHASHLEN);

                	SHAInit(&ctx);
                	SHAUpdate(&ctx,d->sa_info.alg.esp.inbound_auth_key.fast_key+SHABLOCKLEN,SHABLOCKLEN);
                	SHAUpdate(&ctx,adata,SHAHASHLEN);
                	SHAFinal(&ctx);

	                memcpy(skb->data+ihl+dlen,ctx.buffer,(d->sa_info.alg.esp.auth_data_length<<2));
		}
		break;
		case HMACMD596:
		{
			struct MD5Context ctx;

                	MD5Init(&ctx);
	                MD5Update(&ctx,d->sa_info.alg.esp.inbound_auth_key.fast_key,MD5BLOCKLEN);
	                MD5Update(&ctx,skb->data+ihl,dlen);
                	MD5Final(adata,&ctx);

                	MD5Init(&ctx);
                	MD5Update(&ctx,d->sa_info.alg.esp.inbound_auth_key.fast_key+MD5BLOCKLEN,MD5BLOCKLEN);
                	MD5Update(&ctx,adata,MD5HASHLEN);
                	MD5Final(adata,&ctx);

                	memcpy(skb->data+ihl+dlen,adata,(d->sa_info.alg.esp.auth_data_length<<2));
		}
		break;

		}
		
                return skb;

        break;
	}

out_error:
	if (tmpskb) kfree_skb(tmpskb);
	if (skb) kfree_skb(skb);
  	return NULL;

}
