Browse Source

!32777 [MS][LITE] fix train bug

Merge pull request !32777 from jianghui58/train_r1.7
r1.7
i-robot Gitee 3 years ago
parent
commit
7a402ea110
No known key found for this signature in database GPG Key ID: 173E9B9CA92EEF8F
5 changed files with 59 additions and 11 deletions
  1. +2
    -2
      mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/batchnorm_fp32.c
  2. +4
    -4
      mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/batchnorm_fp32.h
  3. +3
    -1
      mindspore/lite/src/runtime/kernel/arm/fp32/fused_batchnorm_fp32.cc
  4. +48
    -3
      mindspore/lite/src/train/train_session.cc
  5. +2
    -1
      mindspore/lite/tools/benchmark_train/net_train.cc

+ 2
- 2
mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/batchnorm_fp32.c View File

@@ -100,10 +100,10 @@ void FusedBatchNormFp32(const float *input, const float *scale, const float *off
}

void FusedBatchNormFp32MeanVar(const float *input, float *run_mean, float *run_var, const BatchNormParameter *param,
float *save_mean, float *save_var) {
float *save_mean, float *save_var, bool isBatchNorm2d) {
const float N = (float)param->unit_;
const float VN = N;
const float VNUB = (N > 1.0f) ? (N - 1.0f) : 1.0f;
const float VNUB = (isBatchNorm2d == false) ? N : ((N > 1.0f) ? (N - 1.0f) : 1.0f);
const float momentum = (1.0f - param->momentum_);

for (int i = 0; i < param->unit_; i++) {


+ 4
- 4
mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/batchnorm_fp32.h View File

@@ -14,8 +14,8 @@
* limitations under the License.
*/

#ifndef MINDSPORE_NNACL_FP32_BATCHNORM_H_
#define MINDSPORE_NNACL_FP32_BATCHNORM_H_
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_FP32_BATCHNORM_FP32_H_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_FP32_BATCHNORM_FP32_H_

#include "nnacl/batchnorm_parameter.h"

@@ -30,9 +30,9 @@ void FusedBatchNormFp32(const float *input, const float *scale, const float *off
const float *variance, const BatchNormParameter *param, int task_id, float *output);

void FusedBatchNormFp32MeanVar(const float *input, float *run_mean, float *run_var, const BatchNormParameter *param,
float *save_mean, float *save_var);
float *save_mean, float *save_var, bool isBatchNorm2d);
#ifdef __cplusplus
}
#endif

#endif // MINDSPORE_NNACL_FUSED_BATCHNORM_H_
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_FP32_BATCHNORM_FP32_H_

+ 3
- 1
mindspore/lite/src/runtime/kernel/arm/fp32/fused_batchnorm_fp32.cc View File

@@ -161,6 +161,7 @@ int FusedBatchnormCPUKernel::Run() {
float *current_var = static_cast<float *>(variance_);
float *save_mean = static_cast<float *>(in_tensors_.at(FOURTH_INPUT)->data());
float *save_variance = static_cast<float *>(in_tensors_.at(FIFTH_INPUT)->data());
bool isBatch2d = true;
if (in == nullptr || scale == nullptr || offset == nullptr || current_mean == nullptr || current_var == nullptr ||
save_mean == nullptr || save_variance == nullptr) {
MS_LOG(ERROR) << "The input data is nullptr.";
@@ -168,8 +169,9 @@ int FusedBatchnormCPUKernel::Run() {
}
std::fill(current_mean, current_mean + in_tensors_.at(FOURTH_INPUT)->ElementsNum(), 0.f);
std::fill(current_var, current_var + in_tensors_.at(FIFTH_INPUT)->ElementsNum(), 0.f);
if (in_tensors_.at(FIRST_INPUT)->shape().size() == C2NUM) isBatch2d = false;
FusedBatchNormFp32MeanVar(in, current_mean, current_var, param, static_cast<float *>(save_mean),
static_cast<float *>(save_variance));
static_cast<float *>(save_variance), isBatch2d);

CHECK_NULL_RETURN(out_tensors_.at(SECOND_INPUT)->data());
CHECK_NULL_RETURN(out_tensors_.at(THIRD_INPUT)->data());


+ 48
- 3
mindspore/lite/src/train/train_session.cc View File

@@ -656,6 +656,11 @@ void TrainSession::CompileEvalOutputs() {
if (IsLossKernel(kernel) && !(IsGradKernel(kernel))) {
for (auto in_kernel : kernel->in_kernels()) {
if (IsLossKernel(in_kernel) || IsGradKernel(in_kernel)) continue;
bool is_loss = false;
for (auto in_in_kernel : in_kernel->in_kernels()) {
if (IsLossKernel(in_in_kernel)) is_loss = true;
}
if (is_loss) continue;
// insert if not already in
if (eval_output_node_map_.find(in_kernel->name()) == eval_output_node_map_.end()) {
auto *ms_tensor = in_kernel->out_tensors().at(0);
@@ -726,13 +731,53 @@ void TrainSession::BuildInferenceKernelsRecursive(kernel::LiteKernel *kernel, st

void TrainSession::CompileTrainKernels() {
train_kernels_.clear();
std::vector<kernel::LiteKernel *> train_kernels;
for (auto ori_kernel : kernels_) {
if (ori_kernel->subgraph_type() == kernel::kNotSubGraph) {
train_kernels_.push_back(ori_kernel);
train_kernels.push_back(ori_kernel);
} else {
auto sub_graph = reinterpret_cast<kernel::SubGraphKernel *>(ori_kernel);
for (auto kernel : sub_graph->nodes()) {
train_kernels_.push_back(kernel);
std::copy(sub_graph->nodes().begin(), sub_graph->nodes().end(), std::back_inserter(train_kernels));
}
}
// For LSTM GPU operators are synchronized internally hence we need to add sync mechanizm to execution graph
for (auto k : train_kernels) {
if (k->type() == schema::PrimitiveType_LSTMGradWeight) {
// Find PrimitiveType_LSTMGradData that matches this PrimitiveType_LSTMGradWeight
for (auto mk : train_kernels) {
if (mk->type() == schema::PrimitiveType_LSTMGradData) {
if (k->in_tensors().at(C2NUM)->tensor_name() == mk->in_tensors().at(0)->tensor_name()) {
mk->AddOutKernel(k);
k->AddInKernel(mk);
}
}
}
}
}
std::queue<kernel::LiteKernel *> queue;
for (auto k : train_kernels) {
auto in_kernels = k->in_kernels();
if (in_kernels.size() == 0) {
queue.push(k);
}
}
std::unordered_map<kernel::LiteKernel *, int> map;
while (queue.size()) {
// pop first element
auto k = queue.front();
train_kernels_.push_back(k);
queue.pop();
for (auto &ok : k->out_kernels()) {
auto cnt_iter = map.find(ok);
if (cnt_iter == map.end()) {
int ref_cnt = ok->in_kernels().size();
map[ok] = ref_cnt;
}
cnt_iter = map.find(ok);
auto ref_cnt = cnt_iter->second - 1;
map[ok] = ref_cnt;
if (ref_cnt <= 0) {
queue.push(cnt_iter->first);
}
}
}


+ 2
- 1
mindspore/lite/tools/benchmark_train/net_train.cc View File

@@ -177,9 +177,10 @@ int NetTrain::CompareOutput(const session::LiteSession &lite_session) {
MS_LOG(ERROR) << "Cannot find output tensors, get model output failed";
return RET_ERROR;
}
std::map<std::string, mindspore::tensor::MSTensor *> ordered_outputs(tensors_list.begin(), tensors_list.end());
mindspore::tensor::MSTensor *tensor = nullptr;
int i = 1;
for (auto it = tensors_list.begin(); it != tensors_list.end(); ++it) {
for (auto it = ordered_outputs.begin(); it != ordered_outputs.end(); ++it) {
tensor = lite_session.GetOutputByTensorName(it->first);
std::cout << "output is tensor " << it->first << "\n";
auto outputs = tensor->data();


Loading…
Cancel
Save