#include "search.h"

#include <stdlib.h>
#include <string.h>

#define MAX_CHILDREN_PER_NODE 64

typedef struct {
    Board board;
    Move path[256];
    int path_len;
    int score;
} Node;

static int move_score(const Board *b, int path_len) {
    /* Favor higher pairs; light penalty for depth */
    return b->pair_count * 1000 - path_len * 5;
}

static int cmp_node_desc(const void *a, const void *b) {
    const Node *na = (const Node *)a;
    const Node *nb = (const Node *)b;
    if (na->score != nb->score) return nb->score - na->score;
    return nb->board.pair_count - na->board.pair_count;
}

int build_move_list(int size, int n_limit, MoveList *out) {
    int cap = size * size * (n_limit - 1);
    Move *moves = (Move *)malloc((size_t)cap * sizeof(Move));
    if (!moves) return -1;
    int count = 0;
    for (int n = 2; n <= size && n <= n_limit; ++n) {
        int limit = size - n;
        for (int y = 0; y <= limit; ++y) {
            for (int x = 0; x <= limit; ++x) {
                moves[count++] = (Move){(uint8_t)x, (uint8_t)y, (uint8_t)n};
            }
        }
    }
    out->moves = moves;
    out->count = count;
    return 0;
}

void free_move_list(MoveList *ml) {
    free(ml->moves);
    ml->moves = NULL;
    ml->count = 0;
}

static void consider_candidate(const Move *m, const Board *b_after, Move *c_moves, int *c_pairs, int *c_count) {
    if (*c_count < MAX_CHILDREN_PER_NODE) {
        c_moves[*c_count] = *m;
        c_pairs[*c_count] = b_after->pair_count;
        (*c_count)++;
        return;
    }
    /* find weakest */
    int weakest = 0;
    for (int i = 1; i < *c_count; ++i) {
        if (c_pairs[i] < c_pairs[weakest]) weakest = i;
    }
    if (b_after->pair_count > c_pairs[weakest]) {
        c_moves[weakest] = *m;
        c_pairs[weakest] = b_after->pair_count;
    }
}

int solve_beam(const Board *start,
               int beam_width,
               int max_depth,
               int n_limit,
               int stagnation_limit,
               uint64_t hash_capacity,
               Solution *out) {
    if (beam_width < 1) beam_width = 1;
    if (max_depth < 1) max_depth = 1;
    if (n_limit < 2) n_limit = 2;

    MoveList ml = {0};
    if (build_move_list(start->size, n_limit, &ml) != 0) {
        return -1;
    }

    HashTable ht = {0};
    if (ht_init(&ht, hash_capacity) != 0) {
        free_move_list(&ml);
        return -1;
    }

    Node *current = (Node *)malloc((size_t)beam_width * sizeof(Node));
    Node *next = (Node *)malloc((size_t)beam_width * MAX_CHILDREN_PER_NODE * sizeof(Node));
    if (!current || !next) {
        free(current);
        free(next);
        free_move_list(&ml);
        ht_free(&ht);
        return -1;
    }

    current[0].board = *start;
    current[0].path_len = 0;
    current[0].score = move_score(&current[0].board, 0);
    int cur_count = 1;

    int best_pairs = start->pair_count;
    Solution best = {.len = 0, .pair_count = best_pairs};
    const int max_pairs = board_max_pairs(start);
    int stagnation = 0;

    for (int depth = 0; depth < max_depth; ++depth) {
        int next_count = 0;
        int improved_this_layer = 0;
        for (int i = 0; i < cur_count; ++i) {
            Node *node = &current[i];

            Move cand_moves[MAX_CHILDREN_PER_NODE];
            int cand_pairs[MAX_CHILDREN_PER_NODE];
            int cand_count = 0;

            for (int mi = 0; mi < ml.count; ++mi) {
                Move m = ml.moves[mi];
                Board nb;
                board_copy(&nb, &node->board);
                board_apply_move(&nb, &m);

                if (ht_should_prune(&ht, nb.hash, (uint16_t)nb.pair_count, (uint16_t)(node->path_len + 1))) {
                    continue;
                }

                consider_candidate(&m, &nb, cand_moves, cand_pairs, &cand_count);
            }

            for (int c = 0; c < cand_count; ++c) {
                if (next_count >= beam_width * MAX_CHILDREN_PER_NODE) break;
                Node *child = &next[next_count++];
                board_copy(&child->board, &node->board);
                board_apply_move(&child->board, &cand_moves[c]);
                child->path_len = node->path_len + 1;
                memcpy(child->path, node->path, sizeof(Move) * node->path_len);
                child->path[child->path_len - 1] = cand_moves[c];
                child->score = move_score(&child->board, child->path_len);

                if (child->board.pair_count > best_pairs && child->path_len <= 256) {
                    best_pairs = child->board.pair_count;
                    best.len = child->path_len;
                    best.pair_count = child->board.pair_count;
                    memcpy(best.moves, child->path, sizeof(Move) * child->path_len);
                    improved_this_layer = 1;
                    if (best_pairs >= max_pairs) goto done;
                }
            }
        }

        if (next_count == 0) break;
        if (next_count > beam_width) {
            qsort(next, (size_t)next_count, sizeof(Node), cmp_node_desc);
            next_count = beam_width;
        }
        memcpy(current, next, (size_t)next_count * sizeof(Node));
        cur_count = next_count;

        if (improved_this_layer) {
            stagnation = 0;
        } else {
            stagnation++;
            if (stagnation_limit > 0 && stagnation >= stagnation_limit) {
                break;
            }
        }
    }

done:
    free_move_list(&ml);
    ht_free(&ht);
    free(next);
    free(current);

    if (best.len == 0 && best_pairs == start->pair_count) {
        out->len = 0;
        out->pair_count = start->pair_count;
        return 0;
    }
    *out = best;
    return 0;
}
