|
|
|
@@ -31,29 +31,26 @@ void HcclMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<K |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<TypeId> data_type_list{kNumberTypeFloat32, kNumberTypeFloat16, kNumberTypeInt8, kNumberTypeInt32}; |
|
|
|
std::vector<std::string> input_format, output_format; |
|
|
|
std::vector<TypeId> input_type, output_type; |
|
|
|
for (const auto &data_type : data_type_list) { |
|
|
|
for (const auto &format : kOpFormatList) { |
|
|
|
auto builder = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>(); |
|
|
|
input_format.clear(); |
|
|
|
input_format.push_back(format); |
|
|
|
input_type.clear(); |
|
|
|
input_type.push_back(data_type); |
|
|
|
output_format.clear(); |
|
|
|
output_format.push_back(format); |
|
|
|
output_type.clear(); |
|
|
|
output_type.push_back(data_type); |
|
|
|
std::vector<std::string> inputs_format{}; |
|
|
|
std::vector<TypeId> inputs_type{}; |
|
|
|
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { |
|
|
|
inputs_format.emplace_back(AnfAlgo::GetPrevNodeOutputFormat(kernel_node, input_index)); |
|
|
|
inputs_type.push_back(AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, input_index)); |
|
|
|
} |
|
|
|
|
|
|
|
builder->SetInputsFormat(input_format); |
|
|
|
builder->SetInputsDeviceType(input_type); |
|
|
|
builder->SetOutputsFormat(output_format); |
|
|
|
builder->SetOutputsDeviceType(output_type); |
|
|
|
builder->SetKernelType(HCCL_KERNEL); |
|
|
|
kernel_info_list->emplace_back(builder->Build()); |
|
|
|
} |
|
|
|
std::vector<std::string> outputs_format; |
|
|
|
std::vector<TypeId> outputs_type; |
|
|
|
for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) { |
|
|
|
outputs_format.emplace_back(AnfAlgo::GetPrevNodeOutputFormat(kernel_node, output_index)); |
|
|
|
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel_node, output_index)); |
|
|
|
} |
|
|
|
auto builder = KernelBuildInfo::KernelBuildInfoBuilder(); |
|
|
|
builder.SetInputsFormat(inputs_format); |
|
|
|
builder.SetInputsDeviceType(inputs_type); |
|
|
|
builder.SetOutputsFormat(outputs_format); |
|
|
|
builder.SetOutputsDeviceType(outputs_type); |
|
|
|
builder.SetKernelType(HCCL_KERNEL); |
|
|
|
kernel_info_list->push_back(builder.Build()); |
|
|
|
} |
|
|
|
} // namespace kernel |
|
|
|
} // namespace mindspore |