You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

numerical_diff.cpp 2.5 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. /**
  2. * \file test/src/numerical_diff.cpp
  3. *
  4. * This file is part of MegBrain, a deep learning framework developed by Megvii.
  5. *
  6. * \copyright Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  7. *
  8. */
  9. #include "megbrain/test/numerical_diff.h"
  10. #include "megbrain/common.h"
  11. #include "megbrain/utils/timer.h"
  12. #include <cmath>
  13. #include <limits>
  14. using namespace mgb;
  15. std::vector<HostTensorND> mgb::numerical_diff_pt2(
  16. const std::vector<HostTensorND*>& input, std::function<float()> cost,
  17. const std::vector<Maybe<float>>& eps) {
  18. std::vector<HostTensorND> result;
  19. if (!eps.empty())
  20. mgb_assert(eps.size() == input.size());
  21. for (size_t cur_inp_idx = 0; cur_inp_idx < input.size(); ++cur_inp_idx) {
  22. result.emplace_back();
  23. if (!input[cur_inp_idx])
  24. continue;
  25. auto&& cur_inp = input[cur_inp_idx];
  26. auto&& dest = result.back();
  27. dest.comp_node(cur_inp->comp_node())
  28. .dtype(cur_inp->dtype())
  29. .resize(cur_inp->shape());
  30. auto dptr = dest.ptr<float>();
  31. mgb_assert(cur_inp->layout().is_contiguous() || cur_inp->layout().is_empty());
  32. auto cur_inp_ptr = cur_inp->ptr<float>();
  33. mgb::RealTimer timer;
  34. double prev_record = 0.0;
  35. for (size_t i = 0, it = cur_inp->layout().total_nr_elems(); i < it; ++i) {
  36. auto orig = cur_inp_ptr[i];
  37. float delta;
  38. if (eps.empty() || !eps[cur_inp_idx].valid()) {
  39. delta = std::sqrt(std::numeric_limits<float>::epsilon()) *
  40. std::max<float>(std::fabs(orig), 1);
  41. } else {
  42. delta = eps[cur_inp_idx].val();
  43. }
  44. cur_inp_ptr[i] = orig - delta;
  45. auto c0 = cost();
  46. cur_inp_ptr[i] = orig + delta;
  47. auto c1 = cost();
  48. cur_inp_ptr[i] = orig;
  49. auto cur_time = timer.get_secs();
  50. if (cur_time - prev_record > 10) {
  51. prev_record = cur_time;
  52. mgb_log_warn(
  53. "numerical diff running for more than %.3f secs, "
  54. "consider to reduce the tensor size",
  55. cur_time);
  56. }
  57. dptr[i] = (c1 - c0) / (delta * 2);
  58. }
  59. }
  60. return result;
  61. }
  62. namespace mgb {
  63. // explicit inst to avoid link error for Maybe::Maybe()
  64. template class Maybe<float>;
  65. } // namespace mgb
  66. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}