| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| import chess |
| import os |
| import chess.engine as eng |
| import torch.multiprocessing as mp |
| import random |
| from pathlib import Path |
|
|
| |
| CONFIG = { |
| "stockfish_path": "/Users/aaronvattay/Downloads/stockfish/stockfish-macos-m1-apple-silicon", |
| "model_path": "chessy_model.pth", |
| "backup_model_path": "chessy_modelt-1.pth", |
| "device": torch.device("mps"), |
| "learning_rate": 1e-4, |
| "num_games": 30, |
| "num_epochs": 10, |
| "stockfish_time_limit": 1.0, |
| "search_depth": 1, |
| "epsilon": 4 |
| } |
|
|
| device = CONFIG["device"] |
|
|
| def board_to_tensor(board): |
| piece_encoding = { |
| 'P': 1, 'N': 2, 'B': 3, 'R': 4, 'Q': 5, 'K': 6, |
| 'p': 7, 'n': 8, 'b': 9, 'r': 10, 'q': 11, 'k': 12 |
| } |
|
|
| tensor = torch.zeros(64, dtype=torch.long) |
| for square in chess.SQUARES: |
| piece = board.piece_at(square) |
| if piece: |
| tensor[square] = piece_encoding[piece.symbol()] |
| else: |
| tensor[square] = 0 |
|
|
| return tensor.unsqueeze(0) |
|
|
| class NN1(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.embedding = nn.Embedding(13, 64) |
| self.attention = nn.MultiheadAttention(embed_dim=64, num_heads=16) |
| self.neu = 512 |
| self.neurons = nn.Sequential( |
| nn.Linear(4096, self.neu), |
| nn.ReLU(), |
| nn.Linear(self.neu, self.neu), |
| nn.ReLU(), |
| nn.Linear(self.neu, self.neu), |
| nn.ReLU(), |
| nn.Linear(self.neu, self.neu), |
| nn.ReLU(), |
| nn.Linear(self.neu, self.neu), |
| nn.ReLU(), |
| nn.Linear(self.neu, self.neu), |
| nn.ReLU(), |
| nn.Linear(self.neu, self.neu), |
| nn.ReLU(), |
| nn.Linear(self.neu, self.neu), |
| nn.ReLU(), |
| nn.Linear(self.neu, self.neu), |
| nn.ReLU(), |
| nn.Linear(self.neu, self.neu), |
| nn.ReLU(), |
| nn.Linear(self.neu, self.neu), |
| nn.ReLU(), |
| nn.Linear(self.neu, self.neu), |
| nn.ReLU(), |
| nn.Linear(self.neu, self.neu), |
| nn.ReLU(), |
| nn.Linear(self.neu, 64), |
| nn.ReLU(), |
| nn.Linear(64, 4) |
| ) |
|
|
| def forward(self, x): |
| x = self.embedding(x) |
| x = x.permute(1, 0, 2) |
| attn_output, _ = self.attention(x, x, x) |
| x = attn_output.permute(1, 0, 2).contiguous() |
| x = x.view(x.size(0), -1) |
| x = self.neurons(x) |
| return x |
|
|
| lass Policy(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.embedding = nn.Embedding(13, 32) |
| self.attention = nn.MultiheadAttention(embed_dim=32, num_heads=16) |
| self.neu = 256 |
| self.neurons = nn.Sequential( |
| nn.Linear(64*32, self.neu), |
| nn.ReLU(), |
| nn.Linear(self.neu, self.neu), |
| nn.ReLU(), |
| nn.Linear(self.neu, self.neu), |
| nn.ReLU(), |
| nn.Linear(self.neu, self.neu), |
| nn.ReLU(), |
| nn.Linear(self.neu, 128), |
| nn.ReLU(), |
| nn.Linear(128, 29275), |
| ) |
|
|
| def forward(self, x): |
| x = chess.Board(x) |
| color = x.turn |
| x = board_to_tensor(x) |
| x = self.embedding(x) |
| x = x.permute(1, 0, 2) |
| attn_output, _ = self.attention(x, x, x) |
| x = attn_output.permute(1, 0, 2).contiguous() |
| x = x.view(x.size(0), -1) |
| x = self.neurons(x) * color |
| return x |
|
|
| model = NN1().to(device) |
| optimizer = optim.Adam(model.parameters(), lr=CONFIG["learning_rate"]) |
| policy = Policy().to(device) |
| polweight = torch.load("NeoChess/chessy_policy.pth",map_location=device,weights_only=False) |
| policy.load_state_dict(polweight) |
|
|
| try: |
| model.load_state_dict(torch.load(CONFIG["model_path"], map_location=device)) |
| print(f"Loaded model from {CONFIG['model_path']}") |
| except FileNotFoundError: |
| try: |
| model.load_state_dict(torch.load(CONFIG["backup_model_path"], map_location=device)) |
| print(f"Loaded backup model from {CONFIG['backup_model_path']}") |
| except FileNotFoundError: |
| print("No model file found, starting from scratch.") |
|
|
| model.train() |
| criterion = nn.MSELoss() |
| engine = eng.SimpleEngine.popen_uci(CONFIG["stockfish_path"]) |
| lim = eng.Limit(time=CONFIG["stockfish_time_limit"]) |
|
|
| def get_evaluation(board): |
| """ |
| Returns the evaluation of the board from the perspective of the current player. |
| The model's output is from White's perspective. |
| """ |
| tensor = board_to_tensor(board).to(device) |
| with torch.no_grad(): |
| evaluation = model(tensor)[0][0].item() |
| |
| if board.turn == chess.WHITE: |
| return evaluation |
| else: |
| return -evaluation |
|
|
| with open("/usr/local/python/3.12.1/lib/python3.12/site-packages/torchrl/envs/custom/san_moves.txt", "r") as f: |
| uci_to_index = {line.strip(): i for i, line in enumerate(f)} |
|
|
|
|
| def search(board ,depth ,policy_net=policy, simulations=100, temperature=1.0, device="cpu"): |
| """ |
| Monte Carlo search using policy network for move selection |
| and value network via get_evaluation(). |
| """ |
| |
| depth |
| with torch.no_grad(): |
| fen_tensor = torch.tensor([board.fen()], device=device) |
| logits = policy_net(fen_tensor)["logits"].squeeze(0) |
| probs = torch.softmax(logits / temperature, dim=-1).cpu().numpy() |
|
|
| move_scores = {move: 0 for move in board.legal_moves} |
|
|
| for move in board.legal_moves: |
| total_eval = 0 |
| for _ in range(simulations): |
| board.push(move) |
| eval_score = get_evaluation(board) |
| total_eval += eval_score |
| board.pop() |
| move_scores[move] = total_eval / simulations |
|
|
| |
| for move in move_scores: |
| move_index = uci_to_index[str(move)] |
| move_scores[move] *= probs[move_index] |
|
|
| |
| best_move = max(move_scores, key=move_scores.get) |
| return best_move, move_scores |
|
|
| |
|
|
| def game_gen(engine_side): |
| data = [] |
| mc = 0 |
| board = chess.Board() |
| while not board.is_game_over(): |
| is_bot_turn = board.turn != engine_side |
| |
| if is_bot_turn: |
| evaling = {} |
| for move in board.legal_moves: |
| board.push(move) |
| evaling[move] = -search(board, depth=CONFIG["search_depth"], alpha=float('-inf'), beta=float('inf')) |
| board.pop() |
| |
| if not evaling: |
| break |
| |
| keys = list(evaling.keys()) |
| logits = torch.tensor(list(evaling.values())).to(device) |
| probs = torch.softmax(logits,dim=0) |
| epsilon = min(CONFIG["epsilon"],len(keys)) |
| bests = torch.multinomial(probs,num_samples=epsilon,replacement=False) |
| best_idx = bests[torch.argmax(logits[bests])] |
| move = keys[best_idx.item()] |
| |
| else: |
| result = engine.play(board, lim) |
| move = result.move |
|
|
| if is_bot_turn: |
| data.append({ |
| 'fen': board.fen(), |
| 'move_number': mc, |
| }) |
|
|
| board.push(move) |
| mc += 1 |
|
|
| result = board.result() |
| c = 0 |
| if result == '1-0': |
| c = 10.0 |
| elif result == '0-1': |
| c = -10.0 |
| return data, c, mc |
| def train(data, c, mc): |
| for entry in data: |
| tensor = board_to_tensor(chess.Board(entry['fen'])).to(device) |
| target = torch.tensor(c * entry['move_number'] / mc, dtype=torch.float32).to(device) |
| output = model(tensor)[0][0] |
| loss = criterion(output, target) |
| optimizer.zero_grad() |
| loss.backward() |
| optimizer.step() |
| |
| print(f"Saving model to {CONFIG['model_path']}") |
| torch.save(model.state_dict(), CONFIG["model_path"]) |
| return |
| def main(): |
| for i in range(CONFIG["num_epochs"]): |
| mp.set_start_method('spawn', force=True) |
| num_games = CONFIG['num_games'] |
| num_instances = mp.cpu_count() |
| print(f"Saving backup model to {CONFIG['backup_model_path']}") |
| torch.save(model.state_dict(), CONFIG["backup_model_path"]) |
| with mp.Pool(processes=num_instances) as pool: |
| results_self = pool.starmap(game_gen, [(None,) for _ in range(num_games // 3)]) |
| results_white = pool.starmap(game_gen, [(chess.WHITE,) for _ in range(num_games // 3)]) |
| results_black = pool.starmap(game_gen, [(chess.BLACK,) for _ in range(num_games // 3)]) |
| results = [] |
| for s, w, b in zip(results_self, results_white, results_black): |
| results.extend([s, w, b]) |
| for batch in results: |
| data, c, mc = batch |
| print(f"Saving backup model to {CONFIG['backup_model_path']}") |
| torch.save(model.state_dict(), CONFIG["backup_model_path"]) |
| if data: |
| train(data, c, mc) |
| print("Training complete.") |
| engine.quit() |
| if __name__ == "__main__": |
| main() |