|
|
|
@@ -50,7 +50,8 @@ ValueNodePtr CreateValueNode(const ValuePtr &value_ptr, TypeId output_type) { |
|
|
|
return new_node; |
|
|
|
} |
|
|
|
|
|
|
|
CNodePtr CreateOneHot(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_node, bool is_pynative = false) { |
|
|
|
CNodePtr CreateOneHot(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_node, |
|
|
|
bool is_convert_const_to_attr = false) { |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
MS_EXCEPTION_IF_NULL(sparse_softmax_node); |
|
|
|
|
|
|
|
@@ -80,7 +81,7 @@ CNodePtr CreateOneHot(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_ |
|
|
|
one_hot_primitive->set_attr(kAttrOutputNames, MakeValue(output_names)); |
|
|
|
|
|
|
|
std::vector<AnfNodePtr> one_hot_inputs; |
|
|
|
if (is_pynative) { |
|
|
|
if (is_convert_const_to_attr) { |
|
|
|
one_hot_inputs = {NewValueNode(one_hot_primitive), sparse_softmax_node->input(2), value_on_node, value_off_node}; |
|
|
|
} else { |
|
|
|
auto depth_node = NewValueNode(depth); |
|
|
|
@@ -97,7 +98,7 @@ CNodePtr CreateOneHot(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_ |
|
|
|
std::vector<size_t> labels_shape = AnfAlgo ::GetPrevNodeOutputInferShape(sparse_softmax_node, 1); |
|
|
|
labels_shape.emplace_back(depth); |
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32}, {labels_shape}, one_hot_node.get()); |
|
|
|
if (is_pynative) { |
|
|
|
if (is_convert_const_to_attr) { |
|
|
|
AnfAlgo::SetNodeAttr(kAttrDepth, MakeValue(depth), one_hot_node); |
|
|
|
} |
|
|
|
return one_hot_node; |
|
|
|
@@ -252,7 +253,7 @@ CNodePtr CreateExpandDimsPynative(const FuncGraphPtr &graph, const CNodePtr &rea |
|
|
|
} |
|
|
|
|
|
|
|
CNodePtr CreateTile(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_node, const CNodePtr &mul_node, |
|
|
|
bool is_pynative = false) { |
|
|
|
bool is_convert_const_to_attr = false) { |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
MS_EXCEPTION_IF_NULL(sparse_softmax_node); |
|
|
|
MS_EXCEPTION_IF_NULL(mul_node); |
|
|
|
@@ -268,6 +269,9 @@ CNodePtr CreateTile(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_no |
|
|
|
std::vector<int64_t> multiple_value; |
|
|
|
std::transform(labels_shape.begin(), labels_shape.end(), std::back_inserter(multiple_value), |
|
|
|
[](size_t label) { return static_cast<int64_t>(label); }); |
|
|
|
if (std::all_of(multiple_value.begin(), multiple_value.end(), [](int64_t value) { return value == 1; })) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
auto multiples = MakeValue(multiple_value); |
|
|
|
auto multiples_node = CreateValueNode(multiples, kNumberTypeInt64); |
|
|
|
MS_EXCEPTION_IF_NULL(multiples_node); |
|
|
|
@@ -279,7 +283,7 @@ CNodePtr CreateTile(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_no |
|
|
|
tile_primitive->set_attr(kAttrOutputNames, MakeValue(output_names)); |
|
|
|
|
|
|
|
std::vector<AnfNodePtr> tile_inputs; |
|
|
|
if (is_pynative) { |
|
|
|
if (is_convert_const_to_attr) { |
|
|
|
tile_inputs = {NewValueNode(tile_primitive), mul_node->input(2)}; |
|
|
|
} else { |
|
|
|
auto kernel_graph = graph->cast<KernelGraphPtr>(); |
|
|
|
@@ -292,7 +296,7 @@ CNodePtr CreateTile(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_no |
|
|
|
tile_node->set_scope(mul_node->scope()); |
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetPrevNodeOutputInferDataType(mul_node, 1)}, {labels_shape}, |
|
|
|
tile_node.get()); |
|
|
|
if (is_pynative) { |
|
|
|
if (is_convert_const_to_attr) { |
|
|
|
AnfAlgo::SetNodeAttr(kAttrMultiples, MakeValue(multiples), tile_node); |
|
|
|
} |
|
|
|
// feature map set |
|
|
|
@@ -302,7 +306,7 @@ CNodePtr CreateTile(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_no |
|
|
|
return tile_node; |
|
|
|
} |
|
|
|
|
|
|
|
CNodePtr CreateRealDiv(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_node, const CNodePtr &tile_node) { |
|
|
|
CNodePtr CreateRealDiv(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_node, const AnfNodePtr &tile_node) { |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
MS_EXCEPTION_IF_NULL(sparse_softmax_node); |
|
|
|
MS_EXCEPTION_IF_NULL(tile_node); |
|
|
|
@@ -464,16 +468,24 @@ const AnfNodePtr GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Process(con |
|
|
|
std::vector<AnfNodePtr> softmax_node_outputs; |
|
|
|
CreateMultipleOutputsOfAnfNode(graph, softmax_node, kSoftmaxCrossEntropyWithLogitsOutputNum, &softmax_node_outputs); |
|
|
|
auto tile_node = CreateTile(graph, sparse_softmax_node_grad, mul_node); |
|
|
|
auto real_div_node = CreateRealDiv(graph, sparse_softmax_node_grad, tile_node); |
|
|
|
CNodePtr real_div_node; |
|
|
|
if (tile_node == nullptr) { |
|
|
|
real_div_node = CreateRealDiv(graph, sparse_softmax_node_grad, mul_node->input(2)); |
|
|
|
} else { |
|
|
|
real_div_node = CreateRealDiv(graph, sparse_softmax_node_grad, tile_node); |
|
|
|
} |
|
|
|
auto expand_dims_node = CreateExpandDims(graph, real_div_node); |
|
|
|
|
|
|
|
mul_node->set_input(1, softmax_node_outputs[1]); |
|
|
|
mul_node->set_input(2, expand_dims_node); |
|
|
|
std::vector<AnfNodePtr> new_mul_inputs = {NewValueNode(std::make_shared<Primitive>(kMulOpName)), |
|
|
|
softmax_node_outputs[1], expand_dims_node}; |
|
|
|
auto new_mul_node = graph->NewCNode(new_mul_inputs); |
|
|
|
MS_EXCEPTION_IF_NULL(new_mul_node); |
|
|
|
new_mul_node->set_scope(mul_node->scope()); |
|
|
|
new_mul_node->set_abstract(mul_node->abstract()); |
|
|
|
|
|
|
|
auto manager = graph->manager(); |
|
|
|
MS_EXCEPTION_IF_NULL(manager); |
|
|
|
manager->Replace(sparse_softmax_node_grad, softmax_node_outputs[1]); |
|
|
|
return mul_node; |
|
|
|
return new_mul_node; |
|
|
|
} |
|
|
|
|
|
|
|
const BaseRef GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV2::DefinePattern() const { |
|
|
|
@@ -563,19 +575,26 @@ const AnfNodePtr PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Pro |
|
|
|
} |
|
|
|
|
|
|
|
CNodePtr softmax_node; |
|
|
|
auto one_hot_node = CreateOneHot(graph, sparse_softmax_node_grad, true); |
|
|
|
auto one_hot_node = CreateOneHot(graph, sparse_softmax_node_grad); |
|
|
|
softmax_node = CreateSoftmaxCrossEntropyWithLogits(graph, sparse_softmax_node_grad, one_hot_node); |
|
|
|
|
|
|
|
std::vector<AnfNodePtr> softmax_node_outputs; |
|
|
|
CreateMultipleOutputsOfAnfNode(graph, softmax_node, kSoftmaxCrossEntropyWithLogitsOutputNum, &softmax_node_outputs); |
|
|
|
auto tile_node = CreateTile(graph, sparse_softmax_node_grad, mul_node, true); |
|
|
|
auto real_div_node = CreateRealDiv(graph, sparse_softmax_node_grad, tile_node); |
|
|
|
auto tile_node = CreateTile(graph, sparse_softmax_node_grad, mul_node); |
|
|
|
CNodePtr real_div_node; |
|
|
|
if (tile_node == nullptr) { |
|
|
|
real_div_node = CreateRealDiv(graph, sparse_softmax_node_grad, mul_node->input(2)); |
|
|
|
} else { |
|
|
|
real_div_node = CreateRealDiv(graph, sparse_softmax_node_grad, tile_node); |
|
|
|
} |
|
|
|
auto expand_dims_node = CreateExpandDimsPynative(graph, real_div_node); |
|
|
|
|
|
|
|
mul_node->set_input(1, softmax_node_outputs[1]); |
|
|
|
mul_node->set_input(2, expand_dims_node); |
|
|
|
|
|
|
|
return mul_node; |
|
|
|
std::vector<AnfNodePtr> new_mul_inputs = {NewValueNode(std::make_shared<Primitive>(kMulOpName)), |
|
|
|
softmax_node_outputs[1], expand_dims_node}; |
|
|
|
auto new_mul_node = graph->NewCNode(new_mul_inputs); |
|
|
|
MS_EXCEPTION_IF_NULL(new_mul_node); |
|
|
|
new_mul_node->set_scope(mul_node->scope()); |
|
|
|
new_mul_node->set_abstract(mul_node->abstract()); |
|
|
|
return new_mul_node; |
|
|
|
} |
|
|
|
} // namespace opt |
|
|
|
} // namespace mindspore |