@@ -59,24 +59,36 @@ def __init__(self, in_channels, inter_channels=None, dimension=3, mode='embedded
5959 self .phi = None
6060 self .concat_project = None
6161
62- if mode in ['embedded_gaussian' , 'dot_product' , 'concatenation' ]:
63- self .theta = conv_nd (in_channels = self .in_channels , out_channels = self .inter_channels ,
64- kernel_size = 1 , stride = 1 , padding = 0 )
65- self .phi = conv_nd (in_channels = self .in_channels , out_channels = self .inter_channels ,
66- kernel_size = 1 , stride = 1 , padding = 0 )
67-
68- if mode == 'embedded_gaussian' :
69- self .operation_function = self ._embedded_gaussian
70- elif mode == 'dot_product' :
71- self .operation_function = self ._dot_product
72- elif mode == 'concatenation' :
73- self .operation_function = self ._concatenation
74- self .concat_project = nn .Sequential (
75- nn .Conv2d (self .inter_channels * 2 , 1 , 1 , 1 , 0 , bias = False ),
76- nn .ReLU ()
77- )
78- elif mode == 'gaussian' :
79- self .operation_function = self ._gaussian
62+ # if mode in ['embedded_gaussian', 'dot_product', 'concatenation']:
63+ self .theta = conv_nd (in_channels = self .in_channels , out_channels = self .inter_channels ,
64+ kernel_size = 1 , stride = 1 , padding = 0 )
65+
66+ self .phi = conv_nd (in_channels = self .in_channels , out_channels = self .inter_channels ,
67+ kernel_size = 1 , stride = 1 , padding = 0 )
68+ # elif mode == 'concatenation':
69+ self .concat_project = nn .Sequential (
70+ nn .Conv2d (self .inter_channels * 2 , 1 , 1 , 1 , 0 , bias = False ),
71+ nn .ReLU ()
72+ )
73+
74+ # if mode in ['embedded_gaussian', 'dot_product', 'concatenation']:
75+ # self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
76+ # kernel_size=1, stride=1, padding=0)
77+ # self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
78+ # kernel_size=1, stride=1, padding=0)
79+ #
80+ # if mode == 'embedded_gaussian':
81+ # self.operation_function = self._embedded_gaussian
82+ # elif mode == 'dot_product':
83+ # self.operation_function = self._dot_product
84+ # elif mode == 'concatenation':
85+ # self.operation_function = self._concatenation
86+ # self.concat_project = nn.Sequential(
87+ # nn.Conv2d(self.inter_channels * 2, 1, 1, 1, 0, bias=False),
88+ # nn.ReLU()
89+ # )
90+ # elif mode == 'gaussian':
91+ # self.operation_function = self._gaussian
8092
8193 if sub_sample :
8294 self .g = nn .Sequential (self .g , max_pool (kernel_size = 2 ))
@@ -91,7 +103,15 @@ def forward(self, x):
91103 :return:
92104 '''
93105
94- output = self .operation_function (x )
106+ if self .mode == 'embedded_gaussian' :
107+ output = self ._embedded_gaussian (x )
108+ elif mode == 'dot_product' :
109+ output = self ._dot_product (x )
110+ elif mode == 'concatenation' :
111+ output = self ._concatenation (x )
112+ elif mode == 'gaussian' :
113+ output = self ._gaussian (x )
114+
95115 return output
96116
97117 def _embedded_gaussian (self , x ):
0 commit comments