| @@ -21,6 +21,7 @@ | |||
| #include <vector> | |||
| #include "device/ascend/kernel_select_ascend.h" | |||
| #include "kernel/kernel_query.h" | |||
| #include "kernel/tbe/tbe_kernel_select.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| @@ -36,6 +37,16 @@ class KernelSelect { | |||
| }; | |||
| using KernelSelectPtr = std::shared_ptr<KernelSelect>; | |||
| class SupportedChecker { | |||
| public: | |||
| SupportedChecker() = default; | |||
| virtual ~SupportedChecker() = default; | |||
| virtual bool CheckSupported(const AnfNodePtr &anf_node, const kernel::KernelBuildInfoPtr &select_kernel_build_info) { | |||
| return kernel::CheckSupported(anf_node, select_kernel_build_info); | |||
| } | |||
| }; | |||
| using SupportedCheckerPtr = std::shared_ptr<SupportedChecker>; | |||
| class KernelQuery { | |||
| public: | |||
| KernelQuery() = default; | |||
| @@ -16,6 +16,9 @@ | |||
| #include "pre_activate/ascend/ir_fission/topk_split.h" | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <unordered_set> | |||
| #include "pre_activate/common/helper.h" | |||
| #include "kernel/kernel_build_info.h" | |||
| #include "utils/utils.h" | |||
| #include "session/kernel_graph.h" | |||
| #include "session/anf_runtime_algorithm.h" | |||
| @@ -25,6 +28,7 @@ | |||
| namespace mindspore { | |||
| namespace opt { | |||
| constexpr size_t kFloat16Len = 2; // size of float16; | |||
| constexpr size_t kTopkIndexK = 1; | |||
| namespace { | |||
| tensor::TensorPtr CreateTensor(const AnfNodePtr &node) { | |||
| // 1 create tensor | |||
| @@ -70,37 +74,68 @@ ValueNodePtr CreateValueNode(const AnfNodePtr &node) { | |||
| AnfAlgo::SetSelectKernelBuildInfo(builder1.Build(), indices_const.get()); | |||
| return indices_const; | |||
| } | |||
| kernel::KernelBuildInfoPtr CreateKernelBuildInfo() { | |||
| kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; | |||
| builder.SetInputsFormat({kOpFormat_DEFAULT, kOpFormat_DEFAULT}); | |||
| builder.SetOutputsFormat({kOpFormat_DEFAULT, kOpFormat_DEFAULT}); | |||
| builder.SetInputsDeviceType({kNumberTypeFloat16, kNumberTypeFloat16}); | |||
| builder.SetOutputsDeviceType({kNumberTypeFloat16, kNumberTypeInt32}); | |||
| return builder.Build(); | |||
| } | |||
| } // namespace | |||
| const BaseRef TopKSplit::DefinePattern() const { | |||
| VarPtr X = std::make_shared<Var>(); | |||
| MS_EXCEPTION_IF_NULL(X); | |||
| VarPtr X1 = std::make_shared<Var>(); | |||
| VarPtr X2 = std::make_shared<Var>(); | |||
| auto prim = std::make_shared<Primitive>(kTopKOpName); | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| return VectorRef({prim, X}); | |||
| return VectorRef({prim, X1, X2}); | |||
| } | |||
| const AnfNodePtr TopKSplit::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto kernel_graph = func_graph->cast<KernelGraphPtr>(); | |||
| auto indices_const = CreateValueNode(node); | |||
| // set value node as topk's input | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| MS_LOG(INFO) << "already has input size: " << cnode->inputs().size(); | |||
| cnode->add_input(indices_const); | |||
| // Copy a new node to check supported. | |||
| std::vector<AnfNodePtr> new_inputs{NewValueNode(std::make_shared<Primitive>(kTopKOpName))}; | |||
| new_inputs.insert(new_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end()); | |||
| CNodePtr new_cnode = func_graph->NewCNode(new_inputs); | |||
| MS_EXCEPTION_IF_NULL(new_cnode); | |||
| new_cnode->set_abstract(cnode->abstract()); | |||
| new_cnode->set_scope(cnode->scope()); | |||
| AnfAlgo::CopyNodeAttrs(cnode, new_cnode); | |||
| CheckCNodeInputSize(new_cnode, kTopkInputNum); | |||
| // Convert the tensor input to scalar and convert it to attr | |||
| auto input_k = new_cnode->input(kTopkIndexK + 1); | |||
| MS_EXCEPTION_IF_NULL(input_k); | |||
| if (!IsValueNode<tensor::Tensor>(input_k)) { | |||
| return nullptr; | |||
| } | |||
| ValuePtr value = GetValueNode(input_k); | |||
| MS_EXCEPTION_IF_NULL(value); | |||
| auto tensor = value->cast<tensor::TensorPtr>(); | |||
| MS_EXCEPTION_IF_NULL(tensor); | |||
| int32_t *data = reinterpret_cast<int32_t *>(tensor->data_c()); | |||
| MS_EXCEPTION_IF_NULL(data); | |||
| auto new_value_node = std::make_shared<ValueNode>(MakeValue(*data)); | |||
| new_cnode->set_input(kTopkIndexK + 1, new_value_node); | |||
| std::unordered_set<size_t> attr_index{kTopkIndexK}; | |||
| ConstInputToAttr(new_cnode, attr_index); | |||
| auto indices_const = CreateValueNode(new_cnode); | |||
| new_cnode->add_input(indices_const); | |||
| MS_EXCEPTION_IF_NULL(supported_checker_); | |||
| if (!supported_checker_->CheckSupported(new_cnode, CreateKernelBuildInfo())) { | |||
| return nullptr; | |||
| } | |||
| if (kernel_graph != nullptr) { | |||
| kernel_graph->AddValueNodeToGraph(indices_const); | |||
| } | |||
| CNodePtr new_cnode = nullptr; | |||
| if (kernel_graph == nullptr) { | |||
| new_cnode = std::make_shared<CNode>(*cnode); | |||
| } else { | |||
| new_cnode = kernel_graph->NewCNode(cnode); | |||
| } | |||
| MS_EXCEPTION_IF_NULL(new_cnode); | |||
| return new_cnode; | |||
| } | |||
| } // namespace opt | |||
| @@ -16,15 +16,22 @@ | |||
| #ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_TOPK_SPLIT_H_ | |||
| #define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_TOPK_SPLIT_H_ | |||
| #include <memory> | |||
| #include "pre_activate/common/optimizer.h" | |||
| #include "pre_activate/ascend/ascend_helper.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class TopKSplit : public PatternProcessPass { | |||
| public: | |||
| explicit TopKSplit(bool multigraph = true) : PatternProcessPass("topk_split", multigraph) {} | |||
| explicit TopKSplit(bool multigraph = true) | |||
| : PatternProcessPass("topk_split", multigraph), supported_checker_(std::make_shared<SupportedChecker>()) {} | |||
| ~TopKSplit() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| private: | |||
| SupportedCheckerPtr supported_checker_; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -422,5 +422,47 @@ AnfNodePtr CreatTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePt | |||
| AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, tuple_getitem.get()); | |||
| return tuple_getitem; | |||
| } | |||
| void ConstInputToAttr(const CNodePtr &cnode, const std::unordered_set<size_t> &input_attrs) { | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| std::vector<AnfNodePtr> new_inputs; | |||
| std::vector<std::string> new_input_names; | |||
| auto primitive = AnfAlgo::GetCNodePrimitive(cnode); | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto input_names = primitive->GetAttr(kAttrInputNames); | |||
| if (input_names == nullptr) { | |||
| MS_LOG(DEBUG) << "input_names are nullptr in cnode[" + cnode->DebugString() + "]"; | |||
| return; | |||
| } | |||
| auto input_names_vec = GetValue<std::vector<std::string>>(input_names); | |||
| auto inputs = cnode->inputs(); | |||
| new_inputs.push_back(inputs[0]); | |||
| bool need_update = false; | |||
| for (size_t i = 0; i < inputs.size() - 1; ++i) { | |||
| auto input_node = inputs[i + 1]; | |||
| MS_EXCEPTION_IF_NULL(input_node); | |||
| if (input_attrs.find(i) != input_attrs.end() && input_node->isa<ValueNode>()) { | |||
| auto value_node = input_node->cast<ValueNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(value_node); | |||
| MS_LOG(DEBUG) << "start erase input[" << i << "] of cnode[" + cnode->DebugString() + "]"; | |||
| if (i >= input_names_vec.size()) { | |||
| MS_LOG(EXCEPTION) << "index " << i << " is larger than input names size [" << input_names_vec.size() << "]"; | |||
| } | |||
| primitive->set_attr(input_names_vec[i], value_node->value()); | |||
| need_update = true; | |||
| } else { | |||
| new_inputs.push_back(input_node); | |||
| if (i < input_names_vec.size()) { | |||
| new_input_names.push_back(input_names_vec[i]); | |||
| } | |||
| } | |||
| } | |||
| if (need_update) { | |||
| // Update cnode's inputs | |||
| cnode->set_inputs(new_inputs); | |||
| // Update cnode's input_names attr | |||
| primitive->set_attr(kAttrInputNames, MakeValue(new_input_names)); | |||
| } | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -19,6 +19,7 @@ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <unordered_set> | |||
| #include "ir/func_graph.h" | |||
| #include "session/kernel_graph.h" | |||
| #include "common/utils.h" | |||
| @@ -86,6 +87,7 @@ constexpr size_t kAdamApplyOneOutputNum = 3; | |||
| constexpr size_t kBackendTransDataInputNum = 2; | |||
| constexpr size_t kApplyMomentumInputNum = 6; | |||
| constexpr size_t kBiasAddInputNum = 3; | |||
| constexpr size_t kTopkInputNum = 3; | |||
| enum FusedBatchNormInput { | |||
| kX = 1, | |||
| @@ -150,6 +152,8 @@ void RemoveNopNode(session::KernelGraph *const graph); | |||
| AnfNodePtr CreatTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_idx); | |||
| bool IsUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node); | |||
| void ConstInputToAttr(const CNodePtr &cnode, const std::unordered_set<size_t> &input_attrs); | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_HELPER_H_ | |||
| @@ -52,7 +52,6 @@ ConstInputToAttrInfoRegistry::ConstInputToAttrInfoRegistry() { | |||
| Register(kFlattenGradOpName, {1}); | |||
| Register(kExpandDimsOpName, {1}); | |||
| Register(kSplitOpName, {0}); | |||
| Register(kTopKOpName, {1}); | |||
| Register(kErfOpName, {1}); | |||
| Register(kSparseApplyAdagradOpName, {2}); | |||
| Register(kResizeNearestNeighborGrad, {1}); | |||
| @@ -18,10 +18,10 @@ | |||
| #include <vector> | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include <unordered_set> | |||
| #include <memory> | |||
| #include "pre_activate/pass/const_input_to_attr_registry.h" | |||
| #include "pre_activate/common/helper.h" | |||
| #include "utils/utils.h" | |||
| #include "utils/context/ms_context.h" | |||
| #include "operator/ops.h" | |||
| @@ -29,50 +29,6 @@ | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| void ConstInputToAttr(const CNodePtr &cnode, const std::unordered_set<size_t> &input_attrs) { | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| std::vector<AnfNodePtr> new_inputs; | |||
| std::vector<std::string> new_input_names; | |||
| auto primitive = AnfAlgo::GetCNodePrimitive(cnode); | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto input_names = primitive->GetAttr(kAttrInputNames); | |||
| if (input_names == nullptr) { | |||
| MS_LOG(DEBUG) << "input_names are nullptr in cnode[" + cnode->DebugString() + "]"; | |||
| return; | |||
| } | |||
| auto input_names_vec = GetValue<std::vector<std::string>>(input_names); | |||
| auto inputs = cnode->inputs(); | |||
| new_inputs.push_back(inputs[0]); | |||
| bool need_update = false; | |||
| for (size_t i = 0; i < inputs.size() - 1; ++i) { | |||
| auto input_node = inputs[i + 1]; | |||
| MS_EXCEPTION_IF_NULL(input_node); | |||
| if (input_attrs.find(i) != input_attrs.end() && input_node->isa<ValueNode>()) { | |||
| auto value_node = input_node->cast<ValueNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(value_node); | |||
| MS_LOG(DEBUG) << "start erase input[" << i << "] of cnode[" + cnode->DebugString() + "]"; | |||
| if (i >= input_names_vec.size()) { | |||
| MS_LOG(EXCEPTION) << "index " << i << " is larger than input names size [" << input_names_vec.size() << "]"; | |||
| } | |||
| primitive->set_attr(input_names_vec[i], value_node->value()); | |||
| need_update = true; | |||
| } else { | |||
| new_inputs.push_back(input_node); | |||
| if (i < input_names_vec.size()) { | |||
| new_input_names.push_back(input_names_vec[i]); | |||
| } | |||
| } | |||
| } | |||
| if (need_update) { | |||
| // Update cnode's inputs | |||
| cnode->set_inputs(new_inputs); | |||
| // Update cnode's input_names attr | |||
| primitive->set_attr(kAttrInputNames, MakeValue(new_input_names)); | |||
| } | |||
| } | |||
| } // namespace | |||
| const AnfNodePtr ConvertConstInputToAttr::Process(const FuncGraphPtr &, const AnfNodePtr &node, | |||
| const EquivPtr &) const { | |||
| if (node == nullptr || !AnfAlgo::IsRealCNodeKernel(node)) { | |||
| @@ -17,8 +17,13 @@ | |||
| #include "common/backend_common_test.h" | |||
| #include "common/py_func_graph_fetcher.h" | |||
| #include "device/kernel_info.h" | |||
| #include "pre_activate/ascend/ir_fission/topk_split.h" | |||
| #include "pre_activate/pass/convert_const_input_to_attr.h" | |||
| #include "debug/anf_ir_dump.h" | |||
| #define private public | |||
| #define protected public | |||
| #include "pre_activate/ascend/ir_fission/topk_split.h" | |||
| #undef private | |||
| #undef protected | |||
| namespace mindspore { | |||
| namespace opt { | |||
| @@ -30,6 +35,15 @@ class TestHWTopKSplit : public BackendCommon { | |||
| UT::PyFuncGraphFetcher get_py_fun_; | |||
| }; | |||
| class MockSupportedChecker : public SupportedChecker { | |||
| public: | |||
| MockSupportedChecker() = default; | |||
| ~MockSupportedChecker() override = default; | |||
| bool CheckSupported(const AnfNodePtr &anf_node, const kernel::KernelBuildInfoPtr &select_kernel_build_info) override { | |||
| return true; | |||
| } | |||
| }; // namespace opt | |||
| TEST_F(TestHWTopKSplit, test_topk_split) { | |||
| /* | |||
| * def before(input): | |||
| @@ -40,19 +54,25 @@ TEST_F(TestHWTopKSplit, test_topk_split) { | |||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_topk_split", "before"); | |||
| std::vector<int> shp{4, 4}; | |||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | |||
| g->parameters()[0]->set_abstract(x_abstract); | |||
| auto ret = g->get_return(); | |||
| EXPECT_NE(ret, nullptr); | |||
| auto tuple_getitem = ret->input(1); | |||
| EXPECT_NE(tuple_getitem, nullptr); | |||
| auto topk = tuple_getitem->cast<CNodePtr>()->input(1); | |||
| topk->set_abstract(x_abstract); | |||
| AbstractBasePtrList args_spec_list{x_abstract}; | |||
| auto kernel_graph = GetKernelGraph(g, args_spec_list); | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| pm->AddPass(std::make_shared<opt::TopKSplit>()); | |||
| pm->AddPass(std::make_shared<opt::ConvertConstInputToAttr>()); | |||
| auto topk_split = std::make_shared<opt::TopKSplit>(); | |||
| topk_split->supported_checker_ = std::make_shared<MockSupportedChecker>(); | |||
| pm->AddPass(topk_split); | |||
| optimizer->AddPassManager(pm); | |||
| FuncGraphPtr new_graph = optimizer->Optimize(g); | |||
| FuncGraphPtr new_graph = optimizer->Optimize(kernel_graph); | |||
| auto ret = new_graph->get_return(); | |||
| EXPECT_NE(ret, nullptr); | |||
| auto make_tuple = ret->input(1); | |||
| EXPECT_NE(make_tuple, nullptr); | |||
| auto tuple_getitem = make_tuple->cast<CNodePtr>()->input(1); | |||
| EXPECT_NE(tuple_getitem, nullptr); | |||
| auto topk = tuple_getitem->cast<CNodePtr>()->input(1); | |||
| auto topk_cnode = topk->cast<CNodePtr>(); | |||
| EXPECT_EQ(topk_cnode->inputs().size(), 3); | |||
| EXPECT_TRUE(topk_cnode->input(2)->isa<ValueNode>()); | |||
| @@ -35,7 +35,7 @@ def test_topk_split(tag): | |||
| @fns | |||
| def before(input): | |||
| topk = TopK(input) | |||
| topk = TopK(input, 2) | |||
| output = tuple_getitem(topk, 0) | |||
| return output | |||