Conversation
There was a problem hiding this comment.
Pull Request Overview
This PR integrates the safetensors library as the default serialization mechanism for the CkptMixin class, replacing torch.save and torch.load. The change maintains backward compatibility while enabling more secure model serialization.
- Changes default behavior of the save method to use
safetensorsformat - Adds comprehensive support for both single-file and directory-based safetensors formats
- Implements fallback mechanisms for loading legacy PyTorch checkpoint formats
Reviewed Changes
Copilot reviewed 31 out of 32 changed files in this pull request and generated 7 comments.
Show a summary per file
| File | Description |
|---|---|
| torch_ecg/utils/utils_nn.py | Major update to CkptMixin with safetensors integration and type annotations |
| torch_ecg/utils/misc.py | Minor type annotation fix for make_serializable function |
| torch_ecg/utils/download.py | Bytes handling fix for path processing |
| torch_ecg/models/grad_cam.py | Tensor operation ordering fix (detach before cpu) |
| torch_ecg/models/ecg_crnn.py | Type annotation additions and import statement |
| torch_ecg/models/_nets.py | Type annotation improvements across various classes |
| torch_ecg/components/trainer.py | Updated save method calls and type annotations |
| pyproject.toml | Addition of safetensors dependency |
Comments suppressed due to low confidence (1)
torch_ecg/utils/utils_nn.py:1
- This else clause sets output_shape to [None] but this appears to be a fallback case that may not be reachable given the preceding conditions. Consider adding a comment explaining when this case would occur or remove if it's truly unreachable.
"""
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
torch_ecg/utils/utils_nn.py
Outdated
| "transposeconvolution", | ||
| ]: | ||
| out_channels = num_filters | ||
| else: |
There was a problem hiding this comment.
The else clause raises an exception after all layer type checks. However, this code appears to be unreachable since all valid layer types should be handled by the preceding conditions. Consider removing this else clause or adding a comment explaining when this condition could be reached.
| else: | |
| else: | |
| # This branch should be unreachable if all valid layer types are handled above. | |
| # Retained as a safeguard in case an unknown or misspelled layer type is passed. |
torch_ecg/components/trainer.py
Outdated
| save_filename = f"{self.save_prefix}_epoch{self.epoch}_{get_date_str()}_{save_suffix}.pth.tar" | ||
| save_path = self.train_config.checkpoints / save_filename | ||
| if self.train_config.keep_checkpoint_max != 0: | ||
| # save_filename = f"{self.save_prefix}_epoch{self.epoch}_{get_date_str()}_{save_suffix}.pth.tar" |
There was a problem hiding this comment.
The commented out line should be removed since the code has been updated to use folder-based saving instead of file-based saving. Keeping commented code can cause confusion about the intended behavior.
| # save_filename = f"{self.save_prefix}_epoch{self.epoch}_{get_date_str()}_{save_suffix}.pth.tar" |
| # save_filename = f"BestModel_{self.save_prefix}{self.best_epoch}_{get_date_str()}_{save_suffix}.pth.tar" | ||
| save_folder = f"BestModel_{self.save_prefix}{self.best_epoch}_{get_date_str()}_{save_suffix}" | ||
| save_path = self.train_config.model_dir / save_folder # type: ignore | ||
| # self.save_checkpoint(path=str(save_path)) | ||
| self._model.save(path=str(save_path), train_config=self.train_config) | ||
| self.log_manager.log_message(f"best model is saved at {save_path}") | ||
| elif self.train_config.monitor is None: | ||
| self.log_manager.log_message("no monitor is set, the last model is selected and saved as the best model") | ||
| self.log_manager.log_message(f"best model is saved at {save_path}") # type: ignore | ||
| elif self.train_config.monitor is None: # type: ignore | ||
| self.log_manager.log_message("no monitor is set, the last model is selected and saved as the best model") # type: ignore | ||
| self.best_state_dict = self._model.state_dict() | ||
| save_filename = f"BestModel_{self.save_prefix}{self.epoch}_{get_date_str()}.pth.tar" | ||
| save_path = self.train_config.model_dir / save_filename | ||
| # save_filename = f"BestModel_{self.save_prefix}{self.epoch}_{get_date_str()}.pth.tar" | ||
| save_folder = f"BestModel_{self.save_prefix}{self.epoch}_{get_date_str()}" | ||
| save_path = self.train_config.model_dir / save_folder # type: ignore | ||
| # self.save_checkpoint(path=str(save_path)) |
There was a problem hiding this comment.
Similar to the previous comment, this commented out line should be removed to avoid confusion about the current implementation approach.
| # save_filename = f"BestModel_{self.save_prefix}{self.best_epoch}_{get_date_str()}_{save_suffix}.pth.tar" | ||
| save_folder = f"BestModel_{self.save_prefix}{self.best_epoch}_{get_date_str()}_{save_suffix}" | ||
| save_path = self.train_config.model_dir / save_folder # type: ignore | ||
| # self.save_checkpoint(path=str(save_path)) | ||
| self._model.save(path=str(save_path), train_config=self.train_config) | ||
| self.log_manager.log_message(f"best model is saved at {save_path}") | ||
| elif self.train_config.monitor is None: | ||
| self.log_manager.log_message("no monitor is set, the last model is selected and saved as the best model") | ||
| self.log_manager.log_message(f"best model is saved at {save_path}") # type: ignore | ||
| elif self.train_config.monitor is None: # type: ignore | ||
| self.log_manager.log_message("no monitor is set, the last model is selected and saved as the best model") # type: ignore | ||
| self.best_state_dict = self._model.state_dict() | ||
| save_filename = f"BestModel_{self.save_prefix}{self.epoch}_{get_date_str()}.pth.tar" | ||
| save_path = self.train_config.model_dir / save_filename | ||
| # save_filename = f"BestModel_{self.save_prefix}{self.epoch}_{get_date_str()}.pth.tar" | ||
| save_folder = f"BestModel_{self.save_prefix}{self.epoch}_{get_date_str()}" | ||
| save_path = self.train_config.model_dir / save_folder # type: ignore | ||
| # self.save_checkpoint(path=str(save_path)) |
There was a problem hiding this comment.
Another commented out line that should be removed to maintain clean code without obsolete references to the old file-based approach.
| if not str(path).endswith(".pth.tar"): | ||
| path = Path(path).with_suffix(".pth.tar") # type: ignore |
There was a problem hiding this comment.
The type ignore comment suggests potential type mismatch issues. Consider using proper type handling instead of ignoring type checking, especially since Path.with_suffix() should work correctly with string inputs.
Add a check to ensure the model has at least one module before the loop Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Update the error message for failing to load a safetensors file Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## master #27 +/- ##
==========================================
+ Coverage 92.86% 93.52% +0.66%
==========================================
Files 138 134 -4
Lines 19020 18252 -768
==========================================
- Hits 17663 17071 -592
+ Misses 1357 1181 -176 ☔ View full report in Codecov by Sentry. |
|
|
||
| def side_effect_get(url, *args, **kwargs): | ||
| parsed = urllib.parse.urlparse(url) | ||
| if parsed.netloc.endswith("dropbox.com"): |
Check failure
Code scanning / CodeQL
Incomplete URL substring sanitization High test
Show autofix suggestion
Hide autofix suggestion
Copilot Autofix
AI 4 days ago
In general, to fix incomplete URL substring sanitization, you should parse the URL, extract the hostname, and compare it exactly to allowed hostnames or verify it matches a well‑defined subdomain pattern (for example, endswith(".dropbox.com") or equals "dropbox.com"), rather than using simple substring or naive suffix checks on the full URL or netloc.
For this specific code, we should change the logic in side_effect_get and mock_head to perform strict host checks. Since urllib.parse.urlparse(url) is already used, we can rely on parsed.hostname, which normalizes the hostname and excludes the port. We then compare hostname against explicit allowed domains. For Dropbox, only www.dropbox.com is used in tests, but to be robust we can allow both "dropbox.com" and "www.dropbox.com" using an equality check. For GitHub, we should similarly check hostname equality against "github.com" or "raw.githubusercontent.com". For the Google branch in side_effect_get, we can check hostname equality to "google.com" or the hosts actually used for Google Drive in the tests (they use drive.google.com URLs but those are handled by a separate Google Drive mock, so we can leave that logic as‑is or keep it as strict "google.com"). The key is to replace parsed.netloc.endswith(...) with explicit hostname comparisons.
Concretely:
- In
test_http_get, withinside_effect_get, replace:if parsed.netloc.endswith("dropbox.com"):with a check onhostnameequal to"dropbox.com"or"www.dropbox.com".elif parsed.netloc.endswith(("github.com", "raw.githubusercontent.com")):with equality checks on hostname in a tuple of allowed hosts.elif parsed.netloc.endswith("google.com"):with an equality checkhostname == "google.com"(or a similar precise condition).
- In
test_url_is_reachable, withinmock_head, replaceif parsed.netloc.endswith("dropbox.com"):with an equality check onparsed.hostnamein("dropbox.com", "www.dropbox.com").
No new imports are needed; urllib.parse is already imported and used.
| @@ -81,11 +81,12 @@ | ||
|
|
||
| def side_effect_get(url, *args, **kwargs): | ||
| parsed = urllib.parse.urlparse(url) | ||
| if parsed.netloc.endswith("dropbox.com"): | ||
| hostname = parsed.hostname or "" | ||
| if hostname in ("dropbox.com", "www.dropbox.com"): | ||
| return mock_get_dropbox() | ||
| elif parsed.netloc.endswith(("github.com", "raw.githubusercontent.com")): | ||
| elif hostname in ("github.com", "raw.githubusercontent.com"): | ||
| return mock_get_text() | ||
| elif parsed.netloc.endswith("google.com"): | ||
| elif hostname == "google.com": | ||
| # Let the google drive test fail naturally or handle it if it makes requests | ||
| return MockResponse(status_code=404) | ||
| return original_get(url, *args, **kwargs) | ||
| @@ -216,7 +215,8 @@ | ||
| self.status_code = status_code | ||
|
|
||
| parsed = urllib.parse.urlparse(url) | ||
| if parsed.netloc.endswith("dropbox.com"): | ||
| hostname = parsed.hostname or "" | ||
| if hostname in ("dropbox.com", "www.dropbox.com"): | ||
| return MockResponse(200) | ||
| return MockResponse(404) | ||
|
|
| return mock_get_dropbox() | ||
| elif parsed.netloc.endswith(("github.com", "raw.githubusercontent.com")): | ||
| return mock_get_text() | ||
| elif parsed.netloc.endswith("google.com"): |
Check failure
Code scanning / CodeQL
Incomplete URL substring sanitization High test
Show autofix suggestion
Hide autofix suggestion
Copilot Autofix
AI 4 days ago
In general, to fix incomplete URL substring sanitization when checking hosts, you should parse the URL, extract the hostname, and then compare it either to an explicit allowlist of exact hostnames or to controlled suffixes that include a leading dot to represent subdomains (e.g., .example.com). This avoids treating evil-example.com as if it were example.com and prevents accidental matches when the trusted domain appears as a substring or suffix in an unrelated hostname.
In this specific test code, the problematic pattern is parsed.netloc.endswith("google.com") on line 88. To preserve intended behavior while avoiding incomplete sanitization, we can change this to either (a) an exact match on "drive.google.com" (which is what the test URLs use) or (b) a safe subdomain check like host == "google.com" or host.endswith(".google.com"). Since the test only needs to recognize Google Drive URLs, the narrowest and most precise change is to check for "drive.google.com" explicitly. Concretely, we should replace the condition elif parsed.netloc.endswith("google.com"): with an exact equality check elif parsed.netloc == "drive.google.com":. No new imports or helper methods are required; we simply modify this single line within test_http_get in test/test_utils/test_download.py.
| @@ -85,7 +85,7 @@ | ||
| return mock_get_dropbox() | ||
| elif parsed.netloc.endswith(("github.com", "raw.githubusercontent.com")): | ||
| return mock_get_text() | ||
| elif parsed.netloc.endswith("google.com"): | ||
| elif parsed.netloc == "drive.google.com": | ||
| # Let the google drive test fail naturally or handle it if it makes requests | ||
| return MockResponse(status_code=404) | ||
| return original_get(url, *args, **kwargs) |
| self.status_code = status_code | ||
|
|
||
| parsed = urllib.parse.urlparse(url) | ||
| if parsed.netloc.endswith("dropbox.com"): |
Check failure
Code scanning / CodeQL
Incomplete URL substring sanitization High test
Show autofix suggestion
Hide autofix suggestion
Copilot Autofix
AI 4 days ago
In general, the fix is to avoid using substring/endswith checks directly on the host unless you ensure you are matching a whole hostname or a controlled suffix with a leading dot. For a single host, the simplest robust approach is to compare parsed.hostname directly to the expected hostname (for example "www.dropbox.com"), or, if subdomains should be allowed, ensure hostname == "example.com" or hostname.endswith(".example.com").
In this specific test, we only care about treating the concrete URL https://www.dropbox.com/... as reachable. We can safely tighten the condition to check for the exact hostname. Concretely, in test_url_is_reachable in test/test_utils/test_download.py, change the if parsed.netloc.endswith("dropbox.com"): line to:
if parsed.hostname == "www.dropbox.com":
return MockResponse(200)This keeps the test semantics (the tested URL has hostname www.dropbox.com) while avoiding an endswith on the raw netloc. No new imports are needed, as urllib.parse is already imported at the top of the file. No other lines need to change.
| @@ -216,7 +216,7 @@ | ||
| self.status_code = status_code | ||
|
|
||
| parsed = urllib.parse.urlparse(url) | ||
| if parsed.netloc.endswith("dropbox.com"): | ||
| if parsed.hostname == "www.dropbox.com": | ||
| return MockResponse(200) | ||
| return MockResponse(404) | ||
|
|
| match=f"Can't instantiate abstract class {_DataBase.__name__}", | ||
| ): | ||
| db = _DataBase() | ||
| db = _DataBase() # type: ignore[abstract] |
| match=f"Can't instantiate abstract class {PhysioNetDataBase.__name__}", | ||
| ): | ||
| db = PhysioNetDataBase() | ||
| db = PhysioNetDataBase() # type: ignore[abstract] |
| match=f"Can't instantiate abstract class {NSRRDataBase.__name__}", | ||
| ): | ||
| db = NSRRDataBase() | ||
| db = NSRRDataBase() # type: ignore[abstract] |
| match=f"Can't instantiate abstract class {CPSCDataBase.__name__}", | ||
| ): | ||
| db = CPSCDataBase() | ||
| db = CPSCDataBase() # type: ignore[abstract] |
| match=f"Can't instantiate abstract class {_DataBase.__name__}", | ||
| ): | ||
| db = _DataBase() | ||
| db = _DataBase() # type: ignore[abstract] |
| match=f"Can't instantiate abstract class {PhysioNetDataBase.__name__}", | ||
| ): | ||
| db = PhysioNetDataBase() | ||
| db = PhysioNetDataBase() # type: ignore[abstract] |
| match=f"Can't instantiate abstract class {NSRRDataBase.__name__}", | ||
| ): | ||
| db = NSRRDataBase() | ||
| db = NSRRDataBase() # type: ignore[abstract] |
This PR typically changes the default behavior of the save method of the
CkptMixinclass. Now it uses thesave_filemethod fromsafetensorsinstead oftorch.saveby default. See the comparison of the model saving mechanisms. Thesavemethod now has the following signatureThis change is backward compatible. One is also able to save the models in
pth/ptformat like previously, by explicitly settinguse_safetensors=False. Theloadmethod is able to loadpth/ptformat models correctly.