From: @liu_xiao_93 Reviewed-by: @liangchenghui,@wuxuejian Signed-off-by: @liangchenghuitags/v1.2.0-rc1
| @@ -34,7 +34,7 @@ const std::unordered_map<std::string, TypeId> type_str_id_maps = { | |||||
| {"int64", TypeId::kNumberTypeInt64}, {"uint", TypeId::kNumberTypeUInt}, | {"int64", TypeId::kNumberTypeInt64}, {"uint", TypeId::kNumberTypeUInt}, | ||||
| {"uint8", TypeId::kNumberTypeUInt8}, {"uint16", TypeId::kNumberTypeUInt16}, | {"uint8", TypeId::kNumberTypeUInt8}, {"uint16", TypeId::kNumberTypeUInt16}, | ||||
| {"uint32", TypeId::kNumberTypeUInt32}, {"uint64", TypeId::kNumberTypeUInt64}, | {"uint32", TypeId::kNumberTypeUInt32}, {"uint64", TypeId::kNumberTypeUInt64}, | ||||
| {"bool", TypeId::kNumberTypeBool}, | |||||
| {"bool", TypeId::kNumberTypeBool}, {"", TypeId::kMetaTypeNone}, | |||||
| }; | }; | ||||
| const std::map<TypeId, std::string> type_id_str_maps = { | const std::map<TypeId, std::string> type_id_str_maps = { | ||||
| @@ -45,7 +45,7 @@ const std::map<TypeId, std::string> type_id_str_maps = { | |||||
| {TypeId::kNumberTypeInt64, "int64"}, {TypeId::kNumberTypeUInt, "uint"}, | {TypeId::kNumberTypeInt64, "int64"}, {TypeId::kNumberTypeUInt, "uint"}, | ||||
| {TypeId::kNumberTypeUInt8, "uint8"}, {TypeId::kNumberTypeUInt16, "uint16"}, | {TypeId::kNumberTypeUInt8, "uint8"}, {TypeId::kNumberTypeUInt16, "uint16"}, | ||||
| {TypeId::kNumberTypeUInt32, "uint32"}, {TypeId::kNumberTypeUInt64, "uint64"}, | {TypeId::kNumberTypeUInt32, "uint32"}, {TypeId::kNumberTypeUInt64, "uint64"}, | ||||
| {TypeId::kNumberTypeBool, "int8"}, | |||||
| {TypeId::kNumberTypeBool, "int8"}, {TypeId::kMetaTypeNone, ""}, | |||||
| }; | }; | ||||
| const std::map<std::string, std::string> type_str_maps = { | const std::map<std::string, std::string> type_str_maps = { | ||||
| @@ -60,12 +60,8 @@ const std::map<std::string, std::vector<std::string>> kNextOpFormatList = { | |||||
| bool MatchInferOutputDataType(const CNodePtr &cnode, const kernel::KernelBuildInfo &kernel_build_info) { | bool MatchInferOutputDataType(const CNodePtr &cnode, const kernel::KernelBuildInfo &kernel_build_info) { | ||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| // Check input data type | // Check input data type | ||||
| auto name = AnfAlgo::GetCNodeName(cnode); | |||||
| for (size_t input_index = 0; input_index < kernel_build_info.GetInputNum(); ++input_index) { | for (size_t input_index = 0; input_index < kernel_build_info.GetInputNum(); ++input_index) { | ||||
| TypeId input_origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index); | TypeId input_origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index); | ||||
| if ((name == kDynamicRNNOpName || name == kDynamicGRUV2OpName) && input_origin_type == kMetaTypeNone) { | |||||
| continue; | |||||
| } | |||||
| if (kernel_build_info.GetInputDeviceType(input_index) != input_origin_type) { | if (kernel_build_info.GetInputDeviceType(input_index) != input_origin_type) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -46,6 +46,38 @@ dynamic_gru_v2_op_info = TBERegOp("DynamicGRUV2") \ | |||||
| .output(3, "reset", False, "optional", "all") \ | .output(3, "reset", False, "optional", "all") \ | ||||
| .output(4, "new", False, "optional", "all") \ | .output(4, "new", False, "optional", "all") \ | ||||
| .output(5, "hidden_new", False, "optional", "all") \ | .output(5, "hidden_new", False, "optional", "all") \ | ||||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F32_Default, | |||||
| DataType.F32_Default, DataType.None_Default, DataType.F32_FracNZ, DataType.F32_FracNZ, | |||||
| DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, | |||||
| DataType.F32_FracNZ) \ | |||||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F32_Default, | |||||
| DataType.None_Default, DataType.None_Default, DataType.F32_FracNZ, DataType.F32_FracNZ, | |||||
| DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, | |||||
| DataType.F32_FracNZ) \ | |||||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.None_Default, | |||||
| DataType.F32_Default, DataType.None_Default, DataType.F32_FracNZ, DataType.F32_FracNZ, | |||||
| DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, | |||||
| DataType.F32_FracNZ) \ | |||||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.None_Default, | |||||
| DataType.None_Default, DataType.None_Default, DataType.F32_FracNZ, DataType.F32_FracNZ, | |||||
| DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, | |||||
| DataType.F32_FracNZ) \ | |||||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_Default, | |||||
| DataType.F16_Default, DataType.None_Default, DataType.F16_FracNZ, DataType.F16_FracNZ, | |||||
| DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, | |||||
| DataType.F16_FracNZ) \ | |||||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_Default, | |||||
| DataType.None_Default, DataType.None_Default, DataType.F16_FracNZ, DataType.F16_FracNZ, | |||||
| DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, | |||||
| DataType.F16_FracNZ) \ | |||||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.None_Default, | |||||
| DataType.F16_Default, DataType.None_Default, DataType.F16_FracNZ, DataType.F16_FracNZ, | |||||
| DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, | |||||
| DataType.F16_FracNZ) \ | |||||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.None_Default, | |||||
| DataType.None_Default, DataType.None_Default, DataType.F16_FracNZ, DataType.F16_FracNZ, | |||||
| DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, | |||||
| DataType.F16_FracNZ) \ | |||||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F32_Default, | .dtype_format(DataType.F16_FracNZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F32_Default, | ||||
| DataType.F32_Default, DataType.I32_Default, DataType.F32_FracNZ, DataType.F32_FracNZ, | DataType.F32_Default, DataType.I32_Default, DataType.F32_FracNZ, DataType.F32_FracNZ, | ||||
| DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, | DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, | ||||
| @@ -52,6 +52,16 @@ dynamic_rnn_op_info = TBERegOp("DynamicRNN") \ | |||||
| .output(5, "f", False, "required", "all") \ | .output(5, "f", False, "required", "all") \ | ||||
| .output(6, "o", False, "required", "all") \ | .output(6, "o", False, "required", "all") \ | ||||
| .output(7, "tanhc", False, "required", "all") \ | .output(7, "tanhc", False, "required", "all") \ | ||||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_FracZNLSTM, DataType.F32_Default, DataType.None_Default, | |||||
| DataType.F16_FracNZ, DataType.F32_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, | |||||
| DataType.F16_FracNZ, DataType.U8_Default, DataType.F32_FracNZ, DataType.F16_FracNZ, | |||||
| DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, | |||||
| DataType.F32_FracNZ, DataType.F32_FracNZ) \ | |||||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_FracZNLSTM, DataType.F16_Default, DataType.None_Default, | |||||
| DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, | |||||
| DataType.F16_FracNZ, DataType.U8_Default, DataType.F16_FracNZ, DataType.F16_FracNZ, | |||||
| DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, | |||||
| DataType.F16_FracNZ, DataType.F16_FracNZ) \ | |||||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_FracZNLSTM, DataType.F32_Default, DataType.I32_Default, | .dtype_format(DataType.F16_FracNZ, DataType.F16_FracZNLSTM, DataType.F32_Default, DataType.I32_Default, | ||||
| DataType.F16_FracNZ, DataType.F32_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, | DataType.F16_FracNZ, DataType.F32_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, | ||||
| DataType.F16_FracNZ, DataType.U8_Default, DataType.F32_FracNZ, DataType.F16_FracNZ, | DataType.F16_FracNZ, DataType.U8_Default, DataType.F32_FracNZ, DataType.F16_FracNZ, | ||||
| @@ -522,6 +522,7 @@ class DataType: | |||||
| """ | """ | ||||
| None_None = ("", "") | None_None = ("", "") | ||||
| None_Default = ("", "DefaultFormat") | |||||
| BOOL_None = ("bool", "") | BOOL_None = ("bool", "") | ||||
| BOOL_Default = ("bool", "DefaultFormat") | BOOL_Default = ("bool", "DefaultFormat") | ||||
| BOOL_5HD = ("bool", "NC1HWC0") | BOOL_5HD = ("bool", "NC1HWC0") | ||||