| @@ -20,6 +20,31 @@ | |||
| namespace mindspore { | |||
| namespace runtime { | |||
| void AbstractActor::RunOpData(OpData<DeviceTensor> *const input_data, OpContext<DeviceTensor> *const context) { | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| auto &sequential_num = context->sequential_num_; | |||
| (void)input_op_datas_[sequential_num].emplace_back(input_data); | |||
| auto is_run = CheckRunningCondition(context); | |||
| MS_LOG(DEBUG) << "Actor(" << GetAID().Name() << ") receive the input op data and check running condition:" << is_run; | |||
| if (is_run) { | |||
| Run(context); | |||
| } | |||
| } | |||
| void AbstractActor::RunOpControl(AID *const input_control, OpContext<DeviceTensor> *const context) { | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| auto &sequential_num = context->sequential_num_; | |||
| (void)input_op_controls_[sequential_num].emplace_back(input_control); | |||
| auto is_run = CheckRunningCondition(context); | |||
| MS_LOG(DEBUG) << "Actor(" << GetAID().Name() | |||
| << ") receive the input op control and check running condition:" << is_run; | |||
| if (is_run) { | |||
| Run(context); | |||
| } | |||
| } | |||
| bool AbstractActor::CheckRunningCondition(const OpContext<DeviceTensor> *context) const { | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| if (input_datas_num_ != 0) { | |||
| @@ -67,30 +92,47 @@ void AbstractActor::EraseInput(const OpContext<DeviceTensor> *context) { | |||
| } | |||
| } | |||
| void AbstractActor::SendOutputResult(OpContext<DeviceTensor> *const context) const { | |||
| void AbstractActor::SendOutput(OpContext<DeviceTensor> *const context) { | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| // Must be the execution order: send result --> send data --> send control, avoid the illegal timing problem. | |||
| // 1.Send graph output result. | |||
| if (output_result_arrows_.size() != output_nodes_.size()) { | |||
| SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "The size of output result arrows is not equal to the output nodes."); | |||
| } | |||
| size_t output_node_index = 0; | |||
| for (const auto &result_arrow : output_result_arrows_) { | |||
| MS_EXCEPTION_IF_NULL(result_arrow); | |||
| Async(result_arrow->to_op_id_, &OutputActor::CollectOutput, output_nodes_[output_node_index], | |||
| Async(result_arrow->to_op_id_, &OutputActor::CollectOutput, output_nodes_[output_node_index++], | |||
| result_arrow->from_output_index_, result_arrow->to_input_index_, context); | |||
| ++output_node_index; | |||
| } | |||
| } | |||
| void AbstractActor::SendOutputControl(OpContext<DeviceTensor> *const context) const { | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| // 2.Send output data. | |||
| if (output_data_arrows_.size() != output_data_.size()) { | |||
| SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "The size of output data arrows is not equal to the output data."); | |||
| } | |||
| size_t output_data_arrow_index = 0; | |||
| for (auto &output_data : output_data_) { | |||
| MS_EXCEPTION_IF_NULL(output_data); | |||
| UpdateOutputData(output_data.get(), output_data_arrows_[output_data_arrow_index++].get(), context); | |||
| Async(output_data->op_id_, &OpActor::RunOpData, output_data.get(), context); | |||
| } | |||
| // 3.Send output control. | |||
| if (output_control_arrows_.size() > 0) { | |||
| auto from_aid = const_cast<AID *>(&GetAID()); | |||
| for (auto &output_control : output_control_arrows_) { | |||
| Async(output_control, &OpActor::RunOpControl, from_aid, context); | |||
| } | |||
| } | |||
| // 4.Send recorder info. | |||
| SendRecorderInfo(context); | |||
| // No output. | |||
| if ((output_data_arrows_.size() == 0) && (output_control_arrows_.size() == 0) && | |||
| (output_result_arrows_.size() == 0)) { | |||
| SET_OPCONTEXT_SUCCESS_RET((*context)); | |||
| } | |||
| } | |||
| } // namespace runtime | |||
| } // namespace mindspore | |||
| @@ -45,6 +45,11 @@ class AbstractActor : public OpActor<DeviceTensor> { | |||
| bool IsActive(int msg_num) override { return msg_num >= running_dependent_msg_num_ ? true : false; } | |||
| // The actor run when receive the input data. | |||
| void RunOpData(OpData<DeviceTensor> *const input_data, OpContext<DeviceTensor> *const context) override; | |||
| // The actor run when receive the input control. | |||
| void RunOpControl(AID *const input_control, OpContext<DeviceTensor> *const context) override; | |||
| // Get the position of node in the actor. | |||
| virtual size_t FetchNodePosition(const AnfNodePtr &node) const { return 0; } | |||
| @@ -53,12 +58,19 @@ class AbstractActor : public OpActor<DeviceTensor> { | |||
| // Check whether satisfy the actor running condition. | |||
| bool CheckRunningCondition(const OpContext<DeviceTensor> *context) const; | |||
| // The actor run really when satisfy the actor running condition. | |||
| virtual void Run(OpContext<DeviceTensor> *const context) {} | |||
| // Erase input data and input controls when finish actor running. | |||
| void EraseInput(const OpContext<DeviceTensor> *const context); | |||
| // Send the output result by output_result_arrows_. | |||
| void SendOutputResult(OpContext<DeviceTensor> *const context) const; | |||
| // Send the output control by output_control_arrows_. | |||
| void SendOutputControl(OpContext<DeviceTensor> *const context) const; | |||
| void EraseInput(const OpContext<DeviceTensor> *context); | |||
| // Update the output data before send output data. | |||
| virtual void UpdateOutputData(OpData<DeviceTensor> *const output_data, const DataArrow *data_arrow, | |||
| OpContext<DeviceTensor> *const context) {} | |||
| // Send output to downstream actors to trigger running. | |||
| virtual void SendOutput(OpContext<DeviceTensor> *const context); | |||
| // Send recorder info to recorder actor. | |||
| virtual void SendRecorderInfo(OpContext<DeviceTensor> *const context) const {} | |||
| KernelTransformType type_; | |||
| @@ -68,6 +80,9 @@ class AbstractActor : public OpActor<DeviceTensor> { | |||
| // The id of recorder actor. Send message to it for recording info. | |||
| const AID *recorder_aid_; | |||
| // The output_data_ corresponds to the output_data_arrows_ one by one. | |||
| std::vector<OpDataUniquePtr<DeviceTensor>> output_data_; | |||
| // The output nodes and output result arrows of graph output. | |||
| std::vector<AnfNodePtr> output_nodes_; | |||
| std::vector<DataArrowPtr> output_result_arrows_; | |||
| @@ -40,17 +40,16 @@ class EntranceActor : public AbstractActor { | |||
| void Init() override; | |||
| // The entrance actor run when receive the input control. | |||
| void RunOpControl(AID *const input_control, OpContext<DeviceTensor> *const context) override; | |||
| // The entrance actor run when receive the real parameter nodes and branch id. | |||
| void CollectRealParametersAndBranchId(const std::vector<KernelWithIndex> &real_parameters, int branch_id, | |||
| OpContext<DeviceTensor> *const context); | |||
| protected: | |||
| void Run(OpContext<DeviceTensor> *const context) override; | |||
| private: | |||
| friend class GraphScheduler; | |||
| void SendOutput(OpContext<DeviceTensor> *const context) const; | |||
| // Formal parameters of actor, which is the front node. | |||
| std::vector<KernelWithIndex> formal_parameters_; | |||
| @@ -44,8 +44,6 @@ class ExitActor : public AbstractActor { | |||
| private: | |||
| friend class GraphScheduler; | |||
| void SendOutput(OpContext<DeviceTensor> *const context) const; | |||
| // Formal parameters of actor, which is the front node. | |||
| std::vector<KernelWithIndex> formal_parameters_; | |||
| @@ -47,7 +47,6 @@ class GatherActor : public AbstractActor { | |||
| private: | |||
| friend class GraphScheduler; | |||
| void SendOutput(OpContext<DeviceTensor> *const context) const; | |||
| // Formal parameters of actor, which is the front node. | |||
| std::vector<KernelWithIndex> formal_parameters_; | |||
| @@ -44,7 +44,6 @@ class StackActor : public MemoryAwareActor { | |||
| private: | |||
| friend class GraphScheduler; | |||
| void SendOutput(OpContext<DeviceTensor> *const context) const; | |||
| // Formal parameters record the input front-end node, these nodes may be parameter, kernel, call node. | |||
| std::vector<KernelWithIndex> formal_parameters_; | |||
| @@ -45,26 +45,10 @@ void CopyActor::Init() { | |||
| } | |||
| } | |||
| void CopyActor::RunOpData(OpData<DeviceTensor> *const input_data, OpContext<DeviceTensor> *const context) { | |||
| void CopyActor::Run(OpContext<DeviceTensor> *const context) { | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| auto &sequential_num = context->sequential_num_; | |||
| (void)input_op_datas_[sequential_num].emplace_back(input_data); | |||
| // When all the inputs are collected, then allocate memory and callback copy. | |||
| if (CheckRunningCondition(context)) { | |||
| FetchDeviceTensor(context); | |||
| SendMemoryAllocReq(context); | |||
| } | |||
| } | |||
| void CopyActor::RunOpControl(AID *const input_control, OpContext<DeviceTensor> *const context) { | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| auto &sequential_num = context->sequential_num_; | |||
| (void)input_op_controls_[sequential_num].emplace_back(input_control); | |||
| // When all the inputs are collected, then allocate memory and callback copy. | |||
| if (CheckRunningCondition(context)) { | |||
| FetchDeviceTensor(context); | |||
| SendMemoryAllocReq(context); | |||
| } | |||
| FetchDeviceTensor(context); | |||
| SendMemoryAllocReq(context); | |||
| } | |||
| void CopyActor::SendMemoryAllocReq(OpContext<DeviceTensor> *const context) { | |||
| @@ -146,22 +130,10 @@ void CopyActor::FetchDeviceTensor(OpContext<DeviceTensor> *const context) { | |||
| } | |||
| } | |||
| void CopyActor::SendOutput(OpContext<DeviceTensor> *const context) const { | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| // No output. | |||
| if ((output_data_arrows_.size() == 0) && (output_control_arrows_.size() == 0)) { | |||
| SET_OPCONTEXT_SUCCESS_RET((*context)); | |||
| } | |||
| // Send output data. | |||
| for (auto &output_data : output_data_) { | |||
| MS_EXCEPTION_IF_NULL(output_data); | |||
| output_data->data_ = output_device_tensor_[0]; | |||
| Async(output_data->op_id_, &OpActor::RunOpData, output_data.get(), context); | |||
| } | |||
| // Send output control. | |||
| SendOutputControl(context); | |||
| void CopyActor::UpdateOutputData(OpData<DeviceTensor> *const output_data, const DataArrow *, | |||
| OpContext<DeviceTensor> *const) { | |||
| MS_EXCEPTION_IF_NULL(output_data); | |||
| output_data->data_ = output_device_tensor_[0]; | |||
| } | |||
| } // namespace runtime | |||
| } // namespace mindspore | |||
| @@ -42,34 +42,28 @@ class CopyActor : public MemoryAwareActor { | |||
| void Init() override; | |||
| // The copy actor run when receive the input data. | |||
| void RunOpData(OpData<DeviceTensor> *const input_data, OpContext<DeviceTensor> *const context) override; | |||
| // The copy actor run when receive the input control. | |||
| void RunOpControl(AID *const input_control, OpContext<DeviceTensor> *const context) override; | |||
| // The memory related operation interface. | |||
| void SendMemoryAllocReq(OpContext<DeviceTensor> *const context) override; | |||
| void SendMemoryFreeReq(OpContext<DeviceTensor> *const context) override; | |||
| // The copy processing after memory alloc finished. | |||
| void OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) override; | |||
| protected: | |||
| void Run(OpContext<DeviceTensor> *const context) override; | |||
| void UpdateOutputData(OpData<DeviceTensor> *const output_data, const DataArrow *data_arrow, | |||
| OpContext<DeviceTensor> *const context) override; | |||
| private: | |||
| friend class GraphScheduler; | |||
| // Fetch the device tensor for copy. | |||
| void FetchDeviceTensor(OpContext<DeviceTensor> *const context); | |||
| // Send output data and output controls when finish copy. | |||
| void SendOutput(OpContext<DeviceTensor> *const context) const; | |||
| // The input device tensor is saved from the input data or fetched by device_tensor_store_keys_. | |||
| std::vector<DeviceTensor *> input_device_tensor_; | |||
| // The output device tensor is saved from the output or fetched by device_tensor_store_keys_. | |||
| std::vector<DeviceTensor *> output_device_tensor_; | |||
| // The output_data_ corresponds to the output_data_arrows_ one by one. | |||
| std::vector<OpDataUniquePtr<DeviceTensor>> output_data_; | |||
| // The output is created in the copy actor build, so can't be the raw pointer. | |||
| DeviceTensorPtr output_; | |||
| }; | |||
| @@ -124,7 +124,7 @@ void DataPrepareActor::Init() { | |||
| void DataPrepareActor::PrepareData(const std::vector<std::vector<TensorPtr>> &input_tensors, | |||
| OpContext<DeviceTensor> *const context) { | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| MS_LOG(INFO) << "Data prepare actor(" << GetAID().Name() << ") prepares data."; | |||
| MS_LOG(DEBUG) << "Data prepare actor(" << GetAID().Name() << ") prepares data."; | |||
| // Convert actor running data from input tensors. | |||
| if (input_tensors.size() > 0) { | |||
| @@ -175,22 +175,6 @@ void DataPrepareActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *const contex | |||
| SendOutput(context); | |||
| } | |||
| void DataPrepareActor::SendOutput(OpContext<DeviceTensor> *const context) { | |||
| for (auto &data_source_aid : data_source_aids_) { | |||
| Async(data_source_aid, &DataSourceActor::FetchData, context); | |||
| } | |||
| auto source_aid = const_cast<AID *>(&GetAID()); | |||
| for (auto &kernel_aid : no_input_kernel_aids_) { | |||
| Async(kernel_aid, &OpActor::RunOpControl, source_aid, context); | |||
| } | |||
| // Trigger loop count actor running when there are no data source actor and kernel actor. | |||
| if ((data_source_aids_.size() + no_input_kernel_aids_.size() == 0) && (loop_count_aid_ != nullptr)) { | |||
| Async(*loop_count_aid_, &LoopCountActor::RunOpControl, source_aid, context); | |||
| } | |||
| } | |||
| void DataPrepareActor::PrepareDataForDeviceTensorStore(const std::vector<std::vector<TensorPtr>> &input_tensors, | |||
| OpContext<DeviceTensor> *const context) { | |||
| for (size_t i = 0; i < graph_compiler_info_->graphs_.size(); ++i) { | |||
| @@ -45,8 +45,7 @@ class DataPrepareActor : public DebugAwareActor { | |||
| graph_compiler_info_(graph_compiler_info), | |||
| strategy_(GraphExecutionStrategy::kPipeline), | |||
| host_data_source_actor_(host_data_source_actor), | |||
| host_tensor_queue_(host_tensor_queue), | |||
| loop_count_aid_(nullptr) {} | |||
| host_tensor_queue_(host_tensor_queue) {} | |||
| ~DataPrepareActor() override = default; | |||
| void Init() override; | |||
| @@ -65,9 +64,6 @@ class DataPrepareActor : public DebugAwareActor { | |||
| private: | |||
| friend class GraphScheduler; | |||
| // Send output controls when finish data prepare. | |||
| void SendOutput(OpContext<DeviceTensor> *const context); | |||
| void PrepareDataForDeviceTensorStore(const std::vector<std::vector<TensorPtr>> &input_tensors, | |||
| OpContext<DeviceTensor> *const context); | |||
| void PrepareDataForHostTensorQueue(const std::vector<std::vector<TensorPtr>> &input_tensors, | |||
| @@ -103,12 +99,6 @@ class DataPrepareActor : public DebugAwareActor { | |||
| HostQueueDSActorPtr host_data_source_actor_; | |||
| HostTensorQueuePtr host_tensor_queue_; | |||
| // The output controls contain the data source actors and the no input kernel actors. | |||
| std::vector<AID> data_source_aids_; | |||
| std::vector<AID> no_input_kernel_aids_; | |||
| // If has no data source actor and kernel actor, then need send to loop count actor. | |||
| const AID *loop_count_aid_; | |||
| // The nodes need continuous memory, which must allocate in the begin of step running. The first bool of pair | |||
| // expresses the inputs of node need continuous memory, the second bool of pair expresses the outputs of node need | |||
| // continuous memory. | |||
| @@ -58,43 +58,21 @@ void DataSourceActor::FetchData(OpContext<DeviceTensor> *const context) { | |||
| SendMemoryAllocReq(context); | |||
| } | |||
| void DataSourceActor::SendOutput(OpContext<DeviceTensor> *const context) { | |||
| void DataSourceActor::UpdateOutputData(OpData<DeviceTensor> *const output_data, const DataArrow *data_arrow, | |||
| OpContext<DeviceTensor> *const context) { | |||
| MS_EXCEPTION_IF_NULL(output_data); | |||
| MS_EXCEPTION_IF_NULL(data_arrow); | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| // No output. | |||
| if ((output_data_arrows_.size() == 0) && (output_control_arrows_.size() == 0) && | |||
| (output_result_arrows_.size() == 0)) { | |||
| SET_OPCONTEXT_SUCCESS_RET((*context)); | |||
| } | |||
| if (buffers_.size() == 0) { | |||
| SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "The data queue is empty."); | |||
| } | |||
| // Must be the execution order: send result --> send data --> send control, avoid the illegal timing problem. | |||
| // 1.Send graph output result. | |||
| SendOutputResult(context); | |||
| // 2.Send output data. | |||
| const auto &output_device_tensors = buffers_.front(); | |||
| for (size_t i = 0; i < output_data_arrows_.size(); ++i) { | |||
| auto &data_arrow = output_data_arrows_[i]; | |||
| auto &output_data = output_data_[i]; | |||
| MS_EXCEPTION_IF_NULL(data_arrow); | |||
| MS_EXCEPTION_IF_NULL(output_data); | |||
| if (IntToSize(data_arrow->from_output_index_) >= output_device_tensors.size()) { | |||
| SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "The output index is of range."); | |||
| } | |||
| output_data->data_ = output_device_tensors[data_arrow->from_output_index_]; | |||
| Async(data_arrow->to_op_id_, &OpActor::RunOpData, output_data.get(), context); | |||
| } | |||
| // 3.Send output control. | |||
| SendOutputControl(context); | |||
| // 4.Send recorder info. | |||
| if (recorder_aid_ != nullptr) { | |||
| SendRecorderInfo(context); | |||
| if (IntToSize(data_arrow->from_output_index_) >= output_device_tensors.size()) { | |||
| SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "The output index is of range."); | |||
| } | |||
| output_data->data_ = output_device_tensors[data_arrow->from_output_index_]; | |||
| } | |||
| void DeviceQueueDataSourceActor::Init() { | |||
| @@ -180,6 +158,8 @@ void DeviceQueueDataSourceActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *co | |||
| return; | |||
| } | |||
| EraseInput(context); | |||
| // Note that SendMemoryFreeReq must be in front of SendOutput, because SendOutput will trigger SendMemoryAllocReq of | |||
| // the next actor and the actor is asynchronous execution. So it is necessary to ensure that SendMemoryFreeReq of | |||
| // the current actor is in front of SendMemoryAllocReq of the next actor. One is to reuse the memory more fully, | |||
| @@ -197,7 +177,7 @@ void DeviceQueueDataSourceActor::OnDebugFinish(OpContext<DeviceTensor> *const co | |||
| SendOutput(context); | |||
| } | |||
| void DeviceQueueDataSourceActor::SendRecorderInfo(OpContext<DeviceTensor> *const context) { | |||
| void DeviceQueueDataSourceActor::SendRecorderInfo(OpContext<DeviceTensor> *const context) const { | |||
| if (recorder_aid_ != nullptr) { | |||
| MS_EXCEPTION_IF_NULL(data_kernel_); | |||
| Async(*recorder_aid_, &RecorderActor::RecordInfo, data_kernel_->fullname_with_scope(), &launch_info_, | |||
| @@ -279,6 +259,8 @@ void HostQueueDataSourceActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *cons | |||
| } | |||
| host_queue_->Pop(); | |||
| EraseInput(context); | |||
| // Note that SendMemoryFreeReq must be in front of SendOutput, because SendOutput will trigger SendMemoryAllocReq of | |||
| // the next actor and the actor is asynchronous execution. So it is necessary to ensure that SendMemoryFreeReq of | |||
| // the current actor is in front of SendMemoryAllocReq of the next actor. One is to reuse the memory more fully, | |||
| @@ -48,27 +48,23 @@ class DataSourceActor : public DebugAwareActor { | |||
| void Init() override; | |||
| // The process entry of data processing. | |||
| void FetchData(OpContext<DeviceTensor> *const context); | |||
| protected: | |||
| friend class GraphScheduler; | |||
| void Run(OpContext<DeviceTensor> *const context) override { FetchData(context); } | |||
| // The process entry of data processing. | |||
| void FetchData(OpContext<DeviceTensor> *const context); | |||
| // Construct the device tensors and fill to device tensor buffer from the member nodes during the data fetching. | |||
| virtual void FillDataBuffer() = 0; | |||
| // Send recorder info to recorder actor, only the device queue data source actor need. | |||
| virtual void SendRecorderInfo(OpContext<DeviceTensor> *const context) {} | |||
| // Send output to downstream actors to trigger computing after fetching data finished. | |||
| void SendOutput(OpContext<DeviceTensor> *const context); | |||
| void UpdateOutputData(OpData<DeviceTensor> *const output_data, const DataArrow *data_arrow, | |||
| OpContext<DeviceTensor> *const context) override; | |||
| // The buffers store the device tensors. | |||
| std::queue<std::vector<DeviceTensor *>> buffers_; | |||
| size_t buffer_capacity_; | |||
| // The output_data_ corresponds to the output_data_arrows_ one by one. | |||
| std::vector<OpDataUniquePtr<DeviceTensor>> output_data_; | |||
| }; | |||
| // The class represents that the data source is device queue. | |||
| @@ -95,7 +91,7 @@ class DeviceQueueDataSourceActor : public DataSourceActor { | |||
| protected: | |||
| void FillDataBuffer() override; | |||
| void SendRecorderInfo(OpContext<DeviceTensor> *const context) override; | |||
| void SendRecorderInfo(OpContext<DeviceTensor> *const context) const override; | |||
| private: | |||
| friend class GraphScheduler; | |||
| @@ -74,59 +74,26 @@ void KernelActor::Init() { | |||
| auto device_address = output_device_tensors_[data_arrow->from_output_index_]; | |||
| auto data = | |||
| std::make_unique<OpData<DeviceTensor>>(data_arrow->to_op_id_, device_address, data_arrow->to_input_index_); | |||
| (void)output_data_.emplace_back(data.get()); | |||
| (void)output_data_by_output_index_[data_arrow->from_output_index_].emplace_back(std::move(data)); | |||
| (void)output_data_by_output_index_[data_arrow->from_output_index_].emplace_back(data.get()); | |||
| (void)output_data_.emplace_back(std::move(data)); | |||
| } | |||
| } | |||
| void KernelActor::RunOpData(OpData<DeviceTensor> *const input_data, OpContext<DeviceTensor> *const context) { | |||
| void KernelActor::Run(OpContext<DeviceTensor> *const context) { | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| MS_EXCEPTION_IF_NULL(device_contexts_[0]); | |||
| auto &sequential_num = context->sequential_num_; | |||
| (void)input_op_datas_[sequential_num].emplace_back(input_data); | |||
| if (input_data->data_ == nullptr) { | |||
| std::string error_info = | |||
| "Input data of actor:" + GetAID().Name() + " num:" + std::to_string(input_data->index_) + " is empty"; | |||
| SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info); | |||
| } | |||
| // When all the inputs are collected, then allocate memory and callback launch. | |||
| if (CheckRunningCondition(context)) { | |||
| // Infer kernel shape and update abstract info for dynamic shape kernel. | |||
| if (is_dynamic_shape_) { | |||
| device_contexts_[0]->UpdateDynamicShape(kernel_); | |||
| } | |||
| FetchInputDeviceTensor(context); | |||
| FetchOutputDeviceTensor(); | |||
| if (memory_alloc_list_.size() > 0) { | |||
| SendMemoryAllocReq(context); | |||
| } else { | |||
| OnMemoryAllocFinish(context); | |||
| } | |||
| // Infer kernel shape and update abstract info for dynamic shape kernel. | |||
| if (is_dynamic_shape_) { | |||
| device_contexts_[0]->UpdateDynamicShape(kernel_); | |||
| } | |||
| } | |||
| void KernelActor::RunOpControl(AID *const input_control, OpContext<DeviceTensor> *const context) { | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| MS_EXCEPTION_IF_NULL(device_contexts_[0]); | |||
| auto &sequential_num = context->sequential_num_; | |||
| (void)input_op_controls_[sequential_num].emplace_back(input_control); | |||
| // When all the inputs are collected, then allocate memory and callback launch. | |||
| if (CheckRunningCondition(context)) { | |||
| // Infer kernel shape and update abstract info for dynamic shape kernel. | |||
| if (is_dynamic_shape_) { | |||
| device_contexts_[0]->UpdateDynamicShape(kernel_); | |||
| } | |||
| FetchInputDeviceTensor(context); | |||
| FetchOutputDeviceTensor(); | |||
| if (memory_alloc_list_.size() > 0) { | |||
| SendMemoryAllocReq(context); | |||
| } else { | |||
| OnMemoryAllocFinish(context); | |||
| } | |||
| FetchInputDeviceTensor(context); | |||
| FetchOutputDeviceTensor(); | |||
| if (memory_alloc_list_.size() > 0) { | |||
| SendMemoryAllocReq(context); | |||
| } else { | |||
| OnMemoryAllocFinish(context); | |||
| } | |||
| } | |||
| @@ -410,40 +377,18 @@ void KernelActor::PostLaunchKernel(OpContext<DeviceTensor> *const context) { | |||
| if (memory_free_list_.size() > 0) { | |||
| SendMemoryFreeReq(context); | |||
| } | |||
| SendOutput(context); | |||
| } | |||
| void KernelActor::SendOutput(OpContext<DeviceTensor> *const context) const { | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| MS_EXCEPTION_IF_NULL(kernel_); | |||
| if (strategy_ == GraphExecutionStrategy::kStep) { | |||
| return; | |||
| } | |||
| // Must be the execution order: send result --> send data --> send control, avoid the illegal timing problem. | |||
| // 1.Send graph output result. | |||
| SendOutputResult(context); | |||
| // 2.Send output data. | |||
| for (auto &output_data : output_data_) { | |||
| MS_EXCEPTION_IF_NULL(output_data); | |||
| Async(output_data->op_id_, &OpActor::RunOpData, output_data, context); | |||
| if (strategy_ == GraphExecutionStrategy::kPipeline) { | |||
| SendOutput(context); | |||
| } | |||
| } | |||
| // 3.Send output control. | |||
| SendOutputControl(context); | |||
| // 4.Send recorder info. | |||
| void KernelActor::SendRecorderInfo(OpContext<DeviceTensor> *const context) const { | |||
| if (recorder_aid_ != nullptr) { | |||
| MS_EXCEPTION_IF_NULL(kernel_); | |||
| Async(*recorder_aid_, &RecorderActor::RecordInfo, kernel_->fullname_with_scope(), &launch_info_, | |||
| device_contexts_[0], context); | |||
| } | |||
| // No output. | |||
| if ((output_data_arrows_.size() == 0) && (output_control_arrows_.size() == 0) && | |||
| (output_result_arrows_.size() == 0)) { | |||
| SET_OPCONTEXT_SUCCESS_RET((*context)); | |||
| } | |||
| } | |||
| } // namespace runtime | |||
| } // namespace mindspore | |||
| @@ -58,10 +58,6 @@ class KernelActor : public DebugAwareActor { | |||
| void Init() override; | |||
| // The kernel actor run when receive the input data. | |||
| void RunOpData(OpData<DeviceTensor> *const input_data, OpContext<DeviceTensor> *const context) override; | |||
| // The kernel actor run when receive the input control. | |||
| void RunOpControl(AID *const input_control, OpContext<DeviceTensor> *const context) override; | |||
| // The kernel actor run when receive the input control and input tensors, used in step mode. | |||
| void RunOpControlWithInputTensor(AID *const input_control, OpContext<DeviceTensor> *const context, | |||
| const std::vector<TensorPtr> *input_tensors); | |||
| @@ -77,6 +73,10 @@ class KernelActor : public DebugAwareActor { | |||
| // The callback after debug finished. | |||
| void OnDebugFinish(OpContext<DeviceTensor> *const context) override; | |||
| protected: | |||
| void Run(OpContext<DeviceTensor> *const context) override; | |||
| void SendRecorderInfo(OpContext<DeviceTensor> *const context) const override; | |||
| private: | |||
| friend class GraphScheduler; | |||
| @@ -92,9 +92,6 @@ class KernelActor : public DebugAwareActor { | |||
| // The processing after kernel launch: 1.erase input, 2.free memory, 3.send output. | |||
| void PostLaunchKernel(OpContext<DeviceTensor> *const context); | |||
| // Send output data and output controls when finish kernel launch. | |||
| void SendOutput(OpContext<DeviceTensor> *const context) const; | |||
| // The info of kernel. | |||
| CNodePtr kernel_; | |||
| KernelInfo *kernel_info_; | |||
| @@ -127,10 +124,8 @@ class KernelActor : public DebugAwareActor { | |||
| // The kernel launch info is fetched by the device tensors. | |||
| KernelLaunchInfo launch_info_; | |||
| // Cache unique output data by output index to modify the output data effectively. | |||
| std::vector<std::vector<OpDataUniquePtr<DeviceTensor>>> output_data_by_output_index_; | |||
| // The output_data_ corresponds to the output_data_arrows_ one by one. | |||
| std::vector<OpData<DeviceTensor> *> output_data_; | |||
| // Cache output data by output index to modify the output data effectively. | |||
| std::vector<std::vector<OpData<DeviceTensor> *>> output_data_by_output_index_; | |||
| }; | |||
| using KernelActorPtr = std::shared_ptr<KernelActor>; | |||
| @@ -25,15 +25,11 @@ | |||
| namespace mindspore { | |||
| namespace runtime { | |||
| void LoopCountActor::RunOpControl(AID *const input_control, OpContext<DeviceTensor> *const context) { | |||
| void LoopCountActor::Run(OpContext<DeviceTensor> *const context) { | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| auto sequential_num = context->sequential_num_; | |||
| (void)input_op_controls_[sequential_num].emplace_back(input_control); | |||
| if (CheckRunningCondition(context)) { | |||
| // Need wait MemoryManagerActor running finished to avoid the illegal memory timing problem before | |||
| // LoopCountActor exits, because other processors which are not in actor also will process device tensor. | |||
| Async(memory_manager_aid_, &MemoryManagerActor::Wait, context, GetAID()); | |||
| } | |||
| // Need wait MemoryManagerActor running finished to avoid the illegal memory timing problem before | |||
| // LoopCountActor exits, because other processors which are not in actor also will process device tensor. | |||
| Async(memory_manager_aid_, &MemoryManagerActor::Wait, context, GetAID()); | |||
| } | |||
| void LoopCountActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) { | |||
| @@ -43,9 +43,6 @@ class LoopCountActor : public DebugAwareActor { | |||
| ~LoopCountActor() override = default; | |||
| // The loop count actor run when receive the input control. | |||
| void RunOpControl(AID *const input_control, OpContext<DeviceTensor> *const context) override; | |||
| // The callback waits for the memory manager actor to finish all the message processing. | |||
| void OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) override; | |||
| @@ -54,11 +51,14 @@ class LoopCountActor : public DebugAwareActor { | |||
| // The callback after debug finished. | |||
| void OnDebugFinish(OpContext<DeviceTensor> *const context) override; | |||
| protected: | |||
| void Run(OpContext<DeviceTensor> *const context) override; | |||
| void SendOutput(OpContext<DeviceTensor> *const context) override; | |||
| private: | |||
| friend class GraphScheduler; | |||
| void IncreaseLoopCount(OpContext<DeviceTensor> *const context); | |||
| void SendOutput(OpContext<DeviceTensor> *const context); | |||
| // The loop count is constant, the current count is increased after each step running finished. | |||
| size_t loop_count_; | |||
| @@ -76,6 +76,9 @@ class OutputActor : public AbstractActor { | |||
| size_t loop_count_; | |||
| size_t current_count_; | |||
| // The dependent input result arrow actors. | |||
| std::vector<AID> input_result_arrow_aids_; | |||
| // The outputs. | |||
| std::vector<TensorPtr> outputs_; | |||
| std::vector<KernelWithIndex> output_nodes_; | |||
| @@ -32,68 +32,28 @@ void SuperKernelActor::Init() { | |||
| running_dependent_msg_num_ = SizeToInt(input_datas_num_ + input_controls_num_); | |||
| } | |||
| void SuperKernelActor::RunOpData(OpData<DeviceTensor> *const input_data, OpContext<DeviceTensor> *const context) { | |||
| void SuperKernelActor::Run(OpContext<DeviceTensor> *const context) { | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| MS_EXCEPTION_IF_NULL(graph_); | |||
| MS_EXCEPTION_IF_NULL(device_contexts_[0]); | |||
| MS_LOG(INFO) << "Super kernel actor(" << GetAID().Name() << ") launches graph: " << graph_->graph_id(); | |||
| auto &sequential_num = context->sequential_num_; | |||
| (void)input_op_datas_[sequential_num].emplace_back(input_data); | |||
| if (CheckRunningCondition(context)) { | |||
| MS_LOG(INFO) << "Super kernel actor(" << GetAID().Name() << ") launches graph: " << graph_->graph_id(); | |||
| try { | |||
| auto ret = device_contexts_[0]->LaunchGraph(graph_); | |||
| if (!ret) { | |||
| std::string error_info = "Launch graph failed, graph id: " + graph_->graph_id(); | |||
| SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info); | |||
| } | |||
| } catch (const std::exception &e) { | |||
| MsException::Instance().SetException(); | |||
| try { | |||
| auto ret = device_contexts_[0]->LaunchGraph(graph_); | |||
| if (!ret) { | |||
| std::string error_info = "Launch graph failed, graph id: " + graph_->graph_id(); | |||
| SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info); | |||
| } | |||
| // The input is invalid and needs to be erased when finish kernel launch. | |||
| EraseInput(context); | |||
| SendOutput(context); | |||
| } catch (const std::exception &e) { | |||
| MsException::Instance().SetException(); | |||
| std::string error_info = "Launch graph exception, graph id: " + graph_->graph_id(); | |||
| SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info); | |||
| } | |||
| } | |||
| void SuperKernelActor::RunOpControl(AID *const input_control, OpContext<DeviceTensor> *const context) { | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| MS_EXCEPTION_IF_NULL(device_contexts_[0]); | |||
| auto &sequential_num = context->sequential_num_; | |||
| (void)input_op_controls_[sequential_num].emplace_back(input_control); | |||
| if (CheckRunningCondition(context)) { | |||
| MS_LOG(INFO) << "Super kernel actor(" << GetAID().Name() << ") launches graph: " << graph_->graph_id(); | |||
| try { | |||
| auto ret = device_contexts_[0]->LaunchGraph(graph_); | |||
| if (!ret) { | |||
| std::string error_info = "Launch graph failed, graph id: " + graph_->graph_id(); | |||
| SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info); | |||
| } | |||
| } catch (const std::exception &e) { | |||
| MsException::Instance().SetException(); | |||
| std::string error_info = "Launch graph failed, graph id: " + graph_->graph_id(); | |||
| SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info); | |||
| } | |||
| // The input is invalid and needs to be erased when finish kernel launch. | |||
| EraseInput(context); | |||
| SendOutput(context); | |||
| } | |||
| // The input is invalid and needs to be erased when finish kernel launch. | |||
| EraseInput(context); | |||
| SendOutput(context); | |||
| } | |||
| void SuperKernelActor::SendOutput(OpContext<DeviceTensor> *const context) const { | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| SendOutputResult(context); | |||
| SendOutputControl(context); | |||
| // No output. | |||
| if ((output_data_arrows_.size() == 0) && (output_control_arrows_.size() == 0) && | |||
| (output_result_arrows_.size() == 0)) { | |||
| SET_OPCONTEXT_SUCCESS_RET((*context)); | |||
| } | |||
| } | |||
| } // namespace runtime | |||
| } // namespace mindspore | |||
| @@ -42,18 +42,12 @@ class SuperKernelActor : public DebugAwareActor { | |||
| void Init() override; | |||
| // The super kernel actor run when receive the input data. | |||
| void RunOpData(OpData<DeviceTensor> *const input_data, OpContext<DeviceTensor> *const context) override; | |||
| // The super kernel actor run when receive the input control. | |||
| void RunOpControl(AID *const input_control, OpContext<DeviceTensor> *const context) override; | |||
| protected: | |||
| void Run(OpContext<DeviceTensor> *const context) override; | |||
| private: | |||
| friend class GraphScheduler; | |||
| // Send output data and output controls when finish kernel launch. | |||
| void SendOutput(OpContext<DeviceTensor> *const context) const; | |||
| KernelGraphPtr graph_; | |||
| }; | |||
| @@ -63,42 +63,42 @@ inline bool IsSingleOpActorSet(const ActorSet *actor_set) { | |||
| } | |||
| // Convert the actors vector by the actor set. | |||
| std::vector<ActorReference> CollectActors(const ActorSet *actor_set) { | |||
| std::vector<AbstractActorPtr> CollectActors(const ActorSet *actor_set) { | |||
| MS_EXCEPTION_IF_NULL(actor_set); | |||
| std::vector<ActorReference> actors; | |||
| std::vector<AbstractActorPtr> actors; | |||
| if (actor_set->data_prepare_actor_ != nullptr) { | |||
| (void)actors.emplace_back(static_cast<ActorReference>(actor_set->data_prepare_actor_)); | |||
| (void)actors.emplace_back(static_cast<AbstractActorPtr>(actor_set->data_prepare_actor_)); | |||
| } | |||
| for (auto &data_source_actor : actor_set->data_source_actors_) { | |||
| MS_EXCEPTION_IF_NULL(data_source_actor); | |||
| (void)actors.emplace_back(static_cast<ActorReference>(data_source_actor)); | |||
| (void)actors.emplace_back(static_cast<AbstractActorPtr>(data_source_actor)); | |||
| } | |||
| for (auto &kernel_actor : actor_set->kernel_actors_) { | |||
| MS_EXCEPTION_IF_NULL(kernel_actor); | |||
| (void)actors.emplace_back(static_cast<ActorReference>(kernel_actor)); | |||
| (void)actors.emplace_back(static_cast<AbstractActorPtr>(kernel_actor)); | |||
| } | |||
| for (auto &super_kernel_actor : actor_set->super_kernel_actors_) { | |||
| MS_EXCEPTION_IF_NULL(super_kernel_actor); | |||
| (void)actors.emplace_back(static_cast<ActorReference>(super_kernel_actor)); | |||
| (void)actors.emplace_back(static_cast<AbstractActorPtr>(super_kernel_actor)); | |||
| } | |||
| for (auto &switch_actor : actor_set->switch_actors_) { | |||
| MS_EXCEPTION_IF_NULL(switch_actor); | |||
| (void)actors.emplace_back(static_cast<ActorReference>(switch_actor)); | |||
| (void)actors.emplace_back(static_cast<AbstractActorPtr>(switch_actor)); | |||
| } | |||
| for (auto &gather_actor : actor_set->gather_actors_) { | |||
| MS_EXCEPTION_IF_NULL(gather_actor); | |||
| (void)actors.emplace_back(static_cast<ActorReference>(gather_actor)); | |||
| (void)actors.emplace_back(static_cast<AbstractActorPtr>(gather_actor)); | |||
| } | |||
| for (auto ©_actor : actor_set->copy_actors_) { | |||
| MS_EXCEPTION_IF_NULL(copy_actor); | |||
| (void)actors.emplace_back(static_cast<ActorReference>(copy_actor)); | |||
| (void)actors.emplace_back(static_cast<AbstractActorPtr>(copy_actor)); | |||
| } | |||
| if (actor_set->loop_count_actor_ != nullptr) { | |||
| (void)actors.emplace_back(static_cast<ActorReference>(actor_set->loop_count_actor_)); | |||
| (void)actors.emplace_back(static_cast<AbstractActorPtr>(actor_set->loop_count_actor_)); | |||
| } | |||
| if (actor_set->output_actor_ != nullptr) { | |||
| (void)actors.emplace_back(static_cast<ActorReference>(actor_set->output_actor_)); | |||
| (void)actors.emplace_back(static_cast<AbstractActorPtr>(actor_set->output_actor_)); | |||
| } | |||
| return actors; | |||
| @@ -294,8 +294,8 @@ ActorSet *GraphScheduler::Transform(const GraphCompilerInfo &graph_compiler_info | |||
| (void)actors_.emplace(actor_set->name_, actor_set); | |||
| DumpActor(actor_set.get(), graph_compiler_info); | |||
| if (!CheckActorValid(actor_set.get(), graph_compiler_info.strategy_)) { | |||
| MS_LOG(EXCEPTION) << "The actor set of " << graph_compiler_info.name_ << " is invalid."; | |||
| if (graph_compiler_info.strategy_ == GraphExecutionStrategy::kPipeline) { | |||
| CheckActorValid(actor_set.get()); | |||
| } | |||
| MS_LOG(INFO) << "Graph(" << graph_compiler_info.name_ << ") transforms actor end."; | |||
| @@ -1072,6 +1072,7 @@ void GraphScheduler::LinkDataArrowForCopyActor(AbstractActor *const from_actor, | |||
| // Link. | |||
| (void)from_actor->output_data_arrows_.emplace_back(op_arrow_to_copy); | |||
| copy_actor->input_datas_num_++; | |||
| (void)copy_actor->input_data_arrow_aids_.emplace_back(from_actor->GetAID()); | |||
| // Set the member of the copy actor. | |||
| auto to_kernel_mod = AnfAlgo::GetKernelMod(to_kernel_with_input_idx.first); | |||
| @@ -1093,6 +1094,7 @@ void GraphScheduler::LinkDataArrowForCopyActor(AbstractActor *const from_actor, | |||
| auto op_arrow_from_copy = std::make_shared<DataArrow>(0, to_actor->GetAID(), to_input_index); | |||
| (void)copy_actor->output_data_arrows_.emplace_back(op_arrow_from_copy); | |||
| to_actor->input_datas_num_++; | |||
| (void)to_actor->input_data_arrow_aids_.emplace_back(copy_actor->GetAID()); | |||
| UpdateRefCount(copy_actor->output_.get()); | |||
| } | |||
| @@ -1171,6 +1173,7 @@ void GraphScheduler::LinkControlArrowByAutoMonad(AbstractActor *to_actor, const | |||
| << ", to actor: " << to_actor->GetAID().Name(); | |||
| (void)from_actor->output_control_arrows_.emplace_back(to_actor->GetAID()); | |||
| to_actor->input_controls_num_++; | |||
| (void)to_actor->input_control_arrow_aids_.emplace_back(from_actor->GetAID()); | |||
| } | |||
| } | |||
| @@ -1190,6 +1193,7 @@ void GraphScheduler::LinkControlArrowBySkippedNode(AbstractActor *to_actor, cons | |||
| << ", from actor: " << from_actor->GetAID().Name() << ", to actor: " << to_actor->GetAID().Name(); | |||
| (void)from_actor->output_control_arrows_.emplace_back(to_aid); | |||
| to_actor->input_controls_num_++; | |||
| (void)to_actor->input_control_arrow_aids_.emplace_back(from_actor->GetAID()); | |||
| } | |||
| } | |||
| @@ -1216,16 +1220,19 @@ void GraphScheduler::LinkControlArrowBySendRecvNodes(const KernelGraphPtr &graph | |||
| if (input_actor != nullptr) { | |||
| (void)input_actor->output_control_arrows_.emplace_back(from_send_actor->GetAID()); | |||
| from_send_actor->input_controls_num_++; | |||
| (void)from_send_actor->input_control_arrow_aids_.emplace_back(input_actor->GetAID()); | |||
| } | |||
| } | |||
| // from_send_actor --> from_recv_actor | |||
| (void)from_send_actor->output_control_arrows_.emplace_back(from_recv_actor->GetAID()); | |||
| from_recv_actor->input_controls_num_++; | |||
| (void)from_recv_actor->input_control_arrow_aids_.emplace_back(from_send_actor->GetAID()); | |||
| // from_recv_actor --> to_allreduce_actor | |||
| (void)from_recv_actor->output_control_arrows_.emplace_back(to_allreduce_actor->GetAID()); | |||
| to_allreduce_actor->input_controls_num_++; | |||
| (void)to_allreduce_actor->input_control_arrow_aids_.emplace_back(from_recv_actor->GetAID()); | |||
| } | |||
| for (auto &to_iter : graph->allreduce_to_send_recv_pairs()) { | |||
| @@ -1246,10 +1253,12 @@ void GraphScheduler::LinkControlArrowBySendRecvNodes(const KernelGraphPtr &graph | |||
| // from_allreduce_actor --> to_send_actor | |||
| (void)from_allreduce_actor->output_control_arrows_.emplace_back(to_send_actor->GetAID()); | |||
| to_send_actor->input_controls_num_++; | |||
| (void)to_send_actor->input_control_arrow_aids_.emplace_back(from_allreduce_actor->GetAID()); | |||
| // to_send_actor --> to_recv_actor | |||
| (void)to_send_actor->output_control_arrows_.emplace_back(to_recv_actor->GetAID()); | |||
| to_recv_actor->input_controls_num_++; | |||
| (void)to_recv_actor->input_control_arrow_aids_.emplace_back(to_send_actor->GetAID()); | |||
| // to_recv_actor --> outputs of from_allreduce_actor | |||
| for (auto &output_data_arrow : from_allreduce_actor->output_data_arrows_) { | |||
| @@ -1257,6 +1266,7 @@ void GraphScheduler::LinkControlArrowBySendRecvNodes(const KernelGraphPtr &graph | |||
| if (output_actor != nullptr) { | |||
| (void)to_recv_actor->output_control_arrows_.emplace_back(output_actor->GetAID()); | |||
| output_actor->input_controls_num_++; | |||
| (void)output_actor->input_control_arrow_aids_.emplace_back(to_recv_actor->GetAID()); | |||
| } | |||
| } | |||
| @@ -1309,6 +1319,7 @@ void GraphScheduler::LinkControlArrowByCommunicationNode(const std::vector<CNode | |||
| MS_EXCEPTION_IF_NULL(to_actor); | |||
| (void)from_actor->output_control_arrows_.emplace_back(to_actor->GetAID()); | |||
| to_actor->input_controls_num_++; | |||
| (void)to_actor->input_control_arrow_aids_.emplace_back(from_actor->GetAID()); | |||
| } | |||
| // Ensure all actors execute orderly to optimize the execution performance in the multi device scenario currently. | |||
| @@ -1322,6 +1333,7 @@ void GraphScheduler::LinkControlArrowByCommunicationNode(const std::vector<CNode | |||
| if ((from_actor != nullptr) && (to_actor != nullptr)) { | |||
| (void)from_actor->output_control_arrows_.emplace_back(to_actor->GetAID()); | |||
| to_actor->input_controls_num_++; | |||
| (void)to_actor->input_control_arrow_aids_.emplace_back(from_actor->GetAID()); | |||
| } | |||
| } | |||
| } | |||
| @@ -1335,20 +1347,25 @@ void GraphScheduler::LinkControlArrowForDataPrepareActor(DataPrepareActor *data_ | |||
| // Data prepare actor --> data source actor. | |||
| for (auto &data_source_actor : actor_set->data_source_actors_) { | |||
| MS_EXCEPTION_IF_NULL(data_source_actor); | |||
| (void)data_prepare_actor->data_source_aids_.emplace_back(data_source_actor->GetAID()); | |||
| (void)data_prepare_actor->output_control_arrows_.emplace_back(data_source_actor->GetAID()); | |||
| data_source_actor->input_controls_num_++; | |||
| (void)data_source_actor->input_control_arrow_aids_.emplace_back(data_prepare_actor->GetAID()); | |||
| } | |||
| // Data prepare actor --> no input kernel actor. | |||
| for (auto &no_input_kernel_actor : actor_set->no_input_kernel_actors_) { | |||
| MS_EXCEPTION_IF_NULL(no_input_kernel_actor); | |||
| (void)data_prepare_actor->no_input_kernel_aids_.emplace_back(no_input_kernel_actor->GetAID()); | |||
| (void)data_prepare_actor->output_control_arrows_.emplace_back(no_input_kernel_actor->GetAID()); | |||
| no_input_kernel_actor->input_controls_num_++; | |||
| (void)no_input_kernel_actor->input_control_arrow_aids_.emplace_back(data_prepare_actor->GetAID()); | |||
| } | |||
| // Data prepare actor --> loop count actor. | |||
| if ((actor_set->data_source_actors_.size() + actor_set->no_input_kernel_actors_.size() == 0) && | |||
| (actor_set->loop_count_actor_ != nullptr)) { | |||
| data_prepare_actor->loop_count_aid_ = &(actor_set->loop_count_actor_->GetAID()); | |||
| (void)data_prepare_actor->output_control_arrows_.emplace_back(actor_set->loop_count_actor_->GetAID()); | |||
| actor_set->loop_count_actor_->input_controls_num_++; | |||
| (void)actor_set->loop_count_actor_->input_control_arrow_aids_.emplace_back(data_prepare_actor->GetAID()); | |||
| } | |||
| } | |||
| @@ -1392,6 +1409,7 @@ void GraphScheduler::LinkControlArrowForLoopCountActor(LoopCountActor *loop_coun | |||
| for (auto &no_output_actor : no_output_actors) { | |||
| (void)no_output_actor->output_control_arrows_.emplace_back(loop_count_actor->GetAID()); | |||
| loop_count_actor->input_controls_num_++; | |||
| (void)loop_count_actor->input_control_arrow_aids_.emplace_back(no_output_actor->GetAID()); | |||
| } | |||
| // Loop count actor --> data prepare actor. | |||
| @@ -1463,6 +1481,7 @@ void GraphScheduler::LinkOutputResultArrowForOutputActor(OutputActor *to_actor, | |||
| auto op_arrow = std::make_shared<DataArrow>(output_with_index.second, to_actor->GetAID(), output_position); | |||
| (void)from_actor->output_result_arrows_.emplace_back(op_arrow); | |||
| (void)from_actor->output_nodes_.emplace_back(output_with_index.first); | |||
| (void)to_actor->input_result_arrow_aids_.emplace_back(from_actor->GetAID()); | |||
| // Update the real compute node in the host data source actor. | |||
| if (kernel_type == KernelTransformType::kHostDataSourceActor) { | |||
| @@ -1525,6 +1544,7 @@ void GraphScheduler::LinkDeviceTensorStoreForAutoMonadActor(const std::vector<Ke | |||
| // Link from kernel actor to copy actor. | |||
| (void)kernel_actor->output_control_arrows_.emplace_back(copy_actor->GetAID()); | |||
| copy_actor->input_controls_num_++; | |||
| (void)copy_actor->input_control_arrow_aids_.emplace_back(kernel_actor->GetAID()); | |||
| } | |||
| } | |||
| } | |||
| @@ -1539,82 +1559,60 @@ void GraphScheduler::LinkDataArrowForSwitchActor(SwitchActor *from_actor, const | |||
| OpActor<DeviceTensor> *to_actor, const size_t to_index, | |||
| const size_t branch_index) {} | |||
| bool GraphScheduler::CheckActorValid(const ActorSet *actor_set, GraphExecutionStrategy strategy) const { | |||
| void GraphScheduler::CheckActorValid(const ActorSet *actor_set) const { | |||
| MS_EXCEPTION_IF_NULL(actor_set); | |||
| // Check the data source actors. | |||
| for (const auto &data_source_actor : actor_set->data_source_actors_) { | |||
| MS_EXCEPTION_IF_NULL(data_source_actor); | |||
| if (data_source_actor->output_data_arrows_.size() + data_source_actor->output_result_arrows_.size() + | |||
| data_source_actor->output_control_arrows_.size() == | |||
| 0) { | |||
| MS_LOG(ERROR) << data_source_actor->GetAID().Name() << " has no user."; | |||
| return false; | |||
| } | |||
| } | |||
| if (strategy == GraphExecutionStrategy::kStep) { | |||
| return true; | |||
| } | |||
| // Check the super kernel actors. | |||
| for (const auto &super_kernel_actor : actor_set->super_kernel_actors_) { | |||
| MS_EXCEPTION_IF_NULL(super_kernel_actor); | |||
| if (super_kernel_actor->output_data_arrows_.size() + super_kernel_actor->output_control_arrows_.size() == 0) { | |||
| MS_LOG(ERROR) << super_kernel_actor->GetAID().Name() << " has no user."; | |||
| return false; | |||
| } | |||
| } | |||
| // Check the kernel actors. | |||
| for (const auto &kernel_actor : actor_set->kernel_actors_) { | |||
| MS_EXCEPTION_IF_NULL(kernel_actor); | |||
| if (kernel_actor->output_data_arrows_.size() + kernel_actor->output_control_arrows_.size() == 0) { | |||
| MS_LOG(ERROR) << kernel_actor->GetAID().Name() << " has no user."; | |||
| return false; | |||
| } | |||
| auto input_num = AnfAlgo::GetInputTensorNum(kernel_actor->kernel_); | |||
| auto input_data_num = kernel_actor->input_datas_num_; | |||
| auto device_tensor_store_num = kernel_actor->device_tensor_store_keys_.size(); | |||
| if (input_data_num + device_tensor_store_num != input_num) { | |||
| MS_LOG(ERROR) << "The input building of " << AnfAlgo::GetNodeDebugString(kernel_actor->kernel_) | |||
| << " is wrong, input data num: " << input_data_num | |||
| << ", device tensor store num: " << device_tensor_store_num << ", total input num: " << input_num; | |||
| return false; | |||
| } | |||
| } | |||
| // Check the copy actors. | |||
| for (const auto ©_actor : actor_set->copy_actors_) { | |||
| MS_EXCEPTION_IF_NULL(copy_actor); | |||
| if (copy_actor->output_data_arrows_.size() + copy_actor->output_control_arrows_.size() == 0) { | |||
| MS_LOG(ERROR) << copy_actor->GetAID().Name() << " has no user."; | |||
| return false; | |||
| } | |||
| const size_t kCopyActorInputDataNum = 1; | |||
| auto input_data_num = copy_actor->input_datas_num_; | |||
| size_t device_tensor_store_num = copy_actor->device_tensor_store_keys_.size(); | |||
| if (input_data_num + device_tensor_store_num != kCopyActorInputDataNum) { | |||
| MS_LOG(ERROR) << "The input building of " << copy_actor->GetAID().Name() | |||
| << " is wrong, input data num: " << input_data_num | |||
| << ", device tensor store num: " << device_tensor_store_num | |||
| << ", total input num: " << kCopyActorInputDataNum; | |||
| return false; | |||
| auto actors = CollectActors(actor_set); | |||
| for (auto &actor : actors) { | |||
| MS_EXCEPTION_IF_NULL(actor); | |||
| if ((actor->input_datas_num_ != actor->input_data_arrow_aids_.size()) || | |||
| (actor->input_controls_num_ != actor->input_control_arrow_aids_.size())) { | |||
| MS_LOG(EXCEPTION) << "The input num of " << actor->GetAID().Name() | |||
| << " is wrong, expect data num: " << actor->input_datas_num_ | |||
| << ", actual data num: " << actor->input_data_arrow_aids_.size() | |||
| << ", expect control num: " << actor->input_controls_num_ | |||
| << ", actual control num: " << actor->input_control_arrow_aids_.size(); | |||
| } | |||
| if ((actor->type_ != KernelTransformType::kOutputActor) && (actor->type_ != KernelTransformType::kLoopCountActor) && | |||
| (actor->output_data_arrows_.size() == 0) && (actor->output_control_arrows_.size() == 0) && | |||
| (actor->output_result_arrows_.size() == 0)) { | |||
| MS_LOG(EXCEPTION) << actor->GetAID().Name() << " has no user."; | |||
| } | |||
| if ((actor->type_ != KernelTransformType::kOutputActor) && | |||
| (actor->type_ != KernelTransformType::kDataPrepareActor) && (actor->input_datas_num_ == 0) && | |||
| (actor->input_controls_num_ == 0)) { | |||
| MS_LOG(EXCEPTION) << actor->GetAID().Name() << " has no source."; | |||
| } | |||
| // Check the input of kernel actors and copy actors. | |||
| if ((actor->type_ == KernelTransformType::kKernelActor) || (actor->type_ == KernelTransformType::kCopyActor)) { | |||
| size_t expect_toal_input_num = 1; | |||
| if (actor->type_ == KernelTransformType::kKernelActor) { | |||
| auto kernel_actor = dynamic_cast<KernelActor *>(actor.get()); | |||
| MS_EXCEPTION_IF_NULL(kernel_actor); | |||
| expect_toal_input_num = AnfAlgo::GetInputTensorNum(kernel_actor->kernel_); | |||
| } | |||
| auto input_data_num = actor->input_datas_num_; | |||
| auto device_tensor_store_num = actor->device_tensor_store_keys_.size(); | |||
| if (input_data_num + device_tensor_store_num != expect_toal_input_num) { | |||
| MS_LOG(EXCEPTION) << "The input building of " << actor->GetAID().Name() | |||
| << " is wrong, input data num: " << input_data_num | |||
| << ", device tensor store num: " << device_tensor_store_num | |||
| << ", total input num: " << expect_toal_input_num; | |||
| } | |||
| } | |||
| } | |||
| // Check the loop count actor. | |||
| const auto &loop_count_actor = actor_set->loop_count_actor_; | |||
| if ((loop_count_actor != nullptr) && | |||
| (actor_set->data_source_actors_.size() + actor_set->kernel_actors_.size() + actor_set->copy_actors_.size() > 0)) { | |||
| if (loop_count_actor->input_controls_num_ == 0) { | |||
| MS_LOG(ERROR) << loop_count_actor->GetAID().Name() << " has no source."; | |||
| return false; | |||
| } | |||
| // Check the output actor. | |||
| auto output_actor = actor_set->output_actor_; | |||
| MS_EXCEPTION_IF_NULL(output_actor); | |||
| if (output_actor->input_result_arrow_aids_.size() + output_actor->device_tensor_store_keys_.size() != | |||
| output_actor->outputs_num_) { | |||
| MS_LOG(EXCEPTION) << "The outputs num of output actor is wrong, the total outputs num: " | |||
| << output_actor->outputs_num_ | |||
| << ", the input result arrows num: " << output_actor->input_result_arrow_aids_.size() | |||
| << ", the device tensor store num: " << output_actor->device_tensor_store_keys_.size(); | |||
| } | |||
| return true; | |||
| } | |||
| void GraphScheduler::PersistDeviceTensor(const GraphCompilerInfo &graph_compiler_info) { | |||
| @@ -1819,14 +1817,6 @@ void GraphScheduler::DumpActor(const ActorSet *actor_set, const GraphCompilerInf | |||
| void GraphScheduler::DumpAbstractActor(const AbstractActor *actor, std::ofstream &ofs) const { | |||
| MS_EXCEPTION_IF_NULL(actor); | |||
| ofs << "\t\tdevice_contexts_num:" << actor->device_contexts_.size() | |||
| << "\tdevice_tensor_store_keys_num:" << actor->device_tensor_store_keys_.size() | |||
| << "\tinput_data_arrow_actors_num:" << actor->input_datas_num_ | |||
| << "\tinput_control_arrow_actors_num:" << actor->input_controls_num_ << "\n"; | |||
| ofs << "\t\toutput_data_arrows_num:" << actor->output_data_arrows_.size() | |||
| << "\toutput_control_arrows_num:" << actor->output_control_arrows_.size() | |||
| << "\toutput_result_arrows_num:" << actor->output_result_arrows_.size() << "\n"; | |||
| if (actor->device_contexts_.size() > 0) { | |||
| ofs << "\t\tdevice_contexts:" << actor->device_contexts_.size() << "\n "; | |||
| for (const auto &device_context : actor->device_contexts_) { | |||
| @@ -1903,14 +1893,6 @@ void GraphScheduler::DumpDataPrepareActor(const DataPrepareActor *actor, std::of | |||
| ofs << "\tactor_name:" << actor->GetAID().Name() << "\n"; | |||
| DumpAbstractActor(actor, ofs); | |||
| ofs << "\t\toutput_control_arrows:" << actor->data_source_aids_.size() + actor->no_input_kernel_aids_.size() << "\n "; | |||
| for (const auto &aid : actor->data_source_aids_) { | |||
| ofs << "\t\t\tto_actor_name:" << aid.Name() << "\n"; | |||
| } | |||
| for (const auto &aid : actor->no_input_kernel_aids_) { | |||
| ofs << "\t\t\tto_actor_name:" << aid.Name() << "\n"; | |||
| } | |||
| ofs << "\t\tcontinuous_memory_nodes:" << actor->continuous_memory_nodes_.size() << "\n "; | |||
| for (const auto &iter : actor->continuous_memory_nodes_) { | |||
| MS_EXCEPTION_IF_NULL(iter.first.first); | |||
| @@ -2023,7 +2005,13 @@ void GraphScheduler::DumpOutputActor(const OutputActor *actor, std::ofstream &of | |||
| MS_EXCEPTION_IF_NULL(actor); | |||
| ofs << "\tactor_name:" << actor->GetAID().Name() << "\tloop_count:" << actor->loop_count_ | |||
| << "\toutputs_num:" << actor->outputs_num_ << "\n"; | |||
| DumpAbstractActor(actor, ofs); | |||
| ofs << "\t\tinput_result_arrows:" << actor->input_result_arrow_aids_.size() << "\n "; | |||
| for (const auto &input_result_arrow_aid : actor->input_result_arrow_aids_) { | |||
| ofs << "\t\t\tfrom_actor_name:" << input_result_arrow_aid.Name() << "\n"; | |||
| } | |||
| } | |||
| void GraphScheduler::DumpCopyActor(const CopyActor *actor, std::ofstream &ofs) const { | |||
| @@ -2043,7 +2031,8 @@ void GraphScheduler::DumpCopyActor(const CopyActor *actor, std::ofstream &ofs) c | |||
| void GraphScheduler::DumpDeviceTensorStore(const GraphCompilerInfo &graph_compiler_info, std::ofstream &ofs) const { | |||
| for (const auto &graph : graph_compiler_info.graphs_) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| ofs << "\tgraph_id:" << graph->graph_id() << "\tis_sink:" << graph->is_sink() << "\n"; | |||
| ofs << "\tgraph_id:" << graph->graph_id() << "\tis_sink:" << graph->is_sink() | |||
| << "\texecution_strategy:" << graph_compiler_info.strategy_ << "\n"; | |||
| for (auto &value_node : graph->graph_value_nodes()) { | |||
| MS_EXCEPTION_IF_NULL(value_node); | |||
| @@ -207,8 +207,7 @@ class GraphScheduler { | |||
| const size_t branch_index = SIZE_MAX); | |||
| // Check whether the actor set is valid. | |||
| bool CheckActorValid(const ActorSet *actor_set, | |||
| GraphExecutionStrategy strategy = GraphExecutionStrategy::kPipeline) const; | |||
| void CheckActorValid(const ActorSet *actor_set) const; | |||
| // Persist device tensors of graph's some nodes(such as weights and value nodes). | |||
| void PersistDeviceTensor(const GraphCompilerInfo &graph_compiler_info); | |||