|
3 | 3 | import numpy as np |
4 | 4 | from typing import List |
5 | 5 | import yaml |
6 | | -from tensorflow.keras.models import load_model |
7 | 6 |
|
8 | 7 | from fire_classifier.preprocessing_utilities import ( |
9 | 8 | read_img_from_path, |
10 | 9 | read_from_file, |
11 | 10 | ) |
12 | | -from fire_classifier.utils import download_model |
| 11 | +from fire_classifier.utils import download_model, load_model |
13 | 12 |
|
14 | 13 |
|
15 | 14 | class ImagePredictor: |
16 | 15 | def __init__( |
17 | | - self, model_paths: List[str], resize_size: List[int], targets: List[str] |
| 16 | + self, model_paths: List[str], resize_size: List[int], |
| 17 | + base_download_url: str, targets: List[str] |
18 | 18 | ): |
19 | 19 | self.model_paths = model_paths |
20 | | - self.model = load_model(self.model_paths) |
21 | 20 | self.resize_size = resize_size |
| 21 | + self.model = load_model(base_download_url, self.model_paths) |
22 | 22 | self.targets = targets |
23 | 23 |
|
24 | 24 | @classmethod |
25 | 25 | def init_from_config_path(cls, config_path): |
| 26 | + # load details for setting up the model |
26 | 27 | with open(config_path, "r") as f: |
27 | 28 | config = yaml.load(f, yaml.SafeLoader) |
| 29 | + # use the config data, to integrate the model into the new object |
28 | 30 | predictor = cls( |
29 | 31 | model_paths=config["model_paths"], |
30 | 32 | resize_size=config["resize_shape"], |
| 33 | + base_download_url=config["base_model_url"], |
31 | 34 | targets=config["targets"], |
32 | 35 | ) |
33 | 36 | return predictor |
34 | 37 |
|
35 | 38 | @classmethod |
36 | 39 | def init_from_config_url(cls, config_path): |
37 | | - with open(config_path, "r") as f: |
38 | | - config = yaml.load(f, yaml.SafeLoader) |
| 40 | + # with open(config_path, "r") as f: |
| 41 | + # config = yaml.load(f, yaml.SafeLoader) |
39 | 42 |
|
40 | | - download_model( |
41 | | - config["model_file_urls"], config["model_paths"], config["model_sha256"] |
42 | | - ) |
| 43 | + # download_model( |
| 44 | + # config["model_file_urls"], config["model_paths"], config["model_sha256"] |
| 45 | + # ) |
43 | 46 |
|
44 | 47 | return cls.init_from_config_path(config_path) |
45 | 48 |
|
|
0 commit comments