#include "aesoptunnel.h"

/* This function allocates a new Connection structure and initialises
   it to the ST_NOTUSED
 */
void Connection_alloc(void) {

#ifdef DEBUG
   err_msg("Connection_alloc: allocating new structure, conn_size=%d.", conn_size);
#endif

   conn = Realloc(conn, (conn_size + 1) * sizeof(Connection));

   memset(&conn[conn_size], 0, sizeof(Connection));
   conn[conn_size].status = ST_NOTUSED;
   conn_size++;
}

int Connection_add(int localfd) {
   int i;

#ifdef DEBUG
   err_msg("Connection_add: Adding Connection structure for fd=%d.", localfd);
#endif
   if(conn == NULL)
      Connection_alloc();

again:
   for(i = 0; i < conn_size; i++) {
      if(conn[i].status == ST_NOTUSED) {
	 long flags;
	 int val = 1;

	 conn[i].localfd = localfd;
	 conn[i].lbuf = Malloc(sndbufsz);
	 conn[i].rbuf = Malloc(sndbufsz);

	 if((flags = fcntl(localfd, F_GETFL, 0)) == -1)
	    err_sys("Connection_add: fcntl F_GETFL error.");

	 if(fcntl(localfd, F_SETFL, flags | O_NONBLOCK) == -1)
	    err_sys("Connection_add: fcntl F_SETFL error.");

	 if(setsockopt(localfd, SOL_SOCKET, SO_OOBINLINE, &val, sizeof(int)) == -1)
	    err_sys("Connection_add: Can not set SO_OOBINLINE.");

	 conn[i].time = time(NULL) + timeout-1;
	 maxfd = MAX(maxfd, conn[i].localfd);

	 if(need_desthdr) {
	    conn[i].status = ST_RD_DEST;
	    FD_SET(localfd, &rdset);
	 }

	 if(conn_size == max_simul_con)
	    FD_CLR(lfd, &rdset);        /* Accept no more connections until free slot)  */

	 return i;
      }
   }

   /* Couldn't find any empty connection structures, allocate a new one
      and do the loop again, it will work this time. */
   Connection_alloc();
   goto again;
}

/* This function deletes a Connection structure.
   It closes the sockets and sets the status to ST_NOTUSED.
 */
void Connection_del(Connection *con) {
   char *ret, *ret2;

#ifdef DEBUG
   err_msg("Connection_del: Deleting Connection structure for localfd=%d, remotefd=%d.", con->localfd, con->remotefd);
#endif

   aesoplog("Connection from %s to %s ended.", ret = print_sock(con->localfd), ret2 = print_sock(con->remotefd));
   free(ret); free(ret2);
   
   if(con->localfd) {
      close(con->localfd);
      FD_CLR(con->localfd, &rdset);
      FD_CLR(con->localfd, &wrset);
   }

   if(con->remotefd) {
      close(con->remotefd);
      FD_CLR(con->remotefd, &rdset);
      FD_CLR(con->remotefd, &wrset);
   }

   if(con->dh)
      free(con->dh);

   if(con->l_oobpos)
      free(con->l_oobpos);

   if(con->r_oobpos)
      free(con->r_oobpos);

   free(con->lbuf);
   free(con->rbuf);

   memset(con, 0, sizeof(Connection));

   /* We just deleted a Connection structure there must be one avaiable to accept
      a new connection on. */
   FD_SET(lfd, &rdset);
}

/* Connection_process connects to the first proxy specified in the routes file.
   Then it sets the socket to be nonblocking and starts a connection (which will
   return EINPROGRESS since the socket is nonblocking. After that maxfd is eventually
   updated with the new socket descriptor.
 */

void Connection_connect(Connection *con) {
   int sd, n, val = 1;
   long flags;

   if((sd = socket(proxyhost->sa_family, SOCK_STREAM, 0)) == -1)
      err_sys("Connection_connect: socket error.");

   if((flags = fcntl(sd, F_GETFL, 0)) == -1)
      err_sys("Connection_connect: fcntl F_GETFL error.");

   if(fcntl(sd, F_SETFL, flags | O_NONBLOCK) == -1)
      err_sys("Connection_connnect: fcntl F_SETFL error.");

   if(setsockopt(sd, SOL_SOCKET, SO_OOBINLINE, &val, sizeof(int)) == -1)
      err_sys("Connection_connect: Can not set SO_OOBINLINE.");

   if((n = connect(sd, proxyhost, proxyhostlen)) == -1) {
      if(errno != EINPROGRESS) {
#ifdef DEBUG
	 err_ret("Connection_connect: connect didnt return EINPROGRESS.");
#endif
	 Connection_del(con);
	 return;
      }
   }
#ifdef DEBUG
   err_msg("Connection_process: Processed localfd=%d, remotefd=%d.", con->localfd, sd);
#endif

   con->remotefd = sd;
   maxfd = MAX(maxfd, sd);

   if(n == 0) {	/* Connection succeeded directly */
      CheckConnection(con);
      return;
   }

   con->status = ST_INPROGR;
   FD_SET(sd, &wrset);
}

void Connection_DH_init(Connection *con, int with) {
   R_RANDOM_STRUCT RandomStruct;

   InitRandomStruct(&RandomStruct);

   if((con->dh = InitDH(&RandomStruct, 0, NULL)) == NULL) {
      err_quit("Connection_DH_init: InitDH() failed.");
      Connection_del(con);
      return;
   }

   memcpy(con->rbuf, con->dh->publicValue, DH_VAL_LEN);

   if(with == FOR_NORMAL) {
      con->status = ST_DHSEND1;
      con->rpos = 0;
      FD_SET(con->remotefd, &wrset);
#ifdef DEBUG
      err_msg("Connection_DH_init: Generated DH data for remotefd=%d.", con->remotefd);
#endif
   } else if(with == FOR_ENDPOINT) {
      con->status = ST_DHSEND2;
      con->rpos = 0;
      FD_SET(con->remotefd, &wrset);
      arcfour_encrypt(&con->e, con->rbuf, con->rbuf, DH_DATA_LEN);
#ifdef DEBUG
      err_msg("Connection_DH_init: Generated DH data for endpoint, remotefd=%d.", con->remotefd);
#endif
   }
}

void Connection_DH_senddata(Connection *con) {
   int n;

   n = write(con->remotefd, (con->rbuf + con->rpos), (DH_DATA_LEN - con->rpos));

   switch(n) {
      case -1:
#ifdef DEBUG
	 err_ret("Connection_DH_senddata: write error to remotefd=%d.", con->remotefd);
#endif
	 Connection_del(con);
	 break;
      case 0:   /* Huh?, when does this happen? */
#ifdef DEBUG
	 err_ret("Connection_DH_senddata: write to remotefd=%d returned zero.", con->remotefd);
#endif
	 Connection_del(con);
	 break;
      default:
	 if((con->rpos + n) == DH_DATA_LEN) { /* We send all DH data, time to read client DH data */
	    con->rpos = 0;
	    FD_CLR(con->remotefd, &wrset);
	    FD_SET(con->remotefd, &rdset);
	    if(con->status == ST_DHSEND1) {
	       con->status = ST_DHRECV1;
#ifdef DEBUG
	       err_msg("Connection_DH_senddata: Send out all server DH data to remotefd=%d.", con->remotefd);
#endif
	    } else {	/* ST_DHSEND2	*/
	       con->status = ST_DHRECV2;
#ifdef DEBUG
	       err_msg("Connection_DH_senddata: Send out all server DH data to remotefd=%d for endpoint.", con->remotefd);
#endif
	    }
	 } else {
#ifdef DEBUG
	    err_msg("Connection_DH_senddata: Send out %d bytes of DH data to remotefd=%d.", n - con->rpos, con->remotefd);
#endif
	    con->rpos += n;
	 }
   }
}

void Connection_DH_readdata(Connection *con) {
   int n;

   if(con->status == ST_DHRECV2) {
      int max = DH_DATA_LEN - con->rpos;
      Readdata(con, FROM_REMOTE, max);
      FD_CLR(con->localfd, &wrset); /* Set by Readdata() */
      if(con->rpos == DH_DATA_LEN) {
#ifdef DEBUG
	 err_msg("Connection_DH_readdata: Read in all DH data from endpoint from remotefd=%d.", con->remotefd);
#endif
	 con->rpos = 0;
	 Connection_DH_process(con, FOR_ENDPOINT);
      }
      return;
   }

   n = read(con->remotefd, (con->rbuf + con->rpos), (DH_DATA_LEN - con->rpos));

   switch(n) {
      case -1:
#ifdef DEBUG
	 err_ret("Connection_DH_readdata: read error from remotefd=%d.", con->remotefd);
#endif
	 Connection_del(con);
	 break;
      case 0:
#ifdef DEBUG
	 err_ret("Connection_DH_readdata: premature end of file from remotefd=%d.", con->remotefd);
#endif
	 Connection_del(con);
	 break;
      default:
	 con->rpos += n;
	 if(con->rpos == DH_DATA_LEN) {
#ifdef DEBUG
	    err_msg("Connection_DH_readdata: Read in all client DH data from remotefd=%d.", con->remotefd);
#endif
	    con->rpos = 0;
	    Connection_DH_process(con, FOR_NORMAL);
	    return;
	 } else {
#ifdef DEBUG
	    err_msg("Connection_DH_readdata: Read in %d bytes of DH data from remotefd=%d", n, con->remotefd);
#endif
	 }
   }
   return;
}

void Connection_DH_process(Connection *con, int with) {
   int num;
   unsigned char key[64+1];     /* 512 bit key is exchanged */
   R_RSA_PUBLIC_KEY *pkey;

   memset(key, 0, 65);

   if(((cntrl_num > 1) && (with == FOR_NORMAL)) || skip_verification)
      con->dh->server = 1;
   else
      con->dh->server = 0;

   if(with == FOR_NORMAL)
      pkey = &FirstPublicKey;
   else
      pkey = &EndpointPublicKey;

   if((num = ComputeDH(con->rbuf, con->dh, pkey, key)) != 0) {
#ifdef DEBUG
      err_msg("Connection_DH_process: ComputeDH failed: %d", num);
#endif
      Connection_del(con);
      return;
   }

#ifdef DEBUG
   if(debug) {
      char *tmp = base64_encode(key, 64);
      err_msg("Connection_DH_process: DH negotiation succesful");
      err_msg("Connection_DH_process: Agreed key = %s", tmp);
      free(tmp);
   }
#endif

   free(con->dh);      /* Not needed anymore */
   con->dh = NULL;

   if(with == FOR_NORMAL) {
      arcfour_init(&con->e, key, 64);
      con->d = con->e;    /* Saves a bit of processing time */

      Connection_initinfo(con);
   } else if(with == FOR_ENDPOINT) {
      arcfour_init(&con->e2, key, 64);
      con->d2 = con->e2;
      con->status = ST_CONNECT;
      FD_SET(con->localfd, &rdset);
   }
}

void Connection_initinfo(Connection *con) {

   if(!cntrl_num && (!need_desthdr)) {
      con->status = ST_CONNECT;
      FD_SET(con->localfd, &rdset);
      return;
   }
   if(need_desthdr)
      memmove(con->lbuf+(cntrl_num * ENC_CONN_SIZE), con->lbuf, ENC_CONN_SIZE);
   memcpy(con->lbuf, ctrlheaders, (cntrl_num * ENC_CONN_SIZE));
   con->lpos = (cntrl_num * ENC_CONN_SIZE);
   if(need_desthdr)
      con->lpos += ENC_CONN_SIZE;
   arcfour_encrypt(&con->e, con->lbuf, con->lbuf, con->lpos);
   FD_CLR(con->remotefd, &rdset);
   FD_SET(con->remotefd, &wrset);
   con->status = ST_SENDINF;
}

void Connection_sendinfo(Connection *con) {
   Writedata(con, TO_REMOTE);

   if(!con->lpos) {	/* Wrote out all control data	*/
#ifdef DEBUG
      err_msg("Connection_sendinfo: Wrote out all control headers to remotefd=%d", con->remotefd);
#endif
      if((cntrl_num > 1) || (need_desthdr && cntrl_num)) {
	/* Talking with chain of proxies, have to do dh with endpoint	*/
	 FD_CLR(con->localfd, &rdset);	/* Set by Writedata()   */
	 Connection_DH_init(con, FOR_ENDPOINT);
	 return;
      } else {	/* Talking with 1 proxy in standalone mode	*/
	 con->status = ST_CONNECT;
	 FD_CLR(con->remotefd, &wrset);
	 FD_SET(con->remotefd, &rdset);
	 FD_SET(con->localfd, &rdset);
      }
   } else
      FD_CLR(con->localfd, &rdset);	/* Set by Writedata()	*/
}

void Connection_eof(Connection *con, int who) {

   if(who == EOF_LOCAL) {
      if(con->status == ST_FINISHL) {
#ifdef DEBUG
	 err_msg("Connection_eof: Wrote out all remaining data to remotefd=%d, closing end for writing",
		     con->remotefd);
#endif
	 shutdown(con->remotefd, SHUT_WR);
	 FD_CLR(con->localfd, &rdset);
      } else if(con->status == ST_FINISHR) {
#ifdef DEBUG
	 err_msg("Connection_eof: Both localfd=%d and remotefd=%d are finished, deleting structure",
		     con->localfd, con->remotefd);
#endif
	 Connection_del(con);
	 return;
      } else {
	 if(con->lpos) {        /* We have data left to write out */
#ifdef DEBUG
	    err_msg("Connection_eof: EOF for localfd=%d, %d bytes left to write out, entering ST_FINISHL state",
			con->localfd, con->lpos);
#endif
	    FD_SET(con->remotefd, &wrset);
	 } else {
#ifdef DEBUG
	    err_msg("Connection_eof: EOF for localfd=%d, entering ST_FINISHL state and closing other end for writing",
			 con->localfd);
#endif
	    shutdown(con->remotefd, SHUT_WR);
	 }

	 shutdown(con->localfd, SHUT_RD);   /* We dont want more data */

	 FD_CLR(con->localfd, &rdset);
	 con->status = ST_FINISHL;
	 return;
      }
   } else if(who == EOF_REMOT) {
      if(con->status == ST_FINISHR) {
#ifdef DEBUG
	 err_msg("Connection_eof: Wrote out all remaining data to localfd=%d, closing end for writing",
		     con->localfd);
#endif
	 shutdown(con->localfd, SHUT_WR);
      } else if(con->status == ST_FINISHL) {
#ifdef DEBUG
	 err_msg("Connection_eof: Both localfd=%d and remotefd=%d are finished, deleting structure",
		     con->localfd, con->remotefd);
#endif
	 Connection_del(con);
	 return;
      } else {
	 if(con->rpos) {        /* We have data left to write out */
#ifdef DEBUG
	    err_msg("Connection_eof: EOF for remotefd=%d, %d bytes left to write out, entering ST_FINISHR state",
			con->remotefd, con->rpos);
#endif
	    FD_SET(con->localfd, &wrset);
	 } else {
#ifdef DEBUG
	    err_msg("Connection_eof: EOF for remotefd=%d, entering ST_FINISHR state and closing other end for writing",
		 con->remotefd);
#endif
	    shutdown(con->localfd, SHUT_WR);
	 }

	 shutdown(con->remotefd, SHUT_RD);   /* We dont want more data */

	 FD_CLR(con->remotefd, &rdset);
	 con->status = ST_FINISHR;
	 return;
      }
   }
}

void Connection_readdest(Connection *con) {

   int n = read(con->localfd, (con->lbuf + con->lpos), sizeof(desthdr) - con->lpos);

   switch(n) {
      case -1:
#ifdef DEBUG
	err_ret("Connection_readdest: read from localfd=%d return -1", con->localfd);
#endif
	Connection_del(con);
	return;
      case 0:
#ifdef DEBUG
	err_msg("Connection_readdest: Premature EOF on destination header for localfd=%d", con->localfd);
#endif
	Connection_del(con);
	return;
      default:
	con->lpos += n;
	if(con->lpos == sizeof(desthdr)) {
	   desthdr *dest = (desthdr *)con->lbuf;
	   conninit init;
#ifdef DEBUG
	   err_msg("Connection_readdest: Read in destination header from localfd=%d", con->localfd);
#endif
	   if(ntohl(dest->magic) != MAGIC) {
#ifdef DEBUG
	      err_msg("Connection_readdest: dest->magic = 0x%x", dest->magic);
#endif 
	      Connection_del(con);
	      return;
	    }

	    init.magic = dest->magic;
	    init.targetport = dest->targetport;
	    strcpy((char *)init.authdata, AuthData);
	    memcpy(init.targetip, dest->targetip, 16);

	    if(cntrl_num)
	       init.control = dest->control | htons(CNTRL_ENDPO);
	    else
	       init.control = dest->control | htons(CNTRL_ALONE);

	    if(sp_spec) {
	       init.data = htons(sp_port);
	       if(sp_spec == 1)
		  init.control = dest->control | htons(CNTRL_APORT);
	       else
		  init.control = dest->control | htons(CNTRL_MPORT);
	    }

	    if(cntrl_num)
	       init.control |= htons(CNTRL_ENDPO);
	    else
	       init.control |= htons(CNTRL_ALONE);

	    enc_conninit(&EndpointPublicKey, (unsigned char *)&init, con->lbuf);

	    con->lpos = 0;
	    FD_CLR(con->localfd, &rdset);
	    Connection_connect(con);
	    return;
	 }
   }
}
