Browse Source

!9429 update supported platforms for op EmbeddingLookup, ELU and fix bugs of op SparseGatherV2, ArgMaxWithValue

From: @lihongkang1
Reviewed-by: @liangchenghui,@youui
Signed-off-by: @liangchenghui
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
9605102dd1
5 changed files with 26 additions and 27 deletions
  1. +3
    -0
      mindspore/nn/layer/activation.py
  2. +1
    -1
      mindspore/nn/layer/embedding.py
  3. +1
    -1
      mindspore/nn/wrap/cell_wrapper.py
  4. +15
    -25
      mindspore/ops/_op_impl/tbe/sparse_gather_v2.py
  5. +6
    -0
      mindspore/ops/operations/array_ops.py

+ 3
- 0
mindspore/nn/layer/activation.py View File

@@ -151,6 +151,9 @@ class ELU(Cell):
Outputs:
Tensor, with the same type and shape as the `input_data`.

Supported Platforms:
``Ascend`` ``GPU``

Examples:
>>> input_x = Tensor(np.array([-1, -2, 0, 2, 1]), mindspore.float32)
>>> elu = nn.ELU()


+ 1
- 1
mindspore/nn/layer/embedding.py View File

@@ -167,7 +167,7 @@ class EmbeddingLookup(Cell):
Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
``Ascend`` ``CPU``

Examples:
>>> input_indices = Tensor(np.array([[1, 0], [3, 2]]), mindspore.int32)


+ 1
- 1
mindspore/nn/wrap/cell_wrapper.py View File

@@ -205,7 +205,7 @@ class TrainOneStepCell(Cell):
>>> train_net = nn.TrainOneStepCell(loss_net, optim)
>>>
>>> #2) Using user-defined WithLossCell
>>>class MyWithLossCell(nn.Cell):
>>> class MyWithLossCell(nn.cell):
>>> def __init__(self, backbone, loss_fn):
>>> super(MyWithLossCell, self).__init__(auto_prefix=False)
>>> self._backbone = backbone


+ 15
- 25
mindspore/ops/_op_impl/tbe/sparse_gather_v2.py View File

@@ -27,36 +27,26 @@ sparse_gather_v2_op_info = TBERegOp("SparseGatherV2") \
.input(0, "x", False, "required", "all") \
.input(1, "indices", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \
.dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default) \
.dtype_format(DataType.I8_Default, DataType.I64_Default, DataType.I8_Default) \
.dtype_format(DataType.I8_5HD, DataType.I32_5HD, DataType.I8_5HD) \
.dtype_format(DataType.I8_5HD, DataType.I64_5HD, DataType.I8_5HD) \
.dtype_format(DataType.I8_FracZ, DataType.I32_FracZ, DataType.I8_FracZ) \
.dtype_format(DataType.I8_FracZ, DataType.I64_FracZ, DataType.I8_FracZ) \
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default) \
.dtype_format(DataType.U8_Default, DataType.I64_Default, DataType.U8_Default) \
.dtype_format(DataType.U8_5HD, DataType.I32_5HD, DataType.U8_5HD) \
.dtype_format(DataType.U8_5HD, DataType.I64_5HD, DataType.U8_5HD) \
.dtype_format(DataType.U8_FracZ, DataType.I32_FracZ, DataType.U8_FracZ) \
.dtype_format(DataType.U8_FracZ, DataType.I64_FracZ, DataType.U8_FracZ) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I32_Default) \
.dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \
.dtype_format(DataType.I32_5HD, DataType.I64_5HD, DataType.I32_5HD) \
.dtype_format(DataType.I32_FracZ, DataType.I32_FracZ, DataType.I32_FracZ) \
.dtype_format(DataType.I32_FracZ, DataType.I64_FracZ, DataType.I32_FracZ) \
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default) \
.dtype_format(DataType.U32_Default, DataType.I32_Default, DataType.U32_Default) \
.dtype_format(DataType.I16_Default, DataType.I32_Default, DataType.I16_Default) \
.dtype_format(DataType.U16_Default, DataType.I32_Default, DataType.U16_Default) \
.dtype_format(DataType.I64_Default, DataType.I32_Default, DataType.I64_Default) \
.dtype_format(DataType.U64_Default, DataType.I32_Default, DataType.U64_Default) \
.dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.I32_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F16_5HD, DataType.I64_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F16_FracZ, DataType.I32_FracZ, DataType.F16_FracZ) \
.dtype_format(DataType.F16_FracZ, DataType.I64_FracZ, DataType.F16_FracZ) \
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.I32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_5HD, DataType.I64_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_FracZ, DataType.I32_FracZ, DataType.F32_FracZ) \
.dtype_format(DataType.F32_FracZ, DataType.I64_FracZ, DataType.F32_FracZ) \
.dtype_format(DataType.I8_Default, DataType.I64_Default, DataType.I8_Default) \
.dtype_format(DataType.U8_Default, DataType.I64_Default, DataType.U8_Default) \
.dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I32_Default) \
.dtype_format(DataType.U32_Default, DataType.I64_Default, DataType.U32_Default) \
.dtype_format(DataType.I16_Default, DataType.I64_Default, DataType.I16_Default) \
.dtype_format(DataType.U16_Default, DataType.I64_Default, DataType.U16_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.U64_Default, DataType.I64_Default, DataType.U64_Default) \
.get_op_info()




+ 6
- 0
mindspore/ops/operations/array_ops.py View File

@@ -833,6 +833,10 @@ class SparseGatherV2(GatherV2):
>>> input_indices = Tensor(np.array([1, 2]), mindspore.int32)
>>> axis = 1
>>> out = ops.SparseGatherV2()(input_params, input_indices, axis)
>>> print(out)
[[2. 7.]
[4. 54.]
[2. 55.]]
"""


@@ -1642,6 +1646,8 @@ class ArgMaxWithValue(PrimitiveWithInfer):
Examples:
>>> input_x = Tensor(np.random.rand(5), mindspore.float32)
>>> index, output = ops.ArgMaxWithValue()(input_x)
>>> print(index, output)
2 0.87173676
"""

@prim_attr_register


Loading…
Cancel
Save