diff --git a/eval.py b/eval.py index e820fda..3393a6e 100644 --- a/eval.py +++ b/eval.py @@ -177,6 +177,6 @@ def score_stockfish(board: chess.Board) -> chess.engine.PovScore: :return: """ engine = chess.engine.SimpleEngine.popen_uci("./stockfish/stockfish-ubuntu-x86-64-avx2") - info = engine.analyse(board, chess.engine.Limit(depth=20)) + info = engine.analyse(board, chess.engine.Limit(depth=2)) engine.quit() return info["score"] diff --git a/main.py b/main.py index 9caad17..e3b501d 100644 --- a/main.py +++ b/main.py @@ -5,7 +5,7 @@ import engine import eval -def test_mcts(seed): +def test_mcts(): fools_mate = "rnbqkbnr/pppp1ppp/4p3/8/5PP1/8/PPPPP2P/RNBQKBNR b KQkq f3 0 2" board = chess.Board(fools_mate) mcts_root = MCTSNode(board) @@ -15,7 +15,7 @@ def test_mcts(seed): print("move (mcts):", c.move, " with score:", c.score) -def test_stockfish(seed): +def test_stockfish(): fools_mate = "rnbqkbnr/pppp1ppp/4p3/8/5PP1/8/PPPPP2P/RNBQKBNR b KQkq f3 0 2" board = chess.Board(fools_mate) moves = {} @@ -37,8 +37,8 @@ def analyze_results(moves: dict): def main(): - test_mcts(0) - test_stockfish(0) + test_mcts() + test_stockfish() if __name__ == '__main__': diff --git a/mcts.py b/mcts.py index 7e56d48..df551d6 100644 --- a/mcts.py +++ b/mcts.py @@ -30,21 +30,23 @@ class MCTSNode: self.children.append(child_node) return child_node - def _rollout(self, rollout_depth: int = 100) -> float: + def _rollout(self, rollout_depth: int = 20) -> int: """ Rolls out the node by simulating a game for a given depth. Sometimes this step is called 'simulation' or 'playout'. :return: the score of the rolled out game """ copied_board = self.board.copy() + steps = 1 for i in range(rollout_depth): if copied_board.is_game_over(): break m = engine.pick_move(copied_board) copied_board.push(m) + steps += 1 - return eval.score_manual(copied_board) + return eval.score_manual(copied_board) // steps def _backpropagate(self, score: float) -> None: """