You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

num_range_checker.cpp 3.1 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. #include "megbrain/plugin/num_range_checker.h"
  2. #include "megbrain/opr/basic_arith.h"
  3. #include "megbrain/opr/io.h"
  4. #include "megbrain/opr/loop.h"
  5. #include "megbrain/test/helper.h"
  6. using namespace mgb;
  7. TEST(TestNumRangeChecker, Simple) {
  8. HostTensorGenerator<> gen;
  9. auto graph = ComputingGraph::make();
  10. NumRangeChecker checker{graph.get(), 1e30f};
  11. auto av = gen({3}), bv = gen({3});
  12. auto a = opr::Host2DeviceCopy::make(*graph, av),
  13. b = opr::Host2DeviceCopy::make(*graph, bv), c = a / b;
  14. auto func = graph->compile({{c, {}}});
  15. auto pb = bv->ptr<float>();
  16. pb[0] = 2;
  17. pb[1] = -1;
  18. pb[2] = 3;
  19. func->execute();
  20. pb[1] = 0;
  21. ASSERT_THROW(func->execute(), NumRangeChecker::Error);
  22. }
  23. TEST(TestNumRangeChecker, MultiDType) {
  24. HostTensorGenerator<dtype::Int32> gen;
  25. auto graph = ComputingGraph::make();
  26. NumRangeChecker checker{graph.get(), 1e30f};
  27. auto av = gen({3});
  28. auto a = opr::Host2DeviceCopy::make(*graph, av), b = a + a,
  29. c = opr::TypeCvt::make(b, dtype::Float32());
  30. auto func = graph->compile({{c, {}}});
  31. func->execute();
  32. }
  33. TEST(TestNumRangeChecker, MultiShape) {
  34. HostTensorGenerator<> gen;
  35. auto graph = ComputingGraph::make();
  36. NumRangeChecker checker{graph.get(), 1e30f};
  37. auto av = gen({1, 3}), bv = gen({3, 1});
  38. auto a = opr::Host2DeviceCopy::make(*graph, av),
  39. b = opr::Host2DeviceCopy::make(*graph, bv), c = (a + 2) / (b - 4);
  40. auto func = graph->compile({{c, {}}});
  41. auto pb = bv->ptr<float>();
  42. pb[0] = 2;
  43. pb[1] = -1;
  44. pb[2] = 3;
  45. func->execute();
  46. pb[2] = 4;
  47. ASSERT_THROW(func->execute(), NumRangeChecker::Error);
  48. }
  49. TEST(TestNumRangeChecker, Loop) {
  50. HostTensorGenerator<> gen;
  51. auto graph = ComputingGraph::make();
  52. NumRangeChecker checker{graph.get(), 1e30f};
  53. auto av = gen({3}), bv = gen({3});
  54. auto a = opr::Host2DeviceCopy::make(*graph, av),
  55. b = opr::Host2DeviceCopy::make(*graph, bv);
  56. auto loop_cb = [&](opr::Loop::Desc& desc) {
  57. auto ai = desc.add_input(a), bi = desc.add_input(b);
  58. desc.set_loop_condition(desc.get_counter_var() < 0);
  59. auto out = ai + bi;
  60. desc.add_output(out, opr::Loop::Desc::OutputMode::LAST);
  61. out.node()->owner_graph()->options().extra_vardeps[out.node()].push_back(
  62. (ai / bi).node());
  63. };
  64. auto c = opr::Loop::make(loop_cb)[0];
  65. HostTensorND host_c;
  66. auto func = graph->compile({make_callback_copy(c, host_c)});
  67. auto pb = bv->ptr<float>();
  68. pb[0] = 2;
  69. pb[1] = -1;
  70. pb[2] = 3;
  71. func->execute();
  72. pb[1] = 0;
  73. ASSERT_THROW(func->execute(), NumRangeChecker::Error);
  74. }
  75. TEST(TestNumRangeChecker, MultiStreamDyn) {
  76. auto cns = load_multiple_xpus(2);
  77. HostTensorGenerator<> gen;
  78. auto graph = ComputingGraph::make();
  79. NumRangeChecker checker{graph.get(), 1e30f};
  80. auto xv = gen({3}, cns[0]);
  81. auto x = opr::Host2DeviceCopy::make(*graph, xv), y = opr::Copy::make(x, cns[1]);
  82. auto func = graph->compile({{y, {}}});
  83. func->execute();
  84. }
  85. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}