Skip to content

Commit b4babbc

Browse files
Jammy2211Jammy2211
authored andcommitted
fitness passed around due to JAX
1 parent 2d45f2d commit b4babbc

2 files changed

Lines changed: 6 additions & 7 deletions

File tree

autofit/non_linear/mock/mock_search.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def __call__(self, vector):
103103
fitness = Fitness(model.instance_from_vector, result=self.result)
104104
fitness([prior.mean for prior in model.priors_ordered_by_id])
105105

106-
return fitness.result
106+
return fitness.result, fitness
107107

108108
def _fit(self, model, analysis):
109109
if self.fit_fast:
@@ -141,9 +141,9 @@ def _fit(self, model, analysis):
141141
return analysis.make_result(
142142
samples_summary=samples_summary,
143143
paths=self.paths,
144-
)
144+
), None
145145

146-
def perform_update(self, model, analysis, during_analysis, search_internal=None):
146+
def perform_update(self, model, analysis, during_analysis, fitness=None, search_internal=None):
147147
if self.samples_summary is not None and not self.return_sensitivity_results:
148148
self.paths.save_samples_summary(self.samples_summary)
149149

autofit/non_linear/search/nest/nautilus/search.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from autofit.non_linear.search.nest import abstract_nest
1515
from autofit.non_linear.samples.sample import Sample
1616
from autofit.non_linear.samples.nest import SamplesNest
17-
from autogalaxy_workspace_test.jax_examples.func_grad import fitness
1817

1918
logger = logging.getLogger(__name__)
2019

@@ -225,7 +224,7 @@ def fit_x1_cpu(self, fitness, model, analysis):
225224
**self.config_dict_search,
226225
)
227226

228-
return self.call_search(search_internal=search_internal, model=model, analysis=analysis)
227+
return self.call_search(search_internal=search_internal, model=model, analysis=analysis, fitness=fitness)
229228

230229
def fit_multiprocessing(self, fitness, model, analysis):
231230
"""
@@ -259,9 +258,9 @@ def fit_multiprocessing(self, fitness, model, analysis):
259258
**self.config_dict_search,
260259
)
261260

262-
return self.call_search(search_internal=search_internal, model=model, analysis=analysis)
261+
return self.call_search(search_internal=search_internal, model=model, analysis=analysis, fitness=fitness)
263262

264-
def call_search(self, search_internal, model, analysis):
263+
def call_search(self, search_internal, model, analysis, fitness):
265264
"""
266265
The x1 CPU and multiprocessing searches both call this function to perform the non-linear search.
267266

0 commit comments

Comments
 (0)