Browse Source

Fixing some tiny mistake of InplaceAdd, InplaceSub and InplaceUpdate vm ops and apply new character of dynamic format

tags/v0.6.0-beta
liuwenhao4 5 years ago
parent
commit
5897793b01
10 changed files with 28 additions and 56 deletions
  1. +5
    -5
      mindspore/ops/_op_impl/tbe/accumulate_n_v2.py
  2. +2
    -4
      mindspore/ops/_op_impl/tbe/approximate_equal.py
  3. +2
    -4
      mindspore/ops/_op_impl/tbe/binary_cross_entropy.py
  4. +2
    -2
      mindspore/ops/_op_impl/tbe/lin_space.py
  5. +6
    -10
      mindspore/ops/_op_impl/tbe/mod.py
  6. +5
    -4
      mindspore/ops/_op_impl/tbe/reduce_mean_d.py
  7. +2
    -2
      mindspore/ops/_op_impl/tbe/softsign.py
  8. +1
    -22
      mindspore/ops/_op_impl/tbe/splitv.py
  9. +1
    -1
      mindspore/ops/operations/array_ops.py
  10. +2
    -2
      mindspore/ops/operations/math_ops.py

+ 5
- 5
mindspore/ops/_op_impl/tbe/accumulate_n_v2.py View File

@@ -27,11 +27,11 @@ accumulate_n_v2_op_info = TBERegOp("AccumulateNV2") \
.input(0, "x", False, "dynamic", "all") \ .input(0, "x", False, "dynamic", "all") \
.output(0, "y", False, "required", "all") \ .output(0, "y", False, "required", "all") \
.op_pattern("broadcast") \ .op_pattern("broadcast") \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.F16_None, DataType.F16_None) \
.dtype_format(DataType.F32_None, DataType.F32_None) \
.dtype_format(DataType.I32_None, DataType.I32_None) \
.dtype_format(DataType.I8_None, DataType.I8_None) \
.dtype_format(DataType.U8_None, DataType.U8_None) \
.get_op_info() .get_op_info()






+ 2
- 4
mindspore/ops/_op_impl/tbe/approximate_equal.py View File

@@ -28,10 +28,8 @@ approximate_equal_op_info = TBERegOp("ApproximateEqual") \
.input(0, "x1", False, "required", "all") \ .input(0, "x1", False, "required", "all") \
.input(1, "x2", False, "required", "all") \ .input(1, "x2", False, "required", "all") \
.output(0, "y", False, "required", "all") \ .output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.BOOL_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.BOOL_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.BOOL_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.BOOL_5HD) \
.dtype_format(DataType.F16_None, DataType.F16_None, DataType.BOOL_None) \
.dtype_format(DataType.F32_None, DataType.F32_None, DataType.BOOL_None) \
.get_op_info() .get_op_info()






+ 2
- 4
mindspore/ops/_op_impl/tbe/binary_cross_entropy.py View File

@@ -28,10 +28,8 @@ binary_cross_entropy_op_info = TBERegOp("BinaryCrossEntropy") \
.input(1, "y", False, "required", "all") \ .input(1, "y", False, "required", "all") \
.input(2, "weight", False, "optional", "all") \ .input(2, "weight", False, "optional", "all") \
.output(0, "output", False, "required", "all") \ .output(0, "output", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.op_pattern("dynamicFormat") \
.dtype_format(DataType.None_None, DataType.None_None, DataType.None_None, DataType.None_None) \
.get_op_info() .get_op_info()






+ 2
- 2
mindspore/ops/_op_impl/tbe/lin_space.py View File

@@ -29,8 +29,8 @@ lin_space_op_info = TBERegOp("LinSpace") \
.input(2, "stop", False, "required", "all") \ .input(2, "stop", False, "required", "all") \
.input(3, "num", False, "required", "all") \ .input(3, "num", False, "required", "all") \
.output(0, "output", False, "required", "all") \ .output(0, "output", False, "required", "all") \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.I32_Default,
DataType.F32_Default,) \
.dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None, DataType.I32_None,
DataType.F32_None,) \
.get_op_info() .get_op_info()






+ 6
- 10
mindspore/ops/_op_impl/tbe/mod.py View File

@@ -26,16 +26,12 @@ mod_op_info = TBERegOp("Mod") \
.input(0, "x1", False, "required", "all") \ .input(0, "x1", False, "required", "all") \
.input(1, "x2", False, "required", "all") \ .input(1, "x2", False, "required", "all") \
.output(0, "y", False, "required", "all") \ .output(0, "y", False, "required", "all") \
.dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I8_5HD, DataType.I8_5HD, DataType.I8_5HD) \
.dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.U8_5HD, DataType.U8_5HD, DataType.U8_5HD) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.op_pattern("broadcast") \
.dtype_format(DataType.I8_None, DataType.I8_None, DataType.I8_None) \
.dtype_format(DataType.U8_None, DataType.U8_None, DataType.U8_None) \
.dtype_format(DataType.I32_None, DataType.I32_None, DataType.I32_None) \
.dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None) \
.dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None) \
.get_op_info() .get_op_info()






+ 5
- 4
mindspore/ops/_op_impl/tbe/reduce_mean_d.py View File

@@ -27,10 +27,11 @@ reduce_mean_d_op_info = TBERegOp("ReduceMeanD") \
.attr("keep_dims", "optional", "bool", "all") \ .attr("keep_dims", "optional", "bool", "all") \
.input(0, "x", False, "required", "all") \ .input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \ .output(0, "y", False, "required", "all") \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.op_pattern("reduce") \
.dtype_format(DataType.I8_None, DataType.I8_None) \
.dtype_format(DataType.U8_None, DataType.U8_None) \
.dtype_format(DataType.F16_None, DataType.F16_None) \
.dtype_format(DataType.F32_None, DataType.F32_None) \
.get_op_info() .get_op_info()






+ 2
- 2
mindspore/ops/_op_impl/tbe/softsign.py View File

@@ -26,8 +26,8 @@ softsign_op_info = TBERegOp("Softsign") \
.op_pattern("formatAgnostic") \ .op_pattern("formatAgnostic") \
.input(0, "x", False, "required", "all") \ .input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \ .output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F16_None, DataType.F16_None) \
.dtype_format(DataType.F32_None, DataType.F32_None) \
.get_op_info() .get_op_info()






+ 1
- 22
mindspore/ops/_op_impl/tbe/splitv.py View File

@@ -29,28 +29,7 @@ split_v_op_info = TBERegOp("SplitV") \
.input(0, "x", False, "required", "all") \ .input(0, "x", False, "required", "all") \
.output(0, "y", False, "dynamic", "all") \ .output(0, "y", False, "dynamic", "all") \
.op_pattern("dynamicFormat") \ .op_pattern("dynamicFormat") \
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
.dtype_format(DataType.BOOL_NHWC, DataType.BOOL_NHWC) \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I8_NHWC, DataType.I8_NHWC) \
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.U8_NHWC, DataType.U8_NHWC) \
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
.dtype_format(DataType.I16_NHWC, DataType.I16_NHWC) \
.dtype_format(DataType.U16_Default, DataType.U16_Default) \
.dtype_format(DataType.U16_NHWC, DataType.U16_NHWC) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I32_NHWC, DataType.I32_NHWC) \
.dtype_format(DataType.U32_Default, DataType.U32_Default) \
.dtype_format(DataType.U32_NHWC, DataType.U32_NHWC) \
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.I64_NHWC, DataType.I64_NHWC) \
.dtype_format(DataType.U64_Default, DataType.U64_Default) \
.dtype_format(DataType.U64_NHWC, DataType.U64_NHWC) \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_NHWC, DataType.F16_NHWC) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_NHWC, DataType.F32_NHWC) \
.dtype_format(DataType.None_None, DataType.None_None) \
.get_op_info() .get_op_info()






+ 1
- 1
mindspore/ops/operations/array_ops.py View File

@@ -3055,7 +3055,7 @@ class InplaceUpdate(PrimitiveWithInfer):
raise ValueError(f'The value of indices must be in [0, {x_shape[0]}), but got {i}.') raise ValueError(f'The value of indices must be in [0, {x_shape[0]}), but got {i}.')
x_rank = len(x_shape) x_rank = len(x_shape)
for idx in range(x_rank)[1:]: for idx in range(x_rank)[1:]:
validator.check("x dim %d" % idx, x_shape[idx], 'v dim %d' % idx, v_shape[idx], Rel.EQ, self.name)
validator.check('v dim %d' % idx, v_shape[idx], "x dim %d" % idx, x_shape[idx], Rel.EQ, self.name)
return x_shape return x_shape






+ 2
- 2
mindspore/ops/operations/math_ops.py View File

@@ -947,7 +947,7 @@ class InplaceAdd(PrimitiveWithInfer):
raise ValueError(f'The value of indices must be in [0, {x_shape[0]}), but got {i}.') raise ValueError(f'The value of indices must be in [0, {x_shape[0]}), but got {i}.')
x_rank = len(x_shape) x_rank = len(x_shape)
for idx in range(x_rank)[1:]: for idx in range(x_rank)[1:]:
validator.check("x dim %d" % idx, x_shape[idx], 'v dim %d' % idx, v_shape[idx], Rel.EQ, self.name)
validator.check('v dim %d' % idx, v_shape[idx], "x dim %d" % idx, x_shape[idx], Rel.EQ, self.name)


return x_shape return x_shape


@@ -1005,7 +1005,7 @@ class InplaceSub(PrimitiveWithInfer):
raise ValueError(f'The value of indices must be in [0, {x_shape[0]}), but got {i}.') raise ValueError(f'The value of indices must be in [0, {x_shape[0]}), but got {i}.')
x_rank = len(x_shape) x_rank = len(x_shape)
for idx in range(x_rank)[1:]: for idx in range(x_rank)[1:]:
validator.check("x dim %d" % idx, x_shape[idx], 'v dim %d' % idx, v_shape[idx], Rel.EQ, self.name)
validator.check('v dim %d' % idx, v_shape[idx], "x dim %d" % idx, x_shape[idx], Rel.EQ, self.name)


return x_shape return x_shape




Loading…
Cancel
Save