GitOrigin-RevId: 77ff909f23
tags/v1.10.0
| @@ -156,6 +156,7 @@ _atexit(_persistent_cache.flush) | |||||
| # subpackages | # subpackages | ||||
| import megengine.amp | import megengine.amp | ||||
| import megengine.autodiff | import megengine.autodiff | ||||
| import megengine.config | |||||
| import megengine.data | import megengine.data | ||||
| import megengine.distributed | import megengine.distributed | ||||
| import megengine.dtr | import megengine.dtr | ||||
| @@ -2,7 +2,13 @@ | |||||
| import os | import os | ||||
| from contextlib import contextmanager | from contextlib import contextmanager | ||||
| from ._imperative_rt.core2 import _clear_algorithm_cache, get_option, set_option | |||||
| from ._imperative_rt.core2 import ( | |||||
| _clear_algorithm_cache, | |||||
| get_auto_format_convert, | |||||
| get_option, | |||||
| set_auto_format_convert, | |||||
| set_option, | |||||
| ) | |||||
| __compute_mode = "default" | __compute_mode = "default" | ||||
| __conv_format = "default" | __conv_format = "default" | ||||
| @@ -24,8 +30,8 @@ __all__ = [ | |||||
| def benchmark_kernel(mod): | def benchmark_kernel(mod): | ||||
| r"""Whether or not run possible algorithms on real device to find the best one. The default option is false, | r"""Whether or not run possible algorithms on real device to find the best one. The default option is false, | ||||
| which means use heuristic to choose the fastest algorithm. | which means use heuristic to choose the fastest algorithm. | ||||
| Examples: | |||||
| Examples: | |||||
| .. code-block:: | .. code-block:: | ||||
| import megengine as mge | import megengine as mge | ||||
| @@ -47,8 +53,8 @@ def benchmark_kernel(mod, option: bool): | |||||
| def deterministic_kernel(mod): | def deterministic_kernel(mod): | ||||
| r"""Whether or not the fastest algorithm choosed is reproducible. The default option is false, | r"""Whether or not the fastest algorithm choosed is reproducible. The default option is false, | ||||
| which means the algorithm is not reproducible. | which means the algorithm is not reproducible. | ||||
| Examples: | |||||
| Examples: | |||||
| .. code-block:: | .. code-block:: | ||||
| import megengine as mge | import megengine as mge | ||||
| @@ -67,8 +73,8 @@ def deterministic_kernel(mod, option: bool): | |||||
| def async_level(mod) -> int: | def async_level(mod) -> int: | ||||
| r"""Get or set config whether raise error exactly when invoking op. The default level is 2, | r"""Get or set config whether raise error exactly when invoking op. The default level is 2, | ||||
| which means both device and user side errors are async. | which means both device and user side errors are async. | ||||
| Examples: | |||||
| Examples: | |||||
| .. code-block:: | .. code-block:: | ||||
| import megengine as mge | import megengine as mge | ||||
| @@ -108,8 +114,8 @@ def _compute_mode(mod): | |||||
| which means that no special requirements will be placed on. When set to 'float32', it | which means that no special requirements will be placed on. When set to 'float32', it | ||||
| would be used for accumulator and intermediate result, but only effective when input and | would be used for accumulator and intermediate result, but only effective when input and | ||||
| output are of float16 dtype. | output are of float16 dtype. | ||||
| Examples: | |||||
| Examples: | |||||
| .. code-block:: | .. code-block:: | ||||
| import megengine as mge | import megengine as mge | ||||
| @@ -137,8 +143,8 @@ def _conv_format(mod): | |||||
| ``NCHW88`` layout: ``{N, C/8, H, W, 8}`` | ``NCHW88`` layout: ``{N, C/8, H, W, 8}`` | ||||
| ``CHWN4`` layout: ``{C/4, H, W, N, 4}`` | ``CHWN4`` layout: ``{C/4, H, W, N, 4}`` | ||||
| ``NCHW64`` layout: ``{N, C/64, H, W, 64}`` | ``NCHW64`` layout: ``{N, C/64, H, W, 64}`` | ||||
| Examples: | |||||
| Examples: | |||||
| .. code-block:: | .. code-block:: | ||||
| import megengine as mge | import megengine as mge | ||||
| @@ -153,20 +159,41 @@ def _conv_format(mod, format: str): | |||||
| __conv_format = format | __conv_format = format | ||||
| @property | |||||
| def _auto_format_convert(mod): | |||||
| r"""Automatically convert indexing params' order for NCHW Tensor to NHWC order. | |||||
| The default value is False, which means no convert. | |||||
| Examples: | |||||
| .. code-block:: | |||||
| import megengine as mge | |||||
| mge.config._auto_format_convert = True | |||||
| """ | |||||
| return get_auto_format_convert() | |||||
| @_auto_format_convert.setter | |||||
| def _auto_format_convert(mod, option: bool): | |||||
| set_auto_format_convert(option) | |||||
| def _reset_execution_config( | def _reset_execution_config( | ||||
| benchmark_kernel=None, | benchmark_kernel=None, | ||||
| deterministic_kernel=None, | deterministic_kernel=None, | ||||
| async_level=None, | async_level=None, | ||||
| compute_mode=None, | compute_mode=None, | ||||
| conv_format=None, | conv_format=None, | ||||
| auto_format_convert=None, | |||||
| ): | ): | ||||
| global _benchmark_kernel, _deterministic_kernel, _async_level, __compute_mode, __conv_format | |||||
| global _benchmark_kernel, _deterministic_kernel, __compute_mode, __conv_format | |||||
| orig_flags = ( | orig_flags = ( | ||||
| _benchmark_kernel, | _benchmark_kernel, | ||||
| _deterministic_kernel, | _deterministic_kernel, | ||||
| get_option("async_level"), | get_option("async_level"), | ||||
| __compute_mode, | __compute_mode, | ||||
| __conv_format, | __conv_format, | ||||
| get_auto_format_convert(), | |||||
| ) | ) | ||||
| if benchmark_kernel is not None: | if benchmark_kernel is not None: | ||||
| _benchmark_kernel = benchmark_kernel | _benchmark_kernel = benchmark_kernel | ||||
| @@ -178,6 +205,8 @@ def _reset_execution_config( | |||||
| __compute_mode = compute_mode | __compute_mode = compute_mode | ||||
| if conv_format is not None: | if conv_format is not None: | ||||
| __conv_format = conv_format | __conv_format = conv_format | ||||
| if auto_format_convert is not None: | |||||
| set_auto_format_convert(auto_format_convert) | |||||
| return orig_flags | return orig_flags | ||||
| @@ -189,26 +218,33 @@ def _override( | |||||
| async_level=None, | async_level=None, | ||||
| compute_mode=None, | compute_mode=None, | ||||
| conv_format=None, | conv_format=None, | ||||
| auto_format_convert=None, | |||||
| ): | ): | ||||
| r"""A context manager that users can opt in by attaching the decorator to set | r"""A context manager that users can opt in by attaching the decorator to set | ||||
| the config of the global variable. | the config of the global variable. | ||||
| Examples: | |||||
| Examples: | |||||
| .. code-block:: | .. code-block:: | ||||
| import megengine as mge | import megengine as mge | ||||
| @mge.config._override( | @mge.config._override( | ||||
| benchmark_kernel = True, | benchmark_kernel = True, | ||||
| deterministic_kernel = Fasle, | deterministic_kernel = Fasle, | ||||
| async_level=2, | async_level=2, | ||||
| compute_mode="float32", | compute_mode="float32", | ||||
| conv_format="NHWC", | conv_format="NHWC", | ||||
| auto_format_convert=True, | |||||
| ) | ) | ||||
| def train(): | def train(): | ||||
| """ | """ | ||||
| orig_flags = _reset_execution_config( | orig_flags = _reset_execution_config( | ||||
| benchmark_kernel, deterministic_kernel, async_level, compute_mode, conv_format, | |||||
| benchmark_kernel, | |||||
| deterministic_kernel, | |||||
| async_level, | |||||
| compute_mode, | |||||
| conv_format, | |||||
| auto_format_convert, | |||||
| ) | ) | ||||
| try: | try: | ||||
| yield | yield | ||||
| @@ -564,7 +564,6 @@ def interpolate( | |||||
| if inp.dtype == np.float16: | if inp.dtype == np.float16: | ||||
| inp = inp.astype("float32") | inp = inp.astype("float32") | ||||
| conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format) | conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format) | ||||
| assert conv_format == "NCHW", "Currently resize only support NCHW mode" | |||||
| op = builtin.Resize(imode=mode_map[mode], format=conv_format) | op = builtin.Resize(imode=mode_map[mode], format=conv_format) | ||||
| shape = astensor1d(dsize, inp, dtype="int32", device=inp.device) | shape = astensor1d(dsize, inp, dtype="int32", device=inp.device) | ||||
| (ret,) = apply(op, inp, shape) | (ret,) = apply(op, inp, shape) | ||||
| @@ -4,6 +4,7 @@ from typing import Union | |||||
| import numpy as np | import numpy as np | ||||
| from .core._imperative_rt import CompNode | from .core._imperative_rt import CompNode | ||||
| from .core._imperative_rt.core2 import FormatType | |||||
| from .core._imperative_rt.core2 import Tensor as _Tensor | from .core._imperative_rt.core2 import Tensor as _Tensor | ||||
| from .core._imperative_rt.core2 import apply, set_py_tensor_type | from .core._imperative_rt.core2 import apply, set_py_tensor_type | ||||
| from .core._trace_option import use_symbolic_shape | from .core._trace_option import use_symbolic_shape | ||||
| @@ -45,6 +46,8 @@ class Tensor(_Tensor, ArrayMethodMixin): | |||||
| is_const: Whether make it a ``ImutableTensor`` in tracing mode, refer to :class:`.jit.trace`. | is_const: Whether make it a ``ImutableTensor`` in tracing mode, refer to :class:`.jit.trace`. | ||||
| no_cache: Whether cache it for memory sharing. | no_cache: Whether cache it for memory sharing. | ||||
| name: Used to improve convenience in graph operation on dumped model. | name: Used to improve convenience in graph operation on dumped model. | ||||
| format: Used to indicate which memory format Tensor uses. It will not affect actual memory order or stride, | |||||
| but may affect some operators related to indexing and dimension. Only support "default", "nchw" and "nhwc". | |||||
| .. note:: | .. note:: | ||||
| @@ -73,6 +76,7 @@ class Tensor(_Tensor, ArrayMethodMixin): | |||||
| is_const: bool = False, | is_const: bool = False, | ||||
| no_cache: bool = False, | no_cache: bool = False, | ||||
| name: str = None, | name: str = None, | ||||
| format: str = "default", | |||||
| ): | ): | ||||
| if name is None: | if name is None: | ||||
| name = "" | name = "" | ||||
| @@ -116,6 +120,10 @@ class Tensor(_Tensor, ArrayMethodMixin): | |||||
| r"""Returns a :class:`numpy.dtype` object represents the data type of a :class:`~.Tensor`.""" | r"""Returns a :class:`numpy.dtype` object represents the data type of a :class:`~.Tensor`.""" | ||||
| return super().dtype | return super().dtype | ||||
| @property | |||||
| def format(self) -> str: | |||||
| return super().format | |||||
| @property | @property | ||||
| def qparams(self): | def qparams(self): | ||||
| r"""Returns a :class:`~.QParams` object containing quantization params of a :class:`~.Tensor`.""" | r"""Returns a :class:`~.QParams` object containing quantization params of a :class:`~.Tensor`.""" | ||||
| @@ -8,6 +8,7 @@ | |||||
| #include "megbrain/imperative/transformations/dim_expansion.h" | #include "megbrain/imperative/transformations/dim_expansion.h" | ||||
| #include "megbrain/imperative/transformations/dtype_promote.h" | #include "megbrain/imperative/transformations/dtype_promote.h" | ||||
| #include "megbrain/imperative/transformations/eval.h" | #include "megbrain/imperative/transformations/eval.h" | ||||
| #include "megbrain/imperative/transformations/format.h" | |||||
| #include "megbrain/imperative/transformations/lazy.h" | #include "megbrain/imperative/transformations/lazy.h" | ||||
| #include "megbrain/imperative/transformations/scalar.h" | #include "megbrain/imperative/transformations/scalar.h" | ||||
| #include "megbrain/imperative/transformations/symbol.h" | #include "megbrain/imperative/transformations/symbol.h" | ||||
| @@ -492,6 +493,9 @@ ssize_t name2idx(const char* name) { | |||||
| // name | // name | ||||
| case 'a': return compare_cstr<'m', 'e'>(ch) ? 5 : -1; | case 'a': return compare_cstr<'m', 'e'>(ch) ? 5 : -1; | ||||
| } | } | ||||
| case 'f': | |||||
| // format | |||||
| return compare_cstr<'o', 'r', 'm', 'a', 't'>(ch) ? 6 : -1; | |||||
| } | } | ||||
| // clang-format on | // clang-format on | ||||
| return -1; | return -1; | ||||
| @@ -508,6 +512,7 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { | |||||
| {"is_const", []() -> py::object { return py::bool_(false); }}, | {"is_const", []() -> py::object { return py::bool_(false); }}, | ||||
| {"no_cache", []() -> py::object { return py::bool_(false); }}, | {"no_cache", []() -> py::object { return py::bool_(false); }}, | ||||
| {"name", []() -> py::object { return py::none(); }}, | {"name", []() -> py::object { return py::none(); }}, | ||||
| {"format", []() -> py::object { return py::none(); }}, | |||||
| }, | }, | ||||
| name2idx}; | name2idx}; | ||||
| py::detail::loader_life_support life_sup; // FIXME!!!required to cast DType | py::detail::loader_life_support life_sup; // FIXME!!!required to cast DType | ||||
| @@ -518,19 +523,23 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { | |||||
| } else { | } else { | ||||
| tup = parse_args(tup, descs); | tup = parse_args(tup, descs); | ||||
| } | } | ||||
| mgb_assert(tup.size() == 6); | |||||
| mgb_assert(tup.size() == 7); | |||||
| if (auto* t = try_cast(tup[0].ptr())) { | if (auto* t = try_cast(tup[0].ptr())) { | ||||
| m_tensor = t->m_tensor->copy(); | m_tensor = t->m_tensor->copy(); | ||||
| } else { | } else { | ||||
| auto data = tup[0]; | auto data = tup[0]; | ||||
| DType dtype = tup[1].cast<DType>(); | DType dtype = tup[1].cast<DType>(); | ||||
| CompNode cn = as_comp_node(tup[2]); | |||||
| bool is_const = tup[3].cast<bool>(); | bool is_const = tup[3].cast<bool>(); | ||||
| bool no_cache = tup[4].cast<bool>(); | bool no_cache = tup[4].cast<bool>(); | ||||
| std::string name; | std::string name; | ||||
| if (!tup[5].is_none()) { | if (!tup[5].is_none()) { | ||||
| name = tup[5].cast<std::string>(); | name = tup[5].cast<std::string>(); | ||||
| } | } | ||||
| CompNode cn = as_comp_node(tup[2]); | |||||
| Format format; | |||||
| if (!tup[6].is_none()) { | |||||
| format = tup[6].cast<std::string>(); | |||||
| } | |||||
| { | { | ||||
| CreateTensor::Kind kind = is_const ? CreateTensor::Const | CreateTensor::Kind kind = is_const ? CreateTensor::Const | ||||
| @@ -544,7 +553,7 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { | |||||
| } else { | } else { | ||||
| auto&& hval = pyobj2hval(data, cn, dtype); | auto&& hval = pyobj2hval(data, cn, dtype); | ||||
| val = imperative::apply( | val = imperative::apply( | ||||
| CreateTensor(kind, cn, hval.dtype, hval.shape), | |||||
| CreateTensor(kind, cn, hval.dtype, hval.shape, format), | |||||
| hval.storage)[0]; | hval.storage)[0]; | ||||
| } | } | ||||
| m_tensor.emplace(val); | m_tensor.emplace(val); | ||||
| @@ -610,6 +619,10 @@ PyObject* TensorWrapper::device() { | |||||
| return py::cast(m_tensor->comp_node()).release().ptr(); | return py::cast(m_tensor->comp_node()).release().ptr(); | ||||
| } | } | ||||
| PyObject* TensorWrapper::format() { | |||||
| return py::cast(m_tensor->format().to_string()).release().ptr(); | |||||
| } | |||||
| PyObject* TensorWrapper::numpy() { | PyObject* TensorWrapper::numpy() { | ||||
| auto hv = m_tensor->numpy(); | auto hv = m_tensor->numpy(); | ||||
| if (!hv) { | if (!hv) { | ||||
| @@ -722,6 +735,7 @@ WRAP_FUNC_PY35(pixel_shuffle_cpp); | |||||
| void init_tensor(py::module m) { | void init_tensor(py::module m) { | ||||
| imperative::Tensor::static_initialize(); | imperative::Tensor::static_initialize(); | ||||
| // Transformations | |||||
| static auto& transformations = TransformationManager::get_instance(); | static auto& transformations = TransformationManager::get_instance(); | ||||
| using Segment = TransformationManager::Segment; | using Segment = TransformationManager::Segment; | ||||
| @@ -755,6 +769,9 @@ void init_tensor(py::module m) { | |||||
| .register_at<Segment::DimExpansion>( | .register_at<Segment::DimExpansion>( | ||||
| std::make_shared<DimExpansionTransformation>()) | std::make_shared<DimExpansionTransformation>()) | ||||
| .release()); | .release()); | ||||
| auto format_trans = std::make_shared<FormatTransformation>(); | |||||
| MGB_MARK_USED_VAR( | |||||
| transformations.register_at<Segment::Format>(format_trans).release()); | |||||
| static py::exception<interpreter::AsyncError> py_async_error( | static py::exception<interpreter::AsyncError> py_async_error( | ||||
| m, "AsyncError", PyExc_RuntimeError); | m, "AsyncError", PyExc_RuntimeError); | ||||
| @@ -788,12 +805,14 @@ void init_tensor(py::module m) { | |||||
| } | } | ||||
| }); | }); | ||||
| // Tensor | |||||
| auto* tensor_type = | auto* tensor_type = | ||||
| TensorWrapper::wrap_t::type() | TensorWrapper::wrap_t::type() | ||||
| .def<&TensorWrapper::numpy>("numpy") | .def<&TensorWrapper::numpy>("numpy") | ||||
| .def_getset<&TensorWrapper::shape>("shape") | .def_getset<&TensorWrapper::shape>("shape") | ||||
| .def_getset<&TensorWrapper::dtype>("dtype") | .def_getset<&TensorWrapper::dtype>("dtype") | ||||
| .def_getset<&TensorWrapper::device>("device") | .def_getset<&TensorWrapper::device>("device") | ||||
| .def_getset<&TensorWrapper::format>("format") | |||||
| .def<&TensorWrapper::reset>("_reset") | .def<&TensorWrapper::reset>("_reset") | ||||
| .def<&TensorWrapper::isscalar>("_isscalar") | .def<&TensorWrapper::isscalar>("_isscalar") | ||||
| .def<&TensorWrapper::detach>("detach") | .def<&TensorWrapper::detach>("detach") | ||||
| @@ -812,6 +831,11 @@ void init_tensor(py::module m) { | |||||
| if (!tensor_type) | if (!tensor_type) | ||||
| throw py::error_already_set(); | throw py::error_already_set(); | ||||
| py::setattr(m, "Tensor", tensor_type); | py::setattr(m, "Tensor", tensor_type); | ||||
| py::enum_<Format::Type>(m, "FormatType") | |||||
| .value("DEFAULT", Format::Type::DEFAULT) | |||||
| .value("NCHW", Format::Type::NCHW) | |||||
| .value("NHWC", Format::Type::NHWC) | |||||
| .export_values(); | |||||
| py::class_<TensorWeakRef>(m, "TensorWeakRef") | py::class_<TensorWeakRef>(m, "TensorWeakRef") | ||||
| .def(py::init<const TensorWrapper&>()) | .def(py::init<const TensorWrapper&>()) | ||||
| @@ -911,6 +935,7 @@ void init_tensor(py::module m) { | |||||
| sync_py_task_q(); | sync_py_task_q(); | ||||
| }); | }); | ||||
| // GradTransformation | |||||
| py::handle grad_key_type = | py::handle grad_key_type = | ||||
| GradKeyWrapper::wrap_t::type() | GradKeyWrapper::wrap_t::type() | ||||
| .def<&GradKeyWrapper::attach>("attach") | .def<&GradKeyWrapper::attach>("attach") | ||||
| @@ -1203,6 +1228,7 @@ void init_tensor(py::module m) { | |||||
| return wrapped_outputs; | return wrapped_outputs; | ||||
| }); | }); | ||||
| // ModuleTraceTransformation | |||||
| static py::function module_trace_hook; | static py::function module_trace_hook; | ||||
| static auto get_module_trace = [] { | static auto get_module_trace = [] { | ||||
| @@ -1309,6 +1335,12 @@ void init_tensor(py::module m) { | |||||
| m.def("_clear_algorithm_cache", [] { megdnn::AlgorithmCache::instance().clear(); }); | m.def("_clear_algorithm_cache", [] { megdnn::AlgorithmCache::instance().clear(); }); | ||||
| // FormatTransformation | |||||
| m.def("set_auto_format_convert", | |||||
| [format_trans](bool enabled) { format_trans->set_auto_convert(enabled); }); | |||||
| m.def("get_auto_format_convert", | |||||
| [format_trans]() { return format_trans->get_auto_convert(); }); | |||||
| py::register_exception<TraceError>(m, "TraceError"); | py::register_exception<TraceError>(m, "TraceError"); | ||||
| } | } | ||||
| @@ -1,10 +1,11 @@ | |||||
| #pragma once | #pragma once | ||||
| #pragma GCC diagnostic ignored "-Wmissing-field-initializers" | #pragma GCC diagnostic ignored "-Wmissing-field-initializers" | ||||
| #include <variant> | |||||
| #include <string> | #include <string> | ||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include <variant> | |||||
| #include "megbrain/imperative/dispatch.h" | |||||
| #include "megbrain/imperative/interpreter.h" | #include "megbrain/imperative/interpreter.h" | ||||
| #include "pybind11/pybind11.h" | #include "pybind11/pybind11.h" | ||||
| @@ -57,6 +58,7 @@ public: | |||||
| } | } | ||||
| return *shape; | return *shape; | ||||
| } | } | ||||
| inline Format format() { return *data().format(); } | |||||
| inline HostValue::ref_t numpy() { return data().numpy(); } | inline HostValue::ref_t numpy() { return data().numpy(); } | ||||
| inline void reset(ValueRef value) { | inline void reset(ValueRef value) { | ||||
| m_data = value; | m_data = value; | ||||
| @@ -116,6 +118,7 @@ public: | |||||
| PyObject* shape(); | PyObject* shape(); | ||||
| PyObject* dtype(); | PyObject* dtype(); | ||||
| PyObject* device(); | PyObject* device(); | ||||
| PyObject* format(); | |||||
| PyObject* numpy(); | PyObject* numpy(); | ||||
| void reset(PyObject*); | void reset(PyObject*); | ||||
| PyObject* detach(); | PyObject* detach(); | ||||
| @@ -19,6 +19,7 @@ public: | |||||
| DTypePromote, | DTypePromote, | ||||
| DimExpansion, | DimExpansion, | ||||
| Grad, | Grad, | ||||
| Format, | |||||
| Scalar, | Scalar, | ||||
| Symbol, | Symbol, | ||||
| Trace, | Trace, | ||||
| @@ -2,7 +2,7 @@ from megengine import amp | |||||
| from megengine.core.tensor import amp as origin_amp | from megengine.core.tensor import amp as origin_amp | ||||
| def test_grad_scaler(): | |||||
| def test_autocast(): | |||||
| def check(enabled, low, high): | def check(enabled, low, high): | ||||
| assert amp.enabled == enabled | assert amp.enabled == enabled | ||||
| assert origin_amp._enabled == enabled | assert origin_amp._enabled == enabled | ||||
| @@ -0,0 +1,307 @@ | |||||
| import numpy as np | |||||
| import pytest | |||||
| import megengine as mge | |||||
| import megengine.functional as F | |||||
| from megengine import tensor | |||||
| from megengine.autodiff import GradManager | |||||
| def test_basic(): | |||||
| a = tensor(np.arange(0, 24).reshape((1, 2, 3, 4)), dtype="float32", format="nhwc") | |||||
| assert a.format == "nhwc" | |||||
| b = tensor(a) | |||||
| assert b.format == "nhwc" | |||||
| # TODO: fix Tensor init bug for another Tensor | |||||
| # c = tensor(a, format="nchw") | |||||
| # assert c.format == "nchw" | |||||
| def _compare_nchw_nhwc(data, func): | |||||
| x1 = tensor(data, format="nchw") | |||||
| x2 = tensor(data.transpose(0, 2, 3, 1), format="nhwc") | |||||
| out1 = func(x1) | |||||
| with mge.config._override(auto_format_convert=True): | |||||
| out2 = func(x2) | |||||
| np.testing.assert_equal(out1, out2) | |||||
| def test_dimshuffle(): | |||||
| def func(x): | |||||
| out = F.transpose(x, [2, 3, 0, 1]) | |||||
| assert out.format == "default" | |||||
| return out.numpy() | |||||
| data = np.arange(0, 24).reshape((1, 2, 3, 4)) | |||||
| _compare_nchw_nhwc(data, func) | |||||
| def test_reshape(): | |||||
| # maintain NHWC format | |||||
| def func(x): | |||||
| out = F.reshape(x, (1, 2, 6, 2)) | |||||
| if x.format == "nhwc": | |||||
| assert out.format == "nhwc" | |||||
| return out.numpy() | |||||
| data = np.arange(0, 24).reshape((1, 2, 3, 4)) | |||||
| _compare_nchw_nhwc(data, func) | |||||
| # not maintain NHWC format | |||||
| def func2(x): | |||||
| out = F.reshape(x, (1, 24)) | |||||
| assert out.format == "default" | |||||
| return out.numpy() | |||||
| _compare_nchw_nhwc(data, func2) | |||||
| def test_flatten(): | |||||
| def func(x): | |||||
| return F.flatten(x).numpy() | |||||
| data = np.arange(0, 24).reshape((1, 2, 3, 4)) | |||||
| _compare_nchw_nhwc(data, func) | |||||
| def test_broadcast(): | |||||
| # maintain NHWC format | |||||
| def func(x): | |||||
| out = F.broadcast_to(x, (4, 3, 2, 3)) | |||||
| if x.format == "nhwc": | |||||
| assert out.format == "nhwc" | |||||
| return out.numpy() | |||||
| data = np.arange(0, 24).reshape((4, 3, 2, 1)) | |||||
| _compare_nchw_nhwc(data, func) | |||||
| # not maintain NHWC format | |||||
| def func2(x): | |||||
| out = F.broadcast_to(x, (3, 4, 3, 2, 1)) | |||||
| assert out.format == "default" | |||||
| return out.numpy() | |||||
| _compare_nchw_nhwc(data, func2) | |||||
| @pytest.mark.skip("repeat cannot maintain format yet") | |||||
| def test_repeat(): | |||||
| def func(x): | |||||
| rst = F.repeat(x, 3, axis=1) | |||||
| assert rst.format == x.format | |||||
| return rst.numpy() | |||||
| data = np.arange(0, 24).reshape((1, 2, 3, 4)) | |||||
| _compare_nchw_nhwc(data, func) | |||||
| def test_getshape(): | |||||
| def func(x): | |||||
| return x.shape | |||||
| data = np.arange(0, 24).reshape((1, 2, 3, 4)) | |||||
| _compare_nchw_nhwc(data, func) | |||||
| @pytest.mark.skip("symbolic shape is not supported yet") | |||||
| def test_get_symbolic_shape(): | |||||
| from megengine.core._trace_option import set_symbolic_shape | |||||
| origin_opt = set_symbolic_shape(True) | |||||
| def func(x): | |||||
| return x.shape.numpy() | |||||
| data = np.arange(0, 24).reshape((1, 2, 3, 4)) | |||||
| _compare_nchw_nhwc(data, func) | |||||
| set_symbolic_shape(origin_opt) | |||||
| def test_getvalue(): | |||||
| def func(x): | |||||
| return x.numpy() | |||||
| data = np.arange(0, 24).reshape((1, 2, 3, 4)) | |||||
| _compare_nchw_nhwc(data, func) | |||||
| def test_get_set_subtensor(): | |||||
| def get_subtensor(x): | |||||
| return x[:, :1, :2, :3].numpy() | |||||
| data = np.arange(0, 24).reshape((1, 2, 3, 4)) | |||||
| _compare_nchw_nhwc(data, get_subtensor) | |||||
| def set_subtensor(x): | |||||
| x[:, :1, :2, :3] = 0 | |||||
| return x.numpy() | |||||
| _compare_nchw_nhwc(data, set_subtensor) | |||||
| def test_get_set_advanced_indexing(): | |||||
| def get_advanced_indexing(x): | |||||
| x = x[:, : mge.tensor(2), : mge.tensor(2), [1, 2]].numpy() | |||||
| return x | |||||
| data = np.arange(0, 24).reshape((1, 2, 3, 4)) | |||||
| _compare_nchw_nhwc(data, get_advanced_indexing) | |||||
| def set_advanced_indexing(x): | |||||
| x[:, : mge.tensor(2), : mge.tensor([2]), [1,]] = 0 | |||||
| return x.numpy() | |||||
| _compare_nchw_nhwc(data, set_advanced_indexing) | |||||
| def test_typecvt(): | |||||
| def typecvt(x): | |||||
| return x.astype("float16").numpy() | |||||
| data = np.arange(0, 24).reshape((1, 2, 3, 4)) | |||||
| _compare_nchw_nhwc(data, typecvt) | |||||
| def test_elemwise(): | |||||
| def elemwise(x): | |||||
| return (x * 2 + x / 2).numpy() | |||||
| data = np.arange(0, 24).reshape((1, 2, 3, 4)) | |||||
| _compare_nchw_nhwc(data, elemwise) | |||||
| def test_concat(): | |||||
| def func(x): | |||||
| rst = F.concat([x / 2, x * 2], axis=1) | |||||
| assert rst.format == x.format | |||||
| return rst.numpy() | |||||
| data = np.arange(0, 24).reshape((1, 2, 3, 4)) | |||||
| _compare_nchw_nhwc(data, func) | |||||
| @pytest.mark.parametrize( | |||||
| "mode", ["bilinear", "nearest"], | |||||
| ) | |||||
| def test_interpolate(mode): | |||||
| def func(x): | |||||
| if x.format == "nhwc": | |||||
| with mge.config._override(conv_format="NHWC"): | |||||
| rst = F.vision.interpolate(x, scale_factor=3, mode=mode) | |||||
| assert rst.format == "nhwc" | |||||
| return rst.numpy() | |||||
| else: | |||||
| return F.vision.interpolate(x, scale_factor=3, mode=mode).numpy() | |||||
| # NHWC interpolate only suppoted channel is 1 or 3 | |||||
| data = np.arange(0, 48).reshape((1, 3, 4, 4)).astype("float32") | |||||
| _compare_nchw_nhwc(data, func) | |||||
| def test_conv2d(): | |||||
| def conv2d(x): | |||||
| if x.format == "nhwc": | |||||
| with mge.config._override(conv_format="NHWC"): | |||||
| x = F.conv2d( | |||||
| x, | |||||
| weight=mge.tensor(np.ones((3, 1, 1, 2)), format="nhwc"), | |||||
| bias=mge.tensor(np.ones((1, 1, 1, 3)), format="nhwc"), | |||||
| ) | |||||
| assert x.format == "nhwc" | |||||
| return x.numpy() | |||||
| else: | |||||
| return F.conv2d(x, F.ones((3, 2, 1, 1)), F.ones((1, 3, 1, 1))).numpy() | |||||
| data = np.arange(0, 24).reshape((1, 2, 3, 4)) | |||||
| _compare_nchw_nhwc(data, conv2d) | |||||
| def test_group_conv2d(): | |||||
| def conv2d(x): | |||||
| if x.format == "nhwc": | |||||
| with mge.config._override(conv_format="NHWC"): | |||||
| x = F.conv2d( | |||||
| x, | |||||
| weight=mge.tensor(np.ones((2, 2, 1, 1, 2)), format="nhwc"), | |||||
| bias=mge.tensor(np.ones((1, 1, 1, 4)), format="nhwc"), | |||||
| groups=2, | |||||
| ) | |||||
| assert x.format == "nhwc" | |||||
| return x.numpy() | |||||
| else: | |||||
| return F.conv2d( | |||||
| x, F.ones((2, 2, 2, 1, 1)), F.ones((1, 4, 1, 1)), groups=2 | |||||
| ).numpy() | |||||
| data = np.arange(0, 48).reshape((1, 4, 3, 4)) | |||||
| _compare_nchw_nhwc(data, conv2d) | |||||
| def test_bn(): | |||||
| def func(x): | |||||
| if x.format == "nhwc": | |||||
| with mge.config._override(bn_format="dim_111c"): | |||||
| oups = F.batch_norm( | |||||
| x.astype("float32"), | |||||
| running_mean=mge.tensor(np.ones((1, 1, 1, 2)), format="nhwc"), | |||||
| running_var=mge.tensor(np.ones((1, 1, 1, 2)), format="nhwc"), | |||||
| weight=mge.tensor(np.ones((1, 1, 1, 2)), format="nhwc"), | |||||
| bias=mge.tensor(np.ones((1, 1, 1, 2)), format="nhwc"), | |||||
| training=True, | |||||
| inplace=False, | |||||
| ) | |||||
| assert oups[0].format == "nhwc", "y's format is wrong" | |||||
| assert oups[1].format == "nhwc", "running_mean's format is wrong" | |||||
| assert oups[2].format == "nhwc", "running_var's format is wrong" | |||||
| return oups[0].numpy() | |||||
| else: | |||||
| return F.batch_norm( | |||||
| x.astype("float32"), | |||||
| running_mean=mge.tensor(np.ones((1, 2, 1, 1))), | |||||
| running_var=mge.tensor(np.ones((1, 2, 1, 1))), | |||||
| weight=mge.tensor(np.ones((1, 2, 1, 1))), | |||||
| bias=mge.tensor(np.ones((1, 2, 1, 1))), | |||||
| training=True, | |||||
| inplace=False, | |||||
| )[0].numpy() | |||||
| data = np.arange(0, 24).reshape((1, 2, 3, 4)) | |||||
| _compare_nchw_nhwc(data, func) | |||||
| @pytest.mark.parametrize( | |||||
| "pooling", | |||||
| [F.max_pool2d, F.avg_pool2d, F.adaptive_avg_pool2d, F.adaptive_max_pool2d], | |||||
| ) | |||||
| def test_pooling2d(pooling): | |||||
| def func(x): | |||||
| if x.format == "nhwc": | |||||
| with mge.config._override(conv_format="NHWC"): | |||||
| x = pooling(x.astype("float32"), 2) | |||||
| assert x.format == "nhwc" | |||||
| return x.numpy() | |||||
| else: | |||||
| return pooling(x.astype("float32"), 2).numpy() | |||||
| data = np.arange(0, 24).reshape((1, 2, 3, 4)) | |||||
| _compare_nchw_nhwc(data, func) | |||||
| def test_backward(): | |||||
| data = np.arange(0, 24).reshape((1, 2, 3, 4)) | |||||
| x = tensor(data.transpose(0, 2, 3, 1), format="nhwc") | |||||
| w = mge.tensor(np.ones((3, 1, 1, 2)), format="nhwc") | |||||
| b = mge.tensor(np.ones((1, 1, 1, 3)), format="nhwc") | |||||
| gm = GradManager().attach([w, b]) | |||||
| with gm: | |||||
| with mge.config._override(auto_format_convert=True, conv_format="NHWC"): | |||||
| x = F.conv2d(x, w, b) | |||||
| gm.backward(x) | |||||
| # TODO: backward grad has no format yet | |||||
| np.testing.assert_equal( | |||||
| w.grad.numpy(), | |||||
| np.array([66, 210, 66, 210, 66, 210]).reshape((3, 1, 1, 2)), | |||||
| ) | |||||
| np.testing.assert_equal( | |||||
| b.grad.numpy(), np.array([12, 12, 12]).reshape((1, 1, 1, 3)) | |||||
| ) | |||||
| @@ -33,14 +33,20 @@ std::string GetAttr::to_string() const { | |||||
| return ssprintf("GetAttr{attr=%s}", attr_name); | return ssprintf("GetAttr{attr=%s}", attr_name); | ||||
| } | } | ||||
| CreateTensor::CreateTensor(Kind kind, CompNode device, DType dtype, ValueShape shape) | |||||
| : m_kind(kind), m_device(device), m_dtype(dtype), m_shape(shape) {} | |||||
| CreateTensor::CreateTensor( | |||||
| Kind kind, CompNode device, DType dtype, ValueShape shape, Format format) | |||||
| : m_kind(kind), | |||||
| m_device(device), | |||||
| m_dtype(dtype), | |||||
| m_shape(shape), | |||||
| m_format(format) {} | |||||
| CreateTensor::CreateTensor(Kind kind, CompNode device, TensorLayout layout) | CreateTensor::CreateTensor(Kind kind, CompNode device, TensorLayout layout) | ||||
| : m_kind(kind), | : m_kind(kind), | ||||
| m_device(device), | m_device(device), | ||||
| m_dtype(layout.dtype), | m_dtype(layout.dtype), | ||||
| m_shape(ValueShape::from(layout)) { | |||||
| m_shape(ValueShape::from(layout)), | |||||
| m_format(Format::Type::DEFAULT) { | |||||
| mgb_assert( | mgb_assert( | ||||
| layout.is_contiguous() || layout.is_empty(), "layout should be contiguous"); | layout.is_contiguous() || layout.is_empty(), "layout should be contiguous"); | ||||
| } | } | ||||
| @@ -74,8 +80,9 @@ auto CreateTensor::parse(Span<ValueRef> inputs) const -> Args { | |||||
| std::string CreateTensor::to_string() const { | std::string CreateTensor::to_string() const { | ||||
| return ssprintf( | return ssprintf( | ||||
| "CreateTensor{kind=%d, device=%s, dtype=%s, shape=%s}", (int)m_kind, | |||||
| m_device.to_string().c_str(), m_dtype.name(), m_shape.to_string().c_str()); | |||||
| "CreateTensor{kind=%d, device=%s, dtype=%s, shape=%s, format=%s}", | |||||
| (int)m_kind, m_device.to_string().c_str(), m_dtype.name(), | |||||
| m_shape.to_string().c_str(), m_format.to_string().c_str()); | |||||
| } | } | ||||
| std::string DTRCommand::to_string() const { | std::string DTRCommand::to_string() const { | ||||
| @@ -0,0 +1,406 @@ | |||||
| #include "megbrain/imperative/transformations/format.h" | |||||
| #include "megbrain/imperative/ops/autogen.h" | |||||
| namespace mgb { | |||||
| namespace imperative { | |||||
| using FT = Format::Type; | |||||
| TypedValueRef<FormattedTensorValue> FormattedTensorValue::as(const FT& target) const { | |||||
| return FormattedTensorValue::make(m_value, target); | |||||
| } | |||||
| TypedValueRef<FormattedTensorValue> FormattedTensorValue::to( | |||||
| const FT& target, const std::string& scope) const { | |||||
| std::vector<int32_t> pattern; | |||||
| if (m_format == FT::NHWC && target == FT::NCHW) { | |||||
| pattern = {0, 3, 1, 2}; | |||||
| } else if (m_format == FT::NCHW && target == FT::NHWC) { | |||||
| pattern = {0, 2, 3, 1}; | |||||
| } else { | |||||
| mgb_throw( | |||||
| MegBrainError, "Unsupport format conversion from %s to %s", | |||||
| m_format.to_string().c_str(), Format(target).to_string().c_str()); | |||||
| } | |||||
| auto output = imperative::apply( | |||||
| *Dimshuffle::make(pattern, scope), std::vector<ValueRef>{m_value})[0]; | |||||
| return FormattedTensorValue::make(output, target); | |||||
| } | |||||
| namespace { | |||||
| ValueRef unwrap_input(const ValueRef& input) { | |||||
| if (auto format_input = input.as_ref<FormattedTensorValue>()) { | |||||
| return format_input->value(); | |||||
| } else { | |||||
| return input; | |||||
| } | |||||
| } | |||||
| std::vector<ValueRef> unwrap_inputs(const Span<ValueRef>& inputs) { | |||||
| std::vector<ValueRef> unwrapped_inputs; | |||||
| for (auto&& input : inputs) { | |||||
| unwrapped_inputs.push_back(unwrap_input(input)); | |||||
| } | |||||
| return unwrapped_inputs; | |||||
| } | |||||
| std::vector<ValueRef> wrap_outputs( | |||||
| const std::vector<ValueRef>& outputs, FT type = FT::DEFAULT) { | |||||
| std::vector<ValueRef> wrapped_outputs; | |||||
| for (auto&& output : outputs) { | |||||
| wrapped_outputs.push_back(FormattedTensorValue::make(output, type)); | |||||
| } | |||||
| return wrapped_outputs; | |||||
| } | |||||
| ValueShape convert_nhwc2nchw_shape(const ValueShape& shape) { | |||||
| mgb_assert(shape.ndim == 4); | |||||
| auto out = ValueShape(shape); | |||||
| out[3] = shape[2]; | |||||
| out[2] = shape[1]; | |||||
| out[1] = shape[3]; | |||||
| return out; | |||||
| } | |||||
| using FormatRule = std::function<std::vector<ValueRef>( | |||||
| const OpDef&, Span<ValueRef>&, const bool&)>; | |||||
| static std::unordered_map<Typeinfo*, FormatRule> format_rules; | |||||
| template <typename T> | |||||
| void register_format_rule( | |||||
| std::vector<ValueRef> (*rule)(const T&, Span<ValueRef>&, const bool&)) { | |||||
| format_rules[T::typeinfo()] = [rule](const OpDef& def, Span<ValueRef>& inputs, | |||||
| const bool& auto_convert) { | |||||
| return (*rule)(def.cast_final_safe<T>(), inputs, auto_convert); | |||||
| }; | |||||
| } | |||||
| auto convert_nchw2nhwc_pattern(const std::vector<int32_t>& pattern) { | |||||
| mgb_assert(pattern.size() == 4); | |||||
| auto nhwc_pattern = pattern; | |||||
| for (size_t idx = 0; idx < 4; ++idx) { | |||||
| auto dim = pattern[idx]; | |||||
| if (dim == 1) { | |||||
| nhwc_pattern[idx] = 3; | |||||
| } else if (dim == 2) { | |||||
| nhwc_pattern[idx] = 1; | |||||
| } else if (dim == 3) { | |||||
| nhwc_pattern[idx] = 2; | |||||
| } | |||||
| } | |||||
| return nhwc_pattern; | |||||
| } | |||||
| std::vector<ValueRef> dimshuffle_rule( | |||||
| const Dimshuffle& op, Span<ValueRef>& inputs, const bool& auto_convert) { | |||||
| mgb_assert(inputs.size() == 1); | |||||
| auto& src = inputs[0].cast<FormattedTensorValue>(); | |||||
| // Only support converting pattern from NCHW to NHWC currently. | |||||
| if (auto_convert && src.format() == FT::NHWC) { | |||||
| auto pattern = convert_nchw2nhwc_pattern(op.pattern); | |||||
| // dimshuffle will not maintain NHWC Format | |||||
| return wrap_outputs(imperative::apply( | |||||
| *Dimshuffle::make(std::move(pattern), op.scope()), | |||||
| unwrap_inputs(inputs))); | |||||
| } | |||||
| return wrap_outputs(imperative::apply(op, unwrap_inputs(inputs))); | |||||
| } | |||||
| ValueRef convert_nchw2nhwc_tensornd(const HostTensorND& shape) { | |||||
| mgb_assert(shape.layout().total_nr_elems() == 4); | |||||
| auto* old_ptr = shape.ptr<dt_int32>(); | |||||
| auto cn = shape.comp_node(); | |||||
| auto layout = shape.layout(); | |||||
| auto nhwc_shape = HostTensorND(cn, layout); | |||||
| auto* new_ptr = nhwc_shape.ptr<dt_int32>(); | |||||
| new_ptr[0] = old_ptr[0]; | |||||
| new_ptr[1] = old_ptr[2]; | |||||
| new_ptr[2] = old_ptr[3]; | |||||
| new_ptr[3] = old_ptr[1]; | |||||
| auto hv = HostStorage::make(nhwc_shape.storage()); | |||||
| auto nhwc_shape_input = | |||||
| imperative::apply(CreateTensor(CreateTensor::Const, cn, layout), hv)[0]; | |||||
| return nhwc_shape_input; | |||||
| } | |||||
| std::vector<ValueRef> reshape_rule( | |||||
| const Reshape& op, Span<ValueRef>& inputs, const bool& auto_convert) { | |||||
| mgb_assert(inputs.size() == 2); | |||||
| auto& src = inputs[0].cast<FormattedTensorValue>(); | |||||
| if (auto_convert && src.format() == FT::NHWC) { | |||||
| auto shape = unwrap_input(inputs[1]).numpy().cast<HostValue>().as_nd(); | |||||
| if (shape.layout().total_nr_elems() == 4) { | |||||
| // output is still NHWC format | |||||
| auto nhwc_shape = convert_nchw2nhwc_tensornd(shape); | |||||
| auto outputs = imperative::apply( | |||||
| op, std::vector<ValueRef>{unwrap_input(inputs[0]), nhwc_shape}); | |||||
| return wrap_outputs(outputs, FT::NHWC); | |||||
| } else { | |||||
| // will not maintain src's format | |||||
| auto nchw_src = src.to(FT::NCHW, op.scope())->value(); | |||||
| auto outputs = imperative::apply( | |||||
| op, std::vector<ValueRef>{nchw_src, unwrap_input(inputs[1])}); | |||||
| return wrap_outputs(outputs); | |||||
| } | |||||
| } | |||||
| return wrap_outputs(imperative::apply(op, unwrap_inputs(inputs))); | |||||
| } | |||||
| std::vector<ValueRef> broadcast_rule( | |||||
| const Broadcast& op, Span<ValueRef>& inputs, const bool& auto_convert) { | |||||
| mgb_assert(inputs.size() == 2); | |||||
| auto& src = inputs[0].cast<FormattedTensorValue>(); | |||||
| if (auto_convert && src.format() == FT::NHWC) { | |||||
| auto shape = unwrap_input(inputs[1]).numpy().cast<HostValue>().as_nd(); | |||||
| if (shape.layout().total_nr_elems() == 4) { | |||||
| // output is still NHWC format | |||||
| auto nhwc_shape = convert_nchw2nhwc_tensornd(shape); | |||||
| auto outputs = imperative::apply( | |||||
| op, std::vector<ValueRef>{unwrap_input(inputs[0]), nhwc_shape}); | |||||
| return wrap_outputs(outputs, FT::NHWC); | |||||
| } else { | |||||
| // will not maintain src's format | |||||
| auto nchw_src = src.to(FT::NCHW, op.scope())->value(); | |||||
| auto outputs = imperative::apply( | |||||
| op, std::vector<ValueRef>{nchw_src, unwrap_input(inputs[1])}); | |||||
| return wrap_outputs(outputs); | |||||
| } | |||||
| } | |||||
| return wrap_outputs(imperative::apply(op, unwrap_inputs(inputs))); | |||||
| } | |||||
| bool is_reduce_ndim_idx_items( | |||||
| const std::vector<std::tuple<int8_t, bool, bool, bool, bool>>& items, | |||||
| const Span<ValueRef>& inputs) { | |||||
| for (auto i = 0; i < items.size(); ++i) { | |||||
| auto&& [axis, begin, end, step, idx] = items[i]; | |||||
| if (idx) { | |||||
| // if inputs[i] contains more than one value, ndim will not be reduced. | |||||
| return inputs[i].is_scalar(); | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| auto convert_nchw2nhwc_idx_items( | |||||
| const std::vector<std::tuple<int8_t, bool, bool, bool, bool>>& items) { | |||||
| auto nhwc_items = items; | |||||
| for (auto i = 0; i < nhwc_items.size(); ++i) { | |||||
| auto&& [axis, begin, end, step, idx] = nhwc_items[i]; | |||||
| if (axis == 2 || axis == 3) { | |||||
| nhwc_items[i] = {axis - 1, begin, end, step, idx}; | |||||
| } else if (axis == 1) { | |||||
| nhwc_items[i] = {3, begin, end, step, idx}; | |||||
| } | |||||
| } | |||||
| return nhwc_items; | |||||
| } | |||||
| template <typename T> | |||||
| std::vector<ValueRef> subtensor_rule( | |||||
| const T& op, Span<ValueRef>& inputs, const bool& auto_convert) { | |||||
| mgb_assert(inputs.size() >= 1); | |||||
| auto& src = inputs[0].cast<FormattedTensorValue>(); | |||||
| bool is_reduce_ndim = is_reduce_ndim_idx_items( | |||||
| op.items, {&inputs[1], &inputs[inputs.size() - 1]}); | |||||
| if (!is_reduce_ndim) { | |||||
| // only support NHWC2NCHW convert, otherwise maintain src's format | |||||
| if (!(auto_convert && src.format() == FT::NHWC)) { | |||||
| return {FormattedTensorValue::make( | |||||
| imperative::apply(op, unwrap_inputs(inputs))[0], src.format())}; | |||||
| } | |||||
| auto nhwc_items = convert_nchw2nhwc_idx_items(op.items); | |||||
| auto outputs = imperative::apply( | |||||
| *T::make(std::move(nhwc_items), op.scope()), unwrap_inputs(inputs)); | |||||
| return wrap_outputs(outputs, FT::NHWC); | |||||
| } | |||||
| return wrap_outputs(imperative::apply(op, unwrap_inputs(inputs))); | |||||
| } | |||||
| template <typename T> | |||||
| std::vector<ValueRef> setsubtensor_rule( | |||||
| const T& op, Span<ValueRef>& inputs, const bool& auto_convert) { | |||||
| mgb_assert(inputs.size() >= 2); | |||||
| auto& src = inputs[0].cast<FormattedTensorValue>(); | |||||
| bool is_reduce_ndim = is_reduce_ndim_idx_items( | |||||
| op.items, {&inputs[2], &inputs[inputs.size() - 1]}); | |||||
| if (!is_reduce_ndim) { | |||||
| // only support NHWC2NCHW convert, otherwise maintain src's format | |||||
| if (!(auto_convert && src.format() == FT::NHWC)) { | |||||
| return {FormattedTensorValue::make( | |||||
| imperative::apply(op, unwrap_inputs(inputs))[0], src.format())}; | |||||
| } | |||||
| // value has been broadcasted to src's fake NCHW shape. | |||||
| auto& value = inputs[1].cast<FormattedTensorValue>(); | |||||
| auto& format = value.format(); | |||||
| auto nhwc_inputs = std::vector<ValueRef>(inputs.size()); | |||||
| if (format == FT::DEFAULT || format == FT::NCHW) { | |||||
| // value for setsubtensor should transpose to match shape. | |||||
| auto nhwc_value = value.as(FT::NCHW)->to(FT::NHWC); | |||||
| // make new inputs for setsubtensor | |||||
| nhwc_inputs[0] = src.value(); | |||||
| nhwc_inputs[1] = nhwc_value->value(); | |||||
| for (auto i = 2; i < inputs.size(); ++i) { | |||||
| nhwc_inputs[i] = inputs[i].as_ref<FormattedTensorValue>()->value(); | |||||
| } | |||||
| } else if (format != FT::NHWC) { | |||||
| mgb_throw( | |||||
| MegBrainError, "Unsupported format(%s) of value for setsubtensor.", | |||||
| format.to_string().c_str()); | |||||
| } | |||||
| auto nhwc_items = convert_nchw2nhwc_idx_items(op.items); | |||||
| auto outputs = imperative::apply( | |||||
| *T::make(std::move(nhwc_items), op.scope()), nhwc_inputs); | |||||
| return wrap_outputs(outputs, FT::NHWC); | |||||
| } | |||||
| return wrap_outputs(imperative::apply(op, unwrap_inputs(inputs))); | |||||
| } | |||||
| FT get_inputs_format(Span<ValueRef>& inputs) { | |||||
| FT format(FT::DEFAULT); | |||||
| for (auto& inp : inputs) { | |||||
| auto& inp_format = inp.cast<FormattedTensorValue>().format(); | |||||
| if (inp_format != FT::DEFAULT) { | |||||
| mgb_assert(format == FT::DEFAULT || inp_format == format); | |||||
| format = inp_format.type(); | |||||
| } | |||||
| } | |||||
| return format; | |||||
| } | |||||
| std::vector<ValueRef> concat_rule( | |||||
| const Concat& op, Span<ValueRef>& inputs, const bool& auto_convert) { | |||||
| FT format = get_inputs_format(inputs); | |||||
| if (!(format == FT::NHWC && auto_convert)) { | |||||
| return wrap_outputs(imperative::apply(op, unwrap_inputs(inputs)), format); | |||||
| } | |||||
| // TODO: handle 5D NHWC Tensor from group conv | |||||
| auto axis = op.axis; | |||||
| if (axis == 2 || axis == 3) { | |||||
| axis = axis - 1; | |||||
| } else if (axis == 1) { | |||||
| axis = 3; | |||||
| } | |||||
| return wrap_outputs( | |||||
| imperative::apply( | |||||
| *Concat::make(axis, op.comp_node, op.scope()), | |||||
| unwrap_inputs(inputs)), | |||||
| format); | |||||
| } | |||||
| std::vector<ValueRef> elemwise_rule( | |||||
| const Elemwise& op, Span<ValueRef>& inputs, const bool& auto_convert) { | |||||
| FT format = get_inputs_format(inputs); | |||||
| return wrap_outputs(imperative::apply(op, unwrap_inputs(inputs)), format); | |||||
| } | |||||
| std::vector<ValueRef> identity_rule_helper( | |||||
| const OpDef& op, const Span<ValueRef>& inputs) { | |||||
| // mgb_assert(inputs.size() == 1); | |||||
| auto& src = inputs[0].cast<FormattedTensorValue>(); | |||||
| return wrap_outputs( | |||||
| imperative::apply(op, unwrap_inputs(inputs)), src.format().type()); | |||||
| } | |||||
| // clang-format off | |||||
| #define FOREACH_IDENTITY_OP(cb) \ | |||||
| cb(Copy) \ | |||||
| cb(FastpathCopy) \ | |||||
| cb(TypeCvt) \ | |||||
| cb(Pooling) \ | |||||
| cb(AdaptivePooling) \ | |||||
| cb(Dropout) \ | |||||
| cb(Convolution) \ | |||||
| cb(BatchNorm) \ | |||||
| cb(Resize) \ | |||||
| cb(Identity) | |||||
| // clang-format on | |||||
| #define CREATE_IDENTITY_OP_RULE(op) \ | |||||
| std::vector<ValueRef> op##_rule( \ | |||||
| const op& _op, Span<ValueRef>& inputs, const bool& auto_convert) { \ | |||||
| return identity_rule_helper(_op, inputs); \ | |||||
| } | |||||
| FOREACH_IDENTITY_OP(CREATE_IDENTITY_OP_RULE) | |||||
| #undef CREATE_IDENTITY_OP_RULE | |||||
| #define REGISTER_IDENTITY_OP_RULE(op) register_format_rule(op##_rule); | |||||
| struct FormatRuleRegistry { | |||||
| FormatRuleRegistry() { | |||||
| register_format_rule(dimshuffle_rule); | |||||
| register_format_rule(reshape_rule); | |||||
| register_format_rule(broadcast_rule); | |||||
| register_format_rule(subtensor_rule<Subtensor>); | |||||
| register_format_rule(subtensor_rule<IndexingMultiAxisVec>); | |||||
| register_format_rule(setsubtensor_rule<SetSubtensor>); | |||||
| register_format_rule(setsubtensor_rule<IndexingSetMultiAxisVec>); | |||||
| register_format_rule(concat_rule); | |||||
| register_format_rule(elemwise_rule); | |||||
| FOREACH_IDENTITY_OP(REGISTER_IDENTITY_OP_RULE) | |||||
| } | |||||
| } _; | |||||
| #undef REGISTER_IDENTITY_OP_RULE | |||||
| } // namespace | |||||
| std::vector<ValueRef> FormatTransformation::apply_transformation( | |||||
| const Operator& op, Span<ValueRef> inputs) { | |||||
| if (auto* apply_op = op.as<ApplyOp>()) { | |||||
| // all inputs should be FormattedTensorValue | |||||
| auto iter = format_rules.find(apply_op->op().dyn_typeinfo()); | |||||
| if (iter != format_rules.end()) { | |||||
| return iter->second(apply_op->op(), inputs, m_auto_convert); | |||||
| } else { | |||||
| return wrap_outputs(imperative::apply(op, unwrap_inputs(inputs))); | |||||
| } | |||||
| } else if (auto* create_tensor = op.as<CreateTensor>()) { | |||||
| auto format = create_tensor->format(); | |||||
| return {FormattedTensorValue::make(imperative::apply(op, inputs)[0], format)}; | |||||
| } else if (auto* get_attr = op.as<GetAttr>()) { | |||||
| auto* src = inputs.as_array<1>()[0].as<FormattedTensorValue>(); | |||||
| if (!m_auto_convert || !src || src->format() != FT::NHWC) { | |||||
| return imperative::apply(op, unwrap_inputs(inputs)); | |||||
| } | |||||
| switch (get_attr->attr()) { | |||||
| case GetAttr::Shape: { | |||||
| auto output = imperative::apply(op, unwrap_inputs(inputs))[0]; | |||||
| auto shape = convert_nhwc2nchw_shape(output.cast<ShapeValue>()); | |||||
| return {ShapeValue::make(shape)}; | |||||
| } | |||||
| case GetAttr::Value: { | |||||
| auto nchw_src = unwrap_input(src->to(FT::NCHW, "")); | |||||
| return imperative::apply(op, std::vector<ValueRef>{nchw_src}); | |||||
| } | |||||
| default: | |||||
| return imperative::apply(op, unwrap_inputs(inputs)); | |||||
| } | |||||
| } else if (op.is<GetFormat>()) { | |||||
| bool is_formatted_tensor = inputs.as_array<1>()[0].is<FormattedTensorValue>(); | |||||
| if (is_formatted_tensor) { | |||||
| return {FormatValue::make(inputs[0].cast<FormattedTensorValue>().format())}; | |||||
| } else { | |||||
| mgb_log_warn( | |||||
| "Not FormattedTensorValue input for GetFormat op: %s", | |||||
| inputs[0].to_string().c_str()); | |||||
| return {FormatValue::make(FT::DEFAULT)}; | |||||
| } | |||||
| } else if (op.is<Operator::IdentityLike>()) { | |||||
| bool is_formatted_tensor = inputs.as_array<1>()[0].is<FormattedTensorValue>(); | |||||
| if (is_formatted_tensor) { | |||||
| auto& format = inputs[0].cast<FormattedTensorValue>().format(); | |||||
| return wrap_outputs( | |||||
| imperative::apply(op, unwrap_inputs(inputs)), format.type()); | |||||
| } else { | |||||
| mgb_log_warn( | |||||
| "Not FormattedTensorValue input for IdentityLike op: %s", | |||||
| inputs[0].to_string().c_str()); | |||||
| return imperative::apply(op, inputs); | |||||
| } | |||||
| } else { | |||||
| return imperative::apply(op, unwrap_inputs(inputs)); | |||||
| } | |||||
| }; | |||||
| } // namespace imperative | |||||
| } // namespace mgb | |||||
| @@ -58,6 +58,10 @@ TypedValueRef<DTypeValue> ValueRef::dtype() const { | |||||
| return imperative::apply(GetAttr(GetAttr::DType), *this)[0].cast_ref<DTypeValue>(); | return imperative::apply(GetAttr(GetAttr::DType), *this)[0].cast_ref<DTypeValue>(); | ||||
| } | } | ||||
| TypedValueRef<FormatValue> ValueRef::format() const { | |||||
| return imperative::apply(GetFormat(), *this)[0].as_ref<FormatValue>(); | |||||
| } | |||||
| TypedValueRef<StringValue> ValueRef::name() const { | TypedValueRef<StringValue> ValueRef::name() const { | ||||
| return imperative::apply(GetName(), *this)[0].cast_ref<StringValue>(); | return imperative::apply(GetName(), *this)[0].cast_ref<StringValue>(); | ||||
| } | } | ||||
| @@ -5,6 +5,7 @@ | |||||
| #include "megbrain/imperative/op_def.h" | #include "megbrain/imperative/op_def.h" | ||||
| #include "megbrain/imperative/operator.h" | #include "megbrain/imperative/operator.h" | ||||
| #include "megbrain/imperative/utils/data_format.h" | |||||
| #include "megbrain/imperative/utils/helper.h" | #include "megbrain/imperative/utils/helper.h" | ||||
| #include "megbrain/imperative/utils/value_shape.h" | #include "megbrain/imperative/utils/value_shape.h" | ||||
| @@ -82,9 +83,12 @@ private: | |||||
| CompNode m_device; | CompNode m_device; | ||||
| DType m_dtype; | DType m_dtype; | ||||
| ValueShape m_shape; | ValueShape m_shape; | ||||
| Format m_format; | |||||
| public: | public: | ||||
| CreateTensor(Kind kind, CompNode device, DType dtype, ValueShape shape); | |||||
| CreateTensor( | |||||
| Kind kind, CompNode device, DType dtype, ValueShape shape, | |||||
| Format format = Format::Type::DEFAULT); | |||||
| CreateTensor(Kind kind, CompNode device, TensorLayout layout); | CreateTensor(Kind kind, CompNode device, TensorLayout layout); | ||||
| /** | /** | ||||
| @@ -99,6 +103,7 @@ public: | |||||
| CompNode device() const { return m_device; } | CompNode device() const { return m_device; } | ||||
| DType dtype() const { return m_dtype; } | DType dtype() const { return m_dtype; } | ||||
| ValueShape shape() const { return m_shape; } | ValueShape shape() const { return m_shape; } | ||||
| Format format() const { return m_format; } | |||||
| std::string to_string() const override; | std::string to_string() const override; | ||||
| }; | }; | ||||
| @@ -157,6 +162,11 @@ public: | |||||
| std::string to_string() const override; | std::string to_string() const override; | ||||
| }; | }; | ||||
| class GetFormat final : public OperatorImpl<GetFormat, Operator::GetAttrLike> { | |||||
| public: | |||||
| std::string to_string() const override { return "GetFormat{}"; } | |||||
| }; | |||||
| class GetVarVal final : public OperatorImpl<GetVarVal, Operator::GetAttrLike> { | class GetVarVal final : public OperatorImpl<GetVarVal, Operator::GetAttrLike> { | ||||
| public: | public: | ||||
| std::string to_string() const override; | std::string to_string() const override; | ||||
| @@ -3,6 +3,7 @@ | |||||
| #include <future> | #include <future> | ||||
| #include <iomanip> | #include <iomanip> | ||||
| #include "megbrain/imperative/utils/data_format.h" | |||||
| #include "megbrain/imperative/utils/helper.h" | #include "megbrain/imperative/utils/helper.h" | ||||
| #include "megbrain/imperative/utils/value_shape.h" | #include "megbrain/imperative/utils/value_shape.h" | ||||
| #include "megbrain/imperative/value.h" | #include "megbrain/imperative/value.h" | ||||
| @@ -148,6 +149,13 @@ public: | |||||
| std::string to_string() const override; | std::string to_string() const override; | ||||
| }; | }; | ||||
| class FormatValue final : public PrimitiveValue<FormatValue, Format> { | |||||
| public: | |||||
| using PrimitiveValue::PrimitiveValue; | |||||
| std::string to_string() const override { return Format::to_string(); } | |||||
| }; | |||||
| class StringValue final : public PrimitiveValue<StringValue, std::string> { | class StringValue final : public PrimitiveValue<StringValue, std::string> { | ||||
| public: | public: | ||||
| using PrimitiveValue::PrimitiveValue; | using PrimitiveValue::PrimitiveValue; | ||||
| @@ -0,0 +1,70 @@ | |||||
| #pragma once | |||||
| #include "megbrain/imperative/basic_values.h" | |||||
| #include "megbrain/imperative/dispatch.h" | |||||
| #include "megbrain/imperative/ops/autogen.h" | |||||
| #include "megbrain/imperative/utils/data_format.h" | |||||
| namespace mgb::imperative { | |||||
| class FormattedTensorValue final : public ValueImpl<FormattedTensorValue> { | |||||
| private: | |||||
| ValueRef m_value; | |||||
| Format m_format; | |||||
| public: | |||||
| FormattedTensorValue(ValueRef value, Format format) | |||||
| : m_value(value), m_format(format) {} | |||||
| std::string to_string() const override { | |||||
| return ssprintf( | |||||
| "FormattedTensorValue{value=%s, format=%s}", | |||||
| m_value.to_string().c_str(), m_format.to_string().c_str()); | |||||
| } | |||||
| ValueRef value() const { return m_value; } | |||||
| const Format& format() const { return m_format; } | |||||
| TypedValueRef<FormattedTensorValue> as(const Format::Type& target) const; | |||||
| TypedValueRef<FormattedTensorValue> to( | |||||
| const Format::Type& target, const std::string& scope = "") const; | |||||
| void clear() override { | |||||
| m_value = {}; | |||||
| m_format = {}; | |||||
| } | |||||
| void on_watch() override { m_value.watch(); } | |||||
| void on_unwatch() override { m_value.unwatch(); } | |||||
| }; | |||||
| /** | |||||
| * \brief simulates scalar because megbrain graph system don't support scalar | |||||
| * | |||||
| * Assume that we has 'a = ScalarValue(b)', thus 'a.shape == []', 'b.shape == [1]'. | |||||
| * This transformation simulates scalars with a flag. If a value is ScalarValue, it is | |||||
| * scalar, vice versa. So there is not scalar down this layer. | |||||
| */ | |||||
| class FormatTransformation final : public Transformation { | |||||
| private: | |||||
| bool m_auto_convert = false; | |||||
| public: | |||||
| std::vector<ValueRef> apply_transformation( | |||||
| const Operator& op, Span<ValueRef> inputs) override; | |||||
| ValueRef unwrap(ValueRef value) override { | |||||
| mgb_assert(!value.is<FormattedTensorValue>()); | |||||
| return value; | |||||
| } | |||||
| std::string name() const override { | |||||
| return ssprintf("FormatTransformation{auto_convert=%d}", m_auto_convert); | |||||
| } | |||||
| void set_auto_convert(bool enabled) { m_auto_convert = enabled; } | |||||
| bool get_auto_convert() const { return m_auto_convert; } | |||||
| }; | |||||
| } // namespace mgb::imperative | |||||
| @@ -0,0 +1,56 @@ | |||||
| #pragma once | |||||
| #include "megbrain/tensor.h" | |||||
| namespace mgb::imperative { | |||||
| /** | |||||
| * \brief like TensorFormats, but only including common formats and DEFAULT. | |||||
| * | |||||
| */ | |||||
| class Format { | |||||
| public: | |||||
| enum class Type { | |||||
| DEFAULT = 0, | |||||
| NCHW = 1, ///< [N, C, H, W] | |||||
| NHWC = 2, ///< [N, H, W, C] | |||||
| }; | |||||
| std::string to_string() const { | |||||
| switch (m_type) { | |||||
| case Type::DEFAULT: | |||||
| return "default"; | |||||
| case Type::NCHW: | |||||
| return "nchw"; | |||||
| case Type::NHWC: | |||||
| return "nhwc"; | |||||
| default: | |||||
| mgb_throw(MegBrainError, "bad format type"); | |||||
| } | |||||
| } | |||||
| Format() : m_type(Type::DEFAULT) {} | |||||
| Format(std::string str) { | |||||
| if (str == "default") { | |||||
| m_type = Type::DEFAULT; | |||||
| } else if (str == "nchw") { | |||||
| m_type = Type::NCHW; | |||||
| } else if (str == "nhwc") { | |||||
| m_type = Type::NHWC; | |||||
| } else { | |||||
| mgb_throw( | |||||
| MegBrainError, | |||||
| "Invalid format type." | |||||
| " Only support \"default\", \"nchw\" and \"nhwc\""); | |||||
| } | |||||
| } | |||||
| Format(Type type) : m_type(type) {} | |||||
| Type type() const { return m_type; } | |||||
| bool operator==(const Format& b) const { return m_type == b.type(); } | |||||
| bool operator==(const Format::Type& b) const { return m_type == b; } | |||||
| bool operator!=(const Format& b) const { return m_type != b.type(); } | |||||
| bool operator!=(const Format::Type& b) const { return m_type != b; } | |||||
| private: | |||||
| Type m_type = Type::DEFAULT; | |||||
| }; | |||||
| } // namespace mgb::imperative | |||||
| @@ -31,6 +31,7 @@ class HostValue; | |||||
| class DeviceValue; | class DeviceValue; | ||||
| class ShapeValue; | class ShapeValue; | ||||
| class DTypeValue; | class DTypeValue; | ||||
| class FormatValue; | |||||
| class CompNodeValue; | class CompNodeValue; | ||||
| class StringValue; | class StringValue; | ||||
| class NodeValue; | class NodeValue; | ||||
| @@ -219,6 +220,7 @@ public: | |||||
| TypedValueRef<CompNodeValue> device() const; | TypedValueRef<CompNodeValue> device() const; | ||||
| TypedValueRef<ShapeValue> shape() const; | TypedValueRef<ShapeValue> shape() const; | ||||
| TypedValueRef<DTypeValue> dtype() const; | TypedValueRef<DTypeValue> dtype() const; | ||||
| TypedValueRef<FormatValue> format() const; | |||||
| TypedValueRef<StringValue> name() const; | TypedValueRef<StringValue> name() const; | ||||
| bool is_scalar() const; | bool is_scalar() const; | ||||
| @@ -431,9 +433,11 @@ inline const TypedValueRef<TValue>& ValueRef::cast_ref(const Type<TValue>& type) | |||||
| inline void ValueRef::on_cast_failure(const IType& type) const { | inline void ValueRef::on_cast_failure(const IType& type) const { | ||||
| // if this is ErrorValue, rethrow directly | // if this is ErrorValue, rethrow directly | ||||
| storage()->try_rethrow(); | storage()->try_rethrow(); | ||||
| mgb_assert( | |||||
| storage()->type() != type, "expect type %s, got %s", type.name().c_str(), | |||||
| to_string().c_str()); | |||||
| if (storage()->type() != type) { | |||||
| mgb_throw( | |||||
| MegBrainError, "Unable to cast ValueRef: expect type %s, got %s", | |||||
| type.name().c_str(), to_string().c_str()); | |||||
| } | |||||
| } | } | ||||
| /** | /** | ||||
| @@ -200,7 +200,7 @@ void BatchNormForward::get_output_var_shape( | |||||
| bias_c = inp_shape[2][channel_idx]; | bias_c = inp_shape[2][channel_idx]; | ||||
| mgb_assert( | mgb_assert( | ||||
| inp_c == scale_c && inp_c == bias_c, | inp_c == scale_c && inp_c == bias_c, | ||||
| "inconsistent channel size, input chennel: %zu, scale channel: %zu, bias " | |||||
| "inconsistent channel size, input channel: %zu, scale channel: %zu, bias " | |||||
| "channel: %zu", | "channel: %zu", | ||||
| inp_c, scale_c, bias_c); | inp_c, scale_c, bias_c); | ||||