|
|
|
@@ -14,8 +14,8 @@ |
|
|
|
* limitations under the License. |
|
|
|
*/ |
|
|
|
|
|
|
|
#ifndef MINDSPORE_CHOLESKY_SOLVE_GPU_KERNEL_H |
|
|
|
#define MINDSPORE_CHOLESKY_SOLVE_GPU_KERNEL_H |
|
|
|
#ifndef MINDSPORE_CHOLESKY_TRSM_SOLVE_GPU_KERNEL_H |
|
|
|
#define MINDSPORE_CHOLESKY_TRSM_SOLVE_GPU_KERNEL_H |
|
|
|
#include <cublas_v2.h> |
|
|
|
#include <cuda_runtime_api.h> |
|
|
|
#include <vector> |
|
|
|
@@ -29,10 +29,10 @@ |
|
|
|
namespace mindspore { |
|
|
|
namespace kernel { |
|
|
|
template <typename T> |
|
|
|
class CholeskyGpuKernel : public GpuKernel { |
|
|
|
class CholeskyTrsmGpuKernel : public GpuKernel { |
|
|
|
public: |
|
|
|
CholeskyGpuKernel() : batch_(0), m_(0), lda_(0), is_null_input_(false), handle_(nullptr) {} |
|
|
|
~CholeskyGpuKernel() = default; |
|
|
|
CholeskyTrsmGpuKernel() : batch_(0), m_(0), lda_(0), is_null_input_(false), handle_(nullptr) {} |
|
|
|
~CholeskyTrsmGpuKernel() = default; |
|
|
|
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } |
|
|
|
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } |
|
|
|
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } |
|
|
|
@@ -111,12 +111,12 @@ class CholeskyGpuKernel : public GpuKernel { |
|
|
|
if (in_shape.size() == 2) { |
|
|
|
batch_ = 1; |
|
|
|
if (in_shape[0] != in_shape[1]) { |
|
|
|
MS_LOG(ERROR) << "Cholesky need square matrix as input."; |
|
|
|
MS_LOG(ERROR) << "CholeskyTrsm need square matrix as input."; |
|
|
|
} |
|
|
|
} else if (in_shape.size() == 3) { |
|
|
|
batch_ = SizeToInt(in_shape[0]); |
|
|
|
if (in_shape[1] != in_shape[2]) { |
|
|
|
MS_LOG(ERROR) << "Cholesky need square matrix as input."; |
|
|
|
MS_LOG(ERROR) << "CholeskyTrsm need square matrix as input."; |
|
|
|
} |
|
|
|
} else { |
|
|
|
MS_LOG(ERROR) << "Input Only support Rank 2 OR 3"; |
|
|
|
@@ -140,12 +140,12 @@ class CholeskyGpuKernel : public GpuKernel { |
|
|
|
InitSizeLists(); |
|
|
|
} else { |
|
|
|
if (in_shape.size() != 2) { |
|
|
|
MS_LOG(ERROR) << "Cholesky Split Matrix Need Input Rank as 2."; |
|
|
|
MS_LOG(ERROR) << "CholeskyTrsm Split Matrix Need Input Rank as 2."; |
|
|
|
} |
|
|
|
height = in_shape[0]; |
|
|
|
width = in_shape[1]; |
|
|
|
if (height != width) { |
|
|
|
MS_LOG(ERROR) << "Cholesky Split Matrix Need Square Matrix as Input."; |
|
|
|
MS_LOG(ERROR) << "CholeskyTrsm Split Matrix Need Square Matrix as Input."; |
|
|
|
} |
|
|
|
if (SizeToInt(height) <= split_dim) { |
|
|
|
use_split_matrix = false; |