| @@ -51,10 +51,12 @@ class CtcLossGpuKernel : public GpuKernel { | |||
| float *grads = GetDeviceAddress<float>(outputs, 1); | |||
| // Copy labels/input_lengths/label_length to host as cudnn7.x.x requires | |||
| void *labels_host = nullptr; | |||
| int *labels_host = nullptr; | |||
| int *no_blank_labels_host = nullptr; | |||
| void *input_lengths_host = nullptr; | |||
| void *label_lengths_host = nullptr; | |||
| CHECK_CUDA_RET_WITH_EXCEPT(cudaMallocHost(&labels_host, inputs[1]->size), "cudaMallocHost failed."); | |||
| CHECK_CUDA_RET_WITH_EXCEPT(cudaMallocHost(&no_blank_labels_host, inputs[1]->size), "cudaMallocHost failed."); | |||
| CHECK_CUDA_RET_WITH_EXCEPT(cudaMallocHost(&input_lengths_host, inputs[2]->size), "cudaMallocHost failed."); | |||
| CHECK_CUDA_RET_WITH_EXCEPT(cudaMallocHost(&label_lengths_host, inputs[3]->size), "cudaMallocHost failed."); | |||
| cudaStream_t stream = reinterpret_cast<cudaStream_t>(stream_ptr); | |||
| @@ -68,12 +70,21 @@ class CtcLossGpuKernel : public GpuKernel { | |||
| "cudaMemcpyAsync failed."); | |||
| CHECK_CUDA_RET_WITH_EXCEPT(cudaStreamSynchronize(stream), "cudaStreamSynchronize failed."); | |||
| size_t j = 0; | |||
| for (size_t i = 0; i < inputs[1]->size / sizeof(int); i++) { | |||
| if (labels_host[i] != 0) { | |||
| no_blank_labels_host[j] = labels_host[i]; | |||
| j++; | |||
| } | |||
| } | |||
| size_t workspace_size = 0; | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| cudnnGetCTCLossWorkspaceSize(cudnn_handle_, probs_desc_, probs_desc_, reinterpret_cast<int *>(labels_host), | |||
| reinterpret_cast<int *>(label_lengths_host), | |||
| reinterpret_cast<int *>(input_lengths_host), CUDNN_CTC_LOSS_ALGO_DETERMINISTIC, | |||
| ctcloss_desc_, &workspace_size), | |||
| cudnnGetCTCLossWorkspaceSize( | |||
| cudnn_handle_, probs_desc_, probs_desc_, reinterpret_cast<int *>(no_blank_labels_host), | |||
| reinterpret_cast<int *>(label_lengths_host), reinterpret_cast<int *>(input_lengths_host), | |||
| CUDNN_CTC_LOSS_ALGO_DETERMINISTIC, ctcloss_desc_, &workspace_size), | |||
| "cudnnGetCTCLossWorkspaceSize failed."); | |||
| void *workspace = device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(workspace_size); | |||
| if (workspace == nullptr) { | |||
| @@ -81,7 +92,7 @@ class CtcLossGpuKernel : public GpuKernel { | |||
| } | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| cudnnCTCLoss(cudnn_handle_, probs_desc_, probs, reinterpret_cast<int *>(labels_host), | |||
| cudnnCTCLoss(cudnn_handle_, probs_desc_, probs, reinterpret_cast<int *>(no_blank_labels_host), | |||
| reinterpret_cast<int *>(label_lengths_host), reinterpret_cast<int *>(input_lengths_host), costs, | |||
| probs_desc_, grads, CUDNN_CTC_LOSS_ALGO_DETERMINISTIC, ctcloss_desc_, workspace, workspace_size), | |||
| "cudnnCtcLoss failed."); | |||
| @@ -91,6 +102,7 @@ class CtcLossGpuKernel : public GpuKernel { | |||
| CHECK_CUDA_RET_WITH_EXCEPT(cudaFreeHost(label_lengths_host), "cudaFreeHost failed."); | |||
| CHECK_CUDA_RET_WITH_EXCEPT(cudaFreeHost(input_lengths_host), "cudaFreeHost failed."); | |||
| CHECK_CUDA_RET_WITH_EXCEPT(cudaFreeHost(labels_host), "cudaFreeHost failed."); | |||
| CHECK_CUDA_RET_WITH_EXCEPT(cudaFreeHost(no_blank_labels_host), "cudaFreeHost failed."); | |||
| return true; | |||
| } | |||
| bool Init(const CNodePtr &kernel_node) override { | |||