Skip to content

Commit bbc535e

Browse files
Refactor classifer utils.
1 parent d91699a commit bbc535e

File tree

8 files changed

+135
-121
lines changed

8 files changed

+135
-121
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ __pycache__/
1111
.vscode
1212

1313
# Keras Models
14-
model/*
14+
*model/*
1515

1616
# Distribution / packaging
1717
.Python

app/main.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
11
from fastapi import FastAPI, File, UploadFile
22
from fire_classifier.predictor import ImagePredictor
3-
from fire_classifier.util.api import API
3+
from app.settings import API_SETTINGS
44

55
# A: init API
66
app = FastAPI(
7-
title=API["title"],
8-
description=API["description"],
9-
version=API["version"],
10-
openapi_tags=API["endpoints"],
7+
title=API_SETTINGS["title"],
8+
description=API_SETTINGS["description"],
9+
version=API_SETTINGS["version"],
10+
openapi_tags=API_SETTINGS["openapi_tags"],
1111
)
1212

1313
# B: init ML inference object, and the routes
14-
predictor_config_path = API["config_path"]
14+
predictor_config_path = API_SETTINGS["predictor_config_path"]
1515
predictor = ImagePredictor.init_from_config_path(predictor_config_path)
1616

1717

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,22 +10,22 @@
1010
# These are intended shown on the UI, related only to a specific endpoint.
1111
# Note: the value in the "name" field should match what goes in the
1212
# "tags" parameter of the corresponding app route in main.py!!
13-
API_ENDPOINTS = (
13+
API_ENDPOINT_DATA = (
1414
{
1515
"name": "Detect Fire",
1616
"description": "Predicts the possibility that a color image contains fire.",
1717
},
1818
)
1919

2020
# Tells the app how to find the config.yaml (for running ML inference)
21-
BASE_DIR = Path(__file__).resolve().parent.parent.parent
21+
BASE_DIR = Path(__file__).resolve().parent.parent
2222
CONFIG_PATH = os.path.join(BASE_DIR, "app", "config.yaml")
2323

2424
# Wraps all the API metadata as one dictionary
25-
API = {
25+
API_SETTINGS = {
2626
"title": API_TITLE,
2727
"description": API_DESCRIPTION,
2828
"version": API_VERSION,
29-
"endpoints": API_ENDPOINTS,
30-
"config_path": CONFIG_PATH,
29+
"openapi_tags": API_ENDPOINT_DATA,
30+
"predictor_config_path": CONFIG_PATH,
3131
}

fire_classifier/predictor.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,15 @@
44
from typing import Dict
55
import yaml
66

7-
from fire_classifier.preprocessing_utilities import read_from_file
8-
from fire_classifier.utils import load_model
7+
from fire_classifier.util import preprocessing
8+
from fire_classifier.util.model import ModelUtility
99

1010

1111
class ImagePredictor:
1212
def __init__(self, config: Dict[str, int or str]):
1313
self.model_paths = config["model_file_paths"]
1414
self.resize_size = config["resize_shape"]
15-
self.model = load_model(
16-
config["base_model_url"], self.model_paths, config["model_sha256"]
17-
)
15+
self.model = ModelUtility.reconstruct_model(config)
1816
self.targets = config["targets"]
1917

2018
@classmethod
@@ -35,7 +33,7 @@ def predict_from_array(self, arr) -> Dict[str, float]:
3533

3634
def predict_from_file(self, file_object):
3735
"""Converts uploaded image to a NumPy array and classifies it."""
38-
arr = read_from_file(file_object)
36+
arr = preprocessing.read_from_file(file_object)
3937
return self.predict_from_array(arr)
4038

4139

fire_classifier/util/__init__.py

Whitespace-only changes.

fire_classifier/util/model.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
import hashlib
2+
import os
3+
from typing import Dict
4+
from tensorflow import keras
5+
6+
7+
class ModelUtility:
8+
def __init__(self, config: Dict[str, str]):
9+
"""
10+
Instaniates a new object using data needed to load in the model.
11+
12+
Args:
13+
config(dict): contains the following fields of interest:
14+
base_model_url(str): the base url where the files are located
15+
model_file_paths(list): collection of all the files needed to
16+
eventually load the model
17+
model_sha256(str): the supposed hash of one of the files
18+
we need to download. Checked against the
19+
one we may already have in the codebase.
20+
"""
21+
self.url = config["base_model_url"]
22+
self.file_paths = config["model_file_paths"]
23+
self.file_sha256 = None
24+
if config["model_sha256"] is not None:
25+
self.file_sha256 = config["model_sha256"]
26+
27+
@classmethod
28+
def reconstruct_model(cls, config):
29+
'''Make a new instance, and load in the model straightaway.'''
30+
model_utility = cls(config)
31+
# detect save format
32+
save_format = 'composite'
33+
if config["model_file_paths"] and len(config["model_file_paths"]) == 1:
34+
save_format = 'h5'
35+
# load the model
36+
return model_utility.load_model(save_format)
37+
38+
def get_hash(self, filename):
39+
"""
40+
Computes the SHA256 hash of a given file.
41+
42+
This can then be used to ensure the model file(s) downloaded
43+
in this codebase are not corrupted.
44+
45+
Args:
46+
filename(str): the name of the file
47+
48+
Returns:
49+
bytes-like object
50+
"""
51+
sha256_hash = hashlib.sha256()
52+
with open(filename, "rb") as f:
53+
for byte_block in iter(lambda: f.read(4096), b""):
54+
sha256_hash.update(byte_block)
55+
56+
return sha256_hash.hexdigest()
57+
58+
def download_model(self):
59+
"""
60+
Downloads the model files in memory.
61+
62+
This will first check if the files are already present,
63+
and not corrupted, before downloading from the address
64+
specified in config.yaml.
65+
66+
Returns:
67+
None
68+
"""
69+
# Download only the model files that are needed
70+
for model_file_path in self.file_paths:
71+
if os.path.exists(model_file_path):
72+
if self.get_hash(model_file_path) == self.file_sha256:
73+
print(f"File already exists: {model_file_path}")
74+
else: # need to download the model
75+
model_file_url = f"{self.url}/{model_file_path}"
76+
keras.utils.get_file(
77+
origin=model_file_url,
78+
fname=model_file_path,
79+
cache_dir=".",
80+
cache_subdir="./model",
81+
)
82+
83+
def load_model(self, format="composite"):
84+
"""
85+
Model reconstruction.
86+
87+
This will first load the model in memory using the given files
88+
and save format
89+
90+
Args:
91+
format(str): currently this only supports 'composite'
92+
(which is for when the model is saved using a H5 + JSON)
93+
or 'h5' as the save format of the model.
94+
95+
Returns:
96+
keras.Model object
97+
"""
98+
99+
def _model_from_composite_format():
100+
"""Specific to using H5 + JSON as the save format"""
101+
params_file, layers_file = self.file_paths
102+
# load the model in memory
103+
with open(f"./model/{layers_file}") as f:
104+
model = keras.models.model_from_json(f.read()) # build the layers
105+
model.load_weights(f"./model/{params_file}") # load weights + biases
106+
return model
107+
108+
def _model_from_h5():
109+
"""Specific to using a single Hadoop(H5) file"""
110+
params_file = self.file_paths[0]
111+
return keras.models.load_model(params_file)
112+
113+
# First download the model, if needed
114+
self.download_model()
115+
# load the model in memory
116+
if format == "composite":
117+
return _model_from_composite_format()
118+
else: # assuming a single H5
119+
return _model_from_h5()
File renamed without changes.

fire_classifier/utils.py

Lines changed: 0 additions & 103 deletions
This file was deleted.

0 commit comments

Comments
 (0)