You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cond.cpp 50 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337
  1. /**
  2. * \file src/opr/test/cond.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megbrain/test/helper.h"
  12. #include "megbrain/opr/basic_arith_wrapper.h"
  13. #include "megbrain/opr/cond.h"
  14. #include "megbrain/opr/io.h"
  15. #include "megbrain/opr/misc.h"
  16. #include "megbrain/opr/tensor_manip.h"
  17. #include "megbrain/opr/utility.h"
  18. #include "megbrain/utils/timer.h"
  19. #include <bitset>
  20. #if MGB_ENABLE_COND_EXEC
  21. using namespace mgb;
  22. namespace {
  23. using MergeMode = opr::CondExecMerge::Param::Mode;
  24. //! return y = (pred == 1 ? x : null)
  25. SymbolVar make_one_cond(
  26. SymbolVar pred, SymbolVar x, size_t nr_branch = 1, size_t this_branch = 0,
  27. bool grad_cond_out = false) {
  28. SymbolVar xcond;
  29. SymbolVarArray keys(nr_branch, pred.make_scalar_dt(0));
  30. keys.at(this_branch) = pred.make_scalar_dt(1);
  31. auto masks = opr::CondExecPred::make(pred, keys);
  32. EXPECT_EQ(nr_branch, masks.size());
  33. using Param = opr::CondExecMark::Param;
  34. Param p;
  35. if (grad_cond_out) {
  36. p.grad_mode = Param::GradMode::SUM_COND_OUT;
  37. }
  38. unpack_vector(opr::CondExecMark::make(masks.at(this_branch), {x}, p), xcond);
  39. return xcond;
  40. }
  41. SymbolVar make_call_rec(SymbolVar x, int* cnt) {
  42. auto cb = [cnt](DeviceTensorND&) { ++*cnt; };
  43. opr::CallbackInjector::Param param{cb};
  44. param.invoke_for_static_infer = false;
  45. return opr::CallbackInjector::make(x, param);
  46. }
  47. SymbolVar merge_one_out(
  48. const SymbolVarArray& inputs_orig, MergeMode mode, size_t nr_distractor = 0,
  49. const VarNodeArrayView& out_shapes = {}) {
  50. SymbolVarArray inputs;
  51. for (size_t i = 0; i < inputs_orig.size(); ++i) {
  52. for (size_t j = 0; j <= nr_distractor; ++j) {
  53. if (j == nr_distractor / 2) {
  54. inputs.push_back(inputs_orig[i]);
  55. } else {
  56. inputs.push_back(inputs_orig[i] + int(i * (nr_distractor + 1) + j + 1));
  57. }
  58. }
  59. }
  60. auto out = opr::CondExecMerge::make(
  61. inputs, {static_cast<uint32_t>(nr_distractor + 1), mode}, out_shapes);
  62. EXPECT_EQ(nr_distractor + 1, out.size());
  63. return out[nr_distractor / 2];
  64. }
  65. void test_merge_opr(MergeMode mode, bool pred_dynamic, bool final_sum) {
  66. if (final_sum && mode != MergeMode::SUM_COND_OUT) {
  67. return;
  68. }
  69. auto graph = ComputingGraph::make();
  70. graph->options().graph_opt_level = 0;
  71. HostTensorGenerator<> gen;
  72. HostTensorGenerator<dtype::Int32> gen_int;
  73. auto host_inp0 = gen({2, 3}), host_inp1 = gen({2, 3}), host_pred0 = gen_int({1}),
  74. host_pred1 = gen_int({1});
  75. host_pred0->ptr<int>()[0] = 0;
  76. host_pred1->ptr<int>()[0] = 1;
  77. SymbolVar inp0 = opr::Host2DeviceCopy::make_no_fwd(*graph, host_inp0),
  78. inp1 = opr::Host2DeviceCopy::make_no_fwd(*graph, host_inp1),
  79. pred0 = opr::Host2DeviceCopy::make(*graph, host_pred0),
  80. pred1 = opr::Host2DeviceCopy::make(*graph, host_pred1);
  81. if (pred_dynamic) {
  82. pred0 = opr::MarkDynamicVar::make(pred0);
  83. pred1 = opr::MarkDynamicVar::make(pred1);
  84. }
  85. int call0 = 0, call1 = 0, call2 = 0, call3 = 0;
  86. SymbolVar inp0_cond = make_call_rec(make_one_cond(pred0, inp0, 3, 2) / 2, &call0),
  87. inp1_cond = make_call_rec(make_one_cond(pred1, inp1, 4, 1) * 3, &call1),
  88. merged = merge_one_out({inp0_cond, inp1_cond}, mode, 3), out;
  89. if (final_sum) {
  90. // check for ExecutionMask produced by CondExecMerge
  91. out = make_call_rec(merged, &call3);
  92. out = merge_one_out({out}, MergeMode::SUM, 2) - 1;
  93. out = make_call_rec(out, &call2);
  94. mode = MergeMode::SUM;
  95. } else {
  96. out = make_call_rec(merged, &call2) - 1;
  97. }
  98. auto make_expect = [&](int pred0, int pred1) {
  99. HostTensorND ret{host_inp0->comp_node(), host_inp0->shape()};
  100. auto p0 = host_inp0->ptr<float>(), p1 = host_inp1->ptr<float>(),
  101. pr = ret.ptr<float>();
  102. for (size_t i = 0, it = ret.shape().total_nr_elems(); i < it; ++i) {
  103. pr[i] = -1;
  104. if (pred0) {
  105. pr[i] += p0[i] / 2;
  106. }
  107. if (pred1) {
  108. pr[i] += p1[i] * 3;
  109. }
  110. }
  111. return ret;
  112. };
  113. // static inference helper
  114. auto updater_shp = cg::static_infer::StaticInferUpdater::make(),
  115. updater_val = cg::static_infer::StaticInferUpdater::make();
  116. using IDType = cg::static_infer::DepType;
  117. if (!pred_dynamic) {
  118. updater_shp->add_dest({out.node(), IDType::SHAPE});
  119. updater_val->add_dest({out.node(), IDType::VALUE});
  120. } else if (mode != MergeMode::EXACT_ONE) {
  121. updater_shp->add_dest({out.node(), IDType::SHAPE});
  122. }
  123. auto infer_shape = [&]() {
  124. updater_shp->update();
  125. return graph->static_infer_manager().infer_shape(out.node());
  126. };
  127. auto infer_value = [&]() {
  128. updater_val->update();
  129. auto val = graph->static_infer_manager().infer_value(out.node());
  130. HostTensorND ret;
  131. ret.copy_from(val);
  132. return ret;
  133. };
  134. HostTensorND host_out;
  135. auto func = graph->compile({make_callback_copy(out, host_out)});
  136. auto check_all = [&](int pred0, int pred1) {
  137. call0 = call1 = call2 = call3 = 0;
  138. auto expect = make_expect(pred0, pred1);
  139. if (mode != MergeMode::EXACT_ONE || !pred_dynamic) {
  140. ASSERT_EQ(expect.shape(), infer_shape());
  141. }
  142. if (!pred_dynamic) {
  143. MGB_ASSERT_TENSOR_NEAR(expect, infer_value(), 1e-5);
  144. }
  145. func->execute();
  146. MGB_ASSERT_TENSOR_NEAR(expect, host_out, 1e-5);
  147. ASSERT_EQ(pred0, call0);
  148. ASSERT_EQ(pred1, call1);
  149. ASSERT_EQ(1, call2);
  150. if (final_sum) {
  151. ASSERT_EQ(pred0 || pred1, call3);
  152. }
  153. };
  154. for (size_t casenum = 0; casenum < 4; ++casenum) {
  155. int pred0 = casenum >> 1, pred1 = casenum & 1;
  156. host_pred0->ptr<int>()[0] = pred0;
  157. host_pred1->ptr<int>()[0] = pred1;
  158. *host_inp0 = *gen({2 + casenum, 3});
  159. *host_inp1 = *gen({2 + casenum, 3});
  160. switch (mode) {
  161. case MergeMode::EXACT_ONE:
  162. case MergeMode::EXACT_ONE_SAME_SHAPE: {
  163. if (pred0 + pred1 == 1) {
  164. check_all(pred0, pred1);
  165. ASSERT_EQ(
  166. prev_dev_ptr(pred0 ? inp0_cond : inp1_cond),
  167. prev_dev_ptr(merged));
  168. } else {
  169. if (mode == MergeMode::EXACT_ONE) {
  170. if (!pred_dynamic) {
  171. ASSERT_THROW(infer_shape(), MegBrainError);
  172. }
  173. } else {
  174. ASSERT_EQ(host_inp0->shape(), infer_shape());
  175. }
  176. if (!pred_dynamic) {
  177. ASSERT_THROW(infer_value(), MegBrainError);
  178. }
  179. ASSERT_THROW(func->execute(), MegBrainError);
  180. }
  181. break;
  182. }
  183. case MergeMode::SUM:
  184. case MergeMode::SUM_COND_OUT: {
  185. if (pred0 || pred1 || mode == MergeMode::SUM) {
  186. check_all(pred0, pred1);
  187. } else {
  188. // no pred, and mode is SUM_COND_OUT
  189. ASSERT_EQ(host_inp0->shape(), infer_shape());
  190. call0 = call1 = call2 = 0;
  191. func->execute();
  192. ASSERT_EQ(0, call0);
  193. ASSERT_EQ(0, call1);
  194. ASSERT_EQ(0, call2);
  195. }
  196. break;
  197. }
  198. default:
  199. mgb_trap();
  200. }
  201. }
  202. }
  203. void test_simple_grad(bool grad_cond_out) {
  204. auto graph = ComputingGraph::make();
  205. HostTensorGenerator<> gen;
  206. auto host_x = gen({2, 3}), host_y = gen({2, 3}), host_pred = gen({1});
  207. host_pred->ptr<float>()[0] = 0;
  208. auto x = opr::Host2DeviceCopy::make(*graph, host_x).rename("x"),
  209. y = opr::Host2DeviceCopy::make(*graph, host_y).rename("y"),
  210. pred = opr::Host2DeviceCopy::make(*graph, host_pred);
  211. auto branches = opr::CondExecPred::make(
  212. pred,
  213. {pred.make_scalar(0.f), pred.make_scalar(1.f), pred.make_scalar(2.f)});
  214. using GradMode = opr::CondExecMark::Param::GradMode;
  215. auto get_marked = [&branches, grad_cond_out](SymbolVar x, size_t br) {
  216. SymbolVar ret;
  217. unpack_vector(
  218. opr::CondExecMark::make(
  219. branches.at(br), {x},
  220. {grad_cond_out ? GradMode::SUM_COND_OUT : GradMode::SUM}),
  221. ret);
  222. return ret;
  223. };
  224. int call_x = 0, call_y = 0;
  225. auto cond_x0 = get_marked(x, 0).rename("cx0"),
  226. cond_x1 = get_marked(x, 1).rename("cx1"),
  227. cond_y = get_marked(y, 2).rename("cy"),
  228. z = merge_one_out(
  229. {cond_x0 * 2, cond_x1 * 3, cond_y * 2.3f},
  230. MergeMode::EXACT_ONE_SAME_SHAPE)
  231. .rename("merged"),
  232. loss = opr::reduce_sum_sqr(z + y, z.make_scalar(1)),
  233. gx = make_call_rec(cg::grad(loss, x), &call_x),
  234. gy = make_call_rec(cg::grad(loss, y), &call_y);
  235. std::array<float, 3> kx_all{2.f, 3.f, 0.f}, ky_all{1.f, 1.f, 3.3f};
  236. auto make_expect = [&](float kx, float ky, int wrt) {
  237. HostTensorND ret{host_x->comp_node(), host_x->shape()};
  238. auto pr = ret.ptr<float>(), px = host_x->ptr<float>(),
  239. py = host_y->ptr<float>();
  240. for (size_t i = 0, it = ret.shape().total_nr_elems(); i < it; ++i) {
  241. float s = px[i] * kx + py[i] * ky, ls = 2 * s;
  242. pr[i] = ls * (wrt ? ky : kx);
  243. }
  244. return ret;
  245. };
  246. HostTensorND host_gx, host_gy;
  247. auto func = graph->compile(
  248. {make_callback_copy(gx, host_gx), make_callback_copy(gy, host_gy)});
  249. for (size_t i = 0; i < 6; ++i) {
  250. *host_x = *gen({i + 3, 3});
  251. *host_y = *gen({i + 3, 3});
  252. int br_num = i % 3;
  253. host_pred->ptr<float>()[0] = br_num;
  254. call_x = 0;
  255. call_y = 0;
  256. func->execute();
  257. float kx = kx_all[br_num], ky = ky_all[br_num];
  258. if (grad_cond_out) {
  259. ASSERT_EQ(br_num <= 1, call_x);
  260. ASSERT_EQ(br_num == 2, call_y);
  261. } else {
  262. ASSERT_EQ(1, call_x);
  263. ASSERT_EQ(1, call_y);
  264. if (br_num < 2) {
  265. MGB_ASSERT_TENSOR_EQ(make_expect(kx, ky, 1), host_gy);
  266. } else {
  267. MGB_ASSERT_TENSOR_EQ(make_expect(kx, ky, 0), host_gx);
  268. }
  269. }
  270. if (br_num < 2) {
  271. MGB_ASSERT_TENSOR_EQ(make_expect(kx, ky, 0), host_gx);
  272. } else {
  273. MGB_ASSERT_TENSOR_EQ(make_expect(kx, ky, 1), host_gy);
  274. }
  275. }
  276. }
  277. void test_nested(bool check_grad) {
  278. using TwoVar = std::pair<SymbolVar, SymbolVar>;
  279. static auto make_bisect_pred = [](SymbolVar pred, float thresh) -> TwoVar {
  280. SymbolVar lt, ge;
  281. unpack_vector(
  282. opr::CondExecPred::make(
  283. pred, {pred.make_scalar_dt(thresh)},
  284. opr::CondExecPred::Mode::PIECEWISE),
  285. lt, ge);
  286. return {lt, ge};
  287. };
  288. static auto mark_two = [](SymbolVar x, TwoVar ppvs) -> TwoVar {
  289. SymbolVar a, b;
  290. unpack_vector(opr::CondExecMark::make(ppvs.first, {x}), a);
  291. unpack_vector(opr::CondExecMark::make(ppvs.second, {x}), b);
  292. return {a, b};
  293. };
  294. static auto make_bisect = [](SymbolVar x, SymbolVar pred, float thresh,
  295. int* call_lt, int* call_ge,
  296. TwoVar* pred_marked = nullptr) -> TwoVar {
  297. TwoVar pred_br;
  298. SymbolVar x_lt, x_ge;
  299. pred_br = make_bisect_pred(pred, thresh);
  300. std::tie(x_lt, x_ge) = mark_two(x, pred_br);
  301. if (pred_marked) {
  302. *pred_marked = mark_two(pred, pred_br);
  303. }
  304. return {make_call_rec(x_lt, call_lt), make_call_rec(x_ge, call_ge)};
  305. };
  306. auto graph = ComputingGraph::make();
  307. HostTensorGenerator<> gen;
  308. auto host_x = gen({2, 3}), host_pred = gen({1});
  309. int call_lt0, call_ge0;
  310. SymbolVar x = opr::Host2DeviceCopy::make(*graph, host_x).rename("x"),
  311. pred = opr::Host2DeviceCopy::make(*graph, host_pred).rename("pred"),
  312. x_lt_0, x_ge_0;
  313. TwoVar pred_th0;
  314. std::tie(x_lt_0, x_ge_0) = make_bisect(x, pred, 0, &call_lt0, &call_ge0, &pred_th0);
  315. x_lt_0 = x_lt_0.rename("lt0") / 2;
  316. x_ge_0 = x_ge_0.rename("ge0") * 2;
  317. int call_n0, call_n1, call_p0, call_p1;
  318. SymbolVar xn0, xn1, xp0, xp1;
  319. std::tie(xn0, xn1) = make_bisect(
  320. x_lt_0, pred_th0.first.rename("pred-neg"), -1, &call_n0, &call_n1);
  321. std::tie(xp0, xp1) = make_bisect(
  322. x_ge_0, pred_th0.second.rename("pred-pos"), 1, &call_p0, &call_p1);
  323. int call_xn, call_xp;
  324. auto xn_merge = make_call_rec(
  325. merge_one_out(
  326. {xn0.rename("xn0") - 3, xn1.rename("xn1") + 3},
  327. MergeMode::EXACT_ONE_SAME_SHAPE),
  328. &call_xn),
  329. xp_merge = make_call_rec(
  330. merge_one_out(
  331. {xp0.rename("xp0") - 4, xp1.rename("xp1") + 4},
  332. MergeMode::EXACT_ONE_SAME_SHAPE),
  333. &call_xp),
  334. out = merge_one_out({xn_merge, xp_merge}, MergeMode::EXACT_ONE_SAME_SHAPE);
  335. // value infer would fail becase EXACT_ONE can not be satisfied (our
  336. // inference system has no conditional execution)
  337. // so we only check shape inference here
  338. ASSERT_EQ(host_x->shape(), out.shape());
  339. HostTensorND host_out, host_gx;
  340. ComputingGraph::OutputSpec out_spec{make_callback_copy(out, host_out)};
  341. if (check_grad) {
  342. auto loss = opr::reduce_sum_sqr(out, out.make_scalar(1)),
  343. gx = cg::grad(loss, x);
  344. out_spec.emplace_back(make_callback_copy(gx, host_gx));
  345. }
  346. auto func = graph->compile(out_spec);
  347. func->to_json()->writeto_fpath(
  348. output_file(ssprintf("TestCondExec.nested-grad%d.json", check_grad)));
  349. std::array<float, 4> all_biases{-3.f, 3.f, -4.f, 4.f};
  350. for (size_t casenum = 0; casenum < 4; ++casenum) {
  351. host_pred->ptr<float>()[0] = -1.5 + casenum;
  352. call_lt0 = call_ge0 = call_n0 = call_n1 = call_p0 = call_p1 = call_xn =
  353. call_xp = 0;
  354. *host_x = *gen({casenum + 6, 4});
  355. float k = casenum < 2 ? 0.5f : 2.f, b = all_biases[casenum];
  356. HostTensorND expect, expect_gx;
  357. // init expect
  358. {
  359. auto ptr = expect.copy_from(*host_x).ptr<float>();
  360. for (size_t i = 0, it = expect.shape().total_nr_elems(); i < it; ++i) {
  361. ptr[i] = ptr[i] * k + b;
  362. }
  363. }
  364. // init expect_gx
  365. if (check_grad) {
  366. auto ptr = expect_gx.copy_from(*host_x).ptr<float>();
  367. for (size_t i = 0, it = expect.shape().total_nr_elems(); i < it; ++i) {
  368. auto x = ptr[i];
  369. ptr[i] = (k * x + b) * 2 * k;
  370. }
  371. }
  372. func->execute();
  373. MGB_ASSERT_TENSOR_EQ(expect, host_out);
  374. if (check_grad) {
  375. MGB_ASSERT_TENSOR_EQ(expect_gx, host_gx);
  376. }
  377. ASSERT_EQ(casenum < 2, call_lt0);
  378. ASSERT_EQ(casenum >= 2, call_ge0);
  379. ASSERT_EQ(1, call_n0 + call_n1 + call_p0 + call_p1);
  380. ASSERT_EQ(
  381. (call_n0 << 0) | (call_n1 << 1) | (call_p0 << 2) | (call_p1 << 3),
  382. 1 << casenum);
  383. ASSERT_EQ(call_lt0, call_xn);
  384. ASSERT_EQ(call_ge0, call_xp);
  385. }
  386. }
  387. void check_waiting_spec(SymbolVar var, const VarNodeArrayView& to_wait) {
  388. auto&& spec = var.node()->owner_opr()->input_waiting_spec();
  389. if (to_wait.empty()) {
  390. ASSERT_TRUE(spec.empty());
  391. return;
  392. }
  393. ASSERT_EQ(1u, spec.size());
  394. ASSERT_EQ(var.node()->comp_node(), spec[0].comp_node);
  395. ThinHashSet<VarNode*> to_wait_set;
  396. for (auto i : to_wait) {
  397. to_wait_set.insert(i);
  398. }
  399. for (auto i : spec[0].dev_ready) {
  400. ASSERT_EQ(1u, to_wait_set.count(i)) << SymbolVar{i};
  401. }
  402. ASSERT_EQ(to_wait_set.size(), spec[0].dev_ready.size());
  403. }
  404. class DynamicMemLeakChecker final : public cg::DeviceMemoryAllocator {
  405. std::atomic_size_t m_nr_alive{0};
  406. public:
  407. void alloc_dynamic(VarNode* var, DeviceTensorStorage& dest, size_t size) override {
  408. ASSERT_LT(dest.size(), size);
  409. ++m_nr_alive;
  410. auto ptr = dest.comp_node().alloc_device(size);
  411. auto del = [this, cn = dest.comp_node()](void* ptr) {
  412. cn.free_device(ptr);
  413. auto nr = m_nr_alive.fetch_sub(1);
  414. ASSERT_GT(nr, 0u);
  415. };
  416. dest.reset(dest.comp_node(), size, {static_cast<dt_byte*>(ptr), del});
  417. }
  418. size_t nr_alive() const { return m_nr_alive; }
  419. ~DynamicMemLeakChecker() { EXPECT_EQ(0u, nr_alive()); }
  420. };
  421. } // anonymous namespace
  422. TEST(TestCondExec, MarkSimple) {
  423. int nr_call = 0;
  424. auto graph = ComputingGraph::make();
  425. graph->options().graph_opt_level = 0;
  426. HostTensorGenerator<> gen;
  427. auto host_x = gen({2, 3}), host_pred = gen({1});
  428. auto x = opr::Host2DeviceCopy::make_no_fwd(*graph, host_x),
  429. pred = opr::Host2DeviceCopy::make(*graph, host_pred);
  430. SymbolVar xcond, ppv;
  431. unpack_vector(
  432. opr::CondExecPred::make(
  433. pred, {pred.make_scalar(0.f)},
  434. opr::CondExecPred::Param::Mode::CASE),
  435. ppv);
  436. ppv = opr::CondExecPredLogical::make({ppv}, opr::CondExecPredLogical::Mode::NAND);
  437. unpack_vector(opr::CondExecMark::make(ppv, {x}), xcond);
  438. {
  439. ASSERT_THROW(opr::CondExecMark::make(xcond, {x}), GraphError);
  440. // also test dedup
  441. auto tmp = opr::CondExecMark::mark_if_need(xcond, {x});
  442. ASSERT_EQ(xcond, tmp);
  443. ASSERT_EQ(ppv.node(), tmp.node()->owner_opr()->input().back());
  444. }
  445. auto y = make_call_rec(xcond + 2.3f, &nr_call);
  446. HostTensorND host_y;
  447. ASSERT_EQ(0u, y.node()->owner_opr()->node_prop().dep_map().count(ppv.node()));
  448. auto func = graph->compile({make_callback_copy(y, host_y)});
  449. // dependency added in topo sorter
  450. ASSERT_EQ(
  451. y.node()->owner_opr()->node_prop().dep_map().at(ppv.node()),
  452. cg::OperatorNodeBase::NodeProp::DepType::DEV_COMP_ORDER);
  453. auto make_expect = [&host_x]() {
  454. auto graph = ComputingGraph::make();
  455. HostTensorND ret;
  456. auto x = opr::Host2DeviceCopy::make(*graph, host_x);
  457. graph->compile({make_callback_copy(x + 2.3f, ret)})->execute();
  458. return ret;
  459. };
  460. auto pp = host_pred->ptr<float>();
  461. pp[0] = 0;
  462. func->execute();
  463. ASSERT_EQ(0, nr_call);
  464. ASSERT_TRUE(host_y.empty());
  465. pp[0] = 1;
  466. func->execute();
  467. ASSERT_EQ(1, nr_call);
  468. MGB_ASSERT_TENSOR_EQ(make_expect(), host_y);
  469. host_y = {};
  470. *host_x = *gen({5, 8});
  471. pp[0] = 0;
  472. func->execute();
  473. ASSERT_EQ(1, nr_call);
  474. ASSERT_TRUE(host_y.empty());
  475. pp[0] = 1;
  476. func->execute();
  477. ASSERT_EQ(2, nr_call);
  478. MGB_ASSERT_TENSOR_EQ(make_expect(), host_y);
  479. ASSERT_EQ(prev_dev_ptr(x), prev_dev_ptr(xcond));
  480. }
  481. TEST(TestCondExec, MarkConst) {
  482. auto graph = ComputingGraph::make();
  483. HostTensorGenerator<> gen;
  484. auto host_pred = gen({1});
  485. host_pred->ptr<float>()[0] = 0;
  486. auto pred = opr::Host2DeviceCopy::make(*graph, host_pred),
  487. y0 = make_one_cond(pred, pred.make_scalar(2.3f)),
  488. y1 = make_one_cond(pred + 1, pred.make_scalar(3.2f)),
  489. z = merge_one_out({y0, y1}, MergeMode::EXACT_ONE);
  490. HostTensorND host_z;
  491. auto func = graph->compile({make_callback_copy(z, host_z)});
  492. func->execute();
  493. ASSERT_EQ(TensorShape{1}, host_z.shape());
  494. ASSERT_EQ(3.2f, host_z.ptr<float>()[0]);
  495. host_pred->ptr<float>()[0] = 1;
  496. func->execute();
  497. ASSERT_EQ(2.3f, host_z.ptr<float>()[0]);
  498. }
  499. TEST(TestCondExec, Merge) {
  500. for (int i = 0; i < 16; ++i) {
  501. int im = i >> 2, idyn = (i >> 1) & 1, final_sum = i & 1;
  502. test_merge_opr(static_cast<MergeMode>(im), idyn, final_sum);
  503. ASSERT_FALSE(Test::HasFailure()) << "failed for mode=" << im << " dyn=" << idyn
  504. << " final_sum=" << final_sum;
  505. }
  506. }
  507. TEST(TestCondExec, SimpleGrad) {
  508. test_simple_grad(false);
  509. }
  510. TEST(TestCondExec, SimpleGradCondOut) {
  511. test_simple_grad(true);
  512. }
  513. TEST(TestCondExec, PredMode) {
  514. using Mode = opr::CondExecPred::Mode;
  515. // each case is a pair containing [pred, [branch_result]]
  516. using CaseDesc = std::vector<std::pair<float, std::vector<bool>>>;
  517. // pred opr is constructed using keys {0, 1, 2}
  518. auto run = [](Mode mode, const CaseDesc& cases) {
  519. auto graph = ComputingGraph::make();
  520. auto make_hv = [](float val) {
  521. auto ret = std::make_shared<HostTensorND>(
  522. CompNode::load("xpux"), TensorShape{1});
  523. ret->ptr<float>()[0] = val;
  524. return ret;
  525. };
  526. auto host_pred = make_hv(0), host_x = make_hv(0);
  527. auto x = opr::Host2DeviceCopy::make(*graph, host_x),
  528. pred = opr::Host2DeviceCopy::make(*graph, host_pred);
  529. auto branches = opr::CondExecPred::make(
  530. pred,
  531. {pred.make_scalar(0.f), pred.make_scalar(1.f), pred.make_scalar(2.f)},
  532. mode);
  533. size_t nr_branch = cases[0].second.size();
  534. ASSERT_EQ(nr_branch, branches.size());
  535. SymbolVarArray branch_vars, branch_vars_dyn;
  536. auto x_dyn = opr::MarkDynamicVar::make(x);
  537. for (size_t i = 0; i < nr_branch; ++i) {
  538. SymbolVar ret;
  539. int delta = 1 << i;
  540. unpack_vector(opr::CondExecMark::make(branches.at(i), {x}), ret);
  541. branch_vars.emplace_back(ret + delta);
  542. unpack_vector(opr::CondExecMark::make(branches.at(i), {x_dyn}), ret);
  543. branch_vars_dyn.emplace_back(ret + delta);
  544. }
  545. auto y = merge_one_out(branch_vars, MergeMode::SUM),
  546. y_dyn = merge_one_out(branch_vars_dyn, MergeMode::SUM, 0, {x.symshape()});
  547. HostTensorND host_y;
  548. auto func = graph->compile({make_callback_copy(y_dyn, host_y)});
  549. auto updater = cg::static_infer::StaticInferUpdater::make();
  550. updater->add_dest({y.node(), cg::static_infer::DepType::VALUE});
  551. auto&& mgr = graph->static_infer_manager();
  552. for (auto&& i : cases) {
  553. host_pred->ptr<float>()[0] = i.first;
  554. updater->update();
  555. HostTensorND infer_val;
  556. infer_val.copy_from(mgr.infer_value(y.node())).sync();
  557. func->execute();
  558. ASSERT_EQ(TensorShape{1}, infer_val.shape());
  559. ASSERT_EQ(TensorShape{1}, host_y.shape());
  560. uint32_t vinfer = infer_val.ptr<float>()[0], vy = host_y.ptr<float>()[0];
  561. ASSERT_EQ(vinfer, vy)
  562. << "input=" << i.first << " vinfer=" << std::bitset<8>{vinfer}
  563. << " vy=" << std::bitset<8>{vy};
  564. auto v = vy;
  565. for (size_t br = 0; br < nr_branch; ++br) {
  566. ASSERT_EQ(i.second[br], v & 1)
  567. << "input=" << i.first << " branch=" << br
  568. << " val=" << std::bitset<8>{vy};
  569. v >>= 1;
  570. }
  571. }
  572. };
  573. run(Mode::CASE, {
  574. {0.f, {1, 0, 0}},
  575. {2.f, {0, 0, 1}},
  576. {2.1f, {0, 0, 0}},
  577. });
  578. ASSERT_FALSE(Test::HasFailure()) << "CASE mode failed";
  579. run(Mode::CASE_FALLBACK,
  580. {{0.f, {1, 0, 0, 0}}, {2.f, {0, 0, 1, 0}}, {2.1f, {0, 0, 0, 1}}});
  581. ASSERT_FALSE(Test::HasFailure()) << "CASE_FALLBACK mode failed";
  582. run(Mode::PIECEWISE, {{-1.f, {1, 0, 0, 0}},
  583. {-0.1f, {1, 0, 0, 0}},
  584. {0.f, {0, 1, 0, 0}},
  585. {0.1f, {0, 1, 0, 0}},
  586. {0.99f, {0, 1, 0, 0}},
  587. {1.f, {0, 0, 1, 0}},
  588. {1.01f, {0, 0, 1, 0}},
  589. {1.5f, {0, 0, 1, 0}},
  590. {2.f, {0, 0, 0, 1}},
  591. {2e3f, {0, 0, 0, 1}}});
  592. ASSERT_FALSE(Test::HasFailure()) << "PIECEWISE mode failed";
  593. static_assert(opr::CondExecPred::Param::MODE_NR_MEMBER == 3, "not all mode tested");
  594. }
  595. TEST(TestCondExec, PredLogicalMode) {
  596. using Mode = opr::CondExecPredLogical::Mode;
  597. using Checker = thin_function<bool(int nr_true)>;
  598. auto run = [](Mode mode, const size_t nr_input, Checker checker) {
  599. const size_t nr_case = 1 << nr_input;
  600. auto host_pred = std::make_shared<HostTensorND>(
  601. CompNode::load("xpux"), TensorShape{nr_case});
  602. auto host_x =
  603. std::make_shared<HostTensorND>(CompNode::load("xpux"), TensorShape{1});
  604. memset(host_pred->ptr<float>(), 0, sizeof(float) * nr_case);
  605. host_x->ptr<float>()[0] = 0;
  606. auto graph = ComputingGraph::make();
  607. auto x = opr::Host2DeviceCopy::make(*graph, host_x),
  608. pred = opr::Host2DeviceCopy::make(*graph, host_pred),
  609. pred_dyn = opr::MarkDynamicVar::make(pred);
  610. SymbolVarArray inputs, inputs_dyn;
  611. for (size_t i = 0; i < nr_input; ++i) {
  612. SymbolVar p, p_dyn, key = pred.make_scalar_dt(1);
  613. opr::Subtensor::IndexDesc idx{opr::indexing::AxisIndexer::make_index(
  614. 0, pred.make_scalar(static_cast<int>(i)))};
  615. auto sub = [&idx](SymbolVar x) { return opr::Subtensor::make(x, idx); };
  616. unpack_vector(opr::CondExecPred::make(sub(pred), {key}), p);
  617. unpack_vector(opr::CondExecPred::make(sub(pred_dyn), {key}), p_dyn);
  618. inputs.push_back(p);
  619. inputs_dyn.push_back(p_dyn);
  620. }
  621. SymbolVar logic_out = opr::CondExecPredLogical::make(inputs, mode),
  622. logic_out_dyn = opr::CondExecPredLogical::make(inputs_dyn, mode),
  623. x_mark, x_mark_dyn;
  624. unpack_vector(opr::CondExecMark::make(logic_out, {x}), x_mark);
  625. unpack_vector(opr::CondExecMark::make(logic_out_dyn, {x}), x_mark_dyn);
  626. auto y = merge_one_out({x_mark + 1}, MergeMode::SUM),
  627. y_dyn = merge_one_out({x_mark_dyn + 1}, MergeMode::SUM);
  628. HostTensorND host_y;
  629. auto func = graph->compile({make_callback_copy(y_dyn, host_y)});
  630. auto updater = cg::static_infer::StaticInferUpdater::make();
  631. updater->add_dest({y.node(), cg::static_infer::DepType::VALUE});
  632. auto&& mgr = graph->static_infer_manager();
  633. for (size_t i = 0; i < nr_case; ++i) {
  634. size_t nr_one = 0;
  635. for (size_t j = 0; j < nr_input; ++j) {
  636. auto cur = (i >> j) & 1;
  637. host_pred->ptr<float>()[j] = cur;
  638. nr_one += cur;
  639. }
  640. updater->update();
  641. int vinfer = mgr.infer_value(y.node()).ptr<float>()[0];
  642. func->execute();
  643. int vy = host_y.ptr<float>()[0];
  644. ASSERT_EQ(checker(nr_one), vy) << "case=" << i;
  645. ASSERT_EQ(vy, vinfer) << "case=" << i;
  646. }
  647. };
  648. for (int inp = 1; inp < 5; ++inp) {
  649. #define DO_RUN(mode, fn) \
  650. do { \
  651. run(Mode::mode, inp, fn); \
  652. ASSERT_FALSE(Test::HasFailure()) << "failed on " << #mode << " inp=" << inp; \
  653. } while (0)
  654. DO_RUN(OR, [](int n) { return n != 0; });
  655. DO_RUN(AND, [inp](int n) { return n == inp; });
  656. DO_RUN(XOR, [](int n) { return n & 1; });
  657. DO_RUN(NOR, [](int n) { return n == 0; });
  658. DO_RUN(NAND, [inp](int n) { return n != inp; });
  659. DO_RUN(XNOR, [](int n) { return !(n & 1); });
  660. #undef DO_RUN
  661. }
  662. static_assert(
  663. opr::CondExecPredLogical::Param::MODE_NR_MEMBER == 6,
  664. "not all mode tested");
  665. }
  666. TEST(TestCondExec, Nested) {
  667. test_nested(false);
  668. }
  669. TEST(TestCondExec, NestedGrad) {
  670. test_nested(true);
  671. }
  672. TEST(TestCondExec, MergeSumDyn) {
  673. auto graph = ComputingGraph::make();
  674. HostTensorGenerator<> gen;
  675. auto host_x = gen({2, 3}), host_pred = gen({1});
  676. auto x = opr::Host2DeviceCopy::make(*graph, host_x),
  677. pred = opr::Host2DeviceCopy::make(*graph, host_pred),
  678. cx0 = opr::MarkDynamicVar::make(make_one_cond(pred, x)) + 1.f,
  679. cx1 = opr::MarkDynamicVar::make(make_one_cond(pred - 1.f, x) + 2.f);
  680. ASSERT_THROW(merge_one_out({cx0, cx1}, MergeMode::SUM, 0, {}), GraphError);
  681. auto y = merge_one_out({cx0, cx1}, MergeMode::SUM, 0, {x.symshape()});
  682. HostTensorND host_y;
  683. auto func = graph->compile({make_callback_copy(y, host_y)});
  684. auto run = [&](float k, float bias) {
  685. host_pred->ptr<float>()[0] = bias;
  686. HostTensorND expect;
  687. expect.copy_from(*host_x);
  688. auto px = expect.ptr<float>();
  689. for (size_t i = 0, it = expect.shape().total_nr_elems(); i < it; ++i) {
  690. px[i] = (px[i] + bias) * k;
  691. }
  692. func->execute();
  693. MGB_ASSERT_TENSOR_EQ(expect, host_y);
  694. };
  695. run(1.f, 1.f);
  696. run(0.f, -1.f);
  697. run(1.f, 2.f);
  698. }
  699. TEST(TestCondExec, AddUpdateFwd) {
  700. auto graph = ComputingGraph::make();
  701. graph->options().graph_opt_level = 0;
  702. HostTensorGenerator<> gen;
  703. auto host_x = gen({2, 3}), host_pred = gen({1});
  704. auto dev_x = std::make_shared<DeviceTensorND>();
  705. dev_x->copy_from(*host_x);
  706. host_pred->ptr<float>()[0] = 1;
  707. auto x = opr::SharedDeviceTensor::make(*graph, dev_x),
  708. pred = opr::Host2DeviceCopy::make(*graph, host_pred),
  709. xmark0 = make_one_cond(pred, x, 1, 0, true),
  710. xmark1 = make_one_cond(pred - 1, x, 1, 0, true),
  711. xmerge = merge_one_out({xmark0 + 1, xmark1 + 2}, MergeMode::EXACT_ONE),
  712. loss = opr::reduce_sum_sqr(xmerge, x.make_scalar(1)), gx = cg::grad(loss, x),
  713. xud = opr::AddUpdate::make(x, gx);
  714. auto func = graph->compile({{xud, {}}});
  715. auto run = [&](float bias) {
  716. host_pred->ptr<float>()[0] = bias;
  717. dev_x->copy_from(*host_x);
  718. func->execute();
  719. HostTensorND got, expect;
  720. got.copy_from(*dev_x).sync();
  721. expect.copy_from(*host_x);
  722. auto px = expect.ptr<float>();
  723. for (size_t i = 0, it = expect.shape().total_nr_elems(); i < it; ++i) {
  724. px[i] += 2 * (px[i] + bias);
  725. }
  726. MGB_ASSERT_TENSOR_EQ(expect, got);
  727. if (bias == 1) {
  728. ASSERT_EQ(dev_x->raw_ptr(), prev_dev_ptr(xmark0));
  729. } else {
  730. ASSERT_EQ(dev_x->raw_ptr(), prev_dev_ptr(xmark1));
  731. }
  732. };
  733. run(1);
  734. run(2);
  735. }
  736. TEST(TestCondExec, CondAddUpdate) {
  737. auto graph = ComputingGraph::make();
  738. HostTensorGenerator<> gen;
  739. auto host_x = gen({2, 3}), host_pred = gen({1});
  740. auto dev_x = std::make_shared<DeviceTensorND>();
  741. dev_x->copy_from(*host_x);
  742. host_pred->ptr<float>()[0] = 1;
  743. auto x = opr::SharedDeviceTensor::make(*graph, dev_x),
  744. pred = opr::Host2DeviceCopy::make(*graph, host_pred),
  745. xmark = make_one_cond(pred, x), xud = opr::AddUpdate::make(x, xmark * 1.3f);
  746. auto func = graph->compile({{xud, {}}});
  747. auto run = [&](float pred) {
  748. host_pred->ptr<float>()[0] = pred;
  749. dev_x->copy_from(*host_x);
  750. func->execute();
  751. HostTensorND got, expect;
  752. got.copy_from(*dev_x).sync();
  753. expect.copy_from(*host_x);
  754. if (pred == 1.f) {
  755. auto px = expect.ptr<float>();
  756. for (size_t i = 0, it = expect.shape().total_nr_elems(); i < it; ++i) {
  757. px[i] *= 2.3f;
  758. }
  759. }
  760. MGB_ASSERT_TENSOR_EQ(expect, got);
  761. };
  762. run(3);
  763. run(1);
  764. run(2);
  765. }
  766. TEST(TestCondExec, MultiCnMarkWaitPred) {
  767. auto graph = ComputingGraph::make();
  768. graph->options().var_sanity_check_first_run = false;
  769. HostTensorGenerator<> gen;
  770. auto host_x = gen({2, 3}), host_pred = gen({1});
  771. auto cn0 = host_x->comp_node(), cn1 = cn0.change_stream(1);
  772. auto x = opr::Host2DeviceCopy::make(*graph, host_x),
  773. pred = opr::Host2DeviceCopy::make(*graph, host_pred),
  774. pred_delayed = opr::Sleep::make(pred, 0.05, {}, cn1);
  775. SymbolVar ppv, y;
  776. unpack_vector(
  777. opr::CondExecPred::make(pred_delayed, {pred_delayed.make_scalar_dt(1.f)}),
  778. ppv);
  779. unpack_vector(opr::CondExecMark::make(ppv, {x}, {}, cn0), y);
  780. HostTensorND host_y;
  781. auto func = graph->compile({make_callback_copy(y, host_y)});
  782. host_pred->ptr<float>()[0] = 0;
  783. func->execute();
  784. ASSERT_TRUE(host_y.empty());
  785. host_pred->ptr<float>()[0] = 1;
  786. func->execute();
  787. MGB_ASSERT_TENSOR_EQ(*host_x, host_y);
  788. }
  789. TEST(TestCondExec, MultiCnMergeWaitPred) {
  790. auto graph = ComputingGraph::make();
  791. graph->options().var_sanity_check_first_run = false;
  792. graph->options().graph_opt_level = 0;
  793. HostTensorGenerator<> gen;
  794. auto host_x = gen({2, 3}), host_pred = gen({1});
  795. auto cn0 = host_x->comp_node(), cn1 = cn0.change_stream(1),
  796. cn2 = cn0.change_stream(2);
  797. SymbolVar x = opr::Host2DeviceCopy::make(*graph, host_x),
  798. pred = opr::Host2DeviceCopy::make(*graph, host_pred), ppv0, ppv1;
  799. auto make_marked = [cn1, pred](SymbolVar x, float pv, SymbolVar& ppv) {
  800. SymbolVar y;
  801. unpack_vector(
  802. opr::CondExecPred::make(
  803. opr::Sleep::make(pred, 0.05), {pred.make_scalar_dt(pv)}),
  804. ppv);
  805. unpack_vector(opr::CondExecMark::make(ppv, {x}, {}, cn1), y);
  806. return y;
  807. };
  808. SymbolVar y0 = make_marked(x, 1.f, ppv0) + 1.f, // cn1
  809. y1 = make_marked(x, 2.f, ppv1) + 2.f, // cn1
  810. z = opr::CondExecMerge::make({y0, y1}, {1, MergeMode::SUM_COND_OUT})[0];
  811. HostTensorND host_z;
  812. z.node()->comp_node(cn2); // change z to cn2
  813. z.node()->owner_opr()->on_output_comp_node_stream_changed();
  814. auto func = graph->compile({make_callback_copy(z, host_z)});
  815. SymbolVar z_ppv = z.node()->owner_opr()->input().back();
  816. check_waiting_spec(z_ppv, {ppv0});
  817. check_waiting_spec(z, {y0});
  818. host_pred->ptr<float>()[0] = 0;
  819. func->execute();
  820. ASSERT_TRUE(host_z.empty());
  821. auto run = [&](float bias) {
  822. host_pred->ptr<float>()[0] = bias;
  823. HostTensorND expect;
  824. expect.copy_from(*host_x);
  825. auto px = expect.ptr<float>();
  826. for (size_t i = 0, it = expect.shape().total_nr_elems(); i < it; ++i) {
  827. px[i] += bias;
  828. }
  829. func->execute();
  830. MGB_ASSERT_TENSOR_EQ(expect, host_z);
  831. };
  832. run(1);
  833. run(2);
  834. }
  835. TEST(TestCondExec, InputWaitingForMerge) {
  836. using Elemwise = opr::Elemwise;
  837. auto cn0 = CompNode::load("xpux"), cn1 = cn0.change_stream(1);
  838. HostTensorGenerator<> gen;
  839. auto host_pred = gen({1}, cn0), host_x = gen({2, 3}, cn0);
  840. host_pred->ptr<float>()[0] = 0;
  841. auto graph = ComputingGraph::make();
  842. graph->options().graph_opt_level = 0;
  843. graph->options().seq_opt.enable_seq_comp_node_opt = false;
  844. auto pred = opr::Host2DeviceCopy::make_no_value_infer(*graph, host_pred);
  845. auto make_delayed_pred = [pred](CompNode cn) {
  846. return opr::MarkDynamicVar::make(opr::Sleep::make(pred, 0.02, {}, cn));
  847. };
  848. auto make_marked = [](SymbolVar x, SymbolVar pred, float key) -> SymbolVar {
  849. SymbolVar ppv;
  850. unpack_vector(opr::CondExecPred::make(pred, {pred.make_scalar(key)}), ppv);
  851. SymbolVar xcond;
  852. unpack_vector(
  853. opr::CondExecMark::make(ppv, {x}, {}, {pred.node()->comp_node()}),
  854. xcond);
  855. return xcond;
  856. };
  857. auto make_merged = [cn0](const VarNodeArrayView& arr) -> SymbolVar {
  858. SymbolVar ret;
  859. for (size_t i = 0; i < arr.size(); ++i) {
  860. mgb_assert((i == 0) == (arr[i]->comp_node() == cn0));
  861. }
  862. unpack_vector(
  863. opr::CondExecMerge::make(arr, {1, MergeMode::SUM_COND_OUT}, {}, cn0),
  864. ret);
  865. return ret;
  866. };
  867. auto x = opr::Host2DeviceCopy::make_no_fwd(*graph, host_x),
  868. x1 = opr::Copy::make(x, cn1), pred0 = make_delayed_pred(cn0),
  869. pred1 = make_delayed_pred(cn1), y0 = make_marked(x, pred0, 1) + 1, // on cn0
  870. y10 = make_marked(x1, pred1, 2) + 2, // on cn1
  871. y11 = make_marked(x, pred1, 3) + 3, // on cn1
  872. ymgr = make_merged({y0, y10, y11}), // on cn0
  873. z = Elemwise::make(
  874. {x1, opr::Sleep::make(ymgr, 0.03)}, Elemwise::Mode::ADD,
  875. cn0); // (cn1, cn0) -> cn0
  876. HostTensorND host_z;
  877. auto func = graph->compile({make_callback_copy(z, host_z)});
  878. check_waiting_spec(ymgr, {y10});
  879. // provable ymgr is later than x1
  880. check_waiting_spec(z, {});
  881. auto run = [&](float pv) {
  882. *host_x = *gen({2 + static_cast<size_t>(pv), 5});
  883. host_pred->ptr<float>()[0] = pv;
  884. host_z = {};
  885. func->execute();
  886. if (pv < 1) {
  887. ASSERT_TRUE(host_z.empty());
  888. return;
  889. }
  890. HostTensorND expect;
  891. auto ptr = expect.copy_from(*host_x).ptr<float>();
  892. for (size_t i = 0, it = host_x->shape().total_nr_elems(); i < it; ++i) {
  893. ptr[i] = ptr[i] * 2 + pv;
  894. }
  895. MGB_ASSERT_TENSOR_EQ(expect, host_z);
  896. };
  897. run(2);
  898. run(1);
  899. run(3);
  900. run(2);
  901. run(-1);
  902. }
  903. TEST(TestCondExec, GradMultiReader) {
  904. // multiple readers of the grad wrt var, on multiple comp nodes
  905. auto cns = load_multiple_xpus(2);
  906. auto graph = ComputingGraph::make();
  907. HostTensorGenerator<> gen;
  908. auto host_pred = gen({1}, cns[0]), host_x = gen({2, 3}, cns[0]);
  909. host_pred->ptr<float>()[0] = 0;
  910. auto copy1 = [&cns](SymbolVar x) { return opr::Copy::make(x, cns[1]); };
  911. auto pred = opr::Host2DeviceCopy::make(*graph, host_pred),
  912. x = opr::Host2DeviceCopy::make(*graph, host_x),
  913. y0 = copy1(make_one_cond(pred, x, 1, 0, true)),
  914. y1 = copy1(make_one_cond(pred + 1, x * 2.f, 1, 0, true)),
  915. y2 = make_one_cond(copy1(pred) + 2, copy1(x) * 3.f, 1, 0, true),
  916. z = opr::Copy::make(merge_one_out({y0, y1, y2}, MergeMode::SUM), cns[0]),
  917. loss = opr::reduce_sum_sqr(z, z.make_scalar(1)), gx = cg::grad(loss, x);
  918. ASSERT_TRUE(cg::is_static_var_value(z.node()));
  919. HostTensorND host_gx;
  920. auto func = graph->compile({make_callback_copy(gx, host_gx)});
  921. auto run = [&](int pv, Maybe<float> coeff) {
  922. host_pred->ptr<float>()[0] = pv;
  923. host_gx = {};
  924. func->execute();
  925. if (!coeff.valid()) {
  926. ASSERT_TRUE(host_gx.empty());
  927. return;
  928. }
  929. HostTensorND expect;
  930. expect.copy_from(*host_x);
  931. auto ptr = expect.ptr<float>();
  932. auto c = coeff.val();
  933. c = c * c * 2;
  934. for (size_t i = 0, it = host_x->shape().total_nr_elems(); i < it; ++i) {
  935. ptr[i] = ptr[i] * c;
  936. }
  937. MGB_ASSERT_TENSOR_EQ(expect, host_gx);
  938. };
  939. run(-1, 3.f);
  940. run(0, 2.f);
  941. run(1, 1.f);
  942. run(2, None);
  943. }
  944. TEST(TestCondExec, SyncForMultiCN) {
  945. auto cns = load_multiple_xpus(2);
  946. auto graph = ComputingGraph::make();
  947. graph->options().var_sanity_check_first_run = false;
  948. HostTensorGenerator<> gen;
  949. auto host_pred = gen({1}, cns[0]), host_x = gen({2, 3}, cns[0]);
  950. host_pred->ptr<float>()[0] = 0;
  951. auto copy1 = [&cns](SymbolVar x) { return opr::Copy::make(x, cns[1]); };
  952. auto pred = opr::Host2DeviceCopy::make_no_value_infer(*graph, host_pred),
  953. x = opr::Host2DeviceCopy::make(*graph, host_x), y0 = make_one_cond(pred, x, 1),
  954. y1 = make_one_cond(copy1(pred) + 1, copy1(x) * 2.f),
  955. y2 = make_one_cond(copy1(pred) + 2, copy1(x) * 3.f),
  956. y12 = opr::Copy::make(
  957. merge_one_out({y1, y2}, MergeMode::SUM_COND_OUT), cns[0]),
  958. z = merge_one_out({y12, y0}, MergeMode::EXACT_ONE),
  959. loss = opr::reduce_sum_sqr(z, z.make_scalar(1)), gx = cg::grad(loss, x);
  960. ASSERT_FALSE(cg::is_static_var_value(z.node()));
  961. HostTensorND host_gx;
  962. auto func = graph->compile({make_callback_copy(gx, host_gx)});
  963. auto run = [&](int pv, Maybe<float> coeff) {
  964. host_pred->ptr<float>()[0] = pv;
  965. host_gx = {};
  966. opr::Sleep::sleep(cns[0], 0.1); // sleep to delay h2d copy
  967. func->execute();
  968. if (!coeff.valid()) {
  969. ASSERT_TRUE(host_gx.empty());
  970. return;
  971. }
  972. HostTensorND expect;
  973. expect.copy_from(*host_x);
  974. auto ptr = expect.ptr<float>();
  975. auto c = coeff.val();
  976. c = c * c * 2;
  977. for (size_t i = 0, it = host_x->shape().total_nr_elems(); i < it; ++i) {
  978. ptr[i] = ptr[i] * c;
  979. }
  980. MGB_ASSERT_TENSOR_EQ(expect, host_gx);
  981. };
  982. run(-1, 3.f);
  983. run(0, 2.f);
  984. run(1, 1.f);
  985. }
  986. TEST(TestCondExec, AsyncCondAccess) {
  987. constexpr float SLEEP_TIME = 0.2;
  988. auto graph = ComputingGraph::make();
  989. graph->options().var_sanity_check_first_run = false;
  990. graph->options().graph_opt_level = 0;
  991. auto allocator = std::make_shared<DynamicMemLeakChecker>();
  992. graph->set_device_memory_allocator(allocator);
  993. HostTensorGenerator<> gen;
  994. auto host_x = gen({2, 3}), host_pred = gen({1});
  995. auto cn1 = host_x->comp_node().change_stream(1);
  996. auto x = opr::Host2DeviceCopy::make_no_fwd(*graph, host_x),
  997. pred = opr::Host2DeviceCopy::make(*graph, host_pred),
  998. xmark = make_one_cond(pred, x),
  999. xmark_delay = opr::Sleep::make(xmark, SLEEP_TIME, {}, cn1),
  1000. xp1 = (x + 1).rename("xp1"),
  1001. y = opr::Elemwise::make(
  1002. {xmark_delay + 2.3f, xp1}, opr::Elemwise::Mode::ADD, cn1);
  1003. host_pred->ptr<float>()[0] = 0;
  1004. set_priority(xp1, 100);
  1005. HostTensorND host_y;
  1006. auto func = graph->compile({make_callback_copy(y, host_y)});
  1007. check_waiting_spec(y, {xp1});
  1008. ASSERT_FALSE(cg::is_static_var_storage(xp1.node()));
  1009. RealTimer timer;
  1010. func->execute().wait();
  1011. ASSERT_TRUE(host_y.empty());
  1012. // sleep kernel in cuda is easily affected by the frequency change of GPU,
  1013. // so we just print warn log instead assert. more refer to
  1014. // XPU-226
  1015. auto use_time = timer.get_secs();
  1016. if (use_time >= SLEEP_TIME / 2) {
  1017. mgb_log_warn(
  1018. "expect time [%f < %f], got %f", use_time, SLEEP_TIME / 2, use_time);
  1019. }
  1020. ASSERT_EQ(0u, allocator->nr_alive());
  1021. host_pred->ptr<float>()[0] = 1;
  1022. func->execute().wait();
  1023. use_time = timer.get_secs();
  1024. if (use_time <= SLEEP_TIME) {
  1025. mgb_log_warn("expect time [%f > %f], got %f", use_time, SLEEP_TIME, use_time);
  1026. }
  1027. HostTensorND expect;
  1028. graph->compile({make_callback_copy(x * 2 + 3.3f, expect)})->execute();
  1029. MGB_ASSERT_TENSOR_EQ(expect, host_y);
  1030. ASSERT_EQ(0u, allocator->nr_alive());
  1031. }
  1032. TEST(TestCondExec, VolatilePtr) {
  1033. auto graph = ComputingGraph::make();
  1034. HostTensorGenerator<> gen;
  1035. HostTensorGenerator<dtype::Int32> gen_int;
  1036. HostTensorND expect;
  1037. auto host_pred = gen_int({1});
  1038. auto dev_x = std::make_shared<DeviceTensorND>();
  1039. auto assign = [&](int br) {
  1040. host_pred->ptr<int>()[0] = br;
  1041. expect = *gen({2, 3});
  1042. auto hold = *dev_x;
  1043. *dev_x = {};
  1044. // ensure a different ptr
  1045. dev_x->copy_from(expect).sync();
  1046. auto p = expect.ptr<float>();
  1047. for (size_t i = 0; i < 6; ++i) {
  1048. p[i] = p[i] + (br == 0 ? 1.2f : 2.1f);
  1049. }
  1050. };
  1051. assign(0);
  1052. auto x = opr::VolatileSharedDeviceTensor::make(*graph, dev_x),
  1053. pred = opr::Host2DeviceCopy::make(*graph, host_pred),
  1054. xc0 = make_one_cond(pred + 1, x), xc1 = make_one_cond(pred, x),
  1055. y = merge_one_out({xc0 + 1.2f, xc1 + 2.1f}, MergeMode::EXACT_ONE_SAME_SHAPE);
  1056. HostTensorND host_y;
  1057. auto func = graph->compile({make_callback_copy(y, host_y)});
  1058. auto run = [&](int br) {
  1059. assign(br);
  1060. func->execute();
  1061. MGB_ASSERT_TENSOR_EQ(expect, host_y);
  1062. if (br == 0) {
  1063. ASSERT_EQ(dev_x->raw_ptr(), prev_dev_ptr(xc0));
  1064. } else {
  1065. ASSERT_EQ(dev_x->raw_ptr(), prev_dev_ptr(xc1));
  1066. }
  1067. };
  1068. run(0);
  1069. run(1);
  1070. run(1);
  1071. run(0);
  1072. }
  1073. TEST(TestCondExec, MultiShape) {
  1074. auto graph = ComputingGraph::make();
  1075. HostTensorGenerator<> gen;
  1076. auto host_x = gen({2}), host_d2 = gen({2}), host_d3 = gen({3});
  1077. //! return y conditioned on shape of \p x equaling \p shp
  1078. auto enable_if_shape = [](SymbolVar x, size_t shp) {
  1079. auto y = make_one_cond(x.symshape() - static_cast<int>(shp - 1), x);
  1080. // static shape inference is always performed regardless of cond
  1081. // exec mark, so we add a reshape here to hint the true shape of y, to
  1082. // ensure that shape inference of oprs depending on y could succeed
  1083. // TODO: remove this if static infer considers execution mask
  1084. y = y.reshape(TensorShape{shp});
  1085. return y;
  1086. };
  1087. SymbolVar x = opr::Host2DeviceCopy::make(*graph, host_x),
  1088. d2 = opr::Host2DeviceCopy::make(*graph, host_d2),
  1089. d3 = opr::Host2DeviceCopy::make(*graph, host_d3),
  1090. xc0 = enable_if_shape(x, 2) + d2, xc1 = enable_if_shape(x, 3) + d3,
  1091. merged = merge_one_out({xc0, xc1}, MergeMode::EXACT_ONE),
  1092. loss = opr::reduce_sum_sqr(merged, merged.make_scalar(1)),
  1093. gx = cg::grad(loss, x);
  1094. HostTensorND host_gx;
  1095. auto func = graph->compile({make_callback_copy(gx, host_gx)});
  1096. auto check = [&](const std::shared_ptr<HostTensorND>& host_delta) {
  1097. auto pd = host_delta->ptr<float>();
  1098. HostTensorND expect;
  1099. auto pe = expect.copy_from(*host_x).ptr<float>();
  1100. for (size_t i = 0, it = expect.shape().total_nr_elems(); i < it; ++i) {
  1101. pe[i] = 2 * (pe[i] + pd[i]);
  1102. }
  1103. func->execute();
  1104. MGB_ASSERT_TENSOR_EQ(expect, host_gx);
  1105. };
  1106. check(host_d2);
  1107. *host_x = *gen({3});
  1108. check(host_d3);
  1109. *host_x = *gen({3});
  1110. check(host_d3);
  1111. *host_x = *gen({2});
  1112. check(host_d2);
  1113. }
  1114. TEST(TestCondExec, EmptyShape) {
  1115. HostTensorGenerator<> gen;
  1116. auto host_pred = gen({1});
  1117. host_pred->ptr<float>()[0] = 0;
  1118. static auto empty_in_empty_out = [](SymbolVar x) { return x; };
  1119. static auto empty_in_scalar_out = [](SymbolVar x) {
  1120. return opr::Concat::make({x, x.make_scalar(1.f)}, 0);
  1121. };
  1122. static auto scalar_in_empty_out = [](SymbolVar x) {
  1123. return opr::CondTake::make(x, x, {})[0]; // whether eq 0
  1124. };
  1125. { // EXACT_ONE
  1126. auto graph = ComputingGraph::make();
  1127. auto pred = opr::Host2DeviceCopy::make(*graph, host_pred),
  1128. empty = opr::ImmutableTensor::make(*graph, *gen({0})),
  1129. scalar = pred.make_scalar(1.f),
  1130. y0 = empty_in_empty_out(make_one_cond(pred + 1, empty)),
  1131. y1 = empty_in_scalar_out(make_one_cond(pred, empty)),
  1132. y2 = scalar_in_empty_out(make_one_cond(pred - 1, scalar)),
  1133. z = merge_one_out({y0, y1, y2}, MergeMode::EXACT_ONE);
  1134. HostTensorND host_z;
  1135. auto func = graph->compile({make_callback_copy(z, host_z)});
  1136. func->execute();
  1137. ASSERT_TRUE(host_z.layout().is_empty());
  1138. host_pred->ptr<float>()[0] = 1;
  1139. func->execute();
  1140. ASSERT_EQ(1.f, host_z.ptr<float>()[0]);
  1141. host_pred->ptr<float>()[0] = 2;
  1142. func->execute();
  1143. ASSERT_TRUE(host_z.layout().is_empty());
  1144. }
  1145. { // SUM
  1146. auto graph = ComputingGraph::make();
  1147. host_pred->ptr<float>()[0] = 1;
  1148. auto pred = opr::Host2DeviceCopy::make(*graph, host_pred),
  1149. empty = opr::ImmutableTensor::make(*graph, *gen({0})),
  1150. scalar = pred.make_scalar(1.f),
  1151. y0 = empty_in_empty_out(make_one_cond(pred, empty)),
  1152. y1 = scalar_in_empty_out(make_one_cond(pred, scalar)),
  1153. z = merge_one_out({y0, y1}, MergeMode::SUM);
  1154. HostTensorND host_z;
  1155. auto func = graph->compile({make_callback_copy(z, host_z)});
  1156. func->execute();
  1157. ASSERT_TRUE(host_z.layout().is_empty());
  1158. }
  1159. { // TAKE GRAD
  1160. auto graph = ComputingGraph::make();
  1161. host_pred->ptr<float>()[0] = 0;
  1162. auto pred = opr::Host2DeviceCopy::make(*graph, host_pred),
  1163. x = pred.make_scalar(1.2f),
  1164. y0 = opr::CondTake::make(make_one_cond(pred + 1, x), pred, {})[0],
  1165. y1 = make_one_cond(pred, x.make_scalar(3.4f)),
  1166. z = merge_one_out({y0, y1}, MergeMode::EXACT_ONE), g = cg::grad(z, x);
  1167. HostTensorND host_z, host_g;
  1168. auto func = graph->compile(
  1169. {make_callback_copy(z, host_z), make_callback_copy(g, host_g)});
  1170. func->execute();
  1171. ASSERT_EQ(1.2f, host_z.ptr<float>()[0]);
  1172. ASSERT_EQ(1.f, host_g.ptr<float>()[0]);
  1173. host_pred->ptr<float>()[0] = 1;
  1174. func->execute();
  1175. ASSERT_EQ(3.4f, host_z.ptr<float>()[0]);
  1176. ASSERT_EQ(0.f, host_g.ptr<float>()[0]);
  1177. }
  1178. }
  1179. #endif // MGB_ENABLE_COND_EXEC
  1180. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}