#include "board.h"

#include <string.h>

/* Zobrist hashing */
static uint64_t zobrist[MAX_SIZE][MAX_SIZE][MAX_SIZE * MAX_SIZE]; /* worst-case value count */
static int zobrist_ready_for = 0;

static uint64_t splitmix64(uint64_t *x) {
    uint64_t z = (*x += 0x9e3779b97f4a7c15ULL);
    z = (z ^ (z >> 30)) * 0xbf58476d1ce4e5b9ULL;
    z = (z ^ (z >> 27)) * 0x94d049bb133111ebULL;
    return z ^ (z >> 31);
}

void board_init_zobrist(int size) {
    if (zobrist_ready_for == size) return;
    uint64_t seed = 88172645463325252ULL;
    for (int y = 0; y < size; ++y) {
        for (int x = 0; x < size; ++x) {
            for (int v = 0; v < size * size; ++v) {
                zobrist[y][x][v] = splitmix64(&seed);
            }
        }
    }
    zobrist_ready_for = size;
}

void board_copy(Board *dst, const Board *src) {
    memcpy(dst, src, sizeof(Board));
}

int board_compute_pairs(const Board *b) {
    const int size = b->size;
    int pairs = 0;
    for (int y = 0; y < size; ++y) {
        for (int x = 0; x < size; ++x) {
            const uint8_t v = b->cells[y][x];
            if (x + 1 < size && b->cells[y][x + 1] == v) pairs++;
            if (y + 1 < size && b->cells[y + 1][x] == v) pairs++;
        }
    }
    return pairs;
}

int board_max_pairs(const Board *b) {
    return (b->size * b->size) / 2;
}

uint64_t board_compute_hash(const Board *b) {
    uint64_t h = 0;
    const int size = b->size;
    for (int y = 0; y < size; ++y) {
        for (int x = 0; x < size; ++x) {
            h ^= zobrist[y][x][b->cells[y][x]];
        }
    }
    return h;
}

void board_apply_move(Board *b, const Move *m) {
    const int n = m->n;
    const int x0 = m->x;
    const int y0 = m->y;
    uint8_t temp[MAX_SIZE][MAX_SIZE];

    for (int dy = 0; dy < n; ++dy) {
        memcpy(temp[dy], &b->cells[y0 + dy][x0], (size_t)n);
    }

    for (int dy = 0; dy < n; ++dy) {
        for (int dx = 0; dx < n; ++dx) {
            /* clockwise rotation */
            b->cells[y0 + dy][x0 + dx] = temp[n - 1 - dx][dy];
        }
    }

    b->pair_count = board_compute_pairs(b);
    b->hash = board_compute_hash(b);
}
