|
|
|
@@ -14,8 +14,28 @@ |
|
|
|
* limitations under the License.
|
|
|
|
*/
|
|
|
|
|
|
|
|
#include <random>
|
|
|
|
#include "multinomial_impl.cuh"
|
|
|
|
#include <algorithm>
|
|
|
|
|
|
|
|
template <typename T, typename S>
|
|
|
|
inline T Floor(const T &num, const S &unit) {
|
|
|
|
return static_cast<T>(num / unit);
|
|
|
|
}
|
|
|
|
|
|
|
|
template <typename T, typename S>
|
|
|
|
inline T Ceil(const T &num, const S &unit) {
|
|
|
|
return static_cast<T>((num + unit - 1) / unit);
|
|
|
|
}
|
|
|
|
|
|
|
|
__global__ void InitRandStateKernel(int seed, int num, curandState *state) {
|
|
|
|
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num; i += blockDim.x * gridDim.x) {
|
|
|
|
curand_init(seed, i, 0, &state[i]);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
void InitRandState(int seed, int num, curandState *state, cudaStream_t stream) {
|
|
|
|
InitRandStateKernel<<<(num + 127) / 128, 128, 0, stream>>>(seed, num, state);
|
|
|
|
}
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
__global__ void CheckZeroKernel(const size_t distributions, const size_t categories, const T *input, T *out) {
|
|
|
|
@@ -50,24 +70,6 @@ void CheckNonNeg(const size_t size, const T *input, T *output, cudaStream_t cuda |
|
|
|
CheckNonNegKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, input, output);
|
|
|
|
}
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
__global__ void NormInputKernel(T *input, const size_t distributions, const size_t categories) {
|
|
|
|
size_t size = distributions * categories;
|
|
|
|
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) {
|
|
|
|
if ((pos + 1) % categories != 0) {
|
|
|
|
int de_pos = (1 + pos / categories) * categories - 1;
|
|
|
|
input[pos] /= input[de_pos];
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
void NormInput(T *input, const size_t distributions, const size_t categories, cudaStream_t cuda_stream) {
|
|
|
|
int count1 = distributions * categories;
|
|
|
|
NormInputKernel<<<GET_BLOCKS(count1), GET_THREADS, 0, cuda_stream>>>(input, distributions, categories);
|
|
|
|
}
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
__device__ int BinarySearchForMultinomial(T *start_addr, int size, T rand) {
|
|
|
|
int start = 0;
|
|
|
|
@@ -88,41 +90,53 @@ __device__ int BinarySearchForMultinomial(T *start_addr, int size, T rand) { |
|
|
|
}
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
__global__ void MultinomialKernel(int seed, T *input, int num_sample, curandState *globalState, int *output,
|
|
|
|
size_t distributions, size_t categories) {
|
|
|
|
int count = num_sample * distributions;
|
|
|
|
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < count; i += blockDim.x * gridDim.x) {
|
|
|
|
int j = i / num_sample % distributions;
|
|
|
|
curand_init(seed, i, 0, &globalState[i]);
|
|
|
|
auto rand = curand_uniform(&globalState[i]);
|
|
|
|
int pick = BinarySearchForMultinomial(input + j * categories, categories, rand);
|
|
|
|
output[i] = pick;
|
|
|
|
__global__ void MultinomialKernel(int row, int col, T *probs, curandState *state, int64_t *num_sample, int *output) {
|
|
|
|
// Load the probs to shared memory.
|
|
|
|
extern __shared__ float accum_probs[];
|
|
|
|
int probs_base_index = (blockIdx.x * blockDim.x + threadIdx.x) * col;
|
|
|
|
if (probs_base_index > row * col) {
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
return;
|
|
|
|
|
|
|
|
int shm_base_index = threadIdx.x * col;
|
|
|
|
accum_probs[shm_base_index] = probs[probs_base_index];
|
|
|
|
for (int i = 1; i < col; i++) {
|
|
|
|
probs_base_index++;
|
|
|
|
accum_probs[shm_base_index + i] = accum_probs[shm_base_index + i - 1] + probs[probs_base_index];
|
|
|
|
}
|
|
|
|
__syncthreads();
|
|
|
|
|
|
|
|
// Probs normalization.
|
|
|
|
float max_probs = accum_probs[shm_base_index + col - 1];
|
|
|
|
for (int i = 0; i < col; i++) {
|
|
|
|
accum_probs[shm_base_index + i] /= max_probs;
|
|
|
|
}
|
|
|
|
__syncthreads();
|
|
|
|
|
|
|
|
// Sample.
|
|
|
|
int output_base_index = (blockIdx.x * blockDim.x + threadIdx.x) * num_sample[0];
|
|
|
|
auto local_state = state[output_base_index];
|
|
|
|
for (int i = 0; i < num_sample[0]; i++) {
|
|
|
|
float rand = curand_uniform(&local_state);
|
|
|
|
output[output_base_index + i] = BinarySearchForMultinomial(&accum_probs[shm_base_index], col, rand);
|
|
|
|
}
|
|
|
|
state[output_base_index] = local_state;
|
|
|
|
}
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
void Multinomial(int seed, int seed2, T *input, int num_sample, curandState *globalState, int *output,
|
|
|
|
size_t distributions, size_t categories, cudaStream_t cuda_stream) {
|
|
|
|
int RNG_seed = 0;
|
|
|
|
std::random_device rd;
|
|
|
|
if (seed2 != 0) {
|
|
|
|
RNG_seed = seed2;
|
|
|
|
} else if (seed != 0) {
|
|
|
|
RNG_seed = seed;
|
|
|
|
} else {
|
|
|
|
RNG_seed = static_cast<int>(rd());
|
|
|
|
}
|
|
|
|
int count = distributions * num_sample;
|
|
|
|
MultinomialKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(RNG_seed, input, num_sample, globalState,
|
|
|
|
output, distributions, categories);
|
|
|
|
return;
|
|
|
|
void Multinomial(int row, int col, T *probs, curandState *state, int64_t *num_sample, int *output,
|
|
|
|
cudaStream_t stream) {
|
|
|
|
// Every block process several rows. It depends on shared memory usage.
|
|
|
|
constexpr int max_shm_used_per_block = 256;
|
|
|
|
int block_dim = std::max(Floor(std::min(row, max_shm_used_per_block), col), 1);
|
|
|
|
int grid_dim = Ceil(row, block_dim);
|
|
|
|
int shm_size = block_dim * col * sizeof(float);
|
|
|
|
|
|
|
|
MultinomialKernel<<<grid_dim, block_dim, shm_size, stream>>>(row, col, probs, state, num_sample, output);
|
|
|
|
}
|
|
|
|
|
|
|
|
template void Multinomial<float>(int seed, int seed2, float *input, int num_sample, curandState *globalState,
|
|
|
|
int *output, size_t distributions, size_t categories, cudaStream_t cuda_stream);
|
|
|
|
template void Multinomial<float>(int row, int col, float *probs, curandState *state, int64_t *num_sample, int *output,
|
|
|
|
cudaStream_t stream);
|
|
|
|
template void CheckNonNeg<float>(const size_t size, const float *input, float *output, cudaStream_t cuda_stream);
|
|
|
|
template void CheckZero<float>(const size_t distributions, const size_t categories, const float *input, float *output,
|
|
|
|
cudaStream_t cuda_stream);
|
|
|
|
template void NormInput<float>(float *input, const size_t distributions, const size_t categories,
|
|
|
|
cudaStream_t cuda_stream);
|