diff --git a/arc/scheduler.py b/arc/scheduler.py index 6f4d81f39b..9be220b700 100644 --- a/arc/scheduler.py +++ b/arc/scheduler.py @@ -2763,7 +2763,6 @@ def switch_ts(self, label: str): logger.info(f'Switching a TS guess for {label}...') self.determine_most_likely_ts_conformer(label=label) # Look for a different TS guess. self.delete_all_species_jobs(label=label) # Delete other currently running jobs for this TS. - self.output[label]['geo'] = self.output[label]['freq'] = self.output[label]['sp'] = self.output[label]['composite'] = '' freq_path = os.path.join(self.project_directory, 'output', 'rxns', label, 'geometry', 'freq.out') if os.path.isfile(freq_path): os.remove(freq_path) @@ -3555,7 +3554,14 @@ def troubleshoot_ess(self, f'log file:\n"{job.job_status[1]["line"]}".' logger.warning(warning_message) if self.species_dict[label].is_ts and conformer is not None: - xyz = self.species_dict[label].ts_guesses[conformer].get_xyz() + tsg = next((t for t in self.species_dict[label].ts_guesses + if t.conformer_index == conformer), None) + if tsg is not None: + xyz = tsg.get_xyz() + else: + logger.warning(f'Could not find TS guess with index {conformer} for {label}; ' + f'skipping troubleshooting for this conformer.') + return None elif conformer is not None: xyz = self.species_dict[label].conformers[conformer] else: @@ -3705,6 +3711,33 @@ def delete_all_species_jobs(self, label: str): job.delete() self.running_jobs[label] = list() self.output[label]['paths'] = {key: '' if key != 'irc' else list() for key in self.output[label]['paths'].keys()} + for job_type in self.output[label]['job_types']: + self.output[label]['job_types'][job_type] = False + self.output[label]['convergence'] = None + self._pending_pipe_sp.discard(label) + self._pending_pipe_freq.discard(label) + self._pending_pipe_irc.discard((label, 'forward')) + self._pending_pipe_irc.discard((label, 'reverse')) + # Clean up any IRC species spawned from this TS. + if label in self.species_dict and self.species_dict[label].is_ts: + irc_labels_str = self.species_dict[label].irc_label + if irc_labels_str: + for irc_label in irc_labels_str.split(): + if irc_label in self.job_dict and irc_label in self.output: + self.delete_all_species_jobs(irc_label) + if irc_label in self.running_jobs: + del self.running_jobs[irc_label] + if irc_label in self.job_dict: + del self.job_dict[irc_label] + if irc_label in self.output: + del self.output[irc_label] + if irc_label in self.species_dict: + self.species_list = [spc for spc in self.species_list if spc.label != irc_label] + del self.species_dict[irc_label] + if irc_label in self.unique_species_labels: + self.unique_species_labels.remove(irc_label) + logger.info(f'Deleted IRC species {irc_label}.') + self.species_dict[label].irc_label = None def restore_running_jobs(self): """ diff --git a/arc/scheduler_test.py b/arc/scheduler_test.py index 3216a9f254..97cf61902f 100644 --- a/arc/scheduler_test.py +++ b/arc/scheduler_test.py @@ -6,6 +6,7 @@ """ import unittest +from unittest.mock import patch import os import shutil @@ -19,7 +20,7 @@ from arc.imports import settings from arc.reaction import ARCReaction from arc.species.converter import str_to_xyz -from arc.species.species import ARCSpecies +from arc.species.species import ARCSpecies, TSGuess default_levels_of_theory = settings['default_levels_of_theory'] @@ -757,6 +758,117 @@ def test_add_label_to_unique_species_labels(self): self.assertEqual(unique_label, 'new_species_15_1') self.assertEqual(self.sched2.unique_species_labels, ['methylamine', 'C2H6', 'CtripCO', 'new_species_15', 'new_species_15_0', 'new_species_15_1']) + @patch('arc.scheduler.Scheduler.run_opt_job') + def test_switch_ts_cleanup(self, mock_run_opt): + """Test that switch_ts resets job_types, convergence, cleans up IRC species, and clears pending pipes.""" + ts_xyz = str_to_xyz("""N 0.91779059 0.51946178 0.00000000 + H 1.81402049 1.03819414 0.00000000 + H 0.00000000 0.00000000 0.00000000 + H 0.91779059 1.22790192 0.72426890""") + + ts_spc = ARCSpecies(label='TS_test', is_ts=True, xyz=ts_xyz, multiplicity=1, charge=0, + compute_thermo=False) + # Create two TSGuess objects so determine_most_likely_ts_conformer can pick the 2nd after the 1st fails. + ts_spc.ts_guesses = [ + TSGuess(index=0, method='heuristics', success=True, energy=100.0, xyz=ts_xyz, + execution_time='0:00:01'), + TSGuess(index=1, method='heuristics', success=True, energy=110.0, xyz=ts_xyz, + execution_time='0:00:01'), + ] + ts_spc.ts_guesses[0].opt_xyz = ts_xyz + ts_spc.ts_guesses[0].imaginary_freqs = [-500.0] + ts_spc.ts_guesses[1].opt_xyz = ts_xyz + ts_spc.ts_guesses[1].imaginary_freqs = [-400.0] + # Simulate guess 0 already tried. + ts_spc.chosen_ts = 0 + ts_spc.chosen_ts_list = [0] + ts_spc.ts_guesses_exhausted = False + + project_directory = os.path.join(ARC_PATH, 'Projects', + 'arc_project_for_testing_delete_after_usage4') + self.addCleanup(shutil.rmtree, project_directory, ignore_errors=True) + sched = Scheduler(project='test_switch_ts', ess_settings=self.ess_settings, + species_list=[ts_spc], + opt_level=Level(repr=default_levels_of_theory['opt']), + freq_level=Level(repr=default_levels_of_theory['freq']), + sp_level=Level(repr=default_levels_of_theory['sp']), + ts_guess_level=Level(repr=default_levels_of_theory['ts_guesses']), + project_directory=project_directory, + testing=True, + job_types=self.job_types1, + ) + + ts_label = 'TS_test' + # Simulate state after guess 0 completed: freq/sp/opt marked done. + sched.output[ts_label]['job_types']['opt'] = True + sched.output[ts_label]['job_types']['freq'] = True + sched.output[ts_label]['job_types']['sp'] = True + sched.output[ts_label]['convergence'] = True + sched.job_dict[ts_label] = {'opt': {}, 'freq': {}, 'sp': {}} + sched.running_jobs[ts_label] = [] + + # Simulate IRC species spawned from guess 0. + irc_label_1 = 'IRC_TS_test_1' + irc_label_2 = 'IRC_TS_test_2' + irc_spc_1 = ARCSpecies(label=irc_label_1, xyz=ts_xyz, compute_thermo=False, + irc_label=ts_label) + irc_spc_2 = ARCSpecies(label=irc_label_2, xyz=ts_xyz, compute_thermo=False, + irc_label=ts_label) + ts_spc.irc_label = f'{irc_label_1} {irc_label_2}' + sched.species_dict[irc_label_1] = irc_spc_1 + sched.species_dict[irc_label_2] = irc_spc_2 + sched.species_list.extend([irc_spc_1, irc_spc_2]) + sched.unique_species_labels.extend([irc_label_1, irc_label_2]) + sched.running_jobs[irc_label_1] = ['opt_a100'] + sched.running_jobs[irc_label_2] = ['opt_a101'] + sched.job_dict[irc_label_1] = {'opt': {}} + sched.job_dict[irc_label_2] = {'opt': {}} + sched.initialize_output_dict(label=irc_label_1) + sched.initialize_output_dict(label=irc_label_2) + + # Simulate pending pipe entries from the old guess. + sched._pending_pipe_sp.add(ts_label) + sched._pending_pipe_freq.add(ts_label) + sched._pending_pipe_irc.add((ts_label, 'forward')) + sched._pending_pipe_irc.add((ts_label, 'reverse')) + + # Call switch_ts — should pick guess 1 and clean up all state from guess 0. + sched.switch_ts(ts_label) + + # Verify guess 1 was selected. + self.assertEqual(sched.species_dict[ts_label].chosen_ts, 1) + self.assertIn(1, sched.species_dict[ts_label].chosen_ts_list) + + # Verify IRC species from guess 0 fully removed. + self.assertNotIn(irc_label_1, sched.species_dict) + self.assertNotIn(irc_label_2, sched.species_dict) + self.assertNotIn(irc_label_1, sched.running_jobs) + self.assertNotIn(irc_label_2, sched.running_jobs) + self.assertNotIn(irc_label_1, sched.job_dict) + self.assertNotIn(irc_label_2, sched.job_dict) + self.assertNotIn(irc_label_1, sched.output) + self.assertNotIn(irc_label_2, sched.output) + self.assertNotIn(irc_label_1, sched.unique_species_labels) + self.assertNotIn(irc_label_2, sched.unique_species_labels) + self.assertIsNone(sched.species_dict[ts_label].irc_label) + + # Verify job_types reset and convergence cleared. + self.assertFalse(sched.output[ts_label]['job_types']['opt']) + self.assertFalse(sched.output[ts_label]['job_types']['freq']) + self.assertFalse(sched.output[ts_label]['job_types']['sp']) + self.assertIsNone(sched.output[ts_label]['convergence']) + + # Verify pending pipe entries cleared. + self.assertNotIn(ts_label, sched._pending_pipe_sp) + self.assertNotIn(ts_label, sched._pending_pipe_freq) + self.assertNotIn((ts_label, 'forward'), sched._pending_pipe_irc) + self.assertNotIn((ts_label, 'reverse'), sched._pending_pipe_irc) + + # Verify ts_checks were reset. + self.assertIsNone(sched.species_dict[ts_label].ts_checks['freq']) + self.assertIsNone(sched.species_dict[ts_label].ts_checks['NMD']) + self.assertIsNone(sched.species_dict[ts_label].ts_checks['E0']) + @classmethod def tearDownClass(cls): """