Browse Source

!9507 Fix bug in scenario that one tensor for multiple graph output in pynative bp graph

From: @HulkTang
Reviewed-by: @chujinjin
Signed-off-by: @chujinjin
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
0d21d5b570
1 changed files with 24 additions and 17 deletions
  1. +24
    -17
      mindspore/ccsrc/backend/session/ascend_session.cc

+ 24
- 17
mindspore/ccsrc/backend/session/ascend_session.cc View File

@@ -128,7 +128,7 @@ void InsertMakeTupleForOutput(NotNull<KernelGraphPtr> root_graph) {
BaseRef CreateNodeOutputPlaceholder(const session::KernelWithIndex &node_output_pair, const KernelGraphPtr &graph,
const std::vector<tensor::TensorPtr> &input_tensors,
const std::vector<size_t> &indexes,
std::map<KernelWithIndex, std::vector<size_t>> *output_indexes) {
std::map<KernelWithIndex, std::vector<std::vector<size_t>>> *output_indexes) {
auto &node = node_output_pair.first;
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(graph);
@@ -152,7 +152,7 @@ BaseRef CreateNodeOutputPlaceholder(const session::KernelWithIndex &node_output_
}
MS_LOG(EXCEPTION) << "Parameter: " << node->DebugString() << " has no output addr";
}
(*output_indexes)[node_output_pair] = indexes;
(*output_indexes)[node_output_pair].emplace_back(indexes);
BaseRef output_placeholder = std::make_shared<BaseRef>();
return output_placeholder;
}
@@ -160,7 +160,7 @@ BaseRef CreateNodeOutputPlaceholder(const session::KernelWithIndex &node_output_
BaseRef CreateNodeOutputPlaceholder(const AnfNodePtr &anf, const KernelGraphPtr &graph,
const std::vector<tensor::TensorPtr> &input_tensors,
const std::vector<size_t> &indexes,
std::map<KernelWithIndex, std::vector<size_t>> *output_indexes) {
std::map<KernelWithIndex, std::vector<std::vector<size_t>>> *output_indexes) {
MS_EXCEPTION_IF_NULL(anf);
MS_EXCEPTION_IF_NULL(output_indexes);
MS_LOG(INFO) << "Create placeholder for output[" << anf->DebugString() << "]";
@@ -189,7 +189,8 @@ BaseRef CreateNodeOutputPlaceholder(const AnfNodePtr &anf, const KernelGraphPtr
}

void CreateOutputPlaceholder(const KernelGraphPtr &kernel_graph, const std::vector<tensor::TensorPtr> &input_tensors,
VectorRef *outputs, std::map<KernelWithIndex, std::vector<size_t>> *output_indexes) {
VectorRef *outputs,
std::map<KernelWithIndex, std::vector<std::vector<size_t>>> *output_indexes) {
MS_EXCEPTION_IF_NULL(kernel_graph);
MS_EXCEPTION_IF_NULL(outputs);
MS_EXCEPTION_IF_NULL(output_indexes);
@@ -333,7 +334,7 @@ void HandleOpInputs(const std::set<KernelWithIndex> &input_kernel, std::map<Kern
}

void HandleOpOutputs(const AnfNodePtr &kernel, const VectorRef &op_outputs,
const std::map<KernelWithIndex, std::vector<size_t>> &output_indexes,
const std::map<KernelWithIndex, std::vector<std::vector<size_t>>> &output_indexes,
const std::map<KernelWithIndex, size_t> &ref_count,
std::map<KernelWithIndex, tensor::TensorPtr> *op_output_map, VectorRef *outputs) {
auto output_tensors = TransformVectorRefToMultiTensor(op_outputs);
@@ -350,19 +351,25 @@ void HandleOpOutputs(const AnfNodePtr &kernel, const VectorRef &op_outputs,
if (iter == output_indexes.end()) {
continue;
}
const std::vector<size_t> &ref_indexes = iter->second;
size_t n = 0;
const VectorRef *cur_vector_ref = outputs;
while (n != ref_indexes.size() - 1) {
size_t index = ref_indexes.at(n++);
const BaseRef &base_ref = (*cur_vector_ref)[index];
if (!utils::isa<VectorRef>(base_ref)) {
MS_LOG(EXCEPTION) << "Get none VectorRef by ref index, indexes: " << ref_indexes << "cur n: " << n - 1;
const std::vector<std::vector<size_t>> &multiple_ref_indexes = iter->second;
for (const auto &ref_indexes : multiple_ref_indexes) {
size_t n = 0;
const VectorRef *cur_vector_ref = outputs;
for (; n < ref_indexes.size() - 1; n += 1) {
size_t index = ref_indexes.at(n);
if (index >= cur_vector_ref->size()) {
MS_LOG(EXCEPTION) << "Get invalid output ref index: " << index << ", size of vertor ref is "
<< cur_vector_ref->size();
}
const BaseRef &base_ref = (*cur_vector_ref)[index];
if (!utils::isa<VectorRef>(base_ref)) {
MS_LOG(EXCEPTION) << "Get none VectorRef by ref index, index: " << index << "cur n: " << n;
}
cur_vector_ref = &utils::cast<VectorRef>(base_ref);
}
cur_vector_ref = &utils::cast<VectorRef>(base_ref);
BaseRef &tensor_ref = (*const_cast<VectorRef *>(cur_vector_ref))[ref_indexes.at(n)];
tensor_ref = output_tensor;
}
BaseRef &tensor_ref = (*const_cast<VectorRef *>(cur_vector_ref))[ref_indexes.at(n)];
tensor_ref = output_tensor;
}
}

@@ -725,7 +732,7 @@ void AscendSession::RunOpsInGraphImpl(const GraphId &graph_id, const std::vector
auto kernel_graph = GetGraph(graph_id);
std::map<AnfNodePtr, size_t> parameter_index;
GetParameterIndex(kernel_graph.get(), inputs, &parameter_index);
std::map<KernelWithIndex, std::vector<size_t>> output_indexes;
std::map<KernelWithIndex, std::vector<std::vector<size_t>>> output_indexes;
CreateOutputPlaceholder(kernel_graph, inputs, outputs, &output_indexes);
std::map<KernelWithIndex, size_t> cnode_ref;
GetRefCount(kernel_graph.get(), &cnode_ref);


Loading…
Cancel
Save