GitOrigin-RevId: 4c05ebc266
tags/v1.5.0
| @@ -21,19 +21,77 @@ class RNGBase: public OperatorBase { | |||
| _megdnn_workspace workspace) = 0; | |||
| virtual size_t get_workspace_in_bytes(const TensorLayout &dst) = 0; | |||
| protected: | |||
| void check_exec(const TensorLayout &dst, size_t workspace_in_bytes); | |||
| virtual void check_exec(const TensorLayout &dst, size_t workspace_in_bytes) = 0; | |||
| }; | |||
| //! sample from poisson distribution | |||
| class PoissonRNG: public OperatorBase { | |||
| DEF_OPR_IMPL(PoissonRNG, OperatorBase, 1, 1); | |||
| DEF_OPR_PARAM(PoissonRNG); | |||
| public: | |||
| virtual void exec(_megdnn_tensor_in lam, | |||
| _megdnn_tensor_out dst, | |||
| _megdnn_workspace workspace) = 0; | |||
| virtual size_t get_workspace_in_bytes(const TensorLayout &lam, | |||
| const TensorLayout &dst) = 0; | |||
| protected: | |||
| void check_exec(const TensorLayout &lam, const TensorLayout &dst, | |||
| size_t workspace_in_bytes); | |||
| }; | |||
| //! sample from beta distribution | |||
| class BetaRNG: public OperatorBase { | |||
| DEF_OPR_IMPL(BetaRNG, OperatorBase, 2, 1); | |||
| DEF_OPR_PARAM(BetaRNG); | |||
| public: | |||
| virtual void exec(_megdnn_tensor_in alpha, | |||
| _megdnn_tensor_in beta, | |||
| _megdnn_tensor_out dst, | |||
| _megdnn_workspace workspace) = 0; | |||
| virtual size_t get_workspace_in_bytes(const TensorLayout &alpha, | |||
| const TensorLayout &beta, const TensorLayout &dst) = 0; | |||
| protected: | |||
| void check_exec(const TensorLayout &alpha, const TensorLayout &beta, | |||
| const TensorLayout &dst, size_t workspace_in_bytes); | |||
| }; | |||
| //! sample from gamma distribution | |||
| class GammaRNG: public OperatorBase { | |||
| DEF_OPR_IMPL(GammaRNG, OperatorBase, 2, 1); | |||
| DEF_OPR_PARAM(GammaRNG); | |||
| public: | |||
| virtual void exec(_megdnn_tensor_in shape, | |||
| _megdnn_tensor_in scale, | |||
| _megdnn_tensor_out dst, | |||
| _megdnn_workspace workspace) = 0; | |||
| virtual size_t get_workspace_in_bytes(const TensorLayout &shape, | |||
| const TensorLayout &scale, const TensorLayout &dst) = 0; | |||
| protected: | |||
| void check_exec(const TensorLayout &shape, const TensorLayout &scale, | |||
| const TensorLayout &dst, size_t workspace_in_bytes); | |||
| }; | |||
| //! sample from uniform distribution on the interval (0, 1] | |||
| class UniformRNG: public RNGBase { | |||
| DEF_OPR_IMPL(UniformRNG, RNGBase, 0, 1); | |||
| DEF_OPR_PARAM(UniformRNG); | |||
| protected: | |||
| void check_exec(const TensorLayout &dst, size_t workspace_in_bytes); | |||
| }; | |||
| //! sample from gaussian distribution | |||
| class GaussianRNG: public RNGBase { | |||
| DEF_OPR_IMPL(GaussianRNG, RNGBase, 0, 1); | |||
| DEF_OPR_PARAM(GaussianRNG); | |||
| protected: | |||
| void check_exec(const TensorLayout &dst, size_t workspace_in_bytes); | |||
| }; | |||
| class PermutationRNG: public RNGBase { | |||
| DEF_OPR_IMPL(PermutationRNG, RNGBase, 0, 1); | |||
| DEF_OPR_PARAM(PermutationRNG); | |||
| protected: | |||
| void check_exec(const TensorLayout &dst, size_t workspace_in_bytes); | |||
| }; | |||
| /*! | |||
| @@ -735,11 +735,34 @@ pdef('Sleep').add_fields('float32', Doc('time', 'time to sleep in seconds'), 0) | |||
| 'dtype', Doc('dtype', 'data type of output value'), | |||
| 'DTypeEnum::Float32')) | |||
| pdef('UniformRNG').add_fields('uint64', 'seed', 0) | |||
| (pdef('UniformRNG'). | |||
| add_fields('uint64', 'seed', 0). | |||
| add_fields( | |||
| 'dtype', Doc('dtype', 'The dtype of output Tensor. Only support Float32.'), | |||
| 'DTypeEnum::Float32')) | |||
| (pdef('GaussianRNG'). | |||
| add_fields('uint64', 'seed', 0). | |||
| add_fields('float32', 'mean', 0, 'std', 1)) | |||
| add_fields('float32', 'mean', 0, 'std', 1). | |||
| add_fields( | |||
| 'dtype', Doc('dtype', 'The dtype of output Tensor. Only support Float32.'), | |||
| 'DTypeEnum::Float32')) | |||
| (pdef('GammaRNG'). | |||
| add_fields('uint64', 'seed', 0)) | |||
| (pdef('BetaRNG'). | |||
| add_fields('uint64', 'seed', 0)) | |||
| (pdef('PoissonRNG'). | |||
| add_fields('uint64', 'seed', 0)) | |||
| (pdef('PermutationRNG'). | |||
| add_fields('uint64', 'seed', 0). | |||
| add_fields( | |||
| 'dtype', Doc('dtype', 'The dtype of output Tensor. Int32, Int16 and ' | |||
| 'Float32 are supported.'), | |||
| 'DTypeEnum::Int32')) | |||
| (pdef('Flip'). | |||
| add_fields('bool', 'vertical', 'false', 'horizontal', 'false')) | |||
| @@ -159,6 +159,10 @@ private: | |||
| cb(SleepForward) \ | |||
| cb(UniformRNG) \ | |||
| cb(GaussianRNG) \ | |||
| cb(GammaRNG) \ | |||
| cb(BetaRNG) \ | |||
| cb(PoissonRNG) \ | |||
| cb(PermutationRNG) \ | |||
| cb(SeparableConvForward) \ | |||
| cb(SeparableFilterForward) \ | |||
| cb(BNForward) \ | |||
| @@ -120,6 +120,10 @@ DEF(TQTBackward, 5, true, false); | |||
| DEF(PowC, 2, false, true); | |||
| DEF(UniformRNG, 1, true, true); | |||
| DEF(GaussianRNG, 1, true, true); | |||
| DEF(GammaRNG, 3, true, true); | |||
| DEF(BetaRNG, 3, true, true); | |||
| DEF(PoissonRNG, 2, true, true); | |||
| DEF(PermutationRNG, 1, true, true); | |||
| DEF(ChecksumForward, 1, true, false); | |||
| DEF(CheckHasInf, 2, true, true); | |||
| DEF(LSQForward, 5, true, true); | |||
| @@ -15,13 +15,62 @@ | |||
| namespace megdnn { | |||
| void RNGBase::check_exec( | |||
| void PermutationRNG::check_exec( | |||
| const TensorLayout &dst, size_t workspace_in_bytes) { | |||
| megdnn_assert(dst.dtype.category() == DTypeCategory::FLOAT && | |||
| dst.is_contiguous()); | |||
| megdnn_assert((dst.dtype == dtype::Float32() || | |||
| dst.dtype == dtype::Int32() || | |||
| dst.dtype == dtype::Int16() ) && | |||
| dst.dtype.enumv() == param().dtype && | |||
| dst.is_contiguous()); | |||
| megdnn_assert(workspace_in_bytes >= get_workspace_in_bytes(dst)); | |||
| } | |||
| void PoissonRNG::check_exec(const TensorLayout &lam, const TensorLayout &dst, | |||
| size_t workspace_in_bytes){ | |||
| megdnn_assert( dst.dtype.category() == DTypeCategory::FLOAT && | |||
| lam.dtype == dst.dtype); | |||
| megdnn_assert(dst.is_contiguous() && lam.is_contiguous()); | |||
| megdnn_assert(lam.total_nr_elems() == dst.total_nr_elems()); | |||
| megdnn_assert(workspace_in_bytes >= get_workspace_in_bytes(lam, dst)); | |||
| } | |||
| void GammaRNG::check_exec(const TensorLayout &shape,const TensorLayout &scale, | |||
| const TensorLayout &dst, size_t workspace_in_bytes){ | |||
| megdnn_assert(dst.dtype.category() == DTypeCategory::FLOAT && | |||
| shape.dtype == dst.dtype && | |||
| scale.dtype == dst.dtype); | |||
| megdnn_assert(shape.is_contiguous() && scale.is_contiguous() | |||
| && dst.is_contiguous()); | |||
| megdnn_assert(shape.total_nr_elems() == dst.total_nr_elems() && | |||
| scale.total_nr_elems() == dst.total_nr_elems()); | |||
| megdnn_assert(workspace_in_bytes >= get_workspace_in_bytes(shape,scale,dst)); | |||
| } | |||
| void BetaRNG::check_exec(const TensorLayout &alpha,const TensorLayout &beta, | |||
| const TensorLayout &dst, size_t workspace_in_bytes){ | |||
| megdnn_assert(dst.dtype.category() == DTypeCategory::FLOAT && | |||
| alpha.dtype == dst.dtype && | |||
| beta.dtype == dst.dtype); | |||
| megdnn_assert(alpha.is_contiguous() && beta.is_contiguous() | |||
| && dst.is_contiguous()); | |||
| megdnn_assert(alpha.total_nr_elems() == dst.total_nr_elems() && | |||
| beta.total_nr_elems() == dst.total_nr_elems()); | |||
| megdnn_assert(workspace_in_bytes >= get_workspace_in_bytes(alpha,beta, dst)); | |||
| } | |||
| #define INST_CHECK_EXEC(RNG_NAME) \ | |||
| void RNG_NAME::check_exec( \ | |||
| const TensorLayout &dst, size_t workspace_in_bytes) { \ | |||
| megdnn_assert(dst.dtype.category() == DTypeCategory::FLOAT && \ | |||
| dst.dtype.enumv() == param().dtype && \ | |||
| dst.is_contiguous()); \ | |||
| megdnn_assert(workspace_in_bytes >= get_workspace_in_bytes(dst)); \ | |||
| } | |||
| INST_CHECK_EXEC(UniformRNG) | |||
| INST_CHECK_EXEC(GaussianRNG) | |||
| #undef INST_CHECK_EXEC | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -49,23 +49,42 @@ bool use_segmented(uint32_t M, uint32_t /*N*/) { | |||
| return M >= 8; | |||
| } | |||
| template <typename KeyType> | |||
| MEGDNN_NOINLINE size_t cub_sort_pairs( | |||
| __global__ void kern_arange(int* dst, uint32_t n, uint32_t mod) { | |||
| uint32_t i = threadIdx.x + blockIdx.x * blockDim.x; | |||
| if (i < n) { | |||
| dst[i] = i % mod; | |||
| } | |||
| } | |||
| template <typename ctype> | |||
| size_t get_sort_workspace(uint32_t M, uint32_t N, bool is_ascending) { | |||
| if (use_bitonic(M, N)) { | |||
| return 0; | |||
| } | |||
| return argsort::cub_sort_pairs<ctype, int>(is_ascending, NULL, 0, NULL, NULL, NULL, NULL, | |||
| M, N, 0, sizeof(float)*8, NULL); | |||
| } | |||
| } // anonymous namespace | |||
| template <typename KeyType, typename ValueType> | |||
| MEGDNN_NOINLINE size_t argsort::cub_sort_pairs( | |||
| bool is_ascending, void* workspace, size_t workspace_size, | |||
| const KeyType* keys_in, KeyType* keys_out, const int* values_in, | |||
| int* values_out, uint32_t M, uint32_t N, cudaStream_t stream) { | |||
| const KeyType* keys_in, KeyType* keys_out, const ValueType* values_in, | |||
| ValueType* values_out, uint32_t M, uint32_t N, int begin_bit, int end_bit,cudaStream_t stream){ | |||
| cudaError_t err; | |||
| if (use_segmented(M, N)) { | |||
| if (is_ascending) { | |||
| err = cub::DeviceSegmentedRadixSort::SortPairs( | |||
| workspace, workspace_size, keys_in, keys_out, values_in, | |||
| values_out, N * M, M, StridedOffsetIterator(0, N), | |||
| StridedOffsetIterator(N, N), 0, sizeof(float) * 8, stream); | |||
| StridedOffsetIterator(N, N), begin_bit, end_bit, stream); | |||
| cuda_check(err); | |||
| } else { | |||
| err = cub::DeviceSegmentedRadixSort::SortPairsDescending( | |||
| workspace, workspace_size, keys_in, keys_out, values_in, | |||
| values_out, N * M, M, StridedOffsetIterator(0, N), | |||
| StridedOffsetIterator(N, N), 0, sizeof(float) * 8, stream); | |||
| StridedOffsetIterator(N, N), begin_bit, end_bit, stream); | |||
| cuda_check(err); | |||
| } | |||
| } else { | |||
| if (is_ascending) { | |||
| @@ -73,7 +92,7 @@ MEGDNN_NOINLINE size_t cub_sort_pairs( | |||
| err = cub::DeviceRadixSort::SortPairs( | |||
| workspace, workspace_size, keys_in + N * i, | |||
| keys_out + N * i, values_in + N * i, values_out + N * i, | |||
| N, 0, sizeof(float) * 8, stream); | |||
| N, begin_bit, end_bit, stream); | |||
| cuda_check(err); | |||
| if (!keys_in) { | |||
| return workspace_size; | |||
| @@ -84,7 +103,7 @@ MEGDNN_NOINLINE size_t cub_sort_pairs( | |||
| err = cub::DeviceRadixSort::SortPairsDescending( | |||
| workspace, workspace_size, keys_in + N * i, | |||
| keys_out + N * i, values_in + N * i, values_out + N * i, | |||
| N, 0, sizeof(float) * 8, stream); | |||
| N, begin_bit, end_bit, stream); | |||
| cuda_check(err); | |||
| if (!keys_in) { | |||
| return workspace_size; | |||
| @@ -95,23 +114,6 @@ MEGDNN_NOINLINE size_t cub_sort_pairs( | |||
| return workspace_size; | |||
| } | |||
| __global__ void kern_arange(int* dst, uint32_t n, uint32_t mod) { | |||
| uint32_t i = threadIdx.x + blockIdx.x * blockDim.x; | |||
| if (i < n) { | |||
| dst[i] = i % mod; | |||
| } | |||
| } | |||
| template <typename ctype> | |||
| size_t get_sort_workspace(uint32_t M, uint32_t N, bool is_ascending) { | |||
| if (use_bitonic(M, N)) { | |||
| return 0; | |||
| } | |||
| return cub_sort_pairs<ctype>(is_ascending, NULL, 0, NULL, NULL, NULL, NULL, | |||
| M, N, NULL); | |||
| } | |||
| } // anonymous namespace | |||
| size_t argsort::get_fwd_workspace_in_bytes(uint32_t M, uint32_t N, DType dtype, | |||
| bool is_ascending, | |||
| bool iptr_src_given) { | |||
| @@ -151,17 +153,28 @@ void argsort::forward(const dtype* sptr, dtype* dptr, int* iptr, | |||
| stream)); | |||
| } else { | |||
| cub_sort_pairs(is_ascending, workspace, wk_size, sptr, dptr, iptr_src, | |||
| iptr, M, N, stream); | |||
| iptr, M, N, 0, sizeof(float)*8, stream); | |||
| } | |||
| } | |||
| namespace megdnn { | |||
| namespace cuda { | |||
| #define INST_CUB_SORT(dtype) \ | |||
| template MEGDNN_NOINLINE size_t argsort::cub_sort_pairs<dtype, dtype>(bool, \ | |||
| void*, size_t, const dtype*, dtype*, \ | |||
| const dtype*, dtype*, uint32_t, uint32_t,\ | |||
| int, int, cudaStream_t); | |||
| #define INST_FORWARD(dtype) \ | |||
| template void argsort::forward<dtype>(const dtype*, dtype*, int*, void*, \ | |||
| uint32_t, uint32_t, bool, \ | |||
| cudaStream_t, const int*); | |||
| template void argsort::forward<dtype>(const dtype*, dtype*, int*, void*, \ | |||
| uint32_t, uint32_t, bool, cudaStream_t, \ | |||
| const int*); | |||
| ARGSORT_FOREACH_CTYPE(INST_FORWARD) | |||
| INST_CUB_SORT(uint32_t) | |||
| INST_CUB_SORT(uint64_t) | |||
| #undef INST_CUB_SORT | |||
| #undef INST_FORWARD | |||
| } | |||
| } // namespace megdnn | |||
| @@ -24,6 +24,12 @@ size_t get_fwd_workspace_in_bytes(uint32_t M, uint32_t N, DType dtype, | |||
| bool is_ascending, | |||
| bool iptr_src_given = false); | |||
| template <typename KeyType, typename ValueType> | |||
| size_t cub_sort_pairs( | |||
| bool is_ascending, void* workspace, size_t workspace_size, | |||
| const KeyType* keys_in, KeyType* keys_out, const ValueType* values_in, | |||
| ValueType* values_out, uint32_t M, uint32_t N, int begin_bit, int end_bit,cudaStream_t stream); | |||
| /*! | |||
| * \param iptr_src pointer to indices; a range would be generated if it is null | |||
| */ | |||
| @@ -0,0 +1,174 @@ | |||
| /** | |||
| * \file dnn/src/cuda/rnd/kernel.cu | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include <curand_kernel.h> | |||
| #include <device_launch_parameters.h> | |||
| #include "../argsort/argsort.cuh" | |||
| #include "./kernel.cuh" | |||
| #include "src/cuda/cuda_shfl_compat.cuh" | |||
| #include "src/cuda/utils.cuh" | |||
| namespace megdnn { | |||
| namespace cuda { | |||
| namespace random { | |||
| template <typename KeyType, typename ValueType> | |||
| __global__ void permute_duplicate_keys_kernel(KeyType* keys, ValueType* indexs, | |||
| KeyType mask, size_t size, | |||
| uint64_t seed, uint64_t offset) { | |||
| uint32_t idx = threadIdx.x + blockDim.x * blockIdx.x; | |||
| if (idx >= size - 1) return; | |||
| uint32_t lane_idx = threadIdx.x & 0x1F; | |||
| KeyType cur_key = keys[idx] & mask; | |||
| KeyType r_key = __shfl_down(cur_key, 1, 32); | |||
| if (lane_idx == 31) r_key = keys[idx + 1] & mask; | |||
| if (cur_key != r_key) return; | |||
| KeyType l_key = __shfl_up(cur_key, 1, 32); | |||
| if (idx != 0 && lane_idx == 0) l_key = keys[idx - 1] & mask; | |||
| if (cur_key == l_key) return; | |||
| indexs += idx; | |||
| int32_t duplicate_size = 1; | |||
| for (; idx + duplicate_size < size && cur_key == (keys[idx + duplicate_size] & mask); | |||
| ++duplicate_size){}; | |||
| Philox state; | |||
| curand_init(seed, idx, offset, &state); | |||
| for (int32_t i = duplicate_size - 1; i > 0; --i) { | |||
| int32_t r = static_cast<int32_t>(curand(&state) & 0x7fffffff) % (i + 1); | |||
| if (i != r) { | |||
| ValueType tmp = indexs[i]; | |||
| indexs[i] = indexs[r]; | |||
| indexs[r] = tmp; | |||
| } | |||
| } | |||
| } | |||
| uint32_t get_permutation_bits(size_t N) { | |||
| double uniq_rand_num_prob = 0.9; | |||
| double thresh = std::log(uniq_rand_num_prob) * 12; | |||
| double dN = static_cast<double>(N); | |||
| uint32_t bits = std::min(64, static_cast<int>(std::ceil(std::log2( | |||
| dN - (6 * dN * dN + 1) / thresh)))); | |||
| return bits; | |||
| } | |||
| size_t get_permutation_workspace_in_bytes(size_t size) { | |||
| uint32_t bits = get_permutation_bits(size); | |||
| size_t work_size = 0; | |||
| #define cb(KeyType, ValueType) \ | |||
| size_t random_src_size = size * sizeof(KeyType); \ | |||
| size_t indexs_size = size * sizeof(ValueType); \ | |||
| size_t sort_worksize = argsort::cub_sort_pairs<KeyType, ValueType>( \ | |||
| false, NULL, 0, NULL, NULL, NULL, NULL, 1, size, 0, bits, NULL); \ | |||
| work_size = 2 * random_src_size + 2 * indexs_size + \ | |||
| DIVUP(sort_worksize, sizeof(KeyType)) * sizeof(KeyType); | |||
| if (bits > 32) { | |||
| cb(uint64_t, uint64_t) | |||
| } else { | |||
| cb(uint32_t, uint32_t) | |||
| } | |||
| #undef cb | |||
| return work_size; | |||
| } | |||
| template <bool is_32bit, typename ctype> | |||
| void permutation_cuda(ctype* dst, void* workspace, size_t size, uint64_t seed, | |||
| uint64_t offset, uint32_t bits, cudaStream_t stream) { | |||
| int threads = 512; | |||
| int blocks = DIVUP(size, threads); | |||
| using KeyType = typename std::conditional<is_32bit, uint32_t, uint64_t>::type; | |||
| using ValueType = KeyType; | |||
| // split workspace | |||
| KeyType* keys_in = static_cast<KeyType*>(workspace); | |||
| KeyType* keys_out = keys_in + size; | |||
| ValueType* values_in = static_cast<ValueType*>(keys_out + size); | |||
| ValueType* values_out = values_in + size; | |||
| void* extra_workspace = static_cast<void*>(values_out + size); | |||
| // init indexs | |||
| ElemwiseOpParamN<0> ele_param(size); | |||
| typedef RangeKernel<ValueType> rangeOp; | |||
| rangeOp range_op; | |||
| range_op.output = values_in; | |||
| run_elemwise<rangeOp, ValueType, 0>(ele_param, stream, range_op); | |||
| // generate random smaple | |||
| typedef RandomKernel<KeyType> randomOP; | |||
| randomOP random_op; | |||
| random_op.output = keys_in; | |||
| random_op.seed = seed; | |||
| random_op.offset = offset; | |||
| run_elemwise<randomOP, KeyType, 0>(ele_param, stream, random_op); | |||
| // argsort random sample | |||
| size_t wk_size = argsort::cub_sort_pairs<KeyType, ValueType>( | |||
| false, NULL, 0, NULL, NULL, NULL, NULL, 1, size, 0, bits, NULL); | |||
| argsort::cub_sort_pairs<KeyType, ValueType>( | |||
| false, extra_workspace, wk_size, keys_in, keys_out, values_in, | |||
| values_out, 1, size, 0, bits, stream); | |||
| // permute duplicate sample | |||
| KeyType mask = static_cast<KeyType>((1ULL << bits) - 1); | |||
| permute_duplicate_keys_kernel<KeyType, ValueType> | |||
| <<<blocks, threads, 0, stream>>>(keys_out, values_out, mask, size, | |||
| seed, offset); | |||
| after_kernel_launch(); | |||
| typedef AsTypeKernel<ValueType, ctype> asTypeOP; | |||
| asTypeOP as_type_op; | |||
| as_type_op.input = values_out; | |||
| as_type_op.output = dst; | |||
| run_elemwise<asTypeOP, ValueType, 0>(ele_param, stream, as_type_op); | |||
| } | |||
| template <typename ctype> | |||
| void permutation_forward(ctype* dst, void* workspace, size_t size, uint64_t seed, | |||
| uint64_t offset, cudaStream_t stream) { | |||
| uint32_t bits = get_permutation_bits(size); | |||
| if (bits <= 32) { | |||
| permutation_cuda<true, ctype>(dst, workspace, size, seed, offset, bits, | |||
| stream); | |||
| } else { | |||
| permutation_cuda<false, ctype>(dst, workspace, size, seed, offset, bits, | |||
| stream); | |||
| } | |||
| } | |||
| #define INST_PERMUTATION(T) \ | |||
| template void permutation_forward<T>(T*, void*, size_t, uint64_t, uint64_t, \ | |||
| cudaStream_t); \ | |||
| INST_PERMUTATION(dt_int32) | |||
| INST_PERMUTATION(dt_int16) | |||
| INST_PERMUTATION(dt_float32) | |||
| #undef INST_PERMUTATION | |||
| } // namespace random | |||
| #define INST(_dtype) \ | |||
| INST_RUN_ELEMWISE(random::GammaKernel<DTypeTrait<_dtype>::ctype>, \ | |||
| DTypeTrait<_dtype>::ctype, 0); \ | |||
| INST_RUN_ELEMWISE(random::PoissonKernel<DTypeTrait<_dtype>::ctype>, \ | |||
| DTypeTrait<_dtype>::ctype, 0); \ | |||
| INST_RUN_ELEMWISE(random::BetaKernel<DTypeTrait<_dtype>::ctype>, \ | |||
| DTypeTrait<_dtype>::ctype, 0); \ | |||
| INST(megdnn::dtype::Float32) | |||
| INST(megdnn::dtype::Float16) | |||
| INST(megdnn::dtype::BFloat16) | |||
| #undef INST | |||
| } // namespace cuda | |||
| } // namespace megdnn | |||
| @@ -0,0 +1,258 @@ | |||
| /** | |||
| * \file dnn/src/cuda/rng/kernel.cuh | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #pragma once | |||
| #include <cuda_runtime_api.h> | |||
| #include <stdint.h> | |||
| #include <curand.h> | |||
| #include <curand_kernel.h> | |||
| #include "megdnn/dtype.h" | |||
| #include "src/cuda/elemwise_helper.cuh" | |||
| #include "src/cuda/utils.cuh" | |||
| #if MEGDNN_CC_HOST | |||
| #include "megdnn/oprs.h" | |||
| #endif | |||
| namespace megdnn { | |||
| namespace cuda { | |||
| namespace random { | |||
| using Philox = curandStatePhilox4_32_10_t; | |||
| QUALIFIERS float _curand_uniform(Philox *state){ | |||
| float r = curand_uniform(state); | |||
| if (r >= 1.0f) { | |||
| r = 0.0f; | |||
| } | |||
| return r; | |||
| } | |||
| template<typename ctype, typename = void> | |||
| struct RandomKernel; | |||
| template<typename ctype> | |||
| using enable_64bit = typename std::enable_if<std::is_integral<ctype>::value && ((sizeof(ctype)) == 8)>::type; | |||
| template<typename ctype> | |||
| using enable_32bit = typename std::enable_if<std::is_integral<ctype>::value && ((sizeof(ctype)) <= 4)>::type; | |||
| template<typename ctype> | |||
| struct RandomKernel<ctype, enable_64bit<ctype>>{ | |||
| ctype* output; | |||
| uint64_t seed, offset; | |||
| uint64_t mask = static_cast<uint64_t>(std::numeric_limits<ctype>::max()); | |||
| __device__ void operator()(uint32_t idx){ | |||
| Philox local_state; | |||
| curand_init(seed, idx, offset, &local_state); | |||
| uint4 rand = curand4(&local_state); | |||
| uint64_t val = (static_cast<uint64_t>(rand.x) << 32) | rand.y; | |||
| output[idx] = static_cast<ctype>(val & mask); | |||
| } | |||
| #if MEGDNN_CC_HOST | |||
| RandomKernel(const ctype* output, uint64_t seed, uint64_t offset) | |||
| : output{output}, | |||
| seed{seed}, | |||
| offset{offset}{} | |||
| #endif | |||
| }; | |||
| template<typename ctype> | |||
| struct RandomKernel<ctype, enable_32bit<ctype>>{ | |||
| ctype* output; | |||
| uint64_t seed, offset; | |||
| uint32_t mask = static_cast<uint32_t>(std::numeric_limits<ctype>::max()); | |||
| __device__ void operator()(uint32_t idx){ | |||
| Philox local_state; | |||
| curand_init(seed, idx, offset, &local_state); | |||
| uint32_t val = curand(&local_state); | |||
| output[idx] = static_cast<ctype>(val & mask); | |||
| } | |||
| #if MEGDNN_CC_HOST | |||
| RandomKernel(const ctype* output, uint64_t seed, uint64_t offset) | |||
| : output{output}, | |||
| seed{seed}, | |||
| offset{offset}{} | |||
| #endif | |||
| }; | |||
| template<typename ctype> | |||
| struct RangeKernel{ | |||
| ctype* output; | |||
| __device__ void operator()(uint32_t idx){ | |||
| output[idx] = static_cast<ctype>(idx); | |||
| } | |||
| #if MEGDNN_CC_HOST | |||
| RangeKernel(const ctype* output) | |||
| : output{output}{} | |||
| #endif | |||
| }; | |||
| template<typename ctype_src, typename ctype_dst> | |||
| struct AsTypeKernel{ | |||
| ctype_src* input; | |||
| ctype_dst* output; | |||
| using ctype_mask =typename std::conditional<std::is_integral<ctype_dst>::value, ctype_dst, ctype_src>::type; | |||
| ctype_src mask = static_cast<ctype_src>(std::numeric_limits<ctype_mask>::max()); | |||
| __device__ void operator()(uint32_t idx){ | |||
| output[idx] = static_cast<ctype_dst>(input[idx] & mask); | |||
| } | |||
| #if MEGDNN_CC_HOST | |||
| AsTypeKernel(const ctype_src* input, const ctype_dst* output) | |||
| : input{input}, output{output}{} | |||
| #endif | |||
| }; | |||
| template <typename ctype> | |||
| struct GammaKernel { | |||
| ctype* output; | |||
| ctype* shape; | |||
| ctype* scale; | |||
| uint64_t seed, offset; | |||
| static __device__ float sample_gamma(float a, float b, Philox* state){ | |||
| float scale = b; | |||
| if (a <= 0) | |||
| return 0.f; | |||
| if (a < 1.0f) { | |||
| scale *= powf(_curand_uniform(state), 1.0f / a); | |||
| a += 1.0f; | |||
| } | |||
| float d = a - 1.0f / 3.0f; | |||
| float c = 1.0f / sqrtf(9.0f * d); | |||
| while (1) { | |||
| float x, y; | |||
| x = curand_normal(state); | |||
| y = 1.0f + c * x; | |||
| if (y <= 0) | |||
| continue; | |||
| float v = y * y * y; | |||
| float u = _curand_uniform(state); | |||
| float xx = x * x; | |||
| if ((u < 1.0f - 0.0331f * xx * xx) || | |||
| logf(u) < 0.5f * xx + d * (1.0f - v + logf(v))) | |||
| return scale * d * v; | |||
| } | |||
| } | |||
| __device__ void operator()(uint32_t idx) { | |||
| Philox local_state; | |||
| curand_init(seed, idx, offset, &local_state); | |||
| float a = static_cast<float>(shape[idx]); | |||
| float b = static_cast<float>(scale[idx]); | |||
| output[idx] = static_cast<ctype>(sample_gamma(a, b, &local_state)); | |||
| } | |||
| #if MEGDNN_CC_HOST | |||
| GammaKernel(const TensorND& output, const TensorND& shape, | |||
| const TensorND& scale, uint64_t seed, uint64_t offset) | |||
| : output{output.ptr<ctype>()}, | |||
| shape{shape.ptr<ctype>()}, | |||
| scale{scale.ptr<ctype>()}, | |||
| seed{seed}, | |||
| offset{offset}{} | |||
| #endif | |||
| }; | |||
| template<typename ctype> | |||
| struct PoissonKernel{ | |||
| ctype* output; | |||
| ctype* lambda; | |||
| uint64_t seed, offset; | |||
| __device__ void operator()(uint32_t idx){ | |||
| Philox local_state; | |||
| curand_init(seed, idx, offset, &local_state); | |||
| float lam = static_cast<float>(lambda[idx]); | |||
| output[idx] = static_cast<ctype>(curand_poisson(&local_state, lam)); | |||
| } | |||
| #if MEGDNN_CC_HOST | |||
| PoissonKernel(const TensorND& output,const TensorND& lambda, | |||
| uint64_t seed, uint64_t offset) | |||
| : output{output.ptr<ctype>()}, | |||
| lambda{lambda.ptr<ctype>()}, | |||
| seed{seed}, | |||
| offset{offset}{} | |||
| #endif | |||
| }; | |||
| template<typename ctype> | |||
| struct BetaKernel{ | |||
| ctype* output; | |||
| ctype* alpha; | |||
| ctype* beta; | |||
| uint64_t seed, offset; | |||
| __device__ void operator()(uint32_t idx){ | |||
| Philox local_state; | |||
| curand_init(seed, idx, offset, &local_state); | |||
| float a = static_cast<float>(alpha[idx]); | |||
| float b = static_cast<float>(beta[idx]); | |||
| if(a <= 0 || b <= 0){ | |||
| output[idx] = 0; | |||
| return; | |||
| } | |||
| if( a < 1.0f && b < 1.0f){ | |||
| float u, v, x, y; | |||
| while (true) | |||
| { | |||
| u = _curand_uniform(&local_state); | |||
| v = _curand_uniform(&local_state); | |||
| x = powf(u, 1.0f / a); | |||
| y = powf(v, 1.0f / b); | |||
| if (x + y < 1.0f) { | |||
| if (x + y > 0) { | |||
| output[idx] = static_cast<ctype>(x / (x + y)); | |||
| return ; | |||
| } else { | |||
| float logx = logf(u) / a; | |||
| float logy = logf(v) / b; | |||
| float log_max = logx > logy ? logx : logy; | |||
| logx -= log_max; | |||
| logy -= log_max; | |||
| output[idx] = static_cast<ctype>(exp(logx - | |||
| log(exp(logx) + exp(logy)))); | |||
| return ; | |||
| } | |||
| } | |||
| } | |||
| }else{ | |||
| float ga = GammaKernel<float>::sample_gamma(a, 1.0f, &local_state); | |||
| float gb = GammaKernel<float>::sample_gamma(b, 1.0f, &local_state); | |||
| output[idx] = static_cast<ctype>(ga / ( ga + gb)); | |||
| return ; | |||
| } | |||
| } | |||
| #if MEGDNN_CC_HOST | |||
| BetaKernel(const TensorND& output, const TensorND& alpha, | |||
| const TensorND& beta, uint64_t seed, uint64_t offset) | |||
| : output{output.ptr<ctype>()}, | |||
| alpha{alpha.ptr<ctype>()}, | |||
| beta{beta.ptr<ctype>()}, | |||
| seed{seed}, | |||
| offset{offset}{} | |||
| #endif | |||
| }; | |||
| template<typename ctype> | |||
| void permutation_forward(ctype* dst, void* workspace, size_t size, uint64_t seed, | |||
| uint64_t offset, cudaStream_t stream); | |||
| size_t get_permutation_workspace_in_bytes(size_t N); | |||
| } // namespace random | |||
| } // namespace cuda | |||
| } // namespace megdnn | |||
| @@ -13,6 +13,7 @@ | |||
| #include "src/cuda/handle.h" | |||
| #include "src/cuda/utils.h" | |||
| #include "./opr_impl.h" | |||
| #include "./kernel.cuh" | |||
| using namespace megdnn; | |||
| using namespace cuda; | |||
| @@ -122,5 +123,143 @@ size_t GaussianRNGImpl::get_workspace_in_bytes(const TensorLayout &layout) { | |||
| return 0; | |||
| } | |||
| GammaRNGImpl::GammaRNGImpl(Handle *handle): | |||
| GammaRNG(handle), | |||
| m_seed(0), | |||
| m_offset(0), | |||
| m_stream(cuda_stream(handle)) | |||
| { | |||
| } | |||
| void GammaRNGImpl::exec(_megdnn_tensor_in shape, _megdnn_tensor_in scale, | |||
| _megdnn_tensor_inout dst, _megdnn_workspace workspace) { | |||
| check_exec(shape.layout, scale.layout ,dst.layout, workspace.size); | |||
| auto size = dst.layout.total_nr_elems(); | |||
| megdnn_assert(size); | |||
| ensure_seed(m_param.seed); | |||
| ElemwiseOpParamN<0> ele_param(size); | |||
| switch (dst.layout.dtype.enumv()){ | |||
| #define cb(_dt) \ | |||
| case DTypeTrait<_dt>::enumv: \ | |||
| { \ | |||
| using ctype = DTypeTrait<_dt>::ctype; \ | |||
| run_elemwise<random::GammaKernel<ctype>, ctype, 0>(ele_param, m_stream, \ | |||
| {dst, shape, scale, m_seed, m_offset}); \ | |||
| break ; \ | |||
| } | |||
| MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) | |||
| #undef cb | |||
| default: | |||
| megdnn_throw("bad dtype"); | |||
| } | |||
| m_offset += 16; | |||
| } | |||
| PoissonRNGImpl::PoissonRNGImpl(Handle *handle): | |||
| PoissonRNG(handle), | |||
| m_seed(0), | |||
| m_offset(0), | |||
| m_stream(cuda_stream(handle)) | |||
| { | |||
| } | |||
| void PoissonRNGImpl::exec(_megdnn_tensor_in lam, _megdnn_tensor_out dst, | |||
| _megdnn_workspace workspace) { | |||
| check_exec(lam.layout, dst.layout, workspace.size); | |||
| auto size = dst.layout.total_nr_elems(); | |||
| megdnn_assert(size); | |||
| ensure_seed(m_param.seed); | |||
| ElemwiseOpParamN<0> ele_param(size); | |||
| switch (dst.layout.dtype.enumv()){ | |||
| #define cb(_dt) \ | |||
| case DTypeTrait<_dt>::enumv: \ | |||
| { \ | |||
| using ctype = DTypeTrait<_dt>::ctype; \ | |||
| run_elemwise<random::PoissonKernel<ctype>, ctype, 0>(ele_param, m_stream, \ | |||
| {dst, lam, m_seed, m_offset}); \ | |||
| break; \ | |||
| } | |||
| MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) | |||
| #undef cb | |||
| default: | |||
| megdnn_throw("bad dtype"); | |||
| } | |||
| m_offset += 20; | |||
| } | |||
| BetaRNGImpl::BetaRNGImpl(Handle *handle): | |||
| BetaRNG(handle), | |||
| m_seed(0), | |||
| m_offset(0), | |||
| m_stream(cuda_stream(handle)) | |||
| { | |||
| } | |||
| void BetaRNGImpl::exec(_megdnn_tensor_in alpha, _megdnn_tensor_in beta,_megdnn_tensor_out dst, | |||
| _megdnn_workspace workspace) { | |||
| check_exec(alpha.layout, beta.layout ,dst.layout, workspace.size); | |||
| auto size = dst.layout.total_nr_elems(); | |||
| megdnn_assert(size); | |||
| ensure_seed(m_param.seed); | |||
| ElemwiseOpParamN<0> ele_param(size); | |||
| switch (dst.layout.dtype.enumv()){ | |||
| #define cb(_dt) \ | |||
| case DTypeTrait<_dt>::enumv: \ | |||
| { \ | |||
| using ctype = DTypeTrait<_dt>::ctype; \ | |||
| run_elemwise<random::BetaKernel<ctype>, ctype, 0>(ele_param, m_stream, \ | |||
| {dst, alpha, beta, m_seed, m_offset}); \ | |||
| break; \ | |||
| } | |||
| MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) | |||
| #undef cb | |||
| default: | |||
| megdnn_throw("bad dtype"); | |||
| } | |||
| m_offset += 32; | |||
| } | |||
| PermutationRNGImpl::PermutationRNGImpl(Handle *handle): | |||
| PermutationRNG(handle), | |||
| m_seed(0), | |||
| m_offset(0), | |||
| m_stream(cuda_stream(handle)) | |||
| { | |||
| } | |||
| void PermutationRNGImpl::exec( | |||
| _megdnn_tensor_inout dst, _megdnn_workspace workspace) { | |||
| check_exec(dst.layout, workspace.size); | |||
| auto size = dst.layout.total_nr_elems(); | |||
| megdnn_assert(size); | |||
| ensure_seed(m_param.seed); | |||
| auto wk = workspace.ptr<void>(); | |||
| switch (dst.layout.dtype.enumv()){ | |||
| #define cb(_dt) \ | |||
| case DTypeTrait<_dt>::enumv: \ | |||
| { \ | |||
| using ctype = DTypeTrait<_dt>::ctype; \ | |||
| ctype max_size = DTypeTrait<_dt>::max() - 1; \ | |||
| megdnn_assert(ctype(size) < max_size); \ | |||
| random::permutation_forward<ctype>(dst.ptr<ctype>(), wk, size, m_seed, \ | |||
| m_offset, m_stream); \ | |||
| break; \ | |||
| } | |||
| cb(::megdnn::dtype::Float32) | |||
| cb(::megdnn::dtype::Int32) | |||
| cb(::megdnn::dtype::Int16) | |||
| #undef cb | |||
| default: | |||
| megdnn_throw("bad dtype"); | |||
| } | |||
| m_offset += 8; | |||
| } | |||
| size_t PermutationRNGImpl::get_workspace_in_bytes(const TensorLayout &layout){ | |||
| size_t size = layout.total_nr_elems(); | |||
| return random::get_permutation_workspace_in_bytes(size); | |||
| } | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -10,9 +10,9 @@ | |||
| */ | |||
| #pragma once | |||
| #include <curand.h> | |||
| #include "megdnn/oprs.h" | |||
| #include "src/cuda/handle.h" | |||
| #include <curand.h> | |||
| namespace megdnn { | |||
| namespace cuda { | |||
| @@ -22,51 +22,136 @@ class CuRandHandle { | |||
| uint64_t m_seed; | |||
| CuRandHandle(const CuRandHandle&) = delete; | |||
| CuRandHandle& operator = (const CuRandHandle&) = delete; | |||
| CuRandHandle& operator=(const CuRandHandle&) = delete; | |||
| public: | |||
| CuRandHandle(cudaStream_t stream, uint64_t seed = 0); | |||
| ~CuRandHandle(); | |||
| public: | |||
| CuRandHandle(cudaStream_t stream, uint64_t seed = 0); | |||
| ~CuRandHandle(); | |||
| void seed(uint64_t seed); | |||
| void seed(uint64_t seed); | |||
| curandGenerator_t gen() const { | |||
| return m_gen; | |||
| } | |||
| curandGenerator_t gen() const { return m_gen; } | |||
| void ensure_seed(uint64_t seed) { | |||
| if (m_seed != seed) { | |||
| this->seed(seed); | |||
| } | |||
| void ensure_seed(uint64_t seed) { | |||
| if (m_seed != seed) { | |||
| this->seed(seed); | |||
| } | |||
| } | |||
| }; | |||
| class UniformRNGImpl: public UniformRNG { | |||
| class UniformRNGImpl : public UniformRNG { | |||
| CuRandHandle m_curand_handle; | |||
| public: | |||
| UniformRNGImpl(Handle *handle); | |||
| void exec(_megdnn_tensor_inout dst, _megdnn_workspace) override; | |||
| public: | |||
| UniformRNGImpl(Handle* handle); | |||
| void exec(_megdnn_tensor_inout dst, _megdnn_workspace) override; | |||
| size_t get_workspace_in_bytes(const TensorLayout&) override { | |||
| return 0; | |||
| } | |||
| size_t get_workspace_in_bytes(const TensorLayout&) override { return 0; } | |||
| }; | |||
| class GaussianRNGImpl: public GaussianRNG { | |||
| class GaussianRNGImpl : public GaussianRNG { | |||
| CuRandHandle m_curand_handle; | |||
| public: | |||
| GaussianRNGImpl(Handle *handle); | |||
| public: | |||
| GaussianRNGImpl(Handle* handle); | |||
| void exec(_megdnn_tensor_inout dst, _megdnn_workspace) override; | |||
| size_t get_workspace_in_bytes(const TensorLayout& layout) override; | |||
| }; | |||
| class GammaRNGImpl : public GammaRNG { | |||
| uint64_t m_seed, m_offset; | |||
| cudaStream_t m_stream; | |||
| public: | |||
| GammaRNGImpl(Handle* handle); | |||
| void exec(_megdnn_tensor_in shape,_megdnn_tensor_in scale, | |||
| _megdnn_tensor_out dst, _megdnn_workspace) override; | |||
| void exec(_megdnn_tensor_inout dst, _megdnn_workspace) override; | |||
| size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | |||
| const TensorLayout&) override { | |||
| return 0; | |||
| } | |||
| size_t get_workspace_in_bytes(const TensorLayout &layout) override; | |||
| void seed(uint64_t seed) { m_seed = seed; } | |||
| void ensure_seed(uint64_t seed) { | |||
| if (m_seed != seed) { | |||
| this->seed(seed); | |||
| } | |||
| } | |||
| }; | |||
| class BetaRNGImpl : public BetaRNG { | |||
| uint64_t m_seed, m_offset; | |||
| cudaStream_t m_stream; | |||
| } // namespace cuda | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| public: | |||
| BetaRNGImpl(Handle* handle); | |||
| void exec(_megdnn_tensor_in alpha,_megdnn_tensor_in beta, | |||
| _megdnn_tensor_out dst, _megdnn_workspace) override; | |||
| size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | |||
| const TensorLayout&) override { | |||
| return 0; | |||
| } | |||
| void seed(uint64_t seed) { m_seed = seed; } | |||
| void ensure_seed(uint64_t seed) { | |||
| if (m_seed != seed) { | |||
| this->seed(seed); | |||
| } | |||
| } | |||
| }; | |||
| class PoissonRNGImpl : public PoissonRNG { | |||
| uint64_t m_seed, m_offset; | |||
| cudaStream_t m_stream; | |||
| public: | |||
| PoissonRNGImpl(Handle* handle); | |||
| void exec(_megdnn_tensor_in lam, _megdnn_tensor_out dst, | |||
| _megdnn_workspace) override; | |||
| size_t get_workspace_in_bytes(const TensorLayout&, | |||
| const TensorLayout&) override { | |||
| return 0; | |||
| } | |||
| void seed(uint64_t seed) { m_seed = seed; } | |||
| void ensure_seed(uint64_t seed) { | |||
| if (m_seed != seed) { | |||
| this->seed(seed); | |||
| } | |||
| } | |||
| }; | |||
| class PermutationRNGImpl : public PermutationRNG { | |||
| uint64_t m_seed, m_offset; | |||
| cudaStream_t m_stream; | |||
| public: | |||
| PermutationRNGImpl(Handle* handle); | |||
| void exec(_megdnn_tensor_inout dst, _megdnn_workspace) override; | |||
| size_t get_workspace_in_bytes(const TensorLayout& layout) override; | |||
| void seed(uint64_t seed) { m_seed = seed; } | |||
| void ensure_seed(uint64_t seed) { | |||
| if (m_seed != seed) { | |||
| this->seed(seed); | |||
| } | |||
| } | |||
| }; | |||
| } // namespace cuda | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -78,6 +78,157 @@ namespace { | |||
| } | |||
| } | |||
| template<typename T> | |||
| T normal_sample(Xoroshiro128plus *rng){ | |||
| T v; | |||
| fill_gaussian<T>(rng, &v, 1, T(0.f), T(1.f)); | |||
| return v; | |||
| } | |||
| template<typename T> | |||
| T uniform_sample(Xoroshiro128plus *rng){ | |||
| return uniform_int2float<T>((*rng)()); | |||
| } | |||
| template<typename T, typename U> | |||
| void fill_gamma(Xoroshiro128plus *rng, U *dst, size_t size, | |||
| U* shape, U* scale){ | |||
| for(size_t i = 0; i < size; ++i){ | |||
| T a = static_cast<T>(shape[i]); | |||
| T b = static_cast<T>(scale[i]); | |||
| T scale = b; | |||
| bool a_less_one = a < 1.f ? true : false; | |||
| if (a <= 0) { | |||
| dst[i] = U(0.0f); | |||
| continue; | |||
| }; | |||
| T d = a + (a_less_one ? 2.0f / 3.0f : -1.0f / 3.0f); | |||
| T c = 1.0f / std::sqrt(9.0f * d); | |||
| while (true) | |||
| { | |||
| T x, y; | |||
| x = normal_sample<T>(rng); | |||
| y = 1.0f + c * x; | |||
| if ( y <= 0) continue; | |||
| T v = y * y * y; | |||
| T u = uniform_sample<T>(rng); | |||
| T xx = x * x; | |||
| if ((u < 1.0f - 0.0331f * xx * xx) || | |||
| std::log(u) < 0.5f * xx + d * (1.0f - v + std::log(v))) | |||
| { | |||
| dst[i] = U(scale * d * v); | |||
| if (a_less_one) dst[i] *= U(std::pow(uniform_sample<T>(rng), T(1.f / a))); | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| template<typename T, typename U> | |||
| void fill_poisson(Xoroshiro128plus *rng, U *dst, U* lam, size_t size){ | |||
| for(size_t i = 0; i < size; ++i) { | |||
| T lambda = static_cast<T>(lam[i]); | |||
| T exp_neg_lambda = std::exp(-lambda); | |||
| T log_lambda = std::log(lambda), sqrt_lambda = std::sqrt(lambda); | |||
| T b = 0.931f + 2.53f * sqrt_lambda; | |||
| T a = -0.059f + 0.02483f * b; | |||
| T inv_alpha = 1.1239f + 1.1328f / ( b - 3.4f); | |||
| T vr = 0.9277f - 3.6224f / (b - 2.f); | |||
| T u , v, u_shifted, k; | |||
| if( lambda == 0) { | |||
| dst[i] = U(0); | |||
| continue; | |||
| } | |||
| if ( lambda < 10){ | |||
| T prod = 1, x = 0; | |||
| u = 0; | |||
| while (true) | |||
| { | |||
| u = uniform_sample<T>(rng); | |||
| prod *= u; | |||
| if ( prod <= exp_neg_lambda ){ | |||
| dst[i] = U(x); | |||
| break; | |||
| } | |||
| x += 1; | |||
| } | |||
| continue; | |||
| } | |||
| while (true) | |||
| { | |||
| u = uniform_sample<T>(rng) - T(0.5f); | |||
| v = uniform_sample<T>(rng); | |||
| u_shifted = T(0.5f) - std::abs(u); | |||
| k = std::floor((T(2.f) * a / u_shifted + b) * u + lambda + T(0.43f)); | |||
| if ( u_shifted >= 0.07 && v < vr ){ | |||
| dst[i] = U(k); | |||
| break; | |||
| } | |||
| if (k < 0 || (u_shifted < T(0.013f) && v > u_shifted)) { | |||
| continue; | |||
| } | |||
| if ((std::log(v) + std::log(inv_alpha) - std::log(a / (u_shifted * u_shifted) + b)) <= | |||
| (-lambda + k * log_lambda - std::lgamma(k + 1))) { | |||
| dst[i] = U(k); | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| template<typename T, typename U> | |||
| void fill_beta(Xoroshiro128plus *rng, U *dst, U* alpha,U* beta, size_t size){ | |||
| for (size_t i = 0; i < size; ++i) { | |||
| T a = static_cast<T>(alpha[i]), b = static_cast<T>(beta[i]); | |||
| if( a < 1.0f && b < 1.0f){ | |||
| T u,v,x,y; | |||
| while (true) | |||
| { | |||
| u = uniform_sample<T>(rng); | |||
| v = uniform_sample<T>(rng); | |||
| x = std::pow(u, 1.0f / a); | |||
| y = std::pow(v, 1.0f / b); | |||
| if (x + y < 1.0f) { | |||
| if (x + y > 0) { | |||
| dst[i] = static_cast<U>(x / (x + y)); | |||
| break; | |||
| }else { | |||
| T logx = std::log(u) / a; | |||
| T logy = std::log(v) / b; | |||
| T log_max = std::max(logx, logy); | |||
| logx -= log_max; | |||
| logy -= log_max; | |||
| dst[i] = static_cast<U> (std::exp(logx - | |||
| std::log(std::exp(logx) + std::exp(logy)))); | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| }else{ | |||
| T ga, gb, one = 1; | |||
| fill_gamma<T,T>(rng, &ga, 1, &a, &one); | |||
| fill_gamma<T,T>(rng, &gb, 1, &b, &one); | |||
| dst[i] = static_cast<U>( ga / (ga + gb)); | |||
| } | |||
| } | |||
| } | |||
| template<typename T> | |||
| void fill_permutation(Xoroshiro128plus *rng, T *dst, size_t size){ | |||
| const int64_t mask = std::numeric_limits<int64_t>::max(); | |||
| for (size_t i = 0; i < size; ++i) { | |||
| dst[i] = static_cast<T>(i); | |||
| } | |||
| for (int64_t i = size - 1; i > 0; --i) { | |||
| int64_t r = static_cast<int64_t>((*rng)()&mask) % (i + 1); | |||
| if (i != r) { | |||
| T tmp = dst[i]; | |||
| dst[i] = dst[r]; | |||
| dst[r] = tmp; | |||
| } | |||
| } | |||
| } | |||
| } // anonymous namespace | |||
| uint64_t Splitmix64::operator() () { | |||
| @@ -150,5 +301,98 @@ void GaussianRNGImpl::exec( | |||
| } | |||
| } | |||
| void GammaRNGImpl::exec(_megdnn_tensor_in shape, _megdnn_tensor_in scale, | |||
| _megdnn_tensor_out dst, _megdnn_workspace workspace) { | |||
| check_exec(shape.layout, scale.layout, dst.layout, workspace.size); | |||
| auto size = dst.layout.total_nr_elems(); | |||
| auto prng = &m_rng.ensure_seed(m_param.seed); | |||
| switch (dst.layout.dtype.enumv()) { | |||
| #define cb(_dt) \ | |||
| case DTypeTrait<_dt>::enumv: \ | |||
| { \ | |||
| using ctype = DTypeTrait<_dt>::ctype; \ | |||
| auto ptr = dst.ptr<ctype>(); \ | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR({fill_gamma<float>(prng, ptr, \ | |||
| size, shape.ptr<ctype>(), scale.ptr<ctype>());};); \ | |||
| return; \ | |||
| } | |||
| MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) | |||
| #undef cb | |||
| default: | |||
| megdnn_throw("bad dtype"); | |||
| } | |||
| } | |||
| void PoissonRNGImpl::exec(_megdnn_tensor_in lam, | |||
| _megdnn_tensor_inout dst, _megdnn_workspace workspace) { | |||
| check_exec(lam.layout, dst.layout, workspace.size); | |||
| auto size = dst.layout.total_nr_elems(); | |||
| auto prng = &m_rng.ensure_seed(m_param.seed); | |||
| switch (dst.layout.dtype.enumv()) { | |||
| #define cb(_dt) \ | |||
| case DTypeTrait<_dt>::enumv: \ | |||
| { \ | |||
| using ctype = DTypeTrait<_dt>::ctype; \ | |||
| auto dst_ptr = dst.ptr<ctype>(); \ | |||
| auto lam_ptr = lam.ptr<ctype>(); \ | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR({fill_poisson<float>(prng, dst_ptr, \ | |||
| lam_ptr, size );};); \ | |||
| return; \ | |||
| } | |||
| MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) | |||
| #undef cb | |||
| default: | |||
| megdnn_throw("bad dtype"); | |||
| } | |||
| } | |||
| void BetaRNGImpl::exec(_megdnn_tensor_in alpha,_megdnn_tensor_in beta, | |||
| _megdnn_tensor_out dst, _megdnn_workspace workspace) { | |||
| check_exec(alpha.layout, beta.layout, dst.layout, workspace.size); | |||
| auto size = dst.layout.total_nr_elems(); | |||
| auto prng = &m_rng.ensure_seed(m_param.seed); | |||
| switch (dst.layout.dtype.enumv()) { | |||
| #define cb(_dt) \ | |||
| case DTypeTrait<_dt>::enumv: \ | |||
| { \ | |||
| using ctype = DTypeTrait<_dt>::ctype; \ | |||
| auto dst_ptr = dst.ptr<ctype>(); \ | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR({fill_beta<float>(prng, dst_ptr, \ | |||
| alpha.ptr<ctype>(),beta.ptr<ctype>(), size );};); \ | |||
| return; \ | |||
| } | |||
| MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) | |||
| #undef cb | |||
| default: | |||
| megdnn_throw("bad dtype"); | |||
| } | |||
| } | |||
| void PermutationRNGImpl::exec( | |||
| _megdnn_tensor_inout dst, _megdnn_workspace workspace) { | |||
| check_exec(dst.layout, workspace.size); | |||
| auto size = dst.layout.total_nr_elems(); | |||
| auto prng = &m_rng.ensure_seed(m_param.seed); | |||
| switch (dst.layout.dtype.enumv()) { | |||
| #define cb(_dt) \ | |||
| case DTypeTrait<_dt>::enumv: \ | |||
| { \ | |||
| using ctype = DTypeTrait<_dt>::ctype; \ | |||
| ctype max_size = DTypeTrait<_dt>::max() - 1; \ | |||
| megdnn_assert((ctype(size) < max_size)); \ | |||
| auto ptr = dst.ptr<ctype>(); \ | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR({fill_permutation<ctype>(prng, ptr, \ | |||
| size);};); \ | |||
| return; \ | |||
| } | |||
| cb(::megdnn::dtype::Float32) | |||
| cb(::megdnn::dtype::Int32) | |||
| cb(::megdnn::dtype::Int16) | |||
| #undef cb | |||
| default: | |||
| megdnn_throw("bad dtype"); | |||
| } | |||
| } | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -10,8 +10,8 @@ | |||
| */ | |||
| #pragma once | |||
| #include "megdnn/oprs.h" | |||
| #include <cstdint> | |||
| #include "megdnn/oprs.h" | |||
| namespace megdnn { | |||
| namespace naive { | |||
| @@ -19,12 +19,11 @@ namespace naive { | |||
| //! see http://xoroshiro.di.unimi.it/splitmix64.c | |||
| class Splitmix64 { | |||
| uint64_t m_s; | |||
| public: | |||
| explicit Splitmix64(uint64_t seed = 0): | |||
| m_s{seed} | |||
| {} | |||
| uint64_t operator() (); | |||
| public: | |||
| explicit Splitmix64(uint64_t seed = 0) : m_s{seed} {} | |||
| uint64_t operator()(); | |||
| }; | |||
| /*! | |||
| @@ -36,51 +35,99 @@ class Xoroshiro128plus { | |||
| return (x << k) | (x >> (64 - k)); | |||
| } | |||
| public: | |||
| explicit Xoroshiro128plus(uint64_t seed = 0) { | |||
| public: | |||
| explicit Xoroshiro128plus(uint64_t seed = 0) { this->seed(seed); } | |||
| //! reset state if seed changed | |||
| Xoroshiro128plus& ensure_seed(uint64_t seed) { | |||
| if (seed != m_init_seed) { | |||
| this->seed(seed); | |||
| } | |||
| return *this; | |||
| } | |||
| //! reset state if seed changed | |||
| Xoroshiro128plus& ensure_seed(uint64_t seed) { | |||
| if (seed != m_init_seed) { | |||
| this->seed(seed); | |||
| } | |||
| return *this; | |||
| } | |||
| //! set seed | |||
| void seed(uint64_t seed); | |||
| uint64_t operator()(); | |||
| }; | |||
| //! set seed | |||
| void seed(uint64_t seed); | |||
| class UniformRNGImpl : public UniformRNG { | |||
| Xoroshiro128plus m_rng; | |||
| uint64_t operator() (); | |||
| public: | |||
| using UniformRNG::UniformRNG; | |||
| void exec(_megdnn_tensor_inout dst, _megdnn_workspace) override; | |||
| size_t get_workspace_in_bytes(const TensorLayout&) override { return 0; } | |||
| }; | |||
| class UniformRNGImpl: public UniformRNG { | |||
| class GaussianRNGImpl : public GaussianRNG { | |||
| Xoroshiro128plus m_rng; | |||
| public: | |||
| using UniformRNG::UniformRNG; | |||
| void exec(_megdnn_tensor_inout dst, _megdnn_workspace) override; | |||
| public: | |||
| using GaussianRNG::GaussianRNG; | |||
| void exec(_megdnn_tensor_inout dst, _megdnn_workspace) override; | |||
| size_t get_workspace_in_bytes(const TensorLayout&) override { | |||
| return 0; | |||
| } | |||
| size_t get_workspace_in_bytes(const TensorLayout&) override { return 0; } | |||
| }; | |||
| class GaussianRNGImpl: public GaussianRNG { | |||
| class GammaRNGImpl : public GammaRNG { | |||
| Xoroshiro128plus m_rng; | |||
| public: | |||
| using GaussianRNG::GaussianRNG; | |||
| void exec(_megdnn_tensor_inout dst, _megdnn_workspace) override; | |||
| public: | |||
| using GammaRNG::GammaRNG; | |||
| size_t get_workspace_in_bytes(const TensorLayout&) override { | |||
| return 0; | |||
| } | |||
| void exec(_megdnn_tensor_in shape,_megdnn_tensor_in scale, _megdnn_tensor_out dst, | |||
| _megdnn_workspace) override; | |||
| size_t get_workspace_in_bytes(const TensorLayout&,const TensorLayout&, | |||
| const TensorLayout&) override { | |||
| return 0; | |||
| } | |||
| }; | |||
| class PoissonRNGImpl : public PoissonRNG { | |||
| Xoroshiro128plus m_rng; | |||
| public: | |||
| using PoissonRNG::PoissonRNG; | |||
| } // namespace naive | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| void exec(_megdnn_tensor_in lam, _megdnn_tensor_inout dst, | |||
| _megdnn_workspace) override; | |||
| size_t get_workspace_in_bytes(const TensorLayout&, | |||
| const TensorLayout&) override { | |||
| return 0; | |||
| } | |||
| }; | |||
| class BetaRNGImpl : public BetaRNG { | |||
| Xoroshiro128plus m_rng; | |||
| public: | |||
| using BetaRNG::BetaRNG; | |||
| void exec(_megdnn_tensor_in alpha,_megdnn_tensor_in beta, _megdnn_tensor_out dst, | |||
| _megdnn_workspace) override; | |||
| size_t get_workspace_in_bytes(const TensorLayout&, | |||
| const TensorLayout&, | |||
| const TensorLayout&) override { | |||
| return 0; | |||
| } | |||
| }; | |||
| class PermutationRNGImpl : public PermutationRNG { | |||
| Xoroshiro128plus m_rng; | |||
| public: | |||
| using PermutationRNG::PermutationRNG; | |||
| void exec(_megdnn_tensor_inout dst, _megdnn_workspace) override; | |||
| size_t get_workspace_in_bytes(const TensorLayout&) override { return 0; } | |||
| }; | |||
| } // namespace naive | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -8,36 +8,165 @@ | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "megdnn/oprs.h" | |||
| #include "test/cuda/fixture.h" | |||
| #include "test/naive/rng.h" | |||
| #include "megdnn/oprs.h" | |||
| #include "test/common/tensor.h" | |||
| #include "test/cuda/fixture.h" | |||
| namespace megdnn { | |||
| namespace test { | |||
| namespace { | |||
| template <typename T> | |||
| void run_gamma(Handle* handle) { | |||
| using ctype = typename DTypeTrait<T>::ctype; | |||
| auto opr = handle->create_operator<GammaRNG>(); | |||
| TensorLayout ly{TensorShape{2000000 * 5}, T()}; | |||
| SyncedTensor<ctype> out(handle, ly); | |||
| SyncedTensor<ctype> shape(handle, ly); | |||
| SyncedTensor<ctype> scale(handle, ly); | |||
| auto shape_ptr = shape.ptr_mutable_host(); | |||
| auto scale_ptr = scale.ptr_mutable_host(); | |||
| for (int i = 0; i < 5; ++i) { | |||
| for (int j = 0; j < 2000000; ++j) { | |||
| shape_ptr[i * 2000000 + j] =2 * 0.3 * i + 0.3; | |||
| scale_ptr[i * 2000000 + j] = i * 0.2 + 0.1; | |||
| } | |||
| } | |||
| opr->exec(shape.tensornd_dev(), scale.tensornd_dev(), out.tensornd_dev(), | |||
| {}); | |||
| auto ptr = out.ptr_mutable_host(); | |||
| for (int i = 0; i < 5; ++i) { | |||
| float a = 2 * 0.3 * i + 0.3, b = i * 0.2 + 0.1; | |||
| float mean = a *b; | |||
| float std = a * (b * b); | |||
| auto stat = get_mean_var(ptr + i * 2000000, 2000000, ctype(mean)); | |||
| ASSERT_LE(std::abs(stat.first - mean), 0.01); | |||
| ASSERT_LE(std::abs(stat.second - std), 0.01); | |||
| } | |||
| } | |||
| template <typename T> | |||
| void run_poisson(Handle* handle) { | |||
| using ctype = typename DTypeTrait<T>::ctype; | |||
| auto opr = handle->create_operator<PoissonRNG>(); | |||
| TensorLayout ly{TensorShape{200000 * 5}, T()}; | |||
| SyncedTensor<ctype> out(handle, ly); | |||
| SyncedTensor<ctype> lam(handle, ly); | |||
| auto lam_ptr = lam.ptr_mutable_host(); | |||
| for (int i = 0; i < 5; ++i) { | |||
| for (int j = 0; j < 200000; ++j) { | |||
| lam_ptr[i * 200000 + j] = ctype(i + 1); | |||
| } | |||
| } | |||
| opr->exec(lam.tensornd_dev(), out.tensornd_dev(), {}); | |||
| auto ptr = out.ptr_mutable_host(); | |||
| for (int i = 0; i < 5; ++i) { | |||
| auto stat = get_mean_var(ptr + i * 200000, 200000, ctype(i + 1)); | |||
| ASSERT_LE(std::abs(stat.first - ctype(i + 1)), 0.01); | |||
| ASSERT_LE(std::abs(stat.second - ctype(i + 1)), 0.01); | |||
| } | |||
| } | |||
| template <typename T> | |||
| void run_beta(Handle* handle) { | |||
| using ctype = typename DTypeTrait<T>::ctype; | |||
| auto opr = handle->create_operator<BetaRNG>(); | |||
| TensorLayout ly{TensorShape{200000 * 5}, T()}; | |||
| SyncedTensor<ctype> out(handle, ly); | |||
| SyncedTensor<ctype> alpha(handle, ly); | |||
| SyncedTensor<ctype> beta(handle, ly); | |||
| auto alpha_ptr = alpha.ptr_mutable_host(); | |||
| auto beta_ptr = beta.ptr_mutable_host(); | |||
| for (int i = 0; i < 5; ++i) { | |||
| for (int j = 0; j < 200000; ++j) { | |||
| alpha_ptr[i * 200000 + j] = 0.3 * i + 0.1; | |||
| beta_ptr[i * 200000 + j] = 2 * i * 0.3 + 0.1; | |||
| } | |||
| } | |||
| opr->exec(alpha.tensornd_dev(), beta.tensornd_dev(), out.tensornd_dev(), | |||
| {}); | |||
| auto ptr = out.ptr_mutable_host(); | |||
| for (int i = 0; i < 5; ++i) { | |||
| float a = 0.3 * i + 0.1, b = 2 * i * 0.3 + 0.1; | |||
| float mean = a / (a + b); | |||
| float std = a * b / ((a + b) * (a + b) * (a + b + 1)); | |||
| auto stat = get_mean_var(ptr + i * 200000, 200000, ctype(mean)); | |||
| ASSERT_LE(std::abs(stat.first - mean), 0.01); | |||
| ASSERT_LE(std::abs(stat.second - std), 0.01); | |||
| } | |||
| } | |||
| template <typename T> | |||
| void run_permutation(Handle* handle) { | |||
| using ctype = typename DTypeTrait<T>::ctype; | |||
| size_t sample_num = | |||
| std::min(200000, static_cast<int>(DTypeTrait<T>::max()) - 10); | |||
| auto opr = handle->create_operator<PermutationRNG>(); | |||
| opr->param().dtype = DTypeTrait<T>::enumv; | |||
| TensorLayout ly{TensorShape{sample_num}, T()}; | |||
| Tensor<dt_byte> workspace( | |||
| handle, | |||
| {TensorShape{opr->get_workspace_in_bytes(ly)}, dtype::Byte()}); | |||
| SyncedTensor<ctype> t(handle, ly); | |||
| opr->exec(t.tensornd_dev(), | |||
| {workspace.ptr(), workspace.layout().total_nr_elems()}); | |||
| auto ptr = t.ptr_mutable_host(); | |||
| auto size = t.layout().total_nr_elems(); | |||
| std::vector<ctype> res(size); | |||
| int not_same = 0; | |||
| for (size_t i = 0; i < size; ++i) { | |||
| if ((ptr[i] - ctype(i)) >= ctype(1)) not_same++; | |||
| res[i] = ptr[i]; | |||
| } | |||
| ASSERT_GT(not_same, 5000); | |||
| std::sort(res.begin(), res.end()); | |||
| for (size_t i = 0; i < size; ++i) { | |||
| ASSERT_LE(std::abs(res[i] - ctype(i)), 1e-8); | |||
| } | |||
| } | |||
| } // anonymous namespace | |||
| TEST_F(CUDA, UNIFORM_RNG_F32) { | |||
| auto opr = handle_cuda()->create_operator<UniformRNG>(); | |||
| opr->param().dtype = DTypeTrait<dtype::Float32>::enumv; | |||
| SyncedTensor<> t(handle_cuda(), {TensorShape{200000}, dtype::Float32()}); | |||
| opr->exec(t.tensornd_dev(), {}); | |||
| assert_uniform_correct(t.ptr_mutable_host(), | |||
| t.layout().total_nr_elems()); | |||
| assert_uniform_correct(t.ptr_mutable_host(), t.layout().total_nr_elems()); | |||
| } | |||
| TEST_F(CUDA, GAUSSIAN_RNG_F32) { | |||
| auto opr = handle_cuda()->create_operator<GaussianRNG>(); | |||
| opr->param().mean = 0.8; | |||
| opr->param().std = 2.3; | |||
| for (size_t size: {1, 200000, 200001}) { | |||
| opr->param().dtype = DTypeTrait<dtype::Float32>::enumv; | |||
| for (size_t size : {1, 200000, 200001}) { | |||
| TensorLayout ly{{size}, dtype::Float32()}; | |||
| Tensor<dt_byte> workspace(handle_cuda(), | |||
| {TensorShape{opr->get_workspace_in_bytes(ly)}, | |||
| dtype::Byte()}); | |||
| Tensor<dt_byte> workspace( | |||
| handle_cuda(), | |||
| {TensorShape{opr->get_workspace_in_bytes(ly)}, dtype::Byte()}); | |||
| SyncedTensor<> t(handle_cuda(), ly); | |||
| opr->exec(t.tensornd_dev(), | |||
| {workspace.ptr(), workspace.layout().total_nr_elems()}); | |||
| {workspace.ptr(), workspace.layout().total_nr_elems()}); | |||
| auto ptr = t.ptr_mutable_host(); | |||
| ASSERT_LE(std::abs(ptr[0] - 0.8), 2.3); | |||
| @@ -50,10 +179,43 @@ TEST_F(CUDA, GAUSSIAN_RNG_F32) { | |||
| } | |||
| } | |||
| } // namespace test | |||
| } // namespace megdnn | |||
| TEST_F(CUDA, GAMMA_RNG_F32) { | |||
| run_gamma<dtype::Float32>(handle_cuda()); | |||
| } | |||
| // vim: syntax=cpp.doxygen | |||
| TEST_F(CUDA, GAMMA_RNG_F16) { | |||
| run_gamma<dtype::Float16>(handle_cuda()); | |||
| } | |||
| TEST_F(CUDA, POISSON_RNG_F32) { | |||
| run_poisson<dtype::Float32>(handle_cuda()); | |||
| } | |||
| TEST_F(CUDA, POISSON_RNG_F16) { | |||
| run_poisson<dtype::Float16>(handle_cuda()); | |||
| } | |||
| TEST_F(CUDA, BETA_RNG_F32) { | |||
| run_beta<dtype::Float32>(handle_cuda()); | |||
| } | |||
| TEST_F(CUDA, BETA_RNG_F16) { | |||
| run_beta<dtype::Float16>(handle_cuda()); | |||
| } | |||
| TEST_F(CUDA, PERMUTATION_RNG_F32) { | |||
| run_permutation<dtype::Float32>(handle_cuda()); | |||
| } | |||
| TEST_F(CUDA, PERMUTATION_RNG_INT32) { | |||
| run_permutation<dtype::Int32>(handle_cuda()); | |||
| } | |||
| TEST_F(CUDA, PERMUTATION_RNG_INT16) { | |||
| run_permutation<dtype::Int16>(handle_cuda()); | |||
| } | |||
| } // namespace test | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -32,6 +32,7 @@ namespace { | |||
| template<typename dtype> | |||
| void run_uniform(Handle *handle) { | |||
| auto opr = handle->create_operator<UniformRNG>(); | |||
| opr->param().dtype = DTypeTrait<dtype>::enumv; | |||
| Tensor<typename DTypeTrait<dtype>::ctype> t( | |||
| handle, {TensorShape{200000}, dtype()}); | |||
| opr->exec(t.tensornd(), {}); | |||
| @@ -44,6 +45,7 @@ namespace { | |||
| auto opr = handle->create_operator<GaussianRNG>(); | |||
| opr->param().mean = 0.8; | |||
| opr->param().std = 2.3; | |||
| opr->param().dtype = DTypeTrait<dtype>::enumv; | |||
| Tensor<ctype> t(handle, {TensorShape{200001}, dtype()}); | |||
| opr->exec(t.tensornd(), {}); | |||
| @@ -53,8 +55,131 @@ namespace { | |||
| ASSERT_LE(std::abs(ptr[i] - 0.8), ctype(15)); | |||
| } | |||
| auto stat = get_mean_var(ptr, size, ctype(0.8)); | |||
| ASSERT_LE(std::abs(stat.first - 0.8), 5e-3); | |||
| ASSERT_LE(std::abs(stat.second - 2.3 * 2.3), 5e-2); | |||
| ASSERT_LE(std::abs(stat.second - 2.3 * 2.3), 5e-2); | |||
| } | |||
| template<typename dtype> | |||
| void run_gamma(Handle* handle){ | |||
| using ctype = typename DTypeTrait<dtype>::ctype; | |||
| auto opr = handle->create_operator<GammaRNG>(); | |||
| TensorLayout ly{TensorShape{2000000*5}, dtype()}; | |||
| Tensor<ctype> out(handle, ly); | |||
| Tensor<ctype> shape(handle, ly); | |||
| Tensor<ctype> scale(handle, ly); | |||
| auto shape_ptr = shape.ptr(); | |||
| auto scale_ptr = scale.ptr(); | |||
| for (int i = 0; i < 5; ++i) { | |||
| for (int j = 0; j < 2000000; ++j) { | |||
| shape_ptr[i * 2000000 + j] = 2 * 0.3 * i + 0.5; | |||
| scale_ptr[i * 2000000 + j] = i * 0.2 + 0.1; | |||
| } | |||
| } | |||
| opr->exec(shape.tensornd(), scale.tensornd(), out.tensornd(), {}); | |||
| auto ptr = out.ptr(); | |||
| for(int i = 0; i < 5 ; ++i){ | |||
| float a = 2 * 0.3 * i + 0.5, b = i * 0.2 + 0.1; | |||
| float mean = a * b; | |||
| float std = a * (b * b) ; | |||
| auto stat = get_mean_var(ptr + i * 2000000, 2000000, ctype(mean)); | |||
| ASSERT_LE(std::abs(stat.first - mean), 0.01); | |||
| ASSERT_LE(std::abs(stat.second - std), 0.01); | |||
| } | |||
| } | |||
| template<typename dtype> | |||
| void run_poisson(Handle* handle){ | |||
| using ctype = typename DTypeTrait<dtype>::ctype; | |||
| auto opr = handle->create_operator<PoissonRNG>(); | |||
| TensorLayout ly{TensorShape{200000*5}, dtype()}; | |||
| Tensor<ctype> out(handle, ly); | |||
| Tensor<ctype> lam(handle, ly); | |||
| auto lam_ptr = lam.ptr(); | |||
| for(int i = 0; i < 5; ++i){ | |||
| for(int j = 0; j <200000; ++j){ | |||
| lam_ptr[i*200000 + j] = ctype(i + 1); | |||
| } | |||
| } | |||
| opr->exec(lam.tensornd(), out.tensornd(), {}); | |||
| auto ptr = out.ptr(); | |||
| for(int i = 0; i < 5 ; ++i){ | |||
| auto stat = get_mean_var(ptr + i*200000, 200000, ctype(i + 1)); | |||
| ASSERT_LE(std::abs(stat.first - ctype(i + 1)), 0.01); | |||
| ASSERT_LE(std::abs(stat.second - ctype(i + 1)), 0.01); | |||
| } | |||
| } | |||
| template<typename dtype> | |||
| void run_beta(Handle* handle){ | |||
| using ctype = typename DTypeTrait<dtype>::ctype; | |||
| auto opr = handle->create_operator<BetaRNG>(); | |||
| TensorLayout ly{TensorShape{200000*5}, dtype()}; | |||
| Tensor<ctype> out(handle, ly); | |||
| Tensor<ctype> alpha(handle, ly); | |||
| Tensor<ctype> beta(handle, ly); | |||
| auto alpha_ptr = alpha.ptr(); | |||
| auto beta_ptr = beta.ptr(); | |||
| for (int i = 0; i < 5; ++i) { | |||
| for (int j = 0; j < 200000; ++j) { | |||
| alpha_ptr[i * 200000 + j] = 0.3 * i + 0.1; | |||
| beta_ptr[i * 200000 + j] = 2 * i * 0.3 + 0.1; | |||
| } | |||
| } | |||
| opr->exec(alpha.tensornd(),beta.tensornd(), out.tensornd(), {}); | |||
| auto ptr = out.ptr(); | |||
| for(int i = 0; i < 5 ; ++i){ | |||
| float a = 0.3 * i + 0.1, b = 2 * i * 0.3 + 0.1; | |||
| float mean = a / (a + b); | |||
| float std = a * b / ((a + b) * (a + b) * (a + b + 1)); | |||
| auto stat = get_mean_var(ptr + i * 200000, 200000, ctype(mean)); | |||
| ASSERT_LE(std::abs(stat.first - mean), 0.01); | |||
| ASSERT_LE(std::abs(stat.second - std), 0.01); | |||
| } | |||
| } | |||
| template<typename dtype> | |||
| void run_permutation(Handle* handle){ | |||
| using ctype = typename DTypeTrait<dtype>::ctype; | |||
| size_t sample_num = std::min(200000, | |||
| static_cast<int>(DTypeTrait<dtype>::max()) - 10); | |||
| auto opr = handle->create_operator<PermutationRNG>(); | |||
| opr->param().dtype = DTypeTrait<dtype>::enumv; | |||
| TensorLayout ly{TensorShape{sample_num}, dtype()}; | |||
| Tensor<ctype> t(handle, ly); | |||
| opr->exec(t.tensornd(), {}); | |||
| auto ptr = t.ptr(); | |||
| auto size = t.layout().total_nr_elems(); | |||
| std::vector<ctype> res(size); | |||
| int not_same = 0; | |||
| for(size_t i = 0; i < size; ++i){ | |||
| if ((ptr[i] - ctype(i)) >= 1 ) not_same++; | |||
| res[i] = ptr[i]; | |||
| } | |||
| ASSERT_GT(not_same, 5000); | |||
| std::sort(res.begin(),res.end()); | |||
| for(size_t i = 0; i < size; ++i){ | |||
| ASSERT_LE(std::abs(res[i] - ctype(i)), 1e-8); | |||
| } | |||
| } | |||
| } | |||
| @@ -74,6 +199,42 @@ TEST_F(NAIVE, GAUSSIAN_RNG_F16) { | |||
| DNN_INC_FLOAT16(run_gaussian<dtype::Float16>(handle())); | |||
| } | |||
| TEST_F(NAIVE, GAMMA_RNG_F32) { | |||
| run_gamma<dtype::Float32>(handle()); | |||
| } | |||
| TEST_F(NAIVE, GAMMA_RNG_F16) { | |||
| DNN_INC_FLOAT16(run_gamma<dtype::Float16>(handle())); | |||
| } | |||
| TEST_F(NAIVE, POISSON_RNG_F32) { | |||
| run_poisson<dtype::Float32>(handle()); | |||
| } | |||
| TEST_F(NAIVE, POISSON_RNG_F16) { | |||
| DNN_INC_FLOAT16(run_poisson<dtype::Float16>(handle())); | |||
| } | |||
| TEST_F(NAIVE, BETA_RNG_F32) { | |||
| run_beta<dtype::Float32>(handle()); | |||
| } | |||
| TEST_F(NAIVE, BETA_RNG_F16) { | |||
| DNN_INC_FLOAT16(run_beta<dtype::Float16>(handle())); | |||
| } | |||
| TEST_F(NAIVE, PERMUTATION_RNG_F32) { | |||
| run_permutation<dtype::Float32>(handle()); | |||
| } | |||
| TEST_F(NAIVE, PERMUTATION_RNG_INT32) { | |||
| run_permutation<dtype::Int32>(handle()); | |||
| } | |||
| TEST_F(NAIVE, PERMUTATION_RNG_INT16) { | |||
| run_permutation<dtype::Int16>(handle()); | |||
| } | |||
| } // namespace test | |||
| } // namespace megdnn | |||
| @@ -6,8 +6,17 @@ | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from .distribution import normal, uniform | |||
| from .rng import RNG, seed | |||
| from .rng import RNG, beta, gamma, normal, permutation, poisson, seed, uniform | |||
| __all__ = [ | |||
| "RNG", | |||
| "beta", | |||
| "gamma", | |||
| "normal", | |||
| "permutation", | |||
| "poisson", | |||
| "seed", | |||
| "uniform", | |||
| ] | |||
| # pylint: disable=undefined-variable | |||
| del distribution, rng # type: ignore[name-defined] | |||
| del rng # type: ignore[name-defined] | |||
| @@ -1,95 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from typing import Iterable, Optional | |||
| from .. import Tensor | |||
| from ..core._imperative_rt.ops import get_global_rng_seed as _get_global_rng_seed | |||
| from .rng import _normal, _uniform | |||
| __all__ = ["normal", "uniform"] | |||
| def normal( | |||
| mean: float = 0, std: float = 1, size: Optional[Iterable[int]] = None | |||
| ) -> Tensor: | |||
| r""" | |||
| Random variable with Gaussian distribution :math:`N(\mu, \sigma)`. | |||
| :param size: output tensor size. | |||
| :param mean: the mean or expectation of the distribution. | |||
| :param std: the standard deviation of the distribution (variance = :math:`\sigma ^ 2`). | |||
| :return: the output tensor. | |||
| Examples: | |||
| .. testcode:: | |||
| import megengine as mge | |||
| import megengine.random as rand | |||
| x = rand.normal(mean=0, std=1, size=(2, 2)) | |||
| print(x.numpy()) | |||
| Outputs: | |||
| .. testoutput:: | |||
| :options: +SKIP | |||
| [[-0.20235455 -0.6959438 ] | |||
| [-1.4939808 -1.5824696 ]] | |||
| """ | |||
| return _normal( | |||
| mean=mean, | |||
| std=std, | |||
| size=size, | |||
| seed=_get_global_rng_seed(), | |||
| device=None, | |||
| handle=0, | |||
| ) | |||
| def uniform( | |||
| low: float = 0, high: float = 1, size: Optional[Iterable[int]] = None | |||
| ) -> Tensor: | |||
| r""" | |||
| Random variable with uniform distribution $U(0, 1)$. | |||
| :param size: output tensor size. | |||
| :param low: lower range. | |||
| :param high: upper range. | |||
| :return: the output tensor. | |||
| Examples: | |||
| .. testcode:: | |||
| import megengine as mge | |||
| import megengine.random as rand | |||
| x = rand.uniform(size=(2, 2)) | |||
| print(x.numpy()) | |||
| Outputs: | |||
| .. testoutput:: | |||
| :options: +SKIP | |||
| [[0.76901674 0.70496535] | |||
| [0.09365904 0.62957656]] | |||
| """ | |||
| return _uniform( | |||
| low=low, | |||
| high=high, | |||
| size=size, | |||
| seed=_get_global_rng_seed(), | |||
| device=None, | |||
| handle=0, | |||
| ) | |||
| @@ -6,8 +6,9 @@ | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import collections | |||
| import time | |||
| from typing import Iterable, Optional | |||
| from typing import Iterable, Optional, Union | |||
| from numpy.random import MT19937 | |||
| @@ -15,15 +16,97 @@ from .. import Tensor | |||
| from ..core._imperative_rt.core2 import apply | |||
| from ..core._imperative_rt.ops import delete_rng_handle as _delete_rng_handle | |||
| from ..core._imperative_rt.ops import get_global_rng_seed as _get_global_rng_seed | |||
| from ..core._imperative_rt.ops import ( | |||
| get_rng_handle_compnode as _get_rng_handle_compnode, | |||
| ) | |||
| from ..core._imperative_rt.ops import new_rng_handle as _new_rng_handle | |||
| from ..core._imperative_rt.ops import set_global_rng_seed as _set_global_rng_seed | |||
| from ..core.ops.builtin import GaussianRNG, UniformRNG | |||
| from ..core.ops.builtin import ( | |||
| BetaRNG, | |||
| GammaRNG, | |||
| GaussianRNG, | |||
| PermutationRNG, | |||
| PoissonRNG, | |||
| UniformRNG, | |||
| ) | |||
| from ..core.tensor import utils | |||
| from ..device import get_default_device | |||
| __all__ = [ | |||
| "seed", | |||
| "RNG", | |||
| "uniform", | |||
| "normal", | |||
| "gamma", | |||
| "beta", | |||
| "poisson", | |||
| "permutation", | |||
| ] | |||
| _rng = None | |||
| def _infer_broadcasted_shape(inps: Iterable[Tensor]) -> tuple: | |||
| broadcasted_ndim = inps[0].ndim | |||
| broadcasted_shape = list(inps[0]._tuple_shape) | |||
| for i in range(1, len(inps)): | |||
| cur_ndim = inps[i].ndim | |||
| cur_shape = list(inps[i]._tuple_shape) | |||
| n_dim = max(cur_ndim, broadcasted_ndim) | |||
| for j in range(n_dim - 1, -1, -1): | |||
| cur_dim = cur_ndim + j - n_dim | |||
| broad_dim = broadcasted_ndim + j - n_dim | |||
| cur_size = cur_shape[cur_dim] if cur_dim >= 0 else 1 | |||
| broad_size = broadcasted_shape[broad_dim] if broad_dim >= 0 else 1 | |||
| assert cur_size == broad_size or cur_size == 1 or broad_size == 1, ( | |||
| "The size of inps[{}] ({}) must match the size ({}) at " | |||
| "dim {}".format(i, cur_size, broad_size, j) | |||
| ) | |||
| broad_size = max(cur_size, broad_size) | |||
| if broad_dim < 0: | |||
| broadcasted_shape = [broad_size] + broadcasted_shape | |||
| broadcasted_ndim += 1 | |||
| else: | |||
| broadcasted_shape[broad_dim] = broad_size | |||
| return tuple(broadcasted_shape) | |||
| def _broadcast_tensors_with_size( | |||
| inps: Iterable[Tensor], size: Iterable[int] | |||
| ) -> Iterable[Tensor]: | |||
| assert inps, "The inps cloud not be empty" | |||
| target_shape = _infer_broadcasted_shape(inps) | |||
| if isinstance(size, collections.abc.Iterable): | |||
| target_shape = tuple(size) + target_shape | |||
| target_ndim = len(target_shape) | |||
| for i in range(len(inps)): | |||
| if inps[i]._tuple_shape != target_shape: | |||
| inps[i] = ( | |||
| inps[i] | |||
| .reshape((1,) * (target_ndim - inps[i].ndim) + inps[i]._tuple_shape) | |||
| ._broadcast(target_shape) | |||
| ) | |||
| return inps | |||
| def _uniform( | |||
| low: float, | |||
| high: float, | |||
| size: Optional[Iterable[int]], | |||
| seed: int, | |||
| device: str, | |||
| handle: int, | |||
| ) -> Tensor: | |||
| assert low < high, "Uniform is not defined when low >= high" | |||
| if size is None: | |||
| size = (1,) | |||
| op = UniformRNG(seed=seed, handle=handle, dtype="float32") | |||
| _ref = Tensor([], dtype="int32", device=device) | |||
| shape = utils.astensor1d(size, _ref, dtype="int32", device=device) | |||
| (output,) = apply(op, shape) | |||
| return low + (high - low) * output | |||
| def _normal( | |||
| mean: float, | |||
| std: float, | |||
| @@ -34,63 +117,477 @@ def _normal( | |||
| ) -> Tensor: | |||
| if size is None: | |||
| size = (1,) | |||
| op = GaussianRNG(seed=seed, mean=mean, std=std, handle=handle) | |||
| op = GaussianRNG(seed=seed, mean=mean, std=std, handle=handle, dtype="float32") | |||
| _ref = Tensor([], dtype="int32", device=device) | |||
| shape = utils.astensor1d(size, _ref, dtype="int32", device=device) | |||
| (output,) = apply(op, shape) | |||
| return output | |||
| def _uniform( | |||
| low: float, | |||
| high: float, | |||
| def _gamma( | |||
| shape: Union[Tensor, float], | |||
| scale: Union[Tensor, float], | |||
| size: Optional[Iterable[int]], | |||
| seed: int, | |||
| device: str, | |||
| handle: int, | |||
| ) -> Tensor: | |||
| assert low < high, "Uniform is not defined when low >= high" | |||
| if size is None: | |||
| size = (1,) | |||
| op = UniformRNG(seed=seed, handle=handle) | |||
| handle_cn = None if handle == 0 else _get_rng_handle_compnode(handle) | |||
| if not isinstance(shape, Tensor): | |||
| assert shape > 0, "Gamma is not defined when shape <= 0" | |||
| shape = Tensor(shape, dtype="float32", device=handle_cn) | |||
| if not isinstance(scale, Tensor): | |||
| assert scale > 0, "Gamma is not defined when scale <= 0" | |||
| scale = Tensor(scale, dtype="float32", device=handle_cn) | |||
| assert ( | |||
| handle_cn is None or handle_cn == shape.device | |||
| ), "The shape ({}) must be the same device with handle ({})".format( | |||
| shape.device, handle_cn | |||
| ) | |||
| assert ( | |||
| handle_cn is None or handle_cn == scale.device | |||
| ), "The scale ({}) must be the same device with handle ({})".format( | |||
| scale.device, handle_cn | |||
| ) | |||
| if isinstance(size, int) and size != 0: | |||
| size = (size,) | |||
| shape, scale = _broadcast_tensors_with_size([shape, scale], size) | |||
| op = GammaRNG(seed=seed, handle=handle) | |||
| (output,) = apply(op, shape, scale) | |||
| return output | |||
| def _beta( | |||
| alpha: Union[Tensor, float], | |||
| beta: Union[Tensor, float], | |||
| size: Optional[Iterable[int]], | |||
| seed: int, | |||
| handle: int, | |||
| ) -> Tensor: | |||
| handle_cn = None if handle == 0 else _get_rng_handle_compnode(handle) | |||
| if not isinstance(alpha, Tensor): | |||
| assert alpha > 0, "Beta is not defined when alpha <= 0" | |||
| alpha = Tensor(alpha, dtype="float32", device=handle_cn) | |||
| if not isinstance(beta, Tensor): | |||
| assert beta > 0, "Beta is not defined when beta <= 0" | |||
| beta = Tensor(beta, dtype="float32", device=handle_cn) | |||
| assert ( | |||
| handle_cn is None or handle_cn == alpha.device | |||
| ), "The alpha ({}) must be the same device with handle ({})".format( | |||
| alpha.device, handle_cn | |||
| ) | |||
| assert ( | |||
| handle_cn is None or handle_cn == beta.device | |||
| ), "The beta ({}) must be the same device with handle ({})".format( | |||
| beta.device, handle_cn | |||
| ) | |||
| if isinstance(size, int) and size != 0: | |||
| size = (size,) | |||
| alpha, beta = _broadcast_tensors_with_size([alpha, beta], size) | |||
| op = BetaRNG(seed=seed, handle=handle) | |||
| (output,) = apply(op, alpha, beta) | |||
| return output | |||
| def _poisson( | |||
| lam: Union[Tensor, float], size: Optional[Iterable[int]], seed: int, handle: int | |||
| ) -> Tensor: | |||
| handle_cn = None if handle == 0 else _get_rng_handle_compnode(handle) | |||
| if not isinstance(lam, Tensor): | |||
| assert lam > 0, "Poisson is not defined when lam <= 0" | |||
| lam = Tensor(lam, dtype="float32", device=handle_cn) | |||
| if isinstance(size, int) and size != 0: | |||
| size = (size,) | |||
| assert ( | |||
| handle_cn is None or handle_cn == lam.device | |||
| ), "The lam ({}) must be the same device with handle ({})".format( | |||
| lam.device, handle_cn | |||
| ) | |||
| (lam,) = _broadcast_tensors_with_size([lam], size) | |||
| op = PoissonRNG(seed=seed, handle=handle) | |||
| (output,) = apply(op, lam) | |||
| return output | |||
| def _permutation(n: int, seed: int, device: str, handle: int, dtype: str) -> Tensor: | |||
| assert isinstance(n, int) and n > 0, "Permutation is not defined when n <= 0" | |||
| size = (n,) | |||
| op = PermutationRNG(seed=seed, handle=handle, dtype=dtype) | |||
| _ref = Tensor([], dtype="int32", device=device) | |||
| shape = utils.astensor1d(size, _ref, dtype="int32", device=device) | |||
| (output,) = apply(op, shape) | |||
| return low + (high - low) * output | |||
| return output | |||
| class RNG: | |||
| def __init__(self, seed=0, device=None): | |||
| self.seed = seed | |||
| self.device = device if device else get_default_device() | |||
| self.handle = _new_rng_handle(self.device, self.seed) | |||
| r""" | |||
| :class:`RNG` exposes a number of methods for generating random numbers. | |||
| :param seed: random seed used to initialize the pseudo-random number generator. | |||
| Default: None | |||
| :param device: the device of generated tensor. Default: None | |||
| Examples: | |||
| .. testcode:: | |||
| import megengine.random as rand | |||
| rng = rand.RNG(seed=100) | |||
| x = rng.uniform(size=(2, 2)) | |||
| print(x.numpy()) | |||
| Outputs: | |||
| .. testoutput:: | |||
| :options: +SKIP | |||
| [[0.84811664 0.6147553 ] | |||
| [0.59429836 0.64727545]] | |||
| """ | |||
| def __init__(self, seed: int = None, device: str = None): | |||
| self._device = device if device else get_default_device() | |||
| if seed is not None: | |||
| self._seed = seed | |||
| self._handle = _new_rng_handle(self._device, self._seed) | |||
| else: | |||
| self._seed = _get_global_rng_seed | |||
| self._handle = 0 | |||
| self._device = None | |||
| def uniform( | |||
| self, low: float = 0, high: float = 1, size: Optional[Iterable[int]] = None | |||
| ): | |||
| r""" | |||
| Random variable with uniform distribution $U(0, 1)$. | |||
| :param low: lower range. Default: 0 | |||
| :param high: upper range. Default: 1 | |||
| :param size: the size of output tensor. Default: None | |||
| :return: the output tensor. | |||
| Examples: | |||
| .. testcode:: | |||
| import megengine as mge | |||
| import megengine.random as rand | |||
| x = rand.uniform(size=(2, 2)) | |||
| print(x.numpy()) | |||
| Outputs: | |||
| .. testoutput:: | |||
| :options: +SKIP | |||
| [[0.91600335 0.6680226 ] | |||
| [0.2046729 0.2769141 ]] | |||
| """ | |||
| _seed = self._seed() if callable(self._seed) else self._seed | |||
| return _uniform( | |||
| low=low, | |||
| high=high, | |||
| size=size, | |||
| seed=self.seed, | |||
| device=self.device, | |||
| handle=self.handle, | |||
| seed=_seed, | |||
| device=self._device, | |||
| handle=self._handle, | |||
| ) | |||
| def normal( | |||
| self, mean: float = 0, std: float = 1, size: Optional[Iterable[int]] = None | |||
| ): | |||
| r""" | |||
| Random variable with Gaussian distribution :math:`N(\mu, \sigma)`. | |||
| :param mean: the mean or expectation of the distribution. Default: 0 | |||
| :param std: the standard deviation of the distribution (variance = :math:`\sigma ^ 2`). | |||
| Default: 1 | |||
| :param size: the size of output tensor. Default: None | |||
| :return: the output tensor. | |||
| Examples: | |||
| .. testcode:: | |||
| import megengine as mge | |||
| import megengine.random as rand | |||
| x = rand.normal(mean=0, std=1, size=(2, 2)) | |||
| print(x.numpy()) | |||
| Outputs: | |||
| .. testoutput:: | |||
| :options: +SKIP | |||
| [[-1.4010863 -0.9874344 ] | |||
| [ 0.56373274 0.79656655]] | |||
| """ | |||
| _seed = self._seed() if callable(self._seed) else self._seed | |||
| return _normal( | |||
| mean=mean, | |||
| std=std, | |||
| size=size, | |||
| seed=self.seed, | |||
| device=self.device, | |||
| handle=self.handle, | |||
| seed=_seed, | |||
| device=self._device, | |||
| handle=self._handle, | |||
| ) | |||
| def gamma( | |||
| self, | |||
| shape: Union[Tensor, float], | |||
| scale: Union[Tensor, float] = 1, | |||
| size: Optional[Iterable[int]] = None, | |||
| ): | |||
| r""" | |||
| Random variable with Gamma distribution :math:`\Gamma(k, \theta)`. | |||
| The corresponding probability density function is | |||
| .. math:: | |||
| p(x)=x^{k-1} \frac{e^{-x / \theta}}{\theta^{k} \Gamma(k)} | |||
| \quad \text { for } x>0 \quad k, \theta>0, | |||
| where :math:`\Gamma(k)` is the gamma function, | |||
| .. math:: | |||
| \Gamma(k)=(k-1) ! \quad \text { for } \quad k>0. | |||
| :param shape: the shape parameter (sometimes designated "k") of the distribution. | |||
| Must be non-negative. | |||
| :param scale: the scale parameter (sometimes designated "theta") of the distribution. | |||
| Must be non-negative. Default: 1 | |||
| :param size: the size of output tensor. If shape and scale are scalars and given size is, e.g., | |||
| `(m, n)`, then the output shape is `(m, n)`. If shape or scale is a Tensor and given size | |||
| is, e.g., `(m, n)`, then the output shape is `(m, n) + broadcast(shape, scale).shape`. | |||
| The broadcast rules are consistent with `numpy.broadcast`. Default: None | |||
| :return: the output tensor. | |||
| Examples: | |||
| .. testcode:: | |||
| import megengine as mge | |||
| import megengine.random as rand | |||
| x = rand.gamma(shape=2, scale=1, size=(2, 2)) | |||
| print(x.numpy()) | |||
| shape = mge.Tensor([[ 1], | |||
| [10]], dtype="float32") | |||
| scale = mge.Tensor([1,5], dtype="float32") | |||
| x = rand.gamma(shape=shape, scale=scale) | |||
| print(x.numpy()) | |||
| x = rand.gamma(shape=shape, scale=scale, size=2) | |||
| print(x.numpy()) | |||
| Outputs: | |||
| .. testoutput:: | |||
| :options: +SKIP | |||
| [[1.5064533 4.0689363 ] | |||
| [0.71639484 1.4551026 ]] | |||
| [[ 0.4352188 11.399335 ] | |||
| [ 9.1888 52.009277 ]] | |||
| [[[ 1.1726005 3.9654975 ] | |||
| [13.656933 36.559006 ]] | |||
| [[ 0.25848487 2.5540342 ] | |||
| [11.960409 21.031536 ]]] | |||
| """ | |||
| _seed = self._seed() if callable(self._seed) else self._seed | |||
| return _gamma( | |||
| shape=shape, scale=scale, size=size, seed=_seed, handle=self._handle | |||
| ) | |||
| def beta( | |||
| self, | |||
| alpha: Union[Tensor, float], | |||
| beta: Union[Tensor, float], | |||
| size: Optional[Iterable[int]] = None, | |||
| ): | |||
| r""" | |||
| Random variable with Beta distribution :math:`\operatorname{Beta}(\alpha, \beta)`. | |||
| The corresponding probability density function is | |||
| .. math:: | |||
| p(x)=\frac{1}{\mathrm{~B}(\alpha, \beta)} x^{\alpha-1}(1-x)^{\beta-1} | |||
| \quad \text { for } \alpha, \beta>0, | |||
| where :math:`\mathrm{~B}(\alpha, \beta)` is the beta function, | |||
| .. math:: | |||
| \mathrm{~B}(\alpha, \beta)=\int_{0}^{1} t^{\alpha-1}(1-t)^{\beta-1} d t. | |||
| :param alpha: the alpha parameter of the distribution. Must be non-negative. | |||
| :param beta: the beta parameter of the distribution. Must be non-negative. | |||
| :param size: the size of output tensor. If alpha and beta are scalars and given size is, e.g., | |||
| `(m, n)`, then the output shape is `(m, n)`. If alpha or beta is a Tensor and given size | |||
| is, e.g., `(m, n)`, then the output shape is `(m, n) + broadcast(alpha, beta).shape`. | |||
| The broadcast rules are consistent with `numpy.broadcast`. Default: None | |||
| :return: the output tensor. | |||
| Examples: | |||
| .. testcode:: | |||
| import megengine as mge | |||
| import megengine.random as rand | |||
| x = rand.beta(alpha=2, beta=1, size=(2, 2)) | |||
| print(x.numpy()) | |||
| alpha = mge.Tensor([[0.5], | |||
| [ 3]], dtype="float32") | |||
| beta = mge.Tensor([0.5,5], dtype="float32") | |||
| x = rand.beta(alpha=alpha, beta=beta) | |||
| print(x.numpy()) | |||
| x = rand.beta(alpha=alpha, beta=beta, size=2) | |||
| print(x.numpy()) | |||
| Outputs: | |||
| .. testoutput:: | |||
| :options: +SKIP | |||
| [[0.582565 0.91763186] | |||
| [0.86963767 0.6088103 ]] | |||
| [[0.41503012 0.16438372] | |||
| [0.90159506 0.47588003]] | |||
| [[[0.55195075 0.01111084] | |||
| [0.95298755 0.25048104]] | |||
| [[0.11680304 0.13859665] | |||
| [0.997879 0.43259275]]] | |||
| """ | |||
| _seed = self._seed() if callable(self._seed) else self._seed | |||
| return _beta(alpha=alpha, beta=beta, size=size, seed=_seed, handle=self._handle) | |||
| def poisson(self, lam: Union[float, Tensor], size: Optional[Iterable[int]] = None): | |||
| r""" | |||
| Random variable with poisson distribution :math:`\operatorname{Poisson}(\lambda)`. | |||
| The corresponding probability density function is | |||
| .. math:: | |||
| f(k ; \lambda)=\frac{\lambda^{k} e^{-\lambda}}{k !}, | |||
| where k is the number of occurrences :math:`({\displaystyle k=0,1,2...})`. | |||
| :param lam: the lambda parameter of the distribution. Must be non-negative. | |||
| :param size: the size of output tensor. If lam is a scalar and given size is, e.g., `(m, n)`, | |||
| then the output shape is `(m, n)`. If lam is a Tensor with shape `(k, v)` and given | |||
| size is, e.g., `(m, n)`, then the output shape is `(m, n, k, v)`. Default: None. | |||
| :return: the output tensor. | |||
| Examples: | |||
| .. testcode:: | |||
| import megengine as mge | |||
| import megengine.random as rand | |||
| x = rand.poisson(lam=2., size=(1, 3)) | |||
| print(x.numpy()) | |||
| lam = mge.Tensor([[1.,1.], | |||
| [10,10]], dtype="float32") | |||
| x = rand.poisson(lam=lam) | |||
| print(x.numpy()) | |||
| x = rand.poisson(lam=lam, size=(1,3)) | |||
| print(x.numpy()) | |||
| Outputs: | |||
| .. testoutput:: | |||
| :options: +SKIP | |||
| [[3. 1. 3.]] | |||
| [[ 2. 2.] | |||
| [12. 11.]] | |||
| [[[[ 1. 1.] | |||
| [11. 4.]] | |||
| [[ 0. 0.] | |||
| [ 9. 13.]] | |||
| [[ 0. 1.] | |||
| [ 7. 12.]]]] | |||
| """ | |||
| _seed = self._seed() if callable(self._seed) else self._seed | |||
| return _poisson(lam=lam, size=size, seed=_seed, handle=self._handle) | |||
| def permutation(self, n: int, *, dtype: str = "int32"): | |||
| r""" | |||
| Generates a random permutation of integers from :math:`0` to :math:`n - 1`. | |||
| :param n: the upper bound. Must be larger than 0. | |||
| :param dtype: the output data type. int32, int16 and float32 are | |||
| supported. Default: int32 | |||
| :return: the output tensor. | |||
| Examples: | |||
| .. testcode:: | |||
| import megengine as mge | |||
| import megengine.random as rand | |||
| x = rand.permutation(n=10, dtype="int32") | |||
| print(x.numpy()) | |||
| x = rand.permutation(n=10, dtype="float32") | |||
| print(x.numpy()) | |||
| Outputs: | |||
| .. testoutput:: | |||
| :options: +SKIP | |||
| [4 5 0 7 3 8 6 1 9 2] | |||
| [3. 4. 9. 0. 6. 8. 7. 1. 5. 2.] | |||
| """ | |||
| _seed = self._seed() if callable(self._seed) else self._seed | |||
| return _permutation( | |||
| n=n, seed=_seed, device=self._device, handle=self._handle, dtype=dtype | |||
| ) | |||
| def __del__(self): | |||
| _delete_rng_handle(self.handle) | |||
| if self._handle != 0: | |||
| _delete_rng_handle(self._handle) | |||
| def _default_rng(): | |||
| r"""Default constructor for :class:`RNG`.""" | |||
| return RNG(seed=None, device=None) | |||
| _default_handle = _default_rng() | |||
| uniform = _default_handle.uniform | |||
| normal = _default_handle.normal | |||
| gamma = _default_handle.gamma | |||
| beta = _default_handle.beta | |||
| poisson = _default_handle.poisson | |||
| permutation = _default_handle.permutation | |||
| def _random_seed_generator(): | |||
| @@ -476,4 +476,5 @@ void init_ops(py::module m) { | |||
| }, py::call_guard<py::gil_scoped_release>()); | |||
| m.def("set_global_rng_seed", &rng::set_global_rng_seed); | |||
| m.def("get_global_rng_seed", &rng::get_global_rng_seed); | |||
| m.def("get_rng_handle_compnode", &rng::get_rng_handle_compnode); | |||
| } | |||
| @@ -9,8 +9,8 @@ | |||
| import numpy as np | |||
| import pytest | |||
| import megengine | |||
| from megengine import is_cuda_available, tensor | |||
| import megengine.functional as F | |||
| from megengine import Tensor | |||
| from megengine.core._imperative_rt import CompNode | |||
| from megengine.core._imperative_rt.core2 import apply | |||
| from megengine.core._imperative_rt.ops import ( | |||
| @@ -18,10 +18,16 @@ from megengine.core._imperative_rt.ops import ( | |||
| get_global_rng_seed, | |||
| new_rng_handle, | |||
| ) | |||
| from megengine.core.ops.builtin import GaussianRNG, UniformRNG | |||
| from megengine.core.ops.builtin import ( | |||
| BetaRNG, | |||
| GammaRNG, | |||
| GaussianRNG, | |||
| PermutationRNG, | |||
| PoissonRNG, | |||
| UniformRNG, | |||
| ) | |||
| from megengine.distributed.helper import get_device_count_by_fork | |||
| from megengine.random import RNG | |||
| from megengine.random.rng import _normal, _uniform | |||
| @pytest.mark.skipif( | |||
| @@ -34,22 +40,24 @@ def test_gaussian_op(): | |||
| 11, | |||
| 12, | |||
| ) | |||
| shape = tensor(shape, dtype="int32") | |||
| op = GaussianRNG(seed=get_global_rng_seed(), mean=1.0, std=3.0) | |||
| shape = Tensor(shape, dtype="int32") | |||
| op = GaussianRNG(seed=get_global_rng_seed(), mean=1.0, std=3.0, dtype="float32") | |||
| (output,) = apply(op, shape) | |||
| assert np.fabs(output.numpy().mean() - 1.0) < 1e-1 | |||
| assert np.sqrt(output.numpy().var()) - 3.0 < 1e-1 | |||
| assert np.fabs(np.sqrt(output.numpy().var()) - 3.0) < 1e-1 | |||
| assert str(output.device) == str(CompNode("xpux")) | |||
| assert output.dtype == np.float32 | |||
| cn = CompNode("xpu2") | |||
| seed = 233333 | |||
| h = new_rng_handle(cn, seed) | |||
| op = GaussianRNG(seed=seed, mean=3.0, std=1.0, handle=h) | |||
| op = GaussianRNG(seed=seed, mean=3.0, std=1.0, dtype="float32", handle=h) | |||
| (output,) = apply(op, shape) | |||
| delete_rng_handle(h) | |||
| assert np.fabs(output.numpy().mean() - 3.0) < 1e-1 | |||
| assert np.sqrt(output.numpy().var()) - 1.0 < 1e-1 | |||
| assert np.fabs(np.sqrt(output.numpy().var()) - 1.0) < 1e-1 | |||
| assert str(output.device) == str(cn) | |||
| assert output.dtype == np.float32 | |||
| @pytest.mark.skipif( | |||
| @@ -62,20 +70,138 @@ def test_uniform_op(): | |||
| 11, | |||
| 12, | |||
| ) | |||
| shape = tensor(shape, dtype="int32") | |||
| op = UniformRNG(seed=get_global_rng_seed()) | |||
| shape = Tensor(shape, dtype="int32") | |||
| op = UniformRNG(seed=get_global_rng_seed(), dtype="float32") | |||
| (output,) = apply(op, shape) | |||
| assert np.fabs(output.numpy().mean() - 0.5) < 1e-1 | |||
| assert str(output.device) == str(CompNode("xpux")) | |||
| assert output.dtype == np.float32 | |||
| cn = CompNode("xpu2") | |||
| seed = 233333 | |||
| h = new_rng_handle(cn, seed) | |||
| op = UniformRNG(seed=seed, handle=h) | |||
| op = UniformRNG(seed=seed, dtype="float32", handle=h) | |||
| (output,) = apply(op, shape) | |||
| delete_rng_handle(h) | |||
| assert np.fabs(output.numpy().mean() - 0.5) < 1e-1 | |||
| assert str(output.device) == str(cn) | |||
| assert output.dtype == np.float32 | |||
| @pytest.mark.skipif( | |||
| get_device_count_by_fork("xpu") <= 2, reason="xpu counts need > 2", | |||
| ) | |||
| def test_gamma_op(): | |||
| _shape, _scale = 2, 0.8 | |||
| _expected_mean, _expected_std = _shape * _scale, np.sqrt(_shape) * _scale | |||
| shape = F.full([8, 9, 11, 12], value=_shape, dtype="float32") | |||
| scale = F.full([8, 9, 11, 12], value=_scale, dtype="float32") | |||
| op = GammaRNG(seed=get_global_rng_seed(), handle=0) | |||
| (output,) = apply(op, shape, scale) | |||
| assert np.fabs(output.numpy().mean() - _expected_mean) < 1e-1 | |||
| assert np.fabs(np.sqrt(output.numpy().var()) - _expected_std) < 1e-1 | |||
| assert str(output.device) == str(CompNode("xpux")) | |||
| cn = CompNode("xpu2") | |||
| seed = 233333 | |||
| h = new_rng_handle(cn, seed) | |||
| shape = F.full([8, 9, 11, 12], value=_shape, dtype="float32", device="xpu2") | |||
| scale = F.full([8, 9, 11, 12], value=_scale, dtype="float32", device="xpu2") | |||
| op = GammaRNG(seed=seed, handle=h) | |||
| (output,) = apply(op, shape, scale) | |||
| delete_rng_handle(h) | |||
| assert np.fabs(output.numpy().mean() - _expected_mean) < 1e-1 | |||
| assert np.fabs(np.sqrt(output.numpy().var()) - _expected_std) < 1e-1 | |||
| assert str(output.device) == str(cn) | |||
| @pytest.mark.skipif( | |||
| get_device_count_by_fork("xpu") <= 2, reason="xpu counts need > 2", | |||
| ) | |||
| def test_beta_op(): | |||
| _alpha, _beta = 2, 0.8 | |||
| _expected_mean = _alpha / (_alpha + _beta) | |||
| _expected_std = np.sqrt( | |||
| _alpha * _beta / ((_alpha + _beta) ** 2 * (_alpha + _beta + 1)) | |||
| ) | |||
| alpha = F.full([8, 9, 11, 12], value=_alpha, dtype="float32") | |||
| beta = F.full([8, 9, 11, 12], value=_beta, dtype="float32") | |||
| op = BetaRNG(seed=get_global_rng_seed()) | |||
| (output,) = apply(op, alpha, beta) | |||
| assert np.fabs(output.numpy().mean() - _expected_mean) < 1e-1 | |||
| assert np.fabs(np.sqrt(output.numpy().var()) - _expected_std) < 1e-1 | |||
| assert str(output.device) == str(CompNode("xpux")) | |||
| cn = CompNode("xpu2") | |||
| seed = 233333 | |||
| h = new_rng_handle(cn, seed) | |||
| alpha = F.full([8, 9, 11, 12], value=_alpha, dtype="float32", device=cn) | |||
| beta = F.full([8, 9, 11, 12], value=_beta, dtype="float32", device=cn) | |||
| op = BetaRNG(seed=seed, handle=h) | |||
| (output,) = apply(op, alpha, beta) | |||
| delete_rng_handle(h) | |||
| assert np.fabs(output.numpy().mean() - _expected_mean) < 1e-1 | |||
| assert np.fabs(np.sqrt(output.numpy().var()) - _expected_std) < 1e-1 | |||
| assert str(output.device) == str(cn) | |||
| @pytest.mark.skipif( | |||
| get_device_count_by_fork("xpu") <= 2, reason="xpu counts need > 2", | |||
| ) | |||
| def test_poisson_op(): | |||
| lam = F.full([8, 9, 11, 12], value=2, dtype="float32") | |||
| op = PoissonRNG(seed=get_global_rng_seed()) | |||
| (output,) = apply(op, lam) | |||
| assert np.fabs(output.numpy().mean() - 2.0) < 1e-1 | |||
| assert np.fabs(np.sqrt(output.numpy().var()) - np.sqrt(2.0)) < 1e-1 | |||
| assert str(output.device) == str(CompNode("xpux")) | |||
| cn = CompNode("xpu2") | |||
| seed = 233333 | |||
| h = new_rng_handle(cn, seed) | |||
| lam = F.full([8, 9, 11, 12], value=2, dtype="float32", device=cn) | |||
| op = PoissonRNG(seed=seed, handle=h) | |||
| (output,) = apply(op, lam) | |||
| delete_rng_handle(h) | |||
| assert np.fabs(output.numpy().mean() - 2.0) < 1e-1 | |||
| assert np.fabs(np.sqrt(output.numpy().var()) - np.sqrt(2.0)) < 1e-1 | |||
| assert str(output.device) == str(cn) | |||
| @pytest.mark.skipif( | |||
| get_device_count_by_fork("xpu") <= 2, reason="xpu counts need > 2", | |||
| ) | |||
| def test_permutation_op(): | |||
| n = 1000 | |||
| def test_permutation_op_dtype(dtype): | |||
| def sum_result(res, fun): | |||
| return sum([1 if i == v else 0 for i, v in enumerate(fun(res.numpy()))]) | |||
| shape = Tensor((n,), dtype="int32") | |||
| op = PermutationRNG(seed=get_global_rng_seed(), dtype=dtype) | |||
| (output,) = apply(op, shape) | |||
| assert sum_result(output, lambda x: x) < 500 | |||
| assert sum_result(output, np.sort) == n | |||
| assert str(output.device) == str(CompNode("xpux")) | |||
| assert output.dtype == dtype | |||
| cn = CompNode("xpu2") | |||
| seed = 233333 | |||
| h = new_rng_handle(cn, seed) | |||
| op = PermutationRNG(seed=seed, handle=h, dtype=dtype) | |||
| (output,) = apply(op, shape) | |||
| delete_rng_handle(h) | |||
| assert sum_result(output, lambda x: x) < 500 | |||
| assert sum_result(output, np.sort) == n | |||
| assert str(output.device) == str(cn) | |||
| assert output.dtype == dtype | |||
| test_permutation_op_dtype(np.float32) | |||
| test_permutation_op_dtype(np.int32) | |||
| test_permutation_op_dtype(np.int16) | |||
| @pytest.mark.skipif( | |||
| @@ -133,3 +259,131 @@ def test_NormalRNG(): | |||
| assert all(out.shape.numpy() == np.array([20, 30, 40])) | |||
| assert np.abs(out.mean().numpy() - mean) / std < 0.1 | |||
| assert np.abs(np.std(out.numpy()) - std) < 0.1 | |||
| @pytest.mark.skipif( | |||
| get_device_count_by_fork("xpu") <= 1, reason="xpu counts need > 1", | |||
| ) | |||
| def test_GammaRNG(): | |||
| m1 = RNG(seed=111, device="xpu0") | |||
| m2 = RNG(seed=111, device="xpu1") | |||
| m3 = RNG(seed=222, device="xpu0") | |||
| out1 = m1.gamma(2, size=(100,)) | |||
| out1_ = m1.uniform(size=(100,)) | |||
| out2 = m2.gamma(2, size=(100,)) | |||
| out3 = m3.gamma(2, size=(100,)) | |||
| np.testing.assert_equal(out1.numpy(), out2.numpy()) | |||
| assert out1.device == "xpu0" and out2.device == "xpu1" | |||
| assert not (out1.numpy() == out3.numpy()).all() | |||
| assert not (out1.numpy() == out1_.numpy()).all() | |||
| shape = Tensor([[2, 3, 4], [9, 10, 11]], dtype=np.float32, device="xpu0") | |||
| scale = Tensor([0.5, 1, 1.5], dtype=np.float32, device="xpu0") | |||
| expected_mean = (shape * scale).numpy() | |||
| expected_std = (F.sqrt(shape) * scale).numpy() | |||
| out = m1.gamma(shape=shape, scale=scale, size=(20, 30, 40)) | |||
| out_shp = out.shape | |||
| if isinstance(out_shp, tuple): | |||
| assert out_shp == (20, 30, 40, 2, 3) | |||
| else: | |||
| assert all(out.shape.numpy() == np.array([20, 30, 40, 2, 3])) | |||
| assert ( | |||
| np.abs(out.mean(axis=(0, 1)).numpy() - expected_mean) / expected_std | |||
| ).mean() < 0.1 | |||
| assert (np.abs(np.std(out.numpy(), axis=(0, 1)) - expected_std)).mean() < 0.1 | |||
| @pytest.mark.skipif( | |||
| get_device_count_by_fork("xpu") <= 1, reason="xpu counts need > 1", | |||
| ) | |||
| def test_BetaRNG(): | |||
| m1 = RNG(seed=111, device="xpu0") | |||
| m2 = RNG(seed=111, device="xpu1") | |||
| m3 = RNG(seed=222, device="xpu0") | |||
| out1 = m1.beta(2, 1, size=(100,)) | |||
| out1_ = m1.uniform(size=(100,)) | |||
| out2 = m2.beta(2, 1, size=(100,)) | |||
| out3 = m3.beta(2, 1, size=(100,)) | |||
| np.testing.assert_equal(out1.numpy(), out2.numpy()) | |||
| assert out1.device == "xpu0" and out2.device == "xpu1" | |||
| assert not (out1.numpy() == out3.numpy()).all() | |||
| assert not (out1.numpy() == out1_.numpy()).all() | |||
| alpha = Tensor([[2, 3, 4], [9, 10, 11]], dtype=np.float32, device="xpu0") | |||
| beta = Tensor([0.5, 1, 1.5], dtype=np.float32, device="xpu0") | |||
| expected_mean = (alpha / (alpha + beta)).numpy() | |||
| expected_std = ( | |||
| F.sqrt(alpha * beta / (F.pow(alpha + beta, 2) * (alpha + beta + 1))) | |||
| ).numpy() | |||
| out = m1.beta(alpha=alpha, beta=beta, size=(20, 30)) | |||
| out_shp = out.shape | |||
| if isinstance(out_shp, tuple): | |||
| assert out_shp == (20, 30, 2, 3) | |||
| else: | |||
| assert all(out.shape.numpy() == np.array([20, 30, 2, 3])) | |||
| assert ( | |||
| np.abs(out.mean(axis=(0, 1)).numpy() - expected_mean) / expected_std | |||
| ).mean() < 0.1 | |||
| assert (np.abs(np.std(out.numpy(), axis=(0, 1)) - expected_std)).mean() < 0.1 | |||
| @pytest.mark.skipif( | |||
| get_device_count_by_fork("xpu") <= 1, reason="xpu counts need > 1", | |||
| ) | |||
| def test_PoissonRNG(): | |||
| m1 = RNG(seed=111, device="xpu0") | |||
| m2 = RNG(seed=111, device="xpu1") | |||
| m3 = RNG(seed=222, device="xpu0") | |||
| lam = Tensor([[2, 3, 4], [9, 10, 11]], dtype=np.float32) | |||
| out1 = m1.poisson(lam.to("xpu0"), size=(100,)) | |||
| out2 = m2.poisson(lam.to("xpu1"), size=(100,)) | |||
| out3 = m3.poisson(lam.to("xpu0"), size=(100,)) | |||
| np.testing.assert_equal(out1.numpy(), out2.numpy()) | |||
| assert out1.device == "xpu0" and out2.device == "xpu1" | |||
| assert not (out1.numpy() == out3.numpy()).all() | |||
| out = m1.poisson(lam.to("xpu0"), size=(20, 30)) | |||
| out_shp = out.shape | |||
| expected_shape = (20, 30) + lam._tuple_shape | |||
| if isinstance(out_shp, tuple): | |||
| assert out_shp == expected_shape | |||
| else: | |||
| assert all(out.shape.numpy() == np.array(expected_shape)) | |||
| lam = lam.numpy() | |||
| assert (np.abs(out.mean(axis=(0, 1)).numpy() - lam) / np.sqrt(lam)).mean() < 0.1 | |||
| assert np.abs(np.std(out.numpy(), axis=(0, 1)) - np.sqrt(lam)).mean() < 0.1 | |||
| @pytest.mark.skipif( | |||
| get_device_count_by_fork("xpu") <= 1, reason="xpu counts need > 1", | |||
| ) | |||
| def test_PermutationRNG(): | |||
| m1 = RNG(seed=111, device="xpu0") | |||
| m2 = RNG(seed=111, device="xpu1") | |||
| m3 = RNG(seed=222, device="xpu0") | |||
| out1 = m1.permutation(n=1000) | |||
| out1_ = m1.uniform(size=(1000,)) | |||
| out2 = m2.permutation(n=1000) | |||
| out3 = m3.permutation(n=1000) | |||
| np.testing.assert_equal(out1.numpy(), out2.numpy()) | |||
| assert out1.device == "xpu0" and out2.device == "xpu1" | |||
| assert not (out1.numpy() == out3.numpy()).all() | |||
| assert not (out1.numpy() == out1_.numpy()).all() | |||
| out = m1.permutation(n=1000) | |||
| out_shp = out.shape | |||
| if isinstance(out_shp, tuple): | |||
| assert out_shp == (1000,) | |||
| else: | |||
| assert all(out.shape.numpy() == np.array([1000])) | |||
| def sum_result(res, fun): | |||
| return sum([1 if i == v else 0 for i, v in enumerate(fun(res.numpy()))]) | |||
| assert sum_result(out, lambda x: x) < 500 | |||
| assert sum_result(out, np.sort) == 1000 | |||
| @@ -176,6 +176,20 @@ struct OpMeth<UniformRNG> { | |||
| using Param = DnnOp::Param; | |||
| using OpNode = mgb::opr::UniformRNG; | |||
| static Param make_param(const UniformRNG& rng) { | |||
| auto handle_seed = RNGDnnOpManager::get_seed(rng.handle); | |||
| mgb_assert(handle_seed == rng.seed, | |||
| "inconsistent rng seed: rng op: %lu handle: %lu", | |||
| handle_seed, rng.seed); | |||
| return {handle_seed, rng.dtype.enumv()}; | |||
| } | |||
| }; | |||
| template <> | |||
| struct OpMeth<PoissonRNG> { | |||
| using DnnOp = megdnn::PoissonRNG; | |||
| using Param = DnnOp::Param; | |||
| using OpNode = mgb::opr::PoissonRNG; | |||
| static Param make_param(const PoissonRNG& rng) { | |||
| auto handle_seed = RNGDnnOpManager::get_seed(rng.handle); | |||
| mgb_assert(handle_seed == rng.seed, | |||
| "inconsistent rng seed: rng op: %lu handle: %lu", | |||
| @@ -194,16 +208,168 @@ struct OpMeth<GaussianRNG> { | |||
| mgb_assert(handle_seed == rng.seed, | |||
| "inconsistent rng seed: rng op: %lu handle: %lu", | |||
| handle_seed, rng.seed); | |||
| return {handle_seed, rng.mean, rng.std}; | |||
| return {handle_seed, rng.mean, rng.std, rng.dtype.enumv()}; | |||
| } | |||
| }; | |||
| template <> | |||
| struct OpMeth<GammaRNG> { | |||
| using DnnOp = megdnn::GammaRNG; | |||
| using Param = DnnOp::Param; | |||
| using OpNode = mgb::opr::GammaRNG; | |||
| static Param make_param(const GammaRNG& rng) { | |||
| auto handle_seed = RNGDnnOpManager::get_seed(rng.handle); | |||
| mgb_assert(handle_seed == rng.seed, | |||
| "inconsistent rng seed: rng op: %lu handle: %lu", | |||
| handle_seed, rng.seed); | |||
| return {handle_seed}; | |||
| } | |||
| }; | |||
| template <> | |||
| struct OpMeth<PermutationRNG> { | |||
| using DnnOp = megdnn::PermutationRNG; | |||
| using Param = DnnOp::Param; | |||
| using OpNode = mgb::opr::PermutationRNG; | |||
| static Param make_param(const PermutationRNG& rng) { | |||
| auto handle_seed = RNGDnnOpManager::get_seed(rng.handle); | |||
| mgb_assert(handle_seed == rng.seed, | |||
| "inconsistent rng seed: rng op: %lu handle: %lu", | |||
| handle_seed, rng.seed); | |||
| return {handle_seed, rng.dtype.enumv()}; | |||
| } | |||
| }; | |||
| template <> | |||
| struct OpMeth<BetaRNG> { | |||
| using DnnOp = megdnn::BetaRNG; | |||
| using Param = DnnOp::Param; | |||
| using OpNode = mgb::opr::BetaRNG; | |||
| static Param make_param(const BetaRNG& rng) { | |||
| auto handle_seed = RNGDnnOpManager::get_seed(rng.handle); | |||
| mgb_assert(handle_seed == rng.seed, | |||
| "inconsistent rng seed: rng op: %lu handle: %lu", | |||
| handle_seed, rng.seed); | |||
| return {handle_seed}; | |||
| } | |||
| }; | |||
| template <bool> | |||
| struct _InferLayout; | |||
| template <int nr_in> | |||
| struct _RNGOprMaker; | |||
| template <int nr_in> | |||
| struct _RNGOprInvoker; | |||
| template<> | |||
| struct _InferLayout<true> | |||
| { | |||
| template<typename Op> | |||
| static TensorLayout do_infer(const TensorPtr& inp, const Op& rng){ | |||
| TensorShape tshape; | |||
| auto hv = inp->get_value().proxy_to_default_cpu(); | |||
| cg::copy_tensor_value_to_shape(tshape, hv); | |||
| return TensorLayout(tshape, rng.dtype); | |||
| } | |||
| template<typename Op> | |||
| static TensorLayout do_infer(const LogicalTensorDesc& inp, const Op& rng){ | |||
| TensorLayout out_layout = inp.layout; | |||
| out_layout.dtype = rng.dtype; | |||
| if (inp.layout.ndim == 0 || inp.value.empty()) { | |||
| out_layout.ndim = 0; | |||
| return out_layout; | |||
| } | |||
| mgb_assert( | |||
| inp.layout.ndim == 1, | |||
| "target shape of %s expects ndim=1; got ndim=%lu actually", | |||
| rng.dyn_typeinfo()->name, | |||
| inp.layout.ndim); | |||
| size_t target_ndim = inp.layout.shape[0]; | |||
| out_layout.ndim = target_ndim; | |||
| auto* ptr = inp.value.ptr<dt_int32>(); | |||
| for (size_t i = 0; i < target_ndim; ++i) { | |||
| out_layout.shape[i] = ptr[i]; | |||
| } | |||
| return out_layout; | |||
| } | |||
| }; | |||
| template<> | |||
| struct _InferLayout<false> | |||
| { | |||
| template<typename Op> | |||
| static TensorLayout do_infer(const TensorPtr& inp, const Op& rng){ | |||
| return inp->layout(); | |||
| } | |||
| template<typename Op> | |||
| static TensorLayout do_infer(const LogicalTensorDesc& inp, const Op& rng){ | |||
| size_t size = inp.layout.total_nr_elems(); | |||
| mgb_assert( | |||
| size > 0, | |||
| "target size of %s expects size>0; got size=%lu actually", | |||
| rng.dyn_typeinfo()->name, | |||
| size); | |||
| return inp.layout; | |||
| } | |||
| }; | |||
| #define _INST_RNG_INVOLKER(DNN_NR_INPUTS) \ | |||
| template<> \ | |||
| struct _RNGOprInvoker<DNN_NR_INPUTS> { \ | |||
| template<typename Opr> \ | |||
| static void exec(Opr *dnn_op, const SmallVector<TensorPtr>& inputs,const TensorPtr& dest){ \ | |||
| size_t wk_size = 0; \ | |||
| wk_size = dnn_op->get_workspace_in_bytes(_FOR_EACH_IN(->layout())dest->layout()); \ | |||
| auto workspace = Blob::make(dest->comp_node(), wk_size); \ | |||
| megdnn::Workspace dnn_wk(workspace->storage().get(), wk_size); \ | |||
| dnn_op->exec(_FOR_EACH_IN(->dev_tensor().as_megdnn()) \ | |||
| dest->dev_tensor().as_megdnn(), dnn_wk); \ | |||
| } \ | |||
| }; | |||
| #define _INST_RNG_MAKER(MGB_NR_INPUTS) \ | |||
| template<> \ | |||
| struct _RNGOprMaker<MGB_NR_INPUTS> { \ | |||
| template<typename Op> \ | |||
| static SymbolVar make(const VarNodeArray& inputs, const Op& rng){ \ | |||
| auto param = OpMeth<Op>::make_param(rng); \ | |||
| OperatorNodeConfig config; \ | |||
| if (rng.handle) { \ | |||
| config = {rng.make_name(), RNGDnnOpManager::get_comp_node(rng.handle)}; \ | |||
| } else { \ | |||
| config = {rng.make_name()}; \ | |||
| } \ | |||
| return OpMeth<Op>::OpNode::make(_FOR_EACH_IN() param, config); \ | |||
| } \ | |||
| }; | |||
| #define _FOR_EACH_IN(subfix) | |||
| _INST_RNG_INVOLKER(0) | |||
| #undef _FOR_EACH_IN | |||
| #define _FOR_EACH_IN(subfix) inputs[0] subfix, | |||
| _INST_RNG_INVOLKER(1) | |||
| _INST_RNG_MAKER(1) | |||
| #undef _FOR_EACH_IN | |||
| #define _FOR_EACH_IN(subfix) inputs[0] subfix, inputs[1] subfix, | |||
| _INST_RNG_INVOLKER(2) | |||
| _INST_RNG_MAKER(2) | |||
| #undef _FOR_EACH_IN | |||
| #undef _INST_RNG_INVOLKER | |||
| #undef _INST_RNG_MAKER | |||
| template <typename Op> | |||
| void exec(const OpDef& op, const SmallVector<TensorPtr>& inputs, | |||
| const SmallVector<TensorPtr>& outputs) { | |||
| auto&& rng = op.cast_final_safe<Op>(); | |||
| auto dest = outputs[0]; | |||
| auto cn = dest->comp_node(); | |||
| auto handle = rng.handle; | |||
| if (!handle) { | |||
| @@ -224,38 +390,40 @@ void exec(const OpDef& op, const SmallVector<TensorPtr>& inputs, | |||
| handle_seed, dnn_op->param().seed); | |||
| } | |||
| dnn_op->param() = OpMeth<Op>::make_param(rng); | |||
| // allocate workspace | |||
| size_t wk_size = dnn_op->get_workspace_in_bytes(dest->layout()); | |||
| auto workspace = Blob::make(cn, wk_size); | |||
| megdnn::Workspace dnn_wk(workspace->storage().get(), wk_size); | |||
| dnn_op->exec(dest->dev_tensor().as_megdnn(), dnn_wk); | |||
| _RNGOprInvoker<OpMeth<Op>::DnnOp::NR_INPUTS>::exec(dnn_op,inputs,dest); | |||
| } | |||
| template <typename Op> | |||
| SmallVector<LogicalTensorDesc> infer_output_attrs( | |||
| const OpDef& op, const SmallVector<TensorPtr>& inputs) { | |||
| LogicalTensorDesc dest; | |||
| auto handle = op.cast_final_safe<Op>().handle; | |||
| auto&& rng = op.cast_final_safe<Op>(); | |||
| auto handle = rng.handle; | |||
| if (handle) { | |||
| dest.comp_node = RNGDnnOpManager::get_comp_node(handle); | |||
| } else { | |||
| dest.comp_node = inputs[0]->comp_node(); | |||
| } | |||
| auto hv = inputs[0]->get_value().proxy_to_default_cpu(); | |||
| TensorShape tshape; | |||
| cg::copy_tensor_value_to_shape(tshape, hv); | |||
| dest.layout = TensorLayout(tshape, dtype::Float32()); | |||
| constexpr bool rng_with_shape = OpMeth<Op>::DnnOp::NR_INPUTS == 0; | |||
| if(!rng_with_shape){ | |||
| for(int i = 0; i < inputs.size(); ++i){ | |||
| mgb_assert(inputs[i]->comp_node() == dest.comp_node, | |||
| "%s expects the device of inputs[%d] to be same as the device of handle; " | |||
| "got %s and %s actually", rng.dyn_typeinfo()->name, i, | |||
| inputs[i]->comp_node().to_string().c_str(), | |||
| dest.comp_node.to_string().c_str()); | |||
| } | |||
| } | |||
| dest.layout = _InferLayout<rng_with_shape>::do_infer(inputs[0], rng); | |||
| return {dest}; | |||
| } | |||
| template <typename Op> | |||
| SmallVector<TensorPtr> apply_on_physical_tensor( | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||
| auto desc = infer_output_attrs<Op>(def, inputs); | |||
| SmallVector<TensorPtr> outputs; | |||
| SmallVector<LogicalTensorDesc> desc; | |||
| desc = infer_output_attrs<Op>(def, inputs); | |||
| for (auto&& i : desc) { | |||
| outputs.push_back(Tensor::make(i.layout, i.comp_node)); | |||
| } | |||
| @@ -268,51 +436,32 @@ SymbolVar apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| size_t nr_inp = inputs.size(); | |||
| constexpr size_t dnn_nr_inp = OpMeth<Op>::DnnOp::NR_INPUTS; | |||
| auto&& rng = def.cast_final_safe<Op>(); | |||
| mgb_assert(nr_inp == 1, "%s expects 1 inputs; got %lu actually", | |||
| rng.dyn_typeinfo()->name, | |||
| nr_inp); | |||
| auto param = OpMeth<Op>::make_param(rng); | |||
| OperatorNodeConfig config; | |||
| if (rng.handle) { | |||
| config = {rng.make_name(), RNGDnnOpManager::get_comp_node(rng.handle)}; | |||
| } else { | |||
| config = {rng.make_name()}; | |||
| if(dnn_nr_inp == 0){ | |||
| mgb_assert(nr_inp == 1, "%s expects 1 inputs; got %lu actually", | |||
| rng.dyn_typeinfo()->name, | |||
| nr_inp); | |||
| } | |||
| return OpMeth<Op>::OpNode::make(inputs[0], param, config); | |||
| constexpr size_t mgb_nr_inp = dnn_nr_inp + !dnn_nr_inp; | |||
| return _RNGOprMaker<mgb_nr_inp>::make(inputs, rng); | |||
| } | |||
| template<typename T> | |||
| template<typename Op> | |||
| std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
| const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { | |||
| auto&& xxx_rng_def = def.cast_final_safe<T>(); | |||
| LogicalTensorDesc dest; | |||
| auto&& xxx_rng_def = def.cast_final_safe<Op>(); | |||
| size_t nr_inp = inputs.size(); | |||
| mgb_assert(nr_inp == 1, "%s expects 1 inputs; got %lu actually", | |||
| xxx_rng_def.dyn_typeinfo()->name, | |||
| nr_inp); | |||
| auto&& tshp = inputs[0]; | |||
| TensorLayout out_layout = tshp.layout; | |||
| out_layout.dtype = dtype::Float32(); | |||
| if (tshp.layout.ndim == 0 || tshp.value.empty()) { | |||
| out_layout.ndim = 0; | |||
| return {{{out_layout, tshp.comp_node}}, true}; | |||
| constexpr bool rng_with_shape = OpMeth<Op>::DnnOp::NR_INPUTS == 0; | |||
| if (rng_with_shape){ | |||
| mgb_assert(nr_inp == 1, "%s expects 1 inputs; got %lu actually", | |||
| xxx_rng_def.dyn_typeinfo()->name, | |||
| nr_inp); | |||
| } | |||
| mgb_assert( | |||
| tshp.layout.ndim == 1, | |||
| "target shape of %s expects ndim=1; got ndim=%lu actually", | |||
| xxx_rng_def.dyn_typeinfo()->name, | |||
| tshp.layout.ndim); | |||
| size_t target_ndim = tshp.layout.shape[0]; | |||
| out_layout.ndim = target_ndim; | |||
| auto* ptr = tshp.value.ptr<dt_int32>(); | |||
| for (size_t i = 0; i < target_ndim; ++i) { | |||
| out_layout.shape[i] = ptr[i]; | |||
| } | |||
| return {{{out_layout, tshp.comp_node}}, true}; | |||
| dest.comp_node = inputs[0].comp_node; | |||
| dest.layout = _InferLayout<rng_with_shape>::do_infer(inputs[0], xxx_rng_def); | |||
| return {{dest}, true}; | |||
| } | |||
| } // anonymous namespace | |||
| @@ -333,6 +482,10 @@ uint64_t get_global_rng_seed() { | |||
| return RNGDnnOpManager::get_glob_default_seed(); | |||
| } | |||
| CompNode get_rng_handle_compnode(Handle handle){ | |||
| return RNGDnnOpManager::get_comp_node(handle); | |||
| } | |||
| #define REG_RNG_OP(NAME)\ | |||
| namespace { \ | |||
| OP_TRAIT_REG(NAME, NAME, OpMeth<NAME>::OpNode) \ | |||
| @@ -344,6 +497,11 @@ OP_TRAIT_REG(NAME, NAME, OpMeth<NAME>::OpNode) \ | |||
| REG_RNG_OP(UniformRNG) | |||
| REG_RNG_OP(GaussianRNG) | |||
| REG_RNG_OP(GammaRNG) | |||
| REG_RNG_OP(PermutationRNG) | |||
| REG_RNG_OP(PoissonRNG) | |||
| REG_RNG_OP(BetaRNG) | |||
| #undef REG_RNG_OP | |||
| } // namespace mgb::imperative::rng | |||
| @@ -22,5 +22,6 @@ Handle new_handle(CompNode comp_node, uint64_t seed); | |||
| size_t delete_handle(Handle handle); | |||
| void set_global_rng_seed(uint64_t seed); | |||
| uint64_t get_global_rng_seed(); | |||
| CompNode get_rng_handle_compnode(Handle handle); | |||
| } // namespace mgb::imperative::rng | |||
| @@ -42,14 +42,72 @@ void check_rng_basic(Args&& ...args) { | |||
| } | |||
| } | |||
| template<typename Op, typename ...Args> | |||
| void check_rng_with_input_basic(const CompNode &cn, | |||
| const SmallVector<TensorPtr> &inputs, Args&& ...args) { | |||
| Handle h = new_handle(cn, 123); | |||
| auto op = Op::make(std::forward<Args>(args)..., h); | |||
| auto outputs = OpDef::apply_on_physical_tensor(*op, inputs); | |||
| ASSERT_TRUE(outputs[0]->layout().eq_shape(inputs[0]->shape())); | |||
| ASSERT_TRUE(cn == outputs[0]->comp_node()); | |||
| // sync before delete handle | |||
| for (auto&& p: outputs) { | |||
| p->get_value(); | |||
| } | |||
| delete_handle(h); | |||
| } | |||
| TEST(TestImperative, PoissonRNGBasic) { | |||
| REQUIRE_XPU(2); | |||
| for (auto&& cn: {CompNode::load("xpu0"), CompNode::load("xpu1")}){ | |||
| TensorShape shape{5, 3000}; | |||
| HostTensorND lam{cn, shape, dtype::Float32()}; | |||
| auto lam_ptr = lam.ptr<float>(); | |||
| for( int i = 0; i < 5*3000; ++i) lam_ptr[i] = 2; | |||
| SmallVector<TensorPtr> inputs{Tensor::make(lam)}; | |||
| check_rng_with_input_basic<PoissonRNG>(cn, inputs, 123); | |||
| } | |||
| } | |||
| TEST(TestImperative, BetaRNGBasic) { | |||
| REQUIRE_XPU(2); | |||
| for (auto&& cn: {CompNode::load("xpu0"), CompNode::load("xpu1")}){ | |||
| TensorShape shape{5, 3000}; | |||
| HostTensorND alpha{cn, shape, dtype::Float32()}, | |||
| beta{cn, shape, dtype::Float32()}; | |||
| auto lam_ptr = alpha.ptr<float>(), beta_ptr = beta.ptr<float>(); | |||
| for( int i = 0; i < 5*3000; ++i) lam_ptr[i] = 2, beta_ptr[i] = 2; | |||
| SmallVector<TensorPtr> inputs{Tensor::make(alpha), Tensor::make(beta)}; | |||
| check_rng_with_input_basic<BetaRNG>(cn, inputs, 123); | |||
| } | |||
| } | |||
| TEST(TestImperative, GammaRNGBasic) { | |||
| REQUIRE_XPU(2); | |||
| for (auto&& cn: {CompNode::load("xpu0"), CompNode::load("xpu1")}){ | |||
| TensorShape size{5, 3000}; | |||
| HostTensorND shape{cn, size, dtype::Float32()}, | |||
| scale{cn, size, dtype::Float32()}; | |||
| auto shape_ptr = shape.ptr<float>(), scale_ptr = scale.ptr<float>(); | |||
| for( int i = 0; i < 5*3000; ++i) shape_ptr[i] = 2, scale_ptr[i] = 2; | |||
| SmallVector<TensorPtr> inputs{Tensor::make(shape), Tensor::make(scale)}; | |||
| check_rng_with_input_basic<GammaRNG>(cn, inputs, 123); | |||
| } | |||
| } | |||
| TEST(TestImperative, UniformRNGBasic) { | |||
| REQUIRE_XPU(2); | |||
| check_rng_basic<UniformRNG>(123); | |||
| check_rng_basic<UniformRNG>(123, dtype::Float32()); | |||
| } | |||
| TEST(TestImperative, GaussianRNGBasic) { | |||
| REQUIRE_XPU(2); | |||
| check_rng_basic<GaussianRNG>(123, 2.f, 3.f); | |||
| check_rng_basic<GaussianRNG>(123, 2.f, 3.f, dtype::Float32()); | |||
| } | |||
| TEST(TestImperative, PermutationRNGBasic) { | |||
| REQUIRE_XPU(2); | |||
| check_rng_basic<PermutationRNG>(123, dtype::Int32()); | |||
| } | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -123,9 +123,13 @@ def UniformRNG: MgbHashableOp<"UniformRNG", [UniformRNGParam]> { | |||
| let hashFunction = [{ | |||
| return mgb::hash_pair_combine( | |||
| mgb::hash($_self.dyn_typeinfo()), | |||
| mgb::hash($_self.handle)); | |||
| mgb::hash_pair_combine( | |||
| mgb::hash($_self.handle), | |||
| mgb::hash($_self.dtype.enumv()) | |||
| ) | |||
| ); | |||
| }]; | |||
| let cmpFunction = [{return $0.handle == $1.handle;}]; | |||
| let cmpFunction = [{return $0.handle == $1.handle && $0.dtype == $1.dtype;}]; | |||
| } | |||
| def GaussianRNG: MgbHashableOp<"GaussianRNG", [GaussianRNGParam]> { | |||
| @@ -139,11 +143,70 @@ def GaussianRNG: MgbHashableOp<"GaussianRNG", [GaussianRNGParam]> { | |||
| mgb::hash($_self.handle), | |||
| mgb::hash_pair_combine( | |||
| mgb::hash($_self.mean), | |||
| mgb::hash($_self.std)) | |||
| mgb::hash_pair_combine( | |||
| mgb::hash($_self.std), | |||
| mgb::hash($_self.dtype.enumv()) | |||
| ) | |||
| ) | |||
| ) | |||
| ); | |||
| }]; | |||
| let cmpFunction = [{return $0.handle == $1.handle && $0.mean == $1.mean && $0.std == $1.std && $0.dtype == $1.dtype;}]; | |||
| } | |||
| def GammaRNG: MgbHashableOp<"GammaRNG", [GammaRNGParam]> { | |||
| let extraArguments = (ins | |||
| MgbSizeTAddr:$handle | |||
| ); | |||
| let hashFunction = [{ | |||
| return mgb::hash_pair_combine( | |||
| mgb::hash($_self.dyn_typeinfo()), | |||
| mgb::hash($_self.handle) | |||
| ); | |||
| }]; | |||
| let cmpFunction = [{return $0.handle == $1.handle;}]; | |||
| } | |||
| def PoissonRNG: MgbHashableOp<"PoissonRNG", [PoissonRNGParam]> { | |||
| let extraArguments = (ins | |||
| MgbSizeTAddr:$handle | |||
| ); | |||
| let hashFunction = [{ | |||
| return mgb::hash_pair_combine( | |||
| mgb::hash($_self.dyn_typeinfo()), | |||
| mgb::hash($_self.handle) | |||
| ); | |||
| }]; | |||
| let cmpFunction = [{return $0.handle == $1.handle;}]; | |||
| } | |||
| def BetaRNG: MgbHashableOp<"BetaRNG", [BetaRNGParam]> { | |||
| let extraArguments = (ins | |||
| MgbSizeTAddr:$handle | |||
| ); | |||
| let hashFunction = [{ | |||
| return mgb::hash_pair_combine( | |||
| mgb::hash($_self.dyn_typeinfo()), | |||
| mgb::hash($_self.handle) | |||
| ); | |||
| }]; | |||
| let cmpFunction = [{return $0.handle == $1.handle;}]; | |||
| } | |||
| def PermutationRNG: MgbHashableOp<"PermutationRNG", [PermutationRNGParam]> { | |||
| let extraArguments = (ins | |||
| MgbSizeTAddr:$handle | |||
| ); | |||
| let hashFunction = [{ | |||
| return mgb::hash_pair_combine( | |||
| mgb::hash($_self.dyn_typeinfo()), | |||
| mgb::hash_pair_combine( | |||
| mgb::hash($_self.handle), | |||
| mgb::hash($_self.dtype.enumv()) | |||
| ) | |||
| ); | |||
| }]; | |||
| let cmpFunction = [{return $0.handle == $1.handle && $0.mean == $1.mean && $0.std == $1.std;}]; | |||
| let cmpFunction = [{return $0.handle == $1.handle && $0.dtype == $1.dtype;}]; | |||
| } | |||
| def Linspace: MgbHashableOp<"Linspace", [LinspaceParam]> { | |||
| @@ -19,46 +19,21 @@ using namespace mgb; | |||
| using namespace opr; | |||
| using namespace intl; | |||
| namespace { | |||
| template<class MegDNNOpr> | |||
| struct RNGName; | |||
| template<> | |||
| struct RNGName<megdnn::UniformRNG> { | |||
| static constexpr const char* name = "uniform_rng"; | |||
| }; | |||
| template<> | |||
| struct RNGName<megdnn::GaussianRNG> { | |||
| static constexpr const char* name = "gaussian_rng"; | |||
| }; | |||
| } // anonymous namespace | |||
| RNGOprBase::RNGOprBase(const OperatorNodeBaseCtorParam &opr, VarNode *shape): | |||
| Super(opr) | |||
| template<typename MegDNNOpr> | |||
| RNGOprBase<MegDNNOpr>::RNGOprBase(const OperatorNodeBaseCtorParam &opr, const Param ¶m): | |||
| Super(opr),m_param(param) | |||
| { | |||
| add_input({shape}); | |||
| add_output(None)->dtype(dtype::Float32()); | |||
| cg::add_workspace_output(this); | |||
| // disable dedup | |||
| add_equivalence_component<ScalarHash<void*>>(this); | |||
| } | |||
| RNGOprBase::~RNGOprBase() { | |||
| } | |||
| cg::OperatorNodeBase::NodeProp* RNGOprBase::do_make_node_prop() const { | |||
| auto prop = Super::do_make_node_prop(); | |||
| prop->add_flag(NodeProp::Flag::IMPURE_FUNC); | |||
| prop->reset_dep_type(input(), {NodeProp::DepType::HOST_VALUE}); | |||
| return prop; | |||
| template<class MegDNNOpr> | |||
| UniqPtrWithCN<MegDNNOpr> RNGOprBase<MegDNNOpr>::create_megdnn_opr() { | |||
| auto opr = intl::create_megdnn_opr<MegDNNOpr>(comp_node()); | |||
| opr->param() = param(); | |||
| return opr; | |||
| } | |||
| void RNGOprBase::ensure_megdnn_opr() { | |||
| template<typename MegDNNOpr> | |||
| void RNGOprBase<MegDNNOpr>::ensure_megdnn_opr() { | |||
| if (!m_dnn_opr || m_dnn_opr.comp_node() != comp_node()) { | |||
| // activate comp_node for curandCreateGenerator in create_megdnn_opr | |||
| comp_node().activate(); | |||
| @@ -66,53 +41,120 @@ void RNGOprBase::ensure_megdnn_opr() { | |||
| } | |||
| } | |||
| void RNGOprBase::init_output_static_infer_desc() { | |||
| using namespace cg::static_infer; | |||
| auto &&mgr = owner_graph()->static_infer_manager(); | |||
| auto infer_out = [](TensorShape &dest, const InpVal &inp) { | |||
| cg::copy_tensor_value_to_shape(dest, inp.val.at(0).value()); | |||
| return true; | |||
| }; | |||
| auto infer_wk = [this](TensorShape &dest, const InpVal &inp) { | |||
| ensure_megdnn_opr(); | |||
| dest.ndim = 1; | |||
| dest.shape[0] = m_dnn_opr->get_workspace_in_bytes( | |||
| {inp.val.at(0).shape(), output(0)->dtype()}); | |||
| return true; | |||
| }; | |||
| mgr.register_shape_infer(output(0), | |||
| {SourceType::DEP, {{input(0), DepType::VALUE}}, infer_out}); | |||
| mgr.register_shape_infer(output(1), | |||
| {SourceType::DEP, {{output(0), DepType::SHAPE}}, infer_wk}); | |||
| /* ================= RNG with shape ================= */ | |||
| #define _INST_RNG_OPR_WITH_SHAPE(RNGOpr, name) \ | |||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(RNGOpr); \ | |||
| cg::OperatorNodeBase::NodeProp* RNGOpr::do_make_node_prop() const { \ | |||
| auto prop = Super::do_make_node_prop(); \ | |||
| prop->add_flag(NodeProp::Flag::IMPURE_FUNC); \ | |||
| prop->reset_dep_type(input(), {NodeProp::DepType::HOST_VALUE}); \ | |||
| return prop; \ | |||
| } \ | |||
| RNGOpr::RNGOpr(VarNode *shape, const Param ¶m, \ | |||
| const OperatorNodeConfig &config): \ | |||
| Super({shape->owner_graph(), config, (name), {shape}}, param) \ | |||
| { \ | |||
| DType dtype = DType::from_enum(param.dtype); \ | |||
| add_input({shape}); \ | |||
| add_output(None)->dtype(dtype); \ | |||
| cg::add_workspace_output(this); \ | |||
| add_equivalence_component<ScalarHash<void*>>(this); \ | |||
| } \ | |||
| SymbolVar RNGOpr::make(SymbolVar shape, const Param ¶m, \ | |||
| const OperatorNodeConfig &config){ \ | |||
| return shape.insert_single_output_opr<RNGOpr>(shape.node(), param, config); \ | |||
| } \ | |||
| void RNGOpr::init_output_static_infer_desc() { \ | |||
| using namespace cg::static_infer; \ | |||
| auto &&mgr = owner_graph()->static_infer_manager(); \ | |||
| auto infer_out = [](TensorShape &dest, const InpVal &inp) { \ | |||
| cg::copy_tensor_value_to_shape(dest, inp.val.at(0).value()); \ | |||
| return true; \ | |||
| }; \ | |||
| auto infer_wk = [this](TensorShape &dest, const InpVal &inp) { \ | |||
| ensure_megdnn_opr(); \ | |||
| dest.ndim = 1; \ | |||
| dest.shape[0] = m_dnn_opr->get_workspace_in_bytes( \ | |||
| {inp.val.at(0).shape(), output(0)->dtype()}); \ | |||
| return true; \ | |||
| }; \ | |||
| mgr.register_shape_infer(output(0), \ | |||
| {SourceType::DEP, {{input(0), DepType::VALUE}}, infer_out}); \ | |||
| mgr.register_shape_infer(output(1), \ | |||
| {SourceType::DEP, {{output(0), DepType::SHAPE}}, infer_wk}); \ | |||
| } \ | |||
| void RNGOpr::scn_do_execute() { \ | |||
| m_dnn_opr->exec(output(0)->dev_tensor().as_megdnn(), \ | |||
| get_megdnn_workspace_from_var(output(1))); \ | |||
| } | |||
| void RNGOprBase::scn_do_execute() { | |||
| m_dnn_opr->exec( | |||
| output(0)->dev_tensor().as_megdnn(), | |||
| get_megdnn_workspace_from_var(output(1))); | |||
| } | |||
| template<class MegDNNOpr> | |||
| RNGOpr<MegDNNOpr>::RNGOpr(VarNode *shape, const Param ¶m, | |||
| const OperatorNodeConfig &config): | |||
| Super({shape->owner_graph(), config, RNGName<MegDNNOpr>::name, {shape}}, | |||
| shape), | |||
| m_param(param) | |||
| { | |||
| } | |||
| template<class MegDNNOpr> | |||
| SymbolVar RNGOpr<MegDNNOpr>::make(SymbolVar shape, const Param ¶m, | |||
| const OperatorNodeConfig &config) { | |||
| return shape.insert_single_output_opr<RNGOpr>(shape.node(), param, config); | |||
| } | |||
| template<class MegDNNOpr> | |||
| UniqPtrWithCN<megdnn::RNGBase> RNGOpr<MegDNNOpr>::create_megdnn_opr() { | |||
| auto opr = intl::create_megdnn_opr<MegDNNOpr>(comp_node()); | |||
| opr->param() = param(); | |||
| return opr; | |||
| } | |||
| _INST_RNG_OPR_WITH_SHAPE(UniformRNG,"uniform_rng") | |||
| _INST_RNG_OPR_WITH_SHAPE(GaussianRNG,"gaussian_rng") | |||
| _INST_RNG_OPR_WITH_SHAPE(PermutationRNG,"permutation_rng") | |||
| #undef _INST_RNG_OPR_WITH_SHAPE | |||
| /* ================= RNG with input ================= */ | |||
| #define _AS_MEGDNN(idx) input((idx))->dev_tensor().as_megdnn() | |||
| #define _INFER_WK_DEPS(idx) {input((idx)), DepType::SHAPE} | |||
| #define _INFER_WK_ARGS(idx) {inp.val.at((idx)).shape(), input((idx))->dtype()} | |||
| #define _INST_RNG_OPR_WITH_INPUT(RNGOpr, name) \ | |||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(RNGOpr); \ | |||
| RNGOpr::RNGOpr(_INPUTS(VarNode*,), const Param ¶m, \ | |||
| const OperatorNodeConfig &config): \ | |||
| Super({i0->owner_graph(), config, (name), {_INPUTS(,)}}, param) \ | |||
| { \ | |||
| add_input({_INPUTS(,)}); \ | |||
| add_output(None)->dtype(i0->dtype()); \ | |||
| cg::add_workspace_output(this); \ | |||
| add_equivalence_component<ScalarHash<void*>>(this); \ | |||
| } \ | |||
| SymbolVar RNGOpr::make(_INPUTS(SymbolVar,), const Param ¶m, \ | |||
| const OperatorNodeConfig &config){ \ | |||
| return i0.insert_single_output_opr<RNGOpr>(_INPUTS(,.node()), param, config); \ | |||
| } \ | |||
| void RNGOpr::init_output_static_infer_desc() { \ | |||
| using namespace cg::static_infer; \ | |||
| auto &&mgr = owner_graph()->static_infer_manager(); \ | |||
| auto infer_wk = [this](TensorShape &dest, const InpVal &inp) { \ | |||
| ensure_megdnn_opr(); \ | |||
| dest.ndim = 1; \ | |||
| dest.shape[0] = m_dnn_opr->get_workspace_in_bytes( \ | |||
| _FOR_EACH(_INFER_WK_ARGS), \ | |||
| {output(0)->shape(), output(0)->dtype()}); \ | |||
| return true; \ | |||
| }; \ | |||
| mgr.register_shape_infer(output(0),ShapeInferDesc::make_identity(input(0))); \ | |||
| mgr.register_shape_infer(output(1),{SourceType::DEP, {_FOR_EACH(_INFER_WK_DEPS)}, \ | |||
| infer_wk}); \ | |||
| } \ | |||
| void RNGOpr::add_input_layout_constraint(){ \ | |||
| for (auto i : input()) i->add_layout_constraint_contiguous(); \ | |||
| }; \ | |||
| void RNGOpr::scn_do_execute() { \ | |||
| m_dnn_opr->exec(_FOR_EACH(_AS_MEGDNN),output(0)->dev_tensor().as_megdnn(), \ | |||
| get_megdnn_workspace_from_var(output(1))); \ | |||
| } | |||
| /* ================= 1 input ================= */ | |||
| #define _INPUTS(prefix, subfix) prefix i0 subfix | |||
| #define _FOR_EACH(cb) cb(0) | |||
| _INST_RNG_OPR_WITH_INPUT(PoissonRNG,"poisson_rng") | |||
| #undef _INPUTS | |||
| #undef _FOR_EACH | |||
| /* ================= 2 input ================= */ | |||
| #define _INPUTS(prefix,subfix) prefix i0 subfix, prefix i1 subfix | |||
| #define _FOR_EACH(cb) cb(0), cb(1) | |||
| _INST_RNG_OPR_WITH_INPUT(BetaRNG,"beta_rng") | |||
| _INST_RNG_OPR_WITH_INPUT(GammaRNG,"gamma_rng") | |||
| #undef _INPUTS | |||
| #undef _FOR_EACH | |||
| #undef _AS_MEGDNN | |||
| #undef _INFER_WK_DEPS | |||
| #undef _INFER_WK_ARGS | |||
| #undef _INST_RNG_OPR_WITH_INPUT | |||
| #define IMPL(_cls) \ | |||
| MGB_IMPL_OPR_GRAD(_cls) { \ | |||
| @@ -123,13 +165,21 @@ UniqPtrWithCN<megdnn::RNGBase> RNGOpr<MegDNNOpr>::create_megdnn_opr() { | |||
| namespace mgb { | |||
| namespace opr { | |||
| namespace intl { | |||
| template class RNGOpr<::megdnn::GaussianRNG>; | |||
| template class RNGOpr<::megdnn::UniformRNG>; | |||
| template class RNGOprBase<::megdnn::GaussianRNG>; | |||
| template class RNGOprBase<::megdnn::UniformRNG>; | |||
| template class RNGOprBase<::megdnn::GammaRNG>; | |||
| template class RNGOprBase<::megdnn::PermutationRNG>; | |||
| template class RNGOprBase<::megdnn::BetaRNG>; | |||
| template class RNGOprBase<::megdnn::PoissonRNG>; | |||
| #if MGB_ENABLE_GRAD | |||
| IMPL(GaussianRNG); | |||
| IMPL(UniformRNG); | |||
| IMPL(GammaRNG); | |||
| IMPL(PoissonRNG); | |||
| IMPL(PermutationRNG); | |||
| IMPL(BetaRNG); | |||
| #endif | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -17,6 +17,10 @@ namespace opr { | |||
| MGB_SEREG_OPR(UniformRNG, 1); | |||
| MGB_SEREG_OPR(GaussianRNG, 1); | |||
| MGB_SEREG_OPR(GammaRNG, 2); | |||
| MGB_SEREG_OPR(PoissonRNG, 1); | |||
| MGB_SEREG_OPR(PermutationRNG, 1); | |||
| MGB_SEREG_OPR(BetaRNG, 2); | |||
| } // namespace opr | |||
| } // namespace mgb | |||
| @@ -14,7 +14,6 @@ | |||
| #include "megbrain/graph.h" | |||
| #include "megbrain/opr/internal/out_shape_by_sym_var.h" | |||
| #include "megbrain/opr/internal/megdnn_opr_wrapper.h" | |||
| #include "megdnn/oprs.h" | |||
| namespace mgb { | |||
| @@ -22,60 +21,81 @@ namespace opr { | |||
| namespace intl { | |||
| template<typename MegDNNOpr> | |||
| MGB_DEFINE_CLS_WITH_SUPER(RNGOprBase, cg::SingleCNOperatorNodeBase) // { | |||
| UniqPtrWithCN<megdnn::RNGBase> m_dnn_opr; | |||
| void ensure_megdnn_opr(); | |||
| void init_output_static_infer_desc() override; | |||
| void scn_do_execute() override final; | |||
| protected: | |||
| RNGOprBase(const OperatorNodeBaseCtorParam &opr, VarNode *shape); | |||
| ~RNGOprBase(); | |||
| NodeProp* do_make_node_prop() const override; | |||
| virtual UniqPtrWithCN<megdnn::RNGBase> create_megdnn_opr() = 0; | |||
| }; | |||
| template<class MegDNNOpr> | |||
| MGB_DEFINE_OPR_CLASS(RNGOpr, RNGOprBase) // { | |||
| public: | |||
| using Param = typename MegDNNOpr::Param; | |||
| RNGOpr(VarNode *shape, const Param ¶m, | |||
| const OperatorNodeConfig &config); | |||
| static SymbolVar make(SymbolVar shape, const Param ¶m = {}, | |||
| const OperatorNodeConfig &config = {}); | |||
| static SymbolVar make(ComputingGraph &graph, const TensorShape &shape, | |||
| const OperatorNodeConfig &config, | |||
| const Param ¶m = {}) { | |||
| return make(var_from_tensor_shape(graph, config, "rng", shape), | |||
| param, config); | |||
| } | |||
| const Param& param() const { | |||
| return m_param; | |||
| } | |||
| private: | |||
| Param m_param; | |||
| UniqPtrWithCN<megdnn::RNGBase> create_megdnn_opr() override; | |||
| UniqPtrWithCN<MegDNNOpr> create_megdnn_opr(); | |||
| protected: | |||
| ~RNGOprBase(){}; | |||
| RNGOprBase(const OperatorNodeBaseCtorParam &opr, const Param ¶m); | |||
| void ensure_megdnn_opr(); | |||
| UniqPtrWithCN<MegDNNOpr> m_dnn_opr; | |||
| }; | |||
| /* ================= RNG with shape ================= */ | |||
| #define _DEFINE_RNG_OPR_WITH_SHAPE_CLASS(RNG) \ | |||
| MGB_DEFINE_OPR_CLASS(RNG,RNGOprBase<megdnn::RNG>) \ | |||
| cg::OperatorNodeBase::NodeProp* do_make_node_prop() const override; \ | |||
| public: \ | |||
| RNG(VarNode *shape, const Param ¶m, const OperatorNodeConfig &config); \ | |||
| static SymbolVar make(SymbolVar shape, const Param ¶m = {}, \ | |||
| const OperatorNodeConfig &config = {}); \ | |||
| static SymbolVar make(ComputingGraph &graph, const TensorShape &shape, \ | |||
| const OperatorNodeConfig &config, \ | |||
| const Param ¶m = {}) { \ | |||
| return make(var_from_tensor_shape(graph, config, "rng", shape), \ | |||
| param, config); \ | |||
| } \ | |||
| void init_output_static_infer_desc() override; \ | |||
| void scn_do_execute() override; \ | |||
| }; | |||
| #undef _MGB_DYN_TYPE_OBJ_FINAL_IMPL_TPL | |||
| #define _MGB_DYN_TYPE_OBJ_FINAL_IMPL_TPL template<class MegDNNOpr> | |||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(RNGOpr<MegDNNOpr>); | |||
| #undef _MGB_DYN_TYPE_OBJ_FINAL_IMPL_TPL | |||
| #define _MGB_DYN_TYPE_OBJ_FINAL_IMPL_TPL | |||
| _DEFINE_RNG_OPR_WITH_SHAPE_CLASS(UniformRNG) | |||
| _DEFINE_RNG_OPR_WITH_SHAPE_CLASS(GaussianRNG) | |||
| _DEFINE_RNG_OPR_WITH_SHAPE_CLASS(PermutationRNG) | |||
| #undef _DEFINE_RNG_OPR_WITH_SHAPE_CLASS | |||
| /* ================= RNG with input ================= */ | |||
| #define _DEFINE_RNG_OPR_WITH_INPUT_CLASS(RNG) \ | |||
| MGB_DEFINE_OPR_CLASS(RNG, RNGOprBase<megdnn::RNG>) \ | |||
| void add_input_layout_constraint() override; \ | |||
| public: \ | |||
| RNG(_INPUTS(VarNode*), const Param ¶m, \ | |||
| const OperatorNodeConfig &config); \ | |||
| static SymbolVar make(_INPUTS(SymbolVar),const Param ¶m = {}, \ | |||
| const OperatorNodeConfig &config = {}); \ | |||
| void init_output_static_infer_desc() override; \ | |||
| void scn_do_execute() override; \ | |||
| }; | |||
| } // intl | |||
| /* ================= 1 input ================= */ | |||
| #define _INPUTS(preifx) preifx i0 | |||
| _DEFINE_RNG_OPR_WITH_INPUT_CLASS(PoissonRNG) | |||
| #undef _INPUTS | |||
| using UniformRNG = intl::RNGOpr<megdnn::UniformRNG>; | |||
| using GaussianRNG = intl::RNGOpr<megdnn::GaussianRNG>; | |||
| /* ================= 2 input ================= */ | |||
| #define _INPUTS(preifx) preifx i0, preifx i1 | |||
| _DEFINE_RNG_OPR_WITH_INPUT_CLASS(BetaRNG) | |||
| _DEFINE_RNG_OPR_WITH_INPUT_CLASS(GammaRNG) | |||
| #undef _INPUTS | |||
| #undef _DEFINE_RNG_OPR_WITH_INPUT_CLASS | |||
| } // intl | |||
| using UniformRNG = intl::UniformRNG; | |||
| using GaussianRNG = intl::GaussianRNG; | |||
| using GammaRNG = intl::GammaRNG; | |||
| using PermutationRNG = intl::PermutationRNG; | |||
| using PoissonRNG = intl::PoissonRNG; | |||
| using BetaRNG = intl::BetaRNG; | |||
| } // namespace opr | |||
| } // namespace mgb | |||
| @@ -19,84 +19,76 @@ | |||
| using namespace mgb; | |||
| namespace { | |||
| struct BasicStat { | |||
| double mean, std, min, max; | |||
| static BasicStat make(const float *ptr, size_t size, | |||
| double mean_expect = 0) { | |||
| double sum = 0, sum2 = 0, | |||
| min = std::numeric_limits<double>::max(), | |||
| max = std::numeric_limits<double>::lowest(); | |||
| for (size_t i = 0; i < size; ++ i) { | |||
| double cur = ptr[i]; | |||
| min = std::min(min, cur); | |||
| max = std::max(max, cur); | |||
| cur -= mean_expect; | |||
| sum += cur; | |||
| sum2 += cur * cur; | |||
| } | |||
| double mean = sum / size + mean_expect, | |||
| std = sqrt((sum2 - sum * sum / size) / (size - 1)); | |||
| return {mean, std, min, max}; | |||
| struct BasicStat { | |||
| double mean, std, min, max; | |||
| static BasicStat make(const float* ptr, size_t size, | |||
| double mean_expect = 0) { | |||
| double sum = 0, sum2 = 0, min = std::numeric_limits<double>::max(), | |||
| max = std::numeric_limits<double>::lowest(); | |||
| for (size_t i = 0; i < size; ++i) { | |||
| double cur = ptr[i]; | |||
| min = std::min(min, cur); | |||
| max = std::max(max, cur); | |||
| cur -= mean_expect; | |||
| sum += cur; | |||
| sum2 += cur * cur; | |||
| } | |||
| }; | |||
| void check_reproducibility( | |||
| thin_function<SymbolVar(SymbolVar, uint64_t seed)> make) { | |||
| auto graph = ComputingGraph::make(); | |||
| constexpr size_t SIZE = 123; | |||
| // out[func][opr][run] | |||
| HostTensorND out[2][2][2]; | |||
| auto run = [&](int fid) { | |||
| SymbolVar | |||
| o0 = make(cg::var_from_tensor_shape(*graph, | |||
| {CompNode::load("xpu0")}, "shp0", {SIZE}), 0), | |||
| o1 = make(cg::var_from_tensor_shape(*graph, | |||
| {CompNode::load("xpu0")}, "shp0", {SIZE}), 1); | |||
| HostTensorND host_o0, host_o1; | |||
| auto func = graph->compile({ | |||
| make_callback_copy(o0, host_o0), | |||
| make_callback_copy(o1, host_o1)}); | |||
| for (int i = 0; i < 2; ++ i) { | |||
| func->execute(); | |||
| out[fid][0][i].copy_from(host_o0); | |||
| out[fid][1][i].copy_from(host_o1); | |||
| } | |||
| }; | |||
| run(0); | |||
| run(1); | |||
| for (int i = 0; i < 2; ++ i) { | |||
| for (int j = 0; j < 2; ++ j) | |||
| MGB_ASSERT_TENSOR_EQ(out[0][i][j], out[1][i][j]); | |||
| double mean = sum / size + mean_expect, | |||
| std = sqrt((sum2 - sum * sum / size) / (size - 1)); | |||
| return {mean, std, min, max}; | |||
| } | |||
| }; | |||
| void check_reproducibility(std::shared_ptr<ComputingGraph> graph, size_t size, | |||
| thin_function<SymbolVar(uint64_t seed)> make) { | |||
| // out[func][opr][run] | |||
| HostTensorND out[2][2][2]; | |||
| auto run = [&](int fid) { | |||
| SymbolVar o0 = make(0), o1 = make(1); | |||
| HostTensorND host_o0, host_o1; | |||
| auto func = graph->compile({make_callback_copy(o0, host_o0), | |||
| make_callback_copy(o1, host_o1)}); | |||
| for (int i = 0; i < 2; ++i) { | |||
| func->execute(); | |||
| out[fid][0][i].copy_from(host_o0); | |||
| out[fid][1][i].copy_from(host_o1); | |||
| } | |||
| }; | |||
| run(0); | |||
| run(1); | |||
| auto max_diff = [&](int off0, int off1) { | |||
| float diff = 0; | |||
| auto p0 = out[0][off0 / 2][off0 % 2].ptr<float>(), | |||
| p1 = out[0][off1 / 2][off1 % 2].ptr<float>(); | |||
| for (size_t i = 0; i < SIZE; ++ i) { | |||
| update_max(diff, std::abs(p0[i] - p1[i])); | |||
| } | |||
| return diff; | |||
| }; | |||
| for (int i = 0; i < 4; ++ i) { | |||
| for (int j = i + 1; j < 4; ++ j) | |||
| ASSERT_GT(max_diff(i, j), 0.3) << i << " " << j; | |||
| for (int i = 0; i < 2; ++i) { | |||
| for (int j = 0; j < 2; ++j) | |||
| MGB_ASSERT_TENSOR_EQ(out[0][i][j], out[1][i][j]); | |||
| } | |||
| auto max_diff = [&](int off0, int off1) { | |||
| float diff = 0; | |||
| auto p0 = out[0][off0 / 2][off0 % 2].ptr<float>(), | |||
| p1 = out[0][off1 / 2][off1 % 2].ptr<float>(); | |||
| for (size_t i = 0; i < size; ++i) { | |||
| update_max(diff, std::abs(p0[i] - p1[i])); | |||
| } | |||
| return diff; | |||
| }; | |||
| for (int i = 0; i < 4; ++i) { | |||
| for (int j = i + 1; j < 4; ++j) | |||
| ASSERT_GT(max_diff(i, j), 0.3) << i << " " << j; | |||
| } | |||
| } | |||
| } // anonymous namespace | |||
| } // anonymous namespace | |||
| TEST(TestOprRand, Uniform) { | |||
| static constexpr size_t M = 128, N = 64; | |||
| auto graph = ComputingGraph::make(); | |||
| SymbolVar dev_out = opr::UniformRNG::make( | |||
| *graph, {M, N}, {CompNode::load("xpu0")}); | |||
| *graph, {M, N}, {CompNode::load("xpu0")}, {23, DTypeEnum::Float32}); | |||
| HostTensorND host_out; | |||
| auto func = graph->compile({make_callback_copy(dev_out, host_out)}); | |||
| @@ -115,9 +107,10 @@ TEST(TestOprRand, Gaussian) { | |||
| static constexpr size_t SIZE = 123451; | |||
| constexpr float MEAN = 1, STD = 2; | |||
| auto graph = ComputingGraph::make(); | |||
| auto y = opr::GaussianRNG::make( | |||
| SymbolVar::make_scalar(int(SIZE), *graph, {CompNode::load("xpu0")}), | |||
| {23, MEAN, STD}); | |||
| {23, MEAN, STD, DTypeEnum::Float32}); | |||
| HostTensorND host_y; | |||
| auto func = graph->compile({make_callback_copy(y, host_y)}); | |||
| @@ -130,17 +123,212 @@ TEST(TestOprRand, Gaussian) { | |||
| ASSERT_LT(fabs(stat.std - STD), 0.1); | |||
| } | |||
| TEST(TestOprRand, Gamma) { | |||
| std::shared_ptr<HostTensorND> shape_host(new HostTensorND{ | |||
| CompNode::load("xpux"), TensorShape{2000000*5}, dtype::Float32()}); | |||
| std::shared_ptr<HostTensorND> scale_host(new HostTensorND{ | |||
| CompNode::load("xpux"), TensorShape{2000000*5}, dtype::Float32()}); | |||
| auto shape_ptr = shape_host->ptr<float>(); | |||
| auto scale_ptr = scale_host->ptr<float>(); | |||
| for (int i = 0; i < 5; ++i) { | |||
| for (int j = 0; j < 2000000; ++j) { | |||
| shape_ptr[i * 2000000 + j] = 2 * 0.3 * i + 0.5; | |||
| scale_ptr[i * 2000000 + j] = i * 0.3 + 0.5; | |||
| } | |||
| } | |||
| auto graph = ComputingGraph::make(); | |||
| auto shape_sym = opr::Host2DeviceCopy::make(*graph, shape_host); | |||
| auto scale_sym = opr::Host2DeviceCopy::make(*graph, scale_host); | |||
| auto y = opr::GammaRNG::make(shape_sym, scale_sym, {10}); | |||
| HostTensorND host_y; | |||
| auto func = graph->compile({make_callback_copy(y, host_y)}); | |||
| func->execute(); | |||
| ASSERT_EQ(TensorShape({2000000*5}), host_y.shape()); | |||
| for (int i = 0; i < 5; ++i) { | |||
| float a = 2 * 0.3 * i + 0.5, b = i * 0.3 + 0.5; | |||
| float mean = a * b; | |||
| float std = a * (b * b); | |||
| auto stat = BasicStat::make(host_y.ptr<float>() + 2000000 * i, | |||
| 2000000, mean); | |||
| ASSERT_LT(fabs(stat.mean - mean), 0.01); | |||
| ASSERT_LT(fabs(stat.std - sqrt(std)), 0.01); | |||
| } | |||
| } | |||
| TEST(TestOprRand, Poisson) { | |||
| std::shared_ptr<HostTensorND> lam_host(new HostTensorND{ | |||
| CompNode::load("xpux"), TensorShape{200000*5}, dtype::Float32()}); | |||
| auto lam_ptr = lam_host->ptr<float>(); | |||
| for (int i = 0; i < 5; ++i) { | |||
| for (int j = 0; j < 200000; ++j) { | |||
| lam_ptr[i * 200000 + j] = i + 1; | |||
| } | |||
| } | |||
| auto graph = ComputingGraph::make(); | |||
| auto lam_sym = opr::Host2DeviceCopy::make(*graph, lam_host); | |||
| auto y = opr::PoissonRNG::make(lam_sym, {10}); | |||
| HostTensorND host_y; | |||
| auto func = graph->compile({make_callback_copy(y, host_y)}); | |||
| func->execute(); | |||
| ASSERT_EQ(TensorShape({200000*5}), host_y.shape()); | |||
| for (int i = 0; i < 5; ++i) { | |||
| float lambda = i + 1; | |||
| auto stat = BasicStat::make(host_y.ptr<float>() + 200000 * i, | |||
| 200000,lambda); | |||
| ASSERT_LT(fabs(stat.mean - lambda), 0.01); | |||
| ASSERT_LT(fabs(stat.std - sqrt(lambda)), 0.1); | |||
| } | |||
| } | |||
| TEST(TestOprRand, Beta) { | |||
| std::shared_ptr<HostTensorND> alpha_host(new HostTensorND{ | |||
| CompNode::load("xpux"), TensorShape{200000*5}, dtype::Float32()}); | |||
| std::shared_ptr<HostTensorND> beta_host(new HostTensorND{ | |||
| CompNode::load("xpux"), TensorShape{200000*5}, dtype::Float32()}); | |||
| auto alpha_ptr = alpha_host->ptr<float>(); | |||
| auto beta_ptr = beta_host->ptr<float>(); | |||
| for (int i = 0; i < 5; ++i) { | |||
| for (int j = 0; j < 200000; ++j) { | |||
| alpha_ptr[i * 200000 + j] = 0.3 * i + 0.1; | |||
| beta_ptr[i * 200000 + j] = 2 * i * 0.3 + 0.1; | |||
| } | |||
| } | |||
| auto graph = ComputingGraph::make(); | |||
| auto alpha_sym = opr::Host2DeviceCopy::make(*graph, alpha_host); | |||
| auto beta_sym = opr::Host2DeviceCopy::make(*graph, beta_host); | |||
| auto y = opr::BetaRNG::make(alpha_sym,beta_sym, {10}); | |||
| HostTensorND host_y; | |||
| auto func = graph->compile({make_callback_copy(y, host_y)}); | |||
| func->execute(); | |||
| ASSERT_EQ(TensorShape({200000*5}), host_y.shape()); | |||
| for (int i = 0; i < 5; ++i) { | |||
| float a = 0.3 * i + 0.1, b = 2 * i * 0.3 + 0.1; | |||
| float mean = a / (a + b); | |||
| float std = a * b / ((a + b) * (a + b) * (a + b + 1)); | |||
| auto stat = BasicStat::make(host_y.ptr<float>() + 200000 * i, | |||
| 200000, mean); | |||
| ASSERT_LT(fabs(stat.mean - mean), 0.01); | |||
| ASSERT_LT(fabs(stat.std - sqrt(std)), 0.01); | |||
| } | |||
| } | |||
| TEST(TestOprRand, PermutationRNG) { | |||
| static constexpr size_t SIZE = 123451; | |||
| auto graph = ComputingGraph::make(); | |||
| auto y = opr::PermutationRNG::make( | |||
| SymbolVar::make_scalar(int(SIZE), *graph, {CompNode::load("xpu0")}), | |||
| {23, DTypeEnum::Int32}); | |||
| HostTensorND host_y; | |||
| auto func = graph->compile({make_callback_copy(y, host_y)}); | |||
| func->execute(); | |||
| ASSERT_EQ(TensorShape({SIZE}), host_y.shape()); | |||
| auto ptr = host_y.ptr<int32_t>(); | |||
| std::vector<int32_t> res(SIZE); | |||
| int not_same = 0; | |||
| for (size_t i = 0; i < SIZE; ++i) { | |||
| if ((ptr[i] - int32_t(i)) >= 1) not_same++; | |||
| res[i] = ptr[i]; | |||
| } | |||
| ASSERT_GT(not_same, 5000); | |||
| std::sort(res.begin(), res.end()); | |||
| for (size_t i = 0; i < SIZE; ++i) { | |||
| ASSERT_LE(std::abs(res[i] - int32_t(i)), 1e-8); | |||
| } | |||
| } | |||
| TEST(TestOprRand, UniformReprod) { | |||
| check_reproducibility([](SymbolVar shp, uint64_t seed) { | |||
| static constexpr size_t SIZE = 123; | |||
| auto graph = ComputingGraph::make(); | |||
| auto shp = cg::var_from_tensor_shape(*graph, {CompNode::load("xpu0")}, | |||
| "shp0", {SIZE}); | |||
| check_reproducibility(graph, SIZE, [&shp](uint64_t seed) { | |||
| return opr::UniformRNG::make(shp, {seed}); | |||
| }); | |||
| } | |||
| TEST(TestOprRand, GaussianReprod) { | |||
| check_reproducibility([](SymbolVar shp, uint64_t seed) { | |||
| static constexpr size_t SIZE = 123; | |||
| auto graph = ComputingGraph::make(); | |||
| auto shp = cg::var_from_tensor_shape(*graph, {CompNode::load("xpu0")}, | |||
| "shp0", {SIZE}); | |||
| check_reproducibility(graph, SIZE, [&shp](uint64_t seed) { | |||
| return opr::GaussianRNG::make(shp, {seed}); | |||
| }); | |||
| } | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| TEST(TestOprRand, GammaReprod) { | |||
| static constexpr size_t SIZE = 123; | |||
| std::shared_ptr<HostTensorND> shape_host(new HostTensorND{ | |||
| CompNode::load("xpux"), TensorShape{SIZE}, dtype::Float32()}); | |||
| std::shared_ptr<HostTensorND> scale_host(new HostTensorND{ | |||
| CompNode::load("xpux"), TensorShape{SIZE}, dtype::Float32()}); | |||
| auto shape_ptr = shape_host->ptr<float>(); | |||
| auto scale_ptr = scale_host->ptr<float>(); | |||
| for (size_t i = 0; i < SIZE; ++i){ | |||
| shape_ptr[i] = 0.5; | |||
| scale_ptr[i] = 1.2; | |||
| } | |||
| auto graph = ComputingGraph::make(); | |||
| auto shape_sym = opr::Host2DeviceCopy::make(*graph, shape_host); | |||
| auto scale_sym = opr::Host2DeviceCopy::make(*graph, scale_host); | |||
| check_reproducibility(graph, SIZE, [&shape_sym,&scale_sym](uint64_t seed) { | |||
| return opr::GammaRNG::make(shape_sym, scale_sym, {seed}); | |||
| }); | |||
| } | |||
| TEST(TestOprRand, PoissonReprod) { | |||
| static constexpr size_t SIZE = 123; | |||
| std::shared_ptr<HostTensorND> lam_host(new HostTensorND{ | |||
| CompNode::load("xpux"), TensorShape{SIZE}, dtype::Float32()}); | |||
| auto lam_ptr = lam_host->ptr<float>(); | |||
| for (size_t i = 0; i < SIZE; ++i) | |||
| lam_ptr[i] = 2; | |||
| auto graph = ComputingGraph::make(); | |||
| auto lam_sym = opr::Host2DeviceCopy::make(*graph, lam_host); | |||
| check_reproducibility(graph, SIZE, [&lam_sym](uint64_t seed) { | |||
| return opr::PoissonRNG::make(lam_sym, {seed}); | |||
| }); | |||
| } | |||
| TEST(TestOprRand, BetaReprod) { | |||
| static constexpr size_t SIZE = 123; | |||
| std::shared_ptr<HostTensorND> alpha_host(new HostTensorND{ | |||
| CompNode::load("xpux"), TensorShape{SIZE}, dtype::Float32()}); | |||
| std::shared_ptr<HostTensorND> beta_host(new HostTensorND{ | |||
| CompNode::load("xpux"), TensorShape{SIZE}, dtype::Float32()}); | |||
| auto alpha_ptr = alpha_host->ptr<float>(); | |||
| auto beta_ptr = beta_host->ptr<float>(); | |||
| for (size_t i = 0; i < SIZE; ++i){ | |||
| alpha_ptr[i] = 0.5; | |||
| beta_ptr[i] = 1.2; | |||
| } | |||
| auto graph = ComputingGraph::make(); | |||
| auto alpha_sym = opr::Host2DeviceCopy::make(*graph, alpha_host); | |||
| auto beta_sym = opr::Host2DeviceCopy::make(*graph, beta_host); | |||
| check_reproducibility(graph, SIZE, [&alpha_sym,&beta_sym](uint64_t seed) { | |||
| return opr::BetaRNG::make(alpha_sym, beta_sym, {seed}); | |||
| }); | |||
| } | |||
| TEST(TestOprRand, PermutationReprod) { | |||
| static constexpr size_t SIZE = 123; | |||
| auto graph = ComputingGraph::make(); | |||
| auto shp = cg::var_from_tensor_shape(*graph, {CompNode::load("xpu0")}, | |||
| "shp0", {SIZE}); | |||
| check_reproducibility(graph, SIZE, [&shp](uint64_t seed) { | |||
| return opr::PermutationRNG::make(shp, {seed, DTypeEnum::Float32}); | |||
| }); | |||
| } | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -108,6 +108,10 @@ union OperatorParam { | |||
| param.TQT = 74, | |||
| param.Correlation = 75, | |||
| param.LSQ = 76, | |||
| param.GammaRNG = 77, | |||
| param.PoissonRNG = 78, | |||
| param.PermutationRNG = 79, | |||
| param.BetaRNG = 80, | |||
| } | |||
| table Operator { | |||