Skip to content

Commit adc5839

Browse files
authored
Merge pull request #969 from reyoung/feature/clean_gradient_machine_start
Remove not used params in GradientMachine::start
2 parents bbf3b47 + 56f2965 commit adc5839

File tree

12 files changed

+14
-23
lines changed

12 files changed

+14
-23
lines changed

paddle/gserver/gradientmachines/GradientMachine.h

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -212,11 +212,7 @@ class GradientMachine {
212212
* @note This function will only been implemented and used in a
213213
* multithreaded environment.
214214
*/
215-
virtual void start(const TrainerConfig& config,
216-
DataProviderPtr dataProvider) {
217-
(void)config;
218-
(void)dataProvider;
219-
}
215+
virtual void start() {}
220216

221217
/**
222218
* @brief check each work-thread whether is failed/error/finish,

paddle/gserver/gradientmachines/MultiGradientMachine.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,7 @@ TrainerThread::TrainerThread(const ModelConfig& config,
441441
TrainerThread::~TrainerThread() { stop(); }
442442

443443
void TrainerThread::start() {
444-
gradientMachine_->start(*(TrainerConfig*)nullptr, (DataProviderPtr) nullptr);
444+
gradientMachine_->start();
445445

446446
computeThread_.reset(new std::thread([this]() { computeThread(); }));
447447

paddle/gserver/gradientmachines/MultiNetwork.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,10 +109,9 @@ void MultiNetwork::onPassEnd() {
109109
}
110110
}
111111

112-
void MultiNetwork::start(const TrainerConfig& config,
113-
DataProviderPtr dataProvider) {
112+
void MultiNetwork::start() {
114113
for (auto& subNetwork : subNetworks_) {
115-
subNetwork->start(config, dataProvider);
114+
subNetwork->start();
116115
}
117116
}
118117

paddle/gserver/gradientmachines/MultiNetwork.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ class MultiNetwork : public NeuralNetwork {
5454
return subNetworks_;
5555
}
5656

57-
virtual void start(const TrainerConfig& config, DataProviderPtr dataProvider);
57+
virtual void start();
5858

5959
virtual void finish();
6060

paddle/gserver/gradientmachines/ParallelNeuralNetwork.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -131,11 +131,7 @@ void ParallelNeuralNetwork::forwardBackward(const std::vector<Argument>& inArgs,
131131
backward(callback);
132132
}
133133

134-
void ParallelNeuralNetwork::start(const TrainerConfig& config,
135-
DataProviderPtr dataProvider) {
136-
(void)config;
137-
(void)dataProvider;
138-
134+
void ParallelNeuralNetwork::start() {
139135
for (auto& thread : threads_) {
140136
thread->start();
141137
}

paddle/gserver/gradientmachines/ParallelNeuralNetwork.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ class ParallelNeuralNetwork : public NeuralNetwork {
5656
PassType passType,
5757
const UpdateCallback &callback = NULL);
5858

59-
virtual void start(const TrainerConfig &config, DataProviderPtr dataProvider);
59+
virtual void start();
6060

6161
void addComputeThread(int deviceId);
6262

paddle/gserver/tests/test_NetworkCompare.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ void calcGradient(DataIn& in, DataOut& out, const std::string& configPath) {
114114
parameters[i]->getBuf(PARAMETER_VALUE)->copyFrom(*in.paraValues[i]);
115115
}
116116
}
117-
gradientMachine->start(trainer.getConfig(), nullptr);
117+
gradientMachine->start();
118118
gradientMachine->forward(in.inArgs, &outArgs, PASS_TRAIN);
119119
for (size_t i = 0; i < in.outGrads.size(); i++) {
120120
// If the all the layers in the config have no parameters, also

paddle/gserver/tests/test_RecurrentGradientMachine.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ class TrainerForTest : public paddle::Trainer {
2828
public:
2929
void startTrain() {
3030
GradientMachine& gm = *this->trainerInternal_.getGradientMachine();
31-
gm.start(this->getConfig(), dataProvider_);
31+
gm.start();
3232
}
3333

3434
void finishTrain() {

paddle/trainer/Tester.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ void Tester::test() {
257257
CHECK(testDataProvider_) << "TestData is not specified";
258258
testDataProvider_->setSkipShuffle();
259259
testDataProvider_->reset();
260-
gradientMachine_->start(*config_, testDataProvider_);
260+
gradientMachine_->start();
261261

262262
// For evaluation
263263
std::vector<std::string> modelList;

paddle/trainer/Trainer.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ static double genPerturbation(real* d, real* grad, size_t dim) {
308308
}
309309

310310
real Trainer::checkGradient() {
311-
trainerInternal_.getGradientMachine()->start(*config_, dataProvider_);
311+
trainerInternal_.getGradientMachine()->start();
312312
std::vector<ParameterPtr>& parameters =
313313
trainerInternal_.getGradientMachine()->getNonStaticParameters();
314314
DataBatch dataBatch;
@@ -390,7 +390,7 @@ void Trainer::startTrain() {
390390
dataProvider_->reset();
391391
}
392392

393-
trainerInternal_.getGradientMachine()->start(*config_, dataProvider_);
393+
trainerInternal_.getGradientMachine()->start();
394394
}
395395

396396
void Trainer::finishTrain() { trainerInternal_.getGradientMachine()->finish(); }

0 commit comments

Comments
 (0)