diff --git a/chesspp/engine.py b/chesspp/engine.py index 0364d21..c977407 100644 --- a/chesspp/engine.py +++ b/chesspp/engine.py @@ -174,7 +174,7 @@ class ClassicMctsEngineV2(Engine): def do(): nonlocal node_count - mcts.build_tree(1) + mcts.sample(1) node_count += 1 limit.run(do) diff --git a/chesspp/mcts/classic_mcts_node_v2.py b/chesspp/mcts/classic_mcts_node_v2.py index 91e9320..c173d64 100644 --- a/chesspp/mcts/classic_mcts_node_v2.py +++ b/chesspp/mcts/classic_mcts_node_v2.py @@ -1,19 +1,19 @@ import math import random +from typing import Self import chess import numpy as np from chesspp.i_strategy import IStrategy +from chesspp.mcts.i_mcts_node import IMctsNode -class ClassicMctsNodeV2: - def __init__(self, board: chess.Board, color: chess.Color, strategy: IStrategy, parent=None, move: chess.Move | None = None, - random_state: int | None = None, depth: int = 0): - self.random = random.Random(random_state) - self.board = board +class ClassicMctsNodeV2(IMctsNode): + def __init__(self, board: chess.Board, color: chess.Color, strategy: IStrategy, parent: Self | None, move: chess.Move | None, + random_state: random.Random, depth: int = 0): + super().__init__(board, strategy, parent, move, random_state) self.color = color - self.strategy = strategy self.parent = parent self.move = move self.children = [] @@ -23,20 +23,23 @@ class ClassicMctsNodeV2: self.score = 0 self.depth = depth - def _expand(self) -> 'ClassicMctsNodeV2': + def expand(self) -> Self: """ Expands the node, i.e., choose an action and apply it to the board :return: """ - move = self.random.choice(self.untried_actions) + if self.is_fully_expanded(): + return self + + move = self.random_state.choice(self.untried_actions) self.untried_actions.remove(move) next_board = self.board.copy() next_board.push(move) - child_node = ClassicMctsNodeV2(next_board, color=not self.color, strategy=self.strategy, parent=self, move=move, depth=self.depth+1) + child_node = ClassicMctsNodeV2(next_board, color=not self.color, strategy=self.strategy, parent=self, move=move, depth=self.depth+1, random_state=self.random_state) self.children.append(child_node) return child_node - def _rollout(self, rollout_depth: int = 4) -> int: + def rollout(self, rollout_depth: int = 4) -> int: """ Rolls out the node by simulating a game for a given depth. Sometimes this step is called 'simulation' or 'playout'. @@ -55,7 +58,7 @@ class ClassicMctsNodeV2: steps = max(2, steps) return int(self.strategy.analyze_board(copied_board) / math.log2(steps)) - def _backpropagate(self, score: float) -> None: + def backpropagate(self, score: float | None = None) -> None: """ Backpropagates the results of the rollout :param score: @@ -63,14 +66,17 @@ class ClassicMctsNodeV2: """ self.visits += 1 # TODO: maybe use score + num of moves together (a win in 1 move is better than a win in 20 moves) - self.score += score + + if score is not None: + self.score += score + if self.parent: - self.parent._backpropagate(score) + self.parent.backpropagate(score) def is_fully_expanded(self) -> bool: return len(self.untried_actions) == 0 - def _best_child(self) -> 'ClassicMctsNodeV2': + def _best_child(self) -> Self: """ Picks the best child according to our policy :return: the best child @@ -81,7 +87,7 @@ class ClassicMctsNodeV2: best_child_index = np.argmax(choices_weights) if self.color == chess.WHITE else np.argmin(choices_weights) return self.children[best_child_index] - def _select_leaf(self) -> 'ClassicMctsNodeV2': + def select(self) -> Self: """ Selects a leaf node. If the node is not expanded is will be expanded. @@ -90,8 +96,7 @@ class ClassicMctsNodeV2: current_node = self while not current_node.board.is_game_over(): if not current_node.is_fully_expanded(): - return current_node._expand() - else: - current_node = current_node._best_child() + return current_node + current_node = current_node._best_child() return current_node diff --git a/chesspp/mcts/classic_mcts_v2.py b/chesspp/mcts/classic_mcts_v2.py index 2dfe737..187b752 100644 --- a/chesspp/mcts/classic_mcts_v2.py +++ b/chesspp/mcts/classic_mcts_v2.py @@ -1,24 +1,29 @@ import chess from chesspp.i_strategy import IStrategy from chesspp.mcts.classic_mcts_node_v2 import ClassicMctsNodeV2 +from chesspp.mcts.i_mcts import IMcts +from chesspp.mcts.i_mcts_node import IMctsNode -class ClassicMctsV2: - def __init__(self, board: chess.Board, color: chess.Color, strategy: IStrategy): - self.board = board +class ClassicMctsV2(IMcts): + def __init__(self, board: chess.Board, color: chess.Color, strategy: IStrategy, seed: int | None = None): + super().__init__(board, strategy, seed) self.color = color - self.strategy = strategy - self.root = ClassicMctsNodeV2(board, color, strategy) + self.root = ClassicMctsNodeV2(board, color, strategy, None, None, self.random_state) - def build_tree(self, samples: int = 1000): + def apply_move(self, move: chess.Move) -> None: + pass + + def get_children(self) -> list[IMctsNode]: + return self.root.children + + def sample(self, samples: int = 1000): """ Runs the MCTS with the given number of samples :param samples: number of simulations :return: best node containing the best move """ for i in range(samples): - node = self.root._select_leaf() - score = node._rollout() - node._backpropagate(score) - - + node = self.root.select().expand() + score = node.rollout() + node.backpropagate(score)