Skip to content

Commit 6a92562

Browse files
authored
Merge pull request #919 from vloncar/fetch_example
Fix fetching models from example-models repo
2 parents 2cd8333 + 23e73ef commit 6a92562

File tree

2 files changed

+49
-18
lines changed

2 files changed

+49
-18
lines changed

hls4ml/utils/example_models.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,18 @@
66

77
from .config import create_config
88

9+
ORGANIZATION = 'fastmachinelearning'
10+
BRANCH = 'master'
11+
912

1013
def _load_data_config_avai(model_name):
1114
"""
1215
Check data and configuration availability for each model from this file:
1316
14-
https://github.com/hls-fpga-machine-learning/example-models/blob/master/available_data_config.json
17+
https://github.com/fastmachinelearning/example-models/blob/master/available_data_config.json
1518
"""
1619

17-
link_to_list = (
18-
'https://raw.githubusercontent.com/hls-fpga-machine-learning/example-models/master/available_data_config.json'
19-
)
20+
link_to_list = f'https://raw.githubusercontent.com/{ORGANIZATION}/example-models/{BRANCH}/available_data_config.json'
2021

2122
temp_file, _ = urlretrieve(link_to_list)
2223

@@ -73,12 +74,8 @@ def _load_example_data(model_name):
7374
input_file_name = filtered_name + "_input.dat"
7475
output_file_name = filtered_name + "_output.dat"
7576

76-
link_to_input = (
77-
'https://raw.githubusercontent.com/hls-fpga-machine-learning/example-models/master/data/' + input_file_name
78-
)
79-
link_to_output = (
80-
'https://raw.githubusercontent.com/hls-fpga-machine-learning/example-models/master/data/' + output_file_name
81-
)
77+
link_to_input = f'https://raw.githubusercontent.com/{ORGANIZATION}/example-models/{BRANCH}/data/' + input_file_name
78+
link_to_output = f'https://raw.githubusercontent.com/{ORGANIZATION}/example-models/{BRANCH}/data/' + output_file_name
8279

8380
urlretrieve(link_to_input, input_file_name)
8481
urlretrieve(link_to_output, output_file_name)
@@ -91,9 +88,7 @@ def _load_example_config(model_name):
9188

9289
config_name = filtered_name + "_config.yml"
9390

94-
link_to_config = (
95-
'https://raw.githubusercontent.com/hls-fpga-machine-learning/example-models/master/config-files/' + config_name
96-
)
91+
link_to_config = f'https://raw.githubusercontent.com/{ORGANIZATION}/example-models/{BRANCH}/config-files/' + config_name
9792

9893
# Load the configuration as dictionary from file
9994
urlretrieve(link_to_config, config_name)
@@ -110,7 +105,7 @@ def fetch_example_model(model_name, backend='Vivado'):
110105
Download an example model (and example data & configuration if available) from github repo to working directory,
111106
and return the corresponding configuration:
112107
113-
https://github.com/hls-fpga-machine-learning/example-models
108+
https://github.com/fastmachinelearning/example-models
114109
115110
Use fetch_example_list() to see all the available models.
116111
@@ -122,15 +117,18 @@ def fetch_example_model(model_name, backend='Vivado'):
122117
dict: Dictionary that stores the configuration to the model
123118
"""
124119

125-
# Initilize the download link and model type
126-
download_link = 'https://raw.githubusercontent.com/hls-fpga-machine-learning/example-models/master/'
120+
# Initialize the download link and model type
121+
download_link = f'https://raw.githubusercontent.com/{ORGANIZATION}/example-models/{BRANCH}/'
127122
model_type = None
128123
model_config = None
129124

130125
# Check for model's type to update link
131126
if '.json' in model_name:
132127
model_type = 'keras'
133128
model_config = 'KerasJson'
129+
elif '.h5' in model_name:
130+
model_type = 'keras'
131+
model_config = 'KerasH5'
134132
elif '.pt' in model_name:
135133
model_type = 'pytorch'
136134
model_config = 'PytorchModel'
@@ -158,11 +156,12 @@ def fetch_example_model(model_name, backend='Vivado'):
158156

159157
if _config_is_available(model_name):
160158
config = _load_example_config(model_name)
159+
config[model_config] = model_name # Ensure that paths are correct
161160
else:
162161
config = _create_default_config(model_name, model_config, backend)
163162

164163
# If the model is a keras model then have to download its weight file as well
165-
if model_type == 'keras':
164+
if model_type == 'keras' and '.json' in model_name:
166165
model_weight_name = model_name[:-5] + "_weights.h5"
167166

168167
download_link_weight = download_link + model_type + '/' + model_weight_name
@@ -174,7 +173,7 @@ def fetch_example_model(model_name, backend='Vivado'):
174173

175174

176175
def fetch_example_list():
177-
link_to_list = 'https://raw.githubusercontent.com/hls-fpga-machine-learning/example-models/master/available_models.json'
176+
link_to_list = f'https://raw.githubusercontent.com/{ORGANIZATION}/example-models/{BRANCH}/available_models.json'
178177

179178
temp_file, _ = urlretrieve(link_to_list)
180179

test/pytest/test_fetch_example.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import ast
2+
import io
3+
from contextlib import redirect_stdout
4+
from pathlib import Path
5+
6+
import pytest
7+
8+
import hls4ml
9+
10+
test_root_path = Path(__file__).parent
11+
12+
13+
@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus'])
14+
def test_fetch_example_utils(backend):
15+
f = io.StringIO()
16+
with redirect_stdout(f):
17+
hls4ml.utils.fetch_example_list()
18+
out = f.getvalue()
19+
20+
model_list = ast.literal_eval(out) # Check if we indeed got a dictionary back
21+
22+
assert 'qkeras_mnist_cnn.json' in model_list['keras']
23+
24+
# This model has an example config that is also downloaded. Stored configurations don't set "Backend" value.
25+
config = hls4ml.utils.fetch_example_model('qkeras_mnist_cnn.json', backend=backend)
26+
config['KerasJson'] = 'qkeras_mnist_cnn.json'
27+
config['KerasH5']
28+
config['Backend'] = backend
29+
config['OutputDir'] = str(test_root_path / f'hls4mlprj_fetch_example_{backend}')
30+
31+
hls_model = hls4ml.converters.keras_to_hls(config)
32+
hls_model.compile() # For now, it is enough if it compiles, we're only testing downloading works as expected

0 commit comments

Comments
 (0)