Browse Source

!22172 GPU buffer sample update

Merge pull request !22172 from VectorSL/buffer-sample-update3
tags/v1.5.0-rc1
i-robot Gitee 4 years ago
parent
commit
4221ff369e
4 changed files with 66 additions and 36 deletions
  1. +27
    -8
      mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/rl/rl_buffer_impl.cu
  2. +3
    -2
      mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/rl/rl_buffer_impl.cuh
  3. +33
    -26
      mindspore/ccsrc/backend/kernel_compiler/gpu/rl/buffer_sample_gpu_kernel.cc
  4. +3
    -0
      mindspore/ccsrc/backend/kernel_compiler/gpu/rl/buffer_sample_gpu_kernel.h

+ 27
- 8
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/rl/rl_buffer_impl.cu View File

@@ -16,6 +16,10 @@

#include "backend/kernel_compiler/gpu/cuda_impl/rl/rl_buffer_impl.cuh"

#include <thrust/device_ptr.h>
#include <thrust/sort.h>
#include <thrust/execution_policy.h>

__global__ void BufferAppendKernel(const int64_t capacity, const size_t size, const int *index, const int exp_batch,
unsigned char *buffer, const unsigned char *exp) {
size_t index_ = index[0];
@@ -84,19 +88,24 @@ __global__ void CheckBatchSizeKernel(const int *count, const int *head, const si
}
}

__global__ void BufferSampleKernel(const size_t size, const size_t one_element, const int *index,
__global__ void BufferSampleKernel(const size_t size, const size_t one_element, const unsigned int *index,
const unsigned char *buffer, unsigned char *out) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
out[i] = buffer[index[i / one_element] * one_element + i % one_element];
}
}

__global__ void SrandUniformFloat(const int size, curandState *globalState, const int seedc, float *out) {
__global__ void SetupKernel(const int seed, curandState *state, const int size) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
curand_init(seedc, threadIdx.x, 0, &globalState[i]);
out[i] = curand_uniform(&globalState[i]);
curand_init(seed, i, 0, &state[i]);
}
}

__global__ void SrandUInt(const int size, curandState *globalState, unsigned int *value, unsigned int *out) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
out[i] = curand(&globalState[i]);
value[i] = i;
}
__syncthreads();
}

void BufferAppend(const int64_t capacity, const size_t size, const int *index, const int exp_batch,
@@ -123,11 +132,21 @@ void CheckBatchSize(const int *count, const int *head, const size_t batch_size,
CheckBatchSizeKernel<<<1, 1, 0, cuda_stream>>>(count, head, batch_size, capacity);
}

void BufferSample(const size_t size, const size_t one_element, const int *index, const unsigned char *buffer,
void BufferSample(const size_t size, const size_t one_element, const unsigned int *index, const unsigned char *buffer,
unsigned char *out, cudaStream_t cuda_stream) {
BufferSampleKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, one_element, index, buffer, out);
}

void RandomGen(const int size, curandState *globalState, const int &seedc, float *out, cudaStream_t stream) {
SrandUniformFloat<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(size, globalState, seedc, out);
void RandInit(const int size, const int seed, curandState *state, cudaStream_t stream) {
SetupKernel<<<(size + 255) / 256, 256, 0, stream>>>(seed, state, size);
}

void RandomGen(const int size, curandState *globalState, unsigned int *value, unsigned int *key, cudaStream_t stream) {
// 1 Generate two list, value for random int num, key for sequence form [0, size).
SrandUInt<<<(size + 255) / 256, 256, 0, stream>>>(size, globalState, value, key);
auto policy = thrust::cuda::par.on(stream);
thrust::device_ptr<unsigned int> dev_data_ptr(value);
thrust::device_ptr<unsigned int> dev_key_ptr(key);
// 2 Sort the key and get the sorted indexes.
thrust::sort_by_key(policy, dev_key_ptr, dev_key_ptr + size, dev_data_ptr);
}

+ 3
- 2
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/rl/rl_buffer_impl.cuh View File

@@ -27,7 +27,8 @@ void BufferGetItem(const size_t size, const int *index, const size_t one_exp_len
unsigned char *out, cudaStream_t cuda_stream);
void CheckBatchSize(const int *count, const int *head, const size_t batch_size, const int64_t capacity,
cudaStream_t cuda_stream);
void BufferSample(const size_t size, const size_t one_element, const int *index, const unsigned char *buffer,
void BufferSample(const size_t size, const size_t one_element, const unsigned int *index, const unsigned char *buffer,
unsigned char *out, cudaStream_t cuda_stream);
void RandomGen(const int size, curandState *globalState, const int &seedc, float *out, cudaStream_t stream);
void RandomGen(const int size, curandState *globalState, unsigned int *value, unsigned int *key, cudaStream_t stream);
void RandInit(const int size, const int seed, curandState *state, cudaStream_t stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADAM_IMPL_H_

+ 33
- 26
mindspore/ccsrc/backend/kernel_compiler/gpu/rl/buffer_sample_gpu_kernel.cc View File

@@ -30,9 +30,14 @@
namespace mindspore {
namespace kernel {

BufferSampleKernel::BufferSampleKernel() : element_nums_(0), capacity_(0), batch_size_(0), seed_(0) {}
BufferSampleKernel::BufferSampleKernel()
: element_nums_(0), capacity_(0), batch_size_(0), seed_(0), states_init_(false) {}

BufferSampleKernel::~BufferSampleKernel() {}
BufferSampleKernel::~BufferSampleKernel() {
if (devStates_ != nullptr) {
device::gpu::GPUMemoryAllocator::GetInstance().FreeTensorMem(static_cast<void *>(devStates_));
}
}

void BufferSampleKernel::ReleaseResource() {}

@@ -50,6 +55,19 @@ bool BufferSampleKernel::Init(const CNodePtr &kernel_node) {
seed_ = GetAttr<int64_t>(kernel_node, "seed");
batch_size_ = LongToSize(GetAttr<int64_t>(kernel_node, "batch_size"));
element_nums_ = shapes.size();
// Set default seed, if seed == 0
if (seed_ == 0) {
generator_.seed(std::chrono::system_clock::now().time_since_epoch().count());
seed_ = generator_();
}
// Keep the device memory for curandstate
const size_t cap_state_size = sizeof(curandState) * capacity_;
void *dev_state = device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(cap_state_size);
if (dev_state == nullptr) {
MS_LOG(EXCEPTION) << "Failed to alloc dev_state, size is " << cap_state_size;
}
devStates_ = reinterpret_cast<curandState *>(dev_state);

for (size_t i = 0; i < element_nums_; i++) {
auto element = shapes[i] * UnitSizeInBytes(types[i]->type_id());
exp_element_list.push_back(element);
@@ -59,10 +77,8 @@ bool BufferSampleKernel::Init(const CNodePtr &kernel_node) {
// count and head
input_size_list_.push_back(sizeof(int));
input_size_list_.push_back(sizeof(int));
workspace_size_list_.push_back(capacity_ * sizeof(curandState));
workspace_size_list_.push_back(capacity_ * sizeof(float));
workspace_size_list_.push_back(capacity_ * sizeof(int));
workspace_size_list_.push_back(capacity_ * sizeof(float));
workspace_size_list_.push_back(capacity_ * sizeof(unsigned int));
workspace_size_list_.push_back(capacity_ * sizeof(unsigned int));
return true;
}

@@ -74,29 +90,20 @@ bool BufferSampleKernel::Launch(const std::vector<AddressPtr> &inputs, const std
int *head_addr = GetDeviceAddress<int>(inputs, element_nums_ + 1);
auto cuda_stream = reinterpret_cast<cudaStream_t>(stream);
CheckBatchSize(count_addr, head_addr, batch_size_, capacity_, cuda_stream);
int k_cut = 0;
int k_num = 0;
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemcpyAsync(&k_cut, count_addr, sizeof(int), cudaMemcpyDeviceToHost, cuda_stream),
cudaMemcpyAsync(&k_num, count_addr, sizeof(int), cudaMemcpyDeviceToHost, cuda_stream),
"sync dev to host failed");
// 1 Generate random floats
auto States = GetDeviceAddress<void *>(workspaces, 0);
auto random_f = GetDeviceAddress<float>(workspaces, 1);
auto indexes = GetDeviceAddress<int>(workspaces, 2);
auto useless_out = GetDeviceAddress<float>(workspaces, 3);
int seedc = 0;
if (seed_ == 0) {
generator_.seed(std::chrono::system_clock::now().time_since_epoch().count());
seedc = generator_();
} else {
seedc = seed_;
// 1 Init curandState for the first time
if (!states_init_) {
RandInit(capacity_, seed_, devStates_, cuda_stream);
states_init_ = true;
}

float init_k = std::numeric_limits<float>::lowest();
curandState *devStates = reinterpret_cast<curandState *>(States);
RandomGen(k_cut, devStates, seedc, random_f, cuda_stream);
// 2 Sort the random floats, and get the sorted indexes as the random indexes
FastTopK(1, k_cut, random_f, k_cut, useless_out, indexes, init_k, cuda_stream);
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, cudaDeviceSynchronize(), "cudaDeviceSync failed, sample-topk");
auto key = GetDeviceAddress<unsigned int>(workspaces, 0);
auto indexes = GetDeviceAddress<unsigned int>(workspaces, 1);
// 2 Generate random indexes by kernel
RandomGen(k_num, devStates_, indexes, key, cuda_stream);
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, cudaDeviceSynchronize(), "cudaDeviceSync failed, random generate.");
for (size_t i = 0; i < element_nums_; i++) {
auto buffer_addr = GetDeviceAddress<unsigned char>(inputs, i);
auto out_addr = GetDeviceAddress<unsigned char>(outputs, i);


+ 3
- 0
mindspore/ccsrc/backend/kernel_compiler/gpu/rl/buffer_sample_gpu_kernel.h View File

@@ -17,6 +17,7 @@
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RL_BUFFER_SAMPLE_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RL_BUFFER_SAMPLE_GPU_KERNEL_H_

#include <curand_kernel.h>
#include <memory>
#include <string>
#include <vector>
@@ -47,7 +48,9 @@ class BufferSampleKernel : public GpuKernel {
int64_t capacity_;
size_t batch_size_;
int64_t seed_;
bool states_init_;
std::mt19937 generator_;
curandState *devStates_;
std::vector<size_t> exp_element_list;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;


Loading…
Cancel
Save