diff --git a/openfe/protocols/openmm_rfe/equil_rfe_methods.py b/openfe/protocols/openmm_rfe/equil_rfe_methods.py index 97035386e..c494e7e6b 100644 --- a/openfe/protocols/openmm_rfe/equil_rfe_methods.py +++ b/openfe/protocols/openmm_rfe/equil_rfe_methods.py @@ -50,6 +50,7 @@ settings, ChemicalSystem, LigandAtomMapping, Component, ComponentMapping, SmallMoleculeComponent, ProteinComponent, SolventComponent, ) +from gufe.storage import stagingregistry from .equil_rfe_settings import ( RelativeHybridTopologyProtocolSettings, SystemSettings, @@ -578,8 +579,10 @@ def __init__(self, *, ) def run(self, *, dry=False, verbose=True, - scratch_basepath=None, - shared_basepath=None) -> dict[str, Any]: + scratch_basepath: pathlib.Path, + shared_basepath: stagingregistry.StagingPath, + permanent_basepath: stagingregistry.StagingPath, + ) -> dict[str, Any]: """Run the relative free energy calculation. Parameters @@ -591,10 +594,12 @@ def run(self, *, dry=False, verbose=True, verbose : bool Verbose output of the simulation progress. Output is provided via INFO level logging. - scratch_basepath: Pathlike, optional - Where to store temporary files, defaults to current working directory - shared_basepath : Pathlike, optional - Where to run the calculation, defaults to current working directory + scratch_basepath: pathlib.Path + Where to store temporary files + shared_basepath : StagingPath + Where to run the calculation + permanent_basepath : StagingPath + Where to store files that must persist beyond the DAG Returns ------- @@ -609,11 +614,6 @@ def run(self, *, dry=False, verbose=True, """ if verbose: self.logger.info("Preparing the hybrid topology simulation") - if scratch_basepath is None: - scratch_basepath = pathlib.Path('.') - if shared_basepath is None: - # use cwd - shared_basepath = pathlib.Path('.') # 0. General setup and settings dependency resolution step @@ -664,11 +664,13 @@ def run(self, *, dry=False, verbose=True, else: ffcache = None + ffcache.register() + system_generator = system_creation.get_system_generator( forcefield_settings=forcefield_settings, thermo_settings=thermo_settings, system_settings=system_settings, - cache=ffcache, + cache=ffcache.as_path(), has_solvent=solvent_comp is not None, ) @@ -812,10 +814,18 @@ def run(self, *, dry=False, verbose=True, ) # a. Create the multistate reporter - nc = shared_basepath / sim_settings.output_filename + # TODO: Logic about keeping/not .nc files goes here + nc = (shared_basepath / sim_settings.output_filename) + checkpoint = (shared_basepath / sim_settings.checkpoint_storage) + real_time_analysis = (shared_basepath / "real_time_analysis.yaml") + # have to flag these files as being created so that they get brought back + nc.register() + checkpoint.register() + real_time_analysis.register() + chk = sim_settings.checkpoint_storage reporter = multistate.MultiStateReporter( - storage=nc, + storage=str(nc.as_path()), analysis_particle_indices=selection_indices, checkpoint_interval=sim_settings.checkpoint_interval.m, checkpoint_storage=chk, @@ -947,13 +957,12 @@ def run(self, *, dry=False, verbose=True, sampling_method=sampler_settings.sampler_method.lower(), result_units=unit.kilocalorie_per_mole, ) - analyzer.plot(filepath=shared_basepath, filename_prefix="") + analyzer.plot(filepath=permanent_basepath, filename_prefix="") analyzer.close() else: # clean up the reporter file - fns = [shared_basepath / sim_settings.output_filename, - shared_basepath / sim_settings.checkpoint_storage] + fns = [nc.as_path(), checkpoint.as_path()] for fn in fns: os.remove(fn) finally: @@ -981,35 +990,38 @@ def run(self, *, dry=False, verbose=True, if not dry: # pragma: no-cover return { 'nc': nc, - 'last_checkpoint': chk, + 'last_checkpoint': checkpoint, **analyzer.unit_results_dict } else: return {'debug': {'sampler': sampler}} @staticmethod - def analyse(where) -> dict: + def analyse(where: stagingregistry.StagingPath) -> dict: # don't put energy analysis in here, it uses the open file reporter # whereas structural stuff requires that the file handle is closed - ret = subprocess.run(['openfe_analysis', str(where)], + output = (where / 'results.json') + ret = subprocess.run(['openfe_analysis', 'RFE_analysis', + str(where.as_path()), + str(output.as_path())], stdout=subprocess.PIPE, stderr=subprocess.PIPE) if ret.returncode: return {'structural_analysis_error': ret.stderr} - data = json.loads(ret.stdout) + with open(output, 'r') as f: + data = json.load(f) - savedir = pathlib.Path(where) if d := data['protein_2D_RMSD']: fig = plotting.plot_2D_rmsd(d) - fig.savefig(savedir / "protein_2D_RMSD.png") + fig.savefig(where / "protein_2D_RMSD.png") plt.close(fig) f2 = plotting.plot_ligand_COM_drift(data['time(ps)'], data['ligand_wander']) - f2.savefig(savedir / "ligand_COM_drift.png") + f2.savefig(where / "ligand_COM_drift.png") plt.close(f2) f3 = plotting.plot_ligand_RMSD(data['time(ps)'], data['ligand_RMSD']) - f3.savefig(savedir / "ligand_RMSD.png") + f3.savefig(where / "ligand_RMSD.png") plt.close(f3) return {'structural_analysis': data} @@ -1020,7 +1032,8 @@ def _execute( log_system_probe(logging.INFO, paths=[ctx.scratch]) with without_oechem_backend(): outputs = self.run(scratch_basepath=ctx.scratch, - shared_basepath=ctx.shared) + shared_basepath=ctx.shared, + permanent_basepath=ctx.permanent) analysis_outputs = self.analyse(ctx.shared)