You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

blas.cpp 26 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665
  1. /**
  2. * \file src/opr/impl/blas.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megbrain/opr/blas.h"
  12. #include "megbrain/common.h"
  13. #include "megbrain/comp_node_env.h"
  14. #include "megbrain/graph/grad_impl.h"
  15. #include "megbrain/opr/basic_arith_wrapper.h"
  16. #include "megbrain/opr/indexing.h"
  17. #include "megbrain/opr/tensor_gen.h"
  18. #include "megbrain/opr/tensor_manip.h"
  19. #include "megbrain/opr/search_policy/algo_chooser.h"
  20. #include "megbrain/opr/search_policy/profiler.h"
  21. #include "./internal/megdnn_opr_wrapper.inl"
  22. #include "./search_policy/workspace_need_limit_getter.inl"
  23. #include "megdnn/oprs/linalg.h"
  24. using namespace mgb;
  25. using namespace opr;
  26. namespace {
  27. int get_mask_from_matmul(const megdnn::param::MatrixMul& param) {
  28. return static_cast<int>(param.transposeA) +
  29. (static_cast<int>(param.transposeB) * 2);
  30. }
  31. }
  32. /* ================= MatrixMul ================= */
  33. MGB_DYN_TYPE_OBJ_FINAL_IMPL(MatrixMul);
  34. MatrixMul::MatrixMul(VarNode* a, VarNode* b, const Param& param,
  35. const ExecutionPolicy& policy,
  36. const OperatorNodeConfig& config)
  37. : Super{a->owner_graph(), config, "matrix_mul", {a, b}} {
  38. init_megdnn_opr(*this, param);
  39. m_policy = policy;
  40. add_input({a, b});
  41. }
  42. SymbolVar MatrixMul::make(SymbolVar a, SymbolVar b, const Param& param,
  43. const ExecutionPolicy& policy,
  44. const OperatorNodeConfig& config) {
  45. return a.insert_single_output_opr<MatrixMul>(a.node(), b.node(), param,
  46. policy, config);
  47. }
  48. void MatrixMul::init_output_dtype() {
  49. DType output_dtype = config().output_dtype();
  50. megdnn_opr()->deduce_dtype(input(0)->dtype(), input(1)->dtype(),
  51. output_dtype);
  52. output(0)->dtype(output_dtype);
  53. }
  54. bool MatrixMul::check_layout(const TensorLayout& layout, int transpose) {
  55. mgb_assert(layout.ndim == 2, "input to MatrixMul must be 2-dim; got %s",
  56. layout.to_string().c_str());
  57. return layout.stride[0 ^ transpose] >=
  58. static_cast<ptrdiff_t>(layout.shape[1 ^ transpose]) &&
  59. layout.stride[1 ^ transpose] == 1;
  60. }
  61. void MatrixMul::add_input_layout_constraint() {
  62. auto check = [](const TensorLayout& ly) {
  63. return check_layout(ly, 0) || check_layout(ly, 1);
  64. };
  65. input(0)->add_layout_constraint(check);
  66. input(1)->add_layout_constraint(check);
  67. }
  68. size_t MatrixMul::get_workspace_size_bytes(
  69. const TensorShapeArray& input_shapes,
  70. const TensorShapeArray& output_shapes) const {
  71. // we may change transepose param in the impl, so get the max possible
  72. // workspace by trying all cases
  73. // current implementation in megdnn guarantees that workspaces in different
  74. // cases are on the same order of magnitude
  75. auto mo = megdnn_opr();
  76. auto&& tparam = mo->param();
  77. size_t a, b, c, d;
  78. mgb_assert(input_shapes.size() == 2 && output_shapes.size() == 1);
  79. TensorLayout i0(input_shapes[0], input(0)->dtype()),
  80. i1(input_shapes[1], input(1)->dtype()),
  81. out(output_shapes[0], output(0)->dtype());
  82. auto transpose = [](TensorLayout& dst, bool& param) {
  83. std::swap(dst.shape[0], dst.shape[1]);
  84. dst.stride[0] = dst[1];
  85. param ^= 1;
  86. };
  87. MGB_TRY {
  88. a = AlgoChooser<megdnn::MatrixMul>::setup_algo({i0, i1, out},
  89. megdnn_opr(), this);
  90. //! Here we just want to save the execution policy got from setup_algo,
  91. //! while change the delaration of get_workspace_in_bytes may cause
  92. //! many changes.
  93. const_cast<MatrixMul*>(this)
  94. ->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] =
  95. megdnn_opr()->execution_policy();
  96. transpose(i0, tparam.transposeA);
  97. b = AlgoChooser<megdnn::MatrixMul>::setup_algo({i0, i1, out},
  98. megdnn_opr(), this);
  99. const_cast<MatrixMul*>(this)
  100. ->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] =
  101. megdnn_opr()->execution_policy();
  102. transpose(i1, tparam.transposeB);
  103. c = AlgoChooser<megdnn::MatrixMul>::setup_algo({i0, i1, out},
  104. megdnn_opr(), this);
  105. const_cast<MatrixMul*>(this)
  106. ->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] =
  107. megdnn_opr()->execution_policy();
  108. transpose(i0, tparam.transposeA);
  109. d = AlgoChooser<megdnn::MatrixMul>::setup_algo({i0, i1, out},
  110. megdnn_opr(), this);
  111. const_cast<MatrixMul*>(this)
  112. ->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] =
  113. megdnn_opr()->execution_policy();
  114. }
  115. MGB_FINALLY({ tparam = this->param(); });
  116. return std::max(std::max(a, b), std::max(c, d));
  117. }
  118. void MatrixMul::scn_do_execute() {
  119. auto inp0 = input(0)->dev_tensor().as_megdnn(),
  120. inp1 = input(1)->dev_tensor().as_megdnn(),
  121. out = output(0)->dev_tensor().as_megdnn();
  122. auto transpose = [](TensorLayout& layout, bool& trans) {
  123. if (!check_layout(layout, 0)) {
  124. mgb_assert(check_layout(layout, 1));
  125. std::swap(layout.shape[0], layout.shape[1]);
  126. std::swap(layout.stride[0], layout.stride[1]);
  127. trans ^= 1;
  128. }
  129. };
  130. auto&& tparam = megdnn_opr()->param();
  131. MGB_TRY {
  132. transpose(inp0.layout, tparam.transposeA);
  133. transpose(inp1.layout, tparam.transposeB);
  134. megdnn_opr()->execution_policy() =
  135. m_cadidate_execution_policies[get_mask_from_matmul(tparam)];
  136. megdnn_opr()->exec(inp0, inp1, out,
  137. intl::get_megdnn_workspace_from_var(output(1)));
  138. }
  139. MGB_FINALLY({ tparam = this->param(); });
  140. }
  141. #if MGB_ENABLE_GRAD
  142. MGB_IMPL_OPR_GRAD(MatrixMul) {
  143. mgb_assert(opr.input(0)->dtype().category() == DTypeCategory::FLOAT,
  144. "only float data type supported for grad");
  145. SymbolVar grad, i0{opr.input(0)}, i1{opr.input(1)}, og{out_grad[0]};
  146. if (wrt_idx == 0) {
  147. // A * B = C, A' = C' * Bt
  148. if (opr.param().transposeA) {
  149. grad = MatrixMul::make(i1, og, {opr.param().transposeB, true});
  150. } else {
  151. grad = MatrixMul::make(og, i1, {false, !opr.param().transposeB});
  152. }
  153. } else {
  154. mgb_assert(wrt_idx == 1);
  155. // A * B = C, B' = At * C'
  156. if (opr.param().transposeB) {
  157. grad = MatrixMul::make(og, i0, {true, opr.param().transposeA});
  158. } else {
  159. grad = MatrixMul::make(i0, og, {!opr.param().transposeA, false});
  160. }
  161. }
  162. return grad.node();
  163. }
  164. #endif
  165. /* ================= BatchedMatrixMul ================= */
  166. MGB_DYN_TYPE_OBJ_FINAL_IMPL(BatchedMatrixMul);
  167. BatchedMatrixMul::BatchedMatrixMul(VarNode* a, VarNode* b, const Param& param,
  168. const ExecutionPolicy& policy,
  169. const OperatorNodeConfig& config)
  170. : Super{a->owner_graph(), config, "batched_matrix_mul", {a, b}} {
  171. init_megdnn_opr(*this, param);
  172. m_policy = policy;
  173. add_input({a, b});
  174. }
  175. SymbolVar BatchedMatrixMul::make(SymbolVar a, SymbolVar b, const Param& param,
  176. const ExecutionPolicy& policy,
  177. const OperatorNodeConfig& config) {
  178. return a.insert_single_output_opr<BatchedMatrixMul>(a.node(), b.node(),
  179. param, policy, config);
  180. }
  181. void BatchedMatrixMul::add_input_layout_constraint() {
  182. auto check = [](const TensorLayout& ly) {
  183. mgb_assert(ly.ndim == 3,
  184. "input to BatchedMatrixMul must be 3-dim; got %s",
  185. ly.to_string().c_str());
  186. bool good_layout =
  187. ((ly.stride[0] >=
  188. static_cast<ptrdiff_t>(ly.shape[1] * ly.stride[1])) &&
  189. (ly.stride[0] >=
  190. static_cast<ptrdiff_t>(ly.shape[2] * ly.stride[2])));
  191. bool ret = good_layout &&
  192. (check_layout(ly, true) || check_layout(ly, false));
  193. return ret;
  194. };
  195. input(0)->add_layout_constraint(check);
  196. input(1)->add_layout_constraint(check);
  197. }
  198. void BatchedMatrixMul::init_output_dtype() {
  199. DType output_dtype = config().output_dtype();
  200. megdnn_opr()->deduce_dtype(input(0)->dtype(), input(1)->dtype(),
  201. output_dtype);
  202. output(0)->dtype(output_dtype);
  203. }
  204. bool BatchedMatrixMul::check_layout(const TensorLayout& layout,
  205. bool transpose) {
  206. int lhs = (transpose) ? 2 : 1, rhs = (transpose) ? 1 : 2;
  207. return (layout.stride[lhs] >= static_cast<ptrdiff_t>(layout.shape[rhs])) &&
  208. (layout.stride[rhs] == 1);
  209. }
  210. size_t BatchedMatrixMul::get_workspace_size_bytes(
  211. const TensorShapeArray& input_shapes,
  212. const TensorShapeArray& output_shapes) const {
  213. // we may change transepose param in the impl, so get the max possible
  214. // workspace by trying all cases
  215. // current implementation in megdnn guarantees that workspaces in different
  216. // cases are on the same order of magnitude
  217. auto mo = megdnn_opr();
  218. auto&& tparam = mo->param();
  219. size_t a, b, c, d;
  220. mgb_assert(input_shapes.size() == 2 && output_shapes.size() == 1);
  221. TensorLayout i0(input_shapes[0], input(0)->dtype()),
  222. i1(input_shapes[1], input(1)->dtype()),
  223. out(output_shapes[0], output(0)->dtype());
  224. auto transpose = [](TensorLayout& dst, bool& param) {
  225. std::swap(dst.shape[1], dst.shape[2]);
  226. dst.stride[1] = dst[2];
  227. param ^= 1;
  228. };
  229. MGB_TRY {
  230. a = AlgoChooser<megdnn::BatchedMatrixMul>::setup_algo(
  231. {i0, i1, out}, megdnn_opr(), this);
  232. const_cast<BatchedMatrixMul*>(this)
  233. ->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] =
  234. megdnn_opr()->execution_policy();
  235. transpose(i0, tparam.transposeA);
  236. b = AlgoChooser<megdnn::BatchedMatrixMul>::setup_algo(
  237. {i0, i1, out}, megdnn_opr(), this);
  238. const_cast<BatchedMatrixMul*>(this)
  239. ->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] =
  240. megdnn_opr()->execution_policy();
  241. transpose(i1, tparam.transposeB);
  242. c = AlgoChooser<megdnn::BatchedMatrixMul>::setup_algo(
  243. {i0, i1, out}, megdnn_opr(), this);
  244. const_cast<BatchedMatrixMul*>(this)
  245. ->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] =
  246. megdnn_opr()->execution_policy();
  247. transpose(i0, tparam.transposeA);
  248. d = AlgoChooser<megdnn::BatchedMatrixMul>::setup_algo(
  249. {i0, i1, out}, megdnn_opr(), this);
  250. const_cast<BatchedMatrixMul*>(this)
  251. ->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] =
  252. megdnn_opr()->execution_policy();
  253. }
  254. MGB_FINALLY({ tparam = this->param(); });
  255. return std::max(std::max(a, b), std::max(c, d));
  256. }
  257. void BatchedMatrixMul::scn_do_execute() {
  258. auto inp0 = input(0)->dev_tensor().as_megdnn(),
  259. inp1 = input(1)->dev_tensor().as_megdnn(),
  260. out = output(0)->dev_tensor().as_megdnn();
  261. auto transpose = [](TensorLayout& layout, bool& trans) {
  262. if (!check_layout(layout, false)) {
  263. mgb_assert(check_layout(layout, true));
  264. std::swap(layout.shape[1], layout.shape[2]);
  265. std::swap(layout.stride[1], layout.stride[2]);
  266. mgb_assert(layout.stride[2] == 1);
  267. trans ^= 1;
  268. }
  269. };
  270. auto&& tparam = megdnn_opr()->param();
  271. MGB_TRY {
  272. transpose(inp0.layout, tparam.transposeA);
  273. transpose(inp1.layout, tparam.transposeB);
  274. megdnn_opr()->execution_policy() =
  275. m_cadidate_execution_policies[get_mask_from_matmul(tparam)];
  276. megdnn_opr()->exec(inp0, inp1, out,
  277. intl::get_megdnn_workspace_from_var(output(1)));
  278. }
  279. MGB_FINALLY({ tparam = this->param(); });
  280. }
  281. #if MGB_ENABLE_GRAD
  282. MGB_IMPL_OPR_GRAD(BatchedMatrixMul) {
  283. mgb_assert(opr.input(0)->dtype().category() == DTypeCategory::FLOAT,
  284. "only float data type supported for grad");
  285. mgb_assert(out_grad.size() == 2 && !out_grad[1]);
  286. SymbolVar grad, i0{opr.input(0)}, i1{opr.input(1)}, og{out_grad[0]};
  287. if (wrt_idx == 0) {
  288. // A * B = C, A' = C' * Bt
  289. if (opr.param().transposeA) {
  290. grad = BatchedMatrixMul::make(
  291. i1, og, {opr.param().transposeB, true});
  292. } else {
  293. grad = BatchedMatrixMul::make(
  294. og, i1, {false, !opr.param().transposeB});
  295. }
  296. } else {
  297. mgb_assert(wrt_idx == 1);
  298. // A * B = C, B' = At * C'
  299. if (opr.param().transposeB) {
  300. grad = BatchedMatrixMul::make(
  301. og, i0, {true, opr.param().transposeA});
  302. } else {
  303. grad = BatchedMatrixMul::make(
  304. i0, og, {!opr.param().transposeA, false});
  305. }
  306. }
  307. return grad.node();
  308. }
  309. #endif
  310. /* ================= Dot ================= */
  311. MGB_DYN_TYPE_OBJ_FINAL_IMPL(Dot);
  312. Dot::Dot(VarNode *opr0, VarNode *opr1, const OperatorNodeConfig &config):
  313. Super{opr0->owner_graph(), config, "dot", {opr0, opr1}}
  314. {
  315. init_megdnn_opr(*this, {});
  316. add_input({opr0, opr1}, AddInputSortType::CUR_ADDED);
  317. static_assert(std::is_empty<Param>::value, "Dot param should be empty");
  318. mgb_assert(opr0->dtype().category() != DTypeCategory::QUANTIZED &&
  319. opr1->dtype().category() != DTypeCategory::QUANTIZED,
  320. "Dot does not support quantized input.");
  321. }
  322. void Dot::init_output_static_infer_desc() {
  323. using namespace cg::static_infer;
  324. auto &&mgr = owner_graph()->static_infer_manager();
  325. auto infer_shp = [](TensorShape &dest, const InpVal &){
  326. dest = {1};
  327. return true;
  328. };
  329. auto infer_workspace = [this](TensorShape &dest, const InpVal &iv) {
  330. auto dtype = input(0)->dtype();
  331. TensorLayout ily(
  332. {std::max(
  333. iv.val[0].shape().total_nr_elems(),
  334. iv.val[1].shape().total_nr_elems())},
  335. dtype);
  336. dest.ndim = 1;
  337. dest.shape[0] = megdnn_opr()->get_workspace_in_bytes(
  338. ily, ily, {{1}, dtype});
  339. return true;
  340. };
  341. mgr.register_shape_infer(output(0), {SourceType::CONSTANT, {}, infer_shp});
  342. mgr.register_shape_infer(output(1),
  343. {SourceType::DEP,
  344. {{input(0), DepType::SHAPE}, {input(1), DepType::SHAPE}},
  345. infer_workspace});
  346. }
  347. void Dot::scn_do_execute() {
  348. auto i0 = input(0)->dev_tensor().as_megdnn(),
  349. i1 = input(1)->dev_tensor().as_megdnn();
  350. mgb_throw_if(i0.layout.ndim != 1 || i1.layout.ndim != 1, GraphError,
  351. "Invalid input shapes for Dot: %s",
  352. cg::dump_var_info(input()).c_str());
  353. if (i0.layout.shape[0] != i1.layout.shape[0]) {
  354. bool s0 = i0.layout.shape[0] == 1, s1 = i1.layout.shape[0] == 1;
  355. mgb_throw_if(!s0 && !s1, GraphError,
  356. "Invalid input shapes for Dot: %s",
  357. cg::dump_var_info(input()).c_str());
  358. if (s0) {
  359. i0.layout.shape[0] = i1.layout.shape[0];
  360. i0.layout.stride[0] = 0;
  361. }
  362. else {
  363. i1.layout.shape[0] = i0.layout.shape[0];
  364. i1.layout.stride[0] = 0;
  365. }
  366. }
  367. megdnn_opr()->exec(i0, i1, output(0)->dev_tensor().as_megdnn(),
  368. intl::get_megdnn_workspace_from_var(output(1)));
  369. }
  370. void Dot::add_input_layout_constraint() {
  371. auto check = [](const TensorLayout &ly) {
  372. mgb_throw_if(ly.ndim != 1, GraphError,
  373. "Dot input must be 1-dim; got %s", ly.to_string().c_str());
  374. return ly.stride[0] >= 0;
  375. };
  376. input(0)->add_layout_constraint(check);
  377. input(1)->add_layout_constraint(check);
  378. }
  379. #if MGB_ENABLE_GRAD
  380. MGB_IMPL_OPR_GRAD(Dot) {
  381. auto other_input = opr.input(wrt_idx == 0 ? 1 : 0);
  382. auto ishp0 = opr::GetVarShape::make(opr.input(0)),
  383. ishp1 = opr::GetVarShape::make(opr.input(1));
  384. auto max_ishp = opr::GetVarShape::make({opr.input(0), opr.input(1)});
  385. return reduce_sum(
  386. Broadcast::make(mul(out_grad[0], other_input), max_ishp),
  387. wrt_idx ? ishp1 : ishp0).node();
  388. }
  389. #endif
  390. SymbolVar Dot::make(SymbolVar opr0, SymbolVar opr1,
  391. const OperatorNodeConfig &config) {
  392. return opr0.insert_single_output_opr<Dot>(opr0.node(), opr1.node(), config);
  393. }
  394. void Dot::record_execute_deps(ExecDependencyArray &deps) {
  395. record_megdnn_opr(deps);
  396. }
  397. /* ================= MatrixInverse ================= */
  398. MGB_DYN_TYPE_OBJ_FINAL_IMPL(MatrixInverse);
  399. MEGDNN_OPR_INIT1(MatrixInverse, "matrix_inv")
  400. #if MGB_ENABLE_GRAD
  401. MGB_IMPL_OPR_GRAD(MatrixInverse) {
  402. SymbolVar a = opr.output(0);
  403. // TODO: use unified MatrixMul interface when we have it
  404. auto n = opr::Subtensor::make(a.symshape(),
  405. {opr::Subtensor::AxisIndexer::make_index(0, a.make_scalar(-1))}),
  406. tshp = opr::Concat::make({a.make_scalar(0), n, n}, 0),
  407. // our hard disk is limited so derivation of the gradient is omitted:)
  408. a_bnn = opr::Dimshuffle::make(opr::Reshape::make(a, tshp, 0),
  409. {0, 2, 1}),
  410. dy = opr::Reshape::make(out_grad.at(0), tshp, 0),
  411. da = - BatchedMatrixMul::make(BatchedMatrixMul::make(a_bnn, dy),
  412. a_bnn);
  413. return da.reshape(a.symshape()).node();
  414. }
  415. #endif
  416. /* ================= SVD ================= */
  417. MGB_DYN_TYPE_OBJ_FINAL_IMPL(SVD);
  418. SVD::SVD(VarNode* src, const Param& param, const OperatorNodeConfig& config) :
  419. Super(OperatorNodeBaseCtorParam{src->owner_graph(),
  420. config, "svd", {src}}) {
  421. mgb_assert(src->dtype() == megdnn::dtype::Float32(),
  422. "Singular Value Decomposition on non-float32 tensors is "
  423. "not supoorted.");
  424. init_megdnn_opr(*this, param);
  425. add_input({src});
  426. if (!param.compute_uv) {
  427. output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE)
  428. .add_flag(VarNode::Flag::VOLATILE_CONTENT);
  429. output(2)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE)
  430. .add_flag(VarNode::Flag::VOLATILE_CONTENT);
  431. }
  432. }
  433. #if MGB_ENABLE_GRAD
  434. namespace {
  435. /*!
  436. * \brief a wrapper similar to SymbolVar but can safely contain nullptr as zero
  437. *
  438. * Note: here we introduce a new class of SymbolVar representation, which allows
  439. * nullptr to represent zero values, and overload other C++ operators
  440. * accordingly. Therefore we can avoid testing nullptr values everywhere in SVD
  441. * grad.
  442. *
  443. * This is a general approach. It can be moved to some header file if we
  444. * encounter another operator that also has complex gradient computation.
  445. */
  446. class SafeSymbolVar {
  447. VarNode* m_node;
  448. public:
  449. explicit SafeSymbolVar(VarNode* node) : m_node{node} {}
  450. SafeSymbolVar(SymbolVar x) : m_node{x.node()} {}
  451. SafeSymbolVar() : m_node{nullptr} {}
  452. VarNode* node() const { return m_node; }
  453. SymbolVar s() const { return m_node; }
  454. #define FWD(name) \
  455. template <typename... Args> \
  456. SafeSymbolVar name(Args&&... args) { \
  457. if (!m_node) \
  458. return {}; \
  459. return SymbolVar{m_node}.name(std::forward<Args>(args)...); \
  460. }
  461. FWD(reshape)
  462. FWD(broadcast)
  463. #undef FWD
  464. };
  465. SymbolVar unsafe(SymbolVar x) {
  466. return x;
  467. }
  468. SymbolVar unsafe(SafeSymbolVar x) {
  469. return x.s();
  470. }
  471. template <typename T>
  472. T reshape_anybatch(T x, SymbolVar tshp) {
  473. if (!x.node())
  474. return x;
  475. return opr::Reshape::make(unsafe(x), tshp, 0);
  476. }
  477. template <typename T>
  478. T trans(T x) {
  479. if (!x.node())
  480. return x;
  481. return opr::Dimshuffle::make(unsafe(x), {0, 2, 1});
  482. }
  483. template <typename T>
  484. T matmul(T a, T b, const opr::BatchedMatrixMul::Param& param = {}) {
  485. if (!a.node() || !b.node())
  486. return {};
  487. return opr::BatchedMatrixMul::make(unsafe(a), unsafe(b), param);
  488. }
  489. SafeSymbolVar matmuls(SafeSymbolVar x, SafeSymbolVar y,
  490. const opr::BatchedMatrixMul::Param& param = {}) {
  491. return matmul(x, y, param);
  492. }
  493. SafeSymbolVar operator-(SafeSymbolVar x) {
  494. if (x.node())
  495. return -x.s();
  496. return {};
  497. }
  498. #define OP(x, a_, b_) \
  499. SafeSymbolVar operator x(SafeSymbolVar a, SafeSymbolVar b) { \
  500. if (!a.node()) \
  501. return a_; \
  502. if (!b.node()) \
  503. return b_; \
  504. return a.s() x b.s(); \
  505. }
  506. OP(+, b, a)
  507. OP(-, -b, a)
  508. OP(*, {}, {})
  509. #undef OP
  510. } // anonymous namespace
  511. #endif
  512. #if MGB_ENABLE_GRAD
  513. MGB_IMPL_OPR_GRAD(SVD) {
  514. /**
  515. * The formula is copied from
  516. * https://j-towns.github.io/papers/svd-derivative.pdf
  517. * It is hard to compare m, n here, so I do not refer this paper :
  518. * http://eprints.maths.ox.ac.uk/1079/1/NA-08-01.pdf
  519. */
  520. mgb_throw_if(!opr.param().compute_uv, MegBrainError,
  521. "Singular value decomposition gradient computation depends "
  522. "on U and V, please set compute_uv = True");
  523. SymbolVar a{opr.input(0)}, u_raw{opr.output(0)}, s_raw{opr.output(1)},
  524. vt_raw{opr.output(2)};
  525. SafeSymbolVar grad_u_raw{out_grad[0]}, grad_s_raw{out_grad[1]},
  526. grad_vt_raw{out_grad[2]};
  527. auto param10 = BatchedMatrixMul::Param{true, false},
  528. param00 = BatchedMatrixMul::Param{false, false},
  529. param01 = BatchedMatrixMul::Param{false, true};
  530. auto n = opr::Subtensor::make(a.symshape(),
  531. {opr::Subtensor::AxisIndexer::make_index(
  532. 0, a.make_scalar(-1))}),
  533. m = opr::Subtensor::make(a.symshape(),
  534. {opr::Subtensor::AxisIndexer::make_index(
  535. 0, a.make_scalar(-2))}),
  536. r = opr::Subtensor::make(s_raw.symshape(),
  537. {opr::Subtensor::AxisIndexer::make_index(
  538. 0, s_raw.make_scalar(-1))});
  539. SymbolVar sshp = opr::Concat::make({a.make_scalar(0), r}, 0),
  540. ushp = opr::Concat::make({a.make_scalar(0), m, r}, 0),
  541. vtshp = opr::Concat::make({a.make_scalar(0), r, n}, 0),
  542. u = reshape_anybatch(u_raw, ushp),
  543. vt = reshape_anybatch(vt_raw, vtshp), v = trans(vt);
  544. SafeSymbolVar grad_u = reshape_anybatch(grad_u_raw, ushp),
  545. grad_vt = reshape_anybatch(grad_vt_raw, vtshp),
  546. grad_v = trans(grad_vt);
  547. auto batches = opr::Subtensor::make(
  548. u.symshape(),
  549. {opr::Subtensor::AxisIndexer::make_index(0, u.make_scalar(-3))});
  550. auto brr = opr::Concat::make({batches, r, r}, 0);
  551. auto I_r = opr::Eye::make(r, {0, DTypeEnum::Float32})
  552. .reshape(opr::Concat::make({a.make_scalar(1), r, r}, 0))
  553. .broadcast(brr),
  554. filter_matrix = 1 - I_r;
  555. auto sf = reshape_anybatch(s_raw, sshp)
  556. .reshape(opr::Concat::make({batches, r, a.make_scalar(1)},
  557. 0))
  558. .broadcast(brr);
  559. auto grad_sf = reshape_anybatch(grad_s_raw, sshp)
  560. .reshape(opr::Concat::make(
  561. {batches, r, a.make_scalar(1)}, 0))
  562. .broadcast(brr);
  563. auto s = I_r * sf;
  564. auto grad_s = I_r * grad_sf;
  565. auto s_inv = 1 / (s + filter_matrix) - filter_matrix;
  566. auto s_rhs = sf * sf, s_mid = trans(s_rhs) - s_rhs,
  567. s_avoid_nan = s_mid + I_r, f = filter_matrix / s_avoid_nan;
  568. auto I_m = opr::Eye::make(m, {0, DTypeEnum::Float32})
  569. .reshape(opr::Concat::make({a.make_scalar(1), m, m}, 0))
  570. .broadcast(opr::Concat::make({batches, m, m}, 0)),
  571. I_n = opr::Eye::make(n, {0, DTypeEnum::Float32})
  572. .reshape(opr::Concat::make({a.make_scalar(1), n, n}, 0))
  573. .broadcast(opr::Concat::make({batches, n, n}, 0));
  574. auto ut_du = matmuls(u, grad_u, param10),
  575. vt_dv = matmuls(v, grad_v, param10);
  576. auto ret =
  577. matmuls(matmuls(matmuls(u, f * (ut_du - trans(ut_du))), s,
  578. param00) +
  579. matmuls(matmuls(I_m - matmul(u, u, param01),
  580. grad_u),
  581. s_inv),
  582. v, param01) +
  583. matmuls(matmuls(u, I_r * grad_s), v, param01) +
  584. matmuls(u, matmuls(matmuls(s, f * (vt_dv - trans(vt_dv)), param00),
  585. v, param01) +
  586. matmuls(matmuls(s_inv, grad_v, param01),
  587. I_n - matmul(v, v, param01)));
  588. return ret.reshape(a.symshape()).node();
  589. }
  590. #endif
  591. SymbolVarArray SVD::make(const SymbolVar& src, const Param& param,
  592. const OperatorNodeConfig& config) {
  593. auto&& out = src.node()
  594. ->owner_graph()
  595. ->insert_opr(std::make_unique<SVD>(src.node(), param,
  596. config))
  597. ->output();
  598. SymbolVarArray ret(out.size());
  599. for (size_t i = 0; i < ret.size(); i++) {
  600. ret[i] = out[i];
  601. }
  602. return ret;
  603. }
  604. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台