Browse Source

add solve_triangular

tags/v1.6.0
zhujingxuan 4 years ago
parent
commit
ec809b6134
4 changed files with 172 additions and 87 deletions
  1. +68
    -1
      mindspore/scipy/linalg.py
  2. +98
    -0
      mindspore/scipy/ops.py
  3. +3
    -43
      tests/st/ops/cpu/test_solve_triangular_op.py
  4. +3
    -43
      tests/st/ops/gpu/test_solve_triangular_op.py

+ 68
- 1
mindspore/scipy/linalg.py View File

@@ -15,8 +15,9 @@
"""Linear algebra submodule"""
from .. import numpy as mnp
from .. import ops
from .ops import SolveTriangular

__all__ = ['block_diag']
__all__ = ['block_diag', 'solve_triangular']


def block_diag(*arrs):
@@ -82,3 +83,69 @@ def block_diag(*arrs):
accum = ops.Pad(((0, 0), (0, c)))(accum)
accum = mnp.concatenate([accum, arr], axis=0)
return accum


def solve_triangular(A, b, trans=0, lower=False, unit_diagonal=False,
overwrite_b=False, debug=None, check_finite=True):
"""
Solve the equation `a x = b` for `x`, assuming a is a triangular matrix.

Args:
A (Tensor): A triangular matrix of shape :math:`(N, N)`.
b (Tensor): A tensor of shape :math:`(M,)` or :math:`(M, N)`.
Right-hand side matrix in :math:`A x = b`.
lower (bool, optional): Use only data contained in the lower triangle of `a`.
Default is to use upper triangle.
trans (0, 1, 2, 'N', 'T', 'C', optional):
Type of system to solve:

======== =========
trans system
======== =========
0 or 'N' a x = b
1 or 'T' a^T x = b
2 or 'C' a^H x = b
======== =========
unit_diagonal (bool, optional): If True, diagonal elements of :math:`A` are assumed to be 1 and
will not be referenced.
overwrite_b (bool, optional): Allow overwriting data in :math:`b` (may enhance performance)
check_finite (bool, optional): Whether to check that the input matrices contain only finite numbers.
Disabling may give a performance gain, but may result in problems
(crashes, non-termination) if the inputs do contain infinities or NaNs.

Returns:
x (Tensor): A tensor of shape :math:`(M,)` or :math:`(M, N)`,
which is the solution to the system :math:`A x = b`.
Shape of :math:`x` matches :math:`b`.

Raises:
LinAlgError: If :math:`A` is singular

Supported Platforms:
``CPU`` ``GPU``

Examples:
Solve the lower triangular system :math:`A x = b`, where:

[3 0 0 0] [4]
A = [2 1 0 0] b = [2]
[1 0 1 0] [4]
[1 1 1 1] [2]

>>> import numpy as onp
>>> from mindspore.common import Tensor
>>> import mindspore.numpy as mnp
>>> from mindspore.scipy.linalg import solve_triangular
>>> A = Tensor(onp.array([[3, 0, 0, 0], [2, 1, 0, 0], [1, 0, 1, 0], [1, 1, 1, 1]], onp.float64))
>>> b = Tensor(onp.array([4, 2, 4, 2], onp.float64))
>>> x = solve_triangular(A, b, lower=True, unit_diagonal=False, trans='N')
>>> x
Tensor(shape=[4], dtype=Float32, value= [ 1.33333337e+00, -6.66666746e-01, 2.66666651e+00, -1.33333313e+00])
>>> mnp.dot(A, x) # Check the result
Tensor(shape=[4], dtype=Float32, value= [ 4.00000000e+00, 2.00000000e+00, 4.00000000e+00, 2.00000000e+00])
"""
if isinstance(trans, int):
trans_table = ['N', 'T', 'C']
trans = trans_table[trans]
solve = SolveTriangular(lower, unit_diagonal, trans)
return solve(A, b)

+ 98
- 0
mindspore/scipy/ops.py View File

@@ -0,0 +1,98 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Operators for scipy submodule"""
from ..ops import PrimitiveWithInfer, prim_attr_register
from .._checkparam import Validator as validator
from ..common import dtype as mstype


class SolveTriangular(PrimitiveWithInfer):
"""
SolveTriangular op frontend implementation.

Args:
lower (bool): The input Matrix :math:`A` is lower triangular matrix or not.
unit_diagonal (bool): If True, diagonal elements of :math:`A` are assumed to be 1 and
will not be referenced.
trans (0, 1, 2, 'N', 'T', 'C', optional):
Type of system to solve:

======== =========
trans system
======== =========
0 or 'N' a x = b
1 or 'T' a^T x = b
2 or 'C' a^H x = b
======== =========

Inputs:
- **A** (Tensor) - A triangular matrix of shape :math:`(N, N)`.
- **b** (Tensor) - A tensor of shape :math:`(M,)` or :math:`(M, N)`. Right-hand side matrix in :math:`A x = b`.

Returns:
- **x** (Tensor) - A tensor of shape :math:`(M,)` or :math:`(M, N)`,
which is the solution to the system :math:`A x = b`.
Shape of :math:`x` matches :math:`b`.

Supported Platforms:
``GPU`` ``CPU``

Examples:
Solve the lower triangular system :math:`A x = b`, where:

[3 0 0 0] [4]
A = [2 1 0 0] b = [2]
[1 0 1 0] [4]
[1 1 1 1] [2]

>>> import numpy as onp
>>> from mindspore.common import Tensor
>>> import mindspore.numpy as mnp
>>> from mindspore.scipy.ops import SolveTriangular
>>> A = Tensor(onp.array([[3, 0, 0, 0], [2, 1, 0, 0], [1, 0, 1, 0], [1, 1, 1, 1]], onp.float64))
>>> b = Tensor(onp.array([4, 2, 4, 2], onp.float64))
>>> solve_triangular = SolveTriangular(lower=True, unit_diagonal=False, trans='N')
>>> x = solve_triangular(A, b)
>>> x
Tensor(shape=[4], dtype=Float64, value= [ 1.33333333e+00, -6.66666667e-01, 2.66666667e+00, -1.33333333e+00])
>>> mnp.dot(A, x) # Check the result
Tensor(shape=[4], dtype=Float64, value= [ 4.00000000e+00, 2.00000000e+00, 4.00000000e+00, 2.00000000e+00])
"""

@prim_attr_register
def __init__(self, lower: bool, unit_diagonal: bool, trans: str):
"""Initialize SolveTriangular"""
super(SolveTriangular, self).__init__("SolveTriangular")
self.lower = validator.check_value_type(
"lower", lower, [bool], self.name)
self.unit_diagonal = validator.check_value_type(
"unit_diagonal", unit_diagonal, [bool], self.name)
self.trans = validator.check_value_type(
"trans", trans, [str], self.name)

self.init_prim_io_names(inputs=['A', 'b'], outputs=['output'])

def __infer__(self, A, b):
out_shapes = b['shape']
return {
'shape': tuple(out_shapes),
'dtype': A['dtype'],
'value': None
}

def infer_dtype(self, x_dtype):
validator.check_tensor_dtype_valid(x_dtype, [mstype.float32, mstype.float64],
self.name, True)
return x_dtype

+ 3
- 43
tests/st/ops/cpu/test_solve_triangular_op.py View File

@@ -20,52 +20,12 @@ import numpy as np
from scipy.linalg import solve_triangular
import mindspore.context as context
from mindspore import Tensor
from mindspore.ops import PrimitiveWithInfer, prim_attr_register
from mindspore._checkparam import Validator as validator
from mindspore.common import dtype as mstype
from mindspore.scipy.linalg import solve_triangular as mind_solve

context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
np.random.seed(0)


class SolveTriangular(PrimitiveWithInfer):
"""
SolveTriangular op frontend implementation
"""

@prim_attr_register
def __init__(self, lower: bool, unit_diagonal: bool, trans: str):
"""Initialize SolveTriangular"""
self.lower = validator.check_value_type(
"lower", lower, [bool], self.name)
self.unit_diagonal = validator.check_value_type(
"unit_diagonal", unit_diagonal, [bool], self.name)
self.trans = validator.check_value_type(
"trans", trans, [str], self.name)

self.init_prim_io_names(inputs=['A', 'b'], outputs=['output'])

def __infer__(self, A, b):
out_shapes = b['shape']
return {
'shape': tuple(out_shapes),
'dtype': A['dtype'],
'value': None
}

def infer_dtype(self, x_dtype):
validator.check_tensor_dtype_valid(x_dtype, [mstype.float32, mstype.float64],
self.name, True)
return x_dtype


def mind_solve(a, b, trans="N", lower=False, unit_diagonal=False,
overwrite_b=False, debug=None, check_finite=True):
solve = SolveTriangular(
lower=lower, unit_diagonal=unit_diagonal, trans=trans)
return solve(a, b)


def match(a, b, lower, unit_diagonal, trans):
sci_x = solve_triangular(
a, b, lower=lower, unit_diagonal=unit_diagonal, trans=trans)
@@ -88,7 +48,7 @@ def match(a, b, lower, unit_diagonal, trans):
@pytest.mark.parametrize('dtype', [np.float32, np.float64])
@pytest.mark.parametrize('lower', [False, True])
@pytest.mark.parametrize('unit_diagonal', [False])
def test_2D(n: int, dtype, lower: bool, unit_diagonal: bool, trans: str):
def test_2d(n: int, dtype, lower: bool, unit_diagonal: bool, trans: str):
"""
Feature: ALL TO ALL
Description: test cases for [N x N] X [N X 1]
@@ -108,7 +68,7 @@ def test_2D(n: int, dtype, lower: bool, unit_diagonal: bool, trans: str):
@pytest.mark.parametrize('dtype', [np.float32, np.float64])
@pytest.mark.parametrize('lower', [False, True])
@pytest.mark.parametrize('unit_diagonal', [False, True])
def test_1D(n: int, dtype, lower: bool, unit_diagonal: bool, trans: str):
def test_1d(n: int, dtype, lower: bool, unit_diagonal: bool, trans: str):
"""
Feature: ALL TO ALL
Description: test cases for [N x N] X [N]


+ 3
- 43
tests/st/ops/gpu/test_solve_triangular_op.py View File

@@ -20,52 +20,12 @@ import numpy as np
from scipy.linalg import solve_triangular
import mindspore.context as context
from mindspore import Tensor
from mindspore.ops import PrimitiveWithInfer, prim_attr_register
from mindspore._checkparam import Validator as validator
from mindspore.common import dtype as mstype
from mindspore.scipy.linalg import solve_triangular as mind_solve

context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
np.random.seed(0)


class SolveTriangular(PrimitiveWithInfer):
"""
SolveTriangular op frontend implementation
"""

@prim_attr_register
def __init__(self, lower: bool, unit_diagonal: bool, trans: str):
"""Initialize SolveTriangular"""
self.lower = validator.check_value_type(
"lower", lower, [bool], self.name)
self.unit_diagonal = validator.check_value_type(
"unit_diagonal", unit_diagonal, [bool], self.name)
self.trans = validator.check_value_type(
"trans", trans, [str], self.name)

self.init_prim_io_names(inputs=['A', 'b'], outputs=['output'])

def __infer__(self, A, b):
out_shapes = b['shape']
return {
'shape': tuple(out_shapes),
'dtype': A['dtype'],
'value': None
}

def infer_dtype(self, x_dtype):
validator.check_tensor_dtype_valid(x_dtype, [mstype.float32, mstype.float64],
self.name, True)
return x_dtype


def mind_solve(a, b, trans="N", lower=False, unit_diagonal=False,
overwrite_b=False, debug=None, check_finite=True):
solve = SolveTriangular(
lower=lower, unit_diagonal=unit_diagonal, trans=trans)
return solve(a, b)


def match(a, b, lower, unit_diagonal, trans):
sci_x = solve_triangular(
a, b, lower=lower, unit_diagonal=unit_diagonal, trans=trans)
@@ -86,7 +46,7 @@ def match(a, b, lower, unit_diagonal, trans):
@pytest.mark.parametrize('dtype', [np.float32, np.float64])
@pytest.mark.parametrize('lower', [False, True])
@pytest.mark.parametrize('unit_diagonal', [False])
def test_2D(n: int, dtype, lower: bool, unit_diagonal: bool, trans: str):
def test_2d(n: int, dtype, lower: bool, unit_diagonal: bool, trans: str):
"""
Feature: ALL TO ALL
Description: test cases for [N x N] X [N X 1]
@@ -106,7 +66,7 @@ def test_2D(n: int, dtype, lower: bool, unit_diagonal: bool, trans: str):
@pytest.mark.parametrize('dtype', [np.float32, np.float64])
@pytest.mark.parametrize('lower', [False, True])
@pytest.mark.parametrize('unit_diagonal', [False, True])
def test_1D(n: int, dtype, lower: bool, unit_diagonal: bool, trans: str):
def test_1d(n: int, dtype, lower: bool, unit_diagonal: bool, trans: str):
"""
Feature: ALL TO ALL
Description: test cases for [N x N] X [N]


Loading…
Cancel
Save