/*
 *  $Id$
 *
 *  libnet
 *  checksum implementation
 *  checksum.c - IP-type checksum routines
 *
 *  route|daemon9 <route@infonexus.com>
 */

#include "../include/libnet.h"


void
do_checksum(u_char *buf, int protocol, int len)
{
    struct tcphdr    *tcp_hdr;
    struct udphdr    *udp_hdr;
    struct ip        *ip_hdr;
    struct psuedohdr *p_hdr;
    u_char *p;

    switch (protocol)
    {
        case IPPROTO_TCP:
            ip_hdr  = (struct ip *)(buf);
            tcp_hdr = (struct tcphdr *)(buf + IP_H);
            tcp_hdr->th_sum = 0;
#if (INTEL)
            tcp_hdr->th_sum = tcp_check(tcp_hdr,
                                        len,
                                        ip_hdr->ip_src.s_addr,
                                        ip_hdr->ip_dst.s_addr);
#else
            /*
             *  Grab memory for a psuedoheader and the TCP packet.
             */
            p = (u_char *)malloc(P_H + len);
            if (!p)
            {
                perror("checksum: malloc");
                exit(1);
            }
            p_hdr = (struct psuedohdr *)p;

            memset(p_hdr, 0, P_H + len);

            p_hdr->ip_src   = ip_hdr->ip_src.s_addr;
            p_hdr->ip_dst   = ip_hdr->ip_dst.s_addr;
            p_hdr->protocol = IPPROTO_TCP;
            p_hdr->len      = htons(len);

            memcpy(p + P_H, tcp_hdr, len);
            tcp_hdr->th_sum = ip_check((u_short *)p, P_H + len);

            free(p);
            p = NULL;
#endif  /* INTEL */
            break;
        case IPPROTO_UDP:
            ip_hdr  = (struct ip *)(buf);
            udp_hdr = (struct udphdr *)(buf + IP_H);

            /*
             *  Grab memory for a psuedoheader and the UDP packet.
             */
            p = (u_char *)malloc(P_H + len);
            if (!p)
            {
                perror("checksum: malloc");
                exit(1);
            }
            p_hdr = (struct psuedohdr *)p;

            bzero(p_hdr, P_H + len);

            p_hdr->ip_src   = ip_hdr->ip_src.s_addr;
            p_hdr->ip_dst   = ip_hdr->ip_dst.s_addr;
            p_hdr->protocol = IPPROTO_UDP;
            p_hdr->len      = htons(len);

            udp_hdr->uh_sum = 0;
            bcopy(udp_hdr, p + P_H, len);
            udp_hdr->uh_sum = ip_check((u_short *)p, P_H + len);
            free(p);
            p = NULL;
            break;
        case IPPROTO_ICMP:
            break;
    }
}

#if (INTEL)
u_short
tcp_check(struct tcphdr *th, int len, u_long saddr, u_long daddr)
{
    u_long sum;

    __asm__("\taddl %%ecx, %%ebx\n\t"
        "adcl %%edx, %%ebx\n\t"
        "adcl $0, %%ebx"
        : "=b"(sum)
        : "0"(daddr), "c"(saddr), "d"((ntohs(len) << 16) + IPPROTO_TCP * 256)
        : "bx", "cx", "dx" );

    __asm__("\tmovl %%ecx, %%edx\n\t"
            "cld\n\t"
            "cmpl $32, %%ecx\n\t"
            "jb 2f\n\t"
            "shrl $5, %%ecx\n\t"
            "clc\n"
            "1:\t"
            "lodsl\n\t"
            "adcl %%eax, %%ebx\n\t"
            "lodsl\n\t"
            "adcl %%eax, %%ebx\n\t"
            "lodsl\n\t"
            "adcl %%eax, %%ebx\n\t"
            "lodsl\n\t"
            "adcl %%eax, %%ebx\n\t"
            "lodsl\n\t"
            "adcl %%eax, %%ebx\n\t"
            "lodsl\n\t"
            "adcl %%eax, %%ebx\n\t"
            "lodsl\n\t"
            "adcl %%eax, %%ebx\n\t"
            "lodsl\n\t"
            "adcl %%eax, %%ebx\n\t"
            "loop 1b\n\t"
            "adcl $0, %%ebx\n\t"
            "movl %%edx, %%ecx\n"
            "2:\t"
            "andl $28, %%ecx\n\t"
            "je 4f\n\t"
            "shrl $2, %%ecx\n\t"
            "clc\n"
            "3:\t"
            "lodsl\n\t"
            "adcl %%eax, %%ebx\n\t"
            "loop 3b\n\t"
            "adcl $0, %%ebx\n"
            "4:\t"
            "movl $0, %%eax\n\t"
            "testw $2, %%dx\n\t"
            "je 5f\n\t"
            "lodsw\n\t"
            "addl %%eax, %%ebx\n\t"
            "adcl $0, %%ebx\n\t"
            "movw $0, %%ax\n"
            "5:\t"
            "test $1, %%edx\n\t"
            "je 6f\n\t"
            "lodsb\n\t"
            "addl %%eax, %%ebx\n\t"
            "adcl $0, %%ebx\n"
            "6:\t"
            "movl %%ebx, %%eax\n\t"
            "shrl $16, %%eax\n\t"
            "addw %%ax, %%bx\n\t"
            "adcw $0, %%bx"
        : "=b"(sum)
        : "0"(sum), "c"(len), "S"(th)
        : "ax", "bx", "cx", "dx", "si" );

    return ((~sum) & 0xffff);
}


u_short
ip_check(u_short *buff, int len)
{
    u_long sum = 0;

    if (len > 3)
    {
        __asm__("clc\n"
        "1:\t"
        "lodsl\n\t"
        "adcl %%eax, %%ebx\n\t"
        "loop 1b\n\t"
        "adcl $0, %%ebx\n\t"
        "movl %%ebx, %%eax\n\t"
        "shrl $16, %%eax\n\t"
        "addw %%ax, %%bx\n\t"
        "adcw $0, %%bx"
        : "=b" (sum) , "=S" (buff)
        : "0" (sum), "c" (len >> 2) ,"1" (buff)
        : "ax", "cx", "si", "bx");
    }
    if (len & 2)
    {
        __asm__("lodsw\n\t"
        "addw %%ax, %%bx\n\t"
        "adcw $0, %%bx"
        : "=b" (sum) , "=S" (buff)
        : "0" (sum), "c" (len >> 2) ,"1" (buff)
        : "ax", "cx", "si", "bx");
    }
    if (len & 2)
    {
        __asm__("lodsw\n\t"
        "addw %%ax, %%bx\n\t"
        "adcw $0, %%bx"
        : "=b" (sum), "=S" (buff)
        : "0" (sum), "1" (buff)
        : "bx", "ax", "si");
    }
    if (len & 1)
    {
        __asm__("lodsb\n\t"
        "movb $0, %%ah\n\t"
        "addw %%ax, %%bx\n\t"
        "adcw $0, %%bx"
        : "=b" (sum), "=S" (buff)
        : "0" (sum), "1" (buff)
        : "bx", "ax", "si");
    }
    if (len & 1)
    {
        __asm__("lodsb\n\t"
        "movb $0, %%ah\n\t"
        "addw %%ax, %%bx\n\t"
        "adcw $0, %%bx"
        : "=b" (sum), "=S" (buff)
        : "0" (sum), "1" (buff)
        : "bx", "ax", "si");
    }
    sum  = ~sum;
    return (sum & 0xffff);
}
#else

u_short
ip_check(register u_short *addr, register int len)
{
        register int nleft = len;
        register u_short *w = addr;
        register u_short answer;
        register int sum = 0;

        /*
         *  Our algorithm is simple, using a 32 bit accumulator (sum),
         *  we add sequential 16 bit words to it, and at the end, fold
         *  back all the carry bits from the top 16 bits into the lower
         *  16 bits.
         */
        while (nleft > 1)  {
                sum += *w++;
                nleft -= 2;

        }

        /* mop up an odd byte, if necessary */
        if (nleft == 1)
                sum += *(u_char *)w;

        /*
         * add back carry outs from top 16 bits to low 16 bits
         */
        sum = (sum >> 16) + (sum & 0xffff);     /* add hi 16 to low 16 */
        sum += (sum >> 16);                     /* add carry */
        answer = ~sum;                          /* truncate to 16 bits */
        return (answer);
}
#endif  /* INTEL */

/* EOF */
