#include <stdio.h>
/*
 * Matching routines for PSF rewrite system
 * 
 */

#include "eqm.h"
#include "tiltype.h"
#include "eqm_local.h"
#include "psf_malloc.h"
#include "psf_exits.h"

#ifdef PRTILPSF
#include "prtilparts.h"
#else
#define print_ae_term(s)
#define print_term(s)
#endif
int conditionals_max_depth = 0;
extern char *progname;

static term_t *term_reduce_top PROTO_ARGS((eqm_t *eqm, term_t *term));
static void term_reduce_children PROTO_ARGS((eqm_t *eqm, term_t *term));
static subst_t *add_to_sub PROTO_ARGS((struct indextype variable, term_t *value, subst_t *oldsub));

static subst_elem_t *se_free_list = NULL;
static subst_t *sub_free_list = NULL;
#ifdef DEBUG_MEMUSE
static int sub_count = 0;
static int sub_countold = 0;
static int sub_dcount = 0;
static int se_count = 0;
static int se_countold = 0;
static int se_dcount = 0;
#endif /* DEBUG_MEMUSE */


#ifdef TERMTRACE
int term_trace = 0;
#endif
#if defined(TERMTRACE) || defined(VERBOSE)
static int level = 0;
#endif

static term_t
* lookup_in_sub(variable, sub)
    struct indextype variable;
    subst_t *sub;
{
    subst_elem_t *elem;

    if (sub == NULL)
	return NULL;

    for (elem = sub->sb_elems; elem; elem = elem->next_se)
	if (var_equal(variable, elem->var))
	    return elem->value;

    /* not found */
    return NULL;
}

static int term_equal PROTO_ARGS((term_t *, term_t *));

static subst_t
* add_to_sub(variable, value, oldsub)
    struct indextype variable;
    term_t *value;
    subst_t *oldsub;
{
    term_t *oldvalue;
    subst_elem_t *newelem;

    oldvalue = lookup_in_sub(variable, oldsub);

    if (oldvalue != NULL) {
	if (!term_equal(value, oldvalue))
	    return NULL;
	else
	    return oldsub;
    } else {

	if (se_free_list != NULL) {
	    newelem = se_free_list;
	    se_free_list = se_free_list->next_se;
#ifdef DEBUG_MEMUSE
	    se_countold ++;
#endif /* DEBUG_MEMUSE */
	} else {
	    newelem = PSF_MALLOC(subst_elem_t);
#ifdef DEBUG_MEMUSE
	    se_count ++;
#endif /* DEBUG_MEMUSE */
	}
	newelem->next_se = oldsub->sb_elems;
	oldsub->sb_elems = newelem;
	newelem->value = value;
	value->refcnt++;
	newelem->var = variable;
	return oldsub;
    }
}

static subst_t
* term_match_internal(pattern, term, sub)
    term_t *pattern;
    term_t *term;
    subst_t *sub;
{
    if (term_is_variable(pattern))
	return add_to_sub(pattern->ind, term, sub);

    if (op_equal(pattern, term)) {
	term_t **pson;
	term_t **tson;
	int cnt;

	if (pattern->nsons > 0) {
	    cnt = pattern->nsons;
	    pson = pattern->sons + cnt;
	    tson = term->sons + cnt;

	    for (; sub != NULL && cnt > 0; cnt--)
		sub = term_match_internal(*--pson, *--tson, sub);

	}
	return sub;
    }
    return NULL;
}

static subst_t
* ae_term_match_internal(pattern, term, sub)
    ae_term *pattern;
    term_t *term;
    subst_t *sub;
{
    if (term_is_variable(pattern))
	return add_to_sub(pattern->ind, term, sub);

    if (op_equal(pattern, term)) {
	ae_term *pson;
	term_t **tson;
	int cnt = pattern->a;

	if (cnt > 0) {
	    pson = pattern->ae_list + cnt;
	    tson = term->sons + cnt;

	    for (; sub != NULL && cnt > 0; cnt--)
		sub = ae_term_match_internal(--pson, *--tson, sub);
	}
	return sub;
    }
    return NULL;
}

subst_t
* term_match_sub(pattern, term, sub)
    ae_term *pattern, *term;
    subst_t *sub;
{
    term_t *t;
    subst_t *result;

    result = ae_term_match_internal(pattern, t = ae_term2term(term), sub);
    if (--t->refcnt == 0)
	term_free(t);
    return result;
}

static subst_t
* unroll_term_match(pattern, term)
    ae_term *pattern;
    term_t *term;
{
    subst_t *sub, *rsub;

    if (sub_free_list == NULL) {
	sub = PSF_MALLOC(subst_t);
#ifdef DEBUG_MEMUSE
	sub_count ++;
#endif /* DEBUG_MEMUSE */
    } else {
	sub = sub_free_list;
	sub_free_list = sub_free_list->next_sub;
#ifdef DEBUG_MEMUSE
	sub_countold ++;
#endif /* DEBUG_MEMUSE */
    }
    sub->f = 1;
    sub->sb_elems = NULL;

    if (term->nsons != 0) {
	ae_term *pson;
	term_t **tson;
	int cnt = pattern->a;

	if (cnt > 0) {
	    pson = pattern->ae_list + cnt;
	    tson = term->sons + cnt;

	    rsub = sub;

	    for (; rsub != NULL && cnt > 0; cnt--)
		rsub = ae_term_match_internal(--pson, *--tson, rsub);
	}
    } else
	return sub;

    if (rsub == NULL)
	subst_free(sub);

#ifdef VERBOSE
    else {
	printf("%*s", level, "");
	print_ae_term(pattern);
	printf("<- match ->");
	print_term(term);
	putchar('\n');
    }
#endif

    return rsub;
}

subst_t
* term_match(pattern, term)
    ae_term *pattern, *term;
{
    subst_t *sub, *rsub;
    term_t *t;

    if (sub_free_list == NULL) {
	sub = PSF_MALLOC(subst_t);
#ifdef DEBUG_MEMUSE
	sub_count ++;
#endif /* DEBUG_MEMUSE */
    } else {
	sub = sub_free_list;
	sub_free_list = sub_free_list->next_sub;
#ifdef DEBUG_MEMUSE
	sub_countold ++;
#endif /* DEBUG_MEMUSE */
    }
    sub->f = 1;
    sub->sb_elems = NULL;

    rsub = ae_term_match_internal(pattern, t = ae_term2term(term), sub);

    if (--t->refcnt == 0)
	term_free(t);

    if (rsub == NULL)
	subst_free(sub);
    return rsub;
}

static int term_equal(term1, term2)
    term_t *term1, *term2;
{
    int i;

    if (!op_equal(term1, term2))
	return 0;

    for (i = term1->nsons; i-- > 0;) {
	if (!term_equal(term1->sons[i], term2->sons[i]))
	    return 0;
    }
    return 1;
}

/*
 * Substitution happens in place.
 */
/*
 * term_reduce_top throws old term away. This is inherited by
 * term_reduce_children. (new term == oldterm || oldterm is disposed of)
 */
extern int free_list_size;
extern term_t *term_free_list[];

#ifdef TRACEALLOC
extern int alloc_stats[];

#endif

static term_t
* term_instantiate_children(pattern, sub, eqm)
    ae_term *pattern;
    subst_t *sub;
    eqm_t *eqm;
{
    int i;
    ae_term *son;
    term_t *new;

    if (pattern->a >= free_list_size || term_free_list[pattern->a] == NULL) {
	new = (term_t *) psf_malloc(sizeof(term_t) +
				pattern->a * sizeof(term_t *));

#ifdef TRACEALLOC
	alloc_stats[pattern->a >= free_list_size ?
		    free_list_size : pattern->a]++;
#endif
    } else {
	new = term_free_list[pattern->a];
	term_free_list[pattern->a] = new->free_next;
    }
    new->ind = pattern->ind;
    new->nsons = pattern->a;
    new->refcnt = 1;
    for (i = pattern->a; i-- > 0;) {
	term_t *val;

	son = pattern->ae_list + i;
	if (term_is_variable(son)) {
	    val = lookup_in_sub(son->ind, sub);
	    if (val != NULL) {
		val->refcnt++;
		new->sons[i] = val;
	    } else
		new->sons[i] = ae_term2term(son);
	} else
	    new->sons[i] =
		term_instantiate_children(son, sub, eqm);
    }
    if (eqm != NULL)
	if (new->ind.table == ATM || new->ind.table == PRO)
	    return (new);
	else
	    /* reduce current term at top level only */
	    return term_reduce_top(eqm, new);

    return new;
}

/*
 * If a pattern is just a variable, it's in normal form. No further reductions
 * are possible. Set the flag finished
 */

static term_t
* term_instantiate_x(pattern, sub, eqm, finished)
    ae_term *pattern;
    subst_t *sub;
    eqm_t *eqm;
    int *finished;
{
    term_t *scratch;
    int i;
    ae_term *son;
    term_t *new;


    if (term_is_variable(pattern)) {
	*finished = 1;
	if ((scratch = lookup_in_sub(pattern->ind, sub)) != NULL) {
	    scratch->refcnt++;
	    return scratch;
	} else {
	    return ae_term2term(pattern);
	}
    }
    *finished = 0;

    if (pattern->a >= free_list_size || term_free_list[pattern->a] == NULL) {
	new = (term_t *) psf_malloc(sizeof(term_t) +
				pattern->a * sizeof(term_t *));

#ifdef TRACEALLOC
	alloc_stats[pattern->a >= free_list_size ?
		    free_list_size : pattern->a]++;
#endif
    } else {
	new = term_free_list[pattern->a];
	term_free_list[pattern->a] = new->free_next;
    }
    new->ind = pattern->ind;
    new->nsons = pattern->a;
    new->refcnt = 1;
    for (i = pattern->a; i-- > 0;) {
	term_t *val;

	son = pattern->ae_list + i;
	if (term_is_variable(son)) {
	    val = lookup_in_sub(son->ind, sub);
	    if (val != NULL) {
		val->refcnt++;
		new->sons[i] = val;
	    } else
		new->sons[i] = ae_term2term(son);
	} else
	    new->sons[i] =
		term_instantiate_children(son, sub, eqm);
    }
    return new;
}

/*
 * returns term_t * instead of ae_term *
 */
static term_t
* my_term_instantiate(pattern, sub, eqm)
    ae_term *pattern;
    subst_t *sub;
    eqm_t *eqm;
{
    term_t *scratch;

    if (term_is_variable(pattern))
	if ((scratch = lookup_in_sub(pattern->ind, sub)) != NULL) {
	    scratch->refcnt++;
	    return scratch;
	} else
	    return ae_term2term(pattern);

    return term_instantiate_children(pattern, sub, eqm);
}

ae_term
*term_instantiate(pattern, sub, eqm)
    ae_term *pattern;
    subst_t *sub;
    eqm_t *eqm;
{
    term_t *scratch, *t;
    ae_term *result;

    if (term_is_variable(pattern))
	if ((scratch = lookup_in_sub(pattern->ind, sub)) != NULL) {
	    result = term2ae_term(scratch);
	    return result;
	} else
	    return pattern;

    t = term_instantiate_children(pattern, sub, eqm);
    result = term2ae_term(t);
    if (--t->refcnt == 0)
	term_free(t);
    return result;
}

void subst_free(sub)
    subst_t *sub;
{
    subst_elem_t *elem;

    if (sub == NULL)
	return;

    if (!sub->f)
	return;

    elem = sub->sb_elems;
    if (elem != NULL) {
	while (elem->next_se != NULL) {
	    if (--elem->value->refcnt == 0)
		term_free(elem->value);
	    elem = elem->next_se;
#ifdef DEBUG_MEMUSE
	    se_dcount ++;
#endif /* DEBUG_MEMUSE */
	}
	if (--elem->value->refcnt == 0)
	    term_free(elem->value);
	elem->next_se = se_free_list;
	se_free_list = sub->sb_elems;
#ifdef DEBUG_MEMUSE
	se_dcount ++;
#endif /* DEBUG_MEMUSE */
    }
    sub->next_sub = sub_free_list;
    sub_free_list = sub;
#ifdef DEBUG_MEMUSE
    sub_dcount ++;
#endif /* DEBUG_MEMUSE */
    sub->f = 0;
}

static int has_uninstantiated(sub, term)
    subst_t *sub;
    ae_term *term;
{
    int i;

    if (term_is_variable(term))
	return lookup_in_sub(term->ind, sub) == NULL;

    for (i = term->a; i-- > 0;) {
	if (has_uninstantiated(sub, term->ae_list + i))
	    return 1;
    }
    return 0;
}
static int condition_depth = 0;

static subst_t
* check_condition(eqm, sub, eq)
    eqm_t *eqm;
    subst_t *sub;
    struct equation *eq;
{
    term_t *left, *right;
    int leftu, rightu;
    subst_t *rsub;

    condition_depth++;
    if (condition_depth >
	    (conditionals_max_depth ? conditionals_max_depth : 1000)) {
	fflush(stdout);
	fprintf(stderr,
		"%s: recursion depth for conditions exceeded.\n",
		progname);
	fprintf(stderr, "%s: %s %s\n",
		"maybe there is a circular dependency",
		"in the conditional equations.", progname);
	/*
	 * conditionals_max_depth is the maximal recursive descent in
	 * check_condition. If it is set to 0 it defaults to 1000. If set to
	 * non zero it will be that value. At initialisation it is set to zero.
	 * Mark 1-9-93
	 */
	exit(EXIT_HELP);
    }
    leftu = has_uninstantiated(sub, &eq->aet1);
    rightu = has_uninstantiated(sub, &eq->aet2);

#ifdef TERMTRACE
    if (term_trace) {
	term_t *left = my_term_instantiate(&eq->aet1, sub, (eqm_t *) 0);
	term_t *right = my_term_instantiate(&eq->aet2, sub, (eqm_t *) 0);

	printf("%*s", 3 * level, "");
	printf("[condition: ");
	print_term(left);
	printf(" = ");
	print_term(right);
	printf("]\n");
    }
#endif

    left = my_term_instantiate(&eq->aet1, sub, eqm);
    right = my_term_instantiate(&eq->aet2, sub, eqm);

    if (leftu && rightu) {
	EQMerror("Invalid condition");
	subst_free(sub);
	condition_depth--;
	return NULL;
    }
    if (leftu)
	rsub = term_match_internal(left, right, sub);
    else if (rightu)
	rsub = term_match_internal(right, left, sub);
    else
	rsub = term_equal(left, right) ? sub : NULL;

#ifdef TERMTRACE
    if (term_trace)
	printf("%*s[%s]\n", 3 * level, "",
	       rsub == NULL ? "failure" : "success");
#endif

    if (--left->refcnt == 0)
	term_free(left);
    if (--right->refcnt == 0)
	term_free(right);
    if (rsub == NULL)
	subst_free(sub);
    condition_depth--;
    return rsub;
}

static term_t
* term_reduce_top(eqm, term)
    eqm_t *eqm;
    term_t *term;
{
    subst_t *sub;
    eq_list_t *eq;
    eq_fun_list_t *fun;
    int i;
    term_t *result;

    if (term_is_variable(term))
	return term;

    fun = eqm->fun_list[term->ind.key];

    if (fun == NULL)
	return term;
    if (term->nsons > 0)
	eq = fun->eq_list[FUN1HASH(term->sons[0]->ind)];
    else
	eq = fun->var_eq;

#ifdef TERMTRACE
    if (term_trace)
	level++;
#endif

    while (eq != NULL) {
	int finished;

#ifdef EQM_STATS
	eq->ntries++;
#endif

	if ((sub = unroll_term_match(&eq->this_eq.aet1, term)) != NULL) {
	    /* success, now check conditions */
	    if (eq->this_eq.a != 0) {
		for (i = 0; i < eq->this_eq.a; i++) {
		    if (!(sub = check_condition(eqm, sub,
						eq->this_eq.guard + i)))
			break;
		}
		if (!sub) {
		    eq = eq->next_eq;
		    continue;
		}
	    }

#ifdef EQM_STATS
	    eq->nsuccess++;
#endif

#ifdef TERMTRACE
	    if (term_trace) {
		printf("%*s", 3 * level, "");
		print_term(term);
		putchar('\n');
	    }
#endif

	    /* subsitute (instantiate & reduce) */
	    result = term_instantiate_x(&eq->this_eq.aet2, sub, eqm, &finished);

#ifdef TERMTRACE
	    if (term_trace) {
		printf("%*s", 3 * level, "");
		printf("-> ");
		print_term(result);
		putchar('\n');
	    }
#endif

	    if (--term->refcnt == 0)
		term_free(term);
	    subst_free(sub);
	    term = result;
	    if (finished) {

#ifdef TERMTRACE
		if (term_trace)
		    level--;
#endif

		return term;
	    }
	    fun = eqm->fun_list[term->ind.key];

	    if (fun == NULL) {

#ifdef TERMTRACE
		if (term_trace)
		    level--;
#endif

		return term;
	    }
	    if (term->nsons > 0)
		eq = fun->eq_list[FUN1HASH(term->sons[0]->ind)];
	    else
		eq = fun->var_eq;
	} else
	    eq = eq->next_eq;
    }

#ifdef TERMTRACE
    if (term_trace)
	level--;
#endif

    return term;
}

static void term_reduce_children(eqm, term)
    eqm_t *eqm;
    term_t *term;
{
    int i;
    term_t *son;

    /* first the children */
    for (i = term->nsons; i-- > 0;) {
	son = term->sons[i];
	term_reduce_children(eqm, son);
	term->sons[i] = term_reduce_top(eqm, son);
    }
}

ae_term
*term_reduce(eqm, term)
    eqm_t *eqm;
    ae_term *term;
{
    term_t *t;

#ifdef TERMTRACE
    if (term_trace) {
	putchar('\n');
    }
#endif

    t = ae_term2term(term);

    term_reduce_children(eqm, t);
    t = term_reduce_top(eqm, t);

    term = term2ae_term(t);
    if (--t->refcnt == 0)
	term_free(t);

    return term;
}

#ifdef SUMMERGE
/* the following is added to deal with the sum and merge in the simulator */

subst_t *copy_add_to_sub(oldsub, ind, aet, eqm)
    subst_t *oldsub;
    struct indextype *ind;
    ae_term *aet;
    eqm_t *eqm;
{
    subst_t *sub;
    subst_elem_t *elem, *newelem, *h;

    if (sub_free_list == NULL) {
	sub = PSF_MALLOC(subst_t);
#ifdef DEBUG_MEMUSE
	sub_count ++;
#endif /* DEBUG_MEMUSE */
    } else {
	sub = sub_free_list;
	sub_free_list = sub_free_list->next_sub;
#ifdef DEBUG_MEMUSE
	sub_countold ++;
#endif /* DEBUG_MEMUSE */
    }
    sub->f = 1;
    sub->sb_elems = NULL;
    if (oldsub != NULL) {
	h = NULL;
	for (elem = oldsub->sb_elems; elem; elem = elem->next_se) {
	    if (se_free_list != NULL) {
		newelem = se_free_list;
		se_free_list = se_free_list->next_se;
#ifdef DEBUG_MEMUSE
		se_countold ++;
#endif /* DEBUG_MEMUSE */
	    } else {
		newelem = PSF_MALLOC(subst_elem_t);
#ifdef DEBUG_MEMUSE
		se_count ++;
#endif /* DEBUG_MEMUSE */
	    }
	    if (h == NULL) {
		sub->sb_elems = newelem;
	    } else {
		h->next_se = newelem;
	    }
	    newelem->next_se = NULL;
	    h = newelem;
	    newelem->value = term_copy(elem->value);
	    newelem->var = elem->var;
	}
    }
    /*
     * sub = add_to_sub (* ind, ae_term2term (aet), sub);
     */
    if (se_free_list != NULL) {
	newelem = se_free_list;
	se_free_list = se_free_list->next_se;
#ifdef DEBUG_MEMUSE
	se_countold ++;
#endif /* DEBUG_MEMUSE */
    } else {
	newelem = PSF_MALLOC(subst_elem_t);
#ifdef DEBUG_MEMUSE
	se_count ++;
#endif /* DEBUG_MEMUSE */
    }
    newelem->next_se = sub->sb_elems;
    newelem->value = term_reduce_top(eqm, ae_term2term(aet));
    newelem->var = *ind;
    sub->sb_elems = newelem;
    return (sub);
}

subst_t *copy_add_sub_to_sub(oldsub, newsub)
    subst_t *oldsub;
    subst_t *newsub;
{
    subst_t *sub;
    subst_elem_t *elem, *newelem, *h;

    if (sub_free_list == NULL) {
	sub = PSF_MALLOC(subst_t);
#ifdef DEBUG_MEMUSE
	sub_count ++;
#endif /* DEBUG_MEMUSE */
    } else {
	sub = sub_free_list;
	sub_free_list = sub_free_list->next_sub;
#ifdef DEBUG_MEMUSE
	sub_countold ++;
#endif /* DEBUG_MEMUSE */
    }
    sub->f = 1;
    sub->sb_elems = NULL;
    if (oldsub != NULL) {
	h = NULL;
	for (elem = oldsub->sb_elems; elem; elem = elem->next_se) {
	    if (se_free_list != NULL) {
		newelem = se_free_list;
		se_free_list = se_free_list->next_se;
#ifdef DEBUG_MEMUSE
		se_countold ++;
#endif /* DEBUG_MEMUSE */
	    } else {
		newelem = PSF_MALLOC(subst_elem_t);
#ifdef DEBUG_MEMUSE
		se_count ++;
#endif /* DEBUG_MEMUSE */
	    }
	    if (h == NULL) {
		sub->sb_elems = newelem;
	    } else {
		h->next_se = newelem;
	    }
	    newelem->next_se = NULL;
	    h = newelem;
	    newelem->value = term_copy(elem->value);
	    newelem->var = elem->var;
	}
    }
    if (newsub != NULL) {
	if (newsub->sb_elems) {
	    for (elem = newsub->sb_elems; elem->next_se; elem = elem->next_se);
	    elem->next_se = sub->sb_elems;
	    sub->sb_elems = newsub->sb_elems;
	    newsub->sb_elems = NULL; /* empty it, so it can be freed */
	}
    }
    return (sub);
}
#endif

#ifdef DEBUG_MEMUSE
void print_sub_count()
{
    fprintf(stderr, "sub     : %11d %11d %11d %11d\n", sub_count, sub_countold, sub_dcount, sub_count + sub_countold - sub_dcount);
    fprintf(stderr, "elem    : %11d %11d %11d %11d\n", se_count, se_countold, se_dcount, se_count + se_countold - se_dcount);
}
#endif /* DEBUG_MEMUSE */

subst_t *get_sub()
{
    subst_t *sub;

    if (sub_free_list == NULL) {
	sub = PSF_MALLOC(subst_t);
#ifdef DEBUG_MEMUSE
	sub_count ++;
#endif /* DEBUG_MEMUSE */
    } else {
	sub = sub_free_list;
	sub_free_list = sub_free_list->next_sub;
#ifdef DEBUG_MEMUSE
	sub_countold ++;
#endif /* DEBUG_MEMUSE */
    }
    sub->f = 1;
    return(sub);
}
