check input dims for nn.LSTM.tags/v1.1.0
| @@ -72,7 +72,6 @@ | |||
| #include "backend/optimizer/ascend/format_type/rectify_do_mask_kernel_info.h" | |||
| #include "backend/optimizer/ascend/format_type/chang_axis_of_reduce_kernel.h" | |||
| #include "backend/optimizer/ascend/format_type/split_unsupported_transdata.h" | |||
| #include "backend/optimizer/ascend/format_type/insert_reshape_for_extract_image_patches_op.h" | |||
| #include "backend/optimizer/ascend/format_type/convert_cast_format.h" | |||
| #include "backend/optimizer/pass/getitem_tuple.h" | |||
| #include "backend/optimizer/pass/optimize_dependence.h" | |||
| @@ -240,7 +239,6 @@ void AscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel_grap | |||
| auto optimizer = std::make_shared<GraphOptimizer>(); | |||
| auto mixed_precision_pm = std::make_shared<PassManager>("cast_pm"); | |||
| mixed_precision_pm->AddPass(std::make_shared<InsertCast>()); | |||
| mixed_precision_pm->AddPass(std::make_shared<InsertReshapeForExtractImagePatchesOp>()); | |||
| mixed_precision_pm->AddPass(std::make_shared<GetitemTuple>()); | |||
| mixed_precision_pm->AddPass(std::make_shared<CommonSubexpressionElimination>()); | |||
| mixed_precision_pm->AddPass(std::make_shared<EliminateRedundantOp>()); | |||
| @@ -1,65 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/optimizer/ascend/format_type/insert_reshape_for_extract_image_patches_op.h" | |||
| #include <memory> | |||
| #include "backend/optimizer/ascend/ascend_helper.h" | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "utils/utils.h" | |||
| #include "base/core_ops.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| const BaseRef InsertReshapeForExtractImagePatchesOp::DefinePattern() const { | |||
| VarPtr Xs = std::make_shared<SeqVar>(); | |||
| return VectorRef({prim::kPrimExtractImagePatches, Xs}); | |||
| } | |||
| const AnfNodePtr InsertReshapeForExtractImagePatchesOp::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||
| const EquivPtr &equiv) const { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(equiv); | |||
| auto extract = CheckAnfNodeIfCNodeAndInputSize(node, 2); | |||
| MS_EXCEPTION_IF_NULL(extract); | |||
| auto in_node = extract->input(1); | |||
| MS_EXCEPTION_IF_NULL(in_node); | |||
| auto extract_kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(extract); | |||
| auto in_node_kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(in_node); | |||
| MS_EXCEPTION_IF_NULL(extract_kernel_build_info); | |||
| MS_EXCEPTION_IF_NULL(in_node_kernel_build_info); | |||
| std::vector<AnfNodePtr> reshape_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReshape->name())), | |||
| in_node}; | |||
| auto reshape_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(); | |||
| reshape_builder->SetInputsFormat({kOpFormat_NC1HWC0}); | |||
| reshape_builder->SetOutputsFormat({kOpFormat_NC1HWC0}); | |||
| reshape_builder->SetInputsDeviceType({AnfAlgo::GetOutputDeviceDataType(in_node, 0)}); | |||
| reshape_builder->SetOutputsDeviceType({AnfAlgo::GetOutputDeviceDataType(in_node, 0)}); | |||
| reshape_builder->SetKernelType(in_node_kernel_build_info->kernel_type()); | |||
| reshape_builder->SetFusionType(in_node_kernel_build_info->fusion_type()); | |||
| reshape_builder->SetProcessor(in_node_kernel_build_info->processor()); | |||
| auto reshape = func_graph->NewCNode(reshape_inputs); | |||
| reshape->set_scope(in_node->scope()); | |||
| auto shape_tmp = AnfAlgo::GetOutputInferShape(in_node, 0); | |||
| AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputDeviceDataType(in_node, 0)}, | |||
| {{shape_tmp[0], shape_tmp[2], shape_tmp[3], shape_tmp[1]}}, reshape.get()); | |||
| AnfAlgo::SetSelectKernelBuildInfo(reshape_builder->Build(), reshape.get()); | |||
| AnfAlgo::SetNodeAttr("nop_op", MakeValue(true), reshape); | |||
| AnfAlgo::SetNodeInput(extract, reshape, 0); | |||
| return extract; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -1,41 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_INSERT_RESHAPE_FOR_EXTRACT_IMAGE_PATCHES_OP_H | |||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_INSERT_RESHAPE_FOR_EXTRACT_IMAGE_PATCHES_OP_H | |||
| #include <vector> | |||
| #include <string> | |||
| #include <utility> | |||
| #include <memory> | |||
| #include "ir/anf.h" | |||
| #include "backend/optimizer/common/pattern_engine.h" | |||
| #include "backend/optimizer/common/helper.h" | |||
| #include "backend/optimizer/common/optimizer.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class InsertReshapeForExtractImagePatchesOp : public PatternProcessPass { | |||
| public: | |||
| explicit InsertReshapeForExtractImagePatchesOp(bool multigraph = true) | |||
| : PatternProcessPass("insert_reshape_for_extract_image_patches_op", multigraph) {} | |||
| ~InsertReshapeForExtractImagePatchesOp() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_INSERT_RESHAPE_FOR_EXTRACT_IMAGE_PATCHES_OP_H | |||
| @@ -563,10 +563,6 @@ std::vector<size_t> AnfRuntimeAlgorithm::GetInputDeviceShape(const AnfNodePtr &n | |||
| if (trans::IsNeedPadding(format, infer_shape.size())) { | |||
| infer_shape = trans::PaddingShapeTo4d(infer_shape, GetInputReshapeType(node, input_idx)); | |||
| } | |||
| if (node->isa<CNode>() && GetCNodeName(node) == kExtractImagePatchesOpName) { | |||
| auto shape_tmp = {infer_shape[0], infer_shape[3], infer_shape[1], infer_shape[2]}; | |||
| return trans::TransShapeToDevice(shape_tmp, format); | |||
| } | |||
| return trans::TransShapeToDevice(infer_shape, format); | |||
| } | |||
| @@ -720,19 +720,27 @@ class Unfold(Cell): | |||
| def __init__(self, ksizes, strides, rates, padding="valid"): | |||
| super(Unfold, self).__init__() | |||
| def _check_tuple_or_list(arg_name, arg_val, prim_name): | |||
| Validator.check_value_type(f"{arg_name}s", ksizes, [tuple, list], self.cls_name) | |||
| if len(arg_val) != 4 or arg_val[0] != 1 or arg_val[3] != 1: | |||
| raise ValueError(f"For \'{prim_name}\' the format of {arg_name}s should be [1, {arg_name}_row, " | |||
| f"{arg_name}_col, 1], but got {arg_val}.") | |||
| if not isinstance(arg_val[1], int) or not isinstance(arg_val[2], int) or arg_val[1] < 1 or arg_val[2] < 1: | |||
| raise ValueError(f"For '{prim_name}' the {arg_name}_row and {arg_name}_col in {arg_name}s should be an " | |||
| f"positive integer number, but got {arg_name}_row is {arg_val[1]}, {arg_name}_col " | |||
| f"is {arg_val[2]}") | |||
| _check_tuple_or_list("ksize", ksizes, self.cls_name) | |||
| _check_tuple_or_list("stride", strides, self.cls_name) | |||
| _check_tuple_or_list("rate", rates, self.cls_name) | |||
| ksizes = ksizes[0], ksizes[3], ksizes[1], ksizes[2] | |||
| strides = strides[0], strides[3], strides[1], strides[2] | |||
| rates = rates[0], rates[3], rates[1], rates[2] | |||
| self.extract_image_patches = inner.ExtractImagePatches(ksizes, strides, rates, padding) | |||
| self.transpose = P.Transpose() | |||
| self.format_NHWC = (0, 2, 3, 1) | |||
| self.format_NCHW = (0, 3, 1, 2) | |||
| self.is_ge = context.get_context("enable_ge") | |||
| def construct(self, input_x): | |||
| if self.is_ge: | |||
| x_transpose = self.transpose(input_x, self.format_NHWC) | |||
| ret = self.extract_image_patches(x_transpose) | |||
| result = self.transpose(ret, self.format_NCHW) | |||
| else: | |||
| result = self.extract_image_patches(input_x) | |||
| result = self.extract_image_patches(input_x) | |||
| return result | |||
| @@ -41,6 +41,11 @@ def _create_sequence_length(shape): | |||
| def _check_input_dtype(input_dtype, param_name, allow_dtypes, cls_name): | |||
| validator.check_type_name(param_name, input_dtype, allow_dtypes, cls_name) | |||
| @constexpr | |||
| def _check_input_3d(input_shape, param_name, func_name): | |||
| if len(input_shape) != 3: | |||
| raise ValueError(f"{func_name} {param_name} should be 3d, but got shape {input_shape}") | |||
| class LSTM(Cell): | |||
| r""" | |||
| Stacked LSTM (Long Short-Term Memory) layers. | |||
| @@ -237,6 +242,8 @@ class LSTM(Cell): | |||
| x = self.transpose(x, (1, 0, 2)) | |||
| h, c = hx | |||
| if self.is_ascend: | |||
| _check_input_3d(F.shape(h), "h of hx", self.cls_name) | |||
| _check_input_3d(F.shape(c), "c of hx", self.cls_name) | |||
| _check_input_dtype(F.dtype(x), "x", [mstype.float32, mstype.float16], self.cls_name) | |||
| _check_input_dtype(F.dtype(h), "h", [mstype.float32, mstype.float16], self.cls_name) | |||
| _check_input_dtype(F.dtype(c), "c", [mstype.float32, mstype.float16], self.cls_name) | |||
| @@ -122,7 +122,7 @@ def get_bprop_extract_image_patches(self): | |||
| cast = P.Cast() | |||
| matmul = P.MatMul() | |||
| _, ksizes_row, ksizes_col, _ = self.ksizes | |||
| _, _, ksizes_row, ksizes_col = self.ksizes | |||
| def bprop(x, out, dout): | |||
| x_shape = get_shape(x) | |||
| @@ -155,39 +155,6 @@ def get_bprop_extract_image_patches(self): | |||
| dx = transpose(dx, (2, 3, 0, 1)) | |||
| return (dx,) | |||
| def bprop_ge(x, out, dout): | |||
| x_shape = get_shape(x) | |||
| x_batch, x_row, x_col, x_depth = x_shape | |||
| x_indices_num = x_row * x_col + 1 | |||
| x_idx = F.tuple_to_array(range(1, x_indices_num)) | |||
| x_idx = reshape(x_idx, (1, x_row, x_col, 1)) | |||
| x_idx_patch = extract_image_patches(x_idx) | |||
| out_shape = get_shape(out) | |||
| _, out_row, out_col, _ = out_shape | |||
| out_indices_num = out_row * out_col * ksizes_row * ksizes_col | |||
| out_idx = F.tuple_to_array(range(out_indices_num)) | |||
| out_idx = reshape(out_idx, (1, out_row, out_col, ksizes_row * ksizes_col)) | |||
| idx_tensor = concat((expand_dims(x_idx_patch, -1), expand_dims(out_idx, -1))) | |||
| idx_tensor = reshape(idx_tensor, (-1, 2)) | |||
| sp_shape = (x_indices_num, out_indices_num) | |||
| sp_tensor = scatter_nd(idx_tensor, fill(dtype(dout), (out_indices_num,), 1), sp_shape) | |||
| sp_tensor = slice_op(sp_tensor, (1, 0), (x_indices_num - 1, out_indices_num)) | |||
| grad = reshape(dout, (x_batch, out_row, out_col, ksizes_row, ksizes_col, x_depth)) | |||
| grad = transpose(grad, (1, 2, 3, 4, 0, 5)) | |||
| grad = reshape(grad, (-1, x_batch * x_depth)) | |||
| jac = matmul(sp_tensor, grad) | |||
| dx = reshape(jac, (x_row, x_col, x_batch, x_depth)) | |||
| dx = transpose(dx, (2, 0, 1, 3)) | |||
| return (dx,) | |||
| if context.get_context("enable_ge"): | |||
| return bprop_ge | |||
| return bprop | |||
| @@ -31,11 +31,11 @@ class ExtractImagePatches(PrimitiveWithInfer): | |||
| Args: | |||
| ksizes (Union[tuple[int], list[int]]): The size of sliding window, must be a tuple or a list of integers, | |||
| and the format is [1, ksize_row, ksize_col, 1]. | |||
| and the format is [1, 1, ksize_row, ksize_col]. | |||
| strides (Union[tuple[int], list[int]]): Distance between the centers of the two consecutive patches, | |||
| must be a tuple or list of int, and the format is [1, stride_row, stride_col, 1]. | |||
| must be a tuple or list of int, and the format is [1, 1, stride_row, stride_col]. | |||
| rates (Union[tuple[int], list[int]]): In each extracted patch, the gap between the corresponding dimension | |||
| pixel positions, must be a tuple or a list of integers, and the format is [1, rate_row, rate_col, 1]. | |||
| pixel positions, must be a tuple or a list of integers, and the format is [1, 1, rate_row, rate_col]. | |||
| padding (str): The type of padding algorithm, is a string whose value is "same" or "valid", | |||
| not case sensitive. Default: "valid". | |||
| @@ -58,30 +58,28 @@ class ExtractImagePatches(PrimitiveWithInfer): | |||
| def _check_tuple_or_list(arg_name, arg_val, prim_name): | |||
| validator.check_value_type(f"{arg_name}s", ksizes, [tuple, list], self.name) | |||
| if len(arg_val) != 4 or arg_val[0] != 1 or arg_val[3] != 1: | |||
| if len(arg_val) != 4 or arg_val[0] != 1 or arg_val[1] != 1: | |||
| raise ValueError(f"For \'{prim_name}\' the format of {arg_name}s should be [1, {arg_name}_row, " | |||
| f"{arg_name}_col, 1], but got {arg_val}.") | |||
| if not isinstance(arg_val[1], int) or not isinstance(arg_val[2], int) or arg_val[1] < 1 or arg_val[2] < 1: | |||
| if not isinstance(arg_val[2], int) or not isinstance(arg_val[3], int) or arg_val[2] < 1 or arg_val[3] < 1: | |||
| raise ValueError(f"For '{prim_name}' the {arg_name}_row and {arg_name}_col in {arg_name}s should be an " | |||
| f"positive integer number, but got {arg_name}_row is {arg_val[1]}, {arg_name}_col " | |||
| f"is {arg_val[2]}") | |||
| f"positive integer number, but got {arg_name}_row is {arg_val[2]}, {arg_name}_col " | |||
| f"is {arg_val[3]}") | |||
| _check_tuple_or_list("ksize", ksizes, self.name) | |||
| _check_tuple_or_list("stride", strides, self.name) | |||
| _check_tuple_or_list("rate", rates, self.name) | |||
| self.padding = validator.check_string(padding.upper(), ['VALID', 'SAME'], 'padding', self.name) | |||
| self.add_prim_attr("padding", self.padding) | |||
| self.add_prim_attr("io_format", "NHWC") | |||
| self.add_prim_attr("io_format", "NCHW") | |||
| self.is_ge = context.get_context("enable_ge") | |||
| def infer_shape(self, input_x): | |||
| """infer shape""" | |||
| in_batch, in_depth, in_row, in_col = input_x | |||
| if self.is_ge: | |||
| in_batch, in_row, in_col, in_depth = input_x | |||
| _, ksize_row, ksize_col, _ = self.ksizes | |||
| _, stride_row, stride_col, _ = self.strides | |||
| _, rate_row, rate_col, _ = self.rates | |||
| _, _, ksize_row, ksize_col = self.ksizes | |||
| _, _, stride_row, stride_col = self.strides | |||
| _, _, rate_row, rate_col = self.rates | |||
| if len(input_x) != 4: | |||
| raise ValueError("The `input_x` should be a 4-D tensor, " | |||
| f"but got a {len(input_x)}-D tensor whose shape is {input_x}") | |||
| @@ -99,8 +97,6 @@ class ExtractImagePatches(PrimitiveWithInfer): | |||
| out_col = (in_col - 1) // stride_col + 1 | |||
| out_shape = [out_batch, out_depth, out_row, out_col] | |||
| if self.is_ge: | |||
| out_shape = [out_batch, out_row, out_col, out_depth] | |||
| return out_shape | |||
| def infer_dtype(self, input_x): | |||
| @@ -6405,7 +6405,7 @@ class DynamicRNN(PrimitiveWithInfer): | |||
| >>> b = Tensor(np.random.rand(128).astype(np.float16)) | |||
| >>> init_h = Tensor(np.random.rand(1, 16, 32).astype(np.float16)) | |||
| >>> init_c = Tensor(np.random.rand(1, 16, 32).astype(np.float16)) | |||
| >>> dynamic_rnn = ops.DynamicRNNN() | |||
| >>> dynamic_rnn = ops.DynamicRNN() | |||
| >>> output = dynamic_rnn(x, w, b, None, init_h, init_c) | |||
| >>> print(output[0].shape) | |||
| (2, 16, 32) | |||