Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
221 changes: 181 additions & 40 deletions simpeg/directives/_directives.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABCMeta, abstractmethod
from typing import TYPE_CHECKING
from typing import Iterable, TYPE_CHECKING

from datetime import datetime
import pathlib
Expand Down Expand Up @@ -2655,6 +2655,16 @@ def endIter(self):
self.opt.xc = m


def flatten(nested_iterable):
for item in nested_iterable:
if isinstance(item, list):
yield from flatten(item)
elif isinstance(item, np.ndarray):
yield from item.tolist()
else:
yield item


class ScaleMisfitMultipliers(InversionDirective):
"""
Scale the misfits by the relative chi-factors of multiple misfit functions.
Expand All @@ -2670,9 +2680,19 @@ class ScaleMisfitMultipliers(InversionDirective):
Path to save the chi-factors log file.
"""

def __init__(self, path: pathlib.Path | None = None, **kwargs):
def __init__(
self,
path: pathlib.Path | None = None,
nesting: list[list] | None = None,
target_chi: float = 1.0,
headers: list[str] | None = None,
**kwargs,
):
self.last_beta = None
self.chi_factors = None
self.target_chi = target_chi
self.nesting = nesting
self.headers = headers

if path is None:
path = pathlib.Path("./")
Expand All @@ -2681,45 +2701,112 @@ def __init__(self, path: pathlib.Path | None = None, **kwargs):

super().__init__(**kwargs)

self._log_array: np.ndarray | None = None

@property
def log_array(self, headers: list[str] | None = None):
if self._log_array is None:
if self.headers is None:

def append_sub_indices(elements, header):
values = []
for ii, elem in enumerate(elements):
heads = header + f"[{ii}]"
if isinstance(elem, Iterable):
values += append_sub_indices(elem, heads)
else:
values += [heads]
return values

headers = []
for ii, elem in enumerate(self.misfit_tree_indices):
headers += append_sub_indices(elem, f"[{ii}]")
self.headers = headers

dtype = np.dtype(
[("Iterations", np.int32)] + [(h, np.float32) for h in self.headers]
)
self._log_array = np.rec.fromrecords((), dtype=dtype)

return self._log_array

def initialize(self):
self.last_beta = self.invProb.beta
self.multipliers = self.invProb.dmisfit.multipliers
self.scalings = np.ones_like(self.multipliers)
with open(self.filepath, "w", encoding="utf-8") as f:
f.write("Logging of [scaling * chi factor] per misfit function.\n\n")
f.write(
"Iterations\t"
+ "\t".join(
f"[{objfct.name}]" for objfct in self.invProb.dmisfit.objfcts
)
)
f.write("\n")
self.scalings = np.ones_like(self.multipliers) # Everyone gets a fair chance
self.misfit_tree_indices = self.parse_by_nested_levels(self.nesting)

def endIter(self):
ratio = self.invProb.beta / self.last_beta
chi_factors = []
for residual in self.invProb.residuals:
phi_d = np.vdot(residual, residual)
chi_factors.append(phi_d / len(residual))
self.write_log()

self.chi_factors = np.asarray(chi_factors)
def scale_by_level(
self, nested_values, nested_indices, ratio, scaling_vector: np.ndarray | None
):
"""
Recursively compute scaling factors for each level of the nested misfit structure.

if np.all(self.chi_factors < 1) or ratio >= 1:
self.last_beta = self.invProb.beta
self.write_log()
return
The maximum chi-factor at each level is used to determine scaling factors
for the misfit functions at that level. The scaling factors are then propagated
down to the next level of the nested structure.

Parameters
----------
nested_values : list
Nested list of misfit residuals.

nested_indices : list
Nested list of indices corresponding to the misfit residuals.

# Normalize scaling between [ratio, 1]
ratio : float
Ratio of current beta to last beta.

scaling_vector : np.ndarray, optional
Current scaling vector to be updated.
"""
if scaling_vector is None:
scaling_vector = np.ones(len(self.invProb.dmisfit.multipliers))

chi_factors = []
flat_indices = []
for elem, indices in zip(nested_values, nested_indices):
flat_indices.append(np.asarray(list(flatten(indices))))
residuals = np.asarray(list(flatten(elem)))
phi_d = np.vdot(residuals, residuals)
chi_factors.append(phi_d / len(residuals))

chi_factors = np.hstack(chi_factors)
scalings = (
1
- (1 - ratio)
* (self.chi_factors.max() - self.chi_factors)
/ self.chi_factors.max()
1 - (1 - ratio) * (chi_factors.max() - chi_factors) / chi_factors.max()
)

# Force the ones that overshot target
scalings[self.chi_factors < 1] = (
ratio # * self.chi_factors[self.chi_factors < 1]
scalings[chi_factors < self.target_chi] = ratio

for elem, indices, scale, group_ind in zip(
nested_values, nested_indices, scalings, flat_indices
):
# Scale everything below same as super group
scaling_vector[group_ind] = np.maximum(
ratio, scale * scaling_vector[group_ind]
)

# Continue one level deeper if more nesting
if isinstance(indices, list) and (
len(indices) > 1 and isinstance(indices[0], list)
):
scaling_vector = self.scale_by_level(
elem, indices, ratio, scaling_vector
)

return scaling_vector

def endIter(self):
ratio = self.invProb.beta / self.last_beta
nested_residuals = self.parse_by_nested_levels(
self.nesting, self.invProb.residuals
)

scalings = self.scale_by_level(
nested_residuals, self.misfit_tree_indices, ratio, None
)

# Update the scaling
Expand All @@ -2728,22 +2815,76 @@ def endIter(self):
# Normalize total phi_d with scalings
self.invProb.dmisfit.multipliers = self.multipliers * self.scalings
self.last_beta = self.invProb.beta

# Log the scaling factors
self.write_log()

def parse_by_nested_levels(
self, nesting: list[Iterable], values: Iterable | None = None
) -> Iterable:
"""
Replace leaf elements of `nesting` with values from `values` (in order).
Assumes the number of leaf positions equals len(values).

Parameters:
- nesting: arbitrarily nested list structure; leaves are non-list values
- values: flat iterable whose values will fill the leaves in order

Returns:
- A new nested structure with leaves replaced by values from `values`.

Raises:
- ValueError if `values` has fewer or more elements than required by `nesting`.
"""
indices = np.arange(len(self.invProb.dmisfit.objfcts))
if nesting is None:
if values is not None:
return values
return indices.tolist()

it = iter(indices)

def _fill(node: Iterable) -> Iterable:
if isinstance(node, list):
return [_fill(child) for child in node]
elif isinstance(node, dict):
return [_fill(child) for child in node.values()]
# leaf: consume a value
try:
if values is not None:
return values[next(it)]
return next(it)
except StopIteration:
raise ValueError("Not enough elements in `flat` to fill `nesting`.")

result = _fill(nesting)

# ensure no extra elements left
try:
next(it)
raise ValueError("Too many elements in `flat` for the given `nesting`.")
except StopIteration:
pass

return result

def write_log(self):
"""
Write the scaling factors to the log file.
"""
with open(self.filepath, "a", encoding="utf-8") as f:
f.write(
f"{self.opt.iter}\t"
+ "\t".join(
f"{multi:.2e} * {chi:.2e}"
for multi, chi in zip(
self.invProb.dmisfit.multipliers, self.chi_factors
)
)
+ "\n"
self._log_array = np.append(
self.log_array,
np.rec.fromrecords(
tuple([getattr(self.opt, "iter", 0)] + self.scalings.tolist()),
dtype=self.log_array.dtype,
),
)
with open(self.filepath, "w", encoding="utf-8") as f:
np.savetxt(
f,
self.log_array,
header="Iterations - Scaling per misfit",
fmt=["%d"] + ["%0.2e"] * (len(self._log_array.dtype) - 1),
)


Expand Down
Loading