|
|
|
@@ -18,6 +18,7 @@ |
|
|
|
#include <memory> |
|
|
|
#include <vector> |
|
|
|
#include <tuple> |
|
|
|
#include <string> |
|
|
|
|
|
|
|
#include "session/anf_runtime_algorithm.h" |
|
|
|
#include "common/utils.h" |
|
|
|
@@ -50,6 +51,8 @@ CNodePtr GenerateSquareSumV1(const FuncGraphPtr &graph, const CNodePtr &square, |
|
|
|
square_sumv1->set_scope(sum->scope()); |
|
|
|
AnfAlgo::CopyNodeAttr(kAttrAxis, sum, square_sumv1); |
|
|
|
AnfAlgo::CopyNodeAttr(kAttrKeepDims, sum, square_sumv1); |
|
|
|
auto names = MakeValue<std::vector<std::string>>({prim::kPrimSquare->name(), prim::kPrimReduceSum->name()}); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrDatadumpOriginalNames, names, square_sumv1); |
|
|
|
return square_sumv1; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -71,6 +74,8 @@ CNodePtr GenerateSquareSumV2(const FuncGraphPtr &graph, const CNodePtr &square, |
|
|
|
square_sumv2->set_scope(sum->scope()); |
|
|
|
AnfAlgo::CopyNodeAttr(kAttrAxis, sum, square_sumv2); |
|
|
|
AnfAlgo::CopyNodeAttr(kAttrKeepDims, sum, square_sumv2); |
|
|
|
auto names = MakeValue<std::vector<std::string>>({prim::kPrimSquare->name(), prim::kPrimReduceSum->name()}); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrDatadumpOriginalNames, names, square_sumv2); |
|
|
|
return square_sumv2; |
|
|
|
} |
|
|
|
|
|
|
|
|