Browse Source

!293 Register all tbe op info with new registration mode

Merge pull request !293 from zjun/Modify_all_tbe_op
tags/v0.2.0-alpha
mindspore-ci-bot Gitee 6 years ago
parent
commit
66377e4c92
100 changed files with 2288 additions and 7782 deletions
  1. +19
    -62
      mindspore/ops/_op_impl/tbe/add.py
  2. +25
    -53
      mindspore/ops/_op_impl/tbe/add_n.py
  3. +58
    -206
      mindspore/ops/_op_impl/tbe/apply_adam.py
  4. +34
    -104
      mindspore/ops/_op_impl/tbe/apply_momentum.py
  5. +17
    -62
      mindspore/ops/_op_impl/tbe/arg_max_with_value.py
  6. +17
    -62
      mindspore/ops/_op_impl/tbe/arg_min_with_value.py
  7. +34
    -84
      mindspore/ops/_op_impl/tbe/assign.py
  8. +25
    -71
      mindspore/ops/_op_impl/tbe/assign_add.py
  9. +18
    -56
      mindspore/ops/_op_impl/tbe/assign_sub.py
  10. +12
    -23
      mindspore/ops/_op_impl/tbe/atomic_addr_clean.py
  11. +21
    -80
      mindspore/ops/_op_impl/tbe/batch_matmul.py
  12. +37
    -166
      mindspore/ops/_op_impl/tbe/batchnorm.py
  13. +37
    -173
      mindspore/ops/_op_impl/tbe/batchnorm_grad.py
  14. +18
    -62
      mindspore/ops/_op_impl/tbe/bias_add.py
  15. +18
    -49
      mindspore/ops/_op_impl/tbe/bias_add_grad.py
  16. +16
    -52
      mindspore/ops/_op_impl/tbe/bn_training_reduce.py
  17. +24
    -126
      mindspore/ops/_op_impl/tbe/bn_training_reduce_grad.py
  18. +32
    -192
      mindspore/ops/_op_impl/tbe/bn_training_update.py
  19. +22
    -101
      mindspore/ops/_op_impl/tbe/bn_training_update_grad.py
  20. +34
    -61
      mindspore/ops/_op_impl/tbe/cast.py
  21. +20
    -82
      mindspore/ops/_op_impl/tbe/clip_by_norm_no_div_sum.py
  22. +22
    -77
      mindspore/ops/_op_impl/tbe/clip_by_value.py
  23. +36
    -133
      mindspore/ops/_op_impl/tbe/concat.py
  24. +19
    -56
      mindspore/ops/_op_impl/tbe/confusion_softmax_grad.py
  25. +36
    -71
      mindspore/ops/_op_impl/tbe/confusion_transpose_d.py
  26. +22
    -106
      mindspore/ops/_op_impl/tbe/conv2d.py
  27. +19
    -81
      mindspore/ops/_op_impl/tbe/conv2d_backprop_filter.py
  28. +19
    -80
      mindspore/ops/_op_impl/tbe/conv2d_backprop_input.py
  29. +23
    -62
      mindspore/ops/_op_impl/tbe/div.py
  30. +17
    -68
      mindspore/ops/_op_impl/tbe/dropout_do_mask.py
  31. +23
    -57
      mindspore/ops/_op_impl/tbe/equal.py
  32. +16
    -43
      mindspore/ops/_op_impl/tbe/exp.py
  33. +17
    -49
      mindspore/ops/_op_impl/tbe/expand_dims.py
  34. +18
    -55
      mindspore/ops/_op_impl/tbe/floor_div.py
  35. +30
    -85
      mindspore/ops/_op_impl/tbe/fused_mul_add.py
  36. +22
    -77
      mindspore/ops/_op_impl/tbe/fused_mul_add_n.py
  37. +35
    -129
      mindspore/ops/_op_impl/tbe/fused_mul_apply_momentum.py
  38. +45
    -86
      mindspore/ops/_op_impl/tbe/gather_v2.py
  39. +21
    -43
      mindspore/ops/_op_impl/tbe/gelu.py
  40. +21
    -69
      mindspore/ops/_op_impl/tbe/gelu_grad.py
  41. +23
    -59
      mindspore/ops/_op_impl/tbe/greater.py
  42. +38
    -271
      mindspore/ops/_op_impl/tbe/lamb_next_mv.py
  43. +38
    -271
      mindspore/ops/_op_impl/tbe/lamb_next_mv_with_decay_v1.py
  44. +27
    -166
      mindspore/ops/_op_impl/tbe/lamb_update_with_lr.py
  45. +23
    -136
      mindspore/ops/_op_impl/tbe/lamb_update_with_lr_v2.py
  46. +30
    -102
      mindspore/ops/_op_impl/tbe/layer_norm.py
  47. +30
    -97
      mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop.py
  48. +26
    -115
      mindspore/ops/_op_impl/tbe/layer_norm_grad.py
  49. +28
    -93
      mindspore/ops/_op_impl/tbe/layer_norm_x_backprop.py
  50. +23
    -58
      mindspore/ops/_op_impl/tbe/less.py
  51. +25
    -58
      mindspore/ops/_op_impl/tbe/less_equal.py
  52. +16
    -43
      mindspore/ops/_op_impl/tbe/log.py
  53. +17
    -56
      mindspore/ops/_op_impl/tbe/logical_and.py
  54. +16
    -43
      mindspore/ops/_op_impl/tbe/logical_not.py
  55. +17
    -56
      mindspore/ops/_op_impl/tbe/logical_or.py
  56. +16
    -49
      mindspore/ops/_op_impl/tbe/logsoftmax.py
  57. +17
    -62
      mindspore/ops/_op_impl/tbe/logsoftmax_grad.py
  58. +21
    -81
      mindspore/ops/_op_impl/tbe/matmul.py
  59. +18
    -66
      mindspore/ops/_op_impl/tbe/max_pool.py
  60. +19
    -85
      mindspore/ops/_op_impl/tbe/max_pool_grad.py
  61. +20
    -87
      mindspore/ops/_op_impl/tbe/max_pool_grad_with_argmax.py
  62. +18
    -74
      mindspore/ops/_op_impl/tbe/max_pool_with_argmax.py
  63. +20
    -61
      mindspore/ops/_op_impl/tbe/maximum.py
  64. +30
    -104
      mindspore/ops/_op_impl/tbe/maximum_grad.py
  65. +19
    -65
      mindspore/ops/_op_impl/tbe/minimum.py
  66. +30
    -104
      mindspore/ops/_op_impl/tbe/minimum_grad.py
  67. +28
    -68
      mindspore/ops/_op_impl/tbe/mul.py
  68. +20
    -42
      mindspore/ops/_op_impl/tbe/neg.py
  69. +12
    -30
      mindspore/ops/_op_impl/tbe/npu_alloc_float_status.py
  70. +13
    -43
      mindspore/ops/_op_impl/tbe/npu_clear_float_status.py
  71. +13
    -43
      mindspore/ops/_op_impl/tbe/npu_get_float_status.py
  72. +27
    -88
      mindspore/ops/_op_impl/tbe/one_hot.py
  73. +19
    -49
      mindspore/ops/_op_impl/tbe/pad_d.py
  74. +18
    -56
      mindspore/ops/_op_impl/tbe/pow.py
  75. +17
    -55
      mindspore/ops/_op_impl/tbe/real_div.py
  76. +18
    -43
      mindspore/ops/_op_impl/tbe/reciprocal.py
  77. +21
    -55
      mindspore/ops/_op_impl/tbe/reduce_max.py
  78. +19
    -55
      mindspore/ops/_op_impl/tbe/reduce_mean.py
  79. +19
    -55
      mindspore/ops/_op_impl/tbe/reduce_mean_d.py
  80. +23
    -55
      mindspore/ops/_op_impl/tbe/reduce_min.py
  81. +17
    -55
      mindspore/ops/_op_impl/tbe/reduce_sum.py
  82. +20
    -45
      mindspore/ops/_op_impl/tbe/relu.py
  83. +23
    -59
      mindspore/ops/_op_impl/tbe/relu_grad.py
  84. +17
    -49
      mindspore/ops/_op_impl/tbe/reshape.py
  85. +25
    -59
      mindspore/ops/_op_impl/tbe/resize_nearest_neighbor.py
  86. +20
    -55
      mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_d.py
  87. +16
    -55
      mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_grad_d.py
  88. +18
    -43
      mindspore/ops/_op_impl/tbe/round.py
  89. +21
    -86
      mindspore/ops/_op_impl/tbe/rsqrt.py
  90. +20
    -63
      mindspore/ops/_op_impl/tbe/scatter_nd.py
  91. +20
    -62
      mindspore/ops/_op_impl/tbe/scatter_nd_d.py
  92. +25
    -86
      mindspore/ops/_op_impl/tbe/select.py
  93. +23
    -59
      mindspore/ops/_op_impl/tbe/sigmoid.py
  94. +18
    -56
      mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits.py
  95. +19
    -69
      mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_grad.py
  96. +18
    -56
      mindspore/ops/_op_impl/tbe/sigmoid_grad.py
  97. +25
    -91
      mindspore/ops/_op_impl/tbe/slice.py
  98. +19
    -49
      mindspore/ops/_op_impl/tbe/softmax.py
  99. +16
    -69
      mindspore/ops/_op_impl/tbe/softmax_cross_entropy_with_logits.py
  100. +37
    -63
      mindspore/ops/_op_impl/tbe/split_d.py

+ 19
- 62
mindspore/ops/_op_impl/tbe/add.py View File

@@ -14,71 +14,28 @@
# ============================================================================

"""Add op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

add_op_info = TBERegOp("Add") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("add.so") \
.compute_cost(10) \
.kernel_name("add") \
.partial_flag(True) \
.input(0, "x1", False, "required", "all") \
.input(1, "x2", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()

@op_info_register("""{
"op_name": "Add",
"imply_type": "TBE",
"fusion_type": "ELEMWISE",
"async_flag": false,
"binfile_name": "add.so",
"compute_cost": 10,
"kernel_name": "add",
"partial_flag": true,
"attr": [

],
"inputs": [
{
"index": 0,
"dtype": [
"float16", "float16", "float16", "float16", "float", "float", "float",
"float", "int32", "int32", "int32", "int32"
],
"format": [
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat"
],
"name": "x1",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype": [
"float16", "float16", "float16", "float16", "float", "float", "float", "float", "int32",
"int32", "int32", "int32"
],
"format": [
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat"
],
"name": "x2",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16", "float16", "float16", "float16", "float", "float", "float", "float", "int32",
"int32", "int32", "int32"
],
"format": [
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat"
],
"name": "y",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
]
}""")
@op_info_register(add_op_info)
def _add_tbe():
"""Add TBE register"""
return

+ 25
- 53
mindspore/ops/_op_impl/tbe/add_n.py View File

@@ -14,61 +14,33 @@
# ============================================================================

"""AddN op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

add_n_op_info = TBERegOp("AddN") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("add_n.so") \
.compute_cost(10) \
.kernel_name("add_n") \
.partial_flag(True) \
.attr("n", "required", "int", "all") \
.input(0, "x", False, "dynamic", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F16_FracZ, DataType.F16_FracZ) \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ) \
.dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I32_5HD, DataType.I32_5HD) \
.dtype_format(DataType.I32_FracZ, DataType.I32_FracZ) \
.get_op_info()

@op_info_register("""{
"op_name": "AddN",
"imply_type": "TBE",
"fusion_type": "ELEMWISE",
"async_flag": false,
"binfile_name": "add_n.so",
"compute_cost": 10,
"kernel_name": "add_n",
"partial_flag": true,
"attr": [
{
"name": "n",
"param_type": "required",
"type": "int",
"value": "all"
}
],
"inputs": [
{
"index": 0,
"dtype": [
"float16","float16","float16","float16",
"float","float","float","float","int32","int32","int32"
],
"format": [
"DefaultFormat","NC1HWC0","FracZ","FRACTAL_NZ",
"DefaultFormat","NC1HWC0","FracZ","FRACTAL_NZ","DefaultFormat","NC1HWC0","FracZ"
],
"name": "x",
"need_compile": false,
"param_type": "dynamic",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16","float16","float16","float16",
"float","float","float","float","int32","int32","int32"
],
"format": [
"DefaultFormat","NC1HWC0","FracZ","FRACTAL_NZ",
"DefaultFormat","NC1HWC0","FracZ","FRACTAL_NZ","DefaultFormat","NC1HWC0","FracZ"
],
"name": "y",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
]
}""")

@op_info_register(add_n_op_info)
def _add_n_tbe():
"""AddN TBE register"""
return

+ 58
- 206
mindspore/ops/_op_impl/tbe/apply_adam.py View File

@@ -14,214 +14,66 @@
# ============================================================================

"""ApplyAdam op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

apply_adam_op_info = TBERegOp("Adam") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("apply_adam.so") \
.compute_cost(10) \
.kernel_name("apply_adam") \
.partial_flag(True) \
.attr("use_locking", "optional", "bool", "true,false", "false") \
.attr("use_nesterov", "optional", "bool", "true,false", "false") \
.input(0, "var", False, "required", "all") \
.input(1, "m", False, "required", "all") \
.input(2, "v", False, "required", "all") \
.input(3, "beta1_power", False, "required", "all") \
.input(4, "beta2_power", False, "required", "all") \
.input(5, "lr", False, "required", "all") \
.input(6, "beta1", False, "required", "all") \
.input(7, "beta2", False, "required", "all") \
.input(8, "epsilon", False, "required", "all") \
.input(9, "grad", False, "required", "all") \
.output(0, "var", False, "required", "all") \
.output(1, "m", False, "required", "all") \
.output(2, "v", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default) \
.dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0,
DataType.F16_C1HWNCoC0) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD,
DataType.F16_5HD) \
.dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ,
DataType.F16_FracZ) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default) \
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0,
DataType.F32_C1HWNCoC0) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD) \
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ,
DataType.F32_FracZ) \
.get_op_info()

@op_info_register("""{
"op_name": "Adam",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "apply_adam.so",
"compute_cost": 10,
"kernel_name": "apply_adam",
"partial_flag": true,
"attr": [
{
"name": "use_locking",
"param_type": "optional",
"type": "bool",
"value": "true,false",
"default_value":"false"
},
{
"name": "use_nesterov",
"param_type": "optional",
"type": "bool",
"value": "true,false",
"default_value":"false"
}
],
"inputs": [
{
"index": 0,
"dtype": [
"float16","float16","float16","float16","float","float","float","float"
],
"format": [
"NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ", "NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ"
],
"name": "var",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype": [
"float16","float16","float16","float16","float","float","float","float"
],
"format": [
"NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ", "NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ"
],
"name": "m",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 2,
"dtype": [
"float16","float16","float16","float16","float","float","float","float"
],
"format": [
"NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ", "NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ"
],
"name": "v",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 3,
"dtype": [
"float16","float16","float16","float16","float","float","float", "float"
],
"format": [
"DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat",
"DefaultFormat", "DefaultFormat"
],
"name": "beta1_power",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 4,
"dtype": [
"float16","float16","float16","float16","float","float","float","float"
],
"format": [
"DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat",
"DefaultFormat", "DefaultFormat"
],
"name": "beta2_power",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 5,
"dtype": [
"float16","float16","float16","float16","float","float","float", "float"
],
"format": [
"DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat",
"DefaultFormat", "DefaultFormat"
],
"name": "lr",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 6,
"dtype": [
"float16","float16","float16","float16","float","float","float", "float"
],
"format": [
"DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat",
"DefaultFormat", "DefaultFormat"
],
"name": "beta1",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 7,
"dtype": [
"float16","float16","float16","float16","float","float","float", "float"
],
"format": [
"DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat",
"DefaultFormat", "DefaultFormat"
],
"name": "beta2",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 8,
"dtype": [
"float16","float16","float16","float16","float","float","float", "float"
],
"format": [
"DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat",
"DefaultFormat", "DefaultFormat"
],
"name": "epsilon",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 9,
"dtype": [
"float16","float16","float16","float16","float","float","float", "float"
],
"format": [
"NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ", "NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ"
],
"name": "grad",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16","float16","float16","float16","float","float","float","float"
],
"format": [
"NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ", "NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ"
],
"name": "var",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype": [
"float16","float16","float16","float16","float","float","float","float"
],
"format": [
"NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ", "NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ"
],
"name": "m",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 2,
"dtype": [
"float16","float16","float16","float16","float","float","float","float"
],
"format": [
"NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ", "NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ"
],
"name": "v",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
]
}""")

@op_info_register(apply_adam_op_info)
def _apply_adam_tbe():
"""ApplyAdam TBE register"""
return

+ 34
- 104
mindspore/ops/_op_impl/tbe/apply_momentum.py View File

@@ -14,112 +14,42 @@
# ============================================================================

"""ApplyMomentum op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

apply_momentum_op_info = TBERegOp("ApplyMomentum") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("apply_momentum.so") \
.compute_cost(10) \
.kernel_name("apply_momentum") \
.partial_flag(True) \
.attr("use_nesterov", "optional", "bool", "true,false", "false") \
.input(0, "var", False, "required", "all") \
.input(1, "accum", False, "required", "all") \
.input(2, "lr", False, "required", "all") \
.input(3, "grad", False, "required", "all") \
.input(4, "momentum", False, "required", "all") \
.output(0, "var", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_Default, DataType.F16_5HD,
DataType.F16_Default, DataType.F16_5HD) \
.dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_Default, DataType.F16_C1HWNCoC0,
DataType.F16_Default, DataType.F16_C1HWNCoC0) \
.dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_Default, DataType.F16_FracZ,
DataType.F16_Default, DataType.F16_FracZ) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_Default, DataType.F32_5HD,
DataType.F32_Default, DataType.F32_5HD) \
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_Default, DataType.F32_C1HWNCoC0,
DataType.F32_Default, DataType.F32_C1HWNCoC0) \
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_Default, DataType.F32_FracZ,
DataType.F32_Default, DataType.F32_FracZ) \
.get_op_info()

@op_info_register("""{
"op_name": "ApplyMomentum",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "apply_momentum.so",
"compute_cost": 10,
"kernel_name": "apply_momentum",
"partial_flag": true,
"attr": [
{
"name": "use_nesterov",
"param_type": "optional",
"type": "bool",
"value": "true,false",
"default_value":"false"
}
],
"inputs": [
{
"index": 0,
"dtype": [
"float16","float16","float16","float16","float","float","float","float"
],
"format": [
"NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ", "NC1HWC0", "DefaultFormat", "FracZ", "C1HWNCoC0"
],
"name": "var",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype": [
"float16","float16","float16","float16","float","float","float","float"
],
"format": [
"NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ", "NC1HWC0", "DefaultFormat", "FracZ", "C1HWNCoC0"
],
"name": "accum",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 2,
"dtype": [
"float16","float16","float16","float16","float","float","float","float"
],
"format": [
"DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat",
"DefaultFormat", "DefaultFormat"
],
"name": "lr",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 3,
"dtype": [
"float16","float16","float16","float16","float","float","float","float"
],
"format": [
"NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ", "NC1HWC0", "DefaultFormat", "FracZ", "C1HWNCoC0"
],
"name": "grad",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 4,
"dtype": [
"float16","float16","float16","float16","float","float","float", "float"
],
"format": [
"DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat",
"DefaultFormat", "DefaultFormat"
],
"name": "momentum",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16","float16","float16","float16","float","float","float","float"
],
"format": [
"NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ", "NC1HWC0", "DefaultFormat", "FracZ", "C1HWNCoC0"
],
"name": "var",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
]
}""")

@op_info_register(apply_momentum_op_info)
def _apply_momentum_tbe():
"""ApplyMomentum TBE register"""
return

+ 17
- 62
mindspore/ops/_op_impl/tbe/arg_max_with_value.py View File

@@ -14,70 +14,25 @@
# ============================================================================

"""ArgMaxWithValue op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

arg_max_with_value_op_info = TBERegOp("ArgMaxWithValue") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("arg_max_with_value.so") \
.compute_cost(10) \
.kernel_name("arg_max_with_value") \
.partial_flag(True) \
.attr("axis", "required", "int", "all") \
.input(0, "x", False, "required", "all") \
.output(0, "indice", False, "required", "all") \
.output(1, "values", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \
.get_op_info()

@op_info_register("""{
"op_name": "ArgMaxWithValue",
"imply_type": "TBE",
"fusion_type": "ELEMWISE",
"async_flag": false,
"binfile_name": "arg_max_with_value.so",
"compute_cost": 10,
"kernel_name": "arg_max_with_value",
"partial_flag": true,
"attr": [
{
"name": "axis",
"param_type": "required",
"type": "int",
"value": "all"
}
],
"inputs": [
{
"index": 0,
"dtype": [
"float16", "float"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "x",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"int32", "int32"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "indice",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 0,
"dtype": [
"float16", "float"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "values",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
]
}""")

@op_info_register(arg_max_with_value_op_info)
def _arg_max_with_value_tbe():
"""ArgMaxWithValue TBE register"""
return

+ 17
- 62
mindspore/ops/_op_impl/tbe/arg_min_with_value.py View File

@@ -14,70 +14,25 @@
# ============================================================================

"""ArgMinWithValue op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

arg_min_with_value_op_info = TBERegOp("ArgMaxWithValue") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("arg_min_with_value.so") \
.compute_cost(10) \
.kernel_name("arg_min_with_value") \
.partial_flag(True) \
.attr("axis", "required", "int", "all") \
.input(0, "x", False, "required", "all") \
.output(0, "indice", False, "required", "all") \
.output(1, "values", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \
.get_op_info()

@op_info_register("""{
"op_name": "ArgMinWithValue",
"imply_type": "TBE",
"fusion_type": "ELEMWISE",
"async_flag": false,
"binfile_name": "arg_min_with_value.so",
"compute_cost": 10,
"kernel_name": "arg_min_with_value",
"partial_flag": true,
"attr": [
{
"name": "axis",
"param_type": "required",
"type": "int",
"value": "all"
}
],
"inputs": [
{
"index": 0,
"dtype": [
"float16", "float"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "x",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"int32", "int32"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "indice",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 0,
"dtype": [
"float16", "float"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "values",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
]
}""")

@op_info_register(arg_min_with_value_op_info)
def _arg_min_with_value_tbe():
"""ArgMinWithValue TBE register"""
return

+ 34
- 84
mindspore/ops/_op_impl/tbe/assign.py View File

@@ -14,93 +14,43 @@
# ============================================================================

"""Assign op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

assign_op_info = TBERegOp("Assign") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("assign.so") \
.compute_cost(10) \
.kernel_name("assign") \
.partial_flag(True) \
.input(0, "resource", False, "required", "all") \
.input(1, "value", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I8_5HD, DataType.I8_5HD, DataType.I8_5HD) \
.dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.U8_5HD, DataType.U8_5HD, DataType.U8_5HD) \
.dtype_format(DataType.I16_Default, DataType.I16_Default, DataType.I16_Default) \
.dtype_format(DataType.I16_5HD, DataType.I16_5HD, DataType.I16_5HD) \
.dtype_format(DataType.U16_Default, DataType.U16_Default, DataType.U16_Default) \
.dtype_format(DataType.U16_5HD, DataType.U16_5HD, DataType.U16_5HD) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \
.dtype_format(DataType.U32_Default, DataType.U32_Default, DataType.U32_Default) \
.dtype_format(DataType.U32_5HD, DataType.U32_5HD, DataType.U32_5HD) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.I64_5HD, DataType.I64_5HD, DataType.I64_5HD) \
.dtype_format(DataType.U64_Default, DataType.U64_Default, DataType.U64_Default) \
.dtype_format(DataType.U64_5HD, DataType.U64_5HD, DataType.U64_5HD) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ) \
.get_op_info()

@op_info_register("""{
"op_name": "Assign",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "assign.so",
"compute_cost": 10,
"kernel_name": "assign",
"partial_flag": true,
"attr": [

],
"inputs": [
{
"index": 0,
"dtype": [
"float16", "float16", "float16", "float16", "float", "float", "float", "float",
"int32", "int32", "int32", "int32", "uint32", "uint32", "uint32", "uint32", "int8",
"int8", "int8", "int8", "uint8", "uint8", "uint8", "uint8", "int16", "int16", "int16",
"int16", "uint16", "uint16", "uint16", "uint16", "int64", "int64", "int64", "int64",
"uint64", "uint64", "uint64", "uint64", "float"
],
"format": [
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat",
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat",
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat",
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "FRACTAL_NZ"
],
"name": "resource",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype": [
"float16", "float16", "float16", "float16", "float", "float", "float", "float", "int32", "int32",
"int32", "int32", "uint32", "uint32", "uint32", "uint32", "int8", "int8", "int8", "int8", "uint8",
"uint8", "uint8", "uint8", "int16", "int16", "int16", "int16", "uint16", "uint16", "uint16",
"uint16", "int64", "int64", "int64", "int64", "uint64", "uint64", "uint64", "uint64", "float"
],
"format": [
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat",
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat",
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat",
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "FRACTAL_NZ"
],
"name": "value",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16", "float16", "float16", "float16", "float", "float", "float", "float", "int32",
"int32", "int32", "int32", "uint32", "uint32", "uint32", "uint32", "int8", "int8", "int8",
"int8", "uint8", "uint8", "uint8", "uint8", "int16", "int16", "int16", "int16", "uint16",
"uint16", "uint16", "uint16", "int64", "int64", "int64", "int64",
"uint64", "uint64", "uint64", "uint64", "float"
],
"format": [
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat",
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat",
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat",
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "FRACTAL_NZ"
],
"name": "y",
"param_type": "required",
"shape": "all"
}
]
}""")
@op_info_register(assign_op_info)
def _assign_tbe():
"""Assign TBE register"""
return

+ 25
- 71
mindspore/ops/_op_impl/tbe/assign_add.py View File

@@ -14,80 +14,34 @@
# ============================================================================

"""AssignAdd op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

assign_add_op_info = TBERegOp("AssignAdd") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("assignadd.so") \
.compute_cost(10) \
.kernel_name("assignadd") \
.partial_flag(True) \
.input(0, "ref", False, "required", "all") \
.input(1, "value", False, "required", "all") \
.output(0, "output_ref", False, "required", "all") \
.dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I8_5HD, DataType.I8_5HD, DataType.I8_5HD) \
.dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.U8_5HD, DataType.U8_5HD, DataType.U8_5HD) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.I64_5HD, DataType.I64_5HD, DataType.I64_5HD) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()

@op_info_register("""{
"op_name": "AssignAdd",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "assignadd.so",
"compute_cost": 10,
"kernel_name": "assignadd",
"partial_flag": true,
"attr": [

],
"inputs": [
{
"index": 0,
"dtype":[
"float16", "float16", "float16", "float16", "float", "float", "float", "float", "int32", "int32",
"int32", "int32", "int8", "int8", "int8", "int8", "uint8", "uint8", "uint8", "uint8", "int64",
"int64", "int64", "int64"
],
"format":[
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat",
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat"
],
"name": "ref",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype":[
"float16", "float16", "float16", "float16", "float", "float", "float", "float", "int32", "int32",
"int32", "int32", "int8", "int8", "int8", "int8", "uint8", "uint8", "uint8", "uint8", "int64",
"int64", "int64", "int64"
],
"format":[
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat",
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat"
],
"name": "value",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype":[
"float16", "float16", "float16", "float16", "float", "float", "float", "float", "int32", "int32",
"int32", "int32", "int8", "int8", "int8", "int8", "uint8", "uint8", "uint8", "uint8", "int64",
"int64", "int64", "int64"
],
"format":[
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat",
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat"
],
"name": "output_ref",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
]
}""")
@op_info_register(assign_add_op_info)
def _assign_add_tbe():
"""AssignAdd TBE register"""
return

+ 18
- 56
mindspore/ops/_op_impl/tbe/assign_sub.py View File

@@ -14,65 +14,27 @@
# ============================================================================

"""AssignSub op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

assign_sub_op_info = TBERegOp("AssignSub") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("assign_sub.so") \
.compute_cost(10) \
.kernel_name("assign_sub") \
.partial_flag(True) \
.input(0, "var", False, "required", "all") \
.input(1, "value", False, "required", "all") \
.output(0, "output_ref", False, "required", "all") \
.dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.get_op_info()

@op_info_register("""{
"op_name": "AssignSub",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "assign_sub.so",
"compute_cost": 10,
"kernel_name": "assign_sub",
"partial_flag": true,
"attr": [

],
"inputs": [
{
"index": 0,
"dtype": [
"float16", "float", "int32", "int8", "uint8"
],
"format": [
"DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat"
],
"name": "var",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype": [
"float16", "float", "int32", "int8", "uint8"
],
"format": [
"DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat"
],
"name": "value",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16", "float", "int32", "int8", "uint8"
],
"format": [
"DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat"
],
"name": "out_ref",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
]
}""")
@op_info_register(assign_sub_op_info)
def _assign_sub_tbe():
"""AssignSub TBE register"""
return

+ 12
- 23
mindspore/ops/_op_impl/tbe/atomic_addr_clean.py View File

@@ -14,31 +14,20 @@
# ============================================================================

"""AtomicAddrClean op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp

atomic_addr_clean_op_info = TBERegOp("AtomicAddrClean") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("atomic_addr_clean.so") \
.compute_cost(10) \
.kernel_name("atomic_addr_clean") \
.partial_flag(True) \
.attr("automic_add_mem_size", "required", "listInt", "all") \
.get_op_info()

@op_info_register("""{
"op_name": "AtomicAddrClean",
"imply_type": "TBE",
"fusion_type": "ELEMWISE",
"async_flag": false,
"binfile_name": "atomic_addr_clean.so",
"compute_cost": 10,
"kernel_name": "atomic_addr_clean",
"partial_flag": true,
"attr": [
{
"name": "automic_add_mem_size",
"param_type": "required",
"type": "listInt",
"value": "all"
}
],
"inputs": [
],
"outputs": [
]
}""")

@op_info_register(atomic_addr_clean_op_info)
def _atomic_addr_clean_tbe():
"""AtomicAddrClean TBE register"""
return

+ 21
- 80
mindspore/ops/_op_impl/tbe/batch_matmul.py View File

@@ -14,88 +14,29 @@
# ============================================================================

"""BatchMatMul op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

batch_matmul_op_info = TBERegOp("BatchMatMul") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("batch_matmul.so") \
.compute_cost(10) \
.kernel_name("batch_matmul") \
.attr("transpose_x1", "required", "bool", "all") \
.attr("transpose_x2", "required", "bool", "all") \
.partial_flag(True) \
.input(0, "x1", False, "required", "all") \
.input(1, "x2", False, "required", "all") \
.input(2, "bias", False, "optional", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_Default, DataType.F16_FracNZ) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.get_op_info()

@op_info_register("""{
"op_name": "BatchMatMul",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "batch_matmul.so",
"compute_cost": 10,
"kernel_name": "batch_matmul",
"partial_flag": true,
"attr": [
{
"name": "transpose_x1",
"param_type": "required",
"type": "bool",
"value": "all"
},
{
"name": "transpose_x2",
"param_type": "required",
"type": "bool",
"value": "all"
}
],
"inputs": [
{
"index": 0,
"dtype": [
"float16","float16","float16","float","float","int32","int32"
],
"format": [
"DefaultFormat","DefaultFormat","FRACTAL_NZ","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat"
],
"name": "x1",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype": [
"float16","float16","float16","float","float","int32","int32"
],
"format": [
"DefaultFormat","DefaultFormat","FRACTAL_NZ","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat"
],
"name": "x2",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 2,
"dtype": [
"float16","float16","float16","float","float","int32","int32"
],
"format": [
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat"
],
"name": "bias",
"need_compile": false,
"param_type": "optional",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16","float16","float16","float","float","int32","int32"
],
"format": [
"DefaultFormat","DefaultFormat","FRACTAL_NZ","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat"
],
"name": "y",
"param_type": "required",
"shape": "all"
}
]
}""")

@op_info_register(batch_matmul_op_info)
def _batch_matmul_tbe():
"""BatchMatMul TBE register"""
return

+ 37
- 166
mindspore/ops/_op_impl/tbe/batchnorm.py View File

@@ -14,174 +14,45 @@
# ============================================================================

"""BatchNorm op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

batch_norm_op_info = TBERegOp("BatchNorm") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("batch_norm.so") \
.compute_cost(10) \
.kernel_name("batch_norm") \
.partial_flag(True) \
.attr("epsilon", "optional", "float", "all") \
.attr("data_format", "optional", "str", "all") \
.attr("is_training", "optional", "bool", "all") \
.input(0, "x", False, "required", "all") \
.input(1, "scale", False, "required", "all") \
.input(2, "offset", False, "required", "all") \
.input(3, "mean", False, "optional", "all") \
.input(4, "variance", False, "optional", "all") \
.output(0, "y", False, "required", "all") \
.output(1, "batch_mean", False, "required", "all") \
.output(2, "batch_variance", False, "required", "all") \
.output(3, "reserve_space_1", False, "optional", "all") \
.output(4, "reserve_space_2", False, "optional", "all") \
.output(5, "reserve_space_3", False, "optional", "all") \
.dtype_format(DataType.F16_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F16_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD, DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()

@op_info_register("""{
"op_name": "BatchNorm",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "batch_norm.so",
"compute_cost": 10,
"kernel_name": "batch_norm",
"partial_flag": true,
"attr": [
{
"name": "epsilon",
"param_type": "required",
"type": "float",
"value": "all"
},
{
"name": "data_format",
"param_type": "required",
"type": "str",
"value": "all"
},
{
"name": "is_training",
"param_type": "required",
"type": "bool",
"value": "all"
}
],
"inputs": [
{
"index": 0,
"dtype": [
"float16","float16","float","float"
],
"format": [
"DefaultFormat","NC1HWC0", "DefaultFormat","NC1HWC0"
],
"name": "x",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype": [
"float","float","float","float"
],
"format": [
"DefaultFormat","NC1HWC0","DefaultFormat", "NC1HWC0"
],
"name": "scale",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 2,
"dtype": [
"float","float","float","float"
],
"format": [
"DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0"
],
"name": "offset",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 3,
"dtype": [
"float","float","float","float"
],
"format": [
"DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0"
],
"name": "mean",
"need_compile": false,
"param_type": "optional",
"shape": "all"
},
{
"index": 4,
"dtype": [
"float","float","float","float"
],
"format": [
"DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0"
],
"name": "variance",
"need_compile": false,
"param_type": "optional",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16", "float16", "float", "float"
],
"format": [
"DefaultFormat","NC1HWC0", "DefaultFormat","NC1HWC0"
],
"name": "y",
"param_type": "required"
},
{
"index": 1,
"dtype": [
"float","float","float","float"
],
"format": [
"DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0"
],
"name": "batch_mean",
"param_type": "required"
},
{
"index": 2,
"dtype": [
"float", "float", "float", "float"
],
"format": [
"DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0"
],
"name": "batch_variance",
"param_type": "required"
},
{
"index": 3,
"dtype": [
"float", "float", "float", "float"
],
"format": [
"DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0"
],
"name": "reserve_space_1",
"param_type": "optional"
},
{
"index": 4,
"dtype": [
"float", "float", "float", "float"
],
"format": [
"DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0"
],
"name": "reserve_space_2",
"param_type": "optional"
},
{
"index": 5,
"dtype": [
"float", "float", "float", "float"
],
"format": [
"DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0"
],
"name": "reserve_space_3",
"param_type": "optional"
}
]
}""")

@op_info_register(batch_norm_op_info)
def _batch_norm_tbe():
"""BatchNorm TBE register"""
return

+ 37
- 173
mindspore/ops/_op_impl/tbe/batchnorm_grad.py View File

@@ -14,181 +14,45 @@
# ============================================================================

"""BatchNormGrad op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

batch_norm_grad_op_info = TBERegOp("BatchNormGrad") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("batchnormgrad.so") \
.compute_cost(10) \
.kernel_name("batchnormgrad") \
.partial_flag(True) \
.attr("epsilon", "optional", "float", "all") \
.attr("data_format", "optional", "str", "all") \
.attr("is_training", "optional", "bool", "all") \
.input(0, "y_backprop", False, "required", "all") \
.input(1, "x", False, "required", "all") \
.input(2, "scale", False, "required", "all") \
.input(3, "reserve_space_1", False, "required", "all") \
.input(4, "reserve_space_2", False, "required", "all") \
.input(5, "reserve_space_3", False, "required", "all") \
.output(0, "x_backprop", False, "required", "all") \
.output(1, "scale_backprop", False, "required", "all") \
.output(2, "offset_backprop", False, "required", "all") \
.output(3, "reserve_space_4", False, "optional", "all") \
.output(4, "reserve_space_5", False, "optional", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F16_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD, DataType.F32_5HD, DataType.F16_5HD, DataType.F32_5HD,
DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()

@op_info_register("""{
"op_name": "BatchNormGrad",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "batchnormgrad.so",
"compute_cost": 10,
"kernel_name": "batchnormgrad",
"partial_flag": true,
"attr": [
{
"name": "epsilon",
"param_type": "optional",
"type": "float",
"value": "all"
},
{
"name": "data_format",
"param_type": "optional",
"type": "str",
"value": "all"
},
{
"name": "is_training",
"param_type": "optional",
"type": "bool",
"value": "all"
}
],
"inputs": [
{
"index": 0,
"dtype": [
"float16","float16","float16","float","float","float"
],
"format": [
"DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","NC1HWC0"
],
"name": "y_backprop",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype": [
"float16","float16","float16","float","float","float"
],
"format": [
"DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","NC1HWC0"
],
"name": "x",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 2,
"dtype": [
"float","float","float","float","float","float"
],
"format": [
"DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","NC1HWC0"
],
"name": "scale",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 3,
"dtype": [
"float","float","float","float","float","float"
],
"format": [
"DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","NC1HWC0"
],
"name": "reserve_space_1",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 4,
"dtype": [
"float","float","float","float","float","float"
],
"format": [
"DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","NC1HWC0"
],
"name": "reserve_space_2",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 5,
"dtype": [
"float","float","float","float","float","float"
],
"format": [
"DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","NC1HWC0"
],
"name": "reserve_space_3",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16","float16","float16","float","float","float"
],
"format": [
"DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","NC1HWC0"
],
"name": "x_backprop",
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype": [
"float","float","float","float","float","float"
],
"format": [
"DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","NC1HWC0"
],
"name": "scale_backprop",
"param_type": "required",
"shape": "all"
},
{
"index": 2,
"dtype": [
"float","float","float","float","float","float"
],
"format": [
"DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","NC1HWC0"
],
"name": "offset_backprop",
"param_type": "required",
"shape": "all"
},
{
"index": 3,
"dtype": [
"float","float","float","float","float","float"
],
"format": [
"DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","NC1HWC0"
],
"name": "reserve_space_4",
"param_type": "optional",
"shape": "all"
},
{
"index": 4,
"dtype": [
"float","float","float","float","float","float"
],
"format": [
"DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","NC1HWC0"
],
"name": "reserve_space_5",
"param_type": "optional",
"shape": "all"
}
]
}""")

@op_info_register(batch_norm_grad_op_info)
def _batch_norm_grad_tbe():
"""BatchNormGrad TBE register"""
return

+ 18
- 62
mindspore/ops/_op_impl/tbe/bias_add.py View File

@@ -14,70 +14,26 @@
# ============================================================================

"""BiasAdd op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

bias_add_grad_op_info = TBERegOp("BiasAdd") \
.fusion_type("COMMREDUCE") \
.async_flag(False) \
.binfile_name("bias_add.so") \
.compute_cost(10) \
.kernel_name("bias_add") \
.partial_flag(True) \
.attr("data_format", "required", "str", "all") \
.input(0, "x", False, "required", "all") \
.input(1, "bias", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.get_op_info()

@op_info_register("""{
"op_name": "BiasAdd",
"imply_type": "TBE",
"fusion_type": "COMMREDUCE",
"async_flag": false,
"binfile_name": "bias_add.so",
"compute_cost": 10,
"kernel_name": "bias_add",
"partial_flag": true,
"attr": [
{
"name": "data_format",
"param_type": "required",
"type": "str",
"value": "all"
}
],
"inputs": [
{
"index": 0,
"dtype": [
"int32", "float16", "float"
],
"format": [
"DefaultFormat", "DefaultFormat", "DefaultFormat"
],
"name": "x",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype": [
"int32", "float16", "float"
],
"format": [
"DefaultFormat", "DefaultFormat", "DefaultFormat"
],
"name": "bias",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"int32", "float16", "float"
],
"format": [
"DefaultFormat", "DefaultFormat", "DefaultFormat"
],
"name": "y",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
]
}""")

@op_info_register(bias_add_grad_op_info)
def _bias_add_tbe():
"""BiasAdd TBE register"""
return

+ 18
- 49
mindspore/ops/_op_impl/tbe/bias_add_grad.py View File

@@ -14,57 +14,26 @@
# ============================================================================

"""BiasAddGrad op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

bias_add_grad_op_info = TBERegOp("BiasAddGrad") \
.fusion_type("COMMREDUCE") \
.async_flag(False) \
.binfile_name("biasaddgrad.so") \
.compute_cost(10) \
.kernel_name("biasaddgrad") \
.partial_flag(True) \
.attr("data_format", "required", "str", "all") \
.input(0, "output_backprop", False, "required", "all") \
.output(0, "output", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_FracNZ, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_FracNZ, DataType.F32_Default) \
.get_op_info()

@op_info_register("""{
"op_name": "BiasAddGrad",
"imply_type": "TBE",
"fusion_type": "COMMREDUCE",
"async_flag": false,
"binfile_name": "biasaddgrad.so",
"compute_cost": 10,
"kernel_name": "biasaddgrad",
"partial_flag": true,
"attr": [
{
"name": "data_format",
"param_type": "required",
"type": "str",
"value": "all"
}
],
"inputs": [
{
"index": 0,
"dtype": [
"float16","float16","float","float"
],
"format": [
"FRACTAL_NZ","DefaultFormat","FRACTAL_NZ","DefaultFormat"
],
"name": "out_backprop",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16","float16","float","float"
],
"format": [
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat"
],
"name": "output",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
]
}""")

@op_info_register(bias_add_grad_op_info)
def _bias_add_grad_tbe():
"""BiasAddGrad TBE register"""
return

+ 16
- 52
mindspore/ops/_op_impl/tbe/bn_training_reduce.py View File

@@ -14,60 +14,24 @@
# ============================================================================

"""BatchNorm op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

bn_training_reduce_op_info = TBERegOp("BNTrainingReduce") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("bn_training_reduce.so") \
.compute_cost(10) \
.kernel_name("bn_training_reduce") \
.partial_flag(True) \
.input(0, "x", False, "required", "all") \
.output(0, "sum", False, "required", "all") \
.output(1, "square_sum", False, "required", "all") \
.dtype_format(DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()

@op_info_register("""{
"op_name": "BNTrainingReduce",
"imply_type": "TBE",
"fusion_type": "ELEMWISE",
"async_flag": false,
"binfile_name": "bn_training_reduce.so",
"compute_cost": 10,
"kernel_name": "bn_training_reduce",
"partial_flag": true,
"attr": [
],
"inputs": [
{
"index": 0,
"dtype": [
"float16","float"
],
"format": [
"NC1HWC0", "NC1HWC0"
],
"name": "x",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float","float"
],
"format": [
"NC1HWC0", "NC1HWC0"
],
"name": "sum",
"param_type": "required"
},
{
"index": 1,
"dtype": [
"float","float"
],
"format": [
"NC1HWC0", "NC1HWC0"
],
"name": "square_sum",
"param_type": "required"
}
]
}""")

@op_info_register(bn_training_reduce_op_info)
def _bn_training_reduce_tbe():
"""BNTrainingReduce TBE register"""
return

+ 24
- 126
mindspore/ops/_op_impl/tbe/bn_training_reduce_grad.py View File

@@ -14,134 +14,32 @@
# ============================================================================

"""BatchNormGrad op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

bn_training_reduce_grad_op_info = TBERegOp("BNTrainingReduceGrad") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("bn_training_reduce_grad.so") \
.compute_cost(10) \
.kernel_name("bn_training_reduce_grad") \
.partial_flag(True) \
.attr("epsilon", "optional", "float", "all") \
.input(0, "grads", False, "required", "all") \
.input(1, "x_norm", False, "required", "all") \
.input(2, "diff_scale", False, "required", "all") \
.input(3, "diff_offset", False, "required", "all") \
.input(4, "scale", False, "required", "all") \
.input(5, "batch_mean", False, "required", "all") \
.input(6, "batch_variance", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()

@op_info_register("""{
"op_name": "BNTrainingReduceGrad",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "bn_training_reduce_grad.so",
"compute_cost": 10,
"kernel_name": "bn_training_reduce_grad",
"partial_flag": true,
"attr": [
{
"name": "epsilon",
"param_type": "optional",
"type": "float",
"value": "all"
}
],
"inputs": [
{
"index": 0,
"dtype": [
"float16", "float"
],
"format": [
"NC1HWC0","NC1HWC0"
],
"name": "grads",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype": [
"float", "float"
],
"format": [
"NC1HWC0","NC1HWC0"
],
"name": "x_norm",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 2,
"dtype": [
"float", "float"
],
"format": [
"NC1HWC0","NC1HWC0"
],
"name": "diff_scale",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 3,
"dtype": [
"float", "float"
],
"format": [
"NC1HWC0","NC1HWC0"
],
"name": "diff_offset",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 4,
"dtype": [
"float", "float"
],
"format": [
"NC1HWC0","NC1HWC0"
],
"name": "scale",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 5,
"dtype": [
"float", "float"
],
"format": [
"NC1HWC0","NC1HWC0"
],
"name": "batch_mean",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 6,
"dtype": [
"float", "float"
],
"format": [
"NC1HWC0","NC1HWC0"
],
"name": "batch_variance",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16", "float"
],
"format": [
"NC1HWC0","NC1HWC0"
],
"name": "y",
"param_type": "required",
"shape": "all"
}
]
}""")

@op_info_register(bn_training_reduce_grad_op_info)
def _bn_training_reduce_grad_tbe():
"""BNTrainingReduceGrad TBE register"""
return

+ 32
- 192
mindspore/ops/_op_impl/tbe/bn_training_update.py View File

@@ -14,200 +14,40 @@
# ============================================================================

"""BatchNormGrad op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

bn_training_update_op_info = TBERegOp("BNTrainingUpdate") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("bn_training_update.so") \
.compute_cost(10) \
.kernel_name("bn_training_update") \
.partial_flag(True) \
.attr("factor", "optional", "float", "all") \
.attr("epsilon", "optional", "float", "all") \
.attr("isRef", "optional", "bool", "all", "true") \
.input(0, "x", False, "required", "all") \
.input(1, "sum", False, "required", "all") \
.input(2, "square_sum", False, "required", "all") \
.input(3, "scale", False, "required", "all") \
.input(4, "offset", False, "required", "all") \
.input(5, "mean", False, "required", "all") \
.input(6, "variance", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.output(1, "mean", False, "required", "all") \
.output(2, "variance", False, "required", "all") \
.output(3, "batch_mean", False, "required", "all") \
.output(4, "batch_variance", False, "required", "all") \
.dtype_format(DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F16_5HD,
DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()

@op_info_register("""{
"op_name": "BNTrainingUpdate",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "bn_training_update.so",
"compute_cost": 10,
"kernel_name": "bn_training_update",
"partial_flag": true,
"attr": [
{
"name": "factor",
"param_type": "optional",
"type": "float",
"value": "all"
},
{
"name": "epsilon",
"param_type": "optional",
"type": "float",
"value": "all"
},
{
"name": "isRef",
"param_type": "optional",
"type": "bool",
"default_value":"true",
"value": "all"
}
],
"inputs": [
{
"index": 0,
"dtype": [
"float16", "float"
],
"format": [
"NC1HWC0","NC1HWC0"
],
"name": "x",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype": [
"float", "float"
],
"format": [
"NC1HWC0","NC1HWC0"
],
"name": "sum",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 2,
"dtype": [
"float", "float"
],
"format": [
"NC1HWC0","NC1HWC0"
],
"name": "square_sum",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 3,
"dtype": [
"float", "float"
],
"format": [
"NC1HWC0","NC1HWC0"
],
"name": "scale",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 4,
"dtype": [
"float", "float"
],
"format": [
"NC1HWC0","NC1HWC0"
],
"name": "offset",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 5,
"dtype": [
"float", "float"
],
"format": [
"NC1HWC0","NC1HWC0"
],
"name": "mean",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 6,
"dtype": [
"float", "float"
],
"format": [
"NC1HWC0","NC1HWC0"
],
"name": "variance",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16", "float"
],
"format": [
"NC1HWC0","NC1HWC0"
],
"name": "y",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype": [
"float", "float"
],
"format": [
"NC1HWC0","NC1HWC0"
],
"name": "mean",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 2,
"dtype": [
"float", "float"
],
"format": [
"NC1HWC0","NC1HWC0"
],
"name": "variance",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 3,
"dtype": [
"float", "float"
],
"format": [
"NC1HWC0","NC1HWC0"
],
"name": "batch_mean",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 4,
"dtype": [
"float", "float"
],
"format": [
"NC1HWC0","NC1HWC0"
],
"name": "batch_variance",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
]
}""")

@op_info_register(bn_training_update_op_info)
def _bn_training_update_tbe():
"""BNTrainingUpdate TBE register"""
return

+ 22
- 101
mindspore/ops/_op_impl/tbe/bn_training_update_grad.py View File

@@ -14,109 +14,30 @@
# ============================================================================

"""BatchNormGrad op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

bn_training_update_grad_op_info = TBERegOp("BNTrainingUpdateGrad") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("bn_training_update_grad.so") \
.compute_cost(10) \
.kernel_name("bn_training_update_grad") \
.partial_flag(True) \
.attr("epsilon", "optional", "float", "all") \
.input(0, "grads", False, "required", "all") \
.input(1, "x", False, "required", "all") \
.input(2, "batch_mean", False, "required", "all") \
.input(3, "batch_variance", False, "required", "all") \
.output(0, "diff_scale", False, "required", "all") \
.output(1, "diff_offset", False, "required", "all") \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()

@op_info_register("""{
"op_name": "BNTrainingUpdateGrad",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "bn_training_update_grad.so",
"compute_cost": 10,
"kernel_name": "bn_training_update_grad",
"partial_flag": true,
"attr": [
{
"name": "epsilon",
"param_type": "optional",
"type": "float",
"value": "all"
}
],
"inputs": [
{
"index": 0,
"dtype": [
"float16", "float"
],
"format": [
"NC1HWC0","NC1HWC0"
],
"name": "grads",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype": [
"float16", "float"
],
"format": [
"NC1HWC0","NC1HWC0"
],
"name": "x",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 2,
"dtype": [
"float", "float"
],
"format": [
"NC1HWC0","NC1HWC0"
],
"name": "batch_mean",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 3,
"dtype": [
"float", "float"
],
"format": [
"NC1HWC0","NC1HWC0"
],
"name": "batch_variance",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float", "float"
],
"format": [
"NC1HWC0","NC1HWC0"
],
"name": "diff_scale",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype": [
"float", "float"
],
"format": [
"NC1HWC0","NC1HWC0"
],
"name": "diff_offset",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
]
}""")

@op_info_register(bn_training_update_grad_op_info)
def _bn_training_update_grad_tbe():
"""BNTrainingUpdateGrad TBE register"""
return

+ 34
- 61
mindspore/ops/_op_impl/tbe/cast.py View File

@@ -14,69 +14,42 @@
# ============================================================================

"""Cast op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

cast_op_info = TBERegOp("Cast") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("cast.so") \
.compute_cost(10) \
.kernel_name("cast") \
.partial_flag(True) \
.attr("dst_type", "required", "int", "all") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.BOOL_Default, DataType.F16_Default) \
.dtype_format(DataType.BOOL_Default, DataType.U8_Default) \
.dtype_format(DataType.BOOL_Default, DataType.F32_Default) \
.dtype_format(DataType.BOOL_Default, DataType.I32_Default) \
.dtype_format(DataType.I8_Default, DataType.F16_Default) \
.dtype_format(DataType.I8_Default, DataType.F32_Default) \
.dtype_format(DataType.I8_Default, DataType.I32_Default) \
.dtype_format(DataType.U8_Default, DataType.F16_Default) \
.dtype_format(DataType.U8_Default, DataType.F32_Default) \
.dtype_format(DataType.U8_Default, DataType.I32_Default) \
.dtype_format(DataType.I32_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I32_Default, DataType.F16_Default) \
.dtype_format(DataType.I32_Default, DataType.F32_Default) \
.dtype_format(DataType.I32_Default, DataType.I8_Default) \
.dtype_format(DataType.I32_Default, DataType.U8_Default) \
.dtype_format(DataType.F16_Default, DataType.U8_Default) \
.dtype_format(DataType.F16_Default, DataType.F32_Default) \
.dtype_format(DataType.F16_Default, DataType.I32_Default) \
.dtype_format(DataType.F32_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I32_Default) \
.get_op_info()

@op_info_register("""{
"op_name": "Cast",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "cast.so",
"compute_cost": 10,
"kernel_name": "cast",
"partial_flag": true,
"attr": [
{
"name": "dst_type",
"param_type": "required",
"type": "int",
"value": "all"
}
],
"inputs": [
{
"index": 0,
"dtype": [
"float16", "float16", "float", "float",
"int32", "int32", "int32", "int32", "int32",
"int8", "int8", "int8", "uint8", "uint8", "uint8",
"bool", "bool", "bool", "bool", "float16"
],
"format": [
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat",
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat",
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat",
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat"
],
"name": "x",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float", "int32", "float16", "int32",
"float16", "float", "int8", "uint8", "bool",
"float16", "float", "int32", "float16", "float", "int32",
"float16", "float", "int32", "uint8", "uint8"
],
"format": [
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat",
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat",
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat",
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat"
],
"name": "y",
"need_compile": true,
"param_type": "required",
"shape": "all"
}
]
}""")

@op_info_register(cast_op_info)
def _cast_tbe():
"""Cast TBE register"""
return

+ 20
- 82
mindspore/ops/_op_impl/tbe/clip_by_norm_no_div_sum.py View File

@@ -14,90 +14,28 @@
# ============================================================================

"""ClipByNormNoDivSum op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

clip_by_norm_no_div_sum_op_info = TBERegOp("ClipByNormNoDivSum") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("clip_by_norm_no_div_sum.so") \
.compute_cost(10) \
.kernel_name("clip_by_norm_no_div_sum") \
.partial_flag(True) \
.input(0, "input_x", False, "required", "all") \
.input(1, "input1", False, "required", "all") \
.input(2, "input2", False, "required", "all") \
.input(3, "input3", False, "required", "all") \
.output(0, "output_y", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default) \
.get_op_info()

@op_info_register("""{
"op_name": "ClipByNormNoDivSum",
"imply_type": "TBE",
"fusion_type": "ELEMWISE",
"async_flag": false,
"binfile_name": "clip_by_norm_no_div_sum.so",
"compute_cost": 10,
"kernel_name": "clip_by_norm_no_div_sum",
"partial_flag": true,
"attr":[
],
"inputs": [
{
"index": 0,
"dtype": [
"float16", "float32"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "input_x",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype": [
"float16", "float32"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "input1",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 2,
"dtype": [
"float16", "float32"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "input2",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 3,
"dtype": [
"float16", "float32"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "input3",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16", "float32"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "output_y",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
]
}""")

@op_info_register(clip_by_norm_no_div_sum_op_info)
def _clip_by_norm_no_div_sum_tbe():
"""ClipByNormNoDivSum TBE register"""
return

+ 22
- 77
mindspore/ops/_op_impl/tbe/clip_by_value.py View File

@@ -14,85 +14,30 @@
# ============================================================================

"""ClipByValue op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

clip_by_value_op_info = TBERegOp("ClipByValue") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("clip_by_value.so") \
.compute_cost(10) \
.kernel_name("clip_by_value") \
.partial_flag(True) \
.attr("dst_type", "required", "int", "all") \
.input(0, "x", False, "required", "all") \
.input(1, "clip_value_min", False, "required", "all") \
.input(2, "clip_value_max", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()

@op_info_register("""{
"op_name": "ClipByValue",
"imply_type": "TBE",
"fusion_type": "ELEMWISE",
"async_flag": false,
"binfile_name": "clip_by_value.so",
"compute_cost": 10,
"kernel_name": "clip_by_value",
"partial_flag": true,
"attr":[
],
"inputs": [
{
"index": 0,
"dtype": [
"float16", "float16", "float16", "float16", "float", "float", "float", "float", "int32",
"int32", "int32", "int32"
],
"format": [
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat"
],
"name": "x",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype": [
"float16", "float16", "float16", "float16", "float", "float", "float", "float", "int32",
"int32", "int32", "int32"
],
"format": [
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat"
],
"name": "clip_value_min",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 2,
"dtype": [
"float16", "float16", "float16", "float16", "float", "float", "float", "float", "int32",
"int32", "int32", "int32"
],
"format": [
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat"
],
"name": "clip_value_max",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16", "float16", "float16", "float16", "float", "float", "float", "float", "int32",
"int32", "int32", "int32"
],
"format": [
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat"
],
"name": "y",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
]
}""")

@op_info_register(clip_by_value_op_info)
def _clip_by_value_tbe():
"""ClipByValue TBE register"""
return

+ 36
- 133
mindspore/ops/_op_impl/tbe/concat.py View File

@@ -14,141 +14,44 @@
# ============================================================================

"""Concat op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

concat_op_info = TBERegOp("Concat") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("concat_d.so") \
.compute_cost(10) \
.kernel_name("concat_d") \
.partial_flag(True) \
.attr("axis", "required", "int", "all") \
.input(0, "input_values", False, "dynamic", "all") \
.output(0, "output_data", False, "required", "all") \
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
.dtype_format(DataType.BOOL_5HD, DataType.BOOL_5HD) \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I8_5HD, DataType.I8_5HD) \
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.U8_5HD, DataType.U8_5HD) \
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
.dtype_format(DataType.I16_5HD, DataType.I16_5HD) \
.dtype_format(DataType.U16_Default, DataType.U16_Default) \
.dtype_format(DataType.U16_5HD, DataType.U16_5HD) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I32_5HD, DataType.I32_5HD) \
.dtype_format(DataType.U32_Default, DataType.U32_Default) \
.dtype_format(DataType.U32_5HD, DataType.U32_5HD) \
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.I64_5HD, DataType.I64_5HD) \
.dtype_format(DataType.U64_Default, DataType.U64_Default) \
.dtype_format(DataType.U64_5HD, DataType.U64_5HD) \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()

@op_info_register("""{
"op_name": "Concat",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "concat_d.so",
"compute_cost": 10,
"kernel_name": "concat_d",
"partial_flag": true,
"attr": [
{
"name": "axis",
"param_type": "required",
"type": "int",
"value": "all"
}
],
"inputs": [
{
"index": 0,
"dtype": [
"float16",
"float16",
"float",
"float",
"int32",
"int32",
"int8",
"int8",
"int16",
"int16",
"int64",
"int64",
"uint8",
"uint8",
"uint16",
"uint16",
"uint32",
"uint32",
"uint64",
"uint64",
"bool",
"bool"
],
"format": [
"DefaultFormat",
"NC1HWC0",
"DefaultFormat",
"NC1HWC0",
"DefaultFormat",
"NC1HWC0",
"DefaultFormat",
"NC1HWC0",
"DefaultFormat",
"NC1HWC0",
"DefaultFormat",
"NC1HWC0",
"DefaultFormat",
"NC1HWC0",
"DefaultFormat",
"NC1HWC0",
"DefaultFormat",
"NC1HWC0",
"DefaultFormat",
"NC1HWC0",
"DefaultFormat",
"NC1HWC0"
],
"name": "input_values",
"need_compile": false,
"param_type": "dynamic",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16",
"float16",
"float",
"float",
"int32",
"int32",
"int8",
"int8",
"int16",
"int16",
"int64",
"int64",
"uint8",
"uint8",
"uint16",
"uint16",
"uint32",
"uint32",
"uint64",
"uint64",
"bool",
"bool"
],
"format": [
"DefaultFormat",
"NC1HWC0",
"DefaultFormat",
"NC1HWC0",
"DefaultFormat",
"NC1HWC0",
"DefaultFormat",
"NC1HWC0",
"DefaultFormat",
"NC1HWC0",
"DefaultFormat",
"NC1HWC0",
"DefaultFormat",
"NC1HWC0",
"DefaultFormat",
"NC1HWC0",
"DefaultFormat",
"NC1HWC0",
"DefaultFormat",
"NC1HWC0",
"DefaultFormat",
"NC1HWC0"
],
"name": "output_data",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
]
}""")

@op_info_register(concat_op_info)
def _concat_tbe():
"""Concat TBE register"""
return

+ 19
- 56
mindspore/ops/_op_impl/tbe/confusion_softmax_grad.py View File

@@ -14,65 +14,28 @@
# ============================================================================

"""ConfusionSoftmaxGrad op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

confusion_softmax_grad_op_info = TBERegOp("ConfusionSoftmaxGrad") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("confusion_softmax_grad.so") \
.compute_cost(10) \
.kernel_name("confusion_softmax_grad") \
.partial_flag(True) \
.input(0, "grad", False, "required", "all") \
.input(1, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ) \
.get_op_info()

@op_info_register("""{
"op_name": "ConfusionSoftmaxGrad",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "confusion_softmax_grad.so",
"compute_cost": 10,
"kernel_name": "confusion_softmax_grad",
"partial_flag": true,
"attr": [

],
"inputs": [
{
"index": 0,
"dtype": [
"float16", "float16", "float16", "float", "float", "float"
],
"format": [
"FRACTAL_NZ", "DefaultFormat", "NC1HWC0", "FRACTAL_NZ", "DefaultFormat", "NC1HWC0"
],
"name": "grad",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype": [
"float16", "float16", "float16", "float", "float", "float"
],
"format": [
"FRACTAL_NZ", "DefaultFormat", "NC1HWC0", "FRACTAL_NZ", "DefaultFormat", "NC1HWC0"
],
"name": "x",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16", "float16", "float16", "float", "float", "float"
],
"format": [
"FRACTAL_NZ", "DefaultFormat", "NC1HWC0", "FRACTAL_NZ", "DefaultFormat", "NC1HWC0"
],
"name": "y",
"need_compile": true,
"param_type": "required",
"shape": "all"
}
]
}""")
@op_info_register(confusion_softmax_grad_op_info)
def _confusion_softmax_grad_tbe():
"""ConfusionSoftmaxGrad TBE register"""
return

+ 36
- 71
mindspore/ops/_op_impl/tbe/confusion_transpose_d.py View File

@@ -14,79 +14,44 @@
# ============================================================================

"""ConfusionTransposeD op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

confusion_transpose_d_op_info = TBERegOp("ConfusionTransposeD") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("confusion_transpose_d.so") \
.compute_cost(10) \
.kernel_name("confusion_transpose_d") \
.partial_flag(True) \
.attr("perm", "required", "listInt", "all") \
.attr("shape", "required", "listInt", "all") \
.attr("transpose_first", "required", "bool", "all") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.I8_FracNZ, DataType.I8_FracNZ) \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.U8_FracNZ, DataType.U8_FracNZ) \
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.I16_FracNZ, DataType.I16_FracNZ) \
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
.dtype_format(DataType.U16_FracNZ, DataType.U16_FracNZ) \
.dtype_format(DataType.U16_Default, DataType.U16_Default) \
.dtype_format(DataType.I32_FracNZ, DataType.I32_FracNZ) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.U32_FracNZ, DataType.U32_FracNZ) \
.dtype_format(DataType.U32_Default, DataType.U32_Default) \
.dtype_format(DataType.I64_FracNZ, DataType.I64_FracNZ) \
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.U64_FracNZ, DataType.U64_FracNZ) \
.dtype_format(DataType.U64_Default, DataType.U64_Default) \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ) \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.get_op_info()

@op_info_register("""{
"op_name": "ConfusionTransposeD",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "confusion_transpose_d.so",
"compute_cost": 10,
"kernel_name": "confusion_transpose_d",
"partial_flag": true,
"attr":[
{
"name":"perm",
"param_type":"required",
"type":"listInt",
"value":"all"
},
{
"name":"shape",
"param_type":"required",
"type":"listInt",
"value":"all"
},
{
"name":"transpose_first",
"param_type":"required",
"type":"bool",
"value":"all"
}
],
"inputs": [
{
"index": 0,
"dtype": [
"float16", "float", "int8", "int16", "int32", "int64", "uint8", "uint16", "uint32",
"uint64", "float16", "float", "int8", "int16", "int32", "int64", "uint8", "uint16",
"uint32", "uint64"
],
"format": [
"FRACTAL_NZ", "FRACTAL_NZ", "FRACTAL_NZ", "FRACTAL_NZ", "FRACTAL_NZ", "FRACTAL_NZ",
"FRACTAL_NZ", "FRACTAL_NZ", "FRACTAL_NZ", "FRACTAL_NZ", "DefaultFormat", "DefaultFormat",
"DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat",
"DefaultFormat", "DefaultFormat", "DefaultFormat"
],
"name": "x",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16", "float", "int8", "int16", "int32", "int64", "uint8", "uint16", "uint32",
"uint64", "float16", "float", "int8", "int16", "int32", "int64", "uint8", "uint16",
"uint32", "uint64"
],
"format": [
"FRACTAL_NZ", "FRACTAL_NZ", "FRACTAL_NZ", "FRACTAL_NZ", "FRACTAL_NZ", "FRACTAL_NZ",
"FRACTAL_NZ", "FRACTAL_NZ", "FRACTAL_NZ", "FRACTAL_NZ", "DefaultFormat", "DefaultFormat",
"DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat",
"DefaultFormat", "DefaultFormat", "DefaultFormat"
],
"name": "y",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
]
}""")

@op_info_register(confusion_transpose_d_op_info)
def _confusion_transpose_d_tbe():
"""ConfusionTransposeD TBE register"""
return

+ 22
- 106
mindspore/ops/_op_impl/tbe/conv2d.py View File

@@ -14,114 +14,30 @@
# ============================================================================

"""Conv2D op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

conv2d_op_info = TBERegOp("Conv2D") \
.fusion_type("CONVLUTION") \
.async_flag(False) \
.binfile_name("conv2d.so") \
.compute_cost(10) \
.kernel_name("conv2d") \
.partial_flag(True) \
.attr("stride", "required", "listInt", "all") \
.attr("pad_list", "required", "listInt", "all") \
.attr("dilation", "required", "listInt", "all") \
.attr("offset_a", "optional", "int", "all") \
.input(0, "x", False, "required", "all") \
.input(1, "filter", False, "required", "all") \
.input(2, "bias", False, "optional", "all") \
.input(3, "offset_w", False, "optional", "all") \
.output(0, "y", True, "required", "all") \
.dtype_format(DataType.F16_5HD, DataType.F16_FracZ, DataType.F16_Default, DataType.I8_Default,
DataType.F16_5HD) \
.get_op_info()

@op_info_register("""{
"op_name": "Conv2D",
"imply_type": "TBE",
"fusion_type": "CONVLUTION",
"async_flag": false,
"binfile_name": "conv2d.so",
"compute_cost": 10,
"kernel_name": "conv2d",
"partial_flag": true,
"attr": [
{
"name": "stride",
"param_type": "required",
"type": "listInt",
"value": "all"
},
{
"name": "pad_list",
"param_type": "required",
"type": "listInt",
"value": "all"
},
{
"name": "dilation",
"param_type": "required",
"type": "listInt",
"value": "all"
},
{
"name": "offset_a",
"param_type": "optional",
"type": "int",
"value": "all"
}
],
"inputs": [
{
"index": 0,
"dtype": [
"float16"
],
"format": [
"NC1HWC0"
],
"name": "x",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype": [
"float16"
],
"format": [
"FracZ"
],
"name": "filter",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 2,
"dtype": [
"float16"
],
"format": [
"DefaultFormat"
],
"name": "bias",
"need_compile": false,
"param_type": "optional",
"shape": "all"
},
{
"index": 3,
"dtype": [
"int8"
],
"format": [
"DefaultFormat"
],
"name": "offset_w",
"need_compile": false,
"param_type": "optional",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16"
],
"format": [
"NC1HWC0"
],
"name": "y",
"need_compile": true,
"param_type": "required",
"shape": "all"
}
]
}""")

@op_info_register(conv2d_op_info)
def _conv2d_tbe():
"""Conv2D TBE register"""
return

+ 19
- 81
mindspore/ops/_op_impl/tbe/conv2d_backprop_filter.py View File

@@ -14,89 +14,27 @@
# ============================================================================

"""Conv2DBackpropFilter op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

conv2d_backprop_filter_op_info = TBERegOp("Conv2DBackpropFilter") \
.fusion_type("CONVLUTION") \
.async_flag(False) \
.binfile_name("conv2d_backprop_filter_d.so") \
.compute_cost(10) \
.kernel_name("conv2d_backprop_filter_d") \
.partial_flag(True) \
.attr("filter_sizes", "required", "listInt", "all") \
.attr("stride", "required", "listInt", "all") \
.attr("pad_mode", "required", "str", "all") \
.attr("dilation", "required", "listInt", "all") \
.input(0, "out_backprop", False, "required", "all") \
.input(1, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F32_FracZ) \
.get_op_info()

# map to tbe kernel name conv2d_backprop_filter_d
@op_info_register("""{
"op_name": "Conv2DBackpropFilter",
"imply_type": "TBE",
"fusion_type": "CONVLUTION",
"async_flag": false,
"binfile_name": "conv2d_backprop_filter_d.so",
"compute_cost": 10,
"kernel_name": "conv2d_backprop_filter_d",
"partial_flag": true,
"attr": [
{
"name": "filter_sizes",
"param_type": "required",
"type": "listInt",
"value": "all"
},
{
"name": "stride",
"param_type": "required",
"type": "listInt",
"value": "all"
},
{
"name": "pad_mode",
"param_type": "required",
"type": "str",
"value": "all"
},
{
"name": "dilation",
"param_type": "required",
"type": "listInt",
"value": "all"
}
],
"inputs": [
{
"index": 0,
"dtype": [
"float16"
],
"format": [
"NC1HWC0"
],
"name": "out_backprop",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype": [
"float16"
],
"format": [
"NC1HWC0"
],
"name": "x",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float32"
],
"format": [
"FracZ"
],
"name": "y",
"need_compile": true,
"param_type": "required",
"shape": "all"
}
]
}""")

@op_info_register(conv2d_backprop_filter_op_info)
def _conv2d_backprop_filter_tbe():
"""Conv2DBackpropFilter TBE register"""
return

+ 19
- 80
mindspore/ops/_op_impl/tbe/conv2d_backprop_input.py View File

@@ -14,88 +14,27 @@
# ============================================================================

"""Conv2DBackpropInput op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

conv2d_backprop_input_op_info = TBERegOp("Conv2DBackpropInput") \
.fusion_type("CONVLUTION") \
.async_flag(False) \
.binfile_name("conv2d_backprop_input_d.so") \
.compute_cost(10) \
.kernel_name("conv2d_backprop_input_d") \
.partial_flag(True) \
.attr("input_sizes", "required", "listInt", "all") \
.attr("stride", "required", "listInt", "all") \
.attr("pad_mode", "required", "str", "all") \
.attr("dilation", "required", "listInt", "all") \
.input(0, "out_backprop", False, "required", "all") \
.input(1, "filter", False, "required", "all") \
.output(0, "y", True, "required", "all") \
.dtype_format(DataType.F16_5HD, DataType.F16_FracZ, DataType.F16_5HD) \
.get_op_info()

@op_info_register("""{
"op_name": "Conv2DBackpropInput",
"imply_type": "TBE",
"fusion_type": "CONVLUTION",
"async_flag": false,
"binfile_name": "conv2d_backprop_input_d.so",
"compute_cost": 10,
"kernel_name": "conv2d_backprop_input_d",
"partial_flag": true,
"attr": [
{
"name": "input_sizes",
"param_type": "required",
"type": "listInt",
"value": "all"
},
{
"name": "stride",
"param_type": "required",
"type": "listInt",
"value": "all"
},
{
"name": "pad_mode",
"param_type": "required",
"type": "str",
"value": "all"
},
{
"name": "dilation",
"param_type": "required",
"type": "listInt",
"value": "all"
}
],
"inputs": [
{
"index": 0,
"dtype": [
"float16"
],
"format": [
"NC1HWC0"
],
"name": "out_backprop",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype": [
"float16"
],
"format": [
"FracZ"
],
"name": "filter",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16"
],
"format": [
"NC1HWC0"
],
"name": "y",
"need_compile": true,
"param_type": "required",
"shape": "all"
}
]
}""")

@op_info_register(conv2d_backprop_input_op_info)
def _conv2d_backprop_input_tbe():
"""Conv2DBackpropInput TBE register"""
return

+ 23
- 62
mindspore/ops/_op_impl/tbe/div.py View File

@@ -14,71 +14,32 @@
# ============================================================================

"""Div op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

div_op_info = TBERegOp("Div") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("div.so") \
.compute_cost(10) \
.kernel_name("div") \
.partial_flag(True) \
.input(0, "x1", False, "required", "all") \
.input(1, "x2", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I8_5HD, DataType.I8_5HD, DataType.I8_5HD) \
.dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.U8_5HD, DataType.U8_5HD, DataType.U8_5HD) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()

@op_info_register("""{
"op_name": "Div",
"imply_type": "TBE",
"fusion_type": "ELEMWISE",
"async_flag": false,
"binfile_name": "div.so",
"compute_cost": 10,
"kernel_name": "div",
"partial_flag": true,
"attr": [

],
"inputs": [
{
"index": 0,
"dtype": [
"float16", "float16", "float", "float", "int32", "int32", "int8", "int8", "uint8", "uint8"
],
"format": [
"DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0",
"DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0",
"DefaultFormat", "NC1HWC0"
],
"name": "x1",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 0,
"dtype": [
"float16", "float16", "float", "float", "int32", "int32", "int8", "int8", "uint8", "uint8"
],
"format": [
"DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0",
"DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0",
"DefaultFormat", "NC1HWC0"
],
"name": "x2",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16", "float16", "float", "float", "int32", "int32", "int8", "int8", "uint8", "uint8"
],
"format": [
"DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0",
"DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0",
"DefaultFormat", "NC1HWC0"
],
"name": "y",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
]
}""")
@op_info_register(div_op_info)
def _div_tbe():
"""Div TBE register"""
return

+ 17
- 68
mindspore/ops/_op_impl/tbe/dropout_do_mask.py View File

@@ -14,76 +14,25 @@
# ============================================================================

"""DropoutdoMask op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

drop_out_do_mask_op_info = TBERegOp("DropoutDoMask") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("drop_out_do_mask.so") \
.compute_cost(10) \
.kernel_name("drop_out_do_mask") \
.partial_flag(True) \
.input(0, "x", False, "required", "all") \
.input(1, "mask", False, "required", "all") \
.input(2, "keep_prob", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.U8_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.U8_Default, DataType.F32_Default, DataType.F32_Default) \
.get_op_info()

@op_info_register("""{
"op_name": "DropoutDoMask",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "drop_out_do_mask.so",
"compute_cost": 10,
"kernel_name": "drop_out_do_mask",
"partial_flag": true,
"attr": [
],
"inputs": [
{
"index": 0,
"dtype": [
"float16","float16","float16","float","float","float"
],
"format": [
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat"
],
"name": "x",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype": [
"uint8","uint8","uint8","uint8","uint8","uint8"
],
"format": [
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat"
],
"name": "mask",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 2,
"dtype": [
"float16","float16","float16","float","float","float"
],
"format": [
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat"
],
"name": "keep_prob",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16","float16","float16","float","float","float"
],
"format": [
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat"
],
"name": "y",
"param_type": "required",
"shape": "all"
}
]
}""")

@op_info_register(drop_out_do_mask_op_info)
def _dropout_do_mask_tbe():
"""DropoutdoMask TBE register"""
return

+ 23
- 57
mindspore/ops/_op_impl/tbe/equal.py View File

@@ -14,66 +14,32 @@
# ============================================================================

"""Equal op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

equal_op_info = TBERegOp("Equal") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("equal.so") \
.compute_cost(10) \
.kernel_name("equal") \
.partial_flag(True) \
.input(0, "x1", False, "required", "all") \
.input(1, "x2", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I8_5HD, DataType.I8_5HD, DataType.BOOL_5HD) \
.dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.BOOL_Default) \
.dtype_format(DataType.U8_5HD, DataType.U8_5HD, DataType.BOOL_5HD) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.BOOL_5HD) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.BOOL_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.BOOL_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.BOOL_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.BOOL_5HD) \
.get_op_info()

@op_info_register("""{
"op_name": "Equal",
"imply_type": "TBE",
"fusion_type": "ELEMWISE",
"async_flag": false,
"binfile_name": "equal.so",
"compute_cost": 10,
"kernel_name": "equal",
"partial_flag": true,
"attr": [

],
"inputs": [
{
"index": 0,
"dtype": [
"float16","float16","float","float","int32","int32","int8","int8","uint8","uint8"
],
"format": [
"DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat",
"NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0"
],
"name": "x1",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype": [
"float16","float16","float","float","int32","int32","int8","int8","uint8","uint8"
],
"format": [
"DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat",
"NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0"
],
"name": "x2",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"bool","bool","bool","bool","bool","bool","bool","bool","bool","bool"
],
"format": [
"DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat",
"NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0"
],
"name": "y",
"param_type": "required"
}
]
}""")
@op_info_register(equal_op_info)
def _equal_tbe():
"""Equal TBE register"""
return

+ 16
- 43
mindspore/ops/_op_impl/tbe/exp.py View File

@@ -14,52 +14,25 @@
# ============================================================================

"""Exp op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

exp_op_info = TBERegOp("Exp") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("exp.so") \
.compute_cost(10) \
.kernel_name("exp") \
.partial_flag(True) \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()

@op_info_register("""{
"op_name": "Exp",
"imply_type": "TBE",
"fusion_type": "ELEMWISE",
"async_flag": false,
"binfile_name": "exp.so",
"compute_cost": 10,
"kernel_name": "exp",
"partial_flag": true,
"attr": [

],
"inputs": [
{
"index": 0,
"dtype": [
"float16", "float16", "float", "float"
],
"format": [
"DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0"
],
"name": "x",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16", "float16", "float", "float"
],
"format": [
"DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0"
],
"name": "y",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
]
}""")
@op_info_register(exp_op_info)
def _exp_tbe():
"""Exp TBE register"""
return

+ 17
- 49
mindspore/ops/_op_impl/tbe/expand_dims.py View File

@@ -14,57 +14,25 @@
# ============================================================================

"""ExpandDims op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

expand_dims_op_info = TBERegOp("ExpandDims") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("expand_dims.so") \
.compute_cost(10) \
.kernel_name("expand_dims") \
.partial_flag(True) \
.attr("axis", "required", "listInt", "all") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.get_op_info()

@op_info_register("""{
"op_name": "ExpandDims",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "expand_dims.so",
"compute_cost": 10,
"kernel_name": "expand_dims",
"partial_flag": true,
"attr": [
{
"name": "axis",
"param_type": "required",
"type": "listInt",
"value": "all"
}
],
"inputs": [
{
"index": 0,
"dtype": [
"float16", "float32", "int32"
],
"format": [
"DefaultFormat", "DefaultFormat", "DefaultFormat"
],
"name": "x",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16", "float32", "int32"
],
"format": [
"DefaultFormat", "DefaultFormat", "DefaultFormat"
],
"name": "y",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
]
}""")

@op_info_register(expand_dims_op_info)
def _expand_dims_tbe():
"""ExpandDims TBE register"""
return

+ 18
- 55
mindspore/ops/_op_impl/tbe/floor_div.py View File

@@ -14,64 +14,27 @@
# ============================================================================

"""FloorDiv op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

floordiv_op_info = TBERegOp("FloorDiv") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("floordiv.so") \
.compute_cost(10) \
.kernel_name("floordiv") \
.partial_flag(True) \
.input(0, "x1", False, "required", "all") \
.input(1, "x2", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.get_op_info()

@op_info_register("""{
"op_name": "FloorDiv",
"imply_type": "TBE",
"fusion_type": "ELEMWISE",
"async_flag": false,
"binfile_name": "floordiv.so",
"compute_cost": 10,
"kernel_name": "floordiv",
"partial_flag": true,
"attr": [

],
"inputs": [
{
"index": 0,
"dtype": [
"float16","float","int32","int8","uint8"
],
"format": [
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat"
],
"name": "x1",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype": [
"float16","float","int32","int8","uint8"
],
"format": [
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat"
],
"name": "x2",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16","float","int32","int8","uint8"
],
"format": [
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat"
],
"name": "y",
"param_type": "required",
"shape": "all"
}
]
}""")
@op_info_register(floordiv_op_info)
def _floor_div_tbe():
"""FloorDiv TBE register"""
return

+ 30
- 85
mindspore/ops/_op_impl/tbe/fused_mul_add.py View File

@@ -14,93 +14,38 @@
# ============================================================================

"""FusedMulAdd op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

fused_mul_add_op_info = TBERegOp("FusedMulAdd") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("fused_mul_add.so") \
.compute_cost(10) \
.kernel_name("fused_mul_add") \
.partial_flag(True) \
.input(0, "x1", False, "required", "all") \
.input(1, "x2", False, "required", "all") \
.input(2, "x3", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \
.dtype_format(DataType.I32_FracZ, DataType.I32_FracZ, DataType.I32_FracZ, DataType.I32_FracZ) \
.dtype_format(DataType.I32_FracNZ, DataType.I32_FracNZ, DataType.I32_FracNZ, DataType.I32_FracNZ) \
.dtype_format(DataType.I32_C1HWNCoC0, DataType.I32_C1HWNCoC0, DataType.I32_C1HWNCoC0, DataType.I32_C1HWNCoC0) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ) \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ) \
.dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ) \
.dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ) \
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \
.get_op_info()

@op_info_register("""{
"op_name": "FusedMulAdd",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "fused_mul_add.so",
"compute_cost": 10,
"kernel_name": "fused_mul_add",
"partial_flag": true,
"attr": [
],
"inputs": [
{
"index": 0,
"dtype": [
"int32", "int32", "int32", "int32", "int32",
"float16", "float16", "float16", "float16", "float16",
"float", "float", "float", "float", "float"
],
"format": [
"FRACTAL_NZ", "DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0",
"FRACTAL_NZ", "DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0",
"FRACTAL_NZ", "DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0"
],
"name": "x1",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype": [
"int32", "int32", "int32", "int32", "int32",
"float16", "float16", "float16", "float16", "float16",
"float", "float", "float", "float", "float"
],
"format": [
"FRACTAL_NZ", "DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0",
"FRACTAL_NZ", "DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0",
"FRACTAL_NZ", "DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0"
],
"name": "x2",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 2,
"dtype": [
"int32", "int32", "int32", "int32", "int32",
"float16", "float16", "float16", "float16", "float16",
"float", "float", "float", "float", "float"
],
"format": [
"FRACTAL_NZ", "DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0",
"FRACTAL_NZ", "DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0",
"FRACTAL_NZ", "DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0"
],
"name": "x3",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"int32", "int32", "int32", "int32", "int32",
"float16", "float16", "float16", "float16", "float16",
"float", "float", "float", "float", "float"
],
"format": [
"FRACTAL_NZ", "DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0",
"FRACTAL_NZ", "DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0",
"FRACTAL_NZ", "DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0"
],
"name": "y",
"param_type": "required",
"shape": "all"
}
]
}""")

@op_info_register(fused_mul_add_op_info)
def _fused_mul_add_tbe():
"""FusedMulAdd TBE register"""
return

+ 22
- 77
mindspore/ops/_op_impl/tbe/fused_mul_add_n.py View File

@@ -14,86 +14,31 @@
# ============================================================================

"""FusedMulAddN op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

fused_mul_add_n_op_info = TBERegOp("FusedMulAddN") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("fused_mul_add_n.so") \
.compute_cost(10) \
.kernel_name("fused_mul_add_n") \
.partial_flag(True) \
.input(0, "x1", False, "required", "all") \
.input(1, "x2", False, "required", "all") \
.input(2, "x3", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_Default, DataType.F16_5HD) \
.dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_Default, DataType.F16_C1HWNCoC0) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_Default, DataType.F16_FracZ) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_Default, DataType.F32_5HD) \
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_Default, DataType.F32_C1HWNCoC0) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_Default, DataType.F32_FracZ) \
.get_op_info()

@op_info_register("""{
"op_name": "FusedMulAddN",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "fused_mul_add_n.so",
"compute_cost": 10,
"kernel_name": "fused_mul_add_n",
"partial_flag": true,
"attr": [

],
"inputs": [
{
"index": 0,
"dtype": [
"float16","float16","float16","float16",
"float","float","float","float"
],
"format": [
"NC1HWC0","C1HWNCoC0","DefaultFormat","FracZ",
"NC1HWC0","C1HWNCoC0","DefaultFormat","FracZ"
],
"name": "x1",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype": [
"float16","float16","float16","float16",
"float","float","float","float"
],
"format": [
"NC1HWC0","C1HWNCoC0","DefaultFormat","FracZ",
"NC1HWC0","C1HWNCoC0","DefaultFormat","FracZ"
],
"name": "x2",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 2,
"dtype": [
"float16","float16","float16","float16",
"float","float","float","float"
],
"format": [
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat",
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat"
],
"name": "x3",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16","float16","float16","float16",
"float","float","float","float"
],
"format": [
"NC1HWC0","C1HWNCoC0","DefaultFormat","FracZ",
"NC1HWC0","C1HWNCoC0","DefaultFormat","FracZ"
],
"name": "y",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
]
}""")
@op_info_register(fused_mul_add_n_op_info)
def _fused_mul_add_n_tbe():
"""FusedMulAddN TBE register"""
return

+ 35
- 129
mindspore/ops/_op_impl/tbe/fused_mul_apply_momentum.py View File

@@ -14,137 +14,43 @@
# ============================================================================

"""FusedMulApplyMomentum op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

fused_mul_apply_momentum_op_info = TBERegOp("FusedMulApplyMomentum") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("fused_mul_apply_momentum.so") \
.compute_cost(10) \
.kernel_name("fused_mul_apply_momentum") \
.partial_flag(True) \
.attr("use_nesterov", "optional", "bool", "true,false", "false") \
.input(0, "var", False, "required", "all") \
.input(1, "accum", False, "required", "all") \
.input(2, "lr", False, "required", "all") \
.input(3, "x1", False, "required", "all") \
.input(4, "momentum", False, "required", "all") \
.input(5, "x2", False, "required", "all") \
.output(0, "var", False, "required", "all") \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_Default, DataType.F16_5HD,
DataType.F16_Default, DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_Default, DataType.F16_C1HWNCoC0,
DataType.F16_Default, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_Default, DataType.F16_FracZ,
DataType.F16_Default, DataType.F16_FracZ, DataType.F16_FracZ) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_Default, DataType.F32_5HD,
DataType.F32_Default, DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_Default, DataType.F32_C1HWNCoC0,
DataType.F32_Default, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_Default, DataType.F32_FracZ,
DataType.F32_Default, DataType.F32_FracZ, DataType.F32_FracZ) \
.get_op_info()

@op_info_register("""{
"op_name": "FusedMulApplyMomentum",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "fused_mul_apply_momentum.so",
"compute_cost": 10,
"kernel_name": "fused_mul_apply_momentum",
"partial_flag": true,
"attr": [
{
"name": "use_nesterov",
"param_type": "optional",
"type": "bool",
"value": "true,false",
"default_value":"false"
}
],
"inputs": [
{
"index": 0,
"dtype": [
"float16","float16","float16","float16",
"float","float","float","float"
],
"format": [
"NC1HWC0","C1HWNCoC0","DefaultFormat","FracZ",
"NC1HWC0","C1HWNCoC0","DefaultFormat","FracZ"
],
"name": "var",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype": [
"float16","float16","float16","float16",
"float","float","float","float"
],
"format": [
"NC1HWC0","C1HWNCoC0","DefaultFormat","FracZ",
"NC1HWC0","C1HWNCoC0","DefaultFormat","FracZ"
],
"name": "accum",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 2,
"dtype": [
"float16","float16","float16","float16",
"float","float","float","float"
],
"format": [
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat",
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat"
],
"name": "lr",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 3,
"dtype": [
"float16","float16","float16","float16",
"float","float","float","float"
],
"format": [
"NC1HWC0","C1HWNCoC0","DefaultFormat","FracZ",
"NC1HWC0","C1HWNCoC0","DefaultFormat","FracZ"
],
"name": "x1",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 4,
"dtype": [
"float16","float16","float16","float16",
"float","float","float","float"
],
"format": [
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat",
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat"
],
"name": "momentum",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 5,
"dtype": [
"float16","float16","float16","float16",
"float","float","float","float"
],
"format": [
"NC1HWC0","C1HWNCoC0","DefaultFormat","FracZ",
"NC1HWC0","C1HWNCoC0","DefaultFormat","FracZ"
],
"name": "x2",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16","float16","float16","float16",
"float","float","float","float"
],
"format": [
"NC1HWC0","C1HWNCoC0","DefaultFormat","FracZ",
"NC1HWC0","C1HWNCoC0","DefaultFormat","FracZ"
],
"name": "var",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
]
}""")

@op_info_register(fused_mul_apply_momentum_op_info)
def _fused_mul_apply_momentum_tbe():
"""FusedMulApplyMomentum TBE register"""
return

+ 45
- 86
mindspore/ops/_op_impl/tbe/gather_v2.py View File

@@ -14,94 +14,53 @@
# ============================================================================

"""AddN op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

gather_v2_op_info = TBERegOp("GatherV2") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("gather_v2_d.so") \
.compute_cost(10) \
.kernel_name("gather_v2_d") \
.partial_flag(True) \
.attr("axis", "optional", "int", "all") \
.input(0, "x", False, "required", "all") \
.input(1, "indices", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default) \
.dtype_format(DataType.I8_Default, DataType.I64_Default, DataType.I8_Default) \
.dtype_format(DataType.I8_5HD, DataType.I32_5HD, DataType.I8_5HD) \
.dtype_format(DataType.I8_5HD, DataType.I64_5HD, DataType.I8_5HD) \
.dtype_format(DataType.I8_FracZ, DataType.I32_FracZ, DataType.I8_FracZ) \
.dtype_format(DataType.I8_FracZ, DataType.I64_FracZ, DataType.I8_FracZ) \
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default) \
.dtype_format(DataType.U8_Default, DataType.I64_Default, DataType.U8_Default) \
.dtype_format(DataType.U8_5HD, DataType.I32_5HD, DataType.U8_5HD) \
.dtype_format(DataType.U8_5HD, DataType.I64_5HD, DataType.U8_5HD) \
.dtype_format(DataType.U8_FracZ, DataType.I32_FracZ, DataType.U8_FracZ) \
.dtype_format(DataType.U8_FracZ, DataType.I64_FracZ, DataType.U8_FracZ) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I32_Default) \
.dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \
.dtype_format(DataType.I32_5HD, DataType.I64_5HD, DataType.I32_5HD) \
.dtype_format(DataType.I32_FracZ, DataType.I32_FracZ, DataType.I32_FracZ) \
.dtype_format(DataType.I32_FracZ, DataType.I64_FracZ, DataType.I32_FracZ) \
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.I32_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F16_5HD, DataType.I64_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F16_FracZ, DataType.I32_FracZ, DataType.F16_FracZ) \
.dtype_format(DataType.F16_FracZ, DataType.I64_FracZ, DataType.F16_FracZ) \
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.I32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_5HD, DataType.I64_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_FracZ, DataType.I32_FracZ, DataType.F32_FracZ) \
.dtype_format(DataType.F32_FracZ, DataType.I64_FracZ, DataType.F32_FracZ) \
.get_op_info()

@op_info_register("""{
"op_name": "GatherV2",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "gather_v2_d.so",
"compute_cost": 10,
"kernel_name": "gather_v2_d",
"partial_flag": true,
"attr": [
{
"name": "axis",
"param_type": "optional",
"type": "int",
"value": "all"
}
],
"inputs": [
{
"index": 0,
"dtype": [
"float16","float16","float16","float16","float16","float16",
"float","float","float","float","float","float",
"int32","int32","int32", "int32","int32","int32",
"uint8","uint8","uint8","uint8","uint8","uint8",
"int8","int8", "int8","int8","int8", "int8"
],
"format": [
"DefaultFormat","NC1HWC0","FracZ","DefaultFormat","NC1HWC0","FracZ",
"DefaultFormat","NC1HWC0","FracZ","DefaultFormat","NC1HWC0","FracZ",
"DefaultFormat","NC1HWC0","FracZ","DefaultFormat","NC1HWC0","FracZ",
"DefaultFormat","NC1HWC0","FracZ","DefaultFormat","NC1HWC0","FracZ",
"DefaultFormat","NC1HWC0","FracZ","DefaultFormat","NC1HWC0","FracZ"
],
"name": "x",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype": [
"int32","int32","int32","int64","int64","int64",
"int32","int32","int32","int64","int64","int64",
"int32","int32","int32","int64","int64","int64",
"int32","int32","int32","int64","int64","int64",
"int32","int32","int32","int64","int64","int64"
],
"format": [
"DefaultFormat","NC1HWC0","FracZ","DefaultFormat","NC1HWC0","FracZ",
"DefaultFormat","NC1HWC0","FracZ","DefaultFormat","NC1HWC0","FracZ",
"DefaultFormat","NC1HWC0","FracZ","DefaultFormat","NC1HWC0","FracZ",
"DefaultFormat","NC1HWC0","FracZ","DefaultFormat","NC1HWC0","FracZ",
"DefaultFormat","NC1HWC0","FracZ","DefaultFormat","NC1HWC0","FracZ"
],
"name": "indices",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16","float16","float16","float16","float16","float16",
"float","float","float","float","float","float",
"int32","int32","int32", "int32","int32","int32",
"uint8","uint8","uint8","uint8","uint8","uint8",
"int8","int8", "int8","int8","int8", "int8"
],
"format": [
"DefaultFormat","NC1HWC0","FracZ","DefaultFormat","NC1HWC0","FracZ",
"DefaultFormat","NC1HWC0","FracZ","DefaultFormat","NC1HWC0","FracZ",
"DefaultFormat","NC1HWC0","FracZ","DefaultFormat","NC1HWC0","FracZ",
"DefaultFormat","NC1HWC0","FracZ","DefaultFormat","NC1HWC0","FracZ",
"DefaultFormat","NC1HWC0","FracZ","DefaultFormat","NC1HWC0","FracZ"
],
"name": "y",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
]
}""")

@op_info_register(gather_v2_op_info)
def _gather_v2_tbe():
"""GatherV2 TBE register"""
return

+ 21
- 43
mindspore/ops/_op_impl/tbe/gelu.py View File

@@ -14,51 +14,29 @@
# ============================================================================

"""Gelu op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

gelu_op_info = TBERegOp("Gelu") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("gelu.so") \
.compute_cost(10) \
.kernel_name("gelu") \
.partial_flag(True) \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F16_FracZ, DataType.F16_FracZ) \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ) \
.dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ) \
.get_op_info()

@op_info_register("""{
"op_name": "Gelu",
"imply_type": "TBE",
"fusion_type": "ELEMWISE",
"async_flag": false,
"binfile_name": "gelu.so",
"compute_cost": 10,
"kernel_name": "gelu",
"partial_flag": true,
"attr": [
],
"inputs": [
{
"index": 0,
"dtype": [
"float16","float","float16","float","float16","float16","float16","float16","float","float","float","float"
],
"format": [
"FRACTAL_NZ","FRACTAL_NZ","FracZ","FracZ","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat"
],
"name": "x",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16","float","float16","float","float16","float16","float16","float16","float","float","float","float"
],
"format": [
"FRACTAL_NZ","FRACTAL_NZ","FracZ","FracZ","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat"
],
"name": "y",
"need_compile": true,
"param_type": "required",
"shape": "all"
}
]
}""")

@op_info_register(gelu_op_info)
def _gelu_tbe():
"""Gelu TBE register"""
return

+ 21
- 69
mindspore/ops/_op_impl/tbe/gelu_grad.py View File

@@ -14,77 +14,29 @@
# ============================================================================

"""GeluGrad op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

gelu_grad_op_info = TBERegOp("GeluGrad") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("gelu_grad.so") \
.compute_cost(10) \
.kernel_name("gelu_grad") \
.partial_flag(True) \
.input(0, "dy", False, "required", "all") \
.input(1, "x", False, "required", "all") \
.input(2, "y", False, "required", "all") \
.output(0, "z", True, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ) \
.get_op_info()

@op_info_register("""{
"op_name": "GeluGrad",
"imply_type": "TBE",
"fusion_type": "ELEMWISE",
"async_flag": false,
"binfile_name": "gelu_grad.so",
"compute_cost": 10,
"kernel_name": "gelu_grad",
"partial_flag": true,
"attr": [
],
"inputs": [
{
"index": 0,
"dtype": [
"float16","float16","float16","float","float","float"
],
"format": [
"FRACTAL_NZ","DefaultFormat","NC1HWC0","FRACTAL_NZ","DefaultFormat","NC1HWC0"
],
"name": "dy",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype": [
"float16","float16","float16","float","float","float"
],
"format": [
"FRACTAL_NZ","DefaultFormat","NC1HWC0","FRACTAL_NZ","DefaultFormat","NC1HWC0"
],
"name": "x",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 2,
"dtype": [
"float16","float16","float16","float","float","float"
],
"format": [
"FRACTAL_NZ","DefaultFormat","NC1HWC0","FRACTAL_NZ","DefaultFormat","NC1HWC0"
],
"name": "y",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16","float16","float16","float","float","float"
],
"format": [
"FRACTAL_NZ","DefaultFormat","NC1HWC0","FRACTAL_NZ","DefaultFormat","NC1HWC0"
],
"name": "z",
"need_compile": true,
"param_type": "required",
"shape": "all"
}
]
}""")

@op_info_register(gelu_grad_op_info)
def _gelu_grad_tbe():
"""GeluGrad TBE register"""
return

+ 23
- 59
mindspore/ops/_op_impl/tbe/greater.py View File

@@ -14,68 +14,32 @@
# ============================================================================

"""Greater op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

greater_op_info = TBERegOp("Greater") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("greater.so") \
.compute_cost(10) \
.kernel_name("greater") \
.partial_flag(True) \
.input(0, "x1", False, "required", "all") \
.input(1, "x2", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I8_5HD, DataType.I8_5HD, DataType.BOOL_5HD) \
.dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.BOOL_Default) \
.dtype_format(DataType.U8_5HD, DataType.U8_5HD, DataType.BOOL_5HD) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.BOOL_5HD) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.BOOL_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.BOOL_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.BOOL_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.BOOL_5HD) \
.get_op_info()

@op_info_register("""{
"op_name": "Greater",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "greater.so",
"compute_cost": 10,
"kernel_name": "greater",
"partial_flag": true,
"attr": [

],
"inputs": [
{
"index": 0,
"dtype": [
"float16","float16","float","float","int32","int32","int8","int8","uint8","uint8"
],
"format": [
"DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat",
"NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0"
],
"name": "x1",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype": [
"float16","float16","float","float","int32","int32","int8","int8","uint8","uint8"
],
"format": [
"DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat",
"NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0"
],
"name": "x2",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"bool","bool","bool","bool","bool","bool","bool","bool","bool","bool"
],
"format": [
"DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat",
"NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0"
],
"name": "y",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
]
}""")
@op_info_register(greater_op_info)
def _greater_tbe():
"""Greater TBE register"""
return

+ 38
- 271
mindspore/ops/_op_impl/tbe/lamb_next_mv.py View File

@@ -14,279 +14,46 @@
# ============================================================================

"""LambNextMV op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

lamb_next_mv_op_info = TBERegOp("LambNextMV") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("lamb_next_m_v.so") \
.compute_cost(10) \
.kernel_name("lamb_next_m_v") \
.partial_flag(True) \
.input(0, "input1", False, "required", "all") \
.input(1, "input2", False, "required", "all") \
.input(2, "input3", False, "required", "all") \
.input(3, "input4", False, "required", "all") \
.input(4, "input5", False, "required", "all") \
.input(5, "input6", False, "required", "all") \
.input(6, "input7", False, "required", "all") \
.input(7, "input8", False, "required", "all") \
.input(8, "input9", False, "required", "all") \
.input(9, "inputx0", False, "required", "all") \
.input(10, "inputx1", False, "required", "all") \
.input(11, "inputx2", False, "required", "all") \
.input(12, "inputx3", False, "required", "all") \
.output(0, "output1", False, "required", "all") \
.output(1, "output2", False, "required", "all") \
.output(2, "output3", False, "required", "all") \
.output(3, "output4", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default) \
.get_op_info()

@op_info_register("""{
"op_name":"LambNextMV",
"imply_type":"TBE",
"fusion_type":"ELEMWISE",
"async_flag":false,
"binfile_name":"lamb_next_m_v.so",
"compute_cost":10,
"kernel_name":"lamb_next_m_v",
"partial_flag":true,
"attr":[],
"inputs":[
{
"index":0,
"dtype":[
"float16",
"float32"
],
"format":[
"DefaultFormat",
"DefaultFormat"
],
"name":"input1",
"need_compile":false,
"param_type":"required",
"shape":"all"
},
{
"index":1,
"dtype":[
"float16",
"float32"
],
"format":[
"DefaultFormat",
"DefaultFormat"
],
"name":"input2",
"need_compile":false,
"param_type":"required",
"shape":"all"
},
{
"index":2,
"dtype":[
"float16",
"float32"
],
"format":[
"DefaultFormat",
"DefaultFormat"
],
"name":"input3",
"need_compile":false,
"param_type":"required",
"shape":"all"
},
{
"index":3,
"dtype":[
"float16",
"float32"
],
"format":[
"DefaultFormat",
"DefaultFormat"
],
"name":"input4",
"need_compile":false,
"param_type":"required",
"shape":"all"
},
{
"index":4,
"dtype":[
"float16",
"float32"
],
"format":[
"DefaultFormat",
"DefaultFormat"
],
"name":"input5",
"need_compile":false,
"param_type":"required",
"shape":"all"
},
{
"index":5,
"dtype":[
"float16",
"float32"
],
"format":[
"DefaultFormat",
"DefaultFormat"
],
"name":"input6",
"need_compile":false,
"param_type":"required",
"shape":"all"
},
{
"index":6,
"dtype":[
"float16",
"float32"
],
"format":[
"DefaultFormat",
"DefaultFormat"
],
"name":"input7",
"need_compile":false,
"param_type":"required",
"shape":"all"
},
{
"index":7,
"dtype":[
"float16",
"float32"
],
"format":[
"DefaultFormat",
"DefaultFormat"
],
"name":"input8",
"need_compile":false,
"param_type":"required",
"shape":"all"
},
{
"index":8,
"dtype":[
"float16",
"float32"
],
"format":[
"DefaultFormat",
"DefaultFormat"
],
"name":"input9",
"need_compile":false,
"param_type":"required",
"shape":"all"
},
{
"index":9,
"dtype":[
"float16",
"float32"
],
"format":[
"DefaultFormat",
"DefaultFormat"
],
"name":"inputx0",
"need_compile":false,
"param_type":"required",
"shape":"all"
},
{
"index":10,
"dtype":[
"float16",
"float32"
],
"format":[
"DefaultFormat",
"DefaultFormat"
],
"name":"inputx1",
"need_compile":false,
"param_type":"required",
"shape":"all"
},
{
"index":11,
"dtype":[
"float16",
"float32"
],
"format":[
"DefaultFormat",
"DefaultFormat"
],
"name":"inputx2",
"need_compile":false,
"param_type":"required",
"shape":"all"
},
{
"index":12,
"dtype":[
"float16",
"float32"
],
"format":[
"DefaultFormat",
"DefaultFormat"
],
"name":"inputx3",
"need_compile":false,
"param_type":"required",
"shape":"all"
}
],
"outputs":[
{
"index":0,
"dtype":[
"float16",
"float32"
],
"format":[
"DefaultFormat",
"DefaultFormat"
],
"name":"output1",
"need_compile":false,
"param_type":"required",
"shape":"all"
},
{
"index":1,
"dtype":[
"float16",
"float32"
],
"format":[
"DefaultFormat",
"DefaultFormat"
],
"name":"output2",
"need_compile":false,
"param_type":"required",
"shape":"all"
},
{
"index":2,
"dtype":[
"float16",
"float32"
],
"format":[
"DefaultFormat",
"DefaultFormat"
],
"name":"output3",
"need_compile":false,
"param_type":"required",
"shape":"all"
},
{
"index":3,
"dtype":[
"float16",
"float32"
],
"format":[
"DefaultFormat",
"DefaultFormat"
],
"name":"output4",
"need_compile":false,
"param_type":"required",
"shape":"all"
}
]
}""")

@op_info_register(lamb_next_mv_op_info)
def _lamb_next_mv_tbe():
"""LambNextMV TBE register"""
return

+ 38
- 271
mindspore/ops/_op_impl/tbe/lamb_next_mv_with_decay_v1.py View File

@@ -14,279 +14,46 @@
# ============================================================================

"""LambNextMVWithDecayV1 op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

lamb_next_m_v_with_decay_v1_op_info = TBERegOp("LambNextMVWithDecayV1") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("lamb_next_m_v_with_decay_v1.so") \
.compute_cost(10) \
.kernel_name("lamb_next_m_v_with_decay_v1") \
.partial_flag(True) \
.input(0, "input1", False, "required", "all") \
.input(1, "input2", False, "required", "all") \
.input(2, "input3", False, "required", "all") \
.input(3, "input4", False, "required", "all") \
.input(4, "input5", False, "required", "all") \
.input(5, "input6", False, "required", "all") \
.input(6, "input7", False, "required", "all") \
.input(7, "input8", False, "required", "all") \
.input(8, "input9", False, "required", "all") \
.input(9, "inputx0", False, "required", "all") \
.input(10, "inputx1", False, "required", "all") \
.input(11, "inputx2", False, "required", "all") \
.input(12, "inputx3", False, "required", "all") \
.output(0, "output1", False, "required", "all") \
.output(1, "output2", False, "required", "all") \
.output(2, "output3", False, "required", "all") \
.output(3, "output4", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default) \
.get_op_info()

@op_info_register("""{
"op_name":"LambNextMVWithDecayV1",
"imply_type":"TBE",
"fusion_type":"OPAQUE",
"async_flag":false,
"binfile_name":"lamb_next_m_v_with_decay_v1.so",
"compute_cost":10,
"kernel_name":"lamb_next_m_v_with_decay_v1",
"partial_flag":true,
"attr":[],
"inputs":[
{
"index":0,
"dtype":[
"float16",
"float32"
],
"format":[
"DefaultFormat",
"DefaultFormat"
],
"name":"input1",
"need_compile":false,
"param_type":"required",
"shape":"all"
},
{
"index":1,
"dtype":[
"float16",
"float32"
],
"format":[
"DefaultFormat",
"DefaultFormat"
],
"name":"input2",
"need_compile":false,
"param_type":"required",
"shape":"all"
},
{
"index":2,
"dtype":[
"float16",
"float32"
],
"format":[
"DefaultFormat",
"DefaultFormat"
],
"name":"input3",
"need_compile":false,
"param_type":"required",
"shape":"all"
},
{
"index":3,
"dtype":[
"float16",
"float32"
],
"format":[
"DefaultFormat",
"DefaultFormat"
],
"name":"input4",
"need_compile":false,
"param_type":"required",
"shape":"all"
},
{
"index":4,
"dtype":[
"float16",
"float32"
],
"format":[
"DefaultFormat",
"DefaultFormat"
],
"name":"input5",
"need_compile":false,
"param_type":"required",
"shape":"all"
},
{
"index":5,
"dtype":[
"float16",
"float32"
],
"format":[
"DefaultFormat",
"DefaultFormat"
],
"name":"input6",
"need_compile":false,
"param_type":"required",
"shape":"all"
},
{
"index":6,
"dtype":[
"float16",
"float32"
],
"format":[
"DefaultFormat",
"DefaultFormat"
],
"name":"input7",
"need_compile":false,
"param_type":"required",
"shape":"all"
},
{
"index":7,
"dtype":[
"float16",
"float32"
],
"format":[
"DefaultFormat",
"DefaultFormat"
],
"name":"input8",
"need_compile":false,
"param_type":"required",
"shape":"all"
},
{
"index":8,
"dtype":[
"float16",
"float32"
],
"format":[
"DefaultFormat",
"DefaultFormat"
],
"name":"input9",
"need_compile":false,
"param_type":"required",
"shape":"all"
},
{
"index":9,
"dtype":[
"float16",
"float32"
],
"format":[
"DefaultFormat",
"DefaultFormat"
],
"name":"inputx0",
"need_compile":false,
"param_type":"required",
"shape":"all"
},
{
"index":10,
"dtype":[
"float16",
"float32"
],
"format":[
"DefaultFormat",
"DefaultFormat"
],
"name":"inputx1",
"need_compile":false,
"param_type":"required",
"shape":"all"
},
{
"index":11,
"dtype":[
"float16",
"float32"
],
"format":[
"DefaultFormat",
"DefaultFormat"
],
"name":"inputx2",
"need_compile":false,
"param_type":"required",
"shape":"all"
},
{
"index":12,
"dtype":[
"float16",
"float32"
],
"format":[
"DefaultFormat",
"DefaultFormat"
],
"name":"inputx3",
"need_compile":false,
"param_type":"required",
"shape":"all"
}
],
"outputs":[
{
"index":0,
"dtype":[
"float16",
"float32"
],
"format":[
"DefaultFormat",
"DefaultFormat"
],
"name":"output1",
"need_compile":false,
"param_type":"required",
"shape":"all"
},
{
"index":1,
"dtype":[
"float16",
"float32"
],
"format":[
"DefaultFormat",
"DefaultFormat"
],
"name":"output2",
"need_compile":false,
"param_type":"required",
"shape":"all"
},
{
"index":2,
"dtype":[
"float16",
"float32"
],
"format":[
"DefaultFormat",
"DefaultFormat"
],
"name":"output3",
"need_compile":false,
"param_type":"required",
"shape":"all"
},
{
"index":3,
"dtype":[
"float16",
"float32"
],
"format":[
"DefaultFormat",
"DefaultFormat"
],
"name":"output4",
"need_compile":false,
"param_type":"required",
"shape":"all"
}
]
}""")

@op_info_register(lamb_next_m_v_with_decay_v1_op_info)
def _lamb_next_mv_with_decay_v1_tbe():
"""LambNextMVWithDecayV1 TBE register"""
return

+ 27
- 166
mindspore/ops/_op_impl/tbe/lamb_update_with_lr.py View File

@@ -14,174 +14,35 @@
# ============================================================================

"""LambUpdateWithLr op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

lamb_update_with_lr_op_info = TBERegOp("LambUpdateWithLR") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("lamb_update_with_lr.so") \
.compute_cost(10) \
.kernel_name("lamb_update_with_lr") \
.partial_flag(True) \
.input(0, "input1", False, "required", "all") \
.input(1, "input2", False, "required", "all") \
.input(2, "input3", False, "required", "all") \
.input(3, "input4", False, "required", "all") \
.input(4, "input5", False, "required", "all") \
.input(5, "input6", False, "required", "all") \
.input(6, "input7", False, "required", "all") \
.input(7, "input8", False, "required", "all") \
.input(8, "input9", False, "required", "all") \
.output(0, "output_y", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default) \
.get_op_info()

@op_info_register("""{
"op_name":"LambUpdateWithLR",
"imply_type":"TBE",
"fusion_type":"ELEMWISE",
"async_flag":false,
"binfile_name":"lamb_update_with_lr.so",
"compute_cost":10,
"kernel_name":"lamb_update_with_lr",
"partial_flag":true,
"attr":[],
"inputs":[
{
"index":0,
"dtype":[
"float16",
"float32"
],
"format":[
"DefaultFormat",
"DefaultFormat"
],
"name":"input1",
"need_compile":false,
"param_type":"required",
"shape":"all"
},
{
"index":1,
"dtype":[
"float16",
"float32"
],
"format":[
"DefaultFormat",
"DefaultFormat"
],
"name":"input2",
"need_compile":false,
"param_type":"required",
"shape":"all"
},
{
"index":2,
"dtype":[
"float16",
"float32"
],
"format":[
"DefaultFormat",
"DefaultFormat"
],
"name":"input3",
"need_compile":false,
"param_type":"required",
"shape":"all"
},
{
"index":3,
"dtype":[
"float16",
"float32"
],
"format":[
"DefaultFormat",
"DefaultFormat"
],
"name":"input4",
"need_compile":false,
"param_type":"required",
"shape":"all"
},
{
"index":4,
"dtype":[
"float16",
"float32"
],
"format":[
"DefaultFormat",
"DefaultFormat"
],
"name":"input5",
"need_compile":false,
"param_type":"required",
"shape":"all"
},
{
"index":5,
"dtype":[
"float16",
"float32"
],
"format":[
"DefaultFormat",
"DefaultFormat"
],
"name":"input6",
"need_compile":false,
"param_type":"required",
"shape":"all"
},
{
"index":6,
"dtype":[
"float16",
"float32"
],
"format":[
"DefaultFormat",
"DefaultFormat"
],
"name":"input7",
"need_compile":false,
"param_type":"required",
"shape":"all"
},
{
"index":7,
"dtype":[
"float16",
"float32"
],
"format":[
"DefaultFormat",
"DefaultFormat"
],
"name":"input8",
"need_compile":false,
"param_type":"required",
"shape":"all"
},
{
"index":8,
"dtype":[
"float16",
"float32"
],
"format":[
"DefaultFormat",
"DefaultFormat"
],
"name":"input9",
"need_compile":false,
"param_type":"required",
"shape":"all"
}
],
"outputs":[
{
"index":0,
"dtype":[
"float16",
"float32"
],
"format":[
"DefaultFormat",
"DefaultFormat"
],
"name":"output_y",
"need_compile":false,
"param_type":"required",
"shape":"all"
}
]
}""")

@op_info_register(lamb_update_with_lr_op_info)
def _lamb_update_with_lr_tbe():
"""LambUpdateWithLr TBE register"""
return

+ 23
- 136
mindspore/ops/_op_impl/tbe/lamb_update_with_lr_v2.py View File

@@ -14,144 +14,31 @@
# ============================================================================

"""LambUpdateWithLrV2 op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

lamb_update_with_lr_v2_op_info = TBERegOp("LambUpdateWithLrV2") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("lamb_update_with_lr_v2.so") \
.compute_cost(10) \
.kernel_name("lamb_update_with_lr_v2") \
.partial_flag(True) \
.input(0, "x1", False, "required", "all") \
.input(1, "x2", False, "required", "all") \
.input(2, "x3", False, "required", "all") \
.input(3, "x4", False, "required", "all") \
.input(4, "x5", False, "required", "all") \
.input(5, "greater_y", False, "required", "all") \
.input(6, "select_e", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.get_op_info()

@op_info_register("""{
"op_name":"LambUpdateWithLrV2",
"imply_type":"TBE",
"fusion_type":"ELEMWISE",
"async_flag":false,
"binfile_name":"lamb_update_with_lr_v2.so",
"compute_cost":10,
"kernel_name":"lamb_update_with_lr_v2",
"partial_flag":true,
"attr":[],
"inputs":[
{
"index":0,
"dtype":[
"float16",
"float32"
],
"format":[
"DefaultFormat",
"DefaultFormat"
],
"name":"x1",
"need_compile":false,
"param_type":"required",
"shape":"all"
},
{
"index":1,
"dtype":[
"float16",
"float32"
],
"format":[
"DefaultFormat",
"DefaultFormat"
],
"name":"x2",
"need_compile":false,
"param_type":"required",
"shape":"all"
},
{
"index":2,
"dtype":[
"float16",
"float32"
],
"format":[
"DefaultFormat",
"DefaultFormat"
],
"name":"x3",
"need_compile":false,
"param_type":"required",
"shape":"all"
},
{
"index":3,
"dtype":[
"float16",
"float32"
],
"format":[
"DefaultFormat",
"DefaultFormat"
],
"name":"x4",
"need_compile":false,
"param_type":"required",
"shape":"all"
},
{
"index":4,
"dtype":[
"float16",
"float32"
],
"format":[
"DefaultFormat",
"DefaultFormat"
],
"name":"x5",
"need_compile":false,
"param_type":"required",
"shape":"all"
},
{
"index":5,
"dtype":[
"float16",
"float32"
],
"format":[
"DefaultFormat",
"DefaultFormat"
],
"name":"greater_y",
"need_compile":false,
"param_type":"required",
"shape":"all"
},
{
"index":6,
"dtype":[
"float16",
"float32"
],
"format":[
"DefaultFormat",
"DefaultFormat"
],
"name":"select_e",
"need_compile":false,
"param_type":"required",
"shape":"all"
}
],
"outputs":[
{
"index":0,
"dtype":[
"float16",
"float32"
],
"format":[
"DefaultFormat",
"DefaultFormat"
],
"name":"y",
"need_compile":false,
"param_type":"required",
"shape":"all"
}
]
}""")

@op_info_register(lamb_update_with_lr_v2_op_info)
def _lamb_update_with_lr_v2_tbe():
"""LambUpdateWithLrV2 TBE register"""
return

+ 30
- 102
mindspore/ops/_op_impl/tbe/layer_norm.py View File

@@ -14,111 +14,39 @@
# ============================================================================

"""LayerNorm op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

layer_norm_op_info = TBERegOp("LayerNorm") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("layer_norm.so") \
.compute_cost(10) \
.kernel_name("layer_norm") \
.partial_flag(True) \
.attr("begin_norm_axis", "required", "int", "all") \
.attr("begin_params_axis", "required", "int", "all") \
.input(0, "x", False, "required", "all") \
.input(1, "gamma", False, "required", "all") \
.input(2, "beta", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.output(1, "mean", False, "required", "all") \
.output(2, "variance", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD,
DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F16_FracNZ, DataType.F16_Default, DataType.F16_Default, DataType.F16_FracNZ,
DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_FracNZ, DataType.F32_Default, DataType.F32_Default, DataType.F32_FracNZ,
DataType.F32_Default, DataType.F32_Default) \
.get_op_info()

@op_info_register("""{
"op_name": "LayerNorm",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "layer_norm.so",
"compute_cost": 10,
"kernel_name": "layer_norm",
"partial_flag": true,
"attr": [
{
"name": "begin_norm_axis",
"param_type": "required",
"type": "int",
"value": "all"
},
{
"name": "begin_params_axis",
"param_type": "required",
"type": "int",
"value": "all"
}
],
"inputs": [
{
"index": 0,
"dtype": [
"float16","float16","float16","float","float","float"
],
"format": [
"FRACTAL_NZ","DefaultFormat","NC1HWC0","FRACTAL_NZ","DefaultFormat","NC1HWC0"
],
"name": "x",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype": [
"float16","float16","float16","float","float","float"
],
"format": [
"DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","NC1HWC0"
],
"name": "gamma",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 2,
"dtype": [
"float16","float16","float16","float","float","float"
],
"format": [
"DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","NC1HWC0"
],
"name": "beta",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16","float16","float16","float","float","float"
],
"format": [
"FRACTAL_NZ","DefaultFormat","NC1HWC0","FRACTAL_NZ","DefaultFormat","NC1HWC0"
],
"name": "y",
"param_type": "required"
},
{
"index": 1,
"dtype": [
"float16","float16","float16","float","float","float"
],
"format": [
"DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","NC1HWC0"

],
"name": "mean",
"param_type": "required"
},
{
"index": 2,
"dtype": [
"float16","float16","float16","float","float","float"
],
"format": [
"DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","NC1HWC0"

],
"name": "variance",
"param_type": "required"
}
]
}""")
@op_info_register(layer_norm_op_info)
def _layer_norm_tbe():
"""LayerNorm TBE register"""
return

+ 30
- 97
mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop.py View File

@@ -14,105 +14,38 @@
# ============================================================================

"""LayerNormBetaGammaBackprop op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

layer_norm_beta_gamma_backprop_op_info = TBERegOp("LayerNormBetaGammaBackprop") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("layer_norm_beta_gamma_backprop.so") \
.compute_cost(10) \
.kernel_name("layer_norm_beta_gamma_backprop") \
.partial_flag(True) \
.attr("shape_gamma", "required", "listInt", "all") \
.input(0, "dy", False, "required", "all") \
.input(1, "x", False, "required", "all") \
.input(2, "variance", False, "required", "all") \
.input(3, "mean", False, "required", "all") \
.output(0, "pd_gamma", False, "required", "all") \
.output(1, "pd_beta", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD,
DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_Default, DataType.F16_Default,
DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default) \
.get_op_info()

@op_info_register("""{
"op_name": "LayerNormBetaGammaBackprop",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "layer_norm_beta_gamma_backprop.so",
"compute_cost": 10,
"kernel_name": "layer_norm_beta_gamma_backprop",
"partial_flag": true,
"attr": [
{
"name": "shape_gamma",
"param_type": "required",
"type": "listInt",
"value": "all"
}
],
"inputs": [
{
"index": 0,
"dtype": [
"float16","float16","float16","float","float","float"
],
"format": [
"FRACTAL_NZ","DefaultFormat","NC1HWC0","FRACTAL_NZ","DefaultFormat","NC1HWC0"
],
"name": "dy",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype": [
"float16","float16","float16","float","float","float"
],
"format": [
"FRACTAL_NZ","DefaultFormat","NC1HWC0","FRACTAL_NZ","DefaultFormat","NC1HWC0"
],
"name": "x",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 2,
"dtype": [
"float16","float16","float16","float","float","float"
],
"format": [
"DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","NC1HWC0"
],
"name": "variance",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 3,
"dtype": [
"float16","float16","float16","float","float","float"
],
"format": [
"DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","NC1HWC0"
],
"name": "mean",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float","float","float","float","float","float"
],
"format": [
"DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","NC1HWC0"
],
"name": "pd_gamma",
"param_type": "required"
},
{
"index": 1,
"dtype": [
"float","float","float","float","float","float"
],
"format": [
"DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","NC1HWC0"
],
"name": "pd_beta",
"param_type": "required"
}
]
}""")

@op_info_register(layer_norm_beta_gamma_backprop_op_info)
def _layer_norm_beta_gamma_backprop_tbe():
"""LayerNormBetaGammaBackprop TBE register"""
return

+ 26
- 115
mindspore/ops/_op_impl/tbe/layer_norm_grad.py View File

@@ -14,124 +14,35 @@
# ============================================================================

"""LayerNormGrad op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

layer_norm_grad_op_info = TBERegOp("LayerNormGrad") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("layer_norm_grad.so") \
.compute_cost(10) \
.kernel_name("layer_norm_grad") \
.partial_flag(True) \
.input(0, "dy", False, "required", "all") \
.input(1, "x", False, "required", "all") \
.input(2, "variance", False, "required", "all") \
.input(3, "mean", False, "required", "all") \
.input(4, "gamma", False, "required", "all") \
.output(0, "pd_x", False, "required", "all") \
.output(1, "pd_gamma", False, "required", "all") \
.output(2, "pd_beta", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD,
DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()

@op_info_register("""{
"op_name": "LayerNormGrad",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "layer_norm_grad.so",
"compute_cost": 10,
"kernel_name": "layer_norm_grad",
"partial_flag": true,
"attr": [

],
"inputs": [
{
"index": 0,
"dtype": [
"float16","float16","float","float"
],
"format": [
"DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0"
],
"name": "dy",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype": [
"float16","float16","float","float"
],
"format": [
"DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0"
],
"name": "x",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 2,
"dtype": [
"float16","float16","float","float"
],
"format": [
"DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0"
],
"name": "variance",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 3,
"dtype": [
"float16","float16","float","float"
],
"format": [
"DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0"
],
"name": "mean",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 4,
"dtype": [
"float16","float16","float","float"
],
"format": [
"DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0"
],
"name": "gamma",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16","float16","float","float"
],
"format": [
"DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0"
],
"name": "pd_x",
"param_type": "required"
},
{
"index": 1,
"dtype": [
"float16","float16","float","float"
],
"format": [
"DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0"
],
"name": "pd_gamma",
"param_type": "required"
},
{
"index": 2,
"dtype": [
"float16","float16","float","float"
],
"format": [
"DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0"
],
"name": "pd_beta",
"param_type": "required"
}
]
}""")
@op_info_register(layer_norm_grad_op_info)
def _layer_norm_grad_tbe():
"""LayerNormGrad TBE register"""
return

+ 28
- 93
mindspore/ops/_op_impl/tbe/layer_norm_x_backprop.py View File

@@ -14,102 +14,37 @@
# ============================================================================

"""LayerNormXBackprop op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

layer_norm_x_backprop_op_info = TBERegOp("LayerNormXBackprop") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("layer_norm_x_backprop.so") \
.compute_cost(10) \
.kernel_name("layer_norm_x_backprop") \
.partial_flag(True) \
.input(0, "dy", False, "required", "all") \
.input(1, "x", False, "required", "all") \
.input(2, "variance", False, "required", "all") \
.input(3, "mean", False, "required", "all") \
.input(4, "gamma", False, "required", "all") \
.output(0, "pd_x", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD,
DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_FracNZ) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_FracNZ) \
.get_op_info()

@op_info_register("""{
"op_name": "LayerNormXBackprop",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "layer_norm_x_backprop.so",
"compute_cost": 10,
"kernel_name": "layer_norm_x_backprop",
"partial_flag": true,
"attr": [

],
"inputs": [
{
"index": 0,
"dtype": [
"float16","float16","float16","float","float","float"
],
"format": [
"FRACTAL_NZ","DefaultFormat","NC1HWC0","FRACTAL_NZ","DefaultFormat","NC1HWC0"
],
"name": "dy",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype": [
"float16","float16","float16","float","float","float"
],
"format": [
"FRACTAL_NZ","DefaultFormat","NC1HWC0","FRACTAL_NZ","DefaultFormat","NC1HWC0"
],
"name": "x",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 2,
"dtype": [
"float16","float16","float16","float","float","float"
],
"format": [
"DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","NC1HWC0"
],
"name": "variance",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 3,
"dtype": [
"float16","float16","float16","float","float","float"
],
"format": [
"DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","NC1HWC0"
],
"name": "mean",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 4,
"dtype": [
"float16","float16","float16","float","float","float"
],
"format": [
"DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","NC1HWC0"
],
"name": "gamma",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16","float16","float16","float","float","float"
],
"format": [
"FRACTAL_NZ","DefaultFormat","NC1HWC0","FRACTAL_NZ","DefaultFormat","NC1HWC0"
],
"name": "pd_x",
"param_type": "required"
}
]
}""")
@op_info_register(layer_norm_x_backprop_op_info)
def _layer_norm_x_backprop_tbe():
"""LayerNormXBackprop TBE register"""
return

+ 23
- 58
mindspore/ops/_op_impl/tbe/less.py View File

@@ -14,67 +14,32 @@
# ============================================================================

"""Less op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

less_op_info = TBERegOp("Less") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("less.so") \
.compute_cost(10) \
.kernel_name("less") \
.partial_flag(True) \
.input(0, "x1", False, "required", "all") \
.input(1, "x2", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I8_5HD, DataType.I8_5HD, DataType.BOOL_5HD) \
.dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.BOOL_Default) \
.dtype_format(DataType.U8_5HD, DataType.U8_5HD, DataType.BOOL_5HD) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.BOOL_5HD) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.BOOL_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.BOOL_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.BOOL_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.BOOL_5HD) \
.get_op_info()

@op_info_register("""{
"op_name": "Less",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "less.so",
"compute_cost": 10,
"kernel_name": "less",
"partial_flag": true,
"attr": [

],
"inputs": [
{
"index": 0,
"dtype": [
"float16","float16","float","float","int32","int32","int8","int8","uint8","uint8"
],
"format": [
"DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat",
"NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0"
],
"name": "x1",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype": [
"float16","float16","float","float","int32","int32","int8","int8","uint8","uint8"
],
"format": [
"DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat",
"NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0"
],
"name": "x2",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"bool","bool","bool","bool","bool","bool","bool","bool","bool","bool"
],
"format": [
"DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat",
"NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0"
],
"name": "y",
"param_type": "required",
"shape": "all"
}
]
}""")
@op_info_register(less_op_info)
def _less_tbe():
"""Less TBE register"""
return

+ 25
- 58
mindspore/ops/_op_impl/tbe/less_equal.py View File

@@ -14,67 +14,34 @@
# ============================================================================

"""LessEqual op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

less_equal_op_info = TBERegOp("LessEqual") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("less_equal.so") \
.compute_cost(10) \
.kernel_name("less_equal") \
.partial_flag(True) \
.attr("begin_norm_axis", "required", "int", "all") \
.attr("begin_params_axis", "required", "int", "all") \
.input(0, "x1", False, "required", "all") \
.input(1, "x2", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I8_5HD, DataType.I8_5HD, DataType.BOOL_5HD) \
.dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.BOOL_Default) \
.dtype_format(DataType.U8_5HD, DataType.U8_5HD, DataType.BOOL_5HD) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.BOOL_5HD) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.BOOL_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.BOOL_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.BOOL_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.BOOL_5HD) \
.get_op_info()

@op_info_register("""{
"op_name": "LessEqual",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "less_equal.so",
"compute_cost": 10,
"kernel_name": "less_equal",
"partial_flag": true,
"attr": [

],
"inputs": [
{
"index": 0,
"dtype": [
"float16","float16","float","float","int32","int32","int8","int8","uint8","uint8"
],
"format": [
"DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat",
"NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0"
],
"name": "x1",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype": [
"float16","float16","float","float","int32","int32","int8","int8","uint8","uint8"
],
"format": [
"DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat",
"NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0"
],
"name": "x2",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"bool","bool","bool","bool","bool","bool","bool","bool","bool","bool"
],
"format": [
"DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat",
"NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0"
],
"name": "y",
"param_type": "required",
"shape": "all"
}
]
}""")
@op_info_register(less_equal_op_info)
def _less_equal_tbe():
"""LessEqual TBE register"""
return

+ 16
- 43
mindspore/ops/_op_impl/tbe/log.py View File

@@ -14,52 +14,25 @@
# ============================================================================

"""Log op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

log_op_info = TBERegOp("Log") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("log.so") \
.compute_cost(10) \
.kernel_name("log") \
.partial_flag(True) \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()

@op_info_register("""{
"op_name": "Log",
"imply_type": "TBE",
"fusion_type": "ELEMWISE",
"async_flag": false,
"binfile_name": "log.so",
"compute_cost": 10,
"kernel_name": "log",
"partial_flag": true,
"attr": [

],
"inputs": [
{
"index": 0,
"dtype": [
"float16", "float16", "float", "float"
],
"format": [
"DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0"
],
"name": "x",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16", "float16", "float", "float"
],
"format": [
"DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0"
],
"name": "y",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
]
}""")
@op_info_register(log_op_info)
def _log_tbe():
"""Log TBE register"""
return

+ 17
- 56
mindspore/ops/_op_impl/tbe/logical_and.py View File

@@ -14,65 +14,26 @@
# ============================================================================
"""LogicalAnd op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
logical_and_op_info = TBERegOp("LogicalAnd") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("logical_and.so") \
.compute_cost(10) \
.kernel_name("logical_and") \
.partial_flag(True) \
.input(0, "x1", False, "required", "all") \
.input(1, "x2", False, "required", "all") \
.output(0, "y", True, "required", "all") \
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default, DataType.BOOL_Default) \
.dtype_format(DataType.BOOL_FracZ, DataType.BOOL_FracZ, DataType.BOOL_FracZ) \
.dtype_format(DataType.BOOL_C1HWNCoC0, DataType.BOOL_C1HWNCoC0, DataType.BOOL_C1HWNCoC0) \
.dtype_format(DataType.BOOL_5HD, DataType.BOOL_5HD, DataType.BOOL_5HD) \
.get_op_info()
@op_info_register("""{
"op_name": "LogicalAnd",
"imply_type": "TBE",
"fusion_type": "ELEMWISE",
"async_flag": false,
"binfile_name": "logical_and.so",
"compute_cost": 10,
"kernel_name": "logical_and",
"partial_flag": true,
"attr": [
],
"inputs": [
{
"index": 0,
"dtype": [
"bool", "bool", "bool", "bool"
],
"format": [
"DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0"
],
"name": "x1",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype": [
"bool", "bool", "bool", "bool"
],
"format": [
"DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0"
],
"name": "x2",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"bool", "bool", "bool", "bool"
],
"format": [
"DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0"
],
"name": "y",
"need_compile": true,
"param_type": "required",
"shape": "all"
}
]
}""")
@op_info_register(logical_and_op_info)
def _logical_and_tbe():
"""LogicalAnd TBE register"""
return

+ 16
- 43
mindspore/ops/_op_impl/tbe/logical_not.py View File

@@ -14,52 +14,25 @@
# ============================================================================
"""LogicalNot op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
logical_not_op_info = TBERegOp("LogicalNot") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("logical_not.so") \
.compute_cost(10) \
.kernel_name("logical_not") \
.partial_flag(True) \
.input(0, "x", False, "required", "all") \
.output(0, "y", True, "required", "all") \
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
.dtype_format(DataType.BOOL_FracZ, DataType.BOOL_FracZ) \
.dtype_format(DataType.BOOL_C1HWNCoC0, DataType.BOOL_C1HWNCoC0) \
.dtype_format(DataType.BOOL_5HD, DataType.BOOL_5HD) \
.get_op_info()
@op_info_register("""{
"op_name": "LogicalNot",
"imply_type": "TBE",
"fusion_type": "ELEMWISE",
"async_flag": false,
"binfile_name": "logical_not.so",
"compute_cost": 10,
"kernel_name": "logical_not",
"partial_flag": true,
"attr": [
],
"inputs": [
{
"index": 0,
"dtype": [
"bool", "bool", "bool", "bool"
],
"format": [
"DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0"
],
"name": "x",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"bool", "bool", "bool", "bool"
],
"format": [
"DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0"
],
"name": "y",
"need_compile": true,
"param_type": "required",
"shape": "all"
}
]
}""")
@op_info_register(logical_not_op_info)
def _logical_not_tbe():
"""LogicalNot TBE register"""
return

+ 17
- 56
mindspore/ops/_op_impl/tbe/logical_or.py View File

@@ -14,65 +14,26 @@
# ============================================================================
"""LogicalOr op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
logical_or_op_info = TBERegOp("LogicalOr") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("logical_or.so") \
.compute_cost(10) \
.kernel_name("logical_or") \
.partial_flag(True) \
.input(0, "x1", False, "required", "all") \
.input(1, "x2", False, "required", "all") \
.output(0, "y", True, "required", "all") \
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default, DataType.BOOL_Default) \
.dtype_format(DataType.BOOL_FracZ, DataType.BOOL_FracZ, DataType.BOOL_FracZ) \
.dtype_format(DataType.BOOL_C1HWNCoC0, DataType.BOOL_C1HWNCoC0, DataType.BOOL_C1HWNCoC0) \
.dtype_format(DataType.BOOL_5HD, DataType.BOOL_5HD, DataType.BOOL_5HD) \
.get_op_info()
@op_info_register("""{
"op_name": "LogicalOr",
"imply_type": "TBE",
"fusion_type": "ELEMWISE",
"async_flag": false,
"binfile_name": "logical_or.so",
"compute_cost": 10,
"kernel_name": "logical_or",
"partial_flag": true,
"attr": [
],
"inputs": [
{
"index": 0,
"dtype": [
"bool", "bool", "bool", "bool"
],
"format": [
"DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0"
],
"name": "x1",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype": [
"bool", "bool", "bool", "bool"
],
"format": [
"DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0"
],
"name": "x2",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"bool", "bool", "bool", "bool"
],
"format": [
"DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0"
],
"name": "y",
"need_compile": true,
"param_type": "required",
"shape": "all"
}
]
}""")
@op_info_register(logical_or_op_info)
def _logical_or_tbe():
"""LogicalOr TBE register"""
return

+ 16
- 49
mindspore/ops/_op_impl/tbe/logsoftmax.py View File

@@ -14,57 +14,24 @@
# ============================================================================

"""LogSoftmax op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

log_softmax_op_info = TBERegOp("LogSoftmax") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("log_softmax.so") \
.compute_cost(10) \
.kernel_name("log_softmax") \
.partial_flag(True) \
.attr("axis", "optional", "listInt", "all") \
.input(0, "logits", False, "required", "all") \
.output(0, "logsoftmax", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.get_op_info()

@op_info_register("""{
"op_name": "LogSoftmax",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "log_softmax.so",
"compute_cost": 10,
"kernel_name": "log_softmax",
"partial_flag": true,
"attr": [
{
"name": "axis",
"param_type": "optional",
"type": "listInt",
"value": "all"
}
],
"inputs": [
{
"index": 0,
"dtype": [
"float16", "float"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "logits",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16", "float"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "logsoftmax",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
]
}""")

@op_info_register(log_softmax_op_info)
def _logsoftmax_tbe():
"""LogSoftMaxGrad TBE register"""
return

+ 17
- 62
mindspore/ops/_op_impl/tbe/logsoftmax_grad.py View File

@@ -14,70 +14,25 @@
# ============================================================================

"""LogSoftmaxGrad op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

log_softmax_grad_op_info = TBERegOp("LogSoftmaxGrad") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("log_softmax_grad.so") \
.compute_cost(10) \
.kernel_name("log_softmax_grad") \
.partial_flag(True) \
.attr("axis", "optional", "listInt", "all") \
.input(0, "x", False, "required", "all") \
.input(1, "grad", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.get_op_info()

@op_info_register("""{
"op_name": "LogSoftmaxGrad",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "log_softmax_grad.so",
"compute_cost": 10,
"kernel_name": "log_softmax_grad",
"partial_flag": true,
"attr": [
{
"name": "axis",
"param_type": "optional",
"type": "listInt",
"value": "all"
}
],
"inputs": [
{
"index": 0,
"dtype": [
"float16", "float"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "x",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype": [
"float16", "float"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "grad",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16", "float"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "y",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
]
}""")

@op_info_register(log_softmax_grad_op_info)
def _logsoftmax_grad_tbe():
"""LogSoftMaxGrad TBE register"""
return

+ 21
- 81
mindspore/ops/_op_impl/tbe/matmul.py View File

@@ -14,89 +14,29 @@
# ============================================================================

"""MatMul op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

matmul_op_info = TBERegOp("MatMul") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("matmul.so") \
.compute_cost(10) \
.kernel_name("matmul") \
.partial_flag(True) \
.attr("transpose_a", "required", "bool", "all") \
.attr("transpose_b", "required", "bool", "all") \
.input(0, "x1", False, "required", "all") \
.input(1, "x2", False, "required", "all") \
.input(2, "x3", False, "optional", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_Default, DataType.F16_FracNZ) \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F32_Default, DataType.F32_FracNZ) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.get_op_info()

@op_info_register("""{
"op_name": "MatMul",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "matmul.so",
"compute_cost": 10,
"kernel_name": "matmul",
"partial_flag": true,
"attr": [
{
"name": "transpose_a",
"param_type": "required",
"type": "bool",
"value": "all"
},
{
"name": "transpose_b",
"param_type": "required",
"type": "bool",
"value": "all"
}
],
"inputs": [
{
"index": 0,
"dtype": [
"float16","float16","float","int32"
],
"format": [
"FRACTAL_NZ","FRACTAL_NZ","DefaultFormat","DefaultFormat"
],
"name": "x1",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype": [
"float16","float16","float","int32"
],
"format": [
"FRACTAL_NZ","FRACTAL_NZ","DefaultFormat","DefaultFormat"
],
"name": "x2",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 2,
"dtype": [
"float16","float","float","int32"
],
"format": [
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat"
],
"name": "x3",
"need_compile": false,
"param_type": "optional",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16","float","float","int32"
],
"format": [
"FRACTAL_NZ","FRACTAL_NZ","DefaultFormat","DefaultFormat"
],
"name": "y",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
]
}""")

@op_info_register(matmul_op_info)
def _matmul_tbe():
"""Mul TBE register"""
return

+ 18
- 66
mindspore/ops/_op_impl/tbe/max_pool.py View File

@@ -14,74 +14,26 @@
# ============================================================================

"""MaxPool op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

max_pool_op_info = TBERegOp("MaxPool") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("max_pool.so") \
.compute_cost(10) \
.kernel_name("max_pool") \
.partial_flag(True) \
.attr("ksize", "required", "listInt", "all") \
.attr("strides", "required", "listInt", "all") \
.attr("padding", "required", "str", "all") \
.attr("data_format", "required", "str", "all") \
.input(0, "input_data", False, "required", "all") \
.output(0, "output_data", False, "required", "all") \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
.get_op_info()

@op_info_register("""{
"op_name": "MaxPool",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "max_pool.so",
"compute_cost": 10,
"kernel_name": "max_pool",
"partial_flag": true,
"attr": [
{
"name": "ksize",
"param_type": "required",
"type": "listInt",
"value": "all"
},
{
"name": "strides",
"param_type": "required",
"type": "listInt",
"value": "all"
},
{
"name": "padding",
"param_type": "required",
"type": "str",
"value": "all"
},
{
"name": "data_format",
"param_type": "required",
"type": "str",
"value": "all"
}
],
"inputs": [
{
"index": 0,
"dtype": [
"float16"
],
"format": [
"NC1HWC0"
],
"name": "input_data",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16"
],
"format": [
"NC1HWC0"
],
"name": "output_data",
"param_type": "required",
"shape": "all"
}
]
}""")

@op_info_register(max_pool_op_info)
def _max_pool_tbe():
"""MaxPool TBE register"""
return

+ 19
- 85
mindspore/ops/_op_impl/tbe/max_pool_grad.py View File

@@ -14,93 +14,27 @@
# ============================================================================

"""MaxPoolGrad op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

max_pool_grad_op_info = TBERegOp("MaxPoolGrad") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("max_pool_grad.so") \
.compute_cost(10) \
.kernel_name("max_pool_grad") \
.partial_flag(True) \
.attr("ksize", "required", "listInt", "all") \
.attr("strides", "required", "listInt", "all") \
.attr("padding", "required", "str", "all") \
.input(0, "x1", False, "required", "all") \
.input(1, "x2", False, "required", "all") \
.input(2, "grad", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
.get_op_info()

@op_info_register("""{
"op_name": "MaxPoolGrad",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "max_pool_grad.so",
"compute_cost": 10,
"kernel_name": "max_pool_grad",
"partial_flag": true,
"attr": [
{
"name": "ksize",
"param_type": "required",
"type": "listInt",
"value": "all"
},
{
"name": "strides",
"param_type": "required",
"type": "listInt",
"value": "all"
},
{
"name": "padding",
"param_type": "required",
"type": "str",
"value": "all"
}
],
"inputs": [
{
"index": 0,
"dtype": [
"float16"
],
"format": [
"NC1HWC0"
],
"name": "x1",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype": [
"float16"
],
"format": [
"NC1HWC0"
],
"name": "x2",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 2,
"dtype": [
"float16"
],
"format": [
"NC1HWC0"
],
"name": "grad",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16"
],
"format": [
"NC1HWC0"
],
"name": "y",
"param_type": "required"
}
]
}""")

@op_info_register(max_pool_grad_op_info)
def _max_pool_grad_tbe():
"""MaxPoolGrad TBE register"""
return

+ 20
- 87
mindspore/ops/_op_impl/tbe/max_pool_grad_with_argmax.py View File

@@ -14,95 +14,28 @@
# ============================================================================

"""MaxPoolGradWithArgmax op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

max_pool_grad_with_argmax_op_info = TBERegOp("MaxPoolGradWithArgmax") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("max_pool_grad_with_argmax.so") \
.compute_cost(10) \
.kernel_name("max_pool_grad_with_argmax") \
.partial_flag(True) \
.attr("ksize", "required", "listInt", "all") \
.attr("strides", "required", "listInt", "all") \
.attr("padding", "required", "str", "all") \
.input(0, "x", False, "required", "all") \
.input(1, "grad", False, "required", "all") \
.input(2, "argmax", False, "optional", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.U16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.I64_5HD, DataType.F16_5HD) \
.get_op_info()

@op_info_register("""{
"op_name": "MaxPoolGradWithArgmax",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "max_pool_grad_with_argmax.so",
"compute_cost": 10,
"kernel_name": "max_pool_grad_with_argmax",
"partial_flag": true,
"attr": [
{
"name": "ksize",
"param_type": "required",
"type": "listInt",
"value": "all"
},
{
"name": "strides",
"param_type": "required",
"type": "listInt",
"value": "all"
},
{
"name": "padding",
"param_type": "required",
"type": "str",
"value": "all"
}
],
"inputs": [
{
"index": 0,
"dtype": [
"float16", "float16"
],
"format": [
"NC1HWC0", "NC1HWC0"
],
"name": "x",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype": [
"float16", "float16"
],
"format": [
"NC1HWC0", "NC1HWC0"
],
"name": "grad",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 2,
"dtype": [
"uint16", "int64"
],
"format": [
"NC1HWC0", "NC1HWC0"
],
"name": "argmax",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16", "float16"
],
"format": [
"NC1HWC0", "NC1HWC0"
],
"name": "y",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
]
}""")

@op_info_register(max_pool_grad_with_argmax_op_info)
def _max_pool_grad_with_argmax_tbe():
"""MaxPoolGradWithArgmax TBE register"""
return

+ 18
- 74
mindspore/ops/_op_impl/tbe/max_pool_with_argmax.py View File

@@ -14,82 +14,26 @@
# ============================================================================

"""MaxPoolWithArgmax op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

max_pool_with_argmax_op_info = TBERegOp("MaxPoolWithArgmax") \
.fusion_type("CONVLUTION") \
.async_flag(False) \
.binfile_name("max_pool_with_argmax.so") \
.compute_cost(10) \
.kernel_name("max_pool_with_argmax") \
.partial_flag(True) \
.attr("ksize", "required", "listInt", "all") \
.attr("strides", "required", "listInt", "all") \
.attr("padding", "required", "str", "all") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.output(1, "argmax", False, "required", "all") \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.U16_5HD) \
.get_op_info()

@op_info_register("""{
"op_name": "MaxPoolWithArgmax",
"imply_type": "TBE",
"fusion_type": "CONVLUTION",
"async_flag": false,
"binfile_name": "max_pool_with_argmax.so",
"compute_cost": 10,
"kernel_name": "max_pool_with_argmax",
"partial_flag": true,
"attr": [
{
"name": "ksize",
"param_type": "required",
"type": "listInt",
"value": "all"
},
{
"name": "strides",
"param_type": "required",
"type": "listInt",
"value": "all"
},
{
"name": "padding",
"param_type": "required",
"type": "str",
"value": "all"
}
],
"inputs": [
{
"index": 0,
"dtype": [
"float16"
],
"format": [
"NC1HWC0"
],
"name": "x",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16"
],
"format": [
"NC1HWC0"
],
"name": "y",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype": [
"uint16"
],
"format": [
"NC1HWC0"
],
"name": "argmax",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
]
}""")

@op_info_register(max_pool_with_argmax_op_info)
def _max_pool_with_argmax_tbe():
"""MaxPoolWithArgmax TBE register"""
return

+ 20
- 61
mindspore/ops/_op_impl/tbe/maximum.py View File

@@ -14,69 +14,28 @@
# ============================================================================

"""Maximum op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

maximum_op_info = TBERegOp("Maximum") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("maximum.so") \
.compute_cost(10) \
.kernel_name("maximum") \
.partial_flag(True) \
.input(0, "x1", False, "required", "all") \
.input(1, "x2", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()

@op_info_register("""{
"op_name":"Maximum",
"imply_type":"TBE",
"fusion_type":"ELEMWISE",
"async_flag":false,
"binfile_name":"maximum.so",
"compute_cost":10,
"kernel_name":"maximum",
"partial_flag":true,
"attr":[],
"inputs":[
{
"index":0,
"dtype":[
"float16", "float16", "float16", "float16", "float", "float", "float", "float", "int32",
"int32", "int32", "int32"
],
"format":[
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat"
],
"name":"x1",
"need_compile":false,
"param_type":"required",
"shape":"all"
},
{
"index":1,
"dtype":[
"float16", "float16", "float16", "float16", "float", "float", "float", "float", "int32",
"int32", "int32", "int32"
],
"format":[
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat"
],
"name":"x2",
"need_compile":false,
"param_type":"required",
"shape":"all"
}
],
"outputs":[
{
"index":0,
"dtype":[
"float16", "float16", "float16", "float16", "float", "float", "float", "float", "int32",
"int32", "int32", "int32"
],
"format":[
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat"
],
"name":"y",
"need_compile":false,
"param_type":"required",
"shape":"all"
}
]
}""")

@op_info_register(maximum_op_info)
def _maximum_tbe():
"""Maximum TBE register"""
return

+ 30
- 104
mindspore/ops/_op_impl/tbe/maximum_grad.py View File

@@ -14,112 +14,38 @@
# ============================================================================

"""MaximumGrad op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

maximum_grad_op_info = TBERegOp("MaximumGrad") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("maximum_grad.so") \
.compute_cost(10) \
.kernel_name("maximum_grad") \
.partial_flag(True) \
.attr("grad_x", "optional", "bool", "all") \
.attr("grad_y", "optional", "bool", "all") \
.input(0, "grads", False, "required", "all") \
.input(1, "x1", False, "required", "all") \
.input(2, "x2", False, "required", "all") \
.output(0, "y1", False, "required", "all") \
.output(1, "y2", False, "required", "all") \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
DataType.I32_Default) \
.dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD,
DataType.I32_5HD) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD,
DataType.F16_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD) \
.get_op_info()

@op_info_register("""{
"op_name":"MaximumGrad",
"imply_type":"TBE",
"fusion_type":"OPAQUE",
"async_flag":false,
"binfile_name":"maximum_grad.so",
"compute_cost":10,
"kernel_name":"maximum_grad",
"partial_flag":true,
"attr":[
{
"name":"grad_x",
"param_type":"optional",
"type":"bool",
"value":"all"
},
{
"name":"grad_y",
"param_type":"optional",
"type":"bool",
"value":"all"
}
],
"inputs":[
{
"index":0,
"dtype":[
"float16", "float16", "float16", "float16", "float", "float", "float", "float",
"int32", "int32", "int32", "int32"
],
"format":[
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat"
],
"name":"grads",
"need_compile":false,
"param_type":"required",
"shape":"all"
},
{
"index":1,
"dtype":[
"float16", "float16", "float16", "float16", "float", "float", "float", "float",
"int32", "int32", "int32", "int32"
],
"format":[
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat"
],
"name":"x1",
"need_compile":false,
"param_type":"required",
"shape":"all"
},
{
"index":2,
"dtype":[
"float16", "float16", "float16", "float16", "float", "float", "float", "float",
"int32", "int32", "int32", "int32"
],
"format":[
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat"
],
"name":"x2",
"need_compile":false,
"param_type":"required",
"shape":"all"
}
],
"outputs":[
{
"index":0,
"dtype":[
"float16", "float16", "float16", "float16", "float", "float", "float", "float",
"int32", "int32", "int32", "int32"
],
"format":[
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat"
],
"name":"y1",
"need_compile":false,
"param_type":"required",
"shape":"all"
},
{
"index":1,
"dtype":[
"float16", "float16", "float16", "float16", "float", "float", "float", "float",
"int32", "int32", "int32", "int32"
],
"format":[
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat"
],
"name":"y2",
"need_compile":false,
"param_type":"required",
"shape":"all"
}
]
}""")

@op_info_register(maximum_grad_op_info)
def _maximum_grad_tbe():
"""MaximumGrad TBE register"""
return

+ 19
- 65
mindspore/ops/_op_impl/tbe/minimum.py View File

@@ -15,74 +15,28 @@


"""Minimum op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

minimum_op_info = TBERegOp("Minimum") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("minimum.so") \
.compute_cost(10) \
.kernel_name("minimum") \
.partial_flag(True) \
.input(0, "x1", False, "required", "all") \
.input(1, "x2", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()

@op_info_register("""{
"op_name": "Minimum",
"imply_type": "TBE",
"fusion_type": "ELEMWISE",
"async_flag": false,
"binfile_name": "minimum.so",
"compute_cost": 10,
"kernel_name": "minimum",
"partial_flag": true,
"attr": [

],
"inputs": [
{
"index": 0,
"dtype": [
"float16", "float16", "float16", "float16", "float", "float", "float", "float",
"int32", "int32", "int32", "int32"
],
"format": [
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat",
"NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
"DefaultFormat", "DefaultFormat"
],
"name": "x1",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype": [
"float16", "float16", "float16", "float16", "float", "float", "float", "float",
"int32", "int32", "int32", "int32"
],
"format": [
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat",
"NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
"DefaultFormat", "DefaultFormat"
],
"name": "x2",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16", "float16", "float16", "float16", "float", "float", "float", "float",
"int32", "int32", "int32", "int32"
],
"format": [
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
"DefaultFormat", "DefaultFormat"
],
"name": "y",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
]
}""")
@op_info_register(minimum_op_info)
def _minimum_tbe():
"""Minimum TBE register"""
return

+ 30
- 104
mindspore/ops/_op_impl/tbe/minimum_grad.py View File

@@ -14,112 +14,38 @@
# ============================================================================

"""MinimumGrad op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

minimum_grad_op_info = TBERegOp("MinimumGrad") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("minimum_grad.so") \
.compute_cost(10) \
.kernel_name("minimum_grad") \
.partial_flag(True) \
.attr("grad_x", "optional", "bool", "all") \
.attr("grad_y", "optional", "bool", "all") \
.input(0, "grads", False, "required", "all") \
.input(1, "x1", False, "required", "all") \
.input(2, "x2", False, "required", "all") \
.output(0, "y1", False, "required", "all") \
.output(1, "y2", False, "required", "all") \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
DataType.I32_Default) \
.dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD,
DataType.I32_5HD) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD,
DataType.F16_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD) \
.get_op_info()

@op_info_register("""{
"op_name":"MinimumGrad",
"imply_type":"TBE",
"fusion_type":"OPAQUE",
"async_flag":false,
"binfile_name":"minimum_grad.so",
"compute_cost":10,
"kernel_name":"minimum_grad",
"partial_flag":true,
"attr":[
{
"name":"grad_x",
"param_type":"optional",
"type":"bool",
"value":"all"
},
{
"name":"grad_y",
"param_type":"optional",
"type":"bool",
"value":"all"
}
],
"inputs":[
{
"index":0,
"dtype":[
"float16", "float16", "float16", "float16", "float", "float", "float", "float",
"int32", "int32", "int32", "int32"
],
"format":[
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat"
],
"name":"grads",
"need_compile":false,
"param_type":"required",
"shape":"all"
},
{
"index":1,
"dtype":[
"float16", "float16", "float16", "float16", "float", "float", "float", "float",
"int32", "int32", "int32", "int32"
],
"format":[
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat"
],
"name":"x1",
"need_compile":false,
"param_type":"required",
"shape":"all"
},
{
"index":2,
"dtype":[
"float16", "float16", "float16", "float16", "float", "float", "float", "float",
"int32", "int32", "int32", "int32"
],
"format":[
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat"
],
"name":"x2",
"need_compile":false,
"param_type":"required",
"shape":"all"
}
],
"outputs":[
{
"index":0,
"dtype":[
"float16", "float16", "float16", "float16", "float", "float", "float", "float",
"int32", "int32", "int32", "int32"
],
"format":[
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat"
],
"name":"y1",
"need_compile":false,
"param_type":"required",
"shape":"all"
},
{
"index":1,
"dtype":[
"float16", "float16", "float16", "float16", "float", "float", "float", "float",
"int32", "int32", "int32", "int32"
],
"format":[
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat"
],
"name":"y2",
"need_compile":false,
"param_type":"required",
"shape":"all"
}
]
}""")

@op_info_register(minimum_grad_op_info)
def _minimum_grad_tbe():
"""MinimumGrad TBE register"""
return

+ 28
- 68
mindspore/ops/_op_impl/tbe/mul.py View File

@@ -14,77 +14,37 @@
# ============================================================================

"""Mul op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

mul_op_info = TBERegOp("Mul") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("mul.so") \
.compute_cost(10) \
.kernel_name("mul") \
.partial_flag(True) \
.input(0, "x", False, "required", "all") \
.input(1, "y", False, "required", "all") \
.output(0, "output", False, "required", "all") \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \
.dtype_format(DataType.I32_FracZ, DataType.I32_FracZ, DataType.I32_FracZ) \
.dtype_format(DataType.I32_FracNZ, DataType.I32_FracNZ, DataType.I32_FracNZ) \
.dtype_format(DataType.I32_C1HWNCoC0, DataType.I32_C1HWNCoC0, DataType.I32_C1HWNCoC0) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ) \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ) \
.dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ) \
.dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ) \
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \
.get_op_info()

@op_info_register("""{
"op_name": "Mul",
"imply_type": "TBE",
"fusion_type": "ELEMWISE",
"async_flag": false,
"binfile_name": "mul.so",
"compute_cost": 10,
"kernel_name": "mul",
"partial_flag": true,
"attr": [

],
"inputs": [
{
"index": 0,
"dtype": [
"int32", "int32", "int32", "int32", "int32",
"float16", "float16", "float16", "float16", "float16",
"float", "float", "float", "float", "float"
],
"format": [
"FRACTAL_NZ", "DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0",
"FRACTAL_NZ", "DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0",
"FRACTAL_NZ", "DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0"
],
"name": "x",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype": [
"int32", "int32", "int32", "int32", "int32",
"float16", "float16", "float16", "float16", "float16",
"float", "float", "float", "float", "float"
],
"format": [
"FRACTAL_NZ", "DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0",
"FRACTAL_NZ", "DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0",
"FRACTAL_NZ", "DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0"
],
"name": "y",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"int32", "int32", "int32", "int32", "int32",
"float16", "float16", "float16", "float16", "float16",
"float", "float", "float", "float","float"
],
"format": [
"FRACTAL_NZ", "DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0",
"FRACTAL_NZ", "DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0",
"FRACTAL_NZ", "DefaultFormat", "FracZ", "C1HWNCoC0", "NC1HWC0"
],
"name": "output",
"need_compile": true,
"param_type": "required",
"shape": "all"
}
]
}""")
@op_info_register(mul_op_info)
def _mul_tbe():
"""Mul TBE register"""
return

+ 20
- 42
mindspore/ops/_op_impl/tbe/neg.py View File

@@ -14,51 +14,29 @@
# ============================================================================

"""Neg op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

neg_op_info = TBERegOp("Neg") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("neg.so") \
.compute_cost(10) \
.kernel_name("neg") \
.partial_flag(True) \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I8_5HD, DataType.I8_5HD) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I32_5HD, DataType.I32_5HD) \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()

@op_info_register("""{
"op_name": "Neg",
"imply_type": "TBE",
"fusion_type": "ELEMWISE",
"async_flag": false,
"binfile_name": "neg.so",
"compute_cost": 10,
"kernel_name": "neg",
"partial_flag": true,
"attr": [

],
"inputs": [
{
"index": 0,
"dtype": [
"float","float","float16","float16","int32","int32","int8","int8"
],
"format": [
"DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0"
],
"name": "x",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float","float","float16","float16","int32","int32","int8","int8"
],
"format": [
"DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0","DefaultFormat","NC1HWC0"
],
"name": "y",
"param_type": "required",
"shape": "all"
}
]
}""")
@op_info_register(neg_op_info)
def _neg_tbe():
"""Neg TBE register"""
return

+ 12
- 30
mindspore/ops/_op_impl/tbe/npu_alloc_float_status.py View File

@@ -14,39 +14,21 @@
# ============================================================================

"""NPUAllocFloatStatus op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

npu_alloc_float_status_op_info = TBERegOp("NPUAllocFloatStatus") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("n_p_u_alloc_float_status.so") \
.compute_cost(10) \
.kernel_name("n_p_u_alloc_float_status") \
.partial_flag(True) \
.output(0, "data", False, "required", "all") \
.dtype_format(DataType.F32_Default) \
.get_op_info()

@op_info_register("""{
"op_name": "NPUAllocFloatStatus",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "n_p_u_alloc_float_status.so",
"compute_cost": 10,
"kernel_name": "n_p_u_alloc_float_status",
"partial_flag": true,
"attr": [

],
"inputs": [

],
"outputs": [
{
"index": 0,
"dtype": [
"float"
],
"format": [
"DefaultFormat"
],
"name": "data",
"param_type": "required",
"shape": "all"
}
]
}""")
@op_info_register(npu_alloc_float_status_op_info)
def _npu_alloc_float_status_tbe():
"""NPUAllocFloatStatus TBE register"""
return

+ 13
- 43
mindspore/ops/_op_impl/tbe/npu_clear_float_status.py View File

@@ -14,52 +14,22 @@
# ============================================================================

"""NPUClearFloatStatus op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

npu_clear_float_status_op_info = TBERegOp("NPUClearFloatStatus") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("n_p_u_clear_float_status.so") \
.compute_cost(10) \
.kernel_name("n_p_u_clear_float_status") \
.partial_flag(True) \
.input(0, "addr", False, "required", "all") \
.output(0, "data", False, "required", "all") \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.get_op_info()

@op_info_register("""{
"op_name": "NPUClearFloatStatus",
"imply_type": "TBE",
"fusion_type": "ELEMWISE",
"async_flag": false,
"binfile_name": "n_p_u_clear_float_status.so",
"compute_cost": 10,
"kernel_name": "n_p_u_clear_float_status",
"partial_flag": true,
"attr": [

],
"inputs": [
{
"index": 0,
"dtype": [
"float"
],
"format": [
"DefaultFormat"
],
"name": "addr",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float"
],
"format": [
"DefaultFormat"
],
"name": "data",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
]
}""")
@op_info_register(npu_clear_float_status_op_info)
def _npu_clear_float_status_tbe():
"""NPUClearFloatStatus TBE register"""
return

+ 13
- 43
mindspore/ops/_op_impl/tbe/npu_get_float_status.py View File

@@ -14,52 +14,22 @@
# ============================================================================

"""NPUGetFloatStatus op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

npu_get_float_status_op_info = TBERegOp("NPUGetFloatStatus") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("n_p_u_get_float_status.so") \
.compute_cost(10) \
.kernel_name("n_p_u_get_float_status") \
.partial_flag(True) \
.input(0, "addr", False, "required", "all") \
.output(0, "data", False, "required", "all") \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.get_op_info()

@op_info_register("""{
"op_name": "NPUGetFloatStatus",
"imply_type": "TBE",
"fusion_type": "ELEMWISE",
"async_flag": false,
"binfile_name": "n_p_u_get_float_status.so",
"compute_cost": 10,
"kernel_name": "n_p_u_get_float_status",
"partial_flag": true,
"attr": [

],
"inputs": [
{
"index": 0,
"dtype": [
"float"
],
"format": [
"DefaultFormat"
],
"name": "addr",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float"
],
"format": [
"DefaultFormat"
],
"name": "data",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
]
}""")
@op_info_register(npu_get_float_status_op_info)
def _npu_get_float_status_tbe():
"""NPUGetFloatStatus TBE register"""
return

+ 27
- 88
mindspore/ops/_op_impl/tbe/one_hot.py View File

@@ -14,96 +14,35 @@
# ============================================================================

"""OneHot op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

one_hot_op_info = TBERegOp("OneHot") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("one_hot.so") \
.compute_cost(10) \
.kernel_name("one_hot") \
.partial_flag(True) \
.attr("depth", "required", "int", "all") \
.attr("axis", "required", "int", "all") \
.input(0, "x", False, "required", "all") \
.input(1, "on_value", False, "required", "all") \
.input(2, "off_value", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.U8_Default, DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.U8_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.U8_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I32_Default, DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I32_Default, DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I32_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.I32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.get_op_info()

@op_info_register("""{
"op_name": "OneHot",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "one_hot.so",
"compute_cost": 10,
"kernel_name": "one_hot",
"partial_flag": true,
"attr": [
{
"name": "depth",
"param_type": "required",
"type": "int",
"value": "all"
},
{
"name": "axis",
"param_type": "required",
"type": "int",
"value": "all"
}
],
"inputs": [
{
"index": 0,
"dtype": [
"int32","int32","int32","int32","int32",
"uint8","uint8","uint8","uint8","uint8"
],
"format": [
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat",
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat"
],
"name": "x",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype": [
"float16","float32","int32","int8","uint8",
"float16","float32","int32","int8","uint8"
],
"format": [
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat",
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat"
],
"name": "on_value",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 2,
"dtype": [
"float16","float32","int32","int8","uint8",
"float16","float32","int32","int8","uint8"
],
"format": [
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat",
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat"
],
"name": "off_value",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16","float32","int32","int8","uint8",
"float16","float32","int32","int8","uint8"
],
"format": [
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat",
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat"
],
"name": "y",
"param_type": "required",
"shape": "all"
}
]
}""")

@op_info_register(one_hot_op_info)
def _one_hot_tbe():
"""OneHot TBE register"""
return

+ 19
- 49
mindspore/ops/_op_impl/tbe/pad_d.py View File

@@ -14,57 +14,27 @@
# ============================================================================

"""Pad op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

pad_d_op_info = TBERegOp("Pad") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("pad_d.so") \
.compute_cost(10) \
.kernel_name("pad_d") \
.partial_flag(True) \
.attr("paddings", "optional", "listListInt", "all") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.get_op_info()

@op_info_register("""{
"op_name": "Pad",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "pad_d.so",
"compute_cost": 10,
"kernel_name": "pad_d",
"partial_flag": true,
"attr": [
{
"name": "paddings",
"param_type": "optional",
"type": "listListInt",
"value": "all"
}
],
"inputs": [
{
"index": 0,
"dtype": [
"float16","float","int8","uint8","int32"
],
"format": [
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat"
],
"name": "x",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16","float","int8","uint8","int32"
],
"format": [
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat"
],
"name": "y",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
]
}""")

@op_info_register(pad_d_op_info)
def _pad_d_tbe():
"""Pad TBE register"""
return

+ 18
- 56
mindspore/ops/_op_impl/tbe/pow.py View File

@@ -14,65 +14,27 @@
# ============================================================================

"""Pow op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

pow_op_info = TBERegOp("Pow") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("pow.so") \
.compute_cost(10) \
.kernel_name("pow") \
.partial_flag(True) \
.input(0, "x1", False, "required", "all") \
.input(1, "x2", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.get_op_info()

@op_info_register("""{
"op_name": "Pow",
"imply_type": "TBE",
"fusion_type": "ELEMWISE",
"async_flag": false,
"binfile_name": "pow.so",
"compute_cost": 10,
"kernel_name": "pow",
"partial_flag": true,
"attr": [

],
"inputs": [
{
"index": 0,
"dtype": [
"float16", "float", "int32", "int8", "uint8"
],
"format": [
"DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat"
],
"name": "x1",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype": [
"float16", "float", "int32", "int8", "uint8"
],
"format": [
"DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat"
],
"name": "x2",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16", "float", "int32", "int8", "uint8"
],
"format": [
"DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat"
],
"name": "y",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
]
}""")
@op_info_register(pow_op_info)
def _pow_tbe():
"""Pow TBE register"""
return

+ 17
- 55
mindspore/ops/_op_impl/tbe/real_div.py View File

@@ -14,64 +14,26 @@
# ============================================================================

"""RealDiv op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

realdiv_op_info = TBERegOp("RealDiv") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("realdiv.so") \
.compute_cost(10) \
.kernel_name("realdiv") \
.partial_flag(True) \
.input(0, "x", False, "required", "all") \
.input(1, "y", False, "required", "all") \
.output(0, "z", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()

@op_info_register("""{
"op_name": "RealDiv",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "realdiv.so",
"compute_cost": 10,
"kernel_name": "realdiv",
"partial_flag": true,
"attr": [

],
"inputs": [
{
"index": 0,
"dtype": [
"float16", "float16", "float", "float"
],
"format": [
"DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0"
],
"name": "x",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype": [
"float16", "float16", "float", "float"
],
"format": [
"DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0"
],
"name": "y",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16", "float16", "float", "float"
],
"format": [
"DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0"
],
"name": "z",
"param_type": "required",
"shape": "all"
}
]
}""")
@op_info_register(realdiv_op_info)
def _real_div_tbe():
"""RealDiv TBE register"""
return

+ 18
- 43
mindspore/ops/_op_impl/tbe/reciprocal.py View File

@@ -14,52 +14,27 @@
# ============================================================================

"""Add op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

reciprocal_op_info = TBERegOp("Reciprocal") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("reciprocal.so") \
.compute_cost(10) \
.kernel_name("reciprocal") \
.partial_flag(True) \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F16_NHWC, DataType.F16_NHWC) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_NHWC, DataType.F32_NHWC) \
.get_op_info()

@op_info_register("""{
"op_name": "Reciprocal",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "reciprocal.so",
"compute_cost": 10,
"kernel_name": "reciprocal",
"partial_flag": true,
"attr": [

],
"inputs": [
{
"index": 0,
"dtype": [
"float16", "float16", "float16", "float32", "float32", "float32"
],
"format": [
"DefaultFormat", "NC1HWC0", "NHWC", "DefaultFormat", "NC1HWC0", "NHWC"
],
"name": "x",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16", "float16", "float16", "float32", "float32", "float32"
],
"format": [
"DefaultFormat", "NC1HWC0", "NHWC", "DefaultFormat", "NC1HWC0", "NHWC"
],
"name": "y",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
]
}""")
@op_info_register(reciprocal_op_info)
def _reciprocal_tbe():
"""Add TBE register"""
return

+ 21
- 55
mindspore/ops/_op_impl/tbe/reduce_max.py View File

@@ -14,63 +14,29 @@
# ============================================================================

"""ReduceMax op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

reduce_max_d_op_info = TBERegOp("ReduceMax") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("reduce_max_d.so") \
.compute_cost(10) \
.kernel_name("reduce_max_d") \
.partial_flag(True) \
.attr("axis", "optional", "listInt", "all") \
.attr("keep_dims", "optional", "bool", "all") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.get_op_info()

@op_info_register("""{
"op_name": "ReduceMax",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "reduce_max_d.so",
"compute_cost": 10,
"kernel_name": "reduce_max_d",
"partial_flag": true,
"attr": [
{
"name": "axis",
"param_type": "required",
"type": "listInt",
"value": "all"
},
{
"name": "keep_dims",
"param_type": "required",
"type": "bool",
"value": "all"
}
],
"inputs": [
{
"index": 0,
"dtype": [
"float16", "float", "int8", "uint8", "bool", "int32"
],
"format": [
"DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat"
],
"name": "x",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16", "float", "int8", "uint8", "bool", "int32"
],
"format": [
"DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat"
],
"name": "y",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
]
}""")

@op_info_register(reduce_max_d_op_info)
def _reduce_max_tbe():
"""ReduceMax TBE register"""
return

+ 19
- 55
mindspore/ops/_op_impl/tbe/reduce_mean.py View File

@@ -14,63 +14,27 @@
# ============================================================================

"""ReduceMean op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

reduce_mean_op_info = TBERegOp("ReduceMean") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("reduce_mean.so") \
.compute_cost(10) \
.kernel_name("reduce_mean") \
.partial_flag(True) \
.attr("axis", "optional", "listInt", "all") \
.attr("keep_dims", "optional", "bool", "all") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.get_op_info()

@op_info_register("""{
"op_name": "ReduceMean",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "reduce_mean.so",
"compute_cost": 10,
"kernel_name": "reduce_mean",
"partial_flag": true,
"attr": [
{
"name": "axis",
"param_type": "optional",
"type": "listInt",
"value": "all"
},
{
"name": "keep_dims",
"param_type": "optional",
"type": "bool",
"value": "all"
}
],
"inputs": [
{
"index": 0,
"dtype": [
"float16","float","float16","int8","uint8"
],
"format": [
"NC1HWC0","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat"
],
"name": "x",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16","float","float16","int8","uint8"
],
"format": [
"NC1HWC0","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat"
],
"name": "y",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
]
}""")

@op_info_register(reduce_mean_op_info)
def _reduce_mean_tbe():
"""ReduceMean TBE register"""
return

+ 19
- 55
mindspore/ops/_op_impl/tbe/reduce_mean_d.py View File

@@ -14,63 +14,27 @@
# ============================================================================

"""ReduceMeanD op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

reduce_mean_d_op_info = TBERegOp("ReduceMeanD") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("reduce_mean_d.so") \
.compute_cost(10) \
.kernel_name("reduce_mean_d") \
.partial_flag(True) \
.attr("axis", "optional", "listInt", "all") \
.attr("keep_dims", "optional", "bool", "all") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.get_op_info()

@op_info_register("""{
"op_name": "ReduceMeanD",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "reduce_mean_d.so",
"compute_cost": 10,
"kernel_name": "reduce_mean_d",
"partial_flag": true,
"attr": [
{
"name": "axis",
"param_type": "optional",
"type": "listInt",
"value": "all"
},
{
"name": "keep_dims",
"param_type": "optional",
"type": "bool",
"value": "all"
}
],
"inputs": [
{
"index": 0,
"dtype": [
"float","float16","int8","uint8"
],
"format": [
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat"
],
"name": "x",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float","float16","int8","uint8"
],
"format": [
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat"
],
"name": "y",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
]
}""")

@op_info_register(reduce_mean_d_op_info)
def _reduce_mean_d_tbe():
"""Conv2D TBE register"""
return

+ 23
- 55
mindspore/ops/_op_impl/tbe/reduce_min.py View File

@@ -14,63 +14,31 @@
# ============================================================================

"""ReduceMin op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

reduce_min_op_info = TBERegOp("ReduceMin") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("reduce_min_d.so") \
.compute_cost(10) \
.kernel_name("reduce_min_d") \
.partial_flag(True) \
.attr("axis", "required", "listInt", "all") \
.attr("keep_dims", "required", "bool", "all") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I8_FracZ, DataType.I8_FracZ) \
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.U8_FracZ, DataType.U8_FracZ) \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_FracZ, DataType.F16_FracZ) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ) \
.get_op_info()

@op_info_register("""{
"op_name": "ReduceMin",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "reduce_min_d.so",
"compute_cost": 10,
"kernel_name": "reduce_min_d",
"partial_flag": true,
"attr": [
{
"name": "axis",
"param_type": "required",
"type": "listInt",
"value": "all"
},
{
"name": "keep_dims",
"param_type": "required",
"type": "bool",
"value": "all"
}
],
"inputs": [
{
"index": 0,
"dtype": [
"float16", "float16", "float", "float", "int8", "int8", "uint8", "uint8"
],
"format": [
"DefaultFormat", "FracZ", "DefaultFormat", "FracZ", "DefaultFormat", "FracZ", "DefaultFormat", "FracZ"
],
"name": "x",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16", "float16", "float", "float", "int8", "int8", "uint8", "uint8"
],
"format": [
"DefaultFormat", "FracZ", "DefaultFormat", "FracZ", "DefaultFormat", "FracZ", "DefaultFormat", "FracZ"
],
"name": "y",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
]
}""")

@op_info_register(reduce_min_op_info)
def _reduce_min_tbe():
"""ReduceMin TBE register"""
return

+ 17
- 55
mindspore/ops/_op_impl/tbe/reduce_sum.py View File

@@ -14,63 +14,25 @@
# ============================================================================

"""ReduceSum op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

reduce_sum_op_info = TBERegOp("ReduceSum") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("reduce_sum_d.so") \
.compute_cost(10) \
.kernel_name("reduce_sum_d") \
.partial_flag(True) \
.attr("axis", "optional", "listInt", "all") \
.attr("keep_dims", "optional", "bool", "all") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.get_op_info()

@op_info_register("""{
"op_name": "ReduceSum",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "reduce_sum_d.so",
"compute_cost": 10,
"kernel_name": "reduce_sum_d",
"partial_flag": true,
"attr": [
{
"name": "axis",
"param_type": "optional",
"type": "listInt",
"value": "all"
},
{
"name": "keep_dims",
"param_type": "optional",
"type": "bool",
"value": "all"
}
],
"inputs": [
{
"index": 0,
"dtype": [
"float16", "float"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "x",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16", "float"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "y",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
]
}""")

@op_info_register(reduce_sum_op_info)
def _reduce_sum_tbe():
"""ReduceSum TBE register"""
return

+ 20
- 45
mindspore/ops/_op_impl/tbe/relu.py View File

@@ -14,54 +14,29 @@
# ============================================================================

"""ReLU op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

relu_op_info = TBERegOp("ReLU") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("relu.so") \
.compute_cost(10) \
.kernel_name("relu") \
.partial_flag(True) \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I8_5HD, DataType.I8_5HD) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I32_5HD, DataType.I32_5HD) \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()

@op_info_register("""{
"op_name": "ReLU",
"imply_type": "TBE",
"fusion_type": "ELEMWISE",
"async_flag": false,
"binfile_name": "relu.so",
"compute_cost": 10,
"kernel_name": "relu",
"partial_flag": true,
"attr": [

],
"inputs": [
{
"index": 0,
"dtype": [
"float16", "float16", "float", "float","int32", "int32", "int8", "int8"
],
"format": [
"DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0",
"DefaultFormat", "NC1HWC0"
],
"name": "x",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16", "float16", "float", "float", "int32", "int32", "int8", "int8"
],
"format": [
"DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0",
"DefaultFormat", "NC1HWC0"
],
"name": "y",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
]
}""")
@op_info_register(relu_op_info)
def _relu_tbe():
"""Relu TBE register"""
return

+ 23
- 59
mindspore/ops/_op_impl/tbe/relu_grad.py View File

@@ -14,68 +14,32 @@
# ============================================================================

"""ReluGrad op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

relugrad_op_info = TBERegOp("ReluGrad") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("relugrad.so") \
.compute_cost(10) \
.kernel_name("relugrad") \
.partial_flag(True) \
.input(0, "gradients", False, "required", "all") \
.input(1, "features", False, "required", "all") \
.output(0, "backprops", True, "required", "all") \
.dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I8_5HD, DataType.I8_5HD, DataType.I8_5HD) \
.dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.U8_5HD, DataType.U8_5HD, DataType.U8_5HD) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()

@op_info_register("""{
"op_name": "ReluGrad",
"imply_type": "TBE",
"fusion_type": "ELEMWISE",
"async_flag": false,
"binfile_name": "relugrad.so",
"compute_cost": 10,
"kernel_name": "relugrad",
"partial_flag": true,
"attr": [

],
"inputs": [
{
"index": 0,
"dtype": [
"float16", "float16", "float", "float", "int32", "int32", "int8", "int8", "uint8", "uint8"
],
"format": [
"DefaultFormat", "NC1HWC0","DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0",
"DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0"
],
"name": "gradients",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype": [
"float16", "float16", "float", "float", "int32", "int32", "int8", "int8", "uint8", "uint8"
],
"format": [
"DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0",
"DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0"
],
"name": "features",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16", "float16", "float", "float", "int32", "int32", "int8", "int8", "uint8", "uint8"
],
"format": [
"DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0",
"DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0"
],
"name": "backprops",
"need_compile": true,
"param_type": "required",
"shape": "all"
}
]
}""")
@op_info_register(relugrad_op_info)
def _relu_grad_tbe():
"""ReluGrad TBE register"""
return

+ 17
- 49
mindspore/ops/_op_impl/tbe/reshape.py View File

@@ -14,57 +14,25 @@
# ============================================================================

"""Reshape op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

reshape_op_info = TBERegOp("Reshape") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("reshape.so") \
.compute_cost(10) \
.kernel_name("reshape") \
.partial_flag(True) \
.attr("shape", "required", "listInt", "all") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.get_op_info()

@op_info_register("""{
"op_name": "Reshape",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "reshape.so",
"compute_cost": 10,
"kernel_name": "reshape",
"partial_flag": true,
"attr": [
{
"name": "shape",
"param_type": "required",
"type": "listInt",
"value": "all"
}
],
"inputs": [
{
"index": 0,
"dtype": [
"float16", "float32", "int32"
],
"format": [
"DefaultFormat", "DefaultFormat", "DefaultFormat"
],
"name": "x",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16", "float32", "int32"
],
"format": [
"DefaultFormat", "DefaultFormat", "DefaultFormat"
],
"name": "y",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
]
}""")

@op_info_register(reshape_op_info)
def _reshape_tbe():
"""Reshape TBE register"""
return

+ 25
- 59
mindspore/ops/_op_impl/tbe/resize_nearest_neighbor.py View File

@@ -14,67 +14,33 @@
# ============================================================================

"""ResizeNearestNeighbor op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

resize_nearest_neighbor_op_info = TBERegOp("ResizeNearestNeighbor") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("resize_nearest_neighbor_d.so") \
.compute_cost(10) \
.kernel_name("resize_nearest_neighbor_d") \
.partial_flag(True) \
.attr("size", "required", "listInt", "all") \
.attr("align_corners", "optional", "bool", "all") \
.input(0, "images", False, "required", "all") \
.output(0, "y", True, "required", "all") \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I8_5HD, DataType.I8_5HD) \
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.U8_5HD, DataType.U8_5HD) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I32_5HD, DataType.I32_5HD) \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()

@op_info_register("""{
"op_name": "ResizeNearestNeighbor",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "resize_nearest_neighbor_d.so",
"compute_cost": 10,
"kernel_name": "resize_nearest_neighbor_d",
"partial_flag": true,
"attr": [
{
"name": "size",
"param_type": "required",
"type": "listInt",
"value": "all"
},
{
"name": "align_corners",
"param_type": "optional",
"type": "bool",
"value": "all"
}
],
"inputs": [
{
"index": 0,
"dtype": [
"float16","float","int32","int8","uint8",
"float16","float","int32","int8","uint8"
],
"format": [
"NC1HWC0","NC1HWC0","NC1HWC0","NC1HWC0","NC1HWC0",
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat"
],
"name": "images",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16","float","int32","int8","uint8",
"float16","float","int32","int8","uint8"
],
"format": [
"NC1HWC0","NC1HWC0","NC1HWC0","NC1HWC0","NC1HWC0",
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat"
],
"name": "y",
"need_compile": true,
"param_type": "required",
"shape": "all"
}
]
}""")

@op_info_register(resize_nearest_neighbor_op_info)
def _resize_nearest_neighbor_d_tbe():
"""ResizeNearestNeighbor TBE register"""
return

+ 20
- 55
mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_d.py View File

@@ -14,63 +14,28 @@
# ============================================================================

"""ResizeNearestNeighbor op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

resize_nearest_neighbor_d_op_info = TBERegOp("ResizeNearestNeighbor") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("resize_nearest_neighbor_d.so") \
.compute_cost(10) \
.kernel_name("resize_nearest_neighbor_d") \
.partial_flag(True) \
.attr("size", "required", "listInt", "all") \
.attr("align_corners", "optional", "bool", "all") \
.input(0, "images", False, "required", "all") \
.output(0, "y", True, "required", "all") \
.dtype_format(DataType.I8_5HD, DataType.I8_5HD) \
.dtype_format(DataType.U8_5HD, DataType.U8_5HD) \
.dtype_format(DataType.I32_5HD, DataType.I32_5HD) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()

@op_info_register("""{
"op_name": "ResizeNearestNeighbor",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "resize_nearest_neighbor_d.so",
"compute_cost": 10,
"kernel_name": "resize_nearest_neighbor_d",
"partial_flag": true,
"attr": [
{
"name": "size",
"param_type": "required",
"type": "listInt",
"value": "all"
},
{
"name": "align_corners",
"param_type": "optional",
"type": "bool",
"value": "all"
}
],
"inputs": [
{
"index": 0,
"dtype": [
"float16","float","int32","int8","uint8"
],
"format": [
"NC1HWC0","NC1HWC0","NC1HWC0","NC1HWC0","NC1HWC0"
],
"name": "images",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16","float","int32","int8","uint8"
],
"format": [
"NC1HWC0","NC1HWC0","NC1HWC0","NC1HWC0","NC1HWC0"
],
"name": "y",
"need_compile": true,
"param_type": "required",
"shape": "all"
}
]
}""")

@op_info_register(resize_nearest_neighbor_d_op_info)
def _resize_nearest_neighbor_d_tbe():
"""ResizeNearestNeighbor TBE register"""
return

+ 16
- 55
mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_grad_d.py View File

@@ -14,63 +14,24 @@
# ============================================================================

"""ResizeNearestNeighborgrad op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

resize_nearest_neighbor_grad_d_op_info = TBERegOp("ResizeNearestNeighborGrad") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("resize_nearest_neighbor_grad_d.so") \
.compute_cost(10) \
.kernel_name("resize_nearest_neighbor_grad_d") \
.partial_flag(True) \
.attr("size", "required", "listInt", "all") \
.attr("align_corners", "optional", "bool", "all") \
.input(0, "grads", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()

@op_info_register("""{
"op_name": "ResizeNearestNeighborGrad",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "resize_nearest_neighbor_grad_d.so",
"compute_cost": 10,
"kernel_name": "resize_nearest_neighbor_grad_d",
"partial_flag": true,
"attr": [
{
"name": "size",
"param_type": "required",
"type": "listInt",
"value": "all"
},
{
"name": "align_corners",
"param_type": "optional",
"type": "bool",
"value": "all"
}
],
"inputs": [
{
"index": 0,
"dtype": [
"float"
],
"format": [
"NC1HWC0"
],
"name": "grads",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float"
],
"format": [
"NC1HWC0"
],
"name": "y",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
]
}""")

@op_info_register(resize_nearest_neighbor_grad_d_op_info)
def _resize_nearest_neighbor_grad_d_tbe():
"""ResizeNearestNeighborGrad TBE register"""
return

+ 18
- 43
mindspore/ops/_op_impl/tbe/round.py View File

@@ -14,52 +14,27 @@
# ============================================================================

"""Round op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

round_op_info = TBERegOp("Round") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("round.so") \
.compute_cost(10) \
.kernel_name("round") \
.partial_flag(True) \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F16_FracZ, DataType.F16_FracZ) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ) \
.get_op_info()

@op_info_register("""{
"op_name": "Round",
"imply_type": "TBE",
"fusion_type": "ELEMWISE",
"async_flag": false,
"binfile_name": "round.so",
"compute_cost": 10,
"kernel_name": "round",
"partial_flag": true,
"attr": [

],
"inputs": [
{
"index": 0,
"dtype": [
"float16", "float16", "float16", "float", "float", "float"
],
"format": [
"DefaultFormat", "NC1HWC0", "FracZ", "DefaultFormat", "NC1HWC0", "FracZ"
],
"name": "x",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16", "float16", "float16", "float", "float", "float"
],
"format": [
"DefaultFormat", "NC1HWC0", "FracZ", "DefaultFormat", "NC1HWC0", "FracZ"
],
"name": "y",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
]
}""")
@op_info_register(round_op_info)
def _round_tbe():
"""Round TBE register"""
return

+ 21
- 86
mindspore/ops/_op_impl/tbe/rsqrt.py View File

@@ -14,94 +14,29 @@
# ============================================================================

"""Rsqrt op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

rsqrt_op_info = TBERegOp("Rsqrt") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("rsqrt.so") \
.compute_cost(10) \
.kernel_name("rsqrt") \
.partial_flag(True) \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F16_FracZ, DataType.F16_FracZ) \
.dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ) \
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \
.get_op_info()

@op_info_register("""{
"op_name":"Rsqrt",
"imply_type":"TBE",
"fusion_type":"OPAQUE",
"async_flag":false,
"binfile_name":"rsqrt.so",
"compute_cost":10,
"kernel_name":"rsqrt",
"partial_flag":true,
"attr":[],
"inputs":[
{
"index":0,
"dtype":[
"float16",
"float16",
"float16",
"float16",
"float16",
"float16",
"float",
"float",
"float",
"float",
"float",
"float"
],
"format":[
"DefaultFormat",
"NC1HWC0",
"DefaultFormat",
"FracZ",
"C1HWNCoC0",
"DefaultFormat",
"DefaultFormat",
"NC1HWC0",
"DefaultFormat",
"DefaultFormat",
"FracZ",
"C1HWNCoC0"
],
"name":"x",
"need_compile":false,
"param_type":"required",
"shape":"all"
}
],
"outputs":[
{
"index":0,
"dtype":[
"float16",
"float16",
"float16",
"float16",
"float16",
"float16",
"float",
"float",
"float",
"float",
"float",
"float"
],
"format":[
"DefaultFormat",
"NC1HWC0",
"DefaultFormat",
"FracZ",
"C1HWNCoC0",
"DefaultFormat",
"DefaultFormat",
"NC1HWC0",
"DefaultFormat",
"DefaultFormat",
"FracZ",
"C1HWNCoC0"
],
"name":"y",
"need_compile":false,
"param_type":"required",
"shape":"all"
}
]
}""")

@op_info_register(rsqrt_op_info)
def _rsqrt_tbe():
"""Rsqrt TBE register"""
return

+ 20
- 63
mindspore/ops/_op_impl/tbe/scatter_nd.py View File

@@ -14,71 +14,28 @@
# ============================================================================

"""ScatterNd op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

scatter_nd_op_info = TBERegOp("ScatterNd") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("scatter_nd_d.so") \
.compute_cost(10) \
.kernel_name("scatter_nd_d") \
.partial_flag(True) \
.attr("shape", "optional", "listInt", "all") \
.input(0, "indices", False, "required", "all") \
.input(1, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I32_Default, DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \
.get_op_info()

# map to tbe kernel name scatter_nd_d
@op_info_register("""{
"op_name": "ScatterNd",
"imply_type": "TBE",
"fusion_type": "ELEMWISE",
"async_flag": false,
"binfile_name": "scatter_nd_d.so",
"compute_cost": 10,
"kernel_name": "scatter_nd_d",
"partial_flag": true,
"attr": [
{
"name": "shape",
"param_type": "optional",
"type": "listInt",
"value": "all"
}
],
"inputs": [
{
"index": 0,
"dtype": [
"int32", "int32", "int32", "int32", "int32"
],
"format": [
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat"
],
"name": "indices",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype": [
"float16","float","int32","int8","uint8"
],
"format": [
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat"
],
"name": "x",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16","float","int32","int8","uint8"
],
"format": [
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat"
],
"name": "y",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
]
}""")

@op_info_register(scatter_nd_op_info)
def _scatter_nd_tbe():
"""Conv2D TBE register"""
return

+ 20
- 62
mindspore/ops/_op_impl/tbe/scatter_nd_d.py View File

@@ -14,70 +14,28 @@
# ============================================================================

"""ScatterNdD op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

scatter_nd_d_op_info = TBERegOp("ScatterNdD") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("scatter_nd_d.so") \
.compute_cost(10) \
.kernel_name("scatter_nd_d") \
.partial_flag(True) \
.attr("shape", "optional", "listInt", "all") \
.input(0, "indices", False, "required", "all") \
.input(1, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I32_Default, DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \
.get_op_info()

@op_info_register("""{
"op_name": "ScatterNdD",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "scatter_nd_d.so",
"compute_cost": 10,
"kernel_name": "scatter_nd_d",
"partial_flag": true,
"attr": [
{
"name": "shape",
"param_type": "optional",
"type": "listInt",
"value": "all"
}
],
"inputs": [
{
"index": 0,
"dtype": [
"int32", "int32", "int32", "int32", "int32"
],
"format": [
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat"
],
"name": "indices",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype": [
"float16","float","int32","int8","uint8"
],
"format": [
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat"
],
"name": "x",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16","float","int32","int8","uint8"
],
"format": [
"DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat","DefaultFormat"
],
"name": "y",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
]
}""")

@op_info_register(scatter_nd_d_op_info)
def _scatter_nd_d_tbe():
"""ScatterNdD TBE register"""
return

+ 25
- 86
mindspore/ops/_op_impl/tbe/select.py View File

@@ -14,94 +14,33 @@
# ============================================================================

"""Select op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

select_op_info = TBERegOp("Select") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("select.so") \
.compute_cost(10) \
.kernel_name("select") \
.partial_flag(True) \
.input(0, "condition", False, "required", "all") \
.input(1, "x1", False, "required", "all") \
.input(2, "x2", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.BOOL_Default, DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.BOOL_Default, DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.BOOL_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.BOOL_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.BOOL_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.BOOL_5HD, DataType.I8_5HD, DataType.I8_5HD, DataType.I8_5HD) \
.dtype_format(DataType.BOOL_5HD, DataType.U8_5HD, DataType.U8_5HD, DataType.U8_5HD) \
.dtype_format(DataType.BOOL_5HD, DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \
.dtype_format(DataType.BOOL_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.BOOL_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()

@op_info_register("""{
"op_name": "Select",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "select.so",
"compute_cost": 10,
"kernel_name": "select",
"partial_flag": true,
"attr":[
],
"inputs": [
{
"index": 0,
"dtype": [
"bool", "bool", "bool", "bool", "bool", "bool", "bool", "bool", "bool", "bool",
"bool", "bool", "bool", "bool", "bool", "bool", "bool", "bool", "bool", "bool"
],
"format": [
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat",
"NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat",
"DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat",
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat"
],
"name": "condition",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype": [
"float16", "float16", "float16", "float16", "float", "float", "float", "float",
"int32", "int32", "int32", "int32", "int8", "int8", "int8", "int8", "uint8",
"uint8", "uint8", "uint8"
],
"format": [
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat",
"DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat",
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat"
],
"name": "x1",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 2,
"dtype": [
"float16", "float16", "float16", "float16", "float", "float", "float", "float", "int32",
"int32", "int32", "int32", "int8", "int8", "int8", "int8", "uint8", "uint8", "uint8", "uint8"
],
"format": [
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat",
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
"DefaultFormat", "DefaultFormat"
],
"name": "x2",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16", "float16", "float16", "float16", "float", "float", "float", "float", "int32",
"int32", "int32", "int32", "int8", "int8", "int8", "int8", "uint8", "uint8", "uint8", "uint8"
],
"format": [
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
"DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat",
"DefaultFormat", "NC1HWC0", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0",
"DefaultFormat", "DefaultFormat"
],
"name": "y",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
]
}""")

@op_info_register(select_op_info)
def _select_tbe():
"""Select TBE register"""
return

+ 23
- 59
mindspore/ops/_op_impl/tbe/sigmoid.py View File

@@ -14,67 +14,31 @@
# ============================================================================

"""Sigmoid op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

sigmoid_op_info = TBERegOp("Sigmoid") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("sigmoid.so") \
.compute_cost(10) \
.kernel_name("sigmoid") \
.partial_flag(True) \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F16_FracZ, DataType.F16_FracZ) \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ) \
.dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ) \
.dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ) \
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \
.get_op_info()

@op_info_register("""{
"op_name": "Sigmoid",
"imply_type": "TBE",
"fusion_type": "ELEMWISE",
"async_flag": false,
"binfile_name": "Sigmoid.so",
"compute_cost": 10,
"kernel_name": "sigmoid",
"partial_flag": true,
"attr": [
],
"inputs": [
{
"index": 0,
"dtype": [
"float16","float",
"float16","float",
"float16","float",
"float16","float",
"float16","float"
],
"format": [
"FracZ","FracZ",
"FRACTAL_NZ","FRACTAL_NZ",
"C1HWNCoC0","C1HWNCoC0",
"NC1HWC0","NC1HWC0",
"DefaultFormat","DefaultFormat"
],
"name": "x",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16","float",
"float16","float",
"float16","float",
"float16","float",
"float16","float"
],
"format": [
"FracZ","FracZ",
"FRACTAL_NZ","FRACTAL_NZ",
"C1HWNCoC0","C1HWNCoC0",
"NC1HWC0","NC1HWC0",
"DefaultFormat","DefaultFormat"
],
"name": "y",
"need_compile": true,
"param_type": "required",
"shape": "all"
}
]
}""")

@op_info_register(sigmoid_op_info)
def _sigmoid_tbe():
"""Sigmoid TBE register"""
return

+ 18
- 56
mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits.py View File

@@ -14,64 +14,26 @@
# ============================================================================

"""SigmoidCrossEntropyWithLogits op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

sigmoid_cross_entropy_with_logits_op_info = TBERegOp("SigmoidCrossEntropyWithLogits") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("sigmoid_cross_entropy_with_logits.so") \
.compute_cost(10) \
.kernel_name("sigmoid_cross_entropy_with_logits") \
.partial_flag(True) \
.input(0, "predict", False, "required", "all") \
.input(1, "target", False, "required", "all") \
.output(0, "loss", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()

@op_info_register("""{
"op_name": "SigmoidCrossEntropyWithLogits",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "sigmoid_cross_entropy_with_logits.so",
"compute_cost": 10,
"kernel_name": "sigmoid_cross_entropy_with_logits",
"partial_flag": true,
"attr": [
],
"inputs": [
{
"index": 0,
"dtype": [
"float16", "float16", "float", "float"
],
"format": [
"NC1HWC0", "DefaultFormat", "NC1HWC0", "DefaultFormat"
],
"name": "predict",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype": [
"float16", "float16", "float", "float"
],
"format": [
"NC1HWC0", "DefaultFormat", "NC1HWC0", "DefaultFormat"
],
"name": "target",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16", "float16", "float", "float"
],
"format": [
"NC1HWC0", "DefaultFormat", "NC1HWC0", "DefaultFormat"
],
"name": "loss",
"need_compile": true,
"param_type": "required",
"shape": "all"
}
]
}""")

@op_info_register(sigmoid_cross_entropy_with_logits_op_info)
def _sigmoid_cross_entropy_with_logits_tbe():
"""SigmoidCrossEntropyWithLogits TBE register"""
return

+ 19
- 69
mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_grad.py View File

@@ -14,77 +14,27 @@
# ============================================================================

"""SigmoidCrossEntropyWithLogitsGrad op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

sigmoid_cross_entropy_with_logits_grad_op_info = TBERegOp("SigmoidCrossEntropyWithLogitsGrad") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("sigmoid_cross_entropy_with_logits_grad.so") \
.compute_cost(10) \
.kernel_name("sigmoid_cross_entropy_with_logits_grad") \
.partial_flag(True) \
.input(0, "predict", False, "required", "all") \
.input(1, "target", False, "required", "all") \
.input(2, "dout", False, "required", "all") \
.output(0, "gradient", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()

@op_info_register("""{
"op_name": "SigmoidCrossEntropyWithLogitsGrad",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "sigmoid_cross_entropy_with_logits_grad.so",
"compute_cost": 10,
"kernel_name": "sigmoid_cross_entropy_with_logits_grad",
"partial_flag": true,
"attr": [
],
"inputs": [
{
"index": 0,
"dtype": [
"float16", "float16", "float", "float"
],
"format": [
"NC1HWC0", "DefaultFormat", "NC1HWC0", "DefaultFormat"
],
"name": "predict",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype": [
"float16", "float16", "float", "float"
],
"format": [
"NC1HWC0", "DefaultFormat", "NC1HWC0", "DefaultFormat"
],
"name": "target",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 2,
"dtype": [
"float16", "float16", "float", "float"
],
"format": [
"NC1HWC0", "DefaultFormat", "NC1HWC0", "DefaultFormat"
],
"name": "dout",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16", "float16", "float", "float"
],
"format": [
"NC1HWC0", "DefaultFormat", "NC1HWC0", "DefaultFormat"
],
"name": "gradient",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
]
}""")

@op_info_register(sigmoid_cross_entropy_with_logits_grad_op_info)
def _sigmoid_cross_entropy_with_logits_grad_tbe():
"""SigmoidCrossEntropyWithLogitsGrad TBE register"""
return

+ 18
- 56
mindspore/ops/_op_impl/tbe/sigmoid_grad.py View File

@@ -14,64 +14,26 @@
# ============================================================================

"""SigmoidGrad op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

sigmoid_cross_entropy_with_logits_op_info = TBERegOp("SigmoidGrad") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("sigmoid_grad.so") \
.compute_cost(10) \
.kernel_name("sigmoid_grad") \
.partial_flag(True) \
.input(0, "x", False, "required", "all") \
.input(1, "y", False, "required", "all") \
.output(0, "z", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()

@op_info_register("""{
"op_name": "SigmoidGrad",
"imply_type": "TBE",
"fusion_type": "ELEMWISE",
"async_flag": false,
"binfile_name": "sigmoid_grad.so",
"compute_cost": 10,
"kernel_name": "sigmoid_grad",
"partial_flag": true,
"attr": [
],
"inputs": [
{
"index": 0,
"dtype": [
"float16","float","float16","float"
],
"format": [
"NC1HWC0","NC1HWC0","DefaultFormat","DefaultFormat"
],
"name": "x",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype": [
"float16","float","float16","float"
],
"format": [
"NC1HWC0","NC1HWC0","DefaultFormat","DefaultFormat"
],
"name": "y",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16","float","float16","float"
],
"format": [
"NC1HWC0","NC1HWC0","DefaultFormat","DefaultFormat"
],
"name": "z",
"need_compile": true,
"param_type": "required",
"shape": "all"
}
]
}""")

@op_info_register(sigmoid_cross_entropy_with_logits_op_info)
def _sigmoid_grad_tbe():
"""SigmoidGrad TBE register"""
return

+ 25
- 91
mindspore/ops/_op_impl/tbe/slice.py View File

@@ -14,99 +14,33 @@
# ============================================================================

"""Slice op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

slice_op_info = TBERegOp("Slice") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("slice_d.so") \
.compute_cost(10) \
.kernel_name("slice_d") \
.partial_flag(True) \
.attr("begin", "required", "listInt", "all") \
.attr("size", "required", "listInt", "all") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
.dtype_format(DataType.U16_Default, DataType.U16_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.U32_Default, DataType.U32_Default) \
.dtype_format(DataType.U64_Default, DataType.U64_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.get_op_info()

@op_info_register("""{
"op_name":"Slice",
"imply_type":"TBE",
"fusion_type":"OPAQUE",
"async_flag":false,
"binfile_name":"slice_d.so",
"compute_cost":10,
"kernel_name":"slice_d",
"partial_flag":true,
"attr":[
{
"name":"begin",
"param_type":"required",
"type":"listInt",
"value":"all"
},
{
"name":"size",
"param_type":"required",
"type":"listInt",
"value":"all"
}
],
"inputs":[
{
"index":0,
"dtype":[
"float",
"float16",
"int8",
"int16",
"int32",
"int64",
"uint8",
"uint16",
"uint32",
"uint64"
],
"format":[
"DefaultFormat",
"DefaultFormat",
"DefaultFormat",
"DefaultFormat",
"DefaultFormat",
"DefaultFormat",
"DefaultFormat",
"DefaultFormat",
"DefaultFormat",
"DefaultFormat"
],
"name":"x",
"need_compile":false,
"param_type":"required",
"shape":"all"
}
],
"outputs":[
{
"index":0,
"dtype":[
"float",
"float16",
"int8",
"int16",
"int32",
"int64",
"uint8",
"uint16",
"uint32",
"uint64"
],
"format":[
"DefaultFormat",
"DefaultFormat",
"DefaultFormat",
"DefaultFormat",
"DefaultFormat",
"DefaultFormat",
"DefaultFormat",
"DefaultFormat",
"DefaultFormat",
"DefaultFormat"
],
"name":"y",
"need_compile":false,
"param_type":"required",
"shape":"all"
}
]
}""")

@op_info_register(slice_op_info)
def _slice_tbe():
"""Slice TBE register"""
return

+ 19
- 49
mindspore/ops/_op_impl/tbe/softmax.py View File

@@ -14,57 +14,27 @@
# ============================================================================

"""Softmax op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

softmax_op_info = TBERegOp("Softmax") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("softmax.so") \
.compute_cost(10) \
.kernel_name("softmax") \
.partial_flag(True) \
.attr("axis", "optional", "listInt", "all") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ) \
.get_op_info()

@op_info_register("""{
"op_name": "Softmax",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "softmax.so",
"compute_cost": 10,
"kernel_name": "softmax",
"partial_flag": true,
"attr": [
{
"name": "axis",
"param_type": "optional",
"type": "listInt",
"value": "all"
}
],
"inputs": [
{
"index": 0,
"dtype": [
"float16", "float16", "float16", "float", "float"
],
"format": [
"FRACTAL_NZ", "DefaultFormat", "NC1HWC0", "FRACTAL_NZ", "DefaultFormat"
],
"name": "x",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16", "float16", "float16", "float", "float"
],
"format": [
"FRACTAL_NZ", "DefaultFormat", "NC1HWC0", "FRACTAL_NZ", "DefaultFormat"
],
"name": "y",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
]
}""")

@op_info_register(softmax_op_info)
def _softmax_tbe():
"""Softmax TBE register"""
return

+ 16
- 69
mindspore/ops/_op_impl/tbe/softmax_cross_entropy_with_logits.py View File

@@ -14,78 +14,25 @@
# ============================================================================

"""SoftmaxCrossEntropyWithLogits op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

softmax_cross_entropy_with_logits_op_info = TBERegOp("SoftmaxCrossEntropyWithLogits") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("softmax_cross_entropy_with_logits.so") \
.compute_cost(10) \
.kernel_name("softmax_cross_entropy_with_logits") \
.partial_flag(True) \
.input(0, "input_features", False, "required", "all") \
.input(1, "input_labels", False, "required", "all") \
.output(0, "output_loss", True, "required", "all") \
.output(1, "output_backprop", True, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.get_op_info()

@op_info_register("""{
"op_name": "SoftmaxCrossEntropyWithLogits",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "softmax_cross_entropy_with_logits.so",
"compute_cost": 10,
"kernel_name": "softmax_cross_entropy_with_logits",
"partial_flag": true,
"attr": [

],
"inputs": [
{
"index": 0,
"dtype": [
"float16", "float"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "input_features",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype": [
"float16", "float"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "input_labels",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16", "float"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "output_loss",
"need_compile": true,
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype": [
"float16", "float"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "output_backprop",
"need_compile": true,
"param_type": "required",
"shape": "all"
}
]
}""")
@op_info_register(softmax_cross_entropy_with_logits_op_info)
def _softmax_cross_entropy_with_logits_tbe():
"""SoftmaxCrossEntropyWithLogits TBE register"""
return

+ 37
- 63
mindspore/ops/_op_impl/tbe/split_d.py View File

@@ -14,71 +14,45 @@
# ============================================================================

"""Add op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

split_d_op_info = TBERegOp("Split") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("split_d.so") \
.compute_cost(10) \
.kernel_name("split_d") \
.partial_flag(True) \
.attr("axis", "required", "int", "all") \
.attr("output_num", "required", "int", "all") \
.input(0, "value", False, "required", "all") \
.output(0, "output", False, "dynamic", "all") \
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
.dtype_format(DataType.BOOL_NHWC, DataType.BOOL_NHWC) \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I8_NHWC, DataType.I8_NHWC) \
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.U8_NHWC, DataType.U8_NHWC) \
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
.dtype_format(DataType.I16_NHWC, DataType.I16_NHWC) \
.dtype_format(DataType.U16_Default, DataType.U16_Default) \
.dtype_format(DataType.U16_NHWC, DataType.U16_NHWC) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I32_NHWC, DataType.I32_NHWC) \
.dtype_format(DataType.U32_Default, DataType.U32_Default) \
.dtype_format(DataType.U32_NHWC, DataType.U32_NHWC) \
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.I64_NHWC, DataType.I64_NHWC) \
.dtype_format(DataType.U64_Default, DataType.U64_Default) \
.dtype_format(DataType.U64_NHWC, DataType.U64_NHWC) \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_NHWC, DataType.F16_NHWC) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_NHWC, DataType.F32_NHWC) \
.get_op_info()

@op_info_register("""{
"op_name": "Split",
"imply_type": "TBE",
"fusion_type": "ELEMWISE",
"async_flag": false,
"binfile_name": "split_d.so",
"compute_cost": 10,
"kernel_name": "split_d",
"partial_flag": true,
"attr": [
{
"name": "axis",
"param_type": "required",
"type": "int",
"value": "all"
},
{
"name": "output_num",
"param_type": "required",
"type": "int",
"value": "all"
}
],
"inputs": [
{
"index": 0,
"dtype": [
"float16", "float16","float32", "float32", "int32", "int32", "int8", "int8",
"int16", "int16", "int64", "int64", "uint8", "uint8", "uint16", "uint16",
"uint32", "uint32", "uint64", "uint64", "bool", "bool"
],
"format": [
"DefaultFormat", "NHWC", "DefaultFormat", "NHWC", "DefaultFormat", "NHWC", "DefaultFormat", "NHWC"
, "DefaultFormat", "NHWC", "DefaultFormat", "NHWC", "DefaultFormat", "NHWC", "DefaultFormat", "NHWC"
, "DefaultFormat", "NHWC", "DefaultFormat", "NHWC", "DefaultFormat", "NHWC"
],
"name": "value",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16", "float16","float32", "float32", "int32", "int32", "int8", "int8",
"int16", "int16", "int64", "int64", "uint8", "uint8", "uint16", "uint16",
"uint32", "uint32", "uint64", "uint64", "bool", "bool"
],
"format": [
"DefaultFormat", "NHWC", "DefaultFormat", "NHWC", "DefaultFormat", "NHWC", "DefaultFormat", "NHWC"
, "DefaultFormat", "NHWC", "DefaultFormat", "NHWC", "DefaultFormat", "NHWC", "DefaultFormat", "NHWC"
, "DefaultFormat", "NHWC", "DefaultFormat", "NHWC", "DefaultFormat", "NHWC"
],
"name": "output",
"need_compile": false,
"param_type": "dynamic",
"shape": "all"
}
]
}""")

@op_info_register(split_d_op_info)
def _split_d_tbe():
"""Add TBE register"""
return

Some files were not shown because too many files changed in this diff

Loading…
Cancel
Save