Skip to content

Commit cae377f

Browse files
committed
Add PMTObserver
1 parent 7c6f709 commit cae377f

File tree

1 file changed

+83
-0
lines changed

1 file changed

+83
-0
lines changed

kernel_tuner/observers/pmt.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import numpy as np
2+
3+
from kernel_tuner.observers.observer import BenchmarkObserver
4+
5+
# check if pmt is installed
6+
try:
7+
import pmt
8+
except ImportError:
9+
pmt = None
10+
11+
12+
class PMTObserver(BenchmarkObserver):
13+
"""Observer that uses the PMT library to measure power
14+
15+
:param observables: One of:
16+
- A string specifying a single power meter to use
17+
- A list of string, specifying one or more power meters to use
18+
- A dictionary, specifying one or more power meters to use,
19+
including the device identifier. For arduino this should be for
20+
instance "/dev/ttyACM0". For nvml, it should correspond to the GPU
21+
id (e.g. '0', or '1'). For some sensors (such as rapl) the device
22+
id is not used, it should be 'None' in those cases.
23+
This observer will report "<platform>_energy>" and "<platform>_power" for
24+
all specified platforms.
25+
:type observables: string,list/dictionary
26+
27+
"""
28+
29+
def __init__(self, observable=None):
30+
if not pmt:
31+
raise ImportError("could not import pmt")
32+
33+
# User specifices a dictonary of platforms and corresponding device
34+
if type(observable) is dict:
35+
pass
36+
elif type(observable) is list:
37+
# user specifies a list of platforms as observable
38+
observable = dict([(obs, 0) for obs in observable])
39+
else:
40+
# User specifices a string (single platform) as observable
41+
observable = {observable: None}
42+
43+
print(observable)
44+
45+
supported = ["arduino", "jetson", "likwid", "nvml", "rapl", "rocm", "xilinx"]
46+
for obs in observable.keys():
47+
if not obs in supported:
48+
raise ValueError(f"Observable {obs} not in supported: {supported}")
49+
50+
self.pms = [pmt.get_pmt(obs[0], obs[1]) for obs in observable.items()]
51+
self.pm_names = list(observable.keys())
52+
53+
self.begin_states = [None] * len(self.pms)
54+
self.initialize_results(self.pm_names)
55+
56+
def initialize_results(self, pm_names):
57+
self.results = dict()
58+
for pm_name in pm_names:
59+
energy_result_name = f"{pm_name}_energy"
60+
power_result_name = f"{pm_name}_power"
61+
self.results[energy_result_name] = []
62+
self.results[power_result_name] = []
63+
64+
def after_start(self):
65+
self.begin_states = [pm.read() for pm in self.pms]
66+
67+
def after_finish(self):
68+
end_states = [pm.read() for pm in self.pms]
69+
for i in range(len(self.pms)):
70+
begin_state = self.begin_states[i]
71+
end_state = end_states[i]
72+
measured_energy = pmt.joules(begin_state, end_state)
73+
measured_power = pmt.watts(begin_state, end_state)
74+
pm_name = self.pm_names[i]
75+
energy_result_name = f"{pm_name}_energy"
76+
power_result_name = f"{pm_name}_power"
77+
self.results[energy_result_name].append(measured_energy)
78+
self.results[power_result_name].append(measured_power)
79+
80+
def get_results(self):
81+
averages = {key: np.average(values) for key, values in self.results.items()}
82+
self.initialize_results(self.pm_names)
83+
return averages

0 commit comments

Comments
 (0)