Browse Source

fix bprop of ExtractImagePatches

tags/v0.5.0-beta
buxue 5 years ago
parent
commit
01e9b2d94e
3 changed files with 7 additions and 11 deletions
  1. +1
    -1
      mindspore/nn/layer/basic.py
  2. +4
    -8
      mindspore/ops/_grad/grad_nn_ops.py
  3. +2
    -2
      mindspore/ops/operations/_inner_ops.py

+ 1
- 1
mindspore/nn/layer/basic.py View File

@@ -482,7 +482,7 @@ class Unfold(Cell):

Inputs:
- **input_x** (Tensor) - A 4-D tensor whose shape is [in_batch, in_depth, in_row, in_col] and
data type is int8, float16, uint8.
data type is number.

Outputs:
Tensor, a 4-D tensor whose data type is same as 'input_x',


+ 4
- 8
mindspore/ops/_grad/grad_nn_ops.py View File

@@ -14,13 +14,12 @@
# ============================================================================

"""Define the grad rules of neural network related operations."""
from mindspore.common import dtype as mstype
from .grad_base import bprop_getters
from .. import functional as F
from .. import operations as P
from ..composite.multitype_ops.zeros_like_impl import zeros_like
from ..operations import _grad_ops as G
from ..operations import _inner_ops as inner
from ..composite.multitype_ops.zeros_like_impl import zeros_like
from .grad_base import bprop_getters
from ... import context


@@ -73,7 +72,7 @@ def get_bprop_extract_image_patches(self):
slice_op = P.Slice()
transpose = P.Transpose()
matmul = P.MatMul()
cast = P.Cast()
_, ksizes_row, ksizes_col, _ = self.ksizes

def bprop(x, out, dout):
@@ -82,16 +81,13 @@ def get_bprop_extract_image_patches(self):
x_indices_num = x_row * x_col + 1
x_idx = F.tuple_to_array(range(1, x_indices_num))
x_idx = reshape(x_idx, (1, x_row, x_col, 1))
x_idx = cast(x_idx, mstype.float16)
x_idx_patch = extract_image_patches(x_idx)
x_idx_patch = transpose(x_idx_patch, (0, 3, 1, 2))
x_idx_patch = cast(x_idx_patch, mstype.int32)

out_shape = get_shape(out)
_, out_row, out_col, _ = out_shape
out_indices_num = out_row * out_col * ksizes_row * ksizes_col
out_idx = F.tuple_to_array(range(out_indices_num))
out_idx = reshape(out_idx, (1, ksizes_row * ksizes_col, out_row, out_col))
out_idx = reshape(out_idx, (1, out_row, out_col, ksizes_row * ksizes_col))

idx_tensor = concat((expand_dims(x_idx_patch, -1), expand_dims(out_idx, -1)))
idx_tensor = reshape(idx_tensor, (-1, 2))


+ 2
- 2
mindspore/ops/operations/_inner_ops.py View File

@@ -41,7 +41,7 @@ class ExtractImagePatches(PrimitiveWithInfer):

Inputs:
- **input_x** (Tensor) - A 4-D tensor whose shape is [in_batch, in_row, in_col, in_depth] and
data type is int8, float16, uint8.
data type is number.

Outputs:
Tensor, a 4-D tensor whose data type is same as 'input_x',
@@ -94,5 +94,5 @@ class ExtractImagePatches(PrimitiveWithInfer):

def infer_dtype(self, input_x):
"""infer dtype"""
validator.check_tensor_type_same({"input_x": input_x}, (mstype.int8, mstype.float16, mstype.float32), self.name)
validator.check_tensor_type_same({"input_x": input_x}, mstype.number_type, self.name)
return input_x

Loading…
Cancel
Save