Skip to content

Commit c81f7d2

Browse files
Merge pull request #8 from UPstartDeveloper/code-quality
Improve code documentation.
2 parents 5d0b89a + b37317e commit c81f7d2

File tree

7 files changed

+124
-77
lines changed

7 files changed

+124
-77
lines changed

app.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# Use this if you choose to deploy on the GCP App Engine!
12
runtime: custom
23
env: flex
34
service: default

app/config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
resize_shape: [256, 256]
22
targets: ['Fire_Images']
3-
model_paths: ["fire_classifier_params.h5", "fire_classifier_layers.json"]
3+
model_file_paths: ["fire_classifier_params.h5", "fire_classifier_layers.json"]
44
base_model_url: "https://github.com/UPstartDeveloper/Fire-Detection-API/releases/download/v0.0.2"
55
model_sha256: "26f32ae0666bbb83e11968935db0ec2ab06623d1"

app/main.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
11
from fastapi import FastAPI, File, UploadFile
2-
32
from fire_classifier.predictor import ImagePredictor
43

4+
# init API
55
app = FastAPI(
66
title="Fire Detection API",
77
description="Informs the probability that an image contains fire.",
88
)
99

10+
# init ML inference object
1011
predictor_config_path = "./app/config.yaml"
11-
12-
predictor = ImagePredictor.init_from_config_url(predictor_config_path)
12+
predictor = ImagePredictor.init_from_config_path(predictor_config_path)
1313

1414

1515
@app.post("/classify-image/")
1616
def create_upload_file(file: UploadFile = File(...)):
17-
"""Predicts the possibility that a RBG image contains fire."""
17+
'''Predicts the possibility that a RBG image contains fire.'''
1818
return predictor.predict_from_file(file.file)

fire_classifier/predictor.py

Lines changed: 20 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,68 +1,48 @@
11
import argparse
22

33
import numpy as np
4-
from typing import List
4+
from typing import Dict
55
import yaml
66

7-
from fire_classifier.preprocessing_utilities import (
8-
read_img_from_path,
9-
read_from_file,
10-
)
11-
from fire_classifier.utils import download_model, load_model
7+
from fire_classifier.preprocessing_utilities import read_from_file
8+
from fire_classifier.utils import load_model
129

1310

1411
class ImagePredictor:
15-
def __init__(
16-
self, model_paths: List[str], resize_size: List[int],
17-
base_download_url: str, targets: List[str]
18-
):
19-
self.model_paths = model_paths
20-
self.resize_size = resize_size
21-
self.model = load_model(base_download_url, self.model_paths)
22-
self.targets = targets
12+
def __init__(self, config: Dict[str, int or str]):
13+
self.model_paths = config["model_file_paths"]
14+
self.resize_size = config["resize_shape"]
15+
self.model = load_model(
16+
config["base_model_url"], self.model_paths, config["model_sha256"]
17+
)
18+
self.targets = config["targets"]
2319

2420
@classmethod
2521
def init_from_config_path(cls, config_path):
22+
'''Parses the config file, and instantiates a new ImagePredictor'''
2623
# load details for setting up the model
2724
with open(config_path, "r") as f:
2825
config = yaml.load(f, yaml.SafeLoader)
29-
# use the config data, to integrate the model into the new object
30-
predictor = cls(
31-
model_paths=config["model_paths"],
32-
resize_size=config["resize_shape"],
33-
base_download_url=config["base_model_url"],
34-
targets=config["targets"],
35-
)
26+
# use the config data to integrate the model into the new instance
27+
predictor = cls(config)
3628
return predictor
3729

38-
@classmethod
39-
def init_from_config_url(cls, config_path):
40-
# with open(config_path, "r") as f:
41-
# config = yaml.load(f, yaml.SafeLoader)
42-
43-
# download_model(
44-
# config["model_file_urls"], config["model_paths"], config["model_sha256"]
45-
# )
46-
47-
return cls.init_from_config_path(config_path)
48-
49-
def predict_from_array(self, arr):
30+
def predict_from_array(self, arr) -> Dict[str, float]:
31+
'''Returns a prediction value the sample belongs to each class.'''
5032
pred = self.model.predict(arr[np.newaxis, ...]).ravel().tolist()
51-
pred = [round(x, 3) for x in pred]
52-
return {k: v for k, v in zip(self.targets, pred)}
53-
54-
def predict_from_path(self, path):
55-
arr = read_img_from_path(path)
56-
return self.predict_from_array(arr)
33+
pred = [round(x, 3) for x in pred] # values between 0-1
34+
return {class_label: prob for class_label, prob in zip(self.targets, pred)}
5735

5836
def predict_from_file(self, file_object):
37+
'''Converts uploaded image to a NumPy array and classifies it.'''
5938
arr = read_from_file(file_object)
6039
return self.predict_from_array(arr)
6140

6241

6342
if __name__ == "__main__":
6443
"""
65-
python predictor.py --predictor_config "../example/predictor_config.yaml"
44+
Test out the predictor class via the CLI:
45+
python predictor.py --predictor_config "../example/predictor_config.yaml"
6646
6747
"""
6848
parser = argparse.ArgumentParser()

fire_classifier/preprocessing_utilities.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,20 @@
22
import numpy as np
33

44

5-
def read_img_from_path(path):
6-
img = cv2.imread(path, cv2.IMREAD_COLOR)
7-
return img
5+
def read_from_file(file_object):
6+
"""
7+
Produces a 3D array representing a color image.
88
9+
NumPy creates a new 1D array from the file object,
10+
and then using OpenCV we convert it to the proper 3D array
11+
that the model can run inference on.
912
10-
def read_from_file(file_object):
13+
Args:
14+
file_object(fastapi.UploadFile): the uploaded image
15+
16+
Returns:
17+
img_np: array-like object
18+
"""
1119
arr = np.fromstring(file_object.read(), np.uint8)
1220
img_np = cv2.imdecode(arr, cv2.IMREAD_COLOR)
1321

fire_classifier/utils.py

Lines changed: 85 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,18 @@
44

55

66
def get_hash(filename):
7+
"""
8+
Computes the SHA256 hash of a given file.
9+
10+
This can then be used to ensure the model file(s) downloaded
11+
in this codebase are not corrupted.
12+
13+
Args:
14+
filename(str): the name of the file
15+
16+
Returns:
17+
bytes-like object
18+
"""
719
sha256_hash = hashlib.sha256()
820
with open(filename, "rb") as f:
921
for byte_block in iter(lambda: f.read(4096), b""):
@@ -13,32 +25,77 @@ def get_hash(filename):
1325

1426

1527
def download_model(url, file_paths, file_sha256=None):
16-
params_file, layers_file = file_paths
17-
params_url, layers_url = (
18-
f"{url}/{params_file}",
19-
f"{url}/{layers_file}"
20-
)
21-
if (os.path.exists(params_file) and os.path.exists(layers_file)
22-
# and get_hash(layers_file) == file_sha256
23-
):
24-
print("File already exists")
25-
else: # download the model
26-
keras.utils.get_file(
27-
origin=layers_url, fname=layers_file,
28-
cache_dir='.', cache_subdir="./model"
29-
)
30-
keras.utils.get_file(
31-
origin=params_url, fname=params_file,
32-
cache_dir='.', cache_subdir="./model"
33-
)
34-
35-
def load_model(url, file_paths):
36-
'''Model reconstruction using H5 + JSON'''
28+
"""
29+
Downloads the model files in memory.
30+
31+
This will first check if the files are already present,
32+
and not corrupted, before downloading from the address
33+
specified in config.yaml.
34+
35+
Args:
36+
url(str): the base url where the files are located
37+
file_paths(List[str]): collection of all the files needed to
38+
eventually load the model
39+
file_sha256(str): the supposed hash of one of the files
40+
we need to download. Checked against the
41+
one we may already have in the codebase.
42+
43+
Returns:
44+
None
45+
"""
46+
# Download only the model files that are needed
47+
for model_file_path in file_paths:
48+
if os.path.exists(model_file_path):
49+
if get_hash(model_file_path) == file_sha256:
50+
print(f"File already exists: {model_file_path}")
51+
else: # need to download the model
52+
model_file_url = f"{url}/{model_file_path}"
53+
keras.utils.get_file(
54+
origin=model_file_url, fname=model_file_path,
55+
cache_dir=".", cache_subdir="./model"
56+
)
57+
58+
59+
def load_model(url, file_paths, file_sha256=None, format='composite'):
60+
"""
61+
Model reconstruction.
62+
63+
This will first load the model in memory using the given files
64+
and save format
65+
66+
Args:
67+
url(str): the base url where the files are located
68+
file_paths(List[str]): collection of all the files needed to
69+
eventually load the model
70+
file_sha256(str): the supposed hash of one of the files
71+
we need to download. Checked against the
72+
one we may already have in the codebase.
73+
format(str): currently this only supports 'composite'
74+
(which is for when the model is saved using a H5 + JSON)
75+
or 'h5' as the save format of the model.
76+
77+
Returns:
78+
keras.Model object
79+
"""
80+
81+
def _model_from_composite_format():
82+
'''Specific to using H5 + JSON as the save format'''
83+
params_file, layers_file = file_paths
84+
# load the model in memory
85+
with open(f"./model/{layers_file}") as f:
86+
model = keras.models.model_from_json(f.read()) # build the layers
87+
model.load_weights(f"./model/{params_file}") # load weights + biases
88+
return model
89+
90+
def _model_from_h5():
91+
'''Specific to using a single Hadoop(H5) file'''
92+
params_file = file_paths[0]
93+
return keras.models.load_model(params_file)
94+
3795
# First download the model, if needed
38-
download_model(url, file_paths)
39-
params_file, layers_file = file_paths
40-
# Model reconstruction
41-
with open(f"./model/{layers_file}") as f:
42-
model = keras.models.model_from_json(f.read())
43-
model.load_weights(f"./model/{params_file}")
44-
return model
96+
download_model(url, file_paths, file_sha256)
97+
# load the model in memory
98+
if format == 'composite':
99+
return _model_from_composite_format()
100+
else: # assuming a single H5
101+
return _model_from_h5()

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
except:
1212
REQUIRED = []
1313

14+
# Use this if you plan to turn this project into a PyPI package!
1415
setup(
1516
name="YOUR_API_NAMME",
1617
version="0.1.0",

0 commit comments

Comments
 (0)