Skip to content

Commit b612b31

Browse files
Merge pull request #12 from UPstartDeveloper/refactoring
Refactoring
2 parents a842c5c + 2a614c5 commit b612b31

File tree

8 files changed

+165
-126
lines changed

8 files changed

+165
-126
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ __pycache__/
1111
.vscode
1212

1313
# Keras Models
14-
model/*
14+
*model/*
1515

1616
# Distribution / packaging
1717
.Python

app/main.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,21 @@
11
from fastapi import FastAPI, File, UploadFile
2-
from fire_classifier.predictor import ImagePredictor
2+
from classifier.predictor import ImagePredictor
3+
from app.settings import API_SETTINGS
34

4-
# A: define description for the API endpoints (shown on UI)
5-
endpoint_metadata = [
6-
{
7-
"name": "classify-image",
8-
"description": "Predicts the possibility that a RBG image contains fire.",
9-
},
10-
]
11-
# B: init API
5+
# A: init API
126
app = FastAPI(
13-
title="DeepFire",
14-
description="A REST API for detecting the presence of fire in an image.",
15-
version="0.0.2",
16-
openapi_tags=endpoint_metadata,
7+
title=API_SETTINGS["title"],
8+
description=API_SETTINGS["description"],
9+
version=API_SETTINGS["version"],
10+
openapi_tags=API_SETTINGS["openapi_tags"],
1711
)
1812

19-
# C: init ML inference object
20-
predictor_config_path = "./app/config.yaml"
13+
# B: init ML inference object, and the routes
14+
predictor_config_path = API_SETTINGS["predictor_config_path"]
2115
predictor = ImagePredictor.init_from_config_path(predictor_config_path)
2216

2317

24-
@app.post("/classify-image/", tags=["classify-image"])
18+
@app.post("/classify-image/", tags=["Detect Fire in an Image"])
2519
def create_upload_file(file: UploadFile = File(...)):
2620
"""Predicts the possibility that a RBG image contains fire."""
2721
return predictor.predict_from_file(file.file)

app/settings.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import os
2+
from pathlib import Path
3+
4+
# General info about the API
5+
API_TITLE = "DeepFire"
6+
API_DESCRIPTION = "A REST API for detecting the presence of fire in an image."
7+
API_VERSION = "0.0.2"
8+
9+
# Info about what data users can request.
10+
# These are intended shown on the UI, related only to a specific endpoint.
11+
# Note: the value in the "name" field should match what goes in the
12+
# "tags" parameter of the corresponding app route in main.py!!
13+
API_ENDPOINT_DATA = (
14+
{
15+
"name": "Detect Fire in an Image",
16+
"description": "Predicts the possibility that a color image contains fire.",
17+
},
18+
)
19+
20+
# Tells the app how to find the config.yaml (for running ML inference)
21+
BASE_DIR = Path(__file__).resolve().parent.parent
22+
CONFIG_PATH = os.path.join(BASE_DIR, "app", "config.yaml")
23+
24+
# Wraps all the API metadata as one dictionary
25+
API_SETTINGS = {
26+
"title": API_TITLE,
27+
"description": API_DESCRIPTION,
28+
"version": API_VERSION,
29+
"openapi_tags": API_ENDPOINT_DATA,
30+
"predictor_config_path": CONFIG_PATH,
31+
}
Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,15 @@
44
from typing import Dict
55
import yaml
66

7-
from fire_classifier.preprocessing_utilities import read_from_file
8-
from fire_classifier.utils import load_model
7+
from classifier.util import preprocessing
8+
from classifier.util.model import ModelUtility
99

1010

1111
class ImagePredictor:
1212
def __init__(self, config: Dict[str, int or str]):
1313
self.model_paths = config["model_file_paths"]
1414
self.resize_size = config["resize_shape"]
15-
self.model = load_model(
16-
config["base_model_url"], self.model_paths, config["model_sha256"]
17-
)
15+
self.model = ModelUtility.reconstruct_model(config)
1816
self.targets = config["targets"]
1917

2018
@classmethod
@@ -35,7 +33,7 @@ def predict_from_array(self, arr) -> Dict[str, float]:
3533

3634
def predict_from_file(self, file_object):
3735
"""Converts uploaded image to a NumPy array and classifies it."""
38-
arr = read_from_file(file_object)
36+
arr = preprocessing.read_from_file(file_object)
3937
return self.predict_from_array(arr)
4038

4139

classifier/util/model.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
import hashlib
2+
import os
3+
from typing import Dict
4+
from tensorflow import keras
5+
6+
7+
class ModelUtility:
8+
def __init__(self, config: Dict[str, str]):
9+
"""
10+
Instaniates a new object using data needed to load in the model.
11+
12+
Args:
13+
config(dict): contains the following fields of interest:
14+
base_model_url(str): the base url where the files are located
15+
model_file_paths(list): collection of all the files needed to
16+
eventually load the model
17+
model_sha256(str): the supposed hash of one of the files
18+
we need to download. Checked against the
19+
one we may already have in the codebase.
20+
"""
21+
self.url = config["base_model_url"]
22+
self.file_paths = config["model_file_paths"]
23+
self.file_sha256 = None
24+
if config["model_sha256"] is not None:
25+
self.file_sha256 = config["model_sha256"]
26+
27+
@classmethod
28+
def reconstruct_model(cls, config):
29+
'''Make a new instance, and load in the model straightaway.'''
30+
model_utility = cls(config)
31+
# detect save format
32+
save_format = 'composite'
33+
if config["model_file_paths"] and len(config["model_file_paths"]) == 1:
34+
save_format = 'h5'
35+
# load the model
36+
return model_utility.load_model(save_format)
37+
38+
def get_hash(self, filename):
39+
"""
40+
Computes the SHA256 hash of a given file.
41+
42+
This can then be used to ensure the model file(s) downloaded
43+
in this codebase are not corrupted.
44+
45+
Args:
46+
filename(str): the name of the file
47+
48+
Returns:
49+
bytes-like object
50+
"""
51+
sha256_hash = hashlib.sha256()
52+
with open(filename, "rb") as f:
53+
for byte_block in iter(lambda: f.read(4096), b""):
54+
sha256_hash.update(byte_block)
55+
56+
return sha256_hash.hexdigest()
57+
58+
def download_model(self):
59+
"""
60+
Downloads the model files in memory.
61+
62+
This will first check if the files are already present,
63+
and not corrupted, before downloading from the address
64+
specified in config.yaml.
65+
66+
Returns:
67+
None
68+
"""
69+
# Download only the model files that are needed
70+
for model_file_path in self.file_paths:
71+
if os.path.exists(model_file_path):
72+
if self.get_hash(model_file_path) == self.file_sha256:
73+
print(f"File already exists: {model_file_path}")
74+
else: # need to download the model
75+
model_file_url = f"{self.url}/{model_file_path}"
76+
keras.utils.get_file(
77+
origin=model_file_url,
78+
fname=model_file_path,
79+
cache_dir=".",
80+
cache_subdir="./model",
81+
)
82+
83+
def load_model(self, format="composite"):
84+
"""
85+
Model reconstruction.
86+
87+
This will first load the model in memory using the given files
88+
and save format
89+
90+
Args:
91+
format(str): currently this only supports 'composite'
92+
(which is for when the model is saved using a H5 + JSON)
93+
or 'h5' as the save format of the model.
94+
95+
Returns:
96+
keras.Model object
97+
"""
98+
99+
def _model_from_composite_format():
100+
"""Specific to using H5 + JSON as the save format"""
101+
params_file, layers_file = self.file_paths
102+
# load the model in memory
103+
with open(f"./model/{layers_file}") as f:
104+
model = keras.models.model_from_json(f.read()) # build the layers
105+
model.load_weights(f"./model/{params_file}") # load weights + biases
106+
return model
107+
108+
def _model_from_h5():
109+
"""Specific to using a single Hadoop(H5) file"""
110+
params_file = self.file_paths[0]
111+
return keras.models.load_model(params_file)
112+
113+
# First download the model, if needed
114+
self.download_model()
115+
# load the model in memory
116+
if format == "composite":
117+
return _model_from_composite_format()
118+
else: # assuming a single H5
119+
return _model_from_h5()
File renamed without changes.

fire_classifier/utils.py

Lines changed: 0 additions & 103 deletions
This file was deleted.

0 commit comments

Comments
 (0)