Skip to content

Commit fc08f7f

Browse files
committed
Make ESMFold data prep scripts more generic
1 parent c3b8f6d commit fc08f7f

File tree

12 files changed

+76
-73
lines changed

12 files changed

+76
-73
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ repos:
126126
hooks:
127127
- id: codespell
128128
args:
129-
- --skip=logs/**,data/**,*.ipynb,posebench/utils/data_utils.py,posebench/utils/residue_utils.py,posebench/data/components/esmfold_fasta_preparation.py,posebench/models/minimize_energy.py,posebench/data/components/create_casp15_ensemble_input_csv.py,posebench/analysis/casp15_ligand_scoring/casp_parser.py
129+
- --skip=logs/**,data/**,*.ipynb,posebench/utils/data_utils.py,posebench/utils/residue_utils.py,posebench/data/components/protein_fasta_preparation.py,posebench/models/minimize_energy.py,posebench/data/components/create_casp15_ensemble_input_csv.py,posebench/analysis/casp15_ligand_scoring/casp_parser.py
130130
# - --ignore-words-list=abc,def
131131

132132
# jupyter notebook cell output clearing

README.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -235,18 +235,18 @@ tar xfz pdb100_2021Mar03.tar.gz
235235
cd ../../
236236
```
237237

238-
### Predicting apo protein structures using ESMFold
238+
### Predicting apo protein structures using ESMFold (optional, preprocessed data available)
239239

240240
First create all the corresponding FASTA files for each protein sequence
241241

242242
```bash
243-
python3 posebench/data/components/esmfold_fasta_preparation.py dataset=posebusters_benchmark
244-
python3 posebench/data/components/esmfold_fasta_preparation.py dataset=astex_diverse
243+
python3 posebench/data/components/protein_fasta_preparation.py dataset=posebusters_benchmark
244+
python3 posebench/data/components/protein_fasta_preparation.py dataset=astex_diverse
245245
```
246246

247247
To generate the apo version of each protein structure,
248248
create ESMFold-ready versions of the combined FASTA files
249-
prepared above by the script `esmfold_fasta_preparation.py`
249+
prepared above by the script `protein_fasta_preparation.py`
250250
for the PoseBusters Benchmark and Astex Diverse sets, respectively
251251

252252
```bash
@@ -272,8 +272,8 @@ or Astex Diverse set, taking ligand conformations into account
272272
during each alignment
273273

274274
```bash
275-
python3 posebench/data/components/esmfold_apo_to_holo_alignment.py dataset=posebusters_benchmark num_workers=1
276-
python3 posebench/data/components/esmfold_apo_to_holo_alignment.py dataset=astex_diverse num_workers=1
275+
python3 posebench/data/components/protein_apo_to_holo_alignment.py dataset=posebusters_benchmark num_workers=1
276+
python3 posebench/data/components/protein_apo_to_holo_alignment.py dataset=astex_diverse num_workers=1
277277
```
278278

279279
**NOTE:** The preprocessed Astex Diverse, PoseBusters Benchmark, DockGen, and CASP15 data available via [Zenodo](https://doi.org/10.5281/zenodo.11477766) provide pre-holo-aligned predicted protein structures for these respective datasets.

configs/data/components/esmfold_apo_to_holo_alignment.yaml

Lines changed: 0 additions & 5 deletions
This file was deleted.
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
dataset: posebusters_benchmark # the dataset to use - NOTE: must be one of (`posebusters_benchmark`, `astex_diverse`, `dockgen`, `casp15`)
2+
data_dir: ${oc.env:PROJECT_ROOT}/data/${dataset}_set/ # where the processed datasets (e.g., PoseBusters Benchmark) are placed
3+
predicted_structures_dir: ${oc.env:PROJECT_ROOT}/data/${dataset}_set/${dataset}_predicted_structures # where the predicted protein structures are placed
4+
output_dir: ${oc.env:PROJECT_ROOT}/data/${dataset}_set/${dataset}_holo_aligned_predicted_structures # where the holo-aligned predicted apo structures should be stored
5+
num_workers: 1 # number of CPU workers for parallel processing

configs/data/components/esmfold_fasta_preparation.yaml renamed to configs/data/components/protein_fasta_preparation.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
dataset: posebusters_benchmark # the dataset to use - NOTE: must be one of (`posebusters_benchmark`, `astex_diverse`)
2-
data_dir: ${oc.env:PROJECT_ROOT}/data/${dataset}_set # where the processed PoseBusters Benchmark and Astex Diverse sets are placed
2+
data_dir: ${oc.env:PROJECT_ROOT}/data/${dataset}_set # where the processed datasets (e.g., PoseBusters Benchmark) are placed
33
out_file: ${oc.env:PROJECT_ROOT}/data/${dataset}_set/${dataset}_sequences.fasta # the output FASTA file to produce

docs/source/configs/data.rst

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,17 @@ Input data components
99

1010
These data component configurations are used to modify how the input (apo) protein structures are predicted or aligned.
1111

12-
ESMFold apo-to-holo alignment
12+
Protein apo-to-holo alignment
1313
^^^^^^^^^^^^^^^^^^^^^^^^
14-
.. literalinclude:: ../../../configs/data/components/esmfold_apo_to_holo_alignment.yaml
14+
.. literalinclude:: ../../../configs/data/components/protein_apo_to_holo_alignment.yaml
1515
:language: yaml
16-
:caption: :file:`data/components/esmfold_apo_to_holo_alignment.yaml`
16+
:caption: :file:`data/components/protein_apo_to_holo_alignment.yaml`
1717

18-
ESMFold FASTA preparation
18+
Protein FASTA preparation
1919
^^^^^^^^^^^^^^^^^^^^^^^^
20-
.. literalinclude:: ../../../configs/data/components/esmfold_fasta_preparation.yaml
20+
.. literalinclude:: ../../../configs/data/components/protein_fasta_preparation.yaml
2121
:language: yaml
22-
:caption: :file:`data/components/esmfold_fasta_preparation.yaml`
22+
:caption: :file:`data/components/protein_fasta_preparation.yaml`
2323

2424
ESMFold sequence preparation
2525
^^^^^^^^^^^^^^^^^^^^^^^^

notebooks/adding_new_dataset_tutorial.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
"\n",
2121
"1. Create a new directory under `data/` with the required suffix `_set` (e.g., `data/newest_set/`) and group your (ground-truth) data files by unique IDs within this new directory (e.g., `data/newest_set/1G9V_RQ4/1G9V_RQ4_{protein.pdb,ligand.sdf}`)\n",
2222
"2. Update the config files throughout `configs/analysis/`, `configs/data/`, and `configs/model/` to list your new dataset as a CLI argument (e.g., `dataset: newest`)\n",
23-
"3. Predict `apo` protein structures for your new dataset using ESMFold by integrating parsing for your dataset into the ESMFold-related source code within `src/data/components/esmfold_fasta_preparation.py` and `src/data/components/esmfold_apo_to_holo_alignment.py`\n",
23+
"3. Predict `apo` protein structures for your new dataset using a structure predictor of your choice (e.g., ESMFold) by integrating parsing for your dataset into the prediction-related source code within `src/data/components/protein_fasta_preparation.py` and `src/data/components/protein_apo_to_holo_alignment.py`\n",
2424
"4. Using `notebooks/posebusters_astex_inference_results_plotting.ipynb` as a template, add a new Jupyter notebook to `notebooks/` for plotting each method's results on your new dataset (after preparing each method's dataset inputs and running inference with each desired method)"
2525
]
2626
},

posebench/analysis/complex_alignment.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
2121

2222
from posebench import register_custom_omegaconf_resolvers
23-
from posebench.data.components.esmfold_apo_to_holo_alignment import (
23+
from posebench.data.components.protein_apo_to_holo_alignment import (
2424
align_prediction,
2525
extract_receptor_structure,
2626
parse_pdb_from_path,

posebench/data/components/esmfold_apo_to_holo_alignment.py renamed to posebench/data/components/protein_apo_to_holo_alignment.py

Lines changed: 49 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ def extract_receptor_structure(
238238
def align_prediction(
239239
smoothing_factor: Optional[float],
240240
dataset_calpha_coords: np.ndarray,
241-
esmfold_calpha_coords: np.ndarray,
241+
predicted_calpha_coords: np.ndarray,
242242
dataset_ligand_coords: Optional[np.ndarray],
243243
return_rotation: bool = False,
244244
) -> Union[Tuple[Rotation, np.ndarray, np.ndarray], float]:
@@ -247,42 +247,42 @@ def align_prediction(
247247
248248
:param smoothing_factor: Smoothing factor controlling the alignment.
249249
:param dataset_calpha_coords: Array of Ca atom coordinates for a dataset's protein structure.
250-
:param esmfold_calpha_coords: Array of Ca atom coordinates for a dataset's protein structure.
250+
:param predicted_calpha_coords: Array of Ca atom coordinates for a dataset's protein structure.
251251
:param dataset_ligand_coords: Array of ligand coordinates from a dataset.
252252
:param return_rotation: Whether to return the rotation matrix and centroids (default: `False`).
253253
:return: If return_rotation is `True`, returns a tuple containing rotation matrix (`Rotation`), centroid of CA atoms for a dataset protein (`np.ndarray`),
254-
and centroid of CA atoms for ESMFold (`np.ndarray`). If return_rotation is `False`, returns the inverse root mean square error of reciprocal distances (`float`).
254+
and centroid of CA atoms for a prediction (`np.ndarray`). If return_rotation is `False`, returns the inverse root mean square error of reciprocal distances (`float`).
255255
"""
256256
if dataset_ligand_coords is not None:
257257
dataset_dists = spa.distance.cdist(dataset_calpha_coords, dataset_ligand_coords)
258258
weights = np.exp(-1 * smoothing_factor * np.amin(dataset_dists, axis=1))
259259
dataset_calpha_centroid = np.sum(
260260
np.expand_dims(weights, axis=1) * dataset_calpha_coords, axis=0
261261
) / np.sum(weights)
262-
esmfold_calpha_centroid = np.sum(
263-
np.expand_dims(weights, axis=1) * esmfold_calpha_coords, axis=0
262+
predicted_calpha_centroid = np.sum(
263+
np.expand_dims(weights, axis=1) * predicted_calpha_coords, axis=0
264264
) / np.sum(weights)
265265
else:
266266
weights = None
267267
dataset_calpha_centroid = np.mean(dataset_calpha_coords, axis=0)
268-
esmfold_calpha_centroid = np.mean(esmfold_calpha_coords, axis=0)
268+
predicted_calpha_centroid = np.mean(predicted_calpha_coords, axis=0)
269269
centered_dataset_calpha_coords = dataset_calpha_coords - dataset_calpha_centroid
270-
centered_esmfold_calpha_coords = esmfold_calpha_coords - esmfold_calpha_centroid
270+
centered_predicted_calpha_coords = predicted_calpha_coords - predicted_calpha_centroid
271271

272272
rotation, _ = spa.transform.Rotation.align_vectors(
273-
centered_dataset_calpha_coords, centered_esmfold_calpha_coords, weights
273+
centered_dataset_calpha_coords, centered_predicted_calpha_coords, weights
274274
)
275275
if return_rotation:
276-
return rotation, dataset_calpha_centroid, esmfold_calpha_centroid
276+
return rotation, dataset_calpha_centroid, predicted_calpha_centroid
277277

278278
if dataset_ligand_coords is not None:
279279
centered_dataset_ligand_coords = dataset_ligand_coords - dataset_calpha_centroid
280-
aligned_esmfold_calpha_coords = rotation.apply(centered_esmfold_calpha_coords)
281-
aligned_esmfold_dataset_dists = spa.distance.cdist(
282-
aligned_esmfold_calpha_coords, centered_dataset_ligand_coords
280+
aligned_predicted_calpha_coords = rotation.apply(centered_predicted_calpha_coords)
281+
aligned_predicted_dataset_dists = spa.distance.cdist(
282+
aligned_predicted_calpha_coords, centered_dataset_ligand_coords
283283
)
284284
inv_r_rmse = np.sqrt(
285-
np.mean(((1 / dataset_dists) - (1 / aligned_esmfold_dataset_dists)) ** 2)
285+
np.mean(((1 / dataset_dists) - (1 / aligned_predicted_dataset_dists)) ** 2)
286286
)
287287
else:
288288
inv_r_rmse = np.nan
@@ -292,23 +292,24 @@ def align_prediction(
292292
def get_alignment_rotation(
293293
pdb_id: str,
294294
dataset_protein_path: str,
295-
esmfold_protein_path: str,
295+
predicted_protein_path: str,
296296
dataset_path: str,
297297
) -> Tuple[Optional[Rotation], Optional[np.ndarray], Optional[np.ndarray]]:
298298
"""Calculate the alignment rotation between apo and holo protein structures and their ligand
299299
coordinates.
300300
301301
:param pdb_id: PDB ID of the protein-ligand complex.
302302
:param dataset_protein_path: Filepath to the PDB file of the protein structure from a dataset.
303-
:param esmfold_protein_path: Filepath to the PDB file of the protein structure from ESMFold.
303+
:param predicted_protein_path: Filepath to the PDB file of the protein structure from a
304+
structure predictor.
304305
:param dataset: Name of the dataset.
305306
:param dataset_path: Filepath to the PDB file containing ligand coordinates.
306307
:param lig_connection_radius: Radius for connecting ligand atoms.
307308
:param exclude_af2aa_excluded_ligs: Whether to exclude ligands excluded from the AF2-AA
308309
dataset.
309310
:param skip_parsed_ligands: Whether to skip parsing ligands if they have already been parsed.
310311
:return: A tuple containing rotation matrix (Optional[Rotation]), centroid of Ca atoms for a
311-
dataset protein (Optional[np.ndarray]), and centroid of Ca atoms for ESMFold
312+
dataset protein (Optional[np.ndarray]), and centroid of Ca atoms for a prediction
312313
(Optional[np.ndarray]).
313314
"""
314315
try:
@@ -319,10 +320,10 @@ def get_alignment_rotation(
319320
)
320321
return None, None, None
321322
try:
322-
esmfold_rec = parse_pdb_from_path(esmfold_protein_path)
323+
predicted_rec = parse_pdb_from_path(predicted_protein_path)
323324
except Exception as e:
324325
logger.warning(
325-
f"Unable to parse ESMFold protein structure for PDB ID {pdb_id} due to the error: {e}. Skipping..."
326+
f"Unable to parse predicted protein structure for PDB ID {pdb_id} due to the error: {e}. Skipping..."
326327
)
327328
return None, None, None
328329
dataset_ligand = read_mols(dataset_path, pdb_id, remove_hs=True)[0]
@@ -337,12 +338,12 @@ def get_alignment_rotation(
337338
)
338339
return None, None, None
339340
try:
340-
esmfold_calpha_coords = extract_receptor_structure(
341-
esmfold_rec, dataset_ligand, filter_out_hetero_residues=True
341+
predicted_calpha_coords = extract_receptor_structure(
342+
predicted_rec, dataset_ligand, filter_out_hetero_residues=True
342343
)[2]
343344
except Exception as e:
344345
logger.warning(
345-
f"Unable to extract ESMFold protein structure for PDB ID {pdb_id} due to the error: {e}. Skipping..."
346+
f"Unable to extract predicted protein structure for PDB ID {pdb_id} due to the error: {e}. Skipping..."
346347
)
347348
return None, None, None
348349
try:
@@ -353,85 +354,86 @@ def get_alignment_rotation(
353354
)
354355
return None, None, None
355356

356-
if dataset_calpha_coords.shape != esmfold_calpha_coords.shape:
357+
if dataset_calpha_coords.shape != predicted_calpha_coords.shape:
357358
logger.warning(
358359
f"Receptor structures differ for PDB ID {pdb_id}. Skipping due to shape mismatch:",
359360
dataset_calpha_coords.shape,
360-
esmfold_calpha_coords.shape,
361+
predicted_calpha_coords.shape,
361362
)
362363
return None, None, None
363364

364365
res = minimize(
365366
align_prediction,
366367
[0.1],
367368
bounds=Bounds([0.0], [1.0]),
368-
args=(dataset_calpha_coords, esmfold_calpha_coords, dataset_ligand_coords),
369+
args=(dataset_calpha_coords, predicted_calpha_coords, dataset_ligand_coords),
369370
tol=1e-8,
370371
)
371372

372373
smoothing_factor = res.x
373-
rotation, dataset_calpha_centroid, esmfold_calpha_centroid = align_prediction(
374+
rotation, dataset_calpha_centroid, predicted_calpha_centroid = align_prediction(
374375
smoothing_factor,
375376
dataset_calpha_coords,
376-
esmfold_calpha_coords,
377+
predicted_calpha_coords,
377378
dataset_ligand_coords,
378379
return_rotation=True,
379380
)
380381

381-
return rotation, dataset_calpha_centroid, esmfold_calpha_centroid
382+
return rotation, dataset_calpha_centroid, predicted_calpha_centroid
382383

383384

384385
def align_apo_structure_to_holo_structure(
385386
cfg: DictConfig, filename: str, atom_df_name: str = "ATOM"
386387
):
387-
"""Align a given ESMFold apo structure to its corresponding holo structure.
388+
"""Align a given predicted apo structure to its corresponding holo structure.
388389
389390
:param cfg: Hydra config for the alignment.
390-
:param filename: Filename of the ESMFold apo structure.
391+
:param filename: Filename of the predicted apo structure.
391392
:param atom_df_name: Name of the atom DataFrame derived from the corresponding PDB file input.
392393
"""
393394
pdb_id = "_".join(Path(filename).stem.split("_")[:2])
394-
esm_protein_filename = os.path.join(cfg.esmfold_structures_dir, f"{pdb_id}.pdb")
395+
predicted_protein_filename = os.path.join(cfg.predicted_structures_dir, f"{pdb_id}.pdb")
395396
processed_protein_name = f"{pdb_id}_protein.pdb"
396397
processed_protein_filename = os.path.join(cfg.data_dir, pdb_id, processed_protein_name)
397-
esm_protein_output_filename = os.path.join(
398-
cfg.output_dir, f"{pdb_id}_holo_aligned_esmfold_protein.pdb"
398+
predicted_protein_output_filename = os.path.join(
399+
cfg.output_dir, f"{pdb_id}_holo_aligned_predicted_protein.pdb"
399400
)
400401

401-
rotation, dataset_calpha_centroid, esmfold_calpha_centroid = get_alignment_rotation(
402+
rotation, dataset_calpha_centroid, predicted_calpha_centroid = get_alignment_rotation(
402403
pdb_id=pdb_id,
403404
dataset_protein_path=processed_protein_filename,
404-
esmfold_protein_path=esm_protein_filename,
405+
predicted_protein_path=predicted_protein_filename,
405406
dataset_path=cfg.data_dir,
406407
)
407408

408409
if any(
409-
[item is None for item in [rotation, dataset_calpha_centroid, esmfold_calpha_centroid]]
410+
[item is None for item in [rotation, dataset_calpha_centroid, predicted_calpha_centroid]]
410411
):
411412
return
412413

413-
ppdb_esmfold = PandasPdb().read_pdb(esm_protein_filename)
414-
ppdb_esmfold_pre_rot = (
415-
ppdb_esmfold.df[atom_df_name][["x_coord", "y_coord", "z_coord"]]
414+
ppdb_predicted = PandasPdb().read_pdb(predicted_protein_filename)
415+
ppdb_predicted_pre_rot = (
416+
ppdb_predicted.df[atom_df_name][["x_coord", "y_coord", "z_coord"]]
416417
.to_numpy()
417418
.squeeze()
418419
.astype(np.float32)
419420
)
420-
ppdb_esmfold_aligned = (
421-
rotation.apply(ppdb_esmfold_pre_rot - esmfold_calpha_centroid) + dataset_calpha_centroid
421+
ppdb_predicted_aligned = (
422+
rotation.apply(ppdb_predicted_pre_rot - predicted_calpha_centroid)
423+
+ dataset_calpha_centroid
422424
)
423425

424-
ppdb_esmfold.df[atom_df_name][["x_coord", "y_coord", "z_coord"]] = ppdb_esmfold_aligned
425-
ppdb_esmfold.to_pdb(path=esm_protein_output_filename, records=[atom_df_name], gz=False)
426+
ppdb_predicted.df[atom_df_name][["x_coord", "y_coord", "z_coord"]] = ppdb_predicted_aligned
427+
ppdb_predicted.to_pdb(path=predicted_protein_output_filename, records=[atom_df_name], gz=False)
426428

427429

428430
@hydra.main(
429431
version_base="1.3",
430432
config_path="../../../configs/data/components",
431-
config_name="esmfold_apo_to_holo_alignment.yaml",
433+
config_name="protein_apo_to_holo_alignment.yaml",
432434
)
433435
def main(cfg: DictConfig):
434-
"""Align all ESMFold apo structures to their corresponding holo structures.
436+
"""Align all predicted apo structures to their corresponding holo structures.
435437
436438
:param cfg: Hydra config for the alignments.
437439
"""
@@ -441,9 +443,9 @@ def main(cfg: DictConfig):
441443
os.makedirs(output_dir, exist_ok=True)
442444
structure_file_inputs = [
443445
file
444-
for file in os.listdir(cfg.esmfold_structures_dir)
446+
for file in os.listdir(cfg.predicted_structures_dir)
445447
if not os.path.exists(
446-
os.path.join(cfg.output_dir, f"{Path(file).stem}_holo_aligned_esmfold_protein.pdb")
448+
os.path.join(cfg.output_dir, f"{Path(file).stem}_holo_aligned_predicted_protein.pdb")
447449
)
448450
]
449451
pbar = tqdm(
@@ -461,7 +463,7 @@ def main(cfg: DictConfig):
461463
# wait for all tasks to complete
462464
for future in tqdm(
463465
futures,
464-
desc="Aligning each ESMFold apo structure to its corresponding holo structure",
466+
desc="Aligning each predicted apo structure to its corresponding holo structure",
465467
total=len(futures),
466468
):
467469
future.result()

0 commit comments

Comments
 (0)