Skip to content

Commit cb54409

Browse files
authored
Merge pull request #1176 from rhayes777/feature/jax-search-logging
Make search logging JAX-aware
2 parents b7eff1e + 380f73b commit cb54409

2 files changed

Lines changed: 22 additions & 6 deletions

File tree

autofit/non_linear/search/abstract_search.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -470,7 +470,20 @@ class represented by model M and gives a score for their fitness.
470470
"""
471471
self.check_model(model=model)
472472

473-
logger.info(f"Starting non-linear search with {self.number_of_cores} cores.")
473+
if getattr(analysis, "_use_jax", False):
474+
try:
475+
import jax
476+
devices = jax.devices()
477+
device = devices[0]
478+
backend = device.platform.upper()
479+
device_name = getattr(device, "device_kind", backend)
480+
logger.info(
481+
f"Starting non-linear search with JAX ({backend}: {device_name})."
482+
)
483+
except Exception:
484+
logger.info("Starting non-linear search with JAX.")
485+
else:
486+
logger.info(f"Starting non-linear search with {self.number_of_cores} cores.")
474487
self._log_process_state()
475488

476489
model = analysis.modify_model(model)

autofit/non_linear/search/nest/nautilus/search.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -216,11 +216,14 @@ def fit_x1_cpu(self, fitness, model, analysis):
216216
the log likelihood the search maximizes.
217217
"""
218218

219-
self.logger.info(
220-
"""
221-
Running search where parallelization is disabled.
222-
"""
223-
)
219+
if analysis._use_jax:
220+
self.logger.info(
221+
"Running search with JAX vectorization (parallelization handled by JAX)."
222+
)
223+
else:
224+
self.logger.info(
225+
"Running search where parallelization is disabled."
226+
)
224227

225228
config_dict = self.config_dict_search
226229
try:

0 commit comments

Comments
 (0)