From 2adff83c99b075f6dc11f2be0497bae67f25de48 Mon Sep 17 00:00:00 2001 From: jjfeing Date: Mon, 25 Jan 2021 09:49:42 +0800 Subject: [PATCH] fix sparse_softmax_cross_entropy_with_logits --- ..._cross_entropy_with_logits_unify_mindir.cc | 29 +++++++++++++------ .../ccsrc/backend/session/ascend_session.cc | 2 +- 2 files changed, 21 insertions(+), 10 deletions(-) diff --git a/mindspore/ccsrc/backend/optimizer/ascend/mindir/sparse_softmax_cross_entropy_with_logits_unify_mindir.cc b/mindspore/ccsrc/backend/optimizer/ascend/mindir/sparse_softmax_cross_entropy_with_logits_unify_mindir.cc index d4929b85cd..f937963b15 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/mindir/sparse_softmax_cross_entropy_with_logits_unify_mindir.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/mindir/sparse_softmax_cross_entropy_with_logits_unify_mindir.cc @@ -123,7 +123,7 @@ CNodePtr CreateSoftmaxCrossEntropyWithLogits(const FuncGraphPtr &graph, const CN std::vector labels_shape = AnfAlgo::GetOutputInferShape(one_hot_node, 0); std::vector loss_shape; - if (labels_shape.size() > 0) { + if (!labels_shape.empty()) { loss_shape.emplace_back(labels_shape[0]); } else { MS_LOG(EXCEPTION) << "one_hot output's shape is empty."; @@ -320,7 +320,7 @@ CNodePtr CreateRealDiv(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax if (labels_shape.size() != 1) { MS_LOG(EXCEPTION) << "label's shape should be 1-D."; } - float y_value = static_cast(labels_shape[0]); + auto y_value = static_cast(labels_shape[0]); auto y = std::make_shared(y_value, kFloat32); auto y_node = CreateValueNode(y, kNumberTypeFloat32); MS_EXCEPTION_IF_NULL(y_node); @@ -436,10 +436,11 @@ const BaseRef GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::DefinePattern( VarPtr x1 = std::make_shared(); VarPtr x2 = std::make_shared(); VarPtr x3 = std::make_shared(); - VarPtr x4 = std::make_shared(); + VectorRef sparse_softmax_cross_entropy_with_logits_grad({prim::kPrimSparseSoftmaxCrossEntropyWithLogits, x1, x2}); VectorRef sparse_softmax_cross_entropy_with_logits({prim::kPrimSparseSoftmaxCrossEntropyWithLogits, x1, x2}); - VectorRef depend({prim::kPrimDepend, sparse_softmax_cross_entropy_with_logits, x3}); - return VectorRef({prim::kPrimMul, depend, x4}); + VectorRef depend( + {prim::kPrimDepend, sparse_softmax_cross_entropy_with_logits_grad, sparse_softmax_cross_entropy_with_logits}); + return VectorRef({prim::kPrimMul, depend, x3}); } const AnfNodePtr GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Process(const FuncGraphPtr &graph, @@ -455,6 +456,7 @@ const AnfNodePtr GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Process(con } auto depend_node = GetDependNode(mul_node); + auto sparse_softmax_node = GetSparseNode(depend_node, 2); auto sparse_softmax_node_grad = GetSparseNode(depend_node, 1); if (sparse_softmax_node_grad->size() != kSparseSoftmaxCrossEntropyWithLogitsInputNum) { MS_LOG(EXCEPTION) << "Op SparseSoftmaxCrossEntropyWithLogits's input not equal " @@ -467,6 +469,7 @@ const AnfNodePtr GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Process(con std::vector softmax_node_outputs; CreateMultipleOutputsOfAnfNode(graph, softmax_node, kSoftmaxCrossEntropyWithLogitsOutputNum, &softmax_node_outputs); + auto reduce_node = CreateReduceMean(graph, sparse_softmax_node_grad, softmax_node_outputs[0]); auto tile_node = CreateTile(graph, sparse_softmax_node_grad, mul_node); CNodePtr real_div_node; if (tile_node == nullptr) { @@ -484,16 +487,22 @@ const AnfNodePtr GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Process(con auto manager = graph->manager(); MS_EXCEPTION_IF_NULL(manager); - manager->Replace(sparse_softmax_node_grad, softmax_node_outputs[1]); + manager->Replace(sparse_softmax_node, reduce_node); + manager->Replace(mul_node, new_mul_node); + std::vector inputs = {NewValueNode(std::make_shared(prim::kPrimDepend->name())), + NewValueNode(MakeValue(true)), NewValueNode(MakeValue(true))}; + auto new_depend = graph->NewCNode(inputs); + manager->Replace(sparse_softmax_node_grad, new_depend); return new_mul_node; } const BaseRef GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV2::DefinePattern() const { VarPtr x1 = std::make_shared(); VarPtr x2 = std::make_shared(); - VarPtr x3 = std::make_shared(); + VectorRef sparse_softmax_cross_entropy_with_logits_grad({prim::kPrimSparseSoftmaxCrossEntropyWithLogits, x1, x2}); VectorRef sparse_softmax_cross_entropy_with_logits({prim::kPrimSparseSoftmaxCrossEntropyWithLogits, x1, x2}); - return VectorRef({prim::kPrimDepend, sparse_softmax_cross_entropy_with_logits, x3}); + return VectorRef( + {prim::kPrimDepend, sparse_softmax_cross_entropy_with_logits_grad, sparse_softmax_cross_entropy_with_logits}); } const AnfNodePtr GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV2::Process(const FuncGraphPtr &graph, @@ -504,6 +513,7 @@ const AnfNodePtr GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV2::Process(c auto depend_node = node->cast(); auto sparse_softmax_node_grad = GetSparseNode(depend_node, 1); + auto sparse_softmax_node = GetSparseNode(depend_node, 2); CNodePtr softmax_node; auto one_hot_node = CreateOneHot(graph, sparse_softmax_node_grad); @@ -511,11 +521,12 @@ const AnfNodePtr GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV2::Process(c std::vector softmax_node_outputs; CreateMultipleOutputsOfAnfNode(graph, softmax_node, kSoftmaxCrossEntropyWithLogitsOutputNum, &softmax_node_outputs); + auto reduce_node = CreateReduceMean(graph, sparse_softmax_node_grad, softmax_node_outputs[0]); auto mul_node = CreateMul(graph, sparse_softmax_node_grad, softmax_node_outputs[1]); auto manager = graph->manager(); MS_EXCEPTION_IF_NULL(manager); - manager->Replace(sparse_softmax_node_grad, softmax_node_outputs[1]); + manager->Replace(sparse_softmax_node, reduce_node); return mul_node; } diff --git a/mindspore/ccsrc/backend/session/ascend_session.cc b/mindspore/ccsrc/backend/session/ascend_session.cc index 452542091c..ec4d11e11d 100644 --- a/mindspore/ccsrc/backend/session/ascend_session.cc +++ b/mindspore/ccsrc/backend/session/ascend_session.cc @@ -585,9 +585,9 @@ void AscendSession::UnifyMindIR(const KernelGraphPtr &graph) { if (ms_context->get_param(MS_CTX_EXECUTION_MODE) == kGraphMode) { unify_mindir_pm->AddPass(std::make_shared()); unify_mindir_pm->AddPass(std::make_shared()); - unify_mindir_pm->AddPass(std::make_shared()); unify_mindir_pm->AddPass(std::make_shared()); unify_mindir_pm->AddPass(std::make_shared()); + unify_mindir_pm->AddPass(std::make_shared()); } else { unify_mindir_pm->AddPass(std::make_shared()); unify_mindir_pm->AddPass(std::make_shared());