Browse Source

use workspace memery for symmetric matrix eigenvalue/eigenvectors

split generic matrix decompress with symmetric matrix to eig and eigh ops
tags/v1.6.0
wenbean 4 years ago
parent
commit
45f9c7bfb7
6 changed files with 428 additions and 85 deletions
  1. +17
    -40
      mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/eig_cpu_kernel.cc
  2. +24
    -22
      mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/eig_cpu_kernel.h
  3. +124
    -0
      mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/eigh_cpu_kernel.cc
  4. +86
    -0
      mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/eigh_cpu_kernel.h
  5. +161
    -0
      tests/st/ops/cpu/test_solve_eig_value_op.py
  6. +16
    -23
      tests/st/ops/cpu/test_solve_eigh_value_op.py

+ 17
- 40
mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/eig_cpu_kernel.cc View File

@@ -22,7 +22,7 @@ namespace mindspore {
namespace kernel {

namespace {
constexpr size_t kInputsNum = 2;
constexpr size_t kInputsNum = 1;
constexpr size_t kOutputsNum = 2;
constexpr size_t kDefaultShape = 1;
constexpr auto kAMatrixDimNum = 2;
@@ -43,8 +43,8 @@ template <typename T>
using ComplexMatrixSquare = Eigen::Matrix<std::complex<T>, Dynamic, Dynamic, RowMajor>;

template <typename T, typename C>
void EighCPUKernel<T, C>::InitKernel(const CNodePtr &kernel_node) {
MS_LOG(INFO) << "init eigen value kernel";
void EigCPUKernel<T, C>::InitKernel(const CNodePtr &kernel_node) {
MS_LOG(INFO) << "init eig cpu kernel";
MS_EXCEPTION_IF_NULL(kernel_node);
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);

@@ -54,27 +54,22 @@ void EighCPUKernel<T, C>::InitKernel(const CNodePtr &kernel_node) {
CHECK_KERNEL_INPUTS_NUM(A_shape.size(), kAMatrixDimNum, AnfAlgo::GetCNodeName(kernel_node));

if (A_shape[kDim0] != A_shape[kDim1]) {
MS_LOG(EXCEPTION) << "wrong array shape, A should be a matrix, but got [" << A_shape[kDim0] << " X "
MS_LOG(EXCEPTION) << "wrong array shape, A should be a matrix, but got [" << A_shape[kDim0] << " X "
<< A_shape[kDim1] << "]";
}
m_ = A_shape[kDim0];
}

template <typename T, typename C>
bool SolveSelfAdjointMatrix(Map<MatrixSquare<T>> *A, Map<MatrixSquare<C>> *output, Map<MatrixSquare<C>> *outputv,
bool compute_eigen_vectors) {
Eigen::SelfAdjointEigenSolver<MatrixSquare<T>> solver(*A);
output->noalias() = solver.eigenvalues();
if (compute_eigen_vectors) {
outputv->noalias() = solver.eigenvectors();
}
return true;
void EigCPUKernel<T, C>::InitInputOutputSize(const CNodePtr &kernel_node) {
CPUKernel::InitInputOutputSize(kernel_node);
(void)workspace_size_list_.template emplace_back(m_ * m_ * sizeof(T));
}

template <typename T, typename C>
bool SolveGenericMatrix(Map<MatrixSquare<T>> *A, Map<MatrixSquare<C>> *output, Map<MatrixSquare<C>> *outputv,
bool compute_eigen_vectors) {
Eigen::EigenSolver<MatrixSquare<T>> solver(*A);
bool SolveGenericRealScalaMatrix(const Map<MatrixSquare<T>> &A, Map<MatrixSquare<C>> *output,
Map<MatrixSquare<C>> *outputv, bool compute_eigen_vectors) {
Eigen::EigenSolver<MatrixSquare<T>> solver(A);
output->noalias() = solver.eigenvalues();
if (compute_eigen_vectors) {
outputv->noalias() = solver.eigenvectors();
@@ -83,20 +78,9 @@ bool SolveGenericMatrix(Map<MatrixSquare<T>> *A, Map<MatrixSquare<C>> *output, M
}

template <typename T, typename C>
bool SolveRealMatrix(int symmetric_type, Map<MatrixSquare<T>> *A, Map<MatrixSquare<C>> *output,
Map<MatrixSquare<C>> *outputv, bool compute_eigen_vectors) {
if (symmetric_type != 0) {
return SolveSelfAdjointMatrix(A, output, outputv, compute_eigen_vectors);
} else {
// this is for none symmetric matrix eigenvalue and eigen vectors, it should support complex
return SolveGenericMatrix(A, output, outputv, compute_eigen_vectors);
}
}

template <typename T, typename C>
bool SolveComplexMatrix(Map<MatrixSquare<T>> *A, Map<MatrixSquare<C>> *output, Map<MatrixSquare<C>> *outputv,
bool SolveComplexMatrix(const Map<MatrixSquare<T>> &A, Map<MatrixSquare<C>> *output, Map<MatrixSquare<C>> *outputv,
bool compute_eigen_vectors) {
Eigen::ComplexEigenSolver<MatrixSquare<T>> solver(*A);
Eigen::ComplexEigenSolver<MatrixSquare<T>> solver(A);
output->noalias() = solver.eigenvalues();
if (compute_eigen_vectors) {
outputv->noalias() = solver.eigenvectors();
@@ -105,33 +89,26 @@ bool SolveComplexMatrix(Map<MatrixSquare<T>> *A, Map<MatrixSquare<C>> *output, M
}

template <typename T, typename C>
bool EighCPUKernel<T, C>::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
bool EigCPUKernel<T, C>::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputsNum, kernel_name_);

auto A_addr = reinterpret_cast<T *>(inputs[0]->addr);
// is the Matrix a symmetric matrix(0, all, general matxi, -1 lower triangle, 1 upper triangle)
auto symmetric_type = reinterpret_cast<int *>(inputs[1]->addr);
auto output_addr = reinterpret_cast<C *>(outputs[0]->addr);
auto output_v_addr = reinterpret_cast<C *>(outputs[1]->addr);
Map<MatrixSquare<T>> A(A_addr, m_, m_);
Map<MatrixSquare<C>> output(output_addr, m_, 1);
Map<MatrixSquare<C>> outputv(output_v_addr, m_, m_);
// selfadjoint matrix
if (*symmetric_type < 0) {
A = A.template selfadjointView<Lower>();
} else if (*symmetric_type > 0) {
A = A.template selfadjointView<Upper>();
}
// Real scalar eigen solver
if constexpr (std::is_same_v<T, float>) {
SolveRealMatrix(*symmetric_type, &A, &output, &outputv, compute_eigen_vectors);
SolveGenericRealScalaMatrix(A, &output, &outputv, compute_eigen_vectors);
} else if constexpr (std::is_same_v<T, double>) {
SolveRealMatrix(*symmetric_type, &A, &output, &outputv, compute_eigen_vectors);
SolveGenericRealScalaMatrix(A, &output, &outputv, compute_eigen_vectors);
} else {
// complex eigen solver
SolveComplexMatrix(&A, &output, &outputv, compute_eigen_vectors);
SolveComplexMatrix(A, &output, &outputv, compute_eigen_vectors);
}
return true;
}


+ 24
- 22
mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/eig_cpu_kernel.h View File

@@ -14,8 +14,8 @@
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGH_CPU_KERNEL_H
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGH_CPU_KERNEL_H
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_EIG_CPU_KERNEL_H
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_EIG_CPU_KERNEL_H

#include <vector>
#include <complex>
@@ -28,51 +28,53 @@ namespace kernel {
using float_complex = std::complex<float>;
using double_complex = std::complex<double>;

/**
* this is for Generic matrix eigenvalues and eigenvectors
* @tparam T , input Type
* @tparam C , output Type, complex
*/
template <typename T, typename C>
class EighCPUKernel : public CPUKernel {
class EigCPUKernel : public CPUKernel {
public:
EighCPUKernel() = default;
~EighCPUKernel() override = default;
EigCPUKernel() = default;
~EigCPUKernel() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;

protected:
void InitInputOutputSize(const CNodePtr &kernel_node);

private:
size_t m_{1};
bool compute_eigen_vectors{false};
TypeId dtype_{kNumberTypeFloat32};
};

MS_REG_CPU_KERNEL_T_S(Eigh,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeComplex64)
.AddOutputAttr(kNumberTypeComplex64),
EighCPUKernel, float, float_complex);
MS_REG_CPU_KERNEL_T_S(Eigh,
MS_REG_CPU_KERNEL_T_S(
Eig,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
EigCPUKernel, float, float_complex);
MS_REG_CPU_KERNEL_T_S(Eig,
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeComplex128)
.AddOutputAttr(kNumberTypeComplex128),
EighCPUKernel, double, double_complex);
EigCPUKernel, double, double_complex);

MS_REG_CPU_KERNEL_T_S(Eigh,
MS_REG_CPU_KERNEL_T_S(Eig,
KernelAttr()
.AddInputAttr(kNumberTypeComplex64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeComplex64)
.AddOutputAttr(kNumberTypeComplex64),
EighCPUKernel, float_complex, float_complex);
MS_REG_CPU_KERNEL_T_S(Eigh,
EigCPUKernel, float_complex, float_complex);
MS_REG_CPU_KERNEL_T_S(Eig,
KernelAttr()
.AddInputAttr(kNumberTypeComplex128)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeComplex128)
.AddOutputAttr(kNumberTypeComplex128),
EighCPUKernel, double_complex, double_complex);
EigCPUKernel, double_complex, double_complex);
} // namespace kernel
} // namespace mindspore

#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGH_CPU_KERNEL_H
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_EIG_CPU_KERNEL_H

+ 124
- 0
mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/eigh_cpu_kernel.cc View File

@@ -0,0 +1,124 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/cpu/eigen/eigh_cpu_kernel.h"
#include <Eigen/Eigenvalues>
#include <type_traits>
#include "utils/ms_utils.h"

namespace mindspore {
namespace kernel {

namespace {
constexpr size_t kInputsNum = 2;
constexpr size_t kOutputsNum = 2;
constexpr size_t kDefaultShape = 1;
constexpr auto kAMatrixDimNum = 2;

} // namespace
using Eigen::Dynamic;
using Eigen::EigenSolver;
using Eigen::Lower;
using Eigen::Map;
using Eigen::MatrixBase;
using Eigen::RowMajor;
using Eigen::Upper;

template <typename T>
using MatrixSquare = Eigen::Matrix<T, Dynamic, Dynamic, RowMajor>;

template <typename T>
using ComplexMatrixSquare = Eigen::Matrix<std::complex<T>, Dynamic, Dynamic, RowMajor>;

template <typename T, typename C>
void EighCPUKernel<T, C>::InitKernel(const CNodePtr &kernel_node) {
MS_LOG(INFO) << "init eigh cpu kernel";
MS_EXCEPTION_IF_NULL(kernel_node);
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);

compute_eigen_vectors = AnfAlgo::GetNodeAttr<bool>(kernel_node, C_EIEH_VECTOR);

auto A_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
CHECK_KERNEL_INPUTS_NUM(A_shape.size(), kAMatrixDimNum, AnfAlgo::GetCNodeName(kernel_node));

if (A_shape[kDim0] != A_shape[kDim1]) {
MS_LOG(EXCEPTION) << "wrong array shape, A should be a matrix, but got [" << A_shape[kDim0] << " X "
<< A_shape[kDim1] << "]";
}
m_ = A_shape[kDim0];
}

template <typename T, typename C>
void EighCPUKernel<T, C>::InitInputOutputSize(const CNodePtr &kernel_node) {
CPUKernel::InitInputOutputSize(kernel_node);
(void)workspace_size_list_.template emplace_back(m_ * m_ * sizeof(T));
}

template <typename T, typename C>
bool SolveSelfAdjointMatrix(const Map<MatrixSquare<T>> &A, Map<MatrixSquare<C>> *output, Map<MatrixSquare<C>> *outputv,
bool compute_eigen_vectors) {
Eigen::SelfAdjointEigenSolver<MatrixSquare<T>> solver(A);
output->noalias() = solver.eigenvalues();
if (compute_eigen_vectors) {
outputv->noalias() = solver.eigenvectors();
}
return true;
}

template <typename T, typename C>
bool SolveComplexMatrix(const Map<MatrixSquare<T>> &A, Map<MatrixSquare<C>> *output, Map<MatrixSquare<C>> *outputv,
bool compute_eigen_vectors) {
Eigen::ComplexEigenSolver<MatrixSquare<T>> solver(A);
output->noalias() = solver.eigenvalues();
if (compute_eigen_vectors) {
outputv->noalias() = solver.eigenvectors();
}
return true;
}

template <typename T, typename C>
bool EighCPUKernel<T, C>::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputsNum, kernel_name_);

auto A_addr = reinterpret_cast<T *>(inputs[0]->addr);
// is the Matrix a symmetric matrix(0, all, general matxi, -1 lower triangle, 1 upper triangle)
auto symmetric_type = reinterpret_cast<bool *>(inputs[1]->addr);
auto output_addr = reinterpret_cast<C *>(outputs[0]->addr);
auto output_v_addr = reinterpret_cast<C *>(outputs[1]->addr);
Map<MatrixSquare<T>> A(A_addr, m_, m_);
Map<MatrixSquare<T>> A_(A_addr, m_, m_);
Map<MatrixSquare<C>> output(output_addr, m_, 1);
Map<MatrixSquare<C>> outputv(output_v_addr, m_, m_);
// selfadjoint matrix
if (*symmetric_type) {
A_ = A.template selfadjointView<Lower>();
} else {
A_ = A.template selfadjointView<Upper>();
}
// Real scalar eigen solver
if constexpr (std::is_same_v<T, float>) {
SolveSelfAdjointMatrix(A_, &output, &outputv, compute_eigen_vectors);
} else if constexpr (std::is_same_v<T, double>) {
SolveSelfAdjointMatrix(A_, &output, &outputv, compute_eigen_vectors);
} else {
// complex eigen solver
SolveComplexMatrix(A_, &output, &outputv, compute_eigen_vectors);
}
return true;
}
} // namespace kernel
} // namespace mindspore

+ 86
- 0
mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/eigh_cpu_kernel.h View File

@@ -0,0 +1,86 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_EIGH_CPU_KERNEL_H
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_EIGH_CPU_KERNEL_H

#include <vector>
#include <complex>
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"

namespace mindspore {
namespace kernel {

using float_complex = std::complex<float>;
using double_complex = std::complex<double>;

/**
* this is for Symmetric matrix eigenvalues & eigenvectors, can decompress the lower/upper triangle matrix
* @tparam T , input Type
* @tparam C , output Type, complex
*/
template <typename T, typename C>
class EighCPUKernel : public CPUKernel {
public:
EighCPUKernel() = default;
~EighCPUKernel() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;

protected:
void InitInputOutputSize(const CNodePtr &kernel_node);

private:
size_t m_{1};
bool compute_eigen_vectors{false};
TypeId dtype_{kNumberTypeFloat32};
};

MS_REG_CPU_KERNEL_T_S(Eigh,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeBool)
.AddOutputAttr(kNumberTypeComplex64)
.AddOutputAttr(kNumberTypeComplex64),
EighCPUKernel, float, float_complex);
MS_REG_CPU_KERNEL_T_S(Eigh,
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeBool)
.AddOutputAttr(kNumberTypeComplex128)
.AddOutputAttr(kNumberTypeComplex128),
EighCPUKernel, double, double_complex);

MS_REG_CPU_KERNEL_T_S(Eigh,
KernelAttr()
.AddInputAttr(kNumberTypeComplex64)
.AddInputAttr(kNumberTypeBool)
.AddOutputAttr(kNumberTypeComplex64)
.AddOutputAttr(kNumberTypeComplex64),
EighCPUKernel, float_complex, float_complex);
MS_REG_CPU_KERNEL_T_S(Eigh,
KernelAttr()
.AddInputAttr(kNumberTypeComplex128)
.AddInputAttr(kNumberTypeBool)
.AddOutputAttr(kNumberTypeComplex128)
.AddOutputAttr(kNumberTypeComplex128),
EighCPUKernel, double_complex, double_complex);
} // namespace kernel
} // namespace mindspore

#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_EIGH_CPU_KERNEL_H

+ 161
- 0
tests/st/ops/cpu/test_solve_eig_value_op.py View File

@@ -0,0 +1,161 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test for solve eigenvalues & eigen vectors"""

import pytest
import numpy as np
import mindspore as msp
import mindspore.nn as nn
import mindspore.context as context
from mindspore import Tensor
from mindspore.ops import PrimitiveWithInfer, prim_attr_register
from mindspore._checkparam import Validator as validator

np.random.seed(0)


class Eig(PrimitiveWithInfer):
"""
Eig decomposition,(generic matrix)
Ax = lambda * x
"""

@prim_attr_register
def __init__(self, compute_eigenvectors):
super().__init__(name="Eig")
self.init_prim_io_names(inputs=['A'], outputs=['output', 'output_v'])
self.compute_eigenvectors = validator.check_value_type(
"compute_eigenvectors", compute_eigenvectors, [bool], self.name)

def __infer__(self, A):
shape = {}
if A['dtype'] == msp.tensor_type(msp.dtype.float32):
shape = {
'shape': ((A['shape'][0],), (A['shape'][0], A['shape'][0])),
'dtype': (msp.complex64, msp.complex64),
'value': None
}
elif A['dtype'] == msp.tensor_type(msp.dtype.float64):
shape = {
'shape': ((A['shape'][0],), (A['shape'][0], A['shape'][0])),
'dtype': (msp.complex128, msp.complex128),
'value': None
}
elif A['dtype'] == msp.tensor_type(msp.dtype.complex64):
shape = {
'shape': ((A['shape'][0],), (A['shape'][0], A['shape'][0])),
'dtype': (msp.complex64, msp.complex64),
'value': None
}
elif A['dtype'] == msp.tensor_type(msp.dtype.complex128):
shape = {
'shape': ((A['shape'][0],), (A['shape'][0], A['shape'][0])),
'dtype': (msp.complex128, msp.complex128),
'value': None
}
return shape


class EigNet(nn.Cell):
def __init__(self, b):
super(EigNet, self).__init__()
self.b = b
self.eig = Eig(b)

def construct(self, A):
r = self.eig(A)
if self.b:
return (r[0], r[1])
return (r[0],)


def match(v, v_, error=0):
if error > 0:
np.testing.assert_almost_equal(v, v_, decimal=error)
else:
np.testing.assert_equal(v, v_)


def create_sym_pos_matrix(m, n, dtype):
a = (np.random.random((m, n)) + np.eye(m, n)).astype(dtype)
return np.dot(a, a.T)


@pytest.mark.parametrize('n', [4, 6, 9, 10])
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
@pytest.mark.platform_x86_cpu
def test_eig_net(n: int, mode):
"""
Feature: ALL To ALL
Description: test cases for eigen decomposition test cases for Ax= lambda * x /( A- lambda * E)X=0
Expectation: the result match to numpy
"""
# test for real scalar float 32
rtol = 1e-4
atol = 1e-5
msp_eig = EigNet(True)
A = create_sym_pos_matrix(n, n, np.float32)
tensor_a = Tensor(np.array(A).astype(np.float32))
msp_w, msp_v = msp_eig(tensor_a)
assert np.allclose(A @ msp_v.asnumpy() - msp_v.asnumpy() @ np.diag(msp_w.asnumpy()), np.zeros((n, n)), rtol, atol)

# test case for real scalar double 64
A = np.random.rand(n, n)
rtol = 1e-5
atol = 1e-8
msp_eig = EigNet(True)
msp_w, msp_v = msp_eig(Tensor(np.array(A).astype(np.float64)))

# Compare with scipy
# sp_w, sp_v = sp.linalg.eig(A.astype(np.float64))
assert np.allclose(A @ msp_v.asnumpy() - msp_v.asnumpy() @ np.diag(msp_w.asnumpy()), np.zeros((n, n)), rtol, atol)

# test case for complex64
rtol = 1e-4
atol = 1e-5
A = np.array(np.random.rand(n, n), dtype=np.complex64)
for i in range(0, n):
for j in range(0, n):
if i == j:
A[i][j] = complex(np.random.rand(1, 1), 0)
else:
A[i][j] = complex(np.random.rand(1, 1), np.random.rand(1, 1))
msp_eig = EigNet(True)
msp_w, msp_v = msp_eig(Tensor(np.array(A).astype(np.complex64)))
# Compare with scipy, scipy passed
# sp_w, sp_v = sp.linalg.eig(A.astype(np.complex128))
# assert np.allclose(A @ sp_v - sp_v @ np.diag(sp_w), np.zeros((n, n)), rtol, atol)

# print(A @ msp_v.asnumpy() - msp_v.asnumpy() @ np.diag(msp_w.asnumpy()))
assert np.allclose(A @ msp_v.asnumpy() - msp_v.asnumpy() @ np.diag(msp_w.asnumpy()), np.zeros((n, n)), rtol, atol)

# test for complex128
rtol = 1e-5
atol = 1e-8
A = np.array(np.random.rand(n, n), dtype=np.complex128)
for i in range(0, n):
for j in range(0, n):
if i == j:
A[i][j] = complex(np.random.rand(1, 1), 0)
else:
A[i][j] = complex(np.random.rand(1, 1), np.random.rand(1, 1))
msp_eig = EigNet(True)
msp_w, msp_v = msp_eig(Tensor(np.array(A).astype(np.complex128)))
# Compare with scipy, scipy passed
# sp_w, sp_v = sp.linalg.eig(A.astype(np.complex128))
# assert np.allclose(A @ sp_v - sp_v @ np.diag(sp_w), np.zeros((n, n)), rtol, atol)

# print(A @ msp_v.asnumpy() - msp_v.asnumpy() @ np.diag(msp_w.asnumpy()))
assert np.allclose(A @ msp_v.asnumpy() - msp_v.asnumpy() @ np.diag(msp_w.asnumpy()), np.zeros((n, n)), rtol, atol)

+ 16
- 23
tests/st/ops/cpu/test_solve_eigh_value_op.py View File

@@ -28,7 +28,7 @@ np.random.seed(0)

class Eigh(PrimitiveWithInfer):
"""
Eigh decomposition
Eigh decomposition(Symmetric matrix)
Ax = lambda * x
"""

@@ -74,7 +74,7 @@ class EighNet(nn.Cell):
self.b = b
self.eigh = Eigh(b)

def construct(self, A, s=0):
def construct(self, A, s=True):
r = self.eigh(A, s)
if self.b:
return (r[0], r[1])
@@ -107,31 +107,32 @@ def test_eigh_net(n: int, mode):
atol = 1e-5
msp_eigh = EighNet(True)
A = create_sym_pos_matrix(n, n, np.float32)
tensor_a = Tensor(np.array(A).astype(np.float32))
msp_w, msp_v = msp_eigh(tensor_a, -1)
assert np.allclose(A @ msp_v.asnumpy() - msp_v.asnumpy() @ np.diag(msp_w.asnumpy()), np.zeros((n, n)), rtol, atol)
msp_wl, msp_vl = msp_eigh(Tensor(np.array(A).astype(np.float32)), True)
msp_wu, msp_vu = msp_eigh(Tensor(np.array(A).astype(np.float32)), False)
sym_Al = (np.tril((np.tril(A) - np.tril(A).T)) + np.tril(A).T)
sym_Au = (np.triu((np.triu(A) - np.triu(A).T)) + np.triu(A).T)
assert np.allclose(sym_Al @ msp_vl.asnumpy() - msp_vl.asnumpy() @ np.diag(msp_wl.asnumpy()), np.zeros((n, n)), rtol,
atol)
assert np.allclose(sym_Au @ msp_vu.asnumpy() - msp_vu.asnumpy() @ np.diag(msp_wu.asnumpy()), np.zeros((n, n)), rtol,
atol)

# test case for real scalar double 64
A = np.random.rand(n, n)
rtol = 1e-5
atol = 1e-8
msp_eigh = EighNet(True)
msp_w, msp_v = msp_eigh(Tensor(np.array(A).astype(np.float64)), 0)
msp_wl, msp_vl = msp_eigh(Tensor(np.array(A).astype(np.float64)), -1)
msp_wu, msp_vu = msp_eigh(Tensor(np.array(A).astype(np.float64)), 1)
msp_wl, msp_vl = msp_eigh(Tensor(np.array(A).astype(np.float64)), True)
msp_wu, msp_vu = msp_eigh(Tensor(np.array(A).astype(np.float64)), False)

# Compare with scipy
# sp_w, sp_v = sp.linalg.eig(A.astype(np.float64))
# sp_wl, sp_vl = sp.linalg.eigh(np.tril(A).astype(np.float64), lower=True, eigvals_only=False)
# sp_wu, sp_vu = sp.linalg.eigh(A.astype(np.float64), lower=False, eigvals_only=False)

sym_Al = (np.tril((np.tril(A) - np.tril(A).T)) + np.tril(A).T)
sym_Au = (np.triu((np.triu(A) - np.triu(A).T)) + np.triu(A).T)
assert np.allclose(sym_Al @ msp_vl.asnumpy() - msp_vl.asnumpy() @ np.diag(msp_wl.asnumpy()), np.zeros((n, n)), rtol,
atol)
assert np.allclose(sym_Au @ msp_vu.asnumpy() - msp_vu.asnumpy() @ np.diag(msp_wu.asnumpy()), np.zeros((n, n)), rtol,
atol)
assert np.allclose(A @ msp_v.asnumpy() - msp_v.asnumpy() @ np.diag(msp_w.asnumpy()), np.zeros((n, n)), rtol, atol)

# test case for complex64
rtol = 1e-4
@@ -146,14 +147,11 @@ def test_eigh_net(n: int, mode):
msp_eigh = EighNet(True)
sym_Al = (np.tril((np.tril(A) - np.tril(A).T)) + np.tril(A).conj().T)
sym_Au = (np.triu((np.triu(A) - np.triu(A).T)) + np.triu(A).conj().T)
msp_w, msp_v = msp_eigh(Tensor(np.array(A).astype(np.complex64)), 0)
msp_wl, msp_vl = msp_eigh(Tensor(np.array(A).astype(np.complex64)), -1)
msp_wu, msp_vu = msp_eigh(Tensor(np.array(A).astype(np.complex64)), 1)
msp_wl, msp_vl = msp_eigh(Tensor(np.array(A).astype(np.complex64)), True)
msp_wu, msp_vu = msp_eigh(Tensor(np.array(A).astype(np.complex64)), False)
# Compare with scipy, scipy passed
# sp_w, sp_v = sp.linalg.eig(A.astype(np.complex128))
# sp_wl, sp_vl = sp.linalg.eigh(np.tril(A).astype(np.complex128), lower=True, eigvals_only=False)
# sp_wu, sp_vu = sp.linalg.eigh(A.astype(np.complex128), lower=False, eigvals_only=False)
# assert np.allclose(A @ sp_v - sp_v @ np.diag(sp_w), np.zeros((n, n)), rtol, atol)
# assert np.allclose(sym_Al @ sp_vl - sp_vl @ np.diag(sp_wl), np.zeros((n, n)), rtol, atol)
# assert np.allclose(sym_Au @ sp_vu - sp_vu @ np.diag(sp_wu), np.zeros((n, n)), rtol, atol)

@@ -162,7 +160,6 @@ def test_eigh_net(n: int, mode):
atol)
assert np.allclose(sym_Au @ msp_vu.asnumpy() - msp_vu.asnumpy() @ np.diag(msp_wu.asnumpy()), np.zeros((n, n)), rtol,
atol)
assert np.allclose(A @ msp_v.asnumpy() - msp_v.asnumpy() @ np.diag(msp_w.asnumpy()), np.zeros((n, n)), rtol, atol)

# test for complex128
rtol = 1e-5
@@ -177,14 +174,11 @@ def test_eigh_net(n: int, mode):
msp_eigh = EighNet(True)
sym_Al = (np.tril((np.tril(A) - np.tril(A).T)) + np.tril(A).conj().T)
sym_Au = (np.triu((np.triu(A) - np.triu(A).T)) + np.triu(A).conj().T)
msp_w, msp_v = msp_eigh(Tensor(np.array(A).astype(np.complex128)), 0)
msp_wl, msp_vl = msp_eigh(Tensor(np.array(A).astype(np.complex128)), -1)
msp_wu, msp_vu = msp_eigh(Tensor(np.array(A).astype(np.complex128)), 1)
msp_wl, msp_vl = msp_eigh(Tensor(np.array(A).astype(np.complex128)), True)
msp_wu, msp_vu = msp_eigh(Tensor(np.array(A).astype(np.complex128)), False)
# Compare with scipy, scipy passed
# sp_w, sp_v = sp.linalg.eig(A.astype(np.complex128))
# sp_wl, sp_vl = sp.linalg.eigh(np.tril(A).astype(np.complex128), lower=True, eigvals_only=False)
# sp_wu, sp_vu = sp.linalg.eigh(A.astype(np.complex128), lower=False, eigvals_only=False)
# assert np.allclose(A @ sp_v - sp_v @ np.diag(sp_w), np.zeros((n, n)), rtol, atol)
# assert np.allclose(sym_Al @ sp_vl - sp_vl @ np.diag(sp_wl), np.zeros((n, n)), rtol, atol)
# assert np.allclose(sym_Au @ sp_vu - sp_vu @ np.diag(sp_wu), np.zeros((n, n)), rtol, atol)

@@ -193,4 +187,3 @@ def test_eigh_net(n: int, mode):
atol)
assert np.allclose(sym_Au @ msp_vu.asnumpy() - msp_vu.asnumpy() @ np.diag(msp_wu.asnumpy()), np.zeros((n, n)), rtol,
atol)
assert np.allclose(A @ msp_v.asnumpy() - msp_v.asnumpy() @ np.diag(msp_w.asnumpy()), np.zeros((n, n)), rtol, atol)

Loading…
Cancel
Save