Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions .patches/spconv.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
--- /mmdet3d/models/layers/spconv/overwrite_spconv/write_spconv2.py
+++ /mmdet3d/models/layers/spconv/overwrite_spconv/write_spconv2.py.corrected
@@ -1,9 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
-import itertools
-from typing import List, OrderedDict

from mmengine.registry import MODELS
-from torch.nn.parameter import Parameter


def register_spconv2() -> bool:
@@ -44,69 +41,4 @@ def register_spconv2() -> bool:
MODELS._register_module(SubMConv3d, 'SubMConv3d', force=True)
MODELS._register_module(SubMConv4d, 'SubMConv4d', force=True)
SparseModule._version = 2
- SparseModule._load_from_state_dict = _load_from_state_dict
return True
-
-
-def _load_from_state_dict(self, state_dict: OrderedDict, prefix: str,
- local_metadata: dict, strict: bool,
- missing_keys: List[str], unexpected_keys: List[str],
- error_msgs: List[str]) -> None:
- """Rewrite this func to compat the convolutional kernel weights between
- spconv 1.x in MMCV and 2.x in spconv2.x.
-
- Kernel weights in MMCV spconv has shape in (D,H,W,in_channel,out_channel) ,
- while those in spcon2.x is in (out_channel,D,H,W,in_channel).
- """
- version = local_metadata.get('version', None)
- for hook in self._load_state_dict_pre_hooks.values():
- hook(state_dict, prefix, local_metadata, strict, missing_keys,
- unexpected_keys, error_msgs)
-
- local_name_params = itertools.chain(self._parameters.items(),
- self._buffers.items())
- local_state = {k: v.data for k, v in local_name_params if v is not None}
-
- for name, param in local_state.items():
- key = prefix + name
- if key in state_dict:
- input_param = state_dict[key]
-
- # Backward compatibility: loading 1-dim tensor from
- # 0.3.* to version 0.4+
- if len(param.shape) == 0 and len(input_param.shape) == 1:
- input_param = input_param[0]
- if version != 2:
- dims = [len(input_param.shape) - 1] + list(
- range(len(input_param.shape) - 1))
- input_param = input_param.permute(*dims)
- if input_param.shape != param.shape:
- # local shape should match the one in checkpoint
- error_msgs.append(
- f'size mismatch for {key}: copying a param with '
- f'shape {key, input_param.shape} from checkpoint,'
- f'the shape in current model is {param.shape}.')
- continue
-
- if isinstance(input_param, Parameter):
- # backwards compatibility for serialized parameters
- input_param = input_param.data
- try:
- param.copy_(input_param)
- except Exception:
- error_msgs.append(
- f'While copying the parameter named "{key}", whose '
- f'dimensions in the model are {param.size()} and whose '
- f'dimensions in the checkpoint are {input_param.size()}.')
- elif strict:
- missing_keys.append(key)
-
- if strict:
- for key, input_param in state_dict.items():
- if key.startswith(prefix):
- input_name = key[len(prefix):]
- input_name = input_name.split(
- '.', 1)[0] # get the name of param/buffer/child
- if input_name not in self._modules \
- and input_name not in local_state:
- unexpected_keys.append(key)
1 change: 1 addition & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ COPY .patches/mmengine.patch /tmp/mmengine.patch
RUN cd $(python -c "import site; print(site.getsitepackages()[0])") \
&& git apply < /tmp/mmdet3d.patch \
&& git apply < /tmp/mmengine.patch \
&& git apply < /tmp/spconv.patch \
&& rm -rf /tmp/* \
&& cd /

Expand Down
13 changes: 6 additions & 7 deletions projects/PTv3/engines/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,12 @@ def build_model(self):
checkpoint = torch.load(self.cfg.weight, weights_only=False)
weight = OrderedDict()
for key, value in checkpoint["state_dict"].items():
if not key.startswith("module."):
key = "module." + key # xxx.xxx -> module.xxx.xxx
# Now all keys contain "module." no matter DDP or not.
if self.keywords in key:
key = key.replace(self.keywords, self.replacement, 1)
if comm.get_world_size() == 1:
key = key[7:] # module.xxx.xxx -> xxx.xxx
if key.startswith("module."):
if comm.get_world_size() == 1:
key = key[7:] # module.xxx.xxx -> xxx.xxx
else:
if comm.get_world_size() > 1:
key = "module." + key # xxx.xxx -> module.xxx.xxx
weight[key] = value
model.load_state_dict(weight, strict=True)
self.logger.info("=> Loaded weight '{}' (epoch {})".format(self.cfg.weight, checkpoint["epoch"]))
Expand Down