Add hypothesis test
This commit is contained in:
@@ -1,9 +1,13 @@
|
|||||||
|
from typing import TypedDict
|
||||||
|
|
||||||
import chess
|
import chess
|
||||||
import chess.engine
|
import chess.engine
|
||||||
from stockfish import Stockfish
|
from stockfish import Stockfish
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import random
|
import random
|
||||||
|
|
||||||
|
from scipy.stats import binomtest
|
||||||
|
|
||||||
|
|
||||||
def pick_move(board: chess.Board) -> chess.Move | None:
|
def pick_move(board: chess.Board) -> chess.Move | None:
|
||||||
"""
|
"""
|
||||||
@@ -77,3 +81,38 @@ def simulate_stockfish_prob(board: chess.Board, move: chess.Move, games: int = 1
|
|||||||
print(scores)
|
print(scores)
|
||||||
# TODO: return distribution here?
|
# TODO: return distribution here?
|
||||||
return np.array(scores).mean(), np.array(scores).std()
|
return np.array(scores).mean(), np.array(scores).std()
|
||||||
|
|
||||||
|
|
||||||
|
HypothesisTestResult = TypedDict('HypothesisTestResult', {"trials": int, "pvalue": float, "statistic": float})
|
||||||
|
|
||||||
|
|
||||||
|
def hypothesis_test(wins: int, draws: int, losses: int) -> HypothesisTestResult:
|
||||||
|
"""
|
||||||
|
Hypothesis test using Binomial distributions.
|
||||||
|
|
||||||
|
Null Hypothesis: Both engines have the same strength, aka they win on average half of the games.
|
||||||
|
Alternative Hypothesis: Both engines have different strength.
|
||||||
|
|
||||||
|
:returns: tuple of trials, pvalue, test-statistic
|
||||||
|
"""
|
||||||
|
|
||||||
|
# wins give 1 point, and draws give 1/2 points
|
||||||
|
score = wins + draws // 2
|
||||||
|
|
||||||
|
# number of games
|
||||||
|
trials = wins + draws + losses
|
||||||
|
|
||||||
|
# due to rounding down the variable score, if draws are even, we have to reduce trials by one.
|
||||||
|
if draws % 2 != 0:
|
||||||
|
trials -= 1
|
||||||
|
|
||||||
|
# we expect that if both engines have the same strength, that they "win" on 50% on average
|
||||||
|
expected_success_rate = 0.5
|
||||||
|
|
||||||
|
result = binomtest(score, trials, expected_success_rate, alternative='two-sided')
|
||||||
|
|
||||||
|
return {
|
||||||
|
"trials": trials,
|
||||||
|
"pvalue": result.pvalue,
|
||||||
|
"statistic": result.statistic
|
||||||
|
}
|
||||||
|
|||||||
30
main.py
30
main.py
@@ -1,18 +1,20 @@
|
|||||||
|
import argparse
|
||||||
|
import os
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import chess
|
import chess
|
||||||
import chess.engine
|
import chess.engine
|
||||||
import chess.pgn
|
import chess.pgn
|
||||||
from chesspp.mcts.classic_mcts import ClassicMcts
|
|
||||||
|
from chesspp import engine
|
||||||
|
from chesspp import simulation, eval
|
||||||
|
from chesspp import util
|
||||||
from chesspp.mcts.baysian_mcts import BayesianMcts
|
from chesspp.mcts.baysian_mcts import BayesianMcts
|
||||||
|
from chesspp.mcts.classic_mcts import ClassicMcts
|
||||||
from chesspp.random_strategy import RandomStrategy
|
from chesspp.random_strategy import RandomStrategy
|
||||||
from chesspp.stockfish_strategy import StockFishStrategy
|
from chesspp.stockfish_strategy import StockFishStrategy
|
||||||
from chesspp import engine
|
from chesspp.util import hypothesis_test
|
||||||
from chesspp import util
|
|
||||||
from chesspp import simulation, eval
|
|
||||||
import argparse
|
|
||||||
import os
|
|
||||||
|
|
||||||
|
|
||||||
def test_simulate():
|
def test_simulate():
|
||||||
@@ -44,7 +46,7 @@ def test_bayes_mcts():
|
|||||||
t1 = time.time_ns()
|
t1 = time.time_ns()
|
||||||
mcts.sample(1)
|
mcts.sample(1)
|
||||||
t2 = time.time_ns()
|
t2 = time.time_ns()
|
||||||
print ((t2 - t1)/1e6)
|
print((t2 - t1) / 1e6)
|
||||||
mcts.print()
|
mcts.print()
|
||||||
for move, score in mcts.get_moves().items():
|
for move, score in mcts.get_moves().items():
|
||||||
print("move (mcts):", move, " with score:", score)
|
print("move (mcts):", move, " with score:", score)
|
||||||
@@ -106,10 +108,15 @@ def test_evaluation():
|
|||||||
b_wins = len(list(filter(lambda x: x.winner == simulation.Winner.Engine_B, results)))
|
b_wins = len(list(filter(lambda x: x.winner == simulation.Winner.Engine_B, results)))
|
||||||
draws = len(list(filter(lambda x: x.winner == simulation.Winner.Draw, results)))
|
draws = len(list(filter(lambda x: x.winner == simulation.Winner.Draw, results)))
|
||||||
|
|
||||||
|
alpha = 0.001
|
||||||
|
test_result = hypothesis_test(a_wins, draws, b_wins)
|
||||||
|
reject_h0 = test_result['pvalue'] < alpha
|
||||||
|
|
||||||
print(f"{games_played} games played")
|
print(f"{games_played} games played")
|
||||||
print(f"Engine {a.get_name()} won {a_wins} games ({a_wins/games_played:.2%})")
|
print(f"Engine {a.get_name()} won {a_wins} games ({a_wins / games_played:.2%})")
|
||||||
print(f"Engine {b.get_name()} won {b_wins} games ({b_wins/games_played:.2%})")
|
print(f"Engine {b.get_name()} won {b_wins} games ({b_wins / games_played:.2%})")
|
||||||
print(f"{draws} games ({draws/games_played:.2%}) resulted in a draw")
|
print(f"{draws} games ({draws / games_played:.2%}) resulted in a draw")
|
||||||
|
print(f"Hypothesis test: trials={test_result['trials']}, pvalue={test_result['pvalue']:2.10f}, statistic={test_result['statistic']:2.4f}, reject_h0={reject_h0}")
|
||||||
|
|
||||||
|
|
||||||
def read_arguments():
|
def read_arguments():
|
||||||
@@ -118,7 +125,8 @@ def read_arguments():
|
|||||||
description='Compare two engines by playing multiple games against each other'
|
description='Compare two engines by playing multiple games against each other'
|
||||||
)
|
)
|
||||||
|
|
||||||
engines = {"ClassicMCTS": engine.ClassicMctsEngine, "BayesianMCTS": engine.BayesMctsEngine, "Random": engine.RandomEngine}
|
engines = {"ClassicMCTS": engine.ClassicMctsEngine, "BayesianMCTS": engine.BayesMctsEngine,
|
||||||
|
"Random": engine.RandomEngine}
|
||||||
strategies = {"Random": RandomStrategy, "Stockfish": StockFishStrategy}
|
strategies = {"Random": RandomStrategy, "Stockfish": StockFishStrategy}
|
||||||
|
|
||||||
if os.name == 'nt':
|
if os.name == 'nt':
|
||||||
|
|||||||
Reference in New Issue
Block a user