Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions .github/workflows/auto-label.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
name: Auto Label

on:
pull_request:

jobs:
label:
runs-on: ubuntu-latest
steps:
- name: Label pre-commit PR
if: github.actor == 'pre-commit-ci[bot]'
uses: actions-ecosystem/action-add-labels@v1
with:
github_token: ${{ secrets.GITHUB_TOKEN }}
labels: pre-commit
32 changes: 31 additions & 1 deletion .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,39 @@ jobs:
uses: actions/checkout@v4
- name: Build Changelog
id: changelog
uses: mikepenz/release-changelog-builder-action@v4
uses: mikepenz/release-changelog-builder-action@v6
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
configurationJson: |
{
"template": "# Changelog\n\n{{CHANGELOG}}\n\n<details>\n<summary>📦 Other changes</summary>\n\n{{UNCATEGORIZED}}\n\n</details>",
"categories": [
{
"title": "🚀 Features",
"labels": ["feature","enhancement"]
},
{
"title": "🐛 Bug Fixes",
"labels": ["bug","fix"]
},
{
"title": "📚 Documentation",
"labels": ["documentation","docs"]
},
{
"title": "⬆️ Dependencies",
"labels": ["dependencies"]
},
{
"title": "🧰 Maintenance",
"labels": ["chore","ci"]
}
],
"exclude_labels": [
"pre-commit"
]
}
- name: Create Release
id: create_release
uses: softprops/action-gh-release@v2
Expand Down
15 changes: 15 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,17 @@ Versioning <https://semver.org/spec/v2.0.0.html>`__.
Added
~~~~~

- Implement a generic `Registry` class and establish registries for backbones, models,
attention layers, preprocessors, augmenters, optimizers, schedulers, and losses.
- Add `forward_features` and `compute_features_output_shape` methods to all CNN
backbones to provide a standardized API for SSL and feature extraction.

Changed
~~~~~~~

- Refactor `ECG_CRNN`, `PreprocManager`, `AugmenterManager`, and `BaseTrainer` to
utilize the new registry system for dynamic component construction and decoupling.
- Enhance `SizeMixin` to support static shape inference for feature maps.
- Make the function `remove_spikes_naive` in `torch_ecg.utils.utils_signal`
support 2D and 3D input signals.
- Use `save_file` and `load_file` from the `safetensors` package for saving
Expand All @@ -36,6 +44,13 @@ Removed
Fixed
~~~~~

- Robustly handle dimension inference and initialization in `ECG_CRNN` models,
especially for cases with `None` or `Identity` modules.
- Address several compatibility issues for Python 3.13, including docstring
indentation and `NaN` comparisons in dataclasses.
- Improve error handling and encoding robustness in `CitationMixin` when
reading cache files.
- Resolve CodeQL warnings regarding incomplete URL substring sanitization in tests.
- Correctly update the `_df_metadata` attribute of the `PTBXL` database reader
classes after filtering records.
- Enhance the `save` method of the `torch_ecg.utils.utils_nn.CkptMixin` class:
Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,10 @@ omit = [
"torch_ecg/models/grad_cam.py",
# temporarily ignore torch_ecg/ssl since it's not implemented
"torch_ecg/ssl/*",
# temporarily ignore models that are not implemented completely
"torch_ecg/models/cnn/darknet.py",
"torch_ecg/models/cnn/efficientnet.py",
"torch_ecg/models/cnn/ho_resnet.py",
]
exclude_also = [
"raise NotImplementedError",
Expand Down
82 changes: 82 additions & 0 deletions test/test_models/test_backbone_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
"""
Unit tests for the standardized Backbone API.
"""

from copy import deepcopy

import pytest
import torch

from torch_ecg.model_configs import ECG_CRNN_CONFIG
from torch_ecg.models.registry import BACKBONES

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Extract all valid backbone configurations from the central config
BACKBONE_CONFIGS = []
for config_key, config_val in ECG_CRNN_CONFIG.cnn.items():
if not isinstance(config_val, dict):
continue

# 1. Try to get name from the config dict
# 2. If not in dict, use the key if it's in the registry
# 3. If key contains a known registry name (e.g. resnet_nature_comm), extract the base name
backbone_name = config_val.get("name")
if backbone_name is None:
if config_key in BACKBONES:
backbone_name = config_key
else:
# Check if config_key contains any registered name as a prefix
for registered_name in BACKBONES.list_all():
if config_key.startswith(registered_name):
backbone_name = registered_name
break

if backbone_name:
BACKBONE_CONFIGS.append((backbone_name, config_key, config_val))


@pytest.mark.parametrize("backbone_name, config_key, config", BACKBONE_CONFIGS)
def test_backbone_api(backbone_name, config_key, config):
n_leads = 12
batch_size = 2
seq_len = 2000

# Skip models that are not implemented yet to avoid noisy test failures
# These will be implemented in Phase 1.5
try:
model = BACKBONES.build(backbone_name, in_channels=n_leads, **deepcopy(config)).to(DEVICE)
except NotImplementedError:
pytest.skip(f"Backbone {backbone_name} (config: {config_key}) is not implemented yet.")
except Exception as e:
pytest.fail(f"Failed to build backbone {backbone_name} with config {config_key}: {e}")

model.eval()
inp = torch.randn(batch_size, n_leads, seq_len).to(DEVICE)

# 1. Test forward_features existence
assert hasattr(model, "forward_features"), f"Backbone {backbone_name} missing forward_features"

# 2. Test forward_features output shape
features = model.forward_features(inp)
assert features.ndim == 3, f"Backbone {backbone_name} forward_features should return 3D tensor, got {features.ndim}D"
assert features.shape[0] == batch_size

# 3. Test compute_features_output_shape consistency
# All Backbones in torch_ecg follow (seq_len, batch_size) signature
expected_shape = model.compute_features_output_shape(seq_len, batch_size)
assert (
features.shape[1] == expected_shape[1]
), f"Backbone {backbone_name} feature channels mismatch: {features.shape[1]} vs {expected_shape[1]}"
if expected_shape[2] is not None:
assert (
features.shape[2] == expected_shape[2]
), f"Backbone {backbone_name} feature seq_len mismatch: {features.shape[2]} vs {expected_shape[2]}"

# 4. Test forward consistency (if model is pure feature extractor)
out = model(inp)
assert torch.allclose(out, features), f"Backbone {backbone_name} forward and forward_features results differ"


if __name__ == "__main__":
pytest.main([__file__])
26 changes: 19 additions & 7 deletions torch_ecg/models/cnn/darknet.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,15 @@
5. Wang, C. Y., Bochkovskiy, A., & Liao, H. Y. M. (2020). Scaled-YOLOv4: Scaling Cross Stage Partial Network. arXiv preprint arXiv:2011.08036.
"""

from typing import List
from typing import List, Optional, Sequence, Union

from torch import nn
from torch import Tensor, nn

from ...models._nets import Conv_Bn_Activation, DownSample, GlobalContextBlock, NonLocalBlock, SEBlock # noqa: F401
from ...utils import CitationMixin, SizeMixin

__all__ = [
"DarkNet",
]


class DarkNet(nn.Sequential, SizeMixin, CitationMixin):
class DarkNet(SizeMixin, nn.Sequential, CitationMixin):
""" """

__name__ = "DarkNet"
Expand All @@ -32,6 +28,22 @@ def __init__(self, in_channels: int, **config) -> None:
super().__init__()
raise NotImplementedError

def compute_output_shape(
self, seq_len: Optional[int] = None, batch_size: Optional[int] = None
) -> Sequence[Union[int, None]]:
"""Compute the output shape of the model."""
raise NotImplementedError

def forward_features(self, input: Tensor) -> Tensor:
"""Forward pass of the model to extract features."""
raise NotImplementedError

def compute_features_output_shape(
self, seq_len: Optional[int] = None, batch_size: Optional[int] = None
) -> Sequence[Union[int, None]]:
"""Compute the output shape of the features."""
raise NotImplementedError

@property
def doi(self) -> List[str]:
return list(set(self.config.get("doi", []) + ["10.1109/CVPR.2016.91"]))
38 changes: 38 additions & 0 deletions torch_ecg/models/cnn/densenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,6 +760,44 @@ def compute_output_shape(
"""Compute the output shape of the network."""
return compute_sequential_output_shape(self, seq_len, batch_size)

def forward_features(self, input: Tensor) -> Tensor:
"""Forward pass of the model to extract features.

Parameters
----------
input : torch.Tensor
Input signal tensor,
of shape ``(batch_size, channels, seq_len)``.

Returns
-------
features : torch.Tensor
Feature map tensor,
of shape ``(batch_size, channels, seq_len)``.

"""
return self.forward(input)

def compute_features_output_shape(
self, seq_len: Optional[int] = None, batch_size: Optional[int] = None
) -> Sequence[Union[int, None]]:
"""Compute the output shape of the features.

Parameters
----------
seq_len : int, optional
Length of the input signal tensor.
batch_size : int, optional
Batch size of the input signal tensor.

Returns
-------
output_shape : sequence
Output shape of the features.

"""
return self.compute_output_shape(seq_len, batch_size)

@property
def in_channels(self) -> int:
return self.__in_channels
Expand Down
40 changes: 30 additions & 10 deletions torch_ecg/models/cnn/efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,13 @@

"""

from typing import List
from typing import List, Optional, Sequence, Union

from torch import nn
from torch import Tensor, nn

from ...models._nets import Conv_Bn_Activation, DownSample, GlobalContextBlock, NonLocalBlock, SEBlock # noqa: F401
from ...utils import CitationMixin, SizeMixin

__all__ = [
"EfficientNet",
]


class EfficientNet(nn.Module, SizeMixin, CitationMixin):
"""
Expand All @@ -38,10 +34,22 @@ def __init__(self, in_channels: int, **config) -> None:
super().__init__()
raise NotImplementedError

def forward(self):
def forward(self, input: Tensor) -> Tensor:
raise NotImplementedError

def compute_output_shape(
self, seq_len: Optional[int] = None, batch_size: Optional[int] = None
) -> Sequence[Union[int, None]]:
raise NotImplementedError

def forward_features(self, input: Tensor) -> Tensor:
"""Forward pass of the model to extract features."""
raise NotImplementedError

def compute_output_shape(self):
def compute_features_output_shape(
self, seq_len: Optional[int] = None, batch_size: Optional[int] = None
) -> Sequence[Union[int, None]]:
"""Compute the output shape of the features."""
raise NotImplementedError

@property
Expand All @@ -67,8 +75,20 @@ def __init__(self, in_channels: int, **config) -> None:
super().__init__()
raise NotImplementedError

def forward(self):
def forward(self, input: Tensor) -> Tensor:
raise NotImplementedError

def compute_output_shape(
self, seq_len: Optional[int] = None, batch_size: Optional[int] = None
) -> Sequence[Union[int, None]]:
raise NotImplementedError

def forward_features(self, input: Tensor) -> Tensor:
"""Forward pass of the model to extract features."""
raise NotImplementedError

def compute_output_shape(self):
def compute_features_output_shape(
self, seq_len: Optional[int] = None, batch_size: Optional[int] = None
) -> Sequence[Union[int, None]]:
"""Compute the output shape of the features."""
raise NotImplementedError
Loading
Loading