diff --git a/models/partialconv2d.py b/models/partialconv2d.py index 77a9a56..988d0c9 100755 --- a/models/partialconv2d.py +++ b/models/partialconv2d.py @@ -30,7 +30,7 @@ def __init__(self, *args, **kwargs): super(PartialConv2d, self).__init__(*args, **kwargs) if self.multi_channel: - self.weight_maskUpdater = torch.ones(self.out_channels, self.in_channels, self.kernel_size[0], self.kernel_size[1]) + self.weight_maskUpdater = torch.ones(1, self.in_channels, self.kernel_size[0], self.kernel_size[1]) else: self.weight_maskUpdater = torch.ones(1, 1, self.kernel_size[0], self.kernel_size[1]) @@ -59,7 +59,9 @@ def forward(self, input, mask_in=None): mask = mask_in self.update_mask = F.conv2d(mask, self.weight_maskUpdater, bias=None, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=1) - + + self.update_mask = self.update_mask.expand(self.update_mask.size(0), self.out_channels, self.update_mask.size(2), self.update_mask.size(3)) + # for mixed precision training, change 1e-8 to 1e-6 self.mask_ratio = self.slide_winsize/(self.update_mask + 1e-8) # self.mask_ratio = torch.max(self.update_mask)/(self.update_mask + 1e-8)