Reworked posterior calculation in baysianMCTS
This commit is contained in:
@@ -77,8 +77,8 @@ class EngineFactory:
|
|||||||
return ClassicMctsEngine(chess.Board(), color, strategy)
|
return ClassicMctsEngine(chess.Board(), color, strategy)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def classic_mcts_v2(color: chess.Color, strategy: IStrategy, board: chess.Board | None = chess.Board()) -> Engine:
|
def classic_mcts_v2(color: chess.Color, strategy: IStrategy) -> Engine:
|
||||||
return ClassicMctsEngineV2(board, color, strategy)
|
return ClassicMctsEngineV2(chess.Board(), color, strategy)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_random_strategy(rollout_depth: int) -> IStrategy:
|
def _get_random_strategy(rollout_depth: int) -> IStrategy:
|
||||||
|
|||||||
@@ -17,7 +17,9 @@ class BayesianMctsNode(IMctsNode):
|
|||||||
self.color = color # Color of the player whose turn it is
|
self.color = color # Color of the player whose turn it is
|
||||||
self.visits = visits
|
self.visits = visits
|
||||||
self.result = inherit_result if inherit_result is not None else 0
|
self.result = inherit_result if inherit_result is not None else 0
|
||||||
self._set_mu_sigma()
|
# set priors
|
||||||
|
self.mu = self.result
|
||||||
|
self.sigma = 1
|
||||||
self.depth = depth
|
self.depth = depth
|
||||||
|
|
||||||
def _create_child(self, move: chess.Move) -> IMctsNode:
|
def _create_child(self, move: chess.Move) -> IMctsNode:
|
||||||
@@ -26,10 +28,6 @@ class BayesianMctsNode(IMctsNode):
|
|||||||
return BayesianMctsNode(copied_board, self.strategy, not self.color, self, move, self.random_state, self.result,
|
return BayesianMctsNode(copied_board, self.strategy, not self.color, self, move, self.random_state, self.result,
|
||||||
self.depth + 1)
|
self.depth + 1)
|
||||||
|
|
||||||
def _set_mu_sigma(self) -> None:
|
|
||||||
self.mu = self.result
|
|
||||||
self.sigma = 1
|
|
||||||
|
|
||||||
def _is_new_ucb1_better(self, current, new) -> bool:
|
def _is_new_ucb1_better(self, current, new) -> bool:
|
||||||
if self.color == chess.WHITE:
|
if self.color == chess.WHITE:
|
||||||
# maximize ucb1
|
# maximize ucb1
|
||||||
@@ -116,7 +114,20 @@ class BayesianMctsNode(IMctsNode):
|
|||||||
|
|
||||||
if len(self.children) == 0:
|
if len(self.children) == 0:
|
||||||
# leaf node
|
# leaf node
|
||||||
self._set_mu_sigma()
|
# prior
|
||||||
|
mu_pri = self.mu
|
||||||
|
sig_pri = self.sigma
|
||||||
|
|
||||||
|
# likelyhood
|
||||||
|
mu_li = self.result
|
||||||
|
sig_li = 1
|
||||||
|
|
||||||
|
# posterior
|
||||||
|
sig_pos = math.sqrt(sig_pri**2 + sig_li**2)
|
||||||
|
mu_pos = (sig_pri**2 * mu_li + sig_li**2 * mu_pri) / (sig_pri**2 + sig_li**2)
|
||||||
|
|
||||||
|
self.mu = mu_pos
|
||||||
|
self.sigma = sig_pos
|
||||||
else:
|
else:
|
||||||
# interior node
|
# interior node
|
||||||
shuffled_children = self.random_state.sample(self.children, len(self.children))
|
shuffled_children = self.random_state.sample(self.children, len(self.children))
|
||||||
|
|||||||
Reference in New Issue
Block a user