| @@ -24,8 +24,6 @@ namespace kernel { | |||
| namespace { | |||
| constexpr size_t kInputsNum = 1; | |||
| constexpr size_t kOutputsNum = 2; | |||
| constexpr size_t kDefaultShape = 1; | |||
| constexpr auto kAMatrixDimNum = 2; | |||
| } // namespace | |||
| using Eigen::Dynamic; | |||
| @@ -45,12 +43,8 @@ using ComplexMatrixSquare = Eigen::Matrix<std::complex<T>, Dynamic, Dynamic, Row | |||
| template <typename T, typename C> | |||
| void EigCPUKernel<T, C>::InitKernel(const CNodePtr &kernel_node) { | |||
| dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0); | |||
| compute_eigen_vectors = AnfAlgo::GetNodeAttr<bool>(kernel_node, C_EIEH_VECTOR); | |||
| auto A_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| CHECK_KERNEL_INPUTS_NUM(A_shape.size(), kAMatrixDimNum, AnfAlgo::GetCNodeName(kernel_node)); | |||
| if (A_shape.size() != kShape2dDims || A_shape[0] != A_shape[1]) { | |||
| MS_LOG(EXCEPTION) << "wrong array shape, A should be a matrix, but got [" << A_shape[0] << " X " << A_shape[1] | |||
| << "]"; | |||
| @@ -22,10 +22,8 @@ namespace mindspore { | |||
| namespace kernel { | |||
| namespace { | |||
| constexpr size_t kInputsNum = 2; | |||
| constexpr size_t kInputsNum = 1; | |||
| constexpr size_t kOutputsNum = 2; | |||
| constexpr size_t kDefaultShape = 1; | |||
| constexpr auto kAMatrixDimNum = 2; | |||
| } // namespace | |||
| using Eigen::Dynamic; | |||
| @@ -45,12 +43,9 @@ using ComplexMatrixSquare = Eigen::Matrix<std::complex<T>, Dynamic, Dynamic, Row | |||
| template <typename T> | |||
| void EighCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) { | |||
| dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0); | |||
| compute_eigen_vectors = AnfAlgo::GetNodeAttr<bool>(kernel_node, C_EIEH_VECTOR); | |||
| compute_eigen_vectors_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, C_EIEH_VECTOR); | |||
| lower_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, LOWER); | |||
| auto A_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| CHECK_KERNEL_INPUTS_NUM(A_shape.size(), kAMatrixDimNum, AnfAlgo::GetCNodeName(kernel_node)); | |||
| if (A_shape.size() != kShape2dDims || A_shape[0] != A_shape[1]) { | |||
| MS_LOG(EXCEPTION) << "wrong array shape, A should be a matrix, but got [" << A_shape[0] << " X " << A_shape[1] | |||
| << "]"; | |||
| @@ -91,10 +86,8 @@ bool EighCPUKernel<T>::Launch(const std::vector<AddressPtr> &inputs, const std:: | |||
| const std::vector<AddressPtr> &outputs) { | |||
| CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputsNum, kernel_name_); | |||
| CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputsNum, kernel_name_); | |||
| auto A_addr = reinterpret_cast<T *>(inputs[0]->addr); | |||
| // is the Matrix a symmetric matrix(0, all, general matxi, -1 lower triangle, 1 upper triangle) | |||
| auto symmetric_type = reinterpret_cast<bool *>(inputs[1]->addr); | |||
| // is the Matrix a symmetric matrix(true lower triangle, false upper triangle) | |||
| auto output_addr = reinterpret_cast<T *>(outputs[0]->addr); | |||
| auto output_v_addr = reinterpret_cast<T *>(outputs[1]->addr); | |||
| Map<MatrixSquare<T>> A(A_addr, m_, m_); | |||
| @@ -102,19 +95,19 @@ bool EighCPUKernel<T>::Launch(const std::vector<AddressPtr> &inputs, const std:: | |||
| Map<MatrixSquare<T>> output(output_addr, m_, 1); | |||
| Map<MatrixSquare<T>> outputv(output_v_addr, m_, m_); | |||
| // selfadjoint matrix | |||
| if (*symmetric_type) { | |||
| if (lower_) { | |||
| A_ = A.template selfadjointView<Lower>(); | |||
| } else { | |||
| A_ = A.template selfadjointView<Upper>(); | |||
| } | |||
| // Real scalar eigen solver | |||
| if constexpr (std::is_same_v<T, float>) { | |||
| SolveSelfAdjointMatrix(A_, &output, &outputv, compute_eigen_vectors); | |||
| SolveSelfAdjointMatrix(A_, &output, &outputv, compute_eigen_vectors_); | |||
| } else if constexpr (std::is_same_v<T, double>) { | |||
| SolveSelfAdjointMatrix(A_, &output, &outputv, compute_eigen_vectors); | |||
| SolveSelfAdjointMatrix(A_, &output, &outputv, compute_eigen_vectors_); | |||
| } else { | |||
| // complex eigen solver | |||
| SolveComplexMatrix(A_, &output, &outputv, compute_eigen_vectors); | |||
| SolveComplexMatrix(A_, &output, &outputv, compute_eigen_vectors_); | |||
| } | |||
| return true; | |||
| } | |||
| @@ -46,36 +46,29 @@ class EighCPUKernel : public CPUKernel { | |||
| private: | |||
| size_t m_{1}; | |||
| bool compute_eigen_vectors{false}; | |||
| bool compute_eigen_vectors_{false}; | |||
| bool lower_{true}; | |||
| TypeId dtype_{kNumberTypeFloat32}; | |||
| }; | |||
| MS_REG_CPU_KERNEL_T(Eigh, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeBool) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| EighCPUKernel, float); | |||
| MS_REG_CPU_KERNEL_T(Eigh, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat64) | |||
| .AddInputAttr(kNumberTypeBool) | |||
| .AddOutputAttr(kNumberTypeFloat64) | |||
| .AddOutputAttr(kNumberTypeFloat64), | |||
| EighCPUKernel, double); | |||
| MS_REG_CPU_KERNEL_T( | |||
| Eigh, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| EighCPUKernel, float); | |||
| MS_REG_CPU_KERNEL_T( | |||
| Eigh, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), | |||
| EighCPUKernel, double); | |||
| MS_REG_CPU_KERNEL_T(Eigh, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeComplex64) | |||
| .AddInputAttr(kNumberTypeBool) | |||
| .AddOutputAttr(kNumberTypeComplex64) | |||
| .AddOutputAttr(kNumberTypeComplex64), | |||
| EighCPUKernel, float_complex); | |||
| MS_REG_CPU_KERNEL_T(Eigh, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeComplex128) | |||
| .AddInputAttr(kNumberTypeBool) | |||
| .AddOutputAttr(kNumberTypeComplex128) | |||
| .AddOutputAttr(kNumberTypeComplex128), | |||
| EighCPUKernel, double_complex); | |||
| @@ -18,6 +18,10 @@ | |||
| #include "transpose_impl.cuh" | |||
| #include "runtime/device/gpu/cuda_common.h" | |||
| #include "utils/complex.h" | |||
| template <typename T> | |||
| using Complex = mindspore::utils::Complex<T>; | |||
| template <typename T> | |||
| __global__ void Transpose(const size_t size, const T *input, const size_t *input_shape, const size_t *input_axis, | |||
| @@ -74,3 +78,9 @@ template void CalTranspose<int>(const size_t size, const int *input, const size_ | |||
| template void CalTranspose<int64_t>(const size_t size, const int64_t *input, const size_t *input_shape, | |||
| const size_t *input_axis, const size_t shape_size, int64_t *output, | |||
| cudaStream_t cuda_stream); | |||
| template void CalTranspose<Complex<float>>(const size_t size, const Complex<float> *input, const size_t *input_shape, | |||
| const size_t *input_axis, const size_t shape_size, Complex<float> *output, | |||
| cudaStream_t cuda_stream); | |||
| template void CalTranspose<Complex<double>>(const size_t size, const Complex<double> *input, const size_t *input_shape, | |||
| const size_t *input_axis, const size_t shape_size, Complex<double> *output, | |||
| cudaStream_t cuda_stream); | |||
| @@ -21,14 +21,12 @@ namespace kernel { | |||
| MS_REG_GPU_KERNEL_ONE(Eigh, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeComplex64) | |||
| .AddInputAttr(kNumberTypeBool) | |||
| .AddOutputAttr(kNumberTypeComplex64) | |||
| .AddOutputAttr(kNumberTypeComplex64), | |||
| EighcGpuKernel, Complex<float>) | |||
| MS_REG_GPU_KERNEL_ONE(Eigh, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeComplex128) | |||
| .AddInputAttr(kNumberTypeBool) | |||
| .AddOutputAttr(kNumberTypeComplex128) | |||
| .AddOutputAttr(kNumberTypeComplex128), | |||
| EighcGpuKernel, Complex<double>); | |||
| @@ -32,10 +32,12 @@ | |||
| #include "utils/convert_utils.h" | |||
| #include "utils/complex.h" | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/real_to_complex_impl.cuh" | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/transpose_impl.cuh" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| constexpr char C_EIEH_VECTOR[] = "compute_eigenvectors"; | |||
| constexpr char LOWER[] = "lower"; | |||
| template <typename T> | |||
| using Complex = mindspore::utils::Complex<T>; | |||
| @@ -61,6 +63,7 @@ class EighcGpuKernel : public GpuKernel { | |||
| dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0); | |||
| auto A_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| compute_eigen_vectors_ = static_cast<bool>(GetAttr<bool>(kernel_node, C_EIEH_VECTOR)); | |||
| lower_ = static_cast<bool>(GetAttr<bool>(kernel_node, LOWER)); | |||
| if (compute_eigen_vectors_) { | |||
| jobz_ = CUSOLVER_EIG_MODE_VECTOR; | |||
| } else { | |||
| @@ -84,13 +87,7 @@ class EighcGpuKernel : public GpuKernel { | |||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | |||
| // matrix A, input or output(eigenvector) | |||
| auto inout_A_addr = GetDeviceAddress<T>(inputs, 0); | |||
| auto lower = GetDeviceAddress<bool>(inputs, 1); | |||
| bool h_lower{true}; | |||
| CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, | |||
| cudaMemcpyAsync(&h_lower, lower, sizeof(bool), cudaMemcpyDeviceToHost, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)), | |||
| "copy lower do device failed"); | |||
| if (h_lower) { | |||
| if (lower_) { | |||
| uplo_ = CUBLAS_FILL_MODE_LOWER; | |||
| } else { | |||
| uplo_ = CUBLAS_FILL_MODE_UPPER; | |||
| @@ -105,24 +102,39 @@ class EighcGpuKernel : public GpuKernel { | |||
| // temp output eigenvalues real scalar | |||
| auto w_w_addr = GetDeviceAddress<D>(workspace, 0); | |||
| auto w_w_c_addr = GetDeviceAddress<T>(workspace, 1); | |||
| // temp eigenvector before transpose | |||
| auto w_v_addr = GetDeviceAddress<T>(workspace, 2); | |||
| CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, | |||
| cudaMemcpyAsync(output_v_addr, inout_A_addr, m_ * m_ * sizeof(T), | |||
| cudaMemcpyDeviceToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)), | |||
| "copy input matrix failed"); | |||
| size_t input_shape[kShape2dDims] = {m_, m_}; | |||
| size_t input_axis[kShape2dDims] = {1, 0}; | |||
| size_t *dev_input_shape = nullptr; | |||
| cudaMalloc(reinterpret_cast<void **>(&dev_input_shape), kShape2dDims * sizeof(size_t)); | |||
| size_t *dev_input_axis = nullptr; | |||
| cudaMalloc(reinterpret_cast<void **>(&dev_input_axis), kShape2dDims * sizeof(size_t)); | |||
| cudaMemcpyAsync(dev_input_shape, input_shape, kShape2dDims * sizeof(size_t), cudaMemcpyHostToDevice, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| cudaMemcpyAsync(dev_input_axis, input_axis, kShape2dDims * sizeof(size_t), cudaMemcpyHostToDevice, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| CalTranspose(m_ * m_, output_v_addr, dev_input_shape, dev_input_axis, kShape2dDims, w_v_addr, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| int lwork = 0; | |||
| if constexpr (std::is_same_v<T, Complex<float>>) { | |||
| cusolverDnCheevd_bufferSize(cusolver_handle_, jobz_, uplo_, m_, reinterpret_cast<cuComplex *>(output_v_addr), | |||
| lda_, w_w_addr, &lwork); | |||
| CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, cudaMalloc(reinterpret_cast<void **>(&d_work), sizeof(T) * lwork), | |||
| "cal eigenvalues workspace failed"); | |||
| cusolverDnCheevd(cusolver_handle_, jobz_, uplo_, m_, reinterpret_cast<cuComplex *>(output_v_addr), lda_, w_w_addr, | |||
| cusolverDnCheevd(cusolver_handle_, jobz_, uplo_, m_, reinterpret_cast<cuComplex *>(w_v_addr), lda_, w_w_addr, | |||
| reinterpret_cast<cuComplex *>(d_work), lwork, devInfo); | |||
| } else { | |||
| cusolverDnZheevd_bufferSize(cusolver_handle_, jobz_, uplo_, m_, | |||
| reinterpret_cast<cuDoubleComplex *>(output_v_addr), lda_, w_w_addr, &lwork); | |||
| CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, cudaMalloc(reinterpret_cast<void **>(&d_work), sizeof(T) * lwork), | |||
| "cal eigenvalues workspace failed"); | |||
| cusolverDnZheevd(cusolver_handle_, jobz_, uplo_, m_, reinterpret_cast<cuDoubleComplex *>(output_v_addr), lda_, | |||
| cusolverDnZheevd(cusolver_handle_, jobz_, uplo_, m_, reinterpret_cast<cuDoubleComplex *>(w_v_addr), lda_, | |||
| w_w_addr, reinterpret_cast<cuDoubleComplex *>(d_work), lwork, devInfo); | |||
| } | |||
| CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, | |||
| @@ -131,6 +143,8 @@ class EighcGpuKernel : public GpuKernel { | |||
| "copy eigenvalue from workspace to host failed"); | |||
| RealToComplex(m_, reinterpret_cast<D *>(w_w_c_addr), reinterpret_cast<D *>(output_w_addr), | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| CalTranspose(m_ * m_, w_v_addr, dev_input_shape, dev_input_axis, kShape2dDims, output_v_addr, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| // convert real scalar to complex | |||
| if (d_work) { | |||
| cudaFree(d_work); | |||
| @@ -153,8 +167,6 @@ class EighcGpuKernel : public GpuKernel { | |||
| void InitSizeLists() override { | |||
| // in/out matrix, eigenvector | |||
| input_size_list_.push_back(m_ * m_ * sizeof(T)); | |||
| // uplo | |||
| input_size_list_.push_back(sizeof(bool)); | |||
| // eigenvalues, cuda output original real scalar, should covert to complex<ft32/64> | |||
| output_size_list_.push_back(m_ * sizeof(T)); | |||
| output_size_list_.push_back(m_ * m_ * sizeof(T)); | |||
| @@ -162,6 +174,7 @@ class EighcGpuKernel : public GpuKernel { | |||
| workspace_size_list_.push_back(m_ * sizeof(D)); | |||
| // for temp pre-transpose complex mitrx | |||
| workspace_size_list_.push_back(m_ * sizeof(T)); | |||
| workspace_size_list_.push_back(m_ * m_ * sizeof(T)); | |||
| } | |||
| size_t m_{1}; | |||
| @@ -171,6 +184,7 @@ class EighcGpuKernel : public GpuKernel { | |||
| cublasFillMode_t uplo_ = CUBLAS_FILL_MODE_UPPER; | |||
| cusolverEigMode_t jobz_ = CUSOLVER_EIG_MODE_NOVECTOR; | |||
| bool compute_eigen_vectors_{false}; | |||
| bool lower_{true}; | |||
| std::vector<T *> h_array_{}; | |||
| std::vector<size_t> input_size_list_{}; | |||
| std::vector<size_t> output_size_list_{}; | |||
| @@ -18,19 +18,13 @@ | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MS_REG_GPU_KERNEL_ONE(Eigh, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeBool) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| EighGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(Eigh, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat64) | |||
| .AddInputAttr(kNumberTypeBool) | |||
| .AddOutputAttr(kNumberTypeFloat64) | |||
| .AddOutputAttr(kNumberTypeFloat64), | |||
| EighGpuKernel, double); | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| Eigh, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| EighGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| Eigh, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), | |||
| EighGpuKernel, double); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -30,10 +30,12 @@ | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | |||
| #include "backend/kernel_compiler/gpu/kernel_constants.h" | |||
| #include "utils/convert_utils.h" | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/transpose_impl.cuh" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| constexpr char C_EIEH_VECTOR[] = "compute_eigenvectors"; | |||
| constexpr char LOWER[] = "lower"; | |||
| template <typename T> | |||
| class EighGpuKernel : public GpuKernel { | |||
| public: | |||
| @@ -47,6 +49,7 @@ class EighGpuKernel : public GpuKernel { | |||
| dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0); | |||
| auto A_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| compute_eigen_vectors_ = static_cast<bool>(GetAttr<bool>(kernel_node, C_EIEH_VECTOR)); | |||
| lower_ = static_cast<bool>(GetAttr<bool>(kernel_node, LOWER)); | |||
| if (compute_eigen_vectors_) { | |||
| jobz_ = CUSOLVER_EIG_MODE_VECTOR; | |||
| } else { | |||
| @@ -69,26 +72,23 @@ class EighGpuKernel : public GpuKernel { | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | |||
| // matrix A, input or output(eigenvector) | |||
| auto inout_A_addr = GetDeviceAddress<T>(inputs, 0); | |||
| auto lower = GetDeviceAddress<bool>(inputs, 1); | |||
| bool h_lower{true}; | |||
| CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, | |||
| cudaMemcpyAsync(&h_lower, lower, sizeof(bool), cudaMemcpyDeviceToHost, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)), | |||
| "copy to lower to device failed"); | |||
| if (h_lower) { | |||
| uplo_ = CUBLAS_FILL_MODE_LOWER; | |||
| } else { | |||
| auto inout_A_addr = GetDeviceAddress<T>(inputs, kDim0); | |||
| // Notice :this is important | |||
| // a col or row major is different to cpu, so a lower triangle is a upper triangle, a upper is a lower in gpu mem | |||
| // so the upper is positive to it from var, but for real scalar matrix, upper eq lower, it's different from complex | |||
| if (lower_) { | |||
| uplo_ = CUBLAS_FILL_MODE_UPPER; | |||
| } else { | |||
| uplo_ = CUBLAS_FILL_MODE_LOWER; | |||
| } | |||
| auto output_addr = GetDeviceAddress<T>(outputs, 0); // output eigenvalues | |||
| auto output_v_addr = GetDeviceAddress<T>(outputs, 1); // output eigenvalues | |||
| auto output_addr = GetDeviceAddress<T>(outputs, kDim0); // output eigenvalues | |||
| auto output_v_addr = GetDeviceAddress<T>(outputs, kDim1); // output eigenvalues | |||
| auto w_v_addr = GetDeviceAddress<T>(workspace, kDim0); // temp eigenvector before transpose | |||
| CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, | |||
| cudaMemcpyAsync(output_v_addr, inout_A_addr, m_ * m_ * sizeof(T), | |||
| cudaMemcpyDeviceToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)), | |||
| cudaMemcpyAsync(w_v_addr, inout_A_addr, m_ * m_ * sizeof(T), cudaMemcpyDeviceToDevice, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)), | |||
| "copy to input matrix failed"); | |||
| size_t lda_ = m_; | |||
| int lwork = 0; | |||
| if constexpr (std::is_same_v<T, float>) { | |||
| cusolverDnSsyevd_bufferSize(cusolver_handle_, jobz_, uplo_, m_, inout_A_addr, lda_, output_addr, &lwork); | |||
| @@ -100,10 +100,22 @@ class EighGpuKernel : public GpuKernel { | |||
| T *d_work = nullptr; | |||
| cudaMalloc(reinterpret_cast<void **>(&d_work), sizeof(T) * lwork); | |||
| if constexpr (std::is_same_v<T, float>) { | |||
| cusolverDnSsyevd(cusolver_handle_, jobz_, uplo_, m_, output_v_addr, lda_, output_addr, d_work, lwork, devInfo); | |||
| cusolverDnSsyevd(cusolver_handle_, jobz_, uplo_, m_, w_v_addr, lda_, output_addr, d_work, lwork, devInfo); | |||
| } else if constexpr (std::is_same_v<T, double>) { | |||
| cusolverDnDsyevd(cusolver_handle_, jobz_, uplo_, m_, output_v_addr, lda_, output_addr, d_work, lwork, devInfo); | |||
| cusolverDnDsyevd(cusolver_handle_, jobz_, uplo_, m_, w_v_addr, lda_, output_addr, d_work, lwork, devInfo); | |||
| } | |||
| size_t input_shape[kShape2dDims] = {m_, m_}; | |||
| size_t input_axis[kShape2dDims] = {1, 0}; | |||
| size_t *dev_input_shape = nullptr; | |||
| cudaMalloc(reinterpret_cast<void **>(&dev_input_shape), kShape2dDims * sizeof(size_t)); | |||
| size_t *dev_input_axis = nullptr; | |||
| cudaMalloc(reinterpret_cast<void **>(&dev_input_axis), kShape2dDims * sizeof(size_t)); | |||
| cudaMemcpyAsync(dev_input_shape, input_shape, kShape2dDims * sizeof(size_t), cudaMemcpyHostToDevice, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| cudaMemcpyAsync(dev_input_axis, input_axis, kShape2dDims * sizeof(size_t), cudaMemcpyHostToDevice, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| CalTranspose(m_ * m_, w_v_addr, dev_input_shape, dev_input_axis, kShape2dDims, output_v_addr, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| if (d_work) { | |||
| cudaFree(d_work); | |||
| } | |||
| @@ -125,12 +137,11 @@ class EighGpuKernel : public GpuKernel { | |||
| void InitSizeLists() override { | |||
| // in/out matrix, eigenvector | |||
| input_size_list_.push_back(m_ * m_ * sizeof(T)); | |||
| // uplo | |||
| input_size_list_.push_back(sizeof(bool)); | |||
| // eigenvalues | |||
| output_size_list_.push_back(m_ * sizeof(T)); | |||
| // eigenvector | |||
| output_size_list_.push_back(m_ * m_ * sizeof(T)); | |||
| workspace_size_list_.push_back(m_ * m_ * sizeof(T)); | |||
| } | |||
| size_t m_{1}; | |||
| @@ -139,6 +150,7 @@ class EighGpuKernel : public GpuKernel { | |||
| cublasFillMode_t uplo_ = CUBLAS_FILL_MODE_UPPER; | |||
| cusolverEigMode_t jobz_ = CUSOLVER_EIG_MODE_NOVECTOR; | |||
| bool compute_eigen_vectors_{false}; | |||
| bool lower_{true}; | |||
| std::vector<T *> h_array_{}; | |||
| std::vector<size_t> input_size_list_{}; | |||
| std::vector<size_t> output_size_list_{}; | |||
| @@ -18,9 +18,10 @@ from .. import ops | |||
| from .ops import SolveTriangular | |||
| from .ops import CholeskySolver | |||
| from .ops import Cholesky | |||
| from .ops import EighNet | |||
| from ..ops import operations as P | |||
| __all__ = ['block_diag', 'solve_triangular', 'inv', 'cho_factor', 'cholesky', 'cho_solve'] | |||
| __all__ = ['block_diag', 'solve_triangular', 'inv', 'cho_factor', 'cholesky', 'cho_solve', 'eigh'] | |||
| def block_diag(*arrs): | |||
| @@ -318,3 +319,84 @@ def cho_solve(c_and_lower, b, overwrite_b=False, check_finite=True): | |||
| cholesky_solver_net = CholeskySolver(lower=lower) | |||
| x = cholesky_solver_net(c, b) | |||
| return x | |||
| def eigh(a, b=None, lower=True, eigvals_only=False, overwrite_a=False, | |||
| overwrite_b=False, turbo=True, eigvals=None, _type=1, | |||
| check_finite=True): | |||
| """ | |||
| Solve a standard or generalized eigenvalue problem for a complex | |||
| Hermitian or real symmetric matrix. | |||
| Find eigenvalues Tensor ``w`` and optionally eigenvectors Tensor ``v`` of | |||
| Tensor ``a``, where ``b`` is positive definite such that for every | |||
| eigenvalue λ (i-th entry of w) and its eigenvector ``vi`` (i-th column of | |||
| ``v``) satisfies:: | |||
| a @ vi = λ * b @ vi | |||
| vi.conj().T @ a @ vi = λ | |||
| vi.conj().T @ b @ vi = 1 | |||
| In the standard problem, ``b`` is assumed to be the identity matrix. | |||
| Args: | |||
| a (Tensor): (M, M) Tensor | |||
| A complex Hermitian or real symmetric matrix whose eigenvalues and | |||
| eigenvectors will be computed. | |||
| b (Tensor, optional): (M, M) Tensor | |||
| A complex Hermitian or real symmetric definite positive matrix in. | |||
| If omitted, identity matrix is assumed. | |||
| lower (bool, optional): Whether the pertinent Tensor data is taken from | |||
| the lower or upper triangle of ``a`` and, if applicable, ``b``. (Default: lower) | |||
| eigvals_only (bool, optional): Whether to calculate only eigenvalues | |||
| and no eigenvectors. (Default: both are calculated) | |||
| _type (int, optional): For the generalized problems, this keyword specifies | |||
| the problem type to be solved for ``w`` and ``v`` (only takes 1, 2, 3 as possible | |||
| inputs):: | |||
| 1 => a @ v = w @ b @ v | |||
| 2 => a @ b @ v = w @ v | |||
| 3 => b @ a @ v = w @ v | |||
| This keyword is ignored for standard problems. | |||
| overwrite_a (bool, optional): Whether to overwrite data in ``a`` | |||
| (may improve performance). Default is False. | |||
| overwrite_b (bool, optional): Whether to overwrite data in ``b`` | |||
| (may improve performance). Default is False. | |||
| check_finite (bool, optional): Whether to check that the input matrices | |||
| contain only finite numbers. | |||
| Disabling may give a performance gain, but may result in problems | |||
| (crashes, non-termination) if the inputs do contain infinities or NaNs. | |||
| turbo (bool, optional): use divide and conquer algorithm (faster but | |||
| expensive in memory, only for generalized eigenvalue problem and | |||
| if full set of eigenvalues are requested.). Has no significant | |||
| effect if eigenvectors are not requested. | |||
| eigvals (tuple, optional): Indexes of the smallest and largest (in ascending order) | |||
| eigenvalues and corresponding eigenvectors to be returned: 0 <= lo <= hi <= M-1. | |||
| If omitted, all eigenvalues and eigenvectors are returned. | |||
| Returns: | |||
| w (Tensor): (N,) Tensor, The N (1<=N<=M) selected eigenvalues, in ascending order, | |||
| each repeated according to its multiplicity. | |||
| v (Tensor): (M, N) Tensor, (if ``eigvals_only == False``) | |||
| Raises: | |||
| LinAlgError: If eigenvalue computation does not converge, an error occurred, or | |||
| b matrix is not definite positive. Note that if input matrices are | |||
| not symmetric or Hermitian, no error will be reported but results will | |||
| be wrong. | |||
| Supported Platforms: | |||
| ``CPU`` ``GPU`` | |||
| Examples: | |||
| >>> import numpy as onp | |||
| >>> from mindspore.common import Tensor | |||
| >>> from mindspore.scipy.linalg import eigh | |||
| >>> A = Tensor(onp.array([[6, 3, 1, 5], [3, 0, 5, 1], [1, 5, 6, 2], [5, 1, 2, 2]])) | |||
| >>> w, v = eigh(A) | |||
| >>> onp.allclose(A @ v - v @ onp.diag(w), onp.zeros((4, 4))) | |||
| True | |||
| """ | |||
| eigh_net = EighNet(not eigvals_only, lower=True) | |||
| return eigh_net(a) | |||
| @@ -194,42 +194,21 @@ class Eigh(PrimitiveWithInfer): | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, compute_eigenvectors): | |||
| def __init__(self, compute_eigenvectors=True, lower=True): | |||
| super().__init__(name="Eigh") | |||
| self.init_prim_io_names(inputs=['A', 's'], outputs=['output', 'output_v']) | |||
| self.init_prim_io_names(inputs=['A'], outputs=['output_w', 'output_v']) | |||
| self.compute_eigenvectors = validator.check_value_type( | |||
| "compute_eigenvectors", compute_eigenvectors, [bool], self.name) | |||
| self.lower = validator.check_value_type("lower", lower, [bool], self.lower) | |||
| self.add_prim_attr('lower', self.lower) | |||
| self.add_prim_attr('compute_eigenvectors', self.compute_eigenvectors) | |||
| def __infer__(self, A, s): | |||
| def __infer__(self, A): | |||
| shape = { | |||
| 'shape': ((A['shape'][0],), (A['shape'][0], A['shape'][0])), | |||
| 'dtype': (A['dtype'], A['dtype']), | |||
| 'value': None | |||
| } | |||
| if A['dtype'] == mstype.tensor_type(mstype.float32): | |||
| shape = { | |||
| 'shape': ((A['shape'][0],), (A['shape'][0], A['shape'][0])), | |||
| 'dtype': (mstype.float32, mstype.float32), | |||
| 'value': None | |||
| } | |||
| elif A['dtype'] == mstype.tensor_type(mstype.float64): | |||
| shape = { | |||
| 'shape': ((A['shape'][0],), (A['shape'][0], A['shape'][0])), | |||
| 'dtype': (mstype.float64, mstype.float64), | |||
| 'value': None | |||
| } | |||
| elif A['dtype'] == mstype.tensor_type(mstype.complex64): | |||
| shape = { | |||
| 'shape': ((A['shape'][0],), (A['shape'][0], A['shape'][0])), | |||
| 'dtype': (A['dtype'], A['dtype']), | |||
| 'value': None | |||
| } | |||
| elif A['dtype'] == mstype.tensor_type(mstype.complex128): | |||
| shape = { | |||
| 'shape': ((A['shape'][0],), (A['shape'][0], A['shape'][0])), | |||
| 'dtype': (mstype.complex128, mstype.complex128), | |||
| 'value': None | |||
| } | |||
| return shape | |||
| @@ -238,16 +217,17 @@ class EighNet(nn.Cell): | |||
| EigenValue /eigenvector solver for symmetric/Hermitian matrix | |||
| Ax = lambda * x | |||
| """ | |||
| def __init__(self, b): | |||
| def __init__(self, bv=True, lower=True): | |||
| super(EighNet, self).__init__() | |||
| self.b = b | |||
| self.eigh = Eigh(b) | |||
| self.bv = bv | |||
| self.eigh = Eigh(bv, lower) | |||
| def construct(self, A, s=True): | |||
| r = self.eigh(A, s) | |||
| if self.b: | |||
| def construct(self, A): | |||
| r = self.eigh(A) | |||
| if self.bv: | |||
| return (r[0], r[1]) | |||
| return (r[0],) | |||
| return r[0] | |||
| class Eig(PrimitiveWithInfer): | |||
| @@ -257,7 +237,7 @@ class Eig(PrimitiveWithInfer): | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, compute_eigenvectors): | |||
| def __init__(self, compute_eigenvectors=True): | |||
| super().__init__(name="Eig") | |||
| self.init_prim_io_names(inputs=['A'], outputs=['output', 'output_v']) | |||
| self.compute_eigenvectors = validator.check_value_type( | |||
| @@ -285,13 +265,14 @@ class EigNet(nn.Cell): | |||
| EigenValue /eigenvector solver for generic matrix | |||
| Ax = lambda * x | |||
| """ | |||
| def __init__(self, b): | |||
| def __init__(self, bv=True): | |||
| super(EigNet, self).__init__() | |||
| self.b = b | |||
| self.eig = Eig(b) | |||
| self.bv = bv | |||
| self.eig = Eig(bv) | |||
| def construct(self, A): | |||
| r = self.eig(A) | |||
| if self.b: | |||
| if self.bv: | |||
| return (r[0], r[1]) | |||
| return (r[0],) | |||
| return r[0] | |||
| @@ -31,11 +31,6 @@ def match(v, v_, error=0): | |||
| np.testing.assert_equal(v, v_) | |||
| def create_sym_pos_matrix(m, n, dtype): | |||
| a = (np.random.random((m, n)) + np.eye(m, n)).astype(dtype) | |||
| return np.dot(a, a.T) | |||
| @pytest.mark.parametrize('n', [4, 6, 9, 10]) | |||
| @pytest.mark.platform_x86_cpu | |||
| def test_eig_net(n: int): | |||
| @@ -48,13 +43,13 @@ def test_eig_net(n: int): | |||
| rtol = 1e-3 | |||
| atol = 1e-4 | |||
| msp_eig = EigNet(True) | |||
| A = create_sym_pos_matrix(n, n, np.float32) | |||
| A = np.array(np.random.rand(n, n), dtype=np.float32) | |||
| tensor_a = Tensor(np.array(A).astype(np.float32)) | |||
| msp_w, msp_v = msp_eig(tensor_a) | |||
| assert np.allclose(A @ msp_v.asnumpy() - msp_v.asnumpy() @ np.diag(msp_w.asnumpy()), np.zeros((n, n)), rtol, atol) | |||
| # test case for real scalar double 64 | |||
| A = np.random.rand(n, n) | |||
| A = np.array(np.random.rand(n, n), dtype=np.float64) | |||
| rtol = 1e-5 | |||
| atol = 1e-8 | |||
| msp_eig = EigNet(True) | |||
| @@ -98,6 +93,7 @@ def test_eig_net(n: int): | |||
| # Com`pare with scipy, scipy passed | |||
| # sp_w, sp_v = sp.linalg.eig(A.astype(np.complex128)) | |||
| # assert np.allclose(A @ sp_v - sp_v @ np.diag(sp_w), np.zeros((n, n)), rtol, atol) | |||
| # print(A @ msp_v.asnumpy() - msp_v.asnumpy() @ np.diag(msp_w.asnumpy())) | |||
| assert np.allclose(A @ msp_v.asnumpy() - msp_v.asnumpy() @ np.diag(msp_w.asnumpy()), np.zeros((n, n)), rtol, atol) | |||
| msp_eig = EigNet(False) | |||
| msp_w0 = msp_eig(Tensor(np.array(A).astype(np.complex128))) | |||
| assert np.allclose(msp_w0.asnumpy() - msp_w.asnumpy(), np.zeros((n, n)), rtol, atol) | |||
| @@ -47,10 +47,12 @@ def test_eigh_net(n: int): | |||
| # test for real scalar float 32 | |||
| rtol = 1e-3 | |||
| atol = 1e-4 | |||
| msp_eigh = EighNet(True) | |||
| A = create_sym_pos_matrix(n, n, np.float32) | |||
| msp_wl, msp_vl = msp_eigh(Tensor(np.array(A).astype(np.float32)), True) | |||
| msp_wu, msp_vu = msp_eigh(Tensor(np.array(A).astype(np.float32)), False) | |||
| msp_eigh = EighNet(True, True) | |||
| msp_wl, msp_vl = msp_eigh(Tensor(np.array(A).astype(np.float32))) | |||
| msp_eigh = EighNet(True, False) | |||
| msp_wu, msp_vu = msp_eigh(Tensor(np.array(A).astype(np.float32))) | |||
| sym_Al = (np.tril((np.tril(A) - np.tril(A).T)) + np.tril(A).T) | |||
| sym_Au = (np.triu((np.triu(A) - np.triu(A).T)) + np.triu(A).T) | |||
| assert np.allclose(sym_Al @ msp_vl.asnumpy() - msp_vl.asnumpy() @ np.diag(msp_wl.asnumpy()), np.zeros((n, n)), rtol, | |||
| @@ -62,19 +64,23 @@ def test_eigh_net(n: int): | |||
| A = np.random.rand(n, n) | |||
| rtol = 1e-5 | |||
| atol = 1e-8 | |||
| msp_eigh = EighNet(True) | |||
| msp_wl, msp_vl = msp_eigh(Tensor(np.array(A).astype(np.float64)), True) | |||
| msp_wu, msp_vu = msp_eigh(Tensor(np.array(A).astype(np.float64)), False) | |||
| # Compare with scipy | |||
| # sp_wl, sp_vl = sp.linalg.eigh(np.tril(A).astype(np.float64), lower=True, eigvals_only=False) | |||
| # sp_wu, sp_vu = sp.linalg.eigh(A.astype(np.float64), lower=False, eigvals_only=False) | |||
| msp_eigh = EighNet(True, True) | |||
| msp_wl, msp_vl = msp_eigh(Tensor(np.array(A).astype(np.float64))) | |||
| msp_eigh = EighNet(True, False) | |||
| msp_wu, msp_vu = msp_eigh(Tensor(np.array(A).astype(np.float64))) | |||
| sym_Al = (np.tril((np.tril(A) - np.tril(A).T)) + np.tril(A).T) | |||
| sym_Au = (np.triu((np.triu(A) - np.triu(A).T)) + np.triu(A).T) | |||
| assert np.allclose(sym_Al @ msp_vl.asnumpy() - msp_vl.asnumpy() @ np.diag(msp_wl.asnumpy()), np.zeros((n, n)), rtol, | |||
| atol) | |||
| assert np.allclose(sym_Au @ msp_vu.asnumpy() - msp_vu.asnumpy() @ np.diag(msp_wu.asnumpy()), np.zeros((n, n)), rtol, | |||
| atol) | |||
| # test for real scalar float64 no vector | |||
| msp_eigh = EighNet(False, True) | |||
| msp_wl0 = msp_eigh(Tensor(np.array(A).astype(np.float64))) | |||
| msp_eigh = EighNet(False, False) | |||
| msp_wu0 = msp_eigh(Tensor(np.array(A).astype(np.float64))) | |||
| assert np.allclose(msp_wl.asnumpy() - msp_wl0.asnumpy(), np.zeros((n, n)), rtol, atol) | |||
| assert np.allclose(msp_wu.asnumpy() - msp_wu0.asnumpy(), np.zeros((n, n)), rtol, atol) | |||
| # test case for complex64 | |||
| rtol = 1e-3 | |||
| @@ -86,18 +92,12 @@ def test_eigh_net(n: int): | |||
| A[i][j] = complex(np.random.rand(1, 1), 0) | |||
| else: | |||
| A[i][j] = complex(np.random.rand(1, 1), np.random.rand(1, 1)) | |||
| msp_eigh = EighNet(True) | |||
| sym_Al = (np.tril((np.tril(A) - np.tril(A).T)) + np.tril(A).conj().T) | |||
| sym_Au = (np.triu((np.triu(A) - np.triu(A).T)) + np.triu(A).conj().T) | |||
| msp_wl, msp_vl = msp_eigh(Tensor(np.array(A).astype(np.complex64)), True) | |||
| msp_wu, msp_vu = msp_eigh(Tensor(np.array(A).astype(np.complex64)), False) | |||
| # Compare with scipy, scipy passed | |||
| # sp_wl, sp_vl = sp.linalg.eigh(np.tril(A).astype(np.complex128), lower=True, eigvals_only=False) | |||
| # sp_wu, sp_vu = sp.linalg.eigh(A.astype(np.complex128), lower=False, eigvals_only=False) | |||
| # assert np.allclose(sym_Al @ sp_vl - sp_vl @ np.diag(sp_wl), np.zeros((n, n)), rtol, atol) | |||
| # assert np.allclose(sym_Au @ sp_vu - sp_vu @ np.diag(sp_wu), np.zeros((n, n)), rtol, atol) | |||
| # print(A @ msp_v.asnumpy() - msp_v.asnumpy() @ np.diag(msp_w.asnumpy())) | |||
| msp_eigh = EighNet(True, True) | |||
| msp_wl, msp_vl = msp_eigh(Tensor(np.array(A).astype(np.complex64))) | |||
| msp_eigh = EighNet(True, False) | |||
| msp_wu, msp_vu = msp_eigh(Tensor(np.array(A).astype(np.complex64))) | |||
| assert np.allclose(sym_Al @ msp_vl.asnumpy() - msp_vl.asnumpy() @ np.diag(msp_wl.asnumpy()), np.zeros((n, n)), rtol, | |||
| atol) | |||
| assert np.allclose(sym_Au @ msp_vu.asnumpy() - msp_vu.asnumpy() @ np.diag(msp_wu.asnumpy()), np.zeros((n, n)), rtol, | |||
| @@ -113,19 +113,21 @@ def test_eigh_net(n: int): | |||
| A[i][j] = complex(np.random.rand(1, 1), 0) | |||
| else: | |||
| A[i][j] = complex(np.random.rand(1, 1), np.random.rand(1, 1)) | |||
| msp_eigh = EighNet(True) | |||
| sym_Al = (np.tril((np.tril(A) - np.tril(A).T)) + np.tril(A).conj().T) | |||
| sym_Au = (np.triu((np.triu(A) - np.triu(A).T)) + np.triu(A).conj().T) | |||
| msp_wl, msp_vl = msp_eigh(Tensor(np.array(A).astype(np.complex128)), True) | |||
| msp_wu, msp_vu = msp_eigh(Tensor(np.array(A).astype(np.complex128)), False) | |||
| # Compare with scipy, scipy passed | |||
| # sp_wl, sp_vl = sp.linalg.eigh(np.tril(A).astype(np.complex128), lower=True, eigvals_only=False) | |||
| # sp_wu, sp_vu = sp.linalg.eigh(A.astype(np.complex128), lower=False, eigvals_only=False) | |||
| # assert np.allclose(sym_Al @ sp_vl - sp_vl @ np.diag(sp_wl), np.zeros((n, n)), rtol, atol) | |||
| # assert np.allclose(sym_Au @ sp_vu - sp_vu @ np.diag(sp_wu), np.zeros((n, n)), rtol, atol) | |||
| # print(A @ msp_v.asnumpy() - msp_v.asnumpy() @ np.diag(msp_w.asnumpy())) | |||
| msp_eigh = EighNet(True, True) | |||
| msp_wl, msp_vl = msp_eigh(Tensor(np.array(A).astype(np.complex128))) | |||
| msp_eigh = EighNet(True, False) | |||
| msp_wu, msp_vu = msp_eigh(Tensor(np.array(A).astype(np.complex128))) | |||
| assert np.allclose(sym_Al @ msp_vl.asnumpy() - msp_vl.asnumpy() @ np.diag(msp_wl.asnumpy()), np.zeros((n, n)), rtol, | |||
| atol) | |||
| assert np.allclose(sym_Au @ msp_vu.asnumpy() - msp_vu.asnumpy() @ np.diag(msp_wu.asnumpy()), np.zeros((n, n)), rtol, | |||
| atol) | |||
| # test for real scalar complex128 no vector | |||
| msp_eigh = EighNet(False, True) | |||
| msp_wl0 = msp_eigh(Tensor(np.array(A).astype(np.complex128))) | |||
| msp_eigh = EighNet(False, False) | |||
| msp_wu0 = msp_eigh(Tensor(np.array(A).astype(np.complex128))) | |||
| assert np.allclose(msp_wl.asnumpy() - msp_wl0.asnumpy(), np.zeros((n, n)), rtol, atol) | |||
| assert np.allclose(msp_wu.asnumpy() - msp_wu0.asnumpy(), np.zeros((n, n)), rtol, atol) | |||
| @@ -47,26 +47,35 @@ def test_eigh_net(n: int): | |||
| # test for real scalar float 32 | |||
| rtol = 1e-3 | |||
| atol = 1e-4 | |||
| msp_eigh = EighNet(True) | |||
| A = create_sym_pos_matrix(n, n, np.float32) | |||
| msp_wl, msp_vl = msp_eigh(Tensor(np.array(A).astype(np.float32)), True) | |||
| msp_wu, msp_vu = msp_eigh(Tensor(np.array(A).astype(np.float32)), False) | |||
| assert np.allclose(A @ msp_vl.T.asnumpy() - msp_vl.T.asnumpy() @ np.diag(msp_wl.asnumpy()), np.zeros((n, n)), rtol, | |||
| msp_eigh = EighNet(True, True) | |||
| msp_wl, msp_vl = msp_eigh(Tensor(np.array(A).astype(np.float32))) | |||
| msp_eigh = EighNet(True, False) | |||
| msp_wu, msp_vu = msp_eigh(Tensor(np.array(A).astype(np.float32))) | |||
| assert np.allclose(A @ msp_vl.asnumpy() - msp_vl.asnumpy() @ np.diag(msp_wl.asnumpy()), np.zeros((n, n)), rtol, | |||
| atol) | |||
| assert np.allclose(A @ msp_vu.T.asnumpy() - msp_vu.T.asnumpy() @ np.diag(msp_wu.asnumpy()), np.zeros((n, n)), rtol, | |||
| assert np.allclose(A @ msp_vu.asnumpy() - msp_vu.asnumpy() @ np.diag(msp_wu.asnumpy()), np.zeros((n, n)), rtol, | |||
| atol) | |||
| # test case for real scalar double 64 | |||
| A = create_sym_pos_matrix(n, n, np.float64) | |||
| rtol = 1e-5 | |||
| atol = 1e-8 | |||
| msp_eigh = EighNet(True) | |||
| msp_wl, msp_vl = msp_eigh(Tensor(np.array(A).astype(np.float64)), True) | |||
| msp_wu, msp_vu = msp_eigh(Tensor(np.array(A).astype(np.float64)), False) | |||
| assert np.allclose(A @ msp_vl.T.asnumpy() - msp_vl.T.asnumpy() @ np.diag(msp_wl.asnumpy()), np.zeros((n, n)), rtol, | |||
| msp_eigh = EighNet(True, True) | |||
| msp_wl, msp_vl = msp_eigh(Tensor(np.array(A).astype(np.float64))) | |||
| msp_eigh = EighNet(True, False) | |||
| msp_wu, msp_vu = msp_eigh(Tensor(np.array(A).astype(np.float64))) | |||
| assert np.allclose(A @ msp_vl.asnumpy() - msp_vl.asnumpy() @ np.diag(msp_wl.asnumpy()), np.zeros((n, n)), rtol, | |||
| atol) | |||
| assert np.allclose(A @ msp_vu.T.asnumpy() - msp_vu.T.asnumpy() @ np.diag(msp_wu.asnumpy()), np.zeros((n, n)), rtol, | |||
| assert np.allclose(A @ msp_vu.asnumpy() - msp_vu.asnumpy() @ np.diag(msp_wu.asnumpy()), np.zeros((n, n)), rtol, | |||
| atol) | |||
| # test for real scalar float64 no vector | |||
| msp_eigh = EighNet(False, True) | |||
| msp_wl0 = msp_eigh(Tensor(np.array(A).astype(np.float64))) | |||
| msp_eigh = EighNet(False, False) | |||
| msp_wu0 = msp_eigh(Tensor(np.array(A).astype(np.float64))) | |||
| assert np.allclose(msp_wl.asnumpy() - msp_wl0.asnumpy(), np.zeros((n, n)), rtol, atol) | |||
| assert np.allclose(msp_wu.asnumpy() - msp_wu0.asnumpy(), np.zeros((n, n)), rtol, atol) | |||
| # test case for complex64 | |||
| rtol = 1e-3 | |||
| @@ -78,14 +87,15 @@ def test_eigh_net(n: int): | |||
| A[i][j] = complex(np.random.rand(1, 1), 0) | |||
| else: | |||
| A[i][j] = complex(np.random.rand(1, 1), np.random.rand(1, 1)) | |||
| msp_eigh = EighNet(True) | |||
| sym_Al = (np.tril((np.tril(A) - np.tril(A).T)) + np.tril(A).conj().T) | |||
| sym_Au = (np.triu((np.triu(A) - np.triu(A).T)) + np.triu(A).conj().T) | |||
| msp_wl, msp_vl = msp_eigh(Tensor(np.array(sym_Al).astype(np.complex64)), True) | |||
| msp_wu, msp_vu = msp_eigh(Tensor(np.array(sym_Au).astype(np.complex64)), False) | |||
| assert np.allclose(sym_Al @ msp_vl.asnumpy().conj().T - msp_vl.asnumpy().conj().T @ np.diag(msp_wl.asnumpy()), | |||
| msp_eigh = EighNet(True, True) | |||
| msp_wl, msp_vl = msp_eigh(Tensor(np.array(A).astype(np.complex64))) | |||
| msp_eigh = EighNet(True, False) | |||
| msp_wu, msp_vu = msp_eigh(Tensor(np.array(A).astype(np.complex64))) | |||
| assert np.allclose(sym_Al @ msp_vl.asnumpy() - msp_vl.asnumpy() @ np.diag(msp_wl.asnumpy()), | |||
| np.zeros((n, n)), rtol, atol) | |||
| assert np.allclose(sym_Au @ msp_vu.asnumpy().conj().T - msp_vu.asnumpy().conj().T @ np.diag(msp_wu.asnumpy()), | |||
| assert np.allclose(sym_Au @ msp_vu.asnumpy() - msp_vu.asnumpy() @ np.diag(msp_wu.asnumpy()), | |||
| np.zeros((n, n)), rtol, atol) | |||
| # test for complex128 | |||
| @@ -94,17 +104,24 @@ def test_eigh_net(n: int): | |||
| A = np.array(np.random.rand(n, n), dtype=np.complex128) | |||
| for i in range(0, n): | |||
| for j in range(0, n): | |||
| if i == j: | |||
| A[i][j] = complex(np.random.rand(1, 1), 0) | |||
| else: | |||
| A[i][j] = complex(np.random.rand(1, 1), np.random.rand(1, 1)) | |||
| msp_eigh = EighNet(True) | |||
| sym_Al = (np.tril((np.tril(A) - np.tril(A).T)) + np.tril(A).conj().T) | |||
| sym_Au = (np.triu((np.triu(A) - np.triu(A).T)) + np.triu(A).conj().T) | |||
| msp_wl, msp_vl = msp_eigh(Tensor(np.array(sym_Al).astype(np.complex128)), True) | |||
| msp_wu, msp_vu = msp_eigh(Tensor(np.array(sym_Au).astype(np.complex128)), False) | |||
| assert np.allclose(sym_Al @ msp_vl.asnumpy().conj().T - msp_vl.asnumpy().conj().T @ np.diag(msp_wl.asnumpy()), | |||
| msp_eigh = EighNet(True, True) | |||
| msp_wl, msp_vl = msp_eigh(Tensor(np.array(A).astype(np.complex128))) | |||
| msp_eigh = EighNet(True, False) | |||
| msp_wu, msp_vu = msp_eigh(Tensor(np.array(A).astype(np.complex128))) | |||
| assert np.allclose(sym_Al @ msp_vl.asnumpy() - msp_vl.asnumpy() @ np.diag(msp_wl.asnumpy()), | |||
| np.zeros((n, n)), rtol, atol) | |||
| assert np.allclose(sym_Au @ msp_vu.asnumpy().conj().T - msp_vu.asnumpy().conj().T @ np.diag(msp_wu.asnumpy()), | |||
| assert np.allclose(sym_Au @ msp_vu.asnumpy() - msp_vu.asnumpy() @ np.diag(msp_wu.asnumpy()), | |||
| np.zeros((n, n)), rtol, atol) | |||
| # test for real scalar complex128 no vector | |||
| msp_eigh = EighNet(False, True) | |||
| msp_wl0 = msp_eigh(Tensor(np.array(A).astype(np.complex128))) | |||
| msp_eigh = EighNet(False, False) | |||
| msp_wu0 = msp_eigh(Tensor(np.array(A).astype(np.complex128))) | |||
| assert np.allclose(msp_wl.asnumpy() - msp_wl0.asnumpy(), np.zeros((n, n)), rtol, atol) | |||
| assert np.allclose(msp_wu.asnumpy() - msp_wu0.asnumpy(), np.zeros((n, n)), rtol, atol) | |||
| @@ -139,3 +139,91 @@ def test_cholesky_solver(n: int, lower: bool, dtype): | |||
| # pre tensor_a has been inplace. | |||
| tensor_a = Tensor(a) | |||
| assert onp.allclose(onp.dot(a, osp_x), mnp.dot(tensor_a, msp_x).asnumpy()) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| @pytest.mark.parametrize('n', [4, 6, 9, 20]) | |||
| def test_eigh_solver(n: int): | |||
| """ | |||
| Feature: ALL TO ALL | |||
| Description: test cases for eigenvalues/eigenvector for symmetric/Hermitian matrix solver [N,N] | |||
| Expectation: the result match scipy cholesky_solve | |||
| """ | |||
| # test for real scalar float 32 | |||
| rtol = 1e-3 | |||
| atol = 1e-4 | |||
| A = create_sym_pos_matrix([n, n], onp.float32) | |||
| msp_wl, msp_vl = msp.linalg.eigh(Tensor(onp.array(A).astype(onp.float32)), lower=True, eigvals_only=False) | |||
| msp_wu, msp_vu = msp.linalg.eigh(Tensor(onp.array(A).astype(onp.float32)), lower=False, eigvals_only=False) | |||
| assert onp.allclose(A @ msp_vl.asnumpy() - msp_vl.asnumpy() @ onp.diag(msp_wl.asnumpy()), onp.zeros((n, n)), | |||
| rtol, | |||
| atol) | |||
| assert onp.allclose(A @ msp_vu.asnumpy() - msp_vu.asnumpy() @ onp.diag(msp_wu.asnumpy()), onp.zeros((n, n)), | |||
| rtol, | |||
| atol) | |||
| # test case for real scalar double 64 | |||
| A = create_sym_pos_matrix([n, n], onp.float64) | |||
| rtol = 1e-5 | |||
| atol = 1e-8 | |||
| msp_wl, msp_vl = msp.linalg.eigh(Tensor(onp.array(A).astype(onp.float64)), lower=True, eigvals_only=False) | |||
| msp_wu, msp_vu = msp.linalg.eigh(Tensor(onp.array(A).astype(onp.float64)), lower=False, eigvals_only=False) | |||
| assert onp.allclose(A @ msp_vl.asnumpy() - msp_vl.asnumpy() @ onp.diag(msp_wl.asnumpy()), onp.zeros((n, n)), | |||
| rtol, | |||
| atol) | |||
| assert onp.allclose(A @ msp_vu.asnumpy() - msp_vu.asnumpy() @ onp.diag(msp_wu.asnumpy()), onp.zeros((n, n)), | |||
| rtol, | |||
| atol) | |||
| # test for real scalar float64 no vector | |||
| msp_wl0 = msp.linalg.eigh(Tensor(onp.array(A).astype(onp.float64)), lower=True, eigvals_only=True) | |||
| msp_wu0 = msp.linalg.eigh(Tensor(onp.array(A).astype(onp.float64)), lower=False, eigvals_only=True) | |||
| assert onp.allclose(msp_wl.asnumpy() - msp_wl0.asnumpy(), onp.zeros((n, n)), rtol, atol) | |||
| assert onp.allclose(msp_wu.asnumpy() - msp_wu0.asnumpy(), onp.zeros((n, n)), rtol, atol) | |||
| # test case for complex64 | |||
| rtol = 1e-3 | |||
| atol = 1e-4 | |||
| A = onp.array(onp.random.rand(n, n), dtype=onp.complex64) | |||
| for i in range(0, n): | |||
| for j in range(0, n): | |||
| if i == j: | |||
| A[i][j] = complex(onp.random.rand(1, 1), 0) | |||
| else: | |||
| A[i][j] = complex(onp.random.rand(1, 1), onp.random.rand(1, 1)) | |||
| sym_Al = (onp.tril((onp.tril(A) - onp.tril(A).T)) + onp.tril(A).conj().T) | |||
| sym_Au = (onp.triu((onp.triu(A) - onp.triu(A).T)) + onp.triu(A).conj().T) | |||
| msp_wl, msp_vl = msp.linalg.eigh(Tensor(onp.array(sym_Al).astype(onp.complex64)), lower=True, eigvals_only=False) | |||
| msp_wu, msp_vu = msp.linalg.eigh(Tensor(onp.array(sym_Au).astype(onp.complex64)), lower=False, eigvals_only=False) | |||
| assert onp.allclose(sym_Al @ msp_vl.asnumpy() - msp_vl.asnumpy() @ onp.diag(msp_wl.asnumpy()), | |||
| onp.zeros((n, n)), rtol, atol) | |||
| assert onp.allclose(sym_Au @ msp_vu.asnumpy() - msp_vu.asnumpy() @ onp.diag(msp_wu.asnumpy()), | |||
| onp.zeros((n, n)), rtol, atol) | |||
| # test for complex128 | |||
| rtol = 1e-5 | |||
| atol = 1e-8 | |||
| A = onp.array(onp.random.rand(n, n), dtype=onp.complex128) | |||
| for i in range(0, n): | |||
| for j in range(0, n): | |||
| if i == j: | |||
| A[i][j] = complex(onp.random.rand(1, 1), 0) | |||
| else: | |||
| A[i][j] = complex(onp.random.rand(1, 1), onp.random.rand(1, 1)) | |||
| sym_Al = (onp.tril((onp.tril(A) - onp.tril(A).T)) + onp.tril(A).conj().T) | |||
| sym_Au = (onp.triu((onp.triu(A) - onp.triu(A).T)) + onp.triu(A).conj().T) | |||
| msp_wl, msp_vl = msp.linalg.eigh(Tensor(onp.array(sym_Al).astype(onp.complex128)), lower=True, eigvals_only=False) | |||
| msp_wu, msp_vu = msp.linalg.eigh(Tensor(onp.array(sym_Au).astype(onp.complex128)), lower=False, eigvals_only=False) | |||
| assert onp.allclose(sym_Al @ msp_vl.asnumpy() - msp_vl.asnumpy() @ onp.diag(msp_wl.asnumpy()), | |||
| onp.zeros((n, n)), rtol, atol) | |||
| assert onp.allclose(sym_Au @ msp_vu.asnumpy() - msp_vu.asnumpy() @ onp.diag(msp_wu.asnumpy()), | |||
| onp.zeros((n, n)), rtol, atol) | |||
| # test for real scalar float64 no vector | |||
| msp_wl0 = msp.linalg.eigh(Tensor(onp.array(sym_Al).astype(onp.complex128)), lower=True, eigvals_only=True) | |||
| msp_wu0 = msp.linalg.eigh(Tensor(onp.array(sym_Au).astype(onp.complex128)), lower=False, eigvals_only=True) | |||
| assert onp.allclose(msp_wl.asnumpy() - msp_wl0.asnumpy(), onp.zeros((n, n)), rtol, atol) | |||
| assert onp.allclose(msp_wu.asnumpy() - msp_wu0.asnumpy(), onp.zeros((n, n)), rtol, atol) | |||