Skip to content

Commit 36825f1

Browse files
committed
Merge with master: changed error handling
2 parents 5882b1a + ab2d870 commit 36825f1

File tree

2 files changed

+54
-56
lines changed

2 files changed

+54
-56
lines changed

kernel_tuner/file_utils.py

Lines changed: 40 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def store_output_file(output_filename: str, results, tune_params, objective="tim
134134
version, _ = output_file_schema("results")
135135
output_json = dict(results=output_data, schema_version=version)
136136
with open(output_filenamepath, "w+") as fh:
137-
json.dump(output_json, fh)
137+
json.dump(output_json, fh, cls=util.NpEncoder)
138138

139139

140140
def get_dependencies(package="kernel_tuner"):
@@ -185,41 +185,48 @@ def store_metadata_file(metadata_filename: str):
185185
metadata_filenamepath = Path(filename_ensure_json_extension(metadata_filename))
186186
make_filenamepath(metadata_filenamepath)
187187
metadata = {}
188+
supported_operating_systems = ["linux", "win32", "darwin"]
188189

189-
# differentiate between OSes, possible values: https://docs.python.org/3/library/sys.html#sys.platform
190-
if platform == "linux":
191-
os_string = "Linux"
192-
hardware_description_out = subprocess.run(["lshw", "-json"], capture_output=True)
193-
elif platform == "win32":
194-
os_string = "Windows"
195-
raise NotImplementedError(f"Hardware specification not yet implemented for Windows")
196-
elif platform == "darwin":
197-
os_string = "Mac"
198-
hardware_description_out = subprocess.run(
199-
[
200-
"system_profiler",
201-
"-json",
202-
"-detailLevel",
203-
"mini",
204-
"SPSoftwareDataType",
205-
"SPHardwareDataType",
206-
"SPiBridgeDataType",
207-
"SPPCIDataType",
208-
"SPMemoryDataType",
209-
"SPNVMeDataType",
210-
],
211-
capture_output=True,
212-
)
213-
else:
190+
if all(platform != supported for supported in supported_operating_systems):
214191
raise ValueError(f"Platform {platform} not supported for metadata collection")
215192

216-
# process the hardware description output
217-
hardware_description_string = hardware_description_out.stdout.decode("utf-8").strip()
218-
if hardware_description_string[0] == "{" and hardware_description_string[-1] == "}":
219-
# sometimes lshw outputs a list of length 1, sometimes just as a dict, schema wants a list
220-
hardware_description_string = "[" + hardware_description_string + "]"
221-
metadata["hardware"] = dict(hardware_description=json.loads(hardware_description_string))
222-
metadata["operating_system"] = os_string
193+
try:
194+
# differentiate between OSes, possible values: https://docs.python.org/3/library/sys.html#sys.platform
195+
if platform == "linux":
196+
os_string = "Linux"
197+
hardware_description_out = subprocess.run(["lshw", "-json"], capture_output=True)
198+
elif platform == "win32":
199+
os_string = "Windows"
200+
raise NotImplementedError("Hardware specification not yet implemented for Windows")
201+
elif platform == "darwin":
202+
os_string = "Mac"
203+
hardware_description_out = subprocess.run(
204+
[
205+
"system_profiler",
206+
"-json",
207+
"-detailLevel",
208+
"mini",
209+
"SPSoftwareDataType",
210+
"SPHardwareDataType",
211+
"SPiBridgeDataType",
212+
"SPPCIDataType",
213+
"SPMemoryDataType",
214+
"SPNVMeDataType",
215+
],
216+
capture_output=True,
217+
)
218+
else:
219+
raise ValueError("This code is supposed to be unreachable, the supported platform check has failed")
220+
221+
# process the hardware description output
222+
hardware_description_string = hardware_description_out.stdout.decode("utf-8").strip()
223+
if hardware_description_string[0] == "{" and hardware_description_string[-1] == "}":
224+
# sometimes lshw outputs a list of length 1, sometimes just as a dict, schema wants a list
225+
hardware_description_string = "[" + hardware_description_string + "]"
226+
metadata["hardware"] = dict(hardware_description=json.loads(hardware_description_string))
227+
metadata["operating_system"] = os_string
228+
except:
229+
metadata["hardware"] = "error retrieving hardware description"
223230

224231
# attempts to use nvidia-smi or rocm-smi if present
225232
device_query = {}

kernel_tuner/util.py

Lines changed: 14 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,19 @@ class RuntimeFailedConfig(ErrorConfig):
4949
pass
5050

5151

52+
class NpEncoder(json.JSONEncoder):
53+
""" Class we use for dumping Numpy objects to JSON """
54+
55+
def default(self, obj):
56+
if isinstance(obj, np.integer):
57+
return int(obj)
58+
if isinstance(obj, np.floating):
59+
return float(obj)
60+
if isinstance(obj, np.ndarray):
61+
return obj.tolist()
62+
return super(NpEncoder, self).default(obj)
63+
64+
5265
class TorchPlaceHolder():
5366

5467
def __init__(self):
@@ -725,18 +738,6 @@ def compile_restrictions(restrictions: list, tune_params: dict):
725738
return func
726739

727740

728-
class NpEncoder(json.JSONEncoder):
729-
730-
def default(self, obj):
731-
if isinstance(obj, np.integer):
732-
return int(obj)
733-
if isinstance(obj, np.floating):
734-
return float(obj)
735-
if isinstance(obj, np.ndarray):
736-
return obj.tolist()
737-
return super(NpEncoder, self).default(obj)
738-
739-
740741
def process_cache(cache, kernel_options, tuning_options, runner):
741742
"""cache file for storing tuned configurations
742743
@@ -871,16 +872,6 @@ def close_cache(cache):
871872
def store_cache(key, params, tuning_options):
872873
""" stores a new entry (key, params) to the cachefile """
873874

874-
# create converter for dumping numpy objects to JSON
875-
def JSONconverter(obj):
876-
if isinstance(obj, np.integer):
877-
return int(obj)
878-
if isinstance(obj, np.floating):
879-
return float(obj)
880-
if isinstance(obj, np.ndarray):
881-
return obj.tolist()
882-
return obj.__str__()
883-
884875
#logging.debug('store_cache called, cache=%s, cachefile=%s' % (tuning_options.cache, tuning_options.cachefile))
885876
if isinstance(tuning_options.cache, dict):
886877
if not key in tuning_options.cache:
@@ -894,7 +885,7 @@ def JSONconverter(obj):
894885

895886
if tuning_options.cachefile:
896887
with open(tuning_options.cachefile, "a") as cachefile:
897-
cachefile.write("\n" + json.dumps({ key: output_params }, default=JSONconverter)[1:-1] + ",")
888+
cachefile.write("\n" + json.dumps({ key: output_params }, cls=NpEncoder)[1:-1] + ",")
898889

899890

900891
def dump_cache(obj: str, tuning_options):

0 commit comments

Comments
 (0)