Browse Source

!11224 Fix unsupported transnode KMetaTypeNonexDefaultFormat.

From: @liu_xiao_93
Reviewed-by: @liangchenghui,@wuxuejian
Signed-off-by: @liangchenghui
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 5 years ago
parent
commit
728d66be53
5 changed files with 45 additions and 6 deletions
  1. +2
    -2
      mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_convert_utils.cc
  2. +0
    -4
      mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc
  3. +32
    -0
      mindspore/ops/_op_impl/tbe/dynamic_gru_v2.py
  4. +10
    -0
      mindspore/ops/_op_impl/tbe/dynamic_rnn.py
  5. +1
    -0
      mindspore/ops/op_info_register.py

+ 2
- 2
mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_convert_utils.cc View File

@@ -34,7 +34,7 @@ const std::unordered_map<std::string, TypeId> type_str_id_maps = {
{"int64", TypeId::kNumberTypeInt64}, {"uint", TypeId::kNumberTypeUInt},
{"uint8", TypeId::kNumberTypeUInt8}, {"uint16", TypeId::kNumberTypeUInt16},
{"uint32", TypeId::kNumberTypeUInt32}, {"uint64", TypeId::kNumberTypeUInt64},
{"bool", TypeId::kNumberTypeBool},
{"bool", TypeId::kNumberTypeBool}, {"", TypeId::kMetaTypeNone},
};

const std::map<TypeId, std::string> type_id_str_maps = {
@@ -45,7 +45,7 @@ const std::map<TypeId, std::string> type_id_str_maps = {
{TypeId::kNumberTypeInt64, "int64"}, {TypeId::kNumberTypeUInt, "uint"},
{TypeId::kNumberTypeUInt8, "uint8"}, {TypeId::kNumberTypeUInt16, "uint16"},
{TypeId::kNumberTypeUInt32, "uint32"}, {TypeId::kNumberTypeUInt64, "uint64"},
{TypeId::kNumberTypeBool, "int8"},
{TypeId::kNumberTypeBool, "int8"}, {TypeId::kMetaTypeNone, ""},
};

const std::map<std::string, std::string> type_str_maps = {


+ 0
- 4
mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc View File

@@ -60,12 +60,8 @@ const std::map<std::string, std::vector<std::string>> kNextOpFormatList = {
bool MatchInferOutputDataType(const CNodePtr &cnode, const kernel::KernelBuildInfo &kernel_build_info) {
MS_EXCEPTION_IF_NULL(cnode);
// Check input data type
auto name = AnfAlgo::GetCNodeName(cnode);
for (size_t input_index = 0; input_index < kernel_build_info.GetInputNum(); ++input_index) {
TypeId input_origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index);
if ((name == kDynamicRNNOpName || name == kDynamicGRUV2OpName) && input_origin_type == kMetaTypeNone) {
continue;
}
if (kernel_build_info.GetInputDeviceType(input_index) != input_origin_type) {
return false;
}


+ 32
- 0
mindspore/ops/_op_impl/tbe/dynamic_gru_v2.py View File

@@ -46,6 +46,38 @@ dynamic_gru_v2_op_info = TBERegOp("DynamicGRUV2") \
.output(3, "reset", False, "optional", "all") \
.output(4, "new", False, "optional", "all") \
.output(5, "hidden_new", False, "optional", "all") \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F32_Default,
DataType.F32_Default, DataType.None_Default, DataType.F32_FracNZ, DataType.F32_FracNZ,
DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ,
DataType.F32_FracNZ) \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F32_Default,
DataType.None_Default, DataType.None_Default, DataType.F32_FracNZ, DataType.F32_FracNZ,
DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ,
DataType.F32_FracNZ) \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.None_Default,
DataType.F32_Default, DataType.None_Default, DataType.F32_FracNZ, DataType.F32_FracNZ,
DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ,
DataType.F32_FracNZ) \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.None_Default,
DataType.None_Default, DataType.None_Default, DataType.F32_FracNZ, DataType.F32_FracNZ,
DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ,
DataType.F32_FracNZ) \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_Default,
DataType.F16_Default, DataType.None_Default, DataType.F16_FracNZ, DataType.F16_FracNZ,
DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ,
DataType.F16_FracNZ) \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_Default,
DataType.None_Default, DataType.None_Default, DataType.F16_FracNZ, DataType.F16_FracNZ,
DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ,
DataType.F16_FracNZ) \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.None_Default,
DataType.F16_Default, DataType.None_Default, DataType.F16_FracNZ, DataType.F16_FracNZ,
DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ,
DataType.F16_FracNZ) \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.None_Default,
DataType.None_Default, DataType.None_Default, DataType.F16_FracNZ, DataType.F16_FracNZ,
DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ,
DataType.F16_FracNZ) \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F32_Default,
DataType.F32_Default, DataType.I32_Default, DataType.F32_FracNZ, DataType.F32_FracNZ,
DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ,


+ 10
- 0
mindspore/ops/_op_impl/tbe/dynamic_rnn.py View File

@@ -52,6 +52,16 @@ dynamic_rnn_op_info = TBERegOp("DynamicRNN") \
.output(5, "f", False, "required", "all") \
.output(6, "o", False, "required", "all") \
.output(7, "tanhc", False, "required", "all") \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracZNLSTM, DataType.F32_Default, DataType.None_Default,
DataType.F16_FracNZ, DataType.F32_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ,
DataType.F16_FracNZ, DataType.U8_Default, DataType.F32_FracNZ, DataType.F16_FracNZ,
DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ,
DataType.F32_FracNZ, DataType.F32_FracNZ) \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracZNLSTM, DataType.F16_Default, DataType.None_Default,
DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ,
DataType.F16_FracNZ, DataType.U8_Default, DataType.F16_FracNZ, DataType.F16_FracNZ,
DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ,
DataType.F16_FracNZ, DataType.F16_FracNZ) \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracZNLSTM, DataType.F32_Default, DataType.I32_Default,
DataType.F16_FracNZ, DataType.F32_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ,
DataType.F16_FracNZ, DataType.U8_Default, DataType.F32_FracNZ, DataType.F16_FracNZ,


+ 1
- 0
mindspore/ops/op_info_register.py View File

@@ -522,6 +522,7 @@ class DataType:
"""

None_None = ("", "")
None_Default = ("", "DefaultFormat")
BOOL_None = ("bool", "")
BOOL_Default = ("bool", "DefaultFormat")
BOOL_5HD = ("bool", "NC1HWC0")


Loading…
Cancel
Save