From 46b8ab3c4079768e50fa6eb614974d7245bb2c3c Mon Sep 17 00:00:00 2001 From: liuxiao93 Date: Wed, 16 Dec 2020 14:25:43 +0800 Subject: [PATCH] Adapt nn.Unfold and inner.ExtractImagePatches. check input dims for nn.LSTM. --- .../ascend/ascend_backend_optimization.cc | 2 - ...rt_reshape_for_extract_image_patches_op.cc | 65 ------------------- ...ert_reshape_for_extract_image_patches_op.h | 41 ------------ .../backend/session/anf_runtime_algorithm.cc | 4 -- mindspore/nn/layer/basic.py | 28 +++++--- mindspore/nn/layer/lstm.py | 7 ++ mindspore/ops/_grad/grad_nn_ops.py | 35 +--------- mindspore/ops/operations/_inner_ops.py | 26 ++++---- mindspore/ops/operations/nn_ops.py | 2 +- 9 files changed, 38 insertions(+), 172 deletions(-) delete mode 100644 mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_reshape_for_extract_image_patches_op.cc delete mode 100644 mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_reshape_for_extract_image_patches_op.h diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc index baf08f7073..f449f69fe5 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc @@ -72,7 +72,6 @@ #include "backend/optimizer/ascend/format_type/rectify_do_mask_kernel_info.h" #include "backend/optimizer/ascend/format_type/chang_axis_of_reduce_kernel.h" #include "backend/optimizer/ascend/format_type/split_unsupported_transdata.h" -#include "backend/optimizer/ascend/format_type/insert_reshape_for_extract_image_patches_op.h" #include "backend/optimizer/ascend/format_type/convert_cast_format.h" #include "backend/optimizer/pass/getitem_tuple.h" #include "backend/optimizer/pass/optimize_dependence.h" @@ -240,7 +239,6 @@ void AscendMixPrecision(const std::shared_ptr &kernel_grap auto optimizer = std::make_shared(); auto mixed_precision_pm = std::make_shared("cast_pm"); mixed_precision_pm->AddPass(std::make_shared()); - mixed_precision_pm->AddPass(std::make_shared()); mixed_precision_pm->AddPass(std::make_shared()); mixed_precision_pm->AddPass(std::make_shared()); mixed_precision_pm->AddPass(std::make_shared()); diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_reshape_for_extract_image_patches_op.cc b/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_reshape_for_extract_image_patches_op.cc deleted file mode 100644 index 2cfb9cfb3e..0000000000 --- a/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_reshape_for_extract_image_patches_op.cc +++ /dev/null @@ -1,65 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "backend/optimizer/ascend/format_type/insert_reshape_for_extract_image_patches_op.h" -#include -#include "backend/optimizer/ascend/ascend_helper.h" -#include "backend/session/anf_runtime_algorithm.h" -#include "utils/utils.h" -#include "base/core_ops.h" - -namespace mindspore { -namespace opt { -const BaseRef InsertReshapeForExtractImagePatchesOp::DefinePattern() const { - VarPtr Xs = std::make_shared(); - return VectorRef({prim::kPrimExtractImagePatches, Xs}); -} - -const AnfNodePtr InsertReshapeForExtractImagePatchesOp::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &equiv) const { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(equiv); - auto extract = CheckAnfNodeIfCNodeAndInputSize(node, 2); - MS_EXCEPTION_IF_NULL(extract); - auto in_node = extract->input(1); - MS_EXCEPTION_IF_NULL(in_node); - auto extract_kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(extract); - auto in_node_kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(in_node); - MS_EXCEPTION_IF_NULL(extract_kernel_build_info); - MS_EXCEPTION_IF_NULL(in_node_kernel_build_info); - std::vector reshape_inputs = {NewValueNode(std::make_shared(prim::kPrimReshape->name())), - in_node}; - auto reshape_builder = std::make_shared(); - reshape_builder->SetInputsFormat({kOpFormat_NC1HWC0}); - reshape_builder->SetOutputsFormat({kOpFormat_NC1HWC0}); - reshape_builder->SetInputsDeviceType({AnfAlgo::GetOutputDeviceDataType(in_node, 0)}); - reshape_builder->SetOutputsDeviceType({AnfAlgo::GetOutputDeviceDataType(in_node, 0)}); - reshape_builder->SetKernelType(in_node_kernel_build_info->kernel_type()); - reshape_builder->SetFusionType(in_node_kernel_build_info->fusion_type()); - reshape_builder->SetProcessor(in_node_kernel_build_info->processor()); - - auto reshape = func_graph->NewCNode(reshape_inputs); - reshape->set_scope(in_node->scope()); - auto shape_tmp = AnfAlgo::GetOutputInferShape(in_node, 0); - AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputDeviceDataType(in_node, 0)}, - {{shape_tmp[0], shape_tmp[2], shape_tmp[3], shape_tmp[1]}}, reshape.get()); - AnfAlgo::SetSelectKernelBuildInfo(reshape_builder->Build(), reshape.get()); - AnfAlgo::SetNodeAttr("nop_op", MakeValue(true), reshape); - AnfAlgo::SetNodeInput(extract, reshape, 0); - return extract; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_reshape_for_extract_image_patches_op.h b/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_reshape_for_extract_image_patches_op.h deleted file mode 100644 index 1344d7f2bc..0000000000 --- a/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_reshape_for_extract_image_patches_op.h +++ /dev/null @@ -1,41 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_INSERT_RESHAPE_FOR_EXTRACT_IMAGE_PATCHES_OP_H -#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_INSERT_RESHAPE_FOR_EXTRACT_IMAGE_PATCHES_OP_H -#include -#include -#include -#include -#include "ir/anf.h" -#include "backend/optimizer/common/pattern_engine.h" -#include "backend/optimizer/common/helper.h" -#include "backend/optimizer/common/optimizer.h" - -namespace mindspore { -namespace opt { -class InsertReshapeForExtractImagePatchesOp : public PatternProcessPass { - public: - explicit InsertReshapeForExtractImagePatchesOp(bool multigraph = true) - : PatternProcessPass("insert_reshape_for_extract_image_patches_op", multigraph) {} - ~InsertReshapeForExtractImagePatchesOp() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_INSERT_RESHAPE_FOR_EXTRACT_IMAGE_PATCHES_OP_H diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc index 0a79c01f0f..cac9e3d92d 100644 --- a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc @@ -563,10 +563,6 @@ std::vector AnfRuntimeAlgorithm::GetInputDeviceShape(const AnfNodePtr &n if (trans::IsNeedPadding(format, infer_shape.size())) { infer_shape = trans::PaddingShapeTo4d(infer_shape, GetInputReshapeType(node, input_idx)); } - if (node->isa() && GetCNodeName(node) == kExtractImagePatchesOpName) { - auto shape_tmp = {infer_shape[0], infer_shape[3], infer_shape[1], infer_shape[2]}; - return trans::TransShapeToDevice(shape_tmp, format); - } return trans::TransShapeToDevice(infer_shape, format); } diff --git a/mindspore/nn/layer/basic.py b/mindspore/nn/layer/basic.py index 1258667709..dda9f2707f 100644 --- a/mindspore/nn/layer/basic.py +++ b/mindspore/nn/layer/basic.py @@ -720,19 +720,27 @@ class Unfold(Cell): def __init__(self, ksizes, strides, rates, padding="valid"): super(Unfold, self).__init__() + + def _check_tuple_or_list(arg_name, arg_val, prim_name): + Validator.check_value_type(f"{arg_name}s", ksizes, [tuple, list], self.cls_name) + if len(arg_val) != 4 or arg_val[0] != 1 or arg_val[3] != 1: + raise ValueError(f"For \'{prim_name}\' the format of {arg_name}s should be [1, {arg_name}_row, " + f"{arg_name}_col, 1], but got {arg_val}.") + if not isinstance(arg_val[1], int) or not isinstance(arg_val[2], int) or arg_val[1] < 1 or arg_val[2] < 1: + raise ValueError(f"For '{prim_name}' the {arg_name}_row and {arg_name}_col in {arg_name}s should be an " + f"positive integer number, but got {arg_name}_row is {arg_val[1]}, {arg_name}_col " + f"is {arg_val[2]}") + + _check_tuple_or_list("ksize", ksizes, self.cls_name) + _check_tuple_or_list("stride", strides, self.cls_name) + _check_tuple_or_list("rate", rates, self.cls_name) + ksizes = ksizes[0], ksizes[3], ksizes[1], ksizes[2] + strides = strides[0], strides[3], strides[1], strides[2] + rates = rates[0], rates[3], rates[1], rates[2] self.extract_image_patches = inner.ExtractImagePatches(ksizes, strides, rates, padding) - self.transpose = P.Transpose() - self.format_NHWC = (0, 2, 3, 1) - self.format_NCHW = (0, 3, 1, 2) - self.is_ge = context.get_context("enable_ge") def construct(self, input_x): - if self.is_ge: - x_transpose = self.transpose(input_x, self.format_NHWC) - ret = self.extract_image_patches(x_transpose) - result = self.transpose(ret, self.format_NCHW) - else: - result = self.extract_image_patches(input_x) + result = self.extract_image_patches(input_x) return result diff --git a/mindspore/nn/layer/lstm.py b/mindspore/nn/layer/lstm.py index 9ea1fce82f..d66aa11dea 100755 --- a/mindspore/nn/layer/lstm.py +++ b/mindspore/nn/layer/lstm.py @@ -41,6 +41,11 @@ def _create_sequence_length(shape): def _check_input_dtype(input_dtype, param_name, allow_dtypes, cls_name): validator.check_type_name(param_name, input_dtype, allow_dtypes, cls_name) +@constexpr +def _check_input_3d(input_shape, param_name, func_name): + if len(input_shape) != 3: + raise ValueError(f"{func_name} {param_name} should be 3d, but got shape {input_shape}") + class LSTM(Cell): r""" Stacked LSTM (Long Short-Term Memory) layers. @@ -237,6 +242,8 @@ class LSTM(Cell): x = self.transpose(x, (1, 0, 2)) h, c = hx if self.is_ascend: + _check_input_3d(F.shape(h), "h of hx", self.cls_name) + _check_input_3d(F.shape(c), "c of hx", self.cls_name) _check_input_dtype(F.dtype(x), "x", [mstype.float32, mstype.float16], self.cls_name) _check_input_dtype(F.dtype(h), "h", [mstype.float32, mstype.float16], self.cls_name) _check_input_dtype(F.dtype(c), "c", [mstype.float32, mstype.float16], self.cls_name) diff --git a/mindspore/ops/_grad/grad_nn_ops.py b/mindspore/ops/_grad/grad_nn_ops.py index 0ed77001a1..ba43e0527d 100755 --- a/mindspore/ops/_grad/grad_nn_ops.py +++ b/mindspore/ops/_grad/grad_nn_ops.py @@ -122,7 +122,7 @@ def get_bprop_extract_image_patches(self): cast = P.Cast() matmul = P.MatMul() - _, ksizes_row, ksizes_col, _ = self.ksizes + _, _, ksizes_row, ksizes_col = self.ksizes def bprop(x, out, dout): x_shape = get_shape(x) @@ -155,39 +155,6 @@ def get_bprop_extract_image_patches(self): dx = transpose(dx, (2, 3, 0, 1)) return (dx,) - def bprop_ge(x, out, dout): - x_shape = get_shape(x) - x_batch, x_row, x_col, x_depth = x_shape - x_indices_num = x_row * x_col + 1 - x_idx = F.tuple_to_array(range(1, x_indices_num)) - x_idx = reshape(x_idx, (1, x_row, x_col, 1)) - x_idx_patch = extract_image_patches(x_idx) - - out_shape = get_shape(out) - _, out_row, out_col, _ = out_shape - out_indices_num = out_row * out_col * ksizes_row * ksizes_col - out_idx = F.tuple_to_array(range(out_indices_num)) - out_idx = reshape(out_idx, (1, out_row, out_col, ksizes_row * ksizes_col)) - - idx_tensor = concat((expand_dims(x_idx_patch, -1), expand_dims(out_idx, -1))) - idx_tensor = reshape(idx_tensor, (-1, 2)) - sp_shape = (x_indices_num, out_indices_num) - sp_tensor = scatter_nd(idx_tensor, fill(dtype(dout), (out_indices_num,), 1), sp_shape) - sp_tensor = slice_op(sp_tensor, (1, 0), (x_indices_num - 1, out_indices_num)) - - grad = reshape(dout, (x_batch, out_row, out_col, ksizes_row, ksizes_col, x_depth)) - grad = transpose(grad, (1, 2, 3, 4, 0, 5)) - grad = reshape(grad, (-1, x_batch * x_depth)) - - jac = matmul(sp_tensor, grad) - dx = reshape(jac, (x_row, x_col, x_batch, x_depth)) - dx = transpose(dx, (2, 0, 1, 3)) - - return (dx,) - - if context.get_context("enable_ge"): - return bprop_ge - return bprop diff --git a/mindspore/ops/operations/_inner_ops.py b/mindspore/ops/operations/_inner_ops.py index e3b39bcc75..f5f2e783b0 100644 --- a/mindspore/ops/operations/_inner_ops.py +++ b/mindspore/ops/operations/_inner_ops.py @@ -31,11 +31,11 @@ class ExtractImagePatches(PrimitiveWithInfer): Args: ksizes (Union[tuple[int], list[int]]): The size of sliding window, must be a tuple or a list of integers, - and the format is [1, ksize_row, ksize_col, 1]. + and the format is [1, 1, ksize_row, ksize_col]. strides (Union[tuple[int], list[int]]): Distance between the centers of the two consecutive patches, - must be a tuple or list of int, and the format is [1, stride_row, stride_col, 1]. + must be a tuple or list of int, and the format is [1, 1, stride_row, stride_col]. rates (Union[tuple[int], list[int]]): In each extracted patch, the gap between the corresponding dimension - pixel positions, must be a tuple or a list of integers, and the format is [1, rate_row, rate_col, 1]. + pixel positions, must be a tuple or a list of integers, and the format is [1, 1, rate_row, rate_col]. padding (str): The type of padding algorithm, is a string whose value is "same" or "valid", not case sensitive. Default: "valid". @@ -58,30 +58,28 @@ class ExtractImagePatches(PrimitiveWithInfer): def _check_tuple_or_list(arg_name, arg_val, prim_name): validator.check_value_type(f"{arg_name}s", ksizes, [tuple, list], self.name) - if len(arg_val) != 4 or arg_val[0] != 1 or arg_val[3] != 1: + if len(arg_val) != 4 or arg_val[0] != 1 or arg_val[1] != 1: raise ValueError(f"For \'{prim_name}\' the format of {arg_name}s should be [1, {arg_name}_row, " f"{arg_name}_col, 1], but got {arg_val}.") - if not isinstance(arg_val[1], int) or not isinstance(arg_val[2], int) or arg_val[1] < 1 or arg_val[2] < 1: + if not isinstance(arg_val[2], int) or not isinstance(arg_val[3], int) or arg_val[2] < 1 or arg_val[3] < 1: raise ValueError(f"For '{prim_name}' the {arg_name}_row and {arg_name}_col in {arg_name}s should be an " - f"positive integer number, but got {arg_name}_row is {arg_val[1]}, {arg_name}_col " - f"is {arg_val[2]}") + f"positive integer number, but got {arg_name}_row is {arg_val[2]}, {arg_name}_col " + f"is {arg_val[3]}") _check_tuple_or_list("ksize", ksizes, self.name) _check_tuple_or_list("stride", strides, self.name) _check_tuple_or_list("rate", rates, self.name) self.padding = validator.check_string(padding.upper(), ['VALID', 'SAME'], 'padding', self.name) self.add_prim_attr("padding", self.padding) - self.add_prim_attr("io_format", "NHWC") + self.add_prim_attr("io_format", "NCHW") self.is_ge = context.get_context("enable_ge") def infer_shape(self, input_x): """infer shape""" in_batch, in_depth, in_row, in_col = input_x - if self.is_ge: - in_batch, in_row, in_col, in_depth = input_x - _, ksize_row, ksize_col, _ = self.ksizes - _, stride_row, stride_col, _ = self.strides - _, rate_row, rate_col, _ = self.rates + _, _, ksize_row, ksize_col = self.ksizes + _, _, stride_row, stride_col = self.strides + _, _, rate_row, rate_col = self.rates if len(input_x) != 4: raise ValueError("The `input_x` should be a 4-D tensor, " f"but got a {len(input_x)}-D tensor whose shape is {input_x}") @@ -99,8 +97,6 @@ class ExtractImagePatches(PrimitiveWithInfer): out_col = (in_col - 1) // stride_col + 1 out_shape = [out_batch, out_depth, out_row, out_col] - if self.is_ge: - out_shape = [out_batch, out_row, out_col, out_depth] return out_shape def infer_dtype(self, input_x): diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 95f086599d..90f1824a9c 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -6405,7 +6405,7 @@ class DynamicRNN(PrimitiveWithInfer): >>> b = Tensor(np.random.rand(128).astype(np.float16)) >>> init_h = Tensor(np.random.rand(1, 16, 32).astype(np.float16)) >>> init_c = Tensor(np.random.rand(1, 16, 32).astype(np.float16)) - >>> dynamic_rnn = ops.DynamicRNNN() + >>> dynamic_rnn = ops.DynamicRNN() >>> output = dynamic_rnn(x, w, b, None, init_h, init_c) >>> print(output[0].shape) (2, 16, 32)