Reworked posterior calculation in baysianMCTS

This commit is contained in:
DarkCider
2024-02-01 13:06:22 +01:00
parent 8d3325ee98
commit db8f4e3e6f
2 changed files with 19 additions and 8 deletions

View File

@@ -77,8 +77,8 @@ class EngineFactory:
return ClassicMctsEngine(chess.Board(), color, strategy) return ClassicMctsEngine(chess.Board(), color, strategy)
@staticmethod @staticmethod
def classic_mcts_v2(color: chess.Color, strategy: IStrategy, board: chess.Board | None = chess.Board()) -> Engine: def classic_mcts_v2(color: chess.Color, strategy: IStrategy) -> Engine:
return ClassicMctsEngineV2(board, color, strategy) return ClassicMctsEngineV2(chess.Board(), color, strategy)
@staticmethod @staticmethod
def _get_random_strategy(rollout_depth: int) -> IStrategy: def _get_random_strategy(rollout_depth: int) -> IStrategy:

View File

@@ -17,7 +17,9 @@ class BayesianMctsNode(IMctsNode):
self.color = color # Color of the player whose turn it is self.color = color # Color of the player whose turn it is
self.visits = visits self.visits = visits
self.result = inherit_result if inherit_result is not None else 0 self.result = inherit_result if inherit_result is not None else 0
self._set_mu_sigma() # set priors
self.mu = self.result
self.sigma = 1
self.depth = depth self.depth = depth
def _create_child(self, move: chess.Move) -> IMctsNode: def _create_child(self, move: chess.Move) -> IMctsNode:
@@ -26,10 +28,6 @@ class BayesianMctsNode(IMctsNode):
return BayesianMctsNode(copied_board, self.strategy, not self.color, self, move, self.random_state, self.result, return BayesianMctsNode(copied_board, self.strategy, not self.color, self, move, self.random_state, self.result,
self.depth + 1) self.depth + 1)
def _set_mu_sigma(self) -> None:
self.mu = self.result
self.sigma = 1
def _is_new_ucb1_better(self, current, new) -> bool: def _is_new_ucb1_better(self, current, new) -> bool:
if self.color == chess.WHITE: if self.color == chess.WHITE:
# maximize ucb1 # maximize ucb1
@@ -116,7 +114,20 @@ class BayesianMctsNode(IMctsNode):
if len(self.children) == 0: if len(self.children) == 0:
# leaf node # leaf node
self._set_mu_sigma() # prior
mu_pri = self.mu
sig_pri = self.sigma
# likelyhood
mu_li = self.result
sig_li = 1
# posterior
sig_pos = math.sqrt(sig_pri**2 + sig_li**2)
mu_pos = (sig_pri**2 * mu_li + sig_li**2 * mu_pri) / (sig_pri**2 + sig_li**2)
self.mu = mu_pos
self.sigma = sig_pos
else: else:
# interior node # interior node
shuffled_children = self.random_state.sample(self.children, len(self.children)) shuffled_children = self.random_state.sample(self.children, len(self.children))