You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

numerical_diff.cpp 2.5 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. /**
  2. * \file test/src/numerical_diff.cpp
  3. *
  4. * This file is part of MegBrain, a deep learning framework developed by Megvii.
  5. *
  6. * \copyright Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  7. *
  8. */
  9. #include "megbrain/test/numerical_diff.h"
  10. #include "megbrain/utils/timer.h"
  11. #include "megbrain/common.h"
  12. #include <limits>
  13. #include <cmath>
  14. using namespace mgb;
  15. std::vector<HostTensorND> mgb::numerical_diff_pt2(
  16. const std::vector<HostTensorND*> &input,
  17. std::function<float()> cost,
  18. const std::vector<Maybe<float>> &eps) {
  19. std::vector<HostTensorND> result;
  20. if (!eps.empty())
  21. mgb_assert(eps.size() == input.size());
  22. for (size_t cur_inp_idx = 0; cur_inp_idx < input.size(); ++ cur_inp_idx)
  23. {
  24. result.emplace_back();
  25. if (!input[cur_inp_idx])
  26. continue;
  27. auto &&cur_inp = input[cur_inp_idx];
  28. auto &&dest = result.back();
  29. dest.comp_node(cur_inp->comp_node()).
  30. dtype(cur_inp->dtype()).
  31. resize(cur_inp->shape());
  32. auto dptr = dest.ptr<float>();
  33. mgb_assert(cur_inp->layout().is_contiguous());
  34. auto cur_inp_ptr = cur_inp->ptr<float>();
  35. mgb::RealTimer timer;
  36. double prev_record = 0.0;
  37. for (size_t i = 0, it = cur_inp->layout().total_nr_elems();
  38. i < it; ++ i) {
  39. auto orig = cur_inp_ptr[i];
  40. float delta;
  41. if (eps.empty() || !eps[cur_inp_idx].valid()) {
  42. delta = std::sqrt(std::numeric_limits<float>::epsilon()) *
  43. std::max<float>(std::fabs(orig), 1);
  44. } else {
  45. delta = eps[cur_inp_idx].val();
  46. }
  47. cur_inp_ptr[i] = orig - delta;
  48. auto c0 = cost();
  49. cur_inp_ptr[i] = orig + delta;
  50. auto c1 = cost();
  51. cur_inp_ptr[i] = orig;
  52. auto cur_time = timer.get_secs();
  53. if (cur_time - prev_record > 10) {
  54. prev_record = cur_time;
  55. mgb_log_warn(
  56. "numerical diff running for more than %.3f secs, "
  57. "consider to reduce the tensor size", cur_time);
  58. }
  59. dptr[i] = (c1 - c0) / (delta * 2);
  60. }
  61. }
  62. return result;
  63. }
  64. namespace mgb {
  65. // explicit inst to avoid link error for Maybe::Maybe()
  66. template class Maybe<float>;
  67. }
  68. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台