diff --git a/CITATION.cff b/CITATION.cff index f8ebf23..cd54050 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -37,6 +37,11 @@ authors: email: rhen7213@uni.sydney.edu.au affiliation: University of Sydney orcid: 'https://orcid.org/0000-0002-0461-6625' + - given-names: Jonathan + family-names: Higham + email: j.higham4@lancaster.ac.uk + affiliation: Lancaster University + orcid: 'https://orcid.org/0000-0002-9779-9968' - given-names: Jas family-names: Kalayan email: jas.kalayan@stfc.ac.uk diff --git a/CodeEntropy/config/arg_config_manager.py b/CodeEntropy/config/arg_config_manager.py index b8682e5..dab98dc 100644 --- a/CodeEntropy/config/arg_config_manager.py +++ b/CodeEntropy/config/arg_config_manager.py @@ -82,11 +82,6 @@ def load_config(self, file_path): yaml_files = glob.glob(os.path.join(file_path, "*.yaml")) if not yaml_files: - logger.warning( - f"No YAML configuration files found in directory: {file_path}. " - "Expected a file with extension '.yaml'. " - "Proceeding with default configuration: {'run1': {}}." - ) return {"run1": {}} try: diff --git a/CodeEntropy/config/data_logger.py b/CodeEntropy/config/data_logger.py index 223a66c..4345e11 100644 --- a/CodeEntropy/config/data_logger.py +++ b/CodeEntropy/config/data_logger.py @@ -2,16 +2,23 @@ import logging import re -from tabulate import tabulate +import numpy as np +from rich.console import Console +from rich.table import Table + +from CodeEntropy.config.logging_config import LoggingConfig # Set up logger logger = logging.getLogger(__name__) +console = LoggingConfig.get_console() class DataLogger: - def __init__(self): + def __init__(self, console=None): + self.console = console or Console() self.molecule_data = [] self.residue_data = [] + self.group_labels = {} def save_dataframes_as_json(self, molecule_df, residue_df, output_file): """Save multiple DataFrames into a single JSON file with separate keys""" @@ -28,44 +35,75 @@ def clean_residue_name(self, resname): """Ensures residue names are stripped and cleaned before being stored""" return re.sub(r"[-–—]", "", str(resname)) - def add_results_data(self, resname, level, entropy_type, value): + def add_results_data(self, group_id, level, entropy_type, value): """Add data for molecule-level entries""" - resname = self.clean_residue_name(resname) - self.molecule_data.append((resname, level, entropy_type, value)) + self.molecule_data.append((group_id, level, entropy_type, value)) - def add_residue_data(self, resid, resname, level, entropy_type, value): + def add_residue_data( + self, group_id, resname, level, entropy_type, frame_count, value + ): """Add data for residue-level entries""" resname = self.clean_residue_name(resname) - self.residue_data.append([resid, resname, level, entropy_type, value]) + if isinstance(frame_count, np.ndarray): + frame_count = frame_count.tolist() + self.residue_data.append( + [group_id, resname, level, entropy_type, frame_count, value] + ) + + def add_group_label(self, group_id, label, residue_count=None, atom_count=None): + """Store a mapping from group ID to a descriptive label and metadata""" + self.group_labels[group_id] = { + "label": label, + "residue_count": residue_count, + "atom_count": atom_count, + } def log_tables(self): - """Log both tables at once""" - # Log molecule data + """Display rich tables in terminal""" + if self.molecule_data: - logger.info("Molecule Data Table:") - table_str = tabulate( - self.molecule_data, - headers=["Residue Name", "Level", "Type", "Result (J/mol/K)"], - tablefmt="grid", - numalign="center", - stralign="center", + table = Table( + title="Molecule Entropy Results", show_lines=True, expand=True ) - logger.info(f"\n{table_str}") + table.add_column("Group ID", justify="center", style="bold cyan") + table.add_column("Level", justify="center", style="magenta") + table.add_column("Type", justify="center", style="green") + table.add_column("Result (J/mol/K)", justify="center", style="yellow") + + for row in self.molecule_data: + table.add_row(*[str(cell) for cell in row]) + + console.print(table) - # Log residue data if self.residue_data: - logger.info("Residue Data Table:") - table_str = tabulate( - self.residue_data, - headers=[ - "Residue ID", - "Residue Name", - "Level", - "Type", - "Result (J/mol/K)", - ], - tablefmt="grid", - numalign="center", - stralign="center", + table = Table(title="Residue Entropy Results", show_lines=True, expand=True) + table.add_column("Group ID", justify="center", style="bold cyan") + table.add_column("Residue Name", justify="center", style="cyan") + table.add_column("Level", justify="center", style="magenta") + table.add_column("Type", justify="center", style="green") + table.add_column("Count", justify="center", style="green") + table.add_column("Result (J/mol/K)", justify="center", style="yellow") + + for row in self.residue_data: + table.add_row(*[str(cell) for cell in row]) + + console.print(table) + + if self.group_labels: + label_table = Table( + title="Group ID to Residue Label Mapping", show_lines=True, expand=True ) - logger.info(f"\n{table_str}") + label_table.add_column("Group ID", justify="center", style="bold cyan") + label_table.add_column("Residue Label", justify="center", style="green") + label_table.add_column("Residue Count", justify="center", style="magenta") + label_table.add_column("Atom Count", justify="center", style="yellow") + + for group_id, info in self.group_labels.items(): + label_table.add_row( + str(group_id), + info["label"], + str(info.get("residue_count", "")), + str(info.get("atom_count", "")), + ) + + console.print(label_table) diff --git a/CodeEntropy/config/logging_config.py b/CodeEntropy/config/logging_config.py index e652e3e..aea5f89 100644 --- a/CodeEntropy/config/logging_config.py +++ b/CodeEntropy/config/logging_config.py @@ -1,104 +1,164 @@ import logging -import logging.config import os +from rich.console import Console +from rich.logging import RichHandler + + +class ErrorFilter(logging.Filter): + """ + Logging filter that only allows records with level ERROR or higher. + + This ensures that the attached handler only processes error and critical logs, + filtering out all lower level messages such as DEBUG and INFO. + """ + + def filter(self, record): + return record.levelno >= logging.ERROR + class LoggingConfig: - def __init__(self, folder, log_level=logging.INFO): - log_directory = os.path.join(folder, "logs") - os.makedirs(log_directory, exist_ok=True) - - self.LOGGING = { - "version": 1, - "disable_existing_loggers": False, - "formatters": { - "detailed": { - "format": "%(asctime)s " - "- %(levelname)s " - "- %(filename)s:%(lineno)d " - "- %(message)s", - }, - "simple": { - "format": "%(message)s", - }, - }, - "handlers": { - "console": { - "class": "logging.StreamHandler", - "formatter": "simple", - "level": logging.INFO, - }, - "stdout": { - "class": "logging.FileHandler", - "filename": os.path.join(log_directory, "program.out"), - "formatter": "simple", - "level": logging.INFO, - }, - "logfile": { - "class": "logging.FileHandler", - "filename": os.path.join(log_directory, "program.log"), - "formatter": "detailed", - "level": log_level, - }, - "errorfile": { - "class": "logging.FileHandler", - "filename": os.path.join(log_directory, "program.err"), - "formatter": "simple", - "level": logging.ERROR, - }, - "commandfile": { - "class": "logging.FileHandler", - "filename": os.path.join(log_directory, "program.com"), - "formatter": "simple", - "level": logging.INFO, - }, - "mdanalysis_log": { - "class": "logging.FileHandler", - "filename": os.path.join(log_directory, "mdanalysis.log"), - "formatter": "detailed", - "level": log_level, - }, - }, - "loggers": { - "": { - "handlers": ["console", "stdout", "logfile", "errorfile"], - "level": log_level, - }, - "MDAnalysis": { - "handlers": ["mdanalysis_log"], - "level": log_level, - "propagate": False, - }, - "commands": { - "handlers": ["commandfile"], - "level": logging.INFO, - "propagate": False, - }, - }, + """ + Configures logging with Rich console output and multiple file handlers. + Provides a single Rich Console instance that records all output for later export. + + Attributes: + _console (Console): Shared Rich Console instance with output recording enabled. + log_dir (str): Directory path to store log files. + level (int): Logging level (e.g., logging.INFO). + console (Console): The Rich Console instance used for output and logging. + handlers (dict): Dictionary of logging handlers for console and files. + """ + + _console = None # Shared Console with recording enabled + + @classmethod + def get_console(cls): + """ + Get or create a singleton Rich Console instance with recording enabled. + + Returns: + Console: Rich Console instance that prints to terminal and records output. + """ + if cls._console is None: + # Create console that records output for later export + cls._console = Console(record=True) + return cls._console + + def __init__(self, folder, level=logging.INFO): + """ + Initialize the logging configuration. + + Args: + folder (str): Base folder where 'logs' directory will be created. + level (int): Logging level (default: logging.INFO). + """ + self.log_dir = os.path.join(folder, "logs") + os.makedirs(self.log_dir, exist_ok=True) + self.level = level + + # Use the single recorded console instance + self.console = self.get_console() + + self._setup_handlers() + + def _setup_handlers(self): + paths = { + "main": os.path.join(self.log_dir, "program.log"), + "error": os.path.join(self.log_dir, "program.err"), + "command": os.path.join(self.log_dir, "program.com"), + "mdanalysis": os.path.join(self.log_dir, "mdanalysis.log"), } + formatter = logging.Formatter( + "%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s" + ) + + self.handlers = { + "rich": RichHandler( + console=self.console, + markup=True, + rich_tracebacks=True, + show_time=True, + show_level=True, + show_path=False, + ), + "main": logging.FileHandler(paths["main"]), + "error": logging.FileHandler(paths["error"]), + "command": logging.FileHandler(paths["command"]), + "mdanalysis": logging.FileHandler(paths["mdanalysis"]), + } + + self.handlers["rich"].setLevel(logging.INFO) + self.handlers["main"].setLevel(self.level) + self.handlers["error"].setLevel(logging.ERROR) + self.handlers["command"].setLevel(logging.INFO) + self.handlers["mdanalysis"].setLevel(self.level) + + for name, handler in self.handlers.items(): + if name != "rich": + handler.setFormatter(formatter) + + # Add filter to error handler to ensure only ERROR and above are logged + self.handlers["error"].addFilter(ErrorFilter()) + def setup_logging(self): - logging.config.dictConfig(self.LOGGING) - logging.getLogger("MDAnalysis") - logging.getLogger("commands") + """ + Configure the root logger and specific loggers with the prepared handlers. + + Returns: + logging.Logger: Logger instance for the current module (__name__). + """ + root = logging.getLogger() + root.setLevel(self.level) + root.addHandler(self.handlers["rich"]) + root.addHandler(self.handlers["main"]) + root.addHandler(self.handlers["error"]) + + logging.getLogger("commands").addHandler(self.handlers["command"]) + logging.getLogger("commands").setLevel(logging.INFO) + logging.getLogger("commands").propagate = False + + logging.getLogger("MDAnalysis").addHandler(self.handlers["mdanalysis"]) + logging.getLogger("MDAnalysis").setLevel(self.level) + logging.getLogger("MDAnalysis").propagate = False + return logging.getLogger(__name__) def update_logging_level(self, log_level): - # Update the root logger level + """ + Update the logging level for the root logger and specific sub-loggers. + + Args: + log_level (int): New logging level (e.g., logging.DEBUG, logging.WARNING). + """ root_logger = logging.getLogger() root_logger.setLevel(log_level) for handler in root_logger.handlers: - handler.setLevel( - log_level if isinstance(handler, logging.FileHandler) else logging.INFO - ) + if isinstance(handler, logging.FileHandler): + handler.setLevel(log_level) + else: + # Keep RichHandler at INFO or higher for nicer console output + handler.setLevel(logging.INFO) - # Update all other loggers and their handlers - for logger_name in self.LOGGING["loggers"]: + for logger_name in ["commands", "MDAnalysis"]: logger = logging.getLogger(logger_name) logger.setLevel(log_level) for handler in logger.handlers: - handler.setLevel( - log_level - if isinstance(handler, logging.FileHandler) - else logging.INFO - ) + if isinstance(handler, logging.FileHandler): + handler.setLevel(log_level) + else: + handler.setLevel(logging.INFO) + + def save_console_log(self, filename="program_output.txt"): + """ + Save all recorded console output to a text file. + + Args: + filename (str): Name of the file to write console output to. + Defaults to 'program_output.txt' in the logs directory. + """ + output_path = os.path.join(self.log_dir, filename) + os.makedirs(self.log_dir, exist_ok=True) + with open(output_path, "w", encoding="utf-8") as f: + f.write(self.console.export_text()) diff --git a/CodeEntropy/entropy.py b/CodeEntropy/entropy.py index e3bf58e..647f7cb 100644 --- a/CodeEntropy/entropy.py +++ b/CodeEntropy/entropy.py @@ -6,8 +6,18 @@ import pandas as pd import waterEntropy.recipes.interfacial_solvent as GetSolvent from numpy import linalg as la +from rich.progress import ( + BarColumn, + Progress, + SpinnerColumn, + TextColumn, + TimeElapsedColumn, +) + +from CodeEntropy.config.logging_config import LoggingConfig logger = logging.getLogger(__name__) +console = LoggingConfig.get_console() class EntropyManager: @@ -50,6 +60,11 @@ def execute(self): """ start, end, step = self._get_trajectory_bounds() number_frames = self._get_number_frames(start, end, step) + + console.print( + f"Analyzing a total of {number_frames} frames in this calculation." + ) + ve = VibrationalEntropy( self._run_manager, self._args, @@ -67,25 +82,47 @@ def execute(self): self._group_molecules, ) - self._handle_water_entropy(start, end, step) reduced_atom, number_molecules, levels, groups = self._initialize_molecules() - force_matrices, torque_matrices = self._level_manager.build_covariance_matrices( - self, - reduced_atom, - levels, - groups, - start, - end, - step, - number_frames, + water_atoms = self._universe.select_atoms("water") + water_resids = set(res.resid for res in water_atoms.residues) + + water_groups = { + gid: g + for gid, g in groups.items() + if any( + res.resid in water_resids + for mol in [self._universe.atoms.fragments[i] for i in g] + for res in mol.residues + ) + } + nonwater_groups = { + gid: g for gid, g in groups.items() if gid not in water_groups + } + + if self._args.water_entropy and water_groups: + self._handle_water_entropy(start, end, step, water_groups) + else: + nonwater_groups.update(water_groups) + + force_matrices, torque_matrices, frame_counts = ( + self._level_manager.build_covariance_matrices( + self, + reduced_atom, + levels, + nonwater_groups, + start, + end, + step, + number_frames, + ) ) states_ua, states_res = self._level_manager.build_conformational_states( self, reduced_atom, levels, - groups, + nonwater_groups, start, end, step, @@ -97,11 +134,12 @@ def execute(self): self._compute_entropies( reduced_atom, levels, - groups, + nonwater_groups, force_matrices, torque_matrices, states_ua, states_res, + frame_counts, number_frames, ve, ce, @@ -110,34 +148,39 @@ def execute(self): self._finalize_molecule_results() self._data_logger.log_tables() - def _handle_water_entropy(self, start, end, step): + def _handle_water_entropy(self, start, end, step, water_groups): """ - Compute and exclude water entropy from the system if applicable. + Compute water entropy for each water group, log data, and update selection + string to exclude water from further analysis. - If water molecules are present and water entropy calculation is enabled, - this method computes their entropy and updates the selection string to - exclude water from further analysis. + Args: + start (int): Start frame index + end (int): End frame index + step (int): Step size + water_groups (dict): {group_id: [atom indices]} for water + """ + if not water_groups or not self._args.water_entropy: + return - Parameters: - start (int): Start frame index. - end (int): End frame index. - step (int): Step size for frame iteration. - """ - has_water = self._universe.select_atoms("water").n_atoms > 0 - if has_water and self._args.water_entropy: - self._calculate_water_entropy(self._universe, start, end, step) - self._args.selection_string = ( - self._args.selection_string + " and not water" - if self._args.selection_string != "all" - else "not water" - ) - logger.debug( - "WaterEntropy: molecule_data: %s", self._data_logger.molecule_data - ) - logger.debug( - "WaterEntropy: residue_data: %s", self._data_logger.residue_data + for group_id, atom_indices in water_groups.items(): + + self._calculate_water_entropy( + universe=self._universe, + start=start, + end=end, + step=step, + group_id=group_id, ) + self._args.selection_string = ( + self._args.selection_string + " and not water" + if self._args.selection_string != "all" + else "not water" + ) + + logger.debug(f"WaterEntropy: molecule_data: {self._data_logger.molecule_data}") + logger.debug(f"WaterEntropy: residue_data: {self._data_logger.residue_data}") + def _initialize_molecules(self): """ Prepare the reduced universe and determine molecule-level configurations. @@ -164,6 +207,7 @@ def _compute_entropies( torque_matrices, states_ua, states_res, + frame_counts, number_frames, ve, ce, @@ -188,56 +232,102 @@ def _compute_entropies( torque_matrices (dict): Precomputed torque covariance matrices. states_ua (dict): Dictionary to store united-atom conformational states. states_res (list): List to store residue-level conformational states. + frames_count (dict): Dictionary to store the frame counts number_frames (int): Total number of trajectory frames to process. """ - for group_id in groups.keys(): - mol = self._get_molecule_container(reduced_atom, groups[group_id][0]) - for level in levels[groups[group_id][0]]: - highest = level == levels[groups[group_id][0]][-1] - - if level == "united_atom": - self._process_united_atom_entropy( - group_id, - mol, - ve, - ce, - level, - force_matrices["ua"], - torque_matrices["ua"], - states_ua, - highest, - number_frames, - ) + with Progress( + SpinnerColumn(), + TextColumn("[bold blue]{task.fields[title]}", justify="right"), + BarColumn(), + "[progress.percentage]{task.percentage:>3.1f}%", + TimeElapsedColumn(), + ) as progress: + + task = progress.add_task( + "[green]Calculating Entropy...", + total=len(groups), + title="Starting...", + ) - elif level == "residue": - self._process_vibrational_entropy( - group_id, - number_frames, - ve, - level, - force_matrices["res"][group_id], - torque_matrices["res"][group_id], - highest, - ) + for group_id in groups.keys(): + mol = self._get_molecule_container(reduced_atom, groups[group_id][0]) - self._process_conformational_entropy( - group_id, - ce, - level, - states_res, - number_frames, - ) + residue_group = "_".join( + sorted(set(res.resname for res in mol.residues)) + ) + group_residue_count = len(groups[group_id]) + group_atom_count = 0 + for mol_id in groups[group_id]: + each_mol = self._get_molecule_container(reduced_atom, mol_id) + group_atom_count += len(each_mol.atoms) + self._data_logger.add_group_label( + group_id, residue_group, group_residue_count, group_atom_count + ) + + resname = mol.atoms[0].resname + resid = mol.atoms[0].resid + segid = mol.atoms[0].segid + + mol_label = f"{resname}_{resid} (segid {segid})" - elif level == "polymer": - self._process_vibrational_entropy( - group_id, - number_frames, - ve, - level, - force_matrices["poly"][group_id], - torque_matrices["poly"][group_id], - highest, + for level in levels[groups[group_id][0]]: + progress.update( + task, + title=f"Calculating entropy values | " + f"Molecule: {mol_label} | " + f"Level: {level}", ) + highest = level == levels[groups[group_id][0]][-1] + + if level == "united_atom": + self._process_united_atom_entropy( + group_id, + mol, + ve, + ce, + level, + force_matrices["ua"], + torque_matrices["ua"], + states_ua, + frame_counts["ua"], + highest, + number_frames, + ) + + elif level == "residue": + self._process_vibrational_entropy( + group_id, + mol, + number_frames, + ve, + level, + force_matrices["res"][group_id], + torque_matrices["res"][group_id], + highest, + ) + + self._process_conformational_entropy( + group_id, + mol, + ce, + level, + states_res, + number_frames, + ) + + elif level == "polymer": + self._process_vibrational_entropy( + group_id, + mol, + number_frames, + ve, + level, + force_matrices["poly"][group_id], + torque_matrices["poly"][group_id], + highest, + ) + + progress.advance(task) def _get_trajectory_bounds(self): """ @@ -264,16 +354,6 @@ def _get_number_frames(self, start, end, step): Returns: int: Total number of frames considered. """ - trajectory_length = len(self._universe.trajectory) - - if start == 0 and end == -1 and step == 1: - return trajectory_length - - if end == -1: - end = trajectory_length - else: - end += 1 - return math.floor((end - start) / step) def _get_reduced_universe(self): @@ -317,6 +397,7 @@ def _process_united_atom_entropy( force_matrix, torque_matrix, states, + frame_counts, highest, number_frames, ): @@ -332,6 +413,7 @@ def _process_united_atom_entropy( level (str): Granularity level (should be 'united_atom'). start, end, step (int): Trajectory frame parameters. n_frames (int): Number of trajectory frames. + frame_counts: Number of frames counted highest (bool): Whether this is the highest level of resolution for the molecule. """ @@ -371,21 +453,50 @@ def _process_united_atom_entropy( S_conf += S_conf_res self._data_logger.add_residue_data( - residue_id, residue.resname, level, "Transvibrational", S_trans_res + group_id, + residue.resname, + level, + "Transvibrational", + frame_counts[key], + S_trans_res, ) self._data_logger.add_residue_data( - residue_id, residue.resname, level, "Rovibrational", S_rot_res + group_id, + residue.resname, + level, + "Rovibrational", + frame_counts[key], + S_rot_res, ) self._data_logger.add_residue_data( - residue_id, residue.resname, level, "Conformational", S_conf_res + group_id, + residue.resname, + level, + "Conformational", + frame_counts[key], + S_conf_res, ) self._data_logger.add_results_data(group_id, level, "Transvibrational", S_trans) self._data_logger.add_results_data(group_id, level, "Rovibrational", S_rot) self._data_logger.add_results_data(group_id, level, "Conformational", S_conf) + residue_group = "_".join( + sorted(set(res.resname for res in mol_container.residues)) + ) + + logger.debug(f"residue_group {residue_group}") + def _process_vibrational_entropy( - self, group_id, number_frames, ve, level, force_matrix, torque_matrix, highest + self, + group_id, + mol_container, + number_frames, + ve, + level, + force_matrix, + torque_matrix, + highest, ): """ Calculates vibrational entropy. @@ -397,6 +508,7 @@ def _process_vibrational_entropy( level (str): Current granularity level. force_matrix : Force covariance matrix torque_matrix : Torque covariance matrix + frame_count: highest (bool): Flag indicating if this is the highest granularity level. """ @@ -414,8 +526,17 @@ def _process_vibrational_entropy( self._data_logger.add_results_data(group_id, level, "Transvibrational", S_trans) self._data_logger.add_results_data(group_id, level, "Rovibrational", S_rot) + residue_group = "_".join( + sorted(set(res.resname for res in mol_container.residues)) + ) + residue_count = len(mol_container.residues) + atom_count = len(mol_container.atoms) + self._data_logger.add_group_label( + group_id, residue_group, residue_count, atom_count + ) + def _process_conformational_entropy( - self, group_id, ce, level, states, number_frames + self, group_id, mol_container, ce, level, states, number_frames ): """ Computes conformational entropy at the residue level (whole-molecule dihedral @@ -445,149 +566,191 @@ def _process_conformational_entropy( if contains_state_data else 0 ) - self._data_logger.add_results_data(group_id, level, "Conformational", S_conf) + residue_group = "_".join( + sorted(set(res.resname for res in mol_container.residues)) + ) + residue_count = len(mol_container.residues) + atom_count = len(mol_container.atoms) + self._data_logger.add_group_label( + group_id, residue_group, residue_count, atom_count + ) + def _finalize_molecule_results(self): """ - Aggregates and logs total entropy per molecule using residue_data grouped by - resid. + Aggregates and logs total entropy and frame counts per molecule. """ entropy_by_molecule = defaultdict(float) - - for mol_id, level, entropy_type, result in self._data_logger.molecule_data: - if level != "Molecule Total": + for ( + mol_id, + level, + entropy_type, + result, + ) in self._data_logger.molecule_data: + if level != "Group Total": try: entropy_by_molecule[mol_id] += float(result) except ValueError: logger.warning(f"Skipping invalid entry: {mol_id}, {result}") - for mol_id, total_entropy in entropy_by_molecule.items(): + for mol_id in entropy_by_molecule.keys(): + total_entropy = entropy_by_molecule[mol_id] + self._data_logger.molecule_data.append( - (mol_id, "Molecule Total", "Molecule Total Entropy", total_entropy) + ( + mol_id, + "Group Total", + "Group Total Entropy", + total_entropy, + ) ) - # Save to file self._data_logger.save_dataframes_as_json( pd.DataFrame( self._data_logger.molecule_data, - columns=["Molecule ID", "Level", "Type", "Result (J/mol/K)"], + columns=[ + "Group ID", + "Level", + "Type", + "Result (J/mol/K)", + ], ), pd.DataFrame( self._data_logger.residue_data, columns=[ - "Residue ID", + "Group ID", "Residue Name", "Level", "Type", + "Frame Count", "Result (J/mol/K)", ], ), self._args.output_file, ) - def _calculate_water_entropy(self, universe, start, end, step): + def _calculate_water_entropy(self, universe, start, end, step, group_id=None): """ - Calculates orientational and vibrational entropy for water molecules. - - Args: - universe: MDAnalysis Universe object. - start (int): Start frame. - end (int): End frame. - step (int): Step size. + Calculate and aggregate the entropy of water molecules in a simulation. + + This function computes orientational, translational, and rotational + entropy components for all water molecules, aggregates them per residue, + and maps all waters to a single group ID. It also logs the total results + and labels the water group in the data logger. + + Parameters + ---------- + universe : MDAnalysis.Universe + The simulation universe containing water molecules. + start : int + The starting frame for analysis. + end : int + The ending frame for analysis. + step : int + Frame interval for analysis. + group_id : int or str, optional + The group ID to which all water molecules will be assigned. """ - Sorient_dict, _, vibrations, _, _ = ( + Sorient_dict, covariances, vibrations, _, water_count = ( GetSolvent.get_interfacial_water_orient_entropy( universe, start, end, step, self._args.temperature, parallel=True ) ) - # Log per-residue entropy using helper functions - self._calculate_water_orientational_entropy(Sorient_dict) - self._calculate_water_vibrational_translational_entropy(vibrations) - self._calculate_water_vibrational_rotational_entropy(vibrations) - - # Aggregate entropy components per molecule - results = {} - - for row in self._data_logger.residue_data: - mol_id = row[1] - entropy_type = row[3].split()[0] - value = float(row[4]) - - if mol_id not in results: - results[mol_id] = { - "Orientational": 0.0, - "Transvibrational": 0.0, - "Rovibrational": 0.0, - } - - results[mol_id][entropy_type] += value - - # Log per-molecule entropy components and total - for mol_id, components in results.items(): - total = 0.0 - for entropy_type in ["Orientational", "Transvibrational", "Rovibrational"]: - S_component = components[entropy_type] - self._data_logger.add_results_data( - mol_id, "water", entropy_type, S_component - ) - total += S_component + self._calculate_water_orientational_entropy(Sorient_dict, group_id) + self._calculate_water_vibrational_translational_entropy( + vibrations, group_id, covariances + ) + self._calculate_water_vibrational_rotational_entropy( + vibrations, group_id, covariances + ) + + water_selection = universe.select_atoms("resname WAT") + actual_water_residues = len(water_selection.residues) + residue_names = { + resname + for res_dict in Sorient_dict.values() + for resname in res_dict.keys() + if resname.upper() in water_selection.residues.resnames + } + + residue_group = "_".join(sorted(residue_names)) if residue_names else "WAT" + self._data_logger.add_group_label( + group_id, residue_group, actual_water_residues, len(water_selection.atoms) + ) - def _calculate_water_orientational_entropy(self, Sorient_dict): + def _calculate_water_orientational_entropy(self, Sorient_dict, group_id): """ - Logs orientational entropy values directly from Sorient_dict. + Aggregate orientational entropy for all water molecules into a single group. + + Parameters + ---------- + Sorient_dict : dict + Dictionary containing orientational entropy values per residue. + group_id : int or str + The group ID to which the water residues belong. + covariances : object + Covariance object. """ for resid, resname_dict in Sorient_dict.items(): for resname, values in resname_dict.items(): if isinstance(values, list) and len(values) == 2: Sor, count = values self._data_logger.add_residue_data( - resid, resname, "Water", "Orientational", Sor + group_id, resname, "Water", "Orientational", count, Sor ) - def _calculate_water_vibrational_translational_entropy(self, vibrations): + def _calculate_water_vibrational_translational_entropy( + self, vibrations, group_id, covariances + ): """ - Logs summed translational entropy values per residue-solvent pair. + Aggregate translational vibrational entropy for all water molecules. + + Parameters + ---------- + vibrations : object + Object containing translational entropy data (vibrations.translational_S). + group_id : int or str + The group ID for the water residues. + covariances : object + Covariance object. """ + for (solute_id, _), entropy in vibrations.translational_S.items(): if isinstance(entropy, (list, np.ndarray)): entropy = float(np.sum(entropy)) - if "_" in solute_id: - resname, resid_str = solute_id.rsplit("_", 1) - try: - resid = int(resid_str) - except ValueError: - resid = -1 - else: - resname = solute_id - resid = -1 - + count = covariances.counts.get((solute_id, "WAT"), 1) + resname = solute_id.rsplit("_", 1)[0] if "_" in solute_id else solute_id self._data_logger.add_residue_data( - resid, resname, "Water", "Transvibrational", entropy + group_id, resname, "Water", "Transvibrational", count, entropy ) - def _calculate_water_vibrational_rotational_entropy(self, vibrations): + def _calculate_water_vibrational_rotational_entropy( + self, vibrations, group_id, covariances + ): """ - Logs summed rotational entropy values per residue-solvent pair. + Aggregate rotational vibrational entropy for all water molecules. + + Parameters + ---------- + vibrations : object + Object containing rotational entropy data (vibrations.rotational_S). + group_id : int or str + The group ID for the water residues. + covariances : object + Covariance object. """ for (solute_id, _), entropy in vibrations.rotational_S.items(): if isinstance(entropy, (list, np.ndarray)): entropy = float(np.sum(entropy)) - if "_" in solute_id: - resname, resid_str = solute_id.rsplit("_", 1) - try: - resid = int(resid_str) - except ValueError: - resid = -1 - else: - resname = solute_id - resid = -1 + count = covariances.counts.get((solute_id, "WAT"), 1) + resname = solute_id.rsplit("_", 1)[0] if "_" in solute_id else solute_id self._data_logger.add_residue_data( - resid, resname, "Water", "Rovibrational", entropy + group_id, resname, "Water", "Rovibrational", count, entropy ) @@ -773,11 +936,11 @@ def assign_conformation( # get the values of the angle for the dihedral # dihedral angle values have a range from -180 to 180 - indices = list(range(start, end, step)) + indices = list(range(number_frames)) for timestep_index, _ in zip( indices, data_container.trajectory[start:end:step] ): - timestep_index = timestep_index - start + timestep_index = timestep_index value = dihedral.value() # we want postive values in range 0 to 360 to make the peak assignment # work using the fact that dihedrals have circular symetry diff --git a/CodeEntropy/group_molecules.py b/CodeEntropy/group_molecules.py index f8e43fb..d361691 100644 --- a/CodeEntropy/group_molecules.py +++ b/CodeEntropy/group_molecules.py @@ -50,7 +50,7 @@ def _by_none(self, universe): number_groups = len(molecule_groups) - logger.info(f"Number of molecule groups: {number_groups}") + logger.debug(f"Number of molecule groups: {number_groups}") logger.debug(f"Molecule groups are: {molecule_groups}") return molecule_groups @@ -85,7 +85,7 @@ def _by_molecules(self, universe): number_groups = len(molecule_groups) - logger.info(f"Number of molecule groups: {number_groups}") + logger.debug(f"Number of molecule groups: {number_groups}") logger.debug(f"Molecule groups are: {molecule_groups}") return molecule_groups diff --git a/CodeEntropy/levels.py b/CodeEntropy/levels.py index aa06c40..585e74f 100644 --- a/CodeEntropy/levels.py +++ b/CodeEntropy/levels.py @@ -1,6 +1,13 @@ import logging import numpy as np +from rich.progress import ( + BarColumn, + Progress, + SpinnerColumn, + TextColumn, + TimeElapsedColumn, +) logger = logging.getLogger(__name__) @@ -784,33 +791,75 @@ def build_covariance_matrices( "res": [None] * number_groups, "poly": [None] * number_groups, } + + total_steps = len(reduced_atom.trajectory[start:end:step]) + total_items = ( + sum(len(levels[mol_id]) for mols in groups.values() for mol_id in mols) + * total_steps + ) + frame_counts = { "ua": {}, "res": np.zeros(number_groups, dtype=int), "poly": np.zeros(number_groups, dtype=int), } - indices = list(range(start, end, step)) - for time_index, _ in zip(indices, reduced_atom.trajectory[start:end:step]): + with Progress( + SpinnerColumn(), + TextColumn("[bold blue]{task.fields[title]}", justify="right"), + BarColumn(), + TextColumn("[progress.percentage]{task.percentage:>3.1f}%"), + TimeElapsedColumn(), + ) as progress: + + task = progress.add_task( + "[green]Processing...", + total=total_items, + title="Starting...", + ) - for group_id, molecules in groups.items(): - for mol_id in molecules: - mol = entropy_manager._get_molecule_container(reduced_atom, mol_id) - for level in levels[mol_id]: - self.update_force_torque_matrices( - entropy_manager, - mol, - group_id, - level, - levels[mol_id], - time_index - start, - number_frames, - force_avg, - torque_avg, - frame_counts, + indices = list(range(number_frames)) + for time_index, _ in zip(indices, reduced_atom.trajectory[start:end:step]): + for group_id, molecules in groups.items(): + for mol_id in molecules: + mol = entropy_manager._get_molecule_container( + reduced_atom, mol_id ) + for level in levels[mol_id]: + mol = entropy_manager._get_molecule_container( + reduced_atom, mol_id + ) - return force_avg, torque_avg + resname = mol.atoms[0].resname + resid = mol.atoms[0].resid + segid = mol.atoms[0].segid + + mol_label = f"{resname}_{resid} (segid {segid})" + + progress.update( + task, + title=f"Building covariance matrices | " + f"Timestep {time_index} | " + f"Molecule: {mol_label} | " + f"Level: {level}", + ) + + self.update_force_torque_matrices( + entropy_manager, + mol, + group_id, + level, + levels[mol_id], + time_index, + number_frames, + force_avg, + torque_avg, + frame_counts, + ) + + progress.advance(task) + + return force_avg, torque_avg, frame_counts def update_force_torque_matrices( self, @@ -926,6 +975,8 @@ def update_force_torque_matrices( force_avg[key][group_id] += (f_mat - force_avg[key][group_id]) / n torque_avg[key][group_id] += (t_mat - torque_avg[key][group_id]) / n + return frame_counts + def filter_zero_rows_columns(self, arg_matrix): """ function for removing rows and columns that contain only zeros from a matrix @@ -1011,28 +1062,78 @@ def build_conformational_states( states_ua = {} states_res = [None] * number_groups - for group_id in groups.keys(): - molecules = groups[group_id] - for mol_id in molecules: - mol = entropy_manager._get_molecule_container(reduced_atom, mol_id) - for level in levels[mol_id]: - if level == "united_atom": - for res_id, residue in enumerate(mol.residues): - key = (group_id, res_id) - - res_container = ( - entropy_manager._run_manager.new_U_select_atom( - mol, - f"index {residue.atoms.indices[0]}:" - f"{residue.atoms.indices[-1]}", + total_items = sum( + len(levels[mol_id]) for mols in groups.values() for mol_id in mols + ) + + with Progress( + SpinnerColumn(), + TextColumn("[bold blue]{task.fields[title]}", justify="right"), + BarColumn(), + TextColumn("[progress.percentage]{task.percentage:>3.1f}%"), + TimeElapsedColumn(), + ) as progress: + + task = progress.add_task( + "[green]Building Conformational States...", + total=total_items, + title="Starting...", + ) + + for group_id in groups.keys(): + molecules = groups[group_id] + for mol_id in molecules: + mol = entropy_manager._get_molecule_container(reduced_atom, mol_id) + + resname = mol.atoms[0].resname + resid = mol.atoms[0].resid + segid = mol.atoms[0].segid + + mol_label = f"{resname}_{resid} (segid {segid})" + + for level in levels[mol_id]: + progress.update( + task, + title=f"Building conformational states | " + f"Molecule: {mol_label} | " + f"Level: {level}", + ) + + if level == "united_atom": + for res_id, residue in enumerate(mol.residues): + key = (group_id, res_id) + + res_container = ( + entropy_manager._run_manager.new_U_select_atom( + mol, + f"index {residue.atoms.indices[0]}:" + f"{residue.atoms.indices[-1]}", + ) + ) + heavy_res = ( + entropy_manager._run_manager.new_U_select_atom( + res_container, "not name H*" + ) + ) + states = self.compute_dihedral_conformations( + heavy_res, + level, + number_frames, + bin_width, + start, + end, + step, + ce, ) - ) - heavy_res = entropy_manager._run_manager.new_U_select_atom( - res_container, "not name H*" - ) + if key in states_ua: + states_ua[key].append(states) + else: + states_ua[key] = states + + elif level == "res": states = self.compute_dihedral_conformations( - heavy_res, + mol, level, number_frames, bin_width, @@ -1042,27 +1143,12 @@ def build_conformational_states( ce, ) - if key in states_ua.keys(): - states_ua[key].append(states) + if states_res[group_id] is None: + states_res[group_id] = states else: - states_ua[key] = states - - if level == "res": - states = self.compute_dihedral_conformations( - mol, - level, - number_frames, - bin_width, - start, - end, - step, - ce, - ) + states_res[group_id] += states - if states_res[group_id] is None: - states_res[group_id] = states - else: - states_res[group_id] += states + progress.advance(task) logger.debug(f"states_ua {states_ua}") logger.debug(f"states_res {states_res}") diff --git a/CodeEntropy/run.py b/CodeEntropy/run.py index 77ca20e..38f3039 100644 --- a/CodeEntropy/run.py +++ b/CodeEntropy/run.py @@ -3,8 +3,18 @@ import pickle import MDAnalysis as mda +import requests +import yaml +from art import text2art from MDAnalysis.analysis.base import AnalysisFromFunction from MDAnalysis.coordinates.memory import MemoryReader +from rich.align import Align +from rich.console import Group +from rich.padding import Padding +from rich.panel import Panel +from rich.rule import Rule +from rich.table import Table +from rich.text import Text from CodeEntropy.config.arg_config_manager import ConfigManager from CodeEntropy.config.data_logger import DataLogger @@ -14,6 +24,7 @@ from CodeEntropy.levels import LevelManager logger = logging.getLogger(__name__) +console = LoggingConfig.get_console() class RunManager: @@ -85,6 +96,114 @@ def create_job_folder(): # Return the path of the newly created folder return new_folder_path + def load_citation_data(self): + """ + Load CITATION.cff from GitHub into memory. + Return empty dict if offline. + """ + url = ( + "https://raw.githubusercontent.com/CCPBioSim/" + "CodeEntropy/refs/heads/main/CITATION.cff" + ) + try: + response = requests.get(url, timeout=10) + response.raise_for_status() + return yaml.safe_load(response.text) + except requests.exceptions.RequestException: + return None + + def show_splash(self): + """Render splash screen with optional citation metadata.""" + citation = self.load_citation_data() + + if citation: + # ASCII Title + ascii_title = text2art(citation.get("title", "CodeEntropy")) + ascii_render = Align.center(Text(ascii_title, style="bold white")) + + # Metadata + version = citation.get("version", "?") + release_date = citation.get("date-released", "?") + url = citation.get("url", citation.get("repository-code", "")) + + version_text = Align.center( + Text(f"Version {version} | Released {release_date}", style="green") + ) + url_text = Align.center(Text(url, style="blue underline")) + + # Description block + abstract = citation.get("abstract", "No description available.") + description_title = Align.center( + Text("Description", style="bold magenta underline") + ) + description_body = Align.center( + Padding(Text(abstract, style="white", justify="left"), (0, 4)) + ) + + # Contributors table + contributors_title = Align.center( + Text("Contributors", style="bold magenta underline") + ) + + author_table = Table( + show_header=True, header_style="bold yellow", box=None, pad_edge=False + ) + author_table.add_column("Name", style="bold", justify="center") + author_table.add_column("Affiliation", justify="center") + + for author in citation.get("authors", []): + name = ( + f"{author.get('given-names', '')} {author.get('family-names', '')}" + ).strip() + affiliation = author.get("affiliation", "") + author_table.add_row(name, affiliation) + + contributors_table = Align.center(Padding(author_table, (0, 4))) + + # Full layout + splash_content = Group( + ascii_render, + Rule(style="cyan"), + version_text, + url_text, + Text(), + description_title, + description_body, + Text(), + contributors_title, + contributors_table, + ) + else: + # ASCII Title + ascii_title = text2art("CodeEntropy") + ascii_render = Align.center(Text(ascii_title, style="bold white")) + + splash_content = Group( + ascii_render, + ) + + splash_panel = Panel( + splash_content, + title="[bold bright_cyan]Welcome to CodeEntropy", + title_align="center", + border_style="bright_cyan", + padding=(1, 4), + expand=True, + ) + + console.print(splash_panel) + + def print_args_table(self, args): + table = Table(title="Run Configuration", expand=True) + + table.add_column("Argument", style="cyan", no_wrap=True) + table.add_column("Value", style="magenta") + + for arg in vars(args): + table.add_row(arg, str(getattr(args, arg))) + + console.print(table) + def run_entropy_workflow(self): """ Runs the entropy analysis workflow by setting up logging, loading configuration @@ -94,6 +213,7 @@ def run_entropy_workflow(self): """ try: logger = self._logging_config.setup_logging() + self.show_splash() current_directory = os.getcwd() @@ -122,10 +242,7 @@ def run_entropy_workflow(self): if not getattr(args, "selection_string", None): raise ValueError("Missing 'selection_string' argument.") - # Log all inputs for the current run - logger.info(f"All input for {run_name}") - for arg in vars(args): - logger.info(f" {arg}: {getattr(args, arg)}") + self.print_args_table(args) # Load MDAnalysis Universe tprfile = args.top_traj_file[0] @@ -153,6 +270,8 @@ def run_entropy_workflow(self): entropy_manager.execute() + self._logging_config.save_console_log() + except Exception as e: logger.error(f"RunManager encountered an error: {e}", exc_info=True) raise diff --git a/pyproject.toml b/pyproject.toml index a0972a6..8d02c51 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,8 +38,10 @@ dependencies = [ "psutil==5.9.5", "PyYAML==6.0.2", "python-json-logger==3.3.0", - "tabulate==0.9.0", - "waterEntropy==1.2.0" + "rich==14.0.0", + "art==6.5", + "waterEntropy==1.2.0", + "requests>=2.32.5", ] [project.urls] diff --git a/tests/test_CodeEntropy/test_data_logger.py b/tests/test_CodeEntropy/test_data_logger.py index 782b0b6..db0db9d 100644 --- a/tests/test_CodeEntropy/test_data_logger.py +++ b/tests/test_CodeEntropy/test_data_logger.py @@ -3,11 +3,12 @@ import shutil import tempfile import unittest -from unittest.mock import patch +import numpy as np import pandas as pd from CodeEntropy.config.data_logger import DataLogger +from CodeEntropy.config.logging_config import LoggingConfig from CodeEntropy.main import main @@ -57,7 +58,7 @@ def test_add_results_data(self): ) self.assertEqual( self.logger.molecule_data, - [("0", "united_atom", "Transvibrational", 653.4041220313459)], + [(0, "united_atom", "Transvibrational", 653.4041220313459)], ) def test_add_residue_data(self): @@ -65,11 +66,24 @@ def test_add_residue_data(self): Test that add_residue_data correctly appends a residue-level entry. """ self.logger.add_residue_data( - 0, "DA", "united_atom", "Transvibrational", 122.61216935211893 + 0, "DA", "united_atom", "Transvibrational", 10, 122.61216935211893 ) self.assertEqual( self.logger.residue_data, - [[0, "DA", "united_atom", "Transvibrational", 122.61216935211893]], + [[0, "DA", "united_atom", "Transvibrational", 10, 122.61216935211893]], + ) + + def test_add_residue_data_with_numpy_array(self): + """ + Test that add_residue_data correctly converts a NumPy array to a list. + """ + frame_array = np.array([10]) + self.logger.add_residue_data( + 1, "DT", "united_atom", "Transvibrational", frame_array, 98.123456789 + ) + self.assertEqual( + self.logger.residue_data, + [[1, "DT", "united_atom", "Transvibrational", [10], 98.123456789]], ) def test_save_dataframes_as_json(self): @@ -120,24 +134,24 @@ def test_save_dataframes_as_json(self): self.assertEqual(data["molecule_data"][0]["Type"], "Transvibrational (J/mol/K)") self.assertEqual(data["residue_data"][0]["Residue"], 0) - @patch("CodeEntropy.config.data_logger.logger") - def test_log_tables(self, mock_logger): - """ - Test that log_tables logs formatted molecule and residue tables using the - logger. - """ + def test_log_tables_rich_output(self): + console = LoggingConfig.get_console() + console.clear_live() + self.logger.add_results_data( 0, "united_atom", "Transvibrational", 653.4041220313459 ) self.logger.add_residue_data( - 0, "DA", "united_atom", "Transvibrational", 122.61216935211893 + 0, "DA", "united_atom", "Transvibrational", 10, 122.61216935211893 ) + self.logger.add_group_label(0, "DA", 10, 100) self.logger.log_tables() - calls = [call[0][0] for call in mock_logger.info.call_args_list] - self.assertTrue(any("Molecule Data Table:" in c for c in calls)) - self.assertTrue(any("Residue Data Table:" in c for c in calls)) + output = console.export_text() + assert "Molecule Entropy Results" in output + assert "Residue Entropy Results" in output + assert "Group ID to Residue Label Mapping" in output if __name__ == "__main__": diff --git a/tests/test_CodeEntropy/test_entropy.py b/tests/test_CodeEntropy/test_entropy.py index a5b6455..f7af3bb 100644 --- a/tests/test_CodeEntropy/test_entropy.py +++ b/tests/test_CodeEntropy/test_entropy.py @@ -3,7 +3,7 @@ import shutil import tempfile import unittest -from unittest.mock import MagicMock, call, patch +from unittest.mock import MagicMock, PropertyMock, call, patch import MDAnalysis as mda import numpy as np @@ -48,7 +48,7 @@ def tearDown(self): shutil.rmtree(self.test_dir) def test_execute_full_workflow(self): - # Setup universe and args as before + # Setup universe and args tprfile = os.path.join(self.test_data_dir, "md_A4_dna.tpr") trrfile = os.path.join(self.test_data_dir, "md_A4_dna_xf.trr") u = mda.Universe(tprfile, trrfile) @@ -83,7 +83,7 @@ def test_execute_full_workflow(self): return_value=(mock_reduced_atom, 3, mock_levels, mock_groups) ) entropy_manager._level_manager.build_covariance_matrices = MagicMock( - return_value=("force_matrices", "torque_matrices") + return_value=("force_matrices", "torque_matrices", "frame_counts") ) entropy_manager._level_manager.build_conformational_states = MagicMock( return_value=(["state_ua"], ["state_res"]) @@ -96,17 +96,23 @@ def test_execute_full_workflow(self): ve = MagicMock() ce = MagicMock() + # Patch both VibrationalEntropy, ConformationalEntropy AND u.atoms.fragments + mock_molecule = MagicMock() + mock_molecule.residues = [] + with ( patch("CodeEntropy.entropy.VibrationalEntropy", return_value=ve), patch("CodeEntropy.entropy.ConformationalEntropy", return_value=ce), + patch.object( + type(u.atoms), "fragments", new_callable=PropertyMock + ) as mock_fragments, ): - + mock_fragments.return_value = [mock_molecule] * 10 entropy_manager.execute() # Assert the key calls happened with expected arguments - ( - entropy_manager._level_manager.build_conformational_states - ).assert_called_once_with( + build_states = entropy_manager._level_manager.build_conformational_states + build_states.assert_called_once_with( entropy_manager, mock_reduced_atom, mock_levels, @@ -127,6 +133,7 @@ def test_execute_full_workflow(self): "torque_matrices", ["state_ua"], ["state_res"], + "frame_counts", 11, ve, ce, @@ -135,74 +142,179 @@ def test_execute_full_workflow(self): entropy_manager._finalize_molecule_results.assert_called_once() entropy_manager._data_logger.log_tables.assert_called_once() - def test_water_entropy_sets_selection_string_when_all(self): + def test_execute_triggers_handle_water_entropy_minimal(self): """ - Tests that when `selection_string` is initially 'all' and water entropy is - enabled, `_handle_water_entropy` sets `selection_string` to 'not water' after - calculating water entropy. + Minimal test to ensure _handle_water_entropy line is executed. """ - mock_universe = MagicMock() - mock_universe.select_atoms.return_value.n_atoms = 5 # Simulate water present + tprfile = os.path.join(self.test_data_dir, "md_A4_dna.tpr") + trrfile = os.path.join(self.test_data_dir, "md_A4_dna_xf.trr") + u = mda.Universe(tprfile, trrfile) - args = MagicMock(water_entropy=True, selection_string="all") - run_manager = MagicMock() - level_manager = MagicMock() + args = MagicMock( + bin_width=0.1, temperature=300, selection_string="all", water_entropy=True + ) + run_manager = RunManager("temp_folder") + level_manager = LevelManager() data_logger = DataLogger() group_molecules = MagicMock() + entropy_manager = EntropyManager( + run_manager, args, u, data_logger, level_manager, group_molecules + ) + entropy_manager._get_trajectory_bounds = MagicMock(return_value=(0, 10, 1)) + entropy_manager._get_number_frames = MagicMock(return_value=11) + entropy_manager._initialize_molecules = MagicMock( + return_value=(MagicMock(), 3, {}, {0: [0]}) + ) + entropy_manager._level_manager.build_covariance_matrices = MagicMock( + return_value=("force_matrices", "torque_matrices", "frame_counts") + ) + entropy_manager._level_manager.build_conformational_states = MagicMock( + return_value=(["state_ua"], ["state_res"]) + ) + entropy_manager._compute_entropies = MagicMock() + entropy_manager._finalize_molecule_results = MagicMock() + entropy_manager._data_logger.log_tables = MagicMock() + + with ( + patch("CodeEntropy.entropy.VibrationalEntropy", return_value=MagicMock()), + patch( + "CodeEntropy.entropy.ConformationalEntropy", return_value=MagicMock() + ), + patch.object( + type(u.atoms), "fragments", new_callable=PropertyMock + ) as mock_fragments, + patch.object(u, "select_atoms") as mock_select_atoms, + patch.object( + entropy_manager, "_handle_water_entropy" + ) as mock_handle_water_entropy, + ): + mock_fragments.return_value = [MagicMock(residues=[MagicMock(resid=1)])] + mock_select_atoms.return_value = MagicMock(residues=[MagicMock(resid=1)]) + + entropy_manager.execute() + + mock_handle_water_entropy.assert_called_once() + + def test_water_entropy_sets_selection_string_when_all(self): + """ + If selection_string is 'all' and water entropy is enabled, + _handle_water_entropy should update it to 'not water'. + """ + mock_universe = MagicMock() + args = MagicMock(water_entropy=True, selection_string="all") manager = EntropyManager( - run_manager, - args, - mock_universe, - data_logger, - level_manager, - group_molecules, + MagicMock(), args, mock_universe, DataLogger(), MagicMock(), MagicMock() ) - # Patch water entropy calculation manager._calculate_water_entropy = MagicMock() + manager._data_logger.add_group_label = MagicMock() - # Call _handle_water_entropy directly - manager._handle_water_entropy(0, 10, 1) + water_groups = {0: [0, 1, 2]} - manager._calculate_water_entropy.assert_called_once_with( - mock_universe, 0, 10, 1 - ) - self.assertEqual(args.selection_string, "not water") + manager._handle_water_entropy(0, 10, 1, water_groups) + + assert manager._args.selection_string == "not water" + manager._calculate_water_entropy.assert_called_once() def test_water_entropy_appends_to_custom_selection_string(self): """ - Tests that when `selection_string` is a custom value and water - entropy is enabled, `_handle_water_entropy` appends ' and not water' - to the existing selection string. + If selection_string is custom and water entropy is enabled, + _handle_water_entropy appends ' and not water'. """ mock_universe = MagicMock() - mock_universe.select_atoms.return_value.n_atoms = 5 # Simulate water present + args = MagicMock(water_entropy=True, selection_string="protein") + manager = EntropyManager( + MagicMock(), args, mock_universe, DataLogger(), MagicMock(), MagicMock() + ) + manager._calculate_water_entropy = MagicMock() + manager._data_logger.add_group_label = MagicMock() + + water_groups = {0: [0, 1, 2]} + + manager._handle_water_entropy(0, 10, 1, water_groups) + + manager._calculate_water_entropy.assert_called_once() + assert args.selection_string == "protein and not water" + + def test_handle_water_entropy_returns_early(self): + """ + Verifies that _handle_water_entropy returns immediately if: + 1. water_groups is empty + 2. water_entropy is disabled + """ + mock_universe = MagicMock() args = MagicMock(water_entropy=True, selection_string="protein") - run_manager = MagicMock() - level_manager = MagicMock() + manager = EntropyManager( + MagicMock(), args, mock_universe, DataLogger(), MagicMock(), MagicMock() + ) + + # Patch _calculate_water_entropy to track if called + manager._calculate_water_entropy = MagicMock() + + # Case 1: empty water_groups + result = manager._handle_water_entropy(0, 10, 1, {}) + assert result is None + manager._calculate_water_entropy.assert_not_called() + + # Case 2: water_entropy disabled + manager._args.water_entropy = False + result = manager._handle_water_entropy(0, 10, 1, {0: [0, 1, 2]}) + assert result is None + manager._calculate_water_entropy.assert_not_called() + + def test_initialize_molecules(self): + """ + Test _initialize_molecules returns expected tuple by mocking internal methods. + + - Ensures _get_reduced_universe is called and its return is used. + - Ensures _level_manager.select_levels is called with the reduced atom + selection. + - Ensures _group_molecules.grouping_molecules is called with the reduced atom + and grouping arg. + - Verifies the returned tuple matches the mocked values. + """ + + args = MagicMock( + bin_width=0.1, temperature=300, selection_string="all", water_entropy=False + ) + run_manager = RunManager("temp_folder") + level_manager = LevelManager() data_logger = DataLogger() group_molecules = MagicMock() - manager = EntropyManager( - run_manager, - args, - mock_universe, - data_logger, - level_manager, - group_molecules, + run_manager, args, MagicMock(), data_logger, level_manager, group_molecules ) - manager._calculate_water_entropy = MagicMock() + # Mock dependencies + manager._get_reduced_universe = MagicMock(return_value="mock_reduced_atom") + manager._level_manager = MagicMock() + manager._level_manager.select_levels = MagicMock( + return_value=(5, ["level1", "level2"]) + ) + manager._group_molecules = MagicMock() + manager._group_molecules.grouping_molecules = MagicMock( + return_value=["groupA", "groupB"] + ) + manager._args = MagicMock() + manager._args.grouping = "custom_grouping" - # Call _handle_water_entropy directly - manager._handle_water_entropy(0, 10, 1) + # Call the method under test + result = manager._initialize_molecules() - manager._calculate_water_entropy.assert_called_once_with( - mock_universe, 0, 10, 1 + # Assert calls + manager._get_reduced_universe.assert_called_once() + manager._level_manager.select_levels.assert_called_once_with( + "mock_reduced_atom" ) - self.assertEqual(args.selection_string, "protein and not water") + manager._group_molecules.grouping_molecules.assert_called_once_with( + "mock_reduced_atom", "custom_grouping" + ) + + # Assert return value + expected = ("mock_reduced_atom", 5, ["level1", "level2"], ["groupA", "groupB"]) + self.assertEqual(result, expected) def test_get_trajectory_bounds(self): """ @@ -234,27 +346,29 @@ def test_get_trajectory_bounds(self): ) def test_get_number_frames(self, mock_args): """ - Test `_get_number_frames` when the end index is -1 (interpreted as no slicing). + Test `_get_number_frames` when the end index is -1. - Ensures that the function returns 0 frames when the trajectory bounds - result in an empty range. + Ensures that the function correctly counts all frames from start to + the end of the trajectory. """ config_manager = ConfigManager() - parser = config_manager.setup_argparse() args = parser.parse_args() + # Mock universe with a trajectory of 10 frames + mock_universe = MagicMock() + mock_universe.trajectory = range(10) + entropy_manager = EntropyManager( - MagicMock(), args, MagicMock(), MagicMock(), MagicMock(), MagicMock() - ) - entropy_manager._get_trajectory_bounds() - number_frames = entropy_manager._get_number_frames( - entropy_manager._args.start, - entropy_manager._args.end, - entropy_manager._args.step, + MagicMock(), args, mock_universe, MagicMock(), MagicMock(), MagicMock() ) - self.assertEqual(number_frames, 0) + # Use _get_trajectory_bounds to convert end=-1 into the actual last frame + start, end, step = entropy_manager._get_trajectory_bounds() + number_frames = entropy_manager._get_number_frames(start, end, step) + + # Expect all frames to be counted + self.assertEqual(number_frames, 10) @patch( "argparse.ArgumentParser.parse_args", @@ -272,21 +386,21 @@ def test_get_number_frames_sliced_trajectory(self, mock_args): when slicing from 0 to 20 with a step of 1, expecting 21 frames. """ config_manager = ConfigManager() - parser = config_manager.setup_argparse() args = parser.parse_args() + # Mock universe with 30 frames + mock_universe = MagicMock() + mock_universe.trajectory = range(30) + entropy_manager = EntropyManager( - MagicMock(), args, MagicMock(), MagicMock(), MagicMock(), MagicMock() - ) - entropy_manager._get_trajectory_bounds() - number_frames = entropy_manager._get_number_frames( - entropy_manager._args.start, - entropy_manager._args.end, - entropy_manager._args.step, + MagicMock(), args, mock_universe, MagicMock(), MagicMock(), MagicMock() ) - self.assertEqual(number_frames, 21) + start, end, step = entropy_manager._get_trajectory_bounds() + number_frames = entropy_manager._get_number_frames(start, end, step) + + self.assertEqual(number_frames, 20) @patch( "argparse.ArgumentParser.parse_args", @@ -298,59 +412,59 @@ def test_get_number_frames_sliced_trajectory(self, mock_args): ) def test_get_number_frames_sliced_trajectory_step(self, mock_args): """ - Test `_get_number_frames` with a step that skips all frames. + Test `_get_number_frames` with a step that skips frames. - Ensures that the function returns 0 when the step size results in - no frames being selected from the trajectory. + Ensures that the function correctly counts the number of frames + when a step size of 5 is applied. """ - config_manager = ConfigManager() - parser = config_manager.setup_argparse() args = parser.parse_args() + # Mock universe with 20 frames + mock_universe = MagicMock() + mock_universe.trajectory = range(20) + entropy_manager = EntropyManager( - MagicMock(), args, MagicMock(), MagicMock(), MagicMock(), MagicMock() - ) - entropy_manager._get_trajectory_bounds() - number_frames = entropy_manager._get_number_frames( - entropy_manager._args.start, - entropy_manager._args.end, - entropy_manager._args.step, + MagicMock(), args, mock_universe, MagicMock(), MagicMock(), MagicMock() ) - self.assertEqual(number_frames, 0) + start, end, step = entropy_manager._get_trajectory_bounds() + number_frames = entropy_manager._get_number_frames(start, end, step) - @patch( - "argparse.ArgumentParser.parse_args", - return_value=MagicMock( - selection_string="all", - ), - ) - def test_get_reduced_universe_all(self, mock_args): - """ - Test `_get_reduced_universe` with 'all' selection. + # Expect 20 frames divided by step of 5 = 4 frames + self.assertEqual(number_frames, 4) - Verifies that the full universe is returned when the selection string - is set to 'all', and the number of atoms remains unchanged. - """ - # Load MDAnalysis Universe - tprfile = os.path.join(self.test_data_dir, "md_A4_dna.tpr") - trrfile = os.path.join(self.test_data_dir, "md_A4_dna_xf.trr") - u = mda.Universe(tprfile, trrfile) + @patch( + "argparse.ArgumentParser.parse_args", + return_value=MagicMock( + selection_string="all", + ), + ) + def test_get_reduced_universe_all(self, mock_args): + """ + Test `_get_reduced_universe` with 'all' selection. - config_manager = ConfigManager() + Verifies that the full universe is returned when the selection string + is set to 'all', and the number of atoms remains unchanged. + """ + # Load MDAnalysis Universe + tprfile = os.path.join(self.test_data_dir, "md_A4_dna.tpr") + trrfile = os.path.join(self.test_data_dir, "md_A4_dna_xf.trr") + u = mda.Universe(tprfile, trrfile) - parser = config_manager.setup_argparse() - args = parser.parse_args() + config_manager = ConfigManager() - entropy_manager = EntropyManager( - MagicMock(), args, u, MagicMock(), MagicMock(), MagicMock() - ) + parser = config_manager.setup_argparse() + args = parser.parse_args() + + entropy_manager = EntropyManager( + MagicMock(), args, u, MagicMock(), MagicMock(), MagicMock() + ) - entropy_manager._get_reduced_universe() + entropy_manager._get_reduced_universe() - self.assertEqual(entropy_manager._universe.atoms.n_atoms, 254) + self.assertEqual(entropy_manager._universe.atoms.n_atoms, 254) @patch( "argparse.ArgumentParser.parse_args", @@ -433,31 +547,27 @@ def test_get_molecule_container(self, mock_args): assert set(selected_indices) == set(expected_indices) assert len(mol_universe.atoms) == len(original_fragment) - def test_process_united_atom_level(self): + def test_process_united_atom_entropy(self): """ Tests that `_process_united_atom_entropy` correctly logs global and - residue-level entropy results for a known molecular system using MDAnalysis. + residue-level entropy results for a mocked molecular system. """ - - # Load a known test universe - tprfile = os.path.join(self.test_data_dir, "md_A4_dna.tpr") - trrfile = os.path.join(self.test_data_dir, "md_A4_dna_xf.trr") - u = mda.Universe(tprfile, trrfile) - # Setup managers and arguments args = MagicMock(bin_width=0.1, temperature=300, selection_string="all") - run_manager = RunManager("temp_folder") - level_manager = LevelManager() + run_manager = MagicMock() + level_manager = MagicMock() data_logger = DataLogger() group_molecules = MagicMock() manager = EntropyManager( - run_manager, args, u, data_logger, level_manager, group_molecules + run_manager, args, MagicMock(), data_logger, level_manager, group_molecules ) - # Prepare mock molecule container - reduced_atom = manager._get_reduced_universe() - mol_container = manager._get_molecule_container(reduced_atom, 0) - n_residues = len(mol_container.residues) + # Mock molecule container with residues and atoms + n_residues = 3 + mock_residues = [MagicMock(resname=f"RES{i}") for i in range(n_residues)] + mock_atoms_per_mol = 3 + mock_atoms = [MagicMock() for _ in range(mock_atoms_per_mol)] # per molecule + mol_container = MagicMock(residues=mock_residues, atoms=mock_atoms) # Create dummy matrices and states force_matrix = {(0, i): np.eye(3) for i in range(n_residues)} @@ -472,6 +582,14 @@ def test_process_united_atom_level(self): ) ce.conformational_entropy_calculation.return_value = 3.0 + # Manually add the group label so group_id=0 exists + data_logger.add_group_label( + 0, + "_".join(f"RES{i}" for i in range(n_residues)), # label + n_residues, # residue_count + len(mock_atoms) * n_residues, # total atoms for the group + ) + # Run the method manager._process_united_atom_entropy( group_id=0, @@ -484,24 +602,30 @@ def test_process_united_atom_level(self): states=states, highest=True, number_frames=10, + frame_counts={(0, i): 10 for i in range(n_residues)}, ) # Check molecule-level results df = data_logger.molecule_data - self.assertEqual(len(df), 3) # Trans, Rot, Conf + assert len(df) == 3 # Trans, Rot, Conf # Check residue-level results residue_df = data_logger.residue_data - self.assertEqual(len(residue_df), 3 * n_residues) # 3 types per residue + assert len(residue_df) == 3 * n_residues # 3 types per residue # Check that all expected types are present expected_types = {"Transvibrational", "Rovibrational", "Conformational"} - actual_types = set(entry[2] for entry in df) - self.assertSetEqual(actual_types, expected_types) + assert actual_types == expected_types residue_types = set(entry[3] for entry in residue_df) - self.assertSetEqual(residue_types, expected_types) + assert residue_types == expected_types + + # Check group label logging + group_label = data_logger.group_labels[0] # Access by group_id key + assert group_label["label"] == "_".join(f"RES{i}" for i in range(n_residues)) + assert group_label["residue_count"] == n_residues + assert group_label["atom_count"] == len(mock_atoms) * n_residues def test_process_vibrational_only_levels(self): """ @@ -541,6 +665,7 @@ def test_process_vibrational_only_levels(self): # Run the method manager._process_vibrational_entropy( group_id=0, + mol_container=mol_container, number_frames=10, ve=ve, level="Vibrational", @@ -584,6 +709,7 @@ def test_compute_entropies_polymer_branch(self): torque_matrices = {"poly": {0: np.eye(3) * 2}} states_ua = {} states_res = [] + frame_counts = 10 mol_mock = MagicMock() mol_mock.residues = [] @@ -604,6 +730,7 @@ def test_compute_entropies_polymer_branch(self): torque_matrices, states_ua, states_res, + frame_counts, number_frames, ve, ce, @@ -642,6 +769,7 @@ def test_process_conformational_residue_level(self): # Run the method manager._process_conformational_entropy( group_id=0, + mol_container=MagicMock(), ce=ce, level="residue", states=states, @@ -659,10 +787,185 @@ def test_process_conformational_residue_level(self): results = [entry[3] for entry in df] self.assertIn(3.33, results) + def test_compute_entropies_united_atom(self): + """ + Test that _process_united_atom_entropy is called correctly for 'united_atom' + level with highest=False when it's the only level. + """ + args = MagicMock(bin_width=0.1) + run_manager = MagicMock() + level_manager = MagicMock() + data_logger = DataLogger() + group_molecules = MagicMock() + manager = EntropyManager( + run_manager, args, MagicMock(), data_logger, level_manager, group_molecules + ) + + reduced_atom = MagicMock() + number_frames = 10 + groups = {0: [0]} + levels = [["united_atom"]] # single level + + force_matrices = {"ua": {0: "force_ua"}} + torque_matrices = {"ua": {0: "torque_ua"}} + states_ua = {} + states_res = [] + frame_counts = {"ua": {(0, 0): 10}} + + mol_mock = MagicMock() + mol_mock.residues = [] + manager._get_molecule_container = MagicMock(return_value=mol_mock) + manager._process_united_atom_entropy = MagicMock() + + ve = MagicMock() + ce = MagicMock() + + manager._compute_entropies( + reduced_atom, + levels, + groups, + force_matrices, + torque_matrices, + states_ua, + states_res, + frame_counts, + number_frames, + ve, + ce, + ) + + manager._process_united_atom_entropy.assert_called_once_with( + 0, + mol_mock, + ve, + ce, + "united_atom", + force_matrices["ua"], + torque_matrices["ua"], + states_ua, + frame_counts["ua"], + True, # highest is True since only level + number_frames, + ) + + def test_compute_entropies_residue(self): + """ + Test that _process_vibrational_entropy and _process_conformational_entropy + are called correctly for 'residue' level with highest=True when it's the + only level. + """ + # Setup + args = MagicMock(bin_width=0.1) + run_manager = MagicMock() + level_manager = MagicMock() + data_logger = DataLogger() + group_molecules = MagicMock() + manager = EntropyManager( + run_manager, args, MagicMock(), data_logger, level_manager, group_molecules + ) + + reduced_atom = MagicMock() + number_frames = 10 + groups = {0: [0]} + levels = [["residue"]] # single level + + force_matrices = {"res": {0: "force_res"}} + torque_matrices = {"res": {0: "torque_res"}} + states_ua = {} + states_res = ["states_res"] + + # Frame counts for residue level + frame_counts = {"res": {(0, 0): 10}} + + # Mock molecule + mol_mock = MagicMock() + mol_mock.residues = [] + manager._get_molecule_container = MagicMock(return_value=mol_mock) + manager._process_vibrational_entropy = MagicMock() + manager._process_conformational_entropy = MagicMock() + + # Mock entropy calculators + ve = MagicMock() + ce = MagicMock() + + # Call the method under test + manager._compute_entropies( + reduced_atom, + levels, + groups, + force_matrices, + torque_matrices, + states_ua, + states_res, + frame_counts, + number_frames, + ve, + ce, + ) + + # Assert that the per-level processing methods were called + manager._process_vibrational_entropy.assert_called() + manager._process_conformational_entropy.assert_called() + + def test_compute_entropies_polymer(self): + args = MagicMock(bin_width=0.1) + run_manager = MagicMock() + level_manager = MagicMock() + data_logger = DataLogger() + group_molecules = MagicMock() + manager = EntropyManager( + run_manager, args, MagicMock(), data_logger, level_manager, group_molecules + ) + + reduced_atom = MagicMock() + number_frames = 10 + groups = {0: [0]} + levels = [["polymer"]] + + force_matrices = {"poly": {0: "force_poly"}} + torque_matrices = {"poly": {0: "torque_poly"}} + states_ua = {} + states_res = [] + + frame_counts = {"poly": {(0, 0): 10}} + + mol_mock = MagicMock() + mol_mock.residues = [] + manager._get_molecule_container = MagicMock(return_value=mol_mock) + manager._process_vibrational_entropy = MagicMock() + + ve = MagicMock() + ce = MagicMock() + + manager._compute_entropies( + reduced_atom, + levels, + groups, + force_matrices, + torque_matrices, + states_ua, + states_res, + frame_counts, + number_frames, + ve, + ce, + ) + + manager._process_vibrational_entropy.assert_called_once_with( + 0, + mol_mock, + number_frames, + ve, + "polymer", + force_matrices["poly"][0], + torque_matrices["poly"][0], + True, + ) + def test_finalize_molecule_results_aggregates_and_logs_total_entropy(self): """ Tests that `_finalize_molecule_results` correctly aggregates entropy values per - molecule from `molecule_data`, appends a 'Molecule Total' entry, and calls + molecule from `molecule_data`, appends a 'Group Total' entry, and calls `save_dataframes_as_json` with the expected DataFrame structure. """ # Setup @@ -686,7 +989,7 @@ def test_finalize_molecule_results_aggregates_and_logs_total_entropy(self): # Check that totals were added totals = [ - entry for entry in data_logger.molecule_data if entry[1] == "Molecule Total" + entry for entry in data_logger.molecule_data if entry[1] == "Group Total" ] self.assertEqual(len(totals), 2) @@ -729,7 +1032,7 @@ def test_finalize_molecule_results_skips_invalid_entries(self, mock_logger): # Check that only valid values were aggregated totals = [ - entry for entry in data_logger.molecule_data if entry[1] == "Molecule Total" + entry for entry in data_logger.molecule_data if entry[1] == "Group Total" ] self.assertEqual(len(totals), 1) self.assertEqual(totals[0][3], 3.0) # 1.0 + 2.0 @@ -1014,162 +1317,206 @@ def test_vibrational_entropy_polymer_torque(self): def test_calculate_water_orientational_entropy(self): """ Test that orientational entropy values are correctly extracted from Sorient_dict - and logged using add_residue_data. + and logged per residue. """ Sorient_dict = {1: {"mol1": [1.0, 2]}, 2: {"mol1": [3.0, 4]}} + group_id = 0 + + self.entropy_manager._data_logger = MagicMock() - self.entropy_manager._calculate_water_orientational_entropy(Sorient_dict) + self.entropy_manager._calculate_water_orientational_entropy( + Sorient_dict, group_id + ) + + expected_calls = [ + call(group_id, "mol1", "Water", "Orientational", 2, 1.0), + call(group_id, "mol1", "Water", "Orientational", 4, 3.0), + ] self.entropy_manager._data_logger.add_residue_data.assert_has_calls( - [ - call(1, "mol1", "Water", "Orientational", 1.0), - call(2, "mol1", "Water", "Orientational", 3.0), - ] + expected_calls, any_order=False ) + assert self.entropy_manager._data_logger.add_residue_data.call_count == 2 def test_calculate_water_vibrational_translational_entropy(self): - """ - Test that translational vibrational entropy values are correctly summed - and logged per residue using add_residue_data. Also verifies that the - molecule-level average is computed and logged using _log_result. - """ mock_vibrations = MagicMock() mock_vibrations.translational_S = { - ("res1", "mol1"): [1.0, 2.0], - ("resB_invalid", "mol1"): 4.0, - ("res2", "mol1"): 3.0, + ("res1", 10): [1.0, 2.0], + ("resB_invalid", 10): 4.0, + ("res2", 10): 3.0, } + mock_covariances = MagicMock() + mock_covariances.counts = { + ("res1", "WAT"): 10, + # resB_invalid and res2 will use default count = 1 + } + + group_id = 0 + self.entropy_manager._data_logger = MagicMock() self.entropy_manager._calculate_water_vibrational_translational_entropy( - mock_vibrations + mock_vibrations, group_id, mock_covariances ) + expected_calls = [ + call(group_id, "res1", "Water", "Transvibrational", 10, 3.0), + call(group_id, "resB", "Water", "Transvibrational", 1, 4.0), + call(group_id, "res2", "Water", "Transvibrational", 1, 3.0), + ] + self.entropy_manager._data_logger.add_residue_data.assert_has_calls( - [ - call(-1, "res1", "Water", "Transvibrational", 3.0), - call(-1, "resB", "Water", "Transvibrational", 4.0), - call(-1, "res2", "Water", "Transvibrational", 3.0), - ] + expected_calls, any_order=False ) + assert self.entropy_manager._data_logger.add_residue_data.call_count == 3 - def test_empty_vibrational_entropy_dicts(self): - """ - Test that no logging occurs when both translational and rotational - entropy dictionaries are empty. Ensures that the methods handle empty - input gracefully without errors or unnecessary logging. - """ - self.entropy_manager._log_residue_data = MagicMock() - self.entropy_manager._log_result = MagicMock() - + def test_calculate_water_vibrational_rotational_entropy(self): mock_vibrations = MagicMock() - mock_vibrations.translational_S = {} - mock_vibrations.rotational_S = {} + mock_vibrations.rotational_S = { + ("resA_101", 14): [2.0, 3.0], + ("resB_invalid", 14): 4.0, + ("resC", 14): 5.0, + } + mock_covariances = MagicMock() + mock_covariances.counts = {("resA_101", "WAT"): 14} + + group_id = 0 + self.entropy_manager._data_logger = MagicMock() - self.entropy_manager._calculate_water_vibrational_translational_entropy( - mock_vibrations - ) self.entropy_manager._calculate_water_vibrational_rotational_entropy( - mock_vibrations + mock_vibrations, group_id, mock_covariances ) - self.entropy_manager._log_residue_data.assert_not_called() - self.entropy_manager._log_result.assert_not_called() + expected_calls = [ + call(group_id, "resA", "Water", "Rovibrational", 14, 5.0), + call(group_id, "resB", "Water", "Rovibrational", 1, 4.0), + call(group_id, "resC", "Water", "Rovibrational", 1, 5.0), + ] - def test_calculate_water_vibrational_rotational_entropy(self): - """ - Test that rotational vibrational entropy values are correctly summed - and logged per residue using add_residue_data. Also verifies that the - residue ID parsing handles both valid and invalid formats. - """ + self.entropy_manager._data_logger.add_residue_data.assert_has_calls( + expected_calls, any_order=False + ) + assert self.entropy_manager._data_logger.add_residue_data.call_count == 3 + + def test_empty_vibrational_entropy_dicts(self): mock_vibrations = MagicMock() - mock_vibrations.rotational_S = { - ("resA_101", "mol1"): [2.0, 3.0], - ("resB_invalid", "mol1"): 4.0, - ("resC", "mol1"): 5.0, - } + mock_vibrations.translational_S = {} + mock_vibrations.rotational_S = {} + + group_id = 0 + mock_covariances = MagicMock() + mock_covariances.counts = {} + + self.entropy_manager._data_logger = MagicMock() + self.entropy_manager._calculate_water_vibrational_translational_entropy( + mock_vibrations, group_id, mock_covariances + ) self.entropy_manager._calculate_water_vibrational_rotational_entropy( - mock_vibrations + mock_vibrations, group_id, mock_covariances ) - self.entropy_manager._data_logger.add_residue_data.assert_has_calls( - [ - call(101, "resA", "Water", "Rovibrational", 5.0), - call(-1, "resB", "Water", "Rovibrational", 4.0), - call(-1, "resC", "Water", "Rovibrational", 5.0), - ] - ) + self.entropy_manager._data_logger.add_residue_data.assert_not_called() @patch( "waterEntropy.recipes.interfacial_solvent.get_interfacial_water_orient_entropy" ) def test_calculate_water_entropy(self, mock_get_entropy): - """ - Integration-style test that verifies _calculate_water_entropy correctly - delegates to the orientational and vibrational entropy methods and logs - the expected values. - """ mock_vibrations = MagicMock() mock_vibrations.translational_S = {("res1", "mol1"): 2.0} mock_vibrations.rotational_S = {("res1", "mol1"): 3.0} mock_get_entropy.return_value = ( - {1: {"mol1": [1.0, 5]}}, - None, + {1: {"mol1": [1.0, 5]}}, # orientational + MagicMock(counts={("res1", "WAT"): 1}), mock_vibrations, None, - None, + 1, ) mock_universe = MagicMock() - self.entropy_manager._calculate_water_entropy(mock_universe, 0, 10, 1) + self.entropy_manager._data_logger = MagicMock() + + self.entropy_manager._calculate_water_entropy(mock_universe, 0, 10, 5) + + expected_calls = [ + call(None, "mol1", "Water", "Orientational", 5, 1.0), + call(None, "res1", "Water", "Transvibrational", 1, 2.0), + call(None, "res1", "Water", "Rovibrational", 1, 3.0), + ] self.entropy_manager._data_logger.add_residue_data.assert_has_calls( - [ - call(1, "mol1", "Water", "Orientational", 1.0), - call(-1, "res1", "Water", "Transvibrational", 2.0), - call(-1, "res1", "Water", "Rovibrational", 3.0), - ] + expected_calls, any_order=False ) + assert self.entropy_manager._data_logger.add_residue_data.call_count == 3 @patch( "waterEntropy.recipes.interfacial_solvent.get_interfacial_water_orient_entropy" ) def test_calculate_water_entropy_minimal(self, mock_get_entropy): - """ - Verifies that _calculate_water_entropy correctly logs entropy components - and total for a single molecule with minimal data. - """ + mock_vibrations = MagicMock() + mock_vibrations.translational_S = {("ACE_1", "WAT"): 10.0} + mock_vibrations.rotational_S = {("ACE_1", "WAT"): 2.0} + mock_get_entropy.return_value = ( - {}, - None, - MagicMock( - translational_S={("ACE_1", "WAT"): 10.0}, - rotational_S={("ACE_1", "WAT"): 2.0}, - ), - None, + {}, # no orientational entropy + MagicMock(counts={("ACE_1", "WAT"): 1}), + mock_vibrations, None, + 1, ) - # Simulate residue-level results already collected - self.entropy_manager._data_logger.residue_data = [ - [1, "ACE", "Water", "Orientational", 5.0], - [1, "ACE_1", "Water", "Transvibrational", 10.0], - [1, "ACE_1", "Water", "Rovibrational", 2.0], - ] + mock_logger = MagicMock() + self.entropy_manager._data_logger = mock_logger + mock_residue = MagicMock(resnames=["WAT"]) + mock_selection = MagicMock(residues=mock_residue, atoms=[MagicMock()]) mock_universe = MagicMock() - self.entropy_manager._calculate_water_entropy(mock_universe, 0, 10, 1) + mock_universe.select_atoms.return_value = mock_selection - self.entropy_manager._data_logger.add_results_data.assert_has_calls( - [ - call("ACE", "water", "Orientational", 5.0), - call("ACE", "water", "Transvibrational", 0.0), - call("ACE", "water", "Rovibrational", 0.0), - call("ACE_1", "water", "Orientational", 0.0), - call("ACE_1", "water", "Transvibrational", 10.0), - call("ACE_1", "water", "Rovibrational", 2.0), - ] + self.entropy_manager._calculate_water_entropy( + mock_universe, 0, 10, 1, group_id=None + ) + + mock_logger.add_group_label.assert_called_once_with( + None, "WAT", len(mock_selection.residues), len(mock_selection.atoms) + ) + + @patch( + "waterEntropy.recipes.interfacial_solvent.get_interfacial_water_orient_entropy" + ) + def test_calculate_water_entropy_adds_resname(self, mock_get_entropy): + mock_vibrations = MagicMock() + mock_vibrations.translational_S = {("res1", "WAT"): 2.0} + mock_vibrations.rotational_S = {("res1", "WAT"): 3.0} + + mock_get_entropy.return_value = ( + {1: {"WAT": [1.0, 5]}}, # orientational + MagicMock(counts={("res1", "WAT"): 1}), + mock_vibrations, + None, + 1, + ) + + mock_water_selection = MagicMock() + mock_residues_group = MagicMock() + mock_residues_group.resnames = ["WAT"] + mock_water_selection.residues = mock_residues_group + mock_water_selection.atoms = [1, 2, 3] + mock_universe = MagicMock() + mock_universe.select_atoms.return_value = mock_water_selection + + group_id = 0 + self.entropy_manager._data_logger = MagicMock() + + self.entropy_manager._calculate_water_entropy( + mock_universe, start=0, end=1, step=1, group_id=group_id + ) + + self.entropy_manager._data_logger.add_group_label.assert_called_with( + group_id, + "WAT", + len(mock_water_selection.residues), + len(mock_water_selection.atoms), ) # TODO test for error handling on invalid inputs @@ -1280,6 +1627,37 @@ def test_assign_conformation(self): assert np.all(result >= 0) assert np.issubdtype(result.dtype, np.floating) + def test_conformational_entropy_calculation(self): + """ + Test `conformational_entropy_calculation` method to verify + correct entropy calculation from a simple discrete state array. + """ + + # Setup managers and arguments + args = MagicMock(bin_width=0.1, temperature=300, selection_string="all") + run_manager = RunManager("temp_folder") + level_manager = LevelManager() + data_logger = DataLogger() + group_molecules = MagicMock() + + ce = ConformationalEntropy( + run_manager, args, MagicMock(), data_logger, level_manager, group_molecules + ) + + # Create a simple array of states with known counts + states = np.array([0, 0, 1, 1, 1, 2]) # 2x state 0, 3x state 1, 1x state 2 + number_frames = len(states) + + # Manually compute expected entropy + probs = np.array([2 / 6, 3 / 6, 1 / 6]) + expected_entropy = -np.sum(probs * np.log(probs)) * ce._GAS_CONST + + # Run the method under test + result = ce.conformational_entropy_calculation(states, number_frames) + + # Assert the result is close to expected entropy + self.assertAlmostEqual(result, expected_entropy, places=6) + class TestOrientationalEntropy(unittest.TestCase): """ diff --git a/tests/test_CodeEntropy/test_levels.py b/tests/test_CodeEntropy/test_levels.py index 2066592..adb30e5 100644 --- a/tests/test_CodeEntropy/test_levels.py +++ b/tests/test_CodeEntropy/test_levels.py @@ -81,7 +81,7 @@ def test_get_matrices(self): Ensures that the method returns correctly shaped matrices after filtering. """ - # Create a mock LevelManager instance + # Create a mock LevelManager level_manager level_manager = LevelManager() # Mock internal methods @@ -198,6 +198,50 @@ def test_get_matrices_torque_shape_mismatch(self): self.assertIn("Inconsistent torque matrix shape", str(context.exception)) + def test_get_matrices_torque_consistency(self): + """ + Test that get_matrices returns consistent torque and force matrices + when called multiple times with the same inputs. + """ + level_manager = LevelManager() + + level_manager.get_beads = MagicMock(return_value=["bead1", "bead2"]) + level_manager.get_axes = MagicMock(return_value=("trans_axes", "rot_axes")) + level_manager.get_weighted_forces = MagicMock( + return_value=np.array([1.0, 2.0, 3.0]) + ) + level_manager.get_weighted_torques = MagicMock( + return_value=np.array([0.5, 1.5, 2.5]) + ) + level_manager.create_submatrix = MagicMock(return_value=np.identity(3)) + + data_container = MagicMock() + + initial_force_matrix = np.zeros((6, 6)) + initial_torque_matrix = np.zeros((6, 6)) + + force_matrix_1, torque_matrix_1 = level_manager.get_matrices( + data_container=data_container, + level="residue", + number_frames=2, + highest_level=True, + force_matrix=initial_force_matrix.copy(), + torque_matrix=initial_torque_matrix.copy(), + ) + + force_matrix_2, torque_matrix_2 = level_manager.get_matrices( + data_container=data_container, + level="residue", + number_frames=2, + highest_level=True, + force_matrix=initial_force_matrix.copy(), + torque_matrix=initial_torque_matrix.copy(), + ) + + # Check that repeated calls produce the same output + self.assertTrue(np.allclose(torque_matrix_1, torque_matrix_2, atol=1e-8)) + self.assertTrue(np.allclose(force_matrix_1, force_matrix_2, atol=1e-8)) + def test_get_dihedrals_united_atom(self): """ Test `get_dihedrals` for 'united_atom' level. @@ -257,6 +301,62 @@ def test_get_dihedrals_no_residue(self): # Should result in no resdies self.assertEqual(result, []) + def test_compute_dihedral_conformations(self): + """ + Test `compute_dihedral_conformations` to ensure it correctly calls + `assign_conformation` on each dihedral and returns the expected + list of conformation strings. + """ + + # Setup + level_manager = LevelManager() + + # Mock selector (can be anything since we're mocking internals) + selector = MagicMock() + + # Mock dihedrals: pretend we have 3 dihedrals + mocked_dihedrals = ["d1", "d2", "d3"] + level_manager.get_dihedrals = MagicMock(return_value=mocked_dihedrals) + + # Mock the conformation entropy (ce) object with assign_conformation method + ce = MagicMock() + # For each dihedral, assign_conformation returns a numpy array of ints + ce.assign_conformation = MagicMock( + side_effect=[ + np.array([0, 1, 2]), + np.array([1, 0, 1]), + np.array([2, 2, 0]), + ] + ) + + number_frames = 3 + bin_width = 10 + start = 0 + end = 3 + step = 1 + level = "residue" + + # Call the method + states = level_manager.compute_dihedral_conformations( + selector, level, number_frames, bin_width, start, end, step, ce + ) + + # Expected states per frame + expected_states = [ + "012", # frame 0: d1=0, d2=1, d3=2 + "102", # frame 1: d1=1, d2=0, d3=2 + "210", # frame 2: d1=2, d2=1, d3=0 + ] + + # Verify the call count matches the number of dihedrals + self.assertEqual(ce.assign_conformation.call_count, len(mocked_dihedrals)) + + # Verify returned states are as expected + self.assertEqual(states, expected_states) + + # Verify get_dihedrals was called once with correct arguments + level_manager.get_dihedrals.assert_called_once_with(selector, level) + def test_compute_dihedral_conformations_no_dihedrals(self): """ Test `compute_dihedral_conformations` when no dihedrals are found. @@ -840,6 +940,227 @@ def test_create_submatrix_symmetric_result_when_data_equal(self): self.assertTrue(np.allclose(result, result.T)) # Check symmetry + def test_build_covariance_matrices_atomic(self): + """ + Test `build_covariance_matrices` to ensure it correctly orchestrates + calls and returns dictionaries with the expected structure. + + This test mocks dependencies including the entropy_manager, reduced_atom + trajectory, levels, groups, and internal method + `update_force_torque_matrices`. + """ + + # Instantiate your class (replace YourClass with actual class name) + level_manager = LevelManager() + + # Mock entropy_manager and _get_molecule_container + entropy_manager = MagicMock() + + # Fake atom with minimal attributes + atom = MagicMock() + atom.resname = "RES" + atom.resid = 1 + atom.segid = "A" + + # Fake molecule with atoms list + fake_mol = MagicMock() + fake_mol.atoms = [atom] + + # Always return fake_mol from _get_molecule_container + entropy_manager._get_molecule_container = MagicMock(return_value=fake_mol) + + # Mock reduced_atom with trajectory yielding two timesteps + timestep1 = MagicMock() + timestep1.frame = 0 + timestep2 = MagicMock() + timestep2.frame = 1 + reduced_atom = MagicMock() + reduced_atom.trajectory.__getitem__.return_value = [timestep1, timestep2] + + # Setup groups and levels dictionaries + groups = {"ua": ["mol1", "mol2"]} + levels = {"mol1": ["level1", "level2"], "mol2": ["level1"]} + + # Mock update_force_torque_matrices to just track calls + level_manager.update_force_torque_matrices = MagicMock() + + # Call the method under test + force_matrices, torque_matrices, _ = level_manager.build_covariance_matrices( + entropy_manager=entropy_manager, + reduced_atom=reduced_atom, + levels=levels, + groups=groups, + start=0, + end=2, + step=1, + number_frames=2, + ) + + # Assert returned matrices are dictionaries with correct keys + self.assertIsInstance(force_matrices, dict) + self.assertIsInstance(torque_matrices, dict) + self.assertSetEqual(set(force_matrices.keys()), {"ua", "res", "poly"}) + self.assertSetEqual(set(torque_matrices.keys()), {"ua", "res", "poly"}) + + # Assert 'res' and 'poly' entries are lists of correct length + self.assertIsInstance(force_matrices["res"], list) + self.assertIsInstance(force_matrices["poly"], list) + self.assertEqual(len(force_matrices["res"]), len(groups)) + self.assertEqual(len(force_matrices["poly"]), len(groups)) + + # Check _get_molecule_container call count: 2 timesteps * 2 molecules = 4 calls + self.assertEqual(entropy_manager._get_molecule_container.call_count, 10) + + # Check update_force_torque_matrices call count: + self.assertEqual(level_manager.update_force_torque_matrices.call_count, 6) + + def test_update_force_torque_matrices_united_atom(self): + """ + Test that `update_force_torque_matrices` correctly updates force and torque + matrices for the 'united_atom' level, assigning per-residue matrices and + incrementing frame counts. + """ + level_manager = LevelManager() + entropy_manager = MagicMock() + run_manager = MagicMock() + entropy_manager._run_manager = run_manager + + mock_residue_group = MagicMock() + mock_residue_group.trajectory.__getitem__.return_value = None + run_manager.new_U_select_atom.return_value = mock_residue_group + + mock_residue1 = MagicMock() + mock_residue1.atoms.indices = [0, 2] + mock_residue2 = MagicMock() + mock_residue2.atoms.indices = [3, 5] + + mol = MagicMock() + mol.residues = [mock_residue1, mock_residue2] + + f_mat_mock = np.array([[1]]) + t_mat_mock = np.array([[2]]) + level_manager.get_matrices = MagicMock(return_value=(f_mat_mock, t_mat_mock)) + + force_avg = {"ua": {}, "res": [None], "poly": [None]} + torque_avg = {"ua": {}, "res": [None], "poly": [None]} + frame_counts = {"ua": {}, "res": [None], "poly": [None]} + + level_manager.update_force_torque_matrices( + entropy_manager=entropy_manager, + mol=mol, + group_id=0, + level="united_atom", + level_list=["residue", "united_atom"], + time_index=5, + num_frames=10, + force_avg=force_avg, + torque_avg=torque_avg, + frame_counts=frame_counts, + ) + + expected_keys = [(0, 0), (0, 1)] + for key in expected_keys: + np.testing.assert_array_equal(force_avg["ua"][key], f_mat_mock) + np.testing.assert_array_equal(torque_avg["ua"][key], t_mat_mock) + self.assertEqual(frame_counts["ua"][key], 1) + + def test_update_force_torque_matrices_residue(self): + """ + Test that `update_force_torque_matrices` correctly updates force and torque + matrices for the 'residue' level, assigning whole-molecule matrices and + incrementing frame counts. + """ + level_manager = LevelManager() + entropy_manager = MagicMock() + mol = MagicMock() + mol.trajectory.__getitem__.return_value = None + + f_mat_mock = np.array([[1]]) + t_mat_mock = np.array([[2]]) + level_manager.get_matrices = MagicMock(return_value=(f_mat_mock, t_mat_mock)) + + force_avg = {"ua": {}, "res": [None], "poly": [None]} + torque_avg = {"ua": {}, "res": [None], "poly": [None]} + frame_counts = {"ua": {}, "res": [None], "poly": [None]} + + level_manager.update_force_torque_matrices( + entropy_manager=entropy_manager, + mol=mol, + group_id=0, + level="residue", + level_list=["residue", "united_atom"], + time_index=3, + num_frames=10, + force_avg=force_avg, + torque_avg=torque_avg, + frame_counts=frame_counts, + ) + + np.testing.assert_array_equal(force_avg["res"][0], f_mat_mock) + np.testing.assert_array_equal(torque_avg["res"][0], t_mat_mock) + self.assertEqual(frame_counts["res"][0], 1) + + def test_update_force_torque_matrices_incremental_average(self): + """ + Test that `update_force_torque_matrices` correctly applies the incremental + mean formula when updating force and torque matrices over multiple frames. + + Ensures that float precision is maintained and no casting errors occur. + """ + level_manager = LevelManager() + entropy_manager = MagicMock() + mol = MagicMock() + mol.trajectory.__getitem__.return_value = None + + # Ensure matrices are float64 to avoid casting errors + f_mat_1 = np.array([[1.0]], dtype=np.float64) + t_mat_1 = np.array([[2.0]], dtype=np.float64) + f_mat_2 = np.array([[3.0]], dtype=np.float64) + t_mat_2 = np.array([[4.0]], dtype=np.float64) + + level_manager.get_matrices = MagicMock( + side_effect=[(f_mat_1, t_mat_1), (f_mat_2, t_mat_2)] + ) + + force_avg = {"ua": {}, "res": [None], "poly": [None]} + torque_avg = {"ua": {}, "res": [None], "poly": [None]} + frame_counts = {"ua": {}, "res": [None], "poly": [None]} + + # First update + level_manager.update_force_torque_matrices( + entropy_manager=entropy_manager, + mol=mol, + group_id=0, + level="residue", + level_list=["residue", "united_atom"], + time_index=0, + num_frames=10, + force_avg=force_avg, + torque_avg=torque_avg, + frame_counts=frame_counts, + ) + + # Second update + level_manager.update_force_torque_matrices( + entropy_manager=entropy_manager, + mol=mol, + group_id=0, + level="residue", + level_list=["residue", "united_atom"], + time_index=1, + num_frames=10, + force_avg=force_avg, + torque_avg=torque_avg, + frame_counts=frame_counts, + ) + + expected_force = f_mat_1 + (f_mat_2 - f_mat_1) / 2 + expected_torque = t_mat_1 + (t_mat_2 - t_mat_1) / 2 + + np.testing.assert_array_almost_equal(force_avg["res"][0], expected_force) + np.testing.assert_array_almost_equal(torque_avg["res"][0], expected_torque) + self.assertEqual(frame_counts["res"][0], 2) + def test_filter_zero_rows_columns_no_zeros(self): """ Test that matrix with no zero-only rows or columns should return unchanged. diff --git a/tests/test_CodeEntropy/test_logging_config.py b/tests/test_CodeEntropy/test_logging_config.py index 08f50a4..9640950 100644 --- a/tests/test_CodeEntropy/test_logging_config.py +++ b/tests/test_CodeEntropy/test_logging_config.py @@ -2,6 +2,7 @@ import os import tempfile import unittest +from unittest.mock import MagicMock from CodeEntropy.config.logging_config import LoggingConfig @@ -14,6 +15,9 @@ def setUp(self): self.log_dir = os.path.join(self.temp_dir.name, "logs") self.logging_config = LoggingConfig(folder=self.temp_dir.name) + self.mock_text = "Test console output" + self.logging_config.console.export_text = MagicMock(return_value=self.mock_text) + def tearDown(self): self.temp_dir.cleanup() @@ -31,20 +35,17 @@ def test_expected_log_files_created(self): """Ensure log file paths are configured correctly in the logging config""" self.logging_config.setup_logging() - # Map actual output files to their corresponding handler keys + # Map expected filenames to the corresponding handler keys in LoggingConfig expected_handlers = { - "program.out": "stdout", - "program.log": "logfile", - "program.err": "errorfile", - "program.com": "commandfile", - "mdanalysis.log": "mdanalysis_log", + "program.log": "main", + "program.err": "error", + "program.com": "command", + "mdanalysis.log": "mdanalysis", } for filename, handler_key in expected_handlers.items(): - expected_path = os.path.join(self.log_dir, filename) - actual_path = self.logging_config.LOGGING["handlers"][handler_key][ - "filename" - ] + expected_path = os.path.join(self.logging_config.log_dir, filename) + actual_path = self.logging_config.handlers[handler_key].baseFilename self.assertEqual(actual_path, expected_path) def test_update_logging_level(self): @@ -67,9 +68,7 @@ def test_update_logging_level(self): def test_mdanalysis_and_command_loggers_exist(self): """Ensure specialized loggers are set up with correct configuration""" log_level = logging.DEBUG - self.logging_config = LoggingConfig( - folder=self.temp_dir.name, log_level=log_level - ) + self.logging_config = LoggingConfig(folder=self.temp_dir.name, level=log_level) self.logging_config.setup_logging() mda_logger = logging.getLogger("MDAnalysis") @@ -80,6 +79,26 @@ def test_mdanalysis_and_command_loggers_exist(self): self.assertFalse(mda_logger.propagate) self.assertFalse(cmd_logger.propagate) + def test_save_console_log_writes_file(self): + """ + Test that save_console_log creates a log file in the expected location + and writes the console's recorded output correctly. + """ + filename = "test_log.txt" + self.logging_config.save_console_log(filename) + + output_path = os.path.join(self.temp_dir.name, "logs", filename) + # Check file exists + self.assertTrue(os.path.exists(output_path)) + + # Read content and check it matches mocked export_text output + with open(output_path, "r", encoding="utf-8") as f: + content = f.read() + self.assertEqual(content, self.mock_text) + + # Ensure export_text was called once + self.logging_config.console.export_text.assert_called_once() + if __name__ == "__main__": unittest.main() diff --git a/tests/test_CodeEntropy/test_main.py b/tests/test_CodeEntropy/test_main.py index 9a16f35..1a60972 100644 --- a/tests/test_CodeEntropy/test_main.py +++ b/tests/test_CodeEntropy/test_main.py @@ -107,6 +107,10 @@ def test_main_entry_point_runs(self): with open(config_path, "w") as f: f.write("run1:\n" " end: 60\n" " selection_string: resid 1\n") + citation_path = os.path.join(self.test_dir, "CITATION.cff") + with open(citation_path, "w") as f: + f.write("run1:\n" " end: 60\n" " selection_string: resid 1\n") + result = subprocess.run( [ sys.executable, diff --git a/tests/test_CodeEntropy/test_run.py b/tests/test_CodeEntropy/test_run.py index 4303d9e..565e476 100644 --- a/tests/test_CodeEntropy/test_run.py +++ b/tests/test_CodeEntropy/test_run.py @@ -2,9 +2,13 @@ import shutil import tempfile import unittest -from unittest.mock import MagicMock, patch +from io import StringIO +from unittest.mock import MagicMock, mock_open, patch import numpy as np +import requests +import yaml +from rich.console import Console from CodeEntropy.run import RunManager @@ -20,6 +24,14 @@ def setUp(self): Set up a temporary directory as the working directory before each test. """ self.test_dir = tempfile.mkdtemp(prefix="CodeEntropy_") + self.config_file = os.path.join(self.test_dir, "CITATION.cff") + + # Create a mock config file + with patch("builtins.open", new_callable=mock_open) as mock_file: + self.setup_citation_file(mock_file) + with open(self.config_file, "w") as f: + f.write(mock_file.return_value.read()) + self._orig_dir = os.getcwd() os.chdir(self.test_dir) @@ -31,6 +43,18 @@ def tearDown(self): os.chdir(self._orig_dir) shutil.rmtree(self.test_dir) + def setup_citation_file(self, mock_file): + """ + Mock the contents of the CITATION.cff file. + """ + citation_content = """\ + authors: + - given-names: Alice + family-names: Smith + """ + + mock_file.return_value = mock_open(read_data=citation_content).return_value + @patch("os.makedirs") @patch("os.listdir") def test_create_job_folder_empty_directory(self, mock_listdir, mock_makedirs): @@ -102,6 +126,130 @@ def test_create_job_folder_with_invalid_job_suffix( self.assertEqual(new_folder_path, expected_path) mock_makedirs.assert_called_once_with(expected_path, exist_ok=True) + @patch("requests.get") + def test_load_citation_data_success(self, mock_get): + """Should return parsed dict when CITATION.cff loads successfully.""" + mock_yaml = """ + authors: + - given-names: Alice + family-names: Smith + title: TestProject + version: 1.0 + date-released: 2025-01-01 + """ + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.text = mock_yaml + mock_get.return_value = mock_response + + instance = RunManager("dummy") + data = instance.load_citation_data() + + self.assertIsInstance(data, dict) + self.assertEqual(data["title"], "TestProject") + self.assertEqual(data["authors"][0]["given-names"], "Alice") + + @patch("requests.get") + def test_load_citation_data_network_error(self, mock_get): + """Should return None if network request fails.""" + mock_get.side_effect = requests.exceptions.ConnectionError("Network down") + + instance = RunManager("dummy") + data = instance.load_citation_data() + + self.assertIsNone(data) + + @patch("requests.get") + def test_load_citation_data_http_error(self, mock_get): + """Should return None if HTTP response is non-200.""" + mock_response = MagicMock() + mock_response.status_code = 404 + mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError() + mock_get.return_value = mock_response + + instance = RunManager("dummy") + data = instance.load_citation_data() + + self.assertIsNone(data) + + @patch("requests.get") + def test_load_citation_data_invalid_yaml(self, mock_get): + """Should raise YAML error if file content is invalid YAML.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.text = "invalid: [oops" + mock_get.return_value = mock_response + + instance = RunManager("dummy") + + with self.assertRaises(yaml.YAMLError): + instance.load_citation_data() + + @patch.object(RunManager, "load_citation_data") + def test_show_splash_with_citation(self, mock_load): + """Should render full splash screen when citation data is present.""" + mock_load.return_value = { + "title": "TestProject", + "version": "1.0", + "date-released": "2025-01-01", + "url": "https://example.com", + "abstract": "This is a test abstract.", + "authors": [ + {"given-names": "Alice", "family-names": "Smith", "affiliation": "Uni"} + ], + } + + buf = StringIO() + test_console = Console(file=buf, force_terminal=False, width=80) + + instance = RunManager("dummy") + with patch("CodeEntropy.run.console", test_console): + instance.show_splash() + + output = buf.getvalue() + + self.assertIn("Version 1.0", output) + self.assertIn("2025-01-01", output) + self.assertIn("https://example.com", output) + self.assertIn("This is a test abstract.", output) + self.assertIn("Alice Smith", output) + + @patch.object(RunManager, "load_citation_data", return_value=None) + def test_show_splash_without_citation(self, mock_load): + """Should render minimal splash screen when no citation data.""" + buf = StringIO() + test_console = Console(file=buf, force_terminal=False, width=80) + + instance = RunManager("dummy") + with patch("CodeEntropy.run.console", test_console): + instance.show_splash() + + output = buf.getvalue() + + self.assertNotIn("Version", output) + self.assertNotIn("Contributors", output) + self.assertIn("Welcome to CodeEntropy", output) + + @patch.object(RunManager, "load_citation_data") + def test_show_splash_missing_fields(self, mock_load): + """Should gracefully handle missing optional fields in citation data.""" + mock_load.return_value = { + "title": "PartialProject", + # no version, no date, no authors, no abstract + } + + buf = StringIO() + test_console = Console(file=buf, force_terminal=False, width=80) + + instance = RunManager("dummy") + with patch("CodeEntropy.run.console", test_console): + instance.show_splash() + + output = buf.getvalue() + + self.assertIn("Version ?", output) + self.assertIn("No description available.", output) + def test_run_entropy_workflow(self): """ Test the run_entropy_workflow method to ensure it initializes and executes @@ -110,6 +258,7 @@ def test_run_entropy_workflow(self): run_manager = RunManager("folder") run_manager._logging_config = MagicMock() run_manager._config_manager = MagicMock() + run_manager.load_citation_data = MagicMock() run_manager._data_logger = MagicMock() run_manager.folder = self.test_dir @@ -125,6 +274,23 @@ def test_run_entropy_workflow(self): } } + run_manager.load_citation_data.return_value = { + "cff-version": "1.2.0", + "title": "CodeEntropy", + "message": ( + "If you use this software, please cite it using the " + "metadata from this file." + ), + "type": "software", + "authors": [ + { + "given-names": "Forename", + "family-names": "Sirname", + "email": "test@email.ac.uk", + } + ], + } + mock_args = MagicMock() mock_args.output_file = "output.json" mock_args.verbose = True @@ -155,6 +321,7 @@ def test_run_configuration_warning(self): run_manager = RunManager("folder") run_manager._logging_config = MagicMock() run_manager._config_manager = MagicMock() + run_manager.load_citation_data = MagicMock() run_manager._data_logger = MagicMock() run_manager.folder = self.test_dir @@ -165,6 +332,23 @@ def test_run_configuration_warning(self): "invalid_run": "this_should_be_a_dict" } + run_manager.load_citation_data.return_value = { + "cff-version": "1.2.0", + "title": "CodeEntropy", + "message": ( + "If you use this software, please cite it using the " + "metadata from this file." + ), + "type": "software", + "authors": [ + { + "given-names": "Forename", + "family-names": "Sirname", + "email": "test@email.ac.uk", + } + ], + } + mock_args = MagicMock() mock_args.output_file = "output.json" mock_args.verbose = False @@ -186,6 +370,7 @@ def test_run_entropy_workflow_missing_traj_file(self): run_manager = RunManager("folder") run_manager._logging_config = MagicMock() run_manager._config_manager = MagicMock() + run_manager.load_citation_data = MagicMock() run_manager._data_logger = MagicMock() run_manager.folder = self.test_dir @@ -200,6 +385,23 @@ def test_run_entropy_workflow_missing_traj_file(self): } } + run_manager.load_citation_data.return_value = { + "cff-version": "1.2.0", + "title": "CodeEntropy", + "message": ( + "If you use this software, please cite it using the " + "metadata from this file." + ), + "type": "software", + "authors": [ + { + "given-names": "Forename", + "family-names": "Sirname", + "email": "test@email.ac.uk", + } + ], + } + mock_args = MagicMock() mock_args.output_file = "output.json" mock_args.verbose = False @@ -220,6 +422,7 @@ def test_run_entropy_workflow_missing_selection_string(self): run_manager = RunManager("folder") run_manager._logging_config = MagicMock() run_manager._config_manager = MagicMock() + run_manager.load_citation_data = MagicMock() run_manager._data_logger = MagicMock() run_manager.folder = self.test_dir @@ -234,6 +437,23 @@ def test_run_entropy_workflow_missing_selection_string(self): } } + run_manager.load_citation_data.return_value = { + "cff-version": "1.2.0", + "title": "CodeEntropy", + "message": ( + "If you use this software, please cite it using the " + "metadata from this file." + ), + "type": "software", + "authors": [ + { + "given-names": "Forename", + "family-names": "Sirname", + "email": "test@email.ac.uk", + } + ], + } + mock_args = MagicMock() mock_args.output_file = "output.json" mock_args.verbose = False