Merge remote-tracking branch 'origin/main'
This commit is contained in:
@@ -57,20 +57,20 @@ class EngineFactory:
|
|||||||
return EngineFactory.lc0_engine(color, lc0_path)
|
return EngineFactory.lc0_engine(color, lc0_path)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def stockfish_engine(color: chess.Color, engine_path: str, stockfish_elo: int, board: chess.Board | None = chess.Board()) -> Engine:
|
def stockfish_engine(color: chess.Color, engine_path: str, stockfish_elo: int) -> Engine:
|
||||||
return StockFishEngine(board, color, stockfish_elo, engine_path)
|
return StockFishEngine(chess.Board(), color, stockfish_elo, engine_path)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def lc0_engine(color: chess.Color, engine_path: str, board: chess.Board | None = chess.Board()) -> Engine:
|
def lc0_engine(color: chess.Color, engine_path: str) -> Engine:
|
||||||
return Lc0Engine(board, color, engine_path)
|
return Lc0Engine(chess.Board(), color, engine_path)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def bayesian_mcts(color: chess.Color, strategy: IStrategy, board: chess.Board | None = chess.Board()) -> Engine:
|
def bayesian_mcts(color: chess.Color, strategy: IStrategy) -> Engine:
|
||||||
return BayesMctsEngine(board, color, strategy)
|
return BayesMctsEngine(chess.Board(), color, strategy)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def classic_mcts(color: chess.Color, strategy: IStrategy, board: chess.Board | None = chess.Board()) -> Engine:
|
def classic_mcts(color: chess.Color, strategy: IStrategy) -> Engine:
|
||||||
return ClassicMctsEngine(board, color, strategy)
|
return ClassicMctsEngine(chess.Board(), color, strategy)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_random_strategy(rollout_depth: int) -> IStrategy:
|
def _get_random_strategy(rollout_depth: int) -> IStrategy:
|
||||||
@@ -91,3 +91,4 @@ class EngineFactory:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_pesto_strategy(rollout_depth: int) -> IStrategy:
|
def _get_pesto_strategy(rollout_depth: int) -> IStrategy:
|
||||||
return PestoStrategy(rollout_depth)
|
return PestoStrategy(rollout_depth)
|
||||||
|
|
||||||
|
|||||||
@@ -97,8 +97,8 @@ class BayesianMctsNode(IMctsNode):
|
|||||||
copied_board.push(m)
|
copied_board.push(m)
|
||||||
steps += 1
|
steps += 1
|
||||||
|
|
||||||
steps = max(1, steps)
|
steps = max(2, steps)
|
||||||
score = int(self.strategy.analyze_board(copied_board) / (math.log2(steps) + 1))
|
score = int(self.strategy.analyze_board(copied_board) / math.log2(steps))
|
||||||
self.result = score
|
self.result = score
|
||||||
return score
|
return score
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
import math
|
||||||
|
|
||||||
import chess
|
import chess
|
||||||
import random
|
import random
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -7,7 +9,7 @@ from chesspp.i_strategy import IStrategy
|
|||||||
class ClassicMcts:
|
class ClassicMcts:
|
||||||
|
|
||||||
def __init__(self, board: chess.Board, color: chess.Color, strategy: IStrategy, parent=None, move: chess.Move | None = None,
|
def __init__(self, board: chess.Board, color: chess.Color, strategy: IStrategy, parent=None, move: chess.Move | None = None,
|
||||||
random_state: int | None = None):
|
random_state: int | None = None, depth: int = 0):
|
||||||
self.random = random.Random(random_state)
|
self.random = random.Random(random_state)
|
||||||
self.board = board
|
self.board = board
|
||||||
self.color = color
|
self.color = color
|
||||||
@@ -19,6 +21,7 @@ class ClassicMcts:
|
|||||||
self.legal_moves = list(board.legal_moves)
|
self.legal_moves = list(board.legal_moves)
|
||||||
self.untried_actions = self.legal_moves
|
self.untried_actions = self.legal_moves
|
||||||
self.score = 0
|
self.score = 0
|
||||||
|
self.depth = depth
|
||||||
|
|
||||||
def _expand(self) -> 'ClassicMcts':
|
def _expand(self) -> 'ClassicMcts':
|
||||||
"""
|
"""
|
||||||
@@ -29,7 +32,7 @@ class ClassicMcts:
|
|||||||
self.untried_actions.remove(move)
|
self.untried_actions.remove(move)
|
||||||
next_board = self.board.copy()
|
next_board = self.board.copy()
|
||||||
next_board.push(move)
|
next_board.push(move)
|
||||||
child_node = ClassicMcts(next_board, color=not self.color, strategy=self.strategy, parent=self, move=move)
|
child_node = ClassicMcts(next_board, color=not self.color, strategy=self.strategy, parent=self, move=move, depth=self.depth+1)
|
||||||
self.children.append(child_node)
|
self.children.append(child_node)
|
||||||
return child_node
|
return child_node
|
||||||
|
|
||||||
@@ -40,7 +43,7 @@ class ClassicMcts:
|
|||||||
:return: the score of the rolled out game
|
:return: the score of the rolled out game
|
||||||
"""
|
"""
|
||||||
copied_board = self.board.copy()
|
copied_board = self.board.copy()
|
||||||
steps = 1
|
steps = self.depth
|
||||||
for i in range(rollout_depth):
|
for i in range(rollout_depth):
|
||||||
if copied_board.is_game_over():
|
if copied_board.is_game_over():
|
||||||
break
|
break
|
||||||
@@ -49,7 +52,8 @@ class ClassicMcts:
|
|||||||
copied_board.push(m)
|
copied_board.push(m)
|
||||||
steps += 1
|
steps += 1
|
||||||
|
|
||||||
return self.strategy.analyze_board(copied_board) // steps
|
steps = max(2, steps)
|
||||||
|
return int(self.strategy.analyze_board(copied_board) / math.log2(steps))
|
||||||
|
|
||||||
def _backpropagate(self, score: float) -> None:
|
def _backpropagate(self, score: float) -> None:
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user