|
|
|
@@ -1259,10 +1259,65 @@ bool AnfRuntimeAlgorithm::IsScalarOutput(const CNodePtr &cnode, size_t index) { |
|
|
|
void AnfRuntimeAlgorithm::ReorderExecList(NotNull<std::vector<CNodePtr> *> node_list) { |
|
|
|
std::vector<CNodePtr> all_opt_list; |
|
|
|
std::vector<CNodePtr> non_opt_list; |
|
|
|
|
|
|
|
std::vector<CNodePtr> trans_list; |
|
|
|
std::vector<CNodePtr> transpose_list; |
|
|
|
std::vector<CNodePtr> cast_list; |
|
|
|
for (const auto &node : *node_list) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
if (kOptOperatorSet.find(AnfAlgo::GetCNodeName(node)) != kOptOperatorSet.end()) { |
|
|
|
auto trans_pose_func = [&](const CNodePtr &node) -> bool { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
if (AnfAlgo::GetCNodeName(node) == prim::kPrimTranspose->name()) { |
|
|
|
auto kernel_index = AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(node, 0), 0); |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_index.first); |
|
|
|
if (kernel_index.first->isa<CNode>() && kOptOperatorSet.find(AnfAlgo::GetCNodeName( |
|
|
|
kernel_index.first->cast<CNodePtr>())) != kOptOperatorSet.end()) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
} |
|
|
|
return false; |
|
|
|
}; |
|
|
|
|
|
|
|
auto trans_data_func = [&](const CNodePtr &node) -> bool { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
if (AnfAlgo::GetCNodeName(node) == prim::KPrimTransData->name()) { |
|
|
|
auto kernel_index = AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(node, 0), 0); |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_index.first); |
|
|
|
if (kernel_index.first->isa<CNode>() && kOptOperatorSet.find(AnfAlgo::GetCNodeName( |
|
|
|
kernel_index.first->cast<CNodePtr>())) != kOptOperatorSet.end()) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
if (!kernel_index.first->isa<CNode>()) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
return trans_pose_func(kernel_index.first->cast<CNodePtr>()); |
|
|
|
} |
|
|
|
return false; |
|
|
|
}; |
|
|
|
|
|
|
|
auto cast_func = [&](const CNodePtr &node) -> bool { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
if (AnfAlgo::GetCNodeName(node) == prim::kPrimCast->name()) { |
|
|
|
auto kernel_index = AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(node, 0), 0); |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_index.first); |
|
|
|
if (kernel_index.first->isa<CNode>() && kOptOperatorSet.find(AnfAlgo::GetCNodeName( |
|
|
|
kernel_index.first->cast<CNodePtr>())) != kOptOperatorSet.end()) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
if (!kernel_index.first->isa<CNode>()) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
return trans_data_func(kernel_index.first->cast<CNodePtr>()); |
|
|
|
} |
|
|
|
return false; |
|
|
|
}; |
|
|
|
|
|
|
|
if (trans_pose_func(node)) { |
|
|
|
transpose_list.emplace_back(node); |
|
|
|
} else if (trans_data_func(node)) { |
|
|
|
trans_list.emplace_back(node); |
|
|
|
} else if (cast_func(node)) { |
|
|
|
cast_list.emplace_back(node); |
|
|
|
} else if (kOptOperatorSet.find(AnfAlgo::GetCNodeName(node)) != kOptOperatorSet.end()) { |
|
|
|
all_opt_list.emplace_back(node); |
|
|
|
} else { |
|
|
|
non_opt_list.emplace_back(node); |
|
|
|
@@ -1271,6 +1326,9 @@ void AnfRuntimeAlgorithm::ReorderExecList(NotNull<std::vector<CNodePtr> *> node_ |
|
|
|
node_list->clear(); |
|
|
|
std::copy(non_opt_list.begin(), non_opt_list.end(), std::back_inserter(*node_list)); |
|
|
|
std::copy(all_opt_list.begin(), all_opt_list.end(), std::back_inserter(*node_list)); |
|
|
|
std::copy(transpose_list.begin(), transpose_list.end(), std::back_inserter(*node_list)); |
|
|
|
std::copy(trans_list.begin(), trans_list.end(), std::back_inserter(*node_list)); |
|
|
|
std::copy(cast_list.begin(), cast_list.end(), std::back_inserter(*node_list)); |
|
|
|
} |
|
|
|
|
|
|
|
TypeId AnfRuntimeAlgorithm::GetCNodeOutputPrecision(const AnfNodePtr &node) { |
|
|
|
|