Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/markdown/guides/reference/callbacks/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ Track and measure execution times during training.
:show-inheritance:
```

:::{note}
ModelCheckpoint is automatically disabled when using `Engine(barebones=True)` for lightweight training without checkpoint overhead.
:::

(graph-logger)=

## {octicon}`graph` Graph Logger
Expand Down
82 changes: 82 additions & 0 deletions examples/api/04_advanced/training_modes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
"""Training modes and checkpoint control examples.
This example demonstrates different training modes in Anomalib:
1. Standard training with default checkpointing
2. Custom checkpoint callback configuration
3. Barebones mode for fast training/testing without checkpoint overhead
Note:
Under the hood, `Engine` uses Lightning's `Trainer` to manage the training
workflow. So, most Trainer arguments can be passed to the Engine constructor.
This includes parameters like
`max_epochs`, `enable_checkpointing`, `barebones`, `logger`, `callbacks`, etc.
For more details on available parameters, see:
https://lightning.ai/docs/pytorch/stable/common/trainer.html
"""

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from anomalib.callbacks import ModelCheckpoint
from anomalib.data import MVTecAD
from anomalib.engine import Engine
from anomalib.models import Fre

# Initialize model and data
model = Fre()
datamodule = MVTecAD(category="bottle")

print("1. Standard Training with Default Checkpointing")
print("-" * 50)
# 1. Standard training (checkpoints saved automatically)
engine = Engine(max_epochs=5)
engine.fit(model=model, datamodule=datamodule)

print(f"Checkpoint saved at: {engine.best_model_path}")


# 2. Custom checkpoint callback
# Example: don't save any checkpoints (useful for quick tests)

print("2. Custom Checkpoint Callback")
print("-" * 50)
checkpoint_callback = ModelCheckpoint(save_top_k=0)

engine = Engine(max_epochs=5, callbacks=[checkpoint_callback])
print("Training with custom checkpoint callback...")
engine.fit(model=model, datamodule=datamodule)
print(f"Checkpoint path: {engine.best_model_path}")
print()


# 3. Barebones Mode for Maximum Speed
# Barebones mode: minimal overhead, no checkpointing, no model summary, etc.
# Useful for benchmarking with minimal overhead.
# See Lightning docs: https://lightning.ai/docs/pytorch/stable/common/trainer.html#barebones

print("3. Barebones Training Mode")
print("-" * 50)

# Initialize model and data
model = Fre()
datamodule = MVTecAD(category="bottle")

# Create engine with barebones mode enabled
engine = Engine(
max_epochs=5,
barebones=True, # Minimal overhead, no checkpoint saving
)

# Train in barebones mode
print("Training in barebones mode (fastest)...")
engine.fit(model=model, datamodule=datamodule)

# Metrics are still captured and returned even in barebones mode
print("\nTesting and retrieving metrics...")
results = engine.test(model=model, datamodule=datamodule)

# Print results
print("\nTest Results:")
for metric, value in results[0].items():
print(f"{metric}: {value:.4f}")
21 changes: 18 additions & 3 deletions src/anomalib/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,9 +310,10 @@ def _setup_anomalib_callbacks(self) -> None:
"""Set up callbacks for the trainer."""
callbacks: list[Callback] = []

# Add ModelCheckpoint if it is not in the callbacks list.
# Add ModelCheckpoint if it is not in the callbacks list and barebones is not enabled.
has_checkpoint_callback = any(isinstance(c, ModelCheckpoint) for c in self._cache.args["callbacks"])
if has_checkpoint_callback is False:
is_barebones = self._cache.args.get("barebones", False)
if has_checkpoint_callback is False and not is_barebones:
callbacks.append(
ModelCheckpoint(
dirpath=self._cache.args["default_root_dir"] / "weights" / "lightning",
Expand Down Expand Up @@ -557,7 +558,21 @@ def test(
if self._should_run_validation(model or self.model, ckpt_path):
logger.info("Running validation before testing to collect normalization metrics and/or thresholds.")
self.trainer.validate(model, dataloaders, None, verbose=False, datamodule=datamodule)
return self.trainer.test(model, dataloaders, ckpt_path, verbose, datamodule, weights_only=False)

results = self.trainer.test(model, dataloaders, ckpt_path, verbose, datamodule, weights_only=False)

# In barebones mode, PyTorch Lightning may return empty results dict despite having logged metrics.
# Inject logged_metrics into results to ensure metrics are available in the return value.
if (
self.trainer.barebones
and results
and isinstance(results, list)
and not results[0]
and self.trainer.logged_metrics
):
results[0] = {k: v.item() if hasattr(v, "item") else v for k, v in self.trainer.logged_metrics.items()}

return results

def predict(
self,
Expand Down
16 changes: 14 additions & 2 deletions src/anomalib/metrics/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,15 @@ def on_validation_epoch_end(
pl_module: LightningModule,
) -> None:
"""Compute and log validation metrics."""
del trainer, pl_module # Unused argument.
del pl_module # Unused argument.
for metric in self.val_metrics:
self.log(metric.name, metric)
# In barebones mode, logging is disabled. We manually update trainer metrics
# to ensure they're available in both callback_metrics and the validate() return value
if trainer.barebones:
metric_value = metric.compute()
trainer.callback_metrics[metric.name] = metric_value
trainer.logged_metrics[metric.name] = metric_value

def on_test_batch_end(
self,
Expand All @@ -171,9 +177,15 @@ def on_test_epoch_end(
pl_module: LightningModule,
) -> None:
"""Compute and log test metrics."""
del trainer, pl_module # Unused argument.
del pl_module # Unused argument.
for metric in self.test_metrics:
self.log(metric.name, metric)
# In barebones mode, logging is disabled. We manually update trainer metrics
# to ensure they're available in both callback_metrics and the test() return value
if trainer.barebones:
metric_value = metric.compute()
trainer.callback_metrics[metric.name] = metric_value
trainer.logged_metrics[metric.name] = metric_value

def metrics_to_cpu(self, metrics: Metric | list[Metric] | ModuleList) -> None:
"""Set the compute_on_cpu attribute of the metrics to True."""
Expand Down
71 changes: 71 additions & 0 deletions tests/unit/engine/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
from anomalib.engine import Engine
from anomalib.models import Padim

# Tolerance threshold for comparing metric values between normal and barebones modes
METRIC_TOLERANCE = 0.01


class TestEngine:
"""Test Engine."""
Expand Down Expand Up @@ -119,3 +122,71 @@ def test_from_config(fxt_full_config_path: Path) -> None:
engine, model, datamodule = Engine.from_config(config_path=fxt_full_config_path, **override_kwargs)
assert datamodule.train_batch_size == 1
assert datamodule.num_workers == 1

@staticmethod
def test_barebones_mode_metrics_and_checkpointing(tmp_path: Path) -> None:
"""Test that barebones mode returns the same metrics and disables checkpointing.

This test verifies that:
1. Barebones mode and normal mode return the same metric values
2. Both modes return the same set of metric keys
3. Metrics are properly captured in barebones mode despite logging being disabled
4. Normal mode (barebones=False) creates checkpoint files
5. Barebones mode (barebones=True) does not create checkpoint files
"""
from lightning import seed_everything

datamodule = MVTecAD(category="toothbrush")

# Test with normal mode
seed_everything(42, workers=True)
model_normal = Padim()
engine_normal = Engine(default_root_dir=tmp_path / "normal")
engine_normal.fit(model=model_normal, datamodule=datamodule)
results_normal = engine_normal.test(model=model_normal, datamodule=datamodule)

# Test with barebones mode
seed_everything(42, workers=True)
model_barebones = Padim()
engine_barebones = Engine(default_root_dir=tmp_path / "barebones", barebones=True)
engine_barebones.fit(model=model_barebones, datamodule=datamodule)
results_barebones = engine_barebones.test(model=model_barebones, datamodule=datamodule)

# Verify both modes return results
assert results_normal
assert results_barebones
assert len(results_normal) > 0
assert len(results_barebones) > 0

# Extract metrics
metrics_normal = results_normal[0]
metrics_barebones = results_barebones[0]

# Verify both have the same metric keys
assert set(metrics_normal.keys()) == set(metrics_barebones.keys())

# Verify expected metrics are present
expected_metrics = {"image_AUROC", "image_F1Score", "pixel_AUROC", "pixel_F1Score"}
assert expected_metrics.issubset(set(metrics_normal.keys()))

# Verify metric values are the same
for metric_name in metrics_normal:
value_normal = metrics_normal[metric_name]
value_barebones = metrics_barebones[metric_name]

if hasattr(value_normal, "item"):
value_normal = value_normal.item()
if hasattr(value_barebones, "item"):
value_barebones = value_barebones.item()

assert abs(value_normal - value_barebones) < METRIC_TOLERANCE

# Verify checkpoint behavior
normal_checkpoints = list((tmp_path / "normal").rglob("*.ckpt"))
barebones_checkpoints = list((tmp_path / "barebones").rglob("*.ckpt"))

# Verify normal mode (barebones=False) creates checkpoints
assert len(normal_checkpoints) > 0, "Normal mode (barebones=False) should create checkpoint files"

# Verify barebones mode (barebones=True) does not create checkpoints
assert len(barebones_checkpoints) == 0, "Barebones mode (barebones=True) should not create checkpoint files"
Loading