| @@ -34,7 +34,7 @@ const std::unordered_map<std::string, TypeId> type_str_id_maps = { | |||
| {"int64", TypeId::kNumberTypeInt64}, {"uint", TypeId::kNumberTypeUInt}, | |||
| {"uint8", TypeId::kNumberTypeUInt8}, {"uint16", TypeId::kNumberTypeUInt16}, | |||
| {"uint32", TypeId::kNumberTypeUInt32}, {"uint64", TypeId::kNumberTypeUInt64}, | |||
| {"bool", TypeId::kNumberTypeBool}, | |||
| {"bool", TypeId::kNumberTypeBool}, {"", TypeId::kMetaTypeNone}, | |||
| }; | |||
| const std::map<TypeId, std::string> type_id_str_maps = { | |||
| @@ -45,7 +45,7 @@ const std::map<TypeId, std::string> type_id_str_maps = { | |||
| {TypeId::kNumberTypeInt64, "int64"}, {TypeId::kNumberTypeUInt, "uint"}, | |||
| {TypeId::kNumberTypeUInt8, "uint8"}, {TypeId::kNumberTypeUInt16, "uint16"}, | |||
| {TypeId::kNumberTypeUInt32, "uint32"}, {TypeId::kNumberTypeUInt64, "uint64"}, | |||
| {TypeId::kNumberTypeBool, "int8"}, | |||
| {TypeId::kNumberTypeBool, "int8"}, {TypeId::kMetaTypeNone, ""}, | |||
| }; | |||
| const std::map<std::string, std::string> type_str_maps = { | |||
| @@ -60,12 +60,8 @@ const std::map<std::string, std::vector<std::string>> kNextOpFormatList = { | |||
| bool MatchInferOutputDataType(const CNodePtr &cnode, const kernel::KernelBuildInfo &kernel_build_info) { | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| // Check input data type | |||
| auto name = AnfAlgo::GetCNodeName(cnode); | |||
| for (size_t input_index = 0; input_index < kernel_build_info.GetInputNum(); ++input_index) { | |||
| TypeId input_origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index); | |||
| if ((name == kDynamicRNNOpName || name == kDynamicGRUV2OpName) && input_origin_type == kMetaTypeNone) { | |||
| continue; | |||
| } | |||
| if (kernel_build_info.GetInputDeviceType(input_index) != input_origin_type) { | |||
| return false; | |||
| } | |||
| @@ -46,6 +46,38 @@ dynamic_gru_v2_op_info = TBERegOp("DynamicGRUV2") \ | |||
| .output(3, "reset", False, "optional", "all") \ | |||
| .output(4, "new", False, "optional", "all") \ | |||
| .output(5, "hidden_new", False, "optional", "all") \ | |||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.None_Default, DataType.F32_FracNZ, DataType.F32_FracNZ, | |||
| DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, | |||
| DataType.F32_FracNZ) \ | |||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F32_Default, | |||
| DataType.None_Default, DataType.None_Default, DataType.F32_FracNZ, DataType.F32_FracNZ, | |||
| DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, | |||
| DataType.F32_FracNZ) \ | |||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.None_Default, | |||
| DataType.F32_Default, DataType.None_Default, DataType.F32_FracNZ, DataType.F32_FracNZ, | |||
| DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, | |||
| DataType.F32_FracNZ) \ | |||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.None_Default, | |||
| DataType.None_Default, DataType.None_Default, DataType.F32_FracNZ, DataType.F32_FracNZ, | |||
| DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, | |||
| DataType.F32_FracNZ) \ | |||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_Default, | |||
| DataType.F16_Default, DataType.None_Default, DataType.F16_FracNZ, DataType.F16_FracNZ, | |||
| DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, | |||
| DataType.F16_FracNZ) \ | |||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_Default, | |||
| DataType.None_Default, DataType.None_Default, DataType.F16_FracNZ, DataType.F16_FracNZ, | |||
| DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, | |||
| DataType.F16_FracNZ) \ | |||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.None_Default, | |||
| DataType.F16_Default, DataType.None_Default, DataType.F16_FracNZ, DataType.F16_FracNZ, | |||
| DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, | |||
| DataType.F16_FracNZ) \ | |||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.None_Default, | |||
| DataType.None_Default, DataType.None_Default, DataType.F16_FracNZ, DataType.F16_FracNZ, | |||
| DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, | |||
| DataType.F16_FracNZ) \ | |||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.I32_Default, DataType.F32_FracNZ, DataType.F32_FracNZ, | |||
| DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, | |||
| @@ -52,6 +52,16 @@ dynamic_rnn_op_info = TBERegOp("DynamicRNN") \ | |||
| .output(5, "f", False, "required", "all") \ | |||
| .output(6, "o", False, "required", "all") \ | |||
| .output(7, "tanhc", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_FracZNLSTM, DataType.F32_Default, DataType.None_Default, | |||
| DataType.F16_FracNZ, DataType.F32_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, | |||
| DataType.F16_FracNZ, DataType.U8_Default, DataType.F32_FracNZ, DataType.F16_FracNZ, | |||
| DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, | |||
| DataType.F32_FracNZ, DataType.F32_FracNZ) \ | |||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_FracZNLSTM, DataType.F16_Default, DataType.None_Default, | |||
| DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, | |||
| DataType.F16_FracNZ, DataType.U8_Default, DataType.F16_FracNZ, DataType.F16_FracNZ, | |||
| DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, | |||
| DataType.F16_FracNZ, DataType.F16_FracNZ) \ | |||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_FracZNLSTM, DataType.F32_Default, DataType.I32_Default, | |||
| DataType.F16_FracNZ, DataType.F32_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, | |||
| DataType.F16_FracNZ, DataType.U8_Default, DataType.F32_FracNZ, DataType.F16_FracNZ, | |||
| @@ -522,6 +522,7 @@ class DataType: | |||
| """ | |||
| None_None = ("", "") | |||
| None_Default = ("", "DefaultFormat") | |||
| BOOL_None = ("bool", "") | |||
| BOOL_Default = ("bool", "DefaultFormat") | |||
| BOOL_5HD = ("bool", "NC1HWC0") | |||