Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 160 additions & 46 deletions bilby/core/prior/dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from io import open as ioopen

import numpy as np
from scipy.stats.qmc import Halton

from .analytical import DeltaFunction
from .base import Prior, Constraint
Expand All @@ -14,6 +15,7 @@
check_directory_exists_and_if_not_mkdir,
BilbyJsonEncoder,
decode_bilby_json,
random
)


Expand Down Expand Up @@ -440,6 +442,10 @@ def fixed_keys(self):
def constraint_keys(self):
return [k for k, p in self.items() if isinstance(p, Constraint)]

@property
def has_constraint(self):
return len(self.constraint_keys) > 0

def sample_subset_constrained(self, keys=iter([]), size=None):
if size is None or size == 1:
while True:
Expand All @@ -465,39 +471,140 @@ def sample_subset_constrained(self, keys=iter([]), size=None):
}
return all_samples

def normalize_constraint_factor(
self, keys, min_accept=10000, sampling_chunk=50000, nrepeats=10
):
def _integrate_normalization_factor_from_samples(self, keys, sampling_chunk):
samples = self.sample_subset(keys=keys, size=sampling_chunk)
factor = np.mean(self.evaluate_constraints(samples))
return factor

def _integrate_normalization_factor_from_qmc_quad(self, keys, sampling_chunk):
qrng = Halton(len(keys), seed=random.rng)
theta = qrng.random(sampling_chunk).T
samples = self.rescale(keys=keys, theta=theta)
samples = {key: samps for key, samps in zip(keys, samples)}
factor = np.mean(self.evaluate_constraints(samples))
return factor

def normalize_constraint_factor(self, keys, nrepeats=10, sampling_chunk=10000, rel_error_target=0.01,
max_trials=False, method="qmc_quad", **kwargs):
"""Estimates the probality normalization factor for constrained priors from (quasi) Monte Carlo integration.

Parameters
==========
keys: list, tuple
The set of keys in the prior dict to perform the integration for. Must contain all keys that the constraint
depends on. For joint priors, the full distribution is sampled. Joint prior keys not present in 'keys' are
marginalized.
nrepeats: int
Number of repeated Monte Carlo integrations before convergence is checked. Higher numbers improve the
estimation of the sampling error.
sampling_chunk: int
Number of samples drawn per Monte Carlo integration. Higher numbers improve the integral estimation.
rel_error_target: float
The relative error targeted by the Monte Carlo integration. The relative error of the integral is estimated
from the standard deviation between repeated runs. The algorithm repeats 'nrepeats' iterations until the
rel_error_target is reached.
max_trials: False, int
Second termination criterion. The integration stops when the number of samples evaluated in the next
iteration would exceed 'max_trials' even if rel_error_target was not reached.
method: ['qmc_quad', 'from_samples']
Method to use for the Monte Carlo integration, effectively choosing between quasi Monte Carlo integration
and "normal" Monte Carlo integration. Due to computational overhead, the latter method will be faster in
most cases. However, 'qmc_quad' is expected to yield lower errors against the ground truth for cases with
few repeats, but high sampling_chunk.

Returns
=======
factor_rounded: float
The normalization factor, rounded to the number of significant digits based on the standard deviation of
the integration estimate.

Notes
=====
.. seealso::
:py:class:`scipy.stats.qmc.Halton`
Documentation of the quasi-random number generator used for the quasi-monte carlo-based method
to integrate the normalization factor.
:py:func:`scipy.integrate.qmc_quad`
Documentation of the scipy-native quasi-monte carlo integration scheme.
The implementation, particularly the error estimate, motivates this implementation.
(The error estimate was re-implemented to also apply for the 'from_samples' method.)

.. versionchanged:: 4.0
The estimation of the normalization factor is now by default based on quasi-monte carlo integration.
Further, the default stopping criterion is now based on an target for the estimated relative integration
error.
"""
keys = tuple(keys)
if keys in self._cached_normalizations.keys():
return self._cached_normalizations[keys]

sample_keys = tuple(keys)

# check if the constraint can be applied to the selection of keys
try:
sample = self.sample_subset(keys, size=1)
self.conversion_function(sample)
except KeyError:
raise ValueError("'keys' does not contain all parameters needed to evaluate the constraint.")

for key in sample_keys:
if isinstance(self[key], JointPrior):
dist_keys = set(self[key].dist.names)
missing_keys = tuple(dist_keys - set(sample_keys))
sample_keys += missing_keys
for key in missing_keys:
self.sample_subset(missing_keys, 1)

if method == "from_samples":
integrator = self._integrate_normalization_factor_from_samples
elif method == "qmc_quad":
integrator = self._integrate_normalization_factor_from_qmc_quad
try:
theta = np.random.uniform(0, 1, size=(len(sample_keys), 1))
samples = self.rescale(sample_keys, theta)
none_keys = []
for i, key in enumerate(sample_keys):
if samples[i] is None:
none_keys.append(key)
if len(none_keys) > 0:
raise NotImplementedError(f"The rescale method returns 'None', for key(s) {none_keys}.")
except NotImplementedError as e:
logger.info(f"The rescaling step fails with message:\n{e}")
logger.info("Switching to method 'from_samples'")
integrator = self._integrate_normalization_factor_from_samples
method = "from_samples"
else:
factor_estimates = [
self._estimate_normalization(keys, min_accept, sampling_chunk)
for _ in range(nrepeats)
]
factor = np.mean(factor_estimates)
if np.std(factor_estimates) > 0:
decimals = int(-np.floor(np.log10(3 * np.std(factor_estimates))))
factor_rounded = np.round(factor, decimals)
else:
factor_rounded = factor
raise ValueError(f"Integration method {method} not understood.\n" +
"Available options are ('from_samples','qmc_quad').")

trials = 0
estimates = []
while True:
for i in range(nrepeats):
integral = integrator(sample_keys, sampling_chunk, **kwargs)
trials += sampling_chunk
estimates.append(integral)

standard_error = np.std(estimates, ddof=1) / np.sqrt(len(estimates))
cumulative_integral = np.mean(estimates)

# compute the rounded factor and relative error (as given by the standard deviation)
factor = 1 / cumulative_integral
rel_error = factor * standard_error

if rel_error < rel_error_target:
break
elif max_trials and ((trials + sampling_chunk * nrepeats) > max_trials):
break

if rel_error > 0:
decimals = int(-np.floor(np.log10(3 * rel_error)))
factor_rounded = np.round(factor, decimals)
self._cached_normalizations[keys] = factor_rounded
return factor_rounded
else:
self._cached_normalizations[keys] = factor

def _estimate_normalization(self, keys, min_accept, sampling_chunk):
samples = self.sample_subset(keys=keys, size=sampling_chunk)
keep = np.atleast_1d(self.evaluate_constraints(samples))
if len(keep) == 1:
self._cached_normalizations[keys] = 1
return 1
all_samples = {key: np.array([]) for key in keys}
while np.count_nonzero(keep) < min_accept:
samples = self.sample_subset(keys=keys, size=sampling_chunk)
for key in samples:
all_samples[key] = np.hstack([all_samples[key], samples[key].flatten()])
keep = np.array(self.evaluate_constraints(all_samples), dtype=bool)
factor = len(keep) / np.count_nonzero(keep)
return factor
return self._cached_normalizations[keys]

def prob(self, sample, **kwargs):
"""
Expand All @@ -519,6 +626,8 @@ def prob(self, sample, **kwargs):
return self.check_prob(sample, prob)

def check_prob(self, sample, prob):
if not self.has_constraint:
return prob
ratio = self.normalize_constraint_factor(tuple(sample.keys()))
if np.all(prob == 0.0):
return prob * ratio
Expand Down Expand Up @@ -558,10 +667,10 @@ def ln_prob(self, sample, axis=None, normalized=True):
normalized=normalized)

def check_ln_prob(self, sample, ln_prob, normalized=True):
if normalized:
if normalized and self.has_constraint:
ratio = self.normalize_constraint_factor(tuple(sample.keys()))
else:
ratio = 1
return ln_prob
if np.all(np.isinf(ln_prob)):
return ln_prob
else:
Expand Down Expand Up @@ -600,18 +709,23 @@ def rescale(self, keys, theta):
==========
keys: list
List of prior keys to be rescaled
theta: list
List of randomly drawn values on a unit cube associated with the prior keys
theta: dict or array-like
Randomly drawn values on a unit cube associated with the prior keys

Returns
=======
list: List of floats containing the rescaled sample
list:
If theta is 1D, returns list of floats containing the rescaled sample.
If theta is 2D, returns list of lists containing the rescaled samples.
"""
theta = list(theta)
theta = [theta[key] for key in keys] if isinstance(theta, dict) else list(theta)
samples = []
for key, units in zip(keys, theta):
samps = self[key].rescale(units)
samples += list(np.asarray(samps).flatten())
samples.append(samps)
for i, samps in enumerate(samples):
# turns 0d-arrays into scalars
samples[i] = np.squeeze(samps).tolist()
return samples

def test_redundancy(self, key, disable_logging=False):
Expand Down Expand Up @@ -832,28 +946,28 @@ def rescale(self, keys, theta):
==========
keys: list
List of prior keys to be rescaled
theta: list
List of randomly drawn values on a unit cube associated with the prior keys
theta: dict or array-like
Randomly drawn values on a unit cube associated with the prior keys

Returns
=======
list: List of floats containing the rescaled sample
list:
If theta is float for each key, returns list of floats containing the rescaled sample.
If theta is array-like for each key, returns list of lists containing the rescaled samples.
"""
keys = list(keys)
theta = list(theta)
theta = [theta[key] for key in keys] if isinstance(theta, dict) else list(theta)
self._check_resolved()
self._update_rescale_keys(keys)
result = dict()
for key, index in zip(
self.sorted_keys_without_fixed_parameters, self._rescale_indexes
):
result[key] = self[key].rescale(
theta[index], **self.get_required_variables(key)
)
for key, index in zip(self.sorted_keys_without_fixed_parameters, self._rescale_indexes):
result[key] = self[key].rescale(theta[index], **self.get_required_variables(key))
self[key].least_recently_sampled = result[key]
samples = []
for key in keys:
samples += list(np.asarray(result[key]).flatten())
# turns 0d-arrays into scalars
res = np.squeeze(result[key]).tolist()
samples.append(res)
return samples

def _update_rescale_keys(self, keys):
Expand Down
52 changes: 38 additions & 14 deletions bilby/core/prior/joint.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,9 @@ def __init__(self, names, bounds=None):
self.requested_parameters = dict()
self.reset_request()

# a dictionary of the rescaled parameters
self.rescale_parameters = dict()
# a dictionary of the rescale(d) parameters
self._rescale_parameters = dict()
self._rescaled_parameters = dict()
self.reset_rescale()

# a list of sampled parameters
Expand Down Expand Up @@ -94,15 +95,24 @@ def filled_rescale(self):
Check if all the rescaled parameters have been filled.
"""

return not np.any([val is None for val in self.rescale_parameters.values()])
return not np.any([val is None for val in self._rescale_parameters.values()])

def set_rescale(self, key, values):
values = np.array(values)
self._rescale_parameters[key] = values
self._rescaled_parameters[key] = np.atleast_1d(np.ones_like(values)) * np.nan

def reset_rescale(self):
"""
Reset the rescaled parameters to None.
"""

for name in self.names:
self.rescale_parameters[name] = None
self._rescale_parameters[name] = None
self._rescaled_parameters[name] = None

def get_rescaled(self, key):
return self._rescaled_parameters[key]

def get_instantiation_dict(self):
subclass_args = infer_args_from_method(self.__init__)
Expand Down Expand Up @@ -303,10 +313,11 @@ def rescale(self, value, **kwargs):

Parameters
==========
value: array
A 1d vector sample (one for each parameter) drawn from a uniform
value: array or None
If given, a 1d vector sample (one for each parameter) drawn from a uniform
distribution between 0 and 1, or a 2d NxM array of samples where
N is the number of samples and M is the number of parameters.
If None, values previously set using BaseJointPriorDist.set_rescale() are used.
kwargs: dict
All keyword args that need to be passed to _rescale method, these keyword
args are called in the JointPrior rescale methods for each parameter
Expand All @@ -317,7 +328,11 @@ def rescale(self, value, **kwargs):
An vector sample drawn from the multivariate Gaussian
distribution.
"""
samp = np.array(value)
if value is None:
samp = np.array(list(self._rescale_parameters.values())).T
else:
samp = np.array(value)

if len(samp.shape) == 1:
samp = samp.reshape(1, self.num_vars)

Expand All @@ -327,6 +342,11 @@ def rescale(self, value, **kwargs):
raise ValueError("Array is the wrong shape")

samp = self._rescale(samp, **kwargs)
if value is None:
for i, key in enumerate(self.names):
output = self.get_rescaled(key)
# update in-place for proper handling in PriorDict-instances
output[:] = samp[:, i]
return np.squeeze(samp)

def _rescale(self, samp, **kwargs):
Expand Down Expand Up @@ -790,19 +810,23 @@ def rescale(self, val, **kwargs):
all kwargs passed to the dist.rescale method
Returns
=======
float:
A sample from the prior parameter.
np.ndarray:
The samples from the prior parameter. If not all names in "dist" have been filled,
the array contains only np.nan. *This* specific array instance will be filled with
the rescaled value once all parameters have been requested
"""

self.dist.rescale_parameters[self.name] = val
self.dist.set_rescale(self.name, val)

if self.dist.filled_rescale():
values = np.array(list(self.dist.rescale_parameters.values())).T
samples = self.dist.rescale(values, **kwargs)
self.dist.rescale(value=None, **kwargs)
output = self.dist.get_rescaled(self.name)
self.dist.reset_rescale()
return samples
else:
return [] # return empty list
output = self.dist.get_rescaled(self.name)

# have to return raw output to conserve in-place modifications
return output

def sample(self, size=1, **kwargs):
"""
Expand Down
Loading