From f521c707d0afb085e7cca0e6f7ea47737633bdec Mon Sep 17 00:00:00 2001 From: Lukas Wieser Date: Thu, 1 Feb 2024 02:06:03 +0100 Subject: [PATCH] Create a new ClassicMcts, which is split into two files --- chesspp/engine.py | 31 ++++++++- chesspp/engine_factory.py | 9 ++- chesspp/mcts/classic_mcts_node_v2.py | 97 ++++++++++++++++++++++++++++ chesspp/mcts/classic_mcts_v2.py | 24 +++++++ main.py | 2 +- 5 files changed, 160 insertions(+), 3 deletions(-) create mode 100644 chesspp/mcts/classic_mcts_node_v2.py create mode 100644 chesspp/mcts/classic_mcts_v2.py diff --git a/chesspp/engine.py b/chesspp/engine.py index 5ce3b9a..0364d21 100644 --- a/chesspp/engine.py +++ b/chesspp/engine.py @@ -10,8 +10,11 @@ from stockfish import Stockfish from chesspp.mcts.baysian_mcts import BayesianMcts from chesspp.mcts.classic_mcts import ClassicMcts from chesspp.i_strategy import IStrategy + from typing import Dict +from chesspp.mcts.classic_mcts_v2 import ClassicMctsV2 + class Limit: """ Class to determine when to stop searching for moves """ @@ -156,6 +159,31 @@ class ClassicMctsEngine(Engine): return chess.engine.PlayResult(move=best_move, ponder=None) +class ClassicMctsEngineV2(Engine): + def __init__(self, board: chess.Board, color: chess.Color, strategy: IStrategy): + super().__init__(board, color, strategy) + self.node_counts = [] + + @staticmethod + def get_name() -> str: + return "ClassicMctsEngine V2" + + def play(self, board: chess.Board, limit: Limit) -> chess.engine.PlayResult: + mcts = ClassicMctsV2(board, self.color, self.strategy) + node_count = 0 + + def do(): + nonlocal node_count + mcts.build_tree(1) + node_count += 1 + + limit.run(do) + self.node_counts.append(node_count) + best_move = max(mcts.root.children, key=lambda x: x.score).move if board.turn == chess.WHITE else ( + min(mcts.root.children, key=lambda x: x.score).move) + return chess.engine.PlayResult(move=best_move, ponder=None) + + class RandomEngine(Engine): def __init__(self, board: chess.Board, color: chess.Color, strategy: IStrategy): super().__init__(board, color, strategy) @@ -170,7 +198,8 @@ class RandomEngine(Engine): class StockFishEngine(Engine): - def __init__(self, board: chess.Board, color: chess, stockfish_elo: int, path="../stockfish/stockfish-ubuntu-x86-64-avx2"): + def __init__(self, board: chess.Board, color: chess, stockfish_elo: int, + path="../stockfish/stockfish-ubuntu-x86-64-avx2"): super().__init__(board, color, None) self.stockfish = Stockfish(path) self.stockfish.set_elo_rating(stockfish_elo) diff --git a/chesspp/engine_factory.py b/chesspp/engine_factory.py index df31818..9a97cc7 100644 --- a/chesspp/engine_factory.py +++ b/chesspp/engine_factory.py @@ -16,6 +16,7 @@ class EngineEnum(Enum): Stockfish = 2 Lc0 = 3 Random = 4 + ClassicMctsV2 = 5 class StrategyEnum(Enum): @@ -47,6 +48,9 @@ class EngineFactory: case EngineEnum.ClassicMcts: return EngineFactory.classic_mcts(color, strategy) + case EngineEnum.ClassicMctsV2: + return EngineFactory.classic_mcts_v2(color, strategy) + case EngineEnum.BayesianMcts: return EngineFactory.bayesian_mcts(color, strategy) @@ -72,6 +76,10 @@ class EngineFactory: def classic_mcts(color: chess.Color, strategy: IStrategy) -> Engine: return ClassicMctsEngine(chess.Board(), color, strategy) + @staticmethod + def classic_mcts_v2(color: chess.Color, strategy: IStrategy, board: chess.Board | None = chess.Board()) -> Engine: + return ClassicMctsEngineV2(board, color, strategy) + @staticmethod def _get_random_strategy(rollout_depth: int) -> IStrategy: return RandomStrategy(random.Random(), rollout_depth) @@ -91,4 +99,3 @@ class EngineFactory: @staticmethod def _get_pesto_strategy(rollout_depth: int) -> IStrategy: return PestoStrategy(rollout_depth) - diff --git a/chesspp/mcts/classic_mcts_node_v2.py b/chesspp/mcts/classic_mcts_node_v2.py new file mode 100644 index 0000000..91e9320 --- /dev/null +++ b/chesspp/mcts/classic_mcts_node_v2.py @@ -0,0 +1,97 @@ +import math +import random + +import chess +import numpy as np + +from chesspp.i_strategy import IStrategy + + +class ClassicMctsNodeV2: + def __init__(self, board: chess.Board, color: chess.Color, strategy: IStrategy, parent=None, move: chess.Move | None = None, + random_state: int | None = None, depth: int = 0): + self.random = random.Random(random_state) + self.board = board + self.color = color + self.strategy = strategy + self.parent = parent + self.move = move + self.children = [] + self.visits = 0 + self.legal_moves = list(board.legal_moves) + self.untried_actions = self.legal_moves + self.score = 0 + self.depth = depth + + def _expand(self) -> 'ClassicMctsNodeV2': + """ + Expands the node, i.e., choose an action and apply it to the board + :return: + """ + move = self.random.choice(self.untried_actions) + self.untried_actions.remove(move) + next_board = self.board.copy() + next_board.push(move) + child_node = ClassicMctsNodeV2(next_board, color=not self.color, strategy=self.strategy, parent=self, move=move, depth=self.depth+1) + self.children.append(child_node) + return child_node + + def _rollout(self, rollout_depth: int = 4) -> int: + """ + Rolls out the node by simulating a game for a given depth. + Sometimes this step is called 'simulation' or 'playout'. + :return: the score of the rolled out game + """ + copied_board = self.board.copy() + steps = self.depth + for i in range(rollout_depth): + if copied_board.is_game_over(): + break + + m = self.strategy.pick_next_move(copied_board) + copied_board.push(m) + steps += 1 + + steps = max(2, steps) + return int(self.strategy.analyze_board(copied_board) / math.log2(steps)) + + def _backpropagate(self, score: float) -> None: + """ + Backpropagates the results of the rollout + :param score: + :return: + """ + self.visits += 1 + # TODO: maybe use score + num of moves together (a win in 1 move is better than a win in 20 moves) + self.score += score + if self.parent: + self.parent._backpropagate(score) + + def is_fully_expanded(self) -> bool: + return len(self.untried_actions) == 0 + + def _best_child(self) -> 'ClassicMctsNodeV2': + """ + Picks the best child according to our policy + :return: the best child + """ + # NOTE: maybe clamp the score between [-1, +1] instead of [-inf, +inf] + choices_weights = [(c.score / c.visits) + np.sqrt(((2 * np.log(self.visits)) / c.visits)) + for c in self.children] + best_child_index = np.argmax(choices_weights) if self.color == chess.WHITE else np.argmin(choices_weights) + return self.children[best_child_index] + + def _select_leaf(self) -> 'ClassicMctsNodeV2': + """ + Selects a leaf node. + If the node is not expanded is will be expanded. + :return: Leaf node + """ + current_node = self + while not current_node.board.is_game_over(): + if not current_node.is_fully_expanded(): + return current_node._expand() + else: + current_node = current_node._best_child() + + return current_node diff --git a/chesspp/mcts/classic_mcts_v2.py b/chesspp/mcts/classic_mcts_v2.py new file mode 100644 index 0000000..2dfe737 --- /dev/null +++ b/chesspp/mcts/classic_mcts_v2.py @@ -0,0 +1,24 @@ +import chess +from chesspp.i_strategy import IStrategy +from chesspp.mcts.classic_mcts_node_v2 import ClassicMctsNodeV2 + + +class ClassicMctsV2: + def __init__(self, board: chess.Board, color: chess.Color, strategy: IStrategy): + self.board = board + self.color = color + self.strategy = strategy + self.root = ClassicMctsNodeV2(board, color, strategy) + + def build_tree(self, samples: int = 1000): + """ + Runs the MCTS with the given number of samples + :param samples: number of simulations + :return: best node containing the best move + """ + for i in range(samples): + node = self.root._select_leaf() + score = node._rollout() + node._backpropagate(score) + + diff --git a/main.py b/main.py index 93d79c3..d1512e3 100644 --- a/main.py +++ b/main.py @@ -132,7 +132,7 @@ def read_arguments(): description='Compare two engines by playing multiple games against each other' ) - engines = {"ClassicMCTS": EngineEnum.ClassicMcts, "BayesianMCTS": EngineEnum.BayesianMcts, + engines = {"ClassicMCTS": EngineEnum.ClassicMcts, "BayesianMCTS": EngineEnum.BayesianMcts, "ClassicMCTSV2": EngineEnum.ClassicMctsV2, "Random": EngineEnum.Random, "Stockfish": EngineEnum.Stockfish, "Lc0": EngineEnum.Lc0} strategies = {"Random": StrategyEnum.Random, "Stockfish": StrategyEnum.Stockfish, "Lc0": StrategyEnum.Lc0, "RandomStockfish": StrategyEnum.RandomStockfish, "PESTO": StrategyEnum.Pestos}