Browse Source

!27283 unified runtime support the dynamic ref count

Merge pull request !27283 from limingqi107/new_actor_runtime
tags/v1.6.0
i-robot Gitee 4 years ago
parent
commit
9ab1edc09e
20 changed files with 489 additions and 143 deletions
  1. +21
    -3
      mindspore/ccsrc/runtime/device/device_address.h
  2. +115
    -3
      mindspore/ccsrc/runtime/framework/actor/control_flow/control_actor.cc
  3. +21
    -3
      mindspore/ccsrc/runtime/framework/actor/control_flow/control_actor.h
  4. +38
    -0
      mindspore/ccsrc/runtime/framework/actor/control_flow/entrance_actor.cc
  5. +5
    -2
      mindspore/ccsrc/runtime/framework/actor/control_flow/entrance_actor.h
  6. +57
    -11
      mindspore/ccsrc/runtime/framework/actor/control_flow/exit_actor.cc
  7. +4
    -2
      mindspore/ccsrc/runtime/framework/actor/control_flow/exit_actor.h
  8. +48
    -24
      mindspore/ccsrc/runtime/framework/actor/control_flow/gather_actor.cc
  9. +5
    -1
      mindspore/ccsrc/runtime/framework/actor/control_flow/gather_actor.h
  10. +51
    -44
      mindspore/ccsrc/runtime/framework/actor/control_flow/stack_actor.cc
  11. +3
    -6
      mindspore/ccsrc/runtime/framework/actor/control_flow/stack_actor.h
  12. +3
    -3
      mindspore/ccsrc/runtime/framework/actor/control_flow/switch_actor.cc
  13. +2
    -1
      mindspore/ccsrc/runtime/framework/actor/control_flow/switch_actor.h
  14. +40
    -28
      mindspore/ccsrc/runtime/framework/actor/memory_manager_actor.cc
  15. +2
    -1
      mindspore/ccsrc/runtime/framework/actor/output_actor.cc
  16. +24
    -0
      mindspore/ccsrc/runtime/framework/actor/super_kernel_actor.cc
  17. +6
    -0
      mindspore/ccsrc/runtime/framework/actor/super_kernel_actor.h
  18. +33
    -9
      mindspore/ccsrc/runtime/framework/control_node_scheduler.cc
  19. +4
    -1
      mindspore/ccsrc/runtime/framework/control_node_scheduler.h
  20. +7
    -1
      mindspore/ccsrc/runtime/framework/graph_scheduler.cc

+ 21
- 3
mindspore/ccsrc/runtime/device/device_address.h View File

@@ -107,8 +107,8 @@ class DeviceAddress : public mindspore::DeviceSync {
virtual DeviceAddressStatus status() const { return DeviceAddressStatus::kInDevice; }
virtual DeviceAddressType DeviceType() const { return DeviceAddressType::kUnknown; }
void *GetMutablePtr() const override { return ptr_; }
std::string DeviceName() const { return device_name_; }
uint32_t DeviceID() const { return device_id_; }
std::string device_name() const { return device_name_; }
uint32_t device_id() const { return device_id_; }

virtual void SetNodeIndex(const AnfNodePtr &node, size_t out_index) { node_index_ = {node, out_index}; }
KernelWithIndex GetNodeIndex() const {
@@ -116,6 +116,21 @@ class DeviceAddress : public mindspore::DeviceSync {
: KernelWithIndex{node_index_.first.lock(), node_index_.second};
}

// The related interface of dynamic reference count operation.
void set_dynamic_ref_conut(int32_t dynamic_ref_conut) { dynamic_ref_conut_ = dynamic_ref_conut; }
int32_t dynamic_ref_conut() const { return dynamic_ref_conut_; }
void IncreaseDynamicRefCount() {
if (dynamic_ref_conut_ < INT32_MAX) {
dynamic_ref_conut_++;
}
}
void DecreaseDynamicRefCount() {
if (dynamic_ref_conut_ <= 0) {
MS_LOG(EXCEPTION) << "The dynamic reference count is invalid value:" << dynamic_ref_conut_;
}
dynamic_ref_conut_--;
}

virtual bool DumpMemToFile(const std::string &filepath, const std::string &host_fmt, const ShapeVector &host_shape,
TypeId host_type, bool trans_flag) const {
return true;
@@ -142,10 +157,13 @@ class DeviceAddress : public mindspore::DeviceSync {
// {node, out_index}
std::pair<AnfNodeWeakPtr, size_t> node_index_{AnfNodePtr(nullptr), 0};
// The device address of the node that owns the device address cannot be updated and replaced.
// application scenario: set to true when the hardware execution mode requires that ptr cannot be changed during
// Application scenario: set to true when the hardware execution mode requires that ptr cannot be changed during
// execution.
bool is_ptr_persisted_{false};

// The device address generated in the control flow scene uses dynamic_ref_conut_.
std::atomic_int32_t dynamic_ref_conut_{INT32_MAX};

// The key of device context.
std::string device_name_{""};
uint32_t device_id_{0};


+ 115
- 3
mindspore/ccsrc/runtime/framework/actor/control_flow/control_actor.cc View File

@@ -18,9 +18,9 @@

namespace mindspore {
namespace runtime {
ControlActor::ControlActor(const std::string &name, KernelTransformType type,
ControlActor::ControlActor(const std::string &name, KernelTransformType type, const AID &memory_manager_aid,
const std::vector<KernelWithIndex> &parameters, const AnfNodePtr &node)
: AbstractActor(name, type, nullptr), formal_parameters_(parameters), node_(node) {
: MemoryAwareActor(name, type, nullptr, memory_manager_aid), formal_parameters_(parameters), node_(node) {
for (size_t i = 0; i < parameters.size(); ++i) {
input_partials_.emplace_back(std::make_shared<OpPartial>());
}
@@ -41,6 +41,59 @@ void ControlActor::Init() {
}
}

std::vector<DeviceTensor *> ControlActor::GetAllDeviceTensors(const OpPartialPtr &op_partial) {
MS_EXCEPTION_IF_NULL(op_partial);
std::vector<DeviceTensor *> ret;
for (auto &device_tensor : op_partial->device_tensors_) {
(void)ret.emplace_back(device_tensor.second);
}

// Foreach the op partial to fetch the device tensors.
for (auto &partial : op_partial->partials_) {
auto ret_inner = GetAllDeviceTensors(partial.second);
(void)std::copy(ret_inner.begin(), ret_inner.end(), std::back_inserter(ret));
}

return ret;
}

std::vector<DeviceTensor *> ControlActor::GetAllDeviceTensors(const OpRealParameterWithBranchID &op_real_parameter) {
std::vector<DeviceTensor *> ret;
for (auto &device_tensor : op_real_parameter.device_tensors_) {
(void)ret.emplace_back(device_tensor.second);
}

// Foreach the op partial to fetch the device tensors.
for (auto &partial : op_real_parameter.partials_) {
auto ret_inner = GetAllDeviceTensors(partial.second);
(void)std::copy(ret_inner.begin(), ret_inner.end(), std::back_inserter(ret));
}
return ret;
}

void ControlActor::IncreaseDynamicRefCount(const OpData<DeviceTensor> *op_data) {
MS_EXCEPTION_IF_NULL(op_data);
MS_EXCEPTION_IF_NULL(op_data->data_);
op_data->data_->IncreaseDynamicRefCount();
}

void ControlActor::IncreaseDynamicRefCount(const OpPartialPtr &op_partial) {
MS_EXCEPTION_IF_NULL(op_partial);
auto partial_device_tensors = GetAllDeviceTensors(op_partial);
for (auto &partial_device_tensor : partial_device_tensors) {
MS_EXCEPTION_IF_NULL(partial_device_tensor);
partial_device_tensor->IncreaseDynamicRefCount();
}
}

void ControlActor::IncreaseDynamicRefCount(const OpRealParameterWithBranchID &op_real_parameter) {
auto partial_device_tensors = GetAllDeviceTensors(op_real_parameter);
for (auto &partial_device_tensor : partial_device_tensors) {
MS_EXCEPTION_IF_NULL(partial_device_tensor);
partial_device_tensor->IncreaseDynamicRefCount();
}
}

size_t ControlActor::FetchNodePosition(const KernelWithIndex &node) const {
const auto &iter = find(formal_parameters_.begin(), formal_parameters_.end(), node);
if (iter == formal_parameters_.end()) {
@@ -52,6 +105,13 @@ size_t ControlActor::FetchNodePosition(const KernelWithIndex &node) const {

void ControlActor::Run(OpContext<DeviceTensor> *const context) {
FetchInput(context);

// Note that IncreaseDynamicRefCounts must be in front of SendMemoryFreeReq. SendMemoryFreeReq will decreasing the
// dynamic ref count. Avoid the illegal timing problem that the dynamic reference count is decremented and then
// incremented.
IncreaseDynamicRefCounts(context);
SendMemoryFreeReq(context);

EraseInput(context);
SendOutput(context);
}
@@ -197,9 +257,61 @@ void ControlActor::FetchInput(OpContext<DeviceTensor> *const context) {
}
}

void ControlActor::IncreaseDynamicRefCounts(OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
// Increase dynamic ref count by the output data.
for (auto &output_data : output_data_) {
MS_EXCEPTION_IF_NULL(output_data);
IncreaseDynamicRefCount(output_data.get());
}

// Increase dynamic ref count by the output partial.
for (const auto &partial_arrow : output_partial_arrows_) {
MS_EXCEPTION_IF_NULL(partial_arrow);
if (IntToSize(partial_arrow->from_output_index_) >= input_partials_.size()) {
std::string error_info = "Invalid partial input:" + std::to_string(partial_arrow->from_output_index_) +
" current:" + std::to_string(input_partials_.size()) + " for actor:" + GetAID().Name();
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
auto output_partial = input_partials_[partial_arrow->from_output_index_];
IncreaseDynamicRefCount(output_partial);
}
}

void ControlActor::SendMemoryFreeReq(OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
const auto &sequential_num = context->sequential_num_;

// Collect the input device tensors.
std::vector<DeviceTensor *> memory_free_list;
if (input_op_datas_.count(sequential_num) > 0) {
for (auto &input_data : input_op_datas_[sequential_num]) {
MS_EXCEPTION_IF_NULL(input_data);
MS_EXCEPTION_IF_NULL(input_data->data_);
memory_free_list.emplace_back(input_data->data_);
}
}

if (input_op_partials_.count(sequential_num) > 0) {
for (auto &input_partial_pair : input_op_partials_[sequential_num]) {
auto partial_device_tensors = GetAllDeviceTensors(input_partial_pair.second);
(void)std::copy(partial_device_tensors.begin(), partial_device_tensors.end(),
std::back_inserter(memory_free_list));
}
}

if (memory_free_list.size() > 0) {
memory_free_lists_.emplace_back(memory_free_list);
ActorDispatcher::Send(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &(memory_free_lists_.back()),
device_contexts_[0], context);
}
}

void ControlActor::EraseInput(const OpContext<DeviceTensor> *context) {
AbstractActor::EraseInput(context);
MS_EXCEPTION_IF_NULL(context);
const auto &sequential_num = context->sequential_num_;
AbstractActor::EraseInput(context);

if (input_partials_num_ != 0) {
auto ret = input_op_partials_.erase(sequential_num);
if (ret == 0) {


+ 21
- 3
mindspore/ccsrc/runtime/framework/actor/control_flow/control_actor.h View File

@@ -25,9 +25,11 @@
#include <stack>
#include <queue>
#include <utility>
#include <algorithm>
#include "runtime/framework/actor/actor_common.h"
#include "runtime/framework/actor/abstract_actor.h"
#include "runtime/framework/actor/memory_aware_actor.h"
#include "runtime/framework/actor/memory_manager_actor.h"

namespace mindspore {
namespace runtime {
@@ -49,10 +51,10 @@ struct OpRealParameterWithBranchID {
int branch_id_;
};
// The control actor is the base class of control flow actor.
class ControlActor : public AbstractActor {
class ControlActor : public MemoryAwareActor {
public:
ControlActor(const std::string &name, KernelTransformType type, const std::vector<KernelWithIndex> &parameters,
const AnfNodePtr &node);
ControlActor(const std::string &name, KernelTransformType type, const AID &memory_manager_aid,
const std::vector<KernelWithIndex> &parameters, const AnfNodePtr &node);
~ControlActor() override = default;

void Init() override;
@@ -72,6 +74,14 @@ class ControlActor : public AbstractActor {

protected:
friend class ControlNodeScheduler;

// The basic interfaces for op partial and op real parameter.
std::vector<DeviceTensor *> GetAllDeviceTensors(const OpPartialPtr &op_partial);
std::vector<DeviceTensor *> GetAllDeviceTensors(const OpRealParameterWithBranchID &op_real_parameter);
void IncreaseDynamicRefCount(const OpData<DeviceTensor> *op_data);
void IncreaseDynamicRefCount(const OpPartialPtr &op_partial);
void IncreaseDynamicRefCount(const OpRealParameterWithBranchID &op_real_parameter);

// Get the position of node in the input.
size_t FetchNodePosition(const KernelWithIndex &node) const;

@@ -82,6 +92,11 @@ class ControlActor : public AbstractActor {
void SendOutput(OpContext<DeviceTensor> *const context) override;
void EraseInput(const OpContext<DeviceTensor> *context) override;

// Increase the dynamic ref count by the outputs. It corresponds to the SendOutput.
virtual void IncreaseDynamicRefCounts(OpContext<DeviceTensor> *const context);
// Free memory by the dynamic ref count decremented. It corresponds to the EraseInput.
void SendMemoryFreeReq(OpContext<DeviceTensor> *const context) override;

// Input data.
// 1.Input partial.
// Record the partial received by each step, the key of the pair indicates the location of the partial.
@@ -99,6 +114,9 @@ class ControlActor : public AbstractActor {
std::vector<OpPartialPtr> input_partials_;
std::vector<DeviceTensor *> input_device_tensors_;

// The lists of device tensors which need free by dynamic ref count, will be cleared at the end of step.
std::vector<std::vector<DeviceTensor *>> memory_free_lists_;

// Input num.
size_t input_partials_num_{0};
size_t input_branch_ids_num_{0};


+ 38
- 0
mindspore/ccsrc/runtime/framework/actor/control_flow/entrance_actor.cc View File

@@ -70,6 +70,13 @@ void EntranceActor::Run(OpContext<DeviceTensor> *const context) {
is_loop_body_execution_ = true;

FetchInput(context);

// Note that IncreaseDynamicRefCount must be in front of SendMemoryFreeReq. SendMemoryFreeReq will decreasing the
// dynamic ref count. Avoid the illegal timing problem that the dynamic reference count is decremented and then
// incremented.
IncreaseDynamicRefCounts(context);
SendMemoryFreeReq(context);

EraseInput(context);
SendOutput(context);
}
@@ -218,5 +225,36 @@ void EntranceActor::EraseInput(const OpContext<DeviceTensor> *const context) {
}
}
}

void EntranceActor::SendMemoryFreeReq(OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
const auto &sequential_num = context->sequential_num_;

// Collect the input device tensors.
std::vector<DeviceTensor *> memory_free_list;
if (input_op_datas_.count(sequential_num) > 0) {
for (auto &input_data : input_op_datas_[sequential_num]) {
MS_EXCEPTION_IF_NULL(input_data);
MS_EXCEPTION_IF_NULL(input_data->data_);
memory_free_list.emplace_back(input_data->data_);
}
}

const auto &iter = real_parameters_with_branch_id_.find(sequential_num);
if (iter != real_parameters_with_branch_id_.end()) {
if (iter->second.empty()) {
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "The real parameter with branch id is empty.");
}
auto &real_parameters_with_branch_id = iter->second.front();
auto partial_device_tensors = GetAllDeviceTensors(real_parameters_with_branch_id);
(void)std::copy(partial_device_tensors.begin(), partial_device_tensors.end(), std::back_inserter(memory_free_list));
}

if (memory_free_list.size() > 0) {
memory_free_lists_.emplace_back(memory_free_list);
ActorDispatcher::Send(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &(memory_free_lists_.back()),
device_contexts_[0], context);
}
}
} // namespace runtime
} // namespace mindspore

+ 5
- 2
mindspore/ccsrc/runtime/framework/actor/control_flow/entrance_actor.h View File

@@ -23,6 +23,7 @@
#include <stack>
#include <queue>
#include <set>
#include <algorithm>
#include "utils/hash_map.h"
#include "runtime/framework/actor/actor_common.h"
#include "runtime/framework/actor/control_flow/control_actor.h"
@@ -33,9 +34,10 @@ namespace runtime {
// the data to the corresponding actor. It is the entry point for subgraph execution.
class EntranceActor : public ControlActor {
public:
EntranceActor(const std::string &name, const std::vector<KernelWithIndex> &parameters,
EntranceActor(const std::string &name, const AID &memory_manager_aid, const std::vector<KernelWithIndex> &parameters,
const std::set<KernelWithIndex> &call_nodes, const AnfNodePtr &node)
: ControlActor(name, KernelTransformType::kEntranceActor, parameters, node), call_nodes_(call_nodes) {
: ControlActor(name, KernelTransformType::kEntranceActor, memory_manager_aid, parameters, node),
call_nodes_(call_nodes) {
device_contexts_.resize(parameters.size());
input_device_tensors_.resize(parameters.size());
}
@@ -56,6 +58,7 @@ class EntranceActor : public ControlActor {
void FetchInput(OpContext<DeviceTensor> *const context) override;
bool CheckRunningCondition(const OpContext<DeviceTensor> *context) const override;
void EraseInput(const OpContext<DeviceTensor> *const context) override;
void SendMemoryFreeReq(OpContext<DeviceTensor> *const context) override;

private:
friend class ControlNodeScheduler;


+ 57
- 11
mindspore/ccsrc/runtime/framework/actor/control_flow/exit_actor.cc View File

@@ -91,6 +91,33 @@ void ExitActor::SendOutput(OpContext<DeviceTensor> *const context) {
}
}

void ExitActor::IncreaseDynamicRefCounts(OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
ControlActor::IncreaseDynamicRefCounts(context);

// Increase dynamic ref count by the output data in output branch.
if (output_branch_data_.count(output_branch_id_) > 0) {
for (auto &output_data : output_branch_data_[output_branch_id_]) {
MS_EXCEPTION_IF_NULL(output_data.second);
IncreaseDynamicRefCount(output_data.second.get());
}
}

// Increase dynamic ref count by the output partial in output branch.
if (output_branch_partial_arrows_.count(output_branch_id_) > 0) {
for (const auto &partial_arrow : output_branch_partial_arrows_[output_branch_id_]) {
MS_EXCEPTION_IF_NULL(partial_arrow);
if (IntToSize(partial_arrow->from_output_index_) >= input_partials_.size()) {
std::string error_info = "Invalid partial input:" + std::to_string(partial_arrow->from_output_index_) +
" current:" + std::to_string(input_partials_.size()) + " for actor:" + GetAID().Name();
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
auto output_partial = input_partials_[partial_arrow->from_output_index_];
IncreaseDynamicRefCount(output_partial);
}
}
}

void ExitActor::CopyDeviceAddress(OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
// If node is not empty, it is the exit of funcgraph, no need to create device address.
@@ -110,26 +137,45 @@ void ExitActor::CopyDeviceAddress(OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(input_device_tensor);
const KernelWithIndex &node_with_index = input_device_tensor->GetNodeIndex();
MS_EXCEPTION_IF_NULL(node_with_index.first);
// If the address ptr can't be changed, does not need to copy a new device tensor.
if ((!is_need_copy_device_tensors_[i]) || input_device_tensor->is_ptr_persisted()) {
if (!is_need_copy_device_tensors_[i]) {
new_device_tensors.emplace_back(input_device_tensor);
continue;
}

MS_EXCEPTION_IF_NULL(device_contexts_[i]);
auto new_device_tensor =
device_contexts_[i]->CreateDeviceAddress(nullptr, input_device_tensors_[i]->GetSize(),
input_device_tensors_[i]->format(), input_device_tensors_[i]->type_id());
// Create the new device tensor to take over the input_device_tensors which are the outputs of kernel graphs.
auto new_device_tensor = device_contexts_[i]->CreateDeviceAddress(
nullptr, input_device_tensor->GetSize(), input_device_tensor->format(), input_device_tensor->type_id());
MS_EXCEPTION_IF_NULL(new_device_tensor);
new_device_tensor->set_ptr(input_device_tensor->GetMutablePtr());
new_device_tensor->set_from_mem_pool(input_device_tensor->from_mem_pool());
created_device_tensors_.emplace_back(new_device_tensor);
new_device_tensors.emplace_back(new_device_tensor.get());

new_device_tensor->SetNodeIndex(node_with_index.first, node_with_index.second);
new_device_tensor->set_from_persistent_mem(input_device_tensor->from_persistent_mem());
// The device address which is created by actor uses the dynamic ref count.
new_device_tensor->set_dynamic_ref_conut(0);
new_device_tensor->set_original_ref_count(SIZE_MAX);
new_device_tensor->ResetRefCount();
new_device_tensors.emplace_back(new_device_tensor.get());
created_device_tensors_.emplace_back(new_device_tensor);

input_device_tensor->set_ptr(nullptr);
input_device_tensor->set_from_mem_pool(false);
// If the address ptr can't be changed, then alloc the new device memory and copy the data.
if (input_device_tensor->is_ptr_persisted()) {
if (!device_contexts_[i]->AllocateMemory(new_device_tensor.get(), new_device_tensor->GetSize())) {
SET_OPCONTEXT_MEMORY_ALLOC_FAIL_BY_STRATEGY(GraphExecutionStrategy::kPipeline, *context, *device_contexts_[i],
GetAID().Name(), new_device_tensor->GetSize());
}
if (!new_device_tensor->SyncDeviceToDevice(
trans::GetRuntimePaddingShape(node_with_index.first, node_with_index.second),
input_device_tensor->GetSize(), input_device_tensor->type_id(), input_device_tensor->GetPtr(),
input_device_tensor->format())) {
SET_OPCONTEXT_FAIL_RET_WITH_ERROR(*context, "Sync device to device failed.");
}
} else {
// Move the device ptr from input_device_tensor to new_device_tensor.
new_device_tensor->set_ptr(input_device_tensor->GetMutablePtr());
new_device_tensor->set_from_mem_pool(input_device_tensor->from_mem_pool());
input_device_tensor->set_ptr(nullptr);
input_device_tensor->set_from_mem_pool(false);
}
}
input_device_tensors_.swap(new_device_tensors);



+ 4
- 2
mindspore/ccsrc/runtime/framework/actor/control_flow/exit_actor.h View File

@@ -32,8 +32,9 @@ namespace runtime {
// device tensors in the data to the corresponding actor. It is the exit of the end of kernel graph execution.
class ExitActor : public ControlActor {
public:
ExitActor(const std::string &name, const std::vector<KernelWithIndex> &parameters, const AnfNodePtr &node)
: ControlActor(name, KernelTransformType::kExitActor, parameters, node) {
ExitActor(const std::string &name, const AID &memory_manager_aid, const std::vector<KernelWithIndex> &parameters,
const AnfNodePtr &node)
: ControlActor(name, KernelTransformType::kExitActor, memory_manager_aid, parameters, node) {
device_contexts_.resize(parameters.size());
input_device_tensors_.resize(parameters.size());
}
@@ -54,6 +55,7 @@ class ExitActor : public ControlActor {
protected:
void FetchInput(OpContext<DeviceTensor> *const context) override;
void SendOutput(OpContext<DeviceTensor> *const context) override;
void IncreaseDynamicRefCounts(OpContext<DeviceTensor> *const context) override;

private:
friend class ControlNodeScheduler;


+ 48
- 24
mindspore/ccsrc/runtime/framework/actor/control_flow/gather_actor.cc View File

@@ -19,9 +19,9 @@

namespace mindspore {
namespace runtime {
GatherActor::GatherActor(const std::string &name, const std::vector<KernelWithIndex> &parameters,
const AnfNodePtr &node)
: ControlActor(name, KernelTransformType::kGatherActor, parameters, node) {
GatherActor::GatherActor(const std::string &name, const AID &memory_manager_aid,
const std::vector<KernelWithIndex> &parameters, const AnfNodePtr &node)
: ControlActor(name, KernelTransformType::kGatherActor, memory_manager_aid, parameters, node) {
device_contexts_.resize(parameters.size());
}

@@ -50,34 +50,41 @@ void GatherActor::FetchInput(OpContext<DeviceTensor> *const context) {
}
}

void GatherActor::FetchOutput(OpRealParameterWithBranchID *const output, OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(output);
MS_EXCEPTION_IF_NULL(context);
output->branch_id_ = output_branch_id_;
output->device_tensors_ = input_partials_[0]->device_tensors_;
output->partials_ = input_partials_[0]->partials_;

// The first input of gather actor is the target funcgraph, which will not be sent to the entrance actor as
// an real parameter, so the subsequent index needs to be reduced by one.
for (auto &device_tensor : output->device_tensors_) {
if (device_tensor.first == 0) {
std::string error_info =
"Invalid device tensor index:" + std::to_string(device_tensor.first) + " for actor:" + GetAID().Name();
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
device_tensor.first--;
}
for (auto &partial : output->partials_) {
if (partial.first == 0) {
std::string error_info =
"Invalid partial index:" + std::to_string(partial.first) + " for actor:" + GetAID().Name();
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
partial.first--;
}
}

void GatherActor::SendOutput(OpContext<DeviceTensor> *const context) {
// Send data with branch id.
const auto &iter = output_data_with_branch_id_arrows_.find(input_partials_[0]->func_graph_);
if (iter != output_data_with_branch_id_arrows_.end()) {
// Build the output data struct.
OpRealParameterWithBranchID output;
output.branch_id_ = output_branch_id_;
output.device_tensors_ = input_partials_[0]->device_tensors_;
output.partials_ = input_partials_[0]->partials_;
FetchOutput(&output, context);

// The first input of gather actor is the target funcgraph, which will not be sent to the entrance actor as
// an real parameter, so the subsequent index needs to be reduced by one.
for (auto &device_tensor : output.device_tensors_) {
if (device_tensor.first == 0) {
std::string error_info =
"Invalid device tensor index:" + std::to_string(device_tensor.first) + " for actor:" + GetAID().Name();
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
device_tensor.first--;
}
for (auto &partial : output.partials_) {
if (partial.first == 0) {
std::string error_info =
"Invalid partial index:" + std::to_string(partial.first) + " for actor:" + GetAID().Name();
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
partial.first--;
}
for (const auto &data_with_branch_id_arrow : iter->second) {
ActorDispatcher::Send(data_with_branch_id_arrow, &EntranceActor::RunOpRealParameterWithBranchID, output, context);
}
@@ -86,5 +93,22 @@ void GatherActor::SendOutput(OpContext<DeviceTensor> *const context) {
// Control arrow needs to be sent after the real parameter data and branch id.
ControlActor::SendOutput(context);
}

void GatherActor::IncreaseDynamicRefCounts(OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
ControlActor::IncreaseDynamicRefCounts(context);

// Increase dynamic ref count by the output data with branch id.
const auto &iter = output_data_with_branch_id_arrows_.find(input_partials_[0]->func_graph_);
if (iter != output_data_with_branch_id_arrows_.end()) {
// Build the output data struct.
OpRealParameterWithBranchID output;
FetchOutput(&output, context);

for (size_t i = 0; i < iter->second.size(); ++i) {
IncreaseDynamicRefCount(output);
}
}
}
} // namespace runtime
} // namespace mindspore

+ 5
- 1
mindspore/ccsrc/runtime/framework/actor/control_flow/gather_actor.h View File

@@ -34,7 +34,8 @@ namespace runtime {
// together and sent to the subgraph.
class GatherActor : public ControlActor {
public:
GatherActor(const std::string &name, const std::vector<KernelWithIndex> &parameters, const AnfNodePtr &node);
GatherActor(const std::string &name, const AID &memory_manager_aid, const std::vector<KernelWithIndex> &parameters,
const AnfNodePtr &node);
~GatherActor() override = default;
const mindspore::HashMap<FuncGraph *, std::vector<AID>> &output_data_with_branch_id_arrows() const {
return output_data_with_branch_id_arrows_;
@@ -43,10 +44,13 @@ class GatherActor : public ControlActor {
protected:
void FetchInput(OpContext<DeviceTensor> *const context) override;
void SendOutput(OpContext<DeviceTensor> *const context) override;
void IncreaseDynamicRefCounts(OpContext<DeviceTensor> *const context) override;

private:
friend class ControlNodeScheduler;

void FetchOutput(OpRealParameterWithBranchID *const output, OpContext<DeviceTensor> *const context);

// There will be multiple output branches for gather actor according the funcgraph in partial.
mindspore::HashMap<FuncGraph *, std::vector<AID>> output_data_with_branch_id_arrows_;
};


+ 51
- 44
mindspore/ccsrc/runtime/framework/actor/control_flow/stack_actor.cc View File

@@ -20,8 +20,9 @@

namespace mindspore {
namespace runtime {
StackActor::StackActor(const std::string &name, const std::vector<KernelWithIndex> &parameters)
: ControlActor(name, KernelTransformType::kStackActor, parameters, nullptr) {
StackActor::StackActor(const std::string &name, const AID &memory_manager_aid,
const std::vector<KernelWithIndex> &parameters)
: ControlActor(name, KernelTransformType::kStackActor, memory_manager_aid, parameters, nullptr) {
input_device_tensors_.resize(parameters.size());
}

@@ -77,7 +78,7 @@ void StackActor::RunOpData(OpData<DeviceTensor> *const input_data, OpContext<Dev
// The parameters from the inside of the subgraph need to be put into the stack.
if (IntToSize(input_data->index_) < input_stack_data_num_ + device_tensor_store_keys_.size() +
input_stack_partials_num_ + local_device_tensors_.size()) {
FillStack(input_data, context);
input_stack_data_[context->sequential_num_][input_data->index_].push(input_data->data_);
} else {
// The outputs of call nodes are placed directly in the input data.
input_op_datas_[context->sequential_num_].emplace_back(input_data);
@@ -129,47 +130,6 @@ void StackActor::RunOpPartial(OpPartialPtr partial, size_t position, OpContext<D
}
}

void StackActor::FillStack(OpData<DeviceTensor> *const input_data, OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
MS_EXCEPTION_IF_NULL(input_data);
auto &input_device_tensor = input_data->data_;
MS_EXCEPTION_IF_NULL(input_device_tensor);
auto &sequential_num = context->sequential_num_;
size_t index = IntToSize(input_data->index_);
if (index >= device_contexts_.size()) {
SET_OPCONTEXT_FAIL_RET_WITH_ERROR(*context, "The index is out of range.");
}

// 1.If device context is empty, it means that the input is from a parameter and does not need copy new device tensor.
// 2.If the address ptr can be changed, it has been copied by exit actor and does not need copy a new device tensor.
if ((device_contexts_[index] == nullptr) || (!input_device_tensor->is_ptr_persisted())) {
input_stack_data_[sequential_num][input_data->index_].push(input_device_tensor);
} else {
const KernelWithIndex &node_with_index = input_device_tensor->GetNodeIndex();
MS_EXCEPTION_IF_NULL(node_with_index.first);
// Create the new device tensor and copy the data from the input data.
auto new_device_tensor = device_contexts_[index]->CreateDeviceAddress(
nullptr, input_device_tensor->GetSize(), input_device_tensor->format(), input_device_tensor->type_id());
MS_EXCEPTION_IF_NULL(new_device_tensor);

if (!device_contexts_[index]->AllocateMemory(new_device_tensor.get(), new_device_tensor->GetSize())) {
SET_OPCONTEXT_MEMORY_ALLOC_FAIL_BY_STRATEGY(GraphExecutionStrategy::kPipeline, *context, *device_contexts_[index],
GetAID().Name(), new_device_tensor->GetSize());
}
if (!new_device_tensor->SyncDeviceToDevice(
trans::GetRuntimePaddingShape(node_with_index.first, node_with_index.second), input_device_tensor->GetSize(),
input_device_tensor->type_id(), input_device_tensor->GetPtr(), input_device_tensor->format())) {
SET_OPCONTEXT_FAIL_RET_WITH_ERROR(*context, "Sync device to device failed.");
}
new_device_tensor->SetNodeIndex(node_with_index.first, node_with_index.second);
new_device_tensor->set_original_ref_count(SIZE_MAX);
new_device_tensor->ResetRefCount();

created_device_tensors_.emplace_back(new_device_tensor);
input_stack_data_[sequential_num][input_data->index_].push(new_device_tensor.get());
}
}

bool StackActor::CheckRunningCondition(const OpContext<DeviceTensor> *context) const {
MS_EXCEPTION_IF_NULL(context);
if (!ControlActor::CheckRunningCondition(context)) {
@@ -331,5 +291,52 @@ void StackActor::EraseInput(const OpContext<DeviceTensor> *const context) {
}
}
}

void StackActor::SendMemoryFreeReq(OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
const auto &sequential_num = context->sequential_num_;

// Collect the input device tensors.
std::vector<DeviceTensor *> memory_free_list;
if (input_op_datas_.count(sequential_num) > 0) {
for (auto &input_data : input_op_datas_[sequential_num]) {
MS_EXCEPTION_IF_NULL(input_data);
MS_EXCEPTION_IF_NULL(input_data->data_);
memory_free_list.emplace_back(input_data->data_);
}
}

if (input_op_partials_.count(sequential_num) > 0) {
for (auto &input_partial_pair : input_op_partials_[sequential_num]) {
auto partial_device_tensors = GetAllDeviceTensors(input_partial_pair.second);
(void)std::copy(partial_device_tensors.begin(), partial_device_tensors.end(),
std::back_inserter(memory_free_list));
}
}

if ((input_stack_data_num_ != 0) && (input_stack_data_.count(sequential_num) > 0)) {
for (auto &stack_data_pair : input_stack_data_[sequential_num]) {
if (!stack_data_pair.second.empty()) {
memory_free_list.emplace_back(stack_data_pair.second.top());
}
}
}

if ((input_stack_partials_num_ != 0) && (input_stack_partials_.count(sequential_num) > 0)) {
for (auto &stack_partial_pair : input_stack_partials_[sequential_num]) {
if (!stack_partial_pair.second.empty()) {
auto partial_device_tensors = GetAllDeviceTensors(stack_partial_pair.second.top());
(void)std::copy(partial_device_tensors.begin(), partial_device_tensors.end(),
std::back_inserter(memory_free_list));
}
}
}

if (memory_free_list.size() > 0) {
memory_free_lists_.emplace_back(memory_free_list);
ActorDispatcher::Send(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &(memory_free_lists_.back()),
device_contexts_[0], context);
}
}
} // namespace runtime
} // namespace mindspore

+ 3
- 6
mindspore/ccsrc/runtime/framework/actor/control_flow/stack_actor.h View File

@@ -22,6 +22,7 @@
#include <memory>
#include <stack>
#include <set>
#include <algorithm>
#include "utils/hash_map.h"
#include "runtime/framework/actor/actor_common.h"
#include "runtime/framework/actor/control_flow/control_actor.h"
@@ -36,7 +37,7 @@ namespace runtime {
// 4. Send output.
class StackActor : public ControlActor {
public:
StackActor(const std::string &name, const std::vector<KernelWithIndex> &parameters);
StackActor(const std::string &name, const AID &memory_manager_aid, const std::vector<KernelWithIndex> &parameters);
~StackActor() override = default;

void Init() override;
@@ -50,15 +51,11 @@ class StackActor : public ControlActor {
void FetchInput(OpContext<DeviceTensor> *const context) override;
bool CheckRunningCondition(const OpContext<DeviceTensor> *context) const override;
void EraseInput(const OpContext<DeviceTensor> *const context) override;
void SendMemoryFreeReq(OpContext<DeviceTensor> *const context) override;

private:
friend class ControlNodeScheduler;

void FillStack(OpData<DeviceTensor> *const input_data, OpContext<DeviceTensor> *const context);

// The device tensors created and stored by the stack.
std::vector<DeviceTensorPtr> created_device_tensors_;

// The input data and partials records that the stack actor is copied from the input nodes and needs to be
// stored in the device tensor in the stack.
mindspore::HashMap<int, mindspore::HashMap<size_t, std::stack<DeviceTensor *>>> input_stack_data_;


+ 3
- 3
mindspore/ccsrc/runtime/framework/actor/control_flow/switch_actor.cc View File

@@ -24,9 +24,9 @@ namespace runtime {
constexpr size_t kMaxSwitchCondSize = 8;
constexpr size_t kSwitchDefaultOutputNum = 1;

SwitchActor::SwitchActor(const std::string &name, const std::vector<KernelWithIndex> &parameters,
const AnfNodePtr &node)
: ControlActor(name, KernelTransformType::kSwitchActor, parameters, node) {
SwitchActor::SwitchActor(const std::string &name, const AID &memory_manager_aid,
const std::vector<KernelWithIndex> &parameters, const AnfNodePtr &node)
: ControlActor(name, KernelTransformType::kSwitchActor, memory_manager_aid, parameters, node) {
device_contexts_.resize(parameters.size());
output_data_by_output_index_.resize(kSwitchDefaultOutputNum);
}


+ 2
- 1
mindspore/ccsrc/runtime/framework/actor/control_flow/switch_actor.h View File

@@ -33,7 +33,8 @@ using mindspore::session::KernelWithIndex;
// Switch and SwitchLayer node will be converted to switch actor.
class SwitchActor : public ControlActor {
public:
SwitchActor(const std::string &name, const std::vector<KernelWithIndex> &parameters, const AnfNodePtr &node);
SwitchActor(const std::string &name, const AID &memory_manager_aid, const std::vector<KernelWithIndex> &parameters,
const AnfNodePtr &node);
~SwitchActor() override = default;

void Init() override;


+ 40
- 28
mindspore/ccsrc/runtime/framework/actor/memory_manager_actor.cc View File

@@ -17,11 +17,49 @@
#include "runtime/framework/actor/memory_manager_actor.h"
#include "runtime/framework/actor/data_source_actor.h"
#include "runtime/framework/actor/kernel_actor.h"
#include "runtime/hardware/device_context_manager.h"
#include "mindrt/include/async/async.h"
#include "utils/log_adapter.h"

namespace mindspore {
namespace runtime {
namespace {
void FreeMemoryInner(DeviceTensor *const device_tensor, const DeviceContext *device_context) {
MS_EXCEPTION_IF_NULL(device_tensor);
// The device context may be not accurate in the control flow scene, so need fetch by device name and device id.
if ((device_context == nullptr) || (device_context->GetDeviceAddressType() != device_tensor->DeviceType())) {
const auto &new_device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
{device_tensor->device_name(), device_tensor->device_id()});
MS_EXCEPTION_IF_NULL(new_device_context);
new_device_context->FreeMemory(device_tensor);
} else {
device_context->FreeMemory(device_tensor);
}
}

// Only one of the static and dynamic reference counts will take effect.
void FreeMemoryByRefCount(DeviceTensor *const device_tensor, const DeviceContext *device_context) {
MS_EXCEPTION_IF_NULL(device_tensor);
if (device_tensor->original_ref_count() != SIZE_MAX) {
// The static reference count is decremented to zero to free memory, and reset to the original count.
device_tensor->DecreaseRefCount();
if (device_tensor->ref_count() == 0) {
if (device_tensor->GetPtr() != nullptr) {
FreeMemoryInner(device_tensor, device_context);
}
device_tensor->ResetRefCount();
}
} else if (device_tensor->dynamic_ref_conut() != INT32_MAX) {
// The dynamic reference count is decremented to zero to free memory.
device_tensor->DecreaseDynamicRefCount();
if ((device_tensor->dynamic_ref_conut() == 0) && (device_tensor->GetPtr() != nullptr)) {
MS_LOG(DEBUG) << "Free memory by the dynamic reference count, device address" << device_tensor->GetPtr();
FreeMemoryInner(device_tensor, device_context);
}
}
}
} // namespace

void MemoryManagerActor::AllocateMemory(const std::vector<DeviceTensor *> *alloc_list,
const DeviceContext *device_context, OpContext<DeviceTensor> *const op_context,
const AID &from_aid) {
@@ -113,21 +151,8 @@ void MemoryManagerActor::AllocateBatchMemory(const std::vector<DeviceTensor *> *
void MemoryManagerActor::FreeMemory(const std::vector<DeviceTensor *> *free_list, const DeviceContext *device_context,
OpContext<DeviceTensor> *) {
MS_EXCEPTION_IF_NULL(free_list);
MS_EXCEPTION_IF_NULL(device_context);
for (auto &device_tensor : *free_list) {
MS_EXCEPTION_IF_NULL(device_tensor);
if (device_tensor->original_ref_count() == SIZE_MAX) {
continue;
}
// The reference count is decremented to zero to free memory, and reset to the original count.
device_tensor->DecreaseRefCount();
if (device_tensor->ref_count() == 0) {
// Free memory through the device context.
if (device_tensor->GetPtr() != nullptr) {
device_context->FreeMemory(device_tensor);
}
device_tensor->ResetRefCount();
}
FreeMemoryByRefCount(device_tensor, device_context);
}
}

@@ -145,20 +170,7 @@ void MemoryManagerActor::FreeBatchMemory(const std::vector<DeviceTensor *> *free
for (size_t i = 0; i < (*free_list).size(); ++i) {
auto &device_tensor = (*free_list)[i];
auto &device_context = (*device_contexts)[i];
MS_EXCEPTION_IF_NULL(device_tensor);
MS_EXCEPTION_IF_NULL(device_context);
if (device_tensor->original_ref_count() == SIZE_MAX) {
continue;
}
// The reference count is decremented to zero to free memory, and reset to the original count.
device_tensor->DecreaseRefCount();
if (device_tensor->ref_count() == 0) {
// Free memory through the device context.
if (device_tensor->GetPtr() != nullptr) {
device_context->FreeMemory(device_tensor);
}
device_tensor->ResetRefCount();
}
FreeMemoryByRefCount(device_tensor, device_context);
}
}



+ 2
- 1
mindspore/ccsrc/runtime/framework/actor/output_actor.cc View File

@@ -123,7 +123,7 @@ TensorPtr OutputActor::CreateOutputTensor(const AnfNodePtr &output_node, size_t
if (device_context->GetDeviceAddressType() != device_tensor->DeviceType()) {
auto old_device_context = device_context;
device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
{device_tensor->DeviceName(), device_tensor->DeviceID()});
{device_tensor->device_name(), device_tensor->device_id()});
MS_LOG(INFO) << "Update device context from:" << old_device_context->GetDeviceAddressType()
<< " to:" << device_context->GetDeviceAddressType();
}
@@ -169,6 +169,7 @@ void OutputActor::UpdateOutputDeviceAddress() {
tensor_device_address->ResetRefCount();
auto node_with_index = device_tensor->GetNodeIndex();
tensor_device_address->SetNodeIndex(node_with_index.first, node_with_index.second);
tensor_device_address->set_from_persistent_mem(device_tensor->from_persistent_mem());
// The outputs may have the same output node, so need skip when the node has been done.
if (device_tensor->GetPtr() == nullptr) {
continue;


+ 24
- 0
mindspore/ccsrc/runtime/framework/actor/super_kernel_actor.cc View File

@@ -16,6 +16,7 @@

#include "runtime/framework/actor/super_kernel_actor.h"
#include "runtime/framework/actor/output_actor.h"
#include "runtime/framework/actor/memory_manager_actor.h"
#include "mindrt/include/async/async.h"
#include "utils/log_adapter.h"

@@ -164,5 +165,28 @@ bool SuperKernelActor::CopyInputData(const OpContext<DeviceTensor> *context) {

return true;
}

void SuperKernelActor::SendMemoryFreeReq(OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
const auto &sequential_num = context->sequential_num_;

// Collect the input device tensors.
std::vector<DeviceTensor *> memory_free_list;
if (input_op_datas_.count(sequential_num) > 0) {
for (auto &input_data : input_op_datas_[sequential_num]) {
MS_EXCEPTION_IF_NULL(input_data);
MS_EXCEPTION_IF_NULL(input_data->data_);
if (input_data->data_->dynamic_ref_conut() != INT32_MAX) {
memory_free_list.emplace_back(input_data->data_);
}
}
}

if (memory_free_list.size() > 0) {
memory_free_lists_.emplace_back(memory_free_list);
ActorDispatcher::Send(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &(memory_free_lists_.back()),
device_contexts_[0], context);
}
}
} // namespace runtime
} // namespace mindspore

+ 6
- 0
mindspore/ccsrc/runtime/framework/actor/super_kernel_actor.h View File

@@ -21,6 +21,7 @@
#include <memory>
#include <map>
#include <utility>
#include <vector>
#include "runtime/framework/actor/debug_aware_actor.h"
#include "runtime/framework/actor/actor_common.h"
#include "runtime/hardware/device_context.h"
@@ -50,6 +51,8 @@ class SuperKernelActor : public DebugAwareActor {

protected:
void Run(OpContext<DeviceTensor> *const context) override;
// The input may come from the control actor, so need free the input memory by the dynamic ref count.
void SendMemoryFreeReq(OpContext<DeviceTensor> *const context) override;

private:
friend class GraphScheduler;
@@ -59,6 +62,9 @@ class SuperKernelActor : public DebugAwareActor {
KernelGraphPtr graph_;

std::map<AnfNodePtr, DeviceAddress *> ref_node_addr_map_;

// The lists of device tensors which need free by dynamic ref count, will be cleared at the end of step.
std::vector<std::vector<DeviceTensor *>> memory_free_lists_;
};

using SuperKernelActorPtr = std::shared_ptr<SuperKernelActor>;


+ 33
- 9
mindspore/ccsrc/runtime/framework/control_node_scheduler.cc View File

@@ -83,12 +83,14 @@ bool IsControlFlowArrow(const ControlNodeParserPtr &parser, const KernelGraphPtr
}
} // namespace

ControlActorSetPtr ControlNodeScheduler::Build(const GraphCompilerInfo &graph_compiler_info) {
ControlActorSetPtr ControlNodeScheduler::Build(const GraphCompilerInfo &graph_compiler_info,
const AID &memory_manager_aid) {
const auto &control_nodes = graph_compiler_info.control_nodes_;
if (control_nodes.size() <= kSingleControlNode) {
return nullptr;
}

memory_manager_aid_ = memory_manager_aid;
ControlActorSetPtr control_actors = std::make_shared<ControlActorSet>();
control_actors->switch_actors_ = BuildSwitchActor(graph_compiler_info);
control_actors->gather_actors_ = BuildGatherActor(graph_compiler_info);
@@ -108,7 +110,8 @@ std::vector<SwitchActorPtr> ControlNodeScheduler::BuildSwitchActor(const GraphCo
AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitchLayer)) {
const auto &actor_name = GetActorName(control_node);
const auto &parameters = FetchInputNodeByCNode(control_node);
const auto &switch_actor = std::make_shared<SwitchActor>(actor_name, parameters, control_node);
const auto &switch_actor =
std::make_shared<SwitchActor>(actor_name, memory_manager_aid_, parameters, control_node);
switch_actors.emplace_back(switch_actor);
InsertActor(switch_actor.get());
}
@@ -127,7 +130,8 @@ std::vector<GatherActorPtr> ControlNodeScheduler::BuildGatherActor(const GraphCo
if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimPartial) || AnfAlgo::IsCallNode(control_node)) {
const auto &actor_name = GetActorName(control_node);
const auto &parameters = FetchInputNodeByCNode(control_node);
const auto &gather_actor = std::make_shared<GatherActor>(actor_name, parameters, control_node);
const auto &gather_actor =
std::make_shared<GatherActor>(actor_name, memory_manager_aid_, parameters, control_node);
gather_actors.emplace_back(gather_actor);
InsertActor(gather_actor.get());

@@ -189,7 +193,7 @@ std::vector<EntranceActorPtr> ControlNodeScheduler::BuildEntranceActor(const Gra
call_nodes = iter->second;
}
const auto &entrance_actor =
std::make_shared<EntranceActor>(actor_name, formal_parameters, call_nodes, control_node);
std::make_shared<EntranceActor>(actor_name, memory_manager_aid_, formal_parameters, call_nodes, control_node);
auto context_iter = parser->func_graph_to_device_contexts_.find(func_graph);
if (context_iter == parser->func_graph_to_device_contexts_.end() ||
context_iter->second.size() < formal_parameters.size()) {
@@ -202,7 +206,6 @@ std::vector<EntranceActorPtr> ControlNodeScheduler::BuildEntranceActor(const Gra
entrance_actor->device_contexts_.clear();
entrance_actor->device_contexts_.insert(entrance_actor->device_contexts_.begin(), context_iter->second.begin(),
context_iter->second.begin() + formal_parameters.size());

entrance_actors.emplace_back(entrance_actor);
InsertActor(entrance_actor.get());
}
@@ -225,7 +228,7 @@ std::vector<ExitActorPtr> ControlNodeScheduler::BuildExitActor(const GraphCompil
MS_EXCEPTION_IF_NULL(func_graph);
const auto &actor_name = func_graph->ToString() + kExitActorNameSuffix;
const auto &parameters = FetchInputNodeByCNode(control_node);
const auto &exit_actor = std::make_shared<ExitActor>(actor_name, parameters, control_node);
const auto &exit_actor = std::make_shared<ExitActor>(actor_name, memory_manager_aid_, parameters, control_node);
auto context_iter = parser->control_node_to_device_contexts_.find(control_node);
if (context_iter == parser->control_node_to_device_contexts_.end() ||
context_iter->second.size() != parameters.size()) {
@@ -267,7 +270,7 @@ std::vector<ExitActorPtr> ControlNodeScheduler::BuildExitActor(const GraphCompil
}

const auto &actor_name = kernel_graph_group_info->group_name_ + kExitActorNameSuffix;
const auto &exit_actor = std::make_shared<ExitActor>(actor_name, formal_parameters, nullptr);
const auto &exit_actor = std::make_shared<ExitActor>(actor_name, memory_manager_aid_, formal_parameters, nullptr);
exit_actor->is_need_copy_device_tensors_.swap(is_need_copy_device_tensors);
exit_actor->device_contexts_.swap(device_contexts);
exit_actors.emplace_back(exit_actor);
@@ -305,7 +308,7 @@ std::vector<StackActorPtr> ControlNodeScheduler::BuildStackActor(const GraphComp
}
}
const auto &actor_name = kernel_graph_group_info->group_name_ + kStackActorNameSuffix;
const auto &stack_actor = std::make_shared<StackActor>(actor_name, formal_parameters);
const auto &stack_actor = std::make_shared<StackActor>(actor_name, memory_manager_aid_, formal_parameters);
stack_actors.emplace_back(stack_actor);
stack_actor->device_contexts_.swap(device_contexts);
stack_actor->input_stack_data_num_ = input_parameter_data_num;
@@ -379,7 +382,7 @@ void ControlNodeScheduler::BuildStackActorForControlNode(const GraphCompilerInfo
}
// Create stack actor.
const auto &stack_actor_name = GetActorName(need_stack_control_node) + kStackActorNameSuffix;
const auto &stack_actor = std::make_shared<StackActor>(stack_actor_name, formal_parameters);
const auto &stack_actor = std::make_shared<StackActor>(stack_actor_name, memory_manager_aid_, formal_parameters);
stack_actor->device_contexts_ = device_contexts;
stack_actor->input_stack_data_num_ = input_parameter_data_num;
stack_actor->input_stack_partials_num_ = input_parameter_partials_num;
@@ -422,8 +425,29 @@ void ControlNodeScheduler::ClearActorData(const ControlActorSet *control_actor_s
return;
}

for (auto &switch_actor : control_actor_set->switch_actors_) {
MS_EXCEPTION_IF_NULL(switch_actor);
switch_actor->memory_free_lists_.clear();
}

for (auto &gather_actor : control_actor_set->gather_actors_) {
MS_EXCEPTION_IF_NULL(gather_actor);
gather_actor->memory_free_lists_.clear();
}

for (auto &entrance_actor : control_actor_set->entrance_actors_) {
MS_EXCEPTION_IF_NULL(entrance_actor);
entrance_actor->memory_free_lists_.clear();
}

for (auto &stack_actor : control_actor_set->stack_actors_) {
MS_EXCEPTION_IF_NULL(stack_actor);
stack_actor->memory_free_lists_.clear();
}

for (auto &exit_actor : control_actor_set->exit_actors_) {
MS_EXCEPTION_IF_NULL(exit_actor);
exit_actor->memory_free_lists_.clear();
exit_actor->created_device_tensors_.clear();
}
}


+ 4
- 1
mindspore/ccsrc/runtime/framework/control_node_scheduler.h View File

@@ -37,7 +37,7 @@ class ControlNodeScheduler {
DISABLE_COPY_AND_ASSIGN(ControlNodeScheduler);

// Transform the control nodes to control actors.
ControlActorSetPtr Build(const GraphCompilerInfo &graph_compiler_info);
ControlActorSetPtr Build(const GraphCompilerInfo &graph_compiler_info, const AID &memory_manager_aid);
// Link control actors.
void Link(ActorSet *actor_set, const GraphCompilerInfo &graph_compiler_info);

@@ -106,6 +106,9 @@ class ControlNodeScheduler {
void LinkPartialArrowForExitActor(ExitActor *const exit_actor, ControlActor *const to_actor, size_t from_index,
size_t to_index, int branch_id);
bool IsNoInputActor(const ControlActor *control_actor);

// The id of memory manager actor.
AID memory_manager_aid_;
};
} // namespace runtime
} // namespace mindspore


+ 7
- 1
mindspore/ccsrc/runtime/framework/graph_scheduler.cc View File

@@ -217,6 +217,12 @@ void GraphScheduler::Clear() {

void GraphScheduler::ClearActorData(const ActorSet *actor_set) {
MS_EXCEPTION_IF_NULL(actor_set);

for (auto &super_kernel_actor : actor_set->super_kernel_actors_) {
MS_EXCEPTION_IF_NULL(super_kernel_actor);
super_kernel_actor->memory_free_lists_.clear();
}

control_node_scheduler_.ClearActorData(actor_set->control_actors_.get());
}

@@ -486,7 +492,7 @@ ActorSetPtr GraphScheduler::Build(const GraphCompilerInfo &graph_compiler_info)
actor_set->output_actor_ = BuildOutputActor(graph_compiler_info);
actor_set->data_prepare_actor_ =
BuildDataPrepareActor(graph_compiler_info, actor_set->data_source_actors_, host_queue);
actor_set->control_actors_ = control_node_scheduler_.Build(graph_compiler_info);
actor_set->control_actors_ = control_node_scheduler_.Build(graph_compiler_info, memory_manager_aid_);
return actor_set;
}



Loading…
Cancel
Save