|
|
|
@@ -261,23 +261,24 @@ CNodePtr CreateFusionOp(const std::vector<AnfNodePtr> &inputs_list, const std::v |
|
|
|
return buffer_fusion_kernel; |
|
|
|
} |
|
|
|
|
|
|
|
kernel::KernelBuildInfoPtr CreateFusionOpKernelInfo(const std::vector<AnfNodePtr> &inputs_list_in, |
|
|
|
const std::vector<AnfNodePtr> &inputs_list, |
|
|
|
kernel::KernelBuildInfoPtr CreateFusionOpKernelInfo(const std::vector<AnfNodePtr> &inputs_list, |
|
|
|
const std::vector<AnfNodePtr> &outputs_list) { |
|
|
|
MS_LOG(DEBUG) << "Start Create Kernel Info"; |
|
|
|
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; |
|
|
|
// inputs format and data type |
|
|
|
std::vector<std::string> inputs_format; |
|
|
|
std::vector<TypeId> inputs_data_type; |
|
|
|
for (auto node : inputs_list_in) { |
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
auto &inputs = cnode->inputs(); |
|
|
|
for (size_t input_index = 1; input_index < inputs.size(); ++input_index) { |
|
|
|
if (std::find(inputs_list.begin(), inputs_list.end(), inputs[input_index]) != inputs_list.end()) { |
|
|
|
inputs_format.push_back(AnfAlgo::GetInputFormat(node, input_index - 1)); |
|
|
|
inputs_data_type.push_back(AnfAlgo::GetInputDeviceDataType(node, input_index - 1)); |
|
|
|
} |
|
|
|
for (const auto &input : inputs_list) { |
|
|
|
if (input->isa<CNode>() && AnfAlgo::GetCNodeName(input) == prim::kPrimTupleGetItem->name()) { |
|
|
|
auto tuple_getitem = input->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(tuple_getitem); |
|
|
|
inputs_format.push_back(AnfAlgo::GetOutputFormat( |
|
|
|
tuple_getitem->input(1), IntToSize(GetValue<int>(GetValueNode(tuple_getitem->input(2)))))); |
|
|
|
inputs_data_type.push_back(AnfAlgo::GetOutputDeviceDataType( |
|
|
|
tuple_getitem->input(1), IntToSize(GetValue<int>(GetValueNode(tuple_getitem->input(2)))))); |
|
|
|
} else { |
|
|
|
inputs_format.push_back(AnfAlgo::GetOutputFormat(input, 0)); |
|
|
|
inputs_data_type.push_back(AnfAlgo::GetOutputDeviceDataType(input, 0)); |
|
|
|
} |
|
|
|
} |
|
|
|
// outputs format and data type |
|
|
|
@@ -360,62 +361,6 @@ void ReplaceOldNode(std::unordered_map<int32_t, BufferFusionInfo_t> *buffer_fusi |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void GetInputList(const CNodePtr &node, const int32_t cur_fusion_id, std::vector<AnfNodePtr> *inputs_list) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
MS_EXCEPTION_IF_NULL(inputs_list); |
|
|
|
auto &inputs = node->inputs(); |
|
|
|
for (size_t input_index = 1; input_index < inputs.size(); ++input_index) { |
|
|
|
auto input = inputs[input_index]; |
|
|
|
if (AnfAlgo::IsRealCNodeKernel(input)) { |
|
|
|
if (AnfAlgo::HasNodeAttr(kOpAttrFusionId, input)) { |
|
|
|
auto fusion_id = AnfAlgo::GetNodeAttr<int32_t>(input, kOpAttrFusionId); |
|
|
|
if (fusion_id != cur_fusion_id) { |
|
|
|
inputs_list->push_back(input); |
|
|
|
} |
|
|
|
} else { |
|
|
|
inputs_list->push_back(input); |
|
|
|
} |
|
|
|
} else if (input->isa<CNode>()) { |
|
|
|
for (auto &input_in : input->cast<CNodePtr>()->inputs()) { |
|
|
|
if (AnfAlgo::IsRealCNodeKernel(input_in)) { |
|
|
|
if (AnfAlgo::HasNodeAttr(kOpAttrFusionId, input_in)) { |
|
|
|
auto fusion_id = AnfAlgo::GetNodeAttr<int32_t>(input_in, kOpAttrFusionId); |
|
|
|
if (fusion_id != cur_fusion_id) { |
|
|
|
inputs_list->push_back(input); |
|
|
|
} |
|
|
|
} else { |
|
|
|
inputs_list->push_back(input); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} else { |
|
|
|
inputs_list->push_back(input); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void CheckCurrentNodeIsInput(const CNodePtr &node, const int32_t &cur_fusion_id, |
|
|
|
std::unordered_map<int32_t, BufferFusionInfo_t> *buffer_fusion_infos) { |
|
|
|
MS_EXCEPTION_IF_NULL(buffer_fusion_infos); |
|
|
|
if ((*buffer_fusion_infos).find(cur_fusion_id) == (*buffer_fusion_infos).end()) { |
|
|
|
BufferFusionInfo_t buffer_fusion_info; |
|
|
|
(*buffer_fusion_infos)[cur_fusion_id] = buffer_fusion_info; |
|
|
|
} |
|
|
|
std::vector<AnfNodePtr> inputs_list; |
|
|
|
GetInputList(node, cur_fusion_id, &inputs_list); |
|
|
|
if (!inputs_list.empty()) { |
|
|
|
if (!(*buffer_fusion_infos)[cur_fusion_id].inputs_list.empty()) { |
|
|
|
(void)(*buffer_fusion_infos)[cur_fusion_id].inputs_list.insert( |
|
|
|
(*buffer_fusion_infos)[cur_fusion_id].inputs_list.end(), inputs_list.begin(), inputs_list.end()); |
|
|
|
(void)(*buffer_fusion_infos)[cur_fusion_id].inputs_list_in.insert( |
|
|
|
(*buffer_fusion_infos)[cur_fusion_id].inputs_list_in.end(), node); |
|
|
|
} else { |
|
|
|
(*buffer_fusion_infos)[cur_fusion_id].inputs_list = inputs_list; |
|
|
|
(*buffer_fusion_infos)[cur_fusion_id].inputs_list_in.push_back(node); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void GetFusionScopeComputeNodeList(session::KernelGraph *kernel_graph, |
|
|
|
std::unordered_map<int32_t, BufferFusionInfo_t> *buffer_fusion_infos) { |
|
|
|
MS_EXCEPTION_IF_NULL(buffer_fusion_infos); |
|
|
|
@@ -429,6 +374,45 @@ void GetFusionScopeComputeNodeList(session::KernelGraph *kernel_graph, |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void GetFusionScopeInputNodeList(session::KernelGraph *kernel_graph, |
|
|
|
std::unordered_map<int32_t, BufferFusionInfo_t> *buffer_fusion_infos) { |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_graph); |
|
|
|
MS_EXCEPTION_IF_NULL(buffer_fusion_infos); |
|
|
|
auto manager = kernel_graph->manager(); |
|
|
|
MS_EXCEPTION_IF_NULL(manager); |
|
|
|
|
|
|
|
for (auto &buffer_fusion_info : *buffer_fusion_infos) { |
|
|
|
auto fusion_id = buffer_fusion_info.first; |
|
|
|
auto fusion_info = buffer_fusion_info.second; |
|
|
|
for (const auto &node : fusion_info.anf_nodes) { |
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
for (size_t idx = 1; idx < cnode->inputs().size(); ++idx) { |
|
|
|
auto real_input = AnfAlgo::VisitKernel(cnode->input(idx), 0); |
|
|
|
if (std::find(fusion_info.anf_nodes.begin(), fusion_info.anf_nodes.end(), real_input.first) == |
|
|
|
fusion_info.anf_nodes.end()) { |
|
|
|
if (std::find((*buffer_fusion_infos)[fusion_id].inputs_list.begin(), |
|
|
|
(*buffer_fusion_infos)[fusion_id].inputs_list.end(), |
|
|
|
cnode->input(idx)) == (*buffer_fusion_infos)[fusion_id].inputs_list.end()) { |
|
|
|
(*buffer_fusion_infos)[fusion_id].inputs_list.push_back(cnode->input(idx)); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
bool TupleGetitemNodeCompare(const AnfNodePtr &node1, const AnfNodePtr &node2) { |
|
|
|
MS_EXCEPTION_IF_NULL(node1); |
|
|
|
MS_EXCEPTION_IF_NULL(node2); |
|
|
|
auto getitem1 = node1->cast<CNodePtr>(); |
|
|
|
auto getitem2 = node2->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(getitem1); |
|
|
|
MS_EXCEPTION_IF_NULL(getitem2); |
|
|
|
auto output_idx1 = GetValue<int>(GetValueNode(getitem1->input(2))); |
|
|
|
auto output_idx2 = GetValue<int>(GetValueNode(getitem2->input(2))); |
|
|
|
return output_idx1 < output_idx2; |
|
|
|
} |
|
|
|
|
|
|
|
void GetFusionScopeOutputNodeList(session::KernelGraph *kernel_graph, |
|
|
|
std::unordered_map<int32_t, BufferFusionInfo_t> *buffer_fusion_infos) { |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_graph); |
|
|
|
@@ -454,14 +438,7 @@ void GetFusionScopeOutputNodeList(session::KernelGraph *kernel_graph, |
|
|
|
std::transform(manager->node_users()[node].begin(), manager->node_users()[node].end(), |
|
|
|
std::back_inserter(tuple_getitem_nodes), |
|
|
|
[](const std::pair<AnfNodePtr, int> &use_node) { return use_node.first; }); |
|
|
|
std::sort(tuple_getitem_nodes.begin(), tuple_getitem_nodes.end(), |
|
|
|
[](const AnfNodePtr &node1, const AnfNodePtr &node2) { |
|
|
|
auto getitem1 = node1->cast<CNodePtr>(); |
|
|
|
auto getitem2 = node2->cast<CNodePtr>(); |
|
|
|
auto output_idx1 = GetValue<int>(GetValueNode(getitem1->input(2))); |
|
|
|
auto output_idx2 = GetValue<int>(GetValueNode(getitem2->input(2))); |
|
|
|
return output_idx1 < output_idx2; |
|
|
|
}); |
|
|
|
std::sort(tuple_getitem_nodes.begin(), tuple_getitem_nodes.end(), TupleGetitemNodeCompare); |
|
|
|
for (auto getitem : tuple_getitem_nodes) { |
|
|
|
auto getitem_ptr = getitem->cast<CNodePtr>(); |
|
|
|
auto input2 = getitem_ptr->input(2); |
|
|
|
@@ -634,24 +611,12 @@ void MatchFusionTypePattern(const session::KernelGraph &kernel_graph, std::unord |
|
|
|
void BufferFusion::GetBufferFusionInfo(session::KernelGraph *kernel_graph, |
|
|
|
std::unordered_map<int32_t, BufferFusionInfo_t> *buffer_fusion_infos) const { |
|
|
|
MS_EXCEPTION_IF_NULL(buffer_fusion_infos); |
|
|
|
std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph->get_return()); |
|
|
|
for (auto &node : node_list) { |
|
|
|
if (!AnfAlgo::IsRealCNodeKernel(node)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
if (AnfAlgo::HasNodeAttr(kOpAttrFusionId, cnode)) { |
|
|
|
auto cur_fusion_id = AnfAlgo::GetNodeAttr<int32_t>(cnode, kOpAttrFusionId); |
|
|
|
CheckCurrentNodeIsInput(cnode, cur_fusion_id, buffer_fusion_infos); |
|
|
|
} |
|
|
|
} |
|
|
|
GetFusionScopeComputeNodeList(kernel_graph, buffer_fusion_infos); |
|
|
|
GetFusionScopeInputNodeList(kernel_graph, buffer_fusion_infos); |
|
|
|
GetFusionScopeOutputNodeList(kernel_graph, buffer_fusion_infos); |
|
|
|
for (auto &buffer_fusion_info : *buffer_fusion_infos) { |
|
|
|
buffer_fusion_info.second.kernel_build_info = |
|
|
|
CreateFusionOpKernelInfo(buffer_fusion_info.second.inputs_list_in, buffer_fusion_info.second.inputs_list, |
|
|
|
buffer_fusion_info.second.outputs_list); |
|
|
|
CreateFusionOpKernelInfo(buffer_fusion_info.second.inputs_list, buffer_fusion_info.second.outputs_list); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|