/* $Id: ssh-add.c,v 1.5 2001/02/11 03:35:08 tls Exp $ */

/*
 * Copyright (c) 2001 Jason R. Thorpe.
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 * 3. The name of the author may not be used to endorse or promote products
 *    derived from this software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
 * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
 * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED
 * AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
 * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
 * SUCH DAMAGE.
 */

/*
 * Utility for adding keys to the FreSSH authentication agent.
 *
 * TODO:
 *
 *	- Add support for the SSHv2 authentication agent protocol.
 */

#include <sys/types.h>
#include <sys/wait.h>
#include <errno.h>
#include <string.h>
#include <signal.h>
#include <unistd.h>

#include "sshd.h"
#include "ssh_buffer.h"
#include "ssh_authagent.h"
#include "ssh_environ.h"
#include "ssh_defines.h"
#include "ssh_global.h"
#include "ssh_parse.h"
#include "ssh_main.h"
#include "ssh_ui.h"
#include "pathnames.h"

#include "crypto/ssh_cipher.h"

/* We'd like to use __progname if we have it. */
#define	FRESSH_ADD	"fressh-add"

#define	FRESSH_ADD_ADD		0
#define	FRESSH_ADD_LIST		1
#define	FRESSH_ADD_DELETE	2
#define	FRESSH_ADD_DELETEALL	3

int	main(int, char *[]);
void	cleanup(void);
void	usage(void);

void	delete_all(void);
void	list_all(void);
int	use_keyfile(char *, int);
int	add_key(ssh_RSA *, char *);
int	delete_key(ssh_RSA *);

int	get_passphrase(const char *, char **);
int	use_ssh_askpass(const char *, char **);

struct authagent_state *authagent;

int
main(int argc, char *argv[])
{
	int ch, mode;
	char *idfile;

	mode = FRESSH_ADD_ADD;

	g_context.running_as_server = SSH_ROLE_NEITHER;

	/* Initialize logging, only to stderr. */
	/* XXX We only need this for library routines. */
	loginit("ssh-add", 2);

	if (atexit(cleanup) != 0) {
		fprintf(stderr, "Unable to register cleanup routine: %s\n",
		    strerror(errno));
		exit(1);
	}

	authagent = ssh_authagent_opensession(&g_context);
	if (authagent == NULL) {
		fprintf(stderr,
		    "Unable to open an authentication agent session.\n");
		exit(1);
	}

	while ((ch = getopt(argc, argv, "dDl")) != -1) {
		if (mode != FRESSH_ADD_ADD)
			usage();
		switch (ch) {
		case 'd':
			mode = FRESSH_ADD_DELETE;
			break;

		case 'D':
			mode = FRESSH_ADD_DELETEALL;
			break;

		case 'l':
			mode = FRESSH_ADD_LIST;
			break;

		default:
			usage();
		}
	}
	argc -= optind;
	argv += optind;

	switch (mode) {
	case FRESSH_ADD_DELETEALL:
		if (argc != 0)
			usage();
		delete_all();
		exit(0);

	case FRESSH_ADD_LIST:
		if (argc != 0)
			usage();
		list_all();
		exit(0);
	}

	if (ssh_sys_lookupuid(&g_context, getuid()) < 0) {
		fprintf(stderr, "No user database entry for UID %d.\n",
		    getuid());
		exit(1);
	}

	if (argc != 0) {
		for (; argc != 0; argc--, argv++) {
			idfile = strdup(argv[0]);
			use_keyfile(idfile, mode);
		}
		exit(0);
	}

	idfile = strdup(def_ssh_IdentityFile);
	use_keyfile(idfile, mode);

	exit(0);
}

void
usage(void)
{

	fprintf(stderr, "%s Add keys to the Authentication Agent (SSHv1)\n",
	    SSHD_REV);
	fprintf(stderr, "Usage: %s [-d] [keyfile [...]]\n", FRESSH_ADD);
	fprintf(stderr, "       %s -D\n", FRESSH_ADD);
	fprintf(stderr, "       %s -l\n", FRESSH_ADD);

	exit(1);
}

/*
 * cleanup:
 *
 *	We sweep up on exit.
 */
void
cleanup(void)
{

	if (authagent != NULL)
		ssh_authagent_closesession(&g_context, authagent);
}

/*
 * delete_all:
 *
 *	Delete all keys from the authentication agent.
 */
void
delete_all(void)
{
	struct ssh_buf *req;
	u_int8_t replcode;

	req = buf_alloc(NULL, 16);
	if (req == NULL) {
		fprintf(stderr, "Unable to allocate request buffer: %s\n",
		    strerror(errno));
		exit(1);
	}

	if (buf_put_byte(req, SSH_AGENT_V1_REMOVE_ALL_RSA_ID) != 0) {
		fprintf(stderr, "Unable to put request into buffer: %s\n",
		    strerror(errno));
		exit(1);
	}

	if (ssh_authagent_request(authagent->sock, req, &authagent->buf) != 0) {
		buf_cleanup(req);
		buf_cleanup(&authagent->buf);
		exit(1);
	}
	buf_cleanup(req);

	if (buf_get_int8(&authagent->buf, &replcode) != 0) {
		fprintf(stderr, "Unable to get auth agent reply code.\n");
		buf_cleanup(&authagent->buf);
		exit(1);
	}
	buf_cleanup(&authagent->buf);

	switch (replcode) {
	case SSH_AGENT_V1_FAILURE:
		fprintf(stderr, "Auth agent reports FAILURE.\n");
		exit(1);

	case SSH_AGENT_V1_SUCCESS:
		/* Hooray, worked. */
		return;

	default:
		fprintf(stderr,
		    "Auth agent sent unknown reply: 0x%02x\n", replcode);
		exit(1);
	}
	/* NOTREACHED */
}

/*
 * list_all:
 *
 *	List all keys managed by the authentication agent.
 */
void
list_all(void)
{
	ssh_RSA *key = NULL;
	char *comment = NULL;
	char *outbuf;
	int count;

	if ((count =
	     ssh_authagent_v1_getcount(&g_context, authagent)) == 0 ||
	    ssh_authagent_v1_getkey(&g_context, authagent, &key,
				    &comment) != 0) {
		printf("Authentication agent has no SSHv1 keys.\n");
		return;
	}

	do {
		if (encode_public_keyfile(&g_context, &outbuf, key,
		    comment) < 0) {
			ssh_rsa_free(&key);
			free(comment);
			fprintf(stderr, "Unable to get public key string\n");
			exit(1);
		}
		printf("%s", outbuf);
		free(outbuf);
		ssh_rsa_free(&key);
		key = NULL;
		free(comment);
		comment = NULL;
	} while (ssh_authagent_v1_getkey(&g_context, authagent,
	    &key, &comment) == 0);
}

/*
 * Note that we are willing to expand ~ in the user's home
 * directory, but we're *not* willing to go trawling all over
 * the place looking for ~foo.
 */

#define	TILDE_EXPAND(c, s)						\
do {									\
	if ((s)[0] == '~' && (s)[1] == '/') {				\
		int te_l = strlen(s) + strlen((c)->pwent.pw_dir);	\
		char *te_t = malloc(te_l);				\
		sprintf(te_t, "%s%s", (c)->pwent.pw_dir, (s) + 1);	\
		free((s));						\
		(s) = te_t;						\
	}								\
} while (0)

/*
 * use_keyfile:
 *
 *	Perform an ADD or DELETE using the specified keyfile.
 */
int
use_keyfile(char *idfile, int mode)
{
	void *keydata = NULL;
	char *keyname = NULL;
	char *passphrase = NULL;
	ssh_RSA *key = NULL;
	off_t keysize;
	int rv;

	TILDE_EXPAND(&g_context, idfile);

	keydata = ssh_sys_readin(idfile, &keysize);
	if (keydata == NULL) {
		fprintf(stderr, "Unable to read keyfile %s\n", idfile);
		goto bad;
	}

 try_again:
	rv = decode_keyfile(&g_context, keydata, keysize, passphrase, 0,
	    &key, &keyname, NULL);
	switch (rv) {
	case DECODE_KEYFILE_OK:
		if (passphrase != NULL) {
			memset(passphrase, 0, strlen(passphrase));
			free(passphrase);
			passphrase = NULL;
		}
		break;

	case DECODE_KEYFILE_PASSPHRASE:
		/*
		 * Encrypted key -- if we're deleting a key, and
		 * we have the epart and npart, then that's all
		 * we need.
		 */
		if (mode == FRESSH_ADD_DELETE &&
		    key != NULL &&
		    ssh_rsa_epart(key) != NULL &&
		    ssh_rsa_npart(key) != NULL)
			break;

		ssh_rsa_free(&key);

		/*
		 * Hm.  In the DELETE case, should we use the public
		 * key file?
		 */

		if (get_passphrase(keyname, &passphrase) != 0)
			goto bad;
		if (passphrase[0] == '\0') {
			fprintf(stderr,
			    "Bad passphrase for key '%s'.\n", keyname);
			free(keyname);
			return (1);
		}
		free(keyname);
		keyname = NULL;
		goto try_again;

	default:
		fprintf(stderr,
		    "Unable to decode keyfile %s\n", idfile);
		goto bad;
	}

	ssh_sys_readrelease(keydata, keysize);

	if ((mode == FRESSH_ADD_ADD ?
	     add_key(key, keyname) : delete_key(key)) == 0) {
		fprintf(stderr, "Identity %s (%s) %s.\n",
		    idfile, keyname,
		    mode == FRESSH_ADD_ADD ? "added" : "removed");
	}

	ssh_rsa_free(&key);
	free(keyname);
	if (passphrase != NULL) {
		memset(passphrase, 0, strlen(passphrase));
		free(passphrase);
	}

	return (0);
 bad:
	if (keydata != NULL)
		ssh_sys_readrelease(keydata, keysize);
	if (keyname != NULL)
		free(keyname);
	if (key != NULL)
		ssh_rsa_free(&key);
	if (passphrase != NULL) {
		memset(passphrase, 0, strlen(passphrase));
		free(passphrase);
	}
	return (1);
}

/*
 * add_key:
 *
 *	Add the specified key to the authentication agent.
 */
int
add_key(ssh_RSA *key, char *keyname)
{
	struct ssh_buf *req;
	u_int8_t replcode;

	req = buf_alloc(NULL, 16);
	if (req == NULL) {
		fprintf(stderr, "Unable to allocate request buffer: %s\n",
		    strerror(errno));
		goto bad;
	}

	if (buf_put_byte(req, SSH_AGENT_V1_ADD_RSA_ID) != 0) {
		fprintf(stderr, "Unable to put request into buffer: %s\n",
		    strerror(errno));
		goto bad;
	}

	if (buf_put_int32(req, bignum_num_bits(ssh_rsa_npart(key))) != 0) {
		fprintf(stderr, "Unable to put keybits into buffer: %s\n",
		    strerror(errno));
		goto bad;
	}

	if (buf_put_bignum(req, ssh_rsa_npart(key)) != 0) {
		fprintf(stderr, "Unable to put npart into buffer: %s\n",
		    strerror(errno));
		goto bad;
	}

	if (buf_put_bignum(req, ssh_rsa_epart(key)) != 0) {
		fprintf(stderr, "Unable to put epart into buffer: %s\n",
		    strerror(errno));
		goto bad;
	}

	if (buf_put_bignum(req, ssh_rsa_dpart(key)) != 0) {
		fprintf(stderr, "Unable to put dpart into buffer: %s\n",
		    strerror(errno));
		goto bad;
	}

	if (buf_put_bignum(req, ssh_rsa_iqmppart(key)) != 0) {
		fprintf(stderr, "Unable to put iqmppart into buffer: %s\n",
		    strerror(errno));
		goto bad;
	}

	if (buf_put_bignum(req, ssh_rsa_qpart(key)) != 0) {
		fprintf(stderr, "Unable to put qpart into buffer: %s\n",
		    strerror(errno));
		goto bad;
	}

	if (buf_put_bignum(req, ssh_rsa_ppart(key)) != 0) {
		fprintf(stderr, "Unable to put ppart into buffer: %s\n",
		    strerror(errno));
		goto bad;
	}

	if (buf_put_binstr(req, (u_int8_t *) keyname, strlen(keyname)) != 0) {
		fprintf(stderr, "Unable to put comment into buffer: %s\n",
		    strerror(errno));
		goto bad;
	}

	if (ssh_authagent_request(authagent->sock, req, &authagent->buf) != 0) {
		buf_cleanup(&authagent->buf);
		goto bad;
	}

	if (buf_get_int8(&authagent->buf, &replcode) != 0) {
		fprintf(stderr, "Unable to get auth agent reply code.\n");
		buf_cleanup(&authagent->buf);
		goto bad;
	}
	buf_cleanup(&authagent->buf);

	switch (replcode) {
	case SSH_AGENT_V1_FAILURE:
		fprintf(stderr, "Auth agent reports FAILURE.\n");
		goto bad;

	case SSH_AGENT_V1_SUCCESS:
		/* Hooray, worked. */
		buf_cleanup(req);
		free(req);
		return (0);

	default:
		fprintf(stderr,
		    "Auth agent sent unknown reply: 0x%02x\n", replcode);
		goto bad;
	}
	/* NOTREACHED */

 bad:
	if (req != NULL) {
		buf_cleanup(req);
		free(req);
	}
	return (-1);
}

/*
 * delete_key:
 *
 *	Delete the specified key from the authentication agent.
 */
int
delete_key(ssh_RSA *key)
{
	struct ssh_buf *req;
	u_int8_t replcode;

	req = buf_alloc(NULL, 16);
	if (req == NULL) {
		fprintf(stderr, "Unable to allocate request buffer: %s\n",
		    strerror(errno));
		goto bad;
	}

	if (buf_put_byte(req, SSH_AGENT_V1_REMOVE_RSA_ID) != 0) {
		fprintf(stderr, "Unable to put request into buffer: %s\n",
		    strerror(errno));
		goto bad;
	}

	if (buf_put_int32(req, bignum_num_bits(ssh_rsa_npart(key))) != 0) {
		fprintf(stderr, "Unable to put keybits into buffer: %s\n",
		    strerror(errno));
		goto bad;
	}

	if (buf_put_bignum(req, ssh_rsa_epart(key)) != 0) {
		fprintf(stderr, "Unable to put epart into buffer: %s\n",
		    strerror(errno));
		goto bad;
	}

	if (buf_put_bignum(req, ssh_rsa_npart(key)) != 0) {
		fprintf(stderr, "Unable to put npart into buffer: %s\n",
		    strerror(errno));
		goto bad;
	}

	if (ssh_authagent_request(authagent->sock, req, &authagent->buf) != 0) {
		buf_cleanup(&authagent->buf);
		goto bad;
	}

	if (buf_get_int8(&authagent->buf, &replcode) != 0) {
		fprintf(stderr, "Unable to get auth agent reply code.\n");
		buf_cleanup(&authagent->buf);
		goto bad;
	}
	buf_cleanup(&authagent->buf);

	switch (replcode) {
	case SSH_AGENT_V1_FAILURE:
		fprintf(stderr, "Auth agent reports FAILURE.\n");
		goto bad;

	case SSH_AGENT_V1_SUCCESS:
		/* Hooray, worked. */
		buf_cleanup(req);
		free(req);
		return (0);

	default:
		fprintf(stderr,
		    "Auth agent sent unknown reply: 0x%02x\n", replcode);
		goto bad;
	}
	/* NOTREACHED */

 bad:
	if (req != NULL) {
		buf_cleanup(req);
		free(req);
	}
	return (1);
}

/*
 * get_passphrase:
 *
 *	Get the passphrase for the key named by "keyname"
 *	from the user.
 */
int
get_passphrase(const char *keyname, char **passphrase)
{
	char *prompt = NULL;
	int len;

#define	FIRST_PROMPT	"Enter passphrase for RSA key '%s': "
#define	NEXT_PROMPT	"Bad passphrase.  " FIRST_PROMPT

	if (*passphrase == NULL)
		len = sizeof(FIRST_PROMPT);
	else {
		len = sizeof(NEXT_PROMPT);
		memset(*passphrase, 0, strlen(*passphrase));
		free(*passphrase);
	}

	len += strlen(keyname) + 1;
	prompt = malloc(len);
	if (prompt == NULL)
		goto bad;

	sprintf(prompt, *passphrase == NULL ? FIRST_PROMPT : NEXT_PROMPT,
	    keyname);

	if (isatty(fileno(stdin)) == 0) {
		/*
		 * Not an interactive session -- try to use an
		 * X11 passphrase getter.
		 */
		if (use_ssh_askpass(prompt, passphrase) != 0)
			goto bad;
		return (0);
	}

	if (ssh_ui_prompt(passphrase, prompt, 0) != 0)
		goto bad;
	return (0);

 bad:
	if (prompt != NULL)
		free(prompt);
	*passphrase = NULL;
	return (-1);
}

/*
 * use_ssk_askpass:
 *
 *	Invoke an X11 passphrase getter to get a passphrase.
 */
int
use_ssh_askpass(const char *prompt, char **passphrase)
{
	char buf[128], *cp;
	const char *prog;
	pid_t pid;
	size_t rv;
	int p[2], estat;

	prog = getenv(SSH_ENVVAR_ASKPASS_PROG);
	if (prog == NULL)
		prog = _PATH_ASKPASS_PROG;

	if (pipe(p) < 0) {
		fprintf(stderr,
		    "Unable to create askpass communication channel: %s\n",
		    strerror(errno));
		return (-1);
	}

	fflush(stdout);

	switch ((pid = fork())) {
	case -1:
		fprintf(stderr,
		    "Unable to fork for askpass program: %s\n",
		    strerror(errno));
		return (-1);

	case 0:
		/* Child -- exec the askpass program. */
		(void) close(p[0]);
		if (dup2(p[1], fileno(stdout)) < 0) {
			fprintf(stderr,
			    "Unable to connect stdout to communication "
			    "channel: %s\n", strerror(errno));
			_exit(1);
		}
		execlp(prog, prog, prompt, NULL);
		fprintf(stderr,
		    "Unable to exec %s: %s\n", prog, strerror(errno));
		_exit(1);

	default:
		/* Parent -- handled below. */
		break;
	}

	(void) close(p[1]);
	rv = read(p[0], buf, sizeof(buf));
	(void) close(p[0]);
	while (waitpid(pid, &estat, 0) < 0) {
		if (errno != EINTR)
			break;
	}
	if (rv <= 1)
		*passphrase = strdup("");
	else {
		cp = strchr(buf, '\n');
		if (cp != NULL)
			*cp = '\0';
		*passphrase = strdup(buf);
	}
	memset(buf, 0, sizeof(buf));

	return (0);
}
