You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grad.cpp 21 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625
  1. /**
  2. * \file imperative/python/src/grad.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "./grad.h"
  12. #include "megbrain/imperative/proxy_graph_detail.h"
  13. #include "megbrain/imperative/ops/autogen.h"
  14. #include "megbrain/imperative/ops/utility.h"
  15. #include "megbrain/utils/mempool.h"
  16. #include "range/v3/all.hpp"
  17. namespace py = pybind11;
  18. namespace views = ranges::views;
  19. namespace mgb::imperative::python {
  20. using scoped_disable = ApplyContext::scoped_disable;
  21. using Flags = Tensor::Flags;
  22. namespace {
  23. struct GradSlotWeakPtr {
  24. std::weak_ptr<GradFn> grad_fn;
  25. size_t idx;
  26. };
  27. struct BackwardGraphCache : std::unordered_map<uint64_t, std::shared_ptr<BackwardGraphResult>>, CompNodeDepedentObject {
  28. std::shared_ptr<void> on_comp_node_finalize() override {
  29. clear();
  30. return {};
  31. }
  32. } backward_graph_cache;
  33. std::shared_ptr<BackwardGraphResult> make_backward_graph(
  34. ApplyContext& ctx, const apply_result_t& outputs) {
  35. // hash
  36. static_assert(alignof(size_t) % alignof(bool) == 0);
  37. size_t buf_size = (1 + ctx.nargs * 2) * sizeof(size_t) + ctx.nargs * sizeof(bool);
  38. alignas(alignof(size_t)) std::byte buf[buf_size];
  39. size_t* size_t_ptr = reinterpret_cast<size_t*>(buf);
  40. bool* bool_ptr = reinterpret_cast<bool*>(size_t_ptr + (1 + ctx.nargs * 2));
  41. bool* bool_ptr0 = bool_ptr;
  42. *(size_t_ptr++) = ctx.op->hash();
  43. for (size_t i = 0; i < ctx.nargs; ++i) {
  44. *(size_t_ptr++) = mgb::hash(ctx.args[i]->dtype().handle());
  45. *(size_t_ptr++) = mgb::hash(ctx.args[i]->comp_node());
  46. *(bool_ptr++) = bool(ctx.args[i]->m_grad_info.grad_fn);
  47. }
  48. mgb_assert(bool_ptr0 == reinterpret_cast<bool*>(size_t_ptr) &&
  49. bool_ptr == reinterpret_cast<bool*>(buf + buf_size));
  50. uint64_t key = XXHash{}.update(buf, buf_size).digest();
  51. auto&& iter = backward_graph_cache.find(key);
  52. if (iter != backward_graph_cache.end()) {
  53. return iter->second;
  54. }
  55. // slow path
  56. SmallVector<LogicalTensorDesc> inputs(ctx.nargs);
  57. SmallVector<bool> input_requires_grad(ctx.nargs, false);
  58. SmallVector<bool> output_has_grad(outputs.size(), true);
  59. for (size_t i = 0; i < ctx.nargs; ++i) {
  60. inputs[i].comp_node = ctx.args[i]->comp_node();
  61. inputs[i].layout.dtype = ctx.args[i]->dtype();
  62. input_requires_grad[i] = python::input_requires_grad(ctx, i);
  63. }
  64. auto result = std::make_shared<BackwardGraphResult>(
  65. proxy_graph_detail::make_backward_graph(
  66. *ctx.op, inputs, input_requires_grad, output_has_grad));
  67. if (!result->backward) {
  68. result.reset();
  69. }
  70. backward_graph_cache.emplace(key, result);
  71. return result;
  72. }
  73. struct BackwardGraphWithClosure {
  74. std::shared_ptr<BackwardGraphResult> backward_graph;
  75. SmallVector<std::shared_ptr<Tensor>> closure;
  76. size_t output_mask_offset;
  77. size_t grad_mask_offset;
  78. BackwardGraphWithClosure(std::shared_ptr<BackwardGraphResult> backward_graph_,
  79. ApplyContext& ctx, const apply_result_t& outputs)
  80. : backward_graph(backward_graph_),
  81. output_mask_offset(ctx.nargs),
  82. grad_mask_offset(ctx.nargs + outputs.size()) {
  83. // save_for_backward[0:nargs]:
  84. // whether input is kept for backward
  85. //
  86. // save_for_backward[nargs:nargs+outputs.size()]:
  87. // whether output is kept for backward
  88. //
  89. // save_for_backward[-outputs.size():]:
  90. // whether gradient of output can propagate to any input
  91. //
  92. // Example:
  93. // perform c = a * b, with a.requires_grad == True and
  94. // b.requires_grad == False, save_for_backward = [0, 1, 0, 1]
  95. auto& save_for_backward = backward_graph->save_for_backward;
  96. mgb_assert(save_for_backward.size() == ctx.nargs + 2 * outputs.size());
  97. closure.reserve(std::count_if(save_for_backward.begin(),
  98. save_for_backward.end(),
  99. ranges::identity{}));
  100. for (size_t i = 0; i < ctx.nargs; ++i) {
  101. if (save_for_backward[i]) {
  102. closure.push_back(ctx.args[i]->shared_from_this());
  103. }
  104. }
  105. for (size_t i = 0; i < outputs.size(); ++i) {
  106. if (save_for_backward[ctx.nargs + i]) {
  107. closure.push_back(outputs[i]);
  108. }
  109. }
  110. }
  111. template <typename T, typename R>
  112. void operator()(BackwardContext&, T&& grads, R&& receiver) {
  113. Tensor* args[closure.size() + grads.size()];
  114. size_t nargs = 0;
  115. for (auto&& t : closure) {
  116. args[nargs++] = t.get();
  117. }
  118. bool null_grad = false;
  119. for (size_t i = 0; i < grads.size(); ++i) {
  120. if (backward_graph->save_for_backward[grad_mask_offset + i]) {
  121. if (grads[i]) {
  122. if (null_grad) {
  123. PyErr_SetString(PyExc_NotImplementedError, "report to devs");
  124. throw py::error_already_set();
  125. }
  126. args[nargs++] = grads[i];
  127. } else {
  128. null_grad = true;
  129. }
  130. }
  131. }
  132. if (null_grad) return;
  133. ApplyContext ctx;
  134. ctx.op = backward_graph->backward;
  135. ctx.flags = is_tracing ? Flags::TRACE : 0;
  136. ctx.nargs = nargs;
  137. ctx.args = args;
  138. for (size_t i = 0; i < nargs; ++i) {
  139. ctx.flags |= args[i]->m_flags;
  140. mgb_assert(args[i]);
  141. }
  142. auto igrads = apply(ctx);
  143. auto&& it = igrads.begin();
  144. for (auto [i, p] : views::enumerate(backward_graph->input_has_grad)) {
  145. if (p) {
  146. receiver(i, std::move(*it));
  147. ++it;
  148. }
  149. }
  150. }
  151. bool input_has_grad(size_t i) {
  152. return backward_graph->input_has_grad[i];
  153. }
  154. bool output_requires_grad(size_t i) {
  155. return backward_graph->save_for_backward[grad_mask_offset + i];
  156. }
  157. bool output_captured(size_t i) {
  158. return backward_graph->save_for_backward[output_mask_offset + i];
  159. }
  160. };
  161. struct PythonBackward {
  162. py::object pyfunc;
  163. size_t input_size;
  164. PythonBackward(py::object f, size_t nin)
  165. : pyfunc(f), input_size(nin) {}
  166. template <typename T, typename R>
  167. void operator()(BackwardContext& ctx, T&& grads, R&& receiver) {
  168. auto args = py::tuple(grads.size());
  169. for (size_t i = 0; i < grads.size(); ++i) {
  170. auto&& g = grads[i];
  171. args[i] = g ? ctx.wrap_tensor(g) : py::none();
  172. }
  173. auto input_grads = py::reinterpret_steal<py::object>(PyObject_Call(pyfunc.ptr(), args.ptr(), nullptr));
  174. if (!input_grads) throw py::error_already_set();
  175. if (input_grads.is_none()) return;
  176. if (auto* tw = TensorWrapper::try_cast(input_grads.ptr())) {
  177. if (input_size != 1) {
  178. throw py::value_error("custom grad rule returned wrong number of grads");
  179. }
  180. if (!ctx.pytype) {
  181. ctx.pytype = Py_TYPE(input_grads.ptr());
  182. }
  183. receiver(0, tw->m_tensor);
  184. return;
  185. }
  186. if (py::len(input_grads) != input_size) {
  187. throw py::value_error("custom grad rule returned wrong number of grads");
  188. }
  189. for (auto [i, g] : views::enumerate(input_grads)) {
  190. if (g.is_none()) continue;
  191. auto* tw = TensorWrapper::try_cast(g.ptr());
  192. if (!tw) {
  193. throw py::type_error("custom grad rule returned non-tensor");
  194. }
  195. if (!ctx.pytype) {
  196. ctx.pytype = Py_TYPE(g.ptr());
  197. }
  198. receiver(i, tw->m_tensor);
  199. }
  200. }
  201. static constexpr bool input_has_grad(size_t) {return true;}
  202. static constexpr bool output_requires_grad(size_t) {return true;}
  203. static constexpr bool output_captured(size_t) {return true;}
  204. };
  205. } // namespace
  206. struct GradProducerRecord : intrusive_list::Node<GradProducerRecord> {
  207. using Base = intrusive_list::Node<GradProducerRecord>;
  208. GradProducerRecord() = default;
  209. GradProducerRecord(GradProducerRecord::head_t& head) : Base(intrusive_list::after_t{}, head) {}
  210. // GradProducerRecord(GradProducerRecord&&) = default;
  211. // GradProducerRecord& operator=(GradProducerRecord&) = default;
  212. // GradProducerRecord& operator=(GradProducerRecord&&) = default;
  213. };
  214. struct GradSlot {
  215. std::shared_ptr<Tensor> grad;
  216. py::object callback;
  217. GradProducerRecord::head_t producer_head;
  218. };
  219. struct GradSlotProducerPtr : GradSlotPtr {
  220. GradProducerRecord producer_record;
  221. GradSlotProducerPtr() = default;
  222. GradSlotProducerPtr(GradInfo& info) : GradSlotPtr(info), producer_record(info->producer_head) {}
  223. };
  224. struct GradFn : std::enable_shared_from_this<GradFn> {
  225. static MemPool<GradFn> pool;
  226. std::weak_ptr<GradKey> key;
  227. // slots for receiving and accumulating grads
  228. // same length as outputs (of forward op)
  229. SmallVector<GradSlot> slots;
  230. // where to send and accumulate grads
  231. // same length as inputs (of forward op)
  232. SmallVector<GradSlotProducerPtr> dsts;
  233. // encapsules actual function to compute gradient
  234. std::variant<std::monostate, BackwardGraphWithClosure, PythonBackward, CustomBackward> backward;
  235. // a flag used during backward
  236. bool in_ref_keeper = false;
  237. static void deleter(GradFn* ptr) {
  238. pool.free(ptr);
  239. }
  240. std::shared_ptr<GradFn> make() {
  241. return std::shared_ptr<GradFn>(pool.alloc(), &deleter);
  242. }
  243. void clear() {
  244. key.reset();
  245. slots.clear();
  246. dsts.clear();
  247. backward.emplace<std::monostate>();
  248. }
  249. };
  250. GradSlotPtr::operator bool() const {
  251. return bool(grad_fn);
  252. }
  253. GradSlot* GradSlotPtr::operator->() {
  254. return &grad_fn->slots[idx];
  255. }
  256. namespace {
  257. class GradFnHelper {
  258. std::shared_ptr<GradFn> grad_fn;
  259. GradFn* get() {
  260. if (!grad_fn) {
  261. grad_fn = std::make_shared<GradFn>();
  262. }
  263. return grad_fn.get();
  264. }
  265. friend apply_result_t imperative::python::apply_grad(ApplyContext&);
  266. public:
  267. template<typename T, typename... Args>
  268. auto& emplace(Args&&... args) {
  269. return get()->backward.emplace<T>(std::forward<Args>(args)...);
  270. }
  271. void reset() { grad_fn = nullptr; }
  272. };
  273. apply_result_t backward_graph_grad_rule(ApplyContext& ctx, GradFnHelper& ret_grad_fn) {
  274. auto outputs = apply(ctx);
  275. auto backward_graph = make_backward_graph(ctx, outputs);
  276. if (!backward_graph) {
  277. return outputs;
  278. }
  279. ret_grad_fn.emplace<BackwardGraphWithClosure>(std::move(backward_graph), ctx, outputs);
  280. return outputs;
  281. }
  282. apply_result_t python_grad_rule(ApplyContext& ctx, GradFnHelper& ret_grad_fn) {
  283. auto* op = ctx.op->try_cast_final<GenericPyOp>();
  284. py::tuple pyin(ctx.nargs);
  285. for (size_t i = 0; i < ctx.nargs; ++i) {
  286. pyin[i] = TensorWrapper::make(ctx.pytype, ctx.args[i]->shared_from_this());
  287. }
  288. auto grad_rule = py::getattr(op->obj, "_grad_rule");
  289. auto pyret = py::reinterpret_steal<py::object>(PyObject_Call(grad_rule.ptr(), pyin.ptr(), nullptr));
  290. if (!pyret) throw py::error_already_set();
  291. auto [outputs, backward] = py::cast<std::tuple<py::object, py::function>>(pyret);
  292. ret_grad_fn.emplace<PythonBackward>(std::move(backward), ctx.nargs);
  293. if (auto* tw = TensorWrapper::try_cast(outputs.ptr())) {
  294. return {tw->m_tensor};
  295. }
  296. apply_result_t ret;
  297. ret.reserve(py::len(outputs));
  298. for (auto&& i : outputs) {
  299. auto* tw = TensorWrapper::try_cast(i.ptr());
  300. mgb_assert(tw);
  301. ret.push_back(tw->m_tensor);
  302. }
  303. return ret;
  304. }
  305. } // namespace
  306. apply_result_t apply_grad(ApplyContext& ctx) {
  307. std::shared_ptr<GradKey> grad_key;
  308. for (size_t i = 0; i < ctx.nargs; ++i) {
  309. auto* tensor = ctx.args[i];
  310. if (tensor->m_grad_info.grad_fn) {
  311. auto&& input_grad_key = tensor->m_grad_info.grad_fn->key.lock();
  312. // tensor is attached to a live GradKey
  313. if (input_grad_key && input_grad_key->active) {
  314. if (grad_key) {
  315. if (grad_key != input_grad_key) {
  316. PyErr_SetString(PyExc_NotImplementedError, "second order grad");
  317. throw pyext17::py_err_set();
  318. }
  319. } else {
  320. grad_key = std::move(input_grad_key);
  321. }
  322. } else {
  323. // cleanup stale grad info
  324. // under what condition?
  325. tensor->m_grad_info = {};
  326. tensor->m_flags &= ~Flags::GRAD;
  327. }
  328. } else {
  329. tensor->m_flags &= ~Flags::GRAD;
  330. }
  331. }
  332. ctx.flags &= ~Flags::GRAD;
  333. if (!grad_key) {
  334. return apply(ctx);
  335. }
  336. GradFnHelper grad_fn_holder;
  337. auto outputs = [&]() {
  338. auto _ = scoped_disable(Flags::GRAD);
  339. if (ctx.op->same_type<GenericPyOp>()) {
  340. return python_grad_rule(ctx, grad_fn_holder);
  341. }
  342. auto&& registry = grad_rule_registry();
  343. auto&& it = registry.find(ctx.op->dyn_typeinfo());
  344. if (it != registry.end()) {
  345. auto&& maker = grad_fn_holder.emplace<CustomBackward>().maker(ctx);
  346. try {
  347. auto ret = it->second(ctx, maker);
  348. maker.finalize();
  349. return ret;
  350. } catch (GradRuleFallback&) {
  351. grad_fn_holder.reset();
  352. }
  353. }
  354. return backward_graph_grad_rule(ctx, grad_fn_holder);
  355. }();
  356. auto& grad_fn = grad_fn_holder.grad_fn;
  357. if (!grad_fn) {
  358. return outputs;
  359. }
  360. grad_fn->key = grad_key;
  361. grad_fn->slots.resize(outputs.size());
  362. grad_fn->dsts.reserve(ctx.nargs);
  363. std::visit([&](auto& backward) {
  364. using T = std::decay_t<decltype(backward)>;
  365. if constexpr (std::is_same_v<T, std::monostate>) {
  366. mgb_assert(0);
  367. } else {
  368. for (size_t i = 0; i < ctx.nargs; ++i) {
  369. if (backward.input_has_grad(i) && input_requires_grad(ctx, i)) {
  370. auto& input_grad_info = ctx.args[i]->m_grad_info;
  371. grad_fn->dsts.emplace_back(input_grad_info);
  372. // register as grad producer
  373. grad_fn->dsts.back().producer_record.insert_after(input_grad_info->producer_head);
  374. } else {
  375. grad_fn->dsts.emplace_back();
  376. }
  377. }
  378. for (size_t i = 0; i < outputs.size(); ++i) {
  379. if (backward.output_requires_grad(i)) {
  380. if (backward.output_captured(i)) {
  381. // avoid reference cycle [Tensor <-> GradFn]
  382. outputs[i] = outputs[i]->copy();
  383. }
  384. // populate grad info of output tensor
  385. auto& grad_info = outputs[i]->m_grad_info;
  386. grad_info.grad_fn = grad_fn;
  387. grad_info.idx = i;
  388. grad_info.insert_after(grad_key->free_vars_head);
  389. outputs[i]->m_flags |= Flags::GRAD;
  390. }
  391. }
  392. }
  393. }, grad_fn->backward);
  394. // record forward history
  395. grad_key->tape.emplace_back(grad_fn);
  396. return outputs;
  397. }
  398. void GradKeyWrapper::attach(PyObject*const* args, size_t nargs) {
  399. if (nargs != 2) {
  400. throw py::type_error("expect 2 arguments");
  401. }
  402. auto* tw = TensorWrapper::try_cast(args[0]);
  403. if (!tw) {
  404. throw py::type_error("argument 1 must be Tensor");
  405. }
  406. auto* tensor = tw->m_tensor.get();
  407. py::object callback;
  408. if (args[1] != Py_None) {
  409. callback = py::reinterpret_borrow<py::object>(args[1]);
  410. }
  411. m_key->attach(tensor, std::move(callback));
  412. }
  413. //! GradKey is weakly refered by tensor->m_grad_info.grad_fn->key after attach
  414. void GradKey::attach(Tensor* tensor, pybind11::object callback) {
  415. if (!active) {
  416. throw py::value_error("grad key finalized");
  417. }
  418. if (tensor->m_grad_info.grad_fn) {
  419. if (tensor->m_grad_info.grad_fn->key.lock().get() != this) {
  420. PyErr_SetString(PyExc_NotImplementedError, "second order grad");
  421. throw pyext17::py_err_set();
  422. }
  423. if (tensor->m_grad_info->callback) {
  424. throw py::value_error("callback already set on this tensor");
  425. }
  426. } else {
  427. tensor->m_grad_info.idx = 0;
  428. auto& grad_fn = tensor->m_grad_info.grad_fn;
  429. grad_fn = std::make_shared<GradFn>();
  430. grad_fn->key = shared_from_this();
  431. grad_fn->slots.resize(1);
  432. tensor->m_grad_info.insert_after(free_vars_head);
  433. tensor->m_flags |= Flags::GRAD;
  434. }
  435. tensor->m_grad_info.grad_fn->slots[0].callback = std::move(callback);
  436. }
  437. template<typename T>
  438. void accum_grad(std::shared_ptr<Tensor>& grad, T&& delta) {
  439. if (!grad) {
  440. grad = std::forward<T>(delta);
  441. return;
  442. }
  443. static std::shared_ptr<OpDef> op = std::shared_ptr<OpDef>(new Elemwise(Elemwise::Mode::ADD));
  444. grad = apply(op, grad, std::forward<T>(delta))[0];
  445. }
  446. void GradKey::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWrapper*> grads) {
  447. if (!active) {
  448. throw py::value_error("finalized");
  449. }
  450. if (tensors.size() != grads.size()) {
  451. throw py::value_error("tensor and grad size mismatch");
  452. }
  453. // this GradKey is marked inactive here
  454. active = false;
  455. struct CleanupGuard {
  456. GradKey* owner;
  457. CleanupGuard(GradKey* this_) : owner(this_) {}
  458. ~CleanupGuard() {owner->cleanup();}
  459. } _cleanup_guard(this);
  460. if (tape.empty()) return;
  461. BackwardContext bctx;
  462. if (!grads.empty()) {
  463. bctx.pytype = Py_TYPE(grads[0]->self().ptr());
  464. }
  465. for (size_t i = 0; i < tensors.size(); ++i) {
  466. auto& grad_info = tensors[i]->m_tensor->m_grad_info;
  467. if (grad_info.grad_fn && grad_info.grad_fn->key.lock().get() == this) {
  468. grad_info->grad = grads[i]->m_tensor;
  469. }
  470. }
  471. std::vector<std::shared_ptr<GradFn>> ref_keeper;
  472. ref_keeper.reserve(tape.size());
  473. // back-propagation in reverse order
  474. for (std::ptrdiff_t k = tape.size() - 1; k >= 0; --k) {
  475. auto&& grad_fn = tape[k].lock();
  476. if (!grad_fn) continue;
  477. auto grad_receiver = [&](size_t i, auto&& g) {
  478. auto& dst = grad_fn->dsts[i];
  479. if (dst) {
  480. accum_grad(dst->grad, std::forward<decltype(g)>(g));
  481. }
  482. };
  483. std::visit([&](auto&& backward) {
  484. using T = std::decay_t<decltype(backward)>;
  485. if constexpr (std::is_same_v<T, std::monostate>) {
  486. mgb_assert(0);
  487. } else {
  488. auto&& grads = views::transform(grad_fn->slots, [](auto&& slot) {return slot.grad.get();});
  489. backward(bctx, std::forward<decltype(grads)>(grads), grad_receiver);
  490. }
  491. }, grad_fn->backward);
  492. for (auto&& dst : grad_fn->dsts) {
  493. if (!dst.grad_fn) continue;
  494. if (!dst.grad_fn->in_ref_keeper) {
  495. // after grad_fn is cleared, refcnt of subsequent grad_fn
  496. // could drop to 0
  497. dst.grad_fn->in_ref_keeper = true;
  498. ref_keeper.push_back(dst.grad_fn);
  499. }
  500. if (!dst.producer_record.next && dst->callback && dst->grad) {
  501. // I'm the last grad producer, invoke callback
  502. dst->callback(bctx.wrap_tensor(dst->grad));
  503. }
  504. }
  505. grad_fn->clear();
  506. } // finish tape loop
  507. }
  508. void GradKey::cleanup() {
  509. active = false;
  510. tape.clear();
  511. for (intrusive_list::Iterator it(free_vars_head); it;) {
  512. it->grad_fn.reset();
  513. (it++)->unlink();
  514. }
  515. }
  516. void GradKeyWrapper::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWrapper*> grads) {
  517. m_key->backward(std::move(tensors), std::move(grads));
  518. }
  519. PyObject* GradKeyWrapper::get_name() {
  520. return py::cast(m_key->name).release().ptr();
  521. }
  522. void GradKeyWrapper::set_name(py::handle name) {
  523. m_key->name = py::cast<std::string>(name);
  524. }
  525. PyObject* GradKeyWrapper::is_attached_to(PyObject*const* args, size_t nargs) {
  526. if (nargs != 1) {
  527. PyErr_SetString(PyExc_TypeError, "expect 1 argument");
  528. return nullptr;
  529. }
  530. auto* tw = TensorWrapper::try_cast(args[0]);
  531. if (!tw) {
  532. PyErr_SetString(PyExc_TypeError, "expect Tensor");
  533. return nullptr;
  534. }
  535. auto&& grad_fn = tw->m_tensor->m_grad_info.grad_fn;
  536. if (grad_fn && grad_fn->key.lock() == m_key) {
  537. Py_RETURN_TRUE;
  538. }
  539. Py_RETURN_FALSE;
  540. }
  541. GradKey::~GradKey() {
  542. cleanup();
  543. }
  544. std::unordered_map<Typeinfo*, GradRuleFn>& grad_rule_registry() {
  545. static std::unordered_map<Typeinfo*, GradRuleFn> registry;
  546. return registry;
  547. }
  548. } // namespace mgb::imperative::python

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台