Browse Source

Add interface implementation

tags/v1.6.0
gaoyong10 4 years ago
parent
commit
ba28e554d9
16 changed files with 685 additions and 39 deletions
  1. +5
    -3
      mindspore/ccsrc/runtime/framework/actor/abstract_actor.cc
  2. +204
    -0
      mindspore/ccsrc/runtime/framework/actor/control_flow/control_actor.cc
  3. +17
    -14
      mindspore/ccsrc/runtime/framework/actor/control_flow/control_actor.h
  4. +145
    -0
      mindspore/ccsrc/runtime/framework/actor/control_flow/entrance_actor.cc
  5. +9
    -4
      mindspore/ccsrc/runtime/framework/actor/control_flow/entrance_actor.h
  6. +119
    -0
      mindspore/ccsrc/runtime/framework/actor/control_flow/exit_actor.cc
  7. +13
    -2
      mindspore/ccsrc/runtime/framework/actor/control_flow/exit_actor.h
  8. +34
    -1
      mindspore/ccsrc/runtime/framework/actor/control_flow/gather_actor.cc
  9. +7
    -4
      mindspore/ccsrc/runtime/framework/actor/control_flow/gather_actor.h
  10. +37
    -0
      mindspore/ccsrc/runtime/framework/actor/control_flow/stack_actor.cc
  11. +3
    -4
      mindspore/ccsrc/runtime/framework/actor/control_flow/stack_actor.h
  12. +86
    -6
      mindspore/ccsrc/runtime/framework/actor/control_flow/switch_actor.cc
  13. +1
    -1
      mindspore/ccsrc/runtime/framework/actor/control_flow/switch_actor.h
  14. +3
    -0
      mindspore/ccsrc/runtime/framework/actor/data_source_actor.h
  15. +1
    -0
      mindspore/ccsrc/runtime/framework/actor/kernel_actor.h
  16. +1
    -0
      mindspore/ccsrc/runtime/framework/actor/output_actor.h

+ 5
- 3
mindspore/ccsrc/runtime/framework/actor/abstract_actor.cc View File

@@ -96,8 +96,9 @@ void AbstractActor::SendOutput(OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
// Must be the execution order: send data --> send control, avoid the illegal timing problem.
// 1.Send output data.
if ((output_data_arrows_.size() != output_data_.size()) ||
(output_data_arrows_.size() != output_data_nodes_.size())) {
if (((output_data_arrows_.size() != output_data_.size()) ||
(output_data_arrows_.size() != output_data_nodes_.size())) &&
(type_ < KernelTransformType::kSwitchActor)) {
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "The size of output data arrows is not equal to the output data.");
}
size_t output_data_arrow_index = 0;
@@ -121,7 +122,8 @@ void AbstractActor::SendOutput(OpContext<DeviceTensor> *const context) {
SendRecorderInfo(context);

// No output.
if ((output_data_arrows_.size() == 0) && (output_control_arrows_.size() == 0)) {
if ((output_data_arrows_.size() == 0) && (output_control_arrows_.size() == 0) &&
(type_ < KernelTransformType::kSwitchActor)) {
SET_OPCONTEXT_SUCCESS_RET((*context));
}
}


+ 204
- 0
mindspore/ccsrc/runtime/framework/actor/control_flow/control_actor.cc View File

@@ -0,0 +1,204 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "runtime/framework/actor/control_flow/control_actor.h"

namespace mindspore {
namespace runtime {
ControlActor::ControlActor(const std::string &name, KernelTransformType type,
const std::vector<KernelWithIndex> &parameters, const AnfNodePtr &node)
: AbstractActor(name, type, nullptr), formal_parameters_(parameters), node_(node) {
input_partials_.resize(parameters.size());
input_device_tensors_.resize(parameters.size());
}

void ControlActor::Init() {
output_data_by_output_index_.resize(formal_parameters_.size());
for (auto &data_arrow : output_data_arrows_) {
MS_EXCEPTION_IF_NULL(data_arrow);
if (IntToSize(data_arrow->from_output_index_) >= formal_parameters_.size()) {
MS_LOG(EXCEPTION) << "The output index is out of range: " << GetAID();
}

auto data = std::make_unique<OpData<DeviceTensor>>(data_arrow->to_op_id_, nullptr, data_arrow->to_input_index_);
(void)output_data_by_output_index_[data_arrow->from_output_index_].emplace_back(data.get());
(void)output_data_.emplace_back(std::move(data));
}
}

size_t ControlActor::FetchNodePosition(const KernelWithIndex &node) const {
const auto &iter = find(formal_parameters_.begin(), formal_parameters_.end(), node);
if (iter == formal_parameters_.end()) {
MS_LOG(EXCEPTION) << "Invalid formal parameter:" << node.first->DebugString() << " for actor:" << GetAID();
}
return iter - formal_parameters_.begin();
}

void ControlActor::Run(OpContext<DeviceTensor> *const context) {
FetchInput(context);
EraseInput(context);
SendOutput(context);
}

void ControlActor::RunOpPartial(FuncGraph *func_graph, std::vector<DeviceTensor *> input_data, size_t position,
OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(context);
auto &sequential_num = context->sequential_num_;
input_op_partials_[sequential_num].emplace_back(position, OpPartial(func_graph, input_data));

if (CheckRunningCondition(context)) {
Run(context);
}
}

void ControlActor::RunBranchID(int branch_id, OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
auto &sequential_num = context->sequential_num_;
input_branch_ids_[sequential_num].push(branch_id);

if (CheckRunningCondition(context)) {
Run(context);
}
}

bool ControlActor::CheckRunningCondition(const OpContext<DeviceTensor> *context) const {
MS_EXCEPTION_IF_NULL(context);

if (!AbstractActor::CheckRunningCondition(context)) {
return false;
}

if (input_partials_num_ != 0) {
const auto &partial_iter = input_op_partials_.find(context->sequential_num_);
if (partial_iter == input_op_partials_.end()) {
return false;
}
if (partial_iter->second.size() != input_partials_num_) {
return false;
}
}
return true;
}

void ControlActor::FetchInput(OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);

// Fetch input device tensor from input data.
const auto &data_iter = input_op_datas_.find(context->sequential_num_);
if (data_iter != input_op_datas_.end()) {
for (auto &input_data : data_iter->second) {
MS_EXCEPTION_IF_NULL(input_data);
if (IntToSize(input_data->index_) >= input_device_tensors_.size()) {
MS_LOG(ERROR) << "Invalid index, need:" << input_data->index_ << " current:" << input_device_tensors_.size()
<< " for actor:" << GetAID();
}

input_device_tensors_[input_data->index_] = input_data->data_;
}
}

// Fetch input device tensor from device store.
for (auto &device_tensor_store_key : device_tensor_store_keys_) {
auto device_tensor = DeviceTensorStore::GetInstance().Fetch(device_tensor_store_key.second.get(),
device_contexts_[0]->GetDeviceAddressType());
if (device_tensor == nullptr) {
MS_LOG(ERROR) << GetAID() << " get device tensor store failed: " << device_tensor_store_key.second->DebugString();
}

if (device_tensor_store_key.first >= input_device_tensors_.size()) {
MS_LOG(ERROR) << "The input index is out of range, need:" << device_tensor_store_key.first
<< " current:" << input_device_tensors_.size() << " for actor:" << GetAID();
}
input_device_tensors_[device_tensor_store_key.first] = device_tensor;
}

for (size_t i = 0; i < output_data_by_output_index_.size(); ++i) {
if (output_data_by_output_index_[i].empty()) {
continue;
}
const auto &data = input_device_tensors_[i];
MS_EXCEPTION_IF_NULL(data);
for (auto &output_data : output_data_by_output_index_[i]) {
MS_EXCEPTION_IF_NULL(output_data);
output_data->data_ = data;
}
}

// Fetch input partial from input data.
const auto &partial_iter = input_op_partials_.find(context->sequential_num_);
if (partial_iter != input_op_partials_.end()) {
for (const auto &input_partial : partial_iter->second) {
MS_EXCEPTION_IF_NULL(input_partial.second.first);
input_partials_[input_partial.first] = input_partial.second;
}
}

// Fetch input partial from local partial.
for (const auto &local_partial : local_partials_) {
input_partials_[local_partial.first] = local_partial.second;
}

// Fetch branch id in stack.
auto iter = input_branch_ids_.find(context->sequential_num_);
if (iter != input_branch_ids_.end() && (!iter->second.empty())) {
output_branch_id_ = iter->second.top();
}
}

void ControlActor::EraseInput(const OpContext<DeviceTensor> *context) {
AbstractActor::EraseInput(context);
const auto &sequential_num = context->sequential_num_;
if (input_partials_num_ != 0) {
auto ret = input_op_partials_.erase(sequential_num);
if (ret == 0) {
std::string error_info = "Erase input partial failed: " + GetAID().Name();
// The sequential num may be invalid, can't set the promise value of context.
MS_LOG(ERROR) << error_info << ", sequential_num: " << sequential_num;
}
}

if (input_branch_ids_.find(sequential_num) != input_branch_ids_.end()) {
input_branch_ids_[sequential_num].pop();
if (input_branch_ids_[sequential_num].empty()) {
auto ret = input_branch_ids_.erase(sequential_num);
if (ret == 0) {
MS_LOG(ERROR) << "Erase input branch id failed: " << GetAID() << ", sequential_num: " << sequential_num;
return;
}
}
}
}

void ControlActor::SendOutput(OpContext<DeviceTensor> *const context) {
// Send branch id.
for (const auto &branch_id_arrow : output_branch_id_arrows_) {
Async(branch_id_arrow, &ControlActor::RunBranchID, output_branch_id_, context);
}

// Send data in base class.
AbstractActor::SendOutput(context);

// Send Partial.
for (const auto &partial_arrow : output_partial_arrows_) {
MS_EXCEPTION_IF_NULL(partial_arrow);
MS_EXCEPTION_IF_NULL(output_partial_.first);
Async(partial_arrow->to_op_id_, &ControlActor::RunOpPartial, output_partial_.first, output_partial_.second,
IntToSize(partial_arrow->to_input_index_), context);
}
}
} // namespace runtime
} // namespace mindspore

+ 17
- 14
mindspore/ccsrc/runtime/framework/actor/control_flow/control_actor.h View File

@@ -35,7 +35,7 @@ namespace runtime {
// parameters and the id of the caller.
using OpDataWithBranchID = std::pair<std::vector<DeviceTensor *>, int>;
// Op partial represents the partial structure, including a funcgraph and its real parameters.
using OpPartial = std::pair<FuncGraphPtr, std::vector<DeviceTensor *>>;
using OpPartial = std::pair<FuncGraph *, std::vector<DeviceTensor *>>;
// The control actor is the base class of control flow actor.
class ControlActor : public AbstractActor {
public:
@@ -43,9 +43,7 @@ class ControlActor : public AbstractActor {
const AnfNodePtr &node);
~ControlActor() override = default;

void Init() override {}

void RunOpControl(AID *const input_control, OpContext<DeviceTensor> *const context) override {}
void Init() override;

// Receive partial.
virtual void RunOpPartial(FuncGraph *func_graph, std::vector<DeviceTensor *> input_data, size_t position,
@@ -54,16 +52,19 @@ class ControlActor : public AbstractActor {
// Receive branch id.
virtual void RunBranchID(int branch_id, OpContext<DeviceTensor> *const context);

const std::vector<DataArrowPtr> &output_partial_arrows() const { return output_partial_arrows_; }
const std::vector<AID> &output_branch_id_arrows() const { return output_branch_id_arrows_; }

protected:
// Get the position of node in the input.
size_t FetchNodePosition(const KernelWithIndex &node) const;

// Get all input, including data, partial, branchid.
virtual void FetchInput(OpContext<DeviceTensor> *const context);
void Run(OpContext<DeviceTensor> *const context);
bool CheckRunningCondition(const OpContext<DeviceTensor> *context) const;
void SendOutput(OpContext<DeviceTensor> *const context);
void EraseInput(const OpContext<DeviceTensor> *context);
void Run(OpContext<DeviceTensor> *const context) override;
bool CheckRunningCondition(const OpContext<DeviceTensor> *context) const override;
void SendOutput(OpContext<DeviceTensor> *const context) override;
void EraseInput(const OpContext<DeviceTensor> *context) override;

// Input data.
// 1.Input partial.
@@ -81,17 +82,19 @@ class ControlActor : public AbstractActor {
// Fetch data. After fetch input, all the input collected is saved here.
std::vector<OpPartial> input_partials_;
std::vector<DeviceTensor *> input_device_tensors_;
// The branch id is the unique identifier of the control actor. In the control flow, there are multiple control
// actors calling the same subgraph at the same time. At this time, the output of the subgraph needs to be returned
// to the calling place according to the branch id.
int branch_id_;

// Input num.
size_t input_partials_num_;
size_t input_partials_num_{0};

// Output Arrows.
std::vector<DataArrowPtr> output_partial_arrows_;
std::vector<DataArrowPtr> output_branch_id_arrows_;
OpPartial output_partial_;

std::vector<AID> output_branch_id_arrows_;
// The branch id is the unique identifier of the control actor. In the control flow, there are multiple control
// actors calling the same subgraph at the same time. At this time, the output of the subgraph needs to be returned
// to the calling place according to the branch id.
int output_branch_id_;

// Partial data in local. When partial is only funcgraph without real parameter, it is stored inside the actor.
std::unordered_map<size_t, OpPartial> local_partials_;


+ 145
- 0
mindspore/ccsrc/runtime/framework/actor/control_flow/entrance_actor.cc View File

@@ -0,0 +1,145 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "runtime/framework/actor/control_flow/entrance_actor.h"
#include "runtime/framework/actor/control_flow/exit_actor.h"

namespace mindspore {
namespace runtime {
constexpr size_t kEntranceInputStartPos = 1;

void EntranceActor::RunOpDataWithBranchID(std::vector<DeviceTensor *> input_data, int branch_id,
OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
auto &sequential_num = context->sequential_num_;
input_op_data_with_branch_id_[sequential_num].emplace(input_data, branch_id);

if (CheckRunningCondition(context)) {
Run(context);
}
}

void EntranceActor::Run(OpContext<DeviceTensor> *const context) {
FetchInput(context);
EraseInput(context);
SendOutput(context);
// The actor needs to be disabled after the actor is running, until no actor is running in the entire funcgraph.
is_actor_ready_ = false;
}

void EntranceActor::FetchInput(OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
auto &sequential_num = context->sequential_num_;

// There are two kinds of run conditions for entrance actor:
// 1.Data comes from the data source actor, it is in the form of data arrow.
const auto &data_iter = input_op_datas_.find(sequential_num);
if (data_iter != input_op_datas_.end()) {
for (auto &input_data : data_iter->second) {
MS_EXCEPTION_IF_NULL(input_data);
if (IntToSize(input_data->index_) >= input_device_tensors_.size()) {
MS_LOG(ERROR) << "The input index is out of range, need:" << input_data->index_
<< " current:" << input_device_tensors_.size() << " for actor:" << GetAID();
}
MS_EXCEPTION_IF_NULL(input_data->data_);
input_device_tensors_[input_data->index_] = input_data->data_;
}
// If the data comes from the data source actor, use the default branch id.
output_branch_id_ = 0;
} else {
// 2.Data comes from the gather actor, it is in the form of data with branch id.
output_branch_id_ = input_op_data_with_branch_id_[sequential_num].front().second;
const auto &device_tensors = input_op_data_with_branch_id_[sequential_num].front().first;
if (device_tensors.size() != formal_parameters_.size()) {
MS_LOG(ERROR) << "Invalid input num, need:" << formal_parameters_.size() << " current:" << device_tensors.size();
}
input_device_tensors_ = device_tensors;
}

// Init the device tensor in output data.
for (size_t i = 0; i < output_data_by_output_index_.size(); ++i) {
if (output_data_by_output_index_[i].empty()) {
continue;
}
const auto &data = input_device_tensors_[i];
MS_EXCEPTION_IF_NULL(data);
for (auto &output_data : output_data_by_output_index_[i]) {
MS_EXCEPTION_IF_NULL(output_data);
output_data->data_ = data;
}
}
}

bool EntranceActor::CheckActorStatus(const OpContext<DeviceTensor> *const context) const {
if (is_actor_ready_) {
return true;
}
// During operation, entrance actor can be enabled only when receives all control arrows.
if (input_controls_num_ != 0) {
const auto &control_iter = input_op_controls_.find(context->sequential_num_);
if (control_iter != input_op_controls_.end() && control_iter->second.size() == input_controls_num_) {
return true;
}
}
return false;
}

bool EntranceActor::CheckRunningCondition(const OpContext<DeviceTensor> *context) const {
MS_EXCEPTION_IF_NULL(context);

// When the entrance actor is in the disabled state, it cannot be run.
if (!CheckActorStatus(context)) {
return false;
}

// Data comes from the data source actor.
if (input_datas_num_ != 0) {
const auto &data_iter = input_op_datas_.find(context->sequential_num_);
if (data_iter != input_op_datas_.end() && data_iter->second.size() == input_datas_num_) {
return true;
}
}

// Data comes from the gather actor.
const auto &iter = input_op_data_with_branch_id_.find(context->sequential_num_);
if (iter == input_op_data_with_branch_id_.end() || iter->second.empty()) {
return false;
}
return true;
}

void EntranceActor::EraseInput(const OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
auto &sequential_num = context->sequential_num_;

const auto &data_iter = input_op_datas_.find(sequential_num);
if (data_iter != input_op_datas_.end()) {
input_op_datas_.erase(data_iter);
return;
}

const auto &iter = input_op_data_with_branch_id_.find(sequential_num);
if (iter == input_op_data_with_branch_id_.end() || iter->second.empty()) {
MS_LOG(ERROR) << "Cannot find input in batch op result for actor:" << GetAID();
}

iter->second.pop();
if (iter->second.empty()) {
input_op_data_with_branch_id_.erase(sequential_num);
}
}
} // namespace runtime
} // namespace mindspore

+ 9
- 4
mindspore/ccsrc/runtime/framework/actor/control_flow/entrance_actor.h View File

@@ -40,17 +40,22 @@ class EntranceActor : public ControlActor {
input_device_tensors_.resize(parameters.size());
}
~EntranceActor() override = default;
void RunOpControl(AID *const input_control, OpContext<DeviceTensor> *const context);
void RunOpDataWithBranchID(std::vector<DeviceTensor *> input_data, int branch_id,
OpContext<DeviceTensor> *const context);

protected:
void FetchInput(OpContext<DeviceTensor> *const context);
bool CheckRunningCondition(const OpContext<DeviceTensor> *context) const;
void EraseInput(const OpContext<DeviceTensor> *const context);
void Run(OpContext<DeviceTensor> *const context) override;
void FetchInput(OpContext<DeviceTensor> *const context) override;
bool CheckRunningCondition(const OpContext<DeviceTensor> *context) const override;
void EraseInput(const OpContext<DeviceTensor> *const context) override;

private:
friend class ControlNodeScheduler;

// Check if actor is enable. During operation, entrance actor can be enabled only when receives all control arrows.
bool CheckActorStatus(const OpContext<DeviceTensor> *const context) const;

// Is actor ready indicates whether the entrance actor can be executed. In the control flow, the subgraph is an
// atomic operation, and execution can only continue after the output of the corresponding exit actor is completed.
// At this time, the exit actor will notify the entrance actor to change the ready to true.


+ 119
- 0
mindspore/ccsrc/runtime/framework/actor/control_flow/exit_actor.cc View File

@@ -0,0 +1,119 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "runtime/framework/actor/control_flow/exit_actor.h"
#include "runtime/framework/actor/output_actor.h"

namespace mindspore {
namespace runtime {
void ExitActor::Init() {
// Init output data in base class.
ControlActor::Init();

// Init output data in each output branch.
for (size_t i = 0; i < output_branch_data_arrows_.size(); ++i) {
auto &output_branch_data_arrows = output_branch_data_arrows_[i];
for (auto &data_arrow : output_branch_data_arrows) {
MS_EXCEPTION_IF_NULL(data_arrow);
auto data = std::make_unique<OpData<DeviceTensor>>(data_arrow->to_op_id_, nullptr, data_arrow->to_input_index_);
output_branch_data_[i].emplace_back(data_arrow->from_output_index_, std::move(data));
}
}
}

void ExitActor::FetchInput(OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
ControlActor::FetchInput(context);
CopyDeviceAddress();

auto data_iter = output_branch_data_.find(output_branch_id_);
if (data_iter != output_branch_data_.end()) {
for (auto &output_data : data_iter->second) {
MS_EXCEPTION_IF_NULL(output_data.second);
MS_EXCEPTION_IF_NULL(input_device_tensors_[output_data.first]);
output_data.second->data_ = input_device_tensors_[output_data.first];
}
}
}

void ExitActor::SendOutput(OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);

// 1.Send output in base class.
ControlActor::SendOutput(context);

// 2.Send output control in output branch.
const auto &control_iter = output_branch_control_arrows_.find(output_branch_id_);
if (control_iter != output_branch_control_arrows_.end()) {
auto source_aid = const_cast<AID *>(&GetAID());
for (const auto &control_arrow : control_iter->second) {
Async(control_arrow, &OpActor::RunOpControl, source_aid, context);
}
}

// 2.Send output data in output branch.
const auto &branch_data_iter = output_branch_data_.find(output_branch_id_);
if (branch_data_iter != output_branch_data_.end()) {
for (const auto &output_data : branch_data_iter->second) {
MS_EXCEPTION_IF_NULL(output_data.second);
Async(output_data.second->op_id_, &OpActor::RunOpData, output_data.second.get(), context);
}
}
}

void ExitActor::CopyDeviceAddress() {
std::vector<DeviceTensor *> new_device_tensors;
for (size_t i = 0; i < input_device_tensors_.size(); ++i) {
auto input_device_tensor = input_device_tensors_[i];
MS_EXCEPTION_IF_NULL(input_device_tensor);
const KernelWithIndex &node_with_index = input_device_tensor->GetNodeIndex();
MS_EXCEPTION_IF_NULL(node_with_index.first);
if (!node_with_index.first->isa<CNode>()) {
continue;
}

MS_EXCEPTION_IF_NULL(device_contexts_[i]);
auto new_device_tensor =
device_contexts_[i]->CreateDeviceAddress(nullptr, input_device_tensors_[i]->GetSize(),
input_device_tensors_[i]->format(), input_device_tensors_[i]->type_id());
MS_EXCEPTION_IF_NULL(new_device_tensor);
new_device_tensor->set_ptr(input_device_tensor->GetMutablePtr());
new_device_tensor->set_from_mem_pool(input_device_tensor->from_mem_pool());
new_device_tensor->SetNodeIndex(node_with_index.first, node_with_index.second);
new_device_tensor->set_original_ref_count(SIZE_MAX);
new_device_tensor->ResetRefCount();
new_device_tensors.emplace_back(new_device_tensor.get());
created_device_tensors_.emplace_back(new_device_tensor);

input_device_tensor->set_ptr(nullptr);
input_device_tensor->set_from_mem_pool(false);
}
input_device_tensors_.swap(new_device_tensors);

for (size_t i = 0; i < output_data_by_output_index_.size(); ++i) {
if (output_data_by_output_index_[i].empty()) {
continue;
}
const auto &data = input_device_tensors_[i];
MS_EXCEPTION_IF_NULL(data);
for (auto &output_data : output_data_by_output_index_[i]) {
MS_EXCEPTION_IF_NULL(output_data);
output_data->data_ = data;
}
}
}
} // namespace runtime
} // namespace mindspore

+ 13
- 2
mindspore/ccsrc/runtime/framework/actor/control_flow/exit_actor.h View File

@@ -39,10 +39,21 @@ class ExitActor : public ControlActor {
}
~ExitActor() override = default;

void Init();
void Init() override;

const std::unordered_map<int, std::vector<AID>> &output_branch_control_arrows() const {
return output_branch_control_arrows_;
}
const std::unordered_map<int, std::vector<DataArrowPtr>> &output_branch_data_arrows() const {
return output_branch_data_arrows_;
}
const std::unordered_map<int, std::vector<DataArrowPtr>> &output_branch_partial_arrows() const {
return output_branch_partial_arrows_;
}

protected:
void FetchInput(OpContext<DeviceTensor> *const context);
void FetchInput(OpContext<DeviceTensor> *const context) override;
void SendOutput(OpContext<DeviceTensor> *const context) override;

private:
friend class ControlNodeScheduler;


+ 34
- 1
mindspore/ccsrc/runtime/framework/actor/control_flow/gather_actor.cc View File

@@ -15,7 +15,40 @@
*/

#include "runtime/framework/actor/control_flow/gather_actor.h"
#include "runtime/framework/actor/control_flow/entrance_actor.h"

namespace mindspore {
namespace runtime {} // namespace runtime
namespace runtime {
GatherActor::GatherActor(const std::string &name, const std::vector<KernelWithIndex> &parameters,
const AnfNodePtr &node)
: ControlActor(name, KernelTransformType::kGatherActor, parameters, node) {}

void GatherActor::FetchInput(OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);

ControlActor::FetchInput(context);
output_partial_ = input_partials_[0];

// Put other real parameter in partial.
for (const auto &device_tensor : input_device_tensors_) {
if (device_tensor != nullptr) {
output_partial_.second.emplace_back(device_tensor);
}
}
}

void GatherActor::SendOutput(OpContext<DeviceTensor> *const context) {
// Send data with branch id.
const auto &iter = output_data_with_branch_id_arrows_.find(output_partial_.first);
if (iter != output_data_with_branch_id_arrows_.end()) {
for (const auto &data_with_branch_id_arrow : iter->second) {
Async(data_with_branch_id_arrow, &EntranceActor::RunOpDataWithBranchID, output_partial_.second, output_branch_id_,
context);
}
}

// Control arrow needs to be sent after the real parameter data and branch id.
ControlActor::SendOutput(context);
}
} // namespace runtime
} // namespace mindspore

+ 7
- 4
mindspore/ccsrc/runtime/framework/actor/control_flow/gather_actor.h View File

@@ -37,16 +37,19 @@ class GatherActor : public ControlActor {
public:
GatherActor(const std::string &name, const std::vector<KernelWithIndex> &parameters, const AnfNodePtr &node);
~GatherActor() override = default;
const std::unordered_map<FuncGraph *, std::vector<AID>> &output_data_with_branch_id_arrows() const {
return output_data_with_branch_id_arrows_;
}

protected:
void FetchInput(OpContext<DeviceTensor> *const context);
void SendOutput(OpContext<DeviceTensor> *const context);
void FetchInput(OpContext<DeviceTensor> *const context) override;
void SendOutput(OpContext<DeviceTensor> *const context) override;

private:
friend class ControlNodeScheduler;

// When the output data arrow needs to have a branch id, there will be multiple output branches.
std::unordered_map<int, std::vector<AID>> output_data_with_branch_id_arrows_;
// There will be multiple output branches for gather actor according the funcgraph in partial.
std::unordered_map<FuncGraph *, std::vector<AID>> output_data_with_branch_id_arrows_;
};

using GatherActorPtr = std::shared_ptr<GatherActor>;


+ 37
- 0
mindspore/ccsrc/runtime/framework/actor/control_flow/stack_actor.cc View File

@@ -0,0 +1,37 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "runtime/framework/actor/control_flow/stack_actor.h"
#include "runtime/framework/actor/memory_manager_actor.h"
#include "runtime/framework/control_node_parser.h"

namespace mindspore {
namespace runtime {
StackActor::StackActor(const std::string &name, const std::vector<KernelWithIndex> &parameters)
: ControlActor(name, KernelTransformType::kStackActor, parameters, nullptr) {
input_device_tensors_.resize(parameters.size());
}

bool StackActor::CheckRunningCondition(const OpContext<DeviceTensor> *context) const {
MS_EXCEPTION_IF_NULL(context);
return false;
}

void StackActor::FetchInput(OpContext<DeviceTensor> *const context) { MS_EXCEPTION_IF_NULL(context); }

void StackActor::EraseInput(const OpContext<DeviceTensor> *const context) { MS_EXCEPTION_IF_NULL(context); }
} // namespace runtime
} // namespace mindspore

+ 3
- 4
mindspore/ccsrc/runtime/framework/actor/control_flow/stack_actor.h View File

@@ -39,10 +39,9 @@ class StackActor : public ControlActor {
~StackActor() override = default;

protected:
void RunOpData(OpData<DeviceTensor> *const input_data, OpContext<DeviceTensor> *const context);
void FetchInput(OpContext<DeviceTensor> *const context);
bool CheckRunningCondition(const OpContext<DeviceTensor> *context) const;
void EraseInput(const OpContext<DeviceTensor> *const context);
void FetchInput(OpContext<DeviceTensor> *const context) override;
bool CheckRunningCondition(const OpContext<DeviceTensor> *context) const override;
void EraseInput(const OpContext<DeviceTensor> *const context) override;

private:
friend class ControlNodeScheduler;


+ 86
- 6
mindspore/ccsrc/runtime/framework/actor/control_flow/switch_actor.cc View File

@@ -15,13 +15,93 @@
*/

#include "runtime/framework/actor/control_flow/switch_actor.h"
#include "runtime/framework/actor/control_flow/gather_actor.h"
#include "runtime/framework/actor/output_actor.h"
#include "runtime/framework/actor/memory_manager_actor.h"
#include "mindrt/include/async/async.h"
#include "runtime/framework/actor/control_flow/entrance_actor.h"
#include "abstract/utils.h"
#include "utils/log_adapter.h"
#include "runtime/framework/actor/output_actor.h"

namespace mindspore {
namespace runtime {} // namespace runtime
namespace runtime {
constexpr size_t kMaxSwitchCondSize = 8;
constexpr size_t kSwitchDefaultOutputNum = 1;

SwitchActor::SwitchActor(const std::string &name, const std::vector<KernelWithIndex> &parameters,
const AnfNodePtr &node)
: ControlActor(name, KernelTransformType::kSwitchActor, parameters, node) {
device_contexts_.resize(parameters.size());
output_data_by_output_index_.resize(kSwitchDefaultOutputNum);
}

void SwitchActor::Init() {
// Init output data.
for (const auto &data_arrow : output_data_arrows_) {
if (data_arrow->from_output_index_ != 0) {
MS_LOG(ERROR) << "Invalid from index:" << data_arrow->from_output_index_ << " for actor:" << GetAID();
}
auto data = std::make_unique<OpData<DeviceTensor>>(data_arrow->to_op_id_, nullptr, data_arrow->to_input_index_);
MS_EXCEPTION_IF_NULL(data);
(void)output_data_.emplace_back(std::move(data));
(void)output_data_by_output_index_[data_arrow->from_output_index_].emplace_back(data.get());
}
}

void SwitchActor::FetchInput(OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);

// Call the base class interface to get input data and input partial.
ControlActor::FetchInput(context);
size_t index = GetIndex(context);

if (!output_partial_arrows_.empty()) {
auto func_graph = input_partials_[index + kSwitchCondPos].first;
MS_EXCEPTION_IF_NULL(func_graph);
output_partial_ = input_partials_[index + kSwitchCondPos];
}

for (auto &output_data : output_data_by_output_index_[kSwitchDefaultOutputNum - 1]) {
MS_EXCEPTION_IF_NULL(output_data);
MS_EXCEPTION_IF_NULL(input_device_tensors_[index + kSwitchCondPos]);
output_data->data_ = input_device_tensors_[index + kSwitchCondPos];
}
}

size_t SwitchActor::GetIndex(const OpContext<DeviceTensor> *const context) const {
MS_EXCEPTION_IF_NULL(context);
MS_EXCEPTION_IF_NULL(input_device_tensors_[0]);

DeviceTensor *device_tensor = input_device_tensors_[0];
TypeId type_id = device_tensor->type_id();
size_t size = abstract::TypeIdSize(type_id);
if (size > sizeof(int64_t)) {
MS_LOG(ERROR) << "Index must be Int type.";
return 0;
}

int64_t index = 0;
char buf[kMaxSwitchCondSize] = {0};
ShapeVector host_shape;
if (!device_tensor->SyncDeviceToHost(host_shape, size, type_id, static_cast<void *>(buf))) {
MS_LOG(ERROR) << GetAID().Name() << " get index from device address failed, type id:" << std::to_string(type_id)
<< ", device type:" << std::to_string(static_cast<int>(device_contexts_[0]->GetDeviceAddressType()));
return 0;
}

if (type_id == TypeId::kNumberTypeInt32) {
index = static_cast<int64_t>((static_cast<int32_t *>(static_cast<void *>(buf)))[0]);
} else if (type_id == TypeId::kNumberTypeInt64) {
index = (static_cast<int64_t *>(static_cast<void *>(buf)))[0];
} else if (type_id == TypeId::kNumberTypeBool) {
bool cond = (static_cast<bool *>(static_cast<void *>(buf)))[0];
index = static_cast<int64_t>(cond ? 1 : 0);
} else {
MS_LOG(ERROR) << "Index must be Int type.";
return 0;
}

// SwitchLayer node support negative index range [-size, -1].
if (index < 0) {
index += SizeToInt(formal_parameters_.size() - 1);
}
return LongToSize(index);
}
} // namespace runtime
} // namespace mindspore

+ 1
- 1
mindspore/ccsrc/runtime/framework/actor/control_flow/switch_actor.h View File

@@ -40,7 +40,7 @@ class SwitchActor : public ControlActor {
void Init() override;

protected:
void FetchInput(OpContext<DeviceTensor> *const context);
void FetchInput(OpContext<DeviceTensor> *const context) override;

private:
friend class ControlNodeScheduler;


+ 3
- 0
mindspore/ccsrc/runtime/framework/actor/data_source_actor.h View File

@@ -50,6 +50,7 @@ class DataSourceActor : public DebugAwareActor {

protected:
friend class GraphScheduler;
friend class ControlNodeScheduler;

void Run(OpContext<DeviceTensor> *const context) override { FetchData(context); }

@@ -96,6 +97,7 @@ class DeviceQueueDataSourceActor : public DataSourceActor {

private:
friend class GraphScheduler;
friend class ControlNodeScheduler;

// Input data kernel(for example GetNext) fetches data from device queue.
CNodePtr data_kernel_{nullptr};
@@ -130,6 +132,7 @@ class HostQueueDataSourceActor : public DataSourceActor {

private:
friend class GraphScheduler;
friend class ControlNodeScheduler;

// Judge all the data_nodes_ is from the same device.
bool IsSameDeviceType() const;


+ 1
- 0
mindspore/ccsrc/runtime/framework/actor/kernel_actor.h View File

@@ -81,6 +81,7 @@ class KernelActor : public DebugAwareActor {

private:
friend class GraphScheduler;
friend class ControlNodeScheduler;

// Fetch the device tensor for launch.
void FetchInputDeviceTensor(OpContext<DeviceTensor> *const context);


+ 1
- 0
mindspore/ccsrc/runtime/framework/actor/output_actor.h View File

@@ -74,6 +74,7 @@ class OutputActor : public AbstractActor {

private:
friend class GraphScheduler;
friend class ControlNodeScheduler;

TensorPtr CreateOutputTensor(const AnfNodePtr &output_node, size_t output_index, size_t output_position);



Loading…
Cancel
Save