You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

rng.cpp 3.3 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. /**
  2. * \file dnn/src/common/rng.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megdnn/oprs.h"
  12. #include "src/common/utils.h"
  13. namespace megdnn {
  14. void PermutationRNG::check_exec(
  15. const TensorLayout &dst, size_t workspace_in_bytes) {
  16. megdnn_assert((dst.dtype == dtype::Float32() ||
  17. dst.dtype == dtype::Int32() ||
  18. dst.dtype == dtype::Int16() ) &&
  19. dst.dtype.enumv() == param().dtype &&
  20. dst.is_contiguous());
  21. megdnn_assert(workspace_in_bytes >= get_workspace_in_bytes(dst));
  22. }
  23. void PoissonRNG::check_exec(const TensorLayout &lam, const TensorLayout &dst,
  24. size_t workspace_in_bytes){
  25. megdnn_assert( dst.dtype.category() == DTypeCategory::FLOAT &&
  26. lam.dtype == dst.dtype);
  27. megdnn_assert(dst.is_contiguous() && lam.is_contiguous());
  28. megdnn_assert(lam.total_nr_elems() == dst.total_nr_elems());
  29. megdnn_assert(workspace_in_bytes >= get_workspace_in_bytes(lam, dst));
  30. }
  31. void GammaRNG::check_exec(const TensorLayout &shape,const TensorLayout &scale,
  32. const TensorLayout &dst, size_t workspace_in_bytes){
  33. megdnn_assert(dst.dtype.category() == DTypeCategory::FLOAT &&
  34. shape.dtype == dst.dtype &&
  35. scale.dtype == dst.dtype);
  36. megdnn_assert(shape.is_contiguous() && scale.is_contiguous()
  37. && dst.is_contiguous());
  38. megdnn_assert(shape.total_nr_elems() == dst.total_nr_elems() &&
  39. scale.total_nr_elems() == dst.total_nr_elems());
  40. megdnn_assert(workspace_in_bytes >= get_workspace_in_bytes(shape,scale,dst));
  41. }
  42. void BetaRNG::check_exec(const TensorLayout &alpha,const TensorLayout &beta,
  43. const TensorLayout &dst, size_t workspace_in_bytes){
  44. megdnn_assert(dst.dtype.category() == DTypeCategory::FLOAT &&
  45. alpha.dtype == dst.dtype &&
  46. beta.dtype == dst.dtype);
  47. megdnn_assert(alpha.is_contiguous() && beta.is_contiguous()
  48. && dst.is_contiguous());
  49. megdnn_assert(alpha.total_nr_elems() == dst.total_nr_elems() &&
  50. beta.total_nr_elems() == dst.total_nr_elems());
  51. megdnn_assert(workspace_in_bytes >= get_workspace_in_bytes(alpha,beta, dst));
  52. }
  53. #define INST_CHECK_EXEC(RNG_NAME) \
  54. void RNG_NAME::check_exec( \
  55. const TensorLayout &dst, size_t workspace_in_bytes) { \
  56. megdnn_assert(dst.dtype.category() == DTypeCategory::FLOAT && \
  57. dst.dtype.enumv() == param().dtype && \
  58. dst.is_contiguous()); \
  59. megdnn_assert(workspace_in_bytes >= get_workspace_in_bytes(dst)); \
  60. }
  61. INST_CHECK_EXEC(UniformRNG)
  62. INST_CHECK_EXEC(GaussianRNG)
  63. #undef INST_CHECK_EXEC
  64. } // namespace megdnn
  65. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台