- refactor(mge): add support for optimizer.step().clear_grad() idiom
- refactor(mge): rename some methods of GradManager
- refactor(mge): remove tensor_nn and TensorDict
- refactor(mge): remove Buffer
- refactor(mge): remove requires_grad flag
- refactor(mge): add a default grad=None attribute to Tensor
- refactor(mge): deprecation for 1.0
GitOrigin-RevId: 3b723d9387
tags/v1.0.0-rc1
| @@ -74,8 +74,7 @@ from .core._imperative_rt.utils import _set_fork_exec_path_for_timed_func | |||
| from .device import * | |||
| from .logger import enable_debug_log, get_logger, set_log_file, set_log_level | |||
| from .serialization import load, save | |||
| from .tensor import Tensor, tensor | |||
| from .tensor_nn import Buffer, Parameter | |||
| from .tensor import Parameter, Tensor, tensor | |||
| from .version import __version__ | |||
| _set_fork_exec_path_for_timed_func( | |||
| @@ -22,7 +22,7 @@ class GradManager: | |||
| self._after_backward_callback = [] | |||
| self._gradients = dict() | |||
| def register(self, params, callbacks=None): | |||
| def attach(self, params, callbacks=None): | |||
| if callbacks is None: | |||
| callbacks = [] | |||
| if isinstance(callbacks, Callable): | |||
| @@ -62,7 +62,7 @@ class GradManager: | |||
| if isinstance(grad, Future): | |||
| grad = grad.get() | |||
| param = self._param_dict[p] | |||
| if getattr(param, "grad", None) is None: | |||
| if param.grad is None: | |||
| param.grad = grad | |||
| else: | |||
| param.grad += grad | |||
| @@ -70,9 +70,9 @@ class GradManager: | |||
| self._stop_record() | |||
| backwarding_grad_manager = cache | |||
| def __enter__(self): | |||
| def record(self): | |||
| if self._recording: | |||
| return self | |||
| raise RuntimeError("already recording") | |||
| grad = Grad() | |||
| self._recording = True | |||
| self._grad = grad | |||
| @@ -88,16 +88,22 @@ class GradManager: | |||
| grad.wrt(param_wrapper, callback=callback) | |||
| grad.__enter__() | |||
| return self | |||
| def __exit__(self, exc_type, exc_val, exc_tb): | |||
| def release(self): | |||
| if not self._recording: | |||
| raise RuntimeError("not recording") | |||
| self._stop_record() | |||
| record = __enter__ | |||
| def _stop_record(self): | |||
| if self._grad is not None: | |||
| self._grad.__exit__(None, None, None) | |||
| self._recording = False | |||
| self._grad = None | |||
| self._gradients = dict() | |||
| def __enter__(self): | |||
| self.record() | |||
| return self | |||
| def __exit__(self, exc_type, exc_val, exc_tb): | |||
| self._stop_record() | |||
| @@ -70,7 +70,7 @@ class Dimshuffle(PodOpVisitor): | |||
| return bytes(ctypes.c_uint32(0)) + bytes(self) | |||
| def __init__(self, pattern, ndim=0): | |||
| assert isinstance(pattern, collections.Iterable) | |||
| assert isinstance(pattern, collections.abc.Iterable) | |||
| assert len(pattern) <= TensorShape.MAX_NDIM | |||
| pattern_array = Dimshuffle.Pattern.Pattern_Array() | |||
| for idx, v in enumerate(pattern): | |||
| @@ -231,13 +231,13 @@ class OpNode: | |||
| def _wrap(x): | |||
| if isinstance(x, collections.Sequence): | |||
| if isinstance(x, collections.abc.Sequence): | |||
| return type(x)(map(_wrap, x)) | |||
| return x.graph._wrap(x) | |||
| def _unwrap(x): | |||
| if isinstance(x, collections.Sequence): | |||
| if isinstance(x, collections.abc.Sequence): | |||
| return type(x)(map(_unwrap, x)) | |||
| return x._node | |||
| @@ -166,7 +166,7 @@ def _reduce(mode): | |||
| op = builtin.Reduce(mode=mode, axis=0) | |||
| (result,) = apply(op, data) | |||
| elif isinstance(axis, collections.Iterable): | |||
| elif isinstance(axis, collections.abc.Iterable): | |||
| axis = list(axis) | |||
| axis.sort(reverse=True) | |||
| @@ -204,7 +204,9 @@ def _todo(*_): | |||
| def _expand_args(args): | |||
| if len(args) == 1: | |||
| if isinstance(args[0], (collections.Sequence, TensorBase, TensorWrapperBase)): | |||
| if isinstance( | |||
| args[0], (collections.abc.Sequence, TensorBase, TensorWrapperBase) | |||
| ): | |||
| args = args[0] | |||
| return args | |||
| @@ -143,7 +143,7 @@ def astensor1d(x, *reference, dtype=None, device=None): | |||
| (x,) = Const(x, dtype=dtype, device=device)(*reference) | |||
| return x | |||
| if not isinstance(x, collections.Sequence): | |||
| if not isinstance(x, collections.abc.Sequence): | |||
| raise TypeError | |||
| if any(isinstance(i, (TensorBase, TensorWrapperBase)) for i in x): | |||
| @@ -432,7 +432,7 @@ def argmin( | |||
| [0] | |||
| """ | |||
| if isinstance(axis, collections.Iterable): | |||
| if isinstance(axis, collections.abc.Iterable): | |||
| axis = list(axis) | |||
| axis.sort(reverse=True) | |||
| @@ -486,7 +486,7 @@ def argmax( | |||
| [5] | |||
| """ | |||
| if isinstance(axis, collections.Iterable): | |||
| if isinstance(axis, collections.abc.Iterable): | |||
| axis = list(axis) | |||
| axis.sort(reverse=True) | |||
| @@ -15,7 +15,7 @@ def get_ndtuple(value, *, n, allow_zero=True): | |||
| :type allow_zero: bool | |||
| :param allow_zero: whether to allow zero tuple value""" | |||
| if not isinstance(value, collections.Iterable): | |||
| if not isinstance(value, collections.abc.Iterable): | |||
| value = int(value) | |||
| value = tuple([value for i in range(n)]) | |||
| else: | |||
| @@ -502,7 +502,7 @@ class trace: | |||
| raise TypeError( | |||
| "cannot specify output_names when output is already in dict format" | |||
| ) | |||
| if output_names and not isinstance(output_names, collections.Sequence): | |||
| if output_names and not isinstance(output_names, collections.abc.Sequence): | |||
| output_names = (output_names,) | |||
| if output_names and len(output_names) != len(self._output_bindings): | |||
| raise ValueError( | |||
| @@ -510,7 +510,7 @@ class trace: | |||
| len(self._output_bindings) | |||
| ) | |||
| ) | |||
| if arg_names and not isinstance(arg_names, collections.Sequence): | |||
| if arg_names and not isinstance(arg_names, collections.abc.Sequence): | |||
| arg_names = (arg_names,) | |||
| if arg_names and len(arg_names) != len(self._arg_bindings): | |||
| raise ValueError( | |||
| @@ -646,9 +646,9 @@ class trace: | |||
| def _process_outputs(self, outputs): | |||
| output_names = None | |||
| if isinstance(outputs, collections.Mapping): | |||
| if isinstance(outputs, collections.abc.Mapping): | |||
| output_names, outputs = zip(*sorted(outputs.items())) | |||
| elif not isinstance(outputs, collections.Sequence): | |||
| elif not isinstance(outputs, collections.abc.Sequence): | |||
| outputs = (outputs,) | |||
| if not self._untraced: | |||
| @@ -18,7 +18,6 @@ from .embedding import Embedding | |||
| from .identity import Identity | |||
| from .linear import Linear | |||
| from .module import Module | |||
| from .parampack import ParamPack | |||
| from .pooling import AvgPool2d, MaxPool2d | |||
| from .quant_dequant import DequantStub, QuantStub | |||
| from .sequential import Sequential | |||
| @@ -9,7 +9,7 @@ | |||
| import numpy as np | |||
| from ..functional import leaky_relu, prelu, relu, sigmoid, softmax | |||
| from ..tensor_nn import Parameter | |||
| from ..tensor import Parameter | |||
| from .module import Module | |||
| @@ -12,7 +12,7 @@ import numpy as np | |||
| from ..distributed.group import WORLD, Group | |||
| from ..functional import batch_norm2d, sync_batch_norm | |||
| from ..tensor_nn import Buffer, Parameter, Tensor | |||
| from ..tensor import Parameter, Tensor | |||
| from . import init | |||
| from .module import Module | |||
| @@ -45,8 +45,8 @@ class _BatchNorm(Module): | |||
| tshape = (1, self.num_features, 1, 1) | |||
| if self.track_running_stats: | |||
| self.running_mean = Buffer(np.zeros(tshape, dtype=np.float32)) | |||
| self.running_var = Buffer(np.ones(tshape, dtype=np.float32)) | |||
| self.running_mean = Tensor(np.zeros(tshape, dtype=np.float32)) | |||
| self.running_var = Tensor(np.ones(tshape, dtype=np.float32)) | |||
| else: | |||
| self.running_mean = None | |||
| self.running_var = None | |||
| @@ -13,7 +13,7 @@ import numpy as np | |||
| from ..core.ops._internal import param_defs as P | |||
| from ..functional import conv2d, conv_transpose2d, local_conv2d, relu | |||
| from ..functional.types import _pair, _pair_nonzero | |||
| from ..tensor_nn import Parameter | |||
| from ..tensor import Parameter | |||
| from . import init | |||
| from .module import Module | |||
| @@ -11,7 +11,7 @@ from typing import Optional | |||
| import numpy as np | |||
| from ..functional import embedding as embedding_func | |||
| from ..tensor_nn import Parameter | |||
| from ..tensor import Parameter | |||
| from . import init | |||
| from .module import Module | |||
| @@ -72,6 +72,7 @@ class Embedding(Module): | |||
| max_norm: Optional[float] = None, | |||
| norm_type: Optional[float] = None, | |||
| initial_weight: Parameter = None, | |||
| freeze: bool = False, | |||
| ): | |||
| super().__init__() | |||
| if padding_idx is not None: | |||
| @@ -83,6 +84,7 @@ class Embedding(Module): | |||
| self.norm_type = norm_type | |||
| self.num_embeddings = num_embeddings | |||
| self.embedding_dim = embedding_dim | |||
| self.freeze = freeze | |||
| if initial_weight is None: | |||
| self.weight = Parameter( | |||
| np.random.uniform( | |||
| @@ -101,7 +103,11 @@ class Embedding(Module): | |||
| init.normal_(self.weight) | |||
| def forward(self, inputs): | |||
| return embedding_func(inputs, self.weight) | |||
| if self.freeze: | |||
| weight = self.weight.detach() | |||
| else: | |||
| weight = self.weight | |||
| return embedding_func(inputs, weight) | |||
| @classmethod | |||
| def from_pretrained( | |||
| @@ -166,6 +172,6 @@ class Embedding(Module): | |||
| padding_idx=padding_idx, | |||
| max_norm=max_norm, | |||
| norm_type=norm_type, | |||
| freeze=freeze, | |||
| ) | |||
| embedding.weight.requires_grad = not freeze | |||
| return embedding | |||
| @@ -23,7 +23,7 @@ def fill_(tensor: Tensor, val: Union[float, int]) -> None: | |||
| :param tensor: An n-dimentional tensor to be initialized | |||
| :param val: The value to be filled throughout the tensor | |||
| """ | |||
| tensor.set_value(full(shape=tensor.shape, value=val, dtype=tensor.dtype)) | |||
| tensor._reset(full(shape=tensor.shape, value=val, dtype=tensor.dtype)) | |||
| def zeros_(tensor: Tensor) -> None: | |||
| @@ -50,7 +50,7 @@ def uniform_(tensor: Tensor, a: float = 0.0, b: float = 1.0) -> None: | |||
| :param a: Lower bound of the sampling interval | |||
| :param b: Upper bound of the sampling interval | |||
| """ | |||
| tensor.set_value(uniform(tensor.shape, low=a, high=b).astype(tensor.dtype)) | |||
| tensor._reset(uniform(tensor.shape, low=a, high=b).astype(tensor.dtype)) | |||
| def normal_(tensor: Tensor, mean: float = 0.0, std: float = 1.0) -> None: | |||
| @@ -61,7 +61,7 @@ def normal_(tensor: Tensor, mean: float = 0.0, std: float = 1.0) -> None: | |||
| :param mean: The mean of the normal distribution | |||
| :param std: The standard deviation of the normal distribution | |||
| """ | |||
| tensor.set_value(gaussian(tensor.shape, mean=mean, std=std).astype(tensor.dtype)) | |||
| tensor._reset(gaussian(tensor.shape, mean=mean, std=std).astype(tensor.dtype)) | |||
| def calculate_gain( | |||
| @@ -8,7 +8,7 @@ | |||
| import numpy as np | |||
| from ..functional import linear | |||
| from ..tensor_nn import Parameter | |||
| from ..tensor import Parameter | |||
| from . import init | |||
| from .module import Module | |||
| @@ -5,6 +5,7 @@ | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import warnings | |||
| from abc import ABCMeta, abstractmethod | |||
| from collections import OrderedDict | |||
| from typing import Any, Callable, Iterable, Optional, Set, Tuple, Union | |||
| @@ -14,8 +15,8 @@ import numpy as np | |||
| from ..core.tensor.dtype import is_quantize | |||
| from ..core.tensor.utils import make_shape_tuple | |||
| from ..logger import get_logger | |||
| from ..tensor import Tensor | |||
| from ..tensor_nn import Buffer, Parameter | |||
| from ..tensor import Parameter, Tensor | |||
| from ..utils.deprecation import deprecated | |||
| from ..utils.hook import HookHandler | |||
| logger = get_logger(__name__) | |||
| @@ -48,7 +49,7 @@ def _is_parameter(obj): | |||
| def _is_buffer(obj): | |||
| return isinstance(obj, Buffer) | |||
| return isinstance(obj, Tensor) and not isinstance(obj, Parameter) | |||
| def _is_module(obj): | |||
| @@ -163,49 +164,43 @@ class Module(metaclass=ABCMeta): | |||
| seen=seen, | |||
| ) | |||
| def parameters( | |||
| self, requires_grad: Optional[bool] = None, recursive: bool = True, **kwargs | |||
| ) -> Iterable[Parameter]: | |||
| def parameters(self, recursive: bool = True, **kwargs) -> Iterable[Parameter]: | |||
| r"""Returns an iterable for the :class:`~.Parameter` of the module. | |||
| :param requires_grad: Limitation over the :attr:`~.Parameter.requires_grad` | |||
| attribute of returned :class:`.Parameter`. ``None`` for no limitation. | |||
| :param recursive: If ``True``, returns all :class:`~.Parameter` within this | |||
| module, else only returns :class:`~.Parameter` that are direct attributes | |||
| of this module. | |||
| """ | |||
| if "requires_grad" in kwargs: | |||
| del kwargs["requires_grad"] | |||
| warnings.warn("passing requires_grad has no effect currently") | |||
| def predicate(obj) -> bool: | |||
| return _is_parameter(obj) and ( | |||
| requires_grad is None or obj.requires_grad == requires_grad | |||
| ) | |||
| return _is_parameter(obj) | |||
| yield from self._flatten( | |||
| with_key=False, predicate=predicate, recursive=recursive, **kwargs | |||
| ) | |||
| def named_parameters( | |||
| self, | |||
| requires_grad: Optional[bool] = None, | |||
| prefix: Optional[str] = None, | |||
| recursive: bool = True, | |||
| **kwargs | |||
| self, prefix: Optional[str] = None, recursive: bool = True, **kwargs | |||
| ) -> Iterable[Tuple[str, Parameter]]: | |||
| """Returns an iterable for key :class:`~.Parameter` pairs of the module, where | |||
| ``key`` is the dotted path from this module to the :class:`~.Parameter` . | |||
| :param requires_grad: Limitation over the :attr:`~.Parameter.requires_grad` | |||
| attribute of returned :class:`~.Parameter` . ``None`` for no limitation. | |||
| :param prefix: The prefix prepended to the keys. | |||
| :param recursive: If ``True``, returns all :class:`~.Parameter` within this | |||
| module, else only returns :class:`~.Parameter` that are direct attributes | |||
| of this module. | |||
| """ | |||
| if "requires_grad" in kwargs: | |||
| del kwargs["requires_grad"] | |||
| warnings.warn("passing requires_grad has no effect currently") | |||
| def predicate(obj) -> bool: | |||
| return _is_parameter(obj) and ( | |||
| requires_grad is None or obj.requires_grad == requires_grad | |||
| ) | |||
| return _is_parameter(obj) | |||
| yield from self._flatten( | |||
| with_key=True, | |||
| @@ -215,11 +210,13 @@ class Module(metaclass=ABCMeta): | |||
| **kwargs, | |||
| ) | |||
| def buffers(self, recursive: bool = True, **kwargs) -> Iterable[Buffer]: | |||
| """Returns an iterable for the :class:`~.Buffer` of the module. | |||
| def buffers(self, recursive: bool = True, **kwargs) -> Iterable[Tensor]: | |||
| """Returns an iterable for the buffers of the module. | |||
| :param recursive: If ``True``, returns all :class:`~.Buffer` within this | |||
| module, else only returns :class:`~.Buffer` that are direct attributes | |||
| Buffer is defined to be :class:`~.Tensor` excluding :class:`~.Parameter`. | |||
| :param recursive: If ``True``, returns all buffers within this | |||
| module, else only returns buffers that are direct attributes | |||
| of this module. | |||
| """ | |||
| yield from self._flatten( | |||
| @@ -228,13 +225,15 @@ class Module(metaclass=ABCMeta): | |||
| def named_buffers( | |||
| self, prefix: Optional[str] = None, recursive: bool = True, **kwargs | |||
| ) -> Iterable[Tuple[str, Buffer]]: | |||
| """Returns an iterable for key :class:`~.Buffer` pairs of the module, where | |||
| ``key`` is the dotted path from this module to the :class:`~.Buffer` . | |||
| ) -> Iterable[Tuple[str, Tensor]]: | |||
| """Returns an iterable for key buffer pairs of the module, where | |||
| ``key`` is the dotted path from this module to the buffer. | |||
| Buffer is defined to be :class:`~.Tensor` excluding :class:`~.Parameter`. | |||
| :param prefix: The prefix prepended to the keys. | |||
| :param recursive: If ``True``, returns all :class:`~.Buffer` within this | |||
| module, else only returns :class:`~.Buffer` that are direct attributes | |||
| :param recursive: If ``True``, returns all buffers within this | |||
| module, else only returns buffers that are direct attributes | |||
| of this module. | |||
| """ | |||
| yield from self._flatten( | |||
| @@ -297,6 +296,7 @@ class Module(metaclass=ABCMeta): | |||
| for it in self.modules(): | |||
| fn(it) | |||
| @deprecated(version="1.0") | |||
| def zero_grad(self) -> None: | |||
| """Set all parameters' grads to zero | |||
| """ | |||
| @@ -505,7 +505,7 @@ class Module(metaclass=ABCMeta): | |||
| # scale/zero_points maybe invalid, use pretrained dtype instead. | |||
| if is_quantize(to_be_load.dtype) and is_quantize(var.dtype): | |||
| var = var.astype(to_be_load.dtype) | |||
| var.set_value(to_be_load) | |||
| var._reset(to_be_load) | |||
| loaded.append(k) | |||
| return set(loaded), set(skipped) | |||
| @@ -1,156 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import collections | |||
| from typing import Callable, Iterable, Optional, Tuple | |||
| import numpy as np | |||
| from ..tensor_nn import Parameter, Tensor | |||
| from .module import Module | |||
| class ParamPack(Module): | |||
| r"""Pack module's parameters by gathering their memory to continuous address. | |||
| Using (device, dtype, requires_grad) as key, for example ('gpu0', float32, True), | |||
| parameters with same key will be packed togather. | |||
| It helps a lot for multimachine training by speeding up allreduce gradients. | |||
| :param model: the module you want to pack parameters. | |||
| :param nr_ignore_first: how many parameters will be unpacked at first. | |||
| :param max_size_per_group: upper bound of packed parameters' size in MB. | |||
| :param max_nr_params_per_group: upper bound of the number of parameters of each group. | |||
| """ | |||
| def __init__( | |||
| self, | |||
| model: Module, | |||
| nr_ignore_first: int = 8, | |||
| max_size_per_group: int = 10, | |||
| max_nr_params_per_group: int = 100, | |||
| group_func: Callable = lambda name, param: 0, | |||
| ): | |||
| super().__init__() | |||
| self._model = model | |||
| self._nr_ignore_first = nr_ignore_first | |||
| self._max_size_per_group = max_size_per_group | |||
| self._max_nr_params_per_group = max_nr_params_per_group | |||
| self._group_func = group_func | |||
| self._grouped_params = [] | |||
| self._packed_params = [] | |||
| params = model.named_parameters() | |||
| self._pack_params(params) | |||
| def parameters(self, requires_grad: Optional[bool] = None) -> Iterable[Parameter]: | |||
| for param in self._packed_params: | |||
| if requires_grad is None or param.requires_grad == requires_grad: | |||
| yield param | |||
| def named_parameters( | |||
| self, requires_grad: Optional[bool] = None | |||
| ) -> Iterable[Tuple[str, Parameter]]: | |||
| for idx, param in enumerate(self._packed_params): | |||
| if requires_grad is None or param.requires_grad == requires_grad: | |||
| yield "packed_param_" + str(idx), param | |||
| def _pack_params(self, params: Iterable[Tuple[str, Parameter]]): | |||
| groups = collections.defaultdict(list) | |||
| ignored = 0 | |||
| param_id = 0 | |||
| for name, param in params: | |||
| if self._nr_ignore_first > ignored: | |||
| ignored += 1 | |||
| self._grouped_params.append([{"shape": param.shape, "id": param_id}]) | |||
| param.pack_group_key = self._group_func(name, param) | |||
| self._packed_params.append(param) | |||
| else: | |||
| key = ( | |||
| param.dtype, | |||
| param.device, | |||
| param.requires_grad, | |||
| self._group_func(name, param), | |||
| ) | |||
| groups[key].append({"tensor": param, "id": param_id}) | |||
| param_id += 1 | |||
| for (dtype, device, requires_grad, group_key) in groups.keys(): | |||
| dtype_sz = np.dtype(dtype).itemsize | |||
| align = device.mem_align | |||
| if align < dtype_sz: | |||
| align = 1 | |||
| else: | |||
| assert align % dtype_sz == 0 | |||
| align //= dtype_sz | |||
| group = groups[(dtype, device, requires_grad, group_key)] | |||
| while group: | |||
| aligned_pos = [] | |||
| offset = 0 | |||
| params = [] | |||
| idx = 0 | |||
| while idx < len(group): | |||
| param = group[idx] | |||
| assert param["tensor"].device == device | |||
| padding = (align - (offset & (align - 1))) & (align - 1) | |||
| offset += padding | |||
| aligned_pos.append(offset) | |||
| params.append(param) | |||
| offset += int(np.prod(param["tensor"].shape)) | |||
| idx += 1 | |||
| if ( | |||
| offset * dtype_sz >= self._max_size_per_group * 1024 * 1024 | |||
| or idx >= self._max_nr_params_per_group | |||
| ): | |||
| break | |||
| group = group[idx:] | |||
| if idx == 1: | |||
| # ignore param packs with only one item | |||
| params[0]["tensor"].pack_group_key = group_key | |||
| self._packed_params.append(params[0]["tensor"]) | |||
| self._grouped_params.append( | |||
| [{"shape": params[0]["tensor"].shape, "id": params[0]["id"]}] | |||
| ) | |||
| continue | |||
| packed_value = np.zeros((offset,), dtype=dtype) | |||
| for param, pos in zip(params, aligned_pos): | |||
| val = param["tensor"].numpy() | |||
| packed_value[pos : pos + val.size] = val.flatten() | |||
| new_param = Parameter( | |||
| value=packed_value, | |||
| device=device, | |||
| dtype=dtype, | |||
| requires_grad=requires_grad, | |||
| ) | |||
| new_param.pack_group_key = group_key | |||
| self._packed_params.append(new_param) | |||
| self._grouped_params.append( | |||
| [{"shape": i["tensor"].shape, "id": i["id"]} for i in params] | |||
| ) | |||
| def forward(self, *args, **kwargs): | |||
| replace_param = dict() | |||
| for i in range(len(self._packed_params)): | |||
| packed_param = self._packed_params[i] | |||
| grouped_params = self._grouped_params[i] | |||
| if len(grouped_params) == 1: | |||
| continue | |||
| split = param_pack_split( | |||
| packed_param._symvar, [i["shape"] for i in grouped_params] | |||
| ) | |||
| split = [ | |||
| Parameter(Tensor(i, requires_grad=packed_param.requires_grad)) | |||
| for i in split | |||
| ] | |||
| for j in range(len(split)): | |||
| replace_param[grouped_params[j]["id"]] = split[j] | |||
| self._model.replace_param(replace_param, 0) | |||
| return self._model.forward(*args, **kwargs) | |||
| @@ -12,7 +12,7 @@ import numpy as np | |||
| from ... import module as Float | |||
| from ...core.tensor import dtype | |||
| from ...functional import conv_bias_activation | |||
| from ...tensor_nn import Parameter | |||
| from ...tensor import Parameter | |||
| from ..qat import conv as QAT | |||
| from .module import QuantizedModule | |||
| @@ -5,7 +5,7 @@ | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from ...tensor_nn import Parameter | |||
| from ...tensor import Parameter | |||
| from ..qat import conv_bn as QAT | |||
| from .conv import Conv2d | |||
| @@ -9,7 +9,7 @@ import numpy as np | |||
| from ... import functional as F | |||
| from ...core.tensor import dtype | |||
| from ...tensor_nn import Parameter | |||
| from ...tensor import Parameter | |||
| from ..qat import linear as QAT | |||
| from .module import QuantizedModule | |||
| @@ -11,7 +11,7 @@ from typing import Iterable, Union | |||
| import numpy as np | |||
| from ..functional import sqrt | |||
| from ..tensor_nn import Parameter | |||
| from ..tensor import Parameter | |||
| from .optimizer import Optimizer | |||
| @@ -63,7 +63,7 @@ class Adadelta(Optimizer): | |||
| for param in param_group["params"]: | |||
| if not param.requires_grad or "grad" not in param.__dict__: | |||
| if param.grad is None: | |||
| continue | |||
| states = self._state[param] | |||
| @@ -11,7 +11,7 @@ from typing import Iterable, Union | |||
| import numpy as np | |||
| from ..functional import sqrt | |||
| from ..tensor_nn import Parameter | |||
| from ..tensor import Parameter | |||
| from .optimizer import Optimizer | |||
| @@ -62,7 +62,7 @@ class Adagrad(Optimizer): | |||
| for param in param_group["params"]: | |||
| if not param.requires_grad or "grad" not in param.__dict__: | |||
| if param.grad is None: | |||
| continue | |||
| states = self._state[param] | |||
| @@ -8,7 +8,7 @@ | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from typing import Iterable, Tuple, Union | |||
| from ..tensor_nn import Parameter | |||
| from ..tensor import Parameter | |||
| from .optimizer import Optimizer | |||
| @@ -59,7 +59,7 @@ class Adam(Optimizer): | |||
| for param in param_group["params"]: | |||
| if not param.requires_grad or "grad" not in param.__dict__: | |||
| if param.grad is None: | |||
| continue | |||
| grad = param.grad | |||
| @@ -7,7 +7,7 @@ | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from abc import ABCMeta, abstractmethod | |||
| from collections import Iterable | |||
| from collections.abc import Iterable | |||
| from contextlib import contextmanager | |||
| from typing import Dict | |||
| from typing import Iterable as Iter | |||
| @@ -15,8 +15,7 @@ from typing import Union | |||
| import numpy as np | |||
| from ..tensor import Tensor, TensorDict | |||
| from ..tensor_nn import Buffer, Parameter | |||
| from ..tensor import Parameter, Tensor | |||
| class _RequiredParameter: | |||
| @@ -37,7 +36,7 @@ class Optimizer(metaclass=ABCMeta): | |||
| def __init__( # pylint: disable=too-many-branches | |||
| self, params: Union[Iter[Parameter], dict], defaults: dict, | |||
| ): | |||
| self._state = TensorDict() | |||
| self._state = dict() | |||
| self._defaults = defaults | |||
| if isinstance(params, (Parameter, dict)): | |||
| @@ -93,10 +92,6 @@ class Optimizer(metaclass=ABCMeta): | |||
| "optimizer can only optimize Parameters, but one of the params is " | |||
| + type(param) | |||
| ) | |||
| if not param.requires_grad: | |||
| raise ValueError( | |||
| "optimizer can only optimize Parameters with requires_grad=True" | |||
| ) | |||
| for name, default in self._defaults.items(): | |||
| if default is required and name not in param_group: | |||
| @@ -122,7 +117,7 @@ class Optimizer(metaclass=ABCMeta): | |||
| initializer = np.zeros(param.shape, dtype=np.float32) | |||
| state_dict = self._state.setdefault(param, {}) | |||
| assert state_name not in state_dict | |||
| state = Buffer(initializer) | |||
| state = Tensor(initializer) | |||
| state_dict[state_name] = state | |||
| @abstractmethod | |||
| @@ -140,7 +135,7 @@ class Optimizer(metaclass=ABCMeta): | |||
| params.append(param) | |||
| return params | |||
| def step(self, clear_grad=False): | |||
| def step(self): | |||
| r"""Performs a single optimization step. | |||
| """ | |||
| @@ -152,8 +147,7 @@ class Optimizer(metaclass=ABCMeta): | |||
| "Please use a list instead." | |||
| ) | |||
| self._updates(group) | |||
| if clear_grad: | |||
| self.clear_grad() | |||
| return self | |||
| def clear_grad(self): | |||
| r"""Clear the grad buffer. | |||
| @@ -161,8 +155,7 @@ class Optimizer(metaclass=ABCMeta): | |||
| """ | |||
| for param_group in self.param_groups: | |||
| for param in param_group["params"]: | |||
| if getattr(param, "grad", None) is not None: | |||
| param.grad = None | |||
| param.grad = None | |||
| def state_dict(self) -> Dict: | |||
| r"""Export the optimizer state. | |||
| @@ -171,7 +164,7 @@ class Optimizer(metaclass=ABCMeta): | |||
| """ | |||
| param_groups = [] | |||
| state = dict() | |||
| param2id = TensorDict() | |||
| param2id = dict() | |||
| cur_id = 0 | |||
| for group in self.param_groups: | |||
| @@ -213,8 +206,9 @@ class Optimizer(metaclass=ABCMeta): | |||
| p = param_new | |||
| self._state[p] = state["state"][param_saved].copy() | |||
| for k, v in self._state[p].items(): | |||
| if isinstance(v, Buffer): | |||
| self._state[p][k] = Buffer(v.numpy()) | |||
| if isinstance(v, Tensor): | |||
| # TODO: maybe a more efficient way? | |||
| self._state[p][k] = Tensor(v.numpy()) | |||
| if set(group_new.keys()) != set(group_saved.keys()): | |||
| raise ValueError( | |||
| @@ -8,7 +8,7 @@ | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from typing import Iterable, Union | |||
| from ..tensor_nn import Parameter | |||
| from ..tensor import Parameter | |||
| from .optimizer import Optimizer | |||
| @@ -52,7 +52,7 @@ class SGD(Optimizer): | |||
| momentum = param_group["momentum"] | |||
| for param in param_group["params"]: | |||
| if not param.requires_grad or "grad" not in param.__dict__: | |||
| if param.grad is None: | |||
| continue | |||
| grad = param.grad | |||
| @@ -14,8 +14,7 @@ from .. import functional as F | |||
| from ..core.tensor.dtype import _metadata_dict, get_quantized_dtype | |||
| from ..core.tensor.function import Function | |||
| from ..module import Module | |||
| from ..tensor import Tensor | |||
| from ..tensor_nn import Parameter | |||
| from ..tensor import Parameter, Tensor | |||
| from .utils import QuantMode, fake_quant_tensor, get_qparam_dict | |||
| @@ -13,7 +13,7 @@ import numpy as np | |||
| from .. import functional as F | |||
| from ..core.tensor.dtype import _metadata_dict, get_quantized_dtype | |||
| from ..module import Module | |||
| from ..tensor_nn import Buffer | |||
| from ..tensor import Tensor | |||
| from .utils import QuantMode, Round, get_qparam_dict | |||
| @@ -82,8 +82,8 @@ class MinMaxObserver(Observer): | |||
| ): | |||
| super().__init__(dtype, narrow_range) | |||
| self.mode = mode | |||
| self.min_val = Buffer(np.finfo(np.float32).max, dtype=np.float32) | |||
| self.max_val = Buffer(np.finfo(np.float32).min, dtype=np.float32) | |||
| self.min_val = Tensor(np.finfo(np.float32).max, dtype=np.float32) | |||
| self.max_val = Tensor(np.finfo(np.float32).min, dtype=np.float32) | |||
| self.scale_limit = eps | |||
| def _calculate_qparams(self, inp_min_val, inp_max_val): | |||
| @@ -118,8 +118,8 @@ class MinMaxObserver(Observer): | |||
| # stop gradient | |||
| x = x_orig.detach() | |||
| # find max and min | |||
| self.min_val.set_value(F.minimum(self.min_val, x.min())) | |||
| self.max_val.set_value(F.maximum(self.max_val, x.max())) | |||
| self.min_val._reset(F.minimum(self.min_val, x.min())) | |||
| self.max_val._reset(F.maximum(self.max_val, x.max())) | |||
| return x_orig | |||
| @@ -133,22 +133,22 @@ class ExponentialMovingAverageObserver(MinMaxObserver): | |||
| narrow_range: bool = False, | |||
| ): | |||
| super().__init__(mode, eps, dtype, narrow_range) | |||
| self.momentum = Buffer(momentum) | |||
| self.runtime_momentum = Buffer(0.0) | |||
| self.momentum = Tensor(momentum) | |||
| self.runtime_momentum = Tensor(0.0) | |||
| def set_momentum(self, momentum): | |||
| self.momentum.set_value(momentum) | |||
| self.momentum._reset(momentum) | |||
| def forward(self, x_orig): | |||
| if self.enabled: | |||
| # stop gradient | |||
| x = x_orig.detach() | |||
| # Exponential Moving Average | |||
| self.min_val.set_value( | |||
| self.min_val._reset( | |||
| self.min_val * self.runtime_momentum | |||
| + (1 - self.runtime_momentum) * x.min() | |||
| ) | |||
| self.max_val.set_value( | |||
| self.max_val._reset( | |||
| self.max_val * self.runtime_momentum | |||
| + (1 - self.runtime_momentum) * x.max() | |||
| ) | |||
| @@ -171,7 +171,7 @@ class HistogramObserver(MinMaxObserver): | |||
| self.bins = bins | |||
| self.upsample_rate = upsample_rate | |||
| self.dst_nbins = _metadata_dict[dtype].qmax - _metadata_dict[dtype].qmin + 1 | |||
| self.histogram = Buffer([-1] + [0.0] * (bins - 1)) | |||
| self.histogram = Tensor([-1] + [0.0] * (bins - 1)) | |||
| def _non_linear_param_search(self): | |||
| r"""Non-linear parameter search. | |||
| @@ -395,9 +395,9 @@ class HistogramObserver(MinMaxObserver): | |||
| self.bins, | |||
| ) | |||
| self.histogram.set_value(new_histogram) | |||
| self.min_val.set_value(new_min) | |||
| self.max_val.set_value(new_max) | |||
| self.histogram._reset(new_histogram) | |||
| self.min_val._reset(new_min) | |||
| self.max_val._reset(new_max) | |||
| def forward(self, x_orig): | |||
| self.sideeffect_forward(x_orig) | |||
| @@ -14,10 +14,11 @@ from .core import Tensor as _Tensor | |||
| from .core.ops.builtin import Copy | |||
| from .core.tensor.core import apply | |||
| from .device import get_default_device | |||
| from .utils.deprecation import deprecated | |||
| class Tensor(_Tensor): | |||
| requires_grad = False | |||
| grad = None | |||
| dmap_callback = None | |||
| def __init__(self, data, dtype=None, device=None): | |||
| @@ -26,15 +27,32 @@ class Tensor(_Tensor): | |||
| self.q_dict = {"mode": None, "scale": None, "zero_point": None} | |||
| super().__init__(data, dtype=dtype, device=device) | |||
| @deprecated(version="1.0", reason="no need to reuse an existing tensor since 1.0") | |||
| def set_value(self, value): | |||
| self._reset(value) | |||
| @deprecated(version="1.0", reason="use *= 0 instead") | |||
| def reset_zero(self): | |||
| self *= 0 | |||
| def to(self, cn): | |||
| return apply(Copy(comp_node=cn), self)[0] | |||
| @property | |||
| def requires_grad(self): | |||
| raise AttributeError("requires_grad is reserved for future use") | |||
| @requires_grad.setter | |||
| def requires_grad(self, value): | |||
| raise AttributeError("requires_grad is reserved for future use") | |||
| @requires_grad.deleter | |||
| def requires_grad(self): | |||
| raise AttributeError("requires_grad is reserved for future use") | |||
| def __hash__(self): | |||
| return id(self) | |||
| def __getstate__(self): | |||
| r""" __getstate__ will be called for pickle serialization or deep copy | |||
| """ | |||
| @@ -73,53 +91,6 @@ class Tensor(_Tensor): | |||
| tensor = Tensor | |||
| class Dict(collections.MutableMapping): | |||
| def __init__(self, *args, key=None, **kwargs): | |||
| self.data = {} | |||
| if key: | |||
| self.keyfn = key | |||
| for i in args: | |||
| self.update(i) | |||
| self.update(**kwargs) | |||
| @staticmethod | |||
| def keyfn(key): # pylint: disable=method-hidden | |||
| return key | |||
| def __getitem__(self, key): | |||
| _, v = self.data[self.keyfn(key)] | |||
| return v | |||
| def __setitem__(self, key, value): | |||
| self.data[self.keyfn(key)] = key, value | |||
| def __delitem__(self, key): | |||
| del self.data[self.keyfn(key)] | |||
| def __iter__(self): | |||
| for _, (k, _) in self.data.items(): | |||
| yield k | |||
| def __len__(self): | |||
| return len(self.data) | |||
| class TensorDict(Dict): # pylint: disable=too-many-ancestors | |||
| class keyfn: | |||
| def __new__(cls, x: Tensor): | |||
| if not isinstance(x, Tensor): | |||
| return x | |||
| return super().__new__(cls) | |||
| def __init__(self, x: Tensor): | |||
| self._data = x # do not save id directly to make pickle work | |||
| def __hash__(self): | |||
| return id(self._data) | |||
| def __eq__(self, other): | |||
| # pylint: disable=undefined-variable | |||
| return isinstance(other, __class__) and id(self._data) == id(other._data) | |||
| def __init__(self, *args): | |||
| super().__init__(*args) | |||
| class Parameter(Tensor): | |||
| r"""A kind of Tensor that is to be considered a module parameter. | |||
| """ | |||
| @@ -1,20 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from . import Tensor, tensor | |||
| class Buffer(Tensor): | |||
| r"""A kind of Tensor with ``requires_grad=False``. | |||
| """ | |||
| class Parameter(Tensor): | |||
| r"""A kind of Tensor that is to be considered a module parameter. | |||
| """ | |||
| requires_grad = True | |||
| @@ -0,0 +1 @@ | |||
| from deprecated.sphinx import deprecated | |||
| @@ -15,7 +15,7 @@ def get_ndtuple(value, *, n, allow_zero=True): | |||
| :type allow_zero: bool | |||
| :param allow_zero: whether to allow zero tuple value""" | |||
| if not isinstance(value, collections.Iterable): | |||
| if not isinstance(value, collections.abc.Iterable): | |||
| value = int(value) | |||
| value = tuple([value for i in range(n)]) | |||
| else: | |||
| @@ -5,3 +5,4 @@ requests | |||
| tabulate | |||
| tqdm | |||
| redispy | |||
| deprecated | |||
| @@ -38,7 +38,7 @@ class Simple2(Module): | |||
| def test_advance_indexing(): | |||
| net = Simple() | |||
| gm = ad.GradManager().register(net.parameters()) | |||
| gm = ad.GradManager().attach(net.parameters()) | |||
| optim = optimizer.SGD(net.parameters(), lr=1.0) | |||
| optim.clear_grad() | |||
| @@ -48,7 +48,7 @@ def test_advance_indexing(): | |||
| data = tensor(raw_data) | |||
| mask = tensor(raw_mask) | |||
| answer = 1.0 - raw_data[raw_mask].sum() | |||
| with gm.record(): | |||
| with gm: | |||
| loss = net(data, mask).sum() | |||
| gm.backward(loss) | |||
| optim.step() | |||
| @@ -58,7 +58,7 @@ def test_advance_indexing(): | |||
| def test_advance_indexing_with_subtensor(): | |||
| net = Simple2() | |||
| gm = ad.GradManager().register(net.parameters()) | |||
| gm = ad.GradManager().attach(net.parameters()) | |||
| optim = optimizer.SGD(net.parameters(), lr=1.0) | |||
| optim.clear_grad() | |||
| @@ -66,7 +66,7 @@ def test_advance_indexing_with_subtensor(): | |||
| raw_data = np.arange(576).reshape(dshape).astype(np.float32) | |||
| data = tensor(raw_data) | |||
| answer = 1.0 - raw_data[1, ..., :, 0:4:2, 0:2].sum() | |||
| with gm.record(): | |||
| with gm: | |||
| loss = net(data).sum() | |||
| gm.backward(loss) | |||
| optim.step() | |||
| @@ -28,13 +28,13 @@ class Simple(Module): | |||
| def test_ai(): | |||
| net = Simple() | |||
| gm = ad.GradManager().register(net.parameters()) | |||
| gm = ad.GradManager().attach(net.parameters()) | |||
| optim = optimizer.SGD(net.parameters(), lr=1.0) | |||
| optim.clear_grad() | |||
| dshape = (10, 10) | |||
| data = tensor(np.ones(dshape).astype(np.float32)) | |||
| with gm.record(): | |||
| with gm: | |||
| loss = net(data).sum() | |||
| gm.backward(loss) | |||
| optim.step() | |||
| @@ -25,12 +25,12 @@ def test_frozen_bn(): | |||
| saved_wt = m.weight.numpy() | |||
| saved_bias = m.bias.numpy() | |||
| gm = ad.GradManager().register(m.parameters()) | |||
| gm = ad.GradManager().attach(m.parameters()) | |||
| optim = optimizer.SGD(m.parameters(), lr=1.0) | |||
| optim.clear_grad() | |||
| data = np.random.random((6, nchannel, 2, 2)).astype("float32") | |||
| with gm.record(): | |||
| with gm: | |||
| loss = m(data).mean() | |||
| gm.backward(loss) | |||
| optim.step() | |||
| @@ -46,12 +46,12 @@ def test_bn_no_track_stat(): | |||
| nchannel = 3 | |||
| m = BatchNorm2d(nchannel, track_running_stats=False) | |||
| gm = ad.GradManager().register(m.parameters()) | |||
| gm = ad.GradManager().attach(m.parameters()) | |||
| optim = optimizer.SGD(m.parameters(), lr=1.0) | |||
| optim.clear_grad() | |||
| data = np.random.random((6, nchannel, 2, 2)).astype("float32") | |||
| with gm.record(): | |||
| with gm: | |||
| loss = m(data).sum() | |||
| gm.backward(loss) | |||
| optim.step() | |||
| @@ -68,12 +68,12 @@ def test_bn_no_track_stat2(): | |||
| saved_mean = m.running_mean.numpy() | |||
| assert saved_mean is not None | |||
| gm = ad.GradManager().register(m.parameters()) | |||
| gm = ad.GradManager().attach(m.parameters()) | |||
| optim = optimizer.SGD(m.parameters(), lr=1.0) | |||
| optim.clear_grad() | |||
| data = np.random.random((6, nchannel, 2, 2)).astype("float32") | |||
| with gm.record(): | |||
| with gm: | |||
| loss = m(data).sum() | |||
| gm.backward(loss) | |||
| optim.step() | |||
| @@ -74,13 +74,11 @@ class XORNet(Module): | |||
| def test_training_converge(): | |||
| net = XORNet() | |||
| opt = SGD( | |||
| net.parameters(requires_grad=True), lr=0.01, momentum=0.9, weight_decay=5e-4 | |||
| ) | |||
| gm = ad.GradManager().register(net.parameters()) | |||
| opt = SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4) | |||
| gm = ad.GradManager().attach(net.parameters()) | |||
| def train(data, label): | |||
| with gm.record(): | |||
| with gm: | |||
| pred = net(data) | |||
| loss = F.cross_entropy_with_softmax(pred, label) | |||
| gm.backward(loss) | |||
| @@ -91,7 +91,7 @@ class MnistNet(Module): | |||
| def train(data, label, net, opt, gm): | |||
| with gm.record(): | |||
| with gm: | |||
| pred = net(data) | |||
| loss = F.cross_entropy_with_softmax(pred, label) | |||
| gm.backward(loss) | |||
| @@ -117,7 +117,7 @@ def update_model(model_path): | |||
| net.load_state_dict(checkpoint["net_init"]) | |||
| lr = checkpoint["sgd_lr"] | |||
| opt = SGD(net.parameters(), lr=lr) | |||
| gm = ad.GradManager().register(net.parameters()) | |||
| gm = ad.GradManager().attach(net.parameters()) | |||
| data = Tensor(checkpoint["data"], dtype=np.float32) | |||
| label = Tensor(checkpoint["label"], dtype=np.int32) | |||
| @@ -152,7 +152,7 @@ def run_train( | |||
| net.load_state_dict(checkpoint["net_init"]) | |||
| lr = checkpoint["sgd_lr"] | |||
| opt = SGD(net.parameters(), lr=lr) | |||
| gm = ad.GradManager().register(net.parameters()) | |||
| gm = ad.GradManager().attach(net.parameters()) | |||
| data = Tensor(checkpoint["data"], dtype=np.float32) | |||
| label = Tensor(checkpoint["label"], dtype=np.int32) | |||
| @@ -32,11 +32,11 @@ def test_detach(): | |||
| optim = optimizer.SGD(net.parameters(), lr=1.0) | |||
| optim.clear_grad() | |||
| gm = ad.GradManager().register(net.parameters()) | |||
| gm = ad.GradManager().attach(net.parameters()) | |||
| dshape = (10, 10) | |||
| data = tensor(np.ones(dshape).astype(np.float32)) | |||
| with gm.record(): | |||
| with gm: | |||
| loss = net(data).sum() | |||
| gm.backward(loss) | |||
| optim.step() | |||
| @@ -97,7 +97,7 @@ class MnistNet(Module): | |||
| def train(data, label, net, opt, gm): | |||
| opt.clear_grad() | |||
| with gm.record(): | |||
| with gm: | |||
| pred = net(data) | |||
| loss = F.cross_entropy_with_softmax(pred, label) | |||
| gm.backward(loss) | |||
| @@ -125,8 +125,7 @@ def update_model(model_path): | |||
| lr = checkpoint["sgd_lr"] | |||
| opt = SGD(net.parameters(), lr=lr) | |||
| gm = ad.GradManager() | |||
| gm.register( | |||
| gm = ad.GradManager().attach( | |||
| net.parameters(), callbacks=[dist.make_allreduce_cb("MEAN", dist.WORLD)] | |||
| ) | |||
| @@ -171,8 +170,7 @@ def run_test( | |||
| lr = checkpoint["sgd_lr"] | |||
| opt = SGD(net.parameters(), lr=lr) | |||
| gm = ad.GradManager() | |||
| gm.register( | |||
| gm = ad.GradManager().attach( | |||
| net.parameters(), callbacks=[dist.make_allreduce_cb("MEAN", dist.WORLD)] | |||
| ) | |||
| @@ -33,10 +33,10 @@ def test_hello_world(): | |||
| optim = optimizer.SGD(net.parameters(), lr=1.0) | |||
| optim.clear_grad() | |||
| gm = ad.GradManager().register(net.parameters()) | |||
| gm = ad.GradManager().attach(net.parameters()) | |||
| data = tensor([2.34]) | |||
| with gm.record(): | |||
| with gm: | |||
| loss = net(data) | |||
| gm.backward(loss) | |||
| optim.step() | |||
| @@ -13,7 +13,7 @@ import megengine.functional as F | |||
| from megengine import Parameter, optimizer | |||
| from megengine.jit import trace | |||
| from megengine.module import Linear, Module | |||
| from megengine.tensor import TensorDict, tensor | |||
| from megengine.tensor import tensor | |||
| class MLP(Module): | |||
| @@ -44,7 +44,7 @@ def _test_optimizer(opt_str, test_case, check_class, update_lr=False): | |||
| net = Simple() | |||
| opt = getattr(optimizer, opt_str)(net.parameters(), **test_case) | |||
| check_func = check_class(net, **test_case) | |||
| gm = ad.GradManager().register(net.parameters()) | |||
| gm = ad.GradManager().attach(net.parameters()) | |||
| step = 0 | |||
| data_shape = (2, 28) | |||
| @@ -57,12 +57,12 @@ def _test_optimizer(opt_str, test_case, check_class, update_lr=False): | |||
| data = tensor(np.random.random(data_shape).astype(np.float32)) | |||
| opt.clear_grad() | |||
| with gm.record(): | |||
| with gm: | |||
| pred = net(data) | |||
| loss = pred.sum() | |||
| gm.backward(loss) | |||
| ori_params = TensorDict() | |||
| ori_params = {} | |||
| for param in net.parameters(): | |||
| ori_params[param] = np.copy(param.numpy()) | |||
| opt.step() | |||
| @@ -75,7 +75,7 @@ def _test_optimizer(opt_str, test_case, check_class, update_lr=False): | |||
| @trace(symbolic=symbolic) | |||
| def train_func(data, *, opt=None, gm=None): | |||
| opt.clear_grad() | |||
| with gm.record(): | |||
| with gm: | |||
| pred = net(data) | |||
| loss = pred.sum() | |||
| gm.backward(loss) | |||
| @@ -84,7 +84,7 @@ def _test_optimizer(opt_str, test_case, check_class, update_lr=False): | |||
| # reset net and opt | |||
| net = Simple() | |||
| opt = getattr(optimizer, opt_str)(net.parameters(), **test_case) | |||
| gm = ad.GradManager().register(net.parameters()) | |||
| gm = ad.GradManager().attach(net.parameters()) | |||
| check_func = check_class(net, **test_case) | |||
| step = 0 | |||
| for i in range(iter_num): | |||
| @@ -93,7 +93,7 @@ def _test_optimizer(opt_str, test_case, check_class, update_lr=False): | |||
| group["lr"] += 0.01 | |||
| check_func.lr += 0.01 | |||
| ori_params = TensorDict() | |||
| ori_params = {} | |||
| for param in net.parameters(): | |||
| ori_params[param] = np.copy(param.numpy()) | |||
| @@ -105,7 +105,7 @@ def _test_optimizer(opt_str, test_case, check_class, update_lr=False): | |||
| def test_sgd(): | |||
| class CheckValue: | |||
| def __init__(self, net, **kwarg): | |||
| self.slots = TensorDict() | |||
| self.slots = {} | |||
| for param in net.parameters(): | |||
| self.slots[param] = np.zeros(param.shape).astype(np.float32) | |||
| for k, v in kwarg.items(): | |||
| @@ -134,8 +134,8 @@ def test_sgd(): | |||
| def test_adam(): | |||
| class CheckValue: | |||
| def __init__(self, net, **kwarg): | |||
| self.m_slots = TensorDict() | |||
| self.v_slots = TensorDict() | |||
| self.m_slots = {} | |||
| self.v_slots = {} | |||
| for param in net.parameters(): | |||
| self.m_slots[param] = np.zeros(param.shape).astype(np.float32) | |||
| self.v_slots[param] = np.zeros(param.shape).astype(np.float32) | |||
| @@ -175,7 +175,7 @@ def test_adam(): | |||
| def test_adagrad(): | |||
| class CheckValue: | |||
| def __init__(self, net, **kwarg): | |||
| self.s_slots = TensorDict() | |||
| self.s_slots = {} | |||
| for param in net.parameters(): | |||
| self.s_slots[param] = np.zeros(param.shape).astype(np.float32) | |||
| for k, v in kwarg.items(): | |||
| @@ -207,8 +207,8 @@ def test_adagrad(): | |||
| def test_adadelta(): | |||
| class CheckValue: | |||
| def __init__(self, net, **kwarg): | |||
| self.s_slots = TensorDict() | |||
| self.a_slots = TensorDict() | |||
| self.s_slots = {} | |||
| self.a_slots = {} | |||
| for param in net.parameters(): | |||
| self.s_slots[param] = np.zeros(param.shape).astype(np.float32) | |||
| self.a_slots[param] = np.zeros(param.shape).astype(np.float32) | |||
| @@ -23,11 +23,11 @@ def test_save_load(): | |||
| optim = optimizer.SGD(net.parameters(), lr=1.0, momentum=0.9) | |||
| optim.clear_grad() | |||
| gm = ad.GradManager().register(net.parameters()) | |||
| gm = ad.GradManager().attach(net.parameters()) | |||
| data = tensor([2.34]) | |||
| with gm.record(): | |||
| with gm: | |||
| loss = net(data) | |||
| gm.backward(loss) | |||
| @@ -55,7 +55,7 @@ def test_save_load(): | |||
| optim.load_state_dict(checkpoint["opt_state"]) | |||
| print("load done") | |||
| with gm.record(): | |||
| with gm: | |||
| loss = net([1.23]) | |||
| gm.backward(loss) | |||
| @@ -31,12 +31,12 @@ def test_sgd_momentum(): | |||
| optim = optimizer.SGD(net.parameters(), lr=1.0, momentum=0.9) | |||
| optim.clear_grad() | |||
| gm = ad.GradManager().register(net.parameters()) | |||
| gm = ad.GradManager().attach(net.parameters()) | |||
| data = tensor([2.34]) | |||
| # do a step of train | |||
| with gm.record(): | |||
| with gm: | |||
| loss = net(data) | |||
| gm.backward(loss) | |||
| optim.step() | |||
| @@ -51,7 +51,7 @@ def test_sgd_momentum(): | |||
| # do a step of train | |||
| optim.clear_grad() | |||
| with gm.record(): | |||
| with gm: | |||
| loss = net(data) | |||
| gm.backward(loss) | |||
| optim.step() | |||
| @@ -69,7 +69,7 @@ def test_sgd_momentum_trace(): | |||
| @trace(symbolic=symbolic) | |||
| def train_func(data, *, model=None, optim=None, gm=None): | |||
| optim.clear_grad() | |||
| with gm.record(): | |||
| with gm: | |||
| loss = net(data) | |||
| gm.backward(loss) | |||
| optim.step() | |||
| @@ -82,7 +82,7 @@ def test_sgd_momentum_trace(): | |||
| net = Simple() | |||
| optim = optimizer.SGD(net.parameters(), lr=1.0, momentum=0.9) | |||
| gm = ad.GradManager().register(net.parameters()) | |||
| gm = ad.GradManager().attach(net.parameters()) | |||
| data = tensor([2.34]) | |||
| train_func(data, model=net, optim=optim, gm=gm) | |||
| np.testing.assert_almost_equal( | |||
| @@ -61,15 +61,15 @@ class XORNet(M.Module): | |||
| def test_xornet_trace_dump(): | |||
| net = XORNet() | |||
| opt = optim.SGD(net.parameters(requires_grad=True), lr=0.01, momentum=0.9) | |||
| gm = GradManager().register(net.parameters(requires_grad=True)) | |||
| opt = optim.SGD(net.parameters(), lr=0.01, momentum=0.9) | |||
| gm = GradManager().attach(net.parameters()) | |||
| batch_size = 64 | |||
| train_dataset = minibatch_generator(batch_size) | |||
| val_dataset = minibatch_generator(batch_size) | |||
| @trace | |||
| def train_fun(data, label): | |||
| with gm.record(): | |||
| with gm: | |||
| net.train() | |||
| pred = net(data) | |||
| loss = F.cross_entropy_with_softmax(pred, label) | |||
| @@ -14,7 +14,7 @@ import pytest | |||
| import megengine.core.ops.builtin as builtin | |||
| import megengine.core.tensor.dtype as dtype | |||
| import megengine.functional as F | |||
| from megengine import Buffer, Parameter, is_cuda_available, tensor | |||
| from megengine import Parameter, Tensor, is_cuda_available, tensor | |||
| from megengine.core._trace_option import use_tensor_shape | |||
| from megengine.core.autodiff.grad import Grad | |||
| from megengine.core.tensor.utils import make_shape_tuple | |||
| @@ -330,7 +330,7 @@ def test_roi_pooling(): | |||
| def test_add_update(): | |||
| shape = (2, 3) | |||
| v = np.random.random(shape).astype(np.float32) | |||
| b = Buffer(v) | |||
| b = Tensor(v) | |||
| u = F.add_update(b, 1) | |||
| assertTensorClose(u.numpy(), v + 1) | |||
| @@ -347,7 +347,7 @@ def test_add_update(): | |||
| def test_add_update_params(): | |||
| b = np.random.random((2, 3)).astype(np.float32) | |||
| y = Buffer(b) | |||
| y = Tensor(b) | |||
| # @jit.trace | |||
| def f(x): | |||
| @@ -355,7 +355,7 @@ def test_add_update_params(): | |||
| f(np.zeros((2, 3)).astype(np.float32)) | |||
| z = Buffer(np.zeros((2, 3)).astype(np.float32)) | |||
| z = Tensor(np.zeros((2, 3)).astype(np.float32)) | |||
| F.add_update(y, z, beta=0.1) | |||
| res = f(np.ones((2, 3)).astype(np.float32)) | |||
| @@ -12,7 +12,7 @@ import numpy as np | |||
| import pytest | |||
| import megengine.functional as F | |||
| from megengine import Buffer, Parameter, is_cuda_available, tensor | |||
| from megengine import tensor | |||
| from megengine.core._trace_option import use_tensor_shape | |||
| from megengine.core.tensor.utils import astensor1d | |||
| from megengine.distributed.helper import get_device_count_by_fork | |||
| @@ -14,10 +14,9 @@ import pytest | |||
| import megengine as mge | |||
| import megengine.distributed as dist | |||
| from megengine import tensor | |||
| from megengine import Tensor | |||
| from megengine.core._trace_option import use_tensor_shape | |||
| from megengine.module import BatchNorm1d, BatchNorm2d, SyncBatchNorm | |||
| from megengine.tensor import Tensor | |||
| from megengine.test import assertTensorClose | |||
| @@ -45,10 +44,8 @@ def test_syncbn(): | |||
| return | |||
| dist.init_process_group("localhost", port, nr_ranks, rank, rank) | |||
| bn = SyncBatchNorm(nr_chan, momentum=momentum, eps=eps) | |||
| data_tensor = tensor([]) | |||
| for i in range(steps): | |||
| data_tensor.set_value(data[i]) | |||
| yv = bn(data_tensor) | |||
| yv = bn(Tensor(data[i])) | |||
| assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6) | |||
| assertTensorClose(running_mean, bn.running_mean.numpy(), max_err=5e-6) | |||
| @@ -105,7 +102,6 @@ def test_batchnorm(): | |||
| bn = BatchNorm1d(nr_chan, momentum=momentum) | |||
| running_mean = np.zeros((1, nr_chan, 1), dtype=np.float32) | |||
| running_var = np.ones((1, nr_chan, 1), dtype=np.float32) | |||
| data = tensor([]) | |||
| for i in range(3): | |||
| xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32) | |||
| mean = np.mean(np.mean(xv, axis=0, keepdims=True), axis=2, keepdims=True) | |||
| @@ -120,8 +116,7 @@ def test_batchnorm(): | |||
| running_mean = running_mean * momentum + mean * (1 - momentum) | |||
| running_var = running_var * momentum + var_unbiased * (1 - momentum) | |||
| data.set_value(xv) | |||
| yv = bn(data) | |||
| yv = bn(Tensor(xv)) | |||
| yv_expect = (xv - mean) / sd | |||
| assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6) | |||
| @@ -137,7 +132,7 @@ def test_batchnorm(): | |||
| var_backup = bn.running_var.numpy() | |||
| bn.training = False | |||
| xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32) | |||
| data.set_value(xv) | |||
| data = Tensor(xv) | |||
| yv1 = bn(data) | |||
| yv2 = bn(data) | |||
| assertTensorClose(yv1.numpy(), yv2.numpy(), max_err=0) | |||
| @@ -161,7 +156,6 @@ def test_syncbn1d(): | |||
| bn = SyncBatchNorm(nr_chan, momentum=momentum) | |||
| running_mean = np.zeros((1, nr_chan, 1), dtype=np.float32) | |||
| running_var = np.ones((1, nr_chan, 1), dtype=np.float32) | |||
| data = tensor([]) | |||
| for i in range(3): | |||
| xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32) | |||
| mean = np.mean(np.mean(xv, axis=0, keepdims=True), axis=2, keepdims=True) | |||
| @@ -176,8 +170,7 @@ def test_syncbn1d(): | |||
| running_mean = running_mean * momentum + mean * (1 - momentum) | |||
| running_var = running_var * momentum + var_unbiased * (1 - momentum) | |||
| data.set_value(xv) | |||
| yv = bn(data) | |||
| yv = bn(Tensor(xv)) | |||
| yv_expect = (xv - mean) / sd | |||
| assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6) | |||
| @@ -193,7 +186,7 @@ def test_syncbn1d(): | |||
| var_backup = bn.running_var.numpy() | |||
| bn.training = False | |||
| xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32) | |||
| data.set_value(xv) | |||
| data = Tensor(xv) | |||
| yv1 = bn(data) | |||
| yv2 = bn(data) | |||
| assertTensorClose(yv1.numpy(), yv2.numpy(), max_err=0) | |||
| @@ -210,7 +203,6 @@ def test_batchnorm2d(): | |||
| bn = BatchNorm2d(nr_chan, momentum=momentum) | |||
| running_mean = np.zeros((1, nr_chan, 1, 1), dtype=np.float32) | |||
| running_var = np.ones((1, nr_chan, 1, 1), dtype=np.float32) | |||
| data = tensor([]) | |||
| for i in range(3): | |||
| xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32) | |||
| xv_transposed = np.transpose(xv, [0, 2, 3, 1]).reshape( | |||
| @@ -226,8 +218,7 @@ def test_batchnorm2d(): | |||
| running_mean = running_mean * momentum + mean * (1 - momentum) | |||
| running_var = running_var * momentum + var_unbiased * (1 - momentum) | |||
| data.set_value(xv) | |||
| yv = bn(data) | |||
| yv = bn(Tensor(xv)) | |||
| yv_expect = (xv - mean) / sd | |||
| assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6) | |||
| @@ -239,7 +230,7 @@ def test_batchnorm2d(): | |||
| var_backup = bn.running_var.numpy() | |||
| bn.training = False | |||
| xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32) | |||
| data.set_value(xv) | |||
| data = Tensor(xv) | |||
| yv1 = bn(data) | |||
| yv2 = bn(data) | |||
| assertTensorClose(yv1.numpy(), yv2.numpy(), max_err=0) | |||
| @@ -263,7 +254,6 @@ def test_syncbn2d(): | |||
| bn = SyncBatchNorm(nr_chan, momentum=momentum) | |||
| running_mean = np.zeros((1, nr_chan, 1, 1), dtype=np.float32) | |||
| running_var = np.ones((1, nr_chan, 1, 1), dtype=np.float32) | |||
| data = tensor([]) | |||
| for i in range(3): | |||
| xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32) | |||
| xv_transposed = np.transpose(xv, [0, 2, 3, 1]).reshape( | |||
| @@ -279,8 +269,7 @@ def test_syncbn2d(): | |||
| running_mean = running_mean * momentum + mean * (1 - momentum) | |||
| running_var = running_var * momentum + var_unbiased * (1 - momentum) | |||
| data.set_value(xv) | |||
| yv = bn(data) | |||
| yv = bn(Tensor(xv)) | |||
| yv_expect = (xv - mean) / sd | |||
| assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6) | |||
| @@ -292,7 +281,7 @@ def test_syncbn2d(): | |||
| var_backup = bn.running_var.numpy() | |||
| bn.training = False | |||
| xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32) | |||
| data.set_value(xv) | |||
| data = Tensor(xv) | |||
| yv1 = bn(data) | |||
| yv2 = bn(data) | |||
| assertTensorClose(yv1.numpy(), yv2.numpy(), max_err=0) | |||
| @@ -306,7 +295,6 @@ def test_batchnorm_no_stats(): | |||
| nr_chan = 8 | |||
| data_shape = (3, nr_chan, 4) | |||
| bn = BatchNorm1d(8, track_running_stats=False) | |||
| data = tensor([]) | |||
| for i in range(4): | |||
| if i == 2: | |||
| bn.training = False | |||
| @@ -320,8 +308,7 @@ def test_batchnorm_no_stats(): | |||
| ).reshape((1, nr_chan, 1)) | |||
| sd = np.sqrt(var + bn.eps) | |||
| data.set_value(xv) | |||
| yv = bn(data) | |||
| yv = bn(Tensor(xv)) | |||
| yv_expect = (xv - mean) / sd | |||
| assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6) | |||
| @@ -338,7 +325,6 @@ def test_syncbn_no_stats(): | |||
| nr_chan = 8 | |||
| data_shape = (3, nr_chan, 4) | |||
| bn = SyncBatchNorm(8, track_running_stats=False) | |||
| data = tensor([]) | |||
| for i in range(4): | |||
| if i == 2: | |||
| bn.training = False | |||
| @@ -352,8 +338,7 @@ def test_syncbn_no_stats(): | |||
| ).reshape((1, nr_chan, 1)) | |||
| sd = np.sqrt(var + bn.eps) | |||
| data.set_value(xv) | |||
| yv = bn(data) | |||
| yv = bn(Tensor(xv)) | |||
| yv_expect = (xv - mean) / sd | |||
| assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6) | |||
| @@ -363,7 +348,6 @@ def test_batchnorm2d_no_stats(): | |||
| nr_chan = 8 | |||
| data_shape = (3, nr_chan, 16, 16) | |||
| bn = BatchNorm2d(8, track_running_stats=False) | |||
| data = tensor([]) | |||
| for i in range(4): | |||
| if i == 2: | |||
| bn.training = False | |||
| @@ -376,8 +360,7 @@ def test_batchnorm2d_no_stats(): | |||
| var = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1, 1)) | |||
| sd = np.sqrt(var + bn.eps) | |||
| data.set_value(xv) | |||
| yv = bn(data) | |||
| yv = bn(Tensor(xv)) | |||
| yv_expect = (xv - mean) / sd | |||
| assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6) | |||
| @@ -394,7 +377,6 @@ def test_syncbn2d_no_stats(): | |||
| nr_chan = 8 | |||
| data_shape = (3, nr_chan, 16, 16) | |||
| bn = SyncBatchNorm(8, track_running_stats=False) | |||
| data = tensor([]) | |||
| for i in range(4): | |||
| if i == 2: | |||
| bn.training = False | |||
| @@ -407,8 +389,7 @@ def test_syncbn2d_no_stats(): | |||
| var = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1, 1)) | |||
| sd = np.sqrt(var + bn.eps) | |||
| data.set_value(xv) | |||
| yv = bn(data) | |||
| yv = bn(Tensor(xv)) | |||
| yv_expect = (xv - mean) / sd | |||
| assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6) | |||
| @@ -12,7 +12,7 @@ import numpy as np | |||
| import pytest | |||
| import megengine as mge | |||
| from megengine import tensor | |||
| from megengine import Tensor | |||
| from megengine.module import Module | |||
| @@ -35,12 +35,12 @@ def test_cambricon_module(): | |||
| with open(model, "rb") as f: | |||
| data = f.read() | |||
| m = MyModule(data) | |||
| inputs = [] | |||
| inputs.append(tensor(data=[], dtype=np.float16, device="cambricon0")) | |||
| inputs[0].set_value(np.random.normal(size=(1, 64, 32, 32)).astype(np.float16)) | |||
| inp = Tensor( | |||
| np.random.normal((1, 64, 32, 32)).astype(np.float16), device="cambricon0" | |||
| ) | |||
| def inference(inps): | |||
| pred = m(inps) | |||
| return pred | |||
| pred = inference(inputs) | |||
| pred = inference([inp]) | |||
| @@ -16,7 +16,7 @@ import pytest | |||
| import megengine as mge | |||
| import megengine.functional as F | |||
| from megengine import Buffer, Parameter, Tensor, tensor | |||
| from megengine import Parameter, Tensor, tensor | |||
| from megengine.module import ( | |||
| BatchNorm1d, | |||
| BatchNorm2d, | |||
| @@ -196,7 +196,7 @@ class MyModule(Module): | |||
| self.i = self.InnerModule() | |||
| self.bn = BatchNorm2d(4) | |||
| self.param = Parameter(np.ones(1, dtype=np.float32)) | |||
| self.buff = Buffer(np.ones(1, dtype=np.float32)) | |||
| self.buff = Tensor(np.ones(1, dtype=np.float32)) | |||
| def forward(self, x): | |||
| x = self.i(x) | |||
| @@ -464,8 +464,7 @@ def test_sequential_named_children(): | |||
| def test_state_dict(): | |||
| data_shape = (2, 28) | |||
| data = tensor([]) | |||
| data.set_value(np.random.random(data_shape)) | |||
| data = tensor(np.random.random(data_shape)) | |||
| mlp = MLP() | |||
| pred0 = mlp(data) | |||
| @@ -542,8 +541,7 @@ def test_shared_param(): | |||
| def test_pickle_module(): | |||
| data_shape = (2, 28) | |||
| data = tensor([]) | |||
| data.set_value(np.random.random(data_shape)) | |||
| data = tensor(np.random.random(data_shape)) | |||
| mlp = MLP() | |||
| # pickle before forward | |||
| with BytesIO() as fout: | |||
| @@ -568,8 +566,7 @@ def test_pickle_module(): | |||
| @pytest.mark.skip(reason="under development") | |||
| def test_dump_model(): | |||
| data_shape = (2, 28) | |||
| data = tensor([]) | |||
| data.set_value(np.random.random(data_shape)) | |||
| data = Tensor(np.random.random(data_shape)) | |||
| mlp = MLP() | |||
| pred = mlp(data) | |||
| f = tempfile.NamedTemporaryFile(delete=False) | |||
| @@ -13,7 +13,7 @@ import pytest | |||
| import megengine as mge | |||
| import megengine.functional as F | |||
| from megengine import Buffer, Parameter | |||
| from megengine import Parameter, Tensor | |||
| from megengine.module import Conv2d | |||
| from megengine.test import assertTensorClose | |||
| @@ -33,7 +33,7 @@ def test_set_value(): | |||
| @pytest.mark.skip(reason="fill unsupported") | |||
| def test_fill(): | |||
| a = Buffer(np.zeros((2, 3), dtype=np.float32)) | |||
| a = Tensor(np.zeros((2, 3), dtype=np.float32)) | |||
| a.fill(3) | |||
| assertTensorClose(a.numpy(), np.full((2, 3), 3, dtype=np.float32)) | |||
| a.fill(124.568) | |||
| @@ -80,7 +80,7 @@ def test_fill(): | |||
| # def test_shape_warning(): | |||
| # with Graph() as cg: | |||
| # cg.set_option("eager_evaluation", False) | |||
| # b = Buffer(np.ones((2, 3)).astype(np.float32)) | |||
| # b = Tensor(np.ones((2, 3)).astype(np.float32)) | |||
| # with pytest.warns(None) as record: | |||
| # print(b.shape) | |||
| # if len(record) != 0: | |||
| @@ -42,11 +42,11 @@ def test_single_input(): | |||
| return x | |||
| net = Simple(av) | |||
| gm = ad.GradManager().register(net.parameters()) | |||
| gm = ad.GradManager().attach(net.parameters()) | |||
| opt = optimizer.SGD(net.parameters(), lr=1.0) | |||
| opt.clear_grad() | |||
| with gm.record(): | |||
| with gm: | |||
| loss = net() | |||
| gm.backward(loss.sum()) | |||
| opt.step() | |||
| @@ -81,11 +81,11 @@ def test_multi_input(): | |||
| return x | |||
| net = Simple(av, bv) | |||
| gm = ad.GradManager().register(net.parameters()) | |||
| gm = ad.GradManager().attach(net.parameters()) | |||
| opt = optimizer.SGD(net.parameters(), lr=1.0) | |||
| opt.clear_grad() | |||
| with gm.record(): | |||
| with gm: | |||
| loss = net() | |||
| gm.backward(loss.sum()) | |||
| opt.step() | |||
| @@ -121,11 +121,11 @@ def test_multi_output(): | |||
| return x + y | |||
| net = Simple(av, bv) | |||
| gm = ad.GradManager().register(net.parameters()) | |||
| gm = ad.GradManager().attach(net.parameters()) | |||
| opt = optimizer.SGD(net.parameters(), lr=1.0) | |||
| opt.clear_grad() | |||
| with gm.record(): | |||
| with gm: | |||
| loss = net() | |||
| gm.backward(loss.sum()) | |||
| opt.step() | |||
| @@ -163,9 +163,9 @@ def test_skip_invalid_grad(): | |||
| net = Simple(av, bv) | |||
| optim = optimizer.SGD(net.parameters(), lr=1.0) | |||
| gm = ad.GradManager().register(net.parameters()) | |||
| gm = ad.GradManager().attach(net.parameters()) | |||
| optim.clear_grad() | |||
| with gm.record(): | |||
| with gm: | |||
| loss = net().sum() | |||
| gm.backward(loss) | |||
| optim.step() | |||
| @@ -198,10 +198,10 @@ def test_ste(): | |||
| av = np.random.random(data_shape).astype(np.float32) | |||
| net = Simple(av) | |||
| optim = optimizer.SGD(net.parameters(), lr=1.0) | |||
| gm = ad.GradManager().register(net.parameters()) | |||
| gm = ad.GradManager().attach(net.parameters()) | |||
| optim.clear_grad() | |||
| with gm.record(): | |||
| with gm: | |||
| loss = net() | |||
| gm.backward(loss.sum()) | |||
| optim.step() | |||
| @@ -256,9 +256,9 @@ def test_none_in_out_grad(): | |||
| b = tensor(np.array([2.0], dtype=np.float32)) | |||
| net = Simple(a, b) | |||
| optim = optimizer.SGD(net.parameters(), lr=1.0) | |||
| gm = ad.GradManager().register(net.parameters()) | |||
| gm = ad.GradManager().attach(net.parameters()) | |||
| optim.clear_grad() | |||
| with gm.record(): | |||
| with gm: | |||
| loss, _ = net() | |||
| gm.backward(loss) | |||
| optim.step() | |||
| @@ -293,10 +293,10 @@ def test_zero_grad(): | |||
| a = tensor(np.array([1.0], dtype=np.float32)) | |||
| net = Simple(a) | |||
| optim = optimizer.SGD(net.parameters(), lr=1.0) | |||
| gm = ad.GradManager().register(net.parameters()) | |||
| gm = ad.GradManager().attach(net.parameters()) | |||
| optim.clear_grad() | |||
| with gm.record(): | |||
| with gm: | |||
| loss = net() | |||
| gm.backward(loss.sum()) | |||
| optim.step() | |||
| @@ -38,7 +38,7 @@ def cvt_to_shape_desc(val, inpvar, config=None): | |||
| if isinstance(val, RawTensor): | |||
| return as_tensor(val, device) | |||
| if not isinstance(val, collections.Iterable): | |||
| if not isinstance(val, collections.abc.Iterable): | |||
| val = [val] | |||
| components = [] | |||
| @@ -12,19 +12,18 @@ from tempfile import TemporaryFile | |||
| import numpy as np | |||
| import megengine as mge | |||
| from megengine import Buffer, Parameter, tensor | |||
| from megengine import Parameter, Tensor | |||
| def test_tensor_serialization(): | |||
| def tensor_eq(a, b): | |||
| assert a.dtype == b.dtype | |||
| assert a.device == b.device | |||
| assert a.requires_grad == b.requires_grad | |||
| np.testing.assert_equal(a.numpy(), b.numpy()) | |||
| with TemporaryFile() as f: | |||
| data = np.random.randint(low=0, high=7, size=[233]) | |||
| a = tensor(data, device="xpux", dtype=np.int32) | |||
| a = Tensor(data, device="xpux", dtype=np.int32) | |||
| pickle.dump(a, f) | |||
| f.seek(0) | |||
| b = pickle.load(f) | |||
| @@ -39,19 +38,19 @@ def test_tensor_serialization(): | |||
| np.testing.assert_equal(a.numpy(), b.numpy()) | |||
| with TemporaryFile() as f: | |||
| a = Buffer(np.random.random(size=(2, 233)).astype(np.float32)) | |||
| a = Tensor(np.random.random(size=(2, 233)).astype(np.float32)) | |||
| pickle.dump(a, f) | |||
| f.seek(0) | |||
| b = pickle.load(f) | |||
| assert isinstance(b, Buffer) | |||
| assert type(b) is Tensor | |||
| np.testing.assert_equal(a.numpy(), b.numpy()) | |||
| with TemporaryFile() as f: | |||
| a = Buffer(np.random.random(size=(2, 233)).astype(np.float32)) | |||
| a = Tensor(np.random.random(size=(2, 233)).astype(np.float32)) | |||
| mge.save(a, f) | |||
| f.seek(0) | |||
| b = mge.load(f, map_location="cpux") | |||
| assert isinstance(b, Buffer) | |||
| assert type(b) is Tensor | |||
| assert "cpu" in str(b.device) | |||
| np.testing.assert_equal(a.numpy(), b.numpy()) | |||
| @@ -59,12 +58,12 @@ def test_tensor_serialization(): | |||
| if mge.is_cuda_available(): | |||
| device_org = mge.get_default_device() | |||
| mge.set_default_device("gpu0") | |||
| a = Buffer(np.random.random(size=(2, 233)).astype(np.float32)) | |||
| a = Tensor(np.random.random(size=(2, 233)).astype(np.float32)) | |||
| mge.save(a, f) | |||
| f.seek(0) | |||
| mge.set_default_device("cpux") | |||
| b = mge.load(f, map_location={"gpu0": "cpu0"}) | |||
| assert isinstance(b, Buffer) | |||
| assert type(b) is Tensor | |||
| assert "cpu0" in str(b.device) | |||
| np.testing.assert_equal(a.numpy(), b.numpy()) | |||
| mge.set_default_device(device_org) | |||
| @@ -66,7 +66,7 @@ def main(): | |||
| mge.set_default_device("cpux") | |||
| net = XORNet() | |||
| opt = optim.SGD(net.parameters(requires_grad=True), lr=0.01, momentum=0.9) | |||
| opt = optim.SGD(net.parameters(), lr=0.01, momentum=0.9) | |||
| batch_size = 64 | |||
| train_dataset = minibatch_generator(batch_size) | |||
| val_dataset = minibatch_generator(batch_size) | |||