Skip to content

Commit 85e20e4

Browse files
committed
Add in pre-training transform functionality
1 parent 2168d2d commit 85e20e4

File tree

26 files changed

+138605
-87
lines changed

26 files changed

+138605
-87
lines changed

configs/config_alignn.yml

Lines changed: 42 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -3,52 +3,53 @@ trainer: property
33

44
task:
55
# run_mode: train
6-
name: "alignn_first_training"
6+
identifier: "alignn_train_job"
77

8-
reprocess: "False"
8+
reprocess: False
99

10-
parallel: "True"
10+
11+
parallel: True
1112
seed: 0
1213
#seed=0 means random initalization
1314

14-
write_output: "True"
15-
parallel: "True"
15+
16+
write_output: True
17+
parallel: True
1618
#Training print out frequency (print per n number of epochs)
1719
verbosity: 1
1820

19-
#Ratios for train/val/test split out of a total of 1
20-
train_ratio: 0.85
21-
val_ratio: 0.05
22-
test_ratio: 0.10
2321

24-
model:
25-
name: "ALIGNN_GRAPHITE"
26-
load_model: "False"
27-
save_model: "True"
28-
model_path: "/global/cfs/projectdirs/m3641/Sidharth/MatDeepLearn_dev/testing/models/alignn_model_t1.pth"
2922

23+
model:
24+
name: ALIGNN
25+
load_model: False
26+
save_model: True
27+
model_path: "my_model.pth"
28+
edge_steps: 50
29+
self_loop: True
3030
#model attributes
31-
alignn_layers: 4
32-
gcn_layers: 4
33-
atom_input_features: 114
34-
edge_input_features: 50
35-
triplet_input_features: 40
36-
embedding_features: 32
37-
hidden_features: 64
38-
output_features: 1
39-
# min_edge_distance: 0.0,
40-
# max_edge_distance: 8.0,
41-
# min_angle: 0.0,
42-
# max_angle: torch.acos(torch.zeros(1)).item() * 2,
43-
link: "identity"
31+
dim1: 100
32+
dim2: 150
33+
pre_fc_count: 1
34+
gc_count: 4
35+
post_fc_count: 3
36+
pool: "global_mean_pool"
37+
pool_order: "early"
38+
batch_norm: True
39+
batch_track_stats: True
40+
act: "relu"
41+
dropout_rate: 0.0
4442

4543
optim:
46-
max_epochs: 300
44+
max_epochs: 250
4745
lr: 0.001
48-
#Loss functions (from pytorch) examples: l1_loss, mse_loss, binary_cross_entropy
49-
loss_fn: "mse_loss"
46+
#Either custom or from torch.nn.functional library. If from torch, loss_type is TorchLossWrapper
47+
loss:
48+
loss_type: "TorchLossWrapper"
49+
loss_args: {"loss_fn": "mse_loss"}
50+
5051
batch_size: 64
51-
52+
5253
optimizer:
5354
optimizer_type: "AdamW"
5455
optimizer_args: {"weight_decay": 0.00001}
@@ -67,16 +68,18 @@ dataset:
6768
target_path: "/global/cfs/projectdirs/m3641/Shared/Materials_datasets/MP_data_69K/targets.csv"
6869
#Path to save processed data.pt file (a directory path not filepath)
6970
pt_path: "/global/cfs/projectdirs/m3641/Sidharth/datasets/MP_data_69K/"
71+
otf: False
7072
transforms:
7173
- NumNodeTransform
7274
- LineGraphMod
7375
- ToFloat
7476
#Format of data files (limit to those supported by ASE)
7577
data_format: "json"
76-
#Method of obtaining atom dictionary: available:(one-hot)
78+
#Method of obtaining atom idctionary: available:(onehot)
7779
node_representation: "onehot"
80+
additional_attributes: []
7881
#Print out processing info
79-
verbose: "True"
82+
verbose: True
8083

8184
#Loading dataset params
8285
#Index of target column in targets.csv
@@ -85,4 +88,9 @@ dataset:
8588
#graph specific settings
8689
cutoff_radius : 8.0
8790
n_neighbors : 12
88-
edge_steps : 50
91+
edge_steps : 50
92+
93+
#Ratios for train/val/test split out of a total of 1
94+
train_ratio: 0.8
95+
val_ratio: 0.05
96+
test_ratio: 0.15

matdeeplearn/common/data.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def dataset_split(
5959

6060

6161
def get_dataset(
62-
data_path, target_index: int = 0, transform_list=[], large_dataset=False
62+
data_path, target_index: int = 0, transform_list=[], otf=False, large_dataset=False
6363
):
6464
"""
6565
get dataset according to data_path
@@ -81,15 +81,16 @@ def get_dataset(
8181
8282
transform_list: transformation function/classes to be applied
8383
"""
84-
84+
8585
transforms = [GetY(index=target_index)]
8686

8787
# set transform method
88-
for transform in transform_list:
89-
if transform in TRANSFORM_REGISTRY:
90-
transforms.append(TRANSFORM_REGISTRY[transform]())
91-
else:
92-
raise ValueError("No such transform found for {transform}")
88+
if otf:
89+
for transform in transform_list:
90+
if transform in TRANSFORM_REGISTRY:
91+
transforms.append(TRANSFORM_REGISTRY[transform]())
92+
else:
93+
raise ValueError("No such transform found for {transform}")
9394

9495
# check if large dataset is needed
9596
if large_dataset:
@@ -98,7 +99,7 @@ def get_dataset(
9899
Dataset = StructureDataset
99100

100101
transform = Compose(transforms)
101-
102+
102103
return Dataset(data_path, processed_data_path="", transform=transform)
103104

104105

matdeeplearn/models/alignn.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from torch_geometric.transforms import Compose
88
from matdeeplearn.common.registry import registry
99
from matdeeplearn.models.base_model import BaseModel
10-
from matdeeplearn.preprocessor.transforms import NumNodeTransform, LineGraphMod, ToFloat
1110
from typing import Optional, Literal
1211
import numpy as np
1312
import contextlib
@@ -362,14 +361,6 @@ def target_attr(self):
362361
return "y"
363362

364363
def forward(self, g: Data):
365-
# Compute OTF transform to generate attributes for L(g)
366-
367-
# with prof_ctx():
368-
369-
with torch.no_grad():
370-
otf = Compose([NumNodeTransform(), LineGraphMod(), ToFloat()])
371-
otf(g)
372-
373364
# initial node features
374365
node_feats = self.atom_embedding(g.x)
375366
# initial bond features

matdeeplearn/models/alignn_graphite.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,10 @@ def __init__(self, dim=100, num_interactions=6, num_species=3, cutoff=3.0, **kwa
116116
)
117117

118118
self.reset_parameters()
119+
120+
@property
121+
def target_attr(self):
122+
return "y"
119123

120124
def reset_parameters(self):
121125
self.embed_atm.reset_parameters()
@@ -127,14 +131,10 @@ def embed_ang(self, x_ang):
127131
cos_ang = torch.cos(x_ang)
128132
return gaussian(cos_ang, start=-1, end=1, num_basis=self.dim)
129133

130-
def forward(self, data: Data):
131-
with torch.no_grad():
132-
otf = Compose([NumNodeTransform(), LineGraphMod(), ToFloat()])
133-
otf(data)
134-
134+
def forward(self, data: Data):
135135
edge_index_G = data.edge_index
136136
edge_index_A = data.edge_index_lg
137-
h_atm = self.embed_atm(data.x)
137+
h_atm = self.embed_atm(data.x.type(torch.long))
138138
h_bnd = self.embed_bnd(data.edge_attr)
139139
h_ang = self.embed_ang(data.edge_attr_lg)
140140

matdeeplearn/preprocessor/helpers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,7 @@ def generate_node_features(input_data, n_neighbors, device):
279279

280280
for i, data in enumerate(input_data):
281281
input_data[i] = one_hot_degree(data, n_neighbors+1)
282+
282283

283284
def generate_edge_features(input_data, edge_steps, r, device):
284285
distance_gaussian = GaussianSmearing(0, 1, edge_steps, 0.2, device=device)
@@ -333,7 +334,7 @@ def compute_bond_angles(pos: torch.Tensor, offsets: torch.Tensor, edge_index: to
333334

334335
# Calculate triplets
335336
idx_i, idx_j, idx_k, idx_kj, idx_ji = triplets(
336-
edge_index, offsets, num_nodes)
337+
edge_index, offsets.to(device=edge_index.device), num_nodes)
337338

338339
# Calculate angles.
339340
pos_i = pos[idx_i]

matdeeplearn/preprocessor/processor.py

Lines changed: 53 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from ase import io
1010
from torch_geometric.data import Data, InMemoryDataset
1111
from torch_geometric.utils import dense_to_sparse
12+
from torch_geometric.transforms import Compose
1213
from tqdm import tqdm
1314

1415
from matdeeplearn.preprocessor.helpers import (
@@ -18,6 +19,8 @@
1819
get_cutoff_distance_matrix,
1920
)
2021

22+
from matdeeplearn.preprocessor.transforms import TRANSFORM_REGISTRY
23+
2124

2225
def process_data(dataset_config):
2326
root_path = dataset_config["src"]
@@ -41,6 +44,8 @@ def process_data(dataset_config):
4144
r=cutoff_radius,
4245
n_neighbors=n_neighbors,
4346
edge_steps=edge_steps,
47+
otf=dataset_config.get("otf", False),
48+
transforms=dataset_config.get("transforms", []),
4449
data_format=data_format,
4550
image_selfloop=image_selfloop,
4651
self_loop=self_loop,
@@ -61,6 +66,8 @@ def __init__(
6166
r: float,
6267
n_neighbors: int,
6368
edge_steps: int,
69+
otf: bool = False,
70+
transforms: list = [],
6471
data_format: str = "json",
6572
image_selfloop: bool = True,
6673
self_loop: bool = True,
@@ -132,6 +139,9 @@ def __init__(
132139
self.verbose = verbose
133140
self.device = device
134141

142+
self.otf = otf
143+
self.transforms = transforms
144+
135145
self.disable_tqdm = logging.root.level > logging.INFO
136146

137147
def src_check(self):
@@ -153,14 +163,17 @@ def ase_wrap(self):
153163
dict_structures = []
154164
ase_structures = []
155165

156-
logging.info("Converting data to standardized form for downstream processing.")
166+
logging.info(
167+
"Converting data to standardized form for downstream processing.")
157168
for i, structure_id in enumerate(file_names):
158-
p = os.path.join(self.root_path, str(structure_id) + "." + self.data_format)
169+
p = os.path.join(self.root_path, str(
170+
structure_id) + "." + self.data_format)
159171
ase_structures.append(ase.io.read(p))
160172

161173
for i, s in enumerate(tqdm(ase_structures, disable=self.disable_tqdm)):
162174
d = {}
163-
pos = torch.tensor(s.get_positions(), device=self.device, dtype=torch.float)
175+
pos = torch.tensor(s.get_positions(),
176+
device=self.device, dtype=torch.float)
164177
cell = torch.tensor(
165178
np.array(s.get_cell()), device=self.device, dtype=torch.float
166179
)
@@ -173,7 +186,8 @@ def ase_wrap(self):
173186

174187
# add additional attributes
175188
if self.additional_attributes:
176-
attributes = self.get_csv_additional_attributes(d["structure_id"])
189+
attributes = self.get_csv_additional_attributes(
190+
d["structure_id"])
177191
for k, v in attributes.items():
178192
d[k] = v
179193

@@ -189,9 +203,12 @@ def get_csv_additional_attributes(self, structure_id):
189203
attributes = {}
190204

191205
for attr in self.additional_attributes:
192-
p = os.path.join(self.root_path, structure_id + "_" + attr + ".csv")
193-
values = np.genfromtxt(p, delimiter=",", dtype=float, encoding=None)
194-
values = torch.tensor(values, device=self.device, dtype=torch.float)
206+
p = os.path.join(self.root_path, structure_id +
207+
"_" + attr + ".csv")
208+
values = np.genfromtxt(
209+
p, delimiter=",", dtype=float, encoding=None)
210+
values = torch.tensor(
211+
values, device=self.device, dtype=torch.float)
195212
attributes[attr] = values
196213

197214
return attributes
@@ -212,13 +229,17 @@ def json_wrap(self):
212229

213230
dict_structures = []
214231
y = []
215-
y_dim = len(original_structures[0]["y"]) if isinstance(original_structures[0]["y"], list) else 1
232+
y_dim = len(original_structures[0]["y"]) if isinstance(
233+
original_structures[0]["y"], list) else 1
216234

217-
logging.info("Converting data to standardized form for downstream processing.")
235+
logging.info(
236+
"Converting data to standardized form for downstream processing.")
218237
for i, s in enumerate(tqdm(original_structures, disable=self.disable_tqdm)):
219238
d = {}
220-
pos = torch.tensor(s["positions"], device=self.device, dtype=torch.float)
221-
cell = torch.tensor(s["cell"], device=self.device, dtype=torch.float)
239+
pos = torch.tensor(
240+
s["positions"], device=self.device, dtype=torch.float)
241+
cell = torch.tensor(
242+
s["cell"], device=self.device, dtype=torch.float)
222243
atomic_numbers = torch.LongTensor(s["atomic_numbers"])
223244

224245
d["positions"] = pos
@@ -268,6 +289,7 @@ def get_data_list(self, dict_structures, y):
268289
data_list = [Data() for _ in range(n_structures)]
269290

270291
logging.info("Getting torch_geometric.data.Data() objects.")
292+
271293
for i, sdict in enumerate(tqdm(dict_structures, disable=self.disable_tqdm)):
272294
target_val = y[i]
273295
data = data_list[i]
@@ -312,7 +334,26 @@ def get_data_list(self, dict_structures, y):
312334
generate_node_features(data_list, self.n_neighbors, device=self.device)
313335

314336
logging.info("Generating edge features...")
315-
generate_edge_features(data_list, self.edge_steps, self.r, device=self.device)
337+
generate_edge_features(data_list, self.edge_steps,
338+
self.r, device=self.device)
339+
340+
logging.info("Applying transforms...")
341+
342+
# saving line graph attributes through transforms
343+
transforms_list = []
344+
345+
if not self.otf:
346+
for transform in self.transforms:
347+
if transform in TRANSFORM_REGISTRY:
348+
transforms_list.append(TRANSFORM_REGISTRY[transform]())
349+
else:
350+
raise ValueError(
351+
"No such transform found for {transform}")
352+
353+
composition = Compose(transforms_list)
354+
355+
for data in data_list:
356+
composition(data)
316357

317358
clean_up(data_list, ["edge_descriptor"])
318359

0 commit comments

Comments
 (0)