| @@ -16,6 +16,7 @@ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RL_BUFFER_SAMPLE_CPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RL_BUFFER_SAMPLE_CPU_KERNEL_H_ | |||
| #include <stdlib.h> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| @@ -27,13 +28,14 @@ namespace mindspore { | |||
| namespace kernel { | |||
| class BufferCPUSampleKernel : public CPUKernel { | |||
| public: | |||
| BufferCPUSampleKernel() : element_nums_(0), capacity_(0), batch_size_(0), exp_size_(0) {} | |||
| BufferCPUSampleKernel() : element_nums_(0), capacity_(0), batch_size_(0), exp_size_(0), seed_(0) {} | |||
| ~BufferCPUSampleKernel() override = default; | |||
| void Init(const CNodePtr &kernel_node) { | |||
| auto shapes = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, "buffer_elements"); | |||
| auto types = AnfAlgo::GetNodeAttr<std::vector<TypePtr>>(kernel_node, "buffer_dtype"); | |||
| capacity_ = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "capacity"); | |||
| seed_ = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "seed"); | |||
| batch_size_ = LongToSize(AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "batch_size")); | |||
| element_nums_ = shapes.size(); | |||
| for (size_t i = 0; i < element_nums_; i++) { | |||
| @@ -45,8 +47,6 @@ class BufferCPUSampleKernel : public CPUKernel { | |||
| output_size_list_.push_back(i * batch_size_); | |||
| exp_size_ += i; | |||
| } | |||
| // index | |||
| input_size_list_.push_back(sizeof(int) * batch_size_); | |||
| // count and head | |||
| input_size_list_.push_back(sizeof(int)); | |||
| input_size_list_.push_back(sizeof(int)); | |||
| @@ -54,18 +54,29 @@ class BufferCPUSampleKernel : public CPUKernel { | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | |||
| const std::vector<AddressPtr> &outputs) { | |||
| auto indexes_addr = GetDeviceAddress<int>(inputs, element_nums_); | |||
| auto count_addr = GetDeviceAddress<int>(inputs, element_nums_ + 1); | |||
| auto head_addr = GetDeviceAddress<int>(inputs, element_nums_ + 2); | |||
| auto count_addr = GetDeviceAddress<int>(inputs, element_nums_); | |||
| auto head_addr = GetDeviceAddress<int>(inputs, element_nums_ + 1); | |||
| if ((head_addr[0] > 0 && SizeToLong(batch_size_) > capacity_) || | |||
| (head_addr[0] == 0 && SizeToLong(batch_size_) > count_addr[0])) { | |||
| MS_LOG(ERROR) << "The batch size " << batch_size_ << " is larger than total buffer size " | |||
| << std::min(capacity_, IntToLong(count_addr[0])); | |||
| } | |||
| // Generate random indexes | |||
| std::vector<size_t> indexes; | |||
| for (size_t i = 0; i < IntToSize(count_addr[0]); ++i) { | |||
| indexes.push_back(i); | |||
| } | |||
| if (seed_ == 0) { | |||
| std::srand(time(nullptr)); | |||
| } else { | |||
| std::srand(seed_); | |||
| } | |||
| random_shuffle(indexes.begin(), indexes.end()); | |||
| auto task = [&](size_t start, size_t end) { | |||
| for (size_t j = start; j < end; j++) { | |||
| int64_t index = IntToSize(indexes_addr[j]); | |||
| size_t index = indexes[j]; | |||
| for (size_t i = 0; i < element_nums_; i++) { | |||
| auto buffer_addr = GetDeviceAddress<unsigned char>(inputs, i); | |||
| auto output_addr = GetDeviceAddress<unsigned char>(outputs, i); | |||
| @@ -92,6 +103,7 @@ class BufferCPUSampleKernel : public CPUKernel { | |||
| int64_t capacity_; | |||
| size_t batch_size_; | |||
| int64_t exp_size_; | |||
| int64_t seed_; | |||
| std::vector<size_t> exp_element_list; | |||
| }; | |||
| } // namespace kernel | |||
| @@ -91,6 +91,14 @@ __global__ void BufferSampleKernel(const size_t size, const size_t one_element, | |||
| } | |||
| } | |||
| __global__ void SrandUniformFloat(const int size, curandState *globalState, const int seedc, float *out) { | |||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { | |||
| curand_init(seedc, threadIdx.x, 0, &globalState[i]); | |||
| out[i] = curand_uniform(&globalState[i]); | |||
| } | |||
| __syncthreads(); | |||
| } | |||
| void BufferAppend(const int64_t capacity, const size_t size, const int *index, const int exp_batch, | |||
| unsigned char *buffer, const unsigned char *exp, cudaStream_t cuda_stream) { | |||
| BufferAppendKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(capacity, size, index, exp_batch, buffer, exp); | |||
| @@ -119,3 +127,7 @@ void BufferSample(const size_t size, const size_t one_element, const int *index, | |||
| unsigned char *out, cudaStream_t cuda_stream) { | |||
| BufferSampleKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, one_element, index, buffer, out); | |||
| } | |||
| void RandomGen(const int size, curandState *globalState, const int &seedc, float *out, cudaStream_t stream) { | |||
| SrandUniformFloat<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(size, globalState, seedc, out); | |||
| } | |||
| @@ -16,7 +16,7 @@ | |||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_RL_BUFFER_IMPL_H_ | |||
| #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_RL_BUFFER_IMPL_H_ | |||
| #include <curand_kernel.h> | |||
| #include "runtime/device/gpu/cuda_common.h" | |||
| void BufferAppend(const int64_t capacity, const size_t size, const int *index, const int exp_batch, | |||
| unsigned char *buffer, const unsigned char *exp, cudaStream_t cuda_stream); | |||
| @@ -29,5 +29,5 @@ void CheckBatchSize(const int *count, const int *head, const size_t batch_size, | |||
| cudaStream_t cuda_stream); | |||
| void BufferSample(const size_t size, const size_t one_element, const int *index, const unsigned char *buffer, | |||
| unsigned char *out, cudaStream_t cuda_stream); | |||
| void RandomGen(const int size, curandState *globalState, const int &seedc, float *out, cudaStream_t stream); | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADAM_IMPL_H_ | |||
| @@ -19,15 +19,18 @@ | |||
| #include <string> | |||
| #include <vector> | |||
| #include <algorithm> | |||
| #include <limits> | |||
| #include <chrono> | |||
| #include "backend/kernel_compiler/common_utils.h" | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/rl/rl_buffer_impl.cuh" | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/topk_impl.cuh" | |||
| #include "runtime/device/gpu/gpu_common.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| BufferSampleKernel::BufferSampleKernel() : element_nums_(0), capacity_(0), batch_size_(0) {} | |||
| BufferSampleKernel::BufferSampleKernel() : element_nums_(0), capacity_(0), batch_size_(0), seed_(0) {} | |||
| BufferSampleKernel::~BufferSampleKernel() {} | |||
| @@ -44,6 +47,7 @@ bool BufferSampleKernel::Init(const CNodePtr &kernel_node) { | |||
| auto shapes = GetAttr<std::vector<int64_t>>(kernel_node, "buffer_elements"); | |||
| auto types = GetAttr<std::vector<TypePtr>>(kernel_node, "buffer_dtype"); | |||
| capacity_ = GetAttr<int64_t>(kernel_node, "capacity"); | |||
| seed_ = GetAttr<int64_t>(kernel_node, "seed"); | |||
| batch_size_ = LongToSize(GetAttr<int64_t>(kernel_node, "batch_size")); | |||
| element_nums_ = shapes.size(); | |||
| for (size_t i = 0; i < element_nums_; i++) { | |||
| @@ -52,28 +56,52 @@ bool BufferSampleKernel::Init(const CNodePtr &kernel_node) { | |||
| input_size_list_.push_back(capacity_ * element); | |||
| output_size_list_.push_back(batch_size_ * element); | |||
| } | |||
| // index | |||
| input_size_list_.push_back(sizeof(int) * batch_size_); | |||
| // count and head | |||
| input_size_list_.push_back(sizeof(int)); | |||
| input_size_list_.push_back(sizeof(int)); | |||
| workspace_size_list_.push_back(capacity_ * sizeof(curandState)); | |||
| workspace_size_list_.push_back(capacity_ * sizeof(float)); | |||
| workspace_size_list_.push_back(capacity_ * sizeof(int)); | |||
| workspace_size_list_.push_back(capacity_ * sizeof(float)); | |||
| return true; | |||
| } | |||
| void BufferSampleKernel::InitSizeLists() { return; } | |||
| bool BufferSampleKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | |||
| bool BufferSampleKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspaces, | |||
| const std::vector<AddressPtr> &outputs, void *stream) { | |||
| int *index_addr = GetDeviceAddress<int>(inputs, element_nums_); | |||
| int *count_addr = GetDeviceAddress<int>(inputs, element_nums_ + 1); | |||
| int *head_addr = GetDeviceAddress<int>(inputs, element_nums_ + 2); | |||
| int *count_addr = GetDeviceAddress<int>(inputs, element_nums_); | |||
| int *head_addr = GetDeviceAddress<int>(inputs, element_nums_ + 1); | |||
| auto cuda_stream = reinterpret_cast<cudaStream_t>(stream); | |||
| CheckBatchSize(count_addr, head_addr, batch_size_, capacity_, cuda_stream); | |||
| int k_cut = 0; | |||
| CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, | |||
| cudaMemcpyAsync(&k_cut, count_addr, sizeof(int), cudaMemcpyDeviceToHost, cuda_stream), | |||
| "sync dev to host failed"); | |||
| // 1 Generate random floats | |||
| auto States = GetDeviceAddress<void *>(workspaces, 0); | |||
| auto random_f = GetDeviceAddress<float>(workspaces, 1); | |||
| auto indexes = GetDeviceAddress<int>(workspaces, 2); | |||
| auto useless_out = GetDeviceAddress<float>(workspaces, 3); | |||
| int seedc = 0; | |||
| if (seed_ == 0) { | |||
| generator_.seed(std::chrono::system_clock::now().time_since_epoch().count()); | |||
| seedc = generator_(); | |||
| } else { | |||
| seedc = seed_; | |||
| } | |||
| float init_k = std::numeric_limits<float>::lowest(); | |||
| curandState *devStates = reinterpret_cast<curandState *>(States); | |||
| RandomGen(k_cut, devStates, seedc, random_f, cuda_stream); | |||
| // 2 Sort the random floats, and get the sorted indexes as the random indexes | |||
| FastTopK(1, k_cut, random_f, k_cut, useless_out, indexes, init_k, cuda_stream); | |||
| CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, cudaDeviceSynchronize(), "cudaDeviceSync failed, sample-topk"); | |||
| for (size_t i = 0; i < element_nums_; i++) { | |||
| auto buffer_addr = GetDeviceAddress<unsigned char>(inputs, i); | |||
| auto out_addr = GetDeviceAddress<unsigned char>(outputs, i); | |||
| size_t size = batch_size_ * exp_element_list[i]; | |||
| BufferSample(size, exp_element_list[i], index_addr, buffer_addr, out_addr, cuda_stream); | |||
| BufferSample(size, exp_element_list[i], indexes, buffer_addr, out_addr, cuda_stream); | |||
| } | |||
| return true; | |||
| } | |||
| @@ -20,6 +20,7 @@ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include <random> | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | |||
| @@ -45,6 +46,8 @@ class BufferSampleKernel : public GpuKernel { | |||
| size_t element_nums_; | |||
| int64_t capacity_; | |||
| size_t batch_size_; | |||
| int64_t seed_; | |||
| std::mt19937 generator_; | |||
| std::vector<size_t> exp_element_list; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| @@ -36,12 +36,12 @@ class BufferSample(PrimitiveWithInfer): | |||
| batch_size (int64): The size of the sampled data, lessequal to `capacity`. | |||
| buffer_shape (tuple(shape)): The shape of an buffer. | |||
| buffer_dtype (tuple(type)): The type of an buffer. | |||
| seed (int64): Random seed for sample. Default: 0. If use the default seed, it will generate a ramdom | |||
| one in kernel. Set a number other than `0` to keep a specific seed. | |||
| Inputs: | |||
| - **data** (tuple(Parameter(Tensor))) - The tuple(Tensor) represents replaybuffer, | |||
| each tensor is described by the `buffer_shape` and `buffer_type`. | |||
| - **indexes** (tuple(int32)) - The position list in replaybuffer, | |||
| the size equal to `batch_size`. | |||
| - **count** (Parameter) - The count mean the real available size of the buffer, | |||
| data type: int32. | |||
| - **head** (Parameter) - The position of the first data in buffer, data type: int32. | |||
| @@ -69,8 +69,7 @@ class BufferSample(PrimitiveWithInfer): | |||
| Parameter(Tensor(np.ones((100, 1)).astype(np.int32)), name="reward"), | |||
| Parameter(Tensor(np.arange(100 * 4).reshape(100, 4).astype(np.float32)), name="state_")] | |||
| >>> buffer_sample = ops.BufferSample(capacity, batch_size, shapes, types) | |||
| >>> indexes = Parameter(Tensor([0, 2, 4, 3, 8], ms.int32), name="index") | |||
| >>> output = buffer_sample(buffer, indexes, count, head) | |||
| >>> output = buffer_sample(buffer, count, head) | |||
| >>> print(output) | |||
| (Tensor(shape=[5, 4], dtype=Float32, value= | |||
| [[ 0.00000000e+00, 1.00000000e+00, 2.00000000e+00, 3.00000000e+00], | |||
| @@ -99,7 +98,7 @@ class BufferSample(PrimitiveWithInfer): | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, capacity, batch_size, buffer_shape, buffer_dtype): | |||
| def __init__(self, capacity, batch_size, buffer_shape, buffer_dtype, seed=0): | |||
| """Initialize BufferSample.""" | |||
| self.init_prim_io_names(inputs=["buffer"], outputs=["sample"]) | |||
| validator.check_value_type("shape of init data", buffer_shape, [tuple, list], self.name) | |||
| @@ -110,6 +109,7 @@ class BufferSample(PrimitiveWithInfer): | |||
| self._n = len(buffer_shape) | |||
| validator.check_int(self._batch_size, capacity, Rel.LE, "batchsize", self.name) | |||
| self.add_prim_attr('capacity', capacity) | |||
| self.add_prim_attr('seed', seed) | |||
| buffer_elements = [] | |||
| for shape in buffer_shape: | |||
| buffer_elements.append(reduce(lambda x, y: x * y, shape)) | |||
| @@ -119,17 +119,16 @@ class BufferSample(PrimitiveWithInfer): | |||
| if context.get_context('device_target') == "Ascend": | |||
| self.add_prim_attr('device_target', "CPU") | |||
| def infer_shape(self, data_shape, index_shape, count_shape, head_shape): | |||
| def infer_shape(self, data_shape, count_shape, head_shape): | |||
| validator.check_value_type("shape of data", data_shape, [tuple, list], self.name) | |||
| out_shapes = [] | |||
| for i in range(self._n): | |||
| out_shapes.append((self._batch_size,) + self._buffer_shape[i]) | |||
| return tuple(out_shapes) | |||
| def infer_dtype(self, data_type, index_type, count_type, head_type): | |||
| def infer_dtype(self, data_type, count_type, head_type): | |||
| validator.check_type_name("count type", count_type, (mstype.int32), self.name) | |||
| validator.check_type_name("head type", head_type, (mstype.int32), self.name) | |||
| validator.check_type_name("index type", index_type, (mstype.int64, mstype.int32), self.name) | |||
| return tuple(self._buffer_dtype) | |||
| class BufferAppend(PrimitiveWithInfer): | |||
| @@ -45,9 +45,6 @@ class RLBuffer(nn.Cell): | |||
| self.buffer_get = P.BufferGetItem(self._capacity, shapes, types) | |||
| self.buffer_sample = P.BufferSample( | |||
| self._capacity, batch_size, shapes, types) | |||
| self.dummy_tensor = Tensor(np.ones(shape=[batch_size]), ms.bool_) | |||
| self.rnd_choice_mask = P.RandomChoiceWithMask(count=batch_size) | |||
| self.reshape = P.Reshape() | |||
| @ms_function | |||
| def append(self, exps): | |||
| @@ -59,9 +56,7 @@ class RLBuffer(nn.Cell): | |||
| @ms_function | |||
| def sample(self): | |||
| index, _ = self.rnd_choice_mask(self.dummy_tensor) | |||
| index = self.reshape(index, (self._batch_size,)) | |||
| return self.buffer_sample(self.buffer, index, self.count, self.head) | |||
| return self.buffer_sample(self.buffer, self.count, self.head) | |||
| s = Tensor(np.array([2, 2, 2, 2]), ms.float32) | |||
| @@ -55,16 +55,13 @@ class RLBufferSample(nn.Cell): | |||
| def __init__(self, capcity, batch_size, shapes, types): | |||
| super(RLBufferSample, self).__init__() | |||
| self._capacity = capcity | |||
| count = 5 | |||
| self.count = Parameter(Tensor(5, ms.int32), name="count") | |||
| self.head = Parameter(Tensor(0, ms.int32), name="head") | |||
| self.input_x = Tensor(np.ones(shape=[count]), ms.bool_) | |||
| self.buffer_sample = P.BufferSample(self._capacity, batch_size, shapes, types) | |||
| self.index = Parameter(Tensor([0, 2, 4], ms.int32), name="index") | |||
| @ms_function | |||
| def construct(self, buffer): | |||
| return self.buffer_sample(buffer, self.index, self.count, self.head) | |||
| return self.buffer_sample(buffer, self.count, self.head) | |||
| states = Tensor(np.arange(4*5).reshape(5, 4).astype(np.float32)/10.0) | |||
| @@ -93,14 +90,7 @@ def test_BufferSample(): | |||
| buffer_sample = RLBufferSample(capcity=5, batch_size=3, shapes=[(4,), (2,), (1,), (4,)], types=[ | |||
| ms.float32, ms.int32, ms.int32, ms.float32]) | |||
| ss, aa, rr, ss_ = buffer_sample(b) | |||
| expect_s = [[0, 0.1, 0.2, 0.3], [0.8, 0.9, 1.0, 1.1], [1.6, 1.7, 1.8, 1.9]] | |||
| expect_a = [[0, 1], [4, 5], [8, 9]] | |||
| expect_r = [[1], [1], [1]] | |||
| expect_s_ = [[0, 1, 2, 3], [8, 9, 10, 11], [16, 17, 18, 19]] | |||
| np.testing.assert_almost_equal(ss.asnumpy(), expect_s) | |||
| np.testing.assert_almost_equal(aa.asnumpy(), expect_a) | |||
| np.testing.assert_almost_equal(rr.asnumpy(), expect_r) | |||
| np.testing.assert_almost_equal(ss_.asnumpy(), expect_s_) | |||
| print(ss, aa, rr, ss_) | |||
| @ pytest.mark.level0 | |||
| @@ -56,9 +56,7 @@ class RLBuffer(nn.Cell): | |||
| @ms_function | |||
| def sample(self): | |||
| count = self.reshape(self.count, (1,)) | |||
| index = self.randperm(count) | |||
| return self.buffer_sample(self.buffer, index, self.count, self.head) | |||
| return self.buffer_sample(self.buffer, self.count, self.head) | |||
| s = Tensor(np.array([2, 2, 2, 2]), ms.float32) | |||
| @@ -55,17 +55,14 @@ class RLBufferSample(nn.Cell): | |||
| def __init__(self, capcity, batch_size, shapes, types): | |||
| super(RLBufferSample, self).__init__() | |||
| self._capacity = capcity | |||
| count = 5 | |||
| self.count = Parameter(Tensor(5, ms.int32), name="count") | |||
| self.head = Parameter(Tensor(0, ms.int32), name="head") | |||
| self.input_x = Tensor(np.ones(shape=[count]), ms.bool_) | |||
| self.buffer_sample = P.BufferSample( | |||
| self._capacity, batch_size, shapes, types) | |||
| self.index = Parameter(Tensor([0, 2, 4], ms.int32), name="index") | |||
| @ms_function | |||
| def construct(self, buffer): | |||
| return self.buffer_sample(buffer, self.index, self.count, self.head) | |||
| return self.buffer_sample(buffer, self.count, self.head) | |||
| states = Tensor(np.arange(4*5).reshape(5, 4).astype(np.float32)/10.0) | |||
| @@ -94,14 +91,7 @@ def test_BufferSample(): | |||
| buffer_sample = RLBufferSample(capcity=5, batch_size=3, shapes=[(4,), (2,), (1,), (4,)], types=[ | |||
| ms.float32, ms.int32, ms.int32, ms.float32]) | |||
| ss, aa, rr, ss_ = buffer_sample(b) | |||
| expect_s = [[0, 0.1, 0.2, 0.3], [0.8, 0.9, 1.0, 1.1], [1.6, 1.7, 1.8, 1.9]] | |||
| expect_a = [[0, 1], [4, 5], [8, 9]] | |||
| expect_r = [[1], [1], [1]] | |||
| expect_s_ = [[0, 1, 2, 3], [8, 9, 10, 11], [16, 17, 18, 19]] | |||
| np.testing.assert_almost_equal(ss.asnumpy(), expect_s) | |||
| np.testing.assert_almost_equal(aa.asnumpy(), expect_a) | |||
| np.testing.assert_almost_equal(rr.asnumpy(), expect_r) | |||
| np.testing.assert_almost_equal(ss_.asnumpy(), expect_s_) | |||
| print(ss, aa, rr, ss_) | |||
| @ pytest.mark.level0 | |||