Browse Source

!1438 GPU GetNext kernel fix

Merge pull request !1438 from chenweifeng/getnext
tags/v0.5.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
132ea0ad08
1 changed files with 5 additions and 2 deletions
  1. +5
    -2
      mindspore/ccsrc/kernel/gpu/data/dataset_iterator_kernel.cc

+ 5
- 2
mindspore/ccsrc/kernel/gpu/data/dataset_iterator_kernel.cc View File

@@ -64,7 +64,7 @@ bool DatasetIteratorKernel::Init(const CNodePtr &kernel_node) {
void DatasetIteratorKernel::InitSizeLists() { return; }

bool DatasetIteratorKernel::Launch(const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *) {
const std::vector<AddressPtr> &outputs, void *stream) {
void *addr = nullptr;
size_t len = 0;

@@ -96,11 +96,14 @@ bool DatasetIteratorKernel::Launch(const std::vector<AddressPtr> &, const std::v
}

for (size_t i = 0; i < output_size_list_.size(); i++) {
CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpy(outputs[i]->addr, addr, output_size_list_[i], cudaMemcpyDeviceToDevice),
CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(outputs[i]->addr, addr, output_size_list_[i], cudaMemcpyDeviceToDevice,
reinterpret_cast<cudaStream_t>(stream)),
"Cuda Memcpy Failed");
addr = reinterpret_cast<unsigned char *>(addr) + output_size_list_[i];
}

CHECK_CUDA_RET_WITH_EXCEPT(cudaStreamSynchronize(reinterpret_cast<cudaStream_t>(stream)),
"cudaStreamSynchronize failed");
(void)GpuBufferMgr::GetInstance().Pop(handle_);
return true;
}


Loading…
Cancel
Save