Skip to content

Commit 65b8c55

Browse files
Reformatting + update API metadata.
1 parent dbfee0a commit 65b8c55

File tree

3 files changed

+28
-17
lines changed

3 files changed

+28
-17
lines changed

app/main.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,27 @@
11
from fastapi import FastAPI, File, UploadFile
22
from fire_classifier.predictor import ImagePredictor
33

4-
# init API
4+
# A: define description for the API endpoints (shown on UI)
5+
endpoint_metadata = [
6+
{
7+
"name": "classify-image",
8+
"description": "Predicts the possibility that a RBG image contains fire.",
9+
},
10+
]
11+
# B: init API
512
app = FastAPI(
6-
title="Fire Detection API",
7-
description="Informs the probability that an image contains fire.",
13+
title="DeepFire",
14+
description="A REST API for detecting the presence of fire in an image.",
15+
version="0.0.2",
16+
openapi_tags=endpoint_metadata,
817
)
918

10-
# init ML inference object
19+
# C: init ML inference object
1120
predictor_config_path = "./app/config.yaml"
1221
predictor = ImagePredictor.init_from_config_path(predictor_config_path)
1322

1423

15-
@app.post("/classify-image/")
24+
@app.post("/classify-image/", tags=["classify-image"])
1625
def create_upload_file(file: UploadFile = File(...)):
17-
'''Predicts the possibility that a RBG image contains fire.'''
26+
"""Predicts the possibility that a RBG image contains fire."""
1827
return predictor.predict_from_file(file.file)

fire_classifier/predictor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def __init__(self, config: Dict[str, int or str]):
1919

2020
@classmethod
2121
def init_from_config_path(cls, config_path):
22-
'''Parses the config file, and instantiates a new ImagePredictor'''
22+
"""Parses the config file, and instantiates a new ImagePredictor"""
2323
# load details for setting up the model
2424
with open(config_path, "r") as f:
2525
config = yaml.load(f, yaml.SafeLoader)
@@ -28,13 +28,13 @@ def init_from_config_path(cls, config_path):
2828
return predictor
2929

3030
def predict_from_array(self, arr) -> Dict[str, float]:
31-
'''Returns a prediction value the sample belongs to each class.'''
31+
"""Returns a prediction value the sample belongs to each class."""
3232
pred = self.model.predict(arr[np.newaxis, ...]).ravel().tolist()
3333
pred = [round(x, 3) for x in pred] # values between 0-1
3434
return {class_label: prob for class_label, prob in zip(self.targets, pred)}
3535

3636
def predict_from_file(self, file_object):
37-
'''Converts uploaded image to a NumPy array and classifies it.'''
37+
"""Converts uploaded image to a NumPy array and classifies it."""
3838
arr = read_from_file(file_object)
3939
return self.predict_from_array(arr)
4040

fire_classifier/utils.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,14 @@ def download_model(url, file_paths, file_sha256=None):
5151
else: # need to download the model
5252
model_file_url = f"{url}/{model_file_path}"
5353
keras.utils.get_file(
54-
origin=model_file_url, fname=model_file_path,
55-
cache_dir=".", cache_subdir="./model"
54+
origin=model_file_url,
55+
fname=model_file_path,
56+
cache_dir=".",
57+
cache_subdir="./model",
5658
)
5759

5860

59-
def load_model(url, file_paths, file_sha256=None, format='composite'):
61+
def load_model(url, file_paths, file_sha256=None, format="composite"):
6062
"""
6163
Model reconstruction.
6264
@@ -70,7 +72,7 @@ def load_model(url, file_paths, file_sha256=None, format='composite'):
7072
file_sha256(str): the supposed hash of one of the files
7173
we need to download. Checked against the
7274
one we may already have in the codebase.
73-
format(str): currently this only supports 'composite'
75+
format(str): currently this only supports 'composite'
7476
(which is for when the model is saved using a H5 + JSON)
7577
or 'h5' as the save format of the model.
7678
@@ -79,7 +81,7 @@ def load_model(url, file_paths, file_sha256=None, format='composite'):
7981
"""
8082

8183
def _model_from_composite_format():
82-
'''Specific to using H5 + JSON as the save format'''
84+
"""Specific to using H5 + JSON as the save format"""
8385
params_file, layers_file = file_paths
8486
# load the model in memory
8587
with open(f"./model/{layers_file}") as f:
@@ -88,14 +90,14 @@ def _model_from_composite_format():
8890
return model
8991

9092
def _model_from_h5():
91-
'''Specific to using a single Hadoop(H5) file'''
93+
"""Specific to using a single Hadoop(H5) file"""
9294
params_file = file_paths[0]
9395
return keras.models.load_model(params_file)
94-
96+
9597
# First download the model, if needed
9698
download_model(url, file_paths, file_sha256)
9799
# load the model in memory
98-
if format == 'composite':
100+
if format == "composite":
99101
return _model_from_composite_format()
100102
else: # assuming a single H5
101103
return _model_from_h5()

0 commit comments

Comments
 (0)