// Copyright (c) 2023, Marvin Borner <dev@marvinborner.de>
// SPDX-License-Identifier: MIT

#include <stdlib.h>
#include <stdio.h>
#include <assert.h>

#include <log.h>
#include <term.h>
#include <map.h>

struct term *term_new(term_type_t type, hash_t hash, size_t depth)
{
	struct term *term = malloc(sizeof(*term));
	if (!term)
		fatal("out of memory!\n");
	term->type = type;
	term->refs = 1;
	term->hash = hash;
	term->depth = depth;
	return term;
}

void term_print(struct term *term)
{
	switch (term->type) {
	case ABS:
		fprintf(stderr, "[");
		term_print(term->u.abs.term);
		fprintf(stderr, "]");
		break;
	case APP:
		fprintf(stderr, "(");
		term_print(term->u.app.lhs);
		fprintf(stderr, " ");
		term_print(term->u.app.rhs);
		fprintf(stderr, ")");
		break;
	case VAR:
		fprintf(stderr, "%ld", term->u.var.index);
		break;
	default:
		fatal("invalid type %d\n", term->type);
	}
}

struct term *term_rehash_abs(struct term *head, struct term *term)
{
	if (head->u.abs.term->hash == term->hash)
		return head;

	hash_t res =
		hash((uint8_t *)&head->type, sizeof(head->type), term->hash);

	assert(res != head->hash);

	struct term *match = map_get(res);
	if (match) { // already exists
		term_refer(match, head->depth);
		term_deref(head);
		return match;
	} else { // create new
		struct term *new = term_new(ABS, res, head->depth);
		new->u.abs.term = term;
		term_refer(term, head->depth + 1);
		term_deref(head);
		return new;
	}
}

struct term *term_rehash_app(struct term *head, struct term *lhs,
			     struct term *rhs)
{
	if (head->u.app.lhs->hash == lhs->hash &&
	    head->u.app.rhs->hash == rhs->hash)
		return head;

	hash_t res =
		hash((uint8_t *)&head->type, sizeof(head->type), lhs->hash);
	res = hash((uint8_t *)&res, sizeof(res), rhs->hash);

	assert(res != head->hash);

	struct term *match = map_get(res);
	if (match) { // already exists
		term_refer(match, head->depth);
		term_deref(head);
		return match;
	} else { // create new
		struct term *new = term_new(APP, res, head->depth);
		new->u.app.lhs = lhs;
		new->u.app.rhs = rhs;
		if (head->u.app.lhs->hash != lhs->hash)
			term_refer(lhs, head->depth + 1);
		if (head->u.app.rhs->hash != rhs->hash)
			term_refer(rhs, head->depth + 1);
		term_deref(head);
		return new;
	}
}

void term_refer_head(struct term *term, size_t depth)
{
	term->refs++;
	if (depth < term->depth) // lower depths are more important
		term->depth = depth;
}

void term_refer(struct term *term, size_t depth)
{
	if (term->type == ABS) {
		term_refer(term->u.abs.term, depth + 1);
	} else if (term->type == APP) {
		term_refer(term->u.app.lhs, depth + 1);
		term_refer(term->u.app.rhs, depth + 1);
	}

	term_refer_head(term, depth);
}

void term_deref(struct term *term)
{
	if (term->type == ABS) {
		term_deref(term->u.abs.term);
	} else if (term->type == APP) {
		term_deref(term->u.app.lhs);
		term_deref(term->u.app.rhs);
	}

	// TODO: remove from hashmap?
	if (--term->refs == 0)
		free(term);
}