You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

misc.cpp 28 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732
  1. /**
  2. * \file src/gopt/test/misc.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "./helper.h"
  12. #include "megbrain/gopt/basic_arith.h"
  13. #include "megbrain/gopt/misc.h"
  14. #include "megbrain/opr/basic_arith_wrapper.h"
  15. #include "megbrain/opr/blas.h"
  16. #include "megbrain/opr/cond.h"
  17. #include "megbrain/opr/tensor_manip.h"
  18. #include "megbrain/opr/utility.h"
  19. using namespace mgb;
  20. TEST_PASS(RemoveNonComputingOprPass, Simple) {
  21. auto x = mkvar("x");
  22. check(x, opr::MarkNoBroadcastElemwise::make(x));
  23. }
  24. TEST_PASS(RemoveNonComputingOprPass, Split) {
  25. auto a = mkvar("a"), b = mkvar("b"),
  26. loss = opr::reduce_sum(opr::Concat::make({a, b}, 0), a.make_scalar(1)),
  27. ga = cg::grad(loss, a),
  28. ga_exp = a.make_scalar(1.f).broadcast(ga.symshape());
  29. check(ga_exp, ga);
  30. }
  31. TEST_PASS(RemoveNonComputingOprPass, SplitImmOpt) {
  32. auto cns = load_multiple_xpus(2);
  33. HostTensorGenerator<> gen;
  34. auto cn0 = cns[0], cn1 = cns[1];
  35. auto host_x0 = gen({2, 3}, cn0),
  36. host_x1 = gen({2, 3}, cn1);
  37. auto graph = ComputingGraph::make();
  38. auto make1 = [&graph](SymbolVar var) {
  39. auto val = std::make_shared<HostTensorND>(
  40. var.node()->comp_node(), TensorShape{1}, dtype::Int32());
  41. val->ptr<int>()[0] = 1;
  42. return opr::Host2DeviceCopy::make(*graph, val);
  43. };
  44. auto x0 = opr::Host2DeviceCopy::make(*graph, host_x0),
  45. x1 = opr::Host2DeviceCopy::make(*graph, host_x1);
  46. auto splt = opr::Split::make(x0.make_scalar(0.f).broadcast({2}),
  47. opr::Split::Options::make_partition(0, {
  48. make1(x0), make1(x1)}),
  49. OperatorNodeConfig{}.comp_node_arr({cn0, cn1}));
  50. auto y0 = x0 + splt[0], y1 = x1 + splt[1];
  51. HostTensorND host_y0, host_y1;
  52. auto func = graph->compile({make_callback_copy(y0, host_y0),
  53. make_callback_copy(y1, host_y1)});
  54. func->execute();
  55. MGB_ASSERT_TENSOR_EQ(*host_x0, host_y0);
  56. MGB_ASSERT_TENSOR_EQ(*host_x1, host_y1);
  57. }
  58. TEST_PASS(DelayBroadcastPass, Basic) {
  59. auto x = mkvar("x", {1, 1, 3});
  60. auto y = mkvar("y", {1, 2, 3});
  61. auto z = mkvar("z", {2, 2, 3});
  62. auto relu_maker = [](SymbolVar x) -> SymbolVar {
  63. using Param = opr::Elemwise::Param;
  64. Param param;
  65. param.mode = Param::Mode::RELU;
  66. return opr::Elemwise::make({x}, param);
  67. };
  68. auto typecvt_maker = [](SymbolVar x, bool float16 = true) -> SymbolVar {
  69. if (float16)
  70. return opr::TypeCvt::make(x, dtype::Float16());
  71. else
  72. return opr::TypeCvt::make(x, dtype::Float32());
  73. };
  74. auto broadcast_maker = [](SymbolVar x, SymbolVar from) -> SymbolVar {
  75. return opr::Broadcast::make(x, opr::GetVarShape::make(from));
  76. };
  77. auto get_var_shp_maker = [](SymbolVar x) -> SymbolVar {
  78. return opr::GetVarShape::make(x);
  79. };
  80. // check just two oprs need swapping
  81. check(broadcast_maker(relu_maker(x), y), relu_maker(broadcast_maker(x, y)));
  82. // check multiple oprs need shifting
  83. check(broadcast_maker(typecvt_maker(relu_maker(x)), y),
  84. typecvt_maker(relu_maker(broadcast_maker(x, y))));
  85. // check opr::GetVarShape
  86. check(get_var_shp_maker(broadcast_maker(typecvt_maker(relu_maker(x)), y)),
  87. get_var_shp_maker(typecvt_maker(relu_maker(broadcast_maker(x, y)))));
  88. check(get_var_shp_maker(broadcast_maker(typecvt_maker(relu_maker(x)), y)),
  89. get_var_shp_maker(typecvt_maker(broadcast_maker(relu_maker(x), y))));
  90. check(typecvt_maker(get_var_shp_maker(broadcast_maker(relu_maker(x), y))),
  91. typecvt_maker(get_var_shp_maker(relu_maker(broadcast_maker(x, y)))));
  92. // remains the same after apply the pass.
  93. check<false>(broadcast_maker(broadcast_maker(x, y), z),
  94. broadcast_maker(broadcast_maker(x, y), z));
  95. // mix.
  96. check(broadcast_maker(broadcast_maker(relu_maker(typecvt_maker(x)), y), z),
  97. relu_maker(broadcast_maker(typecvt_maker(broadcast_maker(x, y)), z)));
  98. // endpoint situation 1. See `DelayBroadcastPass::apply` comments.
  99. check(y + broadcast_maker(relu_maker(x), z),
  100. y + relu_maker(broadcast_maker(x, z)));
  101. // second replaced chain depend on another replaced chain.
  102. check(broadcast_maker(typecvt_maker(broadcast_maker(typecvt_maker(x), y) +
  103. typecvt_maker(y),
  104. false),
  105. z),
  106. typecvt_maker(broadcast_maker(typecvt_maker(broadcast_maker(x, y)) +
  107. typecvt_maker(y),
  108. z),
  109. false));
  110. // broadcast opr depend on another chain.
  111. auto shape3 = mkvar("shape3", {2}).symshape() + 1;
  112. auto shape333 = opr::abs(opr::Broadcast::make(shape3, shape3));
  113. auto shape333_after = opr::Broadcast::make(opr::abs(shape3), shape3);
  114. check(broadcast_maker(relu_maker(x), shape333_after),
  115. relu_maker(broadcast_maker(x, shape333)));
  116. }
  117. TEST_PASS(DelayBroadcastPass, Const) {
  118. auto x = mkvar("x", {5, 3});
  119. check(x.make_scalar(-1).broadcast(x.symshape()),
  120. -x.make_scalar(1).broadcast(x.symshape()));
  121. }
  122. TEST_PASS(DelayBroadcastPass, ScalarInput) {
  123. auto x = mkvar("x", {1}).reshape({1}), y = mkvar("y", {3, 1});
  124. check((x - y).broadcast({3, 5}), x - y.broadcast({3, 5}));
  125. }
  126. TEST_PASS(DelayBroadcastPass, LongChain) {
  127. auto x = mkvar("x", {1, 1, 3});
  128. auto y = mkvar("y", {1, 2, 3});
  129. auto z = mkvar("z", {2, 2, 3});
  130. auto relu = [](SymbolVar x) -> SymbolVar {
  131. using Param = opr::Elemwise::Param;
  132. Param param;
  133. param.mode = Param::Mode::RELU;
  134. return opr::Elemwise::make({x}, param);
  135. };
  136. auto bcast = [](SymbolVar x, SymbolVar from) -> SymbolVar {
  137. return opr::Broadcast::make(x, opr::GetVarShape::make(from));
  138. };
  139. // Do graph optimization first, then construct expected graph.
  140. // Note: DO NOT call `check` directly here, the \p inp and
  141. // \p expect of the `check` are in the same graph, some problems
  142. // would not be exposed due to the cache mechanism
  143. auto out = bcast(relu(bcast(relu(x), y)), z);
  144. out = gopt::GraphOptimizer{}.
  145. add_pass<gopt::DelayBroadcastPass>().
  146. apply({{out}}).endpoint_vars()[0];
  147. ASSERT_EQ(bcast(bcast(relu(relu(x)), y), z), out);
  148. }
  149. TEST_PASS(ExpandVirtualGradPass, Simple) {
  150. auto x = mkvar("x");
  151. check(x * 2,
  152. opr::VirtualGrad::make(opr::reduce_sum_sqr(x, x.make_scalar(1)), x));
  153. }
  154. TEST_PASS(ExpandVirtualGradPass, Dyncase) {
  155. auto x0 = mkvar("x"), x = opr::MarkDynamicVar::make(x0);
  156. check(opr::MarkDynamicVar::make(x * 2),
  157. opr::VirtualGrad::make(
  158. opr::reduce_sum_sqr(x, x.make_scalar(1)),
  159. x0));
  160. }
  161. TEST_F(TestGoptExpandVirtualGradPass, GradWrt) {
  162. graph->options().graph_opt_level = 0;
  163. auto x = mkvar("x", {2, 3});
  164. SymbolVar wrt;
  165. auto get_grad = [&wrt](const opr::SetGrad &g) -> SymbolVar {
  166. auto w = gopt::GraphOptimizer::var_replace_lookup(wrt.node());
  167. return cg::grad(cg::current_grad_target(*g.owner_graph()), w, false);
  168. };
  169. wrt = opr::SetGrad::make(x * 2 + 1, get_grad) * 3 + 1;
  170. auto gx = opr::VirtualGrad::make(
  171. opr::reduce_sum(wrt, wrt.make_scalar(1)),
  172. x);
  173. SymbolVar gx_opt;
  174. unpack_vector(
  175. gopt::GraphOptimizer{}.
  176. add_pass<gopt::ArithFusePass>().
  177. add_pass<gopt::ExpandVirtualGradPass>().
  178. verbosity(2).
  179. apply({{gx}}).endpoint_vars(),
  180. gx_opt);
  181. HostTensorND host_gx;
  182. auto func = graph->compile({make_callback_copy(gx_opt, host_gx)});
  183. func->execute();
  184. ASSERT_EQ(x.shape(), host_gx.shape());
  185. auto pgx = host_gx.ptr<float>();
  186. for (size_t i = 0, it = host_gx.shape().total_nr_elems();
  187. i < it; ++ i) {
  188. ASSERT_EQ(2.f, pgx[i]);
  189. }
  190. }
  191. TEST_F(TestGoptExpandVirtualGradPass, VarReplaceLookup) {
  192. HostTensorGenerator<> gen;
  193. auto graph = ComputingGraph::make();
  194. auto host_x = gen({1});
  195. auto x = opr::Host2DeviceCopy::make(*graph, host_x);
  196. SymbolVar y;
  197. auto grad_getter = [&](const opr::SetGrad &) { return y; };
  198. auto a = opr::SetGrad::make(x, grad_getter);
  199. int counter = 0;
  200. auto callback = [&](DeviceTensorND &) { counter++; };
  201. y = opr::CallbackInjector::make(a * a, callback);
  202. auto grad = opr::VirtualGrad::make(y, x);
  203. HostTensorND host_y, host_grad;
  204. auto func = graph->compile({make_callback_copy(y, host_y),
  205. make_callback_copy(grad, host_grad)});
  206. func->execute();
  207. ASSERT_EQ(counter, 1);
  208. }
  209. TEST_PASS(RecompTypeCvtPass, Basic) {
  210. auto x = mkvar("x", {2, 3, 3});
  211. auto x_fp16 = opr::TypeCvt::make(x, dtype::Float16());
  212. auto sin_x = opr::sin(x_fp16);
  213. auto x_fp32 = opr::TypeCvt::make(sin_x, dtype::Float32());
  214. auto f = x_fp32;
  215. for (size_t i = 0; i < 20; ++i) {
  216. f = opr::sin(f);
  217. }
  218. auto for_pass = f + x_fp32;
  219. OperatorNodeConfig config = x_fp32.node()->owner_opr()->config();
  220. config.instance_id(for_pass.node()->owner_opr());
  221. auto expected = f + opr::TypeCvt::make(sin_x, dtype::Float32(),
  222. config);
  223. check(expected, for_pass, 0.1);
  224. }
  225. TEST_PASS(CombineAstypeAndReducePass, Grad) {
  226. auto data = mkvar("data", {10});
  227. auto x_fp16 = opr::relu(opr::TypeCvt::make(data, dtype::Float16()));
  228. auto x = opr::TypeCvt::make(x_fp16, dtype::Float32());
  229. SymbolVar tshp;
  230. using namespace opr;
  231. Reduce::Param param_i16_co32{Reduce::Mode::SUM, 0,
  232. Reduce::Param::DataType::FLOAT_O32xC32};
  233. Reduce::Param param_default{Reduce::Mode::SUM, 0,
  234. Reduce::Param::DataType::DEFAULT};
  235. auto y0 = opr::Reduce::make(x_fp16, param_i16_co32, tshp);
  236. auto y1 = opr::Reduce::make(x, param_default, tshp);
  237. auto grad0 = cg::grad(y0, data);
  238. auto grad1 = cg::grad(y1, data);
  239. HostTensorND host_grad0, host_grad1;
  240. auto func0 = graph->compile({make_callback_copy(grad0, host_grad0)});
  241. func0->execute();
  242. auto func1 = graph->compile({make_callback_copy(grad1, host_grad1)});
  243. func1->execute();
  244. MGB_ASSERT_TENSOR_EQ(host_grad0, host_grad1);
  245. }
  246. TEST_PASS(CombineAstypeAndReducePass, Basic) {
  247. for (auto&& axis : {MEGDNN_MAX_NDIM, 0}) {
  248. auto x = mkvar("x", {2, 3, 3});
  249. auto x_fp16 = opr::relu(opr::TypeCvt::make(x, dtype::Float16()));
  250. x = opr::TypeCvt::make(x_fp16, dtype::Float32());
  251. SymbolVar tshp;
  252. if (axis == MEGDNN_MAX_NDIM) {
  253. tshp = mkvar("tshp", {1, 3, 2}).symshape();
  254. }
  255. using namespace opr;
  256. Reduce::Param param_i16_co32{Reduce::Mode::SUM, axis,
  257. Reduce::Param::DataType::FLOAT_O32xC32};
  258. Reduce::Param param_default{Reduce::Mode::SUM, axis,
  259. Reduce::Param::DataType::DEFAULT};
  260. auto expected = opr::Reduce::make(x_fp16, param_i16_co32, tshp);
  261. auto get = opr::Reduce::make(x, param_default, tshp);
  262. check(expected, get);
  263. }
  264. }
  265. #if MGB_ENABLE_COND_EXEC
  266. TEST(TestCondExec, GoptRemoveConstMask) {
  267. using MergeMode = opr::CondExecMerge::Mode;
  268. HostTensorGenerator<> gen;
  269. auto host_x = gen({2, 3});
  270. auto run = [&](MergeMode merge_mode, int const_mask, int pred_mask,
  271. bool expect_change) -> HostTensorND {
  272. auto host_pred0 = gen({1}), host_pred1 = gen({1});
  273. host_pred0->ptr<float>()[0] = pred_mask & 1;
  274. host_pred1->ptr<float>()[0] = pred_mask >> 1;
  275. auto graph = ComputingGraph::make();
  276. auto x = opr::Host2DeviceCopy::make(*graph, host_x);
  277. auto make_mark =
  278. [x, &graph](bool const_pred,
  279. const std::shared_ptr<HostTensorND>& host_pred) {
  280. SymbolVar pred;
  281. if (const_pred) {
  282. pred = opr::ImmutableTensor::make(*graph, *host_pred);
  283. } else {
  284. pred = opr::Host2DeviceCopy::make(*graph, host_pred);
  285. }
  286. SymbolVar ppv, ret;
  287. unpack_vector(opr::CondExecPred::make(
  288. pred, {pred.make_scalar_dt(1)}),
  289. ppv);
  290. unpack_vector(opr::CondExecMark::make(ppv, {x}), ret);
  291. return ret;
  292. };
  293. SymbolVarArray merge_shp;
  294. if (merge_mode == MergeMode::SUM) {
  295. merge_shp.push_back(x.symshape());
  296. }
  297. auto xmark0 = make_mark(const_mask & 1, host_pred0) + 1.2f,
  298. xmark1 = make_mark(const_mask >> 1, host_pred1) * 2.3f,
  299. y = opr::CondExecMerge::make({xmark0, xmark1}, {1, merge_mode},
  300. merge_shp)[0];
  301. VarNodeArray y_opt_arr{y.node()};
  302. gopt::GraphOptimizer{}
  303. .add_pass<gopt::CondExecConstPredicateFolding>()
  304. .apply_inplace(y_opt_arr);
  305. SymbolVar y_opt = y_opt_arr[0];
  306. if (expect_change) {
  307. EXPECT_NE(y_opt.node(), y.node());
  308. } else {
  309. EXPECT_EQ(y_opt, y);
  310. }
  311. HostTensorND host_y;
  312. graph->options().graph_opt_level = 0;
  313. auto func = graph->compile({make_callback_copy(y_opt, host_y)});
  314. func->execute();
  315. return host_y;
  316. };
  317. for (size_t mode_num = 0;
  318. mode_num < opr::CondExecMerge::Param::MODE_NR_MEMBER; ++mode_num) {
  319. auto mode = static_cast<MergeMode>(mode_num);
  320. bool exact_one = (mode == MergeMode::EXACT_ONE ||
  321. mode == MergeMode::EXACT_ONE_SAME_SHAPE);
  322. for (int pmask = 0; pmask < 4; ++pmask) {
  323. if (exact_one && (pmask & 1) + (pmask >> 1) != 1) {
  324. continue;
  325. }
  326. if (mode == MergeMode::SUM_COND_OUT && !pmask) {
  327. ASSERT_THROW(run(mode, 0b11, 0, false), GraphError);
  328. continue;
  329. }
  330. auto v0 = run(mode, 0b11, pmask, true);
  331. auto v1 = run(mode, 0b01, pmask, false);
  332. MGB_ASSERT_TENSOR_EQ(v0, v1);
  333. }
  334. }
  335. }
  336. #endif // MGB_ENABLE_COND_EXEC
  337. TEST_PASS(RemoveRedundantTypeCvtPass, Basic) {
  338. #if !MEGDNN_DISABLE_FLOAT16
  339. auto x = mkvar("x", {2, 3, 3});
  340. auto x_fp16 = opr::TypeCvt::make(x, dtype::Float16());
  341. auto x_fp16_fp32 = opr::TypeCvt::make(x_fp16, dtype::Float32());
  342. auto x_fp16_fp32_fp16 = opr::TypeCvt::make(x_fp16_fp32, dtype::Float16());
  343. check(x_fp16, x_fp16_fp32_fp16);
  344. #endif
  345. auto x_i32 = opr::TypeCvt::make(x, dtype::Int32());
  346. auto x_i32_i16 = opr::TypeCvt::make(x_i32, dtype::Int16());
  347. auto x_i32_i16_i8 = opr::TypeCvt::make(x_i32_i16, dtype::Int8());
  348. auto x_i8 = opr::TypeCvt::make(x, dtype::Int8());
  349. check(x_i8, x_i32_i16_i8);
  350. auto x_q8 = opr::TypeCvt::make(x, dtype::QuantizedS8(0.1f));
  351. auto x_q8_fp32 = opr::TypeCvt::make(x_q8, dtype::Float32());
  352. auto x_q8_fp32_q8 = opr::TypeCvt::make(x_q8_fp32, dtype::QuantizedS8(0.1f));
  353. auto x_q8_fp32_q8_ = opr::TypeCvt::make(x_q8_fp32, dtype::QuantizedS8(2.f));
  354. auto x_q8_q8 = opr::TypeCvt::make(x_q8, dtype::QuantizedS8(2.f));
  355. check(x_q8, x_q8_fp32_q8);
  356. check(x_q8_q8, x_q8_fp32_q8_);
  357. }
  358. #if MGB_ENABLE_OPR_MM
  359. #include "megbrain/opr/collective_comm.h"
  360. #include "../../opr-mm/test/mock_client.h"
  361. TEST_PASS(PackAllReduceScanPass, Basic) {
  362. auto graph = ComputingGraph::make();
  363. graph->options().allreduce_pack_max_size = 5000;
  364. auto client = std::make_shared<test::MockGroupClient>();
  365. auto cn = CompNode::load("gpux");
  366. auto dev_x0 = std::make_shared<DeviceTensorND>(cn, TensorShape{3, 5});
  367. auto dev_x1 = std::make_shared<DeviceTensorND>(cn, TensorShape{4, 6});
  368. auto dev_y0 = std::make_shared<DeviceTensorND>(cn, TensorShape{1});
  369. auto dev_y1 = std::make_shared<DeviceTensorND>(cn, TensorShape{1});
  370. auto x0 = opr::SharedDeviceTensor::make(*graph, dev_x0);
  371. auto x1 = opr::VolatileSharedDeviceTensor::make(*graph, dev_x1);
  372. auto y0 = opr::SharedDeviceTensor::make(*graph, dev_y0);
  373. auto y1 = opr::VolatileSharedDeviceTensor::make(*graph, dev_y1);
  374. auto grad0 = opr::VirtualGrad::make(y0, x0);
  375. auto grad1 = opr::VirtualGrad::make(y0, x1);
  376. auto grad2 = opr::VirtualGrad::make(y1, x0);
  377. auto grad3 = opr::VirtualGrad::make(y1, x1);
  378. auto mode = opr::CollectiveComm::Param::Mode::ALL_REDUCE_SUM;
  379. auto comm0 = opr::CollectiveComm::make({grad0}, graph.get(),
  380. "grad0", 2, 0, 0, client, mode)[0];
  381. auto comm1 = opr::CollectiveComm::make({grad1}, graph.get(),
  382. "grad1", 2, 0, 0, client, mode)[0];
  383. auto comm2 = opr::CollectiveComm::make({grad2}, graph.get(),
  384. "grad2", 2, 0, 0, client, mode)[0];
  385. auto comm3 = opr::CollectiveComm::make({grad3}, graph.get(),
  386. "grad3", 2, 0, 0, client, mode)[0];
  387. gopt::GraphOptimizer()
  388. .add_pass<gopt::PackAllReduceScanPass>()
  389. .apply({{comm0, comm1, comm2, comm3}});
  390. auto get_hash = [] (const SymbolVar& symvar) {
  391. cg::OperatorNodeBase* opr = symvar.node()->owner_opr();
  392. return opr->cast_final_safe<opr::CollectiveComm>().pack_hash();
  393. };
  394. uint64_t hash0 = get_hash(comm0);
  395. uint64_t hash1 = get_hash(comm1);
  396. uint64_t hash2 = get_hash(comm2);
  397. uint64_t hash3 = get_hash(comm3);
  398. ASSERT_EQ(hash0, hash1);
  399. ASSERT_EQ(hash2, hash3);
  400. ASSERT_NE(hash0, hash2);
  401. }
  402. TEST_PASS(PackAllReduceReplacePass, CollectGroups) {
  403. REQUIRE_GPU(2);
  404. auto cns = load_multiple_xpus(2);
  405. auto graph = ComputingGraph::make();
  406. graph->options().graph_opt_level = 2;
  407. auto cli0 = std::make_shared<test::MockGroupClient>("mock_addr0");
  408. auto cli1 = std::make_shared<test::MockGroupClient>("mock_addr1");
  409. using GroupInfo = gopt::PackAllReduceReplacePass::GroupInfo;
  410. ThinHashMap<uint64_t, std::shared_ptr<GroupInfo>> group_info;
  411. ThinHashMap<uint64_t, cg::OprNodeArray> groups;
  412. auto add_opr = [&] (const CompNode& cn, TensorShape shape, const DType& dt,
  413. std::shared_ptr<test::MockGroupClient> client, uint64_t extra_hash) {
  414. auto dev0 = std::make_shared<DeviceTensorND>(cn, shape, dt);
  415. auto wrt = opr::SharedDeviceTensor::make(*graph, dev0);
  416. auto dev1 = std::make_shared<DeviceTensorND>(cn, TensorShape{1}, dt);
  417. auto target = opr::SharedDeviceTensor::make(*graph, dev1);
  418. auto grad = opr::VirtualGrad::make(target, wrt);
  419. auto comm = opr::CollectiveComm::make(
  420. {grad}, graph.get(), "key", 2, 0, 0, client,
  421. opr::CollectiveComm::Param::Mode::ALL_REDUCE_SUM)[0]
  422. .node()->owner_opr();
  423. comm->cast_final_safe<opr::CollectiveComm>().set_pack_hash(extra_hash);
  424. return gopt::PackAllReduceReplacePass::collect_groups(comm, group_info, groups);
  425. };
  426. uint64_t hash0 = add_opr(cns[0], TensorShape{1, 3}, dtype::Float32{}, cli0, 1);
  427. uint64_t hash1 = add_opr(cns[0], TensorShape{2, 4}, dtype::Float32{}, cli0, 1); // same
  428. uint64_t hash2 = add_opr(cns[1], TensorShape{3, 5}, dtype::Float32{}, cli0, 1); // comp_node
  429. uint64_t hash3 = add_opr(cns[0], TensorShape{4, 6}, dtype::Float16{}, cli0, 1); // dtype
  430. uint64_t hash4 = add_opr(cns[0], TensorShape{5, 7}, dtype::Float32{}, cli1, 1); // client
  431. uint64_t hash5 = add_opr(cns[0], TensorShape{6, 8}, dtype::Float32{}, cli0, 2); // extra_hash
  432. ASSERT_EQ(hash0, hash1);
  433. std::set<uint64_t> s;
  434. s.insert(hash0);
  435. s.insert(hash1);
  436. s.insert(hash2);
  437. s.insert(hash3);
  438. s.insert(hash4);
  439. s.insert(hash5);
  440. ASSERT_EQ(5, s.size());
  441. ASSERT_EQ(1, group_info.count(hash0));
  442. ASSERT_EQ(1, group_info.count(hash1));
  443. ASSERT_EQ(1, group_info.count(hash2));
  444. ASSERT_EQ(1, group_info.count(hash3));
  445. ASSERT_EQ(1, group_info.count(hash4));
  446. ASSERT_EQ(1, group_info.count(hash5));
  447. ASSERT_EQ(2, groups[hash0].size());
  448. ASSERT_EQ(2, groups[hash1].size());
  449. ASSERT_EQ(1, groups[hash2].size());
  450. ASSERT_EQ(1, groups[hash3].size());
  451. ASSERT_EQ(1, groups[hash4].size());
  452. ASSERT_EQ(1, groups[hash5].size());
  453. }
  454. TEST_PASS(PackAllReduceReplacePass, DividePacks) {
  455. auto cn = CompNode::load("gpux");
  456. auto graph = ComputingGraph::make();
  457. auto client = std::make_shared<test::MockGroupClient>();
  458. auto mode = opr::CollectiveComm::Param::Mode::ALL_REDUCE_SUM;
  459. ThinHashMap<uint64_t, cg::OprNodeArray> groups;
  460. ThinHashMap<uint64_t, std::vector<cg::OprNodeArray>> packs;
  461. auto insert_opr = [&] (size_t size) {
  462. auto dev = std::make_shared<DeviceTensorND>(cn, TensorShape{size / sizeof(float)});
  463. auto sd = opr::SharedDeviceTensor::make(*graph, dev);
  464. auto symvar = opr::CollectiveComm::make({sd}, graph.get(),
  465. "key", 2, 0, 0, client, mode)[0];
  466. auto opr = symvar.node()->owner_opr();
  467. auto& comm = opr->cast_final_safe<opr::CollectiveComm>();
  468. comm.set_pack_hash(1);
  469. return opr;
  470. };
  471. auto pack_size = [&] (cg::OprNodeArray& pack) {
  472. size_t sum = 0;
  473. for (size_t i = 0; i < pack.size(); i++) {
  474. auto var = pack[i]->input(0);
  475. sum += var->dtype().size(var->shape().total_nr_elems());
  476. }
  477. return sum;
  478. };
  479. groups[0].push_back(insert_opr(100)); // group0, pack0, size=1100
  480. groups[0].push_back(insert_opr(300)); // group0, pack0, size=1100
  481. groups[0].push_back(insert_opr(400)); // group0, pack0, size=1100
  482. groups[0].push_back(insert_opr(300)); // group0, pack0, size=1100
  483. groups[0].push_back(insert_opr(500)); // group0, pack1, size=800
  484. groups[0].push_back(insert_opr(200)); // group0, pack1, size=800
  485. groups[0].push_back(insert_opr(100)); // group0, pack1, size=800
  486. groups[1].push_back(insert_opr(100)); // group1, pack0, size=900
  487. groups[1].push_back(insert_opr(400)); // group1, pack0, size=900
  488. groups[1].push_back(insert_opr(300)); // group1, pack0, size=900
  489. groups[1].push_back(insert_opr(100)); // group1, pack0, size=900
  490. gopt::PackAllReduceReplacePass::divide_packs(groups, packs, 1000);
  491. ASSERT_EQ(2, packs.size());
  492. ASSERT_EQ(2, packs[0].size());
  493. ASSERT_EQ(4, packs[0][0].size());
  494. ASSERT_EQ(1100, pack_size(packs[0][0]));
  495. ASSERT_EQ(3, packs[0][1].size());
  496. ASSERT_EQ(800, pack_size(packs[0][1]));
  497. ASSERT_EQ(1, packs[1].size());
  498. ASSERT_EQ(4, packs[1][0].size());
  499. ASSERT_EQ(900, pack_size(packs[1][0]));
  500. }
  501. TEST_PASS(PackAllReduceReplacePass, InsertPackedOprs) {
  502. auto cn = CompNode::load("gpux");
  503. auto graph = ComputingGraph::make();
  504. auto client = std::make_shared<test::MockGroupClient>();
  505. auto mode = opr::CollectiveComm::Param::Mode::ALL_REDUCE_SUM;
  506. size_t nr_devices = 2;
  507. uint32_t rank = 0;
  508. uint32_t root = 0;
  509. using GroupInfo = gopt::PackAllReduceReplacePass::GroupInfo;
  510. ThinHashMap<uint64_t, std::shared_ptr<GroupInfo>> group_info;
  511. ThinHashMap<uint64_t, cg::OprNodeArray> groups;
  512. auto insert_opr = [&] (const TensorShape& shape) {
  513. auto dev = std::make_shared<DeviceTensorND>(cn, shape);
  514. auto sd = opr::SharedDeviceTensor::make(*graph, dev);
  515. auto symvar = opr::CollectiveComm::make({sd}, graph.get(),
  516. "key", nr_devices, rank, root, client, mode)[0];
  517. auto opr = symvar.node()->owner_opr();
  518. auto& comm = opr->cast_final_safe<opr::CollectiveComm>();
  519. comm.set_pack_hash(1);
  520. gopt::PackAllReduceReplacePass::collect_groups(opr, group_info, groups);
  521. return symvar;
  522. };
  523. auto shape_x = TensorShape{100, 200};
  524. auto shape_y = TensorShape{200, 400};
  525. auto x = insert_opr(shape_x);
  526. auto y = insert_opr(shape_y);
  527. ASSERT_EQ(1, group_info.size());
  528. ASSERT_EQ(1, groups.size());
  529. auto info = group_info.begin()->second;
  530. auto pack = groups.begin()->second;
  531. size_t pack_id = 0;
  532. ThinHashMap<VarNode*, VarNode*> replace_map;
  533. gopt::PackAllReduceReplacePass::insert_packed_oprs(pack_id, pack, info, replace_map, -1);
  534. auto grad_x = SymbolVar(x.node()->owner_opr()->input(0));
  535. auto grad_y = SymbolVar(y.node()->owner_opr()->input(0));
  536. auto concat = opr::Concat::make({grad_x.flatten(), grad_y.flatten()}, 0);
  537. std::string key = ssprintf("grad_pack_%zu", pack_id);
  538. auto allreduce = opr::CollectiveComm::make({concat}, graph.get(),
  539. key, nr_devices, rank, root, client, mode)[0];
  540. std::vector<size_t> partition;
  541. partition.push_back(shape_x.total_nr_elems());
  542. partition.push_back(shape_y.total_nr_elems());
  543. auto splits = opr::Split::make(allreduce,
  544. opr::Split::Options::make_partition(allreduce, 0, partition));
  545. ASSERT_EQ(2, splits.size());
  546. auto dest_x = splits[0].reshape(shape_x);
  547. auto dest_y = splits[1].reshape(shape_y);
  548. ASSERT_EQ(2, replace_map.size());
  549. ASSERT_TRUE(replace_map.count(x.node()) > 0);
  550. ASSERT_EQ(replace_map.at(x.node()), dest_x.node());
  551. ASSERT_TRUE(replace_map.count(y.node()) > 0);
  552. ASSERT_EQ(replace_map.at(y.node()), dest_y.node());
  553. }
  554. TEST_PASS(PackAllReduceReplacePass, Equivalence) {
  555. REQUIRE_GPU(2);
  556. auto cns = load_multiple_xpus(2);
  557. auto client = std::make_shared<test::MockGroupClient>();
  558. auto build_graph = [&] (uint32_t rank, std::shared_ptr<ComputingGraph> graph,
  559. SymbolVarArray& array) {
  560. HostTensorGenerator<> gen;
  561. auto cn = cns[rank];
  562. auto host_x = gen({1, 1000});
  563. auto host_y = gen({1000, 1});
  564. auto dev_x = std::make_shared<DeviceTensorND>(cn);
  565. auto dev_y = std::make_shared<DeviceTensorND>(cn);
  566. dev_x->copy_from(*host_x).sync();
  567. dev_y->copy_from(*host_y).sync();
  568. auto x = opr::SharedDeviceTensor::make(*graph, dev_x);
  569. auto y = opr::VolatileSharedDeviceTensor::make(*graph, dev_y);
  570. auto loss = opr::MatrixMul::make(x, y).flatten();
  571. auto grad_x = opr::VirtualGrad::make(loss, x);
  572. auto grad_y = opr::VirtualGrad::make(loss, y);
  573. using Mode = opr::CollectiveComm::Param::Mode;
  574. bool is_root = (rank == 0);
  575. auto reduced_x = opr::CollectiveComm::make({grad_x}, graph.get(),
  576. "x", 2, is_root, rank, client, Mode::ALL_REDUCE_SUM)[0] / 2;
  577. auto reduced_y = opr::CollectiveComm::make({grad_y}, graph.get(),
  578. "y", 2, is_root, rank, client, Mode::ALL_REDUCE_SUM)[0] / 2;
  579. graph->options().allreduce_pack_max_size = 5000;
  580. graph->options().allreduce_pack_ignore_first = 0;
  581. auto dest_vars = gopt::GraphOptimizer{}
  582. .add_pass<gopt::PackAllReduceScanPass>()
  583. .add_pass<gopt::PackAllReduceReplacePass>()
  584. .apply({{reduced_x, reduced_y}}).endpoint_vars();
  585. array.emplace_back(reduced_x);
  586. array.emplace_back(reduced_y);
  587. array.emplace_back(dest_vars[0]);
  588. array.emplace_back(dest_vars[1]);
  589. };
  590. auto run = [&] (uint32_t rank) {
  591. auto graph = ComputingGraph::make();
  592. SymbolVarArray array;
  593. build_graph(rank, graph, array);
  594. HostTensorND host_reduced_x, host_reduced_y, host_dest_0, host_dest_1;
  595. graph->options().allreduce_pack_max_size = 0;
  596. auto func = graph->compile({make_callback_copy(array[0], host_reduced_x),
  597. make_callback_copy(array[1], host_reduced_y),
  598. make_callback_copy(array[2], host_dest_0),
  599. make_callback_copy(array[3], host_dest_1)});
  600. func->execute();
  601. MGB_ASSERT_TENSOR_EQ(host_reduced_x, host_dest_0);
  602. MGB_ASSERT_TENSOR_EQ(host_reduced_y, host_dest_1);
  603. };
  604. std::thread t0(run, 0);
  605. std::thread t1(run, 1);
  606. t0.join();
  607. t1.join();
  608. }
  609. #endif // MGB_ENABLE_OPR_MM
  610. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台