Adjust ClassicMcts so that it implements the IMcts interfaces

This commit is contained in:
2024-02-01 02:34:38 +01:00
parent f521c707d0
commit 8d3325ee98
3 changed files with 40 additions and 30 deletions

View File

@@ -174,7 +174,7 @@ class ClassicMctsEngineV2(Engine):
def do(): def do():
nonlocal node_count nonlocal node_count
mcts.build_tree(1) mcts.sample(1)
node_count += 1 node_count += 1
limit.run(do) limit.run(do)

View File

@@ -1,19 +1,19 @@
import math import math
import random import random
from typing import Self
import chess import chess
import numpy as np import numpy as np
from chesspp.i_strategy import IStrategy from chesspp.i_strategy import IStrategy
from chesspp.mcts.i_mcts_node import IMctsNode
class ClassicMctsNodeV2: class ClassicMctsNodeV2(IMctsNode):
def __init__(self, board: chess.Board, color: chess.Color, strategy: IStrategy, parent=None, move: chess.Move | None = None, def __init__(self, board: chess.Board, color: chess.Color, strategy: IStrategy, parent: Self | None, move: chess.Move | None,
random_state: int | None = None, depth: int = 0): random_state: random.Random, depth: int = 0):
self.random = random.Random(random_state) super().__init__(board, strategy, parent, move, random_state)
self.board = board
self.color = color self.color = color
self.strategy = strategy
self.parent = parent self.parent = parent
self.move = move self.move = move
self.children = [] self.children = []
@@ -23,20 +23,23 @@ class ClassicMctsNodeV2:
self.score = 0 self.score = 0
self.depth = depth self.depth = depth
def _expand(self) -> 'ClassicMctsNodeV2': def expand(self) -> Self:
""" """
Expands the node, i.e., choose an action and apply it to the board Expands the node, i.e., choose an action and apply it to the board
:return: :return:
""" """
move = self.random.choice(self.untried_actions) if self.is_fully_expanded():
return self
move = self.random_state.choice(self.untried_actions)
self.untried_actions.remove(move) self.untried_actions.remove(move)
next_board = self.board.copy() next_board = self.board.copy()
next_board.push(move) next_board.push(move)
child_node = ClassicMctsNodeV2(next_board, color=not self.color, strategy=self.strategy, parent=self, move=move, depth=self.depth+1) child_node = ClassicMctsNodeV2(next_board, color=not self.color, strategy=self.strategy, parent=self, move=move, depth=self.depth+1, random_state=self.random_state)
self.children.append(child_node) self.children.append(child_node)
return child_node return child_node
def _rollout(self, rollout_depth: int = 4) -> int: def rollout(self, rollout_depth: int = 4) -> int:
""" """
Rolls out the node by simulating a game for a given depth. Rolls out the node by simulating a game for a given depth.
Sometimes this step is called 'simulation' or 'playout'. Sometimes this step is called 'simulation' or 'playout'.
@@ -55,7 +58,7 @@ class ClassicMctsNodeV2:
steps = max(2, steps) steps = max(2, steps)
return int(self.strategy.analyze_board(copied_board) / math.log2(steps)) return int(self.strategy.analyze_board(copied_board) / math.log2(steps))
def _backpropagate(self, score: float) -> None: def backpropagate(self, score: float | None = None) -> None:
""" """
Backpropagates the results of the rollout Backpropagates the results of the rollout
:param score: :param score:
@@ -63,14 +66,17 @@ class ClassicMctsNodeV2:
""" """
self.visits += 1 self.visits += 1
# TODO: maybe use score + num of moves together (a win in 1 move is better than a win in 20 moves) # TODO: maybe use score + num of moves together (a win in 1 move is better than a win in 20 moves)
self.score += score
if score is not None:
self.score += score
if self.parent: if self.parent:
self.parent._backpropagate(score) self.parent.backpropagate(score)
def is_fully_expanded(self) -> bool: def is_fully_expanded(self) -> bool:
return len(self.untried_actions) == 0 return len(self.untried_actions) == 0
def _best_child(self) -> 'ClassicMctsNodeV2': def _best_child(self) -> Self:
""" """
Picks the best child according to our policy Picks the best child according to our policy
:return: the best child :return: the best child
@@ -81,7 +87,7 @@ class ClassicMctsNodeV2:
best_child_index = np.argmax(choices_weights) if self.color == chess.WHITE else np.argmin(choices_weights) best_child_index = np.argmax(choices_weights) if self.color == chess.WHITE else np.argmin(choices_weights)
return self.children[best_child_index] return self.children[best_child_index]
def _select_leaf(self) -> 'ClassicMctsNodeV2': def select(self) -> Self:
""" """
Selects a leaf node. Selects a leaf node.
If the node is not expanded is will be expanded. If the node is not expanded is will be expanded.
@@ -90,8 +96,7 @@ class ClassicMctsNodeV2:
current_node = self current_node = self
while not current_node.board.is_game_over(): while not current_node.board.is_game_over():
if not current_node.is_fully_expanded(): if not current_node.is_fully_expanded():
return current_node._expand() return current_node
else: current_node = current_node._best_child()
current_node = current_node._best_child()
return current_node return current_node

View File

@@ -1,24 +1,29 @@
import chess import chess
from chesspp.i_strategy import IStrategy from chesspp.i_strategy import IStrategy
from chesspp.mcts.classic_mcts_node_v2 import ClassicMctsNodeV2 from chesspp.mcts.classic_mcts_node_v2 import ClassicMctsNodeV2
from chesspp.mcts.i_mcts import IMcts
from chesspp.mcts.i_mcts_node import IMctsNode
class ClassicMctsV2: class ClassicMctsV2(IMcts):
def __init__(self, board: chess.Board, color: chess.Color, strategy: IStrategy): def __init__(self, board: chess.Board, color: chess.Color, strategy: IStrategy, seed: int | None = None):
self.board = board super().__init__(board, strategy, seed)
self.color = color self.color = color
self.strategy = strategy self.root = ClassicMctsNodeV2(board, color, strategy, None, None, self.random_state)
self.root = ClassicMctsNodeV2(board, color, strategy)
def build_tree(self, samples: int = 1000): def apply_move(self, move: chess.Move) -> None:
pass
def get_children(self) -> list[IMctsNode]:
return self.root.children
def sample(self, samples: int = 1000):
""" """
Runs the MCTS with the given number of samples Runs the MCTS with the given number of samples
:param samples: number of simulations :param samples: number of simulations
:return: best node containing the best move :return: best node containing the best move
""" """
for i in range(samples): for i in range(samples):
node = self.root._select_leaf() node = self.root.select().expand()
score = node._rollout() score = node.rollout()
node._backpropagate(score) node.backpropagate(score)