44
55
66def get_hash (filename ):
7+ """
8+ Computes the SHA256 hash of a given file.
9+
10+ This can then be used to ensure the model file(s) downloaded
11+ in this codebase are not corrupted.
12+
13+ Args:
14+ filename(str): the name of the file
15+
16+ Returns:
17+ bytes-like object
18+ """
719 sha256_hash = hashlib .sha256 ()
820 with open (filename , "rb" ) as f :
921 for byte_block in iter (lambda : f .read (4096 ), b"" ):
@@ -13,32 +25,77 @@ def get_hash(filename):
1325
1426
1527def download_model (url , file_paths , file_sha256 = None ):
16- params_file , layers_file = file_paths
17- params_url , layers_url = (
18- f"{ url } /{ params_file } " ,
19- f"{ url } /{ layers_file } "
20- )
21- if (os .path .exists (params_file ) and os .path .exists (layers_file )
22- # and get_hash(layers_file) == file_sha256
23- ):
24- print ("File already exists" )
25- else : # download the model
26- keras .utils .get_file (
27- origin = layers_url , fname = layers_file ,
28- cache_dir = '.' , cache_subdir = "./model"
29- )
30- keras .utils .get_file (
31- origin = params_url , fname = params_file ,
32- cache_dir = '.' , cache_subdir = "./model"
33- )
34-
35- def load_model (url , file_paths ):
36- '''Model reconstruction using H5 + JSON'''
28+ """
29+ Downloads the model files in memory.
30+
31+ This will first check if the files are already present,
32+ and not corrupted, before downloading from the address
33+ specified in config.yaml.
34+
35+ Args:
36+ url(str): the base url where the files are located
37+ file_paths(List[str]): collection of all the files needed to
38+ eventually load the model
39+ file_sha256(str): the supposed hash of one of the files
40+ we need to download. Checked against the
41+ one we may already have in the codebase.
42+
43+ Returns:
44+ None
45+ """
46+ # Download only the model files that are needed
47+ for model_file_path in file_paths :
48+ if os .path .exists (model_file_path ):
49+ if get_hash (model_file_path ) == file_sha256 :
50+ print (f"File already exists: { model_file_path } " )
51+ else : # need to download the model
52+ model_file_url = f"{ url } /{ model_file_path } "
53+ keras .utils .get_file (
54+ origin = model_file_url , fname = model_file_path ,
55+ cache_dir = "." , cache_subdir = "./model"
56+ )
57+
58+
59+ def load_model (url , file_paths , file_sha256 = None , format = 'composite' ):
60+ """
61+ Model reconstruction.
62+
63+ This will first load the model in memory using the given files
64+ and save format
65+
66+ Args:
67+ url(str): the base url where the files are located
68+ file_paths(List[str]): collection of all the files needed to
69+ eventually load the model
70+ file_sha256(str): the supposed hash of one of the files
71+ we need to download. Checked against the
72+ one we may already have in the codebase.
73+ format(str): currently this only supports 'composite'
74+ (which is for when the model is saved using a H5 + JSON)
75+ or 'h5' as the save format of the model.
76+
77+ Returns:
78+ keras.Model object
79+ """
80+
81+ def _model_from_composite_format ():
82+ '''Specific to using H5 + JSON as the save format'''
83+ params_file , layers_file = file_paths
84+ # load the model in memory
85+ with open (f"./model/{ layers_file } " ) as f :
86+ model = keras .models .model_from_json (f .read ()) # build the layers
87+ model .load_weights (f"./model/{ params_file } " ) # load weights + biases
88+ return model
89+
90+ def _model_from_h5 ():
91+ '''Specific to using a single Hadoop(H5) file'''
92+ params_file = file_paths [0 ]
93+ return keras .models .load_model (params_file )
94+
3795 # First download the model, if needed
38- download_model (url , file_paths )
39- params_file , layers_file = file_paths
40- # Model reconstruction
41- with open (f"./model/{ layers_file } " ) as f :
42- model = keras .models .model_from_json (f .read ())
43- model .load_weights (f"./model/{ params_file } " )
44- return model
96+ download_model (url , file_paths , file_sha256 )
97+ # load the model in memory
98+ if format == 'composite' :
99+ return _model_from_composite_format ()
100+ else : # assuming a single H5
101+ return _model_from_h5 ()
0 commit comments