Browse Source

!2861 use addparam to replace setparam to reduce overhead

Merge pull request !2861 from xychow/optimize-setparam-to-addparam
tags/v0.6.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
056f9f6dc1
4 changed files with 59 additions and 28 deletions
  1. +2
    -8
      mindspore/ccsrc/ir/func_graph.cc
  2. +1
    -0
      mindspore/ccsrc/ir/func_graph.h
  3. +41
    -19
      mindspore/ccsrc/ir/manager.cc
  4. +15
    -1
      mindspore/ccsrc/ir/manager.h

+ 2
- 8
mindspore/ccsrc/ir/func_graph.cc View File

@@ -68,9 +68,7 @@ ParameterPtr FuncGraph::add_parameter() {


void FuncGraph::add_parameter(const ParameterPtr &p) { void FuncGraph::add_parameter(const ParameterPtr &p) {
if (manager_.lock()) { if (manager_.lock()) {
std::vector<AnfNodePtr> new_params = parameters_;
new_params.push_back(p);
manager_.lock()->SetParameters(shared_from_base<FuncGraph>(), new_params);
manager_.lock()->AddParameter(shared_from_base<FuncGraph>(), p);
} else { } else {
parameters_.push_back(p); parameters_.push_back(p);
} }
@@ -82,12 +80,8 @@ ParameterPtr FuncGraph::AddWeightParameter(const std::string &name) {
p->set_name(name); p->set_name(name);
p->debug_info()->set_name(name); p->debug_info()->set_name(name);


std::vector<AnfNodePtr> new_params = parameters_;
// append parameter
new_params.push_back(p);

if (manager_.lock()) { if (manager_.lock()) {
manager_.lock()->SetParameters(shared_from_base<FuncGraph>(), new_params);
manager_.lock()->AddParameter(shared_from_base<FuncGraph>(), p);
} else { } else {
parameters_.push_back(p); parameters_.push_back(p);
} }


+ 1
- 0
mindspore/ccsrc/ir/func_graph.h View File

@@ -158,6 +158,7 @@ class FuncGraph : public FuncGraphBase {
const std::vector<AnfNodePtr> &parameters() const { return parameters_; } const std::vector<AnfNodePtr> &parameters() const { return parameters_; }
virtual ParameterPtr add_parameter(); virtual ParameterPtr add_parameter();
void add_parameter(const ParameterPtr &p); void add_parameter(const ParameterPtr &p);
void append_parameter(const ParameterPtr &p) { parameters_.push_back(p); }
void set_parameters(const std::vector<AnfNodePtr> &params) { parameters_ = params; } void set_parameters(const std::vector<AnfNodePtr> &params) { parameters_ = params; }
// add a weight parameter with specific name // add a weight parameter with specific name
ParameterPtr AddWeightParameter(const std::string &name); ParameterPtr AddWeightParameter(const std::string &name);


+ 41
- 19
mindspore/ccsrc/ir/manager.cc View File

@@ -420,6 +420,12 @@ void FuncGraphManager::SetParameters(const FuncGraphPtr &fg, const std::vector<A
tr.Commit(); tr.Commit();
} }


void FuncGraphManager::AddParameter(const FuncGraphPtr &fg, const AnfNodePtr &parameter) {
auto tr = Transact();
tr.AddParameter(fg, parameter);
tr.Commit();
}

bool FuncGraphManager::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node) { bool FuncGraphManager::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node) {
auto tr = Transact(); auto tr = Transact();
bool success = tr.Replace(old_node, new_node); bool success = tr.Replace(old_node, new_node);
@@ -532,25 +538,37 @@ void FuncGraphManager::ParseChanges(const std::vector<Change> &changes, EdgeTupl
for (auto &iter : changes) { for (auto &iter : changes) {
auto operation = iter.op; auto operation = iter.op;
auto args = iter.args; auto args = iter.args;
if (operation == Change::kTxSetEdge) {
auto edge = args.cast<ArgsOfSetEdge>();
auto old_node = edge.root_node->input(edge.index);
(*rm_edges)[std::make_pair(edge.root_node, std::make_pair(edge.index, old_node))] += 1;
(*add_edges)[std::make_pair(edge.root_node, std::make_pair(edge.index, edge.new_node))] += 1;
(*rms)[old_node] += 1;
(*adds)[edge.new_node] += 1;
edge.root_node->set_input(edge.index, edge.new_node);
} else if (operation == Change::kTxSetParams) {
auto param = args.cast<ArgsOfSetParams>();
MS_EXCEPTION_IF_NULL(param.func_graph);
auto old_parameters = param.func_graph->parameters();
for (auto &p : param.params) {
(*adds)[p] += 1;
}
for (auto &p : old_parameters) {
(*rms)[p] += 1;
}
param.func_graph->set_parameters(param.params);
switch (operation) {
case Change::kTxSetEdge: {
auto edge = args.cast<ArgsOfSetEdge>();
auto old_node = edge.root_node->input(edge.index);
(*rm_edges)[std::make_pair(edge.root_node, std::make_pair(edge.index, old_node))] += 1;
(*add_edges)[std::make_pair(edge.root_node, std::make_pair(edge.index, edge.new_node))] += 1;
(*rms)[old_node] += 1;
(*adds)[edge.new_node] += 1;
edge.root_node->set_input(edge.index, edge.new_node);
} break;
case Change::kTxSetParams: {
auto param = args.cast<ArgsOfSetParams>();
MS_EXCEPTION_IF_NULL(param.func_graph);
auto old_parameters = param.func_graph->parameters();
for (auto &p : param.params) {
(*adds)[p] += 1;
}
for (auto &p : old_parameters) {
(*rms)[p] += 1;
}
param.func_graph->set_parameters(param.params);
} break;
case Change::kTxAddParam: {
auto param = args.cast<ArgsOfAddParam>();
MS_EXCEPTION_IF_NULL(param.func_graph);
(*adds)[param.param] += 1;
auto param_node = param.param->cast<ParameterPtr>();
param.func_graph->append_parameter(param_node);
} break;
default:
break;
} }
} }
} }
@@ -599,6 +617,10 @@ void FuncGraphTransaction::SetParameters(FuncGraphPtr fg, const std::vector<AnfN
changes_.emplace_back(Change::kTxSetParams, ArgsOfSetParams{fg, params}); changes_.emplace_back(Change::kTxSetParams, ArgsOfSetParams{fg, params});
} }


void FuncGraphTransaction::AddParameter(FuncGraphPtr fg, const AnfNodePtr &param) {
changes_.emplace_back(Change::kTxAddParam, ArgsOfAddParam{fg, param});
}

bool FuncGraphTransaction::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node) { bool FuncGraphTransaction::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node) {
MS_EXCEPTION_IF_NULL(old_node); MS_EXCEPTION_IF_NULL(old_node);
MS_EXCEPTION_IF_NULL(new_node); MS_EXCEPTION_IF_NULL(new_node);


+ 15
- 1
mindspore/ccsrc/ir/manager.h View File

@@ -310,6 +310,7 @@ class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> {
void KeepRoots(const std::vector<FuncGraphPtr> &roots = {}); void KeepRoots(const std::vector<FuncGraphPtr> &roots = {});
void RemoveRoots(); void RemoveRoots();
void SetParameters(const FuncGraphPtr &fg, const std::vector<AnfNodePtr> &parameters); void SetParameters(const FuncGraphPtr &fg, const std::vector<AnfNodePtr> &parameters);
void AddParameter(const FuncGraphPtr &fg, const AnfNodePtr &parameter);
void MaybeDropFuncGraphs(const FuncGraphSet &func_graphs, bool ignore_users = false); void MaybeDropFuncGraphs(const FuncGraphSet &func_graphs, bool ignore_users = false);
bool Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node); bool Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node);
void SetEdge(const AnfNodePtr &node, int index, const AnfNodePtr &value); void SetEdge(const AnfNodePtr &node, int index, const AnfNodePtr &value);
@@ -400,6 +401,7 @@ class FuncGraphTransaction {


// set parameters of a func graph // set parameters of a func graph
void SetParameters(FuncGraphPtr fg, const std::vector<AnfNodePtr> &params); void SetParameters(FuncGraphPtr fg, const std::vector<AnfNodePtr> &params);
void AddParameter(FuncGraphPtr fg, const AnfNodePtr &param);


// replace old_node with new_node // replace old_node with new_node
bool Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node); bool Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node);
@@ -427,6 +429,18 @@ struct ArgsOfSetParams {
} }
}; };


// args for add param
struct ArgsOfAddParam {
FuncGraphPtr func_graph;
AnfNodePtr param;
bool operator==(const ArgsOfAddParam &other) const { return &other == this; }

friend std::ostream &operator<<(std::ostream &os, const ArgsOfAddParam &) {
os << "[ArgsOfAddParam]";
return os;
}
};

// args for set edge // args for set edge
struct ArgsOfSetEdge { struct ArgsOfSetEdge {
CNodePtr root_node; CNodePtr root_node;
@@ -441,7 +455,7 @@ struct ArgsOfSetEdge {
}; };


struct Change { struct Change {
enum OpName { kTxSetParams, kTxSetEdge };
enum OpName { kTxSetParams, kTxSetEdge, kTxAddParam };
OpName op; OpName op;
Any args; Any args;
Change(OpName name, const Any &para) : op(name), args(para) {} Change(OpName name, const Any &para) : op(name), args(para) {}


Loading…
Cancel
Save