diff --git a/feflow/protocols/nonequilibrium_cycling.py b/feflow/protocols/nonequilibrium_cycling.py index f3b25c60..bde1691a 100644 --- a/feflow/protocols/nonequilibrium_cycling.py +++ b/feflow/protocols/nonequilibrium_cycling.py @@ -25,7 +25,12 @@ from openff.units import unit from openff.units.openmm import to_openmm, from_openmm -from ..utils.data import serialize, deserialize +from ..utils.data import ( + serialize, + deserialize, + serialize_and_compress, + decompress_and_deserialize, +) # Specific instance of logger for this module # logger = logging.getLogger(__name__) @@ -133,6 +138,34 @@ def _execute(self, ctx, *, state_a, state_b, mapping, settings, **inputs): from openfe.protocols.openmm_rfe import _rfe_utils from feflow.utils.hybrid_topology import HybridTopologyFactory + if extends_data := self.inputs.get("extends_data"): + + def _write_xml(data, filename): + openmm_object = decompress_and_deserialize(data) + serialize(openmm_object, filename) + return filename + + for replicate in range(settings.num_replicates): + replicate = str(replicate) + system_outfile = ctx.shared / f"system_{replicate}.xml.bz2" + state_outfile = ctx.shared / f"state_{replicate}.xml.bz2" + integrator_outfile = ctx.shared / f"integrator_{replicate}.xml.bz2" + + extends_data["systems"][replicate] = _write_xml( + extends_data["systems"][replicate], + system_outfile, + ) + extends_data["states"][replicate] = _write_xml( + extends_data["states"][replicate], + state_outfile, + ) + extends_data["integrators"][replicate] = _write_xml( + extends_data["integrators"][replicate], + integrator_outfile, + ) + + return extends_data + # Check compatibility between states (same receptor and solvent) self._check_states_compatibility(state_a, state_b) @@ -342,10 +375,18 @@ def _execute(self, ctx, *, state_a, state_b, mapping, settings, **inputs): # Explicit cleanup for GPU resources del context, integrator + systems = dict() + states = dict() + integrators = dict() + for replicate_name in map(str, range(settings.num_replicates)): + systems[replicate_name] = system_outfile + states[replicate_name] = state_outfile + integrators[replicate_name] = integrator_outfile + return { - "system": system_outfile, - "state": state_outfile, - "integrator": integrator_outfile, + "systems": systems, + "states": states, + "integrators": integrators, "phase": phase, "initial_atom_indices": hybrid_factory.initial_atom_indices, "final_atom_indices": hybrid_factory.final_atom_indices, @@ -434,9 +475,9 @@ def _execute(self, ctx, *, setup, settings, **inputs): file_logger.addHandler(file_handler) # Get state, system, and integrator from setup unit - system = deserialize(setup.outputs["system"]) - state = deserialize(setup.outputs["state"]) - integrator = deserialize(setup.outputs["integrator"]) + system = deserialize(setup.outputs["systems"][self.name]) + state = deserialize(setup.outputs["states"][self.name]) + integrator = deserialize(setup.outputs["integrators"][self.name]) PeriodicNonequilibriumIntegrator.restore_interface(integrator) # Get atom indices for either end of the hybrid topology @@ -687,7 +728,20 @@ def _execute(self, ctx, *, setup, settings, **inputs): "reverse_neq_final": reverse_neq_new_path, } finally: + compressed_state = serialize_and_compress( + context.getState(getPositions=True), + ) + + compressed_system = serialize_and_compress( + context.getSystem(), + ) + + compressed_integrator = serialize_and_compress( + context.getIntegrator(), + ) + # Explicit cleanup for GPU resources + del context, integrator return { @@ -696,6 +750,9 @@ def _execute(self, ctx, *, setup, settings, **inputs): "trajectory_paths": trajectory_paths, "log": output_log_path, "timing_info": timing_info, + "system": compressed_system, + "state": compressed_state, + "integrator": compressed_integrator, } @@ -890,10 +947,63 @@ def _create( # Handle parameters if mapping is None: raise ValueError("`mapping` is required for this Protocol") + if "ligand" not in mapping: raise ValueError("'ligand' must be specified in `mapping` dict") - if extends: - raise NotImplementedError("Can't extend simulations yet") + + extends_data = {} + if isinstance(extends, ProtocolDAGResult): + + if not extends.ok(): + raise ValueError("Cannot extend protocols that failed") + + setup = extends.protocol_units[0] + simulations = extends.protocol_units[1:-1] + + r_setup = extends.protocol_unit_results[0] + r_simulations = extends.protocol_unit_results[1:-1] + + # confirm consistency + original_state_a = setup.inputs["state_a"].key + original_state_b = setup.inputs["state_b"].key + original_mapping = setup.inputs["mapping"] + + if original_state_a != stateA.key: + raise ValueError( + "'stateA' key is not the same as the key provided by the 'extends' ProtocolDAGResult." + ) + + if original_state_b != stateB.key: + raise ValueError( + "'stateB' key is not the same as the key provided by the 'extends' ProtocolDAGResult." + ) + + if mapping is not None: + if original_mapping != mapping: + raise ValueError( + "'mapping' is not consistent with the mapping provided by the 'extnds' ProtocolDAGResult." + ) + else: + mapping = original_mapping + + systems = {} + states = {} + integrators = {} + + for r_simulation, simulation in zip(r_simulations, simulations): + sim_name = simulation.name + systems[sim_name] = r_simulation.outputs["system"] + states[sim_name] = r_simulation.outputs["state"] + integrators[sim_name] = r_simulation.outputs["integrator"] + + extends_data = dict( + systems=systems, + states=states, + integrators=integrators, + phase=r_setup.outputs["phase"], + initial_atom_indices=r_setup.outputs["initial_atom_indices"], + final_atom_indices=r_setup.outputs["final_atom_indices"], + ) # inputs to `ProtocolUnit.__init__` should either be `Gufe` objects # or JSON-serializable objects @@ -905,6 +1015,7 @@ def _create( mapping=mapping, settings=self.settings, name="setup", + extends_data=extends_data, ) simulations = [ diff --git a/feflow/tests/conftest.py b/feflow/tests/conftest.py index d419e562..2f8f39c4 100644 --- a/feflow/tests/conftest.py +++ b/feflow/tests/conftest.py @@ -74,8 +74,8 @@ def short_settings(): settings = NonEquilibriumCyclingProtocol.default_settings() settings.thermo_settings.temperature = 300 * unit.kelvin - settings.eq_steps = 25000 - settings.neq_steps = 25000 + settings.eq_steps = 1000 + settings.neq_steps = 1000 settings.work_save_frequency = 50 settings.traj_save_frequency = 250 settings.platform = "CPU" diff --git a/feflow/tests/test_nonequilibrium_cycling.py b/feflow/tests/test_nonequilibrium_cycling.py index e7e58f92..96820938 100644 --- a/feflow/tests/test_nonequilibrium_cycling.py +++ b/feflow/tests/test_nonequilibrium_cycling.py @@ -100,6 +100,98 @@ def test_terminal_units(self, protocol_dag_result): assert isinstance(finals[0], ProtocolUnitResult) assert finals[0].name == "result" + @pytest.mark.parametrize( + "protocol", + [ + "protocol_short", + "protocol_short_multiple_cycles", + ], + ) + def test_pdr_extend( + self, + protocol, + benzene_vacuum_system, + toluene_vacuum_system, + mapping_benzene_toluene, + tmpdir, + request, + ): + + protocol = request.getfixturevalue(protocol) + dag = protocol.create( + stateA=benzene_vacuum_system, + stateB=toluene_vacuum_system, + name="Short vacuum transformation", + mapping={"ligand": mapping_benzene_toluene}, + ) + + with tmpdir.as_cwd(): + + base_path = Path("original") + + shared = base_path / "shared" + shared.mkdir(parents=True) + + scratch = base_path / "scratch" + scratch.mkdir(parents=True) + + pdr: ProtocolDAGResult = execute_DAG( + dag, shared_basedir=shared, scratch_basedir=scratch + ) + + setup = pdr.protocol_units[0] + r_setup = pdr.protocol_unit_results[0] + + assert setup.inputs["extends_data"] == {} + + end_states = {} + for simulation, r_simulation in zip( + pdr.protocol_units[1:-1], pdr.protocol_unit_results[1:-1] + ): + assert isinstance(r_simulation.outputs["system"], str) + assert isinstance(r_simulation.outputs["state"], str) + assert isinstance(r_simulation.outputs["integrator"], str) + + end_states[simulation.name] = r_simulation.outputs["state"] + + dag = protocol.create( + stateA=benzene_vacuum_system, + stateB=toluene_vacuum_system, + name="Short vacuum transformation, but extended", + mapping={"ligand": mapping_benzene_toluene}, + extends=ProtocolDAGResult.from_dict(pdr.to_dict()), + ) + + with tmpdir.as_cwd(): + + base_path = Path("extended") + + shared = base_path / "shared" + shared.mkdir(parents=True) + + scratch = base_path / "scratch" + scratch.mkdir(parents=True) + pdr: ProtocolDAGResult = execute_DAG( + dag, shared_basedir=shared, scratch_basedir=scratch + ) + + r_setup = pdr.protocol_unit_results[0] + + assert r_setup.inputs["extends_data"] != {} + + for replicate in range(protocol.settings.num_replicates): + replicate = str(replicate) + assert isinstance(r_setup.inputs["extends_data"]["systems"][replicate], str) + assert isinstance(r_setup.inputs["extends_data"]["states"][replicate], str) + assert isinstance( + r_setup.inputs["extends_data"]["integrators"][replicate], str + ) + + assert ( + r_setup.inputs["extends_data"]["states"][replicate] + == end_states[replicate] + ) + def test_dag_execute_failure(self, protocol_dag_broken): protocol, dag, dagfailure = protocol_dag_broken diff --git a/feflow/utils/data.py b/feflow/utils/data.py index f829346f..64665d08 100644 --- a/feflow/utils/data.py +++ b/feflow/utils/data.py @@ -1,5 +1,48 @@ import os import pathlib +import bz2 +import base64 + +from openmm import XmlSerializer + + +def serialize_and_compress(item) -> str: + """Serialize an OpenMM System, State, or Integrator and compress. + + Parameters + ---------- + item : System, State, or Integrator + The OpenMM object to serialize and compress. + + Returns + ------- + b64string : str + The compressed serialized OpenMM object encoded in a Base64 string. + """ + serialized = XmlSerializer.serialize(item).encode() + compressed = bz2.compress(serialized) + b64string = base64.b64encode(compressed).decode("ascii") + return b64string + + +def decompress_and_deserialize(data: str): + """Recover an OpenMM object from compression. + + Parameters + ---------- + data : str + String containing a Base64 encoded bzip2 compressed XML serialization + of an OpenMM object. + + Returns + ------- + deserialized + The deserialized OpenMM object. + """ + compressed = base64.b64decode(data) + decompressed = bz2.decompress(compressed).decode("utf-8") + deserialized = XmlSerializer.deserialize(decompressed) + return deserialized def serialize(item, filename: pathlib.Path): @@ -13,7 +56,6 @@ def serialize(item, filename: pathlib.Path): filename : str The filename to serialize to """ - from openmm import XmlSerializer # Create parent directory if it doesn't exist filename_basedir = filename.parent