Browse Source

fix sparse_softmax_cross_entropy_with_logits

tags/v1.2.0-rc1
jjfeing 4 years ago
parent
commit
2adff83c99
2 changed files with 21 additions and 10 deletions
  1. +20
    -9
      mindspore/ccsrc/backend/optimizer/ascend/mindir/sparse_softmax_cross_entropy_with_logits_unify_mindir.cc
  2. +1
    -1
      mindspore/ccsrc/backend/session/ascend_session.cc

+ 20
- 9
mindspore/ccsrc/backend/optimizer/ascend/mindir/sparse_softmax_cross_entropy_with_logits_unify_mindir.cc View File

@@ -123,7 +123,7 @@ CNodePtr CreateSoftmaxCrossEntropyWithLogits(const FuncGraphPtr &graph, const CN

std::vector<size_t> labels_shape = AnfAlgo::GetOutputInferShape(one_hot_node, 0);
std::vector<size_t> loss_shape;
if (labels_shape.size() > 0) {
if (!labels_shape.empty()) {
loss_shape.emplace_back(labels_shape[0]);
} else {
MS_LOG(EXCEPTION) << "one_hot output's shape is empty.";
@@ -320,7 +320,7 @@ CNodePtr CreateRealDiv(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax
if (labels_shape.size() != 1) {
MS_LOG(EXCEPTION) << "label's shape should be 1-D.";
}
float y_value = static_cast<float>(labels_shape[0]);
auto y_value = static_cast<float>(labels_shape[0]);
auto y = std::make_shared<tensor::Tensor>(y_value, kFloat32);
auto y_node = CreateValueNode(y, kNumberTypeFloat32);
MS_EXCEPTION_IF_NULL(y_node);
@@ -436,10 +436,11 @@ const BaseRef GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::DefinePattern(
VarPtr x1 = std::make_shared<Var>();
VarPtr x2 = std::make_shared<Var>();
VarPtr x3 = std::make_shared<Var>();
VarPtr x4 = std::make_shared<Var>();
VectorRef sparse_softmax_cross_entropy_with_logits_grad({prim::kPrimSparseSoftmaxCrossEntropyWithLogits, x1, x2});
VectorRef sparse_softmax_cross_entropy_with_logits({prim::kPrimSparseSoftmaxCrossEntropyWithLogits, x1, x2});
VectorRef depend({prim::kPrimDepend, sparse_softmax_cross_entropy_with_logits, x3});
return VectorRef({prim::kPrimMul, depend, x4});
VectorRef depend(
{prim::kPrimDepend, sparse_softmax_cross_entropy_with_logits_grad, sparse_softmax_cross_entropy_with_logits});
return VectorRef({prim::kPrimMul, depend, x3});
}

const AnfNodePtr GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Process(const FuncGraphPtr &graph,
@@ -455,6 +456,7 @@ const AnfNodePtr GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Process(con
}

auto depend_node = GetDependNode(mul_node);
auto sparse_softmax_node = GetSparseNode(depend_node, 2);
auto sparse_softmax_node_grad = GetSparseNode(depend_node, 1);
if (sparse_softmax_node_grad->size() != kSparseSoftmaxCrossEntropyWithLogitsInputNum) {
MS_LOG(EXCEPTION) << "Op SparseSoftmaxCrossEntropyWithLogits's input not equal "
@@ -467,6 +469,7 @@ const AnfNodePtr GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Process(con

std::vector<AnfNodePtr> softmax_node_outputs;
CreateMultipleOutputsOfAnfNode(graph, softmax_node, kSoftmaxCrossEntropyWithLogitsOutputNum, &softmax_node_outputs);
auto reduce_node = CreateReduceMean(graph, sparse_softmax_node_grad, softmax_node_outputs[0]);
auto tile_node = CreateTile(graph, sparse_softmax_node_grad, mul_node);
CNodePtr real_div_node;
if (tile_node == nullptr) {
@@ -484,16 +487,22 @@ const AnfNodePtr GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Process(con

auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
manager->Replace(sparse_softmax_node_grad, softmax_node_outputs[1]);
manager->Replace(sparse_softmax_node, reduce_node);
manager->Replace(mul_node, new_mul_node);
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimDepend->name())),
NewValueNode(MakeValue<bool>(true)), NewValueNode(MakeValue<bool>(true))};
auto new_depend = graph->NewCNode(inputs);
manager->Replace(sparse_softmax_node_grad, new_depend);
return new_mul_node;
}

const BaseRef GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV2::DefinePattern() const {
VarPtr x1 = std::make_shared<Var>();
VarPtr x2 = std::make_shared<Var>();
VarPtr x3 = std::make_shared<Var>();
VectorRef sparse_softmax_cross_entropy_with_logits_grad({prim::kPrimSparseSoftmaxCrossEntropyWithLogits, x1, x2});
VectorRef sparse_softmax_cross_entropy_with_logits({prim::kPrimSparseSoftmaxCrossEntropyWithLogits, x1, x2});
return VectorRef({prim::kPrimDepend, sparse_softmax_cross_entropy_with_logits, x3});
return VectorRef(
{prim::kPrimDepend, sparse_softmax_cross_entropy_with_logits_grad, sparse_softmax_cross_entropy_with_logits});
}

const AnfNodePtr GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV2::Process(const FuncGraphPtr &graph,
@@ -504,6 +513,7 @@ const AnfNodePtr GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV2::Process(c

auto depend_node = node->cast<CNodePtr>();
auto sparse_softmax_node_grad = GetSparseNode(depend_node, 1);
auto sparse_softmax_node = GetSparseNode(depend_node, 2);

CNodePtr softmax_node;
auto one_hot_node = CreateOneHot(graph, sparse_softmax_node_grad);
@@ -511,11 +521,12 @@ const AnfNodePtr GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV2::Process(c

std::vector<AnfNodePtr> softmax_node_outputs;
CreateMultipleOutputsOfAnfNode(graph, softmax_node, kSoftmaxCrossEntropyWithLogitsOutputNum, &softmax_node_outputs);
auto reduce_node = CreateReduceMean(graph, sparse_softmax_node_grad, softmax_node_outputs[0]);
auto mul_node = CreateMul(graph, sparse_softmax_node_grad, softmax_node_outputs[1]);

auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
manager->Replace(sparse_softmax_node_grad, softmax_node_outputs[1]);
manager->Replace(sparse_softmax_node, reduce_node);
return mul_node;
}



+ 1
- 1
mindspore/ccsrc/backend/session/ascend_session.cc View File

@@ -585,9 +585,9 @@ void AscendSession::UnifyMindIR(const KernelGraphPtr &graph) {
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
unify_mindir_pm->AddPass(std::make_shared<opt::DropoutGradUnifyMindIR>());
unify_mindir_pm->AddPass(std::make_shared<opt::DropoutUnifyMindIR>());
unify_mindir_pm->AddPass(std::make_shared<opt::SparseSoftmaxCrossEntropyWithLogitsUnifyMindIR>());
unify_mindir_pm->AddPass(std::make_shared<opt::GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR>());
unify_mindir_pm->AddPass(std::make_shared<opt::GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV2>());
unify_mindir_pm->AddPass(std::make_shared<opt::SparseSoftmaxCrossEntropyWithLogitsUnifyMindIR>());
} else {
unify_mindir_pm->AddPass(std::make_shared<opt::DropoutUnifyMindIRPynative>());
unify_mindir_pm->AddPass(std::make_shared<opt::DropoutGradUnifyMindIRPynative>());


Loading…
Cancel
Save