/** * \file imperative/python/src/grad_override.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #include "./grad.h" #include "megbrain/imperative/ops/autogen.h" namespace mgb::imperative::python { namespace { std::shared_ptr get_shape(Tensor* x) { static auto op = GetVarShape::make(); return python::apply(op, x)[0]; } std::shared_ptr reduce_to(Tensor* x, Tensor* s) { static auto op = Reduce::make(); return python::apply(op, x, s)[0]; } apply_result_t elemwise_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) { auto& op = ctx.op->cast_final_safe(); if (op.mode == Elemwise::Mode::ADD) { mgb_assert(ctx.nargs == 2); std::array, 2> input_shapes; for (size_t i = 0; i < 2; ++i) { if (input_requires_grad(ctx, i)) { input_shapes[i] = get_shape(ctx.args[i]); } } maker.output_size(1).output_captured(0, false); maker.backward([shapes=std::move(input_shapes)](BackwardContext&, Tensor*const* grads, size_t ngrads) { mgb_assert(ngrads == 1); Tensor* grad = grads[0]; apply_result_t ret(2); for (size_t i = 0; i < 2; ++i) { if (shapes[i]) { ret[i] = reduce_to(grad, shapes[i].get()); } } return ret; }); return apply(ctx); } throw GradRuleFallback(); } struct Init { Init() { auto& reg = grad_rule_registry(); reg.emplace(Elemwise::typeinfo(), elemwise_grad_rule); } } _; } // namespace } // namespace mgb::imperative::python