diff --git a/chesspp/baysian_mcts.py b/chesspp/baysian_mcts.py index 025338b..9551bca 100644 --- a/chesspp/baysian_mcts.py +++ b/chesspp/baysian_mcts.py @@ -9,10 +9,10 @@ from chesspp.util_gaussian import gaussian_ucb1, max_gaussian, min_gaussian class BayesianMctsNode(IMctsNode): def __init__(self, board: chess.Board, strategy: IStrategy, color: chess.Color, parent: Self | None, move: chess.Move | None, - random_state: random.Random, inherit_result: int | None = None, depth: int = 0): + random_state: random.Random, inherit_result: int | None = None, depth: int = 0, visits: int = 0): super().__init__(board, strategy, parent, move, random_state) self.color = color # Color of the player whose turn it is - self.visits = 0 + self.visits = visits self.result = inherit_result if inherit_result is not None else 0 self._set_mu_sigma() self.depth = depth @@ -140,8 +140,7 @@ class BayesianMcts(IMcts): def __init__(self, board: chess.Board, strategy: IStrategy, color: chess.Color, seed: int | None = None): super().__init__(board, strategy, seed) - self.root = BayesianMctsNode(board, strategy, color, None, None, self.random_state) - self.root.visits += 1 + self.root = BayesianMctsNode(board, strategy, color, None, None, self.random_state, visits=1) self.color = color def sample(self, runs: int = 1000) -> None: @@ -164,10 +163,11 @@ class BayesianMcts(IMcts): child.depth = 0 self.root.parent = None self.root.update_depth(0) + self.root.visits = 1 return # if no child node contains the move, initialize a new tree. - self.root = BayesianMctsNode(self.board, self.root.strategy, self.color, None, None, self.random_state) + self.root = BayesianMctsNode(self.board, self.root.strategy, self.color, None, None, self.random_state, visits=1) def get_children(self) -> list[IMctsNode]: return self.root.children