From b0a66cc3a250dd84e56635cac85b6b03a0ab8884 Mon Sep 17 00:00:00 2001 From: liuxiao93 Date: Wed, 13 Jan 2021 15:06:09 +0800 Subject: [PATCH] fix unsupported transnode KMetaTypeNonexDefaultFormat of DynamicRNN and DynamiGRU. --- .../kernel_compiler/tbe/tbe_convert_utils.cc | 4 +-- .../device/ascend/kernel_select_ascend.cc | 4 --- mindspore/ops/_op_impl/tbe/dynamic_gru_v2.py | 32 +++++++++++++++++++ mindspore/ops/_op_impl/tbe/dynamic_rnn.py | 10 ++++++ mindspore/ops/op_info_register.py | 1 + 5 files changed, 45 insertions(+), 6 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_convert_utils.cc b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_convert_utils.cc index 806b06d6f2..9f91e7ff82 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_convert_utils.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_convert_utils.cc @@ -34,7 +34,7 @@ const std::unordered_map type_str_id_maps = { {"int64", TypeId::kNumberTypeInt64}, {"uint", TypeId::kNumberTypeUInt}, {"uint8", TypeId::kNumberTypeUInt8}, {"uint16", TypeId::kNumberTypeUInt16}, {"uint32", TypeId::kNumberTypeUInt32}, {"uint64", TypeId::kNumberTypeUInt64}, - {"bool", TypeId::kNumberTypeBool}, + {"bool", TypeId::kNumberTypeBool}, {"", TypeId::kMetaTypeNone}, }; const std::map type_id_str_maps = { @@ -45,7 +45,7 @@ const std::map type_id_str_maps = { {TypeId::kNumberTypeInt64, "int64"}, {TypeId::kNumberTypeUInt, "uint"}, {TypeId::kNumberTypeUInt8, "uint8"}, {TypeId::kNumberTypeUInt16, "uint16"}, {TypeId::kNumberTypeUInt32, "uint32"}, {TypeId::kNumberTypeUInt64, "uint64"}, - {TypeId::kNumberTypeBool, "int8"}, + {TypeId::kNumberTypeBool, "int8"}, {TypeId::kMetaTypeNone, ""}, }; const std::map type_str_maps = { diff --git a/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc b/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc index 100841e6e5..a07fa0cc04 100644 --- a/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc +++ b/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc @@ -60,12 +60,8 @@ const std::map> kNextOpFormatList = { bool MatchInferOutputDataType(const CNodePtr &cnode, const kernel::KernelBuildInfo &kernel_build_info) { MS_EXCEPTION_IF_NULL(cnode); // Check input data type - auto name = AnfAlgo::GetCNodeName(cnode); for (size_t input_index = 0; input_index < kernel_build_info.GetInputNum(); ++input_index) { TypeId input_origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index); - if ((name == kDynamicRNNOpName || name == kDynamicGRUV2OpName) && input_origin_type == kMetaTypeNone) { - continue; - } if (kernel_build_info.GetInputDeviceType(input_index) != input_origin_type) { return false; } diff --git a/mindspore/ops/_op_impl/tbe/dynamic_gru_v2.py b/mindspore/ops/_op_impl/tbe/dynamic_gru_v2.py index 8ef089e04a..1b61b9fdf2 100644 --- a/mindspore/ops/_op_impl/tbe/dynamic_gru_v2.py +++ b/mindspore/ops/_op_impl/tbe/dynamic_gru_v2.py @@ -46,6 +46,38 @@ dynamic_gru_v2_op_info = TBERegOp("DynamicGRUV2") \ .output(3, "reset", False, "optional", "all") \ .output(4, "new", False, "optional", "all") \ .output(5, "hidden_new", False, "optional", "all") \ + .dtype_format(DataType.F16_FracNZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F32_Default, + DataType.F32_Default, DataType.None_Default, DataType.F32_FracNZ, DataType.F32_FracNZ, + DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, + DataType.F32_FracNZ) \ + .dtype_format(DataType.F16_FracNZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F32_Default, + DataType.None_Default, DataType.None_Default, DataType.F32_FracNZ, DataType.F32_FracNZ, + DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, + DataType.F32_FracNZ) \ + .dtype_format(DataType.F16_FracNZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.None_Default, + DataType.F32_Default, DataType.None_Default, DataType.F32_FracNZ, DataType.F32_FracNZ, + DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, + DataType.F32_FracNZ) \ + .dtype_format(DataType.F16_FracNZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.None_Default, + DataType.None_Default, DataType.None_Default, DataType.F32_FracNZ, DataType.F32_FracNZ, + DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, + DataType.F32_FracNZ) \ + .dtype_format(DataType.F16_FracNZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_Default, + DataType.F16_Default, DataType.None_Default, DataType.F16_FracNZ, DataType.F16_FracNZ, + DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, + DataType.F16_FracNZ) \ + .dtype_format(DataType.F16_FracNZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_Default, + DataType.None_Default, DataType.None_Default, DataType.F16_FracNZ, DataType.F16_FracNZ, + DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, + DataType.F16_FracNZ) \ + .dtype_format(DataType.F16_FracNZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.None_Default, + DataType.F16_Default, DataType.None_Default, DataType.F16_FracNZ, DataType.F16_FracNZ, + DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, + DataType.F16_FracNZ) \ + .dtype_format(DataType.F16_FracNZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.None_Default, + DataType.None_Default, DataType.None_Default, DataType.F16_FracNZ, DataType.F16_FracNZ, + DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, + DataType.F16_FracNZ) \ .dtype_format(DataType.F16_FracNZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F32_Default, DataType.F32_Default, DataType.I32_Default, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, diff --git a/mindspore/ops/_op_impl/tbe/dynamic_rnn.py b/mindspore/ops/_op_impl/tbe/dynamic_rnn.py index 7d2a06b74b..da420657b3 100644 --- a/mindspore/ops/_op_impl/tbe/dynamic_rnn.py +++ b/mindspore/ops/_op_impl/tbe/dynamic_rnn.py @@ -52,6 +52,16 @@ dynamic_rnn_op_info = TBERegOp("DynamicRNN") \ .output(5, "f", False, "required", "all") \ .output(6, "o", False, "required", "all") \ .output(7, "tanhc", False, "required", "all") \ + .dtype_format(DataType.F16_FracNZ, DataType.F16_FracZNLSTM, DataType.F32_Default, DataType.None_Default, + DataType.F16_FracNZ, DataType.F32_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, + DataType.F16_FracNZ, DataType.U8_Default, DataType.F32_FracNZ, DataType.F16_FracNZ, + DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, + DataType.F32_FracNZ, DataType.F32_FracNZ) \ + .dtype_format(DataType.F16_FracNZ, DataType.F16_FracZNLSTM, DataType.F16_Default, DataType.None_Default, + DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, + DataType.F16_FracNZ, DataType.U8_Default, DataType.F16_FracNZ, DataType.F16_FracNZ, + DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, + DataType.F16_FracNZ, DataType.F16_FracNZ) \ .dtype_format(DataType.F16_FracNZ, DataType.F16_FracZNLSTM, DataType.F32_Default, DataType.I32_Default, DataType.F16_FracNZ, DataType.F32_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.U8_Default, DataType.F32_FracNZ, DataType.F16_FracNZ, diff --git a/mindspore/ops/op_info_register.py b/mindspore/ops/op_info_register.py index 907a74523d..88b9410503 100644 --- a/mindspore/ops/op_info_register.py +++ b/mindspore/ops/op_info_register.py @@ -522,6 +522,7 @@ class DataType: """ None_None = ("", "") + None_Default = ("", "DefaultFormat") BOOL_None = ("bool", "") BOOL_Default = ("bool", "DefaultFormat") BOOL_5HD = ("bool", "NC1HWC0")