| @@ -96,8 +96,9 @@ void AbstractActor::SendOutput(OpContext<DeviceTensor> *const context) { | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| // Must be the execution order: send data --> send control, avoid the illegal timing problem. | |||
| // 1.Send output data. | |||
| if ((output_data_arrows_.size() != output_data_.size()) || | |||
| (output_data_arrows_.size() != output_data_nodes_.size())) { | |||
| if (((output_data_arrows_.size() != output_data_.size()) || | |||
| (output_data_arrows_.size() != output_data_nodes_.size())) && | |||
| (type_ < KernelTransformType::kSwitchActor)) { | |||
| SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "The size of output data arrows is not equal to the output data."); | |||
| } | |||
| size_t output_data_arrow_index = 0; | |||
| @@ -121,7 +122,8 @@ void AbstractActor::SendOutput(OpContext<DeviceTensor> *const context) { | |||
| SendRecorderInfo(context); | |||
| // No output. | |||
| if ((output_data_arrows_.size() == 0) && (output_control_arrows_.size() == 0)) { | |||
| if ((output_data_arrows_.size() == 0) && (output_control_arrows_.size() == 0) && | |||
| (type_ < KernelTransformType::kSwitchActor)) { | |||
| SET_OPCONTEXT_SUCCESS_RET((*context)); | |||
| } | |||
| } | |||
| @@ -0,0 +1,204 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "runtime/framework/actor/control_flow/control_actor.h" | |||
| namespace mindspore { | |||
| namespace runtime { | |||
| ControlActor::ControlActor(const std::string &name, KernelTransformType type, | |||
| const std::vector<KernelWithIndex> ¶meters, const AnfNodePtr &node) | |||
| : AbstractActor(name, type, nullptr), formal_parameters_(parameters), node_(node) { | |||
| input_partials_.resize(parameters.size()); | |||
| input_device_tensors_.resize(parameters.size()); | |||
| } | |||
| void ControlActor::Init() { | |||
| output_data_by_output_index_.resize(formal_parameters_.size()); | |||
| for (auto &data_arrow : output_data_arrows_) { | |||
| MS_EXCEPTION_IF_NULL(data_arrow); | |||
| if (IntToSize(data_arrow->from_output_index_) >= formal_parameters_.size()) { | |||
| MS_LOG(EXCEPTION) << "The output index is out of range: " << GetAID(); | |||
| } | |||
| auto data = std::make_unique<OpData<DeviceTensor>>(data_arrow->to_op_id_, nullptr, data_arrow->to_input_index_); | |||
| (void)output_data_by_output_index_[data_arrow->from_output_index_].emplace_back(data.get()); | |||
| (void)output_data_.emplace_back(std::move(data)); | |||
| } | |||
| } | |||
| size_t ControlActor::FetchNodePosition(const KernelWithIndex &node) const { | |||
| const auto &iter = find(formal_parameters_.begin(), formal_parameters_.end(), node); | |||
| if (iter == formal_parameters_.end()) { | |||
| MS_LOG(EXCEPTION) << "Invalid formal parameter:" << node.first->DebugString() << " for actor:" << GetAID(); | |||
| } | |||
| return iter - formal_parameters_.begin(); | |||
| } | |||
| void ControlActor::Run(OpContext<DeviceTensor> *const context) { | |||
| FetchInput(context); | |||
| EraseInput(context); | |||
| SendOutput(context); | |||
| } | |||
| void ControlActor::RunOpPartial(FuncGraph *func_graph, std::vector<DeviceTensor *> input_data, size_t position, | |||
| OpContext<DeviceTensor> *const context) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| auto &sequential_num = context->sequential_num_; | |||
| input_op_partials_[sequential_num].emplace_back(position, OpPartial(func_graph, input_data)); | |||
| if (CheckRunningCondition(context)) { | |||
| Run(context); | |||
| } | |||
| } | |||
| void ControlActor::RunBranchID(int branch_id, OpContext<DeviceTensor> *const context) { | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| auto &sequential_num = context->sequential_num_; | |||
| input_branch_ids_[sequential_num].push(branch_id); | |||
| if (CheckRunningCondition(context)) { | |||
| Run(context); | |||
| } | |||
| } | |||
| bool ControlActor::CheckRunningCondition(const OpContext<DeviceTensor> *context) const { | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| if (!AbstractActor::CheckRunningCondition(context)) { | |||
| return false; | |||
| } | |||
| if (input_partials_num_ != 0) { | |||
| const auto &partial_iter = input_op_partials_.find(context->sequential_num_); | |||
| if (partial_iter == input_op_partials_.end()) { | |||
| return false; | |||
| } | |||
| if (partial_iter->second.size() != input_partials_num_) { | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| void ControlActor::FetchInput(OpContext<DeviceTensor> *const context) { | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| // Fetch input device tensor from input data. | |||
| const auto &data_iter = input_op_datas_.find(context->sequential_num_); | |||
| if (data_iter != input_op_datas_.end()) { | |||
| for (auto &input_data : data_iter->second) { | |||
| MS_EXCEPTION_IF_NULL(input_data); | |||
| if (IntToSize(input_data->index_) >= input_device_tensors_.size()) { | |||
| MS_LOG(ERROR) << "Invalid index, need:" << input_data->index_ << " current:" << input_device_tensors_.size() | |||
| << " for actor:" << GetAID(); | |||
| } | |||
| input_device_tensors_[input_data->index_] = input_data->data_; | |||
| } | |||
| } | |||
| // Fetch input device tensor from device store. | |||
| for (auto &device_tensor_store_key : device_tensor_store_keys_) { | |||
| auto device_tensor = DeviceTensorStore::GetInstance().Fetch(device_tensor_store_key.second.get(), | |||
| device_contexts_[0]->GetDeviceAddressType()); | |||
| if (device_tensor == nullptr) { | |||
| MS_LOG(ERROR) << GetAID() << " get device tensor store failed: " << device_tensor_store_key.second->DebugString(); | |||
| } | |||
| if (device_tensor_store_key.first >= input_device_tensors_.size()) { | |||
| MS_LOG(ERROR) << "The input index is out of range, need:" << device_tensor_store_key.first | |||
| << " current:" << input_device_tensors_.size() << " for actor:" << GetAID(); | |||
| } | |||
| input_device_tensors_[device_tensor_store_key.first] = device_tensor; | |||
| } | |||
| for (size_t i = 0; i < output_data_by_output_index_.size(); ++i) { | |||
| if (output_data_by_output_index_[i].empty()) { | |||
| continue; | |||
| } | |||
| const auto &data = input_device_tensors_[i]; | |||
| MS_EXCEPTION_IF_NULL(data); | |||
| for (auto &output_data : output_data_by_output_index_[i]) { | |||
| MS_EXCEPTION_IF_NULL(output_data); | |||
| output_data->data_ = data; | |||
| } | |||
| } | |||
| // Fetch input partial from input data. | |||
| const auto &partial_iter = input_op_partials_.find(context->sequential_num_); | |||
| if (partial_iter != input_op_partials_.end()) { | |||
| for (const auto &input_partial : partial_iter->second) { | |||
| MS_EXCEPTION_IF_NULL(input_partial.second.first); | |||
| input_partials_[input_partial.first] = input_partial.second; | |||
| } | |||
| } | |||
| // Fetch input partial from local partial. | |||
| for (const auto &local_partial : local_partials_) { | |||
| input_partials_[local_partial.first] = local_partial.second; | |||
| } | |||
| // Fetch branch id in stack. | |||
| auto iter = input_branch_ids_.find(context->sequential_num_); | |||
| if (iter != input_branch_ids_.end() && (!iter->second.empty())) { | |||
| output_branch_id_ = iter->second.top(); | |||
| } | |||
| } | |||
| void ControlActor::EraseInput(const OpContext<DeviceTensor> *context) { | |||
| AbstractActor::EraseInput(context); | |||
| const auto &sequential_num = context->sequential_num_; | |||
| if (input_partials_num_ != 0) { | |||
| auto ret = input_op_partials_.erase(sequential_num); | |||
| if (ret == 0) { | |||
| std::string error_info = "Erase input partial failed: " + GetAID().Name(); | |||
| // The sequential num may be invalid, can't set the promise value of context. | |||
| MS_LOG(ERROR) << error_info << ", sequential_num: " << sequential_num; | |||
| } | |||
| } | |||
| if (input_branch_ids_.find(sequential_num) != input_branch_ids_.end()) { | |||
| input_branch_ids_[sequential_num].pop(); | |||
| if (input_branch_ids_[sequential_num].empty()) { | |||
| auto ret = input_branch_ids_.erase(sequential_num); | |||
| if (ret == 0) { | |||
| MS_LOG(ERROR) << "Erase input branch id failed: " << GetAID() << ", sequential_num: " << sequential_num; | |||
| return; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| void ControlActor::SendOutput(OpContext<DeviceTensor> *const context) { | |||
| // Send branch id. | |||
| for (const auto &branch_id_arrow : output_branch_id_arrows_) { | |||
| Async(branch_id_arrow, &ControlActor::RunBranchID, output_branch_id_, context); | |||
| } | |||
| // Send data in base class. | |||
| AbstractActor::SendOutput(context); | |||
| // Send Partial. | |||
| for (const auto &partial_arrow : output_partial_arrows_) { | |||
| MS_EXCEPTION_IF_NULL(partial_arrow); | |||
| MS_EXCEPTION_IF_NULL(output_partial_.first); | |||
| Async(partial_arrow->to_op_id_, &ControlActor::RunOpPartial, output_partial_.first, output_partial_.second, | |||
| IntToSize(partial_arrow->to_input_index_), context); | |||
| } | |||
| } | |||
| } // namespace runtime | |||
| } // namespace mindspore | |||
| @@ -35,7 +35,7 @@ namespace runtime { | |||
| // parameters and the id of the caller. | |||
| using OpDataWithBranchID = std::pair<std::vector<DeviceTensor *>, int>; | |||
| // Op partial represents the partial structure, including a funcgraph and its real parameters. | |||
| using OpPartial = std::pair<FuncGraphPtr, std::vector<DeviceTensor *>>; | |||
| using OpPartial = std::pair<FuncGraph *, std::vector<DeviceTensor *>>; | |||
| // The control actor is the base class of control flow actor. | |||
| class ControlActor : public AbstractActor { | |||
| public: | |||
| @@ -43,9 +43,7 @@ class ControlActor : public AbstractActor { | |||
| const AnfNodePtr &node); | |||
| ~ControlActor() override = default; | |||
| void Init() override {} | |||
| void RunOpControl(AID *const input_control, OpContext<DeviceTensor> *const context) override {} | |||
| void Init() override; | |||
| // Receive partial. | |||
| virtual void RunOpPartial(FuncGraph *func_graph, std::vector<DeviceTensor *> input_data, size_t position, | |||
| @@ -54,16 +52,19 @@ class ControlActor : public AbstractActor { | |||
| // Receive branch id. | |||
| virtual void RunBranchID(int branch_id, OpContext<DeviceTensor> *const context); | |||
| const std::vector<DataArrowPtr> &output_partial_arrows() const { return output_partial_arrows_; } | |||
| const std::vector<AID> &output_branch_id_arrows() const { return output_branch_id_arrows_; } | |||
| protected: | |||
| // Get the position of node in the input. | |||
| size_t FetchNodePosition(const KernelWithIndex &node) const; | |||
| // Get all input, including data, partial, branchid. | |||
| virtual void FetchInput(OpContext<DeviceTensor> *const context); | |||
| void Run(OpContext<DeviceTensor> *const context); | |||
| bool CheckRunningCondition(const OpContext<DeviceTensor> *context) const; | |||
| void SendOutput(OpContext<DeviceTensor> *const context); | |||
| void EraseInput(const OpContext<DeviceTensor> *context); | |||
| void Run(OpContext<DeviceTensor> *const context) override; | |||
| bool CheckRunningCondition(const OpContext<DeviceTensor> *context) const override; | |||
| void SendOutput(OpContext<DeviceTensor> *const context) override; | |||
| void EraseInput(const OpContext<DeviceTensor> *context) override; | |||
| // Input data. | |||
| // 1.Input partial. | |||
| @@ -81,17 +82,19 @@ class ControlActor : public AbstractActor { | |||
| // Fetch data. After fetch input, all the input collected is saved here. | |||
| std::vector<OpPartial> input_partials_; | |||
| std::vector<DeviceTensor *> input_device_tensors_; | |||
| // The branch id is the unique identifier of the control actor. In the control flow, there are multiple control | |||
| // actors calling the same subgraph at the same time. At this time, the output of the subgraph needs to be returned | |||
| // to the calling place according to the branch id. | |||
| int branch_id_; | |||
| // Input num. | |||
| size_t input_partials_num_; | |||
| size_t input_partials_num_{0}; | |||
| // Output Arrows. | |||
| std::vector<DataArrowPtr> output_partial_arrows_; | |||
| std::vector<DataArrowPtr> output_branch_id_arrows_; | |||
| OpPartial output_partial_; | |||
| std::vector<AID> output_branch_id_arrows_; | |||
| // The branch id is the unique identifier of the control actor. In the control flow, there are multiple control | |||
| // actors calling the same subgraph at the same time. At this time, the output of the subgraph needs to be returned | |||
| // to the calling place according to the branch id. | |||
| int output_branch_id_; | |||
| // Partial data in local. When partial is only funcgraph without real parameter, it is stored inside the actor. | |||
| std::unordered_map<size_t, OpPartial> local_partials_; | |||
| @@ -0,0 +1,145 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "runtime/framework/actor/control_flow/entrance_actor.h" | |||
| #include "runtime/framework/actor/control_flow/exit_actor.h" | |||
| namespace mindspore { | |||
| namespace runtime { | |||
| constexpr size_t kEntranceInputStartPos = 1; | |||
| void EntranceActor::RunOpDataWithBranchID(std::vector<DeviceTensor *> input_data, int branch_id, | |||
| OpContext<DeviceTensor> *const context) { | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| auto &sequential_num = context->sequential_num_; | |||
| input_op_data_with_branch_id_[sequential_num].emplace(input_data, branch_id); | |||
| if (CheckRunningCondition(context)) { | |||
| Run(context); | |||
| } | |||
| } | |||
| void EntranceActor::Run(OpContext<DeviceTensor> *const context) { | |||
| FetchInput(context); | |||
| EraseInput(context); | |||
| SendOutput(context); | |||
| // The actor needs to be disabled after the actor is running, until no actor is running in the entire funcgraph. | |||
| is_actor_ready_ = false; | |||
| } | |||
| void EntranceActor::FetchInput(OpContext<DeviceTensor> *const context) { | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| auto &sequential_num = context->sequential_num_; | |||
| // There are two kinds of run conditions for entrance actor: | |||
| // 1.Data comes from the data source actor, it is in the form of data arrow. | |||
| const auto &data_iter = input_op_datas_.find(sequential_num); | |||
| if (data_iter != input_op_datas_.end()) { | |||
| for (auto &input_data : data_iter->second) { | |||
| MS_EXCEPTION_IF_NULL(input_data); | |||
| if (IntToSize(input_data->index_) >= input_device_tensors_.size()) { | |||
| MS_LOG(ERROR) << "The input index is out of range, need:" << input_data->index_ | |||
| << " current:" << input_device_tensors_.size() << " for actor:" << GetAID(); | |||
| } | |||
| MS_EXCEPTION_IF_NULL(input_data->data_); | |||
| input_device_tensors_[input_data->index_] = input_data->data_; | |||
| } | |||
| // If the data comes from the data source actor, use the default branch id. | |||
| output_branch_id_ = 0; | |||
| } else { | |||
| // 2.Data comes from the gather actor, it is in the form of data with branch id. | |||
| output_branch_id_ = input_op_data_with_branch_id_[sequential_num].front().second; | |||
| const auto &device_tensors = input_op_data_with_branch_id_[sequential_num].front().first; | |||
| if (device_tensors.size() != formal_parameters_.size()) { | |||
| MS_LOG(ERROR) << "Invalid input num, need:" << formal_parameters_.size() << " current:" << device_tensors.size(); | |||
| } | |||
| input_device_tensors_ = device_tensors; | |||
| } | |||
| // Init the device tensor in output data. | |||
| for (size_t i = 0; i < output_data_by_output_index_.size(); ++i) { | |||
| if (output_data_by_output_index_[i].empty()) { | |||
| continue; | |||
| } | |||
| const auto &data = input_device_tensors_[i]; | |||
| MS_EXCEPTION_IF_NULL(data); | |||
| for (auto &output_data : output_data_by_output_index_[i]) { | |||
| MS_EXCEPTION_IF_NULL(output_data); | |||
| output_data->data_ = data; | |||
| } | |||
| } | |||
| } | |||
| bool EntranceActor::CheckActorStatus(const OpContext<DeviceTensor> *const context) const { | |||
| if (is_actor_ready_) { | |||
| return true; | |||
| } | |||
| // During operation, entrance actor can be enabled only when receives all control arrows. | |||
| if (input_controls_num_ != 0) { | |||
| const auto &control_iter = input_op_controls_.find(context->sequential_num_); | |||
| if (control_iter != input_op_controls_.end() && control_iter->second.size() == input_controls_num_) { | |||
| return true; | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| bool EntranceActor::CheckRunningCondition(const OpContext<DeviceTensor> *context) const { | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| // When the entrance actor is in the disabled state, it cannot be run. | |||
| if (!CheckActorStatus(context)) { | |||
| return false; | |||
| } | |||
| // Data comes from the data source actor. | |||
| if (input_datas_num_ != 0) { | |||
| const auto &data_iter = input_op_datas_.find(context->sequential_num_); | |||
| if (data_iter != input_op_datas_.end() && data_iter->second.size() == input_datas_num_) { | |||
| return true; | |||
| } | |||
| } | |||
| // Data comes from the gather actor. | |||
| const auto &iter = input_op_data_with_branch_id_.find(context->sequential_num_); | |||
| if (iter == input_op_data_with_branch_id_.end() || iter->second.empty()) { | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| void EntranceActor::EraseInput(const OpContext<DeviceTensor> *const context) { | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| auto &sequential_num = context->sequential_num_; | |||
| const auto &data_iter = input_op_datas_.find(sequential_num); | |||
| if (data_iter != input_op_datas_.end()) { | |||
| input_op_datas_.erase(data_iter); | |||
| return; | |||
| } | |||
| const auto &iter = input_op_data_with_branch_id_.find(sequential_num); | |||
| if (iter == input_op_data_with_branch_id_.end() || iter->second.empty()) { | |||
| MS_LOG(ERROR) << "Cannot find input in batch op result for actor:" << GetAID(); | |||
| } | |||
| iter->second.pop(); | |||
| if (iter->second.empty()) { | |||
| input_op_data_with_branch_id_.erase(sequential_num); | |||
| } | |||
| } | |||
| } // namespace runtime | |||
| } // namespace mindspore | |||
| @@ -40,17 +40,22 @@ class EntranceActor : public ControlActor { | |||
| input_device_tensors_.resize(parameters.size()); | |||
| } | |||
| ~EntranceActor() override = default; | |||
| void RunOpControl(AID *const input_control, OpContext<DeviceTensor> *const context); | |||
| void RunOpDataWithBranchID(std::vector<DeviceTensor *> input_data, int branch_id, | |||
| OpContext<DeviceTensor> *const context); | |||
| protected: | |||
| void FetchInput(OpContext<DeviceTensor> *const context); | |||
| bool CheckRunningCondition(const OpContext<DeviceTensor> *context) const; | |||
| void EraseInput(const OpContext<DeviceTensor> *const context); | |||
| void Run(OpContext<DeviceTensor> *const context) override; | |||
| void FetchInput(OpContext<DeviceTensor> *const context) override; | |||
| bool CheckRunningCondition(const OpContext<DeviceTensor> *context) const override; | |||
| void EraseInput(const OpContext<DeviceTensor> *const context) override; | |||
| private: | |||
| friend class ControlNodeScheduler; | |||
| // Check if actor is enable. During operation, entrance actor can be enabled only when receives all control arrows. | |||
| bool CheckActorStatus(const OpContext<DeviceTensor> *const context) const; | |||
| // Is actor ready indicates whether the entrance actor can be executed. In the control flow, the subgraph is an | |||
| // atomic operation, and execution can only continue after the output of the corresponding exit actor is completed. | |||
| // At this time, the exit actor will notify the entrance actor to change the ready to true. | |||
| @@ -0,0 +1,119 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "runtime/framework/actor/control_flow/exit_actor.h" | |||
| #include "runtime/framework/actor/output_actor.h" | |||
| namespace mindspore { | |||
| namespace runtime { | |||
| void ExitActor::Init() { | |||
| // Init output data in base class. | |||
| ControlActor::Init(); | |||
| // Init output data in each output branch. | |||
| for (size_t i = 0; i < output_branch_data_arrows_.size(); ++i) { | |||
| auto &output_branch_data_arrows = output_branch_data_arrows_[i]; | |||
| for (auto &data_arrow : output_branch_data_arrows) { | |||
| MS_EXCEPTION_IF_NULL(data_arrow); | |||
| auto data = std::make_unique<OpData<DeviceTensor>>(data_arrow->to_op_id_, nullptr, data_arrow->to_input_index_); | |||
| output_branch_data_[i].emplace_back(data_arrow->from_output_index_, std::move(data)); | |||
| } | |||
| } | |||
| } | |||
| void ExitActor::FetchInput(OpContext<DeviceTensor> *const context) { | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| ControlActor::FetchInput(context); | |||
| CopyDeviceAddress(); | |||
| auto data_iter = output_branch_data_.find(output_branch_id_); | |||
| if (data_iter != output_branch_data_.end()) { | |||
| for (auto &output_data : data_iter->second) { | |||
| MS_EXCEPTION_IF_NULL(output_data.second); | |||
| MS_EXCEPTION_IF_NULL(input_device_tensors_[output_data.first]); | |||
| output_data.second->data_ = input_device_tensors_[output_data.first]; | |||
| } | |||
| } | |||
| } | |||
| void ExitActor::SendOutput(OpContext<DeviceTensor> *const context) { | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| // 1.Send output in base class. | |||
| ControlActor::SendOutput(context); | |||
| // 2.Send output control in output branch. | |||
| const auto &control_iter = output_branch_control_arrows_.find(output_branch_id_); | |||
| if (control_iter != output_branch_control_arrows_.end()) { | |||
| auto source_aid = const_cast<AID *>(&GetAID()); | |||
| for (const auto &control_arrow : control_iter->second) { | |||
| Async(control_arrow, &OpActor::RunOpControl, source_aid, context); | |||
| } | |||
| } | |||
| // 2.Send output data in output branch. | |||
| const auto &branch_data_iter = output_branch_data_.find(output_branch_id_); | |||
| if (branch_data_iter != output_branch_data_.end()) { | |||
| for (const auto &output_data : branch_data_iter->second) { | |||
| MS_EXCEPTION_IF_NULL(output_data.second); | |||
| Async(output_data.second->op_id_, &OpActor::RunOpData, output_data.second.get(), context); | |||
| } | |||
| } | |||
| } | |||
| void ExitActor::CopyDeviceAddress() { | |||
| std::vector<DeviceTensor *> new_device_tensors; | |||
| for (size_t i = 0; i < input_device_tensors_.size(); ++i) { | |||
| auto input_device_tensor = input_device_tensors_[i]; | |||
| MS_EXCEPTION_IF_NULL(input_device_tensor); | |||
| const KernelWithIndex &node_with_index = input_device_tensor->GetNodeIndex(); | |||
| MS_EXCEPTION_IF_NULL(node_with_index.first); | |||
| if (!node_with_index.first->isa<CNode>()) { | |||
| continue; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(device_contexts_[i]); | |||
| auto new_device_tensor = | |||
| device_contexts_[i]->CreateDeviceAddress(nullptr, input_device_tensors_[i]->GetSize(), | |||
| input_device_tensors_[i]->format(), input_device_tensors_[i]->type_id()); | |||
| MS_EXCEPTION_IF_NULL(new_device_tensor); | |||
| new_device_tensor->set_ptr(input_device_tensor->GetMutablePtr()); | |||
| new_device_tensor->set_from_mem_pool(input_device_tensor->from_mem_pool()); | |||
| new_device_tensor->SetNodeIndex(node_with_index.first, node_with_index.second); | |||
| new_device_tensor->set_original_ref_count(SIZE_MAX); | |||
| new_device_tensor->ResetRefCount(); | |||
| new_device_tensors.emplace_back(new_device_tensor.get()); | |||
| created_device_tensors_.emplace_back(new_device_tensor); | |||
| input_device_tensor->set_ptr(nullptr); | |||
| input_device_tensor->set_from_mem_pool(false); | |||
| } | |||
| input_device_tensors_.swap(new_device_tensors); | |||
| for (size_t i = 0; i < output_data_by_output_index_.size(); ++i) { | |||
| if (output_data_by_output_index_[i].empty()) { | |||
| continue; | |||
| } | |||
| const auto &data = input_device_tensors_[i]; | |||
| MS_EXCEPTION_IF_NULL(data); | |||
| for (auto &output_data : output_data_by_output_index_[i]) { | |||
| MS_EXCEPTION_IF_NULL(output_data); | |||
| output_data->data_ = data; | |||
| } | |||
| } | |||
| } | |||
| } // namespace runtime | |||
| } // namespace mindspore | |||
| @@ -39,10 +39,21 @@ class ExitActor : public ControlActor { | |||
| } | |||
| ~ExitActor() override = default; | |||
| void Init(); | |||
| void Init() override; | |||
| const std::unordered_map<int, std::vector<AID>> &output_branch_control_arrows() const { | |||
| return output_branch_control_arrows_; | |||
| } | |||
| const std::unordered_map<int, std::vector<DataArrowPtr>> &output_branch_data_arrows() const { | |||
| return output_branch_data_arrows_; | |||
| } | |||
| const std::unordered_map<int, std::vector<DataArrowPtr>> &output_branch_partial_arrows() const { | |||
| return output_branch_partial_arrows_; | |||
| } | |||
| protected: | |||
| void FetchInput(OpContext<DeviceTensor> *const context); | |||
| void FetchInput(OpContext<DeviceTensor> *const context) override; | |||
| void SendOutput(OpContext<DeviceTensor> *const context) override; | |||
| private: | |||
| friend class ControlNodeScheduler; | |||
| @@ -15,7 +15,40 @@ | |||
| */ | |||
| #include "runtime/framework/actor/control_flow/gather_actor.h" | |||
| #include "runtime/framework/actor/control_flow/entrance_actor.h" | |||
| namespace mindspore { | |||
| namespace runtime {} // namespace runtime | |||
| namespace runtime { | |||
| GatherActor::GatherActor(const std::string &name, const std::vector<KernelWithIndex> ¶meters, | |||
| const AnfNodePtr &node) | |||
| : ControlActor(name, KernelTransformType::kGatherActor, parameters, node) {} | |||
| void GatherActor::FetchInput(OpContext<DeviceTensor> *const context) { | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| ControlActor::FetchInput(context); | |||
| output_partial_ = input_partials_[0]; | |||
| // Put other real parameter in partial. | |||
| for (const auto &device_tensor : input_device_tensors_) { | |||
| if (device_tensor != nullptr) { | |||
| output_partial_.second.emplace_back(device_tensor); | |||
| } | |||
| } | |||
| } | |||
| void GatherActor::SendOutput(OpContext<DeviceTensor> *const context) { | |||
| // Send data with branch id. | |||
| const auto &iter = output_data_with_branch_id_arrows_.find(output_partial_.first); | |||
| if (iter != output_data_with_branch_id_arrows_.end()) { | |||
| for (const auto &data_with_branch_id_arrow : iter->second) { | |||
| Async(data_with_branch_id_arrow, &EntranceActor::RunOpDataWithBranchID, output_partial_.second, output_branch_id_, | |||
| context); | |||
| } | |||
| } | |||
| // Control arrow needs to be sent after the real parameter data and branch id. | |||
| ControlActor::SendOutput(context); | |||
| } | |||
| } // namespace runtime | |||
| } // namespace mindspore | |||
| @@ -37,16 +37,19 @@ class GatherActor : public ControlActor { | |||
| public: | |||
| GatherActor(const std::string &name, const std::vector<KernelWithIndex> ¶meters, const AnfNodePtr &node); | |||
| ~GatherActor() override = default; | |||
| const std::unordered_map<FuncGraph *, std::vector<AID>> &output_data_with_branch_id_arrows() const { | |||
| return output_data_with_branch_id_arrows_; | |||
| } | |||
| protected: | |||
| void FetchInput(OpContext<DeviceTensor> *const context); | |||
| void SendOutput(OpContext<DeviceTensor> *const context); | |||
| void FetchInput(OpContext<DeviceTensor> *const context) override; | |||
| void SendOutput(OpContext<DeviceTensor> *const context) override; | |||
| private: | |||
| friend class ControlNodeScheduler; | |||
| // When the output data arrow needs to have a branch id, there will be multiple output branches. | |||
| std::unordered_map<int, std::vector<AID>> output_data_with_branch_id_arrows_; | |||
| // There will be multiple output branches for gather actor according the funcgraph in partial. | |||
| std::unordered_map<FuncGraph *, std::vector<AID>> output_data_with_branch_id_arrows_; | |||
| }; | |||
| using GatherActorPtr = std::shared_ptr<GatherActor>; | |||
| @@ -0,0 +1,37 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "runtime/framework/actor/control_flow/stack_actor.h" | |||
| #include "runtime/framework/actor/memory_manager_actor.h" | |||
| #include "runtime/framework/control_node_parser.h" | |||
| namespace mindspore { | |||
| namespace runtime { | |||
| StackActor::StackActor(const std::string &name, const std::vector<KernelWithIndex> ¶meters) | |||
| : ControlActor(name, KernelTransformType::kStackActor, parameters, nullptr) { | |||
| input_device_tensors_.resize(parameters.size()); | |||
| } | |||
| bool StackActor::CheckRunningCondition(const OpContext<DeviceTensor> *context) const { | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| return false; | |||
| } | |||
| void StackActor::FetchInput(OpContext<DeviceTensor> *const context) { MS_EXCEPTION_IF_NULL(context); } | |||
| void StackActor::EraseInput(const OpContext<DeviceTensor> *const context) { MS_EXCEPTION_IF_NULL(context); } | |||
| } // namespace runtime | |||
| } // namespace mindspore | |||
| @@ -39,10 +39,9 @@ class StackActor : public ControlActor { | |||
| ~StackActor() override = default; | |||
| protected: | |||
| void RunOpData(OpData<DeviceTensor> *const input_data, OpContext<DeviceTensor> *const context); | |||
| void FetchInput(OpContext<DeviceTensor> *const context); | |||
| bool CheckRunningCondition(const OpContext<DeviceTensor> *context) const; | |||
| void EraseInput(const OpContext<DeviceTensor> *const context); | |||
| void FetchInput(OpContext<DeviceTensor> *const context) override; | |||
| bool CheckRunningCondition(const OpContext<DeviceTensor> *context) const override; | |||
| void EraseInput(const OpContext<DeviceTensor> *const context) override; | |||
| private: | |||
| friend class ControlNodeScheduler; | |||
| @@ -15,13 +15,93 @@ | |||
| */ | |||
| #include "runtime/framework/actor/control_flow/switch_actor.h" | |||
| #include "runtime/framework/actor/control_flow/gather_actor.h" | |||
| #include "runtime/framework/actor/output_actor.h" | |||
| #include "runtime/framework/actor/memory_manager_actor.h" | |||
| #include "mindrt/include/async/async.h" | |||
| #include "runtime/framework/actor/control_flow/entrance_actor.h" | |||
| #include "abstract/utils.h" | |||
| #include "utils/log_adapter.h" | |||
| #include "runtime/framework/actor/output_actor.h" | |||
| namespace mindspore { | |||
| namespace runtime {} // namespace runtime | |||
| namespace runtime { | |||
| constexpr size_t kMaxSwitchCondSize = 8; | |||
| constexpr size_t kSwitchDefaultOutputNum = 1; | |||
| SwitchActor::SwitchActor(const std::string &name, const std::vector<KernelWithIndex> ¶meters, | |||
| const AnfNodePtr &node) | |||
| : ControlActor(name, KernelTransformType::kSwitchActor, parameters, node) { | |||
| device_contexts_.resize(parameters.size()); | |||
| output_data_by_output_index_.resize(kSwitchDefaultOutputNum); | |||
| } | |||
| void SwitchActor::Init() { | |||
| // Init output data. | |||
| for (const auto &data_arrow : output_data_arrows_) { | |||
| if (data_arrow->from_output_index_ != 0) { | |||
| MS_LOG(ERROR) << "Invalid from index:" << data_arrow->from_output_index_ << " for actor:" << GetAID(); | |||
| } | |||
| auto data = std::make_unique<OpData<DeviceTensor>>(data_arrow->to_op_id_, nullptr, data_arrow->to_input_index_); | |||
| MS_EXCEPTION_IF_NULL(data); | |||
| (void)output_data_.emplace_back(std::move(data)); | |||
| (void)output_data_by_output_index_[data_arrow->from_output_index_].emplace_back(data.get()); | |||
| } | |||
| } | |||
| void SwitchActor::FetchInput(OpContext<DeviceTensor> *const context) { | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| // Call the base class interface to get input data and input partial. | |||
| ControlActor::FetchInput(context); | |||
| size_t index = GetIndex(context); | |||
| if (!output_partial_arrows_.empty()) { | |||
| auto func_graph = input_partials_[index + kSwitchCondPos].first; | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| output_partial_ = input_partials_[index + kSwitchCondPos]; | |||
| } | |||
| for (auto &output_data : output_data_by_output_index_[kSwitchDefaultOutputNum - 1]) { | |||
| MS_EXCEPTION_IF_NULL(output_data); | |||
| MS_EXCEPTION_IF_NULL(input_device_tensors_[index + kSwitchCondPos]); | |||
| output_data->data_ = input_device_tensors_[index + kSwitchCondPos]; | |||
| } | |||
| } | |||
| size_t SwitchActor::GetIndex(const OpContext<DeviceTensor> *const context) const { | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| MS_EXCEPTION_IF_NULL(input_device_tensors_[0]); | |||
| DeviceTensor *device_tensor = input_device_tensors_[0]; | |||
| TypeId type_id = device_tensor->type_id(); | |||
| size_t size = abstract::TypeIdSize(type_id); | |||
| if (size > sizeof(int64_t)) { | |||
| MS_LOG(ERROR) << "Index must be Int type."; | |||
| return 0; | |||
| } | |||
| int64_t index = 0; | |||
| char buf[kMaxSwitchCondSize] = {0}; | |||
| ShapeVector host_shape; | |||
| if (!device_tensor->SyncDeviceToHost(host_shape, size, type_id, static_cast<void *>(buf))) { | |||
| MS_LOG(ERROR) << GetAID().Name() << " get index from device address failed, type id:" << std::to_string(type_id) | |||
| << ", device type:" << std::to_string(static_cast<int>(device_contexts_[0]->GetDeviceAddressType())); | |||
| return 0; | |||
| } | |||
| if (type_id == TypeId::kNumberTypeInt32) { | |||
| index = static_cast<int64_t>((static_cast<int32_t *>(static_cast<void *>(buf)))[0]); | |||
| } else if (type_id == TypeId::kNumberTypeInt64) { | |||
| index = (static_cast<int64_t *>(static_cast<void *>(buf)))[0]; | |||
| } else if (type_id == TypeId::kNumberTypeBool) { | |||
| bool cond = (static_cast<bool *>(static_cast<void *>(buf)))[0]; | |||
| index = static_cast<int64_t>(cond ? 1 : 0); | |||
| } else { | |||
| MS_LOG(ERROR) << "Index must be Int type."; | |||
| return 0; | |||
| } | |||
| // SwitchLayer node support negative index range [-size, -1]. | |||
| if (index < 0) { | |||
| index += SizeToInt(formal_parameters_.size() - 1); | |||
| } | |||
| return LongToSize(index); | |||
| } | |||
| } // namespace runtime | |||
| } // namespace mindspore | |||
| @@ -40,7 +40,7 @@ class SwitchActor : public ControlActor { | |||
| void Init() override; | |||
| protected: | |||
| void FetchInput(OpContext<DeviceTensor> *const context); | |||
| void FetchInput(OpContext<DeviceTensor> *const context) override; | |||
| private: | |||
| friend class ControlNodeScheduler; | |||
| @@ -50,6 +50,7 @@ class DataSourceActor : public DebugAwareActor { | |||
| protected: | |||
| friend class GraphScheduler; | |||
| friend class ControlNodeScheduler; | |||
| void Run(OpContext<DeviceTensor> *const context) override { FetchData(context); } | |||
| @@ -96,6 +97,7 @@ class DeviceQueueDataSourceActor : public DataSourceActor { | |||
| private: | |||
| friend class GraphScheduler; | |||
| friend class ControlNodeScheduler; | |||
| // Input data kernel(for example GetNext) fetches data from device queue. | |||
| CNodePtr data_kernel_{nullptr}; | |||
| @@ -130,6 +132,7 @@ class HostQueueDataSourceActor : public DataSourceActor { | |||
| private: | |||
| friend class GraphScheduler; | |||
| friend class ControlNodeScheduler; | |||
| // Judge all the data_nodes_ is from the same device. | |||
| bool IsSameDeviceType() const; | |||
| @@ -81,6 +81,7 @@ class KernelActor : public DebugAwareActor { | |||
| private: | |||
| friend class GraphScheduler; | |||
| friend class ControlNodeScheduler; | |||
| // Fetch the device tensor for launch. | |||
| void FetchInputDeviceTensor(OpContext<DeviceTensor> *const context); | |||
| @@ -74,6 +74,7 @@ class OutputActor : public AbstractActor { | |||
| private: | |||
| friend class GraphScheduler; | |||
| friend class ControlNodeScheduler; | |||
| TensorPtr CreateOutputTensor(const AnfNodePtr &output_node, size_t output_index, size_t output_position); | |||