From b9337ece5bbc7e6cca08fb3f81063c44877d351b Mon Sep 17 00:00:00 2001 From: "Mr.Blue" Date: Wed, 12 Dec 2018 15:05:03 +0800 Subject: [PATCH] Correct mistakes. It seems there are some mistakes here. Correct me if I'm wrong. --- models/partialconv2d.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/models/partialconv2d.py b/models/partialconv2d.py index 9412dee..d08e0e5 100755 --- a/models/partialconv2d.py +++ b/models/partialconv2d.py @@ -64,20 +64,21 @@ def forward(self, input, mask=None): self.mask_ratio = torch.mul(self.mask_ratio, self.update_mask) if self.update_mask.type() != input.type() or self.mask_ratio.type() != input.type(): - self.update_mask.to(input) - self.mask_ratio.to(input) + self.update_mask = self.update_mask.to(input) + self.mask_ratio = self.mask_ratio.to(input) raw_out = super(PartialConv2d, self).forward(input) if self.bias is not None: bias_view = self.bias.view(1, self.out_channels, 1, 1) output = torch.mul(raw_out - bias_view, self.mask_ratio) + bias_view - output = torch.mul(output, self.update_mask) else: output = torch.mul(raw_out, self.mask_ratio) + + output = torch.mul(output, self.update_mask) if self.return_mask: return output, self.update_mask else: - return output \ No newline at end of file + return output