Skip to content

Commit 8380a34

Browse files
committed
Fix fetching models from example-models repo
1 parent 2f3ffd4 commit 8380a34

File tree

1 file changed

+16
-18
lines changed

1 file changed

+16
-18
lines changed

hls4ml/utils/example_models.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,18 @@
66

77
from .config import create_config
88

9+
ORGANIZATION = 'fastmachinelearning'
10+
BRANCH = 'master'
11+
912

1013
def _load_data_config_avai(model_name):
1114
"""
1215
Check data and configuration availability for each model from this file:
1316
14-
https://github.com/hls-fpga-machine-learning/example-models/blob/master/available_data_config.json
17+
https://github.com/fastmachinelearning/example-models/blob/master/available_data_config.json
1518
"""
1619

17-
link_to_list = (
18-
'https://raw.githubusercontent.com/hls-fpga-machine-learning/example-models/master/available_data_config.json'
19-
)
20+
link_to_list = f'https://raw.githubusercontent.com/{ORGANIZATION}/example-models/{BRANCH}/available_data_config.json'
2021

2122
temp_file, _ = urlretrieve(link_to_list)
2223

@@ -73,12 +74,8 @@ def _load_example_data(model_name):
7374
input_file_name = filtered_name + "_input.dat"
7475
output_file_name = filtered_name + "_output.dat"
7576

76-
link_to_input = (
77-
'https://raw.githubusercontent.com/hls-fpga-machine-learning/example-models/master/data/' + input_file_name
78-
)
79-
link_to_output = (
80-
'https://raw.githubusercontent.com/hls-fpga-machine-learning/example-models/master/data/' + output_file_name
81-
)
77+
link_to_input = f'https://raw.githubusercontent.com/{ORGANIZATION}/example-models/{BRANCH}/data/' + input_file_name
78+
link_to_output = f'https://raw.githubusercontent.com/{ORGANIZATION}/example-models/{BRANCH}/data/' + output_file_name
8279

8380
urlretrieve(link_to_input, input_file_name)
8481
urlretrieve(link_to_output, output_file_name)
@@ -91,9 +88,7 @@ def _load_example_config(model_name):
9188

9289
config_name = filtered_name + "_config.yml"
9390

94-
link_to_config = (
95-
'https://raw.githubusercontent.com/hls-fpga-machine-learning/example-models/master/config-files/' + config_name
96-
)
91+
link_to_config = f'https://raw.githubusercontent.com/{ORGANIZATION}/example-models/{BRANCH}/config-files/' + config_name
9792

9893
# Load the configuration as dictionary from file
9994
urlretrieve(link_to_config, config_name)
@@ -110,7 +105,7 @@ def fetch_example_model(model_name, backend='Vivado'):
110105
Download an example model (and example data & configuration if available) from github repo to working directory,
111106
and return the corresponding configuration:
112107
113-
https://github.com/hls-fpga-machine-learning/example-models
108+
https://github.com/fastmachinelearning/example-models
114109
115110
Use fetch_example_list() to see all the available models.
116111
@@ -122,15 +117,18 @@ def fetch_example_model(model_name, backend='Vivado'):
122117
dict: Dictionary that stores the configuration to the model
123118
"""
124119

125-
# Initilize the download link and model type
126-
download_link = 'https://raw.githubusercontent.com/hls-fpga-machine-learning/example-models/master/'
120+
# Initialize the download link and model type
121+
download_link = f'https://raw.githubusercontent.com/{ORGANIZATION}/example-models/{BRANCH}/'
127122
model_type = None
128123
model_config = None
129124

130125
# Check for model's type to update link
131126
if '.json' in model_name:
132127
model_type = 'keras'
133128
model_config = 'KerasJson'
129+
elif '.h5' in model_name:
130+
model_type = 'keras'
131+
model_config = 'KerasH5'
134132
elif '.pt' in model_name:
135133
model_type = 'pytorch'
136134
model_config = 'PytorchModel'
@@ -162,7 +160,7 @@ def fetch_example_model(model_name, backend='Vivado'):
162160
config = _create_default_config(model_name, model_config, backend)
163161

164162
# If the model is a keras model then have to download its weight file as well
165-
if model_type == 'keras':
163+
if model_type == 'keras' and '.json' in model_name:
166164
model_weight_name = model_name[:-5] + "_weights.h5"
167165

168166
download_link_weight = download_link + model_type + '/' + model_weight_name
@@ -174,7 +172,7 @@ def fetch_example_model(model_name, backend='Vivado'):
174172

175173

176174
def fetch_example_list():
177-
link_to_list = 'https://raw.githubusercontent.com/hls-fpga-machine-learning/example-models/master/available_models.json'
175+
link_to_list = f'https://raw.githubusercontent.com/{ORGANIZATION}/example-models/{BRANCH}/available_models.json'
178176

179177
temp_file, _ = urlretrieve(link_to_list)
180178

0 commit comments

Comments
 (0)