#include "general.h"

void Connection_alloc(void) {

   if(debug)
      err_msg("Connection_alloc: allocating new structure, conn_size=%d", conn_size);

   conn = realloc(conn, (conn_size + 1) * sizeof(Connection));

   if(conn == NULL)
      err_sys("Connection_alloc: Memory allocation error");

   memset(&conn[conn_size], 0, sizeof(Connection));

   conn[conn_size].rbuf = malloc(sndbufsz);
   conn[conn_size].lbuf = malloc(sndbufsz);

   if((conn[conn_size].lbuf == NULL) || (conn[conn_size].rbuf == NULL))
      err_sys("Connection_alloc: Memory allocation error");

   conn[conn_size].status = ST_NOTUSED;
   conn_size++;
}

int Connection_add(int localfd) {
   int i;

   if(debug)
      err_msg("Connection_add: Adding Connection structure for fd=%d", localfd);

   if(conn == NULL)
      Connection_alloc();

again:
   for(i = 0; i < conn_size; i++) {
      if(conn[i].status == ST_NOTUSED) {
         long flags;

         conn[i].localfd = localfd;
         if((flags = fcntl(localfd, F_GETFL, 0)) == -1)
            err_sys("Connection_add: fcntl F_GETFL error");

         if(fcntl(localfd, F_SETFL, flags | O_NONBLOCK) == -1)
            err_sys("Connection_add: fcntl F_SETFL error");

         maxfd = MAX(maxfd, conn[i].localfd);
         return i;
      }
   }

   Connection_alloc();
   goto again;
}

void Connection_del(Connection *con) {

   if(debug)
      err_msg("Connection_del: Deleting Connection structure for localfd=%d, remotefd=%d", con->localfd, con->remotefd);

   if(con->localfd) {
      close(con->localfd);
      FD_CLR(con->localfd, &rdset);
      FD_CLR(con->localfd, &wrset);
   }

   if(con->remotefd) {
      close(con->remotefd);
      FD_CLR(con->remotefd, &rdset);
      FD_CLR(con->remotefd, &wrset);
   }

   memset(&con->e, 0, sizeof(ArcfourContext));
   memset(&con->d, 0, sizeof(ArcfourContext));
   memset(&con->e2, 0, sizeof(ArcfourContext));
   memset(&con->d2, 0, sizeof(ArcfourContext));

   if(con->dh)
      free(con->dh);

   con->dh = NULL;
   con->status = ST_NOTUSED;
   con->localfd = 0;
   con->remotefd = 0;
   con->lpos = 0;
   con->rpos = 0;
}

void Connection_connect(Connection *con) {
   struct sockaddr_in sin;
   int sd, n;
   long flags;

   if((sd = socket(AF_INET, SOCK_STREAM, 0)) == -1)
      err_sys("Connection_process: socket error");

   if((flags = fcntl(sd, F_GETFL, 0)) == -1)
      err_sys("Connection_process: fcntl F_GETFL error");

   if(fcntl(sd, F_SETFL, flags | O_NONBLOCK) == -1)
      err_sys("Connection_process: fcntl F_SETFL error");

   bzero(&sin, sizeof(sin));
   sin.sin_family = AF_INET;
   sin.sin_addr = proxyhost;
   sin.sin_port = proxyport;

   if((n = connect(sd, (struct sockaddr *)&sin, sizeof(sin))) == -1) {
      if(errno != EINPROGRESS) {
         err_ret("Connection_process: connect didnt return EINPROGRESS");
         Connection_del(con);
         return;
      }
   } else if(n == 0) {  /* Connection succeeded */
      con->remotefd = sd;
      maxfd = MAX(maxfd, sd);
      CheckConnection(con);
      return;
   }

   if(debug)
      err_msg("Connection_process: Processed localfd=%d, remotefd=%d", con->localfd, sd);

   con->remotefd = sd;
   maxfd = MAX(maxfd, sd);
   con->status = ST_INPROGR;
   FD_SET(sd, &wrset);
}

void Connection_DH_init(Connection *con, int with) {
   R_RANDOM_STRUCT RandomStruct;

   InitRandomStruct(&RandomStruct);

   if((con->dh = InitDH(&RandomStruct, 0, NULL)) == NULL) {
      err_msg("Connection_DH_init: InitDH() failed");
      Connection_del(con);
      return;
   }

   memcpy(con->rbuf, con->dh->publicValue, DH_VAL_LEN);

   if(with == FOR_NORMAL) {
      con->status = ST_DHSEND1;
      con->rpos = 0;
      FD_SET(con->remotefd, &wrset);
      if(debug)
         err_msg("Connection_DH_init: Generated DH data for remotefd=%d", con->remotefd);
   } else if(with == FOR_ENDPOINT) {
      con->status = ST_DHSEND2;
      con->rpos = 0;
      FD_SET(con->remotefd, &wrset);
      arcfour_encrypt(&con->e, con->rbuf, con->rbuf, DH_DATA_LEN);
      if(debug)
	 err_msg("Connection_DH_init: Generated DH data for endpoint, remotefd=%d", con->remotefd);
   }
}

void Connection_DH_senddata(Connection *con) {
   int n;

   n = write(con->remotefd, (con->rbuf + con->rpos), (DH_DATA_LEN - con->rpos));

   switch(n) {
      case -1:
         if(debug)
            err_ret("Connection_DH_senddata: write error to remotefd=%d", con->remotefd);
         Connection_del(con);
         break;
      case 0:   /* Huh?, when does this happen? */
         if(debug)
            err_ret("Connection_DH_senddata: write to remotefd=%d returned zero", con->remotefd);
         Connection_del(con);
         break;
      default:
         if((con->rpos + n) == DH_DATA_LEN) { /* We send all DH data, time to read client DH data */
	    con->rpos = 0;
            FD_CLR(con->remotefd, &wrset);
            FD_SET(con->remotefd, &rdset);
	    if(con->status == ST_DHSEND1) {
               con->status = ST_DHRECV1;
               if(debug)
                  err_msg("Connection_DH_senddata: Send out all server DH data to remotefd=%d", con->remotefd);
	    } else {	/* ST_DHSEND2	*/
	       con->status = ST_DHRECV2;
               if(debug)
                  err_msg("Connection_DH_senddata: Send out all server DH data to remotefd=%d for endpoint", con->remotefd);
            }
         } else {
            if(debug)
               err_msg("Connection_DH_senddata: Send out %d bytes of DH data to remotefd=%d", n - con->rpos, con->remotefd);
            con->rpos += n;
         }
   }
}

void Connection_DH_readdata(Connection *con) {
   int n;

   if(con->status == ST_DHRECV2) {
      int max = DH_DATA_LEN - con->rpos;
      Readdata(con, FROM_REMOTE, max);
      FD_CLR(con->localfd, &wrset); /* Set by Readdata() */
      if(con->rpos == DH_DATA_LEN) {
         if(debug)
            err_msg("Connection_DH_readdata: Read in all DH data from endpoint from remotefd=%d", con->remotefd);
         con->rpos = 0;
         Connection_DH_process(con, FOR_ENDPOINT);
      }
      return;
   }

   n = read(con->remotefd, (con->rbuf + con->rpos), (DH_DATA_LEN - con->rpos));

   switch(n) {
      case -1:
         if(debug)
            err_ret("Connection_DH_readdata: read error from remotefd=%d", con->remotefd);
         Connection_del(con);
         break;
      case 0:
         if(debug)
            err_ret("Connection_DH_readdata: premature end of file from remotefd=%d", con->remotefd);
         Connection_del(con);
         break;
      default:
         con->rpos += n;
         if(con->rpos == DH_DATA_LEN) {
            if(debug)
               err_msg("Connection_DH_readdata: Read in all client DH data from remotefd=%d", con->remotefd);
            con->rpos = 0;
            Connection_DH_process(con, FOR_NORMAL);
            return;
         } else {
            if(debug)
               err_msg("Connection_DH_readdata: Read in %d bytes of DH data from remotefd=%d", n, con->remotefd);
         }
   }
   return;
}

void Connection_DH_process(Connection *con, int with) {
   char key[64+1];     /* 512 bit key is exchanged */
   int num;

   memset(key, 0, 65);

   if(((cntrl_num > 1) && (with == FOR_NORMAL)) || skip_verification)
      con->dh->server = 1;
   else
      con->dh->server = 0;

   if((num = ComputeDH(con->rbuf, con->dh, &ServerPublicKey, key)) != 0) {
      err_msg("Connection_DH_process: ComputeDH failed: %d", num);
      Connection_del(con);
      return;
   }

   if(debug) {
      char *tmp = base64_encode(key, 64);
      err_msg("Connection_DH_process: DH negotiation succesful");
      err_msg("Connection_DH_process: Agreed key = %s", tmp);
      free(tmp);
   }

   free(con->dh);      /* Not needed anymore */
   con->dh = NULL;

   if(with == FOR_NORMAL) {
      arcfour_init(&con->e, key, 64);
      con->d = con->e;    /* Saves a bit of processing time */

      Connection_initinfo(con);
   } else if(with == FOR_ENDPOINT) {
      arcfour_init(&con->e2, key, 64);
      con->d2 = con->e2;
      con->status = ST_CONNECT;
      FD_SET(con->localfd, &rdset);
   }
}

void Connection_initinfo(Connection *con) {

   if(!cntrl_num) {
      con->status = ST_CONNECT;
      FD_SET(con->localfd, &rdset);
      return;
   }
   memcpy(con->lbuf, cntrl, (cntrl_num * sizeof(conninit)));
   con->lpos = (cntrl_num * sizeof(conninit));
   arcfour_encrypt(&con->e, con->lbuf, con->lbuf, (cntrl_num * sizeof(conninit)));
   FD_CLR(con->remotefd, &rdset);
   FD_SET(con->remotefd, &wrset);
   con->status = ST_SENDINF;

}
void Connection_sendinfo(Connection *con) {
   Writedata(con, TO_REMOTE);

   if(!con->lpos) {	/* Wrote out all control data	*/
      if(debug)
	 err_msg("Connection_sendinfo: Wrote out all control headers to remotefd=%d", con->remotefd);
      if(cntrl_num > 1) {	/* Talking with chain of proxies, have to do dh with endpoint	*/
	 FD_CLR(con->localfd, &rdset);	/* Set by Writedata()   */
	 Connection_DH_init(con, FOR_ENDPOINT);
	 return;
      } else {	/* Talking with 1 proxy in standalone mode	*/
         con->status = ST_CONNECT;
         FD_CLR(con->remotefd, &wrset);
         FD_SET(con->remotefd, &rdset);
         FD_SET(con->localfd, &rdset);
      }
   } else
      FD_CLR(con->localfd, &rdset);	/* Set by Writedata()	*/
}

void Connection_eof(Connection *con, int who) {

   if(who == EOF_LOCAL) {
      if(con->lpos) {  /* We have data left to write out */
         if(debug)
            err_msg("Connection_eof: EOF for localfd=%d, %d bytes left to write out, entering ST_FINISHL state",
          con->localfd, con->lpos);

    if(shutdown(con->remotefd, SHUT_RD))     /* We dont want more data */
       err_ret("Connection_eof: Shutdown error for remotefd=%d", con->remotefd);

    FD_CLR(con->localfd, &rdset);
    FD_CLR(con->localfd, &wrset);
    FD_CLR(con->remotefd, &rdset);
    FD_SET(con->remotefd, &wrset);
         con->status = ST_FINISHL;
    return;
      }
      if(debug)
    err_msg("Connection_eof: EOF for localfd=%d, no more data to write out", con->localfd);

      Connection_del(con);
   } else if(who == EOF_REMOT) {
      if(con->rpos) {   /* We have data left to write out */
         if(debug)
       err_msg("Connection_eof: EOF for remotefd=%d, %d bytes left to write out, entering ST_FINISHR state",
       con->remotefd, con->rpos);

         if(shutdown(con->localfd, SHUT_RD))   /* We dont want more data */
            err_ret("Connection_eof: Shutdown error for localfd=%d", con->localfd);

         FD_CLR(con->remotefd, &rdset);
         FD_CLR(con->remotefd, &wrset);
         FD_CLR(con->localfd, &rdset);
         FD_SET(con->localfd, &wrset);
         con->status = ST_FINISHR;
    return;
      }
      if(debug)
         err_msg("Connection_eof: EOF for remotefd=%d, no more data to write out", con->remotefd);

      Connection_del(con);
   }
}
