|
|
|
@@ -270,17 +270,9 @@ kernel::KernelBuildInfoPtr CreateFusionOpKernelInfo(const std::vector<AnfNodePtr |
|
|
|
std::vector<std::string> inputs_format; |
|
|
|
std::vector<TypeId> inputs_data_type; |
|
|
|
for (const auto &input : inputs_list) { |
|
|
|
if (input->isa<CNode>() && AnfAlgo::GetCNodeName(input) == prim::kPrimTupleGetItem->name()) { |
|
|
|
auto tuple_getitem = input->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(tuple_getitem); |
|
|
|
inputs_format.push_back(AnfAlgo::GetOutputFormat( |
|
|
|
tuple_getitem->input(1), IntToSize(GetValue<int>(GetValueNode(tuple_getitem->input(2)))))); |
|
|
|
inputs_data_type.push_back(AnfAlgo::GetOutputDeviceDataType( |
|
|
|
tuple_getitem->input(1), IntToSize(GetValue<int>(GetValueNode(tuple_getitem->input(2)))))); |
|
|
|
} else { |
|
|
|
inputs_format.push_back(AnfAlgo::GetOutputFormat(input, 0)); |
|
|
|
inputs_data_type.push_back(AnfAlgo::GetOutputDeviceDataType(input, 0)); |
|
|
|
} |
|
|
|
auto real_input = AnfAlgo::VisitKernel(input, 0); |
|
|
|
inputs_format.push_back(AnfAlgo::GetOutputFormat(real_input.first, real_input.second)); |
|
|
|
inputs_data_type.push_back(AnfAlgo::GetOutputDeviceDataType(real_input.first, real_input.second)); |
|
|
|
} |
|
|
|
// outputs format and data type |
|
|
|
std::vector<std::string> outputs_format; |
|
|
|
@@ -375,11 +367,10 @@ void GetFusionScopeComputeNodeList(session::KernelGraph *kernel_graph, |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void GetFusionScopeInputNodeList(session::KernelGraph *kernel_graph, |
|
|
|
void GetFusionScopeInputNodeList(const session::KernelGraph &kernel_graph, |
|
|
|
std::unordered_map<int32_t, BufferFusionInfo_t> *buffer_fusion_infos) { |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_graph); |
|
|
|
MS_EXCEPTION_IF_NULL(buffer_fusion_infos); |
|
|
|
auto manager = kernel_graph->manager(); |
|
|
|
auto manager = kernel_graph.manager(); |
|
|
|
MS_EXCEPTION_IF_NULL(manager); |
|
|
|
|
|
|
|
for (auto &buffer_fusion_info : *buffer_fusion_infos) { |
|
|
|
@@ -643,7 +634,7 @@ void BufferFusion::GetBufferFusionInfo(session::KernelGraph *kernel_graph, |
|
|
|
std::unordered_map<int32_t, BufferFusionInfo_t> *buffer_fusion_infos) const { |
|
|
|
MS_EXCEPTION_IF_NULL(buffer_fusion_infos); |
|
|
|
GetFusionScopeComputeNodeList(kernel_graph, buffer_fusion_infos); |
|
|
|
GetFusionScopeInputNodeList(kernel_graph, buffer_fusion_infos); |
|
|
|
GetFusionScopeInputNodeList(*kernel_graph, buffer_fusion_infos); |
|
|
|
GetFusionScopeOutputNodeList(kernel_graph, buffer_fusion_infos); |
|
|
|
for (auto &buffer_fusion_info : *buffer_fusion_infos) { |
|
|
|
buffer_fusion_info.second.kernel_build_info = |
|
|
|
|