|
|
|
@@ -114,10 +114,14 @@ class EighcGpuKernel : public GpuKernel { |
|
|
|
cudaMalloc(reinterpret_cast<void **>(&dev_input_shape), kShape2dDims * sizeof(size_t)); |
|
|
|
size_t *dev_input_axis = nullptr; |
|
|
|
cudaMalloc(reinterpret_cast<void **>(&dev_input_axis), kShape2dDims * sizeof(size_t)); |
|
|
|
cudaMemcpyAsync(dev_input_shape, input_shape, kShape2dDims * sizeof(size_t), cudaMemcpyHostToDevice, |
|
|
|
reinterpret_cast<cudaStream_t>(stream_ptr)); |
|
|
|
cudaMemcpyAsync(dev_input_axis, input_axis, kShape2dDims * sizeof(size_t), cudaMemcpyHostToDevice, |
|
|
|
reinterpret_cast<cudaStream_t>(stream_ptr)); |
|
|
|
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, |
|
|
|
cudaMemcpyAsync(dev_input_shape, input_shape, kShape2dDims * sizeof(size_t), |
|
|
|
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)), |
|
|
|
"malloc input shape workspace failed"); |
|
|
|
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, |
|
|
|
cudaMemcpyAsync(dev_input_axis, input_axis, kShape2dDims * sizeof(size_t), |
|
|
|
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)), |
|
|
|
"malloc input shape workspace failed"); |
|
|
|
CalTranspose(m_ * m_, output_v_addr, dev_input_shape, dev_input_axis, kShape2dDims, w_v_addr, |
|
|
|
reinterpret_cast<cudaStream_t>(stream_ptr)); |
|
|
|
|
|
|
|
@@ -145,6 +149,12 @@ class EighcGpuKernel : public GpuKernel { |
|
|
|
reinterpret_cast<cudaStream_t>(stream_ptr)); |
|
|
|
CalTranspose(m_ * m_, w_v_addr, dev_input_shape, dev_input_axis, kShape2dDims, output_v_addr, |
|
|
|
reinterpret_cast<cudaStream_t>(stream_ptr)); |
|
|
|
if (dev_input_shape) { |
|
|
|
cudaFree(dev_input_shape); |
|
|
|
} |
|
|
|
if (dev_input_axis) { |
|
|
|
cudaFree(dev_input_axis); |
|
|
|
} |
|
|
|
// convert real scalar to complex |
|
|
|
if (d_work) { |
|
|
|
cudaFree(d_work); |
|
|
|
|