33#include < algorithm>
44#include < typeinfo>
55#include < cstdlib>
6+ #include < cmath>
67// For External Library
78#include < torch/torch.h>
89// For Original Header
@@ -30,8 +31,8 @@ torch::Tensor StochasticDepthImpl::forward(torch::Tensor x){
3031 float p_bar;
3132 torch::Tensor mask, out;
3233
33- if (!this ->is_training () || p < eps) return x;
34- else if (p > 1.0 - eps) return torch::zeros_like (x);
34+ if (!this ->is_training () || this -> p < eps) return x;
35+ else if (this -> p > 1.0 - eps) return torch::zeros_like (x);
3536
3637 p_bar = 1.0 - this ->p ;
3738 mask = torch::bernoulli (torch::full ({x.size (0 ), 1 , 1 , 1 }, p_bar, x.options ()));
@@ -54,10 +55,10 @@ void StochasticDepthImpl::pretty_print(std::ostream& stream) const{
5455// ----------------------------------------------------------------------
5556// struct{Conv2dNormActivationImpl}(nn::Module) -> constructor
5657// ----------------------------------------------------------------------
57- Conv2dNormActivationImpl::Conv2dNormActivationImpl (const size_t in_nc, const size_t out_nc, const size_t kernel_size, const size_t stride, const size_t padding, const size_t groups, const bool SiLU){
58+ Conv2dNormActivationImpl::Conv2dNormActivationImpl (const size_t in_nc, const size_t out_nc, const size_t kernel_size, const size_t stride, const size_t padding, const size_t groups, const float eps, const float momentum, const bool SiLU){
5859 this ->model = nn::Sequential (
5960 nn::Conv2d (nn::Conv2dOptions (/* in_channels=*/ in_nc, /* out_channels=*/ out_nc, /* kernel_size=*/ kernel_size).stride (stride).padding (padding).groups (groups).bias (false )),
60- nn::BatchNorm2d (out_nc)
61+ nn::BatchNorm2d (nn::BatchNormOptions ( out_nc). eps (eps). momentum (momentum) )
6162 );
6263 if (SiLU) this ->model ->push_back (nn::SiLU ());
6364 register_module (" Conv2dNormActivation" , this ->model );
@@ -113,16 +114,16 @@ torch::Tensor SqueezeExcitationImpl::forward(torch::Tensor x){
113114// ----------------------------------------------------------------------
114115// struct{MBConvImpl}(nn::Module) -> constructor
115116// ----------------------------------------------------------------------
116- MBConvImpl::MBConvImpl (const size_t in_nc, const size_t out_nc, const size_t kernel_size, const size_t stride, const size_t exp, const float dropconnect){
117+ MBConvImpl::MBConvImpl (const size_t in_nc, const size_t out_nc, const size_t kernel_size, const size_t stride, const size_t exp, const float eps, const float momentum, const float dropconnect){
117118
118119 constexpr size_t reduce = 4 ;
119120 size_t mid = in_nc * exp;
120121 this ->residual = ((stride == 1 ) && (in_nc == out_nc));
121122
122- if (exp != 1 ) this ->block ->push_back (Conv2dNormActivation (in_nc, mid, /* kernel_size=*/ 1 , /* stride=*/ 1 , /* padding=*/ 0 , /* groups=*/ 1 , /* SiLU=*/ true ));
123- this ->block ->push_back (Conv2dNormActivation (mid, mid, /* kernel_size=*/ kernel_size, /* stride=*/ stride, /* padding=*/ kernel_size / 2 , /* groups=*/ mid, /* SiLU=*/ true ));
123+ if (exp != 1 ) this ->block ->push_back (Conv2dNormActivation (in_nc, mid, /* kernel_size=*/ 1 , /* stride=*/ 1 , /* padding=*/ 0 , /* groups=*/ 1 , /* eps= */ eps, /* momentum= */ momentum, /* SiLU=*/ true ));
124+ this ->block ->push_back (Conv2dNormActivation (mid, mid, /* kernel_size=*/ kernel_size, /* stride=*/ stride, /* padding=*/ kernel_size / 2 , /* groups=*/ mid, /* eps= */ eps, /* momentum= */ momentum, /* SiLU=*/ true ));
124125 this ->block ->push_back (SqueezeExcitation (mid, std::max (1 , int (in_nc / reduce))));
125- this ->block ->push_back (Conv2dNormActivation (mid, out_nc, /* kernel_size=*/ 1 , /* stride=*/ 1 , /* padding=*/ 0 , /* groups=*/ 1 , /* SiLU=*/ false ));
126+ this ->block ->push_back (Conv2dNormActivation (mid, out_nc, /* kernel_size=*/ 1 , /* stride=*/ 1 , /* padding=*/ 0 , /* groups=*/ 1 , /* eps= */ eps, /* momentum= */ momentum, /* SiLU=*/ false ));
126127 register_module (" block" , this ->block );
127128
128129 this ->sd = StochasticDepth (dropconnect);
@@ -170,7 +171,7 @@ size_t MC_EfficientNetImpl::round_filters(size_t c, double width_mul){
170171// struct{MC_EfficientNetImpl}(nn::Module) -> function{round_repeats}
171172// ----------------------------------------------------------------------
172173size_t MC_EfficientNetImpl::round_repeats (size_t r, double depth_mul){
173- return std::max (1 , int (std::round (r * depth_mul)));
174+ return std::max (1 , int (std::ceil (r * depth_mul)));
174175}
175176
176177
@@ -186,16 +187,16 @@ MC_EfficientNetImpl::MC_EfficientNetImpl(po::variables_map &vm){
186187
187188 // (0.a) Setting for network's config
188189 std::string network = vm[" network" ].as <std::string>();
189- if (network == " B0" ) this ->cfg = {1.0 , 1.0 , 224 , 0.2 };
190- else if (network == " B1" ) this ->cfg = {1.0 , 1.1 , 240 , 0.2 };
191- else if (network == " B2" ) this ->cfg = {1.1 , 1.2 , 260 , 0.3 };
192- else if (network == " B3" ) this ->cfg = {1.2 , 1.4 , 300 , 0.3 };
193- else if (network == " B4" ) this ->cfg = {1.4 , 1.8 , 380 , 0.4 };
194- else if (network == " B5" ) this ->cfg = {1.6 , 2.2 , 456 , 0.4 };
195- else if (network == " B6" ) this ->cfg = {1.8 , 2.6 , 528 , 0.5 };
196- else if (network == " B7" ) this ->cfg = {2.0 , 3.1 , 600 , 0.5 };
197- else if (network == " B8" ) this ->cfg = {2.2 , 3.6 , 672 , 0.5 };
198- else if (network == " L2" ) this ->cfg = {4.3 , 5.3 , 800 , 0.5 };
190+ if (network == " B0" ) this ->cfg = {1.0 , 1.0 , 224 , 0.2 , 1e-5 , 0.1 , 0.2 };
191+ else if (network == " B1" ) this ->cfg = {1.0 , 1.1 , 240 , 0.2 , 1e-5 , 0.1 , 0.2 };
192+ else if (network == " B2" ) this ->cfg = {1.1 , 1.2 , 260 , 0.3 , 1e-5 , 0.1 , 0.2 };
193+ else if (network == " B3" ) this ->cfg = {1.2 , 1.4 , 300 , 0.3 , 1e-5 , 0.1 , 0.2 };
194+ else if (network == " B4" ) this ->cfg = {1.4 , 1.8 , 380 , 0.4 , 1e-5 , 0.1 , 0.2 };
195+ else if (network == " B5" ) this ->cfg = {1.6 , 2.2 , 456 , 0.4 , 0.001 , 0.01 , 0.2 };
196+ else if (network == " B6" ) this ->cfg = {1.8 , 2.6 , 528 , 0.5 , 0.001 , 0.01 , 0.2 };
197+ else if (network == " B7" ) this ->cfg = {2.0 , 3.1 , 600 , 0.5 , 0.001 , 0.01 , 0.2 };
198+ else if (network == " B8" ) this ->cfg = {2.2 , 3.6 , 672 , 0.5 , 0.001 , 0.01 , 0.2 };
199+ else if (network == " L2" ) this ->cfg = {4.3 , 5.3 , 800 , 0.5 , 0.001 , 0.01 , 0.2 };
199200 else {
200201 std::cerr << " Error : The type of network is " << network << ' .' << std::endl;
201202 std::cerr << " Error : Please choose B0, B1, B2, B3, B4, B5, B6, B7, B8 or L2." << std::endl;
@@ -209,7 +210,7 @@ MC_EfficientNetImpl::MC_EfficientNetImpl(po::variables_map &vm){
209210
210211 // (1) Stem layer
211212 stem_nc = this ->round_filters (stem_feature, this ->cfg .width_mul );
212- this ->features ->push_back (Conv2dNormActivation (vm[" nc" ].as <size_t >(), stem_nc, /* kernel_size=*/ 3 , /* stride=*/ 2 , /* padding=*/ 1 , /* groups=*/ 1 , /* SiLU=*/ true ));
213+ this ->features ->push_back (Conv2dNormActivation (vm[" nc" ].as <size_t >(), stem_nc, /* kernel_size=*/ 3 , /* stride=*/ 2 , /* padding=*/ 1 , /* groups=*/ 1 , /* eps= */ this -> cfg . eps , /* momentum= */ this -> cfg . momentum , /* SiLU=*/ true ));
213214
214215 // (2.a) Bone layer
215216 total_blocks = 0 ;
@@ -225,16 +226,16 @@ MC_EfficientNetImpl::MC_EfficientNetImpl(po::variables_map &vm){
225226 out_nc = this ->round_filters (bcfg[i].c , this ->cfg .width_mul );
226227 for (size_t j = 0 ; j < repeats; j++){
227228 stride = (j == 0 ) ? bcfg[i].s : 1 ;
228- dropconnect = this ->cfg .dropout * (double )block_idx / (double )std::max (1 , int (total_blocks));
229- this ->features ->push_back (MBConvImpl (in_nc, out_nc, bcfg[i].k , stride, bcfg[i].exp , dropconnect));
229+ dropconnect = this ->cfg .stochastic_depth_prob * (double )block_idx / (double )std::max (1 , int (total_blocks));
230+ this ->features ->push_back (MBConvImpl (in_nc, out_nc, bcfg[i].k , stride, bcfg[i].exp , this -> cfg . eps , this -> cfg . momentum , dropconnect));
230231 in_nc = out_nc;
231232 block_idx++;
232233 }
233234 }
234235
235236 // (3) Head layer
236237 head_nc = this ->round_filters (head_feature, this ->cfg .width_mul );
237- this ->features ->push_back (Conv2dNormActivation (in_nc, head_nc, /* kernel_size=*/ 1 , /* stride=*/ 1 , /* padding=*/ 0 , /* groups=*/ 1 , /* SiLU=*/ true ));
238+ this ->features ->push_back (Conv2dNormActivation (in_nc, head_nc, /* kernel_size=*/ 1 , /* stride=*/ 1 , /* padding=*/ 0 , /* groups=*/ 1 , /* eps= */ this -> cfg . eps , /* momentum= */ this -> cfg . momentum , /* SiLU=*/ true ));
238239 register_module (" features" , this ->features );
239240
240241 // (4) Global Average Pooling
0 commit comments