|
|
|
@@ -77,6 +77,9 @@ ValueNodePtr CreateValueNode(const AnfNodePtr &node) { |
|
|
|
|
|
|
|
kernel::KernelBuildInfoPtr CreateKernelBuildInfo() { |
|
|
|
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; |
|
|
|
builder.SetKernelType(TBE_KERNEL); |
|
|
|
builder.SetFusionType(kernel::OPAQUE); |
|
|
|
builder.SetProcessor(kernel::AICORE); |
|
|
|
builder.SetInputsFormat({kOpFormat_DEFAULT, kOpFormat_DEFAULT}); |
|
|
|
builder.SetOutputsFormat({kOpFormat_DEFAULT, kOpFormat_DEFAULT}); |
|
|
|
builder.SetInputsDeviceType({kNumberTypeFloat16, kNumberTypeFloat16}); |
|
|
|
@@ -129,10 +132,12 @@ const AnfNodePtr TopKSplit::Process(const FuncGraphPtr &func_graph, const AnfNod |
|
|
|
new_cnode->add_input(indices_const); |
|
|
|
MS_EXCEPTION_IF_NULL(supported_checker_); |
|
|
|
if (!supported_checker_->CheckAiCoreSupported(new_cnode, CreateKernelBuildInfo())) { |
|
|
|
MS_LOG(INFO) << "split topk failed, check to aicpu."; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
if (kernel_graph != nullptr) { |
|
|
|
MS_LOG(INFO) << "split topk success. use tbe aicore."; |
|
|
|
kernel_graph->AddValueNodeToGraph(indices_const); |
|
|
|
} |
|
|
|
|
|
|
|
|