/*
 * STAN - Stream Analyser
 * Copyright (c) 2001 Konrad Rieck <kr@r0q.cx>
 * The Roqefellaz, http://www.r0q.cx/stan
 *
 * Implementation of a Treap.
 * This treap combines a binary tree for the patterns and a heap for the
 * counts. Depth is NOT O(lg n), since counts are not random numbers. 
 * If a pattern is inserted that already exists its count is increased, 
 * therefore nodes can only move up in a tree during insert.
 * Several parts of the code have been borrowed from "Introduction to
 * Algorithms" by Cormen/Leiserson/Rivest. 
 */

#include <sys/types.h>
#include <stdlib.h>
#include <stdio.h>

#include <stan.h>
#include <treap.h>
#include <imath.h>
#include <string.h>

#ifdef DEBUG
size_t heap_okay = TRUE;
size_t tree_okay = TRUE;
#endif

tnode_t **treap;

void init_treap(size_t patlen)
{
    treap = (tnode_t **) calloc(patlen, sizeof(tnode_t *));
}

void free_treap(size_t patlen)
{
    size_t i;
    for (i = 0; i < patlen; i++) {
    dprintf(STRONG, ("freeing treap[%d]\n",i));
	ttraverse(&treap[i], tfree, postorder);
    }	
}

tnode_t **tlocate(tnode_t ** tnode, byte_t * pattern, size_t len)
{
    dprintf(STRONG, ("tlocate: tnode 0x%p\n", tnode));

    if (!*tnode) {
	return tnode;
    }

    dprintf(STRONG,
	    ("tlocate: tnode 0x%p exists, checking pattern\n", tnode));

    if (memcmp((*tnode)->pattern, pattern, len) > 0)
	return tlocate(&(*tnode)->left, pattern, len);
    else if (memcmp((*tnode)->pattern, pattern, len) < 0)
	return tlocate(&(*tnode)->right, pattern, len);

    dprintf(STRONG, ("tlocaten: tnode 0x%p matches pattern\n", tnode));
    return tnode;
}

tnode_t **tmin(tnode_t ** tnode)
{
    if (!*tnode || !(*tnode)->left)
	return tnode;
    return tmin(&(*tnode)->left);
}

tnode_t **tmax(tnode_t ** tnode)
{
    if (!*tnode || !(*tnode)->right)
	return tnode;
    return tmax(&(*tnode)->right);
}

size_t tdepth(tnode_t ** tnode)
{
    if (!*tnode || (!(*tnode)->left && !(*tnode)->right))
	return 0;
    return max(tdepth(&(*tnode)->left), tdepth(&(*tnode)->right)) + 1;
}

size_t thigh(tnode_t ** tnode)
{
    if (!*tnode || !(*tnode)->parent)
	return 0;
    return thigh(&(*tnode)->parent) + 1;
}

tnode_t **tsuccessor(tnode_t ** tnode)
{
    tnode_t **parent;

    if (!*tnode)
	return tnode;

    if ((*tnode)->right)
	return tmin(&(*tnode)->right);

    if (!(*tnode)->parent)
	return tnode;

    parent = &(*tnode)->parent;
    while (*parent && *tnode == (*parent)->right) {
	tnode = parent;
	parent = &(*tnode)->parent;
    }
    return parent;
}

tnode_t **tpredecessor(tnode_t ** tnode)
{
    tnode_t **parent;

    if (!*tnode)
	return tnode;

    if ((*tnode)->left)
	return tmax(&(*tnode)->left);

    if (!(*tnode)->parent)
	return tnode;

    parent = &(*tnode)->parent;
    while (*parent && *tnode == (*parent)->left) {
	tnode = parent;
	parent = &(*tnode)->parent;
    }
    return parent;
}

size_t tsize_tnodes(tnode_t ** tnode)
{
    if (!*tnode)
	return 0;

    return tsize_tnodes(&(*tnode)->left) + tsize_tnodes(&(*tnode)->right) +
	1;
}

size_t tsize_total(tnode_t ** tnode)
{
    if (!*tnode)
	return 0;

    return tsize_total(&(*tnode)->left) + tsize_total(&(*tnode)->right) +
	(*tnode)->count;
}

void tprint(tnode_t * tnode)
{
    size_t i;
    for (i = thigh(&tnode); i > 0; i--)
	printf("     ");
    printf("(%s:%u)\n", tnode->pattern, tnode->count);
}

void tfree(tnode_t * tnode)
{
    dprintf(STRONG, ("tfree: freeing tnode 0x%p\n", tnode));
    free(tnode);
}

tnode_t **tdelete(tnode_t ** tnode, size_t patlen)
{
    tnode_t *x, *y;

    if (!tnode || !*tnode)
	return NULL;

    dprintf(STRONG, ("deleting tnode 0x%p\n", tnode));

    x = *tnode;

    while (x->left && x->right) {
	if (x->right->count > x->left->count)
	    rotate_left(&x, patlen);
	else
	    rotate_right(&x, patlen);
    }


    if (x->left)
	y = x->left;
    else
	y = x->right;

    if (y)
	y->parent = x->parent;

    if (x->parent) {
	if (x->parent->left == x)
	    x->parent->left = y;
	else
	    x->parent->right = y;
    } else
	treap[patlen - 1] = y;
    return tnode;
}

void tinsert(tnode_t ** tnode, byte_t * pattern, size_t patlen)
{
    tnode_t **parent = NULL, *x;

    while (*tnode) {
	parent = &(*tnode);
	if (memcmp((*tnode)->pattern, pattern, patlen) > 0)
	    tnode = &(*tnode)->left;
	else if (memcmp((*tnode)->pattern, pattern, patlen) < 0)
	    tnode = &(*tnode)->right;
	else
	    break;
    }

    if (!*tnode) {
	dprintf(STRONG,
		("tinsert: allocating memory for tnode 0x%p\n", tnode));
	dprintf(STRONG,
		("tinsert: setting parent 0x%p for tnode 0x%p\n",
		 parent, tnode));
	*tnode = (tnode_t *) calloc(1, sizeof(tnode_t));
	(*tnode)->pattern = (byte_t *) malloc(sizeof(byte_t) * patlen + 1);
	memcpy((*tnode)->pattern, pattern, patlen + 1);
	(*tnode)->count = 1;

	if (parent) {
	    (*tnode)->parent = *parent;
	} else {
	    (*tnode)->parent = NULL;
	}
    } else {
	dprintf(STRONG,
		("tinsert: increment count for tnode 0x%p\n", tnode));
	(*tnode)->count++;

	x = *tnode;
	while (x->parent && x->count > x->parent->count) {
	    dprintf(STRONG, ("tinsert: rotating tnode 0x%p\n", tnode));
	    if (x == x->parent->left)
		rotate_right(&x->parent, patlen);
	    else
		rotate_left(&x->parent, patlen);

	}

    }

}

void ttraverse(tnode_t ** tnode, void (*func) (tnode_t *), int order)
{
    if (*tnode) {
	switch (order) {
	case preorder:
	    func(*tnode);
	    ttraverse(&(*tnode)->left, func, order);
	    ttraverse(&(*tnode)->right, func, order);
	    break;
	case inorder:
	    ttraverse(&(*tnode)->left, func, order);
	    func(*tnode);
	    ttraverse(&(*tnode)->right, func, order);
	    break;
	case postorder:
	    ttraverse(&(*tnode)->left, func, order);
	    ttraverse(&(*tnode)->right, func, order);
	    func(*tnode);
	    break;
	}
    }
}

#if DEBUG
void check_heap(tnode_t * tnode)
{
    if (tnode->parent && tnode->parent->count < tnode->count)
	heap_okay = FALSE;
}

void check_tree(tnode_t * tnode, size_t patlen)
{
    if (tnode->left
	&& memcmp(tnode->left->pattern, tnode->pattern, patlen) > 0)
	tree_okay = FALSE;

    if (tnode->right
	&& memcmp(tnode->right->pattern, tnode->pattern, patlen) < 0)
	tree_okay = FALSE;
}
#endif

void rotate_left(tnode_t ** tnode, size_t patlen)
{
    tnode_t *y, *x;

    x = *tnode;

    /*
     * If tnode doesn't exist or doesn't have a right subtree return
     */
    if (!x || !x->right)
	return;

    y = x->right;

    /* 
     * Turn y's left subtree into x's right subtree.
     */
    x->right = y->left;

    if (y->left)
	y->left->parent = x;

    /*
     * Link x's parent to y.
     */
    y->parent = x->parent;

    if (!x->parent) {
	dprintf(STRONG, ("rotate_left: tnode was root.\n"));
	treap[patlen - 1] = y;
    } else {
	dprintf(STRONG, ("rotate_left: tnode wasn't root.\n"));
	if (x == x->parent->left) {
	    x->parent->left = y;
	} else {
	    x->parent->right = y;
	}
    }

    /*
     * Put x on y's left
     */
    y->left = x;
    x->parent = y;
}


void rotate_right(tnode_t ** tnode, size_t patlen)
{
    tnode_t *y, *x;

    x = *tnode;

    /*
     * If tnode doesn't exist or doesn't have a left subtree return
     */
    if (!x || !x->left)
	return;

    y = x->left;

    /* 
     * Turn y's right subtree into x's left subtree.
     */
    x->left = y->right;

    if (y->right)
	y->right->parent = x;

    /*
     * Link x's parent to y.
     */
    y->parent = x->parent;

    if (!x->parent) {
	dprintf(STRONG, ("rotate_right: tnode was root.\n"));
	treap[patlen - 1] = y;
    } else {
	dprintf(STRONG, ("rotate_right: tnode wasn't root.\n"));
	if (x == x->parent->left) {
	    x->parent->left = y;
	} else {
	    x->parent->right = y;
	}
    }

    /*
     * Put x on y's right
     */
    y->right = x;
    x->parent = y;
}
