Skip to content

Safetensors#27

Merged
wenh06 merged 33 commits intomasterfrom
safetensors
Mar 11, 2026
Merged

Safetensors#27
wenh06 merged 33 commits intomasterfrom
safetensors

Conversation

@wenh06
Copy link
Collaborator

@wenh06 wenh06 commented Sep 25, 2025

This PR typically changes the default behavior of the save method of the CkptMixin class. Now it uses the save_file method from safetensors instead of torch.save by default. See the comparison of the model saving mechanisms. The save method now has the following signature

    def save(
        self,
        path: Union[str, bytes, os.PathLike],
        train_config: CFG,
        extra_items: Optional[dict] = None,
        use_safetensors: bool = True,
        safetensors_single_file: bool = True,
    ) -> None:
        """Save the model to disk.

        .. note::

            `safetensors` is used by default to save the model.
            If one wants to save the models in `.pth` or `.pt` format,
            he/she must explicitly set ``use_safetensors=False``.

        Parameters
        ----------
        path : `path-like`
            Path to save the model.
        train_config : CFG
            Config for training the model,
            used when one restores the model.
        extra_items : dict, optional
            Extra items to save along with the model.
            The values should be serializable: can be saved as a json file,
            or is a dict of torch tensors.

            .. versionadded:: 0.0.32
        use_safetensors : bool, default True
            Whether to use `safetensors` to save the model.
            This will be overridden by the suffix of `path`:
            if it is `.safetensors`, then `use_safetensors` is set to True;
            if it is `.pth` or `.pt`, then if `use_safetensors` is True,
            the suffix is changed to `.safetensors`, otherwise it is unchanged.

            .. versionadded:: 0.0.32
        safetensors_single_file : bool, default True
            Whether to save the metadata along with the state dict into one file.

            .. versionadded:: 0.0.32

        Returns
        -------
        None

        """
        ...

This change is backward compatible. One is also able to save the models in pth/pt format like previously, by explicitly setting use_safetensors=False. The load method is able to load pth/pt format models correctly.

Copilot AI review requested due to automatic review settings September 25, 2025 14:09
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR integrates the safetensors library as the default serialization mechanism for the CkptMixin class, replacing torch.save and torch.load. The change maintains backward compatibility while enabling more secure model serialization.

  • Changes default behavior of the save method to use safetensors format
  • Adds comprehensive support for both single-file and directory-based safetensors formats
  • Implements fallback mechanisms for loading legacy PyTorch checkpoint formats

Reviewed Changes

Copilot reviewed 31 out of 32 changed files in this pull request and generated 7 comments.

Show a summary per file
File Description
torch_ecg/utils/utils_nn.py Major update to CkptMixin with safetensors integration and type annotations
torch_ecg/utils/misc.py Minor type annotation fix for make_serializable function
torch_ecg/utils/download.py Bytes handling fix for path processing
torch_ecg/models/grad_cam.py Tensor operation ordering fix (detach before cpu)
torch_ecg/models/ecg_crnn.py Type annotation additions and import statement
torch_ecg/models/_nets.py Type annotation improvements across various classes
torch_ecg/components/trainer.py Updated save method calls and type annotations
pyproject.toml Addition of safetensors dependency
Comments suppressed due to low confidence (1)

torch_ecg/utils/utils_nn.py:1

  • This else clause sets output_shape to [None] but this appears to be a fallback case that may not be reachable given the preceding conditions. Consider adding a comment explaining when this case would occur or remove if it's truly unreachable.
"""

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

"transposeconvolution",
]:
out_channels = num_filters
else:
Copy link

Copilot AI Sep 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The else clause raises an exception after all layer type checks. However, this code appears to be unreachable since all valid layer types should be handled by the preceding conditions. Consider removing this else clause or adding a comment explaining when this condition could be reached.

Suggested change
else:
else:
# This branch should be unreachable if all valid layer types are handled above.
# Retained as a safeguard in case an unknown or misspelled layer type is passed.

Copilot uses AI. Check for mistakes.
save_filename = f"{self.save_prefix}_epoch{self.epoch}_{get_date_str()}_{save_suffix}.pth.tar"
save_path = self.train_config.checkpoints / save_filename
if self.train_config.keep_checkpoint_max != 0:
# save_filename = f"{self.save_prefix}_epoch{self.epoch}_{get_date_str()}_{save_suffix}.pth.tar"
Copy link

Copilot AI Sep 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The commented out line should be removed since the code has been updated to use folder-based saving instead of file-based saving. Keeping commented code can cause confusion about the intended behavior.

Suggested change
# save_filename = f"{self.save_prefix}_epoch{self.epoch}_{get_date_str()}_{save_suffix}.pth.tar"

Copilot uses AI. Check for mistakes.
Comment on lines +282 to 294
# save_filename = f"BestModel_{self.save_prefix}{self.best_epoch}_{get_date_str()}_{save_suffix}.pth.tar"
save_folder = f"BestModel_{self.save_prefix}{self.best_epoch}_{get_date_str()}_{save_suffix}"
save_path = self.train_config.model_dir / save_folder # type: ignore
# self.save_checkpoint(path=str(save_path))
self._model.save(path=str(save_path), train_config=self.train_config)
self.log_manager.log_message(f"best model is saved at {save_path}")
elif self.train_config.monitor is None:
self.log_manager.log_message("no monitor is set, the last model is selected and saved as the best model")
self.log_manager.log_message(f"best model is saved at {save_path}") # type: ignore
elif self.train_config.monitor is None: # type: ignore
self.log_manager.log_message("no monitor is set, the last model is selected and saved as the best model") # type: ignore
self.best_state_dict = self._model.state_dict()
save_filename = f"BestModel_{self.save_prefix}{self.epoch}_{get_date_str()}.pth.tar"
save_path = self.train_config.model_dir / save_filename
# save_filename = f"BestModel_{self.save_prefix}{self.epoch}_{get_date_str()}.pth.tar"
save_folder = f"BestModel_{self.save_prefix}{self.epoch}_{get_date_str()}"
save_path = self.train_config.model_dir / save_folder # type: ignore
# self.save_checkpoint(path=str(save_path))
Copy link

Copilot AI Sep 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to the previous comment, this commented out line should be removed to avoid confusion about the current implementation approach.

Copilot uses AI. Check for mistakes.
Comment on lines +282 to 294
# save_filename = f"BestModel_{self.save_prefix}{self.best_epoch}_{get_date_str()}_{save_suffix}.pth.tar"
save_folder = f"BestModel_{self.save_prefix}{self.best_epoch}_{get_date_str()}_{save_suffix}"
save_path = self.train_config.model_dir / save_folder # type: ignore
# self.save_checkpoint(path=str(save_path))
self._model.save(path=str(save_path), train_config=self.train_config)
self.log_manager.log_message(f"best model is saved at {save_path}")
elif self.train_config.monitor is None:
self.log_manager.log_message("no monitor is set, the last model is selected and saved as the best model")
self.log_manager.log_message(f"best model is saved at {save_path}") # type: ignore
elif self.train_config.monitor is None: # type: ignore
self.log_manager.log_message("no monitor is set, the last model is selected and saved as the best model") # type: ignore
self.best_state_dict = self._model.state_dict()
save_filename = f"BestModel_{self.save_prefix}{self.epoch}_{get_date_str()}.pth.tar"
save_path = self.train_config.model_dir / save_filename
# save_filename = f"BestModel_{self.save_prefix}{self.epoch}_{get_date_str()}.pth.tar"
save_folder = f"BestModel_{self.save_prefix}{self.epoch}_{get_date_str()}"
save_path = self.train_config.model_dir / save_folder # type: ignore
# self.save_checkpoint(path=str(save_path))
Copy link

Copilot AI Sep 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another commented out line that should be removed to maintain clean code without obsolete references to the old file-based approach.

Copilot uses AI. Check for mistakes.
Comment on lines +781 to +782
if not str(path).endswith(".pth.tar"):
path = Path(path).with_suffix(".pth.tar") # type: ignore
Copy link

Copilot AI Sep 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type ignore comment suggests potential type mismatch issues. Consider using proper type handling instead of ignoring type checking, especially since Path.with_suffix() should work correctly with string inputs.

Copilot uses AI. Check for mistakes.
wenh06 and others added 2 commits September 25, 2025 22:15
Add a check to ensure the model has at least one module before the loop

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Update the error message for failing to load a safetensors file

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
@wenh06 wenh06 added the enhancement New feature or request label Sep 25, 2025
@wenh06 wenh06 self-assigned this Sep 25, 2025
@codecov
Copy link

codecov bot commented Sep 27, 2025

Codecov Report

❌ Patch coverage is 91.48230% with 77 lines in your changes missing coverage. Please review.
✅ Project coverage is 93.52%. Comparing base (3d8008e) to head (deb315e).
⚠️ Report is 43 commits behind head on master.
✅ All tests successful. No failed tests found.

Files with missing lines Patch % Lines
torch_ecg/utils/misc.py 87.41% 18 Missing ⚠️
torch_ecg/components/trainer.py 79.74% 16 Missing ⚠️
torch_ecg/utils/utils_nn.py 91.21% 13 Missing ⚠️
torch_ecg/models/ecg_crnn.py 88.05% 8 Missing ⚠️
torch_ecg/utils/download.py 91.66% 8 Missing ⚠️
torch_ecg/databases/base.py 84.78% 7 Missing ⚠️
torch_ecg/utils/utils_data.py 86.20% 4 Missing ⚠️
torch_ecg/models/_nets.py 96.07% 2 Missing ⚠️
torch_ecg/databases/physionet_databases/ludb.py 96.87% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master      #27      +/-   ##
==========================================
+ Coverage   92.86%   93.52%   +0.66%     
==========================================
  Files         138      134       -4     
  Lines       19020    18252     -768     
==========================================
- Hits        17663    17071     -592     
+ Misses       1357     1181     -176     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@wenh06 wenh06 enabled auto-merge October 7, 2025 08:39

def side_effect_get(url, *args, **kwargs):
parsed = urllib.parse.urlparse(url)
if parsed.netloc.endswith("dropbox.com"):

Check failure

Code scanning / CodeQL

Incomplete URL substring sanitization High test

The string
dropbox.com
may be at an arbitrary position in the sanitized URL.

Copilot Autofix

AI 4 days ago

In general, to fix incomplete URL substring sanitization, you should parse the URL, extract the hostname, and compare it exactly to allowed hostnames or verify it matches a well‑defined subdomain pattern (for example, endswith(".dropbox.com") or equals "dropbox.com"), rather than using simple substring or naive suffix checks on the full URL or netloc.

For this specific code, we should change the logic in side_effect_get and mock_head to perform strict host checks. Since urllib.parse.urlparse(url) is already used, we can rely on parsed.hostname, which normalizes the hostname and excludes the port. We then compare hostname against explicit allowed domains. For Dropbox, only www.dropbox.com is used in tests, but to be robust we can allow both "dropbox.com" and "www.dropbox.com" using an equality check. For GitHub, we should similarly check hostname equality against "github.com" or "raw.githubusercontent.com". For the Google branch in side_effect_get, we can check hostname equality to "google.com" or the hosts actually used for Google Drive in the tests (they use drive.google.com URLs but those are handled by a separate Google Drive mock, so we can leave that logic as‑is or keep it as strict "google.com"). The key is to replace parsed.netloc.endswith(...) with explicit hostname comparisons.

Concretely:

  • In test_http_get, within side_effect_get, replace:
    • if parsed.netloc.endswith("dropbox.com"): with a check on hostname equal to "dropbox.com" or "www.dropbox.com".
    • elif parsed.netloc.endswith(("github.com", "raw.githubusercontent.com")): with equality checks on hostname in a tuple of allowed hosts.
    • elif parsed.netloc.endswith("google.com"): with an equality check hostname == "google.com" (or a similar precise condition).
  • In test_url_is_reachable, within mock_head, replace if parsed.netloc.endswith("dropbox.com"): with an equality check on parsed.hostname in ("dropbox.com", "www.dropbox.com").

No new imports are needed; urllib.parse is already imported and used.


Suggested changeset 1
test/test_utils/test_download.py

Autofix patch

Autofix patch
Run the following command in your local git repository to apply this patch
cat << 'EOF' | git apply
diff --git a/test/test_utils/test_download.py b/test/test_utils/test_download.py
--- a/test/test_utils/test_download.py
+++ b/test/test_utils/test_download.py
@@ -81,11 +81,12 @@
 
     def side_effect_get(url, *args, **kwargs):
         parsed = urllib.parse.urlparse(url)
-        if parsed.netloc.endswith("dropbox.com"):
+        hostname = parsed.hostname or ""
+        if hostname in ("dropbox.com", "www.dropbox.com"):
             return mock_get_dropbox()
-        elif parsed.netloc.endswith(("github.com", "raw.githubusercontent.com")):
+        elif hostname in ("github.com", "raw.githubusercontent.com"):
             return mock_get_text()
-        elif parsed.netloc.endswith("google.com"):
+        elif hostname == "google.com":
             # Let the google drive test fail naturally or handle it if it makes requests
             return MockResponse(status_code=404)
         return original_get(url, *args, **kwargs)
@@ -216,7 +215,8 @@
                 self.status_code = status_code
 
         parsed = urllib.parse.urlparse(url)
-        if parsed.netloc.endswith("dropbox.com"):
+        hostname = parsed.hostname or ""
+        if hostname in ("dropbox.com", "www.dropbox.com"):
             return MockResponse(200)
         return MockResponse(404)
 
EOF
@@ -81,11 +81,12 @@

def side_effect_get(url, *args, **kwargs):
parsed = urllib.parse.urlparse(url)
if parsed.netloc.endswith("dropbox.com"):
hostname = parsed.hostname or ""
if hostname in ("dropbox.com", "www.dropbox.com"):
return mock_get_dropbox()
elif parsed.netloc.endswith(("github.com", "raw.githubusercontent.com")):
elif hostname in ("github.com", "raw.githubusercontent.com"):
return mock_get_text()
elif parsed.netloc.endswith("google.com"):
elif hostname == "google.com":
# Let the google drive test fail naturally or handle it if it makes requests
return MockResponse(status_code=404)
return original_get(url, *args, **kwargs)
@@ -216,7 +215,8 @@
self.status_code = status_code

parsed = urllib.parse.urlparse(url)
if parsed.netloc.endswith("dropbox.com"):
hostname = parsed.hostname or ""
if hostname in ("dropbox.com", "www.dropbox.com"):
return MockResponse(200)
return MockResponse(404)

Copilot is powered by AI and may make mistakes. Always verify output.
return mock_get_dropbox()
elif parsed.netloc.endswith(("github.com", "raw.githubusercontent.com")):
return mock_get_text()
elif parsed.netloc.endswith("google.com"):

Check failure

Code scanning / CodeQL

Incomplete URL substring sanitization High test

The string
google.com
may be at an arbitrary position in the sanitized URL.

Copilot Autofix

AI 4 days ago

In general, to fix incomplete URL substring sanitization when checking hosts, you should parse the URL, extract the hostname, and then compare it either to an explicit allowlist of exact hostnames or to controlled suffixes that include a leading dot to represent subdomains (e.g., .example.com). This avoids treating evil-example.com as if it were example.com and prevents accidental matches when the trusted domain appears as a substring or suffix in an unrelated hostname.

In this specific test code, the problematic pattern is parsed.netloc.endswith("google.com") on line 88. To preserve intended behavior while avoiding incomplete sanitization, we can change this to either (a) an exact match on "drive.google.com" (which is what the test URLs use) or (b) a safe subdomain check like host == "google.com" or host.endswith(".google.com"). Since the test only needs to recognize Google Drive URLs, the narrowest and most precise change is to check for "drive.google.com" explicitly. Concretely, we should replace the condition elif parsed.netloc.endswith("google.com"): with an exact equality check elif parsed.netloc == "drive.google.com":. No new imports or helper methods are required; we simply modify this single line within test_http_get in test/test_utils/test_download.py.

Suggested changeset 1
test/test_utils/test_download.py

Autofix patch

Autofix patch
Run the following command in your local git repository to apply this patch
cat << 'EOF' | git apply
diff --git a/test/test_utils/test_download.py b/test/test_utils/test_download.py
--- a/test/test_utils/test_download.py
+++ b/test/test_utils/test_download.py
@@ -85,7 +85,7 @@
             return mock_get_dropbox()
         elif parsed.netloc.endswith(("github.com", "raw.githubusercontent.com")):
             return mock_get_text()
-        elif parsed.netloc.endswith("google.com"):
+        elif parsed.netloc == "drive.google.com":
             # Let the google drive test fail naturally or handle it if it makes requests
             return MockResponse(status_code=404)
         return original_get(url, *args, **kwargs)
EOF
@@ -85,7 +85,7 @@
return mock_get_dropbox()
elif parsed.netloc.endswith(("github.com", "raw.githubusercontent.com")):
return mock_get_text()
elif parsed.netloc.endswith("google.com"):
elif parsed.netloc == "drive.google.com":
# Let the google drive test fail naturally or handle it if it makes requests
return MockResponse(status_code=404)
return original_get(url, *args, **kwargs)
Copilot is powered by AI and may make mistakes. Always verify output.
self.status_code = status_code

parsed = urllib.parse.urlparse(url)
if parsed.netloc.endswith("dropbox.com"):

Check failure

Code scanning / CodeQL

Incomplete URL substring sanitization High test

The string
dropbox.com
may be at an arbitrary position in the sanitized URL.

Copilot Autofix

AI 4 days ago

In general, the fix is to avoid using substring/endswith checks directly on the host unless you ensure you are matching a whole hostname or a controlled suffix with a leading dot. For a single host, the simplest robust approach is to compare parsed.hostname directly to the expected hostname (for example "www.dropbox.com"), or, if subdomains should be allowed, ensure hostname == "example.com" or hostname.endswith(".example.com").

In this specific test, we only care about treating the concrete URL https://www.dropbox.com/... as reachable. We can safely tighten the condition to check for the exact hostname. Concretely, in test_url_is_reachable in test/test_utils/test_download.py, change the if parsed.netloc.endswith("dropbox.com"): line to:

if parsed.hostname == "www.dropbox.com":
    return MockResponse(200)

This keeps the test semantics (the tested URL has hostname www.dropbox.com) while avoiding an endswith on the raw netloc. No new imports are needed, as urllib.parse is already imported at the top of the file. No other lines need to change.

Suggested changeset 1
test/test_utils/test_download.py

Autofix patch

Autofix patch
Run the following command in your local git repository to apply this patch
cat << 'EOF' | git apply
diff --git a/test/test_utils/test_download.py b/test/test_utils/test_download.py
--- a/test/test_utils/test_download.py
+++ b/test/test_utils/test_download.py
@@ -216,7 +216,7 @@
                 self.status_code = status_code
 
         parsed = urllib.parse.urlparse(url)
-        if parsed.netloc.endswith("dropbox.com"):
+        if parsed.hostname == "www.dropbox.com":
             return MockResponse(200)
         return MockResponse(404)
 
EOF
@@ -216,7 +216,7 @@
self.status_code = status_code

parsed = urllib.parse.urlparse(url)
if parsed.netloc.endswith("dropbox.com"):
if parsed.hostname == "www.dropbox.com":
return MockResponse(200)
return MockResponse(404)

Copilot is powered by AI and may make mistakes. Always verify output.
match=f"Can't instantiate abstract class {_DataBase.__name__}",
):
db = _DataBase()
db = _DataBase() # type: ignore[abstract]
match=f"Can't instantiate abstract class {PhysioNetDataBase.__name__}",
):
db = PhysioNetDataBase()
db = PhysioNetDataBase() # type: ignore[abstract]
match=f"Can't instantiate abstract class {NSRRDataBase.__name__}",
):
db = NSRRDataBase()
db = NSRRDataBase() # type: ignore[abstract]
match=f"Can't instantiate abstract class {CPSCDataBase.__name__}",
):
db = CPSCDataBase()
db = CPSCDataBase() # type: ignore[abstract]
match=f"Can't instantiate abstract class {_DataBase.__name__}",
):
db = _DataBase()
db = _DataBase() # type: ignore[abstract]
match=f"Can't instantiate abstract class {PhysioNetDataBase.__name__}",
):
db = PhysioNetDataBase()
db = PhysioNetDataBase() # type: ignore[abstract]
match=f"Can't instantiate abstract class {NSRRDataBase.__name__}",
):
db = NSRRDataBase()
db = NSRRDataBase() # type: ignore[abstract]
@wenh06 wenh06 disabled auto-merge March 11, 2026 01:37
@wenh06 wenh06 merged commit 0004fb4 into master Mar 11, 2026
10 of 13 checks passed
@wenh06 wenh06 deleted the safetensors branch March 11, 2026 01:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants