Merge pull request !2297 from jiangjinsheng/vm_matrixdiagtags/v0.6.0-beta
| @@ -124,7 +124,10 @@ static std::map<string, string> tbe_func_adapter_map = { | |||||
| {"a_cos_grad", "acos_grad"}, | {"a_cos_grad", "acos_grad"}, | ||||
| {"histogram_fixed_width", "histogram_fixed_width_d"}, | {"histogram_fixed_width", "histogram_fixed_width_d"}, | ||||
| {"broadcast_to", "broadcast_to_d"}, | {"broadcast_to", "broadcast_to_d"}, | ||||
| {"inplace_update", "inplace_update_d"}}; | |||||
| {"inplace_update", "inplace_update_d"}, | |||||
| {"matrix_diag", "matrix_diag_d"}, | |||||
| {"matrix_diag_part", "matrix_diag_part_d"}, | |||||
| {"matrix_set_diag", "matrix_set_diag_d"}}; | |||||
| void TbeAdapter::NormalizeFuncName(std::string *func_name) { | void TbeAdapter::NormalizeFuncName(std::string *func_name) { | ||||
| if (func_name == nullptr) { | if (func_name == nullptr) { | ||||
| @@ -31,9 +31,12 @@ from mindspore.ops import _selected_ops | |||||
| from ..cell import Cell | from ..cell import Cell | ||||
| from .activation import get_activation | from .activation import get_activation | ||||
| from ..._checkparam import Validator as validator | from ..._checkparam import Validator as validator | ||||
| from ..._checkparam import Rel | |||||
| __all__ = ['Dropout', 'Flatten', 'Dense', 'ClipByNorm', 'Norm', 'OneHot', 'Pad', 'Unfold'] | |||||
| __all__ = ['Dropout', 'Flatten', 'Dense', 'ClipByNorm', 'Norm', 'OneHot', 'Pad', 'Unfold', | |||||
| 'MatrixDiag', 'MatrixDiagPart', 'MatrixSetDiag'] | |||||
| class Dropout(Cell): | class Dropout(Cell): | ||||
| r""" | r""" | ||||
| @@ -527,3 +530,112 @@ class Unfold(Cell): | |||||
| ret = self.extract_image_patches(x_transpose) | ret = self.extract_image_patches(x_transpose) | ||||
| ret_transpose = self.transpose(ret, self.format_NCHW) | ret_transpose = self.transpose(ret, self.format_NCHW) | ||||
| return ret_transpose | return ret_transpose | ||||
| @constexpr | |||||
| def _get_matrix_diag_assist(x_shape, x_dtype): | |||||
| validator.check_integer("x rank", len(x_shape), 1, Rel.GE, "_get_matrix_diag_assist") | |||||
| base_eye = np.eye(x_shape[-1], x_shape[-1]).reshape(-1) | |||||
| assist = np.tile(base_eye, x_shape[:-1]).reshape(x_shape + (x_shape[-1],)) | |||||
| return Tensor(assist, x_dtype) | |||||
| @constexpr | |||||
| def _get_matrix_diag_part_assist(x_shape, x_dtype): | |||||
| validator.check_integer("x rank", len(x_shape), 2, Rel.GE, "_get_matrix_diag_part_assist") | |||||
| base_eye = np.eye(x_shape[-2], x_shape[-1]).reshape(-1) | |||||
| assist = np.tile(base_eye, x_shape[:-2]).reshape(x_shape) | |||||
| return Tensor(assist, x_dtype) | |||||
| class MatrixDiag(Cell): | |||||
| """ | |||||
| Returns a batched diagonal tensor with a given batched diagonal values. | |||||
| Inputs: | |||||
| - **x** (Tensor) - The diagonal values. It can be of the following data types: | |||||
| float32, float16, int32, int8, uint8. | |||||
| Outputs: | |||||
| Tensor, same type as input `x`. The shape should be x.shape + (x.shape[-1], ). | |||||
| Examples: | |||||
| >>> x = Tensor(np.array([1, -1]), mstype.float32) | |||||
| >>> matrix_diag = nn.MatrixDiag() | |||||
| >>> result = matrix_diag(x) | |||||
| [[1. 0.] | |||||
| [0. -1.]] | |||||
| """ | |||||
| def __init__(self): | |||||
| super(MatrixDiag, self).__init__() | |||||
| self.matrix_diag = inner.MatrixDiag() | |||||
| self.dtype = P.DType() | |||||
| def construct(self, input_x): | |||||
| x_shape = F.shape(input_x) | |||||
| x_dtype = self.dtype(input_x) | |||||
| assist = _get_matrix_diag_assist(x_shape, x_dtype) | |||||
| out_matrix_diag = self.matrix_diag(input_x, assist) | |||||
| return out_matrix_diag | |||||
| class MatrixDiagPart(Cell): | |||||
| r""" | |||||
| Returns the batched diagonal part of a batched tensor. | |||||
| Inputs: | |||||
| - **x** (Tensor) - The batched tensor. It can be of the following data types: | |||||
| float32, float16, int32, int8, uint8. | |||||
| Outputs: | |||||
| Tensor, same type as input `x`. The shape should be x.shape[:-2] + [min(x.shape[-2:])]. | |||||
| Examples: | |||||
| >>> x = Tensor([[[-1, 0], [0, 1]], [-1, 0], [0, 1]], [[-1, 0], [0, 1]]], mindspore.float32) | |||||
| >>> matrix_diag_part = nn.MatrixDiagPart() | |||||
| >>> result = matrix_diag_part(x) | |||||
| [[-1., 1.], [-1., 1.], [-1., 1.]] | |||||
| """ | |||||
| def __init__(self): | |||||
| super(MatrixDiagPart, self).__init__() | |||||
| self.matrix_diag_part = inner.MatrixDiagPart() | |||||
| self.dtype = P.DType() | |||||
| def construct(self, input_x): | |||||
| x_shape = F.shape(input_x) | |||||
| x_dtype = self.dtype(input_x) | |||||
| assist = _get_matrix_diag_part_assist(x_shape, x_dtype) | |||||
| out_matrix_diag_part = self.matrix_diag_part(input_x, assist) | |||||
| return out_matrix_diag_part | |||||
| class MatrixSetDiag(Cell): | |||||
| r""" | |||||
| Modify the batched diagonal part of a batched tensor. | |||||
| Inputs: | |||||
| - **x** (Tensor) - The batched tensor. It can be of the following data types: | |||||
| float32, float16, int32, int8, uint8. | |||||
| - **diagonal** (Tensor) - The diagonal values. | |||||
| Outputs: | |||||
| Tensor, same type as input `x`. The shape same as `x`. | |||||
| Examples: | |||||
| >>> x = Tensor([[[-1, 0], [0, 1]], [-1, 0], [0, 1]], [[-1, 0], [0, 1]]], mindspore.float32) | |||||
| >>> diagonal = Tensor([[-1., 2.], [-1., 1.], [-1., 1.]], mindspore.float32) | |||||
| >>> matrix_set_diag = nn.MatrixSetDiag() | |||||
| >>> result = matrix_set_diag(x, diagonal) | |||||
| [[[-1, 0], [0, 2]], [-1, 0], [0, 1]], [[-1, 0], [0, 1]]] | |||||
| """ | |||||
| def __init__(self): | |||||
| super(MatrixSetDiag, self).__init__() | |||||
| self.matrix_set_diag = inner.MatrixSetDiag() | |||||
| self.dtype = P.DType() | |||||
| def construct(self, input_x, diagonal): | |||||
| x_shape = F.shape(input_x) | |||||
| x_dtype = self.dtype(input_x) | |||||
| assist = _get_matrix_diag_part_assist(x_shape, x_dtype) | |||||
| out_matrix_set_diag = self.matrix_set_diag(input_x, diagonal, assist) | |||||
| return out_matrix_set_diag | |||||
| @@ -264,3 +264,6 @@ from .inplace_update import _inplace_update_tbe | |||||
| from .splitv import _split_v_tbe | from .splitv import _split_v_tbe | ||||
| from .in_top_k import _in_top_k_tbe | from .in_top_k import _in_top_k_tbe | ||||
| from .lin_space import _lin_space_tbe | from .lin_space import _lin_space_tbe | ||||
| from .matrix_diag import _matrix_diag_tbe | |||||
| from .matrix_diag_part import _matrix_diag_part_tbe | |||||
| from .matrix_set_diag import _matrix_set_diag_tbe | |||||
| @@ -0,0 +1,45 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """MatrixDiagD op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||||
| matrix_diag_d_op_info = TBERegOp("MatrixDiag") \ | |||||
| .fusion_type("ELEMWISE") \ | |||||
| .async_flag(False) \ | |||||
| .binfile_name("matrix_diag_d.so") \ | |||||
| .compute_cost(10) \ | |||||
| .kernel_name("matrix_diag_d") \ | |||||
| .partial_flag(True) \ | |||||
| .input(0, "x", False, "required", "all") \ | |||||
| .input(1, "assist", False, "required", "all") \ | |||||
| .output(0, "y", False, "required", "all") \ | |||||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ | |||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \ | |||||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ | |||||
| .dtype_format(DataType.I8_5HD, DataType.I8_5HD, DataType.I8_5HD) \ | |||||
| .dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \ | |||||
| .dtype_format(DataType.U8_5HD, DataType.U8_5HD, DataType.U8_5HD) \ | |||||
| .dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \ | |||||
| .get_op_info() | |||||
| @op_info_register(matrix_diag_d_op_info) | |||||
| def _matrix_diag_tbe(): | |||||
| """MatrixDiagD TBE register""" | |||||
| return | |||||
| @@ -0,0 +1,45 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """MatrixDiagPartD op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||||
| matrix_diag_part_d_op_info = TBERegOp("MatrixDiagPart") \ | |||||
| .fusion_type("ELEMWISE") \ | |||||
| .async_flag(False) \ | |||||
| .binfile_name("matrix_diag_part_d.so") \ | |||||
| .compute_cost(10) \ | |||||
| .kernel_name("matrix_diag_part_d") \ | |||||
| .partial_flag(True) \ | |||||
| .input(0, "x", False, "required", "all") \ | |||||
| .input(1, "assist", False, "required", "all") \ | |||||
| .output(0, "y", False, "required", "all") \ | |||||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ | |||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \ | |||||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ | |||||
| .dtype_format(DataType.I8_5HD, DataType.I8_5HD, DataType.I8_5HD) \ | |||||
| .dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \ | |||||
| .dtype_format(DataType.U8_5HD, DataType.U8_5HD, DataType.U8_5HD) \ | |||||
| .dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \ | |||||
| .get_op_info() | |||||
| @op_info_register(matrix_diag_part_d_op_info) | |||||
| def _matrix_diag_part_tbe(): | |||||
| """MatrixDiagPartD TBE register""" | |||||
| return | |||||
| @@ -0,0 +1,46 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """MatrixSetDiagD op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||||
| matrix_diag_d_op_info = TBERegOp("MatrixSetDiag") \ | |||||
| .fusion_type("ELEMWISE") \ | |||||
| .async_flag(False) \ | |||||
| .binfile_name("matrix_diag_d.so") \ | |||||
| .compute_cost(10) \ | |||||
| .kernel_name("matrix_diag_d") \ | |||||
| .partial_flag(True) \ | |||||
| .input(0, "x", False, "required", "all") \ | |||||
| .input(1, "diagonal", False, "required", "all") \ | |||||
| .input(2, "assist", False, "required", "all") \ | |||||
| .output(0, "y", False, "required", "all") \ | |||||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ | |||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \ | |||||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ | |||||
| .dtype_format(DataType.I8_5HD, DataType.I8_5HD, DataType.I8_5HD, DataType.I8_5HD) \ | |||||
| .dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \ | |||||
| .dtype_format(DataType.U8_5HD, DataType.U8_5HD, DataType.U8_5HD, DataType.U8_5HD) \ | |||||
| .dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \ | |||||
| .get_op_info() | |||||
| @op_info_register(matrix_diag_d_op_info) | |||||
| def _matrix_set_diag_tbe(): | |||||
| """MatrixSetDiagD TBE register""" | |||||
| return | |||||
| @@ -367,3 +367,144 @@ class LinSpace(PrimitiveWithInfer): | |||||
| args = {"assist": assist, "start": start, "stop": stop} | args = {"assist": assist, "start": start, "stop": stop} | ||||
| validator.check_tensor_type_same(args, (mstype.float32,), self.name) | validator.check_tensor_type_same(args, (mstype.float32,), self.name) | ||||
| return assist | return assist | ||||
| class MatrixDiag(PrimitiveWithInfer): | |||||
| """ | |||||
| Returns a batched diagonal tensor with a given batched diagonal values. | |||||
| Inputs: | |||||
| - **x** (Tensor) - A tensor which to be element-wise multi by `assist`. It can be of the following data types: | |||||
| float32, float16, int32, int8, uint8. | |||||
| - **assist** (Tensor) - A eye tensor of the same type as `x`. It's rank must greater than or equal to 2 and | |||||
| it's last dimension must equal to the second to last dimension. | |||||
| Outputs: | |||||
| Tensor, has the same type and shape as input `assist`. | |||||
| Examples: | |||||
| >>> x = Tensor(np.array([1, -1]), mstype.float32) | |||||
| >>> assist = Tensor(np.arange(-12, 0).reshape(3, 2, 2), mindspore.float32) | |||||
| >>> matrix_diag = P.MatrixDiag() | |||||
| >>> result = matrix_diag(x, assist) | |||||
| [[[-12. 11.] | |||||
| [-10. 9.]] | |||||
| [[ -8. 7.] | |||||
| [ -6. 5.]] | |||||
| [[ -4. 3.] | |||||
| [ -2. 1.]]] | |||||
| """ | |||||
| @prim_attr_register | |||||
| def __init__(self): | |||||
| """init MatrixDiag""" | |||||
| def infer_dtype(self, x_dtype, assist_dtype): | |||||
| valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8] | |||||
| args = {"x": x_dtype, "assist": assist_dtype} | |||||
| validator.check_tensor_type_same(args, valid_type, self.name) | |||||
| return x_dtype | |||||
| def infer_shape(self, x_shape, assist_shape): | |||||
| validator.check_integer("assist rank", len(assist_shape), 2, Rel.GE, self.name) | |||||
| validator.check('rank of x', len(x_shape)+1, | |||||
| 'rank of assist', len(assist_shape), Rel.LE, self.name) | |||||
| validator.check('assist\'s penultimate dimension', assist_shape[-2], 'assist\'s last dimension', | |||||
| assist_shape[-1], Rel.EQ, self.name) | |||||
| r_end_dim = -len(x_shape) | |||||
| r_idx = -1 | |||||
| while r_idx >= r_end_dim: | |||||
| if x_shape[r_idx] != 1: | |||||
| validator.check("reverse x dim %d" % r_idx, x_shape[r_idx], "reverse assist dim %d" % | |||||
| assist_shape[r_idx-1], assist_shape[r_idx-1], Rel.EQ, self.name) | |||||
| r_idx = r_idx - 1 | |||||
| return assist_shape | |||||
| class MatrixDiagPart(PrimitiveWithInfer): | |||||
| r""" | |||||
| Returns the batched diagonal part of a batched tensor. | |||||
| Inputs: | |||||
| - **x** (Tensor) - The batched tensor. It can be of the following data types: | |||||
| float32, float16, int32, int8, uint8. | |||||
| - **assist** (Tensor) - A eye tensor of the same type as `x`. With shape same as `x`. | |||||
| Outputs: | |||||
| Tensor, data type same as input `x`. The shape should be x.shape[:-2] + [min(x.shape[-2:])]. | |||||
| Examples: | |||||
| >>> x = Tensor([[[-1, 0], [0, 1]], [-1, 0], [0, 1]], [[-1, 0], [0, 1]]], mindspore.float32) | |||||
| >>> assist = Tensor(np.arange(-12, 0).reshape(3, 2, 2), mindspore.float32) | |||||
| >>> matrix_diag_part = P.MatrixDiagPart() | |||||
| >>> result = matrix_diag_part(x, assist) | |||||
| [[12., -9.], [8., -5.], [4., -1.]] | |||||
| """ | |||||
| @prim_attr_register | |||||
| def __init__(self): | |||||
| """init MatrixDiagPart""" | |||||
| def infer_dtype(self, x_dtype, assist_dtype): | |||||
| valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8] | |||||
| args = {"x": x_dtype, "assist": assist_dtype} | |||||
| validator.check_tensor_type_same(args, valid_type, self.name) | |||||
| return x_dtype | |||||
| def infer_shape(self, x_shape, assist_shape): | |||||
| validator.check_integer("x rank", len(x_shape), 2, Rel.GE, self.name) | |||||
| validator.check("x shape", x_shape, "assist shape", assist_shape, Rel.EQ, self.name) | |||||
| if assist_shape[-2] < assist_shape[-1]: | |||||
| out_shape = assist_shape[:-1] | |||||
| else: | |||||
| out_shape = assist_shape[:-2] + assist_shape[-1:] | |||||
| return out_shape | |||||
| class MatrixSetDiag(PrimitiveWithInfer): | |||||
| r""" | |||||
| Modify the batched diagonal part of a batched tensor. | |||||
| Inputs: | |||||
| - **x** (Tensor) - The batched tensor. It can be of the following data types: | |||||
| float32, float16, int32, int8, uint8. | |||||
| - **assist** (Tensor) - A eye tensor of the same type as `x`. With shape same as `x`. | |||||
| - **diagonal** (Tensor) - The diagonal values. | |||||
| Outputs: | |||||
| Tensor, data type same as input `x`. The shape same as `x`. | |||||
| Examples: | |||||
| >>> x = Tensor([[[-1, 0], [0, 1]], [-1, 0], [0, 1]], [[-1, 0], [0, 1]]], mindspore.float32) | |||||
| >>> diagonal = Tensor([[-1., 2.], [-1., 1.], [-1., 1.]], mindspore.float32) | |||||
| >>> matrix_set_diag = P.MatrixSetDiag() | |||||
| >>> result = matrix_set_diag(x, diagonal) | |||||
| [[[-1, 0], [0, 2]], [-1, 0], [0, 1]], [[-1, 0], [0, 1]]] | |||||
| """ | |||||
| @prim_attr_register | |||||
| def __init__(self): | |||||
| """init MatrixSetDiag""" | |||||
| def infer_dtype(self, x_dtype, diagonal_dtype, assist_dtype): | |||||
| valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8] | |||||
| args = {"x": x_dtype, "diagonal": diagonal_dtype, "assist": assist_dtype} | |||||
| validator.check_tensor_type_same(args, valid_type, self.name) | |||||
| return x_dtype | |||||
| def infer_shape(self, x_shape, diagonal_shape, assist_shape): | |||||
| validator.check_integer("x rank", len(x_shape), 2, Rel.GE, self.name) | |||||
| validator.check("x shape", x_shape, "assist shape", assist_shape, Rel.EQ, self.name) | |||||
| if x_shape[-2] < x_shape[-1]: | |||||
| validator.check("x shape excluding the last dimension", x_shape[:-1], "diagnoal shape", | |||||
| diagonal_shape, Rel.EQ, self.name) | |||||
| else: | |||||
| validator.check("x shape excluding the second to last dimension", x_shape[:-2]+x_shape[-1:], | |||||
| "diagonal shape", diagonal_shape, Rel.EQ, self.name) | |||||
| return assist_shape | |||||
| @@ -370,6 +370,7 @@ def test_conv2d_same_primitive(): | |||||
| super(Conv2DSameNet, self).__init__() | super(Conv2DSameNet, self).__init__() | ||||
| self.conv1 = nn.Conv2d(16, 64, (1, 41), (1, 4), "same", 0, 1, has_bias=True) | self.conv1 = nn.Conv2d(16, 64, (1, 41), (1, 4), "same", 0, 1, has_bias=True) | ||||
| self.conv2 = nn.Conv2d(16, 64, (1, 41), (1, 4), "same", 0, 1, has_bias=True) | self.conv2 = nn.Conv2d(16, 64, (1, 41), (1, 4), "same", 0, 1, has_bias=True) | ||||
| def construct(self, x, y): | def construct(self, x, y): | ||||
| r1 = self.conv1(x) | r1 = self.conv1(x) | ||||
| r2 = self.conv2(y) | r2 = self.conv2(y) | ||||
| @@ -576,6 +577,22 @@ test_cases = [ | |||||
| Tensor(np.ones([1, 3, 4, 4], np.float32)), | Tensor(np.ones([1, 3, 4, 4], np.float32)), | ||||
| Tensor(np.ones(3, np.float32))], | Tensor(np.ones(3, np.float32))], | ||||
| }), | }), | ||||
| ('MatrixDiag', { | |||||
| 'block': nn.MatrixDiag(), | |||||
| 'desc_inputs': [Tensor(np.array([1, 2, 3]).astype(np.float32))], | |||||
| 'skip': ['backward'] | |||||
| }), | |||||
| ('MatrixDiagPart', { | |||||
| 'block': nn.MatrixDiagPart(), | |||||
| 'desc_inputs': [Tensor(np.array([[1, 2, 3], [4, 5, 6]]).astype(np.float32))], | |||||
| 'skip': ['backward'] | |||||
| }), | |||||
| ('MatrixSetDiag', { | |||||
| 'block': nn.MatrixSetDiag(), | |||||
| 'desc_inputs': [Tensor(np.array([[1, 2, 3], [4, 5, 6]]).astype(np.float32)), | |||||
| Tensor(np.array([1, 2]).astype(np.float32))], | |||||
| 'skip': ['backward'] | |||||
| }), | |||||
| ] | ] | ||||
| test_cases_for_verify_exception = [ | test_cases_for_verify_exception = [ | ||||
| @@ -1612,6 +1612,25 @@ test_case_array_ops = [ | |||||
| Tensor(5, mstype.int32)], | Tensor(5, mstype.int32)], | ||||
| 'skip': ['backward'], | 'skip': ['backward'], | ||||
| }), | }), | ||||
| ('MatrixDiag', { | |||||
| 'block': inner.MatrixDiag(), | |||||
| 'desc_inputs': [Tensor(np.array([1, -1]), mstype.float32), | |||||
| Tensor(np.arange(-12, 0).reshape(3, 2, 2), mstype.float32)], | |||||
| 'skip': ['backward'], | |||||
| }), | |||||
| ('MatrixDiagPart', { | |||||
| 'block': inner.MatrixDiagPart(), | |||||
| 'desc_inputs': [Tensor(np.arange(12).reshape(3, 2, 2), mstype.float32), | |||||
| Tensor(np.arange(-12, 0).reshape(3, 2, 2), mstype.float32)], | |||||
| 'skip': ['backward'], | |||||
| }), | |||||
| ('MatrixSetDiag', { | |||||
| 'block': inner.MatrixSetDiag(), | |||||
| 'desc_inputs': [Tensor(np.arange(12).reshape(3, 2, 2), mstype.float32), | |||||
| Tensor(np.arange(6).reshape(3, 2), mstype.float32), | |||||
| Tensor(np.arange(-12, 0).reshape(3, 2, 2), mstype.float32)], | |||||
| 'skip': ['backward'], | |||||
| }), | |||||
| ] | ] | ||||
| test_case_other_ops = [ | test_case_other_ops = [ | ||||