|
|
|
@@ -19,6 +19,7 @@ |
|
|
|
#include "device/kernel_info.h" |
|
|
|
#include "pre_activate/pass/convert_const_input_to_attr.h" |
|
|
|
#include "debug/anf_ir_dump.h" |
|
|
|
#include "session/anf_runtime_algorithm.h" |
|
|
|
#define private public |
|
|
|
#define protected public |
|
|
|
#include "pre_activate/ascend/ir_fission/topk_split.h" |
|
|
|
@@ -32,6 +33,21 @@ class TestHWTopKSplit : public BackendCommon { |
|
|
|
TestHWTopKSplit() : get_py_fun_("gtest_input.pre_activate.topk_split_test", true) {} |
|
|
|
~TestHWTopKSplit() override = default; |
|
|
|
|
|
|
|
CNodePtr GetTopkCNodeFromKernelGraph(const FuncGraphPtr &func_graph) { |
|
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
|
auto ret = func_graph->get_return(); |
|
|
|
MS_EXCEPTION_IF_NULL(ret); |
|
|
|
auto make_tuple = ret->input(1); |
|
|
|
MS_EXCEPTION_IF_NULL(make_tuple); |
|
|
|
auto tuple_getitem = make_tuple->cast<CNodePtr>()->input(1); |
|
|
|
MS_EXCEPTION_IF_NULL(tuple_getitem); |
|
|
|
auto topk = tuple_getitem->cast<CNodePtr>()->input(1); |
|
|
|
MS_EXCEPTION_IF_NULL(topk); |
|
|
|
auto topk_cnode = topk->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(topk_cnode); |
|
|
|
return topk_cnode; |
|
|
|
} |
|
|
|
|
|
|
|
UT::PyFuncGraphFetcher get_py_fun_; |
|
|
|
}; |
|
|
|
|
|
|
|
@@ -39,7 +55,8 @@ class MockSupportedChecker : public SupportedChecker { |
|
|
|
public: |
|
|
|
MockSupportedChecker() = default; |
|
|
|
~MockSupportedChecker() override = default; |
|
|
|
bool CheckAiCoreSupported(const AnfNodePtr &anf_node, const kernel::KernelBuildInfoPtr &select_kernel_build_info) override { |
|
|
|
bool CheckAiCoreSupported(const AnfNodePtr &anf_node, |
|
|
|
const kernel::KernelBuildInfoPtr &select_kernel_build_info) override { |
|
|
|
return true; |
|
|
|
} |
|
|
|
}; // namespace opt |
|
|
|
@@ -66,14 +83,7 @@ TEST_F(TestHWTopKSplit, test_topk_split) { |
|
|
|
optimizer->AddPassManager(pm); |
|
|
|
FuncGraphPtr new_graph = optimizer->Optimize(kernel_graph); |
|
|
|
|
|
|
|
auto ret = new_graph->get_return(); |
|
|
|
EXPECT_NE(ret, nullptr); |
|
|
|
auto make_tuple = ret->input(1); |
|
|
|
EXPECT_NE(make_tuple, nullptr); |
|
|
|
auto tuple_getitem = make_tuple->cast<CNodePtr>()->input(1); |
|
|
|
EXPECT_NE(tuple_getitem, nullptr); |
|
|
|
auto topk = tuple_getitem->cast<CNodePtr>()->input(1); |
|
|
|
auto topk_cnode = topk->cast<CNodePtr>(); |
|
|
|
auto topk_cnode = GetTopkCNodeFromKernelGraph(new_graph); |
|
|
|
EXPECT_EQ(topk_cnode->inputs().size(), 3); |
|
|
|
EXPECT_TRUE(topk_cnode->input(2)->isa<ValueNode>()); |
|
|
|
auto value_node = topk_cnode->input(2)->cast<ValueNodePtr>(); |
|
|
|
@@ -82,5 +92,39 @@ TEST_F(TestHWTopKSplit, test_topk_split) { |
|
|
|
EXPECT_EQ(tensor->shape().size(), 1); |
|
|
|
EXPECT_EQ(tensor->shape()[0], 4); |
|
|
|
} |
|
|
|
|
|
|
|
TEST_F(TestHWTopKSplit, test_topk_no_split) { |
|
|
|
/* |
|
|
|
* def before(input): |
|
|
|
* topk = TopKSplit(input) |
|
|
|
* output = tuple_getitem(topk, 0) |
|
|
|
* return output |
|
|
|
*/ |
|
|
|
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_topk_split", "before"); |
|
|
|
std::vector<int> shp{4, 4}; |
|
|
|
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); |
|
|
|
AbstractBasePtrList args_spec_list{x_abstract}; |
|
|
|
auto kernel_graph = GetKernelGraph(g, args_spec_list); |
|
|
|
|
|
|
|
CNodePtr topk_cnode = GetTopkCNodeFromKernelGraph(kernel_graph); |
|
|
|
EXPECT_EQ(topk_cnode->inputs().size(), 3); |
|
|
|
auto input_names_vec = AnfAlgo::GetNodeAttr<std::vector<std::string>>(topk_cnode, kAttrInputNames); |
|
|
|
EXPECT_EQ(input_names_vec.size(), 2); |
|
|
|
std::unordered_set<size_t> attr_index{1}; |
|
|
|
ConstInputToAttr(topk_cnode, attr_index); |
|
|
|
EXPECT_EQ(topk_cnode->inputs().size(), 2); |
|
|
|
input_names_vec = AnfAlgo::GetNodeAttr<std::vector<std::string>>(topk_cnode, kAttrInputNames); |
|
|
|
EXPECT_EQ(input_names_vec.size(), 1); |
|
|
|
|
|
|
|
auto optimizer = std::make_shared<opt::GraphOptimizer>(); |
|
|
|
auto pm = std::make_shared<opt::PassManager>(); |
|
|
|
pm->AddPass(std::make_shared<opt::ConvertConstInputToAttr>()); |
|
|
|
auto topk_split = std::make_shared<opt::TopKSplit>(); |
|
|
|
topk_split->supported_checker_ = std::make_shared<MockSupportedChecker>(); |
|
|
|
pm->AddPass(topk_split); |
|
|
|
optimizer->AddPassManager(pm); |
|
|
|
FuncGraphPtr new_graph = optimizer->Optimize(kernel_graph); |
|
|
|
EXPECT_EQ(topk_cnode, GetTopkCNodeFromKernelGraph(new_graph)); |
|
|
|
} |
|
|
|
} // namespace opt |
|
|
|
} // namespace mindspore |