Merge pull request !32777 from jianghui58/train_r1.7r1.7
| @@ -100,10 +100,10 @@ void FusedBatchNormFp32(const float *input, const float *scale, const float *off | |||||
| } | } | ||||
| void FusedBatchNormFp32MeanVar(const float *input, float *run_mean, float *run_var, const BatchNormParameter *param, | void FusedBatchNormFp32MeanVar(const float *input, float *run_mean, float *run_var, const BatchNormParameter *param, | ||||
| float *save_mean, float *save_var) { | |||||
| float *save_mean, float *save_var, bool isBatchNorm2d) { | |||||
| const float N = (float)param->unit_; | const float N = (float)param->unit_; | ||||
| const float VN = N; | const float VN = N; | ||||
| const float VNUB = (N > 1.0f) ? (N - 1.0f) : 1.0f; | |||||
| const float VNUB = (isBatchNorm2d == false) ? N : ((N > 1.0f) ? (N - 1.0f) : 1.0f); | |||||
| const float momentum = (1.0f - param->momentum_); | const float momentum = (1.0f - param->momentum_); | ||||
| for (int i = 0; i < param->unit_; i++) { | for (int i = 0; i < param->unit_; i++) { | ||||
| @@ -14,8 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_NNACL_FP32_BATCHNORM_H_ | |||||
| #define MINDSPORE_NNACL_FP32_BATCHNORM_H_ | |||||
| #ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_FP32_BATCHNORM_FP32_H_ | |||||
| #define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_FP32_BATCHNORM_FP32_H_ | |||||
| #include "nnacl/batchnorm_parameter.h" | #include "nnacl/batchnorm_parameter.h" | ||||
| @@ -30,9 +30,9 @@ void FusedBatchNormFp32(const float *input, const float *scale, const float *off | |||||
| const float *variance, const BatchNormParameter *param, int task_id, float *output); | const float *variance, const BatchNormParameter *param, int task_id, float *output); | ||||
| void FusedBatchNormFp32MeanVar(const float *input, float *run_mean, float *run_var, const BatchNormParameter *param, | void FusedBatchNormFp32MeanVar(const float *input, float *run_mean, float *run_var, const BatchNormParameter *param, | ||||
| float *save_mean, float *save_var); | |||||
| float *save_mean, float *save_var, bool isBatchNorm2d); | |||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| } | } | ||||
| #endif | #endif | ||||
| #endif // MINDSPORE_NNACL_FUSED_BATCHNORM_H_ | |||||
| #endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_FP32_BATCHNORM_FP32_H_ | |||||
| @@ -161,6 +161,7 @@ int FusedBatchnormCPUKernel::Run() { | |||||
| float *current_var = static_cast<float *>(variance_); | float *current_var = static_cast<float *>(variance_); | ||||
| float *save_mean = static_cast<float *>(in_tensors_.at(FOURTH_INPUT)->data()); | float *save_mean = static_cast<float *>(in_tensors_.at(FOURTH_INPUT)->data()); | ||||
| float *save_variance = static_cast<float *>(in_tensors_.at(FIFTH_INPUT)->data()); | float *save_variance = static_cast<float *>(in_tensors_.at(FIFTH_INPUT)->data()); | ||||
| bool isBatch2d = true; | |||||
| if (in == nullptr || scale == nullptr || offset == nullptr || current_mean == nullptr || current_var == nullptr || | if (in == nullptr || scale == nullptr || offset == nullptr || current_mean == nullptr || current_var == nullptr || | ||||
| save_mean == nullptr || save_variance == nullptr) { | save_mean == nullptr || save_variance == nullptr) { | ||||
| MS_LOG(ERROR) << "The input data is nullptr."; | MS_LOG(ERROR) << "The input data is nullptr."; | ||||
| @@ -168,8 +169,9 @@ int FusedBatchnormCPUKernel::Run() { | |||||
| } | } | ||||
| std::fill(current_mean, current_mean + in_tensors_.at(FOURTH_INPUT)->ElementsNum(), 0.f); | std::fill(current_mean, current_mean + in_tensors_.at(FOURTH_INPUT)->ElementsNum(), 0.f); | ||||
| std::fill(current_var, current_var + in_tensors_.at(FIFTH_INPUT)->ElementsNum(), 0.f); | std::fill(current_var, current_var + in_tensors_.at(FIFTH_INPUT)->ElementsNum(), 0.f); | ||||
| if (in_tensors_.at(FIRST_INPUT)->shape().size() == C2NUM) isBatch2d = false; | |||||
| FusedBatchNormFp32MeanVar(in, current_mean, current_var, param, static_cast<float *>(save_mean), | FusedBatchNormFp32MeanVar(in, current_mean, current_var, param, static_cast<float *>(save_mean), | ||||
| static_cast<float *>(save_variance)); | |||||
| static_cast<float *>(save_variance), isBatch2d); | |||||
| CHECK_NULL_RETURN(out_tensors_.at(SECOND_INPUT)->data()); | CHECK_NULL_RETURN(out_tensors_.at(SECOND_INPUT)->data()); | ||||
| CHECK_NULL_RETURN(out_tensors_.at(THIRD_INPUT)->data()); | CHECK_NULL_RETURN(out_tensors_.at(THIRD_INPUT)->data()); | ||||
| @@ -656,6 +656,11 @@ void TrainSession::CompileEvalOutputs() { | |||||
| if (IsLossKernel(kernel) && !(IsGradKernel(kernel))) { | if (IsLossKernel(kernel) && !(IsGradKernel(kernel))) { | ||||
| for (auto in_kernel : kernel->in_kernels()) { | for (auto in_kernel : kernel->in_kernels()) { | ||||
| if (IsLossKernel(in_kernel) || IsGradKernel(in_kernel)) continue; | if (IsLossKernel(in_kernel) || IsGradKernel(in_kernel)) continue; | ||||
| bool is_loss = false; | |||||
| for (auto in_in_kernel : in_kernel->in_kernels()) { | |||||
| if (IsLossKernel(in_in_kernel)) is_loss = true; | |||||
| } | |||||
| if (is_loss) continue; | |||||
| // insert if not already in | // insert if not already in | ||||
| if (eval_output_node_map_.find(in_kernel->name()) == eval_output_node_map_.end()) { | if (eval_output_node_map_.find(in_kernel->name()) == eval_output_node_map_.end()) { | ||||
| auto *ms_tensor = in_kernel->out_tensors().at(0); | auto *ms_tensor = in_kernel->out_tensors().at(0); | ||||
| @@ -726,13 +731,53 @@ void TrainSession::BuildInferenceKernelsRecursive(kernel::LiteKernel *kernel, st | |||||
| void TrainSession::CompileTrainKernels() { | void TrainSession::CompileTrainKernels() { | ||||
| train_kernels_.clear(); | train_kernels_.clear(); | ||||
| std::vector<kernel::LiteKernel *> train_kernels; | |||||
| for (auto ori_kernel : kernels_) { | for (auto ori_kernel : kernels_) { | ||||
| if (ori_kernel->subgraph_type() == kernel::kNotSubGraph) { | if (ori_kernel->subgraph_type() == kernel::kNotSubGraph) { | ||||
| train_kernels_.push_back(ori_kernel); | |||||
| train_kernels.push_back(ori_kernel); | |||||
| } else { | } else { | ||||
| auto sub_graph = reinterpret_cast<kernel::SubGraphKernel *>(ori_kernel); | auto sub_graph = reinterpret_cast<kernel::SubGraphKernel *>(ori_kernel); | ||||
| for (auto kernel : sub_graph->nodes()) { | |||||
| train_kernels_.push_back(kernel); | |||||
| std::copy(sub_graph->nodes().begin(), sub_graph->nodes().end(), std::back_inserter(train_kernels)); | |||||
| } | |||||
| } | |||||
| // For LSTM GPU operators are synchronized internally hence we need to add sync mechanizm to execution graph | |||||
| for (auto k : train_kernels) { | |||||
| if (k->type() == schema::PrimitiveType_LSTMGradWeight) { | |||||
| // Find PrimitiveType_LSTMGradData that matches this PrimitiveType_LSTMGradWeight | |||||
| for (auto mk : train_kernels) { | |||||
| if (mk->type() == schema::PrimitiveType_LSTMGradData) { | |||||
| if (k->in_tensors().at(C2NUM)->tensor_name() == mk->in_tensors().at(0)->tensor_name()) { | |||||
| mk->AddOutKernel(k); | |||||
| k->AddInKernel(mk); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| std::queue<kernel::LiteKernel *> queue; | |||||
| for (auto k : train_kernels) { | |||||
| auto in_kernels = k->in_kernels(); | |||||
| if (in_kernels.size() == 0) { | |||||
| queue.push(k); | |||||
| } | |||||
| } | |||||
| std::unordered_map<kernel::LiteKernel *, int> map; | |||||
| while (queue.size()) { | |||||
| // pop first element | |||||
| auto k = queue.front(); | |||||
| train_kernels_.push_back(k); | |||||
| queue.pop(); | |||||
| for (auto &ok : k->out_kernels()) { | |||||
| auto cnt_iter = map.find(ok); | |||||
| if (cnt_iter == map.end()) { | |||||
| int ref_cnt = ok->in_kernels().size(); | |||||
| map[ok] = ref_cnt; | |||||
| } | |||||
| cnt_iter = map.find(ok); | |||||
| auto ref_cnt = cnt_iter->second - 1; | |||||
| map[ok] = ref_cnt; | |||||
| if (ref_cnt <= 0) { | |||||
| queue.push(cnt_iter->first); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -177,9 +177,10 @@ int NetTrain::CompareOutput(const session::LiteSession &lite_session) { | |||||
| MS_LOG(ERROR) << "Cannot find output tensors, get model output failed"; | MS_LOG(ERROR) << "Cannot find output tensors, get model output failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| std::map<std::string, mindspore::tensor::MSTensor *> ordered_outputs(tensors_list.begin(), tensors_list.end()); | |||||
| mindspore::tensor::MSTensor *tensor = nullptr; | mindspore::tensor::MSTensor *tensor = nullptr; | ||||
| int i = 1; | int i = 1; | ||||
| for (auto it = tensors_list.begin(); it != tensors_list.end(); ++it) { | |||||
| for (auto it = ordered_outputs.begin(); it != ordered_outputs.end(); ++it) { | |||||
| tensor = lite_session.GetOutputByTensorName(it->first); | tensor = lite_session.GetOutputByTensorName(it->first); | ||||
| std::cout << "output is tensor " << it->first << "\n"; | std::cout << "output is tensor " << it->first << "\n"; | ||||
| auto outputs = tensor->data(); | auto outputs = tensor->data(); | ||||