GitOrigin-RevId: 01b5324392
tags/v1.9.0
| @@ -28,6 +28,16 @@ void BNForward::check_exec( | |||
| const TensorLayout& variance, const TensorLayout& batch_mean, | |||
| const TensorLayout& batch_inv_variance, const TensorLayout& dst, | |||
| size_t workspace_in_bytes, size_t reserve_in_bytes) { | |||
| // moving some python assert to dnn to decrease the assert overhead | |||
| megdnn_assert( | |||
| src.ndim == 4, | |||
| "ndim of the input tensor for batch_norm should be 4, but you give %zu", | |||
| src.ndim); | |||
| megdnn_assert(bn_scale.ndim == 4, "expect 4, get %zu\n", bn_scale.ndim); | |||
| megdnn_assert(bn_bias.ndim == 4, "expect 4, get %zu\n", bn_bias.ndim); | |||
| megdnn_assert_eq_layout(bn_scale, bn_bias); | |||
| megdnn_assert_eq_layout(batch_mean, batch_inv_variance); | |||
| megdnn_assert_contiguous(src); | |||
| megdnn_assert_eq_layout(src, dst); | |||
| megdnn_assert_eq_layout(bn_scale, bn_bias); | |||
| @@ -58,16 +58,19 @@ class autocast: | |||
| self._origin_low = None | |||
| def __enter__(self): | |||
| self._origin_enabled, amp._enabled = amp._enabled, self.enabled | |||
| self._origin_high = amp._high_prec_dtype | |||
| amp._high_prec_dtype = self.high_prec_dtype | |||
| self._origin_low = amp._low_prec_dtype | |||
| amp._low_prec_dtype = self.low_prec_dtype | |||
| self._origin_enabled = amp._enabled | |||
| self._origin_high = amp._get_amp_high_prec_dtype() | |||
| self._origin_low = amp._get_amp_low_prec_dtype() | |||
| amp._enabled = self.enabled | |||
| amp._set_amp_dtype_autocast(self.enabled) | |||
| amp._set_amp_high_prec_dtype(self.high_prec_dtype) | |||
| amp._set_amp_low_prec_dtype(self.low_prec_dtype) | |||
| def __exit__(self, *args): | |||
| amp._enabled = self._origin_enabled | |||
| amp._high_prec_dtype = self._origin_high | |||
| amp._low_prec_dtype = self._origin_low | |||
| amp._set_amp_dtype_autocast(self._origin_enabled) | |||
| amp._set_amp_high_prec_dtype(self._origin_high) | |||
| amp._set_amp_low_prec_dtype(self._origin_low) | |||
| def __call__(self, func): | |||
| @functools.wraps(func) | |||
| @@ -5,9 +5,18 @@ | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from .._imperative_rt.core2 import ( | |||
| _get_amp_dtype_autocast, | |||
| _get_amp_high_prec_dtype, | |||
| _get_amp_low_prec_dtype, | |||
| _set_amp_dtype_autocast, | |||
| _set_amp_high_prec_dtype, | |||
| _set_amp_low_prec_dtype, | |||
| ) | |||
| _enabled = False | |||
| _high_prec_dtype = "float32" | |||
| _low_prec_dtype = "float16" | |||
| _set_amp_dtype_autocast(_enabled) | |||
| @property | |||
| @@ -28,6 +37,7 @@ def enabled(mod): | |||
| def enabled(mod, enabled: bool): | |||
| global _enabled | |||
| _enabled = enabled | |||
| _set_amp_dtype_autocast(_enabled) | |||
| @property | |||
| @@ -42,13 +52,12 @@ def high_prec_dtype(mod): | |||
| import megengine as mge | |||
| mge.amp.high_prec_dtype = "float32" | |||
| """ | |||
| return _high_prec_dtype | |||
| return _get_amp_high_prec_dtype() | |||
| @high_prec_dtype.setter | |||
| def high_prec_dtype(mod, dtype: str): | |||
| global _high_prec_dtype | |||
| _high_prec_dtype = dtype | |||
| _set_amp_high_prec_dtype(dtype) | |||
| @property | |||
| @@ -63,10 +72,9 @@ def low_prec_dtype(mod): | |||
| import megengine as mge | |||
| mge.amp.low_prec_dtype = "float16" | |||
| """ | |||
| return _low_prec_dtype | |||
| return _get_amp_low_prec_dtype() | |||
| @low_prec_dtype.setter | |||
| def low_prec_dtype(mod, dtype: str): | |||
| global _low_prec_dtype | |||
| _low_prec_dtype = dtype | |||
| _set_amp_low_prec_dtype(dtype) | |||
| @@ -25,7 +25,6 @@ from .utils import ( | |||
| astensor1d, | |||
| astype, | |||
| cast_tensors, | |||
| convert_inputs, | |||
| make_shape_tuple, | |||
| subgraph, | |||
| ) | |||
| @@ -40,38 +39,6 @@ def _elwise_apply(args, mode): | |||
| def _elwise(*args, mode): | |||
| args = convert_inputs(*args) | |||
| if ( | |||
| mode | |||
| in ( | |||
| _ElwMod.TRUE_DIV, | |||
| _ElwMod.EXP, | |||
| _ElwMod.POW, | |||
| _ElwMod.LOG, | |||
| _ElwMod.EXPM1, | |||
| _ElwMod.LOG1P, | |||
| _ElwMod.ACOS, | |||
| _ElwMod.ASIN, | |||
| _ElwMod.ATAN2, | |||
| _ElwMod.COS, | |||
| _ElwMod.SIN, | |||
| _ElwMod.LOG_SUM_EXP, | |||
| ) | |||
| and ( | |||
| amp._enabled | |||
| or np.all([np.issubdtype(arg.dtype, np.integer) for arg in args]) | |||
| ) | |||
| or mode in (_ElwMod.TANH,) | |||
| and np.all([np.issubdtype(arg.dtype, np.integer) for arg in args]) | |||
| ): | |||
| # autocast to FP32 to maintain precision | |||
| # or to avoid op's not supporting all int args | |||
| args = cast_tensors(*args, promote=True) | |||
| if mode in (_ElwMod.CEIL, _ElwMod.FLOOR, _ElwMod.ROUND,) and np.issubdtype( | |||
| args[0].dtype, np.integer | |||
| ): | |||
| return args[0] | |||
| return _elwise_apply(args, mode) | |||
| @@ -504,10 +471,6 @@ def _remove_axis(inp: Tensor, axis) -> Tensor: | |||
| def _reduce(mode): | |||
| def f(self, axis=None, keepdims: bool = False): | |||
| data = self | |||
| if mode == "mean": | |||
| data = data.astype("float32") | |||
| elif self.dtype == np.bool_: | |||
| data = data.astype("int32") | |||
| if axis is None: | |||
| assert not keepdims, "can not set axis=None and keepdims=True" | |||
| result = _reduce_to_scalar(builtin.Reduce(mode=mode), data) | |||
| @@ -526,9 +489,6 @@ def _reduce(mode): | |||
| if not keepdims: | |||
| result = _remove_axis(result, axis) | |||
| if self.dtype == np.bool_: | |||
| if mode in ["min", "max"]: | |||
| result = result.astype("bool") | |||
| return result | |||
| return f | |||
| @@ -16,6 +16,8 @@ from .._imperative_rt import make_const | |||
| from .._imperative_rt.core2 import ( | |||
| SymbolVar, | |||
| Tensor, | |||
| _get_convert_inputs, | |||
| _set_convert_inputs, | |||
| apply, | |||
| dtype_promotion, | |||
| get_device, | |||
| @@ -27,15 +29,13 @@ from .._wrap import as_device | |||
| from ..autodiff.grad import Function | |||
| from ..ops import builtin | |||
| from ..ops.special import Const | |||
| from .amp import _high_prec_dtype, _low_prec_dtype | |||
| from .amp import _get_amp_high_prec_dtype, _get_amp_low_prec_dtype | |||
| from .dtype import is_dtype_equal, is_quantize | |||
| _enable_convert_inputs = True | |||
| def get_convert_inputs(): | |||
| r"""get the curerent state of `_enable_convert_inputs`""" | |||
| return _enable_convert_inputs | |||
| return _get_convert_inputs() | |||
| def set_convert_inputs(flag): | |||
| @@ -44,10 +44,7 @@ def set_convert_inputs(flag): | |||
| `_enable_convert_inputs` is set to `False`, otherwise enabled. This function is for | |||
| internal use only, and should be removed when the tensor-like system is refactored. | |||
| """ | |||
| global _enable_convert_inputs | |||
| backup = _enable_convert_inputs | |||
| _enable_convert_inputs = flag | |||
| return backup | |||
| return _set_convert_inputs(flag) | |||
| def concatenate(inputs, axis=0, *, device=None): | |||
| @@ -75,7 +72,7 @@ def convert_single_value(v, *, dtype=None, device=None): | |||
| def convert_inputs(*args, device=None): | |||
| if not _enable_convert_inputs: | |||
| if not _get_convert_inputs(): | |||
| return args | |||
| dtype = dtype_promotion(args) | |||
| @@ -109,9 +106,9 @@ def convert_inputs(*args, device=None): | |||
| def cast_tensors(*args, promote=False): | |||
| if promote: | |||
| dtype = _high_prec_dtype | |||
| dtype = _get_amp_high_prec_dtype() | |||
| else: | |||
| dtype = _low_prec_dtype | |||
| dtype = _get_amp_low_prec_dtype() | |||
| return tuple(arg.astype(dtype) if arg is not None else None for arg in args) | |||
| @@ -16,6 +16,7 @@ from ..core.tensor.array_method import _elwise | |||
| from ..core.tensor.utils import convert_inputs | |||
| from ..tensor import Tensor | |||
| from ..utils.deprecation import deprecated_func | |||
| from .tensor_cache import get_scalar_one | |||
| __all__ = [ | |||
| "abs", | |||
| @@ -359,7 +360,11 @@ def asin(x): | |||
| def atan(x): | |||
| r"""Element-wise `inverse tangent`.""" | |||
| return _elwise(x, 1, mode=Elemwise.Mode.ATAN2) | |||
| return _elwise( | |||
| x, | |||
| get_scalar_one("float32", x.device if isinstance(x, Tensor) else None), | |||
| mode=Elemwise.Mode.ATAN2, | |||
| ) | |||
| def atan2(y, x): | |||
| @@ -253,15 +253,6 @@ def conv2d( | |||
| conv_mode.lower() == "cross_correlation" | |||
| or conv_mode.name == "CROSS_CORRELATION" | |||
| ) | |||
| if amp._enabled: | |||
| compute_mode = "float32" | |||
| inp, weight, bias = cast_tensors(inp, weight, bias) | |||
| else: | |||
| dtype = dtype_promotion(inp, weight) | |||
| if inp.dtype != dtype: | |||
| inp = inp.astype(dtype) | |||
| if weight.dtype != dtype: | |||
| weight = weight.astype(dtype) | |||
| stride_h, stride_w = expand_hw(stride) | |||
| pad_h, pad_w = expand_hw(padding) | |||
| @@ -1328,29 +1319,32 @@ def batch_norm( | |||
| inplace: whether to update ``running_mean`` and ``running_var`` | |||
| inplace or return new tensors. Default: True | |||
| """ | |||
| if inp.ndim != 4: | |||
| raise NotImplementedError("batch_norm for ndim != 4") | |||
| if param_dim == "dim_1c11": | |||
| C = inp.shape[1] | |||
| pshape = (1, C, 1, 1) | |||
| elif param_dim == "dim_111c": | |||
| C = inp.shape[3] | |||
| pshape = (1, 1, 1, C) | |||
| else: | |||
| raise ValueError("Invalid param_dim {}".format(param_dim)) | |||
| def make_full_if_none(x, value): | |||
| x_ndim = None if x is None else x.ndim | |||
| # in general case, x will be returned here directly | |||
| if x_ndim is not None and x_ndim != 1: | |||
| return x | |||
| if param_dim == "dim_1c11": | |||
| C = inp.shape[1] | |||
| pshape = (1, C, 1, 1) | |||
| elif param_dim == "dim_111c": | |||
| C = inp.shape[3] | |||
| pshape = (1, 1, 1, C) | |||
| else: | |||
| raise ValueError("Invalid param_dim {}".format(param_dim)) | |||
| if x is None: | |||
| (x,) = Const(value, dtype=inp.dtype, device=inp.device)() | |||
| shape = astensor1d(pshape, inp, dtype="int32", device=inp.device) | |||
| (result,) = apply(builtin.Broadcast(), x, shape) | |||
| return result | |||
| elif x.ndim == 1: | |||
| else: | |||
| assert x_ndim == 1 | |||
| shape = astensor1d(pshape, inp, dtype="int32", device=inp.device) | |||
| (result,) = apply(builtin.Reshape(), x, shape) | |||
| return result | |||
| return x | |||
| has_mean = running_mean is not None | |||
| has_var = running_var is not None | |||
| @@ -1359,16 +1353,6 @@ def batch_norm( | |||
| assert has_mean, "running_mean must be provided in inference mode" | |||
| assert has_var, "running_var must be provided in inference mode" | |||
| if has_mean and running_mean.ndim != 4: | |||
| raise ValueError | |||
| if has_var and running_var.ndim != 4: | |||
| raise ValueError | |||
| if amp._enabled: | |||
| inp = inp.astype("float16") | |||
| weight, bias, running_mean, running_var = cast_tensors( | |||
| weight, bias, running_mean, running_var, promote=True | |||
| ) | |||
| weight = make_full_if_none(weight, 1) | |||
| bias = make_full_if_none(bias, 0) | |||
| @@ -0,0 +1,34 @@ | |||
| from ..core.ops.special import Const | |||
| from ..jit.tracing import is_tracing | |||
| small_tensor_cache = {} | |||
| def _get_scalar_tensor_with_value(value, dtype=None, device=None): | |||
| global small_tensor_cache | |||
| if is_tracing(): | |||
| (ret,) = Const(value, dtype=dtype, device=device)() | |||
| else: | |||
| cache_key = (value, dtype, device) | |||
| if cache_key not in small_tensor_cache: | |||
| (ret,) = Const(value, dtype=dtype, device=device)() | |||
| small_tensor_cache[cache_key] = ret | |||
| else: | |||
| ret = small_tensor_cache[cache_key] | |||
| return ret | |||
| def get_scalar_zero(dtype=None, device=None): | |||
| return _get_scalar_tensor_with_value(0, dtype, device) | |||
| def get_scalar_zero_point_five(dtype=None, device=None): | |||
| return _get_scalar_tensor_with_value(0.5, dtype, device) | |||
| def get_scalar_one(dtype=None, device=None): | |||
| return _get_scalar_tensor_with_value(1, dtype, device) | |||
| def get_scalar_two(dtype=None, device=None): | |||
| return _get_scalar_tensor_with_value(2, dtype, device) | |||
| @@ -15,6 +15,7 @@ | |||
| #include "megbrain/imperative/ops/backward_graph.h" | |||
| #include "megbrain/imperative/ops/utility.h" | |||
| #include "megbrain/imperative/profiler.h" | |||
| #include "megbrain/imperative/transformations/dtype_promote.h" | |||
| #include "megbrain/imperative/transformations/eval.h" | |||
| #include "megbrain/imperative/transformations/lazy.h" | |||
| #include "megbrain/imperative/transformations/scalar.h" | |||
| @@ -59,16 +60,19 @@ struct SymbolVarContext { | |||
| TransformationContext context; | |||
| std::shared_ptr<SymbolTransformation> symbol_tsf; | |||
| std::shared_ptr<ScalarTransformation> scalar_tsf; | |||
| std::shared_ptr<DTypePromoteTransformation> dtype_promote_tsf; | |||
| SymbolVarContext(cg::ComputingGraph* graph) { | |||
| symbol_tsf = std::make_shared<SymbolTransformation>(graph); | |||
| scalar_tsf = std::make_shared<ScalarTransformation>(); | |||
| dtype_promote_tsf = std::make_shared<DTypePromoteTransformation>(); | |||
| Transformation::swap_context(context); | |||
| } | |||
| void init() { | |||
| symbol_tsf->register_at(Transformation::top()); | |||
| scalar_tsf->register_at(Transformation::top()); | |||
| dtype_promote_tsf->register_at(Transformation::top()); | |||
| } | |||
| ValueRef symvar2val(py::handle py_symbol_var) { | |||
| @@ -110,6 +114,9 @@ REGISTE_APPLY_FUNC(cpp_astensor1d) | |||
| #undef REGISTE_APPLY_FUNC | |||
| PyArray_Descr* _dtype_promotion(PyObject* const* args, size_t nargs); | |||
| CompNode _get_device(PyObject* const* args, size_t nargs); | |||
| PyObject* py_apply( | |||
| PyObject* self, PyObject* const* args, size_t nargs /* , PyObject* kwnames */) { | |||
| try { | |||
| @@ -133,19 +140,59 @@ PyObject* py_apply( | |||
| auto op = py::handle(py_op).cast<std::shared_ptr<OpDef>>(); | |||
| SmallVector<ValueRef, 8> tensors(nargs); | |||
| bool is_symbol_var = (!TensorWrapper::try_cast(args[0])) && | |||
| py::isinstance<PySymbolVar>(py::handle(args[0])); | |||
| if (is_symbol_var) { | |||
| SmallVector<bool, 8> is_symbol_var(nargs, false); | |||
| ComputingGraph* cg = nullptr; | |||
| for (size_t i = 0; i < nargs; ++i) { | |||
| if ((!TensorWrapper::try_cast(args[i])) && | |||
| py::isinstance<PySymbolVar>(py::handle(args[i]))) { | |||
| is_symbol_var[i] = true; | |||
| ComputingGraph* cur_cg = | |||
| py::handle(args[i]).cast<PySymbolVar*>()->m_node->owner_graph(); | |||
| if (cg == nullptr) { | |||
| cg = cur_cg; | |||
| } else { | |||
| mgb_assert(cg == cur_cg); | |||
| } | |||
| } | |||
| } | |||
| mgb::CompNode target_cn; | |||
| mgb::DType target_dtype; | |||
| auto convert_pyinput_to_tensor = [&](size_t i) -> ValueRef { | |||
| if (!target_dtype.valid()) { | |||
| target_dtype = npy::dtype_np2mgb_descr(_dtype_promotion(args, nargs)); | |||
| target_cn = _get_device(args, nargs); | |||
| } | |||
| HostTensorND ht(target_cn); | |||
| ht = npy::np2tensor(args[i], npy::Meth::copy_into(&ht), target_dtype); | |||
| if (PyArray_Check(args[i])) { // non scaler | |||
| return imperative::apply( | |||
| CreateTensor(CreateTensor::Const, target_cn, ht.layout()), | |||
| HostStorage::make(ht.storage()))[0]; | |||
| } else { // scaler | |||
| return imperative::apply( | |||
| CreateTensor(CreateTensor::Const, target_cn, target_dtype, {}), | |||
| HostStorage::make(ht.storage()))[0]; | |||
| } | |||
| }; | |||
| if (cg != nullptr) { | |||
| // swap to a special context to reuse scalar handle | |||
| SymbolVarContext context( | |||
| py::handle(args[0]).cast<PySymbolVar*>()->m_node->owner_graph()); | |||
| size_t symbol_var_idx = 8; | |||
| SymbolVarContext context(cg); | |||
| context.init(); | |||
| for (size_t i = 0; i < nargs; ++i) { | |||
| tensors[i] = context.symvar2val(args[i]); | |||
| if (is_symbol_var[i]) { | |||
| symbol_var_idx = i; | |||
| tensors[i] = context.symvar2val(args[i]); | |||
| } else { | |||
| tensors[i] = convert_pyinput_to_tensor(i); | |||
| } | |||
| } | |||
| auto outputs = imperative::apply(*op, tensors); | |||
| auto ret = pybind11::tuple(outputs.size()); | |||
| auto typeobj = py::handle(args[0]).get_type(); | |||
| auto typeobj = py::handle(args[symbol_var_idx]).get_type(); | |||
| for (size_t i = 0; i < outputs.size(); ++i) { | |||
| ret[i] = context.val2symvar(typeobj, outputs[i]); | |||
| } | |||
| @@ -156,13 +203,7 @@ PyObject* py_apply( | |||
| if (TensorWrapper* tw = TensorWrapper::try_cast(args[i])) { | |||
| tensors[i] = tw->m_tensor->data(); | |||
| } else { | |||
| PyErr_SetString( | |||
| PyExc_TypeError, | |||
| ssprintf( | |||
| "op %s expect type Tensor as inputs, got %s actually", | |||
| op->make_name().c_str(), Py_TYPE(args[i])->tp_name) | |||
| .c_str()); | |||
| return nullptr; | |||
| tensors[i] = convert_pyinput_to_tensor(i); | |||
| } | |||
| } | |||
| @@ -616,6 +657,8 @@ void init_tensor(py::module m) { | |||
| std::shared_ptr<Channel>(channel, [](Channel*) {}))); | |||
| transformations.register_at<Segment::Scalar>( | |||
| std::make_shared<ScalarTransformation>()); | |||
| transformations.register_at<Segment::DTypePromote>( | |||
| std::make_shared<DTypePromoteTransformation>()); | |||
| static py::exception<interpreter::AsyncError> py_async_error( | |||
| m, "AsyncError", PyExc_RuntimeError); | |||
| @@ -1137,6 +1180,63 @@ void init_tensor(py::module m) { | |||
| m.def("reset_stats", [] { imperative::Stats::reset(); }); | |||
| m.def("_get_convert_inputs", | |||
| []() -> bool { return DTypePromoteCfg::convert_input_enabled; }); | |||
| m.def("_set_convert_inputs", [](bool flag) -> bool { | |||
| bool ret = DTypePromoteCfg::convert_input_enabled; | |||
| DTypePromoteCfg::convert_input_enabled = flag; | |||
| return ret; | |||
| }); | |||
| m.def("_get_amp_dtype_autocast", | |||
| []() -> bool { return DTypePromoteCfg::amp_dtype_autocast_enabled; }); | |||
| m.def("_set_amp_dtype_autocast", [](bool flag) -> bool { | |||
| bool ret = DTypePromoteCfg::amp_dtype_autocast_enabled; | |||
| DTypePromoteCfg::amp_dtype_autocast_enabled = flag; | |||
| return ret; | |||
| }); | |||
| static auto get_amp_prec_dtype = [](bool is_high) -> std::string { | |||
| DType& target = is_high ? DTypePromoteCfg::amp_high_prec_dtype | |||
| : DTypePromoteCfg::amp_low_prec_dtype; | |||
| mgb_assert(target.category() == DTypeCategory::FLOAT); | |||
| std::string ret = target.name(); | |||
| transform(ret.begin(), ret.end(), ret.begin(), ::tolower); | |||
| return ret; | |||
| }; | |||
| static auto set_amp_prec_dtype = [](bool is_high, | |||
| std::string dtype_name) -> std::string { | |||
| DType& target = is_high ? DTypePromoteCfg::amp_high_prec_dtype | |||
| : DTypePromoteCfg::amp_low_prec_dtype; | |||
| std::string ret = target.name(); | |||
| if (dtype_name == "float32") { | |||
| target = dtype::Float32(); | |||
| } else if (dtype_name == "float16") { | |||
| target = dtype::Float16(); | |||
| } else if (dtype_name == "bfloat16") { | |||
| target = dtype::BFloat16(); | |||
| } else { | |||
| mgb_assert( | |||
| false, "casted type of amp should be float, but you give %s\n", | |||
| dtype_name.c_str()); | |||
| } | |||
| transform(ret.begin(), ret.end(), ret.begin(), ::tolower); | |||
| return ret; | |||
| }; | |||
| m.def("_get_amp_high_prec_dtype", | |||
| []() -> std::string { return get_amp_prec_dtype(true); }); | |||
| m.def("_set_amp_high_prec_dtype", [](std::string dtype_name) -> std::string { | |||
| return set_amp_prec_dtype(true, dtype_name); | |||
| }); | |||
| m.def("_get_amp_low_prec_dtype", | |||
| []() -> std::string { return get_amp_prec_dtype(false); }); | |||
| m.def("_set_amp_low_prec_dtype", [](std::string dtype_name) -> std::string { | |||
| return set_amp_prec_dtype(false, dtype_name); | |||
| }); | |||
| py::register_exception<TraceError>(m, "TraceError"); | |||
| } | |||
| @@ -26,12 +26,13 @@ struct TransformationManager { | |||
| enum Segment { | |||
| ModuleTrace, | |||
| Grad, | |||
| DTypePromote, | |||
| Scalar, | |||
| Trace, | |||
| Eval, | |||
| }; | |||
| std::array<std::vector<std::shared_ptr<Transformation>>, 5> segments; | |||
| std::array<std::vector<std::shared_ptr<Transformation>>, 6> segments; | |||
| template <Segment segment> | |||
| void register_at(std::shared_ptr<Transformation> transformation) { | |||
| @@ -14,20 +14,20 @@ def test_grad_scaler(): | |||
| assert amp.enabled == enabled | |||
| assert origin_amp._enabled == enabled | |||
| assert amp.low_prec_dtype == low | |||
| assert origin_amp._low_prec_dtype == low | |||
| assert origin_amp._get_amp_low_prec_dtype() == low | |||
| assert amp.high_prec_dtype == high | |||
| assert origin_amp._high_prec_dtype == high | |||
| assert origin_amp._get_amp_high_prec_dtype() == high | |||
| origin_enabled = amp.enabled | |||
| origin_high = amp.high_prec_dtype | |||
| origin_low = amp.low_prec_dtype | |||
| with amp.autocast(low_prec_dtype="low", high_prec_dtype="high"): | |||
| check(True, "low", "high") | |||
| with amp.autocast(low_prec_dtype="float16", high_prec_dtype="float32"): | |||
| check(True, "float16", "float32") | |||
| check(origin_enabled, origin_low, origin_high) | |||
| amp.enabled = True | |||
| amp.high_prec_dtype = "high" | |||
| amp.low_prec_dtype = "low" | |||
| check(True, "low", "high") | |||
| amp.high_prec_dtype = "float32" | |||
| amp.low_prec_dtype = "float16" | |||
| check(True, "float16", "float32") | |||
| amp.enabled = origin_enabled | |||
| amp.high_prec_dtype = origin_high | |||
| amp.low_prec_dtype = origin_low | |||
| @@ -0,0 +1,251 @@ | |||
| #include "megbrain/imperative/transformations/dtype_promote.h" | |||
| #include "megbrain/imperative/ops/autogen.h" | |||
| namespace mgb::imperative { | |||
| bool DTypePromoteCfg::convert_input_enabled = true; | |||
| bool DTypePromoteCfg::amp_dtype_autocast_enabled = false; | |||
| DType DTypePromoteCfg::amp_high_prec_dtype = dtype::Float32(); | |||
| DType DTypePromoteCfg::amp_low_prec_dtype = dtype::Float16(); | |||
| namespace { | |||
| // TODO: ScalarRule and DTypePromoteRule should be unified | |||
| using DTypePromoteRule = std::function<ValueRefList(const OpDef&, Span<ValueRef>)>; | |||
| static std::unordered_map<Typeinfo*, DTypePromoteRule> dtype_promotion_rules; | |||
| template <typename T> | |||
| void register_dtype_promote_rule(const DTypePromoteRule& rule) { | |||
| dtype_promotion_rules[T::typeinfo()] = [rule](const OpDef& def, | |||
| Span<ValueRef> inputs) { | |||
| return rule(def.cast_final_safe<T>(), inputs); | |||
| }; | |||
| } | |||
| bool is_quantized_dtype(const DType& dtype) { | |||
| return dtype.category() == DTypeCategory::QUANTIZED; | |||
| } | |||
| bool is_all_integer(const SmallVector<DType>& dtypes) { | |||
| for (size_t i = 0; i < dtypes.size(); ++i) { | |||
| if (dtypes[i].category() != DTypeCategory::INT) { | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| SmallVector<DType> get_value_dtypes(const Span<ValueRef> inputs) { | |||
| SmallVector<DType> dtypes(inputs.size()); | |||
| for (size_t i = 0; i < inputs.size(); ++i) { | |||
| dtypes[i] = *(inputs[i].dtype()); | |||
| } | |||
| return dtypes; | |||
| } | |||
| mgb::DType get_promoted_dtype(const SmallVector<DType>& dtypes) { | |||
| if (dtypes.size() == 0) { | |||
| mgb_assert(false, "there is no input for operator, dtype promote failed"); | |||
| } | |||
| mgb::DType ret = dtypes[0]; | |||
| for (size_t i = 1; i < dtypes.size(); ++i) { | |||
| ret = mgb::dtype_promotion(ret, dtypes[i]); | |||
| } | |||
| return ret; | |||
| } | |||
| ValueRefList elemwise_rule(const OpDef& op, Span<ValueRef> inputs) { | |||
| auto&& elem_op = op.cast_final_safe<Elemwise>(); | |||
| SmallVector<DType> dtypes(inputs.size()); | |||
| for (size_t i = 0; i < inputs.size(); ++i) { | |||
| dtypes[i] = *(inputs[i].dtype()); | |||
| } | |||
| ValueRefList converted(inputs.size()); | |||
| mgb::DType target_dtype = get_promoted_dtype(dtypes); | |||
| // TODO: we can save the dtypes of inputs here and perform TypeCvt at the end of | |||
| // this function, rather than perform TypeCvt eagerly. But for the compatibility, we | |||
| // implement this function with the similar process as the python version and | |||
| // perform TypeCvt here, so we maybe do TypeCvt several times in these function | |||
| for (size_t i = 0; i < inputs.size(); ++i) { | |||
| if (!is_quantized_dtype(dtypes[i]) && dtypes[i] != target_dtype && | |||
| DTypePromoteCfg::convert_input_enabled) { | |||
| converted[i] = imperative::apply( | |||
| ApplyOp(*TypeCvt::make(target_dtype)), inputs[i])[0]; | |||
| dtypes[i] = target_dtype; | |||
| } else { | |||
| converted[i] = inputs[i]; | |||
| } | |||
| } | |||
| static std::unordered_set<Elemwise::Mode> cast_case1 = { | |||
| Elemwise::Mode::TRUE_DIV, Elemwise::Mode::EXP, | |||
| Elemwise::Mode::POW, Elemwise::Mode::LOG, | |||
| Elemwise::Mode::EXPM1, Elemwise::Mode::LOG1P, | |||
| Elemwise::Mode::ACOS, Elemwise::Mode::ASIN, | |||
| Elemwise::Mode::ATAN2, Elemwise::Mode::COS, | |||
| Elemwise::Mode::SIN, Elemwise::Mode::LOG_SUM_EXP, | |||
| }; | |||
| static std::unordered_set<Elemwise::Mode> cast_case2 = { | |||
| Elemwise::Mode::TANH, | |||
| }; | |||
| auto cast_to_high_prec = [&]() { | |||
| for (size_t i = 0; i < dtypes.size(); ++i) { | |||
| if (dtypes[i] != DTypePromoteCfg::amp_high_prec_dtype) { | |||
| converted[i] = imperative::apply( | |||
| ApplyOp(*TypeCvt::make(DTypePromoteCfg::amp_high_prec_dtype)), | |||
| converted[i])[0]; | |||
| dtypes[i] = DTypePromoteCfg::amp_high_prec_dtype; | |||
| } | |||
| } | |||
| }; | |||
| if (cast_case1.find(elem_op.mode) != cast_case1.end()) { | |||
| if (DTypePromoteCfg::amp_dtype_autocast_enabled || is_all_integer(dtypes)) { | |||
| cast_to_high_prec(); | |||
| } | |||
| } | |||
| if (cast_case2.find(elem_op.mode) != cast_case2.end()) { | |||
| if (is_all_integer(dtypes)) { | |||
| cast_to_high_prec(); | |||
| } | |||
| } | |||
| static std::unordered_set<Elemwise::Mode> cast_case3 = { | |||
| Elemwise::Mode::CEIL, Elemwise::Mode::FLOOR, Elemwise::Mode::ROUND}; | |||
| if (cast_case3.find(elem_op.mode) != cast_case3.end()) { | |||
| if (dtypes[0].category() == DTypeCategory::INT) { | |||
| return converted; | |||
| } | |||
| } | |||
| return imperative::apply(op, converted); | |||
| } | |||
| ValueRefList reduce_rule(const OpDef& op, Span<ValueRef> inputs) { | |||
| auto&& reduce_op = op.cast_final_safe<Reduce>(); | |||
| DType org_dtype = *(inputs[0].dtype()); | |||
| DType target_dtype = org_dtype; | |||
| ValueRefList converted(inputs.begin(), inputs.end()); | |||
| if (reduce_op.mode == Reduce::Mode::MEAN) { | |||
| target_dtype = dtype::Float32(); | |||
| } else if (org_dtype.category() == DTypeCategory::BOOL) { | |||
| target_dtype = dtype::Int32(); | |||
| } | |||
| if (target_dtype != org_dtype) { | |||
| converted[0] = | |||
| imperative::apply(ApplyOp(*TypeCvt::make(target_dtype)), inputs[0])[0]; | |||
| } | |||
| ValueRefList ret = imperative::apply(op, converted); | |||
| if (org_dtype.category() == DTypeCategory::BOOL) { | |||
| if (reduce_op.mode == Reduce::Mode::MIN || | |||
| reduce_op.mode == Reduce::Mode::MAX) { | |||
| ret[0] = imperative::apply( | |||
| ApplyOp(*TypeCvt::make(dtype::Bool())), ret[0])[0]; | |||
| } | |||
| } | |||
| return ret; | |||
| } | |||
| ValueRefList convolution_rule(const OpDef& op, Span<ValueRef> inputs) { | |||
| auto&& conv_op = const_cast<Convolution&>(op.cast_final_safe<Convolution>()); | |||
| SmallVector<DType> dtypes = get_value_dtypes(inputs); | |||
| mgb::DType target_dtype; | |||
| if (DTypePromoteCfg::amp_dtype_autocast_enabled) { | |||
| conv_op.compute_mode = Convolution::ComputeMode::FLOAT32; | |||
| target_dtype = DTypePromoteCfg::amp_low_prec_dtype; | |||
| } else { | |||
| target_dtype = get_promoted_dtype(dtypes); | |||
| } | |||
| ValueRefList converted(inputs.size()); | |||
| for (size_t i = 0; i < inputs.size(); ++i) { | |||
| if (dtypes[i] != target_dtype) { | |||
| converted[i] = imperative::apply( | |||
| ApplyOp(*TypeCvt::make(target_dtype)), inputs[i])[0]; | |||
| } else { | |||
| converted[i] = inputs[i]; | |||
| } | |||
| } | |||
| return imperative::apply(op, converted); | |||
| } | |||
| ValueRefList batch_norm_rule(const OpDef& op, Span<ValueRef> inputs) { | |||
| if (DTypePromoteCfg::amp_dtype_autocast_enabled) { | |||
| mgb_assert(inputs.size() > 0); | |||
| ValueRefList converted(inputs.size()); | |||
| converted[0] = imperative::apply( | |||
| ApplyOp(*TypeCvt::make(dtype::Float16())), inputs[0])[0]; | |||
| for (size_t i = 1; i < inputs.size(); ++i) { | |||
| DType idtype = *(inputs[i].dtype()); | |||
| if (idtype != DTypePromoteCfg::amp_high_prec_dtype) { | |||
| converted[i] = imperative::apply( | |||
| ApplyOp(*TypeCvt::make(DTypePromoteCfg::amp_high_prec_dtype)), | |||
| inputs[i])[0]; | |||
| } else { | |||
| converted[i] = inputs[i]; | |||
| } | |||
| } | |||
| return imperative::apply(op, converted); | |||
| } | |||
| return imperative::apply(op, inputs); | |||
| } | |||
| struct DTypePromoteRuleRegistry { | |||
| DTypePromoteRuleRegistry() { | |||
| register_dtype_promote_rule<Elemwise>(elemwise_rule); | |||
| register_dtype_promote_rule<Reduce>(reduce_rule); | |||
| register_dtype_promote_rule<Convolution>(convolution_rule); | |||
| register_dtype_promote_rule<BatchNorm>(batch_norm_rule); | |||
| } | |||
| } register_helper; | |||
| } // namespace | |||
| ValueRefList DTypePromoteTransformation::apply_transformation( | |||
| const Operator& op, Span<ValueRef> inputs) { | |||
| if (auto apply_op = op.as<ApplyOp>()) { | |||
| auto iter = dtype_promotion_rules.find(apply_op->op().dyn_typeinfo()); | |||
| if (iter != dtype_promotion_rules.end()) { | |||
| return iter->second(apply_op->op(), inputs); | |||
| } else { | |||
| return imperative::apply(op, inputs); | |||
| } | |||
| } | |||
| return imperative::apply(op, inputs); | |||
| } | |||
| ValueRef DTypePromoteTransformation::unwrap(ValueRef value) { | |||
| return value; | |||
| } | |||
| std::string DTypePromoteTransformation::name() const { | |||
| return "DTypePromoteTransformation"; | |||
| } | |||
| void DTypePromoteTransformation::on_register() { | |||
| // printf("DTypePromoteTransformation has been registered\n"); | |||
| } | |||
| void DTypePromoteTransformation::on_unregister() noexcept { | |||
| // printf("DTypePromoteTransformation has been unregistered\n"); | |||
| } | |||
| } // namespace mgb::imperative | |||
| @@ -0,0 +1,26 @@ | |||
| #pragma once | |||
| #include "megbrain/imperative/dispatch.h" | |||
| #include "megbrain/imperative/value.h" | |||
| namespace mgb::imperative { | |||
| class DTypePromoteTransformation final : public Transformation { | |||
| private: | |||
| public: | |||
| ValueRefList apply_transformation( | |||
| const Operator& op, Span<ValueRef> inputs) override; | |||
| ValueRef unwrap(ValueRef value) override; | |||
| std::string name() const override; | |||
| void on_register() override; | |||
| void on_unregister() noexcept override; | |||
| }; | |||
| struct DTypePromoteCfg { | |||
| static bool convert_input_enabled; | |||
| static bool amp_dtype_autocast_enabled; | |||
| static DType amp_high_prec_dtype; | |||
| static DType amp_low_prec_dtype; | |||
| }; | |||
| } // namespace mgb::imperative | |||