GitOrigin-RevId: 5fc308f87a
tags/v1.5.0
| @@ -345,12 +345,12 @@ class GradManager: | |||||
| def __exit__(self, exc_type, exc_val, exc_tb): | def __exit__(self, exc_type, exc_val, exc_tb): | ||||
| self.release() | self.release() | ||||
| def __and__(self, other): | |||||
| def __or__(self, other): | |||||
| if isinstance(other, GradManager): | if isinstance(other, GradManager): | ||||
| return GradManagerGroup([self, other]) | return GradManagerGroup([self, other]) | ||||
| return NotImplemented | return NotImplemented | ||||
| __rand__ = __and__ | |||||
| __ror__ = __or__ | |||||
| class GradManagerGroup: | class GradManagerGroup: | ||||
| @@ -364,8 +364,6 @@ class GradManagerGroup: | |||||
| return NotImplemented | return NotImplemented | ||||
| return GradManagerGroup([*self._gms, *other._gms]) | return GradManagerGroup([*self._gms, *other._gms]) | ||||
| __and__ = merge_with | |||||
| __rand__ = merge_with | |||||
| __or__ = merge_with | __or__ = merge_with | ||||
| __ror__ = merge_with | __ror__ = merge_with | ||||
| @@ -468,7 +468,7 @@ PyObject* GradKeyWrapper::get_priority() { | |||||
| } | } | ||||
| void GradKeyWrapper::set_priority(pybind11::handle priority) { | void GradKeyWrapper::set_priority(pybind11::handle priority) { | ||||
| m_key->name = py::cast<int>(priority); | |||||
| m_key->priority = py::cast<int>(priority); | |||||
| } | } | ||||
| void GradKeyWrapper::attach(PyObject*const* args, size_t nargs) { | void GradKeyWrapper::attach(PyObject*const* args, size_t nargs) { | ||||
| @@ -535,7 +535,7 @@ void GradKey::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWr | |||||
| size_t priority_backup; | size_t priority_backup; | ||||
| CleanupGuard(GradKey* this_) : owner(this_) { | CleanupGuard(GradKey* this_) : owner(this_) { | ||||
| priority_backup = sm_min_priority; | priority_backup = sm_min_priority; | ||||
| sm_min_priority = owner->priority; | |||||
| sm_min_priority = owner->priority + 1; | |||||
| } | } | ||||
| ~CleanupGuard() { | ~CleanupGuard() { | ||||
| owner->cleanup(); | owner->cleanup(); | ||||
| @@ -636,7 +636,7 @@ PyObject* GradKeyWrapper::is_attached_to(PyObject*const* args, size_t nargs) { | |||||
| Py_RETURN_FALSE; | Py_RETURN_FALSE; | ||||
| } | } | ||||
| int GradKey::sm_min_priority = 0; | |||||
| int GradKey::sm_min_priority = std::numeric_limits<int>::min(); | |||||
| GradKey::~GradKey() { | GradKey::~GradKey() { | ||||
| cleanup(); | cleanup(); | ||||
| @@ -966,6 +966,7 @@ void init_tensor(py::module m) { | |||||
| .def<&GradKeyWrapper::attach>("attach") | .def<&GradKeyWrapper::attach>("attach") | ||||
| .def<&GradKeyWrapper::is_attached_to>("is_attached_to") | .def<&GradKeyWrapper::is_attached_to>("is_attached_to") | ||||
| .def_getset<&GradKeyWrapper::get_name, &GradKeyWrapper::set_name>("name") | .def_getset<&GradKeyWrapper::get_name, &GradKeyWrapper::set_name>("name") | ||||
| .def_getset<&GradKeyWrapper::get_priority, &GradKeyWrapper::set_priority>("priority") | |||||
| .finalize(); | .finalize(); | ||||
| if (!grad_key_type) throw py::error_already_set(); | if (!grad_key_type) throw py::error_already_set(); | ||||
| py::setattr(m, "GradKey", grad_key_type); | py::setattr(m, "GradKey", grad_key_type); | ||||
| @@ -279,3 +279,69 @@ def test_broadcast_grad(trace_mode): | |||||
| func() | func() | ||||
| worker() | worker() | ||||
| def test_2nd_grad_with_manager(): | |||||
| x_np = np.random.rand(10).astype("float32") | |||||
| x = mge.tensor(x_np) | |||||
| gm = GradManager().attach([x]) | |||||
| gm2 = GradManager().attach([x]) | |||||
| with gm: | |||||
| with gm2: | |||||
| y = F.cos(x) | |||||
| gm2.backward(y) | |||||
| np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5) | |||||
| gm.backward(x.grad) | |||||
| np.testing.assert_almost_equal( | |||||
| x.grad.numpy(), -np.sin(x_np) - np.cos(x_np), decimal=5 | |||||
| ) | |||||
| def test_grad_manager_group(): | |||||
| x_np = np.random.rand(10).astype("float32") | |||||
| x = mge.tensor(x_np) | |||||
| gm = GradManager().attach([x]) | |||||
| gm2 = GradManager().attach([x]) | |||||
| with gm | gm2: | |||||
| y = F.cos(x) | |||||
| gm.backward(y) | |||||
| gm2.backward(y) | |||||
| np.testing.assert_almost_equal(x.grad.numpy(), -2 * np.sin(x_np), decimal=5) | |||||
| x.grad = None | |||||
| def test_grad_manager_group_visibility(): | |||||
| x_np = np.random.rand(10).astype("float32") | |||||
| x = mge.tensor(x_np) | |||||
| gm = GradManager().attach([x]) | |||||
| gm2 = GradManager().attach([x]) | |||||
| with gm | gm2: | |||||
| y = F.cos(x) | |||||
| gm2.backward(y) | |||||
| np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5) | |||||
| gm.backward(x.grad) | |||||
| np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5) | |||||
| def test_grad_manager_visibility_by_order(): | |||||
| x_np = np.random.rand(10).astype("float32") | |||||
| x = mge.tensor(x_np) | |||||
| gm = GradManager().attach([x]) | |||||
| gm2 = GradManager().attach([x]) | |||||
| with gm2: | |||||
| with gm: | |||||
| y = F.cos(x) | |||||
| gm2.backward(y) | |||||
| np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5) | |||||
| gm.backward(x.grad) | |||||
| np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5) | |||||
| @@ -126,7 +126,7 @@ def test_2nd_grad(): | |||||
| x.grad = None | x.grad = None | ||||
| grad2(z, ones) | grad2(z, ones) | ||||
| np.testing.assert_almost_equal(x.grad.numpy(), -np.cos(x_np)) | |||||
| np.testing.assert_almost_equal(x.grad.numpy(), -np.cos(x_np), decimal=5) | |||||
| def test_grad_with_tensor_wrapper(): | def test_grad_with_tensor_wrapper(): | ||||