Browse Source

!4990 fix bug of EraseAssign

Merge pull request !4990 from wenchunjiang/adapte_to_resnet_second_optimize
tags/v1.0.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
c6767c0f3a
2 changed files with 12 additions and 11 deletions
  1. +10
    -10
      mindspore/ccsrc/backend/session/ascend_control_parser.cc
  2. +2
    -1
      mindspore/ccsrc/backend/session/ascend_control_parser.h

+ 10
- 10
mindspore/ccsrc/backend/session/ascend_control_parser.cc View File

@@ -261,17 +261,16 @@ void AscendControlParser::EraseParameter(NotNull<KernelGraphPtr> root_graph,
}
}

EraseAssign(all_nodes, para_to_written_node, root_graph);
root_graph->set_execution_order(exec_order);
EraseAssign(std::make_shared<ReferenceCounter>(parameter_count), all_nodes, para_to_written_node, root_graph);
}

void AscendControlParser::EraseAssign(const std::set<CNodePtr> &all_nodes,
void AscendControlParser::EraseAssign(std::shared_ptr<ReferenceCounter> parameter_count,
const std::set<CNodePtr> &all_nodes,
const std::map<AnfNodePtr, CNodePtr> &para_to_written_node,
NotNull<KernelGraphPtr> root_graph) {
std::vector<CNodePtr> exec_order = root_graph->execution_order();
ReferenceCounter parameter_count([](int32_t read, int32_t write) -> bool { return write == 1; });
while (parameter_count.HasValidElem()) {
auto [para, read, written] = parameter_count.GetOneValidElem();
while (parameter_count->HasValidElem()) {
auto [para, read, written] = parameter_count->GetOneValidElem();
MS_LOG(INFO) << para->DebugString() << " was read " << read << " times, written " << written << " times.";
auto assign_iter = para_to_written_node.find(para);
if (assign_iter == para_to_written_node.end()) {
@@ -280,7 +279,7 @@ void AscendControlParser::EraseAssign(const std::set<CNodePtr> &all_nodes,
auto &assign_node = assign_iter->second;
MS_EXCEPTION_IF_NULL(assign_node);
if (!IsPrimitiveCNode(assign_node, prim::kPrimAssign)) {
parameter_count.EraseElem(para);
parameter_count->EraseElem(para);
continue;
}
MS_LOG(INFO) << "Erase " << assign_node->DebugString(5);
@@ -288,10 +287,10 @@ void AscendControlParser::EraseAssign(const std::set<CNodePtr> &all_nodes,
auto source = assign_node->input(kCNodeAssignSource);
MS_EXCEPTION_IF_NULL(source);
auto visit_source = AnfAlgo::VisitKernelWithReturnType(source, 0).first;
parameter_count.AddWriteCount(para, -1);
parameter_count.AddReadCount(para, -1);
parameter_count->AddWriteCount(para, -1);
parameter_count->AddReadCount(para, -1);
if (visit_source->isa<Parameter>()) {
parameter_count.AddReadCount(visit_source, read - 1);
parameter_count->AddReadCount(visit_source, read - 1);
}
for (auto &node : all_nodes) {
for (size_t i = 0; i < node->size(); ++i) {
@@ -302,6 +301,7 @@ void AscendControlParser::EraseAssign(const std::set<CNodePtr> &all_nodes,
}
}
}
root_graph->set_execution_order(exec_order);
}

void AscendControlParser::EraseLabel(NotNull<KernelGraphPtr> root_graph) {


+ 2
- 1
mindspore/ccsrc/backend/session/ascend_control_parser.h View File

@@ -22,6 +22,7 @@
#include <tuple>
#include <utility>
#include <functional>
#include <memory>
#include "backend/session/kernel_graph.h"
#include "base/base_ref.h"
#include "utils/contract.h"
@@ -44,7 +45,7 @@ class AscendControlParser {
class ReferenceCounter;

static void EraseParameter(NotNull<KernelGraphPtr> root_graph, const std::set<KernelGraphPtr> &graph_list);
static void EraseAssign(const std::set<CNodePtr> &all_nodes,
static void EraseAssign(std::shared_ptr<ReferenceCounter> parameter_count, const std::set<CNodePtr> &all_nodes,
const std::map<AnfNodePtr, CNodePtr> &para_to_written_node,
NotNull<KernelGraphPtr> root_graph);
static void EraseLabel(NotNull<KernelGraphPtr> root_graph);


Loading…
Cancel
Save