| @@ -183,6 +183,26 @@ void DumpCopyActor(const CopyActor *actor, std::ofstream &ofs) { | |||
| ofs << "\n"; | |||
| } | |||
| void DumpFormalParameterDeviceTensor(const ControlActor *actor, std::ofstream &ofs) { | |||
| MS_EXCEPTION_IF_NULL(actor); | |||
| const auto &formal_parameter_device_tensors = actor->ref_formal_parameter_device_tensors(); | |||
| if (formal_parameter_device_tensors.empty()) { | |||
| return; | |||
| } | |||
| ofs << "\t\tref_formal_parameter_device_tensors:" << formal_parameter_device_tensors.size() << "\n "; | |||
| for (const auto &formal_parameter_device_tensor : formal_parameter_device_tensors) { | |||
| for (const auto &device_tensor : formal_parameter_device_tensor.second) { | |||
| MS_EXCEPTION_IF_NULL(device_tensor); | |||
| auto ref_node = device_tensor->GetNodeIndex(); | |||
| MS_EXCEPTION_IF_NULL(ref_node.first); | |||
| ofs << "\t\t\tref_position:" << formal_parameter_device_tensor.first | |||
| << "\tref_node_name:" << ref_node.first->fullname_with_scope() | |||
| << "\tref_node_debug_name:" << ref_node.first->DebugString() << "\n"; | |||
| } | |||
| } | |||
| } | |||
| void DumpControlActor(const ControlActor *actor, std::ofstream &ofs) { | |||
| MS_EXCEPTION_IF_NULL(actor); | |||
| DumpAbstractActor(actor, ofs); | |||
| @@ -229,6 +249,8 @@ void DumpControlActor(const ControlActor *actor, std::ofstream &ofs) { | |||
| ofs << "\t\t\tto_actor_name:" << aid.Name() << "\n"; | |||
| } | |||
| } | |||
| DumpFormalParameterDeviceTensor(actor, ofs); | |||
| } | |||
| void DumpSwitchActor(const SwitchActor *actor, std::ofstream &ofs) { | |||
| @@ -309,6 +331,14 @@ void DumpExitActor(const ExitActor *actor, std::ofstream &ofs) { | |||
| } | |||
| } | |||
| } | |||
| const auto &is_need_copy_device_tensors = actor->is_need_copy_device_tensors(); | |||
| if (is_need_copy_device_tensors.size() > 0) { | |||
| ofs << "\t\twhether_need_copy_device_tensors:" << is_need_copy_device_tensors.size() << "\n "; | |||
| for (size_t i = 0; i < is_need_copy_device_tensors.size(); ++i) { | |||
| ofs << "\t\t\tdevice_tensor_position:" << i << "\tis_need_copy:" << is_need_copy_device_tensors[i] << "\n"; | |||
| } | |||
| } | |||
| } | |||
| void DumpStackActor(const StackActor *actor, std::ofstream &ofs) { | |||
| @@ -449,11 +479,11 @@ void DumpCopyActors(const std::vector<CopyActorPtr> &actors, std::ofstream &ofs) | |||
| } | |||
| void DumpControlActors(const ControlActorSetPtr &control_actor_set, std::ofstream &ofs) { | |||
| ofs << "\n\n[Control actors]\n"; | |||
| if (control_actor_set == nullptr) { | |||
| return; | |||
| } | |||
| ofs << "\n\n[Control actors]\n"; | |||
| DumpEntranceActors(control_actor_set->entrance_actors_, ofs); | |||
| DumpSwitchActors(control_actor_set->switch_actors_, ofs); | |||
| DumpGatherActors(control_actor_set->gather_actors_, ofs); | |||
| @@ -333,6 +333,54 @@ void ControlActor::EraseInput(const OpContext<DeviceTensor> *context) { | |||
| } | |||
| } | |||
| void ControlActor::UpdateOutputData(OpData<DeviceTensor> *const output_data, const DataArrowPtr &data_arrow, | |||
| const AnfNodePtr &, OpContext<DeviceTensor> *const context) { | |||
| MS_EXCEPTION_IF_NULL(output_data); | |||
| MS_EXCEPTION_IF_NULL(data_arrow); | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| const auto &data = output_data->data_; | |||
| MS_EXCEPTION_IF_NULL(data); | |||
| auto formal_parameter_position = data_arrow->from_output_index_; | |||
| // Has no the ref formal parameter. | |||
| if (ref_formal_parameter_device_tensors_.count(formal_parameter_position) == 0) { | |||
| return; | |||
| } | |||
| if (data->GetMutablePtr() == nullptr) { | |||
| std::string error_info = | |||
| "The address of the " + std::to_string(formal_parameter_position) + "position formal parameter is nullptr."; | |||
| SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info); | |||
| } | |||
| if (data->ref_count() != SIZE_MAX) { | |||
| std::string error_info = "The ref count of the " + std::to_string(formal_parameter_position) + | |||
| "position formal parameter is wrong:" + std::to_string(data->ref_count()); | |||
| SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info); | |||
| } | |||
| // Foreach the device tensors to set the ptr from data. | |||
| for (auto &device_tensor : ref_formal_parameter_device_tensors_[formal_parameter_position]) { | |||
| MS_EXCEPTION_IF_NULL(device_tensor); | |||
| if ((device_tensor.get() == data) || (device_tensor->GetMutablePtr() == data->GetMutablePtr())) { | |||
| continue; | |||
| } | |||
| auto real_parameter = device_tensor->GetNodeIndex(); | |||
| MS_EXCEPTION_IF_NULL(real_parameter.first); | |||
| if ((device_tensor->GetSize() != data->GetSize()) || (device_tensor->format() != data->format()) || | |||
| (device_tensor->type_id() != data->type_id())) { | |||
| std::string error_info = | |||
| "The address of the " + std::to_string(formal_parameter_position) + | |||
| "position formal parameter can not be set to real parameter:" + real_parameter.first->DebugString(); | |||
| SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info); | |||
| } | |||
| device_tensor->set_ptr(data->GetMutablePtr()); | |||
| MS_LOG(DEBUG) << "Set the ptr: " << data->GetMutablePtr() | |||
| << " for the ref real parameter: " << real_parameter.first->DebugString() | |||
| << " in the actor: " << GetAID().Name(); | |||
| } | |||
| } | |||
| void ControlActor::SendOutput(OpContext<DeviceTensor> *const context) { | |||
| // Send branch id. | |||
| for (const auto &branch_id_arrow : output_branch_id_arrows_) { | |||
| @@ -21,6 +21,7 @@ | |||
| #include <string> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <set> | |||
| #include <unordered_map> | |||
| #include <stack> | |||
| #include <queue> | |||
| @@ -70,6 +71,9 @@ class ControlActor : public MemoryAwareActor { | |||
| const std::unordered_map<size_t, OpPartialPtr> &local_partials() const { return local_partials_; } | |||
| const std::vector<AID> &input_partial_arrow_aids() const { return input_partial_arrow_aids_; } | |||
| const std::vector<AID> &input_branch_id_arrow_aids() const { return input_branch_id_arrow_aids_; } | |||
| const std::map<size_t, std::set<DeviceTensorPtr>> &ref_formal_parameter_device_tensors() const { | |||
| return ref_formal_parameter_device_tensors_; | |||
| } | |||
| size_t branch_id() const { return output_branch_id_; } | |||
| protected: | |||
| @@ -89,6 +93,8 @@ class ControlActor : public MemoryAwareActor { | |||
| virtual void FetchInput(OpContext<DeviceTensor> *const context); | |||
| void Run(OpContext<DeviceTensor> *const context) override; | |||
| bool CheckRunningCondition(const OpContext<DeviceTensor> *context) const override; | |||
| void UpdateOutputData(OpData<DeviceTensor> *const output_data, const DataArrowPtr &data_arrow, | |||
| const AnfNodePtr &output_node, OpContext<DeviceTensor> *const context) override; | |||
| void SendOutput(OpContext<DeviceTensor> *const context) override; | |||
| void EraseInput(const OpContext<DeviceTensor> *context) override; | |||
| @@ -144,6 +150,10 @@ class ControlActor : public MemoryAwareActor { | |||
| // Formal parameters for control actor. | |||
| std::vector<KernelWithIndex> formal_parameters_; | |||
| // The device tensors of backend input nodes corresponding to ref formal parameters, the key is the position index of | |||
| // formal parameter. Used to update the ptr of device tensors when receive the real parameters for ref nodes. | |||
| std::map<size_t, std::set<DeviceTensorPtr>> ref_formal_parameter_device_tensors_; | |||
| // local node for control actor, such as return node for exit actor, switch node for switch actor. | |||
| AnfNodePtr node_; | |||
| }; | |||
| @@ -51,6 +51,7 @@ class ExitActor : public ControlActor { | |||
| const mindspore::HashMap<int, std::vector<DataArrowPtr>> &output_branch_partial_arrows() const { | |||
| return output_branch_partial_arrows_; | |||
| } | |||
| const std::vector<bool> &is_need_copy_device_tensors() const { return is_need_copy_device_tensors_; } | |||
| protected: | |||
| void FetchInput(OpContext<DeviceTensor> *const context) override; | |||
| @@ -90,6 +90,26 @@ bool IsControlFlowArrow(const ControlNodeParserPtr &parser, const KernelGraphPtr | |||
| (from_node != nullptr && IsPersistentDeviceTensor(from_node)) || | |||
| (from_node != nullptr && parser->IsSameKernelGraphGroup(from_node, graph)); | |||
| } | |||
| // Parameter and ref node can not copy the device tensor. | |||
| bool is_need_copy_device_tensor(const AnfNodePtr &backend_node, size_t index) { | |||
| MS_EXCEPTION_IF_NULL(backend_node); | |||
| if (!backend_node->isa<CNode>()) { | |||
| return false; | |||
| } | |||
| if (HasAbstractRef(backend_node)) { | |||
| return false; | |||
| } | |||
| auto kernel_graph = FetchKernelGraph(backend_node); | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| if (kernel_graph->IsInRefOutputMap({backend_node, index})) { | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| } // namespace | |||
| ControlActorSetPtr ControlNodeScheduler::Build(const GraphCompilerInfo &graph_compiler_info, | |||
| @@ -274,7 +294,8 @@ std::vector<ExitActorPtr> ControlNodeScheduler::BuildExitActor(const GraphCompil | |||
| // Get the device contexts of the exit actor's cnode inputs. | |||
| const AnfNodePtr &backend_node = node_with_context.second.first.first; | |||
| MS_EXCEPTION_IF_NULL(backend_node); | |||
| is_need_copy_device_tensors.emplace_back(backend_node->isa<CNode>() ? true : false); | |||
| is_need_copy_device_tensors.emplace_back( | |||
| is_need_copy_device_tensor(backend_node, node_with_context.second.first.second)); | |||
| device_contexts.emplace_back(node_with_context.second.second); | |||
| } | |||
| @@ -1155,6 +1176,7 @@ void ControlNodeScheduler::LinkDataArrowByKernelGraph(const KernelGraphPtr &grap | |||
| to_index = super_kernel_actor->FetchInputNodePosition(input); | |||
| (void)sink_input_node_linked.insert(input); | |||
| } | |||
| AddFormalParameterDeviceTensor(from_actor, from_index, input); | |||
| LinkDataArrow(from_actor, to_actor, from_index, to_index); | |||
| } | |||
| } | |||
| @@ -1231,6 +1253,22 @@ void ControlNodeScheduler::LinkArrowForRootGraphEntranceActor(const GraphCompile | |||
| } | |||
| } | |||
| void ControlNodeScheduler::AddFormalParameterDeviceTensor(ControlActor *const from_actor, size_t from_index, | |||
| const AnfNodePtr &input_node) { | |||
| MS_EXCEPTION_IF_NULL(from_actor); | |||
| MS_EXCEPTION_IF_NULL(input_node); | |||
| if (!HasAbstractRef(input_node)) { | |||
| return; | |||
| } | |||
| auto device_tensor = AnfAlgo::GetMutableOutputAddr(input_node, 0, false); | |||
| MS_EXCEPTION_IF_NULL(device_tensor); | |||
| (void)from_actor->ref_formal_parameter_device_tensors_[from_index].insert(device_tensor); | |||
| UpdateRefCount(device_tensor.get(), true); | |||
| device_tensor->SetNodeIndex(input_node, 0); | |||
| } | |||
| void ControlNodeScheduler::LinkDataArrow(AbstractActor *const from_actor, AbstractActor *const to_actor, | |||
| size_t from_index, size_t to_index, const AnfNodePtr &from_kernel) { | |||
| MS_EXCEPTION_IF_NULL(from_actor); | |||
| @@ -25,6 +25,7 @@ | |||
| #include <map> | |||
| #include <set> | |||
| #include <algorithm> | |||
| #include <queue> | |||
| #include "runtime/framework/actor/actor_set.h" | |||
| #include "runtime/framework/graph_compiler.h" | |||
| @@ -107,6 +108,9 @@ class ControlNodeScheduler { | |||
| size_t to_index, int branch_id); | |||
| bool IsNoInputActor(const ControlActor *control_actor); | |||
| // Fill the device tensors of backend input nodes corresponding to ref formal parameters. | |||
| void AddFormalParameterDeviceTensor(ControlActor *const from_actor, size_t from_index, const AnfNodePtr &input_node); | |||
| // The id of memory manager actor. | |||
| AID memory_manager_aid_; | |||
| }; | |||