/* $Id: ssh_threads.c,v 1.48 2001/02/11 03:35:15 tls Exp $ */

/*
 * Copyright 1999 RedBack Networks, Incorporated.
 * All rights reserved.
 *
 * This software is not in the public domain.  It is distributed
 * under the terms of the license in the file LICENSE in the
 * same directory as this file.  If you have received a copy of this
 * software without the LICENSE file (which means that whoever gave
 * you this software violated its license) you may obtain a copy from
 * http://www.panix.com/~tls/LICENSE.txt
 */

/*
 * Copyright (c) 2000, 2001 Andrew Brown, Eric Haszlakiewicz,
 * 	and Jason R. Thorpe.
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 * 3. The name of the author may not be used to endorse or promote products
 *    derived from this software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
 * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
 * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED
 * AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
 * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
 * SUCH DAMAGE.
 */

/*
 * This file contains routines that drive the flow of
 * an sshd server which has just received an incoming
 * connection.
 */

#include <errno.h>
#include <signal.h>
#include <string.h>
#include <sys/time.h>
#include <sys/types.h>
#include <unistd.h>
#include <sys/wait.h>
#include <netinet/in.h>

#include "options.h"

#include "sshd.h"

#include "ssh_buffer.h"
#include "ssh_event.h"
#include "ssh_v1_child.h"
#include "ssh_v1_proto.h"
#include "ssh_v1_messages.h"
#include "ssh_v2_proto.h"
#include "ssh_sys.h"
#include "ssh_transport.h"
#include "ssh_util.h"
#include "ssh_types.h"
#include "ssh_client.h"
#include "ssh_threads.h"

static void ssh_read_from_fd(ssh_context_t *, fd_set *);

void 
ssh_send_thread(struct ssh_context * context)
{
	FUNC_DECL(ssh_send_thread);
	fd_set readfds;
	struct ssh_event ev;
	int retval;
	int maxfd;
	int i;

	context->thread_id = SSH_THREAD_SEND;
	ssh_sys_collapse_eventq(context);

	SSH_DLOG(2, ("thread starting\n"));
	if (is_debug_level(6))
		sleep(6);
	/*
	 * Take data from:
	 *    1) Event queue
	 *    2) Local file descriptors
	 *
	 */

/*
 * XAX XXX
 * transport layer max packet size > max data to give to transport layer >=
 * per channel max packet size >= per channel size to read
 */
	while (1) {
		/* XXX Revisit this after defining channel data structures */
		if (context->v1_ctx.sent_exitstat &&
		    !context->v1_ctx.have_other_channels) {
			/* Primary channel exited and we have no other
			 * channels, so exit */
			ssh_exit(context, 0, EXIT_NOW);
		}
		FD_ZERO(&readfds);
		maxfd = 0;
		for (i = 0; i < context->nfds; i++) {
			FD_SET(context->fdi[i].fd, &readfds);
			if (context->fdi[i].fd > maxfd)
				maxfd = context->fdi[i].fd;
		}
		maxfd++;
		if (ssh_sys_select(context, maxfd, &readfds, NULL, NULL,
				   NULL, &ev) < 0) {
			SSH_ERROR("ssh_sys_select returned fatal error\n");
			ssh_exit(context, 1, EXIT_NOW);
		}
		if (ev.event_type != SSH_EVENT_NONE) {
			retval = ssh_dispatch_event(context, &ev);
			if (retval < 0) {
				SSH_ERROR("ssh_dispatch_event returned "
				    "fatal error\n");
				ssh_exit(context, 1, EXIT_NOW);
			}
			if (retval == RET_FAIL) {
				/* Our connection probably went away, exit */
				ssh_exit(context, 0, EXIT_NOW);
			}
			if (retval == RET_NEXT_STATE) {
				/* we're done here...let's just exit */
				ssh_sys_exit(0);
			}
		} else
			ssh_read_from_fd(context, &readfds);
	}
}

static void 
ssh_read_from_fd(ssh_context_t *context, fd_set *readfds)
{
	FUNC_DECL(ssh_read_from_fd);
	int i;
	int nread;
	int nfds;
	ssh_fdinfo_t *fdi;
	int retval;

	nfds = context->nfds;
	fdi = context->fdi;

	for (i = 0; i < nfds; i++) {
		if (!FD_ISSET(fdi[i].fd, readfds))
			continue;
		switch (fdi[i].fd_type) {
		/* V1 only.  V2 uses CHANNEL_DATA */
#ifdef WITH_PROTO_V1
		case SSH_FD_STDDATA:
		case SSH_FD_STDINDATA:
			nread = read(fdi[i].fd,
				     fdi[i].readbuf,
				     fdi[i].max_read_size);
			if (nread <= 0) {
				if (nread < 0) {
					SSH_DLOG(1,
					  ("read stddata failed: %d\n", errno));
				} else {
					/*
					 * This is an EOF on the primary
					 * channel.  That means the
					 * shell/command exited so we need
					 * to send the exitstatus.
					 * However, we might not have it yet.
					 * If not, we'll send it when we get
					 * it.
					 */
					context->v1_ctx.primary_eof = YES;

					SSH_DLOG(5, ("exitstat %d %d %d", context->v1_ctx.have_exitstat, context->v1_ctx.sent_exitstat, context->nfds));
					if (context->v1_ctx.have_exitstat &&
					    !context->v1_ctx.sent_exitstat) {
						/*
						 * Note: the existstatus buf
						 * has the entire payload
						 */
						retval =
					   (context->transport_ctx.xmit_packet)
						    (context,
						    context->v1_ctx.exitstatus);
						if (retval < 0) {
							SSH_ERROR(
					   "xmit_packet returned fatal error");
							ssh_exit(context, 1,
							    EXIT_NOW);
						} else if (retval == 1) {
							/*
							 * Connection went
							 * away.
							 */
							ssh_exit(context, 0,
							    EXIT_NOW);
						}
						/*
						 * Set flag so we exit if no
						 * other channels are open
						 */
						SSH_DLOG(5, ("sent_exitstat set to 1"));
						context->v1_ctx.sent_exitstat =
						    YES;
					}
					/*
					 * Else, wait until we get the
					 * exitstatus
					 */
				}

				/*
				 * If EOF or error reading, then remove the
				 * fd.
				 */
				if (nread == 0 || errno != EINTR) {
					if (fdi[i].fd_type ==
					    SSH_FD_STDINDATA) {
						/*
						 * EOF on stdin...tell the
						 * server
						 */
						struct ssh_buf eof_buf;

						memset(&eof_buf, 0,
						    sizeof(eof_buf));
						if (buf_alloc(&eof_buf, 1) ==
							      NULL ||
						    buf_put_int8(&eof_buf,
						       SSH_V1_CMSG_EOF) != 0) {
							SSH_ERROR(
						"Unable to alloc eof buf: %s\n",
							    strerror(errno));
							ssh_exit(context, 1,
							    EXIT_NOW);
						}
						retval =
					    (context->transport_ctx.xmit_packet)
						    (context, &eof_buf);
						buf_cleanup(&eof_buf);

						if (retval < 0) {
							SSH_ERROR(
					    "xmit_packet returned fatal error");
							ssh_exit(context, 1,
							    EXIT_NOW);
						} else if (retval == 1) {
							/*
							 * Connection went
							 * away.
							 */
							ssh_exit(context, 0,
							    EXIT_NOW);
						}
					}
					SSH_DLOG(4,
					    ("Removed fd: %d  index:%d\n",
					    fdi[i].fd, i));
					close(fdi[i].fd);
					if (fdi[i].xmitbuf)
						free(fdi[i].xmitbuf);
					fdi[i].xmitbuf = NULL;
					fdi[i].readbuf = NULL;
					if (fdi[i].fd ==
					    fdi[i].p_chan->fd_normal)
						fdi[i].p_chan->fd_normal = -1;
					if (fdi[i].fd ==
					    fdi[i].p_chan->fd_normal)
						fdi[i].p_chan->fd_ext_stderr =
						    -1;
					if ((context->nfds - 1) > i)
						memcpy(&fdi[i],
						    &fdi[context->nfds - 1],
						    sizeof(ssh_fdinfo_t));
					else
						memset(&fdi[i], 0,
						    sizeof(ssh_fdinfo_t));
					context->nfds--;
					SSH_DLOG(4, ("%d fds remaining",
					    context->nfds));
					SSH_DLOG(5, ("exitstat %d %d", context->v1_ctx.have_exitstat, context->v1_ctx.sent_exitstat));
					if (context->nfds == 0 &&
					    context->v1_ctx.have_exitstat &&
					    !context->v1_ctx.sent_exitstat) {
						SSH_DLOG(4,
					       ("sending exit status at last"));
					    (context->transport_ctx.xmit_packet)
						(context,
						context->v1_ctx.exitstatus);
					    SSH_DLOG(5, ("sent_exitstat set to 1"));
					    context->v1_ctx.sent_exitstat =
						YES;
					}
				}
			} else { /* nread > 0 */
				struct ssh_buf Xbuf;

				if (!context->running_as_server &&
				    context->client->EscapeChar !=
					NO_ESCAPE_CHAR &&
				    context->client->StdinIsTty)
					prescan_input(context,
					    fdi[i].readbuf, &nread);

				/*
				 * The first byte of xmitbuf is already set
				 * to the message type for this fd.  This
				 * only works for V1, V2 has a different
				 * format.  We only need to set the binstr.
				 */
				setbinstr_len(&fdi[i].xmitbuf[1], nread);
				buf_setbuf(&Xbuf, fdi[i].xmitbuf,
				    nread + fdi[i].header_size);
				retval = (context->transport_ctx.xmit_packet)
				    (context, &Xbuf);
				if (retval < 0) {
					SSH_ERROR("xmit_packet returned "
					    "fatal error\n");
					ssh_exit(context, 1, EXIT_NOW);
				}
				if (retval == 1) {
					/* Connection went away. */
					ssh_exit(context, 0, EXIT_NOW);
				}
			}
			break;
#endif
		default:
			SSH_ERROR("Unknown fd type: %d\n", fdi[i].fd_type);
			ssh_exit(context, 1, EXIT_NOW);
		}
	}
}

/* -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- */

void 
ssh_recv_thread(ssh_context_t *context,
    int (*dispatch_msg)(ssh_context_t *, struct ssh_buf *, int))
{
	FUNC_DECL(ssh_recv_thread);

	struct ssh_transport *t_ctx;
	struct ssh_buf msg;
	int retval;
	int (*start_conv) (ssh_context_t *);
	int (*read_pkt) (ssh_context_t *, struct ssh_buf *);

	SSH_DLOG(2, ("thread starting\n"));
	if (is_debug_level(6))
		sleep(6);

	memset(&msg, 0, sizeof(msg));

	t_ctx = &context->transport_ctx;
	read_pkt = t_ctx->read_packet;

	context->thread_id = SSH_THREAD_RECV;
	ssh_sys_collapse_eventq(context);

	if (buf_alloc(&msg, 8192) == NULL) {
		SSH_ERROR("Unable to allocate message buffer");
		ssh_exit(context, 1, EXIT_NOW);
	}
	switch (context->protocol_version) {
#ifdef WITH_PROTO_V1
	case SSH_V1:
		start_conv = start_v1_conversation;
		break;
#endif
#ifdef WITH_PROTO_V2
	case SSH_V2:
		start_conv = start_v2_conversation;
		break;
#endif
	default:
		SSH_ERROR("Invalid protocol version in context!");
		ssh_exit(context, 1, EXIT_NOW);
	}

	memset(&msg, 0, sizeof(msg));

	/* Do anything necessary to start the conversation */
	if (start_conv(context) != 0) {
		SSH_ERROR("start_conv failed");
		ssh_exit(context, 1, EXIT_NOW);
	}
	/*
	 * Get data from the socket connection
	 * and dispatch it.
	 */

	if (!context->running_as_server && context->client->Verbosity)
		fprintf(stderr, "%s: Waiting for server public key.\r\n",
			context->client->localhost);

	while (1) {
		/*
		 * Read a packet from the transport layer.
		 * If this returns a positive size it means
		 * that we have at least enough decrypted data
		 * in msg to have an entire message.
		 */
read_a_packet:

		if (!context->running_as_server) {
			struct ssh_buf *mybuf;
			struct ssh_event ev;
			struct ssh_ev_send *evsend;
			ev.event_type = SSH_EVENT_NONE;

			if (context->client->sigwinch) {
				context->client->sigwinch = 0;
				ssh_sys_get_tty_size(context->client->ttyfd,
				    &context->win);
				mybuf = buf_alloc(NULL, 16);
				buf_put_int8(mybuf, SSH_V1_CMSG_WINDOW_SIZE);
				buf_put_int32(mybuf, context->win.ws_row);
				buf_put_int32(mybuf, context->win.ws_col);
				buf_put_int32(mybuf, context->win.ws_xpixel);
				buf_put_int32(mybuf, context->win.ws_ypixel);
				ev.event_type = SSH_EVENT_SEND;
				evsend = (struct ssh_ev_send *) ev.event_data;
				SSH_DLOG(1,
				    ("sending window change to server"));
				evsend->dlen = buf_alllen(mybuf);
				memcpy(evsend->data, buf_alldata(mybuf),
				    buf_alllen(mybuf));
				buf_cleanup(mybuf);
				free(mybuf);
				EVT_SEND(&ev, context);
			}
		}
		if ((retval = read_pkt(context, &msg)) < 0) {
			if (errno == EINTR)
				goto read_a_packet;
			if (context->send_pid == -1)
				SSH_DLOG(1,
				  ("ssh_send_thread exited.  Exiting ssh_recv_thread\n"));
			else
				SSH_ERROR("read_pkt failed: %s\n",
				    strerror(errno));
			ssh_exit(context, 1, EXIT_NOW);
		}
		if (retval != 0) {
			/* Figure out what to do with the message */
			/* Note: this function may cause us to exit */
			if ((retval = dispatch_msg(context, &msg,
						   buf_len(&msg))) < 0) {
				/* Shouldn't happen! */
				SSH_ERROR("dispatch_msg failed: %s\n",
				    strerror(errno));
				ssh_exit(context, 1, EXIT_NOW);
			}
			if (retval == RET_NEXT_STATE)
				return;	/* clients go here */
		}
	}
}
