diff --git a/.patches/spconv.patch b/.patches/spconv.patch new file mode 100644 index 000000000..112ee9a6f --- /dev/null +++ b/.patches/spconv.patch @@ -0,0 +1,82 @@ +--- /mmdet3d/models/layers/spconv/overwrite_spconv/write_spconv2.py ++++ /mmdet3d/models/layers/spconv/overwrite_spconv/write_spconv2.py.corrected +@@ -1,9 +1,6 @@ + # Copyright (c) OpenMMLab. All rights reserved. +-import itertools +-from typing import List, OrderedDict + + from mmengine.registry import MODELS +-from torch.nn.parameter import Parameter + + + def register_spconv2() -> bool: +@@ -44,69 +41,4 @@ def register_spconv2() -> bool: + MODELS._register_module(SubMConv3d, 'SubMConv3d', force=True) + MODELS._register_module(SubMConv4d, 'SubMConv4d', force=True) + SparseModule._version = 2 +- SparseModule._load_from_state_dict = _load_from_state_dict + return True +- +- +-def _load_from_state_dict(self, state_dict: OrderedDict, prefix: str, +- local_metadata: dict, strict: bool, +- missing_keys: List[str], unexpected_keys: List[str], +- error_msgs: List[str]) -> None: +- """Rewrite this func to compat the convolutional kernel weights between +- spconv 1.x in MMCV and 2.x in spconv2.x. +- +- Kernel weights in MMCV spconv has shape in (D,H,W,in_channel,out_channel) , +- while those in spcon2.x is in (out_channel,D,H,W,in_channel). +- """ +- version = local_metadata.get('version', None) +- for hook in self._load_state_dict_pre_hooks.values(): +- hook(state_dict, prefix, local_metadata, strict, missing_keys, +- unexpected_keys, error_msgs) +- +- local_name_params = itertools.chain(self._parameters.items(), +- self._buffers.items()) +- local_state = {k: v.data for k, v in local_name_params if v is not None} +- +- for name, param in local_state.items(): +- key = prefix + name +- if key in state_dict: +- input_param = state_dict[key] +- +- # Backward compatibility: loading 1-dim tensor from +- # 0.3.* to version 0.4+ +- if len(param.shape) == 0 and len(input_param.shape) == 1: +- input_param = input_param[0] +- if version != 2: +- dims = [len(input_param.shape) - 1] + list( +- range(len(input_param.shape) - 1)) +- input_param = input_param.permute(*dims) +- if input_param.shape != param.shape: +- # local shape should match the one in checkpoint +- error_msgs.append( +- f'size mismatch for {key}: copying a param with ' +- f'shape {key, input_param.shape} from checkpoint,' +- f'the shape in current model is {param.shape}.') +- continue +- +- if isinstance(input_param, Parameter): +- # backwards compatibility for serialized parameters +- input_param = input_param.data +- try: +- param.copy_(input_param) +- except Exception: +- error_msgs.append( +- f'While copying the parameter named "{key}", whose ' +- f'dimensions in the model are {param.size()} and whose ' +- f'dimensions in the checkpoint are {input_param.size()}.') +- elif strict: +- missing_keys.append(key) +- +- if strict: +- for key, input_param in state_dict.items(): +- if key.startswith(prefix): +- input_name = key[len(prefix):] +- input_name = input_name.split( +- '.', 1)[0] # get the name of param/buffer/child +- if input_name not in self._modules \ +- and input_name not in local_state: +- unexpected_keys.append(key) diff --git a/Dockerfile b/Dockerfile index 582582c76..c837c4904 100644 --- a/Dockerfile +++ b/Dockerfile @@ -77,6 +77,7 @@ COPY .patches/mmengine.patch /tmp/mmengine.patch RUN cd $(python -c "import site; print(site.getsitepackages()[0])") \ && git apply < /tmp/mmdet3d.patch \ && git apply < /tmp/mmengine.patch \ + && git apply < /tmp/spconv.patch \ && rm -rf /tmp/* \ && cd / diff --git a/projects/PTv3/engines/test.py b/projects/PTv3/engines/test.py index 0bf7a6caf..f06e989bc 100644 --- a/projects/PTv3/engines/test.py +++ b/projects/PTv3/engines/test.py @@ -70,13 +70,12 @@ def build_model(self): checkpoint = torch.load(self.cfg.weight, weights_only=False) weight = OrderedDict() for key, value in checkpoint["state_dict"].items(): - if not key.startswith("module."): - key = "module." + key # xxx.xxx -> module.xxx.xxx - # Now all keys contain "module." no matter DDP or not. - if self.keywords in key: - key = key.replace(self.keywords, self.replacement, 1) - if comm.get_world_size() == 1: - key = key[7:] # module.xxx.xxx -> xxx.xxx + if key.startswith("module."): + if comm.get_world_size() == 1: + key = key[7:] # module.xxx.xxx -> xxx.xxx + else: + if comm.get_world_size() > 1: + key = "module." + key # xxx.xxx -> module.xxx.xxx weight[key] = value model.load_state_dict(weight, strict=True) self.logger.info("=> Loaded weight '{}' (epoch {})".format(self.cfg.weight, checkpoint["epoch"]))