From 01e9b2d94e727da33e47272dbad456686de56fb8 Mon Sep 17 00:00:00 2001 From: buxue Date: Sat, 30 May 2020 17:00:22 +0800 Subject: [PATCH] fix bprop of ExtractImagePatches --- mindspore/nn/layer/basic.py | 2 +- mindspore/ops/_grad/grad_nn_ops.py | 12 ++++-------- mindspore/ops/operations/_inner_ops.py | 4 ++-- 3 files changed, 7 insertions(+), 11 deletions(-) diff --git a/mindspore/nn/layer/basic.py b/mindspore/nn/layer/basic.py index 24d547c8b4..8f4e468e0b 100644 --- a/mindspore/nn/layer/basic.py +++ b/mindspore/nn/layer/basic.py @@ -482,7 +482,7 @@ class Unfold(Cell): Inputs: - **input_x** (Tensor) - A 4-D tensor whose shape is [in_batch, in_depth, in_row, in_col] and - data type is int8, float16, uint8. + data type is number. Outputs: Tensor, a 4-D tensor whose data type is same as 'input_x', diff --git a/mindspore/ops/_grad/grad_nn_ops.py b/mindspore/ops/_grad/grad_nn_ops.py index 227281420e..9f1ccdf5a9 100755 --- a/mindspore/ops/_grad/grad_nn_ops.py +++ b/mindspore/ops/_grad/grad_nn_ops.py @@ -14,13 +14,12 @@ # ============================================================================ """Define the grad rules of neural network related operations.""" -from mindspore.common import dtype as mstype +from .grad_base import bprop_getters from .. import functional as F from .. import operations as P +from ..composite.multitype_ops.zeros_like_impl import zeros_like from ..operations import _grad_ops as G from ..operations import _inner_ops as inner -from ..composite.multitype_ops.zeros_like_impl import zeros_like -from .grad_base import bprop_getters from ... import context @@ -73,7 +72,7 @@ def get_bprop_extract_image_patches(self): slice_op = P.Slice() transpose = P.Transpose() matmul = P.MatMul() - cast = P.Cast() + _, ksizes_row, ksizes_col, _ = self.ksizes def bprop(x, out, dout): @@ -82,16 +81,13 @@ def get_bprop_extract_image_patches(self): x_indices_num = x_row * x_col + 1 x_idx = F.tuple_to_array(range(1, x_indices_num)) x_idx = reshape(x_idx, (1, x_row, x_col, 1)) - x_idx = cast(x_idx, mstype.float16) x_idx_patch = extract_image_patches(x_idx) - x_idx_patch = transpose(x_idx_patch, (0, 3, 1, 2)) - x_idx_patch = cast(x_idx_patch, mstype.int32) out_shape = get_shape(out) _, out_row, out_col, _ = out_shape out_indices_num = out_row * out_col * ksizes_row * ksizes_col out_idx = F.tuple_to_array(range(out_indices_num)) - out_idx = reshape(out_idx, (1, ksizes_row * ksizes_col, out_row, out_col)) + out_idx = reshape(out_idx, (1, out_row, out_col, ksizes_row * ksizes_col)) idx_tensor = concat((expand_dims(x_idx_patch, -1), expand_dims(out_idx, -1))) idx_tensor = reshape(idx_tensor, (-1, 2)) diff --git a/mindspore/ops/operations/_inner_ops.py b/mindspore/ops/operations/_inner_ops.py index 632f9c0a20..38f399316a 100644 --- a/mindspore/ops/operations/_inner_ops.py +++ b/mindspore/ops/operations/_inner_ops.py @@ -41,7 +41,7 @@ class ExtractImagePatches(PrimitiveWithInfer): Inputs: - **input_x** (Tensor) - A 4-D tensor whose shape is [in_batch, in_row, in_col, in_depth] and - data type is int8, float16, uint8. + data type is number. Outputs: Tensor, a 4-D tensor whose data type is same as 'input_x', @@ -94,5 +94,5 @@ class ExtractImagePatches(PrimitiveWithInfer): def infer_dtype(self, input_x): """infer dtype""" - validator.check_tensor_type_same({"input_x": input_x}, (mstype.int8, mstype.float16, mstype.float32), self.name) + validator.check_tensor_type_same({"input_x": input_x}, mstype.number_type, self.name) return input_x