Skip to content

Commit d36d389

Browse files
Merge pull request #6 from UPstartDeveloper/lighter-model
Lighter model
2 parents 26f32ae + 3266299 commit d36d389

File tree

7 files changed

+76
-43
lines changed

7 files changed

+76
-43
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ __pycache__/
1111
.vscode
1212

1313
# Keras Models
14-
*.h5
14+
model/*
1515

1616
# Distribution / packaging
1717
.Python

LICENSE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
MIT License
22

3-
Copyright (c) 2020 Mansar Youness
3+
Copyright (c) 2021 Zain Raza
44

55
Permission is hereby granted, free of charge, to any person obtaining a copy
66
of this software and associated documentation files (the "Software"), to deal

app/config.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
resize_shape: [256, 256]
22
targets: ['Fire_Images']
3-
model_path: "fire_classifier.h5"
4-
model_url: "https://github.com/UPstartDeveloper/Fire-Detection-API/releases/download/v0.0.1/fire_classifier.h5"
5-
model_sha256: "b527ecea0f4786e45eb0242f263752c45fb55754"
3+
model_paths: ["fire_classifier_params.h5", "fire_classifier_layers.json"]
4+
base_model_url: "https://github.com/UPstartDeveloper/Fire-Detection-API/releases/download/v0.0.2"
5+
model_sha256: "26f32ae0666bbb83e11968935db0ec2ab06623d1"

app/main.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
from fire_classifier.predictor import ImagePredictor
44

55
app = FastAPI(
6-
title="Fire Detection API",
7-
description="Informs the probability that an image contains fire."
6+
title="Fire Detection API",
7+
description="Informs the probability that an image contains fire.",
88
)
99

1010
predictor_config_path = "./app/config.yaml"
@@ -14,5 +14,5 @@
1414

1515
@app.post("/classify-image/")
1616
def create_upload_file(file: UploadFile = File(...)):
17-
'''Predicts the possibility that a RBG image contains fire.'''
17+
"""Predicts the possibility that a RBG image contains fire."""
1818
return predictor.predict_from_file(file.file)

fire_classifier/predictor.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,48 @@
11
import argparse
22

33
import numpy as np
4+
from typing import List
45
import yaml
5-
from tensorflow.keras.models import load_model
66

77
from fire_classifier.preprocessing_utilities import (
88
read_img_from_path,
99
read_from_file,
1010
)
11-
from fire_classifier.utils import download_model
11+
from fire_classifier.utils import download_model, load_model
1212

1313

1414
class ImagePredictor:
15-
def __init__(self, model_path, resize_size, targets):
16-
self.model_path = model_path
17-
self.model = load_model(self.model_path)
15+
def __init__(
16+
self, model_paths: List[str], resize_size: List[int],
17+
base_download_url: str, targets: List[str]
18+
):
19+
self.model_paths = model_paths
1820
self.resize_size = resize_size
21+
self.model = load_model(base_download_url, self.model_paths)
1922
self.targets = targets
2023

2124
@classmethod
2225
def init_from_config_path(cls, config_path):
26+
# load details for setting up the model
2327
with open(config_path, "r") as f:
2428
config = yaml.load(f, yaml.SafeLoader)
29+
# use the config data, to integrate the model into the new object
2530
predictor = cls(
26-
model_path=config["model_path"],
31+
model_paths=config["model_paths"],
2732
resize_size=config["resize_shape"],
33+
base_download_url=config["base_model_url"],
2834
targets=config["targets"],
2935
)
3036
return predictor
3137

3238
@classmethod
3339
def init_from_config_url(cls, config_path):
34-
with open(config_path, "r") as f:
35-
config = yaml.load(f, yaml.SafeLoader)
40+
# with open(config_path, "r") as f:
41+
# config = yaml.load(f, yaml.SafeLoader)
3642

37-
download_model(
38-
config["model_url"], config["model_path"], config["model_sha256"]
39-
)
43+
# download_model(
44+
# config["model_file_urls"], config["model_paths"], config["model_sha256"]
45+
# )
4046

4147
return cls.init_from_config_path(config_path)
4248

fire_classifier/utils.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import hashlib
22
import os
3-
from tensorflow.keras.utils import get_file
3+
from tensorflow import keras
44

55

66
def get_hash(filename):
@@ -12,8 +12,33 @@ def get_hash(filename):
1212
return sha256_hash.hexdigest()
1313

1414

15-
def download_model(url, file_path, file_sha256):
16-
if os.path.exists(file_path) and get_hash(file_path) == file_sha256:
15+
def download_model(url, file_paths, file_sha256=None):
16+
params_file, layers_file = file_paths
17+
params_url, layers_url = (
18+
f"{url}/{params_file}",
19+
f"{url}/{layers_file}"
20+
)
21+
if (os.path.exists(params_file) and os.path.exists(layers_file)
22+
# and get_hash(layers_file) == file_sha256
23+
):
1724
print("File already exists")
18-
else:
19-
get_file(origin=url, fname=file_path, cache_dir=".", cache_subdir="")
25+
else: # download the model
26+
keras.utils.get_file(
27+
origin=layers_url, fname=layers_file,
28+
cache_dir='.', cache_subdir="./model"
29+
)
30+
keras.utils.get_file(
31+
origin=params_url, fname=params_file,
32+
cache_dir='.', cache_subdir="./model"
33+
)
34+
35+
def load_model(url, file_paths):
36+
'''Model reconstruction using H5 + JSON'''
37+
# First download the model, if needed
38+
download_model(url, file_paths)
39+
params_file, layers_file = file_paths
40+
# Model reconstruction
41+
with open(f"./model/{layers_file}") as f:
42+
model = keras.models.model_from_json(f.read())
43+
model.load_weights(f"./model/{params_file}")
44+
return model

setup.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,22 +11,24 @@
1111
except:
1212
REQUIRED = []
1313

14-
setup(name='YOUR_API_NAMME',
15-
version='0.1.0',
16-
description='ADD_A_DESCRIPTION',
17-
author='YOUR_NAME',
18-
author_email='YOUR_EMAIL',
19-
url='YOUR_REPO_URL',
20-
license='MIT',
21-
install_requires=REQUIRED,
22-
classifiers=[
23-
'Intended Audience :: Developers',
24-
'Intended Audience :: Education',
25-
'Intended Audience :: Science/Research',
26-
'License :: OSI Approved :: MIT License',
27-
'Programming Language :: Python :: 3',
28-
'Programming Language :: Python :: 3.6',
29-
'Topic :: Software Development :: Libraries',
30-
'Topic :: Software Development :: Libraries :: Python Modules'
31-
],
32-
packages=find_packages(exclude=("example", "app", "data", "docker", "tests")))
14+
setup(
15+
name="YOUR_API_NAMME",
16+
version="0.1.0",
17+
description="ADD_A_DESCRIPTION",
18+
author="YOUR_NAME",
19+
author_email="YOUR_EMAIL",
20+
url="YOUR_REPO_URL",
21+
license="MIT",
22+
install_requires=REQUIRED,
23+
classifiers=[
24+
"Intended Audience :: Developers",
25+
"Intended Audience :: Education",
26+
"Intended Audience :: Science/Research",
27+
"License :: OSI Approved :: MIT License",
28+
"Programming Language :: Python :: 3",
29+
"Programming Language :: Python :: 3.6",
30+
"Topic :: Software Development :: Libraries",
31+
"Topic :: Software Development :: Libraries :: Python Modules",
32+
],
33+
packages=find_packages(exclude=("example", "app", "data", "docker", "tests")),
34+
)

0 commit comments

Comments
 (0)