From 5645657fdd675b7c7ed6f3ac736ffbcc352eddb1 Mon Sep 17 00:00:00 2001 From: luk3k Date: Wed, 24 Jan 2024 11:57:25 +0100 Subject: [PATCH] added initial mcts --- .gitignore | 3 +- __pycache__/engine.cpython-310.pyc | Bin 0 -> 1233 bytes __pycache__/eval.cpython-310.pyc | Bin 0 -> 4861 bytes __pycache__/mcts.cpython-310.pyc | Bin 0 -> 3504 bytes engine.py | 33 --------- eval.py | 9 ++- main.py | 45 +++++++++++++ mcts.py | 103 +++++++++++++++++++++++++++++ requirements.txt | 3 +- 9 files changed, 156 insertions(+), 40 deletions(-) create mode 100644 __pycache__/engine.cpython-310.pyc create mode 100644 __pycache__/eval.cpython-310.pyc create mode 100644 __pycache__/mcts.cpython-310.pyc create mode 100644 main.py create mode 100644 mcts.py diff --git a/.gitignore b/.gitignore index 85c7f25..58660d9 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ /stockfish/ -.idea \ No newline at end of file +.idea +.venv \ No newline at end of file diff --git a/__pycache__/engine.cpython-310.pyc b/__pycache__/engine.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3088e8ed5e99adbff4b0144a629cc5a949b5fcdd GIT binary patch literal 1233 zcma)6J8u**5VrSq&xDYI0z~8ramD2%gb<1lLOeP+MM@V5MZUG~)|cJ&#`Yd0x>V5e zA95ufzsD^V`~pP9*n0<&puj7S$1|RXXFjjo+Ug=0+h4!%?;%28-RAm8VDk*7J_N!L z!#O%`R=i$=!|@T!7clibAO;z{MDySS{(u=i4OEW>N63U%(MG;|4! zq^E=uMN1|NQpg4O;)9ARS`Z)ifTS7My4f_^CXoX$T+m!FZ%3peV_9+{rlgeD6?8#G zPA7TeTDg?RgV=^SFKv_yZEP>+Gn${c@EWe^ld(yfNs;maRsmRiMydnL&&Q)o7JQV~ zr+ic?Ip?X-qpBLN`H|0c$jg~1`LL=+4YyyN+fD^{^)X!6v?IGogm>3hlMQqgjd03>^;QuuJduWv$1gJ2l2%ZU-2VR)HyXADn6qHC)!YDL^&hx zgvp4>IGKqBFaNuNO?%yqU3aDahp8{b0V(TZ!WEHI?^{DGYKYEcoij3lMAnHhLZ=4K zRXII(dAXGSa=d|!Q|eSIPNo1Fzx(aRUC&?n;{N_f8=0O?h0gwP!+KJerXIe3{Al>- z;gBxgKNx>=00EqAxus`gj{%!Bt+#zBUX}A#9uj~8S@H6uPKAxCT4y!}xXrlj3w`2Z zpFmBNjm~Oe-~iml#;-*ojOsh*iN~VneKfTLQj3Mf%EPlE?w?_|Hl^`yZ-Pkg0@-eN zaUZwgOYkkcjr)PR1L8qj-EdZYpZk83x^l3y@pJ$Bqc<)y57khV##>#B=^lsjf!j|C U$?BXxbq`HDQ5(5h!tlnE-${7M_3Mk;9PwjuO`__j7{Q*VMUy$cwADbrveJlbbd2r7;OL642~cq{U5; zHn%|LxDC?bIp5yN?`homGutUZn}@atZ2{U6v_)vk(3YT`g0>9pG_+ID&H%$KJccNM zG({1lB}yP|Q3jb4Qy?8N&8NkTp9B3IJcc+g=0xEcd&a;r$tsR5i>%VvvdJostsGfX zW6L3H8ZBT)#u?Q8oUiX(;5f5+J}WLhGWj`v{xfrDVd7f^-yFXHzDpCI)_$vf8T|A7 zBKY6tY)h*wjBvl}BM#EP1H0`18pVH6d}e@f>zY=%Bf-|bg&0|Y8{wf!6Se=meECwM zkIYm6t0qy5zJciION1!;97gS9$cX?#MspZ&P^ccHO-R#m7ZH+2fzzPTQb8sOWEow- zC{340C(+S$3>5#r$V}rbvvOzjIWG2JDE^HrovmQ-j zUlOzk3MhKo9*WXBNsvhag`;k#;ZEbDfe1{3OvRU3hdLmlQtMEQR?|@0WM3c$sSoAH znLJsKRvJMvYdgz)JqOh;s0~-gaTZTnyi^ausl!kj1qi2hT1Ayc zI0bn#AK{2p5o9zqF5wg(X&FHfwP-b6HP8eCJ4m%>@`gB3$LUN|EmnmpGai{LI^j<1 z15G3&2tqwFg+&!o3Q#xFC@`r#Ckk_>$M=te9i-QiTFGe|6;%~DZ3k79G_^-`1nL|p zpul8ebaV~hI5Q{+d=@T+yctJl^X73?f~m#Okm}GfZxWRRH2cYECS?)u#-7&`jgBAe zCDv}}NuJnJ!~+>Lk^cCD17KRN!4v%x1ZPJK9vvP79y8Y3>{r_FZXa1g?YTa*Ua%v3 zXmfp_$%~JRL%VGN!tChsJc9)XbfLcq8q5N4o4^Qdzkz%mOa2xS9R@ax)Um<__U+*BJ@fD%1s6WGfa4ChUb7iW z4%^_yp}QUSyy(#V$Zxg+Up$FA{w)GyhYyzB)i7?kk zL2sGHpn_2!;VD(}iPdb0C`xi$A3t3Aq_&kf)T80WNg>tZgT6?tj~~`*bpVImP$t%= z_tv*-iM_hMb??DuVsF&f@7~+4LQP*7xBBnj*X2(D_(}rf9l(bY_$q7~q^b4-QV$bm zz~q%7YqMBy8PuH$p0Lo^Qw-#IdXsW3Dk;Y~<31_@t% z{Lav97svf~K`((`#`KJ*hT1a*X{yK5=wH<08N@AW+*s6x@XTqD&^j=@hDhju~v2YDtvk2IBlJWp{U$}~>C15z&aZMV$%TWs2^p0X zV<90eJ-k8k0tA!dw(=zTwR^R-jgM9^#h>^=Af$W?K}K`{7X|K8V$pq(KgFIF?T);PRPUg` zj7`OlF0Ec@W_NM{(Hf zJo2N~>qq;_QIc_ZZMiz>ptPmI?jVQ^p!zy?f-B8Rn2&ilSDl4=b^qPZka8Y+y?X5G1JT8m^LlA`uF!mkLM-iZi>*HJ7Ak zxQ?}0pKPGLef!ilZ|Br8Z}+tIdyCuLd2Vs%nSJVU zc4BqiORTq!6OyZskIsGk)ITAFMLMDHZEg!^=yC_L#o3X?T>(8uR?p)t?mcgvuJL!c z&)1*3r#^p|Z(wvytoK~r=9?G|ct@=BHO}y5i*IAr2FBYxoA2;l^fs~j8s9^!gVr1T zO|-W7d;A9J;TBfinwWKBuhQybEmI*@MPg2YK7y=raBhW;~Z`%`O(U$kyU z0&bmHFYIHPmpXnjNHdPU_8zAyt(11NR8@K-6a6@g z=nP74jWInHJc^SlEpnw>Dk-GswzR86)-!75I@B4Rl1}!*k~?_)@3VtJF%k#a_@Ou` zWib>ru$<5c9;x`LWwZW!x}-f zco>&ub{?ixSQTc0ahpcfVSSLp=0n$BLfj4H8qT3TC{7b0JEkKA(N&Y|CIwJ*%dr|r z8>5nfSbNx04v^q5c+y866SDw=qTE+(`ti*fOS}# zxoon%k1%{R*G-_H1{K}rPXY5Psn_U2I&zuW* z6SMM-3gk?QG}q9BU}4sy^=fVA(^rteCjV&wdw`u)h`k#+!n5;GrK52cS83i4=kjfxZaZk3J%F`BASRiX^>PU0*RJiN8gRphrqD7jT;@j3RpwN$WR z(U@R@CM1!1vc2WS0ipd zQ*L4E6<&#B2S3B>XJ|}8^ALsMbIYdGMm;vIdnWkgO$xMi>0cNw}5LqrBg$9jf5H0v2*fw)tPOdK_;)<>-xrw$XzzkRvRuzQ-oS(CY06}{A z)U~yfiQGi_3sxIK$#o?#jF!-hQ4C#_u6`LVKFW60#8}RgP1LvRNk~KYNLRH9nC@#o(s1}?e4wuf-iBBE^LLBd-LxX z?7H<2`w{sTyQH%sJPe`4y|(ib{?oxeBicxv*cNkGb zkZ0NK_RGnK4a^apA~L7nznc}v7wR4wi-ldsq+@wN`+BkLkK!kD57S$9N>|Z$BJB?< zRc~#KFnM@mMfZI&fbK_u=Q8i*wmsQO|~?=8HMbx(lY z%tb~v=+X>{tB1+6lZKuXkx^9-<1iC(&(ytT0{mkgom5dMbLK@A*5v`Y5B=}bEDO&> z^KH4Xaf+`NN}CNq8+oV!mR}+M`Z7xu*zgQ*9(y8 z@Mj1AcfH9QOC(ojU;cod^_q%DWhT^>w0EVPeGILLYCi>z1McNnhBoYy^E-EBf#Y2<@7i#A}kK z4$j6%4N)b9_$y|b+}ZYg*T-$bw3uJ=wFZS4qX~^ zDwB~8A{0hG&ZymvqDQEO^O2Uhaq8AleM$csXVXrK`z<9D<4C&I%Zj*K9!|SwMUmB> zMow0rof7OyDKH6{BGi~hL1}KSzV!hMHug4mw>#~?34$A2wdrf}H#C=Saq{ORsN(7t j_$tp$UZ`(v%QWBU*Z$r5X!odgs5XRe8OMe5?m9mL6D2`$ literal 0 HcmV?d00001 diff --git a/engine.py b/engine.py index 01413ed..db50503 100644 --- a/engine.py +++ b/engine.py @@ -1,31 +1,6 @@ import chess import chess.engine import random -import eval - - -def main(): - fools_mate = "rnbqkbnr/pppp1ppp/4p3/8/5PP1/8/PPPPP2P/RNBQKBNR b KQkq f3 0 2" - board = chess.Board(fools_mate) - print(board, '\n') - moves = {} - for i in range(10): - move = pick_move(board) - if move is None: - break - - simulate_game(board, move, 100) - moves[move] = board - board = chess.Board(fools_mate) - - analyze_results(moves) - - -def analyze_results(moves: dict): - for m, b in moves.items(): - manual_score = eval.score_game(b) - engine_score = eval.analyze_with_stockfish(b) - print(f"score for move {m}: manual_score={manual_score}, engine_score={engine_score}") def pick_move(board: chess.Board) -> chess.Move | None: @@ -49,19 +24,11 @@ def simulate_game(board: chess.Board, move: chess.Move, depth: int): """ engine = chess.engine.SimpleEngine.popen_uci("./stockfish/stockfish-ubuntu-x86-64-avx2") board.push(move) - print(move) - print(board, '\n') for i in range(depth): if board.is_game_over(): engine.quit() return r = engine.play(board, chess.engine.Limit(depth=2)) - print(r) board.push(r.move) - print(board, '\n') engine.quit() - - -if __name__ == '__main__': - main() diff --git a/eval.py b/eval.py index cda3c56..e820fda 100644 --- a/eval.py +++ b/eval.py @@ -1,5 +1,6 @@ import chess import chess.engine +import sys # Eval constants for scoring chess boards # Evaluation metric inspired by Tomasz Michniewski: https://www.chessprogramming.org/Simplified_Evaluation_Function @@ -136,9 +137,7 @@ def check_endgame(board: chess.Board) -> bool: return (queens_black == 0 and queens_white == 0) or ((queens_black >= 1 and minors_black <= 1) or (queens_white >= 1 and minors_white <= 1)) - - -def score_game(board: chess.Board) -> float: +def score_manual(board: chess.Board) -> int: """ Calculate the score of the given board regarding the given color :param board: the chess board @@ -147,7 +146,7 @@ def score_game(board: chess.Board) -> float: outcome = board.outcome() if outcome is not None: if outcome.termination == chess.Termination.CHECKMATE: - return float('inf') if outcome.winner == chess.WHITE else float('-inf') + return sys.maxsize if outcome.winner == chess.WHITE else -sys.maxsize else: # draw return 0 @@ -171,7 +170,7 @@ def score_game(board: chess.Board) -> float: return score -def analyze_with_stockfish(board: chess.Board) -> chess.engine.PovScore: +def score_stockfish(board: chess.Board) -> chess.engine.PovScore: """ Calculate the score of the given board using stockfish :param board: diff --git a/main.py b/main.py new file mode 100644 index 0000000..9caad17 --- /dev/null +++ b/main.py @@ -0,0 +1,45 @@ +import chess +import chess.engine +from mcts import MCTSNode +import engine +import eval + + +def test_mcts(seed): + fools_mate = "rnbqkbnr/pppp1ppp/4p3/8/5PP1/8/PPPPP2P/RNBQKBNR b KQkq f3 0 2" + board = chess.Board(fools_mate) + mcts_root = MCTSNode(board) + mcts_root.build_tree() + sorted_moves = sorted(mcts_root.children, key=lambda x: x.move.uci()) + for c in sorted_moves: + print("move (mcts):", c.move, " with score:", c.score) + + +def test_stockfish(seed): + fools_mate = "rnbqkbnr/pppp1ppp/4p3/8/5PP1/8/PPPPP2P/RNBQKBNR b KQkq f3 0 2" + board = chess.Board(fools_mate) + moves = {} + untried_moves = list(board.legal_moves) + for move in untried_moves: + engine.simulate_game(board, move, 100) + moves[move] = board + board = chess.Board(fools_mate) + + sorted_moves = dict(sorted(moves.items(), key=lambda x: x[0].uci())) + analyze_results(sorted_moves) + + +def analyze_results(moves: dict): + for m, b in moves.items(): + manual_score = eval.score_manual(b) + engine_score = eval.score_stockfish(b).white() + print(f"score for move {m}: manual_score={manual_score}, engine_score={engine_score}") + + +def main(): + test_mcts(0) + test_stockfish(0) + + +if __name__ == '__main__': + main() diff --git a/mcts.py b/mcts.py new file mode 100644 index 0000000..7e56d48 --- /dev/null +++ b/mcts.py @@ -0,0 +1,103 @@ +import chess +import random +import eval +import engine +import numpy as np + + +class MCTSNode: + def __init__(self, board: chess.Board, parent = None, move: chess.Move | None = None, random_state: int | None = None): + self.random = random.Random(random_state) + self.board = board + self.parent = parent + self.move = move + self.children = [] + self.visits = 0 + self.legal_moves = list(board.legal_moves) + self.untried_actions = self.legal_moves + self.score = 0 + + def _expand(self) -> 'MCTSNode': + """ + Expands the node, i.e., choose an action and apply it to the board + :return: + """ + move = self.random.choice(self.untried_actions) + self.untried_actions.remove(move) + next_board = self.board.copy() + next_board.push(move) + child_node = MCTSNode(next_board, parent=self, move=move) + self.children.append(child_node) + return child_node + + def _rollout(self, rollout_depth: int = 100) -> float: + """ + Rolls out the node by simulating a game for a given depth. + Sometimes this step is called 'simulation' or 'playout'. + :return: the score of the rolled out game + """ + copied_board = self.board.copy() + for i in range(rollout_depth): + if copied_board.is_game_over(): + break + + m = engine.pick_move(copied_board) + copied_board.push(m) + + return eval.score_manual(copied_board) + + def _backpropagate(self, score: float) -> None: + """ + Backpropagates the results of the rollout + :param score: + :return: + """ + self.visits += 1 + # TODO: maybe use score + num of moves together (a win in 1 move is better than a win in 20 moves) + self.score += score + if self.parent: + self.parent._backpropagate(score) + + def is_fully_expanded(self) -> bool: + return len(self.untried_actions) == 0 + + def _best_child(self) -> 'MCTSNode': + """ + Picks the best child according to our policy + :return: the best child + """ + # NOTE: maybe clamp the score between [-1, +1] instead of [-inf, +inf] + choices_weights = [(c.score / c.visits) + np.sqrt(((2 * np.log(self.visits)) / c.visits)) + for c in self.children] + return self.children[np.argmax(choices_weights)] + + def _select_leaf(self) -> 'MCTSNode': + """ + Selects a leaf node. + If the node is not expanded is will be expanded. + :return: Leaf node + """ + current_node = self + while not current_node.board.is_game_over(): + if not current_node.is_fully_expanded(): + return current_node._expand() + else: + current_node = current_node._best_child() + + return current_node + + def build_tree(self, samples: int = 1000) -> 'MCTSNode': + """ + Runs the MCTS with the given number of samples + :param samples: number of simulations + :return: best node containing the best move + """ + for i in range(samples): + # selection & expansion + # rollout + # backpropagate score + node = self._select_leaf() + score = node._rollout() + node._backpropagate(score) + + return self._best_child() diff --git a/requirements.txt b/requirements.txt index a4adb10..68b140b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,2 @@ -chess==1.10.0 \ No newline at end of file +chess==1.10.0 +numpy==1.26.3 \ No newline at end of file