Browse Source

multinomial profiling

tags/v1.6.0
wilfChen 4 years ago
parent
commit
cf63527a15
4 changed files with 100 additions and 71 deletions
  1. +62
    -48
      mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/multinomial_impl.cu
  2. +3
    -4
      mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/multinomial_impl.cuh
  3. +1
    -1
      mindspore/ccsrc/backend/kernel_compiler/gpu/math/multinomial_gpu_kernel.cc
  4. +34
    -18
      mindspore/ccsrc/backend/kernel_compiler/gpu/math/multinomial_gpu_kernel.h

+ 62
- 48
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/multinomial_impl.cu View File

@@ -14,8 +14,28 @@
* limitations under the License.
*/
#include <random>
#include "multinomial_impl.cuh"
#include <algorithm>
template <typename T, typename S>
inline T Floor(const T &num, const S &unit) {
return static_cast<T>(num / unit);
}
template <typename T, typename S>
inline T Ceil(const T &num, const S &unit) {
return static_cast<T>((num + unit - 1) / unit);
}
__global__ void InitRandStateKernel(int seed, int num, curandState *state) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num; i += blockDim.x * gridDim.x) {
curand_init(seed, i, 0, &state[i]);
}
}
void InitRandState(int seed, int num, curandState *state, cudaStream_t stream) {
InitRandStateKernel<<<(num + 127) / 128, 128, 0, stream>>>(seed, num, state);
}
template <typename T>
__global__ void CheckZeroKernel(const size_t distributions, const size_t categories, const T *input, T *out) {
@@ -50,24 +70,6 @@ void CheckNonNeg(const size_t size, const T *input, T *output, cudaStream_t cuda
CheckNonNegKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, input, output);
}
template <typename T>
__global__ void NormInputKernel(T *input, const size_t distributions, const size_t categories) {
size_t size = distributions * categories;
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) {
if ((pos + 1) % categories != 0) {
int de_pos = (1 + pos / categories) * categories - 1;
input[pos] /= input[de_pos];
}
}
return;
}
template <typename T>
void NormInput(T *input, const size_t distributions, const size_t categories, cudaStream_t cuda_stream) {
int count1 = distributions * categories;
NormInputKernel<<<GET_BLOCKS(count1), GET_THREADS, 0, cuda_stream>>>(input, distributions, categories);
}
template <typename T>
__device__ int BinarySearchForMultinomial(T *start_addr, int size, T rand) {
int start = 0;
@@ -88,41 +90,53 @@ __device__ int BinarySearchForMultinomial(T *start_addr, int size, T rand) {
}
template <typename T>
__global__ void MultinomialKernel(int seed, T *input, int num_sample, curandState *globalState, int *output,
size_t distributions, size_t categories) {
int count = num_sample * distributions;
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < count; i += blockDim.x * gridDim.x) {
int j = i / num_sample % distributions;
curand_init(seed, i, 0, &globalState[i]);
auto rand = curand_uniform(&globalState[i]);
int pick = BinarySearchForMultinomial(input + j * categories, categories, rand);
output[i] = pick;
__global__ void MultinomialKernel(int row, int col, T *probs, curandState *state, int64_t *num_sample, int *output) {
// Load the probs to shared memory.
extern __shared__ float accum_probs[];
int probs_base_index = (blockIdx.x * blockDim.x + threadIdx.x) * col;
if (probs_base_index > row * col) {
return;
}
return;
int shm_base_index = threadIdx.x * col;
accum_probs[shm_base_index] = probs[probs_base_index];
for (int i = 1; i < col; i++) {
probs_base_index++;
accum_probs[shm_base_index + i] = accum_probs[shm_base_index + i - 1] + probs[probs_base_index];
}
__syncthreads();
// Probs normalization.
float max_probs = accum_probs[shm_base_index + col - 1];
for (int i = 0; i < col; i++) {
accum_probs[shm_base_index + i] /= max_probs;
}
__syncthreads();
// Sample.
int output_base_index = (blockIdx.x * blockDim.x + threadIdx.x) * num_sample[0];
auto local_state = state[output_base_index];
for (int i = 0; i < num_sample[0]; i++) {
float rand = curand_uniform(&local_state);
output[output_base_index + i] = BinarySearchForMultinomial(&accum_probs[shm_base_index], col, rand);
}
state[output_base_index] = local_state;
}
template <typename T>
void Multinomial(int seed, int seed2, T *input, int num_sample, curandState *globalState, int *output,
size_t distributions, size_t categories, cudaStream_t cuda_stream) {
int RNG_seed = 0;
std::random_device rd;
if (seed2 != 0) {
RNG_seed = seed2;
} else if (seed != 0) {
RNG_seed = seed;
} else {
RNG_seed = static_cast<int>(rd());
}
int count = distributions * num_sample;
MultinomialKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(RNG_seed, input, num_sample, globalState,
output, distributions, categories);
return;
void Multinomial(int row, int col, T *probs, curandState *state, int64_t *num_sample, int *output,
cudaStream_t stream) {
// Every block process several rows. It depends on shared memory usage.
constexpr int max_shm_used_per_block = 256;
int block_dim = std::max(Floor(std::min(row, max_shm_used_per_block), col), 1);
int grid_dim = Ceil(row, block_dim);
int shm_size = block_dim * col * sizeof(float);
MultinomialKernel<<<grid_dim, block_dim, shm_size, stream>>>(row, col, probs, state, num_sample, output);
}
template void Multinomial<float>(int seed, int seed2, float *input, int num_sample, curandState *globalState,
int *output, size_t distributions, size_t categories, cudaStream_t cuda_stream);
template void Multinomial<float>(int row, int col, float *probs, curandState *state, int64_t *num_sample, int *output,
cudaStream_t stream);
template void CheckNonNeg<float>(const size_t size, const float *input, float *output, cudaStream_t cuda_stream);
template void CheckZero<float>(const size_t distributions, const size_t categories, const float *input, float *output,
cudaStream_t cuda_stream);
template void NormInput<float>(float *input, const size_t distributions, const size_t categories,
cudaStream_t cuda_stream);

+ 3
- 4
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/multinomial_impl.cuh View File

@@ -19,14 +19,13 @@
#include <curand_kernel.h>
#include "runtime/device/gpu/cuda_common.h"
void InitRandState(int seed, int num, curandState *state, cudaStream_t stream);
template <typename T>
void Multinomial(int seed, int seed2, T *input, int num_sample, curandState *globalState, int *output,
size_t distributions, size_t categories, cudaStream_t cuda_stream);
void Multinomial(int row, int col, T *probs, curandState *rand_state, int64_t *num_sample, int *output,
cudaStream_t stream);
template <typename T>
void CheckNonNeg(const size_t size, const T *input, T *output, cudaStream_t stream);
template <typename T>
void CheckZero(const size_t distributions, const size_t categories, const T *input, T *output, cudaStream_t stream);
template <typename T>
void NormInput(T *input, const size_t distributions, const size_t categories, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MULTINOMIAL_IMPL_CUH_

+ 1
- 1
mindspore/ccsrc/backend/kernel_compiler/gpu/math/multinomial_gpu_kernel.cc View File

@@ -20,7 +20,7 @@ namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(
Multinomial,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
MultinomialGpuKernel, float)
} // namespace kernel
} // namespace mindspore

+ 34
- 18
mindspore/ccsrc/backend/kernel_compiler/gpu/math/multinomial_gpu_kernel.h View File

@@ -22,6 +22,8 @@
#include <vector>
#include <string>
#include <map>
#include <random>
#include "runtime/device/gpu/gpu_memory_allocator.h"
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/multinomial_impl.cuh"
@@ -36,10 +38,12 @@ class MultinomialGpuKernel : public GpuKernel {
: input_size_0_(0),
output_size_(0),
distributions_(0),
workspace_size_(sizeof(curandState)),
categories_{0},
seed_(0),
seed2_(0),
is_null_input_(false) {}
is_null_input_(false),
rand_state_init_(false),
rand_state_(nullptr) {}
~MultinomialGpuKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
@@ -51,23 +55,30 @@ class MultinomialGpuKernel : public GpuKernel {
if (is_null_input_) {
return true;
}
void *workspace_addr = GetDeviceAddress<void *>(workspace, 1);
T *cum_sum_input = GetDeviceAddress<T>(workspace, 0);
curandState *devStates = reinterpret_cast<curandState *>(workspace_addr);
int *output_addr = GetDeviceAddress<int>(outputs, 0);
T *input_addr = GetDeviceAddress<T>(inputs, 0);
T *probs_addr = GetDeviceAddress<T>(inputs, 0);
int64_t *num_sample_addr = GetDeviceAddress<int64_t>(inputs, 1);
if (distributions_ == 0) {
MS_LOG(ERROR) << "Divide by zero. the distributions_ is 0.";
return false;
}
int categories = SizeToInt(inputs[0]->size / sizeof(T)) / distributions_;
int num_sample = SizeToInt(outputs[0]->size / sizeof(int)) / distributions_;
CumSum(input_addr, cum_sum_input, cum_sum_input, IntToSize(distributions_), IntToSize(categories), 1,
IntToSize(categories), 1, false, false, reinterpret_cast<cudaStream_t>(stream_ptr));
NormInput(cum_sum_input, IntToSize(distributions_), IntToSize(categories),
reinterpret_cast<cudaStream_t>(stream_ptr));
Multinomial(seed_, seed2_, cum_sum_input, num_sample, devStates, output_addr, IntToSize(distributions_),
IntToSize(categories), reinterpret_cast<cudaStream_t>(stream_ptr));
auto stream = reinterpret_cast<cudaStream_t>(stream_ptr);
if (!rand_state_init_) {
int rng_seed = 0;
std::random_device rd;
if (seed2_ != 0) {
rng_seed = seed2_;
} else if (seed_ != 0) {
rng_seed = seed_;
} else {
rng_seed = static_cast<int>(rd());
}
InitRandState(rng_seed, distributions_, rand_state_, stream);
rand_state_init_ = true;
}
Multinomial(distributions_, categories_, probs_addr, rand_state_, num_sample_addr, output_addr, stream);
return true;
}
@@ -93,8 +104,10 @@ class MultinomialGpuKernel : public GpuKernel {
}
if (input_shape_0.size() == 1) {
distributions_ = 1;
categories_ = input_shape_0[0];
} else {
distributions_ = input_shape_0[0];
categories_ = input_shape_0[1];
}
input_size_0_ = sizeof(T);
for (size_t i = 0; i < input_shape_0.size(); i++) {
@@ -105,11 +118,14 @@ class MultinomialGpuKernel : public GpuKernel {
for (size_t i = 0; i < output_shape.size(); i++) {
output_size_ *= output_shape[i];
}
workspace_size_ = output_size_ / sizeof(int) * sizeof(curandState);
auto prim = AnfAlgo::GetCNodePrimitive(kernel_node);
MS_EXCEPTION_IF_NULL(prim);
seed_ = static_cast<int>(GetValue<int64_t>(prim->GetAttr("seed")));
seed2_ = static_cast<int>(GetValue<int64_t>(prim->GetAttr("seed2")));
auto &allocator = device::gpu::GPUMemoryAllocator::GetInstance();
rand_state_ = static_cast<curandState *>(allocator.AllocTensorMem(sizeof(curandState) * distributions_));
InitSizeLists();
return true;
}
@@ -119,18 +135,18 @@ class MultinomialGpuKernel : public GpuKernel {
input_size_list_.push_back(input_size_0_);
input_size_list_.push_back(sizeof(int));
output_size_list_.push_back(output_size_);
workspace_size_list_.push_back(input_size_0_);
workspace_size_list_.push_back(workspace_size_);
}
private:
size_t input_size_0_;
size_t output_size_;
size_t distributions_;
size_t workspace_size_;
size_t categories_;
int seed_;
int seed2_;
bool is_null_input_;
bool rand_state_init_;
curandState *rand_state_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;


Loading…
Cancel
Save