|
|
|
@@ -946,12 +946,13 @@ class ExecuteOrderGenerator { |
|
|
|
graph_->set_execution_order(std::move(execution_order)); |
|
|
|
} |
|
|
|
|
|
|
|
std::set<CNodePtr> GetAllNodes() { |
|
|
|
auto &all_graphs = context_.visited_graphs(); |
|
|
|
std::set<CNodePtr> GetAllNodes(std::set<CNodePtr> *search_list) { |
|
|
|
const auto &all_graphs = context_.visited_graphs(); |
|
|
|
std::set<CNodePtr> all_nodes; |
|
|
|
for (auto &graph : all_graphs) { |
|
|
|
auto out = graph->get_return(); |
|
|
|
MS_EXCEPTION_IF_NULL(out); |
|
|
|
search_list->insert(out->cast<CNodePtr>()); |
|
|
|
auto nodes = TopoSort(out); |
|
|
|
for (auto &node : nodes) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
@@ -971,26 +972,34 @@ class ExecuteOrderGenerator { |
|
|
|
return input; |
|
|
|
} |
|
|
|
|
|
|
|
// Erase redundant parameters and assign nodes. |
|
|
|
void EraseParameter() { |
|
|
|
// Copy out execution order list. |
|
|
|
auto exec_order = graph_->execution_order(); |
|
|
|
std::set<CNodePtr> all_nodes = GetAllNodes(); |
|
|
|
|
|
|
|
// Remove assigns that target and source are same. |
|
|
|
for (auto iter = exec_order.begin(); iter != exec_order.end();) { |
|
|
|
void RemoveSameInputsAssigns(std::vector<CNodePtr> *exec_order) { |
|
|
|
for (auto iter = exec_order->begin(); iter != exec_order->end();) { |
|
|
|
auto &node = *iter; |
|
|
|
auto &inputs = node->inputs(); |
|
|
|
if (IsPrimitiveCNode(node, prim::kPrimAssign) && |
|
|
|
(inputs.at(kAssignTargetIndex) == GetRealNode(inputs.at(kAssignSourceIndex)))) { |
|
|
|
iter = exec_order.erase(iter); |
|
|
|
iter = exec_order->erase(iter); |
|
|
|
} else { |
|
|
|
++iter; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// Erase redundant parameters and assign nodes. |
|
|
|
void EraseParameter() { |
|
|
|
// Copy out execution order list. |
|
|
|
auto exec_order = graph_->execution_order(); |
|
|
|
std::set<CNodePtr> search_list(exec_order.begin(), exec_order.end()); |
|
|
|
|
|
|
|
// Remove assigns that target and source are same. |
|
|
|
RemoveSameInputsAssigns(&exec_order); |
|
|
|
|
|
|
|
// Get all nodes and all graphs |
|
|
|
std::set<CNodePtr> all_nodes = GetAllNodes(&search_list); |
|
|
|
auto &all_graphs = context_.visited_graphs(); |
|
|
|
|
|
|
|
// Count parameter write times by check all assign nodes. |
|
|
|
auto param_write_times = CountParameterAssigns(exec_order); |
|
|
|
auto param_write_times = CountParameterAssigns(search_list); |
|
|
|
|
|
|
|
// Erase redundant assigns. |
|
|
|
for (auto iter = exec_order.begin(); iter != exec_order.end();) { |
|
|
|
@@ -1008,6 +1017,14 @@ class ExecuteOrderGenerator { |
|
|
|
MS_EXCEPTION_IF_NULL(kg); |
|
|
|
kg->ReplaceNode(NOT_NULL(target), NOT_NULL(source)); |
|
|
|
|
|
|
|
// replace parameter in graph input |
|
|
|
for (auto &g : all_graphs) { |
|
|
|
auto child_graph_inputs = g->MutableInputs(); |
|
|
|
std::replace(child_graph_inputs->begin(), child_graph_inputs->end(), target, source); |
|
|
|
MS_LOG(DEBUG) << "Replace parameter " << target->DebugString() << " by " << source->DebugString() |
|
|
|
<< " in graph " << g->graph_id() << " inputs"; |
|
|
|
} |
|
|
|
|
|
|
|
// replace parameter in node |
|
|
|
for (auto &iter_node : all_nodes) { |
|
|
|
for (size_t i = 0; i < iter_node->size(); ++i) { |
|
|
|
@@ -1018,15 +1035,6 @@ class ExecuteOrderGenerator { |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// replace parameter in graph input |
|
|
|
auto &all_graphs = context_.visited_graphs(); |
|
|
|
for (auto &g : all_graphs) { |
|
|
|
auto child_graph_inputs = g->MutableInputs(); |
|
|
|
std::replace(child_graph_inputs->begin(), child_graph_inputs->end(), target, source); |
|
|
|
MS_LOG(DEBUG) << "Replace parameter " << target->DebugString() << " by " << source->DebugString() |
|
|
|
<< " in graph " << g->graph_id() << " inputs"; |
|
|
|
} |
|
|
|
iter = exec_order.erase(iter); |
|
|
|
continue; |
|
|
|
} |
|
|
|
@@ -1039,7 +1047,26 @@ class ExecuteOrderGenerator { |
|
|
|
} |
|
|
|
|
|
|
|
// Count parameter write times by check all assign nodes. |
|
|
|
std::map<AnfNodePtr, int> CountParameterAssigns(const std::vector<CNodePtr> &all_nodes) { |
|
|
|
std::map<AnfNodePtr, int> CountParameterAssigns(const std::set<CNodePtr> &search_list) { |
|
|
|
auto ref_map = graph_->GetRefMap(); |
|
|
|
std::multimap<AnfNodePtr, std::tuple<size_t, AnfNodePtr, size_t>> ref_multimap; |
|
|
|
std::set<AnfNodePtr> root_inputs(graph_->inputs().begin(), graph_->inputs().end()); |
|
|
|
std::transform(ref_map.begin(), ref_map.end(), std::inserter(ref_multimap, ref_multimap.end()), |
|
|
|
[](const std::pair<std::pair<AnfNodePtr, size_t>, std::pair<AnfNodePtr, size_t>> &p) |
|
|
|
-> std::pair<AnfNodePtr, std::tuple<size_t, AnfNodePtr, size_t>> { |
|
|
|
return {p.first.first, {p.first.second, p.second.first, p.second.second}}; |
|
|
|
}); |
|
|
|
auto validate_ref_parameter = [](AnfNodePtr node) -> AnfNodePtr { |
|
|
|
if (node->isa<CNode>() && AnfAlgo::CheckPrimitiveType(node, prim::KPrimTransData)) { |
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
auto first_input = cnode->input(kFirstDataInputIndex); |
|
|
|
MS_EXCEPTION_IF_NULL(first_input); |
|
|
|
return first_input; |
|
|
|
} |
|
|
|
return node; |
|
|
|
}; |
|
|
|
|
|
|
|
// Find all graph input parameters. |
|
|
|
std::map<AnfNodePtr, int> param_write_times; |
|
|
|
const auto &all_graphs = context_.visited_graphs(); |
|
|
|
@@ -1051,16 +1078,24 @@ class ExecuteOrderGenerator { |
|
|
|
} |
|
|
|
} |
|
|
|
// Search all nodes for parameter write assigns. |
|
|
|
for (auto &node : all_nodes) { |
|
|
|
if (!IsPrimitiveCNode(node, prim::kPrimAssign)) { |
|
|
|
continue; |
|
|
|
for (auto &node : search_list) { |
|
|
|
std::set<AnfNodePtr> refed_parameters; |
|
|
|
for (auto [iter, end] = ref_multimap.equal_range(node); iter != end; ++iter) { |
|
|
|
refed_parameters.insert(validate_ref_parameter(std::get<1>(iter->second))); |
|
|
|
} |
|
|
|
auto &target = node->inputs().at(kAssignTargetIndex); |
|
|
|
MS_EXCEPTION_IF_NULL(target); |
|
|
|
auto iter = param_write_times.find(target); |
|
|
|
if (iter != param_write_times.end()) { |
|
|
|
// Found a parameter writer, count it. |
|
|
|
++(iter->second); |
|
|
|
for (auto &in : node->inputs()) { |
|
|
|
auto visit_node = AnfAlgo::VisitKernelWithReturnType(in, 0).first; |
|
|
|
visit_node = validate_ref_parameter(visit_node); |
|
|
|
if (!visit_node->isa<Parameter>() || root_inputs.find(visit_node) != root_inputs.end()) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (refed_parameters.find(visit_node) != refed_parameters.end()) { |
|
|
|
auto iter = param_write_times.find(visit_node); |
|
|
|
if (iter != param_write_times.end()) { |
|
|
|
// Found a parameter writer, count it. |
|
|
|
++(iter->second); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
return param_write_times; |
|
|
|
|