|
|
|
@@ -41,8 +41,12 @@ class ResizeBilinearGradGpuKernel : public GpuKernel { |
|
|
|
T *dx = GetDeviceAddress<T>(outputs, 0); |
|
|
|
float h_scale = Scaling(dx_h_, dy_h_, align_corners_); |
|
|
|
float w_scale = Scaling(dx_w_, dy_w_, align_corners_); |
|
|
|
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, cudaMemset(interim, 0, workspace_size_), "cudaMemset dx_interim failed"); |
|
|
|
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, cudaMemset(dx, 0, dx_size_), "cudaMemset dx failed"); |
|
|
|
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, |
|
|
|
cudaMemsetAsync(dx, 0, dx_size_, reinterpret_cast<cudaStream_t>(stream_ptr)), |
|
|
|
"cudaMemsetAsync dx failed"); |
|
|
|
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, |
|
|
|
cudaMemsetAsync(interim, 0, workspace_size_, reinterpret_cast<cudaStream_t>(stream_ptr)), |
|
|
|
"cudaMemsetAsync dx_interim failed"); |
|
|
|
CalResizeBilinearGrad(dy, n_, c_, dy_h_, dy_w_, dx_h_, dx_w_, h_scale, w_scale, dx, interim, |
|
|
|
reinterpret_cast<cudaStream_t>(stream_ptr)); |
|
|
|
return true; |
|
|
|
|