From d47f4bb367d8ebfc8ee6bf98b4516bf295c5f3e4 Mon Sep 17 00:00:00 2001 From: Amadeusz Szymko Date: Fri, 6 Mar 2026 20:27:55 +0900 Subject: [PATCH 1/6] fix(PTv3): checkpoint load for testing Signed-off-by: Amadeusz Szymko --- projects/PTv3/engines/test.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/projects/PTv3/engines/test.py b/projects/PTv3/engines/test.py index 0bf7a6caf..f06e989bc 100644 --- a/projects/PTv3/engines/test.py +++ b/projects/PTv3/engines/test.py @@ -70,13 +70,12 @@ def build_model(self): checkpoint = torch.load(self.cfg.weight, weights_only=False) weight = OrderedDict() for key, value in checkpoint["state_dict"].items(): - if not key.startswith("module."): - key = "module." + key # xxx.xxx -> module.xxx.xxx - # Now all keys contain "module." no matter DDP or not. - if self.keywords in key: - key = key.replace(self.keywords, self.replacement, 1) - if comm.get_world_size() == 1: - key = key[7:] # module.xxx.xxx -> xxx.xxx + if key.startswith("module."): + if comm.get_world_size() == 1: + key = key[7:] # module.xxx.xxx -> xxx.xxx + else: + if comm.get_world_size() > 1: + key = "module." + key # xxx.xxx -> module.xxx.xxx weight[key] = value model.load_state_dict(weight, strict=True) self.logger.info("=> Loaded weight '{}' (epoch {})".format(self.cfg.weight, checkpoint["epoch"])) From 784c7e2474875bb8f939cf54a5befbca50befbc2 Mon Sep 17 00:00:00 2001 From: Amadeusz Szymko Date: Mon, 9 Mar 2026 12:45:34 +0900 Subject: [PATCH 2/6] fix(PTv3): set spconv2 Signed-off-by: Amadeusz Szymko --- projects/PTv3/engines/hooks/misc.py | 9 ++++++++- projects/PTv3/engines/test.py | 6 ++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/projects/PTv3/engines/hooks/misc.py b/projects/PTv3/engines/hooks/misc.py index 2c9e66da0..9e30de189 100644 --- a/projects/PTv3/engines/hooks/misc.py +++ b/projects/PTv3/engines/hooks/misc.py @@ -13,10 +13,11 @@ import torch import utils.comm as comm -from engines.test import TESTERS from utils.comm import is_main_process from utils.timer import Timer +from engines.test import TESTERS + from .builder import HOOKS from .default import HookBase @@ -213,6 +214,12 @@ def before_train(self): if comm.get_world_size() == 1: key = key[7:] # module.xxx.xxx -> xxx.xxx weight[key] = value + + # Preserve metadata and force spconv module version to 2 for layout-aware loading. + weight._metadata = getattr(checkpoint["state_dict"], "_metadata", OrderedDict()) + for meta in weight._metadata.values(): + if isinstance(meta, dict): + meta["version"] = 2 load_state_info = self.trainer.model.load_state_dict(weight, strict=self.strict) self.trainer.logger.info(f"Missing keys: {load_state_info[0]}") if self.trainer.cfg.resume: diff --git a/projects/PTv3/engines/test.py b/projects/PTv3/engines/test.py index f06e989bc..87c5ac430 100644 --- a/projects/PTv3/engines/test.py +++ b/projects/PTv3/engines/test.py @@ -77,6 +77,12 @@ def build_model(self): if comm.get_world_size() > 1: key = "module." + key # xxx.xxx -> module.xxx.xxx weight[key] = value + + # Keep state_dict metadata and force spconv module version to 2. + weight._metadata = getattr(checkpoint["state_dict"], "_metadata", OrderedDict()) + for meta in weight._metadata.values(): + if isinstance(meta, dict): + meta["version"] = 2 model.load_state_dict(weight, strict=True) self.logger.info("=> Loaded weight '{}' (epoch {})".format(self.cfg.weight, checkpoint["epoch"])) else: From 93f19bca5679616229cc5b82080f183b4d75b7c2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 9 Mar 2026 03:49:24 +0000 Subject: [PATCH 3/6] ci(pre-commit): autofix --- projects/PTv3/engines/hooks/misc.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/projects/PTv3/engines/hooks/misc.py b/projects/PTv3/engines/hooks/misc.py index 9e30de189..d4fbbf872 100644 --- a/projects/PTv3/engines/hooks/misc.py +++ b/projects/PTv3/engines/hooks/misc.py @@ -13,11 +13,10 @@ import torch import utils.comm as comm +from engines.test import TESTERS from utils.comm import is_main_process from utils.timer import Timer -from engines.test import TESTERS - from .builder import HOOKS from .default import HookBase From 6b363a6a62c0b07b9e40b1f6fe69305518f30b72 Mon Sep 17 00:00:00 2001 From: Amadeusz Szymko Date: Mon, 9 Mar 2026 15:09:21 +0900 Subject: [PATCH 4/6] fix(PTv3): simplification & adjust comment Signed-off-by: Amadeusz Szymko --- projects/PTv3/engines/hooks/misc.py | 8 ++++---- projects/PTv3/engines/test.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/projects/PTv3/engines/hooks/misc.py b/projects/PTv3/engines/hooks/misc.py index d4fbbf872..ba404a3b4 100644 --- a/projects/PTv3/engines/hooks/misc.py +++ b/projects/PTv3/engines/hooks/misc.py @@ -214,11 +214,11 @@ def before_train(self): key = key[7:] # module.xxx.xxx -> xxx.xxx weight[key] = value - # Preserve metadata and force spconv module version to 2 for layout-aware loading. + # A missing `_metadata` dict, e.g. from checkpoints not generated by mmdet3d, causes SparseModule._load_from_state_dict to incorrectly permute + # the weights (it assumes spconv v1 format when `_metadata` is missing). + # The fix: force `_metadata` to be present to ensure compatibility with all modern non-mmdet3d model checkpoints using spconv. weight._metadata = getattr(checkpoint["state_dict"], "_metadata", OrderedDict()) - for meta in weight._metadata.values(): - if isinstance(meta, dict): - meta["version"] = 2 + weight._metadata["version"] = 2 load_state_info = self.trainer.model.load_state_dict(weight, strict=self.strict) self.trainer.logger.info(f"Missing keys: {load_state_info[0]}") if self.trainer.cfg.resume: diff --git a/projects/PTv3/engines/test.py b/projects/PTv3/engines/test.py index 87c5ac430..de44172a5 100644 --- a/projects/PTv3/engines/test.py +++ b/projects/PTv3/engines/test.py @@ -78,11 +78,11 @@ def build_model(self): key = "module." + key # xxx.xxx -> module.xxx.xxx weight[key] = value - # Keep state_dict metadata and force spconv module version to 2. + # A missing `_metadata` dict, e.g. from checkpoints not generated by mmdet3d, causes SparseModule._load_from_state_dict to incorrectly permute + # the weights (it assumes spconv v1 format when `_metadata` is missing). + # The fix: force `_metadata` to be present to ensure compatibility with all modern non-mmdet3d model checkpoints using spconv. weight._metadata = getattr(checkpoint["state_dict"], "_metadata", OrderedDict()) - for meta in weight._metadata.values(): - if isinstance(meta, dict): - meta["version"] = 2 + weight._metadata["version"] = 2 model.load_state_dict(weight, strict=True) self.logger.info("=> Loaded weight '{}' (epoch {})".format(self.cfg.weight, checkpoint["epoch"])) else: From 3578a47ae8d5d45acf15b43effb65853201c9706 Mon Sep 17 00:00:00 2001 From: Amadeusz Szymko Date: Mon, 9 Mar 2026 19:00:15 +0900 Subject: [PATCH 5/6] fix(PTv3): remove mmdetection3d's spconv patch by patch Signed-off-by: Amadeusz Szymko --- .patches/spconv.patch | 82 +++++++++++++++++++++++++++++ Dockerfile | 1 + projects/PTv3/engines/hooks/misc.py | 9 +--- projects/PTv3/engines/test.py | 6 --- 4 files changed, 85 insertions(+), 13 deletions(-) create mode 100644 .patches/spconv.patch diff --git a/.patches/spconv.patch b/.patches/spconv.patch new file mode 100644 index 000000000..112ee9a6f --- /dev/null +++ b/.patches/spconv.patch @@ -0,0 +1,82 @@ +--- /mmdet3d/models/layers/spconv/overwrite_spconv/write_spconv2.py ++++ /mmdet3d/models/layers/spconv/overwrite_spconv/write_spconv2.py.corrected +@@ -1,9 +1,6 @@ + # Copyright (c) OpenMMLab. All rights reserved. +-import itertools +-from typing import List, OrderedDict + + from mmengine.registry import MODELS +-from torch.nn.parameter import Parameter + + + def register_spconv2() -> bool: +@@ -44,69 +41,4 @@ def register_spconv2() -> bool: + MODELS._register_module(SubMConv3d, 'SubMConv3d', force=True) + MODELS._register_module(SubMConv4d, 'SubMConv4d', force=True) + SparseModule._version = 2 +- SparseModule._load_from_state_dict = _load_from_state_dict + return True +- +- +-def _load_from_state_dict(self, state_dict: OrderedDict, prefix: str, +- local_metadata: dict, strict: bool, +- missing_keys: List[str], unexpected_keys: List[str], +- error_msgs: List[str]) -> None: +- """Rewrite this func to compat the convolutional kernel weights between +- spconv 1.x in MMCV and 2.x in spconv2.x. +- +- Kernel weights in MMCV spconv has shape in (D,H,W,in_channel,out_channel) , +- while those in spcon2.x is in (out_channel,D,H,W,in_channel). +- """ +- version = local_metadata.get('version', None) +- for hook in self._load_state_dict_pre_hooks.values(): +- hook(state_dict, prefix, local_metadata, strict, missing_keys, +- unexpected_keys, error_msgs) +- +- local_name_params = itertools.chain(self._parameters.items(), +- self._buffers.items()) +- local_state = {k: v.data for k, v in local_name_params if v is not None} +- +- for name, param in local_state.items(): +- key = prefix + name +- if key in state_dict: +- input_param = state_dict[key] +- +- # Backward compatibility: loading 1-dim tensor from +- # 0.3.* to version 0.4+ +- if len(param.shape) == 0 and len(input_param.shape) == 1: +- input_param = input_param[0] +- if version != 2: +- dims = [len(input_param.shape) - 1] + list( +- range(len(input_param.shape) - 1)) +- input_param = input_param.permute(*dims) +- if input_param.shape != param.shape: +- # local shape should match the one in checkpoint +- error_msgs.append( +- f'size mismatch for {key}: copying a param with ' +- f'shape {key, input_param.shape} from checkpoint,' +- f'the shape in current model is {param.shape}.') +- continue +- +- if isinstance(input_param, Parameter): +- # backwards compatibility for serialized parameters +- input_param = input_param.data +- try: +- param.copy_(input_param) +- except Exception: +- error_msgs.append( +- f'While copying the parameter named "{key}", whose ' +- f'dimensions in the model are {param.size()} and whose ' +- f'dimensions in the checkpoint are {input_param.size()}.') +- elif strict: +- missing_keys.append(key) +- +- if strict: +- for key, input_param in state_dict.items(): +- if key.startswith(prefix): +- input_name = key[len(prefix):] +- input_name = input_name.split( +- '.', 1)[0] # get the name of param/buffer/child +- if input_name not in self._modules \ +- and input_name not in local_state: +- unexpected_keys.append(key) diff --git a/Dockerfile b/Dockerfile index 582582c76..c837c4904 100644 --- a/Dockerfile +++ b/Dockerfile @@ -77,6 +77,7 @@ COPY .patches/mmengine.patch /tmp/mmengine.patch RUN cd $(python -c "import site; print(site.getsitepackages()[0])") \ && git apply < /tmp/mmdet3d.patch \ && git apply < /tmp/mmengine.patch \ + && git apply < /tmp/spconv.patch \ && rm -rf /tmp/* \ && cd / diff --git a/projects/PTv3/engines/hooks/misc.py b/projects/PTv3/engines/hooks/misc.py index ba404a3b4..48509e204 100644 --- a/projects/PTv3/engines/hooks/misc.py +++ b/projects/PTv3/engines/hooks/misc.py @@ -13,10 +13,11 @@ import torch import utils.comm as comm -from engines.test import TESTERS from utils.comm import is_main_process from utils.timer import Timer +from engines.test import TESTERS + from .builder import HOOKS from .default import HookBase @@ -213,12 +214,6 @@ def before_train(self): if comm.get_world_size() == 1: key = key[7:] # module.xxx.xxx -> xxx.xxx weight[key] = value - - # A missing `_metadata` dict, e.g. from checkpoints not generated by mmdet3d, causes SparseModule._load_from_state_dict to incorrectly permute - # the weights (it assumes spconv v1 format when `_metadata` is missing). - # The fix: force `_metadata` to be present to ensure compatibility with all modern non-mmdet3d model checkpoints using spconv. - weight._metadata = getattr(checkpoint["state_dict"], "_metadata", OrderedDict()) - weight._metadata["version"] = 2 load_state_info = self.trainer.model.load_state_dict(weight, strict=self.strict) self.trainer.logger.info(f"Missing keys: {load_state_info[0]}") if self.trainer.cfg.resume: diff --git a/projects/PTv3/engines/test.py b/projects/PTv3/engines/test.py index de44172a5..f06e989bc 100644 --- a/projects/PTv3/engines/test.py +++ b/projects/PTv3/engines/test.py @@ -77,12 +77,6 @@ def build_model(self): if comm.get_world_size() > 1: key = "module." + key # xxx.xxx -> module.xxx.xxx weight[key] = value - - # A missing `_metadata` dict, e.g. from checkpoints not generated by mmdet3d, causes SparseModule._load_from_state_dict to incorrectly permute - # the weights (it assumes spconv v1 format when `_metadata` is missing). - # The fix: force `_metadata` to be present to ensure compatibility with all modern non-mmdet3d model checkpoints using spconv. - weight._metadata = getattr(checkpoint["state_dict"], "_metadata", OrderedDict()) - weight._metadata["version"] = 2 model.load_state_dict(weight, strict=True) self.logger.info("=> Loaded weight '{}' (epoch {})".format(self.cfg.weight, checkpoint["epoch"])) else: From acb32d22564bb1f825d221f6eef53178b3910b23 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 9 Mar 2026 10:00:52 +0000 Subject: [PATCH 6/6] ci(pre-commit): autofix --- projects/PTv3/engines/hooks/misc.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/projects/PTv3/engines/hooks/misc.py b/projects/PTv3/engines/hooks/misc.py index 48509e204..2c9e66da0 100644 --- a/projects/PTv3/engines/hooks/misc.py +++ b/projects/PTv3/engines/hooks/misc.py @@ -13,11 +13,10 @@ import torch import utils.comm as comm +from engines.test import TESTERS from utils.comm import is_main_process from utils.timer import Timer -from engines.test import TESTERS - from .builder import HOOKS from .default import HookBase