Skip to content
129 changes: 120 additions & 9 deletions feflow/protocols/nonequilibrium_cycling.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,12 @@
from openff.units import unit
from openff.units.openmm import to_openmm, from_openmm

from ..utils.data import serialize, deserialize
from ..utils.data import (
serialize,
deserialize,
serialize_and_compress,
decompress_and_deserialize,
)

# Specific instance of logger for this module
# logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -133,6 +138,34 @@ def _execute(self, ctx, *, state_a, state_b, mapping, settings, **inputs):
from openfe.protocols.openmm_rfe import _rfe_utils
from feflow.utils.hybrid_topology import HybridTopologyFactory

if extends_data := self.inputs.get("extends_data"):

def _write_xml(data, filename):
openmm_object = decompress_and_deserialize(data)
serialize(openmm_object, filename)
return filename

for replicate in range(settings.num_replicates):
replicate = str(replicate)
system_outfile = ctx.shared / f"system_{replicate}.xml.bz2"
state_outfile = ctx.shared / f"state_{replicate}.xml.bz2"
integrator_outfile = ctx.shared / f"integrator_{replicate}.xml.bz2"

extends_data["systems"][replicate] = _write_xml(
extends_data["systems"][replicate],
system_outfile,
)
extends_data["states"][replicate] = _write_xml(
extends_data["states"][replicate],
state_outfile,
)
extends_data["integrators"][replicate] = _write_xml(
extends_data["integrators"][replicate],
integrator_outfile,
)

return extends_data

# Check compatibility between states (same receptor and solvent)
self._check_states_compatibility(state_a, state_b)

Expand Down Expand Up @@ -342,10 +375,18 @@ def _execute(self, ctx, *, state_a, state_b, mapping, settings, **inputs):
# Explicit cleanup for GPU resources
del context, integrator

systems = dict()
states = dict()
integrators = dict()
for replicate_name in map(str, range(settings.num_replicates)):
systems[replicate_name] = system_outfile
states[replicate_name] = state_outfile
integrators[replicate_name] = integrator_outfile

return {
"system": system_outfile,
"state": state_outfile,
"integrator": integrator_outfile,
"systems": systems,
"states": states,
"integrators": integrators,
"phase": phase,
"initial_atom_indices": hybrid_factory.initial_atom_indices,
"final_atom_indices": hybrid_factory.final_atom_indices,
Expand Down Expand Up @@ -434,9 +475,9 @@ def _execute(self, ctx, *, setup, settings, **inputs):
file_logger.addHandler(file_handler)

# Get state, system, and integrator from setup unit
system = deserialize(setup.outputs["system"])
state = deserialize(setup.outputs["state"])
integrator = deserialize(setup.outputs["integrator"])
system = deserialize(setup.outputs["systems"][self.name])
state = deserialize(setup.outputs["states"][self.name])
integrator = deserialize(setup.outputs["integrators"][self.name])
PeriodicNonequilibriumIntegrator.restore_interface(integrator)

# Get atom indices for either end of the hybrid topology
Expand Down Expand Up @@ -687,7 +728,20 @@ def _execute(self, ctx, *, setup, settings, **inputs):
"reverse_neq_final": reverse_neq_new_path,
}
finally:
compressed_state = serialize_and_compress(
context.getState(getPositions=True),
)

compressed_system = serialize_and_compress(
context.getSystem(),
)

compressed_integrator = serialize_and_compress(
context.getIntegrator(),
)

# Explicit cleanup for GPU resources

del context, integrator

return {
Expand All @@ -696,6 +750,9 @@ def _execute(self, ctx, *, setup, settings, **inputs):
"trajectory_paths": trajectory_paths,
"log": output_log_path,
"timing_info": timing_info,
"system": compressed_system,
"state": compressed_state,
"integrator": compressed_integrator,
}


Expand Down Expand Up @@ -890,10 +947,63 @@ def _create(
# Handle parameters
if mapping is None:
raise ValueError("`mapping` is required for this Protocol")

if "ligand" not in mapping:
raise ValueError("'ligand' must be specified in `mapping` dict")
if extends:
raise NotImplementedError("Can't extend simulations yet")

extends_data = {}
if isinstance(extends, ProtocolDAGResult):

if not extends.ok():
raise ValueError("Cannot extend protocols that failed")

setup = extends.protocol_units[0]
simulations = extends.protocol_units[1:-1]

r_setup = extends.protocol_unit_results[0]
r_simulations = extends.protocol_unit_results[1:-1]

# confirm consistency
original_state_a = setup.inputs["state_a"].key
original_state_b = setup.inputs["state_b"].key
original_mapping = setup.inputs["mapping"]

if original_state_a != stateA.key:
raise ValueError(
"'stateA' key is not the same as the key provided by the 'extends' ProtocolDAGResult."
)

if original_state_b != stateB.key:
raise ValueError(
"'stateB' key is not the same as the key provided by the 'extends' ProtocolDAGResult."
)

if mapping is not None:
if original_mapping != mapping:
raise ValueError(
"'mapping' is not consistent with the mapping provided by the 'extnds' ProtocolDAGResult."
)
else:
mapping = original_mapping

systems = {}
states = {}
integrators = {}

for r_simulation, simulation in zip(r_simulations, simulations):
sim_name = simulation.name
systems[sim_name] = r_simulation.outputs["system"]
states[sim_name] = r_simulation.outputs["state"]
integrators[sim_name] = r_simulation.outputs["integrator"]

extends_data = dict(
systems=systems,
states=states,
integrators=integrators,
phase=r_setup.outputs["phase"],
initial_atom_indices=r_setup.outputs["initial_atom_indices"],
final_atom_indices=r_setup.outputs["final_atom_indices"],
)

# inputs to `ProtocolUnit.__init__` should either be `Gufe` objects
# or JSON-serializable objects
Expand All @@ -905,6 +1015,7 @@ def _create(
mapping=mapping,
settings=self.settings,
name="setup",
extends_data=extends_data,
)

simulations = [
Expand Down
4 changes: 2 additions & 2 deletions feflow/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ def short_settings():
settings = NonEquilibriumCyclingProtocol.default_settings()

settings.thermo_settings.temperature = 300 * unit.kelvin
settings.eq_steps = 25000
settings.neq_steps = 25000
settings.eq_steps = 1000
settings.neq_steps = 1000
settings.work_save_frequency = 50
settings.traj_save_frequency = 250
settings.platform = "CPU"
Expand Down
92 changes: 92 additions & 0 deletions feflow/tests/test_nonequilibrium_cycling.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,98 @@ def test_terminal_units(self, protocol_dag_result):
assert isinstance(finals[0], ProtocolUnitResult)
assert finals[0].name == "result"

@pytest.mark.parametrize(
"protocol",
[
"protocol_short",
"protocol_short_multiple_cycles",
],
)
def test_pdr_extend(
self,
protocol,
benzene_vacuum_system,
toluene_vacuum_system,
mapping_benzene_toluene,
tmpdir,
request,
):

protocol = request.getfixturevalue(protocol)
dag = protocol.create(
stateA=benzene_vacuum_system,
stateB=toluene_vacuum_system,
name="Short vacuum transformation",
mapping={"ligand": mapping_benzene_toluene},
)

with tmpdir.as_cwd():

base_path = Path("original")

shared = base_path / "shared"
shared.mkdir(parents=True)

scratch = base_path / "scratch"
scratch.mkdir(parents=True)

pdr: ProtocolDAGResult = execute_DAG(
dag, shared_basedir=shared, scratch_basedir=scratch
)

setup = pdr.protocol_units[0]
r_setup = pdr.protocol_unit_results[0]

assert setup.inputs["extends_data"] == {}

end_states = {}
for simulation, r_simulation in zip(
pdr.protocol_units[1:-1], pdr.protocol_unit_results[1:-1]
):
assert isinstance(r_simulation.outputs["system"], str)
assert isinstance(r_simulation.outputs["state"], str)
assert isinstance(r_simulation.outputs["integrator"], str)

end_states[simulation.name] = r_simulation.outputs["state"]

dag = protocol.create(
stateA=benzene_vacuum_system,
stateB=toluene_vacuum_system,
name="Short vacuum transformation, but extended",
mapping={"ligand": mapping_benzene_toluene},
extends=ProtocolDAGResult.from_dict(pdr.to_dict()),
)

with tmpdir.as_cwd():

base_path = Path("extended")

shared = base_path / "shared"
shared.mkdir(parents=True)

scratch = base_path / "scratch"
scratch.mkdir(parents=True)
pdr: ProtocolDAGResult = execute_DAG(
dag, shared_basedir=shared, scratch_basedir=scratch
)

r_setup = pdr.protocol_unit_results[0]

assert r_setup.inputs["extends_data"] != {}

for replicate in range(protocol.settings.num_replicates):
replicate = str(replicate)
assert isinstance(r_setup.inputs["extends_data"]["systems"][replicate], str)
assert isinstance(r_setup.inputs["extends_data"]["states"][replicate], str)
assert isinstance(
r_setup.inputs["extends_data"]["integrators"][replicate], str
)

assert (
r_setup.inputs["extends_data"]["states"][replicate]
== end_states[replicate]
)

def test_dag_execute_failure(self, protocol_dag_broken):
protocol, dag, dagfailure = protocol_dag_broken

Expand Down
44 changes: 43 additions & 1 deletion feflow/utils/data.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,48 @@
import os
import pathlib
import bz2
import base64

from openmm import XmlSerializer


def serialize_and_compress(item) -> str:
"""Serialize an OpenMM System, State, or Integrator and compress.

Parameters
----------
item : System, State, or Integrator
The OpenMM object to serialize and compress.

Returns
-------
b64string : str
The compressed serialized OpenMM object encoded in a Base64 string.
"""
serialized = XmlSerializer.serialize(item).encode()
compressed = bz2.compress(serialized)
b64string = base64.b64encode(compressed).decode("ascii")
return b64string


def decompress_and_deserialize(data: str):
"""Recover an OpenMM object from compression.

Parameters
----------
data : str
String containing a Base64 encoded bzip2 compressed XML serialization
of an OpenMM object.

Returns
-------
deserialized
The deserialized OpenMM object.
"""
compressed = base64.b64decode(data)
decompressed = bz2.decompress(compressed).decode("utf-8")
deserialized = XmlSerializer.deserialize(decompressed)
return deserialized


def serialize(item, filename: pathlib.Path):
Expand All @@ -13,7 +56,6 @@ def serialize(item, filename: pathlib.Path):
filename : str
The filename to serialize to
"""
from openmm import XmlSerializer

# Create parent directory if it doesn't exist
filename_basedir = filename.parent
Expand Down