/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "kernel/tbe/tbe_convert_utils.h" #include #include #include #include "session/anf_runtime_algorithm.h" #include "common/utils.h" namespace mindspore { namespace kernel { namespace tbe { const std::unordered_map type_str_id_maps = { {"float", TypeId::kNumberTypeFloat32}, {"float16", TypeId::kNumberTypeFloat16}, {"float32", TypeId::kNumberTypeFloat32}, {"float64", TypeId::kNumberTypeFloat64}, {"int", TypeId::kNumberTypeInt}, {"int8", TypeId::kNumberTypeInt8}, {"int16", TypeId::kNumberTypeInt16}, {"int32", TypeId::kNumberTypeInt32}, {"int64", TypeId::kNumberTypeInt64}, {"uint", TypeId::kNumberTypeUInt}, {"uint8", TypeId::kNumberTypeUInt8}, {"uint16", TypeId::kNumberTypeUInt16}, {"uint32", TypeId::kNumberTypeUInt32}, {"uint64", TypeId::kNumberTypeUInt64}, {"bool", TypeId::kNumberTypeBool}, }; const std::map type_id_str_maps = { {TypeId::kNumberTypeFloat32, "float32"}, {TypeId::kNumberTypeFloat16, "float16"}, {TypeId::kNumberTypeFloat, "float"}, {TypeId::kNumberTypeFloat64, "float64"}, {TypeId::kNumberTypeInt, "int"}, {TypeId::kNumberTypeInt8, "int8"}, {TypeId::kNumberTypeInt16, "int16"}, {TypeId::kNumberTypeInt32, "int32"}, {TypeId::kNumberTypeInt64, "int64"}, {TypeId::kNumberTypeUInt, "uint"}, {TypeId::kNumberTypeUInt8, "uint8"}, {TypeId::kNumberTypeUInt16, "uint16"}, {TypeId::kNumberTypeUInt32, "uint32"}, {TypeId::kNumberTypeUInt64, "uint64"}, {TypeId::kNumberTypeBool, "bool"}, }; const std::map type_str_maps = { {"Float32", "float32"}, {"Float16", "float16"}, {"Int8", "int8"}, {"Int16", "int16"}, {"UInt16", "uint16"}, {"UInt8", "uint8"}, {"Int32", "int32"}, {"UInt32", "uint32"}, {"Int64", "int64"}, {"UInt64", "uint64"}, {"Bool_", "int8"}, {"Float64", "float64"}, }; const std::unordered_map type_nbyte_maps = { {"float16", sizeof(float) / 2}, {"float32", sizeof(float)}, {"float64", sizeof(float) * 2}, {"int8", sizeof(int) / 4}, {"int16", sizeof(int) / 2}, {"int32", sizeof(int)}, {"int64", sizeof(int) * 2}, {"uint8", sizeof(int) / 4}, {"uint16", sizeof(int) / 2}, {"uint32", sizeof(int)}, {"uint64", sizeof(int) * 2}, {"bool", sizeof(char)}, }; const std::unordered_map fusion_type_maps = { {"CONVLUTION", FusionType::CONVLUTION}, {"ELEMWISE", FusionType::ELEMWISE}, {"COMMREDUCE", FusionType::COMMREDUCE}, {"SEGMENT", FusionType::SEGMENT}, {"OPAQUE", FusionType::OPAQUE}, }; TypeId DtypeToTypeId(const std::string &dtypes) { auto iter = type_str_id_maps.find(dtypes); if (iter == type_str_id_maps.end()) { MS_LOG(EXCEPTION) << "Illegal input device dtype: " << dtypes; } return iter->second; } std::string DtypeToString(const std::string &dtypes) { auto iter = type_str_maps.find(dtypes); if (iter == type_str_maps.end()) { MS_LOG(EXCEPTION) << "Illegal input dtype: " << dtypes; } return iter->second; } std::string TypeIdToString(TypeId type_id) { auto iter = type_id_str_maps.find(type_id); if (iter == type_id_str_maps.end()) { MS_LOG(EXCEPTION) << "Illegal input dtype." << TypeIdLabel(type_id); } return iter->second; } size_t GetDtypeNbyte(const std::string &dtypes) { auto iter = type_nbyte_maps.find(dtypes); if (iter == type_nbyte_maps.end()) { MS_LOG(EXCEPTION) << "Illegal input dtype: " << dtypes; } return iter->second; } FusionType GetFusionType(const std::string &pattern) { auto iter = fusion_type_maps.find(pattern); if (iter == fusion_type_maps.end()) { MS_LOG(DEBUG) << "Illegal fusion pattern: " << pattern; return UNKNOWN_FUSION_TYPE; } return iter->second; } std::string GetProcessor(const AnfNodePtr &anf_node) { MS_EXCEPTION_IF_NULL(anf_node); std::string device; switch (AnfAlgo::GetProcessor(anf_node)) { case Processor::AICORE: device = kProcessorAiCore; break; default: MS_LOG(DEBUG) << "Unknown processor type." << anf_node->fullname_with_scope(); break; } return device; } } // namespace tbe } // namespace kernel } // namespace mindspore