|
|
|
@@ -357,24 +357,28 @@ void SetFusionOpRefInfos(session::KernelGraph *kernel_graph, const std::vector<A |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
bool CheckCircle(const session::KernelGraph &kernel_graph, const BufferFusionInfo_t &fusion_info) { |
|
|
|
bool has_circle = false; |
|
|
|
for (auto &inp : fusion_info.inputs_list) { |
|
|
|
MS_EXCEPTION_IF_NULL(inp); |
|
|
|
if (!inp->isa<CNode>() || AnfAlgo::CheckPrimitiveType(inp, prim::kPrimLoad)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
if (IsDepend(kernel_graph, inp, fusion_info.anf_nodes)) { |
|
|
|
has_circle = true; |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
return has_circle; |
|
|
|
} |
|
|
|
|
|
|
|
void RemoveCircle(const session::KernelGraph &kernel_graph, |
|
|
|
std::unordered_map<int64_t, BufferFusionInfo_t> *buffer_fusion_infos) { |
|
|
|
MS_EXCEPTION_IF_NULL(buffer_fusion_infos); |
|
|
|
std::vector<int64_t> fusion_ids; |
|
|
|
for (auto &[fusion_id, fusion_info] : *buffer_fusion_infos) { |
|
|
|
bool has_circle = false; |
|
|
|
for (auto &inp : fusion_info.inputs_list) { |
|
|
|
MS_EXCEPTION_IF_NULL(inp); |
|
|
|
if (!inp->isa<CNode>() || AnfAlgo::CheckPrimitiveType(inp, prim::kPrimLoad)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
if (IsDepend(kernel_graph, inp, fusion_info.anf_nodes)) { |
|
|
|
has_circle = true; |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
bool has_circle = CheckCircle(kernel_graph, fusion_info); |
|
|
|
if (has_circle) { |
|
|
|
fusion_ids.emplace_back(fusion_id); |
|
|
|
} |
|
|
|
@@ -435,7 +439,11 @@ bool UbPatternFusion::FuseBufferFusionPattern(session::KernelGraph *kernel_graph |
|
|
|
MS_LOG(DEBUG) << "fusion id: " << fusion_id << ", fusion op compiling failed"; |
|
|
|
continue; |
|
|
|
} |
|
|
|
change = ReplaceFusionOp(&buffer_fusion_infos, fusion_id, kernel_mods[fusion_id], kernel_graph); |
|
|
|
if (CheckCircle(*kernel_graph, buffer_fusion_infos[fusion_id])) { |
|
|
|
MS_LOG(DEBUG) << "fusion id: " << fusion_id << " will cause graph circle, pass this fusion."; |
|
|
|
} else { |
|
|
|
change = ReplaceFusionOp(&buffer_fusion_infos, fusion_id, kernel_mods[fusion_id], kernel_graph); |
|
|
|
} |
|
|
|
} |
|
|
|
MS_LOG(DEBUG) << "End Buffer Fusion"; |
|
|
|
return change; |
|
|
|
|