GitOrigin-RevId: 3d06e3db3c
tags/v1.0.0-rc1
| @@ -29,6 +29,7 @@ class GradManager: | |||
| def register_after_backward_callback(self, callback): | |||
| self._after_backward_callback.append(callback) | |||
| return self | |||
| def backward(self, ys, dys=None): | |||
| global backwarding_grad_manager | |||
| @@ -177,6 +177,13 @@ class Grad: | |||
| dys = aslist(dys) | |||
| assert len(ys) == len(dys) | |||
| ids = [i for i, y in enumerate(ys) if self in y._extra_data.keys()] | |||
| if len(ids) == 0: | |||
| return | |||
| ys = [y for i, y in enumerate(ys) if i in ids] | |||
| dys = [dy for i, dy in enumerate(dys) if i in ids] | |||
| # ys is changed to a list of VariableNode which contains more information | |||
| # such as OpNode, callback, etc. | |||
| ys = [i._extra_data[self].node for i in ys] | |||
| @@ -20,8 +20,8 @@ from ..core.autodiff.grad import ( | |||
| from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend | |||
| from ..core.tensor.core import apply | |||
| from ..core.tensor.tensor import Tensor, tensor_apply | |||
| from ..tensor import tensor | |||
| from ..device import get_default_device | |||
| from ..tensor import tensor | |||
| from .group import WORLD, Group, get_backend, get_client, get_mm_server_addr, get_rank | |||
| __all__ = [ | |||
| @@ -11,7 +11,7 @@ from typing import Iterable, Union | |||
| import numpy as np | |||
| from ..functional import sqrt | |||
| from ..tensor_nn import Buffer, Parameter | |||
| from ..tensor_nn import Parameter | |||
| from .optimizer import Optimizer | |||
| @@ -63,16 +63,7 @@ class Adadelta(Optimizer): | |||
| for param in param_group["params"]: | |||
| if param.__wrapped__ in self._grad_skip: | |||
| self._grad_skip.remove(param.__wrapped__) | |||
| continue | |||
| if not isinstance(param.grad, Buffer): | |||
| raise TypeError( | |||
| "grad must be a Buffer, maybe you forget to call backward()?" | |||
| ) | |||
| if not param.requires_grad: | |||
| if not param.requires_grad or "grad" not in param.__dict__: | |||
| continue | |||
| states = self._state[param] | |||
| @@ -91,5 +82,3 @@ class Adadelta(Optimizer): | |||
| acc_delta = rho * acc_delta + (1 - rho) * delta ** 2 | |||
| states["square_avg"]._reset(square_avg) | |||
| states["acc_delta"]._reset(acc_delta) | |||
| assert len(self._grad_skip) == 0 | |||
| @@ -11,7 +11,7 @@ from typing import Iterable, Union | |||
| import numpy as np | |||
| from ..functional import sqrt | |||
| from ..tensor_nn import Buffer, Parameter | |||
| from ..tensor_nn import Parameter | |||
| from .optimizer import Optimizer | |||
| @@ -62,16 +62,7 @@ class Adagrad(Optimizer): | |||
| for param in param_group["params"]: | |||
| if param.__wrapped__ in self._grad_skip: | |||
| self._grad_skip.remove(param.__wrapped__) | |||
| continue | |||
| if not isinstance(param.grad, Buffer): | |||
| raise TypeError( | |||
| "grad must be a Buffer, maybe you forget to call backward()?" | |||
| ) | |||
| if not param.requires_grad: | |||
| if not param.requires_grad or "grad" not in param.__dict__: | |||
| continue | |||
| states = self._state[param] | |||
| @@ -87,4 +78,3 @@ class Adagrad(Optimizer): | |||
| clr = lr / (1 + (step - 1) * lr_decay) | |||
| param -= clr * delta | |||
| assert len(self._grad_skip) == 0 | |||
| @@ -8,7 +8,7 @@ | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from typing import Iterable, Tuple, Union | |||
| from ..tensor_nn import Buffer, Parameter | |||
| from ..tensor_nn import Parameter | |||
| from .optimizer import Optimizer | |||
| @@ -59,18 +59,9 @@ class Adam(Optimizer): | |||
| for param in param_group["params"]: | |||
| if param.__wrapped__ in self._grad_skip: | |||
| self._grad_skip.remove(param.__wrapped__) | |||
| if not param.requires_grad or "grad" not in param.__dict__: | |||
| continue | |||
| if not param.requires_grad: | |||
| continue | |||
| if not isinstance(param.grad, Buffer): | |||
| raise TypeError( | |||
| "grad must be a Buffer, maybe you forget to call backward()?" | |||
| ) | |||
| grad = param.grad | |||
| if weight_decay != 0.0: | |||
| grad += param * weight_decay | |||
| @@ -91,5 +82,3 @@ class Adam(Optimizer): | |||
| # not inplace change, need to update underlying tensor handler in state | |||
| states["exp_avg"]._reset(exp_avg) | |||
| states["exp_avg_sq"]._reset(exp_avg_sq) | |||
| assert len(self._grad_skip) == 0 | |||
| @@ -8,7 +8,7 @@ | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from typing import Iterable, Union | |||
| from ..tensor_nn import Buffer, Parameter | |||
| from ..tensor_nn import Parameter | |||
| from .optimizer import Optimizer | |||
| @@ -52,7 +52,7 @@ class SGD(Optimizer): | |||
| momentum = param_group["momentum"] | |||
| for param in param_group["params"]: | |||
| if not param.requires_grad: | |||
| if not param.requires_grad or "grad" not in param.__dict__: | |||
| continue | |||
| grad = param.grad | |||
| @@ -9,6 +9,7 @@ | |||
| import numpy as np | |||
| import megengine | |||
| import megengine.autodiff as ad | |||
| import megengine.optimizer as optimizer | |||
| from megengine import Parameter, tensor | |||
| from megengine.module import Module | |||
| @@ -37,8 +38,9 @@ class Simple2(Module): | |||
| def test_advance_indexing(): | |||
| net = Simple() | |||
| gm = ad.GradManager().register(net.parameters()) | |||
| optim = optimizer.SGD(net.parameters(), lr=1.0) | |||
| optim.zero_grad() | |||
| optim.clear_grad() | |||
| dshape = (10, 10) | |||
| raw_data = np.arange(100).reshape(dshape).astype(np.float32) | |||
| @@ -46,9 +48,9 @@ def test_advance_indexing(): | |||
| data = tensor(raw_data) | |||
| mask = tensor(raw_mask) | |||
| answer = 1.0 - raw_data[raw_mask].sum() | |||
| with optim.record(): | |||
| with gm.record(): | |||
| loss = net(data, mask).sum() | |||
| optim.backward(loss) | |||
| gm.backward(loss) | |||
| optim.step() | |||
| np.testing.assert_almost_equal(net.a.numpy(), np.array([answer]).astype(np.float32)) | |||
| @@ -56,15 +58,16 @@ def test_advance_indexing(): | |||
| def test_advance_indexing_with_subtensor(): | |||
| net = Simple2() | |||
| gm = ad.GradManager().register(net.parameters()) | |||
| optim = optimizer.SGD(net.parameters(), lr=1.0) | |||
| optim.zero_grad() | |||
| optim.clear_grad() | |||
| dshape = (2, 3, 4, 3, 4, 2) | |||
| raw_data = np.arange(576).reshape(dshape).astype(np.float32) | |||
| data = tensor(raw_data) | |||
| answer = 1.0 - raw_data[1, ..., :, 0:4:2, 0:2].sum() | |||
| with optim.record(): | |||
| with gm.record(): | |||
| loss = net(data).sum() | |||
| optim.backward(loss) | |||
| gm.backward(loss) | |||
| optim.step() | |||
| np.testing.assert_almost_equal(net.a.numpy(), np.array([answer]).astype(np.float32)) | |||
| @@ -9,6 +9,7 @@ | |||
| import numpy as np | |||
| import megengine | |||
| import megengine.autodiff as ad | |||
| import megengine.optimizer as optimizer | |||
| from megengine import Parameter, tensor | |||
| from megengine.module import Module | |||
| @@ -27,14 +28,15 @@ class Simple(Module): | |||
| def test_ai(): | |||
| net = Simple() | |||
| gm = ad.GradManager().register(net.parameters()) | |||
| optim = optimizer.SGD(net.parameters(), lr=1.0) | |||
| optim.zero_grad() | |||
| optim.clear_grad() | |||
| dshape = (10, 10) | |||
| data = tensor(np.ones(dshape).astype(np.float32)) | |||
| with optim.record(): | |||
| with gm.record(): | |||
| loss = net(data).sum() | |||
| optim.backward(loss) | |||
| gm.backward(loss) | |||
| optim.step() | |||
| np.testing.assert_almost_equal( | |||
| net.a.numpy(), np.array([1.0 - dshape[0]]).astype(np.float32) | |||
| @@ -10,6 +10,7 @@ import numpy as np | |||
| import pytest | |||
| import megengine | |||
| import megengine.autodiff as ad | |||
| import megengine.optimizer as optimizer | |||
| from megengine import Parameter, tensor | |||
| from megengine.module import BatchNorm2d | |||
| @@ -24,13 +25,14 @@ def test_frozen_bn(): | |||
| saved_wt = m.weight.numpy() | |||
| saved_bias = m.bias.numpy() | |||
| gm = ad.GradManager().register(m.parameters()) | |||
| optim = optimizer.SGD(m.parameters(), lr=1.0) | |||
| optim.zero_grad() | |||
| optim.clear_grad() | |||
| data = np.random.random((6, nchannel, 2, 2)).astype("float32") | |||
| with optim.record(): | |||
| with gm.record(): | |||
| loss = m(data).mean() | |||
| optim.backward(loss) | |||
| gm.backward(loss) | |||
| optim.step() | |||
| np.testing.assert_equal(m.running_var.numpy(), saved_var) | |||
| @@ -44,13 +46,14 @@ def test_bn_no_track_stat(): | |||
| nchannel = 3 | |||
| m = BatchNorm2d(nchannel, track_running_stats=False) | |||
| gm = ad.GradManager().register(m.parameters()) | |||
| optim = optimizer.SGD(m.parameters(), lr=1.0) | |||
| optim.zero_grad() | |||
| optim.clear_grad() | |||
| data = np.random.random((6, nchannel, 2, 2)).astype("float32") | |||
| with optim.record(): | |||
| with gm.record(): | |||
| loss = m(data).sum() | |||
| optim.backward(loss) | |||
| gm.backward(loss) | |||
| optim.step() | |||
| @@ -65,13 +68,14 @@ def test_bn_no_track_stat2(): | |||
| saved_mean = m.running_mean.numpy() | |||
| assert saved_mean is not None | |||
| gm = ad.GradManager().register(m.parameters()) | |||
| optim = optimizer.SGD(m.parameters(), lr=1.0) | |||
| optim.zero_grad() | |||
| optim.clear_grad() | |||
| data = np.random.random((6, nchannel, 2, 2)).astype("float32") | |||
| with optim.record(): | |||
| with gm.record(): | |||
| loss = m(data).sum() | |||
| optim.backward(loss) | |||
| gm.backward(loss) | |||
| optim.step() | |||
| np.testing.assert_equal(m.running_var.numpy(), saved_var) | |||
| @@ -12,6 +12,7 @@ import numpy as np | |||
| import pytest | |||
| import megengine as mge | |||
| import megengine.autodiff as ad | |||
| import megengine.functional as F | |||
| from megengine import Tensor | |||
| from megengine.module import Linear, Module | |||
| @@ -76,12 +77,13 @@ def test_training_converge(): | |||
| opt = SGD( | |||
| net.parameters(requires_grad=True), lr=0.01, momentum=0.9, weight_decay=5e-4 | |||
| ) | |||
| gm = ad.GradManager().register(net.parameters()) | |||
| def train(data, label): | |||
| with opt.record(): | |||
| with gm.record(): | |||
| pred = net(data) | |||
| loss = F.cross_entropy_with_softmax(pred, label) | |||
| opt.backward(loss) | |||
| gm.backward(loss) | |||
| return loss | |||
| def infer(data): | |||
| @@ -93,7 +95,7 @@ def test_training_converge(): | |||
| for data, label in itertools.islice(train_dataset, 2000): | |||
| data = Tensor(data, dtype=np.float32) | |||
| label = Tensor(label, dtype=np.int32) | |||
| opt.zero_grad() | |||
| opt.clear_grad() | |||
| loss = train(data, label) | |||
| opt.step() | |||
| losses.append(loss.numpy()) | |||
| @@ -15,6 +15,7 @@ import numpy as np | |||
| import pytest | |||
| import megengine as mge | |||
| import megengine.autodiff as ad | |||
| import megengine.functional as F | |||
| from megengine import jit | |||
| from megengine.core._trace_option import set_tensor_shape | |||
| @@ -89,11 +90,11 @@ class MnistNet(Module): | |||
| return x | |||
| def train(data, label, net, opt): | |||
| with opt.record(): | |||
| def train(data, label, net, opt, gm): | |||
| with gm.record(): | |||
| pred = net(data) | |||
| loss = F.cross_entropy_with_softmax(pred, label) | |||
| opt.backward(loss) | |||
| gm.backward(loss) | |||
| return loss | |||
| @@ -116,12 +117,13 @@ def update_model(model_path): | |||
| net.load_state_dict(checkpoint["net_init"]) | |||
| lr = checkpoint["sgd_lr"] | |||
| opt = SGD(net.parameters(), lr=lr) | |||
| gm = ad.GradManager().register(net.parameters()) | |||
| data = Tensor(checkpoint["data"], dtype=np.float32) | |||
| label = Tensor(checkpoint["label"], dtype=np.int32) | |||
| opt.zero_grad() | |||
| loss = train(data, label, net=net, opt=opt) | |||
| opt.clear_grad() | |||
| loss = train(data, label, net, opt, gm) | |||
| opt.step() | |||
| xpu_name = get_xpu_name() | |||
| @@ -150,6 +152,7 @@ def run_train( | |||
| net.load_state_dict(checkpoint["net_init"]) | |||
| lr = checkpoint["sgd_lr"] | |||
| opt = SGD(net.parameters(), lr=lr) | |||
| gm = ad.GradManager().register(net.parameters()) | |||
| data = Tensor(checkpoint["data"], dtype=np.float32) | |||
| label = Tensor(checkpoint["label"], dtype=np.int32) | |||
| @@ -165,8 +168,8 @@ def run_train( | |||
| sublinear_memory_config=sublinear_memory_config, | |||
| ) | |||
| opt.zero_grad() | |||
| loss = train_func(data, label, net=net, opt=opt) | |||
| opt.clear_grad() | |||
| loss = train_func(data, label, net, opt, gm) | |||
| opt.step() | |||
| assertTensorClose(loss.numpy(), checkpoint["loss"], max_err=max_err) | |||
| @@ -9,6 +9,7 @@ | |||
| import numpy as np | |||
| import megengine | |||
| import megengine.autodiff as ad | |||
| import megengine.optimizer as optimizer | |||
| from megengine import Parameter, tensor | |||
| from megengine.module import Module | |||
| @@ -30,13 +31,14 @@ def test_detach(): | |||
| net = Simple() | |||
| optim = optimizer.SGD(net.parameters(), lr=1.0) | |||
| optim.zero_grad() | |||
| optim.clear_grad() | |||
| gm = ad.GradManager().register(net.parameters()) | |||
| dshape = (10, 10) | |||
| data = tensor(np.ones(dshape).astype(np.float32)) | |||
| with optim.record(): | |||
| with gm.record(): | |||
| loss = net(data).sum() | |||
| optim.backward(loss) | |||
| gm.backward(loss) | |||
| optim.step() | |||
| np.testing.assert_equal(net.a.numpy(), np.array([1.0]).astype(np.float32)) | |||
| np.testing.assert_equal( | |||
| @@ -18,6 +18,7 @@ import numpy as np | |||
| import pytest | |||
| import megengine as mge | |||
| import megengine.autodiff as ad | |||
| import megengine.distributed as dist | |||
| import megengine.functional as F | |||
| from megengine.device import get_default_device, set_default_device | |||
| @@ -94,11 +95,13 @@ class MnistNet(Module): | |||
| return x | |||
| def train(data, label, net, opt): | |||
| with opt.record(): | |||
| def train(data, label, net, opt, gm): | |||
| opt.clear_grad() | |||
| with gm.record(): | |||
| pred = net(data) | |||
| loss = F.cross_entropy_with_softmax(pred, label) | |||
| opt.backward(loss) | |||
| gm.backward(loss) | |||
| opt.step() | |||
| return loss | |||
| @@ -111,7 +114,7 @@ def update_model(model_path): | |||
| .. code-block:: python | |||
| from test_correctness import update_model | |||
| from test_dp_correctness import update_model | |||
| update_model('mnist_model_with_test.mge') # for gpu | |||
| update_model('mnist_model_with_test_cpu.mge') # for cpu | |||
| @@ -122,6 +125,11 @@ def update_model(model_path): | |||
| lr = checkpoint["sgd_lr"] | |||
| opt = SGD(net.parameters(), lr=lr) | |||
| gm = ad.GradManager() | |||
| gm.register( | |||
| net.parameters(), callbacks=[dist.make_allreduce_cb("MEAN", dist.WORLD)] | |||
| ) | |||
| data = Tensor(checkpoint["data"], dtype=np.float32) | |||
| label = Tensor(checkpoint["label"], dtype=np.int32) | |||
| @@ -158,24 +166,23 @@ def run_test( | |||
| def worker(rank, max_err): | |||
| dist.init_process_group("localhost", port, p_num, rank, rank) | |||
| set_default_device(device="gpu{}".format(dist.get_rank())) | |||
| net = MnistNet(has_bn=True) | |||
| net.load_state_dict(checkpoint["net_init"]) | |||
| lr = checkpoint["sgd_lr"] | |||
| opt = SGD(net.parameters(), reduce_method="mean", lr=lr) | |||
| opt = SGD(net.parameters(), lr=lr) | |||
| gm = ad.GradManager() | |||
| gm.register( | |||
| net.parameters(), callbacks=[dist.make_allreduce_cb("MEAN", dist.WORLD)] | |||
| ) | |||
| # use same data and label for all gpu's | |||
| # such that the result does not depend on number of gpu | |||
| data_train = Tensor(data) | |||
| label_train = Tensor(label) | |||
| train_func = train | |||
| opt.zero_grad() | |||
| loss = train_func(data_train, label_train, net=net, opt=opt) | |||
| opt.step() | |||
| loss = train(data_train, label_train, net, opt, gm) | |||
| print("{} loss {}".format(get_default_device(), loss.numpy()[0])) | |||
| assertTensorClose(loss.numpy(), checkpoint["loss"], max_err=max_err) | |||
| if dist.get_rank(): | |||
| @@ -12,6 +12,7 @@ import numpy as np | |||
| import pytest | |||
| import megengine | |||
| import megengine.autodiff as ad | |||
| import megengine.optimizer as optimizer | |||
| from megengine import Parameter, tensor | |||
| from megengine.module import Module | |||
| @@ -31,12 +32,13 @@ def test_hello_world(): | |||
| net = Simple() | |||
| optim = optimizer.SGD(net.parameters(), lr=1.0) | |||
| optim.zero_grad() | |||
| optim.clear_grad() | |||
| gm = ad.GradManager().register(net.parameters()) | |||
| data = tensor([2.34]) | |||
| with optim.record(): | |||
| with gm.record(): | |||
| loss = net(data) | |||
| optim.backward(loss) | |||
| gm.backward(loss) | |||
| optim.step() | |||
| np.testing.assert_almost_equal( | |||
| net.a.numpy(), np.array([1.23 - 2.34]).astype(np.float32) | |||
| @@ -8,6 +8,7 @@ | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import numpy as np | |||
| import megengine.autodiff as ad | |||
| import megengine.functional as F | |||
| from megengine import Parameter, optimizer | |||
| from megengine.jit import trace | |||
| @@ -43,6 +44,7 @@ def _test_optimizer(opt_str, test_case, check_class, update_lr=False): | |||
| net = Simple() | |||
| opt = getattr(optimizer, opt_str)(net.parameters(), **test_case) | |||
| check_func = check_class(net, **test_case) | |||
| gm = ad.GradManager().register(net.parameters()) | |||
| step = 0 | |||
| data_shape = (2, 28) | |||
| @@ -54,11 +56,11 @@ def _test_optimizer(opt_str, test_case, check_class, update_lr=False): | |||
| check_func.lr += 0.01 | |||
| data = tensor(np.random.random(data_shape).astype(np.float32)) | |||
| opt.zero_grad() | |||
| with opt.record(): | |||
| opt.clear_grad() | |||
| with gm.record(): | |||
| pred = net(data) | |||
| loss = pred.sum() | |||
| opt.backward(loss) | |||
| gm.backward(loss) | |||
| ori_params = TensorDict() | |||
| for param in net.parameters(): | |||
| @@ -1,6 +1,7 @@ | |||
| import numpy as np | |||
| import megengine as mge | |||
| import megengine.autodiff as ad | |||
| import megengine.optimizer as optimizer | |||
| from megengine import Parameter, tensor | |||
| from megengine.core.tensor.raw_tensor import RawTensor | |||
| @@ -21,13 +22,14 @@ def test_save_load(): | |||
| net = Simple() | |||
| optim = optimizer.SGD(net.parameters(), lr=1.0, momentum=0.9) | |||
| optim.zero_grad() | |||
| optim.clear_grad() | |||
| gm = ad.GradManager().register(net.parameters()) | |||
| data = tensor([2.34]) | |||
| with optim.record(): | |||
| with gm.record(): | |||
| loss = net(data) | |||
| optim.backward(loss) | |||
| gm.backward(loss) | |||
| optim.step() | |||
| @@ -53,9 +55,9 @@ def test_save_load(): | |||
| optim.load_state_dict(checkpoint["opt_state"]) | |||
| print("load done") | |||
| with optim.record(): | |||
| with gm.record(): | |||
| loss = net([1.23]) | |||
| optim.backward(loss) | |||
| gm.backward(loss) | |||
| optim.step() | |||
| # Restore device | |||
| @@ -9,6 +9,7 @@ | |||
| import numpy as np | |||
| import megengine | |||
| import megengine.autodiff as ad | |||
| import megengine.optimizer as optimizer | |||
| from megengine import Parameter, tensor | |||
| from megengine.jit import trace | |||
| @@ -29,14 +30,15 @@ def test_sgd_momentum(): | |||
| net = Simple() | |||
| optim = optimizer.SGD(net.parameters(), lr=1.0, momentum=0.9) | |||
| optim.zero_grad() | |||
| optim.clear_grad() | |||
| gm = ad.GradManager().register(net.parameters()) | |||
| data = tensor([2.34]) | |||
| # do a step of train | |||
| with optim.record(): | |||
| with gm.record(): | |||
| loss = net(data) | |||
| optim.backward(loss) | |||
| gm.backward(loss) | |||
| optim.step() | |||
| np.testing.assert_almost_equal(optim._state[net.a]["momentum_buffer"].numpy(), 2.34) | |||
| @@ -48,10 +50,10 @@ def test_sgd_momentum(): | |||
| np.testing.assert_almost_equal(optim._state[net.a]["momentum_buffer"].numpy(), 2.34) | |||
| # do a step of train | |||
| optim.zero_grad() | |||
| with optim.record(): | |||
| optim.clear_grad() | |||
| with gm.record(): | |||
| loss = net(data) | |||
| optim.backward(loss) | |||
| gm.backward(loss) | |||
| optim.step() | |||
| np.testing.assert_almost_equal(loss.numpy(), 2.34 * (1.23 - 2.34), 5) | |||
| @@ -9,6 +9,7 @@ import copy | |||
| import numpy as np | |||
| import megengine.autodiff as ad | |||
| import megengine.functional as F | |||
| import megengine.optimizer as optimizer | |||
| from megengine import Parameter | |||
| @@ -41,13 +42,14 @@ def test_single_input(): | |||
| return x | |||
| net = Simple(av) | |||
| optim = optimizer.SGD(net.parameters(), lr=1.0) | |||
| optim.zero_grad() | |||
| gm = ad.GradManager().register(net.parameters()) | |||
| opt = optimizer.SGD(net.parameters(), lr=1.0) | |||
| with optim.record(): | |||
| opt.clear_grad() | |||
| with gm.record(): | |||
| loss = net() | |||
| optim.backward(loss.sum()) | |||
| optim.step() | |||
| gm.backward(loss.sum()) | |||
| opt.step() | |||
| np.testing.assert_almost_equal(loss.numpy(), (av * 10)) | |||
| np.testing.assert_almost_equal(net.a.numpy(), (av - 10)) | |||
| @@ -79,13 +81,14 @@ def test_multi_input(): | |||
| return x | |||
| net = Simple(av, bv) | |||
| optim = optimizer.SGD(net.parameters(), lr=1.0) | |||
| optim.zero_grad() | |||
| gm = ad.GradManager().register(net.parameters()) | |||
| opt = optimizer.SGD(net.parameters(), lr=1.0) | |||
| with optim.record(): | |||
| opt.clear_grad() | |||
| with gm.record(): | |||
| loss = net() | |||
| optim.backward(loss.sum()) | |||
| optim.step() | |||
| gm.backward(loss.sum()) | |||
| opt.step() | |||
| np.testing.assert_almost_equal(loss.numpy(), (av * bv)) | |||
| np.testing.assert_almost_equal(net.a.numpy(), (av - 2 * bv)) | |||
| @@ -118,13 +121,14 @@ def test_multi_output(): | |||
| return x + y | |||
| net = Simple(av, bv) | |||
| optim = optimizer.SGD(net.parameters(), lr=1.0) | |||
| optim.zero_grad() | |||
| gm = ad.GradManager().register(net.parameters()) | |||
| opt = optimizer.SGD(net.parameters(), lr=1.0) | |||
| with optim.record(): | |||
| opt.clear_grad() | |||
| with gm.record(): | |||
| loss = net() | |||
| optim.backward(loss.sum()) | |||
| optim.step() | |||
| gm.backward(loss.sum()) | |||
| opt.step() | |||
| np.testing.assert_almost_equal(loss.numpy(), (av * bv + av + bv), decimal=6) | |||
| np.testing.assert_almost_equal(net.a.numpy(), (av - bv - 1), decimal=6) | |||