From db8f4e3e6fedea204162d9a36e365442869c4b6b Mon Sep 17 00:00:00 2001 From: DarkCider <52292032+DarkCider@users.noreply.github.com> Date: Thu, 1 Feb 2024 13:06:22 +0100 Subject: [PATCH] Reworked posterior calculation in baysianMCTS --- chesspp/engine_factory.py | 4 ++-- chesspp/mcts/baysian_mcts_node.py | 23 +++++++++++++++++------ 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/chesspp/engine_factory.py b/chesspp/engine_factory.py index 9a97cc7..2956f1e 100644 --- a/chesspp/engine_factory.py +++ b/chesspp/engine_factory.py @@ -77,8 +77,8 @@ class EngineFactory: return ClassicMctsEngine(chess.Board(), color, strategy) @staticmethod - def classic_mcts_v2(color: chess.Color, strategy: IStrategy, board: chess.Board | None = chess.Board()) -> Engine: - return ClassicMctsEngineV2(board, color, strategy) + def classic_mcts_v2(color: chess.Color, strategy: IStrategy) -> Engine: + return ClassicMctsEngineV2(chess.Board(), color, strategy) @staticmethod def _get_random_strategy(rollout_depth: int) -> IStrategy: diff --git a/chesspp/mcts/baysian_mcts_node.py b/chesspp/mcts/baysian_mcts_node.py index 16c5fe3..acdecce 100644 --- a/chesspp/mcts/baysian_mcts_node.py +++ b/chesspp/mcts/baysian_mcts_node.py @@ -17,7 +17,9 @@ class BayesianMctsNode(IMctsNode): self.color = color # Color of the player whose turn it is self.visits = visits self.result = inherit_result if inherit_result is not None else 0 - self._set_mu_sigma() + # set priors + self.mu = self.result + self.sigma = 1 self.depth = depth def _create_child(self, move: chess.Move) -> IMctsNode: @@ -26,10 +28,6 @@ class BayesianMctsNode(IMctsNode): return BayesianMctsNode(copied_board, self.strategy, not self.color, self, move, self.random_state, self.result, self.depth + 1) - def _set_mu_sigma(self) -> None: - self.mu = self.result - self.sigma = 1 - def _is_new_ucb1_better(self, current, new) -> bool: if self.color == chess.WHITE: # maximize ucb1 @@ -116,7 +114,20 @@ class BayesianMctsNode(IMctsNode): if len(self.children) == 0: # leaf node - self._set_mu_sigma() + # prior + mu_pri = self.mu + sig_pri = self.sigma + + # likelyhood + mu_li = self.result + sig_li = 1 + + # posterior + sig_pos = math.sqrt(sig_pri**2 + sig_li**2) + mu_pos = (sig_pri**2 * mu_li + sig_li**2 * mu_pri) / (sig_pri**2 + sig_li**2) + + self.mu = mu_pos + self.sigma = sig_pos else: # interior node shuffled_children = self.random_state.sample(self.children, len(self.children))