Merge pull request !2861 from xychow/optimize-setparam-to-addparamtags/v0.6.0-beta
| @@ -68,9 +68,7 @@ ParameterPtr FuncGraph::add_parameter() { | |||||
| void FuncGraph::add_parameter(const ParameterPtr &p) { | void FuncGraph::add_parameter(const ParameterPtr &p) { | ||||
| if (manager_.lock()) { | if (manager_.lock()) { | ||||
| std::vector<AnfNodePtr> new_params = parameters_; | |||||
| new_params.push_back(p); | |||||
| manager_.lock()->SetParameters(shared_from_base<FuncGraph>(), new_params); | |||||
| manager_.lock()->AddParameter(shared_from_base<FuncGraph>(), p); | |||||
| } else { | } else { | ||||
| parameters_.push_back(p); | parameters_.push_back(p); | ||||
| } | } | ||||
| @@ -82,12 +80,8 @@ ParameterPtr FuncGraph::AddWeightParameter(const std::string &name) { | |||||
| p->set_name(name); | p->set_name(name); | ||||
| p->debug_info()->set_name(name); | p->debug_info()->set_name(name); | ||||
| std::vector<AnfNodePtr> new_params = parameters_; | |||||
| // append parameter | |||||
| new_params.push_back(p); | |||||
| if (manager_.lock()) { | if (manager_.lock()) { | ||||
| manager_.lock()->SetParameters(shared_from_base<FuncGraph>(), new_params); | |||||
| manager_.lock()->AddParameter(shared_from_base<FuncGraph>(), p); | |||||
| } else { | } else { | ||||
| parameters_.push_back(p); | parameters_.push_back(p); | ||||
| } | } | ||||
| @@ -158,6 +158,7 @@ class FuncGraph : public FuncGraphBase { | |||||
| const std::vector<AnfNodePtr> ¶meters() const { return parameters_; } | const std::vector<AnfNodePtr> ¶meters() const { return parameters_; } | ||||
| virtual ParameterPtr add_parameter(); | virtual ParameterPtr add_parameter(); | ||||
| void add_parameter(const ParameterPtr &p); | void add_parameter(const ParameterPtr &p); | ||||
| void append_parameter(const ParameterPtr &p) { parameters_.push_back(p); } | |||||
| void set_parameters(const std::vector<AnfNodePtr> ¶ms) { parameters_ = params; } | void set_parameters(const std::vector<AnfNodePtr> ¶ms) { parameters_ = params; } | ||||
| // add a weight parameter with specific name | // add a weight parameter with specific name | ||||
| ParameterPtr AddWeightParameter(const std::string &name); | ParameterPtr AddWeightParameter(const std::string &name); | ||||
| @@ -420,6 +420,12 @@ void FuncGraphManager::SetParameters(const FuncGraphPtr &fg, const std::vector<A | |||||
| tr.Commit(); | tr.Commit(); | ||||
| } | } | ||||
| void FuncGraphManager::AddParameter(const FuncGraphPtr &fg, const AnfNodePtr ¶meter) { | |||||
| auto tr = Transact(); | |||||
| tr.AddParameter(fg, parameter); | |||||
| tr.Commit(); | |||||
| } | |||||
| bool FuncGraphManager::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node) { | bool FuncGraphManager::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node) { | ||||
| auto tr = Transact(); | auto tr = Transact(); | ||||
| bool success = tr.Replace(old_node, new_node); | bool success = tr.Replace(old_node, new_node); | ||||
| @@ -532,25 +538,37 @@ void FuncGraphManager::ParseChanges(const std::vector<Change> &changes, EdgeTupl | |||||
| for (auto &iter : changes) { | for (auto &iter : changes) { | ||||
| auto operation = iter.op; | auto operation = iter.op; | ||||
| auto args = iter.args; | auto args = iter.args; | ||||
| if (operation == Change::kTxSetEdge) { | |||||
| auto edge = args.cast<ArgsOfSetEdge>(); | |||||
| auto old_node = edge.root_node->input(edge.index); | |||||
| (*rm_edges)[std::make_pair(edge.root_node, std::make_pair(edge.index, old_node))] += 1; | |||||
| (*add_edges)[std::make_pair(edge.root_node, std::make_pair(edge.index, edge.new_node))] += 1; | |||||
| (*rms)[old_node] += 1; | |||||
| (*adds)[edge.new_node] += 1; | |||||
| edge.root_node->set_input(edge.index, edge.new_node); | |||||
| } else if (operation == Change::kTxSetParams) { | |||||
| auto param = args.cast<ArgsOfSetParams>(); | |||||
| MS_EXCEPTION_IF_NULL(param.func_graph); | |||||
| auto old_parameters = param.func_graph->parameters(); | |||||
| for (auto &p : param.params) { | |||||
| (*adds)[p] += 1; | |||||
| } | |||||
| for (auto &p : old_parameters) { | |||||
| (*rms)[p] += 1; | |||||
| } | |||||
| param.func_graph->set_parameters(param.params); | |||||
| switch (operation) { | |||||
| case Change::kTxSetEdge: { | |||||
| auto edge = args.cast<ArgsOfSetEdge>(); | |||||
| auto old_node = edge.root_node->input(edge.index); | |||||
| (*rm_edges)[std::make_pair(edge.root_node, std::make_pair(edge.index, old_node))] += 1; | |||||
| (*add_edges)[std::make_pair(edge.root_node, std::make_pair(edge.index, edge.new_node))] += 1; | |||||
| (*rms)[old_node] += 1; | |||||
| (*adds)[edge.new_node] += 1; | |||||
| edge.root_node->set_input(edge.index, edge.new_node); | |||||
| } break; | |||||
| case Change::kTxSetParams: { | |||||
| auto param = args.cast<ArgsOfSetParams>(); | |||||
| MS_EXCEPTION_IF_NULL(param.func_graph); | |||||
| auto old_parameters = param.func_graph->parameters(); | |||||
| for (auto &p : param.params) { | |||||
| (*adds)[p] += 1; | |||||
| } | |||||
| for (auto &p : old_parameters) { | |||||
| (*rms)[p] += 1; | |||||
| } | |||||
| param.func_graph->set_parameters(param.params); | |||||
| } break; | |||||
| case Change::kTxAddParam: { | |||||
| auto param = args.cast<ArgsOfAddParam>(); | |||||
| MS_EXCEPTION_IF_NULL(param.func_graph); | |||||
| (*adds)[param.param] += 1; | |||||
| auto param_node = param.param->cast<ParameterPtr>(); | |||||
| param.func_graph->append_parameter(param_node); | |||||
| } break; | |||||
| default: | |||||
| break; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -599,6 +617,10 @@ void FuncGraphTransaction::SetParameters(FuncGraphPtr fg, const std::vector<AnfN | |||||
| changes_.emplace_back(Change::kTxSetParams, ArgsOfSetParams{fg, params}); | changes_.emplace_back(Change::kTxSetParams, ArgsOfSetParams{fg, params}); | ||||
| } | } | ||||
| void FuncGraphTransaction::AddParameter(FuncGraphPtr fg, const AnfNodePtr ¶m) { | |||||
| changes_.emplace_back(Change::kTxAddParam, ArgsOfAddParam{fg, param}); | |||||
| } | |||||
| bool FuncGraphTransaction::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node) { | bool FuncGraphTransaction::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node) { | ||||
| MS_EXCEPTION_IF_NULL(old_node); | MS_EXCEPTION_IF_NULL(old_node); | ||||
| MS_EXCEPTION_IF_NULL(new_node); | MS_EXCEPTION_IF_NULL(new_node); | ||||
| @@ -310,6 +310,7 @@ class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> { | |||||
| void KeepRoots(const std::vector<FuncGraphPtr> &roots = {}); | void KeepRoots(const std::vector<FuncGraphPtr> &roots = {}); | ||||
| void RemoveRoots(); | void RemoveRoots(); | ||||
| void SetParameters(const FuncGraphPtr &fg, const std::vector<AnfNodePtr> ¶meters); | void SetParameters(const FuncGraphPtr &fg, const std::vector<AnfNodePtr> ¶meters); | ||||
| void AddParameter(const FuncGraphPtr &fg, const AnfNodePtr ¶meter); | |||||
| void MaybeDropFuncGraphs(const FuncGraphSet &func_graphs, bool ignore_users = false); | void MaybeDropFuncGraphs(const FuncGraphSet &func_graphs, bool ignore_users = false); | ||||
| bool Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node); | bool Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node); | ||||
| void SetEdge(const AnfNodePtr &node, int index, const AnfNodePtr &value); | void SetEdge(const AnfNodePtr &node, int index, const AnfNodePtr &value); | ||||
| @@ -400,6 +401,7 @@ class FuncGraphTransaction { | |||||
| // set parameters of a func graph | // set parameters of a func graph | ||||
| void SetParameters(FuncGraphPtr fg, const std::vector<AnfNodePtr> ¶ms); | void SetParameters(FuncGraphPtr fg, const std::vector<AnfNodePtr> ¶ms); | ||||
| void AddParameter(FuncGraphPtr fg, const AnfNodePtr ¶m); | |||||
| // replace old_node with new_node | // replace old_node with new_node | ||||
| bool Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node); | bool Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node); | ||||
| @@ -427,6 +429,18 @@ struct ArgsOfSetParams { | |||||
| } | } | ||||
| }; | }; | ||||
| // args for add param | |||||
| struct ArgsOfAddParam { | |||||
| FuncGraphPtr func_graph; | |||||
| AnfNodePtr param; | |||||
| bool operator==(const ArgsOfAddParam &other) const { return &other == this; } | |||||
| friend std::ostream &operator<<(std::ostream &os, const ArgsOfAddParam &) { | |||||
| os << "[ArgsOfAddParam]"; | |||||
| return os; | |||||
| } | |||||
| }; | |||||
| // args for set edge | // args for set edge | ||||
| struct ArgsOfSetEdge { | struct ArgsOfSetEdge { | ||||
| CNodePtr root_node; | CNodePtr root_node; | ||||
| @@ -441,7 +455,7 @@ struct ArgsOfSetEdge { | |||||
| }; | }; | ||||
| struct Change { | struct Change { | ||||
| enum OpName { kTxSetParams, kTxSetEdge }; | |||||
| enum OpName { kTxSetParams, kTxSetEdge, kTxAddParam }; | |||||
| OpName op; | OpName op; | ||||
| Any args; | Any args; | ||||
| Change(OpName name, const Any ¶) : op(name), args(para) {} | Change(OpName name, const Any ¶) : op(name), args(para) {} | ||||