From 263f9280f8a973145dbcad14bbec5a6c4efc88a9 Mon Sep 17 00:00:00 2001 From: rain-Brian Date: Wed, 3 Jun 2026 14:23:16 -0700 Subject: [PATCH 1/2] DRAFT: remove migrated framework package from the hub, reframe as umbrella Completes the repo split by removing from the Biodiversity hub the framework code that now lives in dedicated repos, and reframes the hub homepage as the ecosystem umbrella. DRAFT: do not merge until the PytorchWildlife PyPI release is cut from microsoft/Pytorch-Wildlife (so pip install resolves there). Removed (verified migrated + parity-checked): - PytorchWildlife/ core package + setup.py + version.txt + MANIFEST.in -> microsoft/Pytorch-Wildlife (complete superset after that repo's release-readiness PR). - PW_Bioacoustics/ -> microsoft/MegaDetector-Acoustic (byte-for-byte identical, no hub drift). - docs/base/ + docs/fine_tuning_modules/ (mkdocstrings API reference) -> now hosted on the Pytorch-Wildlife docs site; removed the mkdocstrings plugin + deps + the Reference nav section. Reframed: - docs/index.md is now the Microsoft Biodiversity umbrella homepage (was 'Welcome to PyTorch-Wildlife'); it points the framework to microsoft/Pytorch-Wildlife and leads with the ecosystem. Top nav group renamed to Microsoft Biodiversity. Left in place and FLAGGED for a team decision (NOT safe to auto-delete): - PW_FT_classification/ has diverged from microsoft/MegaDetector-Classifier and carries hub-only fixes (security #628, GPU-crash #629); reconcile before removing. - PW_FT_detection/ has no dedicated repo yet; needs a destination. - demo/, requirements.txt, Dockerfile, Brewfile, archive/ left for review. Verified: mkdocs build --strict passes; hub sitemap 57 -> 27 (the ~30 API pages moved to the PW site, which grew 6 -> 40); homepage em-dash clean. --- MANIFEST.in | 3 - PW_Bioacoustics/README.md | 208 --- PW_Bioacoustics/__init__.py | 11 - PW_Bioacoustics/demo/README.md | 78 - PW_Bioacoustics/demo/bioacoustics_demo.ipynb | 1271 ----------------- PW_Bioacoustics/inference.py | 430 ------ PW_Bioacoustics/prepare_dataset.py | 365 ----- PW_Bioacoustics/template.yaml | 59 - PW_Bioacoustics/train.py | 380 ----- PytorchWildlife/__init__.py | 18 - PytorchWildlife/data/__init__.py | 3 - PytorchWildlife/data/bioacoustics/__init__.py | 12 - .../bioacoustics/bioacoustics_annotations.py | 367 ----- .../data/bioacoustics/bioacoustics_configs.py | 221 --- .../bioacoustics/bioacoustics_datasets.py | 499 ------- .../bioacoustics/bioacoustics_spectrograms.py | 304 ---- .../data/bioacoustics/bioacoustics_windows.py | 449 ------ PytorchWildlife/data/datasets.py | 204 --- PytorchWildlife/data/transforms.py | 156 -- PytorchWildlife/models/__init__.py | 3 - .../models/bioacoustics/__init__.py | 13 - .../models/bioacoustics/base_bioacoustics.py | 97 -- .../models/bioacoustics/resnet_classifier.py | 588 -------- .../models/classification/__init__.py | 3 - .../models/classification/base_classifier.py | 28 - .../classification/resnet_base/__init__.py | 5 - .../classification/resnet_base/amazon.py | 112 -- .../resnet_base/base_classifier.py | 190 --- .../resnet_base/custom_weights.py | 64 - .../classification/resnet_base/opossum.py | 70 - .../classification/resnet_base/serengeti.py | 82 -- .../models/classification/timm_base/DFNE.py | 56 - .../classification/timm_base/Deepfaune.py | 47 - .../classification/timm_base/__init__.py | 3 - .../timm_base/base_classifier.py | 210 --- PytorchWildlife/models/detection/__init__.py | 4 - .../models/detection/base_detector.py | 100 -- .../models/detection/localization/Herdnet.md | 26 - .../models/detection/localization/OWL_C.py | 307 ---- .../models/detection/localization/OWL_T.py | 301 ---- .../models/detection/localization/__init__.py | 3 - .../localization/animaloc/__init__.py | 14 - .../localization/animaloc/data/__init__.py | 17 - .../localization/animaloc/data/patches.py | 187 --- .../localization/animaloc/data/types.py | 128 -- .../localization/animaloc/eval/__init__.py | 17 - .../localization/animaloc/eval/lmds.py | 276 ---- .../localization/animaloc/eval/stitchers.py | 256 ---- .../models/detection/localization/dla.py | 590 -------- .../models/detection/localization/herdnet.py | 313 ---- .../models/detection/localization/model.py | 150 -- .../detection/localization/model_owl_c.py | 97 -- .../detection/localization/model_owl_t.py | 414 ------ .../detection/rtdetr_apache/__init__.py | 2 - .../rtdetr_apache/megadetectorv6_apache.py | 44 - .../rtdetr_apache/rtdetr_apache_base.py | 225 --- .../rtdetrv2_pytorch/__init__.py | 0 .../rtdetrv2_pytorch/configs/__init__.py | 0 .../configs/dataset/__init__.py | 0 .../dataset/megadetector_detection.yml | 3 - .../configs/rtdetrv2/__init__.py | 0 .../configs/rtdetrv2/include/__init__.py | 0 .../rtdetrv2/include/rtdetrv2_r50vd.yml | 83 -- .../rtdetrv2_r101vd_6x_megadetector.yml | 18 - .../rtdetrv2_r18vd_120e_megadetector.yml | 21 - .../rtdetrv2_pytorch/src/__init__.py | 6 - .../rtdetrv2_pytorch/src/backbone/__init__.py | 4 - .../rtdetrv2_pytorch/src/backbone/common.py | 86 -- .../rtdetrv2_pytorch/src/backbone/presnet.py | 244 ---- .../rtdetrv2_pytorch/src/core/__init__.py | 7 - .../rtdetrv2_pytorch/src/core/_config.py | 76 - .../rtdetrv2_pytorch/src/core/workspace.py | 171 --- .../rtdetrv2_pytorch/src/core/yaml_config.py | 38 - .../rtdetrv2_pytorch/src/core/yaml_utils.py | 97 -- .../rtdetrv2_pytorch/src/rtdetr/__init__.py | 8 - .../rtdetrv2_pytorch/src/rtdetr/box_ops.py | 22 - .../rtdetrv2_pytorch/src/rtdetr/denoising.py | 99 -- .../src/rtdetr/hybrid_encoder.py | 330 ----- .../rtdetrv2_pytorch/src/rtdetr/rtdetr.py | 44 - .../src/rtdetr/rtdetr_postprocessor.py | 93 -- .../src/rtdetr/rtdetrv2_decoder.py | 608 -------- .../rtdetrv2_pytorch/src/rtdetr/utils.py | 127 -- .../detection/ultralytics_based/Deepfaune.py | 43 - .../detection/ultralytics_based/__init__.py | 6 - .../ultralytics_based/megadetectorv5.py | 58 - .../ultralytics_based/megadetectorv6.py | 53 - .../megadetectorv6_distributed.py | 53 - .../ultralytics_based/yolov5_base.py | 183 --- .../ultralytics_based/yolov8_base.py | 220 --- .../ultralytics_based/yolov8_distributed.py | 232 --- .../models/detection/yolo_mit/__init__.py | 2 - .../detection/yolo_mit/megadetectorv6_mit.py | 44 - .../detection/yolo_mit/yolo/__init__.py | 15 - .../models/detection/yolo_mit/yolo/config.py | 168 --- .../detection/yolo_mit/yolo/model/__init__.py | 0 .../detection/yolo_mit/yolo/model/module.py | 414 ------ .../detection/yolo_mit/yolo/model/yolo.py | 180 --- .../detection/yolo_mit/yolo/tools/__init__.py | 0 .../yolo_mit/yolo/tools/data_augmentation.py | 55 - .../yolo_mit/yolo/tools/data_loader.py | 231 --- .../yolo/tools/dataset_preparation.py | 51 - .../detection/yolo_mit/yolo/utils/__init__.py | 0 .../yolo_mit/yolo/utils/bounding_box_utils.py | 197 --- .../yolo_mit/yolo/utils/dataset_utils.py | 116 -- .../yolo_mit/yolo/utils/model_utils.py | 29 - .../detection/yolo_mit/yolo_mit_base.py | 220 --- PytorchWildlife/utils/__init__.py | 2 - PytorchWildlife/utils/misc.py | 53 - PytorchWildlife/utils/post_process.py | 532 ------- docs-requirements.txt | 2 - docs/base/data/datasets.md | 3 - docs/base/data/transforms.md | 3 - .../models/classification/base_classifier.md | 3 - .../classification/resnet_base/amazon.md | 3 - .../resnet_base/base_classifier.md | 3 - .../resnet_base/custom_weights.md | 3 - .../classification/resnet_base/opossum.md | 3 - .../classification/resnet_base/serengeti.md | 3 - .../models/classification/timm_base/DFNE.md | 3 - .../classification/timm_base/Deepfaune.md | 3 - .../timm_base/base_classifier.md | 3 - docs/base/models/detection/base_detector.md | 3 - docs/base/models/detection/herdnet.md | 3 - .../herdnet/animaloc/data/patches.md | 3 - .../detection/herdnet/animaloc/data/types.md | 3 - .../detection/herdnet/animaloc/eval/lmds.md | 3 - .../herdnet/animaloc/eval/stitchers.md | 3 - docs/base/models/detection/herdnet/dla.md | 3 - docs/base/models/detection/herdnet/model.md | 3 - .../detection/ultralytics_based/Deepfaune.md | 3 - .../ultralytics_based/megadetectorv5.md | 3 - .../ultralytics_based/megadetectorv6.md | 3 - .../megadetectorv6_distributed.md | 3 - .../ultralytics_based/yolov5_base.md | 3 - .../ultralytics_based/yolov8_base.md | 3 - .../ultralytics_based/yolov8_distributed.md | 3 - docs/base/overview.md | 37 - docs/base/utils/misc.md | 1 - docs/base/utils/post_process.md | 1 - .../classification/overview.md | 10 - .../fine_tuning_modules/detection/overview.md | 10 - docs/fine_tuning_modules/overview.md | 10 - docs/index.md | 71 +- mkdocs.yml | 68 +- setup.py | 49 - version.txt | 1 - 146 files changed, 32 insertions(+), 15728 deletions(-) delete mode 100644 MANIFEST.in delete mode 100644 PW_Bioacoustics/README.md delete mode 100644 PW_Bioacoustics/__init__.py delete mode 100644 PW_Bioacoustics/demo/README.md delete mode 100644 PW_Bioacoustics/demo/bioacoustics_demo.ipynb delete mode 100644 PW_Bioacoustics/inference.py delete mode 100644 PW_Bioacoustics/prepare_dataset.py delete mode 100644 PW_Bioacoustics/template.yaml delete mode 100644 PW_Bioacoustics/train.py delete mode 100644 PytorchWildlife/__init__.py delete mode 100644 PytorchWildlife/data/__init__.py delete mode 100644 PytorchWildlife/data/bioacoustics/__init__.py delete mode 100644 PytorchWildlife/data/bioacoustics/bioacoustics_annotations.py delete mode 100644 PytorchWildlife/data/bioacoustics/bioacoustics_configs.py delete mode 100644 PytorchWildlife/data/bioacoustics/bioacoustics_datasets.py delete mode 100644 PytorchWildlife/data/bioacoustics/bioacoustics_spectrograms.py delete mode 100644 PytorchWildlife/data/bioacoustics/bioacoustics_windows.py delete mode 100644 PytorchWildlife/data/datasets.py delete mode 100644 PytorchWildlife/data/transforms.py delete mode 100644 PytorchWildlife/models/__init__.py delete mode 100644 PytorchWildlife/models/bioacoustics/__init__.py delete mode 100644 PytorchWildlife/models/bioacoustics/base_bioacoustics.py delete mode 100644 PytorchWildlife/models/bioacoustics/resnet_classifier.py delete mode 100644 PytorchWildlife/models/classification/__init__.py delete mode 100644 PytorchWildlife/models/classification/base_classifier.py delete mode 100644 PytorchWildlife/models/classification/resnet_base/__init__.py delete mode 100644 PytorchWildlife/models/classification/resnet_base/amazon.py delete mode 100644 PytorchWildlife/models/classification/resnet_base/base_classifier.py delete mode 100644 PytorchWildlife/models/classification/resnet_base/custom_weights.py delete mode 100644 PytorchWildlife/models/classification/resnet_base/opossum.py delete mode 100644 PytorchWildlife/models/classification/resnet_base/serengeti.py delete mode 100644 PytorchWildlife/models/classification/timm_base/DFNE.py delete mode 100644 PytorchWildlife/models/classification/timm_base/Deepfaune.py delete mode 100644 PytorchWildlife/models/classification/timm_base/__init__.py delete mode 100644 PytorchWildlife/models/classification/timm_base/base_classifier.py delete mode 100644 PytorchWildlife/models/detection/__init__.py delete mode 100644 PytorchWildlife/models/detection/base_detector.py delete mode 100644 PytorchWildlife/models/detection/localization/Herdnet.md delete mode 100644 PytorchWildlife/models/detection/localization/OWL_C.py delete mode 100644 PytorchWildlife/models/detection/localization/OWL_T.py delete mode 100644 PytorchWildlife/models/detection/localization/__init__.py delete mode 100644 PytorchWildlife/models/detection/localization/animaloc/__init__.py delete mode 100644 PytorchWildlife/models/detection/localization/animaloc/data/__init__.py delete mode 100644 PytorchWildlife/models/detection/localization/animaloc/data/patches.py delete mode 100644 PytorchWildlife/models/detection/localization/animaloc/data/types.py delete mode 100644 PytorchWildlife/models/detection/localization/animaloc/eval/__init__.py delete mode 100644 PytorchWildlife/models/detection/localization/animaloc/eval/lmds.py delete mode 100644 PytorchWildlife/models/detection/localization/animaloc/eval/stitchers.py delete mode 100644 PytorchWildlife/models/detection/localization/dla.py delete mode 100644 PytorchWildlife/models/detection/localization/herdnet.py delete mode 100644 PytorchWildlife/models/detection/localization/model.py delete mode 100644 PytorchWildlife/models/detection/localization/model_owl_c.py delete mode 100644 PytorchWildlife/models/detection/localization/model_owl_t.py delete mode 100644 PytorchWildlife/models/detection/rtdetr_apache/__init__.py delete mode 100644 PytorchWildlife/models/detection/rtdetr_apache/megadetectorv6_apache.py delete mode 100644 PytorchWildlife/models/detection/rtdetr_apache/rtdetr_apache_base.py delete mode 100644 PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/__init__.py delete mode 100644 PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/configs/__init__.py delete mode 100644 PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/configs/dataset/__init__.py delete mode 100644 PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/configs/dataset/megadetector_detection.yml delete mode 100644 PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/configs/rtdetrv2/__init__.py delete mode 100644 PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/configs/rtdetrv2/include/__init__.py delete mode 100644 PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/configs/rtdetrv2/include/rtdetrv2_r50vd.yml delete mode 100644 PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/configs/rtdetrv2/rtdetrv2_r101vd_6x_megadetector.yml delete mode 100644 PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/configs/rtdetrv2/rtdetrv2_r18vd_120e_megadetector.yml delete mode 100644 PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/__init__.py delete mode 100644 PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/backbone/__init__.py delete mode 100644 PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/backbone/common.py delete mode 100644 PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/backbone/presnet.py delete mode 100644 PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/core/__init__.py delete mode 100644 PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/core/_config.py delete mode 100644 PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/core/workspace.py delete mode 100644 PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/core/yaml_config.py delete mode 100644 PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/core/yaml_utils.py delete mode 100644 PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/__init__.py delete mode 100644 PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/box_ops.py delete mode 100644 PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/denoising.py delete mode 100644 PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/hybrid_encoder.py delete mode 100644 PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/rtdetr.py delete mode 100644 PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/rtdetr_postprocessor.py delete mode 100644 PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/rtdetrv2_decoder.py delete mode 100644 PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/utils.py delete mode 100644 PytorchWildlife/models/detection/ultralytics_based/Deepfaune.py delete mode 100644 PytorchWildlife/models/detection/ultralytics_based/__init__.py delete mode 100644 PytorchWildlife/models/detection/ultralytics_based/megadetectorv5.py delete mode 100644 PytorchWildlife/models/detection/ultralytics_based/megadetectorv6.py delete mode 100644 PytorchWildlife/models/detection/ultralytics_based/megadetectorv6_distributed.py delete mode 100644 PytorchWildlife/models/detection/ultralytics_based/yolov5_base.py delete mode 100644 PytorchWildlife/models/detection/ultralytics_based/yolov8_base.py delete mode 100644 PytorchWildlife/models/detection/ultralytics_based/yolov8_distributed.py delete mode 100644 PytorchWildlife/models/detection/yolo_mit/__init__.py delete mode 100644 PytorchWildlife/models/detection/yolo_mit/megadetectorv6_mit.py delete mode 100644 PytorchWildlife/models/detection/yolo_mit/yolo/__init__.py delete mode 100644 PytorchWildlife/models/detection/yolo_mit/yolo/config.py delete mode 100644 PytorchWildlife/models/detection/yolo_mit/yolo/model/__init__.py delete mode 100644 PytorchWildlife/models/detection/yolo_mit/yolo/model/module.py delete mode 100644 PytorchWildlife/models/detection/yolo_mit/yolo/model/yolo.py delete mode 100644 PytorchWildlife/models/detection/yolo_mit/yolo/tools/__init__.py delete mode 100644 PytorchWildlife/models/detection/yolo_mit/yolo/tools/data_augmentation.py delete mode 100644 PytorchWildlife/models/detection/yolo_mit/yolo/tools/data_loader.py delete mode 100644 PytorchWildlife/models/detection/yolo_mit/yolo/tools/dataset_preparation.py delete mode 100644 PytorchWildlife/models/detection/yolo_mit/yolo/utils/__init__.py delete mode 100644 PytorchWildlife/models/detection/yolo_mit/yolo/utils/bounding_box_utils.py delete mode 100644 PytorchWildlife/models/detection/yolo_mit/yolo/utils/dataset_utils.py delete mode 100644 PytorchWildlife/models/detection/yolo_mit/yolo/utils/model_utils.py delete mode 100644 PytorchWildlife/models/detection/yolo_mit/yolo_mit_base.py delete mode 100644 PytorchWildlife/utils/__init__.py delete mode 100644 PytorchWildlife/utils/misc.py delete mode 100644 PytorchWildlife/utils/post_process.py delete mode 100644 docs/base/data/datasets.md delete mode 100644 docs/base/data/transforms.md delete mode 100644 docs/base/models/classification/base_classifier.md delete mode 100644 docs/base/models/classification/resnet_base/amazon.md delete mode 100644 docs/base/models/classification/resnet_base/base_classifier.md delete mode 100644 docs/base/models/classification/resnet_base/custom_weights.md delete mode 100644 docs/base/models/classification/resnet_base/opossum.md delete mode 100644 docs/base/models/classification/resnet_base/serengeti.md delete mode 100644 docs/base/models/classification/timm_base/DFNE.md delete mode 100644 docs/base/models/classification/timm_base/Deepfaune.md delete mode 100644 docs/base/models/classification/timm_base/base_classifier.md delete mode 100644 docs/base/models/detection/base_detector.md delete mode 100644 docs/base/models/detection/herdnet.md delete mode 100644 docs/base/models/detection/herdnet/animaloc/data/patches.md delete mode 100644 docs/base/models/detection/herdnet/animaloc/data/types.md delete mode 100644 docs/base/models/detection/herdnet/animaloc/eval/lmds.md delete mode 100644 docs/base/models/detection/herdnet/animaloc/eval/stitchers.md delete mode 100644 docs/base/models/detection/herdnet/dla.md delete mode 100644 docs/base/models/detection/herdnet/model.md delete mode 100644 docs/base/models/detection/ultralytics_based/Deepfaune.md delete mode 100644 docs/base/models/detection/ultralytics_based/megadetectorv5.md delete mode 100644 docs/base/models/detection/ultralytics_based/megadetectorv6.md delete mode 100644 docs/base/models/detection/ultralytics_based/megadetectorv6_distributed.md delete mode 100644 docs/base/models/detection/ultralytics_based/yolov5_base.md delete mode 100644 docs/base/models/detection/ultralytics_based/yolov8_base.md delete mode 100644 docs/base/models/detection/ultralytics_based/yolov8_distributed.md delete mode 100644 docs/base/overview.md delete mode 100644 docs/base/utils/misc.md delete mode 100644 docs/base/utils/post_process.md delete mode 100644 docs/fine_tuning_modules/classification/overview.md delete mode 100644 docs/fine_tuning_modules/detection/overview.md delete mode 100644 docs/fine_tuning_modules/overview.md delete mode 100644 setup.py delete mode 100644 version.txt diff --git a/MANIFEST.in b/MANIFEST.in deleted file mode 100644 index 31f8ae988..000000000 --- a/MANIFEST.in +++ /dev/null @@ -1,3 +0,0 @@ -global-include *.yml -include version.txt -include README.md \ No newline at end of file diff --git a/PW_Bioacoustics/README.md b/PW_Bioacoustics/README.md deleted file mode 100644 index 864d53d52..000000000 --- a/PW_Bioacoustics/README.md +++ /dev/null @@ -1,208 +0,0 @@ -# PW_Bioacoustics - -Companion module for bioacoustics experiments using the PytorchWildlife core library. - -## Overview - -This module provides CLI scripts for training, inference, and dataset preparation for bioacoustics classification. The core functionality (models, datasets, utilities) is provided by the `PytorchWildlife` library. - -## Quick Start - -### 1. Installation - -```bash -# Install PyTorch-Wildlife with bioacoustics dependencies -pip install -e . -pip install librosa soundfile pyyaml torchmetrics -``` - -### 2. Configuration - -Create a YAML config file for your domain (see `template.yaml` as reference): - -```yaml -name: "my_domain" -datasets: - - "dataset_name_1" - -class_names: - 0: "noise" - 1: "target_class" - -paths: - data_root: "${DATA_ROOT}" - output_root: "${OUTPUT_ROOT}" - spectrograms_dir: "${OUTPUT_ROOT}/mel_spectrograms" - annotations_file: "annotations.json" - -audio: - sample_rate: 48000 - window_size_sec: 5.0 - overlap_sec: 4.0 - -spectrogram: - n_fft: 2048 - hop_length: 512 - n_mels: 224 - -training: - batch_size: 32 - lr: 0.0001 - epochs: 50 - backbone: "resnet18" -``` - -### 3. Prepare Dataset - -```bash -# Full pipeline (stats, windows, spectrograms, splits) -python prepare_dataset.py --config config/my_domain.yaml - -# Or run specific steps -python prepare_dataset.py --config config/my_domain.yaml --steps windows spectrograms -``` - -### 4. Train Model - -```bash -# Binary classification -python train.py --config config/my_domain.yaml \ - --train_csv train_split.csv \ - --val_csv val_split.csv \ - --test_csv test_split.csv - -# Multiclass classification -python train.py --config config/my_domain.yaml \ - --train_csv train_split.csv \ - --test_csv test_split.csv \ - --num_classes 4 -``` - -### 5. Run Inference - -```bash -python inference.py --config config/my_domain.yaml \ - --checkpoint model.ckpt \ - --audios_source /path/to/audio/folder \ - --dataset my_inference -``` - -## Demo - -The recommended way to get started is the **end-to-end demo notebook** at [`demo/bioacoustics_demo.ipynb`](demo/bioacoustics_demo.ipynb). It walks through the full pipeline using real bird recordings from the [PteroSet](https://zenodo.org/records/19137071) dataset: - -1. **Data Exploration** — annotation counts, species distribution -2. **Inference** — download `MD_AudioBirds_V1.onnx` from Zenodo, run ONNX inference on all 5 recordings, visualise predictions vs. ground-truth -3. **Train** - - **3.0 Build COCO Annotations** — `PteroSetReader` converts Raven Pro TSV → COCO-like JSON - - **3.1 Binary Classification** — AVEVOC vs. noise training with `ResNetClassifier` - - **3.2 Multiclass Classification** — top-4 species vs. noise, species analysis bar chart, trains separate model - -See [`demo/README.md`](demo/README.md) for setup instructions and expected runtimes. - -## Module Structure - -``` -PW_Bioacoustics/ -├── __init__.py -├── train.py # Training CLI script -├── inference.py # Inference CLI script -├── prepare_dataset.py # Dataset preparation pipeline -├── template.yaml # Template configuration file -└── demo/ - ├── bioacoustics_demo.ipynb # End-to-end demo notebook - ├── README.md - ├── data/ # Sample audio + annotations - └── config/ # Demo YAML configs -``` - -The CLI scripts (`train.py`, `prepare_dataset.py`, `inference.py`) can be used standalone or imported as modules (as the demo notebook does). `template.yaml` documents all configuration parameters. - -## Core Library (PytorchWildlife) - -This module uses the following components from `PytorchWildlife`: - -### Models (`PytorchWildlife.models.bioacoustics`) -- `ResNetClassifier`: PyTorch Lightning module for spectrogram classification (binary and multiclass) -- `BaseBioacousticsClassifier`: Base class for bioacoustics models -- `load_model_from_checkpoint()`: Load a trained model from a `.ckpt` file for inference - -### Datasets (`PytorchWildlife.data.bioacoustics.bioacoustics_datasets`) -- `BioacousticsDataset`: Training dataset for loading spectrograms from `.npy` files -- `BioacousticsInferenceDataset`: Inference dataset (no labels required) -- `SpectrogramAugmentations`: SpecAugment-style augmentations (time/frequency masking) -- `MixUpCollator`: Batch-level MixUp augmentation -- `PerSampleNormalize`, `ResizeTo`: Spectrogram transforms - -### Annotations (`PytorchWildlife.data.bioacoustics.bioacoustics_annotations`) -- `BaseReader`: Abstract base class for converting annotation formats to COCO-like JSON -- `AnnotationCreator`: Builds COCO-like annotation files from `BaseReader` subclasses - -### Configuration (`PytorchWildlife.data.bioacoustics.bioacoustics_configs`) -- `DomainConfig`: Nested dataclass for domain settings (paths, audio, spectrogram, training, splits) -- `load_config()`: YAML configuration loader with environment variable expansion -- `save_config()`: Serialize a `DomainConfig` back to YAML - -### Windows (`PytorchWildlife.data.bioacoustics.bioacoustics_windows`) -- `build_windows()`: Generate training windows from annotations (sliding, balanced, or customized strategies) -- `build_inference_windows()`: Generate sliding windows for inference on raw audio files - -### Spectrograms (`PytorchWildlife.data.bioacoustics.bioacoustics_spectrograms`) -- `compute_mel_spectrograms_gpu()`: GPU-accelerated mel spectrogram computation, saves `.npy` files - -## Projects Using PytorchWildlife Bioacoustics - -- **[PteroSet](https://github.com/microsoft/PteroSet)** — A machine learning pipeline for detecting and classifying tropical bird vocalizations from passive acoustic monitoring recordings. Built on the PytorchWildlife bioacoustics core library, it demonstrates the full workflow: COCO annotation creation from Raven Pro labels, mel spectrogram preparation, binary ResNet training, and leave-one-project-out cross-validation. -- **[CookInlet_Belugas](https://github.com/microsoft/CookInlet_Belugas)** — An end-to-end passive acoustic monitoring pipeline for endangered Cook Inlet beluga whales. Implements spectrogram generation, a two-stage deep learning architecture for cetacean signal detection and multi-species classification (beluga, humpback, killer whale), and an active-learning loop for domain adaptation to novel soundscapes. - -## Training Arguments - -``` ---config # YAML config file (recommended) ---train_csv # Path to training CSV ---val_csv # Path to validation CSV (optional) ---test_csv # Path to test CSV ---num_classes # 2 for binary, >2 for multiclass (default: 2) ---backbone # resnet18, resnet34, resnet50 (default: resnet18) ---batch_size # Batch size (default: 32) ---lr # Learning rate (default: 1e-4) ---epochs # Number of epochs (default: 5) ---use_specaug # Enable spectrogram augmentations ---normalize # Normalize spectrograms (default: True) ---freeze_backbone # none, all, early, layer1, layer2, layer3 -``` - -## Inference Arguments - -``` ---config # YAML config file (recommended) ---audios_source # Audio folder, JSON, or CSV with windows ---checkpoint # Model checkpoint file (.ckpt) ---num_classes # Number of classes (default: 2) ---class_names # Class names for output columns ---window_size_sec # Window size in seconds (default: 5.0) ---overlap_sec # Overlap between windows (default: 4.0) ---sample_rate # Target sample rate (default: 48000) ---batch_size # Batch size (default: 64) ---temperature # Temperature scaling (default: 1.0) ---dataset # Dataset name for output directory -``` - -## Output Formats - -### Binary Classification -Output CSV columns: `audio`, `start(s)`, `end(s)`, `prediction`, `probability`, `confidence` - -### Multiclass Classification -Output CSV columns: `audio`, `start(s)`, `end(s)`, `prediction`, `{ClassName}_prob` for each class - -## Requirements - -- Python 3.9+ -- PyTorch 2.0+ -- PyTorch Lightning -- librosa, torchaudio -- pandas, numpy -- Weights & Biases (optional, for experiment tracking) - -See the main `requirements.txt` for full dependencies. diff --git a/PW_Bioacoustics/__init__.py b/PW_Bioacoustics/__init__.py deleted file mode 100644 index 93d0ab081..000000000 --- a/PW_Bioacoustics/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -""" -PW_Bioacoustics - Companion module for bioacoustics experiments. - -This module provides CLI scripts for training, inference, and dataset -preparation using the PytorchWildlife bioacoustics core library. -""" - -__version__ = "0.1.0" diff --git a/PW_Bioacoustics/demo/README.md b/PW_Bioacoustics/demo/README.md deleted file mode 100644 index f17a38de8..000000000 --- a/PW_Bioacoustics/demo/README.md +++ /dev/null @@ -1,78 +0,0 @@ -# PW_Bioacoustics Demo - -End-to-end demo using 5 real bird recordings from the -[PteroSet](https://zenodo.org/records/19137071) dataset -(project PPA4, Putumayo, Colombia — CC BY 4.0). - -## Files - -``` -demo/ -├── bioacoustics_demo.ipynb ← main demo notebook -├── data/ -│ ├── audios/ ← 5 PteroSet WAV files (192 kHz, auto-downsampled to 48 kHz) -│ └── labels/ ← Raven Pro selection tables (.Table.1.selections.txt) -└── output/ ← generated at runtime - ├── spectrograms/ ← shared .npy mel spectrogram files - ├── inference/ ← ONNX inference predictions and spectrograms - ├── binary/ ← binary model, splits, logs, checkpoints - └── multiclass/ ← multiclass model, splits, logs, checkpoints -``` - -## Running the notebook - -```bash -# From the PW_Bioacoustics/ directory -cd PW_Bioacoustics/demo -jupyter notebook bioacoustics_demo.ipynb -``` - -The notebook must be run from `PW_Bioacoustics/demo/` — the first cell asserts this. - -## What the notebook demonstrates - -| Section | Description | -|---------|-------------| -| 0. Setup | Imports, path configuration, `%matplotlib inline` | -| 1. Data Exploration | Annotation counts, species distribution (derived automatically from data) | -| 2. Inference | Download `MD_AudioBirds_V1.onnx` from Zenodo, run ONNX inference on all 5 recordings, visualise predictions vs. ground-truth | -| 3. Train | — | -| 3.0 Build COCO Annotations | `PteroSetReader` converts Raven Pro TSV → COCO-like JSON (binary + multiclass) | -| 3.1 Binary Classification | AVEVOC vs. noise — `build_windows`, spectrograms, train, evaluate | -| 3.2 Multiclass Classification | Top-4 species vs. noise — species analysis bar chart, reuses spectrograms, trains separate model | - -Every code cell is preceded by a markdown cell explaining what it does and its expected output. - -## Pre-trained model - -Section 2 downloads the pre-trained **MD_AudioBirds_V1** ONNX model from Zenodo: - -``` -https://zenodo.org/records/18177050/files/MD_AudioBirds_V1.onnx?download=1 -``` - -The file is cached to `output/inference/MD_AudioBirds_V1.onnx` and skipped on subsequent runs. - -## Using your own data - -Swap in your own recordings by replacing the files in `data/audios/` and `data/labels/` -and updating `PPA4_FILES` in the Setup cell. For a different annotation format, subclass `BaseReader` -following the `PteroSetReader` pattern. - -## Expected runtime - -| Environment | Binary training | Multiclass training | -|-------------|----------------|---------------------| -| GPU (A100) | ~2 min | ~2 min | -| CPU (16-core) | ~20–40 min | ~20–40 min | - -The ONNX inference section (Section 2) runs in under a minute on CPU. - -Reduce `epochs` in the config cells to speed up the demo. - -## Data citation - -> Ruiz, D., Ulloa, J. S., Miao, Z., Betancourt, N., Hernández, A., Demuro, B., Barona Cortés, E., -> Toro Gómez, M. P., Mendoza-Henao, A. M., Sierra-Ricaurte, A. F., Pérez-Peña, S. C., Dodhia, R., -> Arbelaez, P., & Lavista Ferres, J. M. (2026). PteroSet [Data set]. Zenodo. -> https://doi.org/10.5281/zenodo.19137071 diff --git a/PW_Bioacoustics/demo/bioacoustics_demo.ipynb b/PW_Bioacoustics/demo/bioacoustics_demo.ipynb deleted file mode 100644 index 40d0910e8..000000000 --- a/PW_Bioacoustics/demo/bioacoustics_demo.ipynb +++ /dev/null @@ -1,1271 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# PW_Bioacoustics Demo\n", - "\n", - "End-to-end walkthrough of the PW_Bioacoustics pipeline using real bird recordings from the\n", - "[PteroSet](https://zenodo.org/records/19137071) dataset." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 0. Setup" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Set up the Python environment, verify the working directory, and configure all output paths." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "%matplotlib inline\n", - "\n", - "import os\n", - "import sys\n", - "import json\n", - "import shutil\n", - "import warnings\n", - "from pathlib import Path\n", - "\n", - "import numpy as np\n", - "import pandas as pd\n", - "import matplotlib.pyplot as plt\n", - "import matplotlib.patches as mpatches\n", - "import soundfile as sf\n", - "import torchaudio\n", - "import torch\n", - "\n", - "warnings.filterwarnings(\"ignore\")\n", - "\n", - "# ── Verify working directory ────────────────────────────────────────────\n", - "DEMO_DIR = Path(os.getcwd())\n", - "if DEMO_DIR.name != \"demo\":\n", - " raise RuntimeError(\n", - " f\"Run this notebook from PW_Bioacoustics/demo/. Currently in: {DEMO_DIR}\\n\"\n", - " \" cd PW_Bioacoustics/demo && jupyter notebook bioacoustics_demo.ipynb\"\n", - " )\n", - "\n", - "PW_BIO_DIR = DEMO_DIR.parent # PW_Bioacoustics/\n", - "CAMERATRAP_DIR = PW_BIO_DIR.parent # CameraTraps/\n", - "\n", - "for p in [str(PW_BIO_DIR), str(CAMERATRAP_DIR)]:\n", - " if p not in sys.path:\n", - " sys.path.insert(0, p)\n", - "\n", - "# ── Directory layout ────────────────────────────────────────────\n", - "DATA_DIR = DEMO_DIR / \"data\"\n", - "AUDIOS_DIR = DATA_DIR / \"audios\"\n", - "LABELS_DIR = DATA_DIR / \"labels\"\n", - "OUTPUT_DIR = DEMO_DIR / \"output\"\n", - "SPEC_DIR = OUTPUT_DIR / \"spectrograms\" # shared by both modes\n", - "BINARY_DIR = OUTPUT_DIR / \"binary\"\n", - "MULTICLASS_DIR = OUTPUT_DIR / \"multiclass\"\n", - "\n", - "for d in [AUDIOS_DIR, LABELS_DIR, SPEC_DIR, BINARY_DIR, MULTICLASS_DIR]:\n", - " d.mkdir(parents=True, exist_ok=True)\n", - "\n", - "print(\"Environment ready\")\n", - "print(f\" PW_Bioacoustics : {PW_BIO_DIR}\")\n", - "print(f\" Data : {DATA_DIR}\")\n", - "print(f\" Output : {OUTPUT_DIR}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 1. Data Exploration" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Load metadata for the five PteroSet recordings (project PPA4, Puerto Asís, Putumayo, Colombia).\n", - "Each `.Table.1.selections.txt` file is a Raven Pro annotation table with `Begin Time (s)`,\n", - "`End Time (s)`, and species `Determination` columns.\n", - "\n", - "**Expected output:** a summary table with file duration, sample rate, and annotation counts." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "PPA4_FILES = [\n", - " \"G021_timelapse_20250623\",\n", - " \"G021_timelapse_20250622\",\n", - " \"G040_timelapse_20250625\",\n", - " \"G040_timelapse_20250629\",\n", - " \"G010_timelapse_20250629\",\n", - "]\n", - "\n", - "rows = []\n", - "for stem in PPA4_FILES:\n", - " txt_path = LABELS_DIR / f\"{stem}.Table.1.selections.txt\"\n", - " wav_path = AUDIOS_DIR / f\"{stem}.wav\"\n", - " df = pd.read_csv(txt_path, delimiter=\"\\t\")\n", - " info = sf.info(str(wav_path))\n", - " with_species = df[\"Determination\"].apply(\n", - " lambda x: isinstance(x, str) and x.strip() != \"\"\n", - " ).sum()\n", - " rows.append({\n", - " \"file\": stem,\n", - " \"duration (min)\": round(info.duration / 60, 1),\n", - " \"sample_rate (kHz)\": info.samplerate // 1000,\n", - " \"total_annotations\": len(df),\n", - " \"with_species_id\": int(with_species),\n", - " })\n", - "\n", - "summary = pd.DataFrame(rows)\n", - "print(summary.to_string(index=False))\n", - "print(f\"\\nTotal annotations : {summary['total_annotations'].sum()}\")\n", - "print(f\"Total duration : {summary['duration (min)'].sum():.0f} min\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 2. Inference" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Run inference on all five recordings using our pre-trained model for birds binary classification (`MD_AudioBirds_V1`). This section does **not** require training — it demonstrates how to\n", - "apply our model to new audio data." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "1. Download the pre-trained `MD_AudioBirds_V1.onnx` model from Zenodo and set up the\n", - "inference output directories. The model file is cached locally and skipped on subsequent runs.\n", - "\n", - " **Expected output:** confirmation that the model file exists or has been downloaded." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import urllib.request\n", - "from PytorchWildlife.data.bioacoustics.bioacoustics_windows import build_inference_windows\n", - "from PytorchWildlife.data.bioacoustics.bioacoustics_spectrograms import compute_mel_spectrograms_gpu\n", - "from PytorchWildlife.data.bioacoustics.bioacoustics_datasets import BioacousticsInferenceDataset\n", - "from torch.utils.data import DataLoader\n", - "\n", - "MODEL_URL = \"https://zenodo.org/records/18177050/files/MD_AudioBirds_V1.onnx?download=1\"\n", - "INFER_DIR = OUTPUT_DIR / \"inference\"\n", - "INFER_SPEC = INFER_DIR / \"spectrograms\"\n", - "INFER_DIR.mkdir(parents=True, exist_ok=True)\n", - "INFER_SPEC.mkdir(parents=True, exist_ok=True)\n", - "\n", - "MODEL_PATH = INFER_DIR / \"MD_AudioBirds_V1.onnx\"\n", - "if not MODEL_PATH.exists():\n", - " print(\"Downloading MD_AudioBirds_V1.onnx from Zenodo...\")\n", - " urllib.request.urlretrieve(MODEL_URL, str(MODEL_PATH))\n", - " print(f\"Model saved to: {MODEL_PATH}\")\n", - "else:\n", - " print(f\"Model already cached: {MODEL_PATH}\")\n", - "\n", - "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", - "SR = 48000\n", - "WIN_SEC = 5.0\n", - "OVL_SEC = 4.0\n", - "print(f\"Device: {DEVICE}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "2. Build sliding inference windows over all five audio files and compute their mel spectrograms.\n", - "Windows already computed on disk are skipped automatically.\n", - "\n", - " **Expected output:** window count per file, spectrogram progress bars, and confirmation of saved files." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "all_wav_files = [str(AUDIOS_DIR / f\"{stem}.wav\") for stem in PPA4_FILES]\n", - "\n", - "# Build sliding windows for all 5 files\n", - "infer_windows = build_inference_windows(\n", - " audios_source=all_wav_files,\n", - " window_size_sec=WIN_SEC,\n", - " overlap_sec=OVL_SEC,\n", - " sample_rate=SR,\n", - ")\n", - "print(f\"Total inference windows: {len(infer_windows)}\")\n", - "\n", - "# Compute mel spectrograms (skips existing files automatically)\n", - "compute_mel_spectrograms_gpu(\n", - " windows=infer_windows,\n", - " sample_rate=SR,\n", - " n_fft=2048, hop_length=512, n_mels=224, top_db=80.0,\n", - " spectrograms_path=str(INFER_SPEC),\n", - " save_npy=True,\n", - " fill_highfreq=False,\n", - ")\n", - "\n", - "# Build DataFrame with full spectrogram paths\n", - "infer_df = pd.DataFrame(infer_windows)\n", - "infer_df[\"sound_stem\"] = infer_df[\"sound_path\"].apply(\n", - " lambda p: os.path.splitext(os.path.basename(p))[0]\n", - ")\n", - "infer_df[\"spec_name\"] = infer_df.apply(\n", - " lambda r: str(INFER_SPEC / f\"{r['sound_stem']}_{r['start']}_{r['end']}.npy\"), axis=1\n", - ")\n", - "print(f\"Spectrogram files ready: {len(infer_df)} windows across {infer_df['sound_path'].nunique()} files\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "3. Load the ONNX model with `onnxruntime` and run batch inference on all spectrogram windows.\n", - "\n", - " **Expected output:** model input/output shapes, per-batch progress, and final output shape." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import onnxruntime as ort\n", - "\n", - "TARGET_SIZE = [224, 469]\n", - "\n", - "infer_ds = BioacousticsInferenceDataset(\n", - " dataframe=infer_df, x_col=\"spec_name\",\n", - " target_size=TARGET_SIZE, normalize=False,\n", - ")\n", - "infer_dl = DataLoader(infer_ds, batch_size=32, shuffle=False, num_workers=0)\n", - "\n", - "# Load ONNX session\n", - "sess_opts = ort.SessionOptions()\n", - "sess_opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL\n", - "providers = ([\"CUDAExecutionProvider\", \"CPUExecutionProvider\"]\n", - " if DEVICE == \"cuda\" else [\"CPUExecutionProvider\"])\n", - "ort_session = ort.InferenceSession(str(MODEL_PATH), sess_opts, providers=providers)\n", - "input_name = ort_session.get_inputs()[0].name\n", - "output_name = ort_session.get_outputs()[0].name\n", - "\n", - "print(f\"Input : {input_name} shape={ort_session.get_inputs()[0].shape}\")\n", - "print(f\"Outputs: {[o.name for o in ort_session.get_outputs()]}\")\n", - "\n", - "# Run inference\n", - "all_probs = []\n", - "for batch in infer_dl:\n", - " x, _ = batch\n", - " x_np = x.numpy().astype(np.float32)\n", - " logits = ort_session.run([output_name], {input_name: x_np})[0]\n", - " probs = 1.0 / (1.0 + np.exp(-logits.reshape(-1, 1)))\n", - " all_probs.append(probs)\n", - "\n", - "all_probs = np.concatenate(all_probs, axis=0)\n", - "print(f\"\\nInference complete — output shape: {all_probs.shape}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "4. Attach the predicted probabilities to the inference DataFrame, save predictions to CSV,\n", - "and display the first few rows.\n", - "\n", - " **Expected output:** a CSV saved to `output/inference/onnx_predictions.csv` and a table preview." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "infer_df[\"start_s\"] = infer_df[\"start\"] / SR\n", - "infer_df[\"end_s\"] = infer_df[\"end\"] / SR\n", - "\n", - "if all_probs.shape[1] == 1:\n", - " infer_df[\"probability\"] = all_probs[:, 0]\n", - " infer_df[\"prediction\"] = (infer_df[\"probability\"] > 0.5).astype(int)\n", - "else:\n", - " n_cls = all_probs.shape[1]\n", - " for i in range(n_cls):\n", - " infer_df[f\"prob_class_{i}\"] = all_probs[:, i]\n", - " infer_df[\"prediction\"] = all_probs.argmax(axis=1)\n", - "\n", - "results_csv = str(INFER_DIR / \"onnx_predictions.csv\")\n", - "infer_df.to_csv(results_csv, index=False)\n", - "print(f\"Predictions saved to: {results_csv}\")\n", - "infer_df.head()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "5. Visualise the predicted bird-call probability over time for\n", - "each of the five audio files. Red shading marks the ground-truth annotation intervals.\n", - "\n", - " **Expected output:** a 5-panel figure — one subplot per recording." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "is_binary_onnx = \"probability\" in infer_df.columns\n", - "\n", - "fig, axes = plt.subplots(len(PPA4_FILES), 1,\n", - " figsize=(14, 3 * len(PPA4_FILES)), sharex=False)\n", - "\n", - "for ax, stem in zip(axes, PPA4_FILES):\n", - " wav_path = str(AUDIOS_DIR / f\"{stem}.wav\")\n", - " file_df = infer_df[infer_df[\"sound_path\"] == wav_path].copy()\n", - " t_mid = (file_df[\"start_s\"] + file_df[\"end_s\"]) / 2\n", - "\n", - " if is_binary_onnx:\n", - " ax.plot(t_mid, file_df[\"probability\"], color=\"#1976D2\", lw=0.8)\n", - " ax.axhline(0.5, color=\"grey\", lw=0.8, ls=\"--\", label=\"threshold\")\n", - " ax.set_ylabel(\"P(bird)\")\n", - " else:\n", - " prob_cols = [c for c in infer_df.columns if c.startswith(\"prob_class_\")]\n", - " for i, col in enumerate(prob_cols[1:], start=1):\n", - " ax.plot(t_mid, file_df[col], lw=0.8, label=f\"class {i}\")\n", - " ax.legend(fontsize=7, ncol=3)\n", - " ax.set_ylabel(\"Probability\")\n", - "\n", - " # Ground-truth annotation spans\n", - " txt_path = LABELS_DIR / f\"{stem}.Table.1.selections.txt\"\n", - " if txt_path.exists():\n", - " gt = pd.read_csv(txt_path, delimiter=\"\\t\")\n", - " for _, row in gt.iterrows():\n", - " ax.axvspan(row[\"Begin Time (s)\"], row[\"End Time (s)\"],\n", - " alpha=0.18, color=\"red\")\n", - "\n", - " ax.set_ylim(-0.05, 1.05)\n", - " ax.set_title(stem)\n", - " ax.set_xlabel(\"Time (s)\")\n", - "\n", - "plt.suptitle(\n", - " \"ONNX Inference — MD_AudioBirds_V1 (red = ground-truth annotations)\",\n", - " y=1.01, fontsize=11,\n", - ")\n", - "plt.tight_layout()\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 3. Train" - ] - }, - { - "cell_type": "markdown", - "id": "e580007f", - "metadata": {}, - "source": [ - "### Build COCO Annotations" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "`PteroSetReader` follows the same `BaseReader` pattern as `data_reader.py` in the source dataset.\n", - "It converts Raven Pro selection tables to the COCO-like JSON that `build_windows()` expects.\n", - "\n", - "- **Binary mode** — every annotated event → `category_id=1` (`AVEVOC`)\n", - "- **Multiclass mode** — only the top-4 species → `category_id` 1–4; unlabeled events are skipped\n", - " and will fall into background windows during `build_windows()`" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from PytorchWildlife.data.bioacoustics.bioacoustics_annotations import BaseReader, AnnotationCreator\n", - "\n", - "class PteroSetReader(BaseReader):\n", - " \"\"\"Converts PteroSet Raven Pro annotations to COCO-like JSON.\n", - "\n", - " Parameters\n", - " ----------\n", - " demo_data_dir : str or Path\n", - " Path to demo/data/ (must contain audios/ and labels/ sub-directories).\n", - " mode : {\"binary\", \"multiclass\"}\n", - " - binary : every AVEVOC event -> category_id=1\n", - " - multiclass : top-4 species (CACCEL, CYAVIO, PSAANG, RAMCAR) -> category_id 1-4\n", - " \"\"\"\n", - "\n", - " TOP_SPECIES = [\"CACCEL\", \"CYAVIO\", \"PSAANG\", \"RAMCAR\"] # category_id 1-4\n", - "\n", - " def __init__(self, demo_data_dir, mode=\"binary\"):\n", - " super().__init__(str(demo_data_dir))\n", - " if mode not in (\"binary\", \"multiclass\"):\n", - " raise ValueError(\"mode must be 'binary' or 'multiclass'\")\n", - " self.mode = mode\n", - " self.audio_dir = Path(demo_data_dir) / \"audios\"\n", - " self.labels_dir = Path(demo_data_dir) / \"labels\"\n", - " self.output_path = str(Path(demo_data_dir) / f\"{mode}_annotations.json\")\n", - "\n", - " def add_dataset_info(self):\n", - " self.annotation_creator.add_info(\n", - " title=\"PteroSet — demo subset (5 recordings)\",\n", - " license=\"CC BY 4.0\",\n", - " description=(\n", - " \"5 recordings from project PPA4 (Puerto Asís, Putumayo, Colombia). \"\n", - " \"Annotated with Raven Pro. Part of the PteroSet dataset \"\n", - " \"(https://zenodo.org/records/19137071).\"\n", - " ),\n", - " )\n", - "\n", - " def add_sounds(self):\n", - " wav_files = sorted(f for f in os.listdir(self.audio_dir) if f.endswith(\".wav\"))\n", - " for sound_id, filename in enumerate(wav_files):\n", - " file_path = str(self.audio_dir / filename)\n", - " duration, sample_rate = self.annotation_creator._get_duration_and_sample_rate(file_path)\n", - " self.annotation_creator.add_sound(\n", - " id=sound_id,\n", - " file_name_path=file_path,\n", - " duration=duration,\n", - " sample_rate=sample_rate,\n", - " latitude=float(\"nan\"),\n", - " longitude=float(\"nan\"),\n", - " )\n", - "\n", - " def add_categories(self):\n", - " if self.mode == \"binary\":\n", - " cats = [\n", - " {\"id\": 0, \"name\": \"noise\", \"supercategory\": \"background\"},\n", - " {\"id\": 1, \"name\": \"AVEVOC\", \"supercategory\": \"BIO\"},\n", - " ]\n", - " else:\n", - " cats = [{\"id\": 0, \"name\": \"noise\", \"supercategory\": \"background\"}]\n", - " for i, sp in enumerate(self.TOP_SPECIES, start=1):\n", - " cats.append({\"id\": i, \"name\": sp, \"supercategory\": \"AVEVOC\"})\n", - " self.annotation_creator.data[\"categories\"] = cats\n", - "\n", - " def add_annotations(self):\n", - " wav_files = sorted(f for f in os.listdir(self.audio_dir) if f.endswith(\".wav\"))\n", - " anno_id = 0\n", - " for sound_id, wav_file in enumerate(wav_files):\n", - " stem = os.path.splitext(wav_file)[0]\n", - " txt_path = self.labels_dir / f\"{stem}.Table.1.selections.txt\"\n", - " if not txt_path.exists():\n", - " continue\n", - " df = pd.read_csv(txt_path, delimiter=\"\\t\")\n", - " for _, row in df.iterrows():\n", - " t_min = float(row[\"Begin Time (s)\"])\n", - " t_max = float(row[\"End Time (s)\"])\n", - " determination = str(row.get(\"Determination\", \"\")).strip()\n", - " if self.mode == \"binary\":\n", - " self.annotation_creator.add_annotation(\n", - " anno_id=anno_id, sound_id=sound_id,\n", - " category_id=1, category=\"AVEVOC\",\n", - " supercategory=\"BIO\", t_min=t_min, t_max=t_max,\n", - " )\n", - " anno_id += 1\n", - " else:\n", - " if determination in self.TOP_SPECIES:\n", - " cat_id = self.TOP_SPECIES.index(determination) + 1\n", - " self.annotation_creator.add_annotation(\n", - " anno_id=anno_id, sound_id=sound_id,\n", - " category_id=cat_id, category=determination,\n", - " supercategory=\"AVEVOC\", t_min=t_min, t_max=t_max,\n", - " )\n", - " anno_id += 1" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Run `PteroSetReader` in both `binary` and `multiclass` modes to produce two COCO-like\n", - "JSON annotation files in `data/`.\n", - "\n", - "**Expected output:** paths and annotation counts for `binary_annotations.json` and `multiclass_annotations.json`." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "for mode in (\"binary\", \"multiclass\"):\n", - " reader = PteroSetReader(DATA_DIR, mode=mode)\n", - " reader.process_dataset()\n", - " print()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Binary Classification" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**Task:** detect any bird vocalization (`AVEVOC`) vs. background noise.\n", - "\n", - "Pipeline steps:\n", - "1. Write config to `pteroset_binary.yaml` and load it\n", - "2. `build_windows()` with `strategy=\"sliding\"` and `multiclass=False` → label ∈ {0, 1}\n", - "3. `compute_mel_spectrograms_gpu()` → `.npy` files\n", - "4. Create train / val / test splits\n", - "5. Train `ResNetClassifier` with `num_classes=2`" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Write the binary classification config to YAML and load it as a typed config object.\n", - "All hyperparameters (window size, spectrogram, training) are stored here for reproducibility.\n", - "\n", - "**Expected output:** printed config summary (sample rate, window size, backbone, etc.)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import yaml\n", - "from PytorchWildlife.data.bioacoustics.bioacoustics_configs import load_config\n", - "\n", - "BINARY_CONFIG_PATH = str(DEMO_DIR / \"config\" / \"pteroset_binary.yaml\")\n", - "os.makedirs(os.path.dirname(BINARY_CONFIG_PATH), exist_ok=True)\n", - "\n", - "binary_cfg_dict = {\n", - " \"name\": \"pteroset_binary\",\n", - " \"datasets\": [\"audios\"],\n", - " \"class_names\": {0: \"noise\", 1: \"AVEVOC\"},\n", - " \"paths\": {\n", - " \"data_root\": str(DATA_DIR),\n", - " \"output_root\": str(BINARY_DIR),\n", - " \"spectrograms_dir\": str(SPEC_DIR),\n", - " \"annotations_file\": \"binary_annotations.json\",\n", - " },\n", - " \"audio\": {\n", - " \"sample_rate\": 48000,\n", - " \"window_size_sec\": 5.0,\n", - " \"overlap_sec\": 4.0,\n", - " \"window_strategy\": \"sliding\",\n", - " \"multiclass\": False,\n", - " },\n", - " \"spectrogram\": {\n", - " \"n_fft\": 2048,\n", - " \"hop_length\": 512,\n", - " \"n_mels\": 128,\n", - " \"top_db\": 80.0,\n", - " \"fill_highfreq\": False,\n", - " \"noise_db_std\": 3.0,\n", - " \"storage_dtype\": \"float32\",\n", - " },\n", - " \"training\": {\n", - " \"batch_size\": 16,\n", - " \"num_workers\": 2,\n", - " \"lr\": 1e-4,\n", - " \"weight_decay\": 1e-4,\n", - " \"epochs\": 10,\n", - " \"backbone\": \"resnet18\",\n", - " \"num_classes\": 2,\n", - " \"target_size\": [128, 465],\n", - " \"normalize\": True,\n", - " \"use_specaug\": True,\n", - " \"pos_weight\": 2.0,\n", - " \"conf_threshold\": 0.5,\n", - " },\n", - " \"splits\": {\n", - " \"test_size\": 0.2,\n", - " \"val_size\": 0.2,\n", - " \"n_splits\": 3,\n", - " \"random_state\": 42,\n", - " },\n", - "}\n", - "\n", - "with open(BINARY_CONFIG_PATH, \"w\") as f:\n", - " yaml.dump(binary_cfg_dict, f, default_flow_style=False, sort_keys=False)\n", - "\n", - "binary_cfg = load_config(BINARY_CONFIG_PATH)\n", - "print(f\"Config loaded: {binary_cfg.name}\")\n", - "print(f\" sample_rate : {binary_cfg.audio.sample_rate} Hz\")\n", - "print(f\" window_size : {binary_cfg.audio.window_size_sec}s\")\n", - "print(f\" strategy : {binary_cfg.audio.window_strategy}\")\n", - "print(f\" num_classes : {binary_cfg.training.num_classes}\")\n", - "print(f\" backbone : {binary_cfg.training.backbone}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Print dataset statistics (annotation counts, class balance) and build the sliding window\n", - "list from the binary annotation JSON.\n", - "\n", - "**Expected output:** per-file annotation stats and total window count with label distribution." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import sys\n", - "from prepare_dataset import run_stats, run_windows\n", - "\n", - "run_stats(binary_cfg)\n", - "binary_windows = run_windows(binary_cfg)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Compute mel spectrograms for all binary windows. Windows are named\n", - "`{sound_filename}_{start}_{end}.npy` and stored in the shared `spectrograms/` directory.\n", - "Existing files are skipped automatically.\n", - "\n", - "**Expected output:** per-file progress bars and a count of new vs. cached spectrograms." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from PytorchWildlife.data.bioacoustics.bioacoustics_spectrograms import compute_mel_spectrograms_gpu\n", - "\n", - "with open(binary_cfg.paths.annotations_path) as f:\n", - " anno_data = json.load(f)\n", - "sounds_map = {s[\"id\"]: s for s in anno_data[\"sounds\"]}\n", - "\n", - "inf_windows = [\n", - " {\n", - " \"window_id\": w[\"window_id\"],\n", - " \"sound_path\": sounds_map[w[\"sound_id\"]][\"file_name_path\"],\n", - " \"start\": w[\"start\"],\n", - " \"end\": w[\"end\"],\n", - " }\n", - " for w in binary_windows\n", - " if w[\"sound_id\"] in sounds_map\n", - "]\n", - "\n", - "compute_mel_spectrograms_gpu(\n", - " windows=inf_windows,\n", - " sample_rate=binary_cfg.audio.sample_rate,\n", - " n_fft=binary_cfg.spectrogram.n_fft,\n", - " hop_length=binary_cfg.spectrogram.hop_length,\n", - " n_mels=binary_cfg.spectrogram.n_mels,\n", - " top_db=binary_cfg.spectrogram.top_db,\n", - " spectrograms_path=str(SPEC_DIR),\n", - " save_npy=True,\n", - " fill_highfreq=binary_cfg.spectrogram.fill_highfreq,\n", - " noise_db_std=binary_cfg.spectrogram.noise_db_std,\n", - " storage_dtype=binary_cfg.spectrogram.storage_dtype,\n", - ")\n", - "print(f\"Spectrograms saved to: {SPEC_DIR}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Split windows into train / val / test sets using `GroupShuffleSplit` (files stay together)\n", - "and `StratifiedGroupKFold` (class balance preserved). CSVs are saved to `output/binary/`.\n", - "\n", - "**Expected output:** split sizes and per-split label distributions." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from sklearn.model_selection import GroupShuffleSplit, StratifiedGroupKFold\n", - "\n", - "\n", - "def create_splits(windows, spectrograms_dir, sounds_list, output_dir, cfg):\n", - " \"\"\"Build train/val/test CSVs with spec_name = {sound_filename}_{start}_{end}.npy.\"\"\"\n", - " sounds_stems = {\n", - " s[\"id\"]: os.path.splitext(os.path.basename(s[\"file_name_path\"]))[0]\n", - " for s in sounds_list\n", - " }\n", - " df = pd.DataFrame(windows)\n", - " df[\"sound_filename\"] = df[\"sound_id\"].map(sounds_stems)\n", - " df[\"spec_name\"] = df.apply(\n", - " lambda r: f\"{r['sound_filename']}_{r['start']}_{r['end']}.npy\", axis=1\n", - " )\n", - " df[\"spec_exists\"] = df[\"spec_name\"].apply(\n", - " lambda x: os.path.exists(os.path.join(spectrograms_dir, x))\n", - " )\n", - " df = df[df[\"spec_exists\"]].drop(columns=[\"spec_exists\"])\n", - " print(f\"Windows with existing spectrograms: {len(df)}\")\n", - "\n", - " gss = GroupShuffleSplit(\n", - " n_splits=1, test_size=cfg.splits.test_size,\n", - " random_state=cfg.splits.random_state\n", - " )\n", - " trainval_idx, test_idx = next(gss.split(df, df[\"label\"], groups=df[\"sound_id\"]))\n", - " trainval_df = df.iloc[trainval_idx].copy()\n", - " test_df = df.iloc[test_idx].copy()\n", - "\n", - " sgkf = StratifiedGroupKFold(\n", - " n_splits=cfg.splits.n_splits, shuffle=True,\n", - " random_state=cfg.splits.random_state\n", - " )\n", - " train_idx, val_idx = next(\n", - " sgkf.split(trainval_df, trainval_df[\"label\"], trainval_df[\"sound_id\"])\n", - " )\n", - " train_df = trainval_df.iloc[train_idx].copy()\n", - " val_df = trainval_df.iloc[val_idx].copy()\n", - "\n", - " os.makedirs(output_dir, exist_ok=True)\n", - " train_df.to_csv(os.path.join(output_dir, \"train_split.csv\"), index=False)\n", - " val_df.to_csv(os.path.join(output_dir, \"val_split.csv\"), index=False)\n", - " test_df.to_csv(os.path.join(output_dir, \"test_split.csv\"), index=False)\n", - "\n", - " print(f\"\\nSplit sizes (train={len(train_df)} val={len(val_df)} test={len(test_df)})\")\n", - " for name, split in [(\"train\", train_df), (\"val\", val_df), (\"test\", test_df)]:\n", - " print(f\" {name}: label dist = {split['label'].value_counts().to_dict()}\")\n", - " return train_df, val_df, test_df\n", - "\n", - "\n", - "binary_train, binary_val, binary_test = create_splits(\n", - " binary_windows, str(SPEC_DIR), anno_data[\"sounds\"],\n", - " str(BINARY_DIR), binary_cfg\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Instantiate `ResNetClassifier` in binary mode (2 classes, BCEWithLogitsLoss) and train with\n", - "PyTorch Lightning. The best checkpoint (highest `val/f1`) is saved to `output/binary/checkpoints/`.\n", - "\n", - "**Expected output:** training progress bars, per-epoch metrics, and final test results." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import pytorch_lightning as pl\n", - "from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor\n", - "from pytorch_lightning.loggers import CSVLogger\n", - "\n", - "from PytorchWildlife.models.bioacoustics import ResNetClassifier\n", - "from train import SpectrogramDataModule, DataModuleConfig\n", - "\n", - "pl.seed_everything(42)\n", - "\n", - "binary_dm_cfg = DataModuleConfig(\n", - " train_csv=str(BINARY_DIR / \"train_split.csv\"),\n", - " val_csv =str(BINARY_DIR / \"val_split.csv\"),\n", - " test_csv =str(BINARY_DIR / \"test_split.csv\"),\n", - " root =str(SPEC_DIR),\n", - " target_size=binary_cfg.training.target_size,\n", - " batch_size =binary_cfg.training.batch_size,\n", - " num_workers=binary_cfg.training.num_workers,\n", - " use_specaug=binary_cfg.training.use_specaug,\n", - " normalize =binary_cfg.training.normalize,\n", - " num_classes=None,\n", - " use_mixup =True,\n", - " pin_memory =False,\n", - ")\n", - "binary_dm = SpectrogramDataModule(binary_dm_cfg)\n", - "binary_dm.setup()\n", - "\n", - "binary_model = ResNetClassifier(\n", - " num_classes =binary_dm.num_classes,\n", - " in_channels =binary_dm.in_channels,\n", - " backbone =binary_cfg.training.backbone,\n", - " lr =binary_cfg.training.lr,\n", - " weight_decay =binary_cfg.training.weight_decay,\n", - " T_max =binary_cfg.training.epochs,\n", - " batch_size =binary_cfg.training.batch_size,\n", - " pos_weight =binary_cfg.training.pos_weight,\n", - " conf_threshold =binary_cfg.training.conf_threshold,\n", - " class_names =list(binary_cfg.class_names.values()),\n", - ")\n", - "print(f\"Mode : {'Binary' if binary_dm.is_binary else 'Multiclass'}\")\n", - "print(f\"Classes : {binary_dm.num_classes}\")\n", - "print(f\"Channels : {binary_dm.in_channels}\")\n", - "\n", - "binary_logger = CSVLogger(str(BINARY_DIR), name=\"logs\")\n", - "binary_ckpt = ModelCheckpoint(\n", - " monitor=\"val/f1\", mode=\"max\", save_top_k=1,\n", - " dirpath=str(BINARY_DIR / \"checkpoints\"),\n", - " filename=\"binary-{epoch:02d}-{val/f1:.3f}\",\n", - ")\n", - "binary_trainer = pl.Trainer(\n", - " max_epochs =binary_cfg.training.epochs,\n", - " accelerator =\"auto\",\n", - " devices =1,\n", - " precision =\"32\",\n", - " gradient_clip_val =1.0,\n", - " log_every_n_steps =5,\n", - " callbacks =[binary_ckpt, LearningRateMonitor()],\n", - " logger =binary_logger,\n", - " enable_progress_bar=True,\n", - ")\n", - "\n", - "binary_trainer.fit(binary_model, datamodule=binary_dm)\n", - "binary_trainer.test(binary_model, datamodule=binary_dm, ckpt_path=\"best\")\n", - "print(f\"\\nBest checkpoint : {binary_ckpt.best_model_path}\")\n", - "print(f\"Best val/f1 : {binary_ckpt.best_model_score:.4f}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Plot loss and validation F1 curves from the CSVLogger output to inspect training dynamics.\n", - "`IPython.display` is used explicitly to ensure the figure renders inline after the trainer output.\n", - "\n", - "**Expected output:** two side-by-side plots — train/val loss and val F1 vs. epoch." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from IPython.display import display as ipy_display\n", - "\n", - "metrics_path = Path(binary_logger.log_dir) / \"metrics.csv\"\n", - "metrics = pd.read_csv(metrics_path)\n", - "\n", - "train_metrics = metrics[metrics[\"train/loss\"].notna()].copy()\n", - "val_metrics = metrics[metrics[\"val/loss\"].notna()].copy()\n", - "\n", - "fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))\n", - "\n", - "ax1.plot(train_metrics[\"epoch\"], train_metrics[\"train/loss\"], label=\"train\")\n", - "ax1.plot(val_metrics[\"epoch\"], val_metrics[\"val/loss\"], label=\"val\")\n", - "ax1.set_xlabel(\"Epoch\")\n", - "ax1.set_ylabel(\"Loss\")\n", - "ax1.set_title(\"Binary — Loss\")\n", - "ax1.legend()\n", - "\n", - "ax2.plot(val_metrics[\"epoch\"], val_metrics[\"val/f1\"], color=\"#1976D2\")\n", - "ax2.set_xlabel(\"Epoch\")\n", - "ax2.set_ylabel(\"F1\")\n", - "ax2.set_title(\"Binary — Validation F1\")\n", - "\n", - "plt.tight_layout()\n", - "ipy_display(fig)\n", - "plt.close(fig)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Multiclass Classification" - ] - }, - { - "cell_type": "markdown", - "id": "a97276c2", - "metadata": {}, - "source": [ - "Count annotation occurrences per species across all five recordings and derive the\n", - "`TOP_SPECIES` list automatically from the data (the N most-annotated species).\n", - "A bar chart highlights which species are used in the multiclass demo.\n", - "\n", - "**Expected output:** a horizontal bar chart of the top-10 species and printed counts." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5a9a6d75", - "metadata": {}, - "outputs": [], - "source": [ - "species_counts = {}\n", - "for stem in PPA4_FILES:\n", - " df = pd.read_csv(LABELS_DIR / f\"{stem}.Table.1.selections.txt\", delimiter=\"\\t\")\n", - " for det in df[\"Determination\"].dropna():\n", - " det = str(det).strip()\n", - " if det:\n", - " species_counts[det] = species_counts.get(det, 0) + 1\n", - "\n", - "# Derive TOP_SPECIES automatically from the most frequent species in the data\n", - "N_TOP = 4\n", - "top_all = sorted(species_counts.items(), key=lambda x: -x[1])\n", - "TOP_SPECIES = [sp for sp, _ in top_all[:N_TOP]]\n", - "\n", - "top10 = top_all[:10]\n", - "labels_plot = [sp for sp, _ in top10]\n", - "counts_plot = [c for _, c in top10]\n", - "colors = [\"#1976D2\" if sp in TOP_SPECIES else \"#B0BEC5\" for sp in labels_plot]\n", - "\n", - "fig, ax = plt.subplots(figsize=(8, 4))\n", - "ax.barh(labels_plot[::-1], counts_plot[::-1], color=colors[::-1])\n", - "ax.set_xlabel(\"Annotation count (across 5 recordings)\")\n", - "ax.set_title(f\"Top 10 species — blue bars = top-{N_TOP} species used for multiclass demo\")\n", - "patch_blue = mpatches.Patch(color=\"#1976D2\", label=f\"Top-{N_TOP} multiclass species\")\n", - "patch_grey = mpatches.Patch(color=\"#B0BEC5\", label=\"Other / background\")\n", - "ax.legend(handles=[patch_blue, patch_grey], fontsize=8)\n", - "plt.tight_layout()\n", - "plt.show()\n", - "\n", - "print(f\"TOP_SPECIES (derived from data): {TOP_SPECIES}\")\n", - "print(f\"Top-{N_TOP} species counts: { {sp: species_counts.get(sp, 0) for sp in TOP_SPECIES} }\")\n", - "total_pos = sum(species_counts.values())\n", - "print(f\"All annotated events (binary positives): {total_pos}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**Task:** classify windows into one of 5 categories — `noise`, `CACCEL`, `CYAVIO`, `PSAANG`, `RAMCAR`.\n", - "\n", - "Key differences from binary:\n", - "- `build_windows(multiclass=True)` → `label = category_id` (0–4) instead of 0/1\n", - "- `ResNetClassifier(num_classes=5)` → CrossEntropyLoss + softmax instead of BCEWithLogitsLoss\n", - "- Spectrograms are **shared** — only new windows (if any) are computed" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Write the multiclass config to YAML and load it. The key differences vs. the binary config are\n", - "`multiclass=True` in the audio section and `num_classes=5` in training.\n", - "\n", - "**Expected output:** printed config summary confirming 5 classes and multiclass mode." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "MULTICLASS_CONFIG_PATH = str(DEMO_DIR / \"config\" / \"pteroset_multiclass.yaml\")\n", - "\n", - "multiclass_cfg_dict = {\n", - " \"name\": \"pteroset_multiclass\",\n", - " \"datasets\": [\"audios\"],\n", - " \"class_names\": {0: \"noise\", 1: \"CACCEL\", 2: \"CYAVIO\", 3: \"PSAANG\", 4: \"RAMCAR\"},\n", - " \"paths\": {\n", - " \"data_root\": str(DATA_DIR),\n", - " \"output_root\": str(MULTICLASS_DIR),\n", - " \"spectrograms_dir\": str(SPEC_DIR),\n", - " \"annotations_file\": \"multiclass_annotations.json\",\n", - " },\n", - " \"audio\": {\n", - " \"sample_rate\": 48000,\n", - " \"window_size_sec\": 5.0,\n", - " \"overlap_sec\": 4.0,\n", - " \"window_strategy\": \"sliding\",\n", - " \"multiclass\": True,\n", - " },\n", - " \"spectrogram\": {\n", - " \"n_fft\": 2048,\n", - " \"hop_length\": 512,\n", - " \"n_mels\": 128,\n", - " \"top_db\": 80.0,\n", - " \"fill_highfreq\": False,\n", - " \"noise_db_std\": 3.0,\n", - " \"storage_dtype\": \"float32\",\n", - " },\n", - " \"training\": {\n", - " \"batch_size\": 16,\n", - " \"num_workers\": 2,\n", - " \"lr\": 1e-4,\n", - " \"weight_decay\": 1e-4,\n", - " \"epochs\": 10,\n", - " \"backbone\": \"resnet18\",\n", - " \"num_classes\": 5,\n", - " \"target_size\": [128, 465],\n", - " \"normalize\": True,\n", - " \"use_specaug\": True,\n", - " },\n", - " \"splits\": {\n", - " \"test_size\": 0.2,\n", - " \"val_size\": 0.2,\n", - " \"n_splits\": 3,\n", - " \"random_state\": 42,\n", - " },\n", - "}\n", - "\n", - "with open(MULTICLASS_CONFIG_PATH, \"w\") as f:\n", - " yaml.dump(multiclass_cfg_dict, f, default_flow_style=False, sort_keys=False)\n", - "\n", - "multiclass_cfg = load_config(MULTICLASS_CONFIG_PATH)\n", - "print(f\"Config loaded: {multiclass_cfg.name}\")\n", - "print(f\" num_classes : {multiclass_cfg.training.num_classes}\")\n", - "print(f\" multiclass : {multiclass_cfg.audio.multiclass}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Build the multiclass sliding windows. Only annotations for the top-4 species receive a\n", - "positive label; all other windows are background (label=0).\n", - "\n", - "**Expected output:** total window count and label distribution across the 5 classes." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "multiclass_windows = run_windows(multiclass_cfg)\n", - "print(f\"\\nLabel distribution: {pd.Series([w['label'] for w in multiclass_windows]).value_counts().to_dict()}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Compute spectrograms for multiclass windows. Because the same 5-second window definition\n", - "is shared with the binary task, most files already exist and are skipped.\n", - "\n", - "**Expected output:** count of new vs. cached spectrograms." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "with open(multiclass_cfg.paths.annotations_path) as f:\n", - " mc_anno_data = json.load(f)\n", - "mc_sounds_map = {s[\"id\"]: s for s in mc_anno_data[\"sounds\"]}\n", - "\n", - "mc_inf_windows = [\n", - " {\n", - " \"window_id\": w[\"window_id\"],\n", - " \"sound_path\": mc_sounds_map[w[\"sound_id\"]][\"file_name_path\"],\n", - " \"start\": w[\"start\"],\n", - " \"end\": w[\"end\"],\n", - " }\n", - " for w in multiclass_windows\n", - " if w[\"sound_id\"] in mc_sounds_map\n", - "]\n", - "\n", - "compute_mel_spectrograms_gpu(\n", - " windows=mc_inf_windows,\n", - " sample_rate=multiclass_cfg.audio.sample_rate,\n", - " n_fft=multiclass_cfg.spectrogram.n_fft,\n", - " hop_length=multiclass_cfg.spectrogram.hop_length,\n", - " n_mels=multiclass_cfg.spectrogram.n_mels,\n", - " top_db=multiclass_cfg.spectrogram.top_db,\n", - " spectrograms_path=str(SPEC_DIR),\n", - " save_npy=True,\n", - " fill_highfreq=multiclass_cfg.spectrogram.fill_highfreq,\n", - " noise_db_std=multiclass_cfg.spectrogram.noise_db_std,\n", - " storage_dtype=multiclass_cfg.spectrogram.storage_dtype,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Split multiclass windows into train / val / test, reusing the same `create_splits` helper.\n", - "\n", - "**Expected output:** split sizes and per-split label distributions for all 5 classes." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "multiclass_train, multiclass_val, multiclass_test = create_splits(\n", - " multiclass_windows, str(SPEC_DIR), mc_anno_data[\"sounds\"],\n", - " str(MULTICLASS_DIR), multiclass_cfg\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Instantiate `ResNetClassifier` in multiclass mode (5 classes, CrossEntropyLoss) and train.\n", - "MixUp is disabled for multiclass to avoid mixing soft labels across 5 categories.\n", - "\n", - "**Expected output:** training progress bars, per-epoch metrics, and final test results." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "pl.seed_everything(42)\n", - "\n", - "mc_dm_cfg = DataModuleConfig(\n", - " train_csv =str(MULTICLASS_DIR / \"train_split.csv\"),\n", - " val_csv =str(MULTICLASS_DIR / \"val_split.csv\"),\n", - " test_csv =str(MULTICLASS_DIR / \"test_split.csv\"),\n", - " root =str(SPEC_DIR),\n", - " target_size=multiclass_cfg.training.target_size,\n", - " batch_size =multiclass_cfg.training.batch_size,\n", - " num_workers=multiclass_cfg.training.num_workers,\n", - " use_specaug=multiclass_cfg.training.use_specaug,\n", - " normalize =multiclass_cfg.training.normalize,\n", - " num_classes=multiclass_cfg.training.num_classes,\n", - " use_mixup =False,\n", - " pin_memory =False,\n", - ")\n", - "mc_dm = SpectrogramDataModule(mc_dm_cfg)\n", - "mc_dm.setup()\n", - "\n", - "mc_model = ResNetClassifier(\n", - " num_classes =multiclass_cfg.training.num_classes,\n", - " in_channels =mc_dm.in_channels,\n", - " backbone =multiclass_cfg.training.backbone,\n", - " lr =multiclass_cfg.training.lr,\n", - " weight_decay =multiclass_cfg.training.weight_decay,\n", - " T_max =multiclass_cfg.training.epochs,\n", - " batch_size =multiclass_cfg.training.batch_size,\n", - " class_names =list(multiclass_cfg.class_names.values()),\n", - ")\n", - "print(f\"Mode : {'Binary' if mc_dm.is_binary else 'Multiclass'}\")\n", - "print(f\"Classes : {mc_dm.num_classes}\")\n", - "\n", - "mc_logger = CSVLogger(str(MULTICLASS_DIR), name=\"logs\")\n", - "mc_ckpt = ModelCheckpoint(\n", - " monitor=\"val/f1\", mode=\"max\", save_top_k=1,\n", - " dirpath=str(MULTICLASS_DIR / \"checkpoints\"),\n", - " filename=\"multiclass-{epoch:02d}-{val/f1:.3f}\",\n", - ")\n", - "mc_trainer = pl.Trainer(\n", - " max_epochs =multiclass_cfg.training.epochs,\n", - " accelerator =\"auto\",\n", - " devices =1,\n", - " precision =\"32\",\n", - " gradient_clip_val =1.0,\n", - " log_every_n_steps =5,\n", - " callbacks =[mc_ckpt, LearningRateMonitor()],\n", - " logger =mc_logger,\n", - " enable_progress_bar=True,\n", - ")\n", - "\n", - "mc_trainer.fit(mc_model, datamodule=mc_dm)\n", - "mc_trainer.test(mc_model, datamodule=mc_dm, ckpt_path=\"best\")\n", - "print(f\"\\nBest checkpoint : {mc_ckpt.best_model_path}\")\n", - "print(f\"Best val/f1 : {mc_ckpt.best_model_score:.4f}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Plot loss and macro-F1 curves from the multiclass CSVLogger.\n", - "`IPython.display` is used explicitly to ensure the figure renders inline after the trainer output.\n", - "\n", - "**Expected output:** two side-by-side plots — train/val loss and val macro-F1 vs. epoch." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "mc_metrics = pd.read_csv(Path(mc_logger.log_dir) / \"metrics.csv\")\n", - "mc_train = mc_metrics[mc_metrics[\"train/loss\"].notna()].copy()\n", - "mc_val = mc_metrics[mc_metrics[\"val/loss\"].notna()].copy()\n", - "\n", - "fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))\n", - "ax1.plot(mc_train[\"epoch\"], mc_train[\"train/loss\"], label=\"train\")\n", - "ax1.plot(mc_val[\"epoch\"], mc_val[\"val/loss\"], label=\"val\")\n", - "ax1.set_xlabel(\"Epoch\")\n", - "ax1.set_ylabel(\"Loss\")\n", - "ax1.set_title(\"Multiclass — Loss\")\n", - "ax1.legend()\n", - "\n", - "ax2.plot(mc_val[\"epoch\"], mc_val[\"val/f1\"], color=\"#388E3C\")\n", - "ax2.set_xlabel(\"Epoch\")\n", - "ax2.set_ylabel(\"F1 (macro)\")\n", - "ax2.set_title(\"Multiclass — Validation F1\")\n", - "\n", - "plt.tight_layout()\n", - "ipy_display(fig)\n", - "plt.close(fig)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "birds", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.19" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/PW_Bioacoustics/inference.py b/PW_Bioacoustics/inference.py deleted file mode 100644 index 26d227b51..000000000 --- a/PW_Bioacoustics/inference.py +++ /dev/null @@ -1,430 +0,0 @@ -""" -Unified inference script supporting both binary and multiclass classification. - -Usage: - # Binary inference (default) - python inference.py --checkpoint model.ckpt --audios_source /path/to/audio --dataset birds - - # Multiclass inference - python inference.py --checkpoint model.ckpt --audios_source /path/to/audio --dataset whales \ - --num_classes 4 --class_names "No Whale,Humpback,Orca,Beluga" - - # Using config file - python inference.py --config config/whales.yaml --checkpoint model.ckpt --audios_source /path/to/audio -""" - -import os -import argparse -import re -import json -import math -from pathlib import Path -from typing import Optional, List, Dict, Union - -import numpy as np -import pandas as pd -from tqdm import tqdm - -import torch -import torch.nn.functional as F -from torch.utils.data import DataLoader - -# Import from PytorchWildlife core library -from PytorchWildlife.models.bioacoustics import ResNetClassifier -from PytorchWildlife.models.bioacoustics.resnet_classifier import load_model_from_checkpoint -from PytorchWildlife.data.bioacoustics.bioacoustics_configs import load_config -from PytorchWildlife.data.bioacoustics.bioacoustics_datasets import BioacousticsInferenceDataset -from PytorchWildlife.data.bioacoustics.bioacoustics_windows import build_inference_windows -from PytorchWildlife.data.bioacoustics.bioacoustics_spectrograms import compute_mel_spectrograms_gpu - - -def run_inference_batch( - model: ResNetClassifier, - dataloader: DataLoader, - sample_rate: int, - num_classes: int = 2, - annotations_json: Optional[str] = None, - device: str = "cuda", - temperature: float = 1.0, -) -> Dict[str, np.ndarray]: - """ - Run inference on a batch of data. Supports both binary and multiclass. - """ - is_binary = (num_classes == 2) - model.eval() - all_paths = [] - all_logits = [] - - print(f"Running inference on {len(dataloader)} batches...") - print(f"Mode: {'binary' if is_binary else f'multiclass ({num_classes} classes)'}") - - with torch.no_grad(): - for batch_idx, batch in tqdm(enumerate(dataloader), total=len(dataloader)): - x, paths = batch - x = x.to(device) - - logits = model(x) - if is_binary: - logits = logits.squeeze(1) - all_logits.append(logits.cpu().numpy()) - all_paths.extend(paths) - - # Parse audio paths, starts, and ends - annotations = None - if annotations_json is not None: - with open(annotations_json, "r") as f: - annotations = json.load(f) - - audios = [] - starts = [] - ends = [] - for p in all_paths: - if "start" in p and "end" in p: - try: - sound_id = int(re.search(r'sid(\d+)_', p).group(1)) - if annotations: - audios.append(next(s["file_name_path"] for s in annotations["sounds"] if s["id"] == sound_id)) - else: - audios.append(f"sound_{sound_id}") - starts.append(float(re.search(r'start(\d+)_end', p).group(1)) / sample_rate) - ends.append(float(re.search(r'end(\d+)\_lab', p).group(1)) / sample_rate) - except (AttributeError, StopIteration): - basename = os.path.basename(p).replace(".npy", "") - parts = basename.split("_") - audios.append("_".join(parts[:-2])) - starts.append(int(parts[-2]) / sample_rate) - ends.append(int(parts[-1]) / sample_rate) - else: - basename = os.path.basename(p).replace(".npy", "") - parts = basename.split("_") - audios.append("_".join(parts[:-2])) - starts.append(int(parts[-2]) / sample_rate) - ends.append(int(parts[-1]) / sample_rate) - - all_logits = np.concatenate(all_logits) - - if is_binary: - scaled_logits = all_logits / temperature - probabilities = 1 / (1 + np.exp(-scaled_logits)) - predictions = (probabilities > 0.5).astype(int) - else: - logits_tensor = torch.tensor(all_logits) / temperature - probabilities = F.softmax(logits_tensor, dim=1).numpy() - predictions = probabilities.argmax(axis=1) - - return { - 'paths': all_paths, - 'audios': audios, - 'starts': starts, - 'ends': ends, - 'predictions': predictions, - 'probabilities': probabilities, - } - - -def process_inference_results_per_second(csv_path: str) -> pd.DataFrame: - """ - Process inference results CSV and obtain a prediction for each second, - averaging the predictions that overlap according to the start(s) and end(s) columns. - """ - df = pd.read_csv(csv_path) - unique_audios = df['audio'].unique() - - all_results = [] - - for audio in unique_audios: - audio_df = df[df['audio'] == audio].copy() - - min_start = int(np.floor(audio_df['start(s)'].min())) - max_end = int(np.ceil(audio_df['end(s)'].max())) - - for second in range(min_start, max_end): - overlapping = audio_df[ - ((audio_df['start(s)'] <= second) & (audio_df['end(s)'] > second)) | - ((audio_df['start(s)'] < second + 1) & (audio_df['end(s)'] >= second + 1)) - ] - - if len(overlapping) > 0: - weights = [] - for _, row in overlapping.iterrows(): - overlap_start = max(row['start(s)'], second) - overlap_end = min(row['end(s)'], second + 1) - overlap_duration = max(0, overlap_end - overlap_start) - weights.append(overlap_duration) - - weights = np.array(weights) - - if weights.sum() > 0: - weights = weights / weights.sum() - - avg_prediction = np.average(overlapping['prediction'], weights=weights) - avg_probability = np.average(overlapping['probability'], weights=weights) - avg_confidence = np.average(overlapping['confidence'], weights=weights) - - all_results.append({ - 'audio': audio, - 'second': second, - 'count_overlaps': len(overlapping), - 'prediction': 1 if avg_prediction >= 0.5 else 0, - 'avg_prediction': avg_prediction, - 'avg_probability': avg_probability, - 'avg_confidence': avg_confidence, - }) - - results_df = pd.DataFrame(all_results) - results_df = results_df.sort_values(['audio', 'second']).reset_index(drop=True) - - output_dir = os.path.dirname(csv_path) - output_path = os.path.join(output_dir, 'per_second_results.csv') - - results_df.to_csv(output_path, index=False) - print(f"Per-second results saved to: {output_path}") - - return results_df - - -def save_inference_results( - results: Dict, - output_path: str, - num_classes: int, - class_names: Optional[List[str]] = None, -) -> pd.DataFrame: - """Save inference results to CSV in appropriate format.""" - is_binary = (num_classes == 2) - - if is_binary: - results_df = pd.DataFrame({ - 'audio': results['audios'], - 'start(s)': results['starts'], - 'end(s)': results['ends'], - 'prediction': results['predictions'], - 'probability': results['probabilities'], - 'confidence': np.abs(results['probabilities'] - 0.5) * 2, - }) - results_df = results_df.sort_values('confidence', ascending=False) - else: - data = { - 'file_path': results['paths'], - 'audio': results['audios'], - 'start(s)': results['starts'], - 'end(s)': results['ends'], - 'prediction': results['predictions'], - } - - if class_names is None: - class_names = [f"class_{i}" for i in range(num_classes)] - - for i, name in enumerate(class_names): - col_name = name.replace(" ", "_") + "_prob" - data[col_name] = results['probabilities'][:, i] - - results_df = pd.DataFrame(data) - - results_df.to_csv(output_path, index=False) - print(f"Results saved to: {output_path}") - return results_df - - -def main(): - parser = argparse.ArgumentParser(description="Run inference on bioacoustic sounds") - - # Config file (optional) - parser.add_argument("--config", type=str, default=None, help="Path to YAML config file") - - # Audio source - parser.add_argument("--audios_source", type=str, required=False, - help="Path to folder, JSON, or CSV with windows") - - # Classification mode - parser.add_argument("--num_classes", type=int, default=2, - help="Number of classes (2=binary, >2=multiclass)") - parser.add_argument("--class_names", type=str, nargs="+", default=None, - help="Class names for multiclass") - - # Audio parameters - parser.add_argument("--window_size_sec", type=float, default=5.0) - parser.add_argument("--overlap_sec", type=float, default=4.0) - parser.add_argument("--sample_rate", type=int, default=48000) - - # Spectrogram parameters - parser.add_argument("--n_fft", type=int, default=2048) - parser.add_argument("--hop_length", type=int, default=512) - parser.add_argument("--n_mels", type=int, default=224) - parser.add_argument("--top_db", type=float, default=80.0) - - # Model and inference - parser.add_argument("--checkpoint", type=str, required=False) - parser.add_argument("--device", type=str, default="cuda") - parser.add_argument("--batch_size", type=int, default=64) - parser.add_argument("--num_workers", type=int, default=1) - parser.add_argument("--temperature", type=float, default=1.0) - - # Output - parser.add_argument("--dataset", type=str, help="Dataset name for output directory") - parser.add_argument("--normalize", action="store_true") - parser.add_argument("--spectrograms_path", type=str, default=None) - parser.add_argument("--annotations_json", type=str, default=None, - help="Annotations JSON for mapping sound IDs to paths") - - args = parser.parse_args() - - # Load config if provided - if args.config: - cfg = load_config(args.config) - - if args.num_classes == 2 and cfg.training.num_classes != 2: - args.num_classes = cfg.training.num_classes - if args.class_names is None and cfg.class_names: - args.class_names = list(cfg.class_names.values()) - if args.window_size_sec == 5.0: - args.window_size_sec = cfg.audio.window_size_sec - if args.overlap_sec == 4.0: - args.overlap_sec = cfg.audio.overlap_sec - if args.sample_rate == 48000: - args.sample_rate = cfg.audio.sample_rate - if args.hop_length == 512: - args.hop_length = cfg.spectrogram.hop_length - if args.n_mels == 224: - args.n_mels = cfg.spectrogram.n_mels - if args.n_fft == 2048: - args.n_fft = cfg.spectrogram.n_fft - if args.top_db == 80.0: - args.top_db = cfg.spectrogram.top_db - if not args.dataset: - args.dataset = cfg.name - - is_binary = (args.num_classes == 2) - print(f"Running {'binary' if is_binary else f'multiclass ({args.num_classes} classes)'} inference") - - # Build windows - if args.audios_source.endswith('.json'): - with open(args.audios_source, 'r') as in_file: - windows = json.load(in_file) - df = pd.DataFrame(windows) - elif args.audios_source.endswith('.csv'): - df = pd.read_csv(args.audios_source) - windows = df.to_dict('records') - else: - windows = build_inference_windows( - audios_source=args.audios_source, - window_size_sec=args.window_size_sec, - overlap_sec=args.overlap_sec, - sample_rate=args.sample_rate, - ) - df = pd.DataFrame(windows) - output_dir = os.path.join(".", "inference", args.dataset) - os.makedirs(output_dir, exist_ok=True) - windows_path = os.path.join(output_dir, f"{args.dataset}_windows.json") - with open(windows_path, 'w') as out_file: - json.dump(windows, out_file, indent=2) - print(f"Windows saved to: {windows_path}") - - # Setup output and spectrograms directories - output_dir = os.path.join(".", "inference", args.dataset) - os.makedirs(output_dir, exist_ok=True) - - if args.spectrograms_path: - spectrograms_path = args.spectrograms_path - else: - spectrograms_path = os.path.join(output_dir, "spectrograms") - os.makedirs(spectrograms_path, exist_ok=True) - compute_mel_spectrograms_gpu( - windows=windows, - sample_rate=args.sample_rate, - n_fft=args.n_fft, - hop_length=args.hop_length, - n_mels=args.n_mels, - top_db=args.top_db, - spectrograms_path=spectrograms_path, - save_npy=True, - fill_highfreq=True, - noise_db_mean=None, - noise_db_std=3.0, - storage_dtype="float32", - ) - - # Build spec_name column - if 'spec_name' not in df.columns and 'file_path' not in df.columns: - if 'sound_id' in df.columns and 'label' in df.columns: - df['spec_name'] = df.apply( - lambda row: os.path.join(spectrograms_path, - f"sid{row.sound_id}_idx{row.window_id}_start{row.start}_end{row.end}_lab{row.label}.npy"), - axis=1 - ) - elif 'sound_path' in df.columns: - df['spec_name'] = df.apply( - lambda row: os.path.join(spectrograms_path, - f"{os.path.splitext(os.path.basename(row['sound_path']))[0]}_{row['start']}_{row['end']}.npy"), - axis=1 - ) - - x_col = 'file_path' if 'file_path' in df.columns else 'spec_name' - - # Calculate target size - n_frames = int(np.ceil((args.window_size_sec * args.sample_rate - args.n_fft) / args.hop_length)) + 1 - target_size = (args.n_mels, n_frames) - print(f"Spectrogram size: {target_size}") - - # Create dataset - dataset = BioacousticsInferenceDataset( - dataframe=df, - x_col=x_col, - target_size=target_size, - normalize=args.normalize, - ) - print(f"Created dataset with {len(dataset)} samples") - - # Create dataloader - dataloader = DataLoader( - dataset, - batch_size=args.batch_size, - shuffle=False, - num_workers=args.num_workers, - pin_memory=True if args.device == "cuda" else False - ) - - # Check device availability - if args.device == "cuda" and not torch.cuda.is_available(): - print("CUDA not available, switching to CPU") - args.device = "cpu" - print(f"Using device: {args.device}") - - # Load model - try: - model = load_model_from_checkpoint(args.checkpoint, args.device) - print("Model loaded successfully") - except Exception as e: - print(f"Error loading model: {e}") - return - - # Run inference - try: - results = run_inference_batch( - model=model, - dataloader=dataloader, - sample_rate=args.sample_rate, - num_classes=args.num_classes, - annotations_json=args.annotations_json, - device=args.device, - temperature=args.temperature, - ) - print("Inference completed successfully") - except Exception as e: - print(f"Error during inference: {e}") - return - - # Save results - suffix = "binary" if is_binary else "multiclass" - results_path = os.path.join(output_dir, f"{suffix}_inference_results.csv") - save_inference_results( - results=results, - output_path=results_path, - num_classes=args.num_classes, - class_names=args.class_names, - ) - - print("Inference pipeline completed successfully!") - - -if __name__ == "__main__": - main() diff --git a/PW_Bioacoustics/prepare_dataset.py b/PW_Bioacoustics/prepare_dataset.py deleted file mode 100644 index 2ee8e52f8..000000000 --- a/PW_Bioacoustics/prepare_dataset.py +++ /dev/null @@ -1,365 +0,0 @@ -""" -Generic dataset preparation script for PW_Bioacoustics. - -Usage: - # Full pipeline - python prepare_dataset.py --config config/template.yaml - - # Run specific steps only - python prepare_dataset.py --config config/template.yaml --steps stats windows - - # Available steps: stats, windows, spectrograms, splits -""" - -import os -import argparse -import json -from pathlib import Path -from typing import List, Optional - -import pandas as pd - -# Import from PytorchWildlife core library -from PytorchWildlife.data.bioacoustics.bioacoustics_configs import load_config, DomainConfig -from PytorchWildlife.data.bioacoustics.bioacoustics_windows import build_windows - - -def count_window_labels(windows: List[dict]) -> dict: - """Count label distribution in windows.""" - counts = {} - for w in windows: - label = w.get('label', 0) - counts[label] = counts.get(label, 0) + 1 - return counts - - -def run_stats(config: DomainConfig) -> None: - """Load and display dataset statistics.""" - print(f"\n{'='*60}") - print(f"Step: Dataset Statistics") - print(f"{'='*60}") - - annotation_path = config.paths.annotations_path - print(f"Loading annotations from: {annotation_path}") - - if not os.path.exists(annotation_path): - print(f"Warning: Annotations file not found: {annotation_path}") - return - - with open(annotation_path, 'r') as f: - data = json.load(f) - - # Dataset info - if 'info' in data: - print(f"\nDataset Info:") - for key, value in data['info'].items(): - print(f" - {key}: {value}") - - # Sound statistics - sounds = data.get('sounds', []) - print(f"\nSounds: {len(sounds)}") - if sounds: - durations = [s.get('duration', 0) for s in sounds] - print(f" - Total duration: {sum(durations):.1f}s ({sum(durations)/3600:.2f}h)") - print(f" - Mean duration: {sum(durations)/len(durations):.1f}s") - print(f" - Min duration: {min(durations):.1f}s") - print(f" - Max duration: {max(durations):.1f}s") - - # Annotation statistics - annotations = data.get('annotations', []) - print(f"\nAnnotations: {len(annotations)}") - if annotations: - categories = {} - for ann in annotations: - cat_id = ann.get('category_id', 0) - categories[cat_id] = categories.get(cat_id, 0) + 1 - print(f" - By category: {categories}") - - # Category names - if 'categories' in data: - print(f"\nCategories:") - for cat in data['categories']: - print(f" - {cat.get('id', '?')}: {cat.get('name', 'Unknown')}") - - -def run_windows(config: DomainConfig) -> List[dict]: - """Build windows from annotations.""" - print(f"\n{'='*60}") - print(f"Step: Build Windows") - print(f"{'='*60}") - - annotation_path = config.paths.annotations_path - output_dir = config.paths.output_root - os.makedirs(output_dir, exist_ok=True) - - windows_output_path = os.path.join( - output_dir, - f"windows_mapping_{config.audio.overlap_sec}overlap.json" - ) - - if os.path.exists(windows_output_path): - print(f"Loading existing windows from: {windows_output_path}") - with open(windows_output_path, 'r') as f: - windows = json.load(f) - print(f"Loaded {len(windows)} windows") - else: - strategy = config.audio.window_strategy - print(f"Building windows with:") - print(f" - strategy: {strategy}") - print(f" - window_size: {config.audio.window_size_sec}s") - print(f" - overlap: {config.audio.overlap_sec}s") - print(f" - sample_rate: {config.audio.sample_rate}") - print(f" - datasets: {config.datasets}") - if strategy == "balanced": - print(f" - negative_proportion: {config.audio.negative_proportion}") - - windows = build_windows( - annotation_file=annotation_path, - window_size_sec=config.audio.window_size_sec, - overlap_sec=config.audio.overlap_sec, - sample_rate=config.audio.sample_rate, - datasets_names=config.datasets, - strategy=strategy, - negative_proportion=config.audio.negative_proportion, - ) - - with open(windows_output_path, 'w') as f: - json.dump(windows, f, indent=2) - print(f"Saved {len(windows)} windows to: {windows_output_path}") - - # Show label distribution - counts = count_window_labels(windows) - print(f"\nLabel distribution: {counts}") - - return windows - - -def run_spectrograms(config: DomainConfig, windows: List[dict]) -> None: - """Compute mel spectrograms using GPU.""" - # Import here to avoid loading torch unnecessarily - from inference import compute_mel_spectrograms_gpu - - print(f"\n{'='*60}") - print(f"Step: Compute Mel Spectrograms (GPU)") - print(f"{'='*60}") - - spectrograms_dir = config.paths.spectrograms_dir - os.makedirs(spectrograms_dir, exist_ok=True) - - print(f"Output directory: {spectrograms_dir}") - print(f"Spectrogram parameters:") - print(f" - n_fft: {config.spectrogram.n_fft}") - print(f" - hop_length: {config.spectrogram.hop_length}") - print(f" - n_mels: {config.spectrogram.n_mels}") - print(f" - top_db: {config.spectrogram.top_db}") - print(f" - fill_highfreq: {config.spectrogram.fill_highfreq}") - - # Load annotations to get audio file paths - with open(config.paths.annotations_path, 'r') as f: - annotations = json.load(f) - - sounds = {s['id']: s for s in annotations['sounds']} - - # Convert windows format to include sound_path - inference_windows = [] - for win in windows: - sound = sounds.get(win['sound_id']) - if sound: - inference_windows.append({ - 'window_id': win['window_id'], - 'sound_path': sound['file_name_path'], - 'start': win['start'], - 'end': win['end'], - }) - - compute_mel_spectrograms_gpu( - windows=inference_windows, - sample_rate=config.audio.sample_rate, - n_fft=config.spectrogram.n_fft, - hop_length=config.spectrogram.hop_length, - n_mels=config.spectrogram.n_mels, - top_db=config.spectrogram.top_db, - spectrograms_path=spectrograms_dir, - save_npy=True, - fill_highfreq=config.spectrogram.fill_highfreq, - noise_db_std=config.spectrogram.noise_db_std, - storage_dtype=config.spectrogram.storage_dtype, - ) - - print("Spectrogram computation complete!") - - -def run_splits(config: DomainConfig, windows: List[dict]) -> None: - """Create train/val/test splits using stratified group splitting.""" - from sklearn.model_selection import GroupShuffleSplit, StratifiedGroupKFold - - print(f"\n{'='*60}") - print(f"Step: Create Data Splits") - print(f"{'='*60}") - - spectrograms_dir = config.paths.spectrograms_dir - output_dir = config.paths.output_root - os.makedirs(output_dir, exist_ok=True) - - print(f"Spectrograms directory: {spectrograms_dir}") - print(f"Output directory: {output_dir}") - print(f"Split parameters:") - print(f" - test_size: {config.splits.test_size}") - print(f" - val_size: {config.splits.val_size}") - print(f" - n_splits: {config.splits.n_splits}") - print(f" - random_state: {config.splits.random_state}") - - # Build dataframe from windows - df = pd.DataFrame(windows) - - # Add spectrogram name column - df['spec_name'] = df.apply( - lambda row: f"sid{row['sound_id']}_idx{row['window_id']}_start{row['start']}_end{row['end']}_lab{row['label']}.npy", - axis=1 - ) - - # Check which spectrograms exist - df['spec_exists'] = df['spec_name'].apply( - lambda x: os.path.exists(os.path.join(spectrograms_dir, x)) - ) - - print(f"\nTotal windows: {len(df)}") - print(f"Existing spectrograms: {df['spec_exists'].sum()}") - - # Filter to existing spectrograms only - df = df[df['spec_exists']].drop(columns=['spec_exists']) - - # Step 1: Split into train+val vs test (grouped by sound_id, stratified by label) - gss = GroupShuffleSplit( - n_splits=1, - test_size=config.splits.test_size, - random_state=config.splits.random_state - ) - trainval_idx, test_idx = next(gss.split(df, df['label'], groups=df['sound_id'])) - - trainval_df = df.iloc[trainval_idx].copy() - test_df = df.iloc[test_idx].copy() - - # Step 2: Split trainval into train vs val (grouped by sound_id, stratified by label) - sgkf = StratifiedGroupKFold( - n_splits=config.splits.n_splits, - shuffle=True, - random_state=config.splits.random_state - ) - train_idx, val_idx = next(sgkf.split(trainval_df, trainval_df['label'], trainval_df['sound_id'])) - - train_df = trainval_df.iloc[train_idx].copy() - val_df = trainval_df.iloc[val_idx].copy() - - # Save splits - train_df.to_csv(os.path.join(output_dir, 'train_split.csv'), index=False) - val_df.to_csv(os.path.join(output_dir, 'val_split.csv'), index=False) - test_df.to_csv(os.path.join(output_dir, 'test_split.csv'), index=False) - - print(f"\nSplit sizes:") - print(f" - Train: {len(train_df)} samples ({train_df['sound_id'].nunique()} sounds)") - print(f" - Val: {len(val_df)} samples ({val_df['sound_id'].nunique()} sounds)") - print(f" - Test: {len(test_df)} samples ({test_df['sound_id'].nunique()} sounds)") - - # Show label distribution by split - print("\nLabel distribution by split:") - for name, split_df in [('Train', train_df), ('Val', val_df), ('Test', test_df)]: - label_counts = split_df['label'].value_counts().to_dict() - print(f" - {name}: {label_counts}") - - -def load_windows_if_exists(config: DomainConfig) -> Optional[List[dict]]: - """Load windows from file if they exist.""" - output_dir = config.paths.output_root - windows_output_path = os.path.join( - output_dir, - f"windows_mapping_{config.audio.overlap_sec}overlap.json" - ) - - if os.path.exists(windows_output_path): - with open(windows_output_path, 'r') as f: - return json.load(f) - return None - - -def main(): - parser = argparse.ArgumentParser( - description="Prepare dataset for training", - formatter_class=argparse.RawDescriptionHelpFormatter, - epilog=""" -Examples: - # Full pipeline - python prepare_dataset.py --config config/template.yaml - - # Only compute statistics and build windows - python prepare_dataset.py --config config/template.yaml --steps stats windows - - # Only compute spectrograms (windows must already exist) - python prepare_dataset.py --config config/template.yaml --steps spectrograms - - # Only create splits (windows and spectrograms must already exist) - python prepare_dataset.py --config config/template.yaml --steps splits - """ - ) - - parser.add_argument( - "--config", type=str, required=True, - help="Path to YAML config file (e.g., config/template.yaml)" - ) - parser.add_argument( - "--steps", type=str, nargs="+", - default=["stats", "windows", "spectrograms", "splits"], - choices=["stats", "windows", "spectrograms", "splits"], - help="Steps to run (default: all)" - ) - - args = parser.parse_args() - - # Load configuration - print(f"Loading config from: {args.config}") - config = load_config(args.config) - - print(f"\nDomain: {config.name}") - print(f"Datasets: {config.datasets}") - print(f"Classes: {config.class_names}") - print(f"Data root: {config.paths.data_root}") - print(f"Output root: {config.paths.output_root}") - - # Track windows (needed for some steps) - windows = None - - # Run requested steps - if "stats" in args.steps: - run_stats(config) - - if "windows" in args.steps: - windows = run_windows(config) - elif "spectrograms" in args.steps or "splits" in args.steps: - windows = load_windows_if_exists(config) - if windows is None: - print("\nError: Windows not found. Run 'windows' step first.") - return - - if "spectrograms" in args.steps: - if windows is None: - windows = load_windows_if_exists(config) - if windows is None: - print("\nError: Windows not found. Run 'windows' step first.") - return - run_spectrograms(config, windows) - - if "splits" in args.steps: - if windows is None: - windows = load_windows_if_exists(config) - if windows is None: - print("\nError: Windows not found. Run 'windows' step first.") - return - run_splits(config, windows) - - print(f"\n{'='*60}") - print("Dataset preparation complete!") - print(f"{'='*60}") - - -if __name__ == "__main__": - main() diff --git a/PW_Bioacoustics/template.yaml b/PW_Bioacoustics/template.yaml deleted file mode 100644 index cf8b009f7..000000000 --- a/PW_Bioacoustics/template.yaml +++ /dev/null @@ -1,59 +0,0 @@ -# Template configuration file for bioacoustics experiments -# Copy this file and customize for your specific dataset - -name: "my_domain" - -datasets: - - "dataset_name_1" - - "dataset_name_2" - -class_names: - 0: "noise" - 1: "target_class" - -paths: - data_root: "${DATA_ROOT}" - output_root: "${OUTPUT_ROOT}" - spectrograms_dir: "${OUTPUT_ROOT}/mel_spectrograms" - annotations_file: "annotations.json" - -audio: - sample_rate: 48000 - window_size_sec: 5.0 - overlap_sec: 4.0 - window_strategy: "sliding" # or "balanced" - negative_proportion: 0.5 - -spectrogram: - n_fft: 2048 - hop_length: 512 - n_mels: 224 - top_db: 80.0 - fill_highfreq: true - noise_db_std: 3.0 - storage_dtype: "float32" - -training: - batch_size: 32 - num_workers: 4 - lr: 0.0001 - weight_decay: 0.0001 - epochs: 50 - backbone: "resnet18" - num_classes: 2 - label_smoothing: 0.0 - target_size: [224, 469] - x_col: "spec_name" - y_col: "label" - normalize: true - use_specaug: false - pos_weight: 1.0 - conf_threshold: 0.5 - freeze_backbone: "none" - backbone_lr_ratio: 1.0 - -splits: - test_size: 0.15 - val_size: 0.15 - n_splits: 5 - random_state: 42 diff --git a/PW_Bioacoustics/train.py b/PW_Bioacoustics/train.py deleted file mode 100644 index 8d7c1698f..000000000 --- a/PW_Bioacoustics/train.py +++ /dev/null @@ -1,380 +0,0 @@ -""" -Unified training script for PW_Bioacoustics. - -Supports both binary classification (num_classes=2) and multiclass (num_classes>2). - -Usage: - # Binary classification (default) - python train.py --train_csv train.csv --val_csv val.csv --test_csv test.csv - - # Multiclass classification - python train.py --train_csv train.csv --test_csv test.csv --num_classes 4 - - # With YAML config - python train.py --config config/template.yaml --train_csv train.csv --test_csv test.csv -""" - -import argparse -from dataclasses import dataclass, field -from typing import Optional, List - -import torch -from torch.utils.data import DataLoader -from torchinfo import summary - -import pytorch_lightning as pl -from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor - -# Import from PytorchWildlife core library -from PytorchWildlife.models.bioacoustics import ResNetClassifier -from PytorchWildlife.data.bioacoustics.bioacoustics_datasets import ( - BioacousticsDataset, - SpectrogramAugmentations, - MixUpCollator, -) -from PytorchWildlife.data.bioacoustics.bioacoustics_configs import load_config - - -@dataclass -class DataModuleConfig: - """Configuration for the SpectrogramDataModule.""" - train_csv: str = "train_split.csv" - val_csv: str = "val_split.csv" - test_csv: str = "test_split.csv" - root: Optional[str] = "mel_spectrograms" - x_col: str = "spec_name" - y_col: str = "label" - target_size: list = field(default_factory=lambda: [224, 469]) - batch_size: int = 32 - num_workers: int = 4 - pin_memory: bool = True - shuffle_train: bool = True - use_specaug: bool = False - normalize: bool = False - pcen: bool = False - num_classes: Optional[int] = None - use_mixup: bool = True - # Transform params - horizontal_shift_prob: float = 0.5 - horizontal_shift_range: float = 0.2 - vertical_shift_prob: float = 0.5 - vertical_shift_range: float = 0.1 - occlusion_prob: float = 0.5 - occlusion_max_lines: int = 3 - occlusion_line_width: float = 0.05 - noise_prob: float = 0.5 - noise_std: float = 0.02 - buffer_prob: float = 0.5 - buffer_max_ratio: float = 0.2 - mixup_prob: float = 0.5 - mixup_alpha: float = 0.2 - color_jitter_prob: float = 0.5 - color_jitter_brightness: float = 0.3 - color_jitter_contrast: float = 0.3 - - -class SpectrogramDataModule(pl.LightningDataModule): - """PyTorch Lightning DataModule for spectrogram classification.""" - - def __init__(self, cfg: DataModuleConfig): - super().__init__() - self.cfg = cfg - self.train_ds = None - self.val_ds = None - self.test_ds = None - self.eval_transform = None - - if cfg.use_specaug: - self.train_transform = SpectrogramAugmentations( - horizontal_shift_prob=self.cfg.horizontal_shift_prob, - horizontal_shift_range=self.cfg.horizontal_shift_range, - vertical_shift_prob=self.cfg.vertical_shift_prob, - vertical_shift_range=self.cfg.vertical_shift_range, - occlusion_prob=self.cfg.occlusion_prob, - occlusion_max_lines=self.cfg.occlusion_max_lines, - occlusion_line_width=self.cfg.occlusion_line_width, - noise_prob=self.cfg.noise_prob, - noise_std=self.cfg.noise_std, - buffer_prob=self.cfg.buffer_prob, - buffer_max_ratio=self.cfg.buffer_max_ratio, - color_jitter_prob=self.cfg.color_jitter_prob, - brightness=self.cfg.color_jitter_brightness, - contrast=self.cfg.color_jitter_contrast, - ) - else: - self.train_transform = None - - def setup(self, stage: Optional[str] = None): - dataset_kwargs = dict( - root=self.cfg.root, - x_col=self.cfg.x_col, - y_col=self.cfg.y_col, - target_size=self.cfg.target_size, - normalize=self.cfg.normalize, - ) - if hasattr(self.cfg, 'pcen'): - dataset_kwargs['pcen'] = self.cfg.pcen - if self.cfg.num_classes is not None: - dataset_kwargs['num_classes'] = self.cfg.num_classes - - if self.cfg.train_csv is not None: - self.train_ds = BioacousticsDataset( - csv_path=self.cfg.train_csv, - transform=self.train_transform, - is_training=True, - **dataset_kwargs - ) - if self.cfg.val_csv is not None: - self.val_ds = BioacousticsDataset( - csv_path=self.cfg.val_csv, - transform=self.eval_transform, - is_training=False, - **dataset_kwargs - ) - self.test_ds = BioacousticsDataset( - csv_path=self.cfg.test_csv, - transform=self.eval_transform, - is_training=False, - **dataset_kwargs - ) - - @property - def num_classes(self) -> int: - return self.test_ds.num_classes - - @property - def in_channels(self) -> int: - x0, _, _ = self.test_ds[0] - return x0.shape[0] - - @property - def is_binary(self) -> bool: - return self.num_classes == 2 - - def train_dataloader(self) -> DataLoader: - if self.is_binary and self.cfg.use_mixup: - collate_fn = MixUpCollator( - mixup_prob=self.cfg.mixup_prob, - mixup_alpha=self.cfg.mixup_alpha - ) - else: - collate_fn = None - - return DataLoader( - self.train_ds, - batch_size=self.cfg.batch_size, - shuffle=self.cfg.shuffle_train, - num_workers=self.cfg.num_workers, - pin_memory=self.cfg.pin_memory, - drop_last=False, - collate_fn=collate_fn, - ) - - def val_dataloader(self) -> DataLoader: - return DataLoader( - self.val_ds, - batch_size=self.cfg.batch_size, - shuffle=False, - num_workers=self.cfg.num_workers, - pin_memory=self.cfg.pin_memory, - ) - - def test_dataloader(self) -> DataLoader: - return DataLoader( - self.test_ds, - batch_size=self.cfg.batch_size, - shuffle=False, - num_workers=self.cfg.num_workers, - pin_memory=self.cfg.pin_memory, - ) - - -def main(): - pl.seed_everything(42) - - parser = argparse.ArgumentParser(description="Unified training for bioacoustics classification") - - # Config file (optional) - parser.add_argument("--config", type=str, default=None, help="Path to YAML config file") - - # Data arguments - parser.add_argument("--train_csv", type=str, default=None) - parser.add_argument("--val_csv", type=str, default=None) - parser.add_argument("--test_csv", type=str, default=None) - parser.add_argument("--root", type=str, default="") - parser.add_argument("--x_col", type=str, default="spec_name") - parser.add_argument("--target_size", type=int, nargs=2, default=[224, 469]) - - # Model arguments - parser.add_argument("--num_classes", type=int, default=2) - parser.add_argument("--class_names", type=str, nargs="+", default=None) - parser.add_argument("--backbone", type=str, default="resnet18", - choices=["resnet18", "resnet34", "resnet50"]) - - # Training arguments - parser.add_argument("--batch_size", type=int, default=32) - parser.add_argument("--num_workers", type=int, default=4) - parser.add_argument("--lr", type=float, default=1e-4) - parser.add_argument("--weight_decay", type=float, default=1e-4) - parser.add_argument("--label_smoothing", type=float, default=0.0) - parser.add_argument("--epochs", type=int, default=5) - parser.add_argument("--ckpt_path", type=str, default=None) - parser.add_argument("--monitor_metric", type=str, default="val/f1") - parser.add_argument("--finetune", type=lambda x: (str(x).lower() == 'true'), default=False) - - # Preprocessing - parser.add_argument("--normalize", type=lambda x: (str(x).lower() == 'true'), default=True) - parser.add_argument("--pcen", action="store_true") - - # Binary-specific - parser.add_argument("--pos_weight", type=float, default=1.0) - parser.add_argument("--conf_threshold", type=float, default=0.5) - parser.add_argument("--temperature", type=float, default=1.0) - - # Freezing - parser.add_argument("--freeze_backbone", type=str, default="none", - choices=["none", "all", "early", "layer1", "layer2", "layer3"]) - parser.add_argument("--backbone_lr_ratio", type=float, default=1.0) - - # Augmentation - parser.add_argument("--use_specaug", action="store_true") - parser.add_argument("--mixup_prob", type=float, default=0) - parser.add_argument("--mixup_alpha", type=float, default=0.2) - - args = parser.parse_args() - - # Load config file if provided - if args.config: - cfg = load_config(args.config) - # Apply config defaults where CLI args weren't explicitly set - if args.num_classes == 2 and cfg.training.num_classes != 2: - args.num_classes = cfg.training.num_classes - if args.x_col == "spec_name": - args.x_col = cfg.training.x_col - if args.target_size == [224, 469]: - args.target_size = cfg.training.target_size - if args.class_names is None: - args.class_names = list(cfg.class_names.values()) - if args.batch_size == 32: - args.batch_size = cfg.training.batch_size - if args.num_workers == 4: - args.num_workers = cfg.training.num_workers - if args.lr == 1e-4: - args.lr = cfg.training.lr - if args.weight_decay == 1e-4: - args.weight_decay = cfg.training.weight_decay - if args.epochs == 5: - args.epochs = cfg.training.epochs - if args.backbone == "resnet18": - args.backbone = cfg.training.backbone - - # Create DataModule config - dm_cfg = DataModuleConfig( - train_csv=args.train_csv, - val_csv=args.val_csv, - test_csv=args.test_csv, - root=args.root, - x_col=args.x_col, - target_size=args.target_size, - batch_size=args.batch_size, - num_workers=args.num_workers, - use_specaug=args.use_specaug, - normalize=args.normalize, - pcen=args.pcen, - num_classes=args.num_classes if args.num_classes != 2 else None, - use_mixup=(args.num_classes == 2), - mixup_prob=args.mixup_prob, - mixup_alpha=args.mixup_alpha, - ) - - dm = SpectrogramDataModule(dm_cfg) - dm.setup() - - num_classes = args.num_classes if args.num_classes != 2 else dm.num_classes - - model = ResNetClassifier( - num_classes=num_classes, - in_channels=dm.in_channels, - backbone=args.backbone, - lr=args.lr, - weight_decay=args.weight_decay, - label_smoothing=args.label_smoothing, - T_max=args.epochs, - batch_size=args.batch_size, - pos_weight=args.pos_weight, - conf_threshold=args.conf_threshold, - freeze_backbone=args.freeze_backbone, - backbone_lr_ratio=args.backbone_lr_ratio, - class_names=args.class_names, - ) - - print(f"\nClassification mode: {'Binary' if num_classes == 2 else f'Multiclass ({num_classes} classes)'}") - print(summary(model, input_size=(args.batch_size, dm.in_channels, *args.target_size))) - - # Callbacks & logging - mode = "min" if args.monitor_metric == "val/loss" else "max" - - ckpt_cb = ModelCheckpoint( - monitor=args.monitor_metric, - mode=mode, - save_top_k=1, - save_last=True, - filename="resnet-finetune-{epoch:02d}" if args.finetune else "resnet-{epoch:02d}", - ) - - early_cb = None - if not args.finetune: - early_cb = EarlyStopping(monitor=args.monitor_metric, mode=mode, patience=20) - - lr_cb = LearningRateMonitor(logging_interval="epoch") - - trainer = pl.Trainer( - max_epochs=args.epochs, - accelerator="gpu", - devices=[0], - precision="16-mixed", - gradient_clip_val=1.0, - log_every_n_steps=20, - callbacks=[cb for cb in [ckpt_cb, lr_cb, early_cb] if cb is not None], - logger=False, - ) - - if args.ckpt_path is None: - trainer.fit(model, datamodule=dm) - trainer.test(model, datamodule=dm, ckpt_path="best") - print("Best ckpt:", ckpt_cb.best_model_path) - print("Best score:", ckpt_cb.best_model_score) - else: - model = ResNetClassifier.load_from_checkpoint(args.ckpt_path) - - if args.temperature != 1.0: - model.temperature = torch.tensor(args.temperature, device=model.device) - print(f"Using manual temperature: {args.temperature}") - - if model.is_binary: - model.hparams.conf_threshold = args.conf_threshold - - if args.finetune: - model.hparams.lr = args.lr - model.hparams.weight_decay = args.weight_decay - model.hparams.label_smoothing = args.label_smoothing - model.hparams.T_max = args.epochs - model.hparams.batch_size = args.batch_size - model.hparams.freeze_backbone = args.freeze_backbone - model.hparams.backbone_lr_ratio = args.backbone_lr_ratio - - print(f"Finetuning from checkpoint: {args.ckpt_path}") - model._apply_freezing_strategy() - - trainer.fit(model, datamodule=dm) - trainer.test(model, datamodule=dm, ckpt_path='best') - print("Finetune completed.") - print("Best ckpt:", ckpt_cb.best_model_path) - print("Best score:", ckpt_cb.best_model_score) - else: - trainer.test(model, datamodule=dm) - print(f"Test completed from checkpoint {args.ckpt_path}") - - -if __name__ == "__main__": - main() diff --git a/PytorchWildlife/__init__.py b/PytorchWildlife/__init__.py deleted file mode 100644 index 026d73919..000000000 --- a/PytorchWildlife/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -import importlib.metadata as importlib_metadata -from pathlib import Path - -try: - # When installed (pip install [-e] .), setuptools writes the version - # from setup.py into the installed distribution metadata. - __version__ = importlib_metadata.version(__package__ or __name__) -except importlib_metadata.PackageNotFoundError: - # Source checkout without install — fall back to version.txt at repo root. - _version_file = Path(__file__).resolve().parent.parent / "version.txt" - try: - __version__ = _version_file.read_text().strip() - except FileNotFoundError: - __version__ = "development" - -from .data import * -from .models import * -from .utils import * \ No newline at end of file diff --git a/PytorchWildlife/data/__init__.py b/PytorchWildlife/data/__init__.py deleted file mode 100644 index cdc005beb..000000000 --- a/PytorchWildlife/data/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .datasets import * -from .transforms import * -from .bioacoustics import * \ No newline at end of file diff --git a/PytorchWildlife/data/bioacoustics/__init__.py b/PytorchWildlife/data/bioacoustics/__init__.py deleted file mode 100644 index a1946a543..000000000 --- a/PytorchWildlife/data/bioacoustics/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -""" -Bioacoustics infrastructure for PytorchWildlife. - -This module provides tools for bioacoustics data processing, including annotations, -spectrograms, windowing, datasets, and configuration management. -""" - -from .bioacoustics_annotations import AnnotationCreator, BaseReader -from .bioacoustics_configs import * -from .bioacoustics_datasets import * -from .bioacoustics_spectrograms import * -from .bioacoustics_windows import * diff --git a/PytorchWildlife/data/bioacoustics/bioacoustics_annotations.py b/PytorchWildlife/data/bioacoustics/bioacoustics_annotations.py deleted file mode 100644 index 00d4aecbd..000000000 --- a/PytorchWildlife/data/bioacoustics/bioacoustics_annotations.py +++ /dev/null @@ -1,367 +0,0 @@ -from datetime import datetime -from typing import Optional -import soundfile as sf -import pandas as pd -import requests -import json -import os - -class AnnotationCreator: - """ - A class to create and manage bioacustics annotations in a standard format. - - Attributes: - data (dict): A dictionary to store general info, sounds, and annotations of a bioacustics dataset. - """ - - def __init__(self): - """ - Initializes the AnnotationCreator with an empty dataset. - """ - self.data = { - "info": {}, - "categories": [], - "sounds": [], - "annotations": [] - } - - def _validate_date_format(self, date_str: str, date_format: str = "%Y%m%d"): - """ - Validates if the date string matches the given format and is not a future date. - - Args: - date_str (str): The date string to validate. - date_format (str): The expected format of the date string. - - Raises: - ValueError: If the date string does not match the format or is a future date. - """ - try: - date_obj = datetime.strptime(date_str, date_format) - except ValueError: - raise ValueError(f"Date '{date_str}' is not in the correct format. Expected format: {date_format}") - - if date_obj > datetime.now(): - raise ValueError("Future dates are invalid.") - - def _get_duration_and_sample_rate(self, file_path:str): - """ - Get the duration and sample rate of a sound file - - Args: - file_path (str): The path to the sound file - - Returns: - duration (float): The duration of the sound file in seconds - sample_rate (int): The sample rate of the sound file in Hz - """ - with sf.SoundFile(file_path) as sound_file: - duration = len(sound_file) / sound_file.samplerate - sample_rate = sound_file.samplerate - return duration, sample_rate - - def add_info(self, title:str=None, license:str=None, publication_date:Optional[datetime]=None, description:Optional[str]=None, creators:Optional[list]=None, version:Optional[float]=None, url:Optional[str]=None): - """ - Adds the general information about the dataset. - - Args: - title (str): The title of the bioacoustic dataset. - license (str): The name of the license that specifies the permissions and restrictions for using the bioacoustic dataset. - publication_date (Optional[datetime]): The date when the bioacoustic dataset was published in YYYY-MM-DD format. - description (Optional[str]): A brief summary of the dataset. - creators (Optional[dict]): List with creators information. - version (Optional[float]): The version number of the dataset. - url (Optional[str]): The web address where the dataset can be accessed or downloaded. If it's a Zenodo URL, metadata is fetched automatically. - - Raises: - ValueError: If the year is in the future or the date format is incorrect. - """ - if url is not None and "zenodo.org/records/" in url: - try: - record_id = url.split("zenodo.org/records/")[1] - response = requests.get(f"https://zenodo.org/api/records/{record_id}") - data = response.json() - metadata = data['metadata'] - title=metadata['title'] - license=metadata['license']['id'] - publication_date=datetime.strptime(metadata['publication_date'], "%Y-%m-%d").strftime("%Y%m%d") - description=metadata['description'] - creators=metadata['creators'] - if 'version' in metadata: - version=metadata['version'] - else: - version=metadata['relations']['version'] - - except requests.exceptions.RequestException as e: - raise ValueError(f"Failed to fetch metadata from Zenodo: {e}") - - if publication_date: - self._validate_date_format(publication_date) - - self.data["info"]["title"] = title - self.data["info"]["license"] = license - self.data["info"]["publication_date"] = publication_date - self.data["info"]["description"] = description - self.data["info"]["creators"] = creators - self.data["info"]["version"] = version - self.data["info"]["url"] = url - - def add_categories(self, categories_df:pd.DataFrame): - """ - Adds the categories ids and names to the dataset. - - Args: - categories (DataFrame): A pandas DataFrame containing the category information. - """ - sorted_df = categories_df.sort_values(by=categories_df.columns[0]) - sorted_df.reset_index(drop=True, inplace=True) - sorted_df.index.name = 'id' - categories_list = sorted_df.reset_index().to_dict(orient='records') - self.data['categories'] = categories_list - - def add_sound(self, id:int, file_name_path:str, duration:int, sample_rate:int, latitude:float, longitude:float, date_recorded:Optional[datetime]=None, project:Optional[str]=None): - """ - Adds a sound entry to the dataset. - - Args: - id (int): A unique identifier for a specific sound within the dataset. - file_name_path (str): The path of the audio file containing the bioacoustic recording. - duration (float): The length of the audio recording in seconds. - sample_rate (int): The number of samples of audio carried per second, measured in Hz. - latitude (float): The geographical latitude where the bioacoustic recording was taken. - longitude (float):The geographical longitude where the bioacoustic recording was taken. - date_recorded (Optional[str]): The datetime when the audio was recoded in YYYY-MM-DD format. - project (Optional[str]): The name of the project associated with the sound recording. - Raises: - ValueError: If duration, sample rate, latitude, or longitude are out of valid range, - or if the date format is incorrect. - """ - if duration <= 0: - raise ValueError("Duration must be a positive value.") - if sample_rate <= 0: - raise ValueError("Sample rate must be a positive value.") - if latitude is not None and not (isinstance(latitude, float) and pd.isna(latitude)): - if not (-90 <= latitude <= 90): - raise ValueError("Latitude must be between -90 and 90 degrees.") - if longitude is not None and not (isinstance(longitude, float) and pd.isna(longitude)): - if not (-180 <= longitude <= 180): - raise ValueError("Longitude must be between -180 and 180 degrees.") - if date_recorded: - self._validate_date_format(date_recorded) - - sound = { - "id": id, - "file_name_path": file_name_path, - "duration": duration, - "sample_rate": sample_rate, - "latitude": latitude, - "longitude": longitude, - "date_recorded": date_recorded, - "project": project, - } - self.data['sounds'].append(sound) - - def add_annotation( - self, - anno_id:int, - sound_id:int, - category_id:int, - category:str, - t_min:float, - t_max:float, - supercategory:Optional[str]=None, - f_min:Optional[float]=None, - f_max:Optional[float]=None, - ): - """ - Adds an annotation entry to the dataset. - - Args: - anno_id (int): A unique identifier for a specific annotation within the dataset. - sound_id (int): The identifier of the sound to which the annotation is linked. - category_id (int): The identifier of the category to which the sound belongs. - category (str): The name of the category of sounds, such as a particular species or type of call. - t_min (float): The starting time of the annotated sound within the recording, in seconds. - t_max (float): The ending time of the annotated sound within the recording, in seconds. - supercategory (Optional[str]): A higher-level grouping that the category belongs to. - f_min (Optional[float]): The lowest frequency of the annotated sound within the recording, in Hz. - f_max (Optional[float]): The highest frequency of the annotated sound within the recording, in Hz - Raises: - ValueError: If any of the provided values are out of valid range, or if the - time/frequency constraints are violated. - - """ - sound_dict = self.data['sounds'][sound_id] - if t_min < 0: - raise ValueError("t_min must be a positive value.") - if t_max < 0: - raise ValueError("t_max must be a positive value.") - if t_max > sound_dict["duration"]: - t_max = sound_dict["duration"] - if t_max < t_min: - raise ValueError( - f"t_max ({t_max:.4f}) must be greater than t_min ({t_min:.4f}). " - f"Sound duration is {sound_dict['duration']:.4f}s " - f"(anno_id={anno_id}, sound_id={sound_id})." - ) - if f_min and f_max: - if f_min < 0: - raise ValueError("f_min must be a positive value.") - if f_max < 0: - raise ValueError("f_max must be a positive value.") - if f_max < f_min: - raise ValueError("f_max must be greater than f_min.") - if f_max > sound_dict["sample_rate"] / 2: - raise ValueError("f_max must be less than half the sample rate of the sound.") - - annotation = { - "anno_id": anno_id, - "sound_id": sound_id, - "category_id": category_id, - "category": category, - "supercategory": supercategory, - "t_min": t_min, - "t_max": t_max, - "f_min": f_min, - "f_max": f_max, - } - self.data['annotations'].append(annotation) - - def convert_crowsetta_bbox_annotations(self, crowsetta_annotations:list): - """ - Adds annotations from Crowsetta to the dataset. - - Args: - crowsetta_annotations (list): A list of annotations in Crowsetta format. - """ - # Add categories - unique_labels = set() - for annot in crowsetta_annotations: - for bbox in annot.bboxes: - unique_labels.add(bbox.label) - - categories_df = pd.DataFrame(list(unique_labels), columns=['label']) - self.add_categories(categories_df) - - # Add sounds - for sound_id, annotation in enumerate(crowsetta_annotations): - duration, sample_rate = self._get_duration_and_sample_rate(annotation.notated_path) - self.add_sound(id=sound_id, - file_name_path=annotation.notated_path.name, - duration=duration, - sample_rate=sample_rate, - latitude=None, - longitude=None) - # Add annotations - for anno_id, bbox in enumerate(annotation.bboxes): - category_id = [category for category in self.data["categories"] if category["label"] == bbox.label][0]["id"] - self.add_annotation(anno_id=anno_id, - sound_id=sound_id, - category_id=category_id, - category=bbox.label, - t_min=float(bbox.onset), - t_max=float(bbox.offset), - f_min=float(bbox.low_freq), - f_max=float(bbox.high_freq)) - - def convert_crowsetta_seq_annotations(self, crowsetta_annotations:list): - """ - Adds annotations from Crowsetta to the dataset. - - Args: - crowsetta_annotations (list): A list of annotations in Crowsetta format. - """ - # Add categories - unique_labels = set() - for annot in crowsetta_annotations: - for segment in annot.seq.segments: - unique_labels.add(segment.label) - - categories_df = pd.DataFrame(list(unique_labels), columns=['label']) - self.add_categories(categories_df) - - # Add sounds - for sound_id, annotation in enumerate(crowsetta_annotations): - duration, sample_rate = self._get_duration_and_sample_rate(annotation.notated_path) - self.add_sound(id=sound_id, - file_name_path=annotation.notated_path.name, - duration=duration, - sample_rate=sample_rate, - latitude=None, - longitude=None) - # Add annotations - for anno_id, segment in enumerate(annotation.seq.segments): - category_id = [category for category in self.data["categories"] if category["label"] == segment.label][0]["id"] - self.add_annotation(anno_id=anno_id, - sound_id=sound_id, - category_id=category_id, - category=segment.label, - t_min=float(segment.onset_s), - t_max=float(segment.offset_s)) - - def save_to_file(self, filename): - """ - Saves the current dataset to a JSON file. - - Args: - filename (str): The name of the file to save the dataset to. - """ - with open(filename, 'w') as f: - json.dump(self.data, f, indent=4) - -class BaseReader: - def __init__(self, data_path): - self.data_path = data_path - self.output_path = os.path.join(data_path, "annotations.json") - self.annotation_creator = AnnotationCreator() - self.data = None - - def add_dataset_info(self): - """Method to add dataset metadata (to be implemented in subclasses).""" - raise NotImplementedError("This method should be implemented in a subclass.") - - def add_sounds(self): - """Method to add sounds (to be implemented in subclasses).""" - raise NotImplementedError("This method should be implemented in a subclass.") - - def add_categories(self): - """Method to add categories (to be implemented in subclasses).""" - raise NotImplementedError("This method should be implemented in a subclass.") - - def add_annotations(self): - """Method to add annotations (to be implemented in subclasses).""" - raise NotImplementedError("This method should be implemented in a subclass.") - - def save_dataset(self): - """Saves the processed dataset as a JSON file.""" - self.annotation_creator.save_to_file(self.output_path) - - def load_dataset(self): - """Loads the dataset from the JSON file.""" - with open(self.output_path, 'r', encoding='utf-8') as f: - self.data = json.load(f) - - self.categories = {cat["id"]: cat["name"] for cat in self.data["categories"]} - self.sounds = self.data["sounds"] - self.annotations = self.data["annotations"] - - def show_summary(self): - """Displays a general summary of the dataset.""" - total_duration = sum(sound['duration'] for sound in self.sounds) - total_hours = total_duration / 3600 - - print(f"Dataset: {self.data['info']['title']}") - print(f"Total categories: {len(self.categories)}") - print(f"Total audio recordings: {len(self.sounds)}") - print(f"Total annotations: {len(self.annotations)}") - print(f"Total duration: {total_hours:.2f} hours") - - def process_dataset(self): - """Executes the full dataset processing pipeline.""" - self.add_dataset_info() - self.add_sounds() - self.add_categories() - self.add_annotations() - self.save_dataset() - self.load_dataset() - self.show_summary() diff --git a/PytorchWildlife/data/bioacoustics/bioacoustics_configs.py b/PytorchWildlife/data/bioacoustics/bioacoustics_configs.py deleted file mode 100644 index 7fa54d82e..000000000 --- a/PytorchWildlife/data/bioacoustics/bioacoustics_configs.py +++ /dev/null @@ -1,221 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -""" -Bioacoustics configuration schema for PytorchWildlife. - -This module provides configuration dataclasses and loader/saver functions -for bioacoustics experiments. -""" - -import os -from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional - -try: - import yaml -except ImportError: - yaml = None - - -__all__ = [ - "PathConfig", - "AudioConfig", - "SpectrogramConfig", - "TrainingConfig", - "SplitsConfig", - "DomainConfig", - "load_config", - "save_config", -] - - -def _expand_env_vars(value: Any) -> Any: - """Recursively expand environment variables in strings.""" - if isinstance(value, str): - return os.path.expandvars(value) - elif isinstance(value, dict): - return {k: _expand_env_vars(v) for k, v in value.items()} - elif isinstance(value, list): - return [_expand_env_vars(v) for v in value] - return value - - -@dataclass -class PathConfig: - """Paths configuration with environment variable expansion.""" - data_root: str = "" - output_root: str = "" - spectrograms_dir: str = "" - annotations_file: str = "annotations.json" - windows_json: str = "windows_annotations.json" - - def __post_init__(self): - """Expand environment variables and resolve paths.""" - self.data_root = os.path.expandvars(self.data_root) - self.output_root = os.path.expandvars(self.output_root) - self.spectrograms_dir = os.path.expandvars(self.spectrograms_dir) - - @property - def annotations_path(self) -> str: - """Full path to annotations file.""" - return os.path.join(self.data_root, self.annotations_file) - - -@dataclass -class AudioConfig: - """Audio processing parameters.""" - sample_rate: int = 48000 - window_size_sec: float = 5.0 - overlap_sec: float = 4.0 - window_strategy: str = "sliding" # "sliding", "balanced", or "customized" - negative_proportion: float = 0.5 # For "balanced" strategy - windows_csv: str = "" # Path to pre-built CSV for "customized" strategy - windows_json: str = "" # Filename for saving/loading the windows JSON file - multiclass: bool = False # Use category_id labels instead of binary 0/1 - min_overlap_sec: float = 0 # Minimum overlap (s) to label a window positive - - @property - def hop_size_sec(self) -> float: - """Hop size in seconds (window_size - overlap).""" - return self.window_size_sec - self.overlap_sec - - -@dataclass -class SpectrogramConfig: - """Mel spectrogram generation parameters.""" - n_fft: int = 2048 - hop_length: int = 512 - n_mels: int = 224 - top_db: float = 80.0 - f_min: float = 0.0 # Minimum frequency (Hz) for the mel filterbank - mono_channel: str = "left" # "left", "right", or "mean" for stereo→mono - fill_highfreq: bool = True - fill_mean_below_sr: bool = False # Fill with mean instead of noise when orig_sr < target_sr - noise_db_std: float = 3.0 - storage_dtype: str = "float32" - - -@dataclass -class TrainingConfig: - """Training hyperparameters.""" - batch_size: int = 32 - num_workers: int = 4 - lr: float = 1e-4 - weight_decay: float = 1e-4 - epochs: int = 50 - backbone: str = "resnet18" - num_classes: int = 2 # 2 = binary mode, >2 = multiclass - label_smoothing: float = 0.0 - target_size: List[int] = field(default_factory=lambda: [224, 469]) - x_col: str = "spec_name" - y_col: str = "label" - normalize: bool = True - use_specaug: bool = False - pos_weight: float = 1.0 # Binary only - conf_threshold: float = 0.5 # Binary only - freeze_backbone: str = "none" - backbone_lr_ratio: float = 1.0 - - -@dataclass -class SplitsConfig: - """Data split parameters.""" - test_size: float = 0.15 - val_size: float = 0.15 - n_splits: int = 5 - random_state: int = 42 - custom_splits_folder: Optional[str] = None - - -@dataclass -class DomainConfig: - """Complete domain-specific configuration.""" - name: str = "" - datasets: List[str] = field(default_factory=list) - class_names: Dict[int, str] = field(default_factory=dict) - paths: PathConfig = field(default_factory=PathConfig) - audio: AudioConfig = field(default_factory=AudioConfig) - spectrogram: SpectrogramConfig = field(default_factory=SpectrogramConfig) - training: TrainingConfig = field(default_factory=TrainingConfig) - splits: SplitsConfig = field(default_factory=SplitsConfig) - - @property - def is_binary(self) -> bool: - """Check if this is a binary classification task.""" - return self.training.num_classes == 2 - - -def load_config(config_path: str) -> DomainConfig: - """ - Load configuration from a YAML file. - - Environment variables in the format ${VAR_NAME} are expanded. - - Args: - config_path: Path to the YAML configuration file. - - Returns: - DomainConfig object with all settings. - - Example: - config = load_config("config/birds.yaml") - print(config.audio.sample_rate) # 48000 - """ - if yaml is None: - raise ImportError("PyYAML is required for config loading. Install with: pip install pyyaml") - - with open(config_path, 'r', encoding='utf-8') as f: - data = yaml.safe_load(f) - - # Expand environment variables - data = _expand_env_vars(data) - - # Build nested config objects - paths = PathConfig(**data.get('paths', {})) - audio = AudioConfig(**data.get('audio', {})) - # If windows_json is set in audio but not in paths, propagate it - if audio.windows_json and not data.get('paths', {}).get('windows_json'): - paths.windows_json = audio.windows_json - spectrogram = SpectrogramConfig(**data.get('spectrogram', {})) - training = TrainingConfig(**data.get('training', {})) - splits = SplitsConfig(**data.get('splits', {})) - - return DomainConfig( - name=data.get('name', ''), - datasets=data.get('datasets', []), - class_names=data.get('class_names', {}), - paths=paths, - audio=audio, - spectrogram=spectrogram, - training=training, - splits=splits, - ) - - -def save_config(config: DomainConfig, config_path: str) -> None: - """ - Save configuration to a YAML file. - - Args: - config: DomainConfig object to save. - config_path: Path to save the YAML file. - """ - if yaml is None: - raise ImportError("PyYAML is required for config saving. Install with: pip install pyyaml") - - from dataclasses import asdict - - data = { - 'name': config.name, - 'datasets': config.datasets, - 'class_names': config.class_names, - 'paths': asdict(config.paths), - 'audio': asdict(config.audio), - 'spectrogram': asdict(config.spectrogram), - 'training': asdict(config.training), - 'splits': asdict(config.splits), - } - - with open(config_path, 'w', encoding='utf-8') as f: - yaml.dump(data, f, default_flow_style=False, sort_keys=False) diff --git a/PytorchWildlife/data/bioacoustics/bioacoustics_datasets.py b/PytorchWildlife/data/bioacoustics/bioacoustics_datasets.py deleted file mode 100644 index 6b8b06fc2..000000000 --- a/PytorchWildlife/data/bioacoustics/bioacoustics_datasets.py +++ /dev/null @@ -1,499 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -""" -Bioacoustics datasets and data augmentation transforms for spectrogram classification. - -This module provides: -- BioacousticsDataset: Dataset for loading spectrograms from .npy files -- SpectrogramAugmentations: Spectrogram-specific augmentation techniques -- MixUpCollator: Batch-level MixUp augmentation -- Utility transforms for normalization and resizing -""" - -import os -import random -from typing import Callable, List, Optional - -import numpy as np -import pandas as pd -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.utils.data import Dataset - -try: - import librosa -except ImportError: - librosa = None - - -__all__ = [ - "BioacousticsDataset", - "BioacousticsInferenceDataset", - "SpectrogramAugmentations", - "MixUpCollator", - "PerSampleNormalize", - "ResizeTo", - "mixup_criterion", -] - - -# --------------- Transforms --------------- - -class PerSampleNormalize(nn.Module): - """ - Normalize each sample x:[C,H,W] to zero-mean / unit-std (per entire tensor). - Keeps scale stable without relying on dataset-wide stats. - """ - def forward(self, x: torch.Tensor) -> torch.Tensor: - mean = x.mean() - std = x.std().clamp_min(1e-6) - return (x - mean) / std - - -class ResizeTo: - """ - Resize a spectrogram tensor to (H, W). - - Args: - size_hw: Target size as [height, width]. - """ - def __init__(self, size_hw: List[int]): - self.size_hw = size_hw - - def __call__(self, x: torch.Tensor) -> torch.Tensor: - if x.ndim != 3: - raise ValueError(f"ResizeTo expects [C,H,W], got {tuple(x.shape)}") - h, w = self.size_hw - if (x.shape[-2], x.shape[-1]) == (h, w): - return x - # Crop width if needed (common for spectrograms) - return x[:, :, :w] - - -class SpectrogramAugmentations: - """ - Spectrogram-specific augmentation techniques. - - Includes horizontal/vertical shifts, random occlusions, Gaussian noise, - buffer simulation, and color jitter. - - Args: - horizontal_shift_prob: Probability of applying horizontal shift. - horizontal_shift_range: Fraction of time dimension to shift. - vertical_shift_prob: Probability of applying vertical shift. - vertical_shift_range: Fraction of frequency dimension to shift. - occlusion_prob: Probability of applying random occlusion. - occlusion_max_lines: Max number of occlusion lines. - occlusion_line_width: Width of occlusion lines as fraction. - noise_prob: Probability of adding Gaussian noise. - noise_std: Std of Gaussian noise. - buffer_prob: Probability of buffer corruption. - buffer_max_ratio: Max corruption ratio for buffer simulation. - color_jitter_prob: Probability of applying color jitter. - brightness: Brightness adjustment factor. - contrast: Contrast adjustment factor. - """ - - def __init__( - self, - horizontal_shift_prob: float = 0.1, - horizontal_shift_range: float = 0.1, - vertical_shift_prob: float = 0.1, - vertical_shift_range: float = 0.1, - occlusion_prob: float = 0.15, - occlusion_max_lines: int = 2, - occlusion_line_width: float = 0.05, - noise_prob: float = 0.1, - noise_std: float = 0.02, - buffer_prob: float = 0.05, - buffer_max_ratio: float = 0.1, - color_jitter_prob: float = 0.1, - brightness: float = 0.1, - contrast: float = 0.1, - ): - self.horizontal_shift_prob = horizontal_shift_prob - self.horizontal_shift_range = horizontal_shift_range - self.vertical_shift_prob = vertical_shift_prob - self.vertical_shift_range = vertical_shift_range - self.occlusion_prob = occlusion_prob - self.occlusion_max_lines = occlusion_max_lines - self.occlusion_line_width = occlusion_line_width - self.noise_prob = noise_prob - self.noise_std = noise_std - self.buffer_prob = buffer_prob - self.buffer_max_ratio = buffer_max_ratio - self.color_jitter_prob = color_jitter_prob - self.brightness = brightness - self.contrast = contrast - - def horizontal_shift(self, spec): - if torch.rand(1) < self.horizontal_shift_prob: - _, _, time_dim = spec.shape - shift_pixels = int(torch.randint( - -int(time_dim * self.horizontal_shift_range), - int(time_dim * self.horizontal_shift_range) + 1, (1,) - )) - if shift_pixels != 0: - mean_val = spec.mean() - spec = torch.roll(spec, shifts=shift_pixels, dims=2) - if shift_pixels > 0: - spec[:, :, :shift_pixels] = mean_val - else: - spec[:, :, shift_pixels:] = mean_val - return spec - - def vertical_shift(self, spec): - if torch.rand(1) < self.vertical_shift_prob: - _, freq_dim, _ = spec.shape - shift_bins = int(torch.randint( - -int(freq_dim * self.vertical_shift_range), - int(freq_dim * self.vertical_shift_range) + 1, (1,) - )) - if shift_bins != 0: - spec = torch.roll(spec, shifts=shift_bins, dims=1) - mean_val = spec.mean() - if shift_bins > 0: - spec[:, :shift_bins, :] = mean_val - else: - spec[:, shift_bins:, :] = mean_val - return spec - - def add_occlusions(self, spec): - if torch.rand(1) < self.occlusion_prob: - _, freq_dim, time_dim = spec.shape - num_lines = torch.randint(1, self.occlusion_max_lines + 1, (1,)).item() - mean_val = spec.mean() - for _ in range(num_lines): - if torch.rand(1) < 0.5: - freq_start = torch.randint(0, freq_dim, (1,)).item() - line_width = torch.randint(1, int(freq_dim * self.occlusion_line_width), (1,)).item() - freq_end = min(freq_start + line_width, freq_dim) - spec[:, freq_start:freq_end, :] = mean_val - else: - time_start = torch.randint(0, time_dim, (1,)).item() - line_width = torch.randint(1, int(time_dim * self.occlusion_line_width), (1,)).item() - time_end = min(time_start + line_width, time_dim) - spec[:, :, time_start:time_end] = mean_val - return spec - - def add_gaussian_noise(self, spec): - if torch.rand(1) < self.noise_prob: - noise = torch.randn_like(spec) * self.noise_std - spec = spec + noise - return spec - - def add_buffer_simulation(self, spec): - if torch.rand(1) < self.buffer_prob: - _, freq_dim, time_dim = spec.shape - downsample_factor = 1.0 - torch.rand(1) * self.buffer_max_ratio - new_time_dim = max(1, int(time_dim * downsample_factor)) - new_freq_dim = max(1, int(freq_dim * downsample_factor)) - spec_down = F.interpolate( - spec.unsqueeze(0), - size=(new_freq_dim, new_time_dim), - mode='bilinear', - align_corners=False - ).squeeze(0) - spec = F.interpolate( - spec_down.unsqueeze(0), - size=(freq_dim, time_dim), - mode='bilinear', - align_corners=False - ).squeeze(0) - return spec - - def color_jitter(self, spec): - if torch.rand(1) < self.color_jitter_prob: - if self.brightness > 0: - brightness_factor = 1.0 + (torch.rand(1).item() * 2 - 1) * self.brightness - spec = spec * brightness_factor - if self.contrast > 0: - mean_val = spec.mean() - contrast_factor = 1.0 + (torch.rand(1).item() * 2 - 1) * self.contrast - spec = (spec - mean_val) * contrast_factor + mean_val - return spec - - def __call__(self, spec, is_training=True): - if not is_training: - return spec - augmentations = [ - self.horizontal_shift, - self.vertical_shift, - self.add_occlusions, - self.add_gaussian_noise, - self.add_buffer_simulation, - self.color_jitter, - ] - num_to_apply = len(augmentations) - selected = random.sample(augmentations, num_to_apply) - random.shuffle(selected) - for aug in selected: - spec = aug(spec) - return spec - - -# --------------- MixUp --------------- - -class MixUpCollator: - """ - A collate function that applies MixUp augmentation at the batch level. - - Args: - mixup_prob: Probability of applying MixUp. - mixup_alpha: Alpha parameter for Beta distribution sampling. - """ - - def __init__(self, mixup_prob: float = 0.2, mixup_alpha: float = 0.2): - self.mixup_prob = mixup_prob - self.mixup_alpha = mixup_alpha - - def __call__(self, batch): - """ - Apply MixUp to a batch of (spectrogram, label, path) tuples. - - Returns: - specs: Tensor of mixed spectrograms [B, C, H, W]. - labels: Dict with info for MixUp loss calculation. - paths: List of original paths. - """ - specs, labels, paths = zip(*batch) - specs = torch.stack(specs) - labels = torch.tensor(labels) - batch_size = specs.size(0) - - if torch.rand(1) < self.mixup_prob and batch_size > 1: - noise_intensity = self.mixup_alpha - lam = torch.clamp( - 1.0 - torch.abs(torch.randn(1) * noise_intensity), 0.0, 1.0 - ).item() - - indices = torch.randperm(batch_size) - mixed_specs = lam * specs + (1 - lam) * specs[indices] - - mixed_labels = { - 'original_labels': labels, - 'shuffled_labels': labels[indices], - 'lambda': lam, - 'is_mixed': True - } - - return mixed_specs, mixed_labels, list(paths) - else: - mixed_labels = { - 'original_labels': labels, - 'shuffled_labels': None, - 'lambda': 1.0, - 'is_mixed': False - } - return specs, mixed_labels, list(paths) - - -def mixup_criterion(criterion, pred, targets): - """ - Compute the MixUp loss. - - Args: - criterion: Loss function (e.g., nn.BCEWithLogitsLoss()). - pred: Model predictions. - targets: Dictionary containing mixed label information. - - Returns: - Mixed loss value. - """ - if targets['is_mixed']: - y_a, y_b = targets['original_labels'].float(), targets['shuffled_labels'].float() - lam = targets['lambda'] - return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b) - else: - return criterion(pred, targets['original_labels'].float()) - - -# --------------- Dataset --------------- - -class BioacousticsDataset(Dataset): - """ - Dataset that reads spectrograms from .npy files whose paths are listed in a CSV. - - Args: - csv_path: Path to the CSV file. - root: Root folder to prepend to spec_name when it's a relative path. - x_col: Column containing the path to the .npy file. - y_col: Column containing the label. - target_size: [H, W] to resize spectrograms to; if None, keep original size. - transform: Callable applied for transformations. - is_training: Whether this is training mode (affects augmentations). - normalize: Whether to apply per-sample normalization. - pcen: Whether to apply PCEN transformation. - num_classes: Number of classes; if None, inferred from data. - """ - - def __init__( - self, - csv_path: str, - root: Optional[str] = None, - x_col: str = "spec_name", - y_col: str = "label", - target_size: Optional[List[int]] = None, - transform: Optional[Callable] = None, - is_training: bool = False, - normalize: bool = True, - pcen: bool = False, - num_classes: int = None, - ): - super().__init__() - self.df = pd.read_csv(csv_path) - self.root = root - self.x_col = x_col - self.y_col = y_col - self.transform = transform - self.is_training = is_training - self.pcen = pcen - - # Resolved paths - self.paths: List[str] = [] - for p in self.df[self.x_col].astype(str).tolist(): - self.paths.append(os.path.join(self.root, p) if self.root else p) - - self._resize = ResizeTo(target_size) if target_size is not None else None - self._normalize = PerSampleNormalize() if normalize else None - - if num_classes is not None: - self.num_classes = num_classes - else: - self.num_classes = int(self.df[self.y_col].max()) + 1 - - def __len__(self) -> int: - return len(self.df) - - def _load_npy(self, idx: int): - path = self.paths[idx] - try: - arr = np.load(path) - except (EOFError, ValueError) as e: - print(f"ERROR loading file at index {idx}: {path}") - print(f"File size: {os.path.getsize(path) if os.path.exists(path) else 'FILE NOT FOUND'} bytes") - raise e - return arr, path - - def _apply_pcen(self, S_db, sr=24000): - """Apply Per-Channel Energy Normalization.""" - if librosa is None: - raise ImportError("librosa is required for PCEN. Install with: pip install librosa") - - S_lin = librosa.db_to_power(S_db) - S_pcen = librosa.pcen( - S_lin * (2**31), - sr=sr, - gain=1.0, - bias=10.0, - power=0.5, - time_constant=0.3, - ) - S_pcen = librosa.power_to_db(S_pcen) - return S_pcen.astype(np.float32, copy=False) - - def __getitem__(self, idx: int): - arr, path = self._load_npy(idx) - arr = arr.astype(np.float32, copy=False) - - if self.pcen: - arr = self._apply_pcen(arr) - - # Shape to [C, H, W] - if arr.ndim == 2: - arr = arr[None, ...] # [1, H, W] - elif arr.ndim == 3: - if arr.shape[0] not in (1, 2, 3) and arr.shape[-1] in (1, 2, 3): - arr = np.moveaxis(arr, -1, 0) - else: - raise ValueError(f"Unexpected .npy shape {arr.shape} at index {idx}") - - x = torch.from_numpy(arr) - - if self._normalize is not None: - x = self._normalize(x) - - if self._resize is not None: - x = self._resize(x) - - if self.transform is not None: - x = self.transform(x, self.is_training) - - y = int(self.df.iloc[idx][self.y_col]) - - return x, y, path - - -class BioacousticsInferenceDataset(Dataset): - """ - Dataset that reads spectrograms from .npy files whose paths are listed in a dataframe. - - Unlike :class:`BioacousticsDataset`, this class does **not** require a - label column and returns ``(tensor, path)`` pairs suitable for inference. - - Parameters - ---------- - dataframe : pd.DataFrame - DataFrame containing at least the column specified by `x_col`. - x_col : str - Column containing the path to the .npy file (default: "spec_name"). - target_size : Optional[List[int]] - [H, W] to resize spectrograms to; if None, keep original size. - normalize : bool - Whether to apply per-sample normalization. - """ - - def __init__( - self, - dataframe: pd.DataFrame, - x_col: str = "spec_name", - target_size: Optional[List[int]] = None, - normalize: bool = True, - ): - super().__init__() - self.df = dataframe - self.x_col = x_col - self.paths = self.df[self.x_col].astype(str).tolist() - self._resize = ResizeTo(target_size) if target_size is not None else None - self._normalize = PerSampleNormalize() if normalize else None - - def __len__(self) -> int: - return len(self.df) - - def _load_npy(self, idx: int): - path = self.paths[idx] - try: - arr = np.load(path) - except Exception as e: - print(f"\n{'='*80}") - print(f"ERROR loading file at index {idx}:") - print(f"Path: {path}") - print(f"Exception: {e}") - print(f"{'='*80}\n") - raise - return arr, path - - def __getitem__(self, idx: int): - arr, path = self._load_npy(idx) - arr = arr.astype(np.float32, copy=False) - - # shape to [C,H,W] - if arr.ndim == 2: - arr = arr[None, ...] # [1,H,W] - elif arr.ndim == 3: - if arr.shape[0] not in (1, 2, 3) and arr.shape[-1] in (1, 2, 3): - arr = np.moveaxis(arr, -1, 0) - else: - raise ValueError(f"Unexpected .npy shape {arr.shape} at index {idx}") - - x = torch.from_numpy(arr) # [C,H,W] - - if self._normalize is not None: - x = self._normalize(x) - - if self._resize is not None: - x = self._resize(x) - - return x, path diff --git a/PytorchWildlife/data/bioacoustics/bioacoustics_spectrograms.py b/PytorchWildlife/data/bioacoustics/bioacoustics_spectrograms.py deleted file mode 100644 index 036990e66..000000000 --- a/PytorchWildlife/data/bioacoustics/bioacoustics_spectrograms.py +++ /dev/null @@ -1,304 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -""" -GPU-accelerated mel spectrogram computation for bioacoustics. - -This module provides a single, configurable function for computing mel -spectrograms from audio windows. It supports high-frequency fill strategies, -configurable mono-channel selection, and custom spectrogram naming. -""" - -import os -import math -import struct -from pathlib import Path -from collections import defaultdict -from typing import Callable, Dict, List, Optional - -import numpy as np -from tqdm import tqdm - -import torch -import torchaudio -import librosa -import soundfile as sf - - -__all__ = [ - "compute_mel_spectrograms_gpu", - "default_spectrogram_path", -] - - -def default_spectrogram_path(win: Dict, spectrograms_path: str) -> str: - """Return the default .npy path for a spectrogram window. - - Uses ``{basename(sound_path)}_{start}_{end}.npy``. - """ - sound_filename = os.path.splitext(os.path.basename(win["sound_path"]))[0] - return os.path.join( - spectrograms_path, - f"{sound_filename}_{int(win['start'])}_{int(win['end'])}.npy", - ) - - -def _read_wav_fallback(filepath: str): - """Read audio from a WAV file whose RIFF size field overflowed (>4 GB). - - Falls back to manual header parsing + raw PCM decoding so that - ``soundfile`` / ``libsndfile`` errors are bypassed. - - Returns - ------- - (data, sample_rate) matching ``sf.read(..., always_2d=True)`` contract: - *data* is float32 with shape ``(n_frames, n_channels)``. - """ - with open(filepath, "rb") as f: - f.read(4) # b'RIFF' - f.read(4) # file-size field (unreliable for >4 GB) - wave_tag = f.read(4) - if wave_tag != b"WAVE": - raise ValueError(f"Not a WAV file (no WAVE tag): {filepath}") - - channels = sample_rate = bits_per_sample = None - data_offset = data_size = None - - while True: - chunk_hdr = f.read(8) - if len(chunk_hdr) < 8: - break - chunk_id = chunk_hdr[:4] - chunk_sz = struct.unpack(" file_sz: - data_size = file_sz - data_offset - else: - data_size = chunk_sz - break - else: - f.seek(chunk_sz + (chunk_sz & 1), 1) - - if channels is None or data_offset is None: - raise ValueError(f"Could not parse WAV header: {filepath}") - - bytes_per_sample = bits_per_sample // 8 - frame_bytes = bytes_per_sample * channels - data_size = (data_size // frame_bytes) * frame_bytes - - f.seek(data_offset) - raw = f.read(data_size) - - if bits_per_sample == 16: - samples = np.frombuffer(raw, dtype=np.int16).astype(np.float32) / 32768.0 - elif bits_per_sample == 24: - n = len(raw) // 3 - b = np.frombuffer(raw[: n * 3], dtype=np.uint8).reshape(-1, 3) - i32 = ( - b[:, 2].astype(np.int32) << 24 - | b[:, 1].astype(np.int32) << 16 - | b[:, 0].astype(np.int32) << 8 - ) >> 8 - samples = i32.astype(np.float32) / 8388608.0 - elif bits_per_sample == 32: - samples = np.frombuffer(raw, dtype=np.int32).astype(np.float32) / 2147483648.0 - else: - raise ValueError(f"Unsupported bits_per_sample={bits_per_sample}: {filepath}") - - samples = samples.reshape(-1, channels) - return samples, sample_rate - - -def compute_mel_spectrograms_gpu( - windows: List[Dict], - sample_rate: int, - n_fft: int, - hop_length: Optional[int], - n_mels: int, - top_db: float, - spectrograms_path: str, - save_npy: bool = True, - f_min: float = 0.0, - mono_channel: str = "left", - fill_highfreq: bool = True, - fill_mean_below_sr: bool = False, - noise_db_mean: Optional[float] = None, - noise_db_std: float = 3.0, - storage_dtype: str = "float32", - spectrogram_path_fn: Optional[Callable[[Dict, str], str]] = None, -) -> None: - """GPU-accelerated mel spectrogram computation. - - Parameters - ---------- - windows : list of dict - Each dict must have ``sound_path``, ``start``, ``end`` keys. - sample_rate : int - Target sample rate. Audio is resampled if it differs. - n_fft, hop_length, n_mels, top_db - Mel spectrogram parameters. - spectrograms_path : str - Directory where ``.npy`` files are saved. - save_npy : bool - Whether to persist spectrograms to disk. - f_min : float - Minimum frequency (Hz) for the mel filterbank. - mono_channel : str - How to reduce stereo to mono: ``"left"``, ``"right"``, or ``"mean"``. - fill_highfreq : bool - Fill high-frequency bins above the original Nyquist when resampling. - fill_mean_below_sr : bool - If True use the mean of valid bands for fill; otherwise use 10th- - percentile noise floor. - noise_db_mean, noise_db_std : float - Parameters for noise-floor estimation when *fill_mean_below_sr* is - False. - storage_dtype : str - NumPy dtype for saved arrays (``"float16"`` or ``"float32"``). - spectrogram_path_fn : callable, optional - ``(win, spectrograms_path) -> str`` returning the full ``.npy`` path - for a window. Defaults to :func:`default_spectrogram_path`. - """ - if hop_length is None: - hop_length = n_fft // 4 - - if spectrogram_path_fn is None: - spectrogram_path_fn = default_spectrogram_path - - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - torch.set_grad_enabled(False) - if device.type == "cuda": - torch.backends.cudnn.benchmark = True - - Path(spectrograms_path).mkdir(parents=True, exist_ok=True) - - # Group windows by sound_path - by_sid = defaultdict(list) - for idx, win in enumerate(windows): - by_sid[win["sound_path"]].append((idx, win)) - - # Check for existing spectrograms - files_to_process = {} - total_windows = 0 - existing_windows = 0 - - print("Checking for existing spectrograms...") - for audio_file_path, items in tqdm(by_sid.items(), desc="Checking files"): - missing_items = [] - for idx, win in items: - npy_path = spectrogram_path_fn(win, spectrograms_path) - total_windows += 1 - - if not os.path.exists(npy_path): - missing_items.append((idx, win)) - else: - existing_windows += 1 - - if missing_items: - files_to_process[audio_file_path] = missing_items - - print(f"Found {existing_windows}/{total_windows} existing spectrograms") - print(f"Need to create {total_windows - existing_windows} spectrograms from {len(files_to_process)} audio files") - - if len(files_to_process) == 0: - print("All spectrograms already exist! Skipping computation.") - return - - for audio_file_path, items in tqdm(files_to_process.items(), desc="Processing files"): - # Decode on CPU - try: - y, orig_sr = sf.read(audio_file_path, dtype="float32", always_2d=True) - except sf.LibsndfileError: - y, orig_sr = _read_wav_fallback(audio_file_path) - print(f" [WAV fallback] {os.path.basename(audio_file_path)} " - f"({os.path.getsize(audio_file_path) / (1024**3):.1f} GB)") - if y.ndim == 2: - if y.shape[1] == 1: - y = y[:, 0] - elif mono_channel == "left": - y = y[:, 0] - elif mono_channel == "right": - y = y[:, 1] - else: # "mean" - y = y.mean(axis=1) - wav_cpu = torch.from_numpy(y).unsqueeze(0) - - # Resample if needed - if orig_sr != sample_rate: - wav_cpu = torchaudio.functional.resample(wav_cpu, orig_freq=orig_sr, new_freq=sample_rate) - - # Mel transform on GPU - mel_tf = torchaudio.transforms.MelSpectrogram( - sample_rate=sample_rate, - n_fft=n_fft, - hop_length=hop_length, - n_mels=n_mels, - f_min=f_min, - f_max=sample_rate / 2.0, - power=2.0, - center=False, - norm="slaney", - mel_scale="slaney", - ).to(device) - - to_db = torchaudio.transforms.AmplitudeToDB( - stype="power", top_db=top_db - ).to(device) - - for global_idx, win in tqdm(items): - start = int(win["start"]) - end = int(win["end"]) - npy_path = spectrogram_path_fn(win, spectrograms_path) - - if not os.path.exists(npy_path): - wav_win = wav_cpu[:, start:end].to(device) - S = mel_tf(wav_win).squeeze(0) - S_db = to_db(S) - - # Optional high-frequency fill - if fill_highfreq and orig_sr < sample_rate: - mel_freqs = librosa.mel_frequencies(n_mels=n_mels, fmin=f_min, fmax=sample_rate / 2.0) - nyq_orig = (float(orig_sr) / 2.0) - 2500 - noise_mask = torch.from_numpy((mel_freqs > nyq_orig).astype(np.bool_)).to(device) - if noise_mask.any(): - valid_mask = ~noise_mask - if fill_mean_below_sr: - vals = S_db[valid_mask, :] - if vals.numel() > 0: - mu = vals.mean().item() - else: - mu = -60.0 - S_db[noise_mask, :] = mu - else: - if noise_db_mean is None: - vals = S_db[valid_mask, :].reshape(-1) - if vals.numel() == 0: - mu = -60.0 - else: - v = vals.float().cpu() - k = max(1, int(math.ceil(0.10 * v.numel()))) - mu = torch.kthvalue(v, k).values.item() - else: - mu = float(noise_db_mean) - - S_db[noise_mask, :] = mu - S_db = torch.clamp(S_db, min=-top_db, max=20.0) - - if save_npy: - arr = S_db.detach().to("cpu").numpy() - if storage_dtype == "float16": - arr = arr.astype("float16", copy=False) - elif storage_dtype == "float32": - arr = arr.astype("float32", copy=False) - np.save(npy_path, arr) - - del wav_win, S, S_db - torch.cuda.empty_cache() diff --git a/PytorchWildlife/data/bioacoustics/bioacoustics_windows.py b/PytorchWildlife/data/bioacoustics/bioacoustics_windows.py deleted file mode 100644 index 77e69a66f..000000000 --- a/PytorchWildlife/data/bioacoustics/bioacoustics_windows.py +++ /dev/null @@ -1,449 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -""" -Window generation utilities for bioacoustics audio segmentation. - -This module provides functions for generating audio windows from annotated -audio files for training and inference. -""" - -import json -import math -import os -from typing import Dict, List, Union - -import numpy as np - -try: - import librosa -except ImportError: - librosa = None - - -__all__ = [ - "build_windows", - "build_inference_windows", - "count_window_labels", -] - - -def count_window_labels(windows: List[Dict]) -> Dict: - """Count label distribution in windows. - - Args: - windows: List of window dicts as returned by :func:`build_windows`. - - Returns: - Dictionary mapping each label value to its count. - """ - counts: Dict = {} - for w in windows: - label = w.get('label', 0) - counts[label] = counts.get(label, 0) + 1 - return counts - - -def build_windows( - annotation_file: str, - window_size_sec: float, - overlap_sec: float, - sample_rate: int, - datasets_names: List[str], - strategy: str = "sliding", - negative_proportion: float = 0.5, - multiclass: bool = False, - min_overlap_sec: float = 0, - custom_builder=None, -) -> List[Dict]: - """ - Build audio windows with their labels using the specified strategy. - - Args: - annotation_file: Path to the annotations JSON file. - window_size_sec: Window size in seconds. - overlap_sec: Overlap between windows in seconds. - sample_rate: Sample rate for calculations. - datasets_names: List of dataset names to match against file paths. - strategy: Window generation strategy: - - "sliding": Fixed overlap sliding windows across entire audio. - Labels based on annotation overlap. - - "balanced": Centers windows on annotations for positives, - then samples negatives to achieve desired proportion. - - "customized": Delegates to a user-supplied ``custom_builder`` - callable. - negative_proportion: Proportion of negatives for "balanced" strategy. - 0.5 means 50% negatives, 50% positives. - multiclass: If True, use the annotation's category_id as the label - instead of a binary 0/1. Defaults to False (binary). - min_overlap_sec: Minimum overlap in seconds between a window and an - annotation for the window to be labelled positive. Defaults to 0 - (any overlap counts). - custom_builder: Callable used when ``strategy="customized"``. It - receives ``(annotation_file, sample_rate, datasets_names)`` and - must return a list of window dicts. - - Returns: - List of window dicts with keys: 'window_id', 'dataset', 'sample_rate', - 'sound_id', 'start', 'end', 'label'. When *multiclass* is True an - extra 'ann_overlap' key is included with the overlap amount in samples. - """ - if strategy == "sliding": - return _build_windows_sliding( - annotation_file, window_size_sec, overlap_sec, - sample_rate, datasets_names, multiclass, min_overlap_sec - ) - elif strategy == "balanced": - return _build_windows_balanced( - annotation_file, window_size_sec, overlap_sec, - sample_rate, datasets_names, negative_proportion, - multiclass, min_overlap_sec - ) - elif strategy == "customized": - if custom_builder is None: - raise ValueError( - "The 'customized' strategy requires a 'custom_builder' callable." - ) - return custom_builder( - annotation_file, sample_rate, datasets_names - ) - else: - raise ValueError( - f"Unknown strategy: {strategy}. " - f"Use 'sliding', 'balanced', or 'customized'." - ) - - -def _build_windows_sliding( - annotation_file: str, - window_size_sec: float, - overlap_sec: float, - sample_rate: int, - datasets_names: List[str], - multiclass: bool = False, - min_overlap_sec: float = 0, -) -> List[Dict]: - """ - Sliding window strategy: generates windows with fixed overlap across entire audio. - Labels each window based on whether it overlaps with any annotation. - - When *multiclass* is False (default) the label is binary (0 or 1). - When *multiclass* is True the label is the annotation's ``category_id`` - and an extra ``ann_overlap`` field records the overlap in samples. - """ - with open(annotation_file, 'r') as f: - data = json.load(f) - sounds = {sound['id']: sound for sound in data['sounds']} - annotations = data['annotations'] - - window_size = int(window_size_sec * sample_rate) - hop_size = int((window_size_sec - overlap_sec) * sample_rate) - min_overlap_samples = int(min_overlap_sec * sample_rate) - - windows = [] - window_idx = 0 - - for sound_id, sound in sounds.items(): - duration_samples = int(sound['duration'] * sample_rate) - num_windows = math.ceil((duration_samples - window_size) / hop_size) + 1 - - # Filter annotations for this sound - sound_events = [] - for ev in annotations: - if ev['sound_id'] == sound_id: - sound_events.append(( - int(ev['t_min'] * sample_rate), - int(ev['t_max'] * sample_rate), - ev.get('category_id', 1) - )) - - dataset = None - for dataset_name in datasets_names: - if dataset_name in sound['file_name_path']: - dataset = dataset_name - - for i in range(num_windows): - start = i * hop_size - end = start + window_size - - if end > duration_samples: - continue - - # Check overlap with any event - label = 0 - ann_overlap = 0 - for event_start, event_end, category_id in sound_events: - if event_end > start and event_start < end: - overlap = min(end, event_end) - max(start, event_start) - if overlap > min_overlap_samples: - label = category_id if multiclass else 1 - ann_overlap = overlap - break - - win = { - 'window_id': window_idx, - 'dataset': dataset, - 'sample_rate': sound["sample_rate"], - 'sound_id': sound_id, - 'start': start, - 'end': end, - 'label': label, - } - if multiclass: - win['ann_overlap'] = ann_overlap - windows.append(win) - window_idx += 1 - - return windows - - -def _build_windows_balanced( - annotation_file: str, - window_size_sec: float, - overlap_sec: float, - sample_rate: int, - datasets_names: List[str], - negative_proportion: float = 0.5, - multiclass: bool = False, - min_overlap_sec: float = 0, -) -> List[Dict]: - """ - Balanced strategy: centers windows on annotations for positives, - then samples negatives to achieve the desired class proportion. - - Args: - negative_proportion: Final proportion of negative examples in the dataset. - 0.5 means 50% negatives, 50% positives (equal amounts). - 0.7 means 70% negatives, 30% positives. - multiclass: If True, use the annotation's category_id as the label. - min_overlap_sec: Minimum overlap in seconds for negative rejection. - """ - with open(annotation_file, 'r') as f: - data = json.load(f) - sounds = {sound['id']: sound for sound in data['sounds']} - annotations = data['annotations'] - - window_size = int(window_size_sec * sample_rate) - hop_size = int((window_size_sec - overlap_sec) * sample_rate) - min_overlap_samples = int(min_overlap_sec * sample_rate) - - window_idx = 0 - positive_windows = [] - all_positive_regions = {} - - # Step 1: Extract all positive examples (annotations) - for sound_id, sound in sounds.items(): - duration_samples = int(sound['duration'] * sample_rate) - - dataset = None - for dataset_name in datasets_names: - if dataset_name in sound['file_name_path']: - dataset = dataset_name - break - - sound_events = [] - for ev in annotations: - if ev['sound_id'] == sound_id: - sound_events.append(( - int(ev['t_min'] * sample_rate), - int(ev['t_max'] * sample_rate), - ev.get('category_id', 1) - )) - - positive_regions = [] - for event_start, event_end, category_id in sound_events: - # Center the window on the annotation - annotation_center = (event_start + event_end) // 2 - win_start = annotation_center - window_size // 2 - win_end = win_start + window_size - - # Adjust if window goes beyond audio boundaries - if win_start < 0: - win_start = 0 - win_end = window_size - elif win_end > duration_samples: - win_end = duration_samples - win_start = win_end - window_size - - # Only add if we have enough samples - if win_end - win_start == window_size and win_end <= duration_samples: - win = { - 'window_id': window_idx, - 'dataset': dataset, - 'sample_rate': sound["sample_rate"], - 'sound_id': sound_id, - 'start': win_start, - 'end': win_end, - 'label': category_id if multiclass else 1, - } - if multiclass: - overlap = ( - min(win_end, event_end) - max(win_start, event_start) - ) - win['ann_overlap'] = overlap - positive_windows.append(win) - positive_regions.append((win_start, win_end)) - window_idx += 1 - - all_positive_regions[sound_id] = positive_regions - - # Step 2: Sample negative examples based on proportion - num_positives = len(positive_windows) - # Calculate negatives needed: negatives / (positives + negatives) = negative_proportion - num_negatives_needed = int(num_positives * negative_proportion / (1 - negative_proportion)) - - print(f"Positive examples found: {num_positives}") - print(f"Negative examples needed: {num_negatives_needed}") - print(f"Desired proportion - Negatives: {negative_proportion:.1%}, Positives: {(1-negative_proportion):.1%}") - - negative_windows = [] - - # Generate candidate negative windows for each sound - for sound_id, sound in sounds.items(): - if sound_id not in all_positive_regions: - continue - - duration_samples = int(sound['duration'] * sample_rate) - positive_regions = all_positive_regions[sound_id] - - dataset = None - for dataset_name in datasets_names: - if dataset_name in sound['file_name_path']: - dataset = dataset_name - break - - # Filter annotations for this sound (needed for min_overlap check) - sound_events = [] - if min_overlap_samples > 0: - for ev in annotations: - if ev['sound_id'] == sound_id: - sound_events.append(( - int(ev['t_min'] * sample_rate), - int(ev['t_max'] * sample_rate), - )) - - # Generate all possible negative windows with overlap - candidates = [] - start = 0 - while start + window_size <= duration_samples: - end = start + window_size - - # Check if this window overlaps with any positive region - is_negative = True - for pos_start, pos_end in positive_regions: - if not (end <= pos_start or start >= pos_end): - # When min_overlap_sec > 0, only reject if actual - # annotation overlap exceeds the threshold - if min_overlap_samples > 0: - for ev_start, ev_end in sound_events: - ov = min(end, ev_end) - max(start, ev_start) - if ov > min_overlap_samples: - is_negative = False - break - else: - is_negative = False - break - - if is_negative: - candidates.append((start, end)) - - start += hop_size - - for start, end in candidates: - win = { - 'window_id': None, - 'dataset': dataset, - 'sample_rate': sound["sample_rate"], - 'sound_id': sound_id, - 'start': start, - 'end': end, - 'label': 0, - } - if multiclass: - win['ann_overlap'] = 0 - negative_windows.append(win) - - # Shuffle and select the required number of negative examples - np.random.shuffle(negative_windows) - print(f"Negative examples available: {len(negative_windows)}") - selected_negatives = negative_windows[:num_negatives_needed] - - # Assign window IDs to selected negatives - for neg_win in selected_negatives: - neg_win['window_id'] = window_idx - window_idx += 1 - - print(f"Negative examples selected: {len(selected_negatives)}") - final_total = len(selected_negatives) + num_positives - print(f"Final proportion - Negatives: {len(selected_negatives)/final_total:.1%}, Positives: {num_positives/final_total:.1%}") - - return positive_windows + selected_negatives - - -def build_inference_windows( - audios_source: Union[str, List[str]], - window_size_sec: float, - overlap_sec: float, - sample_rate: int, -) -> List[Dict]: - """ - Build inference windows with fixed overlap from audio files. - - Parameters - ---------- - audios_source : str or list of str - Path to a directory of audio files, or a list of audio file paths. - window_size_sec : float - Window size in seconds. - overlap_sec : float - Overlap between consecutive windows in seconds. - sample_rate : int - Target sample rate for computing window boundaries. - - Returns - ------- - list of dict - Each dict has keys: ``window_id``, ``sound_path``, ``start``, ``end``. - """ - if librosa is None: - raise ImportError("librosa is required for build_inference_windows. Install with: pip install librosa") - - window_size = int(window_size_sec * sample_rate) - hop_size = int((window_size_sec - overlap_sec) * sample_rate) - - windows = [] - window_idx = 0 - - if isinstance(audios_source, str): - wav_files = [ - os.path.join(audios_source, f) - for f in os.listdir(audios_source) - if f.lower().endswith(('.wav', '.flac', '.mp3', '.m4a', '.aac', '.ogg')) - and not f.startswith('.') - ] - elif isinstance(audios_source, list): - wav_files = audios_source - else: - raise TypeError("audios_source must be either a folder path (str) or a list of file paths (list[str])") - - for filename in wav_files: - sound_duration = librosa.get_duration(path=filename) - duration_samples = int(sound_duration * sample_rate) - num_windows = math.ceil((duration_samples - window_size) / hop_size) + 1 - - for i in range(num_windows): - start = i * hop_size - end = start + window_size - - if end > duration_samples: - continue - - windows.append({ - 'window_id': window_idx, - 'sound_path': filename, - 'start': start, - 'end': end, - }) - window_idx += 1 - - return windows diff --git a/PytorchWildlife/data/datasets.py b/PytorchWildlife/data/datasets.py deleted file mode 100644 index 39870a626..000000000 --- a/PytorchWildlife/data/datasets.py +++ /dev/null @@ -1,204 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -import os -from glob import glob -from PIL import Image, ImageFile -import numpy as np -import supervision as sv -import torch -from torch.utils.data import Dataset - -# To handle truncated images during loading -ImageFile.LOAD_TRUNCATED_IMAGES = True - -# Making the DetectionImageFolder class available for import from this module -__all__ = [ - "ClassificationImageFolder", - "DetectionImageFolder", - ] - -# Define the allowed image extensions -IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp") - -def has_file_allowed_extension(filename: str, extensions: tuple) -> bool: - """Checks if a file is an allowed extension.""" - return filename.lower().endswith(extensions if isinstance(extensions, str) else tuple(extensions)) - -def is_image_file(filename: str) -> bool: - """Checks if a file is an allowed image extension.""" - return has_file_allowed_extension(filename, IMG_EXTENSIONS) - -class ImageFolder(Dataset): - """ - A PyTorch Dataset for loading images from a specified directory. - Each item in the dataset is a tuple containing the image data, - the image's path, and the original size of the image. - """ - - def __init__(self, image_dir, transform=None): - """ - Initializes the dataset. - - Parameters: - image_dir (str): Path to the directory containing the images. - transform (callable, optional): Optional transform to be applied on the image. - """ - super(ImageFolder, self).__init__() - self.image_dir = image_dir - self.transform = transform - self.images = [os.path.join(dp, f) for dp, dn, filenames in os.walk(image_dir) for f in filenames if is_image_file(f)] # dp: directory path, dn: directory name, f: filename - - def __getitem__(self, idx) -> tuple: - """ - Retrieves an image from the dataset. - - Parameters: - idx (int): Index of the image to retrieve. - - Returns: - tuple: Contains the image data, the image's path, and its original size. - """ - pass - - def __len__(self) -> int: - """ - Returns the total number of images in the dataset. - - Returns: - int: Total number of images. - """ - return len(self.images) - -class ClassificationImageFolder(ImageFolder): - """ - A PyTorch Dataset for loading images from a specified directory. - Each item in the dataset is a tuple containing the image data, - the image's path, and the original size of the image. - """ - - def __init__(self, image_dir, transform=None): - """ - Initializes the dataset. - - Parameters: - image_dir (str): Path to the directory containing the images. - transform (callable, optional): Optional transform to be applied on the image. - """ - super(ClassificationImageFolder, self).__init__(image_dir, transform) - - def __getitem__(self, idx) -> tuple: - """ - Retrieves an image from the dataset. - - Parameters: - idx (int): Index of the image to retrieve. - - Returns: - tuple: Contains the image data, the image's path, and its original size. - """ - # Get image filename and path - img_path = self.images[idx] - - # Load and convert image to RGB - img = Image.open(img_path).convert("RGB") - - # Apply transformation if specified - if self.transform: - img = self.transform(img) - - return img, img_path - - -class DetectionImageFolder(ImageFolder): - """ - A PyTorch Dataset for loading images from a specified directory. - Each item in the dataset is a tuple containing the image data, - the image's path, and the original size of the image. - """ - - def __init__(self, image_dir, transform=None): - """ - Initializes the dataset. - - Parameters: - image_dir (str): Path to the directory containing the images. - transform (callable, optional): Optional transform to be applied on the image. - """ - super(DetectionImageFolder, self).__init__(image_dir, transform) - - def __getitem__(self, idx) -> tuple: - """ - Retrieves an image from the dataset. - - Parameters: - idx (int): Index of the image to retrieve. - - Returns: - tuple: Contains the image data, the image's path, and its original size. - """ - # Get image filename and path - img_path = self.images[idx] - - # Load and convert image to RGB - img = Image.open(img_path).convert("RGB") - img_size_ori = img.size[::-1] - - # Apply transformation if specified - if self.transform: - img = self.transform(img) - - return img, img_path, torch.tensor(img_size_ori) - - -# TODO: Under development for efficiency improvement -class DetectionCrops(Dataset): - - def __init__(self, detection_results, transform=None, path_head=None, animal_cls_id=0): - - self.detection_results = detection_results - self.transform = transform - self.path_head = path_head - self.animal_cls_id = animal_cls_id # This determines which detection class id represents animals. - self.img_ids = [] - self.xyxys = [] - - self.load_detection_results() - - def load_detection_results(self): - for det in self.detection_results: - for xyxy, det_id in zip(det["detections"].xyxy, det["detections"].class_id): - # Only run recognition on animal detections - if det_id == self.animal_cls_id: - self.img_ids.append(det["img_id"]) - self.xyxys.append(xyxy) - - def __getitem__(self, idx) -> tuple: - """ - Retrieves an image from the dataset. - - Parameters: - idx (int): Index of the image to retrieve. - - Returns: - tuple: Contains the image data and the image's path. - """ - - # Get image path and corresponding bbox xyxy for cropping - img_id = self.img_ids[idx] - xyxy = self.xyxys[idx] - - img_path = os.path.join(self.path_head, img_id) if self.path_head else img_id - - # Load and crop image with supervision - img = sv.crop_image(np.array(Image.open(img_path).convert("RGB")), - xyxy=xyxy) - - # Apply transformation if specified - if self.transform: - img = self.transform(Image.fromarray(img)) - - return img, img_path - - def __len__(self) -> int: - return len(self.img_ids) \ No newline at end of file diff --git a/PytorchWildlife/data/transforms.py b/PytorchWildlife/data/transforms.py deleted file mode 100644 index c9bceab28..000000000 --- a/PytorchWildlife/data/transforms.py +++ /dev/null @@ -1,156 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -import numpy as np -import torch -from torchvision import transforms -import torchvision.transforms as T -import torch.nn.functional as F -from PIL import Image - -# Making the provided classes available for import from this module -__all__ = [ - "MegaDetector_v5_Transform", - "Classification_Inference_Transform" -] - - -def letterbox(im, new_shape=(640, 640), color=(114, 114, 114), auto=False, scaleFill=False, scaleup=True, stride=32) -> torch.Tensor: - """ - Resize and pad an image to a desired shape while keeping the aspect ratio unchanged. - - This function is commonly used in object detection tasks to prepare images for models like YOLOv5. - It resizes the image to fit into the new shape with the correct aspect ratio and then pads the rest. - - Args: - im (PIL.Image.Image or torch.Tensor): The input image. It can be a PIL image or a PyTorch tensor. - new_shape (tuple, optional): The target size of the image, in the form (height, width). Defaults to (640, 640). - color (tuple, optional): The color used for padding. Defaults to (114, 114, 114). - auto (bool, optional): Adjust padding to ensure the padded image dimensions are a multiple of the stride. Defaults to False. - scaleFill (bool, optional): If True, scales the image to fill the new shape, ignoring the aspect ratio. Defaults to False. - scaleup (bool, optional): Allow the function to scale up the image. Defaults to True. - stride (int, optional): The stride used in the model. The padding is adjusted to be a multiple of this stride. Defaults to 32. - - Returns: - torch.Tensor: The transformed image with padding applied. - """ - - # Convert PIL Image to Torch Tensor - - if isinstance(im, Image.Image): - im = T.ToTensor()(im) - - # Original shape - shape = im.shape[1:] # shape = [height, width] - - # New shape - if isinstance(new_shape, int): - new_shape = (new_shape, new_shape) - - # Scale ratio (new / old) and compute padding - r = min(new_shape[0] / shape[0], new_shape[1] / shape[1]) - if not scaleup: - r = min(r, 1.0) - - new_unpad = (int(round(shape[1] * r)), int(round(shape[0] * r))) - dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] - - if auto: - dw, dh = dw % stride, dh % stride - elif scaleFill: - dw, dh = 0, 0 - new_unpad = new_shape - r = new_shape[1] / shape[1], new_shape[0] / shape[0] - - dw /= 2 - dh /= 2 - - # Resize image - if shape[::-1] != new_unpad: - resize_transform = T.Resize(new_unpad[::-1], interpolation=T.InterpolationMode.BILINEAR, - antialias=False) - im = resize_transform(im) - - # Pad image - padding = (int(round(dw - 0.1)), int(round(dw + 0.1)), int(round(dh + 0.1)), int(round(dh - 0.1))) - im = F.pad(im*255.0, padding, value=114)/255.0 - - return im - -class MegaDetector_v5_Transform: - """ - A transformation class to preprocess images for the MegaDetector v5 model. - This includes resizing, transposing, and normalization operations. - This is a required transformation for the YoloV5 model. - - """ - - def __init__(self, target_size=1280, stride=32): - """ - Initializes the transform. - - Args: - target_size (int): Desired size for the image's longest side after resizing. - stride (int): Stride value for resizing. - """ - self.target_size = target_size - self.stride = stride - - def __call__(self, np_img) -> torch.Tensor: - """ - Applies the transformation on the provided image. - - Args: - np_img (np.ndarray): Input image as a numpy array or PIL Image. - - Returns: - torch.Tensor: Transformed image. - """ - # Convert the image to a PyTorch tensor and normalize it - if isinstance(np_img, np.ndarray): - np_img = np_img.transpose((2, 0, 1)) - np_img = np.ascontiguousarray(np_img) - np_img = torch.from_numpy(np_img).float() - np_img /= 255.0 - - # Resize and pad the image using a customized letterbox function. - img = letterbox(np_img, new_shape=self.target_size, stride=self.stride, auto=False) - - return img - -class Classification_Inference_Transform: - """ - A transformation class to preprocess images for classification inference. - This includes resizing, normalization, and conversion to a tensor. - """ - # Normalization constants - mean = [0.485, 0.456, 0.406] - std = [0.229, 0.224, 0.225] - - def __init__(self, target_size=224, **kwargs): - """ - Initializes the transform. - - Args: - target_size (int): Desired size for the height and width after resizing. - """ - # Define the sequence of transformations - self.trans = transforms.Compose([ - # transforms.Resize((target_size, target_size)), - transforms.Resize((target_size, target_size), **kwargs), - transforms.ToTensor(), - transforms.Normalize(self.mean, self.std) - ]) - - def __call__(self, img) -> torch.Tensor: - """ - Applies the transformation on the provided image. - - Args: - img (PIL.Image.Image): Input image in PIL format. - - Returns: - torch.Tensor: Transformed image. - """ - img = self.trans(img) - return img diff --git a/PytorchWildlife/models/__init__.py b/PytorchWildlife/models/__init__.py deleted file mode 100644 index d835b0ae8..000000000 --- a/PytorchWildlife/models/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .classification import * -from .detection import * -from .bioacoustics import * \ No newline at end of file diff --git a/PytorchWildlife/models/bioacoustics/__init__.py b/PytorchWildlife/models/bioacoustics/__init__.py deleted file mode 100644 index df082e375..000000000 --- a/PytorchWildlife/models/bioacoustics/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -"""Bioacoustics models for PytorchWildlife.""" - -from .base_bioacoustics import BaseBioacousticsClassifier -from .resnet_classifier import ResNetClassifier, load_model_from_checkpoint - -__all__ = [ - "BaseBioacousticsClassifier", - "ResNetClassifier", - "load_model_from_checkpoint", -] diff --git a/PytorchWildlife/models/bioacoustics/base_bioacoustics.py b/PytorchWildlife/models/bioacoustics/base_bioacoustics.py deleted file mode 100644 index fcd7e6f18..000000000 --- a/PytorchWildlife/models/bioacoustics/base_bioacoustics.py +++ /dev/null @@ -1,97 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -"""Base bioacoustics classifier class.""" - -import torch.nn as nn - - -__all__ = ["BaseBioacousticsClassifier"] - - -class BaseBioacousticsClassifier(nn.Module): - """ - Base class for bioacoustics classifiers. - - This class provides utility methods for loading models, generating results, - and performing single and batch audio classifications on spectrograms. - """ - - # Placeholder class-level attributes to be defined in derived classes - SAMPLE_RATE = None - WINDOW_SIZE_SEC = None - N_MELS = None - CLASS_NAMES = None - - def __init__(self, weights=None, device="cpu", url=None): - """ - Initialize the base bioacoustics classifier. - - Args: - weights (str, optional): Path to model weights. Defaults to None. - device (str, optional): Device for inference. Defaults to "cpu". - url (str, optional): URL to fetch model weights. Defaults to None. - """ - super(BaseBioacousticsClassifier, self).__init__() - self.device = device - - def _load_model(self, weights=None, device="cpu", url=None): - """ - Load model weights. - - Args: - weights (str, optional): Path to model weights. Defaults to None. - device (str, optional): Device for inference. Defaults to "cpu". - url (str, optional): URL to fetch model weights. Defaults to None. - - Raises: - Exception: If weights are not provided. - """ - pass - - def results_generation(self, preds, audio_id: str, id_strip: str = None) -> dict: - """ - Generate results for classification based on model predictions. - - Args: - preds: Model predictions (logits or probabilities). - audio_id (str): Audio identifier. - id_strip (str, optional): Strip specific characters from audio_id. - - Returns: - dict: Dictionary containing audio ID, predictions, and labels. - """ - pass - - def single_audio_classification( - self, spectrogram, audio_path=None, conf_threshold=0.5, id_strip=None - ) -> dict: - """ - Perform classification on a single spectrogram. - - Args: - spectrogram: Spectrogram tensor or ndarray. - audio_path (str, optional): Audio path or identifier. - conf_threshold (float, optional): Confidence threshold. Defaults to 0.5. - id_strip (str, optional): Characters to strip from audio_id. - - Returns: - dict: Classification results. - """ - pass - - def batch_audio_classification( - self, dataloader, conf_threshold: float = 0.5, id_strip: str = None - ) -> list[dict]: - """ - Perform classification on a batch of spectrograms. - - Args: - dataloader (DataLoader): DataLoader containing spectrogram batches. - conf_threshold (float, optional): Confidence threshold. Defaults to 0.5. - id_strip (str, optional): Characters to strip from audio_id. - - Returns: - list[dict]: List of classification results for all audio samples. - """ - pass diff --git a/PytorchWildlife/models/bioacoustics/resnet_classifier.py b/PytorchWildlife/models/bioacoustics/resnet_classifier.py deleted file mode 100644 index 3ba44280b..000000000 --- a/PytorchWildlife/models/bioacoustics/resnet_classifier.py +++ /dev/null @@ -1,588 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -""" -ResNet-based classifier for bioacoustics spectrogram classification. - -Supports both binary classification (num_classes=2) and multiclass (num_classes>2). -""" - -from typing import List, Optional - -import os - -import numpy as np -import pandas as pd -import torch -import torch.nn as nn -import torchvision.models as tvm - -import pytorch_lightning as pl - -# Binary metrics -from torchmetrics.classification import ( - BinaryAccuracy, BinaryConfusionMatrix, BinaryF1Score, - BinaryPrecision, BinaryRecall, BinaryAveragePrecision, BinaryPrecisionRecallCurve -) -# Multiclass metrics -from torchmetrics.classification import ( - MulticlassAccuracy, MulticlassF1Score, MulticlassPrecision, - MulticlassRecall, MulticlassConfusionMatrix -) - -from sklearn.metrics import average_precision_score - -try: - from torch_uncertainty.losses import BCEWithLogitsLSLoss -except ImportError: - BCEWithLogitsLSLoss = None - - -__all__ = ["ResNetClassifier"] - - -class ResNetClassifier(pl.LightningModule): - """ - Unified ResNet classifier supporting both binary and multiclass classification. - - When num_classes=2: Binary mode with BCEWithLogitsLoss, sigmoid activation. - When num_classes>2: Multiclass mode with CrossEntropyLoss, softmax activation. - - Args: - num_classes: Number of output classes (2 for binary, >2 for multiclass). - in_channels: Number of input channels (1 for grayscale spectrograms). - backbone: ResNet backbone variant ("resnet18", "resnet34", "resnet50"). - lr: Learning rate. - weight_decay: Weight decay for optimizer. - label_smoothing: Label smoothing factor. - T_max: Cosine annealing T_max parameter. - batch_size: Batch size for logging. - pos_weight: Positive class weight (binary only). - conf_threshold: Confidence threshold for predictions (binary only). - freeze_backbone: Freezing strategy ("none", "all", "early", "layer1-3"). - backbone_lr_ratio: Learning rate ratio for backbone vs classifier. - class_names: List of class names for multiclass labeling. - """ - - def __init__( - self, - num_classes: int = 2, - in_channels: int = 1, - backbone: str = "resnet18", - lr: float = 3e-4, - weight_decay: float = 1e-4, - label_smoothing: float = 0.0, - T_max: int = 100, - batch_size: int = 32, - pos_weight: float = 1.0, - conf_threshold: float = 0.5, - freeze_backbone: str = "none", - backbone_lr_ratio: float = 1.0, - class_names: Optional[List[str]] = None, - ): - super().__init__() - self.save_hyperparameters() - self.is_binary = (num_classes == 2) - - # Build backbone with ImageNet weights - if backbone == "resnet18": - net = tvm.resnet18(weights=tvm.ResNet18_Weights.IMAGENET1K_V1) - elif backbone == "resnet34": - net = tvm.resnet34(weights=tvm.ResNet34_Weights.IMAGENET1K_V1) - elif backbone == "resnet50": - net = tvm.resnet50(weights=tvm.ResNet50_Weights.IMAGENET1K_V1) - else: - raise ValueError(f"Unsupported backbone: {backbone}") - - # Adapt first conv to match spectrogram channels (typically 1) - old_conv = net.conv1 - if in_channels != 3: - new_conv = nn.Conv2d( - in_channels=in_channels, - out_channels=old_conv.out_channels, - kernel_size=old_conv.kernel_size, - stride=old_conv.stride, - padding=old_conv.padding, - bias=False, - ) - with torch.no_grad(): - new_conv.weight.data = old_conv.weight.data.mean(dim=1, keepdim=True) - net.conv1 = new_conv - - # Output head depends on classification mode - in_feats = net.fc.in_features - if self.is_binary: - net.fc = nn.Linear(in_feats, 1) - else: - net.fc = nn.Linear(in_feats, num_classes) - self.net = net - - # Apply freezing strategy - self._apply_freezing_strategy() - - # Initialize loss and metrics based on mode - if self.is_binary: - self._init_binary_loss_and_metrics() - else: - self._init_multiclass_loss_and_metrics(num_classes) - - # Storage for test predictions - self.test_logits = [] - self.test_targets = [] - self.test_paths = [] - self.test_preds = [] - - # Path to the test CSV; when set, on_test_epoch_end exports - # predictions alongside the original columns. - self.test_csv_path: Optional[str] = None - # Directory where the predictions CSV is saved. When None the - # file is written next to the original test CSV. - self.predictions_dir: Optional[str] = None - # When True, test_step skips metric updates and on_test_epoch_end - # skips metric computation. Useful when the test set labels are - # in a different label space than the model (e.g. running a - # 3-class model on a 4-class test set just to export predictions). - self.predict_only: bool = False - - def _init_binary_loss_and_metrics(self): - """Initialize loss and metrics for binary classification.""" - if BCEWithLogitsLSLoss is not None: - self.criterion = BCEWithLogitsLSLoss(label_smoothing=self.hparams.label_smoothing) - else: - self.criterion = nn.BCEWithLogitsLoss() - - self.train_acc = BinaryAccuracy() - self.val_acc = BinaryAccuracy() - self.test_acc = BinaryAccuracy() - - self.train_f1 = BinaryF1Score() - self.val_f1 = BinaryF1Score() - self.test_f1 = BinaryF1Score() - - self.train_auprc = BinaryAveragePrecision() - self.val_auprc = BinaryAveragePrecision() - self.test_auprc = BinaryAveragePrecision() - - self.test_prec = BinaryPrecision() - self.test_rec = BinaryRecall() - self.test_cm = BinaryConfusionMatrix() - self.test_prcurve = BinaryPrecisionRecallCurve(thresholds=None) - - def _init_multiclass_loss_and_metrics(self, num_classes: int): - """Initialize loss and metrics for multiclass classification.""" - self.criterion = nn.CrossEntropyLoss(label_smoothing=self.hparams.label_smoothing) - - self.train_acc = MulticlassAccuracy(num_classes=num_classes, average='micro') - self.val_acc = MulticlassAccuracy(num_classes=num_classes, average='micro') - self.test_acc = MulticlassAccuracy(num_classes=num_classes, average='micro') - - self.train_f1 = MulticlassF1Score(num_classes=num_classes, average='macro') - self.val_f1 = MulticlassF1Score(num_classes=num_classes, average='macro') - self.test_f1 = MulticlassF1Score(num_classes=num_classes, average='macro') - - self.test_prec = MulticlassPrecision(num_classes=num_classes, average='macro') - self.test_rec = MulticlassRecall(num_classes=num_classes, average='macro') - self.test_cm = MulticlassConfusionMatrix(num_classes=num_classes) - - self.test_f1_per_class = MulticlassF1Score(num_classes=num_classes, average=None) - self.test_prec_per_class = MulticlassPrecision(num_classes=num_classes, average=None) - self.test_rec_per_class = MulticlassRecall(num_classes=num_classes, average=None) - - def _apply_freezing_strategy(self): - """Apply layer freezing based on freeze_backbone parameter.""" - freeze_until = None - - if self.hparams.freeze_backbone == "none": - return - elif self.hparams.freeze_backbone == "all": - freeze_until = "fc" - elif self.hparams.freeze_backbone == "early": - freeze_until = "layer3" - elif self.hparams.freeze_backbone in ["layer1", "layer2", "layer3"]: - freeze_until = self.hparams.freeze_backbone - else: - raise ValueError(f"Invalid freeze_backbone: {self.hparams.freeze_backbone}") - - if freeze_until: - trainable_params = 0 - frozen_params = 0 - freeze = True - - for name, param in self.net.named_parameters(): - if freeze_until in name: - freeze = False - param.requires_grad = not freeze - if param.requires_grad: - trainable_params += param.numel() - else: - frozen_params += param.numel() - - print(f"\n{'='*60}") - print(f"Freezing strategy: {self.hparams.freeze_backbone}") - print(f"Frozen parameters: {frozen_params:,}") - print(f"Trainable parameters: {trainable_params:,}") - print(f"Frozen ratio: {frozen_params/(frozen_params+trainable_params)*100:.1f}%") - print(f"{'='*60}\n") - - def forward(self, x): - return self.net(x) - - def _compute_loss(self, logits, y): - """Compute loss handling both regular and MixUp batches.""" - # Import here to avoid circular dependency - from PytorchWildlife.data.bioacoustics.bioacoustics_datasets import mixup_criterion - - if self.is_binary: - logits = logits.squeeze(1) - if isinstance(y, dict) and y.get('is_mixed', False): - loss = mixup_criterion(self.criterion, logits, y) - targets = y['original_labels'].int() - else: - targets = y['original_labels'].int() if isinstance(y, dict) else y.int() - loss = self.criterion(logits, targets.float()) - preds = (logits > 0).int() - else: - targets = y.long() - loss = self.criterion(logits, targets) - preds = torch.argmax(logits, dim=1) - - return loss, preds, targets, logits - - def training_step(self, batch, batch_idx): - x, y, path = batch - logits = self(x) - loss, preds, targets, logits_processed = self._compute_loss(logits, y) - - self.train_acc.update(preds, targets) - self.train_f1.update(preds, targets) - if self.is_binary: - self.train_auprc.update(logits_processed, targets) - - self.log("train/loss", loss, batch_size=self.hparams.batch_size, prog_bar=True, on_step=False, on_epoch=True) - return loss - - def on_train_epoch_end(self): - acc = self.train_acc.compute() - f1 = self.train_f1.compute() - self.log("train/acc", acc, prog_bar=True) - self.log("train/f1", f1, prog_bar=True) - - if self.is_binary: - auprc = self.train_auprc.compute() - self.log("train/auprc", auprc, prog_bar=True) - self.train_auprc.reset() - - self.train_acc.reset() - self.train_f1.reset() - - def validation_step(self, batch, batch_idx): - x, y, path = batch - logits = self(x) - - if self.is_binary: - logits = logits.squeeze(1) - loss = self.criterion(logits, y.float()) - preds = (logits > 0).int() - targets = y.int() - else: - targets = y.long() - loss = self.criterion(logits, targets) - preds = torch.argmax(logits, dim=1) - - self.val_acc.update(preds, targets) - self.val_f1.update(preds, targets) - if self.is_binary: - self.val_auprc.update(logits, targets) - - self.log("val/loss", loss, batch_size=self.hparams.batch_size, prog_bar=True, on_step=False, on_epoch=True) - - def on_validation_epoch_end(self): - acc = self.val_acc.compute() - f1 = self.val_f1.compute() - self.log("val/acc", acc, prog_bar=True) - self.log("val/f1", f1, prog_bar=True) - - if self.is_binary: - auprc = self.val_auprc.compute() - self.log("val/auprc", auprc, prog_bar=True) - self.val_auprc.reset() - - self.val_acc.reset() - self.val_f1.reset() - - def test_step(self, batch, batch_idx): - x, y, path = batch - logits = self(x) - - if hasattr(self, "temperature"): - logits = logits / self.temperature - - if self.is_binary: - logits = logits.squeeze(1) - prob = torch.sigmoid(logits) - preds = (prob > self.hparams.conf_threshold).int() - targets = y.int() - else: - targets = y.long() - preds = torch.argmax(logits, dim=1) - - self.test_logits.append(logits.detach().cpu()) - self.test_targets.append(targets.detach().cpu()) - self.test_paths.extend(path) - self.test_preds.append(preds.detach().cpu()) - - if not self.predict_only: - if self.is_binary: - loss = self.criterion(logits, y.float()) - else: - loss = self.criterion(logits, targets) - - self.test_acc.update(preds, targets) - self.test_f1.update(preds, targets) - self.test_prec.update(preds, targets) - self.test_rec.update(preds, targets) - self.test_cm.update(preds, targets) - - if self.is_binary: - self.test_auprc.update(logits, targets) - self.test_prcurve.update(logits, targets) - else: - self.test_f1_per_class.update(preds, targets) - self.test_prec_per_class.update(preds, targets) - self.test_rec_per_class.update(preds, targets) - - self.log("test/loss", loss, batch_size=self.hparams.batch_size, on_step=False, on_epoch=True) - - def on_test_epoch_end(self): - if self.predict_only: - self._export_test_predictions() - self._reset_test_state() - return - if self.is_binary: - self._on_test_epoch_end_binary() - else: - self._on_test_epoch_end_multiclass() - - def _on_test_epoch_end_binary(self): - """Test epoch end processing for binary classification.""" - f1 = self.test_f1.compute().item() - auprc = self.test_auprc.compute().item() - prec = self.test_prec.compute().item() - rec = self.test_rec.compute().item() - - self.log("test/f1", f1) - self.log("test/auprc", auprc) - self.log("test/precision", prec) - self.log("test/recall", rec) - - cm = self.test_cm.compute().cpu().numpy() - acc_neg = cm[0, 0] / cm[0, :].sum() if cm[0, :].sum() > 0 else 0.0 - acc_pos = cm[1, 1] / cm[1, :].sum() if cm[1, :].sum() > 0 else 0.0 - self.log("test/acc_neg", acc_neg) - self.log("test/acc_pos", acc_pos) - - self._export_test_predictions() - self._reset_test_state() - - def _on_test_epoch_end_multiclass(self): - """Test epoch end processing for multiclass classification.""" - num_classes = self.hparams.num_classes - - logits = torch.cat(self.test_logits, dim=0) - targets = torch.cat(self.test_targets, dim=0).numpy() - - if hasattr(self, "temperature"): - temp = self.temperature.cpu() if self.temperature.is_cuda else self.temperature - probs = torch.softmax(logits / temp, dim=1).numpy() - else: - probs = torch.softmax(logits, dim=1).numpy() - - acc = self.test_acc.compute().item() - f1 = self.test_f1.compute().item() - prec = self.test_prec.compute().item() - rec = self.test_rec.compute().item() - - macro_ap = np.nan - if len(np.unique(targets)) == num_classes: - macro_ap = average_precision_score( - y_true=pd.get_dummies(targets), - y_score=probs, - average="macro" - ) - - self.log_dict({ - "test/acc_micro": acc, - "test/f1": f1, - "test/precision": prec, - "test/recall": rec, - "test/macro_average_precision": macro_ap, - }) - - self._export_test_predictions() - self._reset_test_state() - - def _export_test_predictions(self): - """Export per-sample predictions to CSV alongside original test data. - - Reads the CSV at ``self.test_csv_path``, appends prediction, - probability, confidence and prediction_type columns, and writes - a new file with the suffix ``_with_predictions``. - - Column order is arranged so that ``label`` appears right before - ``prediction``. - """ - if self.test_csv_path is None: - return - - logits = torch.cat(self.test_logits, dim=0) - targets = torch.cat(self.test_targets, dim=0).numpy() - preds = torch.cat(self.test_preds, dim=0).numpy() - - if self.is_binary: - probs = torch.sigmoid(logits).numpy() - confidence = np.abs(probs - 0.5) * 2 - else: - probs = torch.softmax(logits, dim=1).numpy() - confidence = probs.max(axis=1) - - df = pd.read_csv(self.test_csv_path) - - # Convert start/end from samples to mm:ss format - for col in ("start", "end"): - if col in df.columns: - sr_col = "sample_rate" if "sample_rate" in df.columns else None - if sr_col is not None: - secs = df[col] / df[sr_col] - else: - secs = df[col] - df[col] = secs.apply(lambda s: f"{int(s // 60):02d}:{int(s % 60):02d}") - - # Determine the label column present in the CSV - label_col = "label" - for candidate in ("label", "y", "target"): - if candidate in df.columns: - label_col = candidate - break - - # Build new columns in desired order: ..., label, prediction, probability/probs, confidence, prediction_type - new_cols = [c for c in df.columns if c != label_col] - insert_pos = len(new_cols) - new_cols.insert(insert_pos, label_col) - - df["prediction"] = preds - - if self.is_binary: - df["probability"] = probs - df["confidence"] = confidence - new_cols += ["prediction", "probability", "confidence"] - else: - class_names = self.hparams.get("class_names") or [ - f"class_{i}" for i in range(self.hparams.num_classes) - ] - new_cols.append("prediction") - for i, name in enumerate(class_names): - col = name.replace(" ", "_") + "_prob" - df[col] = probs[:, i] - new_cols.append(col) - df["confidence"] = confidence - new_cols.append("confidence") - - # Classify each prediction as TP, TN, FP or FN - if self.is_binary: - conditions = [ - (targets == 1) & (preds == 1), - (targets == 0) & (preds == 0), - (targets == 0) & (preds == 1), - (targets == 1) & (preds == 0), - ] - labels = ["TP", "TN", "FP", "FN"] - df["prediction_type"] = np.select(conditions, labels, default="") - else: - df["prediction_type"] = np.where( - targets == preds, "Correct", "Incorrect" - ) - new_cols.append("prediction_type") - - df = df[new_cols] - - base_name = os.path.basename(self.test_csv_path) - name, ext = os.path.splitext(base_name) - out_name = f"{name}_with_predictions{ext}" - - if self.predictions_dir is not None: - os.makedirs(self.predictions_dir, exist_ok=True) - output_path = os.path.join(self.predictions_dir, out_name) - else: - output_path = os.path.join(os.path.dirname(self.test_csv_path), out_name) - df.to_csv(output_path, index=False) - print(f"Test predictions saved to: {output_path}") - - def _reset_test_state(self): - """Reset test metrics and storage.""" - self.test_cm.reset() - self.test_f1.reset() - self.test_acc.reset() - self.test_prec.reset() - self.test_rec.reset() - - if self.is_binary: - self.test_auprc.reset() - self.test_prcurve.reset() - else: - self.test_f1_per_class.reset() - self.test_prec_per_class.reset() - self.test_rec_per_class.reset() - - self.test_logits.clear() - self.test_targets.clear() - self.test_paths.clear() - self.test_preds.clear() - - def configure_optimizers(self): - if self.hparams.backbone_lr_ratio != 1.0: - backbone_params = [] - classifier_params = [] - - for name, param in self.net.named_parameters(): - if not param.requires_grad: - continue - if 'fc' in name: - classifier_params.append(param) - else: - backbone_params.append(param) - - backbone_lr = self.hparams.lr * self.hparams.backbone_lr_ratio - classifier_lr = self.hparams.lr - - param_groups = [] - if backbone_params: - param_groups.append({'params': backbone_params, 'lr': backbone_lr}) - if classifier_params: - param_groups.append({'params': classifier_params, 'lr': classifier_lr}) - - print(f"Using discriminative LR: backbone={backbone_lr:.2e}, classifier={classifier_lr:.2e}") - opt = torch.optim.AdamW(param_groups, weight_decay=self.hparams.weight_decay) - else: - opt = torch.optim.AdamW(self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay) - - sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=self.hparams.T_max) - return {"optimizer": opt, "lr_scheduler": {"scheduler": sch, "interval": "epoch"}} - - -def load_model_from_checkpoint(checkpoint_path: str, device: str = "cuda") -> ResNetClassifier: - """Load a trained :class:`ResNetClassifier` from a Lightning checkpoint. - - The model is set to eval mode and frozen (no gradients). - - Args: - checkpoint_path: Path to the ``.ckpt`` file. - device: Target device (default ``"cuda"``). - - Returns: - The loaded model on *device*, ready for inference. - """ - print(f"Loading model from checkpoint: {checkpoint_path}") - model = ResNetClassifier.load_from_checkpoint(checkpoint_path) - model.eval() - model.freeze() - return model.to(device) diff --git a/PytorchWildlife/models/classification/__init__.py b/PytorchWildlife/models/classification/__init__.py deleted file mode 100644 index e450ab78b..000000000 --- a/PytorchWildlife/models/classification/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .resnet_base import * -from .timm_base import * -from .base_classifier import * \ No newline at end of file diff --git a/PytorchWildlife/models/classification/base_classifier.py b/PytorchWildlife/models/classification/base_classifier.py deleted file mode 100644 index c3a60ed36..000000000 --- a/PytorchWildlife/models/classification/base_classifier.py +++ /dev/null @@ -1,28 +0,0 @@ - -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -import torch.nn as nn - -# Making the PlainResNetInference class available for import from this module -__all__ = ["BaseClassifierInference"] - -class BaseClassifierInference(nn.Module): - """ - Inference module for the PlainResNet Classifier. - """ - def __init__(self): - super(BaseClassifierInference, self).__init__() - pass - - def results_generation(self): - pass - - def forward(self): - pass - - def single_image_classification(self): - pass - - def batch_image_classification(self): - pass diff --git a/PytorchWildlife/models/classification/resnet_base/__init__.py b/PytorchWildlife/models/classification/resnet_base/__init__.py deleted file mode 100644 index 10c6d0ba2..000000000 --- a/PytorchWildlife/models/classification/resnet_base/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .base_classifier import * -from .opossum import * -from .amazon import * -from .serengeti import * -from .custom_weights import * \ No newline at end of file diff --git a/PytorchWildlife/models/classification/resnet_base/amazon.py b/PytorchWildlife/models/classification/resnet_base/amazon.py deleted file mode 100644 index 8b228b04e..000000000 --- a/PytorchWildlife/models/classification/resnet_base/amazon.py +++ /dev/null @@ -1,112 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -import torch -from .base_classifier import PlainResNetInference - -__all__ = [ - "AI4GAmazonRainforest" -] - - -class AI4GAmazonRainforest(PlainResNetInference): - """ - Amazon Ranforest Animal Classifier that inherits from PlainResNetInference. - This classifier is specialized for recognizing 36 different animals in the Amazon Rainforest. - """ - - # Image size for the Opossum classifier - IMAGE_SIZE = 224 - - # Class names for prediction - CLASS_NAMES = { - 0: 'Dasyprocta', - 1: 'Bos', - 2: 'Pecari', - 3: 'Mazama', - 4: 'Cuniculus', - 5: 'Leptotila', - 6: 'Human', - 7: 'Aramides', - 8: 'Tinamus', - 9: 'Eira', - 10: 'Crax', - 11: 'Procyon', - 12: 'Capra', - 13: 'Dasypus', - 14: 'Sciurus', - 15: 'Crypturellus', - 16: 'Tamandua', - 17: 'Proechimys', - 18: 'Leopardus', - 19: 'Equus', - 20: 'Columbina', - 21: 'Nyctidromus', - 22: 'Ortalis', - 23: 'Emballonura', - 24: 'Odontophorus', - 25: 'Geotrygon', - 26: 'Metachirus', - 27: 'Catharus', - 28: 'Cerdocyon', - 29: 'Momotus', - 30: 'Tapirus', - 31: 'Canis', - 32: 'Furnarius', - 33: 'Didelphis', - 34: 'Sylvilagus', - 35: 'Unknown' - } - - def __init__(self, weights=None, device="cpu", pretrained=True, version="v2"): - """ - Initialize the Amazon animal Classifier. - - Args: - weights (str, optional): Path to the model weights. Defaults to None. - device (str, optional): Device for model inference. Defaults to "cpu". - pretrained (bool, optional): Whether to use pretrained weights. Defaults to True. - version (str, optional): Version of the model to load. Default is 'v2'. - """ - - # If pretrained, use the provided URL to fetch the weights - if pretrained: - if version == 'v1': - url = "https://zenodo.org/records/10042023/files/AI4GAmazonClassification_v0.0.0.ckpt?download=1" - elif version == 'v2': - url = "https://zenodo.org/records/14252214/files/AI4GAmazonDeforestationv2?download=1" - else: - url = None - - super(AI4GAmazonRainforest, self).__init__(weights=weights, device=device, - num_cls=36, num_layers=50, url=url) - - def results_generation(self, logits: torch.Tensor, img_ids: list[str], id_strip: str = None) -> list[dict]: - """ - Generate results for classification. - - Args: - logits (torch.Tensor): Output tensor from the model. - img_ids (str): Image identifier. - id_strip (str): stiping string for better image id saving. - - Returns: - dict: Dictionary containing image ID, prediction, and confidence score. - """ - - probs = torch.softmax(logits, dim=1) - preds = probs.argmax(dim=1) - confs = probs.max(dim=1)[0] - confidences = probs[0].tolist() - result = [[self.CLASS_NAMES[i], confidence] for i, confidence in enumerate(confidences)] - - results = [] - for pred, img_id, conf in zip(preds, img_ids, confs): - r = {"img_id": str(img_id).strip(id_strip)} - r["prediction"] = self.CLASS_NAMES[pred.item()] - r["class_id"] = pred.item() - r["confidence"] = conf.item() - r["all_confidences"] = result - results.append(r) - - return results diff --git a/PytorchWildlife/models/classification/resnet_base/base_classifier.py b/PytorchWildlife/models/classification/resnet_base/base_classifier.py deleted file mode 100644 index a7448c182..000000000 --- a/PytorchWildlife/models/classification/resnet_base/base_classifier.py +++ /dev/null @@ -1,190 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -import numpy as np -from PIL import Image -from tqdm import tqdm -from collections import OrderedDict - -import torch -import torch.nn as nn -from torchvision.models.resnet import BasicBlock, Bottleneck, ResNet -from torch.hub import load_state_dict_from_url -from torch.utils.data import DataLoader - -from ..base_classifier import BaseClassifierInference -from ....data import transforms as pw_trans -from ....data import datasets as pw_data - -# Making the PlainResNetInference class available for import from this module -__all__ = ["PlainResNetInference"] - - -class ResNetBackbone(ResNet): - """ - Custom ResNet Backbone that extracts features from input images. - """ - def _forward_impl(self, x): - # Following the ResNet structure to extract features - x = self.conv1(x) - x = self.bn1(x) - x = self.relu(x) - x = self.maxpool(x) - - x = self.layer1(x) - x = self.layer2(x) - x = self.layer3(x) - x = self.layer4(x) - - x = self.avgpool(x) - x = torch.flatten(x, 1) - return x - - -class PlainResNetClassifier(nn.Module): - """ - Basic ResNet Classifier that uses a custom ResNet backbone. - """ - name = "PlainResNetClassifier" - - def __init__(self, num_cls=1, num_layers=50): - super(PlainResNetClassifier, self).__init__() - self.num_cls = num_cls - self.num_layers = num_layers - self.feature = None - self.classifier = None - self.criterion_cls = None - # Initialize the network and weights - self.setup_net() - - def setup_net(self): - """ - Set up the ResNet classifier according to the specified number of layers. - """ - kwargs = {} - - if self.num_layers == 18: - block = BasicBlock - layers = [2, 2, 2, 2] - # ... [Missing weight URL definition for ResNet18] - elif self.num_layers == 50: - block = Bottleneck - layers = [3, 4, 6, 3] - # ... [Missing weight URL definition for ResNet50] - else: - raise Exception("ResNet Type not supported.") - - self.feature = ResNetBackbone(block, layers, **kwargs) - self.classifier = nn.Linear(512 * block.expansion, self.num_cls) - - def setup_criteria(self): - """ - Setup the criterion for classification. - """ - self.criterion_cls = nn.CrossEntropyLoss() - - def feat_init(self): - """ - Initialize the features using pretrained weights. - """ - init_weights = self.pretrained_weights.get_state_dict(progress=True) - init_weights = OrderedDict({k.replace("module.", "").replace("feature.", ""): init_weights[k] - for k in init_weights}) - self.feature.load_state_dict(init_weights, strict=False) - # Print missing and unused keys for debugging purposes - load_keys = set(init_weights.keys()) - self_keys = set(self.feature.state_dict().keys()) - missing_keys = self_keys - load_keys - unused_keys = load_keys - self_keys - print("missing keys:", sorted(list(missing_keys))) - print("unused_keys:", sorted(list(unused_keys))) - - -class PlainResNetInference(BaseClassifierInference): - """ - Inference module for the PlainResNet Classifier. - """ - IMAGE_SIZE = None - def __init__(self, num_cls=36, num_layers=50, weights=None, device="cpu", url=None, transform=None): - super(PlainResNetInference, self).__init__() - self.device = device - self.net = PlainResNetClassifier(num_cls=num_cls, num_layers=num_layers) - if weights: - clf_weights = torch.load(weights, map_location=torch.device(self.device)) - elif url: - clf_weights = load_state_dict_from_url(url, map_location=torch.device(self.device)) - else: - raise Exception("Need weights for inference.") - self.load_state_dict(clf_weights["state_dict"], strict=True) - self.eval() - self.net.to(self.device) - - if transform: - self.transform = transform - else: - self.transform = pw_trans.Classification_Inference_Transform(target_size=self.IMAGE_SIZE) - - def results_generation(self, logits: torch.Tensor, img_ids: list[str], id_strip: str = None) -> list[dict]: - """ - Process logits to produce final results. - - Args: - logits (torch.Tensor): Logits from the network. - img_ids (list[str]): List of image paths. - id_strip (str): Stripping string for better image ID saving. - - Returns: - list[dict]: List of dictionaries containing the results. - """ - pass - - def forward(self, img): - feats = self.net.feature(img) - logits = self.net.classifier(feats) - return logits - - def single_image_classification(self, img, img_id=None, id_strip=None): - if type(img) == str: - img = Image.open(img) - else: - img = Image.fromarray(img) - img = self.transform(img) - logits = self.forward(img.unsqueeze(0).to(self.device)) - return self.results_generation(logits.cpu(), [img_id], id_strip=id_strip)[0] - - def batch_image_classification(self, data_path=None, det_results=None, id_strip=None): - """ - Process a batch of images for classification. - """ - - if data_path: - dataset = pw_data.ClassificationImageFolder( - data_path, - transform=self.transform, - ) - elif det_results: - dataset = pw_data.DetectionCrops( - det_results, - transform=self.transform, - path_head='.' - ) - else: - raise Exception("Need data for inference.") - - dataloader = DataLoader(dataset, batch_size=32, shuffle=False, - pin_memory=True, num_workers=4, drop_last=False) - total_logits = [] - total_paths = [] - - with tqdm(total=len(dataloader)) as pbar: - for batch in dataloader: - imgs, paths = batch - imgs = imgs.to(self.device) - total_logits.append(self.forward(imgs)) - total_paths.append(paths) - pbar.update(1) - - total_logits = torch.cat(total_logits, dim=0).cpu() - total_paths = np.concatenate(total_paths, axis=0) - - return self.results_generation(total_logits, total_paths, id_strip=id_strip) diff --git a/PytorchWildlife/models/classification/resnet_base/custom_weights.py b/PytorchWildlife/models/classification/resnet_base/custom_weights.py deleted file mode 100644 index 7247abcf5..000000000 --- a/PytorchWildlife/models/classification/resnet_base/custom_weights.py +++ /dev/null @@ -1,64 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -import torch -from .base_classifier import PlainResNetInference - -__all__ = [ - "CustomWeights" -] - - -class CustomWeights(PlainResNetInference): - """ - Custom Weight Classifier that inherits from PlainResNetInference. - This classifier can load any model that was based on the PytorchWildlife finetuning tool. - """ - - # Image size for the classifier - IMAGE_SIZE = 224 - - - def __init__(self, weights=None, class_names=None, device="cpu"): - """ - Initialize the CustomWeights Classifier. - - Args: - weights (str, optional): Path to the model weights. Defaults to None. - class_names (list[str]): List of class names for the classifier. - device (str, optional): Device for model inference. Defaults to "cpu". - """ - self.CLASS_NAMES = class_names - self.num_cls = len(self.CLASS_NAMES) - super(CustomWeights, self).__init__(weights=weights, device=device, - num_cls=self.num_cls, num_layers=50, url=None) - - def results_generation(self, logits: torch.Tensor, img_ids: list[str], id_strip: str = None) -> list[dict]: - """ - Generate results for classification. - - Args: - logits (torch.Tensor): Output tensor from the model. - img_ids (list[str]): List of image identifiers. - id_strip (str): Stripping string for better image ID saving. - - Returns: - list[dict]: List of dictionaries containing image ID, prediction, and confidence score. - """ - - probs = torch.softmax(logits, dim=1) - preds = probs.argmax(dim=1) - confs = probs.max(dim=1)[0] - confidences = probs[0].tolist() - result = [[self.CLASS_NAMES[i], confidence] for i, confidence in enumerate(confidences)] - - results = [] - for pred, img_id, conf in zip(preds, img_ids, confs): - r = {"img_id": str(img_id).strip(id_strip)} - r["prediction"] = self.CLASS_NAMES[pred.item()] - r["class_id"] = pred.item() - r["confidence"] = conf.item() - r["all_confidences"] = result - results.append(r) - - return results diff --git a/PytorchWildlife/models/classification/resnet_base/opossum.py b/PytorchWildlife/models/classification/resnet_base/opossum.py deleted file mode 100644 index 453533394..000000000 --- a/PytorchWildlife/models/classification/resnet_base/opossum.py +++ /dev/null @@ -1,70 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -import torch -from .base_classifier import PlainResNetInference - -__all__ = [ - "AI4GOpossum" -] - - -class AI4GOpossum(PlainResNetInference): - """ - Opossum Classifier that inherits from PlainResNetInference. - This classifier is specialized for distinguishing between Opossums and Non-opossums. - """ - - # Image size for the Opossum classifier - IMAGE_SIZE = 224 - - # Class names for prediction - CLASS_NAMES = { - 0: "Non-opossum", - 1: "Opossum" - } - - def __init__(self, weights=None, device="cpu", pretrained=True): - """ - Initialize the Opossum Classifier. - - Args: - weights (str, optional): Path to the model weights. Defaults to None. - device (str, optional): Device for model inference. Defaults to "cpu". - pretrained (bool, optional): Whether to use pretrained weights. Defaults to True. - """ - - # If pretrained, use the provided URL to fetch the weights - if pretrained: - url = "https://zenodo.org/records/10023414/files/OpossumClassification_v0.0.0.ckpt?download=1" - else: - url = None - - super(AI4GOpossum, self).__init__(weights=weights, device=device, - num_cls=1, num_layers=50, url=url) - - def results_generation(self, logits: torch.Tensor, img_ids: list[str], id_strip: str = None) -> list[dict]: - """ - Generate results for classification. - - Args: - logits (torch.Tensor): Output tensor from the model. - img_ids (list): List of image identifier. - id_strip (str): stiping string for better image id saving. - - Returns: - dict: Dictionary containing image ID, prediction, and confidence score. - """ - - probs = torch.sigmoid(logits) - preds = (probs > 0.5).squeeze(1).numpy().astype(int) - - results = [] - for pred, img_id, prob in zip(preds, img_ids, probs): - r = {"img_id": str(img_id).strip(id_strip)} - r["prediction"] = self.CLASS_NAMES[pred] - r["class_id"] = pred - r["confidence"] = prob.item() if pred == 1 else (1 - prob.item()) - results.append(r) - - return results diff --git a/PytorchWildlife/models/classification/resnet_base/serengeti.py b/PytorchWildlife/models/classification/resnet_base/serengeti.py deleted file mode 100644 index e0820a02f..000000000 --- a/PytorchWildlife/models/classification/resnet_base/serengeti.py +++ /dev/null @@ -1,82 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -import torch -from .base_classifier import PlainResNetInference - -__all__ = [ - "AI4GSnapshotSerengeti" -] - - -class AI4GSnapshotSerengeti(PlainResNetInference): - """ - Snapshot Serengeti Animal Classifier that inherits from PlainResNetInference. - This classifier is specialized for recognizing 9 different animals and has 1 'other' class. - """ - - # Image size for the Opossum classifier - IMAGE_SIZE = 224 - - # Class names for prediction - CLASS_NAMES = { - 0: 'wildebeest', - 1: 'guineafowl', - 2: 'zebra', - 3: 'buffalo', - 4: 'gazellethomsons', - 5: 'gazellegrants', - 6: 'warthog', - 7: 'impala', - 8: 'hyenaspotted', - 9: 'other' - } - - def __init__(self, weights=None, device="cpu", pretrained=True): - """ - Initialize the Amazon animal Classifier. - - Args: - weights (str, optional): Path to the model weights. Defaults to None. - device (str, optional): Device for model inference. Defaults to "cpu". - pretrained (bool, optional): Whether to use pretrained weights. Defaults to True. - """ - - # If pretrained, use the provided URL to fetch the weights - if pretrained: - url = "https://zenodo.org/records/10456813/files/AI4GSnapshotSerengeti.ckpt?download=1" - else: - url = None - - super(AI4GSnapshotSerengeti, self).__init__(weights=weights, device=device, - num_cls=10, num_layers=18, url=url) - - def results_generation(self, logits: torch.Tensor, img_ids: list[str], id_strip: str = None) -> list[dict]: - """ - Generate results for classification. - - Args: - logits (torch.Tensor): Output tensor from the model. - img_ids (str): Image identifier. - id_strip (str): stiping string for better image id saving. - - Returns: - dict: Dictionary containing image ID, prediction, and confidence score. - """ - - probs = torch.softmax(logits, dim=1) - preds = probs.argmax(dim=1) - confs = probs.max(dim=1)[0] - confidences = probs[0].tolist() - result = [[self.CLASS_NAMES[i], confidence] for i, confidence in enumerate(confidences)] - - results = [] - for pred, img_id, conf in zip(preds, img_ids, confs): - r = {"img_id": str(img_id).strip(id_strip)} - r["prediction"] = self.CLASS_NAMES[pred.item()] - r["class_id"] = pred.item() - r["confidence"] = conf.item() - r["all_confidences"] = result - results.append(r) - - return results diff --git a/PytorchWildlife/models/classification/timm_base/DFNE.py b/PytorchWildlife/models/classification/timm_base/DFNE.py deleted file mode 100644 index 1cb6d71b0..000000000 --- a/PytorchWildlife/models/classification/timm_base/DFNE.py +++ /dev/null @@ -1,56 +0,0 @@ -""" -This is a Pytorch-Wildlife loader for the Deepfaune-New-England classifier. -The original model is available at: https://code.usgs.gov/vtcfwru/deepfaune-new-england/-/tree/main?ref_type=heads -Licence: CC0 1.0 Universal -Copyright USGS 2024 -laurence.clarfeld@uvm.edu -""" - -# Import libraries - -from .base_classifier import TIMM_BaseClassifierInference - -__all__ = [ - "DFNE" -] - -class DFNE(TIMM_BaseClassifierInference): - """ - Base detector class for dinov2 classifier. This class provides utility methods - for loading the model, performing single and batch image classifications, and - formatting results. Make sure the appropriate file for the model weights has been - downloaded to the "models" folder before running DFNE. - """ - BACKBONE = "vit_large_patch14_dinov2.lvd142m" - MODEL_NAME = "dfne_weights_v1_0.pth" - IMAGE_SIZE = 182 - CLASS_NAMES = { - 0: "American Marten", - 1: "Bird sp.", - 2: "Black Bear", - 3: "Bobcat", - 4: "Coyote", - 5: "Domestic Cat", - 6: "Domestic Cow", - 7: "Domestic Dog", - 8: "Fisher", - 9: "Gray Fox", - 10: "Gray Squirrel", - 11: "Human", - 12: "Moose", - 13: "Mouse sp.", - 14: "Opossum", - 15: "Raccoon", - 16: "Red Fox", - 17: "Red Squirrel", - 18: "Skunk", - 19: "Snowshoe Hare", - 20: "White-tailed Deer", - 21: "Wild Boar", - 22: "Wild Turkey", - 23: "no-species" - } - - def __init__(self, weights=None, device="cpu", transform=None): - url = 'https://prod-is-usgs-sb-prod-publish.s3.amazonaws.com/67ae17fcd34e3f09c0e0f002/dfne_weights_v1_0.pth' - super(DFNE, self).__init__(weights=weights, device=device, url=url, transform=transform, weights_key='model_state_dict') \ No newline at end of file diff --git a/PytorchWildlife/models/classification/timm_base/Deepfaune.py b/PytorchWildlife/models/classification/timm_base/Deepfaune.py deleted file mode 100644 index a024729b7..000000000 --- a/PytorchWildlife/models/classification/timm_base/Deepfaune.py +++ /dev/null @@ -1,47 +0,0 @@ -""" -This is a Pytorch-Wildlife loader for the Deepfaune classifier. -The original Deepfaune model is available at: https://www.deepfaune.cnrs.fr/en/ -Licence: CC BY-SA 4.0 -Copyright CNRS 2024 -simon.chamaille@cefe.cnrs.fr; vincent.miele@univ-lyon1.fr -""" - -# Import libraries - -from torchvision.transforms.functional import InterpolationMode -from .base_classifier import TIMM_BaseClassifierInference -from ....data import transforms as pw_trans - -__all__ = [ - "DeepfauneClassifier" -] - -class DeepfauneClassifier(TIMM_BaseClassifierInference): - """ - Base detector class for dinov2 classifier. This class provides utility methods - for loading the model, performing single and batch image classifications, and - formatting results. Make sure the appropriate file for the model weights has been - downloaded to the "models" folder before running DFNE. - """ - BACKBONE = "vit_large_patch14_dinov2.lvd142m" - MODEL_NAME = "deepfaune-vit_large_patch14_dinov2.lvd142m.v3.pt" - IMAGE_SIZE = 182 - CLASS_NAMES={ - 'fr': ['bison', 'blaireau', 'bouquetin', 'castor', 'cerf', 'chamois', 'chat', 'chevre', 'chevreuil', 'chien', 'daim', 'ecureuil', 'elan', 'equide', 'genette', 'glouton', 'herisson', 'lagomorphe', 'loup', 'loutre', 'lynx', 'marmotte', 'micromammifere', 'mouflon', 'mouton', 'mustelide', 'oiseau', 'ours', 'ragondin', 'raton laveur', 'renard', 'renne', 'sanglier', 'vache'], - 'en': ['bison', 'badger', 'ibex', 'beaver', 'red deer', 'chamois', 'cat', 'goat', 'roe deer', 'dog', 'fallow deer', 'squirrel', 'moose', 'equid', 'genet', 'wolverine', 'hedgehog', 'lagomorph', 'wolf', 'otter', 'lynx', 'marmot', 'micromammal', 'mouflon', 'sheep', 'mustelid', 'bird', 'bear', 'nutria', 'raccoon', 'fox', 'reindeer', 'wild boar', 'cow'], - 'it': ['bisonte', 'tasso', 'stambecco', 'castoro', 'cervo', 'camoscio', 'gatto', 'capra', 'capriolo', 'cane', 'daino', 'scoiattolo', 'alce', 'equide', 'genetta', 'ghiottone', 'riccio', 'lagomorfo', 'lupo', 'lontra', 'lince', 'marmotta', 'micromammifero', 'muflone', 'pecora', 'mustelide', 'uccello', 'orso', 'nutria', 'procione', 'volpe', 'renna', 'cinghiale', 'mucca'], - 'de': ['Bison', 'Dachs', 'Steinbock', 'Biber', 'Rothirsch', 'Gämse', 'Katze', 'Ziege', 'Rehwild', 'Hund', 'Damwild', 'Eichhörnchen', 'Elch', 'Equide', 'Ginsterkatze', 'Vielfraß', 'Igel', 'Lagomorpha', 'Wolf', 'Otter', 'Luchs', 'Murmeltier', 'Kleinsäuger', 'Mufflon', 'Schaf', 'Marder', 'Vogel', 'Bär', 'Nutria', 'Waschbär', 'Fuchs', 'Rentier', 'Wildschwein', 'Kuh'], - } - - - def __init__(self, weights=None, device="cpu", transform=None, class_name_lang='en'): - url = 'https://pbil.univ-lyon1.fr/software/download/deepfaune/v1.3/deepfaune-vit_large_patch14_dinov2.lvd142m.v3.pt' - self.CLASS_NAMES = {i: c for i, c in enumerate(self.CLASS_NAMES[class_name_lang])} - if transform is None: - transform = pw_trans.Classification_Inference_Transform(target_size=self.IMAGE_SIZE, - interpolation=InterpolationMode.BICUBIC, - max_size=None, - antialias=None) - super(DeepfauneClassifier, self).__init__(weights=weights, device=device, url=url, transform=transform, - weights_key='state_dict', weights_prefix='base_model.') - \ No newline at end of file diff --git a/PytorchWildlife/models/classification/timm_base/__init__.py b/PytorchWildlife/models/classification/timm_base/__init__.py deleted file mode 100644 index 1938f933a..000000000 --- a/PytorchWildlife/models/classification/timm_base/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .base_classifier import * -from .Deepfaune import * -from .DFNE import * \ No newline at end of file diff --git a/PytorchWildlife/models/classification/timm_base/base_classifier.py b/PytorchWildlife/models/classification/timm_base/base_classifier.py deleted file mode 100644 index 5a4268e04..000000000 --- a/PytorchWildlife/models/classification/timm_base/base_classifier.py +++ /dev/null @@ -1,210 +0,0 @@ -""" model class for loading the DFNE classifier. """ - -# Import libraries - -import os -import wget -import numpy as np -import pandas as pd -from tqdm import tqdm -from PIL import Image -from collections import OrderedDict - -import torch -from torch.utils.data import DataLoader - -import timm - -from ..base_classifier import BaseClassifierInference -from ....data import transforms as pw_trans -from ....data import datasets as pw_data - - -class TIMM_BaseClassifierInference(BaseClassifierInference): - """ - Base detector class for dinov2 classifier. This class provides utility methods - for loading the model, performing single and batch image classifications, and - formatting results. Make sure the appropriate file for the model weights has been - downloaded to the "models" folder before running DFNE. - """ - - BACKBONE = None - MODEL_NAME = None - IMAGE_SIZE = None - - def __init__(self, weights=None, device="cpu", url=None, transform=None, - weights_key='model_state_dict', weights_prefix=''): - """ - Initialize the model. - - Args: - weights (str, optional): - Path to the model weights. Defaults to None. - device (str, optional): - Device for model inference. Defaults to "cpu". - url (str, optional): - URL to fetch the model weights. Defaults to None. - weights_key (str, optional): - Key to fetch the model weights. Defaults to None. - weights_prefix (str, optional): - prefix of model weight keys. Defaults to None. - """ - super(TIMM_BaseClassifierInference, self).__init__() - self.device = device - - if transform: - self.transform = transform - else: - self.transform = pw_trans.Classification_Inference_Transform(target_size=self.IMAGE_SIZE) - - self._load_model(weights, url, weights_key, weights_prefix) - - def _load_model(self, weights=None, url=None, weights_key='model_state_dict', weights_prefix=''): - """ - Load TIMM based model weights - - Args: - weights (str, optional): - Path to the model weights. (defaults to None) - url (str, optional): - url to the model weights. (defaults to None) - """ - - self.predictor = timm.create_model( - self.BACKBONE, - pretrained = False, - num_classes = len(self.CLASS_NAMES), - dynamic_img_size = True - ) - - if url: - if not os.path.exists(os.path.join(torch.hub.get_dir(), "checkpoints", self.MODEL_NAME)): - os.makedirs(os.path.join(torch.hub.get_dir(), "checkpoints"), exist_ok=True) - weights = wget.download(url, out=os.path.join(torch.hub.get_dir(), "checkpoints")) - else: - weights = os.path.join(torch.hub.get_dir(), "checkpoints", self.MODEL_NAME) - elif weights is None: - raise Exception("Need weights for inference.") - - checkpoint = torch.load( - f = weights, - map_location = self.device, - weights_only = False - )[weights_key] - - checkpoint = OrderedDict({k.replace("{}".format(weights_prefix), ""): checkpoint[k] - for k in checkpoint}) - - self.predictor.load_state_dict(checkpoint) - print("Model loaded from {}".format(os.path.join(torch.hub.get_dir(), "checkpoints", self.MODEL_NAME))) - - self.predictor.to(self.device) - self.eval() - - def results_generation(self, logits: torch.Tensor, img_ids: list[str], id_strip: str = None) -> list[dict]: - """ - Generate results for classification. - - Args: - logits (torch.Tensor): Output tensor from the model. - img_ids (list[str]): List of image identifiers. - id_strip (str): Stripping string for better image ID saving. - - Returns: - list[dict]: List of dictionaries containing image ID, prediction, and confidence score. - """ - - probs = torch.softmax(logits, dim=1) - preds = probs.argmax(dim=1) - confs = probs.max(dim=1)[0] - confidences = probs[0].tolist() - result = [[self.CLASS_NAMES[i], confidence] for i, confidence in enumerate(confidences)] - - results = [] - for pred, img_id, conf in zip(preds, img_ids, confs): - r = {"img_id": str(img_id).strip(id_strip)} - r["prediction"] = self.CLASS_NAMES[pred.item()] - r["class_id"] = pred.item() - r["confidence"] = conf.item() - r["all_confidences"] = result - results.append(r) - - return results - - def single_image_classification(self, img, img_id=None, id_strip=None): - """ - Perform classification on a single image. - - Args: - img (str or ndarray): - Image path or ndarray of images. - img_id (str, optional): - Image path or identifier. - id_strip (str, optional): - Whether to strip stings in id. Defaults to None. - - Returns: - (dict): Classification results. - """ - if type(img) == str: - img = Image.open(img).convert("RGB") - else: - img = Image.fromarray(img) - img = self.transform(img) - - logits = self.predictor(img.unsqueeze(0).to(self.device)) - return self.results_generation(logits.cpu(), [img_id], id_strip=id_strip)[0] - - def batch_image_classification(self, data_path=None, det_results=None, id_strip=None, - batch_size=32, num_workers=0, **kwargs): - """ - Perform classification on a batch of images. - - Args: - data_path (str): - Path containing all images for inference. Defaults to None. - det_results (dict): - Dirct outputs from detectors. Defaults to None. - id_strip (str, optional): - Whether to strip stings in id. Defaults to None. - batch_size (int, optional): - Batch size for inference. Defaults to 32. - num_workers (int, optional): - Number of workers for dataloader. Defaults to 0. - - Returns: - (dict): Classification results. - """ - - if data_path: - dataset = pw_data.ClassificationImageFolder( - data_path, - transform=self.transform, - ) - elif det_results: - dataset = pw_data.DetectionCrops( - det_results, - transform=self.transform, - path_head='.' - ) - else: - raise Exception("Need data for inference.") - - dataloader = DataLoader(dataset=dataset, batch_size=batch_size, num_workers=num_workers, - shuffle=False, pin_memory=True, drop_last=False, **kwargs) - - total_logits = [] - total_paths = [] - - with tqdm(total=len(dataloader)) as pbar: - for batch in dataloader: - imgs, paths = batch - imgs = imgs.to(self.device) - total_logits.append(self.predictor(imgs)) - total_paths.append(paths) - pbar.update(1) - - total_logits = torch.cat(total_logits, dim=0).cpu() - total_paths = np.concatenate(total_paths, axis=0) - - return self.results_generation(total_logits, total_paths, id_strip=id_strip) diff --git a/PytorchWildlife/models/detection/__init__.py b/PytorchWildlife/models/detection/__init__.py deleted file mode 100644 index 1bc6ec11f..000000000 --- a/PytorchWildlife/models/detection/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .ultralytics_based import * -from .localization import * -from .yolo_mit import * -from .rtdetr_apache import * \ No newline at end of file diff --git a/PytorchWildlife/models/detection/base_detector.py b/PytorchWildlife/models/detection/base_detector.py deleted file mode 100644 index 00cbfc5cc..000000000 --- a/PytorchWildlife/models/detection/base_detector.py +++ /dev/null @@ -1,100 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -""" Base detector class. """ - -# Importing basic libraries -from torch import nn - -class BaseDetector(nn.Module): - """ - Base detector class. This class provides utility methods for - loading the model, generating results, and performing single and batch image detections. - """ - - # Placeholder class-level attributes to be defined in derived classes - IMAGE_SIZE = None - STRIDE = None - CLASS_NAMES = None - TRANSFORM = None - - def __init__(self, weights=None, device="cpu", url=None): - """ - Initialize the base detector. - - Args: - weights (str, optional): - Path to the model weights. Defaults to None. - device (str, optional): - Device for model inference. Defaults to "cpu". - url (str, optional): - URL to fetch the model weights. Defaults to None. - """ - super(BaseDetector, self).__init__() - self.device = device - - - def _load_model(self, weights=None, device="cpu", url=None): - """ - Load model weights. - - Args: - weights (str, optional): - Path to the model weights. Defaults to None. - device (str, optional): - Device for model inference. Defaults to "cpu". - url (str, optional): - URL to fetch the model weights. Defaults to None. - Raises: - Exception: If weights are not provided. - """ - pass - - def results_generation(self, preds, img_id: str, id_strip: str = None) -> dict: - """ - Generate results for detection based on model predictions. - - Args: - preds (numpy.ndarray): Model predictions. - img_id (str): Image identifier. - id_strip (str, optional): Strip specific characters from img_id. Defaults to None. - - Returns: - dict: Dictionary containing image ID, detections, and labels. - """ - pass - - def single_image_detection(self, img, img_size=None, img_path=None, conf_thres=0.2, id_strip=None) -> dict: - """ - Perform detection on a single image. - - Args: - img (str or ndarray): - Image path or ndarray of images. - img_size (tuple): - Original image size. - img_path (str): - Image path or identifier. - conf_thres (float, optional): - Confidence threshold for predictions. Defaults to 0.2. - id_strip (str, optional): - Characters to strip from img_id. Defaults to None. - - Returns: - dict: Detection results. - """ - pass - - def batch_image_detection(self, dataloader, conf_thres: float = 0.2, id_strip: str = None) -> list[dict]: - """ - Perform detection on a batch of images. - - Args: - dataloader (DataLoader): DataLoader containing image batches. - conf_thres (float, optional): Confidence threshold for predictions. Defaults to 0.2. - id_strip (str, optional): Characters to strip from img_id. Defaults to None. - - Returns: - list[dict]: List of detection results for all images. - """ - pass diff --git a/PytorchWildlife/models/detection/localization/Herdnet.md b/PytorchWildlife/models/detection/localization/Herdnet.md deleted file mode 100644 index 0077e137a..000000000 --- a/PytorchWildlife/models/detection/localization/Herdnet.md +++ /dev/null @@ -1,26 +0,0 @@ -# HerdNet - -HerdNet is an advanced deep learning model designed for the accurate detection and counting of African mammals in aerial images. This model is introduced in the research paper ["From crowd to herd counting: How to precisely detect and count African mammals using aerial imagery and deep learning?"](https://www.sciencedirect.com/science/article/pii/S092427162300031X?via%3Dihub) by Alexandre Delplanque and colleagues. - -## Model Overview - -HerdNet is inspired by CenterNet, which is a neural network based on convolutional layers designed for object detection tasks. The architecture of HerdNet is tailored to handle the challenges of locating and counting dense herds in varied landscapes. It focuses on a localization head from CenterNet for detecting animal centers and includes a classification head for species identification. - -## Features - -- Optimized for speed vs. accuracy trade-off. -- Utilizes a modified encoder-decoder structure for efficiency. -- Employs a Local Maxima Detection Strategy (LMDS) for precise localization during testing. - -## Resources - -The original code repository and pretrained models are available at: -[https://github.com/Alexandre-Delplanque/HerdNet.git](https://github.com/Alexandre-Delplanque/HerdNet.git) - -## Citation - -If you use HerdNet in your research, please cite the original paper by Alexandre Delplanque and his team. - -## License - -Refer to the repository [link](https://github.com/Alexandre-Delplanque/HerdNet.git) for licensing information. diff --git a/PytorchWildlife/models/detection/localization/OWL_C.py b/PytorchWildlife/models/detection/localization/OWL_C.py deleted file mode 100644 index c21e970c1..000000000 --- a/PytorchWildlife/models/detection/localization/OWL_C.py +++ /dev/null @@ -1,307 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -from ..base_detector import BaseDetector -from ..localization.animaloc.eval import HerdNetStitcherLocBranch, HerdNetLMDSLocBranch -from ....data import datasets as pw_data -from .model_owl_c import OWLC_Architecture - -import torch -from torch.hub import load_state_dict_from_url -from torch.utils.data import DataLoader -import torchvision.transforms as transforms - -import numpy as np -from PIL import Image -from tqdm import tqdm -import supervision as sv -import os -import wget -import cv2 - -class ResizeIfSmaller: - def __init__(self, min_size, interpolation=Image.BILINEAR): - self.min_size = min_size - self.interpolation = interpolation - - def __call__(self, img): - if isinstance(img, np.ndarray): - img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) - assert isinstance(img, Image.Image), "Image should be a PIL Image" - width, height = img.size - if height < self.min_size or width < self.min_size: - ratio = max(self.min_size / height, self.min_size / width) - new_height = int(height * ratio) - new_width = int(width * ratio) - img = img.resize((new_width, new_height), self.interpolation) - return img - -class OWLC(BaseDetector): - """ - OWL-C (Overhead Wildlife Locator - CNN) detector class. This class provides utility methods for - loading the model, generating results, and performing single and batch image detections. - """ - - def __init__(self, weights=None, device="cpu", version='general', url="https://zenodo.org/records/18165116/files/herdnet_loc_branch_wildme.pth?download=1", transform=None): - """ - Initialize the HerdNet detector. - - Args: - weights (str, optional): - Path to the model weights. Defaults to None. - device (str, optional): - Device for model inference. Defaults to "cpu". - version (str, optional): - Version of the model to use. Defaults to 'general'. It should be either 'general' or 'caribou'. - url (str, optional): - URL to fetch the model weights. Defaults to "https://zenodo.org/records/18165116/files/herdnet_loc_branch_wildme.pth?download=1". - transform (torchvision.transforms.Compose, optional): - Image transformation for inference. Defaults to None. - """ - super(OWLC, self).__init__(weights=weights, device=device, url=url) - # Assert version is either 'general' or 'caribou' - version = version.lower() - assert version in ['general', 'caribou'], "Version should be either 'general' or 'caribou'." - if version == 'caribou': - url = "https://zenodo.org/records/18177050/files/caribou_herdnet_loc_branch.pth?download=1" - self._load_model(weights, device, url) - - self.stitcher = HerdNetStitcherLocBranch( # This module enables patch-based inference - model = self.model, - size = (512,512), - overlap = 160, - down_ratio = 2, - up = True, - reduction = 'mean', - device_name = device - ) - - self.lmds_kwargs: dict = {'kernel_size': (3, 3), 'adapt_ts': 0.2, 'neg_ts': 0.1} - self.lmds = HerdNetLMDSLocBranch(**self.lmds_kwargs) # Local Maxima Detection Strategy - - if not transform: - self.transforms = transforms.Compose([ - ResizeIfSmaller(512), - transforms.ToTensor(), - transforms.Normalize(mean=self.img_mean, std=self.img_std) - ]) - else: - self.transforms = transform - - def _load_model(self, weights=None, device="cpu", url=None): - """ - Load the HerdNet model weights. - - Args: - weights (str, optional): - Path to the model weights. Defaults to None. - device (str, optional): - Device for model inference. Defaults to "cpu". - url (str, optional): - URL to fetch the model weights. Defaults to None. - Raises: - Exception: If weights are not provided. - """ - if weights: - checkpoint = torch.load(weights, map_location=torch.device(device)) - elif url: - filename = url.split('/')[-1][:-11] # Splitting the URL to get the filename and removing the '?download=1' part - if not os.path.exists(os.path.join(torch.hub.get_dir(), "checkpoints", filename)): - os.makedirs(os.path.join(torch.hub.get_dir(), "checkpoints"), exist_ok=True) - weights = wget.download(url, out=os.path.join(torch.hub.get_dir(), "checkpoints")) - else: - weights = os.path.join(torch.hub.get_dir(), "checkpoints", filename) - checkpoint = torch.load(weights, map_location=torch.device(device)) - else: - raise Exception("Need weights for inference.") - - # Load the class names and other metadata from the checkpoint - self.CLASS_NAMES = {1: "animal"} # Animal overhead megadetector with single class - self.img_mean = checkpoint['mean'] - self.img_std = checkpoint['std'] - - # Load the model architecture - self.model = OWLC_Architecture(pretrained=False) - - # Load checkpoint into model - state_dict = checkpoint['model_state_dict'] - # Remove 'model.' prefix from the state_dict keys if the key starts with 'model.' - new_state_dict = {k.replace('model.', ''): v for k, v in state_dict.items() if k.startswith('model.')} - # Load the new state_dict - self.model.load_state_dict(new_state_dict, strict=True) - - print(f"Model loaded from {weights}") - - def results_generation(self, preds: np.ndarray, img: np.ndarray = None, img_id: str = None, id_strip: str = None) -> dict: - """ - Generate results for detection based on model predictions. - - Args: - preds (numpy.ndarray): Model predictions. - img (numpy.ndarray, optional): Image for inference. Defaults to None. - img_id (str, optional): Image identifier. Defaults to None. - id_strip (str, optional): Strip specific characters from img_id. Defaults to None. - - Returns: - dict: Dictionary containing image ID, detections, and labels. - """ - assert img is not None or img_id is not None, "Either img or img_id should be provided." - if img_id is not None: - img_id = str(img_id).strip(id_strip) if id_strip else str(img_id) - results = {"img_id": img_id} - elif img is not None: - results = {"img": img} - - results["detections"] = sv.Detections( - xyxy=preds[:, :4], - confidence=preds[:, 4], - class_id=preds[:, 5].astype(int) - ) - results["labels"] = [ - f"{self.CLASS_NAMES[class_id]} {confidence:0.2f}" - for confidence, class_id in zip(results["detections"].confidence, results["detections"].class_id) - ] - return results - - def single_image_detection(self, img, img_path=None, det_conf_thres=0.20, id_strip=None) -> dict: - """ - Perform detection on a single image. - - Args: - img (str or np.ndarray): - Image for inference. - img_path (str, optional): - Path to the image. Defaults to None. - det_conf_thres (float, optional): - Confidence threshold for detections. Defaults to 0.20. - id_strip (str, optional): - Characters to strip from img_id. Defaults to None. - - Returns: - dict: Detection results for the image. - """ - if isinstance(img, str): - img_path = img_path or img - img = np.array(Image.open(img_path).convert("RGB")) - if self.transforms: - img_tensor = self.transforms(img) - - preds = self.stitcher(img_tensor) - heatmap = preds - counts, locs, labels, dscores = self.lmds(heatmap) - preds_array = self.process_lmds_results(counts, locs, labels, dscores, det_conf_thres) - if img_path: - results_dict = self.results_generation(preds_array, img_id=img_path, id_strip=id_strip) - else: - results_dict = self.results_generation(preds_array, img=img) - return results_dict - - def batch_image_detection(self, data_path: str, det_conf_thres: float = 0.20, batch_size: int = 1, id_strip: str = None) -> list[dict]: - """ - Perform detection on a batch of images. - - Args: - data_path (str): Path containing all images for inference. - det_conf_thres (float, optional): Confidence threshold for detections. Defaults to 0.20. - batch_size (int, optional): Batch size for inference. Defaults to 1. - id_strip (str, optional): Characters to strip from img_id. Defaults to None. - - Returns: - list[dict]: List of detection results for all images. - """ - dataset = pw_data.DetectionImageFolder( - data_path, - transform=self.transforms - ) - # Creating a Dataloader for batching and parallel processing of the images - loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, - pin_memory=True, num_workers=0, drop_last=False) - - results = [] - - with tqdm(total=len(loader)) as pbar: - for batch_index, (imgs, paths, sizes) in enumerate(loader): - imgs = imgs.to(self.device) - predictions = self.stitcher(imgs[0]).detach().cpu() - heatmap = predictions - counts, locs, labels, dscores = self.lmds(heatmap) - preds_array = self.process_lmds_results(counts, locs, labels, dscores, det_conf_thres) - results_dict = self.results_generation(preds_array, img_id=paths[0], id_strip=id_strip) - pbar.update(1) - sizes = sizes.numpy() - normalized_coords = [[x1 / sizes[0][0], y1 / sizes[0][1], x2 / sizes[0][0], y2 / sizes[0][1]] for x1, y1, x2, y2 in preds_array[:, :4]] - results_dict['normalized_coords'] = normalized_coords - results.append(results_dict) - return results - - def process_lmds_results(self, counts: list, locs: list, labels: list, dscores: list, det_conf_thres: float = 0.2) -> np.ndarray: - """ - Process the results from the Local Maxima Detection Strategy. - - Args: - counts (list): Number of detections for each species. - locs (list): Locations of the detections. - labels (list): Labels of the detections. - dscores (list): Detection scores. - det_conf_thres (float, optional): Confidence threshold for detections. Defaults to 0.2. - - Returns: - numpy.ndarray: Processed detection results. - """ - # Flatten the lists since we know its a single image - #counts = counts[0] - locs = locs[0] - labels = labels[0] - dscores = dscores[0] - - # Calculate the total number of detections - total_detections = sum(counts) - - # Pre-allocate based on total possible detections - preds_array = np.empty((total_detections, 6)) #xyxy, confidence, class_id format - detection_idx = 0 - valid_detections_idx = 0 # Index for valid detections after applying the confidence threshold - # Loop through each species - for specie_idx in range(len(counts)): - count = counts[specie_idx] - if count == 0: - continue - - # Get the detections for this species - species_locs = np.array(locs[detection_idx : detection_idx + count]) - species_locs[:, [0, 1]] = species_locs[:, [1, 0]] # Swap x and y in species_locs - species_dscores = np.array(dscores[detection_idx : detection_idx + count]) - species_labels = np.array(labels[detection_idx : detection_idx + count]) - - # Apply the confidence threshold - valid_detections_by_det_score = species_dscores > det_conf_thres - valid_detections = valid_detections_by_det_score - valid_detections_count = np.sum(valid_detections) - valid_detections_idx += valid_detections_count - # Fill the preds_array with the valid detections - if valid_detections_count > 0: - preds_array[valid_detections_idx - valid_detections_count : valid_detections_idx, :2] = species_locs[valid_detections] - 1 - preds_array[valid_detections_idx - valid_detections_count : valid_detections_idx, 2:4] = species_locs[valid_detections] + 1 - preds_array[valid_detections_idx - valid_detections_count : valid_detections_idx, 4] = species_dscores[valid_detections] - preds_array[valid_detections_idx - valid_detections_count : valid_detections_idx, 5] = species_labels[valid_detections] - - detection_idx += count # Move to the next species - - preds_array = preds_array[:valid_detections_idx] # Remove the empty rows - - return preds_array - - def forward(self, input: torch.Tensor) -> torch.Tensor: - """ - Forward pass of the model. - - Args: - input (torch.Tensor): - Input tensor for the model. - - Returns: - torch.Tensor: Model output. - """ - # Call the forward method of the model in evaluation mode - self.model.eval() - return self.model(input) diff --git a/PytorchWildlife/models/detection/localization/OWL_T.py b/PytorchWildlife/models/detection/localization/OWL_T.py deleted file mode 100644 index 7fc5b94b8..000000000 --- a/PytorchWildlife/models/detection/localization/OWL_T.py +++ /dev/null @@ -1,301 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -from ..base_detector import BaseDetector -from ..localization.animaloc.eval import HerdNetStitcherLocBranch, HerdNetLMDSLocBranch -from ....data import datasets as pw_data -from .model_owl_t import OWLT_Architecture - -import torch -from torch.hub import load_state_dict_from_url -from torch.utils.data import DataLoader -import torchvision.transforms as transforms - -import numpy as np -from PIL import Image -from tqdm import tqdm -import supervision as sv -import os -import wget -import cv2 - -class ResizeIfSmaller: - def __init__(self, min_size, interpolation=Image.BILINEAR): - self.min_size = min_size - self.interpolation = interpolation - - def __call__(self, img): - if isinstance(img, np.ndarray): - img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) - assert isinstance(img, Image.Image), "Image should be a PIL Image" - width, height = img.size - if height < self.min_size or width < self.min_size: - ratio = max(self.min_size / height, self.min_size / width) - new_height = int(height * ratio) - new_width = int(width * ratio) - img = img.resize((new_width, new_height), self.interpolation) - return img - -class OWLT(BaseDetector): - """ - OWL-T (Overhead Wildlife Locator - Transformer) detector class. This class provides utility methods for - loading the model, generating results, and performing single and batch image detections. - """ - - def __init__(self, weights=None, device="cpu", url="https://zenodo.org/records/18177050/files/HerdNet_Hybrid_Multiscale_Residual_wildme.pth?download=1", transform=None): - """ - Initialize the HerdNet detector. - - Args: - weights (str, optional): - Path to the model weights. Defaults to None. - device (str, optional): - Device for model inference. Defaults to "cpu". - url (str, optional): - URL to fetch the model weights. Defaults to "https://zenodo.org/records/18177050/files/HerdNet_Hybrid_Multiscale_Residual_wildme.pth?download=1". - transform (torchvision.transforms.Compose, optional): - Image transformation for inference. Defaults to None. - """ - super(OWLT, self).__init__(weights=weights, device=device, url=url) - - self._load_model(weights, device, url) - - self.stitcher = HerdNetStitcherLocBranch( # This module enables patch-based inference - model = self.model, - size = (512,512), - overlap = 160, - down_ratio = 2, - up = True, - reduction = 'mean', - device_name = device - ) - - self.lmds_kwargs: dict = {'kernel_size': (3, 3), 'adapt_ts': 0.2, 'neg_ts': 0.1} - self.lmds = HerdNetLMDSLocBranch(**self.lmds_kwargs) # Local Maxima Detection Strategy - - if not transform: - self.transforms = transforms.Compose([ - ResizeIfSmaller(512), - transforms.ToTensor(), - transforms.Normalize(mean=self.img_mean, std=self.img_std) - ]) - else: - self.transforms = transform - - def _load_model(self, weights=None, device="cpu", url=None): - """ - Load the HerdNet model weights. - - Args: - weights (str, optional): - Path to the model weights. Defaults to None. - device (str, optional): - Device for model inference. Defaults to "cpu". - url (str, optional): - URL to fetch the model weights. Defaults to None. - Raises: - Exception: If weights are not provided. - """ - if weights: - checkpoint = torch.load(weights, map_location=torch.device(device)) - elif url: - filename = url.split('/')[-1][:-11] # Splitting the URL to get the filename and removing the '?download=1' part - if not os.path.exists(os.path.join(torch.hub.get_dir(), "checkpoints", filename)): - os.makedirs(os.path.join(torch.hub.get_dir(), "checkpoints"), exist_ok=True) - weights = wget.download(url, out=os.path.join(torch.hub.get_dir(), "checkpoints")) - else: - weights = os.path.join(torch.hub.get_dir(), "checkpoints", filename) - checkpoint = torch.load(weights, map_location=torch.device(device)) - else: - raise Exception("Need weights for inference.") - - # Load the class names and other metadata from the checkpoint - self.CLASS_NAMES = {1: "animal"} # Animal overhead megadetector with single class - self.img_mean = checkpoint['mean'] - self.img_std = checkpoint['std'] - - # Load the model architecture - self.model = OWLT_Architecture(pretrained_cnn=False) - - # Load checkpoint into model - state_dict = checkpoint['model_state_dict'] - # Remove 'model.' prefix from the state_dict keys if the key starts with 'model.' - new_state_dict = {k.replace('model.', ''): v for k, v in state_dict.items() if k.startswith('model.')} - # Load the new state_dict - self.model.load_state_dict(new_state_dict, strict=True) - - print(f"Model loaded from {weights}") - - def results_generation(self, preds: np.ndarray, img: np.ndarray = None, img_id: str = None, id_strip: str = None) -> dict: - """ - Generate results for detection based on model predictions. - - Args: - preds (numpy.ndarray): Model predictions. - img (numpy.ndarray, optional): Image for inference. Defaults to None. - img_id (str, optional): Image identifier. Defaults to None. - id_strip (str, optional): Strip specific characters from img_id. Defaults to None. - - Returns: - dict: Dictionary containing image ID, detections, and labels. - """ - assert img is not None or img_id is not None, "Either img or img_id should be provided." - if img_id is not None: - img_id = str(img_id).strip(id_strip) if id_strip else str(img_id) - results = {"img_id": img_id} - elif img is not None: - results = {"img": img} - - results["detections"] = sv.Detections( - xyxy=preds[:, :4], - confidence=preds[:, 4], - class_id=preds[:, 5].astype(int) - ) - results["labels"] = [ - f"{self.CLASS_NAMES[class_id]} {confidence:0.2f}" - for confidence, class_id in zip(results["detections"].confidence, results["detections"].class_id) - ] - return results - - def single_image_detection(self, img, img_path=None, det_conf_thres=0.20, id_strip=None) -> dict: - """ - Perform detection on a single image. - - Args: - img (str or np.ndarray): - Image for inference. - img_path (str, optional): - Path to the image. Defaults to None. - det_conf_thres (float, optional): - Confidence threshold for detections. Defaults to 0.20. - id_strip (str, optional): - Characters to strip from img_id. Defaults to None. - - Returns: - dict: Detection results for the image. - """ - if isinstance(img, str): - img_path = img_path or img - img = np.array(Image.open(img_path).convert("RGB")) - if self.transforms: - img_tensor = self.transforms(img) - - preds = self.stitcher(img_tensor) - heatmap = preds - counts, locs, labels, dscores = self.lmds(heatmap) - preds_array = self.process_lmds_results(counts, locs, labels, dscores, det_conf_thres) - if img_path: - results_dict = self.results_generation(preds_array, img_id=img_path, id_strip=id_strip) - else: - results_dict = self.results_generation(preds_array, img=img) - return results_dict - - def batch_image_detection(self, data_path: str, det_conf_thres: float = 0.20, batch_size: int = 1, id_strip: str = None) -> list[dict]: - """ - Perform detection on a batch of images. - - Args: - data_path (str): Path containing all images for inference. - det_conf_thres (float, optional): Confidence threshold for detections. Defaults to 0.20. - batch_size (int, optional): Batch size for inference. Defaults to 1. - id_strip (str, optional): Characters to strip from img_id. Defaults to None. - - Returns: - list[dict]: List of detection results for all images. - """ - dataset = pw_data.DetectionImageFolder( - data_path, - transform=self.transforms - ) - # Creating a Dataloader for batching and parallel processing of the images - loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, - pin_memory=True, num_workers=0, drop_last=False) - - results = [] - - with tqdm(total=len(loader)) as pbar: - for batch_index, (imgs, paths, sizes) in enumerate(loader): - imgs = imgs.to(self.device) - predictions = self.stitcher(imgs[0]).detach().cpu() - heatmap = predictions - counts, locs, labels, dscores = self.lmds(heatmap) - preds_array = self.process_lmds_results(counts, locs, labels, dscores, det_conf_thres) - results_dict = self.results_generation(preds_array, img_id=paths[0], id_strip=id_strip) - pbar.update(1) - sizes = sizes.numpy() - normalized_coords = [[x1 / sizes[0][0], y1 / sizes[0][1], x2 / sizes[0][0], y2 / sizes[0][1]] for x1, y1, x2, y2 in preds_array[:, :4]] # TODO: Check if this is correct due to xy swapping - results_dict['normalized_coords'] = normalized_coords - results.append(results_dict) - return results - - def process_lmds_results(self, counts: list, locs: list, labels: list, dscores: list, det_conf_thres: float = 0.2) -> np.ndarray: - """ - Process the results from the Local Maxima Detection Strategy. - - Args: - counts (list): Number of detections for each species. - locs (list): Locations of the detections. - labels (list): Labels of the detections. - dscores (list): Detection scores. - det_conf_thres (float, optional): Confidence threshold for detections. Defaults to 0.2. - - Returns: - numpy.ndarray: Processed detection results. - """ - # Flatten the lists since we know its a single image - #counts = counts[0] - locs = locs[0] - labels = labels[0] - dscores = dscores[0] - - # Calculate the total number of detections - total_detections = sum(counts) - - # Pre-allocate based on total possible detections - preds_array = np.empty((total_detections, 6)) #xyxy, confidence, class_id format - detection_idx = 0 - valid_detections_idx = 0 # Index for valid detections after applying the confidence threshold - # Loop through each species - for specie_idx in range(len(counts)): - count = counts[specie_idx] - if count == 0: - continue - - # Get the detections for this species - species_locs = np.array(locs[detection_idx : detection_idx + count]) - species_locs[:, [0, 1]] = species_locs[:, [1, 0]] # Swap x and y in species_locs - species_dscores = np.array(dscores[detection_idx : detection_idx + count]) - species_labels = np.array(labels[detection_idx : detection_idx + count]) - - # Apply the confidence threshold - valid_detections_by_det_score = species_dscores > det_conf_thres - valid_detections = valid_detections_by_det_score - valid_detections_count = np.sum(valid_detections) - valid_detections_idx += valid_detections_count - # Fill the preds_array with the valid detections - if valid_detections_count > 0: - preds_array[valid_detections_idx - valid_detections_count : valid_detections_idx, :2] = species_locs[valid_detections] - 1 - preds_array[valid_detections_idx - valid_detections_count : valid_detections_idx, 2:4] = species_locs[valid_detections] + 1 - preds_array[valid_detections_idx - valid_detections_count : valid_detections_idx, 4] = species_dscores[valid_detections] - preds_array[valid_detections_idx - valid_detections_count : valid_detections_idx, 5] = species_labels[valid_detections] - - detection_idx += count # Move to the next species - - preds_array = preds_array[:valid_detections_idx] # Remove the empty rows - - return preds_array - - def forward(self, input: torch.Tensor) -> torch.Tensor: - """ - Forward pass of the model. - - Args: - input (torch.Tensor): - Input tensor for the model. - - Returns: - torch.Tensor: Model output. - """ - # Call the forward method of the model in evaluation mode - self.model.eval() - return self.model(input) diff --git a/PytorchWildlife/models/detection/localization/__init__.py b/PytorchWildlife/models/detection/localization/__init__.py deleted file mode 100644 index 0f5d71b54..000000000 --- a/PytorchWildlife/models/detection/localization/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .herdnet import * -from .OWL_C import * -from .OWL_T import * \ No newline at end of file diff --git a/PytorchWildlife/models/detection/localization/animaloc/__init__.py b/PytorchWildlife/models/detection/localization/animaloc/__init__.py deleted file mode 100644 index 74c0b9052..000000000 --- a/PytorchWildlife/models/detection/localization/animaloc/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -__copyright__ = \ - """ - Copyright (C) 2024 University of Liège, Gembloux Agro-Bio Tech, Forest Is Life - All rights reserved. - - This source code is under the MIT License. - - Please contact the author Alexandre Delplanque (alexandre.delplanque@uliege.be) for any questions. - - Last modification: March 18, 2024 - """ -__author__ = "Alexandre Delplanque" -__license__ = "MIT License" -__version__ = "0.2.1" diff --git a/PytorchWildlife/models/detection/localization/animaloc/data/__init__.py b/PytorchWildlife/models/detection/localization/animaloc/data/__init__.py deleted file mode 100644 index 649204b5d..000000000 --- a/PytorchWildlife/models/detection/localization/animaloc/data/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -__copyright__ = \ - """ - Copyright (C) 2024 University of Liège, Gembloux Agro-Bio Tech, Forest Is Life - All rights reserved. - - This source code is under the MIT License. - - Please contact the author Alexandre Delplanque (alexandre.delplanque@uliege.be) for any questions. - - Last modification: March 18, 2024 - """ -__author__ = "Alexandre Delplanque" -__license__ = "MIT License" -__version__ = "0.2.1" - -from .patches import * -from .types import * diff --git a/PytorchWildlife/models/detection/localization/animaloc/data/patches.py b/PytorchWildlife/models/detection/localization/animaloc/data/patches.py deleted file mode 100644 index 2eb05ba12..000000000 --- a/PytorchWildlife/models/detection/localization/animaloc/data/patches.py +++ /dev/null @@ -1,187 +0,0 @@ -__copyright__ = \ - """ - Copyright (C) 2024 University of Liège, Gembloux Agro-Bio Tech, Forest Is Life - All rights reserved. - - This source code is under the MIT License. - - Please contact the author Alexandre Delplanque (alexandre.delplanque@uliege.be) for any questions. - - Last modification: March 18, 2024 - """ -__author__ = "Alexandre Delplanque" -__license__ = "MIT License" -__version__ = "0.2.1" - - -import os -import PIL -import torch -import pandas -import numpy -import matplotlib.pyplot as plt -import torchvision -from torchvision.utils import make_grid, save_image - -from typing import Union, Tuple - -from tqdm import tqdm - -from .types import BoundingBox - -__all__ = ['ImageToPatches'] - -class ImageToPatches: - ''' Class to make patches from a tensor image ''' - - def __init__( - self, - image: Union[PIL.Image.Image, torch.Tensor], - size: Tuple[int,int], - overlap: int = 0 - ) -> None: - ''' - Args: - image (PIL.Image.Image or torch.Tensor): image, if tensor: (C,H,W) - size (tuple): patches size (height, width), in pixels - overlap (int, optional): overlap between patches, in pixels. - Defaults to 0. - ''' - - assert isinstance(image, (PIL.Image.Image, torch.Tensor)), \ - 'image must be a PIL.Image.Image or a torch.Tensor instance' - - self.image = image - if isinstance(self.image, PIL.Image.Image): - self.image = torchvision.transforms.ToTensor()(self.image) - - self.size = size - self.overlap = overlap - - def make_patches(self) -> torch.Tensor: - ''' Make patches from the image - - When the image division is not perfect, a zero-padding is performed - so that the patches have the same size. - - Returns: - torch.Tensor: - patches of shape (B,C,H,W) - ''' - # patches' height & width - height = min(self.image.size(1),self.size[0]) - width = min(self.image.size(2),self.size[1]) - - # unfold on height - height_fold = self.image.unfold(1, height, height - self.overlap) - - # if non-perfect division on height - residual = self._img_residual(self.image.size(1), height, self.overlap) - if residual != 0: - # get the residual patch and add it to the fold - remaining_height = torch.zeros(3, 1, self.image.size(2), height) # padding - remaining_height[:,:,:,:residual] = self.image[:,-residual:,:].permute(0,2,1).unsqueeze(1) - - height_fold = torch.cat((height_fold,remaining_height),dim=1) - - # unfold on width - fold = height_fold.unfold(2, width, width - self.overlap) - - # if non-perfect division on width, the same - residual = self._img_residual(self.image.size(2), width, self.overlap) - if residual != 0: - remaining_width = torch.zeros(3, fold.shape[1], 1, height, width) # padding - remaining_width[:,:,:,:,:residual] = height_fold[:,:,-residual:,:].permute(0,1,3,2).unsqueeze(2) - - fold = torch.cat((fold,remaining_width),dim=2) - - self._nrow , self._ncol = fold.shape[2] , fold.shape[1] - - # reshaping - patches = fold.permute(1,2,0,3,4).reshape(-1,self.image.size(0),height,width) - - return patches - - def get_limits(self) -> dict: - ''' Get patches limits within the image frame - - When the image division is not perfect, the zero-padding is not - considered here. Hence, the limits are the true limits of patches - within the initial image. - - Returns: - dict: - a dict containing int as key and BoundingBox as value - ''' - - # patches' height & width - height = min(self.image.size(1),self.size[0]) - width = min(self.image.size(2),self.size[1]) - - # lists of pixels numbers - y_pixels = torch.tensor(list(range(0,self.image.size(1)+1))) - x_pixels = torch.tensor(list(range(0,self.image.size(2)+1))) - - # cut into patches to get limits - y_pixels_fold = y_pixels.unfold(0, height+1, height-self.overlap) - y_mina = [int(patch[0]) for patch in y_pixels_fold] - y_maxa = [int(patch[-1]) for patch in y_pixels_fold] - - x_pixels_fold = x_pixels.unfold(0, width+1, width-self.overlap) - x_mina = [int(patch[0]) for patch in x_pixels_fold] - x_maxa = [int(patch[-1]) for patch in x_pixels_fold] - - # if non-perfect division on height - residual = self._img_residual(self.image.size(1), height, self.overlap) - if residual != 0: - remaining_y = y_pixels[-residual-1:].unsqueeze(0)[0] - y_mina.append(int(remaining_y[0])) - y_maxa.append(int(remaining_y[-1])) - - # if non-perfect division on width - residual = self._img_residual(self.image.size(2), width, self.overlap) - if residual != 0: - remaining_x = x_pixels[-residual-1:].unsqueeze(0)[0] - x_mina.append(int(remaining_x[0])) - x_maxa.append(int(remaining_x[-1])) - - i = 0 - patches_limits = {} - for y_min , y_max in zip(y_mina,y_maxa): - for x_min , x_max in zip(x_mina,x_maxa): - patches_limits[i] = BoundingBox(x_min,y_min,x_max,y_max) - i += 1 - - return patches_limits - - def show(self) -> None: - ''' Show the grid of patches ''' - - grid = make_grid( - self.make_patches(), - padding=50, - nrow=self._nrow - ).permute(1,2,0).numpy() - - plt.imshow(grid) - - plt.show() - - return grid - - def _img_residual(self, ims: int, ks: int, overlap: int) -> int: - - ims, stride = int(ims), int(ks - overlap) - n = ims // stride - end = n * stride + overlap - - residual = ims % stride - - if end > ims: - n -= 1 - residual = ims - (n * stride) - - return residual - - def __len__(self) -> int: - return len(self.get_limits()) diff --git a/PytorchWildlife/models/detection/localization/animaloc/data/types.py b/PytorchWildlife/models/detection/localization/animaloc/data/types.py deleted file mode 100644 index bd650ac7b..000000000 --- a/PytorchWildlife/models/detection/localization/animaloc/data/types.py +++ /dev/null @@ -1,128 +0,0 @@ -__copyright__ = \ - """ - Copyright (C) 2024 University of Liège, Gembloux Agro-Bio Tech, Forest Is Life - All rights reserved. - - This source code is under the MIT License. - - Please contact the author Alexandre Delplanque (alexandre.delplanque@uliege.be) for any questions. - - Last modification: March 18, 2024 - """ -__author__ = "Alexandre Delplanque" -__license__ = "MIT License" -__version__ = "0.2.1" - -from typing import Union, Tuple - -__all__ = ['Point', 'BoundingBox'] - -class Point: - ''' Class to define a Point object in a 2D Cartesian - coordinate system. - ''' - - def __init__(self, x: Union[int,float], y: Union[int,float]) -> None: - ''' - Args: - x (int, float): x coordinate - y (int, float): y coordinate - ''' - - assert x >= 0 and y >= 0, f'Coordinates must be positives, got x={x} and y={y}' - - self.x = x - self.y = y - self.area = 1 # always 1 pixel - - # @property - # def area(self) -> int: - # ''' To get area ''' - # return 1 # always 1 pixel - - @property - def get_tuple(self) -> Tuple[Union[int,float],Union[int,float]]: - ''' To get point's coordinates in tuple ''' - return (self.x,self.y) - - @property - def atype(self) -> str: - ''' To get annotation type string ''' - return 'Point' - - def __repr__(self) -> str: - return f'Point(x: {self.x}, y: {self.y})' - - def __eq__(self, other) -> bool: - return all([ - self.x == other.x, - self.y == other.y - ]) - -class BoundingBox: - ''' Class to define a BoundingBox object in a 2D Cartesian - coordinate system. - ''' - - def __init__( - self, - x_min: Union[int,float], - y_min: Union[int,float], - x_max: Union[int,float], - y_max: Union[int,float] - ) -> None: - ''' - Args: - x_min (int, float): x bbox top-left coordinate - y_min (int, float): y bbox top-left coordinate - x_max (int, float): x bbox bottom-right coordinate - y_max (int, float): y bbox bottom-right coordinate - ''' - - assert all([c >= 0 for c in [x_min,y_min,x_max,y_max]]), \ - f'Coordinates must be positives, got x_min={x_min}, y_min={y_min}, ' \ - f'x_max={x_max} and y_max={y_max}' - - assert x_max >= x_min and y_max >= y_min, \ - 'Wrong bounding box coordinates.' - - self.x_min = x_min - self.y_min = y_min - self.x_max = x_max - self.y_max = y_max - - @property - def area(self) -> Union[int,float]: - ''' To get bbox area ''' - return max(0, self.width) * max(0, self.height) - - @property - def width(self) -> Union[int,float]: - ''' To get bbox width ''' - return max(0, self.x_max - self.x_min) - - @property - def height(self) -> Union[int,float]: - ''' To get bbox height ''' - return max(0, self.y_max - self.y_min) - - @property - def get_tuple(self) -> Tuple[Union[int,float],...]: - ''' To get bbox coordinates in tuple type ''' - return (self.x_min,self.y_min,self.x_max,self.y_max) - - @property - def atype(self) -> str: - ''' To get annotation type string ''' - return 'BoundingBox' - - def __repr__(self) -> str: - return f'BoundingBox(x_min: {self.x_min}, y_min: {self.y_min}, x_max: {self.x_max}, y_max: {self.y_max})' - - def __eq__(self, other) -> bool: - return all([ - self.x_min == other.x_min, - self.y_min == other.y_min, - self.x_max == other.x_max, - self.y_max == other.y_max - ]) \ No newline at end of file diff --git a/PytorchWildlife/models/detection/localization/animaloc/eval/__init__.py b/PytorchWildlife/models/detection/localization/animaloc/eval/__init__.py deleted file mode 100644 index cb7dc2dc4..000000000 --- a/PytorchWildlife/models/detection/localization/animaloc/eval/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -__copyright__ = \ - """ - Copyright (C) 2024 University of Liège, Gembloux Agro-Bio Tech, Forest Is Life - All rights reserved. - - This source code is under the MIT License. - - Please contact the author Alexandre Delplanque (alexandre.delplanque@uliege.be) for any questions. - - Last modification: March 18, 2024 - """ -__author__ = "Alexandre Delplanque" -__license__ = "MIT License" -__version__ = "0.2.1" - -from .stitchers import * -from .lmds import * diff --git a/PytorchWildlife/models/detection/localization/animaloc/eval/lmds.py b/PytorchWildlife/models/detection/localization/animaloc/eval/lmds.py deleted file mode 100644 index ff207425c..000000000 --- a/PytorchWildlife/models/detection/localization/animaloc/eval/lmds.py +++ /dev/null @@ -1,276 +0,0 @@ -__copyright__ = \ - """ - Copyright (C) 2024 University of Liège, Gembloux Agro-Bio Tech, Forest Is Life - All rights reserved. - - This source code is under the MIT License. - - Please contact the author Alexandre Delplanque (alexandre.delplanque@uliege.be) for any questions. - - Last modification: March 18, 2024 - """ -__author__ = "Alexandre Delplanque" -__license__ = "MIT License" -__version__ = "0.2.1" - -import torch -import numpy - -import torch.nn.functional as F - -from typing import Tuple, List - -__all__ = ['LMDS', 'HerdNetLMDS', 'HerdNetLMDSLocBranch'] - - -class LMDS: - ''' Local Maxima Detection Strategy - - Adapted and enhanced from https://github.com/dk-liang/FIDTM (author: dklinag) - available under the MIT license ''' - - def __init__( - self, - kernel_size: tuple = (3,3), - adapt_ts: float = 100.0/255.0, - neg_ts: float = 0.1 - ) -> None: - ''' - Args: - kernel_size (tuple, optional): size of the kernel used to select local - maxima. Defaults to (3,3) (as in the paper). - adapt_ts (float, optional): adaptive threshold to select final points - from candidates. Defaults to 100.0/255.0 (as in the paper). - neg_ts (float, optional): negative sample threshold used to define if - an image is a negative sample or not. Defaults to 0.1 (as in the paper). - ''' - - assert kernel_size[0] == kernel_size[1], \ - f'The kernel shape must be a square, got {kernel_size[0]}x{kernel_size[1]}' - assert not kernel_size[0] % 2 == 0, \ - f'The kernel size must be odd, got {kernel_size[0]}' - - self.kernel_size = tuple(kernel_size) - self.adapt_ts = adapt_ts - self.neg_ts = neg_ts - - def __call__(self, est_map: torch.Tensor) -> Tuple[list,list,list,list]: - ''' - Args: - est_map (torch.Tensor): the estimated FIDT map - - Returns: - Tuple[list,list,list,list] - counts, labels, scores and locations per batch - ''' - batch_size, classes = est_map.shape[:2] - - b_counts, b_labels, b_scores, b_locs = [], [], [], [] - for b in range(batch_size): - counts, labels, scores, locs = [], [], [], [] - - for c in range(classes): - count, loc, score = self._lmds(est_map[b][c]) - counts.append(count) - labels = [*labels, *[c+1]*count] - scores = [*scores, *score] - locs = [*locs, *loc] - - b_counts.append(counts) - b_labels.append(labels) - b_scores.append(scores) - b_locs.append(locs) - - return b_counts, b_locs, b_labels, b_scores - - def _local_max(self, est_map: torch.Tensor) -> torch.Tensor: - ''' Shape: est_map = [B,C,H,W] ''' - - pad = int(self.kernel_size[0] / 2) - keep = torch.nn.functional.max_pool2d(est_map, kernel_size=self.kernel_size, stride=1, padding=pad) - keep = (keep == est_map).float() - est_map = keep * est_map - - return est_map - - def _get_locs_and_scores( - self, - locs_map: torch.Tensor, - scores_map: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: - ''' Shapes: locs_map = [H,W] and scores_map = [H,W] ''' - - locs_map = locs_map.data.cpu().numpy() - scores_map = scores_map.data.cpu().numpy() - locs = [] - scores = [] - for i, j in numpy.argwhere(locs_map == 1): - locs.append((i,j)) - scores.append(scores_map[i][j]) - - return torch.Tensor(locs), torch.Tensor(scores) - - def _lmds(self, est_map: torch.Tensor) -> Tuple[int, list, list]: - ''' Shape: est_map = [H,W] ''' - - est_map_max = torch.max(est_map).item() - - # local maxima - est_map = self._local_max(est_map.unsqueeze(0).unsqueeze(0)) - - # adaptive threshold for counting - est_map[est_map < self.adapt_ts * est_map_max] = 0 - scores_map = torch.clone(est_map) - est_map[est_map > 0] = 1 - - # negative sample - if est_map_max < self.neg_ts: - est_map = est_map * 0 - - # count - count = int(torch.sum(est_map).item()) - - # locations and scores - locs, scores = self._get_locs_and_scores( - est_map.squeeze(0).squeeze(0), - scores_map.squeeze(0).squeeze(0) - ) - - return count, locs.tolist(), scores.tolist() - -class HerdNetLMDS(LMDS): - - def __init__( - self, - up: bool = True, - kernel_size: tuple = (3,3), - adapt_ts: float = 0.3, - neg_ts: float = 0.1 - ) -> None: - ''' - Args: - up (bool, optional): set to False to disable class maps upsampling. - Defaults to True. - kernel_size (tuple, optional): size of the kernel used to select local - maxima. Defaults to (3,3) (as in the paper). - adapt_ts (float, optional): adaptive threshold to select final points - from candidates. Defaults to 0.3. - neg_ts (float, optional): negative sample threshold used to define if - an image is a negative sample or not. Defaults to 0.1 (as in the paper). - ''' - - super().__init__(kernel_size=kernel_size, adapt_ts=adapt_ts, neg_ts=neg_ts) - - self.up = up - - def __call__(self, outputs: List[torch.Tensor]) -> Tuple[list, list, list, list, list]: - """ - Args: - outputs (List[torch.Tensor]): Outputs of HerdNet, i.e., 2 tensors: - - heatmap: [B,1,H,W], - - class map: [B,C,H/16,W/16]. - - Returns: - Tuple[list, list, list, list, list]: - Counts, locations, labels, class scores, and detection scores per batch. - """ - - heatmap, clsmap = outputs - - # upsample class map - if self.up: - scale_factor = 16 - clsmap = F.interpolate(clsmap, scale_factor=scale_factor, mode='nearest') - - # softmax - cls_scores = torch.softmax(clsmap, dim=1)[:,1:,:,:] - - # cat to heatmap - outmaps = torch.cat([heatmap, cls_scores], dim=1) - - # LMDS - batch_size, channels = outmaps.shape[:2] - - b_counts, b_labels, b_scores, b_locs, b_dscores = [], [], [], [], [] - for b in range(batch_size): - - _, locs, _ = self._lmds(heatmap[b][0]) - - cls_idx = torch.argmax(clsmap[b,1:,:,:], dim=0) - classes = torch.add(cls_idx, 1) - - h_idx = torch.Tensor([l[0] for l in locs]).long() - w_idx = torch.Tensor([l[1] for l in locs]).long() - labels = classes[h_idx, w_idx].long().tolist() - - chan_idx = cls_idx[h_idx, w_idx].long().tolist() - scores = cls_scores[b, chan_idx, h_idx, w_idx].float().tolist() - - dscores = heatmap[b, 0, h_idx, w_idx].float().tolist() - - counts = [labels.count(i) for i in range(1, channels)] - - b_labels.append(labels) - b_scores.append(scores) - b_locs.append(locs) - b_counts.append(counts) - b_dscores.append(dscores) - - return b_counts, b_locs, b_labels, b_scores, b_dscores - -class HerdNetLMDSLocBranch(LMDS): - - def __init__( - self, - kernel_size: tuple = (3,3), - adapt_ts: float = 0.3, - neg_ts: float = 0.1 - ) -> None: - ''' - Args: - up (bool, optional): set to False to disable class maps upsampling. - Defaults to True. - kernel_size (tuple, optional): size of the kernel used to select local - maxima. Defaults to (3,3) (as in the paper). - adapt_ts (float, optional): adaptive threshold to select final points - from candidates. Defaults to 0.3. - neg_ts (float, optional): negative sample threshold used to define if - an image is a negative sample or not. Defaults to 0.1 (as in the paper). - ''' - - super().__init__(kernel_size=kernel_size, adapt_ts=adapt_ts, neg_ts=neg_ts) - - def __call__(self, outputs: List[torch.Tensor]) -> Tuple[list, list, list, list, list]: - ''' - Args: - outmaps (torch.Tensor): outputs of HerdNet localization branch, i.e. 1 tensors: - - heatmap: [B,1,H,W], - - Returns: - Tuple[list,list,list,list,list] - counts, locations, labels, class scores and detection scores per batch - ''' - heatmap = outputs - - # LMDS - batch_size, channels = heatmap.shape[:2] - - b_counts, b_labels, b_locs, b_dscores = [], [], [], [] - for b in range(batch_size): - - _, locs, _ = self._lmds(heatmap[b][0]) - - h_idx = torch.Tensor([l[0] for l in locs]).long() - w_idx = torch.Tensor([l[1] for l in locs]).long() - labels = [1] * len(locs) - - dscores = heatmap[b, 0, h_idx, w_idx].float().tolist() - - counts = len(locs) - - b_labels.append(labels) - b_locs.append(locs) - b_counts.append(counts) - b_dscores.append(dscores) - - return b_counts, b_locs, b_labels, b_dscores \ No newline at end of file diff --git a/PytorchWildlife/models/detection/localization/animaloc/eval/stitchers.py b/PytorchWildlife/models/detection/localization/animaloc/eval/stitchers.py deleted file mode 100644 index f53775c02..000000000 --- a/PytorchWildlife/models/detection/localization/animaloc/eval/stitchers.py +++ /dev/null @@ -1,256 +0,0 @@ -__copyright__ = \ - """ - Copyright (C) 2024 University of Liège, Gembloux Agro-Bio Tech, Forest Is Life - All rights reserved. - - This source code is under the MIT License. - - Please contact the author Alexandre Delplanque (alexandre.delplanque@uliege.be) for any questions. - - Last modification: March 18, 2024 - """ -__author__ = "Alexandre Delplanque" -__license__ = "MIT License" -__version__ = "0.2.1" - - -import torch -import torchvision - -import torch.nn.functional as F -import numpy as np - -from typing import List, Tuple -from torch.utils.data import TensorDataset, DataLoader, SequentialSampler - -from ..data import ImageToPatches - -class Stitcher(ImageToPatches): - ''' Class to stitch detections of patches into original image - coordinates system - - This algorithm works as follow: - 1) Cut original image into patches - 2) Make inference on each patches and harvest the detections - 3) Patch the detections maps into the coordinate system of the original image - Optional: - 4) Upsample the patched detection map - ''' - - def __init__( - self, - model: torch.nn.Module, - size: Tuple[int,int], - overlap: int = 100, - batch_size: int = 1, - down_ratio: int = 1, - up: bool = False, - reduction: str = 'sum', - device_name: str = 'cuda', - ) -> None: - ''' - Args: - model (torch.nn.Module): CNN detection model, that takes as inputs image and returns - output and dict (i.e. wrapped by LossWrapper) - size (tuple): patches size (height, width), in pixels - overlap (int, optional): overlap between patches, in pixels. - Defaults to 100. - batch_size (int, optional): batch size used for inference over patches. - Defaults to 1. - down_ratio (int, optional): downsample ratio. Set to 1 to get output of the same - size as input (i.e. no downsample). Defaults to 1. - up (bool, optional): set to True to upsample the patched map. Defaults to False. - reduction (str, optional): specifies the reduction to apply on overlapping areas. - Possible values are 'sum', 'mean', 'max'. Defaults to 'sum'. - device_name (str, optional): the device name on which tensors will be allocated - ('cpu' or 'cuda'). Defaults to 'cuda'. - ''' - - assert isinstance(model, torch.nn.Module), \ - 'model argument must be an instance of nn.Module()' - - assert reduction in ['sum', 'mean', 'max'], \ - 'reduction argument possible values are \'sum\', \'mean\' and \'max\' ' \ - f'got \'{reduction}\'' - - self.model = model - self.size = size - self.overlap = overlap - self.batch_size = batch_size - self.down_ratio = down_ratio - self.up = up - self.reduction = reduction - self.device = torch.device(device_name) - - self.model.to(self.device) - - def __call__( - self, - image: torch.Tensor - ) -> torch.Tensor: - ''' Apply the stitching algorithm to the image - - Args: - image (torch.Tensor): image of shape [C,H,W] - - Returns: - torch.Tensor - the detections into the coordinate system of the original image - ''' - - super(Stitcher, self).__init__(image, self.size, self.overlap) - - self.image = image.to(torch.device('cpu')) - - # step 1 - get patches and limits - patches = self.make_patches() - - # step 2 - inference to get maps - det_maps = self._inference(patches) - - # step 3 - patch the maps into initial coordinates system - patched_map = self._patch_maps(det_maps) - patched_map = self._reduce(patched_map) - - # (step 4 - upsample) - if self.up: - patched_map = F.interpolate(patched_map, scale_factor=self.down_ratio, - mode='bilinear', align_corners=True) - - return patched_map - - - @torch.no_grad() - def _inference(self, patches: torch.Tensor) -> List[torch.Tensor]: - - self.model.eval() - - dataset = TensorDataset(patches) - dataloader = DataLoader( - dataset, - batch_size=self.batch_size, - sampler=SequentialSampler(dataset) - ) - - maps = [] - for patch in dataloader: - patch = patch[0].to(self.device) - outputs, _ = self.model(patch) - maps = [*maps, *outputs.unsqueeze(0)] - - return maps - - def _patch_maps(self, maps: List[torch.Tensor]) -> torch.Tensor: - - _, h, w = self.image.shape - dh, dw = h // self.down_ratio, w // self.down_ratio - kernel_size = np.array(self.size) // self.down_ratio - stride = kernel_size - self.overlap // self.down_ratio - output_size = ( - self._ncol * kernel_size[0] - ((self._ncol-1) * self.overlap // self.down_ratio), - self._nrow * kernel_size[1] - ((self._nrow-1) * self.overlap // self.down_ratio) - ) - - maps = torch.cat(maps, dim=0) - - if self.reduction == 'max': - out_map = self._max_fold(maps, output_size=output_size, - kernel_size=tuple(kernel_size), stride=tuple(stride)) - else: - n_patches = maps.shape[0] - maps = maps.permute(1,2,3,0).contiguous().view(1, -1, n_patches) - out_map = F.fold(maps, output_size=output_size, - kernel_size=tuple(kernel_size), stride=tuple(stride)) - - out_map = out_map[:,:, 0:dh, 0:dw] - - return out_map - - def _reduce(self, map: torch.Tensor) -> torch.Tensor: - - dh = self.image.shape[1] // self.down_ratio - dw = self.image.shape[2] // self.down_ratio - ones = torch.ones(self.image.shape[0],dh,dw) - - if self.reduction == 'mean': - ones_patches = ImageToPatches(ones, - np.array(self.size)//self.down_ratio, - self.overlap//self.down_ratio - ).make_patches() - - ones_patches = [p.unsqueeze(0).unsqueeze(0) for p in ones_patches[:,1,:,:]] - norm_map = self._patch_maps(ones_patches) - - else: - norm_map = ones[1,:,:] - - return torch.div(map.to(self.device), norm_map.to(self.device)) - - def _max_fold(self, maps: torch.Tensor, output_size: tuple, - kernel_size: tuple, stride: tuple - ) -> torch.Tensor: - - output = torch.zeros((1, maps.shape[1], *output_size)) - - fn = lambda x: [[i, i+kernel_size[x]] for i in range(0, output_size[x], stride[x])][:-1] - locs = [[*h, *w] for h in fn(0) for w in fn(1)] - - for loc, m in zip(locs, maps): - patch = torch.zeros(output.shape) - patch[:,:, loc[0]:loc[1], loc[2]:loc[3]] = m - output = torch.max(output, patch) - - return output - -class HerdNetStitcher(Stitcher): - - @torch.no_grad() - def _inference(self, patches: torch.Tensor) -> List[torch.Tensor]: - - self.model.eval() - - dataset = TensorDataset(patches) - dataloader = DataLoader( - dataset, - batch_size=self.batch_size, - sampler=SequentialSampler(dataset) - ) - - maps = [] - for patch in dataloader: - patch = patch[0].to(self.device) - #outputs = self.model(patch)[0] - outputs = self.model(patch) # LossWrapper is not used - heatmap = outputs[0] - scale_factor = 16 - clsmap = F.interpolate(outputs[1], scale_factor=scale_factor, mode='nearest') - # cat - outmaps = torch.cat([heatmap, clsmap], dim=1) - maps = [*maps, *outmaps.unsqueeze(0)] - - return maps - -class HerdNetStitcherLocBranch(Stitcher): - - @torch.no_grad() - def _inference(self, patches: torch.Tensor) -> List[torch.Tensor]: - - self.model.eval() - - dataset = TensorDataset(patches) - dataloader = DataLoader( - dataset, - batch_size=self.batch_size, - sampler=SequentialSampler(dataset) - ) - - maps = [] - for patch in dataloader: - patch = patch[0].to(self.device) - #outputs = self.model(patch)[0] - outputs = self.model(patch) # LossWrapper is not used - heatmap = outputs - outmaps = heatmap.unsqueeze(0) - maps = [*maps, *outmaps] - - return maps diff --git a/PytorchWildlife/models/detection/localization/dla.py b/PytorchWildlife/models/detection/localization/dla.py deleted file mode 100644 index dd87cccff..000000000 --- a/PytorchWildlife/models/detection/localization/dla.py +++ /dev/null @@ -1,590 +0,0 @@ -__copyright__ = \ - """ - MIT License - - Copyright (c) 2019 Xingyi Zhou - All rights reserved. - - Permission is hereby granted, free of charge, to any person obtaining a copy - of this software and associated documentation files (the "Software"), to deal - in the Software without restriction, including without limitation the rights - to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - copies of the Software, and to permit persons to whom the Software is - furnished to do so, subject to the following conditions: - - The above copyright notice and this permission notice shall be included in all - copies or substantial portions of the Software. - """ -__authors__ = "Xingyi Zhou, Dequan Wang, Philipp Krähenbühl" -__license__ = "MIT" - - -import math -from os.path import join -from posixpath import basename - -import torch -from torch import nn -import torch.utils.model_zoo as model_zoo - -import numpy as np - -BatchNorm = nn.BatchNorm2d - -def get_model_url(data='imagenet', name='dla34', hash='ba72cf86'): - return join('http://dl.yf.io/dla/models', data, '{}-{}.pth'.format(name, hash)) - - -def conv3x3(in_planes, out_planes, stride=1): - "3x3 convolution with padding" - return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, - padding=1, bias=False) - - -class BasicBlock(nn.Module): - def __init__(self, inplanes, planes, stride=1, dilation=1): - super(BasicBlock, self).__init__() - self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, - stride=stride, padding=dilation, - bias=False, dilation=dilation) - self.bn1 = BatchNorm(planes) - self.relu = nn.ReLU(inplace=True) - self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, - stride=1, padding=dilation, - bias=False, dilation=dilation) - self.bn2 = BatchNorm(planes) - self.stride = stride - - def forward(self, x, residual=None): - if residual is None: - residual = x - - out = self.conv1(x) - out = self.bn1(out) - out = self.relu(out) - - out = self.conv2(out) - out = self.bn2(out) - - out += residual - out = self.relu(out) - - return out - - -class Bottleneck(nn.Module): - expansion = 2 - - def __init__(self, inplanes, planes, stride=1, dilation=1): - super(Bottleneck, self).__init__() - expansion = Bottleneck.expansion - bottle_planes = planes // expansion - self.conv1 = nn.Conv2d(inplanes, bottle_planes, - kernel_size=1, bias=False) - self.bn1 = BatchNorm(bottle_planes) - self.conv2 = nn.Conv2d(bottle_planes, bottle_planes, kernel_size=3, - stride=stride, padding=dilation, - bias=False, dilation=dilation) - self.bn2 = BatchNorm(bottle_planes) - self.conv3 = nn.Conv2d(bottle_planes, planes, - kernel_size=1, bias=False) - self.bn3 = BatchNorm(planes) - self.relu = nn.ReLU(inplace=True) - self.stride = stride - - def forward(self, x, residual=None): - if residual is None: - residual = x - - out = self.conv1(x) - out = self.bn1(out) - out = self.relu(out) - - out = self.conv2(out) - out = self.bn2(out) - out = self.relu(out) - - out = self.conv3(out) - out = self.bn3(out) - - out += residual - out = self.relu(out) - - return out - - -class BottleneckX(nn.Module): - expansion = 2 - cardinality = 32 - - def __init__(self, inplanes, planes, stride=1, dilation=1): - super(BottleneckX, self).__init__() - cardinality = BottleneckX.cardinality - bottle_planes = planes * cardinality // 32 - self.conv1 = nn.Conv2d(inplanes, bottle_planes, - kernel_size=1, bias=False) - self.bn1 = BatchNorm(bottle_planes) - self.conv2 = nn.Conv2d(bottle_planes, bottle_planes, kernel_size=3, - stride=stride, padding=dilation, bias=False, - dilation=dilation, groups=cardinality) - self.bn2 = BatchNorm(bottle_planes) - self.conv3 = nn.Conv2d(bottle_planes, planes, - kernel_size=1, bias=False) - self.bn3 = BatchNorm(planes) - self.relu = nn.ReLU(inplace=True) - self.stride = stride - - def forward(self, x, residual=None): - if residual is None: - residual = x - - out = self.conv1(x) - out = self.bn1(out) - out = self.relu(out) - - out = self.conv2(out) - out = self.bn2(out) - out = self.relu(out) - - out = self.conv3(out) - out = self.bn3(out) - - out += residual - out = self.relu(out) - - return out - - -class Root(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, residual): - super(Root, self).__init__() - self.conv = nn.Conv2d( - in_channels, out_channels, 1, - stride=1, bias=False, padding=(kernel_size - 1) // 2) - self.bn = BatchNorm(out_channels) - self.relu = nn.ReLU(inplace=True) - self.residual = residual - - def forward(self, *x): - children = x - x = self.conv(torch.cat(x, 1)) - x = self.bn(x) - if self.residual: - x += children[0] - x = self.relu(x) - - return x - - -class Tree(nn.Module): - def __init__(self, levels, block, in_channels, out_channels, stride=1, - level_root=False, root_dim=0, root_kernel_size=1, - dilation=1, root_residual=False): - super(Tree, self).__init__() - if root_dim == 0: - root_dim = 2 * out_channels - if level_root: - root_dim += in_channels - if levels == 1: - self.tree1 = block(in_channels, out_channels, stride, - dilation=dilation) - self.tree2 = block(out_channels, out_channels, 1, - dilation=dilation) - else: - self.tree1 = Tree(levels - 1, block, in_channels, out_channels, - stride, root_dim=0, - root_kernel_size=root_kernel_size, - dilation=dilation, root_residual=root_residual) - self.tree2 = Tree(levels - 1, block, out_channels, out_channels, - root_dim=root_dim + out_channels, - root_kernel_size=root_kernel_size, - dilation=dilation, root_residual=root_residual) - if levels == 1: - self.root = Root(root_dim, out_channels, root_kernel_size, - root_residual) - self.level_root = level_root - self.root_dim = root_dim - self.downsample = None - self.project = None - self.levels = levels - if stride > 1: - self.downsample = nn.MaxPool2d(stride, stride=stride) - if in_channels != out_channels: - self.project = nn.Sequential( - nn.Conv2d(in_channels, out_channels, - kernel_size=1, stride=1, bias=False), - BatchNorm(out_channels) - ) - - def forward(self, x, residual=None, children=None): - children = [] if children is None else children - bottom = self.downsample(x) if self.downsample else x - residual = self.project(bottom) if self.project else bottom - if self.level_root: - children.append(bottom) - x1 = self.tree1(x, residual) - if self.levels == 1: - x2 = self.tree2(x1) - x = self.root(x2, x1, *children) - else: - children.append(x1) - x = self.tree2(x1, children=children) - return x - - -class DLA(nn.Module): - def __init__(self, levels, channels, num_classes=1000, - block=BasicBlock, residual_root=False, return_levels=False, - pool_size=7, linear_root=False): - super(DLA, self).__init__() - self.channels = channels - self.return_levels = return_levels - self.num_classes = num_classes - self.base_layer = nn.Sequential( - nn.Conv2d(3, channels[0], kernel_size=7, stride=1, - padding=3, bias=False), - BatchNorm(channels[0]), - nn.ReLU(inplace=True)) - self.level0 = self._make_conv_level( - channels[0], channels[0], levels[0]) - self.level1 = self._make_conv_level( - channels[0], channels[1], levels[1], stride=2) - self.level2 = Tree(levels[2], block, channels[1], channels[2], 2, - level_root=False, - root_residual=residual_root) - self.level3 = Tree(levels[3], block, channels[2], channels[3], 2, - level_root=True, root_residual=residual_root) - self.level4 = Tree(levels[4], block, channels[3], channels[4], 2, - level_root=True, root_residual=residual_root) - self.level5 = Tree(levels[5], block, channels[4], channels[5], 2, - level_root=True, root_residual=residual_root) - - self.avgpool = nn.AvgPool2d(pool_size) - self.fc = nn.Conv2d(channels[-1], num_classes, kernel_size=1, - stride=1, padding=0, bias=True) - - for m in self.modules(): - if isinstance(m, nn.Conv2d): - n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels - m.weight.data.normal_(0, math.sqrt(2. / n)) - elif isinstance(m, BatchNorm): - m.weight.data.fill_(1) - m.bias.data.zero_() - - def _make_level(self, block, inplanes, planes, blocks, stride=1): - downsample = None - if stride != 1 or inplanes != planes: - downsample = nn.Sequential( - nn.MaxPool2d(stride, stride=stride), - nn.Conv2d(inplanes, planes, - kernel_size=1, stride=1, bias=False), - BatchNorm(planes), - ) - - layers = [] - layers.append(block(inplanes, planes, stride, downsample=downsample)) - for i in range(1, blocks): - layers.append(block(inplanes, planes)) - - return nn.Sequential(*layers) - - def _make_conv_level(self, inplanes, planes, convs, stride=1, dilation=1): - modules = [] - for i in range(convs): - modules.extend([ - nn.Conv2d(inplanes, planes, kernel_size=3, - stride=stride if i == 0 else 1, - padding=dilation, bias=False, dilation=dilation), - BatchNorm(planes), - nn.ReLU(inplace=True)]) - inplanes = planes - return nn.Sequential(*modules) - - def forward(self, x): - y = [] - x = self.base_layer(x) - for i in range(6): - x = getattr(self, 'level{}'.format(i))(x) - y.append(x) - if self.return_levels: - return y - else: - x = self.avgpool(x) - x = self.fc(x) - x = x.view(x.size(0), -1) - - return x - - def load_pretrained_model(self, data='imagenet', name='dla34', hash='ba72cf86'): - fc = self.fc - if name.endswith('.pth'): - model_weights = torch.load(data + name) - else: - model_url = get_model_url(data, name, hash) - model_weights = model_zoo.load_url(model_url) - num_classes = len(model_weights[list(model_weights.keys())[-1]]) - self.fc = nn.Conv2d( - self.channels[-1], num_classes, - kernel_size=1, stride=1, padding=0, bias=True) - self.load_state_dict(model_weights) - self.fc = fc - - -def dla34(pretrained, **kwargs): # DLA-34 - model = DLA([1, 1, 1, 2, 2, 1], - [16, 32, 64, 128, 256, 512], - block=BasicBlock, **kwargs) - if pretrained: - model.load_pretrained_model(data='imagenet', name='dla34', hash='ba72cf86') - return model - - -def dla46_c(pretrained=None, **kwargs): # DLA-46-C - Bottleneck.expansion = 2 - model = DLA([1, 1, 1, 2, 2, 1], - [16, 32, 64, 64, 128, 256], - block=Bottleneck, **kwargs) - if pretrained is not None: - model.load_pretrained_model(pretrained, 'dla46_c') - return model - - -def dla46x_c(pretrained=None, **kwargs): # DLA-X-46-C - BottleneckX.expansion = 2 - model = DLA([1, 1, 1, 2, 2, 1], - [16, 32, 64, 64, 128, 256], - block=BottleneckX, **kwargs) - if pretrained is not None: - model.load_pretrained_model(pretrained, 'dla46x_c') - return model - - -def dla60x_c(pretrained, **kwargs): # DLA-X-60-C - BottleneckX.expansion = 2 - model = DLA([1, 1, 1, 2, 3, 1], - [16, 32, 64, 64, 128, 256], - block=BottleneckX, **kwargs) - if pretrained: - model.load_pretrained_model(data='imagenet', name='dla60x_c', hash='b870c45c') - return model - - -def dla60(pretrained=None, **kwargs): # DLA-60 - Bottleneck.expansion = 2 - model = DLA([1, 1, 1, 2, 3, 1], - [16, 32, 128, 256, 512, 1024], - block=Bottleneck, **kwargs) - if pretrained is not None: - model.load_pretrained_model(pretrained, 'dla60') - return model - - -def dla60x(pretrained=None, **kwargs): # DLA-X-60 - BottleneckX.expansion = 2 - model = DLA([1, 1, 1, 2, 3, 1], - [16, 32, 128, 256, 512, 1024], - block=BottleneckX, **kwargs) - if pretrained is not None: - model.load_pretrained_model(pretrained, 'dla60x') - return model - - -def dla102(pretrained=None, **kwargs): # DLA-102 - Bottleneck.expansion = 2 - model = DLA([1, 1, 1, 3, 4, 1], [16, 32, 128, 256, 512, 1024], - block=Bottleneck, residual_root=True, **kwargs) - if pretrained is not None: - model.load_pretrained_model(pretrained, 'dla102') - return model - - -def dla102x(pretrained=None, **kwargs): # DLA-X-102 - BottleneckX.expansion = 2 - model = DLA([1, 1, 1, 3, 4, 1], [16, 32, 128, 256, 512, 1024], - block=BottleneckX, residual_root=True, **kwargs) - if pretrained is not None: - model.load_pretrained_model(pretrained, 'dla102x') - return model - - -def dla102x2(pretrained=None, **kwargs): # DLA-X-102 64 - BottleneckX.cardinality = 64 - model = DLA([1, 1, 1, 3, 4, 1], [16, 32, 128, 256, 512, 1024], - block=BottleneckX, residual_root=True, **kwargs) - if pretrained is not None: - model.load_pretrained_model(pretrained, 'dla102x2') - return model - - -def dla169(pretrained=None, **kwargs): # DLA-169 - Bottleneck.expansion = 2 - model = DLA([1, 1, 2, 3, 5, 1], [16, 32, 128, 256, 512, 1024], - block=Bottleneck, residual_root=True, **kwargs) - if pretrained is not None: - model.load_pretrained_model(pretrained, 'dla169') - return model - - -class Identity(nn.Module): - def __init__(self): - super(Identity, self).__init__() - - def forward(self, x): - return x - - -def fill_up_weights(up): - w = up.weight.data - f = math.ceil(w.size(2) / 2) - c = (2 * f - 1 - f % 2) / (2. * f) - for i in range(w.size(2)): - for j in range(w.size(3)): - w[0, 0, i, j] = \ - (1 - math.fabs(i / f - c)) * (1 - math.fabs(j / f - c)) - for c in range(1, w.size(0)): - w[c, 0, :, :] = w[0, 0, :, :] - - -class IDAUp(nn.Module): - def __init__(self, node_kernel, out_dim, channels, up_factors): - super(IDAUp, self).__init__() - self.channels = channels - self.out_dim = out_dim - for i, c in enumerate(channels): - if c == out_dim: - proj = Identity() - else: - proj = nn.Sequential( - nn.Conv2d(c, out_dim, - kernel_size=1, stride=1, bias=False), - BatchNorm(out_dim), - nn.ReLU(inplace=True)) - f = int(up_factors[i]) - if f == 1: - up = Identity() - else: - up = nn.ConvTranspose2d( - out_dim, out_dim, f * 2, stride=f, padding=f // 2, - output_padding=0, groups=out_dim, bias=False) - fill_up_weights(up) - setattr(self, 'proj_' + str(i), proj) - setattr(self, 'up_' + str(i), up) - - for i in range(1, len(channels)): - node = nn.Sequential( - nn.Conv2d(out_dim * 2, out_dim, - kernel_size=node_kernel, stride=1, - padding=node_kernel // 2, bias=False), - BatchNorm(out_dim), - nn.ReLU(inplace=True)) - setattr(self, 'node_' + str(i), node) - - for m in self.modules(): - if isinstance(m, nn.Conv2d): - n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels - m.weight.data.normal_(0, math.sqrt(2. / n)) - elif isinstance(m, BatchNorm): - m.weight.data.fill_(1) - m.bias.data.zero_() - - def forward(self, layers): - assert len(self.channels) == len(layers), \ - '{} vs {} layers'.format(len(self.channels), len(layers)) - layers = list(layers) - for i, l in enumerate(layers): - upsample = getattr(self, 'up_' + str(i)) - project = getattr(self, 'proj_' + str(i)) - layers[i] = upsample(project(l)) - x = layers[0] - y = [] - for i in range(1, len(layers)): - node = getattr(self, 'node_' + str(i)) - x = node(torch.cat([x, layers[i]], 1)) - y.append(x) - return x, y - - -class DLAUp(nn.Module): - def __init__(self, channels, scales=(1, 2, 4, 8, 16), in_channels=None): - super(DLAUp, self).__init__() - if in_channels is None: - in_channels = channels - self.channels = channels - channels = list(channels) - scales = np.array(scales, dtype=int) - for i in range(len(channels) - 1): - j = -i - 2 - setattr(self, 'ida_{}'.format(i), - IDAUp(3, channels[j], in_channels[j:], - scales[j:] // scales[j])) - scales[j + 1:] = scales[j] - in_channels[j + 1:] = [channels[j] for _ in channels[j + 1:]] - - def forward(self, layers): - layers = list(layers) - assert len(layers) > 1 - for i in range(len(layers) - 1): - ida = getattr(self, 'ida_{}'.format(i)) - x, y = ida(layers[-i - 2:]) - layers[-i - 1:] = y - return x - -def fill_fc_weights(layers): - for m in layers.modules(): - if isinstance(m, nn.Conv2d): - nn.init.normal_(m.weight, std=0.001) - if m.bias is not None: - nn.init.constant_(m.bias, 0) - -class DLASeg(nn.Module): - def __init__(self, base_name, heads, - pretrained=True, down_ratio=4, head_conv=256): - super(DLASeg, self).__init__() - self.heads = heads - self.first_level = int(np.log2(down_ratio)) - self.base = globals()[base_name]( - pretrained=pretrained, return_levels=True) - channels = self.base.channels - scales = [2 ** i for i in range(len(channels[self.first_level:]))] - self.dla_up = DLAUp(channels[self.first_level:], scales=scales) - - for head in self.heads: - classes = self.heads[head] - if head_conv > 0: - fc = nn.Sequential( - nn.Conv2d(channels[self.first_level], head_conv, - kernel_size=3, padding=1, bias=True), - nn.ReLU(inplace=True), - nn.Conv2d(head_conv, classes, - kernel_size=1, stride=1, - padding=0, bias=True)) - if 'hm' in head: - fc[-1].bias.data.fill_(-2.19) - else: - fill_fc_weights(fc) - else: - fc = nn.Conv2d(channels[self.first_level], classes, - kernel_size=1, stride=1, - padding=0, bias=True) - if 'hm' in head: - fc.bias.data.fill_(-2.19) - else: - fill_fc_weights(fc) - self.__setattr__(head, fc) - - def forward(self, x): - x = self.base(x) - x = self.dla_up(x[self.first_level:]) - # x = self.fc(x) - # y = self.softmax(self.up(x)) - ret = {} - for head in self.heads: - ret[head] = self.__getattr__(head)(x) - return [ret] - -def get_pose_net(num_layers, heads, head_conv=256, down_ratio=4): - model = DLASeg('dla{}'.format(num_layers), heads, - pretrained=True, - down_ratio=down_ratio, - head_conv=head_conv) - return model \ No newline at end of file diff --git a/PytorchWildlife/models/detection/localization/herdnet.py b/PytorchWildlife/models/detection/localization/herdnet.py deleted file mode 100644 index 2aa5fc97f..000000000 --- a/PytorchWildlife/models/detection/localization/herdnet.py +++ /dev/null @@ -1,313 +0,0 @@ -from ..base_detector import BaseDetector -from ..localization.animaloc.eval import HerdNetStitcher, HerdNetLMDS -from ....data import datasets as pw_data -from .model import HerdNet as HerdNetArch - -import torch -from torch.hub import load_state_dict_from_url -from torch.utils.data import DataLoader -import torchvision.transforms as transforms - -import numpy as np -from PIL import Image -from tqdm import tqdm -import supervision as sv -import os -import wget -import cv2 - -class ResizeIfSmaller: - def __init__(self, min_size, interpolation=Image.BILINEAR): - self.min_size = min_size - self.interpolation = interpolation - - def __call__(self, img): - if isinstance(img, np.ndarray): - img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) - assert isinstance(img, Image.Image), "Image should be a PIL Image" - width, height = img.size - if height < self.min_size or width < self.min_size: - ratio = max(self.min_size / height, self.min_size / width) - new_height = int(height * ratio) - new_width = int(width * ratio) - img = img.resize((new_width, new_height), self.interpolation) - return img - -class HerdNet(BaseDetector): - """ - HerdNet detector class. This class provides utility methods for - loading the model, generating results, and performing single and batch image detections. - """ - - def __init__(self, weights=None, device="cpu", version='general' ,url="https://zenodo.org/records/13899852/files/20220413_HerdNet_General_dataset_2022.pth?download=1", transform=None): - """ - Initialize the HerdNet detector. - - Args: - weights (str, optional): - Path to the model weights. Defaults to None. - device (str, optional): - Device for model inference. Defaults to "cpu". - version (str, optional): - Version name based on what dataset the model is trained on. It should be either 'general' or 'ennedi'. Defaults to 'general'. - url (str, optional): - URL to fetch the model weights. Defaults to "https://zenodo.org/records/13899852/files/20220413_HerdNet_General_dataset_2022.pth?download=1". - transform (torchvision.transforms.Compose, optional): - Image transformation for inference. Defaults to None. - """ - super(HerdNet, self).__init__(weights=weights, device=device, url=url) - # Assert that the dataset is either 'general' or 'ennedi' - version = version.lower() - assert version in ['general', 'ennedi'], "Dataset should be either 'general' or 'ennedi'" - if version == 'ennedi': - url = "https://zenodo.org/records/13914287/files/20220329_HerdNet_Ennedi_dataset_2023.pth?download=1" - self._load_model(weights, device, url) - - self.stitcher = HerdNetStitcher( # This module enables patch-based inference - model = self.model, - size = (512,512), - overlap = 160, - down_ratio = 2, - up = True, - reduction = 'mean', - device_name = device - ) - - self.lmds_kwargs: dict = {'kernel_size': (3, 3), 'adapt_ts': 0.2, 'neg_ts': 0.1} - self.lmds = HerdNetLMDS(up=False, **self.lmds_kwargs) # Local Maxima Detection Strategy - - if not transform: - self.transforms = transforms.Compose([ - ResizeIfSmaller(512), - transforms.ToTensor(), - transforms.Normalize(mean=self.img_mean, std=self.img_std) - ]) - else: - self.transforms = transform - - def _load_model(self, weights=None, device="cpu", url=None): - """ - Load the HerdNet model weights. - - Args: - weights (str, optional): - Path to the model weights. Defaults to None. - device (str, optional): - Device for model inference. Defaults to "cpu". - url (str, optional): - URL to fetch the model weights. Defaults to None. - Raises: - Exception: If weights are not provided. - """ - if weights: - checkpoint = torch.load(weights, map_location=torch.device(device)) - elif url: - filename = url.split('/')[-1][:-11] # Splitting the URL to get the filename and removing the '?download=1' part - if not os.path.exists(os.path.join(torch.hub.get_dir(), "checkpoints", filename)): - os.makedirs(os.path.join(torch.hub.get_dir(), "checkpoints"), exist_ok=True) - weights = wget.download(url, out=os.path.join(torch.hub.get_dir(), "checkpoints")) - else: - weights = os.path.join(torch.hub.get_dir(), "checkpoints", filename) - checkpoint = torch.load(weights, map_location=torch.device(device)) - else: - raise Exception("Need weights for inference.") - - # Load the class names and other metadata from the checkpoint - self.CLASS_NAMES = checkpoint["classes"] - self.num_classes = len(self.CLASS_NAMES) + 1 - self.img_mean = checkpoint['mean'] - self.img_std = checkpoint['std'] - - # Load the model architecture - self.model = HerdNetArch(num_classes=self.num_classes, pretrained=False) - - # Load checkpoint into model - state_dict = checkpoint['model_state_dict'] - # Remove 'model.' prefix from the state_dict keys if the key starts with 'model.' - new_state_dict = {k.replace('model.', ''): v for k, v in state_dict.items() if k.startswith('model.')} - # Load the new state_dict - self.model.load_state_dict(new_state_dict, strict=True) - - print(f"Model loaded from {weights}") - - def results_generation(self, preds: np.ndarray, img: np.ndarray = None, img_id: str = None, id_strip: str = None) -> dict: - """ - Generate results for detection based on model predictions. - - Args: - preds (numpy.ndarray): Model predictions. - img (numpy.ndarray, optional): Image for inference. Defaults to None. - img_id (str, optional): Image identifier. Defaults to None. - id_strip (str, optional): Strip specific characters from img_id. Defaults to None. - - Returns: - dict: Dictionary containing image ID, detections, and labels. - """ - assert img is not None or img_id is not None, "Either img or img_id should be provided." - if img_id is not None: - img_id = str(img_id).strip(id_strip) if id_strip else str(img_id) - results = {"img_id": img_id} - elif img is not None: - results = {"img": img} - - results["detections"] = sv.Detections( - xyxy=preds[:, :4], - confidence=preds[:, 4], - class_id=preds[:, 5].astype(int) - ) - results["labels"] = [ - f"{self.CLASS_NAMES[class_id]} {confidence:0.2f}" - for confidence, class_id in zip(results["detections"].confidence, results["detections"].class_id) - ] - return results - - def single_image_detection(self, img, img_path=None, det_conf_thres=0.2, clf_conf_thres=0.2, id_strip=None) -> dict: - """ - Perform detection on a single image. - - Args: - img (str or np.ndarray): - Image for inference. - img_path (str, optional): - Path to the image. Defaults to None. - det_conf_thres (float, optional): - Confidence threshold for detections. Defaults to 0.2. - clf_conf_thres (float, optional): - Confidence threshold for classification. Defaults to 0.2. - id_strip (str, optional): - Characters to strip from img_id. Defaults to None. - - Returns: - dict: Detection results for the image. - """ - if isinstance(img, str): - img_path = img_path or img - img = np.array(Image.open(img_path).convert("RGB")) - if self.transforms: - img_tensor = self.transforms(img) - - preds = self.stitcher(img_tensor) - heatmap, clsmap = preds[:,:1,:,:], preds[:,1:,:,:] - counts, locs, labels, scores, dscores = self.lmds((heatmap, clsmap)) - preds_array = self.process_lmds_results(counts, locs, labels, scores, dscores, det_conf_thres, clf_conf_thres) - if img_path: - results_dict = self.results_generation(preds_array, img_id=img_path, id_strip=id_strip) - else: - results_dict = self.results_generation(preds_array, img=img) - return results_dict - - def batch_image_detection(self, data_path: str, det_conf_thres: float = 0.2, clf_conf_thres: float = 0.2, batch_size: int = 1, id_strip: str = None) -> list[dict]: - """ - Perform detection on a batch of images. - - Args: - data_path (str): Path containing all images for inference. - det_conf_thres (float, optional): Confidence threshold for detections. Defaults to 0.2. - clf_conf_thres (float, optional): Confidence threshold for classification. Defaults to 0.2. - batch_size (int, optional): Batch size for inference. Defaults to 1. - id_strip (str, optional): Characters to strip from img_id. Defaults to None. - - Returns: - list[dict]: List of detection results for all images. - """ - dataset = pw_data.DetectionImageFolder( - data_path, - transform=self.transforms - ) - # Creating a Dataloader for batching and parallel processing of the images - loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, - pin_memory=True, num_workers=0, drop_last=False) # TODO: discuss. why is num_workers 0? - - results = [] - - with tqdm(total=len(loader)) as pbar: - for batch_index, (imgs, paths, sizes) in enumerate(loader): - imgs = imgs.to(self.device) - predictions = self.stitcher(imgs[0]).detach().cpu() - heatmap, clsmap = predictions[:,:1,:,:], predictions[:,1:,:,:] - counts, locs, labels, scores, dscores = self.lmds((heatmap, clsmap)) - preds_array = self.process_lmds_results(counts, locs, labels, scores, dscores, det_conf_thres, clf_conf_thres) - results_dict = self.results_generation(preds_array, img_id=paths[0], id_strip=id_strip) - pbar.update(1) - sizes = sizes.numpy() - normalized_coords = [[x1 / sizes[0][0], y1 / sizes[0][1], x2 / sizes[0][0], y2 / sizes[0][1]] for x1, y1, x2, y2 in preds_array[:, :4]] # TODO: Check if this is correct due to xy swapping - results_dict['normalized_coords'] = normalized_coords - results.append(results_dict) - return results - - def process_lmds_results(self, counts: list, locs: list, labels: list, scores: list, dscores: list, det_conf_thres: float = 0.2, clf_conf_thres: float = 0.2) -> np.ndarray: - """ - Process the results from the Local Maxima Detection Strategy. - - Args: - counts (list): Number of detections for each species. - locs (list): Locations of the detections. - labels (list): Labels of the detections. - scores (list): Scores of the detections. - dscores (list): Detection scores. - det_conf_thres (float, optional): Confidence threshold for detections. Defaults to 0.2. - clf_conf_thres (float, optional): Confidence threshold for classification. Defaults to 0.2. - - Returns: - numpy.ndarray: Processed detection results. - """ - # Flatten the lists since we know its a single image - counts = counts[0] - locs = locs[0] - labels = labels[0] - scores = scores[0] - dscores = dscores[0] - - # Calculate the total number of detections - total_detections = sum(counts) - - # Pre-allocate based on total possible detections - preds_array = np.empty((total_detections, 6)) #xyxy, confidence, class_id format - detection_idx = 0 - valid_detections_idx = 0 # Index for valid detections after applying the confidence threshold - # Loop through each species - for specie_idx in range(len(counts)): - count = counts[specie_idx] - if count == 0: - continue - - # Get the detections for this species - species_locs = np.array(locs[detection_idx : detection_idx + count]) - species_locs[:, [0, 1]] = species_locs[:, [1, 0]] # Swap x and y in species_locs - species_scores = np.array(scores[detection_idx : detection_idx + count]) - species_dscores = np.array(dscores[detection_idx : detection_idx + count]) - species_labels = np.array(labels[detection_idx : detection_idx + count]) - - # Apply the confidence threshold - valid_detections_by_clf_score = species_scores > clf_conf_thres - valid_detections_by_det_score = species_dscores > det_conf_thres - valid_detections = np.logical_and(valid_detections_by_clf_score, valid_detections_by_det_score) - valid_detections_count = np.sum(valid_detections) - valid_detections_idx += valid_detections_count - # Fill the preds_array with the valid detections - if valid_detections_count > 0: - preds_array[valid_detections_idx - valid_detections_count : valid_detections_idx, :2] = species_locs[valid_detections] - 1 - preds_array[valid_detections_idx - valid_detections_count : valid_detections_idx, 2:4] = species_locs[valid_detections] + 1 - preds_array[valid_detections_idx - valid_detections_count : valid_detections_idx, 4] = species_scores[valid_detections] - preds_array[valid_detections_idx - valid_detections_count : valid_detections_idx, 5] = species_labels[valid_detections] - - detection_idx += count # Move to the next species - - preds_array = preds_array[:valid_detections_idx] # Remove the empty rows - - return preds_array - - def forward(self, input: torch.Tensor) -> torch.Tensor: - """ - Forward pass of the model. - - Args: - input (torch.Tensor): - Input tensor for the model. - - Returns: - torch.Tensor: Model output. - """ - # Call the forward method of the model in evaluation mode - self.model.eval() - return self.model(input) diff --git a/PytorchWildlife/models/detection/localization/model.py b/PytorchWildlife/models/detection/localization/model.py deleted file mode 100644 index 8692a3007..000000000 --- a/PytorchWildlife/models/detection/localization/model.py +++ /dev/null @@ -1,150 +0,0 @@ -__copyright__ = \ - """ - Copyright (C) 2024 University of Liège, Gembloux Agro-Bio Tech, Forest Is Life - All rights reserved. - - This source code is under the MIT License. - - Please contact the author Alexandre Delplanque (alexandre.delplanque@uliege.be) for any questions. - - Last modification: March 18, 2024 - """ -__author__ = "Alexandre Delplanque" -__license__ = "MIT License" -__version__ = "0.2.1" - - -import torch - -import torch.nn as nn -import numpy as np -import torchvision.transforms as T - -from typing import Optional - -from . import dla as dla_modules - -class HerdNet(nn.Module): - ''' HerdNet architecture ''' - - def __init__( - self, - num_layers: int = 34, - num_classes: int = 2, - pretrained: bool = True, - down_ratio: Optional[int] = 2, - head_conv: int = 64 - ): - ''' - Args: - num_layers (int, optional): number of layers of DLA. Defaults to 34. - num_classes (int, optional): number of output classes, background included. - Defaults to 2. - pretrained (bool, optional): set False to disable pretrained DLA encoder parameters - from ImageNet. Defaults to True. - down_ratio (int, optional): downsample ratio. Possible values are 1, 2, 4, 8, or 16. - Set to 1 to get output of the same size as input (i.e. no downsample). - Defaults to 2. - head_conv (int, optional): number of supplementary convolutional layers at the end - of decoder. Defaults to 64. - ''' - - super(HerdNet, self).__init__() - - assert down_ratio in [1, 2, 4, 8, 16], \ - f'Downsample ratio possible values are 1, 2, 4, 8 or 16, got {down_ratio}' - - base_name = 'dla{}'.format(num_layers) - - self.down_ratio = down_ratio - self.num_classes = num_classes - self.head_conv = head_conv - - self.first_level = int(np.log2(down_ratio)) - - # backbone - base = dla_modules.__dict__[base_name](pretrained=pretrained, return_levels=True) - setattr(self, 'base_0', base) - setattr(self, 'channels_0', base.channels) - - channels = self.channels_0 - - scales = [2 ** i for i in range(len(channels[self.first_level:]))] - self.dla_up = dla_modules.DLAUp(channels[self.first_level:], scales=scales) - - # bottleneck conv - self.bottleneck_conv = nn.Conv2d( - channels[-1], channels[-1], - kernel_size=1, stride=1, - padding=0, bias=True - ) - - # localization head - self.loc_head = nn.Sequential( - nn.Conv2d(channels[self.first_level], head_conv, - kernel_size=3, padding=1, bias=True), - nn.ReLU(inplace=True), - nn.Conv2d( - head_conv, 1, - kernel_size=1, stride=1, - padding=0, bias=True - ), - nn.Sigmoid() - ) - - self.loc_head[-2].bias.data.fill_(0.00) - - # classification head - self.cls_head = nn.Sequential( - nn.Conv2d(channels[-1], head_conv, - kernel_size=3, padding=1, bias=True), - nn.ReLU(inplace=True), - nn.Conv2d( - head_conv, self.num_classes, - kernel_size=1, stride=1, - padding=0, bias=True - ) - ) - - self.cls_head[-1].bias.data.fill_(0.00) - - def forward(self, input: torch.Tensor): - - encode = self.base_0(input) - bottleneck = self.bottleneck_conv(encode[-1]) - encode[-1] = bottleneck - - decode_hm = self.dla_up(encode[self.first_level:]) - # decode_cls = self.cls_dla_up(encode[-3:]) - - heatmap = self.loc_head(decode_hm) - clsmap = self.cls_head(bottleneck) - # clsmap = self.cls_head(decode_cls) - - return heatmap, clsmap - - def freeze(self, layers: list) -> None: - ''' Freeze all layers mentioned in the input list ''' - for layer in layers: - self._freeze_layer(layer) - - def _freeze_layer(self, layer_name: str) -> None: - for param in getattr(self, layer_name).parameters(): - param.requires_grad = False - - def reshape_classes(self, num_classes: int) -> None: - ''' Reshape architecture according to a new number of classes. - - Arg: - num_classes (int): new number of classes - ''' - - self.cls_head[-1] = nn.Conv2d( - self.head_conv, num_classes, - kernel_size=1, stride=1, - padding=0, bias=True - ) - - self.cls_head[-1].bias.data.fill_(0.00) - - self.num_classes = num_classes \ No newline at end of file diff --git a/PytorchWildlife/models/detection/localization/model_owl_c.py b/PytorchWildlife/models/detection/localization/model_owl_c.py deleted file mode 100644 index 52560f4b9..000000000 --- a/PytorchWildlife/models/detection/localization/model_owl_c.py +++ /dev/null @@ -1,97 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -import torch - -import torch.nn as nn -import numpy as np -import torchvision.transforms as T - -from typing import Optional - -from . import dla as dla_modules - - -class OWLC_Architecture(nn.Module): - ''' OWL-C (Overhead Wildlife Locator - CNN) architecture using only the localization branch ''' - - def __init__( - self, - num_layers: int = 34, - pretrained: bool = True, - down_ratio: Optional[int] = 2, - head_conv: int = 64 - ): - ''' - Args: - num_layers (int, optional): number of layers of DLA. Defaults to 34. - pretrained (bool, optional): set False to disable pretrained DLA encoder parameters - from ImageNet. Defaults to True. - down_ratio (int, optional): downsample ratio. Possible values are 1, 2, 4, 8, or 16. - Set to 1 to get output of the same size as input (i.e. no downsample). - Defaults to 2. - head_conv (int, optional): number of supplementary convolutional layers at the end - of decoder. Defaults to 64. - ''' - super(OWLC_Architecture, self).__init__() - - assert down_ratio in [1, 2, 4, 8, 16], \ - f'Downsample ratio possible values are 1, 2, 4, 8 or 16, got {down_ratio}' - - base_name = 'dla{}'.format(num_layers) - - self.down_ratio = down_ratio - self.head_conv = head_conv - - self.first_level = int(np.log2(down_ratio)) - - # backbone - base = dla_modules.__dict__[base_name](pretrained=pretrained, return_levels=True) - setattr(self, 'base_0', base) - setattr(self, 'channels_0', base.channels) - - channels = self.channels_0 - - scales = [2 ** i for i in range(len(channels[self.first_level:]))] - self.dla_up = dla_modules.DLAUp(channels[self.first_level:], scales=scales) - - # bottleneck conv - self.bottleneck_conv = nn.Conv2d( - channels[-1], channels[-1], - kernel_size=1, stride=1, - padding=0, bias=True - ) - - # localization head - self.loc_head = nn.Sequential( - nn.Conv2d(channels[self.first_level], head_conv, - kernel_size=3, padding=1, bias=True), - nn.ReLU(inplace=True), - nn.Conv2d( - head_conv, 1, - kernel_size=1, stride=1, - padding=0, bias=True - ), - nn.Sigmoid() - ) - - self.loc_head[-2].bias.data.fill_(0.00) - - def forward(self, input: torch.Tensor): - encode = self.base_0(input) - bottleneck = self.bottleneck_conv(encode[-1]) - encode[-1] = bottleneck - decode_hm = self.dla_up(encode[self.first_level:]) - heatmap = self.loc_head(decode_hm) - - return heatmap - - def freeze(self, layers: list) -> None: - ''' Freeze all layers mentioned in the input list ''' - for layer in layers: - self._freeze_layer(layer) - - def _freeze_layer(self, layer_name: str) -> None: - for param in getattr(self, layer_name).parameters(): - param.requires_grad = False - \ No newline at end of file diff --git a/PytorchWildlife/models/detection/localization/model_owl_t.py b/PytorchWildlife/models/detection/localization/model_owl_t.py deleted file mode 100644 index 7ae22350d..000000000 --- a/PytorchWildlife/models/detection/localization/model_owl_t.py +++ /dev/null @@ -1,414 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -import torch -import torch.nn as nn -import numpy as np -from typing import Optional, List, Union - -from . import dla as dla_modules - -# Swin Transformer Utilities -def window_partition(x, window_size: int): - """ - x: (B, H, W, C) - return: (num_windows*B, window_size, window_size, C) - """ - B, H, W, C = x.shape - x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) - windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) - return windows - - -def window_reverse(windows, window_size: int, H: int, W: int): - """ - windows: (num_windows*B, window_size, window_size, C) - return: (B, H, W, C) - """ - B = int(windows.shape[0] / (H * W / window_size / window_size)) - x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) - x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) - return x - - -def _build_shifted_window_mask(Hp, Wp, window_size, shift_size, device, dtype): - """ - Build attention mask for SW-MSA. - Returns: (nW, ws*ws, ws*ws) with 0 or -inf values. - """ - if shift_size == 0: - return None - - img_mask = torch.zeros((1, Hp, Wp, 1), device=device, dtype=dtype) # (1, Hp, Wp, 1) - cnt = 0 - h_slices = (slice(0, -window_size), - slice(-window_size, -shift_size), - slice(-shift_size, None)) - w_slices = (slice(0, -window_size), - slice(-window_size, -shift_size), - slice(-shift_size, None)) - for h in h_slices: - for w in w_slices: - img_mask[:, h, w, :] = cnt - cnt += 1 - - mask_windows = window_partition(img_mask, window_size) # (nW, ws, ws, 1) - mask_windows = mask_windows.view(-1, window_size * window_size) - attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) - return attn_mask - - -class DropPath(nn.Module): - """Stochastic depth: drop paths per sample on residual branches.""" - def __init__(self, drop_prob: float = 0.0): - super().__init__() - self.drop_prob = float(drop_prob) - - def forward(self, x): - if self.drop_prob == 0.0 or not self.training: - return x - keep_prob = 1.0 - self.drop_prob - shape = (x.shape[0],) + (1,) * (x.ndim - 1) - random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) - random_tensor.floor_() - return x.div(keep_prob) * random_tensor - - -class WindowAttention(nn.Module): - """Window-based Multi-head Self-Attention with relative position bias.""" - def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.): - super().__init__() - self.dim = dim - self.window_size = window_size - self.num_heads = num_heads - head_dim = dim // num_heads - self.scale = head_dim ** -0.5 - - # Relative position bias table - self.relative_position_bias_table = nn.Parameter( - torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) - - # Compute relative position indices - coords_h = torch.arange(self.window_size[0]) - coords_w = torch.arange(self.window_size[1]) - coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww - coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww - relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww - relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 - relative_coords[:, :, 0] += self.window_size[0] - 1 - relative_coords[:, :, 1] += self.window_size[1] - 1 - relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 - relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww - self.register_buffer("relative_position_index", relative_position_index) - - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) - self.proj_drop = nn.Dropout(proj_drop) - - nn.init.trunc_normal_(self.relative_position_bias_table, std=.02) - self.softmax = nn.Softmax(dim=-1) - - def forward(self, x, mask=None): - """ - x: (nW*B, N, C) - mask: (nW, N, N) or None - """ - B_, N, C = x.shape - qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) - q, k, v = qkv[0], qkv[1], qkv[2] - - q = q * self.scale - attn = (q @ k.transpose(-2, -1)) - - relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( - self.window_size[0] * self.window_size[1], - self.window_size[0] * self.window_size[1], - -1 - ) - relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, N, N - attn = attn + relative_position_bias.unsqueeze(0) - - if mask is not None: - nW = mask.shape[0] - attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) - attn = attn.view(-1, self.num_heads, N, N) - attn = self.softmax(attn) - else: - attn = self.softmax(attn) - - attn = self.attn_drop(attn) - - x = (attn @ v).transpose(1, 2).reshape(B_, N, C) - x = self.proj(x) - x = self.proj_drop(x) - return x - - -class SwinTransformerBlock(nn.Module): - """ Swin Transformer Block (W-MSA / SW-MSA + MLP + DropPath) """ - def __init__(self, dim, num_heads, window_size=7, shift_size=0, - mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.1, - act_layer=nn.GELU, norm_layer=nn.LayerNorm): - super().__init__() - self.dim = dim - self.num_heads = num_heads - self.window_size = window_size - self.shift_size = shift_size - self.mlp_ratio = mlp_ratio - assert 0 <= self.shift_size < self.window_size, "shift_size must be in [0, window_size)" - - self.norm1 = norm_layer(dim) - self.attn = WindowAttention( - dim, window_size=(self.window_size, self.window_size), num_heads=num_heads, - qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) - - self.drop_path = DropPath(drop_path) if drop_path and drop_path > 0.0 else nn.Identity() - self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = nn.Sequential( - nn.Linear(dim, mlp_hidden_dim), - act_layer(), - nn.Dropout(drop), - nn.Linear(mlp_hidden_dim, dim), - nn.Dropout(drop) - ) - - def forward(self, x, mask_matrix: Optional[torch.Tensor] = None): - """ - x: (B, H, W, C) - mask_matrix: optional; if None and shift > 0, generated internally - """ - B, H, W, C = x.shape - shortcut = x - x = self.norm1(x) - - # Pad to multiples of window_size - pad_l = pad_t = 0 - pad_r = (self.window_size - W % self.window_size) % self.window_size - pad_b = (self.window_size - H % self.window_size) % self.window_size - x = nn.functional.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) - _, Hp, Wp, _ = x.shape - - # Cyclic shift and mask - if self.shift_size > 0: - shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) - attn_mask = mask_matrix - if attn_mask is None: - attn_mask = _build_shifted_window_mask( - Hp, Wp, self.window_size, self.shift_size, x.device, x.dtype - ) - else: - shifted_x = x - attn_mask = None - - # Window partition - x_windows = window_partition(shifted_x, self.window_size) # (nW*B, ws, ws, C) - x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # (nW*B, N, C) - - # W-MSA / SW-MSA - attn_windows = self.attn(x_windows, mask=attn_mask) # (nW*B, N, C) - - # Merge windows - attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) - shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # (B, H', W', C) - - # Reverse cyclic shift - if self.shift_size > 0: - x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) - else: - x = shifted_x - - # Remove padding - if pad_r > 0 or pad_b > 0: - x = x[:, :H, :W, :].contiguous() - - # Residual connection + MLP - x = shortcut + self.drop_path(x) - x = x + self.drop_path(self.mlp(self.norm2(x))) - return x - - -class MultiscaleSwinTransformer(nn.Module): - """ - Multi-scale Swin Transformer with per-scale residual connections. - - Args: - channels_list: Channels per level (e.g., [16, 32, 64, 128, 256, 512]) - window_sizes: Window size per level (default: [8,8,8,4,4,4]) - num_heads_list: Attention heads per level (default: [1,1,2,4,8,16]) - num_layers_per_scale: Blocks per level, int or list (default: [0,1,2,2,2,3]) - """ - def __init__( - self, - channels_list: List[int], - window_sizes: Optional[List[int]] = None, - num_heads_list: Optional[List[int]] = None, - num_layers_per_scale: Optional[Union[int, List[int]]] = None, - drop_path_rate: float = 0.1 - ): - super().__init__() - self.num_scales = len(channels_list) - - # Default hyperparameters - if window_sizes is None: - window_sizes = [8, 8, 8, 4, 4, 4] - if num_heads_list is None: - num_heads_list = [1, 1, 2, 4, 8, 16] - if num_layers_per_scale is None: - num_layers_per_scale = [0, 1, 2, 2, 2, 3] - - # Allow single int to be replicated across scales - if isinstance(num_layers_per_scale, int): - num_layers_per_scale = [num_layers_per_scale] * self.num_scales - - assert len(channels_list) == len(window_sizes) == len(num_heads_list) == len(num_layers_per_scale), \ - "channels_list, window_sizes, num_heads_list, and num_layers_per_scale must have the same length" - - def _fit_heads(channels, prefer_heads): - # Adjust heads so that channels % heads == 0 - h = min(prefer_heads, channels) - while channels % h != 0 and h > 1: - h -= 1 - return max(h, 1) - - self.swin_blocks = nn.ModuleList() - for C, ws, h, nl in zip(channels_list, window_sizes, num_heads_list, num_layers_per_scale): - num_heads = _fit_heads(C, h) - # Build blocks alternating shift=0 and shift=ws//2 - scale_blocks = nn.ModuleList([ - SwinTransformerBlock( - dim=C, - num_heads=num_heads, - window_size=ws, - shift_size=0 if j % 2 == 0 else ws // 2, - mlp_ratio=4.0, - qkv_bias=True, - drop=0.0, - attn_drop=0.0, - drop_path=drop_path_rate - ) for j in range(nl) - ]) - self.swin_blocks.append(scale_blocks) - - def forward(self, feature_list: List[torch.Tensor]): - """ - Args: - feature_list: List of tensors (B, C, H, W) - Returns: - List of enhanced tensors (B, C, H, W) with per-scale residual - """ - enhanced_features = [] - for features, scale_blocks in zip(feature_list, self.swin_blocks): - if len(scale_blocks) == 0: - # No Swin at this scale, passthrough - enhanced_features.append(features) - continue - - x = features.permute(0, 2, 3, 1).contiguous() # (B, H, W, C) - for block in scale_blocks: - x = block(x, mask_matrix=None) - x = x.permute(0, 3, 1, 2).contiguous() # (B, C, H, W) - - # Per-scale residual connection - enhanced_x = features + x - enhanced_features.append(enhanced_x) - - return enhanced_features - - -class OWLT_Architecture(nn.Module): - """ - OWL-T hybrid architecture: - - DLA backbone with multi-level features - - Multi-scale Swin Transformer for feature refinement - - DLAUp decoder from first_level - - Localization head (heatmap output) - """ - - def __init__( - self, - num_layers: int = 34, - pretrained_cnn: bool = True, - down_ratio: Optional[int] = 2, - head_conv: int = 64, - swin_num_layers_per_scale: Optional[Union[int, List[int]]] = None, - swin_window_sizes: Optional[List[int]] = None, - swin_num_heads: Optional[List[int]] = None, - drop_path_rate: float = 0.1 - ): - super().__init__() - - assert down_ratio in [1, 2, 4, 8, 16], f"down_ratio must be 1/2/4/8/16, got {down_ratio}" - base_name = f'dla{num_layers}' - - self.down_ratio = down_ratio - self.head_conv = head_conv - self.first_level = int(np.log2(down_ratio)) - - # DLA backbone - base = dla_modules.__dict__[base_name](pretrained=pretrained_cnn, return_levels=True) - setattr(self, 'base_0', base) - setattr(self, 'channels_0', base.channels) - channels = self.channels_0 # e.g., [16, 32, 64, 128, 256, 512] - - # Multi-scale Swin Transformer - if swin_window_sizes is None: - swin_window_sizes = [8, 8, 8, 4, 4, 4] - if swin_num_heads is None: - swin_num_heads = [1, 1, 2, 4, 8, 16] - if swin_num_layers_per_scale is None: - swin_num_layers_per_scale = [0, 1, 2, 2, 2, 3] # L0→L5; L0=0 dado down_ratio=2 - - self.multiscale_swin = MultiscaleSwinTransformer( - channels_list=channels, - window_sizes=swin_window_sizes, - num_heads_list=swin_num_heads, - num_layers_per_scale=swin_num_layers_per_scale, - drop_path_rate=drop_path_rate - ) - - # DLAUp decoder from first_level - scales = [2 ** i for i in range(len(channels[self.first_level:]))] - self.dla_up = dla_modules.DLAUp(channels[self.first_level:], scales=scales) - - # Bottleneck convolution at last scale - self.bottleneck_conv = nn.Conv2d( - channels[-1], channels[-1], - kernel_size=1, stride=1, - padding=0, bias=True - ) - - # Localization head - self.loc_head = nn.Sequential( - nn.Conv2d(channels[self.first_level], head_conv, kernel_size=3, padding=1, bias=True), - nn.ReLU(inplace=True), - nn.Conv2d(head_conv, 1, kernel_size=1, stride=1, padding=0, bias=True), - nn.Sigmoid() - ) - #self.loc_head[-2].bias.data.fill_(0.00) - - def forward(self, input: torch.Tensor): - # Multi-scale features from DLA backbone - encode = self.base_0(input) - - # Bottleneck at last level - encode[-1] = self.bottleneck_conv(encode[-1]) - - # Refine with Swin (includes per-scale residual) - enhanced_encode = self.multiscale_swin(encode) - - # Decode from first_level - decode_hm = self.dla_up(enhanced_encode[self.first_level:]) - heatmap = self.loc_head(decode_hm) - return heatmap - - def freeze(self, layers: list) -> None: - """Freeze layers by attribute name.""" - for layer in layers: - self._freeze_layer(layer) - - def _freeze_layer(self, layer_name: str) -> None: - for param in getattr(self, layer_name).parameters(): - param.requires_grad = False diff --git a/PytorchWildlife/models/detection/rtdetr_apache/__init__.py b/PytorchWildlife/models/detection/rtdetr_apache/__init__.py deleted file mode 100644 index 77db993c5..000000000 --- a/PytorchWildlife/models/detection/rtdetr_apache/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .rtdetr_apache_base import * -from .megadetectorv6_apache import * \ No newline at end of file diff --git a/PytorchWildlife/models/detection/rtdetr_apache/megadetectorv6_apache.py b/PytorchWildlife/models/detection/rtdetr_apache/megadetectorv6_apache.py deleted file mode 100644 index a4b9151e7..000000000 --- a/PytorchWildlife/models/detection/rtdetr_apache/megadetectorv6_apache.py +++ /dev/null @@ -1,44 +0,0 @@ - -from .rtdetr_apache_base import RTDETRApacheBase - -__all__ = [ - 'MegaDetectorV6Apache' -] - -class MegaDetectorV6Apache(RTDETRApacheBase): - """ - MegaDetectorV6 is a specialized class derived from the RTDETRApacheBase class - that is specifically designed for detecting animals, persons, and vehicles. - - Attributes: - CLASS_NAMES (dict): Mapping of class IDs to their respective names. - """ - - CLASS_NAMES = { - 0: "animal", - 1: "person", - 2: "vehicle" - } - - def __init__(self, weights=None, device="cpu", pretrained=True, version='MDV6-apa-rtdetr-c'): - """ - Initializes the MegaDetectorV6 model with the option to load pretrained weights. - - Args: - weights (str, optional): Path to the weights file. - device (str, optional): Device to load the model on (e.g., "cpu" or "cuda"). Default is "cpu". - pretrained (bool, optional): Whether to load the pretrained model. Default is True. - version (str, optional): Version of the model to load. Default is 'MDV6-apa-rtdetr-c'. - """ - self.IMAGE_SIZE = 640 - - if version == "MDV6-apa-rtdetr-c": - url = "https://zenodo.org/records/15398270/files/MDV6-apa-rtdetr-c.pth?download=1" - self.MODEL_NAME = "MDV6-apa-rtdetr-c.pth" - elif version == "MDV6-apa-rtdetr-e": - url = "https://zenodo.org/records/15398270/files/MDV6-apa-rtdetr-e.pth?download=1" - self.MODEL_NAME = "MDV6-apa-rtdetr-e.pth" - else: - raise ValueError('Select a valid model version: MDV6-apa-rtdetr-c or MDV6-apa-rtdetr-e') - - super(MegaDetectorV6Apache, self).__init__(weights=weights, device=device, url=url) \ No newline at end of file diff --git a/PytorchWildlife/models/detection/rtdetr_apache/rtdetr_apache_base.py b/PytorchWildlife/models/detection/rtdetr_apache/rtdetr_apache_base.py deleted file mode 100644 index 4e5004e38..000000000 --- a/PytorchWildlife/models/detection/rtdetr_apache/rtdetr_apache_base.py +++ /dev/null @@ -1,225 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -""" RT-DETR Apache base detector class. """ - -# Importing basic libraries -import os -import supervision as sv -import wget -import torch -import torch.nn as nn -import torchvision.transforms as T - -from ..base_detector import BaseDetector -from ....data import datasets as pw_data -from PIL import Image - -import sys -from pathlib import Path -project_root = Path(__file__).resolve().parent -sys.path.append(str(project_root)) -from rtdetrv2_pytorch.src.core import YAMLConfig - -class RTDETRApacheBase(BaseDetector): - """ - Base detector class for RTDETRApacheBase framework. This class provides utility methods for - loading the model, generating results, and performing single and batch image detections. - """ - def __init__(self, weights=None, device="cpu", url=None): - """ - Initialize the RT-DETR apache detector. - - Args: - weights (str, optional): - Path to the model weights. Defaults to None. - device (str, optional): - Device for model inference. Defaults to "cpu". - url (str, optional): - URL to fetch the model weights. Defaults to None. - """ - self.transform = T.Compose([ - T.Resize((640, 640)), - T.ToTensor(), - ]) - self.weights = weights - self.device = device - self.url = url - super(RTDETRApacheBase, self).__init__(weights=self.weights, device=self.device, url=self.url) - self._load_model(weights=self.weights, device=self.device, url=self.url) - - def _load_model(self, weights=None, device="cpu", url=None): - """ - Load the RT-DETR apache model weights. - - Args: - weights (str, optional): - Path to the model weights. Defaults to None. - device (str, optional): - Device for model inference. Defaults to "cpu". - url (str, optional): - URL to fetch the model weights. Defaults to None. - Raises: - Exception: If weights are not provided. - """ - if weights: - resume = weights - elif url: - if not os.path.exists(os.path.join(torch.hub.get_dir(), "checkpoints", self.MODEL_NAME)): - os.makedirs(os.path.join(torch.hub.get_dir(), "checkpoints"), exist_ok=True) - resume = wget.download(url, out=os.path.join(torch.hub.get_dir(), "checkpoints")) - else: - resume = os.path.join(torch.hub.get_dir(), "checkpoints", self.MODEL_NAME) - else: - raise Exception("Need weights for inference.") - - if self.MODEL_NAME == "MDV6-apa-rtdetr-c.pth": - config = os.path.join(project_root, "rtdetrv2_pytorch", "configs", "rtdetrv2", "rtdetrv2_r18vd_120e_megadetector.yml") - elif self.MODEL_NAME == "MDV6-apa-rtdetr-e.pth": - config = os.path.join(project_root, "rtdetrv2_pytorch", "configs", "rtdetrv2", "rtdetrv2_r101vd_6x_megadetector.yml") - else: - raise ValueError('Select a valid model version: MDV6-apa-rtdetr-c or MDV6-apa-rtdetr-e') - - cfg = YAMLConfig(config, resume=resume) - - checkpoint = torch.load(resume, map_location='cpu') - if 'ema' in checkpoint: - state = checkpoint['ema']['module'] - else: - state = checkpoint['model'] - - cfg.model.load_state_dict(state) - - class Model(nn.Module): - def __init__(self, ) -> None: - super().__init__() - self.model = cfg.model.deploy() - self.postprocessor = cfg.postprocessor.deploy() - - def forward(self, images, orig_target_sizes): - outputs = self.model(images) - outputs = self.postprocessor(outputs, orig_target_sizes) - return outputs - - self.model = Model().to(self.device) - - def results_generation(self, preds, img_id, id_strip=None): - """ - Generate results for detection based on model predictions. - - Args: - preds (List[torch.Tensor]): - Model predictions. - img_id (str): - Image identifier. - id_strip (str, optional): - Strip specific characters from img_id. Defaults to None. - - Returns: - dict: Dictionary containing image ID, detections, and labels. - """ - class_id = preds[0].cpu().numpy().astype(int) - xyxy = preds[1].detach().cpu().numpy() - confidence = preds[2].detach().cpu().numpy() - - results = {"img_id": str(img_id).strip(id_strip)} - results["detections"] = sv.Detections( - xyxy=xyxy, - confidence=confidence, - class_id=class_id - ) - - results["labels"] = [ - f"{self.CLASS_NAMES[class_id]} {confidence:0.2f}" - for _, _, confidence, class_id, _, _ in results["detections"] - ] - - return results - - - def single_image_detection(self, img, img_path=None, det_conf_thres=0.2, id_strip=None): - """ - Perform detection on a single image. - - Args: - img (str or ndarray): - Image path or ndarray of images. - img_path (str, optional): - Image path or identifier. - det_conf_thres (float, optional): - Confidence threshold for predictions. Defaults to 0.2. - id_strip (str, optional): - Characters to strip from img_id. Defaults to None. - - Returns: - dict: Detection results. - """ - if type(img) == str: - if img_path is None: - img_path = img - im_pil = Image.open(img_path).convert('RGB') - else: - im_pil = Image.fromarray(img) - - w, h = im_pil.size - orig_size = torch.tensor([w, h])[None].to(self.device) - im_data = self.transform(im_pil)[None].to(self.device) - labels, boxes, scores = self.model(im_data, orig_size) - - scr = scores[0] - lab = labels[0][scr > det_conf_thres] - box = boxes[0][scr > det_conf_thres] - scrs = scores[0][scr > det_conf_thres] - - return self.results_generation([lab, box, scrs], img_path, id_strip) - - def batch_image_detection(self, data_path, batch_size=16, det_conf_thres=0.2, id_strip=None): - """ - Perform detection on a batch of images. - - Args: - data_path (str): - Path containing all images for inference. - batch_size (int, optional): - Batch size for inference. Defaults to 16. - det_conf_thres (float, optional): - Confidence threshold for predictions. Defaults to 0.2. - id_strip (str, optional): - Characters to strip from img_id. Defaults to None. - extension (str, optional): - Image extension to search for. Defaults to "JPG" - - Returns: - list: List of detection results for all images. - """ - dataset = pw_data.DetectionImageFolder( - data_path, - transform=self.transform, - ) - - results = [] - for i in range(len(dataset)): - im_pil = Image.open(dataset.images[i]).convert('RGB') - w, h = im_pil.size - orig_size = torch.tensor([w, h])[None].to(self.device) - im_data = self.transform(im_pil)[None].to(self.device) - - labels, boxes, scores = self.model(im_data, orig_size) - - scr = scores[0] - lab = labels[0][scr > det_conf_thres] - box = boxes[0][scr > det_conf_thres] - scrs = scores[0][scr > det_conf_thres] - - res = self.results_generation([lab, box, scrs], dataset.images[i], id_strip) - - # Normalize the coordinates for timelapse compatibility - size = orig_size[0].cpu().numpy() - normalized_coords = [[x1 / size[1], y1 / size[0], x2 / size[1], y2 / size[0]] for x1, y1, x2, y2 in res["detections"].xyxy] - res["normalized_coords"] = normalized_coords - results.append(res) - - return results - - - diff --git a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/__init__.py b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/configs/__init__.py b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/configs/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/configs/dataset/__init__.py b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/configs/dataset/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/configs/dataset/megadetector_detection.yml b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/configs/dataset/megadetector_detection.yml deleted file mode 100644 index 0e4f046a6..000000000 --- a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/configs/dataset/megadetector_detection.yml +++ /dev/null @@ -1,3 +0,0 @@ -task: detection -num_classes: 3 -remap_mscoco_category: False \ No newline at end of file diff --git a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/configs/rtdetrv2/__init__.py b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/configs/rtdetrv2/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/configs/rtdetrv2/include/__init__.py b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/configs/rtdetrv2/include/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/configs/rtdetrv2/include/rtdetrv2_r50vd.yml b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/configs/rtdetrv2/include/rtdetrv2_r50vd.yml deleted file mode 100644 index a5c14909b..000000000 --- a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/configs/rtdetrv2/include/rtdetrv2_r50vd.yml +++ /dev/null @@ -1,83 +0,0 @@ -task: detection - -model: RTDETR -criterion: RTDETRCriterionv2 -postprocessor: RTDETRPostProcessor - - -use_focal_loss: True -eval_spatial_size: [640, 640] # h w - - -RTDETR: - backbone: PResNet - encoder: HybridEncoder - decoder: RTDETRTransformerv2 - - -PResNet: - depth: 50 - variant: d - freeze_at: 0 - return_idx: [1, 2, 3] - num_stages: 4 - freeze_norm: True - pretrained: True - - -HybridEncoder: - in_channels: [512, 1024, 2048] - feat_strides: [8, 16, 32] - - # intra - hidden_dim: 256 - use_encoder_idx: [2] - num_encoder_layers: 1 - nhead: 8 - dim_feedforward: 1024 - dropout: 0. - enc_act: 'gelu' - - # cross - expansion: 1.0 - depth_mult: 1 - act: 'silu' - - -RTDETRTransformerv2: - feat_channels: [256, 256, 256] - feat_strides: [8, 16, 32] - hidden_dim: 256 - num_levels: 3 - - num_layers: 6 - num_queries: 300 - - num_denoising: 100 - label_noise_ratio: 0.5 - box_noise_scale: 1.0 # 1.0 0.4 - - eval_idx: -1 - - # NEW - num_points: [4, 4, 4] # [3,3,3] [2,2,2] - cross_attn_method: default # default, discrete - query_select_method: default # default, agnostic - - -RTDETRPostProcessor: - num_top_queries: 300 - - -RTDETRCriterionv2: - weight_dict: {loss_vfl: 1, loss_bbox: 5, loss_giou: 2,} - losses: ['vfl', 'boxes', ] - alpha: 0.75 - gamma: 2.0 - - matcher: - type: HungarianMatcher - weight_dict: {cost_class: 2, cost_bbox: 5, cost_giou: 2} - alpha: 0.25 - gamma: 2.0 - diff --git a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/configs/rtdetrv2/rtdetrv2_r101vd_6x_megadetector.yml b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/configs/rtdetrv2/rtdetrv2_r101vd_6x_megadetector.yml deleted file mode 100644 index 780c90eb5..000000000 --- a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/configs/rtdetrv2/rtdetrv2_r101vd_6x_megadetector.yml +++ /dev/null @@ -1,18 +0,0 @@ -__include__: [ - '../dataset/megadetector_detection.yml', - './include/rtdetrv2_r50vd.yml', -] - - -PResNet: - depth: 101 - - -HybridEncoder: - # intra - hidden_dim: 384 - dim_feedforward: 2048 - - -RTDETRTransformerv2: - feat_channels: [384, 384, 384] diff --git a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/configs/rtdetrv2/rtdetrv2_r18vd_120e_megadetector.yml b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/configs/rtdetrv2/rtdetrv2_r18vd_120e_megadetector.yml deleted file mode 100644 index c703545f0..000000000 --- a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/configs/rtdetrv2/rtdetrv2_r18vd_120e_megadetector.yml +++ /dev/null @@ -1,21 +0,0 @@ -__include__: [ - '../dataset/megadetector_detection.yml', - './include/rtdetrv2_r50vd.yml', -] - - -PResNet: - depth: 18 - freeze_at: -1 - freeze_norm: False - pretrained: True - - -HybridEncoder: - in_channels: [128, 256, 512] - hidden_dim: 256 - expansion: 0.5 - - -RTDETRTransformerv2: - num_layers: 3 \ No newline at end of file diff --git a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/__init__.py b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/__init__.py deleted file mode 100644 index 8b850c549..000000000 --- a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -"""Copyright(c) 2023 lyuwenyu. All Rights Reserved. -""" - -# for register purpose -from . import backbone -from . import rtdetr \ No newline at end of file diff --git a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/backbone/__init__.py b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/backbone/__init__.py deleted file mode 100644 index 53ab01265..000000000 --- a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/backbone/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -"""Copyright(c) 2023 lyuwenyu. All Rights Reserved. -""" - -from .presnet import PResNet \ No newline at end of file diff --git a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/backbone/common.py b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/backbone/common.py deleted file mode 100644 index e3f54ea2e..000000000 --- a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/backbone/common.py +++ /dev/null @@ -1,86 +0,0 @@ -"""Copyright(c) 2023 lyuwenyu. All Rights Reserved. -""" - -import torch -import torch.nn as nn - - -class FrozenBatchNorm2d(nn.Module): - """copy and modified from https://github.com/facebookresearch/detr/blob/master/models/backbone.py - BatchNorm2d where the batch statistics and the affine parameters are fixed. - Copy-paste from torchvision.misc.ops with added eps before rqsrt, - without which any other models than torchvision.models.resnet[18,34,50,101] - produce nans. - """ - def __init__(self, num_features, eps=1e-5): - super(FrozenBatchNorm2d, self).__init__() - n = num_features - self.register_buffer("weight", torch.ones(n)) - self.register_buffer("bias", torch.zeros(n)) - self.register_buffer("running_mean", torch.zeros(n)) - self.register_buffer("running_var", torch.ones(n)) - self.eps = eps - self.num_features = n - - def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, - missing_keys, unexpected_keys, error_msgs): - num_batches_tracked_key = prefix + 'num_batches_tracked' - if num_batches_tracked_key in state_dict: - del state_dict[num_batches_tracked_key] - - super(FrozenBatchNorm2d, self)._load_from_state_dict( - state_dict, prefix, local_metadata, strict, - missing_keys, unexpected_keys, error_msgs) - - def forward(self, x): - # move reshapes to the beginning - # to make it fuser-friendly - w = self.weight.reshape(1, -1, 1, 1) - b = self.bias.reshape(1, -1, 1, 1) - rv = self.running_var.reshape(1, -1, 1, 1) - rm = self.running_mean.reshape(1, -1, 1, 1) - scale = w * (rv + self.eps).rsqrt() - bias = b - rm * scale - return x * scale + bias - - def extra_repr(self): - return ( - "{num_features}, eps={eps}".format(**self.__dict__) - ) - -def get_activation(act: str, inplace: bool=True): - """get activation - """ - if act is None: - return nn.Identity() - - elif isinstance(act, nn.Module): - return act - - act = act.lower() - - if act == 'silu' or act == 'swish': - m = nn.SiLU() - - elif act == 'relu': - m = nn.ReLU() - - elif act == 'leaky_relu': - m = nn.LeakyReLU() - - elif act == 'silu': - m = nn.SiLU() - - elif act == 'gelu': - m = nn.GELU() - - elif act == 'hardsigmoid': - m = nn.Hardsigmoid() - - else: - raise RuntimeError('') - - if hasattr(m, 'inplace'): - m.inplace = inplace - - return m diff --git a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/backbone/presnet.py b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/backbone/presnet.py deleted file mode 100644 index 7401c0ef6..000000000 --- a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/backbone/presnet.py +++ /dev/null @@ -1,244 +0,0 @@ -"""Copyright(c) 2023 lyuwenyu. All Rights Reserved. -""" -import torch -import torch.nn as nn -import torch.nn.functional as F - -from collections import OrderedDict - -from .common import get_activation, FrozenBatchNorm2d - -from ..core import register - - -__all__ = ['PResNet'] - - -ResNet_cfg = { - 18: [2, 2, 2, 2], - 34: [3, 4, 6, 3], - 50: [3, 4, 6, 3], - 101: [3, 4, 23, 3], -} - - -donwload_url = { - 18: 'https://github.com/lyuwenyu/storage/releases/download/v0.1/ResNet18_vd_pretrained_from_paddle.pth', - 34: 'https://github.com/lyuwenyu/storage/releases/download/v0.1/ResNet34_vd_pretrained_from_paddle.pth', - 50: 'https://github.com/lyuwenyu/storage/releases/download/v0.1/ResNet50_vd_ssld_v2_pretrained_from_paddle.pth', - 101: 'https://github.com/lyuwenyu/storage/releases/download/v0.1/ResNet101_vd_ssld_pretrained_from_paddle.pth', -} - - -class ConvNormLayer(nn.Module): - def __init__(self, ch_in, ch_out, kernel_size, stride, padding=None, bias=False, act=None): - super().__init__() - self.conv = nn.Conv2d( - ch_in, - ch_out, - kernel_size, - stride, - padding=(kernel_size-1)//2 if padding is None else padding, - bias=bias) - self.norm = nn.BatchNorm2d(ch_out) - self.act = get_activation(act) - - def forward(self, x): - return self.act(self.norm(self.conv(x))) - - -class BasicBlock(nn.Module): - expansion = 1 - - def __init__(self, ch_in, ch_out, stride, shortcut, act='relu', variant='b'): - super().__init__() - - self.shortcut = shortcut - - if not shortcut: - if variant == 'd' and stride == 2: - self.short = nn.Sequential(OrderedDict([ - ('pool', nn.AvgPool2d(2, 2, 0, ceil_mode=True)), - ('conv', ConvNormLayer(ch_in, ch_out, 1, 1)) - ])) - else: - self.short = ConvNormLayer(ch_in, ch_out, 1, stride) - - self.branch2a = ConvNormLayer(ch_in, ch_out, 3, stride, act=act) - self.branch2b = ConvNormLayer(ch_out, ch_out, 3, 1, act=None) - self.act = nn.Identity() if act is None else get_activation(act) - - - def forward(self, x): - out = self.branch2a(x) - out = self.branch2b(out) - if self.shortcut: - short = x - else: - short = self.short(x) - - out = out + short - out = self.act(out) - - return out - - -class BottleNeck(nn.Module): - expansion = 4 - - def __init__(self, ch_in, ch_out, stride, shortcut, act='relu', variant='b'): - super().__init__() - - if variant == 'a': - stride1, stride2 = stride, 1 - else: - stride1, stride2 = 1, stride - - width = ch_out - - self.branch2a = ConvNormLayer(ch_in, width, 1, stride1, act=act) - self.branch2b = ConvNormLayer(width, width, 3, stride2, act=act) - self.branch2c = ConvNormLayer(width, ch_out * self.expansion, 1, 1) - - self.shortcut = shortcut - if not shortcut: - if variant == 'd' and stride == 2: - self.short = nn.Sequential(OrderedDict([ - ('pool', nn.AvgPool2d(2, 2, 0, ceil_mode=True)), - ('conv', ConvNormLayer(ch_in, ch_out * self.expansion, 1, 1)) - ])) - else: - self.short = ConvNormLayer(ch_in, ch_out * self.expansion, 1, stride) - - self.act = nn.Identity() if act is None else get_activation(act) - - def forward(self, x): - out = self.branch2a(x) - out = self.branch2b(out) - out = self.branch2c(out) - - if self.shortcut: - short = x - else: - short = self.short(x) - - out = out + short - out = self.act(out) - - return out - - -class Blocks(nn.Module): - def __init__(self, block, ch_in, ch_out, count, stage_num, act='relu', variant='b'): - super().__init__() - - self.blocks = nn.ModuleList() - for i in range(count): - self.blocks.append( - block( - ch_in, - ch_out, - stride=2 if i == 0 and stage_num != 2 else 1, - shortcut=False if i == 0 else True, - variant=variant, - act=act) - ) - - if i == 0: - ch_in = ch_out * block.expansion - - def forward(self, x): - out = x - for block in self.blocks: - out = block(out) - return out - - -@register() -class PResNet(nn.Module): - def __init__( - self, - depth, - variant='d', - num_stages=4, - return_idx=[0, 1, 2, 3], - act='relu', - freeze_at=-1, - freeze_norm=True, - pretrained=False): - super().__init__() - - block_nums = ResNet_cfg[depth] - ch_in = 64 - if variant in ['c', 'd']: - conv_def = [ - [3, ch_in // 2, 3, 2, "conv1_1"], - [ch_in // 2, ch_in // 2, 3, 1, "conv1_2"], - [ch_in // 2, ch_in, 3, 1, "conv1_3"], - ] - else: - conv_def = [[3, ch_in, 7, 2, "conv1_1"]] - - self.conv1 = nn.Sequential(OrderedDict([ - (name, ConvNormLayer(cin, cout, k, s, act=act)) for cin, cout, k, s, name in conv_def - ])) - - ch_out_list = [64, 128, 256, 512] - block = BottleNeck if depth >= 50 else BasicBlock - - _out_channels = [block.expansion * v for v in ch_out_list] - _out_strides = [4, 8, 16, 32] - - self.res_layers = nn.ModuleList() - for i in range(num_stages): - stage_num = i + 2 - self.res_layers.append( - Blocks(block, ch_in, ch_out_list[i], block_nums[i], stage_num, act=act, variant=variant) - ) - ch_in = _out_channels[i] - - self.return_idx = return_idx - self.out_channels = [_out_channels[_i] for _i in return_idx] - self.out_strides = [_out_strides[_i] for _i in return_idx] - - if freeze_at >= 0: - self._freeze_parameters(self.conv1) - for i in range(min(freeze_at, num_stages)): - self._freeze_parameters(self.res_layers[i]) - - if freeze_norm: - self._freeze_norm(self) - - if pretrained: - if isinstance(pretrained, bool) or 'http' in pretrained: - state = torch.hub.load_state_dict_from_url(donwload_url[depth], map_location='cpu') - else: - state = torch.load(pretrained, map_location='cpu') - self.load_state_dict(state) - print(f'Load PResNet{depth} state_dict') - - def _freeze_parameters(self, m: nn.Module): - for p in m.parameters(): - p.requires_grad = False - - def _freeze_norm(self, m: nn.Module): - if isinstance(m, nn.BatchNorm2d): - m = FrozenBatchNorm2d(m.num_features) - else: - for name, child in m.named_children(): - _child = self._freeze_norm(child) - if _child is not child: - setattr(m, name, _child) - return m - - def forward(self, x): - conv1 = self.conv1(x) - x = F.max_pool2d(conv1, kernel_size=3, stride=2, padding=1) - outs = [] - for idx, stage in enumerate(self.res_layers): - x = stage(x) - if idx in self.return_idx: - outs.append(x) - return outs - - diff --git a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/core/__init__.py b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/core/__init__.py deleted file mode 100644 index e1078b225..000000000 --- a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/core/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -"""Copyright(c) 2023 lyuwenyu. All Rights Reserved. -""" - -from .workspace import * -from .yaml_utils import * -from ._config import BaseConfig -from .yaml_config import YAMLConfig diff --git a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/core/_config.py b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/core/_config.py deleted file mode 100644 index 91d720ca2..000000000 --- a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/core/_config.py +++ /dev/null @@ -1,76 +0,0 @@ -"""Copyright(c) 2023 lyuwenyu. All Rights Reserved. -""" - -import torch.nn as nn -from torch.utils.data import Dataset, DataLoader -from torch.optim import Optimizer -from torch.optim.lr_scheduler import LRScheduler -from torch.cuda.amp.grad_scaler import GradScaler -from torch.utils.tensorboard import SummaryWriter -from typing import Callable - - -__all__ = ['BaseConfig', ] - - -class BaseConfig(object): - - def __init__(self) -> None: - super().__init__() - - self.task :str = None - - # instance / function - self._model :nn.Module = None - self._postprocessor :nn.Module = None - self._criterion :nn.Module = None - self._optimizer :Optimizer = None - self._lr_scheduler :LRScheduler = None - self._lr_warmup_scheduler: LRScheduler = None - self._train_dataloader :DataLoader = None - self._val_dataloader :DataLoader = None - self._ema :nn.Module = None - self._scaler :GradScaler = None - self._train_dataset :Dataset = None - self._val_dataset :Dataset = None - self._collate_fn :Callable = None - self._evaluator :Callable[[nn.Module, DataLoader, str], ] = None - self._writer: SummaryWriter = None - - # dataset - self.num_workers :int = 0 - self.batch_size :int = None - self._train_batch_size :int = None - self._val_batch_size :int = None - self._train_shuffle: bool = None - self._val_shuffle: bool = None - - # runtime - self.resume :str = None - self.tuning :str = None - - self.epoches :int = None - self.last_epoch :int = -1 - - self.use_amp :bool = False - self.use_ema :bool = False - self.ema_decay :float = 0.9999 - self.ema_warmups: int = 2000 - self.sync_bn :bool = False - self.clip_max_norm : float = 0. - self.find_unused_parameters :bool = None - - self.seed :int = None - self.print_freq :int = None - self.checkpoint_freq :int = 1 - self.output_dir :str = None - self.summary_dir :str = None - self.device : str = '' - - @property - def model(self, ) -> nn.Module: - return self._model - - @property - def postprocessor(self, ) -> nn.Module: - return self._postprocessor \ No newline at end of file diff --git a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/core/workspace.py b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/core/workspace.py deleted file mode 100644 index ea0e41af6..000000000 --- a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/core/workspace.py +++ /dev/null @@ -1,171 +0,0 @@ -""""Copyright(c) 2023 lyuwenyu. All Rights Reserved. -""" - -import inspect -import importlib -import functools -import inspect -from collections import defaultdict -from typing import Any, Dict, Optional, List - - -GLOBAL_CONFIG = defaultdict(dict) - - -def register(dct :Any=GLOBAL_CONFIG, name=None, force=False): - """ - dct: - if dct is Dict, register foo into dct as key-value pair - if dct is Clas, register as modules attibute - force - whether force register. - """ - def decorator(foo): - register_name = foo.__name__ if name is None else name - if not force: - if inspect.isclass(dct): - assert not hasattr(dct, foo.__name__), \ - f'module {dct.__name__} has {foo.__name__}' - else: - assert foo.__name__ not in dct, \ - f'{foo.__name__} has been already registered' - - if inspect.isfunction(foo): - @functools.wraps(foo) - def wrap_func(*args, **kwargs): - return foo(*args, **kwargs) - if isinstance(dct, dict): - dct[foo.__name__] = wrap_func - elif inspect.isclass(dct): - setattr(dct, foo.__name__, wrap_func) - else: - raise AttributeError('') - return wrap_func - - elif inspect.isclass(foo): - dct[register_name] = extract_schema(foo) - - else: - raise ValueError(f'Do not support {type(foo)} register') - - return foo - - return decorator - - - -def extract_schema(module: type): - """ - Args: - module (type), - Return: - Dict, - """ - argspec = inspect.getfullargspec(module.__init__) - arg_names = [arg for arg in argspec.args if arg != 'self'] - num_defualts = len(argspec.defaults) if argspec.defaults is not None else 0 - num_requires = len(arg_names) - num_defualts - - schame = dict() - schame['_name'] = module.__name__ - schame['_pymodule'] = importlib.import_module(module.__module__) - schame['_inject'] = getattr(module, '__inject__', []) - schame['_share'] = getattr(module, '__share__', []) - schame['_kwargs'] = {} - for i, name in enumerate(arg_names): - if name in schame['_share']: - assert i >= num_requires, 'share config must have default value.' - value = argspec.defaults[i - num_requires] - - elif i >= num_requires: - value = argspec.defaults[i - num_requires] - - else: - value = None - - schame[name] = value - schame['_kwargs'][name] = value - - return schame - - -def create(type_or_name, global_cfg=GLOBAL_CONFIG, **kwargs): - """ - """ - assert type(type_or_name) in (type, str), 'create should be modules or name.' - - name = type_or_name if isinstance(type_or_name, str) else type_or_name.__name__ - - if name in global_cfg: - if hasattr(global_cfg[name], '__dict__'): - return global_cfg[name] - else: - raise ValueError('The module {} is not registered'.format(name)) - - cfg = global_cfg[name] - - if isinstance(cfg, dict) and 'type' in cfg: - _cfg: dict = global_cfg[cfg['type']] - # clean args - _keys = [k for k in _cfg.keys() if not k.startswith('_')] - for _arg in _keys: - del _cfg[_arg] - _cfg.update(_cfg['_kwargs']) # restore default args - _cfg.update(cfg) # load config args - _cfg.update(kwargs) # TODO recive extra kwargs - name = _cfg.pop('type') # pop extra key `type` (from cfg) - - return create(name, global_cfg) - - module = getattr(cfg['_pymodule'], name) - module_kwargs = {} - module_kwargs.update(cfg) - - # shared var - for k in cfg['_share']: - if k in global_cfg: - module_kwargs[k] = global_cfg[k] - else: - module_kwargs[k] = cfg[k] - - # inject - for k in cfg['_inject']: - _k = cfg[k] - - if _k is None: - continue - - if isinstance(_k, str): - if _k not in global_cfg: - raise ValueError(f'Missing inject config of {_k}.') - - _cfg = global_cfg[_k] - - if isinstance(_cfg, dict): - module_kwargs[k] = create(_cfg['_name'], global_cfg) - else: - module_kwargs[k] = _cfg - - elif isinstance(_k, dict): - if 'type' not in _k.keys(): - raise ValueError(f'Missing inject for `type` style.') - - _type = str(_k['type']) - if _type not in global_cfg: - raise ValueError(f'Missing {_type} in inspect stage.') - - _cfg: dict = global_cfg[_type] - _keys = [k for k in _cfg.keys() if not k.startswith('_')] - for _arg in _keys: - del _cfg[_arg] - _cfg.update(_cfg['_kwargs']) # restore default values - _cfg.update(_k) # load config args - name = _cfg.pop('type') # pop extra key (`type` from _k) - module_kwargs[k] = create(name, global_cfg) - - else: - raise ValueError(f'Inject does not support {_k}') - - module_kwargs = {k: v for k, v in module_kwargs.items() if not k.startswith('_')} - - return module(**module_kwargs) \ No newline at end of file diff --git a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/core/yaml_config.py b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/core/yaml_config.py deleted file mode 100644 index 8e0e02fd4..000000000 --- a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/core/yaml_config.py +++ /dev/null @@ -1,38 +0,0 @@ -"""Copyright(c) 2023 lyuwenyu. All Rights Reserved. -""" - -import torch -import copy - -from ._config import BaseConfig -from .workspace import create -from .yaml_utils import load_config, merge_config, merge_dict - -class YAMLConfig(BaseConfig): - def __init__(self, cfg_path: str, **kwargs) -> None: - super().__init__() - - cfg = load_config(cfg_path) - cfg = merge_dict(cfg, kwargs) - - self.yaml_cfg = copy.deepcopy(cfg) - - for k in super().__dict__: - if not k.startswith('_') and k in cfg: - self.__dict__[k] = cfg[k] - - @property - def global_cfg(self, ): - return merge_config(self.yaml_cfg, inplace=False, overwrite=False) - - @property - def model(self, ) -> torch.nn.Module: - if self._model is None and 'model' in self.yaml_cfg: - self._model = create(self.yaml_cfg['model'], self.global_cfg) - return super().model - - @property - def postprocessor(self, ) -> torch.nn.Module: - if self._postprocessor is None and 'postprocessor' in self.yaml_cfg: - self._postprocessor = create(self.yaml_cfg['postprocessor'], self.global_cfg) - return super().postprocessor \ No newline at end of file diff --git a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/core/yaml_utils.py b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/core/yaml_utils.py deleted file mode 100644 index 1033bcf21..000000000 --- a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/core/yaml_utils.py +++ /dev/null @@ -1,97 +0,0 @@ -""""Copyright(c) 2023 lyuwenyu. All Rights Reserved. -""" - -import os -import copy -import yaml -from typing import Any, Dict, Optional, List - -from .workspace import GLOBAL_CONFIG - -__all__ = [ - 'load_config', - 'merge_config', - 'merge_dict', -] - - -INCLUDE_KEY = '__include__' - - -def load_config(file_path, cfg=dict()): - """load config - """ - _, ext = os.path.splitext(file_path) - assert ext in ['.yml', '.yaml'], "only support yaml files" - - with open(file_path) as f: - file_cfg = yaml.load(f, Loader=yaml.Loader) - if file_cfg is None: - return {} - - if INCLUDE_KEY in file_cfg: - base_yamls = list(file_cfg[INCLUDE_KEY]) - for base_yaml in base_yamls: - if base_yaml.startswith('~'): - base_yaml = os.path.expanduser(base_yaml) - - if not base_yaml.startswith('/'): - base_yaml = os.path.join(os.path.dirname(file_path), base_yaml) - - with open(base_yaml) as f: - base_cfg = load_config(base_yaml, cfg) - merge_dict(cfg, base_cfg) - - return merge_dict(cfg, file_cfg) - - -def merge_dict(dct, another_dct, inplace=True) -> Dict: - """merge another_dct into dct - """ - def _merge(dct, another) -> Dict: - for k in another: - if (k in dct and isinstance(dct[k], dict) and isinstance(another[k], dict)): - _merge(dct[k], another[k]) - else: - dct[k] = another[k] - - return dct - - if not inplace: - dct = copy.deepcopy(dct) - - return _merge(dct, another_dct) - - -def merge_config(cfg, another_cfg=GLOBAL_CONFIG, inplace: bool=False, overwrite: bool=False): - """ - Merge another_cfg into cfg, return the merged config - - Example: - - cfg1 = load_config('./rtdetrv2_r18vd_6x_coco.yml') - cfg1 = merge_config(cfg, inplace=True) - - cfg2 = load_config('./rtdetr_r50vd_6x_coco.yml') - cfg2 = merge_config(cfg2, inplace=True) - - model1 = create(cfg1['model'], cfg1) - model2 = create(cfg2['model'], cfg2) - """ - def _merge(dct, another): - for k in another: - if k not in dct: - dct[k] = another[k] - - elif isinstance(dct[k], dict) and isinstance(another[k], dict): - _merge(dct[k], another[k]) - - elif overwrite: - dct[k] = another[k] - - return cfg - - if not inplace: - cfg = copy.deepcopy(cfg) - - return _merge(cfg, another_cfg) diff --git a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/__init__.py b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/__init__.py deleted file mode 100644 index 1df1f9669..000000000 --- a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -"""Copyright(c) 2023 lyuwenyu. All Rights Reserved. -""" - - -from .rtdetr import RTDETR -from .hybrid_encoder import HybridEncoder -from .rtdetr_postprocessor import RTDETRPostProcessor -from .rtdetrv2_decoder import RTDETRTransformerv2 \ No newline at end of file diff --git a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/box_ops.py b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/box_ops.py deleted file mode 100644 index c45752a6c..000000000 --- a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/box_ops.py +++ /dev/null @@ -1,22 +0,0 @@ -""" -# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved -https://github.com/facebookresearch/detr/blob/main/util/box_ops.py -""" - -import torch -from torch import Tensor -from torchvision.ops.boxes import box_area - - -def box_cxcywh_to_xyxy(x: Tensor) -> Tensor: - x_c, y_c, w, h = x.unbind(-1) - b = [(x_c - 0.5 * w), (y_c - 0.5 * h), - (x_c + 0.5 * w), (y_c + 0.5 * h)] - return torch.stack(b, dim=-1) - - -def box_xyxy_to_cxcywh(x: Tensor) -> Tensor: - x0, y0, x1, y1 = x.unbind(-1) - b = [(x0 + x1) / 2, (y0 + y1) / 2, - (x1 - x0), (y1 - y0)] - return torch.stack(b, dim=-1) \ No newline at end of file diff --git a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/denoising.py b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/denoising.py deleted file mode 100644 index c50f214c6..000000000 --- a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/denoising.py +++ /dev/null @@ -1,99 +0,0 @@ -"""Copyright(c) 2023 lyuwenyu. All Rights Reserved. -""" - -import torch - -from .utils import inverse_sigmoid -from .box_ops import box_cxcywh_to_xyxy, box_xyxy_to_cxcywh - - -def get_contrastive_denoising_training_group(targets, - num_classes, - num_queries, - class_embed, - num_denoising=100, - label_noise_ratio=0.5, - box_noise_scale=1.0,): - """cnd""" - if num_denoising <= 0: - return None, None, None, None - - num_gts = [len(t['labels']) for t in targets] - device = targets[0]['labels'].device - - max_gt_num = max(num_gts) - if max_gt_num == 0: - return None, None, None, None - - num_group = num_denoising // max_gt_num - num_group = 1 if num_group == 0 else num_group - # pad gt to max_num of a batch - bs = len(num_gts) - - input_query_class = torch.full([bs, max_gt_num], num_classes, dtype=torch.int32, device=device) - input_query_bbox = torch.zeros([bs, max_gt_num, 4], device=device) - pad_gt_mask = torch.zeros([bs, max_gt_num], dtype=torch.bool, device=device) - - for i in range(bs): - num_gt = num_gts[i] - if num_gt > 0: - input_query_class[i, :num_gt] = targets[i]['labels'] - input_query_bbox[i, :num_gt] = targets[i]['boxes'] - pad_gt_mask[i, :num_gt] = 1 - # each group has positive and negative queries. - input_query_class = input_query_class.tile([1, 2 * num_group]) - input_query_bbox = input_query_bbox.tile([1, 2 * num_group, 1]) - pad_gt_mask = pad_gt_mask.tile([1, 2 * num_group]) - # positive and negative mask - negative_gt_mask = torch.zeros([bs, max_gt_num * 2, 1], device=device) - negative_gt_mask[:, max_gt_num:] = 1 - negative_gt_mask = negative_gt_mask.tile([1, num_group, 1]) - positive_gt_mask = 1 - negative_gt_mask - # contrastive denoising training positive index - positive_gt_mask = positive_gt_mask.squeeze(-1) * pad_gt_mask - dn_positive_idx = torch.nonzero(positive_gt_mask)[:, 1] - dn_positive_idx = torch.split(dn_positive_idx, [n * num_group for n in num_gts]) - # total denoising queries - num_denoising = int(max_gt_num * 2 * num_group) - - if label_noise_ratio > 0: - mask = torch.rand_like(input_query_class, dtype=torch.float) < (label_noise_ratio * 0.5) - # randomly put a new one here - new_label = torch.randint_like(mask, 0, num_classes, dtype=input_query_class.dtype) - input_query_class = torch.where(mask & pad_gt_mask, new_label, input_query_class) - - if box_noise_scale > 0: - known_bbox = box_cxcywh_to_xyxy(input_query_bbox) - diff = torch.tile(input_query_bbox[..., 2:] * 0.5, [1, 1, 2]) * box_noise_scale - rand_sign = torch.randint_like(input_query_bbox, 0, 2) * 2.0 - 1.0 - rand_part = torch.rand_like(input_query_bbox) - rand_part = (rand_part + 1.0) * negative_gt_mask + rand_part * (1 - negative_gt_mask) - known_bbox += (rand_sign * rand_part * diff) - known_bbox = torch.clip(known_bbox, min=0.0, max=1.0) - input_query_bbox = box_xyxy_to_cxcywh(known_bbox) - input_query_bbox_unact = inverse_sigmoid(input_query_bbox) - - input_query_logits = class_embed(input_query_class) - - tgt_size = num_denoising + num_queries - attn_mask = torch.full([tgt_size, tgt_size], False, dtype=torch.bool, device=device) - # match query cannot see the reconstruction - attn_mask[num_denoising:, :num_denoising] = True - - # reconstruct cannot see each other - for i in range(num_group): - if i == 0: - attn_mask[max_gt_num * 2 * i: max_gt_num * 2 * (i + 1), max_gt_num * 2 * (i + 1): num_denoising] = True - if i == num_group - 1: - attn_mask[max_gt_num * 2 * i: max_gt_num * 2 * (i + 1), :max_gt_num * i * 2] = True - else: - attn_mask[max_gt_num * 2 * i: max_gt_num * 2 * (i + 1), max_gt_num * 2 * (i + 1): num_denoising] = True - attn_mask[max_gt_num * 2 * i: max_gt_num * 2 * (i + 1), :max_gt_num * 2 * i] = True - - dn_meta = { - "dn_positive_idx": dn_positive_idx, - "dn_num_group": num_group, - "dn_num_split": [num_denoising, num_queries] - } - - return input_query_logits, input_query_bbox_unact, attn_mask, dn_meta diff --git a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/hybrid_encoder.py b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/hybrid_encoder.py deleted file mode 100644 index 15e5acf7e..000000000 --- a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/hybrid_encoder.py +++ /dev/null @@ -1,330 +0,0 @@ -"""Copyright(c) 2023 lyuwenyu. All Rights Reserved. -""" - -import copy -from collections import OrderedDict - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from .utils import get_activation - -from ..core import register - - -__all__ = ['HybridEncoder'] - - - -class ConvNormLayer(nn.Module): - def __init__(self, ch_in, ch_out, kernel_size, stride, padding=None, bias=False, act=None): - super().__init__() - self.conv = nn.Conv2d( - ch_in, - ch_out, - kernel_size, - stride, - padding=(kernel_size-1)//2 if padding is None else padding, - bias=bias) - self.norm = nn.BatchNorm2d(ch_out) - self.act = nn.Identity() if act is None else get_activation(act) - - def forward(self, x): - return self.act(self.norm(self.conv(x))) - - -class RepVggBlock(nn.Module): - def __init__(self, ch_in, ch_out, act='relu'): - super().__init__() - self.ch_in = ch_in - self.ch_out = ch_out - self.conv1 = ConvNormLayer(ch_in, ch_out, 3, 1, padding=1, act=None) - self.conv2 = ConvNormLayer(ch_in, ch_out, 1, 1, padding=0, act=None) - self.act = nn.Identity() if act is None else get_activation(act) - - def forward(self, x): - if hasattr(self, 'conv'): - y = self.conv(x) - else: - y = self.conv1(x) + self.conv2(x) - - return self.act(y) - - def convert_to_deploy(self): - if not hasattr(self, 'conv'): - self.conv = nn.Conv2d(self.ch_in, self.ch_out, 3, 1, padding=1) - - kernel, bias = self.get_equivalent_kernel_bias() - self.conv.weight.data = kernel - self.conv.bias.data = bias - - def get_equivalent_kernel_bias(self): - kernel3x3, bias3x3 = self._fuse_bn_tensor(self.conv1) - kernel1x1, bias1x1 = self._fuse_bn_tensor(self.conv2) - - return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1), bias3x3 + bias1x1 - - def _pad_1x1_to_3x3_tensor(self, kernel1x1): - if kernel1x1 is None: - return 0 - else: - return F.pad(kernel1x1, [1, 1, 1, 1]) - - def _fuse_bn_tensor(self, branch: ConvNormLayer): - if branch is None: - return 0, 0 - kernel = branch.conv.weight - running_mean = branch.norm.running_mean - running_var = branch.norm.running_var - gamma = branch.norm.weight - beta = branch.norm.bias - eps = branch.norm.eps - std = (running_var + eps).sqrt() - t = (gamma / std).reshape(-1, 1, 1, 1) - return kernel * t, beta - running_mean * gamma / std - - -class CSPRepLayer(nn.Module): - def __init__(self, - in_channels, - out_channels, - num_blocks=3, - expansion=1.0, - bias=None, - act="silu"): - super(CSPRepLayer, self).__init__() - hidden_channels = int(out_channels * expansion) - self.conv1 = ConvNormLayer(in_channels, hidden_channels, 1, 1, bias=bias, act=act) - self.conv2 = ConvNormLayer(in_channels, hidden_channels, 1, 1, bias=bias, act=act) - self.bottlenecks = nn.Sequential(*[ - RepVggBlock(hidden_channels, hidden_channels, act=act) for _ in range(num_blocks) - ]) - if hidden_channels != out_channels: - self.conv3 = ConvNormLayer(hidden_channels, out_channels, 1, 1, bias=bias, act=act) - else: - self.conv3 = nn.Identity() - - def forward(self, x): - x_1 = self.conv1(x) - x_1 = self.bottlenecks(x_1) - x_2 = self.conv2(x) - return self.conv3(x_1 + x_2) - - -# transformer -class TransformerEncoderLayer(nn.Module): - def __init__(self, - d_model, - nhead, - dim_feedforward=2048, - dropout=0.1, - activation="relu", - normalize_before=False): - super().__init__() - self.normalize_before = normalize_before - - self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout, batch_first=True) - - self.linear1 = nn.Linear(d_model, dim_feedforward) - self.dropout = nn.Dropout(dropout) - self.linear2 = nn.Linear(dim_feedforward, d_model) - - self.norm1 = nn.LayerNorm(d_model) - self.norm2 = nn.LayerNorm(d_model) - self.dropout1 = nn.Dropout(dropout) - self.dropout2 = nn.Dropout(dropout) - self.activation = get_activation(activation) - - @staticmethod - def with_pos_embed(tensor, pos_embed): - return tensor if pos_embed is None else tensor + pos_embed - - def forward(self, src, src_mask=None, pos_embed=None) -> torch.Tensor: - residual = src - if self.normalize_before: - src = self.norm1(src) - q = k = self.with_pos_embed(src, pos_embed) - src, _ = self.self_attn(q, k, value=src, attn_mask=src_mask) - - src = residual + self.dropout1(src) - if not self.normalize_before: - src = self.norm1(src) - - residual = src - if self.normalize_before: - src = self.norm2(src) - src = self.linear2(self.dropout(self.activation(self.linear1(src)))) - src = residual + self.dropout2(src) - if not self.normalize_before: - src = self.norm2(src) - return src - - -class TransformerEncoder(nn.Module): - def __init__(self, encoder_layer, num_layers, norm=None): - super(TransformerEncoder, self).__init__() - self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(num_layers)]) - self.num_layers = num_layers - self.norm = norm - - def forward(self, src, src_mask=None, pos_embed=None) -> torch.Tensor: - output = src - for layer in self.layers: - output = layer(output, src_mask=src_mask, pos_embed=pos_embed) - - if self.norm is not None: - output = self.norm(output) - - return output - - -@register() -class HybridEncoder(nn.Module): - __share__ = ['eval_spatial_size', ] - - def __init__(self, - in_channels=[512, 1024, 2048], - feat_strides=[8, 16, 32], - hidden_dim=256, - nhead=8, - dim_feedforward = 1024, - dropout=0.0, - enc_act='gelu', - use_encoder_idx=[2], - num_encoder_layers=1, - pe_temperature=10000, - expansion=1.0, - depth_mult=1.0, - act='silu', - eval_spatial_size=None, - version='v2'): - super().__init__() - self.in_channels = in_channels - self.feat_strides = feat_strides - self.hidden_dim = hidden_dim - self.use_encoder_idx = use_encoder_idx - self.num_encoder_layers = num_encoder_layers - self.pe_temperature = pe_temperature - self.eval_spatial_size = eval_spatial_size - self.out_channels = [hidden_dim for _ in range(len(in_channels))] - self.out_strides = feat_strides - - # channel projection - self.input_proj = nn.ModuleList() - for in_channel in in_channels: - if version == 'v1': - proj = nn.Sequential( - nn.Conv2d(in_channel, hidden_dim, kernel_size=1, bias=False), - nn.BatchNorm2d(hidden_dim)) - elif version == 'v2': - proj = nn.Sequential(OrderedDict([ - ('conv', nn.Conv2d(in_channel, hidden_dim, kernel_size=1, bias=False)), - ('norm', nn.BatchNorm2d(hidden_dim)) - ])) - else: - raise AttributeError() - - self.input_proj.append(proj) - - # encoder transformer - encoder_layer = TransformerEncoderLayer( - hidden_dim, - nhead=nhead, - dim_feedforward=dim_feedforward, - dropout=dropout, - activation=enc_act) - - self.encoder = nn.ModuleList([ - TransformerEncoder(copy.deepcopy(encoder_layer), num_encoder_layers) for _ in range(len(use_encoder_idx)) - ]) - - # top-down fpn - self.lateral_convs = nn.ModuleList() - self.fpn_blocks = nn.ModuleList() - for _ in range(len(in_channels) - 1, 0, -1): - self.lateral_convs.append(ConvNormLayer(hidden_dim, hidden_dim, 1, 1, act=act)) - self.fpn_blocks.append( - CSPRepLayer(hidden_dim * 2, hidden_dim, round(3 * depth_mult), act=act, expansion=expansion) - ) - - # bottom-up pan - self.downsample_convs = nn.ModuleList() - self.pan_blocks = nn.ModuleList() - for _ in range(len(in_channels) - 1): - self.downsample_convs.append( - ConvNormLayer(hidden_dim, hidden_dim, 3, 2, act=act) - ) - self.pan_blocks.append( - CSPRepLayer(hidden_dim * 2, hidden_dim, round(3 * depth_mult), act=act, expansion=expansion) - ) - - self._reset_parameters() - - def _reset_parameters(self): - if self.eval_spatial_size: - for idx in self.use_encoder_idx: - stride = self.feat_strides[idx] - pos_embed = self.build_2d_sincos_position_embedding( - self.eval_spatial_size[1] // stride, self.eval_spatial_size[0] // stride, - self.hidden_dim, self.pe_temperature) - setattr(self, f'pos_embed{idx}', pos_embed) - #self.register_buffer(f'pos_embed{idx}', pos_embed) - - @staticmethod - def build_2d_sincos_position_embedding(w, h, embed_dim=256, temperature=10000.): - """ - """ - grid_w = torch.arange(int(w), dtype=torch.float32) - grid_h = torch.arange(int(h), dtype=torch.float32) - grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing='ij') - assert embed_dim % 4 == 0, \ - 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding' - pos_dim = embed_dim // 4 - omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim - omega = 1. / (temperature ** omega) - - out_w = grid_w.flatten()[..., None] @ omega[None] - out_h = grid_h.flatten()[..., None] @ omega[None] - - return torch.concat([out_w.sin(), out_w.cos(), out_h.sin(), out_h.cos()], dim=1)[None, :, :] - - def forward(self, feats): - assert len(feats) == len(self.in_channels) - proj_feats = [self.input_proj[i](feat) for i, feat in enumerate(feats)] - - # encoder - if self.num_encoder_layers > 0: - for i, enc_ind in enumerate(self.use_encoder_idx): - h, w = proj_feats[enc_ind].shape[2:] - # flatten [B, C, H, W] to [B, HxW, C] - src_flatten = proj_feats[enc_ind].flatten(2).permute(0, 2, 1) - if self.training or self.eval_spatial_size is None: - pos_embed = self.build_2d_sincos_position_embedding( - w, h, self.hidden_dim, self.pe_temperature).to(src_flatten.device) - else: - pos_embed = getattr(self, f'pos_embed{enc_ind}', None).to(src_flatten.device) - - memory :torch.Tensor = self.encoder[i](src_flatten, pos_embed=pos_embed) - proj_feats[enc_ind] = memory.permute(0, 2, 1).reshape(-1, self.hidden_dim, h, w).contiguous() - - # broadcasting and fusion - inner_outs = [proj_feats[-1]] - for idx in range(len(self.in_channels) - 1, 0, -1): - feat_heigh = inner_outs[0] - feat_low = proj_feats[idx - 1] - feat_heigh = self.lateral_convs[len(self.in_channels) - 1 - idx](feat_heigh) - inner_outs[0] = feat_heigh - upsample_feat = F.interpolate(feat_heigh, scale_factor=2., mode='nearest') - inner_out = self.fpn_blocks[len(self.in_channels)-1-idx](torch.concat([upsample_feat, feat_low], dim=1)) - inner_outs.insert(0, inner_out) - - outs = [inner_outs[0]] - for idx in range(len(self.in_channels) - 1): - feat_low = outs[-1] - feat_height = inner_outs[idx + 1] - downsample_feat = self.downsample_convs[idx](feat_low) - out = self.pan_blocks[idx](torch.concat([downsample_feat, feat_height], dim=1)) - outs.append(out) - - return outs diff --git a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/rtdetr.py b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/rtdetr.py deleted file mode 100644 index 77a23be5c..000000000 --- a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/rtdetr.py +++ /dev/null @@ -1,44 +0,0 @@ -"""Copyright(c) 2023 lyuwenyu. All Rights Reserved. -""" - -import torch -import torch.nn as nn -import torch.nn.functional as F - -import random -import numpy as np -from typing import List - -from ..core import register - - -__all__ = ['RTDETR', ] - - -@register() -class RTDETR(nn.Module): - __inject__ = ['backbone', 'encoder', 'decoder', ] - - def __init__(self, \ - backbone: nn.Module, - encoder: nn.Module, - decoder: nn.Module, - ): - super().__init__() - self.backbone = backbone - self.decoder = decoder - self.encoder = encoder - - def forward(self, x, targets=None): - x = self.backbone(x) - x = self.encoder(x) - x = self.decoder(x, targets) - - return x - - def deploy(self, ): - self.eval() - for m in self.modules(): - if hasattr(m, 'convert_to_deploy'): - m.convert_to_deploy() - return self diff --git a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/rtdetr_postprocessor.py b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/rtdetr_postprocessor.py deleted file mode 100644 index 6d024fa45..000000000 --- a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/rtdetr_postprocessor.py +++ /dev/null @@ -1,93 +0,0 @@ -"""Copyright(c) 2023 lyuwenyu. All Rights Reserved. -""" - -import torch -import torch.nn as nn -import torch.nn.functional as F - -import torchvision - -from ..core import register - - -__all__ = ['RTDETRPostProcessor'] - - -def mod(a, b): - out = a - a // b * b - return out - - -@register() -class RTDETRPostProcessor(nn.Module): - __share__ = [ - 'num_classes', - 'use_focal_loss', - 'num_top_queries', - 'remap_mscoco_category' - ] - - def __init__( - self, - num_classes=80, - use_focal_loss=True, - num_top_queries=300, - remap_mscoco_category=False - ) -> None: - super().__init__() - self.use_focal_loss = use_focal_loss - self.num_top_queries = num_top_queries - self.num_classes = int(num_classes) - self.remap_mscoco_category = remap_mscoco_category - self.deploy_mode = False - - def extra_repr(self) -> str: - return f'use_focal_loss={self.use_focal_loss}, num_classes={self.num_classes}, num_top_queries={self.num_top_queries}' - - # def forward(self, outputs, orig_target_sizes): - def forward(self, outputs, orig_target_sizes: torch.Tensor): - logits, boxes = outputs['pred_logits'], outputs['pred_boxes'] - # orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0) - - bbox_pred = torchvision.ops.box_convert(boxes, in_fmt='cxcywh', out_fmt='xyxy') - bbox_pred *= orig_target_sizes.repeat(1, 2).unsqueeze(1) - - if self.use_focal_loss: - scores = F.sigmoid(logits) - scores, index = torch.topk(scores.flatten(1), self.num_top_queries, dim=-1) - # TODO for older tensorrt - # labels = index % self.num_classes - labels = mod(index, self.num_classes) - index = index // self.num_classes - boxes = bbox_pred.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, bbox_pred.shape[-1])) - - else: - scores = F.softmax(logits)[:, :, :-1] - scores, labels = scores.max(dim=-1) - if scores.shape[1] > self.num_top_queries: - scores, index = torch.topk(scores, self.num_top_queries, dim=-1) - labels = torch.gather(labels, dim=1, index=index) - boxes = torch.gather(boxes, dim=1, index=index.unsqueeze(-1).tile(1, 1, boxes.shape[-1])) - - # TODO for onnx export - if self.deploy_mode: - return labels, boxes, scores - - # TODO - if self.remap_mscoco_category: - from ...data.dataset import mscoco_label2category - labels = torch.tensor([mscoco_label2category[int(x.item())] for x in labels.flatten()])\ - .to(boxes.device).reshape(labels.shape) - - results = [] - for lab, box, sco in zip(labels, boxes, scores): - result = dict(labels=lab, boxes=box, scores=sco) - results.append(result) - - return results - - - def deploy(self, ): - self.eval() - self.deploy_mode = True - return self diff --git a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/rtdetrv2_decoder.py b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/rtdetrv2_decoder.py deleted file mode 100644 index 945fbfd49..000000000 --- a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/rtdetrv2_decoder.py +++ /dev/null @@ -1,608 +0,0 @@ -"""Copyright(c) 2023 lyuwenyu. All Rights Reserved. -""" - -import math -import copy -import functools -from collections import OrderedDict - -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.nn.init as init -from typing import List - -from .denoising import get_contrastive_denoising_training_group -from .utils import deformable_attention_core_func_v2, get_activation, inverse_sigmoid, bias_init_with_prob - -from ..core import register - -__all__ = ['RTDETRTransformerv2'] - - -class MLP(nn.Module): - def __init__(self, input_dim, hidden_dim, output_dim, num_layers, act='relu'): - super().__init__() - self.num_layers = num_layers - h = [hidden_dim] * (num_layers - 1) - self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) - self.act = get_activation(act) - - def forward(self, x): - for i, layer in enumerate(self.layers): - x = self.act(layer(x)) if i < self.num_layers - 1 else layer(x) - return x - - -class MSDeformableAttention(nn.Module): - def __init__( - self, - embed_dim=256, - num_heads=8, - num_levels=4, - num_points=4, - method='default', - offset_scale=0.5, - ): - """Multi-Scale Deformable Attention - """ - super(MSDeformableAttention, self).__init__() - self.embed_dim = embed_dim - self.num_heads = num_heads - self.num_levels = num_levels - self.offset_scale = offset_scale - - if isinstance(num_points, list): - assert len(num_points) == num_levels, '' - num_points_list = num_points - else: - num_points_list = [num_points for _ in range(num_levels)] - - self.num_points_list = num_points_list - - num_points_scale = [1/n for n in num_points_list for _ in range(n)] - self.register_buffer('num_points_scale', torch.tensor(num_points_scale, dtype=torch.float32)) - - self.total_points = num_heads * sum(num_points_list) - self.method = method - - self.head_dim = embed_dim // num_heads - assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" - - self.sampling_offsets = nn.Linear(embed_dim, self.total_points * 2) - self.attention_weights = nn.Linear(embed_dim, self.total_points) - self.value_proj = nn.Linear(embed_dim, embed_dim) - self.output_proj = nn.Linear(embed_dim, embed_dim) - - self.ms_deformable_attn_core = functools.partial(deformable_attention_core_func_v2, method=self.method) - - self._reset_parameters() - - if method == 'discrete': - for p in self.sampling_offsets.parameters(): - p.requires_grad = False - - def _reset_parameters(self): - # sampling_offsets - init.constant_(self.sampling_offsets.weight, 0) - thetas = torch.arange(self.num_heads, dtype=torch.float32) * (2.0 * math.pi / self.num_heads) - grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) - grid_init = grid_init / grid_init.abs().max(-1, keepdim=True).values - grid_init = grid_init.reshape(self.num_heads, 1, 2).tile([1, sum(self.num_points_list), 1]) - scaling = torch.concat([torch.arange(1, n + 1) for n in self.num_points_list]).reshape(1, -1, 1) - grid_init *= scaling - self.sampling_offsets.bias.data[...] = grid_init.flatten() - - # attention_weights - init.constant_(self.attention_weights.weight, 0) - init.constant_(self.attention_weights.bias, 0) - - # proj - init.xavier_uniform_(self.value_proj.weight) - init.constant_(self.value_proj.bias, 0) - init.xavier_uniform_(self.output_proj.weight) - init.constant_(self.output_proj.bias, 0) - - - def forward(self, - query: torch.Tensor, - reference_points: torch.Tensor, - value: torch.Tensor, - value_spatial_shapes: List[int], - value_mask: torch.Tensor=None): - """ - Args: - query (Tensor): [bs, query_length, C] - reference_points (Tensor): [bs, query_length, n_levels, 2], range in [0, 1], top-left (0,0), - bottom-right (1, 1), including padding area - value (Tensor): [bs, value_length, C] - value_spatial_shapes (List): [n_levels, 2], [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] - value_mask (Tensor): [bs, value_length], True for non-padding elements, False for padding elements - - Returns: - output (Tensor): [bs, Length_{query}, C] - """ - bs, Len_q = query.shape[:2] - Len_v = value.shape[1] - - value = self.value_proj(value) - if value_mask is not None: - value = value * value_mask.to(value.dtype).unsqueeze(-1) - - value = value.reshape(bs, Len_v, self.num_heads, self.head_dim) - - sampling_offsets: torch.Tensor = self.sampling_offsets(query) - sampling_offsets = sampling_offsets.reshape(bs, Len_q, self.num_heads, sum(self.num_points_list), 2) - - attention_weights = self.attention_weights(query).reshape(bs, Len_q, self.num_heads, sum(self.num_points_list)) - attention_weights = F.softmax(attention_weights, dim=-1).reshape(bs, Len_q, self.num_heads, sum(self.num_points_list)) - - if reference_points.shape[-1] == 2: - offset_normalizer = torch.tensor(value_spatial_shapes) - offset_normalizer = offset_normalizer.flip([1]).reshape(1, 1, 1, self.num_levels, 1, 2) - sampling_locations = reference_points.reshape(bs, Len_q, 1, self.num_levels, 1, 2) + sampling_offsets / offset_normalizer - elif reference_points.shape[-1] == 4: - # reference_points [8, 480, None, 1, 4] - # sampling_offsets [8, 480, 8, 12, 2] - num_points_scale = self.num_points_scale.to(dtype=query.dtype).unsqueeze(-1) - offset = sampling_offsets * num_points_scale * reference_points[:, :, None, :, 2:] * self.offset_scale - sampling_locations = reference_points[:, :, None, :, :2] + offset - else: - raise ValueError( - "Last dim of reference_points must be 2 or 4, but get {} instead.". - format(reference_points.shape[-1])) - - output = self.ms_deformable_attn_core(value, value_spatial_shapes, sampling_locations, attention_weights, self.num_points_list) - - output = self.output_proj(output) - - return output - - -class TransformerDecoderLayer(nn.Module): - def __init__(self, - d_model=256, - n_head=8, - dim_feedforward=1024, - dropout=0., - activation='relu', - n_levels=4, - n_points=4, - cross_attn_method='default'): - super(TransformerDecoderLayer, self).__init__() - - # self attention - self.self_attn = nn.MultiheadAttention(d_model, n_head, dropout=dropout, batch_first=True) - self.dropout1 = nn.Dropout(dropout) - self.norm1 = nn.LayerNorm(d_model) - - # cross attention - self.cross_attn = MSDeformableAttention(d_model, n_head, n_levels, n_points, method=cross_attn_method) - self.dropout2 = nn.Dropout(dropout) - self.norm2 = nn.LayerNorm(d_model) - - # ffn - self.linear1 = nn.Linear(d_model, dim_feedforward) - self.activation = get_activation(activation) - self.dropout3 = nn.Dropout(dropout) - self.linear2 = nn.Linear(dim_feedforward, d_model) - self.dropout4 = nn.Dropout(dropout) - self.norm3 = nn.LayerNorm(d_model) - - self._reset_parameters() - - def _reset_parameters(self): - init.xavier_uniform_(self.linear1.weight) - init.xavier_uniform_(self.linear2.weight) - - def with_pos_embed(self, tensor, pos): - return tensor if pos is None else tensor + pos - - def forward_ffn(self, tgt): - return self.linear2(self.dropout3(self.activation(self.linear1(tgt)))) - - def forward(self, - target, - reference_points, - memory, - memory_spatial_shapes, - attn_mask=None, - memory_mask=None, - query_pos_embed=None): - # self attention - q = k = self.with_pos_embed(target, query_pos_embed) - - target2, _ = self.self_attn(q, k, value=target, attn_mask=attn_mask) - target = target + self.dropout1(target2) - target = self.norm1(target) - - # cross attention - target2 = self.cross_attn(\ - self.with_pos_embed(target, query_pos_embed), - reference_points, - memory, - memory_spatial_shapes, - memory_mask) - target = target + self.dropout2(target2) - target = self.norm2(target) - - # ffn - target2 = self.forward_ffn(target) - target = target + self.dropout4(target2) - target = self.norm3(target) - - return target - - -class TransformerDecoder(nn.Module): - def __init__(self, hidden_dim, decoder_layer, num_layers, eval_idx=-1): - super(TransformerDecoder, self).__init__() - self.layers = nn.ModuleList([copy.deepcopy(decoder_layer) for _ in range(num_layers)]) - self.hidden_dim = hidden_dim - self.num_layers = num_layers - self.eval_idx = eval_idx if eval_idx >= 0 else num_layers + eval_idx - - def forward(self, - target, - ref_points_unact, - memory, - memory_spatial_shapes, - bbox_head, - score_head, - query_pos_head, - attn_mask=None, - memory_mask=None): - dec_out_bboxes = [] - dec_out_logits = [] - ref_points_detach = F.sigmoid(ref_points_unact) - - output = target - for i, layer in enumerate(self.layers): - ref_points_input = ref_points_detach.unsqueeze(2) - query_pos_embed = query_pos_head(ref_points_detach) - - output = layer(output, ref_points_input, memory, memory_spatial_shapes, attn_mask, memory_mask, query_pos_embed) - - inter_ref_bbox = F.sigmoid(bbox_head[i](output) + inverse_sigmoid(ref_points_detach)) - - if self.training: - dec_out_logits.append(score_head[i](output)) - if i == 0: - dec_out_bboxes.append(inter_ref_bbox) - else: - dec_out_bboxes.append(F.sigmoid(bbox_head[i](output) + inverse_sigmoid(ref_points))) - - elif i == self.eval_idx: - dec_out_logits.append(score_head[i](output)) - dec_out_bboxes.append(inter_ref_bbox) - break - - ref_points = inter_ref_bbox - ref_points_detach = inter_ref_bbox.detach() - - return torch.stack(dec_out_bboxes), torch.stack(dec_out_logits) - - -@register() -class RTDETRTransformerv2(nn.Module): - __share__ = ['num_classes', 'eval_spatial_size'] - - def __init__(self, - num_classes=80, - hidden_dim=256, - num_queries=300, - feat_channels=[512, 1024, 2048], - feat_strides=[8, 16, 32], - num_levels=3, - num_points=4, - nhead=8, - num_layers=6, - dim_feedforward=1024, - dropout=0., - activation="relu", - num_denoising=100, - label_noise_ratio=0.5, - box_noise_scale=1.0, - learn_query_content=False, - eval_spatial_size=None, - eval_idx=-1, - eps=1e-2, - aux_loss=True, - cross_attn_method='default', - query_select_method='default'): - super().__init__() - assert len(feat_channels) <= num_levels - assert len(feat_strides) == len(feat_channels) - - for _ in range(num_levels - len(feat_strides)): - feat_strides.append(feat_strides[-1] * 2) - - self.hidden_dim = hidden_dim - self.nhead = nhead - self.feat_strides = feat_strides - self.num_levels = num_levels - self.num_classes = num_classes - self.num_queries = num_queries - self.eps = eps - self.num_layers = num_layers - self.eval_spatial_size = eval_spatial_size - self.aux_loss = aux_loss - - assert query_select_method in ('default', 'one2many', 'agnostic'), '' - assert cross_attn_method in ('default', 'discrete'), '' - self.cross_attn_method = cross_attn_method - self.query_select_method = query_select_method - - # backbone feature projection - self._build_input_proj_layer(feat_channels) - - # Transformer module - decoder_layer = TransformerDecoderLayer(hidden_dim, nhead, dim_feedforward, dropout, \ - activation, num_levels, num_points, cross_attn_method=cross_attn_method) - self.decoder = TransformerDecoder(hidden_dim, decoder_layer, num_layers, eval_idx) - - # denoising - self.num_denoising = num_denoising - self.label_noise_ratio = label_noise_ratio - self.box_noise_scale = box_noise_scale - if num_denoising > 0: - self.denoising_class_embed = nn.Embedding(num_classes+1, hidden_dim, padding_idx=num_classes) - init.normal_(self.denoising_class_embed.weight[:-1]) - - # decoder embedding - self.learn_query_content = learn_query_content - if learn_query_content: - self.tgt_embed = nn.Embedding(num_queries, hidden_dim) - self.query_pos_head = MLP(4, 2 * hidden_dim, hidden_dim, 2) - - # if num_select_queries != self.num_queries: - # layer = TransformerEncoderLayer(hidden_dim, nhead, dim_feedforward, activation='gelu') - # self.encoder = TransformerEncoder(layer, 1) - - self.enc_output = nn.Sequential(OrderedDict([ - ('proj', nn.Linear(hidden_dim, hidden_dim)), - ('norm', nn.LayerNorm(hidden_dim,)), - ])) - - if query_select_method == 'agnostic': - self.enc_score_head = nn.Linear(hidden_dim, 1) - else: - self.enc_score_head = nn.Linear(hidden_dim, num_classes) - - self.enc_bbox_head = MLP(hidden_dim, hidden_dim, 4, 3) - - # decoder head - self.dec_score_head = nn.ModuleList([ - nn.Linear(hidden_dim, num_classes) for _ in range(num_layers) - ]) - self.dec_bbox_head = nn.ModuleList([ - MLP(hidden_dim, hidden_dim, 4, 3) for _ in range(num_layers) - ]) - - # init encoder output anchors and valid_mask - if self.eval_spatial_size: - anchors, valid_mask = self._generate_anchors() - self.register_buffer('anchors', anchors) - self.register_buffer('valid_mask', valid_mask) - - self._reset_parameters() - - def _reset_parameters(self): - bias = bias_init_with_prob(0.01) - init.constant_(self.enc_score_head.bias, bias) - init.constant_(self.enc_bbox_head.layers[-1].weight, 0) - init.constant_(self.enc_bbox_head.layers[-1].bias, 0) - - for _cls, _reg in zip(self.dec_score_head, self.dec_bbox_head): - init.constant_(_cls.bias, bias) - init.constant_(_reg.layers[-1].weight, 0) - init.constant_(_reg.layers[-1].bias, 0) - - init.xavier_uniform_(self.enc_output[0].weight) - if self.learn_query_content: - init.xavier_uniform_(self.tgt_embed.weight) - init.xavier_uniform_(self.query_pos_head.layers[0].weight) - init.xavier_uniform_(self.query_pos_head.layers[1].weight) - for m in self.input_proj: - init.xavier_uniform_(m[0].weight) - - def _build_input_proj_layer(self, feat_channels): - self.input_proj = nn.ModuleList() - for in_channels in feat_channels: - self.input_proj.append( - nn.Sequential(OrderedDict([ - ('conv', nn.Conv2d(in_channels, self.hidden_dim, 1, bias=False)), - ('norm', nn.BatchNorm2d(self.hidden_dim,))]) - ) - ) - - in_channels = feat_channels[-1] - - for _ in range(self.num_levels - len(feat_channels)): - self.input_proj.append( - nn.Sequential(OrderedDict([ - ('conv', nn.Conv2d(in_channels, self.hidden_dim, 3, 2, padding=1, bias=False)), - ('norm', nn.BatchNorm2d(self.hidden_dim))]) - ) - ) - in_channels = self.hidden_dim - - def _get_encoder_input(self, feats: List[torch.Tensor]): - # get projection features - proj_feats = [self.input_proj[i](feat) for i, feat in enumerate(feats)] - if self.num_levels > len(proj_feats): - len_srcs = len(proj_feats) - for i in range(len_srcs, self.num_levels): - if i == len_srcs: - proj_feats.append(self.input_proj[i](feats[-1])) - else: - proj_feats.append(self.input_proj[i](proj_feats[-1])) - - # get encoder inputs - feat_flatten = [] - spatial_shapes = [] - for i, feat in enumerate(proj_feats): - _, _, h, w = feat.shape - # [b, c, h, w] -> [b, h*w, c] - feat_flatten.append(feat.flatten(2).permute(0, 2, 1)) - # [num_levels, 2] - spatial_shapes.append([h, w]) - # [b, l, c] - feat_flatten = torch.concat(feat_flatten, 1) - return feat_flatten, spatial_shapes - - def _generate_anchors(self, - spatial_shapes=None, - grid_size=0.05, - dtype=torch.float32, - device='cpu'): - if spatial_shapes is None: - spatial_shapes = [] - eval_h, eval_w = self.eval_spatial_size - for s in self.feat_strides: - spatial_shapes.append([int(eval_h / s), int(eval_w / s)]) - - anchors = [] - for lvl, (h, w) in enumerate(spatial_shapes): - grid_y, grid_x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing='ij') - grid_xy = torch.stack([grid_x, grid_y], dim=-1) - grid_xy = (grid_xy.unsqueeze(0) + 0.5) / torch.tensor([w, h], dtype=dtype) - wh = torch.ones_like(grid_xy) * grid_size * (2.0 ** lvl) - lvl_anchors = torch.concat([grid_xy, wh], dim=-1).reshape(-1, h * w, 4) - anchors.append(lvl_anchors) - - anchors = torch.concat(anchors, dim=1).to(device) - valid_mask = ((anchors > self.eps) * (anchors < 1 - self.eps)).all(-1, keepdim=True) - anchors = torch.log(anchors / (1 - anchors)) - anchors = torch.where(valid_mask, anchors, torch.inf) - - return anchors, valid_mask - - - def _get_decoder_input(self, - memory: torch.Tensor, - spatial_shapes, - denoising_logits=None, - denoising_bbox_unact=None): - - # prepare input for decoder - if self.training or self.eval_spatial_size is None: - anchors, valid_mask = self._generate_anchors(spatial_shapes, device=memory.device) - else: - anchors = self.anchors - valid_mask = self.valid_mask - - # memory = torch.where(valid_mask, memory, 0) - # TODO fix type error for onnx export - memory = valid_mask.to(memory.dtype) * memory - - output_memory :torch.Tensor = self.enc_output(memory) - enc_outputs_logits :torch.Tensor = self.enc_score_head(output_memory) - enc_outputs_coord_unact :torch.Tensor = self.enc_bbox_head(output_memory) + anchors - - enc_topk_bboxes_list, enc_topk_logits_list = [], [] - enc_topk_memory, enc_topk_logits, enc_topk_bbox_unact = \ - self._select_topk(output_memory, enc_outputs_logits, enc_outputs_coord_unact, self.num_queries) - - if self.training: - enc_topk_bboxes = F.sigmoid(enc_topk_bbox_unact) - enc_topk_bboxes_list.append(enc_topk_bboxes) - enc_topk_logits_list.append(enc_topk_logits) - - # if self.num_select_queries != self.num_queries: - # raise NotImplementedError('') - - if self.learn_query_content: - content = self.tgt_embed.weight.unsqueeze(0).tile([memory.shape[0], 1, 1]) - else: - content = enc_topk_memory.detach() - - enc_topk_bbox_unact = enc_topk_bbox_unact.detach() - - if denoising_bbox_unact is not None: - enc_topk_bbox_unact = torch.concat([denoising_bbox_unact, enc_topk_bbox_unact], dim=1) - content = torch.concat([denoising_logits, content], dim=1) - - return content, enc_topk_bbox_unact, enc_topk_bboxes_list, enc_topk_logits_list - - def _select_topk(self, memory: torch.Tensor, outputs_logits: torch.Tensor, outputs_coords_unact: torch.Tensor, topk: int): - if self.query_select_method == 'default': - _, topk_ind = torch.topk(outputs_logits.max(-1).values, topk, dim=-1) - - elif self.query_select_method == 'one2many': - _, topk_ind = torch.topk(outputs_logits.flatten(1), topk, dim=-1) - topk_ind = topk_ind // self.num_classes - - elif self.query_select_method == 'agnostic': - _, topk_ind = torch.topk(outputs_logits.squeeze(-1), topk, dim=-1) - - topk_ind: torch.Tensor - - topk_coords = outputs_coords_unact.gather(dim=1, \ - index=topk_ind.unsqueeze(-1).repeat(1, 1, outputs_coords_unact.shape[-1])) - - topk_logits = outputs_logits.gather(dim=1, \ - index=topk_ind.unsqueeze(-1).repeat(1, 1, outputs_logits.shape[-1])) - - topk_memory = memory.gather(dim=1, \ - index=topk_ind.unsqueeze(-1).repeat(1, 1, memory.shape[-1])) - - return topk_memory, topk_logits, topk_coords - - - def forward(self, feats, targets=None): - # input projection and embedding - memory, spatial_shapes = self._get_encoder_input(feats) - - # prepare denoising training - if self.training and self.num_denoising > 0: - denoising_logits, denoising_bbox_unact, attn_mask, dn_meta = \ - get_contrastive_denoising_training_group(targets, \ - self.num_classes, - self.num_queries, - self.denoising_class_embed, - num_denoising=self.num_denoising, - label_noise_ratio=self.label_noise_ratio, - box_noise_scale=self.box_noise_scale, ) - else: - denoising_logits, denoising_bbox_unact, attn_mask, dn_meta = None, None, None, None - - init_ref_contents, init_ref_points_unact, enc_topk_bboxes_list, enc_topk_logits_list = \ - self._get_decoder_input(memory, spatial_shapes, denoising_logits, denoising_bbox_unact) - - # decoder - out_bboxes, out_logits = self.decoder( - init_ref_contents, - init_ref_points_unact, - memory, - spatial_shapes, - self.dec_bbox_head, - self.dec_score_head, - self.query_pos_head, - attn_mask=attn_mask) - - if self.training and dn_meta is not None: - dn_out_bboxes, out_bboxes = torch.split(out_bboxes, dn_meta['dn_num_split'], dim=2) - dn_out_logits, out_logits = torch.split(out_logits, dn_meta['dn_num_split'], dim=2) - - out = {'pred_logits': out_logits[-1], 'pred_boxes': out_bboxes[-1]} - - if self.training and self.aux_loss: - out['aux_outputs'] = self._set_aux_loss(out_logits[:-1], out_bboxes[:-1]) - out['enc_aux_outputs'] = self._set_aux_loss(enc_topk_logits_list, enc_topk_bboxes_list) - out['enc_meta'] = {'class_agnostic': self.query_select_method == 'agnostic'} - - if dn_meta is not None: - out['dn_aux_outputs'] = self._set_aux_loss(dn_out_logits, dn_out_bboxes) - out['dn_meta'] = dn_meta - - return out - - - @torch.jit.unused - def _set_aux_loss(self, outputs_class, outputs_coord): - # this is a workaround to make torchscript happy, as torchscript - # doesn't support dictionary with non-homogeneous values, such - # as a dict having both a Tensor and a list. - return [{'pred_logits': a, 'pred_boxes': b} - for a, b in zip(outputs_class, outputs_coord)] diff --git a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/utils.py b/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/utils.py deleted file mode 100644 index 7653bf9e6..000000000 --- a/PytorchWildlife/models/detection/rtdetr_apache/rtdetrv2_pytorch/src/rtdetr/utils.py +++ /dev/null @@ -1,127 +0,0 @@ -"""Copyright(c) 2023 lyuwenyu. All Rights Reserved. -""" - -import math -from typing import List - -import torch -import torch.nn as nn -import torch.nn.functional as F - - -def inverse_sigmoid(x: torch.Tensor, eps: float=1e-5) -> torch.Tensor: - x = x.clip(min=0., max=1.) - return torch.log(x.clip(min=eps) / (1 - x).clip(min=eps)) - - -def bias_init_with_prob(prior_prob=0.01): - """initialize conv/fc bias value according to a given probability value.""" - bias_init = float(-math.log((1 - prior_prob) / prior_prob)) - return bias_init - - -def deformable_attention_core_func_v2(\ - value: torch.Tensor, - value_spatial_shapes, - sampling_locations: torch.Tensor, - attention_weights: torch.Tensor, - num_points_list: List[int], - method='default'): - """ - Args: - value (Tensor): [bs, value_length, n_head, c] - value_spatial_shapes (Tensor|List): [n_levels, 2] - value_level_start_index (Tensor|List): [n_levels] - sampling_locations (Tensor): [bs, query_length, n_head, n_levels * n_points, 2] - attention_weights (Tensor): [bs, query_length, n_head, n_levels * n_points] - - Returns: - output (Tensor): [bs, Length_{query}, C] - """ - bs, _, n_head, c = value.shape - _, Len_q, _, _, _ = sampling_locations.shape - - split_shape = [h * w for h, w in value_spatial_shapes] - value_list = value.permute(0, 2, 3, 1).flatten(0, 1).split(split_shape, dim=-1) - - # sampling_offsets [8, 480, 8, 12, 2] - if method == 'default': - sampling_grids = 2 * sampling_locations - 1 - - elif method == 'discrete': - sampling_grids = sampling_locations - - sampling_grids = sampling_grids.permute(0, 2, 1, 3, 4).flatten(0, 1) - sampling_locations_list = sampling_grids.split(num_points_list, dim=-2) - - sampling_value_list = [] - for level, (h, w) in enumerate(value_spatial_shapes): - value_l = value_list[level].reshape(bs * n_head, c, h, w) - sampling_grid_l: torch.Tensor = sampling_locations_list[level] - - if method == 'default': - sampling_value_l = F.grid_sample( - value_l, - sampling_grid_l, - mode='bilinear', - padding_mode='zeros', - align_corners=False) - - elif method == 'discrete': - # n * m, seq, n, 2 - sampling_coord = (sampling_grid_l * torch.tensor([[w, h]], device=value.device) + 0.5).to(torch.int64) - - # FIX ME? for rectangle input - sampling_coord = sampling_coord.clamp(0, h - 1) - sampling_coord = sampling_coord.reshape(bs * n_head, Len_q * num_points_list[level], 2) - - s_idx = torch.arange(sampling_coord.shape[0], device=value.device).unsqueeze(-1).repeat(1, sampling_coord.shape[1]) - sampling_value_l: torch.Tensor = value_l[s_idx, :, sampling_coord[..., 1], sampling_coord[..., 0]] # n l c - - sampling_value_l = sampling_value_l.permute(0, 2, 1).reshape(bs * n_head, c, Len_q, num_points_list[level]) - - sampling_value_list.append(sampling_value_l) - - attn_weights = attention_weights.permute(0, 2, 1, 3).reshape(bs * n_head, 1, Len_q, sum(num_points_list)) - weighted_sample_locs = torch.concat(sampling_value_list, dim=-1) * attn_weights - output = weighted_sample_locs.sum(-1).reshape(bs, n_head * c, Len_q) - - return output.permute(0, 2, 1) - - -def get_activation(act: str, inpace: bool=True): - """get activation - """ - if act is None: - return nn.Identity() - - elif isinstance(act, nn.Module): - return act - - act = act.lower() - - if act == 'silu' or act == 'swish': - m = nn.SiLU() - - elif act == 'relu': - m = nn.ReLU() - - elif act == 'leaky_relu': - m = nn.LeakyReLU() - - elif act == 'silu': - m = nn.SiLU() - - elif act == 'gelu': - m = nn.GELU() - - elif act == 'hardsigmoid': - m = nn.Hardsigmoid() - - else: - raise RuntimeError('') - - if hasattr(m, 'inplace'): - m.inplace = inpace - - return m diff --git a/PytorchWildlife/models/detection/ultralytics_based/Deepfaune.py b/PytorchWildlife/models/detection/ultralytics_based/Deepfaune.py deleted file mode 100644 index 36c84592f..000000000 --- a/PytorchWildlife/models/detection/ultralytics_based/Deepfaune.py +++ /dev/null @@ -1,43 +0,0 @@ -""" -This is a Pytorch-Wildlife loader for the Deepfaune detector. -The original Deepfaune model is available at: https://www.deepfaune.cnrs.fr/en/ -Licence: CC BY-SA 4.0 -Copyright CNRS 2024 -simon.chamaille@cefe.cnrs.fr; vincent.miele@univ-lyon1.fr -""" - -from .yolov8_base import YOLOV8Base - -__all__ = [ - 'DeepfauneDetector', -] - -class DeepfauneDetector(YOLOV8Base): - """ - MegaDetectorV6 is a specialized class derived from the YOLOV8Base class - that is specifically designed for detecting animals, persons, and vehicles. - - Attributes: - CLASS_NAMES (dict): Mapping of class IDs to their respective names. - """ - - CLASS_NAMES = { - 0: "animal", - 1: "person", - 2: "vehicle" - } - - def __init__(self, weights=None, device="cpu"): - """ - Initializes the MegaDetectorV5 model with the option to load pretrained weights. - - Args: - weights (str, optional): Path to the weights file. - device (str, optional): Device to load the model on (e.g., "cpu" or "cuda"). Default is "cpu". - """ - self.IMAGE_SIZE = 960 - - url = "https://pbil.univ-lyon1.fr/software/download/deepfaune/v1.3/deepfaune-yolov8s_960.pt" - self.MODEL_NAME = "deepfaune-yolov8s_960.pt" - - super(DeepfauneDetector, self).__init__(weights=weights, device=device, url=url) \ No newline at end of file diff --git a/PytorchWildlife/models/detection/ultralytics_based/__init__.py b/PytorchWildlife/models/detection/ultralytics_based/__init__.py deleted file mode 100644 index cb9fd8b1e..000000000 --- a/PytorchWildlife/models/detection/ultralytics_based/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from .yolov5_base import * -from .yolov8_base import * -from .megadetectorv5 import * -from .megadetectorv6 import * -from .megadetectorv6_distributed import * -from .Deepfaune import * diff --git a/PytorchWildlife/models/detection/ultralytics_based/megadetectorv5.py b/PytorchWildlife/models/detection/ultralytics_based/megadetectorv5.py deleted file mode 100644 index cbbbb6671..000000000 --- a/PytorchWildlife/models/detection/ultralytics_based/megadetectorv5.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the AGPL License. - -from .yolov5_base import YOLOV5Base - -__all__ = [ - 'MegaDetectorV5' -] - -class MegaDetectorV5(YOLOV5Base): - """ - MegaDetectorV5 is a specialized class derived from the YOLOV5Base class - that is specifically designed for detecting animals, persons, and vehicles. - - Attributes: - IMAGE_SIZE (int): The standard image size used during training. - STRIDE (int): Stride value used in the detector. - CLASS_NAMES (dict): Mapping of class IDs to their respective names. - """ - - IMAGE_SIZE = 1280 # image size used in training - STRIDE = 64 - CLASS_NAMES = { - 0: "animal", - 1: "person", - 2: "vehicle" - } - - def __init__(self, weights=None, device="cpu", pretrained=True, version="a"): - """ - Initializes the MegaDetectorV5 model with the option to load pretrained weights. - - Args: - weights (str, optional): Path to the weights file. - device (str, optional): Device to load the model on (e.g., "cpu" or "cuda"). Default is "cpu". - pretrained (bool, optional): Whether to load the pretrained model. Default is True. - version (str, optional): Version of the MegaDetectorV5 model to load. Default is "a". - """ - - if pretrained: - if version == "a": - url = "https://zenodo.org/records/13357337/files/md_v5a.0.0.pt?download=1" - elif version == "b": - url = "https://zenodo.org/records/10023414/files/MegaDetector_v5b.0.0.pt?download=1" - else: - url = None - - import site - import sys - sys.path.insert(0, site.getsitepackages()[0]+'/yolov5') - - super(MegaDetectorV5, self).__init__(weights=weights, device=device, url=url) - - - - - -# %% diff --git a/PytorchWildlife/models/detection/ultralytics_based/megadetectorv6.py b/PytorchWildlife/models/detection/ultralytics_based/megadetectorv6.py deleted file mode 100644 index be0dfb6e7..000000000 --- a/PytorchWildlife/models/detection/ultralytics_based/megadetectorv6.py +++ /dev/null @@ -1,53 +0,0 @@ - -from .yolov8_base import YOLOV8Base - -__all__ = [ - 'MegaDetectorV6' -] - -class MegaDetectorV6(YOLOV8Base): - """ - MegaDetectorV6 is a specialized class derived from the YOLOV8Base class - that is specifically designed for detecting animals, persons, and vehicles. - - Attributes: - CLASS_NAMES (dict): Mapping of class IDs to their respective names. - """ - - CLASS_NAMES = { - 0: "animal", - 1: "person", - 2: "vehicle" - } - - def __init__(self, weights=None, device="cpu", pretrained=True, version='MDV6-yolov9-c'): - """ - Initializes the MegaDetectorV5 model with the option to load pretrained weights. - - Args: - weights (str, optional): Path to the weights file. - device (str, optional): Device to load the model on (e.g., "cpu" or "cuda"). Default is "cpu". - pretrained (bool, optional): Whether to load the pretrained model. Default is True. - version (str, optional): Version of the model to load. Default is 'MDV6-yolov9-c'. - """ - self.IMAGE_SIZE = 1280 - - if version == 'MDV6-yolov9-c': - url = "https://zenodo.org/records/15398270/files/MDV6-yolov9-c.pt?download=1" - self.MODEL_NAME = "MDV6b-yolov9-c.pt" - elif version == 'MDV6-yolov9-e': - url = "https://zenodo.org/records/15398270/files/MDV6-yolov9-e-1280.pt?download=1" - self.MODEL_NAME = "MDV6-yolov9-e-1280.pt" - elif version == 'MDV6-yolov10-c': - url = "https://zenodo.org/records/15398270/files/MDV6-yolov10-c.pt?download=1" - self.MODEL_NAME = "MDV6-yolov10-c.pt" - elif version == 'MDV6-yolov10-e': - url = "https://zenodo.org/records/15398270/files/MDV6-yolov10-e-1280.pt?download=1" - self.MODEL_NAME = "MDV6-yolov10-e-1280.pt" - elif version == 'MDV6-rtdetr-c': - url = "https://zenodo.org/records/15398270/files/MDV6-rtdetr-c.pt?download=1" - self.MODEL_NAME = "MDV6b-rtdetr-c.pt" - else: - raise ValueError('Select a valid model version: MDV6-yolov9-c, MDV6-yolov9-e, MDV6-yolov10-c, MDV6-yolov10-e, MDV6-rtdetr-c') - - super(MegaDetectorV6, self).__init__(weights=weights, device=device, url=url) \ No newline at end of file diff --git a/PytorchWildlife/models/detection/ultralytics_based/megadetectorv6_distributed.py b/PytorchWildlife/models/detection/ultralytics_based/megadetectorv6_distributed.py deleted file mode 100644 index 1fc0f9539..000000000 --- a/PytorchWildlife/models/detection/ultralytics_based/megadetectorv6_distributed.py +++ /dev/null @@ -1,53 +0,0 @@ - -from .yolov8_distributed import YOLOV8_Distributed - -__all__ = [ - 'MegaDetectorV6_Distributed' -] - -class MegaDetectorV6_Distributed(YOLOV8_Distributed): - """ - MegaDetectorV6 is a specialized class derived from the YOLOV8Base class - that is specifically designed for detecting animals, persons, and vehicles. - - Attributes: - CLASS_NAMES (dict): Mapping of class IDs to their respective names. - """ - - CLASS_NAMES = { - 0: "animal", - 1: "person", - 2: "vehicle" - } - - def __init__(self, weights=None, device="cpu", pretrained=True, version='MDV6-yolov9-c'): - """ - Initializes the MegaDetectorV5 model with the option to load pretrained weights. - - Args: - weights (str, optional): Path to the weights file. - device (str, optional): Device to load the model on (e.g., "cpu" or "cuda"). Default is "cpu". - pretrained (bool, optional): Whether to load the pretrained model. Default is True. - version (str, optional): Version of the model to load. Default is 'MDV6-yolov9-c'. - """ - self.IMAGE_SIZE = 1280 - - if version == 'MDV6-yolov9-c': - url = "https://zenodo.org/records/15398270/files/MDV6-yolov9-c.pt?download=1" - self.MODEL_NAME = "MDV6b-yolov9-c.pt" - elif version == 'MDV6-yolov9-e': - url = "https://zenodo.org/records/15398270/files/MDV6-yolov9-e-1280.pt?download=1" - self.MODEL_NAME = "MDV6-yolov9-e-1280.pt" - elif version == 'MDV6-yolov10-c': - url = "https://zenodo.org/records/15398270/files/MDV6-yolov10-c.pt?download=1" - self.MODEL_NAME = "MDV6-yolov10-c.pt" - elif version == 'MDV6-yolov10-e': - url = "https://zenodo.org/records/15398270/files/MDV6-yolov10-e-1280.pt?download=1" - self.MODEL_NAME = "MDV6-yolov10-e-1280.pt" - elif version == 'MDV6-rtdetr-c': - url = "https://zenodo.org/records/15398270/files/MDV6-rtdetr-c.pt?download=1" - self.MODEL_NAME = "MDV6b-rtdetr-c.pt" - else: - raise ValueError('Select a valid model version: MDV6-yolov9-c, MDV6-yolov9-e, MDV6-yolov10-c, MDV6-yolov10-e, MDV6-rtdetr-c') - - super(MegaDetectorV6_Distributed, self).__init__(weights=weights, device=device, url=url) \ No newline at end of file diff --git a/PytorchWildlife/models/detection/ultralytics_based/yolov5_base.py b/PytorchWildlife/models/detection/ultralytics_based/yolov5_base.py deleted file mode 100644 index 917b6d93c..000000000 --- a/PytorchWildlife/models/detection/ultralytics_based/yolov5_base.py +++ /dev/null @@ -1,183 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the AGPL License. - -""" YoloV5 base detector class. """ - -# Importing basic libraries - -import numpy as np -from tqdm import tqdm -from PIL import Image -import supervision as sv - -import torch -from torch.utils.data import DataLoader -from torch.hub import load_state_dict_from_url - -from yolov5.utils.general import non_max_suppression, scale_boxes -from ..base_detector import BaseDetector -from ....data import transforms as pw_trans -from ....data import datasets as pw_data - - -class YOLOV5Base(BaseDetector): - """ - Base detector class for YOLO V5. This class provides utility methods for - loading the model, generating results, and performing single and batch image detections. - """ - def __init__(self, weights=None, device="cpu", url=None, transform=None): - """ - Initialize the YOLO V5 detector. - - Args: - weights (str, optional): - Path to the model weights. Defaults to None. - device (str, optional): - Device for model inference. Defaults to "cpu". - url (str, optional): - URL to fetch the model weights. Defaults to None. - transform (callable, optional): - Optional transform to be applied on the image. Defaults to None. - """ - self.transform = transform - super(YOLOV5Base, self).__init__(weights=weights, device=device, url=url) - self._load_model(weights, device, url) - - def _load_model(self, weights=None, device="cpu", url=None): - """ - Load the YOLO V5 model weights. - - Args: - weights (str, optional): - Path to the model weights. Defaults to None. - device (str, optional): - Device for model inference. Defaults to "cpu". - url (str, optional): - URL to fetch the model weights. Defaults to None. - Raises: - Exception: If weights are not provided. - """ - if weights: - checkpoint = torch.load(weights, map_location=torch.device(device)) - elif url: - checkpoint = load_state_dict_from_url(url, map_location=torch.device(self.device)) - else: - raise Exception("Need weights for inference.") - self.model = checkpoint["model"].float().fuse().eval().to(self.device) - - if not self.transform: - self.transform = pw_trans.MegaDetector_v5_Transform(target_size=self.IMAGE_SIZE, - stride=self.STRIDE) - - def results_generation(self, preds, img_id, id_strip=None) -> dict: - """ - Generate results for detection based on model predictions. - - Args: - preds (numpy.ndarray): - Model predictions. - img_id (str): - Image identifier. - id_strip (str, optional): - Strip specific characters from img_id. Defaults to None. - - Returns: - dict: Dictionary containing image ID, detections, and labels. - """ - results = {"img_id": str(img_id).strip(id_strip)} - results["detections"] = sv.Detections( - xyxy=preds[:, :4], - confidence=preds[:, 4], - class_id=preds[:, 5].astype(int) - ) - results["labels"] = [ - f"{self.CLASS_NAMES[class_id]} {confidence:0.2f}" - for confidence, class_id in zip(results["detections"].confidence, results["detections"].class_id) - ] - return results - - def single_image_detection(self, img, img_path=None, det_conf_thres=0.2, id_strip=None) -> dict: - """ - Perform detection on a single image. - - Args: - img (str or ndarray): - Image path or ndarray of images. - img_path (str, optional): - Image path or identifier. - det_conf_thres (float, optional): - Confidence threshold for predictions. Defaults to 0.2. - id_strip (str, optional): - Characters to strip from img_id. Defaults to None. - - Returns: - dict: Detection results. - """ - if type(img) == str: - if img_path is None: - img_path = img - img = np.array(Image.open(img_path).convert("RGB")) - img_size = img.shape - img = self.transform(img) - - if img_size is None: - img_size = img.permute((1, 2, 0)).shape # We need hwc instead of chw for coord scaling - preds = self.model(img.unsqueeze(0).to(self.device))[0] - preds = torch.cat(non_max_suppression(prediction=preds, conf_thres=det_conf_thres), axis=0).cpu().numpy() - # preds[:, :4] = scale_coords([self.IMAGE_SIZE] * 2, preds[:, :4], img_size).round() - preds[:, :4] = scale_boxes([self.IMAGE_SIZE] * 2, preds[:, :4], img_size).round() - res = self.results_generation(preds, img_path, id_strip) - - normalized_coords = [[x1 / img_size[1], y1 / img_size[0], x2 / img_size[1], y2 / img_size[0]] for x1, y1, x2, y2 in preds[:, :4]] - res["normalized_coords"] = normalized_coords - - return res - - def batch_image_detection(self, data_path, batch_size: int = 16, det_conf_thres: float = 0.2, id_strip: str = None) -> list[dict]: - """ - Perform detection on a batch of images. - - Args: - data_path (str): Path containing all images for inference. - batch_size (int, optional): Batch size for inference. Defaults to 16. - det_conf_thres (float, optional): Confidence threshold for predictions. Defaults to 0.2. - id_strip (str, optional): Characters to strip from img_id. Defaults to None. - - Returns: - list[dict]: List of detection results for all images. - """ - - dataset = pw_data.DetectionImageFolder( - data_path, - transform=self.transform, - ) - - # Creating a DataLoader for batching and parallel processing of the images - loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, - pin_memory=True, num_workers=0, drop_last=False) - - results = [] - with tqdm(total=len(loader)) as pbar: - for batch_index, (imgs, paths, sizes) in enumerate(loader): - imgs = imgs.to(self.device) - predictions = self.model(imgs)[0].detach().cpu() - predictions = non_max_suppression(predictions, conf_thres=det_conf_thres) - - batch_results = [] - for i, pred in enumerate(predictions): - if pred.size(0) == 0: - continue - pred = pred.numpy() - size = sizes[i].numpy() - path = paths[i] - original_coords = pred[:, :4].copy() - # pred[:, :4] = scale_coords([self.IMAGE_SIZE] * 2, pred[:, :4], size).round() - pred[:, :4] = scale_boxes([self.IMAGE_SIZE] * 2, pred[:, :4], size).round() - # Normalize the coordinates for timelapse compatibility - normalized_coords = [[x1 / size[1], y1 / size[0], x2 / size[1], y2 / size[0]] for x1, y1, x2, y2 in pred[:, :4]] - res = self.results_generation(pred, path, id_strip) - res["normalized_coords"] = normalized_coords - batch_results.append(res) - pbar.update(1) - results.extend(batch_results) - return results diff --git a/PytorchWildlife/models/detection/ultralytics_based/yolov8_base.py b/PytorchWildlife/models/detection/ultralytics_based/yolov8_base.py deleted file mode 100644 index 50d6cfbbb..000000000 --- a/PytorchWildlife/models/detection/ultralytics_based/yolov8_base.py +++ /dev/null @@ -1,220 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -""" YoloV8 base detector class. """ - -# Importing basic libraries - -import os -import wget -import numpy as np -from tqdm import tqdm -from PIL import Image -import supervision as sv - -import torch -from torch.utils.data import DataLoader - -from ultralytics.models import yolo, rtdetr - -from ..base_detector import BaseDetector -from ....data import transforms as pw_trans -from ....data import datasets as pw_data - - -class YOLOV8Base(BaseDetector): - """ - Base detector class for the new ultralytics YOLOV8 framework. This class provides utility methods for - loading the model, generating results, and performing single and batch image detections. - This base detector class is also compatible with all the new ultralytics models including YOLOV9, - RTDetr, and more. - """ - def __init__(self, weights=None, device="cpu", url=None, transform=None): - """ - Initialize the YOLOV8 detector. - - Args: - weights (str, optional): - Path to the model weights. Defaults to None. - device (str, optional): - Device for model inference. Defaults to "cpu". - url (str, optional): - URL to fetch the model weights. Defaults to None. - """ - super(YOLOV8Base, self).__init__(weights=weights, device=device, url=url) - self.transform = transform - self._load_model(weights, self.device, url) - - def _load_model(self, weights=None, device="cpu", url=None): - """ - Load the YOLOV8 model weights. - - Args: - weights (str, optional): - Path to the model weights. Defaults to None. - device (str, optional): - Device for model inference. Defaults to "cpu". - url (str, optional): - URL to fetch the model weights. Defaults to None. - Raises: - Exception: If weights are not provided. - """ - - if self.MODEL_NAME == 'MDV6b-rtdetrl.pt': - self.predictor = rtdetr.RTDETRPredictor() - else: - self.predictor = yolo.detect.DetectionPredictor() - # self.predictor.args.device = device # Will uncomment later - self.predictor.args.imgsz = self.IMAGE_SIZE - self.predictor.args.save = False # Will see if we want to use ultralytics native inference saving functions. - - if weights: - self.predictor.setup_model(weights) - elif url: - if not os.path.exists(os.path.join(torch.hub.get_dir(), "checkpoints", self.MODEL_NAME)): - os.makedirs(os.path.join(torch.hub.get_dir(), "checkpoints"), exist_ok=True) - weights = wget.download(url, out=os.path.join(torch.hub.get_dir(), "checkpoints")) - else: - weights = os.path.join(torch.hub.get_dir(), "checkpoints", self.MODEL_NAME) - self.predictor.setup_model(weights) - else: - raise Exception("Need weights for inference.") - - if not self.transform: - self.transform = pw_trans.MegaDetector_v5_Transform(target_size=self.IMAGE_SIZE, - stride=self.STRIDE) - - def results_generation(self, preds, img_id, id_strip=None) -> dict: - """ - Generate results for detection based on model predictions. - - Args: - preds (ultralytics.engine.results.Results): - Model predictions. - img_id (str): - Image identifier. - id_strip (str, optional): - Strip specific characters from img_id. Defaults to None. - - Returns: - dict: Dictionary containing image ID, detections, and labels. - """ - xyxy = preds.boxes.xyxy.cpu().numpy() - confidence = preds.boxes.conf.cpu().numpy() - class_id = preds.boxes.cls.cpu().numpy().astype(int) - - results = {"img_id": str(img_id).strip(id_strip)} - results["detections"] = sv.Detections( - xyxy=xyxy, - confidence=confidence, - class_id=class_id - ) - - results["labels"] = [ - f"{self.CLASS_NAMES[class_id]} {confidence:0.2f}" - for _, _, confidence, class_id, _, _ in results["detections"] - ] - - return results - - - def single_image_detection(self, img, img_path=None, det_conf_thres=0.2, id_strip=None) -> dict: - """ - Perform detection on a single image. - - Args: - img (str or ndarray): - Image path or ndarray of images. - img_path (str, optional): - Image path or identifier. - det_conf_thres (float, optional): - Confidence threshold for predictions. Defaults to 0.2. - id_strip (str, optional): - Characters to strip from img_id. Defaults to None. - - Returns: - dict: Detection results. - """ - - if type(img) == str: - if img_path is None: - img_path = img - img = np.array(Image.open(img_path).convert("RGB")) - img_size = img.shape - - self.predictor.args.batch = 1 - self.predictor.args.conf = det_conf_thres - - det_results = list(self.predictor.stream_inference([img])) - - res = self.results_generation(det_results[0], img_path, id_strip) - - normalized_coords = [[x1 / img_size[1], y1 / img_size[0], x2 / img_size[1], y2 / img_size[0]] - for x1, y1, x2, y2 in res["detections"].xyxy] - res["normalized_coords"] = normalized_coords - - return res - - def batch_image_detection(self, data_source, batch_size: int = 16, det_conf_thres: float = 0.2, id_strip: str = None) -> list[dict]: - """ - Perform detection on a batch of images. - - Args: - data_source (str or List[np.ndarray]): Either path containing images for inference or list of numpy arrays (RGB format, shape: H×W×3). - batch_size (int, optional): Batch size for inference. Defaults to 16. - det_conf_thres (float, optional): Confidence threshold for predictions. Defaults to 0.2. - id_strip (str, optional): Characters to strip from img_id. Defaults to None. - - Returns: - list[dict]: List of detection results for all images. - """ - self.predictor.args.batch = batch_size - self.predictor.args.conf = det_conf_thres - - # Handle numpy array input - if isinstance(data_source, (list, np.ndarray)): - results = [] - num_batches = (len(data_source) + batch_size - 1) // batch_size # Calculate total batches - - with tqdm(total=num_batches) as pbar: - for start_idx in range(0, len(data_source), batch_size): - batch_arrays = data_source[start_idx:start_idx + batch_size] - det_results = self.predictor.stream_inference(batch_arrays) - - for idx, preds in enumerate(det_results): - res = self.results_generation(preds, f"{start_idx + idx}", id_strip) - # Get size directly from numpy array - img_height, img_width = batch_arrays[idx].shape[:2] - normalized_coords = [[x1/img_width, y1/img_height, x2/img_width, y2/img_height] - for x1, y1, x2, y2 in res["detections"].xyxy] - res["normalized_coords"] = normalized_coords - results.append(res) - pbar.update(1) - return results - - # Handle image directory input - dataset = pw_data.DetectionImageFolder( - data_source, - transform=self.transform, - ) - - # Creating a DataLoader for batching and parallel processing of the images - loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, - pin_memory=True, num_workers=0, drop_last=False - ) - - results = [] - with tqdm(total=len(loader)) as pbar: - for batch_index, (imgs, paths, sizes) in enumerate(loader): - det_results = self.predictor.stream_inference(paths) - batch_results = [] - for idx, preds in enumerate(det_results): - res = self.results_generation(preds, paths[idx], id_strip) - size = preds.orig_shape - # Normalize the coordinates for timelapse compatibility - normalized_coords = [[x1 / size[1], y1 / size[0], x2 / size[1], y2 / size[0]] for x1, y1, x2, y2 in res["detections"].xyxy] - res["normalized_coords"] = normalized_coords - results.append(res) - pbar.update(1) - results.extend(batch_results) - return results diff --git a/PytorchWildlife/models/detection/ultralytics_based/yolov8_distributed.py b/PytorchWildlife/models/detection/ultralytics_based/yolov8_distributed.py deleted file mode 100644 index 688ea780d..000000000 --- a/PytorchWildlife/models/detection/ultralytics_based/yolov8_distributed.py +++ /dev/null @@ -1,232 +0,0 @@ -"""" -YoloV8 base detector class. -Modified to support PyTorch DDP framework -""" - - -import os -import time -from glob import glob -import supervision as sv -import numpy as np -import pandas as pd -from PIL import Image -import wget -import torch - -from ultralytics.models import yolo, rtdetr -from torch.utils.data import DataLoader -from tqdm import tqdm - -from ..base_detector import BaseDetector -from ....data import transforms as pw_trans -from ....data import datasets as pw_data - -class YOLOV8_Distributed(BaseDetector): - """ - Distributed YoloV8 detector class. - This class provides utility methods for loading the model, generating results, - and performing batch image detections. - """ - - def __init__(self, weights=None, device="cpu", url=None, transform=None): - """ - Initialize the YOLOV8 detector. - - Args: - weights (str, optional): - Path to the model weights. Defaults to None. - device (str, optional): - Device for model inference. Defaults to "cpu". - url (str, optional): - URL to fetch the model weights. Defaults to None. - """ - self.transform = transform - super(YOLOV8_Distributed, self).__init__(weights=weights, device=device, url=url) - self._load_model(weights, self.device, url) - - def _load_model(self, weights=None, device="cpu", url=None): - """ - Load the YOLOV8 model weights. - - Args: - weights (str, optional): - Path to the model weights. Defaults to None. - device (str, optional): - Device for model inference. Defaults to "cpu". - url (str, optional): - URL to fetch the model weights. Defaults to None. - Raises: - Exception: If weights are not provided. - """ - - if self.MODEL_NAME == 'MDV6b-rtdetrl.pt': - self.predictor = rtdetr.RTDETRPredictor() - else: - self.predictor = yolo.detect.DetectionPredictor() - # self.predictor.args.device = device # Will uncomment later - self.predictor.args.imgsz = self.IMAGE_SIZE - self.predictor.args.save = False # Will see if we want to use ultralytics native inference saving functions. - - if weights: - self.predictor.setup_model(weights) - elif url: - if not os.path.exists(os.path.join(torch.hub.get_dir(), "checkpoints", self.MODEL_NAME)): - os.makedirs(os.path.join(torch.hub.get_dir(), "checkpoints"), exist_ok=True) - weights = wget.download(url, out=os.path.join(torch.hub.get_dir(), "checkpoints")) - else: - weights = os.path.join(torch.hub.get_dir(), "checkpoints", self.MODEL_NAME) - self.predictor.setup_model(weights) - else: - raise Exception("Need weights for inference.") - - if not self.transform: - self.transform = pw_trans.MegaDetector_v5_Transform(target_size=self.IMAGE_SIZE, - stride=self.STRIDE) - - def results_generation(self, preds, img_id, id_strip=None) -> dict: - """ - Generate results for detection based on model predictions. - - Args: - preds (ultralytics.engine.results.Results): - Model predictions. - img_id (str): - Image identifier. - id_strip (str, optional): - Strip specific characters from img_id. Defaults to None. - - Returns: - dict: Dictionary containing image ID, detections, and labels. - """ - xyxy = preds.boxes.xyxy.cpu().numpy() - confidence = preds.boxes.conf.cpu().numpy() - class_id = preds.boxes.cls.cpu().numpy().astype(int) - - results = {"img_id": str(img_id).strip(id_strip)} - # results["detections"] = sv.Detections( - # xyxy=xyxy, - # confidence=confidence, - # class_id=class_id - # ) - results["detections_xyxy"] = xyxy - results["detections_confidence"] = confidence - results["detections_class_id"] = class_id - - # results["labels"] = [ - # f"{self.CLASS_NAMES[class_id]} {confidence:0.2f}" - # for _, _, confidence, class_id, _, _ in results["detections"] - # ] - - results["labels"] = [ - f"{self.CLASS_NAMES[cls_id]} {conf:0.2f}" - for cls_id, conf in zip(class_id, confidence) - ] - - results["n_animal_detected"] = np.sum(class_id == 0) - - return results - - def batch_image_detection(self, loader, batch_size, global_rank, local_rank, output_dir, det_conf_thres=0.2, checkpoint_frequency = 1000): - - """ - Perform batch image detection using the YOLOV8 model. - - Args: - loader (torch.utils.data.DataLoader): - DataLoader for input images. - batch_size (int): - Size of the batch for detection. - global_rank (int): - Global rank of the process. - local_rank (int): - Local rank of the process. - output_dir (str): - Directory to save detection results. - det_conf_thres (float, optional): - Confidence threshold for detections. Defaults to 0.2. - checkpoint_frequency (int, optional): - Frequency of saving intermediate results. Defaults to 1000. - """ - os.makedirs(output_dir, exist_ok=True) - self.predictor.args.batch = batch_size - self.predictor.args.conf = det_conf_thres - self.predictor.args.device = local_rank - - - # Create checkpoint directory - # Track batches and processed items - results = { - "img_id": [], - "detections_xyxy": [], - "detections_confidence": [], - "detections_class_id": [], - "labels": [], - "n_animal_detected": [], - "normalized_coords": [] - } - - checkpoint_dir = os.path.join(output_dir, f"checkpoints_rank{global_rank}") - os.makedirs(checkpoint_dir, exist_ok=True) - batch_counter = 0 - processed_count = 0 - start_time = time.time() - - for uuids, images in loader: - batch_counter += 1 - processed_count += len(images) - # images: tensor of shape [batch_size, 3, H, W] - # Assuming images are transformed & Standardized - det_results = self.predictor.stream_inference(images) - - for idx, preds in enumerate(det_results): - res = self.results_generation(preds, uuids[idx]) - - size = preds.orig_shape - normalized_coords = [[x1 / size[1], y1 / size[0], x2 / size[1], y2 / size[0]] for x1, y1, x2, y2 in res["detections_xyxy"]] - res["normalized_coords"] = normalized_coords - - #results.append(res) - results["img_id"].append(res["img_id"]) - results["detections_xyxy"].append(res["detections_xyxy"].tolist()) - results["detections_confidence"].append(res["detections_confidence"].tolist()) - results["detections_class_id"].append(res["detections_class_id"].tolist()) - results["labels"].append(res["labels"]) - results["n_animal_detected"].append(int(res["n_animal_detected"])) - results["normalized_coords"].append(res["normalized_coords"]) - - if batch_counter % checkpoint_frequency == 0: - elapsed = time.time() - start_time - print(f"[Rank {global_rank}] Processed {processed_count} images in {elapsed}") - - # Save intermediate results - checkpoint_path = os.path.join( - checkpoint_dir, - f"checkpoint_{batch_counter:06d}.parquet" - ) - - df = pd.DataFrame({ - "img_id": results["img_id"], - "n_animal_detected": results["n_animal_detected"] - }) - df.to_parquet(checkpoint_path, index=False) - print(f"[Rank {global_rank}] Saved checkpoint to {checkpoint_path}") - - # Save results to disk - os.makedirs(output_dir, exist_ok=True) - df = pd.DataFrame({ - "img_id": results["img_id"], - "n_animal_detected": results["n_animal_detected"] - }) - out_path = os.path.join(output_dir, f"predictions_rank{global_rank}.parquet") - df.to_parquet(out_path, index=False) - print(f"[rank {global_rank}] Saved predictions to {out_path}") - - return results - - - - - - - \ No newline at end of file diff --git a/PytorchWildlife/models/detection/yolo_mit/__init__.py b/PytorchWildlife/models/detection/yolo_mit/__init__.py deleted file mode 100644 index 6d410c7dd..000000000 --- a/PytorchWildlife/models/detection/yolo_mit/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .yolo_mit_base import * -from .megadetectorv6_mit import * \ No newline at end of file diff --git a/PytorchWildlife/models/detection/yolo_mit/megadetectorv6_mit.py b/PytorchWildlife/models/detection/yolo_mit/megadetectorv6_mit.py deleted file mode 100644 index 456bdcd7c..000000000 --- a/PytorchWildlife/models/detection/yolo_mit/megadetectorv6_mit.py +++ /dev/null @@ -1,44 +0,0 @@ - -from .yolo_mit_base import YOLOMITBase - -__all__ = [ - 'MegaDetectorV6MIT' -] - -class MegaDetectorV6MIT(YOLOMITBase): - """ - MegaDetectorV6 is a specialized class derived from the YOLOMITBase class - that is specifically designed for detecting animals, persons, and vehicles. - - Attributes: - CLASS_NAMES (dict): Mapping of class IDs to their respective names. - """ - - CLASS_NAMES = { - 0: "animal", - 1: "person", - 2: "vehicle" - } - - def __init__(self, weights=None, device="cpu", pretrained=True, version='MDV6-mit-yolov9-c'): - """ - Initializes the MegaDetectorV6 model with the option to load pretrained weights. - - Args: - weights (str, optional): Path to the weights file. - device (str, optional): Device to load the model on (e.g., "cpu" or "cuda"). Default is "cpu". - pretrained (bool, optional): Whether to load the pretrained model. Default is True. - version (str, optional): Version of the model to load. Default is 'MDV6-mit-yolov9-c'. - """ - self.IMAGE_SIZE = 640 - - if version == 'MDV6-mit-yolov9-c': - url = "https://zenodo.org/records/15398270/files/MDV6-mit-yolov9-c.ckpt?download=1" - self.MODEL_NAME = "MDV6-mit-yolov9-c.ckpt" - elif version == 'MDV6-mit-yolov9-e': - url = "https://zenodo.org/records/15398270/files/MDV6-mit-yolov9-e.ckpt?download=1" - self.MODEL_NAME = "MDV6-mit-yolov9-e.ckpt" - else: - raise ValueError('Select a valid model version: MDV6-mit-yolov9-c or MDV6-mit-yolov9-e') - - super(MegaDetectorV6MIT, self).__init__(weights=weights, device=device, url=url) \ No newline at end of file diff --git a/PytorchWildlife/models/detection/yolo_mit/yolo/__init__.py b/PytorchWildlife/models/detection/yolo_mit/yolo/__init__.py deleted file mode 100644 index 02d8f69c4..000000000 --- a/PytorchWildlife/models/detection/yolo_mit/yolo/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -from yolo.model.yolo import create_model -from yolo.config import Config, NMSConfig -from yolo.tools.data_loader import AugmentationComposer, create_dataloader -from yolo.utils.model_utils import PostProcess -from yolo.utils.bounding_box_utils import create_converter - -all = [ - "create_model", - "Config", - "NMSConfig", - "AugmentationComposer" - "create_dataloader", - "PostProcess", - "create_converter", -] diff --git a/PytorchWildlife/models/detection/yolo_mit/yolo/config.py b/PytorchWildlife/models/detection/yolo_mit/yolo/config.py deleted file mode 100644 index b8b69f5d8..000000000 --- a/PytorchWildlife/models/detection/yolo_mit/yolo/config.py +++ /dev/null @@ -1,168 +0,0 @@ -from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Union -from torch import nn - - -@dataclass -class AnchorConfig: - strides: List[int] - reg_max: Optional[int] - anchor_num: Optional[int] - anchor: List[List[int]] - - -@dataclass -class LayerConfg: - args: Dict - source: Union[int, str, List[int]] - tags: str - - -@dataclass -class BlockConfig: - block: List[Dict[str, LayerConfg]] - - -@dataclass -class ModelConfig: - name: Optional[str] - anchor: AnchorConfig - model: Dict[str, BlockConfig] - - -@dataclass -class DownloadDetail: - url: str - file_size: int - - -@dataclass -class DownloadOptions: - details: Dict[str, DownloadDetail] - - -@dataclass -class DatasetConfig: - path: str - class_num: int - class_list: List[str] - auto_download: Optional[DownloadOptions] - - -@dataclass -class DataConfig: - shuffle: bool - batch_size: int - pin_memory: bool - cpu_num: int - image_size: List[int] - data_augment: Dict[str, int] - source: Optional[Union[str, int]] - dynamic_shape: Optional[bool] - - -@dataclass -class OptimizerArgs: - lr: float - weight_decay: float - momentum: float - - -@dataclass -class OptimizerConfig: - type: str - args: OptimizerArgs - - -@dataclass -class MatcherConfig: - iou: str - topk: int - factor: Dict[str, int] - - -@dataclass -class LossConfig: - objective: Dict[str, int] - aux: Union[bool, float] - matcher: MatcherConfig - - -@dataclass -class SchedulerConfig: - type: str - warmup: Dict[str, Union[int, float]] - args: Dict[str, Any] - - -@dataclass -class EMAConfig: - enable: bool - decay: float - - -@dataclass -class NMSConfig: - min_confidence: float - min_iou: float - max_bbox: int - - -@dataclass -class InferenceConfig: - task: str - nms: NMSConfig - data: DataConfig - fast_inference: Optional[None] - save_predict: bool - - -@dataclass -class ValidationConfig: - task: str - nms: NMSConfig - data: DataConfig - - -@dataclass -class TrainConfig: - task: str - epoch: int - data: DataConfig - optimizer: OptimizerConfig - loss: LossConfig - scheduler: SchedulerConfig - ema: EMAConfig - validation: ValidationConfig - - -@dataclass -class Config: - task: Union[TrainConfig, InferenceConfig, ValidationConfig] - dataset: DatasetConfig - model: ModelConfig - name: str - - device: Union[str, int, List[int]] - cpu_num: int - - image_size: List[int] - - out_path: str - exist_ok: bool - - lucky_number: 10 - use_wandb: bool - use_tensorboard: bool - - weight: Optional[str] - - -@dataclass -class YOLOLayer(nn.Module): - source: Union[int, str, List[int]] - output: bool - tags: str - layer_type: str - usable: bool - external: Optional[dict] \ No newline at end of file diff --git a/PytorchWildlife/models/detection/yolo_mit/yolo/model/__init__.py b/PytorchWildlife/models/detection/yolo_mit/yolo/model/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/PytorchWildlife/models/detection/yolo_mit/yolo/model/module.py b/PytorchWildlife/models/detection/yolo_mit/yolo/model/module.py deleted file mode 100644 index 87b211d43..000000000 --- a/PytorchWildlife/models/detection/yolo_mit/yolo/model/module.py +++ /dev/null @@ -1,414 +0,0 @@ -from typing import Any, Dict, List, Optional, Tuple, Union -import torch -import torch.nn.functional as F -from torch import Tensor, nn -from torch.nn.common_types import _size_2_t -import inspect - -# ----------- Utils ----------- # -def get_layer_map(): - """ - Dynamically generates a dictionary mapping class names to classes, - filtering to include only those that are subclasses of nn.Module, - ensuring they are relevant neural network layers. - """ - layer_map = {} - from yolo.model import module - - for name, obj in inspect.getmembers(module, inspect.isclass): - if issubclass(obj, nn.Module) and obj is not nn.Module: - layer_map[name] = obj - return layer_map - - -def auto_pad(kernel_size: _size_2_t, dilation: _size_2_t = 1, **kwargs) -> Tuple[int, int]: - """ - Auto Padding for the convolution blocks - """ - if isinstance(kernel_size, int): - kernel_size = (kernel_size, kernel_size) - if isinstance(dilation, int): - dilation = (dilation, dilation) - - pad_h = ((kernel_size[0] - 1) * dilation[0]) // 2 - pad_w = ((kernel_size[1] - 1) * dilation[1]) // 2 - return (pad_h, pad_w) - - -def create_activation_function(activation: str) -> nn.Module: - """ - Retrieves an activation function from the PyTorch nn module based on its name, case-insensitively. - """ - if not activation or activation.lower() in ["false", "none"]: - return nn.Identity() - - activation_map = { - name.lower(): obj - for name, obj in nn.modules.activation.__dict__.items() - if isinstance(obj, type) and issubclass(obj, nn.Module) - } - if activation.lower() in activation_map: - return activation_map[activation.lower()](inplace=True) - else: - raise ValueError(f"Activation function '{activation}' is not found in torch.nn") - - -def round_up(x: Union[int, Tensor], div: int = 1) -> Union[int, Tensor]: - """ - Rounds up `x` to the bigger-nearest multiple of `div`. - """ - return x + (-x % div) - - -# ----------- Basic Class ----------- # -class Conv(nn.Module): - """A basic convolutional block that includes convolution, batch normalization, and activation.""" - - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: _size_2_t, - *, - activation: Optional[str] = "SiLU", - **kwargs, - ): - super().__init__() - kwargs.setdefault("padding", auto_pad(kernel_size, **kwargs)) - self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, bias=False, **kwargs) - self.bn = nn.BatchNorm2d(out_channels, eps=1e-3, momentum=3e-2) - self.act = create_activation_function(activation) - - def forward(self, x: Tensor) -> Tensor: - return self.act(self.bn(self.conv(x))) - - -class Pool(nn.Module): - """A generic pooling block supporting 'max' and 'avg' pooling methods.""" - - def __init__(self, method: str = "max", kernel_size: _size_2_t = 2, **kwargs): - super().__init__() - kwargs.setdefault("padding", auto_pad(kernel_size, **kwargs)) - pool_classes = {"max": nn.MaxPool2d, "avg": nn.AvgPool2d} - self.pool = pool_classes[method.lower()](kernel_size=kernel_size, **kwargs) - - def forward(self, x: Tensor) -> Tensor: - return self.pool(x) - - -class Concat(nn.Module): - def __init__(self, dim=1): - super(Concat, self).__init__() - self.dim = dim - - def forward(self, x): - return torch.cat(x, self.dim) - - -# ----------- Detection Class ----------- # -class Detection(nn.Module): - """A single YOLO Detection head for detection models""" - - def __init__(self, in_channels: Tuple[int], num_classes: int, *, reg_max: int = 16, use_group: bool = True): - super().__init__() - - groups = 4 if use_group else 1 - anchor_channels = 4 * reg_max - - first_neck, in_channels = in_channels - anchor_neck = max(round_up(first_neck // 4, groups), anchor_channels, reg_max) - class_neck = max(first_neck, min(num_classes * 2, 128)) - - self.anchor_conv = nn.Sequential( - Conv(in_channels, anchor_neck, 3), - Conv(anchor_neck, anchor_neck, 3, groups=groups), - nn.Conv2d(anchor_neck, anchor_channels, 1, groups=groups), - ) - self.class_conv = nn.Sequential( - Conv(in_channels, class_neck, 3), Conv(class_neck, class_neck, 3), nn.Conv2d(class_neck, num_classes, 1) - ) - - self.anc2vec = Anchor2Vec(reg_max=reg_max) - - self.anchor_conv[-1].bias.data.fill_(1.0) - self.class_conv[-1].bias.data.fill_(-10) # TODO: math.log(5 * 4 ** idx / 80 ** 3) - - def forward(self, x: Tensor) -> Tuple[Tensor]: - anchor_x = self.anchor_conv(x) - class_x = self.class_conv(x) - anchor_x, vector_x = self.anc2vec(anchor_x) - return class_x, anchor_x, vector_x - - -class MultiheadDetection(nn.Module): - """Mutlihead Detection module for Dual detect or Triple detect""" - - def __init__(self, in_channels: List[int], num_classes: int, **head_kwargs): - super().__init__() - DetectionHead = Detection - - if head_kwargs.pop("version", None) == "v7": - DetectionHead = IDetection - - self.heads = nn.ModuleList( - [DetectionHead((in_channels[0], in_channel), num_classes, **head_kwargs) for in_channel in in_channels] - ) - - def forward(self, x_list: List[torch.Tensor]) -> List[torch.Tensor]: - return [head(x) for x, head in zip(x_list, self.heads)] - - -class Anchor2Vec(nn.Module): - def __init__(self, reg_max: int = 16) -> None: - super().__init__() - reverse_reg = torch.arange(reg_max, dtype=torch.float32).view(1, reg_max, 1, 1, 1) - self.anc2vec = nn.Conv3d(in_channels=reg_max, out_channels=1, kernel_size=1, bias=False) - self.anc2vec.weight = nn.Parameter(reverse_reg, requires_grad=False) - - def forward(self, anchor_x: Tensor) -> Tensor: - #anchor_x = rearrange(anchor_x, "B (P R) h w -> B R P h w", P=4) - B, PR, h, w = anchor_x.shape - P = 4 - R = PR // P - anchor_x = anchor_x.reshape(B, P, R, h, w).permute(0, 2, 1, 3, 4) - vector_x = anchor_x.softmax(dim=1) - vector_x = self.anc2vec(vector_x)[:, 0] - return anchor_x, vector_x - - -# ----------- Backbone Class ----------- # -class RepConv(nn.Module): - """A convolutional block that combines two convolution layers (kernel and point-wise).""" - - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: _size_2_t = 3, - *, - activation: Optional[str] = "SiLU", - **kwargs, - ): - super().__init__() - self.act = create_activation_function(activation) - self.conv1 = Conv(in_channels, out_channels, kernel_size, activation=False, **kwargs) - self.conv2 = Conv(in_channels, out_channels, 1, activation=False, **kwargs) - - def forward(self, x: Tensor) -> Tensor: - return self.act(self.conv1(x) + self.conv2(x)) - - -class Bottleneck(nn.Module): - """A bottleneck block with optional residual connections.""" - - def __init__( - self, - in_channels: int, - out_channels: int, - *, - kernel_size: Tuple[int, int] = (3, 3), - residual: bool = True, - expand: float = 1.0, - **kwargs, - ): - super().__init__() - neck_channels = int(out_channels * expand) - self.conv1 = RepConv(in_channels, neck_channels, kernel_size[0], **kwargs) - self.conv2 = Conv(neck_channels, out_channels, kernel_size[1], **kwargs) - self.residual = residual - - if residual and (in_channels != out_channels): - self.residual = False - - def forward(self, x: torch.Tensor) -> torch.Tensor: - y = self.conv2(self.conv1(x)) - return x + y if self.residual else y - - -class RepNCSP(nn.Module): - """RepNCSP block with convolutions, split, and bottleneck processing.""" - - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: int = 1, - *, - csp_expand: float = 0.5, - repeat_num: int = 1, - neck_args: Dict[str, Any] = {}, - **kwargs, - ): - super().__init__() - - neck_channels = int(out_channels * csp_expand) - self.conv1 = Conv(in_channels, neck_channels, kernel_size, **kwargs) - self.conv2 = Conv(in_channels, neck_channels, kernel_size, **kwargs) - self.conv3 = Conv(2 * neck_channels, out_channels, kernel_size, **kwargs) - - self.bottleneck = nn.Sequential( - *[Bottleneck(neck_channels, neck_channels, **neck_args) for _ in range(repeat_num)] - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x1 = self.bottleneck(self.conv1(x)) - x2 = self.conv2(x) - return self.conv3(torch.cat((x1, x2), dim=1)) - - -class ELAN(nn.Module): - """ELAN structure.""" - - def __init__( - self, - in_channels: int, - out_channels: int, - part_channels: int, - *, - process_channels: Optional[int] = None, - **kwargs, - ): - super().__init__() - - if process_channels is None: - process_channels = part_channels // 2 - - self.conv1 = Conv(in_channels, part_channels, 1, **kwargs) - self.conv2 = Conv(part_channels // 2, process_channels, 3, padding=1, **kwargs) - self.conv3 = Conv(process_channels, process_channels, 3, padding=1, **kwargs) - self.conv4 = Conv(part_channels + 2 * process_channels, out_channels, 1, **kwargs) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x1, x2 = self.conv1(x).chunk(2, 1) - x3 = self.conv2(x2) - x4 = self.conv3(x3) - x5 = self.conv4(torch.cat([x1, x2, x3, x4], dim=1)) - return x5 - - -class RepNCSPELAN(nn.Module): - """RepNCSPELAN block combining RepNCSP blocks with ELAN structure.""" - - def __init__( - self, - in_channels: int, - out_channels: int, - part_channels: int, - *, - process_channels: Optional[int] = None, - csp_args: Dict[str, Any] = {}, - csp_neck_args: Dict[str, Any] = {}, - **kwargs, - ): - super().__init__() - - if process_channels is None: - process_channels = part_channels // 2 - - self.conv1 = Conv(in_channels, part_channels, 1, **kwargs) - self.conv2 = nn.Sequential( - RepNCSP(part_channels // 2, process_channels, neck_args=csp_neck_args, **csp_args), - Conv(process_channels, process_channels, 3, padding=1, **kwargs), - ) - self.conv3 = nn.Sequential( - RepNCSP(process_channels, process_channels, neck_args=csp_neck_args, **csp_args), - Conv(process_channels, process_channels, 3, padding=1, **kwargs), - ) - self.conv4 = Conv(part_channels + 2 * process_channels, out_channels, 1, **kwargs) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x1, x2 = self.conv1(x).chunk(2, 1) - x3 = self.conv2(x2) - x4 = self.conv3(x3) - x5 = self.conv4(torch.cat([x1, x2, x3, x4], dim=1)) - return x5 - - -class AConv(nn.Module): - """Downsampling module combining average and max pooling with convolution for feature reduction.""" - - def __init__(self, in_channels: int, out_channels: int): - super().__init__() - mid_layer = {"kernel_size": 3, "stride": 2} - self.avg_pool = Pool("avg", kernel_size=2, stride=1) - self.conv = Conv(in_channels, out_channels, **mid_layer) - - def forward(self, x: Tensor) -> Tensor: - x = self.avg_pool(x) - x = self.conv(x) - return x - -class ADown(nn.Module): - """Downsampling module combining average and max pooling with convolution for feature reduction.""" - - def __init__(self, in_channels: int, out_channels: int): - super().__init__() - half_in_channels = in_channels // 2 - half_out_channels = out_channels // 2 - mid_layer = {"kernel_size": 3, "stride": 2} - self.avg_pool = Pool("avg", kernel_size=2, stride=1) - self.conv1 = Conv(half_in_channels, half_out_channels, **mid_layer) - self.max_pool = Pool("max", **mid_layer) - self.conv2 = Conv(half_in_channels, half_out_channels, kernel_size=1) - - def forward(self, x: Tensor) -> Tensor: - x = self.avg_pool(x) - x1, x2 = x.chunk(2, dim=1) - x1 = self.conv1(x1) - x2 = self.max_pool(x2) - x2 = self.conv2(x2) - return torch.cat((x1, x2), dim=1) - - -class CBLinear(nn.Module): - """Convolutional block that outputs multiple feature maps split along the channel dimension.""" - - def __init__(self, in_channels: int, out_channels: List[int], kernel_size: int = 1, **kwargs): - super(CBLinear, self).__init__() - kwargs.setdefault("padding", auto_pad(kernel_size, **kwargs)) - self.conv = nn.Conv2d(in_channels, sum(out_channels), kernel_size, **kwargs) - self.out_channels = list(out_channels) - - def forward(self, x: Tensor) -> Tuple[Tensor]: - x = self.conv(x) - return x.split(self.out_channels, dim=1) - -class CBFuse(nn.Module): - def __init__(self, index: List[int], mode: str = "nearest"): - super().__init__() - self.idx = index - self.mode = mode - - def forward(self, x_list: List[torch.Tensor]) -> List[Tensor]: - target = x_list[-1] - target_size = target.shape[2:] # Batch, Channel, H, W - - res = [F.interpolate(x[pick_id], size=target_size, mode=self.mode) for pick_id, x in zip(self.idx, x_list)] - out = torch.stack(res + [target]).sum(dim=0) - return out - -class SPPELAN(nn.Module): - """SPPELAN module comprising multiple pooling and convolution layers.""" - - def __init__(self, in_channels: int, out_channels: int, neck_channels: Optional[int] = None): - super(SPPELAN, self).__init__() - neck_channels = neck_channels or out_channels // 2 - - self.conv1 = Conv(in_channels, neck_channels, kernel_size=1) - self.pools = nn.ModuleList([Pool("max", 5, stride=1) for _ in range(3)]) - self.conv5 = Conv(4 * neck_channels, out_channels, kernel_size=1) - - def forward(self, x: Tensor) -> Tensor: - features = [self.conv1(x)] - for pool in self.pools: - features.append(pool(features[-1])) - return self.conv5(torch.cat(features, dim=1)) - - -class UpSample(nn.Module): - def __init__(self, **kwargs): - super().__init__() - self.UpSample = nn.Upsample(**kwargs) - - def forward(self, x): - return self.UpSample(x) \ No newline at end of file diff --git a/PytorchWildlife/models/detection/yolo_mit/yolo/model/yolo.py b/PytorchWildlife/models/detection/yolo_mit/yolo/model/yolo.py deleted file mode 100644 index 3834dd43f..000000000 --- a/PytorchWildlife/models/detection/yolo_mit/yolo/model/yolo.py +++ /dev/null @@ -1,180 +0,0 @@ -from collections import OrderedDict -from pathlib import Path -from typing import Dict, List, Optional, Union - -import torch -from omegaconf import ListConfig, OmegaConf -from torch import nn - -from yolo.config import ModelConfig, YOLOLayer -from yolo.tools.dataset_preparation import prepare_weight -from yolo.model.module import get_layer_map - - -class YOLO(nn.Module): - """ - A preliminary YOLO (You Only Look Once) model class still under development. - - Parameters: - model_cfg: Configuration for the YOLO model. Expected to define the layers, - parameters, and any other relevant configuration details. - """ - - def __init__(self, model_cfg: ModelConfig, class_num: int = 80): - super(YOLO, self).__init__() - self.num_classes = class_num - self.layer_map = get_layer_map() # Get the map Dict[str: Module] - self.model: List[YOLOLayer] = nn.ModuleList() - self.reg_max = getattr(model_cfg.anchor, "reg_max", 16) - self.build_model(model_cfg.model) - - def build_model(self, model_arch: Dict[str, List[Dict[str, Dict[str, Dict]]]]): - self.layer_index = {} - output_dim, layer_idx = [3], 1 - for arch_name in model_arch: - for layer_idx, layer_spec in enumerate(model_arch[arch_name], start=layer_idx): - layer_type, layer_info = next(iter(layer_spec.items())) - layer_args = layer_info.get("args", {}) - - # Get input source - source = self.get_source_idx(layer_info.get("source", -1), layer_idx) - - # Find in channels - if any(module in layer_type for module in ["Conv", "ELAN", "ADown", "AConv", "CBLinear"]): - layer_args["in_channels"] = output_dim[source] - if any(module in layer_type for module in ["Detection", "Segmentation", "Classification"]): - if isinstance(source, list): - layer_args["in_channels"] = [output_dim[idx] for idx in source] - else: - layer_args["in_channel"] = output_dim[source] - layer_args["num_classes"] = self.num_classes - layer_args["reg_max"] = self.reg_max - - # create layers - layer = self.create_layer(layer_type, source, layer_info, **layer_args) - self.model.append(layer) - - if layer.tags: - if layer.tags in self.layer_index: - raise ValueError(f"Duplicate tag '{layer_info['tags']}' found.") - self.layer_index[layer.tags] = layer_idx - - out_channels = self.get_out_channels(layer_type, layer_args, output_dim, source) - output_dim.append(out_channels) - setattr(layer, "out_c", out_channels) - layer_idx += 1 - - def forward(self, x, external: Optional[Dict] = None, shortcut: Optional[str] = None): - y = {0: x, **(external or {})} - output = dict() - for index, layer in enumerate(self.model, start=1): - if isinstance(layer.source, list): - model_input = [y[idx] for idx in layer.source] - else: - model_input = y[layer.source] - - external_input = {source_name: y[source_name] for source_name in layer.external} - - x = layer(model_input, **external_input) - y[-1] = x - if layer.usable: - y[index] = x - if layer.output: - output[layer.tags] = x - if layer.tags == shortcut: - return output - return output - - def get_out_channels(self, layer_type: str, layer_args: dict, output_dim: list, source: Union[int, list]): - if hasattr(layer_args, "out_channels"): - return layer_args["out_channels"] - if layer_type == "CBFuse": - return output_dim[source[-1]] - if isinstance(source, int): - return output_dim[source] - if isinstance(source, list): - return sum(output_dim[idx] for idx in source) - - def get_source_idx(self, source: Union[ListConfig, str, int], layer_idx: int): - if isinstance(source, ListConfig): - return [self.get_source_idx(index, layer_idx) for index in source] - if isinstance(source, str): - source = self.layer_index[source] - if source < -1: - source += layer_idx - if source > 0: # Using Previous Layer's Output - self.model[source - 1].usable = True - return source - - def create_layer(self, layer_type: str, source: Union[int, list], layer_info: Dict, **kwargs) -> YOLOLayer: - if layer_type in self.layer_map: - layer = self.layer_map[layer_type](**kwargs) - setattr(layer, "layer_type", layer_type) - setattr(layer, "source", source) - setattr(layer, "in_c", kwargs.get("in_channels", None)) - setattr(layer, "output", layer_info.get("output", False)) - setattr(layer, "tags", layer_info.get("tags", None)) - setattr(layer, "external", layer_info.get("external", [])) - setattr(layer, "usable", 0) - return layer - else: - raise ValueError(f"Unsupported layer type: {layer_type}") - - def save_load_weights(self, weights: Union[Path, OrderedDict]): - """ - Update the model's weights with the provided weights. - - args: - weights: A OrderedDict containing the new weights. - """ - if isinstance(weights, Path): - weights = torch.load(weights, map_location=torch.device("cpu"), weights_only=False) - if "state_dict" in weights: - weights = weights["state_dict"] - - # Drop the prefix 'model.model.' from the keys - if "model.model." in list(weights.keys())[0]: - weights = {k.replace("model.model.", ""): v for k, v in weights.items()} - - model_state_dict = self.model.state_dict() - - # TODO1: autoload old version weight - # TODO2: weight transform if num_class difference - - error_dict = {"Mismatch": set(), "Not Found": set()} - - for model_key, model_weight in model_state_dict.items(): - if model_key not in weights: - error_dict["Not Found"].add(tuple(model_key.split(".")[:-2])) - continue - if model_weight.shape != weights[model_key].shape: - error_dict["Mismatch"].add(tuple(model_key.split(".")[:-2])) - continue - model_state_dict[model_key] = weights[model_key] - - self.model.load_state_dict(model_state_dict) - - -def create_model(model_cfg: ModelConfig, weight_path: Union[bool, Path] = True, class_num: int = 80) -> YOLO: - """Constructs and returns a model from a Dictionary configuration file. - - Args: - config_file (dict): The configuration file of the model. - - Returns: - YOLO: An instance of the model defined by the given configuration. - """ - OmegaConf.set_struct(model_cfg, False) - model = YOLO(model_cfg, class_num) - if weight_path: - if weight_path == True: - weight_path = Path("weights") / f"{model_cfg.name}.pt" - elif isinstance(weight_path, str): - weight_path = Path(weight_path) - - if not weight_path.exists(): - prepare_weight(weight_path=weight_path) - if weight_path.exists(): - model.save_load_weights(weight_path) - - return model diff --git a/PytorchWildlife/models/detection/yolo_mit/yolo/tools/__init__.py b/PytorchWildlife/models/detection/yolo_mit/yolo/tools/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/PytorchWildlife/models/detection/yolo_mit/yolo/tools/data_augmentation.py b/PytorchWildlife/models/detection/yolo_mit/yolo/tools/data_augmentation.py deleted file mode 100644 index a8003e135..000000000 --- a/PytorchWildlife/models/detection/yolo_mit/yolo/tools/data_augmentation.py +++ /dev/null @@ -1,55 +0,0 @@ -from typing import List - -import numpy as np -import torch -from PIL import Image -from torchvision.transforms import functional as TF - - -class AugmentationComposer: - """Composes several transforms together.""" - - def __init__(self, transforms, image_size: int = [640, 640], base_size: int = 640): - self.transforms = transforms - # TODO: handle List of image_size [640, 640] - self.pad_resize = PadAndResize(image_size) - self.base_size = base_size - - for transform in self.transforms: - if hasattr(transform, "set_parent"): - transform.set_parent(self) - - def __call__(self, image, boxes=torch.zeros(0, 5)): - for transform in self.transforms: - image, boxes = transform(image, boxes) - image, boxes, rev_tensor = self.pad_resize(image, boxes) - image = TF.to_tensor(image) - return image, boxes, rev_tensor - - -class PadAndResize: - def __init__(self, image_size, background_color=(114, 114, 114)): - """Initialize the object with the target image size.""" - self.target_width, self.target_height = image_size - self.background_color = background_color - - def set_size(self, image_size: List[int]): - self.target_width, self.target_height = image_size - - def __call__(self, image: Image, boxes): - img_width, img_height = image.size - scale = min(self.target_width / img_width, self.target_height / img_height) - new_width, new_height = int(img_width * scale), int(img_height * scale) - - resized_image = image.resize((new_width, new_height), Image.Resampling.LANCZOS) - - pad_left = (self.target_width - new_width) // 2 - pad_top = (self.target_height - new_height) // 2 - padded_image = Image.new("RGB", (self.target_width, self.target_height), self.background_color) - padded_image.paste(resized_image, (pad_left, pad_top)) - - boxes[:, [1, 3]] = (boxes[:, [1, 3]] * new_width + pad_left) / self.target_width - boxes[:, [2, 4]] = (boxes[:, [2, 4]] * new_height + pad_top) / self.target_height - - transform_info = torch.tensor([scale, pad_left, pad_top, pad_left, pad_top]) - return padded_image, boxes, transform_info \ No newline at end of file diff --git a/PytorchWildlife/models/detection/yolo_mit/yolo/tools/data_loader.py b/PytorchWildlife/models/detection/yolo_mit/yolo/tools/data_loader.py deleted file mode 100644 index 67a91e5e4..000000000 --- a/PytorchWildlife/models/detection/yolo_mit/yolo/tools/data_loader.py +++ /dev/null @@ -1,231 +0,0 @@ -from pathlib import Path -from statistics import mean -from typing import Generator, List, Tuple, Union - -import numpy as np -import torch -from PIL import Image -from rich.progress import track -from torch import Tensor -from torch.utils.data import DataLoader, Dataset - -from yolo.config import DataConfig, DatasetConfig -from yolo.tools.data_augmentation import AugmentationComposer -from yolo.tools.dataset_preparation import prepare_dataset -from yolo.utils.dataset_utils import ( - create_image_metadata, - locate_label_paths, - scale_segmentation, - tensorlize, -) - - -class YoloDataset(Dataset): - def __init__(self, data_cfg: DataConfig, dataset_cfg: DatasetConfig, phase: str = "train2017"): - augment_cfg = data_cfg.data_augment - self.image_size = data_cfg.image_size - phase_name = dataset_cfg.get(phase, phase) - self.batch_size = data_cfg.batch_size - self.dynamic_shape = getattr(data_cfg, "dynamic_shape", False) - self.base_size = mean(self.image_size) - - transforms = [eval(aug)(prob) for aug, prob in augment_cfg.items()] - self.transform = AugmentationComposer(transforms, self.image_size, self.base_size) - self.transform.get_more_data = self.get_more_data - self.img_paths, self.bboxes, self.ratios = tensorlize(self.load_data(Path(dataset_cfg.path), phase_name)) - - def load_data(self, dataset_path: Path, phase_name: str): - """ - Loads data from a cache or generates a new cache for a specific dataset phase. - - Parameters: - dataset_path (Path): The root path to the dataset directory. - phase_name (str): The specific phase of the dataset (e.g., 'train', 'test') to load or generate data for. - - Returns: - dict: The loaded data from the cache for the specified phase. - """ - cache_path = dataset_path / f"{phase_name}.cache" - - if not cache_path.exists(): - data = self.filter_data(dataset_path, phase_name, self.dynamic_shape) - torch.save(data, cache_path) - else: - try: - data = torch.load(cache_path, weights_only=False) - except Exception as e: - raise e - - return data - - def filter_data(self, dataset_path: Path, phase_name: str, sort_image: bool = False) -> list: - """ - Filters and collects dataset information by pairing images with their corresponding labels. - - Parameters: - images_path (Path): Path to the directory containing image files. - labels_path (str): Path to the directory containing label files. - sort_image (bool): If True, sorts the dataset by the width-to-height ratio of images in descending order. - - Returns: - list: A list of tuples, each containing the path to an image file and its associated segmentation as a tensor. - """ - images_path = dataset_path / "images" / phase_name - labels_path, data_type = locate_label_paths(dataset_path, phase_name) - images_list = sorted([p.name for p in Path(images_path).iterdir() if p.is_file()]) - if data_type == "json": - annotations_index, image_info_dict = create_image_metadata(labels_path) - - data = [] - valid_inputs = 0 - for image_name in track(images_list, description="Filtering data"): - if not image_name.lower().endswith((".jpg", ".jpeg", ".png")): - continue - image_id = Path(image_name).stem - - if data_type == "json": - image_info = image_info_dict.get(image_id, None) - if image_info is None: - continue - annotations = annotations_index.get(image_info["id"], []) - image_seg_annotations = scale_segmentation(annotations, image_info) - elif data_type == "txt": - label_path = labels_path / f"{image_id}.txt" - if not label_path.is_file(): - continue - with open(label_path, "r") as file: - image_seg_annotations = [list(map(float, line.strip().split())) for line in file] - else: - image_seg_annotations = [] - - labels = self.load_valid_labels(image_id, image_seg_annotations) - - img_path = images_path / image_name - if sort_image: - with Image.open(img_path) as img: - width, height = img.size - else: - width, height = 0, 1 - data.append((img_path, labels, width / height)) - valid_inputs += 1 - - data = sorted(data, key=lambda x: x[2], reverse=True) - - return data - - def load_valid_labels(self, label_path: str, seg_data_one_img: list) -> Union[Tensor, None]: - """ - Loads valid COCO style segmentation data (values between [0, 1]) and converts it to bounding box coordinates - by finding the minimum and maximum x and y values. - - Parameters: - label_path (str): The filepath to the label file containing annotation data. - seg_data_one_img (list): The actual list of annotations (in segmentation format) - - Returns: - Tensor or None: A tensor of all valid bounding boxes if any are found; otherwise, None. - """ - bboxes = [] - for seg_data in seg_data_one_img: - cls = seg_data[0] - points = self.adapt_labels_list(seg_data[1:]) - points = np.array(points).reshape(-1, 2) - if (points >= 0).all() and (points <= 1).all(): - valid_points = points[(points >= 0) & (points <= 1)].reshape(-1, 2) - bbox = torch.tensor([cls, *valid_points.min(axis=0), *valid_points.max(axis=0)]) - bboxes.append(bbox) - - if bboxes: - return torch.stack(bboxes) - else: - return torch.zeros((0, 5)) - - def adapt_labels(self, bboxes: Tensor) -> Tensor: - """ - Adapt bounding box labels using vectorized operations. - - Args: - bboxes (Tensor): Tensor of bounding boxes in the format [class_id, width, height, x_center, y_center]. - - Returns: - Tensor: Tensor of adapted bounding boxes in the format [class_id, xmin, ymin, xmax, ymax]. - """ - class_ids = bboxes[:, 0] - widths = bboxes[:, 1] - heights = bboxes[:, 2] - x_centers = bboxes[:, 3] - y_centers = bboxes[:, 4] - - xmins = x_centers - widths / 2 - ymins = y_centers - heights / 2 - xmaxs = x_centers + widths / 2 - ymaxs = y_centers + heights / 2 - - adapted_bboxes = torch.stack([class_ids, xmins, ymins, xmaxs, ymaxs], dim=1) - - return adapted_bboxes - - def adapt_labels_list(self, points): - - x_center = points[0] - y_center = points[1] - width = points[2] - height = points[3] - - xmin = x_center - width / 2 - ymin = y_center - height / 2 - xmax = x_center + width / 2 - ymax = y_center + height / 2 - - return [xmin, ymin, xmax, ymax] - - def get_data(self, idx): - img_path, bboxes = self.img_paths[idx], self.bboxes[idx] - valid_mask = bboxes[:, 0] != -1 - with Image.open(img_path) as img: - img = img.convert("RGB") - return img, torch.from_numpy(bboxes[valid_mask]), img_path - - def get_more_data(self, num: int = 1): - indices = torch.randint(0, len(self), (num,)) - return [self.get_data(idx)[:2] for idx in indices] - - def _update_image_size(self, idx: int) -> None: - """Update image size based on dynamic shape and batch settings.""" - batch_start_idx = (idx // self.batch_size) * self.batch_size - image_ratio = self.ratios[batch_start_idx].clip(1 / 3, 3) - shift = ((self.base_size / 32 * (image_ratio - 1)) // (image_ratio + 1)) * 32 - - self.image_size = [int(self.base_size + shift), int(self.base_size - shift)] - self.transform.pad_resize.set_size(self.image_size) - - def __getitem__(self, idx) -> Tuple[Image.Image, Tensor, Tensor, List[str]]: - img, bboxes, img_path = self.get_data(idx) - - if self.dynamic_shape: - self._update_image_size(idx) - - img, bboxes, rev_tensor = self.transform(img, bboxes) - bboxes[:, [1, 3]] *= self.image_size[0] - bboxes[:, [2, 4]] *= self.image_size[1] - return img, bboxes, rev_tensor, img_path - - def __len__(self) -> int: - return len(self.bboxes) - - -def create_dataloader(data_cfg: DataConfig, dataset_cfg: DatasetConfig, task: str = "train"): - if task == "inference": - return StreamDataLoader(data_cfg) - - if getattr(dataset_cfg, "auto_download", False): - prepare_dataset(dataset_cfg, task) - dataset = YoloDataset(data_cfg, dataset_cfg, task) - - return DataLoader( - dataset, - batch_size=data_cfg.batch_size, - num_workers=data_cfg.cpu_num, - pin_memory=data_cfg.pin_memory, - collate_fn=collate_fn, - ) \ No newline at end of file diff --git a/PytorchWildlife/models/detection/yolo_mit/yolo/tools/dataset_preparation.py b/PytorchWildlife/models/detection/yolo_mit/yolo/tools/dataset_preparation.py deleted file mode 100644 index bd41c1ec4..000000000 --- a/PytorchWildlife/models/detection/yolo_mit/yolo/tools/dataset_preparation.py +++ /dev/null @@ -1,51 +0,0 @@ -from pathlib import Path -from typing import Optional - -import requests - -from yolo.config import DatasetConfig - - -def prepare_dataset(dataset_cfg: DatasetConfig, task: str): - """ - Prepares dataset by downloading and unzipping if necessary. - """ - data_dir = Path(dataset_cfg.path) - for data_type, settings in dataset_cfg.auto_download.items(): - base_url = settings["base_url"] - for dataset_type, dataset_args in settings.items(): - if dataset_type != "annotations" and dataset_cfg.get(task, task) != dataset_type: - continue - file_name = f"{dataset_args.get('file_name', dataset_type)}.zip" - url = f"{base_url}{file_name}" - local_zip_path = data_dir / file_name - extract_to = data_dir / data_type if data_type != "annotations" else data_dir - final_place = extract_to / dataset_type - - final_place.mkdir(parents=True, exist_ok=True) - if check_files(final_place, dataset_args.get("file_num")): - raise RuntimeError(f"Error verifying the {dataset_type} dataset after extraction.") - continue - - if not local_zip_path.exists(): - download_file(url, local_zip_path) - unzip_file(local_zip_path, extract_to) - - if not check_files(final_place, dataset_args.get("file_num")): - raise RuntimeError(f"Error verifying the {dataset_type} dataset after extraction.") - - -def prepare_weight(download_link: Optional[str] = None, weight_path: Path = Path("v9-c.pt")): - weight_name = weight_path.name - if download_link is None: - download_link = "https://github.com/MultimediaTechLab/YOLO/releases/download/v1.0-alpha/" - weight_link = f"{download_link}{weight_name}" - - if not weight_path.parent.is_dir(): - weight_path.parent.mkdir(parents=True, exist_ok=True) - - try: - download_file(weight_link, weight_path) - except requests.exceptions.RequestException as e: - raise RuntimeError(f"Failed to download the weight file: {e}") - diff --git a/PytorchWildlife/models/detection/yolo_mit/yolo/utils/__init__.py b/PytorchWildlife/models/detection/yolo_mit/yolo/utils/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/PytorchWildlife/models/detection/yolo_mit/yolo/utils/bounding_box_utils.py b/PytorchWildlife/models/detection/yolo_mit/yolo/utils/bounding_box_utils.py deleted file mode 100644 index b6e1c13e6..000000000 --- a/PytorchWildlife/models/detection/yolo_mit/yolo/utils/bounding_box_utils.py +++ /dev/null @@ -1,197 +0,0 @@ -from typing import List, Optional, Union - -import torch -from torch import Tensor, tensor -from torchvision.ops import batched_nms - -from yolo.config import AnchorConfig, NMSConfig -from yolo.model.yolo import YOLO - - -def generate_anchors(image_size: List[int], strides: List[int]): - """ - Find the anchor maps for each w, h. - - Args: - image_size List: the image size of augmented image size - strides List[8, 16, 32, ...]: the stride size for each predicted layer - - Returns: - all_anchors [HW x 2]: - all_scalers [HW]: The index of the best targets for each anchors - """ - W, H = image_size - anchors = [] - scaler = [] - for stride in strides: - anchor_num = W // stride * H // stride - scaler.append(torch.full((anchor_num,), stride)) - shift = stride // 2 - h = torch.arange(0, H, stride) + shift - w = torch.arange(0, W, stride) + shift - if torch.__version__ >= "2.3.0": - anchor_h, anchor_w = torch.meshgrid(h, w, indexing="ij") - else: - anchor_h, anchor_w = torch.meshgrid(h, w) - anchor = torch.stack([anchor_w.flatten(), anchor_h.flatten()], dim=-1) - anchors.append(anchor) - all_anchors = torch.cat(anchors, dim=0) - all_scalers = torch.cat(scaler, dim=0) - return all_anchors, all_scalers - - -class Vec2Box: - def __init__(self, model: YOLO, anchor_cfg: AnchorConfig, image_size, device): - self.device = device - - if hasattr(anchor_cfg, "strides"): - self.strides = anchor_cfg.strides - else: - self.strides = self.create_auto_anchor(model, image_size) - - anchor_grid, scaler = generate_anchors(image_size, self.strides) - self.image_size = image_size - self.anchor_grid, self.scaler = anchor_grid.to(device), scaler.to(device) - - def create_auto_anchor(self, model: YOLO, image_size): - W, H = image_size - dummy_input = torch.zeros(1, 3, H, W).to(self.device) - dummy_output = model(dummy_input) - strides = [] - for predict_head in dummy_output["Main"]: - _, _, *anchor_num = predict_head[2].shape - strides.append(W // anchor_num[1]) - return strides - - def update(self, image_size): - """ - image_size: W, H - """ - if self.image_size == image_size: - return - anchor_grid, scaler = generate_anchors(image_size, self.strides) - self.image_size = image_size - self.anchor_grid, self.scaler = anchor_grid.to(self.device), scaler.to(self.device) - - def __call__(self, predicts): - preds_cls, preds_anc, preds_box = [], [], [] - for layer_output in predicts: - pred_cls, pred_anc, pred_box = layer_output - B, C, h, w = pred_cls.shape - pred_cls = pred_cls.permute(0, 2, 3, 1).reshape(B, h * w, C) - preds_cls.append(pred_cls) - - B, A, R, h, w = pred_anc.shape - pred_anc = pred_anc.permute(0, 3, 4, 2, 1).reshape(B, h * w, R, A) - preds_anc.append(pred_anc) - - B, X, h, w = pred_box.shape - pred_box = pred_box.permute(0, 2, 3, 1).reshape(B, h * w, X) - preds_box.append(pred_box) - preds_cls = torch.concat(preds_cls, dim=1) - preds_anc = torch.concat(preds_anc, dim=1) - preds_box = torch.concat(preds_box, dim=1) - - pred_LTRB = preds_box * self.scaler.view(1, -1, 1) - lt, rb = pred_LTRB.chunk(2, dim=-1) - preds_box = torch.cat([self.anchor_grid - lt, self.anchor_grid + rb], dim=-1) - return preds_cls, preds_anc, preds_box - - -class Anc2Box: - def __init__(self, model: YOLO, anchor_cfg: AnchorConfig, image_size, device): - self.device = device - - if hasattr(anchor_cfg, "strides"): - self.strides = anchor_cfg.strides - else: - self.strides = self.create_auto_anchor(model, image_size) - - self.head_num = len(anchor_cfg.anchor) - self.anchor_grids = self.generate_anchors(image_size) - self.anchor_scale = tensor(anchor_cfg.anchor, device=device).view(self.head_num, 1, -1, 1, 1, 2) - self.anchor_num = self.anchor_scale.size(2) - self.class_num = model.num_classes - - def create_auto_anchor(self, model: YOLO, image_size): - W, H = image_size - dummy_input = torch.zeros(1, 3, H, W).to(self.device) - dummy_output = model(dummy_input) - strides = [] - for predict_head in dummy_output["Main"]: - _, _, *anchor_num = predict_head.shape - strides.append(W // anchor_num[1]) - return strides - - def generate_anchors(self, image_size: List[int]): - anchor_grids = [] - for stride in self.strides: - W, H = image_size[0] // stride, image_size[1] // stride - anchor_h, anchor_w = torch.meshgrid([torch.arange(H), torch.arange(W)], indexing="ij") - anchor_grid = torch.stack((anchor_w, anchor_h), 2).view((1, 1, H, W, 2)).float().to(self.device) - anchor_grids.append(anchor_grid) - return anchor_grids - - def update(self, image_size): - self.anchor_grids = self.generate_anchors(image_size) - - def __call__(self, predicts: List[Tensor]): - preds_box, preds_cls, preds_cnf = [], [], [] - for layer_idx, predict in enumerate(predicts): - B, LC, h, w = predict.shape - L = self.anchor_num - C = LC // L - predict = predict.view(B, L, C, h, w).permute(0, 1, 3, 4, 2) # B, L, h, w, C - - pred_box, pred_cnf, pred_cls = predict.split((4, 1, self.class_num), dim=-1) - pred_box = pred_box.sigmoid() - - pred_box[..., 0:2] = ( - (pred_box[..., 0:2] * 2.0 - 0.5 + self.anchor_grids[layer_idx]) * self.strides[layer_idx] - ) - pred_box[..., 2:4] = ( - (pred_box[..., 2:4] * 2) ** 2 * self.anchor_scale[layer_idx] - ) - - B, L, h, w, A = pred_box.shape - preds_box.append(pred_box.reshape(B, L * h * w, A)) - - B, L, h, w, C = pred_cls.shape - preds_cls.append(pred_cls.reshape(B, L * h * w, C)) - - preds_cnf.append(pred_cnf.reshape(B, L * h * w, C)) - - preds_box = torch.concat(preds_box, dim=1) - preds_cls = torch.concat(preds_cls, dim=1) - preds_cnf = torch.concat(preds_cnf, dim=1) - - preds_box = transform_bbox(preds_box, "xycwh -> xyxy") - return preds_cls, None, preds_box, preds_cnf.sigmoid() - - -def create_converter(model_version: str = "v9-c", *args, **kwargs) -> Union[Anc2Box, Vec2Box]: - if "v7" in model_version: # check model if v7 - converter = Anc2Box(*args, **kwargs) - else: - converter = Vec2Box(*args, **kwargs) - return converter - - -def bbox_nms(cls_dist: Tensor, bbox: Tensor, nms_cfg: NMSConfig, confidence: Optional[Tensor] = None): - cls_dist = cls_dist.sigmoid() * (1 if confidence is None else confidence) - - batch_idx, valid_grid, valid_cls = torch.where(cls_dist > nms_cfg.min_confidence) - valid_con = cls_dist[batch_idx, valid_grid, valid_cls] - valid_box = bbox[batch_idx, valid_grid] - - nms_idx = batched_nms(valid_box, valid_con, batch_idx + valid_cls * bbox.size(0), nms_cfg.min_iou) - predicts_nms = [] - for idx in range(cls_dist.size(0)): - instance_idx = nms_idx[idx == batch_idx[nms_idx]] - - predict_nms = torch.cat( - [valid_cls[instance_idx][:, None], valid_box[instance_idx], valid_con[instance_idx][:, None]], dim=-1 - ) - - predicts_nms.append(predict_nms[: nms_cfg.max_bbox]) - return predicts_nms diff --git a/PytorchWildlife/models/detection/yolo_mit/yolo/utils/dataset_utils.py b/PytorchWildlife/models/detection/yolo_mit/yolo/utils/dataset_utils.py deleted file mode 100644 index da1989d59..000000000 --- a/PytorchWildlife/models/detection/yolo_mit/yolo/utils/dataset_utils.py +++ /dev/null @@ -1,116 +0,0 @@ -import json -import os -from itertools import chain -from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple -import numpy as np -import torch - -def discretize_categories(categories: List[Dict[str, int]]) -> Dict[int, int]: - """ - Maps each unique 'id' in the list of category dictionaries to a sequential integer index. - Indices are assigned based on the sorted 'id' values. - """ - sorted_categories = sorted(categories, key=lambda category: category["id"]) - return {category["id"]: index for index, category in enumerate(sorted_categories)} - - -def locate_label_paths(dataset_path: Path, phase_name: Path) -> Tuple[Path, Path]: - """ - Find the path to label files for a specified dataset and phase(e.g. training). - - Args: - dataset_path (Path): The path to the root directory of the dataset. - phase_name (Path): The name of the phase for which labels are being searched (e.g., "train", "val", "test"). - - Returns: - Tuple[Path, Path]: A tuple containing the path to the labels file and the file format ("json" or "txt"). - """ - json_labels_path = dataset_path / "annotations" / f"instances_{phase_name}.json" - - txt_labels_path = dataset_path / "labels" / phase_name - - if json_labels_path.is_file(): - return json_labels_path, "json" - - elif txt_labels_path.is_dir(): - txt_files = [f for f in os.listdir(txt_labels_path) if f.endswith(".txt")] - if txt_files: - return txt_labels_path, "txt" - - return [], None - - -def create_image_metadata(labels_path: str) -> Tuple[Dict[str, List], Dict[str, Dict]]: - """ - Create a dictionary containing image information and annotations indexed by image ID. - - Args: - labels_path (str): The path to the annotation json file. - - Returns: - - annotations_index: A dictionary where keys are image IDs and values are lists of annotations. - - image_info_dict: A dictionary where keys are image file names without extension and values are image information dictionaries. - """ - with open(labels_path, "r") as file: - labels_data = json.load(file) - id_to_idx = discretize_categories(labels_data.get("categories", [])) if "categories" in labels_data else None - annotations_index = organize_annotations_by_image(labels_data, id_to_idx) # check lookup is a good name? - image_info_dict = {Path(img["file_name"]).stem: img for img in labels_data["images"]} - return annotations_index, image_info_dict - - -def scale_segmentation( - annotations: List[Dict[str, Any]], image_dimensions: Dict[str, int] -) -> Optional[List[List[float]]]: - """ - Scale the segmentation data based on image dimensions and return a list of scaled segmentation data. - - Args: - annotations (List[Dict[str, Any]]): A list of annotation dictionaries. - image_dimensions (Dict[str, int]): A dictionary containing image dimensions (height and width). - - Returns: - Optional[List[List[float]]]: A list of scaled segmentation data, where each sublist contains category_id followed by scaled (x, y) coordinates. - """ - if annotations is None: - return None - - seg_array_with_cat = [] - h, w = image_dimensions["height"], image_dimensions["width"] - for anno in annotations: - category_id = anno["category_id"] - if "segmentation" in anno: - seg_list = [item for sublist in anno["segmentation"] for item in sublist] - elif "bbox" in anno: - x, y, width, height = anno["bbox"] - seg_list = [x, y, x + width, y, x + width, y + height, x, y + height] - - scaled_seg_data = ( - np.array(seg_list).reshape(-1, 2) / [w, h] - ).tolist() # make the list group in x, y pairs and scaled with image width, height - scaled_flat_seg_data = [category_id] + list(chain(*scaled_seg_data)) # flatten the scaled_seg_data list - seg_array_with_cat.append(scaled_flat_seg_data) - - return seg_array_with_cat - - -def tensorlize(data): - try: - img_paths, bboxes, img_ratios = zip(*data) - except ValueError as e: - # logger.error( - # ":rotating_light: This may be caused by using old cache or another version of YOLO's cache.\n" - # ":rotating_light: Please clean the cache and try running again." - # ) - raise e - max_box = max(bbox.size(0) for bbox in bboxes) - padded_bbox_list = [] - for bbox in bboxes: - padding = torch.full((max_box, 5), -1, dtype=torch.float32) - padding[: bbox.size(0)] = bbox - padded_bbox_list.append(padding) - bboxes = np.stack(padded_bbox_list) - img_paths = np.array(img_paths) - img_ratios = np.array(img_ratios) - return img_paths, bboxes, img_ratios diff --git a/PytorchWildlife/models/detection/yolo_mit/yolo/utils/model_utils.py b/PytorchWildlife/models/detection/yolo_mit/yolo/utils/model_utils.py deleted file mode 100644 index e03942e8e..000000000 --- a/PytorchWildlife/models/detection/yolo_mit/yolo/utils/model_utils.py +++ /dev/null @@ -1,29 +0,0 @@ -from typing import List, Optional, Union -from torch import Tensor -from yolo.config import NMSConfig -from yolo.model.yolo import YOLO -from yolo.utils.bounding_box_utils import Anc2Box, Vec2Box, bbox_nms - - -class PostProcess: - """ - TODO: function document - scale back the prediction and do nms for pred_bbox - """ - - def __init__(self, converter: Union[Vec2Box, Anc2Box], nms_cfg: NMSConfig) -> None: - self.converter = converter - self.nms = nms_cfg - - def __call__( - self, predict, rev_tensor: Optional[Tensor] = None, image_size: Optional[List[int]] = None - ) -> List[Tensor]: - if image_size is not None: - self.converter.update(image_size) - prediction = self.converter(predict["Main"]) - pred_class, _, pred_bbox = prediction[:3] - pred_conf = prediction[3] if len(prediction) == 4 else None - if rev_tensor is not None: - pred_bbox = (pred_bbox - rev_tensor[:, None, 1:]) / rev_tensor[:, 0:1, None] - pred_bbox = bbox_nms(pred_class, pred_bbox, self.nms, pred_conf) #pred_box: [cls, x1, y1, x2, y2, conf] - return pred_bbox \ No newline at end of file diff --git a/PytorchWildlife/models/detection/yolo_mit/yolo_mit_base.py b/PytorchWildlife/models/detection/yolo_mit/yolo_mit_base.py deleted file mode 100644 index 4b02a7dff..000000000 --- a/PytorchWildlife/models/detection/yolo_mit/yolo_mit_base.py +++ /dev/null @@ -1,220 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -""" Yolo mit base detector class. """ - -# Importing basic libraries - -import os -import supervision as sv -import numpy as np -from PIL import Image -import wget -import torch - -from ..base_detector import BaseDetector -from ....data import datasets as pw_data - -import sys -from pathlib import Path - -from lightning import Trainer -import yaml -from omegaconf import OmegaConf - -project_root = Path(__file__).resolve().parent -sys.path.append(str(project_root)) - -from yolo import create_model, create_converter, PostProcess, AugmentationComposer - -class YOLOMITBase(BaseDetector): - """ - Base detector class for YOLO MIT framework. This class provides utility methods for - loading the model, generating results, and performing single and batch image detections. - """ - def __init__(self, weights=None, device="cpu", url=None): - """ - Initialize the YOLO MIT detector. - - Args: - weights (str, optional): - Path to the model weights. Defaults to None. - device (str, optional): - Device for model inference. Defaults to "cpu". - url (str, optional): - URL to fetch the model weights. Defaults to None. - """ - - self.cfg = self._load_cfg() - self.transform = AugmentationComposer([], self.cfg.image_size, self.cfg.image_size[0]) - self.weights = weights - self.device = device - self.url = url - super(YOLOMITBase, self).__init__(weights=self.weights, device=self.device, url=self.url) - - def _load_cfg(self): - if self.MODEL_NAME == "MDV6-mit-yolov9-c.ckpt": - if not os.path.exists(os.path.join(torch.hub.get_dir(), "checkpoints", "config_v9s.yaml")): - os.makedirs(os.path.join(torch.hub.get_dir(), "checkpoints"), exist_ok=True) - url = "https://zenodo.org/records/15178680/files/config_v9s.yaml?download=1" - config_path = wget.download(url, out=os.path.join(torch.hub.get_dir(), "checkpoints")) - else: - config_path = os.path.join(torch.hub.get_dir(), "checkpoints", "config_v9s.yaml") - elif self.MODEL_NAME == "MDV6-mit-yolov9-e.ckpt": - if not os.path.exists(os.path.join(torch.hub.get_dir(), "checkpoints", "config_v9c.yaml")): - os.makedirs(os.path.join(torch.hub.get_dir(), "checkpoints"), exist_ok=True) - url = "https://zenodo.org/records/15178680/files/config_v9c.yaml?download=1" - config_path = wget.download(url, out=os.path.join(torch.hub.get_dir(), "checkpoints")) - else: - config_path = os.path.join(torch.hub.get_dir(), "checkpoints", "config_v9c.yaml") - - with open(config_path, 'r') as f: - cfg_dict = yaml.safe_load(f) - - return OmegaConf.create(cfg_dict) - - def _load_model(self, weights=None, device="cpu", url=None): - """ - Load the YOLO MIT model weights. - - Args: - weights (str, optional): - Path to the model weights. Defaults to None. - device (str, optional): - Device for model inference. Defaults to "cpu". - url (str, optional): - URL to fetch the model weights. Defaults to None. - Raises: - Exception: If weights are not provided. - """ - if url: - if not os.path.exists(os.path.join(torch.hub.get_dir(), "checkpoints", self.MODEL_NAME)): - os.makedirs(os.path.join(torch.hub.get_dir(), "checkpoints"), exist_ok=True) - weights = wget.download(url, out=os.path.join(torch.hub.get_dir(), "checkpoints")) - else: - weights = os.path.join(torch.hub.get_dir(), "checkpoints", self.MODEL_NAME) - else: - raise Exception("Need weights for inference.") - - self.cfg.image_size = [self.IMAGE_SIZE, self.IMAGE_SIZE] - self.model = create_model(self.cfg.model, weight_path=weights, class_num=3).to(self.device) - self.converter = create_converter(self.cfg.model.name, self.model, self.cfg.model.anchor, self.cfg.image_size, self.device) - self.post_proccess = PostProcess(self.converter, self.cfg.task.nms) - - def results_generation(self, preds, img_id, id_strip=None): - """ - Generate results for detection based on model predictions. - - Args: - preds (List[torch.Tensor]): - Model predictions. - img_id (str): - Image identifier. - id_strip (str, optional): - Strip specific characters from img_id. Defaults to None. - - Returns: - dict: Dictionary containing image ID, detections, and labels. - """ - #preds: [cls, x1, y1, x2, y2, conf] - class_id = preds[0][:,0].cpu().numpy().astype(int) - xyxy = preds[0][:,1:5].cpu().numpy() - confidence = preds[0][:,5].cpu().numpy() - - results = {"img_id": str(img_id).strip(id_strip)} - results["detections"] = sv.Detections( - xyxy=xyxy, - confidence=confidence, - class_id=class_id - ) - - results["labels"] = [ - f"{self.CLASS_NAMES[class_id]} {confidence:0.2f}" - for _, _, confidence, class_id, _, _ in results["detections"] - ] - results - return results - - - def single_image_detection(self, img, img_path=None, det_conf_thres=0.2, id_strip=None): - """ - Perform detection on a single image. - - Args: - img (str or ndarray): - Image path or ndarray of images. - img_path (str, optional): - Image path or identifier. - det_conf_thres (float, optional): - Confidence threshold for predictions. Defaults to 0.2. - id_strip (str, optional): - Characters to strip from img_id. Defaults to None. - - Returns: - dict: Detection results. - """ - self.cfg.task.data.source = img_path - self.cfg.task.nms.min_confidence = det_conf_thres - self._load_model(weights=self.weights, device=self.device, url=self.url) - - if type(img) == str: - if img_path is None: - img_path = img - im_pil = Image.open(img_path).convert('RGB') - else: - im_pil = Image.fromarray(img) - - image, bbox, rev_tensor = self.transform(im_pil) - image = image.to(self.device)[None] - rev_tensor = rev_tensor.to(self.device)[None] - - with torch.no_grad(): - predict = self.model(image) - det_results = self.post_proccess(predict, rev_tensor) #pred_box: [cls, x1, y1, x2, y2, conf] - - return self.results_generation(det_results, img_path, id_strip) - - def batch_image_detection(self, data_path, batch_size=16, det_conf_thres=0.2, id_strip=None): - """ - Perform detection on a batch of images. - - Args: - data_path (str): - Path containing all images for inference. - batch_size (int, optional): - Batch size for inference. Defaults to 16. - det_conf_thres (float, optional): - Confidence threshold for predictions. Defaults to 0.2. - id_strip (str, optional): - Characters to strip from img_id. Defaults to None. - extension (str, optional): - Image extension to search for. Defaults to "JPG" - - Returns: - list: List of detection results for all images. - """ - self.cfg.task.data.source = data_path - self.cfg.task.nms.min_confidence = det_conf_thres - self._load_model(weights=self.weights, device=self.device, url=self.url) - - dataset = pw_data.DetectionImageFolder( - data_path, - transform=self.transform, - ) - - results = [] - for i in range(len(dataset.images)): - res = self.single_image_detection(dataset.images[i], img_path=dataset.images[i], det_conf_thres=det_conf_thres, id_strip=id_strip) - # Upload the original image and get the size in the format (height, width) - img = Image.open(dataset.images[i]) - img = np.asarray(img) - size = img.shape[:2] - # Normalize the coordinates for timelapse compatibility - normalized_coords = [[x1 / size[1], y1 / size[0], x2 / size[1], y2 / size[0]] for x1, y1, x2, y2 in res["detections"].xyxy] - res["normalized_coords"] = normalized_coords - results.append(res) - - return results - - - diff --git a/PytorchWildlife/utils/__init__.py b/PytorchWildlife/utils/__init__.py deleted file mode 100644 index bca86797f..000000000 --- a/PytorchWildlife/utils/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .misc import * -from .post_process import * \ No newline at end of file diff --git a/PytorchWildlife/utils/misc.py b/PytorchWildlife/utils/misc.py deleted file mode 100644 index f516a66c2..000000000 --- a/PytorchWildlife/utils/misc.py +++ /dev/null @@ -1,53 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -""" Miscellaneous functions.""" - -import numpy as np -from tqdm import tqdm -import cv2 -from typing import Callable -from supervision import VideoInfo, VideoSink, get_video_frames_generator - -__all__ = [ - "process_video" -] - - -def process_video( - source_path: str, - target_path: str, - callback: Callable[[np.ndarray, int], np.ndarray], - target_fps: int = 1, - codec: str = "mp4v" -) -> None: - """ - Process a video frame-by-frame, applying a callback function to each frame and saving the results - to a new video. This version includes a progress bar and allows codec selection. - - Args: - source_path (str): - Path to the source video file. - target_path (str): - Path to save the processed video. - callback (Callable[[np.ndarray, int], np.ndarray]): - A function that takes a video frame and its index as input and returns the processed frame. - codec (str, optional): - Codec used to encode the processed video. Default is "avc1". - """ - source_video_info = VideoInfo.from_video_path(video_path=source_path) - - if source_video_info.fps > target_fps: - stride = int(source_video_info.fps / target_fps) - source_video_info.fps = target_fps - else: - stride = 1 - - with VideoSink(target_path=target_path, video_info=source_video_info, codec=codec) as sink: - with tqdm(total=int(source_video_info.total_frames / stride)) as pbar: - for index, frame in enumerate( - get_video_frames_generator(source_path=source_path, stride=stride) - ): - result_frame = callback(frame, index) - sink.write_frame(frame=cv2.cvtColor(result_frame, cv2.COLOR_RGB2BGR)) - pbar.update(1) diff --git a/PytorchWildlife/utils/post_process.py b/PytorchWildlife/utils/post_process.py deleted file mode 100644 index b3e2a333b..000000000 --- a/PytorchWildlife/utils/post_process.py +++ /dev/null @@ -1,532 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -""" Post-processing functions.""" - -import os -import numpy as np -import json -import cv2 -from PIL import Image -import supervision as sv -import shutil -from pathlib import Path - -__all__ = [ - "save_detection_images", - "save_detection_images_dots", - "save_crop_images", - "save_detection_json", - "save_detection_json_as_dots", - "save_detection_classification_json", - "save_detection_timelapse_json", - "save_detection_classification_timelapse_json", - "detection_folder_separation" -] - - -def save_detection_images(results, output_dir, input_dir=None, overwrite=False): - """ - Save detected images with bounding boxes and labels annotated. - - Args: - results (list or dict): - Detection results containing image ID, detections, and labels. - output_dir (str): - Directory to save the annotated images. - input_dir (str): - Directory containing the input images. Default to None. - overwrite (bool): - Whether overwriting existing image folders. Default to False. - """ - box_annotator = sv.BoxAnnotator(thickness=4) - lab_annotator = sv.LabelAnnotator(text_color=sv.Color.BLACK, text_thickness=4, text_scale=2) - os.makedirs(output_dir, exist_ok=True) - - with sv.ImageSink(target_dir_path=output_dir, overwrite=overwrite) as sink: - if isinstance(results, list): - for entry in results: - annotated_img = lab_annotator.annotate( - scene=box_annotator.annotate( - scene=np.array(Image.open(entry["img_id"]).convert("RGB")), - detections=entry["detections"], - ), - detections=entry["detections"], - labels=entry["labels"], - ) - if input_dir: - relative_path = os.path.relpath(entry["img_id"], input_dir) - save_path = os.path.join(output_dir, relative_path) - os.makedirs(os.path.dirname(save_path), exist_ok=True) - image_name = relative_path - else: - image_name = os.path.basename(entry["img_id"]) - sink.save_image( - image=cv2.cvtColor(annotated_img, cv2.COLOR_RGB2BGR), image_name=image_name - ) - else: - annotated_img = lab_annotator.annotate( - scene=box_annotator.annotate( - scene=np.array(Image.open(results["img_id"]).convert("RGB")), - detections=results["detections"], - ), - detections=results["detections"], - labels=results["labels"], - ) - - sink.save_image( - image=cv2.cvtColor(annotated_img, cv2.COLOR_RGB2BGR), image_name=os.path.basename(results["img_id"]) - ) - -def save_detection_images_dots(results, output_dir, input_dir=None, overwrite=False, show_labels=True): - """ - Save detected images with dot annotations and optional labels. - - Args: - results (list or dict): - Detection results containing image ID, detections, and labels. - output_dir (str): - Directory to save the annotated images. - input_dir (str): - Directory containing the input images. Default to None. - overwrite (bool): - Whether overwriting existing image folders. Default to False. - show_labels (bool): - Whether to show text labels next to dots. Default to True. - """ - dot_annotator = sv.DotAnnotator(radius=6) - lab_annotator = sv.LabelAnnotator(text_position=sv.Position.BOTTOM_RIGHT) if show_labels else None - os.makedirs(output_dir, exist_ok=True) - - with sv.ImageSink(target_dir_path=output_dir, overwrite=overwrite) as sink: - if isinstance(results, list): - for i, entry in enumerate(results): - if "img_id" in entry: - scene = np.array(Image.open(entry["img_id"]).convert("RGB")) - image_name = os.path.basename(entry["img_id"]) - else: - scene = entry["img"] - image_name = f"output_image_{i}.jpg" # default name if no image id is provided - - annotated_img = dot_annotator.annotate( - scene=scene, - detections=entry["detections"], - ) - if lab_annotator: - annotated_img = lab_annotator.annotate( - scene=annotated_img, - detections=entry["detections"], - labels=entry["labels"], - ) - if input_dir: - relative_path = os.path.relpath(entry["img_id"], input_dir) - save_path = os.path.join(output_dir, relative_path) - os.makedirs(os.path.dirname(save_path), exist_ok=True) - image_name = relative_path - sink.save_image( - image=cv2.cvtColor(annotated_img, cv2.COLOR_RGB2BGR), image_name=image_name - ) - else: - if "img_id" in results: - scene = np.array(Image.open(results["img_id"]).convert("RGB")) - image_name = os.path.basename(results["img_id"]) - else: - scene = results["img"] - image_name = "output_image.jpg" # default name if no image id is provided - - annotated_img = dot_annotator.annotate( - scene=scene, - detections=results["detections"], - ) - if lab_annotator: - annotated_img = lab_annotator.annotate( - scene=annotated_img, - detections=results["detections"], - labels=results["labels"], - ) - sink.save_image( - image=cv2.cvtColor(annotated_img, cv2.COLOR_RGB2BGR), image_name=image_name - ) - - -# !!! Output paths need to be optimized !!! -def save_crop_images(results, output_dir, input_dir=None, overwrite=False): - """ - Save cropped images based on the detection bounding boxes. - - Args: - results (list): - Detection results containing image ID and detections. - output_dir (str): - Directory to save the cropped images. - input_dir (str): - Directory containing the input images. Default to None. - overwrite (bool): - Whether overwriting existing image folders. Default to False. - """ - - os.makedirs(output_dir, exist_ok=True) - - with sv.ImageSink(target_dir_path=output_dir, overwrite=overwrite) as sink: - if isinstance(results, list): - for entry in results: - for i, (xyxy, cat) in enumerate(zip(entry["detections"].xyxy, entry["detections"].class_id)): - cropped_img = sv.crop_image( - image=np.array(Image.open(entry["img_id"]).convert("RGB")), xyxy=xyxy - ) - if input_dir: - relative_path = os.path.relpath(entry["img_id"], input_dir) - save_path = os.path.join(output_dir, relative_path) - os.makedirs(os.path.dirname(save_path), exist_ok=True) - image_name = os.path.join(os.path.dirname(relative_path), "{}_{}_{}".format(int(cat), i, os.path.basename(entry["img_id"]))) - else: - image_name = "{}_{}_{}".format(int(cat), i, os.path.basename(entry["img_id"])) - sink.save_image( - image=cv2.cvtColor(cropped_img, cv2.COLOR_RGB2BGR), - image_name=image_name, - ) - else: - for i, (xyxy, cat) in enumerate(zip(results["detections"].xyxy, results["detections"].class_id)): - cropped_img = sv.crop_image( - image=np.array(Image.open(results["img_id"]).convert("RGB")), xyxy=xyxy - ) - sink.save_image( - image=cv2.cvtColor(cropped_img, cv2.COLOR_RGB2BGR), - image_name="{}_{}_{}".format(int(cat), i, os.path.basename(results["img_id"]), - )) - -def save_detection_json(det_results, output_dir, categories=None, exclude_category_ids=[], exclude_file_path=None): - """ - Save detection results to a JSON file. - - Args: - det_results (list): - Detection results containing image ID, bounding boxes, category, and confidence. - output_dir (str): - Path to save the output JSON file. - categories (list, optional): - List of categories for detected objects. Defaults to None. - exclude_category_ids (list, optional): - List of category IDs to exclude from the output. Defaults to []. Category IDs can be found in the definition of each models. - exclude_file_path (str, optional): - We can exclude the some path sections from the image ID. Defaults to None. - """ - json_results = {"annotations": [], "categories": categories} - - for det_r in det_results: - - # Category filtering - img_id = det_r["img_id"] - category = det_r["detections"].class_id - - bbox = det_r["detections"].xyxy.astype(int)[~np.isin(category, exclude_category_ids)] - confidence = det_r["detections"].confidence[~np.isin(category, exclude_category_ids)] - category = category[~np.isin(category, exclude_category_ids)] - - # if not all([x in exclude_category_ids for x in category]): - json_results["annotations"].append( - { - "img_id": img_id.replace(exclude_file_path + os.sep, '') if exclude_file_path else img_id, - "bbox": bbox.tolist(), - "category": category.tolist(), - "confidence": confidence.tolist(), - } - ) - - with open(output_dir, "w") as f: - json.dump(json_results, f, indent=4) - -def save_detection_json_as_dots(det_results, output_dir, categories=None, exclude_category_ids=[], exclude_file_path=None): - """ - Save detection results to a JSON file in dots format. - - Args: - det_results (list): - Detection results containing image ID, bounding boxes, category, and confidence. - output_dir (str): - Path to save the output JSON file. - categories (list, optional): - List of categories for detected objects. Defaults to None. - exclude_category_ids (list, optional): - List of category IDs to exclude from the output. Defaults to []. Category IDs can be found in the definition of each models. - exclude_file_path (str, optional): - We can exclude the some path sections from the image ID. Defaults to None. - """ - json_results = {"annotations": [], "categories": categories} - - for det_r in det_results: - - # Category filtering - img_id = det_r["img_id"] - category = det_r["detections"].class_id - - bbox = det_r["detections"].xyxy.astype(int)[~np.isin(category, exclude_category_ids)] - dot = np.array([[np.mean(row[::2]), np.mean(row[1::2])] for row in bbox]) - confidence = det_r["detections"].confidence[~np.isin(category, exclude_category_ids)] - category = category[~np.isin(category, exclude_category_ids)] - - # if not all([x in exclude_category_ids for x in category]): - json_results["annotations"].append( - { - "img_id": img_id.replace(exclude_file_path + os.sep, '') if exclude_file_path else img_id, - "dot": dot.tolist(), - "category": category.tolist(), - "confidence": confidence.tolist(), - } - ) - - with open(output_dir, "w") as f: - json.dump(json_results, f, indent=4) - - -def save_detection_timelapse_json( - det_results, output_dir, categories=None, - exclude_category_ids=[], exclude_file_path=None, info={"detector": "megadetector_v5"} - ): - """ - Save detection results to a JSON file. - - Args: - det_results (list): - Detection results containing image ID, bounding boxes, category, and confidence. - output_dir (str): - Path to save the output JSON file. - categories (list, optional): - List of categories for detected objects. Defaults to None. - exclude_category_ids (list, optional): - List of category IDs to exclude from the output. Defaults to []. Category IDs can be found in the definition of each models. - exclude_file_path (str, optional): - Some time, Timelapse has path issues. We can exclude the some path sections from the image ID. Defaults to None. - info (dict, optional): - Default Timelapse info. Defaults to {"detector": "megadetector_v5}. - """ - - json_results = { - "info": info, - "detection_categories": categories, - "images": [] - } - - for det_r in det_results: - - img_id = det_r["img_id"] - category_id_list = det_r["detections"].class_id - - bbox_list = det_r["detections"].xyxy.astype(int)[~np.isin(category_id_list, exclude_category_ids)] - confidence_list = det_r["detections"].confidence[~np.isin(category_id_list, exclude_category_ids)] - normalized_bbox_list = np.array(det_r["normalized_coords"])[~np.isin(category_id_list, exclude_category_ids)] - category_id_list = category_id_list[~np.isin(category_id_list, exclude_category_ids)] - - # if not all([x in exclude_category_ids for x in category_id_list]): - image_annotations = { - "file": img_id.replace(exclude_file_path + os.sep, '') if exclude_file_path else img_id, - "max_detection_conf": float(max(confidence_list)) if len(confidence_list) > 0 else '', - "detections": [] - } - for i in range(len(bbox_list)): - normalized_bbox = [float(y) for y in normalized_bbox_list[i]] - detection = { - "category": str(category_id_list[i]), - "conf": float(confidence_list[i]), - "bbox": [normalized_bbox[0], normalized_bbox[1], normalized_bbox[2]-normalized_bbox[0], normalized_bbox[3]-normalized_bbox[1]], - "classifications": [] - } - - image_annotations["detections"].append(detection) - - json_results["images"].append(image_annotations) - - with open(output_dir, "w") as f: - json.dump(json_results, f, indent=4) - - -def save_detection_classification_json( - det_results, clf_results, output_path, det_categories=None, clf_categories=None, exclude_file_path=None -): - """ - Save classification results to a JSON file. - - Args: - det_results (list): - Detection results containing image ID, bounding boxes, detection category, and confidence. - clf_results (list): - classification results containing image ID, classification category, and confidence. - output_path (str): - Path to save the output JSON file. - det_categories (list, optional): - List of categories for detected objects. Defaults to None. - clf_categories (list, optional): - List of categories for classified objects. Defaults to None. - exclude_file_path (str, optional): - We can exclude the some path sections from the image ID. Defaults to None. - """ - - json_results = { - "annotations": [], - "det_categories": det_categories, - "clf_categories": clf_categories, - } - - with open(output_path, "w") as f: - counter = 0 - for det_r in det_results: - clf_categories = [] - clf_confidence = [] - for i in range(counter, len(clf_results)): - clf_r = clf_results[i] - if clf_r["img_id"] == det_r["img_id"]: - clf_categories.append(clf_r["class_id"]) - clf_confidence.append(clf_r["confidence"]) - counter += 1 - else: - break - - json_results["annotations"].append( - { - "img_id": str(det_r["img_id"]).replace(exclude_file_path + os.sep, '') if exclude_file_path else str(det_r["img_id"]), - "bbox": [ - [int(x) for x in sublist] - for sublist in det_r["detections"].xyxy.astype(int).tolist() - ], - "det_category": [ - int(x) for x in det_r["detections"].class_id.tolist() - ], - "det_confidence": [ - float(x) for x in det_r["detections"].confidence.tolist() - ], - "clf_category": [int(x) for x in clf_categories], - "clf_confidence": [float(x) for x in clf_confidence], - } - ) - json.dump(json_results, f, indent=4) - - -def save_detection_classification_timelapse_json( - det_results, clf_results, output_path, det_categories=None, clf_categories=None, - exclude_file_path=None, info={"detector": "megadetector_v5"} -): - """ - Save detection and classification results to a JSON file in the specified format. - - Args: - det_results (list): - Detection results containing image ID, bounding boxes, detection category, and confidence. - clf_results (list): - Classification results containing image ID, classification category, and confidence. - output_path (str): - Path to save the output JSON file. - det_categories (dict, optional): - Dictionary of categories for detected objects. Defaults to None. - clf_categories (dict, optional): - Dictionary of categories for classified objects. Defaults to None. - exclude_file_path (str, optional): - We can exclude the some path sections from the image ID. Defaults to None. - """ - json_results = { - "info": info, - "detection_categories": det_categories, - "classification_categories": clf_categories, - "images": [] - } - - for det_r in det_results: - image_annotations = { - "file": str(det_r["img_id"]).replace(exclude_file_path + os.sep, '') if exclude_file_path else str(det_r["img_id"]), - "max_detection_conf": float(max(det_r["detections"].confidence)) if len(det_r["detections"].confidence) > 0 else '', - "detections": [] - } - - for i in range(len(det_r["detections"])): - det = det_r["detections"][i] - normalized_bbox = [float(y) for y in det_r["normalized_coords"][i]] - detection = { - "category": str(det.class_id[0]), - "conf": float(det.confidence[0]), - "bbox": [normalized_bbox[0], normalized_bbox[1], normalized_bbox[2]-normalized_bbox[0], normalized_bbox[3]-normalized_bbox[1]], - "classifications": [] - } - - # Find classifications for this detection - for clf_r in clf_results: - if clf_r["img_id"] == det_r["img_id"]: - detection["classifications"].append([str(clf_r["class_id"]), float(clf_r["confidence"])]) - - image_annotations["detections"].append(detection) - - json_results["images"].append(image_annotations) - - with open(output_path, "w") as f: - json.dump(json_results, f, indent=4) - - -def detection_folder_separation(json_file, img_path, destination_path, confidence_threshold): - """ - Processes detection data from a JSON file to sort images into 'Animal' or 'No_animal' directories - based on detection categories and confidence levels. - - This function reads a JSON formatted file containing annotations of image detections. - Each image is checked for detections with category '0' and a confidence level above the specified - threshold. If such detections are found, the image is categorized under 'Animal'. Images without - any category '0' detections above the threshold, including those with no detections at all, are - categorized under 'No_animal'. - - Parameters: - - json_file (str): Path to the JSON file containing detection data. - - destination_path (str): Base path where 'Animal' and 'No_animal' folders will be created - and into which images will be sorted and copied. - - source_images_directory (str): Path to the directory containing the source images to be processed. - - confidence_threshold (float): The confidence threshold to consider a detection as valid. - - Effects: - - Reads from the specified `json_file`. - - Copies files from `source_images_directory` to either `destination_path/Animal` or - `destination_path/No_animal` based on the detection data and confidence level. - - Note: - - The function assumes that the JSON file structure includes keys 'annotations', each containing - 'img_id', 'bbox', 'category', and 'confidence'. It does not handle missing keys or unexpected - JSON structures and may raise an exception in such cases. - - Directories `Animal` and `No_animal` are created if they do not already exist. - - Images are copied, not moved; original images remain in the source directory. - """ - - # Load JSON data from the file - with open(json_file, 'r') as file: - data = json.load(file) - - # Ensure the destination directories exist - os.makedirs(destination_path, exist_ok=True) - animal_path = os.path.join(destination_path, "Animal") - no_animal_path = os.path.join(destination_path, "No_animal") - os.makedirs(animal_path, exist_ok=True) - os.makedirs(no_animal_path, exist_ok=True) - - # Process each image detection - i = 0 - for item in data['annotations']: - i+=1 - img_id = item['img_id'] - categories = item['category'] - confidences = item['confidence'] - - # Check if there is any category '0' with confidence above the threshold - file_targeted_for_animal = False - for category, confidence in zip(categories, confidences): - if category == 0 and confidence > confidence_threshold: - file_targeted_for_animal = True - break - - if file_targeted_for_animal: - target_folder = animal_path - else: - target_folder = no_animal_path - - # Construct the source and destination file paths - src_file_path = os.path.join(img_path, img_id) - dest_file_path = os.path.join(target_folder, os.path.dirname(img_id)) - os.makedirs(dest_file_path, exist_ok=True) - - # Copy the file to the appropriate directory - shutil.copy(src_file_path, dest_file_path) - - return "{} files were successfully separated".format(i) diff --git a/docs-requirements.txt b/docs-requirements.txt index 8bd08954a..54cc33c4a 100644 --- a/docs-requirements.txt +++ b/docs-requirements.txt @@ -5,5 +5,3 @@ mkdocs-material-extensions mkdocs-callouts mkdocs-git-revision-date-localized-plugin pymdown-extensions -mkdocstrings -mkdocstrings-python diff --git a/docs/base/data/datasets.md b/docs/base/data/datasets.md deleted file mode 100644 index a16a16c8f..000000000 --- a/docs/base/data/datasets.md +++ /dev/null @@ -1,3 +0,0 @@ -# Datasets Module - -::: PytorchWildlife.data.datasets \ No newline at end of file diff --git a/docs/base/data/transforms.md b/docs/base/data/transforms.md deleted file mode 100644 index b9b264689..000000000 --- a/docs/base/data/transforms.md +++ /dev/null @@ -1,3 +0,0 @@ -# Transforms Module - -::: PytorchWildlife.data.transforms \ No newline at end of file diff --git a/docs/base/models/classification/base_classifier.md b/docs/base/models/classification/base_classifier.md deleted file mode 100644 index 6f729c9b1..000000000 --- a/docs/base/models/classification/base_classifier.md +++ /dev/null @@ -1,3 +0,0 @@ -# Base Classifier - -::: PytorchWildlife.models.classification.base_classifier \ No newline at end of file diff --git a/docs/base/models/classification/resnet_base/amazon.md b/docs/base/models/classification/resnet_base/amazon.md deleted file mode 100644 index 520bfb778..000000000 --- a/docs/base/models/classification/resnet_base/amazon.md +++ /dev/null @@ -1,3 +0,0 @@ -# Amazon - -::: PytorchWildlife.models.classification.resnet_base.amazon \ No newline at end of file diff --git a/docs/base/models/classification/resnet_base/base_classifier.md b/docs/base/models/classification/resnet_base/base_classifier.md deleted file mode 100644 index 68c1b339a..000000000 --- a/docs/base/models/classification/resnet_base/base_classifier.md +++ /dev/null @@ -1,3 +0,0 @@ -# ResNet Base - -::: PytorchWildlife.models.classification.resnet_base.base_classifier \ No newline at end of file diff --git a/docs/base/models/classification/resnet_base/custom_weights.md b/docs/base/models/classification/resnet_base/custom_weights.md deleted file mode 100644 index e41b4e85e..000000000 --- a/docs/base/models/classification/resnet_base/custom_weights.md +++ /dev/null @@ -1,3 +0,0 @@ -# Custom Weights - -::: PytorchWildlife.models.classification.resnet_base.custom_weights \ No newline at end of file diff --git a/docs/base/models/classification/resnet_base/opossum.md b/docs/base/models/classification/resnet_base/opossum.md deleted file mode 100644 index 4f6b11f1c..000000000 --- a/docs/base/models/classification/resnet_base/opossum.md +++ /dev/null @@ -1,3 +0,0 @@ -# Opossum - -::: PytorchWildlife.models.classification.resnet_base.opossum \ No newline at end of file diff --git a/docs/base/models/classification/resnet_base/serengeti.md b/docs/base/models/classification/resnet_base/serengeti.md deleted file mode 100644 index 6f7a3760b..000000000 --- a/docs/base/models/classification/resnet_base/serengeti.md +++ /dev/null @@ -1,3 +0,0 @@ -# Serengeti - -::: PytorchWildlife.models.classification.resnet_base.serengeti \ No newline at end of file diff --git a/docs/base/models/classification/timm_base/DFNE.md b/docs/base/models/classification/timm_base/DFNE.md deleted file mode 100644 index 01f11fff9..000000000 --- a/docs/base/models/classification/timm_base/DFNE.md +++ /dev/null @@ -1,3 +0,0 @@ -# DFNE - -::: PytorchWildlife.models.classification.timm_base.DFNE \ No newline at end of file diff --git a/docs/base/models/classification/timm_base/Deepfaune.md b/docs/base/models/classification/timm_base/Deepfaune.md deleted file mode 100644 index a943f5900..000000000 --- a/docs/base/models/classification/timm_base/Deepfaune.md +++ /dev/null @@ -1,3 +0,0 @@ -# Deepfaune - -::: PytorchWildlife.models.classification.timm_base.Deepfaune \ No newline at end of file diff --git a/docs/base/models/classification/timm_base/base_classifier.md b/docs/base/models/classification/timm_base/base_classifier.md deleted file mode 100644 index ef03c85cb..000000000 --- a/docs/base/models/classification/timm_base/base_classifier.md +++ /dev/null @@ -1,3 +0,0 @@ -# Timm Base - -::: PytorchWildlife.models.classification.timm_base.base_classifier \ No newline at end of file diff --git a/docs/base/models/detection/base_detector.md b/docs/base/models/detection/base_detector.md deleted file mode 100644 index 7fbb5f9e0..000000000 --- a/docs/base/models/detection/base_detector.md +++ /dev/null @@ -1,3 +0,0 @@ -# Base Detector - -::: PytorchWildlife.models.detection.base_detector \ No newline at end of file diff --git a/docs/base/models/detection/herdnet.md b/docs/base/models/detection/herdnet.md deleted file mode 100644 index 557246333..000000000 --- a/docs/base/models/detection/herdnet.md +++ /dev/null @@ -1,3 +0,0 @@ -# HerdNet - -::: PytorchWildlife.models.detection.localization.herdnet \ No newline at end of file diff --git a/docs/base/models/detection/herdnet/animaloc/data/patches.md b/docs/base/models/detection/herdnet/animaloc/data/patches.md deleted file mode 100644 index 34716a179..000000000 --- a/docs/base/models/detection/herdnet/animaloc/data/patches.md +++ /dev/null @@ -1,3 +0,0 @@ -# Patches - -::: PytorchWildlife.models.detection.localization.animaloc.data.patches \ No newline at end of file diff --git a/docs/base/models/detection/herdnet/animaloc/data/types.md b/docs/base/models/detection/herdnet/animaloc/data/types.md deleted file mode 100644 index 02a3e7e46..000000000 --- a/docs/base/models/detection/herdnet/animaloc/data/types.md +++ /dev/null @@ -1,3 +0,0 @@ -# Types - -::: PytorchWildlife.models.detection.localization.animaloc.data.types \ No newline at end of file diff --git a/docs/base/models/detection/herdnet/animaloc/eval/lmds.md b/docs/base/models/detection/herdnet/animaloc/eval/lmds.md deleted file mode 100644 index 8e3066fc5..000000000 --- a/docs/base/models/detection/herdnet/animaloc/eval/lmds.md +++ /dev/null @@ -1,3 +0,0 @@ -# LMDS - -::: PytorchWildlife.models.detection.localization.animaloc.eval.lmds \ No newline at end of file diff --git a/docs/base/models/detection/herdnet/animaloc/eval/stitchers.md b/docs/base/models/detection/herdnet/animaloc/eval/stitchers.md deleted file mode 100644 index e96b6148c..000000000 --- a/docs/base/models/detection/herdnet/animaloc/eval/stitchers.md +++ /dev/null @@ -1,3 +0,0 @@ -# Stitchers - -::: PytorchWildlife.models.detection.localization.animaloc.eval.stitchers \ No newline at end of file diff --git a/docs/base/models/detection/herdnet/dla.md b/docs/base/models/detection/herdnet/dla.md deleted file mode 100644 index 5c5349cb1..000000000 --- a/docs/base/models/detection/herdnet/dla.md +++ /dev/null @@ -1,3 +0,0 @@ -# DLA - -::: PytorchWildlife.models.detection.localization.dla \ No newline at end of file diff --git a/docs/base/models/detection/herdnet/model.md b/docs/base/models/detection/herdnet/model.md deleted file mode 100644 index 7330fc235..000000000 --- a/docs/base/models/detection/herdnet/model.md +++ /dev/null @@ -1,3 +0,0 @@ -# Model - -::: PytorchWildlife.models.detection.localization.model \ No newline at end of file diff --git a/docs/base/models/detection/ultralytics_based/Deepfaune.md b/docs/base/models/detection/ultralytics_based/Deepfaune.md deleted file mode 100644 index f432ea92f..000000000 --- a/docs/base/models/detection/ultralytics_based/Deepfaune.md +++ /dev/null @@ -1,3 +0,0 @@ -# Deepfaune - -::: PytorchWildlife.models.detection.ultralytics_based.Deepfaune \ No newline at end of file diff --git a/docs/base/models/detection/ultralytics_based/megadetectorv5.md b/docs/base/models/detection/ultralytics_based/megadetectorv5.md deleted file mode 100644 index acc45ca84..000000000 --- a/docs/base/models/detection/ultralytics_based/megadetectorv5.md +++ /dev/null @@ -1,3 +0,0 @@ -# MegaDetector v5 - -::: PytorchWildlife.models.detection.ultralytics_based.megadetectorv5 \ No newline at end of file diff --git a/docs/base/models/detection/ultralytics_based/megadetectorv6.md b/docs/base/models/detection/ultralytics_based/megadetectorv6.md deleted file mode 100644 index 62fc7b329..000000000 --- a/docs/base/models/detection/ultralytics_based/megadetectorv6.md +++ /dev/null @@ -1,3 +0,0 @@ -# MegaDetector v6 - -::: PytorchWildlife.models.detection.ultralytics_based.megadetectorv6 \ No newline at end of file diff --git a/docs/base/models/detection/ultralytics_based/megadetectorv6_distributed.md b/docs/base/models/detection/ultralytics_based/megadetectorv6_distributed.md deleted file mode 100644 index afab79d5c..000000000 --- a/docs/base/models/detection/ultralytics_based/megadetectorv6_distributed.md +++ /dev/null @@ -1,3 +0,0 @@ -# MegaDetector v6 Distributed - -::: PytorchWildlife.models.detection.ultralytics_based.megadetectorv6_distributed \ No newline at end of file diff --git a/docs/base/models/detection/ultralytics_based/yolov5_base.md b/docs/base/models/detection/ultralytics_based/yolov5_base.md deleted file mode 100644 index 57366f9b2..000000000 --- a/docs/base/models/detection/ultralytics_based/yolov5_base.md +++ /dev/null @@ -1,3 +0,0 @@ -# YOLOv5 Base - -::: PytorchWildlife.models.detection.ultralytics_based.yolov5_base \ No newline at end of file diff --git a/docs/base/models/detection/ultralytics_based/yolov8_base.md b/docs/base/models/detection/ultralytics_based/yolov8_base.md deleted file mode 100644 index b71b3ac0e..000000000 --- a/docs/base/models/detection/ultralytics_based/yolov8_base.md +++ /dev/null @@ -1,3 +0,0 @@ -# YOLOv8 Base - -::: PytorchWildlife.models.detection.ultralytics_based.yolov8_base \ No newline at end of file diff --git a/docs/base/models/detection/ultralytics_based/yolov8_distributed.md b/docs/base/models/detection/ultralytics_based/yolov8_distributed.md deleted file mode 100644 index 12feadc67..000000000 --- a/docs/base/models/detection/ultralytics_based/yolov8_distributed.md +++ /dev/null @@ -1,3 +0,0 @@ -# YOLOv8 Distributed - -::: PytorchWildlife.models.detection.ultralytics_based.yolov8_distributed \ No newline at end of file diff --git a/docs/base/overview.md b/docs/base/overview.md deleted file mode 100644 index d2833a1ef..000000000 --- a/docs/base/overview.md +++ /dev/null @@ -1,37 +0,0 @@ -# PytorchWildlife Base Module - -The `PytorchWildlife` base module is the core component of this repository, designed to facilitate wildlife detection and classification tasks using PyTorch. It provides utilities for data processing, model implementation, and post-processing. It is also what is currently packaged in our Python package. - -## Overview - -The module is structured into the following submodules: - -- **`data`**: Contains utilities for handling datasets and applying transformations. -- **`models`**: Includes implementations for classification and detection models. -- **`utils`**: Provides miscellaneous utilities for post-processing and other tasks. - -## Submodules - -### `data` -- `datasets.py`: Defines dataset classes for loading and preprocessing data. -- `transforms.py`: Implements data augmentation and transformation utilities. - -### `models` -- `classification/`: Contains classification model architectures. -- `detection/`: Includes detection model architectures. - -### `utils` -- `misc.py`: Provides helper functions for miscellaneous tasks. -- `post_process.py`: Implements post-processing utilities for model outputs. - -## Getting Started - -To use the `PytorchWildlife` module, import the required submodules as follows: - -```python -from PytorchWildlife.data import datasets, transforms -from PytorchWildlife.models import classification, detection -from PytorchWildlife.utils import misc, post_process -``` - -Refer to the specific submodule documentation for detailed usage instructions. \ No newline at end of file diff --git a/docs/base/utils/misc.md b/docs/base/utils/misc.md deleted file mode 100644 index 72139cdf8..000000000 --- a/docs/base/utils/misc.md +++ /dev/null @@ -1 +0,0 @@ -::: PytorchWildlife.utils.misc \ No newline at end of file diff --git a/docs/base/utils/post_process.md b/docs/base/utils/post_process.md deleted file mode 100644 index 02b513c1b..000000000 --- a/docs/base/utils/post_process.md +++ /dev/null @@ -1 +0,0 @@ -::: PytorchWildlife.utils.post_process \ No newline at end of file diff --git a/docs/fine_tuning_modules/classification/overview.md b/docs/fine_tuning_modules/classification/overview.md deleted file mode 100644 index e138ab00d..000000000 --- a/docs/fine_tuning_modules/classification/overview.md +++ /dev/null @@ -1,10 +0,0 @@ ---- -description: "PyTorch-Wildlife classification fine-tuning — adapt species classifiers to your own datasets and geographic regions." -tags: - - PyTorch-Wildlife - - classification fine-tuning - - species classification - - transfer learning ---- - -# In Construction diff --git a/docs/fine_tuning_modules/detection/overview.md b/docs/fine_tuning_modules/detection/overview.md deleted file mode 100644 index 36b6e84a8..000000000 --- a/docs/fine_tuning_modules/detection/overview.md +++ /dev/null @@ -1,10 +0,0 @@ ---- -description: "PyTorch-Wildlife detection fine-tuning — fine-tune MegaDetector on your own camera-trap data for improved local performance." -tags: - - PyTorch-Wildlife - - detection fine-tuning - - MegaDetector - - transfer learning ---- - -# In Construction \ No newline at end of file diff --git a/docs/fine_tuning_modules/overview.md b/docs/fine_tuning_modules/overview.md deleted file mode 100644 index 44d25c336..000000000 --- a/docs/fine_tuning_modules/overview.md +++ /dev/null @@ -1,10 +0,0 @@ ---- -description: "PyTorch-Wildlife fine-tuning modules — adapt MegaDetector and classification models to your own camera-trap datasets." -tags: - - PyTorch-Wildlife - - fine-tuning - - MegaDetector - - transfer learning ---- - -# In construction \ No newline at end of file diff --git a/docs/index.md b/docs/index.md index 7e4bc1d82..1403adbf5 100644 --- a/docs/index.md +++ b/docs/index.md @@ -13,67 +13,56 @@ tags: ![Microsoft Biodiversity banner showing wildlife monitored by AI across camera traps, bioacoustics, and aerial detection, powered by PyTorch-Wildlife and MegaDetector](assets/biodiversity-banner.png) -
+
Open-source AI for camera traps, bioacoustics, and wildlife monitoring

- - - - -

-## 👋 Welcome to PyTorch-Wildlife -**PyTorch-Wildlife** is an AI platform designed for the AI for Conservation community to create, modify, and share powerful AI conservation models. It allows users to directly load a variety of models including [MegaDetector](https://github.com/microsoft/MegaDetector), [DeepFaune](https://www.deepfaune.cnrs.fr/en/), and [HerdNet](https://github.com/Alexandre-Delplanque/HerdNet) from our ever expanding [model zoo](model_zoo/megadetector.md) for both animal detection and classification. +## Welcome to Microsoft Biodiversity -Our scope now spans well beyond camera-trap imagery, with active work in [MegaDetector-Acoustic](bioacoustics.md) for bioacoustic species identification, [MegaDetector-Overhead](model_zoo/other_detectors.md) for aerial wildlife detection, and edge computing for remote field deployments. +**Microsoft Biodiversity** is the home for the AI for Good Lab's open-source conservation AI: a family +of projects for finding wildlife in imagery and audio, deploying models in the field, and building on a +shared framework. This hub ties the projects together and points you to the right starting point. -**New here?** The [Microsoft Biodiversity ecosystem guide](ecosystem.md) walks through every open-source project in the family and helps you pick the right one for camera-trap images, audio, field hardware, or model development. See who is already using these tools on the [collaborators page](collaborators.md). +**Start here.** The [ecosystem guide](ecosystem.md) walks through every project and helps you choose by +what you already have (camera-trap images, audio recordings, a field device) or what you want to build. -> **Coming from an older version?** OWL is now **MegaDetector-Overhead**, the bioacoustics module is now **MegaDetector-Acoustic**, and the repo has moved from `microsoft/CameraTraps` to `microsoft/Biodiversity` (old links redirect automatically). See the [full naming changes](releases/release_notes.md#naming-changes) in the v1.3.0 release notes. +## The ecosystem +- **[MegaDetector](https://microsoft.github.io/MegaDetector/)** detects animals, people, and vehicles in + camera-trap images and filters out the blank frames. +- **[MegaDetector-Acoustic](https://microsoft.github.io/MegaDetector-Acoustic/)** classifies and identifies + terrestrial species from audio recordings. +- **[SPARROW](https://microsoft.github.io/SPARROW/)** is the solar-powered edge device that runs these + models on site in remote field deployments. +- **[PyTorch-Wildlife](https://microsoft.github.io/Pytorch-Wildlife/)** is the deep learning framework and + model zoo that ties the ecosystem together in Python. -## 🚀 Quick Start +## The PyTorch-Wildlife framework -👇 Here is a brief example on how to perform detection and classification on a single image using `PyTorch-Wildlife` -```python -import numpy as np -from PytorchWildlife.models import detection as pw_detection -from PytorchWildlife.models import classification as pw_classification +The PyTorch-Wildlife framework has its own home at +**[microsoft/Pytorch-Wildlife](https://github.com/microsoft/Pytorch-Wildlife)** +([documentation](https://microsoft.github.io/Pytorch-Wildlife/)). Install it with: -img = np.random.randn(3, 1280, 1280) - -# Detection -detection_model = pw_detection.MegaDetectorV6() # Model weights are automatically downloaded. -detection_result = detection_model.single_image_detection(img) - -#Classification -classification_model = pw_classification.AI4GAmazonRainforest() # Model weights are automatically downloaded. -classification_results = classification_model.single_image_classification(img) -``` - -## ⚙️ Install PyTorch-Wildlife ``` pip install PytorchWildlife ``` -Please refer to our [installation guide](installation.md) for more installation information. - -## 🖼️ Examples +See the [framework documentation](https://microsoft.github.io/Pytorch-Wildlife/) for quick-start examples, +the model zoo, and the full API reference. -### Image detection using `MegaDetector` -Camera trap photo with MegaDetector bounding box detecting an animal
-*Credits to Universidad de los Andes, Colombia.* +> **Coming from an older version?** OWL is now **MegaDetector-Overhead**, the bioacoustics module is now +> **MegaDetector-Acoustic**, and the repo moved from `microsoft/CameraTraps` to `microsoft/Biodiversity` +> (old links redirect automatically). See the [full naming changes](releases/release_notes.md#naming-changes) +> in the v1.3.0 release notes. -### Image classification with `MegaDetector` and `AI4GAmazonRainforest` -MegaDetector detection with AI4GAmazonRainforest species classification overlay
-*Credits to Universidad de los Andes, Colombia.* +## Who uses these tools -### Opossum ID with `MegaDetector` and `AI4GOpossum` -Opossum identified using MegaDetector and AI4GOpossum classification model
-*Credits to the Agency for Regulation and Control of Biosecurity and Quarantine for Galápagos (ABG), Ecuador.* +Conservation teams around the world build on the ecosystem. See the [collaborators page](collaborators.md) +for organizations already using these tools, and the +[ecosystem documentation standards](ecosystem-standards.md) if you are bringing a new project into the family. diff --git a/mkdocs.yml b/mkdocs.yml index 1d9715246..a8f0d01e5 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -71,7 +71,7 @@ theme: - navigation.footer nav: - - PyTorch-Wildlife: + - Microsoft Biodiversity: - Overview: index.md - The Ecosystem: ecosystem.md - Core Features: core_features.md @@ -103,64 +103,6 @@ nav: - Documentation Standards: ecosystem-standards.md - License: license.md - Developer Guide: build_mkdocs.md - - Reference - Code API: - - Base Module: - - Overview: base/overview.md - - Data: - - Datasets: base/data/datasets.md - - Transforms: base/data/transforms.md - - Models: - - Classification: - - Base Classifier: - - base/models/classification/base_classifier.md - - ResNet Base: - - base/models/classification/resnet_base/base_classifier.md - - base/models/classification/resnet_base/amazon.md - - base/models/classification/resnet_base/custom_weights.md - - base/models/classification/resnet_base/opossum.md - - base/models/classification/resnet_base/serengeti.md - - TIMM Base: - - base/models/classification/timm_base/base_classifier.md - - base/models/classification/timm_base/DFNE.md - - base/models/classification/timm_base/Deepfaune.md - - Detection: - - Base Detector: - - base/models/detection/base_detector.md - - HerdNet: - - base/models/detection/herdnet.md - - base/models/detection/herdnet/dla.md - - base/models/detection/herdnet/model.md - - Animaloc: - - Data: - - base/models/detection/herdnet/animaloc/data/patches.md - - base/models/detection/herdnet/animaloc/data/types.md - - Eval: - - base/models/detection/herdnet/animaloc/eval/lmds.md - - base/models/detection/herdnet/animaloc/eval/stitchers.md - - Ultralytics Based: - - base/models/detection/ultralytics_based/Deepfaune.md - - MegaDetector v5: - - base/models/detection/ultralytics_based/megadetectorv5.md - - MegaDetector v6: - - base/models/detection/ultralytics_based/megadetectorv6.md - - base/models/detection/ultralytics_based/megadetectorv6_distributed.md - - YOLOv5 Base: - - base/models/detection/ultralytics_based/yolov5_base.md - - YOLOv8 Base: - - base/models/detection/ultralytics_based/yolov8_base.md - - YOLOv8 Distributed: - - base/models/detection/ultralytics_based/yolov8_distributed.md - - Utils: - - base/utils/misc.md - - base/utils/post_process.md - - - Model Fine-tuning: - - Overview: fine_tuning_modules/overview.md - - Classification Fine-tuning: - - Overview: fine_tuning_modules/classification/overview.md - - Detection Fine-tuning: - - Overview: fine_tuning_modules/detection/overview.md - markdown_extensions: - admonition - pymdownx.details @@ -191,10 +133,4 @@ plugins: - git-revision-date-localized: type: date enable_creation_date: true - fallback_to_build_date: true - - mkdocstrings: - handlers: - python: - options: - docstring_style: google - + fallback_to_build_date: true \ No newline at end of file diff --git a/setup.py b/setup.py deleted file mode 100644 index 943198e66..000000000 --- a/setup.py +++ /dev/null @@ -1,49 +0,0 @@ -from pathlib import Path -from setuptools import setup, find_packages - -HERE = Path(__file__).parent -with open(HERE / 'README.md', encoding="utf8") as file: - long_description = file.read() -VERSION = (HERE / 'version.txt').read_text().strip() - -setup( - name='PytorchWildlife', - version=VERSION, - packages=find_packages(), - include_package_data=True, - package_data={"": ["*.yml"]}, - url='https://github.com/microsoft/Biodiversity/', - license='MIT', - author='Andres Hernandez, Zhongqi Miao, Daniela Ruiz Lopez, Isai Daniel Chacon Silva', - author_email='v-hernandres@microsoft.com, zhongqimiao@microsoft.com, v-druizlopez@microsoft.com, v-ichaconsil@microsoft.com', - description='a PyTorch Collaborative Deep Learning Framework for Conservation.', - long_description=long_description, - long_description_content_type='text/markdown', - install_requires=[ - 'torch', - 'torchvision', - 'torchaudio', - 'tqdm', - 'Pillow', - 'supervision==0.23.0', - 'gradio', - 'ultralytics', - 'chardet', - 'wget', - 'yolov5', - 'setuptools', - 'scikit-learn', - 'timm', - 'omegaconf', - 'lightning', - 'setuptools==68.2.2' - ], - classifiers=[ - 'Development Status :: 3 - Alpha', - 'Intended Audience :: Developers', - 'License :: OSI Approved :: MIT License', - 'Programming Language :: Python :: 3', - ], - keywords='pytorch_wildlife, pytorch, wildlife, megadetector, conservation, animal, detection, classification', - python_requires='>=3.8', -) diff --git a/version.txt b/version.txt deleted file mode 100644 index f0bb29e76..000000000 --- a/version.txt +++ /dev/null @@ -1 +0,0 @@ -1.3.0 From 8eb641ac7bf2460651327508cdf3efd184b28644 Mon Sep 17 00:00:00 2001 From: rain-Brian Date: Wed, 3 Jun 2026 15:06:19 -0700 Subject: [PATCH 2/2] docs(seo): redirect moved API URLs to PW instead of 404; omit private Acoustic Per SEO review: a hard removal of the ~32 hub API pages would 404 and leak link signals. Replace the deletion outcome with mkdocs-redirects: each old hub API URL (base/*, fine_tuning_modules/*) now serves a 0-second meta-refresh to its Pytorch-Wildlife equivalent (Google treats instant meta-refresh as a permanent redirect on static hosts). Redirect stubs are auto-excluded from the sitemap; slugs are preserved 1:1. Also omit the private MegaDetector-Acoustic link from the umbrella homepage until its Pages URL is stable. --- docs-requirements.txt | 1 + docs/index.md | 2 +- mkdocs.yml | 36 +++++++++++++++++++++++++++++++++++- 3 files changed, 37 insertions(+), 2 deletions(-) diff --git a/docs-requirements.txt b/docs-requirements.txt index 54cc33c4a..59ae937b4 100644 --- a/docs-requirements.txt +++ b/docs-requirements.txt @@ -5,3 +5,4 @@ mkdocs-material-extensions mkdocs-callouts mkdocs-git-revision-date-localized-plugin pymdown-extensions +mkdocs-redirects diff --git a/docs/index.md b/docs/index.md index 1403adbf5..5117f4c37 100644 --- a/docs/index.md +++ b/docs/index.md @@ -36,7 +36,7 @@ what you already have (camera-trap images, audio recordings, a field device) or - **[MegaDetector](https://microsoft.github.io/MegaDetector/)** detects animals, people, and vehicles in camera-trap images and filters out the blank frames. -- **[MegaDetector-Acoustic](https://microsoft.github.io/MegaDetector-Acoustic/)** classifies and identifies +- **MegaDetector-Acoustic** (documentation coming soon) classifies and identifies terrestrial species from audio recordings. - **[SPARROW](https://microsoft.github.io/SPARROW/)** is the solar-powered edge device that runs these models on site in remote field deployments. diff --git a/mkdocs.yml b/mkdocs.yml index a8f0d01e5..1ab027b28 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -133,4 +133,38 @@ plugins: - git-revision-date-localized: type: date enable_creation_date: true - fallback_to_build_date: true \ No newline at end of file + fallback_to_build_date: true + - redirects: + redirect_maps: + 'base/data/datasets.md': 'https://microsoft.github.io/Pytorch-Wildlife/base/data/datasets/' + 'base/data/transforms.md': 'https://microsoft.github.io/Pytorch-Wildlife/base/data/transforms/' + 'base/models/classification/base_classifier.md': 'https://microsoft.github.io/Pytorch-Wildlife/base/models/classification/base_classifier/' + 'base/models/classification/resnet_base/amazon.md': 'https://microsoft.github.io/Pytorch-Wildlife/base/models/classification/resnet_base/amazon/' + 'base/models/classification/resnet_base/base_classifier.md': 'https://microsoft.github.io/Pytorch-Wildlife/base/models/classification/resnet_base/base_classifier/' + 'base/models/classification/resnet_base/custom_weights.md': 'https://microsoft.github.io/Pytorch-Wildlife/base/models/classification/resnet_base/custom_weights/' + 'base/models/classification/resnet_base/opossum.md': 'https://microsoft.github.io/Pytorch-Wildlife/base/models/classification/resnet_base/opossum/' + 'base/models/classification/resnet_base/serengeti.md': 'https://microsoft.github.io/Pytorch-Wildlife/base/models/classification/resnet_base/serengeti/' + 'base/models/classification/timm_base/DFNE.md': 'https://microsoft.github.io/Pytorch-Wildlife/base/models/classification/timm_base/DFNE/' + 'base/models/classification/timm_base/Deepfaune.md': 'https://microsoft.github.io/Pytorch-Wildlife/base/models/classification/timm_base/Deepfaune/' + 'base/models/classification/timm_base/base_classifier.md': 'https://microsoft.github.io/Pytorch-Wildlife/base/models/classification/timm_base/base_classifier/' + 'base/models/detection/base_detector.md': 'https://microsoft.github.io/Pytorch-Wildlife/base/models/detection/base_detector/' + 'base/models/detection/herdnet.md': 'https://microsoft.github.io/Pytorch-Wildlife/base/models/detection/herdnet/' + 'base/models/detection/herdnet/animaloc/data/patches.md': 'https://microsoft.github.io/Pytorch-Wildlife/base/models/detection/herdnet/animaloc/data/patches/' + 'base/models/detection/herdnet/animaloc/data/types.md': 'https://microsoft.github.io/Pytorch-Wildlife/base/models/detection/herdnet/animaloc/data/types/' + 'base/models/detection/herdnet/animaloc/eval/lmds.md': 'https://microsoft.github.io/Pytorch-Wildlife/base/models/detection/herdnet/animaloc/eval/lmds/' + 'base/models/detection/herdnet/animaloc/eval/stitchers.md': 'https://microsoft.github.io/Pytorch-Wildlife/base/models/detection/herdnet/animaloc/eval/stitchers/' + 'base/models/detection/herdnet/dla.md': 'https://microsoft.github.io/Pytorch-Wildlife/base/models/detection/herdnet/dla/' + 'base/models/detection/herdnet/model.md': 'https://microsoft.github.io/Pytorch-Wildlife/base/models/detection/herdnet/model/' + 'base/models/detection/ultralytics_based/Deepfaune.md': 'https://microsoft.github.io/Pytorch-Wildlife/base/models/detection/ultralytics_based/Deepfaune/' + 'base/models/detection/ultralytics_based/megadetectorv5.md': 'https://microsoft.github.io/Pytorch-Wildlife/base/models/detection/ultralytics_based/megadetectorv5/' + 'base/models/detection/ultralytics_based/megadetectorv6.md': 'https://microsoft.github.io/Pytorch-Wildlife/base/models/detection/ultralytics_based/megadetectorv6/' + 'base/models/detection/ultralytics_based/megadetectorv6_distributed.md': 'https://microsoft.github.io/Pytorch-Wildlife/base/models/detection/ultralytics_based/megadetectorv6_distributed/' + 'base/models/detection/ultralytics_based/yolov5_base.md': 'https://microsoft.github.io/Pytorch-Wildlife/base/models/detection/ultralytics_based/yolov5_base/' + 'base/models/detection/ultralytics_based/yolov8_base.md': 'https://microsoft.github.io/Pytorch-Wildlife/base/models/detection/ultralytics_based/yolov8_base/' + 'base/models/detection/ultralytics_based/yolov8_distributed.md': 'https://microsoft.github.io/Pytorch-Wildlife/base/models/detection/ultralytics_based/yolov8_distributed/' + 'base/overview.md': 'https://microsoft.github.io/Pytorch-Wildlife/base/overview/' + 'base/utils/misc.md': 'https://microsoft.github.io/Pytorch-Wildlife/base/utils/misc/' + 'base/utils/post_process.md': 'https://microsoft.github.io/Pytorch-Wildlife/base/utils/post_process/' + 'fine_tuning_modules/classification/overview.md': 'https://microsoft.github.io/Pytorch-Wildlife/fine_tuning_modules/classification/overview/' + 'fine_tuning_modules/detection/overview.md': 'https://microsoft.github.io/Pytorch-Wildlife/fine_tuning_modules/detection/overview/' + 'fine_tuning_modules/overview.md': 'https://microsoft.github.io/Pytorch-Wildlife/fine_tuning_modules/overview/' \ No newline at end of file