You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

misc.cpp 31 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805
  1. /**
  2. * \file src/gopt/test/misc.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "./helper.h"
  12. #include "megbrain/gopt/basic_arith.h"
  13. #include "megbrain/gopt/misc.h"
  14. #include "megbrain/opr/basic_arith_wrapper.h"
  15. #include "megbrain/opr/blas.h"
  16. #include "megbrain/opr/cond.h"
  17. #include "megbrain/opr/io.h"
  18. #include "megbrain/opr/tensor_manip.h"
  19. #include "megbrain/opr/utility.h"
  20. using namespace mgb;
  21. TEST_PASS(RemoveNonComputingOprPass, Simple) {
  22. auto x = mkvar("x");
  23. check(x, opr::MarkNoBroadcastElemwise::make(x));
  24. }
  25. TEST_PASS(RemoveNonComputingOprPass, Split) {
  26. auto a = mkvar("a"), b = mkvar("b"),
  27. loss = opr::reduce_sum(opr::Concat::make({a, b}, 0), a.make_scalar(1)),
  28. ga = cg::grad(loss, a),
  29. ga_exp = a.make_scalar(1.f).broadcast(ga.symshape());
  30. check(ga_exp, ga);
  31. }
  32. TEST_PASS(RemoveNonComputingOprPass, SplitImmOpt) {
  33. auto cns = load_multiple_xpus(2);
  34. HostTensorGenerator<> gen;
  35. auto cn0 = cns[0], cn1 = cns[1];
  36. auto host_x0 = gen({2, 3}, cn0),
  37. host_x1 = gen({2, 3}, cn1);
  38. auto graph = ComputingGraph::make();
  39. auto make1 = [&graph](SymbolVar var) {
  40. auto val = std::make_shared<HostTensorND>(
  41. var.node()->comp_node(), TensorShape{1}, dtype::Int32());
  42. val->ptr<int>()[0] = 1;
  43. return opr::Host2DeviceCopy::make(*graph, val);
  44. };
  45. auto x0 = opr::Host2DeviceCopy::make(*graph, host_x0),
  46. x1 = opr::Host2DeviceCopy::make(*graph, host_x1);
  47. auto splt = opr::Split::make(x0.make_scalar(0.f).broadcast({2}),
  48. opr::Split::Options::make_partition(0, {
  49. make1(x0), make1(x1)}),
  50. OperatorNodeConfig{}.comp_node_arr({cn0, cn1}));
  51. auto y0 = x0 + splt[0], y1 = x1 + splt[1];
  52. HostTensorND host_y0, host_y1;
  53. auto func = graph->compile({make_callback_copy(y0, host_y0),
  54. make_callback_copy(y1, host_y1)});
  55. func->execute();
  56. MGB_ASSERT_TENSOR_EQ(*host_x0, host_y0);
  57. MGB_ASSERT_TENSOR_EQ(*host_x1, host_y1);
  58. }
  59. TEST_PASS(DelayBroadcastPass, Basic) {
  60. auto x = mkvar("x", {1, 1, 3});
  61. auto y = mkvar("y", {1, 2, 3});
  62. auto z = mkvar("z", {2, 2, 3});
  63. auto relu_maker = [](SymbolVar x) -> SymbolVar {
  64. using Param = opr::Elemwise::Param;
  65. Param param;
  66. param.mode = Param::Mode::RELU;
  67. return opr::Elemwise::make({x}, param);
  68. };
  69. auto typecvt_maker = [](SymbolVar x, bool float16 = true) -> SymbolVar {
  70. if (float16)
  71. return opr::TypeCvt::make(x, dtype::Float16());
  72. else
  73. return opr::TypeCvt::make(x, dtype::Float32());
  74. };
  75. auto broadcast_maker = [](SymbolVar x, SymbolVar from) -> SymbolVar {
  76. return opr::Broadcast::make(x, opr::GetVarShape::make(from));
  77. };
  78. auto get_var_shp_maker = [](SymbolVar x) -> SymbolVar {
  79. return opr::GetVarShape::make(x);
  80. };
  81. // check just two oprs need swapping
  82. check(broadcast_maker(relu_maker(x), y), relu_maker(broadcast_maker(x, y)));
  83. // check multiple oprs need shifting
  84. check(broadcast_maker(typecvt_maker(relu_maker(x)), y),
  85. typecvt_maker(relu_maker(broadcast_maker(x, y))));
  86. // check opr::GetVarShape
  87. check(get_var_shp_maker(broadcast_maker(typecvt_maker(relu_maker(x)), y)),
  88. get_var_shp_maker(typecvt_maker(relu_maker(broadcast_maker(x, y)))));
  89. check(get_var_shp_maker(broadcast_maker(typecvt_maker(relu_maker(x)), y)),
  90. get_var_shp_maker(typecvt_maker(broadcast_maker(relu_maker(x), y))));
  91. check(typecvt_maker(get_var_shp_maker(broadcast_maker(relu_maker(x), y))),
  92. typecvt_maker(get_var_shp_maker(relu_maker(broadcast_maker(x, y)))));
  93. // remains the same after apply the pass.
  94. check<false>(broadcast_maker(broadcast_maker(x, y), z),
  95. broadcast_maker(broadcast_maker(x, y), z));
  96. // mix.
  97. check(broadcast_maker(broadcast_maker(relu_maker(typecvt_maker(x)), y), z),
  98. relu_maker(broadcast_maker(typecvt_maker(broadcast_maker(x, y)), z)));
  99. // endpoint situation 1. See `DelayBroadcastPass::apply` comments.
  100. check(y + broadcast_maker(relu_maker(x), z),
  101. y + relu_maker(broadcast_maker(x, z)));
  102. // second replaced chain depend on another replaced chain.
  103. check(broadcast_maker(typecvt_maker(broadcast_maker(typecvt_maker(x), y) +
  104. typecvt_maker(y),
  105. false),
  106. z),
  107. typecvt_maker(broadcast_maker(typecvt_maker(broadcast_maker(x, y)) +
  108. typecvt_maker(y),
  109. z),
  110. false));
  111. // broadcast opr depend on another chain.
  112. auto shape3 = mkvar("shape3", {2}).symshape() + 1;
  113. auto shape333 = opr::abs(opr::Broadcast::make(shape3, shape3));
  114. auto shape333_after = opr::Broadcast::make(opr::abs(shape3), shape3);
  115. check(broadcast_maker(relu_maker(x), shape333_after),
  116. relu_maker(broadcast_maker(x, shape333)));
  117. }
  118. TEST_PASS(DelayBroadcastPass, Const) {
  119. auto x = mkvar("x", {5, 3});
  120. check(x.make_scalar(-1).broadcast(x.symshape()),
  121. -x.make_scalar(1).broadcast(x.symshape()));
  122. }
  123. TEST_PASS(DelayBroadcastPass, ScalarInput) {
  124. auto x = mkvar("x", {1}).reshape({1}), y = mkvar("y", {3, 1});
  125. check((x - y).broadcast({3, 5}), x - y.broadcast({3, 5}));
  126. }
  127. TEST_PASS(DelayBroadcastPass, LongChain) {
  128. auto x = mkvar("x", {1, 1, 3});
  129. auto y = mkvar("y", {1, 2, 3});
  130. auto z = mkvar("z", {2, 2, 3});
  131. auto relu = [](SymbolVar x) -> SymbolVar {
  132. using Param = opr::Elemwise::Param;
  133. Param param;
  134. param.mode = Param::Mode::RELU;
  135. return opr::Elemwise::make({x}, param);
  136. };
  137. auto bcast = [](SymbolVar x, SymbolVar from) -> SymbolVar {
  138. return opr::Broadcast::make(x, opr::GetVarShape::make(from));
  139. };
  140. // Do graph optimization first, then construct expected graph.
  141. // Note: DO NOT call `check` directly here, the \p inp and
  142. // \p expect of the `check` are in the same graph, some problems
  143. // would not be exposed due to the cache mechanism
  144. auto out = bcast(relu(bcast(relu(x), y)), z);
  145. out = gopt::GraphOptimizer{}.
  146. add_pass<gopt::DelayBroadcastPass>().
  147. apply({{out}}).endpoint_vars()[0];
  148. ASSERT_EQ(bcast(bcast(relu(relu(x)), y), z), out);
  149. }
  150. TEST_PASS(ExpandVirtualGradPass, Simple) {
  151. auto x = mkvar("x");
  152. check(x * 2,
  153. opr::VirtualGrad::make(opr::reduce_sum_sqr(x, x.make_scalar(1)), x));
  154. }
  155. TEST_PASS(ExpandVirtualGradPass, Dyncase) {
  156. auto x0 = mkvar("x"), x = opr::MarkDynamicVar::make(x0);
  157. check(opr::MarkDynamicVar::make(x * 2),
  158. opr::VirtualGrad::make(
  159. opr::reduce_sum_sqr(x, x.make_scalar(1)),
  160. x0));
  161. }
  162. TEST_F(TestGoptExpandVirtualGradPass, GradWrt) {
  163. graph->options().graph_opt_level = 0;
  164. auto x = mkvar("x", {2, 3});
  165. SymbolVar wrt;
  166. auto get_grad = [&wrt](const opr::SetGrad &g) -> SymbolVar {
  167. auto w = gopt::GraphOptimizer::var_replace_lookup(wrt.node());
  168. return cg::grad(cg::current_grad_target(*g.owner_graph()), w, false);
  169. };
  170. wrt = opr::SetGrad::make(x * 2 + 1, get_grad) * 3 + 1;
  171. auto gx = opr::VirtualGrad::make(
  172. opr::reduce_sum(wrt, wrt.make_scalar(1)),
  173. x);
  174. SymbolVar gx_opt;
  175. unpack_vector(
  176. gopt::GraphOptimizer{}.
  177. add_pass<gopt::ArithFusePass>().
  178. add_pass<gopt::ExpandVirtualGradPass>().
  179. verbosity(2).
  180. apply({{gx}}).endpoint_vars(),
  181. gx_opt);
  182. HostTensorND host_gx;
  183. auto func = graph->compile({make_callback_copy(gx_opt, host_gx)});
  184. func->execute();
  185. ASSERT_EQ(x.shape(), host_gx.shape());
  186. auto pgx = host_gx.ptr<float>();
  187. for (size_t i = 0, it = host_gx.shape().total_nr_elems();
  188. i < it; ++ i) {
  189. ASSERT_EQ(2.f, pgx[i]);
  190. }
  191. }
  192. TEST_F(TestGoptExpandVirtualGradPass, VarReplaceLookup) {
  193. HostTensorGenerator<> gen;
  194. auto graph = ComputingGraph::make();
  195. auto host_x = gen({1});
  196. auto x = opr::Host2DeviceCopy::make(*graph, host_x);
  197. SymbolVar y;
  198. auto grad_getter = [&](const opr::SetGrad &) { return y; };
  199. auto a = opr::SetGrad::make(x, grad_getter);
  200. int counter = 0;
  201. auto callback = [&](DeviceTensorND &) { counter++; };
  202. y = opr::CallbackInjector::make(a * a, callback);
  203. auto grad = opr::VirtualGrad::make(y, x);
  204. HostTensorND host_y, host_grad;
  205. auto func = graph->compile({make_callback_copy(y, host_y),
  206. make_callback_copy(grad, host_grad)});
  207. func->execute();
  208. ASSERT_EQ(counter, 1);
  209. }
  210. TEST_PASS(RecompTypeCvtPass, Basic) {
  211. auto x = mkvar("x", {2, 3, 3});
  212. auto x_fp16 = opr::TypeCvt::make(x, dtype::Float16());
  213. auto sin_x = opr::sin(x_fp16);
  214. auto x_fp32 = opr::TypeCvt::make(sin_x, dtype::Float32());
  215. auto f = x_fp32;
  216. for (size_t i = 0; i < 20; ++i) {
  217. f = opr::sin(f);
  218. }
  219. auto for_pass = f + x_fp32;
  220. OperatorNodeConfig config = x_fp32.node()->owner_opr()->config();
  221. config.update_instance_id(for_pass.node()->owner_opr());
  222. auto expected = f + opr::TypeCvt::make(sin_x, dtype::Float32(),
  223. config);
  224. check(expected, for_pass, 0.1);
  225. }
  226. TEST_PASS(CombineAstypeAndReducePass, Grad) {
  227. auto data = mkvar("data", {10});
  228. auto x_fp16 = opr::relu(opr::TypeCvt::make(data, dtype::Float16()));
  229. auto x = opr::TypeCvt::make(x_fp16, dtype::Float32());
  230. SymbolVar tshp;
  231. using namespace opr;
  232. Reduce::Param param_i16_co32{Reduce::Mode::SUM, 0,
  233. Reduce::Param::DataType::FLOAT_O32xC32};
  234. Reduce::Param param_default{Reduce::Mode::SUM, 0,
  235. Reduce::Param::DataType::DEFAULT};
  236. auto y0 = opr::Reduce::make(x_fp16, param_i16_co32, tshp);
  237. auto y1 = opr::Reduce::make(x, param_default, tshp);
  238. auto grad0 = cg::grad(y0, data);
  239. auto grad1 = cg::grad(y1, data);
  240. HostTensorND host_grad0, host_grad1;
  241. auto func0 = graph->compile({make_callback_copy(grad0, host_grad0)});
  242. func0->execute();
  243. auto func1 = graph->compile({make_callback_copy(grad1, host_grad1)});
  244. func1->execute();
  245. MGB_ASSERT_TENSOR_EQ(host_grad0, host_grad1);
  246. }
  247. TEST_PASS(CombineAstypeAndReducePass, Basic) {
  248. for (auto&& axis : {MEGDNN_MAX_NDIM, 0}) {
  249. auto x = mkvar("x", {2, 3, 3});
  250. auto x_fp16 = opr::relu(opr::TypeCvt::make(x, dtype::Float16()));
  251. x = opr::TypeCvt::make(x_fp16, dtype::Float32());
  252. SymbolVar tshp;
  253. if (axis == MEGDNN_MAX_NDIM) {
  254. tshp = mkvar("tshp", {1, 3, 2}).symshape();
  255. }
  256. using namespace opr;
  257. Reduce::Param param_i16_co32{Reduce::Mode::SUM, axis,
  258. Reduce::Param::DataType::FLOAT_O32xC32};
  259. Reduce::Param param_default{Reduce::Mode::SUM, axis,
  260. Reduce::Param::DataType::DEFAULT};
  261. auto expected = opr::Reduce::make(x_fp16, param_i16_co32, tshp);
  262. auto get = opr::Reduce::make(x, param_default, tshp);
  263. check(expected, get);
  264. }
  265. }
  266. #if MGB_ENABLE_COND_EXEC
  267. TEST(TestCondExec, GoptRemoveConstMask) {
  268. using MergeMode = opr::CondExecMerge::Mode;
  269. HostTensorGenerator<> gen;
  270. auto host_x = gen({2, 3});
  271. auto run = [&](MergeMode merge_mode, int const_mask, int pred_mask,
  272. bool expect_change) -> HostTensorND {
  273. auto host_pred0 = gen({1}), host_pred1 = gen({1});
  274. host_pred0->ptr<float>()[0] = pred_mask & 1;
  275. host_pred1->ptr<float>()[0] = pred_mask >> 1;
  276. auto graph = ComputingGraph::make();
  277. auto x = opr::Host2DeviceCopy::make(*graph, host_x);
  278. auto make_mark =
  279. [x, &graph](bool const_pred,
  280. const std::shared_ptr<HostTensorND>& host_pred) {
  281. SymbolVar pred;
  282. if (const_pred) {
  283. pred = opr::ImmutableTensor::make(*graph, *host_pred);
  284. } else {
  285. pred = opr::Host2DeviceCopy::make(*graph, host_pred);
  286. }
  287. SymbolVar ppv, ret;
  288. unpack_vector(opr::CondExecPred::make(
  289. pred, {pred.make_scalar_dt(1)}),
  290. ppv);
  291. unpack_vector(opr::CondExecMark::make(ppv, {x}), ret);
  292. return ret;
  293. };
  294. SymbolVarArray merge_shp;
  295. if (merge_mode == MergeMode::SUM) {
  296. merge_shp.push_back(x.symshape());
  297. }
  298. auto xmark0 = make_mark(const_mask & 1, host_pred0) + 1.2f,
  299. xmark1 = make_mark(const_mask >> 1, host_pred1) * 2.3f,
  300. y = opr::CondExecMerge::make({xmark0, xmark1}, {1, merge_mode},
  301. merge_shp)[0];
  302. VarNodeArray y_opt_arr{y.node()};
  303. gopt::GraphOptimizer{}
  304. .add_pass<gopt::CondExecConstPredicateFolding>()
  305. .apply_inplace(y_opt_arr);
  306. SymbolVar y_opt = y_opt_arr[0];
  307. if (expect_change) {
  308. EXPECT_NE(y_opt.node(), y.node());
  309. } else {
  310. EXPECT_EQ(y_opt, y);
  311. }
  312. HostTensorND host_y;
  313. graph->options().graph_opt_level = 0;
  314. auto func = graph->compile({make_callback_copy(y_opt, host_y)});
  315. func->execute();
  316. return host_y;
  317. };
  318. for (size_t mode_num = 0;
  319. mode_num < opr::CondExecMerge::Param::MODE_NR_MEMBER; ++mode_num) {
  320. auto mode = static_cast<MergeMode>(mode_num);
  321. bool exact_one = (mode == MergeMode::EXACT_ONE ||
  322. mode == MergeMode::EXACT_ONE_SAME_SHAPE);
  323. for (int pmask = 0; pmask < 4; ++pmask) {
  324. if (exact_one && (pmask & 1) + (pmask >> 1) != 1) {
  325. continue;
  326. }
  327. if (mode == MergeMode::SUM_COND_OUT && !pmask) {
  328. ASSERT_THROW(run(mode, 0b11, 0, false), GraphError);
  329. continue;
  330. }
  331. auto v0 = run(mode, 0b11, pmask, true);
  332. auto v1 = run(mode, 0b01, pmask, false);
  333. MGB_ASSERT_TENSOR_EQ(v0, v1);
  334. }
  335. }
  336. }
  337. #endif // MGB_ENABLE_COND_EXEC
  338. TEST_PASS(RemoveRedundantTypeCvtPass, Basic) {
  339. #if !MEGDNN_DISABLE_FLOAT16
  340. auto x = mkvar("x", {2, 3, 3});
  341. auto x_fp16 = opr::TypeCvt::make(x, dtype::Float16());
  342. auto x_fp16_fp32 = opr::TypeCvt::make(x_fp16, dtype::Float32());
  343. auto x_fp16_fp32_fp16 = opr::TypeCvt::make(x_fp16_fp32, dtype::Float16());
  344. check(x_fp16, x_fp16_fp32_fp16);
  345. #endif
  346. auto x_i32 = opr::TypeCvt::make(x, dtype::Int32());
  347. auto x_i32_i16 = opr::TypeCvt::make(x_i32, dtype::Int16());
  348. auto x_i32_i16_i8 = opr::TypeCvt::make(x_i32_i16, dtype::Int8());
  349. auto x_i8 = opr::TypeCvt::make(x, dtype::Int8());
  350. check(x_i8, x_i32_i16_i8);
  351. auto x_q8 = opr::TypeCvt::make(x, dtype::QuantizedS8(0.1f));
  352. auto x_q8_fp32 = opr::TypeCvt::make(x_q8, dtype::Float32());
  353. auto x_q8_fp32_q8 = opr::TypeCvt::make(x_q8_fp32, dtype::QuantizedS8(0.1f));
  354. auto x_q8_fp32_q8_ = opr::TypeCvt::make(x_q8_fp32, dtype::QuantizedS8(2.f));
  355. auto x_q8_q8 = opr::TypeCvt::make(x_q8, dtype::QuantizedS8(2.f));
  356. check(x_q8, x_q8_fp32_q8);
  357. check(x_q8_q8, x_q8_fp32_q8_);
  358. }
  359. TEST_PASS(RemoveRedundantCopyPass, Basic) {
  360. auto x = mkvar("x", {2, 3, 3}, CompNode::load("cpu0"));
  361. {
  362. auto x_cpu1 = opr::Copy::make(x, CompNode::load("cpu1"));
  363. auto x_cpu0 = opr::Copy::make(x_cpu1, CompNode::load("cpu0"));
  364. auto x_cpu2 = opr::Copy::make(x_cpu0, CompNode::load("cpu2"));
  365. auto x_expected = opr::Copy::make(x, CompNode::load("cpu2"));
  366. check(x, x_cpu0);
  367. check(x_expected, x_cpu2);
  368. }
  369. {
  370. auto x_cpu1 = opr::Copy::make(x, CompNode::load("cpu1"));
  371. auto x_cpu2 = opr::Copy::make(x_cpu1, CompNode::load("cpu2"));
  372. auto x_cpu3 = opr::Copy::make(x_cpu2, CompNode::load("cpu3"));
  373. auto x_expected = opr::Copy::make(x, CompNode::load("cpu3"));
  374. check(x_expected, x_cpu3);
  375. }
  376. {
  377. auto x_cpu1 = opr::Copy::make(x, CompNode::load("cpu0:1"));
  378. auto x_cpu2 = opr::Copy::make(x_cpu1, CompNode::load("cpu0:2"));
  379. auto x_cpu3 = opr::Copy::make(x_cpu2, CompNode::load("cpu0:3"));
  380. auto x_expected = opr::Copy::make(x, CompNode::load("cpu0:3"));
  381. check(x_expected, x_cpu3);
  382. }
  383. {
  384. auto x_cpu1 = opr::Copy::make(x, CompNode::load("cpu0:1"));
  385. auto x_mt = opr::Copy::make(x_cpu1, CompNode::load("multithread8:0"));
  386. auto x_cpu3 = opr::Copy::make(x_mt, CompNode::load("cpu0:3"));
  387. auto x_expected = opr::Copy::make(x, CompNode::load("cpu0:3"));
  388. check(x_expected, x_cpu3);
  389. }
  390. #if MGB_ATLAS
  391. {
  392. auto x_atlas0 = opr::Copy::make(x, CompNode::load("atlas0"));
  393. auto x_cpu2 = opr::Copy::make(x_atlas0, CompNode::load("cpu0:2"));
  394. auto x_cpu3 = opr::Copy::make(x_cpu2, CompNode::load("cpu0:3"));
  395. auto x_expected = opr::Copy::make(x, CompNode::load("cpu0:3"));
  396. check(x_expected, x_cpu3);
  397. }
  398. #endif
  399. #if MGB_CUDA
  400. {
  401. auto x_cuda0 = opr::Copy::make(x, CompNode::load("gpu0"));
  402. auto x_cpu2 = opr::Copy::make(x_cuda0, CompNode::load("cpu0:2"));
  403. auto x_cpu3 = opr::Copy::make(x_cpu2, CompNode::load("cpu0:3"));
  404. auto x_expected = opr::Copy::make(x, CompNode::load("cpu0:3"));
  405. check(x_expected, x_cpu3);
  406. }
  407. {
  408. auto x_mt = opr::Copy::make(x, CompNode::load("multithread8:0"));
  409. auto x_cpu2 = opr::Copy::make(x_mt , CompNode::load("gpu0:1"));
  410. auto x_cpu3 = opr::Copy::make(x_cpu2, CompNode::load("multithread8:0"));
  411. auto x_expected = opr::Copy::make(x, CompNode::load("multithread8:0"));
  412. check(x_expected, x_cpu3);
  413. }
  414. #endif
  415. }
  416. #if MGB_ENABLE_OPR_MM
  417. #include "megbrain/opr/collective_comm.h"
  418. #include "../../opr-mm/test/mock_client.h"
  419. TEST_PASS(PackAllReduceScanPass, Basic) {
  420. auto graph = ComputingGraph::make();
  421. graph->options().allreduce_pack_max_size = 5000;
  422. auto client = std::make_shared<test::MockGroupClient>();
  423. auto cn = CompNode::load("gpux");
  424. auto dev_x0 = std::make_shared<DeviceTensorND>(cn, TensorShape{3, 5});
  425. auto dev_x1 = std::make_shared<DeviceTensorND>(cn, TensorShape{4, 6});
  426. auto dev_y0 = std::make_shared<DeviceTensorND>(cn, TensorShape{1});
  427. auto dev_y1 = std::make_shared<DeviceTensorND>(cn, TensorShape{1});
  428. auto x0 = opr::SharedDeviceTensor::make(*graph, dev_x0);
  429. auto x1 = opr::VolatileSharedDeviceTensor::make(*graph, dev_x1);
  430. auto y0 = opr::SharedDeviceTensor::make(*graph, dev_y0);
  431. auto y1 = opr::VolatileSharedDeviceTensor::make(*graph, dev_y1);
  432. auto grad0 = opr::VirtualGrad::make(y0, x0);
  433. auto grad1 = opr::VirtualGrad::make(y0, x1);
  434. auto grad2 = opr::VirtualGrad::make(y1, x0);
  435. auto grad3 = opr::VirtualGrad::make(y1, x1);
  436. auto mode = opr::CollectiveComm::Param::Mode::ALL_REDUCE_SUM;
  437. auto comm0 = opr::CollectiveComm::make({grad0}, graph.get(), "grad0", 2,
  438. false, 0, false, client, mode)[0];
  439. auto comm1 = opr::CollectiveComm::make({grad1}, graph.get(), "grad1", 2,
  440. false, 0, false, client, mode)[0];
  441. auto comm2 = opr::CollectiveComm::make({grad2}, graph.get(), "grad2", 2,
  442. false, 0, false, client, mode)[0];
  443. auto comm3 = opr::CollectiveComm::make({grad3}, graph.get(), "grad3", 2,
  444. false, 0, false, client, mode)[0];
  445. gopt::GraphOptimizer()
  446. .add_pass<gopt::PackAllReduceScanPass>()
  447. .apply({{comm0, comm1, comm2, comm3}});
  448. auto get_hash = [] (const SymbolVar& symvar) {
  449. cg::OperatorNodeBase* opr = symvar.node()->owner_opr();
  450. return opr->cast_final_safe<opr::CollectiveComm>().pack_hash();
  451. };
  452. uint64_t hash0 = get_hash(comm0);
  453. uint64_t hash1 = get_hash(comm1);
  454. uint64_t hash2 = get_hash(comm2);
  455. uint64_t hash3 = get_hash(comm3);
  456. ASSERT_EQ(hash0, hash1);
  457. ASSERT_EQ(hash2, hash3);
  458. ASSERT_NE(hash0, hash2);
  459. }
  460. TEST_PASS(PackAllReduceReplacePass, CollectGroups) {
  461. REQUIRE_GPU(2);
  462. auto cns = load_multiple_xpus(2);
  463. auto graph = ComputingGraph::make();
  464. graph->options().graph_opt_level = 2;
  465. auto cli0 = std::make_shared<test::MockGroupClient>("mock_addr0");
  466. auto cli1 = std::make_shared<test::MockGroupClient>("mock_addr1");
  467. using GroupInfo = gopt::PackAllReduceReplacePass::GroupInfo;
  468. ThinHashMap<uint64_t, std::shared_ptr<GroupInfo>> group_info;
  469. ThinHashMap<uint64_t, cg::OprNodeArray> groups;
  470. auto add_opr = [&] (const CompNode& cn, TensorShape shape, const DType& dt,
  471. std::shared_ptr<test::MockGroupClient> client, uint64_t extra_hash) {
  472. auto dev0 = std::make_shared<DeviceTensorND>(cn, shape, dt);
  473. auto wrt = opr::SharedDeviceTensor::make(*graph, dev0);
  474. auto dev1 = std::make_shared<DeviceTensorND>(cn, TensorShape{1}, dt);
  475. auto target = opr::SharedDeviceTensor::make(*graph, dev1);
  476. auto grad = opr::VirtualGrad::make(target, wrt);
  477. auto comm =
  478. opr::CollectiveComm::make(
  479. {grad}, graph.get(), "key", 2, false, 0, false, client,
  480. opr::CollectiveComm::Param::Mode::ALL_REDUCE_SUM)[0]
  481. .node()
  482. ->owner_opr();
  483. comm->cast_final_safe<opr::CollectiveComm>().set_pack_hash(extra_hash);
  484. return gopt::PackAllReduceReplacePass::collect_groups(comm, group_info, groups);
  485. };
  486. uint64_t hash0 = add_opr(cns[0], TensorShape{1, 3}, dtype::Float32{}, cli0, 1);
  487. uint64_t hash1 = add_opr(cns[0], TensorShape{2, 4}, dtype::Float32{}, cli0, 1); // same
  488. uint64_t hash2 = add_opr(cns[1], TensorShape{3, 5}, dtype::Float32{}, cli0, 1); // comp_node
  489. uint64_t hash3 = add_opr(cns[0], TensorShape{4, 6}, dtype::Float16{}, cli0, 1); // dtype
  490. uint64_t hash4 = add_opr(cns[0], TensorShape{5, 7}, dtype::Float32{}, cli1, 1); // client
  491. uint64_t hash5 = add_opr(cns[0], TensorShape{6, 8}, dtype::Float32{}, cli0, 2); // extra_hash
  492. ASSERT_EQ(hash0, hash1);
  493. std::set<uint64_t> s;
  494. s.insert(hash0);
  495. s.insert(hash1);
  496. s.insert(hash2);
  497. s.insert(hash3);
  498. s.insert(hash4);
  499. s.insert(hash5);
  500. ASSERT_EQ(5, s.size());
  501. ASSERT_EQ(1, group_info.count(hash0));
  502. ASSERT_EQ(1, group_info.count(hash1));
  503. ASSERT_EQ(1, group_info.count(hash2));
  504. ASSERT_EQ(1, group_info.count(hash3));
  505. ASSERT_EQ(1, group_info.count(hash4));
  506. ASSERT_EQ(1, group_info.count(hash5));
  507. ASSERT_EQ(2, groups[hash0].size());
  508. ASSERT_EQ(2, groups[hash1].size());
  509. ASSERT_EQ(1, groups[hash2].size());
  510. ASSERT_EQ(1, groups[hash3].size());
  511. ASSERT_EQ(1, groups[hash4].size());
  512. ASSERT_EQ(1, groups[hash5].size());
  513. }
  514. TEST_PASS(PackAllReduceReplacePass, DividePacks) {
  515. auto cn = CompNode::load("gpux");
  516. auto graph = ComputingGraph::make();
  517. auto client = std::make_shared<test::MockGroupClient>();
  518. auto mode = opr::CollectiveComm::Param::Mode::ALL_REDUCE_SUM;
  519. ThinHashMap<uint64_t, cg::OprNodeArray> groups;
  520. ThinHashMap<uint64_t, std::vector<cg::OprNodeArray>> packs;
  521. auto insert_opr = [&] (size_t size) {
  522. auto dev = std::make_shared<DeviceTensorND>(cn, TensorShape{size / sizeof(float)});
  523. auto sd = opr::SharedDeviceTensor::make(*graph, dev);
  524. auto symvar = opr::CollectiveComm::make(
  525. {sd}, graph.get(), "key", 2, false, 0, false, client, mode)[0];
  526. auto opr = symvar.node()->owner_opr();
  527. auto& comm = opr->cast_final_safe<opr::CollectiveComm>();
  528. comm.set_pack_hash(1);
  529. return opr;
  530. };
  531. auto pack_size = [&] (cg::OprNodeArray& pack) {
  532. size_t sum = 0;
  533. for (size_t i = 0; i < pack.size(); i++) {
  534. auto var = pack[i]->input(0);
  535. sum += var->dtype().size(var->shape().total_nr_elems());
  536. }
  537. return sum;
  538. };
  539. groups[0].push_back(insert_opr(100)); // group0, pack0, size=1100
  540. groups[0].push_back(insert_opr(300)); // group0, pack0, size=1100
  541. groups[0].push_back(insert_opr(400)); // group0, pack0, size=1100
  542. groups[0].push_back(insert_opr(300)); // group0, pack0, size=1100
  543. groups[0].push_back(insert_opr(500)); // group0, pack1, size=800
  544. groups[0].push_back(insert_opr(200)); // group0, pack1, size=800
  545. groups[0].push_back(insert_opr(100)); // group0, pack1, size=800
  546. groups[1].push_back(insert_opr(100)); // group1, pack0, size=900
  547. groups[1].push_back(insert_opr(400)); // group1, pack0, size=900
  548. groups[1].push_back(insert_opr(300)); // group1, pack0, size=900
  549. groups[1].push_back(insert_opr(100)); // group1, pack0, size=900
  550. gopt::PackAllReduceReplacePass::divide_packs(groups, packs, 1000);
  551. ASSERT_EQ(2, packs.size());
  552. ASSERT_EQ(2, packs[0].size());
  553. ASSERT_EQ(4, packs[0][0].size());
  554. ASSERT_EQ(1100, pack_size(packs[0][0]));
  555. ASSERT_EQ(3, packs[0][1].size());
  556. ASSERT_EQ(800, pack_size(packs[0][1]));
  557. ASSERT_EQ(1, packs[1].size());
  558. ASSERT_EQ(4, packs[1][0].size());
  559. ASSERT_EQ(900, pack_size(packs[1][0]));
  560. }
  561. TEST_PASS(PackAllReduceReplacePass, InsertPackedOprs) {
  562. auto cn = CompNode::load("gpux");
  563. auto graph = ComputingGraph::make();
  564. auto client = std::make_shared<test::MockGroupClient>();
  565. auto mode = opr::CollectiveComm::Param::Mode::ALL_REDUCE_SUM;
  566. size_t nr_devices = 2;
  567. uint32_t rank = 0;
  568. using GroupInfo = gopt::PackAllReduceReplacePass::GroupInfo;
  569. ThinHashMap<uint64_t, std::shared_ptr<GroupInfo>> group_info;
  570. ThinHashMap<uint64_t, cg::OprNodeArray> groups;
  571. auto insert_opr = [&] (const TensorShape& shape) {
  572. auto dev = std::make_shared<DeviceTensorND>(cn, shape);
  573. auto sd = opr::SharedDeviceTensor::make(*graph, dev);
  574. auto symvar =
  575. opr::CollectiveComm::make({sd}, graph.get(), "key", nr_devices,
  576. false, rank, false, client, mode)[0];
  577. auto opr = symvar.node()->owner_opr();
  578. auto& comm = opr->cast_final_safe<opr::CollectiveComm>();
  579. comm.set_pack_hash(1);
  580. gopt::PackAllReduceReplacePass::collect_groups(opr, group_info, groups);
  581. return symvar;
  582. };
  583. auto shape_x = TensorShape{100, 200};
  584. auto shape_y = TensorShape{200, 400};
  585. auto x = insert_opr(shape_x);
  586. auto y = insert_opr(shape_y);
  587. ASSERT_EQ(1, group_info.size());
  588. ASSERT_EQ(1, groups.size());
  589. auto info = group_info.begin()->second;
  590. auto pack = groups.begin()->second;
  591. size_t pack_id = 0;
  592. ThinHashMap<VarNode*, VarNode*> replace_map;
  593. gopt::PackAllReduceReplacePass::insert_packed_oprs(pack_id, pack, info, replace_map, -1);
  594. auto grad_x = SymbolVar(x.node()->owner_opr()->input(0));
  595. auto grad_y = SymbolVar(y.node()->owner_opr()->input(0));
  596. auto concat = opr::Concat::make({grad_x.flatten(), grad_y.flatten()}, 0);
  597. std::string key = ssprintf("grad_pack_%zu", pack_id);
  598. auto allreduce =
  599. opr::CollectiveComm::make({concat}, graph.get(), key, nr_devices,
  600. false, rank, false, client, mode)[0];
  601. std::vector<size_t> partition;
  602. partition.push_back(shape_x.total_nr_elems());
  603. partition.push_back(shape_y.total_nr_elems());
  604. auto splits = opr::Split::make(allreduce,
  605. opr::Split::Options::make_partition(allreduce, 0, partition));
  606. ASSERT_EQ(2, splits.size());
  607. auto dest_x = splits[0].reshape(shape_x);
  608. auto dest_y = splits[1].reshape(shape_y);
  609. ASSERT_EQ(2, replace_map.size());
  610. ASSERT_TRUE(replace_map.count(x.node()) > 0);
  611. ASSERT_EQ(replace_map.at(x.node()), dest_x.node());
  612. ASSERT_TRUE(replace_map.count(y.node()) > 0);
  613. ASSERT_EQ(replace_map.at(y.node()), dest_y.node());
  614. }
  615. TEST_PASS(PackAllReduceReplacePass, Equivalence) {
  616. REQUIRE_GPU(2);
  617. auto cns = load_multiple_xpus(2);
  618. auto client = std::make_shared<test::MockGroupClient>();
  619. auto build_graph = [&] (uint32_t rank, std::shared_ptr<ComputingGraph> graph,
  620. SymbolVarArray& array) {
  621. HostTensorGenerator<> gen;
  622. auto cn = cns[rank];
  623. auto host_x = gen({1, 1000});
  624. auto host_y = gen({1000, 1});
  625. auto dev_x = std::make_shared<DeviceTensorND>(cn);
  626. auto dev_y = std::make_shared<DeviceTensorND>(cn);
  627. dev_x->copy_from(*host_x).sync();
  628. dev_y->copy_from(*host_y).sync();
  629. auto x = opr::SharedDeviceTensor::make(*graph, dev_x);
  630. auto y = opr::VolatileSharedDeviceTensor::make(*graph, dev_y);
  631. auto loss = opr::MatrixMul::make(x, y).flatten();
  632. auto grad_x = opr::VirtualGrad::make(loss, x);
  633. auto grad_y = opr::VirtualGrad::make(loss, y);
  634. using Mode = opr::CollectiveComm::Param::Mode;
  635. bool is_root = (rank == 0);
  636. auto reduced_x = opr::CollectiveComm::make(
  637. {grad_x}, graph.get(), "x", 2, is_root, rank,
  638. false, client, Mode::ALL_REDUCE_SUM)[0] /
  639. 2;
  640. auto reduced_y = opr::CollectiveComm::make(
  641. {grad_y}, graph.get(), "y", 2, is_root, rank,
  642. false, client, Mode::ALL_REDUCE_SUM)[0] /
  643. 2;
  644. graph->options().allreduce_pack_max_size = 5000;
  645. graph->options().allreduce_pack_ignore_first = 0;
  646. auto dest_vars = gopt::GraphOptimizer{}
  647. .add_pass<gopt::PackAllReduceScanPass>()
  648. .add_pass<gopt::PackAllReduceReplacePass>()
  649. .apply({{reduced_x, reduced_y}}).endpoint_vars();
  650. array.emplace_back(reduced_x);
  651. array.emplace_back(reduced_y);
  652. array.emplace_back(dest_vars[0]);
  653. array.emplace_back(dest_vars[1]);
  654. };
  655. auto run = [&] (uint32_t rank) {
  656. auto graph = ComputingGraph::make();
  657. SymbolVarArray array;
  658. build_graph(rank, graph, array);
  659. HostTensorND host_reduced_x, host_reduced_y, host_dest_0, host_dest_1;
  660. graph->options().allreduce_pack_max_size = 0;
  661. auto func = graph->compile({make_callback_copy(array[0], host_reduced_x),
  662. make_callback_copy(array[1], host_reduced_y),
  663. make_callback_copy(array[2], host_dest_0),
  664. make_callback_copy(array[3], host_dest_1)});
  665. func->execute();
  666. MGB_ASSERT_TENSOR_EQ(host_reduced_x, host_dest_0);
  667. MGB_ASSERT_TENSOR_EQ(host_reduced_y, host_dest_1);
  668. };
  669. std::thread t0(run, 0);
  670. std::thread t1(run, 1);
  671. t0.join();
  672. t1.join();
  673. }
  674. #endif // MGB_ENABLE_OPR_MM
  675. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台