From 621e4fe5a868e76b49f0deffcc33ff614fe99c0b Mon Sep 17 00:00:00 2001 From: Simon Levine <50503513+simonlevine@users.noreply.github.com> Date: Fri, 7 Jan 2022 11:46:56 -0800 Subject: [PATCH] change 2dConv to 1d --- src/text/cct.py | 4 +-- src/utils/tokenizer.py | 63 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 65 insertions(+), 2 deletions(-) diff --git a/src/text/cct.py b/src/text/cct.py index 552fc01..c9f3fed 100644 --- a/src/text/cct.py +++ b/src/text/cct.py @@ -1,6 +1,6 @@ import torch.nn as nn from ..utils.transformers import MaskedTransformerClassifier -from ..utils.tokenizer import TextTokenizer +from ..utils.tokenizer import TextTokenizer1D from ..utils.embedder import Embedder __all__ = [ @@ -27,7 +27,7 @@ def __init__(self, self.embedder = Embedder(word_embedding_dim=word_embedding_dim, *args, **kwargs) - self.tokenizer = TextTokenizer(n_input_channels=word_embedding_dim, + self.tokenizer = TextTokenizer1D(n_input_channels=word_embedding_dim, n_output_channels=embedding_dim, kernel_size=kernel_size, stride=stride, diff --git a/src/utils/tokenizer.py b/src/utils/tokenizer.py index 90af3ad..e8a1324 100644 --- a/src/utils/tokenizer.py +++ b/src/utils/tokenizer.py @@ -109,3 +109,66 @@ def forward(self, x, mask=None): def init_weight(m): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight) + + +class TextTokenizer1D(nn.Module): + def __init__(self, + kernel_size, stride, padding, + pooling_kernel_size=3, pooling_stride=2, pooling_padding=1, + embedding_dim=300, + n_output_channels=128, + activation=None, + max_pool=True, + *args, **kwargs): + + super(TextTokenizer1D, self).__init__() + + self.max_pool = max_pool + self.conv_layers = nn.Sequential( + nn.Conv1d(in_channels=embedding_dim, out_channels=n_output_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, bias=False), + nn.Identity() if activation is None else activation(), + nn.MaxPool1d( + kernel_size=pooling_kernel_size, + stride=pooling_stride, + padding=pooling_padding + ) if max_pool else nn.Identity() + ) + + self.apply(self.init_weight) + + def seq_len(self, seq_len=32, embed_dim=300): + return self.forward(torch.zeros((1, seq_len, embed_dim)))[0].shape[1] + + def forward_mask(self, mask): + new_mask = mask.unsqueeze(1).float() + cnn_weight = torch.ones( + (1, 1, self.conv_layers[0].kernel_size[0]), + device=mask.device, + dtype=torch.float) + new_mask = F.conv1d( + new_mask, cnn_weight, None, + self.conv_layers[0].stride[0], self.conv_layers[0].padding[0], 1, 1) + if self.max_pool: + new_mask = F.max_pool1d( + new_mask, self.conv_layers[2].kernel_size[0], + self.conv_layers[2].stride[0], self.conv_layers[2].padding[0], 1, False, False) + new_mask = new_mask.squeeze(1) + new_mask = (new_mask > 0) + return new_mask + + def forward(self, x, mask=None): + x = self.conv_layers(x.transpose(1,2)) + x = x.transpose(1, 2) + x = x if mask is None else x * self.forward_mask(mask).unsqueeze(-1).float() + if mask is not None: + mask = self.forward_mask(mask).unsqueeze(-1).float() + x = x * mask + return x, mask + + @staticmethod + def init_weight(m): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight)