diff --git a/simpeg/directives/_directives.py b/simpeg/directives/_directives.py index 034840a200..00340066e1 100644 --- a/simpeg/directives/_directives.py +++ b/simpeg/directives/_directives.py @@ -1,5 +1,5 @@ from abc import ABCMeta, abstractmethod -from typing import TYPE_CHECKING +from typing import Iterable, TYPE_CHECKING from datetime import datetime import pathlib @@ -2655,6 +2655,16 @@ def endIter(self): self.opt.xc = m +def flatten(nested_iterable): + for item in nested_iterable: + if isinstance(item, list): + yield from flatten(item) + elif isinstance(item, np.ndarray): + yield from item.tolist() + else: + yield item + + class ScaleMisfitMultipliers(InversionDirective): """ Scale the misfits by the relative chi-factors of multiple misfit functions. @@ -2670,9 +2680,19 @@ class ScaleMisfitMultipliers(InversionDirective): Path to save the chi-factors log file. """ - def __init__(self, path: pathlib.Path | None = None, **kwargs): + def __init__( + self, + path: pathlib.Path | None = None, + nesting: list[list] | None = None, + target_chi: float = 1.0, + headers: list[str] | None = None, + **kwargs, + ): self.last_beta = None self.chi_factors = None + self.target_chi = target_chi + self.nesting = nesting + self.headers = headers if path is None: path = pathlib.Path("./") @@ -2681,45 +2701,112 @@ def __init__(self, path: pathlib.Path | None = None, **kwargs): super().__init__(**kwargs) + self._log_array: np.ndarray | None = None + + @property + def log_array(self, headers: list[str] | None = None): + if self._log_array is None: + if self.headers is None: + + def append_sub_indices(elements, header): + values = [] + for ii, elem in enumerate(elements): + heads = header + f"[{ii}]" + if isinstance(elem, Iterable): + values += append_sub_indices(elem, heads) + else: + values += [heads] + return values + + headers = [] + for ii, elem in enumerate(self.misfit_tree_indices): + headers += append_sub_indices(elem, f"[{ii}]") + self.headers = headers + + dtype = np.dtype( + [("Iterations", np.int32)] + [(h, np.float32) for h in self.headers] + ) + self._log_array = np.rec.fromrecords((), dtype=dtype) + + return self._log_array + def initialize(self): self.last_beta = self.invProb.beta self.multipliers = self.invProb.dmisfit.multipliers - self.scalings = np.ones_like(self.multipliers) - with open(self.filepath, "w", encoding="utf-8") as f: - f.write("Logging of [scaling * chi factor] per misfit function.\n\n") - f.write( - "Iterations\t" - + "\t".join( - f"[{objfct.name}]" for objfct in self.invProb.dmisfit.objfcts - ) - ) - f.write("\n") + self.scalings = np.ones_like(self.multipliers) # Everyone gets a fair chance + self.misfit_tree_indices = self.parse_by_nested_levels(self.nesting) - def endIter(self): - ratio = self.invProb.beta / self.last_beta - chi_factors = [] - for residual in self.invProb.residuals: - phi_d = np.vdot(residual, residual) - chi_factors.append(phi_d / len(residual)) + self.write_log() - self.chi_factors = np.asarray(chi_factors) + def scale_by_level( + self, nested_values, nested_indices, ratio, scaling_vector: np.ndarray | None + ): + """ + Recursively compute scaling factors for each level of the nested misfit structure. - if np.all(self.chi_factors < 1) or ratio >= 1: - self.last_beta = self.invProb.beta - self.write_log() - return + The maximum chi-factor at each level is used to determine scaling factors + for the misfit functions at that level. The scaling factors are then propagated + down to the next level of the nested structure. + + Parameters + ---------- + nested_values : list + Nested list of misfit residuals. + + nested_indices : list + Nested list of indices corresponding to the misfit residuals. - # Normalize scaling between [ratio, 1] + ratio : float + Ratio of current beta to last beta. + + scaling_vector : np.ndarray, optional + Current scaling vector to be updated. + """ + if scaling_vector is None: + scaling_vector = np.ones(len(self.invProb.dmisfit.multipliers)) + + chi_factors = [] + flat_indices = [] + for elem, indices in zip(nested_values, nested_indices): + flat_indices.append(np.asarray(list(flatten(indices)))) + residuals = np.asarray(list(flatten(elem))) + phi_d = np.vdot(residuals, residuals) + chi_factors.append(phi_d / len(residuals)) + + chi_factors = np.hstack(chi_factors) scalings = ( - 1 - - (1 - ratio) - * (self.chi_factors.max() - self.chi_factors) - / self.chi_factors.max() + 1 - (1 - ratio) * (chi_factors.max() - chi_factors) / chi_factors.max() ) # Force the ones that overshot target - scalings[self.chi_factors < 1] = ( - ratio # * self.chi_factors[self.chi_factors < 1] + scalings[chi_factors < self.target_chi] = ratio + + for elem, indices, scale, group_ind in zip( + nested_values, nested_indices, scalings, flat_indices + ): + # Scale everything below same as super group + scaling_vector[group_ind] = np.maximum( + ratio, scale * scaling_vector[group_ind] + ) + + # Continue one level deeper if more nesting + if isinstance(indices, list) and ( + len(indices) > 1 and isinstance(indices[0], list) + ): + scaling_vector = self.scale_by_level( + elem, indices, ratio, scaling_vector + ) + + return scaling_vector + + def endIter(self): + ratio = self.invProb.beta / self.last_beta + nested_residuals = self.parse_by_nested_levels( + self.nesting, self.invProb.residuals + ) + + scalings = self.scale_by_level( + nested_residuals, self.misfit_tree_indices, ratio, None ) # Update the scaling @@ -2728,22 +2815,76 @@ def endIter(self): # Normalize total phi_d with scalings self.invProb.dmisfit.multipliers = self.multipliers * self.scalings self.last_beta = self.invProb.beta + + # Log the scaling factors self.write_log() + def parse_by_nested_levels( + self, nesting: list[Iterable], values: Iterable | None = None + ) -> Iterable: + """ + Replace leaf elements of `nesting` with values from `values` (in order). + Assumes the number of leaf positions equals len(values). + + Parameters: + - nesting: arbitrarily nested list structure; leaves are non-list values + - values: flat iterable whose values will fill the leaves in order + + Returns: + - A new nested structure with leaves replaced by values from `values`. + + Raises: + - ValueError if `values` has fewer or more elements than required by `nesting`. + """ + indices = np.arange(len(self.invProb.dmisfit.objfcts)) + if nesting is None: + if values is not None: + return values + return indices.tolist() + + it = iter(indices) + + def _fill(node: Iterable) -> Iterable: + if isinstance(node, list): + return [_fill(child) for child in node] + elif isinstance(node, dict): + return [_fill(child) for child in node.values()] + # leaf: consume a value + try: + if values is not None: + return values[next(it)] + return next(it) + except StopIteration: + raise ValueError("Not enough elements in `flat` to fill `nesting`.") + + result = _fill(nesting) + + # ensure no extra elements left + try: + next(it) + raise ValueError("Too many elements in `flat` for the given `nesting`.") + except StopIteration: + pass + + return result + def write_log(self): """ Write the scaling factors to the log file. """ - with open(self.filepath, "a", encoding="utf-8") as f: - f.write( - f"{self.opt.iter}\t" - + "\t".join( - f"{multi:.2e} * {chi:.2e}" - for multi, chi in zip( - self.invProb.dmisfit.multipliers, self.chi_factors - ) - ) - + "\n" + self._log_array = np.append( + self.log_array, + np.rec.fromrecords( + tuple([getattr(self.opt, "iter", 0)] + self.scalings.tolist()), + dtype=self.log_array.dtype, + ), + ) + with open(self.filepath, "w", encoding="utf-8") as f: + np.savetxt( + f, + self.log_array, + header="Iterations - Scaling per misfit", + fmt=["%d"] + ["%0.2e"] * (len(self._log_array.dtype) - 1), )