|
|
|
@@ -17,6 +17,7 @@ |
|
|
|
|
|
|
|
#include <vector> |
|
|
|
#include <tuple> |
|
|
|
#include <utility> |
|
|
|
#include <unordered_set> |
|
|
|
#include <unordered_map> |
|
|
|
#include <deque> |
|
|
|
@@ -282,11 +283,17 @@ kernel::KernelBuildInfoPtr CreateFusionOpKernelInfo(const std::vector<AnfNodePtr |
|
|
|
// outputs format and data type |
|
|
|
std::vector<std::string> outputs_format; |
|
|
|
std::vector<TypeId> outputs_data_type; |
|
|
|
for (size_t index = 0; index < outputs_list.size(); ++index) { |
|
|
|
for (size_t idx = 0; idx < AnfAlgo::GetOutputTensorNum(outputs_list[index]); ++idx) { |
|
|
|
auto kernel_with_index = AnfAlgo::VisitKernel(outputs_list[index], idx); |
|
|
|
outputs_format.push_back(AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second)); |
|
|
|
outputs_data_type.push_back(AnfAlgo::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second)); |
|
|
|
for (const auto &output : outputs_list) { |
|
|
|
if (AnfAlgo::GetCNodeName(output) == prim::kPrimTupleGetItem->name()) { |
|
|
|
auto tuple_getitem = output->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(tuple_getitem); |
|
|
|
outputs_format.push_back(AnfAlgo::GetOutputFormat( |
|
|
|
tuple_getitem->input(1), IntToSize(GetValue<int>(GetValueNode(tuple_getitem->input(2)))))); |
|
|
|
outputs_data_type.push_back(AnfAlgo::GetOutputDeviceDataType( |
|
|
|
tuple_getitem->input(1), IntToSize(GetValue<int>(GetValueNode(tuple_getitem->input(2)))))); |
|
|
|
} else { |
|
|
|
outputs_format.push_back(AnfAlgo::GetOutputFormat(output, 0)); |
|
|
|
outputs_data_type.push_back(AnfAlgo::GetOutputDeviceDataType(output, 0)); |
|
|
|
} |
|
|
|
} |
|
|
|
builder.SetInputsFormat(inputs_format); |
|
|
|
@@ -320,32 +327,35 @@ AnfNodePtr CreateTupleGetItem(const AnfNodePtr &buffer_fusion_kernel, session::K |
|
|
|
return tuple_item; |
|
|
|
} |
|
|
|
|
|
|
|
void ReplaceOldNode(const std::vector<AnfNodePtr> &outputs_list, const AnfNodePtr &buffer_fusion_kernel, |
|
|
|
session::KernelGraph *kernel_graph) { |
|
|
|
void ReplaceInputNodeInOtherFusionScope(std::unordered_map<int32_t, BufferFusionInfo_t> *buffer_fusion_infos, |
|
|
|
int32_t fusion_id, const AnfNodePtr &output_item, |
|
|
|
const AnfNodePtr &replace_item) { |
|
|
|
for (int32_t id = fusion_id + 1; id <= SizeToInt(buffer_fusion_infos->size()); ++id) { |
|
|
|
auto itr = std::find((*buffer_fusion_infos)[id].inputs_list.begin(), (*buffer_fusion_infos)[id].inputs_list.end(), |
|
|
|
output_item); |
|
|
|
if (itr != (*buffer_fusion_infos)[id].inputs_list.end()) { |
|
|
|
MS_LOG(DEBUG) << "replace input of other pattern, id = " << id; |
|
|
|
*itr = replace_item; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void ReplaceOldNode(std::unordered_map<int32_t, BufferFusionInfo_t> *buffer_fusion_infos, int32_t fusion_id, |
|
|
|
const AnfNodePtr &buffer_fusion_kernel, session::KernelGraph *kernel_graph) { |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_graph); |
|
|
|
auto manager = kernel_graph->manager(); |
|
|
|
MS_EXCEPTION_IF_NULL(manager); |
|
|
|
if (outputs_list.size() == 1) { // single output |
|
|
|
(void)manager->Replace(outputs_list[0], buffer_fusion_kernel); |
|
|
|
auto buffer_fusion_info = (*buffer_fusion_infos)[fusion_id]; |
|
|
|
if (buffer_fusion_info.outputs_list.size() == 1) { // single output |
|
|
|
(void)manager->Replace(buffer_fusion_info.outputs_list[0], buffer_fusion_kernel); |
|
|
|
ReplaceInputNodeInOtherFusionScope(buffer_fusion_infos, fusion_id, buffer_fusion_info.outputs_list[0], |
|
|
|
buffer_fusion_kernel); |
|
|
|
} else { // multiple output |
|
|
|
size_t real_idx = 0; |
|
|
|
for (size_t index = 0; index < outputs_list.size(); ++index) { |
|
|
|
if (AnfAlgo::GetOutputTensorNum(outputs_list[index]) == 1) { |
|
|
|
auto tuple_item = CreateTupleGetItem(buffer_fusion_kernel, kernel_graph, real_idx++); |
|
|
|
(void)manager->Replace(outputs_list[index], tuple_item); |
|
|
|
} else { |
|
|
|
std::vector<AnfNodePtr> make_tuple_inputs; |
|
|
|
AbstractBasePtrList abstract_list; |
|
|
|
make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple)); |
|
|
|
for (size_t idx = 0; idx < AnfAlgo::GetOutputTensorNum(outputs_list[index]); ++idx) { |
|
|
|
auto tuple_item = CreateTupleGetItem(buffer_fusion_kernel, kernel_graph, real_idx++); |
|
|
|
abstract_list.push_back(tuple_item->abstract()); |
|
|
|
make_tuple_inputs.push_back(tuple_item); |
|
|
|
} |
|
|
|
AnfNodePtr make_tuple = kernel_graph->NewCNode(make_tuple_inputs); |
|
|
|
make_tuple->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list)); |
|
|
|
(void)manager->Replace(outputs_list[index], make_tuple); |
|
|
|
} |
|
|
|
for (size_t index = 0; index < buffer_fusion_info.outputs_list.size(); ++index) { |
|
|
|
auto tuple_item = CreateTupleGetItem(buffer_fusion_kernel, kernel_graph, index); |
|
|
|
(void)manager->Replace(buffer_fusion_info.outputs_list[index], tuple_item); |
|
|
|
ReplaceInputNodeInOtherFusionScope(buffer_fusion_infos, fusion_id, buffer_fusion_info.outputs_list[index], |
|
|
|
tuple_item); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -406,38 +416,67 @@ void CheckCurrentNodeIsInput(const CNodePtr &node, const int32_t &cur_fusion_id, |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void InsertNode(const AnfNodePtr &node, std::vector<AnfNodePtr> *list) { |
|
|
|
MS_EXCEPTION_IF_NULL(list); |
|
|
|
if (std::find(list->begin(), list->end(), node) == list->end()) { |
|
|
|
(void)list->insert(list->end(), node); |
|
|
|
void GetFusionScopeComputeNodeList(session::KernelGraph *kernel_graph, |
|
|
|
std::unordered_map<int32_t, BufferFusionInfo_t> *buffer_fusion_infos) { |
|
|
|
MS_EXCEPTION_IF_NULL(buffer_fusion_infos); |
|
|
|
auto nodes = TopoSort(kernel_graph->get_return()); |
|
|
|
for (auto &node : nodes) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
if (AnfAlgo::IsRealCNodeKernel(node) && AnfAlgo::HasNodeAttr(kOpAttrFusionId, node)) { |
|
|
|
auto fusion_id = AnfAlgo::GetNodeAttr<int32_t>(node, kOpAttrFusionId); |
|
|
|
(*buffer_fusion_infos)[fusion_id].anf_nodes.push_back(node); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void CheckCurrentNodeIsOutput(const CNodePtr &node, const int32_t &cur_fusion_id, |
|
|
|
std::unordered_map<int32_t, BufferFusionInfo_t> *buffer_fusion_infos) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
void GetFusionScopeOutputNodeList(session::KernelGraph *kernel_graph, |
|
|
|
std::unordered_map<int32_t, BufferFusionInfo_t> *buffer_fusion_infos) { |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_graph); |
|
|
|
MS_EXCEPTION_IF_NULL(buffer_fusion_infos); |
|
|
|
for (auto &input : node->inputs()) { |
|
|
|
MS_EXCEPTION_IF_NULL(input); |
|
|
|
if (AnfAlgo::IsRealCNodeKernel(input) && AnfAlgo::HasNodeAttr(kOpAttrFusionId, input)) { |
|
|
|
auto fusion_id = AnfAlgo::GetNodeAttr<int32_t>(input, kOpAttrFusionId); |
|
|
|
if (buffer_fusion_infos->find(fusion_id) == buffer_fusion_infos->end()) { |
|
|
|
BufferFusionInfo_t buffer_fusion_info; |
|
|
|
(*buffer_fusion_infos)[fusion_id] = buffer_fusion_info; |
|
|
|
} |
|
|
|
if (fusion_id != cur_fusion_id) { |
|
|
|
InsertNode(input, &((*buffer_fusion_infos)[fusion_id].outputs_list)); |
|
|
|
} |
|
|
|
} else if (input->isa<CNode>()) { |
|
|
|
for (auto &input_in : input->cast<CNodePtr>()->inputs()) { |
|
|
|
if (AnfAlgo::IsRealCNodeKernel(input_in) && AnfAlgo::HasNodeAttr(kOpAttrFusionId, input_in)) { |
|
|
|
auto fusion_id = AnfAlgo::GetNodeAttr<int32_t>(input_in, kOpAttrFusionId); |
|
|
|
if (buffer_fusion_infos->find(fusion_id) == buffer_fusion_infos->end()) { |
|
|
|
BufferFusionInfo_t buffer_fusion_info; |
|
|
|
(*buffer_fusion_infos)[fusion_id] = buffer_fusion_info; |
|
|
|
auto manager = kernel_graph->manager(); |
|
|
|
MS_EXCEPTION_IF_NULL(manager); |
|
|
|
|
|
|
|
for (auto &buffer_fusion_info : *buffer_fusion_infos) { |
|
|
|
auto fusion_id = buffer_fusion_info.first; |
|
|
|
auto fusion_info = buffer_fusion_info.second; |
|
|
|
for (const auto &node : fusion_info.anf_nodes) { |
|
|
|
if (AnfAlgo::GetOutputTensorNum(node) == 1) { |
|
|
|
for (auto use_node : manager->node_users()[node]) { |
|
|
|
if (std::find(fusion_info.anf_nodes.begin(), fusion_info.anf_nodes.end(), use_node.first) == |
|
|
|
fusion_info.anf_nodes.end()) { |
|
|
|
(*buffer_fusion_infos)[fusion_id].outputs_list.push_back(node); |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
} else { |
|
|
|
int prev_idx = 0; |
|
|
|
std::vector<AnfNodePtr> tuple_getitem_nodes; |
|
|
|
std::transform(manager->node_users()[node].begin(), manager->node_users()[node].end(), |
|
|
|
std::back_inserter(tuple_getitem_nodes), |
|
|
|
[](const std::pair<AnfNodePtr, int> &use_node) { return use_node.first; }); |
|
|
|
std::sort(tuple_getitem_nodes.begin(), tuple_getitem_nodes.end(), |
|
|
|
[](const AnfNodePtr &node1, const AnfNodePtr &node2) { |
|
|
|
auto getitem1 = node1->cast<CNodePtr>(); |
|
|
|
auto getitem2 = node2->cast<CNodePtr>(); |
|
|
|
auto output_idx1 = GetValue<int>(GetValueNode(getitem1->input(2))); |
|
|
|
auto output_idx2 = GetValue<int>(GetValueNode(getitem2->input(2))); |
|
|
|
return output_idx1 < output_idx2; |
|
|
|
}); |
|
|
|
for (auto getitem : tuple_getitem_nodes) { |
|
|
|
auto getitem_ptr = getitem->cast<CNodePtr>(); |
|
|
|
auto input2 = getitem_ptr->input(2); |
|
|
|
auto output_idx = GetValue<int>(GetValueNode(input2)); |
|
|
|
for (int stub_idx = prev_idx; stub_idx < output_idx; ++stub_idx) { |
|
|
|
auto stub_node = CreateTupleGetItem(node, kernel_graph, IntToSize(stub_idx)); |
|
|
|
(*buffer_fusion_infos)[fusion_id].outputs_list.push_back(stub_node); |
|
|
|
} |
|
|
|
if (fusion_id != cur_fusion_id) { |
|
|
|
InsertNode(input_in, &((*buffer_fusion_infos)[fusion_id].outputs_list)); |
|
|
|
prev_idx = output_idx + 1; |
|
|
|
for (auto item_use_node : manager->node_users()[getitem]) { |
|
|
|
if (std::find(fusion_info.anf_nodes.begin(), fusion_info.anf_nodes.end(), item_use_node.first) == |
|
|
|
fusion_info.anf_nodes.end()) { |
|
|
|
(*buffer_fusion_infos)[fusion_id].outputs_list.push_back(getitem); |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -445,15 +484,72 @@ void CheckCurrentNodeIsOutput(const CNodePtr &node, const int32_t &cur_fusion_id |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void GetFusionScopeNodeList(const session::KernelGraph &kernel_graph, |
|
|
|
std::unordered_map<int32_t, BufferFusionInfo_t> *buffer_fusion_infos) { |
|
|
|
MS_EXCEPTION_IF_NULL(buffer_fusion_infos); |
|
|
|
auto nodes = TopoSort(kernel_graph.get_return()); |
|
|
|
for (auto &node : nodes) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
if (AnfAlgo::IsRealCNodeKernel(node) && AnfAlgo::HasNodeAttr(kOpAttrFusionId, node)) { |
|
|
|
auto fusion_id = AnfAlgo::GetNodeAttr<int32_t>(node, kOpAttrFusionId); |
|
|
|
(*buffer_fusion_infos)[fusion_id].anf_nodes.push_back(node); |
|
|
|
void MatchConvBnreduce(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, |
|
|
|
std::unordered_set<AnfNodePtr> *fused_set, FusedNodeRecord *candidate_fusion) { |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
MS_EXCEPTION_IF_NULL(fused_set); |
|
|
|
MS_EXCEPTION_IF_NULL(candidate_fusion); |
|
|
|
auto manager = kernel_graph.manager(); |
|
|
|
MS_EXCEPTION_IF_NULL(manager); |
|
|
|
auto conv = cnode->input(1); |
|
|
|
if (conv->isa<CNode>() && AnfAlgo::GetCNodeName(conv) == prim::kPrimConv2D->name()) { |
|
|
|
std::vector<int> output_used_num{SizeToInt(manager->node_users()[conv].size())}; |
|
|
|
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), conv); |
|
|
|
std::unordered_set<AnfNodePtr> record{cnode, conv}; |
|
|
|
candidate_fusion->push_back(record); |
|
|
|
fused_set->insert(record.begin(), record.end()); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void MatchBnupdateRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input, const session::KernelGraph &kernel_graph, |
|
|
|
std::unordered_set<AnfNodePtr> *fused_set, FusedNodeRecord *candidate_fusion) { |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
MS_EXCEPTION_IF_NULL(fused_set); |
|
|
|
MS_EXCEPTION_IF_NULL(candidate_fusion); |
|
|
|
auto manager = kernel_graph.manager(); |
|
|
|
MS_EXCEPTION_IF_NULL(manager); |
|
|
|
auto getitem = relu_input->cast<CNodePtr>(); |
|
|
|
auto bnupdate = getitem->input(1); |
|
|
|
if (bnupdate->isa<CNode>() && AnfAlgo::GetCNodeName(bnupdate) == kBNTrainingUpdateOpName) { |
|
|
|
std::vector<int> output_used_num(AnfAlgo::GetOutputTensorNum(bnupdate), 0); |
|
|
|
for (auto out_getitem : manager->node_users()[bnupdate]) { |
|
|
|
auto out_getitem_ptr = out_getitem.first->cast<CNodePtr>(); |
|
|
|
auto input2 = out_getitem_ptr->input(2); |
|
|
|
auto output_idx = GetValue<int>(GetValueNode(input2)); |
|
|
|
output_used_num[output_idx] = SizeToInt(manager->node_users()[out_getitem.first].size()); |
|
|
|
} |
|
|
|
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), bnupdate); |
|
|
|
std::unordered_set<AnfNodePtr> record{cnode, bnupdate}; |
|
|
|
candidate_fusion->push_back(record); |
|
|
|
fused_set->insert(record.begin(), record.end()); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void MatchBnupdateAddRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input, const session::KernelGraph &kernel_graph, |
|
|
|
std::unordered_set<AnfNodePtr> *fused_set, FusedNodeRecord *candidate_fusion) { |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
MS_EXCEPTION_IF_NULL(fused_set); |
|
|
|
MS_EXCEPTION_IF_NULL(candidate_fusion); |
|
|
|
auto manager = kernel_graph.manager(); |
|
|
|
MS_EXCEPTION_IF_NULL(manager); |
|
|
|
auto add = relu_input->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(add); |
|
|
|
auto tuple_getitem = add->input(1); |
|
|
|
if (tuple_getitem->isa<CNode>() && AnfAlgo::GetCNodeName(tuple_getitem) == prim::kPrimTupleGetItem->name()) { |
|
|
|
auto getitem = tuple_getitem->cast<CNodePtr>(); |
|
|
|
auto bnupdate = getitem->input(1); |
|
|
|
if (bnupdate->isa<CNode>() && AnfAlgo::GetCNodeName(bnupdate) == kBNTrainingUpdateOpName) { |
|
|
|
std::vector<int> output_used_num(AnfAlgo::GetOutputTensorNum(bnupdate), 0); |
|
|
|
for (auto out_getitem : manager->node_users()[bnupdate]) { |
|
|
|
auto out_getitem_ptr = out_getitem.first->cast<CNodePtr>(); |
|
|
|
auto input2 = out_getitem_ptr->input(2); |
|
|
|
auto output_idx = GetValue<int>(GetValueNode(input2)); |
|
|
|
output_used_num[output_idx] = SizeToInt(manager->node_users()[out_getitem.first].size()); |
|
|
|
} |
|
|
|
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), bnupdate); |
|
|
|
std::unordered_set<AnfNodePtr> record{cnode, relu_input, bnupdate}; |
|
|
|
candidate_fusion->push_back(record); |
|
|
|
fused_set->insert(record.begin(), record.end()); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -470,15 +566,14 @@ void MatchOpNamePattern(const session::KernelGraph &kernel_graph, std::unordered |
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
if (AnfAlgo::GetCNodeName(cnode) == kBNTrainingReduceOpName) { |
|
|
|
auto conv = cnode->input(1); |
|
|
|
if (conv->isa<CNode>() && AnfAlgo::GetCNodeName(conv) == prim::kPrimConv2D->name()) { |
|
|
|
auto manager = kernel_graph.manager(); |
|
|
|
MS_EXCEPTION_IF_NULL(manager); |
|
|
|
auto &users = manager->node_users(); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(users[conv].size()), conv); |
|
|
|
std::unordered_set<AnfNodePtr> record({cnode, conv}); |
|
|
|
candidate_fusion->push_back(record); |
|
|
|
fused_set->insert(record.begin(), record.end()); |
|
|
|
MatchConvBnreduce(cnode, kernel_graph, fused_set, candidate_fusion); |
|
|
|
} else if (AnfAlgo::GetCNodeName(cnode) == kReluV2OpName || |
|
|
|
AnfAlgo::GetCNodeName(cnode) == prim::kPrimRelu->name()) { |
|
|
|
auto relu_input = cnode->input(1); |
|
|
|
if (relu_input->isa<CNode>() && AnfAlgo::GetCNodeName(relu_input) == prim::kPrimTensorAdd->name()) { |
|
|
|
MatchBnupdateAddRelu(cnode, relu_input, kernel_graph, fused_set, candidate_fusion); |
|
|
|
} else if (relu_input->isa<CNode>() && AnfAlgo::GetCNodeName(relu_input) == prim::kPrimTupleGetItem->name()) { |
|
|
|
MatchBnupdateRelu(cnode, relu_input, kernel_graph, fused_set, candidate_fusion); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -536,27 +631,23 @@ void MatchFusionTypePattern(const session::KernelGraph &kernel_graph, std::unord |
|
|
|
} |
|
|
|
} // namespace |
|
|
|
|
|
|
|
void BufferFusion::GetBufferFusionInfo(const session::KernelGraph &kernel_graph, |
|
|
|
void BufferFusion::GetBufferFusionInfo(session::KernelGraph *kernel_graph, |
|
|
|
std::unordered_map<int32_t, BufferFusionInfo_t> *buffer_fusion_infos) const { |
|
|
|
MS_EXCEPTION_IF_NULL(buffer_fusion_infos); |
|
|
|
std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph.get_return()); |
|
|
|
std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph->get_return()); |
|
|
|
for (auto &node : node_list) { |
|
|
|
if (!AnfAlgo::IsRealCNodeKernel(node)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
int32_t cur_fusion_id = -1; |
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
if (AnfAlgo::HasNodeAttr(kOpAttrFusionId, cnode)) { |
|
|
|
cur_fusion_id = AnfAlgo::GetNodeAttr<int32_t>(cnode, kOpAttrFusionId); |
|
|
|
auto cur_fusion_id = AnfAlgo::GetNodeAttr<int32_t>(cnode, kOpAttrFusionId); |
|
|
|
CheckCurrentNodeIsInput(cnode, cur_fusion_id, buffer_fusion_infos); |
|
|
|
} |
|
|
|
// Check if current node is output |
|
|
|
CheckCurrentNodeIsOutput(cnode, cur_fusion_id, buffer_fusion_infos); |
|
|
|
} |
|
|
|
|
|
|
|
GetFusionScopeNodeList(kernel_graph, buffer_fusion_infos); |
|
|
|
GetFusionScopeComputeNodeList(kernel_graph, buffer_fusion_infos); |
|
|
|
GetFusionScopeOutputNodeList(kernel_graph, buffer_fusion_infos); |
|
|
|
for (auto &buffer_fusion_info : *buffer_fusion_infos) { |
|
|
|
buffer_fusion_info.second.kernel_build_info = |
|
|
|
CreateFusionOpKernelInfo(buffer_fusion_info.second.inputs_list_in, buffer_fusion_info.second.inputs_list, |
|
|
|
@@ -569,7 +660,7 @@ bool BufferFusion::FuseBufferFusionPattern(session::KernelGraph *kernel_graph) c |
|
|
|
bool change = false; |
|
|
|
std::unordered_map<int32_t, BufferFusionInfo_t> buffer_fusion_infos; |
|
|
|
buffer_fusion_infos.clear(); |
|
|
|
GetBufferFusionInfo(*kernel_graph, &buffer_fusion_infos); |
|
|
|
GetBufferFusionInfo(kernel_graph, &buffer_fusion_infos); |
|
|
|
|
|
|
|
std::vector<mindspore::kernel::FusionScopeInfo> fusion_scope_infos; |
|
|
|
for (auto &buffer_fusion_info : buffer_fusion_infos) { |
|
|
|
@@ -600,7 +691,7 @@ bool BufferFusion::FuseBufferFusionPattern(session::KernelGraph *kernel_graph) c |
|
|
|
MS_LOG(DEBUG) << "fusion id: " << fusion_id << ", fusion op compiling failed"; |
|
|
|
continue; |
|
|
|
} |
|
|
|
change = ReplaceFusionOp(buffer_fusion_infos[fusion_id], kernel_mods[fusion_id], kernel_graph); |
|
|
|
change = ReplaceFusionOp(&buffer_fusion_infos, fusion_id, kernel_mods[fusion_id], kernel_graph); |
|
|
|
} |
|
|
|
MS_LOG(DEBUG) << "End Buffer Fusion"; |
|
|
|
return change; |
|
|
|
@@ -630,8 +721,10 @@ bool BufferFusion::MatchBufferFusionPattern(const session::KernelGraph &kernel_g |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
bool BufferFusion::ReplaceFusionOp(const BufferFusionInfo_t &buffer_fusion_info, const kernel::KernelModPtr &kernel_ptr, |
|
|
|
bool BufferFusion::ReplaceFusionOp(std::unordered_map<int32_t, BufferFusionInfo_t> *buffer_fusion_infos, |
|
|
|
int32_t fusion_id, const kernel::KernelModPtr &kernel_ptr, |
|
|
|
session::KernelGraph *kernel_graph) const { |
|
|
|
auto buffer_fusion_info = (*buffer_fusion_infos)[fusion_id]; |
|
|
|
auto buffer_fusion = CreateFusionOp(buffer_fusion_info.inputs_list, buffer_fusion_info.outputs_list, |
|
|
|
buffer_fusion_info.anf_nodes, kernel_graph); |
|
|
|
AnfAlgo::SetSelectKernelBuildInfo(buffer_fusion_info.kernel_build_info, buffer_fusion.get()); |
|
|
|
@@ -651,7 +744,7 @@ bool BufferFusion::ReplaceFusionOp(const BufferFusionInfo_t &buffer_fusion_info, |
|
|
|
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, buffer_fusion.get()); |
|
|
|
AnfAlgo::SetKernelMod(kernel_ptr, buffer_fusion.get()); |
|
|
|
// replace node |
|
|
|
ReplaceOldNode(buffer_fusion_info.outputs_list, buffer_fusion, kernel_graph); |
|
|
|
ReplaceOldNode(buffer_fusion_infos, fusion_id, buffer_fusion, kernel_graph); |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
|