| @@ -20,6 +20,9 @@ class AttachSpec: | |||||
| __slots__ = "tensor", "callbacks" | __slots__ = "tensor", "callbacks" | ||||
| _global_priority = 0 | |||||
| class GradManager: | class GradManager: | ||||
| r""" | r""" | ||||
| GradManager computes gradients or more generally, vector-Jacobian product, by reverse mode | GradManager computes gradients or more generally, vector-Jacobian product, by reverse mode | ||||
| @@ -118,6 +121,7 @@ class GradManager: | |||||
| self._grad = None | self._grad = None | ||||
| self._after_backward_callback = [] | self._after_backward_callback = [] | ||||
| self._gradients = {} | self._gradients = {} | ||||
| self._priority = None | |||||
| def attach(self, tensors: Iterable[Tensor], callbacks=None): | def attach(self, tensors: Iterable[Tensor], callbacks=None): | ||||
| r""" | r""" | ||||
| @@ -293,6 +297,7 @@ class GradManager: | |||||
| After this call, you will be able to call :meth:`backward`. | After this call, you will be able to call :meth:`backward`. | ||||
| """ | """ | ||||
| global _global_priority | |||||
| if self._recording: | if self._recording: | ||||
| raise RuntimeError("already recording") | raise RuntimeError("already recording") | ||||
| grad = Grad() | grad = Grad() | ||||
| @@ -300,6 +305,9 @@ class GradManager: | |||||
| self._grad = grad | self._grad = grad | ||||
| for spec in self._attach_specs.values(): | for spec in self._attach_specs.values(): | ||||
| self._do_record(spec) | self._do_record(spec) | ||||
| if self._priority is None: | |||||
| grad._priority = _global_priority | |||||
| _global_priority -= 1 | |||||
| grad.__enter__() | grad.__enter__() | ||||
| def _do_record(self, spec): | def _do_record(self, spec): | ||||
| @@ -321,11 +329,14 @@ class GradManager: | |||||
| After this call, you will not be able to call :meth:`backward`. | After this call, you will not be able to call :meth:`backward`. | ||||
| """ | """ | ||||
| global _global_priority | |||||
| if self._grad is not None: | if self._grad is not None: | ||||
| self._grad.__exit__(None, None, None) | self._grad.__exit__(None, None, None) | ||||
| self._grad = None | self._grad = None | ||||
| self._recording = False | self._recording = False | ||||
| self._gradients = dict() | self._gradients = dict() | ||||
| if self._priority is None: | |||||
| _global_priority += 1 | |||||
| def __enter__(self): | def __enter__(self): | ||||
| self.record() | self.record() | ||||
| @@ -333,3 +344,41 @@ class GradManager: | |||||
| def __exit__(self, exc_type, exc_val, exc_tb): | def __exit__(self, exc_type, exc_val, exc_tb): | ||||
| self.release() | self.release() | ||||
| def __and__(self, other): | |||||
| if isinstance(other, GradManager): | |||||
| return GradManagerGroup([self, other]) | |||||
| return NotImplemented | |||||
| __rand__ = __and__ | |||||
| class GradManagerGroup: | |||||
| def __init__(self, gms) -> None: | |||||
| self._gms = list(gms) | |||||
| def merge_with(self, other): | |||||
| if isinstance(other, GradManager): | |||||
| other = GradManagerGroup([other]) | |||||
| elif not isinstance(other, GradManagerGroup): | |||||
| return NotImplemented | |||||
| return GradManagerGroup([*self._gms, *other._gms]) | |||||
| __and__ = merge_with | |||||
| __rand__ = merge_with | |||||
| __or__ = merge_with | |||||
| __ror__ = merge_with | |||||
| def __enter__(self): | |||||
| global _global_priority | |||||
| _global_priority += 1 | |||||
| for gm in self._gms: | |||||
| gm._priority = _global_priority | |||||
| gm.record() | |||||
| def __exit__(self, exc_type, exc_val, exc_tb): | |||||
| global _global_priority | |||||
| _global_priority -= 1 | |||||
| for gm in self._gms: | |||||
| gm.release() | |||||
| gm._priority = None | |||||
| @@ -47,6 +47,14 @@ class Grad: | |||||
| self._impl = GradKey(name) | self._impl = GradKey(name) | ||||
| _grad_manager_dict[self._name] = self | _grad_manager_dict[self._name] = self | ||||
| @property | |||||
| def _priority(self): | |||||
| return self._impl.priority | |||||
| @_priority.setter | |||||
| def _priority(self, priority): | |||||
| self._impl.priority = priority | |||||
| @property | @property | ||||
| def _name(self): | def _name(self): | ||||
| return self._impl.name | return self._impl.name | ||||
| @@ -54,7 +54,7 @@ std::shared_ptr<OptimizedBackwardGraphResult> make_backward_graph( | |||||
| for (size_t i = 0; i < ctx.nargs; ++i) { | for (size_t i = 0; i < ctx.nargs; ++i) { | ||||
| *(size_t_ptr++) = mgb::hash(ctx.args[i]->dtype().handle()); | *(size_t_ptr++) = mgb::hash(ctx.args[i]->dtype().handle()); | ||||
| *(size_t_ptr++) = mgb::hash(ctx.args[i]->comp_node()); | *(size_t_ptr++) = mgb::hash(ctx.args[i]->comp_node()); | ||||
| *(bool_ptr++) = bool(ctx.args[i]->m_grad_info.grad_fn); | |||||
| *(bool_ptr++) = !ctx.args[i]->m_grad_info_dict.empty(); | |||||
| } | } | ||||
| mgb_assert(bool_ptr0 == reinterpret_cast<bool*>(size_t_ptr) && | mgb_assert(bool_ptr0 == reinterpret_cast<bool*>(size_t_ptr) && | ||||
| bool_ptr == reinterpret_cast<bool*>(buf + buf_size)); | bool_ptr == reinterpret_cast<bool*>(buf + buf_size)); | ||||
| @@ -321,7 +321,7 @@ apply_result_t backward_graph_grad_rule(ApplyContext& ctx, GradFnHelper& ret_gra | |||||
| for (size_t i = 0; i < ctx.nargs; ++i) { | for (size_t i = 0; i < ctx.nargs; ++i) { | ||||
| inputs_copy.push_back(python::apply(FastpathCopy::make(), ctx.args[i]->shared_from_this())[0]); | inputs_copy.push_back(python::apply(FastpathCopy::make(), ctx.args[i]->shared_from_this())[0]); | ||||
| inputs_copy_weak.push_back(inputs_copy.back().get()); | inputs_copy_weak.push_back(inputs_copy.back().get()); | ||||
| inputs_copy.back()->m_grad_info = ctx.args[i]->m_grad_info; | |||||
| inputs_copy.back()->m_grad_info_dict = ctx.args[i]->m_grad_info_dict; | |||||
| } | } | ||||
| ApplyContext ctx_dup = ctx; | ApplyContext ctx_dup = ctx; | ||||
| ctx_dup.args = inputs_copy_weak.data(); | ctx_dup.args = inputs_copy_weak.data(); | ||||
| @@ -365,25 +365,19 @@ apply_result_t python_grad_rule(ApplyContext& ctx, GradFnHelper& ret_grad_fn) { | |||||
| } // namespace | } // namespace | ||||
| apply_result_t apply_grad(ApplyContext& ctx) { | apply_result_t apply_grad(ApplyContext& ctx) { | ||||
| std::shared_ptr<GradKey> grad_key; | |||||
| std::unordered_set<std::shared_ptr<GradKey>> grad_keys; | |||||
| for (size_t i = 0; i < ctx.nargs; ++i) { | for (size_t i = 0; i < ctx.nargs; ++i) { | ||||
| auto* tensor = ctx.args[i]; | auto* tensor = ctx.args[i]; | ||||
| if (tensor->m_grad_info.grad_fn) { | |||||
| auto&& input_grad_key = tensor->m_grad_info.grad_fn->key.lock(); | |||||
| // tensor is attached to a live GradKey | |||||
| if (input_grad_key && input_grad_key->active) { | |||||
| if (grad_key) { | |||||
| if (grad_key != input_grad_key) { | |||||
| PyErr_SetString(PyExc_NotImplementedError, "second order grad"); | |||||
| throw pyext17::py_err_set(); | |||||
| } | |||||
| } else { | |||||
| grad_key = std::move(input_grad_key); | |||||
| if (!tensor->m_grad_info_dict.empty()) { | |||||
| size_t grad_cnt = 0; | |||||
| for (auto&& grad_info: tensor->m_grad_info_dict) { | |||||
| auto input_grad_key = grad_info.grad_fn->key.lock(); | |||||
| if (input_grad_key && input_grad_key->active && !input_grad_key->is_blocked()) { | |||||
| grad_keys.insert(input_grad_key); | |||||
| grad_cnt++; | |||||
| } | } | ||||
| } else { | |||||
| // cleanup stale grad info | |||||
| // under what condition? | |||||
| tensor->m_grad_info = {}; | |||||
| } | |||||
| if (!grad_cnt) { | |||||
| tensor->m_flags &= ~Flags::GRAD; | tensor->m_flags &= ~Flags::GRAD; | ||||
| } | } | ||||
| } else { | } else { | ||||
| @@ -393,7 +387,7 @@ apply_result_t apply_grad(ApplyContext& ctx) { | |||||
| ctx.flags &= ~Flags::GRAD; | ctx.flags &= ~Flags::GRAD; | ||||
| if (!grad_key) { | |||||
| if (grad_keys.empty()) { | |||||
| return apply(ctx); | return apply(ctx); | ||||
| } | } | ||||
| @@ -418,54 +412,65 @@ apply_result_t apply_grad(ApplyContext& ctx) { | |||||
| return backward_graph_grad_rule(ctx, grad_fn_holder); | return backward_graph_grad_rule(ctx, grad_fn_holder); | ||||
| }(); | }(); | ||||
| auto& grad_fn = grad_fn_holder.grad_fn; | |||||
| if (!grad_fn) { | |||||
| if (!grad_fn_holder.grad_fn) { | |||||
| return outputs; | return outputs; | ||||
| } | } | ||||
| grad_fn->key = grad_key; | |||||
| grad_fn->slots.resize(outputs.size()); | |||||
| grad_fn->dsts.reserve(ctx.nargs); | |||||
| for (auto&& grad_key: grad_keys) { | |||||
| auto grad_fn = std::make_shared<GradFn>(); | |||||
| grad_fn->backward = grad_fn_holder.grad_fn->backward; | |||||
| grad_fn->key = grad_key; | |||||
| grad_fn->slots.resize(outputs.size()); | |||||
| grad_fn->dsts.reserve(ctx.nargs); | |||||
| std::visit([&](auto& backward) { | |||||
| using T = std::decay_t<decltype(backward)>; | |||||
| if constexpr (std::is_same_v<T, std::monostate>) { | |||||
| mgb_assert(0); | |||||
| } else { | |||||
| for (size_t i = 0; i < ctx.nargs; ++i) { | |||||
| if (backward.input_has_grad(i) && input_requires_grad(ctx, i)) { | |||||
| auto& input_grad_info = ctx.args[i]->m_grad_info; | |||||
| grad_fn->dsts.emplace_back(input_grad_info); | |||||
| // register as grad producer | |||||
| grad_fn->dsts.back().producer_record.insert_after(input_grad_info->producer_head); | |||||
| } else { | |||||
| grad_fn->dsts.emplace_back(); | |||||
| std::visit([&](auto& backward) { | |||||
| using T = std::decay_t<decltype(backward)>; | |||||
| if constexpr (std::is_same_v<T, std::monostate>) { | |||||
| mgb_assert(0); | |||||
| } else { | |||||
| for (size_t i = 0; i < ctx.nargs; ++i) { | |||||
| if (backward.input_has_grad(i) && input_requires_grad(ctx, i) && ctx.args[i]->m_grad_info_dict.count(grad_key.get())) { | |||||
| auto& input_grad_info = ctx.args[i]->m_grad_info_dict.at(grad_key.get()); | |||||
| grad_fn->dsts.emplace_back(input_grad_info); | |||||
| // register as grad producer | |||||
| grad_fn->dsts.back().producer_record.insert_after(input_grad_info->producer_head); | |||||
| } else { | |||||
| grad_fn->dsts.emplace_back(); | |||||
| } | |||||
| } | } | ||||
| } | |||||
| for (size_t i = 0; i < outputs.size(); ++i) { | |||||
| if (backward.output_requires_grad(i)) { | |||||
| if (backward.output_captured(i)) { | |||||
| // avoid reference cycle [Tensor <-> GradFn] | |||||
| static std::shared_ptr<OpDef> op = std::shared_ptr<OpDef>(new FastpathCopy()); | |||||
| outputs[i] = python::apply(op, outputs[i])[0]; | |||||
| for (size_t i = 0; i < outputs.size(); ++i) { | |||||
| if (backward.output_requires_grad(i)) { | |||||
| if (backward.output_captured(i)) { | |||||
| // avoid reference cycle [Tensor <-> GradFn] | |||||
| static std::shared_ptr<OpDef> op = std::make_shared<FastpathCopy>(); | |||||
| outputs[i] = python::apply(op, outputs[i])[0]; | |||||
| } | |||||
| // populate grad info of output tensor | |||||
| auto& grad_info = outputs[i]->m_grad_info_dict[grad_key.get()]; | |||||
| grad_info.grad_fn = grad_fn; | |||||
| grad_info.idx = i; | |||||
| grad_info.insert_after(grad_key->free_vars_head); | |||||
| outputs[i]->m_flags |= Flags::GRAD; | |||||
| } | } | ||||
| // populate grad info of output tensor | |||||
| auto& grad_info = outputs[i]->m_grad_info; | |||||
| grad_info.grad_fn = grad_fn; | |||||
| grad_info.idx = i; | |||||
| grad_info.insert_after(grad_key->free_vars_head); | |||||
| outputs[i]->m_flags |= Flags::GRAD; | |||||
| } | } | ||||
| } | } | ||||
| } | |||||
| }, grad_fn->backward); | |||||
| }, grad_fn->backward); | |||||
| // record forward history | |||||
| grad_key->tape.emplace_back(grad_fn); | |||||
| // record forward history | |||||
| grad_key->tape.emplace_back(grad_fn); | |||||
| } | |||||
| return outputs; | return outputs; | ||||
| } | } | ||||
| PyObject* GradKeyWrapper::get_priority() { | |||||
| return py::cast(m_key->priority).release().ptr(); | |||||
| } | |||||
| void GradKeyWrapper::set_priority(pybind11::handle priority) { | |||||
| m_key->name = py::cast<int>(priority); | |||||
| } | |||||
| void GradKeyWrapper::attach(PyObject*const* args, size_t nargs) { | void GradKeyWrapper::attach(PyObject*const* args, size_t nargs) { | ||||
| if (nargs != 2) { | if (nargs != 2) { | ||||
| throw py::type_error("expect 2 arguments"); | throw py::type_error("expect 2 arguments"); | ||||
| @@ -488,24 +493,21 @@ void GradKey::attach(Tensor* tensor, pybind11::object callback) { | |||||
| throw py::value_error("grad key finalized"); | throw py::value_error("grad key finalized"); | ||||
| } | } | ||||
| if (tensor->m_grad_info.grad_fn) { | |||||
| if (tensor->m_grad_info.grad_fn->key.lock().get() != this) { | |||||
| PyErr_SetString(PyExc_NotImplementedError, "second order grad"); | |||||
| throw pyext17::py_err_set(); | |||||
| } | |||||
| if (tensor->m_grad_info->callback) { | |||||
| if (tensor->m_grad_info_dict.count(this)) { | |||||
| if (tensor->m_grad_info_dict.at(this)->callback) { | |||||
| throw py::value_error("callback already set on this tensor"); | throw py::value_error("callback already set on this tensor"); | ||||
| } | } | ||||
| } else { | } else { | ||||
| tensor->m_grad_info.idx = 0; | |||||
| auto& grad_fn = tensor->m_grad_info.grad_fn; | |||||
| auto& grad_info = tensor->m_grad_info_dict[this]; | |||||
| grad_info.idx = 0; | |||||
| auto& grad_fn = grad_info.grad_fn; | |||||
| grad_fn = std::make_shared<GradFn>(); | grad_fn = std::make_shared<GradFn>(); | ||||
| grad_fn->key = shared_from_this(); | grad_fn->key = shared_from_this(); | ||||
| grad_fn->slots.resize(1); | grad_fn->slots.resize(1); | ||||
| tensor->m_grad_info.insert_after(free_vars_head); | |||||
| grad_info.insert_after(free_vars_head); | |||||
| tensor->m_flags |= Flags::GRAD; | tensor->m_flags |= Flags::GRAD; | ||||
| } | } | ||||
| tensor->m_grad_info.grad_fn->slots[0].callback = std::move(callback); | |||||
| tensor->m_grad_info_dict.at(this).grad_fn->slots[0].callback = std::move(callback); | |||||
| } | } | ||||
| template<typename T> | template<typename T> | ||||
| @@ -530,8 +532,15 @@ void GradKey::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWr | |||||
| active = false; | active = false; | ||||
| struct CleanupGuard { | struct CleanupGuard { | ||||
| GradKey* owner; | GradKey* owner; | ||||
| CleanupGuard(GradKey* this_) : owner(this_) {} | |||||
| ~CleanupGuard() {owner->cleanup();} | |||||
| size_t priority_backup; | |||||
| CleanupGuard(GradKey* this_) : owner(this_) { | |||||
| priority_backup = sm_min_priority; | |||||
| sm_min_priority = owner->priority; | |||||
| } | |||||
| ~CleanupGuard() { | |||||
| owner->cleanup(); | |||||
| sm_min_priority = priority_backup; | |||||
| } | |||||
| } _cleanup_guard(this); | } _cleanup_guard(this); | ||||
| if (tape.empty()) return; | if (tape.empty()) return; | ||||
| @@ -542,14 +551,16 @@ void GradKey::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWr | |||||
| } | } | ||||
| for (size_t i = 0; i < tensors.size(); ++i) { | for (size_t i = 0; i < tensors.size(); ++i) { | ||||
| auto& grad_info = tensors[i]->m_tensor->m_grad_info; | |||||
| if (grad_info.grad_fn && grad_info.grad_fn->key.lock().get() == this) { | |||||
| grad_info->grad = grads[i]->m_tensor; | |||||
| if (tensors[i]->m_tensor->m_grad_info_dict.count(this) == 0) { | |||||
| continue; | |||||
| } | } | ||||
| auto& grad_info = tensors[i]->m_tensor->m_grad_info_dict.at(this); | |||||
| grad_info->grad = grads[i]->m_tensor; | |||||
| } | } | ||||
| std::vector<std::shared_ptr<GradFn>> ref_keeper; | std::vector<std::shared_ptr<GradFn>> ref_keeper; | ||||
| ref_keeper.reserve(tape.size()); | ref_keeper.reserve(tape.size()); | ||||
| // back-propagation in reverse order | // back-propagation in reverse order | ||||
| for (std::ptrdiff_t k = tape.size() - 1; k >= 0; --k) { | for (std::ptrdiff_t k = tape.size() - 1; k >= 0; --k) { | ||||
| auto&& grad_fn = tape[k].lock(); | auto&& grad_fn = tape[k].lock(); | ||||
| @@ -619,13 +630,14 @@ PyObject* GradKeyWrapper::is_attached_to(PyObject*const* args, size_t nargs) { | |||||
| PyErr_SetString(PyExc_TypeError, "expect Tensor"); | PyErr_SetString(PyExc_TypeError, "expect Tensor"); | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| auto&& grad_fn = tw->m_tensor->m_grad_info.grad_fn; | |||||
| if (grad_fn && grad_fn->key.lock() == m_key) { | |||||
| if (tw->m_tensor->m_grad_info_dict.count(m_key.get())) { | |||||
| Py_RETURN_TRUE; | Py_RETURN_TRUE; | ||||
| } | } | ||||
| Py_RETURN_FALSE; | Py_RETURN_FALSE; | ||||
| } | } | ||||
| int GradKey::sm_min_priority = 0; | |||||
| GradKey::~GradKey() { | GradKey::~GradKey() { | ||||
| cleanup(); | cleanup(); | ||||
| } | } | ||||
| @@ -635,4 +647,41 @@ std::unordered_map<Typeinfo*, GradRuleFn>& grad_rule_registry() { | |||||
| return registry; | return registry; | ||||
| } | } | ||||
| void GradInfoCollection::_shrink() { | |||||
| auto pred = [](GradInfo& info){ return !(info.grad_fn) || info.grad_fn->key.expired(); }; | |||||
| auto iter = std::remove_if(m_storage.begin(), m_storage.end(), pred); | |||||
| m_storage.erase(iter, m_storage.end()); | |||||
| } | |||||
| bool GradInfoCollection::contains(GradKey* key) { | |||||
| _shrink(); | |||||
| for (auto&& grad_info: m_storage) { | |||||
| if (grad_info.grad_fn->key.lock().get() == key) { | |||||
| return true; | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| GradInfo& GradInfoCollection::operator[](GradKey* key) { | |||||
| _shrink(); | |||||
| for (auto&& grad_info: m_storage) { | |||||
| if (grad_info.grad_fn->key.lock().get() == key) { | |||||
| return grad_info; | |||||
| } | |||||
| } | |||||
| m_storage.emplace_back(); | |||||
| return m_storage.back(); | |||||
| } | |||||
| GradInfo& GradInfoCollection::at(GradKey* key) { | |||||
| _shrink(); | |||||
| for (auto&& grad_info: m_storage) { | |||||
| if (grad_info.grad_fn->key.lock().get() == key) { | |||||
| return grad_info; | |||||
| } | |||||
| } | |||||
| mgb_assert(false); | |||||
| } | |||||
| } // namespace mgb::imperative::python | } // namespace mgb::imperative::python | ||||
| @@ -26,12 +26,18 @@ struct GradKey : std::enable_shared_from_this<GradKey>, NonCopyableObj { | |||||
| bool active = true; | bool active = true; | ||||
| GradInfo::head_t free_vars_head; | GradInfo::head_t free_vars_head; | ||||
| std::vector<std::weak_ptr<GradFn>> tape; | std::vector<std::weak_ptr<GradFn>> tape; | ||||
| int priority = 0; | |||||
| ~GradKey(); | ~GradKey(); | ||||
| void attach(Tensor* tensor, pybind11::object callback); | void attach(Tensor* tensor, pybind11::object callback); | ||||
| void backward(std::vector<TensorWrapper*>, std::vector<TensorWrapper*>); | void backward(std::vector<TensorWrapper*>, std::vector<TensorWrapper*>); | ||||
| void cleanup(); | void cleanup(); | ||||
| bool is_blocked() const { | |||||
| return priority < sm_min_priority; | |||||
| } | |||||
| private: | |||||
| static int sm_min_priority; | |||||
| }; | }; | ||||
| struct GradKeyWrapper { | struct GradKeyWrapper { | ||||
| @@ -44,6 +50,8 @@ struct GradKeyWrapper { | |||||
| PyObject* get_name(); | PyObject* get_name(); | ||||
| void set_name(pybind11::handle name); | void set_name(pybind11::handle name); | ||||
| PyObject* get_priority(); | |||||
| void set_priority(pybind11::handle priority); | |||||
| void attach(PyObject*const* args, size_t nargs); | void attach(PyObject*const* args, size_t nargs); | ||||
| void backward(std::vector<TensorWrapper*>, std::vector<TensorWrapper*>); | void backward(std::vector<TensorWrapper*>, std::vector<TensorWrapper*>); | ||||
| PyObject* is_attached_to(PyObject*const* args, size_t nargs); | PyObject* is_attached_to(PyObject*const* args, size_t nargs); | ||||
| @@ -150,7 +158,7 @@ using GradRuleFn = std::function<apply_result_t(ApplyContext&, CustomBackward::M | |||||
| std::unordered_map<Typeinfo*, GradRuleFn>& grad_rule_registry(); | std::unordered_map<Typeinfo*, GradRuleFn>& grad_rule_registry(); | ||||
| inline bool input_requires_grad(const ApplyContext& ctx, size_t i) { | inline bool input_requires_grad(const ApplyContext& ctx, size_t i) { | ||||
| return bool(ctx.args[i]->m_grad_info.grad_fn); | |||||
| return !ctx.args[i]->m_grad_info_dict.empty(); | |||||
| } | } | ||||
| struct GradRuleFallback : std::exception {}; | struct GradRuleFallback : std::exception {}; | ||||
| @@ -15,6 +15,7 @@ | |||||
| namespace mgb::imperative::python { | namespace mgb::imperative::python { | ||||
| struct GradKey; | |||||
| struct GradFn; | struct GradFn; | ||||
| struct GradSlot; | struct GradSlot; | ||||
| @@ -32,6 +33,10 @@ struct GradInfo : GradSlotPtr, intrusive_list::Node<GradInfo, intrusive_list::be | |||||
| GradInfo(GradInfo&&) = default; | GradInfo(GradInfo&&) = default; | ||||
| GradInfo& operator=(GradInfo&) = default; | GradInfo& operator=(GradInfo&) = default; | ||||
| GradInfo& operator=(GradInfo&&) = default; | GradInfo& operator=(GradInfo&&) = default; | ||||
| GradInfo(const GradInfo& rhs): GradInfo(const_cast<GradInfo&>(rhs)){} | |||||
| GradInfo& operator=(const GradInfo& rhs) { | |||||
| return *this = const_cast<GradInfo&>(rhs); | |||||
| } | |||||
| }; | }; | ||||
| } // namespace mgb::imperative::python | } // namespace mgb::imperative::python | ||||
| @@ -182,7 +182,7 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje | |||||
| if (py::isinstance<PySymbolVar>(py::handle(args[0]))){ | if (py::isinstance<PySymbolVar>(py::handle(args[0]))){ | ||||
| SmallVector<cg::VarNode*> vinputs(nargs); | SmallVector<cg::VarNode*> vinputs(nargs); | ||||
| for (size_t i = 0; i < nargs; ++i) { | for (size_t i = 0; i < nargs; ++i) { | ||||
| vinputs[i] = py::handle(args[i]).cast<PySymbolVar*>()->m_node; | |||||
| vinputs[i] = py::handle(args[i]).cast<PySymbolVar*>()->m_node; | |||||
| } | } | ||||
| auto op = ctx.op.get(); | auto op = ctx.op.get(); | ||||
| auto rst = OpDef::apply_on_var_node(*op, vinputs); | auto rst = OpDef::apply_on_var_node(*op, vinputs); | ||||
| @@ -17,6 +17,7 @@ | |||||
| #include "megbrain/imperative/interpreter.h" | #include "megbrain/imperative/interpreter.h" | ||||
| #include "pybind11/pybind11.h" | #include "pybind11/pybind11.h" | ||||
| #include <string> | #include <string> | ||||
| #include <unordered_map> | |||||
| #include "./pyext17.h" | #include "./pyext17.h" | ||||
| @@ -36,6 +37,8 @@ struct ObjectPtr : B { | |||||
| namespace mgb::imperative::python { | namespace mgb::imperative::python { | ||||
| struct GradKey; | |||||
| extern interpreter::Interpreter::Channel* interpreter_for_py; | extern interpreter::Interpreter::Channel* interpreter_for_py; | ||||
| class SharedHandle { | class SharedHandle { | ||||
| @@ -58,6 +61,34 @@ public: | |||||
| }; | }; | ||||
| // impl in grad.cpp | |||||
| class GradInfoCollection { | |||||
| private: | |||||
| SmallVector<GradInfo> m_storage; | |||||
| protected: | |||||
| void _shrink(); | |||||
| public: | |||||
| bool contains(GradKey* key); | |||||
| GradInfo& operator[](GradKey* key); | |||||
| GradInfo& at(GradKey* key); | |||||
| bool empty() { | |||||
| _shrink(); | |||||
| return m_storage.empty(); | |||||
| } | |||||
| auto begin() { | |||||
| _shrink(); | |||||
| return m_storage.begin(); | |||||
| } | |||||
| auto end() { | |||||
| _shrink(); | |||||
| return m_storage.end(); | |||||
| } | |||||
| size_t count(GradKey* key) { | |||||
| return contains(key) ? 1 : 0; | |||||
| } | |||||
| }; | |||||
| struct Tensor : std::enable_shared_from_this<Tensor>, NonCopyableObj { | struct Tensor : std::enable_shared_from_this<Tensor>, NonCopyableObj { | ||||
| using flags_t = uint64_t; | using flags_t = uint64_t; | ||||
| @@ -69,7 +100,7 @@ struct Tensor : std::enable_shared_from_this<Tensor>, NonCopyableObj { | |||||
| flags_t m_flags = 0; | flags_t m_flags = 0; | ||||
| GradInfo m_grad_info; | |||||
| GradInfoCollection m_grad_info_dict; | |||||
| TraceInfo m_trace_info; | TraceInfo m_trace_info; | ||||
| SharedHandle m_handle; | SharedHandle m_handle; | ||||
| std::string user_custom_name; | std::string user_custom_name; | ||||
| @@ -88,7 +119,7 @@ struct Tensor : std::enable_shared_from_this<Tensor>, NonCopyableObj { | |||||
| inline std::shared_ptr<Tensor> copy() { | inline std::shared_ptr<Tensor> copy() { | ||||
| auto ret = std::make_shared<Tensor>(m_handle); | auto ret = std::make_shared<Tensor>(m_handle); | ||||
| ret->m_flags = m_flags; | ret->m_flags = m_flags; | ||||
| ret->m_grad_info = m_grad_info; | |||||
| ret->m_grad_info_dict = m_grad_info_dict; | |||||
| ret->m_trace_info = m_trace_info; | ret->m_trace_info = m_trace_info; | ||||
| ret->m_var = m_var; | ret->m_var = m_var; | ||||
| return ret; | return ret; | ||||
| @@ -108,21 +108,24 @@ def test_grad_2(): | |||||
| np.testing.assert_almost_equal(x.grad.numpy(), 4 * x_np ** 3, decimal=6) | np.testing.assert_almost_equal(x.grad.numpy(), 4 * x_np ** 3, decimal=6) | ||||
| @pytest.mark.skip(reason="high order gradient was not implemented yet") | |||||
| def test_2nd_grad(): | def test_2nd_grad(): | ||||
| x_np = np.random.rand(10).astype("float32") | x_np = np.random.rand(10).astype("float32") | ||||
| x = as_tensor(x_np) | x = as_tensor(x_np) | ||||
| ones = as_tensor(np.ones_like(x_np)) | ones = as_tensor(np.ones_like(x_np)) | ||||
| grad = Grad().wrt(x, callback=save_to(x)) | grad = Grad().wrt(x, callback=save_to(x)) | ||||
| grad._priority = -1 | |||||
| grad2 = Grad().wrt(x, callback=save_to(x)) | grad2 = Grad().wrt(x, callback=save_to(x)) | ||||
| grad2._priority = 0 | |||||
| y = cos(x) | y = cos(x) | ||||
| grad(y, ones) | grad(y, ones) | ||||
| z = x.grad | |||||
| np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5) | np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5) | ||||
| grad2(x.grad, ones) | |||||
| x.grad = None | |||||
| grad2(z, ones) | |||||
| np.testing.assert_almost_equal(x.grad.numpy(), -np.cos(x_np)) | np.testing.assert_almost_equal(x.grad.numpy(), -np.cos(x_np)) | ||||
| @@ -398,20 +398,6 @@ OP_TRAIT_REG(Copy, Copy) | |||||
| .fallback(); | .fallback(); | ||||
| }} // copy | }} // copy | ||||
| namespace { namespace identity { | |||||
| auto apply_on_var_node( | |||||
| const OpDef& def, | |||||
| const VarNodeArray& inputs) { | |||||
| auto&& op = def.cast_final_safe<Identity>(); | |||||
| mgb_assert(inputs.size() == 1); | |||||
| OperatorNodeConfig config{op.make_name()}; | |||||
| return opr::Identity::make(inputs[0], config); | |||||
| } | |||||
| OP_TRAIT_REG(Identity, Identity) | |||||
| .apply_on_var_node(apply_on_var_node) | |||||
| .fallback(); | |||||
| }} // identity | |||||
| namespace { namespace assert_equal { | namespace { namespace assert_equal { | ||||
| auto apply_on_var_node( | auto apply_on_var_node( | ||||
| const OpDef& def, | const OpDef& def, | ||||
| @@ -9,6 +9,7 @@ | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| */ | */ | ||||
| #include "megbrain/imperative/ops/autogen.h" | |||||
| #include "megbrain/imperative/ops/utility.h" | #include "megbrain/imperative/ops/utility.h" | ||||
| #include "megbrain/imperative/ops/opr_attr.h" | #include "megbrain/imperative/ops/opr_attr.h" | ||||
| #include "megbrain/opr/utility.h" | #include "megbrain/opr/utility.h" | ||||
| @@ -32,4 +33,25 @@ OP_TRAIT_REG(FastpathCopy,FastpathCopy) | |||||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(FastpathCopy); | MGB_DYN_TYPE_OBJ_FINAL_IMPL(FastpathCopy); | ||||
| namespace { namespace identity { | |||||
| auto apply_on_var_node( | |||||
| const OpDef& def, | |||||
| const VarNodeArray& inputs) { | |||||
| auto&& op = def.cast_final_safe<Identity>(); | |||||
| mgb_assert(inputs.size() == 1); | |||||
| OperatorNodeConfig config{op.make_name()}; | |||||
| return opr::Identity::make(inputs[0], config); | |||||
| } | |||||
| auto apply_on_physical_tensor( | |||||
| const OpDef& def, | |||||
| const SmallVector<TensorPtr>& inputs) { | |||||
| return SmallVector<TensorPtr>{inputs[0]}; | |||||
| } | |||||
| OP_TRAIT_REG(Identity, Identity) | |||||
| .apply_on_var_node(apply_on_var_node) | |||||
| .apply_on_physical_tensor(apply_on_physical_tensor) | |||||
| .fallback(); | |||||
| }} // identity | |||||
| } // namespace mgb::imperative | } // namespace mgb::imperative | ||||