Signed-off-by: zhoufeng <zhoufeng54@huawei.com>pull/15107/head
| @@ -252,7 +252,6 @@ if(NOT ENABLE_GE) | |||
| FILES | |||
| ${CMAKE_BINARY_DIR}/graphengine/metadef/graph/libgraph.so | |||
| ${CMAKE_BINARY_DIR}/graphengine/ge/common/libge_common.so | |||
| ${CMAKE_BINARY_DIR}/graphengine/ge/ge_runtime/libge_runtime.so | |||
| DESTINATION ${INSTALL_LIB_DIR} | |||
| COMPONENT mindspore | |||
| ) | |||
| @@ -309,7 +309,7 @@ if(ENABLE_D) | |||
| target_link_options(ms_profile PRIVATE -Wl,-init,common_log_init) | |||
| target_link_libraries(ms_profile -Wl,--start-group -Wl,--whole-archive ${PROFILING} -Wl,--no-whole-archive | |||
| mindspore::protobuf -Wl,--end-group) | |||
| target_link_libraries(mindspore ge_runtime ${CCE_LIB} ${RUNTIME_LIB} ${TSDCLIENT} ${HCCL} ${DATATRANSFER} | |||
| target_link_libraries(mindspore ${CCE_LIB} ${RUNTIME_LIB} ${TSDCLIENT} ${HCCL} ${DATATRANSFER} | |||
| ${HCCL_ADPTER} ${REGISTER} -Wl,--no-as-needed ${OPTILING} ${HCCL_BUILDER} | |||
| ${HCCL_RA} ${PLATFORM} ${ACL}) | |||
| target_link_libraries(mindspore -Wl,--start-group proto_input mindspore::protobuf -Wl,--end-group) | |||
| @@ -30,7 +30,7 @@ | |||
| #include "runtime/device/kernel_runtime.h" | |||
| #include "runtime/device/ascend/executor/host_dynamic_kernel.h" | |||
| using AicpuTaskInfoPtr = std::shared_ptr<ge::model_runner::AicpuTaskInfo>; | |||
| using AicpuTaskInfoPtr = std::shared_ptr<mindspore::ge::model_runner::AicpuTaskInfo>; | |||
| using AicpuDynamicKernel = mindspore::device::ascend::AiCpuDynamicKernel; | |||
| using HostDynamicKernel = mindspore::device::ascend::HostDynamicKernel; | |||
| @@ -193,9 +193,9 @@ std::vector<TaskInfoPtr> AicpuOpKernelMod::GenTask(const std::vector<AddressPtr> | |||
| node_name_ = kPack; | |||
| } | |||
| AicpuTaskInfoPtr task_info_ptr = | |||
| make_shared<ge::model_runner::AicpuTaskInfo>(kernel_name_, stream_id, node_so_, node_name_, node_def_str_, | |||
| ext_info_, input_data_addrs, output_data_addrs, NeedDump()); | |||
| AicpuTaskInfoPtr task_info_ptr = std::make_shared<mindspore::ge::model_runner::AicpuTaskInfo>( | |||
| kernel_name_, stream_id, node_so_, node_name_, node_def_str_, ext_info_, input_data_addrs, output_data_addrs, | |||
| NeedDump()); | |||
| MS_LOG(INFO) << "AicpuOpKernelMod GenTask end"; | |||
| return {task_info_ptr}; | |||
| @@ -29,7 +29,7 @@ using std::fstream; | |||
| using std::map; | |||
| using std::mutex; | |||
| using std::string; | |||
| using TbeTaskInfoPtr = std::shared_ptr<ge::model_runner::TbeTaskInfo>; | |||
| using TbeTaskInfoPtr = std::shared_ptr<mindspore::ge::model_runner::TbeTaskInfo>; | |||
| using tbe::KernelManager; | |||
| constexpr uint32_t DEFAULT_BLOCK_DIM = 1; | |||
| /** | |||
| @@ -118,7 +118,7 @@ std::vector<TaskInfoPtr> AkgKernelMod::GenTask(const std::vector<AddressPtr> &in | |||
| MS_LOG(DEBUG) << "The block_dim is:" << block_dim; | |||
| TbeTaskInfoPtr task_info_ptr = make_shared<ge::model_runner::TbeTaskInfo>( | |||
| TbeTaskInfoPtr task_info_ptr = std::make_shared<mindspore::ge::model_runner::TbeTaskInfo>( | |||
| kernel_name_, stream_id, stub_func, block_dim, args, args_size, sm_desc, binary, binary_size, meta_data, | |||
| input_data_addrs, output_data_addrs, workspace_addrs, NeedDump()); | |||
| return {task_info_ptr}; | |||
| @@ -19,11 +19,11 @@ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "framework/ge_runtime/task_info.h" | |||
| #include "runtime/device/ascend/ge_runtime/task_info.h" | |||
| #include "backend/kernel_compiler/kernel.h" | |||
| #include "debug/data_dump/dump_json_parser.h" | |||
| using TaskInfoPtr = std::shared_ptr<ge::model_runner::TaskInfo>; | |||
| using TaskInfoPtr = std::shared_ptr<mindspore::ge::model_runner::TaskInfo>; | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| class AscendKernelMod : public KernelMod { | |||
| @@ -24,8 +24,8 @@ | |||
| #include "runtime/device/ascend/executor/hccl_dynamic_kernel.h" | |||
| #include "runtime/hccl_adapter/hccl_adapter.h" | |||
| using HcclTaskInfoPtr = std::shared_ptr<ge::model_runner::HcclTaskInfo>; | |||
| using ge::model_runner::HcclTaskInfo; | |||
| using HcclTaskInfoPtr = std::shared_ptr<mindspore::ge::model_runner::HcclTaskInfo>; | |||
| using mindspore::ge::model_runner::HcclTaskInfo; | |||
| namespace { | |||
| static std::map<std::string, std::string> kMsOpNameToHcomHcclType = { | |||
| @@ -18,7 +18,7 @@ | |||
| #include <memory> | |||
| #include "runtime/mem.h" | |||
| using ge::model_runner::MemcpyAsyncTaskInfo; | |||
| using mindspore::ge::model_runner::MemcpyAsyncTaskInfo; | |||
| using MemcpyAsyncTaskInfoPtr = std::shared_ptr<MemcpyAsyncTaskInfo>; | |||
| namespace mindspore { | |||
| @@ -20,7 +20,7 @@ | |||
| #include "framework/ge_runtime/task_info.h" | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| using ge::model_runner::LabelGotoTaskInfo; | |||
| using mindspore::ge::model_runner::LabelGotoTaskInfo; | |||
| using LabelGotoTaskInfoPtr = std::shared_ptr<LabelGotoTaskInfo>; | |||
| namespace mindspore { | |||
| @@ -20,7 +20,7 @@ | |||
| #include "framework/ge_runtime/task_info.h" | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| using ge::model_runner::LabelSetTaskInfo; | |||
| using mindspore::ge::model_runner::LabelSetTaskInfo; | |||
| using LabelSetTaskInfoPtr = std::shared_ptr<LabelSetTaskInfo>; | |||
| namespace mindspore { | |||
| @@ -21,7 +21,7 @@ | |||
| #include "framework/ge_runtime/task_info.h" | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| using ge::model_runner::LabelSwitchTaskInfo; | |||
| using mindspore::ge::model_runner::LabelSwitchTaskInfo; | |||
| using LabelSwitchTaskInfoPtr = std::shared_ptr<LabelSwitchTaskInfo>; | |||
| namespace mindspore { | |||
| @@ -25,7 +25,7 @@ | |||
| #include "runtime/device/kernel_runtime.h" | |||
| #include "runtime/device/ascend/executor/rts/memcpy_rts_dynamic_kernel.h" | |||
| using ge::model_runner::MemcpyAsyncTaskInfo; | |||
| using mindspore::ge::model_runner::MemcpyAsyncTaskInfo; | |||
| using MemcpyAsyncTaskInfoPtr = std::shared_ptr<MemcpyAsyncTaskInfo>; | |||
| using AddressPtrList = std::vector<mindspore::kernel::AddressPtr>; | |||
| using mindspore::device::ascend::MemcpyRtsDynamicKernel; | |||
| @@ -23,7 +23,7 @@ | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "runtime/device/ascend/executor/rts/profiling_rts_dynamic_kernel.h" | |||
| using ProfilerTraceTaskInfo = ge::model_runner::ProfilerTraceTaskInfo; | |||
| using ProfilerTraceTaskInfo = mindspore::ge::model_runner::ProfilerTraceTaskInfo; | |||
| using mindspore::device::ascend::ProfilingRtsDynamicKernel; | |||
| using mindspore::device::ascend::ProfilingUtils; | |||
| @@ -23,7 +23,7 @@ | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| using ge::model_runner::EventWaitTaskInfo; | |||
| using mindspore::ge::model_runner::EventWaitTaskInfo; | |||
| using EventWaitTaskInfoPtr = std::shared_ptr<EventWaitTaskInfo>; | |||
| RecvKernel::RecvKernel() { event_id_ = 0; } | |||
| @@ -20,7 +20,7 @@ | |||
| #include "framework/ge_runtime/task_info.h" | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| using ge::model_runner::EventRecordTaskInfo; | |||
| using mindspore::ge::model_runner::EventRecordTaskInfo; | |||
| using EventRecordTaskInfoPtr = std::shared_ptr<EventRecordTaskInfo>; | |||
| namespace mindspore { | |||
| @@ -20,7 +20,7 @@ | |||
| #include "framework/ge_runtime/task_info.h" | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| using ge::model_runner::StreamActiveTaskInfo; | |||
| using mindspore::ge::model_runner::StreamActiveTaskInfo; | |||
| using StreamActiveTaskInfoPtr = std::shared_ptr<StreamActiveTaskInfo>; | |||
| namespace mindspore { | |||
| @@ -21,7 +21,7 @@ | |||
| #include "framework/ge_runtime/task_info.h" | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| using ge::model_runner::StreamSwitchTaskInfo; | |||
| using mindspore::ge::model_runner::StreamSwitchTaskInfo; | |||
| using StreamSwitchTaskInfoPtr = std::shared_ptr<StreamSwitchTaskInfo>; | |||
| namespace mindspore { | |||
| @@ -24,7 +24,7 @@ | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| using TbeTaskInfoPtr = std::shared_ptr<ge::model_runner::TbeTaskInfo>; | |||
| using TbeTaskInfoPtr = std::shared_ptr<mindspore::ge::model_runner::TbeTaskInfo>; | |||
| using tbe::KernelManager; | |||
| using AddressPtrList = std::vector<mindspore::kernel::AddressPtr>; | |||
| bool TbeKernelMod::Launch(const std::vector<mindspore::kernel::AddressPtr> &inputs, | |||
| @@ -102,7 +102,7 @@ std::vector<TaskInfoPtr> TbeKernelMod::GenTask(const std::vector<AddressPtr> &in | |||
| MS_LOG(INFO) << "block_dim is:" << block_dim_; | |||
| TbeTaskInfoPtr task_info_ptr = make_shared<ge::model_runner::TbeTaskInfo>( | |||
| TbeTaskInfoPtr task_info_ptr = std::make_shared<mindspore::ge::model_runner::TbeTaskInfo>( | |||
| kernel_name_, stream_id, stub_func, block_dim_, args, 0, sm_desc, nullptr, 0, meta_data, input_data_addrs, | |||
| output_data_addrs, workspace_addrs, NeedDump()); | |||
| return {task_info_ptr}; | |||
| @@ -36,7 +36,7 @@ using mindspore::kernel::tbe::TbeUtils; | |||
| bool TbeOpParallelBuild(const std::vector<AnfNodePtr> &anf_nodes) { | |||
| auto build_manger = std::make_shared<ParallelBuildManager>(); | |||
| MS_EXCEPTION_IF_NULL(build_manger); | |||
| static set<std::string> processed_kernel; | |||
| static std::set<std::string> processed_kernel; | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| auto tune_mode = context_ptr->get_param<std::string>(MS_CTX_TUNE_MODE); | |||
| @@ -259,8 +259,8 @@ bool ParallelBuildManager::SearchInCache(const std::string &json_name, const std | |||
| } | |||
| KernelModPtr ParallelBuildManager::GenKernelMod(const string &json_name, const string &processor, | |||
| const vector<size_t> &input_size_list, | |||
| const vector<size_t> &output_size_list, | |||
| const std::vector<size_t> &input_size_list, | |||
| const std::vector<size_t> &output_size_list, | |||
| const mindspore::kernel::KernelPackPtr &kernel_pack) const { | |||
| MS_EXCEPTION_IF_NULL(kernel_pack); | |||
| auto kernel_json_info = kernel_pack->kernel_json_info(); | |||
| @@ -27,6 +27,7 @@ | |||
| #include "proto/tensor_shape.pb.h" | |||
| #include "proto/attr.pb.h" | |||
| #include "proto/node_def.pb.h" | |||
| #include "runtime/rt.h" | |||
| using mindspore::kernel::Address; | |||
| using AddressPtr = std::shared_ptr<Address>; | |||
| @@ -24,6 +24,7 @@ | |||
| #include "ps/ps_cache/ps_cache_basic.h" | |||
| #include "backend/kernel_compiler/aicpu/aicpu_kernel_mod.h" | |||
| #include "ir/dtype.h" | |||
| #include "runtime/base.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| @@ -79,3 +79,9 @@ list(REMOVE_ITEM D_SRC_LIST "ascend/profiling/profiling_callback_register.cc") | |||
| set_property(SOURCE ${DEVICE_SRC_LIST} ${D_SRC_LIST} ${CPU_SRC_LIST} | |||
| PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE) | |||
| add_library(_mindspore_runtime_device_obj OBJECT ${DEVICE_SRC_LIST} ${D_SRC_LIST} ${CPU_SRC_LIST} ${TDT_SRC_LIST}) | |||
| if(ENABLE_D) | |||
| file(GLOB_RECURSE GE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "ascend/ge_runtime/*.cc") | |||
| set_property(SOURCE ${GE_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_GE) | |||
| target_include_directories(_mindspore_runtime_device_obj PRIVATE ${CMAKE_BINARY_DIR}/proto/ge) | |||
| add_dependencies(_mindspore_runtime_device_obj graph) | |||
| endif() | |||
| @@ -28,9 +28,9 @@ | |||
| #include "utils/mpi/mpi_config.h" | |||
| #include "runtime/device/ascend/profiling/profiling_manager.h" | |||
| #include "common/trans.h" | |||
| #include "runtime/context.h" | |||
| #include "runtime/rt.h" | |||
| #include "runtime/device/ascend/ascend_stream_assign.h" | |||
| #include "framework/ge_runtime/model_runner.h" | |||
| #include "runtime/device/ascend/ge_runtime/model_runner.h" | |||
| #include "runtime/device/ascend/tasksink/task_generator.h" | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "runtime/device/ascend/profiling/profiling_utils.h" | |||
| @@ -40,7 +40,6 @@ | |||
| #include "toolchain/adx_datadump_server.h" | |||
| #include "utils/trace_base.h" | |||
| #include "graphengine/inc/external/acl/error_codes/rt_error_codes.h" | |||
| #include "utils/runtime_error_codes.h" | |||
| #include "debug/anf_ir_dump.h" | |||
| #ifdef MEM_REUSE_DEBUG | |||
| #include "backend/optimizer/mem_reuse/mem_reuse_checker.h" | |||
| @@ -61,10 +60,10 @@ using mindspore::dataset::TdtHandle; | |||
| #include "debug/rdr/running_data_recorder.h" | |||
| #endif | |||
| using ge::model_runner::ModelRunner; | |||
| using mindspore::device::ascend::ProfilingManager; | |||
| using mindspore::device::ascend::ProfilingUtils; | |||
| using mindspore::device::ascend::tasksink::TaskGenerator; | |||
| using mindspore::ge::model_runner::ModelRunner; | |||
| using mindspore::kernel::tbe::TbeUtils; | |||
| using std::vector; | |||
| @@ -158,10 +157,7 @@ void AscendKernelRuntime::ClearGraphModelMap() { | |||
| graph_kernel_events_map_.clear(); | |||
| for (auto &iter : graph_model_map_) { | |||
| MS_LOG(INFO) << "Ge UnloadModel " << iter.first; | |||
| auto ret = ModelRunner::Instance().UnloadModel(iter.first); | |||
| if (!ret) { | |||
| MS_LOG(ERROR) << "UnloadModel failed"; | |||
| } | |||
| ModelRunner::Instance().UnloadModel(iter.first); | |||
| } | |||
| } | |||
| @@ -194,10 +190,7 @@ void AscendKernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id, const std | |||
| MS_LOG(DEBUG) << "Clear graph:" << graph_id << " runtime resource"; | |||
| if (auto model_iter = graph_model_map_.find(graph_id); model_iter != graph_model_map_.end()) { | |||
| MS_LOG(DEBUG) << "Ge UnloadModel " << graph_id; | |||
| auto ret = ModelRunner::Instance().UnloadModel(graph_id); | |||
| if (!ret) { | |||
| MS_LOG(ERROR) << "UnloadModel failed"; | |||
| } | |||
| ModelRunner::Instance().UnloadModel(graph_id); | |||
| graph_model_map_.erase(model_iter); | |||
| } else { | |||
| MS_LOG(DEBUG) << "GraphId:" << graph_id << " not found"; | |||
| @@ -482,10 +475,9 @@ bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) { | |||
| << ", total label num:" << graph->label_num() | |||
| << ", wait_active_stream_list size:" << wait_active_stream_list.size() | |||
| << ", force_copy_stream_list size:" << force_copy_stream_list.size(); | |||
| std::vector<std::shared_ptr<ge::model_runner::OpInfo>> empty_list; | |||
| auto model = std::make_shared<ge::model_runner::DavinciModel>( | |||
| task_info_list, empty_list, empty_list, empty_list, empty_list, wait_active_stream_list, force_copy_stream_list, 0, | |||
| 0, 0, 0, 0, 0, resource_manager.get_cur_stream_num(), graph->label_num(), resource_manager.get_cur_event_num(), 0); | |||
| task_info_list, wait_active_stream_list, force_copy_stream_list, 0, 0, 0, 0, 0, 0, | |||
| resource_manager.get_cur_stream_num(), graph->label_num(), resource_manager.get_cur_event_num(), 0); | |||
| auto ret = graph_model_map_.insert(std::make_pair(graph->graph_id(), model)); | |||
| if (!ret.second) { | |||
| MS_LOG(EXCEPTION) << "Duplicate GraphId! Please check in ascend_session."; | |||
| @@ -514,24 +506,20 @@ bool AscendKernelRuntime::LoadTask(const session::KernelGraph *graph) { | |||
| return false; | |||
| } | |||
| std::shared_ptr<ge::ModelListener> listener; | |||
| MS_LOG(INFO) << "LoadDavinciModel mode_id:" << model_iter->first; | |||
| bool status = | |||
| ModelRunner::Instance().LoadDavinciModel(device_id_, 0, model_iter->first, model_iter->second, listener); | |||
| if (!status) { | |||
| MS_LOG(EXCEPTION) << "Load Model Failed"; | |||
| } | |||
| ModelRunner::Instance().LoadDavinciModel(device_id_, 0, model_iter->first, model_iter->second); | |||
| std::function<void *()> model_handle = | |||
| std::bind(&ModelRunner::GetModelHandle, &ModelRunner::Instance(), model_iter->first); | |||
| DistributeDebugTask(NOT_NULL(graph), NOT_NULL(model_handle)); | |||
| status = ModelRunner::Instance().DistributeTask(model_iter->first); | |||
| if (!status) { | |||
| try { | |||
| ModelRunner::Instance().DistributeTask(model_iter->first); | |||
| } catch (const std::exception &e) { | |||
| #ifdef ENABLE_DUMP_IR | |||
| mindspore::RDR::TriggerAll(); | |||
| #endif | |||
| MS_LOG(EXCEPTION) << "Distribute Task Failed"; | |||
| MS_LOG(EXCEPTION) << "Distribute Task Failed, error: " << e.what(); | |||
| } | |||
| if (ProfilingManager::GetInstance().IsProfiling()) { | |||
| @@ -542,10 +530,7 @@ bool AscendKernelRuntime::LoadTask(const session::KernelGraph *graph) { | |||
| LaunchDataDump(graph->graph_id()); | |||
| if (!ModelRunner::Instance().LoadModelComplete(model_iter->first)) { | |||
| MS_LOG(ERROR) << "Call ge runtime LoadModelComplete failed"; | |||
| return false; | |||
| } | |||
| ModelRunner::Instance().LoadModelComplete(model_iter->first); | |||
| return true; | |||
| } | |||
| @@ -730,8 +715,6 @@ bool AscendKernelRuntime::RunTask(const session::KernelGraph *graph) { | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| ge::InputData input_tensors = ge::InputData(); | |||
| ge::OutputData *output_tensors = nullptr; | |||
| if (GraphWithEmptyTaskList(graph)) { | |||
| MS_LOG(WARNING) << "RunTask end, no task info found"; | |||
| return true; | |||
| @@ -742,8 +725,9 @@ bool AscendKernelRuntime::RunTask(const session::KernelGraph *graph) { | |||
| return false; | |||
| } | |||
| bool status = ModelRunner::Instance().RunModel(graph->graph_id(), input_tensors, output_tensors); | |||
| if (!status) { | |||
| try { | |||
| ModelRunner::Instance().RunModel(graph->graph_id()); | |||
| } catch (const std::exception &) { | |||
| DumpTaskExceptionInfo(graph); | |||
| std::string file_name = "task_error_debug" + std::to_string(graph->graph_id()) + ".ir"; | |||
| auto graph_tmp = std::make_shared<session::KernelGraph>(*graph); | |||
| @@ -988,7 +972,7 @@ void AscendKernelRuntime::KernelLaunchProfiling(const std::string &kernel_name) | |||
| } | |||
| uint64_t AscendKernelRuntime::GetAvailableMemMaxSize() const { | |||
| auto ascend_mem_manager = dynamic_pointer_cast<AscendMemoryManager>(mem_manager_); | |||
| auto ascend_mem_manager = std::dynamic_pointer_cast<AscendMemoryManager>(mem_manager_); | |||
| return ascend_mem_manager->GetDeviceMemSize(); | |||
| } | |||
| @@ -25,15 +25,15 @@ | |||
| #include <unordered_set> | |||
| #include "runtime/device/kernel_runtime.h" | |||
| #include "runtime/context.h" | |||
| #include "framework/ge_runtime/davinci_model.h" | |||
| #include "runtime/device/ascend/ge_runtime/davinci_model.h" | |||
| #include "runtime/device/kernel_runtime_manager.h" | |||
| #include "backend/session/session_basic.h" | |||
| #include "runtime/device/ascend/dump/data_dumper.h" | |||
| using ge::model_runner::TaskInfo; | |||
| using std::unordered_map; | |||
| using std::vector; | |||
| namespace mindspore::device::ascend { | |||
| using ge::model_runner::TaskInfo; | |||
| class AscendKernelRuntime : public KernelRuntime { | |||
| public: | |||
| AscendKernelRuntime() = default; | |||
| @@ -16,6 +16,7 @@ | |||
| #include <algorithm> | |||
| #include "runtime/device/ascend/ascend_memory_pool.h" | |||
| #include "runtime/mem.h" | |||
| #include "runtime/device/ascend/ascend_kernel_runtime.h" | |||
| #include "utils/log_adapter.h" | |||
| @@ -0,0 +1,92 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_DAVINCI_MODEL_H_ | |||
| #define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_DAVINCI_MODEL_H_ | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "runtime/device/ascend/ge_runtime/task_info.h" | |||
| namespace mindspore::ge::model_runner { | |||
| class DavinciModel { | |||
| public: | |||
| DavinciModel(const std::vector<std::shared_ptr<TaskInfo>> &task_info_list, | |||
| const std::vector<uint32_t> &wait_active_stream_list, | |||
| const std::vector<uint32_t> &force_copy_stream_list, uint64_t mem_size = 0, uint64_t weight_size = 0, | |||
| uint64_t var_size = 0, uintptr_t logic_mem_base = 0, uintptr_t logic_weight_base = 0, | |||
| uintptr_t logic_var_base = 0, uint32_t stream_num = 0, uint32_t batch_num = 0, uint32_t event_num = 0, | |||
| int32_t priority = 0) | |||
| : task_info_list_(task_info_list), | |||
| wait_active_stream_list_(wait_active_stream_list), | |||
| force_copy_stream_list_(force_copy_stream_list), | |||
| mem_size_(mem_size), | |||
| weight_size_(weight_size), | |||
| var_size_(var_size), | |||
| logic_mem_base_(logic_mem_base), | |||
| logic_weight_base_(logic_weight_base), | |||
| logic_var_base_(logic_var_base), | |||
| stream_num_(stream_num), | |||
| batch_num_(batch_num), | |||
| event_num_(event_num), | |||
| priority_(priority) {} | |||
| ~DavinciModel() {} | |||
| uint64_t GetMemSize() const { return mem_size_; } | |||
| uint64_t GetWeightSize() const { return weight_size_; } | |||
| uint64_t GetVarSize() const { return var_size_; } | |||
| uintptr_t GetLogicMemBase() const { return logic_mem_base_; } | |||
| uintptr_t GetLogicWeightBase() const { return logic_weight_base_; } | |||
| uintptr_t GetLogicVarBase() const { return logic_var_base_; } | |||
| uint32_t GetStreamNum() const { return stream_num_; } | |||
| uint32_t GetBatchNum() const { return batch_num_; } | |||
| uint32_t GetEventNum() const { return event_num_; } | |||
| const std::vector<uint32_t> &GetWaitActiveStreams() const { return wait_active_stream_list_; } | |||
| const std::vector<uint32_t> &GetForceCopyStreams() const { return force_copy_stream_list_; } | |||
| int32_t GetPriority() const { return priority_; } | |||
| const std::vector<std::shared_ptr<TaskInfo>> &GetTaskInfoList() const { return task_info_list_; } | |||
| private: | |||
| std::vector<std::shared_ptr<TaskInfo>> task_info_list_; | |||
| std::vector<uint32_t> wait_active_stream_list_; | |||
| std::vector<uint32_t> force_copy_stream_list_; | |||
| uint64_t mem_size_; | |||
| uint64_t weight_size_; | |||
| uint64_t var_size_; | |||
| uintptr_t logic_mem_base_; | |||
| uintptr_t logic_weight_base_; | |||
| uintptr_t logic_var_base_; | |||
| uint32_t stream_num_; | |||
| uint32_t batch_num_; | |||
| uint32_t event_num_; | |||
| int32_t priority_; | |||
| // Disable to copy constructor and assignment operator | |||
| DavinciModel &operator=(const DavinciModel &) = delete; | |||
| DavinciModel(const DavinciModel &) = delete; | |||
| }; | |||
| } // namespace mindspore::ge::model_runner | |||
| #endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_DAVINCI_MODEL_H_ | |||
| @@ -0,0 +1,59 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_MODEL_CONTEXT_H_ | |||
| #define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_MODEL_CONTEXT_H_ | |||
| #include <vector> | |||
| #include "runtime/rt_model.h" | |||
| namespace mindspore::ge::model_runner { | |||
| class ModelContext { | |||
| public: | |||
| ModelContext(uint32_t device_id, uint64_t session_id, int32_t priority, rtModel_t rt_model_handle, | |||
| rtStream_t rt_model_stream, const std::vector<rtStream_t> &stream_list, | |||
| const std::vector<rtLabel_t> &label_list, const std::vector<rtEvent_t> &event_list) | |||
| : device_id_(device_id), | |||
| session_id_(session_id), | |||
| priority_(priority), | |||
| rt_model_handle_(rt_model_handle), | |||
| rt_model_stream_(rt_model_stream), | |||
| stream_list_(stream_list), | |||
| label_list_(label_list), | |||
| event_list_(event_list) {} | |||
| ~ModelContext() {} | |||
| uint64_t device_id() const { return device_id_; } | |||
| uint64_t session_id() const { return session_id_; } | |||
| int32_t priority() const { return priority_; } | |||
| const rtModel_t &rt_model_handle() const { return rt_model_handle_; } | |||
| const rtStream_t &rt_model_stream() const { return rt_model_stream_; } | |||
| const std::vector<rtStream_t> &stream_list() const { return stream_list_; } | |||
| const std::vector<rtLabel_t> &label_list() const { return label_list_; } | |||
| const std::vector<rtEvent_t> &event_list() const { return event_list_; } | |||
| private: | |||
| uint32_t device_id_; | |||
| uint64_t session_id_; | |||
| int32_t priority_; | |||
| rtModel_t rt_model_handle_; | |||
| rtStream_t rt_model_stream_; | |||
| std::vector<rtStream_t> stream_list_; | |||
| std::vector<rtLabel_t> label_list_; | |||
| std::vector<rtEvent_t> event_list_; | |||
| }; | |||
| } // namespace mindspore::ge::model_runner | |||
| #endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_MODEL_CONTEXT_H_ | |||
| @@ -0,0 +1,104 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "runtime/device/ascend/ge_runtime/model_runner.h" | |||
| #include "runtime/device/ascend/ge_runtime/runtime_model.h" | |||
| #include "runtime/device/ascend/ge_runtime/davinci_model.h" | |||
| #include "mindspore/core/utils/log_adapter.h" | |||
| namespace mindspore::ge::model_runner { | |||
| ModelRunner &ModelRunner::Instance() { | |||
| static ModelRunner instance; // Guaranteed to be destroyed. | |||
| return instance; | |||
| } | |||
| void ModelRunner::LoadDavinciModel(uint32_t device_id, uint64_t session_id, uint32_t model_id, | |||
| const std::shared_ptr<DavinciModel> &davinci_model) { | |||
| std::shared_ptr<RuntimeModel> model = std::make_shared<RuntimeModel>(); | |||
| model->Load(device_id, session_id, davinci_model); | |||
| runtime_models_[model_id] = model; | |||
| } | |||
| void ModelRunner::DistributeTask(uint32_t model_id) { | |||
| auto model_iter = runtime_models_.find(model_id); | |||
| if (model_iter == runtime_models_.end()) { | |||
| MS_LOG(EXCEPTION) << "Model id " << model_id << " not found."; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(model_iter->second); | |||
| model_iter->second->DistributeTask(); | |||
| } | |||
| void ModelRunner::LoadModelComplete(uint32_t model_id) { | |||
| auto model_iter = runtime_models_.find(model_id); | |||
| if (model_iter == runtime_models_.end()) { | |||
| MS_LOG(EXCEPTION) << "Model id " << model_id << " not found."; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(model_iter->second); | |||
| model_iter->second->LoadComplete(); | |||
| } | |||
| const std::vector<uint32_t> &ModelRunner::GetTaskIdList(uint32_t model_id) const { | |||
| auto model_iter = runtime_models_.find(model_id); | |||
| if (model_iter == runtime_models_.end()) { | |||
| MS_LOG(EXCEPTION) << "Model id " << model_id << " not found."; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(model_iter->second); | |||
| return model_iter->second->GetTaskIdList(); | |||
| } | |||
| const std::vector<uint32_t> &ModelRunner::GetStreamIdList(uint32_t model_id) const { | |||
| auto model_iter = runtime_models_.find(model_id); | |||
| if (model_iter == runtime_models_.end()) { | |||
| MS_LOG(EXCEPTION) << "Model id " << model_id << " not found."; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(model_iter->second); | |||
| return model_iter->second->GetStreamIdList(); | |||
| } | |||
| const std::map<std::string, std::shared_ptr<RuntimeInfo>> &ModelRunner::GetRuntimeInfoMap(uint32_t model_id) const { | |||
| auto model_iter = runtime_models_.find(model_id); | |||
| if (model_iter == runtime_models_.end()) { | |||
| MS_LOG(EXCEPTION) << "Model id " << model_id << " not found."; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(model_iter->second); | |||
| return model_iter->second->GetRuntimeInfoMap(); | |||
| } | |||
| void *ModelRunner::GetModelHandle(uint32_t model_id) const { | |||
| auto model_iter = runtime_models_.find(model_id); | |||
| if (model_iter == runtime_models_.end()) { | |||
| MS_LOG(EXCEPTION) << "Model id " << model_id << " not found."; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(model_iter->second); | |||
| return model_iter->second->GetModelHandle(); | |||
| } | |||
| void ModelRunner::UnloadModel(uint32_t model_id) { | |||
| auto iter = runtime_models_.find(model_id); | |||
| if (iter != runtime_models_.end()) { | |||
| (void)runtime_models_.erase(iter); | |||
| } | |||
| } | |||
| void ModelRunner::RunModel(uint32_t model_id) { | |||
| auto model_iter = runtime_models_.find(model_id); | |||
| if (model_iter == runtime_models_.end()) { | |||
| MS_LOG(EXCEPTION) << "Model id " << model_id << " not found."; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(model_iter->second); | |||
| model_iter->second->Run(); | |||
| } | |||
| } // namespace mindspore::ge::model_runner | |||
| @@ -0,0 +1,60 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_MODEL_RUNNER_H_ | |||
| #define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_MODEL_RUNNER_H_ | |||
| #include <memory> | |||
| #include <map> | |||
| #include <vector> | |||
| #include <tuple> | |||
| #include <string> | |||
| #include "runtime/device/ascend/ge_runtime/davinci_model.h" | |||
| namespace mindspore::ge::model_runner { | |||
| class RuntimeModel; | |||
| using RuntimeInfo = std::tuple<uint32_t, uint32_t, void *>; | |||
| class ModelRunner { | |||
| public: | |||
| static ModelRunner &Instance(); | |||
| void LoadDavinciModel(uint32_t device_id, uint64_t session_id, uint32_t model_id, | |||
| const std::shared_ptr<DavinciModel> &davinci_model); | |||
| void DistributeTask(uint32_t model_id); | |||
| void LoadModelComplete(uint32_t model_id); | |||
| const std::vector<uint32_t> &GetTaskIdList(uint32_t model_id) const; | |||
| const std::vector<uint32_t> &GetStreamIdList(uint32_t model_id) const; | |||
| const std::map<std::string, std::shared_ptr<RuntimeInfo>> &GetRuntimeInfoMap(uint32_t model_id) const; | |||
| void *GetModelHandle(uint32_t model_id) const; | |||
| void UnloadModel(uint32_t model_id); | |||
| void RunModel(uint32_t model_id); | |||
| private: | |||
| ModelRunner() = default; | |||
| ~ModelRunner() = default; | |||
| std::map<uint32_t, std::shared_ptr<RuntimeModel>> runtime_models_; | |||
| }; | |||
| } // namespace mindspore::ge::model_runner | |||
| #endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_MODEL_RUNNER_H_ | |||
| @@ -0,0 +1,292 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "runtime/device/ascend/ge_runtime/runtime_model.h" | |||
| #include <set> | |||
| #include "runtime/kernel.h" | |||
| #include "runtime/rt_model.h" | |||
| #include "graphengine/inc/external/runtime/rt_error_codes.h" | |||
| #include "runtime/device/ascend/ge_runtime/model_context.h" | |||
| #include "runtime/device/ascend/ge_runtime/task/task.h" | |||
| #include "runtime/device/ascend/ge_runtime/task/task_factory.h" | |||
| #include "mindspore/core/utils/log_adapter.h" | |||
| namespace mindspore::ge::model_runner { | |||
| RuntimeModel::~RuntimeModel() { | |||
| MS_LOG(INFO) << "RuntimeModel destructor start."; | |||
| // Unbind rtModel from all task related streams | |||
| RtModelUnbindStream(); | |||
| // Release task first, hccl task hold stream | |||
| task_list_.clear(); | |||
| // Release all task related streams | |||
| RtStreamDestory(); | |||
| // Release rtlabel resource | |||
| RtLabelDestory(); | |||
| // Release rtEvent resourece | |||
| RtEventDestory(); | |||
| MS_LOG(INFO) << "Do RtModelDestroy"; | |||
| // Release all rt_model | |||
| RtModelDestory(); | |||
| } | |||
| void RuntimeModel::InitStream(const std::shared_ptr<DavinciModel> &davinci_model) { | |||
| MS_EXCEPTION_IF_NULL(davinci_model); | |||
| std::set<int64_t> wait_active_streams; | |||
| std::set<int64_t> force_copy_streams; | |||
| for (const auto &stream_id : davinci_model->GetWaitActiveStreams()) { | |||
| MS_LOG(INFO) << "Stream id " << stream_id << " is wait active stream."; | |||
| (void)wait_active_streams.insert(stream_id); | |||
| } | |||
| for (const auto &stream_id : davinci_model->GetForceCopyStreams()) { | |||
| MS_LOG(INFO) << "Stream id " << stream_id << " is force copy stream."; | |||
| (void)force_copy_streams.insert(stream_id); | |||
| } | |||
| MS_LOG(INFO) << "Total stream num " << davinci_model->GetStreamNum(); | |||
| for (uint32_t i = 0; i < davinci_model->GetStreamNum(); ++i) { | |||
| rtStream_t stream = nullptr; | |||
| uint32_t flag = (force_copy_streams.find(i) != force_copy_streams.end()) | |||
| ? (RT_STREAM_PERSISTENT | RT_STREAM_FORCE_COPY) | |||
| : (RT_STREAM_PERSISTENT); | |||
| rtError_t rt_ret = rtStreamCreateWithFlags(&stream, davinci_model->GetPriority(), flag); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| MS_LOG(EXCEPTION) << "Call rt api rtStreamCreate failed, ret: " << std::hex << rt_ret; | |||
| } | |||
| MS_LOG(INFO) << "rtStreamCreateWithFlags end."; | |||
| stream_list_.emplace_back(stream); | |||
| // Bind rt_model_handle_ to all task related streams | |||
| flag = (wait_active_streams.find(i) != wait_active_streams.end()) ? (static_cast<uint32_t>(RT_INVALID_FLAG)) | |||
| : (static_cast<uint32_t>(RT_HEAD_STREAM)); | |||
| rt_ret = rtModelBindStream(rt_model_handle_, stream, flag); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| MS_LOG(EXCEPTION) << "Call rt api rtModelBindStream failed, ret: " << std::hex << rt_ret; | |||
| } | |||
| MS_LOG(INFO) << "stream index: " << i << ", stream: " << std::hex << stream; | |||
| } | |||
| } | |||
| void RuntimeModel::InitEvent(uint32_t event_num) { | |||
| MS_LOG(INFO) << "Event number: " << event_num; | |||
| for (uint32_t i = 0; i < event_num; ++i) { | |||
| rtEvent_t rt_event; | |||
| rtError_t rt_ret = rtEventCreate(&rt_event); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| MS_LOG(EXCEPTION) << "Call rt api rtEventCreate failed, ret: " << std::hex << rt_ret; | |||
| } | |||
| event_list_.push_back(rt_event); | |||
| } | |||
| } | |||
| void RuntimeModel::InitLabel(const std::shared_ptr<DavinciModel> &davinci_model) { | |||
| MS_LOG(INFO) << "Label number: " << davinci_model->GetBatchNum(); | |||
| label_list_.resize(davinci_model->GetBatchNum()); | |||
| for (auto &task_info : davinci_model->GetTaskInfoList()) { | |||
| MS_EXCEPTION_IF_NULL(task_info); | |||
| if (task_info->type() != TaskInfoType::LABEL_SET) { | |||
| continue; | |||
| } | |||
| auto label_set_task_info = std::static_pointer_cast<LabelSetTaskInfo>(task_info); | |||
| if (label_set_task_info->stream_id() >= stream_list_.size()) { | |||
| MS_LOG(EXCEPTION) << "Invalid stream id " << label_set_task_info->stream_id() << " total stream num " | |||
| << stream_list_.size(); | |||
| } | |||
| rtLabel_t rt_label = nullptr; | |||
| rtError_t rt_ret = rtLabelCreateEx(&rt_label, stream_list_[label_set_task_info->stream_id()]); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| MS_LOG(EXCEPTION) << "Call rt api rtLabelCreate failed, ret: " << std::hex << rt_ret; | |||
| } | |||
| label_list_[label_set_task_info->label_id()] = rt_label; | |||
| } | |||
| } | |||
| void RuntimeModel::InitResource(const std::shared_ptr<DavinciModel> &davinci_model) { | |||
| MS_LOG(INFO) << "InitResource start"; | |||
| MS_EXCEPTION_IF_NULL(davinci_model); | |||
| rtError_t rt_ret = rtModelCreate(&rt_model_handle_, 0); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| MS_LOG(EXCEPTION) << "Call rt api rtModelCreate failed, ret: " << std::hex << rt_ret; | |||
| } | |||
| // Create rtStream for rt_model_handle_ | |||
| rt_ret = rtStreamCreate(&rt_model_stream_, davinci_model->GetPriority()); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| MS_LOG(EXCEPTION) << "Call rt api rtStreamCreate failed, ret: " << std::hex << rt_ret; | |||
| } | |||
| MS_LOG(INFO) << "rtStreamCreate end"; | |||
| InitStream(davinci_model); | |||
| InitEvent(davinci_model->GetEventNum()); | |||
| InitLabel(davinci_model); | |||
| MS_LOG(INFO) << "InitResource success"; | |||
| } | |||
| void RuntimeModel::GenerateTask(uint32_t device_id, uint64_t session_id, | |||
| const std::shared_ptr<DavinciModel> &davinci_model) { | |||
| MS_LOG(INFO) << "GenerateTask start."; | |||
| MS_EXCEPTION_IF_NULL(davinci_model); | |||
| auto task_infos = davinci_model->GetTaskInfoList(); | |||
| ModelContext model_context(device_id, session_id, davinci_model->GetPriority(), rt_model_handle_, rt_model_stream_, | |||
| stream_list_, label_list_, event_list_); | |||
| for (auto &task_info : task_infos) { | |||
| auto task = TaskFactory::GetInstance().Create(model_context, task_info); | |||
| task_list_.push_back(task); | |||
| } | |||
| MS_LOG(INFO) << "GenerateTask success."; | |||
| } | |||
| void RuntimeModel::LoadComplete() { | |||
| uint32_t task_id = 0; | |||
| uint32_t stream_id = 0; | |||
| auto rt_ret = rtModelGetTaskId(rt_model_handle_, &task_id, &stream_id); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| MS_LOG(EXCEPTION) << "Call rt api rtModelGetTaskId failed, ret: " << std::hex << rt_ret; | |||
| } | |||
| task_id_list_.push_back(task_id); | |||
| stream_id_list_.push_back(stream_id); | |||
| rt_ret = rtModelLoadComplete(rt_model_handle_); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| MS_LOG(EXCEPTION) << "Call rt api rtModelLoadComplete failed, ret: " << std::hex << rt_ret; | |||
| } | |||
| } | |||
| void RuntimeModel::Load(uint32_t device_id, uint64_t session_id, const std::shared_ptr<DavinciModel> &davinci_model) { | |||
| InitResource(davinci_model); | |||
| GenerateTask(device_id, session_id, davinci_model); | |||
| } | |||
| void RuntimeModel::DistributeTask() { | |||
| MS_LOG(INFO) << "DistributeTask start."; | |||
| for (auto &task : task_list_) { | |||
| MS_EXCEPTION_IF_NULL(task); | |||
| task->Distribute(); | |||
| uint32_t task_id = 0; | |||
| uint32_t stream_id = 0; | |||
| rtError_t rt_ret = rtModelGetTaskId(rt_model_handle_, &task_id, &stream_id); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| MS_LOG(EXCEPTION) << "Call rt api rtModelGetTaskId failed, ret: " << std::hex << rt_ret; | |||
| } | |||
| task_id_list_.push_back(task_id); | |||
| stream_id_list_.push_back(stream_id); | |||
| if (task->Args() != nullptr) { | |||
| std::shared_ptr<RuntimeInfo> runtime_tuple = std::make_shared<RuntimeInfo>(task_id, stream_id, task->Args()); | |||
| auto emplace_ret = runtime_info_map_.emplace(task->task_name(), runtime_tuple); | |||
| if (!emplace_ret.second) { | |||
| MS_LOG(WARNING) << "Task name exist: " << task->task_name(); | |||
| } | |||
| } | |||
| } | |||
| if (task_list_.empty()) { | |||
| MS_LOG(EXCEPTION) << "Task list is empty"; | |||
| } | |||
| MS_LOG(INFO) << "DistributeTask success."; | |||
| } | |||
| void RuntimeModel::Run() { | |||
| MS_LOG(INFO) << "Davinci task run start."; | |||
| rtError_t ret = rtModelExecute(rt_model_handle_, rt_model_stream_, 0); | |||
| if (ret != RT_ERROR_NONE) { | |||
| MS_LOG(EXCEPTION) << "Call rt api rtModelLoadComplete failed, ret: " << std::hex << ret; | |||
| } | |||
| MS_LOG(INFO) << "Run rtModelExecute success, start to rtStreamSynchronize."; | |||
| ret = rtStreamSynchronize(rt_model_stream_); | |||
| if (ret != RT_ERROR_NONE) { | |||
| if (ret == ACL_ERROR_RT_END_OF_SEQUENCE) { | |||
| MS_LOG(INFO) << "Model stream ACL_ERROR_RT_END_OF_SEQUENCE signal received."; | |||
| return; | |||
| } | |||
| MS_LOG(EXCEPTION) << "Call rt api rtStreamSynchronize failed, ret: " << std::hex << ret; | |||
| } | |||
| MS_LOG(INFO) << "Davinci task run success."; | |||
| } | |||
| void RuntimeModel::RtModelUnbindStream() noexcept { | |||
| for (size_t i = 0; i < stream_list_.size(); i++) { | |||
| if (rtModelUnbindStream(rt_model_handle_, stream_list_[i]) != RT_ERROR_NONE) { | |||
| MS_LOG(ERROR) << "Unbind stream from model failed! Index: " << i; | |||
| return; | |||
| } | |||
| } | |||
| } | |||
| void RuntimeModel::RtStreamDestory() noexcept { | |||
| if (rtStreamDestroy(rt_model_stream_) != RT_ERROR_NONE) { | |||
| MS_LOG(ERROR) << "Destroy stream for rt_model failed!"; | |||
| return; | |||
| } | |||
| for (size_t i = 0; i < stream_list_.size(); i++) { | |||
| if (rtStreamDestroy(stream_list_[i]) != RT_ERROR_NONE) { | |||
| MS_LOG(ERROR) << "Destroy stream failed! Index: " << i; | |||
| return; | |||
| } | |||
| } | |||
| } | |||
| void RuntimeModel::RtLabelDestory() noexcept { | |||
| for (size_t i = 0; i < label_list_.size(); i++) { | |||
| if (label_list_[i] == nullptr) { | |||
| continue; | |||
| } | |||
| if (rtLabelDestroy(label_list_[i]) != RT_ERROR_NONE) { | |||
| MS_LOG(ERROR) << "Destroy label failed! Index: " << i; | |||
| return; | |||
| } | |||
| } | |||
| } | |||
| void RuntimeModel::RtModelDestory() noexcept { | |||
| rtError_t ret = rtModelDestroy(rt_model_handle_); | |||
| if (ret != RT_ERROR_NONE) { | |||
| MS_LOG(ERROR) << "Call rt api rtModelDestroy failed, ret: " << std::hex << ret; | |||
| return; | |||
| } | |||
| } | |||
| void RuntimeModel::RtEventDestory() noexcept { | |||
| for (size_t i = 0; i < event_list_.size(); i++) { | |||
| if (rtEventDestroy(event_list_[i]) != RT_ERROR_NONE) { | |||
| MS_LOG(ERROR) << "Destroy event failed! Index: " << i; | |||
| return; | |||
| } | |||
| } | |||
| } | |||
| const std::vector<uint32_t> &RuntimeModel::GetTaskIdList() const { return task_id_list_; } | |||
| const std::vector<uint32_t> &RuntimeModel::GetStreamIdList() const { return stream_id_list_; } | |||
| } // namespace mindspore::ge::model_runner | |||
| @@ -0,0 +1,71 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_RUNTIME_MODEL_H_ | |||
| #define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_RUNTIME_MODEL_H_ | |||
| #include <map> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include <tuple> | |||
| #include "runtime/base.h" | |||
| #include "runtime/rt_model.h" | |||
| #include "runtime/device/ascend/ge_runtime/davinci_model.h" | |||
| namespace mindspore::ge::model_runner { | |||
| using RuntimeInfo = std::tuple<uint32_t, uint32_t, void *>; | |||
| class Task; | |||
| class RuntimeModel { | |||
| public: | |||
| RuntimeModel() = default; | |||
| ~RuntimeModel(); | |||
| void Load(uint32_t device_id, uint64_t session_id, const std::shared_ptr<DavinciModel> &davinci_model); | |||
| void DistributeTask(); | |||
| void LoadComplete(); | |||
| const std::vector<uint32_t> &GetTaskIdList() const; | |||
| const std::vector<uint32_t> &GetStreamIdList() const; | |||
| const std::map<std::string, std::shared_ptr<RuntimeInfo>> &GetRuntimeInfoMap() const { return runtime_info_map_; } | |||
| rtModel_t GetModelHandle() const { return rt_model_handle_; } | |||
| void Run(); | |||
| private: | |||
| void InitResource(const std::shared_ptr<DavinciModel> &davinci_model); | |||
| void GenerateTask(uint32_t device_id, uint64_t session_id, const std::shared_ptr<DavinciModel> &davinci_model); | |||
| void InitStream(const std::shared_ptr<DavinciModel> &davinci_model); | |||
| void InitEvent(uint32_t event_num); | |||
| void InitLabel(const std::shared_ptr<DavinciModel> &davinci_model); | |||
| void RtModelUnbindStream() noexcept; | |||
| void RtStreamDestory() noexcept; | |||
| void RtModelDestory() noexcept; | |||
| void RtLabelDestory() noexcept; | |||
| void RtEventDestory() noexcept; | |||
| rtModel_t rt_model_handle_{}; | |||
| rtStream_t rt_model_stream_{}; | |||
| std::vector<rtStream_t> stream_list_{}; | |||
| std::vector<rtLabel_t> label_list_{}; | |||
| std::vector<rtEvent_t> event_list_{}; | |||
| std::vector<std::shared_ptr<Task>> task_list_{}; | |||
| std::vector<uint32_t> task_id_list_{}; | |||
| std::vector<uint32_t> stream_id_list_{}; | |||
| std::map<std::string, std::shared_ptr<RuntimeInfo>> runtime_info_map_; | |||
| }; | |||
| } // namespace mindspore::ge::model_runner | |||
| #endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_RUNTIME_MODEL_H_ | |||
| @@ -0,0 +1,168 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "runtime/device/ascend/ge_runtime/task/aicpu_task.h" | |||
| #include <vector> | |||
| #include "runtime/mem.h" | |||
| #include "runtime/kernel.h" | |||
| #include "runtime/device/ascend/ge_runtime/task/task_factory.h" | |||
| #include "aicpu/common/aicpu_task_struct.h" | |||
| namespace mindspore::ge::model_runner { | |||
| AicpuTask::AicpuTask(const ModelContext &model_context, const std::shared_ptr<AicpuTaskInfo> &task_info) | |||
| : TaskRepeater<AicpuTaskInfo>(model_context, task_info), | |||
| task_info_(task_info), | |||
| stream_(nullptr), | |||
| args_(nullptr), | |||
| ext_info_(nullptr), | |||
| input_output_addr_(nullptr) { | |||
| MS_EXCEPTION_IF_NULL(task_info_); | |||
| auto stream_list = model_context.stream_list(); | |||
| if (stream_list.size() == 1) { | |||
| stream_ = stream_list[0]; | |||
| } else if (stream_list.size() > task_info_->stream_id()) { | |||
| stream_ = stream_list[task_info_->stream_id()]; | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Index: " << task_info_->stream_id() << " >= stream_list.size(): " << stream_list.size(); | |||
| } | |||
| } | |||
| AicpuTask::~AicpuTask() { | |||
| ReleaseRtMem(&args_); | |||
| ReleaseRtMem(&ext_info_); | |||
| } | |||
| void AicpuTask::Distribute() { | |||
| MS_LOG(INFO) << "InitAicpuTask start."; | |||
| std::vector<void *> io_addrs; | |||
| io_addrs.insert(io_addrs.end(), task_info_->input_data_addrs().begin(), task_info_->input_data_addrs().end()); | |||
| io_addrs.insert(io_addrs.end(), task_info_->output_data_addrs().begin(), task_info_->output_data_addrs().end()); | |||
| auto io_addrs_num = static_cast<uint32_t>(io_addrs.size()); | |||
| auto io_addrs_size = static_cast<uint32_t>(io_addrs_num * sizeof(void *)); | |||
| constexpr uint32_t io_addr_offset = sizeof(aicpu::AicpuParamHead); | |||
| uint32_t node_def_len_offset = io_addr_offset + io_addrs_size; | |||
| uint32_t node_def_addr_offset = node_def_len_offset + sizeof(uint32_t); | |||
| uint32_t args_size = sizeof(aicpu::AicpuParamHead) + io_addrs_size + | |||
| static_cast<uint32_t>(task_info_->node_def().size()) + sizeof(uint32_t); | |||
| // Malloc device memory for args | |||
| rtError_t rt_ret = rtMalloc(&args_, args_size, RT_MEMORY_HBM); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| MS_LOG(EXCEPTION) << "Call rt api rtMalloc failed, ret: " << std::hex << rt_ret; | |||
| } | |||
| SetAicpuParamHead(args_size, io_addrs_num); | |||
| SetInputOutputAddrs(io_addrs, io_addr_offset); | |||
| SetNodeDef(node_def_len_offset, node_def_addr_offset); | |||
| // for data dump | |||
| input_output_addr_ = reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(args_) + io_addr_offset); | |||
| auto dump_flag = task_info_->dump_flag() ? RT_KERNEL_DUMPFLAG : RT_KERNEL_DEFAULT; | |||
| MS_LOG(INFO) << "Distribute AicpuTask start, args_size = " << args_size << ", io_addrs_num =" << io_addrs_num | |||
| << ", so_name = " << task_info_->so_name() << ", kernel_name = " << task_info_->kernel_name() | |||
| << ", dump_flag = " << dump_flag; | |||
| rt_ret = rtCpuKernelLaunchWithFlag(reinterpret_cast<const void *>(task_info_->so_name().data()), | |||
| reinterpret_cast<const void *>(task_info_->kernel_name().data()), 1, args_, | |||
| args_size, nullptr, stream_, dump_flag); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| MS_LOG(EXCEPTION) << "Call rt api rtCpuKernelLaunchWithFlag failed, ret: " << std::hex << rt_ret; | |||
| } | |||
| MS_LOG(INFO) << "Distribute AicpuTask end."; | |||
| } | |||
| void AicpuTask::ReleaseRtMem(void **ptr) noexcept { | |||
| if (ptr == nullptr || *ptr == nullptr) { | |||
| return; | |||
| } | |||
| rtError_t rt_ret = rtFree(*ptr); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| return; | |||
| } | |||
| *ptr = nullptr; | |||
| } | |||
| void AicpuTask::SetAicpuParamHead(uint32_t args_size, uint32_t io_addrs_num) { | |||
| aicpu::AicpuParamHead aicpu_param_head; | |||
| aicpu_param_head.length = args_size; | |||
| aicpu_param_head.ioAddrNum = io_addrs_num; | |||
| const auto &ext_info = task_info_->ext_info(); | |||
| uint32_t ext_size = ext_info.size(); | |||
| if (ext_info.empty()) { | |||
| aicpu_param_head.extInfoLength = 0; | |||
| aicpu_param_head.extInfoAddr = 0; | |||
| } else { | |||
| rtError_t flag = rtMalloc(&ext_info_, ext_size, RT_MEMORY_HBM); | |||
| if (flag != RT_ERROR_NONE) { | |||
| MS_LOG(EXCEPTION) << "Call rt api rtMalloc failed, ret: " << std::hex << flag; | |||
| } | |||
| flag = rtMemcpy(ext_info_, ext_size, const_cast<void *>(reinterpret_cast<const void *>(ext_info.data())), ext_size, | |||
| RT_MEMCPY_HOST_TO_DEVICE); | |||
| if (flag != RT_ERROR_NONE) { | |||
| MS_LOG(EXCEPTION) << "Call rt api rtMemcpy failed, ret: " << std::hex << flag; | |||
| } | |||
| MS_LOG(INFO) << "ext info size: " << ext_size; | |||
| aicpu_param_head.extInfoLength = ext_size; | |||
| aicpu_param_head.extInfoAddr = reinterpret_cast<uintptr_t>(ext_info_); | |||
| } | |||
| // Memcpy AicpuParamHead | |||
| auto rt_ret = rtMemcpy(args_, sizeof(aicpu::AicpuParamHead), reinterpret_cast<void *>(&aicpu_param_head), | |||
| sizeof(aicpu::AicpuParamHead), RT_MEMCPY_HOST_TO_DEVICE); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| MS_LOG(EXCEPTION) << "Call rt api rtMemcpy failed, ret: " << std::hex << rt_ret; | |||
| } | |||
| } | |||
| void AicpuTask::SetInputOutputAddrs(const std::vector<void *> &io_addrs, uint32_t io_addr_offset) { | |||
| // Memcpy io addrs | |||
| if (!io_addrs.empty()) { | |||
| auto rt_ret = rtMemcpy(reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(args_) + io_addr_offset), | |||
| static_cast<uint32_t>(io_addrs.size()) * sizeof(void *), io_addrs.data(), | |||
| static_cast<uint32_t>(io_addrs.size()) * sizeof(void *), RT_MEMCPY_HOST_TO_DEVICE); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| MS_LOG(EXCEPTION) << "Call rt api rtMemcpy failed, ret: " << std::hex << rt_ret; | |||
| } | |||
| } | |||
| } | |||
| void AicpuTask::SetNodeDef(uint32_t node_def_len_offset, uint32_t node_def_addr_offset) { | |||
| // Memcpy node def | |||
| auto size = task_info_->node_def().size(); | |||
| auto rt_ret = | |||
| rtMemcpy(reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(args_) + node_def_len_offset), sizeof(uint32_t), | |||
| reinterpret_cast<const void *>(&size), sizeof(uint32_t), RT_MEMCPY_HOST_TO_DEVICE); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| MS_LOG(EXCEPTION) << "Call rt api rtMemcpy failed, ret: " << std::hex << rt_ret; | |||
| } | |||
| // Memcpy node def | |||
| rt_ret = rtMemcpy(reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(args_) + node_def_addr_offset), | |||
| task_info_->node_def().size(), reinterpret_cast<const void *>(task_info_->node_def().data()), | |||
| task_info_->node_def().size(), RT_MEMCPY_HOST_TO_DEVICE); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| MS_LOG(EXCEPTION) << "Call rt api rtMemcpy failed, ret: " << std::hex << rt_ret; | |||
| } | |||
| } | |||
| REGISTER_TASK(TaskInfoType::AICPU, AicpuTask, AicpuTaskInfo); | |||
| } // namespace mindspore::ge::model_runner | |||
| @@ -0,0 +1,51 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_AICPU_TASK_H_ | |||
| #define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_AICPU_TASK_H_ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <string> | |||
| #include "runtime/device/ascend/ge_runtime/task/task.h" | |||
| namespace mindspore::ge::model_runner { | |||
| class AicpuTask : public TaskRepeater<AicpuTaskInfo> { | |||
| public: | |||
| AicpuTask(const ModelContext &model_context, const std::shared_ptr<AicpuTaskInfo> &task_info); | |||
| ~AicpuTask() override; | |||
| void Distribute() override; | |||
| void *Args() override { return input_output_addr_; } | |||
| std::string task_name() const override { return task_info_->op_name(); } | |||
| private: | |||
| static void ReleaseRtMem(void **ptr) noexcept; | |||
| void SetAicpuParamHead(uint32_t args_size, uint32_t io_addrs_num); | |||
| void SetInputOutputAddrs(const std::vector<void *> &io_addrs, uint32_t io_addr_offset); | |||
| void SetNodeDef(uint32_t node_def_len_offset, uint32_t node_def_addr_offset); | |||
| std::shared_ptr<AicpuTaskInfo> task_info_; | |||
| void *stream_; | |||
| void *args_; | |||
| void *ext_info_; | |||
| void *input_output_addr_; | |||
| }; | |||
| } // namespace mindspore::ge::model_runner | |||
| #endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_AICPU_TASK_H_ | |||
| @@ -0,0 +1,54 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "runtime/device/ascend/ge_runtime/task/event_record_task.h" | |||
| #include "runtime/kernel.h" | |||
| #include "runtime/device/ascend/ge_runtime/task/task_factory.h" | |||
| namespace mindspore::ge::model_runner { | |||
| EventRecordTask::EventRecordTask(const ModelContext &model_context, | |||
| const std::shared_ptr<EventRecordTaskInfo> &task_info) | |||
| : TaskRepeater<EventRecordTaskInfo>(model_context, task_info), | |||
| task_info_(task_info), | |||
| stream_(nullptr), | |||
| event_(nullptr) { | |||
| MS_EXCEPTION_IF_NULL(task_info_); | |||
| auto stream_list = model_context.stream_list(); | |||
| auto event_list = model_context.event_list(); | |||
| uint32_t stream_id = task_info_->stream_id(); | |||
| uint32_t event_id = task_info_->event_id(); | |||
| if (stream_id >= stream_list.size() || event_id >= event_list.size()) { | |||
| MS_LOG(EXCEPTION) << "stream_list size: " << stream_list.size() << ", stream_id: " << stream_id | |||
| << ", event_list size: " << event_list.size() << ", event_id: " << event_id; | |||
| } | |||
| stream_ = stream_list[stream_id]; | |||
| event_ = event_list[event_id]; | |||
| } | |||
| EventRecordTask::~EventRecordTask() {} | |||
| void EventRecordTask::Distribute() { | |||
| MS_LOG(INFO) << "EventRecordTask Distribute start, stream: " << stream_ << ", event: " << event_ | |||
| << ", stream_id: " << task_info_->stream_id() << ", event_id: " << task_info_->event_id(); | |||
| rtError_t rt_ret = rtEventRecord(event_, stream_); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| MS_LOG(EXCEPTION) << "Call rt api rtEventRecord failed, ret: " << std::hex << rt_ret; | |||
| } | |||
| MS_LOG(INFO) << "Distribute end."; | |||
| } | |||
| REGISTER_TASK(TaskInfoType::EVENT_RECORD, EventRecordTask, EventRecordTaskInfo); | |||
| } // namespace mindspore::ge::model_runner | |||
| @@ -0,0 +1,38 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_EVENT_RECORD_TASK_H_ | |||
| #define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_EVENT_RECORD_TASK_H_ | |||
| #include <memory> | |||
| #include "runtime/device/ascend/ge_runtime/task/task.h" | |||
| namespace mindspore::ge::model_runner { | |||
| class EventRecordTask : public TaskRepeater<EventRecordTaskInfo> { | |||
| public: | |||
| EventRecordTask(const ModelContext &model_context, const std::shared_ptr<EventRecordTaskInfo> &task_info); | |||
| ~EventRecordTask() override; | |||
| void Distribute() override; | |||
| private: | |||
| std::shared_ptr<EventRecordTaskInfo> task_info_; | |||
| rtStream_t stream_; | |||
| rtEvent_t event_; | |||
| }; | |||
| } // namespace mindspore::ge::model_runner | |||
| #endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_EVENT_RECORD_TASK_H_ | |||
| @@ -0,0 +1,59 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "runtime/device/ascend/ge_runtime/task/event_wait_task.h" | |||
| #include "runtime/kernel.h" | |||
| #include "runtime/device/ascend/ge_runtime/task/task_factory.h" | |||
| namespace mindspore::ge::model_runner { | |||
| EventWaitTask::EventWaitTask(const ModelContext &model_context, const std::shared_ptr<EventWaitTaskInfo> &task_info) | |||
| : TaskRepeater<EventWaitTaskInfo>(model_context, task_info), | |||
| task_info_(task_info), | |||
| stream_(nullptr), | |||
| event_(nullptr) { | |||
| MS_EXCEPTION_IF_NULL(task_info_); | |||
| auto stream_list = model_context.stream_list(); | |||
| auto event_list = model_context.event_list(); | |||
| uint32_t stream_id = task_info_->stream_id(); | |||
| uint32_t event_id = task_info_->event_id(); | |||
| if (stream_id >= stream_list.size() || event_id >= event_list.size()) { | |||
| MS_LOG(EXCEPTION) << "stream_list size: " << stream_list.size() << ", stream_id: " << stream_id | |||
| << ", event_list size: " << event_list.size() << ", event_id: " << event_id; | |||
| } | |||
| stream_ = stream_list[stream_id]; | |||
| event_ = event_list[event_id]; | |||
| } | |||
| EventWaitTask::~EventWaitTask() {} | |||
| void EventWaitTask::Distribute() { | |||
| MS_LOG(INFO) << "EventWaitTask Distribute start, stream: " << stream_ << ", event: " << event_ | |||
| << ", stream_id: " << task_info_->stream_id() << ", event_id: " << task_info_->event_id(); | |||
| rtError_t rt_ret = rtStreamWaitEvent(stream_, event_); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| MS_LOG(EXCEPTION) << "Call rt api rtStreamWaitEvent failed, ret: " << std::hex << rt_ret; | |||
| } | |||
| rt_ret = rtEventReset(event_, stream_); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| MS_LOG(EXCEPTION) << "Call rt api rtEventReset failed, ret: " << std::hex << rt_ret; | |||
| } | |||
| MS_LOG(INFO) << "Distribute end."; | |||
| } | |||
| REGISTER_TASK(TaskInfoType::EVENT_WAIT, EventWaitTask, EventWaitTaskInfo); | |||
| } // namespace mindspore::ge::model_runner | |||
| @@ -0,0 +1,38 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_EVENT_WAIT_TASK_H_ | |||
| #define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_EVENT_WAIT_TASK_H_ | |||
| #include <memory> | |||
| #include "runtime/device/ascend/ge_runtime/task/task.h" | |||
| namespace mindspore::ge::model_runner { | |||
| class EventWaitTask : public TaskRepeater<EventWaitTaskInfo> { | |||
| public: | |||
| EventWaitTask(const ModelContext &model_context, const std::shared_ptr<EventWaitTaskInfo> &task_info); | |||
| ~EventWaitTask() override; | |||
| void Distribute() override; | |||
| private: | |||
| std::shared_ptr<EventWaitTaskInfo> task_info_; | |||
| rtStream_t stream_; | |||
| rtEvent_t event_; | |||
| }; | |||
| } // namespace mindspore::ge::model_runner | |||
| #endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_EVENT_WAIT_TASK_H_ | |||
| @@ -0,0 +1,221 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "runtime/device/ascend/ge_runtime/task/hccl_task.h" | |||
| #include <algorithm> | |||
| #include "runtime/device/ascend/ge_runtime/task/task_factory.h" | |||
| #include "common/opskernel/ops_kernel_info_store.h" | |||
| #include "common/opskernel/ge_task_info.h" | |||
| namespace mindspore::ge::model_runner { | |||
| std::map<rtModel_t, std::map<uint32_t, std::vector<std::weak_ptr<HcclTask::StreamGuard>>>> | |||
| HcclTask::model_stream_mapping_; | |||
| std::mutex HcclTask::model_stream_mapping_mutex_; | |||
| HcclTask::HcclTask(const ModelContext &model_context, const std::shared_ptr<HcclTaskInfo> &task_info) | |||
| : TaskRepeater<HcclTaskInfo>(model_context, task_info), | |||
| task_info_(task_info), | |||
| stream_(nullptr), | |||
| workspace_mem_(nullptr), | |||
| rt_model_handle_(nullptr), | |||
| priority_(0), | |||
| secondary_stream_list_() { | |||
| MS_EXCEPTION_IF_NULL(task_info_); | |||
| priority_ = model_context.priority(); | |||
| rt_model_handle_ = model_context.rt_model_handle(); | |||
| auto stream_list = model_context.stream_list(); | |||
| if (stream_list.size() == 1) { | |||
| stream_ = stream_list[0]; | |||
| } else if (stream_list.size() > task_info_->stream_id()) { | |||
| stream_ = stream_list[task_info_->stream_id()]; | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Index: " << task_info_->stream_id() << " >= stream_list.size(): " << stream_list.size(); | |||
| } | |||
| } | |||
| HcclTask::~HcclTask() {} | |||
| void HcclTask::Distribute() { | |||
| // Ops kernel info store | |||
| // Get privateDef and opsKernelStorePtr | |||
| MS_LOG(INFO) << "Distribute hccl task start."; | |||
| void *ops_kernel_store = task_info_->ops_kernel_store(); | |||
| ::ge::OpsKernelInfoStore *ops_kernel_info_store = reinterpret_cast<::ge::OpsKernelInfoStore *>(ops_kernel_store); | |||
| MS_EXCEPTION_IF_NULL(ops_kernel_info_store); | |||
| char *private_def = reinterpret_cast<char *>(const_cast<char unsigned *>(task_info_->private_def().data())); | |||
| auto private_def_len = static_cast<uint32_t>(task_info_->private_def().size()); | |||
| MS_LOG(INFO) << "The first address of the custom info, privateDef= " << private_def; | |||
| SetSecondaryStream(); | |||
| if (task_info_->workspace_size() > 0) { | |||
| workspace_mem_ = task_info_->workspace_addr(); | |||
| } | |||
| ::ge::GETaskInfo ge_task; | |||
| ge_task.id = 0; | |||
| ge_task.type = static_cast<uint16_t>(RT_MODEL_TASK_HCCL); | |||
| ge_task.stream = stream_; | |||
| ge_task.kernelHcclInfo = std::vector<::ge::GETaskKernelHcclInfo>(1); | |||
| ge_task.kernelHcclInfo[0].hccl_type = task_info_->hccl_type(); | |||
| ge_task.kernelHcclInfo[0].inputDataAddr = task_info_->input_data_addr(); | |||
| ge_task.kernelHcclInfo[0].outputDataAddr = task_info_->output_data_addr(); | |||
| ge_task.kernelHcclInfo[0].workSpaceAddr = workspace_mem_; | |||
| ge_task.kernelHcclInfo[0].workSpaceMemSize = task_info_->workspace_size(); | |||
| ge_task.kernelHcclInfo[0].count = task_info_->count(); | |||
| ge_task.kernelHcclInfo[0].dataType = static_cast<int32_t>(task_info_->data_type()); | |||
| ge_task.kernelHcclInfo[0].opType = static_cast<int32_t>(task_info_->op_type()); | |||
| ge_task.kernelHcclInfo[0].rootId = task_info_->root_id(); | |||
| std::vector<rtStream_t> secondary_stream_list; | |||
| std::transform(secondary_stream_list_.begin(), secondary_stream_list_.end(), | |||
| std::back_inserter(secondary_stream_list), | |||
| [](const std::shared_ptr<StreamGuard> &stream) -> rtStream_t { return stream->GetStream(); }); | |||
| ge_task.kernelHcclInfo[0].hcclStreamList = secondary_stream_list; | |||
| ge_task.privateDef = private_def; | |||
| ge_task.privateDefLen = private_def_len; | |||
| ge_task.opsKernelStorePtr = ops_kernel_store; | |||
| MS_LOG(INFO) << "Begin to call function LoadTask in hccl."; | |||
| auto result = ops_kernel_info_store->LoadTask(ge_task); | |||
| // tagHcclResult::HCCL_SUCCESS is 0 | |||
| if (result != 0) { | |||
| MS_LOG(EXCEPTION) << "davinci_model : load task fail, return ret: " << result; | |||
| } | |||
| MS_LOG(INFO) << "Call function LoadTask end."; | |||
| } | |||
| void HcclTask::SetSecondaryStream() { | |||
| const uint32_t master_stream_id = task_info_->stream_id(); | |||
| const int64_t hccl_secondary_stream_num = task_info_->hccl_stream_num(); | |||
| std::lock_guard<std::mutex> lock(model_stream_mapping_mutex_); | |||
| // no model, create all secondary stream | |||
| auto model_iter = model_stream_mapping_.find(rt_model_handle_); | |||
| if (model_iter == model_stream_mapping_.end()) { | |||
| MS_LOG(INFO) << "Need to create map for rt_model_handle_: " << rt_model_handle_ << " with new mainstream " | |||
| << master_stream_id; | |||
| CreateStream(hccl_secondary_stream_num, master_stream_id); | |||
| MS_LOG(INFO) << "Initialize hccl secondary stream success, hccl_secondary_stream_num=" << hccl_secondary_stream_num; | |||
| return; | |||
| } | |||
| // has model, but no secondary stream before, create all secondary stream | |||
| auto &master_secondary_stream_map = model_iter->second; | |||
| auto iter = master_secondary_stream_map.find(master_stream_id); | |||
| if (iter == master_secondary_stream_map.end()) { | |||
| MS_LOG(INFO) << "Need to create secondary stream for " << task_info_->op_name() << " with new mainstream " | |||
| << master_stream_id; | |||
| CreateStream(hccl_secondary_stream_num, master_stream_id); | |||
| MS_LOG(INFO) << "Initialize hccl secondary stream success, hccl_secondary_stream_num=" << hccl_secondary_stream_num; | |||
| return; | |||
| } | |||
| // has model, has secondary stream, but number is not enough to be reuse | |||
| std::vector<std::weak_ptr<StreamGuard>> &secondary_stream_vec = iter->second; | |||
| if (static_cast<size_t>(hccl_secondary_stream_num) > secondary_stream_vec.size()) { | |||
| size_t created_stream_num = secondary_stream_vec.size(); | |||
| auto need_to_create_num = hccl_secondary_stream_num - created_stream_num; | |||
| MS_LOG(INFO) << "Need to reuse " << secondary_stream_vec.size() << " secondary stream and create " | |||
| << need_to_create_num << " new secondary stream."; | |||
| for (size_t i = 0; i < secondary_stream_vec.size(); ++i) { | |||
| secondary_stream_list_.push_back(GetSecondaryStream(&secondary_stream_vec, i)); | |||
| } | |||
| CreateStream(need_to_create_num, master_stream_id); | |||
| MS_LOG(INFO) << "Initialize hccl secondary stream success, hccl_secondary_stream_num=" << hccl_secondary_stream_num; | |||
| return; | |||
| } | |||
| // all can be reuse | |||
| MS_LOG(INFO) << "Number of secondary stream " << hccl_secondary_stream_num << " is enough to be reused."; | |||
| for (int64_t i = 0; i < hccl_secondary_stream_num; ++i) { | |||
| secondary_stream_list_.push_back(GetSecondaryStream(&secondary_stream_vec, i)); | |||
| } | |||
| MS_LOG(INFO) << "Initialize hccl secondary stream success, hccl_secondary_stream_num = " << hccl_secondary_stream_num; | |||
| } | |||
| void HcclTask::CreateStream(int64_t stream_num, int64_t master_stream_id) { | |||
| MS_LOG(INFO) << "Start to create " << stream_num << " hccl secondary stream."; | |||
| for (int64_t i = 0; i < stream_num; ++i) { | |||
| rtStream_t stream = nullptr; | |||
| CreateStream(rt_model_handle_, &stream); | |||
| auto shared_stream = std::make_shared<StreamGuard>(rt_model_handle_, stream); | |||
| SaveHcclSecondaryStream(master_stream_id, shared_stream); | |||
| secondary_stream_list_.push_back(shared_stream); | |||
| } | |||
| MS_LOG(INFO) << "CreateStream success."; | |||
| } | |||
| void HcclTask::CreateStream(rtModel_t model, rtStream_t *stream) const { | |||
| MS_EXCEPTION_IF_NULL(stream); | |||
| rtError_t rt_ret = rtStreamCreateWithFlags(stream, priority_, RT_STREAM_PERSISTENT | RT_STREAM_FORCE_COPY); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| MS_LOG(EXCEPTION) << "Call rt api rtEventRecord failed, ret: " << std::hex << rt_ret; | |||
| } | |||
| // Create secondary stream, inactive by default, activated by hccl | |||
| rt_ret = rtModelBindStream(model, *stream, RT_MODEL_WAIT_ACTIVE_STREAM); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| MS_LOG(EXCEPTION) << "Call rt api rtEventRecord failed, ret: " << std::hex << rt_ret; | |||
| } | |||
| } | |||
| void HcclTask::SaveHcclSecondaryStream(int64_t master_stream_id, const std::shared_ptr<StreamGuard> &stream) { | |||
| if (model_stream_mapping_.find(rt_model_handle_) == model_stream_mapping_.end()) { | |||
| model_stream_mapping_.emplace(rt_model_handle_, std::map<uint32_t, std::vector<std::weak_ptr<StreamGuard>>>()); | |||
| } | |||
| std::map<uint32_t, std::vector<std::weak_ptr<StreamGuard>>> &master_secondary_stream_map = | |||
| model_stream_mapping_.at(rt_model_handle_); | |||
| master_secondary_stream_map[master_stream_id].emplace_back(stream); | |||
| } | |||
| HcclTask::StreamGuard::~StreamGuard() { | |||
| rtError_t rt_ret = rtModelUnbindStream(model_, stream_); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| MS_LOG(ERROR) << "Call rt api rtModelUnbindStream failed, ret: " << std::hex << rt_ret; | |||
| return; | |||
| } | |||
| rt_ret = rtStreamDestroy(stream_); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| MS_LOG(ERROR) << "Call rt api rtStreamDestroy failed, ret: " << std::hex << rt_ret; | |||
| return; | |||
| } | |||
| } | |||
| std::shared_ptr<HcclTask::StreamGuard> HcclTask::GetSecondaryStream( | |||
| std::vector<std::weak_ptr<StreamGuard>> *secondary_streams, size_t index) { | |||
| MS_EXCEPTION_IF_NULL(secondary_streams); | |||
| if (index >= secondary_streams->size()) { | |||
| MS_LOG(EXCEPTION) << "Invalid stream index " << index << ", secondary streams size " << secondary_streams->size(); | |||
| } | |||
| auto stream = secondary_streams->at(index).lock(); | |||
| if (stream == nullptr) { | |||
| rtStream_t new_stream = nullptr; | |||
| CreateStream(rt_model_handle_, &new_stream); | |||
| stream = std::make_shared<HcclTask::StreamGuard>(rt_model_handle_, new_stream); | |||
| (*secondary_streams)[index] = stream; | |||
| } | |||
| return stream; | |||
| } | |||
| REGISTER_TASK(TaskInfoType::HCCL, HcclTask, HcclTaskInfo); | |||
| } // namespace mindspore::ge::model_runner | |||
| @@ -0,0 +1,68 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_HCCL_TASK_H_ | |||
| #define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_HCCL_TASK_H_ | |||
| #include <memory> | |||
| #include <set> | |||
| #include <map> | |||
| #include <vector> | |||
| #include <mutex> | |||
| #include "runtime/device/ascend/ge_runtime/task/task.h" | |||
| namespace mindspore::ge::model_runner { | |||
| class HcclTask : public TaskRepeater<HcclTaskInfo> { | |||
| public: | |||
| HcclTask(const ModelContext &model_context, const std::shared_ptr<HcclTaskInfo> &task_info); | |||
| ~HcclTask() override; | |||
| void Distribute() override; | |||
| private: | |||
| class StreamGuard; | |||
| void SetSecondaryStream(); | |||
| void CreateStream(int64_t stream_num, int64_t master_stream_id); | |||
| void CreateStream(rtModel_t model, rtStream_t *stream) const; | |||
| void SaveHcclSecondaryStream(int64_t master_stream_id, const std::shared_ptr<StreamGuard> &stream); | |||
| std::shared_ptr<StreamGuard> GetSecondaryStream(std::vector<std::weak_ptr<StreamGuard>> *secondary_streams, | |||
| size_t index); | |||
| std::shared_ptr<HcclTaskInfo> task_info_; | |||
| void *stream_; | |||
| void *workspace_mem_; | |||
| rtModel_t rt_model_handle_; | |||
| int32_t priority_; | |||
| std::vector<std::shared_ptr<StreamGuard>> secondary_stream_list_; | |||
| // map<key: model pointer, value: map<key: primary stream id, value: vector<secondary stream pointer>>> | |||
| static std::map<rtModel_t, std::map<uint32_t, std::vector<std::weak_ptr<StreamGuard>>>> model_stream_mapping_; | |||
| static std::mutex model_stream_mapping_mutex_; | |||
| }; | |||
| class HcclTask::StreamGuard { | |||
| public: | |||
| StreamGuard(rtModel_t model, rtStream_t stream) : model_(model), stream_(stream) {} | |||
| ~StreamGuard(); | |||
| rtStream_t GetStream() const { return stream_; } | |||
| private: | |||
| rtModel_t model_; | |||
| rtStream_t stream_; | |||
| }; | |||
| } // namespace mindspore::ge::model_runner | |||
| #endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_HCCL_TASK_H_ | |||
| @@ -0,0 +1,83 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "runtime/device/ascend/ge_runtime/task/label_goto_task.h" | |||
| #include "runtime/mem.h" | |||
| #include "runtime/device/ascend/ge_runtime/task/task_factory.h" | |||
| namespace mindspore::ge::model_runner { | |||
| LabelGotoTask::LabelGotoTask(const ModelContext &model_context, const std::shared_ptr<LabelGotoTaskInfo> &task_info) | |||
| : TaskRepeater<LabelGotoTaskInfo>(model_context, task_info), | |||
| task_info_(task_info), | |||
| stream_(nullptr), | |||
| index_value_(nullptr) { | |||
| MS_EXCEPTION_IF_NULL(task_info_); | |||
| auto stream_list = model_context.stream_list(); | |||
| auto label_list = model_context.label_list(); | |||
| rt_model_handle_ = model_context.rt_model_handle(); | |||
| uint32_t stream_id = task_info_->stream_id(); | |||
| label_id_ = task_info_->label_id(); | |||
| MS_LOG(INFO) << "Stream list size: " << stream_list.size() << ", stream id: " << stream_id; | |||
| MS_LOG(INFO) << "Label list size: " << label_list.size() << ", label id: " << label_id_; | |||
| if (stream_id >= stream_list.size() || label_id_ >= label_list.size()) { | |||
| MS_LOG(EXCEPTION) << "Stream/Label id invalid."; | |||
| } | |||
| stream_ = stream_list[stream_id]; | |||
| label_manager_ = LabelManager::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(label_manager_); | |||
| label_info_ = label_manager_->GetLabelInfo(rt_model_handle_, {label_id_}, label_list); | |||
| MS_EXCEPTION_IF_NULL(label_info_); | |||
| } | |||
| LabelGotoTask::~LabelGotoTask() { | |||
| if (index_value_ != nullptr) { | |||
| rtError_t rt_ret = rtFree(index_value_); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| MS_LOG(ERROR) << "Call rtFree index_value_ failed, ret: " << std::hex << rt_ret; | |||
| } | |||
| index_value_ = nullptr; | |||
| } | |||
| } | |||
| void LabelGotoTask::Distribute() { | |||
| MS_LOG(INFO) << "LabelGotoTask Distribute start."; | |||
| MS_EXCEPTION_IF_NULL(stream_); | |||
| MS_EXCEPTION_IF_NULL(label_info_); | |||
| if (index_value_ == nullptr) { | |||
| rtError_t rt_ret = rtMalloc(&index_value_, sizeof(uint64_t), RT_MEMORY_HBM); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| MS_LOG(EXCEPTION) << "Call rt api rtMalloc failed, ret: " << std::hex << rt_ret; | |||
| } | |||
| uint64_t index = 0; | |||
| rt_ret = rtMemcpy(index_value_, sizeof(uint64_t), &index, sizeof(index), RT_MEMCPY_HOST_TO_DEVICE); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| MS_LOG(EXCEPTION) << "Call rt api rtMemcpy failed, ret: " << std::hex << rt_ret; | |||
| } | |||
| } | |||
| void *label_info = label_info_->GetLabelInfo(); | |||
| rtError_t rt_ret = rtLabelSwitchByIndex(index_value_, 1, label_info, stream_); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| MS_LOG(EXCEPTION) << "Call rt api rtLabelSwitchByIndex failed, ret: " << std::hex << rt_ret; | |||
| } | |||
| MS_LOG(INFO) << "DistributeTask end."; | |||
| } | |||
| REGISTER_TASK(TaskInfoType::LABEL_GOTO, LabelGotoTask, LabelGotoTaskInfo); | |||
| } // namespace mindspore::ge::model_runner | |||
| @@ -0,0 +1,46 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_ | |||
| #define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_ | |||
| #include <memory> | |||
| #include <vector> | |||
| #include <map> | |||
| #include <mutex> | |||
| #include "runtime/device/ascend/ge_runtime/task/task.h" | |||
| #include "runtime/device/ascend/ge_runtime/task/label_manager.h" | |||
| namespace mindspore::ge::model_runner { | |||
| class LabelGotoTask : public TaskRepeater<LabelGotoTaskInfo> { | |||
| public: | |||
| LabelGotoTask(const ModelContext &model_context, const std::shared_ptr<LabelGotoTaskInfo> &task_info); | |||
| ~LabelGotoTask() override; | |||
| void Distribute() override; | |||
| private: | |||
| std::shared_ptr<LabelGotoTaskInfo> task_info_; | |||
| void *stream_; | |||
| std::shared_ptr<LabelGuard> label_info_; | |||
| void *index_value_; | |||
| uint32_t label_id_; | |||
| rtModel_t rt_model_handle_; | |||
| std::shared_ptr<LabelManager> label_manager_; | |||
| }; | |||
| } // namespace mindspore::ge::model_runner | |||
| #endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_ | |||
| @@ -0,0 +1,116 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "runtime/device/ascend/ge_runtime/task/label_manager.h" | |||
| #include <algorithm> | |||
| #include <string> | |||
| #include "runtime/mem.h" | |||
| #include "runtime/rt_model.h" | |||
| #include "mindspore/core/utils/log_adapter.h" | |||
| namespace mindspore::ge::model_runner { | |||
| std::weak_ptr<LabelManager> LabelManager::instance_; | |||
| std::mutex LabelManager::instance_mutex_; | |||
| template <class T> | |||
| static std::string GetVectorString(const std::vector<T> &vec) { | |||
| std::string ret; | |||
| for (size_t i = 0; i < vec.size(); ++i) { | |||
| if (i != 0) { | |||
| ret.push_back(','); | |||
| } | |||
| ret += std::to_string(vec[i]); | |||
| } | |||
| return ret; | |||
| } | |||
| LabelGuard::~LabelGuard() { | |||
| void *label_info = GetLabelInfo(); | |||
| if (label_info != nullptr) { | |||
| rtError_t rt_ret = rtFree(label_info); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| MS_LOG(ERROR) << "rtFree label_info failed! ret: " << std::hex << rt_ret; | |||
| } | |||
| } | |||
| } | |||
| std::shared_ptr<LabelManager> LabelManager::GetInstance() { | |||
| std::lock_guard<std::mutex> lock(instance_mutex_); | |||
| auto instance = instance_.lock(); | |||
| if (instance != nullptr) { | |||
| return instance; | |||
| } | |||
| instance = std::make_shared<LabelManager>(); | |||
| instance_ = instance; | |||
| return instance; | |||
| } | |||
| std::shared_ptr<LabelGuard> LabelManager::GetLabelInfo(rtModel_t model, const std::vector<uint32_t> &label_ids, | |||
| const std::vector<void *> &all_label) { | |||
| std::lock_guard<std::mutex> lock(model_info_mapping_mutex_); | |||
| rtError_t rt_ret; | |||
| auto model_iter = model_info_mapping_.find(model); | |||
| if (model_iter == model_info_mapping_.end()) { | |||
| model_info_mapping_.emplace(model, std::map<std::string, std::weak_ptr<LabelGuard>>()); | |||
| model_iter = model_info_mapping_.find(model); | |||
| } | |||
| std::string label_id_str = GetVectorString(label_ids); | |||
| auto &label_map = model_iter->second; | |||
| auto label_iter = label_map.find(label_id_str); | |||
| if (label_iter != label_map.end()) { | |||
| auto label_guard = label_iter->second.lock(); | |||
| if (label_guard != nullptr) { | |||
| MS_LOG(INFO) << "model " << model << " find same label id " << label_id_str; | |||
| return label_guard; | |||
| } | |||
| } | |||
| MS_LOG(INFO) << "Alloc label id " << label_id_str << " for model " << model; | |||
| void *label_info = nullptr; | |||
| std::vector<void *> label_list; | |||
| bool status = true; | |||
| std::transform(label_ids.begin(), label_ids.end(), std::back_inserter(label_list), | |||
| [&all_label, &status](uint32_t idx) -> void * { | |||
| if (idx >= all_label.size()) { | |||
| MS_LOG(ERROR) << "Invalid label id " << idx << " all label list size " << all_label.size(); | |||
| status = false; | |||
| return nullptr; | |||
| } | |||
| return all_label[idx]; | |||
| }); | |||
| if (!status) { | |||
| MS_LOG(ERROR) << "Get label info failed."; | |||
| return nullptr; | |||
| } | |||
| uint32_t label_info_size = sizeof(rtLabelDevInfo) * label_list.size(); | |||
| rt_ret = rtMalloc(&label_info, label_info_size, RT_MEMORY_HBM); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| MS_LOG(ERROR) << "Call rt api rtMalloc failed, ret: " << std::hex << rt_ret; | |||
| return nullptr; | |||
| } | |||
| rt_ret = rtLabelListCpy(label_list.data(), label_list.size(), label_info, label_info_size); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| MS_LOG(ERROR) << "Call rt api rtLabelListCpy failed, ret: " << std::hex << rt_ret; | |||
| return nullptr; | |||
| } | |||
| auto label_guard = std::make_shared<LabelGuard>(label_info); | |||
| label_map.emplace(label_id_str, label_guard); | |||
| return label_guard; | |||
| } | |||
| } // namespace mindspore::ge::model_runner | |||
| @@ -0,0 +1,51 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_LABEL_MANAGER_H_ | |||
| #define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_LABEL_MANAGER_H_ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <mutex> | |||
| #include <map> | |||
| #include <string> | |||
| #include "runtime/base.h" | |||
| namespace mindspore::ge::model_runner { | |||
| class LabelGuard { | |||
| public: | |||
| explicit LabelGuard(void *label_info) : label_info_(reinterpret_cast<uintptr_t>(label_info)) {} | |||
| ~LabelGuard(); | |||
| void *GetLabelInfo() { return reinterpret_cast<void *>(label_info_); } | |||
| private: | |||
| uintptr_t label_info_; | |||
| }; | |||
| class LabelManager { | |||
| public: | |||
| static std::shared_ptr<LabelManager> GetInstance(); | |||
| std::shared_ptr<LabelGuard> GetLabelInfo(rtModel_t model, const std::vector<uint32_t> &label_ids, | |||
| const std::vector<void *> &all_label); | |||
| private: | |||
| std::mutex model_info_mapping_mutex_; | |||
| std::map<rtModel_t, std::map<std::string, std::weak_ptr<LabelGuard>>> model_info_mapping_; | |||
| static std::weak_ptr<LabelManager> instance_; | |||
| static std::mutex instance_mutex_; | |||
| }; | |||
| } // namespace mindspore::ge::model_runner | |||
| #endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_LABEL_MANAGER_H_ | |||
| @@ -0,0 +1,56 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "runtime/device/ascend/ge_runtime/task/label_set_task.h" | |||
| #include "runtime/device/ascend/ge_runtime/task/task_factory.h" | |||
| namespace mindspore::ge::model_runner { | |||
| LabelSetTask::LabelSetTask(const ModelContext &model_context, const std::shared_ptr<LabelSetTaskInfo> &task_info) | |||
| : TaskRepeater<LabelSetTaskInfo>(model_context, task_info), | |||
| task_info_(task_info), | |||
| stream_(nullptr), | |||
| label_(nullptr) { | |||
| MS_EXCEPTION_IF_NULL(task_info_); | |||
| auto stream_list = model_context.stream_list(); | |||
| auto label_list = model_context.label_list(); | |||
| uint32_t stream_id = task_info->stream_id(); | |||
| uint32_t label_id = task_info->label_id(); | |||
| MS_LOG(INFO) << "Stream list size: " << stream_list.size() << ", stream id: " << stream_id; | |||
| MS_LOG(INFO) << "Label list size: " << label_list.size() << ", label id: " << label_id; | |||
| if (stream_id >= stream_list.size() || label_id >= label_list.size()) { | |||
| MS_LOG(EXCEPTION) << "Stream/Label id invalid."; | |||
| } | |||
| stream_ = stream_list[stream_id]; | |||
| label_ = label_list[label_id]; | |||
| } | |||
| LabelSetTask::~LabelSetTask() {} | |||
| void LabelSetTask::Distribute() { | |||
| MS_LOG(INFO) << "LabelSetTask Distribute start."; | |||
| MS_EXCEPTION_IF_NULL(stream_); | |||
| MS_EXCEPTION_IF_NULL(label_); | |||
| rtError_t rt_ret = rtLabelSet(label_, stream_); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| MS_LOG(EXCEPTION) << "Call rt api rtLabelSet failed, ret: " << std::hex << rt_ret; | |||
| } | |||
| MS_LOG(INFO) << "DistributeTask end."; | |||
| } | |||
| REGISTER_TASK(TaskInfoType::LABEL_SET, LabelSetTask, LabelSetTaskInfo); | |||
| } // namespace mindspore::ge::model_runner | |||
| @@ -0,0 +1,38 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_LABEL_SET_TASK_H_ | |||
| #define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_LABEL_SET_TASK_H_ | |||
| #include <memory> | |||
| #include "runtime/device/ascend/ge_runtime/task/task.h" | |||
| namespace mindspore::ge::model_runner { | |||
| class LabelSetTask : public TaskRepeater<LabelSetTaskInfo> { | |||
| public: | |||
| LabelSetTask(const ModelContext &model_context, const std::shared_ptr<LabelSetTaskInfo> &task_info); | |||
| ~LabelSetTask() override; | |||
| void Distribute() override; | |||
| private: | |||
| std::shared_ptr<LabelSetTaskInfo> task_info_; | |||
| void *stream_; | |||
| void *label_; | |||
| }; | |||
| } // namespace mindspore::ge::model_runner | |||
| #endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_LABEL_SET_TASK_H_ | |||
| @@ -0,0 +1,77 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "runtime/device/ascend/ge_runtime/task/label_switch_task.h" | |||
| #include "runtime/device/ascend/ge_runtime/task/task_factory.h" | |||
| namespace mindspore::ge::model_runner { | |||
| LabelSwitchTask::LabelSwitchTask(const ModelContext &model_context, | |||
| const std::shared_ptr<LabelSwitchTaskInfo> &task_info) | |||
| : TaskRepeater<LabelSwitchTaskInfo>(model_context, task_info), | |||
| task_info_(task_info), | |||
| stream_(nullptr), | |||
| label_info_(nullptr) { | |||
| MS_EXCEPTION_IF_NULL(task_info); | |||
| rt_model_handle_ = model_context.rt_model_handle(); | |||
| auto all_label_resource = model_context.label_list(); | |||
| auto stream_list = model_context.stream_list(); | |||
| uint32_t stream_id = task_info->stream_id(); | |||
| MS_LOG(INFO) << "Stream list size: " << stream_list.size() << ", stream id: " << stream_id; | |||
| if (stream_id >= stream_list.size()) { | |||
| MS_LOG(EXCEPTION) << "Stream id invalid."; | |||
| } | |||
| stream_ = stream_list[stream_id]; | |||
| label_manager_ = LabelManager::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(label_manager_); | |||
| label_info_ = label_manager_->GetLabelInfo(rt_model_handle_, task_info_->label_list(), all_label_resource); | |||
| MS_EXCEPTION_IF_NULL(label_info_); | |||
| } | |||
| LabelSwitchTask::~LabelSwitchTask() {} | |||
| void LabelSwitchTask::Distribute() { | |||
| MS_LOG(INFO) << "LabelSwitchTask Distribute start."; | |||
| CheckParamValid(); | |||
| void *label_info = label_info_->GetLabelInfo(); | |||
| rtError_t rt_ret = rtLabelSwitchByIndex(task_info_->cond(), task_info_->label_size(), label_info, stream_); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| MS_LOG(EXCEPTION) << "Call rt api rtLabelSwitchByIndex failed, ret: " << std::hex << rt_ret; | |||
| } | |||
| MS_LOG(INFO) << "DistributeTask end."; | |||
| } | |||
| void LabelSwitchTask::CheckParamValid() { | |||
| MS_EXCEPTION_IF_NULL(stream_); | |||
| if (task_info_->label_list().empty()) { | |||
| MS_LOG(EXCEPTION) << "label_list is empty."; | |||
| } | |||
| if (task_info_->label_size() != task_info_->label_list().size()) { | |||
| MS_LOG(EXCEPTION) << "label_list size " << task_info_->label_list().size() << " but label_size is " | |||
| << task_info_->label_size(); | |||
| } | |||
| if (task_info_->label_size() >= UINT32_MAX / sizeof(rtLabelDevInfo)) { | |||
| MS_LOG(EXCEPTION) << "label_size " << task_info_->label_size() << " will overflow."; | |||
| } | |||
| } | |||
| REGISTER_TASK(TaskInfoType::LABEL_SWITCH, LabelSwitchTask, LabelSwitchTaskInfo); | |||
| } // namespace mindspore::ge::model_runner | |||
| @@ -0,0 +1,43 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_ | |||
| #define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_ | |||
| #include <memory> | |||
| #include "runtime/device/ascend/ge_runtime/task/task.h" | |||
| #include "runtime/device/ascend/ge_runtime/task/label_manager.h" | |||
| namespace mindspore::ge::model_runner { | |||
| class LabelSwitchTask : public TaskRepeater<LabelSwitchTaskInfo> { | |||
| public: | |||
| LabelSwitchTask(const ModelContext &model_context, const std::shared_ptr<LabelSwitchTaskInfo> &task_info); | |||
| ~LabelSwitchTask() override; | |||
| void Distribute() override; | |||
| private: | |||
| void CheckParamValid(); | |||
| std::shared_ptr<LabelSwitchTaskInfo> task_info_; | |||
| void *stream_; | |||
| rtModel_t rt_model_handle_; | |||
| std::shared_ptr<LabelGuard> label_info_; | |||
| std::shared_ptr<LabelManager> label_manager_; | |||
| }; | |||
| } // namespace mindspore::ge::model_runner | |||
| #endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_ | |||
| @@ -0,0 +1,51 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "runtime/device/ascend/ge_runtime/task/memcpy_async_task.h" | |||
| #include "runtime/mem.h" | |||
| #include "runtime/device/ascend/ge_runtime/task/task_factory.h" | |||
| namespace mindspore::ge::model_runner { | |||
| MemcpyAsyncTask::MemcpyAsyncTask(const ModelContext &model_context, | |||
| const std::shared_ptr<MemcpyAsyncTaskInfo> &task_info) | |||
| : TaskRepeater<MemcpyAsyncTaskInfo>(model_context, task_info), task_info_(task_info), stream_(nullptr) { | |||
| MS_EXCEPTION_IF_NULL(task_info); | |||
| auto stream_list = model_context.stream_list(); | |||
| uint32_t stream_id = task_info->stream_id(); | |||
| MS_LOG(INFO) << "Stream list size: " << stream_list.size() << ", stream id: " << stream_id; | |||
| if (stream_id >= stream_list.size()) { | |||
| MS_LOG(EXCEPTION) << "Index: " << task_info->stream_id() << " >= stream_list.size(): " << stream_list.size(); | |||
| } | |||
| stream_ = stream_list[stream_id]; | |||
| } | |||
| MemcpyAsyncTask::~MemcpyAsyncTask() {} | |||
| void MemcpyAsyncTask::Distribute() { | |||
| MS_LOG(INFO) << "MemcpyAsyncTask Distribute start."; | |||
| MS_LOG(INFO) << "dst_max: " << task_info_->dst_max() << ", count: " << task_info_->count() | |||
| << ", kind: " << task_info_->kind(); | |||
| rtError_t rt_ret = rtMemcpyAsync(task_info_->dst(), task_info_->dst_max(), task_info_->src(), task_info_->count(), | |||
| static_cast<rtMemcpyKind_t>(task_info_->kind()), stream_); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| MS_LOG(EXCEPTION) << "Call rt api rtMemcpyAsync failed, ret: " << std::hex << rt_ret; | |||
| } | |||
| MS_LOG(INFO) << "DistributeTask end"; | |||
| } | |||
| REGISTER_TASK(TaskInfoType::MEMCPY_ASYNC, MemcpyAsyncTask, MemcpyAsyncTaskInfo); | |||
| } // namespace mindspore::ge::model_runner | |||
| @@ -0,0 +1,37 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_MEMCPY_ASYNC_TASK_H_ | |||
| #define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_MEMCPY_ASYNC_TASK_H_ | |||
| #include <memory> | |||
| #include "runtime/device/ascend/ge_runtime/task/task.h" | |||
| namespace mindspore::ge::model_runner { | |||
| class MemcpyAsyncTask : public TaskRepeater<MemcpyAsyncTaskInfo> { | |||
| public: | |||
| MemcpyAsyncTask(const ModelContext &model_context, const std::shared_ptr<MemcpyAsyncTaskInfo> &task_info); | |||
| ~MemcpyAsyncTask() override; | |||
| void Distribute() override; | |||
| private: | |||
| std::shared_ptr<MemcpyAsyncTaskInfo> task_info_; | |||
| rtStream_t stream_; | |||
| }; | |||
| } // namespace mindspore::ge::model_runner | |||
| #endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_MEMCPY_ASYNC_TASK_H_ | |||
| @@ -0,0 +1,47 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "runtime/device/ascend/ge_runtime/task/profiler_task.h" | |||
| #include "runtime/device/ascend/ge_runtime/task/task_factory.h" | |||
| namespace mindspore::ge::model_runner { | |||
| ProfilerTask::ProfilerTask(const ModelContext &model_context, const std::shared_ptr<ProfilerTraceTaskInfo> &task_info) | |||
| : TaskRepeater<ProfilerTraceTaskInfo>(model_context, task_info), task_info_(task_info), stream_(nullptr) { | |||
| MS_EXCEPTION_IF_NULL(task_info); | |||
| auto stream_list = model_context.stream_list(); | |||
| uint32_t stream_id = task_info->stream_id(); | |||
| MS_LOG(INFO) << "Stream list size: " << stream_list.size() << ", stream id: " << stream_id; | |||
| if (stream_id >= stream_list.size()) { | |||
| MS_LOG(EXCEPTION) << "Index: " << task_info->stream_id() << " >= stream_list.size(): " << stream_list.size(); | |||
| } | |||
| stream_ = stream_list[stream_id]; | |||
| } | |||
| ProfilerTask::~ProfilerTask() {} | |||
| void ProfilerTask::Distribute() { | |||
| MS_LOG(INFO) << "ProfilerTask Distribute start."; | |||
| MS_LOG(INFO) << "log id = " << task_info_->log_id() << ", notify = " << task_info_->notify() | |||
| << ", flat = " << task_info_->flat(); | |||
| rtError_t rt_ret = rtProfilerTrace(task_info_->log_id(), task_info_->notify(), task_info_->flat(), stream_); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| MS_LOG(EXCEPTION) << "Call rt api rtProfilerTrace failed, ret: " << std::hex << rt_ret; | |||
| } | |||
| MS_LOG(INFO) << "DistributeTask end."; | |||
| } | |||
| REGISTER_TASK(TaskInfoType::PROFILER_TRACE, ProfilerTask, ProfilerTraceTaskInfo); | |||
| } // namespace mindspore::ge::model_runner | |||
| @@ -0,0 +1,37 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_PROFILER_TASK_H_ | |||
| #define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_PROFILER_TASK_H_ | |||
| #include <memory> | |||
| #include "runtime/device/ascend/ge_runtime/task/task.h" | |||
| namespace mindspore::ge::model_runner { | |||
| class ProfilerTask : public TaskRepeater<ProfilerTraceTaskInfo> { | |||
| public: | |||
| ProfilerTask(const ModelContext &model_context, const std::shared_ptr<ProfilerTraceTaskInfo> &task_info); | |||
| ~ProfilerTask() override; | |||
| void Distribute() override; | |||
| private: | |||
| std::shared_ptr<ProfilerTraceTaskInfo> task_info_; | |||
| rtStream_t stream_; | |||
| }; | |||
| } // namespace mindspore::ge::model_runner | |||
| #endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_PROFILER_TASK_H_ | |||
| @@ -0,0 +1,56 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "runtime/device/ascend/ge_runtime/task/stream_active_task.h" | |||
| #include "runtime/kernel.h" | |||
| #include "runtime/device/ascend/ge_runtime/task/task_factory.h" | |||
| namespace mindspore::ge::model_runner { | |||
| StreamActiveTask::StreamActiveTask(const ModelContext &model_context, | |||
| const std::shared_ptr<StreamActiveTaskInfo> &task_info) | |||
| : TaskRepeater<StreamActiveTaskInfo>(model_context, task_info), | |||
| task_info_(task_info), | |||
| stream_(nullptr), | |||
| active_stream_(nullptr) { | |||
| MS_EXCEPTION_IF_NULL(task_info); | |||
| auto stream_list = model_context.stream_list(); | |||
| uint32_t stream_id = task_info->stream_id(); | |||
| uint32_t active_stream_id = task_info->active_stream_id(); | |||
| MS_LOG(INFO) << "Stream list size: " << stream_list.size() << ", stream id: " << stream_id | |||
| << ", active stream id: " << active_stream_id; | |||
| if (stream_id >= stream_list.size() || active_stream_id >= stream_list.size()) { | |||
| MS_LOG(EXCEPTION) << "Stream id invalid"; | |||
| } | |||
| stream_ = stream_list[stream_id]; | |||
| active_stream_ = stream_list[active_stream_id]; | |||
| } | |||
| StreamActiveTask::~StreamActiveTask() {} | |||
| void StreamActiveTask::Distribute() { | |||
| MS_LOG(INFO) << "Distribute start"; | |||
| MS_LOG(INFO) << "Stream " << task_info_->stream_id() << " active " << task_info_->active_stream_id(); | |||
| MS_EXCEPTION_IF_NULL(stream_); | |||
| MS_EXCEPTION_IF_NULL(active_stream_); | |||
| rtError_t rt_ret = rtStreamActive(active_stream_, stream_); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| MS_LOG(EXCEPTION) << "Call rt api rtStreamActive failed, ret: " << std::hex << rt_ret; | |||
| } | |||
| MS_LOG(INFO) << "DistributeTask end."; | |||
| } | |||
| REGISTER_TASK(TaskInfoType::STREAM_ACTIVE, StreamActiveTask, StreamActiveTaskInfo); | |||
| } // namespace mindspore::ge::model_runner | |||
| @@ -0,0 +1,38 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_STREAM_ACTIVE_TASK_H_ | |||
| #define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_STREAM_ACTIVE_TASK_H_ | |||
| #include <memory> | |||
| #include "runtime/device/ascend/ge_runtime/task/task.h" | |||
| namespace mindspore::ge::model_runner { | |||
| class StreamActiveTask : public TaskRepeater<StreamActiveTaskInfo> { | |||
| public: | |||
| StreamActiveTask(const ModelContext &model_context, const std::shared_ptr<StreamActiveTaskInfo> &task_info); | |||
| ~StreamActiveTask() override; | |||
| void Distribute() override; | |||
| private: | |||
| std::shared_ptr<StreamActiveTaskInfo> task_info_; | |||
| rtStream_t stream_; | |||
| rtStream_t active_stream_; | |||
| }; | |||
| } // namespace mindspore::ge::model_runner | |||
| #endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_STREAM_ACTIVE_TASK_H_ | |||
| @@ -0,0 +1,70 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "runtime/device/ascend/ge_runtime/task/stream_switch_task.h" | |||
| #include "runtime/kernel.h" | |||
| #include "runtime/device/ascend/ge_runtime/task/task_factory.h" | |||
| namespace mindspore::ge::model_runner { | |||
| StreamSwitchTask::StreamSwitchTask(const ModelContext &model_context, | |||
| const std::shared_ptr<StreamSwitchTaskInfo> &task_info) | |||
| : TaskRepeater<StreamSwitchTaskInfo>(model_context, task_info), | |||
| task_info_(task_info), | |||
| stream_(nullptr), | |||
| stream_list_() { | |||
| MS_EXCEPTION_IF_NULL(task_info); | |||
| stream_list_ = model_context.stream_list(); | |||
| if (stream_list_.size() == 1) { | |||
| stream_ = stream_list_[0]; | |||
| } else if (stream_list_.size() > task_info->stream_id()) { | |||
| stream_ = stream_list_[task_info->stream_id()]; | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Index: " << task_info->stream_id() << " >= stream_list.size(): " << stream_list_.size(); | |||
| } | |||
| } | |||
| StreamSwitchTask::~StreamSwitchTask() {} | |||
| void StreamSwitchTask::Distribute() { | |||
| MS_LOG(INFO) << "Init StreamSwitchTask start."; | |||
| MS_LOG(INFO) << "Stream " << task_info_->stream_id() << " active " << task_info_->true_stream_id(); | |||
| MS_EXCEPTION_IF_NULL(stream_); | |||
| if (static_cast<uint64_t>(task_info_->true_stream_id()) >= stream_list_.size()) { | |||
| MS_LOG(EXCEPTION) << "true_stream_id " << task_info_->true_stream_id() << " must be less than stream_list_ size " | |||
| << stream_list_.size(); | |||
| } | |||
| void *input = reinterpret_cast<void *>(task_info_->input_addr()); | |||
| rtCondition_t cond = static_cast<rtCondition_t>(task_info_->cond()); | |||
| void *value = reinterpret_cast<void *>(task_info_->value_addr()); | |||
| rtStream_t true_stream = stream_list_[task_info_->true_stream_id()]; | |||
| rtSwitchDataType_t data_type = static_cast<rtSwitchDataType_t>(task_info_->data_type()); | |||
| MS_LOG(INFO) << "InitStreamSwitchTask, cond: " << cond << ", trueStream: " << true_stream | |||
| << ", trueStreamID: " << task_info_->true_stream_id() << ", datatype: " << task_info_->data_type(); | |||
| MS_LOG(INFO) << "StreamSwitchTask Distribute Start."; | |||
| rtError_t rt_ret = rtStreamSwitchEx(input, cond, value, true_stream, stream_, data_type); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| MS_LOG(EXCEPTION) << "Call rt api rtStreamSwitchEx failed, ret: " << std::hex << rt_ret; | |||
| } | |||
| MS_LOG(INFO) << "Distribute StreamSwitch success"; | |||
| } | |||
| REGISTER_TASK(TaskInfoType::STREAM_SWITCH, StreamSwitchTask, StreamSwitchTaskInfo); | |||
| } // namespace mindspore::ge::model_runner | |||
| @@ -0,0 +1,40 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_STREAM_SWITCH_TASK_H_ | |||
| #define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_STREAM_SWITCH_TASK_H_ | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "runtime/device/ascend/ge_runtime/task/task.h" | |||
| namespace mindspore::ge::model_runner { | |||
| class StreamSwitchTask : public TaskRepeater<StreamSwitchTaskInfo> { | |||
| public: | |||
| StreamSwitchTask(const ModelContext &model_context, const std::shared_ptr<StreamSwitchTaskInfo> &task_info); | |||
| ~StreamSwitchTask() override; | |||
| void Distribute() override; | |||
| private: | |||
| std::shared_ptr<StreamSwitchTaskInfo> task_info_; | |||
| void *stream_; | |||
| std::vector<rtStream_t> stream_list_; | |||
| }; | |||
| } // namespace mindspore::ge::model_runner | |||
| #endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_STREAM_SWITCH_TASK_H_ | |||
| @@ -0,0 +1,53 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_TASK_H_ | |||
| #define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_TASK_H_ | |||
| #include <memory> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include <string> | |||
| #include "runtime/device/ascend/ge_runtime/model_context.h" | |||
| #include "runtime/device/ascend/ge_runtime/task_info.h" | |||
| namespace mindspore::ge::model_runner { | |||
| class Task { | |||
| public: | |||
| Task() {} | |||
| virtual ~Task() {} | |||
| virtual void Distribute() = 0; | |||
| virtual void *Args() { return nullptr; } | |||
| virtual std::string task_name() const { return ""; } | |||
| }; | |||
| template <class T> | |||
| class TaskRepeater : public Task { | |||
| static_assert(std::is_base_of<TaskInfo, T>(), "Wrong TaskInfo Type!"); | |||
| public: | |||
| TaskRepeater(const ModelContext &model_context, const std::shared_ptr<T> &task_info) {} | |||
| virtual ~TaskRepeater() {} | |||
| virtual void Distribute() = 0; | |||
| }; | |||
| } // namespace mindspore::ge::model_runner | |||
| #endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_TASK_H_ | |||
| @@ -0,0 +1,84 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_TASK_FACTORY_H_ | |||
| #define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_TASK_FACTORY_H_ | |||
| #include <functional> | |||
| #include <map> | |||
| #include <memory> | |||
| #include <unordered_map> | |||
| #include "runtime/device/ascend/ge_runtime/task_info.h" | |||
| #include "mindspore/core/utils/log_adapter.h" | |||
| namespace mindspore::ge::model_runner { | |||
| class Task; | |||
| class ModelContext; | |||
| using TASK_CREATOR_FUN = std::function<std::shared_ptr<Task>(const ModelContext &, std::shared_ptr<TaskInfo>)>; | |||
| class TaskFactory { | |||
| private: | |||
| TaskFactory() {} | |||
| ~TaskFactory() {} | |||
| void RegisterCreator(const TaskInfoType &type, const TASK_CREATOR_FUN &func) { | |||
| if (creator_map_.find(type) != creator_map_.end()) { | |||
| MS_LOG(WARNING) << "Creator type " << type << " already exist."; | |||
| } | |||
| creator_map_[type] = func; | |||
| } | |||
| std::map<TaskInfoType, TASK_CREATOR_FUN> creator_map_; | |||
| public: | |||
| static TaskFactory &GetInstance() { | |||
| static TaskFactory instance; | |||
| return instance; | |||
| } | |||
| std::shared_ptr<Task> Create(const ModelContext &model_context, const std::shared_ptr<TaskInfo> &task_info) const { | |||
| if (task_info == nullptr) { | |||
| MS_LOG(ERROR) << "task_info is null."; | |||
| return nullptr; | |||
| } | |||
| auto iter = creator_map_.find(task_info->type()); | |||
| if (iter == creator_map_.end()) { | |||
| MS_LOG(ERROR) << "Unknown task type " << task_info->type(); | |||
| return nullptr; | |||
| } | |||
| return iter->second(model_context, task_info); | |||
| } | |||
| class Register { | |||
| public: | |||
| Register(const TaskInfoType &type, const TASK_CREATOR_FUN &func) { | |||
| MS_LOG(DEBUG) << "register type " << type; | |||
| TaskFactory::GetInstance().RegisterCreator(type, func); | |||
| } | |||
| ~Register() {} | |||
| }; | |||
| }; | |||
| #define REGISTER_TASK(type, task_clazz, task_info_clazz) \ | |||
| TaskFactory::Register g_##task_clazz##_register( \ | |||
| type, [](const ModelContext &model_context, const std::shared_ptr<TaskInfo> &task_info) -> std::shared_ptr<Task> { \ | |||
| std::shared_ptr<task_info_clazz> concrete_task_info = std::static_pointer_cast<task_info_clazz>(task_info); \ | |||
| return std::make_shared<task_clazz>(model_context, concrete_task_info); \ | |||
| }); | |||
| } // namespace mindspore::ge::model_runner | |||
| #endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_TASK_FACTORY_H_ | |||
| @@ -0,0 +1,97 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "runtime/device/ascend/ge_runtime/task/tbe_task.h" | |||
| #include <vector> | |||
| #include "runtime/mem.h" | |||
| #include "runtime/kernel.h" | |||
| #include "runtime/device/ascend/ge_runtime/task/task_factory.h" | |||
| namespace mindspore::ge::model_runner { | |||
| TbeTask::TbeTask(const ModelContext &model_context, const std::shared_ptr<TbeTaskInfo> &task_info) | |||
| : TaskRepeater<TbeTaskInfo>(model_context, task_info), | |||
| task_info_(task_info), | |||
| stream_(nullptr), | |||
| stub_func_(nullptr), | |||
| args_(nullptr) { | |||
| MS_EXCEPTION_IF_NULL(task_info); | |||
| auto stream_list = model_context.stream_list(); | |||
| if (stream_list.size() == 1) { | |||
| stream_ = stream_list[0]; | |||
| } else if (stream_list.size() > task_info->stream_id()) { | |||
| stream_ = stream_list[task_info->stream_id()]; | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Index: " << task_info->stream_id() << " >= stream_list.size(): " << stream_list.size(); | |||
| } | |||
| } | |||
| TbeTask::~TbeTask() { | |||
| if (args_ != nullptr) { | |||
| rtError_t rt_ret = rtFree(args_); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| MS_LOG(ERROR) << "Call rt api rtFree failed, ret: " << std::hex << rt_ret; | |||
| } | |||
| args_ = nullptr; | |||
| } | |||
| } | |||
| void TbeTask::Distribute() { | |||
| MS_LOG(INFO) << "InitTbeTask start."; | |||
| MS_EXCEPTION_IF_NULL(stream_); | |||
| // Get stub_func | |||
| if (task_info_->stub_func().empty()) { | |||
| MS_LOG(EXCEPTION) << "kernel_info->stub_func is empty!"; | |||
| } | |||
| rtError_t rt_ret = rtGetFunctionByName(const_cast<char *>(task_info_->stub_func().c_str()), &stub_func_); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| MS_LOG(EXCEPTION) << "Call rt api rtGetFunctionByName failed, ret: " << std::hex << rt_ret; | |||
| } | |||
| MS_LOG(INFO) << "TbeTask: stub_func = " << task_info_->stub_func(); | |||
| // Get args | |||
| std::vector<void *> tensor_device_addrs; | |||
| tensor_device_addrs.insert(tensor_device_addrs.end(), task_info_->input_data_addrs().begin(), | |||
| task_info_->input_data_addrs().end()); | |||
| tensor_device_addrs.insert(tensor_device_addrs.end(), task_info_->output_data_addrs().begin(), | |||
| task_info_->output_data_addrs().end()); | |||
| tensor_device_addrs.insert(tensor_device_addrs.end(), task_info_->workspace_addrs().begin(), | |||
| task_info_->workspace_addrs().end()); | |||
| auto args_size = static_cast<uint32_t>(tensor_device_addrs.size() * sizeof(void *)); | |||
| rt_ret = rtMalloc(&args_, args_size, RT_MEMORY_HBM); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| MS_LOG(EXCEPTION) << "Call rt api rtMalloc failed, ret: " << std::hex << rt_ret << " mem size " << args_size; | |||
| } | |||
| rt_ret = rtMemcpy(args_, args_size, reinterpret_cast<void *>(tensor_device_addrs.data()), args_size, | |||
| RT_MEMCPY_HOST_TO_DEVICE); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| MS_LOG(EXCEPTION) << "Call rt api rtMemcpy failed, ret: " << std::hex << rt_ret; | |||
| } | |||
| MS_LOG(INFO) << "DistributeTbeTask start."; | |||
| auto dump_flag = task_info_->dump_flag() ? RT_KERNEL_DUMPFLAG : RT_KERNEL_DEFAULT; | |||
| rt_ret = rtKernelLaunchWithFlag(stub_func_, task_info_->block_dim(), args_, args_size, nullptr, stream_, dump_flag); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| MS_LOG(EXCEPTION) << "Call rt api rtKernelLaunch failed, ret: " << std::hex << rt_ret << " mem size " << args_size; | |||
| } | |||
| MS_LOG(INFO) << "[DataDump] task name: " << task_info_->op_name() << " dump_flag: " << dump_flag; | |||
| } | |||
| REGISTER_TASK(TaskInfoType::TBE, TbeTask, TbeTaskInfo); | |||
| } // namespace mindspore::ge::model_runner | |||
| @@ -0,0 +1,44 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_TBE_TASK_H_ | |||
| #define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_TBE_TASK_H_ | |||
| #include <string> | |||
| #include <memory> | |||
| #include "runtime/device/ascend/ge_runtime/task/task.h" | |||
| namespace mindspore::ge::model_runner { | |||
| class TbeTask : public TaskRepeater<TbeTaskInfo> { | |||
| public: | |||
| TbeTask(const ModelContext &model_context, const std::shared_ptr<TbeTaskInfo> &task_info); | |||
| ~TbeTask() override; | |||
| void Distribute() override; | |||
| void *Args() override { return args_; } | |||
| std::string task_name() const override { return task_info_->op_name(); } | |||
| private: | |||
| std::shared_ptr<TbeTaskInfo> task_info_; | |||
| void *stream_; | |||
| void *stub_func_; | |||
| void *args_; | |||
| }; | |||
| } // namespace mindspore::ge::model_runner | |||
| #endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_TBE_TASK_H_ | |||
| @@ -0,0 +1,364 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_INFO_H_ | |||
| #define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_INFO_H_ | |||
| #include <stdint.h> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <utility> | |||
| #include <vector> | |||
| namespace mindspore::ge::model_runner { | |||
| enum TaskInfoType { | |||
| CCE = 0, | |||
| TBE, | |||
| AICPU, | |||
| LABEL_SET, | |||
| LABEL_SWITCH, | |||
| LABEL_GOTO, | |||
| EVENT_RECORD, | |||
| EVENT_WAIT, | |||
| FUSION_START, | |||
| FUSION_END, | |||
| HCCL, | |||
| PROFILER_TRACE, | |||
| MEMCPY_ASYNC, | |||
| STREAM_SWITCH, | |||
| STREAM_ACTIVE, | |||
| // Insert new task type here | |||
| REVSERVED = 23 | |||
| }; | |||
| class TaskInfo { | |||
| public: | |||
| virtual ~TaskInfo() {} | |||
| uint32_t stream_id() const { return stream_id_; } | |||
| TaskInfoType type() const { return type_; } | |||
| std::string op_name() const { return op_name_; } | |||
| bool dump_flag() const { return dump_flag_; } | |||
| protected: | |||
| TaskInfo(const std::string &op_name, uint32_t stream_id, TaskInfoType type, bool dump_flag) | |||
| : op_name_(op_name), stream_id_(stream_id), type_(type), dump_flag_(dump_flag) {} | |||
| private: | |||
| std::string op_name_; | |||
| uint32_t stream_id_; | |||
| TaskInfoType type_; | |||
| bool dump_flag_; | |||
| }; | |||
| class TbeTaskInfo : public TaskInfo { | |||
| public: | |||
| TbeTaskInfo(const std::string &op_name, uint32_t stream_id, const std::string &stub_func, uint32_t block_dim, | |||
| const std::vector<uint8_t> &args, uint32_t args_size, const std::vector<uint8_t> &sm_desc, void *binary, | |||
| uint32_t binary_size, const std::vector<uint8_t> &meta_data, const std::vector<void *> &input_data_addrs, | |||
| const std::vector<void *> &output_data_addrs, const std::vector<void *> &workspace_addrs, bool dump_flag) | |||
| : TaskInfo(op_name, stream_id, TaskInfoType::TBE, dump_flag), | |||
| stub_func_(stub_func), | |||
| block_dim_(block_dim), | |||
| args_(args), | |||
| args_size_(args_size), | |||
| sm_desc_(sm_desc), | |||
| binary_(binary), | |||
| binary_size_(binary_size), | |||
| meta_data_(meta_data), | |||
| input_data_addrs_(input_data_addrs), | |||
| output_data_addrs_(output_data_addrs), | |||
| workspace_addrs_(workspace_addrs) {} | |||
| ~TbeTaskInfo() override {} | |||
| const std::string &stub_func() const { return stub_func_; } | |||
| uint32_t block_dim() const { return block_dim_; } | |||
| const std::vector<uint8_t> &args() const { return args_; } | |||
| uint32_t args_size() const { return args_size_; } | |||
| const std::vector<uint8_t> &sm_desc() const { return sm_desc_; } | |||
| void *binary() const { return binary_; } | |||
| uint32_t binary_size() const { return binary_size_; } | |||
| const std::vector<uint8_t> &meta_data() const { return meta_data_; } | |||
| const std::vector<void *> &input_data_addrs() const { return input_data_addrs_; } | |||
| const std::vector<void *> &output_data_addrs() const { return output_data_addrs_; } | |||
| const std::vector<void *> &workspace_addrs() const { return workspace_addrs_; } | |||
| void SetBinary(void *binary, uint32_t binary_size) { | |||
| binary_ = binary; | |||
| binary_size_ = binary_size; | |||
| } | |||
| private: | |||
| std::string stub_func_; | |||
| uint32_t block_dim_; | |||
| std::vector<uint8_t> args_; | |||
| uint32_t args_size_; | |||
| std::vector<uint8_t> sm_desc_; | |||
| void *binary_; | |||
| uint32_t binary_size_; | |||
| std::vector<uint8_t> meta_data_; | |||
| std::vector<void *> input_data_addrs_; | |||
| std::vector<void *> output_data_addrs_; | |||
| std::vector<void *> workspace_addrs_; | |||
| }; | |||
| class AicpuTaskInfo : public TaskInfo { | |||
| public: | |||
| AicpuTaskInfo(const std::string &op_name, uint32_t stream_id, const std::string &so_name, | |||
| const std::string &kernel_name, const std::string &node_def, const std::string &ext_info, | |||
| const std::vector<void *> &input_data_addrs, const std::vector<void *> &output_data_addrs, | |||
| bool dump_flag) | |||
| : TaskInfo(op_name, stream_id, TaskInfoType::AICPU, dump_flag), | |||
| so_name_(so_name), | |||
| kernel_name_(kernel_name), | |||
| node_def_(node_def), | |||
| ext_info_(ext_info), | |||
| input_data_addrs_(input_data_addrs), | |||
| output_data_addrs_(output_data_addrs) {} | |||
| ~AicpuTaskInfo() override {} | |||
| const std::string &so_name() const { return so_name_; } | |||
| const std::string &kernel_name() const { return kernel_name_; } | |||
| const std::string &node_def() const { return node_def_; } | |||
| const std::vector<void *> &input_data_addrs() const { return input_data_addrs_; } | |||
| const std::vector<void *> &output_data_addrs() const { return output_data_addrs_; } | |||
| const std::string &ext_info() const { return ext_info_; } | |||
| private: | |||
| std::string so_name_; | |||
| std::string kernel_name_; | |||
| std::string node_def_; | |||
| std::string ext_info_; | |||
| std::vector<void *> input_data_addrs_; | |||
| std::vector<void *> output_data_addrs_; | |||
| }; | |||
| class LabelSetTaskInfo : public TaskInfo { | |||
| public: | |||
| LabelSetTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t label_id) | |||
| : TaskInfo(op_name, stream_id, TaskInfoType::LABEL_SET, false), label_id_(label_id) {} | |||
| ~LabelSetTaskInfo() override {} | |||
| uint32_t label_id() const { return label_id_; } | |||
| private: | |||
| uint32_t label_id_; | |||
| }; | |||
| class LabelGotoTaskInfo : public TaskInfo { | |||
| public: | |||
| LabelGotoTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t label_id) | |||
| : TaskInfo(op_name, stream_id, TaskInfoType::LABEL_GOTO, false), label_id_(label_id) {} | |||
| ~LabelGotoTaskInfo() override {} | |||
| uint32_t label_id() const { return label_id_; } | |||
| private: | |||
| uint32_t label_id_; | |||
| }; | |||
| class LabelSwitchTaskInfo : public TaskInfo { | |||
| public: | |||
| LabelSwitchTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t label_size, | |||
| const std::vector<uint32_t> &label_list, void *cond) | |||
| : TaskInfo(op_name, stream_id, TaskInfoType::LABEL_SWITCH, false), | |||
| label_size_(label_size), | |||
| label_list_(label_list), | |||
| cond_(cond) {} | |||
| ~LabelSwitchTaskInfo() override {} | |||
| uint32_t label_size() const { return label_size_; } | |||
| const std::vector<uint32_t> &label_list() const { return label_list_; } | |||
| void *cond() const { return cond_; } | |||
| private: | |||
| uint32_t label_size_; | |||
| std::vector<uint32_t> label_list_; | |||
| void *cond_; | |||
| }; | |||
| class EventTaskInfo : public TaskInfo { | |||
| public: | |||
| uint32_t event_id() const { return event_id_; } | |||
| protected: | |||
| EventTaskInfo(const std::string &op_name, uint32_t stream_id, TaskInfoType type, uint32_t event_id) | |||
| : TaskInfo(op_name, stream_id, type, false), event_id_(event_id) {} | |||
| ~EventTaskInfo() override {} | |||
| uint32_t event_id_; | |||
| }; | |||
| class EventRecordTaskInfo : public EventTaskInfo { | |||
| public: | |||
| EventRecordTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t event_id) | |||
| : EventTaskInfo(op_name, stream_id, TaskInfoType::EVENT_RECORD, event_id) {} | |||
| ~EventRecordTaskInfo() override {} | |||
| }; | |||
| class EventWaitTaskInfo : public EventTaskInfo { | |||
| public: | |||
| EventWaitTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t event_id) | |||
| : EventTaskInfo(op_name, stream_id, TaskInfoType::EVENT_WAIT, event_id) {} | |||
| ~EventWaitTaskInfo() override {} | |||
| }; | |||
| class FusionStartTaskInfo : public TaskInfo { | |||
| public: | |||
| explicit FusionStartTaskInfo(const std::string &op_name, uint32_t stream_id) | |||
| : TaskInfo(op_name, stream_id, TaskInfoType::FUSION_START, false) {} | |||
| ~FusionStartTaskInfo() override {} | |||
| }; | |||
| class FusionEndTaskInfo : public TaskInfo { | |||
| public: | |||
| explicit FusionEndTaskInfo(const std::string &op_name, uint32_t stream_id) | |||
| : TaskInfo(op_name, stream_id, TaskInfoType::FUSION_END, false) {} | |||
| ~FusionEndTaskInfo() override {} | |||
| }; | |||
| class HcclTaskInfo : public TaskInfo { | |||
| public: | |||
| HcclTaskInfo(const std::string &op_name, uint32_t stream_id, const std::string hccl_type, void *input_data_addr, | |||
| void *output_data_addr, void *workspace_addr, int64_t workspace_size, int64_t hccl_stream_num, | |||
| const std::vector<uint8_t> &private_def, void *ops_kernel_store, int32_t count, int64_t root_id, | |||
| int64_t op_type, int64_t data_type, const std::string &group, bool dump_flag) | |||
| : TaskInfo(op_name, stream_id, TaskInfoType::HCCL, dump_flag), | |||
| hccl_type_(hccl_type), | |||
| input_data_addr_(input_data_addr), | |||
| output_data_addr_(output_data_addr), | |||
| workspace_addr_(workspace_addr), | |||
| workspace_size_(workspace_size), | |||
| hccl_stream_num_(hccl_stream_num), | |||
| private_def_(private_def), | |||
| ops_kernel_store_(ops_kernel_store), | |||
| count_(count), | |||
| root_id_(root_id), | |||
| op_type_(op_type), | |||
| data_type_(data_type), | |||
| group_(group) {} | |||
| ~HcclTaskInfo() override {} | |||
| const std::string &hccl_type() const { return hccl_type_; } | |||
| void *input_data_addr() const { return input_data_addr_; } | |||
| void *output_data_addr() const { return output_data_addr_; } | |||
| void *workspace_addr() const { return workspace_addr_; } | |||
| int64_t workspace_size() const { return workspace_size_; } | |||
| int64_t hccl_stream_num() const { return hccl_stream_num_; } | |||
| const std::vector<uint8_t> &private_def() const { return private_def_; } | |||
| void *ops_kernel_store() const { return ops_kernel_store_; } | |||
| int32_t count() const { return count_; } | |||
| int64_t root_id() const { return root_id_; } | |||
| int64_t op_type() const { return op_type_; } | |||
| int64_t data_type() const { return data_type_; } | |||
| const std::string &group() const { return group_; } | |||
| private: | |||
| std::string hccl_type_; | |||
| void *input_data_addr_; | |||
| void *output_data_addr_; | |||
| void *workspace_addr_; | |||
| int64_t workspace_size_; | |||
| int64_t hccl_stream_num_; | |||
| std::vector<uint8_t> private_def_; | |||
| void *ops_kernel_store_; | |||
| int32_t count_; | |||
| int64_t root_id_; | |||
| int64_t op_type_; | |||
| int64_t data_type_; | |||
| std::string group_; | |||
| }; | |||
| class ProfilerTraceTaskInfo : public TaskInfo { | |||
| public: | |||
| ProfilerTraceTaskInfo(const std::string &op_name, uint32_t stream_id, uint64_t log_id, bool notify, uint32_t flat) | |||
| : TaskInfo(op_name, stream_id, TaskInfoType::PROFILER_TRACE, false), | |||
| log_id_(log_id), | |||
| notify_(notify), | |||
| flat_(flat) {} | |||
| ~ProfilerTraceTaskInfo() override {} | |||
| uint64_t log_id() const { return log_id_; } | |||
| bool notify() const { return notify_; } | |||
| uint32_t flat() const { return flat_; } | |||
| private: | |||
| uint64_t log_id_; | |||
| bool notify_; | |||
| uint32_t flat_; | |||
| }; | |||
| class MemcpyAsyncTaskInfo : public TaskInfo { | |||
| public: | |||
| MemcpyAsyncTaskInfo(const std::string &op_name, uint32_t stream_id, void *dst, uint64_t dst_max, void *src, | |||
| uint64_t count, uint32_t kind, bool dump_flag) | |||
| : TaskInfo(op_name, stream_id, TaskInfoType::MEMCPY_ASYNC, dump_flag), | |||
| dst_(dst), | |||
| dst_max_(dst_max), | |||
| src_(src), | |||
| count_(count), | |||
| kind_(kind) {} | |||
| ~MemcpyAsyncTaskInfo() override {} | |||
| void *dst() const { return dst_; } | |||
| uint64_t dst_max() const { return dst_max_; } | |||
| void *src() const { return src_; } | |||
| uint64_t count() const { return count_; } | |||
| uint32_t kind() const { return kind_; } | |||
| private: | |||
| void *dst_; | |||
| uint64_t dst_max_; | |||
| void *src_; | |||
| uint64_t count_; | |||
| int32_t kind_; | |||
| }; | |||
| class StreamSwitchTaskInfo : public TaskInfo { | |||
| public: | |||
| StreamSwitchTaskInfo(const std::string &op_name, uint32_t stream_id, int64_t true_stream_id, void *input_addr, | |||
| void *value_addr, int64_t cond, int64_t data_type) | |||
| : TaskInfo(op_name, stream_id, TaskInfoType::STREAM_SWITCH, false), | |||
| true_stream_id_(true_stream_id), | |||
| input_addr_(input_addr), | |||
| value_addr_(value_addr), | |||
| cond_(cond), | |||
| data_type_(data_type) {} | |||
| ~StreamSwitchTaskInfo() override {} | |||
| int64_t true_stream_id() const { return true_stream_id_; } | |||
| void *input_addr() const { return input_addr_; } | |||
| void *value_addr() const { return value_addr_; } | |||
| int64_t cond() const { return cond_; } | |||
| int64_t data_type() const { return data_type_; } | |||
| private: | |||
| int64_t true_stream_id_; | |||
| void *input_addr_; | |||
| void *value_addr_; | |||
| int64_t cond_; | |||
| int64_t data_type_; | |||
| }; | |||
| class StreamActiveTaskInfo : public TaskInfo { | |||
| public: | |||
| StreamActiveTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t active_stream_id) | |||
| : TaskInfo(op_name, stream_id, TaskInfoType::STREAM_ACTIVE, false), active_stream_id_(active_stream_id) {} | |||
| ~StreamActiveTaskInfo() override {} | |||
| uint32_t active_stream_id() const { return active_stream_id_; } | |||
| private: | |||
| uint32_t active_stream_id_; | |||
| }; | |||
| } // namespace mindspore::ge::model_runner | |||
| #endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_INFO_H_ | |||
| @@ -25,7 +25,7 @@ | |||
| #include "runtime/device/kernel_runtime.h" | |||
| #include "ir/anf.h" | |||
| #include "backend/kernel_compiler/ascend_kernel_mod.h" | |||
| #include "framework/ge_runtime/task_info.h" | |||
| #include "runtime/device/ascend/ge_runtime/task_info.h" | |||
| namespace mindspore { | |||
| namespace device { | |||
| @@ -134,6 +134,7 @@ enum SubModuleId : int { | |||
| SM_HCCL_ADPT, // Hccl Adapter | |||
| SM_MINDQUANTUM, // MindQuantum | |||
| SM_RUNTIME_FRAMEWORK, // Runtime framework | |||
| SM_GE, // GraphEngine | |||
| NUM_SUBMODUES // number of submodules | |||
| }; | |||
| @@ -142,34 +143,35 @@ enum SubModuleId : int { | |||
| #endif | |||
| static const char *SUB_MODULE_NAMES[NUM_SUBMODUES] = { | |||
| "UNKNOWN", // SM_UNKNOWN | |||
| "CORE", // SM_CORE | |||
| "ANALYZER", // SM_ANALYZER | |||
| "COMMON", // SM_COMMON | |||
| "DEBUG", // SM_DEBUG | |||
| "OFFLINE_DEBUG", // SM_OFFLINE_DEBUG | |||
| "DEVICE", // SM_DEVICE | |||
| "GE_ADPT", // SM_GE_ADPT | |||
| "IR", // SM_IR | |||
| "KERNEL", // SM_KERNEL | |||
| "MD", // SM_MD | |||
| "ME", // SM_ME | |||
| "EXPRESS", // SM_EXPRESS | |||
| "OPTIMIZER", // SM_OPTIMIZER | |||
| "PARALLEL", // SM_PARALLEL | |||
| "PARSER", // SM_PARSER | |||
| "PIPELINE", // SM_PIPELINE | |||
| "PRE_ACT", // SM_PRE_ACT | |||
| "PYNATIVE", // SM_PYNATIVE | |||
| "SESSION", // SM_SESSION | |||
| "UTILS", // SM_UTILS | |||
| "VM", // SM_VM | |||
| "PROFILER", // SM_PROFILER | |||
| "PS", // SM_PS | |||
| "LITE", // SM_LITE | |||
| "HCCL_ADPT", // SM_HCCL_ADPT | |||
| "MINDQUANTUM", // SM_MINDQUANTUM | |||
| "RUNTIME_FRAMEWORK" // SM_RUNTIME_FRAMEWORK | |||
| "UNKNOWN", // SM_UNKNOWN | |||
| "CORE", // SM_CORE | |||
| "ANALYZER", // SM_ANALYZER | |||
| "COMMON", // SM_COMMON | |||
| "DEBUG", // SM_DEBUG | |||
| "OFFLINE_DEBUG", // SM_OFFLINE_DEBUG | |||
| "DEVICE", // SM_DEVICE | |||
| "GE_ADPT", // SM_GE_ADPT | |||
| "IR", // SM_IR | |||
| "KERNEL", // SM_KERNEL | |||
| "MD", // SM_MD | |||
| "ME", // SM_ME | |||
| "EXPRESS", // SM_EXPRESS | |||
| "OPTIMIZER", // SM_OPTIMIZER | |||
| "PARALLEL", // SM_PARALLEL | |||
| "PARSER", // SM_PARSER | |||
| "PIPELINE", // SM_PIPELINE | |||
| "PRE_ACT", // SM_PRE_ACT | |||
| "PYNATIVE", // SM_PYNATIVE | |||
| "SESSION", // SM_SESSION | |||
| "UTILS", // SM_UTILS | |||
| "VM", // SM_VM | |||
| "PROFILER", // SM_PROFILER | |||
| "PS", // SM_PS | |||
| "LITE", // SM_LITE | |||
| "HCCL_ADPT", // SM_HCCL_ADPT | |||
| "MINDQUANTUM", // SM_MINDQUANTUM | |||
| "RUNTIME_FRAMEWORK", // SM_RUNTIME_FRAMEWORK | |||
| "GE", // SM_GE | |||
| }; | |||
| #if defined(_WIN32) || defined(_WIN64) | |||
| @@ -23,6 +23,7 @@ include_directories(${CMAKE_SOURCE_DIR}/mindspore/core) | |||
| include_directories(${CMAKE_CURRENT_SOURCE_DIR}) | |||
| include_directories(${CMAKE_CURRENT_SOURCE_DIR}/stub/runtime/) | |||
| include_directories(${CMAKE_BINARY_DIR}) | |||
| include_directories(${CMAKE_BINARY_DIR}/proto/ge) | |||
| include_directories(${CUDA_INCLUDE_DIRS}) | |||
| MESSAGE("check ut_test ${CMAKE_BINARY_DIR}") | |||
| @@ -103,6 +104,7 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||
| "../../../mindspore/ccsrc/runtime/device/bucket.cc" | |||
| "../../../mindspore/ccsrc/runtime/device/launch_kernel.cc" | |||
| "../../../mindspore/ccsrc/runtime/device/ascend/profiling/*.cc" | |||
| "../../../mindspore/ccsrc/runtime/device/ascend/ge_runtime/*.cc" | |||
| "../../../mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc" | |||
| "../../../mindspore/ccsrc/runtime/device/ascend/ascend_launch_kernel.cc" | |||
| "../../../mindspore/ccsrc/runtime/device/ascend/ascend_launch_mul.cc" | |||
| @@ -0,0 +1,473 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <memory> | |||
| #include "common/common_test.h" | |||
| #define private public | |||
| #include "runtime/device/ascend/ge_runtime/model_runner.h" | |||
| #include "runtime/device/ascend/ge_runtime/runtime_model.h" | |||
| #include "runtime/device/ascend/ge_runtime/task/task_factory.h" | |||
| #include "runtime/device/ascend/ge_runtime/task/aicpu_task.h" | |||
| #include "runtime/device/ascend/ge_runtime/task/event_record_task.h" | |||
| #include "runtime/device/ascend/ge_runtime/task/event_wait_task.h" | |||
| #include "runtime/device/ascend/ge_runtime/task/hccl_task.h" | |||
| #include "runtime/device/ascend/ge_runtime/task/label_goto_task.h" | |||
| #include "runtime/device/ascend/ge_runtime/task/label_manager.h" | |||
| #include "runtime/device/ascend/ge_runtime/task/label_set_task.h" | |||
| #include "runtime/device/ascend/ge_runtime/task/label_switch_task.h" | |||
| #include "runtime/device/ascend/ge_runtime/task/memcpy_async_task.h" | |||
| #include "runtime/device/ascend/ge_runtime/task/profiler_task.h" | |||
| #include "runtime/device/ascend/ge_runtime/task/stream_active_task.h" | |||
| #include "runtime/device/ascend/ge_runtime/task/stream_switch_task.h" | |||
| #include "runtime/device/ascend/ge_runtime/task/tbe_task.h" | |||
| #undef private | |||
| #include "common/opskernel/ops_kernel_info_store.h" | |||
| using namespace mindspore::ge::model_runner; | |||
| using namespace testing; | |||
| class MockOpsKernelInfoStore : public ge::OpsKernelInfoStore { | |||
| public: | |||
| ge::Status Initialize(const map<string, string> &) override { return ge::SUCCESS; } | |||
| ge::Status Finalize() override { return ge::SUCCESS; } | |||
| void GetAllOpsKernelInfo(std::map<string, ge::OpInfo> &infos) const override {} | |||
| bool CheckSupported(const ge::OpDescPtr &opDescPtr, std::string &un_supported_reason) const override { return true; } | |||
| ge::Status LoadTask(ge::GETaskInfo &task) override { return ge::SUCCESS; } | |||
| }; | |||
| namespace mindspore { | |||
| class TestAscendGeRuntime : public UT::Common { | |||
| public: | |||
| TestAscendGeRuntime() {} | |||
| private: | |||
| void TearDown() override { | |||
| { | |||
| std::lock_guard<std::mutex> lock(HcclTask::model_stream_mapping_mutex_); | |||
| HcclTask::model_stream_mapping_.clear(); | |||
| } | |||
| } | |||
| }; | |||
| TEST_F(TestAscendGeRuntime, test_task_create_null_task_info_failed) { | |||
| ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), | |||
| {reinterpret_cast<rtStream_t>(1), reinterpret_cast<rtStream_t>(1)}, | |||
| {reinterpret_cast<rtLabel_t>(1)}, {reinterpret_cast<rtEvent_t>(1)}); | |||
| ASSERT_TRUE(TaskFactory::GetInstance().Create(model_context, nullptr) == nullptr); | |||
| } | |||
| TEST_F(TestAscendGeRuntime, test_aicpu_task_create_one_stream_success) { | |||
| ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), | |||
| {reinterpret_cast<rtStream_t>(1)}, {reinterpret_cast<rtLabel_t>(1)}, | |||
| {reinterpret_cast<rtEvent_t>(1)}); | |||
| std::shared_ptr<TaskInfo> aicpu_task_info = std::make_shared<AicpuTaskInfo>( | |||
| "op_name", 0, "so_name", "kernel_name", "node_def", "ext_info", std::vector<void *>{reinterpret_cast<void *>(1)}, | |||
| std::vector<void *>{reinterpret_cast<void *>(1)}, true); | |||
| std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, aicpu_task_info); | |||
| ASSERT_TRUE(std::dynamic_pointer_cast<AicpuTask>(task) != nullptr); | |||
| ASSERT_NO_THROW(task->Distribute()); | |||
| } | |||
| TEST_F(TestAscendGeRuntime, test_aicpu_task_create_multi_stream_success) { | |||
| ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), | |||
| {reinterpret_cast<rtStream_t>(1), reinterpret_cast<rtStream_t>(1)}, | |||
| {reinterpret_cast<rtLabel_t>(1)}, {reinterpret_cast<rtEvent_t>(1)}); | |||
| std::shared_ptr<TaskInfo> aicpu_task_info = std::make_shared<AicpuTaskInfo>( | |||
| "op_name", 0, "so_name", "kernel_name", "node_def", "", std::vector<void *>{reinterpret_cast<void *>(1)}, | |||
| std::vector<void *>{reinterpret_cast<void *>(1)}, true); | |||
| std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, aicpu_task_info); | |||
| ASSERT_TRUE(std::dynamic_pointer_cast<AicpuTask>(task) != nullptr); | |||
| ASSERT_NO_THROW(task->Distribute()); | |||
| } | |||
| TEST_F(TestAscendGeRuntime, test_aicpu_task_create_invalid_stream_id_failed) { | |||
| ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), | |||
| {reinterpret_cast<rtStream_t>(1), reinterpret_cast<rtStream_t>(1)}, | |||
| {reinterpret_cast<rtLabel_t>(1)}, {reinterpret_cast<rtEvent_t>(1)}); | |||
| std::shared_ptr<TaskInfo> aicpu_task_info = std::make_shared<AicpuTaskInfo>( | |||
| "op_name", 5, "so_name", "kernel_name", "node_def", "", std::vector<void *>{reinterpret_cast<void *>(1)}, | |||
| std::vector<void *>{reinterpret_cast<void *>(1)}, true); | |||
| ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, aicpu_task_info)); | |||
| } | |||
| TEST_F(TestAscendGeRuntime, test_event_record_task_create_success) { | |||
| ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), | |||
| {reinterpret_cast<rtStream_t>(1)}, {reinterpret_cast<rtLabel_t>(1)}, | |||
| {reinterpret_cast<rtEvent_t>(1)}); | |||
| std::shared_ptr<TaskInfo> event_record_task_info = std::make_shared<EventRecordTaskInfo>("op_name", 0, 0); | |||
| std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, event_record_task_info); | |||
| ASSERT_TRUE(std::dynamic_pointer_cast<EventRecordTask>(task) != nullptr); | |||
| ASSERT_NO_THROW(task->Distribute()); | |||
| } | |||
| TEST_F(TestAscendGeRuntime, test_event_record_task_create_invalid_event_id_failed) { | |||
| ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), | |||
| {reinterpret_cast<rtStream_t>(1)}, {reinterpret_cast<rtLabel_t>(1)}, | |||
| {reinterpret_cast<rtEvent_t>(1)}); | |||
| std::shared_ptr<TaskInfo> event_record_task_info = std::make_shared<EventRecordTaskInfo>("op_name", 0, 10); | |||
| ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, event_record_task_info)); | |||
| } | |||
| TEST_F(TestAscendGeRuntime, test_event_wait_task_create_success) { | |||
| ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), | |||
| {reinterpret_cast<rtStream_t>(1)}, {reinterpret_cast<rtLabel_t>(1)}, | |||
| {reinterpret_cast<rtEvent_t>(1)}); | |||
| std::shared_ptr<TaskInfo> event_record_task_info = std::make_shared<EventWaitTaskInfo>("op_name", 0, 0); | |||
| std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, event_record_task_info); | |||
| ASSERT_TRUE(std::dynamic_pointer_cast<EventWaitTask>(task) != nullptr); | |||
| ASSERT_NO_THROW(task->Distribute()); | |||
| } | |||
| TEST_F(TestAscendGeRuntime, test_event_wait_task_create_invalid_event_id_failed) { | |||
| ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), | |||
| {reinterpret_cast<rtStream_t>(1)}, {reinterpret_cast<rtLabel_t>(1)}, | |||
| {reinterpret_cast<rtEvent_t>(1)}); | |||
| std::shared_ptr<TaskInfo> event_record_task_info = std::make_shared<EventWaitTaskInfo>("op_name", 0, 10); | |||
| ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, event_record_task_info)); | |||
| } | |||
| TEST_F(TestAscendGeRuntime, test_hccl_task_create_success) { | |||
| MockOpsKernelInfoStore ops_kernel_info_store; | |||
| ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), | |||
| {reinterpret_cast<rtStream_t>(1)}, {reinterpret_cast<rtLabel_t>(1)}, | |||
| {reinterpret_cast<rtEvent_t>(1)}); | |||
| std::shared_ptr<TaskInfo> hccl_task_info = std::make_shared<HcclTaskInfo>( | |||
| "op_name", 0, "hccl_type", reinterpret_cast<void *>(1), reinterpret_cast<void *>(2), reinterpret_cast<void *>(3), 4, | |||
| 5, std::vector<uint8_t>(6, 7), reinterpret_cast<void *>(&ops_kernel_info_store), 9, 10, 11, 12, "group", true); | |||
| std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, hccl_task_info); | |||
| ASSERT_TRUE(std::dynamic_pointer_cast<HcclTask>(task) != nullptr); | |||
| ASSERT_NO_THROW(task->Distribute()); | |||
| } | |||
| TEST_F(TestAscendGeRuntime, test_hccl_task_create_stream_reuse_success) { | |||
| const rtModel_t model = reinterpret_cast<rtModel_t>(0x12345678); | |||
| const rtStream_t stream = reinterpret_cast<rtStream_t>(0x87654321); | |||
| constexpr uint32_t stream_id = 0; | |||
| constexpr int64_t task1_stream_num = 3; | |||
| constexpr int64_t task2_stream_num = 5; | |||
| constexpr int64_t task3_stream_num = 4; | |||
| MockOpsKernelInfoStore ops_kernel_info_store; | |||
| ModelContext model_context(0, 0, 0, model, reinterpret_cast<rtStream_t>(2), {stream}, | |||
| {reinterpret_cast<rtLabel_t>(1)}, {reinterpret_cast<rtEvent_t>(1)}); | |||
| std::shared_ptr<TaskInfo> hccl_task_info_1 = std::make_shared<HcclTaskInfo>( | |||
| "op_name", stream_id, "hccl_type", reinterpret_cast<void *>(1), reinterpret_cast<void *>(2), | |||
| reinterpret_cast<void *>(3), 4, task1_stream_num, std::vector<uint8_t>(6, 7), | |||
| reinterpret_cast<void *>(&ops_kernel_info_store), 9, 10, 11, 12, "group", true); | |||
| std::shared_ptr<TaskInfo> hccl_task_info_2 = std::make_shared<HcclTaskInfo>( | |||
| "op_name", stream_id, "hccl_type", reinterpret_cast<void *>(1), reinterpret_cast<void *>(2), | |||
| reinterpret_cast<void *>(3), 4, task2_stream_num, std::vector<uint8_t>(6, 7), | |||
| reinterpret_cast<void *>(&ops_kernel_info_store), 9, 10, 11, 12, "group", true); | |||
| std::shared_ptr<TaskInfo> hccl_task_info_3 = std::make_shared<HcclTaskInfo>( | |||
| "op_name", stream_id, "hccl_type", reinterpret_cast<void *>(1), reinterpret_cast<void *>(2), | |||
| reinterpret_cast<void *>(3), 4, task3_stream_num, std::vector<uint8_t>(6, 7), | |||
| reinterpret_cast<void *>(&ops_kernel_info_store), 9, 10, 11, 12, "group", true); | |||
| std::shared_ptr<Task> task_1 = TaskFactory::GetInstance().Create(model_context, hccl_task_info_1); | |||
| std::shared_ptr<Task> task_2 = TaskFactory::GetInstance().Create(model_context, hccl_task_info_2); | |||
| std::shared_ptr<Task> task_3 = TaskFactory::GetInstance().Create(model_context, hccl_task_info_3); | |||
| ASSERT_TRUE(std::dynamic_pointer_cast<HcclTask>(task_1) != nullptr); | |||
| ASSERT_TRUE(std::dynamic_pointer_cast<HcclTask>(task_2) != nullptr); | |||
| ASSERT_TRUE(std::dynamic_pointer_cast<HcclTask>(task_3) != nullptr); | |||
| ASSERT_NO_THROW(task_1->Distribute()); | |||
| ASSERT_NO_THROW(task_2->Distribute()); | |||
| ASSERT_NO_THROW(task_3->Distribute()); | |||
| { | |||
| std::lock_guard<std::mutex> lock(HcclTask::model_stream_mapping_mutex_); | |||
| auto model_iter = HcclTask::model_stream_mapping_.find(model); | |||
| ASSERT_NE(model_iter, HcclTask::model_stream_mapping_.end()); | |||
| auto stream_iter = model_iter->second.find(stream_id); | |||
| ASSERT_NE(stream_iter, model_iter->second.end()); | |||
| const auto &stream_vec = stream_iter->second; | |||
| ASSERT_EQ(stream_vec.size(), std::max(task1_stream_num, std::max(task2_stream_num, task3_stream_num))); | |||
| for (const auto &s : stream_vec) { | |||
| auto shared = s.lock(); | |||
| ASSERT_TRUE(shared != nullptr); | |||
| } | |||
| } | |||
| } | |||
| TEST_F(TestAscendGeRuntime, test_label_goto_task_create_success) { | |||
| ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), | |||
| {reinterpret_cast<rtStream_t>(1)}, {reinterpret_cast<rtLabel_t>(1)}, | |||
| {reinterpret_cast<rtEvent_t>(1)}); | |||
| std::shared_ptr<TaskInfo> label_goto_task_info = std::make_shared<LabelGotoTaskInfo>("op_name", 0, 0); | |||
| std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, label_goto_task_info); | |||
| auto label_goto_task = std::dynamic_pointer_cast<LabelGotoTask>(task); | |||
| ASSERT_TRUE(label_goto_task != nullptr); | |||
| ASSERT_NO_THROW(task->Distribute()); | |||
| label_goto_task->index_value_ = new uint8_t[5]; | |||
| } | |||
| TEST_F(TestAscendGeRuntime, test_label_goto_task_create_invalid_label_id_failed) { | |||
| ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), | |||
| {reinterpret_cast<rtStream_t>(1)}, {reinterpret_cast<rtLabel_t>(1)}, | |||
| {reinterpret_cast<rtEvent_t>(1)}); | |||
| std::shared_ptr<TaskInfo> label_goto_task_info = std::make_shared<LabelGotoTaskInfo>("op_name", 0, 1); | |||
| ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, label_goto_task_info)); | |||
| } | |||
| TEST_F(TestAscendGeRuntime, test_label_goto_task_reuse_success) { | |||
| ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), | |||
| {reinterpret_cast<rtStream_t>(1)}, {reinterpret_cast<rtLabel_t>(1)}, | |||
| {reinterpret_cast<rtEvent_t>(1)}); | |||
| std::shared_ptr<TaskInfo> label_goto_task_info = std::make_shared<LabelGotoTaskInfo>("op_name", 0, 0); | |||
| std::shared_ptr<Task> task1 = TaskFactory::GetInstance().Create(model_context, label_goto_task_info); | |||
| std::shared_ptr<Task> task2 = TaskFactory::GetInstance().Create(model_context, label_goto_task_info); | |||
| auto label_goto_task_1 = std::dynamic_pointer_cast<LabelGotoTask>(task1); | |||
| auto label_goto_task_2 = std::dynamic_pointer_cast<LabelGotoTask>(task2); | |||
| ASSERT_TRUE(label_goto_task_1 != nullptr); | |||
| ASSERT_NO_THROW(task1->Distribute()); | |||
| ASSERT_TRUE(label_goto_task_2 != nullptr); | |||
| ASSERT_NO_THROW(task2->Distribute()); | |||
| ASSERT_EQ(label_goto_task_1->label_info_, label_goto_task_2->label_info_); | |||
| } | |||
| TEST_F(TestAscendGeRuntime, test_label_set_task_create_success) { | |||
| ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), | |||
| {reinterpret_cast<rtStream_t>(1)}, {reinterpret_cast<rtLabel_t>(1)}, | |||
| {reinterpret_cast<rtEvent_t>(1)}); | |||
| std::shared_ptr<TaskInfo> label_set_task_info = std::make_shared<LabelSetTaskInfo>("op_name", 0, 0); | |||
| std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, label_set_task_info); | |||
| ASSERT_TRUE(std::dynamic_pointer_cast<LabelSetTask>(task) != nullptr); | |||
| ASSERT_NO_THROW(task->Distribute()); | |||
| } | |||
| TEST_F(TestAscendGeRuntime, test_label_set_task_create_invalid_label_id_failed) { | |||
| ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), | |||
| {reinterpret_cast<rtStream_t>(1)}, {reinterpret_cast<rtLabel_t>(1)}, | |||
| {reinterpret_cast<rtEvent_t>(1)}); | |||
| std::shared_ptr<TaskInfo> label_set_task_info = std::make_shared<LabelGotoTaskInfo>("op_name", 0, 1); | |||
| ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, label_set_task_info)); | |||
| } | |||
| TEST_F(TestAscendGeRuntime, test_label_switch_task_create_success) { | |||
| ModelContext model_context( | |||
| 0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), {reinterpret_cast<rtStream_t>(1)}, | |||
| {reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)}, {reinterpret_cast<rtEvent_t>(1)}); | |||
| std::shared_ptr<TaskInfo> label_switch_task_info = | |||
| std::make_shared<LabelSwitchTaskInfo>("op_name", 0, 2, std::vector<uint32_t>{0, 1}, reinterpret_cast<void *>(1)); | |||
| std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, label_switch_task_info); | |||
| ASSERT_TRUE(std::dynamic_pointer_cast<LabelSwitchTask>(task) != nullptr); | |||
| ASSERT_NO_THROW(task->Distribute()); | |||
| } | |||
| TEST_F(TestAscendGeRuntime, test_label_switch_task_create_invalid_stream_id_failed) { | |||
| ModelContext model_context( | |||
| 0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), {reinterpret_cast<rtStream_t>(1)}, | |||
| {reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)}, {reinterpret_cast<rtEvent_t>(1)}); | |||
| std::shared_ptr<TaskInfo> label_switch_task_info = | |||
| std::make_shared<LabelSwitchTaskInfo>("op_name", 1, 2, std::vector<uint32_t>{0, 1}, reinterpret_cast<void *>(1)); | |||
| ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, label_switch_task_info)); | |||
| } | |||
| TEST_F(TestAscendGeRuntime, test_label_switch_task_create_invalid_label_id_failed) { | |||
| ModelContext model_context( | |||
| 0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), {reinterpret_cast<rtStream_t>(1)}, | |||
| {reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)}, {reinterpret_cast<rtEvent_t>(1)}); | |||
| std::shared_ptr<TaskInfo> label_switch_task_info = | |||
| std::make_shared<LabelSwitchTaskInfo>("op_name", 0, 3, std::vector<uint32_t>{0, 1, 2}, reinterpret_cast<void *>(1)); | |||
| ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, label_switch_task_info)); | |||
| } | |||
| TEST_F(TestAscendGeRuntime, test_label_switch_task_reuse_success) { | |||
| ModelContext model_context( | |||
| 0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), {reinterpret_cast<rtStream_t>(1)}, | |||
| {reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)}, {reinterpret_cast<rtEvent_t>(1)}); | |||
| std::shared_ptr<TaskInfo> label_switch_task_info = | |||
| std::make_shared<LabelSwitchTaskInfo>("op_name", 0, 2, std::vector<uint32_t>{0, 1}, reinterpret_cast<void *>(1)); | |||
| std::shared_ptr<Task> task1 = TaskFactory::GetInstance().Create(model_context, label_switch_task_info); | |||
| std::shared_ptr<Task> task2 = TaskFactory::GetInstance().Create(model_context, label_switch_task_info); | |||
| auto label_switch_task_1 = std::dynamic_pointer_cast<LabelSwitchTask>(task1); | |||
| auto label_switch_task_2 = std::dynamic_pointer_cast<LabelSwitchTask>(task2); | |||
| ASSERT_TRUE(label_switch_task_1 != nullptr); | |||
| ASSERT_TRUE(label_switch_task_2 != nullptr); | |||
| ASSERT_NO_THROW(task1->Distribute()); | |||
| ASSERT_NO_THROW(task2->Distribute()); | |||
| ASSERT_EQ(label_switch_task_1->label_info_, label_switch_task_2->label_info_); | |||
| } | |||
| TEST_F(TestAscendGeRuntime, test_memcpy_async_task_create_success) { | |||
| ModelContext model_context( | |||
| 0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), {reinterpret_cast<rtStream_t>(1)}, | |||
| {reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)}, {reinterpret_cast<rtEvent_t>(1)}); | |||
| std::shared_ptr<TaskInfo> memcpy_task_info = std::make_shared<MemcpyAsyncTaskInfo>( | |||
| "op_name", 0, reinterpret_cast<void *>(1), 2, reinterpret_cast<void *>(3), 4, 5, true); | |||
| std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, memcpy_task_info); | |||
| ASSERT_TRUE(std::dynamic_pointer_cast<MemcpyAsyncTask>(task) != nullptr); | |||
| ASSERT_NO_THROW(task->Distribute()); | |||
| } | |||
| TEST_F(TestAscendGeRuntime, test_memcpy_async_task_create_invalid_stream_id_failed) { | |||
| ModelContext model_context( | |||
| 0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), {reinterpret_cast<rtStream_t>(1)}, | |||
| {reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)}, {reinterpret_cast<rtEvent_t>(1)}); | |||
| std::shared_ptr<TaskInfo> memcpy_task_info = std::make_shared<MemcpyAsyncTaskInfo>( | |||
| "op_name", 1, reinterpret_cast<void *>(1), 2, reinterpret_cast<void *>(3), 4, 5, true); | |||
| ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, memcpy_task_info)); | |||
| } | |||
| TEST_F(TestAscendGeRuntime, test_profiler_task_create_success) { | |||
| ModelContext model_context( | |||
| 0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), {reinterpret_cast<rtStream_t>(1)}, | |||
| {reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)}, {reinterpret_cast<rtEvent_t>(1)}); | |||
| std::shared_ptr<TaskInfo> profiler_task_info = std::make_shared<ProfilerTraceTaskInfo>("op_name", 0, 1, true, 2); | |||
| std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, profiler_task_info); | |||
| ASSERT_TRUE(std::dynamic_pointer_cast<ProfilerTask>(task) != nullptr); | |||
| ASSERT_NO_THROW(task->Distribute()); | |||
| } | |||
| TEST_F(TestAscendGeRuntime, test_profiler_task_create_invalid_stream_id_failed) { | |||
| ModelContext model_context( | |||
| 0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), {reinterpret_cast<rtStream_t>(1)}, | |||
| {reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)}, {reinterpret_cast<rtEvent_t>(1)}); | |||
| std::shared_ptr<TaskInfo> profiler_task_info = std::make_shared<ProfilerTraceTaskInfo>("op_name", 1, 1, true, 2); | |||
| ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, profiler_task_info)); | |||
| } | |||
| TEST_F(TestAscendGeRuntime, test_stream_active_task_create_success) { | |||
| ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), | |||
| {reinterpret_cast<rtStream_t>(1), reinterpret_cast<rtStream_t>(2)}, | |||
| {reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)}, | |||
| {reinterpret_cast<rtEvent_t>(1)}); | |||
| std::shared_ptr<TaskInfo> stream_active_task_info = std::make_shared<StreamActiveTaskInfo>("op_name", 0, 1); | |||
| std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, stream_active_task_info); | |||
| ASSERT_TRUE(std::dynamic_pointer_cast<StreamActiveTask>(task) != nullptr); | |||
| ASSERT_NO_THROW(task->Distribute()); | |||
| } | |||
| TEST_F(TestAscendGeRuntime, test_stream_active_task_create_invalid_active_stream_id_failed) { | |||
| ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), | |||
| {reinterpret_cast<rtStream_t>(1), reinterpret_cast<rtStream_t>(2)}, | |||
| {reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)}, | |||
| {reinterpret_cast<rtEvent_t>(1)}); | |||
| std::shared_ptr<TaskInfo> stream_active_task_info = std::make_shared<StreamActiveTaskInfo>("op_name", 0, 2); | |||
| ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, stream_active_task_info)); | |||
| } | |||
| TEST_F(TestAscendGeRuntime, test_stream_switch_task_create_success) { | |||
| ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), | |||
| {reinterpret_cast<rtStream_t>(1), reinterpret_cast<rtStream_t>(2)}, | |||
| {reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)}, | |||
| {reinterpret_cast<rtEvent_t>(1)}); | |||
| std::shared_ptr<TaskInfo> stream_switch_task_info = std::make_shared<StreamSwitchTaskInfo>( | |||
| "op_name", 0, 1, reinterpret_cast<void *>(2), reinterpret_cast<void *>(3), 4, 5); | |||
| std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, stream_switch_task_info); | |||
| ASSERT_TRUE(std::dynamic_pointer_cast<StreamSwitchTask>(task) != nullptr); | |||
| ASSERT_NO_THROW(task->Distribute()); | |||
| } | |||
| TEST_F(TestAscendGeRuntime, test_stream_switch_task_create_invalid_true_stream_id_failed) { | |||
| ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), | |||
| {reinterpret_cast<rtStream_t>(1), reinterpret_cast<rtStream_t>(2)}, | |||
| {reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)}, | |||
| {reinterpret_cast<rtEvent_t>(1)}); | |||
| std::shared_ptr<TaskInfo> stream_switch_task_info = std::make_shared<StreamSwitchTaskInfo>( | |||
| "op_name", 0, 2, reinterpret_cast<void *>(2), reinterpret_cast<void *>(3), 4, 5); | |||
| std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, stream_switch_task_info); | |||
| ASSERT_TRUE(std::dynamic_pointer_cast<StreamSwitchTask>(task) != nullptr); | |||
| ASSERT_ANY_THROW(task->Distribute()); | |||
| } | |||
| TEST_F(TestAscendGeRuntime, test_stream_switch_task_create_invalid_stream_id_failed) { | |||
| ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), | |||
| {reinterpret_cast<rtStream_t>(1), reinterpret_cast<rtStream_t>(2)}, | |||
| {reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)}, | |||
| {reinterpret_cast<rtEvent_t>(1)}); | |||
| std::shared_ptr<TaskInfo> stream_switch_task_info = std::make_shared<StreamSwitchTaskInfo>( | |||
| "op_name", 2, 1, reinterpret_cast<void *>(2), reinterpret_cast<void *>(3), 4, 5); | |||
| ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, stream_switch_task_info)); | |||
| } | |||
| TEST_F(TestAscendGeRuntime, test_tbe_task_create_success) { | |||
| ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), | |||
| {reinterpret_cast<rtStream_t>(1), reinterpret_cast<rtStream_t>(2)}, | |||
| {reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)}, | |||
| {reinterpret_cast<rtEvent_t>(1)}); | |||
| std::shared_ptr<TaskInfo> tbe_task_info = std::make_shared<TbeTaskInfo>( | |||
| "op_name", 0, "stub_func", 1, std::vector<uint8_t>(100, 2), 100, std::vector<uint8_t>{5, 6}, | |||
| reinterpret_cast<void *>(7), 8, std::vector<uint8_t>{9, 10}, | |||
| std::vector<void *>{reinterpret_cast<void *>(11), reinterpret_cast<void *>(12)}, | |||
| std::vector<void *>{reinterpret_cast<void *>(13), reinterpret_cast<void *>(14)}, | |||
| std::vector<void *>{reinterpret_cast<void *>(15), reinterpret_cast<void *>(16)}, true); | |||
| std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, tbe_task_info); | |||
| auto tbe_task = std::dynamic_pointer_cast<TbeTask>(task); | |||
| ASSERT_TRUE(tbe_task != nullptr); | |||
| ASSERT_NO_THROW(task->Distribute()); | |||
| tbe_task->args_ = new uint8_t[5]; | |||
| } | |||
| TEST_F(TestAscendGeRuntime, test_tbe_task_create_invalid_stream_id_failed) { | |||
| ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), | |||
| {reinterpret_cast<rtStream_t>(1), reinterpret_cast<rtStream_t>(2)}, | |||
| {reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)}, | |||
| {reinterpret_cast<rtEvent_t>(1)}); | |||
| std::shared_ptr<TaskInfo> tbe_task_info = std::make_shared<TbeTaskInfo>( | |||
| "op_name", 3, "stub_func", 1, std::vector<uint8_t>(100, 2), 100, std::vector<uint8_t>{5, 6}, | |||
| reinterpret_cast<void *>(7), 8, std::vector<uint8_t>{9, 10}, | |||
| std::vector<void *>{reinterpret_cast<void *>(11), reinterpret_cast<void *>(12)}, | |||
| std::vector<void *>{reinterpret_cast<void *>(13), reinterpret_cast<void *>(14)}, | |||
| std::vector<void *>{reinterpret_cast<void *>(15), reinterpret_cast<void *>(16)}, true); | |||
| ASSERT_ANY_THROW(TaskFactory::GetInstance().Create(model_context, tbe_task_info)); | |||
| } | |||
| TEST_F(TestAscendGeRuntime, test_tbe_task_create_empty_stub_func_failed) { | |||
| ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), | |||
| {reinterpret_cast<rtStream_t>(1), reinterpret_cast<rtStream_t>(2)}, | |||
| {reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)}, | |||
| {reinterpret_cast<rtEvent_t>(1)}); | |||
| std::shared_ptr<TaskInfo> tbe_task_info = std::make_shared<TbeTaskInfo>( | |||
| "op_name", 0, "", 1, std::vector<uint8_t>(100, 2), 100, std::vector<uint8_t>{5, 6}, reinterpret_cast<void *>(7), 8, | |||
| std::vector<uint8_t>{9, 10}, std::vector<void *>{reinterpret_cast<void *>(11), reinterpret_cast<void *>(12)}, | |||
| std::vector<void *>{reinterpret_cast<void *>(13), reinterpret_cast<void *>(14)}, | |||
| std::vector<void *>{reinterpret_cast<void *>(15), reinterpret_cast<void *>(16)}, true); | |||
| std::shared_ptr<Task> task = TaskFactory::GetInstance().Create(model_context, tbe_task_info); | |||
| ASSERT_TRUE(std::dynamic_pointer_cast<TbeTask>(task) != nullptr); | |||
| ASSERT_ANY_THROW(task->Distribute()); | |||
| } | |||
| TEST_F(TestAscendGeRuntime, test_model_runner_success) { | |||
| constexpr uint32_t model_id = 0; | |||
| ModelContext model_context(0, 0, 0, reinterpret_cast<rtModel_t>(1), reinterpret_cast<rtStream_t>(2), | |||
| {reinterpret_cast<rtStream_t>(1), reinterpret_cast<rtStream_t>(2)}, | |||
| {reinterpret_cast<rtLabel_t>(1), reinterpret_cast<rtLabel_t>(1)}, | |||
| {reinterpret_cast<rtEvent_t>(1)}); | |||
| std::shared_ptr<TaskInfo> tbe_task_info = std::make_shared<TbeTaskInfo>( | |||
| "op_name", 0, "stub_func", 1, std::vector<uint8_t>(100, 2), 100, std::vector<uint8_t>{5, 6}, | |||
| reinterpret_cast<void *>(7), 8, std::vector<uint8_t>{9, 10}, | |||
| std::vector<void *>{reinterpret_cast<void *>(11), reinterpret_cast<void *>(12)}, | |||
| std::vector<void *>{reinterpret_cast<void *>(13), reinterpret_cast<void *>(14)}, | |||
| std::vector<void *>{reinterpret_cast<void *>(15), reinterpret_cast<void *>(16)}, true); | |||
| std::shared_ptr<TaskInfo> aicpu_task_info = std::make_shared<AicpuTaskInfo>( | |||
| "op_name", 0, "so_name", "kernel_name", "node_def", "ext_info", std::vector<void *>{reinterpret_cast<void *>(1)}, | |||
| std::vector<void *>{reinterpret_cast<void *>(1)}, true); | |||
| auto davice_model = | |||
| std::make_shared<DavinciModel>(std::vector<std::shared_ptr<TaskInfo>>{tbe_task_info, aicpu_task_info}, | |||
| std::vector<uint32_t>{}, std::vector<uint32_t>{}, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0); | |||
| ASSERT_NO_THROW(ModelRunner::Instance().LoadDavinciModel(0, 0, model_id, davice_model)); | |||
| auto iter = ModelRunner::Instance().runtime_models_.find(model_id); | |||
| ASSERT_TRUE(iter != ModelRunner::Instance().runtime_models_.end()); | |||
| auto &task_list = iter->second->task_list_; | |||
| task_list.clear(); | |||
| ASSERT_NO_THROW(task_list.emplace_back(TaskFactory::GetInstance().Create(model_context, tbe_task_info))); | |||
| ASSERT_NO_THROW(task_list.emplace_back(TaskFactory::GetInstance().Create(model_context, aicpu_task_info))); | |||
| ASSERT_NO_THROW(ModelRunner::Instance().DistributeTask(model_id)); | |||
| ASSERT_NO_THROW(ModelRunner::Instance().LoadModelComplete(model_id)); | |||
| ASSERT_NO_THROW(ModelRunner::Instance().RunModel(model_id)); | |||
| ASSERT_FALSE(ModelRunner::Instance().GetTaskIdList(model_id).empty()); | |||
| ASSERT_FALSE(ModelRunner::Instance().GetStreamIdList(model_id).empty()); | |||
| ASSERT_FALSE(ModelRunner::Instance().GetRuntimeInfoMap(model_id).empty()); | |||
| ASSERT_NO_THROW(ModelRunner::Instance().GetModelHandle(model_id)); | |||
| ASSERT_NO_THROW(ModelRunner::Instance().UnloadModel(model_id)); | |||
| } | |||
| } // namespace mindspore | |||
| @@ -14,51 +14,8 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include <vector> | |||
| #include "framework/ge_runtime/model_runner.h" | |||
| #include "runtime/hccl_adapter/hccl_adapter.h" | |||
| namespace ge { | |||
| namespace model_runner { | |||
| ModelRunner &ModelRunner::Instance() { | |||
| static ModelRunner runner; | |||
| return runner; | |||
| } | |||
| bool ModelRunner::LoadDavinciModel(uint32_t device_id, uint64_t session_id, uint32_t model_id, | |||
| std::shared_ptr<DavinciModel> ascend_model, | |||
| std::shared_ptr<ge::ModelListener> listener) { | |||
| return true; | |||
| } | |||
| bool ModelRunner::UnloadModel(uint32_t model_id) { return true; } | |||
| bool ModelRunner::LoadModelComplete(uint32_t model_id) { return true; } | |||
| bool ModelRunner::RunModel(uint32_t model_id, const ge::InputData &input_data, ge::OutputData *output_data) { | |||
| return true; | |||
| } | |||
| void *ModelRunner::GetModelHandle(uint32_t model_id) const { return nullptr; } | |||
| bool ModelRunner::DistributeTask(uint32_t model_id) { return true; } | |||
| const std::vector<uint32_t> &ModelRunner::GetTaskIdList(uint32_t model_id) const { | |||
| static std::vector<uint32_t> task_id_list; | |||
| return task_id_list; | |||
| } | |||
| const std::vector<uint32_t> &ModelRunner::GetStreamIdList(uint32_t model_id) const { | |||
| static std::vector<uint32_t> stream_id_list; | |||
| return stream_id_list; | |||
| } | |||
| const std::map<std::string, std::shared_ptr<RuntimeInfo>> &ModelRunner::GetRuntimeInfoMap(uint32_t model_id) const { | |||
| static std::map<std::string, std::shared_ptr<RuntimeInfo>> runtime_info_map; | |||
| return runtime_info_map; | |||
| } | |||
| } // namespace model_runner | |||
| } // namespace ge | |||
| namespace mindspore { | |||
| namespace hccl { | |||
| bool InitHccl(uint32_t, std::string_view, std::string_view) { return true; } | |||
| @@ -141,9 +141,9 @@ rtError_t rtGetFunctionByName(const char *stubName, void **stubFunc) { return RT | |||
| rtError_t rtSetTaskGenCallback(rtTaskGenCallback callback) { return RT_ERROR_NONE; } | |||
| RTS_API rtError_t rtProfilerStart(uint64_t profConfig, int32_t numsDev, uint32_t* deviceList) { return RT_ERROR_NONE; } | |||
| RTS_API rtError_t rtProfilerStart(uint64_t profConfig, int32_t numsDev, uint32_t *deviceList) { return RT_ERROR_NONE; } | |||
| RTS_API rtError_t rtProfilerStop(uint64_t profConfig, int32_t numsDev, uint32_t* deviceList) { return RT_ERROR_NONE; } | |||
| RTS_API rtError_t rtProfilerStop(uint64_t profConfig, int32_t numsDev, uint32_t *deviceList) { return RT_ERROR_NONE; } | |||
| int AdxDataDumpServerInit() { return 0; } | |||
| @@ -151,11 +151,13 @@ int AdxDataDumpServerUnInit() { return 0; } | |||
| RTS_API rtError_t rtGetTaskIdAndStreamID(uint32_t *taskid, uint32_t *streamid) { return RT_ERROR_NONE; } | |||
| RTS_API rtError_t rtSetTaskFailCallback(rtTaskFailCallback callback) {return RT_ERROR_NONE; } | |||
| RTS_API rtError_t rtSetTaskFailCallback(rtTaskFailCallback callback) { return RT_ERROR_NONE; } | |||
| RTS_API rtError_t rtRegDeviceStateCallback(const char *regName, rtDeviceStateCallback callback) {return RT_ERROR_NONE; } | |||
| RTS_API rtError_t rtRegDeviceStateCallback(const char *regName, rtDeviceStateCallback callback) { | |||
| return RT_ERROR_NONE; | |||
| } | |||
| RTS_API rtError_t rtSetMsprofReporterCallback(MsprofReporterCallback callback) {return RT_ERROR_NONE; } | |||
| RTS_API rtError_t rtSetMsprofReporterCallback(MsprofReporterCallback callback) { return RT_ERROR_NONE; } | |||
| RTS_API rtError_t rtRegTaskFailCallbackByModule(const char *moduleName, rtTaskFailCallback callback) { | |||
| return RT_ERROR_NONE; | |||
| @@ -168,3 +170,28 @@ RTS_API rtError_t rtDevBinaryUnRegister(void *handle) { return RT_ERROR_NONE; } | |||
| RTS_API rtError_t rtMemsetAsync(void *ptr, uint64_t destMax, uint32_t value, uint64_t count, rtStream_t stream) { | |||
| return RT_ERROR_NONE; | |||
| } | |||
| RTS_API rtError_t rtLabelListCpy(rtLabel_t *label, uint32_t labelNumber, void *dst, uint32_t dstMax) { | |||
| return RT_ERROR_NONE; | |||
| } | |||
| RTS_API rtError_t rtModelGetTaskId(rtModel_t model, uint32_t *taskid, uint32_t *streamid) { return RT_ERROR_NONE; } | |||
| RTS_API rtError_t rtLabelCreateEx(rtLabel_t *label, rtStream_t stream) { return RT_ERROR_NONE; } | |||
| RTS_API rtError_t rtCpuKernelLaunchWithFlag(const void *soName, const void *kernelName, uint32_t blockDim, | |||
| const void *args, uint32_t argsSize, rtSmDesc_t *smDesc, rtStream_t stream, | |||
| uint32_t flags) { | |||
| return RT_ERROR_NONE; | |||
| } | |||
| RTS_API rtError_t rtLabelSwitchByIndex(void *ptr, uint32_t max, void *labelInfoPtr, rtStream_t stream) { | |||
| return RT_ERROR_NONE; | |||
| } | |||
| RTS_API rtError_t rtProfilerTrace(uint64_t id, bool notify, uint32_t flags, rtStream_t stream) { return RT_ERROR_NONE; } | |||
| RTS_API rtError_t rtKernelLaunchWithFlag(const void *stubFunc, uint32_t blockDim, void *args, uint32_t argsSize, | |||
| rtSmDesc_t *smDesc, rtStream_t stream, uint32_t flags) { | |||
| return RT_ERROR_NONE; | |||
| } | |||