| @@ -15,8 +15,9 @@ | |||
| """Linear algebra submodule""" | |||
| from .. import numpy as mnp | |||
| from .. import ops | |||
| from .ops import SolveTriangular | |||
| __all__ = ['block_diag'] | |||
| __all__ = ['block_diag', 'solve_triangular'] | |||
| def block_diag(*arrs): | |||
| @@ -82,3 +83,69 @@ def block_diag(*arrs): | |||
| accum = ops.Pad(((0, 0), (0, c)))(accum) | |||
| accum = mnp.concatenate([accum, arr], axis=0) | |||
| return accum | |||
| def solve_triangular(A, b, trans=0, lower=False, unit_diagonal=False, | |||
| overwrite_b=False, debug=None, check_finite=True): | |||
| """ | |||
| Solve the equation `a x = b` for `x`, assuming a is a triangular matrix. | |||
| Args: | |||
| A (Tensor): A triangular matrix of shape :math:`(N, N)`. | |||
| b (Tensor): A tensor of shape :math:`(M,)` or :math:`(M, N)`. | |||
| Right-hand side matrix in :math:`A x = b`. | |||
| lower (bool, optional): Use only data contained in the lower triangle of `a`. | |||
| Default is to use upper triangle. | |||
| trans (0, 1, 2, 'N', 'T', 'C', optional): | |||
| Type of system to solve: | |||
| ======== ========= | |||
| trans system | |||
| ======== ========= | |||
| 0 or 'N' a x = b | |||
| 1 or 'T' a^T x = b | |||
| 2 or 'C' a^H x = b | |||
| ======== ========= | |||
| unit_diagonal (bool, optional): If True, diagonal elements of :math:`A` are assumed to be 1 and | |||
| will not be referenced. | |||
| overwrite_b (bool, optional): Allow overwriting data in :math:`b` (may enhance performance) | |||
| check_finite (bool, optional): Whether to check that the input matrices contain only finite numbers. | |||
| Disabling may give a performance gain, but may result in problems | |||
| (crashes, non-termination) if the inputs do contain infinities or NaNs. | |||
| Returns: | |||
| x (Tensor): A tensor of shape :math:`(M,)` or :math:`(M, N)`, | |||
| which is the solution to the system :math:`A x = b`. | |||
| Shape of :math:`x` matches :math:`b`. | |||
| Raises: | |||
| LinAlgError: If :math:`A` is singular | |||
| Supported Platforms: | |||
| ``CPU`` ``GPU`` | |||
| Examples: | |||
| Solve the lower triangular system :math:`A x = b`, where: | |||
| [3 0 0 0] [4] | |||
| A = [2 1 0 0] b = [2] | |||
| [1 0 1 0] [4] | |||
| [1 1 1 1] [2] | |||
| >>> import numpy as onp | |||
| >>> from mindspore.common import Tensor | |||
| >>> import mindspore.numpy as mnp | |||
| >>> from mindspore.scipy.linalg import solve_triangular | |||
| >>> A = Tensor(onp.array([[3, 0, 0, 0], [2, 1, 0, 0], [1, 0, 1, 0], [1, 1, 1, 1]], onp.float64)) | |||
| >>> b = Tensor(onp.array([4, 2, 4, 2], onp.float64)) | |||
| >>> x = solve_triangular(A, b, lower=True, unit_diagonal=False, trans='N') | |||
| >>> x | |||
| Tensor(shape=[4], dtype=Float32, value= [ 1.33333337e+00, -6.66666746e-01, 2.66666651e+00, -1.33333313e+00]) | |||
| >>> mnp.dot(A, x) # Check the result | |||
| Tensor(shape=[4], dtype=Float32, value= [ 4.00000000e+00, 2.00000000e+00, 4.00000000e+00, 2.00000000e+00]) | |||
| """ | |||
| if isinstance(trans, int): | |||
| trans_table = ['N', 'T', 'C'] | |||
| trans = trans_table[trans] | |||
| solve = SolveTriangular(lower, unit_diagonal, trans) | |||
| return solve(A, b) | |||
| @@ -0,0 +1,98 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Operators for scipy submodule""" | |||
| from ..ops import PrimitiveWithInfer, prim_attr_register | |||
| from .._checkparam import Validator as validator | |||
| from ..common import dtype as mstype | |||
| class SolveTriangular(PrimitiveWithInfer): | |||
| """ | |||
| SolveTriangular op frontend implementation. | |||
| Args: | |||
| lower (bool): The input Matrix :math:`A` is lower triangular matrix or not. | |||
| unit_diagonal (bool): If True, diagonal elements of :math:`A` are assumed to be 1 and | |||
| will not be referenced. | |||
| trans (0, 1, 2, 'N', 'T', 'C', optional): | |||
| Type of system to solve: | |||
| ======== ========= | |||
| trans system | |||
| ======== ========= | |||
| 0 or 'N' a x = b | |||
| 1 or 'T' a^T x = b | |||
| 2 or 'C' a^H x = b | |||
| ======== ========= | |||
| Inputs: | |||
| - **A** (Tensor) - A triangular matrix of shape :math:`(N, N)`. | |||
| - **b** (Tensor) - A tensor of shape :math:`(M,)` or :math:`(M, N)`. Right-hand side matrix in :math:`A x = b`. | |||
| Returns: | |||
| - **x** (Tensor) - A tensor of shape :math:`(M,)` or :math:`(M, N)`, | |||
| which is the solution to the system :math:`A x = b`. | |||
| Shape of :math:`x` matches :math:`b`. | |||
| Supported Platforms: | |||
| ``GPU`` ``CPU`` | |||
| Examples: | |||
| Solve the lower triangular system :math:`A x = b`, where: | |||
| [3 0 0 0] [4] | |||
| A = [2 1 0 0] b = [2] | |||
| [1 0 1 0] [4] | |||
| [1 1 1 1] [2] | |||
| >>> import numpy as onp | |||
| >>> from mindspore.common import Tensor | |||
| >>> import mindspore.numpy as mnp | |||
| >>> from mindspore.scipy.ops import SolveTriangular | |||
| >>> A = Tensor(onp.array([[3, 0, 0, 0], [2, 1, 0, 0], [1, 0, 1, 0], [1, 1, 1, 1]], onp.float64)) | |||
| >>> b = Tensor(onp.array([4, 2, 4, 2], onp.float64)) | |||
| >>> solve_triangular = SolveTriangular(lower=True, unit_diagonal=False, trans='N') | |||
| >>> x = solve_triangular(A, b) | |||
| >>> x | |||
| Tensor(shape=[4], dtype=Float64, value= [ 1.33333333e+00, -6.66666667e-01, 2.66666667e+00, -1.33333333e+00]) | |||
| >>> mnp.dot(A, x) # Check the result | |||
| Tensor(shape=[4], dtype=Float64, value= [ 4.00000000e+00, 2.00000000e+00, 4.00000000e+00, 2.00000000e+00]) | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, lower: bool, unit_diagonal: bool, trans: str): | |||
| """Initialize SolveTriangular""" | |||
| super(SolveTriangular, self).__init__("SolveTriangular") | |||
| self.lower = validator.check_value_type( | |||
| "lower", lower, [bool], self.name) | |||
| self.unit_diagonal = validator.check_value_type( | |||
| "unit_diagonal", unit_diagonal, [bool], self.name) | |||
| self.trans = validator.check_value_type( | |||
| "trans", trans, [str], self.name) | |||
| self.init_prim_io_names(inputs=['A', 'b'], outputs=['output']) | |||
| def __infer__(self, A, b): | |||
| out_shapes = b['shape'] | |||
| return { | |||
| 'shape': tuple(out_shapes), | |||
| 'dtype': A['dtype'], | |||
| 'value': None | |||
| } | |||
| def infer_dtype(self, x_dtype): | |||
| validator.check_tensor_dtype_valid(x_dtype, [mstype.float32, mstype.float64], | |||
| self.name, True) | |||
| return x_dtype | |||
| @@ -20,52 +20,12 @@ import numpy as np | |||
| from scipy.linalg import solve_triangular | |||
| import mindspore.context as context | |||
| from mindspore import Tensor | |||
| from mindspore.ops import PrimitiveWithInfer, prim_attr_register | |||
| from mindspore._checkparam import Validator as validator | |||
| from mindspore.common import dtype as mstype | |||
| from mindspore.scipy.linalg import solve_triangular as mind_solve | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||
| np.random.seed(0) | |||
| class SolveTriangular(PrimitiveWithInfer): | |||
| """ | |||
| SolveTriangular op frontend implementation | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, lower: bool, unit_diagonal: bool, trans: str): | |||
| """Initialize SolveTriangular""" | |||
| self.lower = validator.check_value_type( | |||
| "lower", lower, [bool], self.name) | |||
| self.unit_diagonal = validator.check_value_type( | |||
| "unit_diagonal", unit_diagonal, [bool], self.name) | |||
| self.trans = validator.check_value_type( | |||
| "trans", trans, [str], self.name) | |||
| self.init_prim_io_names(inputs=['A', 'b'], outputs=['output']) | |||
| def __infer__(self, A, b): | |||
| out_shapes = b['shape'] | |||
| return { | |||
| 'shape': tuple(out_shapes), | |||
| 'dtype': A['dtype'], | |||
| 'value': None | |||
| } | |||
| def infer_dtype(self, x_dtype): | |||
| validator.check_tensor_dtype_valid(x_dtype, [mstype.float32, mstype.float64], | |||
| self.name, True) | |||
| return x_dtype | |||
| def mind_solve(a, b, trans="N", lower=False, unit_diagonal=False, | |||
| overwrite_b=False, debug=None, check_finite=True): | |||
| solve = SolveTriangular( | |||
| lower=lower, unit_diagonal=unit_diagonal, trans=trans) | |||
| return solve(a, b) | |||
| def match(a, b, lower, unit_diagonal, trans): | |||
| sci_x = solve_triangular( | |||
| a, b, lower=lower, unit_diagonal=unit_diagonal, trans=trans) | |||
| @@ -88,7 +48,7 @@ def match(a, b, lower, unit_diagonal, trans): | |||
| @pytest.mark.parametrize('dtype', [np.float32, np.float64]) | |||
| @pytest.mark.parametrize('lower', [False, True]) | |||
| @pytest.mark.parametrize('unit_diagonal', [False]) | |||
| def test_2D(n: int, dtype, lower: bool, unit_diagonal: bool, trans: str): | |||
| def test_2d(n: int, dtype, lower: bool, unit_diagonal: bool, trans: str): | |||
| """ | |||
| Feature: ALL TO ALL | |||
| Description: test cases for [N x N] X [N X 1] | |||
| @@ -108,7 +68,7 @@ def test_2D(n: int, dtype, lower: bool, unit_diagonal: bool, trans: str): | |||
| @pytest.mark.parametrize('dtype', [np.float32, np.float64]) | |||
| @pytest.mark.parametrize('lower', [False, True]) | |||
| @pytest.mark.parametrize('unit_diagonal', [False, True]) | |||
| def test_1D(n: int, dtype, lower: bool, unit_diagonal: bool, trans: str): | |||
| def test_1d(n: int, dtype, lower: bool, unit_diagonal: bool, trans: str): | |||
| """ | |||
| Feature: ALL TO ALL | |||
| Description: test cases for [N x N] X [N] | |||
| @@ -20,52 +20,12 @@ import numpy as np | |||
| from scipy.linalg import solve_triangular | |||
| import mindspore.context as context | |||
| from mindspore import Tensor | |||
| from mindspore.ops import PrimitiveWithInfer, prim_attr_register | |||
| from mindspore._checkparam import Validator as validator | |||
| from mindspore.common import dtype as mstype | |||
| from mindspore.scipy.linalg import solve_triangular as mind_solve | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| np.random.seed(0) | |||
| class SolveTriangular(PrimitiveWithInfer): | |||
| """ | |||
| SolveTriangular op frontend implementation | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, lower: bool, unit_diagonal: bool, trans: str): | |||
| """Initialize SolveTriangular""" | |||
| self.lower = validator.check_value_type( | |||
| "lower", lower, [bool], self.name) | |||
| self.unit_diagonal = validator.check_value_type( | |||
| "unit_diagonal", unit_diagonal, [bool], self.name) | |||
| self.trans = validator.check_value_type( | |||
| "trans", trans, [str], self.name) | |||
| self.init_prim_io_names(inputs=['A', 'b'], outputs=['output']) | |||
| def __infer__(self, A, b): | |||
| out_shapes = b['shape'] | |||
| return { | |||
| 'shape': tuple(out_shapes), | |||
| 'dtype': A['dtype'], | |||
| 'value': None | |||
| } | |||
| def infer_dtype(self, x_dtype): | |||
| validator.check_tensor_dtype_valid(x_dtype, [mstype.float32, mstype.float64], | |||
| self.name, True) | |||
| return x_dtype | |||
| def mind_solve(a, b, trans="N", lower=False, unit_diagonal=False, | |||
| overwrite_b=False, debug=None, check_finite=True): | |||
| solve = SolveTriangular( | |||
| lower=lower, unit_diagonal=unit_diagonal, trans=trans) | |||
| return solve(a, b) | |||
| def match(a, b, lower, unit_diagonal, trans): | |||
| sci_x = solve_triangular( | |||
| a, b, lower=lower, unit_diagonal=unit_diagonal, trans=trans) | |||
| @@ -86,7 +46,7 @@ def match(a, b, lower, unit_diagonal, trans): | |||
| @pytest.mark.parametrize('dtype', [np.float32, np.float64]) | |||
| @pytest.mark.parametrize('lower', [False, True]) | |||
| @pytest.mark.parametrize('unit_diagonal', [False]) | |||
| def test_2D(n: int, dtype, lower: bool, unit_diagonal: bool, trans: str): | |||
| def test_2d(n: int, dtype, lower: bool, unit_diagonal: bool, trans: str): | |||
| """ | |||
| Feature: ALL TO ALL | |||
| Description: test cases for [N x N] X [N X 1] | |||
| @@ -106,7 +66,7 @@ def test_2D(n: int, dtype, lower: bool, unit_diagonal: bool, trans: str): | |||
| @pytest.mark.parametrize('dtype', [np.float32, np.float64]) | |||
| @pytest.mark.parametrize('lower', [False, True]) | |||
| @pytest.mark.parametrize('unit_diagonal', [False, True]) | |||
| def test_1D(n: int, dtype, lower: bool, unit_diagonal: bool, trans: str): | |||
| def test_1d(n: int, dtype, lower: bool, unit_diagonal: bool, trans: str): | |||
| """ | |||
| Feature: ALL TO ALL | |||
| Description: test cases for [N x N] X [N] | |||