Fixed engines board init and classic mcts score function

This commit is contained in:
Theo Haslinger
2024-01-31 22:43:15 +01:00
parent c4d56f52a4
commit 628c8f2240
3 changed files with 19 additions and 14 deletions

View File

@@ -1,3 +1,5 @@
import math
import chess
import random
import numpy as np
@@ -7,7 +9,7 @@ from chesspp.i_strategy import IStrategy
class ClassicMcts:
def __init__(self, board: chess.Board, color: chess.Color, strategy: IStrategy, parent=None, move: chess.Move | None = None,
random_state: int | None = None):
random_state: int | None = None, depth: int = 0):
self.random = random.Random(random_state)
self.board = board
self.color = color
@@ -19,6 +21,7 @@ class ClassicMcts:
self.legal_moves = list(board.legal_moves)
self.untried_actions = self.legal_moves
self.score = 0
self.depth = depth
def _expand(self) -> 'ClassicMcts':
"""
@@ -29,7 +32,7 @@ class ClassicMcts:
self.untried_actions.remove(move)
next_board = self.board.copy()
next_board.push(move)
child_node = ClassicMcts(next_board, color=not self.color, strategy=self.strategy, parent=self, move=move)
child_node = ClassicMcts(next_board, color=not self.color, strategy=self.strategy, parent=self, move=move, depth=self.depth+1)
self.children.append(child_node)
return child_node
@@ -40,7 +43,7 @@ class ClassicMcts:
:return: the score of the rolled out game
"""
copied_board = self.board.copy()
steps = 1
steps = self.depth
for i in range(rollout_depth):
if copied_board.is_game_over():
break
@@ -49,7 +52,8 @@ class ClassicMcts:
copied_board.push(m)
steps += 1
return self.strategy.analyze_board(copied_board) // steps
steps = max(2, steps)
return int(self.strategy.analyze_board(copied_board) / math.log2(steps))
def _backpropagate(self, score: float) -> None:
"""