From fddc9f01ca6e651227bf90358615b16805361f92 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E6=B6=9B?= Date: Thu, 4 Mar 2021 23:09:35 +0800 Subject: [PATCH] =?UTF-8?q?=E5=9B=9E=E9=80=80=20'Pull=20Request=20!1189=20?= =?UTF-8?q?:=20multi-kernel=20modification'?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ge/hybrid/model/hybrid_model_builder.cc | 7 +-- ge/single_op/single_op_model.cc | 9 ++- tests/ut/ge/hybrid/ge_hybrid_unittest.cc | 80 ------------------------ 3 files changed, 6 insertions(+), 90 deletions(-) diff --git a/ge/hybrid/model/hybrid_model_builder.cc b/ge/hybrid/model/hybrid_model_builder.cc index 48558e83..7ea9e446 100755 --- a/ge/hybrid/model/hybrid_model_builder.cc +++ b/ge/hybrid/model/hybrid_model_builder.cc @@ -1131,22 +1131,19 @@ Status HybridModelBuilder::IndexTaskDefs(const ComputeGraphPtr &sub_graph, const op_index = task_def.kernel_ex().op_index(); } else if (task_type == RT_MODEL_TASK_HCCL) { op_index = task_def.kernel_hccl().op_index(); - } else if (task_type == RT_MODEL_TASK_ALL_KERNEL) { - op_index = task_def.kernel_with_handle().context().op_index(); } else { GELOGD("Skip task type: %d", static_cast(task_type)); continue; } - GELOGD("op_index = %u, task_type = %d", op_index, task_type); auto iter = node_map.find(op_index); if (iter == node_map.end()) { - GELOGE(INTERNAL_ERROR, "Failed to get node by op_index = %u", op_index); + GELOGE(INTERNAL_ERROR, "Failed to get node by index = %u", op_index); return INTERNAL_ERROR; } auto &node = iter->second; - if (task_type == RT_MODEL_TASK_KERNEL || task_type == RT_MODEL_TASK_ALL_KERNEL) { + if (task_type == RT_MODEL_TASK_KERNEL) { ge_model->GetTBEKernelStore().LoadTBEKernelBinToOpDesc(node->GetOpDesc()); } diff --git a/ge/single_op/single_op_model.cc b/ge/single_op/single_op_model.cc index 49dde9c4..43c47894 100755 --- a/ge/single_op/single_op_model.cc +++ b/ge/single_op/single_op_model.cc @@ -48,8 +48,7 @@ bool NeedHybridModel(GeModelPtr &ge_model) { auto tasks = ge_model->GetModelTaskDefPtr()->task(); int32_t kernel_task_num = 0; for (int i = 0; i < tasks.size(); ++i) { - auto task_type = static_cast(tasks[i].type()); - if (task_type == RT_MODEL_TASK_KERNEL || task_type == RT_MODEL_TASK_ALL_KERNEL) { + if (static_cast(tasks[i].type()) == RT_MODEL_TASK_KERNEL) { kernel_task_num++; if (kernel_task_num > 1) { return true; @@ -255,9 +254,9 @@ Status SingleOpModel::BuildTaskList(StreamResource *stream_resource, SingleOp &s GELOGI("[%s] Task[%d], type = %u, DebugString = %s", model_name_.c_str(), i, task_def.type(), task_def.DebugString().c_str()); auto task_type = static_cast(task_def.type()); - if (task_type == RT_MODEL_TASK_KERNEL || task_type == RT_MODEL_TASK_ALL_KERNEL) { - const auto &context = task_type == RT_MODEL_TASK_KERNEL ? task_def.kernel().context() : - task_def.kernel_with_handle().context(); + if (task_type == RT_MODEL_TASK_KERNEL) { + const domi::KernelDef &kernel_def = task_def.kernel(); + const auto &context = kernel_def.context(); auto kernel_type = static_cast(context.kernel_type()); if (kernel_type == ccKernelType::TE) { GELOGD("Building TBE task"); diff --git a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc index 0b6ca271..97a36894 100644 --- a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc +++ b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc @@ -41,7 +41,6 @@ using namespace std; using namespace testing; using namespace ge; -using namespace hybrid; class UtestGeHybrid : public testing::Test { protected: @@ -111,83 +110,4 @@ TEST_F(UtestGeHybrid, task_update_tiling_info) { auto node = graph->AddNode(op_desc); optiling::OpRunInfo tiling_info; ASSERT_EQ(aicore_task->CalcTilingInfo(node, tiling_info), SUCCESS); -} - -TEST_F(UtestGeHybrid, index_taskdefs_failed) { - // build aicore task - domi::ModelTaskDef model_task_def; - - std::shared_ptr model_task_def_ptr = make_shared(model_task_def); - domi::TaskDef *task_def = model_task_def_ptr->add_task(); - GeModelPtr ge_model = make_shared(); - ge_model->SetModelTaskDef(model_task_def_ptr); - - auto aicore_task = std::unique_ptr(new(std::nothrow)hybrid::AiCoreOpTask()); - task_def->set_type(RT_MODEL_TASK_ALL_KERNEL); - domi::KernelDefWithHandle *kernel_with_handle = task_def->mutable_kernel_with_handle(); - kernel_with_handle->set_original_kernel_key(""); - kernel_with_handle->set_node_info(""); - kernel_with_handle->set_block_dim(32); - kernel_with_handle->set_args_size(64); - string args(64, '1'); - kernel_with_handle->set_args(args.data(), 64); - domi::KernelContext *context = kernel_with_handle->mutable_context(); - context->set_op_index(1); - context->set_kernel_type(2); // ccKernelType::TE - uint16_t args_offset[9] = {0}; - context->set_args_offset(args_offset, 9 * sizeof(uint16_t)); - - OpDescPtr op_desc = CreateOpDesc("Add", "Add"); - std::vector kernelBin; - TBEKernelPtr tbe_kernel = std::make_shared("name/Add", std::move(kernelBin)); - op_desc->SetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, tbe_kernel); - std::string kernel_name("kernel/Add"); - AttrUtils::SetStr(op_desc, op_desc->GetName() + "_kernelname", kernel_name); - - ComputeGraphPtr graph = std::make_shared("test"); - GeRootModelPtr ge_root_model = make_shared(graph); - HybridModel hybrid_model(ge_root_model); - HybridModelBuilder hybrid_model_builder(hybrid_model); - - ASSERT_EQ(hybrid_model_builder.IndexTaskDefs(graph, ge_model), INTERNAL_ERROR); -} - -TEST_F(UtestGeHybrid, index_taskdefs_success) { - // build aicore task - domi::ModelTaskDef model_task_def; - - std::shared_ptr model_task_def_ptr = make_shared(model_task_def); - domi::TaskDef *task_def = model_task_def_ptr->add_task(); - GeModelPtr ge_model = make_shared(); - ge_model->SetModelTaskDef(model_task_def_ptr); - - auto aicore_task = std::unique_ptr(new(std::nothrow)hybrid::AiCoreOpTask()); - task_def->set_type(RT_MODEL_TASK_ALL_KERNEL); - domi::KernelDefWithHandle *kernel_with_handle = task_def->mutable_kernel_with_handle(); - kernel_with_handle->set_original_kernel_key(""); - kernel_with_handle->set_node_info(""); - kernel_with_handle->set_block_dim(32); - kernel_with_handle->set_args_size(64); - string args(64, '1'); - kernel_with_handle->set_args(args.data(), 64); - domi::KernelContext *context = kernel_with_handle->mutable_context(); - context->set_op_index(0); - context->set_kernel_type(2); // ccKernelType::TE - uint16_t args_offset[9] = {0}; - context->set_args_offset(args_offset, 9 * sizeof(uint16_t)); - - OpDescPtr op_desc = CreateOpDesc("Add", "Add"); - std::vector kernelBin; - TBEKernelPtr tbe_kernel = std::make_shared("name/Add", std::move(kernelBin)); - op_desc->SetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, tbe_kernel); - std::string kernel_name("kernel/Add"); - AttrUtils::SetStr(op_desc, op_desc->GetName() + "_kernelname", kernel_name); - - ComputeGraphPtr graph = std::make_shared("test"); - NodePtr node = graph->AddNode(op_desc); - GeRootModelPtr ge_root_model = make_shared(graph); - HybridModel hybrid_model(ge_root_model); - HybridModelBuilder hybrid_model_builder(hybrid_model); - - ASSERT_EQ(hybrid_model_builder.IndexTaskDefs(graph, ge_model), SUCCESS); } \ No newline at end of file