You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grad_override.cpp 28 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710
  1. #include "./grad.h"
  2. #include "megbrain/imperative/ops/autogen.h"
  3. #include "megbrain/imperative/transformations/grad.h"
  4. namespace mgb::imperative::python {
  5. class CustomGradMaker {
  6. bool output_size_set = false, input_has_grad_initialized = false;
  7. CustomBackward& target;
  8. size_t nr_inputs;
  9. void init_input_has_grad() {
  10. if (!input_has_grad_initialized) {
  11. input_has_grad_initialized = true;
  12. target.m_input_has_grad.resize(nr_inputs, true);
  13. }
  14. }
  15. public:
  16. CustomGradMaker(CustomBackward& target, size_t nr_inputs)
  17. : target(target), nr_inputs(nr_inputs) {}
  18. CustomGradMaker& backward(CustomBackward::BackwardFn f) {
  19. mgb_assert(!target.m_backward);
  20. target.m_backward = f;
  21. return *this;
  22. }
  23. // mandatory
  24. CustomGradMaker& output_size(size_t sz) {
  25. mgb_assert(!output_size_set);
  26. output_size_set = true;
  27. target.m_output_attrs.resize(sz);
  28. return *this;
  29. }
  30. // optional, defaults to all true
  31. CustomGradMaker& input_has_grad(size_t i, bool v) {
  32. init_input_has_grad();
  33. target.m_input_has_grad.at(i) = v;
  34. return *this;
  35. }
  36. // optional, defaults to all true
  37. CustomGradMaker& output_requires_grad(size_t i, bool v) {
  38. target.m_output_attrs.at(i).requires_grad = v;
  39. return *this;
  40. }
  41. // optional, defaults to all true
  42. CustomGradMaker& output_captured(size_t i, bool v) {
  43. target.m_output_attrs.at(i).captured = v;
  44. return *this;
  45. }
  46. void finalize() {
  47. mgb_assert(output_size_set);
  48. init_input_has_grad();
  49. }
  50. };
  51. namespace {
  52. ValueRef get_shape(ValueRef x) {
  53. static auto op = GetVarShape::make();
  54. return imperative::apply(*op, x)[0];
  55. }
  56. ValueRef reduce_to(ValueRef x, ValueRef s) {
  57. static auto op = Reduce::make();
  58. return imperative::apply(*op, x, s)[0];
  59. }
  60. ValueRef reshape_to(ValueRef x, ValueRef s) {
  61. static auto op = Reshape::make();
  62. return imperative::apply(*op, x, s)[0];
  63. }
  64. ValueRef broadcast_to(ValueRef x, ValueRef s) {
  65. static auto op = Broadcast::make();
  66. return imperative::apply(*op, x, s)[0];
  67. }
  68. ValueRef make_empty_tensor(
  69. CompNodeValue::ref_t device, ValueRef shape, DTypeValue::ref_t dtype) {
  70. HostTensorStorage storage(*device);
  71. storage.ensure_size(dtype->size());
  72. std::memset(storage.ptr(), 0, dtype->size());
  73. auto t = imperative::apply(
  74. CreateTensor(CreateTensor::Unique, *device, *dtype, ValueShape()),
  75. HostStorage::make(storage))[0];
  76. auto res = broadcast_to(t, shape);
  77. return res;
  78. }
  79. std::optional<ValueRefList> matrix_mul_grad_rule(
  80. const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
  81. CustomBackward& backward) {
  82. auto&& matmul = op.cast_final_safe<MatrixMul>();
  83. size_t dimA = matmul.dimA;
  84. size_t dimB = matmul.dimB;
  85. auto&& param = matmul.param();
  86. auto&& policy = matmul.policy();
  87. mgb_assert(inputs.size() == 2);
  88. std::array<ValueRef, 2> inps, input_shapes;
  89. for (size_t i = 0; i < 2; ++i) {
  90. if (inputs_require_grad[i ^ 1]) {
  91. inps[i] = inputs[i];
  92. input_shapes[i] = get_shape(inputs[i]);
  93. }
  94. }
  95. auto maker = CustomGradMaker(backward, inputs.size());
  96. maker.output_size(1).output_captured(0, false);
  97. maker.backward([inps_ = std::move(inps), input_shapes_ = std::move(input_shapes),
  98. param, policy, dimA, dimB](Span<ValueRef> grads) {
  99. mgb_assert(grads.size() == 1);
  100. ValueRef grad = grads[0];
  101. SmallVector<ValueRef> ret(2);
  102. if (!grad) {
  103. return ret;
  104. }
  105. size_t dimG = std::max(dimA, dimB);
  106. if (inps_[1]) {
  107. if (param.transposeA) {
  108. // A^T(2) @ B(2) = G(2), A'(2) = B'(2) @ G'^T(2) -> MatrixMul
  109. auto&& grad_op = MatrixMul::make(
  110. param.transposeB, true, param.compute_mode, param.format,
  111. policy.strategy, policy.workspace_limit, dimB, dimG);
  112. ret[0] = imperative::apply(*grad_op, inps_[1], grad)[0];
  113. } else {
  114. // A(>=2) @ B(2) = G(>=2), A'(>=2) = G'(>=2) @ B(2) -> MatrixMul
  115. auto&& grad_op = MatrixMul::make(
  116. false, !param.transposeB, param.compute_mode, param.format,
  117. policy.strategy, policy.workspace_limit, dimG, dimB);
  118. ret[0] = imperative::apply(*grad_op, grad, inps_[1])[0];
  119. }
  120. }
  121. if (inps_[0]) {
  122. if (param.transposeB) {
  123. // A(>=2) @ B^T(2) = G(>=2), B'(2) = G'^T(>=2) @ A(>=2) -> MatrixMul
  124. // (specialized)
  125. auto&& grad_op = MatrixMul::make(
  126. true, param.transposeA, param.compute_mode, param.format,
  127. policy.strategy, policy.workspace_limit, dimG, dimA);
  128. ret[1] = imperative::apply(*grad_op, grad, inps_[0])[0];
  129. } else {
  130. // A(>=2) @ B(2) = G(>=2), B'(2) = G'(>=2) @ A(>=2) -> MatrixMul
  131. // (specialized)
  132. auto&& grad_op = MatrixMul::make(
  133. !param.transposeA, false, param.compute_mode, param.format,
  134. policy.strategy, policy.workspace_limit, dimA, dimG);
  135. ret[1] = imperative::apply(*grad_op, inps_[0], grad)[0];
  136. }
  137. }
  138. return ret;
  139. });
  140. maker.finalize();
  141. return imperative::apply(ApplyOp(op), inputs);
  142. }
  143. std::optional<ValueRefList> batched_matrix_mul_grad_rule(
  144. const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
  145. CustomBackward& backward) {
  146. auto&& bmm = op.cast_final_safe<BatchedMatrixMul>();
  147. size_t dimA = bmm.dimA;
  148. size_t dimB = bmm.dimB;
  149. auto&& param = bmm.param();
  150. auto&& policy = bmm.policy();
  151. mgb_assert(inputs.size() == 2);
  152. std::array<ValueRef, 2> inps, input_shapes;
  153. for (size_t i = 0; i < 2; ++i) {
  154. if (inputs_require_grad[i ^ 1]) {
  155. inps[i] = inputs[i];
  156. input_shapes[i] = get_shape(inputs[i]);
  157. }
  158. }
  159. auto maker = CustomGradMaker(backward, inputs.size());
  160. maker.output_size(1).output_captured(0, false);
  161. maker.backward([inps_ = std::move(inps), input_shapes_ = std::move(input_shapes),
  162. param, policy, dimA, dimB](Span<ValueRef> grads) {
  163. mgb_assert(grads.size() == 1);
  164. ValueRef grad = grads[0];
  165. SmallVector<ValueRef> ret(2);
  166. if (!grad) {
  167. return ret;
  168. }
  169. size_t dimG = std::max(dimA, dimB);
  170. if (inps_[1]) {
  171. if (param.transposeA) {
  172. auto&& grad_op = BatchedMatrixMul::make(
  173. param.transposeB, true, param.compute_mode, param.format,
  174. policy.strategy, policy.workspace_limit, dimB, dimG);
  175. ret[0] = imperative::apply(*grad_op, inps_[1], grad)[0];
  176. } else {
  177. auto&& grad_op = BatchedMatrixMul::make(
  178. false, !param.transposeB, param.compute_mode, param.format,
  179. policy.strategy, policy.workspace_limit, dimG, dimB);
  180. ret[0] = imperative::apply(*grad_op, grad, inps_[1])[0];
  181. }
  182. if (dimG != dimA) {
  183. ret[0] = reduce_to(ret[0], input_shapes_[0]);
  184. }
  185. }
  186. if (inps_[0]) {
  187. if (param.transposeB) {
  188. auto&& grad_op = BatchedMatrixMul::make(
  189. true, param.transposeA, param.compute_mode, param.format,
  190. policy.strategy, policy.workspace_limit, dimG, dimA);
  191. ret[1] = imperative::apply(*grad_op, grad, inps_[0])[0];
  192. } else {
  193. auto&& grad_op = BatchedMatrixMul::make(
  194. !param.transposeA, false, param.compute_mode, param.format,
  195. policy.strategy, policy.workspace_limit, dimA, dimG);
  196. ret[1] = imperative::apply(*grad_op, inps_[0], grad)[0];
  197. }
  198. if (dimG != dimB) {
  199. ret[1] = reduce_to(ret[1], input_shapes_[1]);
  200. }
  201. }
  202. return ret;
  203. });
  204. maker.finalize();
  205. return imperative::apply(ApplyOp(op), inputs);
  206. }
  207. std::optional<ValueRefList> elemwise_grad_rule(
  208. const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
  209. CustomBackward& backward) {
  210. auto& elemwise = op.cast_final_safe<Elemwise>();
  211. if (elemwise.mode != Elemwise::Mode::ADD) {
  212. return {};
  213. }
  214. mgb_assert(inputs.size() == 2);
  215. std::array<ValueRef, 2> input_shapes;
  216. for (size_t i = 0; i < 2; ++i) {
  217. if (inputs_require_grad[i]) {
  218. input_shapes[i] = get_shape(inputs[i]);
  219. }
  220. }
  221. auto maker = CustomGradMaker(backward, inputs.size());
  222. maker.output_size(1).output_captured(0, false);
  223. maker.backward([shapes = std::move(input_shapes)](Span<ValueRef> grads) {
  224. mgb_assert(grads.size() == 1);
  225. ValueRef grad = grads[0];
  226. SmallVector<ValueRef> ret(2);
  227. if (!grad) {
  228. return ret;
  229. }
  230. for (size_t i = 0; i < 2; ++i) {
  231. if (shapes[i]) {
  232. ret[i] = reduce_to(grad, shapes[i]);
  233. }
  234. }
  235. return ret;
  236. });
  237. maker.finalize();
  238. return imperative::apply(ApplyOp(op), inputs);
  239. }
  240. std::optional<ValueRefList> reshape_grad_rule(
  241. const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
  242. CustomBackward& backward) {
  243. mgb_assert(inputs.size() == 1 || inputs.size() == 2);
  244. size_t nr_inp = inputs.size();
  245. std::array<ValueRef, 2> input_shapes;
  246. for (size_t i = 0; i < nr_inp; ++i) {
  247. if (inputs_require_grad[i]) {
  248. input_shapes[i] = get_shape(inputs[i]);
  249. }
  250. }
  251. auto maker = CustomGradMaker(backward, inputs.size());
  252. maker.output_size(1).output_captured(0, false);
  253. maker.backward([shapes = std::move(input_shapes), nr_inp](Span<ValueRef> grads) {
  254. mgb_assert(grads.size() == 1);
  255. ValueRef grad = grads[0];
  256. SmallVector<ValueRef> ret(nr_inp);
  257. if (!grad) {
  258. return ret;
  259. }
  260. for (size_t i = 0; i < nr_inp; ++i) {
  261. if (shapes[i]) {
  262. ret[i] = reshape_to(grad, shapes[i]);
  263. }
  264. }
  265. return ret;
  266. });
  267. maker.finalize();
  268. return imperative::apply(ApplyOp(op), inputs);
  269. }
  270. std::optional<ValueRefList> broadcast_grad_rule(
  271. const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
  272. CustomBackward& backward) {
  273. mgb_assert(inputs.size() == 1 || inputs.size() == 2);
  274. size_t nr_inp = inputs.size();
  275. std::array<ValueRef, 2> input_shapes;
  276. for (size_t i = 0; i < nr_inp; ++i) {
  277. if (inputs_require_grad[i]) {
  278. input_shapes[i] = get_shape(inputs[i]);
  279. }
  280. }
  281. auto maker = CustomGradMaker(backward, inputs.size());
  282. maker.output_size(1).output_captured(0, false);
  283. maker.backward([shapes = std::move(input_shapes), nr_inp](Span<ValueRef> grads) {
  284. mgb_assert(grads.size() == 1);
  285. ValueRef grad = grads[0];
  286. SmallVector<ValueRef> ret(nr_inp);
  287. if (!grad) {
  288. return ret;
  289. }
  290. for (size_t i = 0; i < nr_inp; ++i) {
  291. if (shapes[i]) {
  292. ret[i] = reduce_to(grad, shapes[i]);
  293. }
  294. }
  295. return ret;
  296. });
  297. maker.finalize();
  298. return imperative::apply(ApplyOp(op), inputs);
  299. }
  300. std::optional<ValueRefList> subtensor_grad_rule(
  301. const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
  302. CustomBackward& backward) {
  303. auto&& subtensor = op.cast_final_safe<Subtensor>();
  304. auto&& grad_op = SetSubtensor::make(subtensor.items);
  305. SmallVector<ValueRef> inputs2;
  306. if (inputs_require_grad[0]) {
  307. inputs2.push_back(get_shape(inputs[0]));
  308. for (size_t i = 1; i < inputs.size(); ++i) {
  309. inputs2.push_back(inputs[i]);
  310. }
  311. }
  312. auto maker = CustomGradMaker(backward, inputs.size());
  313. maker.output_size(1).output_captured(0, false);
  314. maker.backward([inputs = std::move(inputs2),
  315. grad_op_ = std::move(grad_op)](Span<ValueRef> grads) {
  316. mgb_assert(grads.size() == 1);
  317. ValueRef grad = grads[0];
  318. SmallVector<ValueRef> ret(1);
  319. if (grad && inputs[0]) {
  320. ValueRefList args_(inputs.size() + 1);
  321. auto&& zeros = make_empty_tensor(grad.device(), inputs[0], grad.dtype());
  322. args_[0] = zeros;
  323. args_[1] = grad;
  324. for (size_t i = 1; i < inputs.size(); ++i) {
  325. args_[i + 1] = inputs[i];
  326. }
  327. ret[0] = imperative::apply(ApplyOp(*grad_op_), args_)[0];
  328. }
  329. return ret;
  330. });
  331. maker.finalize();
  332. return imperative::apply(ApplyOp(op), inputs);
  333. }
  334. std::optional<ValueRefList> indexingMultiAxisVec_grad_rule(
  335. const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
  336. CustomBackward& backward) {
  337. auto&& indexingMultiAxisVec = op.cast_final_safe<IndexingMultiAxisVec>();
  338. auto&& grad_op = IndexingIncrMultiAxisVec::make(indexingMultiAxisVec.items);
  339. SmallVector<ValueRef> inputs2;
  340. if (inputs_require_grad[0]) {
  341. inputs2.push_back(get_shape(inputs[0]));
  342. for (size_t i = 1; i < inputs.size(); ++i) {
  343. inputs2.push_back(inputs[i]);
  344. }
  345. }
  346. auto maker = CustomGradMaker(backward, inputs.size());
  347. maker.output_size(1).output_captured(0, false);
  348. maker.backward([inputs = std::move(inputs2),
  349. grad_op_ = std::move(grad_op)](Span<ValueRef> grads) {
  350. mgb_assert(grads.size() == 1);
  351. ValueRef grad = grads[0];
  352. SmallVector<ValueRef> ret(1);
  353. if (grad && inputs[0]) {
  354. ValueRefList args_(inputs.size() + 1);
  355. auto&& zeros = make_empty_tensor(grad.device(), inputs[0], grad.dtype());
  356. args_[0] = zeros;
  357. args_[1] = grad;
  358. for (size_t i = 1; i < inputs.size(); ++i) {
  359. args_[i + 1] = inputs[i];
  360. }
  361. ret[0] = imperative::apply(ApplyOp(*grad_op_), args_)[0];
  362. }
  363. return ret;
  364. });
  365. maker.finalize();
  366. return imperative::apply(ApplyOp(op), inputs);
  367. }
  368. std::optional<ValueRefList> reduce_grad_rule(
  369. const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
  370. CustomBackward& backward) {
  371. auto& reduce = op.cast_final_safe<Reduce>();
  372. if (reduce.mode != Reduce::Mode::SUM) {
  373. return {};
  374. }
  375. auto axis = reduce.axis;
  376. if (inputs.size() != 1 || axis == INT_MAX) {
  377. return {};
  378. }
  379. std::array<ValueRef, 1> input_shapes;
  380. if (inputs_require_grad[0]) {
  381. input_shapes[0] = get_shape(inputs[0]);
  382. }
  383. if (axis < 0) {
  384. axis = (*inputs[0].shape()).ndim + axis;
  385. }
  386. auto maker = CustomGradMaker(backward, inputs.size());
  387. auto keepdim = reduce.keepdim || axis == INT_MAX;
  388. maker.output_size(1).output_captured(0, false);
  389. maker.backward(
  390. [shapes = std::move(input_shapes), axis, keepdim](Span<ValueRef> grads) {
  391. mgb_assert(grads.size() == 1);
  392. ValueRef grad = grads[0];
  393. if (!keepdim && grad) {
  394. auto&& grad_op = AddAxis::make(std::vector<int32_t>({axis}));
  395. grad = imperative::apply(*grad_op, grad)[0];
  396. }
  397. SmallVector<ValueRef> ret(1);
  398. if (grad && shapes[0]) {
  399. ret[0] = broadcast_to(grad, shapes[0]);
  400. }
  401. return ret;
  402. });
  403. maker.finalize();
  404. return imperative::apply(ApplyOp(op), inputs);
  405. }
  406. std::optional<ValueRefList> addAxis_grad_rule(
  407. const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
  408. CustomBackward& backward) {
  409. auto&& addAxis = op.cast_final_safe<AddAxis>();
  410. mgb_assert(inputs.size() == 1);
  411. bool flag = inputs_require_grad[0];
  412. auto&& grad_op = RemoveAxis::make(addAxis.axis);
  413. std::sort(grad_op->axis.begin(), grad_op->axis.end(), std::greater<int32_t>());
  414. auto maker = CustomGradMaker(backward, inputs.size());
  415. maker.output_size(1).output_captured(0, false);
  416. maker.backward([grad_op_ = std::move(grad_op), flag_ = flag](Span<ValueRef> grads) {
  417. mgb_assert(grads.size() == 1);
  418. ValueRef grad = grads[0];
  419. SmallVector<ValueRef> ret(1);
  420. if (grad && flag_) {
  421. ret[0] = imperative::apply(*grad_op_, grad)[0];
  422. }
  423. return ret;
  424. });
  425. maker.finalize();
  426. return imperative::apply(op, inputs);
  427. }
  428. std::optional<ValueRefList> removeAxis_grad_rule(
  429. const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
  430. CustomBackward& backward) {
  431. auto&& removeAxis = op.cast_final_safe<RemoveAxis>();
  432. mgb_assert(inputs.size() == 1);
  433. bool flag = inputs_require_grad[0];
  434. auto&& grad_op = AddAxis::make(removeAxis.axis);
  435. std::sort(grad_op->axis.begin(), grad_op->axis.end());
  436. auto maker = CustomGradMaker(backward, inputs.size());
  437. maker.output_size(1).output_captured(0, false);
  438. maker.backward([grad_op_ = std::move(grad_op), flag_ = flag](Span<ValueRef> grads) {
  439. mgb_assert(grads.size() == 1);
  440. ValueRef grad = grads[0];
  441. SmallVector<ValueRef> ret(1);
  442. if (grad && flag_) {
  443. ret[0] = imperative::apply(*grad_op_, grad)[0];
  444. }
  445. return ret;
  446. });
  447. maker.finalize();
  448. return imperative::apply(op, inputs);
  449. }
  450. std::optional<ValueRefList> pixelShuffle_grad_rule(
  451. const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
  452. CustomBackward& backward) {
  453. auto&& pixelShuffle = op.cast_final_safe<PixelShuffle>();
  454. mgb_assert(inputs.size() == 1);
  455. bool flag = inputs_require_grad[0];
  456. auto&& grad_op = PixelShuffleBackward::make(pixelShuffle.factor);
  457. auto maker = CustomGradMaker(backward, inputs.size());
  458. maker.output_size(1).output_captured(0, false);
  459. maker.backward([grad_op_ = std::move(grad_op), flag_ = flag](Span<ValueRef> grads) {
  460. mgb_assert(grads.size() == 1);
  461. ValueRef grad = grads[0];
  462. SmallVector<ValueRef> ret(1);
  463. if (grad && flag_) {
  464. ret[0] = imperative::apply(*grad_op_, grad)[0];
  465. }
  466. return ret;
  467. });
  468. maker.finalize();
  469. return imperative::apply(op, inputs);
  470. }
  471. std::optional<ValueRefList> indexing_grad_rule(
  472. const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
  473. CustomBackward& backward) {
  474. auto&& indexing = op.cast_final_safe<IndexingOneHot>();
  475. mgb_assert(inputs.size() == 2);
  476. bool flag = inputs_require_grad[0];
  477. auto&& grad_op = IndexingSetOneHot::make(indexing.axis, indexing.ndim);
  478. SmallVector<ValueRef> inputs2;
  479. if (flag) {
  480. inputs2.push_back(get_shape(inputs[0]));
  481. for (size_t i = 1; i < inputs.size(); ++i) {
  482. inputs2.push_back(inputs[i]);
  483. }
  484. }
  485. auto maker = CustomGradMaker(backward, inputs.size());
  486. maker.output_size(1).output_captured(0, false);
  487. maker.backward([inputs = std::move(inputs2),
  488. grad_op_ = std::move(grad_op)](Span<ValueRef> grads) {
  489. mgb_assert(grads.size() == 1);
  490. ValueRef grad = grads[0];
  491. SmallVector<ValueRef> ret(1);
  492. if (grad && inputs[0]) {
  493. ValueRefList args_(inputs.size() + 1);
  494. auto&& zeros = make_empty_tensor(grad.device(), inputs[0], grad.dtype());
  495. args_[0] = zeros;
  496. args_[1] = inputs[1];
  497. args_[2] = grads[0];
  498. ret[0] = imperative::apply(*grad_op_, args_)[0];
  499. }
  500. return ret;
  501. });
  502. maker.finalize();
  503. return imperative::apply(op, inputs);
  504. }
  505. std::optional<ValueRefList> indexing_set_one_hot_grad_rule(
  506. const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
  507. CustomBackward& backward) {
  508. auto&& indexingSetOneHot = op.cast_final_safe<IndexingSetOneHot>();
  509. mgb_assert(inputs.size() == 3);
  510. SmallVector<ValueRef> inputs2;
  511. inputs2.push_back(get_shape(inputs[0]));
  512. inputs2.push_back(inputs[1]);
  513. inputs2.push_back(get_shape(inputs[2]));
  514. auto maker = CustomGradMaker(backward, inputs.size());
  515. maker.output_size(1).output_captured(0, false);
  516. maker.backward([inputs = std::move(inputs2),
  517. &indexingSetOneHot](Span<ValueRef> grads) {
  518. mgb_assert(grads.size() == 1);
  519. ValueRef grad = grads[0];
  520. SmallVector<ValueRef> ret(3);
  521. if (!grad) {
  522. return ret;
  523. }
  524. if (inputs[0]) {
  525. auto&& grad_op = IndexingSetOneHot::make(
  526. indexingSetOneHot.axis, indexingSetOneHot.ndim);
  527. ValueRefList args_(inputs.size());
  528. auto&& zeros = make_empty_tensor(grad.device(), inputs[2], grad.dtype());
  529. args_[0] = grads[0];
  530. args_[1] = inputs[1];
  531. args_[2] = zeros;
  532. ret[0] = imperative::apply(*grad_op, args_)[0];
  533. }
  534. if (inputs[2]) {
  535. auto&& grad_op = IndexingOneHot::make(
  536. indexingSetOneHot.axis, indexingSetOneHot.ndim);
  537. ValueRefList args_(inputs.size() - 1);
  538. args_[0] = grads[0];
  539. args_[1] = inputs[1];
  540. ret[2] = imperative::apply(*grad_op, args_)[0];
  541. }
  542. return ret;
  543. });
  544. maker.finalize();
  545. return imperative::apply(op, inputs);
  546. }
  547. std::optional<ValueRefList> fastpathcopy_grad_rule(
  548. const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
  549. CustomBackward& backward) {
  550. mgb_assert(inputs.size() == 1);
  551. auto maker = CustomGradMaker(backward, inputs.size());
  552. maker.output_size(1).output_captured(0, false);
  553. maker.backward([](Span<ValueRef> grads) {
  554. mgb_assert(grads.size() == 1);
  555. ValueRef grad = grads[0];
  556. SmallVector<ValueRef> ret(1);
  557. if (grad) {
  558. ret[0] = grad;
  559. }
  560. return ret;
  561. });
  562. maker.finalize();
  563. return imperative::apply(op, inputs);
  564. }
  565. std::optional<ValueRefList> warp_affine_grad_rule(
  566. const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
  567. CustomBackward& backward) {
  568. auto&& warp_affine = op.cast_final_safe<WarpAffine>();
  569. auto&& param = warp_affine.param();
  570. mgb_assert(inputs.size() == 3);
  571. SmallVector<ValueRef> inps;
  572. if (inputs_require_grad[0] || inputs_require_grad[1]) {
  573. for (size_t i = 0; i < inputs.size(); ++i) {
  574. inps.push_back(inputs[i]);
  575. }
  576. }
  577. auto maker = CustomGradMaker(backward, inputs.size());
  578. maker.output_size(1).output_captured(0, false);
  579. maker.backward([inputs = std::move(inps), &warp_affine,
  580. param](Span<ValueRef> grads) {
  581. mgb_assert(grads.size() == 1);
  582. ValueRef grad = grads[0];
  583. SmallVector<ValueRef> ret(2);
  584. if (!grad) {
  585. return ret;
  586. }
  587. CompNodeValue::ref_t device = inputs[0].device();
  588. DTypeValue::ref_t dtype = inputs[0].dtype();
  589. HostTensorStorage storage(*device);
  590. storage.ensure_size(3 * (dtype->size()));
  591. auto* ptr = reinterpret_cast<dt_float32*>(storage.ptr());
  592. ptr[0] = 0;
  593. ptr[1] = 0;
  594. ptr[2] = 1;
  595. auto t = imperative::apply(
  596. CreateTensor(
  597. CreateTensor::Unique, *device, dtype::Float32(),
  598. ValueShape({1, 1, 3})),
  599. HostStorage::make(storage))[0];
  600. auto mat = inputs[1];
  601. auto&& concat = Concat::make();
  602. concat->axis = 1;
  603. mat = imperative::apply(*concat, inputs[1], t)[0];
  604. if (inputs[0]) {
  605. auto&& grad_op = WarpPerspectiveBackwardData::make(
  606. param.imode, param.border_mode, param.format, param.border_val);
  607. ValueRefList args_(inputs.size());
  608. args_[0] = mat;
  609. args_[1] = grads[0];
  610. args_[2] = inputs[0];
  611. ret[0] = imperative::apply(*grad_op, args_)[0];
  612. }
  613. if (inputs[1]) {
  614. auto&& grad_op = WarpPerspectiveBackwardMat::make(
  615. param.imode, param.border_mode, param.format, param.border_val);
  616. ValueRefList args_(inputs.size());
  617. args_[0] = inputs[0];
  618. args_[1] = mat;
  619. args_[2] = grads[0];
  620. ret[1] = imperative::apply(*grad_op, args_)[0];
  621. std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items;
  622. items.push_back(std::make_tuple(1, true, true, false, false));
  623. auto&& subtensor = Subtensor::make(items);
  624. CompNodeValue::ref_t device = inputs[0].device();
  625. DTypeValue::ref_t dtype = inputs[0].dtype();
  626. int start_idx = 0;
  627. int stop_idx = 2;
  628. auto get_subtensor_index = [&](int idx) {
  629. HostTensorStorage storage(*device);
  630. storage.ensure_size(dtype::Int32().size());
  631. auto* ptr = reinterpret_cast<dt_int32*>(storage.ptr());
  632. ptr[0] = idx;
  633. return imperative::apply(
  634. CreateTensor(
  635. CreateTensor::Unique, *device, dtype::Int32(),
  636. ValueShape({1})),
  637. HostStorage::make(storage))[0];
  638. };
  639. auto start = get_subtensor_index(start_idx);
  640. auto stop = get_subtensor_index(stop_idx);
  641. auto data = ret[1];
  642. mgb_assert(data);
  643. ret[1] = imperative::apply(*subtensor, data, start, stop)[0];
  644. }
  645. return ret;
  646. });
  647. maker.finalize();
  648. return imperative::apply(ApplyOp(op), inputs);
  649. }
  650. struct Init {
  651. Init() {
  652. CustomBackward::register_grad_rule(Elemwise::typeinfo(), elemwise_grad_rule);
  653. CustomBackward::register_grad_rule(Reshape::typeinfo(), reshape_grad_rule);
  654. CustomBackward::register_grad_rule(Broadcast::typeinfo(), broadcast_grad_rule);
  655. CustomBackward::register_grad_rule(Subtensor::typeinfo(), subtensor_grad_rule);
  656. CustomBackward::register_grad_rule(
  657. IndexingMultiAxisVec::typeinfo(), indexingMultiAxisVec_grad_rule);
  658. CustomBackward::register_grad_rule(Reduce::typeinfo(), reduce_grad_rule);
  659. CustomBackward::register_grad_rule(AddAxis::typeinfo(), addAxis_grad_rule);
  660. CustomBackward::register_grad_rule(
  661. RemoveAxis::typeinfo(), removeAxis_grad_rule);
  662. CustomBackward::register_grad_rule(
  663. IndexingOneHot::typeinfo(), indexing_grad_rule);
  664. CustomBackward::register_grad_rule(
  665. IndexingSetOneHot::typeinfo(), indexing_set_one_hot_grad_rule);
  666. CustomBackward::register_grad_rule(
  667. FastpathCopy::typeinfo(), fastpathcopy_grad_rule);
  668. CustomBackward::register_grad_rule(
  669. PixelShuffle::typeinfo(), pixelShuffle_grad_rule);
  670. CustomBackward::register_grad_rule(MatrixMul::typeinfo(), matrix_mul_grad_rule);
  671. CustomBackward::register_grad_rule(
  672. BatchedMatrixMul::typeinfo(), batched_matrix_mul_grad_rule);
  673. CustomBackward::register_grad_rule(
  674. WarpAffine::typeinfo(), warp_affine_grad_rule);
  675. }
  676. } _;
  677. } // namespace
  678. } // namespace mgb::imperative::python