| @@ -10,6 +10,7 @@ from .adadelta import Adadelta | |||
| from .adagrad import Adagrad | |||
| from .adam import Adam | |||
| from .adamw import AdamW | |||
| from .clip_grad import * | |||
| from .lr_scheduler import LRScheduler | |||
| from .multi_step_lr import MultiStepLR | |||
| from .optimizer import Optimizer | |||
| @@ -0,0 +1,72 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # pylint: disable=redefined-builtin | |||
| from typing import Iterable, Union | |||
| from ..core._imperative_rt.core2 import pop_scope, push_scope | |||
| from ..functional import clip, concat, minimum, norm | |||
| from ..tensor import Tensor | |||
| __all__ = ["clip_grad_norm", "clip_grad_value"] | |||
| def clip_grad_norm( | |||
| tensors: Union[Tensor, Iterable[Tensor]], max_norm: float, ord: float = 2.0, | |||
| ): | |||
| r"""Clips gradient norm of an iterable of parameters. | |||
| The norm is computed over all gradients together, as if they were | |||
| concatenated into a single vector. Gradients are modified in-place. | |||
| :param tensors: an iterable of Tensors or a single Tensor. | |||
| :param max_norm: max norm of the gradients. | |||
| :param ord: type of the used p-norm. Can be ``'inf'`` for infinity norm. | |||
| :return: total norm of the parameters (viewed as a single vector). | |||
| """ | |||
| push_scope("clip_grad_norm") | |||
| if isinstance(tensors, Tensor): | |||
| tensors = [tensors] | |||
| tensors = [t for t in tensors if t.grad is not None] | |||
| if len(tensors) == 0: | |||
| pop_scope("clip_grad_norm") | |||
| return Tensor(0.0) | |||
| norm_ = [norm(t.grad.flatten(), ord=ord) for t in tensors] | |||
| if len(norm_) > 1: | |||
| norm_ = norm(concat(norm_), ord=ord) | |||
| else: | |||
| norm_ = norm_[0] | |||
| scale = max_norm / (norm_ + 1e-6) | |||
| scale = minimum(scale, 1) | |||
| for tensor in tensors: | |||
| tensor.grad._reset(tensor.grad * scale) | |||
| pop_scope("clip_grad_norm") | |||
| return norm_ | |||
| def clip_grad_value( | |||
| tensors: Union[Tensor, Iterable[Tensor]], lower: float, upper: float | |||
| ): | |||
| r"""Clips gradient of an iterable of parameters to a specified lower and | |||
| upper. Gradients are modified in-place. | |||
| The gradients are clipped in the range: | |||
| .. math:: \left[\text{lower}, \text{upper}\right] | |||
| :param tensors: an iterable of Tensors or a single Tensor. | |||
| :param lower: minimum allowed value of the gradients. | |||
| :param upper: maximum allowed value of the gradients. | |||
| """ | |||
| push_scope("clip_grad_value") | |||
| if isinstance(tensors, Tensor): | |||
| tensors = [tensors] | |||
| for tensor in tensors: | |||
| if tensor.grad is None: | |||
| continue | |||
| tensor.grad._reset(clip(tensor.grad, lower, upper)) | |||
| pop_scope("clip_grad_value") | |||
| @@ -0,0 +1,120 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import itertools | |||
| import numpy as np | |||
| import pytest | |||
| import megengine as mge | |||
| import megengine.autodiff as ad | |||
| import megengine.functional as F | |||
| import megengine.optimizer as optim | |||
| from megengine import Tensor | |||
| from megengine.jit import trace | |||
| from megengine.module import Linear, Module | |||
| from megengine.optimizer import SGD | |||
| batch_size = 64 | |||
| data_shape = (batch_size, 2) | |||
| label_shape = (batch_size,) | |||
| def minibatch_generator(): | |||
| while True: | |||
| inp_data = np.zeros((batch_size, 2)) | |||
| label = np.zeros(batch_size, dtype=np.int32) | |||
| for i in range(batch_size): | |||
| # [x0, x1], sampled from U[-1, 1] | |||
| inp_data[i, :] = np.random.rand(2) * 2 - 1 | |||
| label[i] = 0 if np.prod(inp_data[i]) < 0 else 1 | |||
| yield inp_data.astype(np.float32), label.astype(np.int32) | |||
| def calculate_precision(data: np.ndarray, pred: np.ndarray) -> float: | |||
| """ Calculate precision for given data and prediction. | |||
| :type data: [[x, y], ...] | |||
| :param data: Input data | |||
| :type pred: [[x_pred, y_pred], ...] | |||
| :param pred: Network output data | |||
| """ | |||
| correct = 0 | |||
| assert len(data) == len(pred) | |||
| for inp_data, pred_output in zip(data, pred): | |||
| label = 0 if np.prod(inp_data) < 0 else 1 | |||
| pred_label = np.argmax(pred_output) | |||
| if pred_label == label: | |||
| correct += 1 | |||
| return float(correct) / len(data) | |||
| class XORNet(Module): | |||
| def __init__(self): | |||
| self.mid_layers = 14 | |||
| self.num_class = 2 | |||
| super().__init__() | |||
| self.fc0 = Linear(self.num_class, self.mid_layers, bias=True) | |||
| self.fc1 = Linear(self.mid_layers, self.mid_layers, bias=True) | |||
| self.fc2 = Linear(self.mid_layers, self.num_class, bias=True) | |||
| def forward(self, x): | |||
| x = self.fc0(x) | |||
| x = F.tanh(x) | |||
| x = self.fc1(x) | |||
| x = F.tanh(x) | |||
| x = self.fc2(x) | |||
| return x | |||
| def test_training_converge(): | |||
| net = XORNet() | |||
| opt = SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4) | |||
| gm = ad.GradManager().attach(net.parameters()) | |||
| @trace(symbolic=False) | |||
| def train(data, label): | |||
| with gm: | |||
| pred = net(data) | |||
| loss = F.nn.cross_entropy(pred, label) | |||
| gm.backward(loss) | |||
| optim.clip_grad_norm(net.parameters(), max_norm=0.2, ord=2.0) | |||
| return loss | |||
| def infer(data): | |||
| return net(data) | |||
| train_dataset = minibatch_generator() | |||
| losses = [] | |||
| for data, label in itertools.islice(train_dataset, 2000): | |||
| data = Tensor(data, dtype=np.float32) | |||
| label = Tensor(label, dtype=np.int32) | |||
| opt.clear_grad() | |||
| loss = train(data, label) | |||
| optim.clip_grad_value(net.parameters(), lower=-0.1, upper=0.1) | |||
| opt.step() | |||
| losses.append(loss.numpy()) | |||
| print(np.mean(losses[-100:])) | |||
| assert np.mean(losses[-100:]) < 0.1, "Final training Loss must be low enough" | |||
| ngrid = 10 | |||
| x = np.linspace(-1.0, 1.0, ngrid) | |||
| xx, yy = np.meshgrid(x, x) | |||
| xx = xx.reshape((ngrid * ngrid, 1)) | |||
| yy = yy.reshape((ngrid * ngrid, 1)) | |||
| data = np.concatenate((xx, yy), axis=1).astype(np.float32) | |||
| pred = infer(data).numpy() | |||
| precision = calculate_precision(data, pred) | |||
| print("precision=", precision) | |||
| assert precision == 1.0, "Test precision must be high enough, get {}".format( | |||
| precision | |||
| ) | |||
| @@ -0,0 +1,80 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import platform | |||
| import weakref | |||
| import numpy as np | |||
| import pytest | |||
| import megengine as mge | |||
| import megengine.autodiff as ad | |||
| import megengine.functional as F | |||
| import megengine.module as M | |||
| import megengine.optimizer as optim | |||
| class Net(M.Module): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.conv1 = M.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) | |||
| self.bn1 = M.BatchNorm2d(64) | |||
| self.avgpool = M.AvgPool2d(kernel_size=5, stride=5, padding=0) | |||
| self.fc = M.Linear(64, 10) | |||
| def forward(self, x): | |||
| x = self.conv1(x) | |||
| x = self.bn1(x) | |||
| x = F.relu(x) | |||
| x = self.avgpool(x) | |||
| x = F.avg_pool2d(x, 22) | |||
| x = F.flatten(x, 1) | |||
| x = self.fc(x) | |||
| return x | |||
| def save_grad_value(net): | |||
| for param in net.parameters(): | |||
| param.grad_backup = param.grad.numpy().copy() | |||
| def test_clip_grad_norm(): | |||
| net = Net() | |||
| x = mge.tensor(np.random.randn(10, 3, 224, 224)) | |||
| gm = ad.GradManager().attach(net.parameters()) | |||
| opt = optim.SGD(net.parameters(), 1e-3, momentum=0.9) | |||
| with gm: | |||
| loss = net(x).sum() | |||
| gm.backward(loss) | |||
| save_grad_value(net) | |||
| max_norm = 1.0 | |||
| original_norm = optim.clip_grad_norm(net.parameters(), max_norm=max_norm, ord=2) | |||
| scale = max_norm / original_norm | |||
| for param in net.parameters(): | |||
| np.testing.assert_almost_equal(param.grad.numpy(), param.grad_backup * scale) | |||
| opt.step().clear_grad() | |||
| def test_clip_grad_value(): | |||
| net = Net() | |||
| x = np.random.randn(10, 3, 224, 224).astype("float32") | |||
| gm = ad.GradManager().attach(net.parameters()) | |||
| opt = optim.SGD(net.parameters(), 1e-3, momentum=0.9) | |||
| with gm: | |||
| y = net(x) | |||
| y = y.mean() | |||
| gm.backward(y) | |||
| save_grad_value(net) | |||
| max_val = 5 | |||
| min_val = -2 | |||
| optim.clip_grad_value(net.parameters(), lower=min_val, upper=max_val) | |||
| for param in net.parameters(): | |||
| np.testing.assert_almost_equal( | |||
| param.grad.numpy(), | |||
| np.maximum(np.minimum(param.grad_backup, max_val), min_val), | |||
| ) | |||
| opt.step().clear_grad() | |||
| @@ -0,0 +1,58 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import platform | |||
| import weakref | |||
| import numpy as np | |||
| import pytest | |||
| import torch | |||
| import megengine as mge | |||
| import megengine.functional as F | |||
| import megengine.module as M | |||
| import megengine.optimizer as optim | |||
| def make_fake_params(): | |||
| shapes = [(1,), (3, 3), (5, 5, 5), (6, 7, 8, 9)] | |||
| params = [np.random.randn(*shape).astype("float32") for shape in shapes] | |||
| params_mge = [] | |||
| params_torch = [] | |||
| for param in params: | |||
| t = torch.ones(param.shape) | |||
| t.grad = torch.Tensor(param.copy()) | |||
| params_torch.append(t) | |||
| t = mge.functional.ones(param.shape) | |||
| t.grad = mge.tensor(param.copy()) | |||
| params_mge.append(t) | |||
| return params_mge, params_torch | |||
| def test_clip_grad_norm_torch(): | |||
| max_norm = 1.0 | |||
| params_mge, params_torch = make_fake_params() | |||
| norm_torch = torch.nn.utils.clip_grad_norm_(params_torch, max_norm, norm_type=2.0) | |||
| norm_mge = optim.clip_grad_norm(params_mge, max_norm=max_norm, ord=2.0) | |||
| np.testing.assert_allclose(norm_mge.numpy(), norm_torch.numpy(), atol=1e-4) | |||
| for i in range(len(params_mge)): | |||
| np.testing.assert_allclose( | |||
| params_mge[i].grad.numpy(), params_torch[i].grad.numpy(), atol=1e-7 | |||
| ) | |||
| def test_clip_grad_value_torch(): | |||
| max_val = 0.5 | |||
| min_val = -0.5 | |||
| params_mge, params_torch = make_fake_params() | |||
| torch.nn.utils.clip_grad_value_(params_torch, clip_value=max_val) | |||
| optim.clip_grad_value(params_mge, lower=min_val, upper=max_val) | |||
| for i in range(len(params_mge)): | |||
| np.testing.assert_allclose( | |||
| params_mge[i].grad.numpy(), params_torch[i].grad.numpy(), atol=1e-7 | |||
| ) | |||