#include "aesopd.h"

/* This function allocates a new Connection structure and initialises
   it to the ST_NOTUSED
 */
void Connection_alloc(void) {
   
#ifdef DEBUG
   err_msg("Connection_alloc: allocating new structure, conn_size=%d", conn_size);
#endif

   conn = Realloc(conn, (conn_size + 1) * sizeof(Connection));

   memset(&conn[conn_size], 0, sizeof(Connection));
   conn[conn_size].status = ST_NOTUSED;
   conn_size++;
}


/* This function searches for a free Connection structure, or allocates a new one.
   It also sets the time the socket is added, for the timeout functionality.
   It sets the socking to nonblocking mode. Then it eventually updates maxfd.
   It adds the socketdescriptor to the readset and the index is returned.
 */
int Connection_add(int localfd) {
   int i;

#ifdef DEBUG
   err_msg("Connection_add: Adding Connection structure for fd=%d", localfd);
#endif
   if(conn == NULL)
      Connection_alloc();

again:
   for(i = 0; i < conn_size; i++) {
      if(conn[i].status == ST_NOTUSED) {
	 long flags;
	 int val = 1;

	 conn[i].localfd = localfd;
	 conn[i].lbuf = Malloc(sndbufsz);
	 conn[i].rbuf = Malloc(sndbufsz);

	 if((flags = fcntl(localfd, F_GETFL, 0)) == -1)
	    err_sys("Connection_add: fcntl F_GETFL error");

	 if(fcntl(localfd, F_SETFL, flags | O_NONBLOCK) == -1)
	    err_sys("Connection_add: fcntl F_SETFL error");

	 if(setsockopt(localfd, SOL_SOCKET, SO_OOBINLINE, &val, sizeof(int)) == -1)
	    err_sys("Connection_add: Can not set SO_OOBINLINE");

	 conn[i].time = time(NULL) + timeout-1;
	 maxfd = MAX(maxfd, conn[i].localfd);

	 if(conn_size == max_simul_con)
	    FD_CLR(lfd, &rdset);	/* Accept no more connections until free slot)	*/

	 return i;
      }
   }

   /* Couldn't find any empty connection structures, allocate a new one
      and do the loop again, it will work this time. */   
   Connection_alloc();
   goto again;
}

/* This function deletes a Connection structure.
   It closes the sockets and sets the status to ST_NOTUSED.
 */

void Connection_del(Connection *con) {
   char *ret, *ret2;

#ifdef DEBUG
   err_msg("Connection_del: Deleting Connection structure for localfd=%d, remotefd=%d", con->localfd, con->remotefd);
#endif

   aesoplog("Connection from %s to %s ended.", ret = print_sock(con->localfd), ret2 = print_sock(con->remotefd));
   free(ret); free(ret2);

   if(con->localfd) {
      close(con->localfd);
      FD_CLR(con->localfd, &rdset);
      FD_CLR(con->localfd, &wrset);
   }

   if(con->remotefd) {
      close(con->remotefd);
      FD_CLR(con->remotefd, &rdset);
      FD_CLR(con->remotefd, &wrset);
   }

   if(con->dh)
      free(con->dh);

   if(con->init)
      free(con->init);

   if(con->l_oobpos)
      free(con->l_oobpos);

   if(con->r_oobpos)
      free(con->r_oobpos);

   free(con->lbuf);
   free(con->rbuf);

   memset(con, 0, sizeof(Connection));

   /* We just deleted a Connection structure there must be one avaiable to accept
      a new connection on. */
   FD_SET(lfd, &rdset);
}

/* Connection_process reads in the conninit structure and checks the magic value.
   Then it sets the socket to be nonblocking and starts a connection (which will
   return EINPROGRESS since the socket is nonblocking. After that maxfd is eventually
   updated with the new socket descriptor.
 */

void Connection_process(Connection *con) {
   
   int n, sd, max;
   int val = 1;
   char *tmp, *ret, *ret2;
   long flags;
   conninit init;
   struct sockaddr *sin;
   socklen_t sinlen;   

   if(con->control & CNTRL_ENDPO) {	/* 2nd time visit as endpoint */
       FD_CLR(con->localfd, &rdset);
       init = *(con->init);
   } else {
      n = con->lpos;
      max = ENC_CONN_SIZE - n;

      Readdata(con, FROM_LOCAL, max);

      if(con->lpos < ENC_CONN_SIZE) {
#ifdef DEBUG
	 err_msg("Connection_process: Read in %d bytes of conninit structure from localfd=%d", con->lpos - n, con->localfd);
#endif
	 FD_CLR(con->remotefd, &wrset);	/* Readdata() set con->remotefd on in wrset */
	 return;
      }

      FD_CLR(con->localfd, &rdset);
      FD_CLR(con->remotefd, &wrset);	/* Readdata() set con->remotefd on in wrset */

#ifdef DEBUG
      err_msg("Connection_process: Read in conninit structure, processing for localfd=%d", con->localfd);
#endif

      dec_conninit(&ServerPrivateKey, con->lbuf, (unsigned char *)&init);
      memmove(con->lbuf, con->lbuf + ENC_CONN_SIZE, con->lpos - ENC_CONN_SIZE);

      con->lpos -= ENC_CONN_SIZE;
 
      if(ntohl(init.magic) != MAGIC) {

#ifdef DEBUG
	 err_msg("Connection_process: init.magic = 0x%x", init.magic);
#endif

	 Connection_del(con);
	 return;
      }

      printf("init.authdata = %s\n", init.authdata);

      if(AuthData != NULL) {
	    char *ret;
	 tmp = MD5Hash((unsigned char *)init.authdata);
	 if(strcmp(AuthData, tmp)) {
#ifdef DEBUG
	    err_msg("Connection_process: AuthData from localfd=%d is invalid!", con->localfd);
#endif
	    aesoplog("Authentication for %s failed.", ret = print_sock(con->localfd));
	    free(ret);
	    Connection_del(con);
	    free(tmp);
	    return;
	 }
	 free(tmp);
#ifdef DEBUG
	 err_msg("Connection_process: AuthData from localfd=%d is valid", con->localfd);
#endif
	 aesoplog("%s authenticated succesfully.", ret = print_sock(con->localfd));
	 free(ret);
      }

      con->init = Malloc(sizeof(conninit));
      memcpy(con->init, &init, sizeof(conninit));
      con->init->control = ntohs(con->init->control);
      con->control = con->init->control;
      con->init->pkey.bits =  ntohl(con->init->pkey.bits);

      if((!ipv6support) && (con->control & CNTRL_INET6)) {
#ifdef DEBUG
	 err_msg("Connection_process: Got unsupported IPv6 request on fd=%d", con->localfd);
#endif
	 Connection_del(con);
	 return;
      }

      if(con->control & CNTRL_ENDPO) {	/* Endpoint */
	 Connection_DH_init(con, AS_ENDPOINT);
	 return;
      }
   }

   if((sin = w_setsockaddr(&init.targetip, init.targetport, &sinlen,
			   con->control & CNTRL_INET6)) == NULL) {
#ifdef DEBUG
      err_msg("Connection_process: w_setsockaddr returned NULL");
#endif

      Connection_del(con);
      return;
   }

#ifdef DEBUG
   err_msg("Connection_process: Connection request for localfd=%d to %s", con->localfd, ret = w_sock_ntop(sin));
   free(ret);
#endif
   aesoplog("%s wants to connect to %s", ret = print_sock(con->localfd), ret2 = w_sock_ntop(sin));
   free(ret); free(ret2);
   if((sd = socket(sin->sa_family, SOCK_STREAM, 0)) == -1) {
      if(errno == EAFNOSUPPORT) {
#ifdef DEBUG
	 err_msg("Connection_process: No AF support for localfd=%d", con->localfd);
#endif
	 free(sin);
	 Connection_del(con);
	 return;
      } else {
	 err_sys("Connection_process: socket error");
      }
   }

   if((flags = fcntl(sd, F_GETFL, 0)) == -1)
      err_sys("Connection_process: fcntl F_GETFL error");
   
   if(fcntl(sd, F_SETFL, flags | O_NONBLOCK) == -1)
      err_sys("Connection_process: fcntl F_SETFL error");

   if(setsockopt(sd, SOL_SOCKET, SO_OOBINLINE, &val, sizeof(int)) == -1)
      err_sys("Connection_add: Can not set SO_OOBINLINE");

   if((con->control & CNTRL_ENDPO) || (con->control & CNTRL_ALONE)) {
      if((con->control & CNTRL_APORT) || (con->control & CNTRL_MPORT)) {
	 char *ret;
	 int n = con->control & CNTRL_APORT ? 1 : 2;
	 aesoplog("%s requested %s sourceport specification for port %d.", ret = print_sock(con->localfd),
	      n == 1 ? "advisory" : "mandatory", ntohs(init.data));
	 free(ret);
	 if(ntohs(init.data) < sp_base) {
#ifdef DEBUG
	    err_msg("Connection_process: %s sourceport specification (%d) is not allowed",
		     n == 1 ? "Advisory" : "Mandatory", ntohs(init.data));
#endif   
	    if(n == 1) {
	       goto rest;
	    } else {
	       Connection_del(con);
	       return;
	    }
	 } else {
	    struct sockaddr *s;
	    socklen_t slen;
	    char data[16];

	    memset(data, 0, 16);
	    if((s = w_setsockaddr(data, init.data, &slen,
			   con->control & CNTRL_INET6)) == NULL) {
#ifdef DEBUG
	       err_msg("Connection_process: w_setsockaddr returned NULL");
#endif
	       if(n == 1) {
		  goto rest;
	       } else {
		  Connection_del(con);
		  return;
	       }
	    }

	    if(setsockopt(sd, SOL_SOCKET, SO_REUSEADDR, &val, sizeof(val)) == -1) {
#ifdef DEBUG
	       err_ret("Connection_process: Could not set SO_REUSEADDR");
#endif
	    }

	    if(bind(sd, s, slen) == -1) {
#ifdef DEBUG
	       err_ret("Connection_process: %s sourceport(%d) bind() failed",
			n == 1 ? "Advisory" : "Mandatory", ntohs(init.data));
#endif
	       if(n == 1) {
		  goto rest;
	       } else {
		  Connection_del(con);
		  return;
	       }
	    }
#ifdef DEBUG
	    err_msg("Connection_process: %s sourceport(%d) succesfully bound",
		     n == 1 ? "Advisory" : "Mandatory", ntohs(init.data));
#endif
	 }
      }
   }

rest:

   if(!(con->control & CNTRL_CHAIN)) {
      memset(con->init, 0, sizeof(conninit));
      free(con->init);
      con->init = NULL;
   }

   if((n = connect(sd, sin, sinlen)) == -1) {
      if(errno != EINPROGRESS) {
	 err_ret("Connection_process: connect didnt return EINPROGRESS");
	 close(sd);
	 Connection_del(con);
	 free(sin);
	 return;
      }
   } else if(n == 0) {	/* Connection succeeded */
      con->remotefd = sd;
      maxfd = MAX(maxfd, sd);
      free(sin);
      CheckConnection(con);
      return;
   }

#ifdef DEBUG
   err_msg("Connection_process: Processed localfd=%d, remotefd=%d", con->localfd, sd);
#endif

   free(sin);
   con->remotefd = sd;
   maxfd = MAX(maxfd, sd);
   con->status = ST_INPROGR;
   FD_SET(sd, &wrset);
}

/* Connection_DH_init() does server side DH initialisation. It stores it's part
   of the protocol in the connection buffer. Subsequent calls of
   Connection_DH_senddata() write it out later. If 'how' is AS_ENDPOINT the
   DH data will also be encrypted as it goes over the already setup encrypted link
*/
void Connection_DH_init(Connection *con, int how) {
   R_RANDOM_STRUCT RandomStruct;
 
   InitRandomStruct(&RandomStruct);

   if((con->dh = InitDH(&RandomStruct, 1, &ServerPrivateKey)) == NULL) {
#ifdef DEBUG
      err_msg("Connection_DH_init: InitDH() failed");
#endif
      Connection_del(con);
      return;
   }
   
   /* Initialise buffer with server side DH data which will be written
      out by Writedata(). If the server does DH back with the user over
      an existing encrypted link the buffer must be encrypted. The next
      if/else statements handle that. */
   memcpy(con->rbuf, con->dh->publicValue, DH_VAL_LEN);
   memcpy(con->rbuf+DH_VAL_LEN, con->dh->signature, DH_SIGNATURE_LEN);

   if(how == AS_NORMAL) {
      con->status = ST_DHSEND1;
      FD_SET(con->localfd, &wrset);
#ifdef DEBUG
      err_msg("Connection_DH_init: Generated DH data for localfd=%d", con->localfd);
#endif
   } else if(how == AS_CHAIN) {
      con->status = ST_DHSEND2;
      FD_SET(con->remotefd, &wrset);
#ifdef DEBUG
      err_msg("Connection_DH_init: Generated DH data for remotefd=%d", con->remotefd);
#endif
   } else if(how == AS_ENDPOINT) {
      con->status = ST_DHSEND3;
      con->rpos = DH_DATA_LEN;
      arcfour_encrypt(&con->e, con->rbuf, con->rbuf, DH_DATA_LEN);
      FD_SET(con->localfd, &wrset);
#ifdef DEBUG
      err_msg("Connection_DH_init: Generated DH data as endpoint for localfd=%d", con->localfd);
#endif
   }
}

/* Connection_DH_senddata() will be called subsequently until all DH data has been
   written out. Once that's done it will change the status of the connection 
   structure appropriately. */

void Connection_DH_senddata(Connection *con, int to) {
   int n, fd;
   char *name;

   if(to == TO_LOCAL) {
      name = "localfd";
      fd = con->localfd;
      if(con->status == ST_DHSEND3) {
	 Writedata(con, TO_LOCAL);
	 FD_CLR(con->remotefd, &rdset);	/* Was set by Writedata() */
	 if(!con->rpos) {
#ifdef DEBUG
	    err_msg("Connection_DH_senddata: Send out all server DH data as endpoint to %s=%d", name, fd);
#endif
	    con->status = ST_DHRECV3;
	    FD_CLR(con->localfd, &wrset);
	    FD_SET(con->localfd, &rdset);
	 }
	 return;
      }
      n = write(con->localfd, (con->rbuf + con->rpos), (DH_DATA_LEN - con->rpos));
   } else {
      name = "remotefd";
      fd = con->remotefd;
      n = write(con->remotefd, (con->rbuf + con->rpos), (DH_DATA_LEN - con->rpos));
   }

   switch(n) {
      case -1:
#ifdef DEBUG
	 err_ret("Connection_DH_senddata: write error to %s=%d", name, fd);
#endif
	 Connection_del(con);
	 break;
      case 0: 	/* Huh?, when does this happen? */
#ifdef DEBUG
	 err_ret("Connection_DH_senddata: write to %s=%d returned zero", name, fd);
#endif
	 Connection_del(con);
	 break;
      default:
	 if((con->rpos + n) == DH_DATA_LEN) { /* We send all DH data, time to read client DH data */
	    if(to == TO_LOCAL) {
	       con->status = ST_DHRECV1;
	       FD_CLR(con->localfd, &wrset);
	       FD_SET(con->localfd, &rdset);
	    } else {
	       con->status = ST_DHRECV2;
	       FD_CLR(con->remotefd, &wrset);
	       FD_SET(con->remotefd, &rdset);
	    }
	    con->rpos = 0;
#ifdef DEBUG
	    err_msg("Connection_DH_senddata: Send out all server DH data to %s=%d", name, fd);
#endif
	 } else {
#ifdef DEBUG
	    err_msg("Connection_DH_senddata: Send out %d bytes of DH data to %s=%d", n - con->rpos, name, fd);
#endif
	    con->rpos += n;
	 }
   }
}

/* Connection_DH_readdata() will read all DH data from the other side
   in subsequent calls. As soon as it has read all data it will be passed
   to Connection_DH_process(). If 'from' is  FROM_LOCAL and the status
   of the structure is ST_DHRECV3, it will use Readdata() as the DH data
   must also get decrypted. */
      
void Connection_DH_readdata(Connection *con, int from) {
   int n, fd;
   char *name;

   if(from == FROM_LOCAL) {
      name = "localfd";
      fd = con->localfd;
      if(con->status == ST_DHRECV3) {
	 int max = DH_DATA_LEN - con->lpos;
	 Readdata(con, FROM_LOCAL, max);
	 FD_CLR(con->remotefd, &wrset);	/* Set by Readdata() */
	 if(con->lpos == DH_DATA_LEN) {
#ifdef DEBUG
	    err_msg("Connection_DH_readdata: Read in all client DH data as endpoint from %s=%d", name, fd);
#endif
	    con->lpos = 0;
	    Connection_DH_process(con, AS_ENDPOINT);
	 }
	 return;
      }
      n = read(con->localfd, (con->lbuf + con->lpos), (DH_DATA_LEN - con->lpos));
   } else {	/* FROM_REMOTE	*/
      name = "remotefd";
      fd = con->remotefd;
      n = read(con->remotefd, (con->lbuf + con->lpos), (DH_DATA_LEN - con->lpos));
   }

   switch(n) {
      case -1:
#ifdef DEBUG
	 err_ret("Connection_DH_readdata: read error from %s=%d", name, fd);
#endif
	 Connection_del(con);
	 break;
      case 0:
#ifdef DEBUG
	 err_ret("Connection_DH_readdata: premature end of file from %s=%d", name, fd);
#endif
	 Connection_del(con);
	 break;
      default:
	 con->lpos += n;
	 if(con->lpos == DH_DATA_LEN) {
#ifdef DEBUG
	    err_msg("Connection_DH_readdata: Read in all client DH data from %s=%d", name, fd);
#endif
	    con->lpos = 0;
	    if(from == FROM_LOCAL)
	       Connection_DH_process(con, AS_NORMAL);
	    else
	       Connection_DH_process(con, AS_CHAIN);
	    return;
	 } else {
#ifdef DEBUG
	    err_msg("Connection_DH_readdata: Read in %d bytes of DH data from %s=%d", n, name, fd);
#endif
	 }
   }
   return;
}

/* Connection_DH_process() computes the key from the DH parameters.
   And initialises the appropriate ciphers with it.
 */

void Connection_DH_process(Connection *con, int how) {
   unsigned char key[64+1];	/* 512 bit key is exchanged */
   R_RSA_PUBLIC_KEY *Pub = NULL;

   if(how == AS_CHAIN) {
      con->dh->server = 0;
      Pub = &con->init->pkey;
   }

   memset(key, 0, 65);
   if(ComputeDH(con->lbuf, con->dh, Pub, key) != 0) {
#ifdef DEBUG
      err_msg("Connection_DH_process: ComputeDH failed");
#endif
      Connection_del(con);
      return;
   }

#ifdef DEBUG
   if(debug) {
      char *tmp = base64_encode(key, 64);
      err_msg("Connection_DH_process: DH negotiation succesful");
      err_msg("Connection_DH_process: Agreed key = %s", tmp);
      free(tmp);
   }
#endif

   free(con->dh);	/* Not needed anymore */
   con->dh = NULL;

   if(how == AS_NORMAL) {
      arcfour_init(&con->e, key, 64);
      con->d = con->e;	/* Saves a bit of processing time */
      FD_CLR(con->localfd, &wrset);
      FD_SET(con->localfd, &rdset);
      con->status = ST_WAITINF;
      con->time = time(NULL) + timeout;
   } else if(how == AS_CHAIN) {
      arcfour_init(&con->e2, key, 64);
      con->d2 = con->e2;
      FD_SET(con->localfd, &rdset);
      con->status = ST_CONNECT;
      memset(con->init, 0, sizeof(conninit));
      free(con->init);
      con->init = NULL;
   } else if(how == AS_ENDPOINT) {
      arcfour_init(&con->e2, key, 64);
      con->d2 = con->e2;
      Connection_process(con);
   }
}

void Connection_eof(Connection *con, int who) {
   
   if(who == EOF_LOCAL) {
      if(con->status == ST_FINISHL) {
#ifdef DEBUG
	 err_msg("Connection_eof: Wrote out all remaining data to remotefd=%d, closing end for writing",
		  con->remotefd);
#endif
	 shutdown(con->remotefd, SHUT_WR);
      } else if(con->status == ST_FINISHR) {
#ifdef DEBUG
	 err_msg("Connection_eof: Both localfd=%d and remotefd=%d are finished, deleting structure",
		  con->localfd, con->remotefd);
#endif
	 Connection_del(con);
	 return;
      } else {
	 if(con->lpos) {	/* We have data left to write out */
#ifdef DEBUG
	    err_msg("Connection_eof: EOF for localfd=%d, %d bytes left to write out, entering ST_FINISHL state",
		     con->localfd, con->lpos);
#endif
	    FD_SET(con->remotefd, &wrset);
	 } else {
#ifdef DEBUG
	    err_msg("Connection_eof: EOF for localfd=%d, entering ST_FINISHL state and closing other end for writing",
		     con->localfd);
#endif
	    shutdown(con->remotefd, SHUT_WR);
	 }

	 shutdown(con->localfd, SHUT_RD);	/* We dont want more data */
	 
	 FD_CLR(con->localfd, &rdset);
	 con->status = ST_FINISHL;
	 return;
      }
   } else if(who == EOF_REMOT) {
      if(con->status == ST_FINISHR) {
#ifdef DEBUG
	 err_msg("Connection_eof: Wrote out all remaining data to localfd=%d, closing end for writing",
		  con->localfd);
#endif
	 shutdown(con->localfd, SHUT_WR);
      } else if(con->status == ST_FINISHL) {
#ifdef DEBUG
	 err_msg("Connection_eof: Both localfd=%d and remotefd=%d are finished, deleting structure",
		  con->localfd, con->remotefd);
#endif
	 Connection_del(con);
	 return;
      } else {
	 if(con->rpos) {        /* We have data left to write out */
#ifdef DEBUG
	    err_msg("Connection_eof: EOF for remotefd=%d, %d bytes left to write out, entering ST_FINISHR state",
		     con->remotefd, con->rpos);
#endif
	    FD_SET(con->localfd, &wrset);
	 } else {
#ifdef DEBUG
	    err_msg("Connection_eof: EOF for remotefd=%d, entering ST_FINISHR state and closing other end for writing",
		     con->remotefd);
#endif
	    shutdown(con->localfd, SHUT_WR);
	 }

	 shutdown(con->remotefd, SHUT_RD);   /* We dont want more data */

	 FD_CLR(con->remotefd, &rdset);
	 con->status = ST_FINISHR;
	 return;
      }
   }
}

void dec_conninit(R_RSA_PRIVATE_KEY *pkey, unsigned char *data, unsigned char *dest) {
   int i;
   int ret;

   for(i = ENC_CONN_SIZE; i > 0; i -= 128) {
      unsigned int outputlen, inputlen = 128;
      
      if((ret = RSAPrivateDecrypt(dest, &outputlen, data, inputlen, pkey)) != 0) {
	 err_msg("dec_conninit: RSAPrivateDecrypt() failure: %d", ret);
	 return;
      }

      dest += 117;
      data += 128;
  }
}
