added engine class and simulate_game function
This commit is contained in:
@@ -1,16 +1,17 @@
|
|||||||
import chess
|
import chess
|
||||||
import random
|
import random
|
||||||
import eval
|
import eval
|
||||||
import engine
|
import util
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
class ClassicMcts:
|
class ClassicMcts:
|
||||||
|
|
||||||
def __init__(self, board: chess.Board, parent=None, move: chess.Move | None = None,
|
def __init__(self, board: chess.Board, color: chess.Color, parent=None, move: chess.Move | None = None,
|
||||||
random_state: int | None = None):
|
random_state: int | None = None):
|
||||||
self.random = random.Random(random_state)
|
self.random = random.Random(random_state)
|
||||||
self.board = board
|
self.board = board
|
||||||
|
self.color = color
|
||||||
self.parent = parent
|
self.parent = parent
|
||||||
self.move = move
|
self.move = move
|
||||||
self.children = []
|
self.children = []
|
||||||
@@ -28,7 +29,7 @@ class ClassicMcts:
|
|||||||
self.untried_actions.remove(move)
|
self.untried_actions.remove(move)
|
||||||
next_board = self.board.copy()
|
next_board = self.board.copy()
|
||||||
next_board.push(move)
|
next_board.push(move)
|
||||||
child_node = ClassicMcts(next_board, parent=self, move=move)
|
child_node = ClassicMcts(next_board, color=self.color, parent=self, move=move)
|
||||||
self.children.append(child_node)
|
self.children.append(child_node)
|
||||||
return child_node
|
return child_node
|
||||||
|
|
||||||
@@ -44,7 +45,7 @@ class ClassicMcts:
|
|||||||
if copied_board.is_game_over():
|
if copied_board.is_game_over():
|
||||||
break
|
break
|
||||||
|
|
||||||
m = engine.pick_move(copied_board)
|
m = util.pick_move(copied_board)
|
||||||
copied_board.push(m)
|
copied_board.push(m)
|
||||||
steps += 1
|
steps += 1
|
||||||
|
|
||||||
@@ -73,7 +74,8 @@ class ClassicMcts:
|
|||||||
# NOTE: maybe clamp the score between [-1, +1] instead of [-inf, +inf]
|
# NOTE: maybe clamp the score between [-1, +1] instead of [-inf, +inf]
|
||||||
choices_weights = [(c.score / c.visits) + np.sqrt(((2 * np.log(self.visits)) / c.visits))
|
choices_weights = [(c.score / c.visits) + np.sqrt(((2 * np.log(self.visits)) / c.visits))
|
||||||
for c in self.children]
|
for c in self.children]
|
||||||
return self.children[np.argmax(choices_weights)]
|
best_child_index = np.argmax(choices_weights) if self.color == chess.WHITE else np.argmin(choices_weights)
|
||||||
|
return self.children[best_child_index]
|
||||||
|
|
||||||
def _select_leaf(self) -> 'ClassicMcts':
|
def _select_leaf(self) -> 'ClassicMcts':
|
||||||
"""
|
"""
|
||||||
|
|||||||
109
engine.py
109
engine.py
@@ -1,80 +1,45 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
import chess
|
import chess
|
||||||
import chess.engine
|
import chess.engine
|
||||||
import random
|
from classic_mcts import ClassicMcts
|
||||||
import eval
|
|
||||||
import numpy as np
|
|
||||||
from stockfish import Stockfish
|
|
||||||
|
|
||||||
|
|
||||||
def pick_move(board: chess.Board) -> chess.Move | None:
|
class Engine(ABC):
|
||||||
"""
|
|
||||||
Pick a random move
|
color: chess.Color
|
||||||
:param board: chess board
|
"""The side the engine plays (``chess.WHITE`` or ``chess.BLACK``)."""
|
||||||
:return: a valid move or None if no valid move available
|
|
||||||
"""
|
def __init__(self, color: chess.Color):
|
||||||
if len(list(board.legal_moves)) == 0:
|
self.color = color
|
||||||
return None
|
|
||||||
return random.choice(list(board.legal_moves))
|
@abstractmethod
|
||||||
|
def play(self, board: chess.Board) -> chess.engine.PlayResult:
|
||||||
|
"""
|
||||||
|
Return the next action the engine chooses based on the given board
|
||||||
|
:param board: the chess board
|
||||||
|
:return: the engine's PlayResult
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_name(self) -> str:
|
||||||
|
"""
|
||||||
|
Return the engine's name
|
||||||
|
:return: the engine's name
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def simulate_game(board: chess.Board, move: chess.Move, depth: int):
|
class ClassicMctsEngine(Engine):
|
||||||
"""
|
def __init__(self, color: chess.Color):
|
||||||
Simulate a game starting with the given move
|
super().__init__(color)
|
||||||
:param board: chess board
|
|
||||||
:param move: chosen move
|
|
||||||
:param depth: number of moves that should be simulated after playing the chosen move
|
|
||||||
:return: the score for the simulated game
|
|
||||||
"""
|
|
||||||
engine = chess.engine.SimpleEngine.popen_uci("./stockfish/stockfish-ubuntu-x86-64-avx2")
|
|
||||||
board.push(move)
|
|
||||||
for i in range(depth):
|
|
||||||
if board.is_game_over():
|
|
||||||
engine.quit()
|
|
||||||
return
|
|
||||||
r = engine.play(board, chess.engine.Limit(depth=2))
|
|
||||||
board.push(r.move)
|
|
||||||
|
|
||||||
engine.quit()
|
def get_name(self) -> str:
|
||||||
|
return "ClassicMctsEngine"
|
||||||
|
|
||||||
|
def play(self, board: chess.Board) -> chess.engine.PlayResult:
|
||||||
def simulate_stockfish_prob(board: chess.Board, move: chess.Move, games: int = 10, depth: int = 10) -> (float, float):
|
mcts_root = ClassicMcts(board, self.color)
|
||||||
"""
|
mcts_root.build_tree()
|
||||||
Simulate a game using
|
best_move = max(mcts_root.children, key=lambda x: x.score).move if board.turn == chess.WHITE else (
|
||||||
:param board: chess board
|
min(mcts_root.children, key=lambda x: x.score).move)
|
||||||
:param move: chosen move
|
return chess.engine.PlayResult(move=best_move, ponder=None)
|
||||||
:param games: number of games that should be simulated after playing the move
|
|
||||||
:param depth: simulation depth per game
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
board.push(move)
|
|
||||||
copied_board = board.copy()
|
|
||||||
scores = []
|
|
||||||
|
|
||||||
stockfish = Stockfish("./stockfish/stockfish-ubuntu-x86-64-avx2", depth=2, parameters={"Threads": 8, "Hash": 2048})
|
|
||||||
stockfish.set_elo_rating(1200)
|
|
||||||
stockfish.set_fen_position(board.fen())
|
|
||||||
|
|
||||||
def reset_game():
|
|
||||||
nonlocal scores, copied_board, board
|
|
||||||
score = eval.score_stockfish(copied_board).white().score(mate_score=100_000)
|
|
||||||
scores.append(score)
|
|
||||||
copied_board = board.copy()
|
|
||||||
stockfish.set_fen_position(board.fen())
|
|
||||||
|
|
||||||
for _ in range(games):
|
|
||||||
for d in range(depth):
|
|
||||||
if copied_board.is_game_over() or d == depth - 1:
|
|
||||||
reset_game()
|
|
||||||
break
|
|
||||||
|
|
||||||
if d == depth - 1:
|
|
||||||
reset_game()
|
|
||||||
|
|
||||||
top_moves = stockfish.get_top_moves(3)
|
|
||||||
chosen_move = random.choice(top_moves)['Move']
|
|
||||||
stockfish.make_moves_from_current_position([chosen_move])
|
|
||||||
copied_board.push(chess.Move.from_uci(chosen_move))
|
|
||||||
|
|
||||||
print(scores)
|
|
||||||
# TODO: return distribution here?
|
|
||||||
return np.array(scores).mean(), np.array(scores).std()
|
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ class IMcts(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_children(self) -> list['Mcts']:
|
def get_children(self) -> list['IMcts']:
|
||||||
"""
|
"""
|
||||||
Return the immediate children of the root node
|
Return the immediate children of the root node
|
||||||
:return: list of immediate children of mcts root
|
:return: list of immediate children of mcts root
|
||||||
|
|||||||
40
main.py
40
main.py
@@ -3,12 +3,37 @@ import chess.engine
|
|||||||
from classic_mcts import ClassicMcts
|
from classic_mcts import ClassicMcts
|
||||||
import engine
|
import engine
|
||||||
import eval
|
import eval
|
||||||
|
import util
|
||||||
|
|
||||||
|
|
||||||
|
def simulate_game(white: engine.Engine, black: engine.Engine) -> chess.pgn.Game:
|
||||||
|
board = chess.Board()
|
||||||
|
|
||||||
|
is_white_playing = True
|
||||||
|
while not board.is_game_over():
|
||||||
|
play_result = white.play(board) if is_white_playing else black.play(board)
|
||||||
|
board.push(play_result.move)
|
||||||
|
print(board)
|
||||||
|
print()
|
||||||
|
is_white_playing = not is_white_playing
|
||||||
|
|
||||||
|
game = chess.pgn.Game.from_board(board)
|
||||||
|
game.headers['White'] = white.get_name()
|
||||||
|
game.headers['Black'] = black.get_name()
|
||||||
|
return game
|
||||||
|
|
||||||
|
|
||||||
|
def test_simulate():
|
||||||
|
white = engine.ClassicMctsEngine(chess.WHITE)
|
||||||
|
black = engine.ClassicMctsEngine(chess.BLACK)
|
||||||
|
game = simulate_game(white, black)
|
||||||
|
print(game)
|
||||||
|
|
||||||
|
|
||||||
def test_mcts():
|
def test_mcts():
|
||||||
fools_mate = "rnbqkbnr/pppp1ppp/4p3/8/5PP1/8/PPPPP2P/RNBQKBNR b KQkq f3 0 2"
|
fools_mate = "rnbqkbnr/pppp1ppp/4p3/8/5PP1/8/PPPPP2P/RNBQKBNR b KQkq f3 0 2"
|
||||||
board = chess.Board(fools_mate)
|
board = chess.Board(fools_mate)
|
||||||
mcts_root = ClassicMcts(board)
|
mcts_root = ClassicMcts(board, chess.BLACK)
|
||||||
mcts_root.build_tree()
|
mcts_root.build_tree()
|
||||||
sorted_moves = sorted(mcts_root.children, key=lambda x: x.move.uci())
|
sorted_moves = sorted(mcts_root.children, key=lambda x: x.move.uci())
|
||||||
for c in sorted_moves:
|
for c in sorted_moves:
|
||||||
@@ -21,7 +46,7 @@ def test_stockfish():
|
|||||||
moves = {}
|
moves = {}
|
||||||
untried_moves = list(board.legal_moves)
|
untried_moves = list(board.legal_moves)
|
||||||
for move in untried_moves:
|
for move in untried_moves:
|
||||||
engine.simulate_game(board, move, 100)
|
util.simulate_game(board, move, 100)
|
||||||
moves[move] = board
|
moves[move] = board
|
||||||
board = chess.Board(fools_mate)
|
board = chess.Board(fools_mate)
|
||||||
|
|
||||||
@@ -35,7 +60,7 @@ def test_stockfish_prob():
|
|||||||
moves = {}
|
moves = {}
|
||||||
untried_moves = list(board.legal_moves)
|
untried_moves = list(board.legal_moves)
|
||||||
for move in untried_moves:
|
for move in untried_moves:
|
||||||
mean, std = engine.simulate_stockfish_prob(board, move, 10, 4)
|
mean, std = util.simulate_stockfish_prob(board, move, 10, 4)
|
||||||
moves[move] = (mean, std)
|
moves[move] = (mean, std)
|
||||||
board = chess.Board(fools_mate)
|
board = chess.Board(fools_mate)
|
||||||
|
|
||||||
@@ -47,14 +72,15 @@ def test_stockfish_prob():
|
|||||||
def analyze_results(moves: dict):
|
def analyze_results(moves: dict):
|
||||||
for m, b in moves.items():
|
for m, b in moves.items():
|
||||||
manual_score = eval.score_manual(b)
|
manual_score = eval.score_manual(b)
|
||||||
engine_score = eval.score_stockfish(b).white()
|
engine_score = eval.score_stockfish(b).white().score(mate_score=100_000)
|
||||||
print(f"score for move {m}: manual_score={manual_score}, engine_score={engine_score}")
|
print(f"score for move {m}: manual_score={manual_score}, engine_score={engine_score}")
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
test_mcts()
|
test_simulate()
|
||||||
test_stockfish()
|
# test_mcts()
|
||||||
test_stockfish_prob()
|
# test_stockfish()
|
||||||
|
# test_stockfish_prob()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
79
util.py
Normal file
79
util.py
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
import chess
|
||||||
|
import chess.engine
|
||||||
|
from stockfish import Stockfish
|
||||||
|
import numpy as np
|
||||||
|
import random
|
||||||
|
|
||||||
|
|
||||||
|
def pick_move(board: chess.Board) -> chess.Move | None:
|
||||||
|
"""
|
||||||
|
Pick a random move
|
||||||
|
:param board: chess board
|
||||||
|
:return: a valid move or None if no valid move available
|
||||||
|
"""
|
||||||
|
if len(list(board.legal_moves)) == 0:
|
||||||
|
return None
|
||||||
|
return random.choice(list(board.legal_moves))
|
||||||
|
|
||||||
|
|
||||||
|
def simulate_game(board: chess.Board, move: chess.Move, depth: int):
|
||||||
|
"""
|
||||||
|
Simulate a game starting with the given move
|
||||||
|
:param board: chess board
|
||||||
|
:param move: chosen move
|
||||||
|
:param depth: number of moves that should be simulated after playing the chosen move
|
||||||
|
:return: the score for the simulated game
|
||||||
|
"""
|
||||||
|
engine = chess.engine.SimpleEngine.popen_uci("./stockfish/stockfish-ubuntu-x86-64-avx2")
|
||||||
|
board.push(move)
|
||||||
|
for i in range(depth):
|
||||||
|
if board.is_game_over():
|
||||||
|
engine.quit()
|
||||||
|
return
|
||||||
|
r = engine.play(board, chess.engine.Limit(depth=2))
|
||||||
|
board.push(r.move)
|
||||||
|
|
||||||
|
engine.quit()
|
||||||
|
|
||||||
|
|
||||||
|
def simulate_stockfish_prob(board: chess.Board, move: chess.Move, games: int = 10, depth: int = 10) -> (float, float):
|
||||||
|
"""
|
||||||
|
Simulate a game using
|
||||||
|
:param board: chess board
|
||||||
|
:param move: chosen move
|
||||||
|
:param games: number of games that should be simulated after playing the move
|
||||||
|
:param depth: simulation depth per game
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
board.push(move)
|
||||||
|
copied_board = board.copy()
|
||||||
|
scores = []
|
||||||
|
|
||||||
|
stockfish = Stockfish("./stockfish/stockfish-ubuntu-x86-64-avx2", depth=2, parameters={"Threads": 8, "Hash": 2048})
|
||||||
|
stockfish.set_elo_rating(1200)
|
||||||
|
stockfish.set_fen_position(board.fen())
|
||||||
|
|
||||||
|
def reset_game():
|
||||||
|
nonlocal scores, copied_board, board
|
||||||
|
score = eval.score_stockfish(copied_board).white().score(mate_score=100_000)
|
||||||
|
scores.append(score)
|
||||||
|
copied_board = board.copy()
|
||||||
|
stockfish.set_fen_position(board.fen())
|
||||||
|
|
||||||
|
for _ in range(games):
|
||||||
|
for d in range(depth):
|
||||||
|
if copied_board.is_game_over() or d == depth - 1:
|
||||||
|
reset_game()
|
||||||
|
break
|
||||||
|
|
||||||
|
if d == depth - 1:
|
||||||
|
reset_game()
|
||||||
|
|
||||||
|
top_moves = stockfish.get_top_moves(3)
|
||||||
|
chosen_move = random.choice(top_moves)['Move']
|
||||||
|
stockfish.make_moves_from_current_position([chosen_move])
|
||||||
|
copied_board.push(chess.Move.from_uci(chosen_move))
|
||||||
|
|
||||||
|
print(scores)
|
||||||
|
# TODO: return distribution here?
|
||||||
|
return np.array(scores).mean(), np.array(scores).std()
|
||||||
Reference in New Issue
Block a user