Browse Source

!25800 Suport solve complex64/ complex128 Eigen Value and eigen vectros

Merge pull request !25800 from wuwenbing/master
tags/v1.6.0
i-robot Gitee 4 years ago
parent
commit
4771a2006e
3 changed files with 141 additions and 23 deletions
  1. +62
    -17
      mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/eig_cpu_kernel.cc
  2. +2
    -4
      mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/eig_cpu_kernel.h
  3. +77
    -2
      tests/st/ops/cpu/test_solve_eigh_value_op.py

+ 62
- 17
mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/eig_cpu_kernel.cc View File

@@ -15,6 +15,7 @@
*/
#include "backend/kernel_compiler/cpu/eigen/eig_cpu_kernel.h"
#include <Eigen/Eigenvalues>
#include <type_traits>
#include "utils/ms_utils.h"

namespace mindspore {
@@ -34,9 +35,13 @@ using Eigen::Map;
using Eigen::MatrixBase;
using Eigen::RowMajor;
using Eigen::Upper;

template <typename T>
using MatrixSquare = Eigen::Matrix<T, Dynamic, Dynamic, RowMajor>;

template <typename T>
using ComplexMatrixSquare = Eigen::Matrix<std::complex<T>, Dynamic, Dynamic, RowMajor>;

template <typename T, typename C>
void EighCPUKernel<T, C>::InitKernel(const CNodePtr &kernel_node) {
MS_LOG(INFO) << "init eigen value kernel";
@@ -55,6 +60,50 @@ void EighCPUKernel<T, C>::InitKernel(const CNodePtr &kernel_node) {
m_ = A_shape[kDim0];
}

template <typename T, typename C>
bool SolveSelfAdjointMatrix(Map<MatrixSquare<T>> *A, Map<MatrixSquare<C>> *output, Map<MatrixSquare<C>> *outputv,
bool compute_eigen_vectors) {
Eigen::SelfAdjointEigenSolver<MatrixSquare<T>> solver(*A);
output->noalias() = solver.eigenvalues();
if (compute_eigen_vectors) {
outputv->noalias() = solver.eigenvectors();
}
return true;
}

template <typename T, typename C>
bool SolveGenericMatrix(Map<MatrixSquare<T>> *A, Map<MatrixSquare<C>> *output, Map<MatrixSquare<C>> *outputv,
bool compute_eigen_vectors) {
Eigen::EigenSolver<MatrixSquare<T>> solver(*A);
output->noalias() = solver.eigenvalues();
if (compute_eigen_vectors) {
outputv->noalias() = solver.eigenvectors();
}
return true;
}

template <typename T, typename C>
bool SolveRealMatrix(int symmetric_type, Map<MatrixSquare<T>> *A, Map<MatrixSquare<C>> *output,
Map<MatrixSquare<C>> *outputv, bool compute_eigen_vectors) {
if (symmetric_type != 0) {
return SolveSelfAdjointMatrix(A, output, outputv, compute_eigen_vectors);
} else {
// this is for none symmetric matrix eigenvalue and eigen vectors, it should support complex
return SolveGenericMatrix(A, output, outputv, compute_eigen_vectors);
}
}

template <typename T, typename C>
bool SolveComplexMatrix(Map<MatrixSquare<T>> *A, Map<MatrixSquare<C>> *output, Map<MatrixSquare<C>> *outputv,
bool compute_eigen_vectors) {
Eigen::ComplexEigenSolver<MatrixSquare<T>> solver(*A);
output->noalias() = solver.eigenvalues();
if (compute_eigen_vectors) {
outputv->noalias() = solver.eigenvectors();
}
return true;
}

template <typename T, typename C>
bool EighCPUKernel<T, C>::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
@@ -69,24 +118,20 @@ bool EighCPUKernel<T, C>::Launch(const std::vector<AddressPtr> &inputs, const st
Map<MatrixSquare<T>> A(A_addr, m_, m_);
Map<MatrixSquare<C>> output(output_addr, m_, 1);
Map<MatrixSquare<C>> outputv(output_v_addr, m_, m_);
if (*symmetric_type != 0) {
if (*symmetric_type < 0) {
A = A.template selfadjointView<Lower>();
} else {
A = A.template selfadjointView<Upper>();
}
Eigen::SelfAdjointEigenSolver<MatrixSquare<T>> solver(A);
output.noalias() = solver.eigenvalues();
if (compute_eigen_vectors) {
outputv.noalias() = solver.eigenvectors();
}
// selfadjoint matrix
if (*symmetric_type < 0) {
A = A.template selfadjointView<Lower>();
} else if (*symmetric_type > 0) {
A = A.template selfadjointView<Upper>();
}
// Real scalar eigen solver
if constexpr (std::is_same_v<T, float>) {
SolveRealMatrix(*symmetric_type, &A, &output, &outputv, compute_eigen_vectors);
} else if constexpr (std::is_same_v<T, double>) {
SolveRealMatrix(*symmetric_type, &A, &output, &outputv, compute_eigen_vectors);
} else {
// this is for none symmetric matrix eigenvalue and eigen vectors, it should support complex
Eigen::EigenSolver<MatrixSquare<T>> solver(A);
output.noalias() = solver.eigenvalues();
if (compute_eigen_vectors) {
outputv.noalias() = solver.eigenvectors();
}
// complex eigen solver
SolveComplexMatrix(&A, &output, &outputv, compute_eigen_vectors);
}
return true;
}


+ 2
- 4
mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/eig_cpu_kernel.h View File

@@ -27,8 +27,6 @@ namespace kernel {

using float_complex = std::complex<float>;
using double_complex = std::complex<double>;
using c_float_complex = std::complex<float>;
using c_double_complex = std::complex<double>;

template <typename T, typename C>
class EighCPUKernel : public CPUKernel {
@@ -66,14 +64,14 @@ MS_REG_CPU_KERNEL_T_S(Eigh,
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeComplex64)
.AddOutputAttr(kNumberTypeComplex64),
EighCPUKernel, float, c_float_complex);
EighCPUKernel, float_complex, float_complex);
MS_REG_CPU_KERNEL_T_S(Eigh,
KernelAttr()
.AddInputAttr(kNumberTypeComplex128)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeComplex128)
.AddOutputAttr(kNumberTypeComplex128),
EighCPUKernel, double, c_double_complex);
EighCPUKernel, double_complex, double_complex);
} // namespace kernel
} // namespace mindspore



+ 77
- 2
tests/st/ops/cpu/test_solve_eigh_value_op.py View File

@@ -53,6 +53,18 @@ class Eigh(PrimitiveWithInfer):
'dtype': (msp.complex128, msp.complex128),
'value': None
}
elif A['dtype'] == msp.tensor_type(msp.dtype.complex64):
shape = {
'shape': ((A['shape'][0],), (A['shape'][0], A['shape'][0])),
'dtype': (msp.complex64, msp.complex64),
'value': None
}
elif A['dtype'] == msp.tensor_type(msp.dtype.complex128):
shape = {
'shape': ((A['shape'][0],), (A['shape'][0], A['shape'][0])),
'dtype': (msp.complex128, msp.complex128),
'value': None
}
return shape


@@ -90,7 +102,7 @@ def test_eigh_net(n: int, mode):
Description: test cases for eigen decomposition test cases for Ax= lambda * x /( A- lambda * E)X=0
Expectation: the result match to numpy
"""
context.set_context(mode=mode, device_target="CPU")
# test for real scalar float 32
rtol = 1e-4
atol = 1e-5
msp_eigh = EighNet(True)
@@ -99,6 +111,7 @@ def test_eigh_net(n: int, mode):
msp_w, msp_v = msp_eigh(tensor_a, -1)
assert np.allclose(A @ msp_v.asnumpy() - msp_v.asnumpy() @ np.diag(msp_w.asnumpy()), np.zeros((n, n)), rtol, atol)

# test case for real scalar double 64
A = np.random.rand(n, n)
rtol = 1e-5
atol = 1e-8
@@ -109,8 +122,8 @@ def test_eigh_net(n: int, mode):

# Compare with scipy
# sp_w, sp_v = sp.linalg.eig(A.astype(np.float64))
# sp_wl, sp_vl = sp.linalg.eigh(np.tril(A).astype(np.float64), lower=True, eigvals_only=False)
# sp_wu, sp_vu = sp.linalg.eigh(A.astype(np.float64), lower=False, eigvals_only=False)
# p_wl, sp_vl = sp.linalg.eigh(np.tril(A).astype(np.float64), lower=True, eigvals_only=False)

sym_Al = (np.tril((np.tril(A) - np.tril(A).T)) + np.tril(A).T)
sym_Au = (np.triu((np.triu(A) - np.triu(A).T)) + np.triu(A).T)
@@ -119,3 +132,65 @@ def test_eigh_net(n: int, mode):
assert np.allclose(sym_Au @ msp_vu.asnumpy() - msp_vu.asnumpy() @ np.diag(msp_wu.asnumpy()), np.zeros((n, n)), rtol,
atol)
assert np.allclose(A @ msp_v.asnumpy() - msp_v.asnumpy() @ np.diag(msp_w.asnumpy()), np.zeros((n, n)), rtol, atol)

# test case for complex64
rtol = 1e-4
atol = 1e-5
A = np.array(np.random.rand(n, n), dtype=np.complex64)
for i in range(0, n):
for j in range(0, n):
if i == j:
A[i][j] = complex(np.random.rand(1, 1), 0)
else:
A[i][j] = complex(np.random.rand(1, 1), np.random.rand(1, 1))
msp_eigh = EighNet(True)
sym_Al = (np.tril((np.tril(A) - np.tril(A).T)) + np.tril(A).conj().T)
sym_Au = (np.triu((np.triu(A) - np.triu(A).T)) + np.triu(A).conj().T)
msp_w, msp_v = msp_eigh(Tensor(np.array(A).astype(np.complex64)), 0)
msp_wl, msp_vl = msp_eigh(Tensor(np.array(A).astype(np.complex64)), -1)
msp_wu, msp_vu = msp_eigh(Tensor(np.array(A).astype(np.complex64)), 1)
# Compare with scipy, scipy passed
# sp_w, sp_v = sp.linalg.eig(A.astype(np.complex128))
# sp_wl, sp_vl = sp.linalg.eigh(np.tril(A).astype(np.complex128), lower=True, eigvals_only=False)
# sp_wu, sp_vu = sp.linalg.eigh(A.astype(np.complex128), lower=False, eigvals_only=False)
# assert np.allclose(A @ sp_v - sp_v @ np.diag(sp_w), np.zeros((n, n)), rtol, atol)
# assert np.allclose(sym_Al @ sp_vl - sp_vl @ np.diag(sp_wl), np.zeros((n, n)), rtol, atol)
# assert np.allclose(sym_Au @ sp_vu - sp_vu @ np.diag(sp_wu), np.zeros((n, n)), rtol, atol)

# print(A @ msp_v.asnumpy() - msp_v.asnumpy() @ np.diag(msp_w.asnumpy()))
assert np.allclose(sym_Al @ msp_vl.asnumpy() - msp_vl.asnumpy() @ np.diag(msp_wl.asnumpy()), np.zeros((n, n)), rtol,
atol)
assert np.allclose(sym_Au @ msp_vu.asnumpy() - msp_vu.asnumpy() @ np.diag(msp_wu.asnumpy()), np.zeros((n, n)), rtol,
atol)
assert np.allclose(A @ msp_v.asnumpy() - msp_v.asnumpy() @ np.diag(msp_w.asnumpy()), np.zeros((n, n)), rtol, atol)

# test for complex128
rtol = 1e-5
atol = 1e-8
A = np.array(np.random.rand(n, n), dtype=np.complex128)
for i in range(0, n):
for j in range(0, n):
if i == j:
A[i][j] = complex(np.random.rand(1, 1), 0)
else:
A[i][j] = complex(np.random.rand(1, 1), np.random.rand(1, 1))
msp_eigh = EighNet(True)
sym_Al = (np.tril((np.tril(A) - np.tril(A).T)) + np.tril(A).conj().T)
sym_Au = (np.triu((np.triu(A) - np.triu(A).T)) + np.triu(A).conj().T)
msp_w, msp_v = msp_eigh(Tensor(np.array(A).astype(np.complex128)), 0)
msp_wl, msp_vl = msp_eigh(Tensor(np.array(A).astype(np.complex128)), -1)
msp_wu, msp_vu = msp_eigh(Tensor(np.array(A).astype(np.complex128)), 1)
# Compare with scipy, scipy passed
# sp_w, sp_v = sp.linalg.eig(A.astype(np.complex128))
# sp_wl, sp_vl = sp.linalg.eigh(np.tril(A).astype(np.complex128), lower=True, eigvals_only=False)
# sp_wu, sp_vu = sp.linalg.eigh(A.astype(np.complex128), lower=False, eigvals_only=False)
# assert np.allclose(A @ sp_v - sp_v @ np.diag(sp_w), np.zeros((n, n)), rtol, atol)
# assert np.allclose(sym_Al @ sp_vl - sp_vl @ np.diag(sp_wl), np.zeros((n, n)), rtol, atol)
# assert np.allclose(sym_Au @ sp_vu - sp_vu @ np.diag(sp_wu), np.zeros((n, n)), rtol, atol)

# print(A @ msp_v.asnumpy() - msp_v.asnumpy() @ np.diag(msp_w.asnumpy()))
assert np.allclose(sym_Al @ msp_vl.asnumpy() - msp_vl.asnumpy() @ np.diag(msp_wl.asnumpy()), np.zeros((n, n)), rtol,
atol)
assert np.allclose(sym_Au @ msp_vu.asnumpy() - msp_vu.asnumpy() @ np.diag(msp_wu.asnumpy()), np.zeros((n, n)), rtol,
atol)
assert np.allclose(A @ msp_v.asnumpy() - msp_v.asnumpy() @ np.diag(msp_w.asnumpy()), np.zeros((n, n)), rtol, atol)

Loading…
Cancel
Save