#include "general.h"

/* This function allocates a new Connection structure and initialises
   it to the ST_NOTUSED
 */
void Connection_alloc(void) {
   
   if(debug)
      err_msg("Connection_alloc: allocating new structure, conn_size=%d", conn_size);

   if(conn == NULL)
      conn = calloc(1, sizeof(Connection));
   else
      conn = realloc(conn, (conn_size + 1) * sizeof(Connection));

   if(conn == NULL)
      err_sys("Connection_alloc: Memory allocation error");

   conn[conn_size].rbuf = malloc(sndbufsz);
   conn[conn_size].lbuf = malloc(sndbufsz);

   if((conn[conn_size].lbuf == NULL) || (conn[conn_size].rbuf == NULL))
      err_sys("Connection_alloc: Memory allocation error");

   conn[conn_size].status = ST_NOTUSED;
   conn_size++;
}


/* This function searches for a free Connection structure, or allocates a new one
   It also sets the time the socket is added, for the timeout functionality
   Then it eventually updates maxfd. It adds the socketdescriptor 
   to the readset and the index is returned.
 */
int Connection_add(int localfd) {
   int i;

   if(debug)
      err_msg("Connection_add: Adding Connection structure for fd=%d", localfd);

   if(conn == NULL)
      Connection_alloc();

again:
   for(i = 0; i < conn_size; i++) {
      if(conn[i].status == ST_NOTUSED) {
	 long flags;

         conn[i].localfd = localfd;
         if((flags = fcntl(localfd, F_GETFL, 0)) == -1)
            err_sys("Connection_add: fcntl F_GETFL error");

         if(fcntl(localfd, F_SETFL, flags | O_NONBLOCK) == -1)
            err_sys("Connection_add: fcntl F_SETFL error");

	 conn[i].time = time(NULL) + timeout;
	 maxfd = MAX(maxfd, conn[i].localfd);

         if(conn_size == max_simul_con)
	    FD_CLR(lfd, &rdset);	/* Accept no more connections until free slot)	*/

         return i;
      }
   }
   
   Connection_alloc();
   goto again;
}

/* This function deletes a Connection structure.
   It closes the sockets and sets the status to ST_NOTUSED.
 */

void Connection_del(Connection *con) {

   if(debug)
      err_msg("Connection_del: Deleting Connection structure for localfd=%d, remotefd=%d", con->localfd, con->remotefd);

   if(con->localfd) {
      close(con->localfd);
      FD_CLR(con->localfd, &rdset);
      FD_CLR(con->localfd, &wrset);
   }

   if(con->remotefd) {
      close(con->remotefd);
      FD_CLR(con->remotefd, &rdset);
      FD_CLR(con->remotefd, &wrset);
   }

   memset(&con->e, 0, sizeof(ArcfourContext));
   memset(&con->d, 0, sizeof(ArcfourContext));
   memset(&con->e2, 0, sizeof(ArcfourContext));
   memset(&con->d2, 0, sizeof(ArcfourContext));

   if(con->dh)
      free(con->dh);

   con->dh = NULL;
   con->status = ST_NOTUSED;
   memset(&con->init, 0, sizeof(conninit));
   con->localfd = 0;
   con->remotefd = 0;
   con->lpos = 0;
   con->rpos = 0;

   FD_SET(lfd, &rdset);		/* We just deleted a Connection structure there must be one avaiable to accept
				   a new connection on.
				 */
}

/* Connection_process reads in the conninit structure and checks the magic value.
   Then it sets the socket to be nonblocking and starts a connection (which will
   return EINPROGRESS since the socket is nonblocking. After that maxfd is eventually
   updated with the new socket descriptor.
 */

void Connection_process(Connection *con) {
   
   int n, sd, max;
   long flags;
   conninit init;
   struct sockaddr_in sin;
   
   if(con->init.control & CNTRL_ENDPO) {	/* 2nd time visit as endpoint */
       FD_CLR(con->localfd, &rdset);
       init = con->init;
   } else {
      n = con->lpos;
      max = sizeof(conninit) - n;

      Readdata(con, FROM_LOCAL, max);

      if(con->lpos < sizeof(conninit)) {
         if(debug)
            err_msg("Connection_process: Read in %d bytes of conninit structure from localfd=%d", con->lpos - n, con->localfd);
         FD_CLR(con->remotefd, &wrset);	/* Readdata() set con->remotefd on in wrset */
         return;
      }

      FD_CLR(con->localfd, &rdset);
      FD_CLR(con->remotefd, &wrset);	/* Readdata() set con->remotefd on in wrset */

      if(debug)
         err_msg("Connection_process: Read in conninit structure, processing for localfd=%d", con->localfd);

      memcpy(&init, con->lbuf, sizeof(conninit));
      memmove(con->lbuf, con->lbuf + sizeof(conninit), con->lpos - sizeof(conninit));
      con->lpos -= sizeof(conninit);
 
      if(ntohl(init.magic) != MAGIC) {
         err_msg("Connection_process: init.magic = 0x%x", init.magic);
         Connection_del(con);
         return;
      }

      con->init = init;
      con->init.control = ntohs(con->init.control);

      if(con->init.control & CNTRL_ENDPO) {	/* Endpoint */
         Connection_DH_init(con, AS_ENDPOINT);
         return;
      }
   }

   if((sd = socket(AF_INET, SOCK_STREAM, 0)) == -1)
      err_sys("Connection_process: socket error");

   if((flags = fcntl(sd, F_GETFL, 0)) == -1)
      err_sys("Connection_process: fcntl F_GETFL error");
   
   if(fcntl(sd, F_SETFL, flags | O_NONBLOCK) == -1)
      err_sys("Connection_process: fcntl F_SETFL error");

   bzero(&sin, sizeof(sin));
   sin.sin_family = AF_INET;
   sin.sin_addr.s_addr = init.targetip;
   sin.sin_port = init.targetport;

   if(debug)
      err_msg("Connection_process: Connection request for localfd=%d to %s:%hu", con->localfd,
		inet_ntoa(sin.sin_addr), ntohs(sin.sin_port));

   if((n = connect(sd, (struct sockaddr *)&sin, sizeof(sin))) == -1) {
      if(errno != EINPROGRESS) {
         err_ret("Connection_process: connect didnt return EINPROGRESS");
	 Connection_del(con);
         return;
      }
   } else if(n == 0) {	/* Connection succeeded */
      con->remotefd = sd;
      maxfd = MAX(maxfd, sd);
      CheckConnection(con);
      return;
   }

   if(debug)
      err_msg("Connection_process: Processed localfd=%d, remotefd=%d", con->localfd, sd);

   con->remotefd = sd;
   maxfd = MAX(maxfd, sd);
   con->status = ST_INPROGR;
   FD_SET(sd, &wrset);
}

void Connection_DH_init(Connection *con, int how) {
   R_RANDOM_STRUCT RandomStruct;
 
   InitRandomStruct(&RandomStruct);

   if((con->dh = InitDH(&RandomStruct, 1, &ServerPrivateKey)) == NULL) {
      err_msg("Connection_DH_init: InitDH() failed");
      Connection_del(con);
      return;
   }
   
   memcpy(con->rbuf, con->dh->publicValue, DH_VAL_LEN);
   memcpy(con->rbuf+DH_VAL_LEN, con->dh->signature, DH_SIGNATURE_LEN);

   if(how == AS_NORMAL) {
      con->status = ST_DHSEND1;
      FD_SET(con->localfd, &wrset);
      if(debug)
         err_msg("Connection_DH_init: Generated DH data for localfd=%d", con->localfd);
   } else if(how == AS_CHAIN) {
      con->status = ST_DHSEND2;
      FD_SET(con->remotefd, &wrset);
      if(debug)
	 err_msg("Connection_DH_init: Generated DH data for remotefd=%d", con->remotefd);
   } else if(how == AS_ENDPOINT) {
      con->status = ST_DHSEND3;
      con->rpos = DH_DATA_LEN;
      arcfour_encrypt(&con->e, con->rbuf, con->rbuf, DH_DATA_LEN);
      FD_SET(con->localfd, &wrset);
      if(debug)
         err_msg("Connection_DH_init: Generated DH data as endpoint for localfd=%d", con->localfd);
   }
}

void Connection_DH_senddata(Connection *con, int to) {
   int n;
   char *name;
   int fd;

   if(to == TO_LOCAL) {
      name = "localfd";
      fd = con->localfd;
      if(con->status == ST_DHSEND3) {
         Writedata(con, TO_LOCAL);
         FD_CLR(con->remotefd, &rdset);	/* Was set by Writedata() */
         if(!con->rpos) {
            if(debug)
	       err_msg("Connection_DH_senddata: Send out all server DH data as endpoint to %s=%d", name, fd);
	    con->status = ST_DHRECV3;
            FD_CLR(con->localfd, &wrset);
	    FD_SET(con->localfd, &rdset);
         }
         return;
      }
      n = write(con->localfd, (con->rbuf + con->rpos), (DH_DATA_LEN - con->rpos));
   } else {
      name = "remotefd";
      fd = con->remotefd;
      n = write(con->remotefd, (con->rbuf + con->rpos), (DH_DATA_LEN - con->rpos));
   }

   switch(n) {
      case -1:
         if(debug)
            err_ret("Connection_DH_senddata: write error to %s=%d", name, fd);
         Connection_del(con);
         break;
      case 0: 	/* Huh?, when does this happen? */
	 if(debug)
	    err_ret("Connection_DH_senddata: write to %s=%d returned zero", name, fd);
	 Connection_del(con);
	 break;
      default:
         if((con->rpos + n) == DH_DATA_LEN) { /* We send all DH data, time to read client DH data */
	    if(to == TO_LOCAL) {
               con->status = ST_DHRECV1;
               FD_CLR(con->localfd, &wrset);
               FD_SET(con->localfd, &rdset);
            } else {
	       con->status = ST_DHRECV2;
	       FD_CLR(con->remotefd, &wrset);
	       FD_SET(con->remotefd, &rdset);
	    }
	    con->rpos = 0;
            if(debug)
               err_msg("Connection_DH_senddata: Send out all server DH data to %s=%d", name, fd);
         } else {
            if(debug)
               err_msg("Connection_DH_senddata: Send out %d bytes of DH data to %s=%d", n - con->rpos, name, fd);
	    con->rpos += n;
         }
   }
}
      
void Connection_DH_readdata(Connection *con, int from) {
   int n;
   char *name;
   int fd;

   if(from == FROM_LOCAL) {
      name = "localfd";
      fd = con->localfd;
      if(con->status == ST_DHRECV3) {
         int max = DH_DATA_LEN - con->lpos;
         Readdata(con, FROM_LOCAL, max);
	 FD_CLR(con->remotefd, &wrset);	/* Set by Readdata() */
         if(con->lpos == DH_DATA_LEN) {
            if(debug)
	       err_msg("Connection_DH_readdata: Read in all client DH data as endpoint from %s=%d", name, fd);
	    con->lpos = 0;
	    Connection_DH_process(con, AS_ENDPOINT);
         }
         return;
      }
      n = read(con->localfd, (con->lbuf + con->lpos), (DH_DATA_LEN - con->lpos));
   } else {
      name = "remotefd";
      fd = con->remotefd;
      n = read(con->remotefd, (con->lbuf + con->lpos), (DH_DATA_LEN - con->lpos));
   }

   switch(n) {
      case -1:
         if(debug)
            err_ret("Connection_DH_readdata: read error from %s=%d", name, fd);
         Connection_del(con);
         break;
      case 0:
	 if(debug)
	    err_ret("Connection_DH_readdata: premature end of file from %s=%d", name, fd);
         Connection_del(con);
         break;
      default:
	 con->lpos += n;
         if(con->lpos == DH_DATA_LEN) {
	    if(debug)
	       err_msg("Connection_DH_readdata: Read in all client DH data from %s=%d", name, fd);
	    con->lpos = 0;
	    if(from == FROM_LOCAL)
               Connection_DH_process(con, AS_NORMAL);
            else
	       Connection_DH_process(con, AS_CHAIN);
            return;
         } else {
            if(debug)
	       err_msg("Connection_DH_readdata: Read in %d bytes of DH data from %s=%d", n, name, fd);
	 }
   }
   return;
}

void Connection_DH_process(Connection *con, int how) {
   char key[64+1];	/* 512 bit key is exchanged */

   memset(key, 0, 65);
   if(ComputeDH(con->lbuf, con->dh, NULL, key) != 0) {
      err_msg("Connection_DH_process: ComputeDH failed");
      Connection_del(con);
      return;
   }

   if(debug) {
      char *tmp = base64_encode(key, 64);
      err_msg("Connection_DH_process: DH negotiation succesful");
      err_msg("Connection_DH_process: Agreed key = %s", tmp);
      free(tmp);
   }

   free(con->dh);	/* Not needed anymore */
   con->dh = NULL;

   if(how == AS_NORMAL) {
      if(con->init.control & CNTRL_ENDPO) {
         arcfour_init(&con->e2, key, 64);
         con->d2 = con->e2;
         FD_CLR(con->localfd, &rdset);
	 Connection_process(con);
         return;
      } else {
         arcfour_init(&con->e, key, 64);
         con->d = con->e;	/* Saves a bit of processing time */
   
         FD_CLR(con->localfd, &wrset);
         FD_SET(con->localfd, &rdset);
         con->status = ST_WAITINF;
         con->time = time(NULL) + timeout;
      }
   } else if(how == AS_CHAIN) {
      arcfour_init(&con->e2, key, 64);
      con->d2 = con->e2;
      
      FD_SET(con->localfd, &rdset);
      con->status = ST_CONNECT;
   } else if(how == AS_ENDPOINT) {
      arcfour_init(&con->e2, key, 64);
      con->d2 = con->e2;

      Connection_process(con);
      return;
   }
}

void Connection_eof(Connection *con, int who) {
   
   if(who == EOF_LOCAL) {
      if(con->lpos) {	/* We have data left to write out */
         if(debug)
            err_msg("Connection_eof: EOF for localfd=%d, %d bytes left to write out, entering ST_FINISHL state",
		     con->localfd, con->lpos);

	 if(shutdown(con->remotefd, SHUT_RD))	/* We dont want more data */
	    err_ret("Connection_eof: Shutdown error for remotefd=%d", con->remotefd);
	 
	 FD_CLR(con->localfd, &rdset);
	 FD_CLR(con->localfd, &wrset);
	 FD_CLR(con->remotefd, &rdset);
	 FD_SET(con->remotefd, &wrset);
         con->status = ST_FINISHL;
	 return;
      }
      if(debug)
	 err_msg("Connection_eof: EOF for localfd=%d, no more data to write out", con->localfd);

      Connection_del(con);
   } else if(who == EOF_REMOT) {
      if(con->rpos) {   /* We have data left to write out */
         if(debug)
	    err_msg("Connection_eof: EOF for remotefd=%d, %d bytes left to write out, entering ST_FINISHR state", 
	    con->remotefd, con->rpos);

         if(shutdown(con->localfd, SHUT_RD))   /* We dont want more data */
            err_ret("Connection_eof: Shutdown error for localfd=%d", con->localfd);

         FD_CLR(con->remotefd, &rdset);
         FD_CLR(con->remotefd, &wrset);
         FD_CLR(con->localfd, &rdset);
         FD_SET(con->localfd, &wrset);
         con->status = ST_FINISHR;
	 return;
      }
      if(debug)
         err_msg("Connection_eof: EOF for remotefd=%d, no more data to write out", con->remotefd);

      Connection_del(con);
   }
}
