#include "hash.h"

#include <stdlib.h>
#include <string.h>

int ht_init(HashTable *ht, uint64_t capacity) {
    if (capacity == 0) return -1;
    /* force power of two */
    uint64_t cap = 1;
    while (cap < capacity) cap <<= 1;
    ht->entries = (HashEntry *)calloc((size_t)cap, sizeof(HashEntry));
    if (!ht->entries) return -1;
    ht->mask = cap - 1;
    return 0;
}

void ht_free(HashTable *ht) {
    free(ht->entries);
    ht->entries = NULL;
    ht->mask = 0;
}

int ht_should_prune(HashTable *ht, uint64_t key, uint16_t pairs, uint16_t depth) {
    uint64_t idx = key & ht->mask;
    for (uint64_t i = 0; i < ht->mask + 1; ++i) {
        uint64_t pos = (idx + i) & ht->mask;
        HashEntry *e = &ht->entries[pos];
        if (e->key == 0) {
            e->key = key ? key : 1; /* avoid zero key meaning empty */
            e->pairs = pairs;
            e->depth = depth;
            return 0;
        }
        if (e->key == key) {
            if (e->pairs >= pairs && e->depth <= depth) {
                return 1; /* already have equal/better */
            }
            e->pairs = pairs;
            e->depth = depth;
            return 0;
        }
    }
    return 0;
}
