add limit to engine
This commit is contained in:
3
main.py
3
main.py
@@ -63,7 +63,8 @@ def analyze_results(moves: dict):
|
|||||||
def test_evaluation():
|
def test_evaluation():
|
||||||
a = engine.ClassicMctsEngine
|
a = engine.ClassicMctsEngine
|
||||||
b = engine.RandomEngine
|
b = engine.RandomEngine
|
||||||
evaluator = simulation.Evaluation(a, b)
|
limit = engine.Limit(time=0.5)
|
||||||
|
evaluator = simulation.Evaluation(a, b, limit)
|
||||||
results = evaluator.run(4)
|
results = evaluator.run(4)
|
||||||
a_results = len(list(filter(lambda x: x.winner == simulation.Winner.Engine_A, results))) / len(results) * 100
|
a_results = len(list(filter(lambda x: x.winner == simulation.Winner.Engine_A, results))) / len(results) * 100
|
||||||
b_results = len(list(filter(lambda x: x.winner == simulation.Winner.Engine_B, results))) / len(results) * 100
|
b_results = len(list(filter(lambda x: x.winner == simulation.Winner.Engine_B, results))) / len(results) * 100
|
||||||
|
|||||||
@@ -2,9 +2,44 @@ from abc import ABC, abstractmethod
|
|||||||
import chess
|
import chess
|
||||||
import chess.engine
|
import chess.engine
|
||||||
import random
|
import random
|
||||||
|
import time
|
||||||
from chesspp.classic_mcts import ClassicMcts
|
from chesspp.classic_mcts import ClassicMcts
|
||||||
|
|
||||||
|
class Limit:
|
||||||
|
""" Class to determine when to stop searching for moves """
|
||||||
|
|
||||||
|
time: float|None
|
||||||
|
""" Search for `time` seconds """
|
||||||
|
|
||||||
|
nodes: int|None
|
||||||
|
""" Search for a limited number of `nodes`"""
|
||||||
|
|
||||||
|
def __init__(self, time: float|None = None, nodes: int|None = None):
|
||||||
|
self.time = time
|
||||||
|
self.nodes = nodes
|
||||||
|
|
||||||
|
def run(self, func, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Run `func` until the limit condition is reached
|
||||||
|
:param func: the func that performs one search iteration
|
||||||
|
:param *args: are passed to `func`
|
||||||
|
:param **kwargs: are passed to `func`
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self.nodes:
|
||||||
|
self._run_nodes(func, *args, **kwargs)
|
||||||
|
elif self.time:
|
||||||
|
self._run_time(func, *args, **kwargs)
|
||||||
|
|
||||||
|
def _run_nodes(self, func, *args, **kwargs):
|
||||||
|
for _ in range(self.nodes):
|
||||||
|
func(*args, **kwargs)
|
||||||
|
|
||||||
|
def _run_time(self, func, *args, **kwargs):
|
||||||
|
start = time.perf_counter_ns()
|
||||||
|
while (time.perf_counter_ns()-start)/1e9 < self.time:
|
||||||
|
func(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class Engine(ABC):
|
class Engine(ABC):
|
||||||
color: chess.Color
|
color: chess.Color
|
||||||
@@ -14,10 +49,11 @@ class Engine(ABC):
|
|||||||
self.color = color
|
self.color = color
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def play(self, board: chess.Board) -> chess.engine.PlayResult:
|
def play(self, board: chess.Board, limit: Limit) -> chess.engine.PlayResult:
|
||||||
"""
|
"""
|
||||||
Return the next action the engine chooses based on the given board
|
Return the next action the engine chooses based on the given board
|
||||||
:param board: the chess board
|
:param board: the chess board
|
||||||
|
:param limit: a limit specifying when to stop searching
|
||||||
:return: the engine's PlayResult
|
:return: the engine's PlayResult
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
@@ -40,9 +76,9 @@ class ClassicMctsEngine(Engine):
|
|||||||
def get_name() -> str:
|
def get_name() -> str:
|
||||||
return "ClassicMctsEngine"
|
return "ClassicMctsEngine"
|
||||||
|
|
||||||
def play(self, board: chess.Board) -> chess.engine.PlayResult:
|
def play(self, board: chess.Board, limit: Limit) -> chess.engine.PlayResult:
|
||||||
mcts_root = ClassicMcts(board, self.color)
|
mcts_root = ClassicMcts(board, self.color)
|
||||||
mcts_root.build_tree()
|
limit.run(lambda: mcts_root.build_tree(samples=1))
|
||||||
best_move = max(mcts_root.children, key=lambda x: x.score).move if board.turn == chess.WHITE else (
|
best_move = max(mcts_root.children, key=lambda x: x.score).move if board.turn == chess.WHITE else (
|
||||||
min(mcts_root.children, key=lambda x: x.score).move)
|
min(mcts_root.children, key=lambda x: x.score).move)
|
||||||
return chess.engine.PlayResult(move=best_move, ponder=None)
|
return chess.engine.PlayResult(move=best_move, ponder=None)
|
||||||
@@ -56,6 +92,6 @@ class RandomEngine(Engine):
|
|||||||
def get_name() -> str:
|
def get_name() -> str:
|
||||||
return "Random"
|
return "Random"
|
||||||
|
|
||||||
def play(self, board: chess.Board) -> chess.engine.PlayResult:
|
def play(self, board: chess.Board, limit: Limit) -> chess.engine.PlayResult:
|
||||||
move = random.choice(list(board.legal_moves))
|
move = random.choice(list(board.legal_moves))
|
||||||
return chess.engine.PlayResult(move=move, ponder=None)
|
return chess.engine.PlayResult(move=move, ponder=None)
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from typing import Tuple, List
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from chesspp.engine import Engine
|
from chesspp.engine import Engine, Limit
|
||||||
|
|
||||||
|
|
||||||
class Winner(Enum):
|
class Winner(Enum):
|
||||||
@@ -21,12 +21,12 @@ class EvaluationResult:
|
|||||||
game: chess.pgn.Game
|
game: chess.pgn.Game
|
||||||
|
|
||||||
|
|
||||||
def simulate_game(white: Engine, black: Engine) -> chess.pgn.Game:
|
def simulate_game(white: Engine, black: Engine, limit: Limit) -> chess.pgn.Game:
|
||||||
board = chess.Board()
|
board = chess.Board()
|
||||||
|
|
||||||
is_white_playing = True
|
is_white_playing = True
|
||||||
while not board.is_game_over():
|
while not board.is_game_over():
|
||||||
play_result = white.play(board) if is_white_playing else black.play(board)
|
play_result = white.play(board, limit) if is_white_playing else black.play(board, limit)
|
||||||
board.push(play_result.move)
|
board.push(play_result.move)
|
||||||
is_white_playing = not is_white_playing
|
is_white_playing = not is_white_playing
|
||||||
|
|
||||||
@@ -37,25 +37,26 @@ def simulate_game(white: Engine, black: Engine) -> chess.pgn.Game:
|
|||||||
|
|
||||||
|
|
||||||
class Evaluation:
|
class Evaluation:
|
||||||
def __init__(self, engine_a: Engine.__class__, engine_b: Engine.__class__):
|
def __init__(self, engine_a: Engine.__class__, engine_b: Engine.__class__, limit: Limit):
|
||||||
self.engine_a = engine_a
|
self.engine_a = engine_a
|
||||||
self.engine_b = engine_b
|
self.engine_b = engine_b
|
||||||
|
self.limit = limit
|
||||||
|
|
||||||
def run(self, n_games=100) -> List[EvaluationResult]:
|
def run(self, n_games=100) -> List[EvaluationResult]:
|
||||||
with mp.Pool(mp.cpu_count()) as pool:
|
with mp.Pool(mp.cpu_count()) as pool:
|
||||||
args = [(self.engine_a, self.engine_b) for i in range(n_games)]
|
args = [(self.engine_a, self.engine_b, self.limit) for i in range(n_games)]
|
||||||
return pool.map(Evaluation._test_simulate, args)
|
return pool.map(Evaluation._test_simulate, args)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _test_simulate(arg: Tuple[Engine.__class__, Engine.__class__]) -> EvaluationResult:
|
def _test_simulate(arg: Tuple[Engine.__class__, Engine.__class__, Limit]) -> EvaluationResult:
|
||||||
engine_a, engine_b = arg
|
engine_a, engine_b, limit = arg
|
||||||
flip_engines = bool(random.getrandbits(1))
|
flip_engines = bool(random.getrandbits(1))
|
||||||
if flip_engines:
|
if flip_engines:
|
||||||
black, white = engine_a(chess.BLACK), engine_b(chess.WHITE)
|
black, white = engine_a(chess.BLACK), engine_b(chess.WHITE)
|
||||||
else:
|
else:
|
||||||
white, black = engine_a(chess.WHITE), engine_b(chess.BLACK)
|
white, black = engine_a(chess.WHITE), engine_b(chess.BLACK)
|
||||||
|
|
||||||
game = simulate_game(white, black)
|
game = simulate_game(white, black, limit)
|
||||||
winner = game.end().board().outcome().winner
|
winner = game.end().board().outcome().winner
|
||||||
|
|
||||||
result = Winner.Draw
|
result = Winner.Draw
|
||||||
|
|||||||
@@ -30,21 +30,22 @@ class Simulate:
|
|||||||
self.white = engine_white
|
self.white = engine_white
|
||||||
self.black = engine_black
|
self.black = engine_black
|
||||||
|
|
||||||
def run(self):
|
def run(self, limit: engine.Limit):
|
||||||
board = chess.Board()
|
board = chess.Board()
|
||||||
|
|
||||||
is_white_playing = True
|
is_white_playing = True
|
||||||
while not board.is_game_over():
|
while not board.is_game_over():
|
||||||
play_result = self.white.play(board) if is_white_playing else self.black.play(board)
|
play_result = self.white.play(board, limit) if is_white_playing else self.black.play(board, limit)
|
||||||
board.push(play_result.move)
|
board.push(play_result.move)
|
||||||
yield board
|
yield board
|
||||||
is_white_playing = not is_white_playing
|
is_white_playing = not is_white_playing
|
||||||
|
|
||||||
|
|
||||||
class WebInterface:
|
class WebInterface:
|
||||||
def __init__(self, white_engine: engine.Engine.__class__, black_engine: engine.Engine.__class__):
|
def __init__(self, white_engine: engine.Engine.__class__, black_engine: engine.Engine.__class__, limit: engine.Limit):
|
||||||
self.white = white_engine
|
self.white = white_engine
|
||||||
self.black = black_engine
|
self.black = black_engine
|
||||||
|
self.limit = limit
|
||||||
|
|
||||||
|
|
||||||
async def handle_index(self, request) -> web.Response:
|
async def handle_index(self, request) -> web.Response:
|
||||||
@@ -70,7 +71,7 @@ class WebInterface:
|
|||||||
|
|
||||||
async def turns():
|
async def turns():
|
||||||
""" Simulates the game and sends the response to the client """
|
""" Simulates the game and sends the response to the client """
|
||||||
runner = Simulate(self.white(chess.WHITE), self.black(chess.BLACK)).run()
|
runner = Simulate(self.white(chess.WHITE), self.black(chess.BLACK)).run(limit)
|
||||||
def sim():
|
def sim():
|
||||||
return next(runner, None)
|
return next(runner, None)
|
||||||
|
|
||||||
@@ -98,4 +99,5 @@ class WebInterface:
|
|||||||
web.run_app(app)
|
web.run_app(app)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
WebInterface(engine.ClassicMctsEngine, engine.ClassicMctsEngine).run_app()
|
limit = engine.Limit(time=0.5)
|
||||||
|
WebInterface(engine.ClassicMctsEngine, engine.ClassicMctsEngine, limit).run_app()
|
||||||
|
|||||||
Reference in New Issue
Block a user