From 8c8f46f7d2c23712375060df92b938c5a7d027c1 Mon Sep 17 00:00:00 2001 From: Hyungseok Shin Date: Tue, 30 Mar 2021 17:44:06 +0900 Subject: [PATCH 1/2] Add test option for pruned model - Can test pruned model using `--test-weight` option in `prune.py` - Formatting --- prune.py | 13 +++++++++++-- src/augmentation/methods.py | 4 +++- src/augmentation/policies.py | 8 ++++++-- src/criterions.py | 4 +++- src/models/adjmodule_getter.py | 6 +++--- src/models/densenet.py | 6 +++++- src/models/mixnet.py | 9 +++++++-- src/models/quant_densenet.py | 6 +++++- src/models/quant_mixnet.py | 3 ++- src/models/quant_resnet.py | 3 ++- src/models/resnet.py | 7 ++++--- src/models/utils.py | 12 ++++++++---- src/plotter.py | 10 ++++++---- src/regularizers.py | 3 ++- src/runners/pruner.py | 24 ++++++++++++++++++------ src/runners/shrinker.py | 7 +++++-- src/runners/trainer.py | 18 ++++++++++++++---- src/utils.py | 13 +++++++++---- 18 files changed, 113 insertions(+), 43 deletions(-) diff --git a/prune.py b/prune.py index 2828728..0b50deb 100644 --- a/prune.py +++ b/prune.py @@ -15,7 +15,10 @@ parser.add_argument("--multi-gpu", action="store_true", help="Multi-GPU use") parser.add_argument("--gpu", default=0, type=int, help="GPU id to use") parser.add_argument( - "--resume", type=str, default="", help="Input checkpoint directory name", + "--resume", + type=str, + default="", + help="Input checkpoint directory name", ) parser.add_argument( "--wlog", dest="wlog", action="store_true", help="Turns on wandb logging" @@ -26,6 +29,7 @@ default="config/prune/simplenet_kd.py", help="Configuration path", ) +parser.add_argument("--test-weight", default="", help="Weight filepath to test") parser.set_defaults(multi_gpu=False) parser.set_defaults(log=False) args = parser.parse_args() @@ -49,4 +53,9 @@ wandb_init_params=wandb_init_params, device=device, ) -pruner.run(args.resume) +if args.test_weight: + if not args.test_weight.startswith(args.resume): + raise Exception(f"{args.test_weight} from {args.resume} ?") + pruner.test(args.test_weight) +else: + pruner.run(args.resume) diff --git a/src/augmentation/methods.py b/src/augmentation/methods.py index 5aaa1de..675e45f 100644 --- a/src/augmentation/methods.py +++ b/src/augmentation/methods.py @@ -49,7 +49,9 @@ class SequentialAugmentation(Augmentation): """Sequential augmentation class.""" def __init__( - self, policies: List[Tuple[str, float, int]], n_level: int = 10, + self, + policies: List[Tuple[str, float, int]], + n_level: int = 10, ) -> None: """Initialize.""" super(SequentialAugmentation, self).__init__(n_level) diff --git a/src/augmentation/policies.py b/src/augmentation/policies.py index 56687de..0bc3f48 100644 --- a/src/augmentation/policies.py +++ b/src/augmentation/policies.py @@ -124,7 +124,9 @@ def autoaugment_train_cifar100_riair() -> transforms.Compose: def randaugment_train_cifar100( - n_select: int = 2, level: int = 14, n_level: int = 31, + n_select: int = 2, + level: int = 14, + n_level: int = 31, ) -> transforms.Compose: """Random augmentation policy for training CIFAR100.""" operators = [ @@ -156,7 +158,9 @@ def randaugment_train_cifar100( def randaugment_train_cifar100_224( - n_select: int = 2, level: int = 14, n_level: int = 31, + n_select: int = 2, + level: int = 14, + n_level: int = 31, ) -> transforms.Compose: operators = [ "Identity", diff --git a/src/criterions.py b/src/criterions.py index d1b0e98..1aea035 100644 --- a/src/criterions.py +++ b/src/criterions.py @@ -203,7 +203,9 @@ def add_label_smoothing(self, target: torch.Tensor) -> torch.Tensor: def get_criterion( - criterion_name: str, criterion_params: Dict[str, Any], device: torch.device, + criterion_name: str, + criterion_params: Dict[str, Any], + device: torch.device, ) -> nn.Module: """Create loss class.""" return eval(criterion_name)(device, **criterion_params) diff --git a/src/models/adjmodule_getter.py b/src/models/adjmodule_getter.py index 7b485e3..b524b8e 100644 --- a/src/models/adjmodule_getter.py +++ b/src/models/adjmodule_getter.py @@ -90,12 +90,12 @@ def backward_search(var: Any) -> None: graph[var].append(next_var[0]) backward_search(next_var[0]) - backward_search(out.grad_fn) + backward_search(out.grad_fn) # type: ignore return graph def _hook_fn(self, module: nn.Module, inp: torch.Tensor, out: torch.Tensor) -> None: - self.module_ahead[out.grad_fn] = module - self.op_behind[module] = out.grad_fn + self.module_ahead[out.grad_fn] = module # type: ignore + self.op_behind[module] = out.grad_fn # type: ignore if type(module) == nn.Flatten: # type: ignore self.last_conv_shape = inp[0].size()[-1] diff --git a/src/models/densenet.py b/src/models/densenet.py index 1c2ba0d..c3d3f36 100644 --- a/src/models/densenet.py +++ b/src/models/densenet.py @@ -24,7 +24,11 @@ class Bottleneck(nn.Module): """Bottleneck block for DenseNet.""" def __init__( - self, inplanes: int, expansion: int, growthRate: int, efficient: bool, + self, + inplanes: int, + expansion: int, + growthRate: int, + efficient: bool, ) -> None: """Initialize.""" super(Bottleneck, self).__init__() diff --git a/src/models/mixnet.py b/src/models/mixnet.py index 1b89f13..0bcf99c 100644 --- a/src/models/mixnet.py +++ b/src/models/mixnet.py @@ -147,7 +147,10 @@ def __init__( self.stem = nn.Sequential( ConvBN( - in_channels=3, out_channels=stem, kernel_size=3, stride=stem_stride, + in_channels=3, + out_channels=stem, + kernel_size=3, + stride=stem_stride, ), HSwish(inplace=True), ) @@ -178,7 +181,9 @@ def __init__( if head: self.head = nn.Sequential( ConvBN( - in_channels=last_out_channels, out_channels=head, kernel_size=1, + in_channels=last_out_channels, + out_channels=head, + kernel_size=1, ), HSwish(inplace=True), ) diff --git a/src/models/quant_densenet.py b/src/models/quant_densenet.py index 3585e04..8c27301 100644 --- a/src/models/quant_densenet.py +++ b/src/models/quant_densenet.py @@ -19,7 +19,11 @@ class QuantizableBottleneck(Bottleneck): """Quantizable Bottleneck layer.""" def __init__( - self, inplanes: int, expansion: int, growthRate: int, efficient: bool, + self, + inplanes: int, + expansion: int, + growthRate: int, + efficient: bool, ) -> None: """Initialize.""" super(QuantizableBottleneck, self).__init__( diff --git a/src/models/quant_mixnet.py b/src/models/quant_mixnet.py index 769d88b..31f65eb 100644 --- a/src/models/quant_mixnet.py +++ b/src/models/quant_mixnet.py @@ -55,7 +55,8 @@ def __init__(self, **kwargs: bool) -> None: self.se = ( QuantizableSqueezeExcitation( - in_channels=self.out_channels, se_ratio=self.se_ratio, + in_channels=self.out_channels, + se_ratio=self.se_ratio, ) if self.has_se else Identity() diff --git a/src/models/quant_resnet.py b/src/models/quant_resnet.py index d1fb461..27a3a53 100644 --- a/src/models/quant_resnet.py +++ b/src/models/quant_resnet.py @@ -15,5 +15,6 @@ def get_model(model_type: str, num_classes: int, pretrained: bool = False) -> nn """Constructs a ResNet model.""" assert model_type in ["resnet18", "resnet50", "resnext101_32x8d"] return getattr( - __import__("torchvision.models.quantization", fromlist=[""]), model_type, + __import__("torchvision.models.quantization", fromlist=[""]), + model_type, )(pretrained=pretrained, num_classes=num_classes) diff --git a/src/models/resnet.py b/src/models/resnet.py index ce5e9ad..570d8f1 100644 --- a/src/models/resnet.py +++ b/src/models/resnet.py @@ -23,6 +23,7 @@ def get_model(model_type: str, num_classes: int, pretrained: bool = False) -> nn "wide_resnet50_2", "wide_resnet101_2", ] - return getattr(__import__("torchvision.models", fromlist=[""]), model_type,)( - pretrained=pretrained, num_classes=num_classes - ) + return getattr( + __import__("torchvision.models", fromlist=[""]), + model_type, + )(pretrained=pretrained, num_classes=num_classes) diff --git a/src/models/utils.py b/src/models/utils.py index 8887304..6db45d6 100644 --- a/src/models/utils.py +++ b/src/models/utils.py @@ -125,7 +125,9 @@ def get_masks(model: nn.Module) -> Dict[str, torch.Tensor]: def dummy_pruning(params_all: Tuple[Tuple[nn.Module, str], ...]) -> None: """Conduct fake pruning.""" prune.global_unstructured( - params_all, pruning_method=prune.L1Unstructured, amount=0.0, + params_all, + pruning_method=prune.L1Unstructured, + amount=0.0, ) @@ -246,9 +248,11 @@ def wlog_weight(model: nn.Module) -> None: named_buffers = eval( "model." + dot2bracket(layer_name) + ".named_buffers()" ) - mask: Tuple[str, torch.Tensor] = next( - x for x in list(named_buffers) if x[0] == "weight_mask" - )[1].cpu().data.numpy() + mask: Tuple[str, torch.Tensor] = ( + next(x for x in list(named_buffers) if x[0] == "weight_mask")[1] + .cpu() + .data.numpy() + ) masked_weight = weight[np.where(mask == 1.0)] wlog.update({w_name: wandb.Histogram(masked_weight)}) wandb.log(wlog, commit=False) diff --git a/src/plotter.py b/src/plotter.py index 79c2892..e8b8c00 100644 --- a/src/plotter.py +++ b/src/plotter.py @@ -35,8 +35,8 @@ class PruneStat(NamedTuple): class Plotter: """Plotter for models. - Currently, it only plots sparsity information of each layer of the model, - but it can be utilized for plotting all sort of infomration. + Currently, it only plots sparsity information of each layer of the model, + but it can be utilized for plotting all sort of infomration. """ def __init__(self, wandb_log: bool) -> None: @@ -129,7 +129,7 @@ def _plot_pruned_stats( ) -> None: """Plot pruned parameters for each layers.""" # extract type save_path: 'path+type.png' - stat_type = save_path.rsplit(".", 3)[0].rsplit("/", 1)[1] + stat_type = save_path.rsplit(".", 1)[0].rsplit("/", 1)[1] fig, ax = self._get_fig(x_names) x = np.arange(len(x_names)) @@ -200,7 +200,9 @@ def _plot_pruned_stats( ) def _annotate_on_bar( - self, ax: matplotlib.axes.Axes, bars: List[matplotlib.axes.Axes.bar], + self, + ax: matplotlib.axes.Axes, + bars: List[matplotlib.axes.Axes.bar], ) -> None: """Attach a text label above each bar in rects, displaying its height.""" for _, bar in enumerate(bars): diff --git a/src/regularizers.py b/src/regularizers.py index e807c1b..72ee1fb 100644 --- a/src/regularizers.py +++ b/src/regularizers.py @@ -38,7 +38,8 @@ def forward(self, model: nn.Module) -> float: def get_regularizer( - regularizer_name: str, regularizer_params: Dict[str, Any], + regularizer_name: str, + regularizer_params: Dict[str, Any], ) -> nn.Module: """Create regularizer class.""" if not regularizer_params: diff --git a/src/runners/pruner.py b/src/runners/pruner.py index 3e4b18e..a63613f 100644 --- a/src/runners/pruner.py +++ b/src/runners/pruner.py @@ -87,7 +87,9 @@ def get_params_to_prune(self) -> Tuple[Tuple[nn.Module, str], ...]: raise NotImplementedError def reset( - self, prune_iter: int, resumed: bool = False, + self, + prune_iter: int, + resumed: bool = False, ) -> Tuple[int, List[Tuple[str, float, Callable[[float], str]]]]: """Reset the processes for pruning or pretraining. @@ -214,6 +216,16 @@ def resume(self) -> int: return last_iter + def test(self, weight_path: str) -> None: + """Test the model with saved weight.""" + original_state_dict = copy.deepcopy(self.model.state_dict()) + + state_dict = torch.load(weight_path)["state_dict"] + self.model.load_state_dict(state_dict) + avg_loss, acc = self.trainer.test_one_epoch_model(self.model) + print(f"Average loss: {avg_loss:.4f}\t Accuracy: {acc}") + self.model.load_state_dict(original_state_dict) + def run(self, resume_info_path: str = "") -> None: """Run pruning.""" # resume pruner if needed @@ -418,16 +430,16 @@ def update_masks(self, adjmodule_getter: AdjModuleGetter) -> None: ch_buffers = {name: buf for name, buf in channelrepr.named_buffers()} ch_mask = ch_buffers["weight_mask"].detach().clone() if "bias_mask" in ch_buffers: - ch_buffers["bias_mask"].set_(ch_mask) + ch_buffers["bias_mask"].set_(ch_mask) # type: ignore # Copy channel weight_mask to bn weight_mask, bias_mask bn_buffers = {name: buf for name, buf in bn.named_buffers()} - bn_buffers["weight_mask"].set_(ch_mask) - bn_buffers["bias_mask"].set_(ch_mask) + bn_buffers["weight_mask"].set_(ch_mask) # type: ignore + bn_buffers["bias_mask"].set_(ch_mask) # type: ignore conv_buffers = {name: buf for name, buf in conv.named_buffers()} if "bias_mask" in conv_buffers: - conv_buffers["bias_mask"].set_(ch_mask) + conv_buffers["bias_mask"].set_(ch_mask) # type: ignore # conv2d - batchnorm - activation (CBA) # bn_mask: [out], conv: [out, in, h, w] @@ -445,7 +457,7 @@ def update_masks(self, adjmodule_getter: AdjModuleGetter) -> None: # ch_mask: [out, in, h, w] ch_mask = ch_mask.repeat(1, i, 1, 1) - conv_buffers["weight_mask"].set_(ch_mask) + conv_buffers["weight_mask"].set_(ch_mask) # type: ignore # Update fc layer mask fc_modules: Dict[str, nn.Linear] = dict() diff --git a/src/runners/shrinker.py b/src/runners/shrinker.py index 5beff60..50c897f 100644 --- a/src/runners/shrinker.py +++ b/src/runners/shrinker.py @@ -228,7 +228,10 @@ def _reshape_fcs( self._set_layer(new_model, fc_name, reshaped_fc) def _generate_reshaped_conv( - self, in_mask: Optional[torch.Tensor], out_mask: torch.Tensor, conv: nn.Conv2d, + self, + in_mask: Optional[torch.Tensor], + out_mask: torch.Tensor, + conv: nn.Conv2d, ) -> nn.Conv2d: """Generate new conv given old conv and masks(in and out or out only).""" # Shrink both input, output channel of conv, and extract weight(orig, mask) @@ -238,7 +241,7 @@ def _generate_reshaped_conv( if in_mask is not None: # make masking matrix[o, i]: in_mask.T * out_mask # mask_flattened : [o*i] - mask_flattened = in_mask.unsqueeze(1).T * out_mask.unsqueeze(1) + mask_flattened = in_mask.unsqueeze(1).T * out_mask.unsqueeze(1) # type: ignore mask_flattened = mask_flattened.reshape(-1) mask_idx = (mask_flattened == 1).nonzero().view(-1, 1, 1).repeat(1, h, w) diff --git a/src/runners/trainer.py b/src/runners/trainer.py index 1e184cf..0770de0 100644 --- a/src/runners/trainer.py +++ b/src/runners/trainer.py @@ -100,12 +100,17 @@ def setup_train_configuration(self, config: Dict[str, Any]) -> None: # transform the training dataset for CutMix augmentation if "CUTMIX" in config: trainset = CutMix( - trainset, config["MODEL_PARAMS"]["num_classes"], **config["CUTMIX"], + trainset, + config["MODEL_PARAMS"]["num_classes"], + **config["CUTMIX"], ) # get dataloaders self.trainloader, self.testloader = utils.get_dataloader( - trainset, testset, config["BATCH_SIZE"], config["N_WORKERS"], + trainset, + testset, + config["BATCH_SIZE"], + config["N_WORKERS"], ) logger.info("Dataloader prepared") @@ -132,7 +137,8 @@ def setup_train_configuration(self, config: Dict[str, Any]) -> None: # learning rate scheduler self.lr_scheduler = get_lr_scheduler( - config["LR_SCHEDULER"], config["LR_SCHEDULER_PARAMS"], + config["LR_SCHEDULER"], + config["LR_SCHEDULER_PARAMS"], ) def reset(self, checkpt_dir: str) -> None: @@ -289,7 +295,11 @@ def warmup_one_iter(self) -> None: return None def save_params( - self, model_path: str, filename: str, epoch: int, record_path: bool = True, + self, + model_path: str, + filename: str, + epoch: int, + record_path: bool = True, ) -> None: """Save model.""" params = { diff --git a/src/utils.py b/src/utils.py index fe302b1..7a89413 100644 --- a/src/utils.py +++ b/src/utils.py @@ -51,7 +51,7 @@ def get_rand_bbox_coord( def to_onehot(labels: torch.Tensor, num_classes: int) -> torch.Tensor: """Convert index based labels into one-hot based labels. - If labels are one-hot based already(e.g. [0.9, 0.01, 0.03,...]), do nothing. + If labels are one-hot based already(e.g. [0.9, 0.01, 0.03,...]), do nothing. """ if len(labels.size()) == 1: return F.one_hot(labels, num_classes).float() @@ -71,10 +71,12 @@ def get_dataset( # preprocessing policies transform_train = getattr( - __import__("src.augmentation.policies", fromlist=[""]), transform_train, + __import__("src.augmentation.policies", fromlist=[""]), + transform_train, )(**transform_train_params) transform_test = getattr( - __import__("src.augmentation.policies", fromlist=[""]), transform_test, + __import__("src.augmentation.policies", fromlist=[""]), + transform_test, )(**transform_test_params) # pytorch dataset @@ -90,7 +92,10 @@ def get_dataset( def get_dataloader( - trainset: VisionDataset, testset: VisionDataset, batch_size: int, n_workers: int, + trainset: VisionDataset, + testset: VisionDataset, + batch_size: int, + n_workers: int, ) -> Tuple[data.DataLoader, data.DataLoader]: """Get dataloader for training and testing.""" trainloader = data.DataLoader( From a0a3cba6135b67a9bac0e0e356b45cf9efb6ccda Mon Sep 17 00:00:00 2001 From: Hyungseok Shin Date: Tue, 30 Mar 2021 18:04:34 +0900 Subject: [PATCH 2/2] Add printing inference time --- src/runners/pruner.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/runners/pruner.py b/src/runners/pruner.py index a63613f..0f98a84 100644 --- a/src/runners/pruner.py +++ b/src/runners/pruner.py @@ -9,6 +9,7 @@ import copy import itertools import os +import time from typing import Any, Callable, Dict, List, Set, Tuple, cast import torch @@ -222,7 +223,9 @@ def test(self, weight_path: str) -> None: state_dict = torch.load(weight_path)["state_dict"] self.model.load_state_dict(state_dict) + t0 = time.time() avg_loss, acc = self.trainer.test_one_epoch_model(self.model) + print(f"Inference time: {time.time() - t0:.2f} sec") print(f"Average loss: {avg_loss:.4f}\t Accuracy: {acc}") self.model.load_state_dict(original_state_dict)