GitOrigin-RevId: 1c728d6ab9
tags/v1.10.0
| @@ -50,36 +50,36 @@ class autocast: | |||
| self._origin_enabled = None | |||
| self._origin_high = None | |||
| self._origin_low = None | |||
| self._origin_compute_mode = None | |||
| self._origin_configs = None | |||
| def __enter__(self): | |||
| self._origin_enabled = amp._enabled | |||
| amp._enabled = self.enabled | |||
| amp._set_amp_dtype_autocast(self.enabled) | |||
| if not self.enabled: | |||
| return | |||
| if self.enabled: | |||
| self._origin_enabled = amp._enabled | |||
| self._origin_high = amp._get_amp_high_prec_dtype() | |||
| self._origin_low = amp._get_amp_low_prec_dtype() | |||
| amp._enabled = self.enabled | |||
| amp._set_amp_dtype_autocast(self.enabled) | |||
| amp._set_amp_high_prec_dtype(self.high_prec_dtype) | |||
| amp._set_amp_low_prec_dtype(self.low_prec_dtype) | |||
| self._origin_high = amp._get_amp_high_prec_dtype() | |||
| self._origin_low = amp._get_amp_low_prec_dtype() | |||
| amp._set_amp_high_prec_dtype(self.high_prec_dtype) | |||
| amp._set_amp_low_prec_dtype(self.low_prec_dtype) | |||
| self._origin_configs = _config._reset_execution_config(compute_mode="float32") | |||
| self._origin_configs = _config._reset_execution_config( | |||
| compute_mode="float32" | |||
| ) | |||
| def __exit__(self, *args): | |||
| amp._enabled = self._origin_enabled | |||
| amp._set_amp_dtype_autocast(self._origin_enabled) | |||
| if not self.enabled: | |||
| return | |||
| amp._set_amp_high_prec_dtype(self._origin_high) | |||
| amp._set_amp_low_prec_dtype(self._origin_low) | |||
| if self.enabled: | |||
| amp._enabled = self._origin_enabled | |||
| amp._set_amp_dtype_autocast(self._origin_enabled) | |||
| amp._set_amp_high_prec_dtype(self._origin_high) | |||
| amp._set_amp_low_prec_dtype(self._origin_low) | |||
| _config._reset_execution_config(*self._origin_compute_mode) | |||
| def __call__(self, func): | |||
| @functools.wraps(func) | |||
| def wrapper(*args, **kwargs): | |||
| if not self.enabled: | |||
| return func(*args, **kwargs) | |||
| with self: | |||
| return func(*args, **kwargs) | |||
| @@ -10,6 +10,7 @@ from copy import deepcopy | |||
| from .. import functional as F | |||
| from ..module import Module | |||
| from ..tensor import Tensor | |||
| from ..core import _config | |||
| def _is_nchw_format(param: Tensor): | |||
| @@ -26,10 +27,12 @@ def convert_tensor_format(x: Tensor, inplace: bool = True): | |||
| else: | |||
| raise ValueError("Unsupport tensor ndim {}".format(x.ndim)) | |||
| # TODO: use initialization from tensor after fixing format setting | |||
| if inplace: | |||
| x[...] = Tensor(x.numpy().transpose(*pattern), format="nhwc") | |||
| else: | |||
| x = Tensor(x.numpy().transpose(*pattern), format="nhwc") | |||
| if x.format != "nhwc": | |||
| if inplace: | |||
| data = x.numpy().transpose(*pattern) | |||
| x[...] = Tensor(data, format="nhwc") | |||
| else: | |||
| x = Tensor(x.numpy().transpose(*pattern), format="nhwc") | |||
| return x | |||
| @@ -144,7 +144,9 @@ class GradScaler: | |||
| def _check_gradients(self, grads, scale): | |||
| if len(grads) == 0: | |||
| return False | |||
| return _check_non_finite(grads, scale) | |||
| rst = _check_non_finite(grads, scale) | |||
| rst = rst.numpy() | |||
| return rst | |||
| def update(self, new_scale: float = None): | |||
| r"""Update the scale factor according to whether encountered overflow grad. | |||
| @@ -182,7 +182,6 @@ def _reset_execution_config( | |||
| deterministic_kernel=None, | |||
| async_level=None, | |||
| compute_mode=None, | |||
| bn_format=None, | |||
| auto_format_convert=None, | |||
| ): | |||
| global _benchmark_kernel, _deterministic_kernel, __compute_mode | |||
| @@ -234,11 +233,11 @@ def _override( | |||
| def train(): | |||
| """ | |||
| orig_flags = _reset_execution_config( | |||
| benchmark_kernel, | |||
| deterministic_kernel, | |||
| async_level, | |||
| compute_mode, | |||
| auto_format_convert, | |||
| benchmark_kernel=benchmark_kernel, | |||
| deterministic_kernel=deterministic_kernel, | |||
| async_level=async_level, | |||
| compute_mode=compute_mode, | |||
| auto_format_convert=auto_format_convert, | |||
| ) | |||
| try: | |||
| yield | |||
| @@ -64,7 +64,9 @@ class Grad: | |||
| continue | |||
| grad.suppress() | |||
| print("before backward") | |||
| self._impl.backward(ys, dys) | |||
| print("after backward") | |||
| for grad in group: | |||
| if grad is self: | |||
| @@ -24,6 +24,7 @@ from .._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder | |||
| from .._imperative_rt.ops import jit_supported | |||
| from .._wrap import as_device | |||
| from ..autodiff.grad import Function | |||
| from .. import _config | |||
| from ..ops import builtin | |||
| from .amp import _get_amp_high_prec_dtype, _get_amp_low_prec_dtype | |||
| from .dtype import is_dtype_equal, is_quantize | |||
| @@ -1226,12 +1226,16 @@ def batch_norm( | |||
| bias = make_full_if_none(bias, 0) | |||
| if not training: | |||
| op = builtin.BatchNorm(fwd_mode=BatchNorm.FwdMode.INFERENCE, epsilon=eps) | |||
| op = builtin.BatchNorm( | |||
| fwd_mode=BatchNorm.FwdMode.INFERENCE, param_dim="dim_1c11", epsilon=eps | |||
| ) | |||
| ret = apply(op, inp, weight, bias, running_mean, running_var)[-1] | |||
| return ret | |||
| else: | |||
| op = builtin.BatchNorm(avg_factor=1 - momentum, epsilon=eps) | |||
| op = builtin.BatchNorm( | |||
| avg_factor=1 - momentum, param_dim="dim_1c11", epsilon=eps | |||
| ) | |||
| if has_mean or has_var: | |||
| running_mean = make_full_if_none(running_mean, 0) | |||
| running_var = make_full_if_none(running_var, 1) | |||
| @@ -19,7 +19,6 @@ class _BatchNorm(Module): | |||
| affine=True, | |||
| track_running_stats=True, | |||
| freeze=False, | |||
| param_dim="dim_1c11", | |||
| **kwargs | |||
| ): | |||
| super(_BatchNorm, self).__init__(**kwargs) | |||
| @@ -30,7 +29,6 @@ class _BatchNorm(Module): | |||
| self.track_running_stats = track_running_stats | |||
| self._track_running_stats_saved = track_running_stats | |||
| self.freeze = freeze | |||
| self.param_dim = param_dim | |||
| if self.freeze: | |||
| assert ( | |||
| self._track_running_stats_saved | |||
| @@ -104,7 +102,6 @@ class _BatchNorm(Module): | |||
| or ((self.running_mean is None) and (self.running_var is None)), | |||
| momentum=exponential_average_factor, | |||
| eps=self.eps, | |||
| param_dim=self.param_dim, | |||
| ) | |||
| return output | |||
| @@ -8,6 +8,7 @@ from typing import Union | |||
| import numpy as np | |||
| from ..core import _config | |||
| from ..core._imperative_rt.core2 import ( | |||
| get_auto_format_convert, | |||
| pop_scope, | |||
| @@ -96,7 +97,7 @@ class Optimizer(metaclass=ABCMeta): | |||
| "optimizer can only optimize Parameters, but one of the params is " | |||
| + str(type(param)) | |||
| ) | |||
| param._reset(Tensor(param.numpy(), no_cache=True, format=param.format)) | |||
| param._reset(Tensor(param, no_cache=True)) | |||
| for name, default in self._defaults.items(): | |||
| if default is required and name not in param_group: | |||
| @@ -119,10 +120,11 @@ class Optimizer(metaclass=ABCMeta): | |||
| def _add_state(self, param, state_name, initializer=None): | |||
| if initializer is None: | |||
| initializer = np.zeros(param.shape, dtype=np.float32) | |||
| with _config._override(auto_format_convert=False): | |||
| initializer = np.zeros(param.shape, dtype=np.float32) | |||
| state_dict = self._state.setdefault(param, {}) | |||
| assert state_name not in state_dict | |||
| state = Tensor(initializer, no_cache=True) | |||
| state = Tensor(initializer, no_cache=True, format=param.format) | |||
| state_dict[state_name] = state | |||
| @abstractmethod | |||
| @@ -5,6 +5,7 @@ from typing import Iterable, Union | |||
| from ..functional.inplace import _inplace_add_ | |||
| from ..tensor import Parameter, tensor | |||
| from .optimizer import Optimizer | |||
| from ..core import _config | |||
| class SGD(Optimizer): | |||
| @@ -10,7 +10,7 @@ import pytest | |||
| import megengine.functional as F | |||
| import megengine.module as M | |||
| from megengine import Parameter, Tensor, amp, tensor | |||
| from megengine import Parameter, Tensor, amp, config | |||
| class MyModule(M.Module): | |||
| @@ -39,6 +39,22 @@ class MyModule(M.Module): | |||
| @pytest.mark.parametrize("is_inplace", [False, True]) | |||
| def test_convert_module(is_inplace): | |||
| m = MyModule() | |||
| expected_shape = { | |||
| "i.bn.weight": (1, 1, 1, 4), | |||
| "i.bn.bias": (1, 1, 1, 4), | |||
| "i.bn.running_mean": (1, 1, 1, 4), | |||
| "i.bn.running_var": (1, 1, 1, 4), | |||
| "conv.weight": (2, 2, 4, 4, 2), | |||
| "conv.bias": (1, 1, 1, 4), | |||
| "bn.weight": (1, 1, 1, 4), | |||
| "bn.bias": (1, 1, 1, 4), | |||
| "bn.running_mean": (1, 1, 1, 4), | |||
| "bn.running_var": (1, 1, 1, 4), | |||
| "param": (1, 1, 1, 3), | |||
| "buff": (1, 1, 1, 3), | |||
| } | |||
| m = amp.convert_module_format(m, is_inplace) | |||
| for name, param in m.named_tensors(): | |||
| assert param.format == "nhwc" | |||
| with config._override(auto_format_convert=False): | |||
| assert param.shape == expected_shape[name], name | |||
| @@ -3,6 +3,7 @@ import pytest | |||
| import megengine as mge | |||
| import megengine.functional as F | |||
| import megengine.module as M | |||
| from megengine import tensor | |||
| from megengine.autodiff import GradManager | |||
| from megengine.jit import trace | |||
| @@ -36,9 +37,9 @@ def _compare_nchw_nhwc(data, func, is_symbolic=None): | |||
| x2 = tensor(data.transpose(0, 2, 3, 1), format="nhwc") | |||
| if is_symbolic is not None: | |||
| func = trace(func, symbolic=is_symbolic) | |||
| # out1 = func(x1) | |||
| out1 = func(x1) | |||
| out2 = func(x2) | |||
| # np.testing.assert_almost_equal(out1, out2, decimal=5) | |||
| np.testing.assert_almost_equal(out1, out2, decimal=5) | |||
| @pytest.mark.parametrize("is_symbolic", [None]) | |||
| @@ -322,30 +323,91 @@ def test_pooling2d(pooling, is_symbolic): | |||
| _compare_nchw_nhwc(data, func, is_symbolic) | |||
| @pytest.mark.parametrize("is_symbolic", [None]) | |||
| def test_backward(is_symbolic): | |||
| data = np.arange(0, 24).reshape((1, 2, 3, 4)) | |||
| x = tensor(data.transpose(0, 2, 3, 1), format="nhwc") | |||
| w = mge.tensor(np.ones((3, 1, 1, 2)), format="nhwc") | |||
| b = mge.tensor(np.ones((1, 1, 1, 3)), format="nhwc") | |||
| gm = GradManager().attach([w, b]) | |||
| def _compare_backward(inps, model, is_symbolic=None): | |||
| def func(*inps): | |||
| return model(*inps) | |||
| def func(x, w, b): | |||
| return F.conv2d(x, w, b) | |||
| if is_symbolic is not None: | |||
| func = trace(func, symbolic=is_symbolic) | |||
| gm = GradManager().attach(model.parameters()) | |||
| with gm: | |||
| if is_symbolic is not None: | |||
| func = trace(func, symbolic=is_symbolic) | |||
| x = func(x, w, b) | |||
| assert x.format == "nhwc" | |||
| # test manually convert to NHWC, usually used in detection head | |||
| x = x.transpose(0, 2, 3, 1).reshape(1, 18, 2) | |||
| gm.backward(x) | |||
| print("finish backward", x.format) | |||
| # backward grad has no format | |||
| np.testing.assert_equal( | |||
| w.grad.numpy(), np.array([66, 210, 66, 210, 66, 210]).reshape((3, 1, 1, 2)), | |||
| ) | |||
| np.testing.assert_equal( | |||
| b.grad.numpy(), np.array([12, 12, 12]).reshape((1, 1, 1, 3)) | |||
| ) | |||
| rst = func(*inps) | |||
| gm.backward(rst) | |||
| expected_grads = [param.grad for param in model.parameters()] | |||
| inps = [mge.amp.convert_tensor_format(inp) for inp in inps] | |||
| model = mge.amp.convert_module_format(model) | |||
| gm = GradManager().attach(model.parameters()) | |||
| with gm: | |||
| rst = func(*inps) | |||
| gm.backward(rst) | |||
| actual_grads = [param.grad for param in model.parameters()] | |||
| for expected, actual in zip(expected_grads, actual_grads): | |||
| # print(param.grad) | |||
| np.testing.assert_equal(expected.numpy(), actual.numpy()) | |||
| @pytest.mark.parametrize("is_symbolic", [None]) | |||
| def test_backward_conv2d_dimshuffle(is_symbolic): | |||
| class Net(M.Module): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.conv = M.Conv2d(2, 3, 1) | |||
| def forward(self, inp): | |||
| # test manually convert to NHWC, usually used in detection head | |||
| return F.transpose(self.conv(inp), (0, 2, 3, 1)).reshape(1, 18, 2) | |||
| inp = mge.tensor(np.arange(0, 24).reshape((1, 2, 3, 4))) | |||
| # x = tensor(data.transpose(0, 2, 3, 1), format="nhwc") | |||
| # w = mge.tensor(np.ones((3, 1, 1, 2)), format="nhwc") | |||
| # b = mge.tensor(np.ones((1, 1, 1, 3)), format="nhwc") | |||
| # grads = [ | |||
| # np.array([66, 210, 66, 210, 66, 210]).reshape((3, 1, 1, 2)), | |||
| # np.array([12, 12, 12]).reshape((1, 1, 1, 3)), | |||
| # ] | |||
| _compare_backward([inp], Net(), is_symbolic) | |||
| @pytest.mark.parametrize("is_symbolic", [None]) | |||
| def test_backward_groupconv2d_bn(is_symbolic): | |||
| class Net(M.Module): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.conv = M.Conv2d(2, 2, 1, groups=2) | |||
| self.bn = M.BatchNorm2d(2) | |||
| def forward(self, inp): | |||
| # test manually convert to NHWC, usually used in detection head | |||
| return self.bn(self.conv(inp)) | |||
| inp = mge.tensor(np.arange(0, 24).reshape((1, 2, 3, 4))) | |||
| _compare_backward([inp], Net(), is_symbolic) | |||
| # def func(x, w, b, bn_w, bn_b): | |||
| # x = F.conv2d(x, w, b, groups=2) | |||
| # x = F.batch_norm( | |||
| # x, | |||
| # running_mean=mge.tensor(np.ones((1, 1, 1, 2)), format="nhwc"), | |||
| # running_var=mge.tensor(np.ones((1, 1, 1, 2)), format="nhwc"), | |||
| # weight=bn_w, | |||
| # bias=bn_b, | |||
| # training=True, | |||
| # inplace=True, | |||
| # ) | |||
| # return x | |||
| # data = np.arange(0, 24).reshape((1, 2, 3, 4)) | |||
| # x = tensor(data.transpose(0, 2, 3, 1), format="nhwc") | |||
| # w = tensor(np.ones((2, 1, 1, 1, 1)), format="nhwc") | |||
| # b = tensor(np.ones((1, 1, 1, 2)), format="nhwc") | |||
| # bn_w = tensor(np.ones((1, 1, 1, 2)), format="nhwc") | |||
| # bn_b = tensor(np.ones((1, 1, 1, 2)), format="nhwc") | |||
| # grads = [ | |||
| # np.array([66, 210]).reshape((2, 1, 1, 1, 1)), | |||
| # np.array([12, 12]).reshape((1, 1, 1, 2)), | |||
| # np.array([12, 12]).reshape((1, 1, 1, 2)), | |||
| # np.array([12, 12]).reshape((1, 1, 1, 2)), | |||
| # ] | |||
| # _compare_backward(x, func, [w, b, bn_w, bn_b], grads, is_symbolic) | |||
| @@ -1,6 +1,8 @@ | |||
| #include "megbrain/imperative/transformations/format.h" | |||
| #include "megbrain/imperative/transformations/grad.h" | |||
| #include "megbrain/imperative/ops/autogen.h" | |||
| #include "megbrain/imperative/ops/utility.h" | |||
| namespace mgb { | |||
| namespace imperative { | |||
| @@ -17,7 +19,12 @@ TypedValueRef<FormattedTensorValue> FormatTransformation::to( | |||
| const std::string& scope) const { | |||
| std::vector<int32_t> pattern; | |||
| if (tensor.format() == FT::NHWC && target == FT::NCHW) { | |||
| pattern = {0, 3, 1, 2}; | |||
| // FIXME(czh): temporary fast path for group conv 5D weight. | |||
| if (tensor.value().shape().cast<ShapeValue>().ndim == 5) { | |||
| pattern = {0, 1, 4, 2, 3}; | |||
| } else { | |||
| pattern = {0, 3, 1, 2}; | |||
| } | |||
| } else if (tensor.format() == FT::NCHW && target == FT::NHWC) { | |||
| pattern = {0, 2, 3, 1}; | |||
| } else { | |||
| @@ -65,12 +72,22 @@ inline ValueRefList FormatTransformation::wrap_outputs( | |||
| namespace { | |||
| ValueShape convert_nhwc2nchw_shape(const ValueShape& shape) { | |||
| mgb_assert(shape.ndim == 4); | |||
| auto out = ValueShape(shape); | |||
| out[3] = shape[2]; | |||
| out[2] = shape[1]; | |||
| out[1] = shape[3]; | |||
| return out; | |||
| if (shape.ndim == 4) { | |||
| out[1] = shape[3]; | |||
| out[2] = shape[1]; | |||
| out[3] = shape[2]; | |||
| return out; | |||
| } else if (shape.ndim == 5) { | |||
| out[2] = shape[4]; | |||
| out[3] = shape[2]; | |||
| out[4] = shape[3]; | |||
| return out; | |||
| } else { | |||
| mgb_throw( | |||
| MegBrainError, "Unsupported shape ndim %u in GetAttr(Shape).", | |||
| shape.ndim); | |||
| } | |||
| } | |||
| using FormatRule = std::function<ValueRefList( | |||
| @@ -278,10 +295,10 @@ ValueRefList setsubtensor_rule( | |||
| inline FT get_inputs_format(Span<ValueRef>& inputs, const FormatTransformation& t) { | |||
| FT format(FT::DEFAULT); | |||
| for (auto& inp : inputs) { | |||
| auto&& inp_ref = inp.as_ref(t.value_type()); | |||
| if (inp_ref && inp_ref->format() != FT::DEFAULT) { | |||
| mgb_assert(format == FT::DEFAULT || inp_ref->format() == format); | |||
| format = inp_ref->format().type(); | |||
| auto&& inp_format = inp.cast(t.value_type()).format(); | |||
| if (inp_format != FT::DEFAULT) { | |||
| mgb_assert(format == FT::DEFAULT || inp_format == format); | |||
| format = inp_format.type(); | |||
| } | |||
| } | |||
| return format; | |||
| @@ -308,13 +325,6 @@ ValueRefList concat_rule( | |||
| format); | |||
| } | |||
| ValueRefList elemwise_rule( | |||
| const Elemwise& op, Span<ValueRef>& inputs, const bool& auto_convert, | |||
| const FormatTransformation& t) { | |||
| FT format = get_inputs_format(inputs, t); | |||
| return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs)), format); | |||
| } | |||
| ValueRefList identity_rule_helper( | |||
| const OpDef& op, const Span<ValueRef>& inputs, const FormatTransformation& t) { | |||
| // mgb_assert(inputs.size() == 1); | |||
| @@ -336,24 +346,49 @@ ValueRefList batchnorm_rule( | |||
| return identity_rule_helper(op, inputs, t); | |||
| } | |||
| ValueRefList checknonfinite_rule( | |||
| const CheckNonFinite& op, Span<ValueRef>& inputs, const bool& auto_convert, | |||
| const FormatTransformation& t) { | |||
| auto&& inputs_ = t.unwrap_inputs(inputs); | |||
| auto&& outputs_ = imperative::apply(op, inputs_); | |||
| return t.wrap_outputs(outputs_); | |||
| } | |||
| // clang-format off | |||
| #define FOREACH_IDENTITY_OP(cb) \ | |||
| cb(Copy) \ | |||
| cb(FastpathCopy) \ | |||
| cb(TypeCvt) \ | |||
| cb(Dropout) \ | |||
| #define FOREACH_MULTI_INPS_NO_PARAM_OP(cb) \ | |||
| cb(Elemwise) \ | |||
| cb(CompiledOp) \ | |||
| cb(SubgraphOp) | |||
| #define FOREACH_IDENTITY_OP(cb) \ | |||
| cb(Copy) \ | |||
| cb(FastpathCopy) \ | |||
| cb(TypeCvt) \ | |||
| cb(Dropout) \ | |||
| cb(Identity) | |||
| #define FOREACH_FORMAT_OP(cb) \ | |||
| cb(AdaptivePooling) \ | |||
| cb(WarpAffine) \ | |||
| #define FOREACH_FORMAT_OP(cb) \ | |||
| cb(AdaptivePooling) \ | |||
| cb(WarpAffine) \ | |||
| cb(Resize) | |||
| #define FOREACH_FORMAT_POLICY_OP(cb)\ | |||
| cb(Pooling) \ | |||
| #define FOREACH_FORMAT_POLICY_OP(cb) \ | |||
| cb(Pooling) \ | |||
| cb(Convolution) | |||
| // clang-format on | |||
| // multi inputs op without params | |||
| #define CREATE_MULTI_INPS_NO_PARAM_OP_RULE(Op) \ | |||
| ValueRefList Op##_rule( \ | |||
| const Op& _op, Span<ValueRef>& inputs, const bool& auto_convert, \ | |||
| const FormatTransformation& t) { \ | |||
| FT format = get_inputs_format(inputs, t); \ | |||
| return t.wrap_outputs( \ | |||
| imperative::apply(_op, t.unwrap_inputs(inputs)), format); \ | |||
| } | |||
| FOREACH_MULTI_INPS_NO_PARAM_OP(CREATE_MULTI_INPS_NO_PARAM_OP_RULE) | |||
| #undef CREATE_MULTI_INPS_NO_PARAM_OP_RULE | |||
| // identity op | |||
| #define CREATE_IDENTITY_OP_RULE(Op) \ | |||
| ValueRefList Op##_rule( \ | |||
| @@ -409,8 +444,9 @@ struct FormatRuleRegistry { | |||
| register_format_rule(setsubtensor_rule<SetSubtensor>); | |||
| register_format_rule(setsubtensor_rule<IndexingSetMultiAxisVec>); | |||
| register_format_rule(concat_rule); | |||
| register_format_rule(elemwise_rule); | |||
| register_format_rule(batchnorm_rule); | |||
| register_format_rule(checknonfinite_rule); | |||
| FOREACH_MULTI_INPS_NO_PARAM_OP(REGISTER_OP_RULE) | |||
| FOREACH_IDENTITY_OP(REGISTER_OP_RULE) | |||
| FOREACH_FORMAT_OP(REGISTER_OP_RULE) | |||
| FOREACH_FORMAT_POLICY_OP(REGISTER_OP_RULE) | |||
| @@ -455,27 +491,73 @@ ValueRefList FormatTransformation::apply_transformation( | |||
| return imperative::apply(op, unwrap_inputs(inputs)); | |||
| } | |||
| } else if (op.is<GetFormat>()) { | |||
| bool is_formatted_tensor = inputs.item().is(m_value_type); | |||
| if (is_formatted_tensor) { | |||
| return {FormatValue::make(inputs[0].cast(m_value_type).format())}; | |||
| auto&& inp_ref = inputs[0].as_ref(m_value_type); | |||
| if (inp_ref) { | |||
| return {FormatValue::make(inp_ref->format())}; | |||
| } else { | |||
| mgb_log_warn( | |||
| "Not FormattedTensorValue input for GetFormat op: %s", | |||
| inputs[0].to_string().c_str()); | |||
| "Not FormattedTensorValue input for GetFormat op: %s, %s", | |||
| op.to_string().c_str(), inputs[0].to_string().c_str()); | |||
| return {FormatValue::make(FT::DEFAULT)}; | |||
| } | |||
| } else if (op.is<Operator::IdentityLike>()) { | |||
| bool is_formatted_tensor = inputs.item().is(m_value_type); | |||
| if (is_formatted_tensor) { | |||
| auto&& format = inputs[0].cast(m_value_type).format(); | |||
| auto&& inp_ref = inputs[0].as_ref(m_value_type); | |||
| if (inp_ref) { | |||
| auto&& format = inp_ref->format(); | |||
| return wrap_outputs( | |||
| imperative::apply(op, unwrap_inputs(inputs)), format.type()); | |||
| } else { | |||
| mgb_log_warn( | |||
| "Not FormattedTensorValue input for IdentityLike op: %s", | |||
| inputs[0].to_string().c_str()); | |||
| "Not FormattedTensorValue input for IdentityLike op: %s, %s", | |||
| op.to_string().c_str(), inputs[0].to_string().c_str()); | |||
| return imperative::apply(op, inputs); | |||
| } | |||
| } else if (op.is<AttachGrad>()) { | |||
| auto&& inp_ref = inputs[0].as_ref(m_value_type); | |||
| if (inp_ref) { | |||
| auto format = inp_ref->format(); | |||
| GenericFunction callback = | |||
| (GenericFunction&)inputs[1].cast<FunctionValue>(); | |||
| GenericFunction new_callback = | |||
| [this, callback, format](Span<ValueRef> inputs_) -> ValueRefList { | |||
| auto wrapped_inputs = SmallVector<ValueRef>{ | |||
| this->value_type().make(inputs_.item(), format.type())}; | |||
| auto ret = callback(wrapped_inputs); | |||
| return ret; | |||
| }; | |||
| auto&& outputs = imperative::apply( | |||
| op, inp_ref->value(), FunctionValue::make(new_callback)); | |||
| return wrap_outputs(outputs, format.type()); | |||
| } else { | |||
| mgb_log_warn( | |||
| "Not FormattedTensorValue input for AttachGrad op: %s, %s", | |||
| op.to_string().c_str(), inputs[0].to_string().c_str()); | |||
| return imperative::apply(op, inputs); | |||
| } | |||
| } else if (auto* set_grad = op.as<SetGrad>()) { | |||
| size_t nr_inputs = set_grad->nr_inputs(); | |||
| size_t nr_outputs = inputs.size() - nr_inputs; | |||
| Span<ValueRef> inputs_ = {inputs.data(), nr_inputs}; | |||
| Span<ValueRef> outputs_ = {inputs.data() + nr_inputs, nr_outputs}; | |||
| // run original apply. | |||
| // grads needn't to unwrap and wrap, which will be unwrapped in GradTrans | |||
| auto&& outputs = imperative::apply(op, unwrap_inputs(inputs)); | |||
| // handle output's formats | |||
| auto wrapped_outputs = ValueRefList(nr_outputs); | |||
| for (size_t i = 0; i < nr_outputs; ++i) { | |||
| if (auto output_ref = outputs_[i].as_ref(m_value_type)) { | |||
| wrapped_outputs[i] = | |||
| m_value_type.make(outputs[i], output_ref->format().type()); | |||
| } else { | |||
| mgb_log_warn( | |||
| "Not FormattedTensorValue outputs for SetGrad op: %s, %s", | |||
| op.to_string().c_str(), inputs_[i].to_string().c_str()); | |||
| wrapped_outputs[i] = m_value_type.make(outputs[i], FT::DEFAULT); | |||
| } | |||
| } | |||
| return wrapped_outputs; | |||
| } else { | |||
| return imperative::apply(op, unwrap_inputs(inputs)); | |||
| } | |||
| @@ -47,7 +47,10 @@ public: | |||
| const Operator& op, Span<ValueRef> inputs) override; | |||
| ValueRef unwrap(ValueRef value) override { | |||
| mgb_assert(!value.is(m_value_type)); | |||
| //mgb_assert(!value.is(m_value_type)); | |||
| if (auto format_val = value.as_ref(m_value_type)) { | |||
| return format_val->value(); | |||
| } | |||
| return value; | |||
| } | |||
| @@ -377,6 +377,8 @@ public: | |||
| SetGrad(GenericFunction grad_fn, size_t nr_inputs) | |||
| : m_grad_fn(grad_fn), m_nr_inputs(nr_inputs) {} | |||
| std::shared_ptr<GradKey> key() const { return m_key; } | |||
| GenericFunction grad_fn() const { return m_grad_fn; } | |||
| size_t nr_inputs() const { return m_nr_inputs; } | |||