add limit to engine

This commit is contained in:
Stefan Steininger
2024-01-28 20:11:29 +01:00
parent 9893da5b58
commit db89d79902
4 changed files with 59 additions and 19 deletions

View File

@@ -63,7 +63,8 @@ def analyze_results(moves: dict):
def test_evaluation():
a = engine.ClassicMctsEngine
b = engine.RandomEngine
evaluator = simulation.Evaluation(a, b)
limit = engine.Limit(time=0.5)
evaluator = simulation.Evaluation(a, b, limit)
results = evaluator.run(4)
a_results = len(list(filter(lambda x: x.winner == simulation.Winner.Engine_A, results))) / len(results) * 100
b_results = len(list(filter(lambda x: x.winner == simulation.Winner.Engine_B, results))) / len(results) * 100

View File

@@ -2,9 +2,44 @@ from abc import ABC, abstractmethod
import chess
import chess.engine
import random
import time
from chesspp.classic_mcts import ClassicMcts
class Limit:
""" Class to determine when to stop searching for moves """
time: float|None
""" Search for `time` seconds """
nodes: int|None
""" Search for a limited number of `nodes`"""
def __init__(self, time: float|None = None, nodes: int|None = None):
self.time = time
self.nodes = nodes
def run(self, func, *args, **kwargs):
"""
Run `func` until the limit condition is reached
:param func: the func that performs one search iteration
:param *args: are passed to `func`
:param **kwargs: are passed to `func`
"""
if self.nodes:
self._run_nodes(func, *args, **kwargs)
elif self.time:
self._run_time(func, *args, **kwargs)
def _run_nodes(self, func, *args, **kwargs):
for _ in range(self.nodes):
func(*args, **kwargs)
def _run_time(self, func, *args, **kwargs):
start = time.perf_counter_ns()
while (time.perf_counter_ns()-start)/1e9 < self.time:
func(*args, **kwargs)
class Engine(ABC):
color: chess.Color
@@ -14,10 +49,11 @@ class Engine(ABC):
self.color = color
@abstractmethod
def play(self, board: chess.Board) -> chess.engine.PlayResult:
def play(self, board: chess.Board, limit: Limit) -> chess.engine.PlayResult:
"""
Return the next action the engine chooses based on the given board
:param board: the chess board
:param limit: a limit specifying when to stop searching
:return: the engine's PlayResult
"""
pass
@@ -40,9 +76,9 @@ class ClassicMctsEngine(Engine):
def get_name() -> str:
return "ClassicMctsEngine"
def play(self, board: chess.Board) -> chess.engine.PlayResult:
def play(self, board: chess.Board, limit: Limit) -> chess.engine.PlayResult:
mcts_root = ClassicMcts(board, self.color)
mcts_root.build_tree()
limit.run(lambda: mcts_root.build_tree(samples=1))
best_move = max(mcts_root.children, key=lambda x: x.score).move if board.turn == chess.WHITE else (
min(mcts_root.children, key=lambda x: x.score).move)
return chess.engine.PlayResult(move=best_move, ponder=None)
@@ -56,6 +92,6 @@ class RandomEngine(Engine):
def get_name() -> str:
return "Random"
def play(self, board: chess.Board) -> chess.engine.PlayResult:
def play(self, board: chess.Board, limit: Limit) -> chess.engine.PlayResult:
move = random.choice(list(board.legal_moves))
return chess.engine.PlayResult(move=move, ponder=None)

View File

@@ -6,7 +6,7 @@ from typing import Tuple, List
from enum import Enum
from dataclasses import dataclass
from chesspp.engine import Engine
from chesspp.engine import Engine, Limit
class Winner(Enum):
@@ -21,12 +21,12 @@ class EvaluationResult:
game: chess.pgn.Game
def simulate_game(white: Engine, black: Engine) -> chess.pgn.Game:
def simulate_game(white: Engine, black: Engine, limit: Limit) -> chess.pgn.Game:
board = chess.Board()
is_white_playing = True
while not board.is_game_over():
play_result = white.play(board) if is_white_playing else black.play(board)
play_result = white.play(board, limit) if is_white_playing else black.play(board, limit)
board.push(play_result.move)
is_white_playing = not is_white_playing
@@ -37,25 +37,26 @@ def simulate_game(white: Engine, black: Engine) -> chess.pgn.Game:
class Evaluation:
def __init__(self, engine_a: Engine.__class__, engine_b: Engine.__class__):
def __init__(self, engine_a: Engine.__class__, engine_b: Engine.__class__, limit: Limit):
self.engine_a = engine_a
self.engine_b = engine_b
self.limit = limit
def run(self, n_games=100) -> List[EvaluationResult]:
with mp.Pool(mp.cpu_count()) as pool:
args = [(self.engine_a, self.engine_b) for i in range(n_games)]
args = [(self.engine_a, self.engine_b, self.limit) for i in range(n_games)]
return pool.map(Evaluation._test_simulate, args)
@staticmethod
def _test_simulate(arg: Tuple[Engine.__class__, Engine.__class__]) -> EvaluationResult:
engine_a, engine_b = arg
def _test_simulate(arg: Tuple[Engine.__class__, Engine.__class__, Limit]) -> EvaluationResult:
engine_a, engine_b, limit = arg
flip_engines = bool(random.getrandbits(1))
if flip_engines:
black, white = engine_a(chess.BLACK), engine_b(chess.WHITE)
else:
white, black = engine_a(chess.WHITE), engine_b(chess.BLACK)
game = simulate_game(white, black)
game = simulate_game(white, black, limit)
winner = game.end().board().outcome().winner
result = Winner.Draw

View File

@@ -30,21 +30,22 @@ class Simulate:
self.white = engine_white
self.black = engine_black
def run(self):
def run(self, limit: engine.Limit):
board = chess.Board()
is_white_playing = True
while not board.is_game_over():
play_result = self.white.play(board) if is_white_playing else self.black.play(board)
play_result = self.white.play(board, limit) if is_white_playing else self.black.play(board, limit)
board.push(play_result.move)
yield board
is_white_playing = not is_white_playing
class WebInterface:
def __init__(self, white_engine: engine.Engine.__class__, black_engine: engine.Engine.__class__):
def __init__(self, white_engine: engine.Engine.__class__, black_engine: engine.Engine.__class__, limit: engine.Limit):
self.white = white_engine
self.black = black_engine
self.limit = limit
async def handle_index(self, request) -> web.Response:
@@ -70,7 +71,7 @@ class WebInterface:
async def turns():
""" Simulates the game and sends the response to the client """
runner = Simulate(self.white(chess.WHITE), self.black(chess.BLACK)).run()
runner = Simulate(self.white(chess.WHITE), self.black(chess.BLACK)).run(limit)
def sim():
return next(runner, None)
@@ -98,4 +99,5 @@ class WebInterface:
web.run_app(app)
if __name__ == '__main__':
WebInterface(engine.ClassicMctsEngine, engine.ClassicMctsEngine).run_app()
limit = engine.Limit(time=0.5)
WebInterface(engine.ClassicMctsEngine, engine.ClassicMctsEngine, limit).run_app()