Browse Source

fix for async mem_init bilinearResize_grad

fix - typo
tags/v1.4.0
danishfarid 4 years ago
parent
commit
92d9bc7ccd
1 changed files with 6 additions and 2 deletions
  1. +6
    -2
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/resize_bilinear_grad_gpu_kernel.h

+ 6
- 2
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/resize_bilinear_grad_gpu_kernel.h View File

@@ -41,8 +41,12 @@ class ResizeBilinearGradGpuKernel : public GpuKernel {
T *dx = GetDeviceAddress<T>(outputs, 0);
float h_scale = Scaling(dx_h_, dy_h_, align_corners_);
float w_scale = Scaling(dx_w_, dy_w_, align_corners_);
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, cudaMemset(interim, 0, workspace_size_), "cudaMemset dx_interim failed");
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, cudaMemset(dx, 0, dx_size_), "cudaMemset dx failed");
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemsetAsync(dx, 0, dx_size_, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemsetAsync dx failed");
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemsetAsync(interim, 0, workspace_size_, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemsetAsync dx_interim failed");
CalResizeBilinearGrad(dy, n_, c_, dy_h_, dy_w_, dx_h_, dx_w_, h_scale, w_scale, dx, interim,
reinterpret_cast<cudaStream_t>(stream_ptr));
return true;


Loading…
Cancel
Save