@@ -19,6 +19,23 @@ limitations under the License. */
1919
2020namespace paddle {
2121
22+ MatrixPtr CrossChannelNormLayer::createSampleMatrix (MatrixPtr data,
23+ size_t iter,
24+ size_t spatialDim) {
25+ return Matrix::create (data->getData () + iter * channels_ * spatialDim,
26+ channels_,
27+ spatialDim,
28+ false ,
29+ useGpu_);
30+ }
31+
32+ MatrixPtr CrossChannelNormLayer::createSpatialMatrix (MatrixPtr data,
33+ size_t iter,
34+ size_t spatialDim) {
35+ return Matrix::create (
36+ data->getData () + iter * spatialDim, 1 , spatialDim, false , useGpu_);
37+ }
38+
2239void CrossChannelNormLayer::forward (PassType passType) {
2340 Layer::forward (passType);
2441 MatrixPtr inV = getInputValue (0 );
@@ -40,25 +57,19 @@ void CrossChannelNormLayer::forward(PassType passType) {
4057 normBuffer_->addScalar (*normBuffer_, 1e-6 );
4158 inV->square2 (*dataBuffer_);
4259 for (size_t i = 0 ; i < batchSize; i++) {
43- MatrixPtr inTmp = Matrix::create (
44- inV->getData () + i * dataDim, channels_, spatialDim, false , useGpu_);
45- MatrixPtr dataTmp = Matrix::create (dataBuffer_->getData () + i * dataDim,
46- channels_,
47- spatialDim,
48- false ,
49- useGpu_);
50- MatrixPtr outTmp = Matrix::create (
51- outV->getData () + i * dataDim, channels_, spatialDim, false , useGpu_);
52- MatrixPtr normTmp = Matrix::create (
53- normBuffer_->getData () + i * spatialDim, 1 , spatialDim, false , useGpu_);
60+ const MatrixPtr inVTmp = createSampleMatrix (inV, i, spatialDim);
61+ const MatrixPtr dataTmp = createSampleMatrix (dataBuffer_, i, spatialDim);
62+ MatrixPtr outVTmp = createSampleMatrix (outV, i, spatialDim);
63+ MatrixPtr normTmp = createSpatialMatrix (normBuffer_, i, spatialDim);
64+
5465 // compute norm.
55- spatialBuffer_->sumCols (*dataTmp, 1 , 1 );
66+ spatialBuffer_->sumCols (*dataTmp, 1 , 0 );
5667 spatialBuffer_->sqrt2 (*spatialBuffer_);
5768 normTmp->copyFrom (*spatialBuffer_);
58- outTmp ->copyFrom (*inTmp );
59- outTmp ->divRowVector (*spatialBuffer_);
69+ outVTmp ->copyFrom (*inVTmp );
70+ outVTmp ->divRowVector (*spatialBuffer_);
6071 // scale the layer.
61- outTmp ->mulColVector (*scale_->getW ());
72+ outVTmp ->mulColVector (*scale_->getW ());
6273 }
6374}
6475
@@ -78,40 +89,31 @@ void CrossChannelNormLayer::backward(const UpdateCallback& callback) {
7889 Matrix::resizeOrCreate (sampleBuffer_, channels_, spatialDim, false , useGpu_);
7990 scaleDiff_->zeroMem ();
8091 for (size_t i = 0 ; i < batchSize; i++) {
81- // propagate to param.
82- MatrixPtr dataBufferTmp =
83- Matrix::create (dataBuffer_->getData () + i * dataDim,
84- channels_,
85- spatialDim,
86- false ,
87- useGpu_);
88- const MatrixPtr inValueTmp = Matrix::create (
89- inV->getData () + i * dataDim, channels_, spatialDim, false , useGpu_);
90- const MatrixPtr outGradTmp = Matrix::create (
91- outG->getData () + i * dataDim, channels_, spatialDim, false , useGpu_);
92- MatrixPtr inGradTmp = Matrix::create (
93- inG->getData () + i * dataDim, channels_, spatialDim, false , useGpu_);
94- const MatrixPtr normTmp = Matrix::create (
95- normBuffer_->getData () + i * spatialDim, 1 , spatialDim, false , useGpu_);
96- channelBuffer_->sumRows (*dataBufferTmp, 1 , 1 );
92+ MatrixPtr outGTmp = createSampleMatrix (outG, i, spatialDim);
93+ const MatrixPtr dataTmp = createSampleMatrix (dataBuffer_, i, spatialDim);
94+ const MatrixPtr inVTmp = createSampleMatrix (inV, i, spatialDim);
95+ const MatrixPtr inGTmp = createSampleMatrix (inG, i, spatialDim);
96+ const MatrixPtr normTmp = createSpatialMatrix (normBuffer_, i, spatialDim);
97+
98+ channelBuffer_->sumRows (*dataTmp, 1 , 0 );
9799 channelBuffer_->dotDiv (*channelBuffer_, *(scale_->getW ()));
98100 // store a / scale[i] in scaleDiff_ temporary
99101 scaleDiff_->add (*channelBuffer_, 1 .);
100102
101- sampleBuffer_->dotMul (*inValueTmp , *outGradTmp );
103+ sampleBuffer_->dotMul (*inVTmp , *outGTmp );
102104 spatialBuffer_->sumCols (*sampleBuffer_, 1 ., 1 .);
103105 // scale the grad
104- inGradTmp ->copyFrom (*inValueTmp );
105- inGradTmp ->mulRowVector (*spatialBuffer_);
106+ inGTmp ->copyFrom (*inVTmp );
107+ inGTmp ->mulRowVector (*spatialBuffer_);
106108 // divide by square of norm
107109 spatialBuffer_->dotMul (*normTmp, *normTmp);
108- inGradTmp ->divRowVector (*spatialBuffer_);
110+ inGTmp ->divRowVector (*spatialBuffer_);
109111 // subtract
110- inGradTmp ->add (*outGradTmp , -1 , 1 );
112+ inGTmp ->add (*outGTmp , -1 , 1 );
111113 // divide by norm
112- inGradTmp ->divRowVector (*normTmp);
114+ inGTmp ->divRowVector (*normTmp);
113115 // scale the diff
114- inGradTmp ->mulColVector (*scale_->getW ());
116+ inGTmp ->mulColVector (*scale_->getW ());
115117 }
116118 // updata scale
117119 if (scale_->getWGrad ()) scale_->getWGrad ()->copyFrom (*scaleDiff_);
0 commit comments