diff --git a/ge/single_op/single_op_model.cc b/ge/single_op/single_op_model.cc index 7d092091..eb9b5833 100755 --- a/ge/single_op/single_op_model.cc +++ b/ge/single_op/single_op_model.cc @@ -244,7 +244,7 @@ Status SingleOpModel::BuildTaskList(StreamResource *stream_resource, SingleOp &s if (kernel_type == ccKernelType::TE) { GELOGD("Building TBE task"); TbeOpTask *tbe_task = nullptr; - auto ret = BuildKernelTask(task_def.kernel(), &tbe_task); + auto ret = BuildKernelTask(task_def, &tbe_task); if (ret != SUCCESS) { return ret; } @@ -315,9 +315,11 @@ void SingleOpModel::ParseArgTable(OpTask *task, SingleOp &op) { } } -Status SingleOpModel::BuildKernelTask(const domi::KernelDef &kernel_def, TbeOpTask **task) { +Status SingleOpModel::BuildKernelTask(const domi::TaskDef &task_def, TbeOpTask **task) { GE_CHECK_NOTNULL(task); - const auto &context = kernel_def.context(); + auto task_type = static_cast(task_def.type()); + const auto &context = task_type == RT_MODEL_TASK_KERNEL ? task_def.kernel().context() : + task_def.kernel_with_handle().context(); auto iter = op_list_.find(context.op_index()); if (iter == op_list_.end()) { GELOGE(ACL_ERROR_GE_INTERNAL_ERROR, "op desc not found. op index = %u", context.op_index()); @@ -330,7 +332,7 @@ Status SingleOpModel::BuildKernelTask(const domi::KernelDef &kernel_def, TbeOpTa return ACL_ERROR_GE_MEMORY_ALLOCATION; } - auto builder = TbeTaskBuilder(model_name_, iter->second, kernel_def); + auto builder = TbeTaskBuilder(model_name_, iter->second, task_def); auto ret = builder.BuildTask(*tbe_task, model_params_); if (ret != SUCCESS) { delete tbe_task; @@ -401,13 +403,15 @@ Status SingleOpModel::BuildOp(StreamResource &resource, SingleOp &single_op) { } Status SingleOpModel::BuildModelTaskKernel(const TaskDef &task_def, DynamicSingleOp &single_op) { - const domi::KernelDef &kernel_def = task_def.kernel(); - const auto &context = kernel_def.context(); + auto task_type = static_cast(task_def.type()); + const auto &context = task_type == RT_MODEL_TASK_KERNEL ? task_def.kernel().context() : + task_def.kernel_with_handle().context(); + auto kernel_type = static_cast(context.kernel_type()); if (kernel_type == ccKernelType::TE) { GELOGD("Building TBE task"); TbeOpTask *tbe_task = nullptr; - GE_CHK_STATUS_RET_NOLOG(BuildKernelTask(task_def.kernel(), &tbe_task)); + GE_CHK_STATUS_RET_NOLOG(BuildKernelTask(task_def, &tbe_task)); tbe_task->SetModelArgs(model_name_, model_id_); single_op.op_task_.reset(tbe_task); } else if (kernel_type == ccKernelType::AI_CPU || kernel_type == ccKernelType::CUST_AI_CPU) { @@ -436,7 +440,7 @@ Status SingleOpModel::BuildTaskListForDynamicOp(DynamicSingleOp &single_op) { GELOGI("[%s] Task[%d], type = %u, DebugString = %s", model_name_.c_str(), i, task_def.type(), task_def.DebugString().c_str()); auto task_type = static_cast(task_def.type()); - if (task_type == RT_MODEL_TASK_KERNEL) { + if (task_type == RT_MODEL_TASK_KERNEL || task_type == RT_MODEL_TASK_ALL_KERNEL) { if (single_op.op_task_ != nullptr) { GELOGE(ACL_ERROR_GE_OP_TASK_TYPE_INVALID, "Do not support dynamic op with multiple tasks."); return ACL_ERROR_GE_OP_TASK_TYPE_INVALID; diff --git a/ge/single_op/single_op_model.h b/ge/single_op/single_op_model.h index 6637271c..684dab77 100755 --- a/ge/single_op/single_op_model.h +++ b/ge/single_op/single_op_model.h @@ -67,7 +67,7 @@ class SingleOpModel { Status BuildTaskList(StreamResource *stream_resource, SingleOp &single_op); Status BuildTaskListForDynamicOp(DynamicSingleOp &dynamic_single_op); - Status BuildKernelTask(const domi::KernelDef &kernel_def, TbeOpTask **task); + Status BuildKernelTask(const domi::TaskDef &task_def, TbeOpTask **task); Status BuildKernelExTask(const domi::KernelExDef &kernel_def, AiCpuTask **task, bool dynamic_flag, bool& depend_compute_flag, uint64_t kernel_id); Status BuildCpuKernelTask(const domi::KernelDef &kernel_def, OpTask **task, uint64_t kernel_id); diff --git a/ge/single_op/task/op_task.cc b/ge/single_op/task/op_task.cc index ff200806..4eb7d62d 100755 --- a/ge/single_op/task/op_task.cc +++ b/ge/single_op/task/op_task.cc @@ -93,6 +93,14 @@ void TbeOpTask::SetKernelArgs(std::unique_ptr &&args, size_t arg_size op_desc_ = op_desc; } +void TbeOpTask::SetKernelWithHandleArgs(std::unique_ptr &&args, size_t arg_size, uint32_t block_dim, + const OpDescPtr &op_desc, + const domi::KernelDefWithHandle &kernel_def_with_handle) { + SetKernelArgs(std::move(args), arg_size, block_dim, op_desc); + original_kernel_key_ = kernel_def_with_handle.original_kernel_key(); + node_info_ = kernel_def_with_handle.node_info(); +} + void TbeOpTask::SetSmDesc(void *sm_desc) { sm_desc_ = sm_desc; } void OpTask::SetModelArgs(std::string model_name, uint32_t model_id) { @@ -165,6 +173,10 @@ const std::string &TbeOpTask::GetStubName() const { return stub_name_; } uint32_t TbeOpTask::GetTaskType() const { return kTaskTypeAicore; } +void TbeOpTask::SetHandle(void *handle) { + this->handle_ = handle; +} + Status TbeOpTask::LaunchKernel(rtStream_t stream) { GELOGD("To invoke rtKernelLaunch. task = %s, block_dim = %u", this->stub_name_.c_str(), block_dim_); auto *sm_desc = reinterpret_cast(sm_desc_); @@ -204,6 +216,7 @@ Status TbeOpTask::UpdateRunInfo(const vector &input_desc, const ve } block_dim_ = run_info.block_dim; tiling_data_ = run_info.tiling_data.str(); + tiling_key_ = run_info.tiling_key; GELOGD("Done invoking OpParaCalculate successfully. block_dim = %u, tiling size = %zu", block_dim_, tiling_data_.size()); @@ -329,8 +342,17 @@ Status TbeOpTask::LaunchKernel(const vector &input_desc, } GELOGD("[%s] Start to invoke rtKernelLaunch", node_->GetName().c_str()); - GE_CHK_RT_RET(rtKernelLaunch(stub_func_, block_dim_, args_.get(), arg_size_, nullptr, stream)); - GELOGD("[%s] Done invoking rtKernelLaunch successfully", node_->GetName().c_str()); + if (original_kernel_key_.empty()) { + GE_CHK_RT_RET(rtKernelLaunch(stub_func_, block_dim_, args_.get(), arg_size_, nullptr, stream)); + GELOGD("[%s] Done invoking rtKernelLaunch successfully", node_->GetName().c_str()); + } else { + std::string dev_func = original_kernel_key_ + "_" + std::to_string(tiling_key_); + std::string kernel_info = node_info_ + "/" + std::to_string(tiling_key_); + GE_CHK_RT_RET(rtKernelLaunchWithHandle(handle_, dev_func.c_str(), block_dim_, args_.get(), arg_size_, nullptr, + stream, kernel_info.c_str())); + GELOGD("[%s] Done invoking rtKernelLaunchWithHandle successfully", node_->GetName().c_str()); + } + return SUCCESS; } diff --git a/ge/single_op/task/op_task.h b/ge/single_op/task/op_task.h index 78e1f6f0..be7f4aab 100644 --- a/ge/single_op/task/op_task.h +++ b/ge/single_op/task/op_task.h @@ -78,6 +78,8 @@ class TbeOpTask : public OpTask { void SetSmDesc(void *sm_desc); void SetStubFunc(const std::string &name, const void *stub_func); void SetKernelArgs(std::unique_ptr &&args, size_t arg_size, uint32_t block_dim, const OpDescPtr &op_desc); + void SetKernelWithHandleArgs(std::unique_ptr &&args, size_t arg_size, uint32_t block_dim, + const OpDescPtr &op_desc, const domi::KernelDefWithHandle& kernel_def_with_handle); Status UpdateRunInfo(const vector &input_desc, const vector &output_desc) override; @@ -87,6 +89,7 @@ class TbeOpTask : public OpTask { const std::string &GetStubName() const; void EnableDynamicSupport(const NodePtr &node, void *tiling_buffer, size_t max_tiling_size); uint32_t GetTaskType() const override; + void SetHandle(void *handle); private: friend class SingleOpModel; @@ -107,6 +110,11 @@ class TbeOpTask : public OpTask { std::string tiling_data_; std::vector workspaces_; NodePtr node_; + + uint32_t tiling_key_ = 0; + void* handle_ = nullptr; + std::string original_kernel_key_; + std::string node_info_; }; class AiCpuBaseTask : public OpTask { diff --git a/ge/single_op/task/tbe_task_builder.cc b/ge/single_op/task/tbe_task_builder.cc index 6eee61d0..9ce0bd0d 100644 --- a/ge/single_op/task/tbe_task_builder.cc +++ b/ge/single_op/task/tbe_task_builder.cc @@ -76,10 +76,12 @@ bool KernelBinRegistry::AddKernel(const std::string &stub_name, std::unique_ptr< return ret.second; } -TbeTaskBuilder::TbeTaskBuilder(const std::string &model_name, const NodePtr &node, const domi::KernelDef &kernel_def) +TbeTaskBuilder::TbeTaskBuilder(const std::string &model_name, const NodePtr &node, const domi::TaskDef &task_def) : node_(node), op_desc_(node->GetOpDesc()), - kernel_def_(kernel_def), + task_def_(task_def), + kernel_def_(task_def.kernel()), + kernel_def_with_handle_(task_def.kernel_with_handle()), stub_name_(model_name + "/" + node->GetName() + "_tvmbin") {} Status TbeTaskBuilder::DoRegisterBinary(const OpKernelBin &kernel_bin, void **bin_handle, @@ -89,9 +91,14 @@ Status TbeTaskBuilder::DoRegisterBinary(const OpKernelBin &kernel_bin, void **bi binary.data = kernel_bin.GetBinData(); binary.length = kernel_bin.GetBinDataSize(); binary.magic = param.core_type == 0 ? RT_DEV_BINARY_MAGIC_ELF : RT_DEV_BINARY_MAGIC_ELF_AIVEC; - auto ret = rtDevBinaryRegister(&binary, bin_handle); + Status ret = 0; + if (task_def_.type() == RT_MODEL_TASK_ALL_KERNEL) { + ret = rtRegisterAllKernel(&binary, bin_handle); + } else { + ret = rtDevBinaryRegister(&binary, bin_handle); + } if (ret != RT_ERROR_NONE) { - GELOGE(ret, "rtDevBinaryRegister failed, bin key = %s, core_type = %ld, rt ret = %d", stub_name_.c_str(), + GELOGE(ret, "DoRegisterBinary failed, bin key = %s, core_type = %ld, rt ret = %d", stub_name_.c_str(), param.core_type, static_cast(ret)); return ret; } @@ -128,14 +135,15 @@ Status TbeTaskBuilder::DoRegisterFunction(void *bin_handle, const char *stub_nam Status TbeTaskBuilder::DoRegisterKernel(const ge::OpKernelBin &tbe_kernel, const char *bin_file_key, void **bin_handle, const SingleOpModelParam ¶m) { - std::string kernel_name; - GetKernelName(op_desc_, kernel_name); - void *handle = nullptr; auto ret = DoRegisterBinary(tbe_kernel, &handle, param); if (ret != SUCCESS) { return ret; } + if (task_def_.type() == RT_MODEL_TASK_ALL_KERNEL) { + *bin_handle = handle; + return SUCCESS; + } ret = DoRegisterMeta(handle); if (ret != SUCCESS) { @@ -143,6 +151,8 @@ Status TbeTaskBuilder::DoRegisterKernel(const ge::OpKernelBin &tbe_kernel, const return ret; } + std::string kernel_name; + GetKernelName(op_desc_, kernel_name); ret = DoRegisterFunction(handle, bin_file_key, kernel_name.c_str()); if (ret != SUCCESS) { GE_CHK_RT(rtDevBinaryUnRegister(handle)); @@ -186,6 +196,7 @@ Status TbeTaskBuilder::RegisterKernel(TbeOpTask &task, const SingleOpModelParam void *bin_handle = nullptr; auto ret = DoRegisterKernel(*tbe_kernel, stub_func, &bin_handle, param); + handle_ = bin_handle; if (ret == SUCCESS) { holder->SetBinHandle(bin_handle); if (!registry.AddKernel(stub_name_, std::move(holder))) { @@ -200,6 +211,28 @@ Status TbeTaskBuilder::RegisterKernel(TbeOpTask &task, const SingleOpModelParam return SUCCESS; } +Status TbeTaskBuilder::RegisterKernelWithHandle(TbeOpTask &task, const SingleOpModelParam ¶m) { + GELOGD("RegisterKernelWithHandle begin"); + + auto tbe_kernel = GetTbeKernel(op_desc_); + if (tbe_kernel == nullptr) { + GELOGE(ACL_ERROR_GE_INTERNAL_ERROR, "OP EXT ATTR NAME TBE_KERNEL not found. op = %s", + op_desc_->GetName().c_str()); + return ACL_ERROR_GE_INTERNAL_ERROR; + } + + void *bin_handle = nullptr; + auto ret = DoRegisterKernel(*tbe_kernel, nullptr, &bin_handle, param); + handle_ = bin_handle; + if (ret != SUCCESS) { + // should not happen. only one thread can reach here + GELOGE(ACL_ERROR_GE_INTERNAL_ERROR, "[DoRegisterKernel] failed. stub name = %s", stub_name_.c_str()); + return ACL_ERROR_GE_INTERNAL_ERROR; + } + + return SUCCESS; +} + Status TbeTaskBuilder::GetSmDesc(void **sm_desc, const SingleOpModelParam ¶m) const { const std::string &sm_desc_str = kernel_def_.sm_desc(); if (sm_desc_str.empty()) { @@ -264,33 +297,77 @@ Status TbeTaskBuilder::SetKernelArgs(TbeOpTask &task, const SingleOpModelParam & return RT_ERROR_TO_GE_STATUS(rtRet); } } - task.SetKernelArgs(std::move(args), arg_size, kernel_def_.block_dim(), op_desc); + + return SUCCESS; +} + +Status TbeTaskBuilder::SetKernelWithHandleArgs(TbeOpTask &task, const SingleOpModelParam ¶m, + const OpDescPtr &op_desc) { + size_t arg_size = kernel_def_with_handle_.args_size(); + auto args = std::unique_ptr(new (std::nothrow) uint8_t[arg_size]); + GE_CHECK_NOTNULL(args); + + auto rtRet = rtMemcpy(args.get(), arg_size, kernel_def_with_handle_.args().data(), arg_size, RT_MEMCPY_HOST_TO_HOST); + if (rtRet != RT_ERROR_NONE) { + GELOGE(rtRet, "rtMemcpy args failed, size = %zu, ret = %d", arg_size, static_cast(rtRet)); + return rtRet; + } + + const domi::KernelContext &context = kernel_def_with_handle_.context(); + const auto *args_offset_tmp = reinterpret_cast(context.args_offset().data()); + uint16_t offset = *args_offset_tmp; + + bool is_dynamic = false; + (void)AttrUtils::GetBool(op_desc_, kAttrSupportDynamicShape, is_dynamic); + if (is_dynamic) { + GE_CHK_STATUS_RET_NOLOG(InitTilingInfo(task)); + } else { + // copy args + std::vector tensor_device_addr_vec = BuildTaskUtils::GetKernelArgs(op_desc_, param); + void *src_addr = reinterpret_cast(tensor_device_addr_vec.data()); + uint64_t src_len = sizeof(void *) * tensor_device_addr_vec.size(); + rtRet = rtMemcpy(args.get() + offset, arg_size - offset, src_addr, src_len, RT_MEMCPY_HOST_TO_HOST); + if (rtRet != RT_ERROR_NONE) { + GELOGE(rtRet, "rtMemcpy addresses failed, ret = %d", static_cast(rtRet)); + return rtRet; + } + } + task.SetKernelWithHandleArgs(std::move(args), arg_size, kernel_def_with_handle_.block_dim(), op_desc, + kernel_def_with_handle_); + return SUCCESS; } Status TbeTaskBuilder::BuildTask(TbeOpTask &task, const SingleOpModelParam ¶m) { GELOGD("Build tbe task begin"); - auto ret = SetKernelArgs(task, param, op_desc_); + auto task_type = static_cast(task_def_.type()); + auto ret = task_type == RT_MODEL_TASK_ALL_KERNEL ? SetKernelWithHandleArgs(task, param, op_desc_) : + SetKernelArgs(task, param, op_desc_); if (ret != SUCCESS) { return ret; } - ret = RegisterKernel(task, param); + ret = task_type == RT_MODEL_TASK_ALL_KERNEL ? RegisterKernelWithHandle(task, param) : + RegisterKernel(task, param); + task.SetHandle(handle_); if (ret != SUCCESS) { return ret; } + auto task_info = BuildTaskUtils::GetTaskInfo(op_desc_); GELOGI("[TASK_INFO] %s %s", stub_name_.c_str(), task_info.c_str()); - void *stub_func = nullptr; - auto rtRet = rtGetFunctionByName(stub_name_.c_str(), &stub_func); - if (rtRet != SUCCESS) { - GELOGE(rtRet, "rtGetFunctionByName failed."); - return RT_ERROR_TO_GE_STATUS(rtRet); + if (task_type != RT_MODEL_TASK_ALL_KERNEL) { + void *stub_func = nullptr; + auto rtRet = rtGetFunctionByName(stub_name_.c_str(), &stub_func); + if (rtRet != SUCCESS) { + GELOGE(rtRet, "rtGetFunctionByName failed."); + return RT_ERROR_TO_GE_STATUS(rtRet); + } + task.SetStubFunc(stub_name_, stub_func); } - task.SetStubFunc(stub_name_, stub_func); return SUCCESS; } diff --git a/ge/single_op/task/tbe_task_builder.h b/ge/single_op/task/tbe_task_builder.h index 5cd5c463..a5e01cfb 100755 --- a/ge/single_op/task/tbe_task_builder.h +++ b/ge/single_op/task/tbe_task_builder.h @@ -63,7 +63,7 @@ class KernelBinRegistry { class TbeTaskBuilder { public: - TbeTaskBuilder(const std::string &model_name, const NodePtr &node, const domi::KernelDef &kernel_def); + TbeTaskBuilder(const std::string &model_name, const NodePtr &node, const domi::TaskDef &task_def); ~TbeTaskBuilder() = default; Status BuildTask(TbeOpTask &task, const SingleOpModelParam ¶m); @@ -71,9 +71,11 @@ class TbeTaskBuilder { private: Status InitTilingInfo(TbeOpTask &task); Status SetKernelArgs(TbeOpTask &task, const SingleOpModelParam ¶m, const OpDescPtr &op_desc); + Status SetKernelWithHandleArgs(TbeOpTask &task, const SingleOpModelParam ¶m, const OpDescPtr &op_desc); Status GetSmDesc(void **sm_desc, const SingleOpModelParam ¶m) const; Status RegisterKernel(TbeOpTask &task, const SingleOpModelParam ¶m); + Status RegisterKernelWithHandle(TbeOpTask &task, const SingleOpModelParam ¶m); Status DoRegisterKernel(const OpKernelBin &kernel_bin, const char *bin_file_key, void **bin_handle, const SingleOpModelParam ¶m); Status DoRegisterBinary(const OpKernelBin &kernel_bin, void **bin_handle, const SingleOpModelParam ¶m) const; @@ -83,8 +85,11 @@ class TbeTaskBuilder { const NodePtr node_; const OpDescPtr op_desc_; + const domi::TaskDef &task_def_; const domi::KernelDef &kernel_def_; + const domi::KernelDefWithHandle &kernel_def_with_handle_; const std::string stub_name_; + void *handle_ = nullptr; }; } // namespace ge