|
1 | 1 | import argparse |
2 | 2 |
|
3 | 3 | import numpy as np |
| 4 | +from typing import List |
4 | 5 | import yaml |
5 | | -from tensorflow.keras.models import load_model |
6 | 6 |
|
7 | 7 | from fire_classifier.preprocessing_utilities import ( |
8 | 8 | read_img_from_path, |
9 | 9 | read_from_file, |
10 | 10 | ) |
11 | | -from fire_classifier.utils import download_model |
| 11 | +from fire_classifier.utils import download_model, load_model |
12 | 12 |
|
13 | 13 |
|
14 | 14 | class ImagePredictor: |
15 | | - def __init__(self, model_path, resize_size, targets): |
16 | | - self.model_path = model_path |
17 | | - self.model = load_model(self.model_path) |
| 15 | + def __init__( |
| 16 | + self, model_paths: List[str], resize_size: List[int], |
| 17 | + base_download_url: str, targets: List[str] |
| 18 | + ): |
| 19 | + self.model_paths = model_paths |
18 | 20 | self.resize_size = resize_size |
| 21 | + self.model = load_model(base_download_url, self.model_paths) |
19 | 22 | self.targets = targets |
20 | 23 |
|
21 | 24 | @classmethod |
22 | 25 | def init_from_config_path(cls, config_path): |
| 26 | + # load details for setting up the model |
23 | 27 | with open(config_path, "r") as f: |
24 | 28 | config = yaml.load(f, yaml.SafeLoader) |
| 29 | + # use the config data, to integrate the model into the new object |
25 | 30 | predictor = cls( |
26 | | - model_path=config["model_path"], |
| 31 | + model_paths=config["model_paths"], |
27 | 32 | resize_size=config["resize_shape"], |
| 33 | + base_download_url=config["base_model_url"], |
28 | 34 | targets=config["targets"], |
29 | 35 | ) |
30 | 36 | return predictor |
31 | 37 |
|
32 | 38 | @classmethod |
33 | 39 | def init_from_config_url(cls, config_path): |
34 | | - with open(config_path, "r") as f: |
35 | | - config = yaml.load(f, yaml.SafeLoader) |
| 40 | + # with open(config_path, "r") as f: |
| 41 | + # config = yaml.load(f, yaml.SafeLoader) |
36 | 42 |
|
37 | | - download_model( |
38 | | - config["model_url"], config["model_path"], config["model_sha256"] |
39 | | - ) |
| 43 | + # download_model( |
| 44 | + # config["model_file_urls"], config["model_paths"], config["model_sha256"] |
| 45 | + # ) |
40 | 46 |
|
41 | 47 | return cls.init_from_config_path(config_path) |
42 | 48 |
|
|
0 commit comments