/*
 * Copyright (C) 2000 Avaya Labs, Avaya Inc.
 * Copyright (C) 1999 Bell Labs, Lucent Technologies.
 * Copyright (C) Arash Baratloo, Timothy Tsai, and Navjot Singh.
 *
 * This file is part of the Libsafe library.
 * Libsafe version 2.x: protecting against stack smashing attacks.
 *
 * This library is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Lesser General Public
 * License as published by the Free Software Foundation; either
 * version 2.1 of the License, or (at your option) any later version.
 *
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public
 * License along with this library; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 * For more information, 
 *
 *   visit http://www.research.avayalabs.com/project/libsafe/index.html
 *   or email libsafe@research.avayalabs.com
 */

/* 
 * unsafe functions that are supported:
 *              strcpy(3), strcat(3), sprintf(3), vsprintf(3),
 *              getwd(3), gets(3), realpath(3),
 *              fscanf(3), scanf(3), sscanf(3)
 * safe but supported (as we must): 
 *              memcpy(3)
 * might be problematic, but I can't figure out why:
 *              getopt(3), getpass(3), index(3), streadd(?) 
 */

#include <stdio.h>		/* defines stdin */
#include <stdarg.h>		/* defines va_args */
#define __NO_STRING_INLINES 1	/* stops the inline expansion of strcpy() */
#define __USE_GNU 1		/* defines strnlen() */
#include <string.h>		/* defines strncat() */
#include <unistd.h>		/* defines getcwd(), readlink() */
#include <sys/param.h>		/* MAXPATHLEN for realpath() */
#include <limits.h>		/* PATH_MAX for getwd(); actually in */
				/*     linux/limits.h */
#include <pwd.h>		/* defines getpass() */
#include <errno.h>		/* defines errno */
#include <dlfcn.h>		/* defines dlsym() */
#include "util.h"
#include "log.h"

/*
 * -----------------------------------------------------------------
 * ----------------- system library protocols ----------------------
 */
typedef void *(*memcpy_t) (void *dest, const void *src, size_t n);
typedef char *(*strcpy_t) (char *dest, const char *src);
typedef char *(*strcat_t) (char *dest, const char *src);
typedef char *(*getwd_t) (char *buf);
typedef char *(*gets_t) (char *s);
typedef char *(*realpath_t) (char *path, char resolved_path[]);
#ifdef ORIG_PRINTF_INTERCEPT
typedef int (*vfprintf_t) (FILE *s, const char *format, va_list ap);
typedef int (*register_printf_function) (int __spec, printf_function __func,
	printf_arginfo_function __arginfo);
#endif

/*
 * -----------------------------------------------------------------
 * ------------------- utility functions ---------------------------
 */
#ifndef __USE_GNU
inline size_t strnlen(const char *s, size_t count)
{
    register int __res;
    __asm__ __volatile__("movl %1,%0\n\t"
			 "jmp 2f\n"
			 "1:\tcmpb $0,(%0)\n\t"
			 "je 3f\n\t"
			 "incl %0\n"
			 "2:\tdecl %2\n\t"
			 "cmpl $-1,%2\n\t"
			 "jne 1b\n" "3:\tsubl %1,%0":"=a"(__res)
			 :"c"(s), "d"(count)
			 :"dx");
    return __res;
}
#endif

/*
 * returns a pointer to the implementation of 'funcName' in
 * the libc library.  If not found, terminates the program.
 */
static void *getLibraryFunction(const char *funcName)
{
    void *res;

    if ((res = dlsym(RTLD_NEXT, funcName)) == NULL) {
	fprintf(stderr, "dlsym %s error:%s\n", funcName, dlerror());
	_exit(1);
    }
    return res;
}


/* Starting with version 2.0, we keep a single global copy of the pointer to
 * the real memcpy() function.  This allows us to call
 * getLibraryFunction("memcpy") just once instead of multiple times, since
 * memcpy() is needed in four different functions below.
 */
static memcpy_t real_memcpy = NULL;


/*
 * -------------- system library implementations -------------------
 * Here is the story: if a C source file includes <string.h> and is
 * compiled with -O, then by default strcpy() is expanded (to several
 * memcpy()'s and a strcpy()) just like a macro.  Thus, it is wise to
 * bounds-check memcpy().  Furthermore, because the string "strcpy(,)"
 * gets expanded even when the function is being declared, this code
 * will not compile if optimized unless __NO_STRING_INLINES is defined
 * (see the end of /usr/include/string.h).  This is obviously a
 * compiler/header-file specific thing.  I am using gcc version
 * egcs-2.91.66.
 */
char *strcpy(char *dest, const char *src)
{
    static strcpy_t real_strcpy = NULL;
    size_t max_size, len;

    if (!real_memcpy)
	real_memcpy = (memcpy_t) getLibraryFunction("memcpy");
    if (!real_strcpy)
	real_strcpy = (strcpy_t) getLibraryFunction("strcpy");

    if ((max_size = _libsafe_stackVariableP(dest)) == 0) {
	LOG(5, "strcpy(<heap var> , <src>)\n");
	return real_strcpy(dest, src);
    }

    LOG(4, "strcpy(<stack var> , <src>) stack limit=%d)\n", max_size);
    /*
     * Note: we can't use the standard strncpy()!  From the strncpy(3) manual
     * pages: In the case where the length of 'src' is less than that of
     * 'max_size', the remainder of 'dest' will be padded with nulls.  We do
     * not want null written all over the 'dest', hence, our own
     * implementation.
     */
    if ((len = strnlen(src, max_size)) == max_size)
	_libsafe_die("Overflow caused by strcpy()");
    real_memcpy(dest, src, len + 1);
    return dest;
}

/*
 * This is needed!  See the strcpy() for the reason. -ab.
 */
void *memcpy(void *dest, const void *src, size_t n)
{
    size_t max_size;

    if (!real_memcpy)
	real_memcpy = (memcpy_t) getLibraryFunction("memcpy");

    if ((max_size = _libsafe_stackVariableP(dest)) == 0) {
	LOG(5, "memcpy(<heap var> , <src>, %d)\n", n);
	return real_memcpy(dest, src, n);
    }

    LOG(4, "memcpy(<stack var> , <src>, %d) stack limit=%d)\n", n, max_size);
    if (n > max_size)
	_libsafe_die("Overflow caused by memcpy()");
    return real_memcpy(dest, src, n);
}


char *strcat(char *dest, const char *src)
{
    static strcat_t real_strcat = NULL;
    size_t max_size;
    uint dest_len, src_len;

    if (!real_memcpy)
	real_memcpy = (memcpy_t) getLibraryFunction("memcpy");
    if (!real_strcat)
	real_strcat = (strcpy_t) getLibraryFunction("strcat");

    if ((max_size = _libsafe_stackVariableP(dest)) == 0) {
	LOG(5, "strcat(<heap var> , <src>)\n");
	return real_strcat(dest, src);
    }

    LOG(4, "strcat(<stack var> , <src>) stack limit=%d\n", max_size);
    dest_len = strnlen(dest, max_size + 1);
    src_len = strnlen(src, max_size);

    if (dest_len + src_len >= max_size)
	_libsafe_die("Overflow caused by strcat()");

    real_memcpy(dest + dest_len, src, src_len + 1);

    return dest;
}


#ifdef ORIG_PRINTF_INTERCEPT
int vfprintf(FILE *s, const char *format, va_list ap)
{
    static vfprintf_t real_vfprintf = NULL;

    if (!real_vfprintf)
	real_vfprintf = (vfprintf_t) getLibraryFunction("vfprintf");

    return real_vfprintf(s, format, ap);
}


int _IO_vfprintf(FILE *s, const char *format, va_list ap)
{
    static vfprintf_t real_vfprintf = NULL;

    if (!real_vfprintf)
	real_vfprintf = (vfprintf_t) getLibraryFunction("_IO_vfprintf");

    return real_vfprintf(s, format, ap);
}
#endif


char *getwd(char *buf)
{
    static getwd_t real_getwd = NULL;
    size_t max_size;
    char *res;

    if (!real_getwd)
	real_getwd = (getwd_t) getLibraryFunction("getwd");

    if ((max_size = _libsafe_stackVariableP(buf)) == 0) {
	LOG(5, "getwd(<heap var>)\n");
	return real_getwd(buf);
    }

    LOG(4, "getwd(<stack var>) stack limit=%d\n", max_size);
    res = getcwd(buf, PATH_MAX);
    if ((strlen(buf) + 1) > max_size)
	_libsafe_die("Overflow caused by getwd()");
    return res;
}


char *gets(char *s)
{
    static gets_t real_gets = NULL;
    size_t max_size, len;

    if (!real_gets)
	real_gets = (gets_t) getLibraryFunction("gets");

    if ((max_size = _libsafe_stackVariableP(s)) == 0) {
	LOG(5, "gets(<heap var>)\n");
	return real_gets(s);
    }

    LOG(4, "gets(<stack var>) stack limit=%d\n", max_size);
    fgets(s, max_size, stdin);
    len = strlen(s);

    if(s[len - 1] == '\n')
	s[len - 1] = '\0';
    return s;
}


char *realpath(char *path, char resolved_path[])
{
    static realpath_t real_realpath = NULL;
    size_t max_size, len;
    char *res;
    char buf[MAXPATHLEN + 1];

    if (!real_memcpy)
	real_memcpy = (memcpy_t) getLibraryFunction("memcpy");
    if (!real_realpath)
	real_realpath = (realpath_t) getLibraryFunction("realpath");

    if ((max_size = _libsafe_stackVariableP(resolved_path)) == 0) {
	LOG(5, "realpath(<src>, <heap var>)\n");
	return real_realpath(path, resolved_path);
    }

    LOG(4, "realpath(<src>, <stack var>) stack limit=%d\n", max_size);
    /*
     * realpath(3) copies at most MAXNAMLEN characters
     */
    res = real_realpath(path, buf);
    if ((len = strnlen(buf, max_size)) == max_size)
	_libsafe_die("Overflow caused by realpath()");

    real_memcpy(resolved_path, buf, len + 1);
    return (res == NULL) ? NULL : resolved_path;
}



/*
 * -----------------------------------------------------------------
 * ------------- initializer and finalizer--------------------------
 */
static void _intercept_init() __attribute__ ((constructor));
static void _intercept_fini() __attribute__ ((destructor));


static void _intercept_init(void)
{
    LOG(4, "beginning of _intercept_init()\n");
}


static void _intercept_fini(void)
{
    LOG(4, "end of _intercept_fini()\n");
}
