99from ase import io
1010from torch_geometric .data import Data , InMemoryDataset
1111from torch_geometric .utils import dense_to_sparse
12+ from torch_geometric .transforms import Compose
1213from tqdm import tqdm
1314
1415from matdeeplearn .preprocessor .helpers import (
1819 get_cutoff_distance_matrix ,
1920)
2021
22+ from matdeeplearn .preprocessor .transforms import TRANSFORM_REGISTRY
23+
2124
2225def process_data (dataset_config ):
2326 root_path = dataset_config ["src" ]
@@ -41,6 +44,8 @@ def process_data(dataset_config):
4144 r = cutoff_radius ,
4245 n_neighbors = n_neighbors ,
4346 edge_steps = edge_steps ,
47+ otf = dataset_config .get ("otf" , False ),
48+ transforms = dataset_config .get ("transforms" , []),
4449 data_format = data_format ,
4550 image_selfloop = image_selfloop ,
4651 self_loop = self_loop ,
@@ -61,6 +66,8 @@ def __init__(
6166 r : float ,
6267 n_neighbors : int ,
6368 edge_steps : int ,
69+ otf : bool = False ,
70+ transforms : list = [],
6471 data_format : str = "json" ,
6572 image_selfloop : bool = True ,
6673 self_loop : bool = True ,
@@ -132,6 +139,9 @@ def __init__(
132139 self .verbose = verbose
133140 self .device = device
134141
142+ self .otf = otf
143+ self .transforms = transforms
144+
135145 self .disable_tqdm = logging .root .level > logging .INFO
136146
137147 def src_check (self ):
@@ -153,14 +163,17 @@ def ase_wrap(self):
153163 dict_structures = []
154164 ase_structures = []
155165
156- logging .info ("Converting data to standardized form for downstream processing." )
166+ logging .info (
167+ "Converting data to standardized form for downstream processing." )
157168 for i , structure_id in enumerate (file_names ):
158- p = os .path .join (self .root_path , str (structure_id ) + "." + self .data_format )
169+ p = os .path .join (self .root_path , str (
170+ structure_id ) + "." + self .data_format )
159171 ase_structures .append (ase .io .read (p ))
160172
161173 for i , s in enumerate (tqdm (ase_structures , disable = self .disable_tqdm )):
162174 d = {}
163- pos = torch .tensor (s .get_positions (), device = self .device , dtype = torch .float )
175+ pos = torch .tensor (s .get_positions (),
176+ device = self .device , dtype = torch .float )
164177 cell = torch .tensor (
165178 np .array (s .get_cell ()), device = self .device , dtype = torch .float
166179 )
@@ -173,7 +186,8 @@ def ase_wrap(self):
173186
174187 # add additional attributes
175188 if self .additional_attributes :
176- attributes = self .get_csv_additional_attributes (d ["structure_id" ])
189+ attributes = self .get_csv_additional_attributes (
190+ d ["structure_id" ])
177191 for k , v in attributes .items ():
178192 d [k ] = v
179193
@@ -189,9 +203,12 @@ def get_csv_additional_attributes(self, structure_id):
189203 attributes = {}
190204
191205 for attr in self .additional_attributes :
192- p = os .path .join (self .root_path , structure_id + "_" + attr + ".csv" )
193- values = np .genfromtxt (p , delimiter = "," , dtype = float , encoding = None )
194- values = torch .tensor (values , device = self .device , dtype = torch .float )
206+ p = os .path .join (self .root_path , structure_id +
207+ "_" + attr + ".csv" )
208+ values = np .genfromtxt (
209+ p , delimiter = "," , dtype = float , encoding = None )
210+ values = torch .tensor (
211+ values , device = self .device , dtype = torch .float )
195212 attributes [attr ] = values
196213
197214 return attributes
@@ -212,13 +229,17 @@ def json_wrap(self):
212229
213230 dict_structures = []
214231 y = []
215- y_dim = len (original_structures [0 ]["y" ]) if isinstance (original_structures [0 ]["y" ], list ) else 1
232+ y_dim = len (original_structures [0 ]["y" ]) if isinstance (
233+ original_structures [0 ]["y" ], list ) else 1
216234
217- logging .info ("Converting data to standardized form for downstream processing." )
235+ logging .info (
236+ "Converting data to standardized form for downstream processing." )
218237 for i , s in enumerate (tqdm (original_structures , disable = self .disable_tqdm )):
219238 d = {}
220- pos = torch .tensor (s ["positions" ], device = self .device , dtype = torch .float )
221- cell = torch .tensor (s ["cell" ], device = self .device , dtype = torch .float )
239+ pos = torch .tensor (
240+ s ["positions" ], device = self .device , dtype = torch .float )
241+ cell = torch .tensor (
242+ s ["cell" ], device = self .device , dtype = torch .float )
222243 atomic_numbers = torch .LongTensor (s ["atomic_numbers" ])
223244
224245 d ["positions" ] = pos
@@ -268,6 +289,7 @@ def get_data_list(self, dict_structures, y):
268289 data_list = [Data () for _ in range (n_structures )]
269290
270291 logging .info ("Getting torch_geometric.data.Data() objects." )
292+
271293 for i , sdict in enumerate (tqdm (dict_structures , disable = self .disable_tqdm )):
272294 target_val = y [i ]
273295 data = data_list [i ]
@@ -312,7 +334,26 @@ def get_data_list(self, dict_structures, y):
312334 generate_node_features (data_list , self .n_neighbors , device = self .device )
313335
314336 logging .info ("Generating edge features..." )
315- generate_edge_features (data_list , self .edge_steps , self .r , device = self .device )
337+ generate_edge_features (data_list , self .edge_steps ,
338+ self .r , device = self .device )
339+
340+ logging .info ("Applying transforms..." )
341+
342+ # saving line graph attributes through transforms
343+ transforms_list = []
344+
345+ if not self .otf :
346+ for transform in self .transforms :
347+ if transform in TRANSFORM_REGISTRY :
348+ transforms_list .append (TRANSFORM_REGISTRY [transform ]())
349+ else :
350+ raise ValueError (
351+ "No such transform found for {transform}" )
352+
353+ composition = Compose (transforms_list )
354+
355+ for data in data_list :
356+ composition (data )
316357
317358 clean_up (data_list , ["edge_descriptor" ])
318359
0 commit comments