Browse Source

fix the output shape of the operator maxPoolGradGrad

tags/v1.2.0-rc1
hedongdong 5 years ago
parent
commit
81a5233c8a
2 changed files with 6 additions and 6 deletions
  1. +2
    -2
      mindspore/ops/operations/_grad_ops.py
  2. +4
    -4
      mindspore/ops/operations/array_ops.py

+ 2
- 2
mindspore/ops/operations/_grad_ops.py View File

@@ -976,12 +976,12 @@ class MaxPoolGradGrad(_PoolGrad):
super(MaxPoolGradGrad, self).__init__(kernel_size, strides, pad_mode) super(MaxPoolGradGrad, self).__init__(kernel_size, strides, pad_mode)


def infer_shape(self, x1_shape, x2_shape, grad_shape): def infer_shape(self, x1_shape, x2_shape, grad_shape):
return x1_shape
return x2_shape


def infer_dtype(self, x1_dtype, x2_dtype, grad_dtype): def infer_dtype(self, x1_dtype, x2_dtype, grad_dtype):
args = {'x1_dtype': x1_dtype, 'x2_dtype': x2_dtype, 'grad_dtype': grad_dtype} args = {'x1_dtype': x1_dtype, 'x2_dtype': x2_dtype, 'grad_dtype': grad_dtype}
validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16], self.name) validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16], self.name)
return x1_dtype
return x2_dtype




def _get_max_pool3d_grad_pads_by_pad_mode(input_shape, kernel_size, strides, pad_mode): def _get_max_pool3d_grad_pads_by_pad_mode(input_shape, kernel_size, strides, pad_mode):


+ 4
- 4
mindspore/ops/operations/array_ops.py View File

@@ -3376,13 +3376,13 @@ class TensorScatterUpdate(PrimitiveWithInfer):
`indices`, with values from `update`. This operation is almost equivalent to using `indices`, with values from `update`. This operation is almost equivalent to using
ScatterNd, except that the updates are applied on `input_x` instead of a zero tensor. ScatterNd, except that the updates are applied on `input_x` instead of a zero tensor.


`indices` must have rank atleast 2, the last axis is the depth of each index
`indices` must have rank at least 2, the last axis is the depth of each index
vectors. For each index vector, there must be a corresponding value in `update`. If vectors. For each index vector, there must be a corresponding value in `update`. If
the depth of each index tensor matches the rank of `input_x`, then each index the depth of each index tensor matches the rank of `input_x`, then each index
vector corresponds to a scalar in `input_x` and each update updates a scalar. If vector corresponds to a scalar in `input_x` and each update updates a scalar. If
the depth of each index tensor is less than the rnak of `input_x`, then each index the depth of each index tensor is less than the rnak of `input_x`, then each index
vector corresponds to a slice in `input_x`, and each update updates a slice. vector corresponds to a slice in `input_x`, and each update updates a slice.
The order in which updates are applied is nondeterministic, meaning that if there The order in which updates are applied is nondeterministic, meaning that if there
are multiple index vectors in `indices` that correspond to the same position, the are multiple index vectors in `indices` that correspond to the same position, the
value of that position in the output will be nondeterministic. value of that position in the output will be nondeterministic.
@@ -3390,7 +3390,7 @@ class TensorScatterUpdate(PrimitiveWithInfer):
Inputs: Inputs:
- **input_x** (Tensor) - The target tensor. The dimension of input_x must be no less than indices.shape[-1]. - **input_x** (Tensor) - The target tensor. The dimension of input_x must be no less than indices.shape[-1].
- **indices** (Tensor) - The index of input tensor whose data type is int32 or int64. - **indices** (Tensor) - The index of input tensor whose data type is int32 or int64.
The rank must be atleast 2.
The rank must be at least 2.
- **update** (Tensor) - The tensor to update the input tensor, has the same type as input, - **update** (Tensor) - The tensor to update the input tensor, has the same type as input,
and update.shape = indices.shape[:-1] + input_x.shape[indices.shape[-1]:]. and update.shape = indices.shape[:-1] + input_x.shape[indices.shape[-1]:].


@@ -3520,7 +3520,7 @@ class ScatterNdUpdate(_ScatterNdOp):
- **indices** (Tensor) - The index of input tensor, with int32 data type. - **indices** (Tensor) - The index of input tensor, with int32 data type.
The rank of indices must be at least 2 and `indices_shape[-1] <= len(shape)`. The rank of indices must be at least 2 and `indices_shape[-1] <= len(shape)`.
- **updates** (Tensor) - The tensor to be updated to the input tensor, has the same type as input. - **updates** (Tensor) - The tensor to be updated to the input tensor, has the same type as input.
the shape is `indices_shape[:-1] + x_shape[indices_shape[-1]:]`.
The shape is `indices_shape[:-1] + x_shape[indices_shape[-1]:]`.


Outputs: Outputs:
Tensor, has the same shape and type as `input_x`. Tensor, has the same shape and type as `input_x`.


Loading…
Cancel
Save