Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions prune.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
parser.add_argument("--multi-gpu", action="store_true", help="Multi-GPU use")
parser.add_argument("--gpu", default=0, type=int, help="GPU id to use")
parser.add_argument(
"--resume", type=str, default="", help="Input checkpoint directory name",
"--resume",
type=str,
default="",
help="Input checkpoint directory name",
)
parser.add_argument(
"--wlog", dest="wlog", action="store_true", help="Turns on wandb logging"
Expand All @@ -26,6 +29,7 @@
default="config/prune/simplenet_kd.py",
help="Configuration path",
)
parser.add_argument("--test-weight", default="", help="Weight filepath to test")
parser.set_defaults(multi_gpu=False)
parser.set_defaults(log=False)
args = parser.parse_args()
Expand All @@ -49,4 +53,9 @@
wandb_init_params=wandb_init_params,
device=device,
)
pruner.run(args.resume)
if args.test_weight:
if not args.test_weight.startswith(args.resume):
raise Exception(f"{args.test_weight} from {args.resume} ?")
pruner.test(args.test_weight)
else:
pruner.run(args.resume)
4 changes: 3 additions & 1 deletion src/augmentation/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ class SequentialAugmentation(Augmentation):
"""Sequential augmentation class."""

def __init__(
self, policies: List[Tuple[str, float, int]], n_level: int = 10,
self,
policies: List[Tuple[str, float, int]],
n_level: int = 10,
) -> None:
"""Initialize."""
super(SequentialAugmentation, self).__init__(n_level)
Expand Down
8 changes: 6 additions & 2 deletions src/augmentation/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,9 @@ def autoaugment_train_cifar100_riair() -> transforms.Compose:


def randaugment_train_cifar100(
n_select: int = 2, level: int = 14, n_level: int = 31,
n_select: int = 2,
level: int = 14,
n_level: int = 31,
) -> transforms.Compose:
"""Random augmentation policy for training CIFAR100."""
operators = [
Expand Down Expand Up @@ -156,7 +158,9 @@ def randaugment_train_cifar100(


def randaugment_train_cifar100_224(
n_select: int = 2, level: int = 14, n_level: int = 31,
n_select: int = 2,
level: int = 14,
n_level: int = 31,
) -> transforms.Compose:
operators = [
"Identity",
Expand Down
4 changes: 3 additions & 1 deletion src/criterions.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,9 @@ def add_label_smoothing(self, target: torch.Tensor) -> torch.Tensor:


def get_criterion(
criterion_name: str, criterion_params: Dict[str, Any], device: torch.device,
criterion_name: str,
criterion_params: Dict[str, Any],
device: torch.device,
) -> nn.Module:
"""Create loss class."""
return eval(criterion_name)(device, **criterion_params)
6 changes: 3 additions & 3 deletions src/models/adjmodule_getter.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,12 @@ def backward_search(var: Any) -> None:
graph[var].append(next_var[0])
backward_search(next_var[0])

backward_search(out.grad_fn)
backward_search(out.grad_fn) # type: ignore
return graph

def _hook_fn(self, module: nn.Module, inp: torch.Tensor, out: torch.Tensor) -> None:
self.module_ahead[out.grad_fn] = module
self.op_behind[module] = out.grad_fn
self.module_ahead[out.grad_fn] = module # type: ignore
self.op_behind[module] = out.grad_fn # type: ignore

if type(module) == nn.Flatten: # type: ignore
self.last_conv_shape = inp[0].size()[-1]
6 changes: 5 additions & 1 deletion src/models/densenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@ class Bottleneck(nn.Module):
"""Bottleneck block for DenseNet."""

def __init__(
self, inplanes: int, expansion: int, growthRate: int, efficient: bool,
self,
inplanes: int,
expansion: int,
growthRate: int,
efficient: bool,
) -> None:
"""Initialize."""
super(Bottleneck, self).__init__()
Expand Down
9 changes: 7 additions & 2 deletions src/models/mixnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,10 @@ def __init__(

self.stem = nn.Sequential(
ConvBN(
in_channels=3, out_channels=stem, kernel_size=3, stride=stem_stride,
in_channels=3,
out_channels=stem,
kernel_size=3,
stride=stem_stride,
),
HSwish(inplace=True),
)
Expand Down Expand Up @@ -178,7 +181,9 @@ def __init__(
if head:
self.head = nn.Sequential(
ConvBN(
in_channels=last_out_channels, out_channels=head, kernel_size=1,
in_channels=last_out_channels,
out_channels=head,
kernel_size=1,
),
HSwish(inplace=True),
)
Expand Down
6 changes: 5 additions & 1 deletion src/models/quant_densenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@ class QuantizableBottleneck(Bottleneck):
"""Quantizable Bottleneck layer."""

def __init__(
self, inplanes: int, expansion: int, growthRate: int, efficient: bool,
self,
inplanes: int,
expansion: int,
growthRate: int,
efficient: bool,
) -> None:
"""Initialize."""
super(QuantizableBottleneck, self).__init__(
Expand Down
3 changes: 2 additions & 1 deletion src/models/quant_mixnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ def __init__(self, **kwargs: bool) -> None:

self.se = (
QuantizableSqueezeExcitation(
in_channels=self.out_channels, se_ratio=self.se_ratio,
in_channels=self.out_channels,
se_ratio=self.se_ratio,
)
if self.has_se
else Identity()
Expand Down
3 changes: 2 additions & 1 deletion src/models/quant_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,6 @@ def get_model(model_type: str, num_classes: int, pretrained: bool = False) -> nn
"""Constructs a ResNet model."""
assert model_type in ["resnet18", "resnet50", "resnext101_32x8d"]
return getattr(
__import__("torchvision.models.quantization", fromlist=[""]), model_type,
__import__("torchvision.models.quantization", fromlist=[""]),
model_type,
)(pretrained=pretrained, num_classes=num_classes)
7 changes: 4 additions & 3 deletions src/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def get_model(model_type: str, num_classes: int, pretrained: bool = False) -> nn
"wide_resnet50_2",
"wide_resnet101_2",
]
return getattr(__import__("torchvision.models", fromlist=[""]), model_type,)(
pretrained=pretrained, num_classes=num_classes
)
return getattr(
__import__("torchvision.models", fromlist=[""]),
model_type,
)(pretrained=pretrained, num_classes=num_classes)
12 changes: 8 additions & 4 deletions src/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,9 @@ def get_masks(model: nn.Module) -> Dict[str, torch.Tensor]:
def dummy_pruning(params_all: Tuple[Tuple[nn.Module, str], ...]) -> None:
"""Conduct fake pruning."""
prune.global_unstructured(
params_all, pruning_method=prune.L1Unstructured, amount=0.0,
params_all,
pruning_method=prune.L1Unstructured,
amount=0.0,
)


Expand Down Expand Up @@ -246,9 +248,11 @@ def wlog_weight(model: nn.Module) -> None:
named_buffers = eval(
"model." + dot2bracket(layer_name) + ".named_buffers()"
)
mask: Tuple[str, torch.Tensor] = next(
x for x in list(named_buffers) if x[0] == "weight_mask"
)[1].cpu().data.numpy()
mask: Tuple[str, torch.Tensor] = (
next(x for x in list(named_buffers) if x[0] == "weight_mask")[1]
.cpu()
.data.numpy()
)
masked_weight = weight[np.where(mask == 1.0)]
wlog.update({w_name: wandb.Histogram(masked_weight)})
wandb.log(wlog, commit=False)
Expand Down
10 changes: 6 additions & 4 deletions src/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ class PruneStat(NamedTuple):
class Plotter:
"""Plotter for models.

Currently, it only plots sparsity information of each layer of the model,
but it can be utilized for plotting all sort of infomration.
Currently, it only plots sparsity information of each layer of the model,
but it can be utilized for plotting all sort of infomration.
"""

def __init__(self, wandb_log: bool) -> None:
Expand Down Expand Up @@ -129,7 +129,7 @@ def _plot_pruned_stats(
) -> None:
"""Plot pruned parameters for each layers."""
# extract type save_path: 'path+type.png'
stat_type = save_path.rsplit(".", 3)[0].rsplit("/", 1)[1]
stat_type = save_path.rsplit(".", 1)[0].rsplit("/", 1)[1]

fig, ax = self._get_fig(x_names)
x = np.arange(len(x_names))
Expand Down Expand Up @@ -200,7 +200,9 @@ def _plot_pruned_stats(
)

def _annotate_on_bar(
self, ax: matplotlib.axes.Axes, bars: List[matplotlib.axes.Axes.bar],
self,
ax: matplotlib.axes.Axes,
bars: List[matplotlib.axes.Axes.bar],
) -> None:
"""Attach a text label above each bar in rects, displaying its height."""
for _, bar in enumerate(bars):
Expand Down
3 changes: 2 additions & 1 deletion src/regularizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ def forward(self, model: nn.Module) -> float:


def get_regularizer(
regularizer_name: str, regularizer_params: Dict[str, Any],
regularizer_name: str,
regularizer_params: Dict[str, Any],
) -> nn.Module:
"""Create regularizer class."""
if not regularizer_params:
Expand Down
27 changes: 21 additions & 6 deletions src/runners/pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import copy
import itertools
import os
import time
from typing import Any, Callable, Dict, List, Set, Tuple, cast

import torch
Expand Down Expand Up @@ -87,7 +88,9 @@ def get_params_to_prune(self) -> Tuple[Tuple[nn.Module, str], ...]:
raise NotImplementedError

def reset(
self, prune_iter: int, resumed: bool = False,
self,
prune_iter: int,
resumed: bool = False,
) -> Tuple[int, List[Tuple[str, float, Callable[[float], str]]]]:
"""Reset the processes for pruning or pretraining.

Expand Down Expand Up @@ -214,6 +217,18 @@ def resume(self) -> int:

return last_iter

def test(self, weight_path: str) -> None:
"""Test the model with saved weight."""
original_state_dict = copy.deepcopy(self.model.state_dict())

state_dict = torch.load(weight_path)["state_dict"]
self.model.load_state_dict(state_dict)
t0 = time.time()
avg_loss, acc = self.trainer.test_one_epoch_model(self.model)
print(f"Inference time: {time.time() - t0:.2f} sec")
print(f"Average loss: {avg_loss:.4f}\t Accuracy: {acc}")
self.model.load_state_dict(original_state_dict)

def run(self, resume_info_path: str = "") -> None:
"""Run pruning."""
# resume pruner if needed
Expand Down Expand Up @@ -418,16 +433,16 @@ def update_masks(self, adjmodule_getter: AdjModuleGetter) -> None:
ch_buffers = {name: buf for name, buf in channelrepr.named_buffers()}
ch_mask = ch_buffers["weight_mask"].detach().clone()
if "bias_mask" in ch_buffers:
ch_buffers["bias_mask"].set_(ch_mask)
ch_buffers["bias_mask"].set_(ch_mask) # type: ignore

# Copy channel weight_mask to bn weight_mask, bias_mask
bn_buffers = {name: buf for name, buf in bn.named_buffers()}
bn_buffers["weight_mask"].set_(ch_mask)
bn_buffers["bias_mask"].set_(ch_mask)
bn_buffers["weight_mask"].set_(ch_mask) # type: ignore
bn_buffers["bias_mask"].set_(ch_mask) # type: ignore

conv_buffers = {name: buf for name, buf in conv.named_buffers()}
if "bias_mask" in conv_buffers:
conv_buffers["bias_mask"].set_(ch_mask)
conv_buffers["bias_mask"].set_(ch_mask) # type: ignore

# conv2d - batchnorm - activation (CBA)
# bn_mask: [out], conv: [out, in, h, w]
Expand All @@ -445,7 +460,7 @@ def update_masks(self, adjmodule_getter: AdjModuleGetter) -> None:
# ch_mask: [out, in, h, w]
ch_mask = ch_mask.repeat(1, i, 1, 1)

conv_buffers["weight_mask"].set_(ch_mask)
conv_buffers["weight_mask"].set_(ch_mask) # type: ignore

# Update fc layer mask
fc_modules: Dict[str, nn.Linear] = dict()
Expand Down
7 changes: 5 additions & 2 deletions src/runners/shrinker.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,10 @@ def _reshape_fcs(
self._set_layer(new_model, fc_name, reshaped_fc)

def _generate_reshaped_conv(
self, in_mask: Optional[torch.Tensor], out_mask: torch.Tensor, conv: nn.Conv2d,
self,
in_mask: Optional[torch.Tensor],
out_mask: torch.Tensor,
conv: nn.Conv2d,
) -> nn.Conv2d:
"""Generate new conv given old conv and masks(in and out or out only)."""
# Shrink both input, output channel of conv, and extract weight(orig, mask)
Expand All @@ -238,7 +241,7 @@ def _generate_reshaped_conv(
if in_mask is not None:
# make masking matrix[o, i]: in_mask.T * out_mask
# mask_flattened : [o*i]
mask_flattened = in_mask.unsqueeze(1).T * out_mask.unsqueeze(1)
mask_flattened = in_mask.unsqueeze(1).T * out_mask.unsqueeze(1) # type: ignore
mask_flattened = mask_flattened.reshape(-1)
mask_idx = (mask_flattened == 1).nonzero().view(-1, 1, 1).repeat(1, h, w)

Expand Down
18 changes: 14 additions & 4 deletions src/runners/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,17 @@ def setup_train_configuration(self, config: Dict[str, Any]) -> None:
# transform the training dataset for CutMix augmentation
if "CUTMIX" in config:
trainset = CutMix(
trainset, config["MODEL_PARAMS"]["num_classes"], **config["CUTMIX"],
trainset,
config["MODEL_PARAMS"]["num_classes"],
**config["CUTMIX"],
)

# get dataloaders
self.trainloader, self.testloader = utils.get_dataloader(
trainset, testset, config["BATCH_SIZE"], config["N_WORKERS"],
trainset,
testset,
config["BATCH_SIZE"],
config["N_WORKERS"],
)
logger.info("Dataloader prepared")

Expand All @@ -132,7 +137,8 @@ def setup_train_configuration(self, config: Dict[str, Any]) -> None:

# learning rate scheduler
self.lr_scheduler = get_lr_scheduler(
config["LR_SCHEDULER"], config["LR_SCHEDULER_PARAMS"],
config["LR_SCHEDULER"],
config["LR_SCHEDULER_PARAMS"],
)

def reset(self, checkpt_dir: str) -> None:
Expand Down Expand Up @@ -289,7 +295,11 @@ def warmup_one_iter(self) -> None:
return None

def save_params(
self, model_path: str, filename: str, epoch: int, record_path: bool = True,
self,
model_path: str,
filename: str,
epoch: int,
record_path: bool = True,
) -> None:
"""Save model."""
params = {
Expand Down
Loading