|
|
|
@@ -16,12 +16,30 @@ |
|
|
|
|
|
|
|
#include "kernel/hccl/hccl_kernel_metadata.h" |
|
|
|
#include <memory> |
|
|
|
#include <set> |
|
|
|
#include "utils/utils.h" |
|
|
|
#include "kernel/hccl/hcom_util.h" |
|
|
|
#include "session/anf_runtime_algorithm.h" |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
namespace kernel { |
|
|
|
namespace { |
|
|
|
std::string GetKernelFormat(const CNodePtr &kernel_node, size_t index) { |
|
|
|
const std::set<std::string> kReduceNoSupportedSet = {kOpFormat_FRAC_Z, kOpFormat_FRACTAL_Z_C04, kOpFormat_C1HWNCoC0}; |
|
|
|
auto op_name = AnfAlgo::GetCNodeName(kernel_node); |
|
|
|
auto format = AnfAlgo::GetPrevNodeOutputFormat(kernel_node, index); |
|
|
|
if (op_name != kReduceScatter && op_name != kAllGatherOpName) { |
|
|
|
return format; |
|
|
|
} |
|
|
|
if (format == kOpFormat_FRAC_NZ && AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, index).size() <= 2) { |
|
|
|
return kOpFormat_DEFAULT; |
|
|
|
} |
|
|
|
if (kReduceNoSupportedSet.find(format) != kReduceNoSupportedSet.end()) { |
|
|
|
return kOpFormat_DEFAULT; |
|
|
|
} |
|
|
|
return format; |
|
|
|
} |
|
|
|
} // namespace |
|
|
|
void HcclMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list) { |
|
|
|
const std::vector<TypeId> kHcclSupportTypes = {kNumberTypeInt8, kNumberTypeInt32, kNumberTypeFloat16, |
|
|
|
kNumberTypeFloat32, kNumberTypeInt16}; |
|
|
|
@@ -36,13 +54,13 @@ void HcclMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<K |
|
|
|
std::vector<std::string> inputs_format{}; |
|
|
|
std::vector<TypeId> inputs_type{}; |
|
|
|
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { |
|
|
|
inputs_format.emplace_back(AnfAlgo::GetPrevNodeOutputFormat(kernel_node, input_index)); |
|
|
|
inputs_format.emplace_back(GetKernelFormat(kernel_node, input_index)); |
|
|
|
inputs_type.push_back(type); |
|
|
|
} |
|
|
|
std::vector<std::string> outputs_format; |
|
|
|
std::vector<TypeId> outputs_type; |
|
|
|
for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) { |
|
|
|
outputs_format.emplace_back(AnfAlgo::GetPrevNodeOutputFormat(kernel_node, output_index)); |
|
|
|
outputs_format.emplace_back(GetKernelFormat(kernel_node, output_index)); |
|
|
|
outputs_type.push_back(type); |
|
|
|
} |
|
|
|
auto builder = KernelBuildInfo::KernelBuildInfoBuilder(); |
|
|
|
|