diff --git a/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_mod.cc b/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_mod.cc index 5028c757f2..d18e0fbe6f 100644 --- a/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_mod.cc +++ b/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_mod.cc @@ -111,6 +111,9 @@ bool AicpuOpKernelMod::Launch(const std::vector &inputs, const std:: CreateCpuKernelInfo(inputs, outputs); auto *stream = reinterpret_cast(stream_ptr); + if (node_name_ == "TopK") { + node_name_ = "TopKV2"; + } MS_LOG(INFO) << "Aicpu launch, node_so_:" << node_so_ << ", node name:" << node_name_ << ", args_size:" << args_.length(); if (rtCpuKernelLaunch(reinterpret_cast(node_so_.c_str()), @@ -137,6 +140,9 @@ vector AicpuOpKernelMod::GenTask(const std::vector &inp (void)std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(output_data_addrs), [](const AddressPtr &output) -> void * { return output->addr; }); + if (node_name_ == "TopK") { + node_name_ = "TopKV2"; + } AicpuTaskInfoPtr task_info_ptr = make_shared( stream_id, node_so_, node_name_, node_def_str_, input_data_addrs, output_data_addrs); diff --git a/mindspore/ccsrc/kernel/tbe/tbe_kernel_select.cc b/mindspore/ccsrc/kernel/tbe/tbe_kernel_select.cc index 63e0fb888d..c96f8c7813 100644 --- a/mindspore/ccsrc/kernel/tbe/tbe_kernel_select.cc +++ b/mindspore/ccsrc/kernel/tbe/tbe_kernel_select.cc @@ -568,6 +568,12 @@ void TbeMetadataInfo(const CNodePtr &kernel_node, std::vector> parse_info_list; + + if (AnfAlgo::GetCNodeName(kernel_node) == kTopKOpName && AnfAlgo::GetNodeAttr(kernel_node, "sorted") == false) { + MS_LOG(INFO) << "will select aicpu topk."; + return; + } + std::string op_name = AnfAlgo::GetCNodeName(kernel_node); auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kTBE); if (op_info_ptr == nullptr) { diff --git a/mindspore/ops/_op_impl/aicpu/__init__.py b/mindspore/ops/_op_impl/aicpu/__init__.py index b0f90a629b..1b01a556cc 100644 --- a/mindspore/ops/_op_impl/aicpu/__init__.py +++ b/mindspore/ops/_op_impl/aicpu/__init__.py @@ -17,3 +17,4 @@ from .init_data_set_queue import _init_data_set_queue_aicpu from .dropout_genmask import _dropout_genmask_aicpu from .get_next import _get_next_aicpu from .print_tensor import _print_aicpu +from .topk import _top_k_aicpu diff --git a/mindspore/ops/_op_impl/aicpu/topk.py b/mindspore/ops/_op_impl/aicpu/topk.py new file mode 100644 index 0000000000..95bffbdb8c --- /dev/null +++ b/mindspore/ops/_op_impl/aicpu/topk.py @@ -0,0 +1,32 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""TopK op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType + +top_k_op_info = AiCPURegOp("TopK") \ + .fusion_type("OPAQUE") \ + .attr("sorted", "bool")\ + .input(0, "intput", "required") \ + .input(1, "k", "required") \ + .output(0, "values", "required") \ + .output(1, "indices", "required") \ + .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.I32_Default) \ + .get_op_info() + +@op_info_register(top_k_op_info) +def _top_k_aicpu(): + """TopK aicpu register""" + return