Adjust ClassicMcts so that it implements the IMcts interfaces

This commit is contained in:
2024-02-01 02:34:38 +01:00
parent f521c707d0
commit 8d3325ee98
3 changed files with 40 additions and 30 deletions

View File

@@ -1,24 +1,29 @@
import chess
from chesspp.i_strategy import IStrategy
from chesspp.mcts.classic_mcts_node_v2 import ClassicMctsNodeV2
from chesspp.mcts.i_mcts import IMcts
from chesspp.mcts.i_mcts_node import IMctsNode
class ClassicMctsV2:
def __init__(self, board: chess.Board, color: chess.Color, strategy: IStrategy):
self.board = board
class ClassicMctsV2(IMcts):
def __init__(self, board: chess.Board, color: chess.Color, strategy: IStrategy, seed: int | None = None):
super().__init__(board, strategy, seed)
self.color = color
self.strategy = strategy
self.root = ClassicMctsNodeV2(board, color, strategy)
self.root = ClassicMctsNodeV2(board, color, strategy, None, None, self.random_state)
def build_tree(self, samples: int = 1000):
def apply_move(self, move: chess.Move) -> None:
pass
def get_children(self) -> list[IMctsNode]:
return self.root.children
def sample(self, samples: int = 1000):
"""
Runs the MCTS with the given number of samples
:param samples: number of simulations
:return: best node containing the best move
"""
for i in range(samples):
node = self.root._select_leaf()
score = node._rollout()
node._backpropagate(score)
node = self.root.select().expand()
score = node.rollout()
node.backpropagate(score)