Merge pull request !905 from yanzhenxiang2020/add_topkop_for_aicputags/v0.3.0-alpha
| @@ -111,6 +111,9 @@ bool AicpuOpKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std:: | |||||
| CreateCpuKernelInfo(inputs, outputs); | CreateCpuKernelInfo(inputs, outputs); | ||||
| auto *stream = reinterpret_cast<rtStream_t *>(stream_ptr); | auto *stream = reinterpret_cast<rtStream_t *>(stream_ptr); | ||||
| if (node_name_ == "TopK") { | |||||
| node_name_ = "TopKV2"; | |||||
| } | |||||
| MS_LOG(INFO) << "Aicpu launch, node_so_:" << node_so_ << ", node name:" << node_name_ | MS_LOG(INFO) << "Aicpu launch, node_so_:" << node_so_ << ", node name:" << node_name_ | ||||
| << ", args_size:" << args_.length(); | << ", args_size:" << args_.length(); | ||||
| if (rtCpuKernelLaunch(reinterpret_cast<const void *>(node_so_.c_str()), | if (rtCpuKernelLaunch(reinterpret_cast<const void *>(node_so_.c_str()), | ||||
| @@ -137,6 +140,9 @@ vector<TaskInfoPtr> AicpuOpKernelMod::GenTask(const std::vector<AddressPtr> &inp | |||||
| (void)std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(output_data_addrs), | (void)std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(output_data_addrs), | ||||
| [](const AddressPtr &output) -> void * { return output->addr; }); | [](const AddressPtr &output) -> void * { return output->addr; }); | ||||
| if (node_name_ == "TopK") { | |||||
| node_name_ = "TopKV2"; | |||||
| } | |||||
| AicpuTaskInfoPtr task_info_ptr = make_shared<ge::model_runner::AicpuTaskInfo>( | AicpuTaskInfoPtr task_info_ptr = make_shared<ge::model_runner::AicpuTaskInfo>( | ||||
| stream_id, node_so_, node_name_, node_def_str_, input_data_addrs, output_data_addrs); | stream_id, node_so_, node_name_, node_def_str_, input_data_addrs, output_data_addrs); | ||||
| @@ -568,6 +568,12 @@ void TbeMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<Ke | |||||
| MS_EXCEPTION_IF_NULL(kernel_node); | MS_EXCEPTION_IF_NULL(kernel_node); | ||||
| MS_EXCEPTION_IF_NULL(kernel_info_list); | MS_EXCEPTION_IF_NULL(kernel_info_list); | ||||
| std::vector<std::shared_ptr<kernel::KernelBuildInfo>> parse_info_list; | std::vector<std::shared_ptr<kernel::KernelBuildInfo>> parse_info_list; | ||||
| if (AnfAlgo::GetCNodeName(kernel_node) == kTopKOpName && AnfAlgo::GetNodeAttr<bool>(kernel_node, "sorted") == false) { | |||||
| MS_LOG(INFO) << "will select aicpu topk."; | |||||
| return; | |||||
| } | |||||
| std::string op_name = AnfAlgo::GetCNodeName(kernel_node); | std::string op_name = AnfAlgo::GetCNodeName(kernel_node); | ||||
| auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kTBE); | auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kTBE); | ||||
| if (op_info_ptr == nullptr) { | if (op_info_ptr == nullptr) { | ||||
| @@ -17,3 +17,4 @@ from .init_data_set_queue import _init_data_set_queue_aicpu | |||||
| from .dropout_genmask import _dropout_genmask_aicpu | from .dropout_genmask import _dropout_genmask_aicpu | ||||
| from .get_next import _get_next_aicpu | from .get_next import _get_next_aicpu | ||||
| from .print_tensor import _print_aicpu | from .print_tensor import _print_aicpu | ||||
| from .topk import _top_k_aicpu | |||||
| @@ -0,0 +1,32 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """TopK op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType | |||||
| top_k_op_info = AiCPURegOp("TopK") \ | |||||
| .fusion_type("OPAQUE") \ | |||||
| .attr("sorted", "bool")\ | |||||
| .input(0, "intput", "required") \ | |||||
| .input(1, "k", "required") \ | |||||
| .output(0, "values", "required") \ | |||||
| .output(1, "indices", "required") \ | |||||
| .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.I32_Default) \ | |||||
| .get_op_info() | |||||
| @op_info_register(top_k_op_info) | |||||
| def _top_k_aicpu(): | |||||
| """TopK aicpu register""" | |||||
| return | |||||