Skip to content

Commit 0ce417b

Browse files
committed
Add new generator: PowerTermGenerator
The new generator produces a line of 1-D points according to the equation: y = A((x-xf)/a)^n + yf The parameter yf acts as a focus point around which the point density will be highest, decreasing as we move away from it. The generator takes the arguments axis, units, start, stop, focus, exponent (n), and divisor (a) - 'start' is used to find xf (where y(xf)=yf): xf = a * nth root(|start - focus|) - 'stop' determines the scan size: size = int(f^-1(stop)) - 'A' will be 1 or -1, determined from start, stop, focus and exponent arguments (the sign of 'a' is ignored) - If the exponent is even, it is assumed that we will pass through the focus point.
1 parent 5836341 commit 0ce417b

File tree

3 files changed

+286
-0
lines changed

3 files changed

+286
-0
lines changed

scanpointgenerator/generators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,4 @@
1717
from scanpointgenerator.generators.linegenerator import LineGenerator
1818
from scanpointgenerator.generators.lissajousgenerator import LissajousGenerator
1919
from scanpointgenerator.generators.spiralgenerator import SpiralGenerator
20+
from scanpointgenerator.generators.powertermgenerator import PowerTermGenerator
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
from scanpointgenerator.core import Generator
2+
from scanpointgenerator.compat import np
3+
4+
5+
@Generator.register_subclass("scanpointgenerator:generator/PowerTermGenerator:1.0")
6+
class PowerTermGenerator(Generator):
7+
"""
8+
Generate a line of points according to the function
9+
y = sign * ((x-xf)/a)**n + yf
10+
11+
sign is determined by start, stop, focus and exponent parameters
12+
"""
13+
14+
def __init__(self, axis, units, start, stop, focus, exponent, divisor):
15+
"""
16+
y = ((x-x_focus)/divisor)**exponent + focus
17+
Args:
18+
axis (str): The scannable axis e.g. "dcm_energy"
19+
units (str): The scannable units e.g. "keV"
20+
start (float): The first position to be generated.
21+
stop (float): Will determine scan size. The final generated position will not necessarily be this...
22+
focus (float): Point of interest which will be most finely sampled
23+
e.g. 7.112 (for Fe K edge)
24+
exponent (int): If exponent is even, it is assumed we pass through the focus point.
25+
divisor (float): Sign will be ignored.
26+
"""
27+
28+
if divisor == 0:
29+
raise ValueError("Divisor must be non-zero")
30+
31+
if exponent < 1 or exponent != int(exponent):
32+
raise ValueError("Exponent must be a positive integer")
33+
34+
self.sign = get_suitable_sign(start, stop, focus, exponent)
35+
self.exponent = exponent
36+
self.divisor = np.abs(divisor)
37+
self.focus = focus
38+
self.axes = [axis]
39+
self.units = {axis: units}
40+
self.start = start
41+
self.stop = stop
42+
43+
self.xf = self.find_xf()
44+
self.size = int(self._inv_fn(stop))+1
45+
46+
def prepare_arrays(self, index_array):
47+
arrays = dict()
48+
arrays[self.axes[0]] = self._fn(index_array)
49+
return arrays
50+
51+
def _fn(self, x):
52+
return self.sign * np.power((x - self.xf) / self.divisor, self.exponent) + self.focus
53+
54+
def _inv_fn(self, y):
55+
nth_root = np.power(np.abs(y-self.focus), 1./self.exponent)
56+
if not self.stop_after_focus():
57+
nth_root *= -1
58+
return self.divisor * nth_root + self.xf
59+
60+
def find_xf(self):
61+
x = self.divisor * np.power(np.abs(self.start-self.focus), 1./self.exponent)
62+
return x if self.start_before_focus() else -x
63+
64+
def start_before_focus(self):
65+
if self.exponent % 2 == 0:
66+
return True
67+
return self.sign * self.start < self.sign * self.focus
68+
69+
def stop_after_focus(self):
70+
if self.exponent % 2 == 0:
71+
return True
72+
return self.sign * self.stop > self.sign * self.focus
73+
74+
def to_dict(self):
75+
d = dict()
76+
d['typeid'] = self.typeid
77+
d['axes'] = self.axes
78+
d['units'] = self.units[self.axes[0]]
79+
d['start'] = self.start
80+
d['stop'] = self.stop
81+
d['focus'] = self.focus
82+
d['exponent'] = self.exponent
83+
d['divisor'] = self.divisor
84+
85+
return d
86+
87+
@classmethod
88+
def from_dict(cls, d):
89+
axes = d['axes']
90+
units = d['units']
91+
start = d['start']
92+
stop = d['stop']
93+
exponent = d['exponent']
94+
divisor = d['divisor']
95+
focus = d['focus']
96+
97+
return cls(axes, units, start, stop, focus, exponent, divisor)
98+
99+
100+
def get_suitable_sign(start, stop, focus, exponent):
101+
if exponent % 2 == 1:
102+
return 1 if start < stop else -1
103+
else:
104+
if focus <= start and focus <= stop:
105+
return 1
106+
if focus >= start and focus >= stop:
107+
return -1
108+
raise ValueError("For even exponents, focus point must be either lowest or highest value")
Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
import os
2+
import sys
3+
import unittest
4+
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
5+
6+
from test_util import ScanPointGeneratorTest
7+
from scanpointgenerator import PowerTermGenerator
8+
9+
10+
def _get_gen(start, stop, focus, exponent):
11+
return PowerTermGenerator('x', 'mm', start, stop, focus, exponent, 1)
12+
13+
14+
class PowerGeneratorTest(ScanPointGeneratorTest):
15+
16+
def test_axis_and_units(self):
17+
gen = _get_gen(0, 100, 20, 3)
18+
self.assertEqual(gen.units, dict(x="mm"))
19+
20+
def test_array_positions(self):
21+
# We only need to test positions for one set of parameters
22+
# as long as we can reliably find the sign, xf and size
23+
gen = PowerTermGenerator("x", "mm", 260., 360., 280., 3, 10)
24+
25+
expected = [260., 262.12998637, 264.10310768, 265.92536394,
26+
267.60275514, 269.14128128, 270.54694237, 271.8257384,
27+
272.98366937, 274.02673528, 274.96093614, 275.79227194,
28+
276.52674269, 277.17034837, 277.729089, 278.20896458,
29+
278.61597509, 278.95612055, 279.23540095, 279.4598163,
30+
279.63536658, 279.76805182, 279.86387199, 279.92882711,
31+
279.96891717, 279.99014217, 279.99850211, 279.999997,
32+
280.00062683, 280.00639161, 280.02329133, 280.05732599,
33+
280.11449559, 280.20080014, 280.32223963, 280.48481406,
34+
280.69452344, 280.95736776, 281.27934702, 281.66646122,
35+
282.12471037, 282.66009446, 283.2786135, 283.98626747,
36+
284.78905639, 285.69298026, 286.70403906, 287.82823281,
37+
289.0715615, 290.44002514, 291.93962371, 293.57635724,
38+
295.3562257, 297.28522911, 299.36936746, 301.61464075,
39+
304.02704899, 306.61259216, 309.37727029, 312.32708335,
40+
315.46803136, 318.80611431, 322.3473322, 326.09768504,
41+
330.06317282, 334.24979554, 338.66355321, 343.31044582,
42+
348.19647337, 353.32763587, 358.7099333]
43+
44+
gen.prepare_positions()
45+
self.assertListAlmostEqual(gen.positions['x'], expected)
46+
47+
# The following 12 tests test that the equation sign, xf, and scan size
48+
# are correctly calculated in 12 distinct scenarios
49+
# (the first 6 with an odd exponent, the final 6 with an even exponent)
50+
51+
def test_params1(self):
52+
# 1) start < focus < stop
53+
gen = _get_gen(0, 100, 27, 3)
54+
self._check_params(gen, 1, 3, 8)
55+
56+
def test_params2(self):
57+
# 2) start < stop < focus
58+
gen = _get_gen(0, 26, 27, 3)
59+
self._check_params(gen, 1, 3, 3)
60+
61+
def test_params3(self):
62+
# 3) focus < start < stop
63+
gen = _get_gen(20, 85, 12, 3)
64+
self._check_params(gen, 1, -2, 3)
65+
66+
def test_params4(self):
67+
# 4) start > focus > stop
68+
gen = _get_gen(77, 0, 50, 3)
69+
self._check_params(gen, -1, 3, 7)
70+
71+
def test_params5(self):
72+
# 5) start > stop > focus
73+
gen = _get_gen(27, 1, 0, 3)
74+
self._check_params(gen, -1, 3, 3)
75+
76+
def test_params6(self):
77+
# 6) focus > start > stop
78+
gen = _get_gen(73, 0, 100, 3)
79+
self._check_params(gen, -1, -3, 2)
80+
81+
def test_params7(self):
82+
# 7) focus < start < stop
83+
gen = _get_gen(9, 8, 0, 2)
84+
self._check_params(gen, 1, 3, 6)
85+
86+
def test_params8(self):
87+
# 8) focus < stop < start
88+
gen = _get_gen(9, 12, 0, 2)
89+
self._check_params(gen, 1, 3, 7)
90+
91+
def test_params9(self):
92+
# 9) focus < start = stop
93+
gen = _get_gen(9, 9, 0, 2)
94+
self._check_params(gen, 1, 3, 7)
95+
96+
def test_params10(self):
97+
# 10) focus > start > stop
98+
gen = _get_gen(2, 0, 18, 2)
99+
self._check_params(gen, -1, 4, 9)
100+
101+
def test_params11(self):
102+
# 11) focus > stop > start
103+
gen = _get_gen(2, 4, 18, 2)
104+
self._check_params(gen, -1, 4, 8)
105+
106+
def test_params12(self):
107+
# 12) focus > stop = start
108+
gen = _get_gen(0, 0, 9, 2)
109+
self._check_params(gen, -1, 3, 7)
110+
111+
def _check_params(self, gen, sign, xf, size):
112+
self.assertEquals(gen.sign, sign)
113+
self.assertEquals(gen.xf, xf)
114+
self.assertEquals(gen.size, size)
115+
116+
def test_to_dict(self):
117+
g = PowerTermGenerator('energy', 'eV', 260., 350., 280., 3, 5)
118+
expected = dict()
119+
expected['typeid'] = "scanpointgenerator:generator/PowerTermGenerator:1.0"
120+
expected['axes'] = ['energy']
121+
expected['units'] = "eV"
122+
expected['start'] = 260.
123+
expected['stop'] = 350.
124+
expected['focus'] = 280.
125+
expected['exponent'] = 3
126+
expected['divisor'] = 5.
127+
128+
self.assertEquals(g.to_dict(), expected)
129+
130+
def test_from_dict(self):
131+
_dict = dict()
132+
_dict['axes'] = "x"
133+
_dict['units'] = "cm"
134+
_dict['start'] = 270.
135+
_dict['stop'] = 500.
136+
_dict['focus'] = 280.
137+
_dict['exponent'] = 3
138+
_dict['divisor'] = 20.5
139+
140+
units_dict = dict()
141+
units_dict['x'] = "cm"
142+
143+
gen = PowerTermGenerator.from_dict(_dict)
144+
145+
self.assertEqual(gen.axes, ["x"])
146+
self.assertEqual(gen.units, units_dict)
147+
self.assertEqual(gen.start, 270.)
148+
self.assertEqual(gen.stop, 500.)
149+
self.assertEqual(gen.focus, 280.)
150+
self.assertEqual(gen.exponent, 3)
151+
self.assertEqual(gen.divisor, 20.5)
152+
153+
# Argument validation tests
154+
def test_zero_divisor_raises(self):
155+
with self.assertRaises(ValueError):
156+
PowerTermGenerator('x', 'mm', 0, 10, 5, 3, 0.)
157+
158+
def test_fractional_exponent_raises(self):
159+
with self.assertRaises(ValueError):
160+
PowerTermGenerator('x', 'mm', 0, 10, 10, 3.5, 1)
161+
162+
def test_negative_exponent_raises(self):
163+
with self.assertRaises(ValueError):
164+
PowerTermGenerator('x', 'mm', 0, 10, 10, -5, 1)
165+
166+
def test_invalid_parameters1(self):
167+
# even exponent and start < focus < stop
168+
with self.assertRaises(ValueError):
169+
PowerTermGenerator('x', 'mm', 0, 100, 50, 2, 1)
170+
171+
def test_invalid_parameters2(self):
172+
# even exponent and start > focus > stop
173+
with self.assertRaises(ValueError):
174+
PowerTermGenerator('x', 'mm', 100, 0, 50, 2, 1)
175+
176+
if __name__ == "__main__":
177+
unittest.main(verbosity=2)

0 commit comments

Comments
 (0)