Merge pull request !27283 from limingqi107/new_actor_runtimetags/v1.6.0
| @@ -107,8 +107,8 @@ class DeviceAddress : public mindspore::DeviceSync { | |||
| virtual DeviceAddressStatus status() const { return DeviceAddressStatus::kInDevice; } | |||
| virtual DeviceAddressType DeviceType() const { return DeviceAddressType::kUnknown; } | |||
| void *GetMutablePtr() const override { return ptr_; } | |||
| std::string DeviceName() const { return device_name_; } | |||
| uint32_t DeviceID() const { return device_id_; } | |||
| std::string device_name() const { return device_name_; } | |||
| uint32_t device_id() const { return device_id_; } | |||
| virtual void SetNodeIndex(const AnfNodePtr &node, size_t out_index) { node_index_ = {node, out_index}; } | |||
| KernelWithIndex GetNodeIndex() const { | |||
| @@ -116,6 +116,21 @@ class DeviceAddress : public mindspore::DeviceSync { | |||
| : KernelWithIndex{node_index_.first.lock(), node_index_.second}; | |||
| } | |||
| // The related interface of dynamic reference count operation. | |||
| void set_dynamic_ref_conut(int32_t dynamic_ref_conut) { dynamic_ref_conut_ = dynamic_ref_conut; } | |||
| int32_t dynamic_ref_conut() const { return dynamic_ref_conut_; } | |||
| void IncreaseDynamicRefCount() { | |||
| if (dynamic_ref_conut_ < INT32_MAX) { | |||
| dynamic_ref_conut_++; | |||
| } | |||
| } | |||
| void DecreaseDynamicRefCount() { | |||
| if (dynamic_ref_conut_ <= 0) { | |||
| MS_LOG(EXCEPTION) << "The dynamic reference count is invalid value:" << dynamic_ref_conut_; | |||
| } | |||
| dynamic_ref_conut_--; | |||
| } | |||
| virtual bool DumpMemToFile(const std::string &filepath, const std::string &host_fmt, const ShapeVector &host_shape, | |||
| TypeId host_type, bool trans_flag) const { | |||
| return true; | |||
| @@ -142,10 +157,13 @@ class DeviceAddress : public mindspore::DeviceSync { | |||
| // {node, out_index} | |||
| std::pair<AnfNodeWeakPtr, size_t> node_index_{AnfNodePtr(nullptr), 0}; | |||
| // The device address of the node that owns the device address cannot be updated and replaced. | |||
| // application scenario: set to true when the hardware execution mode requires that ptr cannot be changed during | |||
| // Application scenario: set to true when the hardware execution mode requires that ptr cannot be changed during | |||
| // execution. | |||
| bool is_ptr_persisted_{false}; | |||
| // The device address generated in the control flow scene uses dynamic_ref_conut_. | |||
| std::atomic_int32_t dynamic_ref_conut_{INT32_MAX}; | |||
| // The key of device context. | |||
| std::string device_name_{""}; | |||
| uint32_t device_id_{0}; | |||
| @@ -18,9 +18,9 @@ | |||
| namespace mindspore { | |||
| namespace runtime { | |||
| ControlActor::ControlActor(const std::string &name, KernelTransformType type, | |||
| ControlActor::ControlActor(const std::string &name, KernelTransformType type, const AID &memory_manager_aid, | |||
| const std::vector<KernelWithIndex> ¶meters, const AnfNodePtr &node) | |||
| : AbstractActor(name, type, nullptr), formal_parameters_(parameters), node_(node) { | |||
| : MemoryAwareActor(name, type, nullptr, memory_manager_aid), formal_parameters_(parameters), node_(node) { | |||
| for (size_t i = 0; i < parameters.size(); ++i) { | |||
| input_partials_.emplace_back(std::make_shared<OpPartial>()); | |||
| } | |||
| @@ -41,6 +41,59 @@ void ControlActor::Init() { | |||
| } | |||
| } | |||
| std::vector<DeviceTensor *> ControlActor::GetAllDeviceTensors(const OpPartialPtr &op_partial) { | |||
| MS_EXCEPTION_IF_NULL(op_partial); | |||
| std::vector<DeviceTensor *> ret; | |||
| for (auto &device_tensor : op_partial->device_tensors_) { | |||
| (void)ret.emplace_back(device_tensor.second); | |||
| } | |||
| // Foreach the op partial to fetch the device tensors. | |||
| for (auto &partial : op_partial->partials_) { | |||
| auto ret_inner = GetAllDeviceTensors(partial.second); | |||
| (void)std::copy(ret_inner.begin(), ret_inner.end(), std::back_inserter(ret)); | |||
| } | |||
| return ret; | |||
| } | |||
| std::vector<DeviceTensor *> ControlActor::GetAllDeviceTensors(const OpRealParameterWithBranchID &op_real_parameter) { | |||
| std::vector<DeviceTensor *> ret; | |||
| for (auto &device_tensor : op_real_parameter.device_tensors_) { | |||
| (void)ret.emplace_back(device_tensor.second); | |||
| } | |||
| // Foreach the op partial to fetch the device tensors. | |||
| for (auto &partial : op_real_parameter.partials_) { | |||
| auto ret_inner = GetAllDeviceTensors(partial.second); | |||
| (void)std::copy(ret_inner.begin(), ret_inner.end(), std::back_inserter(ret)); | |||
| } | |||
| return ret; | |||
| } | |||
| void ControlActor::IncreaseDynamicRefCount(const OpData<DeviceTensor> *op_data) { | |||
| MS_EXCEPTION_IF_NULL(op_data); | |||
| MS_EXCEPTION_IF_NULL(op_data->data_); | |||
| op_data->data_->IncreaseDynamicRefCount(); | |||
| } | |||
| void ControlActor::IncreaseDynamicRefCount(const OpPartialPtr &op_partial) { | |||
| MS_EXCEPTION_IF_NULL(op_partial); | |||
| auto partial_device_tensors = GetAllDeviceTensors(op_partial); | |||
| for (auto &partial_device_tensor : partial_device_tensors) { | |||
| MS_EXCEPTION_IF_NULL(partial_device_tensor); | |||
| partial_device_tensor->IncreaseDynamicRefCount(); | |||
| } | |||
| } | |||
| void ControlActor::IncreaseDynamicRefCount(const OpRealParameterWithBranchID &op_real_parameter) { | |||
| auto partial_device_tensors = GetAllDeviceTensors(op_real_parameter); | |||
| for (auto &partial_device_tensor : partial_device_tensors) { | |||
| MS_EXCEPTION_IF_NULL(partial_device_tensor); | |||
| partial_device_tensor->IncreaseDynamicRefCount(); | |||
| } | |||
| } | |||
| size_t ControlActor::FetchNodePosition(const KernelWithIndex &node) const { | |||
| const auto &iter = find(formal_parameters_.begin(), formal_parameters_.end(), node); | |||
| if (iter == formal_parameters_.end()) { | |||
| @@ -52,6 +105,13 @@ size_t ControlActor::FetchNodePosition(const KernelWithIndex &node) const { | |||
| void ControlActor::Run(OpContext<DeviceTensor> *const context) { | |||
| FetchInput(context); | |||
| // Note that IncreaseDynamicRefCounts must be in front of SendMemoryFreeReq. SendMemoryFreeReq will decreasing the | |||
| // dynamic ref count. Avoid the illegal timing problem that the dynamic reference count is decremented and then | |||
| // incremented. | |||
| IncreaseDynamicRefCounts(context); | |||
| SendMemoryFreeReq(context); | |||
| EraseInput(context); | |||
| SendOutput(context); | |||
| } | |||
| @@ -197,9 +257,61 @@ void ControlActor::FetchInput(OpContext<DeviceTensor> *const context) { | |||
| } | |||
| } | |||
| void ControlActor::IncreaseDynamicRefCounts(OpContext<DeviceTensor> *const context) { | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| // Increase dynamic ref count by the output data. | |||
| for (auto &output_data : output_data_) { | |||
| MS_EXCEPTION_IF_NULL(output_data); | |||
| IncreaseDynamicRefCount(output_data.get()); | |||
| } | |||
| // Increase dynamic ref count by the output partial. | |||
| for (const auto &partial_arrow : output_partial_arrows_) { | |||
| MS_EXCEPTION_IF_NULL(partial_arrow); | |||
| if (IntToSize(partial_arrow->from_output_index_) >= input_partials_.size()) { | |||
| std::string error_info = "Invalid partial input:" + std::to_string(partial_arrow->from_output_index_) + | |||
| " current:" + std::to_string(input_partials_.size()) + " for actor:" + GetAID().Name(); | |||
| SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info); | |||
| } | |||
| auto output_partial = input_partials_[partial_arrow->from_output_index_]; | |||
| IncreaseDynamicRefCount(output_partial); | |||
| } | |||
| } | |||
| void ControlActor::SendMemoryFreeReq(OpContext<DeviceTensor> *const context) { | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| const auto &sequential_num = context->sequential_num_; | |||
| // Collect the input device tensors. | |||
| std::vector<DeviceTensor *> memory_free_list; | |||
| if (input_op_datas_.count(sequential_num) > 0) { | |||
| for (auto &input_data : input_op_datas_[sequential_num]) { | |||
| MS_EXCEPTION_IF_NULL(input_data); | |||
| MS_EXCEPTION_IF_NULL(input_data->data_); | |||
| memory_free_list.emplace_back(input_data->data_); | |||
| } | |||
| } | |||
| if (input_op_partials_.count(sequential_num) > 0) { | |||
| for (auto &input_partial_pair : input_op_partials_[sequential_num]) { | |||
| auto partial_device_tensors = GetAllDeviceTensors(input_partial_pair.second); | |||
| (void)std::copy(partial_device_tensors.begin(), partial_device_tensors.end(), | |||
| std::back_inserter(memory_free_list)); | |||
| } | |||
| } | |||
| if (memory_free_list.size() > 0) { | |||
| memory_free_lists_.emplace_back(memory_free_list); | |||
| ActorDispatcher::Send(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &(memory_free_lists_.back()), | |||
| device_contexts_[0], context); | |||
| } | |||
| } | |||
| void ControlActor::EraseInput(const OpContext<DeviceTensor> *context) { | |||
| AbstractActor::EraseInput(context); | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| const auto &sequential_num = context->sequential_num_; | |||
| AbstractActor::EraseInput(context); | |||
| if (input_partials_num_ != 0) { | |||
| auto ret = input_op_partials_.erase(sequential_num); | |||
| if (ret == 0) { | |||
| @@ -25,9 +25,11 @@ | |||
| #include <stack> | |||
| #include <queue> | |||
| #include <utility> | |||
| #include <algorithm> | |||
| #include "runtime/framework/actor/actor_common.h" | |||
| #include "runtime/framework/actor/abstract_actor.h" | |||
| #include "runtime/framework/actor/memory_aware_actor.h" | |||
| #include "runtime/framework/actor/memory_manager_actor.h" | |||
| namespace mindspore { | |||
| namespace runtime { | |||
| @@ -49,10 +51,10 @@ struct OpRealParameterWithBranchID { | |||
| int branch_id_; | |||
| }; | |||
| // The control actor is the base class of control flow actor. | |||
| class ControlActor : public AbstractActor { | |||
| class ControlActor : public MemoryAwareActor { | |||
| public: | |||
| ControlActor(const std::string &name, KernelTransformType type, const std::vector<KernelWithIndex> ¶meters, | |||
| const AnfNodePtr &node); | |||
| ControlActor(const std::string &name, KernelTransformType type, const AID &memory_manager_aid, | |||
| const std::vector<KernelWithIndex> ¶meters, const AnfNodePtr &node); | |||
| ~ControlActor() override = default; | |||
| void Init() override; | |||
| @@ -72,6 +74,14 @@ class ControlActor : public AbstractActor { | |||
| protected: | |||
| friend class ControlNodeScheduler; | |||
| // The basic interfaces for op partial and op real parameter. | |||
| std::vector<DeviceTensor *> GetAllDeviceTensors(const OpPartialPtr &op_partial); | |||
| std::vector<DeviceTensor *> GetAllDeviceTensors(const OpRealParameterWithBranchID &op_real_parameter); | |||
| void IncreaseDynamicRefCount(const OpData<DeviceTensor> *op_data); | |||
| void IncreaseDynamicRefCount(const OpPartialPtr &op_partial); | |||
| void IncreaseDynamicRefCount(const OpRealParameterWithBranchID &op_real_parameter); | |||
| // Get the position of node in the input. | |||
| size_t FetchNodePosition(const KernelWithIndex &node) const; | |||
| @@ -82,6 +92,11 @@ class ControlActor : public AbstractActor { | |||
| void SendOutput(OpContext<DeviceTensor> *const context) override; | |||
| void EraseInput(const OpContext<DeviceTensor> *context) override; | |||
| // Increase the dynamic ref count by the outputs. It corresponds to the SendOutput. | |||
| virtual void IncreaseDynamicRefCounts(OpContext<DeviceTensor> *const context); | |||
| // Free memory by the dynamic ref count decremented. It corresponds to the EraseInput. | |||
| void SendMemoryFreeReq(OpContext<DeviceTensor> *const context) override; | |||
| // Input data. | |||
| // 1.Input partial. | |||
| // Record the partial received by each step, the key of the pair indicates the location of the partial. | |||
| @@ -99,6 +114,9 @@ class ControlActor : public AbstractActor { | |||
| std::vector<OpPartialPtr> input_partials_; | |||
| std::vector<DeviceTensor *> input_device_tensors_; | |||
| // The lists of device tensors which need free by dynamic ref count, will be cleared at the end of step. | |||
| std::vector<std::vector<DeviceTensor *>> memory_free_lists_; | |||
| // Input num. | |||
| size_t input_partials_num_{0}; | |||
| size_t input_branch_ids_num_{0}; | |||
| @@ -70,6 +70,13 @@ void EntranceActor::Run(OpContext<DeviceTensor> *const context) { | |||
| is_loop_body_execution_ = true; | |||
| FetchInput(context); | |||
| // Note that IncreaseDynamicRefCount must be in front of SendMemoryFreeReq. SendMemoryFreeReq will decreasing the | |||
| // dynamic ref count. Avoid the illegal timing problem that the dynamic reference count is decremented and then | |||
| // incremented. | |||
| IncreaseDynamicRefCounts(context); | |||
| SendMemoryFreeReq(context); | |||
| EraseInput(context); | |||
| SendOutput(context); | |||
| } | |||
| @@ -218,5 +225,36 @@ void EntranceActor::EraseInput(const OpContext<DeviceTensor> *const context) { | |||
| } | |||
| } | |||
| } | |||
| void EntranceActor::SendMemoryFreeReq(OpContext<DeviceTensor> *const context) { | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| const auto &sequential_num = context->sequential_num_; | |||
| // Collect the input device tensors. | |||
| std::vector<DeviceTensor *> memory_free_list; | |||
| if (input_op_datas_.count(sequential_num) > 0) { | |||
| for (auto &input_data : input_op_datas_[sequential_num]) { | |||
| MS_EXCEPTION_IF_NULL(input_data); | |||
| MS_EXCEPTION_IF_NULL(input_data->data_); | |||
| memory_free_list.emplace_back(input_data->data_); | |||
| } | |||
| } | |||
| const auto &iter = real_parameters_with_branch_id_.find(sequential_num); | |||
| if (iter != real_parameters_with_branch_id_.end()) { | |||
| if (iter->second.empty()) { | |||
| SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "The real parameter with branch id is empty."); | |||
| } | |||
| auto &real_parameters_with_branch_id = iter->second.front(); | |||
| auto partial_device_tensors = GetAllDeviceTensors(real_parameters_with_branch_id); | |||
| (void)std::copy(partial_device_tensors.begin(), partial_device_tensors.end(), std::back_inserter(memory_free_list)); | |||
| } | |||
| if (memory_free_list.size() > 0) { | |||
| memory_free_lists_.emplace_back(memory_free_list); | |||
| ActorDispatcher::Send(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &(memory_free_lists_.back()), | |||
| device_contexts_[0], context); | |||
| } | |||
| } | |||
| } // namespace runtime | |||
| } // namespace mindspore | |||
| @@ -23,6 +23,7 @@ | |||
| #include <stack> | |||
| #include <queue> | |||
| #include <set> | |||
| #include <algorithm> | |||
| #include "utils/hash_map.h" | |||
| #include "runtime/framework/actor/actor_common.h" | |||
| #include "runtime/framework/actor/control_flow/control_actor.h" | |||
| @@ -33,9 +34,10 @@ namespace runtime { | |||
| // the data to the corresponding actor. It is the entry point for subgraph execution. | |||
| class EntranceActor : public ControlActor { | |||
| public: | |||
| EntranceActor(const std::string &name, const std::vector<KernelWithIndex> ¶meters, | |||
| EntranceActor(const std::string &name, const AID &memory_manager_aid, const std::vector<KernelWithIndex> ¶meters, | |||
| const std::set<KernelWithIndex> &call_nodes, const AnfNodePtr &node) | |||
| : ControlActor(name, KernelTransformType::kEntranceActor, parameters, node), call_nodes_(call_nodes) { | |||
| : ControlActor(name, KernelTransformType::kEntranceActor, memory_manager_aid, parameters, node), | |||
| call_nodes_(call_nodes) { | |||
| device_contexts_.resize(parameters.size()); | |||
| input_device_tensors_.resize(parameters.size()); | |||
| } | |||
| @@ -56,6 +58,7 @@ class EntranceActor : public ControlActor { | |||
| void FetchInput(OpContext<DeviceTensor> *const context) override; | |||
| bool CheckRunningCondition(const OpContext<DeviceTensor> *context) const override; | |||
| void EraseInput(const OpContext<DeviceTensor> *const context) override; | |||
| void SendMemoryFreeReq(OpContext<DeviceTensor> *const context) override; | |||
| private: | |||
| friend class ControlNodeScheduler; | |||
| @@ -91,6 +91,33 @@ void ExitActor::SendOutput(OpContext<DeviceTensor> *const context) { | |||
| } | |||
| } | |||
| void ExitActor::IncreaseDynamicRefCounts(OpContext<DeviceTensor> *const context) { | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| ControlActor::IncreaseDynamicRefCounts(context); | |||
| // Increase dynamic ref count by the output data in output branch. | |||
| if (output_branch_data_.count(output_branch_id_) > 0) { | |||
| for (auto &output_data : output_branch_data_[output_branch_id_]) { | |||
| MS_EXCEPTION_IF_NULL(output_data.second); | |||
| IncreaseDynamicRefCount(output_data.second.get()); | |||
| } | |||
| } | |||
| // Increase dynamic ref count by the output partial in output branch. | |||
| if (output_branch_partial_arrows_.count(output_branch_id_) > 0) { | |||
| for (const auto &partial_arrow : output_branch_partial_arrows_[output_branch_id_]) { | |||
| MS_EXCEPTION_IF_NULL(partial_arrow); | |||
| if (IntToSize(partial_arrow->from_output_index_) >= input_partials_.size()) { | |||
| std::string error_info = "Invalid partial input:" + std::to_string(partial_arrow->from_output_index_) + | |||
| " current:" + std::to_string(input_partials_.size()) + " for actor:" + GetAID().Name(); | |||
| SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info); | |||
| } | |||
| auto output_partial = input_partials_[partial_arrow->from_output_index_]; | |||
| IncreaseDynamicRefCount(output_partial); | |||
| } | |||
| } | |||
| } | |||
| void ExitActor::CopyDeviceAddress(OpContext<DeviceTensor> *const context) { | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| // If node is not empty, it is the exit of funcgraph, no need to create device address. | |||
| @@ -110,26 +137,45 @@ void ExitActor::CopyDeviceAddress(OpContext<DeviceTensor> *const context) { | |||
| MS_EXCEPTION_IF_NULL(input_device_tensor); | |||
| const KernelWithIndex &node_with_index = input_device_tensor->GetNodeIndex(); | |||
| MS_EXCEPTION_IF_NULL(node_with_index.first); | |||
| // If the address ptr can't be changed, does not need to copy a new device tensor. | |||
| if ((!is_need_copy_device_tensors_[i]) || input_device_tensor->is_ptr_persisted()) { | |||
| if (!is_need_copy_device_tensors_[i]) { | |||
| new_device_tensors.emplace_back(input_device_tensor); | |||
| continue; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(device_contexts_[i]); | |||
| auto new_device_tensor = | |||
| device_contexts_[i]->CreateDeviceAddress(nullptr, input_device_tensors_[i]->GetSize(), | |||
| input_device_tensors_[i]->format(), input_device_tensors_[i]->type_id()); | |||
| // Create the new device tensor to take over the input_device_tensors which are the outputs of kernel graphs. | |||
| auto new_device_tensor = device_contexts_[i]->CreateDeviceAddress( | |||
| nullptr, input_device_tensor->GetSize(), input_device_tensor->format(), input_device_tensor->type_id()); | |||
| MS_EXCEPTION_IF_NULL(new_device_tensor); | |||
| new_device_tensor->set_ptr(input_device_tensor->GetMutablePtr()); | |||
| new_device_tensor->set_from_mem_pool(input_device_tensor->from_mem_pool()); | |||
| created_device_tensors_.emplace_back(new_device_tensor); | |||
| new_device_tensors.emplace_back(new_device_tensor.get()); | |||
| new_device_tensor->SetNodeIndex(node_with_index.first, node_with_index.second); | |||
| new_device_tensor->set_from_persistent_mem(input_device_tensor->from_persistent_mem()); | |||
| // The device address which is created by actor uses the dynamic ref count. | |||
| new_device_tensor->set_dynamic_ref_conut(0); | |||
| new_device_tensor->set_original_ref_count(SIZE_MAX); | |||
| new_device_tensor->ResetRefCount(); | |||
| new_device_tensors.emplace_back(new_device_tensor.get()); | |||
| created_device_tensors_.emplace_back(new_device_tensor); | |||
| input_device_tensor->set_ptr(nullptr); | |||
| input_device_tensor->set_from_mem_pool(false); | |||
| // If the address ptr can't be changed, then alloc the new device memory and copy the data. | |||
| if (input_device_tensor->is_ptr_persisted()) { | |||
| if (!device_contexts_[i]->AllocateMemory(new_device_tensor.get(), new_device_tensor->GetSize())) { | |||
| SET_OPCONTEXT_MEMORY_ALLOC_FAIL_BY_STRATEGY(GraphExecutionStrategy::kPipeline, *context, *device_contexts_[i], | |||
| GetAID().Name(), new_device_tensor->GetSize()); | |||
| } | |||
| if (!new_device_tensor->SyncDeviceToDevice( | |||
| trans::GetRuntimePaddingShape(node_with_index.first, node_with_index.second), | |||
| input_device_tensor->GetSize(), input_device_tensor->type_id(), input_device_tensor->GetPtr(), | |||
| input_device_tensor->format())) { | |||
| SET_OPCONTEXT_FAIL_RET_WITH_ERROR(*context, "Sync device to device failed."); | |||
| } | |||
| } else { | |||
| // Move the device ptr from input_device_tensor to new_device_tensor. | |||
| new_device_tensor->set_ptr(input_device_tensor->GetMutablePtr()); | |||
| new_device_tensor->set_from_mem_pool(input_device_tensor->from_mem_pool()); | |||
| input_device_tensor->set_ptr(nullptr); | |||
| input_device_tensor->set_from_mem_pool(false); | |||
| } | |||
| } | |||
| input_device_tensors_.swap(new_device_tensors); | |||
| @@ -32,8 +32,9 @@ namespace runtime { | |||
| // device tensors in the data to the corresponding actor. It is the exit of the end of kernel graph execution. | |||
| class ExitActor : public ControlActor { | |||
| public: | |||
| ExitActor(const std::string &name, const std::vector<KernelWithIndex> ¶meters, const AnfNodePtr &node) | |||
| : ControlActor(name, KernelTransformType::kExitActor, parameters, node) { | |||
| ExitActor(const std::string &name, const AID &memory_manager_aid, const std::vector<KernelWithIndex> ¶meters, | |||
| const AnfNodePtr &node) | |||
| : ControlActor(name, KernelTransformType::kExitActor, memory_manager_aid, parameters, node) { | |||
| device_contexts_.resize(parameters.size()); | |||
| input_device_tensors_.resize(parameters.size()); | |||
| } | |||
| @@ -54,6 +55,7 @@ class ExitActor : public ControlActor { | |||
| protected: | |||
| void FetchInput(OpContext<DeviceTensor> *const context) override; | |||
| void SendOutput(OpContext<DeviceTensor> *const context) override; | |||
| void IncreaseDynamicRefCounts(OpContext<DeviceTensor> *const context) override; | |||
| private: | |||
| friend class ControlNodeScheduler; | |||
| @@ -19,9 +19,9 @@ | |||
| namespace mindspore { | |||
| namespace runtime { | |||
| GatherActor::GatherActor(const std::string &name, const std::vector<KernelWithIndex> ¶meters, | |||
| const AnfNodePtr &node) | |||
| : ControlActor(name, KernelTransformType::kGatherActor, parameters, node) { | |||
| GatherActor::GatherActor(const std::string &name, const AID &memory_manager_aid, | |||
| const std::vector<KernelWithIndex> ¶meters, const AnfNodePtr &node) | |||
| : ControlActor(name, KernelTransformType::kGatherActor, memory_manager_aid, parameters, node) { | |||
| device_contexts_.resize(parameters.size()); | |||
| } | |||
| @@ -50,34 +50,41 @@ void GatherActor::FetchInput(OpContext<DeviceTensor> *const context) { | |||
| } | |||
| } | |||
| void GatherActor::FetchOutput(OpRealParameterWithBranchID *const output, OpContext<DeviceTensor> *const context) { | |||
| MS_EXCEPTION_IF_NULL(output); | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| output->branch_id_ = output_branch_id_; | |||
| output->device_tensors_ = input_partials_[0]->device_tensors_; | |||
| output->partials_ = input_partials_[0]->partials_; | |||
| // The first input of gather actor is the target funcgraph, which will not be sent to the entrance actor as | |||
| // an real parameter, so the subsequent index needs to be reduced by one. | |||
| for (auto &device_tensor : output->device_tensors_) { | |||
| if (device_tensor.first == 0) { | |||
| std::string error_info = | |||
| "Invalid device tensor index:" + std::to_string(device_tensor.first) + " for actor:" + GetAID().Name(); | |||
| SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info); | |||
| } | |||
| device_tensor.first--; | |||
| } | |||
| for (auto &partial : output->partials_) { | |||
| if (partial.first == 0) { | |||
| std::string error_info = | |||
| "Invalid partial index:" + std::to_string(partial.first) + " for actor:" + GetAID().Name(); | |||
| SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info); | |||
| } | |||
| partial.first--; | |||
| } | |||
| } | |||
| void GatherActor::SendOutput(OpContext<DeviceTensor> *const context) { | |||
| // Send data with branch id. | |||
| const auto &iter = output_data_with_branch_id_arrows_.find(input_partials_[0]->func_graph_); | |||
| if (iter != output_data_with_branch_id_arrows_.end()) { | |||
| // Build the output data struct. | |||
| OpRealParameterWithBranchID output; | |||
| output.branch_id_ = output_branch_id_; | |||
| output.device_tensors_ = input_partials_[0]->device_tensors_; | |||
| output.partials_ = input_partials_[0]->partials_; | |||
| FetchOutput(&output, context); | |||
| // The first input of gather actor is the target funcgraph, which will not be sent to the entrance actor as | |||
| // an real parameter, so the subsequent index needs to be reduced by one. | |||
| for (auto &device_tensor : output.device_tensors_) { | |||
| if (device_tensor.first == 0) { | |||
| std::string error_info = | |||
| "Invalid device tensor index:" + std::to_string(device_tensor.first) + " for actor:" + GetAID().Name(); | |||
| SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info); | |||
| } | |||
| device_tensor.first--; | |||
| } | |||
| for (auto &partial : output.partials_) { | |||
| if (partial.first == 0) { | |||
| std::string error_info = | |||
| "Invalid partial index:" + std::to_string(partial.first) + " for actor:" + GetAID().Name(); | |||
| SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info); | |||
| } | |||
| partial.first--; | |||
| } | |||
| for (const auto &data_with_branch_id_arrow : iter->second) { | |||
| ActorDispatcher::Send(data_with_branch_id_arrow, &EntranceActor::RunOpRealParameterWithBranchID, output, context); | |||
| } | |||
| @@ -86,5 +93,22 @@ void GatherActor::SendOutput(OpContext<DeviceTensor> *const context) { | |||
| // Control arrow needs to be sent after the real parameter data and branch id. | |||
| ControlActor::SendOutput(context); | |||
| } | |||
| void GatherActor::IncreaseDynamicRefCounts(OpContext<DeviceTensor> *const context) { | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| ControlActor::IncreaseDynamicRefCounts(context); | |||
| // Increase dynamic ref count by the output data with branch id. | |||
| const auto &iter = output_data_with_branch_id_arrows_.find(input_partials_[0]->func_graph_); | |||
| if (iter != output_data_with_branch_id_arrows_.end()) { | |||
| // Build the output data struct. | |||
| OpRealParameterWithBranchID output; | |||
| FetchOutput(&output, context); | |||
| for (size_t i = 0; i < iter->second.size(); ++i) { | |||
| IncreaseDynamicRefCount(output); | |||
| } | |||
| } | |||
| } | |||
| } // namespace runtime | |||
| } // namespace mindspore | |||
| @@ -34,7 +34,8 @@ namespace runtime { | |||
| // together and sent to the subgraph. | |||
| class GatherActor : public ControlActor { | |||
| public: | |||
| GatherActor(const std::string &name, const std::vector<KernelWithIndex> ¶meters, const AnfNodePtr &node); | |||
| GatherActor(const std::string &name, const AID &memory_manager_aid, const std::vector<KernelWithIndex> ¶meters, | |||
| const AnfNodePtr &node); | |||
| ~GatherActor() override = default; | |||
| const mindspore::HashMap<FuncGraph *, std::vector<AID>> &output_data_with_branch_id_arrows() const { | |||
| return output_data_with_branch_id_arrows_; | |||
| @@ -43,10 +44,13 @@ class GatherActor : public ControlActor { | |||
| protected: | |||
| void FetchInput(OpContext<DeviceTensor> *const context) override; | |||
| void SendOutput(OpContext<DeviceTensor> *const context) override; | |||
| void IncreaseDynamicRefCounts(OpContext<DeviceTensor> *const context) override; | |||
| private: | |||
| friend class ControlNodeScheduler; | |||
| void FetchOutput(OpRealParameterWithBranchID *const output, OpContext<DeviceTensor> *const context); | |||
| // There will be multiple output branches for gather actor according the funcgraph in partial. | |||
| mindspore::HashMap<FuncGraph *, std::vector<AID>> output_data_with_branch_id_arrows_; | |||
| }; | |||
| @@ -20,8 +20,9 @@ | |||
| namespace mindspore { | |||
| namespace runtime { | |||
| StackActor::StackActor(const std::string &name, const std::vector<KernelWithIndex> ¶meters) | |||
| : ControlActor(name, KernelTransformType::kStackActor, parameters, nullptr) { | |||
| StackActor::StackActor(const std::string &name, const AID &memory_manager_aid, | |||
| const std::vector<KernelWithIndex> ¶meters) | |||
| : ControlActor(name, KernelTransformType::kStackActor, memory_manager_aid, parameters, nullptr) { | |||
| input_device_tensors_.resize(parameters.size()); | |||
| } | |||
| @@ -77,7 +78,7 @@ void StackActor::RunOpData(OpData<DeviceTensor> *const input_data, OpContext<Dev | |||
| // The parameters from the inside of the subgraph need to be put into the stack. | |||
| if (IntToSize(input_data->index_) < input_stack_data_num_ + device_tensor_store_keys_.size() + | |||
| input_stack_partials_num_ + local_device_tensors_.size()) { | |||
| FillStack(input_data, context); | |||
| input_stack_data_[context->sequential_num_][input_data->index_].push(input_data->data_); | |||
| } else { | |||
| // The outputs of call nodes are placed directly in the input data. | |||
| input_op_datas_[context->sequential_num_].emplace_back(input_data); | |||
| @@ -129,47 +130,6 @@ void StackActor::RunOpPartial(OpPartialPtr partial, size_t position, OpContext<D | |||
| } | |||
| } | |||
| void StackActor::FillStack(OpData<DeviceTensor> *const input_data, OpContext<DeviceTensor> *const context) { | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| MS_EXCEPTION_IF_NULL(input_data); | |||
| auto &input_device_tensor = input_data->data_; | |||
| MS_EXCEPTION_IF_NULL(input_device_tensor); | |||
| auto &sequential_num = context->sequential_num_; | |||
| size_t index = IntToSize(input_data->index_); | |||
| if (index >= device_contexts_.size()) { | |||
| SET_OPCONTEXT_FAIL_RET_WITH_ERROR(*context, "The index is out of range."); | |||
| } | |||
| // 1.If device context is empty, it means that the input is from a parameter and does not need copy new device tensor. | |||
| // 2.If the address ptr can be changed, it has been copied by exit actor and does not need copy a new device tensor. | |||
| if ((device_contexts_[index] == nullptr) || (!input_device_tensor->is_ptr_persisted())) { | |||
| input_stack_data_[sequential_num][input_data->index_].push(input_device_tensor); | |||
| } else { | |||
| const KernelWithIndex &node_with_index = input_device_tensor->GetNodeIndex(); | |||
| MS_EXCEPTION_IF_NULL(node_with_index.first); | |||
| // Create the new device tensor and copy the data from the input data. | |||
| auto new_device_tensor = device_contexts_[index]->CreateDeviceAddress( | |||
| nullptr, input_device_tensor->GetSize(), input_device_tensor->format(), input_device_tensor->type_id()); | |||
| MS_EXCEPTION_IF_NULL(new_device_tensor); | |||
| if (!device_contexts_[index]->AllocateMemory(new_device_tensor.get(), new_device_tensor->GetSize())) { | |||
| SET_OPCONTEXT_MEMORY_ALLOC_FAIL_BY_STRATEGY(GraphExecutionStrategy::kPipeline, *context, *device_contexts_[index], | |||
| GetAID().Name(), new_device_tensor->GetSize()); | |||
| } | |||
| if (!new_device_tensor->SyncDeviceToDevice( | |||
| trans::GetRuntimePaddingShape(node_with_index.first, node_with_index.second), input_device_tensor->GetSize(), | |||
| input_device_tensor->type_id(), input_device_tensor->GetPtr(), input_device_tensor->format())) { | |||
| SET_OPCONTEXT_FAIL_RET_WITH_ERROR(*context, "Sync device to device failed."); | |||
| } | |||
| new_device_tensor->SetNodeIndex(node_with_index.first, node_with_index.second); | |||
| new_device_tensor->set_original_ref_count(SIZE_MAX); | |||
| new_device_tensor->ResetRefCount(); | |||
| created_device_tensors_.emplace_back(new_device_tensor); | |||
| input_stack_data_[sequential_num][input_data->index_].push(new_device_tensor.get()); | |||
| } | |||
| } | |||
| bool StackActor::CheckRunningCondition(const OpContext<DeviceTensor> *context) const { | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| if (!ControlActor::CheckRunningCondition(context)) { | |||
| @@ -331,5 +291,52 @@ void StackActor::EraseInput(const OpContext<DeviceTensor> *const context) { | |||
| } | |||
| } | |||
| } | |||
| void StackActor::SendMemoryFreeReq(OpContext<DeviceTensor> *const context) { | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| const auto &sequential_num = context->sequential_num_; | |||
| // Collect the input device tensors. | |||
| std::vector<DeviceTensor *> memory_free_list; | |||
| if (input_op_datas_.count(sequential_num) > 0) { | |||
| for (auto &input_data : input_op_datas_[sequential_num]) { | |||
| MS_EXCEPTION_IF_NULL(input_data); | |||
| MS_EXCEPTION_IF_NULL(input_data->data_); | |||
| memory_free_list.emplace_back(input_data->data_); | |||
| } | |||
| } | |||
| if (input_op_partials_.count(sequential_num) > 0) { | |||
| for (auto &input_partial_pair : input_op_partials_[sequential_num]) { | |||
| auto partial_device_tensors = GetAllDeviceTensors(input_partial_pair.second); | |||
| (void)std::copy(partial_device_tensors.begin(), partial_device_tensors.end(), | |||
| std::back_inserter(memory_free_list)); | |||
| } | |||
| } | |||
| if ((input_stack_data_num_ != 0) && (input_stack_data_.count(sequential_num) > 0)) { | |||
| for (auto &stack_data_pair : input_stack_data_[sequential_num]) { | |||
| if (!stack_data_pair.second.empty()) { | |||
| memory_free_list.emplace_back(stack_data_pair.second.top()); | |||
| } | |||
| } | |||
| } | |||
| if ((input_stack_partials_num_ != 0) && (input_stack_partials_.count(sequential_num) > 0)) { | |||
| for (auto &stack_partial_pair : input_stack_partials_[sequential_num]) { | |||
| if (!stack_partial_pair.second.empty()) { | |||
| auto partial_device_tensors = GetAllDeviceTensors(stack_partial_pair.second.top()); | |||
| (void)std::copy(partial_device_tensors.begin(), partial_device_tensors.end(), | |||
| std::back_inserter(memory_free_list)); | |||
| } | |||
| } | |||
| } | |||
| if (memory_free_list.size() > 0) { | |||
| memory_free_lists_.emplace_back(memory_free_list); | |||
| ActorDispatcher::Send(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &(memory_free_lists_.back()), | |||
| device_contexts_[0], context); | |||
| } | |||
| } | |||
| } // namespace runtime | |||
| } // namespace mindspore | |||
| @@ -22,6 +22,7 @@ | |||
| #include <memory> | |||
| #include <stack> | |||
| #include <set> | |||
| #include <algorithm> | |||
| #include "utils/hash_map.h" | |||
| #include "runtime/framework/actor/actor_common.h" | |||
| #include "runtime/framework/actor/control_flow/control_actor.h" | |||
| @@ -36,7 +37,7 @@ namespace runtime { | |||
| // 4. Send output. | |||
| class StackActor : public ControlActor { | |||
| public: | |||
| StackActor(const std::string &name, const std::vector<KernelWithIndex> ¶meters); | |||
| StackActor(const std::string &name, const AID &memory_manager_aid, const std::vector<KernelWithIndex> ¶meters); | |||
| ~StackActor() override = default; | |||
| void Init() override; | |||
| @@ -50,15 +51,11 @@ class StackActor : public ControlActor { | |||
| void FetchInput(OpContext<DeviceTensor> *const context) override; | |||
| bool CheckRunningCondition(const OpContext<DeviceTensor> *context) const override; | |||
| void EraseInput(const OpContext<DeviceTensor> *const context) override; | |||
| void SendMemoryFreeReq(OpContext<DeviceTensor> *const context) override; | |||
| private: | |||
| friend class ControlNodeScheduler; | |||
| void FillStack(OpData<DeviceTensor> *const input_data, OpContext<DeviceTensor> *const context); | |||
| // The device tensors created and stored by the stack. | |||
| std::vector<DeviceTensorPtr> created_device_tensors_; | |||
| // The input data and partials records that the stack actor is copied from the input nodes and needs to be | |||
| // stored in the device tensor in the stack. | |||
| mindspore::HashMap<int, mindspore::HashMap<size_t, std::stack<DeviceTensor *>>> input_stack_data_; | |||
| @@ -24,9 +24,9 @@ namespace runtime { | |||
| constexpr size_t kMaxSwitchCondSize = 8; | |||
| constexpr size_t kSwitchDefaultOutputNum = 1; | |||
| SwitchActor::SwitchActor(const std::string &name, const std::vector<KernelWithIndex> ¶meters, | |||
| const AnfNodePtr &node) | |||
| : ControlActor(name, KernelTransformType::kSwitchActor, parameters, node) { | |||
| SwitchActor::SwitchActor(const std::string &name, const AID &memory_manager_aid, | |||
| const std::vector<KernelWithIndex> ¶meters, const AnfNodePtr &node) | |||
| : ControlActor(name, KernelTransformType::kSwitchActor, memory_manager_aid, parameters, node) { | |||
| device_contexts_.resize(parameters.size()); | |||
| output_data_by_output_index_.resize(kSwitchDefaultOutputNum); | |||
| } | |||
| @@ -33,7 +33,8 @@ using mindspore::session::KernelWithIndex; | |||
| // Switch and SwitchLayer node will be converted to switch actor. | |||
| class SwitchActor : public ControlActor { | |||
| public: | |||
| SwitchActor(const std::string &name, const std::vector<KernelWithIndex> ¶meters, const AnfNodePtr &node); | |||
| SwitchActor(const std::string &name, const AID &memory_manager_aid, const std::vector<KernelWithIndex> ¶meters, | |||
| const AnfNodePtr &node); | |||
| ~SwitchActor() override = default; | |||
| void Init() override; | |||
| @@ -17,11 +17,49 @@ | |||
| #include "runtime/framework/actor/memory_manager_actor.h" | |||
| #include "runtime/framework/actor/data_source_actor.h" | |||
| #include "runtime/framework/actor/kernel_actor.h" | |||
| #include "runtime/hardware/device_context_manager.h" | |||
| #include "mindrt/include/async/async.h" | |||
| #include "utils/log_adapter.h" | |||
| namespace mindspore { | |||
| namespace runtime { | |||
| namespace { | |||
| void FreeMemoryInner(DeviceTensor *const device_tensor, const DeviceContext *device_context) { | |||
| MS_EXCEPTION_IF_NULL(device_tensor); | |||
| // The device context may be not accurate in the control flow scene, so need fetch by device name and device id. | |||
| if ((device_context == nullptr) || (device_context->GetDeviceAddressType() != device_tensor->DeviceType())) { | |||
| const auto &new_device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext( | |||
| {device_tensor->device_name(), device_tensor->device_id()}); | |||
| MS_EXCEPTION_IF_NULL(new_device_context); | |||
| new_device_context->FreeMemory(device_tensor); | |||
| } else { | |||
| device_context->FreeMemory(device_tensor); | |||
| } | |||
| } | |||
| // Only one of the static and dynamic reference counts will take effect. | |||
| void FreeMemoryByRefCount(DeviceTensor *const device_tensor, const DeviceContext *device_context) { | |||
| MS_EXCEPTION_IF_NULL(device_tensor); | |||
| if (device_tensor->original_ref_count() != SIZE_MAX) { | |||
| // The static reference count is decremented to zero to free memory, and reset to the original count. | |||
| device_tensor->DecreaseRefCount(); | |||
| if (device_tensor->ref_count() == 0) { | |||
| if (device_tensor->GetPtr() != nullptr) { | |||
| FreeMemoryInner(device_tensor, device_context); | |||
| } | |||
| device_tensor->ResetRefCount(); | |||
| } | |||
| } else if (device_tensor->dynamic_ref_conut() != INT32_MAX) { | |||
| // The dynamic reference count is decremented to zero to free memory. | |||
| device_tensor->DecreaseDynamicRefCount(); | |||
| if ((device_tensor->dynamic_ref_conut() == 0) && (device_tensor->GetPtr() != nullptr)) { | |||
| MS_LOG(DEBUG) << "Free memory by the dynamic reference count, device address" << device_tensor->GetPtr(); | |||
| FreeMemoryInner(device_tensor, device_context); | |||
| } | |||
| } | |||
| } | |||
| } // namespace | |||
| void MemoryManagerActor::AllocateMemory(const std::vector<DeviceTensor *> *alloc_list, | |||
| const DeviceContext *device_context, OpContext<DeviceTensor> *const op_context, | |||
| const AID &from_aid) { | |||
| @@ -113,21 +151,8 @@ void MemoryManagerActor::AllocateBatchMemory(const std::vector<DeviceTensor *> * | |||
| void MemoryManagerActor::FreeMemory(const std::vector<DeviceTensor *> *free_list, const DeviceContext *device_context, | |||
| OpContext<DeviceTensor> *) { | |||
| MS_EXCEPTION_IF_NULL(free_list); | |||
| MS_EXCEPTION_IF_NULL(device_context); | |||
| for (auto &device_tensor : *free_list) { | |||
| MS_EXCEPTION_IF_NULL(device_tensor); | |||
| if (device_tensor->original_ref_count() == SIZE_MAX) { | |||
| continue; | |||
| } | |||
| // The reference count is decremented to zero to free memory, and reset to the original count. | |||
| device_tensor->DecreaseRefCount(); | |||
| if (device_tensor->ref_count() == 0) { | |||
| // Free memory through the device context. | |||
| if (device_tensor->GetPtr() != nullptr) { | |||
| device_context->FreeMemory(device_tensor); | |||
| } | |||
| device_tensor->ResetRefCount(); | |||
| } | |||
| FreeMemoryByRefCount(device_tensor, device_context); | |||
| } | |||
| } | |||
| @@ -145,20 +170,7 @@ void MemoryManagerActor::FreeBatchMemory(const std::vector<DeviceTensor *> *free | |||
| for (size_t i = 0; i < (*free_list).size(); ++i) { | |||
| auto &device_tensor = (*free_list)[i]; | |||
| auto &device_context = (*device_contexts)[i]; | |||
| MS_EXCEPTION_IF_NULL(device_tensor); | |||
| MS_EXCEPTION_IF_NULL(device_context); | |||
| if (device_tensor->original_ref_count() == SIZE_MAX) { | |||
| continue; | |||
| } | |||
| // The reference count is decremented to zero to free memory, and reset to the original count. | |||
| device_tensor->DecreaseRefCount(); | |||
| if (device_tensor->ref_count() == 0) { | |||
| // Free memory through the device context. | |||
| if (device_tensor->GetPtr() != nullptr) { | |||
| device_context->FreeMemory(device_tensor); | |||
| } | |||
| device_tensor->ResetRefCount(); | |||
| } | |||
| FreeMemoryByRefCount(device_tensor, device_context); | |||
| } | |||
| } | |||
| @@ -123,7 +123,7 @@ TensorPtr OutputActor::CreateOutputTensor(const AnfNodePtr &output_node, size_t | |||
| if (device_context->GetDeviceAddressType() != device_tensor->DeviceType()) { | |||
| auto old_device_context = device_context; | |||
| device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext( | |||
| {device_tensor->DeviceName(), device_tensor->DeviceID()}); | |||
| {device_tensor->device_name(), device_tensor->device_id()}); | |||
| MS_LOG(INFO) << "Update device context from:" << old_device_context->GetDeviceAddressType() | |||
| << " to:" << device_context->GetDeviceAddressType(); | |||
| } | |||
| @@ -169,6 +169,7 @@ void OutputActor::UpdateOutputDeviceAddress() { | |||
| tensor_device_address->ResetRefCount(); | |||
| auto node_with_index = device_tensor->GetNodeIndex(); | |||
| tensor_device_address->SetNodeIndex(node_with_index.first, node_with_index.second); | |||
| tensor_device_address->set_from_persistent_mem(device_tensor->from_persistent_mem()); | |||
| // The outputs may have the same output node, so need skip when the node has been done. | |||
| if (device_tensor->GetPtr() == nullptr) { | |||
| continue; | |||
| @@ -16,6 +16,7 @@ | |||
| #include "runtime/framework/actor/super_kernel_actor.h" | |||
| #include "runtime/framework/actor/output_actor.h" | |||
| #include "runtime/framework/actor/memory_manager_actor.h" | |||
| #include "mindrt/include/async/async.h" | |||
| #include "utils/log_adapter.h" | |||
| @@ -164,5 +165,28 @@ bool SuperKernelActor::CopyInputData(const OpContext<DeviceTensor> *context) { | |||
| return true; | |||
| } | |||
| void SuperKernelActor::SendMemoryFreeReq(OpContext<DeviceTensor> *const context) { | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| const auto &sequential_num = context->sequential_num_; | |||
| // Collect the input device tensors. | |||
| std::vector<DeviceTensor *> memory_free_list; | |||
| if (input_op_datas_.count(sequential_num) > 0) { | |||
| for (auto &input_data : input_op_datas_[sequential_num]) { | |||
| MS_EXCEPTION_IF_NULL(input_data); | |||
| MS_EXCEPTION_IF_NULL(input_data->data_); | |||
| if (input_data->data_->dynamic_ref_conut() != INT32_MAX) { | |||
| memory_free_list.emplace_back(input_data->data_); | |||
| } | |||
| } | |||
| } | |||
| if (memory_free_list.size() > 0) { | |||
| memory_free_lists_.emplace_back(memory_free_list); | |||
| ActorDispatcher::Send(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &(memory_free_lists_.back()), | |||
| device_contexts_[0], context); | |||
| } | |||
| } | |||
| } // namespace runtime | |||
| } // namespace mindspore | |||
| @@ -21,6 +21,7 @@ | |||
| #include <memory> | |||
| #include <map> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "runtime/framework/actor/debug_aware_actor.h" | |||
| #include "runtime/framework/actor/actor_common.h" | |||
| #include "runtime/hardware/device_context.h" | |||
| @@ -50,6 +51,8 @@ class SuperKernelActor : public DebugAwareActor { | |||
| protected: | |||
| void Run(OpContext<DeviceTensor> *const context) override; | |||
| // The input may come from the control actor, so need free the input memory by the dynamic ref count. | |||
| void SendMemoryFreeReq(OpContext<DeviceTensor> *const context) override; | |||
| private: | |||
| friend class GraphScheduler; | |||
| @@ -59,6 +62,9 @@ class SuperKernelActor : public DebugAwareActor { | |||
| KernelGraphPtr graph_; | |||
| std::map<AnfNodePtr, DeviceAddress *> ref_node_addr_map_; | |||
| // The lists of device tensors which need free by dynamic ref count, will be cleared at the end of step. | |||
| std::vector<std::vector<DeviceTensor *>> memory_free_lists_; | |||
| }; | |||
| using SuperKernelActorPtr = std::shared_ptr<SuperKernelActor>; | |||
| @@ -83,12 +83,14 @@ bool IsControlFlowArrow(const ControlNodeParserPtr &parser, const KernelGraphPtr | |||
| } | |||
| } // namespace | |||
| ControlActorSetPtr ControlNodeScheduler::Build(const GraphCompilerInfo &graph_compiler_info) { | |||
| ControlActorSetPtr ControlNodeScheduler::Build(const GraphCompilerInfo &graph_compiler_info, | |||
| const AID &memory_manager_aid) { | |||
| const auto &control_nodes = graph_compiler_info.control_nodes_; | |||
| if (control_nodes.size() <= kSingleControlNode) { | |||
| return nullptr; | |||
| } | |||
| memory_manager_aid_ = memory_manager_aid; | |||
| ControlActorSetPtr control_actors = std::make_shared<ControlActorSet>(); | |||
| control_actors->switch_actors_ = BuildSwitchActor(graph_compiler_info); | |||
| control_actors->gather_actors_ = BuildGatherActor(graph_compiler_info); | |||
| @@ -108,7 +110,8 @@ std::vector<SwitchActorPtr> ControlNodeScheduler::BuildSwitchActor(const GraphCo | |||
| AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitchLayer)) { | |||
| const auto &actor_name = GetActorName(control_node); | |||
| const auto ¶meters = FetchInputNodeByCNode(control_node); | |||
| const auto &switch_actor = std::make_shared<SwitchActor>(actor_name, parameters, control_node); | |||
| const auto &switch_actor = | |||
| std::make_shared<SwitchActor>(actor_name, memory_manager_aid_, parameters, control_node); | |||
| switch_actors.emplace_back(switch_actor); | |||
| InsertActor(switch_actor.get()); | |||
| } | |||
| @@ -127,7 +130,8 @@ std::vector<GatherActorPtr> ControlNodeScheduler::BuildGatherActor(const GraphCo | |||
| if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimPartial) || AnfAlgo::IsCallNode(control_node)) { | |||
| const auto &actor_name = GetActorName(control_node); | |||
| const auto ¶meters = FetchInputNodeByCNode(control_node); | |||
| const auto &gather_actor = std::make_shared<GatherActor>(actor_name, parameters, control_node); | |||
| const auto &gather_actor = | |||
| std::make_shared<GatherActor>(actor_name, memory_manager_aid_, parameters, control_node); | |||
| gather_actors.emplace_back(gather_actor); | |||
| InsertActor(gather_actor.get()); | |||
| @@ -189,7 +193,7 @@ std::vector<EntranceActorPtr> ControlNodeScheduler::BuildEntranceActor(const Gra | |||
| call_nodes = iter->second; | |||
| } | |||
| const auto &entrance_actor = | |||
| std::make_shared<EntranceActor>(actor_name, formal_parameters, call_nodes, control_node); | |||
| std::make_shared<EntranceActor>(actor_name, memory_manager_aid_, formal_parameters, call_nodes, control_node); | |||
| auto context_iter = parser->func_graph_to_device_contexts_.find(func_graph); | |||
| if (context_iter == parser->func_graph_to_device_contexts_.end() || | |||
| context_iter->second.size() < formal_parameters.size()) { | |||
| @@ -202,7 +206,6 @@ std::vector<EntranceActorPtr> ControlNodeScheduler::BuildEntranceActor(const Gra | |||
| entrance_actor->device_contexts_.clear(); | |||
| entrance_actor->device_contexts_.insert(entrance_actor->device_contexts_.begin(), context_iter->second.begin(), | |||
| context_iter->second.begin() + formal_parameters.size()); | |||
| entrance_actors.emplace_back(entrance_actor); | |||
| InsertActor(entrance_actor.get()); | |||
| } | |||
| @@ -225,7 +228,7 @@ std::vector<ExitActorPtr> ControlNodeScheduler::BuildExitActor(const GraphCompil | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| const auto &actor_name = func_graph->ToString() + kExitActorNameSuffix; | |||
| const auto ¶meters = FetchInputNodeByCNode(control_node); | |||
| const auto &exit_actor = std::make_shared<ExitActor>(actor_name, parameters, control_node); | |||
| const auto &exit_actor = std::make_shared<ExitActor>(actor_name, memory_manager_aid_, parameters, control_node); | |||
| auto context_iter = parser->control_node_to_device_contexts_.find(control_node); | |||
| if (context_iter == parser->control_node_to_device_contexts_.end() || | |||
| context_iter->second.size() != parameters.size()) { | |||
| @@ -267,7 +270,7 @@ std::vector<ExitActorPtr> ControlNodeScheduler::BuildExitActor(const GraphCompil | |||
| } | |||
| const auto &actor_name = kernel_graph_group_info->group_name_ + kExitActorNameSuffix; | |||
| const auto &exit_actor = std::make_shared<ExitActor>(actor_name, formal_parameters, nullptr); | |||
| const auto &exit_actor = std::make_shared<ExitActor>(actor_name, memory_manager_aid_, formal_parameters, nullptr); | |||
| exit_actor->is_need_copy_device_tensors_.swap(is_need_copy_device_tensors); | |||
| exit_actor->device_contexts_.swap(device_contexts); | |||
| exit_actors.emplace_back(exit_actor); | |||
| @@ -305,7 +308,7 @@ std::vector<StackActorPtr> ControlNodeScheduler::BuildStackActor(const GraphComp | |||
| } | |||
| } | |||
| const auto &actor_name = kernel_graph_group_info->group_name_ + kStackActorNameSuffix; | |||
| const auto &stack_actor = std::make_shared<StackActor>(actor_name, formal_parameters); | |||
| const auto &stack_actor = std::make_shared<StackActor>(actor_name, memory_manager_aid_, formal_parameters); | |||
| stack_actors.emplace_back(stack_actor); | |||
| stack_actor->device_contexts_.swap(device_contexts); | |||
| stack_actor->input_stack_data_num_ = input_parameter_data_num; | |||
| @@ -379,7 +382,7 @@ void ControlNodeScheduler::BuildStackActorForControlNode(const GraphCompilerInfo | |||
| } | |||
| // Create stack actor. | |||
| const auto &stack_actor_name = GetActorName(need_stack_control_node) + kStackActorNameSuffix; | |||
| const auto &stack_actor = std::make_shared<StackActor>(stack_actor_name, formal_parameters); | |||
| const auto &stack_actor = std::make_shared<StackActor>(stack_actor_name, memory_manager_aid_, formal_parameters); | |||
| stack_actor->device_contexts_ = device_contexts; | |||
| stack_actor->input_stack_data_num_ = input_parameter_data_num; | |||
| stack_actor->input_stack_partials_num_ = input_parameter_partials_num; | |||
| @@ -422,8 +425,29 @@ void ControlNodeScheduler::ClearActorData(const ControlActorSet *control_actor_s | |||
| return; | |||
| } | |||
| for (auto &switch_actor : control_actor_set->switch_actors_) { | |||
| MS_EXCEPTION_IF_NULL(switch_actor); | |||
| switch_actor->memory_free_lists_.clear(); | |||
| } | |||
| for (auto &gather_actor : control_actor_set->gather_actors_) { | |||
| MS_EXCEPTION_IF_NULL(gather_actor); | |||
| gather_actor->memory_free_lists_.clear(); | |||
| } | |||
| for (auto &entrance_actor : control_actor_set->entrance_actors_) { | |||
| MS_EXCEPTION_IF_NULL(entrance_actor); | |||
| entrance_actor->memory_free_lists_.clear(); | |||
| } | |||
| for (auto &stack_actor : control_actor_set->stack_actors_) { | |||
| MS_EXCEPTION_IF_NULL(stack_actor); | |||
| stack_actor->memory_free_lists_.clear(); | |||
| } | |||
| for (auto &exit_actor : control_actor_set->exit_actors_) { | |||
| MS_EXCEPTION_IF_NULL(exit_actor); | |||
| exit_actor->memory_free_lists_.clear(); | |||
| exit_actor->created_device_tensors_.clear(); | |||
| } | |||
| } | |||
| @@ -37,7 +37,7 @@ class ControlNodeScheduler { | |||
| DISABLE_COPY_AND_ASSIGN(ControlNodeScheduler); | |||
| // Transform the control nodes to control actors. | |||
| ControlActorSetPtr Build(const GraphCompilerInfo &graph_compiler_info); | |||
| ControlActorSetPtr Build(const GraphCompilerInfo &graph_compiler_info, const AID &memory_manager_aid); | |||
| // Link control actors. | |||
| void Link(ActorSet *actor_set, const GraphCompilerInfo &graph_compiler_info); | |||
| @@ -106,6 +106,9 @@ class ControlNodeScheduler { | |||
| void LinkPartialArrowForExitActor(ExitActor *const exit_actor, ControlActor *const to_actor, size_t from_index, | |||
| size_t to_index, int branch_id); | |||
| bool IsNoInputActor(const ControlActor *control_actor); | |||
| // The id of memory manager actor. | |||
| AID memory_manager_aid_; | |||
| }; | |||
| } // namespace runtime | |||
| } // namespace mindspore | |||
| @@ -217,6 +217,12 @@ void GraphScheduler::Clear() { | |||
| void GraphScheduler::ClearActorData(const ActorSet *actor_set) { | |||
| MS_EXCEPTION_IF_NULL(actor_set); | |||
| for (auto &super_kernel_actor : actor_set->super_kernel_actors_) { | |||
| MS_EXCEPTION_IF_NULL(super_kernel_actor); | |||
| super_kernel_actor->memory_free_lists_.clear(); | |||
| } | |||
| control_node_scheduler_.ClearActorData(actor_set->control_actors_.get()); | |||
| } | |||
| @@ -486,7 +492,7 @@ ActorSetPtr GraphScheduler::Build(const GraphCompilerInfo &graph_compiler_info) | |||
| actor_set->output_actor_ = BuildOutputActor(graph_compiler_info); | |||
| actor_set->data_prepare_actor_ = | |||
| BuildDataPrepareActor(graph_compiler_info, actor_set->data_source_actors_, host_queue); | |||
| actor_set->control_actors_ = control_node_scheduler_.Build(graph_compiler_info); | |||
| actor_set->control_actors_ = control_node_scheduler_.Build(graph_compiler_info, memory_manager_aid_); | |||
| return actor_set; | |||
| } | |||