| @@ -6,7 +6,8 @@ | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| #include "megdnn/internal/opr_header_prologue.h" | |||
| @@ -94,6 +95,42 @@ class PermutationRNG: public RNGBase { | |||
| void check_exec(const TensorLayout &dst, size_t workspace_in_bytes); | |||
| }; | |||
| class ShuffleRNGForward : public OperatorBase { | |||
| DEF_OPR_IMPL(ShuffleRNGForward, OperatorBase, 1, 2); | |||
| DEF_OPR_PARAM(ShuffleRNG); | |||
| public: | |||
| virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
| _megdnn_tensor_out indices, | |||
| _megdnn_workspace workspace) = 0; | |||
| void deduce_layout(const TensorLayout& src, TensorLayout& dst, | |||
| TensorLayout& indices); | |||
| virtual size_t get_workspace_in_bytes(const TensorLayout& src, | |||
| const TensorLayout& dst, | |||
| const TensorLayout& indices) = 0; | |||
| protected: | |||
| void check_exec(const TensorLayout& src, const TensorLayout& dst, | |||
| const TensorLayout& indices, size_t workspace_in_bytes); | |||
| }; | |||
| using ShuffleRNG = ShuffleRNGForward; | |||
| class ShuffleRNGBackward : public OperatorBase { | |||
| DEF_OPR_IMPL(ShuffleRNGBackward, OperatorBase, 2, 1); | |||
| DEF_OPR_PARAM(ShuffleRNG); | |||
| public: | |||
| virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_in indices, | |||
| _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0; | |||
| virtual size_t get_workspace_in_bytes(const TensorLayout& diff, | |||
| const TensorLayout& indices, | |||
| const TensorLayout& grad) = 0; | |||
| protected: | |||
| void check_exec(const TensorLayout& diff, const TensorLayout& indices, | |||
| const TensorLayout& grad, size_t workspace_in_bytes); | |||
| }; | |||
| /*! | |||
| * \brief sleep for specific time on the computing device; useful for testing | |||
| * async problems | |||
| @@ -781,6 +781,9 @@ pdef('Sleep').add_fields('float32', Doc('time', 'time to sleep in seconds'), 0) | |||
| 'Float32 are supported.'), | |||
| 'DTypeEnum::Int32')) | |||
| (pdef('ShuffleRNG'). | |||
| add_fields('uint64', 'seed', 0)) | |||
| (pdef('Flip'). | |||
| add_fields('bool', 'vertical', 'false', 'horizontal', 'false')) | |||
| @@ -165,6 +165,8 @@ private: | |||
| cb(BetaRNG) \ | |||
| cb(PoissonRNG) \ | |||
| cb(PermutationRNG) \ | |||
| cb(ShuffleRNGForward) \ | |||
| cb(ShuffleRNGBackward) \ | |||
| cb(SeparableConvForward) \ | |||
| cb(SeparableFilterForward) \ | |||
| cb(BNForward) \ | |||
| @@ -128,6 +128,8 @@ DEF(GammaRNG, 3, true, true); | |||
| DEF(BetaRNG, 3, true, true); | |||
| DEF(PoissonRNG, 2, true, true); | |||
| DEF(PermutationRNG, 1, true, true); | |||
| DEF(ShuffleRNGForward, 3, true, true); | |||
| DEF(ShuffleRNGBackward, 3, true, false); | |||
| DEF(ChecksumForward, 1, true, false); | |||
| DEF(CheckHasInf, 2, true, true); | |||
| DEF(LSQForward, 5, true, true); | |||
| @@ -15,6 +15,47 @@ | |||
| namespace megdnn { | |||
| void ShuffleRNGForward::deduce_layout(const TensorLayout& src, | |||
| TensorLayout& dst, | |||
| TensorLayout& indices) { | |||
| dst = src; | |||
| indices = TensorLayout(TensorShape({src.shape[0]}), dtype::Int32()); | |||
| } | |||
| void ShuffleRNGForward::check_exec(const TensorLayout& src, | |||
| const TensorLayout& dst, | |||
| const TensorLayout& indices, | |||
| size_t workspace_in_bytes) { | |||
| TensorLayout dst_expected, indices_expected; | |||
| megdnn_assert_contiguous(src); | |||
| deduce_layout(src, dst_expected, indices_expected); | |||
| megdnn_assert_eq_layout(dst_expected, dst); | |||
| megdnn_assert_eq_layout(indices_expected, indices); | |||
| megdnn_assert_contiguous(indices); | |||
| megdnn_assert(src.dtype == dst.dtype); | |||
| megdnn_assert(indices.dtype == dtype::Int32()); | |||
| auto required_workspace_in_bytes = | |||
| get_workspace_in_bytes(src, dst, indices); | |||
| megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | |||
| } | |||
| void ShuffleRNGBackward::check_exec(const TensorLayout& diff, | |||
| const TensorLayout& indices, | |||
| const TensorLayout& grad, | |||
| size_t workspace_in_bytes) { | |||
| megdnn_assert( | |||
| diff.shape[0] == indices.shape[0] && diff.dtype == grad.dtype && | |||
| indices.dtype == dtype::Int32{} && diff.is_contiguous() && | |||
| indices.is_contiguous() && grad.is_contiguous(), | |||
| "invalid layouts: diff=%s indices=%s grad=%s", | |||
| diff.to_string().c_str(), indices.to_string().c_str(), | |||
| grad.to_string().c_str()); | |||
| auto required_workspace_in_bytes = | |||
| get_workspace_in_bytes(diff, indices, grad); | |||
| megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | |||
| } | |||
| void PermutationRNG::check_exec( | |||
| const TensorLayout &dst, size_t workspace_in_bytes) { | |||
| megdnn_assert((dst.dtype == dtype::Float32() || | |||
| @@ -55,6 +55,42 @@ __global__ void permute_duplicate_keys_kernel(KeyType* keys, ValueType* indexs, | |||
| } | |||
| } | |||
| template <typename T> | |||
| __global__ void shuffle_fwd_kernel(uint32_t step, uint32_t src_size, const T* sptr, | |||
| T* dptr, const int* iptr) { | |||
| uint32_t idx = threadIdx.x + blockIdx.x * blockDim.x; | |||
| if (idx < src_size) { | |||
| uint32_t r = idx / step; | |||
| dptr[idx]=sptr[iptr[r] * step + idx % step]; | |||
| } | |||
| } | |||
| template <typename T> | |||
| void shuffle_forward(T* sptr, T* dptr, dt_int32* iptr, | |||
| size_t len, size_t step, cudaStream_t stream) { | |||
| uint32_t src_size = len * step; | |||
| shuffle_fwd_kernel<<<DIVUP(src_size, 512), 512, 0, stream>>>( | |||
| step, src_size, sptr, dptr, iptr); | |||
| after_kernel_launch(); | |||
| } | |||
| template <typename T> | |||
| __global__ void shuffle_bwd_kernel(uint32_t step, uint32_t src_size, T* sptr, | |||
| T* dptr, const int* iptr) { | |||
| uint32_t idx = threadIdx.x + blockIdx.x * blockDim.x; | |||
| if (idx < src_size) { | |||
| uint32_t r = idx / step; | |||
| sptr[iptr[r] * step + idx % step]=dptr[idx]; | |||
| } | |||
| } | |||
| template <typename T> | |||
| void shuffle_backward(T* dptr, dt_int32* iptr, T* sptr, | |||
| size_t len, size_t step, cudaStream_t stream) { | |||
| uint32_t src_size = len * step; | |||
| shuffle_bwd_kernel<<<DIVUP(src_size, 512), 512, 0, stream>>>( | |||
| step, src_size, sptr, dptr, iptr); | |||
| after_kernel_launch(); | |||
| } | |||
| uint32_t get_permutation_bits(size_t N) { | |||
| double uniq_rand_num_prob = 0.9; | |||
| double thresh = std::log(uniq_rand_num_prob) * 12; | |||
| @@ -156,6 +192,14 @@ INST_PERMUTATION(dt_int16) | |||
| INST_PERMUTATION(dt_float32) | |||
| #undef INST_PERMUTATION | |||
| #define INST_SHUFFLE(T) \ | |||
| template void shuffle_forward<T>(T* sptr, T* dptr, dt_int32* iptr,\ | |||
| size_t len, size_t step, cudaStream_t stream);\ | |||
| template void shuffle_backward(T* dptr, dt_int32* iptr, T* sptr,\ | |||
| size_t len, size_t step, cudaStream_t stream); | |||
| ARGSORT_FOREACH_CTYPE(INST_SHUFFLE) | |||
| #undef INST_SHUFFLE | |||
| } // namespace random | |||
| #define INST(_dtype) \ | |||
| @@ -253,6 +253,17 @@ void permutation_forward(ctype* dst, void* workspace, size_t size, uint64_t seed | |||
| size_t get_permutation_workspace_in_bytes(size_t N); | |||
| template<typename T> | |||
| void shuffle_forward(T* sptr, T* dptr, dt_int32* iptr, | |||
| size_t len, size_t step, cudaStream_t stream); | |||
| template<typename T> | |||
| void shuffle_backward(T* dptr, dt_int32* iptr, T* sptr, | |||
| size_t len, size_t step, cudaStream_t stream); | |||
| #define ARGSORT_FOREACH_CTYPE(cb) \ | |||
| cb(float) cb(int32_t) DNN_INC_FLOAT16(cb(dt_float16)) | |||
| } // namespace random | |||
| } // namespace cuda | |||
| } // namespace megdnn | |||
| @@ -9,11 +9,11 @@ | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "./opr_impl.h" | |||
| #include "./kernel.cuh" | |||
| #include "src/common/utils.h" | |||
| #include "src/cuda/handle.h" | |||
| #include "src/cuda/utils.h" | |||
| #include "./opr_impl.h" | |||
| #include "./kernel.cuh" | |||
| using namespace megdnn; | |||
| using namespace cuda; | |||
| @@ -261,5 +261,76 @@ size_t PermutationRNGImpl::get_workspace_in_bytes(const TensorLayout &layout){ | |||
| return random::get_permutation_workspace_in_bytes(size); | |||
| } | |||
| ShuffleRNGForwardImpl::ShuffleRNGForwardImpl(Handle* handle) | |||
| : ShuffleRNGForward(handle), | |||
| m_seed(0), | |||
| m_offset(0), | |||
| m_stream(cuda_stream(handle)) {} | |||
| void ShuffleRNGForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
| _megdnn_tensor_out indices, | |||
| _megdnn_workspace workspace) { | |||
| check_exec(src.layout, dst.layout, indices.layout, workspace.size); | |||
| ensure_seed(m_param.seed); | |||
| auto wk = workspace.ptr<void>(); | |||
| const auto len = indices.layout[0]; | |||
| random::permutation_forward<dt_int32>(indices.ptr<dt_int32>(), wk, len, | |||
| m_seed, m_offset, m_stream); | |||
| size_t step = 0; | |||
| for (size_t i = 1; i < src.layout.ndim; ++i) { | |||
| step += src.layout[i]; | |||
| } | |||
| if (step <= 0) | |||
| step = 1; | |||
| switch (src.layout.dtype.enumv()) { | |||
| #define cb(DType) \ | |||
| case DTypeTrait<DType>::enumv: \ | |||
| random::shuffle_forward<DTypeTrait<DType>::ctype>( \ | |||
| src.ptr<DTypeTrait<DType>::ctype>(), \ | |||
| dst.ptr<DTypeTrait<DType>::ctype>(), indices.ptr<dt_int32>(), \ | |||
| len, step, m_stream); \ | |||
| break; | |||
| ARGSORT_FOREACH_CTYPE(cb) | |||
| #undef cb | |||
| default : megdnn_throw("bad dtype"); | |||
| } | |||
| m_offset += 8; | |||
| } | |||
| size_t ShuffleRNGForwardImpl::get_workspace_in_bytes( | |||
| const TensorLayout&, const TensorLayout&, | |||
| const TensorLayout& indices) { | |||
| size_t size = indices.total_nr_elems(); | |||
| return random::get_permutation_workspace_in_bytes(size); | |||
| } | |||
| ShuffleRNGBackwardImpl::ShuffleRNGBackwardImpl(Handle* handle) | |||
| : ShuffleRNGBackward(handle), m_stream(cuda_stream(handle)) {} | |||
| void ShuffleRNGBackwardImpl::exec(_megdnn_tensor_in diff, | |||
| _megdnn_tensor_in indices, | |||
| _megdnn_tensor_out grad, | |||
| _megdnn_workspace workspace) { | |||
| const auto len = indices.layout[0]; | |||
| auto step = 0; | |||
| for (size_t i = 1; i < diff.layout.ndim; ++i) { | |||
| step += diff.layout[i]; | |||
| } | |||
| if (step <= 0) | |||
| step = 1; | |||
| switch (diff.layout.dtype.enumv()) { | |||
| #define cb(DType) \ | |||
| case DTypeTrait<DType>::enumv: \ | |||
| random::shuffle_backward<DTypeTrait<DType>::ctype>( \ | |||
| diff.ptr<DTypeTrait<DType>::ctype>(), indices.ptr<dt_int32>(), \ | |||
| grad.ptr<DTypeTrait<DType>::ctype>(), len, step, m_stream); \ | |||
| break; | |||
| ARGSORT_FOREACH_CTYPE(cb) | |||
| #undef cb | |||
| default: | |||
| megdnn_throw("bad dtype"); | |||
| } | |||
| } | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -6,7 +6,8 @@ | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| @@ -152,6 +153,45 @@ public: | |||
| } | |||
| }; | |||
| class ShuffleRNGForwardImpl : public ShuffleRNGForward { | |||
| uint64_t m_seed, m_offset; | |||
| cudaStream_t m_stream; | |||
| public: | |||
| using ShuffleRNGForward::ShuffleRNGForward; | |||
| ShuffleRNGForwardImpl(Handle* handle); | |||
| void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
| _megdnn_tensor_out indices, _megdnn_workspace workspace) override; | |||
| size_t get_workspace_in_bytes(const TensorLayout& src, | |||
| const TensorLayout& dst, | |||
| const TensorLayout& indices) override; | |||
| void seed(uint64_t seed) { m_seed = seed; } | |||
| void ensure_seed(uint64_t seed) { | |||
| if (m_seed != seed) { | |||
| this->seed(seed); | |||
| } | |||
| } | |||
| }; | |||
| class ShuffleRNGBackwardImpl : public ShuffleRNGBackward { | |||
| cudaStream_t m_stream; | |||
| public: | |||
| using ShuffleRNGBackward::ShuffleRNGBackward; | |||
| ShuffleRNGBackwardImpl(Handle* handle); | |||
| void exec(_megdnn_tensor_in diff, _megdnn_tensor_in indices, | |||
| _megdnn_tensor_out grad, _megdnn_workspace workspace) override; | |||
| size_t get_workspace_in_bytes(const TensorLayout&, | |||
| const TensorLayout&, | |||
| const TensorLayout&) override { | |||
| return 0; | |||
| } | |||
| }; | |||
| } // namespace cuda | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -6,12 +6,13 @@ | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/naive/handle.h" | |||
| #include "src/common/utils.h" | |||
| #include "./opr_impl.h" | |||
| #include "src/common/utils.h" | |||
| #include "src/naive/handle.h" | |||
| #include <cmath> | |||
| @@ -229,7 +230,29 @@ namespace { | |||
| } | |||
| } | |||
| } // anonymous namespace | |||
| template <typename T> | |||
| void shuffle_fwd(const T* __restrict sptr, T* __restrict dptr, | |||
| const dt_int32* iptr, const size_t len, | |||
| const size_t step) MEGDNN_NOEXCEPT { | |||
| for (size_t i = 0; i < len; ++i) { | |||
| for (size_t j = 0; j < step; ++j) { | |||
| dptr[i * step + j] = sptr[iptr[i] * step + j]; | |||
| } | |||
| } | |||
| } | |||
| template <typename T> | |||
| void shuffle_bwd(T* __restrict sptr, const T* __restrict dptr, | |||
| const dt_int32* iptr, const size_t len, | |||
| const size_t step) MEGDNN_NOEXCEPT { | |||
| for (size_t i = 0; i < len; ++i) { | |||
| for (size_t j = 0; j < step; ++j) { | |||
| sptr[iptr[i] * step + j] = dptr[i * step + j]; | |||
| } | |||
| } | |||
| } | |||
| } // anonymous namespace | |||
| uint64_t Splitmix64::operator() () { | |||
| uint64_t z = (m_s += UINT64_C(0x9E3779B97F4A7C15)); | |||
| @@ -394,5 +417,54 @@ void PermutationRNGImpl::exec( | |||
| } | |||
| } | |||
| // vim: syntax=cpp.doxygen | |||
| void ShuffleRNGForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
| _megdnn_tensor_out indices, | |||
| _megdnn_workspace workspace) { | |||
| check_exec(src.layout, dst.layout, indices.layout, workspace.size); | |||
| const auto len = indices.layout[0]; | |||
| auto iptr = indices.ptr<dt_int32>(); | |||
| auto prng = &m_rng.ensure_seed(m_param.seed); | |||
| fill_permutation<dt_int32>(prng, iptr, len); | |||
| auto step = 0; | |||
| for (size_t i = 1; i < src.layout.ndim; ++i) { | |||
| step += src.layout[i]; | |||
| } | |||
| if (step <= 0) | |||
| step = 1; | |||
| #define cb(DType) \ | |||
| if (src.layout.dtype == DType()) { \ | |||
| using T = typename DTypeTrait<DType>::ctype; \ | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR( \ | |||
| shuffle_fwd<T>(src.ptr<T>(), dst.ptr<T>(), iptr, len, step)); \ | |||
| return; \ | |||
| } | |||
| MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | |||
| #undef cb | |||
| } | |||
| void ShuffleRNGBackwardImpl::exec(_megdnn_tensor_in diff, | |||
| _megdnn_tensor_in indices, | |||
| _megdnn_tensor_out grad, | |||
| _megdnn_workspace workspace) { | |||
| check_exec(diff.layout, indices.layout, grad.layout, workspace.size); | |||
| const auto len = indices.layout[0]; | |||
| auto iptr = indices.ptr<dt_int32>(); | |||
| auto step = 0; | |||
| for (size_t i = 1; i < diff.layout.ndim; ++i) { | |||
| step += diff.layout[i]; | |||
| } | |||
| if (step <= 0) | |||
| step = 1; | |||
| #define cb(DType) \ | |||
| if (diff.layout.dtype == DType()) { \ | |||
| using T = typename DTypeTrait<DType>::ctype; \ | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR(shuffle_bwd<T>( \ | |||
| grad.ptr<T>(), diff.ptr<T>(), iptr, len, step)); \ | |||
| return; \ | |||
| } | |||
| MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | |||
| #undef cb | |||
| } | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -128,6 +128,35 @@ public: | |||
| size_t get_workspace_in_bytes(const TensorLayout&) override { return 0; } | |||
| }; | |||
| class ShuffleRNGForwardImpl : public ShuffleRNGForward { | |||
| Xoroshiro128plus m_rng; | |||
| public: | |||
| using ShuffleRNGForward::ShuffleRNGForward; | |||
| void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
| _megdnn_tensor_out indices, _megdnn_workspace workspace) override; | |||
| size_t get_workspace_in_bytes(const TensorLayout&, | |||
| const TensorLayout&, | |||
| const TensorLayout&) override { | |||
| return 0; | |||
| } | |||
| }; | |||
| class ShuffleRNGBackwardImpl : public ShuffleRNGBackward { | |||
| Xoroshiro128plus m_rng; | |||
| public: | |||
| using ShuffleRNGBackward::ShuffleRNGBackward; | |||
| void exec(_megdnn_tensor_in diff, _megdnn_tensor_in indices, | |||
| _megdnn_tensor_out grad, _megdnn_workspace workspace) override; | |||
| size_t get_workspace_in_bytes(const TensorLayout&, | |||
| const TensorLayout&, | |||
| const TensorLayout&) override { | |||
| return 0; | |||
| } | |||
| }; | |||
| } // namespace naive | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -143,6 +143,60 @@ void run_permutation(Handle* handle) { | |||
| } | |||
| } | |||
| template <typename T> | |||
| void run_shuffle(Handle* handle, bool bwd_flag) { | |||
| using ctype = typename DTypeTrait<T>::ctype; | |||
| auto run = [&](TensorShape shape) { | |||
| auto opr = handle->create_operator<ShuffleRNGForward>(); | |||
| TensorLayout srclay{shape, T()}; | |||
| TensorLayout dstlay{shape, T()}; | |||
| TensorLayout indexlay{TensorShape{shape[0]}, dtype::Int32()}; | |||
| Tensor<dt_byte> workspace( | |||
| handle, {TensorShape{opr->get_workspace_in_bytes(srclay, dstlay, | |||
| indexlay)}, | |||
| dtype::Byte()}); | |||
| SyncedTensor<ctype> src(handle, srclay); | |||
| SyncedTensor<ctype> dst(handle, dstlay); | |||
| SyncedTensor<DTypeTrait<dt_int32>::ctype> index(handle, indexlay); | |||
| auto sptr = src.ptr_mutable_host(); | |||
| size_t size = src.layout().total_nr_elems(); | |||
| for (size_t j = 0; j < size; ++j) { | |||
| sptr[j] = j; | |||
| } | |||
| opr->exec(src.tensornd_dev(), dst.tensornd_dev(), index.tensornd_dev(), | |||
| {workspace.ptr(), workspace.layout().total_nr_elems()}); | |||
| auto dptr = dst.ptr_mutable_host(); | |||
| auto iptr = index.ptr_mutable_host(); | |||
| size_t len = index.layout().total_nr_elems(); | |||
| size_t step = size / len; | |||
| for (size_t i = 0; i < len; ++i) { | |||
| for (size_t j = 0; j < step; ++j) { | |||
| ASSERT_EQ(dptr[i * step + j], sptr[iptr[i] * step + j]); | |||
| } | |||
| } | |||
| if (bwd_flag) { | |||
| for (size_t j = 0; j < size; ++j) { | |||
| sptr[j] = 0; | |||
| } | |||
| auto oprbwd = handle->create_operator<ShuffleRNGBackward>(); | |||
| oprbwd->exec( | |||
| dst.tensornd_dev(), index.tensornd_dev(), | |||
| src.tensornd_dev(), | |||
| {workspace.ptr(), workspace.layout().total_nr_elems()}); | |||
| auto sptr_bwd = src.ptr_mutable_host(); | |||
| for (size_t i = 0; i < len; ++i) { | |||
| for (size_t j = 0; j < step; ++j) { | |||
| ASSERT_EQ(dptr[i * step + j], sptr_bwd[iptr[i] * step + j]); | |||
| } | |||
| } | |||
| } | |||
| }; | |||
| run({10}); | |||
| run({6, 3}); | |||
| } | |||
| } // anonymous namespace | |||
| TEST_F(CUDA, UNIFORM_RNG_F32) { | |||
| @@ -215,6 +269,30 @@ TEST_F(CUDA, PERMUTATION_RNG_INT16) { | |||
| run_permutation<dtype::Int16>(handle_cuda()); | |||
| } | |||
| TEST_F(CUDA, SHUFFLE_RNG_F32) { | |||
| run_shuffle<dtype::Float32>(handle_cuda(), false); | |||
| } | |||
| TEST_F(CUDA, SHUFFLE_RNG_INT32) { | |||
| run_shuffle<dtype::Int32>(handle_cuda(), false); | |||
| } | |||
| TEST_F(CUDA, SHUFFLE_RNG_F16) { | |||
| run_shuffle<dtype::Float16>(handle_cuda(), false); | |||
| } | |||
| TEST_F(CUDA, SHUFFLE_RNG_BWD_F32) { | |||
| run_shuffle<dtype::Float32>(handle_cuda(), true); | |||
| } | |||
| TEST_F(CUDA, SHUFFLE_RNG_BWD_INT32) { | |||
| run_shuffle<dtype::Int32>(handle_cuda(), true); | |||
| } | |||
| TEST_F(CUDA, SHUFFLE_RNG_BWD_F16) { | |||
| run_shuffle<dtype::Float16>(handle_cuda(), true); | |||
| } | |||
| } // namespace test | |||
| } // namespace megdnn | |||
| @@ -6,12 +6,13 @@ | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "megdnn.h" | |||
| #include "test/naive/fixture.h" | |||
| #include "test/naive/rng.h" | |||
| #include "megdnn.h" | |||
| #include "test/common/tensor.h" | |||
| #include "test/naive/fixture.h" | |||
| namespace megdnn { | |||
| @@ -181,7 +182,59 @@ namespace { | |||
| ASSERT_LE(std::abs(res[i] - ctype(i)), 1e-8); | |||
| } | |||
| } | |||
| } | |||
| template <typename T> | |||
| void run_shuffle(Handle* handle, bool bwd_flag) { | |||
| using ctype = typename DTypeTrait<T>::ctype; | |||
| auto run = [&](TensorShape shape) { | |||
| auto opr = handle->create_operator<ShuffleRNGForward>(); | |||
| TensorLayout srclay{shape, T()}; | |||
| TensorLayout dstlay{shape, T()}; | |||
| TensorLayout indexlay{TensorShape{shape[0]}, dtype::Int32()}; | |||
| Tensor<dt_byte> workspace( | |||
| handle, {TensorShape{opr->get_workspace_in_bytes(srclay, dstlay, | |||
| indexlay)}, | |||
| dtype::Byte()}); | |||
| Tensor<ctype> src(handle, srclay); | |||
| Tensor<ctype> dst(handle, dstlay); | |||
| Tensor<DTypeTrait<dt_int32>::ctype> index(handle, indexlay); | |||
| auto sptr = src.ptr(); | |||
| size_t size = src.layout().total_nr_elems(); | |||
| for (size_t j = 0; j < size; ++j) { | |||
| sptr[j] = j; | |||
| } | |||
| opr->exec(src.tensornd(), dst.tensornd(), index.tensornd(), | |||
| {workspace.ptr(), workspace.layout().total_nr_elems()}); | |||
| auto dptr = dst.ptr(); | |||
| auto iptr = index.ptr(); | |||
| size_t len = index.layout().total_nr_elems(); | |||
| size_t step = size / len; | |||
| for (size_t i = 0; i < len; ++i) { | |||
| for (size_t j = 0; j < step; ++j) { | |||
| ASSERT_EQ(dptr[i * step + j], sptr[iptr[i] * step + j]); | |||
| } | |||
| } | |||
| if (bwd_flag) { | |||
| for (size_t j = 0; j < size; ++j) { | |||
| sptr[j] = 0; | |||
| } | |||
| auto oprbwd = handle->create_operator<ShuffleRNGBackward>(); | |||
| oprbwd->exec( | |||
| dst.tensornd(), index.tensornd(), src.tensornd(), | |||
| {workspace.ptr(), workspace.layout().total_nr_elems()}); | |||
| for (size_t i = 0; i < len; ++i) { | |||
| for (size_t j = 0; j < step; ++j) { | |||
| ASSERT_EQ(dptr[i * step + j], sptr[iptr[i] * step + j]); | |||
| } | |||
| } | |||
| } | |||
| }; | |||
| run({10}); | |||
| run({6, 3}); | |||
| } | |||
| } // namespace | |||
| TEST_F(NAIVE, UNIFORM_RNG_F32) { | |||
| run_uniform<dtype::Float32>(handle()); | |||
| @@ -235,10 +288,31 @@ TEST_F(NAIVE, PERMUTATION_RNG_INT16) { | |||
| run_permutation<dtype::Int16>(handle()); | |||
| } | |||
| } // namespace test | |||
| } // namespace megdnn | |||
| TEST_F(NAIVE, SHUFFLE_RNG_FWD_F32) { | |||
| run_shuffle<dtype::Float32>(handle(), false); | |||
| } | |||
| // vim: syntax=cpp.doxygen | |||
| TEST_F(NAIVE, SHUFFLE_RNG_FWD_INT32) { | |||
| run_shuffle<dtype::Int32>(handle(), false); | |||
| } | |||
| TEST_F(NAIVE, SHUFFLE_RNG_FWD_F16) { | |||
| run_shuffle<dtype::Float16>(handle(), false); | |||
| } | |||
| TEST_F(NAIVE, SHUFFLE_RNG_BWD_F32) { | |||
| run_shuffle<dtype::Float32>(handle(), true); | |||
| } | |||
| TEST_F(NAIVE, SHUFFLE_RNG_BWD_INT32) { | |||
| run_shuffle<dtype::Int32>(handle(), true); | |||
| } | |||
| TEST_F(NAIVE, SHUFFLE_RNG_BWD_F16) { | |||
| run_shuffle<dtype::Float16>(handle(), true); | |||
| } | |||
| } // namespace test | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -6,7 +6,7 @@ | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from .rng import RNG, beta, gamma, normal, permutation, poisson, seed, uniform | |||
| from .rng import RNG, beta, gamma, normal, permutation, poisson, seed, shuffle, uniform | |||
| __all__ = [ | |||
| "RNG", | |||
| @@ -17,6 +17,7 @@ __all__ = [ | |||
| "poisson", | |||
| "seed", | |||
| "uniform", | |||
| "shuffle", | |||
| ] | |||
| # pylint: disable=undefined-variable | |||
| del rng # type: ignore[name-defined] | |||
| @@ -27,6 +27,7 @@ from ..core.ops.builtin import ( | |||
| GaussianRNG, | |||
| PermutationRNG, | |||
| PoissonRNG, | |||
| ShuffleRNG, | |||
| UniformRNG, | |||
| ) | |||
| from ..core.tensor import utils | |||
| @@ -41,6 +42,7 @@ __all__ = [ | |||
| "beta", | |||
| "poisson", | |||
| "permutation", | |||
| "shuffle", | |||
| ] | |||
| _rng = None | |||
| @@ -219,6 +221,13 @@ def _permutation(n: int, seed: int, device: str, handle: int, dtype: str) -> Ten | |||
| return output | |||
| def _shuffle(inp: Tensor, seed: int, handle: int) -> Tensor: | |||
| assert inp.size > 0, "size needs to be greater than 0" | |||
| op = ShuffleRNG(seed=seed, handle=handle) | |||
| output, _ = apply(op, inp) | |||
| inp._reset(output) | |||
| class RNG: | |||
| r""":class:`RNG` exposes a number of methods for generating random numbers. | |||
| @@ -581,6 +590,45 @@ class RNG: | |||
| n=n, seed=_seed, device=self._device, handle=self._handle, dtype=dtype | |||
| ) | |||
| def shuffle(self, inp: Tensor): | |||
| r"""Modify a sequence in-place by shuffling its contents. | |||
| This function only shuffles the Tensor along the first axis of a multi-dimensional Tensor. | |||
| The order of sub-Tensors is changed but their contents remains the same. | |||
| Args: | |||
| inp: input tensor. | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| import megengine as mge | |||
| import megengine.random as rand | |||
| x = mge.tensor(np.arange(10)) | |||
| rand.shuffle(x) | |||
| print(x.numpy()) | |||
| y = mge.tensor(np.arange(18)).reshape(6,3) | |||
| rand.shuffle(y) | |||
| print(y.numpy()) | |||
| Outputs: | |||
| .. testoutput:: | |||
| :options: +SKIP | |||
| [7 9 3 0 8 2 4 5 6 1] | |||
| [[12. 13. 14.] | |||
| [ 3. 4. 5.] | |||
| [15. 16. 17.] | |||
| [ 0. 1. 2.] | |||
| [ 9. 10. 11.] | |||
| [ 6. 7. 8.]] | |||
| """ | |||
| _seed = self._seed() if callable(self._seed) else self._seed | |||
| _shuffle(inp=inp, seed=_seed, handle=self._handle) | |||
| def __del__(self): | |||
| if self._handle != 0: | |||
| _delete_rng_handle(self._handle) | |||
| @@ -599,6 +647,7 @@ gamma = _default_handle.gamma | |||
| beta = _default_handle.beta | |||
| poisson = _default_handle.poisson | |||
| permutation = _default_handle.permutation | |||
| shuffle = _default_handle.shuffle | |||
| def _random_seed_generator(): | |||
| @@ -18,6 +18,7 @@ from megengine.core._imperative_rt.ops import ( | |||
| get_global_rng_seed, | |||
| new_rng_handle, | |||
| ) | |||
| from megengine.core.autodiff.grad import Grad | |||
| from megengine.core.ops.builtin import ( | |||
| BetaRNG, | |||
| GammaRNG, | |||
| @@ -397,6 +398,45 @@ def test_PermutationRNG(): | |||
| assert sum_result(out, np.sort) == 1000 | |||
| @pytest.mark.skipif( | |||
| get_device_count("xpu") <= 1, reason="xpu counts need > 1", | |||
| ) | |||
| def test_ShuffleRNG(): | |||
| g = [] | |||
| def cb(grad): | |||
| g.append(grad) | |||
| n, m = 6, 3 | |||
| arr = np.arange(n * m) | |||
| out0 = Tensor(arr, dtype="float32") | |||
| grad = Grad().wrt(out0, callback=cb) | |||
| random.shuffle(out0) | |||
| grad(out0, F.ones_like(out0)) | |||
| m1 = RNG(seed=111, device="xpu0") | |||
| m2 = RNG(seed=111, device="xpu1") | |||
| m3 = RNG(seed=222, device="xpu0") | |||
| out1 = Tensor(arr, dtype="float32", device="xpu0") | |||
| out2 = Tensor(arr, dtype="float32", device="xpu1") | |||
| out3 = Tensor(arr, dtype="float32", device="xpu0") | |||
| m1.shuffle(out1) | |||
| m2.shuffle(out2) | |||
| m3.shuffle(out3) | |||
| np.testing.assert_equal(out1.numpy(), out2.numpy()) | |||
| assert out1.device == "xpu0" and out2.device == "xpu1" | |||
| assert not (out1.numpy() == out3.numpy()).all() | |||
| out = Tensor(arr, dtype="float32").reshape(n, m) | |||
| m1.shuffle(out) | |||
| out_shp = out.shape | |||
| if isinstance(out_shp, tuple): | |||
| assert out_shp == (n, m) | |||
| else: | |||
| assert all(out.shape.numpy() == np.array([n, m])) | |||
| def test_seed(): | |||
| set_global_seed(10) | |||
| out1 = uniform(size=[10, 10]) | |||
| @@ -6,7 +6,8 @@ | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "megbrain/imperative/ops/rng.h" | |||
| @@ -14,8 +15,8 @@ | |||
| #include "megbrain/graph/helper.h" | |||
| #include "megbrain/opr/rand.h" | |||
| #include "../op_trait.h" | |||
| #include "../dnn_op_helper.h" | |||
| #include "../op_trait.h" | |||
| namespace mgb::imperative::rng { | |||
| @@ -259,13 +260,27 @@ struct OpMeth<BetaRNG> { | |||
| } | |||
| }; | |||
| template <> | |||
| struct OpMeth<ShuffleRNG> { | |||
| using DnnOp = megdnn::ShuffleRNG; | |||
| using Param = DnnOp::Param; | |||
| using OpNode = mgb::opr::ShuffleRNG; | |||
| static Param make_param(const ShuffleRNG& rng) { | |||
| auto handle_seed = RNGDnnOpManager::get_seed(rng.handle); | |||
| mgb_assert(handle_seed == rng.seed, | |||
| "inconsistent rng seed: rng op: %lu handle: %lu", | |||
| handle_seed, rng.seed); | |||
| return {handle_seed}; | |||
| } | |||
| }; | |||
| template <bool> | |||
| struct _InferLayout; | |||
| template <int nr_in> | |||
| struct _RNGOprMaker; | |||
| template <int nr_in> | |||
| template <int nr_in, int nr_out> | |||
| struct _RNGOprInvoker; | |||
| template<> | |||
| @@ -316,50 +331,63 @@ struct _InferLayout<false> | |||
| return inp.layout; | |||
| } | |||
| }; | |||
| #define _INST_RNG_INVOLKER(DNN_NR_INPUTS) \ | |||
| template<> \ | |||
| struct _RNGOprInvoker<DNN_NR_INPUTS> { \ | |||
| template<typename Opr> \ | |||
| static void exec(Opr *dnn_op, const SmallVector<TensorPtr>& inputs,const TensorPtr& dest){ \ | |||
| size_t wk_size = 0; \ | |||
| wk_size = dnn_op->get_workspace_in_bytes(_FOR_EACH_IN(->layout())dest->layout()); \ | |||
| auto workspace = Blob::make(dest->comp_node(), wk_size); \ | |||
| megdnn::Workspace dnn_wk(workspace->storage().get(), wk_size); \ | |||
| dnn_op->exec(_FOR_EACH_IN(->dev_tensor().as_megdnn()) \ | |||
| dest->dev_tensor().as_megdnn(), dnn_wk); \ | |||
| } \ | |||
| }; | |||
| #define _INST_RNG_INVOLKER(DNN_NR_INPUTS, DNN_NR_OUTPUTS) \ | |||
| template <> \ | |||
| struct _RNGOprInvoker<DNN_NR_INPUTS, DNN_NR_OUTPUTS> { \ | |||
| template <typename Opr> \ | |||
| static void exec(Opr* dnn_op, const SmallVector<TensorPtr>& inputs, \ | |||
| const SmallVector<TensorPtr>& outputs) { \ | |||
| size_t wk_size = 0; \ | |||
| wk_size = dnn_op->get_workspace_in_bytes( \ | |||
| _FOR_EACH_IN(->layout()) _FOR_EACH_OUT(->layout())); \ | |||
| auto workspace = Blob::make(outputs[0]->comp_node(), wk_size); \ | |||
| megdnn::Workspace dnn_wk(workspace->storage().get(), wk_size); \ | |||
| dnn_op->exec(_FOR_EACH_IN(->dev_tensor().as_megdnn()) \ | |||
| _FOR_EACH_OUT(->dev_tensor().as_megdnn()), \ | |||
| dnn_wk); \ | |||
| } \ | |||
| }; | |||
| #define _INST_RNG_MAKER(MGB_NR_INPUTS) \ | |||
| template<> \ | |||
| struct _RNGOprMaker<MGB_NR_INPUTS> { \ | |||
| template<typename Op> \ | |||
| static SymbolVar make(const VarNodeArray& inputs, const Op& rng){ \ | |||
| auto param = OpMeth<Op>::make_param(rng); \ | |||
| OperatorNodeConfig config; \ | |||
| if (rng.handle) { \ | |||
| config = {rng.make_name(), RNGDnnOpManager::get_comp_node(rng.handle)}; \ | |||
| } else { \ | |||
| config = {rng.make_name()}; \ | |||
| } \ | |||
| return OpMeth<Op>::OpNode::make(_FOR_EACH_IN() param, config); \ | |||
| } \ | |||
| }; | |||
| #define _INST_RNG_MAKER(MGB_NR_INPUTS) \ | |||
| template <> \ | |||
| struct _RNGOprMaker<MGB_NR_INPUTS> { \ | |||
| template <typename Op> \ | |||
| static auto make(const VarNodeArray& inputs, const Op& rng) { \ | |||
| auto param = OpMeth<Op>::make_param(rng); \ | |||
| OperatorNodeConfig config; \ | |||
| if (rng.handle) { \ | |||
| config = {rng.make_name(), \ | |||
| RNGDnnOpManager::get_comp_node(rng.handle)}; \ | |||
| } else { \ | |||
| config = {rng.make_name()}; \ | |||
| } \ | |||
| return OpMeth<Op>::OpNode::make(_FOR_EACH_IN() param, config); \ | |||
| } \ | |||
| }; | |||
| #define _FOR_EACH_IN(subfix) | |||
| _INST_RNG_INVOLKER(0) | |||
| #define _FOR_EACH_IN(subfix) | |||
| #define _FOR_EACH_OUT(subfix) outputs[0] subfix | |||
| _INST_RNG_INVOLKER(0, 1) | |||
| #undef _FOR_EACH_OUT | |||
| #undef _FOR_EACH_IN | |||
| #define _FOR_EACH_IN(subfix) inputs[0] subfix, | |||
| _INST_RNG_INVOLKER(1) | |||
| #define _FOR_EACH_OUT(subfix) outputs[0] subfix | |||
| _INST_RNG_INVOLKER(1, 1) | |||
| #undef _FOR_EACH_OUT | |||
| #define _FOR_EACH_OUT(subfix) outputs[0] subfix, outputs[1] subfix | |||
| _INST_RNG_INVOLKER(1, 2) | |||
| _INST_RNG_MAKER(1) | |||
| #undef _FOR_EACH_OUT | |||
| #undef _FOR_EACH_IN | |||
| #define _FOR_EACH_IN(subfix) inputs[0] subfix, inputs[1] subfix, | |||
| _INST_RNG_INVOLKER(2) | |||
| #define _FOR_EACH_OUT(subfix) outputs[0] subfix | |||
| _INST_RNG_INVOLKER(2, 1) | |||
| _INST_RNG_MAKER(2) | |||
| #undef _FOR_EACH_OUT | |||
| #undef _FOR_EACH_IN | |||
| #undef _INST_RNG_INVOLKER | |||
| @@ -392,7 +420,9 @@ void exec(const OpDef& op, const SmallVector<TensorPtr>& inputs, | |||
| handle_seed, dnn_op->param().seed); | |||
| } | |||
| dnn_op->param() = OpMeth<Op>::make_param(rng); | |||
| _RNGOprInvoker<OpMeth<Op>::DnnOp::NR_INPUTS>::exec(dnn_op,inputs,dest); | |||
| _RNGOprInvoker<OpMeth<Op>::DnnOp::NR_INPUTS, | |||
| OpMeth<Op>::DnnOp::NR_OUTPUTS>::exec(dnn_op, inputs, | |||
| outputs); | |||
| } | |||
| template <typename Op> | |||
| @@ -420,24 +450,45 @@ SmallVector<LogicalTensorDesc> infer_output_attrs( | |||
| return {dest}; | |||
| } | |||
| template <typename Op> | |||
| std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc( | |||
| const OpDef& def, | |||
| const SmallVector<TensorPtr>& inputs_tensors, | |||
| const SmallVector<MemoryDesc>& inputs_mems) { | |||
| auto &&dest = infer_output_attrs<Op>(def, inputs_tensors); | |||
| SmallVector<MemoryDesc> outputs = {{dest[0].layout, 0, dest[0].comp_node, StorageIdentifier::make(1)}}; | |||
| return {outputs, {}}; | |||
| template <> | |||
| SmallVector<LogicalTensorDesc> infer_output_attrs<ShuffleRNG>( | |||
| const OpDef& op, const SmallVector<TensorPtr>& inputs) { | |||
| SmallVector<LogicalTensorDesc> dests(2); | |||
| auto&& rng = op.cast_final_safe<ShuffleRNG>(); | |||
| auto handle = rng.handle; | |||
| if (handle) { | |||
| dests[0].comp_node = RNGDnnOpManager::get_comp_node(handle); | |||
| dests[1].comp_node = RNGDnnOpManager::get_comp_node(handle); | |||
| } else { | |||
| dests[0].comp_node = inputs[0]->comp_node(); | |||
| dests[1].comp_node = inputs[0]->comp_node(); | |||
| } | |||
| dests[0].layout = TensorLayout(inputs[0]->layout()); | |||
| dests[0].layout.dtype = inputs[0]->layout().dtype; | |||
| dests[1].layout = | |||
| TensorLayout(TensorShape({inputs[0]->layout()[0]}), dtype::Int32()); | |||
| return dests; | |||
| } | |||
| template <typename Op> | |||
| std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> | |||
| infer_output_mem_desc(const OpDef& def, | |||
| const SmallVector<TensorPtr>& inputs_tensors, | |||
| const SmallVector<MemoryDesc>& inputs_mems) { | |||
| auto&& dests = infer_output_attrs<Op>(def, inputs_tensors); | |||
| SmallVector<MemoryDesc> outputs; | |||
| for (size_t i = 0; i < dests.size(); ++i) { | |||
| outputs.push_back({dests[i].layout, 0, dests[i].comp_node, | |||
| StorageIdentifier::make(i + 1)}); | |||
| } | |||
| return {outputs, {}}; | |||
| } | |||
| template <typename Op> | |||
| SmallVector<TensorPtr> apply_on_physical_tensor( | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||
| SmallVector<TensorPtr> outputs; | |||
| SmallVector<LogicalTensorDesc> desc; | |||
| desc = infer_output_attrs<Op>(def, inputs); | |||
| SmallVector<LogicalTensorDesc> desc = infer_output_attrs<Op>(def, inputs); | |||
| for (auto&& i : desc) { | |||
| outputs.push_back(Tensor::make(i.layout, i.comp_node)); | |||
| } | |||
| @@ -454,10 +505,8 @@ void execute( | |||
| exec<Op>(def, inputs, outputs, {}); | |||
| } | |||
| template<typename Op> | |||
| SymbolVar apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| template <typename Op, typename Output> | |||
| Output apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
| size_t nr_inp = inputs.size(); | |||
| constexpr size_t dnn_nr_inp = OpMeth<Op>::DnnOp::NR_INPUTS; | |||
| auto&& rng = def.cast_final_safe<Op>(); | |||
| @@ -487,7 +536,21 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
| return {{dest}, true}; | |||
| } | |||
| } // anonymous namespace | |||
| template <> | |||
| std::tuple<SmallVector<LogicalTensorDesc>, bool> | |||
| infer_output_attrs_fallible<ShuffleRNG>( | |||
| const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { | |||
| SmallVector<LogicalTensorDesc> dests(2); | |||
| dests[0].comp_node = inputs[0].comp_node; | |||
| dests[0].layout = TensorLayout(inputs[0].layout); | |||
| dests[0].layout.dtype = inputs[0].layout.dtype; | |||
| dests[1].comp_node = inputs[0].comp_node; | |||
| dests[1].layout = TensorLayout(TensorShape({inputs[0].layout.shape[0]}), | |||
| dtype::Int32()); | |||
| return {dests, true}; | |||
| } | |||
| } // anonymous namespace | |||
| Handle new_handle(CompNode comp_node, uint64_t seed) { | |||
| return RNGDnnOpManager::inst().new_handle(comp_node, seed); | |||
| @@ -509,23 +572,24 @@ CompNode get_rng_handle_compnode(Handle handle){ | |||
| return RNGDnnOpManager::get_comp_node(handle); | |||
| } | |||
| #define REG_RNG_OP(NAME)\ | |||
| namespace { \ | |||
| OP_TRAIT_REG(NAME, NAME, OpMeth<NAME>::OpNode) \ | |||
| .apply_on_var_node(apply_on_var_node<NAME>) \ | |||
| .apply_on_physical_tensor(apply_on_physical_tensor<NAME>) \ | |||
| .infer_output_attrs_fallible(infer_output_attrs_fallible<NAME>) \ | |||
| .infer_output_mem_desc(infer_output_mem_desc<NAME>) \ | |||
| .execute(execute<NAME>) \ | |||
| .fallback(); \ | |||
| } \ | |||
| REG_RNG_OP(UniformRNG) | |||
| REG_RNG_OP(GaussianRNG) | |||
| REG_RNG_OP(GammaRNG) | |||
| REG_RNG_OP(PermutationRNG) | |||
| REG_RNG_OP(PoissonRNG) | |||
| REG_RNG_OP(BetaRNG) | |||
| #define REG_RNG_OP(NAME, Output) \ | |||
| namespace { \ | |||
| OP_TRAIT_REG(NAME, NAME, OpMeth<NAME>::OpNode) \ | |||
| .apply_on_var_node(apply_on_var_node<NAME, Output>) \ | |||
| .apply_on_physical_tensor(apply_on_physical_tensor<NAME>) \ | |||
| .infer_output_attrs_fallible(infer_output_attrs_fallible<NAME>) \ | |||
| .infer_output_mem_desc(infer_output_mem_desc<NAME>) \ | |||
| .execute(execute<NAME>) \ | |||
| .fallback(); \ | |||
| } | |||
| REG_RNG_OP(UniformRNG, SymbolVar) | |||
| REG_RNG_OP(GaussianRNG, SymbolVar) | |||
| REG_RNG_OP(GammaRNG, SymbolVar) | |||
| REG_RNG_OP(PermutationRNG, SymbolVar) | |||
| REG_RNG_OP(PoissonRNG, SymbolVar) | |||
| REG_RNG_OP(BetaRNG, SymbolVar) | |||
| REG_RNG_OP(ShuffleRNG, SymbolVarArray) | |||
| #undef REG_RNG_OP | |||
| } // namespace mgb::imperative::rng | |||
| @@ -215,6 +215,19 @@ def PermutationRNG: MgbHashableOp<"PermutationRNG", [PermutationRNGParam]> { | |||
| let cmpFunction = [{return $0.handle == $1.handle && $0.dtype == $1.dtype;}]; | |||
| } | |||
| def ShuffleRNG: MgbHashableOp<"ShuffleRNG", [ShuffleRNGParam]> { | |||
| let extraArguments = (ins | |||
| MgbSizeTAddr:$handle | |||
| ); | |||
| let hashFunction = [{ | |||
| return mgb::hash_pair_combine( | |||
| mgb::hash($_self.dyn_typeinfo()), | |||
| mgb::hash($_self.handle) | |||
| ); | |||
| }]; | |||
| let cmpFunction = [{return $0.handle == $1.handle;}]; | |||
| } | |||
| def Linspace: MgbHashableOp<"Linspace", [LinspaceParam]> { | |||
| let extraArguments = (ins | |||
| MgbCompNodeAttr:$comp_node | |||
| @@ -192,6 +192,8 @@ template class RNGOprBase<::megdnn::GammaRNG>; | |||
| template class RNGOprBase<::megdnn::PermutationRNG>; | |||
| template class RNGOprBase<::megdnn::BetaRNG>; | |||
| template class RNGOprBase<::megdnn::PoissonRNG>; | |||
| template class RNGOprBase<::megdnn::ShuffleRNGForward>; | |||
| template class RNGOprBase<::megdnn::ShuffleRNGBackward>; | |||
| #if MGB_ENABLE_GRAD | |||
| IMPL(GaussianRNG); | |||
| IMPL(UniformRNG); | |||
| @@ -200,9 +202,87 @@ IMPL(PoissonRNG); | |||
| IMPL(PermutationRNG); | |||
| IMPL(BetaRNG); | |||
| #endif | |||
| } | |||
| } // namespace intl | |||
| } // namespace opr | |||
| } // namespace mgb | |||
| /* ================= ShuffleRNGForward ================= */ | |||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(ShuffleRNGForward); | |||
| ShuffleRNGForward::ShuffleRNGForward(VarNode* data, const Param& param, | |||
| const OperatorNodeConfig& config) | |||
| : Super({data->owner_graph(), config, "shuffle_rng", {data}}, param) { | |||
| add_input({data}); | |||
| add_output(None)->dtype(data->dtype()); | |||
| add_output(None)->dtype(dtype::Int32{}); | |||
| cg::add_workspace_output(this); | |||
| add_equivalence_component<ScalarHash<void*>>(this); | |||
| } | |||
| SymbolVarArray ShuffleRNGForward::make(SymbolVar in_tensor, const Param& param, | |||
| const OperatorNodeConfig& config) { | |||
| auto node = in_tensor.node()->owner_graph()->insert_opr( | |||
| std::make_unique<ShuffleRNGForward>(in_tensor.node(), param, | |||
| config)); | |||
| mgb_assert(node->output().size() == 3); | |||
| return {node->output(0), node->output(1)}; | |||
| } | |||
| void ShuffleRNGForward::init_output_static_infer_desc() { | |||
| using namespace cg::static_infer; | |||
| auto&& mgr = owner_graph()->static_infer_manager(); | |||
| mgr.register_shape_infer(output(0), | |||
| ShapeInferDesc::make_identity(input(0))); | |||
| auto infer_oshp1 = [this](TensorShape& dest, const InpVal& iv) { | |||
| TensorLayout o0, o1; | |||
| m_dnn_opr->deduce_layout({iv.val[0].shape(), input(0)->dtype()}, o0, | |||
| o1); | |||
| dest = o1; | |||
| return true; | |||
| }; | |||
| mgr.register_shape_infer( | |||
| output(1), | |||
| {SourceType::DEP, {{input(0), DepType::SHAPE}}, infer_oshp1}); | |||
| auto infer_wk = [this](TensorShape& dest, const InpVal& inp) { | |||
| ensure_megdnn_opr(); | |||
| dest.ndim = 1; | |||
| dest.shape[0] = m_dnn_opr->get_workspace_in_bytes( | |||
| {inp.val[0].shape(), input(0)->dtype()}, | |||
| {output(0)->shape(), output(0)->dtype()}, | |||
| {output(1)->shape(), output(1)->dtype()}); | |||
| return true; | |||
| }; | |||
| mgr.register_shape_infer( | |||
| output(2), | |||
| {SourceType::DEP, {{input(0), DepType::SHAPE}}, infer_wk}); | |||
| } | |||
| void ShuffleRNGForward::add_input_layout_constraint() { | |||
| input(0)->add_layout_constraint_contiguous(); | |||
| }; | |||
| void ShuffleRNGForward::scn_do_execute() { | |||
| m_dnn_opr->exec(input(0)->dev_tensor().as_megdnn(), | |||
| output(0)->dev_tensor().as_megdnn(), | |||
| output(1)->dev_tensor().as_megdnn(), | |||
| get_megdnn_workspace_from_var(output(2))); | |||
| } | |||
| #if MGB_ENABLE_GRAD | |||
| MGB_IMPL_OPR_GRAD(ShuffleRNGForward) { | |||
| mgb_assert(out_grad.size() == 3 && wrt_idx == 0 && !out_grad[2]); | |||
| if (!out_grad[0]) | |||
| return nullptr; | |||
| return ShuffleRNGBackward::make(out_grad[0], opr.output(1), opr.input(0)).node(); | |||
| } | |||
| #endif | |||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(ShuffleRNGBackward); | |||
| MEGDNN_OPR_INIT3(ShuffleRNGBackward, "shuffle_rng_bwd", 2, true) | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -6,7 +6,8 @@ | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "megbrain/opr/rand.h" | |||
| @@ -14,6 +15,23 @@ | |||
| namespace mgb { | |||
| namespace serialization { | |||
| template <> | |||
| struct OprMaker<opr::ShuffleRNG, 1> { | |||
| using Opr = opr::ShuffleRNG; | |||
| using Param = Opr::Param; | |||
| static cg::OperatorNodeBase* make(const Param& param, | |||
| const cg::VarNodeArray& inputs, | |||
| ComputingGraph& graph, | |||
| const OperatorNodeConfig& config) { | |||
| MGB_MARK_USED_VAR(graph); | |||
| auto out = Opr::make(inputs[0], param, config); | |||
| return out[0].node()->owner_opr(); | |||
| } | |||
| }; | |||
| } // namespace serialization | |||
| namespace opr { | |||
| using UniformRNGV1 = opr::UniformRNG; | |||
| @@ -24,9 +42,10 @@ MGB_SEREG_OPR(GammaRNG, 2); | |||
| MGB_SEREG_OPR(PoissonRNG, 1); | |||
| MGB_SEREG_OPR(PermutationRNG, 1); | |||
| MGB_SEREG_OPR(BetaRNG, 2); | |||
| MGB_SEREG_OPR(ShuffleRNG, 1); | |||
| MGB_SEREG_OPR(ShuffleRNGBackward, 3); | |||
| } // namespace opr | |||
| } // namespace mgb | |||
| } // namespace opr | |||
| } // namespace mgb | |||
| // vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -6,14 +6,15 @@ | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| #include "megbrain/graph.h" | |||
| #include "megbrain/opr/internal/out_shape_by_sym_var.h" | |||
| #include "megbrain/opr/internal/megdnn_opr_wrapper.h" | |||
| #include "megbrain/opr/internal/out_shape_by_sym_var.h" | |||
| #include "megdnn/oprs.h" | |||
| namespace mgb { | |||
| @@ -41,22 +42,24 @@ MGB_DEFINE_CLS_WITH_SUPER(RNGOprBase, cg::SingleCNOperatorNodeBase) // { | |||
| }; | |||
| /* ================= RNG with shape ================= */ | |||
| #define _DEFINE_RNG_OPR_WITH_SHAPE_CLASS(RNG) \ | |||
| MGB_DEFINE_OPR_CLASS(RNG,RNGOprBase<megdnn::RNG>) \ | |||
| cg::OperatorNodeBase::NodeProp* do_make_node_prop() const override; \ | |||
| public: \ | |||
| RNG(VarNode *shape, const Param ¶m, const OperatorNodeConfig &config); \ | |||
| static SymbolVar make(SymbolVar shape, const Param ¶m = {}, \ | |||
| const OperatorNodeConfig &config = {}); \ | |||
| static SymbolVar make(ComputingGraph &graph, const TensorShape &shape, \ | |||
| const OperatorNodeConfig &config, \ | |||
| const Param ¶m = {}) { \ | |||
| return make(var_from_tensor_shape(graph, config, "rng", shape), \ | |||
| param, config); \ | |||
| } \ | |||
| void init_output_static_infer_desc() override; \ | |||
| void scn_do_execute() override; \ | |||
| }; | |||
| #define _DEFINE_RNG_OPR_WITH_SHAPE_CLASS(RNG) \ | |||
| MGB_DEFINE_OPR_CLASS(RNG, RNGOprBase<megdnn::RNG>) \ | |||
| cg::OperatorNodeBase::NodeProp* do_make_node_prop() const override; \ | |||
| \ | |||
| public: \ | |||
| RNG(VarNode* shape, const Param& param, const OperatorNodeConfig& config); \ | |||
| static SymbolVar make(SymbolVar shape, const Param& param = {}, \ | |||
| const OperatorNodeConfig& config = {}); \ | |||
| static SymbolVar make(ComputingGraph& graph, const TensorShape& shape, \ | |||
| const OperatorNodeConfig& config, \ | |||
| const Param& param = {}) { \ | |||
| return make(var_from_tensor_shape(graph, config, "rng", shape), param, \ | |||
| config); \ | |||
| } \ | |||
| void init_output_static_infer_desc() override; \ | |||
| void scn_do_execute() override; \ | |||
| } \ | |||
| ; | |||
| _DEFINE_RNG_OPR_WITH_SHAPE_CLASS(UniformRNG) | |||
| _DEFINE_RNG_OPR_WITH_SHAPE_CLASS(GaussianRNG) | |||
| @@ -71,7 +74,7 @@ MGB_DEFINE_OPR_CLASS(RNG, RNGOprBase<megdnn::RNG>) | |||
| public: \ | |||
| RNG(_INPUTS(VarNode*), const Param ¶m, \ | |||
| const OperatorNodeConfig &config); \ | |||
| static SymbolVar make(_INPUTS(SymbolVar),const Param ¶m = {}, \ | |||
| static _OUTPUTS make(_INPUTS(SymbolVar),const Param ¶m = {}, \ | |||
| const OperatorNodeConfig &config = {}); \ | |||
| void init_output_static_infer_desc() override; \ | |||
| void scn_do_execute() override; \ | |||
| @@ -79,17 +82,24 @@ MGB_DEFINE_OPR_CLASS(RNG, RNGOprBase<megdnn::RNG>) | |||
| /* ================= 1 input ================= */ | |||
| #define _INPUTS(preifx) preifx i0 | |||
| #define _OUTPUTS SymbolVar | |||
| _DEFINE_RNG_OPR_WITH_INPUT_CLASS(PoissonRNG) | |||
| #undef _OUTPUTS | |||
| #define _OUTPUTS SymbolVarArray | |||
| _DEFINE_RNG_OPR_WITH_INPUT_CLASS(ShuffleRNGForward) | |||
| #undef _OUTPUTS | |||
| #undef _INPUTS | |||
| /* ================= 2 input ================= */ | |||
| #define _INPUTS(preifx) preifx i0, preifx i1 | |||
| #define _OUTPUTS SymbolVar | |||
| _DEFINE_RNG_OPR_WITH_INPUT_CLASS(BetaRNG) | |||
| _DEFINE_RNG_OPR_WITH_INPUT_CLASS(GammaRNG) | |||
| #undef _OUTPUTS | |||
| #undef _INPUTS | |||
| #undef _DEFINE_RNG_OPR_WITH_INPUT_CLASS | |||
| } // intl | |||
| } // intl | |||
| using UniformRNG = intl::UniformRNG; | |||
| using GaussianRNG = intl::GaussianRNG; | |||
| @@ -97,9 +107,20 @@ using GammaRNG = intl::GammaRNG; | |||
| using PermutationRNG = intl::PermutationRNG; | |||
| using PoissonRNG = intl::PoissonRNG; | |||
| using BetaRNG = intl::BetaRNG; | |||
| } // namespace opr | |||
| } // namespace mgb | |||
| using ShuffleRNG = intl::ShuffleRNGForward; | |||
| MGB_DEFINE_OPR_CLASS(ShuffleRNGBackward, | |||
| intl::MegDNNOprWrapperBwd<megdnn::ShuffleRNGBackward>) //{ | |||
| public: | |||
| ShuffleRNGBackward(VarNode* out_diff, VarNode* indices, VarNode* result_shape, | |||
| const Param& param, const OperatorNodeConfig& config); | |||
| // vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| static SymbolVar make(SymbolVar out_diff, SymbolVar indices, | |||
| SymbolVar result_shape, const Param& param = {}, | |||
| const OperatorNodeConfig& config = {}); | |||
| }; | |||
| } // namespace opr | |||
| } // namespace mgb | |||
| // vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -333,6 +333,38 @@ TEST(TestOprRand, EmptyShape) { | |||
| } | |||
| TEST(TestOprRand, ShuffleForward) { | |||
| auto run = [&](TensorShape shape) { | |||
| std::shared_ptr<HostTensorND> src_host(new HostTensorND{ | |||
| CompNode::load("xpux"), shape, dtype::Float32()}); | |||
| auto sptr = src_host->ptr<dt_float32>(); | |||
| auto size = shape.total_nr_elems(); | |||
| for (size_t i = 0; i < size; ++i) { | |||
| sptr[i] = i; | |||
| } | |||
| auto graph = ComputingGraph::make(); | |||
| auto src_sym = opr::Host2DeviceCopy::make(*graph, src_host); | |||
| auto rec = opr::ShuffleRNG::make(src_sym, {10}); | |||
| HostTensorND host_y, host_index; | |||
| auto func = graph->compile({make_callback_copy(rec[0], host_y), | |||
| make_callback_copy(rec[1], host_index)}); | |||
| func->execute(); | |||
| auto dptr = host_y.ptr<dt_float32>(); | |||
| auto iptr = host_index.ptr<dt_int32>(); | |||
| size_t len = shape[0]; | |||
| size_t step = size / len; | |||
| for (size_t i = 0; i < len; ++i) { | |||
| for (size_t j = 0; j < step; ++j) { | |||
| assert(dptr[i * step + j] == sptr[iptr[i] * step + j]); | |||
| } | |||
| } | |||
| }; | |||
| run({10}); | |||
| run({6, 3}); | |||
| run({1, 1}); | |||
| } | |||
| TEST(TestOprRand, UniformReprod) { | |||
| static constexpr size_t SIZE = 123; | |||
| auto graph = ComputingGraph::make(); | |||
| @@ -114,6 +114,7 @@ union OperatorParam { | |||
| param.BetaRNG = 80, | |||
| param.SlidingWindowTranspose = 81, | |||
| param.Padding = 82, | |||
| param.ShuffleRNG = 83, | |||
| } | |||
| table Operator { | |||