Skip to content

Commit 3266299

Browse files
Fix error downloading model files.
1 parent d01236c commit 3266299

File tree

3 files changed

+36
-19
lines changed

3 files changed

+36
-19
lines changed

app/config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
resize_shape: [256, 256]
22
targets: ['Fire_Images']
33
model_paths: ["fire_classifier_params.h5", "fire_classifier_layers.json"]
4-
model_file_urls: ["https://github.com/UPstartDeveloper/Fire-Detection-API/releases/download/v0.0.2/fire_classifier_params.h5", "https://github.com/UPstartDeveloper/Fire-Detection-API/releases/download/v0.0.2/fire_classifier_layers.json"]
4+
base_model_url: "https://github.com/UPstartDeveloper/Fire-Detection-API/releases/download/v0.0.2"
55
model_sha256: "26f32ae0666bbb83e11968935db0ec2ab06623d1"

fire_classifier/predictor.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,43 +3,46 @@
33
import numpy as np
44
from typing import List
55
import yaml
6-
from tensorflow.keras.models import load_model
76

87
from fire_classifier.preprocessing_utilities import (
98
read_img_from_path,
109
read_from_file,
1110
)
12-
from fire_classifier.utils import download_model
11+
from fire_classifier.utils import download_model, load_model
1312

1413

1514
class ImagePredictor:
1615
def __init__(
17-
self, model_paths: List[str], resize_size: List[int], targets: List[str]
16+
self, model_paths: List[str], resize_size: List[int],
17+
base_download_url: str, targets: List[str]
1818
):
1919
self.model_paths = model_paths
20-
self.model = load_model(self.model_paths)
2120
self.resize_size = resize_size
21+
self.model = load_model(base_download_url, self.model_paths)
2222
self.targets = targets
2323

2424
@classmethod
2525
def init_from_config_path(cls, config_path):
26+
# load details for setting up the model
2627
with open(config_path, "r") as f:
2728
config = yaml.load(f, yaml.SafeLoader)
29+
# use the config data, to integrate the model into the new object
2830
predictor = cls(
2931
model_paths=config["model_paths"],
3032
resize_size=config["resize_shape"],
33+
base_download_url=config["base_model_url"],
3134
targets=config["targets"],
3235
)
3336
return predictor
3437

3538
@classmethod
3639
def init_from_config_url(cls, config_path):
37-
with open(config_path, "r") as f:
38-
config = yaml.load(f, yaml.SafeLoader)
40+
# with open(config_path, "r") as f:
41+
# config = yaml.load(f, yaml.SafeLoader)
3942

40-
download_model(
41-
config["model_file_urls"], config["model_paths"], config["model_sha256"]
42-
)
43+
# download_model(
44+
# config["model_file_urls"], config["model_paths"], config["model_sha256"]
45+
# )
4346

4447
return cls.init_from_config_path(config_path)
4548

fire_classifier/utils.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,33 @@ def get_hash(filename):
1212
return sha256_hash.hexdigest()
1313

1414

15-
def download_model(urls, file_paths, file_sha256):
15+
def download_model(url, file_paths, file_sha256=None):
1616
params_file, layers_file = file_paths
17-
params_url, layers_url = urls
18-
if (
19-
os.path.exists(params_file)
20-
and os.path.exists(layers_file)
21-
and get_hash(layers_file) == file_sha256
17+
params_url, layers_url = (
18+
f"{url}/{params_file}",
19+
f"{url}/{layers_file}"
20+
)
21+
if (os.path.exists(params_file) and os.path.exists(layers_file)
22+
# and get_hash(layers_file) == file_sha256
2223
):
2324
print("File already exists")
24-
else:
25+
else: # download the model
2526
keras.utils.get_file(
26-
origin=params_url, fname=params_file, cache_subdir=""
27+
origin=layers_url, fname=layers_file,
28+
cache_dir='.', cache_subdir="./model"
2729
)
2830
keras.utils.get_file(
29-
origin=layers_url, fname=layers_file, cache_subdir=""
31+
origin=params_url, fname=params_file,
32+
cache_dir='.', cache_subdir="./model"
3033
)
34+
35+
def load_model(url, file_paths):
36+
'''Model reconstruction using H5 + JSON'''
37+
# First download the model, if needed
38+
download_model(url, file_paths)
39+
params_file, layers_file = file_paths
40+
# Model reconstruction
41+
with open(f"./model/{layers_file}") as f:
42+
model = keras.models.model_from_json(f.read())
43+
model.load_weights(f"./model/{params_file}")
44+
return model

0 commit comments

Comments
 (0)