GitOrigin-RevId: d14a69424d
tags/v1.9.0
| @@ -41,7 +41,6 @@ from ..distributed import WORLD, is_distributed | |||
| from ..jit import exclude_from_trace | |||
| from ..tensor import Tensor | |||
| from ..utils.deprecation import deprecated_func | |||
| from ..utils.tuple_function import _pair, _pair_nonzero, _triple, _triple_nonzero | |||
| from .debug_param import get_execution_strategy | |||
| from .distributed import all_reduce_sum | |||
| from .elemwise import _elwise, exp, log, log1p, maximum, minimum | |||
| @@ -94,14 +93,15 @@ __all__ = [ | |||
| def expand_hw(x): | |||
| # NOTE: >1d array is accepted, as long as 1 <= size <= 2 | |||
| try: | |||
| x = int(x) | |||
| return [x, x] | |||
| except (TypeError, ValueError): | |||
| pass | |||
| h, w = x | |||
| return int(h), int(w) | |||
| if isinstance(x, Sequence): | |||
| return int(x[0]), int(x[1]) | |||
| return int(x), int(x) | |||
| def expand_dhw(x): | |||
| if isinstance(x, Sequence): | |||
| return int(x[0]), int(x[1]), int(x[2]) | |||
| return int(x), int(x), int(x) | |||
| def linear( | |||
| @@ -177,11 +177,8 @@ def conv1d( | |||
| if weight.dtype != dtype: | |||
| weight = weight.astype(dtype) | |||
| inp = expand_dims(inp, 3) | |||
| weight = expand_dims(weight, 3) | |||
| if bias is not None: | |||
| assert bias.ndim == 3, "the bias dimension of conv1d should be 3" | |||
| bias = expand_dims(bias, 3) | |||
| stride_h = stride | |||
| pad_h = padding | |||
| @@ -206,7 +203,6 @@ def conv1d( | |||
| (output,) = apply(op, inp, weight) | |||
| if bias is not None: | |||
| output += bias | |||
| output = squeeze(output, 3) | |||
| return output | |||
| @@ -314,9 +310,9 @@ def conv3d( | |||
| D, H, W = 0, 1, 2 | |||
| pad = _triple(padding) | |||
| stride = _triple_nonzero(stride) | |||
| dilate = _triple_nonzero(dilation) | |||
| pad = expand_dhw(padding) | |||
| stride = expand_dhw(stride) | |||
| dilate = expand_dhw(dilation) | |||
| sparse_type = "dense" if groups == 1 else "group" | |||
| op = builtin.Convolution3D( | |||
| @@ -572,9 +568,9 @@ def conv_transpose3d( | |||
| output tensor. | |||
| """ | |||
| D, H, W = 0, 1, 2 | |||
| pad = _triple(padding) | |||
| stride = _triple_nonzero(stride) | |||
| dilate = _triple_nonzero(dilation) | |||
| pad = expand_dhw(padding) | |||
| stride = expand_dhw(stride) | |||
| dilate = expand_dhw(dilation) | |||
| sparse_type = "dense" if groups == 1 else "group" | |||
| op = builtin.Convolution3DBackwardData( | |||
| @@ -618,9 +614,9 @@ def max_pool2d( | |||
| """ | |||
| if stride is None: | |||
| stride = kernel_size | |||
| window_h, window_w = _pair_nonzero(kernel_size) | |||
| stride_h, stride_w = _pair_nonzero(stride) | |||
| padding_h, padding_w = _pair(padding) | |||
| window_h, window_w = expand_hw(kernel_size) | |||
| stride_h, stride_w = expand_hw(stride) | |||
| padding_h, padding_w = expand_hw(padding) | |||
| conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format) | |||
| op = builtin.Pooling( | |||
| @@ -662,9 +658,9 @@ def avg_pool2d( | |||
| """ | |||
| if stride is None: | |||
| stride = kernel_size | |||
| window_h, window_w = _pair_nonzero(kernel_size) | |||
| stride_h, stride_w = _pair_nonzero(stride) | |||
| padding_h, padding_w = _pair(padding) | |||
| window_h, window_w = expand_hw(kernel_size) | |||
| stride_h, stride_w = expand_hw(stride) | |||
| padding_h, padding_w = expand_hw(padding) | |||
| conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format) | |||
| op = builtin.Pooling( | |||
| @@ -1779,10 +1775,10 @@ def sliding_window( | |||
| stride: stride of the window. Default: 1 | |||
| dilation: dilation of the window. Default: 1 | |||
| """ | |||
| padding_h, padding_w = _pair(padding) | |||
| stride_h, stride_w = _pair_nonzero(stride) | |||
| dilation_h, dilation_w = _pair_nonzero(dilation) | |||
| window_h, window_w = _pair_nonzero(kernel_size) | |||
| padding_h, padding_w = expand_hw(padding) | |||
| stride_h, stride_w = expand_hw(stride) | |||
| dilation_h, dilation_w = expand_hw(dilation) | |||
| window_h, window_w = expand_hw(kernel_size) | |||
| op = builtin.Images2Neibs( | |||
| pad_h=padding_h, | |||
| @@ -1818,11 +1814,11 @@ def sliding_window_transpose( | |||
| stride: stride of the window. Default: 1 | |||
| dilation: dilation of the window. Default: 1 | |||
| """ | |||
| output_h, output_w = _pair_nonzero(output_size) | |||
| padding_h, padding_w = _pair(padding) | |||
| stride_h, stride_w = _pair_nonzero(stride) | |||
| dilation_h, dilation_w = _pair_nonzero(dilation) | |||
| window_h, window_w = _pair_nonzero(kernel_size) | |||
| output_h, output_w = expand_hw(output_size) | |||
| padding_h, padding_w = expand_hw(padding) | |||
| stride_h, stride_w = expand_hw(stride) | |||
| dilation_h, dilation_w = expand_hw(dilation) | |||
| window_h, window_w = expand_hw(kernel_size) | |||
| expected_h = ( | |||
| output_h + 2 * padding_h - dilation_h * (window_h - 1) - 1 | |||
| @@ -80,19 +80,6 @@ class _BatchNorm(Module): | |||
| self.track_running_stats == False | |||
| ), "track_running_stats can not be initilized to False and changed to True later" | |||
| inp_shape = inp.shape | |||
| _ndims = len(inp_shape) | |||
| if _ndims != 4: | |||
| origin_shape = inp_shape | |||
| if _ndims == 2: | |||
| n, c = inp_shape[0], inp_shape[1] | |||
| new_shape = (n, c, 1, 1) | |||
| elif _ndims == 3: | |||
| n, c, h = inp_shape[0], inp_shape[1], inp_shape[2] | |||
| new_shape = (n, c, h, 1) | |||
| inp = inp.reshape(new_shape) | |||
| _weight = self.weight | |||
| _bias = self.bias | |||
| @@ -130,9 +117,6 @@ class _BatchNorm(Module): | |||
| param_dim=self.param_dim, | |||
| ) | |||
| if _ndims != 4: | |||
| output = output.reshape(origin_shape) | |||
| return output | |||
| def _module_info_string(self) -> str: | |||
| @@ -15,6 +15,7 @@ | |||
| #include "megbrain/imperative/ops/backward_graph.h" | |||
| #include "megbrain/imperative/ops/utility.h" | |||
| #include "megbrain/imperative/profiler.h" | |||
| #include "megbrain/imperative/transformations/dim_expansion.h" | |||
| #include "megbrain/imperative/transformations/dtype_promote.h" | |||
| #include "megbrain/imperative/transformations/eval.h" | |||
| #include "megbrain/imperative/transformations/lazy.h" | |||
| @@ -61,11 +62,13 @@ struct SymbolVarContext { | |||
| std::shared_ptr<SymbolTransformation> symbol_tsf; | |||
| std::shared_ptr<ScalarTransformation> scalar_tsf; | |||
| std::shared_ptr<DTypePromoteTransformation> dtype_promote_tsf; | |||
| std::shared_ptr<DimExpansionTransformation> dim_expansion_tsf; | |||
| SymbolVarContext(cg::ComputingGraph* graph) { | |||
| symbol_tsf = std::make_shared<SymbolTransformation>(graph); | |||
| scalar_tsf = std::make_shared<ScalarTransformation>(); | |||
| dtype_promote_tsf = std::make_shared<DTypePromoteTransformation>(); | |||
| dim_expansion_tsf = std::make_shared<DimExpansionTransformation>(); | |||
| Transformation::swap_context(context); | |||
| } | |||
| @@ -73,6 +76,7 @@ struct SymbolVarContext { | |||
| symbol_tsf->register_at(Transformation::top()); | |||
| scalar_tsf->register_at(Transformation::top()); | |||
| dtype_promote_tsf->register_at(Transformation::top()); | |||
| dim_expansion_tsf->register_at(Transformation::top()); | |||
| } | |||
| ValueRef symvar2val(py::handle py_symbol_var) { | |||
| @@ -452,6 +456,8 @@ void init_tensor(py::module m) { | |||
| std::make_shared<ScalarTransformation>()); | |||
| transformations.register_at<Segment::DTypePromote>( | |||
| std::make_shared<DTypePromoteTransformation>()); | |||
| transformations.register_at<Segment::DimExpansion>( | |||
| std::make_shared<DimExpansionTransformation>()); | |||
| static py::exception<interpreter::AsyncError> py_async_error( | |||
| m, "AsyncError", PyExc_RuntimeError); | |||
| @@ -26,13 +26,14 @@ struct TransformationManager { | |||
| enum Segment { | |||
| ModuleTrace, | |||
| DTypePromote, | |||
| DimExpansion, | |||
| Grad, | |||
| Scalar, | |||
| Trace, | |||
| Eval, | |||
| }; | |||
| std::array<std::vector<std::shared_ptr<Transformation>>, 6> segments; | |||
| std::array<std::vector<std::shared_ptr<Transformation>>, 7> segments; | |||
| template <Segment segment> | |||
| void register_at(std::shared_ptr<Transformation> transformation) { | |||
| @@ -91,7 +91,7 @@ class ResNet(M.Module): | |||
| def run_dtr_resnet1202(): | |||
| batch_size = 8 | |||
| batch_size = 7 | |||
| resnet1202 = ResNet(BasicBlock, [200, 200, 200]) | |||
| opt = optim.SGD(resnet1202.parameters(), lr=0.05, momentum=0.9, weight_decay=1e-4) | |||
| gm = GradManager().attach(resnet1202.parameters()) | |||
| @@ -0,0 +1,95 @@ | |||
| #include "megbrain/imperative/transformations/dim_expansion.h" | |||
| #include "megbrain/imperative/ops/autogen.h" | |||
| namespace mgb::imperative { | |||
| namespace { | |||
| using DimExpansionRule = std::function<ValueRefList(const OpDef&, Span<ValueRef>)>; | |||
| static std::unordered_map<Typeinfo*, DimExpansionRule> dim_expansion_rules; | |||
| template <typename T> | |||
| void register_dim_expansion_rules(const DimExpansionRule& rule) { | |||
| dim_expansion_rules[T::typeinfo()] = [rule](const OpDef& def, | |||
| Span<ValueRef> inputs) { | |||
| return rule(def.cast_final_safe<T>(), inputs); | |||
| }; | |||
| } | |||
| ValueRefList conv1d_rule(const OpDef& op, Span<ValueRef> inputs) { | |||
| bool need_expand = inputs.at(0).shape()->ndim == 3; | |||
| if (!need_expand) | |||
| return imperative::apply(op, inputs); | |||
| ValueRefList converted(inputs.size()); | |||
| std::vector<int32_t> axis = {(int32_t)3}; | |||
| for (size_t i = 0; i < inputs.size(); ++i) { | |||
| converted[i] = imperative::apply(ApplyOp(*AddAxis::make(axis)), inputs[i])[0]; | |||
| } | |||
| auto outputs = imperative::apply(op, converted); | |||
| outputs[0] = imperative::apply(ApplyOp(*RemoveAxis::make(axis)), outputs[0])[0]; | |||
| return outputs; | |||
| } | |||
| ValueRefList bn1d_rule(const OpDef& op, Span<ValueRef> inputs) { | |||
| size_t ndim = inputs.at(0).shape()->ndim; | |||
| bool need_expand = (ndim == 2 || ndim == 3); | |||
| if (!need_expand) | |||
| return imperative::apply(op, inputs); | |||
| ValueRefList converted(inputs.size()); | |||
| std::vector<int32_t> axis = {(int32_t)3}; | |||
| if (ndim == 2) { | |||
| axis.insert(axis.begin(), (int32_t)2); | |||
| } | |||
| converted[0] = imperative::apply(ApplyOp(*AddAxis::make(axis)), inputs[0])[0]; | |||
| for (size_t i = 1; i < inputs.size(); ++i) { | |||
| converted[i] = inputs[i]; | |||
| } | |||
| std::reverse(std::begin(axis), std::end(axis)); | |||
| auto outputs = imperative::apply(op, converted); | |||
| size_t idx = outputs.size() - 1; | |||
| outputs[idx] = imperative::apply(ApplyOp(*RemoveAxis::make(axis)), outputs[idx])[0]; | |||
| return outputs; | |||
| } | |||
| struct DimExpansionRuleRegistry { | |||
| DimExpansionRuleRegistry() { | |||
| register_dim_expansion_rules<Convolution>(conv1d_rule); | |||
| register_dim_expansion_rules<BatchNorm>(bn1d_rule); | |||
| } | |||
| } register_helper; | |||
| } // namespace | |||
| ValueRefList DimExpansionTransformation::apply_transformation( | |||
| const Operator& op, Span<ValueRef> inputs) { | |||
| if (auto apply_op = op.as<ApplyOp>()) { | |||
| auto iter = dim_expansion_rules.find(apply_op->op().dyn_typeinfo()); | |||
| if (iter != dim_expansion_rules.end()) { | |||
| return iter->second(apply_op->op(), inputs); | |||
| } else { | |||
| return imperative::apply(op, inputs); | |||
| } | |||
| } | |||
| return imperative::apply(op, inputs); | |||
| } | |||
| ValueRef DimExpansionTransformation::unwrap(ValueRef value) { | |||
| return value; | |||
| } | |||
| std::string DimExpansionTransformation::name() const { | |||
| return "DimExpansionTransformation"; | |||
| } | |||
| void DimExpansionTransformation::on_register() { | |||
| // printf("DimExpansionTransformation has been registered\n"); | |||
| } | |||
| void DimExpansionTransformation::on_unregister() noexcept { | |||
| // printf("DimExpansionTransformation has been unregistered\n"); | |||
| } | |||
| } // namespace mgb::imperative | |||
| @@ -0,0 +1,19 @@ | |||
| #pragma once | |||
| #include "megbrain/imperative/dispatch.h" | |||
| #include "megbrain/imperative/value.h" | |||
| namespace mgb::imperative { | |||
| class DimExpansionTransformation final : public Transformation { | |||
| private: | |||
| public: | |||
| ValueRefList apply_transformation( | |||
| const Operator& op, Span<ValueRef> inputs) override; | |||
| ValueRef unwrap(ValueRef value) override; | |||
| std::string name() const override; | |||
| void on_register() override; | |||
| void on_unregister() noexcept override; | |||
| }; | |||
| } // namespace mgb::imperative | |||