diff --git a/chesspp/mcts/classic_mcts.py b/chesspp/mcts/classic_mcts.py index 068fc3c..ae2586b 100644 --- a/chesspp/mcts/classic_mcts.py +++ b/chesspp/mcts/classic_mcts.py @@ -29,7 +29,7 @@ class ClassicMcts: self.untried_actions.remove(move) next_board = self.board.copy() next_board.push(move) - child_node = ClassicMcts(next_board, color=self.color, strategy=self.strategy, parent=self, move=move) + child_node = ClassicMcts(next_board, color=not self.color, strategy=self.strategy, parent=self, move=move) self.children.append(child_node) return child_node diff --git a/chesspp/mcts/i_mcts.py b/chesspp/mcts/i_mcts.py index 1231495..6a0020c 100644 --- a/chesspp/mcts/i_mcts.py +++ b/chesspp/mcts/i_mcts.py @@ -40,13 +40,13 @@ class IMcts(ABC): """ pass - @abstractmethod - def get_moves(self) -> Dict[chess.Move, int]: - """ - Return all legal moves from this node with respective scores - :return: dictionary with moves as key and scores as values - """ - pass + #@abstractmethod + #def get_moves(self) -> Dict[chess.Move, int]: + # """ + # Return all legal moves from this node with respective scores + # :return: dictionary with moves as key and scores as values + # """ + # pass """ TODO: add score class: diff --git a/chesspp/mcts/i_mcts_node.py b/chesspp/mcts/i_mcts_node.py index ffb1f05..140e960 100644 --- a/chesspp/mcts/i_mcts_node.py +++ b/chesspp/mcts/i_mcts_node.py @@ -6,6 +6,7 @@ import chess from chesspp.i_strategy import IStrategy + class IMctsNode(ABC): def __init__(self, board: chess.Board, strategy: IStrategy, parent: Self | None, move: chess.Move | None, random_state: random.Random): @@ -16,7 +17,6 @@ class IMctsNode(ABC): self.move = move self.legal_moves = list(board.legal_moves) self.random_state = random_state - self.depth = 0 @abstractmethod def select(self) -> Self: