#include "search.h"

#include <stdlib.h>
#include <string.h>

#define MAX_PATH 256
#define MAX_CHILDREN_PER_NODE 64

typedef struct {
    Board board;
    Move path[MAX_PATH];
    int path_len;
    int score;
} Node;

static int move_score(const Board *b, int path_len) {
    return b->pair_count * 1000 - path_len;
}

static int cmp_node_desc(const void *a, const void *b) {
    const Node *na = (const Node *)a;
    const Node *nb = (const Node *)b;
    if (na->score != nb->score) return nb->score - na->score;
    return nb->board.pair_count - na->board.pair_count;
}

static void consider_candidate(const Move *m, const Board *b_after, Move *c_moves, int *c_pairs, int *c_count) {
    //Keep top MAX_CHILDREN_PER_NODE by pair_count
    if (*c_count < MAX_CHILDREN_PER_NODE) {
        c_moves[*c_count] = *m;
        c_pairs[*c_count] = b_after->pair_count;
        (*c_count)++;
        return;
    }
    //weakest candidate
    int weakest = 0;
    for (int i = 1; i < *c_count; ++i) {
        if (c_pairs[i] < c_pairs[weakest]) weakest = i;
    }
    if (b_after->pair_count > c_pairs[weakest]) {
        c_moves[weakest] = *m;
        c_pairs[weakest] = b_after->pair_count;
    }
}

int solve_beam(const Board *start, int beam_width, int max_depth, int n_limit, Solution *out) {
    if (beam_width < 1) beam_width = 1;
    if (max_depth < 1) max_depth = 1;
    if (n_limit < 2) n_limit = 2;

    Node *current = (Node *)malloc((size_t)beam_width * sizeof(Node));
    Node *next = (Node *)malloc((size_t)beam_width * MAX_CHILDREN_PER_NODE * sizeof(Node));
    if (!current || !next) {
        free(current);
        free(next);
        return -1;
    }

    current[0].board = *start;
    current[0].path_len = 0;
    current[0].score = move_score(&current[0].board, 0);
    int cur_count = 1;

    int best_pairs = start->pair_count;
    Solution best = {.len = 0, .pair_count = best_pairs};

    const int max_pairs = board_max_pairs(start);

    for (int depth = 0; depth < max_depth; ++depth) {
        int next_count = 0;
        for (int i = 0; i < cur_count; ++i) {
            Node *node = &current[i];

            Move cand_moves[MAX_CHILDREN_PER_NODE];
            int cand_pairs[MAX_CHILDREN_PER_NODE];
            int cand_count = 0;

            for (int n = 2; n <= start->size && n <= n_limit; ++n) {
                const int limit = start->size - n;
                for (int y = 0; y <= limit; ++y) {
                    for (int x = 0; x <= limit; ++x) {
                        Move m = {(uint8_t)x, (uint8_t)y, (uint8_t)n};
                        Board nb;
                        board_copy(&nb, &node->board);
                        board_apply_move(&nb, &m);
                        consider_candidate(&m, &nb, cand_moves, cand_pairs, &cand_count);
                    }
                }
            }

            for (int c = 0; c < cand_count; ++c) {
                if (next_count >= beam_width * MAX_CHILDREN_PER_NODE) break;
                Node *child = &next[next_count++];
                board_copy(&child->board, &node->board);
                board_apply_move(&child->board, &cand_moves[c]);
                child->path_len = node->path_len + 1;
                if (child->path_len < MAX_PATH) {
                    memcpy(child->path, node->path, sizeof(Move) * node->path_len);
                    child->path[child->path_len - 1] = cand_moves[c];
                }
                child->score = move_score(&child->board, child->path_len);

                if (child->board.pair_count > best_pairs && child->path_len <= MAX_PATH) {
                    best_pairs = child->board.pair_count;
                    best.len = child->path_len;
                    best.pair_count = child->board.pair_count;
                    memcpy(best.moves, child->path, sizeof(Move) * child->path_len);
                    if (best_pairs >= max_pairs) {
                        //Perfect board
                        goto done;
                    }
                }
            }
        }

        if (next_count == 0) break;
        if (next_count > beam_width) {
            qsort(next, (size_t)next_count, sizeof(Node), cmp_node_desc);
            next_count = beam_width;
        }
        memcpy(current, next, (size_t)next_count * sizeof(Node));
        cur_count = next_count;
    }

done:
    free(next);
    free(current);
    if (best.len == 0 && best_pairs == start->pair_count) {
        // No improvement found; return empty solution
        out->len = 0;
        out->pair_count = start->pair_count;
        return 0;
    }
    *out = best;
    return 0;
}
