Merge remote-tracking branch 'origin/main' into refactor-mcts

# Conflicts:
#	main.py
This commit is contained in:
2024-01-31 18:16:10 +01:00
14 changed files with 678 additions and 117 deletions

52
main.py
View File

@@ -1,12 +1,16 @@
import argparse
import os
import random
import sys
import time
import chess
import chess.engine
import chess.pgn
from chesspp.classic_mcts import ClassicMcts
from chesspp.baysian_mcts import BayesianMcts
from chesspp.engine_factory import EngineEnum, StrategyEnum
from chesspp.random_strategy import RandomStrategy
from chesspp.stockfish_strategy import StockFishStrategy
from chesspp import engine
from chesspp import simulation, eval
from chesspp import util
@@ -89,19 +93,10 @@ def analyze_results(moves: dict):
def test_evaluation():
a, b, s1, s2, n, limit, stockfish_path, proc = read_arguments()
a, b, s1, s2, n, limit, stockfish_path, lc0_path, proc = read_arguments()
limit = engine.Limit(time=limit)
if s1 == StockFishStrategy:
strat1 = StockFishStrategy(stockfish_path)
else:
strat1 = s1()
if s2 == StockFishStrategy:
strat2 = StockFishStrategy(stockfish_path)
else:
strat2 = s1()
evaluator = simulation.Evaluation(a, strat1, b, strat2, limit)
evaluator = simulation.Evaluation(a, s1, b, s2, limit, stockfish_path, lc0_path)
results = evaluator.run(n, proc)
games_played = len(results)
a_wins = len(list(filter(lambda x: x.winner == simulation.Winner.Engine_A, results)))
@@ -125,20 +120,25 @@ def read_arguments():
description='Compare two engines by playing multiple games against each other'
)
engines = {"ClassicMCTS": engine.ClassicMctsEngine, "BayesianMCTS": engine.BayesMctsEngine,
"Random": engine.RandomEngine}
strategies = {"Random": RandomStrategy, "Stockfish": StockFishStrategy}
engines = {"ClassicMCTS": EngineEnum.ClassicMcts, "BayesianMCTS": EngineEnum.BayesianMcts,
"Random": EngineEnum.Random, "Stockfish": EngineEnum.Stockfish, "Lc0": EngineEnum.Lc0}
strategies = {"Random": StrategyEnum.Random, "Stockfish": StrategyEnum.Stockfish, "Lc0": StrategyEnum.Lc0,
"RandomStockfish": StrategyEnum.RandomStockfish, "PESTO": StrategyEnum.Pestos}
if os.name == 'nt':
stockfish_default = "../stockfish/stockfish-windows-x86-64-avx2"
stockfish_default = "stockfish/stockfish-windows-x86-64-avx2"
lc0_default = "lc0/lc0.exe"
else:
stockfish_default = "../stockfish/stockfish-ubuntu-x86-64-avx2"
stockfish_default = "stockfish/stockfish-ubuntu-x86-64-avx2"
lc0_default = "lc0/lc0"
parser.add_argument("--proc", default=2, help="Number of processors to use for simulation, default=1")
parser.add_argument("--time", default=0.5, help="Time limit for each simulation step, default=0.5")
parser.add_argument("-n", default=100, help="Number of games to simulate, default=100")
parser.add_argument("--stockfish", default=stockfish_default,
help=f"Path for stockfish executable, default='{stockfish_default}'")
parser.add_argument("--stockfish_path", default=stockfish_default,
help=f"Path for engine executable, default='{stockfish_default}'")
parser.add_argument("--lc0_path", default=lc0_default,
help=f"Path for engine executable, default='{stockfish_default}'")
parser.add_argument("--engine1", "--e1", help="Engine A for the simulation", choices=engines.keys(), required=True)
parser.add_argument("--engine2", "--e2", help="Engine B for the simulation", choices=engines.keys(), required=True)
parser.add_argument("--strategy1", "--s1", default=list(strategies.keys())[0],
@@ -146,7 +146,7 @@ def read_arguments():
choices=strategies.keys())
parser.add_argument("--strategy2", "--s2", default=list(strategies.keys())[0],
help="Strategy for engine B for the rollout",
choices=strategies)
choices=strategies.keys())
args = parser.parse_args()
engine1 = engines[args.engine1]
@@ -154,7 +154,10 @@ def read_arguments():
strategy1 = strategies[args.strategy1]
strategy2 = strategies[args.strategy2]
return engine1, engine2, strategy1, strategy2, int(args.n), float(args.time), args.stockfish, int(args.proc)
print(engine1, engine2, strategy1, strategy2, int(args.n), float(args.time), args.stockfish_path, args.lc0_path,
int(args.proc))
return engine1, engine2, strategy1, strategy2, int(args.n), float(
args.time), args.stockfish_path, args.lc0_path, int(args.proc)
def main():
@@ -168,3 +171,8 @@ def main():
if __name__ == '__main__':
main()
# Note: prevent endless wait on StockFish process
# by allowing for cleanup of objects (which closes stockfish)
import gc
gc.collect()