You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

hccl_kernel_metadata.cc 2.3 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "kernel/hccl/hccl_kernel_metadata.h"
  17. #include <memory>
  18. #include "utils/utils.h"
  19. #include "kernel/hccl/hcom_util.h"
  20. #include "session/anf_runtime_algorithm.h"
  21. namespace mindspore {
  22. namespace kernel {
  23. void HcclMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list) {
  24. MS_EXCEPTION_IF_NULL(kernel_info_list);
  25. MS_EXCEPTION_IF_NULL(kernel_node);
  26. std::string op_name = AnfAlgo::GetCNodeName(kernel_node);
  27. if (op_name != kAllGather && op_name != kAllReduce && op_name != kBroadcast && op_name != kReduceScatter) {
  28. MS_LOG(DEBUG) << "Hccl does not have op [" << op_name << "]";
  29. return;
  30. }
  31. std::vector<TypeId> data_type_list{kNumberTypeFloat32, kNumberTypeFloat16, kNumberTypeInt8, kNumberTypeInt32};
  32. std::vector<std::string> input_format, output_format;
  33. std::vector<TypeId> input_type, output_type;
  34. for (const auto &data_type : data_type_list) {
  35. for (const auto &format : k4DSupportFormat) {
  36. auto builder = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>();
  37. input_format.clear();
  38. input_format.push_back(format);
  39. input_type.clear();
  40. input_type.push_back(data_type);
  41. output_format.clear();
  42. output_format.push_back(format);
  43. output_type.clear();
  44. output_type.push_back(data_type);
  45. builder->SetInputsFormat(input_format);
  46. builder->SetInputsDeviceType(input_type);
  47. builder->SetOutputsFormat(output_format);
  48. builder->SetOutputsDeviceType(output_type);
  49. builder->SetKernelType(HCCL_KERNEL);
  50. kernel_info_list->emplace_back(builder->Build());
  51. }
  52. }
  53. }
  54. } // namespace kernel
  55. } // namespace mindspore