1. implement double backward for batchnorm
2. fix grad attach in nested grad manager
3. pad empty tensor for unsatisfied output_has_grad
4. support double backward for jit subgraph
5. support double backward for autodiff.Function
6. readd debug flag MGE_LOG_OP_DISPATCH
GitOrigin-RevId: cd31ddc620
tags/v1.9.0
| @@ -212,10 +212,7 @@ class Function: | |||
| if self.__single_output: | |||
| outputs = (outputs,) | |||
| for grad in reversed(group): | |||
| if grad._impl is None: | |||
| continue | |||
| outputs = core2.set_grad(grad._impl, normalized_backward, args, outputs) | |||
| outputs = core2.set_grad(normalized_backward, args, outputs) | |||
| if self.__single_output: | |||
| (outputs,) = outputs | |||
| return outputs | |||
| @@ -209,7 +209,6 @@ def subgraph( | |||
| outputs = gen.send(None) | |||
| nr_outputs = len(outputs) | |||
| forward_fn = build(builder, outputs, [False] * nr_outputs) | |||
| output_grads = [builder.input() for _ in range(nr_outputs)] | |||
| input_grads = gen.send(output_grads) | |||
| assert len(input_grads) == nr_inputs | |||
| @@ -222,25 +221,49 @@ def subgraph( | |||
| ] | |||
| encoded_input_grads = [grad for grad in input_grads if grad is not None] | |||
| backward_fn = build( | |||
| builder, encoded_input_grads, [False] * len(encoded_input_grads) | |||
| builder, encoded_input_grads, [True] * len(encoded_input_grads) | |||
| ) | |||
| class SubgraphOp(Function): | |||
| def __init__(self): | |||
| self.inputs = None | |||
| self.output_shapes = None | |||
| def forward(self, *inputs): | |||
| self.inputs = inputs | |||
| return apply(forward_fn(), *inputs) | |||
| outputs = apply(forward_fn(), *inputs) | |||
| if len(outputs) > 1: | |||
| self.output_shapes = [output.shape for output in outputs] | |||
| return outputs | |||
| def backward(self, *output_grads): | |||
| inputs = self.inputs | |||
| self.inputs = None | |||
| encoded_input_grads = apply(backward_fn(), *inputs, *output_grads) | |||
| input_grads = [ | |||
| encoded_input_grads[i] if i is not None else None | |||
| for i in indices | |||
| ] | |||
| any_valid = False | |||
| all_valid = True | |||
| for output_grad in output_grads: | |||
| if output_grad is None: | |||
| all_valid = False | |||
| else: | |||
| any_valid = True | |||
| if not any_valid: | |||
| input_grads = [None] * len(indices) | |||
| else: | |||
| if not all_valid: | |||
| assert self.output_shapes is not None | |||
| from ...functional import zeros | |||
| output_grads = [ | |||
| zeros(self.output_shapes[i]) if grad is None else grad | |||
| for i, grad in enumerate(output_grads) | |||
| ] | |||
| self = None | |||
| encoded_input_grads = apply( | |||
| backward_fn(), *inputs, *output_grads | |||
| ) | |||
| input_grads = [ | |||
| encoded_input_grads[i] if i is not None else None | |||
| for i in indices | |||
| ] | |||
| return input_grads | |||
| gen.close() | |||
| @@ -896,7 +896,7 @@ def prelu(inp: Tensor, weight: Tensor) -> Tensor: | |||
| @lru_cache(maxsize=None) | |||
| def _get_leagk_relu_op(negative_slope, *, dtype=None, device=None): | |||
| def _get_leaky_relu_op(negative_slope, *, dtype=None, device=None): | |||
| @subgraph_fn( | |||
| "LeakyReLU", | |||
| dtype=dtype, | |||
| @@ -925,7 +925,7 @@ def leaky_relu(inp: Tensor, negative_slope: float = 0.01) -> Tensor: | |||
| Refer to :class:`~.LeakyReLU` for more information. | |||
| """ | |||
| leakyReLU = _get_leagk_relu_op(negative_slope, dtype=inp.dtype, device=inp.device) | |||
| leakyReLU = _get_leaky_relu_op(negative_slope, dtype=inp.dtype, device=inp.device) | |||
| (oup,) = leakyReLU(inp) | |||
| return oup | |||
| @@ -1399,7 +1399,7 @@ def _get_sync_bn_ops(device, dtype, eps_mode, ndim, channels): | |||
| f("fma3", input, inv_var_wt, | |||
| f("+", f("*", neg_channel_mean, inv_var_wt), | |||
| bias)) | |||
| return (outvar, channel_mean, channel_var, inv_var_wt), (True, False, False, False) | |||
| return (outvar, channel_mean, channel_var), (True, True, True) | |||
| @subgraph("SyncBnStage1Inference", dtype, device, 6) | |||
| def syncbn_stage1_inference(inputs, f, c): | |||
| @@ -1509,7 +1509,7 @@ def sync_batch_norm( | |||
| """ | |||
| _eps_mode = eps_mode.lower() | |||
| assert _eps_mode in {"max", "additive"}, "unknown eps_mode: {}".format(eps_mode) | |||
| if _eps_mode == "additive" and not (is_distributed() and training): | |||
| if _eps_mode == "additive" and not (is_distributed() or training): | |||
| return batch_norm( | |||
| inp, | |||
| running_mean, | |||
| @@ -121,13 +121,13 @@ void GradKeyWrapper::enter() { | |||
| m_key = m_transformation->key(); | |||
| m_key->name(m_name); | |||
| grad_key_map[m_key] = this; | |||
| TransformationManager::get_instance().register_at<TransformationManager::Grad>( | |||
| m_transformation); | |||
| m_transformation_guard = | |||
| TransformationManager::get_instance() | |||
| .register_at<TransformationManager::Grad>(m_transformation); | |||
| } | |||
| void GradKeyWrapper::exit() { | |||
| TransformationManager::get_instance().unregister<TransformationManager::Grad>( | |||
| m_transformation); | |||
| m_transformation_guard.reset(); | |||
| grad_key_map.erase(m_key); | |||
| m_key = {}; | |||
| m_transformation.reset(); | |||
| @@ -29,6 +29,7 @@ struct GradKeyWrapper : NonCopyableObj { | |||
| std::string m_name; | |||
| std::shared_ptr<GradKey> m_key; | |||
| std::shared_ptr<GradTransformation> m_transformation; | |||
| std::unique_ptr<CleanupGuard<>> m_transformation_guard; | |||
| GradKeyWrapper(); | |||
| @@ -449,15 +449,24 @@ void init_tensor(py::module m) { | |||
| interpreter::Interpreter::inst().create_channel()) | |||
| ->get(); | |||
| interpreter_for_py = channel; | |||
| transformations.register_at<Segment::Eval>( | |||
| std::make_shared<InterpreterTransformation>( | |||
| std::shared_ptr<Channel>(channel, [](Channel*) {}))); | |||
| transformations.register_at<Segment::Scalar>( | |||
| std::make_shared<ScalarTransformation>()); | |||
| transformations.register_at<Segment::DTypePromote>( | |||
| std::make_shared<DTypePromoteTransformation>()); | |||
| transformations.register_at<Segment::DimExpansion>( | |||
| std::make_shared<DimExpansionTransformation>()); | |||
| MGB_MARK_USED_VAR( | |||
| transformations | |||
| .register_at<Segment::Eval>( | |||
| std::make_shared<InterpreterTransformation>( | |||
| std::shared_ptr<Channel>(channel, [](Channel*) {}))) | |||
| .release()); | |||
| MGB_MARK_USED_VAR(transformations | |||
| .register_at<Segment::Scalar>( | |||
| std::make_shared<ScalarTransformation>()) | |||
| .release()); | |||
| MGB_MARK_USED_VAR(transformations | |||
| .register_at<Segment::DTypePromote>( | |||
| std::make_shared<DTypePromoteTransformation>()) | |||
| .release()); | |||
| MGB_MARK_USED_VAR(transformations | |||
| .register_at<Segment::DimExpansion>( | |||
| std::make_shared<DimExpansionTransformation>()) | |||
| .release()); | |||
| static py::exception<interpreter::AsyncError> py_async_error( | |||
| m, "AsyncError", PyExc_RuntimeError); | |||
| @@ -681,6 +690,9 @@ void init_tensor(py::module m) { | |||
| std::pair<size_t, std::shared_ptr<GraphProfiler>> profiler; | |||
| std::optional<TraceResult> trace_result; | |||
| std::function<bool(py::object, py::object)> array_comparator; | |||
| std::unique_ptr<CleanupGuard<>> tracing_guard; | |||
| std::unique_ptr<CleanupGuard<>> compiled_guard; | |||
| std::unique_ptr<CleanupGuard<>> lazy_eval_guard; | |||
| bool compare_value(ValueRef lhs, ValueRef rhs) { | |||
| auto lvalue = lhs.cast_ref<HostValue>(); | |||
| @@ -730,13 +742,16 @@ void init_tensor(py::module m) { | |||
| std::make_shared<GraphProfiler>(¤t_graph)); | |||
| } | |||
| } | |||
| transformations.register_at<Segment::Trace>(self.compiled); | |||
| compiled_guard = | |||
| transformations.register_at<Segment::Trace>(self.compiled); | |||
| // start execute because InputCallback depends | |||
| self.compiled->execute(); | |||
| } else if (self.tracing) { | |||
| transformations.register_at<Segment::Trace>(self.tracing); | |||
| tracing_guard = | |||
| transformations.register_at<Segment::Trace>(self.tracing); | |||
| if (self.lazy_eval) { | |||
| transformations.register_at<Segment::Eval>(self.lazy_eval); | |||
| lazy_eval_guard = | |||
| transformations.register_at<Segment::Eval>(self.lazy_eval); | |||
| } | |||
| } else { | |||
| mgb_throw(MegBrainError, "invalid state: neither tracing nor compiled"); | |||
| @@ -746,16 +761,16 @@ void init_tensor(py::module m) { | |||
| void exit() { | |||
| auto& self = *this; | |||
| if (self.tracing) { | |||
| transformations.unregister<Segment::Trace>(self.tracing); | |||
| tracing_guard.reset(); | |||
| self.trace_result = self.tracing->get_result(); | |||
| self.tracing.reset(); | |||
| if (self.lazy_eval) { | |||
| auto lazy_eval = std::move(self.lazy_eval); | |||
| transformations.unregister<Segment::Eval>(lazy_eval); | |||
| lazy_eval_guard.reset(); | |||
| lazy_eval->check_exception(); | |||
| } | |||
| } else if (self.compiled) { | |||
| transformations.unregister<Segment::Trace>(self.compiled); | |||
| compiled_guard.reset(); | |||
| self.compiled->wait(); | |||
| } else { | |||
| mgb_throw(MegBrainError, "invalid state: neither tracing nor compiled"); | |||
| @@ -829,17 +844,19 @@ void init_tensor(py::module m) { | |||
| [](Trace& self) { | |||
| mgb_assert(bool(self.tracing) ^ bool(self.compiled)); | |||
| if (self.tracing) { | |||
| transformations.unregister<Segment::Trace>(self.tracing); | |||
| self.tracing_guard.reset(); | |||
| } else if (self.compiled) { | |||
| transformations.unregister<Segment::Trace>(self.compiled); | |||
| self.compiled_guard.reset(); | |||
| } | |||
| }) | |||
| .def("end_excluded_region", [](Trace& self) { | |||
| mgb_assert(bool(self.tracing) ^ bool(self.compiled)); | |||
| if (self.tracing) { | |||
| transformations.register_at<Segment::Trace>(self.tracing); | |||
| self.tracing_guard = | |||
| transformations.register_at<Segment::Trace>(self.tracing); | |||
| } else if (self.compiled) { | |||
| transformations.register_at<Segment::Trace>(self.compiled); | |||
| self.compiled_guard = | |||
| transformations.register_at<Segment::Trace>(self.compiled); | |||
| } | |||
| }); | |||
| @@ -900,11 +917,8 @@ void init_tensor(py::module m) { | |||
| GradKeyWrapper::get(output.cast<GradKeyValue>()))); | |||
| }); | |||
| m.def("set_grad", [](py::object py_key, py::function backward_fn, | |||
| std::vector<py::object> inputs, | |||
| m.def("set_grad", [](py::function backward_fn, std::vector<py::object> inputs, | |||
| std::vector<py::object> outputs) { | |||
| mgb_assert(GradKeyWrapper::wrap_t::type().isinstance(py_key.ptr())); | |||
| auto* key = reinterpret_cast<GradKeyWrapper::wrap_t*>(py_key.ptr())->inst(); | |||
| GenericFunction generic_backward_fn = | |||
| [backward_fn](Span<ValueRef> output_grads) -> ValueRefList { | |||
| py::list output_grad_tws; | |||
| @@ -937,8 +951,8 @@ void init_tensor(py::module m) { | |||
| values[i + inputs.size()] = | |||
| outputs[i].cast<TensorWrapper>().m_tensor->data(); | |||
| } | |||
| auto wrapped_output_values = imperative::apply( | |||
| SetGrad(key->m_key, generic_backward_fn, inputs.size()), values); | |||
| auto wrapped_output_values = | |||
| imperative::apply(SetGrad(generic_backward_fn, inputs.size()), values); | |||
| std::vector<py::object> wrapped_outputs; | |||
| mgb_assert(wrapped_output_values.size() == outputs.size()); | |||
| for (auto&& output_value : wrapped_output_values) { | |||
| @@ -956,8 +970,10 @@ void init_tensor(py::module m) { | |||
| mgb_assert(module_trace_hook); | |||
| module_trace_transformation = | |||
| std::make_shared<ModuleTraceTransformation>(module_trace_hook); | |||
| transformations.register_at<Segment::ModuleTrace>( | |||
| module_trace_transformation); | |||
| MGB_MARK_USED_VAR(transformations | |||
| .register_at<Segment::ModuleTrace>( | |||
| module_trace_transformation) | |||
| .release()); | |||
| } | |||
| return module_trace_transformation; | |||
| }; | |||
| @@ -18,11 +18,13 @@ | |||
| #include "megbrain/imperative/dispatch.h" | |||
| #include "megbrain/imperative/transformation.h" | |||
| #include "megbrain/imperative/utils/helper.h" | |||
| #include "megbrain/imperative/value.h" | |||
| #include "megbrain/utils/small_vector.h" | |||
| namespace mgb::imperative::python { | |||
| struct TransformationManager { | |||
| public: | |||
| enum Segment { | |||
| ModuleTrace, | |||
| DTypePromote, | |||
| @@ -35,8 +37,21 @@ struct TransformationManager { | |||
| std::array<std::vector<std::shared_ptr<Transformation>>, 7> segments; | |||
| private: | |||
| template <Segment segment> | |||
| void unregister(std::shared_ptr<Transformation> transformation) noexcept { | |||
| mgb_assert(segment < segments.size()); | |||
| auto iter = std::find( | |||
| segments[segment].begin(), segments[segment].end(), transformation); | |||
| mgb_assert(iter != segments[segment].end()); | |||
| transformation->unregister(); | |||
| segments[segment].erase(iter); | |||
| } | |||
| public: | |||
| template <Segment segment> | |||
| void register_at(std::shared_ptr<Transformation> transformation) { | |||
| [[nodiscard]] std::unique_ptr<CleanupGuard<>> register_at( | |||
| std::shared_ptr<Transformation> transformation) { | |||
| mgb_assert(segment < segments.size()); | |||
| std::shared_ptr<Transformation> next; | |||
| for (size_t i = segment; i < segments.size(); ++i) { | |||
| @@ -51,16 +66,8 @@ struct TransformationManager { | |||
| transformation->register_at(next->pos()); | |||
| } | |||
| segments[segment].push_back(transformation); | |||
| } | |||
| template <Segment segment> | |||
| void unregister(std::shared_ptr<Transformation> transformation) noexcept { | |||
| mgb_assert(segment < segments.size()); | |||
| auto iter = std::find( | |||
| segments[segment].begin(), segments[segment].end(), transformation); | |||
| mgb_assert(iter != segments[segment].end()); | |||
| transformation->unregister(); | |||
| segments[segment].erase(iter); | |||
| return std::make_unique<CleanupGuard<>>( | |||
| [this, transformation]() { unregister<segment>(transformation); }); | |||
| } | |||
| static TransformationManager& get_instance() { | |||
| @@ -452,6 +452,8 @@ def test_2nd_grad_with_custom_gradient(): | |||
| return y | |||
| def backward(self, dy): | |||
| if dy is None: | |||
| return None | |||
| dx = -MySin()(self.inp) * dy | |||
| return dx | |||
| @@ -14,6 +14,7 @@ import pytest | |||
| import megengine as mge | |||
| import megengine.distributed as dist | |||
| import megengine.functional as F | |||
| import megengine.module as M | |||
| from megengine.core._imperative_rt import CompNode, TensorAttr, imperative | |||
| from megengine.core._imperative_rt.core2 import TensorWeakRef, apply, sync | |||
| from megengine.core.autodiff.grad import Grad | |||
| @@ -318,3 +318,41 @@ def test_throw_on_non_tensor_argument(): | |||
| func = NonTensorArg() | |||
| with pytest.raises(TypeError, match=r"op .* expect type Tensor as inputs"): | |||
| func(x, 1) | |||
| def test_multiple_grad(): | |||
| data_shape = (9, 2, 6) | |||
| av = np.random.random(data_shape).astype(np.float32) | |||
| class MulFunc(Function): | |||
| def forward(self, a): | |||
| self.a = a | |||
| return a * 10 | |||
| def backward(self, grad_o): | |||
| return grad_o * 20 | |||
| class Simple(Module): | |||
| def __init__(self, a): | |||
| super().__init__() | |||
| self.a = Parameter(a, dtype=np.float32) | |||
| self.layer1 = MulFunc() | |||
| def forward(self): | |||
| x = self.layer1(self.a) | |||
| return x | |||
| net = Simple(av) | |||
| gm = ad.GradManager().attach(net.parameters()) | |||
| gm2 = ad.GradManager().attach(net.parameters()) | |||
| opt = optimizer.SGD(net.parameters(), lr=1.0) | |||
| opt.clear_grad() | |||
| with gm: | |||
| with gm2: | |||
| loss = net() | |||
| gm.backward(loss.sum()) | |||
| opt.step() | |||
| np.testing.assert_almost_equal(loss.numpy(), (av * 10)) | |||
| np.testing.assert_almost_equal(net.a.numpy(), (av - 20)) | |||
| @@ -109,3 +109,46 @@ def test_subgraph(device, batch_size, channels, use_trace, symbolic, gopt_level, | |||
| _assert_allclose(out1.numpy(), out2.numpy()) | |||
| _assert_allclose(grad1.numpy(), grad2.numpy()) | |||
| @functools.lru_cache(maxsize=None) | |||
| def _get_mul_fn(dtype, device): | |||
| @subgraph_fn( | |||
| "Mul", | |||
| dtype=dtype, | |||
| device=device, | |||
| nr_inputs=2, | |||
| gopt_level=None, | |||
| jit_fusion=False, | |||
| custom_grad=True, | |||
| ) | |||
| def mul(inputs, f, c): | |||
| x, y = inputs[0:2] | |||
| z = f("*", x, y) | |||
| (dz,) = yield (z,) | |||
| dx = f("*", dz, y) | |||
| dy = f("*", dz, x) | |||
| yield (dx, dy) | |||
| return mul | |||
| def test_subgraph_jit_backward(): | |||
| x_np = np.random.rand(3, 4, 5).astype("float32") | |||
| x1 = megengine.Tensor(x_np) | |||
| x2 = megengine.Tensor(x_np) | |||
| mul = _get_mul_fn(x1.dtype, x1.device) | |||
| gm = GradManager() | |||
| gm.attach([x1, x2]) | |||
| with gm: | |||
| y1 = x1 * x1 | |||
| y2 = mul(x2, x2) | |||
| gm.backward(y1) | |||
| with gm: | |||
| y1 = x1 * x1 | |||
| y2 = mul(x2, x2) | |||
| gm.backward(y1 + y2) | |||
| with gm: | |||
| y1 = x1 * x1 | |||
| y2 = mul(x2, x2) | |||
| gm.backward(y2) | |||
| @@ -18,18 +18,44 @@ | |||
| namespace mgb { | |||
| namespace imperative { | |||
| namespace { | |||
| ValueRefList apply(const Operator& op, Span<ValueRef> inputs) { | |||
| ValueRefList apply_release(const Operator& op, Span<ValueRef> inputs) { | |||
| auto& context = Transformation::get_context(); | |||
| size_t& depth = context.next_transformation; | |||
| mgb_assert(depth < context.transformations.size()); | |||
| auto& transformation = *context.transformations[depth++]; | |||
| CleanupGuard _{[&] { --depth; }}; | |||
| return transformation.apply_transformation(op, inputs); | |||
| } | |||
| MGB_NOINLINE ValueRefList apply_debug(const Operator& op, Span<ValueRef> inputs) { | |||
| auto& context = Transformation::get_context(); | |||
| size_t& depth = context.next_transformation; | |||
| // TODO: add fallback transformation | |||
| bool fallback = depth >= context.transformations.size(); | |||
| if (mgb_unlikely(fallback)) { | |||
| return op.fallback(inputs); | |||
| mgb_assert(depth < context.transformations.size()); | |||
| static const char tabs[] = "\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t"; | |||
| const char* prefix = tabs + (sizeof(tabs) / sizeof(char)) - depth - 1; | |||
| mgb_log_debug( | |||
| "%s apply %s to %s", prefix, op.to_string().c_str(), | |||
| imperative::to_string(inputs).c_str()); | |||
| ValueRefList result; | |||
| auto& transformation = *context.transformations[depth++]; | |||
| CleanupGuard _{[&] { --depth; }}; | |||
| result = transformation.apply_transformation(op, inputs); | |||
| mgb_log_debug( | |||
| "%s returns %s", prefix, | |||
| imperative::to_string(Span<ValueRef>(result)).c_str()); | |||
| return result; | |||
| } | |||
| } // namespace | |||
| ValueRefList apply(const Operator& op, Span<ValueRef> inputs) { | |||
| static bool debug = MGB_GETENV("MGE_LOG_OP_DISPATCH"); | |||
| if (mgb_unlikely(debug)) { | |||
| return apply_debug(op, inputs); | |||
| } else { | |||
| auto& transformation = *context.transformations[depth++]; | |||
| CleanupGuard _{[&] { --depth; }}; | |||
| return transformation.apply_transformation(op, inputs); | |||
| return apply_release(op, inputs); | |||
| } | |||
| } | |||
| @@ -106,7 +106,8 @@ EncodedSubgraph OpDef::make_forward_graph( | |||
| } | |||
| std::string OpDef::to_string() const { | |||
| std::string builder = trait()->make_name(*this) + "{"; | |||
| std::string builder = trait()->name; | |||
| builder += "{"; | |||
| for (auto&& [name, value] : props(*this)) { | |||
| builder += name; | |||
| builder += ": "; | |||
| @@ -196,7 +197,7 @@ std::string Subgraph::repr() const { | |||
| if (auto* p = op->try_cast_final<OprAttr>()) { | |||
| buf << p->type; | |||
| } else { | |||
| buf << op->make_name(); | |||
| buf << op->to_string(); | |||
| } | |||
| for (size_t i : ins) { | |||
| buf << " "; | |||
| @@ -11,13 +11,94 @@ | |||
| #include "megbrain/opr/dnn/batch_norm.h" | |||
| #include "../op_trait.h" | |||
| #include "megbrain/imperative/graph_builder.h" | |||
| #include "megbrain/imperative/ops/autogen.h" | |||
| #include "megbrain/imperative/ops/utility.h" | |||
| #include "megbrain/imperative/proxy_graph_detail.h" | |||
| #include "megbrain/imperative/subgraph_detail.h" | |||
| #include "megbrain/tensor.h" | |||
| namespace mgb { | |||
| namespace imperative { | |||
| namespace { | |||
| EncodedSubgraph generate_batchnorm_backward_graph(DType dtype, CompNode device) { | |||
| Subgraph::Builder<LogicalTensorDesc> builder{ | |||
| [](std::shared_ptr<OpDef> op, SmallVector<LogicalTensorDesc> inputs, | |||
| size_t nr_outputs) { | |||
| auto [outputs, validated] = | |||
| OpDef::infer_output_attrs_fallible(*op, inputs); | |||
| mgb_assert(outputs.size() == nr_outputs, "nr_outputs mismatch"); | |||
| return outputs; | |||
| }}; | |||
| auto f = [&](auto&& op, auto... args) { | |||
| return builder.write_expr( | |||
| op, Subgraph::vars_t({(Subgraph::var_t)args...}), 1)[0]; | |||
| }; | |||
| auto prod = Reduce::make(megdnn::param::Reduce(Reduce::Mode::PRODUCT, 0)); | |||
| auto sum = Reduce::make(megdnn::param::Reduce(Reduce::Mode::SUM)); | |||
| auto sub = Elemwise::make(Elemwise::Mode::SUB); | |||
| auto mul = Elemwise::make(Elemwise::Mode::MUL); | |||
| auto div = Elemwise::make(Elemwise::Mode::TRUE_DIV); | |||
| auto floor_div = Elemwise::make(Elemwise::Mode::FLOOR_DIV); | |||
| auto broadcast = Broadcast::make(); | |||
| auto c = [&](TensorPtr tensor, DType dtype) { | |||
| auto result = builder.write_constant( | |||
| tensor, {TensorLayout{tensor->dtype()}, tensor->comp_node()}); | |||
| if (tensor->dtype() != dtype) { | |||
| result = f(TypeCvt::make(dtype), result); | |||
| } | |||
| return result; | |||
| }; | |||
| auto ci = [&](megdnn::dt_int32 value) { | |||
| return c(Tensor::make_scalar(DTypeScalar(value), device), dtype::Int32()); | |||
| }; | |||
| auto cf = [&](megdnn::dt_float32 value) { | |||
| return c(Tensor::make_scalar(DTypeScalar(value), device), dtype); | |||
| }; | |||
| auto desc = LogicalTensorDesc{TensorLayout{dtype}, device}; | |||
| auto x = builder.write_input(desc); | |||
| auto y_grad = builder.write_input(desc); | |||
| auto save_mean = builder.write_input(desc); | |||
| auto save_invstd = builder.write_input(desc); | |||
| auto weight = builder.write_input(desc); | |||
| auto reserved = builder.write_input(desc); | |||
| MGB_MARK_USED_VAR(reserved); | |||
| // assert x.ndim == 4 | |||
| auto input_shape = f(GetVarShape::make(), x); | |||
| auto channels = f(GetVarShape::make(1), x); | |||
| auto reduce_shape = f(Concat::make(0, device), ci(1), channels, ci(1), ci(1)); | |||
| auto input_elems = f(prod, input_shape); | |||
| auto reduce_size = f(floor_div, input_elems, channels); | |||
| auto reduce_size_f = f(TypeCvt::make(dtype), reduce_size); | |||
| auto mean = f(broadcast, save_mean, input_shape); | |||
| auto invstd = save_invstd; | |||
| auto norm = f(div, cf(1), reduce_size_f); | |||
| auto output_grad_sum = f(sum, y_grad, reduce_shape); | |||
| auto dot_p = f(sum, f(mul, y_grad, f(sub, x, mean)), reduce_shape); | |||
| auto mean_grad = f(broadcast, f(mul, output_grad_sum, norm), input_shape); | |||
| auto proj_scale = | |||
| f(broadcast, f(mul, f(mul, dot_p, norm), f(mul, invstd, invstd)), | |||
| input_shape); | |||
| auto grad_scale = f( | |||
| mul, f(broadcast, invstd, input_shape), f(broadcast, weight, input_shape)); | |||
| auto proj = f(mul, f(sub, x, mean), proj_scale); | |||
| auto x_grad = f(mul, f(sub, f(sub, y_grad, proj), mean_grad), grad_scale); | |||
| auto weight_grad = f(mul, dot_p, invstd); | |||
| auto bias_grad = output_grad_sum; | |||
| builder.add_outputs({weight_grad, bias_grad, x_grad}); | |||
| auto bn_backward = builder.encode(); | |||
| return bn_backward; | |||
| } | |||
| namespace bn { | |||
| std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { | |||
| auto* node = &node_->cast_final_safe<opr::BatchNorm>(); | |||
| return BatchNorm::make(node->param()); | |||
| @@ -72,8 +153,60 @@ OP_TRAIT_REG(BatchNorm, BatchNorm, opr::BatchNorm) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .infer_output_attrs_fallible(infer_output_attrs_fallible) | |||
| .fallback(); | |||
| } // anonymous namespace | |||
| } // namespace bn | |||
| namespace bn_backward { | |||
| std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { | |||
| auto* node = &node_->cast_final_safe<opr::BatchNormBackward>(); | |||
| return BatchNormBackward::make(node->param()); | |||
| } | |||
| VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
| auto& op = def.cast_final_safe<BatchNormBackward>(); | |||
| cg::SymbolVar x, y_grad, save_mean, save_variance, weight, reserve; | |||
| x = inputs[0]; | |||
| y_grad = inputs[1]; | |||
| save_mean = inputs[2]; | |||
| save_variance = inputs[3]; | |||
| weight = inputs[4]; | |||
| if (inputs.size() == 6) { | |||
| reserve = inputs[5]; | |||
| } | |||
| return opr::BatchNormBackward::make( | |||
| x, y_grad, save_mean, save_variance, weight, reserve, op.param())[0] | |||
| .node() | |||
| ->owner_opr() | |||
| ->usable_output(); | |||
| } | |||
| EncodedSubgraph make_backward_graph( | |||
| const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs, | |||
| const SmallVector<bool>& input_requires_grad, | |||
| const SmallVector<bool>& output_has_grad) { | |||
| def.cast_final_safe<BatchNormBackward>(); | |||
| size_t nr_inputs = 6; | |||
| size_t nr_outputs = 3; | |||
| mgb_assert(inputs.size() == nr_inputs); | |||
| mgb_assert(input_requires_grad.size() == nr_inputs); | |||
| mgb_assert(output_has_grad.size() == nr_outputs); | |||
| auto dtype = inputs[0].layout.dtype; | |||
| auto device = inputs[0].comp_node; | |||
| auto bn_backward = generate_batchnorm_backward_graph(dtype, device); | |||
| auto bn_double_backward = subgraph_detail::make_backward_graph_from_forward( | |||
| bn_backward, inputs, input_requires_grad, output_has_grad); | |||
| return bn_double_backward; | |||
| } | |||
| OP_TRAIT_REG(BatchNormBackward, BatchNormBackward, opr::BatchNormBackward) | |||
| .make_from_op_node(make_from_op_node) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .make_backward_graph(make_backward_graph) | |||
| .fallback(); | |||
| } // namespace bn_backward | |||
| } // anonymous namespace | |||
| } // namespace imperative | |||
| } // namespace mgb | |||
| @@ -762,7 +762,9 @@ EncodedSubgraph make_backward_graph( | |||
| const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs, | |||
| const SmallVector<bool>& input_requires_grad, | |||
| const SmallVector<bool>& output_has_grad) { | |||
| return {}; | |||
| return OpDef::make_backward_graph( | |||
| *def.cast_final_safe<JITFusionOp>().op, inputs, input_requires_grad, | |||
| output_has_grad); | |||
| } | |||
| OP_TRAIT_REG(JITFusionOp, JITFusionOp) | |||
| @@ -96,10 +96,11 @@ SmallVector<LayoutConstraintCallback> get_input_layout_constraint( | |||
| return res; | |||
| } | |||
| static EncodedSubgraph make_backward_graph_from_forward( | |||
| EncodedSubgraph make_backward_graph_from_forward( | |||
| const EncodedSubgraph& forward_graph, | |||
| const SmallVector<LogicalTensorDesc>& inputs, | |||
| const SmallVector<bool>& input_requires_grad, | |||
| const SmallVector<bool>& output_has_grad, EncodedSubgraph forward_graph) { | |||
| const SmallVector<bool>& output_has_grad) { | |||
| using namespace std::placeholders; | |||
| using var_t = Subgraph::var_t; | |||
| using vars_t = Subgraph::vars_t; | |||
| @@ -179,7 +180,7 @@ EncodedSubgraph make_backward_graph( | |||
| const SmallVector<bool>& output_has_grad) { | |||
| auto forward_graph = OpDef::make_forward_graph(def, inputs); | |||
| return make_backward_graph_from_forward( | |||
| inputs, input_requires_grad, output_has_grad, forward_graph); | |||
| forward_graph, inputs, input_requires_grad, output_has_grad); | |||
| } | |||
| } // namespace subgraph_detail | |||
| @@ -139,7 +139,7 @@ ValueRefList InterpreterTransformation::apply_transformation( | |||
| return {ValueRef()}; | |||
| } | |||
| } else { | |||
| return imperative::apply(op, inputs); | |||
| return op.fallback(inputs); | |||
| } | |||
| } | |||
| @@ -62,7 +62,8 @@ BackwardGraphWithClosure::BackwardGraphWithClosure( | |||
| std::shared_ptr<OpDef> op, Span<ValueRef> inputs, Span<ValueRef> outputs) | |||
| : backward_graph(backward_graph), | |||
| output_mask_offset(inputs.size()), | |||
| grad_mask_offset(inputs.size() + outputs.size()) { | |||
| grad_mask_offset(inputs.size() + outputs.size()), | |||
| op(op) { | |||
| auto& save_for_backward = backward_graph->save_for_backward; | |||
| mgb_assert(save_for_backward.size() == inputs.size() + 2 * outputs.size()); | |||
| size_t count = std::count_if( | |||
| @@ -92,6 +93,13 @@ BackwardGraphWithClosure::BackwardGraphWithClosure( | |||
| closure.push_back(outputs[i]); | |||
| } | |||
| } | |||
| if (outputs.size() > 1) { | |||
| output_descs.reserve(outputs.size()); | |||
| for (auto&& output : outputs) { | |||
| auto symbolic_shape = imperative::apply(*GetVarShape::make(), output)[0]; | |||
| output_descs.push_back({symbolic_shape, output.dtype(), output.device()}); | |||
| } | |||
| } | |||
| } | |||
| void BackwardGraphWithClosure::operator()( | |||
| Span<ValueRef> grads, std::function<void(size_t, ValueRef)> receiver) { | |||
| @@ -100,23 +108,46 @@ void BackwardGraphWithClosure::operator()( | |||
| for (auto&& value : closure) { | |||
| args[nargs++] = value; | |||
| } | |||
| bool null_grad = false; | |||
| size_t null_grad = 0; | |||
| size_t valid_grad = 0; | |||
| for (size_t i = 0; i < grads.size(); ++i) { | |||
| if (backward_graph->save_for_backward[grad_mask_offset + i]) { | |||
| if (grads[i]) { | |||
| mgb_assert(!null_grad, "null_grad"); | |||
| valid_grad++; | |||
| args[nargs++] = grads[i]; | |||
| } else { | |||
| null_grad = true; | |||
| null_grad++; | |||
| nargs++; | |||
| } | |||
| } | |||
| } | |||
| if (null_grad) { | |||
| if (valid_grad == 0) { | |||
| return; | |||
| } | |||
| auto igrads_ = imperative::apply(backward_graph->backward, Span(args, nargs)); | |||
| SmallVector<ValueRef> igrads = {igrads_.begin(), igrads_.end()}; | |||
| igrads_.clear(); | |||
| if (null_grad > 0) { | |||
| auto zeros_like = [](const OutputDesc& desc) { | |||
| HostTensorStorage storage(*desc.device); | |||
| storage.ensure_size(desc.dtype->size()); | |||
| std::memset(storage.ptr(), 0, desc.dtype->size()); | |||
| auto t = imperative::apply( | |||
| CreateTensor( | |||
| CreateTensor::Unique, *desc.device, *desc.dtype, | |||
| ValueShape()), | |||
| HostStorage::make(storage))[0]; | |||
| auto res = imperative::apply(*Broadcast::make(), t, desc.shape)[0]; | |||
| return res; | |||
| }; | |||
| nargs = closure.size(); | |||
| for (size_t i = 0; i < grads.size(); ++i) { | |||
| if (backward_graph->save_for_backward[grad_mask_offset + i]) { | |||
| if (!grads[i]) { | |||
| args[nargs] = zeros_like(output_descs[i]); | |||
| } | |||
| nargs++; | |||
| } | |||
| } | |||
| } | |||
| auto igrads = imperative::apply(backward_graph->backward, Span(args, nargs)); | |||
| auto&& iter = igrads.begin(); | |||
| for (auto [i, p] : ranges::views::enumerate(backward_graph->input_has_grad)) { | |||
| if (p) { | |||
| @@ -221,9 +252,11 @@ void GradKey::backward() { | |||
| if (!dest) { | |||
| continue; | |||
| } | |||
| if (!dest.m_producer_record.next && dest->callback && dest->m_grad) { | |||
| if (!dest.m_producer_record.next && dest->callback) { | |||
| // I'm the last grad producer, invoke callback | |||
| dest->callback(dest->m_grad); | |||
| if (dest->m_grad) { | |||
| dest->callback(dest->m_grad); | |||
| } | |||
| } | |||
| } | |||
| grad_fn->clear(); | |||
| @@ -394,16 +427,22 @@ ValueRefList GradTransformation::apply_transformation( | |||
| return imperative::apply(op, inputs); | |||
| } | |||
| if (auto* attach_grad = op.as<AttachGrad>()) { | |||
| if (!has_key(attach_grad->key())) { | |||
| auto& tensor = inputs[0]; | |||
| if (auto&& grad_value = tensor.as_ref(m_value_type)) { | |||
| mgb_assert(!has_key(attach_grad->key())); | |||
| auto output = fallback()[0]; | |||
| return record_grad(m_value_type.make(output, m_key, grad_value->slot())); | |||
| } else if (!has_key(attach_grad->key())) { | |||
| return fallback(); | |||
| } else { | |||
| GenericFunction callback = | |||
| (GenericFunction&)inputs[1].cast<FunctionValue>(); | |||
| auto output = attach_grad->key()->attach(tensor, [callback](ValueRef grad) { | |||
| auto ret = callback({&grad, 1}); | |||
| assert(ret.empty()); | |||
| }); | |||
| return {record_grad(output)}; | |||
| } | |||
| auto tensor = inputs[0]; | |||
| GenericFunction callback = (GenericFunction&)inputs[1].cast<FunctionValue>(); | |||
| auto output = attach_grad->key()->attach(tensor, [callback](ValueRef grad) { | |||
| auto ret = callback({&grad, 1}); | |||
| assert(ret.empty()); | |||
| }); | |||
| return {record_grad(output)}; | |||
| } else if (auto* grad_backward = op.as<GradBackward>()) { | |||
| if (!has_key(grad_backward->key())) { | |||
| return fallback(); | |||
| @@ -431,10 +470,10 @@ ValueRefList GradTransformation::apply_transformation( | |||
| mgb_assert(inputs.size() > nr_inputs); | |||
| size_t nr_outputs = inputs.size() - nr_inputs; | |||
| Span<ValueRef> inputs_ = {inputs.data(), nr_inputs}; | |||
| Span<ValueRef> outputs_ = {inputs.data() + nr_inputs, nr_outputs}; | |||
| backward.m_input_has_grad = SmallVector(nr_inputs, true); | |||
| backward.m_output_attrs = | |||
| SmallVector(nr_outputs, CustomBackward::OutputAttr{true, true}); | |||
| auto outputs_ = fallback(); | |||
| backward.m_input_has_grad.resize(nr_inputs, true); | |||
| backward.m_output_attrs.resize( | |||
| nr_outputs, CustomBackward::OutputAttr{true, true}); | |||
| backward.m_backward = [fn = set_grad->grad_fn()](Span<ValueRef> inputs) { | |||
| auto result = fn(inputs); | |||
| return SmallVector<ValueRef>(result.begin(), result.end()); | |||
| @@ -31,6 +31,7 @@ class Subgraph::Builder { | |||
| using infer_fn_t = std::function<descs_t(op_t, descs_t, size_t)>; | |||
| using encoded_graph_t = EncodedSubgraph; | |||
| using var_map_t = std::unordered_map<var_t, var_t>; | |||
| using mask_t = SmallVector<bool>; | |||
| vars_t m_inputs; | |||
| SmallVector<std::pair<var_t, TensorPtr>> m_constants; | |||
| vars_t m_outputs; | |||
| @@ -94,6 +95,7 @@ public: | |||
| descs_t get_descs(vars_t vars) { | |||
| descs_t descs; | |||
| for (auto&& var : vars) { | |||
| mgb_assert(var, "invalid var"); | |||
| descs.push_back(get_desc(var)); | |||
| } | |||
| return descs; | |||
| @@ -128,4 +130,4 @@ public: | |||
| expr_iter_t end() { return m_exprs.end(); } | |||
| }; | |||
| } // namespace imperative | |||
| } // namespace mgb | |||
| } // namespace mgb | |||
| @@ -38,7 +38,6 @@ struct ShapeInfer final : OpDefImplBase<ShapeInfer> { | |||
| std::shared_ptr<OpDef> op; | |||
| SmallVector<CompNode> devices; | |||
| SmallVector<DType> dtypes; | |||
| EncodedSubgraph graph; | |||
| ShapeInfer() = default; | |||
| ShapeInfer( | |||
| std::shared_ptr<OpDef> op, SmallVector<CompNode> devices, | |||
| @@ -39,6 +39,11 @@ EncodedSubgraph make_backward_graph( | |||
| SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint( | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs); | |||
| EncodedSubgraph make_backward_graph_from_forward( | |||
| const EncodedSubgraph& forward, const SmallVector<LogicalTensorDesc>& inputs, | |||
| const SmallVector<bool>& input_requires_grad, | |||
| const SmallVector<bool>& output_has_grad); | |||
| } // namespace subgraph_detail | |||
| } // namespace imperative | |||
| } // namespace mgb | |||
| } // namespace mgb | |||
| @@ -29,6 +29,15 @@ struct BackwardGraphWithClosure { | |||
| SmallVector<ValueRef> closure; | |||
| size_t output_mask_offset; | |||
| size_t grad_mask_offset; | |||
| std::shared_ptr<OpDef> op; | |||
| struct OutputDesc { | |||
| ValueRef shape; | |||
| DTypeValue::ref_t dtype; | |||
| CompNodeValue::ref_t device; | |||
| }; | |||
| SmallVector<OutputDesc> output_descs; | |||
| BackwardGraphWithClosure( | |||
| std::shared_ptr<OptimizedBackwardGraphResult> backward_graph, | |||
| @@ -356,20 +365,22 @@ public: | |||
| class SetGrad : public OperatorImpl<SetGrad> { | |||
| private: | |||
| std::shared_ptr<GradKey> m_key; | |||
| GenericFunction m_grad_fn; | |||
| size_t m_nr_inputs; | |||
| public: | |||
| SetGrad(std::shared_ptr<GradKey> key, GenericFunction grad_fn, size_t nr_inputs) | |||
| : m_key(key), m_grad_fn(grad_fn), m_nr_inputs(nr_inputs) {} | |||
| SetGrad(GenericFunction grad_fn, size_t nr_inputs) | |||
| : m_grad_fn(grad_fn), m_nr_inputs(nr_inputs) {} | |||
| GenericFunction grad_fn() const { return m_grad_fn; } | |||
| size_t nr_inputs() const { return m_nr_inputs; } | |||
| std::string to_string() const override { | |||
| return ssprintf("SetGradValue{key=%s}", m_key->name().c_str()); | |||
| std::string to_string() const override { return ssprintf("SetGradValue{}"); } | |||
| ValueRefList fallback(Span<ValueRef> inputs) const override { | |||
| auto outputs = inputs.sub(m_nr_inputs, inputs.size() - m_nr_inputs); | |||
| return {outputs.begin(), outputs.end()}; | |||
| } | |||
| }; | |||
| @@ -15,12 +15,14 @@ | |||
| #include <memory> | |||
| #include <sstream> | |||
| #include "megbrain/utils/metahelper.h" | |||
| namespace mgb { | |||
| namespace imperative { | |||
| template <typename T> | |||
| class CleanupGuard { | |||
| template <typename T = std::function<void()>> | |||
| class CleanupGuard : public NonCopyableObj { | |||
| private: | |||
| T m_callback; | |||
| @@ -37,4 +39,4 @@ inline std::string quoted(std::string str) { | |||
| } // namespace imperative | |||
| } // namespace mgb | |||
| } // namespace mgb | |||
| @@ -89,6 +89,8 @@ def SlidingWindowTranspose : MgbHashableOp<"SlidingWindowTranspose", [SlidingWin | |||
| def BatchNorm : MgbHashableOp<"BatchNorm", [BNParam]>; | |||
| def BatchNormBackward : MgbHashableOp<"BatchNormBackward", [BNParam]>; | |||
| def ROIAlign: MgbHashableOp<"ROIAlign", [ROIAlignParam]>; | |||
| def Correlation: MgbHashableOp<"Correlation", [CorrelationParam]>; | |||