1- import os
21import random
2+ import io
33from pathlib import Path
4- from concurrent .futures import ThreadPoolExecutor , as_completed
5- from typing import List , Optional
4+ from concurrent .futures import ProcessPoolExecutor , as_completed
5+ from typing import List
66import chess
77import chess .pgn
8- import io
98from tqdm import tqdm
109
1110
11+ # ------------------------------------------------------------
12+ # Game logic
13+ # ------------------------------------------------------------
1214
1315def is_capture_or_promotion (board : chess .Board , move : chess .Move ) -> bool :
14- """Check if a move is a capture or promotion."""
1516 return board .is_capture (move ) or move .promotion is not None
1617
1718
1819def extract_positions_from_game (game_data : str , n_positions : int = 5 ) -> List [str ]:
19- """Extract N sampled positions from a game, with original restrictions, labeled with final result (w/d/b)."""
2020 try :
2121 pgn_io = io .StringIO (game_data )
2222 game = chess .pgn .read_game (pgn_io )
2323 if not game :
2424 return []
2525
26- # Map result to label
2726 result = game .headers .get ("Result" , "" )
2827 if result == "1-0" :
2928 outcome = "w"
@@ -32,27 +31,22 @@ def extract_positions_from_game(game_data: str, n_positions: int = 5) -> List[st
3231 elif result == "1/2-1/2" :
3332 outcome = "d"
3433 else :
35- return [] # Skip unfinished/invalid games
34+ return []
3635
3736 board = game .board ()
3837 valid_positions = []
39- move_number = 0
4038
41- # Iterate through all moves in the game
4239 for node in game .mainline ():
4340 move = node .move
4441 if move :
4542 board .push (move )
46- move_number += 1
4743
48- # Skip if in check or no legal moves
4944 if board .is_check () or not any (board .legal_moves ):
5045 continue
5146
5247 next_node = node .next ()
5348 best_move = next_node .move if next_node else None
5449
55- # Skip if best move is capture or promotion
5650 if best_move and is_capture_or_promotion (board , best_move ):
5751 continue
5852
@@ -61,7 +55,6 @@ def extract_positions_from_game(game_data: str, n_positions: int = 5) -> List[st
6155 if not valid_positions :
6256 return []
6357
64- # Randomly select up to n_positions
6558 if len (valid_positions ) > n_positions :
6659 valid_positions = random .sample (valid_positions , n_positions )
6760
@@ -71,115 +64,130 @@ def extract_positions_from_game(game_data: str, n_positions: int = 5) -> List[st
7164 return []
7265
7366
74- def count_games_in_file (file_path : Path ) -> int :
75- try :
76- with open (file_path , "r" , encoding = "utf-8" ) as f :
77- content = f .read ()
78- return content .count ("[Event " )
79- except Exception :
80- return 0
67+ def process_single_game (game_data : str , positions_per_game : int ) -> List [str ]:
68+ return extract_positions_from_game (game_data , positions_per_game )
8169
8270
83- def process_pgn_file (file_path : Path , positions_per_game : int = 5 , pbar : Optional [tqdm ] = None ) -> List [str ]:
84- results = []
85- try :
86- with open (file_path , "r" , encoding = "utf-8" ) as f :
87- content = f .read ()
88-
89- games = []
90- current_game = []
91- for line in content .split ("\n " ):
92- if line .strip ().startswith ("[Event " ) and current_game :
93- games .append ("\n " .join (current_game ))
94- current_game = [line ]
71+ # ------------------------------------------------------------
72+ # PGN streaming utilities
73+ # ------------------------------------------------------------
74+
75+ def stream_pgn_games (file_path : Path ):
76+ with open (file_path , "r" , encoding = "utf-8" ) as f :
77+ game_lines = []
78+ for line in f :
79+ if line .startswith ("[Event " ) and game_lines :
80+ yield "" .join (game_lines )
81+ game_lines = [line ]
9582 else :
96- current_game .append (line )
97- if current_game :
98- games .append ("\n " .join (current_game ))
99-
100- for game_data in games :
101- if game_data .strip ():
102- positions = extract_positions_from_game (game_data , positions_per_game )
103- results .extend (positions )
104- if pbar :
105- pbar .update (1 )
83+ game_lines .append (line )
10684
107- except Exception as e :
108- print (f"Error processing file { file_path } : { e } " )
109- if pbar :
110- expected_games = count_games_in_file (file_path )
111- pbar .update (expected_games )
85+ if game_lines :
86+ yield "" .join (game_lines )
11287
113- return results
88+
89+ def count_games_in_file (file_path : Path ) -> int :
90+ count = 0
91+ with open (file_path , "r" , encoding = "utf-8" ) as f :
92+ for line in f :
93+ if line .startswith ("[Event " ):
94+ count += 1
95+ return count
11496
11597
116- def extract_positions_from_folder (folder_path : str , positions_per_game : int = 5 , max_workers : int = 4 ) -> List [str ]:
98+ # ------------------------------------------------------------
99+ # Main multiprocessing pipeline
100+ # ------------------------------------------------------------
101+
102+ def extract_positions_from_folder (
103+ folder_path : str ,
104+ positions_per_game : int = 5 ,
105+ max_workers : int = 4 ,
106+ task_buffer : int = 2000 ,
107+ ) -> List [str ]:
108+
117109 folder = Path (folder_path )
118110 if not folder .exists ():
119- raise FileNotFoundError (f"Folder { folder_path } does not exist" )
111+ raise FileNotFoundError (f"Folder does not exist: { folder } " )
120112
121113 pgn_files = list (folder .glob ("*.pgn" ))
122114 if not pgn_files :
123- print (f "No PGN files found in { folder_path } " )
115+ print ("No PGN files found. " )
124116 return []
125117
126118 print (f"Found { len (pgn_files )} PGN files" )
127119
128- total_games = 0
129- for pgn_file in tqdm (pgn_files , desc = "Counting games" ):
130- total_games += count_games_in_file (pgn_file )
131-
120+ total_games = sum (count_games_in_file (p ) for p in pgn_files )
132121 print (f"Total games to process: { total_games } " )
133122
134123 all_results = []
135- with tqdm (total = total_games , desc = "Processing games" , unit = "games" ) as pbar :
136- with ThreadPoolExecutor (max_workers = max_workers ) as executor :
137- future_to_file = {
138- executor .submit (process_pgn_file , pgn_file , positions_per_game , pbar ): pgn_file
139- for pgn_file in pgn_files
140- }
141- for future in as_completed (future_to_file ):
142- try :
143- all_results .extend (future .result ())
144- except Exception as e :
145- print (f"Error: { e } " )
124+ futures = []
125+
126+ with ProcessPoolExecutor (max_workers = max_workers ) as executor :
127+ with tqdm (total = total_games , desc = "Processing games" , unit = "games" ) as pbar :
128+
129+ for pgn_file in pgn_files :
130+ for game_data in stream_pgn_games (pgn_file ):
131+ futures .append (
132+ executor .submit (
133+ process_single_game ,
134+ game_data ,
135+ positions_per_game ,
136+ )
137+ )
138+
139+ if len (futures ) >= task_buffer :
140+ for future in as_completed (futures ):
141+ all_results .extend (future .result ())
142+ pbar .update (1 )
143+ futures .clear ()
144+
145+ for future in as_completed (futures ):
146+ all_results .extend (future .result ())
147+ pbar .update (1 )
146148
147149 return all_results
148150
149151
150- def save_positions_to_file (positions : List [str ], output_file : str = "sampled_positions.txt" ):
152+ # ------------------------------------------------------------
153+ # Output
154+ # ------------------------------------------------------------
155+
156+ def save_positions_to_file (positions : List [str ], output_file : str ):
151157 with open (output_file , "w" , encoding = "utf-8" ) as f :
152158 for line in positions :
153159 f .write (line + "\n " )
154160 print (f"Saved { len (positions )} positions to { output_file } " )
155161
156162
163+ # ------------------------------------------------------------
164+ # CLI
165+ # ------------------------------------------------------------
166+
157167def main ():
158168 folder_path = input ("Enter the path to the folder containing PGN files: " ).strip ()
169+
159170 try :
160171 positions_per_game = int (input ("Enter number of positions per game (default 5): " ) or "5" )
161172 except ValueError :
162173 positions_per_game = 5
163174
164175 try :
165- max_workers = int (input ("Enter maximum number of worker threads (default 4): " ) or "4" )
176+ max_workers = int (input ("Enter maximum number of worker processes (default 4): " ) or "4" )
166177 except ValueError :
167178 max_workers = 4
168179
169- try :
170- positions = extract_positions_from_folder (
171- folder_path , positions_per_game = positions_per_game , max_workers = max_workers
172- )
173- if positions :
174- output_file = f"sampled_positions_{ len (positions )} .txt"
175- save_positions_to_file (positions , output_file )
176- print (f"\n Summary:" )
177- print (f"- Total positions saved: { len (positions )} " )
178- print (f"- File: { output_file } " )
179- else :
180- print ("No valid games found!" )
181- except Exception as e :
182- print (f"Error: { e } " )
180+ positions = extract_positions_from_folder (
181+ folder_path ,
182+ positions_per_game = positions_per_game ,
183+ max_workers = max_workers ,
184+ )
185+
186+ if positions :
187+ output_file = f"sampled_positions_{ len (positions )} .txt"
188+ save_positions_to_file (positions , output_file )
189+ else :
190+ print ("No positions extracted." )
183191
184192
185193if __name__ == "__main__" :
@@ -188,7 +196,7 @@ def main():
188196 import chess .pgn
189197 from tqdm import tqdm
190198 except ImportError :
191- print ("Please install required dependencies:" )
199+ print ("Please install dependencies:" )
192200 print ("pip install python-chess tqdm" )
193201 exit (1 )
194202
0 commit comments