You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

num_range_checker.cpp 4.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. /**
  2. * \file src/plugin/impl/num_range_checker.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megbrain/opr/io.h"
  12. #include "megbrain/opr/basic_arith_wrapper.h"
  13. #include "megbrain/plugin/num_range_checker.h"
  14. #include "megbrain/graph/exc_extra_info.h"
  15. #include "megdnn/tensor_iter.h"
  16. #include <cmath>
  17. using namespace mgb;
  18. void NumRangeChecker::Checker::init(VarNode *var, float range) {
  19. if (m_func)
  20. return;
  21. m_inp = std::make_shared<DeviceTensorND>(var->comp_node(), var->dtype());
  22. m_out = std::make_unique<HostTensorND>();
  23. auto cg = ComputingGraph::make();
  24. cg->options().log_level = 0;
  25. auto vi = opr::VolatileSharedDeviceTensor::make(*cg, m_inp),
  26. chk = opr::abs(vi) < range,
  27. good = opr::reduce_min(chk, chk.make_scalar(1));
  28. auto cb = [d=m_out.get()](DeviceTensorND &dv) {
  29. d->copy_from(dv).sync();
  30. };
  31. m_func = cg->compile({{good, cb}});
  32. }
  33. bool NumRangeChecker::Checker::check(VarNode *var){
  34. auto &&val = var->dev_tensor();
  35. if (val.layout().is_contiguous()) {
  36. *m_inp = var->dev_tensor();
  37. } else {
  38. *m_inp = {};
  39. m_inp->copy_from(val);
  40. }
  41. m_func->execute();
  42. mgb_assert(m_out->shape().is_scalar());
  43. return m_out->ptr<float>()[0] >= 0.5;
  44. }
  45. NumRangeChecker::NumRangeChecker(cg::ComputingGraph *graph, float range):
  46. PluginBase(graph), m_range{range}
  47. {
  48. add_member_func_as_event_handler(&NumRangeChecker::on_kern_end);
  49. add_member_func_as_event_handler(&NumRangeChecker::on_subgraph_associated);
  50. }
  51. void NumRangeChecker::on_kern_end(const cg::event::OprExecKernelEnd &event) {
  52. for (VarNode *var: event.opr->output()) {
  53. if (!var->contain_flag(VarNode::Flag::VOLATILE_CONTENT) &&
  54. var->dtype().category() == DTypeCategory::FLOAT) {
  55. event.env->dispatch_on_comp_node(var->comp_node(),
  56. [this, var](){on_var_computed(var);});
  57. }
  58. }
  59. }
  60. void NumRangeChecker::on_subgraph_associated(
  61. const cg::event::SubgraphAssociated &event) {
  62. mgb_assert(event.par_graph == m_owner_graph);
  63. m_sub_graph_checkers.emplace_back(std::make_unique<NumRangeChecker>(
  64. event.sub_graph, m_range));
  65. }
  66. void NumRangeChecker::on_var_computed(VarNode *var) {
  67. if (!var->dev_tensor_valid())
  68. return;
  69. auto &&checker = m_cn2dt2checker[var->comp_node()][var->dtype().enumv()];
  70. checker.init(var, m_range);
  71. if (!checker.check(var)) {
  72. HostTensorND hv;
  73. hv.copy_from(var->dev_tensor()).sync();
  74. std::string msg{mgb_ssprintf_log("float value out of range: var: %s\n",
  75. cg::dump_var_info({var}).c_str())};
  76. switch (hv.dtype().enumv()) {
  77. #define cb(_dt) case DTypeTrait<_dt>::enumv: \
  78. msg += format_msg<DTypeTrait<_dt>::ctype>(hv, m_range); \
  79. break;
  80. MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb)
  81. default:
  82. mgb_assert(0, "unexpected dtype");
  83. }
  84. mgb_throw_raw(cg::OperatorNodeExcExtraInfo::
  85. ExcMaker{var->owner_opr()}.make<Error>(msg));
  86. }
  87. }
  88. template<typename ctype>
  89. std::string NumRangeChecker::format_msg(const HostTensorND &hv, float range) {
  90. auto iter = megdnn::tensor_iter<ctype>(hv.as_megdnn()).begin();
  91. for (size_t i = 0, it = hv.shape().total_nr_elems(); i < it; ++ i) {
  92. float val = static_cast<float>(*iter);
  93. if (!(std::fabs(val) < range)) {
  94. TensorShape idx_shp;
  95. idx_shp.ndim = hv.shape().ndim;
  96. std::copy(iter.idx(), iter.idx() + idx_shp.ndim, idx_shp.shape);
  97. return mgb_ssprintf_log(
  98. " value=%g range=%g index=%s/%s",
  99. val, range,
  100. idx_shp.to_string().c_str(),
  101. hv.shape().to_string().c_str());
  102. }
  103. ++ iter;
  104. }
  105. return mgb_cstr_log(" <error: range check passed on host>");
  106. }
  107. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台