Merge pull request !2297 from jiangjinsheng/vm_matrixdiagtags/v0.6.0-beta
| @@ -124,7 +124,10 @@ static std::map<string, string> tbe_func_adapter_map = { | |||
| {"a_cos_grad", "acos_grad"}, | |||
| {"histogram_fixed_width", "histogram_fixed_width_d"}, | |||
| {"broadcast_to", "broadcast_to_d"}, | |||
| {"inplace_update", "inplace_update_d"}}; | |||
| {"inplace_update", "inplace_update_d"}, | |||
| {"matrix_diag", "matrix_diag_d"}, | |||
| {"matrix_diag_part", "matrix_diag_part_d"}, | |||
| {"matrix_set_diag", "matrix_set_diag_d"}}; | |||
| void TbeAdapter::NormalizeFuncName(std::string *func_name) { | |||
| if (func_name == nullptr) { | |||
| @@ -31,9 +31,12 @@ from mindspore.ops import _selected_ops | |||
| from ..cell import Cell | |||
| from .activation import get_activation | |||
| from ..._checkparam import Validator as validator | |||
| from ..._checkparam import Rel | |||
| __all__ = ['Dropout', 'Flatten', 'Dense', 'ClipByNorm', 'Norm', 'OneHot', 'Pad', 'Unfold'] | |||
| __all__ = ['Dropout', 'Flatten', 'Dense', 'ClipByNorm', 'Norm', 'OneHot', 'Pad', 'Unfold', | |||
| 'MatrixDiag', 'MatrixDiagPart', 'MatrixSetDiag'] | |||
| class Dropout(Cell): | |||
| r""" | |||
| @@ -527,3 +530,112 @@ class Unfold(Cell): | |||
| ret = self.extract_image_patches(x_transpose) | |||
| ret_transpose = self.transpose(ret, self.format_NCHW) | |||
| return ret_transpose | |||
| @constexpr | |||
| def _get_matrix_diag_assist(x_shape, x_dtype): | |||
| validator.check_integer("x rank", len(x_shape), 1, Rel.GE, "_get_matrix_diag_assist") | |||
| base_eye = np.eye(x_shape[-1], x_shape[-1]).reshape(-1) | |||
| assist = np.tile(base_eye, x_shape[:-1]).reshape(x_shape + (x_shape[-1],)) | |||
| return Tensor(assist, x_dtype) | |||
| @constexpr | |||
| def _get_matrix_diag_part_assist(x_shape, x_dtype): | |||
| validator.check_integer("x rank", len(x_shape), 2, Rel.GE, "_get_matrix_diag_part_assist") | |||
| base_eye = np.eye(x_shape[-2], x_shape[-1]).reshape(-1) | |||
| assist = np.tile(base_eye, x_shape[:-2]).reshape(x_shape) | |||
| return Tensor(assist, x_dtype) | |||
| class MatrixDiag(Cell): | |||
| """ | |||
| Returns a batched diagonal tensor with a given batched diagonal values. | |||
| Inputs: | |||
| - **x** (Tensor) - The diagonal values. It can be of the following data types: | |||
| float32, float16, int32, int8, uint8. | |||
| Outputs: | |||
| Tensor, same type as input `x`. The shape should be x.shape + (x.shape[-1], ). | |||
| Examples: | |||
| >>> x = Tensor(np.array([1, -1]), mstype.float32) | |||
| >>> matrix_diag = nn.MatrixDiag() | |||
| >>> result = matrix_diag(x) | |||
| [[1. 0.] | |||
| [0. -1.]] | |||
| """ | |||
| def __init__(self): | |||
| super(MatrixDiag, self).__init__() | |||
| self.matrix_diag = inner.MatrixDiag() | |||
| self.dtype = P.DType() | |||
| def construct(self, input_x): | |||
| x_shape = F.shape(input_x) | |||
| x_dtype = self.dtype(input_x) | |||
| assist = _get_matrix_diag_assist(x_shape, x_dtype) | |||
| out_matrix_diag = self.matrix_diag(input_x, assist) | |||
| return out_matrix_diag | |||
| class MatrixDiagPart(Cell): | |||
| r""" | |||
| Returns the batched diagonal part of a batched tensor. | |||
| Inputs: | |||
| - **x** (Tensor) - The batched tensor. It can be of the following data types: | |||
| float32, float16, int32, int8, uint8. | |||
| Outputs: | |||
| Tensor, same type as input `x`. The shape should be x.shape[:-2] + [min(x.shape[-2:])]. | |||
| Examples: | |||
| >>> x = Tensor([[[-1, 0], [0, 1]], [-1, 0], [0, 1]], [[-1, 0], [0, 1]]], mindspore.float32) | |||
| >>> matrix_diag_part = nn.MatrixDiagPart() | |||
| >>> result = matrix_diag_part(x) | |||
| [[-1., 1.], [-1., 1.], [-1., 1.]] | |||
| """ | |||
| def __init__(self): | |||
| super(MatrixDiagPart, self).__init__() | |||
| self.matrix_diag_part = inner.MatrixDiagPart() | |||
| self.dtype = P.DType() | |||
| def construct(self, input_x): | |||
| x_shape = F.shape(input_x) | |||
| x_dtype = self.dtype(input_x) | |||
| assist = _get_matrix_diag_part_assist(x_shape, x_dtype) | |||
| out_matrix_diag_part = self.matrix_diag_part(input_x, assist) | |||
| return out_matrix_diag_part | |||
| class MatrixSetDiag(Cell): | |||
| r""" | |||
| Modify the batched diagonal part of a batched tensor. | |||
| Inputs: | |||
| - **x** (Tensor) - The batched tensor. It can be of the following data types: | |||
| float32, float16, int32, int8, uint8. | |||
| - **diagonal** (Tensor) - The diagonal values. | |||
| Outputs: | |||
| Tensor, same type as input `x`. The shape same as `x`. | |||
| Examples: | |||
| >>> x = Tensor([[[-1, 0], [0, 1]], [-1, 0], [0, 1]], [[-1, 0], [0, 1]]], mindspore.float32) | |||
| >>> diagonal = Tensor([[-1., 2.], [-1., 1.], [-1., 1.]], mindspore.float32) | |||
| >>> matrix_set_diag = nn.MatrixSetDiag() | |||
| >>> result = matrix_set_diag(x, diagonal) | |||
| [[[-1, 0], [0, 2]], [-1, 0], [0, 1]], [[-1, 0], [0, 1]]] | |||
| """ | |||
| def __init__(self): | |||
| super(MatrixSetDiag, self).__init__() | |||
| self.matrix_set_diag = inner.MatrixSetDiag() | |||
| self.dtype = P.DType() | |||
| def construct(self, input_x, diagonal): | |||
| x_shape = F.shape(input_x) | |||
| x_dtype = self.dtype(input_x) | |||
| assist = _get_matrix_diag_part_assist(x_shape, x_dtype) | |||
| out_matrix_set_diag = self.matrix_set_diag(input_x, diagonal, assist) | |||
| return out_matrix_set_diag | |||
| @@ -264,3 +264,6 @@ from .inplace_update import _inplace_update_tbe | |||
| from .splitv import _split_v_tbe | |||
| from .in_top_k import _in_top_k_tbe | |||
| from .lin_space import _lin_space_tbe | |||
| from .matrix_diag import _matrix_diag_tbe | |||
| from .matrix_diag_part import _matrix_diag_part_tbe | |||
| from .matrix_set_diag import _matrix_set_diag_tbe | |||
| @@ -0,0 +1,45 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """MatrixDiagD op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| matrix_diag_d_op_info = TBERegOp("MatrixDiag") \ | |||
| .fusion_type("ELEMWISE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("matrix_diag_d.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("matrix_diag_d") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .input(1, "assist", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \ | |||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.I8_5HD, DataType.I8_5HD, DataType.I8_5HD) \ | |||
| .dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \ | |||
| .dtype_format(DataType.U8_5HD, DataType.U8_5HD, DataType.U8_5HD) \ | |||
| .dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \ | |||
| .get_op_info() | |||
| @op_info_register(matrix_diag_d_op_info) | |||
| def _matrix_diag_tbe(): | |||
| """MatrixDiagD TBE register""" | |||
| return | |||
| @@ -0,0 +1,45 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """MatrixDiagPartD op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| matrix_diag_part_d_op_info = TBERegOp("MatrixDiagPart") \ | |||
| .fusion_type("ELEMWISE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("matrix_diag_part_d.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("matrix_diag_part_d") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .input(1, "assist", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \ | |||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.I8_5HD, DataType.I8_5HD, DataType.I8_5HD) \ | |||
| .dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \ | |||
| .dtype_format(DataType.U8_5HD, DataType.U8_5HD, DataType.U8_5HD) \ | |||
| .dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \ | |||
| .get_op_info() | |||
| @op_info_register(matrix_diag_part_d_op_info) | |||
| def _matrix_diag_part_tbe(): | |||
| """MatrixDiagPartD TBE register""" | |||
| return | |||
| @@ -0,0 +1,46 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """MatrixSetDiagD op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| matrix_diag_d_op_info = TBERegOp("MatrixSetDiag") \ | |||
| .fusion_type("ELEMWISE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("matrix_diag_d.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("matrix_diag_d") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .input(1, "diagonal", False, "required", "all") \ | |||
| .input(2, "assist", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \ | |||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.I8_5HD, DataType.I8_5HD, DataType.I8_5HD, DataType.I8_5HD) \ | |||
| .dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \ | |||
| .dtype_format(DataType.U8_5HD, DataType.U8_5HD, DataType.U8_5HD, DataType.U8_5HD) \ | |||
| .dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \ | |||
| .get_op_info() | |||
| @op_info_register(matrix_diag_d_op_info) | |||
| def _matrix_set_diag_tbe(): | |||
| """MatrixSetDiagD TBE register""" | |||
| return | |||
| @@ -367,3 +367,144 @@ class LinSpace(PrimitiveWithInfer): | |||
| args = {"assist": assist, "start": start, "stop": stop} | |||
| validator.check_tensor_type_same(args, (mstype.float32,), self.name) | |||
| return assist | |||
| class MatrixDiag(PrimitiveWithInfer): | |||
| """ | |||
| Returns a batched diagonal tensor with a given batched diagonal values. | |||
| Inputs: | |||
| - **x** (Tensor) - A tensor which to be element-wise multi by `assist`. It can be of the following data types: | |||
| float32, float16, int32, int8, uint8. | |||
| - **assist** (Tensor) - A eye tensor of the same type as `x`. It's rank must greater than or equal to 2 and | |||
| it's last dimension must equal to the second to last dimension. | |||
| Outputs: | |||
| Tensor, has the same type and shape as input `assist`. | |||
| Examples: | |||
| >>> x = Tensor(np.array([1, -1]), mstype.float32) | |||
| >>> assist = Tensor(np.arange(-12, 0).reshape(3, 2, 2), mindspore.float32) | |||
| >>> matrix_diag = P.MatrixDiag() | |||
| >>> result = matrix_diag(x, assist) | |||
| [[[-12. 11.] | |||
| [-10. 9.]] | |||
| [[ -8. 7.] | |||
| [ -6. 5.]] | |||
| [[ -4. 3.] | |||
| [ -2. 1.]]] | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| """init MatrixDiag""" | |||
| def infer_dtype(self, x_dtype, assist_dtype): | |||
| valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8] | |||
| args = {"x": x_dtype, "assist": assist_dtype} | |||
| validator.check_tensor_type_same(args, valid_type, self.name) | |||
| return x_dtype | |||
| def infer_shape(self, x_shape, assist_shape): | |||
| validator.check_integer("assist rank", len(assist_shape), 2, Rel.GE, self.name) | |||
| validator.check('rank of x', len(x_shape)+1, | |||
| 'rank of assist', len(assist_shape), Rel.LE, self.name) | |||
| validator.check('assist\'s penultimate dimension', assist_shape[-2], 'assist\'s last dimension', | |||
| assist_shape[-1], Rel.EQ, self.name) | |||
| r_end_dim = -len(x_shape) | |||
| r_idx = -1 | |||
| while r_idx >= r_end_dim: | |||
| if x_shape[r_idx] != 1: | |||
| validator.check("reverse x dim %d" % r_idx, x_shape[r_idx], "reverse assist dim %d" % | |||
| assist_shape[r_idx-1], assist_shape[r_idx-1], Rel.EQ, self.name) | |||
| r_idx = r_idx - 1 | |||
| return assist_shape | |||
| class MatrixDiagPart(PrimitiveWithInfer): | |||
| r""" | |||
| Returns the batched diagonal part of a batched tensor. | |||
| Inputs: | |||
| - **x** (Tensor) - The batched tensor. It can be of the following data types: | |||
| float32, float16, int32, int8, uint8. | |||
| - **assist** (Tensor) - A eye tensor of the same type as `x`. With shape same as `x`. | |||
| Outputs: | |||
| Tensor, data type same as input `x`. The shape should be x.shape[:-2] + [min(x.shape[-2:])]. | |||
| Examples: | |||
| >>> x = Tensor([[[-1, 0], [0, 1]], [-1, 0], [0, 1]], [[-1, 0], [0, 1]]], mindspore.float32) | |||
| >>> assist = Tensor(np.arange(-12, 0).reshape(3, 2, 2), mindspore.float32) | |||
| >>> matrix_diag_part = P.MatrixDiagPart() | |||
| >>> result = matrix_diag_part(x, assist) | |||
| [[12., -9.], [8., -5.], [4., -1.]] | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| """init MatrixDiagPart""" | |||
| def infer_dtype(self, x_dtype, assist_dtype): | |||
| valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8] | |||
| args = {"x": x_dtype, "assist": assist_dtype} | |||
| validator.check_tensor_type_same(args, valid_type, self.name) | |||
| return x_dtype | |||
| def infer_shape(self, x_shape, assist_shape): | |||
| validator.check_integer("x rank", len(x_shape), 2, Rel.GE, self.name) | |||
| validator.check("x shape", x_shape, "assist shape", assist_shape, Rel.EQ, self.name) | |||
| if assist_shape[-2] < assist_shape[-1]: | |||
| out_shape = assist_shape[:-1] | |||
| else: | |||
| out_shape = assist_shape[:-2] + assist_shape[-1:] | |||
| return out_shape | |||
| class MatrixSetDiag(PrimitiveWithInfer): | |||
| r""" | |||
| Modify the batched diagonal part of a batched tensor. | |||
| Inputs: | |||
| - **x** (Tensor) - The batched tensor. It can be of the following data types: | |||
| float32, float16, int32, int8, uint8. | |||
| - **assist** (Tensor) - A eye tensor of the same type as `x`. With shape same as `x`. | |||
| - **diagonal** (Tensor) - The diagonal values. | |||
| Outputs: | |||
| Tensor, data type same as input `x`. The shape same as `x`. | |||
| Examples: | |||
| >>> x = Tensor([[[-1, 0], [0, 1]], [-1, 0], [0, 1]], [[-1, 0], [0, 1]]], mindspore.float32) | |||
| >>> diagonal = Tensor([[-1., 2.], [-1., 1.], [-1., 1.]], mindspore.float32) | |||
| >>> matrix_set_diag = P.MatrixSetDiag() | |||
| >>> result = matrix_set_diag(x, diagonal) | |||
| [[[-1, 0], [0, 2]], [-1, 0], [0, 1]], [[-1, 0], [0, 1]]] | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| """init MatrixSetDiag""" | |||
| def infer_dtype(self, x_dtype, diagonal_dtype, assist_dtype): | |||
| valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8] | |||
| args = {"x": x_dtype, "diagonal": diagonal_dtype, "assist": assist_dtype} | |||
| validator.check_tensor_type_same(args, valid_type, self.name) | |||
| return x_dtype | |||
| def infer_shape(self, x_shape, diagonal_shape, assist_shape): | |||
| validator.check_integer("x rank", len(x_shape), 2, Rel.GE, self.name) | |||
| validator.check("x shape", x_shape, "assist shape", assist_shape, Rel.EQ, self.name) | |||
| if x_shape[-2] < x_shape[-1]: | |||
| validator.check("x shape excluding the last dimension", x_shape[:-1], "diagnoal shape", | |||
| diagonal_shape, Rel.EQ, self.name) | |||
| else: | |||
| validator.check("x shape excluding the second to last dimension", x_shape[:-2]+x_shape[-1:], | |||
| "diagonal shape", diagonal_shape, Rel.EQ, self.name) | |||
| return assist_shape | |||
| @@ -370,6 +370,7 @@ def test_conv2d_same_primitive(): | |||
| super(Conv2DSameNet, self).__init__() | |||
| self.conv1 = nn.Conv2d(16, 64, (1, 41), (1, 4), "same", 0, 1, has_bias=True) | |||
| self.conv2 = nn.Conv2d(16, 64, (1, 41), (1, 4), "same", 0, 1, has_bias=True) | |||
| def construct(self, x, y): | |||
| r1 = self.conv1(x) | |||
| r2 = self.conv2(y) | |||
| @@ -576,6 +577,22 @@ test_cases = [ | |||
| Tensor(np.ones([1, 3, 4, 4], np.float32)), | |||
| Tensor(np.ones(3, np.float32))], | |||
| }), | |||
| ('MatrixDiag', { | |||
| 'block': nn.MatrixDiag(), | |||
| 'desc_inputs': [Tensor(np.array([1, 2, 3]).astype(np.float32))], | |||
| 'skip': ['backward'] | |||
| }), | |||
| ('MatrixDiagPart', { | |||
| 'block': nn.MatrixDiagPart(), | |||
| 'desc_inputs': [Tensor(np.array([[1, 2, 3], [4, 5, 6]]).astype(np.float32))], | |||
| 'skip': ['backward'] | |||
| }), | |||
| ('MatrixSetDiag', { | |||
| 'block': nn.MatrixSetDiag(), | |||
| 'desc_inputs': [Tensor(np.array([[1, 2, 3], [4, 5, 6]]).astype(np.float32)), | |||
| Tensor(np.array([1, 2]).astype(np.float32))], | |||
| 'skip': ['backward'] | |||
| }), | |||
| ] | |||
| test_cases_for_verify_exception = [ | |||
| @@ -1612,6 +1612,25 @@ test_case_array_ops = [ | |||
| Tensor(5, mstype.int32)], | |||
| 'skip': ['backward'], | |||
| }), | |||
| ('MatrixDiag', { | |||
| 'block': inner.MatrixDiag(), | |||
| 'desc_inputs': [Tensor(np.array([1, -1]), mstype.float32), | |||
| Tensor(np.arange(-12, 0).reshape(3, 2, 2), mstype.float32)], | |||
| 'skip': ['backward'], | |||
| }), | |||
| ('MatrixDiagPart', { | |||
| 'block': inner.MatrixDiagPart(), | |||
| 'desc_inputs': [Tensor(np.arange(12).reshape(3, 2, 2), mstype.float32), | |||
| Tensor(np.arange(-12, 0).reshape(3, 2, 2), mstype.float32)], | |||
| 'skip': ['backward'], | |||
| }), | |||
| ('MatrixSetDiag', { | |||
| 'block': inner.MatrixSetDiag(), | |||
| 'desc_inputs': [Tensor(np.arange(12).reshape(3, 2, 2), mstype.float32), | |||
| Tensor(np.arange(6).reshape(3, 2), mstype.float32), | |||
| Tensor(np.arange(-12, 0).reshape(3, 2, 2), mstype.float32)], | |||
| 'skip': ['backward'], | |||
| }), | |||
| ] | |||
| test_case_other_ops = [ | |||