Browse Source

Fix bug about API of ops.

tags/v0.6.0-beta
liuxiao93 5 years ago
parent
commit
75c38a08a7
2 changed files with 9 additions and 9 deletions
  1. +3
    -2
      mindspore/ops/operations/array_ops.py
  2. +6
    -7
      mindspore/ops/operations/nn_ops.py

+ 3
- 2
mindspore/ops/operations/array_ops.py View File

@@ -435,7 +435,7 @@ class Squeeze(PrimitiveWithInfer):
ValueError: If the corresponding dimension of the specified axis does not equal to 1. ValueError: If the corresponding dimension of the specified axis does not equal to 1.


Args: Args:
axis (int): Specifies the dimension indexes of shape to be removed, which will remove
axis (Union[int, tuple(int)]): Specifies the dimension indexes of shape to be removed, which will remove
all the dimensions that are equal to 1. If specified, it must be int32 or int64. all the dimensions that are equal to 1. If specified, it must be int32 or int64.
Default: (), an empty tuple. Default: (), an empty tuple.


@@ -1427,7 +1427,8 @@ class UnsortedSegmentProd(PrimitiveWithInfer):
Inputs: Inputs:
- **input_x** (Tensor) - The shape is :math:`(x_1, x_2, ..., x_R)`. - **input_x** (Tensor) - The shape is :math:`(x_1, x_2, ..., x_R)`.
With float16, float32 or int32 data type. With float16, float32 or int32 data type.
- **segment_ids** (Tensor) - A `1-D` tensor whose shape is :math:`(x_1)`. Data type must be int32.
- **segment_ids** (Tensor) - A `1-D` tensor whose shape is :math:`(x_1)`, the value should be >= 0.
Data type must be int32.
- **num_segments** (int) - The value spcifies the number of distinct `segment_ids`, - **num_segments** (int) - The value spcifies the number of distinct `segment_ids`,
should be greater than 0. should be greater than 0.




+ 6
- 7
mindspore/ops/operations/nn_ops.py View File

@@ -3760,12 +3760,12 @@ class ApplyAdagradV2(PrimitiveWithInfer):
update_slots (bool): If `True`, `accum` will be updated. Default: True. update_slots (bool): If `True`, `accum` will be updated. Default: True.


Inputs: Inputs:
- **var** (Parameter) - Variable to be updated. With float32 or float16 data type.
- **var** (Parameter) - Variable to be updated. With float32 data type.
- **accum** (Parameter) - Accum to be updated. The shape and dtype should be the same as `var`. - **accum** (Parameter) - Accum to be updated. The shape and dtype should be the same as `var`.
With float32 or float16 data type.
- **lr** (Union[Number, Tensor]) - The learning rate value, should be scalar. With float32 or float16 data type.
With float32 data type.
- **lr** (Union[Number, Tensor]) - The learning rate value, should be scalar. With float32 data type.
- **grad** (Tensor) - A tensor for gradient. The shape and dtype should be the same as `var`. - **grad** (Tensor) - A tensor for gradient. The shape and dtype should be the same as `var`.
With float32 or float16 data type.
With float32 data type.


Outputs: Outputs:
Tuple of 2 Tensor, the updated parameters. Tuple of 2 Tensor, the updated parameters.
@@ -3817,9 +3817,8 @@ class ApplyAdagradV2(PrimitiveWithInfer):


def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, grad_dtype): def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, grad_dtype):
args = {'var': var_dtype, 'accum': accum_dtype, 'grad': grad_dtype} args = {'var': var_dtype, 'accum': accum_dtype, 'grad': grad_dtype}
valid_types = [mstype.float16, mstype.float32]
validator.check_tensor_type_same(args, valid_types, self.name)
validator.check_scalar_or_tensor_type_same({'lr': lr_dtype}, valid_types, self.name)
validator.check_tensor_type_same(args, [mstype.float32], self.name)
validator.check_scalar_or_tensor_type_same({'lr': lr_dtype}, [mstype.float32], self.name)
return var_dtype, accum_dtype return var_dtype, accum_dtype






Loading…
Cancel
Save