check input dims for nn.LSTM.tags/v1.1.0
| @@ -72,7 +72,6 @@ | |||||
| #include "backend/optimizer/ascend/format_type/rectify_do_mask_kernel_info.h" | #include "backend/optimizer/ascend/format_type/rectify_do_mask_kernel_info.h" | ||||
| #include "backend/optimizer/ascend/format_type/chang_axis_of_reduce_kernel.h" | #include "backend/optimizer/ascend/format_type/chang_axis_of_reduce_kernel.h" | ||||
| #include "backend/optimizer/ascend/format_type/split_unsupported_transdata.h" | #include "backend/optimizer/ascend/format_type/split_unsupported_transdata.h" | ||||
| #include "backend/optimizer/ascend/format_type/insert_reshape_for_extract_image_patches_op.h" | |||||
| #include "backend/optimizer/ascend/format_type/convert_cast_format.h" | #include "backend/optimizer/ascend/format_type/convert_cast_format.h" | ||||
| #include "backend/optimizer/pass/getitem_tuple.h" | #include "backend/optimizer/pass/getitem_tuple.h" | ||||
| #include "backend/optimizer/pass/optimize_dependence.h" | #include "backend/optimizer/pass/optimize_dependence.h" | ||||
| @@ -240,7 +239,6 @@ void AscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel_grap | |||||
| auto optimizer = std::make_shared<GraphOptimizer>(); | auto optimizer = std::make_shared<GraphOptimizer>(); | ||||
| auto mixed_precision_pm = std::make_shared<PassManager>("cast_pm"); | auto mixed_precision_pm = std::make_shared<PassManager>("cast_pm"); | ||||
| mixed_precision_pm->AddPass(std::make_shared<InsertCast>()); | mixed_precision_pm->AddPass(std::make_shared<InsertCast>()); | ||||
| mixed_precision_pm->AddPass(std::make_shared<InsertReshapeForExtractImagePatchesOp>()); | |||||
| mixed_precision_pm->AddPass(std::make_shared<GetitemTuple>()); | mixed_precision_pm->AddPass(std::make_shared<GetitemTuple>()); | ||||
| mixed_precision_pm->AddPass(std::make_shared<CommonSubexpressionElimination>()); | mixed_precision_pm->AddPass(std::make_shared<CommonSubexpressionElimination>()); | ||||
| mixed_precision_pm->AddPass(std::make_shared<EliminateRedundantOp>()); | mixed_precision_pm->AddPass(std::make_shared<EliminateRedundantOp>()); | ||||
| @@ -1,65 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "backend/optimizer/ascend/format_type/insert_reshape_for_extract_image_patches_op.h" | |||||
| #include <memory> | |||||
| #include "backend/optimizer/ascend/ascend_helper.h" | |||||
| #include "backend/session/anf_runtime_algorithm.h" | |||||
| #include "utils/utils.h" | |||||
| #include "base/core_ops.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| const BaseRef InsertReshapeForExtractImagePatchesOp::DefinePattern() const { | |||||
| VarPtr Xs = std::make_shared<SeqVar>(); | |||||
| return VectorRef({prim::kPrimExtractImagePatches, Xs}); | |||||
| } | |||||
| const AnfNodePtr InsertReshapeForExtractImagePatchesOp::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||||
| const EquivPtr &equiv) const { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| MS_EXCEPTION_IF_NULL(equiv); | |||||
| auto extract = CheckAnfNodeIfCNodeAndInputSize(node, 2); | |||||
| MS_EXCEPTION_IF_NULL(extract); | |||||
| auto in_node = extract->input(1); | |||||
| MS_EXCEPTION_IF_NULL(in_node); | |||||
| auto extract_kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(extract); | |||||
| auto in_node_kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(in_node); | |||||
| MS_EXCEPTION_IF_NULL(extract_kernel_build_info); | |||||
| MS_EXCEPTION_IF_NULL(in_node_kernel_build_info); | |||||
| std::vector<AnfNodePtr> reshape_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReshape->name())), | |||||
| in_node}; | |||||
| auto reshape_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(); | |||||
| reshape_builder->SetInputsFormat({kOpFormat_NC1HWC0}); | |||||
| reshape_builder->SetOutputsFormat({kOpFormat_NC1HWC0}); | |||||
| reshape_builder->SetInputsDeviceType({AnfAlgo::GetOutputDeviceDataType(in_node, 0)}); | |||||
| reshape_builder->SetOutputsDeviceType({AnfAlgo::GetOutputDeviceDataType(in_node, 0)}); | |||||
| reshape_builder->SetKernelType(in_node_kernel_build_info->kernel_type()); | |||||
| reshape_builder->SetFusionType(in_node_kernel_build_info->fusion_type()); | |||||
| reshape_builder->SetProcessor(in_node_kernel_build_info->processor()); | |||||
| auto reshape = func_graph->NewCNode(reshape_inputs); | |||||
| reshape->set_scope(in_node->scope()); | |||||
| auto shape_tmp = AnfAlgo::GetOutputInferShape(in_node, 0); | |||||
| AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputDeviceDataType(in_node, 0)}, | |||||
| {{shape_tmp[0], shape_tmp[2], shape_tmp[3], shape_tmp[1]}}, reshape.get()); | |||||
| AnfAlgo::SetSelectKernelBuildInfo(reshape_builder->Build(), reshape.get()); | |||||
| AnfAlgo::SetNodeAttr("nop_op", MakeValue(true), reshape); | |||||
| AnfAlgo::SetNodeInput(extract, reshape, 0); | |||||
| return extract; | |||||
| } | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| @@ -1,41 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_INSERT_RESHAPE_FOR_EXTRACT_IMAGE_PATCHES_OP_H | |||||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_INSERT_RESHAPE_FOR_EXTRACT_IMAGE_PATCHES_OP_H | |||||
| #include <vector> | |||||
| #include <string> | |||||
| #include <utility> | |||||
| #include <memory> | |||||
| #include "ir/anf.h" | |||||
| #include "backend/optimizer/common/pattern_engine.h" | |||||
| #include "backend/optimizer/common/helper.h" | |||||
| #include "backend/optimizer/common/optimizer.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| class InsertReshapeForExtractImagePatchesOp : public PatternProcessPass { | |||||
| public: | |||||
| explicit InsertReshapeForExtractImagePatchesOp(bool multigraph = true) | |||||
| : PatternProcessPass("insert_reshape_for_extract_image_patches_op", multigraph) {} | |||||
| ~InsertReshapeForExtractImagePatchesOp() override = default; | |||||
| const BaseRef DefinePattern() const override; | |||||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||||
| }; | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_INSERT_RESHAPE_FOR_EXTRACT_IMAGE_PATCHES_OP_H | |||||
| @@ -563,10 +563,6 @@ std::vector<size_t> AnfRuntimeAlgorithm::GetInputDeviceShape(const AnfNodePtr &n | |||||
| if (trans::IsNeedPadding(format, infer_shape.size())) { | if (trans::IsNeedPadding(format, infer_shape.size())) { | ||||
| infer_shape = trans::PaddingShapeTo4d(infer_shape, GetInputReshapeType(node, input_idx)); | infer_shape = trans::PaddingShapeTo4d(infer_shape, GetInputReshapeType(node, input_idx)); | ||||
| } | } | ||||
| if (node->isa<CNode>() && GetCNodeName(node) == kExtractImagePatchesOpName) { | |||||
| auto shape_tmp = {infer_shape[0], infer_shape[3], infer_shape[1], infer_shape[2]}; | |||||
| return trans::TransShapeToDevice(shape_tmp, format); | |||||
| } | |||||
| return trans::TransShapeToDevice(infer_shape, format); | return trans::TransShapeToDevice(infer_shape, format); | ||||
| } | } | ||||
| @@ -720,19 +720,27 @@ class Unfold(Cell): | |||||
| def __init__(self, ksizes, strides, rates, padding="valid"): | def __init__(self, ksizes, strides, rates, padding="valid"): | ||||
| super(Unfold, self).__init__() | super(Unfold, self).__init__() | ||||
| def _check_tuple_or_list(arg_name, arg_val, prim_name): | |||||
| Validator.check_value_type(f"{arg_name}s", ksizes, [tuple, list], self.cls_name) | |||||
| if len(arg_val) != 4 or arg_val[0] != 1 or arg_val[3] != 1: | |||||
| raise ValueError(f"For \'{prim_name}\' the format of {arg_name}s should be [1, {arg_name}_row, " | |||||
| f"{arg_name}_col, 1], but got {arg_val}.") | |||||
| if not isinstance(arg_val[1], int) or not isinstance(arg_val[2], int) or arg_val[1] < 1 or arg_val[2] < 1: | |||||
| raise ValueError(f"For '{prim_name}' the {arg_name}_row and {arg_name}_col in {arg_name}s should be an " | |||||
| f"positive integer number, but got {arg_name}_row is {arg_val[1]}, {arg_name}_col " | |||||
| f"is {arg_val[2]}") | |||||
| _check_tuple_or_list("ksize", ksizes, self.cls_name) | |||||
| _check_tuple_or_list("stride", strides, self.cls_name) | |||||
| _check_tuple_or_list("rate", rates, self.cls_name) | |||||
| ksizes = ksizes[0], ksizes[3], ksizes[1], ksizes[2] | |||||
| strides = strides[0], strides[3], strides[1], strides[2] | |||||
| rates = rates[0], rates[3], rates[1], rates[2] | |||||
| self.extract_image_patches = inner.ExtractImagePatches(ksizes, strides, rates, padding) | self.extract_image_patches = inner.ExtractImagePatches(ksizes, strides, rates, padding) | ||||
| self.transpose = P.Transpose() | |||||
| self.format_NHWC = (0, 2, 3, 1) | |||||
| self.format_NCHW = (0, 3, 1, 2) | |||||
| self.is_ge = context.get_context("enable_ge") | |||||
| def construct(self, input_x): | def construct(self, input_x): | ||||
| if self.is_ge: | |||||
| x_transpose = self.transpose(input_x, self.format_NHWC) | |||||
| ret = self.extract_image_patches(x_transpose) | |||||
| result = self.transpose(ret, self.format_NCHW) | |||||
| else: | |||||
| result = self.extract_image_patches(input_x) | |||||
| result = self.extract_image_patches(input_x) | |||||
| return result | return result | ||||
| @@ -41,6 +41,11 @@ def _create_sequence_length(shape): | |||||
| def _check_input_dtype(input_dtype, param_name, allow_dtypes, cls_name): | def _check_input_dtype(input_dtype, param_name, allow_dtypes, cls_name): | ||||
| validator.check_type_name(param_name, input_dtype, allow_dtypes, cls_name) | validator.check_type_name(param_name, input_dtype, allow_dtypes, cls_name) | ||||
| @constexpr | |||||
| def _check_input_3d(input_shape, param_name, func_name): | |||||
| if len(input_shape) != 3: | |||||
| raise ValueError(f"{func_name} {param_name} should be 3d, but got shape {input_shape}") | |||||
| class LSTM(Cell): | class LSTM(Cell): | ||||
| r""" | r""" | ||||
| Stacked LSTM (Long Short-Term Memory) layers. | Stacked LSTM (Long Short-Term Memory) layers. | ||||
| @@ -237,6 +242,8 @@ class LSTM(Cell): | |||||
| x = self.transpose(x, (1, 0, 2)) | x = self.transpose(x, (1, 0, 2)) | ||||
| h, c = hx | h, c = hx | ||||
| if self.is_ascend: | if self.is_ascend: | ||||
| _check_input_3d(F.shape(h), "h of hx", self.cls_name) | |||||
| _check_input_3d(F.shape(c), "c of hx", self.cls_name) | |||||
| _check_input_dtype(F.dtype(x), "x", [mstype.float32, mstype.float16], self.cls_name) | _check_input_dtype(F.dtype(x), "x", [mstype.float32, mstype.float16], self.cls_name) | ||||
| _check_input_dtype(F.dtype(h), "h", [mstype.float32, mstype.float16], self.cls_name) | _check_input_dtype(F.dtype(h), "h", [mstype.float32, mstype.float16], self.cls_name) | ||||
| _check_input_dtype(F.dtype(c), "c", [mstype.float32, mstype.float16], self.cls_name) | _check_input_dtype(F.dtype(c), "c", [mstype.float32, mstype.float16], self.cls_name) | ||||
| @@ -122,7 +122,7 @@ def get_bprop_extract_image_patches(self): | |||||
| cast = P.Cast() | cast = P.Cast() | ||||
| matmul = P.MatMul() | matmul = P.MatMul() | ||||
| _, ksizes_row, ksizes_col, _ = self.ksizes | |||||
| _, _, ksizes_row, ksizes_col = self.ksizes | |||||
| def bprop(x, out, dout): | def bprop(x, out, dout): | ||||
| x_shape = get_shape(x) | x_shape = get_shape(x) | ||||
| @@ -155,39 +155,6 @@ def get_bprop_extract_image_patches(self): | |||||
| dx = transpose(dx, (2, 3, 0, 1)) | dx = transpose(dx, (2, 3, 0, 1)) | ||||
| return (dx,) | return (dx,) | ||||
| def bprop_ge(x, out, dout): | |||||
| x_shape = get_shape(x) | |||||
| x_batch, x_row, x_col, x_depth = x_shape | |||||
| x_indices_num = x_row * x_col + 1 | |||||
| x_idx = F.tuple_to_array(range(1, x_indices_num)) | |||||
| x_idx = reshape(x_idx, (1, x_row, x_col, 1)) | |||||
| x_idx_patch = extract_image_patches(x_idx) | |||||
| out_shape = get_shape(out) | |||||
| _, out_row, out_col, _ = out_shape | |||||
| out_indices_num = out_row * out_col * ksizes_row * ksizes_col | |||||
| out_idx = F.tuple_to_array(range(out_indices_num)) | |||||
| out_idx = reshape(out_idx, (1, out_row, out_col, ksizes_row * ksizes_col)) | |||||
| idx_tensor = concat((expand_dims(x_idx_patch, -1), expand_dims(out_idx, -1))) | |||||
| idx_tensor = reshape(idx_tensor, (-1, 2)) | |||||
| sp_shape = (x_indices_num, out_indices_num) | |||||
| sp_tensor = scatter_nd(idx_tensor, fill(dtype(dout), (out_indices_num,), 1), sp_shape) | |||||
| sp_tensor = slice_op(sp_tensor, (1, 0), (x_indices_num - 1, out_indices_num)) | |||||
| grad = reshape(dout, (x_batch, out_row, out_col, ksizes_row, ksizes_col, x_depth)) | |||||
| grad = transpose(grad, (1, 2, 3, 4, 0, 5)) | |||||
| grad = reshape(grad, (-1, x_batch * x_depth)) | |||||
| jac = matmul(sp_tensor, grad) | |||||
| dx = reshape(jac, (x_row, x_col, x_batch, x_depth)) | |||||
| dx = transpose(dx, (2, 0, 1, 3)) | |||||
| return (dx,) | |||||
| if context.get_context("enable_ge"): | |||||
| return bprop_ge | |||||
| return bprop | return bprop | ||||
| @@ -31,11 +31,11 @@ class ExtractImagePatches(PrimitiveWithInfer): | |||||
| Args: | Args: | ||||
| ksizes (Union[tuple[int], list[int]]): The size of sliding window, must be a tuple or a list of integers, | ksizes (Union[tuple[int], list[int]]): The size of sliding window, must be a tuple or a list of integers, | ||||
| and the format is [1, ksize_row, ksize_col, 1]. | |||||
| and the format is [1, 1, ksize_row, ksize_col]. | |||||
| strides (Union[tuple[int], list[int]]): Distance between the centers of the two consecutive patches, | strides (Union[tuple[int], list[int]]): Distance between the centers of the two consecutive patches, | ||||
| must be a tuple or list of int, and the format is [1, stride_row, stride_col, 1]. | |||||
| must be a tuple or list of int, and the format is [1, 1, stride_row, stride_col]. | |||||
| rates (Union[tuple[int], list[int]]): In each extracted patch, the gap between the corresponding dimension | rates (Union[tuple[int], list[int]]): In each extracted patch, the gap between the corresponding dimension | ||||
| pixel positions, must be a tuple or a list of integers, and the format is [1, rate_row, rate_col, 1]. | |||||
| pixel positions, must be a tuple or a list of integers, and the format is [1, 1, rate_row, rate_col]. | |||||
| padding (str): The type of padding algorithm, is a string whose value is "same" or "valid", | padding (str): The type of padding algorithm, is a string whose value is "same" or "valid", | ||||
| not case sensitive. Default: "valid". | not case sensitive. Default: "valid". | ||||
| @@ -58,30 +58,28 @@ class ExtractImagePatches(PrimitiveWithInfer): | |||||
| def _check_tuple_or_list(arg_name, arg_val, prim_name): | def _check_tuple_or_list(arg_name, arg_val, prim_name): | ||||
| validator.check_value_type(f"{arg_name}s", ksizes, [tuple, list], self.name) | validator.check_value_type(f"{arg_name}s", ksizes, [tuple, list], self.name) | ||||
| if len(arg_val) != 4 or arg_val[0] != 1 or arg_val[3] != 1: | |||||
| if len(arg_val) != 4 or arg_val[0] != 1 or arg_val[1] != 1: | |||||
| raise ValueError(f"For \'{prim_name}\' the format of {arg_name}s should be [1, {arg_name}_row, " | raise ValueError(f"For \'{prim_name}\' the format of {arg_name}s should be [1, {arg_name}_row, " | ||||
| f"{arg_name}_col, 1], but got {arg_val}.") | f"{arg_name}_col, 1], but got {arg_val}.") | ||||
| if not isinstance(arg_val[1], int) or not isinstance(arg_val[2], int) or arg_val[1] < 1 or arg_val[2] < 1: | |||||
| if not isinstance(arg_val[2], int) or not isinstance(arg_val[3], int) or arg_val[2] < 1 or arg_val[3] < 1: | |||||
| raise ValueError(f"For '{prim_name}' the {arg_name}_row and {arg_name}_col in {arg_name}s should be an " | raise ValueError(f"For '{prim_name}' the {arg_name}_row and {arg_name}_col in {arg_name}s should be an " | ||||
| f"positive integer number, but got {arg_name}_row is {arg_val[1]}, {arg_name}_col " | |||||
| f"is {arg_val[2]}") | |||||
| f"positive integer number, but got {arg_name}_row is {arg_val[2]}, {arg_name}_col " | |||||
| f"is {arg_val[3]}") | |||||
| _check_tuple_or_list("ksize", ksizes, self.name) | _check_tuple_or_list("ksize", ksizes, self.name) | ||||
| _check_tuple_or_list("stride", strides, self.name) | _check_tuple_or_list("stride", strides, self.name) | ||||
| _check_tuple_or_list("rate", rates, self.name) | _check_tuple_or_list("rate", rates, self.name) | ||||
| self.padding = validator.check_string(padding.upper(), ['VALID', 'SAME'], 'padding', self.name) | self.padding = validator.check_string(padding.upper(), ['VALID', 'SAME'], 'padding', self.name) | ||||
| self.add_prim_attr("padding", self.padding) | self.add_prim_attr("padding", self.padding) | ||||
| self.add_prim_attr("io_format", "NHWC") | |||||
| self.add_prim_attr("io_format", "NCHW") | |||||
| self.is_ge = context.get_context("enable_ge") | self.is_ge = context.get_context("enable_ge") | ||||
| def infer_shape(self, input_x): | def infer_shape(self, input_x): | ||||
| """infer shape""" | """infer shape""" | ||||
| in_batch, in_depth, in_row, in_col = input_x | in_batch, in_depth, in_row, in_col = input_x | ||||
| if self.is_ge: | |||||
| in_batch, in_row, in_col, in_depth = input_x | |||||
| _, ksize_row, ksize_col, _ = self.ksizes | |||||
| _, stride_row, stride_col, _ = self.strides | |||||
| _, rate_row, rate_col, _ = self.rates | |||||
| _, _, ksize_row, ksize_col = self.ksizes | |||||
| _, _, stride_row, stride_col = self.strides | |||||
| _, _, rate_row, rate_col = self.rates | |||||
| if len(input_x) != 4: | if len(input_x) != 4: | ||||
| raise ValueError("The `input_x` should be a 4-D tensor, " | raise ValueError("The `input_x` should be a 4-D tensor, " | ||||
| f"but got a {len(input_x)}-D tensor whose shape is {input_x}") | f"but got a {len(input_x)}-D tensor whose shape is {input_x}") | ||||
| @@ -99,8 +97,6 @@ class ExtractImagePatches(PrimitiveWithInfer): | |||||
| out_col = (in_col - 1) // stride_col + 1 | out_col = (in_col - 1) // stride_col + 1 | ||||
| out_shape = [out_batch, out_depth, out_row, out_col] | out_shape = [out_batch, out_depth, out_row, out_col] | ||||
| if self.is_ge: | |||||
| out_shape = [out_batch, out_row, out_col, out_depth] | |||||
| return out_shape | return out_shape | ||||
| def infer_dtype(self, input_x): | def infer_dtype(self, input_x): | ||||
| @@ -6405,7 +6405,7 @@ class DynamicRNN(PrimitiveWithInfer): | |||||
| >>> b = Tensor(np.random.rand(128).astype(np.float16)) | >>> b = Tensor(np.random.rand(128).astype(np.float16)) | ||||
| >>> init_h = Tensor(np.random.rand(1, 16, 32).astype(np.float16)) | >>> init_h = Tensor(np.random.rand(1, 16, 32).astype(np.float16)) | ||||
| >>> init_c = Tensor(np.random.rand(1, 16, 32).astype(np.float16)) | >>> init_c = Tensor(np.random.rand(1, 16, 32).astype(np.float16)) | ||||
| >>> dynamic_rnn = ops.DynamicRNNN() | |||||
| >>> dynamic_rnn = ops.DynamicRNN() | |||||
| >>> output = dynamic_rnn(x, w, b, None, init_h, init_c) | >>> output = dynamic_rnn(x, w, b, None, init_h, init_c) | ||||
| >>> print(output[0].shape) | >>> print(output[0].shape) | ||||
| (2, 16, 32) | (2, 16, 32) | ||||