/*
 * Copyright (c) 2002 Jean-Baptiste Marchand, Herv Schauer Consultants.  
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 * 3. All advertising materials mentioning features or use of this software
 *    must display the following acknowledgement:
 *      This product includes software developed by Jean-Baptiste Marchand
 *	at Herv Schauer Consultants.
 * 4. The name of the author may not be used to endorse or promote products
 *    derived from this software without specific prior written permission
 *
 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
 * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
 * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
 * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
 * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 */

#include <windows.h>
#include <stdio.h>

#include "my_fltdefs.h"
#include "pktflt.h"
#include "filters.h"
#include "filter_stats.h"

#include "logging.h"

char Log_File[MAX_PATH];

extern DWORD log_buf_size;
extern char *log_buf;
extern struct pf_interface *interfaces;

void print_header(char *buf, PFLOGFRAME *logFrame)
{
	FILETIME localTime;
	SYSTEMTIME logDate;
	char *block_reason;
	DWORD if_index = 0;
	static DWORD last_system_index = -1;
	static DWORD last_if_index;
	DWORD filterRule;
	static struct pf_interface *current;

	FileTimeToLocalFileTime((FILETIME *) &logFrame->Timestamp, &localTime);
	FileTimeToSystemTime(&localTime, &logDate);

	block_reason = "?";
		if (logFrame->pfeTypeOfFrame == PFFT_FILTER) {
			filterRule = logFrame->dwFilterRule;
			block_reason = "b";
		}
		if (logFrame->pfeTypeOfFrame == PFFT_FRAG) {
			filterRule = 0; 
			block_reason = "b-frag";
		}
		if (logFrame->pfeTypeOfFrame == PFFT_SPOOF) {
			filterRule = 0;
			block_reason = "b-spoof";
		}

	if (last_system_index == logFrame->dwIPIndex)
		if_index = last_if_index;
	else {
		/* find interface index */
		if_index = 0;
		current = interfaces;
		while (current && current->system_index != logFrame->dwIPIndex) {
			current = current->next;
			if_index++;
		}
		last_system_index = current->system_index;
		last_if_index = if_index;
	}
	
	sprintf(buf, "%02u/%02u/%u %02u:%02u:%02u.%03u %s @0:%u %s ", 
			logDate.wMonth, logDate.wDay, logDate.wYear,
			logDate.wHour, logDate.wMinute, logDate.wSecond, 
			logDate.wMilliseconds,		
			current->name,
			filterRule, block_reason);
			
}


void print_ip_len(char *buf, struct ip_hdr *iphdr)
{
			
	sprintf(buf + strlen(buf), "len %u %u ", (iphdr->ip_vhl & 0x0f) * 4,
			ntohs(iphdr->ip_len));
}


void print_frag(char *buf, struct ip_hdr *iphdr)
{

	sprintf(buf + strlen(buf), "len %u (%u) frag %u:%u@%u%s",
			(iphdr->ip_vhl & 0x0f) * 4,
			ntohs(iphdr->ip_len),
			ntohs(iphdr->ip_id),
			ntohs(iphdr->ip_len) - ((iphdr->ip_vhl & 0x0f) * 4),
			(ntohs(iphdr->ip_off) & IP_OFFMASK) << 3,
			(ntohs(iphdr->ip_off) & IP_MF) ? "+" : "");

}


void print_icmp(char *buf, struct ip_hdr *iphdr, char fragmented)
{
	char *ip_src, *ip_dst;

	struct icmp_hdr *icmphdr;
	char icmp_type_code[64];
	char *icmp_type = NULL;

	ip_src = strdup(inet_ntoa(iphdr->ip_src));
	ip_dst = strdup(inet_ntoa(iphdr->ip_dst));

	if (!fragmented) {
		sprintf(buf + strlen(buf), "%s -> %s PR icmp ", ip_src, ip_dst);
		print_ip_len(buf, iphdr);

		icmphdr = (struct icmp_hdr *) ((char *) iphdr + ((iphdr->ip_vhl & 0x0F) * 4));
			
		memset(icmp_type_code, 0, 64);
		if (icmphdr->icmp_type < ICMP_TYPES)
			icmp_type = icmp_types[icmphdr->icmp_type];
		if (icmp_type == NULL)
			sprintf(icmp_type_code, "icmptype(%d)/", icmphdr->icmp_type);
		else
			sprintf(icmp_type_code, "%s/", icmp_type);


		switch (icmphdr->icmp_type) {
			case ICMP_UNREACH:
					if (icmphdr->icmp_code >= ICMP_UNREACH_NAMES)
						sprintf(icmp_type_code + strlen(icmp_type_code), 
								"%d", icmphdr->icmp_code);
					else
						sprintf(icmp_type_code + strlen(icmp_type_code), 
								"%s", icmp_unreach_names[icmphdr->icmp_code]);
					break;
			
			case ICMP_REDIRECT: 
					if (icmphdr->icmp_code >= ICMP_REDIRECT_NAMES)
						sprintf(icmp_type_code + strlen(icmp_type_code), 
								"%d", icmphdr->icmp_code);
					else
						sprintf(icmp_type_code + strlen(icmp_type_code), 
								"%s", icmp_redirect_names[icmphdr->icmp_code]);
					break;
			case ICMP_TIMEX:
					if (icmphdr->icmp_code >= ICMP_TIMEX_NAMES)
						sprintf(icmp_type_code + strlen(icmp_type_code), 
								"%d", icmphdr->icmp_code);
					else
						sprintf(icmp_type_code + strlen(icmp_type_code), 
								"%s", icmp_timex_names[icmphdr->icmp_code]);
					break;
			case ICMP_PARAM_PROB:
					if (icmphdr->icmp_code >= ICMP_PARAM_PROB_NAMES)
						sprintf(icmp_type_code + strlen(icmp_type_code), 
								"%d", icmphdr->icmp_code);
					else
						sprintf(icmp_type_code + strlen(icmp_type_code), 
								"%s", icmp_param_prob_names[icmphdr->icmp_code]);
					break;
			default:
					sprintf(icmp_type_code + strlen(icmp_type_code),
							"%d", icmphdr->icmp_code);
					break;
		}

	}
	else {
		sprintf(buf + strlen(buf), "%s -> %s PR icmp ", ip_src, ip_dst);
		print_frag(buf, iphdr);
	}

	free(ip_src);
	free(ip_dst);
}


void print_tcp(char *buf, struct ip_hdr *iphdr, char fragmented)
{
	struct tcp_hdr *tcphdr;
	char *ip_src, *ip_dst;
	unsigned short src_port, dst_port;

	char stcp_flags[8];
	char tcp_flags;
	char tcp_flags_idx;

	ip_src = strdup(inet_ntoa(iphdr->ip_src));
	ip_dst = strdup(inet_ntoa(iphdr->ip_dst));

	if (!fragmented) {
		tcphdr = (struct tcp_hdr *) ((char *) iphdr + ((iphdr->ip_vhl & 0x0F) * 4));
		src_port = ntohs(tcphdr->th_sport);
		dst_port = ntohs(tcphdr->th_dport);

		sprintf(buf + strlen(buf), "%s,%u -> %s,%u PR tcp ", ip_src, src_port, ip_dst, dst_port);
		print_ip_len(buf, iphdr);

		memset(stcp_flags, 0, 8);
		stcp_flags[0] = '-';
		tcp_flags_idx = 1;
			
		tcp_flags = tcphdr->th_flags;
			
		if (tcp_flags & 0x10)
			stcp_flags[tcp_flags_idx++] = 'A';

		if (tcp_flags & 0x04)
			stcp_flags[tcp_flags_idx++] = 'R';
			
		if (tcp_flags & 0x02)
			stcp_flags[tcp_flags_idx++] = 'S';

		if (tcp_flags & 0x01)
			stcp_flags[tcp_flags_idx++] = 'F';

		if (tcp_flags & 0x20)
			stcp_flags[tcp_flags_idx++] = 'U';

		if (tcp_flags & 0x08)
			stcp_flags[tcp_flags_idx++] = 'P';

		sprintf(buf + strlen(buf), "%s ", stcp_flags);

	}
	else {
		sprintf(buf + strlen(buf), "%s -> %s PR tcp ", ip_src, ip_dst);
		print_frag(buf, iphdr);
	}

	free(ip_src);
	free(ip_dst);
	
}

void print_udp(char *buf, struct ip_hdr *iphdr, char fragmented)
{
	struct udp_hdr *udphdr;
	char *ip_src, *ip_dst;
	unsigned short src_port, dst_port;

	ip_src = strdup(inet_ntoa(iphdr->ip_src));
	ip_dst = strdup(inet_ntoa(iphdr->ip_dst));

	if (!fragmented) {
		udphdr = (struct udp_hdr *) ((char *) iphdr + ((iphdr->ip_vhl & 0x0F) * 4));
		src_port = ntohs(udphdr->uh_sport);
		dst_port = ntohs(udphdr->uh_dport);

		sprintf(buf + strlen(buf), "%s,%u -> %s,%u PR udp ",
				ip_src, src_port, ip_dst, dst_port);
		print_ip_len(buf, iphdr);
								
	}
	else {
		sprintf(buf + strlen(buf), "%s -> %s PR udp ", ip_src, ip_dst);
		print_frag(buf, iphdr);
	}
			
	free(ip_src);
	free(ip_dst);
	
}

void print_ip(char *buf, struct ip_hdr *iphdr, char fragmented)
{
	char *ip_src, *ip_dst;

	switch (iphdr->ip_p) {
		case 1:
			print_icmp(buf, iphdr, fragmented);
			break;
		case 6:
			print_tcp(buf, iphdr, fragmented);
			break;
		case 17:
			print_udp(buf, iphdr, fragmented);
			break;
		default:
			ip_src = strdup(inet_ntoa(iphdr->ip_src));
			ip_dst = strdup(inet_ntoa(iphdr->ip_dst));
			sprintf(buf + strlen(buf), "%s -> %s PR %u len %u %u", 
					ip_src, ip_dst, iphdr->ip_p, 
					(iphdr->ip_vhl & 0x0f) * 4,
					ntohs(iphdr->ip_len));
			free(ip_src);
			free(ip_dst);
	}
}


DWORD WINAPI LoggingThread(PVOID event) 
{
	HANDLE loggingEvent = (HANDLE) event;
	HANDLE logFile;
	DWORD status, logged, lost, old_size;
	DWORD written;
	
	PFLOGFRAME *cur_log_frame;
#define LOG_LINE_SIZE 256
	char buf[LOG_LINE_SIZE];
	struct ip_hdr *iphdr;
	char *cur_log_buf; 
	char *log_buf_limit = log_buf + log_buf_size;
		
	logFile = CreateFile(Log_File, GENERIC_WRITE, FILE_SHARE_READ, NULL, OPEN_ALWAYS, 0, NULL);

	if (logFile == INVALID_HANDLE_VALUE)
		/* error */
		return 1;

	status = SetFilePointer(logFile, 0, NULL, FILE_END);
	if (status == INVALID_SET_FILE_POINTER)
		/* error */
		return 1;

	while (1) {
		status = WaitForSingleObject(event, INFINITE);
		if (status == WAIT_FAILED)
			return 1;

		cur_log_buf = log_buf;
		cur_log_frame = (PFLOGFRAME *) cur_log_buf;

		while (1) {
			
			memset(buf, 0, LOG_LINE_SIZE);

			print_header(buf, cur_log_frame);
				
			iphdr = (struct ip_hdr *) cur_log_frame->bPacketData;
			if (ntohs(iphdr->ip_off) & (IP_MF | IP_OFFMASK)) 
				/* packet is a fragment */
				print_ip(buf, iphdr, 1);
			else
				print_ip(buf, iphdr, 0);

			if (cur_log_frame->dwFilterRule != 0) {
				/* we can deduce the direction of packet 
				   with the matching rule number */
				if (cur_log_frame->dwFilterRule < OUT_FILTERS_START)
					strcat(buf + strlen(buf), "IN");
				else
					strcat(buf + strlen(buf), "OUT");
			}

			strcat(buf + strlen(buf), "\n");

			WriteFile(logFile, buf, strlen(buf), &written, 0);

			cur_log_buf += cur_log_frame->dwTotalSizeUsed;
			cur_log_frame = (PFLOGFRAME *) cur_log_buf;

			if (cur_log_frame->dwTotalSizeUsed == 0)
				break;

			if ((cur_log_buf + cur_log_frame->dwTotalSizeUsed) > log_buf_limit)
				/* sanity check */
				break;
		}
			
		memset(log_buf, 0, log_buf_size);
		status = PfSetLogBuffer(log_buf, log_buf_size, 20, 1, &logged, &lost, &old_size);

		if (lost > 0) {
			/* some packets were lost (i.e, not logged), warn the administrator */
			memset(buf, 0, LOG_LINE_SIZE);
			sprintf(buf, "# %u packets logged, %u packets lost, used buffer size was %u, total buffer size is %u\n", 
					logged, lost, old_size, log_buf_size);
			WriteFile(logFile, buf, strlen(buf), &written, 0);
		}

		ResetEvent(event);
	}


}