diff --git a/bilby/core/prior/dict.py b/bilby/core/prior/dict.py index be3d543a9..489b0655b 100644 --- a/bilby/core/prior/dict.py +++ b/bilby/core/prior/dict.py @@ -5,6 +5,7 @@ from io import open as ioopen import numpy as np +from scipy.stats.qmc import Halton from .analytical import DeltaFunction from .base import Prior, Constraint @@ -14,6 +15,7 @@ check_directory_exists_and_if_not_mkdir, BilbyJsonEncoder, decode_bilby_json, + random ) @@ -440,6 +442,10 @@ def fixed_keys(self): def constraint_keys(self): return [k for k, p in self.items() if isinstance(p, Constraint)] + @property + def has_constraint(self): + return len(self.constraint_keys) > 0 + def sample_subset_constrained(self, keys=iter([]), size=None): if size is None or size == 1: while True: @@ -465,39 +471,140 @@ def sample_subset_constrained(self, keys=iter([]), size=None): } return all_samples - def normalize_constraint_factor( - self, keys, min_accept=10000, sampling_chunk=50000, nrepeats=10 - ): + def _integrate_normalization_factor_from_samples(self, keys, sampling_chunk): + samples = self.sample_subset(keys=keys, size=sampling_chunk) + factor = np.mean(self.evaluate_constraints(samples)) + return factor + + def _integrate_normalization_factor_from_qmc_quad(self, keys, sampling_chunk): + qrng = Halton(len(keys), seed=random.rng) + theta = qrng.random(sampling_chunk).T + samples = self.rescale(keys=keys, theta=theta) + samples = {key: samps for key, samps in zip(keys, samples)} + factor = np.mean(self.evaluate_constraints(samples)) + return factor + + def normalize_constraint_factor(self, keys, nrepeats=10, sampling_chunk=10000, rel_error_target=0.01, + max_trials=False, method="qmc_quad", **kwargs): + """Estimates the probality normalization factor for constrained priors from (quasi) Monte Carlo integration. + + Parameters + ========== + keys: list, tuple + The set of keys in the prior dict to perform the integration for. Must contain all keys that the constraint + depends on. For joint priors, the full distribution is sampled. Joint prior keys not present in 'keys' are + marginalized. + nrepeats: int + Number of repeated Monte Carlo integrations before convergence is checked. Higher numbers improve the + estimation of the sampling error. + sampling_chunk: int + Number of samples drawn per Monte Carlo integration. Higher numbers improve the integral estimation. + rel_error_target: float + The relative error targeted by the Monte Carlo integration. The relative error of the integral is estimated + from the standard deviation between repeated runs. The algorithm repeats 'nrepeats' iterations until the + rel_error_target is reached. + max_trials: False, int + Second termination criterion. The integration stops when the number of samples evaluated in the next + iteration would exceed 'max_trials' even if rel_error_target was not reached. + method: ['qmc_quad', 'from_samples'] + Method to use for the Monte Carlo integration, effectively choosing between quasi Monte Carlo integration + and "normal" Monte Carlo integration. Due to computational overhead, the latter method will be faster in + most cases. However, 'qmc_quad' is expected to yield lower errors against the ground truth for cases with + few repeats, but high sampling_chunk. + + Returns + ======= + factor_rounded: float + The normalization factor, rounded to the number of significant digits based on the standard deviation of + the integration estimate. + + Notes + ===== + .. seealso:: + :py:class:`scipy.stats.qmc.Halton` + Documentation of the quasi-random number generator used for the quasi-monte carlo-based method + to integrate the normalization factor. + :py:func:`scipy.integrate.qmc_quad` + Documentation of the scipy-native quasi-monte carlo integration scheme. + The implementation, particularly the error estimate, motivates this implementation. + (The error estimate was re-implemented to also apply for the 'from_samples' method.) + + .. versionchanged:: 4.0 + The estimation of the normalization factor is now by default based on quasi-monte carlo integration. + Further, the default stopping criterion is now based on an target for the estimated relative integration + error. + """ + keys = tuple(keys) if keys in self._cached_normalizations.keys(): return self._cached_normalizations[keys] + + sample_keys = tuple(keys) + + # check if the constraint can be applied to the selection of keys + try: + sample = self.sample_subset(keys, size=1) + self.conversion_function(sample) + except KeyError: + raise ValueError("'keys' does not contain all parameters needed to evaluate the constraint.") + + for key in sample_keys: + if isinstance(self[key], JointPrior): + dist_keys = set(self[key].dist.names) + missing_keys = tuple(dist_keys - set(sample_keys)) + sample_keys += missing_keys + for key in missing_keys: + self.sample_subset(missing_keys, 1) + + if method == "from_samples": + integrator = self._integrate_normalization_factor_from_samples + elif method == "qmc_quad": + integrator = self._integrate_normalization_factor_from_qmc_quad + try: + theta = np.random.uniform(0, 1, size=(len(sample_keys), 1)) + samples = self.rescale(sample_keys, theta) + none_keys = [] + for i, key in enumerate(sample_keys): + if samples[i] is None: + none_keys.append(key) + if len(none_keys) > 0: + raise NotImplementedError(f"The rescale method returns 'None', for key(s) {none_keys}.") + except NotImplementedError as e: + logger.info(f"The rescaling step fails with message:\n{e}") + logger.info("Switching to method 'from_samples'") + integrator = self._integrate_normalization_factor_from_samples + method = "from_samples" else: - factor_estimates = [ - self._estimate_normalization(keys, min_accept, sampling_chunk) - for _ in range(nrepeats) - ] - factor = np.mean(factor_estimates) - if np.std(factor_estimates) > 0: - decimals = int(-np.floor(np.log10(3 * np.std(factor_estimates)))) - factor_rounded = np.round(factor, decimals) - else: - factor_rounded = factor + raise ValueError(f"Integration method {method} not understood.\n" + + "Available options are ('from_samples','qmc_quad').") + + trials = 0 + estimates = [] + while True: + for i in range(nrepeats): + integral = integrator(sample_keys, sampling_chunk, **kwargs) + trials += sampling_chunk + estimates.append(integral) + + standard_error = np.std(estimates, ddof=1) / np.sqrt(len(estimates)) + cumulative_integral = np.mean(estimates) + + # compute the rounded factor and relative error (as given by the standard deviation) + factor = 1 / cumulative_integral + rel_error = factor * standard_error + + if rel_error < rel_error_target: + break + elif max_trials and ((trials + sampling_chunk * nrepeats) > max_trials): + break + + if rel_error > 0: + decimals = int(-np.floor(np.log10(3 * rel_error))) + factor_rounded = np.round(factor, decimals) self._cached_normalizations[keys] = factor_rounded - return factor_rounded + else: + self._cached_normalizations[keys] = factor - def _estimate_normalization(self, keys, min_accept, sampling_chunk): - samples = self.sample_subset(keys=keys, size=sampling_chunk) - keep = np.atleast_1d(self.evaluate_constraints(samples)) - if len(keep) == 1: - self._cached_normalizations[keys] = 1 - return 1 - all_samples = {key: np.array([]) for key in keys} - while np.count_nonzero(keep) < min_accept: - samples = self.sample_subset(keys=keys, size=sampling_chunk) - for key in samples: - all_samples[key] = np.hstack([all_samples[key], samples[key].flatten()]) - keep = np.array(self.evaluate_constraints(all_samples), dtype=bool) - factor = len(keep) / np.count_nonzero(keep) - return factor + return self._cached_normalizations[keys] def prob(self, sample, **kwargs): """ @@ -519,6 +626,8 @@ def prob(self, sample, **kwargs): return self.check_prob(sample, prob) def check_prob(self, sample, prob): + if not self.has_constraint: + return prob ratio = self.normalize_constraint_factor(tuple(sample.keys())) if np.all(prob == 0.0): return prob * ratio @@ -558,10 +667,10 @@ def ln_prob(self, sample, axis=None, normalized=True): normalized=normalized) def check_ln_prob(self, sample, ln_prob, normalized=True): - if normalized: + if normalized and self.has_constraint: ratio = self.normalize_constraint_factor(tuple(sample.keys())) else: - ratio = 1 + return ln_prob if np.all(np.isinf(ln_prob)): return ln_prob else: @@ -600,18 +709,23 @@ def rescale(self, keys, theta): ========== keys: list List of prior keys to be rescaled - theta: list - List of randomly drawn values on a unit cube associated with the prior keys + theta: dict or array-like + Randomly drawn values on a unit cube associated with the prior keys Returns ======= - list: List of floats containing the rescaled sample + list: + If theta is 1D, returns list of floats containing the rescaled sample. + If theta is 2D, returns list of lists containing the rescaled samples. """ - theta = list(theta) + theta = [theta[key] for key in keys] if isinstance(theta, dict) else list(theta) samples = [] for key, units in zip(keys, theta): samps = self[key].rescale(units) - samples += list(np.asarray(samps).flatten()) + samples.append(samps) + for i, samps in enumerate(samples): + # turns 0d-arrays into scalars + samples[i] = np.squeeze(samps).tolist() return samples def test_redundancy(self, key, disable_logging=False): @@ -832,28 +946,28 @@ def rescale(self, keys, theta): ========== keys: list List of prior keys to be rescaled - theta: list - List of randomly drawn values on a unit cube associated with the prior keys + theta: dict or array-like + Randomly drawn values on a unit cube associated with the prior keys Returns ======= - list: List of floats containing the rescaled sample + list: + If theta is float for each key, returns list of floats containing the rescaled sample. + If theta is array-like for each key, returns list of lists containing the rescaled samples. """ keys = list(keys) - theta = list(theta) + theta = [theta[key] for key in keys] if isinstance(theta, dict) else list(theta) self._check_resolved() self._update_rescale_keys(keys) result = dict() - for key, index in zip( - self.sorted_keys_without_fixed_parameters, self._rescale_indexes - ): - result[key] = self[key].rescale( - theta[index], **self.get_required_variables(key) - ) + for key, index in zip(self.sorted_keys_without_fixed_parameters, self._rescale_indexes): + result[key] = self[key].rescale(theta[index], **self.get_required_variables(key)) self[key].least_recently_sampled = result[key] samples = [] for key in keys: - samples += list(np.asarray(result[key]).flatten()) + # turns 0d-arrays into scalars + res = np.squeeze(result[key]).tolist() + samples.append(res) return samples def _update_rescale_keys(self, keys): diff --git a/bilby/core/prior/joint.py b/bilby/core/prior/joint.py index 43c8913e3..7e6655ee3 100644 --- a/bilby/core/prior/joint.py +++ b/bilby/core/prior/joint.py @@ -63,8 +63,9 @@ def __init__(self, names, bounds=None): self.requested_parameters = dict() self.reset_request() - # a dictionary of the rescaled parameters - self.rescale_parameters = dict() + # a dictionary of the rescale(d) parameters + self._rescale_parameters = dict() + self._rescaled_parameters = dict() self.reset_rescale() # a list of sampled parameters @@ -94,7 +95,12 @@ def filled_rescale(self): Check if all the rescaled parameters have been filled. """ - return not np.any([val is None for val in self.rescale_parameters.values()]) + return not np.any([val is None for val in self._rescale_parameters.values()]) + + def set_rescale(self, key, values): + values = np.array(values) + self._rescale_parameters[key] = values + self._rescaled_parameters[key] = np.atleast_1d(np.ones_like(values)) * np.nan def reset_rescale(self): """ @@ -102,7 +108,11 @@ def reset_rescale(self): """ for name in self.names: - self.rescale_parameters[name] = None + self._rescale_parameters[name] = None + self._rescaled_parameters[name] = None + + def get_rescaled(self, key): + return self._rescaled_parameters[key] def get_instantiation_dict(self): subclass_args = infer_args_from_method(self.__init__) @@ -303,10 +313,11 @@ def rescale(self, value, **kwargs): Parameters ========== - value: array - A 1d vector sample (one for each parameter) drawn from a uniform + value: array or None + If given, a 1d vector sample (one for each parameter) drawn from a uniform distribution between 0 and 1, or a 2d NxM array of samples where N is the number of samples and M is the number of parameters. + If None, values previously set using BaseJointPriorDist.set_rescale() are used. kwargs: dict All keyword args that need to be passed to _rescale method, these keyword args are called in the JointPrior rescale methods for each parameter @@ -317,7 +328,11 @@ def rescale(self, value, **kwargs): An vector sample drawn from the multivariate Gaussian distribution. """ - samp = np.array(value) + if value is None: + samp = np.array(list(self._rescale_parameters.values())).T + else: + samp = np.array(value) + if len(samp.shape) == 1: samp = samp.reshape(1, self.num_vars) @@ -327,6 +342,11 @@ def rescale(self, value, **kwargs): raise ValueError("Array is the wrong shape") samp = self._rescale(samp, **kwargs) + if value is None: + for i, key in enumerate(self.names): + output = self.get_rescaled(key) + # update in-place for proper handling in PriorDict-instances + output[:] = samp[:, i] return np.squeeze(samp) def _rescale(self, samp, **kwargs): @@ -790,19 +810,23 @@ def rescale(self, val, **kwargs): all kwargs passed to the dist.rescale method Returns ======= - float: - A sample from the prior parameter. + np.ndarray: + The samples from the prior parameter. If not all names in "dist" have been filled, + the array contains only np.nan. *This* specific array instance will be filled with + the rescaled value once all parameters have been requested """ - self.dist.rescale_parameters[self.name] = val + self.dist.set_rescale(self.name, val) if self.dist.filled_rescale(): - values = np.array(list(self.dist.rescale_parameters.values())).T - samples = self.dist.rescale(values, **kwargs) + self.dist.rescale(value=None, **kwargs) + output = self.dist.get_rescaled(self.name) self.dist.reset_rescale() - return samples else: - return [] # return empty list + output = self.dist.get_rescaled(self.name) + + # have to return raw output to conserve in-place modifications + return output def sample(self, size=1, **kwargs): """ diff --git a/test/core/prior/dict_test.py b/test/core/prior/dict_test.py index d6e6239f2..8c2f702e0 100644 --- a/test/core/prior/dict_test.py +++ b/test/core/prior/dict_test.py @@ -294,6 +294,95 @@ def test_sample_subset_constrained_as_array(self): self.assertTrue(isinstance(out, np.ndarray)) self.assertTrue(out.shape == (len(keys), size)) + def test_constrained_normalization(self): + prior_x = bilby.prior.Uniform(-1, 1, name="x") + prior_y = bilby.prior.Uniform(-1, 1, name="y") + + def radius_squared(sample_dict): + sample_dict["constraint"] = np.array(sample_dict["x"]) ** 2 + np.array(sample_dict["y"]) ** 2 + return sample_dict + + r = 1 + prior_constraint = bilby.prior.Constraint(minimum=0, maximum=r**2, name="constraint") + + prior_dict = bilby.prior.PriorDict( + { + "x": prior_x, + "y": prior_y, + "constraint": prior_constraint, + }, + conversion_function=radius_squared, + ) + + factor = prior_dict.normalize_constraint_factor(keys=("x", "y")) + truth = 4 / (np.pi * r**2) + self.assertAlmostEqual(truth, factor, delta=truth * 0.01) + + r = 2 + prior_constraint_2 = bilby.prior.Constraint(minimum=0, maximum=r**2, name="constraint") + + prior_dict_2 = bilby.prior.PriorDict( + { + "x": prior_x, + "y": prior_y, + "constraint": prior_constraint_2, + }, + conversion_function=radius_squared, + ) + + factor_2 = prior_dict_2.normalize_constraint_factor(keys=("x", "y")) + truth_2 = 1 + self.assertAlmostEqual(truth_2, factor_2, delta=truth_2 * 0.01) + + prior_dist_1 = bilby.prior.MultivariateGaussianDist( + names=["x", "y", "z"], + mus=[0, 0, 0], + sigmas=[1, 1, 1], + ) + + prior_dist_2 = bilby.prior.MultivariateGaussianDist( + names=["x", "y"], + mus=[0, 0], + sigmas=[1, 1], + ) + + joint_prior_x = bilby.prior.JointPrior(dist=prior_dist_1, name="x") + joint_prior_y = bilby.prior.JointPrior(dist=prior_dist_1, name="y") + joint_prior_z = bilby.prior.JointPrior(dist=prior_dist_1, name="z") + + r = 1 + prior_constraint_3 = bilby.prior.Constraint(minimum=0, maximum=r**2, name="constraint") + + joint_prior_x_2 = bilby.prior.JointPrior(dist=prior_dist_2, name="x") + joint_prior_y_2 = bilby.prior.JointPrior(dist=prior_dist_2, name="y") + prior_constraint_4 = bilby.prior.Constraint(minimum=0, maximum=r**2, name="constraint") + prior_dict_4 = bilby.prior.PriorDict( + { + "x": joint_prior_x_2, + "y": joint_prior_y_2, + "constraint": prior_constraint_4, + }, + conversion_function=radius_squared, + ) + prior_dict_3 = bilby.prior.PriorDict( + { + "x": joint_prior_x, + "y": joint_prior_y, + "z": joint_prior_z, + "constraint": prior_constraint_3, + }, + conversion_function=radius_squared, + ) + + factor_3 = prior_dict_3.normalize_constraint_factor(keys=("x", "y")) + factor_4 = prior_dict_4.normalize_constraint_factor(keys=("x", "y")) + + from scipy.stats import chi2 + + truth_3 = 1 / (chi2(df=2).cdf(1)) + self.assertAlmostEqual(truth_3, factor_3, delta=truth_3 * 0.01) + self.assertAlmostEqual(truth_3, factor_4, delta=truth_3 * 0.01) + def test_sample(self): size = 7 bilby.core.utils.random.seed(42)