diff --git a/mindspore/ccsrc/kernel/hccl/hccl_kernel_metadata.cc b/mindspore/ccsrc/kernel/hccl/hccl_kernel_metadata.cc index 6c101c92bb..f0a0dda258 100755 --- a/mindspore/ccsrc/kernel/hccl/hccl_kernel_metadata.cc +++ b/mindspore/ccsrc/kernel/hccl/hccl_kernel_metadata.cc @@ -31,29 +31,26 @@ void HcclMetadataInfo(const CNodePtr &kernel_node, std::vector data_type_list{kNumberTypeFloat32, kNumberTypeFloat16, kNumberTypeInt8, kNumberTypeInt32}; - std::vector input_format, output_format; - std::vector input_type, output_type; - for (const auto &data_type : data_type_list) { - for (const auto &format : kOpFormatList) { - auto builder = std::make_shared(); - input_format.clear(); - input_format.push_back(format); - input_type.clear(); - input_type.push_back(data_type); - output_format.clear(); - output_format.push_back(format); - output_type.clear(); - output_type.push_back(data_type); + std::vector inputs_format{}; + std::vector inputs_type{}; + for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { + inputs_format.emplace_back(AnfAlgo::GetPrevNodeOutputFormat(kernel_node, input_index)); + inputs_type.push_back(AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, input_index)); + } - builder->SetInputsFormat(input_format); - builder->SetInputsDeviceType(input_type); - builder->SetOutputsFormat(output_format); - builder->SetOutputsDeviceType(output_type); - builder->SetKernelType(HCCL_KERNEL); - kernel_info_list->emplace_back(builder->Build()); - } + std::vector outputs_format; + std::vector outputs_type; + for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) { + outputs_format.emplace_back(AnfAlgo::GetPrevNodeOutputFormat(kernel_node, output_index)); + outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel_node, output_index)); } + auto builder = KernelBuildInfo::KernelBuildInfoBuilder(); + builder.SetInputsFormat(inputs_format); + builder.SetInputsDeviceType(inputs_type); + builder.SetOutputsFormat(outputs_format); + builder.SetOutputsDeviceType(outputs_type); + builder.SetKernelType(HCCL_KERNEL); + kernel_info_list->push_back(builder.Build()); } } // namespace kernel } // namespace mindspore