diff --git a/.gitignore b/.gitignore index 5cb5580..e370921 100644 --- a/.gitignore +++ b/.gitignore @@ -6,4 +6,5 @@ weights build/ *.egg-info/ -gradio_cached_examples \ No newline at end of file +gradio_cached_examples +/output/ diff --git a/Inference.py b/Inference.py index 61b70fc..0b6d5dc 100644 --- a/Inference.py +++ b/Inference.py @@ -1,5 +1,7 @@ import argparse -from fastsam import FastSAM, FastSAMPrompt +import pathlib + +from fastsam import FastSAM, FastSAMPrompt import ast import torch from PIL import Image @@ -42,7 +44,8 @@ def parse_args(): default="[0]", help="[1,0] 0:background, 1:foreground", ) - parser.add_argument("--box_prompt", type=str, default="[[0,0,0,0]]", help="[[x,y,w,h],[x2,y2,w2,h2]] support multiple boxes") + parser.add_argument("--box_prompt", type=str, default="[[0,0,0,0]]", + help="[[x,y,w,h],[x2,y2,w2,h2]] support multiple boxes") parser.add_argument( "--better_quality", type=str, @@ -77,44 +80,55 @@ def main(args): args.point_prompt = ast.literal_eval(args.point_prompt) args.box_prompt = convert_box_xywh_to_xyxy(ast.literal_eval(args.box_prompt)) args.point_label = ast.literal_eval(args.point_label) - input = Image.open(args.img_path) - input = input.convert("RGB") - everything_results = model( - input, - device=args.device, - retina_masks=args.retina, - imgsz=args.imgsz, - conf=args.conf, - iou=args.iou + img_path = pathlib.Path(args.img_path) + img_paths = [] + # iterate through entire folder if specified + if img_path.exists() and img_path.is_file(): + img_paths.append(img_path) + else: + img_formats = ["*.jpg", "*.png", "*.bmp"] + for img_format in img_formats: + img_paths.extend(img_path.glob(img_format)) + + for img_path in img_paths: + input_image = Image.open(img_path) + input_image = input_image.convert("RGB") + input_image = input_image.resize((args.imgsz, args.imgsz)) + + everything_results = model( + input_image, + device=args.device, + retina_masks=args.retina, + imgsz=args.imgsz, + conf=args.conf, + iou=args.iou ) - bboxes = None - points = None - point_label = None - prompt_process = FastSAMPrompt(input, everything_results, device=args.device) - if args.box_prompt[0][2] != 0 and args.box_prompt[0][3] != 0: + bboxes = None + points = None + point_label = None + prompt_process = FastSAMPrompt(input_image, everything_results, device=args.device) + if args.box_prompt[0][2] != 0 and args.box_prompt[0][3] != 0: ann = prompt_process.box_prompt(bboxes=args.box_prompt) bboxes = args.box_prompt - elif args.text_prompt != None: - ann = prompt_process.text_prompt(text=args.text_prompt) - elif args.point_prompt[0] != [0, 0]: - ann = prompt_process.point_prompt( - points=args.point_prompt, pointlabel=args.point_label + elif args.text_prompt != None: + ann = prompt_process.text_prompt(text=args.text_prompt) + elif args.point_prompt[0] != [0, 0]: + ann = prompt_process.point_prompt( + points=args.point_prompt, pointlabel=args.point_label + ) + points = args.point_prompt + point_label = args.point_label + else: + ann = prompt_process.everything_prompt() + prompt_process.plot( + annotations=ann, + output_path=args.output + img_path.name, + bboxes=bboxes, + points=points, + point_label=point_label, + withContours=args.withContours, + better_quality=args.better_quality, ) - points = args.point_prompt - point_label = args.point_label - else: - ann = prompt_process.everything_prompt() - prompt_process.plot( - annotations=ann, - output_path=args.output+args.img_path.split("/")[-1], - bboxes = bboxes, - points = points, - point_label = point_label, - withContours=args.withContours, - better_quality=args.better_quality, - ) - - if __name__ == "__main__": diff --git a/fastsam/model.py b/fastsam/model.py index a68f88c..12be2ce 100644 --- a/fastsam/model.py +++ b/fastsam/model.py @@ -8,6 +8,7 @@ model = FastSAM('last.pt') results = model.predict('ultralytics/assets/bus.jpg') """ +import traceback from ultralytics.yolo.cfg import get_cfg from ultralytics.yolo.engine.exporter import Exporter @@ -50,6 +51,8 @@ def predict(self, source=None, stream=False, **kwargs): try: return self.predictor(source, stream=stream) except Exception as e: + LOGGER.error("Failed to predict with: %s",e) + LOGGER.error(traceback.format_exc()) return None def train(self, **kwargs): diff --git a/requirements.txt b/requirements.txt index b40e8ff..4316eb5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ # Base----------------------------------- -matplotlib>=3.2.2 +matplotlib>=3.2.2, <3.10.0 opencv-python>=4.6.0 Pillow>=7.1.2 PyYAML>=5.3.1 @@ -13,7 +13,7 @@ pandas>=1.1.4 seaborn>=0.11.0 gradio==3.35.2 - +psutil>=6.0.0 # Ultralytics----------------------------------- # ultralytics == 8.0.120 diff --git a/ultralytics/nn/tasks.py b/ultralytics/nn/tasks.py index 3c2ba06..bf47de7 100644 --- a/ultralytics/nn/tasks.py +++ b/ultralytics/nn/tasks.py @@ -7,11 +7,18 @@ import torch import torch.nn as nn +from torch.nn.modules.batchnorm import BatchNorm2d +from torch.nn import Sequential, Conv2d, MaxPool2d, Upsample, ConvTranspose2d +from ultralytics.nn.modules.block import C2f, DFL, Bottleneck +from torch.nn.modules.container import ModuleList +from torch.nn.modules.activation import SiLU +from ultralytics.nn.modules.conv import Concat from ultralytics.nn.modules import (AIFI, C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x, Classify, Concat, Conv, Conv2, ConvTranspose, Detect, DWConv, DWConvTranspose2d, Focus, GhostBottleneck, GhostConv, HGBlock, HGStem, Pose, RepC3, RepConv, - RTDETRDecoder, Segment) -from ultralytics.yolo.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load + RTDETRDecoder, Segment, Proto) +from ultralytics.yolo.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load, \ + IterableSimpleNamespace from ultralytics.yolo.utils.checks import check_requirements, check_suffix, check_yaml from ultralytics.yolo.utils.loss import v8ClassificationLoss, v8DetectionLoss, v8PoseLoss, v8SegmentationLoss from ultralytics.yolo.utils.plotting import feature_visualization @@ -575,7 +582,17 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=False): def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False): """Loads a single model weights.""" - ckpt, weight = torch_safe_load(weight) # load ckpt + with torch.serialization.safe_globals([SegmentationModel, Sequential, Conv, Conv2, + Conv2d, BatchNorm2d, SiLU, C2f, ModuleList, Bottleneck, SPPF, MaxPool2d, + Upsample, Concat, Segment, + DFL, Proto, ConvTranspose2d, AIFI, C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, + BottleneckCSP, C2f, C3Ghost, C3x, + Classify, Concat, Conv, Conv2, ConvTranspose, Detect, DWConv, + DWConvTranspose2d, + Focus, GhostBottleneck, GhostConv, HGBlock, HGStem, Pose, RepC3, RepConv, + RTDETRDecoder, Segment, getattr, IterableSimpleNamespace + ]): + ckpt, weight = torch_safe_load(weight) # load ckpt args = {**DEFAULT_CFG_DICT, **(ckpt.get('train_args', {}))} # combine model and default args, preferring model args model = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model