Skip to content

Commit a77f7ef

Browse files
Jammy2211Jammy2211
authored andcommitted
minor changes
1 parent b8f7a45 commit a77f7ef

6 files changed

Lines changed: 18 additions & 8 deletions

File tree

autofit/mapper/prior/abstract.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from typing import Union, Tuple, Optional, Dict
66

77
from autoconf import conf
8-
from autoconf import jax_wrapper
98

109
from autofit.mapper.prior.arithmetic import ArithmeticMixin
1110
from autofit.mapper.prior.constant import Constant

autofit/mapper/prior_model/abstract.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
import copy
22
import inspect
3-
import jax.numpy as jnp
4-
import jax
53
import json
64
import logging
75
import random

autofit/non_linear/analysis/analysis.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def method(*args, **kwargs):
5858

5959
return method
6060

61-
def compute_latent_samples(self, samples: Samples) -> Optional[Samples]:
61+
def compute_latent_samples(self, samples: Samples, batch_size : Optional[int] = None) -> Optional[Samples]:
6262
"""
6363
Compute latent variables from a model instance.
6464
@@ -91,11 +91,19 @@ def compute_latent_samples(self, samples: Samples) -> Optional[Samples]:
9191
`(intensity_total, magnitude, angle)`. Each entry may be NaN if the corresponding component
9292
of the model is not present.
9393
"""
94+
95+
if use_jax:
96+
xp = jnp
97+
else:
98+
xp = np
99+
100+
batch_size = batch_size or 10
101+
94102
try:
95103

96104
start_latent = time.time()
97105

98-
compute_latent_for_model = functools.partial(self.compute_latent_variables, model=samples.model)
106+
compute_latent_for_model = functools.partial(self.compute_latent_variables, model=samples.model, xp=xp)
99107

100108
if use_jax:
101109
start = time.time()
@@ -107,7 +115,6 @@ def batched_compute_latent(x):
107115
return np.array([compute_latent_for_model(xx) for xx in x])
108116

109117
parameter_array = np.array(samples.parameter_lists)
110-
batch_size = 50
111118
latent_samples = []
112119

113120
# process in batches

autofit/non_linear/fitness.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def __init__(
4343
convert_to_chi_squared: bool = False,
4444
store_history: bool = False,
4545
use_jax_vmap : bool = False,
46+
batch_size : Optional[int] = None,
4647
iterations_per_quick_update: Optional[int] = None,
4748
):
4849
"""
@@ -123,6 +124,7 @@ def __init__(
123124
if self.use_jax_vmap:
124125
self._call = self._vmap
125126

127+
self.batch_size = batch_size
126128
self.iterations_per_quick_update = iterations_per_quick_update
127129
self.quick_update_max_lh_parameters = None
128130
self.quick_update_max_lh = -xp.inf
@@ -152,7 +154,7 @@ def call(self, parameters):
152154
instance = self.model.instance_from_vector(vector=parameters)
153155

154156
# Evaluate log likelihood (must be side-effect free and exception-free)
155-
log_likelihood = self.analysis.log_likelihood_function(instance=instance)
157+
log_likelihood = self.analysis.log_likelihood_function(instance=instance, xp=xp)
156158

157159
# Penalize NaNs in the log-likelihood
158160
log_likelihood = xp.where(xp.isnan(log_likelihood), self.resample_figure_of_merit, log_likelihood)

autofit/non_linear/search/abstract_search.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -939,7 +939,10 @@ def perform_update(
939939

940940
latent_samples = samples_save
941941

942-
latent_samples = analysis.compute_latent_samples(latent_samples)
942+
latent_samples = analysis.compute_latent_samples(
943+
latent_samples,
944+
batch_size=fitness.batch_size
945+
)
943946

944947
if latent_samples:
945948
if not conf.instance["output"]["latent_draw_via_pdf"]:

autofit/non_linear/search/nest/nautilus/search.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ def _fit(self, model: AbstractPriorModel, analysis):
139139
fom_is_log_likelihood=True,
140140
resample_figure_of_merit=-1.0e99,
141141
use_jax_vmap=True,
142+
batch_size=self.config_dict_search["n_batch"],
142143
iterations_per_quick_update=self.iterations_per_quick_update
143144

144145
)

0 commit comments

Comments
 (0)