|
|
|
@@ -24,15 +24,6 @@ |
|
|
|
namespace mindspore { |
|
|
|
namespace runtime { |
|
|
|
namespace { |
|
|
|
// Check whether node has a partial structure, a node is a partial structure whicih: |
|
|
|
// 1. a partial cnode. |
|
|
|
// 2. a funcgraph value node. |
|
|
|
bool IsPartial(const AnfNodePtr &node) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
return (node->isa<ValueNode>() && IsValueNode<FuncGraph>(node)) || |
|
|
|
AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial); |
|
|
|
} |
|
|
|
|
|
|
|
// Check if node is a value node need to create a device tensor. |
|
|
|
bool IsFrontValueNode(const KernelWithIndex &node_with_index) { |
|
|
|
const auto &node = node_with_index.first; |
|
|
|
@@ -58,135 +49,6 @@ bool IsFrontValueNode(const KernelWithIndex &node_with_index) { |
|
|
|
return !sub_abstracts[index]->isa<abstract::AbstractMonad>(); |
|
|
|
} |
|
|
|
|
|
|
|
// Get funcgraph in partial structure. |
|
|
|
// Depth represents the number of layers of the call. When the first input of the call node is a call node, |
|
|
|
// the funcgraph in the return value of the inner call needs to be returned. |
|
|
|
FuncGraphPtr GetFuncGraphFromPartial(const AnfNodePtr &node, std::stack<size_t> *output_indexs) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
MS_EXCEPTION_IF_NULL(output_indexs); |
|
|
|
|
|
|
|
if (output_indexs->empty()) { |
|
|
|
if (node->isa<ValueNode>() && IsValueNode<FuncGraph>(node)) { |
|
|
|
// Value node of funcgraph. |
|
|
|
return GetValueNode<FuncGraphPtr>(node); |
|
|
|
} else if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial)) { |
|
|
|
// Partial cnode. |
|
|
|
const auto &partial_inputs = node->cast<CNodePtr>()->inputs(); |
|
|
|
return GetValueNode<FuncGraphPtr>(partial_inputs[kPartialFuncGraphPos]); |
|
|
|
} else { |
|
|
|
MS_LOG(EXCEPTION) << "Invalid partial construct node:" << node->DebugString(); |
|
|
|
} |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
// Get funcgraph in the output of inner call. |
|
|
|
FuncGraphPtr func_graph; |
|
|
|
if (node->isa<ValueNode>() && IsValueNode<FuncGraph>(node)) { |
|
|
|
func_graph = GetValueNode<FuncGraphPtr>(node); |
|
|
|
} else if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial)) { |
|
|
|
const auto &partial_inputs = node->cast<CNodePtr>()->inputs(); |
|
|
|
const auto &func_graph_node = partial_inputs[kPartialFuncGraphPos]; |
|
|
|
if (!func_graph_node->isa<ValueNode>() || !IsValueNode<FuncGraph>(func_graph_node)) { |
|
|
|
MS_LOG(EXCEPTION) << "Invalid partial node:" << node->DebugString(); |
|
|
|
} |
|
|
|
func_graph = GetValueNode<FuncGraphPtr>(func_graph_node); |
|
|
|
} else { |
|
|
|
MS_LOG(EXCEPTION) << "Invalid partial node:" << node->DebugString(); |
|
|
|
} |
|
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
|
|
|
|
|
size_t index = output_indexs->top(); |
|
|
|
output_indexs->pop(); |
|
|
|
const auto &output = func_graph->output(); |
|
|
|
MS_EXCEPTION_IF_NULL(output); |
|
|
|
KernelWithIndex real_output_with_index = AnfAlgo::VisitKernelWithReturnType(output, 0); |
|
|
|
auto real_output = real_output_with_index.first; |
|
|
|
MS_EXCEPTION_IF_NULL(real_output); |
|
|
|
if (AnfAlgo::CheckPrimitiveType(real_output, prim::kPrimMakeTuple)) { |
|
|
|
const auto &cnode = real_output->cast<CNodePtr>(); |
|
|
|
const auto &inputs = cnode->inputs(); |
|
|
|
if (inputs.size() <= index + kMakeTupleInputStartPos) { |
|
|
|
MS_LOG(EXCEPTION) << "Invalid output index:" << index << " for node:" << real_output->DebugString(); |
|
|
|
} |
|
|
|
real_output = inputs[index + kMakeTupleInputStartPos]; |
|
|
|
} |
|
|
|
MS_EXCEPTION_IF_NULL(real_output); |
|
|
|
return GetFuncGraphFromPartial(real_output, output_indexs); |
|
|
|
} |
|
|
|
|
|
|
|
// Find all funcgraphs that the call node will call. |
|
|
|
// The output index represents the index of the funcgraph in the graph output. When the funcgraph is passed through |
|
|
|
// the function return value, the index in the return value needs to be placed on the stack so that the index can be |
|
|
|
// obtained from the stack when the graph output is found. |
|
|
|
std::set<FuncGraphPtr> GetFuncGraphbyCallNode(const AnfNodePtr &node, std::stack<size_t> *output_indexs, |
|
|
|
AnfNodePtr make_tuple = nullptr) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
std::set<FuncGraphPtr> func_graphs; |
|
|
|
if (!node->isa<CNode>()) { |
|
|
|
return func_graphs; |
|
|
|
} |
|
|
|
|
|
|
|
const auto &cnode = node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
const auto &call_input0 = cnode->input(0); |
|
|
|
MS_EXCEPTION_IF_NULL(call_input0); |
|
|
|
const auto &call_input_with_index = AnfAlgo::VisitKernelWithReturnType(call_input0, 0); |
|
|
|
const auto &real_input = call_input_with_index.first; |
|
|
|
if (AnfAlgo::IsCallNode(real_input)) { |
|
|
|
output_indexs->push(call_input_with_index.second); |
|
|
|
return GetFuncGraphbyCallNode(real_input, output_indexs); |
|
|
|
} |
|
|
|
|
|
|
|
if (AnfAlgo::CheckPrimitiveType(real_input, prim::kPrimSwitch)) { |
|
|
|
// First input node of call is switch node. |
|
|
|
const auto &input_cnode = real_input->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(input_cnode); |
|
|
|
const auto &switch_inputs = input_cnode->inputs(); |
|
|
|
for (size_t i = kSwitchTrueBranchPos; i < switch_inputs.size(); ++i) { |
|
|
|
MS_EXCEPTION_IF_NULL(switch_inputs[i]); |
|
|
|
std::stack<size_t> tmp_output_indexs = *output_indexs; |
|
|
|
(void)func_graphs.emplace(GetFuncGraphFromPartial(switch_inputs[i], &tmp_output_indexs)); |
|
|
|
} |
|
|
|
} else if (AnfAlgo::CheckPrimitiveType(real_input, prim::kPrimSwitchLayer)) { |
|
|
|
// First input node of call is switch layer node. |
|
|
|
const auto &input_cnode = real_input->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(input_cnode); |
|
|
|
const auto &tuple_node = input_cnode->input(kSwitchLayerBranchPos); |
|
|
|
if (AnfAlgo::CheckPrimitiveType(tuple_node, prim::kPrimMakeTuple)) { |
|
|
|
const auto &tuple_inputs = tuple_node->cast<CNodePtr>()->inputs(); |
|
|
|
for (size_t i = kMakeTupleInputStartPos; i < tuple_inputs.size(); ++i) { |
|
|
|
MS_EXCEPTION_IF_NULL(tuple_inputs[i]); |
|
|
|
std::stack<size_t> tmp_output_indexs = *output_indexs; |
|
|
|
(void)func_graphs.emplace(GetFuncGraphFromPartial(tuple_inputs[i], &tmp_output_indexs)); |
|
|
|
} |
|
|
|
} else if (tuple_node->isa<Parameter>()) { |
|
|
|
const auto &abstract = tuple_node->abstract(); |
|
|
|
MS_EXCEPTION_IF_NULL(abstract); |
|
|
|
if (!abstract->isa<abstract::AbstractTuple>()) { |
|
|
|
MS_LOG(EXCEPTION) << "Invalid abstract:" << abstract->ToString() << " in node:" << tuple_node->DebugString() |
|
|
|
<< " for switch layer node:" << real_input->DebugString(); |
|
|
|
} |
|
|
|
MS_EXCEPTION_IF_NULL(make_tuple); |
|
|
|
auto partials = make_tuple->cast<CNodePtr>()->inputs(); |
|
|
|
for (const auto partial : partials) { |
|
|
|
if (IsPartial(partial)) { |
|
|
|
std::stack<size_t> tmp_output_indexs; |
|
|
|
(void)func_graphs.emplace(GetFuncGraphFromPartial(partial, &tmp_output_indexs)); |
|
|
|
} |
|
|
|
} |
|
|
|
} else { |
|
|
|
MS_LOG(EXCEPTION) << "Invalid input tuple node:" << tuple_node->DebugString() |
|
|
|
<< " for switch layer node:" << cnode->DebugString(); |
|
|
|
} |
|
|
|
} else if (IsPartial(real_input)) { |
|
|
|
// First input node of call is partial node or value node of funcgraph. |
|
|
|
(void)func_graphs.emplace(GetFuncGraphFromPartial(real_input, output_indexs)); |
|
|
|
} else { |
|
|
|
MS_LOG(EXCEPTION) << "Unable to identify call node" << real_input->DebugString(); |
|
|
|
} |
|
|
|
return func_graphs; |
|
|
|
} |
|
|
|
|
|
|
|
// Fetch real input node in maketuple. |
|
|
|
KernelWithIndex FetchRealInputNode(const KernelWithIndex &node_with_index) { |
|
|
|
const auto &node = node_with_index.first; |
|
|
|
@@ -456,7 +318,7 @@ std::vector<KernelWithIndex> FetchAllOutputWithIndex(const AnfNodePtr &node) { |
|
|
|
const auto &cnode = node_with_index.first->cast<CNodePtr>(); |
|
|
|
const auto &inputs = cnode->inputs(); |
|
|
|
const auto &tmp_list = FetchAllOutputWithIndex(inputs[i + kMakeTupleInputStartPos]); |
|
|
|
result.insert(result.end(), tmp_list.begin(), tmp_list.end()); |
|
|
|
(void)result.insert(result.end(), tmp_list.begin(), tmp_list.end()); |
|
|
|
} else if (AnfAlgo::CheckPrimitiveType(node_with_index.first, prim::kPrimSwitch) || |
|
|
|
AnfAlgo::CheckPrimitiveType(node_with_index.first, prim::kPrimSwitchLayer)) { |
|
|
|
} else if (AnfAlgo::IsCallNode(node_with_index.first)) { |
|
|
|
@@ -517,7 +379,6 @@ void CreateDeviceTensorForValueNode(const KernelWithIndex &front_node_with_index |
|
|
|
void CreateDeviceTensorForFrontNode(const KernelWithIndex &front_node_with_index, const DeviceContext *device_context) { |
|
|
|
MS_EXCEPTION_IF_NULL(device_context); |
|
|
|
const auto &node = front_node_with_index.first; |
|
|
|
MS_EXCEPTION_IF_NULL(device_context); |
|
|
|
|
|
|
|
TypeId type_id = AnfAlgo::GetOutputInferDataType(node, 0); |
|
|
|
if (node->kernel_info() == nullptr) { |
|
|
|
@@ -579,7 +440,7 @@ std::vector<KernelWithIndex> FetchInputNodeByNode(const AnfNodePtr &node) { |
|
|
|
const auto &inputs = cnode->inputs(); |
|
|
|
for (size_t i = kMakeTupleInputStartPos; i < inputs.size(); ++i) { |
|
|
|
const auto &sub_results = FetchInputNodeByNode(inputs[i]); |
|
|
|
results.insert(results.end(), sub_results.begin(), sub_results.end()); |
|
|
|
(void)results.insert(results.end(), sub_results.begin(), sub_results.end()); |
|
|
|
} |
|
|
|
return results; |
|
|
|
} |
|
|
|
@@ -608,8 +469,9 @@ std::vector<KernelWithIndex> FetchInputNodeByNode(const AnfNodePtr &node) { |
|
|
|
if (csr_tensor_inputs.size() <= kMakeCSRTensorInputNum) { |
|
|
|
MS_LOG(EXCEPTION) << "Invalid make csr tensor node:" << cnode->DebugString(); |
|
|
|
} |
|
|
|
const auto &sub_results = FetchInputNodeByNode(csr_tensor_inputs[iter->second + kMakeCSRTensorInputStartPos]); |
|
|
|
results.insert(results.end(), sub_results.begin(), sub_results.end()); |
|
|
|
const auto &sub_results = |
|
|
|
FetchInputNodeByNode(csr_tensor_inputs[LongToSize(iter->second) + kMakeCSRTensorInputStartPos]); |
|
|
|
(void)results.insert(results.end(), sub_results.begin(), sub_results.end()); |
|
|
|
} else { |
|
|
|
// Csr node from parameter or call node. |
|
|
|
auto abstract = src_node->abstract(); |
|
|
|
@@ -637,7 +499,7 @@ std::vector<KernelWithIndex> FetchInputNodeByNode(const AnfNodePtr &node) { |
|
|
|
MS_EXCEPTION_IF_NULL(get_item_src_node); |
|
|
|
if (index_stack.empty()) { |
|
|
|
const auto &sub_results = FetchInputNodeByNode(get_item_src_node); |
|
|
|
results.insert(results.end(), sub_results.begin(), sub_results.end()); |
|
|
|
(void)results.insert(results.end(), sub_results.begin(), sub_results.end()); |
|
|
|
return results; |
|
|
|
} |
|
|
|
auto get_item_src_abstract = get_item_src_node->abstract(); |
|
|
|
@@ -698,7 +560,7 @@ bool IsFirstControlNode(const AnfNodePtr &node, std::set<AnfNodePtr> *checked_no |
|
|
|
if (!node->isa<CNode>() || checked_nodes->find(node) != checked_nodes->end()) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
checked_nodes->emplace(node); |
|
|
|
(void)checked_nodes->emplace(node); |
|
|
|
|
|
|
|
const auto &cnode = node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
@@ -722,7 +584,7 @@ size_t ParseControlNodeLevel(const AnfNodePtr &node, std::set<AnfNodePtr> *check |
|
|
|
if (!node->isa<CNode>() || checked_nodes->find(node) != checked_nodes->end()) { |
|
|
|
return 0; |
|
|
|
} |
|
|
|
checked_nodes->emplace(node); |
|
|
|
(void)checked_nodes->emplace(node); |
|
|
|
|
|
|
|
auto iter = node_to_level.find(node); |
|
|
|
if (iter != node_to_level.end()) { |
|
|
|
@@ -832,7 +694,7 @@ std::vector<KernelWithIndex> FetchInputNodeByCNode(const AnfNodePtr &node) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
const auto &sub_results = FetchInputNodeByNode(inputs[i]); |
|
|
|
results.insert(results.end(), sub_results.begin(), sub_results.end()); |
|
|
|
(void)results.insert(results.end(), sub_results.begin(), sub_results.end()); |
|
|
|
} |
|
|
|
return results; |
|
|
|
} |
|
|
|
@@ -1259,8 +1121,9 @@ void ControlNodeParser::ParseDeviceContextForReturnNode(const DeviceContext *def |
|
|
|
if (func_graph_iter == func_graph_to_device_contexts_.end()) { |
|
|
|
MS_LOG(EXCEPTION) << "Cannot find device context for funcgraph:" << func_graph->ToString(); |
|
|
|
} |
|
|
|
MS_EXCEPTION_IF_NULL(func_graph_iter->second[iter - func_graph->parameters().begin()]); |
|
|
|
(void)return_device_contexts.emplace_back(func_graph_iter->second[iter - func_graph->parameters().begin()]); |
|
|
|
size_t index = LongToSize(iter - func_graph->parameters().begin()); |
|
|
|
MS_EXCEPTION_IF_NULL(func_graph_iter->second[index]); |
|
|
|
(void)return_device_contexts.emplace_back(func_graph_iter->second[index]); |
|
|
|
} else if (output_node.first->isa<ValueNode>()) { |
|
|
|
// If the output is parameter, used the default context type. |
|
|
|
MS_EXCEPTION_IF_NULL(default_context); |
|
|
|
@@ -1449,13 +1312,14 @@ void ControlNodeParser::ParseFormalToRealParameter(const std::vector<AnfNodePtr> |
|
|
|
const auto &func_graphs = FetchFuncGraphbyCallNode(node); |
|
|
|
for (const auto func_graph : func_graphs) { |
|
|
|
const auto ¶meters = func_graph->parameters(); |
|
|
|
for (size_t i = inputs.size() - 1, j = parameters.size() - 1; i >= kCallInputStartPos && j >= 0; --i, --j) { |
|
|
|
MS_EXCEPTION_IF_NULL(inputs[i]); |
|
|
|
MS_EXCEPTION_IF_NULL(parameters[j]); |
|
|
|
if (HasAbstractMonad(inputs[i])) { |
|
|
|
for (int i = inputs.size() - 1, j = parameters.size() - 1; i >= 1 && j >= 0; --i, --j) { |
|
|
|
MS_EXCEPTION_IF_NULL(inputs[IntToSize(i)]); |
|
|
|
MS_EXCEPTION_IF_NULL(parameters[IntToSize(j)]); |
|
|
|
if (HasAbstractMonad(inputs[IntToSize(i)])) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
AddFormalToRealParameter(parameters[j], inputs[i], call_node_to_func_graphs_, &formal_to_real_parameters); |
|
|
|
AddFormalToRealParameter(parameters[IntToSize(j)], inputs[IntToSize(i)], call_node_to_func_graphs_, |
|
|
|
&formal_to_real_parameters); |
|
|
|
} |
|
|
|
} |
|
|
|
} else if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial)) { |
|
|
|
@@ -1968,7 +1832,7 @@ void ControlNodeParser::ParseNodeLevel(const std::vector<AnfNodePtr> &control_no |
|
|
|
kernel_graph_group_info->need_stack_ = true; |
|
|
|
kernel_graph_group_info->level_ = max_level; |
|
|
|
for (const auto &kernel_graph : kernel_graph_group_info->graphs_) { |
|
|
|
call_input_kernel_graphs_.emplace(kernel_graph.get()); |
|
|
|
(void)call_input_kernel_graphs_.emplace(kernel_graph.get()); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|