|
|
|
@@ -157,7 +157,7 @@ std::set<FuncGraphPtr> GetFuncGraphbyCallNode(const AnfNodePtr &node, std::stack |
|
|
|
for (size_t i = kMakeTupleInputStartPos; i < tuple_inputs.size(); ++i) { |
|
|
|
MS_EXCEPTION_IF_NULL(tuple_inputs[i]); |
|
|
|
std::stack<size_t> tmp_output_indexs = *output_indexs; |
|
|
|
func_graphs.emplace(GetFuncGraphFromPartial(tuple_inputs[i], &tmp_output_indexs)); |
|
|
|
(void)func_graphs.emplace(GetFuncGraphFromPartial(tuple_inputs[i], &tmp_output_indexs)); |
|
|
|
} |
|
|
|
} else if (tuple_node->isa<Parameter>()) { |
|
|
|
const auto &abstract = tuple_node->abstract(); |
|
|
|
@@ -223,7 +223,7 @@ KernelWithIndex FetchRealInputNode(const KernelWithIndex &node_with_index) { |
|
|
|
} |
|
|
|
|
|
|
|
// Fetch all the output index in the sub-abstract of abstract. |
|
|
|
std::set<size_t> FetchRealIndexByAbstract(const AbstractBasePtr &abstract, std::vector<size_t> *indexes) { |
|
|
|
std::set<size_t> FetchRealIndexByAbstract(const AbstractBasePtr &abstract, std::vector<size_t> *const indexes) { |
|
|
|
MS_EXCEPTION_IF_NULL(abstract); |
|
|
|
MS_EXCEPTION_IF_NULL(indexes); |
|
|
|
AbstractBasePtr dst_abstract = abstract; |
|
|
|
@@ -232,7 +232,7 @@ std::set<size_t> FetchRealIndexByAbstract(const AbstractBasePtr &abstract, std:: |
|
|
|
if (indexes->empty()) { |
|
|
|
size_t output_num = AnfAlgo::GetOutputNumByAbstract(abstract); |
|
|
|
for (size_t i = 0; i < output_num; ++i) { |
|
|
|
output_indexs.emplace(i); |
|
|
|
(void)output_indexs.emplace(i); |
|
|
|
} |
|
|
|
return output_indexs; |
|
|
|
} |
|
|
|
@@ -286,13 +286,13 @@ std::set<size_t> FetchRealIndexByAbstract(const AbstractBasePtr &abstract, std:: |
|
|
|
// Fetch real output index. |
|
|
|
auto tmp_indexs = FetchRealIndexByAbstract(dst_abstract, indexes); |
|
|
|
for (auto tmp_index : tmp_indexs) { |
|
|
|
output_indexs.emplace(tmp_index + pre_abstract_num); |
|
|
|
(void)output_indexs.emplace(tmp_index + pre_abstract_num); |
|
|
|
} |
|
|
|
return output_indexs; |
|
|
|
} |
|
|
|
|
|
|
|
// Get all the real parameters corresponding to node. |
|
|
|
void FetchRealParameterByNode(const KernelWithIndex &node, std::set<KernelWithIndex> *real_parameters, |
|
|
|
void FetchRealParameterByNode(const KernelWithIndex &node, std::set<KernelWithIndex> *const real_parameters, |
|
|
|
std::set<KernelWithIndex> *invalid_call_nodes, |
|
|
|
const mindspore::HashMap<AnfNodePtr, std::set<FuncGraphPtr>> &call_node_to_func_graphs) { |
|
|
|
auto node_with_index = node; |
|
|
|
@@ -301,13 +301,13 @@ void FetchRealParameterByNode(const KernelWithIndex &node, std::set<KernelWithIn |
|
|
|
} |
|
|
|
if (node_with_index.first->isa<ValueNode>() || node_with_index.first->isa<Parameter>()) { |
|
|
|
// If node is a valuenode or parameter, the real parameter is itself. |
|
|
|
real_parameters->emplace(node_with_index); |
|
|
|
(void)real_parameters->emplace(node_with_index); |
|
|
|
} else if (AnfAlgo::IsCallNode(node_with_index.first)) { |
|
|
|
// If node is a call node, the real parameters are the outputs of funcgraph the node called. |
|
|
|
if (invalid_call_nodes->find(node_with_index) != invalid_call_nodes->end()) { |
|
|
|
return; |
|
|
|
} |
|
|
|
invalid_call_nodes->emplace(node_with_index); |
|
|
|
(void)invalid_call_nodes->emplace(node_with_index); |
|
|
|
const auto &iter = call_node_to_func_graphs.find(node_with_index.first); |
|
|
|
if (iter == call_node_to_func_graphs.end()) { |
|
|
|
MS_LOG(EXCEPTION) << "Invalid call node:" << node_with_index.first->DebugString(); |
|
|
|
@@ -347,18 +347,18 @@ void FetchRealParameterByNode(const KernelWithIndex &node, std::set<KernelWithIn |
|
|
|
} |
|
|
|
} else { |
|
|
|
// If node is a kernel, the real parameter is itself. |
|
|
|
real_parameters->emplace(node_with_index); |
|
|
|
(void)real_parameters->emplace(node_with_index); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// Fetch all the weight parameters related to node. It runs like this: |
|
|
|
// if we have a map like {{a, {b, c}}, {b, {d, e}}}, final we will get {{a, {b, c, d, e}}, {b, {c, d}}}. |
|
|
|
void FetchWeightbyHostParameter(const AnfNodePtr &node, std::set<AnfNodePtr> *dest_nodes, |
|
|
|
void FetchWeightbyHostParameter(const AnfNodePtr &node, std::set<AnfNodePtr> *const dest_nodes, |
|
|
|
const RealToFormalParameter &front_to_front_weight) { |
|
|
|
if (dest_nodes->find(node) != dest_nodes->end()) { |
|
|
|
return; |
|
|
|
} |
|
|
|
dest_nodes->emplace(node); |
|
|
|
(void)dest_nodes->emplace(node); |
|
|
|
auto iter = front_to_front_weight.find({node, 0}); |
|
|
|
if (iter == front_to_front_weight.end()) { |
|
|
|
return; |
|
|
|
@@ -385,15 +385,15 @@ AnfNodePtr FetchSourceNodeByAutoMonad(const AnfNodePtr &node) { |
|
|
|
} |
|
|
|
|
|
|
|
// Topologically sort all funcgraphs according to the function call relationship. |
|
|
|
std::vector<FuncGraphPtr> TopoSortForFuncGraph(const FuncGraphPtr &root, FuncGraphCallRelation *edges) { |
|
|
|
std::vector<FuncGraphPtr> TopoSortForFuncGraph(const FuncGraphPtr &root, FuncGraphCallRelation *const edges) { |
|
|
|
MS_EXCEPTION_IF_NULL(root->manager()); |
|
|
|
std::set<FuncGraphPtr> nodes; |
|
|
|
nodes.emplace(root); |
|
|
|
(void)nodes.emplace(root); |
|
|
|
|
|
|
|
FuncGraphSet subs = root->manager()->func_graphs(); |
|
|
|
for (auto sub : subs) { |
|
|
|
if (sub != root && root != nullptr) { |
|
|
|
nodes.emplace(sub); |
|
|
|
(void)nodes.emplace(sub); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
@@ -408,7 +408,7 @@ std::vector<FuncGraphPtr> TopoSortForFuncGraph(const FuncGraphPtr &root, FuncGra |
|
|
|
while (!que.empty()) { |
|
|
|
const auto node = que.front(); |
|
|
|
que.pop(); |
|
|
|
result.emplace_back(node); |
|
|
|
(void)result.emplace_back(node); |
|
|
|
for (auto iter = edges->begin(); iter != edges->end();) { |
|
|
|
auto &sub_edges = iter->second; |
|
|
|
for (auto sub_iter = sub_edges.begin(); sub_iter != sub_edges.end();) { |
|
|
|
@@ -444,7 +444,7 @@ std::vector<KernelWithIndex> FetchAllOutputWithIndex(const AnfNodePtr &node) { |
|
|
|
MS_EXCEPTION_IF_NULL(value_tuple); |
|
|
|
const auto tuple_value = value_tuple->value(); |
|
|
|
for (size_t i = 0; i < tuple_value.size(); ++i) { |
|
|
|
result.emplace_back(node, i); |
|
|
|
(void)result.emplace_back(node, i); |
|
|
|
} |
|
|
|
return result; |
|
|
|
} |
|
|
|
@@ -460,9 +460,9 @@ std::vector<KernelWithIndex> FetchAllOutputWithIndex(const AnfNodePtr &node) { |
|
|
|
} else if (AnfAlgo::CheckPrimitiveType(node_with_index.first, prim::kPrimSwitch) || |
|
|
|
AnfAlgo::CheckPrimitiveType(node_with_index.first, prim::kPrimSwitchLayer)) { |
|
|
|
} else if (AnfAlgo::IsCallNode(node_with_index.first)) { |
|
|
|
result.emplace_back(node_with_index.first, i); |
|
|
|
(void)result.emplace_back(node_with_index.first, i); |
|
|
|
} else { |
|
|
|
result.emplace_back(node_with_index); |
|
|
|
(void)result.emplace_back(node_with_index); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
@@ -539,13 +539,13 @@ void CreateDeviceTensorForFrontNode(const KernelWithIndex &front_node_with_index |
|
|
|
} |
|
|
|
|
|
|
|
// Fetch all funcgraph by a seed graph, if a calls b, b calls c, and c calls a, return a set of a, b, c. |
|
|
|
void FetchAllExecutionFunction(const FuncGraphPtr &func_graph, std::set<FuncGraphPtr> *checked_funcgraphs, |
|
|
|
void FetchAllExecutionFunction(const FuncGraphPtr &func_graph, std::set<FuncGraphPtr> *const checked_funcgraphs, |
|
|
|
const std::unordered_map<FuncGraphPtr, std::set<FuncGraphPtr>> &call_relation) { |
|
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
|
if (checked_funcgraphs->find(func_graph) != checked_funcgraphs->end()) { |
|
|
|
return; |
|
|
|
} |
|
|
|
checked_funcgraphs->emplace(func_graph); |
|
|
|
(void)checked_funcgraphs->emplace(func_graph); |
|
|
|
auto iter = call_relation.find(func_graph); |
|
|
|
if (iter == call_relation.end()) { |
|
|
|
return; |
|
|
|
@@ -626,7 +626,7 @@ std::vector<KernelWithIndex> FetchInputNodeByNode(const AnfNodePtr &node) { |
|
|
|
const auto &abstract = real_node->abstract(); |
|
|
|
if (abstract == nullptr) { |
|
|
|
MS_LOG(WARNING) << "Empty abstract for node:" << real_node->DebugString(); |
|
|
|
results.emplace_back(AnfAlgo::VisitKernelWithReturnType(real_node, real_index)); |
|
|
|
(void)results.emplace_back(AnfAlgo::VisitKernelWithReturnType(real_node, real_index)); |
|
|
|
return results; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -651,7 +651,7 @@ std::vector<KernelWithIndex> FetchInputNodeByNode(const AnfNodePtr &node) { |
|
|
|
size_t output_num = AnfAlgo::GetOutputNumByAbstract(abstract); |
|
|
|
if (!abstract->isa<abstract::AbstractTuple>()) { |
|
|
|
for (size_t i = 0; i < output_num; ++i) { |
|
|
|
results.emplace_back(real_node, i); |
|
|
|
(void)results.emplace_back(real_node, i); |
|
|
|
} |
|
|
|
return results; |
|
|
|
} |
|
|
|
@@ -663,7 +663,7 @@ std::vector<KernelWithIndex> FetchInputNodeByNode(const AnfNodePtr &node) { |
|
|
|
for (const auto &sub_abstract : sub_abstracts) { |
|
|
|
MS_EXCEPTION_IF_NULL(sub_abstract); |
|
|
|
if (!sub_abstract->isa<abstract::AbstractMonad>()) { |
|
|
|
results.emplace_back(real_node, index++); |
|
|
|
(void)results.emplace_back(real_node, index++); |
|
|
|
} |
|
|
|
} |
|
|
|
return results; |
|
|
|
@@ -672,7 +672,7 @@ std::vector<KernelWithIndex> FetchInputNodeByNode(const AnfNodePtr &node) { |
|
|
|
// Add formal parameter and real parameter into realationship map. |
|
|
|
void AddFormalToRealParameter(const AnfNodePtr &formal_parameter, const AnfNodePtr &real_parameter, |
|
|
|
const CallNodeToFuncGraph &call_node_to_func_graphs, |
|
|
|
FormalToRealParameter *formal_to_real_parameters) { |
|
|
|
FormalToRealParameter *const formal_to_real_parameters) { |
|
|
|
MS_EXCEPTION_IF_NULL(formal_parameter); |
|
|
|
auto abstract = formal_parameter->abstract(); |
|
|
|
MS_EXCEPTION_IF_NULL(abstract); |
|
|
|
@@ -778,7 +778,7 @@ bool IsCsrNode(const AnfNodePtr &node) { |
|
|
|
AnfAlgo::CheckPrimitiveType(node, prim::kPrimCSRTensorGetDenseShape); |
|
|
|
} |
|
|
|
|
|
|
|
KernelWithIndex GetFrontNodeByKernelGraph(const AnfNodePtr &backend_node, KernelGraph *const graph) { |
|
|
|
KernelWithIndex GetFrontNodeByKernelGraph(const AnfNodePtr &backend_node, const KernelGraph *const graph) { |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
const auto &front_node = graph->GetFrontAnfByBackendAnf(backend_node); |
|
|
|
if (front_node != nullptr) { |
|
|
|
@@ -819,9 +819,9 @@ std::vector<KernelWithIndex> FetchInputNodeByCNode(const AnfNodePtr &node) { |
|
|
|
if (inputs.size() != kSwitchInputNum) { |
|
|
|
MS_LOG(EXCEPTION) << "Invalid switch node:" << node->DebugString(); |
|
|
|
} |
|
|
|
results.emplace_back(AnfAlgo::VisitKernelWithReturnType(inputs[kSwitchCondPos], 0)); |
|
|
|
results.emplace_back(AnfAlgo::VisitKernelWithReturnType(inputs[kSwitchFalseBranchPos], 0)); |
|
|
|
results.emplace_back(AnfAlgo::VisitKernelWithReturnType(inputs[kSwitchTrueBranchPos], 0)); |
|
|
|
(void)results.emplace_back(AnfAlgo::VisitKernelWithReturnType(inputs[kSwitchCondPos], 0)); |
|
|
|
(void)results.emplace_back(AnfAlgo::VisitKernelWithReturnType(inputs[kSwitchFalseBranchPos], 0)); |
|
|
|
(void)results.emplace_back(AnfAlgo::VisitKernelWithReturnType(inputs[kSwitchTrueBranchPos], 0)); |
|
|
|
return results; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -1081,7 +1081,7 @@ void ControlNodeParser::ParseDeviceContextForFuncGraph(const std::vector<AnfNode |
|
|
|
const auto &abstract = parameter->abstract(); |
|
|
|
MS_EXCEPTION_IF_NULL(abstract); |
|
|
|
for (size_t i = 0; i < AnfAlgo::GetOutputNumByAbstract(abstract); ++i) { |
|
|
|
front_parameters.emplace_back(parameter, i); |
|
|
|
(void)front_parameters.emplace_back(parameter, i); |
|
|
|
} |
|
|
|
} |
|
|
|
std::vector<const DeviceContext *> parameter_device_contexts(front_parameters.size(), default_context); |
|
|
|
@@ -1162,7 +1162,7 @@ void ControlNodeParser::ParseDeviceContextForPartialNode(const std::vector<AnfNo |
|
|
|
std::vector<const DeviceContext *> device_contexts; |
|
|
|
// In partial node, the first input is always a partial, maybe a funcgraph or a partial node, so we need |
|
|
|
// to insert an empty device context for it. |
|
|
|
device_contexts.emplace_back(nullptr); |
|
|
|
(void)device_contexts.emplace_back(nullptr); |
|
|
|
for (size_t i = 0; i < inputs.size() - kPartialInputStartPos; ++i) { |
|
|
|
if (i >= iter->second.size()) { |
|
|
|
MS_LOG(EXCEPTION) << "Invalid device context index:" << i << " for funcgraph:" << func_graph->ToString() |
|
|
|
@@ -1170,7 +1170,7 @@ void ControlNodeParser::ParseDeviceContextForPartialNode(const std::vector<AnfNo |
|
|
|
<< " for partial node:" << cnode->DebugString(); |
|
|
|
} |
|
|
|
MS_EXCEPTION_IF_NULL(iter->second[i]); |
|
|
|
device_contexts.emplace_back(iter->second[i]); |
|
|
|
(void)device_contexts.emplace_back(iter->second[i]); |
|
|
|
} |
|
|
|
control_node_to_device_contexts_[control_node] = device_contexts; |
|
|
|
} |
|
|
|
@@ -1198,7 +1198,7 @@ void ControlNodeParser::ParseDeviceContextForCallNode(const std::vector<AnfNodeP |
|
|
|
std::vector<const DeviceContext *> device_contexts; |
|
|
|
// In call node, the first input is always a partial, maybe a funcgraph or a partial node, so we need |
|
|
|
// to insert an empty device context for it. |
|
|
|
device_contexts.emplace_back(nullptr); |
|
|
|
(void)device_contexts.emplace_back(nullptr); |
|
|
|
const auto &cnode = control_node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
const auto &inputs = cnode->inputs(); |
|
|
|
@@ -1216,13 +1216,8 @@ void ControlNodeParser::ParseDeviceContextForCallNode(const std::vector<AnfNodeP |
|
|
|
|
|
|
|
// Fetch the device contexts for the real parameters on the call node. |
|
|
|
for (size_t i = iter->second.size() - call_input_num; i < iter->second.size(); ++i) { |
|
|
|
if (i >= iter->second.size()) { |
|
|
|
MS_LOG(EXCEPTION) << "Invalid device context index:" << i << " for funcgraph:" << func_graph->ToString() |
|
|
|
<< " device context size:" << iter->second.size() |
|
|
|
<< " for partial node:" << cnode->DebugString(); |
|
|
|
} |
|
|
|
MS_EXCEPTION_IF_NULL(iter->second[i]); |
|
|
|
device_contexts.emplace_back(iter->second[i]); |
|
|
|
(void)device_contexts.emplace_back(iter->second[i]); |
|
|
|
} |
|
|
|
control_node_to_device_contexts_[control_node] = device_contexts; |
|
|
|
} |
|
|
|
@@ -1237,7 +1232,7 @@ void ControlNodeParser::ParseDeviceContextForReturnNode(const DeviceContext *def |
|
|
|
MS_EXCEPTION_IF_NULL(call_node); |
|
|
|
const auto &func_graph = call_node->func_graph(); |
|
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
|
func_graph_call_relation[func_graph].emplace_back(call_node_to_func_graphs.second); |
|
|
|
(void)func_graph_call_relation[func_graph].emplace_back(call_node_to_func_graphs.second); |
|
|
|
} |
|
|
|
|
|
|
|
// Topologically sort all funcgraphs according to the function call relationship. |
|
|
|
@@ -1267,11 +1262,11 @@ void ControlNodeParser::ParseDeviceContextForReturnNode(const DeviceContext *def |
|
|
|
MS_LOG(EXCEPTION) << "Cannot find device context for funcgraph:" << func_graph->ToString(); |
|
|
|
} |
|
|
|
MS_EXCEPTION_IF_NULL(func_graph_iter->second[iter - func_graph->parameters().begin()]); |
|
|
|
return_device_contexts.emplace_back(func_graph_iter->second[iter - func_graph->parameters().begin()]); |
|
|
|
(void)return_device_contexts.emplace_back(func_graph_iter->second[iter - func_graph->parameters().begin()]); |
|
|
|
} else if (output_node.first->isa<ValueNode>()) { |
|
|
|
// If the output is parameter, used the default context type. |
|
|
|
MS_EXCEPTION_IF_NULL(default_context); |
|
|
|
return_device_contexts.emplace_back(default_context); |
|
|
|
(void)return_device_contexts.emplace_back(default_context); |
|
|
|
} else if (AnfAlgo::IsCallNode(output_node.first)) { |
|
|
|
// If the output is call node, get the device context type by the output of funcgraph. |
|
|
|
const auto &func_graphs = call_node_to_func_graphs_[output_node.first]; |
|
|
|
@@ -1295,10 +1290,10 @@ void ControlNodeParser::ParseDeviceContextForReturnNode(const DeviceContext *def |
|
|
|
<< " index:" << output_node.second; |
|
|
|
} |
|
|
|
MS_EXCEPTION_IF_NULL(call_device_contexts[output_node.second]); |
|
|
|
return_device_contexts.emplace_back(call_device_contexts[output_node.second]); |
|
|
|
(void)return_device_contexts.emplace_back(call_device_contexts[output_node.second]); |
|
|
|
} else if (AnfAlgo::CheckPrimitiveType(output_node.first, prim::kPrimPartial) || |
|
|
|
AnfAlgo::CheckPrimitiveType(output_node.first, prim::kPrimSwitch)) { |
|
|
|
return_device_contexts.emplace_back(default_context); |
|
|
|
(void)return_device_contexts.emplace_back(default_context); |
|
|
|
} else if (output_node.first->isa<CNode>()) { |
|
|
|
// If the output is a cnode, get the device context type by the kernel. |
|
|
|
const auto &iter = front_to_backend_kernels_.find(output_node); |
|
|
|
@@ -1306,7 +1301,7 @@ void ControlNodeParser::ParseDeviceContextForReturnNode(const DeviceContext *def |
|
|
|
MS_LOG(EXCEPTION) << "Cannot find backend kernel for cnode:" << output_node.first->DebugString(); |
|
|
|
} |
|
|
|
MS_EXCEPTION_IF_NULL(iter->second.second); |
|
|
|
return_device_contexts.emplace_back(iter->second.second); |
|
|
|
(void)return_device_contexts.emplace_back(iter->second.second); |
|
|
|
} else { |
|
|
|
MS_LOG(EXCEPTION) << "Invalid node for return:" << output_node.first->DebugString(); |
|
|
|
} |
|
|
|
@@ -1384,7 +1379,7 @@ FuncGraphPtr ControlNodeParser::FetchFuncGraphByKernelGraph(const KernelGraph *c |
|
|
|
} |
|
|
|
|
|
|
|
void ControlNodeParser::FetchFrontValueNode(const std::vector<AnfNodePtr> &control_nodes, |
|
|
|
DeviceContext *default_context) { |
|
|
|
const DeviceContext *const default_context) { |
|
|
|
MS_EXCEPTION_IF_NULL(default_context); |
|
|
|
|
|
|
|
for (const auto &formal_to_real_parameter : formal_to_real_parameters_) { |
|
|
|
@@ -1395,11 +1390,11 @@ void ControlNodeParser::FetchFrontValueNode(const std::vector<AnfNodePtr> &contr |
|
|
|
|
|
|
|
const auto &iter = front_to_backend_parameters_.find(real_parameter_with_index); |
|
|
|
if (iter != front_to_backend_parameters_.end() && (!iter->second.empty())) { |
|
|
|
front_value_nodes_.emplace(real_parameter_with_index, iter->second.begin()->second); |
|
|
|
(void)front_value_nodes_.emplace(real_parameter_with_index, iter->second.begin()->second); |
|
|
|
CreateDeviceTensorForValueNode(real_parameter_with_index, iter->second.begin()->first, |
|
|
|
iter->second.begin()->second); |
|
|
|
} else { |
|
|
|
front_value_nodes_.emplace(real_parameter_with_index, default_context); |
|
|
|
(void)front_value_nodes_.emplace(real_parameter_with_index, default_context); |
|
|
|
CreateDeviceTensorForFrontNode(real_parameter_with_index, default_context); |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -1414,7 +1409,7 @@ void ControlNodeParser::FetchFrontValueNode(const std::vector<AnfNodePtr> &contr |
|
|
|
const auto &backend_parameter = front_to_backend_parameters.second.begin()->first; |
|
|
|
const auto &device_context = front_to_backend_parameters.second.begin()->second; |
|
|
|
CreateDeviceTensorForValueNode(front_to_backend_parameters.first, backend_parameter, device_context); |
|
|
|
front_value_nodes_.emplace(front_to_backend_parameters.first, device_context); |
|
|
|
(void)front_value_nodes_.emplace(front_to_backend_parameters.first, device_context); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
@@ -1437,7 +1432,7 @@ void ControlNodeParser::FetchFrontValueNode(const std::vector<AnfNodePtr> &contr |
|
|
|
if (IsFrontValueNode(input_with_index) && |
|
|
|
front_value_nodes_.find({input_with_index, iter->second[i]}) == front_value_nodes_.end()) { |
|
|
|
CreateDeviceTensorForFrontNode(input_with_index, iter->second[i]); |
|
|
|
front_value_nodes_.emplace(input_with_index, iter->second[i]); |
|
|
|
(void)front_value_nodes_.emplace(input_with_index, iter->second[i]); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -1502,9 +1497,9 @@ void ControlNodeParser::ParseFormalToRealParameter(const std::vector<AnfNodePtr> |
|
|
|
std::set<KernelWithIndex> invalid_real_parameter{formal_parameter}; |
|
|
|
ParseAllRealParameterByFormalParameter(real_parameter, formal_to_real_parameters, &total_real_parameters, |
|
|
|
&invalid_real_parameter); |
|
|
|
real_to_formal_parameters_[real_parameter].emplace(formal_parameter); |
|
|
|
(void)real_to_formal_parameters_[real_parameter].emplace(formal_parameter); |
|
|
|
} else { |
|
|
|
total_real_parameters.emplace(real_parameter); |
|
|
|
(void)total_real_parameters.emplace(real_parameter); |
|
|
|
} |
|
|
|
} |
|
|
|
std::swap(formal_to_real_parameters_[formal_parameter], total_real_parameters); |
|
|
|
@@ -1513,12 +1508,12 @@ void ControlNodeParser::ParseFormalToRealParameter(const std::vector<AnfNodePtr> |
|
|
|
|
|
|
|
void ControlNodeParser::ParseAllRealParameterByFormalParameter(const KernelWithIndex &formal_parameter, |
|
|
|
const FormalToRealParameter &formal_to_real_parameters, |
|
|
|
std::set<KernelWithIndex> *total_real_parameters, |
|
|
|
std::set<KernelWithIndex> *const total_real_parameters, |
|
|
|
std::set<KernelWithIndex> *invalid_real_parameter) { |
|
|
|
if (invalid_real_parameter->find(formal_parameter) != invalid_real_parameter->end()) { |
|
|
|
return; |
|
|
|
} |
|
|
|
invalid_real_parameter->emplace(formal_parameter); |
|
|
|
(void)invalid_real_parameter->emplace(formal_parameter); |
|
|
|
|
|
|
|
// Get all the actual parameters corresponding to parameter recursively. |
|
|
|
const auto &dst_iter = formal_to_real_parameters_.find(formal_parameter); |
|
|
|
@@ -1538,7 +1533,7 @@ void ControlNodeParser::ParseAllRealParameterByFormalParameter(const KernelWithI |
|
|
|
const auto &real_parameters = src_iter->second; |
|
|
|
for (const auto &real_parameter : real_parameters) { |
|
|
|
MS_EXCEPTION_IF_NULL(real_parameter.first); |
|
|
|
total_real_parameters->emplace(real_parameter); |
|
|
|
(void)total_real_parameters->emplace(real_parameter); |
|
|
|
if (real_parameter.first->isa<Parameter>()) { |
|
|
|
ParseAllRealParameterByFormalParameter(real_parameter, formal_to_real_parameters, total_real_parameters, |
|
|
|
invalid_real_parameter); |
|
|
|
@@ -1623,13 +1618,13 @@ void ControlNodeParser::ParseFrontToBackendParameter(const std::vector<KernelGra |
|
|
|
call_node_to_func_graphs_); |
|
|
|
for (const auto real_parameter : real_parameters) { |
|
|
|
if (real_parameter.first->isa<Parameter>() || real_parameter.first->isa<ValueNode>()) { |
|
|
|
front_to_backend_parameters_[real_parameter].emplace(parameter, device_context); |
|
|
|
(void)front_to_backend_parameters_[real_parameter].emplace(parameter, device_context); |
|
|
|
} |
|
|
|
} |
|
|
|
} else if (front_tuple_parameter_with_index.first != nullptr) { |
|
|
|
front_to_backend_parameters_[front_tuple_parameter_with_index].emplace(parameter, device_context); |
|
|
|
(void)front_to_backend_parameters_[front_tuple_parameter_with_index].emplace(parameter, device_context); |
|
|
|
} else { |
|
|
|
front_to_backend_parameters_[{front_node, 0}].emplace(parameter, device_context); |
|
|
|
(void)front_to_backend_parameters_[{front_node, 0}].emplace(parameter, device_context); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -1664,7 +1659,7 @@ void ControlNodeParser::ParseCallNodeToFuncGraph(const std::vector<AnfNodePtr> & |
|
|
|
|
|
|
|
auto func_graphs = func_graph_analyzer->GetCallerFuncGraphs(control_node); |
|
|
|
for (auto func_graph : func_graphs) { |
|
|
|
call_node_to_func_graphs_[control_node].emplace(func_graph); |
|
|
|
(void)call_node_to_func_graphs_[control_node].emplace(func_graph); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -1758,7 +1753,7 @@ void ControlNodeParser::ParseFirstControlNodeForFuncGraph(const std::vector<AnfN |
|
|
|
IsFirstControlNode(control_node, &checked_nodes, unrecursion_call_nodes_)) { |
|
|
|
const auto &func_graph = control_node->func_graph(); |
|
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
|
func_graph_to_first_control_nodes_[func_graph].emplace(control_node); |
|
|
|
(void)func_graph_to_first_control_nodes_[func_graph].emplace(control_node); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -1784,7 +1779,7 @@ void ControlNodeParser::ParseUnRecursionCallNode() { |
|
|
|
FetchAllExecutionFunction(func_graph, &exexution_func_graphs, func_graph_call_relation); |
|
|
|
} |
|
|
|
if (exexution_func_graphs.find(dest_func_graph) == exexution_func_graphs.end()) { |
|
|
|
unrecursion_call_nodes_.emplace(call_node); |
|
|
|
(void)unrecursion_call_nodes_.emplace(call_node); |
|
|
|
MS_LOG(DEBUG) << "Add unrecursion call control node:" << call_node->DebugString(); |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -1809,7 +1804,7 @@ void ControlNodeParser::ParseNeedStackControlNode(const std::vector<AnfNodePtr> |
|
|
|
continue; |
|
|
|
} |
|
|
|
} |
|
|
|
need_stack_control_nodes_.emplace(control_node); |
|
|
|
(void)need_stack_control_nodes_.emplace(control_node); |
|
|
|
MS_LOG(DEBUG) << "Add need stack control node:" << control_node->DebugString(); |
|
|
|
break; |
|
|
|
} |
|
|
|
@@ -1834,13 +1829,13 @@ void ControlNodeParser::ParseNeedStackControlNode(const std::vector<AnfNodePtr> |
|
|
|
} |
|
|
|
|
|
|
|
if (call_input_num != 0 && (AnfAlgo::CheckPrimitiveType(inputs[kReturnInputPos], prim::kPrimDepend))) { |
|
|
|
need_stack_control_nodes_.emplace(control_node); |
|
|
|
(void)need_stack_control_nodes_.emplace(control_node); |
|
|
|
} |
|
|
|
} else if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimPartial) || |
|
|
|
AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitch) || |
|
|
|
AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitchLayer)) { |
|
|
|
if (!IsInputInSameLevel(control_node)) { |
|
|
|
need_stack_control_nodes_.emplace(control_node); |
|
|
|
(void)need_stack_control_nodes_.emplace(control_node); |
|
|
|
MS_LOG(DEBUG) << "Add need stack control node:" << control_node->DebugString(); |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -1865,7 +1860,7 @@ void ControlNodeParser::ParseNeedStackKernelGraph(const KernelGraphToDeviceConte |
|
|
|
MS_LOG(EXCEPTION) << "Failed to find device context for kernel graph:" << kernel_graph->ToString(); |
|
|
|
} |
|
|
|
// Collect kernel graphs in group. |
|
|
|
kernel_graph_group_info->graphs_.emplace(kernel_graph); |
|
|
|
(void)kernel_graph_group_info->graphs_.emplace(kernel_graph); |
|
|
|
|
|
|
|
// Collect inputs in group. |
|
|
|
const auto &real_parameters = kernel_graph->input_nodes(); |
|
|
|
@@ -1908,14 +1903,14 @@ void ControlNodeParser::ParseNeedStackKernelGraph(const KernelGraphToDeviceConte |
|
|
|
|
|
|
|
kernel_graphs_to_group_info_[kernel_graph] = kernel_graph_group_info; |
|
|
|
if (kernel_graph_group_info->need_stack_) { |
|
|
|
call_input_kernel_graphs_.emplace(kernel_graph.get()); |
|
|
|
(void)call_input_kernel_graphs_.emplace(kernel_graph.get()); |
|
|
|
} |
|
|
|
} |
|
|
|
kernel_graph_group_info->group_name_ = "kernel_graph"; |
|
|
|
for (const auto &graph : kernel_graph_group_info->graphs_) { |
|
|
|
kernel_graph_group_info->group_name_ += ("_" + std::to_string(graph->graph_id())); |
|
|
|
} |
|
|
|
kernel_graph_group_infos_.emplace(kernel_graph_group_info); |
|
|
|
(void)kernel_graph_group_infos_.emplace(kernel_graph_group_info); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -2011,13 +2006,13 @@ bool ControlNodeParser::IsInputInSameLevel(const AnfNodePtr &node) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
void ControlNodeParser::CreateDeviceTensorForRootGraphParameter(DeviceContext *default_context) { |
|
|
|
void ControlNodeParser::CreateDeviceTensorForRootGraphParameter(DeviceContext *const default_context) { |
|
|
|
MS_EXCEPTION_IF_NULL(default_context); |
|
|
|
for (const auto ¶meter : root_graph_parameters_) { |
|
|
|
KernelWithIndex parameter_with_index(parameter, 0); |
|
|
|
if (front_to_backend_parameters_.find(parameter_with_index) == front_to_backend_parameters_.end()) { |
|
|
|
CreateDeviceTensorForFrontNode(parameter_with_index, default_context); |
|
|
|
front_to_backend_parameters_[parameter_with_index].emplace(parameter, default_context); |
|
|
|
(void)front_to_backend_parameters_[parameter_with_index].emplace(parameter, default_context); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|