@@ -49,6 +49,19 @@ class RuntimeFailedConfig(ErrorConfig):
4949 pass
5050
5151
52+ class NpEncoder (json .JSONEncoder ):
53+ """ Class we use for dumping Numpy objects to JSON """
54+
55+ def default (self , obj ):
56+ if isinstance (obj , np .integer ):
57+ return int (obj )
58+ if isinstance (obj , np .floating ):
59+ return float (obj )
60+ if isinstance (obj , np .ndarray ):
61+ return obj .tolist ()
62+ return super (NpEncoder , self ).default (obj )
63+
64+
5265class TorchPlaceHolder ():
5366
5467 def __init__ (self ):
@@ -725,18 +738,6 @@ def compile_restrictions(restrictions: list, tune_params: dict):
725738 return func
726739
727740
728- class NpEncoder (json .JSONEncoder ):
729-
730- def default (self , obj ):
731- if isinstance (obj , np .integer ):
732- return int (obj )
733- if isinstance (obj , np .floating ):
734- return float (obj )
735- if isinstance (obj , np .ndarray ):
736- return obj .tolist ()
737- return super (NpEncoder , self ).default (obj )
738-
739-
740741def process_cache (cache , kernel_options , tuning_options , runner ):
741742 """cache file for storing tuned configurations
742743
@@ -871,16 +872,6 @@ def close_cache(cache):
871872def store_cache (key , params , tuning_options ):
872873 """ stores a new entry (key, params) to the cachefile """
873874
874- # create converter for dumping numpy objects to JSON
875- def JSONconverter (obj ):
876- if isinstance (obj , np .integer ):
877- return int (obj )
878- if isinstance (obj , np .floating ):
879- return float (obj )
880- if isinstance (obj , np .ndarray ):
881- return obj .tolist ()
882- return obj .__str__ ()
883-
884875 #logging.debug('store_cache called, cache=%s, cachefile=%s' % (tuning_options.cache, tuning_options.cachefile))
885876 if isinstance (tuning_options .cache , dict ):
886877 if not key in tuning_options .cache :
@@ -894,7 +885,7 @@ def JSONconverter(obj):
894885
895886 if tuning_options .cachefile :
896887 with open (tuning_options .cachefile , "a" ) as cachefile :
897- cachefile .write ("\n " + json .dumps ({ key : output_params }, default = JSONconverter )[1 :- 1 ] + "," )
888+ cachefile .write ("\n " + json .dumps ({ key : output_params }, cls = NpEncoder )[1 :- 1 ] + "," )
898889
899890
900891def dump_cache (obj : str , tuning_options ):
0 commit comments