Browse Source

Collect weight from entrance actor in subgraph.

feature/build-system-rewrite
gaoyong10 4 years ago
parent
commit
5cab727543
6 changed files with 89 additions and 101 deletions
  1. +38
    -58
      mindspore/ccsrc/runtime/framework/actor/data_prepare_actor.cc
  2. +4
    -8
      mindspore/ccsrc/runtime/framework/actor/data_prepare_actor.h
  3. +10
    -12
      mindspore/ccsrc/runtime/framework/control_node_parser.cc
  4. +1
    -1
      mindspore/ccsrc/runtime/framework/control_node_parser.h
  5. +35
    -22
      mindspore/ccsrc/runtime/framework/graph_scheduler.cc
  6. +1
    -0
      mindspore/ccsrc/runtime/framework/graph_scheduler.h

+ 38
- 58
mindspore/ccsrc/runtime/framework/actor/data_prepare_actor.cc View File

@@ -320,6 +320,7 @@ void DataPrepareActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *const contex

void DataPrepareActor::PrepareDataForDeviceTensorStore(const std::vector<std::vector<TensorPtr>> &input_tensors,
OpContext<DeviceTensor> *const context) {
const auto &parser = graph_compiler_info_->control_node_parser_;
for (size_t i = 0; i < graph_compiler_info_->graphs_.size(); ++i) {
const auto &graph = graph_compiler_info_->graphs_[i];
const auto &device_context = graph_compiler_info_->device_contexts_[i];
@@ -338,15 +339,16 @@ void DataPrepareActor::PrepareDataForDeviceTensorStore(const std::vector<std::ve
const auto &input_node = input_nodes[j];
const auto &input_tensor = tensors[j];
MS_EXCEPTION_IF_NULL(input_node);
if (!IsPersistentDeviceTensor(input_node)) {
const auto front_node = FetchFrontNodeByBackendNode(input_node, graph);
if (!IsPersistentDeviceTensor(input_node) ||
(parser != nullptr && parser->IsInited() && (!parser->IsRootGraphParameter(front_node)))) {
continue;
}
const auto front_node = FetchFrontNodeByBackendNode(input_node, graph);
PrepareDataForWeightNode(input_node, front_node, input_tensor, device_context, context);
}
}

PrepareDeviceTensorStoreForControlNode(graph_compiler_info_->control_node_parser_, input_tensors.back(), context);
PrepareDataForControlNode(graph_compiler_info_->control_node_parser_, input_tensors.back(), context);
}

void DataPrepareActor::PrepareDataForHostTensorQueue(const std::vector<std::vector<TensorPtr>> &input_tensors,
@@ -699,51 +701,14 @@ void DataPrepareActor::PrepareDataForWeightNode(const AnfNodePtr &backend_node,
}
}

// In control flow, all weight nodes associated with the host weight parameter need to use the same device tensor.
void DataPrepareActor::PrepareDataForControlWeightNode(const AnfNodePtr &node, const AnfNodePtr &front_node,
const TensorPtr &tensor, const DeviceContext *device_context,
const HostParameterToWeight &host_parameter_to_weights,
OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(front_node);
MS_EXCEPTION_IF_NULL(tensor);
MS_EXCEPTION_IF_NULL(device_context);

auto device_tensors = DeviceTensorStore::GetInstance().Fetch(front_node.get());
bool need_update_device_tensor_store = (device_tensors.size() == 0) ? true : false;
for (auto &device_tensor : device_tensors) {
MS_EXCEPTION_IF_NULL(device_tensor);
// Different from CPU、GPU platform, the subgraph weight params device addr of Ascend platform
// has already been allocated during the compilation, so these weight params still need to be updated.
if (device_tensor->GetPtr() == nullptr || device_tensor->is_ptr_persisted()) {
need_update_device_tensor_store = true;
break;
}
}
if (need_update_device_tensor_store) {
PrepareDataForWeightNode(node, front_node, tensor, device_context, context);
}

const auto iter = host_parameter_to_weights.find(front_node);
if (iter == host_parameter_to_weights.end()) {
void DataPrepareActor::PrepareDataForControlNode(const ControlNodeParserPtr &control_node_parser,
const std::vector<TensorPtr> &tensors,
OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(control_node_parser);
if (!control_node_parser->IsInited()) {
return;
}

// Fetch all the device tensors of host weight node and insert as the weight of other nodes.
const auto &sub_front_nodes = host_parameter_to_weights.at(front_node);
device_tensors = DeviceTensorStore::GetInstance().Fetch(front_node.get());
for (const auto &sub_front_node : sub_front_nodes) {
for (const auto &device_tensor : device_tensors) {
MS_EXCEPTION_IF_NULL(sub_front_node);
DeviceTensorStore::GetInstance().Insert(sub_front_node.get(), device_tensor);
}
}
}

void DataPrepareActor::PrepareDeviceTensorStoreForControlNode(const ControlNodeParserPtr &control_node_parser,
const std::vector<TensorPtr> &tensors,
OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(control_node_parser);
for (const auto &value_node_with_context : control_node_parser->front_value_nodes()) {
MS_EXCEPTION_IF_NULL(value_node_with_context.first.first);
if (AnfAlgo::OutputAddrExist(value_node_with_context.first.first, 0)) {
@@ -753,19 +718,34 @@ void DataPrepareActor::PrepareDeviceTensorStoreForControlNode(const ControlNodeP

const auto &control_node_parameters = control_node_parser->control_node_parameters();
for (size_t i = 0; i < control_node_parameters.size(); ++i) {
const auto &input_node = control_node_parameters[i];
const auto &input_tensor = tensors[i];
MS_EXCEPTION_IF_NULL(input_node);
if (IsPersistentDeviceTensor(input_node)) {
const auto &front_to_backend_parameters = control_node_parser->front_to_backend_parameters();
const auto &iter = front_to_backend_parameters.find({input_node, 0});
if (iter == front_to_backend_parameters.end() || iter->second.empty()) {
MS_LOG(EXCEPTION) << "Cannot find backend node for weight parameter:"
<< AnfAlgo::GetNodeDebugString(input_node);
}
const auto &node_with_context = iter->second.begin();
PrepareDataForControlWeightNode(node_with_context->first, input_node, input_tensor, node_with_context->second,
control_node_parser->host_parameter_to_weights(), context);
const auto &front_node = control_node_parameters[i];
MS_EXCEPTION_IF_NULL(front_node);
if ((!IsPersistentDeviceTensor(front_node)) || (!control_node_parser->IsRootGraphParameter(front_node))) {
continue;
}

const auto &front_to_backend_parameters = control_node_parser->front_to_backend_parameters();
const auto &iter = front_to_backend_parameters.find({front_node, 0});
if (iter == front_to_backend_parameters.end() || iter->second.empty()) {
MS_LOG(EXCEPTION) << "Cannot find backend node for weight parameter:" << AnfAlgo::GetNodeDebugString(front_node);
}
const auto &node_with_context = iter->second.begin();
const auto &backend_node = node_with_context->first;
const auto &device_context = node_with_context->second;
MS_EXCEPTION_IF_NULL(backend_node);
MS_EXCEPTION_IF_NULL(device_context);

auto device_tensors = DeviceTensorStore::GetInstance().Fetch(front_node.get());
if (device_tensors.empty()) {
std::string error_info = "Failed to get device tensor for front node:" + front_node->DebugString();
SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(real_strategy_, (*context), error_info);
}

// Different from CPU, GPU platform, the subgraph weight params device addr of Ascend platform has already been
// allocated during the compilation, so these weight params still need to be updated.
if (device_tensors[0] != nullptr &&
(device_tensors[0]->GetPtr() == nullptr || device_tensors[0]->is_ptr_persisted())) {
PrepareDataForWeightNode(backend_node, front_node, tensors[i], device_context, context);
}
}
}


+ 4
- 8
mindspore/ccsrc/runtime/framework/actor/data_prepare_actor.h View File

@@ -90,17 +90,13 @@ class DataPrepareActor : public DebugAwareActor {
const DeviceContext *device_context, OpContext<DeviceTensor> *const context);

// The data prepare in the control flow scene.
void PrepareDeviceTensorStoreForControlNode(const ControlNodeParserPtr &control_node_parser,
const std::vector<TensorPtr> &tensors,
OpContext<DeviceTensor> *const context);
// If the parameters in the root graph are only used by the control node, these parameters will not be initialized
// by the kernel graph, and addresses need to be specially allocated for these parameters.
void PrepareDataForControlNode(const ControlNodeParserPtr &control_node_parser, const std::vector<TensorPtr> &tensors,
OpContext<DeviceTensor> *const context);
void PrepareHostTensorQueueForControlNode(const std::vector<TensorPtr> &tensors,
std::vector<TensorPtr> *const host_tensors,
OpContext<DeviceTensor> *const context);
// In control flow, all weight nodes associated with the host weight parameter need to use the same device tensor.
void PrepareDataForControlWeightNode(const AnfNodePtr &node, const AnfNodePtr &front_node, const TensorPtr &tensor,
const DeviceContext *device_context,
const HostParameterToWeight &host_parameter_to_weights,
OpContext<DeviceTensor> *const context);
void PrepareDataForControlValueNode(const KernelWithIndex &node_with_index, const DeviceContext *device_context,
OpContext<DeviceTensor> *const context);



+ 10
- 12
mindspore/ccsrc/runtime/framework/control_node_parser.cc View File

@@ -945,7 +945,7 @@ void ControlNodeParser::Parse(const std::vector<AnfNodePtr> &control_nodes, cons
ParseFirstControlNodeForFuncGraph(control_nodes);
}

bool ControlNodeParser::IsControlFlowDataArrow(const KernelGraphPtr &graph, const AnfNodePtr &node) {
bool ControlNodeParser::IsControlFlowDataArrow(const KernelGraphPtr &graph, const AnfNodePtr &backend_node) {
MS_EXCEPTION_IF_NULL(graph);
// Has no control flow node.
if (!IsInited()) {
@@ -956,26 +956,24 @@ bool ControlNodeParser::IsControlFlowDataArrow(const KernelGraphPtr &graph, cons
return true;
}

MS_EXCEPTION_IF_NULL(node);
if (!node->isa<Parameter>()) {
MS_EXCEPTION_IF_NULL(backend_node);
if (!backend_node->isa<Parameter>()) {
return false;
}
auto parameter_node = node->cast<ParameterPtr>();
auto parameter_node = backend_node->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(parameter_node);

// Parameter input should be linked to its entrance actor.
auto front_node = graph->GetFrontAnfByBackendAnf(node);
auto internal_node_with_index = graph->GetFrontNodeByInternalParameter(node);
auto front_node = graph->GetFrontAnfByBackendAnf(backend_node);
auto internal_node_with_index = graph->GetFrontNodeByInternalParameter(backend_node);
front_node = (front_node != nullptr ? front_node : internal_node_with_index.first);
if (front_node == nullptr) {
auto front_node_with_index = graph->GetElementInTupleBackendFrontIndexMap(node);
auto front_node_with_index = graph->GetElementInTupleBackendFrontIndexMap(backend_node);
front_node = front_node_with_index.first;
}

// If parameter is a weight node, it should be set to kernel actor directly.
if (AnfAlgo::IsParameterWeight(node->cast<ParameterPtr>()) ||
(front_node != nullptr && front_node->isa<Parameter>() &&
AnfAlgo::IsParameterWeight(front_node->cast<ParameterPtr>()))) {
MS_EXCEPTION_IF_NULL(front_node);
// If parameter is a weight node in root funcgraph, it should be set to kernel actor directly.
if (IsRootGraphParameter(front_node) && AnfAlgo::IsParameterWeight(backend_node->cast<ParameterPtr>())) {
return false;
}



+ 1
- 1
mindspore/ccsrc/runtime/framework/control_node_parser.h View File

@@ -125,7 +125,7 @@ class ControlNodeParser {
// There are two situations:
// 1. In control flow, the parameter input needs to be connected to the entrance actor of the funcgraph.
// 2. In the kernel graph with call node input, the data arrow needs to be connected to the stack actor.
bool IsControlFlowDataArrow(const KernelGraphPtr &graph, const AnfNodePtr &node);
bool IsControlFlowDataArrow(const KernelGraphPtr &graph, const AnfNodePtr &backend_node);
bool IsRootGraphParameter(const AnfNodePtr &node);
bool IsRecursionCallNode(const AnfNodePtr &node);
// If there is a recursive call node in the input of the kernel graph, the graph is recursive.


+ 35
- 22
mindspore/ccsrc/runtime/framework/graph_scheduler.cc View File

@@ -1862,6 +1862,7 @@ void GraphScheduler::CheckActorValid(const ActorSet *actor_set) const {
}

void GraphScheduler::PersistDeviceTensor(const GraphCompilerInfo &graph_compiler_info) {
const auto &parser = graph_compiler_info.control_node_parser_;
for (size_t i = 0; i < graph_compiler_info.graphs_.size(); ++i) {
const auto &graph = graph_compiler_info.graphs_[i];
const auto &device_context = graph_compiler_info.device_contexts_[i];
@@ -1882,24 +1883,18 @@ void GraphScheduler::PersistDeviceTensor(const GraphCompilerInfo &graph_compiler

for (auto &input_node : graph->input_nodes()) {
MS_EXCEPTION_IF_NULL(input_node);
AnfNodePtr sub_front_node = nullptr;
AnfNodePtr front_node = nullptr;
if (IsInternalParameter(input_node, graph)) {
auto front_output_with_index = graph->GetFrontNodeByInternalParameter(input_node);
sub_front_node = front_output_with_index.first;
} else if (IsPersistentDeviceTensor(input_node) || HasAbstractRef(input_node)) {
sub_front_node = FetchFrontNodeByBackendNode(input_node, graph);
front_node = front_output_with_index.first;
} else if (IsPersistentDeviceTensor(input_node)) {
front_node = FetchFrontNodeByBackendNode(input_node, graph);
}
if (sub_front_node == nullptr) {
if (front_node == nullptr || (!IsPersistentDeviceTensor(front_node)) ||
(parser != nullptr && parser->IsInited() && (!parser->IsRootGraphParameter(front_node)))) {
continue;
}

// The sub front nodes share the device tensor store with the root front node.
MS_EXCEPTION_IF_NULL(graph_compiler_info.control_node_parser_);
auto front_node = graph_compiler_info.control_node_parser_->FetchRootGraphFrontNodeBySubFrontNode(sub_front_node);
MS_EXCEPTION_IF_NULL(front_node);
MS_LOG(DEBUG) << "Graph id:" << graph->graph_id() << ", sub front node:" << sub_front_node->DebugString()
<< ", root front node:" << front_node->DebugString();

auto device_tensor = AnfAlgo::GetMutableOutputAddr(input_node, 0, false);
MS_EXCEPTION_IF_NULL(device_tensor);
if (IsPersistentDeviceTensor(input_node)) {
@@ -1907,14 +1902,11 @@ void GraphScheduler::PersistDeviceTensor(const GraphCompilerInfo &graph_compiler
AddDeviceTensorStore(front_node.get(), device_tensor);
}

// Share the weight in the host and device, then input_node is internal parameter and front_node is weight.
if (!IsPersistentDeviceTensor(front_node)) {
continue;
}
if (device_tensor->is_ptr_persisted()) {
device_tensor->SetNodeIndex(input_node, 0);
AddDeviceTensorStore(front_node.get(), device_tensor);
}

// If the device tensor store of this device type is not exist, then create the new device tensor of this type.
if (DeviceTensorStore::GetInstance().Fetch(front_node.get(), device_context->GetDeviceAddressType()) == nullptr) {
MS_LOG(INFO) << "Fetch no device tensor store by:" << front_node->fullname_with_scope()
@@ -1927,17 +1919,38 @@ void GraphScheduler::PersistDeviceTensor(const GraphCompilerInfo &graph_compiler
}
}
}
PersistDeviceTensorForControlNode(graph_compiler_info);
}

void GraphScheduler::PersistDeviceTensorForControlNode(const GraphCompilerInfo &graph_compiler_info) {
const auto &parser = graph_compiler_info.control_node_parser_;
if (parser == nullptr) {
if (parser == nullptr || (!parser->IsInited())) {
return;
}
for (const auto &sub_front_node_to_root_front_node : parser->sub_front_node_to_root_front_node_) {
auto device_tensors = DeviceTensorStore::GetInstance().Fetch(sub_front_node_to_root_front_node.second.get());
for (const auto &device_tensor : device_tensors) {
MS_EXCEPTION_IF_NULL(device_tensor);
AddDeviceTensorStore(sub_front_node_to_root_front_node.first.get(), device_tensor);

const auto &control_node_parameters = parser->control_node_parameters();
for (size_t i = 0; i < control_node_parameters.size(); ++i) {
const auto &input_node = control_node_parameters[i];
MS_EXCEPTION_IF_NULL(input_node);
if ((!IsPersistentDeviceTensor(input_node)) || (!parser->IsRootGraphParameter(input_node))) {
continue;
}
const auto &front_to_backend_parameters = parser->front_to_backend_parameters();
const auto &iter = front_to_backend_parameters.find({input_node, 0});
if (iter == front_to_backend_parameters.end() || iter->second.empty()) {
MS_LOG(EXCEPTION) << "Cannot find backend node for weight parameter:" << input_node->DebugString();
}
const auto &node_with_context = iter->second.begin();
const auto &backend_node = node_with_context->first;
const auto &device_context = node_with_context->second;
MS_EXCEPTION_IF_NULL(backend_node);
MS_EXCEPTION_IF_NULL(device_context);
if (!DeviceTensorStore::GetInstance().Fetch(input_node.get()).empty()) {
continue;
}
auto device_tensor = AnfAlgo::GetMutableOutputAddr(backend_node, 0, false);
MS_EXCEPTION_IF_NULL(device_tensor);
AddDeviceTensorStore(input_node.get(), device_tensor);
}
}



+ 1
- 0
mindspore/ccsrc/runtime/framework/graph_scheduler.h View File

@@ -179,6 +179,7 @@ class GraphScheduler {

// Persist device tensors of graph's some nodes(such as weights and value nodes).
void PersistDeviceTensor(const GraphCompilerInfo &graph_compiler_info);
void PersistDeviceTensorForControlNode(const GraphCompilerInfo &graph_compiler_info);

// Display the actor information of corresponding kernel graph.
void DumpActor(const ActorSet *actor_set, const GraphCompilerInfo &graph_compiler_info) const;


Loading…
Cancel
Save