fixed pickle recursion depth error and updated depth of nodes recursively in apply_move

This commit is contained in:
2024-01-29 19:25:35 +01:00
parent d43899ecda
commit b9761e1e2b
5 changed files with 21 additions and 9 deletions

View File

@@ -60,12 +60,16 @@ class BayesianMctsNode(IMctsNode):
return best_child return best_child
def update_depth(self, depth: int) -> None:
self.depth = depth
for c in self.children:
c.update_depth(depth + 1)
def select(self) -> IMctsNode: def select(self) -> IMctsNode:
if len(self.children) == 0: if len(self.children) == 0 or self.board.is_game_over():
return self return self
elif not self.board.is_game_over():
return self._select_best_child().select() return self._select_best_child().select()
return self
def expand(self) -> IMctsNode: def expand(self) -> IMctsNode:
if self.visits == 0: if self.visits == 0:
@@ -159,6 +163,7 @@ class BayesianMcts(IMcts):
self.root = child self.root = child
child.depth = 0 child.depth = 0
self.root.parent = None self.root.parent = None
self.root.update_depth(0)
return return
# if no child node contains the move, initialize a new tree. # if no child node contains the move, initialize a new tree.

View File

@@ -111,7 +111,6 @@ class BayesMctsEngine(Engine):
min(moves.items(), key=lambda x: x[1])[0]) min(moves.items(), key=lambda x: x[1])[0])
class ClassicMctsEngine(Engine): class ClassicMctsEngine(Engine):
def __init__(self, board: chess.Board, color: chess.Color, strategy: IStrategy): def __init__(self, board: chess.Board, color: chess.Color, strategy: IStrategy):
super().__init__(board, color, strategy) super().__init__(board, color, strategy)

View File

@@ -51,6 +51,13 @@ class IMctsNode(ABC):
""" """
pass pass
def update_depth(self, depth: int) -> None:
"""
Recursively updates the depth the current node and all it's children
:param depth: new depth for current node
:return:
"""
class IMcts(ABC): class IMcts(ABC):
def __init__(self, board: chess.Board, strategy: IStrategy, seed: int | None): def __init__(self, board: chess.Board, strategy: IStrategy, seed: int | None):

View File

@@ -5,6 +5,7 @@ import chess.pgn
from typing import Tuple, List from typing import Tuple, List
from enum import Enum from enum import Enum
from dataclasses import dataclass from dataclasses import dataclass
from chesspp.i_strategy import IStrategy
from chesspp.engine import Engine, Limit from chesspp.engine import Engine, Limit
@@ -18,7 +19,7 @@ class Winner(Enum):
@dataclass @dataclass
class EvaluationResult: class EvaluationResult:
winner: Winner winner: Winner
game: chess.pgn.Game game: str
def simulate_game(white: Engine, black: Engine, limit: Limit, board: chess.Board) -> chess.pgn.Game: def simulate_game(white: Engine, black: Engine, limit: Limit, board: chess.Board) -> chess.pgn.Game:
@@ -48,7 +49,7 @@ class Evaluation:
return pool.map(Evaluation._test_simulate, args) return pool.map(Evaluation._test_simulate, args)
@staticmethod @staticmethod
def _test_simulate(arg: Tuple[Engine.__class__, Engine.__class__, Limit]) -> EvaluationResult: def _test_simulate(arg: Tuple[Engine.__class__, IStrategy, Engine.__class__, IStrategy, Limit]) -> EvaluationResult:
engine_a, strategy_a, engine_b, strategy_b, limit = arg engine_a, strategy_a, engine_b, strategy_b, limit = arg
flip_engines = bool(random.getrandbits(1)) flip_engines = bool(random.getrandbits(1))
@@ -73,4 +74,4 @@ class Evaluation:
case (chess.BLACK, False): case (chess.BLACK, False):
result = Winner.Engine_B result = Winner.Engine_B
return EvaluationResult(result, game) return EvaluationResult(result, str(game))