Browse Source

add GPU trsm 2d matrix support

tags/v1.6.0
zhujingxuan 4 years ago
parent
commit
fb1805de30
3 changed files with 111 additions and 13 deletions
  1. +61
    -11
      mindspore/ccsrc/backend/kernel_compiler/gpu/math/trsm_solve_gpu_kernel.h
  2. +24
    -0
      tests/st/ops/cpu/test_solve_triangular_op.py
  3. +26
    -2
      tests/st/ops/gpu/test_solve_triangular_op.py

+ 61
- 11
mindspore/ccsrc/backend/kernel_compiler/gpu/math/trsm_solve_gpu_kernel.h View File

@@ -47,26 +47,72 @@ class TrsmGpuKernel : public GpuKernel {
auto inputb_addr = GetDeviceAddress<T>(inputs, 1);
auto output_addr = GetDeviceAddress<T>(outputs, 0);

const size_t batch = m_ * n_;
// if b is not a vector, solve b in the workspace
T *dst = nullptr;
if (n_ == 1) {
dst = output_addr;
} else {
dst = GetDeviceAddress<T>(workspace, 0);
}

CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemcpyAsync(output_addr, inputb_addr, batch * sizeof(T), cudaMemcpyDeviceToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync output_addr failed");
if (n_ == 1) {
const size_t batch = m_ * n_;
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemcpyAsync(dst, inputb_addr, batch * sizeof(T), cudaMemcpyDeviceToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync dst failed");
} else {
T alpha = 1;
T beta = 0;
// in order to convert row major matrix b(m x n) to col major matrix b'(m x n),
// the following operation is equivalent to:
// b' = b.T.reshape(m, n)
if constexpr (std::is_same_v<T, float>) {
CHECK_CUBLAS_RET_WITH_EXCEPT(kernel_node_,
cublasSgeam(blas_handle_, CUBLAS_OP_T, CUBLAS_OP_T, m_, n_, &alpha, inputb_addr,
n_, &beta, inputb_addr, n_, dst, m_),
"cublas transpose b Fail");
} else {
CHECK_CUBLAS_RET_WITH_EXCEPT(kernel_node_,
cublasDgeam(blas_handle_, CUBLAS_OP_T, CUBLAS_OP_T, m_, n_, &alpha, inputb_addr,
n_, &beta, inputb_addr, n_, dst, m_),
"cublas transpose b Fail");
}
}

T alpha = 1;
if constexpr (std::is_same_v<T, float>) {
CHECK_CUBLAS_RET_WITH_EXCEPT(kernel_node_,
cublasStrsm(blas_handle_, CUBLAS_SIDE_LEFT, uplo_, trans_, unit_diagonal_, m_, n_,
&alpha, inputA_addr, lda_, output_addr, ldb_),
&alpha, inputA_addr, lda_, dst, ldb_),
"cublas trsm Fail");
} else {
CHECK_CUBLAS_RET_WITH_EXCEPT(kernel_node_,
cublasDtrsm(blas_handle_, CUBLAS_SIDE_LEFT, uplo_, trans_, unit_diagonal_, m_, n_,
&alpha, inputA_addr, lda_, output_addr, ldb_),
&alpha, inputA_addr, lda_, dst, ldb_),
"cublas trsm Fail");
}

// if x is not a vector, do transpose
if (n_ != 1) {
T alpha = 1;
T beta = 0;
// in order to convert col major matrix x'(m x n) to row major matrix x'(m x n),
// the following operation is equivalent to:
// x = x'.reshape(n, m).T
if constexpr (std::is_same_v<T, float>) {
CHECK_CUBLAS_RET_WITH_EXCEPT(
kernel_node_,
cublasSgeam(blas_handle_, CUBLAS_OP_T, CUBLAS_OP_T, n_, m_, &alpha, dst, m_, &beta, dst, m_, output_addr, n_),
"cublas transpose x Fail");
} else {
CHECK_CUBLAS_RET_WITH_EXCEPT(
kernel_node_,
cublasDgeam(blas_handle_, CUBLAS_OP_T, CUBLAS_OP_T, n_, m_, &alpha, dst, m_, &beta, dst, m_, output_addr, n_),
"cublas transpose x Fail");
}
}

return true;
}
bool Init(const CNodePtr &kernel_node) override {
@@ -97,9 +143,8 @@ class TrsmGpuKernel : public GpuKernel {
if (b_shape.size() == kAVectorxDimNum || (b_shape.size() == kAMatrixDimNum && b_shape[kDim1] == 1)) {
n_ = 1;
} else {
MS_LOG(EXCEPTION) << "b as a matrix is currently not supported.";
n_ = b_shape[kDim1];
}
m_ = b_shape[kDim0];

lda_ = SizeToInt(m_);
ldb_ = SizeToInt(m_);
@@ -137,8 +182,13 @@ class TrsmGpuKernel : public GpuKernel {
protected:
void InitSizeLists() override {
size_t unit_size = sizeof(T);
input_size_list_ = {m_ * m_ * unit_size, m_ * n_ * unit_size};
output_size_list_ = {m_ * n_ * unit_size};
size_t A_size = m_ * m_ * unit_size;
size_t b_size = m_ * n_ * unit_size;
input_size_list_ = {A_size, b_size};
output_size_list_ = {b_size};
if (n_ != 1) {
workspace_size_list_ = {b_size};
}
}

private:


+ 24
- 0
tests/st/ops/cpu/test_solve_triangular_op.py View File

@@ -118,3 +118,27 @@ def test_1D(n: int, dtype, lower: bool, unit_diagonal: bool, trans: str):
a = (np.random.random((n, n)) + np.eye(n)).astype(dtype)
b = np.random.random(n).astype(dtype)
match(a, b, lower=lower, unit_diagonal=unit_diagonal, trans=trans)


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('shape', [(10, 20)])
@pytest.mark.parametrize('trans', ["N", "T"])
@pytest.mark.parametrize('dtype', [np.float32, np.float64])
@pytest.mark.parametrize('lower', [False, True])
@pytest.mark.parametrize('unit_diagonal', [False, True])
def test_matrix(shape: int, dtype, lower: bool, unit_diagonal: bool, trans: str):
"""
Feature: ALL TO ALL
Description: test cases for [N x N] X [N]
Expectation: the result match scipy
"""
if trans == 'T':
n, m = shape
else:
m, n = shape
# add Identity matrix to make matrix A non-singular
a = (np.random.random((m, m)) + np.eye(m)).astype(dtype)
b = np.random.random((m, n)).astype(dtype)
match(a, b, lower=lower, unit_diagonal=unit_diagonal, trans=trans)

+ 26
- 2
tests/st/ops/gpu/test_solve_triangular_op.py View File

@@ -79,7 +79,7 @@ def match(a, b, lower, unit_diagonal, trans):


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('n', [10, 20])
@pytest.mark.parametrize('trans', ["N", "T"])
@@ -99,7 +99,7 @@ def test_2D(n: int, dtype, lower: bool, unit_diagonal: bool, trans: str):


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('n', [10, 20])
@pytest.mark.parametrize('trans', ["N", "T"])
@@ -116,3 +116,27 @@ def test_1D(n: int, dtype, lower: bool, unit_diagonal: bool, trans: str):
a = (np.random.random((n, n)) + np.eye(n)).astype(dtype)
b = np.random.random(n).astype(dtype)
match(a, b, lower=lower, unit_diagonal=unit_diagonal, trans=trans)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('shape', [(4, 5), (10, 20)])
@pytest.mark.parametrize('trans', ["N", "T"])
@pytest.mark.parametrize('dtype', [np.float32, np.float64])
@pytest.mark.parametrize('lower', [False, True])
@pytest.mark.parametrize('unit_diagonal', [False, True])
def test_matrix(shape: int, dtype, lower: bool, unit_diagonal: bool, trans: str):
"""
Feature: ALL TO ALL
Description: test cases for [N x N] X [N]
Expectation: the result match scipy
"""
if trans == 'T':
n, m = shape
else:
m, n = shape
# add Identity matrix to make matrix A non-singular
a = (np.random.random((m, m)) + np.eye(m)).astype(dtype)
b = np.random.random((m, n)).astype(dtype)
match(a, b, lower=lower, unit_diagonal=unit_diagonal, trans=trans)

Loading…
Cancel
Save