diff --git a/.gitignore b/.gitignore index f8c8a682..d9853b91 100644 --- a/.gitignore +++ b/.gitignore @@ -37,4 +37,4 @@ potential all_result.json # dflow debug folders generated by examples or local runs -**/dflow_debug/ +**/dflow_debug/ \ No newline at end of file diff --git a/apex/config.py b/apex/config.py index 28e7e457..7fd4aee5 100644 --- a/apex/config.py +++ b/apex/config.py @@ -67,6 +67,10 @@ class Config: vasp_run_command: str = None abacus_image_name: str = None abacus_run_command: str = None + lammps_header_retry_attempts: int = 2 + lammps_header_retry_delay: float = 5 + lammps_transient_retry_attempts: int = 1 + lammps_retry_group_size: int = None # common APEX config is_bohrium_dflow: bool = False @@ -256,7 +260,11 @@ def basic_config_dict(self): "vasp_image_name": self.vasp_image_name, "vasp_run_command": self.vasp_run_command, "abacus_image_name": self.abacus_image_name, - "abacus_run_command": self.abacus_run_command + "abacus_run_command": self.abacus_run_command, + "lammps_header_retry_attempts": self.lammps_header_retry_attempts, + "lammps_header_retry_delay": self.lammps_header_retry_delay, + "lammps_transient_retry_attempts": self.lammps_transient_retry_attempts, + "lammps_retry_group_size": self.lammps_retry_group_size } return basic_config diff --git a/apex/core/property/Gamma.py b/apex/core/property/Gamma.py index faed243c..28b941b8 100644 --- a/apex/core/property/Gamma.py +++ b/apex/core/property/Gamma.py @@ -36,6 +36,7 @@ def __init__(self, parameter, inter_param=None): self.reprod = parameter["reproduce"] if not self.reprod: if not ("init_from_suffix" in parameter and "output_suffix" in parameter): + parameter["cal_type"] = parameter.get("cal_type", "relaxation") parameter["plane_miller"] = parameter.get("plane_miller", None) self.plane_miller = parameter["plane_miller"] parameter["slip_direction"] = parameter.get("slip_direction", None) @@ -48,14 +49,18 @@ def __init__(self, parameter, inter_param=None): self.supercell_size = parameter["supercell_size"] parameter["vacuum_size"] = parameter.get("vacuum_size", 0) self.vacuum_size = parameter["vacuum_size"] - parameter["add_fix"] = parameter.get( - "add_fix", ["true", "true", "false"] - ) # standard method + if parameter["cal_type"] == "static" and "add_fix" not in parameter: + parameter["add_fix"] = None + else: + parameter["add_fix"] = parameter.get( + "add_fix", ["true", "true", "false"] + ) # standard method self.add_fix = parameter["add_fix"] parameter["n_steps"] = parameter.get("n_steps", 10) self.n_steps = parameter["n_steps"] self.atom_num = None - parameter["cal_type"] = parameter.get("cal_type", "relaxation") + else: + parameter["cal_type"] = parameter.get("cal_type", "relaxation") default_cal_setting = { "relax_pos": True, "relax_shape": False, @@ -157,7 +162,18 @@ def make_confs(self, path_to_work, path_to_equi, refine=False): equi_contcar = os.path.join(path_to_equi, CONTCAR) if not os.path.exists(equi_contcar): - raise RuntimeError("please do relaxation first") + raise RuntimeError( + "Gamma requires a baseline relaxation before PropsMake. " + "For static gamma scans, run a static baseline relaxation " + "with relax_pos=false, relax_shape=false, and relax_vol=false." + ) + equi_result = os.path.join(path_to_equi, "result.json") + if not os.path.exists(equi_result): + raise RuntimeError( + "Gamma post-processing requires relaxation/relax_task/result.json. " + "Please provide a static baseline relaxation result before " + "running gamma PropsMake." + ) # print("we now only support gamma line calculation for BCC FCC and HCP metals") # print( # f"supported slip systems are:\n{SlabSlipSystem.hint_string()}" @@ -451,6 +467,8 @@ def __inLammpes_fix(self, inLammps) -> None: ) with open(inLammps, "r") as fin1: contents = fin1.readlines() + lower_id = None + upper_id = None for ii in range(len(contents)): upper = contents[ii].split()[:3] == ["variable", "N", "equal"] lower = re.search("min_style cg", contents[ii]) @@ -460,6 +478,13 @@ def __inLammpes_fix(self, inLammps) -> None: elif upper: upper_id = ii # print(upper_id) + if lower_id is None or upper_id is None or lower_id >= upper_id: + raise RuntimeError( + "Gamma add_fix was requested, but in.lammps does not contain " + "a compatible minimization block to patch. For static gamma " + "calculations set add_fix to null, or use a relaxation-style " + "LAMMPS input with min_style cg and variable N equal markers." + ) del contents[lower_id + 1:upper_id - 1] contents.insert(lower_id + 1, add_fix_str) with open(inLammps, "w") as fin2: diff --git a/apex/core/property/Gruneisen.py b/apex/core/property/Gruneisen.py index 4ccedaa7..2bf0462d 100644 --- a/apex/core/property/Gruneisen.py +++ b/apex/core/property/Gruneisen.py @@ -794,8 +794,8 @@ def _ensure_mesh_yaml(self, task_dir: str) -> None: raise FileNotFoundError(f"POSCAR not found in {task_dir}") os.chdir(task_dir) cell_file = "POSCAR-unitcell" if self.inter_param["type"] == "vasp" else "POSCAR" - command = ( - 'phonopy --nomeshsym --dim="%s %s %s" -c %s band.conf' + command = Phonon.phonopy_command( + '--nomeshsym --dim="%s %s %s" -c %s band.conf' % (self.supercell_size[0], self.supercell_size[1], self.supercell_size[2], cell_file) ) subprocess.check_call(command, shell=True) @@ -868,8 +868,8 @@ def _ensure_vasp_volume_outputs( cwd = os.getcwd() try: os.chdir(helper_dir) - subprocess.check_call( - Phonon.phonopy_setup_command( + Phonon.run_first_success( + Phonon.phonopy_writefc_commands( '--dim="%s %s %s" -c POSCAR-unitcell --writefc' % ( self.supercell_size[0], @@ -877,7 +877,7 @@ def _ensure_vasp_volume_outputs( self.supercell_size[2], ) ), - shell=True, + required_file="FORCE_CONSTANTS", ) finally: os.chdir(cwd) @@ -888,8 +888,10 @@ def _ensure_vasp_volume_outputs( try: os.chdir(helper_dir) subprocess.check_call( - 'phonopy --dim="%s %s %s" -c POSCAR-unitcell band.conf' - % (self.supercell_size[0], self.supercell_size[1], self.supercell_size[2]), + Phonon.phonopy_command( + '--dim="%s %s %s" -c POSCAR-unitcell band.conf' + % (self.supercell_size[0], self.supercell_size[1], self.supercell_size[2]) + ), shell=True, ) self._write_band_dat() @@ -954,14 +956,14 @@ def _ensure_abacus_volume_outputs( # Pass phonopy_disp.yaml explicitly so phonopy reads the supercell from the yaml # rather than falling into old-style POSCAR mode (which has no DIM). if not os.path.isfile("FORCE_CONSTANTS"): - subprocess.check_call( - Phonon.phonopy_setup_command("phonopy_disp.yaml --writefc"), - shell=True, + Phonon.run_first_success( + Phonon.phonopy_writefc_commands("phonopy_disp.yaml --writefc"), + required_file="FORCE_CONSTANTS", ) if not os.path.isfile("FORCE_CONSTANTS"): raise FileNotFoundError(f"FORCE_CONSTANTS was not created in {helper_dir}") if not os.path.isfile("mesh.yaml"): - subprocess.check_call("phonopy band.conf", shell=True) + subprocess.check_call(Phonon.phonopy_command("band.conf"), shell=True) self._write_band_dat() finally: os.chdir(cwd) @@ -973,28 +975,7 @@ def _write_band_dat() -> None: if not os.path.isfile("band.yaml"): logging.warning("band.yaml was not created; skipping band.dat export") return - with open("band.dat", "w") as fp: - result = subprocess.run( - ["phonopy-bandplot", "--gnuplot", "band.yaml"], - stdout=fp, - stderr=subprocess.PIPE, - text=True, - ) - if result.returncode == 0: - return - if os.path.isfile("band.dat") and os.path.getsize("band.dat") > 0: - logging.warning( - "phonopy-bandplot exited with code %s after writing band.dat; continuing. stderr: %s", - result.returncode, - result.stderr.strip(), - ) - return - raise subprocess.CalledProcessError( - result.returncode, - ["phonopy-bandplot", "--gnuplot", "band.yaml"], - output=None, - stderr=result.stderr, - ) + Phonon.write_band_dat() def _attach_abacus_reference_energies( self, task_infos: List[dict], task_result_map: Dict[str, str] diff --git a/apex/core/property/Interstitial.py b/apex/core/property/Interstitial.py index 27ba7f5d..dfdf1038 100644 --- a/apex/core/property/Interstitial.py +++ b/apex/core/property/Interstitial.py @@ -294,6 +294,7 @@ def make_confs(self, path_to_work, path_to_equi, refine=False): with open("task.000000/POSCAR", "r") as fin: self.pos_line = fin.read().split("\n") + chl = None for idx, ii in enumerate(self.pos_line): ss = ii.split() if len(ss) > 3: @@ -303,11 +304,19 @@ def make_confs(self, path_to_work, path_to_equi, refine=False): and abs(0.14 / self.supercell[2] - float(ss[2])) < TOL ): chl = idx + if chl is None: + raise RuntimeError( + f"Could not locate the generated interstitial anchor site " + f"for {self.structure_type} special interstitial generation. " + "Check the relaxed structure, supercell, or set special_list=[] " + "to use the Voronoi generator." + ) shutil.rmtree("task.000000") os.chdir(cwd) # specify interstitial structures if self.structure_type == 'bcc': + center = None for idx, ii in enumerate(self.pos_line): ss = ii.split() if len(ss) > 3: @@ -317,6 +326,13 @@ def make_confs(self, path_to_work, path_to_equi, refine=False): and abs(0.5 / self.supercell[2] - float(ss[2])) < TOL ): center = idx + if center is None: + raise RuntimeError( + "Could not locate the BCC center atom required for " + "special interstitial dumbbell generation. Check the " + "relaxed structure/supercell or set special_list=[] " + "to use the Voronoi generator." + ) bcc_interstital_dict = { 'tetrahedral': {chl: [0.25, 0.5, 0]}, 'octahedral': {chl: [0.5, 0.5, 0]}, @@ -331,6 +347,8 @@ def make_confs(self, path_to_work, path_to_equi, refine=False): total_task = self.__gen_tasks(bcc_interstital_dict) elif self.structure_type == 'fcc': + face = None + corner = None for idx, ii in enumerate(self.pos_line): ss = ii.split() if len(ss) > 3: @@ -347,6 +365,13 @@ def make_confs(self, path_to_work, path_to_equi, refine=False): and abs(1 / self.supercell[2] - float(ss[2])) < TOL ): corner = idx + if face is None or corner is None: + raise RuntimeError( + "Could not locate the FCC face/corner atoms required " + "for special interstitial generation. Check the relaxed " + "structure/supercell or set special_list=[] to use the " + "Voronoi generator." + ) fcc_interstital_dict = { 'tetrahedral': {chl: [0.75, 0.25, 0.25]}, @@ -376,6 +401,7 @@ def make_confs(self, path_to_work, path_to_equi, refine=False): total_task = self.__gen_tasks(fcc_interstital_dict) elif self.structure_type == 'hcp': + center = None for idx, ii in enumerate(self.pos_line): ss = ii.split() if len(ss) > 3: @@ -385,6 +411,13 @@ def make_confs(self, path_to_work, path_to_equi, refine=False): and abs(0.25 / self.supercell[2] - float(ss[2])) < TOL ): center = idx + if center is None: + raise RuntimeError( + "Could not locate the HCP center atom required for " + "special interstitial generation. Check the relaxed " + "structure/supercell or set special_list=[] to use the " + "Voronoi generator." + ) hcp_interstital_dict = { 'O': {chl: [0, 0, 0.5]}, 'BO': {chl: [0, 0, 0.25]}, diff --git a/apex/core/property/Phonon.py b/apex/core/property/Phonon.py index d8fa15cc..1951b4f2 100644 --- a/apex/core/property/Phonon.py +++ b/apex/core/property/Phonon.py @@ -30,6 +30,60 @@ def phonopy_setup_command(arguments: str) -> str: executable = "phonopy-init" if shutil.which("phonopy-init") else "phonopy" return f"{executable} {arguments}" + @staticmethod + def phonopy_command(arguments: str) -> str: + return f"phonopy {arguments}" + + @staticmethod + def phonopy_writefc_commands(arguments: str) -> List[str]: + setup_command = Phonon.phonopy_setup_command(arguments) + phonopy_command = Phonon.phonopy_command(arguments) + if setup_command == phonopy_command: + return [setup_command] + return [setup_command, phonopy_command] + + @staticmethod + def run_first_success(commands: List[str], required_file: str | None = None) -> None: + errors = [] + for command in commands: + try: + subprocess.check_call(command, shell=True) + except subprocess.CalledProcessError as exc: + errors.append(exc) + continue + if required_file is None or os.path.isfile(required_file): + return + errors.append(FileNotFoundError(f"{required_file} was not created by: {command}")) + if errors: + raise errors[-1] + + @staticmethod + def write_band_dat() -> None: + if not os.path.isfile("band.yaml"): + raise FileNotFoundError("band.yaml was not created") + with open("band.dat", "w") as fp: + result = subprocess.run( + ["phonopy-bandplot", "--gnuplot", "band.yaml"], + stdout=fp, + stderr=subprocess.PIPE, + text=True, + ) + if result.returncode == 0: + return + if os.path.isfile("band.dat") and os.path.getsize("band.dat") > 0: + logging.warning( + "phonopy-bandplot exited with code %s after writing band.dat; continuing. stderr: %s", + result.returncode, + result.stderr.strip(), + ) + return + raise subprocess.CalledProcessError( + result.returncode, + ["phonopy-bandplot", "--gnuplot", "band.yaml"], + output=None, + stderr=result.stderr, + ) + def __init__(self, parameter, inter_param=None): parameter["reproduce"] = parameter.get("reproduce", False) self.reprod = parameter["reproduce"] @@ -529,7 +583,7 @@ def phonopy_band_string_2_band_list(band_str: str, band_label: str = None): @staticmethod def check_same_copy(src, dst): - if os.path.samefile(src, dst): + if os.path.exists(dst) and os.path.samefile(src, dst): return shutil.copyfile(src, dst) @@ -537,90 +591,110 @@ def _compute_lower(self, output_file, all_tasks, all_res): cwd = Path.cwd() work_path = Path(output_file).parent.absolute() output_file = os.path.abspath(output_file) + resolved_tasks = [] + for task in all_tasks: + task_path = Path(task) + if task_path.is_absolute(): + resolved_tasks.append(str(task_path)) + elif (cwd / task_path).exists(): + resolved_tasks.append(str((cwd / task_path).absolute())) + elif task_path.parent == Path("."): + resolved_tasks.append(str((work_path / task_path).absolute())) + else: + resolved_tasks.append(str((cwd / task_path).absolute())) + all_tasks = resolved_tasks res_data = {} ptr_data = os.path.dirname(output_file) + "\n" band_path = loadfn(os.path.join(work_path, "band_path.json")) - if not self.reprod: - os.chdir(work_path) - if self.inter_param["type"] == 'abacus': - self.check_same_copy("task.000000/band.conf", "band.conf") - self.check_same_copy("task.000000/STRU.ori", "STRU") - self.check_same_copy("task.000000/phonopy_disp.yaml", "phonopy_disp.yaml") - os.system(self.phonopy_setup_command("-f task.0*/OUT.ABACUS/running_scf.log")) - if os.path.exists("FORCE_SETS"): + try: + if not self.reprod: + os.chdir(work_path) + if self.inter_param["type"] == 'abacus': + self.check_same_copy("task.000000/band.conf", "band.conf") + self.check_same_copy("task.000000/STRU.ori", "STRU") + self.check_same_copy("task.000000/phonopy_disp.yaml", "phonopy_disp.yaml") + subprocess.check_call( + self.phonopy_setup_command("-f task.0*/OUT.ABACUS/running_scf.log"), + shell=True, + ) + if not os.path.exists("FORCE_SETS"): + raise FileNotFoundError("FORCE_SETS was not created") print('FORCE_SETS is created') - else: - logging.warning('FORCE_SETS can not be created') - os.system('phonopy band.conf --abacus') - os.system('phonopy-bandplot --gnuplot band.yaml > band.dat') + subprocess.check_call(self.phonopy_command("band.conf"), shell=True) + self.write_band_dat() - elif self.inter_param["type"] == 'vasp': - self.check_same_copy("task.000000/band.conf", "band.conf") - self.check_same_copy("task.000000/POSCAR-unitcell", "POSCAR-unitcell") + elif self.inter_param["type"] == 'vasp': + self.check_same_copy("task.000000/band.conf", "band.conf") + if not os.path.exists("POSCAR-unitcell"): + self.check_same_copy("task.000000/POSCAR-unitcell", "POSCAR-unitcell") + + if self.approach == "linear": + os.chdir(all_tasks[0]) + assert os.path.isfile('vasprun.xml'), "vasprun.xml not found" + subprocess.check_call(self.phonopy_setup_command("--fc vasprun.xml"), shell=True) + assert os.path.isfile('FORCE_CONSTANTS'), "FORCE_CONSTANTS not created" + subprocess.check_call(self.phonopy_command('--dim="%s %s %s" -c POSCAR-unitcell band.conf' % ( + self.supercell_size[0], + self.supercell_size[1], + self.supercell_size[2])), shell=True) + self.write_band_dat() + print('band.dat is created') + shutil.copyfile("band.dat", work_path/"band.dat") + + elif self.approach == "displacement": + self.check_same_copy("task.000000/band.conf", "band.conf") + self.check_same_copy("task.000000/phonopy_disp.yaml", "phonopy_disp.yaml") + subprocess.check_call(self.phonopy_setup_command("-f task.0*/vasprun.xml"), shell=True) + if not os.path.exists("FORCE_SETS"): + raise FileNotFoundError("FORCE_SETS was not created") + print('FORCE_SETS is created') + subprocess.check_call(self.phonopy_command('--dim="%s %s %s" -c POSCAR-unitcell band.conf' % ( + self.supercell_size[0], + self.supercell_size[1], + self.supercell_size[2])), shell=True) + self.write_band_dat() - if self.approach == "linear": + elif self.inter_param["type"] in LAMMPS_INTER_TYPE: os.chdir(all_tasks[0]) - assert os.path.isfile('vasprun.xml'), "vasprun.xml not found" - os.system(self.phonopy_setup_command("--fc vasprun.xml")) assert os.path.isfile('FORCE_CONSTANTS'), "FORCE_CONSTANTS not created" - os.system('phonopy --dim="%s %s %s" -c POSCAR-unitcell band.conf' % ( - self.supercell_size[0], - self.supercell_size[1], - self.supercell_size[2])) - os.system('phonopy-bandplot --gnuplot band.yaml > band.dat') - print('band.dat is created') + subprocess.check_call(self.phonopy_command('--dim="%s %s %s" -c POSCAR band.conf' % ( + self.supercell_size[0], self.supercell_size[1], self.supercell_size[2]) + ), shell=True) + self.write_band_dat() shutil.copyfile("band.dat", work_path/"band.dat") - elif self.approach == "displacement": - self.check_same_copy("task.000000/band.conf", "band.conf") - self.check_same_copy("task.000000/phonopy_disp.yaml", "phonopy_disp.yaml") - os.system(self.phonopy_setup_command("-f task.0*/vasprun.xml")) - if os.path.exists("FORCE_SETS"): - print('FORCE_SETS is created') - else: - logging.warning('FORCE_SETS can not be created') - os.system('phonopy --dim="%s %s %s" -c POSCAR-unitcell band.conf' % ( - self.supercell_size[0], - self.supercell_size[1], - self.supercell_size[2])) - os.system('phonopy-bandplot --gnuplot band.yaml > band.dat') - - elif self.inter_param["type"] in LAMMPS_INTER_TYPE: - os.chdir(all_tasks[0]) - assert os.path.isfile('FORCE_CONSTANTS'), "FORCE_CONSTANTS not created" - os.system('phonopy --dim="%s %s %s" -c POSCAR band.conf' % ( - self.supercell_size[0], self.supercell_size[1], self.supercell_size[2]) - ) - os.system('phonopy-bandplot --gnuplot band.yaml > band.dat') - shutil.copyfile("band.dat", work_path/"band.dat") - - else: - if "init_data_path" not in self.parameter: - raise RuntimeError("please provide the initial data path to reproduce") - init_data_path = os.path.abspath(self.parameter["init_data_path"]) - res_data, ptr_data = post_repro( - init_data_path, - self.parameter["init_from_suffix"], - all_tasks, - ptr_data, - self.parameter.get("reprod_last_frame", True), - ) - - os.chdir(work_path) - with open('band.dat', 'r') as f: - ptr_data = f.read() - - result_points = ptr_data.split('\n')[1][4:].split() - result_lines = ptr_data.split('\n')[2:] - unpacked_lines = self.unpack_band('\n'.join(result_lines)) - res_data['segment'] = result_points - res_data['band_path'] = band_path - res_data['band'] = unpacked_lines - - with open(output_file, "w") as fp: - json.dump(res_data, fp, indent=4) + else: + if "init_data_path" not in self.parameter: + raise RuntimeError("please provide the initial data path to reproduce") + init_data_path = os.path.abspath(self.parameter["init_data_path"]) + res_data, ptr_data = post_repro( + init_data_path, + self.parameter["init_from_suffix"], + all_tasks, + ptr_data, + self.parameter.get("reprod_last_frame", True), + ) - os.chdir(cwd) - return res_data, ptr_data + os.chdir(work_path) + if not os.path.isfile("band.dat"): + raise FileNotFoundError("band.dat was not created") + with open('band.dat', 'r') as f: + ptr_data = f.read() + + if len(ptr_data.split('\n')) < 2: + raise ValueError("band.dat is empty or malformed") + result_points = ptr_data.split('\n')[1][4:].split() + result_lines = ptr_data.split('\n')[2:] + unpacked_lines = self.unpack_band('\n'.join(result_lines)) + res_data['segment'] = result_points + res_data['band_path'] = band_path + res_data['band'] = unpacked_lines + + with open(output_file, "w") as fp: + json.dump(res_data, fp, indent=4) + + return res_data, ptr_data + finally: + os.chdir(cwd) diff --git a/apex/flow.py b/apex/flow.py index d63cab76..cd5e70d6 100644 --- a/apex/flow.py +++ b/apex/flow.py @@ -25,7 +25,7 @@ from apex.superop.RelaxationFlow import RelaxationFlow from apex.superop.SimplePropertySteps import SimplePropertySteps from apex.op.relaxation_ops import RelaxMake, RelaxPost -from apex.op.property_ops import PropsMake, PropsPost +from apex.op.property_ops import PropsMake, PropsPost, PropsRepairStatusCheck from apex.utils import json2dict, handle_prop_suffix from dflow.python import upload_packages @@ -46,6 +46,7 @@ def __init__( relax_post_op: Type[OP] = RelaxPost, props_make_op: Type[OP] = PropsMake, props_post_op: Type[OP] = PropsPost, + props_repair_op: Type[OP] = PropsRepairStatusCheck, group_size: Optional[int] = None, pool_size: Optional[int] = None, executor: Optional[DispatcherExecutor] = None, @@ -62,6 +63,7 @@ def __init__( self.relax_post_op = relax_post_op self.props_make_op = props_make_op self.props_post_op = props_post_op + self.props_repair_op = props_repair_op self.run_op = run_op self.make_image = make_image self.run_image = run_image @@ -1001,6 +1003,7 @@ def _set_props_flow( make_op=self.props_make_op, run_op=self.run_op, post_op=self.props_post_op, + repair_op=self.props_repair_op, make_image=self.make_image, run_image=self.run_image, post_image=self.post_image, @@ -1122,6 +1125,7 @@ def _set_props_tasks( make_op=self.props_make_op, run_op=self.run_op, post_op=self.props_post_op, + repair_op=self.props_repair_op, make_image=self.make_image, run_image=self.run_image, post_image=self.post_image, diff --git a/apex/main.py b/apex/main.py index 1d30565d..de049c0e 100644 --- a/apex/main.py +++ b/apex/main.py @@ -4,6 +4,7 @@ import os import datetime import time +import json from typing import List from dflow import ( @@ -23,6 +24,7 @@ from apex.archive import archive_from_args from apex.report import report_from_args from apex.utils import load_config_file +from apex.task_failure import classify_apex_task_status def parse_args(): @@ -902,11 +904,11 @@ def _is_transient_download_error(exc: Exception) -> bool: return any(marker in message for marker in _TRANSIENT_DOWNLOAD_MARKERS) -def _download_artifact_with_retry(artifact, path, retries: int = 3, delay: int = 10): +def _download_artifact_with_retry(artifact, path, retries: int = 3, delay: int = 10, **kwargs): last_exc = None for attempt in range(1, retries + 1): try: - return download_artifact(artifact=artifact, path=path) + return download_artifact(artifact=artifact, path=path, **kwargs) except Exception as exc: last_exc = exc if _is_missing_artifact_error(exc) or not _is_transient_download_error(exc): @@ -1019,6 +1021,94 @@ def _collect_step_with_children(wf_info, root_step): return all_steps +def _safe_read_text(path: str, limit: int = 200000) -> str: + try: + with open(path, "r", encoding="utf-8", errors="replace") as fp: + return fp.read(limit) + except Exception: + return "" + + +def _extract_failed_task_ids(artifact_root: str) -> list[str]: + task_ids = set() + for root, _dirs, files in os.walk(artifact_root): + for filename in files: + if filename not in {"main.log", "failed_lammps_tasks.json", "run_status_check.json"}: + continue + text = _safe_read_text(os.path.join(root, filename)) + for match in re.findall(r"task\.(\d{6})", text): + task_ids.add(match) + return sorted(task_ids) + + +def _is_diagnostic_file(filename: str) -> bool: + if filename in { + "apex_task_status.json", + ".debug.log", + "log.lammps", + "outlog", + "task.json", + "failed_lammps_tasks.json", + "run_status_check.json", + }: + return True + return filename.endswith(".json") and filename not in {"param.json", "result.json", "result_task.json"} + + +def _write_failed_artifact_summary(key: str, artifact_root: str) -> dict: + task_records = [] + diagnostic_files = [] + classifications = {} + for root, _dirs, files in os.walk(artifact_root): + for filename in files: + full_path = os.path.join(root, filename) + rel_path = os.path.relpath(full_path, artifact_root) + if _is_diagnostic_file(filename): + diagnostic_files.append(rel_path) + if filename != "apex_task_status.json": + continue + try: + with open(full_path, "r", encoding="utf-8", errors="replace") as fp: + status = json.load(fp) + except Exception as exc: + classified = { + "state": "failed", + "reason": "invalid_task_status", + "message": f"Could not parse apex_task_status.json: {exc}", + "exit_code": None, + } + else: + classified = classify_apex_task_status(status, root) + reason = classified.get("reason", "unknown_failure") + classifications[reason] = classifications.get(reason, 0) + 1 + task_records.append( + { + "task": os.path.basename(root), + "classification": reason, + "state": classified.get("state"), + "exit_code": classified.get("exit_code"), + "diagnostic_path": os.path.relpath(root, artifact_root), + "retry_reason": classified.get("retry_reason"), + "original_reason": classified.get("original_reason"), + } + ) + + failed_tasks = [record for record in task_records if record.get("state") != "succeeded"] + summary = { + "step_key": key, + "failed_task_count": len(failed_tasks), + "task_status_count": len(task_records), + "classifications": classifications, + "failed_tasks": failed_tasks, + "diagnostic_files": sorted(diagnostic_files), + } + os.makedirs(artifact_root, exist_ok=True) + summary_path = os.path.join(artifact_root, "summary.json") + with open(summary_path, "w", encoding="utf-8") as fp: + json.dump(summary, fp, indent=4) + return summary + + def _is_retrievable_result_step_key(key: str) -> bool: prefix = str(key).split("-")[0] return prefix in {"propertycal", "relaxcal"} or key == "relaxationcal" @@ -1041,6 +1131,40 @@ def _download_failure_artifacts_for_step(wf_info, root_step, key, work_dir): related_steps = _collect_step_with_children(wf_info, root_step) downloaded = 0 seen = set() + artifact_root = os.path.join( + work_dir, + ".failed-artifacts", + _sanitize_path_token(key), + ) + + def _target_dir(step_name, art_name, suffix=None): + parts = [ + artifact_root, + _sanitize_path_token(step_name), + _sanitize_path_token(art_name), + ] + if suffix: + parts.append(_sanitize_path_token(suffix)) + return os.path.join(*parts) + + def _download_one(step_name, step_id, art_name, artifact, target_dir, **kwargs): + nonlocal downloaded + os.makedirs(target_dir, exist_ok=True) + if _directory_has_entries(target_dir): + logging.info( + "Skip retrieving failure artifact %s for step %s (%s) " + "because %s already contains files.", + art_name, + step_name, + key, + target_dir, + ) + return + _download_artifact_with_retry(artifact=artifact, path=target_dir, **kwargs) + downloaded += 1 + + # Download failed step logs first; they often contain the failed task IDs + # needed for focused backward_dir slice retrieval. for step in related_steps: step_id = _safe_get(step, "id", "step") step_name = _safe_get(step, "displayName", _safe_get(step, "name", step_id)) @@ -1048,34 +1172,60 @@ def _download_failure_artifacts_for_step(wf_info, root_step, key, work_dir): for art_name, artifact in artifacts.items(): if art_name.startswith("dflow_"): continue - if art_name not in preferred_names: + if art_name not in {"main-logs", "main_logs"}: continue key_tuple = (str(step_id), str(art_name)) if key_tuple in seen: continue seen.add(key_tuple) - - target_dir = os.path.join( - work_dir, - ".failed-artifacts", - _sanitize_path_token(key), - _sanitize_path_token(step_name), - _sanitize_path_token(art_name), - ) - os.makedirs(target_dir, exist_ok=True) - if _directory_has_entries(target_dir): - logging.info( - "Skip retrieving failure artifact %s for step %s (%s) " - "because %s already contains files.", + try: + _download_one(step_name, step_id, art_name, artifact, _target_dir(step_name, art_name)) + except Exception as exc: + logging.warning( + "Failed to download artifact %s for step %s (%s): %s", art_name, step_name, key, - target_dir, + exc, ) + + failed_task_ids = _extract_failed_task_ids(artifact_root) + + for step in related_steps: + step_id = _safe_get(step, "id", "step") + step_name = _safe_get(step, "displayName", _safe_get(step, "name", step_id)) + artifacts = _get_step_artifacts(step) + for art_name, artifact in artifacts.items(): + if art_name.startswith("dflow_"): + continue + if art_name not in preferred_names or art_name in {"main-logs", "main_logs"}: + continue + key_tuple = (str(step_id), str(art_name)) + if key_tuple in seen: continue + seen.add(key_tuple) try: - _download_artifact_with_retry(artifact=artifact, path=target_dir) - downloaded += 1 + if art_name == "backward_dir": + if failed_task_ids: + for task_id in failed_task_ids: + _download_one( + step_name, + step_id, + art_name, + artifact, + _target_dir(step_name, art_name, f"task.{task_id}"), + slice=int(task_id), + remove_catalog=False, + ) + else: + logging.warning( + "Could not determine failed task slices for %s; " + "falling back to full backward_dir download.", + key, + ) + _download_one(step_name, step_id, art_name, artifact, _target_dir(step_name, art_name)) + else: + _download_one(step_name, step_id, art_name, artifact, _target_dir(step_name, art_name)) except Exception as exc: logging.warning( "Failed to download artifact %s for step %s (%s): %s", @@ -1084,6 +1234,13 @@ def _download_failure_artifacts_for_step(wf_info, root_step, key, work_dir): key, exc, ) + summary = _write_failed_artifact_summary(key, artifact_root) + logging.warning( + "Failure summary for %s: %s failed task(s), classifications=%s", + key, + summary.get("failed_task_count", 0), + summary.get("classifications", {}), + ) return downloaded diff --git a/apex/op/RunLAMMPS.py b/apex/op/RunLAMMPS.py index adbd4294..44c6c65f 100644 --- a/apex/op/RunLAMMPS.py +++ b/apex/op/RunLAMMPS.py @@ -1,5 +1,6 @@ import datetime import os, subprocess, logging, time +import re from pathlib import Path from monty.serialization import dumpfn, loadfn from dflow.python import ( @@ -9,6 +10,12 @@ Artifact, upload_packages ) +from apex.task_failure import ( + HEADER_ONLY_RETRY_REASON, + classify_lammps_exit_code, + is_header_only_lammps_failure, + is_lammps_header_only_log, +) upload_packages.append(__file__) @@ -66,53 +73,7 @@ def _utc_now(cls) -> str: @classmethod def _classify_exit_code(cls, exit_code: int) -> dict: - if exit_code == 0: - return { - "state": "succeeded", - "reason": "command_exit_zero", - "message": "Command completed successfully.", - } - if exit_code == 124: - return { - "state": "failed", - "reason": "timeout", - "message": "Command exited with timeout code 124.", - } - if exit_code == 126: - return { - "state": "failed", - "reason": "command_not_executable", - "message": "Command was found but could not be executed.", - } - if exit_code == 127: - return { - "state": "failed", - "reason": "command_not_found", - "message": "Command executable was not found.", - } - if exit_code in (130, 143): - return { - "state": "failed", - "reason": "terminated", - "message": f"Command was terminated by signal-like exit code {exit_code}.", - } - if exit_code == 137: - return { - "state": "failed", - "reason": "killed_or_oom", - "message": "Command was killed with exit code 137, commonly SIGKILL/OOM/preemption.", - } - if exit_code > 128: - return { - "state": "failed", - "reason": "signal_exit", - "message": f"Command exited with code {exit_code}, likely signal {exit_code - 128}.", - } - return { - "state": "failed", - "reason": "nonzero_exit", - "message": f"Command exited with non-zero code {exit_code}.", - } + return classify_lammps_exit_code(exit_code) @classmethod def _write_task_status( @@ -127,8 +88,15 @@ def _write_task_status( debug_log: str = ".debug.log", attempts: int = 1, retry_reason: str | None = None, + task_dir: Path | None = None, ): - status = cls._classify_exit_code(exit_code) + remote_startup = ( + retry_reason == HEADER_ONLY_RETRY_REASON + and exit_code != 0 + ) + if task_dir is not None: + remote_startup = remote_startup or is_header_only_lammps_failure(task_dir, exit_code) + status = classify_lammps_exit_code(exit_code, remote_startup=remote_startup) payload = { **status, "exit_code": int(exit_code), @@ -141,6 +109,10 @@ def _write_task_status( } if retry_reason: payload["retry_reason"] = retry_reason + payload["retry_classification"] = classify_lammps_exit_code( + 1, + remote_startup=(retry_reason == HEADER_ONLY_RETRY_REASON), + )["reason"] dumpfn( payload, status_file, @@ -168,6 +140,31 @@ def _safe_cmd(cls, cmd: str, timeout: int = 10) -> str: except Exception as exc: return f"" + @classmethod + def _command_env_value(cls, cmd: str, name: str): + match = re.search(rf"(?:^|\s){re.escape(name)}=([^\s]+)", str(cmd)) + return match.group(1) if match else None + + @classmethod + def _runtime_int_option(cls, cmd: str, name: str, default: int) -> int: + value = cls._command_env_value(cmd, name) + if value is None: + value = os.environ.get(name, str(default)) + try: + return int(value) + except (TypeError, ValueError): + return default + + @classmethod + def _runtime_float_option(cls, cmd: str, name: str, default: float) -> float: + value = cls._command_env_value(cmd, name) + if value is None: + value = os.environ.get(name, str(default)) + try: + return float(value) + except (TypeError, ValueError): + return default + @classmethod def _tail_file(cls, path: Path, n_lines: int = 80) -> str: if not path.exists(): @@ -252,28 +249,11 @@ def _log_candidates(cls, task_dir: Path) -> list[Path]: @classmethod def _is_lammps_header_only_log(cls, path: Path) -> bool: - if not path.is_file(): - return False - try: - lines = [ - line.strip() - for line in path.read_text(errors="replace").splitlines() - if line.strip() - ] - except Exception: - return False - return len(lines) == 1 and lines[0].startswith("LAMMPS (") + return is_lammps_header_only_log(path) @classmethod def _is_header_only_lammps_failure(cls, task_dir: Path, exit_code: int) -> bool: - if exit_code == 0: - return False - if any((task_dir / name).exists() for name in ["CONTCAR", "dump.relax", "stress_timeseries.txt"]): - return False - return ( - cls._is_lammps_header_only_log(task_dir / "log.lammps") - or cls._is_lammps_header_only_log(task_dir / "outlog") - ) + return is_header_only_lammps_failure(task_dir, exit_code) @classmethod def _archive_retry_file(cls, path: Path, attempt: int): @@ -380,6 +360,7 @@ def execute(self, op_in: OPIO) -> OPIO: elapsed=0.0, started_at=now, finished_at=now, + task_dir=task_dir, ) self._write_final_debug(debug_file, task_dir, 127, 0.0) self._cleanup_model_links(task_dir) @@ -389,14 +370,18 @@ def execute(self, op_in: OPIO) -> OPIO: start = time.time() retry_reason = None attempts = 1 - max_attempts = int(os.environ.get("APEX_LAMMPS_HEADER_RETRY", "2")) + max_attempts = self._runtime_int_option(cmd, "APEX_LAMMPS_HEADER_RETRY", 2) max_attempts = max(1, max_attempts) + retry_delay = max( + 0.0, + self._runtime_float_option(cmd, "APEX_LAMMPS_HEADER_RETRY_DELAY", 5.0), + ) exit_code = self._run_command(cmd, task_dir) while ( attempts < max_attempts and self._is_header_only_lammps_failure(task_dir, exit_code) ): - retry_reason = "header_only_lammps_log_after_nonzero_exit" + retry_reason = HEADER_ONLY_RETRY_REASON self._append_debug( debug_file, f"\n## Retry {attempts + 1}\n" @@ -404,7 +389,7 @@ def execute(self, op_in: OPIO) -> OPIO: "contains only the LAMMPS header.", ) self._prepare_retry(task_dir, attempts) - time.sleep(float(os.environ.get("APEX_LAMMPS_HEADER_RETRY_DELAY", "5"))) + time.sleep(retry_delay) attempts += 1 exit_code = self._run_command(cmd, task_dir) elapsed = time.time() - start @@ -418,6 +403,7 @@ def execute(self, op_in: OPIO) -> OPIO: finished_at=finished_at, attempts=attempts, retry_reason=retry_reason, + task_dir=task_dir, ) self._write_final_debug(debug_file, task_dir, exit_code, elapsed) if exit_code == 0: diff --git a/apex/op/property_ops.py b/apex/op/property_ops.py index 6d09fe01..bf4e3883 100644 --- a/apex/op/property_ops.py +++ b/apex/op/property_ops.py @@ -13,6 +13,10 @@ from apex.utils import recursive_search, apex_task_succeeded from apex.core.lib.utils import create_path from apex.core.calculator import LAMMPS_INTER_TYPE +from apex.task_failure import ( + REMOTE_LAMMPS_STARTUP_FAILURE, + classify_apex_task_status, +) upload_packages.append(__file__) @@ -34,7 +38,101 @@ def _load_task_status(status_path: Path): def _is_failed_task_status(status) -> bool: if status is None: return False - return status.get("state") != "succeeded" or status.get("exit_code") != 0 + classified = classify_apex_task_status(status) + return classified.get("state") != "succeeded" or classified.get("exit_code") != 0 + + +def _collect_lammps_status_failures(path_to_prop: Path): + failures = [] + for status_path in sorted(path_to_prop.glob("task.*/apex_task_status.json")): + status = _load_task_status(status_path) + classified = classify_apex_task_status(status, status_path.parent) + if classified.get("state") != "succeeded" or classified.get("exit_code") != 0: + failures.append( + { + "task": str(status_path.parent), + "state": classified.get("state"), + "reason": classified.get("reason"), + "exit_code": classified.get("exit_code"), + "message": classified.get("message"), + "retry_reason": classified.get("retry_reason"), + "original_reason": classified.get("original_reason"), + } + ) + return failures + + +class PropsRepairStatusCheck(OP): + """ + Lightweight status gate between LAMMPS run and property post. + + The actual bounded retry happens inside RunLAMMPS for the only transient + class we currently trust. This OP records which failures remain eligible + for that repair path and leaves deterministic errors for PropsPost to fail. + """ + + @classmethod + def get_input_sign(cls): + return OPIOSign({ + 'input_post': Artifact(Path, sub_path=False), + 'input_all': Artifact(Path), + 'task_names': List[str], + 'path_to_prop': str + }) + + @classmethod + def get_output_sign(cls): + return OPIOSign({ + 'checked_post': Artifact(Path, sub_path=False) + }) + + @OP.exec_sign_check + def execute(self, op_in: OPIO) -> OPIO: + cwd = os.getcwd() + input_post = op_in["input_post"] + input_all = op_in["input_all"] + task_names = op_in["task_names"] + path_to_prop = op_in["path_to_prop"] + + if len(task_names) == 0: + return OPIO({"checked_post": input_post}) + + try: + copy_dir_list_input = [path_to_prop.split('/')[0]] + os.chdir(input_all) + copy_dir_list = [] + for ii in copy_dir_list_input: + copy_dir_list.extend(glob.glob(ii)) + copy_dir_list = sorted(set(copy_dir_list)) + + os.chdir(input_post) + src_path = recursive_search(copy_dir_list) + if not src_path: + return OPIO({"checked_post": input_post}) + + prop_root = Path(src_path) / path_to_prop + failures = _collect_lammps_status_failures(prop_root) + if failures: + eligible = [ + item for item in failures + if item.get("reason") == REMOTE_LAMMPS_STARTUP_FAILURE + ] + dumpfn( + { + "failed_tasks": failures, + "retry_eligible_tasks": eligible, + "retry_policy": ( + "RunLAMMPS retries remote_lammps_startup_failure before " + "this status check; remaining failures are passed to PropsPost." + ), + }, + prop_root / "run_status_check.json", + indent=4, + ) + finally: + os.chdir(cwd) + + return OPIO({"checked_post": input_post}) class PropsMake(OP): @@ -224,19 +322,7 @@ def execute(self, op_in: OPIO) -> OPIO: inter_param = prop_param["cal_setting"]["overwrite_interaction"] abs_path_to_prop = Path.cwd() / path_to_prop - lammps_failures = [] - for status_path in sorted(abs_path_to_prop.glob("task.*/apex_task_status.json")): - status = _load_task_status(status_path) - if _is_failed_task_status(status): - lammps_failures.append( - { - "task": str(status_path.parent), - "state": status.get("state"), - "reason": status.get("reason"), - "exit_code": status.get("exit_code"), - "message": status.get("message"), - } - ) + lammps_failures = _collect_lammps_status_failures(abs_path_to_prop) if lammps_failures: dumpfn( {"failed_tasks": lammps_failures}, diff --git a/apex/submit.py b/apex/submit.py index 26e1df46..09a2116f 100644 --- a/apex/submit.py +++ b/apex/submit.py @@ -54,6 +54,19 @@ def validate_submit_paths(parameter_dicts: List[dict]) -> None: ) +def _with_lammps_retry_env(run_command: str, wf_config: Config) -> str: + if not run_command: + return run_command + if "APEX_LAMMPS_HEADER_RETRY" in run_command: + return run_command + retry_env = ( + f"APEX_LAMMPS_HEADER_RETRY={int(wf_config.lammps_header_retry_attempts)} " + f"APEX_LAMMPS_HEADER_RETRY_DELAY={float(wf_config.lammps_header_retry_delay)} " + f"APEX_LAMMPS_TRANSIENT_RETRY={int(wf_config.lammps_transient_retry_attempts)}" + ) + return f"{retry_env} {run_command}" + + def _infer_type_map_from_structure_file(structure_file: str) -> dict: structure_name = os.path.basename(structure_file) symbols = [] @@ -553,6 +566,8 @@ def submit_workflow( run_command = wf_config.basic_config_dict[f"{calculator}_run_command"] if not run_command: run_command = wf_config.basic_config_dict["run_command"] + if calculator == "lammps": + run_command = _with_lammps_retry_env(run_command, wf_config) lammps_run_command = wf_config.basic_config_dict["lammps_run_command"] phonolammps_run_command = wf_config.basic_config_dict["phonolammps_run_command"] post_image = make_image diff --git a/apex/superop/SimplePropertySteps.py b/apex/superop/SimplePropertySteps.py index ab40fb75..f77b837e 100644 --- a/apex/superop/SimplePropertySteps.py +++ b/apex/superop/SimplePropertySteps.py @@ -43,6 +43,7 @@ def __init__( pool_size: Optional[int] = None, executor: Optional[DispatcherExecutor] = None, upload_python_packages: Optional[List[os.PathLike]] = None, + repair_op: Type[OP] = None, ): self._input_parameters = { "flow_id": InputParameter(type=str, value=""), @@ -99,7 +100,8 @@ def __init__( group_size, pool_size, executor, - upload_python_packages + upload_python_packages, + repair_op ) @property @@ -137,6 +139,7 @@ def _build( pool_size: Optional[int] = None, executor: Optional[DispatcherExecutor] = None, upload_python_packages: Optional[List[os.PathLike]] = None, + repair_op: Type[OP] = None, ): # Step for property make make = Step( @@ -233,6 +236,29 @@ def _build( raise RuntimeError(f'Incorrect calculator type to initiate step: {calculator}') self.add(runcal) + post_input = runcal.outputs.artifacts["backward_dir"] + if calculator == 'lammps' and repair_op is not None: + repair = Step( + name="Props-run-status-check", + template=PythonOPTemplate( + repair_op, + image=post_image, + python_packages=upload_python_packages, + command=["python3"] + ), + artifacts={ + "input_post": runcal.outputs.artifacts["backward_dir"], + "input_all": make.outputs.artifacts["output_work_path"] + }, + parameters={ + "task_names": make.outputs.parameters["task_names"], + "path_to_prop": self.inputs.parameters["path_to_prop"] + }, + key=self.step_keys["run"] + '-status-check', + ) + self.add(repair) + post_input = repair.outputs.artifacts["checked_post"] + # Step for property post post = Step( name="Props-post", @@ -243,7 +269,7 @@ def _build( command=["python3"] ), artifacts={ - "input_post": runcal.outputs.artifacts["backward_dir"], + "input_post": post_input, "input_all": make.outputs.artifacts["output_work_path"] }, parameters={ diff --git a/apex/task_failure.py b/apex/task_failure.py new file mode 100644 index 00000000..76af509e --- /dev/null +++ b/apex/task_failure.py @@ -0,0 +1,147 @@ +from __future__ import annotations + +from pathlib import Path +from typing import Any + +from monty.serialization import loadfn + + +HEADER_ONLY_RETRY_REASON = "header_only_lammps_log_after_nonzero_exit" +REMOTE_LAMMPS_STARTUP_FAILURE = "remote_lammps_startup_failure" + + +def is_lammps_header_only_text(text: str) -> bool: + lines = [line.strip() for line in str(text).splitlines() if line.strip()] + return len(lines) == 1 and lines[0].startswith("LAMMPS (") + + +def is_lammps_header_only_log(path: Path | str) -> bool: + log_path = Path(path) + if not log_path.is_file(): + return False + try: + return is_lammps_header_only_text(log_path.read_text(errors="replace")) + except Exception: + return False + + +def is_header_only_lammps_failure(task_dir: Path | str, exit_code: int | None) -> bool: + if exit_code in (None, 0): + return False + task_path = Path(task_dir) + if any((task_path / name).exists() for name in ["CONTCAR", "dump.relax", "stress_timeseries.txt"]): + return False + return ( + is_lammps_header_only_log(task_path / "log.lammps") + or is_lammps_header_only_log(task_path / "outlog") + ) + + +def classify_lammps_exit_code(exit_code: int | None, *, remote_startup: bool = False) -> dict[str, Any]: + if exit_code is None: + return { + "state": "failed", + "reason": "unknown_failure", + "message": "Command exit code is unavailable.", + } + if exit_code == 0: + return { + "state": "succeeded", + "reason": "command_exit_zero", + "message": "Command completed successfully.", + } + if remote_startup: + return { + "state": "failed", + "reason": REMOTE_LAMMPS_STARTUP_FAILURE, + "message": "LAMMPS exited non-zero after writing only the startup header.", + } + if exit_code == 124: + return { + "state": "failed", + "reason": "timeout", + "message": "Command exited with timeout code 124.", + } + if exit_code == 126: + return { + "state": "failed", + "reason": "command_not_executable", + "message": "Command was found but could not be executed.", + } + if exit_code == 127: + return { + "state": "failed", + "reason": "command_not_found", + "message": "Command executable was not found.", + } + if exit_code == 137: + return { + "state": "failed", + "reason": "killed_or_oom", + "message": "Command was killed with exit code 137, commonly SIGKILL/OOM/preemption.", + } + if exit_code in (130, 143): + return { + "state": "failed", + "reason": "terminated", + "message": f"Command was terminated by signal-like exit code {exit_code}.", + } + if exit_code > 128: + return { + "state": "failed", + "reason": "killed_or_oom", + "message": f"Command exited with code {exit_code}, likely signal {exit_code - 128}.", + } + return { + "state": "failed", + "reason": "nonzero_lammps_error", + "message": f"Command exited with non-zero code {exit_code}.", + } + + +def classify_apex_task_status(status: Any, task_dir: Path | str | None = None) -> dict[str, Any]: + if not isinstance(status, dict): + return { + "state": "failed", + "reason": "invalid_task_status", + "message": "apex_task_status.json is missing or is not a JSON object.", + "exit_code": None, + } + exit_code = status.get("exit_code") + try: + exit_code = int(exit_code) if exit_code is not None else None + except (TypeError, ValueError): + exit_code = None + remote_startup = ( + status.get("reason") == REMOTE_LAMMPS_STARTUP_FAILURE + or status.get("retry_reason") == HEADER_ONLY_RETRY_REASON + ) + if task_dir is not None: + remote_startup = remote_startup or is_header_only_lammps_failure(task_dir, exit_code) + classified = classify_lammps_exit_code(exit_code, remote_startup=remote_startup) + if status.get("state") != "succeeded" and classified["state"] == "succeeded": + classified = { + "state": "failed", + "reason": "invalid_task_status", + "message": "Task status is failed but exit_code is zero.", + } + return { + **classified, + "exit_code": exit_code, + "original_reason": status.get("reason"), + "retry_reason": status.get("retry_reason"), + } + + +def load_and_classify_task_status(status_path: Path | str) -> dict[str, Any]: + path = Path(status_path) + try: + status = loadfn(path) + except Exception as exc: + return { + "state": "failed", + "reason": "invalid_task_status", + "message": f"Could not parse apex_task_status.json: {exc}", + "exit_code": None, + } + return classify_apex_task_status(status, path.parent) diff --git a/tests/test_gamma.py b/tests/test_gamma.py index f1ef1e83..e2203a41 100644 --- a/tests/test_gamma.py +++ b/tests/test_gamma.py @@ -74,6 +74,14 @@ def test_task_type(self): def test_task_param(self): self.assertEqual(self.prop_param[0], self.gamma.task_param()) + def test_refine_style_gamma_defaults_cal_type(self): + gamma = Gamma({ + "type": "gamma", + "init_from_suffix": "00", + "output_suffix": "01", + }) + self.assertEqual(gamma.cal_type, "relaxation") + def test_make_confs_bcc(self): if not os.path.exists(os.path.join(self.equi_path, "CONTCAR")): with self.assertRaises(RuntimeError): @@ -82,6 +90,10 @@ def test_make_confs_bcc(self): os.path.join(self.source_path, "CONTCAR_Mo_bcc"), os.path.join(self.equi_path, "CONTCAR"), ) + with self.assertRaisesRegex(RuntimeError, "result.json"): + self.gamma.make_confs(self.target_path, self.equi_path) + with open(os.path.join(self.equi_path, "result.json"), "w", encoding="utf-8") as fp: + fp.write('{"energies": [-1.0], "atom_numbs": [1]}') task_list = self.gamma.make_confs(self.target_path, self.equi_path) dfm_dirs = glob.glob(os.path.join(self.target_path, "task.*")) self.assertEqual(len(dfm_dirs), self.gamma.n_steps + 1) @@ -107,6 +119,33 @@ def test_make_confs_bcc(self): z_coord = float(z_coord_str) self.assertTrue(z_coord <= 1) + def test_static_gamma_defaults_without_add_fix(self): + gamma = Gamma({ + "type": "gamma", + "cal_type": "static", + "plane_miller": [0, 0, 1], + "slip_direction": [1, 0, 0], + }) + self.assertIsNone(gamma.add_fix) + + def test_incompatible_lammps_add_fix_fails_clearly(self): + gamma = Gamma({ + "type": "gamma", + "cal_type": "static", + "plane_miller": [0, 0, 1], + "slip_direction": [1, 0, 0], + "add_fix": ["true", "true", "false"], + }) + task_dir = os.path.join(self.target_path, "task.000000") + os.makedirs(task_dir, exist_ok=True) + with open(os.path.join(task_dir, "inter.json"), "w", encoding="utf-8") as fp: + fp.write('{"type": "deepmd"}') + with open(os.path.join(task_dir, "in.lammps"), "w", encoding="utf-8") as fp: + fp.write("run 0\n") + + with self.assertRaisesRegex(RuntimeError, "add_fix was requested"): + gamma.post_process([task_dir]) + def test_compute_lower(self): cwd = os.getcwd() output_file = os.path.join(cwd, "output/gamma_00/result.json") diff --git a/tests/test_gruneisen.py b/tests/test_gruneisen.py index cf92165f..0d1ad82d 100644 --- a/tests/test_gruneisen.py +++ b/tests/test_gruneisen.py @@ -288,7 +288,7 @@ def fake_check_call(command, shell): Path("FORCE_SETS").write_text("fake force sets\n") elif command.startswith(Phonon.phonopy_setup_command("--dim=")) and "--writefc" in command: Path("FORCE_CONSTANTS").write_text("fake force constants\n") - elif command.startswith("phonopy --dim="): + elif command.startswith(Phonon.phonopy_command("--dim=")): strain = loadfn("volume.json")["strain"] if strain < 0: frequencies = [4.2, 8.4] @@ -408,7 +408,7 @@ def fake_check_call(command, shell): Path("FORCE_SETS").write_text("fake force sets\n") elif command == Phonon.phonopy_setup_command("phonopy_disp.yaml --writefc"): Path("FORCE_CONSTANTS").write_text("fake force constants\n") - elif command == "phonopy band.conf": + elif command == Phonon.phonopy_command("band.conf"): strain = loadfn("volume.json")["strain"] if strain < 0: frequencies = [4.2, 8.4] @@ -459,7 +459,8 @@ def fake_run(command, stdout, stderr, text): ]), 3, ) - self.assertEqual(len([cmd for _, cmd in calls if cmd == "phonopy band.conf"]), 3) + self.assertEqual(len([cmd for _, cmd in calls if cmd == Phonon.phonopy_command("band.conf")]), 3) + self.assertFalse(any(cmd == "phonopy band.conf --abacus" for _, cmd in calls)) self.assertTrue((work_dir / "volume.000000" / "mesh.yaml").is_file()) self.assertTrue((work_dir / "volume.000001" / "band.dat").is_file()) self.assertTrue("Temperature(K) SumGammaCv Sign" in ptr) diff --git a/tests/test_interstitial.py b/tests/test_interstitial.py index 9b25a574..ab651fff 100644 --- a/tests/test_interstitial.py +++ b/tests/test_interstitial.py @@ -2,7 +2,9 @@ import os import shutil import sys +import tempfile import unittest +from unittest.mock import patch import numpy as np from pymatgen.analysis.defects.core import Interstitial as pmg_Interstitial @@ -104,3 +106,59 @@ def test_make_confs_bcc(self): center = (inter_site1.coords + inter_site2.coords) / 2 self.assertTrue((center[0] - center[1]) < 1e-4) self.assertTrue((center[1] - center[2]) < 1e-4) + + def test_special_interstitial_generation_reports_missing_reference_sites(self): + class FakeGeneratedStructure: + distance_matrix = np.array([[0.0, 1.0], [1.0, 0.0]]) + + def __init__(self, coords): + self.coords = coords + + def to(self, _fmt, filename): + with open(filename, "w", encoding="utf-8") as fp: + fp.write("fake\n1.0\n1 0 0\n0 1 0\n0 0 1\nV\n") + fp.write(f"{len(self.coords)}\nDirect\n") + for coord in self.coords: + fp.write(f"{coord[0]} {coord[1]} {coord[2]} T T T\n") + + class FakeInterstitialDefect: + coords = [] + + def __init__(self, *_args, **_kwargs): + pass + + def get_supercell_structure(self, sc_mat): + return FakeGeneratedStructure(self.coords) + + anchor = [0.12, 0.13, 0.14] + cases = [ + ("bcc", [], "anchor site"), + ("bcc", [anchor], "BCC center atom"), + ("fcc", [anchor], "FCC face/corner atoms"), + ("hcp", [anchor], "HCP center atom"), + ] + + for lattice_type, coords, message in cases: + with self.subTest(lattice_type=lattice_type, message=message): + with tempfile.TemporaryDirectory() as tmpdir: + cwd = os.getcwd() + equi_path = os.path.join(tmpdir, "relaxation", "relax_task") + target_path = os.path.join(tmpdir, "interstitial_00") + os.makedirs(equi_path) + shutil.copy( + os.path.join(self.source_path, "CONTCAR_V_bcc"), + os.path.join(equi_path, "CONTCAR"), + ) + FakeInterstitialDefect.coords = coords + prop = Interstitial({ + "type": "interstitial", + "supercell": [1, 1, 1], + "insert_ele": ["V"], + "lattice_type": lattice_type, + }) + try: + with patch("apex.core.property.Interstitial.pmgInterstitial", FakeInterstitialDefect): + with self.assertRaisesRegex(RuntimeError, message): + prop.make_confs(target_path, equi_path) + finally: + os.chdir(cwd) diff --git a/tests/test_main_workflow_errors.py b/tests/test_main_workflow_errors.py index 4c469aab..0e3d8ac5 100644 --- a/tests/test_main_workflow_errors.py +++ b/tests/test_main_workflow_errors.py @@ -312,6 +312,229 @@ def test_download_failure_artifacts_skips_existing_target_dir(self): self.assertEqual(downloaded, 0) mocked_download.assert_not_called() + def test_download_failure_artifacts_prefers_failed_backward_slice_and_writes_summary(self): + class SliceStepInfo: + def get_step(self, parent_id=None, sort_by_generation=False, key=None): + if parent_id == "post-001": + return [ + { + "id": "run-004", + "displayName": "PropsLAMMPS-Cal", + "outputs": { + "artifacts": { + "backward_dir": "backward-artifact", + } + }, + } + ] + return [] + + root_step = { + "id": "post-001", + "displayName": "Props-post", + "outputs": { + "artifacts": { + "main-logs": "logs-artifact", + "dflow_internal": "ignored-dflow-artifact", + "debug-extra": "ignored-extra-artifact", + } + }, + } + + def fake_download(artifact, path, **kwargs): + os.makedirs(path, exist_ok=True) + if artifact == "logs-artifact": + with open(os.path.join(path, "main.log"), "w", encoding="utf-8") as fp: + fp.write("LAMMPS failed for property task(s): task.000004\n") + else: + self.assertEqual(kwargs.get("slice"), 4) + with open(os.path.join(path, "apex_task_status.json"), "w", encoding="utf-8") as fp: + fp.write( + '{"state": "failed", "reason": "nonzero_exit", "exit_code": 1, ' + '"retry_reason": "header_only_lammps_log_after_nonzero_exit"}' + ) + with open(os.path.join(path, "log.lammps"), "w", encoding="utf-8") as fp: + fp.write("LAMMPS (29 Aug 2024)\n") + + with tempfile.TemporaryDirectory() as tmpdir: + with mock.patch("apex.main.download_artifact", side_effect=fake_download) as mocked_download: + downloaded = apex_main._download_failure_artifacts_for_step( + wf_info=SliceStepInfo(), + root_step=root_step, + key="propertycal-confs-std-bcc-elastic-00", + work_dir=tmpdir, + ) + + summary_path = os.path.join( + tmpdir, + ".failed-artifacts", + "propertycal-confs-std-bcc-elastic-00", + "summary.json", + ) + self.assertTrue(os.path.isfile(summary_path)) + with open(summary_path, "r", encoding="utf-8") as fp: + summary = __import__("json").load(fp) + + self.assertEqual(downloaded, 2) + self.assertEqual(mocked_download.call_count, 2) + self.assertEqual(summary["failed_task_count"], 1) + self.assertEqual(summary["classifications"]["remote_lammps_startup_failure"], 1) + + def test_extract_failed_task_ids_ignores_unrelated_files(self): + with tempfile.TemporaryDirectory() as tmpdir: + with open(os.path.join(tmpdir, "notes.txt"), "w", encoding="utf-8") as fp: + fp.write("unrelated task.000001 text\n") + with open(os.path.join(tmpdir, "main.log"), "w", encoding="utf-8") as fp: + fp.write("LAMMPS failed for task.000002\n") + + self.assertEqual(apex_main._extract_failed_task_ids(tmpdir), ["000002"]) + + def test_download_failure_artifacts_continues_after_main_log_download_error(self): + root_step = { + "id": "post-001", + "displayName": "Props-post", + "outputs": { + "artifacts": { + "main-logs": "logs-artifact", + } + }, + } + + with tempfile.TemporaryDirectory() as tmpdir: + with mock.patch( + "apex.main._download_artifact_with_retry", + side_effect=RuntimeError("storage temporarily unavailable"), + ): + downloaded = apex_main._download_failure_artifacts_for_step( + wf_info=FakeStepInfo(), + root_step=root_step, + key="propertycal-confs-std-bcc-elastic-00", + work_dir=tmpdir, + ) + + self.assertEqual(downloaded, 0) + + def test_download_failure_artifacts_skips_duplicate_main_log_artifact(self): + class DuplicateStepInfo: + def get_step(self, parent_id=None, sort_by_generation=False, key=None): + if parent_id == "post-001": + return [ + { + "id": "post-001", + "displayName": "Props-post-duplicate", + "outputs": { + "artifacts": { + "main-logs": "duplicate-logs-artifact", + } + }, + } + ] + return [] # pragma: no cover + + root_step = { + "id": "post-001", + "displayName": "Props-post", + "outputs": { + "artifacts": { + "main-logs": "logs-artifact", + } + }, + } + + def fake_download(artifact, path, **_kwargs): + os.makedirs(path, exist_ok=True) + with open(os.path.join(path, "main.log"), "w", encoding="utf-8") as fp: + fp.write(f"{artifact}\n") + + with tempfile.TemporaryDirectory() as tmpdir: + with mock.patch("apex.main.download_artifact", side_effect=fake_download) as mocked_download: + downloaded = apex_main._download_failure_artifacts_for_step( + wf_info=DuplicateStepInfo(), + root_step=root_step, + key="propertycal-confs-std-bcc-elastic-00", + work_dir=tmpdir, + ) + + self.assertEqual(downloaded, 1) + mocked_download.assert_called_once() + + def test_failure_artifact_helpers_handle_invalid_status_and_fallback_downloads(self): + class RelatedStepInfo: + def get_step(self, parent_id=None, sort_by_generation=False, key=None): + if parent_id == "post-001": + return [ + { + "id": "run-001", + "displayName": "PropsLAMMPS-Cal", + "outputs": { + "artifacts": { + "backward_dir": "backward-artifact", + "output_work_path": "extra-artifact", + } + }, + }, + { + "id": "run-001", + "displayName": "PropsLAMMPS-Cal", + "outputs": { + "artifacts": { + "backward_dir": "duplicate-artifact", + } + }, + }, + ] + return [] + + root_step = { + "id": "post-001", + "displayName": "Props-post", + "outputs": { + "artifacts": { + "main-logs": "logs-artifact", + } + }, + } + calls = [] + + def fake_download(artifact, path, **kwargs): + calls.append((artifact, kwargs)) + os.makedirs(path, exist_ok=True) + if artifact == "logs-artifact": + with open(os.path.join(path, "main.log"), "w", encoding="utf-8") as fp: + fp.write("no task id here\n") + elif artifact == "backward-artifact": + self.assertNotIn("slice", kwargs) + with open(os.path.join(path, "apex_task_status.json"), "w", encoding="utf-8") as fp: + fp.write("{not-json") + else: + with open(os.path.join(path, ".debug.log"), "w", encoding="utf-8") as fp: + fp.write("diagnostic\n") + + with tempfile.TemporaryDirectory() as tmpdir: + missing_path = os.path.join(tmpdir, "does-not-exist.txt") + self.assertEqual(apex_main._safe_read_text(missing_path), "") + with mock.patch("apex.main.download_artifact", side_effect=fake_download): + downloaded = apex_main._download_failure_artifacts_for_step( + wf_info=RelatedStepInfo(), + root_step=root_step, + key="propertycal-confs-std-bcc-elastic-00", + work_dir=tmpdir, + ) + + summary_path = os.path.join( + tmpdir, + ".failed-artifacts", + "propertycal-confs-std-bcc-elastic-00", + "summary.json", + ) + with open(summary_path, "r", encoding="utf-8") as fp: + summary = __import__("json").load(fp) + + self.assertEqual(downloaded, 3) + self.assertEqual([call[0] for call in calls], ["logs-artifact", "backward-artifact", "extra-artifact"]) + self.assertEqual(summary["failed_task_count"], 1) + self.assertEqual(summary["classifications"]["invalid_task_status"], 1) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_ops.py b/tests/test_ops.py index e0463e4a..d1630f97 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -4,6 +4,8 @@ import glob import shutil import tempfile +from types import SimpleNamespace +from unittest.mock import patch from pathlib import Path from dflow.python import ( OP, @@ -15,8 +17,17 @@ from monty.serialization import loadfn from apex.op.relaxation_ops import RelaxMake, _check_relaxation_outputs -from apex.op.property_ops import PropsMake, _is_failed_task_status +from apex.op.property_ops import PropsMake, PropsPost, PropsRepairStatusCheck, _is_failed_task_status from apex.op.RunLAMMPS import RunLAMMPS +from apex.superop.SimplePropertySteps import SimplePropertySteps +from apex.task_failure import ( + REMOTE_LAMMPS_STARTUP_FAILURE, + classify_apex_task_status, + classify_lammps_exit_code, + is_header_only_lammps_failure, + is_lammps_header_only_log, + load_and_classify_task_status, +) from apex.utils import apex_task_succeeded, all_apex_task_status_succeeded try: from context import write_poscar @@ -28,6 +39,50 @@ class TestTaskStatusHelpers(unittest.TestCase): + def test_task_failure_helpers_cover_error_branches(self): + with tempfile.TemporaryDirectory() as tmpdir: + task_dir = Path(tmpdir) + (task_dir / "log.lammps").write_text("LAMMPS (29 Aug 2024)\n") + self.assertTrue(is_lammps_header_only_log(task_dir / "log.lammps")) + self.assertTrue(is_header_only_lammps_failure(task_dir, 1)) + + (task_dir / "CONTCAR").write_text("finished\n") + self.assertFalse(is_header_only_lammps_failure(task_dir, 1)) + + broken_status = task_dir / "broken_status.json" + broken_status.write_text("{not-json") + self.assertEqual( + load_and_classify_task_status(broken_status)["reason"], + "invalid_task_status", + ) + + valid_status = task_dir / "apex_task_status.json" + valid_status.write_text('{"state": "failed", "exit_code": "bad"}') + self.assertEqual( + load_and_classify_task_status(valid_status)["reason"], + "unknown_failure", + ) + + self.assertEqual(classify_lammps_exit_code(None)["reason"], "unknown_failure") + self.assertEqual(classify_lammps_exit_code(126)["reason"], "command_not_executable") + self.assertEqual(classify_lammps_exit_code(129)["reason"], "killed_or_oom") + self.assertEqual(classify_apex_task_status(None)["reason"], "invalid_task_status") + self.assertEqual( + classify_apex_task_status({"state": "failed", "exit_code": "not-int"})["reason"], + "unknown_failure", + ) + self.assertEqual( + classify_apex_task_status({"state": "failed", "exit_code": 0})["reason"], + "invalid_task_status", + ) + + def test_lammps_header_only_log_treats_read_error_as_not_header_only(self): + with tempfile.TemporaryDirectory() as tmpdir: + log_path = Path(tmpdir) / "log.lammps" + log_path.write_text("LAMMPS (29 Aug 2024)\n") + with patch("apex.task_failure.Path.read_text", side_effect=OSError("boom")): + self.assertFalse(is_lammps_header_only_log(log_path)) + def test_failed_status_uses_apex_task_status_fields(self): self.assertFalse(_is_failed_task_status({ "state": "succeeded", @@ -60,6 +115,183 @@ def test_rerun_finished_helpers_match_status_state(self): (task1 / "apex_task_status.json").write_text('{"state": "succeeded", "exit_code": 7}') self.assertTrue(all_apex_task_status_succeeded(work_dir)) + def test_props_repair_status_check_summarizes_remote_startup_failures(self): + with tempfile.TemporaryDirectory() as tmpdir: + root = Path(tmpdir) + input_all = root / "all" + input_post = root / "post" + prop_dir = input_post / "confs" / "std-bcc" / "elastic_00" + task_dir = prop_dir / "task.000003" + (input_all / "confs").mkdir(parents=True) + task_dir.mkdir(parents=True) + (task_dir / "apex_task_status.json").write_text( + '{"state": "failed", "reason": "nonzero_exit", "exit_code": 1, ' + '"retry_reason": "header_only_lammps_log_after_nonzero_exit"}' + ) + (task_dir / "log.lammps").write_text("LAMMPS (29 Aug 2024)\n") + + op = PropsRepairStatusCheck() + out = op.execute(OPIO({ + "input_post": input_post, + "input_all": input_all, + "task_names": ["confs/std-bcc/elastic_00/task.000003"], + "path_to_prop": "confs/std-bcc/elastic_00", + })) + + self.assertEqual(out["checked_post"], input_post) + summary = loadfn(prop_dir / "run_status_check.json") + self.assertEqual( + summary["retry_eligible_tasks"][0]["reason"], + REMOTE_LAMMPS_STARTUP_FAILURE, + ) + + def test_props_repair_status_check_short_circuits_empty_or_missing_inputs(self): + with tempfile.TemporaryDirectory() as tmpdir: + root = Path(tmpdir) + input_all = root / "all" + input_post = root / "post" + (input_all / "confs").mkdir(parents=True) + input_post.mkdir() + + op = PropsRepairStatusCheck() + empty_out = op.execute(OPIO({ + "input_post": input_post, + "input_all": input_all, + "task_names": [], + "path_to_prop": "confs/std-bcc/eos_00", + })) + self.assertEqual(empty_out["checked_post"], input_post) + + missing_src_out = op.execute(OPIO({ + "input_post": input_post, + "input_all": input_all, + "task_names": ["confs/std-bcc/eos_00/task.000000"], + "path_to_prop": "confs/std-bcc/eos_00", + })) + self.assertEqual(missing_src_out["checked_post"], input_post) + + def test_props_post_reports_lammps_status_failures_before_compute(self): + with tempfile.TemporaryDirectory() as tmpdir: + cwd = os.getcwd() + root = Path(tmpdir) + input_all = root / "all" + input_post = root / "post" + prop_dir = input_post / "confs" / "std-bcc" / "eos_00" + task_dir = prop_dir / "task.000000" + (input_all / "confs").mkdir(parents=True) + task_dir.mkdir(parents=True) + (task_dir / "apex_task_status.json").write_text( + '{"state": "failed", "reason": "nonzero_exit", "exit_code": 7}' + ) + + try: + with self.assertRaisesRegex(RuntimeError, "LAMMPS failed for property task"): + PropsPost().execute(OPIO({ + "input_post": input_post, + "input_all": input_all, + "prop_param": {"type": "eos"}, + "inter_param": {"type": "deepmd", "model": "model.pb"}, + "task_names": ["confs/std-bcc/eos_00/task.000000"], + "path_to_prop": "confs/std-bcc/eos_00", + })) + finally: + os.chdir(cwd) + + +class TestSimplePropertySteps(unittest.TestCase): + def test_lammps_repair_step_feeds_checked_post_to_post_step(self): + import apex.superop.SimplePropertySteps as simple_steps + + added_steps = [] + + class FakeTemplate: + def __init__(self, op, **kwargs): + self.op = op + self.kwargs = kwargs + + class FakeStep: + def __init__(self, name, template=None, artifacts=None, parameters=None, + with_param=None, key=None, executor=None): + self.name = name + self.template = template + self.artifacts = artifacts or {} + self.parameters = parameters or {} + self.with_param = with_param + self.key = key + self.executor = executor + self.outputs = SimpleNamespace( + artifacts={ + "task_paths": f"{name}-task_paths", + "output_work_path": f"{name}-output_work_path", + "backward_dir": f"{name}-backward_dir", + "checked_post": f"{name}-checked_post", + "retrieve_path": f"{name}-retrieve_path", + }, + parameters={ + "task_names": f"{name}-task_names", + "njobs": f"{name}-njobs", + }, + ) + + def fake_add(self, step): + added_steps.append(step) + + obj = SimplePropertySteps.__new__(SimplePropertySteps) + object.__setattr__(obj, "inputs", SimpleNamespace( + parameters={ + "prop_param": "prop-param", + "inter_param": "inter-param", + "do_refine": False, + "path_to_prop": "confs/std-bcc/eos_00", + }, + artifacts={"input_work_path": "input-work"}, + )) + object.__setattr__(obj, "outputs", SimpleNamespace( + artifacts={"retrieve_path": SimpleNamespace(_from=None)} + )) + object.__setattr__(obj, "step_keys", { + "make": "props-make", + "run": "props-run", + "post": "props-post", + }) + + with patch.object(simple_steps, "Step", FakeStep), \ + patch.object(simple_steps, "PythonOPTemplate", FakeTemplate), \ + patch.object(simple_steps, "Slices", lambda *args, **kwargs: ("slices", args, kwargs)), \ + patch.object(simple_steps, "argo_range", lambda value: f"range:{value}"), \ + patch.object(simple_steps, "argo_len", lambda value: f"len:{value}"), \ + patch.object(SimplePropertySteps, "add", fake_add): + obj._build( + "step", + make_op=object(), + run_op=object(), + post_op=object(), + make_image="make-image", + run_image="run-image", + post_image="post-image", + run_command="lmp -in in.lammps", + calculator="lammps", + upload_python_packages=[], + group_size=1, + pool_size=1, + executor=None, + repair_op=object(), + ) + + self.assertEqual( + [step.name for step in added_steps], + ["Props-make", "PropsLAMMPS-Cal", "Props-run-status-check", "Props-post"], + ) + post_step = added_steps[-1] + self.assertEqual( + post_step.artifacts["input_post"], + "Props-run-status-check-checked_post", + ) + self.assertEqual( + obj.outputs.artifacts["retrieve_path"]._from, + "Props-post-retrieve_path", + ) + class TestRunLAMMPSDebug(unittest.TestCase): def test_run_lammps_writes_debug_log_on_success(self): @@ -99,14 +331,40 @@ def test_run_lammps_writes_failed_status_with_debug_log(self): self.assertTrue((task_dir / ".debug.log").is_file()) status = loadfn(task_dir / "apex_task_status.json") self.assertEqual(status["state"], "failed") - self.assertEqual(status["reason"], "nonzero_exit") + self.assertEqual(status["reason"], "nonzero_lammps_error") self.assertEqual(status["exit_code"], 7) self.assertEqual(status["debug_log"], ".debug.log") def test_run_lammps_classifies_common_exit_codes(self): self.assertEqual(RunLAMMPS._classify_exit_code(127)["reason"], "command_not_found") + self.assertEqual(RunLAMMPS._classify_exit_code(124)["reason"], "timeout") self.assertEqual(RunLAMMPS._classify_exit_code(137)["reason"], "killed_or_oom") self.assertEqual(RunLAMMPS._classify_exit_code(143)["reason"], "terminated") + self.assertEqual( + RunLAMMPS._runtime_int_option( + "APEX_LAMMPS_HEADER_RETRY=3 lmp -in in.lammps", + "APEX_LAMMPS_HEADER_RETRY", + 2, + ), + 3, + ) + self.assertEqual( + RunLAMMPS._runtime_int_option( + "APEX_LAMMPS_HEADER_RETRY=bad lmp -in in.lammps", + "APEX_LAMMPS_HEADER_RETRY", + 2, + ), + 2, + ) + self.assertEqual( + RunLAMMPS._runtime_float_option( + "APEX_LAMMPS_HEADER_RETRY_DELAY=bad lmp -in in.lammps", + "APEX_LAMMPS_HEADER_RETRY_DELAY", + 5.0, + ), + 5.0, + ) + self.assertFalse(RunLAMMPS._is_lammps_header_only_log(Path("missing-log"))) def test_run_lammps_retries_header_only_failure(self): with tempfile.TemporaryDirectory() as tmpdir: @@ -124,10 +382,11 @@ def test_run_lammps_retries_header_only_failure(self): "Path('stress_timeseries.txt').write_text('0 0 0 0 0 0 0\\n')\n" ) op = RunLAMMPS() - op.execute(OPIO({ - "input_lammps": task_dir, - "run_command": f"{sys.executable} {script.name}", - })) + with patch.dict(os.environ, {"APEX_LAMMPS_HEADER_RETRY_DELAY": "0"}): + op.execute(OPIO({ + "input_lammps": task_dir, + "run_command": f"{sys.executable} {script.name}", + })) self.assertEqual((task_dir / "count.txt").read_text(), "2") self.assertTrue((task_dir / "log.lammps.attempt1").is_file()) @@ -135,6 +394,29 @@ def test_run_lammps_retries_header_only_failure(self): self.assertEqual(status["state"], "succeeded") self.assertEqual(status["attempts"], 2) self.assertEqual(status["retry_reason"], "header_only_lammps_log_after_nonzero_exit") + self.assertEqual(status["retry_classification"], REMOTE_LAMMPS_STARTUP_FAILURE) + + def test_run_lammps_classifies_persistent_header_only_failure(self): + with tempfile.TemporaryDirectory() as tmpdir: + task_dir = Path(tmpdir) + script = task_dir / "always_header_only.py" + script.write_text( + "from pathlib import Path\n" + "Path('log.lammps').write_text('LAMMPS (29 Aug 2024)\\n')\n" + "Path('outlog').write_text('LAMMPS (29 Aug 2024)\\n')\n" + "raise SystemExit(1)\n" + ) + op = RunLAMMPS() + with patch.dict(os.environ, {"APEX_LAMMPS_HEADER_RETRY_DELAY": "0"}): + op.execute(OPIO({ + "input_lammps": task_dir, + "run_command": f"{sys.executable} {script.name}", + })) + + status = loadfn(task_dir / "apex_task_status.json") + self.assertEqual(status["state"], "failed") + self.assertEqual(status["reason"], REMOTE_LAMMPS_STARTUP_FAILURE) + self.assertEqual(status["attempts"], 2) class TestMakeRelaxOPs(unittest.TestCase): diff --git a/tests/test_phonon.py b/tests/test_phonon.py index 82e8bc8f..b817df5e 100644 --- a/tests/test_phonon.py +++ b/tests/test_phonon.py @@ -1,6 +1,7 @@ import glob import os import shutil +import subprocess import sys import unittest from pathlib import Path @@ -11,6 +12,7 @@ from monty.serialization import loadfn from pymatgen.io.vasp import Incar from apex.core.property.Phonon import Phonon +from apex.core.calculator.calculator import LAMMPS_INTER_TYPE sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) __package__ = "tests" @@ -77,6 +79,10 @@ def test_phonopy_setup_command_prefers_v4_setup_tool(self): Phonon.phonopy_setup_command("-d --dim='2 2 2' -c POSCAR"), "phonopy-init -d --dim='2 2 2' -c POSCAR", ) + self.assertEqual( + Phonon.phonopy_command("band.conf"), + "phonopy band.conf", + ) with patch("apex.core.property.Phonon.shutil.which", return_value=None): self.assertEqual( @@ -84,6 +90,398 @@ def test_phonopy_setup_command_prefers_v4_setup_tool(self): "phonopy -d --dim='2 2 2' -c POSCAR", ) + def test_phonopy_writefc_commands_collapse_for_phonopy_2(self): + with patch("apex.core.property.Phonon.shutil.which", return_value=None): + self.assertEqual( + Phonon.phonopy_writefc_commands("phonopy_disp.yaml --writefc"), + ["phonopy phonopy_disp.yaml --writefc"], + ) + + def test_writefc_command_falls_back_to_phonopy(self): + work_dir = Path("output/phonopy_writefc_fallback") + shutil.rmtree(work_dir, ignore_errors=True) + work_dir.mkdir(parents=True) + calls = [] + cwd = os.getcwd() + + def fake_check_call(command, shell): + self.assertTrue(shell) + calls.append(command) + if command.startswith("phonopy-init"): + raise subprocess.CalledProcessError(2, command) + Path("FORCE_CONSTANTS").write_text("fake force constants\n") + + try: + os.chdir(work_dir) + with patch("apex.core.property.Phonon.shutil.which", return_value="/usr/bin/phonopy-init"), \ + patch("apex.core.property.Phonon.subprocess.check_call", side_effect=fake_check_call): + Phonon.run_first_success( + Phonon.phonopy_writefc_commands("phonopy_disp.yaml --writefc"), + required_file="FORCE_CONSTANTS", + ) + self.assertEqual( + calls, + [ + "phonopy-init phonopy_disp.yaml --writefc", + "phonopy phonopy_disp.yaml --writefc", + ], + ) + self.assertTrue(Path("FORCE_CONSTANTS").is_file()) + finally: + os.chdir(cwd) + shutil.rmtree(work_dir, ignore_errors=True) + + def test_run_first_success_raises_last_error_when_all_commands_fail(self): + with patch( + "apex.core.property.Phonon.subprocess.check_call", + side_effect=[ + subprocess.CalledProcessError(2, "phonopy-init --writefc"), + subprocess.CalledProcessError(2, "phonopy --writefc"), + ], + ): + with self.assertRaises(subprocess.CalledProcessError) as context: + Phonon.run_first_success(["phonopy-init --writefc", "phonopy --writefc"]) + self.assertEqual(context.exception.cmd, "phonopy --writefc") + + def test_run_first_success_requires_output_file_before_accepting_command(self): + calls = [] + + def fake_check_call(command, shell): + self.assertTrue(shell) + calls.append(command) + if command == "second": + Path("FORCE_CONSTANTS").write_text("created\n") + + work_dir = Path("output/phonopy_required_file") + shutil.rmtree(work_dir, ignore_errors=True) + work_dir.mkdir(parents=True) + cwd = os.getcwd() + try: + os.chdir(work_dir) + with patch("apex.core.property.Phonon.subprocess.check_call", side_effect=fake_check_call): + Phonon.run_first_success(["first", "second"], required_file="FORCE_CONSTANTS") + self.assertEqual(calls, ["first", "second"]) + finally: + os.chdir(cwd) + shutil.rmtree(work_dir, ignore_errors=True) + + def test_write_band_dat_accepts_nonzero_exit_with_output(self): + work_dir = Path("output/phonopy_bandplot_nonzero") + shutil.rmtree(work_dir, ignore_errors=True) + work_dir.mkdir(parents=True) + cwd = os.getcwd() + + def fake_run(command, stdout, stderr, text): + self.assertEqual(command, ["phonopy-bandplot", "--gnuplot", "band.yaml"]) + stdout.write("# distance frequency\n") + return subprocess.CompletedProcess(command, 1, stderr="warning") + + try: + os.chdir(work_dir) + Path("band.yaml").write_text("phonon: []\n") + with patch("apex.core.property.Phonon.subprocess.run", side_effect=fake_run): + Phonon.write_band_dat() + self.assertGreater(Path("band.dat").stat().st_size, 0) + finally: + os.chdir(cwd) + shutil.rmtree(work_dir, ignore_errors=True) + + def test_write_band_dat_accepts_zero_exit_and_raises_on_empty_output(self): + work_dir = Path("output/phonopy_bandplot_branches") + shutil.rmtree(work_dir, ignore_errors=True) + work_dir.mkdir(parents=True) + cwd = os.getcwd() + try: + os.chdir(work_dir) + Path("band.yaml").write_text("phonon: []\n") + + def successful_run(command, stdout, stderr, text): + stdout.write("# distance frequency\n") + return subprocess.CompletedProcess(command, 0, stderr="") + + with patch("apex.core.property.Phonon.subprocess.run", side_effect=successful_run): + Phonon.write_band_dat() + self.assertGreater(Path("band.dat").stat().st_size, 0) + + def empty_failed_run(command, stdout, stderr, text): + return subprocess.CompletedProcess(command, 1, stderr="empty") + + with patch("apex.core.property.Phonon.subprocess.run", side_effect=empty_failed_run): + with self.assertRaises(subprocess.CalledProcessError): + Phonon.write_band_dat() + finally: + os.chdir(cwd) + shutil.rmtree(work_dir, ignore_errors=True) + + def test_write_band_dat_requires_band_yaml(self): + work_dir = Path("output/phonopy_bandplot_missing_yaml") + shutil.rmtree(work_dir, ignore_errors=True) + work_dir.mkdir(parents=True) + cwd = os.getcwd() + try: + os.chdir(work_dir) + with self.assertRaises(FileNotFoundError): + Phonon.write_band_dat() + finally: + os.chdir(cwd) + shutil.rmtree(work_dir, ignore_errors=True) + + def _write_phonon_compute_common(self, work_dir): + work_dir.mkdir(parents=True, exist_ok=True) + (work_dir / "band_path.json").write_text("[]\n") + (work_dir / "band.conf").write_text("BAND = 0 0 0 0.5 0 0\n") + (work_dir / "phonopy_disp.yaml").write_text("displacements: []\n") + + def _write_band_dat_for_compute(self): + Path("band.dat").write_text("# phonopy bandplot\n# G X\n\n") + + def test_compute_lower_resolves_task_paths_and_restores_cwd_on_reproduce_error(self): + work_dir = Path("output/phonon_reproduce_missing_init") + shutil.rmtree(work_dir, ignore_errors=True) + self._write_phonon_compute_common(work_dir) + existing_rel = Path("output/phonon_existing_relative_task") + existing_rel.mkdir(parents=True, exist_ok=True) + abs_task = (work_dir / "abs_task").absolute() + abs_task.mkdir(parents=True) + cwd = Path.cwd() + + try: + phonon = Phonon( + {"type": "phonon", "reproduce": True, "init_from_suffix": "old", "output_suffix": "new"}, + inter_param={"type": "vasp"}, + ) + with self.assertRaisesRegex(RuntimeError, "initial data path"): + phonon._compute_lower( + str(work_dir / "result.json"), + [str(abs_task), str(existing_rel), "bare_missing_task", "nested/missing_task"], + [], + ) + self.assertEqual(Path.cwd(), cwd) + finally: + shutil.rmtree(work_dir, ignore_errors=True) + shutil.rmtree(existing_rel, ignore_errors=True) + + def test_compute_lower_reproduce_success_and_malformed_band_errors(self): + work_dir = Path("output/phonon_reproduce_success") + shutil.rmtree(work_dir, ignore_errors=True) + self._write_phonon_compute_common(work_dir) + init_dir = work_dir / "init" + init_dir.mkdir() + + def fake_post_repro(init_data_path, init_from_suffix, all_tasks, ptr_data, reprod_last_frame): + self.assertEqual(init_data_path, str(init_dir.absolute())) + self.assertEqual(init_from_suffix, "old") + self.assertTrue(reprod_last_frame) + (work_dir / "band.dat").write_text("# phonopy bandplot\n# G X\n\n") + return {"reproduced": True}, ptr_data + + try: + phonon = Phonon( + { + "type": "phonon", + "reproduce": True, + "init_from_suffix": "old", + "output_suffix": "new", + "init_data_path": str(init_dir), + }, + inter_param={"type": "vasp"}, + ) + with patch("apex.core.property.Phonon.post_repro", side_effect=fake_post_repro): + result, ptr = phonon._compute_lower(str(work_dir / "result.json"), [], []) + self.assertTrue(result["reproduced"]) + self.assertIn("G", result["segment"]) + + (work_dir / "band.dat").write_text("only one line") + with patch("apex.core.property.Phonon.post_repro", return_value=({}, "")): + with self.assertRaisesRegex(ValueError, "empty or malformed"): + phonon._compute_lower(str(work_dir / "result2.json"), [], []) + finally: + shutil.rmtree(work_dir, ignore_errors=True) + + def test_compute_lower_reproduce_requires_band_dat_from_post_repro(self): + work_dir = Path("output/phonon_reproduce_missing_band") + shutil.rmtree(work_dir, ignore_errors=True) + self._write_phonon_compute_common(work_dir) + init_dir = work_dir / "init" + init_dir.mkdir() + + try: + phonon = Phonon( + { + "type": "phonon", + "reproduce": True, + "init_from_suffix": "old", + "output_suffix": "new", + "init_data_path": str(init_dir), + }, + inter_param={"type": "vasp"}, + ) + with patch("apex.core.property.Phonon.post_repro", return_value=({}, "")): + with self.assertRaisesRegex(FileNotFoundError, "band.dat was not created"): + phonon._compute_lower(str(work_dir / "result.json"), [], []) + finally: + shutil.rmtree(work_dir, ignore_errors=True) + + def test_compute_lower_abacus_uses_phonopy_init_for_forces_and_phonopy_for_band(self): + work_dir = Path("output/phonon_abacus_compute") + shutil.rmtree(work_dir, ignore_errors=True) + self._write_phonon_compute_common(work_dir) + task_dir = work_dir / "task.000000" + (task_dir / "OUT.ABACUS").mkdir(parents=True) + (task_dir / "band.conf").write_text((work_dir / "band.conf").read_text()) + (task_dir / "STRU.ori").write_text("STRU\n") + (work_dir / "STRU").write_text("STRU\n") + (task_dir / "phonopy_disp.yaml").write_text((work_dir / "phonopy_disp.yaml").read_text()) + (task_dir / "OUT.ABACUS" / "running_scf.log").write_text("force log\n") + calls = [] + + def fake_check_call(command, shell): + self.assertTrue(shell) + calls.append(command) + if command.startswith(Phonon.phonopy_setup_command("-f")): + Path("FORCE_SETS").write_text("fake force sets\n") + elif command == Phonon.phonopy_command("band.conf"): + Path("band.yaml").write_text("phonon: []\n") + + try: + phonon = Phonon({"type": "phonon"}, inter_param={"type": "abacus"}) + with patch("apex.core.property.Phonon.subprocess.check_call", side_effect=fake_check_call), \ + patch.object(Phonon, "write_band_dat", side_effect=self._write_band_dat_for_compute): + phonon._compute_lower(str(work_dir / "result.json"), [str(task_dir)], []) + self.assertEqual(calls[0], Phonon.phonopy_setup_command("-f task.0*/OUT.ABACUS/running_scf.log")) + self.assertEqual(calls[1], Phonon.phonopy_command("band.conf")) + self.assertFalse(any("--abacus" in command and command.startswith("phonopy band.conf") for command in calls)) + finally: + shutil.rmtree(work_dir, ignore_errors=True) + + def test_compute_lower_abacus_requires_force_sets_after_setup(self): + work_dir = Path("output/phonon_abacus_missing_force_sets") + shutil.rmtree(work_dir, ignore_errors=True) + self._write_phonon_compute_common(work_dir) + task_dir = work_dir / "task.000000" + (task_dir / "OUT.ABACUS").mkdir(parents=True) + (task_dir / "band.conf").write_text((work_dir / "band.conf").read_text()) + (task_dir / "STRU.ori").write_text("STRU\n") + (work_dir / "STRU").write_text("STRU\n") + (task_dir / "phonopy_disp.yaml").write_text((work_dir / "phonopy_disp.yaml").read_text()) + (task_dir / "OUT.ABACUS" / "running_scf.log").write_text("force log\n") + + try: + phonon = Phonon({"type": "phonon"}, inter_param={"type": "abacus"}) + with patch("apex.core.property.Phonon.subprocess.check_call", return_value=0): + with self.assertRaisesRegex(FileNotFoundError, "FORCE_SETS was not created"): + phonon._compute_lower(str(work_dir / "result.json"), [str(task_dir)], []) + finally: + shutil.rmtree(work_dir, ignore_errors=True) + + def test_compute_lower_vasp_linear_uses_setup_for_fc_and_phonopy_for_band(self): + work_dir = Path("output/phonon_vasp_linear_compute") + shutil.rmtree(work_dir, ignore_errors=True) + self._write_phonon_compute_common(work_dir) + task_dir = work_dir / "task.000000" + task_dir.mkdir(parents=True) + (task_dir / "band.conf").write_text((work_dir / "band.conf").read_text()) + (task_dir / "POSCAR-unitcell").write_text("POSCAR\n") + (task_dir / "vasprun.xml").write_text("\n") + calls = [] + + def fake_check_call(command, shell): + self.assertTrue(shell) + calls.append(command) + if command == Phonon.phonopy_setup_command("--fc vasprun.xml"): + Path("FORCE_CONSTANTS").write_text("fake force constants\n") + elif command == Phonon.phonopy_command('--dim="2 2 2" -c POSCAR-unitcell band.conf'): + Path("band.yaml").write_text("phonon: []\n") + + try: + phonon = Phonon({"type": "phonon", "supercell_size": [2, 2, 2]}, inter_param={"type": "vasp"}) + with patch("apex.core.property.Phonon.subprocess.check_call", side_effect=fake_check_call), \ + patch.object(Phonon, "write_band_dat", side_effect=self._write_band_dat_for_compute): + phonon._compute_lower(str(work_dir / "result.json"), [str(task_dir)], []) + self.assertEqual(calls[0], Phonon.phonopy_setup_command("--fc vasprun.xml")) + self.assertEqual(calls[1], Phonon.phonopy_command('--dim="2 2 2" -c POSCAR-unitcell band.conf')) + finally: + shutil.rmtree(work_dir, ignore_errors=True) + + def test_compute_lower_vasp_displacement_uses_setup_for_force_sets_and_phonopy_for_band(self): + work_dir = Path("output/phonon_vasp_displacement_compute") + shutil.rmtree(work_dir, ignore_errors=True) + self._write_phonon_compute_common(work_dir) + task_dir = work_dir / "task.000000" + task_dir.mkdir(parents=True) + (task_dir / "band.conf").write_text((work_dir / "band.conf").read_text()) + (task_dir / "phonopy_disp.yaml").write_text((work_dir / "phonopy_disp.yaml").read_text()) + (work_dir / "POSCAR-unitcell").write_text("POSCAR\n") + (task_dir / "vasprun.xml").write_text("\n") + calls = [] + + def fake_check_call(command, shell): + self.assertTrue(shell) + calls.append(command) + if command == Phonon.phonopy_setup_command("-f task.0*/vasprun.xml"): + Path("FORCE_SETS").write_text("fake force sets\n") + elif command == Phonon.phonopy_command('--dim="2 2 2" -c POSCAR-unitcell band.conf'): + Path("band.yaml").write_text("phonon: []\n") + + try: + phonon = Phonon( + {"type": "phonon", "supercell_size": [2, 2, 2], "approach": "displacement"}, + inter_param={"type": "vasp"}, + ) + with patch("apex.core.property.Phonon.subprocess.check_call", side_effect=fake_check_call), \ + patch.object(Phonon, "write_band_dat", side_effect=self._write_band_dat_for_compute): + phonon._compute_lower(str(work_dir / "result.json"), [str(task_dir)], []) + self.assertEqual(calls[0], Phonon.phonopy_setup_command("-f task.0*/vasprun.xml")) + self.assertEqual(calls[1], Phonon.phonopy_command('--dim="2 2 2" -c POSCAR-unitcell band.conf')) + finally: + shutil.rmtree(work_dir, ignore_errors=True) + + def test_compute_lower_vasp_displacement_requires_force_sets_after_setup(self): + work_dir = Path("output/phonon_vasp_displacement_missing_force_sets") + shutil.rmtree(work_dir, ignore_errors=True) + self._write_phonon_compute_common(work_dir) + task_dir = work_dir / "task.000000" + task_dir.mkdir(parents=True) + (task_dir / "band.conf").write_text((work_dir / "band.conf").read_text()) + (task_dir / "phonopy_disp.yaml").write_text((work_dir / "phonopy_disp.yaml").read_text()) + (work_dir / "POSCAR-unitcell").write_text("POSCAR\n") + (task_dir / "vasprun.xml").write_text("\n") + + try: + phonon = Phonon( + {"type": "phonon", "supercell_size": [2, 2, 2], "approach": "displacement"}, + inter_param={"type": "vasp"}, + ) + with patch("apex.core.property.Phonon.subprocess.check_call", return_value=0): + with self.assertRaisesRegex(FileNotFoundError, "FORCE_SETS was not created"): + phonon._compute_lower(str(work_dir / "result.json"), [str(task_dir)], []) + finally: + shutil.rmtree(work_dir, ignore_errors=True) + + def test_compute_lower_lammps_uses_phonopy_for_band(self): + work_dir = Path("output/phonon_lammps_compute") + shutil.rmtree(work_dir, ignore_errors=True) + self._write_phonon_compute_common(work_dir) + task_dir = work_dir / "task.000000" + task_dir.mkdir(parents=True) + (task_dir / "FORCE_CONSTANTS").write_text("fake force constants\n") + calls = [] + + def fake_check_call(command, shell): + self.assertTrue(shell) + calls.append(command) + if command == Phonon.phonopy_command('--dim="2 2 2" -c POSCAR band.conf'): + Path("band.yaml").write_text("phonon: []\n") + + try: + phonon = Phonon({"type": "phonon", "supercell_size": [2, 2, 2]}, inter_param={"type": LAMMPS_INTER_TYPE[0]}) + with patch("apex.core.property.Phonon.subprocess.check_call", side_effect=fake_check_call), \ + patch.object(Phonon, "write_band_dat", side_effect=self._write_band_dat_for_compute): + phonon._compute_lower(str(work_dir / "result.json"), [str(task_dir)], []) + self.assertEqual(calls, [Phonon.phonopy_command('--dim="2 2 2" -c POSCAR band.conf')]) + finally: + shutil.rmtree(work_dir, ignore_errors=True) + def test_make_phonon_conf(self): if not os.path.exists(os.path.join(self.equi_path, "CONTCAR")): with self.assertRaises(RuntimeError): diff --git a/tests/test_submit_path_validation.py b/tests/test_submit_path_validation.py index f9efa145..865ca4f1 100644 --- a/tests/test_submit_path_validation.py +++ b/tests/test_submit_path_validation.py @@ -2,8 +2,12 @@ import tempfile import os import json +from unittest.mock import patch +import apex.submit as submit_module from apex.submit import ( + _with_lammps_retry_env, + submit_workflow, validate_submit_paths, auto_fill_type_map_from_poscar, pack_upload_dir, @@ -11,6 +15,85 @@ class TestSubmitPathValidation(unittest.TestCase): + def test_with_lammps_retry_env_handles_empty_existing_and_new_commands(self): + class DummyConfig: + lammps_header_retry_attempts = 4 + lammps_header_retry_delay = 0.25 + lammps_transient_retry_attempts = 2 + + self.assertEqual(_with_lammps_retry_env("", DummyConfig()), "") + existing = "APEX_LAMMPS_HEADER_RETRY=9 lmp -in in.lammps" + self.assertEqual(_with_lammps_retry_env(existing, DummyConfig()), existing) + wrapped = _with_lammps_retry_env("lmp -in in.lammps", DummyConfig()) + self.assertTrue(wrapped.startswith("APEX_LAMMPS_HEADER_RETRY=4 ")) + self.assertIn("APEX_LAMMPS_HEADER_RETRY_DELAY=0.25", wrapped) + self.assertIn("APEX_LAMMPS_TRANSIENT_RETRY=2", wrapped) + self.assertTrue(wrapped.endswith("lmp -in in.lammps")) + + def test_submit_workflow_wraps_lammps_run_command(self): + captured = {} + + class DummyConfig: + remote_root = None + dispatcher_config_dict = {} + dflow_config_dict = {} + bohrium_config_dict = {} + dflow_s3_config_dict = {} + database_type = "local" + submit_only = False + flow_name = None + lammps_header_retry_attempts = 4 + lammps_header_retry_delay = 0.25 + lammps_transient_retry_attempts = 2 + + def __init__(self, **_kwargs): + self.basic_config_dict = { + "apex_image_name": "apex-image", + "lammps_image_name": "", + "run_image_name": "run-image", + "lammps_run_command": "", + "run_command": "lmp -in in.lammps", + "phonolammps_run_command": "", + "group_size": 1, + "pool_size": 1, + "upload_python_packages": [], + } + + @staticmethod + def config_dflow(_config): + return None + + @staticmethod + def config_bohrium(_config): + return None + + @staticmethod + def config_s3(_config): + return None + + def get_executor(self, _config): + return None + + class DummyFlow: + def __init__(self, **kwargs): + captured.update(kwargs) + + with tempfile.TemporaryDirectory() as work_dir, \ + patch.object(submit_module, "validate_submit_paths"), \ + patch.object(submit_module, "Config", DummyConfig), \ + patch.object( + submit_module, + "judge_flow", + return_value=(object(), "lammps", "props", None, {"properties": []}), + ), \ + patch.object(submit_module, "FlowGenerator", DummyFlow), \ + patch.object(submit_module, "submit") as mocked_submit: + submit_workflow([{}], {}, [work_dir], "props", submit_only=True) + + self.assertIn("APEX_LAMMPS_HEADER_RETRY=4", captured["run_command"]) + self.assertTrue(captured["run_command"].endswith("lmp -in in.lammps")) + mocked_submit.assert_called_once() + def test_accept_paths_without_dot(self): params = [ {