/* autrace.c -- 
 * Copyright 2005 Red Hat Inc., Durham, North Carolina.
 * All Rights Reserved.
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 *
 * Authors:
 *     Steve Grubb <sgrubb@redhat.com>
 */

#include "config.h"
#include <stdio.h>
#include <string.h>
#include <sys/wait.h>
#include <unistd.h>
#include <fcntl.h>
#include <stdlib.h>
#include <errno.h>
#include <time.h>
#include "libaudit.h"

/*
 * This program will add the audit rules to trace a process similar
 * to strace. It will then execute the process.
 */

static int count_rules(void);
static int count_em(int fd);

static void usage(void)
{
    fprintf(stderr, "usage: autrace program\n");
}

/*
 * Algorithm:
 * check that user is root
 * check to see if program exists
 * if so fork, child waits for parent
 * parent clears audit rules, loads audit all syscalls with child's pid
 * parent tells child to go & waits for sigchld
 * child exec's program
 * parent deletes rules after getting sigchld
 */
int main(int argc, char *argv[])
{
	int fd[2];
	int pid;
	char buf[2];

	if (argc < 2) {
		usage();
		return 1;
	}
	if (strcmp(argv[1], "-h") == 0) {
		usage();
		return 1;
	}
	if (getuid() != 0) {
		fprintf(stderr, "You must be root to run this program.\n");
		return 1;
	}
	if (access(argv[1], X_OK)) {
		if (errno == ENOENT)
			fprintf(stderr, "Error - can't find: %s\n", argv[1]); 
		else
			fprintf(stderr, "Error checking %s (%s)\n", 
				argv[1], strerror(errno));
		return 1;
	}
	set_aumessage_mode(MSG_STDERR, DBG_NO);
	switch (count_rules())
	{
		case -1:
			fprintf(stderr, "Error - can't get rule count.\n");
			return 1;
		case 0:
			break;
		default:
			fprintf(stderr, 
			"autrace cannot be run with rules loaded.\n Please "
			"delete all rules using 'auditctl -D' if you really "
			"wanted to run this command.\n");
			return 1;
	}
	if (pipe(fd) != 0) {
		fprintf(stderr, "Error creating pipe.\n");
		return 1;
	}
	
	switch ((pid=fork()))
	{
		case -1:
			fprintf(stderr, "Error forking.\n");
			return 1;
		case 0: /* Child */
			close(fd[1]);
			printf("Waiting to execute: %s\n", argv[1]);
			while (read(fd[0], buf, 1) == -1 && errno == EINTR)
				/* blank */ ;
			close(fd[0]);
			execvp(argv[1], &argv[1]);
			fprintf(stderr, "Failed to exec %s\n", argv[1]);
			return 1;
		default: /* Parent */
			close(fd[0]);
			fcntl(fd[1], F_SETFD, FD_CLOEXEC);
			{
				char rule[64];
				snprintf(rule, sizeof(rule), 
			"/sbin/auditctl -a entry,always -F pid=%d -S all",
					pid);

				(void)system(rule);
			}
			sleep(1);
			(void)write(fd[1],"1", 1);
			waitpid(pid, NULL, 0);
			close(fd[1]);
			puts("Cleaning up...");
			(void)system("/sbin/auditctl -D");
			printf("Trace complete. "
				"You can locate the records with "
				"\'ausearch -i -p %d\'\n",
				pid);
			break;
	}

	return 0;
}

static int count_rules(void)
{
	int fd, total, count, rc;
	fd = audit_open();
	if (fd < 0) {
		total = -1;
		goto err_done;
	}
	rc = audit_request_rules_list(fd);
	if (rc > 0) 
		total = count_em(fd);
	else {
		total = -1;
		goto err_done;
	}
	if (total < 0)
		goto err_done;
	rc = audit_request_watch_list(fd);
	if (rc > 0)
		count = count_em(fd);
	else {
		if (rc == -EINVAL)
			count = 0; /* watches not supported in this kernel */
		else { 
			total = -1;
			goto err_done;
		}
	}
	if (count < 0) {
		total = -1;
		goto err_done;
	}
	total += count;

err_done:
	close(fd); 
	return total;
}

static int count_em(int fd)
{
	int i, retval, count = 0;
	int timeout = 40; /* loop has delay of .1 - this is 4 seconds */
	struct audit_reply rep;
	fd_set read_mask;

	FD_ZERO(&read_mask);
	FD_SET(fd, &read_mask);

	for (i = 0; i < timeout; i++) {
		retval = audit_get_reply(fd, &rep, GET_REPLY_NONBLOCKING, 0);
		if (retval > 0) {
			struct timeval t;

			if (rep.type == NLMSG_ERROR && 
					rep.error->error == 0)
				continue;
			t.tv_sec  = 0;
			t.tv_usec = 100000; /* .1 second */
			do {
				retval=select(fd+1, &read_mask, NULL, NULL, &t);
			} while (retval < 0 && errno == EINTR);
			switch (rep.type)
			{
				case NLMSG_DONE:
					return count;
				case AUDIT_LIST:
					i = 0;
					count++;
					break;
				case NLMSG_ERROR:
					return -1;
				default:
					break;
			}
		}
	}
	return count;
}

