From b9761e1e2b9c01bbc5788152c024be4e70d37062 Mon Sep 17 00:00:00 2001 From: luk3k Date: Mon, 29 Jan 2024 19:25:35 +0100 Subject: [PATCH] fixed pickle recursion depth error and updated depth of nodes recursively in apply_move --- chesspp/baysian_mcts.py | 13 +++++++++---- chesspp/engine.py | 1 - chesspp/i_mcts.py | 7 +++++++ chesspp/simulation.py | 7 ++++--- main.py | 2 +- 5 files changed, 21 insertions(+), 9 deletions(-) diff --git a/chesspp/baysian_mcts.py b/chesspp/baysian_mcts.py index c58a699..025338b 100644 --- a/chesspp/baysian_mcts.py +++ b/chesspp/baysian_mcts.py @@ -60,12 +60,16 @@ class BayesianMctsNode(IMctsNode): return best_child + def update_depth(self, depth: int) -> None: + self.depth = depth + for c in self.children: + c.update_depth(depth + 1) + def select(self) -> IMctsNode: - if len(self.children) == 0: + if len(self.children) == 0 or self.board.is_game_over(): return self - elif not self.board.is_game_over(): - return self._select_best_child().select() - return self + + return self._select_best_child().select() def expand(self) -> IMctsNode: if self.visits == 0: @@ -159,6 +163,7 @@ class BayesianMcts(IMcts): self.root = child child.depth = 0 self.root.parent = None + self.root.update_depth(0) return # if no child node contains the move, initialize a new tree. diff --git a/chesspp/engine.py b/chesspp/engine.py index 582adab..da08b64 100644 --- a/chesspp/engine.py +++ b/chesspp/engine.py @@ -111,7 +111,6 @@ class BayesMctsEngine(Engine): min(moves.items(), key=lambda x: x[1])[0]) - class ClassicMctsEngine(Engine): def __init__(self, board: chess.Board, color: chess.Color, strategy: IStrategy): super().__init__(board, color, strategy) diff --git a/chesspp/i_mcts.py b/chesspp/i_mcts.py index 1df8085..a3a3856 100644 --- a/chesspp/i_mcts.py +++ b/chesspp/i_mcts.py @@ -51,6 +51,13 @@ class IMctsNode(ABC): """ pass + def update_depth(self, depth: int) -> None: + """ + Recursively updates the depth the current node and all it's children + :param depth: new depth for current node + :return: + """ + class IMcts(ABC): def __init__(self, board: chess.Board, strategy: IStrategy, seed: int | None): diff --git a/chesspp/simulation.py b/chesspp/simulation.py index b7b1d33..3662f4f 100644 --- a/chesspp/simulation.py +++ b/chesspp/simulation.py @@ -5,6 +5,7 @@ import chess.pgn from typing import Tuple, List from enum import Enum from dataclasses import dataclass +from chesspp.i_strategy import IStrategy from chesspp.engine import Engine, Limit @@ -18,7 +19,7 @@ class Winner(Enum): @dataclass class EvaluationResult: winner: Winner - game: chess.pgn.Game + game: str def simulate_game(white: Engine, black: Engine, limit: Limit, board: chess.Board) -> chess.pgn.Game: @@ -48,7 +49,7 @@ class Evaluation: return pool.map(Evaluation._test_simulate, args) @staticmethod - def _test_simulate(arg: Tuple[Engine.__class__, Engine.__class__, Limit]) -> EvaluationResult: + def _test_simulate(arg: Tuple[Engine.__class__, IStrategy, Engine.__class__, IStrategy, Limit]) -> EvaluationResult: engine_a, strategy_a, engine_b, strategy_b, limit = arg flip_engines = bool(random.getrandbits(1)) @@ -73,4 +74,4 @@ class Evaluation: case (chess.BLACK, False): result = Winner.Engine_B - return EvaluationResult(result, game) + return EvaluationResult(result, str(game)) diff --git a/main.py b/main.py index c0d14d9..b890209 100644 --- a/main.py +++ b/main.py @@ -96,7 +96,7 @@ def test_evaluation(): def main(): test_evaluation() - #test_simulate() + # test_simulate() # test_mcts() # test_stockfish() # test_stockfish_prob()