Skip to content

Commit ab2d870

Browse files
Merge pull request #200 from KernelTuner/fix-json-errors
Fix json errors
2 parents 4e0a80e + 19c3435 commit ab2d870

File tree

2 files changed

+26
-31
lines changed

2 files changed

+26
-31
lines changed

kernel_tuner/file_utils.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def store_output_file(output_filename, results, tune_params, objective="time"):
131131
version, _ = output_file_schema("results")
132132
output_json = dict(results=output_data, schema_version=version)
133133
with open(output_filename, 'w+') as fh:
134-
json.dump(output_json, fh)
134+
json.dump(output_json, fh, cls=util.NpEncoder)
135135

136136

137137
def get_dependencies(package='kernel_tuner'):
@@ -184,15 +184,19 @@ def store_metadata_file(metadata_filename):
184184
metadata_filename = filename_ensure_json_extension(metadata_filename)
185185
metadata = {}
186186

187-
# lshw only works on Linux, this intentionally raises a FileNotFoundError when ran on systems that do not have it
188-
lshw_out = subprocess.run(["lshw", "-json"], capture_output=True)
187+
# lshw only works on Linux
188+
try:
189+
lshw_out = subprocess.run(["lshw", "-json"], capture_output=True)
189190

190-
# sometimes lshw outputs a list of length 1, sometimes just as a dict, schema wants a list
191-
lshw_string = lshw_out.stdout.decode('utf-8').strip()
192-
if lshw_string[0] == '{' and lshw_string[-1] == '}':
193-
lshw_string = '[' + lshw_string + ']'
191+
# sometimes lshw outputs a list of length 1, sometimes just as a dict, schema wants a list
192+
lshw_string = lshw_out.stdout.decode('utf-8').strip()
193+
if lshw_string[0] == '{' and lshw_string[-1] == '}':
194+
lshw_string = '[' + lshw_string + ']'
195+
hardware_desc = dict(lshw=json.loads(lshw_string))
196+
except:
197+
hardware_desc = dict(lshw=["lshw error"])
194198

195-
metadata["hardware"] = dict(lshw=json.loads(lshw_string))
199+
metadata["hardware"] = hardware_desc
196200

197201
# attempts to use nvidia-smi or rocm-smi if present
198202
device_query = {}

kernel_tuner/util.py

Lines changed: 14 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,19 @@ class RuntimeFailedConfig(ErrorConfig):
4949
pass
5050

5151

52+
class NpEncoder(json.JSONEncoder):
53+
""" Class we use for dumping Numpy objects to JSON """
54+
55+
def default(self, obj):
56+
if isinstance(obj, np.integer):
57+
return int(obj)
58+
if isinstance(obj, np.floating):
59+
return float(obj)
60+
if isinstance(obj, np.ndarray):
61+
return obj.tolist()
62+
return super(NpEncoder, self).default(obj)
63+
64+
5265
class TorchPlaceHolder():
5366

5467
def __init__(self):
@@ -725,18 +738,6 @@ def compile_restrictions(restrictions: list, tune_params: dict):
725738
return func
726739

727740

728-
class NpEncoder(json.JSONEncoder):
729-
730-
def default(self, obj):
731-
if isinstance(obj, np.integer):
732-
return int(obj)
733-
if isinstance(obj, np.floating):
734-
return float(obj)
735-
if isinstance(obj, np.ndarray):
736-
return obj.tolist()
737-
return super(NpEncoder, self).default(obj)
738-
739-
740741
def process_cache(cache, kernel_options, tuning_options, runner):
741742
"""cache file for storing tuned configurations
742743
@@ -871,16 +872,6 @@ def close_cache(cache):
871872
def store_cache(key, params, tuning_options):
872873
""" stores a new entry (key, params) to the cachefile """
873874

874-
# create converter for dumping numpy objects to JSON
875-
def JSONconverter(obj):
876-
if isinstance(obj, np.integer):
877-
return int(obj)
878-
if isinstance(obj, np.floating):
879-
return float(obj)
880-
if isinstance(obj, np.ndarray):
881-
return obj.tolist()
882-
return obj.__str__()
883-
884875
#logging.debug('store_cache called, cache=%s, cachefile=%s' % (tuning_options.cache, tuning_options.cachefile))
885876
if isinstance(tuning_options.cache, dict):
886877
if not key in tuning_options.cache:
@@ -894,7 +885,7 @@ def JSONconverter(obj):
894885

895886
if tuning_options.cachefile:
896887
with open(tuning_options.cachefile, "a") as cachefile:
897-
cachefile.write("\n" + json.dumps({ key: output_params }, default=JSONconverter)[1:-1] + ",")
888+
cachefile.write("\n" + json.dumps({ key: output_params }, cls=NpEncoder)[1:-1] + ",")
898889

899890

900891
def dump_cache(obj: str, tuning_options):

0 commit comments

Comments
 (0)