|
|
|
@@ -420,6 +420,12 @@ void FuncGraphManager::SetParameters(const FuncGraphPtr &fg, const std::vector<A |
|
|
|
tr.Commit(); |
|
|
|
} |
|
|
|
|
|
|
|
void FuncGraphManager::AddParameter(const FuncGraphPtr &fg, const AnfNodePtr ¶meter) { |
|
|
|
auto tr = Transact(); |
|
|
|
tr.AddParameter(fg, parameter); |
|
|
|
tr.Commit(); |
|
|
|
} |
|
|
|
|
|
|
|
bool FuncGraphManager::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node) { |
|
|
|
auto tr = Transact(); |
|
|
|
bool success = tr.Replace(old_node, new_node); |
|
|
|
@@ -532,25 +538,37 @@ void FuncGraphManager::ParseChanges(const std::vector<Change> &changes, EdgeTupl |
|
|
|
for (auto &iter : changes) { |
|
|
|
auto operation = iter.op; |
|
|
|
auto args = iter.args; |
|
|
|
if (operation == Change::kTxSetEdge) { |
|
|
|
auto edge = args.cast<ArgsOfSetEdge>(); |
|
|
|
auto old_node = edge.root_node->input(edge.index); |
|
|
|
(*rm_edges)[std::make_pair(edge.root_node, std::make_pair(edge.index, old_node))] += 1; |
|
|
|
(*add_edges)[std::make_pair(edge.root_node, std::make_pair(edge.index, edge.new_node))] += 1; |
|
|
|
(*rms)[old_node] += 1; |
|
|
|
(*adds)[edge.new_node] += 1; |
|
|
|
edge.root_node->set_input(edge.index, edge.new_node); |
|
|
|
} else if (operation == Change::kTxSetParams) { |
|
|
|
auto param = args.cast<ArgsOfSetParams>(); |
|
|
|
MS_EXCEPTION_IF_NULL(param.func_graph); |
|
|
|
auto old_parameters = param.func_graph->parameters(); |
|
|
|
for (auto &p : param.params) { |
|
|
|
(*adds)[p] += 1; |
|
|
|
} |
|
|
|
for (auto &p : old_parameters) { |
|
|
|
(*rms)[p] += 1; |
|
|
|
} |
|
|
|
param.func_graph->set_parameters(param.params); |
|
|
|
switch (operation) { |
|
|
|
case Change::kTxSetEdge: { |
|
|
|
auto edge = args.cast<ArgsOfSetEdge>(); |
|
|
|
auto old_node = edge.root_node->input(edge.index); |
|
|
|
(*rm_edges)[std::make_pair(edge.root_node, std::make_pair(edge.index, old_node))] += 1; |
|
|
|
(*add_edges)[std::make_pair(edge.root_node, std::make_pair(edge.index, edge.new_node))] += 1; |
|
|
|
(*rms)[old_node] += 1; |
|
|
|
(*adds)[edge.new_node] += 1; |
|
|
|
edge.root_node->set_input(edge.index, edge.new_node); |
|
|
|
} break; |
|
|
|
case Change::kTxSetParams: { |
|
|
|
auto param = args.cast<ArgsOfSetParams>(); |
|
|
|
MS_EXCEPTION_IF_NULL(param.func_graph); |
|
|
|
auto old_parameters = param.func_graph->parameters(); |
|
|
|
for (auto &p : param.params) { |
|
|
|
(*adds)[p] += 1; |
|
|
|
} |
|
|
|
for (auto &p : old_parameters) { |
|
|
|
(*rms)[p] += 1; |
|
|
|
} |
|
|
|
param.func_graph->set_parameters(param.params); |
|
|
|
} break; |
|
|
|
case Change::kTxAddParam: { |
|
|
|
auto param = args.cast<ArgsOfAddParam>(); |
|
|
|
MS_EXCEPTION_IF_NULL(param.func_graph); |
|
|
|
(*adds)[param.param] += 1; |
|
|
|
auto param_node = param.param->cast<ParameterPtr>(); |
|
|
|
param.func_graph->append_parameter(param_node); |
|
|
|
} break; |
|
|
|
default: |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -599,6 +617,10 @@ void FuncGraphTransaction::SetParameters(FuncGraphPtr fg, const std::vector<AnfN |
|
|
|
changes_.emplace_back(Change::kTxSetParams, ArgsOfSetParams{fg, params}); |
|
|
|
} |
|
|
|
|
|
|
|
void FuncGraphTransaction::AddParameter(FuncGraphPtr fg, const AnfNodePtr ¶m) { |
|
|
|
changes_.emplace_back(Change::kTxAddParam, ArgsOfAddParam{fg, param}); |
|
|
|
} |
|
|
|
|
|
|
|
bool FuncGraphTransaction::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node) { |
|
|
|
MS_EXCEPTION_IF_NULL(old_node); |
|
|
|
MS_EXCEPTION_IF_NULL(new_node); |
|
|
|
|