You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

host_static_calc.cpp 1.2 kB

1234567891011121314151617181920212223242526272829303132333435363738
  1. /**
  2. * \file test/src/host_static_calc.cpp
  3. *
  4. * This file is part of MegBrain, a deep learning framework developed by Megvii.
  5. *
  6. * \brief static calculating on host to check opr correctness
  7. *
  8. * \copyright Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  9. *
  10. */
  11. #include "megbrain/test/host_static_calc.h"
  12. void mgb::elemwise_static_calc(opr::Elemwise::Mode mode,
  13. HostTensorND &dest, const std::vector<HostTensorND>& inputs) {
  14. #if defined(ANDROID) || defined(IOS) || defined(__arm__)
  15. static opr::intl::UniqPtrWithCN<megdnn::Elemwise> opr_impl;
  16. static std::mutex mtx;
  17. MGB_LOCK_GUARD(mtx);
  18. #else
  19. static thread_local opr::intl::UniqPtrWithCN<megdnn::Elemwise> opr_impl;
  20. #endif
  21. auto cn = CompNode::default_cpu();
  22. if (!opr_impl) {
  23. opr_impl = opr::intl::create_megdnn_opr<megdnn::Elemwise>(cn);
  24. }
  25. DeviceTensorND dev_dest{cn};
  26. SmallVector<DeviceTensorND> dev_inp(inputs.size());
  27. for (size_t i = 0; i < inputs.size(); ++ i) {
  28. dev_inp[i].comp_node(cn).copy_from(inputs[i]);
  29. }
  30. opr::Elemwise::perform(mode, dev_dest, dev_inp, opr_impl);
  31. dest.copy_from(dev_dest);
  32. }
  33. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台