| @@ -64,6 +64,7 @@ | |||
| #include "toolchain/adx_datadump_server.h" | |||
| #ifdef ENABLE_DUMP_IR | |||
| #include "debug/rdr/running_data_recorder.h" | |||
| #include "runtime/device/ascend/ascend_bucket.h" | |||
| #endif | |||
| #if ENABLE_CPU && ENABLE_D | |||
| #include "ps/util.h" | |||
| @@ -258,6 +259,7 @@ GraphId AscendSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNode | |||
| // construct graph, if successfully, graph_sum_ + 1 | |||
| auto graph = ConstructKernelGraph(lst, outputs); | |||
| auto graph_id = graph->graph_id(); | |||
| InitAllBucket(graph); | |||
| MS_LOG(INFO) << "Compile graph " << graph_id << " success"; | |||
| return graph_id; | |||
| } | |||
| @@ -632,6 +634,13 @@ void AscendSession::RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_inf | |||
| MS_EXCEPTION_IF_NULL(op_run_info); | |||
| BuildOpImpl(*op_run_info, graph_info, *input_tensors, tensors_mask); | |||
| EraseValueNodeTensor(tensors_mask, input_tensors); | |||
| // wait for allreduce | |||
| for (auto &tensor : *input_tensors) { | |||
| if (tensor->NeedWaitDevice()) { | |||
| tensor->WaitDevice(); | |||
| } | |||
| } | |||
| // Run op | |||
| auto graph = run_op_graphs_[graph_info]; | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| @@ -1510,5 +1519,9 @@ void AscendSession::SyncStream() { | |||
| MS_LOG(EXCEPTION) << "Sync stream error!"; | |||
| } | |||
| } | |||
| std::shared_ptr<device::Bucket> AscendSession::CreateBucket(uint32_t bucket_id, uint32_t bucket_size) { | |||
| return std::make_shared<device::ascend::AscendBucket>(bucket_id, bucket_size); | |||
| } | |||
| } // namespace session | |||
| } // namespace mindspore | |||
| @@ -61,6 +61,7 @@ class AscendSession : public SessionBasic { | |||
| void BuildOpsInGraph(const GraphId &graph_id, const std::map<AnfNodePtr, size_t> ¶meter_index, | |||
| const std::vector<tensor::TensorPtr> &graph_inputs, | |||
| const std::map<KernelWithIndex, size_t> &cnode_refcount) override; | |||
| std::string GetCommWorldGroup() override { return kHcclWorldGroup; } | |||
| private: | |||
| // compile child graph when session have multiple child graphs | |||
| @@ -123,6 +124,7 @@ class AscendSession : public SessionBasic { | |||
| const std::vector<tensor::TensorPtr> &graph_inputs, | |||
| const std::map<KernelWithIndex, OutputTensorInfo> &node_output_info, | |||
| InputTensorInfo *input_tensor_info); | |||
| std::shared_ptr<device::Bucket> CreateBucket(uint32_t bucket_id, uint32_t bucket_size) override; | |||
| // key is final_graph_id,value is child graph execute order of final graph | |||
| std::unordered_map<GraphId, std::vector<GraphId>> graph_execute_orders_; | |||
| // key is final_graph_id,value is the graph types of child graphs | |||
| @@ -16,6 +16,7 @@ | |||
| #include "backend/session/gpu_session.h" | |||
| #include <string> | |||
| #include <utility> | |||
| #include "backend/optimizer/common/helper.h" | |||
| #include "backend/optimizer/common/optimizer.h" | |||
| #include "backend/optimizer/common/pass_manager.h" | |||
| @@ -63,6 +64,7 @@ | |||
| #include "runtime/device/kernel_runtime_manager.h" | |||
| #include "runtime/device/gpu/cuda_driver.h" | |||
| #include "runtime/device/gpu/distribution/collective_init.h" | |||
| #include "runtime/device/gpu/gpu_bucket.h" | |||
| #include "utils/ms_utils.h" | |||
| #include "utils/config_manager.h" | |||
| #include "utils/ms_context.h" | |||
| @@ -394,6 +396,8 @@ GraphId GPUSession::CompileGraphImpl(KernelGraphPtr graph) { | |||
| manager->AddFuncGraph(graph); | |||
| graph->set_manager(manager); | |||
| } | |||
| InitAllBucket(graph); | |||
| // Alloc memory in graph mode, including static memory and dynamic memory | |||
| if (!pynative_mode) { | |||
| AllocateMemory(graph.get()); | |||
| @@ -473,6 +477,12 @@ void GPUSession::RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info, | |||
| MS_EXCEPTION_IF_NULL(op_run_info); | |||
| BuildOpImpl(*op_run_info, graph_info, *input_tensors, tensors_mask); | |||
| EraseValueNodeTensor(tensors_mask, input_tensors); | |||
| // wait for allreduce | |||
| for (auto &tensor : *input_tensors) { | |||
| if (tensor->NeedWaitDevice()) { | |||
| tensor->WaitDevice(); | |||
| } | |||
| } | |||
| // run op | |||
| auto kernel_graph = run_op_graphs_[graph_info]; | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| @@ -548,6 +558,10 @@ void GPUSession::SyncStream() { | |||
| MS_LOG(EXCEPTION) << "Sync stream error!"; | |||
| } | |||
| } | |||
| std::shared_ptr<device::Bucket> GPUSession::CreateBucket(uint32_t bucket_id, uint32_t bucket_size) { | |||
| return std::make_shared<device::gpu::GPUBucket>(bucket_id, bucket_size); | |||
| } | |||
| } // namespace gpu | |||
| } // namespace session | |||
| } // namespace mindspore | |||
| @@ -19,6 +19,7 @@ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <algorithm> | |||
| #include <string> | |||
| #include "backend/session/session_basic.h" | |||
| #include "backend/session/kernel_graph.h" | |||
| #include "backend/session/session_factory.h" | |||
| @@ -44,6 +45,8 @@ class GPUSession : public SessionBasic { | |||
| const std::vector<int64_t> &tensors_mask) override; | |||
| void RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info, std::vector<tensor::TensorPtr> *input_tensors, | |||
| VectorRef *outputs, const std::vector<int64_t> &tensors_mask) override; | |||
| std::shared_ptr<device::Bucket> CreateBucket(uint32_t bucket_id, uint32_t bucket_size) override; | |||
| std::string GetCommWorldGroup() override { return kNcclWorldGroup; } | |||
| private: | |||
| void SelectKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const; | |||
| @@ -40,6 +40,7 @@ | |||
| #include "debug/anf_ir_dump.h" | |||
| #include "debug/common.h" | |||
| #include "utils/trace_base.h" | |||
| #include "frontend/parallel/context.h" | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| #include "ps/ps_cache/ps_cache_manager.h" | |||
| #include "ps/constants.h" | |||
| @@ -556,10 +557,12 @@ void HandleOpInputs(const std::set<KernelWithIndex> &input_kernel, std::map<Kern | |||
| void HandleOpOutputs(const AnfNodePtr &kernel, const VectorRef &op_outputs, | |||
| const std::map<KernelWithIndex, std::vector<std::vector<size_t>>> &output_indexes, | |||
| const std::map<KernelWithIndex, size_t> &ref_count, | |||
| std::map<KernelWithIndex, tensor::TensorPtr> *op_output_map, VectorRef *outputs) { | |||
| std::map<KernelWithIndex, tensor::TensorPtr> *op_output_map, VectorRef *outputs, | |||
| std::vector<TensorPtr> *runop_output_tensors) { | |||
| MS_EXCEPTION_IF_NULL(kernel); | |||
| MS_EXCEPTION_IF_NULL(op_output_map); | |||
| MS_EXCEPTION_IF_NULL(outputs); | |||
| MS_EXCEPTION_IF_NULL(runop_output_tensors); | |||
| auto output_tensors = TransformVectorRefToMultiTensor(op_outputs); | |||
| if (output_tensors.size() > op_outputs.size()) { | |||
| MS_LOG(EXCEPTION) << "Op output contains tuple, node = " << kernel->DebugString(); | |||
| @@ -592,6 +595,7 @@ void HandleOpOutputs(const AnfNodePtr &kernel, const VectorRef &op_outputs, | |||
| } | |||
| BaseRef &tensor_ref = (*const_cast<VectorRef *>(cur_vector_ref))[ref_indexes.at(n)]; | |||
| tensor_ref = output_tensor; | |||
| runop_output_tensors->emplace_back(output_tensor); | |||
| } | |||
| } | |||
| } | |||
| @@ -2196,6 +2200,11 @@ void SessionBasic::RunOpsInGraphImpl(const GraphId &graph_id, const std::vector< | |||
| GetRefCount(kernel_graph.get(), &cnode_refcount); | |||
| BuildOpsInGraph(graph_id, parameter_index, inputs, cnode_refcount); | |||
| // Clear bucket resources every step | |||
| if (kernel_graph->is_bprop()) { | |||
| ClearAllBucket(graph_id); | |||
| } | |||
| std::map<KernelWithIndex, tensor::TensorPtr> op_output_map; | |||
| for (const auto &kernel : kernel_graph->execution_order()) { | |||
| // Generate input tensors, tensor masks and input kernel with index | |||
| @@ -2212,9 +2221,15 @@ void SessionBasic::RunOpsInGraphImpl(const GraphId &graph_id, const std::vector< | |||
| RunOpImpl(graph_info, &run_info, &input_tensor_info.input_tensors, &op_outputs, | |||
| input_tensor_info.input_tensors_mask); | |||
| std::vector<tensor::TensorPtr> new_output_tensors; | |||
| // Handle inputs and outputs of current op | |||
| HandleOpInputs(input_tensor_info.input_kernel, &cnode_refcount, &op_output_map); | |||
| HandleOpOutputs(kernel, op_outputs, output_indexes, cnode_refcount, &op_output_map, outputs); | |||
| HandleOpOutputs(kernel, op_outputs, output_indexes, cnode_refcount, &op_output_map, outputs, &new_output_tensors); | |||
| // Save grad node to Bucket | |||
| if (kernel_graph->is_bprop()) { | |||
| AddGradAddrToBucket(graph_id, new_output_tensors); | |||
| } | |||
| } | |||
| MS_LOG(INFO) << "Finish!"; | |||
| } | |||
| @@ -2287,6 +2302,137 @@ void SessionBasic::RunOpHideNopNode(const KernelGraphPtr &kernel_graph) const { | |||
| } | |||
| } | |||
| std::vector<uint32_t> SessionBasic::GetAllReduceSplitIndex() { | |||
| auto ms_context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(ms_context); | |||
| std::string group = GetCommWorldGroup(); | |||
| auto parallel_context = parallel::ParallelContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(parallel_context); | |||
| // PyNative not support multi group allreduce | |||
| group += "sum1"; | |||
| return parallel_context->GetAllReduceFusionSplitIndices(group); | |||
| } | |||
| uint32_t GetBpropGraphGradsCount(const KernelGraphPtr &graph) { | |||
| return AnfAlgo::GetAllOutput(graph->output(), {prim::kPrimTupleGetItem}).size(); | |||
| } | |||
| void SetGraphBpropAttr(const KernelGraphPtr &graph) { | |||
| auto &execution_orders = graph->execution_order(); | |||
| if (std::any_of(execution_orders.begin(), execution_orders.end(), | |||
| [](const AnfNodePtr &node) { return node->scope()->name().rfind("Gradient", 0) == 0; })) { | |||
| graph->set_is_bprop(true); | |||
| MS_LOG(INFO) << "Match bprop graph"; | |||
| } else { | |||
| graph->set_is_bprop(false); | |||
| } | |||
| } | |||
| std::vector<uint32_t> GenerateBucketSizeList(const KernelGraphPtr &graph, const std::vector<uint32_t> &split_index) { | |||
| if (split_index.empty()) { | |||
| auto grads_count = GetBpropGraphGradsCount(graph); | |||
| if (grads_count == 0) { | |||
| MS_LOG(EXCEPTION) << "Bprop graph has no grad"; | |||
| } | |||
| return {grads_count}; | |||
| } | |||
| std::vector<uint32_t> bucket_size_list; | |||
| uint32_t old_index = 0; | |||
| for (auto &index : split_index) { | |||
| if (old_index == 0) { | |||
| bucket_size_list.emplace_back(index - old_index + 1); | |||
| } else { | |||
| bucket_size_list.emplace_back(index - old_index); | |||
| } | |||
| old_index = index; | |||
| } | |||
| return bucket_size_list; | |||
| } | |||
| void SessionBasic::InitAllBucket(const KernelGraphPtr &graph) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_LOG(INFO) << "Init Bucket start, graph_id:" << graph->graph_id(); | |||
| SetGraphBpropAttr(graph); | |||
| if (!graph->is_bprop()) { | |||
| return; | |||
| } | |||
| std::vector<std::shared_ptr<device::Bucket>> bucket_list; | |||
| // Create bucket for every split allreduce ops | |||
| auto split_index = GetAllReduceSplitIndex(); | |||
| auto bucket_size_list = GenerateBucketSizeList(graph, split_index); | |||
| uint32_t bucket_id = 0; | |||
| for (auto bucket_size : bucket_size_list) { | |||
| MS_LOG(INFO) << "Create new bucket:" << bucket_id; | |||
| auto bucket = CreateBucket(bucket_id++, bucket_size); | |||
| bucket->Init(); | |||
| bucket_list.emplace_back(bucket); | |||
| } | |||
| auto bucket_ret = bucket_map_.try_emplace(graph->graph_id(), bucket_list); | |||
| if (!bucket_ret.second) { | |||
| MS_LOG(EXCEPTION) << "Duplicate bucket_map_ graph key:" << graph->graph_id(); | |||
| } | |||
| // set all free bucket index to 0 | |||
| auto free_bucket_ret = free_bucket_id_map_.try_emplace(graph->graph_id(), 0); | |||
| if (!free_bucket_ret.second) { | |||
| MS_LOG(EXCEPTION) << "Duplicate free_bucket_id_map_ graph key:" << graph->graph_id(); | |||
| } | |||
| MS_LOG(INFO) << "Init Bucket finish"; | |||
| } | |||
| void SessionBasic::AddGradAddrToBucket(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &grad_tensor) { | |||
| auto parallel_context = parallel::ParallelContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(parallel_context); | |||
| auto parallel_mode = parallel_context->parallel_mode(); | |||
| if (parallel_mode != parallel::DATA_PARALLEL) { | |||
| return; | |||
| } | |||
| auto iter = bucket_map_.find(graph_id); | |||
| if (iter == bucket_map_.end()) { | |||
| MS_LOG(EXCEPTION) << "unknown graph id:" << graph_id; | |||
| } | |||
| auto &bucket_list = iter->second; | |||
| auto free_bucket_iter = free_bucket_id_map_.find(graph_id); | |||
| if (free_bucket_iter == free_bucket_id_map_.end()) { | |||
| MS_LOG(EXCEPTION) << "unknown free graph id:" << graph_id; | |||
| } | |||
| auto free_bucket_index = free_bucket_iter->second; | |||
| for (auto &tensor : grad_tensor) { | |||
| if (free_bucket_index >= bucket_list.size()) { | |||
| MS_LOG(EXCEPTION) << "Invalid free bucket id:" << free_bucket_iter->second | |||
| << " total bucket num:" << bucket_list.size(); | |||
| } | |||
| auto &free_bucket = bucket_list[free_bucket_index]; | |||
| free_bucket->AddGradTensor(tensor); | |||
| if (free_bucket->full()) { | |||
| MS_LOG(INFO) << "bucket is full"; | |||
| free_bucket->Launch(); | |||
| free_bucket_index = ++free_bucket_iter->second; | |||
| MS_LOG(INFO) << "new free bucket:" << free_bucket_index; | |||
| } | |||
| } | |||
| } | |||
| void SessionBasic::ClearAllBucket(const GraphId &graph_id) { | |||
| auto iter = bucket_map_.find(graph_id); | |||
| if (iter != bucket_map_.end()) { | |||
| auto bucket_list = iter->second; | |||
| for (auto &bucket : bucket_list) { | |||
| MS_LOG(INFO) << "Clear bucket:" << bucket->id(); | |||
| bucket->Release(); | |||
| } | |||
| } | |||
| auto free_iter = free_bucket_id_map_.find(graph_id); | |||
| if (free_iter != free_bucket_id_map_.end()) { | |||
| free_iter->second = 0; | |||
| } | |||
| } | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| void SessionBasic::InitPsWorker(const KernelGraphPtr &kernel_graph) { | |||
| if (!ps::PSContext::instance()->is_worker()) { | |||
| @@ -32,6 +32,7 @@ | |||
| #include "utils/contract.h" | |||
| #include "runtime/device/kernel_info.h" | |||
| #include "utils/ms_context.h" | |||
| #include "runtime/device/bucket.h" | |||
| #if !defined(_WIN32) && !defined(_WIN64) | |||
| #include "debug/debugger/debugger.h" | |||
| #endif | |||
| @@ -224,12 +225,20 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> { | |||
| void UpdateAllGraphDynamicShapeAttr(const std::vector<KernelGraphPtr> &all_graphs); | |||
| void RunOpRemoveNopNode(const KernelGraphPtr &kernel_graph) const; | |||
| void RunOpHideNopNode(const KernelGraphPtr &kernel_graph) const; | |||
| virtual std::shared_ptr<device::Bucket> CreateBucket(uint32_t bucket_id, uint32_t bucket_size) { return nullptr; } | |||
| void InitAllBucket(const KernelGraphPtr &graph); | |||
| void AddGradAddrToBucket(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &grad_tensor); | |||
| void ClearAllBucket(const GraphId &graph_id); | |||
| std::vector<uint32_t> GetAllReduceSplitIndex(); | |||
| virtual std::string GetCommWorldGroup() { return std::string(); } | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| void CheckPSModeConsistence(const KernelGraphPtr &kernel_graph) const; | |||
| void GetBatchElements(const AnfNodePtr &kernel_node) const; | |||
| void InitPsWorker(const KernelGraphPtr &kernel_graph); | |||
| #endif | |||
| std::map<uint32_t, std::vector<std::shared_ptr<device::Bucket>>> bucket_map_; | |||
| std::map<uint32_t, uint32_t> free_bucket_id_map_; | |||
| std::unordered_map<GraphId, std::shared_ptr<KernelGraph>> graphs_; | |||
| std::unordered_map<GraphInfo, std::shared_ptr<KernelGraph>> run_op_graphs_; | |||
| std::unordered_map<FuncGraphPtr, KernelGraphPtr> front_backend_graph_map_; | |||
| @@ -677,34 +677,9 @@ OpExecInfoPtr PynativeExecutor::GenerateOpExecInfo(const py::args &args) { | |||
| return op_exec_info; | |||
| } | |||
| AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector<bool> *op_masks, | |||
| abstract::AbstractBasePtrList *args_spec_list) { | |||
| MS_EXCEPTION_IF_NULL(op_masks); | |||
| MS_EXCEPTION_IF_NULL(args_spec_list); | |||
| MS_EXCEPTION_IF_NULL(op_exec_info); | |||
| void PynativeExecutor::GetArgsSpec(const OpExecInfoPtr &op_exec_info, std::vector<bool> *op_masks, | |||
| std::vector<AnfNodePtr> *inputs, abstract::AbstractBasePtrList *args_spec_list) { | |||
| auto prim = op_exec_info->py_primitive; | |||
| std::vector<AnfNodePtr> inputs; | |||
| inputs.emplace_back(NewValueNode(prim)); | |||
| const auto &signature = prim->signatures(); | |||
| auto sig_size = signature.size(); | |||
| auto size = op_exec_info->op_inputs.size(); | |||
| // ignore monad signature | |||
| for (auto sig : signature) { | |||
| if (sig.default_value != nullptr && sig.default_value->isa<Monad>()) { | |||
| --sig_size; | |||
| } | |||
| } | |||
| if (sig_size > 0 && sig_size != size) { | |||
| MS_EXCEPTION(ValueError) << op_exec_info->op_name << " inputs size " << size << " does not match the requires " | |||
| << "inputs size " << sig_size; | |||
| } | |||
| if (op_exec_info->op_name != prim::kPrimCast->name()) { | |||
| RunParameterAutoMixPrecisionCast(op_exec_info); | |||
| } | |||
| MS_LOG(DEBUG) << "Get op " << op_exec_info->op_name << " grad_flag_ " << grad_flag(); | |||
| for (size_t i = 0; i < op_exec_info->op_inputs.size(); i++) { | |||
| abstract::AbstractBasePtr abs = nullptr; | |||
| const auto &obj = op_exec_info->op_inputs[i]; | |||
| @@ -733,11 +708,42 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v | |||
| if (input_node->abstract() != nullptr) { | |||
| abs = input_node->abstract(); | |||
| } | |||
| inputs.emplace_back(input_node); | |||
| inputs->emplace_back(input_node); | |||
| } | |||
| } | |||
| (*args_spec_list).emplace_back(CheckConstValue(prim, obj, abs, id, i)); | |||
| } | |||
| } | |||
| AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector<bool> *op_masks, | |||
| abstract::AbstractBasePtrList *args_spec_list) { | |||
| MS_EXCEPTION_IF_NULL(op_masks); | |||
| MS_EXCEPTION_IF_NULL(args_spec_list); | |||
| MS_EXCEPTION_IF_NULL(op_exec_info); | |||
| auto prim = op_exec_info->py_primitive; | |||
| std::vector<AnfNodePtr> inputs; | |||
| inputs.emplace_back(NewValueNode(prim)); | |||
| const auto &signature = prim->signatures(); | |||
| auto sig_size = signature.size(); | |||
| auto size = op_exec_info->op_inputs.size(); | |||
| // ignore monad signature | |||
| for (auto sig : signature) { | |||
| if (sig.default_value != nullptr && sig.default_value->isa<Monad>()) { | |||
| --sig_size; | |||
| } | |||
| } | |||
| if (sig_size > 0 && sig_size != size) { | |||
| MS_EXCEPTION(ValueError) << op_exec_info->op_name << " inputs size " << size << " does not match the requires " | |||
| << "inputs size " << sig_size; | |||
| } | |||
| if (op_exec_info->op_name != prim::kPrimCast->name()) { | |||
| RunParameterAutoMixPrecisionCast(op_exec_info); | |||
| } | |||
| MS_LOG(DEBUG) << "Get op " << op_exec_info->op_name << " grad_flag_ " << grad_flag(); | |||
| GetArgsSpec(op_exec_info, op_masks, &inputs, args_spec_list); | |||
| CNodePtr cnode = nullptr; | |||
| if (need_construct_graph()) { | |||
| @@ -208,6 +208,8 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> { | |||
| PynativeStatusCode *const status); | |||
| AnfNodePtr GetObjNode(const py::object &obj, const std::string &obj_id); | |||
| AnfNodePtr MakeValueNode(const py::object &obj, const std::string &obj_id); | |||
| void GetArgsSpec(const OpExecInfoPtr &op_exec_info, std::vector<bool> *op_masks, std::vector<AnfNodePtr> *inputs, | |||
| abstract::AbstractBasePtrList *args_spec_list); | |||
| AnfNodePtr MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector<bool> *op_masks, | |||
| abstract::AbstractBasePtrList *args_spec_list); | |||
| abstract::AbstractBasePtr CheckConstValue(const PrimitivePyPtr &prim, const py::object &obj, | |||
| @@ -1,5 +1,7 @@ | |||
| file(GLOB_RECURSE DEVICE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "common/*.cc" | |||
| "kernel_info.cc" "executor/dynamic_kernel.cc" "executor/executor_callback.cc" "kernel_runtime.cc" "memory_manager.cc" "kernel_runtime_manager.cc" "convert_tensor_utils.cc" | |||
| "kernel_info.cc" "executor/dynamic_kernel.cc" "executor/executor_callback.cc" "kernel_runtime.cc" | |||
| "memory_manager.cc" "kernel_runtime_manager.cc" "convert_tensor_utils.cc" | |||
| "bucket.cc" | |||
| ) | |||
| if(ENABLE_GPU) | |||
| @@ -0,0 +1,173 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "runtime/device/ascend/ascend_bucket.h" | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "runtime/mem.h" | |||
| #include "external/hccl/hccl.h" | |||
| #include "runtime/device/ascend/ascend_memory_pool.h" | |||
| #include "backend/kernel_compiler/hccl/hcom_util.h" | |||
| #include "backend/kernel_compiler/hccl/hccl_context.h" | |||
| #include "runtime/device/memory_manager.h" | |||
| #include "runtime/device/kernel_runtime_manager.h" | |||
| #include "runtime/device/ascend/ascend_event.h" | |||
| #include "utils/profile.h" | |||
| #define CHECK_ASCEND_RT_WITH_EXCEPTION(expression, message) \ | |||
| { \ | |||
| rtError_t ret = (expression); \ | |||
| if (ret != RT_ERROR_NONE) { \ | |||
| MS_LOG(EXCEPTION) << message << ", error code: " << ret; \ | |||
| } \ | |||
| } | |||
| namespace mindspore::device::ascend { | |||
| void AscendBucket::AllocateAllReduceAddr() { | |||
| // Check bucket is full | |||
| if (grad_tensor_list_.size() != bucket_size_) { | |||
| MS_LOG(EXCEPTION) << "grad tensor list size:" << grad_tensor_list_.size() | |||
| << " is not equal to bucket size:" << bucket_size_; | |||
| } | |||
| auto total_size = 0; | |||
| std::vector<size_t> align_size_list; | |||
| std::vector<size_t> origin_size_list; | |||
| for (auto &tensor : grad_tensor_list_) { | |||
| MS_EXCEPTION_IF_NULL(tensor); | |||
| tensor_type_list_.emplace_back(tensor->data_type()); | |||
| DeviceAddressPtr device_address = std::dynamic_pointer_cast<DeviceAddress>(tensor->device_address()); | |||
| auto origin_size = device_address->GetSize(); | |||
| auto align_size = MemoryManager::GetCommonAlignSize(origin_size); | |||
| origin_size_list.emplace_back(origin_size); | |||
| align_size_list.emplace_back(align_size); | |||
| total_size += align_size; | |||
| memcpy_input_addrs_.emplace_back(std::make_shared<kernel::Address>( | |||
| static_cast<uint8_t *>(device_address->GetMutablePtr()), device_address->GetSize())); | |||
| } | |||
| total_size_ = total_size; | |||
| auto runtime_instance = device::KernelRuntimeManager::Instance().GetCurrentKernelRuntime(); | |||
| MS_EXCEPTION_IF_NULL(runtime_instance); | |||
| // AllReduce input output addr need to clear zero | |||
| ar_input_addr_ = runtime_instance->MallocCommunicationMemFromMemPool(total_size); | |||
| ar_output_addr_ = runtime_instance->MallocCommunicationMemFromMemPool(total_size); | |||
| // generate memecpy output addr | |||
| uint8_t *memcpy_output = ar_input_addr_; | |||
| for (size_t i = 0; i < bucket_size_; ++i) { | |||
| memcpy_output_addrs_.emplace_back(std::make_shared<kernel::Address>(memcpy_output, origin_size_list[i])); | |||
| memcpy_output += align_size_list[i]; | |||
| } | |||
| // store output tensor addr | |||
| uint8_t *tensor_output = ar_output_addr_; | |||
| for (size_t i = 0; i < bucket_size_; ++i) { | |||
| new_tensor_output_addrs_.emplace_back(tensor_output); | |||
| tensor_output += align_size_list[i]; | |||
| } | |||
| } | |||
| void AscendBucket::FreeDeviceMem(void *dev_ptr) { AscendMemoryPool::GetInstance().FreeTensorMem(dev_ptr); } | |||
| void AscendBucket::FreeAllDeviceMem() { | |||
| if (ar_input_addr_ != nullptr) { | |||
| uint8_t *origin_dev_addr = ar_input_addr_ - kMemAlignSize; | |||
| FreeDeviceMem(origin_dev_addr); | |||
| ar_input_addr_ = nullptr; | |||
| } | |||
| if (ar_output_addr_ != nullptr) { | |||
| uint8_t *origin_dev_addr = ar_output_addr_ - kMemAlignSize; | |||
| FreeDeviceMem(origin_dev_addr); | |||
| ar_output_addr_ = nullptr; | |||
| } | |||
| } | |||
| void AscendBucket::CopyTensorToContiguousMemory() { | |||
| // Clean input addr | |||
| CHECK_ASCEND_RT_WITH_EXCEPTION(rtMemsetAsync(ar_input_addr_, total_size_, 0, total_size_, compute_stream_), | |||
| "Call rtMemsetAsync failed"); | |||
| for (size_t i = 0; i < bucket_size_; ++i) { | |||
| MS_EXCEPTION_IF_NULL(memcpy_input_addrs_[i]); | |||
| MS_EXCEPTION_IF_NULL(memcpy_output_addrs_[i]); | |||
| MS_LOG(DEBUG) << "MemcpyAsync dst size:" << memcpy_output_addrs_[i]->size | |||
| << " src size:" << memcpy_input_addrs_[i]->size; | |||
| if (memcpy_output_addrs_[i]->size < memcpy_input_addrs_[i]->size) { | |||
| MS_LOG(EXCEPTION) << "rtMemcpyAsync dst size < src size"; | |||
| } | |||
| CHECK_ASCEND_RT_WITH_EXCEPTION( | |||
| rtMemcpyAsync(memcpy_output_addrs_[i]->addr, memcpy_output_addrs_[i]->size, memcpy_input_addrs_[i]->addr, | |||
| memcpy_input_addrs_[i]->size, RT_MEMCPY_DEVICE_TO_DEVICE, compute_stream_), | |||
| "Call rtMemcpyAsync failed"); | |||
| } | |||
| } | |||
| void AscendBucket::LaunchAllReduce() { | |||
| if (tensor_type_list_.empty()) { | |||
| MS_LOG(EXCEPTION) << "No tesnor type found"; | |||
| } | |||
| // AllReduce inputs data type should be same | |||
| auto type = tensor_type_list_[0]; | |||
| if (std::any_of(tensor_type_list_.begin(), tensor_type_list_.end(), | |||
| [&type](TypeId tensor_type) { return type != tensor_type; })) { | |||
| MS_LOG(EXCEPTION) << "allreduce input have different dtype"; | |||
| } | |||
| auto iter = CONST_OP_HCOM_DATA_TYPE_MAP.find(type); | |||
| if (iter == CONST_OP_HCOM_DATA_TYPE_MAP.end()) { | |||
| MS_LOG(EXCEPTION) << "unknown data type:" << type; | |||
| } | |||
| uint32_t type_size; | |||
| if (!HcomUtil::GetHcomTypeSize(iter->second, &type_size)) { | |||
| MS_LOG(EXCEPTION) << "get hcom type size fialed"; | |||
| } | |||
| if (type_size == 0 || total_size_ % type_size != 0) { | |||
| MS_LOG(EXCEPTION) << "Total_size[" << total_size_ << "],Type_size[" << type_size << "] != 0, fail!"; | |||
| } | |||
| auto hccl_count = total_size_ / type_size; | |||
| HcclReduceOp op_type = HcclReduceOp::HCCL_REDUCE_SUM; | |||
| auto hccl_result = HcclAllReduce(ar_input_addr_, ar_output_addr_, hccl_count, iter->second, op_type, | |||
| kernel::HcclContext::GetInstance().hccl_comm(), stream_); | |||
| if (hccl_result != HCCL_SUCCESS) { | |||
| MS_LOG(EXCEPTION) << "HcclAllReduce faled, ret:" << hccl_result; | |||
| } | |||
| } | |||
| void AscendBucket::Init() { | |||
| pre_event_ = std::make_shared<AscendEvent>(); | |||
| post_event_ = std::make_shared<AscendEvent>(); | |||
| auto kernel_runtime = KernelRuntimeManager::Instance().GetCurrentKernelRuntime(); | |||
| MS_EXCEPTION_IF_NULL(kernel_runtime); | |||
| compute_stream_ = kernel_runtime->compute_stream(); | |||
| stream_ = kernel_runtime->communication_stream(); | |||
| MS_EXCEPTION_IF_NULL(pre_event_); | |||
| MS_EXCEPTION_IF_NULL(post_event_); | |||
| pre_event_->set_wait_stream(stream_); | |||
| pre_event_->set_record_stream(compute_stream_); | |||
| post_event_->set_wait_stream(compute_stream_); | |||
| post_event_->set_record_stream(stream_); | |||
| } | |||
| } // namespace mindspore::device::ascend | |||
| @@ -0,0 +1,38 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_BUCKET_H_ | |||
| #define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_BUCKET_H_ | |||
| #include "runtime/device/bucket.h" | |||
| namespace mindspore::device::ascend { | |||
| class AscendBucket : public Bucket { | |||
| public: | |||
| AscendBucket(uint32_t id, uint32_t bucket_size) : Bucket(id, bucket_size) {} | |||
| ~AscendBucket() override = default; | |||
| void Init() override; | |||
| private: | |||
| void AllocateAllReduceAddr() override; | |||
| void FreeAllDeviceMem() override; | |||
| void FreeDeviceMem(void *dev_ptr) override; | |||
| void CopyTensorToContiguousMemory() override; | |||
| void LaunchAllReduce() override; | |||
| }; | |||
| } // namespace mindspore::device::ascend | |||
| #endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_BUCKET_H_ | |||
| @@ -0,0 +1,60 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "runtime/device/ascend/ascend_event.h" | |||
| #include "runtime/event.h" | |||
| #include "runtime/stream.h" | |||
| #include "utils/log_adapter.h" | |||
| namespace mindspore::device::ascend { | |||
| AscendEvent::AscendEvent() { | |||
| auto ret = rtEventCreate(&event_); | |||
| if (ret != RT_ERROR_NONE) { | |||
| MS_LOG(ERROR) << "rtEventCreate failed, ret:" << ret; | |||
| event_ = nullptr; | |||
| } | |||
| } | |||
| AscendEvent::~AscendEvent() { | |||
| auto ret = rtEventDestroy(event_); | |||
| if (ret != RT_ERROR_NONE) { | |||
| MS_LOG(ERROR) << "rtEventDestory failed, ret:" << ret; | |||
| } | |||
| } | |||
| void AscendEvent::RecordEvent() { | |||
| MS_EXCEPTION_IF_NULL(event_); | |||
| MS_EXCEPTION_IF_NULL(record_stream_); | |||
| auto ret = rtEventRecord(event_, record_stream_); | |||
| if (ret != RT_ERROR_NONE) { | |||
| MS_LOG(EXCEPTION) << "rtEventRecord failed, ret:" << ret; | |||
| } | |||
| need_wait_ = true; | |||
| } | |||
| void AscendEvent::WaitEvent() { | |||
| MS_EXCEPTION_IF_NULL(event_); | |||
| MS_EXCEPTION_IF_NULL(wait_stream_); | |||
| auto ret = rtStreamWaitEvent(wait_stream_, event_); | |||
| if (ret != RT_ERROR_NONE) { | |||
| MS_LOG(EXCEPTION) << "rtStreamWaitEvent failed, ret:" << ret; | |||
| } | |||
| need_wait_ = false; | |||
| } | |||
| bool AscendEvent::NeedWait() { return need_wait_; } | |||
| } // namespace mindspore::device::ascend | |||
| @@ -0,0 +1,41 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_ASCEND_EVENT_H | |||
| #define MINDSPORE_ASCEND_EVENT_H | |||
| #include "runtime/base.h" | |||
| #include "ir/device_event.h" | |||
| namespace mindspore::device::ascend { | |||
| class AscendEvent : public DeviceEvent { | |||
| public: | |||
| AscendEvent(); | |||
| ~AscendEvent() override; | |||
| void WaitEvent() override; | |||
| void RecordEvent() override; | |||
| bool NeedWait() override; | |||
| void set_wait_stream(rtStream_t wait_stream) override { wait_stream_ = wait_stream; } | |||
| void set_record_stream(rtStream_t record_stream) override { record_stream_ = record_stream; } | |||
| private: | |||
| rtEvent_t event_{nullptr}; | |||
| rtStream_t wait_stream_{nullptr}; | |||
| rtStream_t record_stream_{nullptr}; | |||
| bool need_wait_{false}; | |||
| }; | |||
| } // namespace mindspore::device::ascend | |||
| #endif // MINDSPORE_ASCEND_EVENT_H | |||
| @@ -718,6 +718,10 @@ bool AscendKernelRuntime::SyncStream() { | |||
| MS_LOG(ERROR) << "Call runtime rtStreamSynchronize error."; | |||
| return false; | |||
| } | |||
| if (RT_ERROR_NONE != rtStreamSynchronize(communication_stream_)) { // o for switch stream | |||
| MS_LOG(ERROR) << "Call runtime rtStreamSynchronize error."; | |||
| return false; | |||
| } | |||
| FreeAndClearBufferPtrs(); | |||
| return true; | |||
| } | |||
| @@ -786,6 +790,10 @@ bool AscendKernelRuntime::InitDevice() { | |||
| if (ret != RT_ERROR_NONE) { | |||
| MS_LOG(EXCEPTION) << "Call rtStreamCreate, ret[" << ret << "]"; | |||
| } | |||
| ret = rtStreamCreate(&communication_stream_, 0); | |||
| if (ret != RT_ERROR_NONE) { | |||
| MS_LOG(EXCEPTION) << "create communication stream failed, ret:" << ret; | |||
| } | |||
| return true; | |||
| } | |||
| @@ -799,6 +807,14 @@ bool AscendKernelRuntime::ResetDevice(uint32_t device_id) { | |||
| stream_ = nullptr; | |||
| } | |||
| if (communication_stream_ != nullptr) { | |||
| auto ret = rtStreamDestroy(communication_stream_); | |||
| if (ret != RT_ERROR_NONE) { | |||
| MS_LOG(EXCEPTION) << "Call rtStreamDestroy, ret[" << ret << "]"; | |||
| } | |||
| communication_stream_ = nullptr; | |||
| } | |||
| auto ret = rtDeviceReset(device_id); | |||
| if (ret != RT_ERROR_NONE) { | |||
| MS_EXCEPTION(DeviceProcessError) << "Call rtDeviceReset, ret[" << ret << "]"; | |||
| @@ -919,4 +935,5 @@ uint64_t AscendKernelRuntime::GetAvailableMemMaxSize() const { | |||
| auto ascend_mem_manager = dynamic_pointer_cast<AscendMemoryManager>(mem_manager_); | |||
| return ascend_mem_manager->GetDeviceMemSize(); | |||
| } | |||
| } // namespace mindspore::device::ascend | |||
| @@ -57,6 +57,8 @@ class AscendKernelRuntime : public KernelRuntime { | |||
| void *context() const override { return rt_context_; } | |||
| void PreInit() override; | |||
| uint64_t GetAvailableMemMaxSize() const override; | |||
| void *compute_stream() const override { return stream_; } | |||
| void *communication_stream() const override { return communication_stream_; } | |||
| protected: | |||
| DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, | |||
| @@ -162,6 +162,14 @@ void AscendMemoryManager::MallocSomasDynamicMem(const session::KernelGraph *grap | |||
| somas_reuse_util_ptr_->ConvertToProfilingNode(graph->graph_id()); | |||
| } | |||
| } | |||
| // communication memory: [512align_size + data + 512align_size] | |||
| // return the pointer to the start of data address. | |||
| uint8_t *AscendMemoryManager::MallocCommunicationMemFromMemPool(size_t size) { | |||
| auto align_size = GetCommunicationAlignSize(size); | |||
| uint8_t *base_ptr = reinterpret_cast<uint8_t *>(AscendMemoryPool::GetInstance().AllocTensorMem(align_size)); | |||
| return base_ptr + kMemAlignSize; | |||
| } | |||
| } // namespace ascend | |||
| } // namespace device | |||
| } // namespace mindspore | |||
| @@ -33,6 +33,7 @@ class AscendMemoryManager : public MemoryManager { | |||
| void *MallocMemFromMemPool(size_t size) override; | |||
| uint64_t GetDeviceMemSize(); | |||
| void MallocSomasDynamicMem(const session::KernelGraph *graph); | |||
| uint8_t *MallocCommunicationMemFromMemPool(size_t size) override; | |||
| protected: | |||
| uint8_t *MallocStaticMem(size_t size, bool communication_mem, uint32_t graph_id = kInvalidGraphId) override; | |||
| @@ -0,0 +1,106 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "runtime/device/bucket.h" | |||
| #include <memory> | |||
| #include "runtime/device/kernel_runtime_manager.h" | |||
| #include "utils/profile.h" | |||
| namespace mindspore::device { | |||
| void Bucket::AddGradTensor(const tensor::TensorPtr &tensor) { | |||
| if (grad_tensor_list_.size() >= bucket_size_) { | |||
| MS_LOG(EXCEPTION) << "bucket is full"; | |||
| } | |||
| grad_tensor_list_.emplace_back(tensor); | |||
| if (grad_tensor_list_.size() > bucket_size_) { | |||
| MS_LOG(EXCEPTION) << "too many tensor add to the bucket, bucket_size_:" << bucket_size_ | |||
| << " total tensor size:" << grad_tensor_list_.size(); | |||
| } | |||
| MS_LOG(INFO) << "current bucket tensors size:" << grad_tensor_list_.size(); | |||
| // bucket is full, start to launch allreduce | |||
| if (grad_tensor_list_.size() == bucket_size_) { | |||
| full_ = true; | |||
| } | |||
| } | |||
| void Bucket::Launch() { | |||
| auto start = GetTime(); | |||
| if (grad_tensor_list_.size() != bucket_size_) { | |||
| MS_LOG(EXCEPTION) << "Bucket is not full, grad_tensor_list_ size:" << grad_tensor_list_.size() | |||
| << " bucket_size_:" << bucket_size_; | |||
| } | |||
| MS_LOG(INFO) << "Bucket is full, start to launch AllReduce"; | |||
| MS_EXCEPTION_IF_NULL(pre_event_); | |||
| MS_EXCEPTION_IF_NULL(post_event_); | |||
| AllocateAllReduceAddr(); | |||
| CopyTensorToContiguousMemory(); | |||
| pre_event_->RecordEvent(); | |||
| pre_event_->WaitEvent(); | |||
| LaunchAllReduce(); | |||
| post_event_->RecordEvent(); | |||
| UpdateTensorAddr(); | |||
| // pass event to the tensor | |||
| for (auto &tensor : grad_tensor_list_) { | |||
| MS_EXCEPTION_IF_NULL(tensor); | |||
| tensor->SetDeviceEvent(post_event_); | |||
| } | |||
| MS_LOG(INFO) << "Bucket launch cost:" << (GetTime() - start) * 1e6 << " us"; | |||
| } | |||
| // TODO(caifubi): float16 grad cast to float32 grad | |||
| void Bucket::UpdateTensorAddr() { | |||
| if (grad_tensor_list_.size() != bucket_size_ || new_tensor_output_addrs_.size() != bucket_size_) { | |||
| MS_LOG(EXCEPTION) << "grad_tensor_list size:" << grad_tensor_list_.size() | |||
| << " tensor output addr size:" << new_tensor_output_addrs_.size() | |||
| << " bucket size:" << bucket_size_; | |||
| } | |||
| for (size_t i = 0; i < bucket_size_; ++i) { | |||
| auto &tensor = grad_tensor_list_[i]; | |||
| MS_EXCEPTION_IF_NULL(tensor); | |||
| auto device_address = std::dynamic_pointer_cast<DeviceAddress>(tensor->device_address()); | |||
| // release old addr and manage addr by this Bucket. | |||
| MS_EXCEPTION_IF_NULL(device_address); | |||
| auto origin_dev_ptr = device_address->GetMutablePtr(); | |||
| // FreeDeviceMem(origin_dev_ptr); | |||
| tensor_old_addr_list_.emplace_back(origin_dev_ptr); | |||
| device_address->from_mem_pool_ = false; | |||
| device_address->set_ptr(new_tensor_output_addrs_[i]); | |||
| } | |||
| } | |||
| void Bucket::LazyDeleteOldAddr() { | |||
| MS_LOG(INFO) << "Lazy delete old grad address"; | |||
| for (auto old_addr : tensor_old_addr_list_) { | |||
| FreeDeviceMem(old_addr); | |||
| } | |||
| tensor_old_addr_list_.clear(); | |||
| } | |||
| void Bucket::Release() { | |||
| MS_LOG(INFO) << "Clear bucket:" << id_; | |||
| grad_tensor_list_.clear(); | |||
| new_tensor_output_addrs_.clear(); | |||
| memcpy_input_addrs_.clear(); | |||
| memcpy_output_addrs_.clear(); | |||
| tensor_type_list_.clear(); | |||
| LazyDeleteOldAddr(); | |||
| FreeAllDeviceMem(); | |||
| full_ = false; | |||
| } | |||
| } // namespace mindspore::device | |||
| @@ -0,0 +1,83 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_BUCKET_H_ | |||
| #define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_BUCKET_H_ | |||
| #include <vector> | |||
| #include <utility> | |||
| #include <string> | |||
| #include <memory> | |||
| #include "ir/anf.h" | |||
| #include "ir/device_event.h" | |||
| #include "runtime/device/device_address.h" | |||
| #include "backend/session/kernel_graph.h" | |||
| namespace mindspore::device { | |||
| class Bucket { | |||
| public: | |||
| Bucket(uint32_t id, uint32_t bucket_size) | |||
| : id_(id), | |||
| bucket_size_(bucket_size), | |||
| full_(false), | |||
| stream_(nullptr), | |||
| compute_stream_(nullptr), | |||
| pre_event_(nullptr), | |||
| post_event_(nullptr), | |||
| total_size_(0), | |||
| ar_input_addr_(nullptr), | |||
| ar_output_addr_(nullptr) {} | |||
| virtual ~Bucket() = default; | |||
| uint32_t id() const { return id_; } | |||
| bool full() const { return full_; } | |||
| void Launch(); | |||
| void Release(); | |||
| void AddGradTensor(const tensor::TensorPtr &tensor); | |||
| virtual void Init() = 0; | |||
| protected: | |||
| uint32_t id_; | |||
| uint32_t bucket_size_; | |||
| bool full_; | |||
| void *stream_; | |||
| void *compute_stream_; | |||
| std::shared_ptr<DeviceEvent> pre_event_; | |||
| std::shared_ptr<DeviceEvent> post_event_; | |||
| size_t total_size_; | |||
| uint8_t *ar_input_addr_; | |||
| uint8_t *ar_output_addr_; | |||
| std::string group_; | |||
| std::vector<tensor::TensorPtr> grad_tensor_list_; | |||
| std::vector<uint8_t *> new_tensor_output_addrs_; | |||
| std::vector<kernel::AddressPtr> memcpy_input_addrs_; | |||
| std::vector<kernel::AddressPtr> memcpy_output_addrs_; | |||
| std::vector<TypeId> tensor_type_list_; | |||
| std::vector<void *> tensor_old_addr_list_; | |||
| virtual void AllocateAllReduceAddr() = 0; | |||
| void UpdateTensorAddr(); | |||
| virtual void LaunchAllReduce() = 0; | |||
| virtual void FreeAllDeviceMem() = 0; | |||
| virtual void FreeDeviceMem(void *dev_ptr) = 0; | |||
| virtual void CopyTensorToContiguousMemory() = 0; | |||
| void LazyDeleteOldAddr(); | |||
| }; | |||
| } // namespace mindspore::device | |||
| #endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_BUCKET_H_ | |||
| @@ -26,6 +26,7 @@ | |||
| namespace mindspore { | |||
| namespace device { | |||
| class Bucket; | |||
| namespace cpu { | |||
| class CPUSimpleMemPlan; | |||
| class CPUMemoryManager; | |||
| @@ -100,6 +101,7 @@ class DeviceAddress : public mindspore::DeviceSync { | |||
| friend class mindspore::device::ascend::AscendKernelRuntime; | |||
| friend class mindspore::device::ascend::AscendMemoryManager; | |||
| friend class mindspore::device::ascend::DataDumper; | |||
| friend class mindspore::device::Bucket; | |||
| }; | |||
| using DeviceAddressPtr = std::shared_ptr<DeviceAddress>; | |||
| @@ -0,0 +1,177 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "runtime/device/gpu/gpu_bucket.h" | |||
| #include <cuda_runtime_api.h> | |||
| #include <nccl.h> | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "abstract/utils.h" | |||
| #include "runtime/device/gpu/gpu_event.h" | |||
| #include "runtime/device/gpu/gpu_memory_allocator.h" | |||
| #include "runtime/device/gpu/gpu_device_manager.h" | |||
| #include "runtime/device/kernel_runtime_manager.h" | |||
| #include "runtime/device/gpu/distribution/collective_init.h" | |||
| #include "backend/kernel_compiler/gpu/nccl/nccl_gpu_kernel.h" | |||
| #include "runtime/device/gpu/gpu_common.h" | |||
| namespace { | |||
| const size_t kCommunicationMemAlignSize = 16; | |||
| size_t AlignMemorySize(size_t size) { | |||
| if (size == 0) { | |||
| return kCommunicationMemAlignSize; | |||
| } | |||
| return ((size + kCommunicationMemAlignSize - 1) / kCommunicationMemAlignSize) * kCommunicationMemAlignSize; | |||
| } | |||
| } // namespace | |||
| namespace mindspore::device::gpu { | |||
| GPUBucket::GPUBucket(uint32_t id, uint32_t bucket_size) : Bucket(id, bucket_size), collective_handle_(nullptr) { | |||
| group_ = kNcclWorldGroup; | |||
| } | |||
| void GPUBucket::AllocateAllReduceAddr() { | |||
| MS_LOG(INFO) << "start"; | |||
| if (grad_tensor_list_.size() != bucket_size_) { | |||
| MS_LOG(EXCEPTION) << "grad tensor list size:" << grad_tensor_list_.size() | |||
| << " is not equal to bucket size:" << bucket_size_; | |||
| } | |||
| auto total_size = 0; | |||
| std::vector<size_t> size_list; | |||
| std::vector<size_t> align_size_list; | |||
| for (auto &tensor : grad_tensor_list_) { | |||
| MS_EXCEPTION_IF_NULL(tensor); | |||
| tensor_type_list_.emplace_back(tensor->data_type()); | |||
| DeviceAddressPtr device_address = std::dynamic_pointer_cast<DeviceAddress>(tensor->device_address()); | |||
| MS_EXCEPTION_IF_NULL(device_address); | |||
| auto origin_size = device_address->GetSize(); | |||
| auto align_size = AlignMemorySize(origin_size); | |||
| size_list.emplace_back(origin_size); | |||
| align_size_list.emplace_back(align_size); | |||
| total_size += align_size; | |||
| memcpy_input_addrs_.emplace_back( | |||
| std::make_shared<kernel::Address>(static_cast<uint8_t *>(device_address->GetMutablePtr()), origin_size)); | |||
| } | |||
| total_size_ = total_size; | |||
| ar_input_addr_ = static_cast<uint8_t *>(GPUMemoryAllocator::GetInstance().AllocTensorMem(total_size)); | |||
| ar_output_addr_ = static_cast<uint8_t *>(GPUMemoryAllocator::GetInstance().AllocTensorMem(total_size)); | |||
| uint8_t *memcpy_output = ar_input_addr_; | |||
| for (size_t i = 0; i < bucket_size_; ++i) { | |||
| memcpy_output_addrs_.emplace_back(std::make_shared<kernel::Address>(memcpy_output, size_list[i])); | |||
| memcpy_output += align_size_list[i]; | |||
| } | |||
| uint8_t *tensor_output = ar_output_addr_; | |||
| for (size_t i = 0; i < bucket_size_; ++i) { | |||
| new_tensor_output_addrs_.emplace_back(tensor_output); | |||
| tensor_output += align_size_list[i]; | |||
| } | |||
| MS_LOG(INFO) << "end"; | |||
| } | |||
| void GPUBucket::FreeDeviceMem(void *dev_ptr) { GPUMemoryAllocator::GetInstance().FreeTensorMem(dev_ptr); } | |||
| void GPUBucket::FreeAllDeviceMem() { | |||
| MS_LOG(INFO) << "start"; | |||
| if (ar_input_addr_ != nullptr) { | |||
| FreeDeviceMem(ar_input_addr_); | |||
| ar_input_addr_ = nullptr; | |||
| } | |||
| if (ar_output_addr_ != nullptr) { | |||
| FreeDeviceMem(ar_output_addr_); | |||
| ar_output_addr_ = nullptr; | |||
| } | |||
| MS_LOG(INFO) << "end"; | |||
| } | |||
| void GPUBucket::CopyTensorToContiguousMemory() { | |||
| MS_LOG(INFO) << "start"; | |||
| MS_EXCEPTION_IF_NULL(compute_stream_); | |||
| // Clean allreduce input | |||
| CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( | |||
| cudaMemsetAsync(ar_input_addr_, 0, total_size_, static_cast<cudaStream_t>(compute_stream_)), | |||
| "Call cudaMemsetAsync failed"); | |||
| for (size_t i = 0; i < bucket_size_; ++i) { | |||
| MS_EXCEPTION_IF_NULL(memcpy_output_addrs_[i]); | |||
| MS_EXCEPTION_IF_NULL(memcpy_input_addrs_[i]); | |||
| if (!GPUDeviceManager::GetInstance().CopyDeviceMemToDeviceAsync(memcpy_output_addrs_[i]->addr, | |||
| memcpy_input_addrs_[i]->addr, | |||
| memcpy_output_addrs_[i]->size, compute_stream_)) { | |||
| MS_LOG(EXCEPTION) << "Copy memory failed"; | |||
| } | |||
| } | |||
| MS_LOG(INFO) << "end"; | |||
| } | |||
| void GPUBucket::LaunchAllReduce() { | |||
| MS_LOG(INFO) << "start"; | |||
| collective_handle_ = device::gpu::CollectiveInitializer::instance().collective_handle(); | |||
| auto all_reduce_funcptr = | |||
| reinterpret_cast<kernel::AllReduce>(dlsym(const_cast<void *>(collective_handle_), "AllReduce")); | |||
| MS_EXCEPTION_IF_NULL(all_reduce_funcptr); | |||
| MS_EXCEPTION_IF_NULL(stream_); | |||
| if (tensor_type_list_.empty()) { | |||
| MS_LOG(EXCEPTION) << "No tesnor type found"; | |||
| } | |||
| auto type = tensor_type_list_[0]; | |||
| if (std::any_of(tensor_type_list_.begin(), tensor_type_list_.end(), | |||
| [&type](TypeId tensor_type) { return type != tensor_type; })) { | |||
| MS_LOG(EXCEPTION) << "AllReduce input have different dtype"; | |||
| } | |||
| auto type_size = abstract::TypeIdSize(type); | |||
| if (type_size == 0) { | |||
| MS_LOG(EXCEPTION) << "Invalid type:" << type; | |||
| } | |||
| // typeid to nccl_data_type | |||
| auto nccl_data_type_iter = kernel::kNcclDtypeMap.find(TypeIdLabel(type)); | |||
| if (nccl_data_type_iter == kernel::kNcclDtypeMap.end()) { | |||
| MS_LOG(EXCEPTION) << "Invalid type:" << type; | |||
| } | |||
| auto nccl_result = | |||
| (*all_reduce_funcptr)(ar_input_addr_, ar_output_addr_, total_size_ / type_size, nccl_data_type_iter->second, | |||
| ncclRedOp_t::ncclSum, static_cast<cudaStream_t>(stream_), group_); | |||
| if (nccl_result != ncclSuccess) { | |||
| MS_LOG(EXCEPTION) << "AllReduce failed, ret:" << nccl_result; | |||
| } | |||
| MS_LOG(INFO) << "end"; | |||
| } | |||
| void GPUBucket::Init() { | |||
| pre_event_ = std::make_shared<GpuEvent>(); | |||
| post_event_ = std::make_shared<GpuEvent>(); | |||
| auto kernel_runtime = KernelRuntimeManager::Instance().GetCurrentKernelRuntime(); | |||
| MS_EXCEPTION_IF_NULL(kernel_runtime); | |||
| stream_ = kernel_runtime->communication_stream(); | |||
| compute_stream_ = kernel_runtime->compute_stream(); | |||
| MS_EXCEPTION_IF_NULL(pre_event_); | |||
| MS_EXCEPTION_IF_NULL(post_event_); | |||
| pre_event_->set_record_stream(compute_stream_); | |||
| pre_event_->set_wait_stream(stream_); | |||
| post_event_->set_record_stream(stream_); | |||
| post_event_->set_wait_stream(compute_stream_); | |||
| } | |||
| } // namespace mindspore::device::gpu | |||
| @@ -0,0 +1,40 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_BUCKET_H_ | |||
| #define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_BUCKET_H_ | |||
| #include "runtime/device/bucket.h" | |||
| namespace mindspore::device::gpu { | |||
| class GPUBucket : public Bucket { | |||
| public: | |||
| GPUBucket(uint32_t id, uint32_t bucket_size); | |||
| ~GPUBucket() override = default; | |||
| void Init() override; | |||
| private: | |||
| void AllocateAllReduceAddr() override; | |||
| void FreeAllDeviceMem() override; | |||
| void FreeDeviceMem(void *dev_ptr) override; | |||
| void CopyTensorToContiguousMemory() override; | |||
| void LaunchAllReduce() override; | |||
| const void *collective_handle_; | |||
| }; | |||
| } // namespace mindspore::device::gpu | |||
| #endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_BUCKET_H_ | |||
| @@ -0,0 +1,46 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "runtime/device/gpu/gpu_event.h" | |||
| #include "runtime/device/gpu/gpu_common.h" | |||
| namespace mindspore::device::gpu { | |||
| GpuEvent::GpuEvent() { | |||
| auto ret = cudaEventCreate(&event_); | |||
| if (ret != cudaSuccess) { | |||
| MS_LOG(ERROR) << "cudaEventCreate failed, ret:" << ret; | |||
| event_ = nullptr; | |||
| } | |||
| } | |||
| GpuEvent::~GpuEvent() { CHECK_CUDA_RET_WITH_ERROR_NOTRACE(cudaEventDestroy(event_), "cudaEventDestory failed"); } | |||
| void GpuEvent::WaitEvent() { | |||
| MS_EXCEPTION_IF_NULL(wait_stream_); | |||
| MS_EXCEPTION_IF_NULL(event_); | |||
| CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamWaitEvent(wait_stream_, event_, 0), "cudaStreamWaitEvent failed"); | |||
| need_wait_ = false; | |||
| } | |||
| void GpuEvent::RecordEvent() { | |||
| MS_EXCEPTION_IF_NULL(event_); | |||
| MS_EXCEPTION_IF_NULL(record_stream_); | |||
| CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaEventRecord(event_, record_stream_), "cudaEventRecord failed"); | |||
| need_wait_ = true; | |||
| } | |||
| bool GpuEvent::NeedWait() { return need_wait_; } | |||
| } // namespace mindspore::device::gpu | |||
| @@ -0,0 +1,42 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_EVENT_H_ | |||
| #define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_EVENT_H_ | |||
| #include <cuda_runtime_api.h> | |||
| #include "ir/device_event.h" | |||
| namespace mindspore::device::gpu { | |||
| class GpuEvent : public DeviceEvent { | |||
| public: | |||
| GpuEvent(); | |||
| ~GpuEvent() override; | |||
| void WaitEvent() override; | |||
| void RecordEvent() override; | |||
| bool NeedWait() override; | |||
| void set_wait_stream(void *wait_stream) override { wait_stream_ = static_cast<cudaStream_t>(wait_stream); } | |||
| void set_record_stream(void *record_stream) override { record_stream_ = static_cast<cudaStream_t>(record_stream); } | |||
| private: | |||
| cudaEvent_t event_{nullptr}; | |||
| cudaStream_t wait_stream_{nullptr}; | |||
| cudaStream_t record_stream_{nullptr}; | |||
| bool need_wait_{false}; | |||
| }; | |||
| } // namespace mindspore::device::gpu | |||
| #endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_EVENT_H_ | |||
| @@ -229,6 +229,11 @@ bool GPUKernelRuntime::InitDevice() { | |||
| MS_LOG(ERROR) << "No default CUDA stream found."; | |||
| return false; | |||
| } | |||
| GPUDeviceManager::GetInstance().CreateStream(&communication_stream_); | |||
| if (communication_stream_ == nullptr) { | |||
| MS_LOG(ERROR) << "Invalid communication stream"; | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| @@ -1251,6 +1256,7 @@ session::KernelWithIndex GPUKernelRuntime::GetPrevNodeOutput(const AnfNodePtr &n | |||
| return addr_iter->second[i]; | |||
| } | |||
| } // namespace gpu | |||
| } // namespace device | |||
| } // namespace mindspore | |||
| @@ -47,6 +47,8 @@ class GPUKernelRuntime : public KernelRuntime { | |||
| bool Run(session::KernelGraph *graph, bool is_task_sink) override; | |||
| bool GenDynamicKernel(const session::KernelGraph *graph) override { return true; } | |||
| bool RunDynamicKernelAsync(const session::KernelGraph *graph) override { return true; } | |||
| void *compute_stream() const override { return stream_; } | |||
| void *communication_stream() const override { return communication_stream_; } | |||
| protected: | |||
| DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, | |||
| @@ -946,11 +946,13 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph) { | |||
| AddressPtrList kernel_workspaces; | |||
| AddressPtrList kernel_outputs; | |||
| GenLaunchArgs(*kernel_mod, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs); | |||
| auto ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_); | |||
| if (!ret) { | |||
| MS_LOG(ERROR) << "Launch kernel failed."; | |||
| return false; | |||
| } | |||
| KernelLaunchProfiling(kernels[i]->fullname_with_scope()); | |||
| } | |||
| } | |||
| @@ -33,6 +33,7 @@ | |||
| #include "utils/ms_context.h" | |||
| #include "runtime/device/memory_manager.h" | |||
| #include "runtime/device/executor/dynamic_kernel.h" | |||
| #include "ir/device_event.h" | |||
| using mindspore::tensor::Tensor; | |||
| using std::vector; | |||
| @@ -83,6 +84,9 @@ class KernelRuntime { | |||
| uint8_t *MallocMem(MemType type, size_t size, const DeviceAddressPtr &address) { | |||
| return mem_manager_->MallocMem(type, size, address); | |||
| } | |||
| uint8_t *MallocCommunicationMemFromMemPool(size_t size) { | |||
| return mem_manager_->MallocCommunicationMemFromMemPool(size); | |||
| } | |||
| static void GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod, const AnfNodePtr &kernel, | |||
| AddressPtrList *kernel_inputs, AddressPtrList *kernel_workspaces, | |||
| AddressPtrList *kernel_outputs); | |||
| @@ -104,6 +108,8 @@ class KernelRuntime { | |||
| virtual uint64_t GetAvailableMemMaxSize() const { return 0; } | |||
| void AddBufferPtr(std::shared_ptr<char[]> ptr) { buffer_ptrs_.push_back(ptr); } | |||
| void FreeAndClearBufferPtrs() { buffer_ptrs_.clear(); } | |||
| virtual void *compute_stream() const { return nullptr; } | |||
| virtual void *communication_stream() const { return nullptr; } | |||
| protected: | |||
| virtual DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, | |||
| @@ -149,7 +155,8 @@ class KernelRuntime { | |||
| #if !defined(_WIN32) && !defined(_WIN64) | |||
| std::shared_ptr<Debugger> debugger_; | |||
| #endif | |||
| void *stream_ = nullptr; | |||
| void *stream_{nullptr}; | |||
| void *communication_stream_{nullptr}; | |||
| std::shared_ptr<MemoryManager> mem_manager_{nullptr}; | |||
| std::map<uint32_t, std::vector<DynamicKernelPtr>> graph_dynamic_kernel_map_; | |||
| std::vector<std::shared_ptr<char[]>> buffer_ptrs_ = {}; | |||
| @@ -106,6 +106,14 @@ KernelRuntime *KernelRuntimeManager::GetKernelRuntime(const std::string &device_ | |||
| return kernel_runtime.get(); | |||
| } | |||
| KernelRuntime *KernelRuntimeManager::GetCurrentKernelRuntime() { | |||
| auto ms_context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(ms_context); | |||
| uint32_t device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID); | |||
| std::string device_name = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET); | |||
| return GetKernelRuntime(device_name, device_id); | |||
| } | |||
| void KernelRuntimeManager::ReleaseKernelRuntime(const std::string &device_name, uint32_t device_id) { | |||
| std::string runtime_key = GetDeviceKey(device_name, device_id); | |||
| std::lock_guard<std::mutex> guard(lock_); | |||
| @@ -38,6 +38,7 @@ class KernelRuntimeManager { | |||
| } | |||
| void Register(const std::string &device_name, KernelRuntimeCreator &&runtime_creator); | |||
| KernelRuntime *GetKernelRuntime(const std::string &device_name, uint32_t device_id); | |||
| KernelRuntime *GetCurrentKernelRuntime(); | |||
| KernelRuntime *GetSingleKernelRuntime(const std::string &device_name, uint32_t device_id); | |||
| void ReleaseKernelRuntime(const std::string &device_name, uint32_t device_id); | |||
| void ClearRuntimeResource(); | |||
| @@ -27,7 +27,7 @@ using mindspore::memreuse::MemReuseUtilPtr; | |||
| namespace mindspore { | |||
| namespace device { | |||
| size_t MemoryManager::GetCommonAlignSize(size_t input_size) const { | |||
| size_t MemoryManager::GetCommonAlignSize(size_t input_size) { | |||
| return (input_size + kMemAlignSize + 31) / kMemAlignSize * kMemAlignSize; | |||
| } | |||
| @@ -53,13 +53,14 @@ class MemoryManager { | |||
| virtual bool MallocMemFromMemPool(const DeviceAddressPtr address, size_t size); | |||
| virtual void *MallocMemFromMemPool(size_t size); | |||
| virtual uint8_t *MallocCommunicationMemFromMemPool(size_t size) { return nullptr; } | |||
| virtual void FreeMemFromMemPool(const DeviceAddressPtr address); | |||
| virtual void FreeMemFromMemPool(void *device_ptr); | |||
| virtual bool MallocContinuousMemFromMemPool(const DeviceAddressPtrList addr_list, size_t total_size, | |||
| std::vector<size_t> size_list); | |||
| virtual std::vector<void *> MallocContinuousMemFromMemPool(size_t total_size, std::vector<size_t> size_list); | |||
| size_t GetCommonAlignSize(size_t input_size) const; | |||
| static size_t GetCommonAlignSize(size_t input_size); | |||
| size_t GetCommunicationAlignSize(size_t input_size) const; | |||
| protected: | |||
| @@ -273,6 +273,10 @@ constexpr auto kOneHotOpName = "OneHot"; | |||
| constexpr auto kSoftmaxCrossEntropyWithLogitsOpName = "SoftmaxCrossEntropyWithLogits"; | |||
| constexpr auto kUniformCandidateSamplerOpName = "UniformCandidateSampler"; | |||
| // Communication world group | |||
| constexpr auto kNcclWorldGroup = "nccl_world_group"; | |||
| constexpr auto kHcclWorldGroup = "hccl_world_group"; | |||
| // Hcom Op Type | |||
| constexpr auto kHcomOpTypeAllReduce = "HcomAllReduce"; | |||
| constexpr auto kHcomOpTypeAllGather = "HcomAllGather"; | |||
| @@ -331,12 +335,15 @@ constexpr auto kAttrSrTag = "sr_tag"; | |||
| constexpr auto kAttrRootRank = "root_rank"; | |||
| constexpr auto kAttrIsTraining = "is_training"; | |||
| constexpr auto kAttrFusionId = "fusion_id"; | |||
| constexpr auto kAttrBucketId = "bucket_id"; | |||
| constexpr auto kAttrGradOutputIndex = "grad_output_index"; | |||
| constexpr auto kAttrLabelIndex = "label_index"; | |||
| constexpr auto kAttrLabelSwitchList = "label_switch_list"; | |||
| constexpr auto kAttrNewAxisMask = "new_axis_mask"; | |||
| constexpr auto kAttrShrinkAxisMask = "shrink_axis_mask"; | |||
| constexpr auto kAttrDatadumpOriginalNames = "_datadump_original_names"; | |||
| constexpr auto kAttrDatadumpIsMultiop = "_datadump_is_multiop"; | |||
| constexpr auto kAttrNeedRecordEvent = "need_record_event"; | |||
| constexpr auto kAttrStreamId = "stream_id"; | |||
| constexpr auto kAttrRecordEvent = "record_event"; | |||
| constexpr auto kAttrWaitEvent = "wait_event"; | |||
| @@ -0,0 +1,31 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_IR_DEVICE_EVENT_H | |||
| #define MINDSPORE_CORE_IR_DEVICE_EVENT_H | |||
| namespace mindspore { | |||
| class DeviceEvent { | |||
| public: | |||
| virtual ~DeviceEvent() = default; | |||
| virtual void WaitEvent() = 0; | |||
| virtual void RecordEvent() = 0; | |||
| virtual bool NeedWait() = 0; | |||
| virtual void set_wait_stream(void *stream) = 0; | |||
| virtual void set_record_stream(void *stream) = 0; | |||
| }; | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_IR_DEVICE_EVENT_H | |||
| @@ -44,6 +44,7 @@ FuncGraph::FuncGraph() | |||
| kwonlyargs_count_(0), | |||
| hyper_param_count_(0), | |||
| is_generated_(false), | |||
| is_bprop_(false), | |||
| return_(nullptr), | |||
| manager_(std::weak_ptr<FuncGraphManager>()), | |||
| stub_(false), | |||
| @@ -217,6 +217,8 @@ class FuncGraph : public FuncGraphBase, public EffectInfoHolder { | |||
| FuncGraphPtr GenerateGraph(const AbstractBasePtrList &args_spec_list); | |||
| void set_is_generate(bool generated) { is_generated_ = generated; } | |||
| bool is_generated() const { return is_generated_; } | |||
| void set_is_bprop(bool is_brop) { is_bprop_ = is_brop; } | |||
| bool is_bprop() const { return is_bprop_; } | |||
| std::unordered_map<std::string, ValuePtr> &attrs() { return attrs_; } | |||
| void set_attrs(const std::unordered_map<std::string, ValuePtr> &attrs) { | |||
| @@ -440,6 +442,8 @@ class FuncGraph : public FuncGraphBase, public EffectInfoHolder { | |||
| // the argument input list for the graph used to generate this graph | |||
| bool is_generated_; | |||
| bool is_bprop_; | |||
| // the cnode that calls 'return' primitive | |||
| // we use shared pointer to manage it. | |||
| CNodePtr return_; | |||
| @@ -247,6 +247,47 @@ class TensorDataImpl : public TensorData { | |||
| } | |||
| private: | |||
| void OutputFloatDataString(std::ostringstream &ss, bool isScalar, const T &value) const { | |||
| if (isScalar) { | |||
| ss << value; | |||
| } else { | |||
| // The placeholder of float16 is fixed at 11, while float/double is fixed at 15. | |||
| const int width = std::is_same<T, float16>::value ? 11 : 15; | |||
| // The printing precision of float16 is fixed at 4, while float/double is fixed at 8. | |||
| const int precision = std::is_same<T, float16>::value ? 4 : 8; | |||
| ss << std::setw(width) << std::setprecision(precision) << std::setiosflags(std::ios::scientific | std::ios::right) | |||
| << value; | |||
| } | |||
| } | |||
| void OutputBoolDataString(std::ostringstream &ss, bool isScalar, const T &value) const { | |||
| if (isScalar) { | |||
| ss << (value ? "True" : "False"); | |||
| } else { | |||
| constexpr int bool_max_width = sizeof("False") - 1; | |||
| ss << std::setw(bool_max_width) << std::setiosflags(std::ios::right) << (value ? "True" : "False"); | |||
| } | |||
| } | |||
| void OutputOtherDataString(std::ostringstream &ss, bool isScalar, const T &value, int *max_width) const { | |||
| if (isScalar) { | |||
| ss << value; | |||
| } else { | |||
| // Add a padding string before the number, such as "###123", for subsequent replacement. | |||
| const int width = GetNumLength(value); | |||
| *max_width = std::max(*max_width, width); | |||
| std::string pad(width, '#'); | |||
| ss << pad; | |||
| if constexpr (std::is_same<T, uint8_t>::value) { | |||
| ss << static_cast<uint16_t>(value); | |||
| } else if constexpr (std::is_same<T, int8_t>::value) { | |||
| ss << static_cast<int16_t>(value); | |||
| } else { | |||
| ss << value; | |||
| } | |||
| } | |||
| } | |||
| void OutputDataString(std::ostringstream &ss, ssize_t cursor, ssize_t start, ssize_t end, bool use_comma, | |||
| int *max_width) const { | |||
| const bool isScalar = ndim_ == 0 && end - start == 1; | |||
| @@ -257,40 +298,11 @@ class TensorDataImpl : public TensorData { | |||
| for (ssize_t i = start; i < end && (cursor + i) < static_cast<ssize_t>(data_size_); i++) { | |||
| const auto value = data_[cursor + i]; | |||
| if constexpr (isFloat) { | |||
| if (isScalar) { | |||
| ss << value; | |||
| } else { | |||
| // The placeholder of float16 is fixed at 11, while float/double is fixed at 15. | |||
| const int width = std::is_same<T, float16>::value ? 11 : 15; | |||
| // The printing precision of float16 is fixed at 4, while float/double is fixed at 8. | |||
| const int precision = std::is_same<T, float16>::value ? 4 : 8; | |||
| ss << std::setw(width) << std::setprecision(precision) | |||
| << std::setiosflags(std::ios::scientific | std::ios::right) << value; | |||
| } | |||
| OutputFloatDataString(ss, isScalar, value); | |||
| } else if (isBool) { | |||
| if (isScalar) { | |||
| ss << (value ? "True" : "False"); | |||
| } else { | |||
| constexpr int bool_max_width = sizeof("False") - 1; | |||
| ss << std::setw(bool_max_width) << std::setiosflags(std::ios::right) << (value ? "True" : "False"); | |||
| } | |||
| OutputBoolDataString(ss, isScalar, value); | |||
| } else { | |||
| if (isScalar) { | |||
| ss << value; | |||
| } else { | |||
| // Add a padding string before the number, such as "###123", for subsequent replacement. | |||
| const int width = GetNumLength(value); | |||
| *max_width = std::max(*max_width, width); | |||
| std::string pad(width, '#'); | |||
| ss << pad; | |||
| if constexpr (std::is_same<T, uint8_t>::value) { | |||
| ss << static_cast<uint16_t>(value); | |||
| } else if constexpr (std::is_same<T, int8_t>::value) { | |||
| ss << static_cast<int16_t>(value); | |||
| } else { | |||
| ss << value; | |||
| } | |||
| } | |||
| OutputOtherDataString(ss, isScalar, value, max_width); | |||
| } | |||
| if (!isScalar && i != end - 1) { | |||
| if (use_comma) { | |||
| @@ -452,7 +464,8 @@ Tensor::Tensor(const Tensor &tensor) | |||
| cache_enable_(tensor.cache_enable_), | |||
| cache_tensor_ptr_(tensor.cache_tensor_ptr_), | |||
| hashmap_tensor_ptr_(tensor.hashmap_tensor_ptr_), | |||
| padding_type_(tensor.padding_type()) {} | |||
| padding_type_(tensor.padding_type()), | |||
| device_event_(tensor.device_event_) {} | |||
| Tensor::Tensor(const Tensor &tensor, TypeId data_type) | |||
| : MetaTensor(data_type, tensor.shape_), | |||
| @@ -465,7 +478,8 @@ Tensor::Tensor(const Tensor &tensor, TypeId data_type) | |||
| cache_enable_(tensor.cache_enable_), | |||
| cache_tensor_ptr_(tensor.cache_tensor_ptr_), | |||
| hashmap_tensor_ptr_(tensor.hashmap_tensor_ptr_), | |||
| padding_type_(tensor.padding_type()) {} | |||
| padding_type_(tensor.padding_type()), | |||
| device_event_(tensor.device_event_) {} | |||
| Tensor::Tensor(TypeId data_type, const ShapeVector &shape, TensorDataPtr data) | |||
| : MetaTensor(data_type, shape), data_(std::move(data)), id_(MakeId()) {} | |||
| @@ -527,6 +541,7 @@ Tensor &Tensor::AssignValue(const Tensor &tensor) { | |||
| event_ = tensor.event_; | |||
| sync_status_ = tensor.sync_status_; | |||
| padding_type_ = tensor.padding_type_; | |||
| device_event_ = tensor.device_event_; | |||
| } | |||
| return *this; | |||
| } | |||
| @@ -30,6 +30,7 @@ | |||
| #include "base/float16.h" | |||
| #include "utils/shape_utils.h" | |||
| #include "utils/ms_exception.h" | |||
| #include "ir/device_event.h" | |||
| // brief mindspore namespace. | |||
| // | |||
| @@ -326,6 +327,21 @@ class Tensor : public MetaTensor { | |||
| event_ = nullptr; | |||
| } | |||
| void SetDeviceEvent(const std::shared_ptr<DeviceEvent> &device_event) { device_event_ = device_event; } | |||
| void WaitDevice() { | |||
| if (device_event_ != nullptr) { | |||
| device_event_->WaitEvent(); | |||
| } | |||
| } | |||
| bool NeedWaitDevice() const { | |||
| if (device_event_ != nullptr) { | |||
| return device_event_->NeedWait(); | |||
| } | |||
| return false; | |||
| } | |||
| void set_sync_status(TensorSyncStatus sync_status) { sync_status_ = sync_status; } | |||
| TensorSyncStatus sync_status() const { return sync_status_; } | |||
| @@ -352,6 +368,7 @@ class Tensor : public MetaTensor { | |||
| std::shared_ptr<Tensor> hashmap_tensor_ptr_{nullptr}; | |||
| std::vector<Axis> padding_type_; | |||
| TypePtr cast_dtype_{nullptr}; | |||
| std::shared_ptr<DeviceEvent> device_event_{nullptr}; | |||
| }; | |||
| using TensorPtr = std::shared_ptr<Tensor>; | |||
| using TensorPtrList = std::vector<std::shared_ptr<Tensor>>; | |||
| @@ -17,10 +17,11 @@ from mindspore import context | |||
| from mindspore.nn.cell import Cell | |||
| from mindspore.communication.management import GlobalComm, get_group_size | |||
| from mindspore.common.tensor import RowTensor | |||
| from mindspore.ops import functional as F, composite as C, operations as P | |||
| from mindspore.ops import functional as F, composite as C | |||
| from mindspore.ops.operations.comm_ops import AllReduce, AllGather | |||
| from mindspore.parallel._auto_parallel_context import auto_parallel_context | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore.common.tensor import Tensor | |||
| reduce_opt = C.MultitypeFuncGraph("reduce_opt") | |||
| @@ -45,7 +46,7 @@ def _init_allreduce_operators(length, split_indices): | |||
| return op_list | |||
| @reduce_opt.register("Number", "Bool", "Function", "Function", "Bool", "Tensor") | |||
| @reduce_opt.register("Tensor", "Bool", "Function", "Function", "Bool", "Tensor") | |||
| def _tensors_allreduce(degree, mean, allgather, allreduce, allreduce_filter, grad): | |||
| """ | |||
| Apply allreduce on gradient. | |||
| @@ -64,13 +65,33 @@ def _tensors_allreduce(degree, mean, allgather, allreduce, allreduce_filter, gra | |||
| if allreduce_filter: | |||
| grad = allreduce(grad) | |||
| if mean: | |||
| degree = F.scalar_cast(degree, F.dtype(grad)) | |||
| grad = F.tensor_mul(grad, F.cast(F.scalar_to_array(1.0 / degree), F.dtype(grad))) | |||
| grad = F.tensor_mul(grad, F.cast(degree, F.dtype(grad))) | |||
| return grad | |||
| return grad | |||
| @reduce_opt.register("Tensor", "Bool", "Bool", "Tensor") | |||
| def _tensors_allreduce_post(degree, mean, allreduce_filter, grad): | |||
| """ | |||
| Apply allreduce on gradient in PyNative mode. | |||
| Args: | |||
| degree (int): The mean coefficient. | |||
| mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients. | |||
| allgather (Primitive): The communication operator for sparse gradients. | |||
| allreduce (Primitive): The communication operator for gradients. | |||
| allreduce_filter (bool): When it is true, allreduce would apply. | |||
| grad (Tensor): The gradient tensor before operation. | |||
| Returns: | |||
| Tensor, the gradient tensor after operation. | |||
| """ | |||
| if allreduce_filter: | |||
| if mean: | |||
| grad = F.tensor_mul(grad, F.cast(degree, F.dtype(grad))) | |||
| return grad | |||
| return grad | |||
| @reduce_opt.register("Number", "Bool", "Function", "Function", "Bool", "Tensor", "Bool") | |||
| @reduce_opt.register("Tensor", "Bool", "Function", "Function", "Bool", "Tensor", "Bool") | |||
| def _tensors_allreduce_ps(degree, mean, allgather, allreduce, allreduce_filter, grad, ps_parameter): | |||
| """ | |||
| Apply allreduce on gradient. | |||
| @@ -93,15 +114,12 @@ def _tensors_allreduce_ps(degree, mean, allgather, allreduce, allreduce_filter, | |||
| if allreduce_filter: | |||
| grad = allreduce(grad) | |||
| if mean: | |||
| degree = F.scalar_cast(degree, F.dtype(grad)) | |||
| cast_op = P.Cast() | |||
| mul_op = P.Mul() | |||
| grad = mul_op(grad, cast_op(F.scalar_to_array(1.0/degree), F.dtype(grad))) | |||
| grad = F.tensor_mul(grad, F.cast(degree, F.dtype(grad))) | |||
| return grad | |||
| return grad | |||
| @reduce_opt.register("Number", "Bool", "Function", "Function", "Bool", "RowTensor") | |||
| @reduce_opt.register("Tensor", "Bool", "Function", "Function", "Bool", "RowTensor") | |||
| def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce, allreduce_filter, grad): | |||
| """ | |||
| Apply allgather on gradient instead of allreduce for sparse feature. | |||
| @@ -122,15 +140,12 @@ def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce, allreduce | |||
| indices = allgather(grad.indices) | |||
| dout = allgather(grad.values) | |||
| if mean: | |||
| degree = F.scalar_cast(degree, F.dtype(grad.values)) | |||
| cast_op = P.Cast() | |||
| mul_op = P.Mul() | |||
| dout = mul_op(dout, cast_op(F.scalar_to_array(1.0 / degree), F.dtype(dout))) | |||
| dout = F.tensor_mul(dout, F.cast(degree, F.dtype(dout))) | |||
| grad = RowTensor(indices, dout, grad.dense_shape) | |||
| return grad | |||
| @reduce_opt.register("Number", "Bool", "Function", "Function", "Bool", "RowTensor", "Bool") | |||
| @reduce_opt.register("Tensor", "Bool", "Function", "Function", "Bool", "RowTensor", "Bool") | |||
| def _tensors_allreduce_with_sparse_ps(degree, mean, allgather, allreduce, allreduce_filter, grad, ps_parameter): | |||
| """ | |||
| Apply allgather on gradient instead of allreduce for sparse feature. | |||
| @@ -155,10 +170,7 @@ def _tensors_allreduce_with_sparse_ps(degree, mean, allgather, allreduce, allred | |||
| indices = allgather(grad.indices) | |||
| dout = allgather(grad.values) | |||
| if mean: | |||
| degree = F.scalar_cast(degree, F.dtype(grad.values)) | |||
| cast_op = P.Cast() | |||
| mul_op = P.Mul() | |||
| dout = mul_op(dout, cast_op(F.scalar_to_array(1.0 / degree), F.dtype(dout))) | |||
| dout = F.tensor_mul(dout, F.cast(degree, F.dtype(dout))) | |||
| grad = RowTensor(indices, dout, grad.dense_shape) | |||
| return grad | |||
| @@ -329,6 +341,7 @@ class DistributedGradReducer(Cell): | |||
| if not isinstance(degree, int) or degree <= 0: | |||
| raise ValueError("Parameter 'degree' in DistributedGradReducer should large than 0 and be int") | |||
| self.degree = degree | |||
| self.degree = Tensor(1.0 / self.degree, mstype.float32) | |||
| self.mean = mean | |||
| self.allreduce_filter = tuple(x.layerwise_parallel is False for x in parameters) | |||
| is_parallel_optimizer = context.get_auto_parallel_context("enable_parallel_optimizer") | |||
| @@ -343,6 +356,7 @@ class DistributedGradReducer(Cell): | |||
| ps_filter = lambda x: x.is_param_ps | |||
| self.ps_parameters = tuple(ps_filter(x) for x in parameters) | |||
| self.enable_parameter_server = any(self.ps_parameters) | |||
| self.mode = context.get_context("mode") | |||
| def construct(self, grads): | |||
| """ | |||
| @@ -358,7 +372,9 @@ class DistributedGradReducer(Cell): | |||
| """ | |||
| datatypes = self.map_(F.partial(_get_datatype), grads) | |||
| grads = self.map_(F.partial(_cast_datatype, mstype.float32), grads) | |||
| if self.split_fusion: | |||
| if self.mode == context.PYNATIVE_MODE: | |||
| new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean), self.allreduce_filter, grads) | |||
| elif self.split_fusion: | |||
| if self.enable_parameter_server: | |||
| new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather), | |||
| self.op_list, self.allreduce_filter, grads, self.ps_parameters) | |||
| @@ -97,10 +97,13 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||
| "../../../mindspore/ccsrc/runtime/device/memory_manager.cc" | |||
| "../../../mindspore/ccsrc/runtime/device/kernel_runtime_manager.cc" | |||
| "../../../mindspore/ccsrc/runtime/device/kernel_info.cc" | |||
| "../../../mindspore/ccsrc/runtime/device/bucket.cc" | |||
| "../../../mindspore/ccsrc/runtime/device/ascend/profiling/*.cc" | |||
| "../../../mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc" | |||
| "../../../mindspore/ccsrc/runtime/device/ascend/kernel_select_graph_kernel.cc" | |||
| "../../../mindspore/ccsrc/runtime/device/convert_tensor_utils.cc" | |||
| "../../../mindspore/ccsrc/runtime/device/ascend/ascend_bucket.cc" | |||
| "../../../mindspore/ccsrc/runtime/device/ascend/ascend_event.cc" | |||
| "../../../mindspore/ccsrc/runtime/device/ascend/kernel_build_ascend.cc" | |||
| "../../../mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc" | |||
| "../../../mindspore/ccsrc/runtime/device/ascend/signal_util.cc" | |||
| @@ -160,3 +160,7 @@ RTS_API rtError_t rtSetMsprofReporterCallback(MsprofReporterCallback callback) { | |||
| RTS_API rtError_t rtRegTaskFailCallbackByModule(const char *moduleName, rtTaskFailCallback callback) { | |||
| return RT_ERROR_NONE; | |||
| } | |||
| RTS_API rtError_t rtMemsetAsync(void *ptr, uint64_t destMax, uint32_t value, uint64_t count, rtStream_t stream) { | |||
| return RT_ERROR_NONE; | |||
| } | |||