Skip to content

Commit a0e4fa9

Browse files
Jammy2211Jammy2211
authored andcommitted
aggregator works with FactorGraph
1 parent 9de61b8 commit a0e4fa9

3 files changed

Lines changed: 78 additions & 65 deletions

File tree

autolens/aggregator/tracer.py

Lines changed: 44 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,14 @@
22
from typing import List, Optional
33

44
import autofit as af
5+
import autoarray as aa
6+
import autogalaxy as ag
57

68
from autolens.lens.tracer import Tracer
79

10+
from autogalaxy.aggregator import agg_util
11+
from autolens.lens import tracer_util
12+
813
logger = logging.getLogger(__name__)
914

1015

@@ -18,16 +23,15 @@ def _tracer_from(
1823
attributes of the fit:
1924
2025
- The model and its best fit parameters (e.g. `model.json`).
21-
- The adapt images associated with adaptive galaxy features (`adapt` folder).
2226
2327
Each individual attribute can be loaded from the database via the `fit.value()` method.
2428
25-
This method combines all of these attributes and returns a `Tracer` object for a given non-linear search sample
26-
(e.g. the maximum likelihood model). This includes associating adapt images with their respective galaxies.
29+
This method combines this attributesand returns a `Tracer` object for a given non-linear search sample
30+
(e.g. the maximum likelihood model).
2731
28-
If multiple `Tracer` objects were fitted simultaneously via analysis summing, the `fit.child_values()` method
29-
is instead used to load lists of Tracers. This is necessary if each Tracer has different galaxies (e.g. certain
30-
parameters vary across each dataset and `Analysis` object).
32+
If multiple `Tracer` objects were fitted simultaneously via multiple analysis, the instance is iterated over as
33+
a list such that a list of `Tracer` objects with parameters updated for each analysis are returned. This means
34+
fits using a single analysis are wrapped in a list to prodcue a consistent API.
3135
3236
Parameters
3337
----------
@@ -39,41 +43,46 @@ def _tracer_from(
3943
randomly from the PDF).
4044
"""
4145

42-
if instance is not None:
46+
instance_list = agg_util.instance_list_from(fit=fit, instance=instance)
47+
48+
tracer_list = []
49+
50+
for instance in instance_list:
51+
4352
galaxies = instance.galaxies
4453

4554
if hasattr(instance, "extra_galaxies"):
46-
if fit.instance.extra_galaxies is not None:
47-
galaxies = galaxies + fit.instance.extra_galaxies
48-
49-
else:
50-
galaxies = fit.instance.galaxies
51-
52-
if hasattr(fit.instance, "extra_galaxies"):
53-
if fit.instance.extra_galaxies is not None:
54-
galaxies = galaxies + fit.instance.extra_galaxies
55-
56-
try:
57-
cosmology = instance.cosmology
58-
except AttributeError:
59-
cosmology = fit.value(name="cosmology")
60-
61-
tracer = Tracer(galaxies=galaxies, cosmology=cosmology)
62-
63-
if fit.children is not None:
64-
if len(fit.children) > 0:
65-
logger.info(
66-
"""
67-
Using database for a fit with multiple summed Analysis objects.
68-
69-
Tracer objects do not fully support this yet (e.g. model parameters which vary over analyses may be incorrect)
70-
so proceed with caution!
71-
"""
55+
if instance.extra_galaxies is not None:
56+
galaxies = galaxies + instance.extra_galaxies
57+
58+
try:
59+
cosmology = instance.cosmology
60+
except AttributeError:
61+
cosmology = fit.value(name="cosmology")
62+
63+
if cosmology is None:
64+
cosmology = ag.cosmo.Planck15()
65+
66+
# TODO : These are ugly as hell (>_<)
67+
68+
if hasattr(instance, "perturb"):
69+
galaxies.subhalo = instance.perturb
70+
71+
if hasattr(instance.galaxies, "subhalo"):
72+
subhalo_centre = tracer_util.grid_2d_at_redshift_from(
73+
galaxies=instance.galaxies,
74+
redshift=instance.galaxies.subhalo.redshift,
75+
grid=aa.Grid2DIrregular(values=[instance.galaxies.subhalo.mass.centre]),
76+
cosmology=cosmology,
7277
)
7378

74-
return [tracer] * len(fit.children)
79+
galaxies.subhalo.mass.centre = tuple(subhalo_centre.in_list[0])
80+
81+
tracer = Tracer(galaxies=galaxies, cosmology=cosmology)
82+
83+
tracer_list.append(tracer)
7584

76-
return [tracer]
85+
return tracer_list
7786

7887

7988
class TracerAgg(af.AggBase):

test_autolens/aggregator/test_aggregator_fit_imaging.py

Lines changed: 32 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -35,36 +35,38 @@ def test__fit_imaging_randomly_drawn_via_pdf_gen_from(
3535
clean(database_file=database_file)
3636

3737

38-
def test__fit_imaging_randomly_drawn_via_pdf_gen_from__analysis_multi(
39-
analysis_imaging_7x7, samples, model
40-
):
41-
agg = aggregator_from(
42-
database_file=database_file,
43-
analysis=analysis_imaging_7x7 + analysis_imaging_7x7,
44-
model=model,
45-
samples=samples,
46-
)
47-
48-
fit_agg = al.agg.FitImagingAgg(aggregator=agg)
49-
fit_pdf_gen = fit_agg.randomly_drawn_via_pdf_gen_from(total_samples=2)
50-
51-
i = 0
52-
53-
for fit_gen in fit_pdf_gen:
54-
for fit_list in fit_gen:
55-
i += 1
56-
57-
assert fit_list[0].tracer.galaxies[0].redshift == 0.5
58-
assert fit_list[0].tracer.galaxies[0].light.centre == (10.0, 10.0)
59-
assert fit_list[0].tracer.galaxies[1].redshift == 1.0
60-
61-
assert fit_list[1].tracer.galaxies[0].redshift == 0.5
62-
assert fit_list[1].tracer.galaxies[0].light.centre == (10.0, 10.0)
63-
assert fit_list[1].tracer.galaxies[1].redshift == 1.0
64-
65-
assert i == 2
66-
67-
clean(database_file=database_file)
38+
# TODO : These need to use FactorGraphModel
39+
40+
# def test__fit_imaging_randomly_drawn_via_pdf_gen_from__analysis_multi(
41+
# analysis_imaging_7x7, samples, model
42+
# ):
43+
# agg = aggregator_from(
44+
# database_file=database_file,
45+
# analysis=analysis_imaging_7x7 + analysis_imaging_7x7,
46+
# model=model,
47+
# samples=samples,
48+
# )
49+
#
50+
# fit_agg = al.agg.FitImagingAgg(aggregator=agg)
51+
# fit_pdf_gen = fit_agg.randomly_drawn_via_pdf_gen_from(total_samples=2)
52+
#
53+
# i = 0
54+
#
55+
# for fit_gen in fit_pdf_gen:
56+
# for fit_list in fit_gen:
57+
# i += 1
58+
#
59+
# assert fit_list[0].tracer.galaxies[0].redshift == 0.5
60+
# assert fit_list[0].tracer.galaxies[0].light.centre == (10.0, 10.0)
61+
# assert fit_list[0].tracer.galaxies[1].redshift == 1.0
62+
#
63+
# assert fit_list[1].tracer.galaxies[0].redshift == 0.5
64+
# assert fit_list[1].tracer.galaxies[0].light.centre == (10.0, 10.0)
65+
# assert fit_list[1].tracer.galaxies[1].redshift == 1.0
66+
#
67+
# assert i == 2
68+
#
69+
# clean(database_file=database_file)
6870

6971

7072
def test__fit_imaging_all_above_weight_gen(analysis_imaging_7x7, samples, model):

test_autolens/aggregator/test_aggregator_fit_interferometer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ def test__fit_interferometer_randomly_drawn_via_pdf_gen_from(
3535
clean(database_file=database_file)
3636

3737

38+
# TODO : These need to use FactorGraphModel
39+
3840
# def test__fit_interferometer_randomly_drawn_via_pdf_gen_from__analysis_multi(analysis_interferometer_7, samples, model):
3941
#
4042
# agg = aggregator_from(

0 commit comments

Comments
 (0)