-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathDOS_mae_calculation.py
More file actions
156 lines (124 loc) · 6.54 KB
/
DOS_mae_calculation.py
File metadata and controls
156 lines (124 loc) · 6.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
"""
Node-level (per-atom) DOS MAE calculation.
Three MAEs are reported:
1. gen_vs_gt : DOS predicted from generated structures vs. ground truth DOS
2. gt_pred_vs_gt : DOS predicted from ground truth structures vs. ground truth DOS (predictor baseline)
3. gen_vs_gt_pred : DOS predicted from generated structures vs. predicted from ground truth structures
"""
import argparse
import numpy as np
from pathlib import Path
from glob import glob
from tqdm import tqdm
from omegaconf import OmegaConf
import torch
from torch_geometric.loader import DataLoader
from dosmatgen.diffusion.property import CSPProperty
from dosmatgen.dataset.datamodule import CrystalDataModule, worker_init_fn
# ── model loading ──────────────────────────────────────────────────────────────
def load_predictor(root_path):
"""Load CSPProperty predictor from a checkpoint directory."""
root_path = Path(root_path)
config = OmegaConf.load(root_path / 'hparams.yaml')
ckpts = glob(str(root_path / '*.ckpt'))
if len(ckpts) != 1:
raise ValueError(f"Expected 1 checkpoint in {root_path}, found {len(ckpts)}")
predictor = CSPProperty.load_from_checkpoint(ckpts[0], config=config)
predictor.to('cuda')
predictor.eval()
return predictor
def load_data_module(model_path, test_path, batch_size):
"""Load CrystalDataModule from diffusion model config, overriding the test dataset path."""
model_path = Path(model_path)
config = OmegaConf.load(model_path / 'hparams.yaml')
config.datamodule.batch_size.test = batch_size
config.datamodule.datasets.test.dataset_path = test_path
# save_path for preprocessed cache lives next to the dataset file
config.datamodule.datasets.test.save_path = str(Path(test_path).with_suffix('.pt'))
data_module = CrystalDataModule(config, scaler_path=str(model_path))
data_module.setup(stage="test")
return data_module
# ── DOS extraction ─────────────────────────────────────────────────────────────
def predict_dos(loader, predictor, scaler):
"""
Run predictor on all structures in loader.
Returns {structure_id: np.array [num_atoms, dos_dim]}.
"""
dos_dict = {}
for batch in tqdm(loader, desc="Predicting DOS"):
batch = batch.to('cuda')
pred_node, _ = predictor.infer(batch) # [total_atoms, dos_dim]
pred_node = scaler.inverse_transform(pred_node.cpu())
offset = 0
for sid, n in zip(batch.structure_id, batch.num_atoms.tolist()):
dos_dict[sid] = pred_node[offset:offset + n].numpy()
offset += n
return dos_dict
def get_gt_dos(test_loader, scaler):
"""
Extract ground truth per-atom DOS from the test loader.
Returns {structure_id: np.array [num_atoms, dos_dim]}.
"""
gt_dict = {}
for batch in tqdm(test_loader, desc="Loading ground truth DOS"):
gt = scaler.inverse_transform(batch.y.cpu()) # [total_atoms, dos_dim]
offset = 0
for sid, n in zip(batch.structure_id, batch.num_atoms.tolist()):
gt_dict[sid] = gt[offset:offset + n].numpy()
offset += n
return gt_dict
# ── MAE ────────────────────────────────────────────────────────────────────────
def compute_mae(pred_dict, ref_dict):
"""
Mean per-atom L1 loss averaged over all matched structures.
Both dicts map structure_id -> [num_atoms, dos_dim].
"""
loss_fn = torch.nn.L1Loss()
maes = []
for sid in pred_dict:
if sid not in ref_dict:
continue
pred = torch.tensor(pred_dict[sid])
ref = torch.tensor(ref_dict[sid])
if pred.shape != ref.shape:
continue # skip if atom count differs
maes.append(loss_fn(pred, ref).item())
return float(np.mean(maes))
# ── main ───────────────────────────────────────────────────────────────────────
def main(args):
# data module (config + scalers from diffusion model; test data from test_path)
print("Loading data module...")
data_module = load_data_module(args.model_path, args.test_path, args.batch_size)
test_loader = data_module.test_dataloader()
scaler = data_module.scaler
# predictor model
print("Loading predictor...")
predictor = load_predictor(args.root_path)
# ground truth DOS from test structures
gt_dos = get_gt_dos(test_loader, scaler)
# DOS predicted from ground truth structures (predictor quality baseline)
gt_pred_dos = predict_dos(test_loader, predictor, scaler)
# DOS predicted from generated structures
print("Loading generated structures...")
gen_loader = data_module.get_adhoc_dataloader(args.gen_path, batch_size=args.batch_size)
gen_pred_dos = predict_dos(gen_loader, predictor, scaler)
# MAEs
mae_gen_vs_gt = compute_mae(gen_pred_dos, gt_dos)
mae_gt_pred_vs_gt = compute_mae(gt_pred_dos, gt_dos)
mae_gen_vs_gt_pred = compute_mae(gen_pred_dos, gt_pred_dos)
n_gen = len(gen_pred_dos)
n_gt = len(gt_dos)
print(f"\nStructures — ground truth: {n_gt}, generated: {n_gen}")
print(f"\nPer-atom mean L1 MAE:")
print(f" Gen pred vs. GT : {mae_gen_vs_gt:.6f}")
print(f" GT pred vs. GT (baseline): {mae_gt_pred_vs_gt:.6f}")
print(f" Gen pred vs. GT pred : {mae_gen_vs_gt_pred:.6f}")
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Compute node-level DOS MAE for generated vs. ground-truth structures.")
parser.add_argument('--test_path', type=str, required=True, help="Path to test dataset (JSON or directory) with ground truth structures and DOSes")
parser.add_argument('--gen_path', type=str, required=True, help="Path to directory of generated CIF structures")
parser.add_argument('--model_path', type=str, required=True, help="Path to trained diffusion model (provides config and scalers)")
parser.add_argument('--root_path', type=str, required=True, help="Path to CSPProperty predictor model checkpoint directory")
parser.add_argument('--batch_size', type=int, default=100)
args = parser.parse_args()
main(args)