You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tbe_convert_utils.cc 4.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "kernel/tbe/tbe_convert_utils.h"
  17. #include <unordered_map>
  18. #include <map>
  19. #include <string>
  20. #include "session/anf_runtime_algorithm.h"
  21. #include "common/utils.h"
  22. namespace mindspore {
  23. namespace kernel {
  24. namespace tbe {
  25. const std::unordered_map<std::string, TypeId> type_str_id_maps = {
  26. {"float", TypeId::kNumberTypeFloat32}, {"float16", TypeId::kNumberTypeFloat16},
  27. {"float32", TypeId::kNumberTypeFloat32}, {"float64", TypeId::kNumberTypeFloat64},
  28. {"int", TypeId::kNumberTypeInt}, {"int8", TypeId::kNumberTypeInt8},
  29. {"int16", TypeId::kNumberTypeInt16}, {"int32", TypeId::kNumberTypeInt32},
  30. {"int64", TypeId::kNumberTypeInt64}, {"uint", TypeId::kNumberTypeUInt},
  31. {"uint8", TypeId::kNumberTypeUInt8}, {"uint16", TypeId::kNumberTypeUInt16},
  32. {"uint32", TypeId::kNumberTypeUInt32}, {"uint64", TypeId::kNumberTypeUInt64},
  33. {"bool", TypeId::kNumberTypeBool},
  34. };
  35. const std::map<TypeId, std::string> type_id_str_maps = {
  36. {TypeId::kNumberTypeFloat32, "float32"}, {TypeId::kNumberTypeFloat16, "float16"},
  37. {TypeId::kNumberTypeFloat, "float"}, {TypeId::kNumberTypeFloat64, "float64"},
  38. {TypeId::kNumberTypeInt, "int"}, {TypeId::kNumberTypeInt8, "int8"},
  39. {TypeId::kNumberTypeInt16, "int16"}, {TypeId::kNumberTypeInt32, "int32"},
  40. {TypeId::kNumberTypeInt64, "int64"}, {TypeId::kNumberTypeUInt, "uint"},
  41. {TypeId::kNumberTypeUInt8, "uint8"}, {TypeId::kNumberTypeUInt16, "uint16"},
  42. {TypeId::kNumberTypeUInt32, "uint32"}, {TypeId::kNumberTypeUInt64, "uint64"},
  43. {TypeId::kNumberTypeBool, "bool"},
  44. };
  45. const std::map<std::string, std::string> type_str_maps = {
  46. {"Float32", "float32"}, {"Float16", "float16"}, {"Int8", "int8"}, {"Int16", "int16"},
  47. {"UInt16", "uint16"}, {"UInt8", "uint8"}, {"Int32", "int32"}, {"UInt32", "uint32"},
  48. {"Int64", "int64"}, {"UInt64", "uint64"}, {"Bool_", "int8"}, {"Float64", "float64"},
  49. };
  50. const std::unordered_map<std::string, size_t> type_nbyte_maps = {
  51. {"float16", sizeof(float) / 2}, {"float32", sizeof(float)}, {"float64", sizeof(float) * 2},
  52. {"int8", sizeof(int) / 4}, {"int16", sizeof(int) / 2}, {"int32", sizeof(int)},
  53. {"int64", sizeof(int) * 2}, {"uint8", sizeof(int) / 4}, {"uint16", sizeof(int) / 2},
  54. {"uint32", sizeof(int)}, {"uint64", sizeof(int) * 2}, {"bool", sizeof(char)},
  55. };
  56. const std::unordered_map<std::string, FusionType> fusion_type_maps = {
  57. {"CONVLUTION", FusionType::CONVLUTION}, {"ELEMWISE", FusionType::ELEMWISE}, {"COMMREDUCE", FusionType::COMMREDUCE},
  58. {"SEGMENT", FusionType::SEGMENT}, {"OPAQUE", FusionType::OPAQUE},
  59. };
  60. TypeId DtypeToTypeId(const std::string &dtypes) {
  61. auto iter = type_str_id_maps.find(dtypes);
  62. if (iter == type_str_id_maps.end()) {
  63. MS_LOG(EXCEPTION) << "Illegal input device dtype: " << dtypes;
  64. }
  65. return iter->second;
  66. }
  67. std::string DtypeToString(const std::string &dtypes) {
  68. auto iter = type_str_maps.find(dtypes);
  69. if (iter == type_str_maps.end()) {
  70. MS_LOG(EXCEPTION) << "Illegal input dtype: " << dtypes;
  71. }
  72. return iter->second;
  73. }
  74. std::string TypeIdToString(TypeId type_id) {
  75. auto iter = type_id_str_maps.find(type_id);
  76. if (iter == type_id_str_maps.end()) {
  77. MS_LOG(EXCEPTION) << "Illegal input dtype." << TypeIdLabel(type_id);
  78. }
  79. return iter->second;
  80. }
  81. size_t GetDtypeNbyte(const std::string &dtypes) {
  82. auto iter = type_nbyte_maps.find(dtypes);
  83. if (iter == type_nbyte_maps.end()) {
  84. MS_LOG(EXCEPTION) << "Illegal input dtype: " << dtypes;
  85. }
  86. return iter->second;
  87. }
  88. FusionType GetFusionType(const std::string &pattern) {
  89. auto iter = fusion_type_maps.find(pattern);
  90. if (iter == fusion_type_maps.end()) {
  91. MS_LOG(DEBUG) << "Illegal fusion pattern: " << pattern;
  92. return UNKNOWN_FUSION_TYPE;
  93. }
  94. return iter->second;
  95. }
  96. std::string GetProcessor(const AnfNodePtr &anf_node) {
  97. MS_EXCEPTION_IF_NULL(anf_node);
  98. std::string device;
  99. switch (AnfAlgo::GetProcessor(anf_node)) {
  100. case Processor::AICORE:
  101. device = kProcessorAiCore;
  102. break;
  103. default:
  104. MS_LOG(DEBUG) << "Unknown processor type." << anf_node->fullname_with_scope();
  105. break;
  106. }
  107. return device;
  108. }
  109. } // namespace tbe
  110. } // namespace kernel
  111. } // namespace mindspore