GitOrigin-RevId: 5ced9e1a31
tags/v1.10.0
| @@ -23,23 +23,17 @@ def convert_tensor_format(x: Tensor, inplace: bool = True): | |||||
| if not _is_nchw_format(x): | if not _is_nchw_format(x): | ||||
| return x | return x | ||||
| if x.ndim == 4: | |||||
| pattern = (0, 2, 3, 1) | |||||
| elif x.ndim == 5: | |||||
| pattern = (0, 1, 3, 4, 2) | |||||
| else: | |||||
| if x.ndim != 4 and x.ndim != 5: | |||||
| raise ValueError("Unsupport tensor ndim {}".format(x.ndim)) | raise ValueError("Unsupport tensor ndim {}".format(x.ndim)) | ||||
| # TODO: use initialization from tensor after fixing format setting | |||||
| if x.format != "nhwc": | if x.format != "nhwc": | ||||
| # hostvalue should still be valid, so no d2h cost. | |||||
| data = x.numpy() | |||||
| if inplace: | if inplace: | ||||
| # hostvalue should still be valid, so no d2h cost. | |||||
| data = x.numpy() | |||||
| # reset will destroy existed backward grad | # reset will destroy existed backward grad | ||||
| x[...] = Tensor(data, format="nhwc") | x[...] = Tensor(data, format="nhwc") | ||||
| else: | else: | ||||
| # use mge interface to maintain grad | # use mge interface to maintain grad | ||||
| x = F.transpose(x, pattern) | |||||
| x.format = "nhwc" | |||||
| x = Tensor(data, format="nhwc") | |||||
| return x | return x | ||||
| @@ -181,7 +181,6 @@ def _reset_execution_config( | |||||
| deterministic_kernel=None, | deterministic_kernel=None, | ||||
| async_level=None, | async_level=None, | ||||
| compute_mode=None, | compute_mode=None, | ||||
| auto_format_convert=None, | |||||
| ): | ): | ||||
| global _benchmark_kernel, _deterministic_kernel, __compute_mode | global _benchmark_kernel, _deterministic_kernel, __compute_mode | ||||
| orig_flags = ( | orig_flags = ( | ||||
| @@ -189,7 +188,6 @@ def _reset_execution_config( | |||||
| _deterministic_kernel, | _deterministic_kernel, | ||||
| get_option("async_level"), | get_option("async_level"), | ||||
| __compute_mode, | __compute_mode, | ||||
| get_auto_format_convert(), | |||||
| ) | ) | ||||
| if benchmark_kernel is not None: | if benchmark_kernel is not None: | ||||
| _benchmark_kernel = benchmark_kernel | _benchmark_kernel = benchmark_kernel | ||||
| @@ -199,8 +197,6 @@ def _reset_execution_config( | |||||
| set_option("async_level", async_level) | set_option("async_level", async_level) | ||||
| if compute_mode is not None: | if compute_mode is not None: | ||||
| __compute_mode = compute_mode | __compute_mode = compute_mode | ||||
| if auto_format_convert is not None: | |||||
| set_auto_format_convert(auto_format_convert) | |||||
| return orig_flags | return orig_flags | ||||
| @@ -211,7 +207,6 @@ def _override( | |||||
| deterministic_kernel=None, | deterministic_kernel=None, | ||||
| async_level=None, | async_level=None, | ||||
| compute_mode=None, | compute_mode=None, | ||||
| auto_format_convert=None, | |||||
| ): | ): | ||||
| r"""A context manager that users can opt in by attaching the decorator to set | r"""A context manager that users can opt in by attaching the decorator to set | ||||
| the config of the global variable. | the config of the global variable. | ||||
| @@ -227,7 +222,6 @@ def _override( | |||||
| deterministic_kernel = Fasle, | deterministic_kernel = Fasle, | ||||
| async_level=2, | async_level=2, | ||||
| compute_mode="float32", | compute_mode="float32", | ||||
| auto_format_convert=True, | |||||
| ) | ) | ||||
| def train(): | def train(): | ||||
| """ | """ | ||||
| @@ -236,7 +230,6 @@ def _override( | |||||
| deterministic_kernel=deterministic_kernel, | deterministic_kernel=deterministic_kernel, | ||||
| async_level=async_level, | async_level=async_level, | ||||
| compute_mode=compute_mode, | compute_mode=compute_mode, | ||||
| auto_format_convert=auto_format_convert, | |||||
| ) | ) | ||||
| try: | try: | ||||
| yield | yield | ||||
| @@ -1206,9 +1206,9 @@ def batch_norm( | |||||
| if x is None: | if x is None: | ||||
| x = Const(value, inp.dtype, inp.device) | x = Const(value, inp.dtype, inp.device) | ||||
| x.format = inp.format | |||||
| shape = astensor1d(pshape, inp, dtype="int32", device=inp.device) | shape = astensor1d(pshape, inp, dtype="int32", device=inp.device) | ||||
| (result,) = apply(builtin.Broadcast(), x, shape) | (result,) = apply(builtin.Broadcast(), x, shape) | ||||
| result.format = inp.format | |||||
| return result | return result | ||||
| else: | else: | ||||
| assert x_ndim == 1 | assert x_ndim == 1 | ||||
| @@ -274,7 +274,6 @@ def full_like(inp: Tensor, value: Union[int, float]) -> Tensor: | |||||
| return x | return x | ||||
| # set x's format to use FormatTransformation rule for Broadcast. | # set x's format to use FormatTransformation rule for Broadcast. | ||||
| x.format = inp.format | |||||
| return broadcast_to(x, inp.shape) | return broadcast_to(x, inp.shape) | ||||
| @@ -91,14 +91,13 @@ class Optimizer(metaclass=ABCMeta): | |||||
| else: | else: | ||||
| param_group["params"] = list(param_group["params"]) | param_group["params"] = list(param_group["params"]) | ||||
| with _config._override(auto_format_convert=False): | |||||
| for param in param_group["params"]: | |||||
| if not isinstance(param, Parameter): | |||||
| raise TypeError( | |||||
| "optimizer can only optimize Parameters, but one of the params is " | |||||
| + str(type(param)) | |||||
| ) | |||||
| param._reset(Tensor(param.numpy(), no_cache=True, format=param.format)) | |||||
| for param in param_group["params"]: | |||||
| if not isinstance(param, Parameter): | |||||
| raise TypeError( | |||||
| "optimizer can only optimize Parameters, but one of the params is " | |||||
| + str(type(param)) | |||||
| ) | |||||
| param._reset(Tensor(param.numpy(), no_cache=True, format=param.format)) | |||||
| for name, default in self._defaults.items(): | for name, default in self._defaults.items(): | ||||
| if default is required and name not in param_group: | if default is required and name not in param_group: | ||||
| @@ -121,8 +120,7 @@ class Optimizer(metaclass=ABCMeta): | |||||
| def _add_state(self, param, state_name, initializer=None): | def _add_state(self, param, state_name, initializer=None): | ||||
| if initializer is None: | if initializer is None: | ||||
| with _config._override(auto_format_convert=False): | |||||
| initializer = np.zeros(param.shape, dtype=np.float32) | |||||
| initializer = np.zeros(param.shape, dtype=np.float32) | |||||
| state_dict = self._state.setdefault(param, {}) | state_dict = self._state.setdefault(param, {}) | ||||
| assert state_name not in state_dict | assert state_name not in state_dict | ||||
| state = Tensor(initializer, no_cache=True, format=param.format) | state = Tensor(initializer, no_cache=True, format=param.format) | ||||
| @@ -10,7 +10,8 @@ import pytest | |||||
| import megengine.functional as F | import megengine.functional as F | ||||
| import megengine.module as M | import megengine.module as M | ||||
| from megengine import Parameter, Tensor, amp, config | |||||
| from megengine import Parameter, Tensor, amp | |||||
| from megengine.core._config import set_auto_format_convert | |||||
| class MyModule(M.Module): | class MyModule(M.Module): | ||||
| @@ -56,5 +57,6 @@ def test_convert_module(is_inplace): | |||||
| m = amp.convert_module_format(m, is_inplace) | m = amp.convert_module_format(m, is_inplace) | ||||
| for name, param in m.named_tensors(): | for name, param in m.named_tensors(): | ||||
| assert param.format == "nhwc" | assert param.format == "nhwc" | ||||
| with config._override(auto_format_convert=False): | |||||
| assert param.shape == expected_shape[name], name | |||||
| set_auto_format_convert(False) | |||||
| assert param.shape == expected_shape[name], name | |||||
| set_auto_format_convert(True) | |||||
| @@ -19,6 +19,9 @@ TypedValueRef<FormattedTensorValue> FormatTransformation::to( | |||||
| const std::string& scope) const { | const std::string& scope) const { | ||||
| std::vector<int32_t> pattern; | std::vector<int32_t> pattern; | ||||
| Format format = tensor.format(); | Format format = tensor.format(); | ||||
| if (format == target) | |||||
| return as(tensor, target); | |||||
| if (format == FT::NHWC && (target == FT::NCHW || target == FT::DEFAULT)) { | if (format == FT::NHWC && (target == FT::NCHW || target == FT::DEFAULT)) { | ||||
| // FIXME(czh): temporary fast path for group conv 5D weight. | // FIXME(czh): temporary fast path for group conv 5D weight. | ||||
| if (tensor.value().shape().cast<ShapeValue>().ndim == 5) { | if (tensor.value().shape().cast<ShapeValue>().ndim == 5) { | ||||
| @@ -618,7 +621,7 @@ ValueRefList FormatTransformation::apply_transformation( | |||||
| } else if (auto* _op = op.as<SetFormat>()) { | } else if (auto* _op = op.as<SetFormat>()) { | ||||
| auto&& inp_ref = inputs[0].as_ref(m_value_type); | auto&& inp_ref = inputs[0].as_ref(m_value_type); | ||||
| mgb_assert(inp_ref, "Cannot set format for non-format Tensor."); | mgb_assert(inp_ref, "Cannot set format for non-format Tensor."); | ||||
| return {m_value_type.make(inp_ref->value(), _op->format())}; | |||||
| return {to(*inp_ref, _op->format().type(), "")}; | |||||
| } else if (op.is<Operator::IdentityLike>()) { | } else if (op.is<Operator::IdentityLike>()) { | ||||
| auto&& inp_ref = inputs[0].as_ref(m_value_type); | auto&& inp_ref = inputs[0].as_ref(m_value_type); | ||||
| if (inp_ref) { | if (inp_ref) { | ||||