GitOrigin-RevId: 1c728d6ab9
tags/v1.10.0
| @@ -50,36 +50,36 @@ class autocast: | |||||
| self._origin_enabled = None | self._origin_enabled = None | ||||
| self._origin_high = None | self._origin_high = None | ||||
| self._origin_low = None | self._origin_low = None | ||||
| self._origin_compute_mode = None | |||||
| self._origin_configs = None | self._origin_configs = None | ||||
| def __enter__(self): | def __enter__(self): | ||||
| self._origin_enabled = amp._enabled | |||||
| amp._enabled = self.enabled | |||||
| amp._set_amp_dtype_autocast(self.enabled) | |||||
| if not self.enabled: | |||||
| return | |||||
| if self.enabled: | |||||
| self._origin_enabled = amp._enabled | |||||
| self._origin_high = amp._get_amp_high_prec_dtype() | |||||
| self._origin_low = amp._get_amp_low_prec_dtype() | |||||
| amp._enabled = self.enabled | |||||
| amp._set_amp_dtype_autocast(self.enabled) | |||||
| amp._set_amp_high_prec_dtype(self.high_prec_dtype) | |||||
| amp._set_amp_low_prec_dtype(self.low_prec_dtype) | |||||
| self._origin_high = amp._get_amp_high_prec_dtype() | |||||
| self._origin_low = amp._get_amp_low_prec_dtype() | |||||
| amp._set_amp_high_prec_dtype(self.high_prec_dtype) | |||||
| amp._set_amp_low_prec_dtype(self.low_prec_dtype) | |||||
| self._origin_configs = _config._reset_execution_config(compute_mode="float32") | |||||
| self._origin_configs = _config._reset_execution_config( | |||||
| compute_mode="float32" | |||||
| ) | |||||
| def __exit__(self, *args): | def __exit__(self, *args): | ||||
| amp._enabled = self._origin_enabled | |||||
| amp._set_amp_dtype_autocast(self._origin_enabled) | |||||
| if not self.enabled: | |||||
| return | |||||
| amp._set_amp_high_prec_dtype(self._origin_high) | |||||
| amp._set_amp_low_prec_dtype(self._origin_low) | |||||
| if self.enabled: | |||||
| amp._enabled = self._origin_enabled | |||||
| amp._set_amp_dtype_autocast(self._origin_enabled) | |||||
| amp._set_amp_high_prec_dtype(self._origin_high) | |||||
| amp._set_amp_low_prec_dtype(self._origin_low) | |||||
| _config._reset_execution_config(*self._origin_compute_mode) | |||||
| def __call__(self, func): | def __call__(self, func): | ||||
| @functools.wraps(func) | @functools.wraps(func) | ||||
| def wrapper(*args, **kwargs): | def wrapper(*args, **kwargs): | ||||
| if not self.enabled: | |||||
| return func(*args, **kwargs) | |||||
| with self: | with self: | ||||
| return func(*args, **kwargs) | return func(*args, **kwargs) | ||||
| @@ -10,6 +10,7 @@ from copy import deepcopy | |||||
| from .. import functional as F | from .. import functional as F | ||||
| from ..module import Module | from ..module import Module | ||||
| from ..tensor import Tensor | from ..tensor import Tensor | ||||
| from ..core import _config | |||||
| def _is_nchw_format(param: Tensor): | def _is_nchw_format(param: Tensor): | ||||
| @@ -26,10 +27,12 @@ def convert_tensor_format(x: Tensor, inplace: bool = True): | |||||
| else: | else: | ||||
| raise ValueError("Unsupport tensor ndim {}".format(x.ndim)) | raise ValueError("Unsupport tensor ndim {}".format(x.ndim)) | ||||
| # TODO: use initialization from tensor after fixing format setting | # TODO: use initialization from tensor after fixing format setting | ||||
| if inplace: | |||||
| x[...] = Tensor(x.numpy().transpose(*pattern), format="nhwc") | |||||
| else: | |||||
| x = Tensor(x.numpy().transpose(*pattern), format="nhwc") | |||||
| if x.format != "nhwc": | |||||
| if inplace: | |||||
| data = x.numpy().transpose(*pattern) | |||||
| x[...] = Tensor(data, format="nhwc") | |||||
| else: | |||||
| x = Tensor(x.numpy().transpose(*pattern), format="nhwc") | |||||
| return x | return x | ||||
| @@ -144,7 +144,9 @@ class GradScaler: | |||||
| def _check_gradients(self, grads, scale): | def _check_gradients(self, grads, scale): | ||||
| if len(grads) == 0: | if len(grads) == 0: | ||||
| return False | return False | ||||
| return _check_non_finite(grads, scale) | |||||
| rst = _check_non_finite(grads, scale) | |||||
| rst = rst.numpy() | |||||
| return rst | |||||
| def update(self, new_scale: float = None): | def update(self, new_scale: float = None): | ||||
| r"""Update the scale factor according to whether encountered overflow grad. | r"""Update the scale factor according to whether encountered overflow grad. | ||||
| @@ -182,7 +182,6 @@ def _reset_execution_config( | |||||
| deterministic_kernel=None, | deterministic_kernel=None, | ||||
| async_level=None, | async_level=None, | ||||
| compute_mode=None, | compute_mode=None, | ||||
| bn_format=None, | |||||
| auto_format_convert=None, | auto_format_convert=None, | ||||
| ): | ): | ||||
| global _benchmark_kernel, _deterministic_kernel, __compute_mode | global _benchmark_kernel, _deterministic_kernel, __compute_mode | ||||
| @@ -234,11 +233,11 @@ def _override( | |||||
| def train(): | def train(): | ||||
| """ | """ | ||||
| orig_flags = _reset_execution_config( | orig_flags = _reset_execution_config( | ||||
| benchmark_kernel, | |||||
| deterministic_kernel, | |||||
| async_level, | |||||
| compute_mode, | |||||
| auto_format_convert, | |||||
| benchmark_kernel=benchmark_kernel, | |||||
| deterministic_kernel=deterministic_kernel, | |||||
| async_level=async_level, | |||||
| compute_mode=compute_mode, | |||||
| auto_format_convert=auto_format_convert, | |||||
| ) | ) | ||||
| try: | try: | ||||
| yield | yield | ||||
| @@ -64,7 +64,9 @@ class Grad: | |||||
| continue | continue | ||||
| grad.suppress() | grad.suppress() | ||||
| print("before backward") | |||||
| self._impl.backward(ys, dys) | self._impl.backward(ys, dys) | ||||
| print("after backward") | |||||
| for grad in group: | for grad in group: | ||||
| if grad is self: | if grad is self: | ||||
| @@ -24,6 +24,7 @@ from .._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder | |||||
| from .._imperative_rt.ops import jit_supported | from .._imperative_rt.ops import jit_supported | ||||
| from .._wrap import as_device | from .._wrap import as_device | ||||
| from ..autodiff.grad import Function | from ..autodiff.grad import Function | ||||
| from .. import _config | |||||
| from ..ops import builtin | from ..ops import builtin | ||||
| from .amp import _get_amp_high_prec_dtype, _get_amp_low_prec_dtype | from .amp import _get_amp_high_prec_dtype, _get_amp_low_prec_dtype | ||||
| from .dtype import is_dtype_equal, is_quantize | from .dtype import is_dtype_equal, is_quantize | ||||
| @@ -1226,12 +1226,16 @@ def batch_norm( | |||||
| bias = make_full_if_none(bias, 0) | bias = make_full_if_none(bias, 0) | ||||
| if not training: | if not training: | ||||
| op = builtin.BatchNorm(fwd_mode=BatchNorm.FwdMode.INFERENCE, epsilon=eps) | |||||
| op = builtin.BatchNorm( | |||||
| fwd_mode=BatchNorm.FwdMode.INFERENCE, param_dim="dim_1c11", epsilon=eps | |||||
| ) | |||||
| ret = apply(op, inp, weight, bias, running_mean, running_var)[-1] | ret = apply(op, inp, weight, bias, running_mean, running_var)[-1] | ||||
| return ret | return ret | ||||
| else: | else: | ||||
| op = builtin.BatchNorm(avg_factor=1 - momentum, epsilon=eps) | |||||
| op = builtin.BatchNorm( | |||||
| avg_factor=1 - momentum, param_dim="dim_1c11", epsilon=eps | |||||
| ) | |||||
| if has_mean or has_var: | if has_mean or has_var: | ||||
| running_mean = make_full_if_none(running_mean, 0) | running_mean = make_full_if_none(running_mean, 0) | ||||
| running_var = make_full_if_none(running_var, 1) | running_var = make_full_if_none(running_var, 1) | ||||
| @@ -19,7 +19,6 @@ class _BatchNorm(Module): | |||||
| affine=True, | affine=True, | ||||
| track_running_stats=True, | track_running_stats=True, | ||||
| freeze=False, | freeze=False, | ||||
| param_dim="dim_1c11", | |||||
| **kwargs | **kwargs | ||||
| ): | ): | ||||
| super(_BatchNorm, self).__init__(**kwargs) | super(_BatchNorm, self).__init__(**kwargs) | ||||
| @@ -30,7 +29,6 @@ class _BatchNorm(Module): | |||||
| self.track_running_stats = track_running_stats | self.track_running_stats = track_running_stats | ||||
| self._track_running_stats_saved = track_running_stats | self._track_running_stats_saved = track_running_stats | ||||
| self.freeze = freeze | self.freeze = freeze | ||||
| self.param_dim = param_dim | |||||
| if self.freeze: | if self.freeze: | ||||
| assert ( | assert ( | ||||
| self._track_running_stats_saved | self._track_running_stats_saved | ||||
| @@ -104,7 +102,6 @@ class _BatchNorm(Module): | |||||
| or ((self.running_mean is None) and (self.running_var is None)), | or ((self.running_mean is None) and (self.running_var is None)), | ||||
| momentum=exponential_average_factor, | momentum=exponential_average_factor, | ||||
| eps=self.eps, | eps=self.eps, | ||||
| param_dim=self.param_dim, | |||||
| ) | ) | ||||
| return output | return output | ||||
| @@ -8,6 +8,7 @@ from typing import Union | |||||
| import numpy as np | import numpy as np | ||||
| from ..core import _config | |||||
| from ..core._imperative_rt.core2 import ( | from ..core._imperative_rt.core2 import ( | ||||
| get_auto_format_convert, | get_auto_format_convert, | ||||
| pop_scope, | pop_scope, | ||||
| @@ -96,7 +97,7 @@ class Optimizer(metaclass=ABCMeta): | |||||
| "optimizer can only optimize Parameters, but one of the params is " | "optimizer can only optimize Parameters, but one of the params is " | ||||
| + str(type(param)) | + str(type(param)) | ||||
| ) | ) | ||||
| param._reset(Tensor(param.numpy(), no_cache=True, format=param.format)) | |||||
| param._reset(Tensor(param, no_cache=True)) | |||||
| for name, default in self._defaults.items(): | for name, default in self._defaults.items(): | ||||
| if default is required and name not in param_group: | if default is required and name not in param_group: | ||||
| @@ -119,10 +120,11 @@ class Optimizer(metaclass=ABCMeta): | |||||
| def _add_state(self, param, state_name, initializer=None): | def _add_state(self, param, state_name, initializer=None): | ||||
| if initializer is None: | if initializer is None: | ||||
| initializer = np.zeros(param.shape, dtype=np.float32) | |||||
| with _config._override(auto_format_convert=False): | |||||
| initializer = np.zeros(param.shape, dtype=np.float32) | |||||
| state_dict = self._state.setdefault(param, {}) | state_dict = self._state.setdefault(param, {}) | ||||
| assert state_name not in state_dict | assert state_name not in state_dict | ||||
| state = Tensor(initializer, no_cache=True) | |||||
| state = Tensor(initializer, no_cache=True, format=param.format) | |||||
| state_dict[state_name] = state | state_dict[state_name] = state | ||||
| @abstractmethod | @abstractmethod | ||||
| @@ -5,6 +5,7 @@ from typing import Iterable, Union | |||||
| from ..functional.inplace import _inplace_add_ | from ..functional.inplace import _inplace_add_ | ||||
| from ..tensor import Parameter, tensor | from ..tensor import Parameter, tensor | ||||
| from .optimizer import Optimizer | from .optimizer import Optimizer | ||||
| from ..core import _config | |||||
| class SGD(Optimizer): | class SGD(Optimizer): | ||||
| @@ -10,7 +10,7 @@ import pytest | |||||
| import megengine.functional as F | import megengine.functional as F | ||||
| import megengine.module as M | import megengine.module as M | ||||
| from megengine import Parameter, Tensor, amp, tensor | |||||
| from megengine import Parameter, Tensor, amp, config | |||||
| class MyModule(M.Module): | class MyModule(M.Module): | ||||
| @@ -39,6 +39,22 @@ class MyModule(M.Module): | |||||
| @pytest.mark.parametrize("is_inplace", [False, True]) | @pytest.mark.parametrize("is_inplace", [False, True]) | ||||
| def test_convert_module(is_inplace): | def test_convert_module(is_inplace): | ||||
| m = MyModule() | m = MyModule() | ||||
| expected_shape = { | |||||
| "i.bn.weight": (1, 1, 1, 4), | |||||
| "i.bn.bias": (1, 1, 1, 4), | |||||
| "i.bn.running_mean": (1, 1, 1, 4), | |||||
| "i.bn.running_var": (1, 1, 1, 4), | |||||
| "conv.weight": (2, 2, 4, 4, 2), | |||||
| "conv.bias": (1, 1, 1, 4), | |||||
| "bn.weight": (1, 1, 1, 4), | |||||
| "bn.bias": (1, 1, 1, 4), | |||||
| "bn.running_mean": (1, 1, 1, 4), | |||||
| "bn.running_var": (1, 1, 1, 4), | |||||
| "param": (1, 1, 1, 3), | |||||
| "buff": (1, 1, 1, 3), | |||||
| } | |||||
| m = amp.convert_module_format(m, is_inplace) | m = amp.convert_module_format(m, is_inplace) | ||||
| for name, param in m.named_tensors(): | for name, param in m.named_tensors(): | ||||
| assert param.format == "nhwc" | assert param.format == "nhwc" | ||||
| with config._override(auto_format_convert=False): | |||||
| assert param.shape == expected_shape[name], name | |||||
| @@ -3,6 +3,7 @@ import pytest | |||||
| import megengine as mge | import megengine as mge | ||||
| import megengine.functional as F | import megengine.functional as F | ||||
| import megengine.module as M | |||||
| from megengine import tensor | from megengine import tensor | ||||
| from megengine.autodiff import GradManager | from megengine.autodiff import GradManager | ||||
| from megengine.jit import trace | from megengine.jit import trace | ||||
| @@ -36,9 +37,9 @@ def _compare_nchw_nhwc(data, func, is_symbolic=None): | |||||
| x2 = tensor(data.transpose(0, 2, 3, 1), format="nhwc") | x2 = tensor(data.transpose(0, 2, 3, 1), format="nhwc") | ||||
| if is_symbolic is not None: | if is_symbolic is not None: | ||||
| func = trace(func, symbolic=is_symbolic) | func = trace(func, symbolic=is_symbolic) | ||||
| # out1 = func(x1) | |||||
| out1 = func(x1) | |||||
| out2 = func(x2) | out2 = func(x2) | ||||
| # np.testing.assert_almost_equal(out1, out2, decimal=5) | |||||
| np.testing.assert_almost_equal(out1, out2, decimal=5) | |||||
| @pytest.mark.parametrize("is_symbolic", [None]) | @pytest.mark.parametrize("is_symbolic", [None]) | ||||
| @@ -322,30 +323,91 @@ def test_pooling2d(pooling, is_symbolic): | |||||
| _compare_nchw_nhwc(data, func, is_symbolic) | _compare_nchw_nhwc(data, func, is_symbolic) | ||||
| @pytest.mark.parametrize("is_symbolic", [None]) | |||||
| def test_backward(is_symbolic): | |||||
| data = np.arange(0, 24).reshape((1, 2, 3, 4)) | |||||
| x = tensor(data.transpose(0, 2, 3, 1), format="nhwc") | |||||
| w = mge.tensor(np.ones((3, 1, 1, 2)), format="nhwc") | |||||
| b = mge.tensor(np.ones((1, 1, 1, 3)), format="nhwc") | |||||
| gm = GradManager().attach([w, b]) | |||||
| def _compare_backward(inps, model, is_symbolic=None): | |||||
| def func(*inps): | |||||
| return model(*inps) | |||||
| def func(x, w, b): | |||||
| return F.conv2d(x, w, b) | |||||
| if is_symbolic is not None: | |||||
| func = trace(func, symbolic=is_symbolic) | |||||
| gm = GradManager().attach(model.parameters()) | |||||
| with gm: | with gm: | ||||
| if is_symbolic is not None: | |||||
| func = trace(func, symbolic=is_symbolic) | |||||
| x = func(x, w, b) | |||||
| assert x.format == "nhwc" | |||||
| # test manually convert to NHWC, usually used in detection head | |||||
| x = x.transpose(0, 2, 3, 1).reshape(1, 18, 2) | |||||
| gm.backward(x) | |||||
| print("finish backward", x.format) | |||||
| # backward grad has no format | |||||
| np.testing.assert_equal( | |||||
| w.grad.numpy(), np.array([66, 210, 66, 210, 66, 210]).reshape((3, 1, 1, 2)), | |||||
| ) | |||||
| np.testing.assert_equal( | |||||
| b.grad.numpy(), np.array([12, 12, 12]).reshape((1, 1, 1, 3)) | |||||
| ) | |||||
| rst = func(*inps) | |||||
| gm.backward(rst) | |||||
| expected_grads = [param.grad for param in model.parameters()] | |||||
| inps = [mge.amp.convert_tensor_format(inp) for inp in inps] | |||||
| model = mge.amp.convert_module_format(model) | |||||
| gm = GradManager().attach(model.parameters()) | |||||
| with gm: | |||||
| rst = func(*inps) | |||||
| gm.backward(rst) | |||||
| actual_grads = [param.grad for param in model.parameters()] | |||||
| for expected, actual in zip(expected_grads, actual_grads): | |||||
| # print(param.grad) | |||||
| np.testing.assert_equal(expected.numpy(), actual.numpy()) | |||||
| @pytest.mark.parametrize("is_symbolic", [None]) | |||||
| def test_backward_conv2d_dimshuffle(is_symbolic): | |||||
| class Net(M.Module): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.conv = M.Conv2d(2, 3, 1) | |||||
| def forward(self, inp): | |||||
| # test manually convert to NHWC, usually used in detection head | |||||
| return F.transpose(self.conv(inp), (0, 2, 3, 1)).reshape(1, 18, 2) | |||||
| inp = mge.tensor(np.arange(0, 24).reshape((1, 2, 3, 4))) | |||||
| # x = tensor(data.transpose(0, 2, 3, 1), format="nhwc") | |||||
| # w = mge.tensor(np.ones((3, 1, 1, 2)), format="nhwc") | |||||
| # b = mge.tensor(np.ones((1, 1, 1, 3)), format="nhwc") | |||||
| # grads = [ | |||||
| # np.array([66, 210, 66, 210, 66, 210]).reshape((3, 1, 1, 2)), | |||||
| # np.array([12, 12, 12]).reshape((1, 1, 1, 3)), | |||||
| # ] | |||||
| _compare_backward([inp], Net(), is_symbolic) | |||||
| @pytest.mark.parametrize("is_symbolic", [None]) | |||||
| def test_backward_groupconv2d_bn(is_symbolic): | |||||
| class Net(M.Module): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.conv = M.Conv2d(2, 2, 1, groups=2) | |||||
| self.bn = M.BatchNorm2d(2) | |||||
| def forward(self, inp): | |||||
| # test manually convert to NHWC, usually used in detection head | |||||
| return self.bn(self.conv(inp)) | |||||
| inp = mge.tensor(np.arange(0, 24).reshape((1, 2, 3, 4))) | |||||
| _compare_backward([inp], Net(), is_symbolic) | |||||
| # def func(x, w, b, bn_w, bn_b): | |||||
| # x = F.conv2d(x, w, b, groups=2) | |||||
| # x = F.batch_norm( | |||||
| # x, | |||||
| # running_mean=mge.tensor(np.ones((1, 1, 1, 2)), format="nhwc"), | |||||
| # running_var=mge.tensor(np.ones((1, 1, 1, 2)), format="nhwc"), | |||||
| # weight=bn_w, | |||||
| # bias=bn_b, | |||||
| # training=True, | |||||
| # inplace=True, | |||||
| # ) | |||||
| # return x | |||||
| # data = np.arange(0, 24).reshape((1, 2, 3, 4)) | |||||
| # x = tensor(data.transpose(0, 2, 3, 1), format="nhwc") | |||||
| # w = tensor(np.ones((2, 1, 1, 1, 1)), format="nhwc") | |||||
| # b = tensor(np.ones((1, 1, 1, 2)), format="nhwc") | |||||
| # bn_w = tensor(np.ones((1, 1, 1, 2)), format="nhwc") | |||||
| # bn_b = tensor(np.ones((1, 1, 1, 2)), format="nhwc") | |||||
| # grads = [ | |||||
| # np.array([66, 210]).reshape((2, 1, 1, 1, 1)), | |||||
| # np.array([12, 12]).reshape((1, 1, 1, 2)), | |||||
| # np.array([12, 12]).reshape((1, 1, 1, 2)), | |||||
| # np.array([12, 12]).reshape((1, 1, 1, 2)), | |||||
| # ] | |||||
| # _compare_backward(x, func, [w, b, bn_w, bn_b], grads, is_symbolic) | |||||
| @@ -1,6 +1,8 @@ | |||||
| #include "megbrain/imperative/transformations/format.h" | #include "megbrain/imperative/transformations/format.h" | ||||
| #include "megbrain/imperative/transformations/grad.h" | |||||
| #include "megbrain/imperative/ops/autogen.h" | #include "megbrain/imperative/ops/autogen.h" | ||||
| #include "megbrain/imperative/ops/utility.h" | |||||
| namespace mgb { | namespace mgb { | ||||
| namespace imperative { | namespace imperative { | ||||
| @@ -17,7 +19,12 @@ TypedValueRef<FormattedTensorValue> FormatTransformation::to( | |||||
| const std::string& scope) const { | const std::string& scope) const { | ||||
| std::vector<int32_t> pattern; | std::vector<int32_t> pattern; | ||||
| if (tensor.format() == FT::NHWC && target == FT::NCHW) { | if (tensor.format() == FT::NHWC && target == FT::NCHW) { | ||||
| pattern = {0, 3, 1, 2}; | |||||
| // FIXME(czh): temporary fast path for group conv 5D weight. | |||||
| if (tensor.value().shape().cast<ShapeValue>().ndim == 5) { | |||||
| pattern = {0, 1, 4, 2, 3}; | |||||
| } else { | |||||
| pattern = {0, 3, 1, 2}; | |||||
| } | |||||
| } else if (tensor.format() == FT::NCHW && target == FT::NHWC) { | } else if (tensor.format() == FT::NCHW && target == FT::NHWC) { | ||||
| pattern = {0, 2, 3, 1}; | pattern = {0, 2, 3, 1}; | ||||
| } else { | } else { | ||||
| @@ -65,12 +72,22 @@ inline ValueRefList FormatTransformation::wrap_outputs( | |||||
| namespace { | namespace { | ||||
| ValueShape convert_nhwc2nchw_shape(const ValueShape& shape) { | ValueShape convert_nhwc2nchw_shape(const ValueShape& shape) { | ||||
| mgb_assert(shape.ndim == 4); | |||||
| auto out = ValueShape(shape); | auto out = ValueShape(shape); | ||||
| out[3] = shape[2]; | |||||
| out[2] = shape[1]; | |||||
| out[1] = shape[3]; | |||||
| return out; | |||||
| if (shape.ndim == 4) { | |||||
| out[1] = shape[3]; | |||||
| out[2] = shape[1]; | |||||
| out[3] = shape[2]; | |||||
| return out; | |||||
| } else if (shape.ndim == 5) { | |||||
| out[2] = shape[4]; | |||||
| out[3] = shape[2]; | |||||
| out[4] = shape[3]; | |||||
| return out; | |||||
| } else { | |||||
| mgb_throw( | |||||
| MegBrainError, "Unsupported shape ndim %u in GetAttr(Shape).", | |||||
| shape.ndim); | |||||
| } | |||||
| } | } | ||||
| using FormatRule = std::function<ValueRefList( | using FormatRule = std::function<ValueRefList( | ||||
| @@ -278,10 +295,10 @@ ValueRefList setsubtensor_rule( | |||||
| inline FT get_inputs_format(Span<ValueRef>& inputs, const FormatTransformation& t) { | inline FT get_inputs_format(Span<ValueRef>& inputs, const FormatTransformation& t) { | ||||
| FT format(FT::DEFAULT); | FT format(FT::DEFAULT); | ||||
| for (auto& inp : inputs) { | for (auto& inp : inputs) { | ||||
| auto&& inp_ref = inp.as_ref(t.value_type()); | |||||
| if (inp_ref && inp_ref->format() != FT::DEFAULT) { | |||||
| mgb_assert(format == FT::DEFAULT || inp_ref->format() == format); | |||||
| format = inp_ref->format().type(); | |||||
| auto&& inp_format = inp.cast(t.value_type()).format(); | |||||
| if (inp_format != FT::DEFAULT) { | |||||
| mgb_assert(format == FT::DEFAULT || inp_format == format); | |||||
| format = inp_format.type(); | |||||
| } | } | ||||
| } | } | ||||
| return format; | return format; | ||||
| @@ -308,13 +325,6 @@ ValueRefList concat_rule( | |||||
| format); | format); | ||||
| } | } | ||||
| ValueRefList elemwise_rule( | |||||
| const Elemwise& op, Span<ValueRef>& inputs, const bool& auto_convert, | |||||
| const FormatTransformation& t) { | |||||
| FT format = get_inputs_format(inputs, t); | |||||
| return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs)), format); | |||||
| } | |||||
| ValueRefList identity_rule_helper( | ValueRefList identity_rule_helper( | ||||
| const OpDef& op, const Span<ValueRef>& inputs, const FormatTransformation& t) { | const OpDef& op, const Span<ValueRef>& inputs, const FormatTransformation& t) { | ||||
| // mgb_assert(inputs.size() == 1); | // mgb_assert(inputs.size() == 1); | ||||
| @@ -336,24 +346,49 @@ ValueRefList batchnorm_rule( | |||||
| return identity_rule_helper(op, inputs, t); | return identity_rule_helper(op, inputs, t); | ||||
| } | } | ||||
| ValueRefList checknonfinite_rule( | |||||
| const CheckNonFinite& op, Span<ValueRef>& inputs, const bool& auto_convert, | |||||
| const FormatTransformation& t) { | |||||
| auto&& inputs_ = t.unwrap_inputs(inputs); | |||||
| auto&& outputs_ = imperative::apply(op, inputs_); | |||||
| return t.wrap_outputs(outputs_); | |||||
| } | |||||
| // clang-format off | // clang-format off | ||||
| #define FOREACH_IDENTITY_OP(cb) \ | |||||
| cb(Copy) \ | |||||
| cb(FastpathCopy) \ | |||||
| cb(TypeCvt) \ | |||||
| cb(Dropout) \ | |||||
| #define FOREACH_MULTI_INPS_NO_PARAM_OP(cb) \ | |||||
| cb(Elemwise) \ | |||||
| cb(CompiledOp) \ | |||||
| cb(SubgraphOp) | |||||
| #define FOREACH_IDENTITY_OP(cb) \ | |||||
| cb(Copy) \ | |||||
| cb(FastpathCopy) \ | |||||
| cb(TypeCvt) \ | |||||
| cb(Dropout) \ | |||||
| cb(Identity) | cb(Identity) | ||||
| #define FOREACH_FORMAT_OP(cb) \ | |||||
| cb(AdaptivePooling) \ | |||||
| cb(WarpAffine) \ | |||||
| #define FOREACH_FORMAT_OP(cb) \ | |||||
| cb(AdaptivePooling) \ | |||||
| cb(WarpAffine) \ | |||||
| cb(Resize) | cb(Resize) | ||||
| #define FOREACH_FORMAT_POLICY_OP(cb)\ | |||||
| cb(Pooling) \ | |||||
| #define FOREACH_FORMAT_POLICY_OP(cb) \ | |||||
| cb(Pooling) \ | |||||
| cb(Convolution) | cb(Convolution) | ||||
| // clang-format on | // clang-format on | ||||
| // multi inputs op without params | |||||
| #define CREATE_MULTI_INPS_NO_PARAM_OP_RULE(Op) \ | |||||
| ValueRefList Op##_rule( \ | |||||
| const Op& _op, Span<ValueRef>& inputs, const bool& auto_convert, \ | |||||
| const FormatTransformation& t) { \ | |||||
| FT format = get_inputs_format(inputs, t); \ | |||||
| return t.wrap_outputs( \ | |||||
| imperative::apply(_op, t.unwrap_inputs(inputs)), format); \ | |||||
| } | |||||
| FOREACH_MULTI_INPS_NO_PARAM_OP(CREATE_MULTI_INPS_NO_PARAM_OP_RULE) | |||||
| #undef CREATE_MULTI_INPS_NO_PARAM_OP_RULE | |||||
| // identity op | // identity op | ||||
| #define CREATE_IDENTITY_OP_RULE(Op) \ | #define CREATE_IDENTITY_OP_RULE(Op) \ | ||||
| ValueRefList Op##_rule( \ | ValueRefList Op##_rule( \ | ||||
| @@ -409,8 +444,9 @@ struct FormatRuleRegistry { | |||||
| register_format_rule(setsubtensor_rule<SetSubtensor>); | register_format_rule(setsubtensor_rule<SetSubtensor>); | ||||
| register_format_rule(setsubtensor_rule<IndexingSetMultiAxisVec>); | register_format_rule(setsubtensor_rule<IndexingSetMultiAxisVec>); | ||||
| register_format_rule(concat_rule); | register_format_rule(concat_rule); | ||||
| register_format_rule(elemwise_rule); | |||||
| register_format_rule(batchnorm_rule); | register_format_rule(batchnorm_rule); | ||||
| register_format_rule(checknonfinite_rule); | |||||
| FOREACH_MULTI_INPS_NO_PARAM_OP(REGISTER_OP_RULE) | |||||
| FOREACH_IDENTITY_OP(REGISTER_OP_RULE) | FOREACH_IDENTITY_OP(REGISTER_OP_RULE) | ||||
| FOREACH_FORMAT_OP(REGISTER_OP_RULE) | FOREACH_FORMAT_OP(REGISTER_OP_RULE) | ||||
| FOREACH_FORMAT_POLICY_OP(REGISTER_OP_RULE) | FOREACH_FORMAT_POLICY_OP(REGISTER_OP_RULE) | ||||
| @@ -455,27 +491,73 @@ ValueRefList FormatTransformation::apply_transformation( | |||||
| return imperative::apply(op, unwrap_inputs(inputs)); | return imperative::apply(op, unwrap_inputs(inputs)); | ||||
| } | } | ||||
| } else if (op.is<GetFormat>()) { | } else if (op.is<GetFormat>()) { | ||||
| bool is_formatted_tensor = inputs.item().is(m_value_type); | |||||
| if (is_formatted_tensor) { | |||||
| return {FormatValue::make(inputs[0].cast(m_value_type).format())}; | |||||
| auto&& inp_ref = inputs[0].as_ref(m_value_type); | |||||
| if (inp_ref) { | |||||
| return {FormatValue::make(inp_ref->format())}; | |||||
| } else { | } else { | ||||
| mgb_log_warn( | mgb_log_warn( | ||||
| "Not FormattedTensorValue input for GetFormat op: %s", | |||||
| inputs[0].to_string().c_str()); | |||||
| "Not FormattedTensorValue input for GetFormat op: %s, %s", | |||||
| op.to_string().c_str(), inputs[0].to_string().c_str()); | |||||
| return {FormatValue::make(FT::DEFAULT)}; | return {FormatValue::make(FT::DEFAULT)}; | ||||
| } | } | ||||
| } else if (op.is<Operator::IdentityLike>()) { | } else if (op.is<Operator::IdentityLike>()) { | ||||
| bool is_formatted_tensor = inputs.item().is(m_value_type); | |||||
| if (is_formatted_tensor) { | |||||
| auto&& format = inputs[0].cast(m_value_type).format(); | |||||
| auto&& inp_ref = inputs[0].as_ref(m_value_type); | |||||
| if (inp_ref) { | |||||
| auto&& format = inp_ref->format(); | |||||
| return wrap_outputs( | return wrap_outputs( | ||||
| imperative::apply(op, unwrap_inputs(inputs)), format.type()); | imperative::apply(op, unwrap_inputs(inputs)), format.type()); | ||||
| } else { | } else { | ||||
| mgb_log_warn( | mgb_log_warn( | ||||
| "Not FormattedTensorValue input for IdentityLike op: %s", | |||||
| inputs[0].to_string().c_str()); | |||||
| "Not FormattedTensorValue input for IdentityLike op: %s, %s", | |||||
| op.to_string().c_str(), inputs[0].to_string().c_str()); | |||||
| return imperative::apply(op, inputs); | return imperative::apply(op, inputs); | ||||
| } | } | ||||
| } else if (op.is<AttachGrad>()) { | |||||
| auto&& inp_ref = inputs[0].as_ref(m_value_type); | |||||
| if (inp_ref) { | |||||
| auto format = inp_ref->format(); | |||||
| GenericFunction callback = | |||||
| (GenericFunction&)inputs[1].cast<FunctionValue>(); | |||||
| GenericFunction new_callback = | |||||
| [this, callback, format](Span<ValueRef> inputs_) -> ValueRefList { | |||||
| auto wrapped_inputs = SmallVector<ValueRef>{ | |||||
| this->value_type().make(inputs_.item(), format.type())}; | |||||
| auto ret = callback(wrapped_inputs); | |||||
| return ret; | |||||
| }; | |||||
| auto&& outputs = imperative::apply( | |||||
| op, inp_ref->value(), FunctionValue::make(new_callback)); | |||||
| return wrap_outputs(outputs, format.type()); | |||||
| } else { | |||||
| mgb_log_warn( | |||||
| "Not FormattedTensorValue input for AttachGrad op: %s, %s", | |||||
| op.to_string().c_str(), inputs[0].to_string().c_str()); | |||||
| return imperative::apply(op, inputs); | |||||
| } | |||||
| } else if (auto* set_grad = op.as<SetGrad>()) { | |||||
| size_t nr_inputs = set_grad->nr_inputs(); | |||||
| size_t nr_outputs = inputs.size() - nr_inputs; | |||||
| Span<ValueRef> inputs_ = {inputs.data(), nr_inputs}; | |||||
| Span<ValueRef> outputs_ = {inputs.data() + nr_inputs, nr_outputs}; | |||||
| // run original apply. | |||||
| // grads needn't to unwrap and wrap, which will be unwrapped in GradTrans | |||||
| auto&& outputs = imperative::apply(op, unwrap_inputs(inputs)); | |||||
| // handle output's formats | |||||
| auto wrapped_outputs = ValueRefList(nr_outputs); | |||||
| for (size_t i = 0; i < nr_outputs; ++i) { | |||||
| if (auto output_ref = outputs_[i].as_ref(m_value_type)) { | |||||
| wrapped_outputs[i] = | |||||
| m_value_type.make(outputs[i], output_ref->format().type()); | |||||
| } else { | |||||
| mgb_log_warn( | |||||
| "Not FormattedTensorValue outputs for SetGrad op: %s, %s", | |||||
| op.to_string().c_str(), inputs_[i].to_string().c_str()); | |||||
| wrapped_outputs[i] = m_value_type.make(outputs[i], FT::DEFAULT); | |||||
| } | |||||
| } | |||||
| return wrapped_outputs; | |||||
| } else { | } else { | ||||
| return imperative::apply(op, unwrap_inputs(inputs)); | return imperative::apply(op, unwrap_inputs(inputs)); | ||||
| } | } | ||||
| @@ -47,7 +47,10 @@ public: | |||||
| const Operator& op, Span<ValueRef> inputs) override; | const Operator& op, Span<ValueRef> inputs) override; | ||||
| ValueRef unwrap(ValueRef value) override { | ValueRef unwrap(ValueRef value) override { | ||||
| mgb_assert(!value.is(m_value_type)); | |||||
| //mgb_assert(!value.is(m_value_type)); | |||||
| if (auto format_val = value.as_ref(m_value_type)) { | |||||
| return format_val->value(); | |||||
| } | |||||
| return value; | return value; | ||||
| } | } | ||||
| @@ -377,6 +377,8 @@ public: | |||||
| SetGrad(GenericFunction grad_fn, size_t nr_inputs) | SetGrad(GenericFunction grad_fn, size_t nr_inputs) | ||||
| : m_grad_fn(grad_fn), m_nr_inputs(nr_inputs) {} | : m_grad_fn(grad_fn), m_nr_inputs(nr_inputs) {} | ||||
| std::shared_ptr<GradKey> key() const { return m_key; } | |||||
| GenericFunction grad_fn() const { return m_grad_fn; } | GenericFunction grad_fn() const { return m_grad_fn; } | ||||
| size_t nr_inputs() const { return m_nr_inputs; } | size_t nr_inputs() const { return m_nr_inputs; } | ||||