Skip to content

Commit e07bf73

Browse files
fix: update precision parameter default to PrecisionType enum in AnomalyDINO and Patchcore
1 parent e015d7c commit e07bf73

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

src/anomalib/models/image/anomaly_dino/lightning_model.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ class AnomalyDINO(MemoryBankMixin, AnomalibModule):
9292
should we subsample. Defaults to ``0.1``
9393
precision (str | PrecisionType, optional): Precision type for model computations.
9494
Can be either a string (``"float32"``, ``"float16"``) or a :class:`PrecisionType` enum value.
95-
Defaults to ``"float32"``.
95+
Defaults to ``PrecisionType.FLOAT32``.
9696
pre_processor (PreProcessor | bool, optional): Pre-processor instance or
9797
bool flag to enable default preprocessing. Defaults to ``True``.
9898
post_processor (PostProcessor | bool, optional): Post-processor instance or
@@ -154,7 +154,7 @@ def __init__(
154154
masking: bool = False,
155155
coreset_subsampling: bool = False,
156156
sampling_ratio: float = 0.1,
157-
precision: str | PrecisionType = PrecisionType.FLOAT32.value,
157+
precision: str | PrecisionType = PrecisionType.FLOAT32,
158158
pre_processor: nn.Module | bool = True,
159159
post_processor: nn.Module | bool = True,
160160
evaluator: Evaluator | bool = True,
@@ -174,6 +174,10 @@ def __init__(
174174
sampling_ratio=sampling_ratio,
175175
)
176176

177+
# Convert string to PrecisionType enum if needed
178+
if isinstance(precision, str):
179+
precision = PrecisionType(precision.lower())
180+
177181
if precision == PrecisionType.FLOAT16:
178182
self.model = self.model.half()
179183
elif precision == PrecisionType.FLOAT32:

src/anomalib/models/image/patchcore/lightning_model.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ class Patchcore(MemoryBankMixin, AnomalibModule):
9595
Defaults to ``9``.
9696
precision (str | PrecisionType, optional): Precision type for model computations.
9797
Can be either a string (``"float32"``, ``"float16"``) or a :class:`PrecisionType` enum value.
98-
Defaults to ``"float32"``.
98+
Defaults to ``PrecisionType.FLOAT32``.
9999
pre_processor (PreProcessor | bool, optional): Pre-processor instance or
100100
bool flag. Defaults to ``True``.
101101
post_processor (PostProcessor | bool, optional): Post-processor instance or
@@ -145,7 +145,7 @@ def __init__(
145145
pre_trained: bool = True,
146146
coreset_sampling_ratio: float = 0.1,
147147
num_neighbors: int = 9,
148-
precision: str | PrecisionType = PrecisionType.FLOAT32.value,
148+
precision: str | PrecisionType = PrecisionType.FLOAT32,
149149
pre_processor: nn.Module | bool = True,
150150
post_processor: nn.Module | bool = True,
151151
evaluator: Evaluator | bool = True,
@@ -166,6 +166,10 @@ def __init__(
166166
)
167167
self.coreset_sampling_ratio = coreset_sampling_ratio
168168

169+
# Convert string to PrecisionType enum if needed
170+
if isinstance(precision, str):
171+
precision = PrecisionType(precision.lower())
172+
169173
if precision == PrecisionType.FLOAT16:
170174
self.model = self.model.half()
171175
elif precision == PrecisionType.FLOAT32:

0 commit comments

Comments
 (0)