|
|
|
@@ -34,9 +34,28 @@ bool IsPartial(const AnfNodePtr &node) { |
|
|
|
} |
|
|
|
|
|
|
|
// Check if node is a value node need to create a device tensor. |
|
|
|
bool IsFrontValueNode(const AnfNodePtr &node) { |
|
|
|
bool IsFrontValueNode(const KernelWithIndex &node_with_index) { |
|
|
|
const auto &node = node_with_index.first; |
|
|
|
size_t index = node_with_index.second; |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
return node->isa<ValueNode>() && (!IsValueNode<FuncGraph>(node)) && (!IsValueNode<Primitive>(node)); |
|
|
|
if (!node->isa<ValueNode>() || IsValueNode<FuncGraph>(node) || IsValueNode<Primitive>(node)) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
if (!IsValueNode<ValueTuple>(node)) { |
|
|
|
return !HasAbstractMonad(node); |
|
|
|
} |
|
|
|
|
|
|
|
const auto &abstract = node->abstract(); |
|
|
|
MS_EXCEPTION_IF_NULL(abstract); |
|
|
|
auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(tuple_abstract); |
|
|
|
const auto &sub_abstracts = tuple_abstract->elements(); |
|
|
|
if (sub_abstracts.size() <= index) { |
|
|
|
MS_LOG(EXCEPTION) << "Invalid index:" << index << " for tuple value node:" << node->DebugString(); |
|
|
|
} |
|
|
|
MS_EXCEPTION_IF_NULL(sub_abstracts[index]); |
|
|
|
return !sub_abstracts[index]->isa<abstract::AbstractMonad>(); |
|
|
|
} |
|
|
|
|
|
|
|
// Get funcgraph in partial structure. |
|
|
|
@@ -599,7 +618,6 @@ std::vector<KernelWithIndex> FetchInputNodeByNode(const AnfNodePtr &node) { |
|
|
|
} |
|
|
|
|
|
|
|
// 5 Other. |
|
|
|
size_t output_num = AnfAlgo::GetOutputNumByAbstract(abstract); |
|
|
|
if (AnfAlgo::CheckPrimitiveType(real_node, prim::kPrimTupleGetItem)) { |
|
|
|
const auto &get_item_cnode = real_node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(get_item_cnode); |
|
|
|
@@ -625,9 +643,23 @@ std::vector<KernelWithIndex> FetchInputNodeByNode(const AnfNodePtr &node) { |
|
|
|
real_indexs.begin(), real_indexs.end(), std::back_inserter(results), |
|
|
|
[&get_item_src_node](const auto &index) { return KernelWithIndex(get_item_src_node, index); }); |
|
|
|
} |
|
|
|
} else { |
|
|
|
for (size_t i = 0; i < output_num; ++i) { |
|
|
|
results.emplace_back(real_node, i); |
|
|
|
return results; |
|
|
|
} |
|
|
|
|
|
|
|
size_t output_num = AnfAlgo::GetOutputNumByAbstract(abstract); |
|
|
|
if (output_num == 1) { |
|
|
|
results.emplace_back(real_node, 0); |
|
|
|
return results; |
|
|
|
} |
|
|
|
|
|
|
|
auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(tuple_abstract); |
|
|
|
const auto &sub_abstracts = tuple_abstract->elements(); |
|
|
|
size_t index = 0; |
|
|
|
for (const auto &sub_abstract : sub_abstracts) { |
|
|
|
MS_EXCEPTION_IF_NULL(sub_abstract); |
|
|
|
if (!sub_abstract->isa<abstract::AbstractMonad>()) { |
|
|
|
results.emplace_back(real_node, index++); |
|
|
|
} |
|
|
|
} |
|
|
|
return results; |
|
|
|
@@ -1011,8 +1043,13 @@ void ControlNodeParser::ParseDeviceContextForFuncGraph(const std::vector<AnfNode |
|
|
|
FuncGraphSet sub_graphs = root_func_graph_->manager()->func_graphs(); |
|
|
|
for (auto sub_graph : sub_graphs) { |
|
|
|
if (func_graph_to_device_contexts_.find(sub_graph) == func_graph_to_device_contexts_.end()) { |
|
|
|
func_graph_to_device_contexts_[sub_graph] = |
|
|
|
std::vector<const DeviceContext *>(sub_graph->parameters().size(), default_context); |
|
|
|
size_t output_num = 0; |
|
|
|
for (const auto ¶meter : sub_graph->parameters()) { |
|
|
|
const auto &abstract = parameter->abstract(); |
|
|
|
MS_EXCEPTION_IF_NULL(abstract); |
|
|
|
output_num += AnfAlgo::GetOutputNumByAbstract(abstract); |
|
|
|
} |
|
|
|
func_graph_to_device_contexts_[sub_graph] = std::vector<const DeviceContext *>(output_num, default_context); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -1277,8 +1314,7 @@ void ControlNodeParser::FetchFrontValueNode(const std::vector<AnfNodePtr> &contr |
|
|
|
|
|
|
|
for (const auto &formal_to_real_parameter : formal_to_real_parameters_) { |
|
|
|
for (const auto &real_parameter_with_index : formal_to_real_parameter.second) { |
|
|
|
const auto &real_parameter = real_parameter_with_index.first; |
|
|
|
if (!IsFrontValueNode(real_parameter)) { |
|
|
|
if (!IsFrontValueNode(real_parameter_with_index)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -1299,7 +1335,7 @@ void ControlNodeParser::FetchFrontValueNode(const std::vector<AnfNodePtr> &contr |
|
|
|
for (const auto &front_to_backend_parameters : front_to_backend_parameters_) { |
|
|
|
const auto &front_node = front_to_backend_parameters.first.first; |
|
|
|
MS_EXCEPTION_IF_NULL(front_node); |
|
|
|
if (IsFrontValueNode(front_node) && (!front_to_backend_parameters.second.empty())) { |
|
|
|
if (IsFrontValueNode(front_to_backend_parameters.first) && (!front_to_backend_parameters.second.empty())) { |
|
|
|
const auto &backend_parameter = front_to_backend_parameters.second.begin()->first; |
|
|
|
const auto &device_context = front_to_backend_parameters.second.begin()->second; |
|
|
|
CreateDeviceTensorForValueNode(front_to_backend_parameters.first, backend_parameter, device_context); |
|
|
|
@@ -1323,7 +1359,7 @@ void ControlNodeParser::FetchFrontValueNode(const std::vector<AnfNodePtr> &contr |
|
|
|
} |
|
|
|
for (size_t i = 0; i < input_with_indexs.size(); ++i) { |
|
|
|
const auto &input_with_index = input_with_indexs[i]; |
|
|
|
if (IsFrontValueNode(input_with_index.first) && |
|
|
|
if (IsFrontValueNode(input_with_index) && |
|
|
|
front_value_nodes_.find({input_with_index, iter->second[i]}) == front_value_nodes_.end()) { |
|
|
|
CreateDeviceTensorForFrontNode(input_with_index, iter->second[i]); |
|
|
|
front_value_nodes_.emplace(input_with_index, iter->second[i]); |
|
|
|
|