Skip to content

Commit 9ee10fd

Browse files
ashwinvaidya17alexriedel1rajeshgangireddy
authored
🚀 feat(model): Enable Patchcore Training Half Precision (#3047) (#3055)
* 🚀 feat(model): REFACTOR Enable Patchcore Training Half Precision (#3047) * patchcore infer dtype from model dtype * lock * pre-commits Signed-off-by: Alexander Riedel <alex.riedel@googlemail.com> --------- Signed-off-by: Alexander Riedel <alex.riedel@googlemail.com> Signed-off-by: Ashwin Vaidya <ashwin.vaidya@intel.com> Co-authored-by: Ashwin Vaidya <ashwin.vaidya@intel.com> Co-authored-by: Rajesh Gangireddy <rajesh.gangireddy@intel.com> --------- Signed-off-by: Alexander Riedel <alex.riedel@googlemail.com> Signed-off-by: Ashwin Vaidya <ashwin.vaidya@intel.com> Co-authored-by: Alexander Riedel <alex.riedel@googlemail.com> Co-authored-by: Rajesh Gangireddy <rajesh.gangireddy@intel.com>
1 parent c306931 commit 9ee10fd

File tree

6 files changed

+6373
-4830
lines changed

6 files changed

+6373
-4830
lines changed

.github/actions/pytest/action.yaml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ runs:
112112
uv venv --python "$(command -v python)" .venv
113113
source .venv/bin/activate
114114
# Install dependencies with dev extras into the uv-managed environment
115-
uv pip install ".[dev]"
115+
uv pip install ".[dev,cu124]" # TODO(ashwinvaidya17): See issue #3050
116116
uv pip install codecov
117117
118118
# Determine which tests to run based on input
@@ -144,6 +144,7 @@ runs:
144144
INPUTS_MAX_TEST_TIME: ${{ inputs.max-test-time }}
145145
STEPS_TEST_SCOPE_OUTPUTS_PATH: ${{ steps.test-scope.outputs.path }}
146146
run: |
147+
source .venv/bin/activate
147148
start_time=$(date +%s)
148149
set -o pipefail
149150
@@ -155,7 +156,7 @@ runs:
155156
fi
156157
157158
# Run pytest
158-
PYTHONPATH=src uv run pytest "$STEPS_TEST_SCOPE_OUTPUTS_PATH" \
159+
PYTHONPATH=src pytest "$STEPS_TEST_SCOPE_OUTPUTS_PATH" \
159160
--numprocesses=0 \
160161
--durations=10 \
161162
--durations-min=1.0 \
@@ -225,6 +226,7 @@ runs:
225226
INPUTS_PYTHON_VERSION: ${{ inputs.python-version }}
226227

227228
run: |
229+
source .venv/bin/activate
228230
uv run codecov --token "$INPUTS_CODECOV_TOKEN" \
229231
--file coverage.xml \
230232
--flags "$INPUTS_TEST_TYPE_py$INPUTS_PYTHON_VERSION" \

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
88

99
### Added
1010

11+
- 🚀 **model**: Enable Patchcore Training Half Precision by @alexriedel1 in https://github.com/open-edge-platform/anomalib/pull/3047
12+
1113
### Removed
1214

1315
### Changed

pyproject.toml

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -100,28 +100,32 @@ test = [
100100
]
101101
# PyTorch dependency groups
102102
cpu = [
103-
"torch>=2.4.0",
103+
"torch>=2.4.0,<=2.8.0",
104104
"torchvision>=0.19.0",
105105
]
106106
cu118 = [
107-
"torch>=2.4.0",
107+
"torch>=2.4.0,<=2.8.0",
108108
"torchvision>=0.19.0",
109109
]
110110
cu121 = [
111-
"torch>=2.4.0",
111+
"torch>=2.4.0,<=2.8.0",
112112
"torchvision>=0.19.0",
113113
]
114114
cu124 = [
115-
"torch>=2.4.0",
115+
"torch>=2.4.0,<=2.8.0",
116+
"torchvision>=0.19.0",
117+
]
118+
cu130 = [
119+
"torch>=2.4.0", # Note: onnx export fails with torch>=2.9.0 and cuda 13 isn't supported by <=2.8.0
116120
"torchvision>=0.19.0",
117121
]
118122
rocm = [
119-
"torch>=2.4.0",
123+
"torch>=2.4.0,<=2.8.0",
120124
"torchvision>=0.19.0",
121125
"pytorch-triton-rocm ; sys_platform == 'linux'",
122126
]
123127
xpu = [
124-
"torch>=2.4.0",
128+
"torch>=2.4.0,<=2.8.0",
125129
"torchvision>=0.19.0",
126130
"pytorch-triton-xpu ; sys_platform == 'linux' or sys_platform == 'win32'",
127131
]
@@ -158,6 +162,7 @@ conflicts = [
158162
{ extra = "cu118" },
159163
{ extra = "cu121" },
160164
{ extra = "cu124" },
165+
{ extra = "cu130" },
161166
{ extra = "rocm" },
162167
{ extra = "xpu" },
163168
],
@@ -184,6 +189,11 @@ name = "pytorch-cu124"
184189
url = "https://download.pytorch.org/whl/cu124"
185190
explicit = true
186191

192+
[[tool.uv.index]]
193+
name = "pytorch-cu130"
194+
url = "https://download.pytorch.org/whl/cu130"
195+
explicit = true
196+
187197
[[tool.uv.index]]
188198
name = "pytorch-rocm61"
189199
url = "https://download.pytorch.org/whl/rocm6.1"
@@ -201,6 +211,7 @@ torch = [
201211
{ index = "pytorch-cu118", extra = "cu118" },
202212
{ index = "pytorch-cu121", extra = "cu121" },
203213
{ index = "pytorch-cu124", extra = "cu124" },
214+
{ index = "pytorch-cu130", extra = "cu130" },
204215
{ index = "pytorch-rocm61", extra = "rocm" },
205216
{ index = "pytorch-xpu", extra = "xpu" },
206217
]
@@ -209,6 +220,7 @@ torchvision = [
209220
{ index = "pytorch-cu118", extra = "cu118" },
210221
{ index = "pytorch-cu121", extra = "cu121" },
211222
{ index = "pytorch-cu124", extra = "cu124" },
223+
{ index = "pytorch-cu130", extra = "cu130" },
212224
{ index = "pytorch-rocm61", extra = "rocm" },
213225
{ index = "pytorch-xpu", extra = "xpu" },
214226
]

src/anomalib/__init__.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,3 +86,28 @@ class TaskType(str, Enum):
8686

8787
CLASSIFICATION = "classification"
8888
SEGMENTATION = "segmentation"
89+
90+
91+
class PrecisionType(str, Enum):
92+
"""Precision type defining the numerical precision used in model computations.
93+
94+
This enum defines the different precision types supported by anomalib models:
95+
96+
- ``FLOAT32``: Standard 32-bit floating point precision
97+
- ``FLOAT16``: Half-precision 16-bit floating point for faster computation
98+
99+
Example:
100+
>>> from anomalib import PrecisionType
101+
>>> precision_type = PrecisionType.FLOAT16
102+
>>> print(precision_type)
103+
'float16'
104+
105+
Note:
106+
The precision type affects:
107+
- Memory usage during model training and inference
108+
- Computational speed, especially on compatible hardware
109+
- Numerical stability of model computations
110+
"""
111+
112+
FLOAT32 = "float32"
113+
FLOAT16 = "float16"

src/anomalib/models/image/patchcore/lightning_model.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
from torch import nn
5757
from torchvision.transforms.v2 import CenterCrop, Compose, Normalize, Resize
5858

59-
from anomalib import LearningType
59+
from anomalib import LearningType, PrecisionType
6060
from anomalib.data import Batch
6161
from anomalib.metrics import Evaluator
6262
from anomalib.models.components import AnomalibModule, MemoryBankMixin
@@ -93,6 +93,9 @@ class Patchcore(MemoryBankMixin, AnomalibModule):
9393
subsample embeddings. Defaults to ``0.1``.
9494
num_neighbors (int, optional): Number of nearest neighbors to use.
9595
Defaults to ``9``.
96+
precision (str, optional): Precision type for model computations.
97+
Supported values are defined in :class:`PrecisionType`.
98+
Defaults to ``PrecisionType.FLOAT32``.
9699
pre_processor (PreProcessor | bool, optional): Pre-processor instance or
97100
bool flag. Defaults to ``True``.
98101
post_processor (PostProcessor | bool, optional): Post-processor instance or
@@ -102,6 +105,7 @@ class Patchcore(MemoryBankMixin, AnomalibModule):
102105
visualizer (Visualizer | bool, optional): Visualizer instance or bool flag.
103106
Defaults to ``True``.
104107
108+
105109
Example:
106110
>>> from anomalib.data import MVTecAD
107111
>>> from anomalib.models import Patchcore
@@ -112,7 +116,8 @@ class Patchcore(MemoryBankMixin, AnomalibModule):
112116
>>> model = Patchcore(
113117
... backbone="wide_resnet50_2",
114118
... layers=["layer2", "layer3"],
115-
... coreset_sampling_ratio=0.1
119+
... coreset_sampling_ratio=0.1,
120+
... precision="float32",
116121
... )
117122
118123
>>> # Train using the Engine
@@ -140,6 +145,7 @@ def __init__(
140145
pre_trained: bool = True,
141146
coreset_sampling_ratio: float = 0.1,
142147
num_neighbors: int = 9,
148+
precision: str = PrecisionType.FLOAT32,
143149
pre_processor: nn.Module | bool = True,
144150
post_processor: nn.Module | bool = True,
145151
evaluator: Evaluator | bool = True,
@@ -160,6 +166,15 @@ def __init__(
160166
)
161167
self.coreset_sampling_ratio = coreset_sampling_ratio
162168

169+
if precision == PrecisionType.FLOAT16:
170+
self.model = self.model.half()
171+
elif precision == PrecisionType.FLOAT32:
172+
self.model = self.model.float()
173+
else:
174+
msg = f"""Unsupported precision type: {precision}.
175+
Supported types are: {PrecisionType.FLOAT16}, {PrecisionType.FLOAT32}."""
176+
raise ValueError(msg)
177+
163178
@classmethod
164179
def configure_pre_processor(
165180
cls,

0 commit comments

Comments
 (0)