|
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677 |
- /**
- * \file dnn/src/common/rng.cpp
- * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
- *
- * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- */
-
- #include "megdnn/oprs.h"
-
- #include "src/common/utils.h"
-
- namespace megdnn {
-
- void PermutationRNG::check_exec(
- const TensorLayout &dst, size_t workspace_in_bytes) {
- megdnn_assert((dst.dtype == dtype::Float32() ||
- dst.dtype == dtype::Int32() ||
- dst.dtype == dtype::Int16() ) &&
- dst.dtype.enumv() == param().dtype &&
- dst.is_contiguous());
- megdnn_assert(workspace_in_bytes >= get_workspace_in_bytes(dst));
- }
-
- void PoissonRNG::check_exec(const TensorLayout &lam, const TensorLayout &dst,
- size_t workspace_in_bytes){
- megdnn_assert( dst.dtype.category() == DTypeCategory::FLOAT &&
- lam.dtype == dst.dtype);
- megdnn_assert(dst.is_contiguous() && lam.is_contiguous());
- megdnn_assert(lam.total_nr_elems() == dst.total_nr_elems());
- megdnn_assert(workspace_in_bytes >= get_workspace_in_bytes(lam, dst));
- }
-
- void GammaRNG::check_exec(const TensorLayout &shape,const TensorLayout &scale,
- const TensorLayout &dst, size_t workspace_in_bytes){
- megdnn_assert(dst.dtype.category() == DTypeCategory::FLOAT &&
- shape.dtype == dst.dtype &&
- scale.dtype == dst.dtype);
- megdnn_assert(shape.is_contiguous() && scale.is_contiguous()
- && dst.is_contiguous());
- megdnn_assert(shape.total_nr_elems() == dst.total_nr_elems() &&
- scale.total_nr_elems() == dst.total_nr_elems());
- megdnn_assert(workspace_in_bytes >= get_workspace_in_bytes(shape,scale,dst));
- }
-
- void BetaRNG::check_exec(const TensorLayout &alpha,const TensorLayout &beta,
- const TensorLayout &dst, size_t workspace_in_bytes){
- megdnn_assert(dst.dtype.category() == DTypeCategory::FLOAT &&
- alpha.dtype == dst.dtype &&
- beta.dtype == dst.dtype);
- megdnn_assert(alpha.is_contiguous() && beta.is_contiguous()
- && dst.is_contiguous());
- megdnn_assert(alpha.total_nr_elems() == dst.total_nr_elems() &&
- beta.total_nr_elems() == dst.total_nr_elems());
- megdnn_assert(workspace_in_bytes >= get_workspace_in_bytes(alpha,beta, dst));
- }
-
- #define INST_CHECK_EXEC(RNG_NAME) \
- void RNG_NAME::check_exec( \
- const TensorLayout &dst, size_t workspace_in_bytes) { \
- megdnn_assert(dst.dtype.category() == DTypeCategory::FLOAT && \
- dst.dtype.enumv() == param().dtype && \
- dst.is_contiguous()); \
- megdnn_assert(workspace_in_bytes >= get_workspace_in_bytes(dst)); \
- }
-
- INST_CHECK_EXEC(UniformRNG)
- INST_CHECK_EXEC(GaussianRNG)
- #undef INST_CHECK_EXEC
-
- } // namespace megdnn
-
- // vim: syntax=cpp.doxygen
-
|