|
|
|
@@ -47,26 +47,72 @@ class TrsmGpuKernel : public GpuKernel { |
|
|
|
auto inputb_addr = GetDeviceAddress<T>(inputs, 1); |
|
|
|
auto output_addr = GetDeviceAddress<T>(outputs, 0); |
|
|
|
|
|
|
|
const size_t batch = m_ * n_; |
|
|
|
// if b is not a vector, solve b in the workspace |
|
|
|
T *dst = nullptr; |
|
|
|
if (n_ == 1) { |
|
|
|
dst = output_addr; |
|
|
|
} else { |
|
|
|
dst = GetDeviceAddress<T>(workspace, 0); |
|
|
|
} |
|
|
|
|
|
|
|
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, |
|
|
|
cudaMemcpyAsync(output_addr, inputb_addr, batch * sizeof(T), cudaMemcpyDeviceToDevice, |
|
|
|
reinterpret_cast<cudaStream_t>(stream_ptr)), |
|
|
|
"cudaMemcpyAsync output_addr failed"); |
|
|
|
if (n_ == 1) { |
|
|
|
const size_t batch = m_ * n_; |
|
|
|
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, |
|
|
|
cudaMemcpyAsync(dst, inputb_addr, batch * sizeof(T), cudaMemcpyDeviceToDevice, |
|
|
|
reinterpret_cast<cudaStream_t>(stream_ptr)), |
|
|
|
"cudaMemcpyAsync dst failed"); |
|
|
|
} else { |
|
|
|
T alpha = 1; |
|
|
|
T beta = 0; |
|
|
|
// in order to convert row major matrix b(m x n) to col major matrix b'(m x n), |
|
|
|
// the following operation is equivalent to: |
|
|
|
// b' = b.T.reshape(m, n) |
|
|
|
if constexpr (std::is_same_v<T, float>) { |
|
|
|
CHECK_CUBLAS_RET_WITH_EXCEPT(kernel_node_, |
|
|
|
cublasSgeam(blas_handle_, CUBLAS_OP_T, CUBLAS_OP_T, m_, n_, &alpha, inputb_addr, |
|
|
|
n_, &beta, inputb_addr, n_, dst, m_), |
|
|
|
"cublas transpose b Fail"); |
|
|
|
} else { |
|
|
|
CHECK_CUBLAS_RET_WITH_EXCEPT(kernel_node_, |
|
|
|
cublasDgeam(blas_handle_, CUBLAS_OP_T, CUBLAS_OP_T, m_, n_, &alpha, inputb_addr, |
|
|
|
n_, &beta, inputb_addr, n_, dst, m_), |
|
|
|
"cublas transpose b Fail"); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
T alpha = 1; |
|
|
|
if constexpr (std::is_same_v<T, float>) { |
|
|
|
CHECK_CUBLAS_RET_WITH_EXCEPT(kernel_node_, |
|
|
|
cublasStrsm(blas_handle_, CUBLAS_SIDE_LEFT, uplo_, trans_, unit_diagonal_, m_, n_, |
|
|
|
&alpha, inputA_addr, lda_, output_addr, ldb_), |
|
|
|
&alpha, inputA_addr, lda_, dst, ldb_), |
|
|
|
"cublas trsm Fail"); |
|
|
|
} else { |
|
|
|
CHECK_CUBLAS_RET_WITH_EXCEPT(kernel_node_, |
|
|
|
cublasDtrsm(blas_handle_, CUBLAS_SIDE_LEFT, uplo_, trans_, unit_diagonal_, m_, n_, |
|
|
|
&alpha, inputA_addr, lda_, output_addr, ldb_), |
|
|
|
&alpha, inputA_addr, lda_, dst, ldb_), |
|
|
|
"cublas trsm Fail"); |
|
|
|
} |
|
|
|
|
|
|
|
// if x is not a vector, do transpose |
|
|
|
if (n_ != 1) { |
|
|
|
T alpha = 1; |
|
|
|
T beta = 0; |
|
|
|
// in order to convert col major matrix x'(m x n) to row major matrix x'(m x n), |
|
|
|
// the following operation is equivalent to: |
|
|
|
// x = x'.reshape(n, m).T |
|
|
|
if constexpr (std::is_same_v<T, float>) { |
|
|
|
CHECK_CUBLAS_RET_WITH_EXCEPT( |
|
|
|
kernel_node_, |
|
|
|
cublasSgeam(blas_handle_, CUBLAS_OP_T, CUBLAS_OP_T, n_, m_, &alpha, dst, m_, &beta, dst, m_, output_addr, n_), |
|
|
|
"cublas transpose x Fail"); |
|
|
|
} else { |
|
|
|
CHECK_CUBLAS_RET_WITH_EXCEPT( |
|
|
|
kernel_node_, |
|
|
|
cublasDgeam(blas_handle_, CUBLAS_OP_T, CUBLAS_OP_T, n_, m_, &alpha, dst, m_, &beta, dst, m_, output_addr, n_), |
|
|
|
"cublas transpose x Fail"); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
return true; |
|
|
|
} |
|
|
|
bool Init(const CNodePtr &kernel_node) override { |
|
|
|
@@ -97,9 +143,8 @@ class TrsmGpuKernel : public GpuKernel { |
|
|
|
if (b_shape.size() == kAVectorxDimNum || (b_shape.size() == kAMatrixDimNum && b_shape[kDim1] == 1)) { |
|
|
|
n_ = 1; |
|
|
|
} else { |
|
|
|
MS_LOG(EXCEPTION) << "b as a matrix is currently not supported."; |
|
|
|
n_ = b_shape[kDim1]; |
|
|
|
} |
|
|
|
m_ = b_shape[kDim0]; |
|
|
|
|
|
|
|
lda_ = SizeToInt(m_); |
|
|
|
ldb_ = SizeToInt(m_); |
|
|
|
@@ -137,8 +182,13 @@ class TrsmGpuKernel : public GpuKernel { |
|
|
|
protected: |
|
|
|
void InitSizeLists() override { |
|
|
|
size_t unit_size = sizeof(T); |
|
|
|
input_size_list_ = {m_ * m_ * unit_size, m_ * n_ * unit_size}; |
|
|
|
output_size_list_ = {m_ * n_ * unit_size}; |
|
|
|
size_t A_size = m_ * m_ * unit_size; |
|
|
|
size_t b_size = m_ * n_ * unit_size; |
|
|
|
input_size_list_ = {A_size, b_size}; |
|
|
|
output_size_list_ = {b_size}; |
|
|
|
if (n_ != 1) { |
|
|
|
workspace_size_list_ = {b_size}; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
private: |
|
|
|
|