| @@ -21,6 +21,7 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include "device/ascend/kernel_select_ascend.h" | #include "device/ascend/kernel_select_ascend.h" | ||||
| #include "kernel/kernel_query.h" | #include "kernel/kernel_query.h" | ||||
| #include "kernel/tbe/tbe_kernel_select.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| @@ -36,6 +37,16 @@ class KernelSelect { | |||||
| }; | }; | ||||
| using KernelSelectPtr = std::shared_ptr<KernelSelect>; | using KernelSelectPtr = std::shared_ptr<KernelSelect>; | ||||
| class SupportedChecker { | |||||
| public: | |||||
| SupportedChecker() = default; | |||||
| virtual ~SupportedChecker() = default; | |||||
| virtual bool CheckSupported(const AnfNodePtr &anf_node, const kernel::KernelBuildInfoPtr &select_kernel_build_info) { | |||||
| return kernel::CheckSupported(anf_node, select_kernel_build_info); | |||||
| } | |||||
| }; | |||||
| using SupportedCheckerPtr = std::shared_ptr<SupportedChecker>; | |||||
| class KernelQuery { | class KernelQuery { | ||||
| public: | public: | ||||
| KernelQuery() = default; | KernelQuery() = default; | ||||
| @@ -16,6 +16,9 @@ | |||||
| #include "pre_activate/ascend/ir_fission/topk_split.h" | #include "pre_activate/ascend/ir_fission/topk_split.h" | ||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| #include <unordered_set> | |||||
| #include "pre_activate/common/helper.h" | |||||
| #include "kernel/kernel_build_info.h" | |||||
| #include "utils/utils.h" | #include "utils/utils.h" | ||||
| #include "session/kernel_graph.h" | #include "session/kernel_graph.h" | ||||
| #include "session/anf_runtime_algorithm.h" | #include "session/anf_runtime_algorithm.h" | ||||
| @@ -25,6 +28,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| constexpr size_t kFloat16Len = 2; // size of float16; | constexpr size_t kFloat16Len = 2; // size of float16; | ||||
| constexpr size_t kTopkIndexK = 1; | |||||
| namespace { | namespace { | ||||
| tensor::TensorPtr CreateTensor(const AnfNodePtr &node) { | tensor::TensorPtr CreateTensor(const AnfNodePtr &node) { | ||||
| // 1 create tensor | // 1 create tensor | ||||
| @@ -70,37 +74,68 @@ ValueNodePtr CreateValueNode(const AnfNodePtr &node) { | |||||
| AnfAlgo::SetSelectKernelBuildInfo(builder1.Build(), indices_const.get()); | AnfAlgo::SetSelectKernelBuildInfo(builder1.Build(), indices_const.get()); | ||||
| return indices_const; | return indices_const; | ||||
| } | } | ||||
| kernel::KernelBuildInfoPtr CreateKernelBuildInfo() { | |||||
| kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; | |||||
| builder.SetInputsFormat({kOpFormat_DEFAULT, kOpFormat_DEFAULT}); | |||||
| builder.SetOutputsFormat({kOpFormat_DEFAULT, kOpFormat_DEFAULT}); | |||||
| builder.SetInputsDeviceType({kNumberTypeFloat16, kNumberTypeFloat16}); | |||||
| builder.SetOutputsDeviceType({kNumberTypeFloat16, kNumberTypeInt32}); | |||||
| return builder.Build(); | |||||
| } | |||||
| } // namespace | } // namespace | ||||
| const BaseRef TopKSplit::DefinePattern() const { | const BaseRef TopKSplit::DefinePattern() const { | ||||
| VarPtr X = std::make_shared<Var>(); | |||||
| MS_EXCEPTION_IF_NULL(X); | |||||
| VarPtr X1 = std::make_shared<Var>(); | |||||
| VarPtr X2 = std::make_shared<Var>(); | |||||
| auto prim = std::make_shared<Primitive>(kTopKOpName); | auto prim = std::make_shared<Primitive>(kTopKOpName); | ||||
| MS_EXCEPTION_IF_NULL(prim); | |||||
| return VectorRef({prim, X}); | |||||
| return VectorRef({prim, X1, X2}); | |||||
| } | } | ||||
| const AnfNodePtr TopKSplit::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const { | const AnfNodePtr TopKSplit::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const { | ||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| auto kernel_graph = func_graph->cast<KernelGraphPtr>(); | auto kernel_graph = func_graph->cast<KernelGraphPtr>(); | ||||
| auto indices_const = CreateValueNode(node); | |||||
| // set value node as topk's input | // set value node as topk's input | ||||
| auto cnode = node->cast<CNodePtr>(); | auto cnode = node->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| MS_LOG(INFO) << "already has input size: " << cnode->inputs().size(); | |||||
| cnode->add_input(indices_const); | |||||
| // Copy a new node to check supported. | |||||
| std::vector<AnfNodePtr> new_inputs{NewValueNode(std::make_shared<Primitive>(kTopKOpName))}; | |||||
| new_inputs.insert(new_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end()); | |||||
| CNodePtr new_cnode = func_graph->NewCNode(new_inputs); | |||||
| MS_EXCEPTION_IF_NULL(new_cnode); | |||||
| new_cnode->set_abstract(cnode->abstract()); | |||||
| new_cnode->set_scope(cnode->scope()); | |||||
| AnfAlgo::CopyNodeAttrs(cnode, new_cnode); | |||||
| CheckCNodeInputSize(new_cnode, kTopkInputNum); | |||||
| // Convert the tensor input to scalar and convert it to attr | |||||
| auto input_k = new_cnode->input(kTopkIndexK + 1); | |||||
| MS_EXCEPTION_IF_NULL(input_k); | |||||
| if (!IsValueNode<tensor::Tensor>(input_k)) { | |||||
| return nullptr; | |||||
| } | |||||
| ValuePtr value = GetValueNode(input_k); | |||||
| MS_EXCEPTION_IF_NULL(value); | |||||
| auto tensor = value->cast<tensor::TensorPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(tensor); | |||||
| int32_t *data = reinterpret_cast<int32_t *>(tensor->data_c()); | |||||
| MS_EXCEPTION_IF_NULL(data); | |||||
| auto new_value_node = std::make_shared<ValueNode>(MakeValue(*data)); | |||||
| new_cnode->set_input(kTopkIndexK + 1, new_value_node); | |||||
| std::unordered_set<size_t> attr_index{kTopkIndexK}; | |||||
| ConstInputToAttr(new_cnode, attr_index); | |||||
| auto indices_const = CreateValueNode(new_cnode); | |||||
| new_cnode->add_input(indices_const); | |||||
| MS_EXCEPTION_IF_NULL(supported_checker_); | |||||
| if (!supported_checker_->CheckSupported(new_cnode, CreateKernelBuildInfo())) { | |||||
| return nullptr; | |||||
| } | |||||
| if (kernel_graph != nullptr) { | if (kernel_graph != nullptr) { | ||||
| kernel_graph->AddValueNodeToGraph(indices_const); | kernel_graph->AddValueNodeToGraph(indices_const); | ||||
| } | } | ||||
| CNodePtr new_cnode = nullptr; | |||||
| if (kernel_graph == nullptr) { | |||||
| new_cnode = std::make_shared<CNode>(*cnode); | |||||
| } else { | |||||
| new_cnode = kernel_graph->NewCNode(cnode); | |||||
| } | |||||
| MS_EXCEPTION_IF_NULL(new_cnode); | |||||
| return new_cnode; | return new_cnode; | ||||
| } | } | ||||
| } // namespace opt | } // namespace opt | ||||
| @@ -16,15 +16,22 @@ | |||||
| #ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_TOPK_SPLIT_H_ | #ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_TOPK_SPLIT_H_ | ||||
| #define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_TOPK_SPLIT_H_ | #define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_TOPK_SPLIT_H_ | ||||
| #include <memory> | |||||
| #include "pre_activate/common/optimizer.h" | #include "pre_activate/common/optimizer.h" | ||||
| #include "pre_activate/ascend/ascend_helper.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| class TopKSplit : public PatternProcessPass { | class TopKSplit : public PatternProcessPass { | ||||
| public: | public: | ||||
| explicit TopKSplit(bool multigraph = true) : PatternProcessPass("topk_split", multigraph) {} | |||||
| explicit TopKSplit(bool multigraph = true) | |||||
| : PatternProcessPass("topk_split", multigraph), supported_checker_(std::make_shared<SupportedChecker>()) {} | |||||
| ~TopKSplit() override = default; | ~TopKSplit() override = default; | ||||
| const BaseRef DefinePattern() const override; | const BaseRef DefinePattern() const override; | ||||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | ||||
| private: | |||||
| SupportedCheckerPtr supported_checker_; | |||||
| }; | }; | ||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -422,5 +422,47 @@ AnfNodePtr CreatTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePt | |||||
| AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, tuple_getitem.get()); | AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, tuple_getitem.get()); | ||||
| return tuple_getitem; | return tuple_getitem; | ||||
| } | } | ||||
| void ConstInputToAttr(const CNodePtr &cnode, const std::unordered_set<size_t> &input_attrs) { | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| std::vector<AnfNodePtr> new_inputs; | |||||
| std::vector<std::string> new_input_names; | |||||
| auto primitive = AnfAlgo::GetCNodePrimitive(cnode); | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| auto input_names = primitive->GetAttr(kAttrInputNames); | |||||
| if (input_names == nullptr) { | |||||
| MS_LOG(DEBUG) << "input_names are nullptr in cnode[" + cnode->DebugString() + "]"; | |||||
| return; | |||||
| } | |||||
| auto input_names_vec = GetValue<std::vector<std::string>>(input_names); | |||||
| auto inputs = cnode->inputs(); | |||||
| new_inputs.push_back(inputs[0]); | |||||
| bool need_update = false; | |||||
| for (size_t i = 0; i < inputs.size() - 1; ++i) { | |||||
| auto input_node = inputs[i + 1]; | |||||
| MS_EXCEPTION_IF_NULL(input_node); | |||||
| if (input_attrs.find(i) != input_attrs.end() && input_node->isa<ValueNode>()) { | |||||
| auto value_node = input_node->cast<ValueNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(value_node); | |||||
| MS_LOG(DEBUG) << "start erase input[" << i << "] of cnode[" + cnode->DebugString() + "]"; | |||||
| if (i >= input_names_vec.size()) { | |||||
| MS_LOG(EXCEPTION) << "index " << i << " is larger than input names size [" << input_names_vec.size() << "]"; | |||||
| } | |||||
| primitive->set_attr(input_names_vec[i], value_node->value()); | |||||
| need_update = true; | |||||
| } else { | |||||
| new_inputs.push_back(input_node); | |||||
| if (i < input_names_vec.size()) { | |||||
| new_input_names.push_back(input_names_vec[i]); | |||||
| } | |||||
| } | |||||
| } | |||||
| if (need_update) { | |||||
| // Update cnode's inputs | |||||
| cnode->set_inputs(new_inputs); | |||||
| // Update cnode's input_names attr | |||||
| primitive->set_attr(kAttrInputNames, MakeValue(new_input_names)); | |||||
| } | |||||
| } | |||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -19,6 +19,7 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| #include <string> | #include <string> | ||||
| #include <unordered_set> | |||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| #include "session/kernel_graph.h" | #include "session/kernel_graph.h" | ||||
| #include "common/utils.h" | #include "common/utils.h" | ||||
| @@ -86,6 +87,7 @@ constexpr size_t kAdamApplyOneOutputNum = 3; | |||||
| constexpr size_t kBackendTransDataInputNum = 2; | constexpr size_t kBackendTransDataInputNum = 2; | ||||
| constexpr size_t kApplyMomentumInputNum = 6; | constexpr size_t kApplyMomentumInputNum = 6; | ||||
| constexpr size_t kBiasAddInputNum = 3; | constexpr size_t kBiasAddInputNum = 3; | ||||
| constexpr size_t kTopkInputNum = 3; | |||||
| enum FusedBatchNormInput { | enum FusedBatchNormInput { | ||||
| kX = 1, | kX = 1, | ||||
| @@ -150,6 +152,8 @@ void RemoveNopNode(session::KernelGraph *const graph); | |||||
| AnfNodePtr CreatTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_idx); | AnfNodePtr CreatTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_idx); | ||||
| bool IsUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node); | bool IsUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node); | ||||
| void ConstInputToAttr(const CNodePtr &cnode, const std::unordered_set<size_t> &input_attrs); | |||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_HELPER_H_ | #endif // MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_HELPER_H_ | ||||
| @@ -52,7 +52,6 @@ ConstInputToAttrInfoRegistry::ConstInputToAttrInfoRegistry() { | |||||
| Register(kFlattenGradOpName, {1}); | Register(kFlattenGradOpName, {1}); | ||||
| Register(kExpandDimsOpName, {1}); | Register(kExpandDimsOpName, {1}); | ||||
| Register(kSplitOpName, {0}); | Register(kSplitOpName, {0}); | ||||
| Register(kTopKOpName, {1}); | |||||
| Register(kErfOpName, {1}); | Register(kErfOpName, {1}); | ||||
| Register(kSparseApplyAdagradOpName, {2}); | Register(kSparseApplyAdagradOpName, {2}); | ||||
| Register(kResizeNearestNeighborGrad, {1}); | Register(kResizeNearestNeighborGrad, {1}); | ||||
| @@ -18,10 +18,10 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include <string> | #include <string> | ||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include <unordered_set> | |||||
| #include <memory> | #include <memory> | ||||
| #include "pre_activate/pass/const_input_to_attr_registry.h" | #include "pre_activate/pass/const_input_to_attr_registry.h" | ||||
| #include "pre_activate/common/helper.h" | |||||
| #include "utils/utils.h" | #include "utils/utils.h" | ||||
| #include "utils/context/ms_context.h" | #include "utils/context/ms_context.h" | ||||
| #include "operator/ops.h" | #include "operator/ops.h" | ||||
| @@ -29,50 +29,6 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| namespace { | |||||
| void ConstInputToAttr(const CNodePtr &cnode, const std::unordered_set<size_t> &input_attrs) { | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| std::vector<AnfNodePtr> new_inputs; | |||||
| std::vector<std::string> new_input_names; | |||||
| auto primitive = AnfAlgo::GetCNodePrimitive(cnode); | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| auto input_names = primitive->GetAttr(kAttrInputNames); | |||||
| if (input_names == nullptr) { | |||||
| MS_LOG(DEBUG) << "input_names are nullptr in cnode[" + cnode->DebugString() + "]"; | |||||
| return; | |||||
| } | |||||
| auto input_names_vec = GetValue<std::vector<std::string>>(input_names); | |||||
| auto inputs = cnode->inputs(); | |||||
| new_inputs.push_back(inputs[0]); | |||||
| bool need_update = false; | |||||
| for (size_t i = 0; i < inputs.size() - 1; ++i) { | |||||
| auto input_node = inputs[i + 1]; | |||||
| MS_EXCEPTION_IF_NULL(input_node); | |||||
| if (input_attrs.find(i) != input_attrs.end() && input_node->isa<ValueNode>()) { | |||||
| auto value_node = input_node->cast<ValueNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(value_node); | |||||
| MS_LOG(DEBUG) << "start erase input[" << i << "] of cnode[" + cnode->DebugString() + "]"; | |||||
| if (i >= input_names_vec.size()) { | |||||
| MS_LOG(EXCEPTION) << "index " << i << " is larger than input names size [" << input_names_vec.size() << "]"; | |||||
| } | |||||
| primitive->set_attr(input_names_vec[i], value_node->value()); | |||||
| need_update = true; | |||||
| } else { | |||||
| new_inputs.push_back(input_node); | |||||
| if (i < input_names_vec.size()) { | |||||
| new_input_names.push_back(input_names_vec[i]); | |||||
| } | |||||
| } | |||||
| } | |||||
| if (need_update) { | |||||
| // Update cnode's inputs | |||||
| cnode->set_inputs(new_inputs); | |||||
| // Update cnode's input_names attr | |||||
| primitive->set_attr(kAttrInputNames, MakeValue(new_input_names)); | |||||
| } | |||||
| } | |||||
| } // namespace | |||||
| const AnfNodePtr ConvertConstInputToAttr::Process(const FuncGraphPtr &, const AnfNodePtr &node, | const AnfNodePtr ConvertConstInputToAttr::Process(const FuncGraphPtr &, const AnfNodePtr &node, | ||||
| const EquivPtr &) const { | const EquivPtr &) const { | ||||
| if (node == nullptr || !AnfAlgo::IsRealCNodeKernel(node)) { | if (node == nullptr || !AnfAlgo::IsRealCNodeKernel(node)) { | ||||
| @@ -17,8 +17,13 @@ | |||||
| #include "common/backend_common_test.h" | #include "common/backend_common_test.h" | ||||
| #include "common/py_func_graph_fetcher.h" | #include "common/py_func_graph_fetcher.h" | ||||
| #include "device/kernel_info.h" | #include "device/kernel_info.h" | ||||
| #include "pre_activate/ascend/ir_fission/topk_split.h" | |||||
| #include "pre_activate/pass/convert_const_input_to_attr.h" | |||||
| #include "debug/anf_ir_dump.h" | #include "debug/anf_ir_dump.h" | ||||
| #define private public | |||||
| #define protected public | |||||
| #include "pre_activate/ascend/ir_fission/topk_split.h" | |||||
| #undef private | |||||
| #undef protected | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| @@ -30,6 +35,15 @@ class TestHWTopKSplit : public BackendCommon { | |||||
| UT::PyFuncGraphFetcher get_py_fun_; | UT::PyFuncGraphFetcher get_py_fun_; | ||||
| }; | }; | ||||
| class MockSupportedChecker : public SupportedChecker { | |||||
| public: | |||||
| MockSupportedChecker() = default; | |||||
| ~MockSupportedChecker() override = default; | |||||
| bool CheckSupported(const AnfNodePtr &anf_node, const kernel::KernelBuildInfoPtr &select_kernel_build_info) override { | |||||
| return true; | |||||
| } | |||||
| }; // namespace opt | |||||
| TEST_F(TestHWTopKSplit, test_topk_split) { | TEST_F(TestHWTopKSplit, test_topk_split) { | ||||
| /* | /* | ||||
| * def before(input): | * def before(input): | ||||
| @@ -40,19 +54,25 @@ TEST_F(TestHWTopKSplit, test_topk_split) { | |||||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_topk_split", "before"); | FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_topk_split", "before"); | ||||
| std::vector<int> shp{4, 4}; | std::vector<int> shp{4, 4}; | ||||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | ||||
| g->parameters()[0]->set_abstract(x_abstract); | |||||
| auto ret = g->get_return(); | |||||
| EXPECT_NE(ret, nullptr); | |||||
| auto tuple_getitem = ret->input(1); | |||||
| EXPECT_NE(tuple_getitem, nullptr); | |||||
| auto topk = tuple_getitem->cast<CNodePtr>()->input(1); | |||||
| topk->set_abstract(x_abstract); | |||||
| AbstractBasePtrList args_spec_list{x_abstract}; | |||||
| auto kernel_graph = GetKernelGraph(g, args_spec_list); | |||||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | auto optimizer = std::make_shared<opt::GraphOptimizer>(); | ||||
| auto pm = std::make_shared<opt::PassManager>(); | auto pm = std::make_shared<opt::PassManager>(); | ||||
| pm->AddPass(std::make_shared<opt::TopKSplit>()); | |||||
| pm->AddPass(std::make_shared<opt::ConvertConstInputToAttr>()); | |||||
| auto topk_split = std::make_shared<opt::TopKSplit>(); | |||||
| topk_split->supported_checker_ = std::make_shared<MockSupportedChecker>(); | |||||
| pm->AddPass(topk_split); | |||||
| optimizer->AddPassManager(pm); | optimizer->AddPassManager(pm); | ||||
| FuncGraphPtr new_graph = optimizer->Optimize(g); | |||||
| FuncGraphPtr new_graph = optimizer->Optimize(kernel_graph); | |||||
| auto ret = new_graph->get_return(); | |||||
| EXPECT_NE(ret, nullptr); | |||||
| auto make_tuple = ret->input(1); | |||||
| EXPECT_NE(make_tuple, nullptr); | |||||
| auto tuple_getitem = make_tuple->cast<CNodePtr>()->input(1); | |||||
| EXPECT_NE(tuple_getitem, nullptr); | |||||
| auto topk = tuple_getitem->cast<CNodePtr>()->input(1); | |||||
| auto topk_cnode = topk->cast<CNodePtr>(); | auto topk_cnode = topk->cast<CNodePtr>(); | ||||
| EXPECT_EQ(topk_cnode->inputs().size(), 3); | EXPECT_EQ(topk_cnode->inputs().size(), 3); | ||||
| EXPECT_TRUE(topk_cnode->input(2)->isa<ValueNode>()); | EXPECT_TRUE(topk_cnode->input(2)->isa<ValueNode>()); | ||||
| @@ -35,7 +35,7 @@ def test_topk_split(tag): | |||||
| @fns | @fns | ||||
| def before(input): | def before(input): | ||||
| topk = TopK(input) | |||||
| topk = TopK(input, 2) | |||||
| output = tuple_getitem(topk, 0) | output = tuple_getitem(topk, 0) | ||||
| return output | return output | ||||