Skip to content

Commit 8e4b8e2

Browse files
Jammy2211Jammy2211
authored andcommitted
jax solver tests reinstated
1 parent f241b35 commit 8e4b8e2

2 files changed

Lines changed: 77 additions & 135 deletions

File tree

autolens/point/solver/shape_solver.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from autoarray.structures.triangles.shape import Shape
1010
from autofit.jax_wrapper import register_pytree_node_class
1111

12-
from autoarray.structures.triangles.coordinate_array.jax_coordinate_array import (
12+
from autoarray.structures.triangles.coordinate_array import (
1313
CoordinateArrayTriangles,
1414
)
1515
from autoarray.structures.triangles.abstract import AbstractTriangles
Lines changed: 76 additions & 134 deletions
Original file line numberDiff line numberDiff line change
@@ -1,141 +1,83 @@
1+
import numpy as np
2+
import pytest
13
import time
24
from typing import Tuple
35

4-
import pytest
5-
66
import autogalaxy as ag
77
import autofit as af
8-
import numpy as np
98
from autolens import PointSolver, Tracer
109

11-
#
12-
# try:
13-
# from autoarray.structures.triangles.coordinate_array.jax_coordinate_array import (
14-
# CoordinateArrayTriangles,
15-
# )
16-
#
17-
# except ImportError:
18-
# from autoarray.structures.triangles.coordinate_array import CoordinateArrayTriangles
19-
#
20-
# from autolens.mock import NullTracer
21-
#
22-
# pytest.importorskip("jax")
23-
#
24-
#
25-
# @pytest.fixture(autouse=True)
26-
# def register(tracer):
27-
# af.Model.from_instance(tracer)
28-
#
29-
#
30-
# @pytest.fixture
31-
# def solver(grid):
32-
# return PointSolver.for_grid(
33-
# grid=grid,
34-
# pixel_scale_precision=0.01,
35-
# array_triangles_cls=CoordinateArrayTriangles,
36-
# )
37-
#
38-
#
39-
# def test_solver(solver):
40-
# mass_profile = ag.mp.Isothermal(
41-
# centre=(0.0, 0.0),
42-
# einstein_radius=1.0,
43-
# )
44-
# tracer = Tracer(
45-
# galaxies=[ag.Galaxy(redshift=0.5, mass=mass_profile)],
46-
# )
47-
# result = solver.solve(
48-
# tracer,
49-
# source_plane_coordinate=(0.0, 0.0),
50-
# )
51-
# print(result)
52-
# assert result
53-
#
54-
#
55-
# @pytest.mark.parametrize(
56-
# "source_plane_coordinate",
57-
# [
58-
# (0.0, 0.0),
59-
# (0.0, 1.0),
60-
# (1.0, 0.0),
61-
# (1.0, 1.0),
62-
# (0.5, 0.5),
63-
# (0.1, 0.1),
64-
# (-1.0, -1.0),
65-
# ],
66-
# )
67-
# def test_trivial(
68-
# source_plane_coordinate: Tuple[float, float],
69-
# grid,
70-
# solver,
71-
# ):
72-
# coordinates = solver.solve(
73-
# NullTracer(),
74-
# source_plane_coordinate=source_plane_coordinate,
75-
# )
76-
# coordinates = coordinates.array[~np.isnan(coordinates.array).any(axis=1)]
77-
# assert coordinates[0] == pytest.approx(source_plane_coordinate, abs=1.0e-1)
78-
#
79-
#
80-
# def test_real_example(grid, tracer):
81-
# solver = PointSolver.for_grid(
82-
# grid=grid,
83-
# pixel_scale_precision=0.001,
84-
# array_triangles_cls=CoordinateArrayTriangles,
85-
# )
86-
#
87-
# result = solver.solve(tracer, (0.07, 0.07))
88-
# assert len(result) == 5
89-
#
90-
#
91-
# def _test_jax(grid):
92-
# sizes = (5, 10, 15, 20, 25, 30, 35, 40, 45, 50)
93-
# run_times = []
94-
# init_times = []
95-
#
96-
# for size in sizes:
97-
# start = time.time()
98-
# solver = PointSolver.for_grid(
99-
# grid=grid,
100-
# pixel_scale_precision=0.001,
101-
# array_triangles_cls=CoordinateArrayTriangles,
102-
# max_containing_size=size,
103-
# )
104-
#
105-
# solver.solve(NullTracer(), (0.07, 0.07))
106-
#
107-
# repeats = 100
108-
#
109-
# done_init_time = time.time()
110-
# init_time = done_init_time - start
111-
# for _ in range(repeats):
112-
# _ = solver.solve(NullTracer(), (0.07, 0.07))
113-
#
114-
# # print(result)
115-
#
116-
# init_times.append(init_time)
117-
#
118-
# run_time = (time.time() - done_init_time) / repeats
119-
# run_times.append(run_time)
120-
#
121-
# print(f"Time taken for {size}: {run_time} ({init_time} to init)")
122-
#
123-
# from matplotlib import pyplot as plt
124-
#
125-
# plt.plot(sizes, run_times)
126-
# plt.show()
127-
#
128-
#
129-
# def test_real_example_jax(grid, tracer):
130-
# jax_solver = PointSolver.for_grid(
131-
# grid=grid,
132-
# pixel_scale_precision=0.001,
133-
# array_triangles_cls=CoordinateArrayTriangles,
134-
# )
135-
#
136-
# result = jax_solver.solve(
137-
# tracer=tracer,
138-
# source_plane_coordinate=(0.07, 0.07),
139-
# )
140-
#
141-
# assert len(result) == 5
10+
from autoarray.structures.triangles.coordinate_array import (
11+
CoordinateArrayTriangles,
12+
)
13+
14+
from autolens.mock import NullTracer
15+
16+
17+
@pytest.fixture(autouse=True)
18+
def register(tracer):
19+
af.Model.from_instance(tracer)
20+
21+
22+
@pytest.fixture
23+
def solver(grid):
24+
return PointSolver.for_grid(
25+
grid=grid,
26+
pixel_scale_precision=0.01,
27+
array_triangles_cls=CoordinateArrayTriangles,
28+
)
29+
30+
31+
def test_solver(solver):
32+
mass_profile = ag.mp.Isothermal(
33+
centre=(0.0, 0.0),
34+
einstein_radius=1.0,
35+
)
36+
tracer = Tracer(
37+
galaxies=[ag.Galaxy(redshift=0.5, mass=mass_profile)],
38+
)
39+
result = solver.solve(
40+
tracer,
41+
source_plane_coordinate=(0.0, 0.0),
42+
)
43+
print(result)
44+
assert result
45+
46+
47+
@pytest.mark.parametrize(
48+
"source_plane_coordinate",
49+
[
50+
(0.0, 0.0),
51+
(0.0, 1.0),
52+
(1.0, 0.0),
53+
(1.0, 1.0),
54+
(0.5, 0.5),
55+
(0.1, 0.1),
56+
(-1.0, -1.0),
57+
],
58+
)
59+
def test_trivial(
60+
source_plane_coordinate: Tuple[float, float],
61+
grid,
62+
solver,
63+
):
64+
coordinates = solver.solve(
65+
NullTracer(),
66+
source_plane_coordinate=source_plane_coordinate,
67+
)
68+
coordinates = coordinates.array[~np.isnan(coordinates.array).any(axis=1)]
69+
assert coordinates[0] == pytest.approx(source_plane_coordinate, abs=1.0e-1)
70+
71+
def test_real_example_jax(grid, tracer):
72+
jax_solver = PointSolver.for_grid(
73+
grid=grid,
74+
pixel_scale_precision=0.001,
75+
array_triangles_cls=CoordinateArrayTriangles,
76+
)
77+
78+
result = jax_solver.solve(
79+
tracer=tracer,
80+
source_plane_coordinate=(0.07, 0.07),
81+
)
82+
83+
assert len(result) == 15

0 commit comments

Comments
 (0)