You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grad.cpp 22 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625
  1. /**
  2. * \file imperative/python/src/grad.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "./grad.h"
  12. #include "megbrain/imperative/proxy_graph_detail.h"
  13. #include "megbrain/imperative/backward_graph_opt.h"
  14. #include "megbrain/imperative/ops/autogen.h"
  15. #include "megbrain/imperative/ops/utility.h"
  16. #include "megbrain/utils/mempool.h"
  17. #include "range/v3/all.hpp"
  18. namespace py = pybind11;
  19. namespace views = ranges::views;
  20. namespace mgb::imperative::python {
  21. using scoped_disable = ApplyContext::scoped_disable;
  22. using Flags = Tensor::Flags;
  23. namespace {
  24. struct GradSlotWeakPtr {
  25. std::weak_ptr<GradFn> grad_fn;
  26. size_t idx;
  27. };
  28. struct BackwardGraphCache : std::unordered_map<uint64_t, std::shared_ptr<OptimizedBackwardGraphResult>>, CompNodeDepedentObject {
  29. std::shared_ptr<void> on_comp_node_finalize() override {
  30. clear();
  31. return {};
  32. }
  33. } backward_graph_cache;
  34. std::shared_ptr<OptimizedBackwardGraphResult> make_backward_graph(
  35. ApplyContext& ctx, const apply_result_t& outputs) {
  36. // hash
  37. static_assert(alignof(size_t) % alignof(bool) == 0);
  38. size_t buf_size = (1 + ctx.nargs * 2) * sizeof(size_t) + ctx.nargs * sizeof(bool);
  39. alignas(alignof(size_t)) std::byte buf[buf_size];
  40. size_t* size_t_ptr = reinterpret_cast<size_t*>(buf);
  41. bool* bool_ptr = reinterpret_cast<bool*>(size_t_ptr + (1 + ctx.nargs * 2));
  42. bool* bool_ptr0 = bool_ptr;
  43. *(size_t_ptr++) = ctx.op->hash();
  44. for (size_t i = 0; i < ctx.nargs; ++i) {
  45. *(size_t_ptr++) = mgb::hash(ctx.args[i]->dtype().handle());
  46. *(size_t_ptr++) = mgb::hash(ctx.args[i]->comp_node());
  47. *(bool_ptr++) = bool(ctx.args[i]->m_grad_info.grad_fn);
  48. }
  49. mgb_assert(bool_ptr0 == reinterpret_cast<bool*>(size_t_ptr) &&
  50. bool_ptr == reinterpret_cast<bool*>(buf + buf_size));
  51. uint64_t key = XXHash{}.update(buf, buf_size).digest();
  52. auto&& iter = backward_graph_cache.find(key);
  53. if (iter != backward_graph_cache.end()) {
  54. return iter->second;
  55. }
  56. // slow path
  57. SmallVector<LogicalTensorDesc> inputs(ctx.nargs);
  58. SmallVector<bool> input_requires_grad(ctx.nargs, false);
  59. SmallVector<bool> output_has_grad(outputs.size(), true);
  60. for (size_t i = 0; i < ctx.nargs; ++i) {
  61. inputs[i].comp_node = ctx.args[i]->comp_node();
  62. inputs[i].layout.dtype = ctx.args[i]->dtype();
  63. input_requires_grad[i] = python::input_requires_grad(ctx, i);
  64. }
  65. std::shared_ptr<OptimizedBackwardGraphResult> ret;
  66. auto bg = proxy_graph_detail::make_backward_graph(
  67. *ctx.op, inputs, input_requires_grad, output_has_grad);
  68. if (bg.backward) {
  69. ret = std::make_shared<OptimizedBackwardGraphResult>(bg);
  70. }
  71. backward_graph_cache.emplace(key, ret);
  72. return ret;
  73. }
  74. struct BackwardGraphWithClosure {
  75. std::shared_ptr<OptimizedBackwardGraphResult> backward_graph;
  76. SmallVector<std::shared_ptr<Tensor>> closure;
  77. size_t output_mask_offset;
  78. size_t grad_mask_offset;
  79. BackwardGraphWithClosure(std::shared_ptr<OptimizedBackwardGraphResult> backward_graph_,
  80. ApplyContext& ctx, const apply_result_t& outputs)
  81. : backward_graph(backward_graph_),
  82. output_mask_offset(ctx.nargs),
  83. grad_mask_offset(ctx.nargs + outputs.size()) {
  84. // save_for_backward[0:nargs]:
  85. // whether input is kept for backward
  86. //
  87. // save_for_backward[nargs:nargs+outputs.size()]:
  88. // whether output is kept for backward
  89. //
  90. // save_for_backward[-outputs.size():]:
  91. // whether gradient of output can propagate to any input
  92. //
  93. // Example:
  94. // perform c = a * b, with a.requires_grad == True and
  95. // b.requires_grad == False, save_for_backward = [0, 1, 0, 1]
  96. auto& save_for_backward = backward_graph->save_for_backward;
  97. mgb_assert(save_for_backward.size() == ctx.nargs + 2 * outputs.size());
  98. size_t count = std::count_if(save_for_backward.begin(),
  99. save_for_backward.end(),
  100. ranges::identity{});
  101. if (backward_graph->precomp) {
  102. auto&& irng = ranges::span(ctx.args, ctx.nargs);
  103. auto&& orng = views::transform(outputs, [](auto&& i){return i.get();});
  104. auto precomp = apply(backward_graph->precomp, views::concat(irng, orng));
  105. closure.reserve(precomp.size() + count);
  106. std::copy(precomp.begin(), precomp.end(), std::back_inserter(closure));
  107. } else {
  108. closure.reserve(count);
  109. }
  110. for (size_t i = 0; i < ctx.nargs; ++i) {
  111. if (save_for_backward[i]) {
  112. closure.push_back(ctx.args[i]->shared_from_this());
  113. }
  114. }
  115. for (size_t i = 0; i < outputs.size(); ++i) {
  116. if (save_for_backward[ctx.nargs + i]) {
  117. closure.push_back(outputs[i]);
  118. }
  119. }
  120. }
  121. template <typename T, typename R>
  122. void operator()(BackwardContext&, T&& grads, R&& receiver) {
  123. Tensor* args[closure.size() + grads.size()];
  124. size_t nargs = 0;
  125. for (auto&& t : closure) {
  126. args[nargs++] = t.get();
  127. }
  128. bool null_grad = false;
  129. for (size_t i = 0; i < grads.size(); ++i) {
  130. if (backward_graph->save_for_backward[grad_mask_offset + i]) {
  131. if (grads[i]) {
  132. if (null_grad) {
  133. PyErr_SetString(PyExc_NotImplementedError, "report to devs");
  134. throw py::error_already_set();
  135. }
  136. args[nargs++] = grads[i];
  137. } else {
  138. null_grad = true;
  139. }
  140. }
  141. }
  142. if (null_grad) return;
  143. auto igrads = apply(backward_graph->backward, args, nargs);
  144. auto&& it = igrads.begin();
  145. for (auto [i, p] : views::enumerate(backward_graph->input_has_grad)) {
  146. if (p) {
  147. receiver(i, std::move(*it));
  148. ++it;
  149. }
  150. }
  151. }
  152. bool input_has_grad(size_t i) {
  153. return backward_graph->input_has_grad[i];
  154. }
  155. bool output_requires_grad(size_t i) {
  156. return backward_graph->save_for_backward[grad_mask_offset + i];
  157. }
  158. bool output_captured(size_t i) {
  159. return backward_graph->save_for_backward[output_mask_offset + i];
  160. }
  161. };
  162. struct PythonBackward {
  163. py::object pyfunc;
  164. size_t input_size;
  165. PythonBackward(py::object f, size_t nin)
  166. : pyfunc(f), input_size(nin) {}
  167. template <typename T, typename R>
  168. void operator()(BackwardContext& ctx, T&& grads, R&& receiver) {
  169. auto args = py::tuple(grads.size());
  170. for (size_t i = 0; i < grads.size(); ++i) {
  171. auto&& g = grads[i];
  172. args[i] = g ? ctx.wrap_tensor(g) : py::none();
  173. }
  174. auto input_grads = py::reinterpret_steal<py::object>(PyObject_Call(pyfunc.ptr(), args.ptr(), nullptr));
  175. if (!input_grads) throw py::error_already_set();
  176. if (input_grads.is_none()) return;
  177. if (auto* tw = TensorWrapper::try_cast(input_grads.ptr())) {
  178. if (input_size != 1) {
  179. throw py::value_error("custom grad rule returned wrong number of grads");
  180. }
  181. if (!ctx.pytype) {
  182. ctx.pytype = Py_TYPE(input_grads.ptr());
  183. }
  184. receiver(0, tw->m_tensor);
  185. return;
  186. }
  187. if (py::len(input_grads) != input_size) {
  188. throw py::value_error("custom grad rule returned wrong number of grads");
  189. }
  190. for (auto [i, g] : views::enumerate(input_grads)) {
  191. if (g.is_none()) continue;
  192. auto* tw = TensorWrapper::try_cast(g.ptr());
  193. if (!tw) {
  194. throw py::type_error("custom grad rule returned non-tensor");
  195. }
  196. if (!ctx.pytype) {
  197. ctx.pytype = Py_TYPE(g.ptr());
  198. }
  199. receiver(i, tw->m_tensor);
  200. }
  201. }
  202. static constexpr bool input_has_grad(size_t) {return true;}
  203. static constexpr bool output_requires_grad(size_t) {return true;}
  204. static constexpr bool output_captured(size_t) {return true;}
  205. };
  206. } // namespace
  207. struct GradProducerRecord : intrusive_list::Node<GradProducerRecord> {
  208. using Base = intrusive_list::Node<GradProducerRecord>;
  209. GradProducerRecord() = default;
  210. GradProducerRecord(GradProducerRecord::head_t& head) : Base(intrusive_list::after_t{}, head) {}
  211. // GradProducerRecord(GradProducerRecord&&) = default;
  212. // GradProducerRecord& operator=(GradProducerRecord&) = default;
  213. // GradProducerRecord& operator=(GradProducerRecord&&) = default;
  214. };
  215. struct GradSlot {
  216. std::shared_ptr<Tensor> grad;
  217. py::object callback;
  218. GradProducerRecord::head_t producer_head;
  219. };
  220. struct GradSlotProducerPtr : GradSlotPtr {
  221. GradProducerRecord producer_record;
  222. GradSlotProducerPtr() = default;
  223. GradSlotProducerPtr(GradInfo& info) : GradSlotPtr(info), producer_record(info->producer_head) {}
  224. };
  225. struct GradFn : std::enable_shared_from_this<GradFn> {
  226. static MemPool<GradFn> pool;
  227. std::weak_ptr<GradKey> key;
  228. // slots for receiving and accumulating grads
  229. // same length as outputs (of forward op)
  230. SmallVector<GradSlot> slots;
  231. // where to send and accumulate grads
  232. // same length as inputs (of forward op)
  233. SmallVector<GradSlotProducerPtr> dsts;
  234. // encapsules actual function to compute gradient
  235. std::variant<std::monostate, BackwardGraphWithClosure, PythonBackward, CustomBackward> backward;
  236. // a flag used during backward
  237. bool in_ref_keeper = false;
  238. static void deleter(GradFn* ptr) {
  239. pool.free(ptr);
  240. }
  241. std::shared_ptr<GradFn> make() {
  242. return std::shared_ptr<GradFn>(pool.alloc(), &deleter);
  243. }
  244. void clear() {
  245. key.reset();
  246. slots.clear();
  247. dsts.clear();
  248. backward.emplace<std::monostate>();
  249. }
  250. };
  251. GradSlotPtr::operator bool() const {
  252. return bool(grad_fn);
  253. }
  254. GradSlot* GradSlotPtr::operator->() {
  255. return &grad_fn->slots[idx];
  256. }
  257. namespace {
  258. class GradFnHelper {
  259. std::shared_ptr<GradFn> grad_fn;
  260. GradFn* get() {
  261. if (!grad_fn) {
  262. grad_fn = std::make_shared<GradFn>();
  263. }
  264. return grad_fn.get();
  265. }
  266. friend apply_result_t imperative::python::apply_grad(ApplyContext&);
  267. public:
  268. template<typename T, typename... Args>
  269. auto& emplace(Args&&... args) {
  270. return get()->backward.emplace<T>(std::forward<Args>(args)...);
  271. }
  272. void reset() { grad_fn = nullptr; }
  273. };
  274. apply_result_t backward_graph_grad_rule(ApplyContext& ctx, GradFnHelper& ret_grad_fn) {
  275. auto outputs = apply(ctx);
  276. auto backward_graph = make_backward_graph(ctx, outputs);
  277. if (!backward_graph) {
  278. return outputs;
  279. }
  280. ret_grad_fn.emplace<BackwardGraphWithClosure>(std::move(backward_graph), ctx, outputs);
  281. return outputs;
  282. }
  283. apply_result_t python_grad_rule(ApplyContext& ctx, GradFnHelper& ret_grad_fn) {
  284. auto* op = ctx.op->try_cast_final<GenericPyOp>();
  285. py::tuple pyin(ctx.nargs);
  286. for (size_t i = 0; i < ctx.nargs; ++i) {
  287. pyin[i] = TensorWrapper::make(ctx.pytype, ctx.args[i]->shared_from_this());
  288. }
  289. auto grad_rule = py::getattr(op->obj, "_grad_rule");
  290. auto pyret = py::reinterpret_steal<py::object>(PyObject_Call(grad_rule.ptr(), pyin.ptr(), nullptr));
  291. if (!pyret) throw py::error_already_set();
  292. auto [outputs, backward] = py::cast<std::tuple<py::object, py::function>>(pyret);
  293. ret_grad_fn.emplace<PythonBackward>(std::move(backward), ctx.nargs);
  294. if (auto* tw = TensorWrapper::try_cast(outputs.ptr())) {
  295. return {tw->m_tensor};
  296. }
  297. apply_result_t ret;
  298. ret.reserve(py::len(outputs));
  299. for (auto&& i : outputs) {
  300. auto* tw = TensorWrapper::try_cast(i.ptr());
  301. mgb_assert(tw);
  302. ret.push_back(tw->m_tensor);
  303. }
  304. return ret;
  305. }
  306. } // namespace
  307. apply_result_t apply_grad(ApplyContext& ctx) {
  308. std::shared_ptr<GradKey> grad_key;
  309. for (size_t i = 0; i < ctx.nargs; ++i) {
  310. auto* tensor = ctx.args[i];
  311. if (tensor->m_grad_info.grad_fn) {
  312. auto&& input_grad_key = tensor->m_grad_info.grad_fn->key.lock();
  313. // tensor is attached to a live GradKey
  314. if (input_grad_key && input_grad_key->active) {
  315. if (grad_key) {
  316. if (grad_key != input_grad_key) {
  317. PyErr_SetString(PyExc_NotImplementedError, "second order grad");
  318. throw pyext17::py_err_set();
  319. }
  320. } else {
  321. grad_key = std::move(input_grad_key);
  322. }
  323. } else {
  324. // cleanup stale grad info
  325. // under what condition?
  326. tensor->m_grad_info = {};
  327. tensor->m_flags &= ~Flags::GRAD;
  328. }
  329. } else {
  330. tensor->m_flags &= ~Flags::GRAD;
  331. }
  332. }
  333. ctx.flags &= ~Flags::GRAD;
  334. if (!grad_key) {
  335. return apply(ctx);
  336. }
  337. GradFnHelper grad_fn_holder;
  338. auto outputs = [&]() {
  339. auto _ = scoped_disable(Flags::GRAD);
  340. if (ctx.op->same_type<GenericPyOp>()) {
  341. return python_grad_rule(ctx, grad_fn_holder);
  342. }
  343. auto&& registry = grad_rule_registry();
  344. auto&& it = registry.find(ctx.op->dyn_typeinfo());
  345. if (it != registry.end()) {
  346. auto&& maker = grad_fn_holder.emplace<CustomBackward>().maker(ctx);
  347. try {
  348. auto ret = it->second(ctx, maker);
  349. maker.finalize();
  350. return ret;
  351. } catch (GradRuleFallback&) {
  352. grad_fn_holder.reset();
  353. }
  354. }
  355. return backward_graph_grad_rule(ctx, grad_fn_holder);
  356. }();
  357. auto& grad_fn = grad_fn_holder.grad_fn;
  358. if (!grad_fn) {
  359. return outputs;
  360. }
  361. grad_fn->key = grad_key;
  362. grad_fn->slots.resize(outputs.size());
  363. grad_fn->dsts.reserve(ctx.nargs);
  364. std::visit([&](auto& backward) {
  365. using T = std::decay_t<decltype(backward)>;
  366. if constexpr (std::is_same_v<T, std::monostate>) {
  367. mgb_assert(0);
  368. } else {
  369. for (size_t i = 0; i < ctx.nargs; ++i) {
  370. if (backward.input_has_grad(i) && input_requires_grad(ctx, i)) {
  371. auto& input_grad_info = ctx.args[i]->m_grad_info;
  372. grad_fn->dsts.emplace_back(input_grad_info);
  373. // register as grad producer
  374. grad_fn->dsts.back().producer_record.insert_after(input_grad_info->producer_head);
  375. } else {
  376. grad_fn->dsts.emplace_back();
  377. }
  378. }
  379. for (size_t i = 0; i < outputs.size(); ++i) {
  380. if (backward.output_requires_grad(i)) {
  381. if (backward.output_captured(i)) {
  382. // avoid reference cycle [Tensor <-> GradFn]
  383. outputs[i] = outputs[i]->copy();
  384. }
  385. // populate grad info of output tensor
  386. auto& grad_info = outputs[i]->m_grad_info;
  387. grad_info.grad_fn = grad_fn;
  388. grad_info.idx = i;
  389. grad_info.insert_after(grad_key->free_vars_head);
  390. outputs[i]->m_flags |= Flags::GRAD;
  391. }
  392. }
  393. }
  394. }, grad_fn->backward);
  395. // record forward history
  396. grad_key->tape.emplace_back(grad_fn);
  397. return outputs;
  398. }
  399. void GradKeyWrapper::attach(PyObject*const* args, size_t nargs) {
  400. if (nargs != 2) {
  401. throw py::type_error("expect 2 arguments");
  402. }
  403. auto* tw = TensorWrapper::try_cast(args[0]);
  404. if (!tw) {
  405. throw py::type_error("argument 1 must be Tensor");
  406. }
  407. auto* tensor = tw->m_tensor.get();
  408. py::object callback;
  409. if (args[1] != Py_None) {
  410. callback = py::reinterpret_borrow<py::object>(args[1]);
  411. }
  412. m_key->attach(tensor, std::move(callback));
  413. }
  414. //! GradKey is weakly refered by tensor->m_grad_info.grad_fn->key after attach
  415. void GradKey::attach(Tensor* tensor, pybind11::object callback) {
  416. if (!active) {
  417. throw py::value_error("grad key finalized");
  418. }
  419. if (tensor->m_grad_info.grad_fn) {
  420. if (tensor->m_grad_info.grad_fn->key.lock().get() != this) {
  421. PyErr_SetString(PyExc_NotImplementedError, "second order grad");
  422. throw pyext17::py_err_set();
  423. }
  424. if (tensor->m_grad_info->callback) {
  425. throw py::value_error("callback already set on this tensor");
  426. }
  427. } else {
  428. tensor->m_grad_info.idx = 0;
  429. auto& grad_fn = tensor->m_grad_info.grad_fn;
  430. grad_fn = std::make_shared<GradFn>();
  431. grad_fn->key = shared_from_this();
  432. grad_fn->slots.resize(1);
  433. tensor->m_grad_info.insert_after(free_vars_head);
  434. tensor->m_flags |= Flags::GRAD;
  435. }
  436. tensor->m_grad_info.grad_fn->slots[0].callback = std::move(callback);
  437. }
  438. template<typename T>
  439. void accum_grad(std::shared_ptr<Tensor>& grad, T&& delta) {
  440. if (!grad) {
  441. grad = std::forward<T>(delta);
  442. return;
  443. }
  444. static std::shared_ptr<OpDef> op = std::shared_ptr<OpDef>(new Elemwise(Elemwise::Mode::ADD));
  445. grad = apply(op, grad, std::forward<T>(delta))[0];
  446. }
  447. void GradKey::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWrapper*> grads) {
  448. if (!active) {
  449. throw py::value_error("finalized");
  450. }
  451. if (tensors.size() != grads.size()) {
  452. throw py::value_error("tensor and grad size mismatch");
  453. }
  454. // this GradKey is marked inactive here
  455. active = false;
  456. struct CleanupGuard {
  457. GradKey* owner;
  458. CleanupGuard(GradKey* this_) : owner(this_) {}
  459. ~CleanupGuard() {owner->cleanup();}
  460. } _cleanup_guard(this);
  461. if (tape.empty()) return;
  462. BackwardContext bctx;
  463. if (!grads.empty()) {
  464. bctx.pytype = Py_TYPE(grads[0]->self().ptr());
  465. }
  466. for (size_t i = 0; i < tensors.size(); ++i) {
  467. auto& grad_info = tensors[i]->m_tensor->m_grad_info;
  468. if (grad_info.grad_fn && grad_info.grad_fn->key.lock().get() == this) {
  469. grad_info->grad = grads[i]->m_tensor;
  470. }
  471. }
  472. std::vector<std::shared_ptr<GradFn>> ref_keeper;
  473. ref_keeper.reserve(tape.size());
  474. // back-propagation in reverse order
  475. for (std::ptrdiff_t k = tape.size() - 1; k >= 0; --k) {
  476. auto&& grad_fn = tape[k].lock();
  477. if (!grad_fn) continue;
  478. auto grad_receiver = [&](size_t i, auto&& g) {
  479. auto& dst = grad_fn->dsts[i];
  480. if (dst) {
  481. accum_grad(dst->grad, std::forward<decltype(g)>(g));
  482. }
  483. };
  484. std::visit([&](auto&& backward) {
  485. using T = std::decay_t<decltype(backward)>;
  486. if constexpr (std::is_same_v<T, std::monostate>) {
  487. mgb_assert(0);
  488. } else {
  489. auto&& grads = views::transform(grad_fn->slots, [](auto&& slot) {return slot.grad.get();});
  490. backward(bctx, std::forward<decltype(grads)>(grads), grad_receiver);
  491. }
  492. }, grad_fn->backward);
  493. for (auto&& dst : grad_fn->dsts) {
  494. if (!dst.grad_fn) continue;
  495. if (!dst.grad_fn->in_ref_keeper) {
  496. // after grad_fn is cleared, refcnt of subsequent grad_fn
  497. // could drop to 0
  498. dst.grad_fn->in_ref_keeper = true;
  499. ref_keeper.push_back(dst.grad_fn);
  500. }
  501. if (!dst.producer_record.next && dst->callback && dst->grad) {
  502. // I'm the last grad producer, invoke callback
  503. dst->callback(bctx.wrap_tensor(dst->grad));
  504. }
  505. }
  506. grad_fn->clear();
  507. } // finish tape loop
  508. }
  509. void GradKey::cleanup() {
  510. active = false;
  511. tape.clear();
  512. for (intrusive_list::Iterator it(free_vars_head); it;) {
  513. it->grad_fn.reset();
  514. (it++)->unlink();
  515. }
  516. }
  517. void GradKeyWrapper::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWrapper*> grads) {
  518. m_key->backward(std::move(tensors), std::move(grads));
  519. }
  520. PyObject* GradKeyWrapper::get_name() {
  521. return py::cast(m_key->name).release().ptr();
  522. }
  523. void GradKeyWrapper::set_name(py::handle name) {
  524. m_key->name = py::cast<std::string>(name);
  525. }
  526. PyObject* GradKeyWrapper::is_attached_to(PyObject*const* args, size_t nargs) {
  527. if (nargs != 1) {
  528. PyErr_SetString(PyExc_TypeError, "expect 1 argument");
  529. return nullptr;
  530. }
  531. auto* tw = TensorWrapper::try_cast(args[0]);
  532. if (!tw) {
  533. PyErr_SetString(PyExc_TypeError, "expect Tensor");
  534. return nullptr;
  535. }
  536. auto&& grad_fn = tw->m_tensor->m_grad_info.grad_fn;
  537. if (grad_fn && grad_fn->key.lock() == m_key) {
  538. Py_RETURN_TRUE;
  539. }
  540. Py_RETURN_FALSE;
  541. }
  542. GradKey::~GradKey() {
  543. cleanup();
  544. }
  545. std::unordered_map<Typeinfo*, GradRuleFn>& grad_rule_registry() {
  546. static std::unordered_map<Typeinfo*, GradRuleFn> registry;
  547. return registry;
  548. }
  549. } // namespace mgb::imperative::python

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台