|
|
|
@@ -123,7 +123,7 @@ CNodePtr CreateSoftmaxCrossEntropyWithLogits(const FuncGraphPtr &graph, const CN |
|
|
|
|
|
|
|
std::vector<size_t> labels_shape = AnfAlgo::GetOutputInferShape(one_hot_node, 0); |
|
|
|
std::vector<size_t> loss_shape; |
|
|
|
if (labels_shape.size() > 0) { |
|
|
|
if (!labels_shape.empty()) { |
|
|
|
loss_shape.emplace_back(labels_shape[0]); |
|
|
|
} else { |
|
|
|
MS_LOG(EXCEPTION) << "one_hot output's shape is empty."; |
|
|
|
@@ -320,7 +320,7 @@ CNodePtr CreateRealDiv(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax |
|
|
|
if (labels_shape.size() != 1) { |
|
|
|
MS_LOG(EXCEPTION) << "label's shape should be 1-D."; |
|
|
|
} |
|
|
|
float y_value = static_cast<float>(labels_shape[0]); |
|
|
|
auto y_value = static_cast<float>(labels_shape[0]); |
|
|
|
auto y = std::make_shared<tensor::Tensor>(y_value, kFloat32); |
|
|
|
auto y_node = CreateValueNode(y, kNumberTypeFloat32); |
|
|
|
MS_EXCEPTION_IF_NULL(y_node); |
|
|
|
@@ -436,10 +436,11 @@ const BaseRef GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::DefinePattern( |
|
|
|
VarPtr x1 = std::make_shared<Var>(); |
|
|
|
VarPtr x2 = std::make_shared<Var>(); |
|
|
|
VarPtr x3 = std::make_shared<Var>(); |
|
|
|
VarPtr x4 = std::make_shared<Var>(); |
|
|
|
VectorRef sparse_softmax_cross_entropy_with_logits_grad({prim::kPrimSparseSoftmaxCrossEntropyWithLogits, x1, x2}); |
|
|
|
VectorRef sparse_softmax_cross_entropy_with_logits({prim::kPrimSparseSoftmaxCrossEntropyWithLogits, x1, x2}); |
|
|
|
VectorRef depend({prim::kPrimDepend, sparse_softmax_cross_entropy_with_logits, x3}); |
|
|
|
return VectorRef({prim::kPrimMul, depend, x4}); |
|
|
|
VectorRef depend( |
|
|
|
{prim::kPrimDepend, sparse_softmax_cross_entropy_with_logits_grad, sparse_softmax_cross_entropy_with_logits}); |
|
|
|
return VectorRef({prim::kPrimMul, depend, x3}); |
|
|
|
} |
|
|
|
|
|
|
|
const AnfNodePtr GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Process(const FuncGraphPtr &graph, |
|
|
|
@@ -455,6 +456,7 @@ const AnfNodePtr GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Process(con |
|
|
|
} |
|
|
|
|
|
|
|
auto depend_node = GetDependNode(mul_node); |
|
|
|
auto sparse_softmax_node = GetSparseNode(depend_node, 2); |
|
|
|
auto sparse_softmax_node_grad = GetSparseNode(depend_node, 1); |
|
|
|
if (sparse_softmax_node_grad->size() != kSparseSoftmaxCrossEntropyWithLogitsInputNum) { |
|
|
|
MS_LOG(EXCEPTION) << "Op SparseSoftmaxCrossEntropyWithLogits's input not equal " |
|
|
|
@@ -467,6 +469,7 @@ const AnfNodePtr GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Process(con |
|
|
|
|
|
|
|
std::vector<AnfNodePtr> softmax_node_outputs; |
|
|
|
CreateMultipleOutputsOfAnfNode(graph, softmax_node, kSoftmaxCrossEntropyWithLogitsOutputNum, &softmax_node_outputs); |
|
|
|
auto reduce_node = CreateReduceMean(graph, sparse_softmax_node_grad, softmax_node_outputs[0]); |
|
|
|
auto tile_node = CreateTile(graph, sparse_softmax_node_grad, mul_node); |
|
|
|
CNodePtr real_div_node; |
|
|
|
if (tile_node == nullptr) { |
|
|
|
@@ -484,16 +487,22 @@ const AnfNodePtr GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Process(con |
|
|
|
|
|
|
|
auto manager = graph->manager(); |
|
|
|
MS_EXCEPTION_IF_NULL(manager); |
|
|
|
manager->Replace(sparse_softmax_node_grad, softmax_node_outputs[1]); |
|
|
|
manager->Replace(sparse_softmax_node, reduce_node); |
|
|
|
manager->Replace(mul_node, new_mul_node); |
|
|
|
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimDepend->name())), |
|
|
|
NewValueNode(MakeValue<bool>(true)), NewValueNode(MakeValue<bool>(true))}; |
|
|
|
auto new_depend = graph->NewCNode(inputs); |
|
|
|
manager->Replace(sparse_softmax_node_grad, new_depend); |
|
|
|
return new_mul_node; |
|
|
|
} |
|
|
|
|
|
|
|
const BaseRef GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV2::DefinePattern() const { |
|
|
|
VarPtr x1 = std::make_shared<Var>(); |
|
|
|
VarPtr x2 = std::make_shared<Var>(); |
|
|
|
VarPtr x3 = std::make_shared<Var>(); |
|
|
|
VectorRef sparse_softmax_cross_entropy_with_logits_grad({prim::kPrimSparseSoftmaxCrossEntropyWithLogits, x1, x2}); |
|
|
|
VectorRef sparse_softmax_cross_entropy_with_logits({prim::kPrimSparseSoftmaxCrossEntropyWithLogits, x1, x2}); |
|
|
|
return VectorRef({prim::kPrimDepend, sparse_softmax_cross_entropy_with_logits, x3}); |
|
|
|
return VectorRef( |
|
|
|
{prim::kPrimDepend, sparse_softmax_cross_entropy_with_logits_grad, sparse_softmax_cross_entropy_with_logits}); |
|
|
|
} |
|
|
|
|
|
|
|
const AnfNodePtr GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV2::Process(const FuncGraphPtr &graph, |
|
|
|
@@ -504,6 +513,7 @@ const AnfNodePtr GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV2::Process(c |
|
|
|
|
|
|
|
auto depend_node = node->cast<CNodePtr>(); |
|
|
|
auto sparse_softmax_node_grad = GetSparseNode(depend_node, 1); |
|
|
|
auto sparse_softmax_node = GetSparseNode(depend_node, 2); |
|
|
|
|
|
|
|
CNodePtr softmax_node; |
|
|
|
auto one_hot_node = CreateOneHot(graph, sparse_softmax_node_grad); |
|
|
|
@@ -511,11 +521,12 @@ const AnfNodePtr GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV2::Process(c |
|
|
|
|
|
|
|
std::vector<AnfNodePtr> softmax_node_outputs; |
|
|
|
CreateMultipleOutputsOfAnfNode(graph, softmax_node, kSoftmaxCrossEntropyWithLogitsOutputNum, &softmax_node_outputs); |
|
|
|
auto reduce_node = CreateReduceMean(graph, sparse_softmax_node_grad, softmax_node_outputs[0]); |
|
|
|
auto mul_node = CreateMul(graph, sparse_softmax_node_grad, softmax_node_outputs[1]); |
|
|
|
|
|
|
|
auto manager = graph->manager(); |
|
|
|
MS_EXCEPTION_IF_NULL(manager); |
|
|
|
manager->Replace(sparse_softmax_node_grad, softmax_node_outputs[1]); |
|
|
|
manager->Replace(sparse_softmax_node, reduce_node); |
|
|
|
return mul_node; |
|
|
|
} |
|
|
|
|
|
|
|
|