#include "hexdump.h"
#include "crc32.h"

#include "CipeBlowfishEncryptor.h"
#include "CipeIdeaEncryptor.h"
#include "CipeNullEncryptor.h"
#include "CipeEncryptor.h"
#include "CipeSocketIO.h"
#include "CipeGenericIO.h"

//========================================================================================
//
//========================================================================================
CipeSocketIO::CipeSocketIO (CipePeer &p_Peer) : m_Peer (p_Peer), m_RequestKey (FALSE), m_GenerateKey (FALSE), m_TickCount (0), m_PingTicks (0), m_RequestPing (FALSE)
   {
    SOCKADDR_IN l_LocalInfo;

    CopyMAC();

    m_PreferredProtocol = htons (0x0800); // TCPIP

    l_LocalInfo.sin_addr.s_addr = GetHostAddress (m_Peer.LocalIP());
    l_LocalInfo.sin_port = m_Peer.LocalPort();
    l_LocalInfo.sin_family = AF_INET;

    m_PeerInfo.sin_addr.s_addr = GetHostAddress (m_Peer.PeerIP());
    m_PeerInfo.sin_port = m_Peer.PeerPort();
    m_PeerInfo.sin_family = AF_INET;

    m_TaskObject = m_TapObject = 0;

    if ((m_Socket = socket (AF_INET, SOCK_DGRAM, 0)) == SOCKET_ERROR)
       {
        DbgPrint ("[%s] can't create UDP socket\n", Name().c_str());
        throw CipeSocketIOException();
       }
    else if (bind (m_Socket, (LPSOCKADDR) &l_LocalInfo, sizeof (struct sockaddr_in)))
       {
        DbgPrint ("[%s] Can't bind to port [%d]\n", Name().c_str(), ntohs (l_LocalInfo.sin_port));
        closesocket (m_Socket);
        throw CipeSocketIOException();
       }

    //================================================================
    // If any other algorithms are to be used, subclass CipeEncryptor
    // and conform to the interface. Instantiate the one desired here
    // There should be a null encryptor type as well. Any established
    // encryption engine should "always" encrypt and not, itself, be
    // conditionally invokable
    //================================================================
    if (m_Peer.EncryptionType() == "BLOWFISH")
       m_Encryptor = new CipeBlowfishEncryptor();
    else if (m_Peer.EncryptionType() == "IDEA")
       m_Encryptor = new CipeIdeaEncryptor();
    else
       m_Encryptor = new CipeNullEncryptor();

    if (m_Encryptor)
       {
        DbgPrint ("[%s] Is using a %s encryptor\n", Name().c_str(), m_Peer.EncryptionType().c_str());
       }
    else
       {
        DbgPrint ("[%s] Couldn't instantiate a %s encryptor\n", Name().c_str(), m_Peer.EncryptionType().c_str());
        closesocket (m_Socket);
        throw CipeSocketIOException();
       }

    DbgPrint ("[%s] peer started on adapter [%s]\n", Name().c_str(), m_Peer.Adapter().DisplayName().c_str());

    m_Encryptor->InstallStaticKey (m_Peer.TextKey());
   }

CipeSocketIO::~CipeSocketIO()
   {
    if (m_TaskObject)
       {
        m_TaskObject->Send (m_Peer.ShutdownScript());
       }

    if (m_Socket != SOCKET_ERROR)
       {
        closesocket (m_Socket);
        m_Socket = SOCKET_ERROR;
       }

    if (m_SendSocket != SOCKET_ERROR)
       {
        closesocket (m_SendSocket);
        m_SendSocket = SOCKET_ERROR;
       }
   }

//========================================================================================
//
//========================================================================================
void CipeSocketIO::RequestAsyncReceive() throw (CipeSocketIOException)
   {
    if (! CheckForShutdownIndication())
       {
        m_PeerLen = sizeof (m_PeerInfo);
        DWORD l_Flags = 0;
        WSABUF l_Desc;

        l_Desc.buf = (char *) m_Buffer.buffer;
        l_Desc.len = m_Buffer.length = DescriptorSize();

        do
           {
            while (m_Synchronous = (WSARecvFrom (m_Socket, &l_Desc, 1, &m_Buffer.length, &l_Flags, (struct sockaddr *) &m_PeerInfo, &m_PeerLen, &m_Overlapped, NULL) != SOCKET_ERROR))
               {
                CompleteAsyncReceive();
               }
           }
        while (WSAGetLastError() == WSAECONNRESET);

        if (WSAGetLastError() != WSA_IO_PENDING)
           {
            DbgPrint ("[%s] failed on pending read attempt\n", Name().c_str());
           }

        if (m_GenerateKey)
           {
            PerformKeyExchange();
           }

        if (m_RequestKey)
           {
            RequestKeyExchange();
           }

        if (m_RequestPing)
           {
            RequestPing();
           }
       }
   }

void CipeSocketIO::CompleteAsyncReceive()
   {
    unsigned long l_KeyTimeout = m_Peer.KeyTimeout() * 1000;
    unsigned long l_TickCount = GetTickCount();
    CipePacketDescriptor l_Packet;

    //------------------------------------------------------------
    // Remember, this routine artificially prepends unformatted
    // MAC info onto l_Packet
    //------------------------------------------------------------
    if (! m_Synchronous)
       {
        m_Buffer.length = m_Overlapped.InternalHigh;
       }

    if (l_TickCount < m_TickCount) // Counter wraps around every 49.7 days !
       {
        m_TickCount = l_TickCount;
       }

    if ((m_Tx % m_Peer.PacketTimeout()) == 0 && m_Tx > 1)
       {
        DbgPrint ("[%s] Maximal tx count reached. Forcing key exchange\n", m_Peer.Name().c_str());
        m_GenerateKey = TRUE;
       }

    if ((m_Rx % m_Peer.PacketTimeout()) == 0 && m_Rx > 1)
       {
        DbgPrint ("[%s] Maximal rx count reached. Requesting key exchange\n", Name().c_str());
        m_RequestKey = TRUE;
       }

    if (l_KeyTimeout && l_TickCount - m_TickCount > l_KeyTimeout)
       {
        DbgPrint ("[%s] Key lifetime expired. Forcing key exchange\n", Name().c_str());
        m_TickCount = l_TickCount;
        m_GenerateKey = TRUE;
        m_RequestKey = TRUE;
       }

    EventReset();

    if (m_Buffer.length > DescriptorSize() || m_Buffer.length < (sizeof (unsigned long) * 3))
       ;
    else if ((l_Packet.length = m_Encryptor->Decrypt (l_Packet.raw, m_Buffer.buffer, m_Buffer.length, KEY_DYNAMIC_DECRYPT, l_Packet.flags)) < sizeof (unsigned char))
       m_RequestKey = TRUE;
    else if (l_Packet.length += sizeof (l_Packet.mac))
       HandlePeerEvent (l_Packet);
   }

void CipeSocketIO::TimeoutEvent()
   {
    if (m_PingTicks && (m_PingTicks + PING_TIMEOUT) < GetTickCount()) // CT_PING timeout, try key exchange
       {
        RequestKeyExchange();
        m_RequestPing = TRUE;
        return;
       }

    if (m_RequestPing || (m_PingTicks == 0 && m_Rx == 0)) // No packets received yet, so CT_PING 'em
       {
        RequestPing();
       }

    if (m_GenerateKey)
       {
        PerformKeyExchange();
       }

    if (m_RequestKey)
       {
        RequestKeyExchange();
       }
   }

void CipeSocketIO::Send (CipePacketDescriptor &p_Buffer)
   {
    if (MatchingMAC (p_Buffer.mac.destination, BroadcastMAC())) // Broadcast packet, recompute icmp
       {
        p_Buffer.ip.destination = GetHostAddress (m_Peer.PeerPTP());
        ComputeHeaderCheckSum (p_Buffer.ip);
       }

    Send (p_Buffer, NK_DATA);
    ++m_Tx;
   }

void CipeSocketIO::Send (CipePacketDescriptor &p_Buffer, NK_Type p_Type, CipeEncryptionKeyType p_Mode)
   {
    unsigned long l_BytesWritten = 0, l_Flags = 0, l_PeerLen = sizeof (m_PeerInfo);
    UDPBUFFER l_Packet;
    WSABUF l_Desc;

    //------------------------------------------------------------
    // Remember, this routine must strip the MAC info in p_Buffer
    //------------------------------------------------------------

    if (p_Buffer.length >= sizeof (p_Buffer.mac) && p_Buffer.length <= DescriptorSize())
       {
        l_Desc.len = m_Encryptor->Encrypt (l_Packet, p_Buffer.raw, p_Buffer.length - sizeof (p_Buffer.mac), p_Mode, p_Type);
        l_Desc.buf = (char *) l_Packet;

        if (l_Desc.len == 0)
           {DbgPrint ("[%s] failed encrypt attempt\n", Name().c_str()); ++m_TxErr;}
        else if (WSASendTo (m_Socket, &l_Desc, 1, &l_BytesWritten, 0, (struct sockaddr *) &m_PeerInfo, l_PeerLen, &m_OverlappedEx, NULL) != SOCKET_ERROR)
           ;
        else if (WSAGetLastError() == WSA_IO_PENDING)
           WSAGetOverlappedResult (m_Socket, &m_OverlappedEx, &l_BytesWritten, TRUE, &l_Flags);
        else
           {DbgPrint ("[%s] failed on write attempt\n", Name().c_str()); ++m_TxErr;}

        ResetEvent (m_OverlappedEx.hEvent);
       }
    else
       ++m_TxErr;
   }

void CipeSocketIO::Shutdown()
   {
    if (m_TaskObject)
       {
        string l_Command = string ("arp -d ") + inet_str (Address());
        m_TaskObject->Send (m_Peer.ShutdownScript());
        m_TaskObject->Send (l_Command);
        m_ShutdownIndicated = TRUE;
       }
   }

//========================================================================================
//
//========================================================================================
void CipeSocketIO::Enjoin (CipeGenericIO &p_Object)
   {
    if (p_Object.Name() == "JOB HANDLER")
       (m_TaskObject = &p_Object)->Send (m_Peer.StartupScript());
    else if (typeid (CipeSocketIO) != typeid (p_Object) && AssociateName() == p_Object.Name())
       m_TapObject = &p_Object;
   }

//========================================================================================
//
//========================================================================================
void CipeSocketIO::ComputeHeaderCheckSum (CipeIpInfo &p_IpInfo)
   {
    unsigned long l_ArrayLength = ((p_IpInfo.version_info & 0x0f) * sizeof (unsigned long)) / sizeof (unsigned short);
    unsigned short *l_Array = (unsigned short *) &p_IpInfo;
    unsigned long l_Sum = p_IpInfo.checksum = 0;

    for (int l_Index = 0; l_Index < l_ArrayLength; ++l_Index)
       {
        if ((l_Sum += ntohs (l_Array [l_Index])) & 0x10000) l_Sum = (l_Sum & 0xffff) + 1;
       }

    p_IpInfo.checksum = htons (~l_Sum);
   }

//========================================================================================
//
//========================================================================================
void CipeSocketIO::HandlePeerEvent (CipePacketDescriptor &p_Buffer)
   {
    //====================================================================================================
    // Do decode stuff here, forward decrypted data packet to TAP device. Don't forget to format MAC info !!!
    // Also remember that the incoming packet has had its length extended to include the MAC info
    //====================================================================================================

    unsigned long l_KeyDataLength = m_Encryptor->KeyDataSize();
    CipePacketDescriptor l_Datagram;

    switch (p_Buffer.flags == 2 ? (NK_Type) p_Buffer.nk.type : NK_DATA)
       {
        case NK_DATA:
           // DbgPrint ("[%s] received data packet of length=%ld\n", Name().c_str(), p_Buffer.length);
           memcpy (p_Buffer.mac.source, MAC(), sizeof (MACADDR)); // Ensures that source MAC is the peer's
           p_Buffer.mac.type = m_PreferredProtocol;
           m_TapObject->Send (p_Buffer);
           m_Rx++;
           break;

        case NK_IND:
           l_Datagram.nk.ack.crc = *((unsigned long *) (p_Buffer.nk.ind.keydata + l_KeyDataLength));

           if (htonl (crc32 (p_Buffer.nk.ind.keydata, l_KeyDataLength)) == l_Datagram.nk.ack.crc)
              {
               DbgPrint ("[%s] NK_IND: Using peer's new key for decryption. Sending NK_ACK. CRC=%lx\n", Name().c_str(), l_Datagram.nk.ack.crc);
               m_Encryptor->InstallKey (p_Buffer.nk.ind.keydata, KEY_DYNAMIC_DECRYPT);
               l_Datagram.length = sizeof (l_Datagram.mac) + sizeof (l_Datagram.nk.type) + sizeof (l_Datagram.nk.ack);
               l_Datagram.nk.type = NK_ACK;
               Send (l_Datagram, NK_KEY_EXCHANGE, KEY_STATIC_PRIMARY);
              }

           break;

        case NK_REQ:
           DbgPrint ("[%s] NK_REQ: Received request from peer to generate new sending key. CRC=%lx KEYLEN=%ld\n", Name().c_str(), m_CRC, l_KeyDataLength);
           m_GenerateKey = TRUE;
           break;

        case NK_ACK:
           if (p_Buffer.nk.ack.crc == m_CRC)
              {
               DbgPrint ("[%s] NK_ACK: Installing new encryption key. CRC=%lx\n", Name().c_str(), p_Buffer.nk.ack.crc);
               m_Encryptor->InstallKey (KEY_DYNAMIC_ENCRYPT);
              }
           else
              {
               DbgPrint ("[%s] NK_ACK: Bad CRC of my new key returned from peer CRC=%lx. Rejecting key\n", Name().c_str(), p_Buffer.nk.ack.crc);
              }

           break;

        CT_DUMMY:
           break;

        CT_DEBUG:
           DbgPrint ("[%s] Peer debug message received\n", Name().c_str());
           break;

        CT_PING:
           l_Datagram.nk.type = CT_PONG;
           l_Datagram.length = sizeof (l_Datagram.mac) + sizeof (l_Datagram.nk.type);
           DbgPrint ("[%s] CT_PING message received\n", Name().c_str());
           Send (l_Datagram, NK_KEY_EXCHANGE, KEY_DYNAMIC_ENCRYPT);
           break;

        CT_PONG:
           DbgPrint ("[%s] CT_PONG message received\n", Name().c_str());
           m_RequestPing = FALSE;
           m_PingTicks = 0;
           break;

        CT_KILL:
           DbgPrint ("[%s] remote is shutting down\n", Name().c_str());
           break;

        default:
           break;
       }
   }

void CipeSocketIO::RequestKeyExchange()
   {
    CipePacketDescriptor l_Datagram;

    m_RequestKey = FALSE;
    l_Datagram.nk.type = NK_REQ;
    l_Datagram.length = sizeof (l_Datagram.mac) + sizeof (l_Datagram.nk.type);

    Send (l_Datagram, NK_KEY_EXCHANGE, KEY_STATIC_PRIMARY);
   }


void CipeSocketIO::PerformKeyExchange()
   {
    unsigned char *l_KeyData = m_Encryptor->GenerateDynamicKeyData (KEY_DYNAMIC_ENCRYPT);
    unsigned long l_KeyDataLength = m_Encryptor->KeyDataSize();
    CipePacketDescriptor l_Datagram;

    m_GenerateKey = FALSE;
    m_TickCount = GetTickCount();
    m_CRC = htonl (crc32 (l_KeyData, l_KeyDataLength));

    memcpy (l_Datagram.nk.ind.keydata, l_KeyData, l_KeyDataLength);
    memcpy (l_Datagram.nk.ind.keydata + l_KeyDataLength, (unsigned char *) &m_CRC, sizeof (m_CRC));

    l_Datagram.nk.type = NK_IND;

    l_Datagram.length =
       (
        sizeof (l_Datagram.mac) +
        sizeof (l_Datagram.nk.type) +
        sizeof (l_Datagram.nk.ind) +
        l_KeyDataLength +
        sizeof (l_Datagram.nk.ack.crc)
       );

    Send (l_Datagram, NK_KEY_EXCHANGE, KEY_STATIC_PRIMARY);
   }

void CipeSocketIO::RequestPing()
   {
    CipePacketDescriptor l_Datagram;

    m_RequestPing = FALSE;

    l_Datagram.nk.type = CT_PING;
    l_Datagram.length = sizeof (l_Datagram.mac) + sizeof (l_Datagram.nk.type);

    DbgPrint ("[%s] Sending CT_PING message\n", Name().c_str());

    Send (l_Datagram, NK_KEY_EXCHANGE, KEY_STATIC_PRIMARY);

    m_PingTicks = GetTickCount();
   }

unsigned long CipeSocketIO::GetHostAddress (string &p_Address)
   {
    unsigned long l_Return = inet_addr (p_Address.c_str());
    struct hostent *l_HostInfo = NULL;

    if (l_Return != INADDR_NONE)
       ;
    else if ((l_HostInfo = gethostbyname (p_Address.c_str())) == NULL)
       ;
    else
       l_Return = *((unsigned long *) l_HostInfo->h_addr_list [0]);

   //  DbgPrint ("[%s] resolved [%s] to [%lx]\n", Name().c_str(), p_Address.c_str(), l_Return);

    return l_Return;
   }

//========================================================================================
//                                   End of Source
//========================================================================================
