Skip to content

Commit 8c9281b

Browse files
committed
ENH: add demo of binding widgets to sliders
1 parent cab4a67 commit 8c9281b

File tree

2 files changed

+144
-4
lines changed

2 files changed

+144
-4
lines changed

data_prototype/containers.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -165,15 +165,19 @@ def _split(input_dict):
165165
self._xyfuncs = _split(xyfuncs) if xyfuncs is not None else {}
166166
self._cache: MutableMapping[Union[str, int], Any] = LFUCache(64)
167167

168+
def _query_hash(self, coord_transform, size):
169+
# TODO find a better way to compute the hash key, this is not sentative to
170+
# scale changes, only limit changes
171+
data_bounds = tuple(coord_transform.transform([[0, 0], [1, 1]]).flatten())
172+
hash_key = hash((data_bounds, size))
173+
return hash_key
174+
168175
def query(
169176
self,
170177
coord_transform: _MatplotlibTransform,
171178
size: Tuple[int, int],
172179
) -> Tuple[Dict[str, Any], Union[str, int]]:
173-
# TODO find a better way to compute the hash key, this is not sentative to
174-
# scale changes, only limit changes
175-
data_bounds = tuple(coord_transform.transform([[0, 0], [1, 1]]).flatten())
176-
hash_key = hash((data_bounds, size))
180+
hash_key = self._query_hash(coord_transform, size)
177181
if hash_key in self._cache:
178182
return self._cache[hash_key], hash_key
179183

examples/widgets.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
"""
2+
======
3+
Slider
4+
======
5+
6+
In this example, sliders are used to control the frequency and amplitude of
7+
a sine wave.
8+
9+
See :doc:`/gallery/widgets/slider_snap_demo` for an example of having
10+
the ``Slider`` snap to discrete values.
11+
12+
See :doc:`/gallery/widgets/range_slider` for an example of using
13+
a ``RangeSlider`` to define a range of values.
14+
"""
15+
import inspect
16+
17+
import numpy as np
18+
import matplotlib.pyplot as plt
19+
from matplotlib.widgets import Slider, Button
20+
21+
from data_prototype.wrappers import LineWrapper
22+
from data_prototype.containers import FuncContainer
23+
24+
25+
class SliderContainer(FuncContainer):
26+
def __init__(self, xfuncs, /, **sliders):
27+
self._sliders = sliders
28+
for slider in sliders.values():
29+
slider.on_changed(
30+
lambda x, sld=slider: sld.ax.figure.canvas.draw_idle(),
31+
)
32+
33+
def get_needed_keys(f, offset=1):
34+
return tuple(inspect.signature(f).parameters)[offset:]
35+
36+
super().__init__(
37+
{
38+
k: (
39+
s,
40+
# this line binds the correct sliders to the functions
41+
# and makes lambdas that match the API FuncContainer needs
42+
lambda x, keys=get_needed_keys(f), f=f: f(x, *(sliders[k].val for k in keys)),
43+
)
44+
for k, (s, f) in xfuncs.items()
45+
},
46+
)
47+
48+
def _query_hash(self, coord_transform, size):
49+
key = super()._query_hash(coord_transform, size)
50+
# inject the slider values into the hashing logic
51+
return hash((key, tuple(s.val for s in self._sliders.values())))
52+
53+
54+
# Define initial parameters
55+
init_amplitude = 5
56+
init_frequency = 3
57+
58+
# Create the figure and the line that we will manipulate
59+
fig, ax = plt.subplots()
60+
ax.set_xlim(0, 1)
61+
ax.set_ylim(-7, 7)
62+
63+
ax.set_xlabel("Time [s]")
64+
65+
# adjust the main plot to make room for the sliders
66+
fig.subplots_adjust(left=0.25, bottom=0.25, right=0.75)
67+
68+
# Make a horizontal slider to control the frequency.
69+
axfreq = fig.add_axes([0.25, 0.1, 0.65, 0.03])
70+
freq_slider = Slider(
71+
ax=axfreq,
72+
label="Frequency [Hz]",
73+
valmin=0.1,
74+
valmax=30,
75+
valinit=init_frequency,
76+
)
77+
78+
# Make a vertically oriented slider to control the amplitude
79+
axamp = fig.add_axes([0.1, 0.25, 0.0225, 0.63])
80+
amp_slider = Slider(
81+
ax=axamp,
82+
label="Amplitude",
83+
valmin=0,
84+
valmax=10,
85+
valinit=init_amplitude,
86+
orientation="vertical",
87+
)
88+
89+
# Make a vertically oriented slider to control the phase
90+
axphase = fig.add_axes([0.85, 0.25, 0.0225, 0.63])
91+
phase_slider = Slider(
92+
ax=axphase,
93+
label="Phase [rad]",
94+
valmin=-2 * np.pi,
95+
valmax=2 * np.pi,
96+
valinit=0,
97+
orientation="vertical",
98+
)
99+
100+
# pick a cyclic color map
101+
cmap = plt.get_cmap("twilight")
102+
103+
# set up the data container
104+
fc = SliderContainer(
105+
{
106+
# the x data does not need the sliders values
107+
"x": (("N",), lambda t: t),
108+
"y": (
109+
("N",),
110+
# the y data needs all three sliders
111+
lambda t, amplitude, frequency, phase: amplitude * np.sin(2 * np.pi * frequency * t + phase),
112+
),
113+
# the color data has to take the x (because reasons), but just
114+
# needs the phase
115+
"color": ((1,), lambda t, phase: phase),
116+
},
117+
# bind the sliders to the data container
118+
amplitude=amp_slider,
119+
frequency=freq_slider,
120+
phase=phase_slider,
121+
)
122+
lw = LineWrapper(
123+
fc,
124+
# color map phase (scaled to 2pi and wrapped to [0, 1])
125+
{"color": lambda color: cmap((color / (2 * np.pi)) % 1)},
126+
lw=5,
127+
)
128+
ax.add_artist(lw)
129+
130+
131+
# Create a `matplotlib.widgets.Button` to reset the sliders to initial values.
132+
resetax = fig.add_axes([0.8, 0.025, 0.1, 0.04])
133+
button = Button(resetax, "Reset", hovercolor="0.975")
134+
button.on_clicked(lambda event: [sld.reset() for sld in (freq_slider, amp_slider, phase_slider)])
135+
136+
plt.show()

0 commit comments

Comments
 (0)