| @@ -0,0 +1,54 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "triangle_matrix_copy_impl.cuh" | |||||
| template <typename T> | |||||
| __global__ void TriangleMatrixCopyKernel(const T *input, T *output, cublasFillMode_t uplo, | |||||
| const size_t count, const size_t ldb, const size_t m) { | |||||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { | |||||
| size_t batchIdx = i / (ldb * m); | |||||
| size_t row = (i - batchIdx * ldb * m) / m; | |||||
| size_t col = (i - batchIdx * ldb * m) % m; | |||||
| // If fill mode is 'CUBLAS_FILL_MODE_UPPER', the upper half of the matrix should be all 0; | |||||
| // If fill mode is 'CUBLAS_FILL_MODE_LOWER', the lower half of the matrix should be all 0; | |||||
| if (uplo == CUBLAS_FILL_MODE_UPPER) { | |||||
| if (col > row) { | |||||
| output[i] = 0; | |||||
| } else { | |||||
| output[i] = input[i]; | |||||
| } | |||||
| } else if (uplo == CUBLAS_FILL_MODE_LOWER) { | |||||
| if (col < row) { | |||||
| output[i] = 0; | |||||
| } else { | |||||
| output[i] = input[i]; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| template <typename T> | |||||
| void TriangleMatrixCopy(const T *input, T *output, cublasFillMode_t uplo, | |||||
| const size_t count, const size_t ldb, const size_t m, cudaStream_t cuda_stream) { | |||||
| TriangleMatrixCopyKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, uplo, count, ldb, m); | |||||
| return; | |||||
| } | |||||
| template void TriangleMatrixCopy<float>(const float *input, float *output, cublasFillMode_t uplo, const size_t count, | |||||
| const size_t ldb, const size_t m, cudaStream_t cuda_stream); | |||||
| template void TriangleMatrixCopy<half>(const half *input, half *output, cublasFillMode_t uplo, const size_t count, | |||||
| const size_t ldb, const size_t m, cudaStream_t cuda_stream); | |||||
| @@ -0,0 +1,24 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_TRIANGLEMATRIXCOPYIMPL_H_ | |||||
| #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_TRIANGLEMATRIXCOPYIMPL_H_ | |||||
| #include "runtime/device/gpu/cuda_common.h" | |||||
| template <typename T> | |||||
| void TriangleMatrixCopy(const T *input, T *output, cublasFillMode_t uplo, | |||||
| const size_t count, const size_t ldb, const size_t m, cudaStream_t cuda_stream); | |||||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_TRIANGLEMATRIXCOPYIMPL_H_ | |||||
| @@ -0,0 +1,23 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "backend/kernel_compiler/gpu/math/cholesky_solve_gpu_kernel.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| MS_REG_GPU_KERNEL_ONE(Cholesky, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||||
| CholeskyGpuKernel, float) | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,247 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CHOLESKY_SOLVE_GPU_KERNEL_H | |||||
| #define MINDSPORE_CHOLESKY_SOLVE_GPU_KERNEL_H | |||||
| #include <cublas_v2.h> | |||||
| #include <cuda_runtime_api.h> | |||||
| #include <vector> | |||||
| #include <algorithm> | |||||
| #include "backend/kernel_compiler/gpu/cuda_impl/identity_impl.cuh" | |||||
| #include "backend/kernel_compiler/gpu/cuda_impl/matrix_split_impl.cuh" | |||||
| #include "backend/kernel_compiler/gpu/cuda_impl/triangle_matrix_copy_impl.cuh" | |||||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | |||||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | |||||
| #include "backend/kernel_compiler/gpu/kernel_constants.h" | |||||
| #include "utils/convert_utils.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| template <typename T> | |||||
| class CholeskyGpuKernel : public GpuKernel { | |||||
| public: | |||||
| CholeskyGpuKernel() : batch_(0), m_(0), lda_(0), is_null_input_(false), handle_(nullptr) {} | |||||
| ~CholeskyGpuKernel() = default; | |||||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | |||||
| if (is_null_input_) { | |||||
| return true; | |||||
| } | |||||
| auto input1_addr = GetDeviceAddress<T>(inputs, 0); | |||||
| auto output_addr = GetDeviceAddress<T>(outputs, 0); | |||||
| auto d_array_addr = GetDeviceAddress<T *>(workspace, 0); | |||||
| auto d_identity_addr = GetDeviceAddress<T *>(workspace, 1); | |||||
| if (!use_split_matrix) { | |||||
| auto d_info_array_addr = GetDeviceAddress<int>(workspace, 2); | |||||
| for (size_t i = 0; i < batch_; i++) { | |||||
| h_array[i] = input1_addr + i * lda_ * m_; | |||||
| h_identity[i] = output_addr + i * ldb_ * m_; | |||||
| CHECK_CUDA_RET_WITH_ERROR( | |||||
| cudaMemcpyAsync(output_addr + i * ldb_ * m_, h_identity_data.data(), sizeof(T) * ldb_ * m_, | |||||
| cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)), | |||||
| "cuda memcopy Fail"); | |||||
| } | |||||
| CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(d_array_addr, h_array.data(), sizeof(T *) * batch_, | |||||
| cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)), | |||||
| "cuda memcopy Fail"); | |||||
| CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(d_identity_addr, h_identity.data(), sizeof(T *) * batch_, | |||||
| cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)), | |||||
| "cuda memcopy Fail"); | |||||
| CHECK_CUSOLVER_RET_WITH_EXCEPT( | |||||
| cusolverDnSpotrfBatched(handle_, uplo, m_, d_array_addr, lda_, d_info_array_addr, batch_), | |||||
| "cusolver cholesky batched Fail"); | |||||
| TriangleMatrixCopy(input1_addr, output_addr, uplo, outputs[0]->size / sizeof(T), ldb_, m_, | |||||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||||
| } else { | |||||
| auto d_info_array_addr = GetDeviceAddress<int>(workspace, 2); | |||||
| auto d_batch_input_addr = GetDeviceAddress<T>(workspace, 3); | |||||
| for (size_t i = 0; i < batch_; i++) { | |||||
| h_array[i] = d_batch_input_addr + i * lda_ * m_; | |||||
| h_identity[i] = output_addr + i * ldb_ * m_; | |||||
| } | |||||
| Identity(batch_ * split_dim * split_dim, split_dim, output_addr, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||||
| MatrixSplit(batch_ * split_dim * split_dim, split_dim, width, input1_addr, d_batch_input_addr, | |||||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||||
| CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(d_array_addr, h_array.data(), sizeof(T *) * batch_, | |||||
| cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)), | |||||
| "cuda memcopy Fail"); | |||||
| CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(d_identity_addr, h_identity.data(), sizeof(T *) * batch_, | |||||
| cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)), | |||||
| "cuda memcopy Fail"); | |||||
| CHECK_CUSOLVER_RET_WITH_EXCEPT( | |||||
| cusolverDnSpotrfBatched(handle_, uplo, m_, d_array_addr, lda_, d_info_array_addr, batch_), | |||||
| "cusolver cholesky batched Fail"); | |||||
| TriangleMatrixCopy(d_batch_input_addr, output_addr, uplo, outputs[0]->size / sizeof(T), ldb_, m_, | |||||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool Init(const CNodePtr &kernel_node) override { | |||||
| handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCusolverDnHandle(); | |||||
| blas_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCublasHandle(); | |||||
| auto in_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||||
| split_dim = static_cast<int>(GetAttr<int64_t>(kernel_node, "split_dim")); | |||||
| if (split_dim == 0) { | |||||
| InitNoSpltDim(in_shape); | |||||
| } else { | |||||
| InitSpltDim(in_shape); | |||||
| } | |||||
| return true; | |||||
| } | |||||
| protected: | |||||
| void InitNoSpltDim(const std::vector<size_t> &in_shape) { | |||||
| use_split_matrix = false; | |||||
| if (in_shape.size() == 2) { | |||||
| batch_ = 1; | |||||
| if (in_shape[0] != in_shape[1]) { | |||||
| MS_LOG(ERROR) << "Cholesky need square matrix as input."; | |||||
| } | |||||
| } else if (in_shape.size() == 3) { | |||||
| batch_ = SizeToInt(in_shape[0]); | |||||
| if (in_shape[1] != in_shape[2]) { | |||||
| MS_LOG(ERROR) << "Cholesky need square matrix as input."; | |||||
| } | |||||
| } else { | |||||
| MS_LOG(ERROR) << "Input Only support Rank 2 OR 3"; | |||||
| } | |||||
| m_ = SizeToInt(in_shape[1]); | |||||
| lda_ = m_; | |||||
| ldb_ = m_; | |||||
| h_array.resize(batch_); | |||||
| h_identity.resize(batch_); | |||||
| h_identity_data.resize(m_ * m_); | |||||
| for (size_t i = 0; i < m_; i++) { | |||||
| for (size_t j = 0; j < m_; j++) { | |||||
| if (i == j) { | |||||
| h_identity_data[i * m_ + j] = 1; | |||||
| } else { | |||||
| h_identity_data[i * m_ + j] = 0; | |||||
| } | |||||
| } | |||||
| } | |||||
| InitSizeLists(); | |||||
| } | |||||
| void InitSpltDim(const std::vector<size_t> &in_shape) { | |||||
| if (in_shape.size() != 2) { | |||||
| MS_LOG(ERROR) << "Cholesky Split Matrix Need Input Rank as 2."; | |||||
| } | |||||
| height = in_shape[0]; | |||||
| width = in_shape[1]; | |||||
| if (height != width) { | |||||
| MS_LOG(ERROR) << "Cholesky Split Matrix Need Square Matrix as Input."; | |||||
| } | |||||
| if (SizeToInt(height) <= split_dim) { | |||||
| use_split_matrix = false; | |||||
| batch_ = 1; | |||||
| m_ = SizeToInt(in_shape[1]); | |||||
| lda_ = m_; | |||||
| ldb_ = m_; | |||||
| h_array.resize(batch_); | |||||
| h_identity.resize(batch_); | |||||
| h_identity_data.resize(m_ * m_); | |||||
| for (size_t i = 0; i < m_; i++) { | |||||
| for (size_t j = 0; j < m_; j++) { | |||||
| if (i == j) { | |||||
| h_identity_data[i * m_ + j] = 1; | |||||
| } else { | |||||
| h_identity_data[i * m_ + j] = 0; | |||||
| } | |||||
| } | |||||
| } | |||||
| InitSizeLists(); | |||||
| } else { | |||||
| use_split_matrix = true; | |||||
| int batch = SizeToInt(in_shape[1]) / split_dim; | |||||
| res_dim = in_shape[1] - batch * split_dim; | |||||
| if (res_dim == 0) { | |||||
| batch_ = batch; | |||||
| } else { | |||||
| batch_ = batch + 1; | |||||
| } | |||||
| m_ = split_dim; | |||||
| lda_ = m_; | |||||
| ldb_ = m_; | |||||
| h_array.resize(batch_); | |||||
| h_identity.resize(batch_); | |||||
| h_identity_data.resize(m_ * m_); | |||||
| for (size_t i = 0; i < m_; i++) { | |||||
| for (size_t j = 0; j < m_; j++) { | |||||
| if (i == j) { | |||||
| h_identity_data[i * m_ + j] = 1; | |||||
| } else { | |||||
| h_identity_data[i * m_ + j] = 0; | |||||
| } | |||||
| } | |||||
| } | |||||
| InitSizeLists(); | |||||
| } | |||||
| } | |||||
| void InitSizeLists() override { | |||||
| size_t unit_size = sizeof(T); | |||||
| size_t input_size; | |||||
| size_t workspace_size; | |||||
| if (!use_split_matrix) { | |||||
| input_size = batch_ * m_ * lda_ * unit_size; | |||||
| } else { | |||||
| input_size = height * width * unit_size; | |||||
| workspace_size = batch_ * m_ * lda_ * unit_size; | |||||
| workspace_size_list_.push_back(workspace_size); | |||||
| } | |||||
| input_size_list_.push_back(input_size); | |||||
| size_t output_size = batch_ * m_ * lda_ * unit_size; | |||||
| output_size_list_.push_back(output_size); | |||||
| workspace_size = batch_ * sizeof(T *); | |||||
| workspace_size_list_.insert(workspace_size_list_.begin(), workspace_size); | |||||
| workspace_size = batch_ * sizeof(T *); | |||||
| workspace_size_list_.insert(workspace_size_list_.begin(), workspace_size); | |||||
| workspace_size = batch_ * sizeof(int); | |||||
| workspace_size_list_.insert(workspace_size_list_.begin(), workspace_size); | |||||
| } | |||||
| private: | |||||
| size_t batch_; | |||||
| size_t m_; | |||||
| size_t lda_; | |||||
| size_t ldb_; | |||||
| int res_dim; | |||||
| int split_dim; | |||||
| bool is_null_input_; | |||||
| bool use_split_matrix; | |||||
| size_t height; | |||||
| size_t width; | |||||
| cusolverDnHandle_t handle_; | |||||
| cublasHandle_t blas_handle_; | |||||
| cublasFillMode_t uplo = CUBLAS_FILL_MODE_UPPER; | |||||
| std::vector<T *> h_array; | |||||
| std::vector<T *> h_identity; | |||||
| std::vector<T> h_identity_data; | |||||
| std::vector<size_t> input_size_list_; | |||||
| std::vector<size_t> output_size_list_; | |||||
| std::vector<size_t> workspace_size_list_; | |||||
| }; | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| #endif | |||||
| @@ -88,7 +88,7 @@ from .other_ops import (Assign, InplaceAssign, IOU, BoundingBoxDecode, BoundingB | |||||
| from ._thor_ops import (CusBatchMatMul, CusCholeskyTrsm, CusFusedAbsMax1, CusImg2Col, CusMatMulCubeDenseLeft, | from ._thor_ops import (CusBatchMatMul, CusCholeskyTrsm, CusFusedAbsMax1, CusImg2Col, CusMatMulCubeDenseLeft, | ||||
| CusMatMulCubeFraczRightMul, CusMatMulCube, CusMatrixCombine, CusTranspose02314, | CusMatMulCubeFraczRightMul, CusMatMulCube, CusMatrixCombine, CusTranspose02314, | ||||
| CusMatMulCubeDenseRight, | CusMatMulCubeDenseRight, | ||||
| CusMatMulCubeFraczLeftCast, Im2Col, UpdateThorGradient, CholeskyTrsm, DetTriangle) | |||||
| CusMatMulCubeFraczLeftCast, Im2Col, UpdateThorGradient, Cholesky, CholeskyTrsm, DetTriangle) | |||||
| from .sparse_ops import SparseToDense | from .sparse_ops import SparseToDense | ||||
| from ._cache_ops import CacheSwapHashmap, SearchCacheIdx, CacheSwapTable, UpdateCache, MapCacheIdx | from ._cache_ops import CacheSwapHashmap, SearchCacheIdx, CacheSwapTable, UpdateCache, MapCacheIdx | ||||
| @@ -608,9 +608,42 @@ class UpdateThorGradient(PrimitiveWithInfer): | |||||
| return x2_dtype | return x2_dtype | ||||
| class Cholesky(PrimitiveWithInfer): | |||||
| """ | |||||
| Inner API for positive-definite matrix Cholesky decomposition GPU backend. | |||||
| """ | |||||
| @prim_attr_register | |||||
| def __init__(self, split_dim=0): | |||||
| self.init_prim_io_names(inputs=['x1'], outputs=['y']) | |||||
| self.split_dim = split_dim | |||||
| self.add_prim_attr('split_dim', self.split_dim) | |||||
| def infer_shape(self, x1_shape): | |||||
| if self.split_dim != 0: | |||||
| assert len(x1_shape) == 2 | |||||
| height = x1_shape[0] | |||||
| width = x1_shape[1] | |||||
| assert height == width | |||||
| if height <= self.split_dim: | |||||
| out_shape = [1, height, width] | |||||
| else: | |||||
| batch = height // self.split_dim | |||||
| if height != batch * self.split_dim: | |||||
| batch += 1 | |||||
| out_shape = [batch, self.split_dim, self.split_dim] | |||||
| else: | |||||
| out_shape = x1_shape | |||||
| return out_shape | |||||
| def infer_dtype(self, x1_dtype): | |||||
| validator.check_tensor_dtype_valid('x1', x1_dtype, [mstype.float32], self.name) | |||||
| return x1_dtype | |||||
| class CholeskyTrsm(PrimitiveWithInfer): | class CholeskyTrsm(PrimitiveWithInfer): | ||||
| """ | """ | ||||
| Inner API for resnet50 THOR GPU backend | |||||
| Inner API for resnet50 THOR GPU backend. | |||||
| """ | """ | ||||
| @prim_attr_register | @prim_attr_register | ||||
| @@ -643,7 +676,23 @@ class CholeskyTrsm(PrimitiveWithInfer): | |||||
| class DetTriangle(PrimitiveWithInfer): | class DetTriangle(PrimitiveWithInfer): | ||||
| """ | """ | ||||
| Calculate the determinant of triangle matrices | |||||
| Calculate the determinant of triangle matrices. | |||||
| Args: | |||||
| fill_mode (tuple): The target shape to broadcast. | |||||
| Inputs: | |||||
| - **input_x** (Tensor) - The input tensor. | |||||
| Outputs: | |||||
| Tensor, with the given `shape` and the same data type as `input_x`. | |||||
| Examples: | |||||
| >>> shape = (2, 3) | |||||
| >>> input_x = Tensor(np.array([1, 2, 3]).astype(np.float32)) | |||||
| >>> broadcast_to = P.BroadcastTo(shape) | |||||
| >>> broadcast_to(input_x) | |||||
| [[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]] | |||||
| """ | """ | ||||
| @prim_attr_register | @prim_attr_register | ||||
| @@ -0,0 +1,44 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| import pytest | |||||
| import mindspore.context as context | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore.common import dtype as mstype | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||||
| class NetCholesky(nn.Cell): | |||||
| def __init__(self): | |||||
| super(NetCholesky, self).__init__() | |||||
| self.cholesky = P.Cholesky() | |||||
| def construct(self, x): | |||||
| return self.cholesky(x) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_cholesky_fp32(): | |||||
| cholesky = NetCholesky() | |||||
| x = np.array([[4, 12, -16], [12, 37, -43], [-16, -43, 98]]).astype(np.float32) | |||||
| output = cholesky(Tensor(x, dtype=mstype.float32)) | |||||
| expect = np.linalg.cholesky(x) | |||||
| tol = 1e-6 | |||||
| assert (np.abs(output.asnumpy() - expect) < tol).all() | |||||