|
|
|
@@ -27,6 +27,9 @@ |
|
|
|
#include "ir/dtype/type.h" |
|
|
|
|
|
|
|
constexpr auto softmax_output_shape_size = 2; |
|
|
|
constexpr auto kAttrDepth = "depth"; |
|
|
|
constexpr auto kAttrMultiples = "multiples"; |
|
|
|
constexpr auto kIsFeatureMapInputList = "IsFeatureMapInputList"; |
|
|
|
namespace mindspore { |
|
|
|
namespace opt { |
|
|
|
namespace { |
|
|
|
@@ -47,12 +50,12 @@ ValueNodePtr CreateValueNode(const ValuePtr &value_ptr, TypeId output_type) { |
|
|
|
return new_node; |
|
|
|
} |
|
|
|
|
|
|
|
CNodePtr CreateOneHot(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_node) { |
|
|
|
CNodePtr CreateOneHot(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_node, bool is_pynative = false) { |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
MS_EXCEPTION_IF_NULL(sparse_softmax_node); |
|
|
|
|
|
|
|
std::vector<size_t> logits_shape = AnfAlgo::GetPrevNodeOutputInferShape(sparse_softmax_node, 0); |
|
|
|
int64_t depth; |
|
|
|
int64_t depth = 0; |
|
|
|
if (logits_shape.size() >= 1) { |
|
|
|
size_t index = logits_shape.size() - 1; |
|
|
|
depth = logits_shape[index]; |
|
|
|
@@ -66,33 +69,37 @@ CNodePtr CreateOneHot(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_ |
|
|
|
auto value_off = std::make_shared<tensor::Tensor>(0.0, kFloat32); |
|
|
|
auto value_off_node = CreateValueNode(value_off, kNumberTypeFloat32); |
|
|
|
MS_EXCEPTION_IF_NULL(value_off_node); |
|
|
|
|
|
|
|
auto kernel_graph = graph->cast<KernelGraphPtr>(); |
|
|
|
kernel_graph->AddValueNodeToGraph(value_on_node); |
|
|
|
kernel_graph->AddValueNodeToGraph(value_off_node); |
|
|
|
|
|
|
|
auto depth_node = NewValueNode(depth); |
|
|
|
MS_EXCEPTION_IF_NULL(depth_node); |
|
|
|
|
|
|
|
auto depth_abstract = std::make_shared<abstract::AbstractScalar>(); |
|
|
|
depth_abstract->set_type(kInt64); |
|
|
|
depth_node->set_abstract(depth_abstract); |
|
|
|
|
|
|
|
auto one_hot_primitive = std::make_shared<Primitive>(kOneHotOpName); |
|
|
|
std::vector<std::string> input_names = {"indices", "depth", "on_value", "off_value"}; |
|
|
|
std::vector<std::string> output_names = {"output"}; |
|
|
|
one_hot_primitive->set_attr(kAttrInputNames, MakeValue(input_names)); |
|
|
|
one_hot_primitive->set_attr(kAttrOutputNames, MakeValue(output_names)); |
|
|
|
std::vector<AnfNodePtr> one_hot_inputs = {NewValueNode(one_hot_primitive), sparse_softmax_node->input(2), depth_node, |
|
|
|
value_on_node, value_off_node}; |
|
|
|
|
|
|
|
std::vector<AnfNodePtr> one_hot_inputs; |
|
|
|
if (is_pynative) { |
|
|
|
one_hot_inputs = {NewValueNode(one_hot_primitive), sparse_softmax_node->input(2), value_on_node, value_off_node}; |
|
|
|
} else { |
|
|
|
auto depth_node = NewValueNode(depth); |
|
|
|
MS_EXCEPTION_IF_NULL(depth_node); |
|
|
|
auto depth_abstract = std::make_shared<abstract::AbstractScalar>(); |
|
|
|
depth_abstract->set_type(kInt64); |
|
|
|
depth_node->set_abstract(depth_abstract); |
|
|
|
one_hot_inputs = {NewValueNode(one_hot_primitive), sparse_softmax_node->input(2), depth_node, value_on_node, |
|
|
|
value_off_node}; |
|
|
|
} |
|
|
|
auto one_hot_node = graph->NewCNode(one_hot_inputs); |
|
|
|
MS_EXCEPTION_IF_NULL(one_hot_node); |
|
|
|
|
|
|
|
one_hot_node->set_scope(sparse_softmax_node->scope()); |
|
|
|
std::vector<size_t> labels_shape = AnfAlgo ::GetPrevNodeOutputInferShape(sparse_softmax_node, 1); |
|
|
|
labels_shape.emplace_back(depth); |
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32}, {labels_shape}, one_hot_node.get()); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(-1), one_hot_node); |
|
|
|
if (is_pynative) { |
|
|
|
AnfAlgo::SetNodeAttr(kAttrDepth, MakeValue(depth), one_hot_node); |
|
|
|
} |
|
|
|
return one_hot_node; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -106,9 +113,6 @@ CNodePtr CreateSoftmaxCrossEntropyWithLogits(const FuncGraphPtr &graph, const CN |
|
|
|
MS_LOG(EXCEPTION) << "sparse_softmax_cross_entropy_with_logits's input size not equal " |
|
|
|
<< kSparseSoftmaxCrossEntropyWithLogitsInputNum; |
|
|
|
} |
|
|
|
if (one_hot_node->size() != kOneHotInputNum) { |
|
|
|
MS_LOG(EXCEPTION) << "ont_hot's input size not equal " << kOneHotInputNum; |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(kSoftmaxCrossEntropyWithLogitsOpName)), |
|
|
|
sparse_softmax_node->input(1), one_hot_node}; |
|
|
|
@@ -131,7 +135,7 @@ CNodePtr CreateSoftmaxCrossEntropyWithLogits(const FuncGraphPtr &graph, const CN |
|
|
|
return softmax_node; |
|
|
|
} |
|
|
|
|
|
|
|
ValueNodePtr GetAxis(const AnfNodePtr &node) { |
|
|
|
std::vector<int64_t> GetAxis(const AnfNodePtr &node) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
std::vector<size_t> output_shape = AnfAlgo::GetOutputInferShape(node, 0); |
|
|
|
if (output_shape.empty()) { |
|
|
|
@@ -141,13 +145,19 @@ ValueNodePtr GetAxis(const AnfNodePtr &node) { |
|
|
|
for (size_t i = 0; i < output_shape.size(); i++) { |
|
|
|
range.emplace_back(i); |
|
|
|
} |
|
|
|
return range; |
|
|
|
} |
|
|
|
|
|
|
|
ValueNodePtr GetAxisNode(const AnfNodePtr &node) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
auto range = GetAxis(node); |
|
|
|
auto axis_node = CreateValueNode(MakeValue(range), kNumberTypeInt64); |
|
|
|
MS_EXCEPTION_IF_NULL(axis_node); |
|
|
|
return axis_node; |
|
|
|
} |
|
|
|
|
|
|
|
CNodePtr CreateReduceMean(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_node, |
|
|
|
const AnfNodePtr &softmax_output_node) { |
|
|
|
const AnfNodePtr &softmax_output_node, bool is_pynative = false) { |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
MS_EXCEPTION_IF_NULL(sparse_softmax_node); |
|
|
|
MS_EXCEPTION_IF_NULL(softmax_output_node); |
|
|
|
@@ -155,10 +165,10 @@ CNodePtr CreateReduceMean(const FuncGraphPtr &graph, const CNodePtr &sparse_soft |
|
|
|
MS_LOG(EXCEPTION) << "sparse_softmax_cross_entropy_with_logits's input size not equal " |
|
|
|
<< kSparseSoftmaxCrossEntropyWithLogitsInputNum; |
|
|
|
} |
|
|
|
auto axis_node = GetAxis(softmax_output_node); |
|
|
|
|
|
|
|
auto axis_value = GetAxis(softmax_output_node); |
|
|
|
auto axis_node = GetAxisNode(softmax_output_node); |
|
|
|
MS_EXCEPTION_IF_NULL(axis_node); |
|
|
|
auto kernel_graph = graph->cast<KernelGraphPtr>(); |
|
|
|
kernel_graph->AddValueNodeToGraph(axis_node); |
|
|
|
|
|
|
|
auto reduce_primitive = std::make_shared<Primitive>(kReduceMeanOpName); |
|
|
|
std::vector<std::string> input_names = {"x", "axis"}; |
|
|
|
@@ -166,14 +176,23 @@ CNodePtr CreateReduceMean(const FuncGraphPtr &graph, const CNodePtr &sparse_soft |
|
|
|
reduce_primitive->set_attr(kAttrInputNames, MakeValue(input_names)); |
|
|
|
reduce_primitive->set_attr(kAttrOutputNames, MakeValue(output_names)); |
|
|
|
|
|
|
|
std::vector<AnfNodePtr> inputs = {NewValueNode(reduce_primitive), softmax_output_node, axis_node}; |
|
|
|
auto kernel_graph = graph->cast<KernelGraphPtr>(); |
|
|
|
std::vector<AnfNodePtr> inputs; |
|
|
|
if (is_pynative) { |
|
|
|
inputs = {NewValueNode(reduce_primitive), softmax_output_node}; |
|
|
|
} else { |
|
|
|
kernel_graph->AddValueNodeToGraph(axis_node); |
|
|
|
inputs = {NewValueNode(reduce_primitive), softmax_output_node, axis_node}; |
|
|
|
} |
|
|
|
auto reduce_node = graph->NewCNode(inputs); |
|
|
|
MS_EXCEPTION_IF_NULL(reduce_node); |
|
|
|
|
|
|
|
reduce_node->set_scope(sparse_softmax_node->scope()); |
|
|
|
auto reduce_abstract = softmax_output_node->abstract(); |
|
|
|
reduce_abstract->set_shape(std::make_shared<abstract::Shape>()); |
|
|
|
reduce_node->set_abstract(reduce_abstract); |
|
|
|
if (is_pynative) { |
|
|
|
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(axis_value), reduce_node); |
|
|
|
} |
|
|
|
return reduce_node; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -207,8 +226,33 @@ CNodePtr CreateExpandDims(const FuncGraphPtr &graph, const CNodePtr &real_div_no |
|
|
|
expand_dims_node.get()); |
|
|
|
return expand_dims_node; |
|
|
|
} |
|
|
|
CNodePtr CreateExpandDimsPynative(const FuncGraphPtr &graph, const CNodePtr &real_div_node) { |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
MS_EXCEPTION_IF_NULL(real_div_node); |
|
|
|
if (real_div_node->size() != kRealDivInputNum) { |
|
|
|
MS_LOG(EXCEPTION) << "Op real_div's input num not equal " << kRealDivInputNum; |
|
|
|
} |
|
|
|
int64_t axis = -1; |
|
|
|
auto expand_dims_primitive = std::make_shared<Primitive>(kExpandDimsOpName); |
|
|
|
std::vector<std::string> input_names = {"x"}; |
|
|
|
std::vector<std::string> output_names = {"output"}; |
|
|
|
expand_dims_primitive->set_attr(kAttrInputNames, MakeValue(input_names)); |
|
|
|
expand_dims_primitive->set_attr(kAttrOutputNames, MakeValue(output_names)); |
|
|
|
std::vector<AnfNodePtr> expand_dims_inputs = {NewValueNode(expand_dims_primitive), real_div_node}; |
|
|
|
auto expand_dims_node = graph->NewCNode(expand_dims_inputs); |
|
|
|
MS_EXCEPTION_IF_NULL(expand_dims_node); |
|
|
|
|
|
|
|
expand_dims_node->set_scope(real_div_node->scope()); |
|
|
|
std::vector<size_t> y_shape = AnfAlgo::GetOutputInferShape(real_div_node, 0); |
|
|
|
y_shape.emplace_back(1); |
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(real_div_node, 0)}, {y_shape}, |
|
|
|
expand_dims_node.get()); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(axis), expand_dims_node); |
|
|
|
return expand_dims_node; |
|
|
|
} |
|
|
|
|
|
|
|
CNodePtr CreateTile(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_node, const CNodePtr &mul_node) { |
|
|
|
CNodePtr CreateTile(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_node, const CNodePtr &mul_node, |
|
|
|
bool is_pynative = false) { |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
MS_EXCEPTION_IF_NULL(sparse_softmax_node); |
|
|
|
MS_EXCEPTION_IF_NULL(mul_node); |
|
|
|
@@ -224,24 +268,37 @@ CNodePtr CreateTile(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_no |
|
|
|
std::vector<int64_t> multiple_value; |
|
|
|
std::transform(labels_shape.begin(), labels_shape.end(), std::back_inserter(multiple_value), |
|
|
|
[](size_t label) { return static_cast<int64_t>(label); }); |
|
|
|
auto mutiples = MakeValue(multiple_value); |
|
|
|
auto mutiples_node = CreateValueNode(mutiples, kNumberTypeInt64); |
|
|
|
MS_EXCEPTION_IF_NULL(mutiples_node); |
|
|
|
auto kernel_graph = graph->cast<KernelGraphPtr>(); |
|
|
|
kernel_graph->AddValueNodeToGraph(mutiples_node); |
|
|
|
auto multiples = MakeValue(multiple_value); |
|
|
|
auto multiples_node = CreateValueNode(multiples, kNumberTypeInt64); |
|
|
|
MS_EXCEPTION_IF_NULL(multiples_node); |
|
|
|
|
|
|
|
auto tile_primitive = std::make_shared<Primitive>(kTileOpName); |
|
|
|
std::vector<std::string> input_names = {"x", "multiples"}; |
|
|
|
std::vector<std::string> output_names = {"output"}; |
|
|
|
tile_primitive->set_attr(kAttrInputNames, MakeValue(input_names)); |
|
|
|
tile_primitive->set_attr(kAttrOutputNames, MakeValue(output_names)); |
|
|
|
std::vector<AnfNodePtr> tile_inputs = {NewValueNode(tile_primitive), mul_node->input(2), mutiples_node}; |
|
|
|
|
|
|
|
std::vector<AnfNodePtr> tile_inputs; |
|
|
|
if (is_pynative) { |
|
|
|
tile_inputs = {NewValueNode(tile_primitive), mul_node->input(2)}; |
|
|
|
} else { |
|
|
|
auto kernel_graph = graph->cast<KernelGraphPtr>(); |
|
|
|
kernel_graph->AddValueNodeToGraph(multiples_node); |
|
|
|
tile_inputs = {NewValueNode(tile_primitive), mul_node->input(2), multiples_node}; |
|
|
|
} |
|
|
|
|
|
|
|
auto tile_node = graph->NewCNode(tile_inputs); |
|
|
|
MS_EXCEPTION_IF_NULL(tile_node); |
|
|
|
|
|
|
|
tile_node->set_scope(mul_node->scope()); |
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetPrevNodeOutputInferDataType(mul_node, 1)}, {labels_shape}, |
|
|
|
tile_node.get()); |
|
|
|
if (is_pynative) { |
|
|
|
AnfAlgo::SetNodeAttr(kAttrMultiples, MakeValue(multiples), tile_node); |
|
|
|
} |
|
|
|
// feature map set |
|
|
|
std::vector<size_t> feature_map_input_indexs; |
|
|
|
feature_map_input_indexs.push_back(0); |
|
|
|
AnfAlgo::SetNodeAttr(kIsFeatureMapInputList, MakeValue(feature_map_input_indexs), tile_node); |
|
|
|
return tile_node; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -368,7 +425,6 @@ const AnfNodePtr SparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Process(const F |
|
|
|
std::vector<AnfNodePtr> softmax_node_outputs; |
|
|
|
CreateMultipleOutputsOfAnfNode(graph, softmax_node, kSoftmaxCrossEntropyWithLogitsOutputNum, &softmax_node_outputs); |
|
|
|
auto reduce_node = CreateReduceMean(graph, sparse_softmax_node, softmax_node_outputs[0]); |
|
|
|
|
|
|
|
return reduce_node; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -450,5 +506,76 @@ const AnfNodePtr GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV2::Process(c |
|
|
|
manager->Replace(sparse_softmax_node_grad, softmax_node_outputs[1]); |
|
|
|
return mul_node; |
|
|
|
} |
|
|
|
|
|
|
|
const AnfNodePtr PynativeSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Process(const FuncGraphPtr &graph, |
|
|
|
const AnfNodePtr &node, |
|
|
|
const EquivPtr &) const { |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
|
|
|
|
auto sparse_softmax_node = node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(sparse_softmax_node); |
|
|
|
if (sparse_softmax_node->size() != kSparseSoftmaxCrossEntropyWithLogitsInputNum) { |
|
|
|
MS_LOG(EXCEPTION) << "Op SparseSoftmaxCrossEntropyWithLogits's input not equal " |
|
|
|
<< kSparseSoftmaxCrossEntropyWithLogitsInputNum; |
|
|
|
} |
|
|
|
if (AnfAlgo::HasNodeAttr(kAttrIsGrad, sparse_softmax_node) && |
|
|
|
AnfAlgo::GetNodeAttr<bool>(sparse_softmax_node, kAttrIsGrad)) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
CNodePtr softmax_node; |
|
|
|
auto one_hot_node = CreateOneHot(graph, sparse_softmax_node, true); |
|
|
|
softmax_node = CreateSoftmaxCrossEntropyWithLogits(graph, sparse_softmax_node, one_hot_node); |
|
|
|
|
|
|
|
std::vector<AnfNodePtr> softmax_node_outputs; |
|
|
|
CreateMultipleOutputsOfAnfNode(graph, softmax_node, kSoftmaxCrossEntropyWithLogitsOutputNum, &softmax_node_outputs); |
|
|
|
auto reduce_node = CreateReduceMean(graph, sparse_softmax_node, softmax_node_outputs[0], true); |
|
|
|
return reduce_node; |
|
|
|
} |
|
|
|
|
|
|
|
const BaseRef PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::DefinePattern() const { |
|
|
|
VarPtr x1 = std::make_shared<Var>(); |
|
|
|
VarPtr x2 = std::make_shared<Var>(); |
|
|
|
VarPtr x3 = std::make_shared<Var>(); |
|
|
|
VectorRef sparse_softmax_cross_entropy_with_logits({prim::kPrimSparseSoftmaxCrossEntropyWithLogits, x1, x2}); |
|
|
|
return VectorRef({prim::kPrimMul, sparse_softmax_cross_entropy_with_logits, x3}); |
|
|
|
} |
|
|
|
|
|
|
|
const AnfNodePtr PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Process(const FuncGraphPtr &graph, |
|
|
|
const AnfNodePtr &node, |
|
|
|
const EquivPtr &) const { |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
|
|
|
|
auto mul_node = node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(mul_node); |
|
|
|
if (mul_node->size() != kMulInputNum) { |
|
|
|
MS_LOG(EXCEPTION) << "Op Mul's input not equal " << kMulInputNum; |
|
|
|
} |
|
|
|
auto sparse_softmax_node = mul_node->input(1); |
|
|
|
auto sparse_softmax_node_grad = sparse_softmax_node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(sparse_softmax_node_grad); |
|
|
|
|
|
|
|
if (sparse_softmax_node_grad->size() != kSparseSoftmaxCrossEntropyWithLogitsInputNum) { |
|
|
|
MS_LOG(EXCEPTION) << "Op SparseSoftmaxCrossEntropyWithLogits's input not equal " |
|
|
|
<< kSparseSoftmaxCrossEntropyWithLogitsInputNum; |
|
|
|
} |
|
|
|
|
|
|
|
CNodePtr softmax_node; |
|
|
|
auto one_hot_node = CreateOneHot(graph, sparse_softmax_node_grad, true); |
|
|
|
softmax_node = CreateSoftmaxCrossEntropyWithLogits(graph, sparse_softmax_node_grad, one_hot_node); |
|
|
|
|
|
|
|
std::vector<AnfNodePtr> softmax_node_outputs; |
|
|
|
CreateMultipleOutputsOfAnfNode(graph, softmax_node, kSoftmaxCrossEntropyWithLogitsOutputNum, &softmax_node_outputs); |
|
|
|
auto tile_node = CreateTile(graph, sparse_softmax_node_grad, mul_node, true); |
|
|
|
auto real_div_node = CreateRealDiv(graph, sparse_softmax_node_grad, tile_node); |
|
|
|
auto expand_dims_node = CreateExpandDimsPynative(graph, real_div_node); |
|
|
|
|
|
|
|
mul_node->set_input(1, softmax_node_outputs[1]); |
|
|
|
mul_node->set_input(2, expand_dims_node); |
|
|
|
|
|
|
|
return mul_node; |
|
|
|
} |
|
|
|
} // namespace opt |
|
|
|
} // namespace mindspore |