added mcts and strategy base classes

This commit is contained in:
2024-01-26 18:02:44 +01:00
parent e4fa09bac3
commit 662da27f72
5 changed files with 29 additions and 33 deletions

View File

@@ -2,32 +2,13 @@ import chess
import random import random
import eval import eval
import engine import engine
import IStrategy
import numpy as np import numpy as np
from abc import ABC, abstractmethod
class IMcts(ABC): class ClassicMcts:
def __init__(self, board: chess.Board, strategy: IStrategy): def __init__(self, board: chess.Board, parent=None, move: chess.Move | None = None,
self.board = board random_state: int | None = None):
@abstractmethod
def sample(self, runs: int = 1000) -> None:
pass
@abstractmethod
def apply_move(self, move: chess.Move) -> None:
pass
@abstractmethod
def get_children(self) -> list['Mcts']:
pass
class MCTSNode:
def __init__(self, board: chess.Board, parent = None, move: chess.Move | None = None, random_state: int | None = None):
self.random = random.Random(random_state) self.random = random.Random(random_state)
self.board = board self.board = board
self.parent = parent self.parent = parent
@@ -38,7 +19,7 @@ class MCTSNode:
self.untried_actions = self.legal_moves self.untried_actions = self.legal_moves
self.score = 0 self.score = 0
def _expand(self) -> 'MCTSNode': def _expand(self) -> 'ClassicMcts':
""" """
Expands the node, i.e., choose an action and apply it to the board Expands the node, i.e., choose an action and apply it to the board
:return: :return:
@@ -47,7 +28,7 @@ class MCTSNode:
self.untried_actions.remove(move) self.untried_actions.remove(move)
next_board = self.board.copy() next_board = self.board.copy()
next_board.push(move) next_board.push(move)
child_node = MCTSNode(next_board, parent=self, move=move) child_node = ClassicMcts(next_board, parent=self, move=move)
self.children.append(child_node) self.children.append(child_node)
return child_node return child_node
@@ -84,7 +65,7 @@ class MCTSNode:
def is_fully_expanded(self) -> bool: def is_fully_expanded(self) -> bool:
return len(self.untried_actions) == 0 return len(self.untried_actions) == 0
def _best_child(self) -> 'MCTSNode': def _best_child(self) -> 'ClassicMcts':
""" """
Picks the best child according to our policy Picks the best child according to our policy
:return: the best child :return: the best child
@@ -94,7 +75,7 @@ class MCTSNode:
for c in self.children] for c in self.children]
return self.children[np.argmax(choices_weights)] return self.children[np.argmax(choices_weights)]
def _select_leaf(self) -> 'MCTSNode': def _select_leaf(self) -> 'ClassicMcts':
""" """
Selects a leaf node. Selects a leaf node.
If the node is not expanded is will be expanded. If the node is not expanded is will be expanded.
@@ -109,7 +90,7 @@ class MCTSNode:
return current_node return current_node
def build_tree(self, samples: int = 1000) -> 'MCTSNode': def build_tree(self, samples: int = 1000) -> 'ClassicMcts':
""" """
Runs the MCTS with the given number of samples Runs the MCTS with the given number of samples
:param samples: number of simulations :param samples: number of simulations

View File

@@ -134,7 +134,8 @@ def check_endgame(board: chess.Board) -> bool:
else: else:
minors_black += 1 minors_black += 1
return (queens_black == 0 and queens_white == 0) or ((queens_black >= 1 and minors_black <= 1) or (queens_white >= 1 and minors_white <= 1)) return (queens_black == 0 and queens_white == 0) or ((queens_black >= 1 >= minors_black) or (
queens_white >= 1 >= minors_white))
def score_manual(board: chess.Board) -> int: def score_manual(board: chess.Board) -> int:
@@ -177,6 +178,6 @@ def score_stockfish(board: chess.Board) -> chess.engine.PovScore:
:return: :return:
""" """
engine = chess.engine.SimpleEngine.popen_uci("./stockfish/stockfish-ubuntu-x86-64-avx2") engine = chess.engine.SimpleEngine.popen_uci("./stockfish/stockfish-ubuntu-x86-64-avx2")
info = engine.analyse(board, chess.engine.Limit(depth=2)) info = engine.analyse(board, chess.engine.Limit(depth=0))
engine.quit() engine.quit()
return info["score"] return info["score"]

View File

@@ -1,6 +1,6 @@
import chess import chess
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from IStrategy import IStrategy from i_strategy import IStrategy
class IMcts(ABC): class IMcts(ABC):
@@ -10,12 +10,26 @@ class IMcts(ABC):
@abstractmethod @abstractmethod
def sample(self, runs: int = 1000) -> None: def sample(self, runs: int = 1000) -> None:
"""
Run the MCTS simulation
:param runs: number of runs
:return:
"""
pass pass
@abstractmethod @abstractmethod
def apply_move(self, move: chess.Move) -> None: def apply_move(self, move: chess.Move) -> None:
"""
Apply the move to the chess board
:param move: move to apply
:return:
"""
pass pass
@abstractmethod @abstractmethod
def get_children(self) -> list['Mcts']: def get_children(self) -> list['Mcts']:
"""
Return the immediate children of the root node
:return: list of immediate children of mcts root
"""
pass pass

View File

@@ -9,7 +9,7 @@ class ProbStockfish(MinimalEngine):
moves = {} moves = {}
untried_moves = list(board.legal_moves) untried_moves = list(board.legal_moves)
for move in untried_moves: for move in untried_moves:
mean, std = engine.simulate_stockfish_prob(board, move, 10, 4) mean, std = engine.simulate_game(board, move, 10)
moves[move] = (mean, std) moves[move] = (mean, std)
return self.get_best_move(moves) return self.get_best_move(moves)

View File

@@ -1,6 +1,6 @@
import chess import chess
import chess.engine import chess.engine
from mcts import MCTSNode from classic_mcts import ClassicMcts
import engine import engine
import eval import eval
@@ -8,7 +8,7 @@ import eval
def test_mcts(): def test_mcts():
fools_mate = "rnbqkbnr/pppp1ppp/4p3/8/5PP1/8/PPPPP2P/RNBQKBNR b KQkq f3 0 2" fools_mate = "rnbqkbnr/pppp1ppp/4p3/8/5PP1/8/PPPPP2P/RNBQKBNR b KQkq f3 0 2"
board = chess.Board(fools_mate) board = chess.Board(fools_mate)
mcts_root = MCTSNode(board) mcts_root = ClassicMcts(board)
mcts_root.build_tree() mcts_root.build_tree()
sorted_moves = sorted(mcts_root.children, key=lambda x: x.move.uci()) sorted_moves = sorted(mcts_root.children, key=lambda x: x.move.uci())
for c in sorted_moves: for c in sorted_moves: