diff --git a/CodeEntropy/config/arg_config_manager.py b/CodeEntropy/config/arg_config_manager.py index 373dd9d..0d809ed 100644 --- a/CodeEntropy/config/arg_config_manager.py +++ b/CodeEntropy/config/arg_config_manager.py @@ -64,6 +64,11 @@ "help": "If set to False, disables the calculation of water entropy", "default": True, }, + "grouping": { + "type": str, + "help": "How to group molecules for averaging", + "default": "each", + }, } diff --git a/CodeEntropy/entropy.py b/CodeEntropy/entropy.py index 21a7e62..9451bb3 100644 --- a/CodeEntropy/entropy.py +++ b/CodeEntropy/entropy.py @@ -16,7 +16,9 @@ class EntropyManager: molecular dynamics trajectory. """ - def __init__(self, run_manager, args, universe, data_logger, level_manager): + def __init__( + self, run_manager, args, universe, data_logger, level_manager, group_molecules + ): """ Initializes the EntropyManager with required components. @@ -32,44 +34,29 @@ def __init__(self, run_manager, args, universe, data_logger, level_manager): self._universe = universe self._data_logger = data_logger self._level_manager = level_manager + self._group_molecules = group_molecules self._GAS_CONST = 8.3144598484848 def execute(self): """ - Executes the full entropy computation workflow over selected molecules and - levels. This includes both vibrational and conformational entropy, recorded - per molecule and residue. + Run the full entropy computation workflow. + + This method orchestrates the entire entropy analysis pipeline, including: + - Handling water entropy if present. + - Initializing molecular structures and levels. + - Building force and torque covariance matrices. + - Computing vibrational and conformational entropies. + - Finalizing and logging results. """ start, end, step = self._get_trajectory_bounds() number_frames = self._get_number_frames(start, end, step) - - has_water = self._universe.select_atoms("water").n_atoms > 0 - if has_water and self._args.water_entropy: - self._calculate_water_entropy(self._universe, start, end, step) - - if self._args.selection_string != "all": - self._args.selection_string += " and not water" - else: - self._args.selection_string = "not water" - - logger.debug( - "WaterEntropy: molecule_data: %s", - self._data_logger.molecule_data, - ) - logger.debug( - "WaterEntropy: residue_data: %s", - self._data_logger.residue_data, - ) - - reduced_atom = self._get_reduced_universe() - number_molecules, levels = self._level_manager.select_levels(reduced_atom) - ve = VibrationalEntropy( self._run_manager, self._args, self._universe, self._data_logger, self._level_manager, + self._group_molecules, ) ce = ConformationalEntropy( self._run_manager, @@ -77,89 +64,181 @@ def execute(self): self._universe, self._data_logger, self._level_manager, + self._group_molecules, + ) + + self._handle_water_entropy(start, end, step) + reduced_atom, number_molecules, levels, groups = self._initialize_molecules() + + force_matrices, torque_matrices = self._level_manager.build_covariance_matrices( + self, + reduced_atom, + levels, + groups, + start, + end, + step, + number_frames, ) - for molecule_id in range(number_molecules): - mol_container = self._get_molecule_container(reduced_atom, molecule_id) + states_ua, states_res = self._level_manager.build_conformational_states( + self, + reduced_atom, + levels, + groups, + start, + end, + step, + number_frames, + self._args.bin_width, + ce, + ) + + self._compute_entropies( + reduced_atom, + levels, + groups, + force_matrices, + torque_matrices, + states_ua, + states_res, + number_frames, + ve, + ce, + ) + + self._finalize_molecule_results() + self._data_logger.log_tables() + + def _handle_water_entropy(self, start, end, step): + """ + Compute and exclude water entropy from the system if applicable. + + If water molecules are present and water entropy calculation is enabled, + this method computes their entropy and updates the selection string to + exclude water from further analysis. + + Parameters: + start (int): Start frame index. + end (int): End frame index. + step (int): Step size for frame iteration. + """ + has_water = self._universe.select_atoms("water").n_atoms > 0 + if has_water and self._args.water_entropy: + self._calculate_water_entropy(self._universe, start, end, step) + self._args.selection_string = ( + self._args.selection_string + " and not water" + if self._args.selection_string != "all" + else "not water" + ) + logger.debug( + "WaterEntropy: molecule_data: %s", self._data_logger.molecule_data + ) + logger.debug( + "WaterEntropy: residue_data: %s", self._data_logger.residue_data + ) + + def _initialize_molecules(self): + """ + Prepare the reduced universe and determine molecule-level configurations. + + Returns: + tuple: A tuple containing: + - reduced_atom (Universe): The reduced atom selection. + - number_molecules (int): Number of molecules in the system. + - levels (list): List of entropy levels per molecule. + """ + reduced_atom = self._get_reduced_universe() + number_molecules, levels = self._level_manager.select_levels(reduced_atom) + grouping = self._args.grouping + groups = self._group_molecules.grouping_molecules(reduced_atom, grouping) + + return reduced_atom, number_molecules, levels, groups + + def _compute_entropies( + self, + reduced_atom, + levels, + groups, + force_matrices, + torque_matrices, + states_ua, + states_res, + number_frames, + ve, + ce, + ): + """ + Compute vibrational and conformational entropies for all molecules and levels. + + This method iterates over each molecule and its associated entropy levels + (united_atom, residue, polymer), computing the corresponding entropy + contributions using force/torque matrices and dihedral conformations. + + For each level: + - "united_atom": Computes per-residue conformational states and entropy. + - "residue": Computes molecule-level conformational and vibrational entropy. + - "polymer": Computes only vibrational entropy. + + Parameters: + reduced_atom (Universe): The reduced atom selection from the trajectory. + number_molecules (int): Number of molecules in the system. + levels (list): List of entropy levels per molecule. + force_matrices (dict): Precomputed force covariance matrices. + torque_matrices (dict): Precomputed torque covariance matrices. + states_ua (dict): Dictionary to store united-atom conformational states. + states_res (list): List to store residue-level conformational states. + number_frames (int): Total number of trajectory frames to process. + """ + for group_id in groups.keys(): + mol = self._get_molecule_container(reduced_atom, groups[group_id][0]) + for level in levels[groups[group_id][0]]: + highest = level == levels[groups[group_id][0]][-1] - for level in levels[molecule_id]: - highest_level = level == levels[molecule_id][-1] if level == "united_atom": - self._process_united_atom_level( - molecule_id, - mol_container, + self._process_united_atom_entropy( + group_id, + mol, ve, ce, level, - start, - end, - step, + force_matrices["ua"], + torque_matrices["ua"], + states_ua, + highest, number_frames, - highest_level, ) - logger.debug( - "%s level: molecule_data: %s", - level, - self._data_logger.molecule_data, - ) - logger.debug( - "%s level: residue_data: %s", - level, - self._data_logger.residue_data, - ) - - elif level in ("polymer", "residue"): - self._process_vibrational_only_levels( - molecule_id, - mol_container, - ve, - level, - start, - end, - step, + elif level == "residue": + self._process_vibrational_entropy( + group_id, number_frames, - highest_level, - ) - - logger.debug( - "%s level: molecule_data: %s", - level, - self._data_logger.molecule_data, - ) - logger.debug( - "%s level: residue_data: %s", + ve, level, - self._data_logger.residue_data, + force_matrices["res"][group_id], + torque_matrices["res"][group_id], + highest, ) - if level == "residue": - self._process_conformational_residue_level( - molecule_id, - mol_container, + self._process_conformational_entropy( + group_id, ce, level, - start, - end, - step, + states_res, number_frames, ) - logger.debug( - "%s level: molecule_data: %s", - level, - self._data_logger.molecule_data, - ) - logger.debug( - "%s level: residue_data: %s", + elif level == "polymer": + self._process_vibrational_entropy( + group_id, + number_frames, + ve, level, - self._data_logger.residue_data, + force_matrices["poly"][group_id], + torque_matrices["poly"][group_id], + highest, ) - self._finalize_molecule_results() - - self._data_logger.log_tables() - def _get_trajectory_bounds(self): """ Returns the start, end, and step frame indices based on input arguments. @@ -228,8 +307,18 @@ def _get_molecule_container(self, universe, molecule_id): selection_string = f"index {frag.indices[0]}:{frag.indices[-1]}" return self._run_manager.new_U_select_atom(universe, selection_string) - def _process_united_atom_level( - self, mol_id, mol_container, ve, ce, level, start, end, step, n_frames, highest + def _process_united_atom_entropy( + self, + group_id, + mol_container, + ve, + ce, + level, + force_matrix, + torque_matrix, + states, + highest, + number_frames, ): """ Calculates translational, rotational, and conformational entropy at the @@ -246,31 +335,29 @@ def _process_united_atom_level( highest (bool): Whether this is the highest level of resolution for the molecule. """ - bin_width = self._args.bin_width S_trans, S_rot, S_conf = 0, 0, 0 + for residue_id, residue in enumerate(mol_container.residues): - res_container = self._run_manager.new_U_select_atom( - mol_container, - f"index {residue.atoms.indices[0]}:{residue.atoms.indices[-1]}", - ) - heavy_res = self._run_manager.new_U_select_atom( - res_container, "not name H*" - ) - force_matrix, torque_matrix = self._level_manager.get_matrices( - res_container, level, start, end, step, n_frames, highest - ) + key = (group_id, residue_id) + + f_matrix = force_matrix[key] + f_matrix = self._level_manager.filter_zero_rows_columns(f_matrix) + f_matrix = f_matrix / number_frames + + t_matrix = torque_matrix[key] + t_matrix = self._level_manager.filter_zero_rows_columns(t_matrix) + t_matrix = t_matrix / number_frames S_trans_res = ve.vibrational_entropy_calculation( - force_matrix, "force", self._args.temperature, highest + f_matrix, "force", self._args.temperature, highest ) S_rot_res = ve.vibrational_entropy_calculation( - torque_matrix, "torque", self._args.temperature, highest + t_matrix, "torque", self._args.temperature, highest ) - dihedrals = self._level_manager.get_dihedrals(heavy_res, level) S_conf_res = ce.conformational_entropy_calculation( - heavy_res, dihedrals, bin_width, start, end, step, n_frames + states[key], number_frames ) S_trans += S_trans_res @@ -278,61 +365,53 @@ def _process_united_atom_level( S_conf += S_conf_res self._data_logger.add_residue_data( - mol_id, residue.resname, level, "Transvibrational", S_trans_res + residue_id, residue.resname, level, "Transvibrational", S_trans_res ) self._data_logger.add_residue_data( - mol_id, residue.resname, level, "Rovibrational", S_rot_res + residue_id, residue.resname, level, "Rovibrational", S_rot_res ) self._data_logger.add_residue_data( - mol_id, residue.resname, level, "Conformational", S_conf_res + residue_id, residue.resname, level, "Conformational", S_conf_res ) - self._data_logger.add_results_data( - residue.resname, level, "Transvibrational", S_trans - ) - self._data_logger.add_results_data( - residue.resname, level, "Rovibrational", S_rot - ) - self._data_logger.add_results_data( - residue.resname, level, "Conformational", S_conf - ) + self._data_logger.add_results_data(group_id, level, "Transvibrational", S_trans) + self._data_logger.add_results_data(group_id, level, "Rovibrational", S_rot) + self._data_logger.add_results_data(group_id, level, "Conformational", S_conf) - def _process_vibrational_only_levels( - self, mol_id, mol_container, ve, level, start, end, step, n_frames, highest + def _process_vibrational_entropy( + self, group_id, number_frames, ve, level, force_matrix, torque_matrix, highest ): """ - Calculates vibrational entropy at levels where conformational entropy is - not considered. + Calculates vibrational entropy. Args: mol_id (int): Molecule ID. mol_container (Universe): Selected molecule's universe. ve: VibrationalEntropy object. - level (str): Current granularity level ('polymer' or 'residue'). - start, end, step (int): Trajectory frame parameters. - n_frames (int): Number of trajectory frames. + level (str): Current granularity level. + force_matrix : Force covariance matrix + torque_matrix : Torque covariance matrix highest (bool): Flag indicating if this is the highest granularity level. """ - force_matrix, torque_matrix = self._level_manager.get_matrices( - mol_container, level, start, end, step, n_frames, highest - ) + force_matrix = self._level_manager.filter_zero_rows_columns(force_matrix) + force_matrix = force_matrix / number_frames + + torque_matrix = self._level_manager.filter_zero_rows_columns(torque_matrix) + torque_matrix = torque_matrix / number_frames + S_trans = ve.vibrational_entropy_calculation( force_matrix, "force", self._args.temperature, highest ) S_rot = ve.vibrational_entropy_calculation( torque_matrix, "torque", self._args.temperature, highest ) - residue = mol_container.residues[mol_id] - self._data_logger.add_results_data( - residue.resname, level, "Transvibrational", S_trans - ) - self._data_logger.add_results_data( - residue.resname, level, "Rovibrational", S_rot - ) - def _process_conformational_residue_level( - self, mol_id, mol_container, ce, level, start, end, step, n_frames + self._data_logger.add_results_data(group_id, level, "Transvibrational", S_trans) + self._data_logger.add_results_data(group_id, level, "Rovibrational", S_rot) + + def _process_conformational_entropy( + self, group_id, ce, level, states, number_frames ): """ Computes conformational entropy at the residue level (whole-molecule dihedral @@ -346,15 +425,9 @@ def _process_conformational_residue_level( start, end, step (int): Frame bounds. n_frames (int): Number of frames used. """ - bin_width = self._args.bin_width - dihedrals = self._level_manager.get_dihedrals(mol_container, level) - S_conf = ce.conformational_entropy_calculation( - mol_container, dihedrals, bin_width, start, end, step, n_frames - ) - residue = mol_container.residues[mol_id] - self._data_logger.add_results_data( - residue.resname, level, "Conformational", S_conf - ) + S_conf = ce.conformational_entropy_calculation(states[group_id], number_frames) + + self._data_logger.add_results_data(group_id, level, "Conformational", S_conf) def _finalize_molecule_results(self): """ @@ -504,12 +577,16 @@ class VibrationalEntropy(EntropyManager): vibrational modes and thermodynamic properties. """ - def __init__(self, run_manager, args, universe, data_logger, level_manager): + def __init__( + self, run_manager, args, universe, data_logger, level_manager, group_molecules + ): """ Initializes the VibrationalEntropy manager with all required components and defines physical constants used in vibrational entropy calculations. """ - super().__init__(run_manager, args, universe, data_logger, level_manager) + super().__init__( + run_manager, args, universe, data_logger, level_manager, group_molecules + ) self._PLANCK_CONST = 6.62607004081818e-34 def frequency_calculation(self, lambdas, temp): @@ -628,12 +705,16 @@ class ConformationalEntropy(EntropyManager): analysis using statistical mechanics principles. """ - def __init__(self, run_manager, args, universe, data_logger, level_manager): + def __init__( + self, run_manager, args, universe, data_logger, level_manager, group_molecules + ): """ Initializes the ConformationalEntropy manager with all required components and sets the gas constant used in conformational entropy calculations. """ - super().__init__(run_manager, args, universe, data_logger, level_manager) + super().__init__( + run_manager, args, universe, data_logger, level_manager, group_molecules + ) def assign_conformation( self, data_container, dihedral, number_frames, bin_width, start, end, step @@ -718,9 +799,7 @@ def assign_conformation( return conformations - def conformational_entropy_calculation( - self, data_container, dihedrals, bin_width, start, end, step, number_frames - ): + def conformational_entropy_calculation(self, states, number_frames): """ Function to calculate conformational entropies using eq. (7) in Higham, S.-Y. Chou, F. Gräter and R. H. Henchman, Molecular Physics, 2018, 116, @@ -739,27 +818,6 @@ def conformational_entropy_calculation( S_conf_total = 0 - # For each dihedral, identify the conformation in each frame - num_dihedrals = len(dihedrals) - conformation = np.zeros((num_dihedrals, number_frames)) - index = 0 - for dihedral in dihedrals: - conformation[index] = self.assign_conformation( - data_container, dihedral, number_frames, bin_width, start, end, step - ) - index += 1 - - logger.debug(f"Conformation matrix: {conformation}") - - # For each frame, convert the conformation of all dihedrals into a - # state string - states = ["" for x in range(number_frames)] - for frame_index in range(number_frames): - for index in range(num_dihedrals): - states[frame_index] += str(conformation[index][frame_index]) - - logger.debug(f"States: {states}") - # Count how many times each state occurs, then use the probability # to get the entropy # entropy = sum over states p*ln(p) @@ -787,12 +845,16 @@ class OrientationalEntropy(EntropyManager): and orientational degrees of freedom. """ - def __init__(self, run_manager, args, universe, data_logger, level_manager): + def __init__( + self, run_manager, args, universe, data_logger, level_manager, group_molecules + ): """ Initializes the OrientationalEntropy manager with all required components and sets the gas constant used in orientational entropy calculations. """ - super().__init__(run_manager, args, universe, data_logger, level_manager) + super().__init__( + run_manager, args, universe, data_logger, level_manager, group_molecules + ) def orientational_entropy_calculation(self, neighbours_dict): """ diff --git a/CodeEntropy/group_molecules.py b/CodeEntropy/group_molecules.py new file mode 100644 index 0000000..f8e43fb --- /dev/null +++ b/CodeEntropy/group_molecules.py @@ -0,0 +1,91 @@ +import logging + +logger = logging.getLogger(__name__) + + +class GroupMolecules: + """ + Groups molecules for averaging. + """ + + def __init__(self): + """ + Initializes the class with relevant information. + + Args: + run_manager: Manager for universe and selection operations. + args: Argument namespace containing user parameters. + universe: MDAnalysis universe representing the simulation system. + data_logger: Logger for storing and exporting entropy data. + """ + self._molecule_groups = None + + def grouping_molecules(self, universe, grouping): + """ + Grouping molecules by desired level of detail. + """ + + molecule_groups = {} + + if grouping == "each": + molecule_groups = self._by_none(universe) + + if grouping == "molecules": + molecule_groups = self._by_molecules(universe) + + return molecule_groups + + def _by_none(self, universe): + """ + Don't group molecules. Every molecule is in its own group. + """ + + # fragments is MDAnalysis terminology for molecules + number_molecules = len(universe.atoms.fragments) + + molecule_groups = {} + + for molecule_i in range(number_molecules): + molecule_groups[molecule_i] = [molecule_i] + + number_groups = len(molecule_groups) + + logger.info(f"Number of molecule groups: {number_groups}") + logger.debug(f"Molecule groups are: {molecule_groups}") + + return molecule_groups + + def _by_molecules(self, universe): + """ + Group molecules by chemical type. + Based on number of atoms and atom names. + """ + + # fragments is MDAnalysis terminology for molecules + number_molecules = len(universe.atoms.fragments) + fragments = universe.atoms.fragments + + molecule_groups = {} + + for molecule_i in range(number_molecules): + names_i = fragments[molecule_i].names + number_atoms_i = len(names_i) + + for molecule_j in range(number_molecules): + names_j = fragments[molecule_j].names + number_atoms_j = len(names_j) + + if number_atoms_i == number_atoms_j and (names_i == names_j).all: + if molecule_j in molecule_groups.keys(): + molecule_groups[molecule_j].append(molecule_i) + else: + molecule_groups[molecule_j] = [] + molecule_groups[molecule_j].append(molecule_i) + break + + number_groups = len(molecule_groups) + + logger.info(f"Number of molecule groups: {number_groups}") + logger.debug(f"Molecule groups are: {molecule_groups}") + + return molecule_groups diff --git a/CodeEntropy/levels.py b/CodeEntropy/levels.py index a4be74d..6c58f2d 100644 --- a/CodeEntropy/levels.py +++ b/CodeEntropy/levels.py @@ -46,7 +46,7 @@ def select_levels(self, data_container): # fragments is MDAnalysis terminology for what chemists would call molecules number_molecules = len(data_container.atoms.fragments) - logger.debug("The number of molecules is {}.".format(number_molecules)) + logger.debug(f"The number of molecules is {number_molecules}.") fragments = data_container.atoms.fragments levels = [[] for _ in range(number_molecules)] @@ -70,26 +70,36 @@ def select_levels(self, data_container): return number_molecules, levels def get_matrices( - self, data_container, level, start, end, step, number_frames, highest_level + self, + data_container, + level, + number_frames, + highest_level, + force_matrix, + torque_matrix, ): """ - Function to create the force matrix needed for the transvibrational entropy - calculation and the torque matrix for the rovibrational entropy calculation. + Compute and accumulate force/torque covariance matrices for a given level. - Input - ----- - data_container : MDAnalysis universe type with the information on the - molecule of interest. - level : string, which of the polymer, residue, or united atom levels - are the matrices for. - start : int, starting frame, default 0 (first frame) - end : int, ending frame, default -1 (last frame) - step : int, step for going through trajectories, default 1 + Parameters + ---------- + data_container : MDAnalysis.Universe + Atom group for a molecule or residue. + level : str + 'polymer', 'residue', or 'united_atom'. + number_frames : int + Number of frames being processed. + highest_level : bool + Whether this is the top (polymer) level. + force_matrix, torque_matrix : np.ndarray or None + Accumulated matrices to add to. Returns ------- - force_matrix : force covariance matrix for transvibrational entropy - torque_matrix : torque convariance matrix for rovibrational entropy + force_matrix : np.ndarray + Accumulated force covariance matrix. + torque_matrix : np.ndarray + Accumulated torque covariance matrix. """ # Make beads @@ -99,81 +109,79 @@ def get_matrices( number_beads = len(list_of_beads) # initialize force and torque arrays - weighted_forces = [ - [0 for x in range(number_frames)] for y in range(number_beads) - ] - weighted_torques = [ - [0 for x in range(number_frames)] for y in range(number_beads) - ] + weighted_forces = [None for _ in range(number_beads)] + weighted_torques = [None for _ in range(number_beads)] # Calculate forces/torques for each bead for bead_index in range(number_beads): - for timestep in data_container.trajectory[start:end:step]: - # Set up axes - # translation and rotation use different axes - # how the axes are defined depends on the level - trans_axes, rot_axes = self.get_axes(data_container, level, bead_index) - - # Sort out coordinates, forces, and torques for each atom in the bead - timestep_index = timestep.frame - start - weighted_forces[bead_index][timestep_index] = self.get_weighted_forces( - data_container, list_of_beads[bead_index], trans_axes, highest_level - ) - weighted_torques[bead_index][timestep_index] = ( - self.get_weighted_torques( - data_container, list_of_beads[bead_index], rot_axes - ) - ) - - # Make covariance matrices - looping over pairs of beads - # list of pairs of indices - pair_list = [(i, j) for i in range(number_beads) for j in range(number_beads)] + # Set up axes + # translation and rotation use different axes + # how the axes are defined depends on the level + trans_axes, rot_axes = self.get_axes(data_container, level, bead_index) + + # Sort out coordinates, forces, and torques for each atom in the bead + weighted_forces[bead_index] = self.get_weighted_forces( + data_container, list_of_beads[bead_index], trans_axes, highest_level + ) + weighted_torques[bead_index] = self.get_weighted_torques( + data_container, list_of_beads[bead_index], rot_axes + ) + # Create covariance submatrices force_submatrix = [ - [0 for x in range(number_beads)] for y in range(number_beads) + [0 for _ in range(number_beads)] for _ in range(number_beads) ] torque_submatrix = [ - [0 for x in range(number_beads)] for y in range(number_beads) + [0 for _ in range(number_beads)] for _ in range(number_beads) ] - for i, j in pair_list: - # for each pair of beads - # reducing effort because the matrix for [i][j] is the transpose of the one - # for [j][i] - if i <= j: - # calculate the force covariance segment of the matrix - force_submatrix[i][j] = self.create_submatrix( + for i in range(number_beads): + for j in range(i, number_beads): + f_sub = self.create_submatrix( weighted_forces[i], weighted_forces[j], number_frames ) - force_submatrix[j][i] = np.transpose(force_submatrix[i][j]) - - # calculate the torque covariance segment of the matrix - torque_submatrix[i][j] = self.create_submatrix( + t_sub = self.create_submatrix( weighted_torques[i], weighted_torques[j], number_frames ) - torque_submatrix[j][i] = np.transpose(torque_submatrix[i][j]) + force_submatrix[i][j] = f_sub + force_submatrix[j][i] = f_sub.T + torque_submatrix[i][j] = t_sub + torque_submatrix[j][i] = t_sub.T - # use np.block to make submatrices into one matrix - force_matrix = np.block( + # Convert block matrices to full matrix + force_block = np.block( [ [force_submatrix[i][j] for j in range(number_beads)] for i in range(number_beads) ] ) - - torque_matrix = np.block( + torque_block = np.block( [ [torque_submatrix[i][j] for j in range(number_beads)] for i in range(number_beads) ] ) - # fliter zeros to remove any rows/columns that are all zero - force_matrix = self.filter_zero_rows_columns(force_matrix) - torque_matrix = self.filter_zero_rows_columns(torque_matrix) + # Enforce consistent shape before accumulation + if force_matrix is None: + force_matrix = np.zeros_like(force_block) + elif force_matrix.shape != force_block.shape: + raise ValueError( + f"Inconsistent force matrix shape: existing " + f"{force_matrix.shape}, new {force_block.shape}" + ) + else: + force_matrix += force_block - logger.debug(f"Force Matrix: {force_matrix}") - logger.debug(f"Torque Matrix: {torque_matrix}") + if torque_matrix is None: + torque_matrix = np.zeros_like(torque_block) + elif torque_matrix.shape != torque_block.shape: + raise ValueError( + f"Inconsistent torque matrix shape: existing " + f"{torque_matrix.shape}, new {torque_block.shape}" + ) + else: + torque_matrix += torque_block return force_matrix, torque_matrix @@ -253,6 +261,50 @@ def get_dihedrals(self, data_container, level): return dihedrals + def compute_dihedral_conformations( + self, + selector, + level, + number_frames, + bin_width, + start, + end, + step, + ce, + ): + """ + Compute dihedral conformations for a given selector and entropy level. + + Parameters: + selector (AtomGroup): Atom selection to compute dihedrals for. + level (str): Entropy level ("united_atom" or "residue"). + number_frames (int): Number of frames to process. + bin_width (float): Bin width for dihedral angle discretization. + start (int): Start frame index. + end (int): End frame index. + step (int): Step size for frame iteration. + + Returns: + tuple: A tuple containing: + - states (list): List of conformation strings per frame. + - dihedrals (list): List of dihedral angle definitions. + """ + dihedrals = self.get_dihedrals(selector, level) + num_dihedrals = len(dihedrals) + + conformation = np.zeros((num_dihedrals, number_frames)) + for i, dihedral in enumerate(dihedrals): + conformation[i] = ce.assign_conformation( + selector, dihedral, number_frames, bin_width, start, end, step + ) + + states = [ + "".join(str(int(conformation[d][f])) for d in range(num_dihedrals)) + for f in range(number_frames) + ] + + return states + def get_beads(self, data_container, level): """ Function to define beads depending on the level in the hierarchy. @@ -324,7 +376,7 @@ def get_axes(self, data_container, level, index=0): trans_axes = data_container.atoms.principal_axes() rot_axes = data_container.atoms.principal_axes() - if level == "residue": + elif level == "residue": # Translation # for residues use principal axes of whole molecule for translation trans_axes = data_container.atoms.principal_axes() @@ -356,7 +408,7 @@ def get_axes(self, data_container, level, index=0): # use spherical coordinates function to get rotational axes rot_axes = self.get_sphCoord_axes(vector) - if level == "united_atom": + elif level == "united_atom": # Translation # for united atoms use principal axes of residue for translation trans_axes = data_container.residues.principal_axes() @@ -367,15 +419,19 @@ def get_axes(self, data_container, level, index=0): f"not name H* and bonded index {index}" ) - # center at position of heavy atom - atom_group = data_container.select_atoms(f"index {index}") - center = atom_group.positions[0] + if len(atom_set) == 0: + # if no bonds to other residues use pricipal axes of residue + rot_axes = data_container.residues.principal_axes() + else: + # center at position of heavy atom + atom_group = data_container.select_atoms(f"index {index}") + center = atom_group.positions[0] - # get vector for average position of hydrogens - vector = self.get_avg_pos(atom_set, center) + # get vector for average position of hydrogens + vector = self.get_avg_pos(atom_set, center) - # use spherical coordinates function to get rotational axes - rot_axes = self.get_sphCoord_axes(vector) + # use spherical coordinates function to get rotational axes + rot_axes = self.get_sphCoord_axes(vector) logger.debug(f"Translational Axes: {trans_axes}") logger.debug(f"Rotational Axes: {rot_axes}") @@ -658,17 +714,137 @@ def create_submatrix(self, data_i, data_j, number_frames): # For each frame calculate the outer product (cross product) of the data from # the two beads and add the result to the submatrix - for frame in range(number_frames): - outer_product_matrix = np.outer(data_i[frame], data_j[frame]) - submatrix = np.add(submatrix, outer_product_matrix) - - # Divide by the number of frames to get the average - submatrix /= number_frames + outer_product_matrix = np.outer(data_i, data_j) + submatrix = np.add(submatrix, outer_product_matrix) logger.debug(f"Submatrix: {submatrix}") return submatrix + def build_covariance_matrices( + self, + entropy_manager, + reduced_atom, + levels, + groups, + start, + end, + step, + number_frames, + ): + """ + Construct force and torque covariance matrices for all molecules and levels. + + Parameters: + entropy_manager (EntropyManager): Instance of the EntropyManager + reduced_atom (Universe): The reduced atom selection. + number_molecules (int): Number of molecules in the system. + levels (list): List of entropy levels per molecule. + start (int): Start frame index. + end (int): End frame index. + step (int): Step size for frame iteration. + number_frames (int): Total number of frames to process. + + Returns: + tuple: A tuple containing: + - force_matrices (dict): Force covariance matrices by level. + - torque_matrices (dict): Torque covariance matrices by level. + """ + number_groups = len(groups) + force_matrices = { + "ua": {}, + "res": [None] * number_groups, + "poly": [None] * number_groups, + } + torque_matrices = { + "ua": {}, + "res": [None] * number_groups, + "poly": [None] * number_groups, + } + + for timestep in reduced_atom.trajectory[start:end:step]: + time_index = timestep.frame - start + + for group_id in groups.keys(): + molecules = groups[group_id] + for mol_id in molecules: + mol = entropy_manager._get_molecule_container(reduced_atom, mol_id) + for level in levels[mol_id]: + self.update_force_torque_matrices( + entropy_manager, + mol, + group_id, + level, + levels[mol_id], + time_index, + number_frames, + force_matrices, + torque_matrices, + ) + + return force_matrices, torque_matrices + + def update_force_torque_matrices( + self, + entropy_manager, + mol, + group_id, + level, + level_list, + time_index, + num_frames, + force_matrices, + torque_matrices, + ): + """ + Update force and torque matrices for a given molecule and entropy level. + + Parameters: + entropy_manager (EntropyManager): Instance of the EntropyManager + mol (AtomGroup): The molecule to process. + group_id (int): Index of the group. + level (str): Current entropy level ("united_atom", "residue", or "polymer"). + level_list (list): List of levels for the molecule. + time_index (int): Index of the current frame. + num_frames (int): Total number of frames. + force_matrices (dict): Dictionary of force matrices to update. + torque_matrices (dict): Dictionary of torque matrices to update. + """ + highest = level == level_list[-1] + + if level == "united_atom": + for res_id, residue in enumerate(mol.residues): + key = (group_id, res_id) + res = entropy_manager._run_manager.new_U_select_atom( + mol, f"index {residue.atoms.indices[0]}:{residue.atoms.indices[-1]}" + ) + res.trajectory[time_index] + + f_mat, t_mat = self.get_matrices( + res, + level, + num_frames, + highest, + force_matrices["ua"].get(key), + torque_matrices["ua"].get(key), + ) + force_matrices["ua"][key] = f_mat + torque_matrices["ua"][key] = t_mat + + elif level in ["residue", "polymer"]: + mol.trajectory[time_index] + key = "res" if level == "residue" else "poly" + f_mat, t_mat = self.get_matrices( + mol, + level, + num_frames, + highest, + force_matrices[key][group_id], + torque_matrices[key][group_id], + ) + force_matrices[key][group_id] = f_mat + torque_matrices[key][group_id] = t_mat + def filter_zero_rows_columns(self, arg_matrix): """ function for removing rows and columns that contain only zeros from a matrix @@ -718,3 +894,96 @@ def filter_zero_rows_columns(self, arg_matrix): logger.debug(f"arg_matrix: {arg_matrix}") return arg_matrix + + def build_conformational_states( + self, + entropy_manager, + reduced_atom, + levels, + groups, + start, + end, + step, + number_frames, + bin_width, + ce, + ): + """ + Construct the conformational states for each molecule at + relevant levels. + + Parameters: + entropy_manager (EntropyManager): Instance of the EntropyManager + reduced_atom (Universe): The reduced atom selection. + levels (list): List of entropy levels per molecule. + start (int): Start frame index. + end (int): End frame index. + step (int): Step size for frame iteration. + number_frames (int): Total number of frames to process. + + Returns: + tuple: A tuple containing: + - states_ua (dict): Conformational states at the united-atom level. + - states_res (list): Conformational states at the residue level. + """ + number_groups = len(groups) + states_ua = {} + states_res = [None] * number_groups + + for group_id in groups.keys(): + molecules = groups[group_id] + for mol_id in molecules: + mol = entropy_manager._get_molecule_container(reduced_atom, mol_id) + for level in levels[mol_id]: + if level == "united_atom": + for res_id, residue in enumerate(mol.residues): + key = (group_id, res_id) + + res_container = ( + entropy_manager._run_manager.new_U_select_atom( + mol, + f"index {residue.atoms.indices[0]}:" + f"{residue.atoms.indices[-1]}", + ) + ) + heavy_res = entropy_manager._run_manager.new_U_select_atom( + res_container, "not name H*" + ) + + states = self.compute_dihedral_conformations( + heavy_res, + level, + number_frames, + bin_width, + start, + end, + step, + ce, + ) + + if key in states_ua.keys(): + states_ua[key].append(states) + else: + states_ua[key] = states + + if level == "res": + states = self.compute_dihedral_conformations( + mol, + level, + number_frames, + bin_width, + start, + end, + step, + ce, + ) + + if states_res[group_id] is None: + states_res[group_id] = states + else: + states_res[group_id] += states + + logger.debug(f"states_ua {states_ua}") + logger.debug(f"states_res {states_res}") + + return states_ua, states_res diff --git a/CodeEntropy/run.py b/CodeEntropy/run.py index e89d259..f90f56d 100644 --- a/CodeEntropy/run.py +++ b/CodeEntropy/run.py @@ -10,6 +10,7 @@ from CodeEntropy.config.data_logger import DataLogger from CodeEntropy.config.logging_config import LoggingConfig from CodeEntropy.entropy import EntropyManager +from CodeEntropy.group_molecules import GroupMolecules from CodeEntropy.levels import LevelManager logger = logging.getLogger(__name__) @@ -140,6 +141,9 @@ def run_entropy_workflow(self): # Create LevelManager instance level_manager = LevelManager() + # Create GroupMolecules instance + group_molecules = GroupMolecules() + # Inject all dependencies into EntropyManager entropy_manager = EntropyManager( run_manager=self, @@ -147,6 +151,7 @@ def run_entropy_workflow(self): universe=u, data_logger=self._data_logger, level_manager=level_manager, + group_molecules=group_molecules, ) entropy_manager.execute() diff --git a/tests/test_CodeEntropy/test_entropy.py b/tests/test_CodeEntropy/test_entropy.py index dd31cf6..3336ce2 100644 --- a/tests/test_CodeEntropy/test_entropy.py +++ b/tests/test_CodeEntropy/test_entropy.py @@ -48,11 +48,7 @@ def tearDown(self): shutil.rmtree(self.test_dir) def test_execute_full_workflow(self): - """ - Tests that `execute` runs the full entropy workflow for a known system, - triggering all processing branches and logging expected results. - """ - # Load test universe + # Setup universe and args as before tprfile = os.path.join(self.test_data_dir, "md_A4_dna.tpr") trrfile = os.path.join(self.test_data_dir, "md_A4_dna_xf.trr") u = mda.Universe(tprfile, trrfile) @@ -63,126 +59,150 @@ def test_execute_full_workflow(self): run_manager = RunManager("temp_folder") level_manager = LevelManager() data_logger = DataLogger() + group_molecules = MagicMock() entropy_manager = EntropyManager( - run_manager, args, u, data_logger, level_manager + run_manager, args, u, data_logger, level_manager, group_molecules ) + # Mocks for trajectory and molecules entropy_manager._get_trajectory_bounds = MagicMock(return_value=(0, 10, 1)) entropy_manager._get_number_frames = MagicMock(return_value=11) - entropy_manager._get_reduced_universe = MagicMock( - return_value="reduced_universe" + entropy_manager._handle_water_entropy = MagicMock() + + mock_reduced_atom = MagicMock() + mock_reduced_atom.trajectory = [1] * 11 + + mock_groups = {0: [0], 1: [1], 2: [2]} + mock_levels = { + 0: ["united_atom", "polymer", "residue"], + 1: ["united_atom", "polymer", "residue"], + 2: ["united_atom", "polymer", "residue"], + } + + entropy_manager._initialize_molecules = MagicMock( + return_value=(mock_reduced_atom, 3, mock_levels, mock_groups) + ) + entropy_manager._level_manager.build_covariance_matrices = MagicMock( + return_value=("force_matrices", "torque_matrices") ) - entropy_manager._get_molecule_container = MagicMock( - return_value=MagicMock(residues=[1, 2, 3]) + entropy_manager._level_manager.build_conformational_states = MagicMock( + return_value=(["state_ua"], ["state_res"]) ) + entropy_manager._compute_entropies = MagicMock() entropy_manager._finalize_molecule_results = MagicMock() entropy_manager._data_logger.log_tables = MagicMock() - entropy_manager._level_manager.select_levels = MagicMock( - return_value=(1, [["united_atom", "polymer", "residue"]]) - ) - - # Patch entropy classes and processing methods + # Create mocks for VibrationalEntropy and ConformationalEntropy ve = MagicMock() ce = MagicMock() + with ( patch("CodeEntropy.entropy.VibrationalEntropy", return_value=ve), patch("CodeEntropy.entropy.ConformationalEntropy", return_value=ce), ): - entropy_manager._process_united_atom_level = MagicMock( - side_effect=lambda *args, **kwargs: data_logger.add_results_data( - "A", "united_atom", "Conformational", 1.0 - ) - ) - entropy_manager._process_vibrational_only_levels = MagicMock( - side_effect=lambda *args, **kwargs: data_logger.add_results_data( - "A", "polymer", "Transvibrational", 2.0 - ) - ) - entropy_manager._process_conformational_residue_level = MagicMock( - side_effect=lambda *args, **kwargs: data_logger.add_residue_data( - 0, "A", "residue", "Conformational", 3.0 - ) - ) - entropy_manager.execute() - # Assertions - entropy_manager._process_united_atom_level.assert_called_once() - self.assertEqual(entropy_manager._process_vibrational_only_levels.call_count, 2) - entropy_manager._process_conformational_residue_level.assert_called_once() - entropy_manager._finalize_molecule_results.assert_called_once() - entropy_manager._data_logger.log_tables.assert_called_once() + # Assert the key calls happened with expected arguments + ( + entropy_manager._level_manager.build_conformational_states + ).assert_called_once_with( + entropy_manager, + mock_reduced_atom, + mock_levels, + mock_groups, + 0, + 10, + 1, + 11, + args.bin_width, + ce, + ) - # Check molecule-level entropy types - molecule_types = set(entry[2] for entry in data_logger.molecule_data) - self.assertIn("Conformational", molecule_types) - self.assertIn("Transvibrational", molecule_types) + entropy_manager._compute_entropies.assert_called_once_with( + mock_reduced_atom, + mock_levels, + mock_groups, + "force_matrices", + "torque_matrices", + ["state_ua"], + ["state_res"], + 11, + ve, + ce, + ) - # Check residue-level entropy types - residue_types = set(entry[3] for entry in data_logger.residue_data) - self.assertIn("Conformational", residue_types) + entropy_manager._finalize_molecule_results.assert_called_once() + entropy_manager._data_logger.log_tables.assert_called_once() def test_water_entropy_sets_selection_string_when_all(self): """ Tests that when `selection_string` is initially 'all' and water entropy is - enabled, the `execute` method sets `selection_string` to 'not water' after + enabled, `_handle_water_entropy` sets `selection_string` to 'not water' after calculating water entropy. """ mock_universe = MagicMock() - mock_universe.select_atoms.return_value.n_atoms = 5 + mock_universe.select_atoms.return_value.n_atoms = 5 # Simulate water present args = MagicMock(water_entropy=True, selection_string="all") run_manager = MagicMock() level_manager = MagicMock() data_logger = DataLogger() + group_molecules = MagicMock() manager = EntropyManager( - run_manager, args, mock_universe, data_logger, level_manager + run_manager, + args, + mock_universe, + data_logger, + level_manager, + group_molecules, ) - manager._get_trajectory_bounds = MagicMock(return_value=(0, 10, 1)) - manager._get_number_frames = MagicMock(return_value=11) + + # Patch water entropy calculation manager._calculate_water_entropy = MagicMock() - manager._get_reduced_universe = MagicMock(return_value="reduced") - manager._level_manager.select_levels = MagicMock(return_value=(0, [])) - manager._finalize_molecule_results = MagicMock() - manager._data_logger.log_tables = MagicMock() - manager.execute() + # Call _handle_water_entropy directly + manager._handle_water_entropy(0, 10, 1) - manager._calculate_water_entropy.assert_called_once() - assert args.selection_string == "not water" + manager._calculate_water_entropy.assert_called_once_with( + mock_universe, 0, 10, 1 + ) + self.assertEqual(args.selection_string, "not water") def test_water_entropy_appends_to_custom_selection_string(self): """ Tests that when `selection_string` is a custom value and water - entropy is enabled, the `execute` method appends ' and not water' - to the existing selection string after calculating water entropy. + entropy is enabled, `_handle_water_entropy` appends ' and not water' + to the existing selection string. """ mock_universe = MagicMock() - mock_universe.select_atoms.return_value.n_atoms = 5 + mock_universe.select_atoms.return_value.n_atoms = 5 # Simulate water present args = MagicMock(water_entropy=True, selection_string="protein") run_manager = MagicMock() level_manager = MagicMock() data_logger = DataLogger() + group_molecules = MagicMock() manager = EntropyManager( - run_manager, args, mock_universe, data_logger, level_manager + run_manager, + args, + mock_universe, + data_logger, + level_manager, + group_molecules, ) - manager._get_trajectory_bounds = MagicMock(return_value=(0, 10, 1)) - manager._get_number_frames = MagicMock(return_value=11) + manager._calculate_water_entropy = MagicMock() - manager._get_reduced_universe = MagicMock(return_value="reduced") - manager._level_manager.select_levels = MagicMock(return_value=(0, [])) - manager._finalize_molecule_results = MagicMock() - manager._data_logger.log_tables = MagicMock() - manager.execute() + # Call _handle_water_entropy directly + manager._handle_water_entropy(0, 10, 1) - manager._calculate_water_entropy.assert_called_once() - assert args.selection_string == "protein and not water" + manager._calculate_water_entropy.assert_called_once_with( + mock_universe, 0, 10, 1 + ) + self.assertEqual(args.selection_string, "protein and not water") def test_get_trajectory_bounds(self): """ @@ -195,7 +215,7 @@ def test_get_trajectory_bounds(self): args, _ = parser.parse_known_args() entropy_manager = EntropyManager( - MagicMock(), args, MagicMock(), MagicMock(), MagicMock() + MagicMock(), args, MagicMock(), MagicMock(), MagicMock(), MagicMock() ) self.assertIsInstance(entropy_manager._args.start, int) @@ -225,7 +245,7 @@ def test_get_number_frames(self, mock_args): args = parser.parse_args() entropy_manager = EntropyManager( - MagicMock(), args, MagicMock(), MagicMock(), MagicMock() + MagicMock(), args, MagicMock(), MagicMock(), MagicMock(), MagicMock() ) entropy_manager._get_trajectory_bounds() number_frames = entropy_manager._get_number_frames( @@ -257,7 +277,7 @@ def test_get_number_frames_sliced_trajectory(self, mock_args): args = parser.parse_args() entropy_manager = EntropyManager( - MagicMock(), args, MagicMock(), MagicMock(), MagicMock() + MagicMock(), args, MagicMock(), MagicMock(), MagicMock(), MagicMock() ) entropy_manager._get_trajectory_bounds() number_frames = entropy_manager._get_number_frames( @@ -290,7 +310,7 @@ def test_get_number_frames_sliced_trajectory_step(self, mock_args): args = parser.parse_args() entropy_manager = EntropyManager( - MagicMock(), args, MagicMock(), MagicMock(), MagicMock() + MagicMock(), args, MagicMock(), MagicMock(), MagicMock(), MagicMock() ) entropy_manager._get_trajectory_bounds() number_frames = entropy_manager._get_number_frames( @@ -324,7 +344,9 @@ def test_get_reduced_universe_all(self, mock_args): parser = config_manager.setup_argparse() args = parser.parse_args() - entropy_manager = EntropyManager(MagicMock(), args, u, MagicMock(), MagicMock()) + entropy_manager = EntropyManager( + MagicMock(), args, u, MagicMock(), MagicMock(), MagicMock() + ) entropy_manager._get_reduced_universe() @@ -355,7 +377,9 @@ def test_get_reduced_universe_reduced(self, mock_args): parser = config_manager.setup_argparse() args = parser.parse_args() - entropy_manager = EntropyManager(run_manager, args, u, MagicMock(), MagicMock()) + entropy_manager = EntropyManager( + run_manager, args, u, MagicMock(), MagicMock(), MagicMock() + ) reduced_u = entropy_manager._get_reduced_universe() @@ -391,7 +415,9 @@ def test_get_molecule_container(self, mock_args): parser = config_manager.setup_argparse() args = parser.parse_args() - entropy_manager = EntropyManager(run_manager, args, u, MagicMock(), MagicMock()) + entropy_manager = EntropyManager( + run_manager, args, u, MagicMock(), MagicMock(), MagicMock() + ) # Call the method molecule_id = 0 @@ -409,8 +435,8 @@ def test_get_molecule_container(self, mock_args): def test_process_united_atom_level(self): """ - Tests that `_process_united_atom_level` correctly logs global and residue-level - entropy results for a known molecular system using MDAnalysis. + Tests that `_process_united_atom_entropy` correctly logs global and + residue-level entropy results for a known molecular system using MDAnalysis. """ # Load a known test universe @@ -423,43 +449,53 @@ def test_process_united_atom_level(self): run_manager = RunManager("temp_folder") level_manager = LevelManager() data_logger = DataLogger() - manager = EntropyManager(run_manager, args, u, data_logger, level_manager) + group_molecules = MagicMock() + manager = EntropyManager( + run_manager, args, u, data_logger, level_manager, group_molecules + ) + # Prepare mock molecule container reduced_atom = manager._get_reduced_universe() mol_container = manager._get_molecule_container(reduced_atom, 0) n_residues = len(mol_container.residues) - ve = VibrationalEntropy(run_manager, args, u, data_logger, level_manager) - ce = ConformationalEntropy(run_manager, args, u, data_logger, level_manager) + # Create dummy matrices and states + force_matrix = {(0, i): np.eye(3) for i in range(n_residues)} + torque_matrix = {(0, i): np.eye(3) * 2 for i in range(n_residues)} + states = {(0, i): np.ones((10, 3)) for i in range(n_residues)} + + # Mock entropy calculators + ve = MagicMock() + ce = MagicMock() + ve.vibrational_entropy_calculation.side_effect = lambda m, t, temp, high: ( + 1.0 if t == "force" else 2.0 + ) + ce.conformational_entropy_calculation.return_value = 3.0 - # Run the function - manager._process_united_atom_level( - mol_id=0, + # Run the method + manager._process_united_atom_entropy( + group_id=0, mol_container=mol_container, ve=ve, ce=ce, level="united_atom", - start=1, - end=1, - step=1, - n_frames=1, + force_matrix=force_matrix, + torque_matrix=torque_matrix, + states=states, highest=True, + number_frames=10, ) - # Check that results were logged for each entropy type + # Check molecule-level results df = data_logger.molecule_data self.assertEqual(len(df), 3) # Trans, Rot, Conf - # Check that residue-level results were logged + # Check residue-level results residue_df = data_logger.residue_data self.assertEqual(len(residue_df), 3 * n_residues) # 3 types per residue # Check that all expected types are present - expected_types = { - "Transvibrational", - "Rovibrational", - "Conformational", - } + expected_types = {"Transvibrational", "Rovibrational", "Conformational"} actual_types = set(entry[2] for entry in df) self.assertSetEqual(actual_types, expected_types) @@ -469,10 +505,9 @@ def test_process_united_atom_level(self): def test_process_vibrational_only_levels(self): """ - Tests that `_process_vibrational_only_levels` correctly logs vibrational + Tests that `_process_vibrational_entropy` correctly logs vibrational entropy results for a known molecular system using MDAnalysis. """ - # Load a known test universe tprfile = os.path.join(self.test_data_dir, "md_A4_dna.tpr") trrfile = os.path.join(self.test_data_dir, "md_A4_dna_xf.trr") @@ -483,29 +518,34 @@ def test_process_vibrational_only_levels(self): run_manager = RunManager("temp_folder") level_manager = LevelManager() data_logger = DataLogger() - manager = EntropyManager(run_manager, args, u, data_logger, level_manager) + group_molecules = MagicMock() + manager = EntropyManager( + run_manager, args, u, data_logger, level_manager, group_molecules + ) + # Prepare mock molecule container reduced_atom = manager._get_reduced_universe() mol_container = manager._get_molecule_container(reduced_atom, 0) - # Patch methods to isolate the test - manager._level_manager.get_matrices = MagicMock( - return_value=("mock_force", "mock_torque") - ) + # Simulate trajectory length + mol_container.trajectory = [None] * 10 # 10 frames - ve = VibrationalEntropy(run_manager, args, u, data_logger, level_manager) - ve.vibrational_entropy_calculation = MagicMock(side_effect=[1.11, 2.22]) + # Create dummy matrices + force_matrix = np.eye(3) + torque_matrix = np.eye(3) * 2 - # Run the function - manager._process_vibrational_only_levels( - mol_id=0, - mol_container=mol_container, + # Mock entropy calculator + ve = MagicMock() + ve.vibrational_entropy_calculation.side_effect = [1.11, 2.22] + + # Run the method + manager._process_vibrational_entropy( + group_id=0, + number_frames=10, ve=ve, level="Vibrational", - start=1, - end=1, - step=1, - n_frames=1, + force_matrix=force_matrix, + torque_matrix=torque_matrix, highest=True, ) @@ -513,25 +553,70 @@ def test_process_vibrational_only_levels(self): df = data_logger.molecule_data self.assertEqual(len(df), 2) # Transvibrational and Rovibrational - expected_types = { - "Transvibrational", - "Rovibrational", - } + expected_types = {"Transvibrational", "Rovibrational"} actual_types = set(entry[2] for entry in df) self.assertSetEqual(actual_types, expected_types) - # Check entropy values results = [entry[3] for entry in df] self.assertIn(1.11, results) self.assertIn(2.22, results) + def test_compute_entropies_polymer_branch(self): + """ + Test _compute_entropies triggers _process_vibrational_entropy for 'polymer' + level. + """ + args = MagicMock(bin_width=0.1) + run_manager = MagicMock() + level_manager = MagicMock() + data_logger = DataLogger() + group_molecules = MagicMock() + manager = EntropyManager( + run_manager, args, MagicMock(), data_logger, level_manager, group_molecules + ) + + reduced_atom = MagicMock() + number_frames = 5 + groups = {0: [0]} # One molecule only + levels = [["polymer"]] # One level for that molecule + + force_matrices = {"poly": {0: np.eye(3)}} + torque_matrices = {"poly": {0: np.eye(3) * 2}} + states_ua = {} + states_res = [] + + mol_mock = MagicMock() + mol_mock.residues = [] + manager._get_molecule_container = MagicMock(return_value=mol_mock) + manager._process_vibrational_entropy = MagicMock() + + ve = MagicMock() + ve.vibrational_entropy_calculation.side_effect = [1.11] + + ce = MagicMock() + ce.conformational_entropy_calculation.return_value = 3.33 + + manager._compute_entropies( + reduced_atom, + levels, + groups, + force_matrices, + torque_matrices, + states_ua, + states_res, + number_frames, + ve, + ce, + ) + + manager._process_vibrational_entropy.assert_called_once() + def test_process_conformational_residue_level(self): """ - Tests that `_process_conformational_residue_level` correctly logs conformational + Tests that `_process_conformational_entropy` correctly logs conformational entropy results at the residue level for a known molecular system using MDAnalysis. """ - # Load a known test universe tprfile = os.path.join(self.test_data_dir, "md_A4_dna.tpr") trrfile = os.path.join(self.test_data_dir, "md_A4_dna_xf.trr") @@ -542,28 +627,25 @@ def test_process_conformational_residue_level(self): run_manager = RunManager("temp_folder") level_manager = LevelManager() data_logger = DataLogger() - manager = EntropyManager(run_manager, args, u, data_logger, level_manager) - - reduced_atom = manager._get_reduced_universe() - mol_container = manager._get_molecule_container(reduced_atom, 0) + group_molecules = MagicMock() + manager = EntropyManager( + run_manager, args, u, data_logger, level_manager, group_molecules + ) - # Patch methods to isolate the test - mock_dihedrals = ["phi", "psi", "chi1"] - manager._level_manager.get_dihedrals = MagicMock(return_value=mock_dihedrals) + # Create dummy states + states = {0: np.ones((10, 3))} - ce = ConformationalEntropy(run_manager, args, u, data_logger, level_manager) - ce.conformational_entropy_calculation = MagicMock(return_value=3.33) + # Mock entropy calculator + ce = MagicMock() + ce.conformational_entropy_calculation.return_value = 3.33 - # Run the function - manager._process_conformational_residue_level( - mol_id=0, - mol_container=mol_container, + # Run the method + manager._process_conformational_entropy( + group_id=0, ce=ce, level="residue", - start=1, - end=1, - step=1, - n_frames=1, + states=states, + number_frames=10, ) # Check that results were logged @@ -574,7 +656,6 @@ def test_process_conformational_residue_level(self): actual_types = set(entry[2] for entry in df) self.assertSetEqual(actual_types, expected_types) - # Check entropy values results = [entry[3] for entry in df] self.assertIn(3.33, results) @@ -595,7 +676,7 @@ def test_finalize_molecule_results_aggregates_and_logs_total_entropy(self): ] data_logger.residue_data = [] - manager = EntropyManager(None, args, None, data_logger, None) + manager = EntropyManager(None, args, None, data_logger, None, None) # Patch save method data_logger.save_dataframes_as_json = MagicMock() @@ -638,7 +719,7 @@ def test_finalize_molecule_results_skips_invalid_entries(self, mock_logger): ] data_logger.residue_data = [] - manager = EntropyManager(None, args, None, data_logger, None) + manager = EntropyManager(None, args, None, data_logger, None, None) # Patch save method data_logger.save_dataframes_as_json = MagicMock() @@ -677,7 +758,7 @@ def setUp(self): os.chdir(self.test_dir) self.entropy_manager = EntropyManager( - MagicMock(), MagicMock(), MagicMock(), MagicMock(), MagicMock() + MagicMock(), MagicMock(), MagicMock(), MagicMock(), MagicMock(), MagicMock() ) def tearDown(self): @@ -705,9 +786,12 @@ def test_vibrational_entropy_init(self): run_manager = RunManager("temp_folder") level_manager = LevelManager() data_logger = DataLogger() + group_molecules = MagicMock() # Instantiate VibrationalEntropy - ve = VibrationalEntropy(run_manager, args, universe, data_logger, level_manager) + ve = VibrationalEntropy( + run_manager, args, universe, data_logger, level_manager, group_molecules + ) # Basic assertions to check initialization self.assertIsInstance(ve, VibrationalEntropy) @@ -727,7 +811,7 @@ def test_frequency_calculation_0(self): run_manager = RunManager("mock_folder") ve = VibrationalEntropy( - run_manager, MagicMock(), MagicMock(), MagicMock(), MagicMock() + run_manager, MagicMock(), MagicMock(), MagicMock(), MagicMock(), MagicMock() ) frequencies = ve.frequency_calculation(lambdas, temp) @@ -748,7 +832,7 @@ def test_frequency_calculation_positive(self): # Instantiate VibrationalEntropy with mocks ve = VibrationalEntropy( - run_manager, MagicMock(), MagicMock(), MagicMock(), MagicMock() + run_manager, MagicMock(), MagicMock(), MagicMock(), MagicMock(), MagicMock() ) # Call the method under test @@ -774,7 +858,7 @@ def test_frequency_calculation_negative(self): # Instantiate VibrationalEntropy with mocks ve = VibrationalEntropy( - run_manager, MagicMock(), MagicMock(), MagicMock(), MagicMock() + run_manager, MagicMock(), MagicMock(), MagicMock(), MagicMock(), MagicMock() ) # Assert that ValueError is raised due to negative eigenvalue @@ -798,7 +882,7 @@ def test_vibrational_entropy_calculation_force_not_highest(self): # Instantiate VibrationalEntropy with mocks ve = VibrationalEntropy( - run_manager, MagicMock(), MagicMock(), MagicMock(), MagicMock() + run_manager, MagicMock(), MagicMock(), MagicMock(), MagicMock(), MagicMock() ) # Patch frequency_calculation to return known frequencies @@ -843,7 +927,7 @@ def test_vibrational_entropy_polymer_force(self): run_manager = RunManager("mock_folder") ve = VibrationalEntropy( - run_manager, MagicMock(), MagicMock(), MagicMock(), MagicMock() + run_manager, MagicMock(), MagicMock(), MagicMock(), MagicMock(), MagicMock() ) S_vib = ve.vibrational_entropy_calculation( @@ -873,7 +957,7 @@ def test_vibrational_entropy_polymer_torque(self): run_manager = RunManager("mock_folder") ve = VibrationalEntropy( - run_manager, MagicMock(), MagicMock(), MagicMock(), MagicMock() + run_manager, MagicMock(), MagicMock(), MagicMock(), MagicMock(), MagicMock() ) S_vib = ve.vibrational_entropy_calculation( @@ -1086,10 +1170,11 @@ def test_confirmational_entropy_init(self): run_manager = RunManager("temp_folder") level_manager = LevelManager() data_logger = DataLogger() + group_molecules = MagicMock() # Instantiate ConformationalEntropy ce = ConformationalEntropy( - run_manager, args, universe, data_logger, level_manager + run_manager, args, universe, data_logger, level_manager, group_molecules ) # Basic assertions to check initialization @@ -1127,8 +1212,11 @@ def test_assign_conformation(self): run_manager = RunManager("temp_folder") level_manager = LevelManager() data_logger = DataLogger() + group_molecules = MagicMock() - ce = ConformationalEntropy(run_manager, args, u, data_logger, level_manager) + ce = ConformationalEntropy( + run_manager, args, u, data_logger, level_manager, group_molecules + ) result = ce.assign_conformation( data_container=data_container, @@ -1187,10 +1275,11 @@ def test_orientational_entropy_init(self): run_manager = RunManager("temp_folder") level_manager = LevelManager() data_logger = DataLogger() + group_molecules = MagicMock() # Instantiate OrientationalEntropy oe = OrientationalEntropy( - run_manager, args, universe, data_logger, level_manager + run_manager, args, universe, data_logger, level_manager, group_molecules ) # Basic assertions to check initialization @@ -1211,7 +1300,7 @@ def test_orientational_entropy_calculation(self): } # Create an instance of OrientationalEntropy with dummy dependencies - oe = OrientationalEntropy(None, None, None, None, None) + oe = OrientationalEntropy(None, None, None, None, None, None) # Run the method result = oe.orientational_entropy_calculation(neighbours_dict) @@ -1232,7 +1321,7 @@ def test_orientational_entropy_water_branch_is_covered(self): """ neighbours_dict = {"H2O": 1} # Matches the condition exactly - oe = OrientationalEntropy(None, None, None, None, None) + oe = OrientationalEntropy(None, None, None, None, None, None) result = oe.orientational_entropy_calculation(neighbours_dict) # Since the logic is skipped, total entropy should be 0.0 diff --git a/tests/test_CodeEntropy/test_group_molecules.py b/tests/test_CodeEntropy/test_group_molecules.py new file mode 100644 index 0000000..1914e5e --- /dev/null +++ b/tests/test_CodeEntropy/test_group_molecules.py @@ -0,0 +1,94 @@ +import os +import shutil +import tempfile +import unittest +from unittest.mock import MagicMock + +import numpy as np + +from CodeEntropy.group_molecules import GroupMolecules + + +class TestMain(unittest.TestCase): + """ + Unit tests for the functionality of GroupMolecules class. + """ + + def setUp(self): + """ + Set up a temporary directory as the working directory before each test. + Initialize GroupMolecules instance. + """ + self.test_dir = tempfile.mkdtemp(prefix="CodeEntropy_") + self._orig_dir = os.getcwd() + os.chdir(self.test_dir) + self.group_molecules = GroupMolecules() + + def tearDown(self): + """ + Clean up by removing the temporary directory and restoring the original working + directory. + """ + os.chdir(self._orig_dir) + shutil.rmtree(self.test_dir) + + def test_by_none_returns_individual_groups(self): + """ + Test _by_none returns each molecule in its own group when grouping is 'each'. + """ + mock_universe = MagicMock() + # Simulate universe.atoms.fragments has 3 molecules + mock_universe.atoms.fragments = [MagicMock(), MagicMock(), MagicMock()] + + groups = self.group_molecules._by_none(mock_universe) + expected = {0: [0], 1: [1], 2: [2]} + self.assertEqual(groups, expected) + + def test_by_molecules_groups_by_chemical_type(self): + """ + Test _by_molecules groups molecules with identical atom counts and names + together. + """ + mock_universe = MagicMock() + + fragment0 = MagicMock() + fragment0.names = np.array(["H", "O", "H"]) + fragment1 = MagicMock() + fragment1.names = np.array(["H", "O", "H"]) + fragment2 = MagicMock() + fragment2.names = np.array(["C", "C", "H", "H"]) + + mock_universe.atoms.fragments = [fragment0, fragment1, fragment2] + + groups = self.group_molecules._by_molecules(mock_universe) + + # Expect first two grouped, third separate + self.assertIn(0, groups) + self.assertIn(2, groups) + self.assertCountEqual(groups[0], [0, 1]) + self.assertEqual(groups[2], [2]) + + def test_grouping_molecules_dispatches_correctly(self): + """ + Test grouping_molecules method dispatches to correct grouping strategy. + """ + mock_universe = MagicMock() + mock_universe.atoms.fragments = [MagicMock()] # Just 1 molecule to keep simple + + # When grouping='each', calls _by_none + groups = self.group_molecules.grouping_molecules(mock_universe, "each") + self.assertEqual(groups, {0: [0]}) + + # When grouping='molecules', calls _by_molecules (mock to test call) + self.group_molecules._by_molecules = MagicMock(return_value={"mocked": [42]}) + groups = self.group_molecules.grouping_molecules(mock_universe, "molecules") + self.group_molecules._by_molecules.assert_called_once_with(mock_universe) + self.assertEqual(groups, {"mocked": [42]}) + + # If grouping unknown, should return empty dict + groups = self.group_molecules.grouping_molecules(mock_universe, "unknown") + self.assertEqual(groups, {}) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_CodeEntropy/test_levels.py b/tests/test_CodeEntropy/test_levels.py index 81166bd..734c7ef 100644 --- a/tests/test_CodeEntropy/test_levels.py +++ b/tests/test_CodeEntropy/test_levels.py @@ -108,11 +108,10 @@ def test_get_matrices(self): force_matrix, torque_matrix = level_manager.get_matrices( data_container=data_container, level="residue", - start=0, - end=2, - step=1, number_frames=2, - highest_level="polymer", + highest_level=True, + force_matrix=None, + torque_matrix=None, ) # Assertions @@ -123,11 +122,82 @@ def test_get_matrices(self): # Check that internal methods were called self.assertEqual(level_manager.get_beads.call_count, 1) - self.assertEqual(level_manager.get_axes.call_count, 4) # 2 beads × 2 frames + self.assertEqual(level_manager.get_axes.call_count, 2) # 2 beads self.assertEqual( level_manager.create_submatrix.call_count, 6 ) # 3 force + 3 torque + def test_get_matrices_force_shape_mismatch(self): + """ + Test that get_matrices raises a ValueError when the provided force_matrix + has a shape mismatch with the computed force block matrix. + """ + level_manager = LevelManager() + + # Mock internal methods + level_manager.get_beads = MagicMock(return_value=["bead1", "bead2"]) + level_manager.get_axes = MagicMock(return_value=("trans_axes", "rot_axes")) + level_manager.get_weighted_forces = MagicMock( + return_value=np.array([1.0, 2.0, 3.0]) + ) + level_manager.get_weighted_torques = MagicMock( + return_value=np.array([0.5, 1.5, 2.5]) + ) + level_manager.create_submatrix = MagicMock(return_value=np.identity(3)) + + data_container = MagicMock() + + # Incorrect shape for force matrix (should be 6x6 for 2 beads) + bad_force_matrix = np.zeros((3, 3)) + correct_torque_matrix = np.zeros((6, 6)) + + with self.assertRaises(ValueError) as context: + level_manager.get_matrices( + data_container=data_container, + level="residue", + number_frames=2, + highest_level=True, + force_matrix=bad_force_matrix, + torque_matrix=correct_torque_matrix, + ) + + self.assertIn("Inconsistent force matrix shape", str(context.exception)) + + def test_get_matrices_torque_shape_mismatch(self): + """ + Test that get_matrices raises a ValueError when the provided torque_matrix + has a shape mismatch with the computed torque block matrix. + """ + level_manager = LevelManager() + + # Mock internal methods + level_manager.get_beads = MagicMock(return_value=["bead1", "bead2"]) + level_manager.get_axes = MagicMock(return_value=("trans_axes", "rot_axes")) + level_manager.get_weighted_forces = MagicMock( + return_value=np.array([1.0, 2.0, 3.0]) + ) + level_manager.get_weighted_torques = MagicMock( + return_value=np.array([0.5, 1.5, 2.5]) + ) + level_manager.create_submatrix = MagicMock(return_value=np.identity(3)) + + data_container = MagicMock() + + correct_force_matrix = np.zeros((6, 6)) + bad_torque_matrix = np.zeros((3, 3)) # Incorrect shape + + with self.assertRaises(ValueError) as context: + level_manager.get_matrices( + data_container=data_container, + level="residue", + number_frames=2, + highest_level=True, + force_matrix=correct_force_matrix, + torque_matrix=bad_torque_matrix, + ) + + self.assertIn("Inconsistent torque matrix shape", str(context.exception)) + def test_get_dihedrals_united_atom(self): """ Test `get_dihedrals` for 'united_atom' level. @@ -247,6 +317,38 @@ def test_get_beads_united_atom_level(self): data_container.select_atoms.call_count, 4 ) # 1 for heavy_atoms + 3 beads + def test_get_axes_united_atom_no_bonds(self): + """ + Test `get_axes` for 'united_atom' level when no bonded atoms are found. + Ensures that rotational axes fall back to residues' principal axes. + """ + level_manager = LevelManager() + + data_container = MagicMock() + + # Mock principal axes for translation and rotation + mock_rot_axes = MagicMock(name="rot_axes") + + data_container.residues.principal_axes.return_value = mock_rot_axes + data_container.residues.principal_axes.return_value = mock_rot_axes + data_container.residues.principal_axes.return_value = mock_rot_axes # fallback + + # First select_atoms returns empty bonded atom set + atom_set = MagicMock() + atom_set.__len__.return_value = 0 # triggers fallback + + data_container.select_atoms.side_effect = [atom_set] + + trans_axes, rot_axes = level_manager.get_axes( + data_container=data_container, level="united_atom", index=5 + ) + + # Assertions + self.assertEqual(trans_axes, mock_rot_axes) + self.assertEqual(rot_axes, mock_rot_axes) + data_container.residues.principal_axes.assert_called() + self.assertEqual(data_container.select_atoms.call_count, 1) + def test_get_axes_polymer_level(self): """ Test `get_axes` for 'polymer' level. @@ -328,6 +430,8 @@ def test_get_axes_united_atom_level(self): data_container.residues.principal_axes.return_value = "trans_axes" atom_set = MagicMock() + atom_set.__len__.return_value = 1 + atom_group = MagicMock() atom_group.positions = [[1.0, 2.0, 3.0]] @@ -660,30 +764,31 @@ def test_get_weighted_torques_negative_moi_raises(self): "Negative value encountered for moment of inertia", str(context.exception) ) - def test_create_submatrix_basic_outer_product_average(self): + def test_create_submatrix_basic_outer_product(self): """ - Test with known vectors to verify correct average outer product. + Test with known vectors to verify correct outer product. """ level_manager = LevelManager() - data_i = [np.array([1, 0, 0]), np.array([0, 1, 0])] - data_j = [np.array([0, 1, 0]), np.array([1, 0, 0])] - number_frames = 2 - - expected = (np.outer(data_i[0], data_j[0]) + np.outer(data_i[1], data_j[1])) / 2 + data_i = np.array([1, 0, 0]) + data_j = np.array([0, 1, 0]) + number_frames = 1 # Not used in current implementation + expected = np.outer(data_i, data_j) result = level_manager.create_submatrix(data_i, data_j, number_frames) - np.testing.assert_array_almost_equal(result, expected) + + np.testing.assert_array_equal(result, expected) def test_create_submatrix_zero_vectors_returns_zero_matrix(self): """ - Test that all-zero input vectors should return a zero matrix. + Test that all-zero input vectors return a zero matrix. """ level_manager = LevelManager() - data_i = [np.zeros(3) for _ in range(3)] - data_j = [np.zeros(3) for _ in range(3)] - result = level_manager.create_submatrix(data_i, data_j, 3) + data_i = np.zeros(3) + data_j = np.zeros(3) + result = level_manager.create_submatrix(data_i, data_j, 1) + np.testing.assert_array_equal(result, np.zeros((3, 3))) def test_create_submatrix_single_frame(self): @@ -702,12 +807,13 @@ def test_create_submatrix_single_frame(self): def test_create_submatrix_symmetric_result_when_data_equal(self): """ - Test that if data_i == data_j, the result should be symmetric. + Test that if data_i == data_j, the result is symmetric. """ level_manager = LevelManager() - data = [np.array([1, 2, 3]), np.array([4, 5, 6])] - result = level_manager.create_submatrix(data, data, 2) + data = np.array([1, 2, 3]) + result = level_manager.create_submatrix(data, data, 1) + self.assertTrue(np.allclose(result, result.T)) # Check symmetry def test_filter_zero_rows_columns_no_zeros(self): @@ -752,3 +858,138 @@ def test_filter_zero_rows_columns_partial_zero_removal(self): expected = np.array([[1, 2, 3]]) result = level_manager.filter_zero_rows_columns(matrix) np.testing.assert_array_equal(result, expected) + + def test_build_conformational_states_united_atom_accumulates_states(self): + """ + Test that the 'build_conformational_states' method correctly accumulates + united atom level conformational states for multiple molecules within the + same group. + + Specifically, when called with two molecules in the same group, the method + should append the states returned for the second molecule to the list of + states for the first molecule, resulting in a nested list structure. + + Verifies: + - The states_ua dictionary accumulates states as a nested list. + - The compute_dihedral_conformations method is called once per molecule. + """ + level_manager = LevelManager() + entropy_manager = MagicMock() + reduced_atom = MagicMock() + ce = MagicMock() + + # Setup mock residue for molecules + residue = MagicMock() + residue.atoms.indices = [10, 11, 12] + + # Setup two mock molecules with the same residue + mol_0 = MagicMock() + mol_0.residues = [residue] + mol_1 = MagicMock() + mol_1.residues = [residue] + + # entropy_manager returns different molecules by mol_id + entropy_manager._get_molecule_container.side_effect = [mol_0, mol_1] + + # new_U_select_atom returns dummy selections twice per molecule call + dummy_sel_1 = MagicMock() + dummy_sel_2 = MagicMock() + # For mol_0: light then heavy + # For mol_1: light then heavy + entropy_manager._run_manager.new_U_select_atom.side_effect = [ + dummy_sel_1, + dummy_sel_2, + dummy_sel_1, + dummy_sel_2, + ] + + # Mock compute_dihedral_conformations to return different states for each call + state_1 = ["ua_state_1"] + state_2 = ["ua_state_2"] + level_manager.compute_dihedral_conformations = MagicMock( + side_effect=[state_1, state_2] + ) + + groups = {0: [0, 1]} # Group 0 contains molecule 0 and molecule 1 + levels = [["united_atom"], ["united_atom"]] + start, end, step = 0, 10, 1 + number_frames = 10 + bin_width = 0.1 + + states_ua, states_res = level_manager.build_conformational_states( + entropy_manager, + reduced_atom, + levels, + groups, + start, + end, + step, + number_frames, + bin_width, + ce, + ) + + assert states_ua[(0, 0)] == ["ua_state_1", ["ua_state_2"]] + + # Confirm compute_dihedral_conformations was called twice (once per molecule) + assert level_manager.compute_dihedral_conformations.call_count == 2 + + def test_build_conformational_states_residue_level_accumulates_states(self): + """ + Test that the 'build_conformational_states' method correctly accumulates + residue level conformational states for multiple molecules within the + same group. + + When called with multiple molecules assigned to the same group at residue level, + the method should concatenate the returned states into a single flat list. + + Verifies: + - The states_res list contains concatenated residue states from all molecules. + - The states_ua dictionary remains empty for residue level. + - compute_dihedral_conformations is called once per molecule. + """ + level_manager = LevelManager() + entropy_manager = MagicMock() + reduced_atom = MagicMock() + ce = MagicMock() + + # Setup molecule with no residues + mol = MagicMock() + mol.residues = [] + entropy_manager._get_molecule_container.return_value = mol + + # Setup return values for compute_dihedral_conformations + states_1 = ["res_state1"] + states_2 = ["res_state2"] + level_manager.compute_dihedral_conformations = MagicMock( + side_effect=[states_1, states_2] + ) + + # Setup inputs with 2 molecules in same group + groups = {0: [0, 1]} # Both mol 0 and mol 1 are in group 0 + levels = [["res"], ["res"]] + start, end, step = 0, 10, 1 + number_frames = 10 + bin_width = 0.1 + + # Run + states_ua, states_res = level_manager.build_conformational_states( + entropy_manager, + reduced_atom, + levels, + groups, + start, + end, + step, + number_frames, + bin_width, + ce, + ) + + # Confirm accumulation occurred + assert states_ua == {} + assert states_res[0] == ["res_state1", "res_state2"] + assert states_res == [["res_state1", "res_state2"]] + + # Assert both calls to compute_dihedral_conformations happened + assert level_manager.compute_dihedral_conformations.call_count == 2