Merge pull request !2861 from xychow/optimize-setparam-to-addparamtags/v0.6.0-beta
| @@ -68,9 +68,7 @@ ParameterPtr FuncGraph::add_parameter() { | |||
| void FuncGraph::add_parameter(const ParameterPtr &p) { | |||
| if (manager_.lock()) { | |||
| std::vector<AnfNodePtr> new_params = parameters_; | |||
| new_params.push_back(p); | |||
| manager_.lock()->SetParameters(shared_from_base<FuncGraph>(), new_params); | |||
| manager_.lock()->AddParameter(shared_from_base<FuncGraph>(), p); | |||
| } else { | |||
| parameters_.push_back(p); | |||
| } | |||
| @@ -82,12 +80,8 @@ ParameterPtr FuncGraph::AddWeightParameter(const std::string &name) { | |||
| p->set_name(name); | |||
| p->debug_info()->set_name(name); | |||
| std::vector<AnfNodePtr> new_params = parameters_; | |||
| // append parameter | |||
| new_params.push_back(p); | |||
| if (manager_.lock()) { | |||
| manager_.lock()->SetParameters(shared_from_base<FuncGraph>(), new_params); | |||
| manager_.lock()->AddParameter(shared_from_base<FuncGraph>(), p); | |||
| } else { | |||
| parameters_.push_back(p); | |||
| } | |||
| @@ -158,6 +158,7 @@ class FuncGraph : public FuncGraphBase { | |||
| const std::vector<AnfNodePtr> ¶meters() const { return parameters_; } | |||
| virtual ParameterPtr add_parameter(); | |||
| void add_parameter(const ParameterPtr &p); | |||
| void append_parameter(const ParameterPtr &p) { parameters_.push_back(p); } | |||
| void set_parameters(const std::vector<AnfNodePtr> ¶ms) { parameters_ = params; } | |||
| // add a weight parameter with specific name | |||
| ParameterPtr AddWeightParameter(const std::string &name); | |||
| @@ -420,6 +420,12 @@ void FuncGraphManager::SetParameters(const FuncGraphPtr &fg, const std::vector<A | |||
| tr.Commit(); | |||
| } | |||
| void FuncGraphManager::AddParameter(const FuncGraphPtr &fg, const AnfNodePtr ¶meter) { | |||
| auto tr = Transact(); | |||
| tr.AddParameter(fg, parameter); | |||
| tr.Commit(); | |||
| } | |||
| bool FuncGraphManager::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node) { | |||
| auto tr = Transact(); | |||
| bool success = tr.Replace(old_node, new_node); | |||
| @@ -532,25 +538,37 @@ void FuncGraphManager::ParseChanges(const std::vector<Change> &changes, EdgeTupl | |||
| for (auto &iter : changes) { | |||
| auto operation = iter.op; | |||
| auto args = iter.args; | |||
| if (operation == Change::kTxSetEdge) { | |||
| auto edge = args.cast<ArgsOfSetEdge>(); | |||
| auto old_node = edge.root_node->input(edge.index); | |||
| (*rm_edges)[std::make_pair(edge.root_node, std::make_pair(edge.index, old_node))] += 1; | |||
| (*add_edges)[std::make_pair(edge.root_node, std::make_pair(edge.index, edge.new_node))] += 1; | |||
| (*rms)[old_node] += 1; | |||
| (*adds)[edge.new_node] += 1; | |||
| edge.root_node->set_input(edge.index, edge.new_node); | |||
| } else if (operation == Change::kTxSetParams) { | |||
| auto param = args.cast<ArgsOfSetParams>(); | |||
| MS_EXCEPTION_IF_NULL(param.func_graph); | |||
| auto old_parameters = param.func_graph->parameters(); | |||
| for (auto &p : param.params) { | |||
| (*adds)[p] += 1; | |||
| } | |||
| for (auto &p : old_parameters) { | |||
| (*rms)[p] += 1; | |||
| } | |||
| param.func_graph->set_parameters(param.params); | |||
| switch (operation) { | |||
| case Change::kTxSetEdge: { | |||
| auto edge = args.cast<ArgsOfSetEdge>(); | |||
| auto old_node = edge.root_node->input(edge.index); | |||
| (*rm_edges)[std::make_pair(edge.root_node, std::make_pair(edge.index, old_node))] += 1; | |||
| (*add_edges)[std::make_pair(edge.root_node, std::make_pair(edge.index, edge.new_node))] += 1; | |||
| (*rms)[old_node] += 1; | |||
| (*adds)[edge.new_node] += 1; | |||
| edge.root_node->set_input(edge.index, edge.new_node); | |||
| } break; | |||
| case Change::kTxSetParams: { | |||
| auto param = args.cast<ArgsOfSetParams>(); | |||
| MS_EXCEPTION_IF_NULL(param.func_graph); | |||
| auto old_parameters = param.func_graph->parameters(); | |||
| for (auto &p : param.params) { | |||
| (*adds)[p] += 1; | |||
| } | |||
| for (auto &p : old_parameters) { | |||
| (*rms)[p] += 1; | |||
| } | |||
| param.func_graph->set_parameters(param.params); | |||
| } break; | |||
| case Change::kTxAddParam: { | |||
| auto param = args.cast<ArgsOfAddParam>(); | |||
| MS_EXCEPTION_IF_NULL(param.func_graph); | |||
| (*adds)[param.param] += 1; | |||
| auto param_node = param.param->cast<ParameterPtr>(); | |||
| param.func_graph->append_parameter(param_node); | |||
| } break; | |||
| default: | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| @@ -599,6 +617,10 @@ void FuncGraphTransaction::SetParameters(FuncGraphPtr fg, const std::vector<AnfN | |||
| changes_.emplace_back(Change::kTxSetParams, ArgsOfSetParams{fg, params}); | |||
| } | |||
| void FuncGraphTransaction::AddParameter(FuncGraphPtr fg, const AnfNodePtr ¶m) { | |||
| changes_.emplace_back(Change::kTxAddParam, ArgsOfAddParam{fg, param}); | |||
| } | |||
| bool FuncGraphTransaction::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node) { | |||
| MS_EXCEPTION_IF_NULL(old_node); | |||
| MS_EXCEPTION_IF_NULL(new_node); | |||
| @@ -310,6 +310,7 @@ class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> { | |||
| void KeepRoots(const std::vector<FuncGraphPtr> &roots = {}); | |||
| void RemoveRoots(); | |||
| void SetParameters(const FuncGraphPtr &fg, const std::vector<AnfNodePtr> ¶meters); | |||
| void AddParameter(const FuncGraphPtr &fg, const AnfNodePtr ¶meter); | |||
| void MaybeDropFuncGraphs(const FuncGraphSet &func_graphs, bool ignore_users = false); | |||
| bool Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node); | |||
| void SetEdge(const AnfNodePtr &node, int index, const AnfNodePtr &value); | |||
| @@ -400,6 +401,7 @@ class FuncGraphTransaction { | |||
| // set parameters of a func graph | |||
| void SetParameters(FuncGraphPtr fg, const std::vector<AnfNodePtr> ¶ms); | |||
| void AddParameter(FuncGraphPtr fg, const AnfNodePtr ¶m); | |||
| // replace old_node with new_node | |||
| bool Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node); | |||
| @@ -427,6 +429,18 @@ struct ArgsOfSetParams { | |||
| } | |||
| }; | |||
| // args for add param | |||
| struct ArgsOfAddParam { | |||
| FuncGraphPtr func_graph; | |||
| AnfNodePtr param; | |||
| bool operator==(const ArgsOfAddParam &other) const { return &other == this; } | |||
| friend std::ostream &operator<<(std::ostream &os, const ArgsOfAddParam &) { | |||
| os << "[ArgsOfAddParam]"; | |||
| return os; | |||
| } | |||
| }; | |||
| // args for set edge | |||
| struct ArgsOfSetEdge { | |||
| CNodePtr root_node; | |||
| @@ -441,7 +455,7 @@ struct ArgsOfSetEdge { | |||
| }; | |||
| struct Change { | |||
| enum OpName { kTxSetParams, kTxSetEdge }; | |||
| enum OpName { kTxSetParams, kTxSetEdge, kTxAddParam }; | |||
| OpName op; | |||
| Any args; | |||
| Change(OpName name, const Any ¶) : op(name), args(para) {} | |||