Browse Source

fix-bot-warning

tags/v1.1.0
zhouyuanshen 5 years ago
parent
commit
88513a223a
3 changed files with 17 additions and 11 deletions
  1. +3
    -2
      mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/pack_gpu_kernel.h
  2. +3
    -2
      mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unpack_gpu_kernel.h
  3. +11
    -7
      mindspore/ccsrc/backend/kernel_compiler/gpu/random/random_categorical_gpu_kernel.h

+ 3
- 2
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/pack_gpu_kernel.h View File

@@ -41,8 +41,9 @@ class PackGpuFwdKernel : public GpuKernel {
for (size_t i = 0; i < inputs.size(); i++) {
inputs_host_[i] = GetDeviceAddress<T>(inputs, i);
}
CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(inputs_array, inputs_host_.get(), sizeof(T *) * input_num_,
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(inputs_array, // NOLINT
inputs_host_.get(), sizeof(T *) * input_num_, cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"Pack opt cudaMemcpyAsync inputs failed");
PackKernel(SizeToInt(output_size_), input_num_, dims_behind_axis_, inputs_array, output,
reinterpret_cast<cudaStream_t>(stream_ptr));


+ 3
- 2
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unpack_gpu_kernel.h View File

@@ -41,8 +41,9 @@ class UnpackGpuFwdKernel : public GpuKernel {
for (size_t i = 0; i < outputs.size(); i++) {
outputs_host_[i] = GetDeviceAddress<T>(outputs, i);
}
CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(outputs_array, outputs_host_.get(), sizeof(T *) * output_num_,
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(outputs_array, // NOLINT
outputs_host_.get(), sizeof(T *) * output_num_, cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"Unpack opt cudaMemcpyAsync outputs failed");
UnpackKernel(SizeToInt(input_size_), output_num_, dims_after_axis_, outputs_array, input,
reinterpret_cast<cudaStream_t>(stream_ptr));


+ 11
- 7
mindspore/ccsrc/backend/kernel_compiler/gpu/random/random_categorical_gpu_kernel.h View File

@@ -47,8 +47,9 @@ class RandomCategoricalGpuKernel : public GpuKernel {
host_cdf[i] = GetDeviceAddress<double>(workspaces, i);
}
double **dev_cdf = GetDeviceAddress<double *>(workspaces, batch_size_);
CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(dev_cdf, host_cdf.get(), sizeof(double *) * batch_size_,
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(dev_cdf, // NOLINT
host_cdf.get(), sizeof(double *) * batch_size_, cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"Random_categorica cudaMemcpyAsync dev_cdf failed");

std::unique_ptr<double *[]> host_rand;
@@ -59,19 +60,22 @@ class RandomCategoricalGpuKernel : public GpuKernel {

double **dev_rand = GetDeviceAddress<double *>(workspaces, batch_size_ * 2 + 1);
for (int i = 0; i < batch_size_; i++) {
double *host_1d_rand = new double[num_samples_];
std::unique_ptr<double[]> host_1d_rand;
host_1d_rand = std::make_unique<double[]>(num_samples_);

std::default_random_engine rng(seed_);
std::uniform_real_distribution<> dist(0, 1);
for (int j = 0; j < num_samples_; j++) {
host_1d_rand[j] = dist(rng);
}
CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(host_rand[i], host_1d_rand, sizeof(double) * num_samples_,
CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(host_rand[i], // NOLINT
host_1d_rand.get(), sizeof(double) * num_samples_,
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
"Random_categorica cudaMemcpyAsync host_1d_rand failed");
delete[] host_1d_rand;
}
CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(dev_rand, host_rand.get(), sizeof(double *) * batch_size_,
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(dev_rand, // NOLINT
host_rand.get(), sizeof(double *) * batch_size_, cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"Random_categorica cudaMemcpyAsync dev_rand failed");

GetCdfKernel(logits_addr, dev_cdf, batch_size_, num_classes_, reinterpret_cast<cudaStream_t>(stream_ptr));


Loading…
Cancel
Save