fixed pickle recursion depth error and updated depth of nodes recursively in apply_move
This commit is contained in:
@@ -60,12 +60,16 @@ class BayesianMctsNode(IMctsNode):
|
|||||||
|
|
||||||
return best_child
|
return best_child
|
||||||
|
|
||||||
|
def update_depth(self, depth: int) -> None:
|
||||||
|
self.depth = depth
|
||||||
|
for c in self.children:
|
||||||
|
c.update_depth(depth + 1)
|
||||||
|
|
||||||
def select(self) -> IMctsNode:
|
def select(self) -> IMctsNode:
|
||||||
if len(self.children) == 0:
|
if len(self.children) == 0 or self.board.is_game_over():
|
||||||
return self
|
return self
|
||||||
elif not self.board.is_game_over():
|
|
||||||
return self._select_best_child().select()
|
return self._select_best_child().select()
|
||||||
return self
|
|
||||||
|
|
||||||
def expand(self) -> IMctsNode:
|
def expand(self) -> IMctsNode:
|
||||||
if self.visits == 0:
|
if self.visits == 0:
|
||||||
@@ -159,6 +163,7 @@ class BayesianMcts(IMcts):
|
|||||||
self.root = child
|
self.root = child
|
||||||
child.depth = 0
|
child.depth = 0
|
||||||
self.root.parent = None
|
self.root.parent = None
|
||||||
|
self.root.update_depth(0)
|
||||||
return
|
return
|
||||||
|
|
||||||
# if no child node contains the move, initialize a new tree.
|
# if no child node contains the move, initialize a new tree.
|
||||||
|
|||||||
@@ -111,7 +111,6 @@ class BayesMctsEngine(Engine):
|
|||||||
min(moves.items(), key=lambda x: x[1])[0])
|
min(moves.items(), key=lambda x: x[1])[0])
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ClassicMctsEngine(Engine):
|
class ClassicMctsEngine(Engine):
|
||||||
def __init__(self, board: chess.Board, color: chess.Color, strategy: IStrategy):
|
def __init__(self, board: chess.Board, color: chess.Color, strategy: IStrategy):
|
||||||
super().__init__(board, color, strategy)
|
super().__init__(board, color, strategy)
|
||||||
|
|||||||
@@ -51,6 +51,13 @@ class IMctsNode(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def update_depth(self, depth: int) -> None:
|
||||||
|
"""
|
||||||
|
Recursively updates the depth the current node and all it's children
|
||||||
|
:param depth: new depth for current node
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class IMcts(ABC):
|
class IMcts(ABC):
|
||||||
def __init__(self, board: chess.Board, strategy: IStrategy, seed: int | None):
|
def __init__(self, board: chess.Board, strategy: IStrategy, seed: int | None):
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import chess.pgn
|
|||||||
from typing import Tuple, List
|
from typing import Tuple, List
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from chesspp.i_strategy import IStrategy
|
||||||
|
|
||||||
from chesspp.engine import Engine, Limit
|
from chesspp.engine import Engine, Limit
|
||||||
|
|
||||||
@@ -18,7 +19,7 @@ class Winner(Enum):
|
|||||||
@dataclass
|
@dataclass
|
||||||
class EvaluationResult:
|
class EvaluationResult:
|
||||||
winner: Winner
|
winner: Winner
|
||||||
game: chess.pgn.Game
|
game: str
|
||||||
|
|
||||||
|
|
||||||
def simulate_game(white: Engine, black: Engine, limit: Limit, board: chess.Board) -> chess.pgn.Game:
|
def simulate_game(white: Engine, black: Engine, limit: Limit, board: chess.Board) -> chess.pgn.Game:
|
||||||
@@ -48,7 +49,7 @@ class Evaluation:
|
|||||||
return pool.map(Evaluation._test_simulate, args)
|
return pool.map(Evaluation._test_simulate, args)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _test_simulate(arg: Tuple[Engine.__class__, Engine.__class__, Limit]) -> EvaluationResult:
|
def _test_simulate(arg: Tuple[Engine.__class__, IStrategy, Engine.__class__, IStrategy, Limit]) -> EvaluationResult:
|
||||||
engine_a, strategy_a, engine_b, strategy_b, limit = arg
|
engine_a, strategy_a, engine_b, strategy_b, limit = arg
|
||||||
flip_engines = bool(random.getrandbits(1))
|
flip_engines = bool(random.getrandbits(1))
|
||||||
|
|
||||||
@@ -73,4 +74,4 @@ class Evaluation:
|
|||||||
case (chess.BLACK, False):
|
case (chess.BLACK, False):
|
||||||
result = Winner.Engine_B
|
result = Winner.Engine_B
|
||||||
|
|
||||||
return EvaluationResult(result, game)
|
return EvaluationResult(result, str(game))
|
||||||
|
|||||||
Reference in New Issue
Block a user