|
|
|
@@ -24,6 +24,7 @@ |
|
|
|
#include <memory> |
|
|
|
#include <string> |
|
|
|
#include <algorithm> |
|
|
|
#include <iterator> |
|
|
|
|
|
|
|
#include "kernel/kernel_fusion.h" |
|
|
|
#include "debug/anf_ir_dump.h" |
|
|
|
@@ -461,6 +462,36 @@ void GetFusionScopeOutputNodeList(session::KernelGraph *kernel_graph, |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void SetFusionOpRefInfos(session::KernelGraph *kernel_graph, const std::vector<AnfNodePtr> &outputs_list, |
|
|
|
const AnfNodePtr &fusion_kernel) { |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_graph); |
|
|
|
auto manager = kernel_graph->manager(); |
|
|
|
MS_EXCEPTION_IF_NULL(manager); |
|
|
|
for (size_t idx = 0; idx < outputs_list.size(); ++idx) { |
|
|
|
auto output = outputs_list[idx]; |
|
|
|
if (output->isa<CNode>() && AnfAlgo::GetCNodeName(output) == prim::kPrimTupleGetItem->name()) { |
|
|
|
auto real_output = AnfAlgo::VisitKernel(output, 0); |
|
|
|
auto output_cnode = output->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(output_cnode); |
|
|
|
auto input2 = output_cnode->input(2); |
|
|
|
auto output_idx = GetValue<int>(GetValueNode(input2)); |
|
|
|
session::AnfWithOutIndex out_pair(real_output.first, output_idx); |
|
|
|
if (kernel_graph->IsInRefOutputMap(out_pair)) { |
|
|
|
auto origin_pair = kernel_graph->GetRefCorrespondOutput(out_pair); |
|
|
|
session::AnfWithOutIndex fusion_final_pair(fusion_kernel, idx); |
|
|
|
kernel_graph->AddRefCorrespondPairs(fusion_final_pair, origin_pair); |
|
|
|
} |
|
|
|
} else { |
|
|
|
session::AnfWithOutIndex out_pair(output, 0); |
|
|
|
if (kernel_graph->IsInRefOutputMap(out_pair)) { |
|
|
|
auto origin_pair = kernel_graph->GetRefCorrespondOutput(out_pair); |
|
|
|
session::AnfWithOutIndex fusion_final_pair(fusion_kernel, idx); |
|
|
|
kernel_graph->AddRefCorrespondPairs(fusion_final_pair, origin_pair); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void MatchConvBnreduce(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, |
|
|
|
std::unordered_set<AnfNodePtr> *fused_set, FusedNodeRecord *candidate_fusion) { |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
@@ -708,7 +739,7 @@ bool BufferFusion::ReplaceFusionOp(std::unordered_map<int32_t, BufferFusionInfo_ |
|
|
|
} |
|
|
|
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, buffer_fusion.get()); |
|
|
|
AnfAlgo::SetKernelMod(kernel_ptr, buffer_fusion.get()); |
|
|
|
// replace node |
|
|
|
SetFusionOpRefInfos(kernel_graph, buffer_fusion_info.outputs_list, buffer_fusion); |
|
|
|
ReplaceOldNode(buffer_fusion_infos, fusion_id, buffer_fusion, kernel_graph); |
|
|
|
return true; |
|
|
|
} |
|
|
|
|