Merge pull request !30735 from hezhenhao1/fix_cgfeature/build-system-rewrite
| @@ -21,7 +21,7 @@ from ..linalg import solve_triangular | |||
| from ..linalg import cho_factor, cho_solve | |||
| from ..utils import _normalize_matvec, _to_tensor, _safe_normalize, _eps, _norm, _type_check, _value_check, \ | |||
| _sparse_check | |||
| from ..utils_const import _raise_value_error, _raise_type_error | |||
| from ..utils_const import _raise_value_error, _raise_type_error, _nullable_const | |||
| def gram_schmidt(Q, q): | |||
| @@ -323,6 +323,45 @@ class CG(nn.Cell): | |||
| return x, F.select(_norm(r) > atol_, k, _INT_ZERO) | |||
| class CGv2(nn.Cell): | |||
| """ | |||
| This is a new version of CG, which contains all parameters in a graph. | |||
| """ | |||
| def __init__(self): | |||
| super(CGv2, self).__init__() | |||
| def construct(self, A, M, b, x0, tol, atol, maxiter): | |||
| # Constant tensor which avoids loop unrolling | |||
| _INT_ZERO = _to_tensor(0) | |||
| A = _normalize_matvec(A) | |||
| M = _normalize_matvec(M) | |||
| atol_ = mnp.maximum(atol, tol * _norm(b)) | |||
| r = b - A(x0) | |||
| z = p = M(r) | |||
| rho = mnp.dot(r, z) | |||
| k = _INT_ZERO | |||
| x = x0 | |||
| while k < maxiter and _norm(r) > atol_: | |||
| q = A(p) | |||
| alpha = rho / mnp.dot(p, q) | |||
| x = x + alpha * p | |||
| r = r - alpha * q | |||
| z = M(r) | |||
| rho_ = mnp.dot(r, z) | |||
| beta = rho_ / rho | |||
| p = z + beta * p | |||
| rho = rho_ | |||
| k += 1 | |||
| return x, F.select(_norm(r) > atol_, k, _INT_ZERO) | |||
| def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None, callback=None): | |||
| """Use Conjugate Gradient iteration to solve the linear system: | |||
| @@ -343,7 +382,7 @@ def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None, callback=None | |||
| - `cg` is not supported on Windows platform yet. | |||
| Args: | |||
| A (Union[Tensor, function]): 2D Tensor or function that calculates the linear | |||
| A (Union[Tensor, CSRTensor, function]): 2D Tensor, CSRTensor or function that calculates the linear | |||
| map (matrix-vector product) :math:`Ax` when called like :math:`A(x)`. | |||
| As function, `A` must return Tensor with the same structure and shape as its input matrix. | |||
| b (Tensor): Right hand side of the linear system representing a single vector. Can be | |||
| @@ -372,8 +411,8 @@ def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None, callback=None | |||
| TypeError: If `atol` is not float. | |||
| TypeError: If `maxiter` is not int. | |||
| ValueError: If `callback` is not None. | |||
| TypeError: If `A` is not Tensor or Function. | |||
| TypeError: If `M` is not None, Tensor or Function. | |||
| TypeError: If `A` is not Tensor, CSRTensor, or Function. | |||
| TypeError: If `M` is not None, Tensor, CSRTensor, or Function. | |||
| TypeError: If `b` is not Tensor. | |||
| TypeError: If `x0` is not None or Tensor. | |||
| ValueError: If `b` is not 1 or 2 dimension. | |||
| @@ -411,9 +450,12 @@ def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None, callback=None | |||
| _type_check(func_name, atol, float, 'atol') | |||
| _type_check(func_name, maxiter, int, 'maxiter') | |||
| _value_check(func_name, callback, None, 'callback', op='is', fmt='todo') | |||
| _sparse_check(func_name, A, M, b, x0) | |||
| A, M, b, x0 = _sparse_check(func_name, A, M, b, x0) | |||
| x, info = CG(A, M)(b, x0, tol, atol, maxiter) | |||
| if not _nullable_const(A): | |||
| x, info = CG(A, M)(b, x0, tol, atol, maxiter) | |||
| else: | |||
| x, info = CGv2()(A, M, b, x0, tol, atol, maxiter) | |||
| return x, info | |||
| @@ -17,7 +17,7 @@ from .. import ops | |||
| from .. import numpy as mnp | |||
| from ..numpy import where, zeros_like, dot, greater | |||
| from ..ops import functional as F | |||
| from ..common import Tensor | |||
| from ..common import Tensor, CSRTensor | |||
| from ..common import dtype as mstype | |||
| from .utils_const import _type_convert, _raise_value_error, _callable_const, _super_check, pack | |||
| from ..ops.composite import GradOperation | |||
| @@ -85,19 +85,22 @@ def _safe_normalize(x, threshold=None): | |||
| return normalized_x, norm | |||
| def sparse_dot(a, b): | |||
| """Returns the dot product of CSRTensor and generic Tensor(vector).""" | |||
| b_aligned = F.reshape(b, (b.shape[0], -1)) | |||
| res = F.csr_mv(a, b_aligned) | |||
| res = F.reshape(res, a.shape[:-1] + b.shape[1:]) | |||
| return res | |||
| def _normalize_matvec(f): | |||
| """Normalize an argument for computing matrix-vector products.""" | |||
| if _callable_const(F.typeof(f)): | |||
| return f | |||
| if isinstance(f, Tensor): | |||
| if f.ndim != 2 or f.shape[0] != f.shape[1]: | |||
| _raise_value_error( | |||
| 'linear operator must be a square matrix, but has shape: ', f.shape, ".") | |||
| return F.partial(dot, f) | |||
| _raise_value_error( | |||
| 'linear operator must be either a function or Tensor: but got ', F.typeof(f), ".") | |||
| if isinstance(f, CSRTensor): | |||
| return F.partial(sparse_dot, f) | |||
| return f | |||
| @@ -119,11 +122,11 @@ def _nd_transpose(a): | |||
| def _value_check(func_name, arg1, arg2, arg_name='', attr_name='', op="in", fmt="attr", msg=None): | |||
| return _super_check((arg1, arg2), (func_name, arg_name, attr_name), op, fmt, msg, True) | |||
| return _super_check(pack(arg1, arg2), (func_name, arg_name, attr_name), op, fmt, msg, True) | |||
| def _type_check(func_name, arg1, arg2, arg_name='', op="isinstance", fmt="type", msg=None): | |||
| return _super_check((arg1, arg2), (func_name, arg_name), op, fmt, msg, False) | |||
| return _super_check(pack(arg1, arg2), (func_name, arg_name), op, fmt, msg, False) | |||
| def _mstype_check(func_name, arg, arg_mstype, arg_name='a'): | |||
| @@ -132,24 +135,25 @@ def _mstype_check(func_name, arg, arg_mstype, arg_name='a'): | |||
| def _dtype_check(func_name, arg, arg_dtype, arg_name='a'): | |||
| return _super_check((arg.dtype, arg_dtype), (func_name, arg_name, "data type"), "in", "attr", None, False) | |||
| return _super_check((F.dtype(arg), arg_dtype), (func_name, arg_name, "data type"), "in", "attr", None, False) | |||
| def _square_check(func_name, arg, arg_name='a'): | |||
| arg_shape = arg.shape | |||
| _super_check((len(arg_shape), 2), (func_name, arg_name, 'dimension'), '==', 'attr', None, True) | |||
| _super_check(arg_shape, (func_name, arg_name), '==', 'square', None, True) | |||
| return func_name | |||
| return arg | |||
| def _solve_check(func_name, arg1, arg2, arg1_name='a', arg2_name='b', sparse=False): | |||
| arg1_shape, arg1_dtype = arg1.shape, arg1.dtype | |||
| arg2_shape, arg2_dtype = arg2.shape, arg2.dtype | |||
| arg1_shape, arg1_dtype = arg1.shape, F.dtype(arg1) | |||
| arg2_shape, arg2_dtype = arg2.shape, F.dtype(arg2) | |||
| _square_check(func_name, arg1, arg1_name) | |||
| _super_check((len(arg2_shape), (1, 2)), (func_name, arg2_name, 'dimension'), 'in', 'attr', None, True) | |||
| _super_check((arg1_shape, arg2_shape), (func_name, arg1_name, arg2_name, sparse), 'solve', 'solve', None, True) | |||
| _super_check((arg1_dtype, arg2_dtype), (func_name, arg1_name, arg2_name, 'data type'), '==', 'match', None, False) | |||
| return func_name | |||
| return arg1, arg2 | |||
| def _sparse_check(func_name, a, m, b, x0): | |||
| """Used for cg, bicgstab and gmres method.""" | |||
| @@ -163,14 +167,28 @@ def _sparse_check(func_name, a, m, b, x0): | |||
| if b.ndim != 1 or (b.ndim == 2 and b.shape[1] != 1): | |||
| _raise_value_error( | |||
| "For: '", func_name, "', the shape of b should be like (N,) or (N, 1), bug got ", b.shape, ".") | |||
| _super_check((b.dtype, [mstype.int32, mstype.int64, mstype.float32, mstype.float64]), | |||
| (func_name, 'b', "data type"), "in", "attr", None, False) | |||
| _dtype_check(func_name, b, [mstype.int32, mstype.int64, mstype.float32, mstype.float64], 'b') | |||
| _super_check((b.dtype, x0.dtype), (func_name, 'b', 'x0', 'data type'), '==', 'match', None, True) | |||
| _super_check((b.shape, x0.shape), (func_name, 'b', 'x0', 'shape'), '==', 'match', None, True) | |||
| if not _callable_const(F.typeof(a)): | |||
| _solve_check(func_name, a, b, 'A', 'b', True) | |||
| def _check(arg, arg_name): | |||
| if _callable_const(F.typeof(arg)): | |||
| return arg | |||
| if not _callable_const(F.typeof(m)): | |||
| _solve_check(func_name, m, b, 'M', 'b', True) | |||
| return func_name | |||
| _solve_check(func_name, arg, b, arg_name, 'b', True) | |||
| if isinstance(arg, CSRTensor): | |||
| _dtype_check(func_name, arg.indptr, [mstype.int32], arg_name) | |||
| _dtype_check(func_name, arg.indices, [mstype.int32], arg_name) | |||
| _dtype_check(func_name, arg.values, [mstype.float32], arg_name) | |||
| else: | |||
| _dtype_check(func_name, arg, [mstype.int32, mstype.int64, mstype.float32, mstype.float64], arg_name) | |||
| if F.dtype(arg) in (mstype.int32, mstype.int64): | |||
| arg = F.cast(arg, mstype.float64) | |||
| return arg | |||
| a = _check(a, 'A') | |||
| m = _check(m, 'M') | |||
| if F.dtype(b) in (mstype.int32, mstype.int64): | |||
| b = F.cast(b, mstype.float64) | |||
| x0 = F.cast(x0, mstype.float64) | |||
| return a, m, b, x0 | |||
| @@ -26,6 +26,15 @@ def _callable_const(x): | |||
| return isinstance(x, mstype.function_type) | |||
| @constexpr | |||
| def _nullable_const(x): | |||
| """ | |||
| Returns true if x is None. It's aim to check whether the call is within MindSpore graph. | |||
| Because in graph mode, x should be None in constexpr when x is a variable of MindSpore. | |||
| """ | |||
| return x is None | |||
| @constexpr | |||
| def _type_convert(new_type, obj): | |||
| """ | |||
| @@ -18,10 +18,11 @@ import numpy as onp | |||
| import scipy as osp | |||
| import scipy.sparse.linalg | |||
| import mindspore.nn as nn | |||
| import mindspore.scipy as msp | |||
| from mindspore import context | |||
| from mindspore.common import Tensor | |||
| from tests.st.scipy_st.utils import create_sym_pos_matrix, create_full_rank_matrix | |||
| from mindspore.common import Tensor, CSRTensor | |||
| from tests.st.scipy_st.utils import create_sym_pos_matrix, create_full_rank_matrix, create_sym_pos_sparse_matrix | |||
| def _fetch_preconditioner(preconditioner, A): | |||
| @@ -46,38 +47,39 @@ def _fetch_preconditioner(preconditioner, A): | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| @pytest.mark.parametrize('dtype_tol', [(onp.float32, 1e-5), (onp.float64, 1e-12)]) | |||
| @pytest.mark.parametrize('dtype, tol', [(onp.float32, 1e-5), (onp.float64, 1e-12)]) | |||
| @pytest.mark.parametrize('shape', [(4, 4), (7, 7)]) | |||
| @pytest.mark.parametrize('preconditioner', [None, 'identity', 'exact', 'random']) | |||
| @pytest.mark.parametrize('maxiter', [1, 3]) | |||
| def test_cg_against_scipy(dtype_tol, shape, preconditioner, maxiter): | |||
| def test_cg_against_scipy(dtype, tol, shape, preconditioner, maxiter): | |||
| """ | |||
| Feature: ALL TO ALL | |||
| Description: test cases for cg | |||
| Expectation: the result match scipy | |||
| """ | |||
| onp.random.seed(0) | |||
| dtype, tol = dtype_tol | |||
| A = create_sym_pos_matrix(shape, dtype) | |||
| a = create_sym_pos_matrix(shape, dtype) | |||
| b = onp.random.random(shape[:1]).astype(dtype) | |||
| M = _fetch_preconditioner(preconditioner, A) | |||
| osp_res = scipy.sparse.linalg.cg(A, b, M=M, maxiter=maxiter, atol=tol, tol=tol)[0] | |||
| m = _fetch_preconditioner(preconditioner, a) | |||
| osp_res = scipy.sparse.linalg.cg(a, b, M=m, maxiter=maxiter, atol=tol, tol=tol) | |||
| A = Tensor(A) | |||
| a = Tensor(a) | |||
| b = Tensor(b) | |||
| M = Tensor(M) if M is not None else M | |||
| m = Tensor(m) if m is not None else m | |||
| # using PYNATIVE MODE | |||
| context.set_context(mode=context.PYNATIVE_MODE) | |||
| msp_res_dyn = msp.sparse.linalg.cg(A, b, M=M, maxiter=maxiter, atol=tol, tol=tol)[0] | |||
| msp_res_dyn = msp.sparse.linalg.cg(a, b, M=m, maxiter=maxiter, atol=tol, tol=tol) | |||
| # using GRAPH MODE | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| msp_res_sta = msp.sparse.linalg.cg(A, b, M=M, maxiter=maxiter, atol=tol, tol=tol)[0] | |||
| msp_res_sta = msp.sparse.linalg.cg(a, b, M=m, maxiter=maxiter, atol=tol, tol=tol) | |||
| kw = {"atol": tol, "rtol": tol} | |||
| onp.testing.assert_allclose(osp_res, msp_res_dyn.asnumpy(), **kw) | |||
| onp.testing.assert_allclose(osp_res, msp_res_sta.asnumpy(), **kw) | |||
| onp.testing.assert_allclose(osp_res[0], msp_res_dyn[0].asnumpy(), **kw) | |||
| onp.testing.assert_allclose(osp_res[0], msp_res_sta[0].asnumpy(), **kw) | |||
| assert osp_res[1] == msp_res_dyn[1].asnumpy().item() | |||
| assert osp_res[1] == msp_res_sta[1].asnumpy().item() | |||
| @pytest.mark.level0 | |||
| @@ -93,23 +95,97 @@ def test_cg_against_numpy(dtype, shape): | |||
| Expectation: the result match numpy | |||
| """ | |||
| onp.random.seed(0) | |||
| A = create_sym_pos_matrix(shape, dtype) | |||
| a = create_sym_pos_matrix(shape, dtype) | |||
| b = onp.random.random(shape[:1]).astype(dtype) | |||
| expected = onp.linalg.solve(A, b) | |||
| expected = onp.linalg.solve(a, b) | |||
| # using PYNATIVE MODE | |||
| context.set_context(mode=context.PYNATIVE_MODE) | |||
| actual_dyn, _ = msp.sparse.linalg.cg(Tensor(A), Tensor(b)) | |||
| actual_dyn, _ = msp.sparse.linalg.cg(Tensor(a), Tensor(b)) | |||
| # using GRAPH MODE | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| actual_sta, _ = msp.sparse.linalg.cg(Tensor(A), Tensor(b)) | |||
| actual_sta, _ = msp.sparse.linalg.cg(Tensor(a), Tensor(b)) | |||
| kw = {"atol": 1e-5, "rtol": 1e-5} | |||
| onp.testing.assert_allclose(expected, actual_dyn.asnumpy(), **kw) | |||
| onp.testing.assert_allclose(expected, actual_sta.asnumpy(), **kw) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| @pytest.mark.parametrize('dtype, tol', [(onp.float32, 1e-5), (onp.float64, 1e-12)]) | |||
| @pytest.mark.parametrize('shape', [(7, 7)]) | |||
| @pytest.mark.parametrize('preconditioner', [None, 'identity', 'exact', 'random']) | |||
| @pytest.mark.parametrize('maxiter', [3]) | |||
| def test_cg_against_scipy_graph(dtype, tol, shape, preconditioner, maxiter): | |||
| """ | |||
| Feature: ALL TO ALL | |||
| Description: test cases for cg within Cell object | |||
| Expectation: the result match scipy | |||
| """ | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| class TestNet(nn.Cell): | |||
| def construct(self, a, b, m, maxiter, tol): | |||
| return msp.sparse.linalg.cg(a, b, M=m, maxiter=maxiter, atol=tol, tol=tol) | |||
| onp.random.seed(0) | |||
| a = create_sym_pos_matrix(shape, dtype) | |||
| b = onp.random.random(shape[:1]).astype(dtype) | |||
| m = _fetch_preconditioner(preconditioner, a) | |||
| osp_res = scipy.sparse.linalg.cg(a, b, M=m, maxiter=maxiter, atol=tol, tol=tol) | |||
| a = Tensor(a) | |||
| b = Tensor(b) | |||
| m = Tensor(m) if m is not None else m | |||
| msp_res = TestNet()(a, b, m, maxiter, tol) | |||
| kw = {"atol": tol, "rtol": tol} | |||
| onp.testing.assert_allclose(osp_res[0], msp_res[0].asnumpy(), **kw) | |||
| assert osp_res[1] == msp_res[1].asnumpy().item() | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| @pytest.mark.parametrize('dtype, tol', [(onp.float32, 1e-5)]) | |||
| @pytest.mark.parametrize('shape', [(7, 7)]) | |||
| @pytest.mark.parametrize('preconditioner', [None, 'identity', 'random']) | |||
| @pytest.mark.parametrize('maxiter', [3]) | |||
| def test_cg_against_scipy_sparse(dtype, tol, shape, preconditioner, maxiter): | |||
| """ | |||
| Feature: ALL TO ALL | |||
| Description: test cases of CSRTensor for cg | |||
| Expectation: the result match scipy. | |||
| """ | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| class TestNet(nn.Cell): | |||
| def construct(self, a, b, m, maxiter, tol): | |||
| return msp.sparse.linalg.cg(a, b, M=m, maxiter=maxiter, atol=tol, tol=tol) | |||
| onp.random.seed(0) | |||
| # scipy | |||
| a = create_sym_pos_sparse_matrix(shape, dtype) | |||
| b = onp.random.random(shape[:1]).astype(dtype) | |||
| m = _fetch_preconditioner(preconditioner, a) | |||
| osp_res = scipy.sparse.linalg.cg(a, b, M=m, maxiter=maxiter, atol=tol, tol=tol) | |||
| # mindspore | |||
| a = CSRTensor(Tensor(a.indptr), Tensor(a.indices), Tensor(a.data), shape) | |||
| b = Tensor(b) | |||
| m = Tensor(m) if m is not None else m | |||
| msp_res = TestNet()(a, b, m, maxiter, tol) | |||
| kw = {"atol": tol, "rtol": tol} | |||
| onp.testing.assert_allclose(osp_res[0], msp_res[0].asnumpy(), **kw) | |||
| assert osp_res[1] == msp_res[1].asnumpy().item() | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @@ -17,6 +17,7 @@ from typing import List | |||
| from functools import cmp_to_key | |||
| import numpy as onp | |||
| import scipy.sparse.linalg | |||
| from mindspore import Tensor | |||
| import mindspore.ops as ops | |||
| import mindspore.numpy as mnp | |||
| @@ -98,6 +99,18 @@ def create_sym_pos_matrix(shape, dtype): | |||
| return (onp.matmul(x, x.T) + onp.eye(n)).astype(dtype) | |||
| def create_sym_pos_sparse_matrix(shape, dtype, indice_dtype=onp.int32): | |||
| if len(shape) != 2 or shape[0] != shape[1]: | |||
| raise ValueError( | |||
| 'Symmetric positive definite matrix must be a square matrix, but has shape: ', shape) | |||
| n = shape[-1] | |||
| indptr = onp.arange(n + 1).astype(indice_dtype) | |||
| indices = onp.arange(n).astype(indice_dtype) | |||
| values = onp.random.random(n).astype(dtype) | |||
| return scipy.sparse.csr_matrix((values, indices, indptr), shape=shape) | |||
| def gradient_check(x, net, epsilon=1e-3, enumerate_fn=onp.ndenumerate): | |||
| # some utils | |||
| def _tensor_to_numpy(arg: List[Tensor]) -> List[onp.ndarray]: | |||