Fixed engines board init and classic mcts score function

This commit is contained in:
Theo Haslinger
2024-01-31 22:43:15 +01:00
parent c4d56f52a4
commit 628c8f2240
3 changed files with 19 additions and 14 deletions

View File

@@ -57,20 +57,20 @@ class EngineFactory:
return EngineFactory.lc0_engine(color, lc0_path) return EngineFactory.lc0_engine(color, lc0_path)
@staticmethod @staticmethod
def stockfish_engine(color: chess.Color, engine_path: str, stockfish_elo: int, board: chess.Board | None = chess.Board()) -> Engine: def stockfish_engine(color: chess.Color, engine_path: str, stockfish_elo: int) -> Engine:
return StockFishEngine(board, color, stockfish_elo, engine_path) return StockFishEngine(chess.Board(), color, stockfish_elo, engine_path)
@staticmethod @staticmethod
def lc0_engine(color: chess.Color, engine_path: str, board: chess.Board | None = chess.Board()) -> Engine: def lc0_engine(color: chess.Color, engine_path: str) -> Engine:
return Lc0Engine(board, color, engine_path) return Lc0Engine(chess.Board(), color, engine_path)
@staticmethod @staticmethod
def bayesian_mcts(color: chess.Color, strategy: IStrategy, board: chess.Board | None = chess.Board()) -> Engine: def bayesian_mcts(color: chess.Color, strategy: IStrategy) -> Engine:
return BayesMctsEngine(board, color, strategy) return BayesMctsEngine(chess.Board(), color, strategy)
@staticmethod @staticmethod
def classic_mcts(color: chess.Color, strategy: IStrategy, board: chess.Board | None = chess.Board()) -> Engine: def classic_mcts(color: chess.Color, strategy: IStrategy) -> Engine:
return ClassicMctsEngine(board, color, strategy) return ClassicMctsEngine(chess.Board(), color, strategy)
@staticmethod @staticmethod
def _get_random_strategy(rollout_depth: int) -> IStrategy: def _get_random_strategy(rollout_depth: int) -> IStrategy:
@@ -91,3 +91,4 @@ class EngineFactory:
@staticmethod @staticmethod
def _get_pesto_strategy(rollout_depth: int) -> IStrategy: def _get_pesto_strategy(rollout_depth: int) -> IStrategy:
return PestoStrategy(rollout_depth) return PestoStrategy(rollout_depth)

View File

@@ -97,8 +97,8 @@ class BayesianMctsNode(IMctsNode):
copied_board.push(m) copied_board.push(m)
steps += 1 steps += 1
steps = max(1, steps) steps = max(2, steps)
score = int(self.strategy.analyze_board(copied_board) / (math.log2(steps) + 1)) score = int(self.strategy.analyze_board(copied_board) / math.log2(steps))
self.result = score self.result = score
return score return score

View File

@@ -1,3 +1,5 @@
import math
import chess import chess
import random import random
import numpy as np import numpy as np
@@ -7,7 +9,7 @@ from chesspp.i_strategy import IStrategy
class ClassicMcts: class ClassicMcts:
def __init__(self, board: chess.Board, color: chess.Color, strategy: IStrategy, parent=None, move: chess.Move | None = None, def __init__(self, board: chess.Board, color: chess.Color, strategy: IStrategy, parent=None, move: chess.Move | None = None,
random_state: int | None = None): random_state: int | None = None, depth: int = 0):
self.random = random.Random(random_state) self.random = random.Random(random_state)
self.board = board self.board = board
self.color = color self.color = color
@@ -19,6 +21,7 @@ class ClassicMcts:
self.legal_moves = list(board.legal_moves) self.legal_moves = list(board.legal_moves)
self.untried_actions = self.legal_moves self.untried_actions = self.legal_moves
self.score = 0 self.score = 0
self.depth = depth
def _expand(self) -> 'ClassicMcts': def _expand(self) -> 'ClassicMcts':
""" """
@@ -29,7 +32,7 @@ class ClassicMcts:
self.untried_actions.remove(move) self.untried_actions.remove(move)
next_board = self.board.copy() next_board = self.board.copy()
next_board.push(move) next_board.push(move)
child_node = ClassicMcts(next_board, color=not self.color, strategy=self.strategy, parent=self, move=move) child_node = ClassicMcts(next_board, color=not self.color, strategy=self.strategy, parent=self, move=move, depth=self.depth+1)
self.children.append(child_node) self.children.append(child_node)
return child_node return child_node
@@ -40,7 +43,7 @@ class ClassicMcts:
:return: the score of the rolled out game :return: the score of the rolled out game
""" """
copied_board = self.board.copy() copied_board = self.board.copy()
steps = 1 steps = self.depth
for i in range(rollout_depth): for i in range(rollout_depth):
if copied_board.is_game_over(): if copied_board.is_game_over():
break break
@@ -49,7 +52,8 @@ class ClassicMcts:
copied_board.push(m) copied_board.push(m)
steps += 1 steps += 1
return self.strategy.analyze_board(copied_board) // steps steps = max(2, steps)
return int(self.strategy.analyze_board(copied_board) / math.log2(steps))
def _backpropagate(self, score: float) -> None: def _backpropagate(self, score: float) -> None:
""" """