|
|
|
@@ -50,8 +50,8 @@ class ArrayReduceGpuKernel : public GpuKernel { |
|
|
|
T *output_addr = GetDeviceAddress<T>(outputs, 0); |
|
|
|
T *workspace_addr = GetDeviceAddress<T>(workspace, 0); |
|
|
|
|
|
|
|
const float alpha = 1; |
|
|
|
const float beta = 0; |
|
|
|
T alpha = static_cast<T>(1.0f); |
|
|
|
T beta = static_cast<T>(0.0f); |
|
|
|
if (all_match_) { |
|
|
|
MS_LOG(DEBUG) |
|
|
|
<< "The corresponding dimensions of the input and output tensors all match. No need to call cuDNN kernel."; |
|
|
|
@@ -60,11 +60,21 @@ class ArrayReduceGpuKernel : public GpuKernel { |
|
|
|
reinterpret_cast<cudaStream_t>(stream_ptr)), |
|
|
|
"cudaMemcpyAsync failed in ArrayReduceGpuKernel::Launch."); |
|
|
|
} else { |
|
|
|
CHECK_CUDNN_RET_WITH_EXCEPT( |
|
|
|
kernel_node_, |
|
|
|
cudnnReduceTensor(cudnn_handle_, reduce_tensor_descriptor_, nullptr, 0, workspace_addr, workspace_size_, &alpha, |
|
|
|
inputA_descriptor_, input_addr, &beta, outputC_descriptor_, output_addr), |
|
|
|
"cudnnReduceTensor failed."); |
|
|
|
if (data_type_ == CUDNN_DATA_DOUBLE) { |
|
|
|
CHECK_CUDNN_RET_WITH_EXCEPT( |
|
|
|
kernel_node_, |
|
|
|
cudnnReduceTensor(cudnn_handle_, reduce_tensor_descriptor_, nullptr, 0, workspace_addr, workspace_size_, |
|
|
|
&alpha, inputA_descriptor_, input_addr, &beta, outputC_descriptor_, output_addr), |
|
|
|
"cudnnReduceTensor failed."); |
|
|
|
} else { |
|
|
|
const float alphaf = static_cast<float>(alpha); |
|
|
|
const float betaf = static_cast<float>(beta); |
|
|
|
CHECK_CUDNN_RET_WITH_EXCEPT( |
|
|
|
kernel_node_, |
|
|
|
cudnnReduceTensor(cudnn_handle_, reduce_tensor_descriptor_, nullptr, 0, workspace_addr, workspace_size_, |
|
|
|
&alphaf, inputA_descriptor_, input_addr, &betaf, outputC_descriptor_, output_addr), |
|
|
|
"cudnnReduceTensor failed."); |
|
|
|
} |
|
|
|
} |
|
|
|
return true; |
|
|
|
} |
|
|
|
@@ -194,12 +204,12 @@ class ArrayReduceGpuKernel : public GpuKernel { |
|
|
|
MS_LOG(EXCEPTION) << "Array reduce kernel type " << kernel_name << " is not supported."; |
|
|
|
} |
|
|
|
reduce_tensor_op_ = iter->second; |
|
|
|
|
|
|
|
CHECK_CUDNN_RET_WITH_EXCEPT( |
|
|
|
kernel_node_, |
|
|
|
cudnnSetReduceTensorDescriptor(reduce_tensor_descriptor_, reduce_tensor_op_, CUDNN_DATA_FLOAT, nan_prop_, |
|
|
|
reduce_indices_, CUDNN_32BIT_INDICES), |
|
|
|
"cudnnSetReduceTensorDescriptor failed"); |
|
|
|
// add check for float64 |
|
|
|
cudnnDataType_t comp_type = (data_type_ == CUDNN_DATA_DOUBLE) ? CUDNN_DATA_DOUBLE : CUDNN_DATA_FLOAT; |
|
|
|
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, |
|
|
|
cudnnSetReduceTensorDescriptor(reduce_tensor_descriptor_, reduce_tensor_op_, comp_type, |
|
|
|
nan_prop_, reduce_indices_, CUDNN_32BIT_INDICES), |
|
|
|
"cudnnSetReduceTensorDescriptor failed"); |
|
|
|
return; |
|
|
|
} |
|
|
|
void InferInAndOutDesc(const std::vector<size_t> &input_shape, const std::vector<size_t> &output_shape) { |
|
|
|
|