precompute ops in forward to reduce saved tensor size
GitOrigin-RevId: d67043ba82
tags/v1.2.0
| @@ -11,6 +11,7 @@ | |||
| #include "./grad.h" | |||
| #include "megbrain/imperative/proxy_graph_detail.h" | |||
| #include "megbrain/imperative/backward_graph_opt.h" | |||
| #include "megbrain/imperative/ops/autogen.h" | |||
| #include "megbrain/imperative/ops/utility.h" | |||
| #include "megbrain/utils/mempool.h" | |||
| @@ -32,14 +33,14 @@ struct GradSlotWeakPtr { | |||
| size_t idx; | |||
| }; | |||
| struct BackwardGraphCache : std::unordered_map<uint64_t, std::shared_ptr<BackwardGraphResult>>, CompNodeDepedentObject { | |||
| struct BackwardGraphCache : std::unordered_map<uint64_t, std::shared_ptr<OptimizedBackwardGraphResult>>, CompNodeDepedentObject { | |||
| std::shared_ptr<void> on_comp_node_finalize() override { | |||
| clear(); | |||
| return {}; | |||
| } | |||
| } backward_graph_cache; | |||
| std::shared_ptr<BackwardGraphResult> make_backward_graph( | |||
| std::shared_ptr<OptimizedBackwardGraphResult> make_backward_graph( | |||
| ApplyContext& ctx, const apply_result_t& outputs) { | |||
| // hash | |||
| static_assert(alignof(size_t) % alignof(bool) == 0); | |||
| @@ -72,23 +73,23 @@ std::shared_ptr<BackwardGraphResult> make_backward_graph( | |||
| inputs[i].layout.dtype = ctx.args[i]->dtype(); | |||
| input_requires_grad[i] = python::input_requires_grad(ctx, i); | |||
| } | |||
| auto result = std::make_shared<BackwardGraphResult>( | |||
| proxy_graph_detail::make_backward_graph( | |||
| *ctx.op, inputs, input_requires_grad, output_has_grad)); | |||
| if (!result->backward) { | |||
| result.reset(); | |||
| std::shared_ptr<OptimizedBackwardGraphResult> ret; | |||
| auto bg = proxy_graph_detail::make_backward_graph( | |||
| *ctx.op, inputs, input_requires_grad, output_has_grad); | |||
| if (bg.backward) { | |||
| ret = std::make_shared<OptimizedBackwardGraphResult>(bg); | |||
| } | |||
| backward_graph_cache.emplace(key, result); | |||
| return result; | |||
| backward_graph_cache.emplace(key, ret); | |||
| return ret; | |||
| } | |||
| struct BackwardGraphWithClosure { | |||
| std::shared_ptr<BackwardGraphResult> backward_graph; | |||
| std::shared_ptr<OptimizedBackwardGraphResult> backward_graph; | |||
| SmallVector<std::shared_ptr<Tensor>> closure; | |||
| size_t output_mask_offset; | |||
| size_t grad_mask_offset; | |||
| BackwardGraphWithClosure(std::shared_ptr<BackwardGraphResult> backward_graph_, | |||
| BackwardGraphWithClosure(std::shared_ptr<OptimizedBackwardGraphResult> backward_graph_, | |||
| ApplyContext& ctx, const apply_result_t& outputs) | |||
| : backward_graph(backward_graph_), | |||
| output_mask_offset(ctx.nargs), | |||
| @@ -107,9 +108,18 @@ struct BackwardGraphWithClosure { | |||
| // b.requires_grad == False, save_for_backward = [0, 1, 0, 1] | |||
| auto& save_for_backward = backward_graph->save_for_backward; | |||
| mgb_assert(save_for_backward.size() == ctx.nargs + 2 * outputs.size()); | |||
| closure.reserve(std::count_if(save_for_backward.begin(), | |||
| save_for_backward.end(), | |||
| ranges::identity{})); | |||
| size_t count = std::count_if(save_for_backward.begin(), | |||
| save_for_backward.end(), | |||
| ranges::identity{}); | |||
| if (backward_graph->precomp) { | |||
| auto&& irng = ranges::span(ctx.args, ctx.nargs); | |||
| auto&& orng = views::transform(outputs, [](auto&& i){return i.get();}); | |||
| auto precomp = apply(backward_graph->precomp, views::concat(irng, orng)); | |||
| closure.reserve(precomp.size() + count); | |||
| std::copy(precomp.begin(), precomp.end(), std::back_inserter(closure)); | |||
| } else { | |||
| closure.reserve(count); | |||
| } | |||
| for (size_t i = 0; i < ctx.nargs; ++i) { | |||
| if (save_for_backward[i]) { | |||
| closure.push_back(ctx.args[i]->shared_from_this()); | |||
| @@ -212,7 +212,7 @@ decltype(auto) resolve_arrow(T&& p) { | |||
| if constexpr (std::is_invocable_v<decltype(probe), decltype(p)>) { | |||
| return resolve_arrow(p.operator->()); | |||
| } else { | |||
| return p; | |||
| return std::forward<T>(p); | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,114 @@ | |||
| /** | |||
| * \file imperative/src/impl/backward_graph_opt.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "megbrain/imperative/backward_graph_opt.h" | |||
| #include "megbrain/imperative/ops/backward_graph.h" | |||
| #include "megbrain/imperative/ops/autogen.h" | |||
| using namespace mgb; | |||
| using namespace imperative; | |||
| OptimizedBackwardGraphResult::OptimizedBackwardGraphResult(const BackwardGraphResult& src) | |||
| : input_has_grad(src.input_has_grad) { | |||
| if (!src.backward->same_type<BackwardGraph>()) { | |||
| // backward graph only contains a single op | |||
| backward = src.backward; | |||
| save_for_backward = src.save_for_backward; | |||
| return; | |||
| } | |||
| save_for_backward.resize(src.save_for_backward.size(), false); | |||
| precomp.reset(new BackwardGraph); | |||
| backward.reset(new BackwardGraph); | |||
| auto&& graph = src.backward->cast_final_safe<BackwardGraph>().graph(); | |||
| auto&& mask = src.save_for_backward; | |||
| size_t input_size = src.input_has_grad.size(); | |||
| size_t output_size = (mask.size() - input_size) / 2; | |||
| mgb_assert(input_size + output_size * 2 == mask.size()); | |||
| auto& fgraph = precomp->cast_final<BackwardGraph>().graph(); | |||
| auto& bgraph = backward->cast_final<BackwardGraph>().graph(); | |||
| // optimization: move ops (e.g. GetVarShape) to forward to | |||
| // reduce memory footprint | |||
| struct VInfo { | |||
| bool appears_in_backward = false; | |||
| }; | |||
| std::unordered_map<size_t, VInfo> vinfo; | |||
| // step 1.1: ops not in whitelist must run in backward. | |||
| // mark their inputs as always appears in backward | |||
| for (auto&& [op, iv, ov] : graph.exprs) { | |||
| if (!op->same_type<GetVarShape>()) { | |||
| for (auto&& v : iv) { | |||
| vinfo[v].appears_in_backward = true; | |||
| } | |||
| } | |||
| } | |||
| // step 1.2: inputs only available in backward (i.e. grads) | |||
| // should be marked as always appears in backward | |||
| for (size_t i = 0, j = 0; i < mask.size(); ++i) { | |||
| if (!mask[i]) continue; | |||
| if (i > input_size + output_size) { | |||
| vinfo[graph.inputs[j]].appears_in_backward = true; | |||
| } | |||
| ++j; | |||
| } | |||
| // step 2: try to move ops to forward, if not all their inputs | |||
| // are marked always appears in backward (otherwise no memory saving) | |||
| for (auto&& expr : graph.exprs) { | |||
| auto&& [op, iv, ov] = expr; | |||
| if (std::all_of(iv.begin(), iv.end(), [&](auto&& v){return vinfo[v].appears_in_backward;})) { | |||
| bgraph.exprs.push_back(expr); | |||
| for (auto&& v : ov) { | |||
| vinfo[v].appears_in_backward = true; | |||
| } | |||
| // logically should also mark all inputs as appears in backward | |||
| // but clearly that's a no-op. | |||
| } else { | |||
| fgraph.exprs.push_back(expr); | |||
| for (auto&& v : ov) { | |||
| if (vinfo[v].appears_in_backward) { | |||
| // appears_in_backward won't change after this point | |||
| // so it is safe to set fgraph.outputs based on current value | |||
| fgraph.outputs.push_back(v); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| // initialize remaining parts | |||
| fgraph.constants = graph.constants; | |||
| fgraph.inputs.reserve(input_size + output_size); | |||
| for (size_t i = 0, j = 0; i < input_size + output_size; ++i) { | |||
| if (!mask[i]) { | |||
| fgraph.inputs.push_back(1000000000 + i); | |||
| continue; | |||
| } | |||
| fgraph.inputs.push_back(graph.inputs[j++]); | |||
| } | |||
| bgraph.constants = graph.constants; | |||
| bgraph.outputs = graph.outputs; | |||
| bgraph.inputs = fgraph.outputs; | |||
| for (size_t i = 0, j = 0; i < mask.size(); ++i) { | |||
| if (mask[i]) { | |||
| auto&& v = graph.inputs[j++]; | |||
| if (vinfo[v].appears_in_backward) { | |||
| save_for_backward[i] = true; | |||
| bgraph.inputs.push_back(v); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,25 @@ | |||
| /** | |||
| * \file imperative/src/include/megbrain/imperative/backward_graph_opt.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "./op_def.h" | |||
| namespace mgb::imperative { | |||
| struct OptimizedBackwardGraphResult { | |||
| std::shared_ptr<OpDef> precomp; | |||
| std::shared_ptr<OpDef> backward; | |||
| std::vector<bool> save_for_backward; | |||
| std::vector<bool> input_has_grad; | |||
| OptimizedBackwardGraphResult(const BackwardGraphResult& bgraph); | |||
| }; | |||
| } // namespace mgb::imperative | |||
| @@ -13,11 +13,68 @@ | |||
| #include "megbrain/opr/basic_arith.h" | |||
| #include "megbrain/opr/dnn/batch_norm.h" | |||
| #include "megbrain/imperative/ops/opr_attr.h" | |||
| #include "megbrain/imperative/ops/autogen.h" | |||
| #include "megbrain/imperative/backward_graph_opt.h" | |||
| using namespace mgb; | |||
| using namespace cg; | |||
| using namespace imperative; | |||
| template <typename T> | |||
| T prepare_backward_graph_inputs(const BackwardGraphResult& bg, const T& inputs, const T& outputs, const T& grads) { | |||
| T ret; | |||
| size_t i = 0; | |||
| for (auto&& t : inputs) { | |||
| if (bg.save_for_backward[i++]) { | |||
| ret.push_back(t); | |||
| } | |||
| } | |||
| for (auto&& t : outputs) { | |||
| if (bg.save_for_backward[i++]) { | |||
| ret.push_back(t); | |||
| } | |||
| } | |||
| for (auto&& t : grads) { | |||
| if (bg.save_for_backward[i++]) { | |||
| ret.push_back(t); | |||
| } | |||
| } | |||
| return ret; | |||
| } | |||
| template <typename T, typename U> | |||
| T expand_grads(const U& bg, const T& outputs) { | |||
| T ret(bg.input_has_grad.size()); | |||
| for (size_t i = 0, j = 0; i < bg.input_has_grad.size(); ++i) { | |||
| if (bg.input_has_grad[i]) { | |||
| ret[i] = outputs[j++]; | |||
| } | |||
| } | |||
| return ret; | |||
| } | |||
| template <typename T> | |||
| T prepare_optimized_backward_inputs(const OptimizedBackwardGraphResult& bg, const T& precomp, const T& inputs, const T& outputs, const T& grads) { | |||
| T ret = precomp; | |||
| size_t i = 0; | |||
| for (auto&& t : inputs) { | |||
| if (bg.save_for_backward[i++]) { | |||
| ret.push_back(t); | |||
| } | |||
| } | |||
| for (auto&& t : outputs) { | |||
| if (bg.save_for_backward[i++]) { | |||
| ret.push_back(t); | |||
| } | |||
| } | |||
| for (auto&& t : grads) { | |||
| if (bg.save_for_backward[i++]) { | |||
| ret.push_back(t); | |||
| } | |||
| } | |||
| return ret; | |||
| } | |||
| TEST(TestImperative, BackwardGraphBasic) { | |||
| HostTensorGenerator<> gen; | |||
| SmallVector<HostTensorND> hvs; | |||
| @@ -121,27 +178,65 @@ TEST(TestImperative, BackwardGraphIdentity) { | |||
| } | |||
| TEST(TestImperative, BatchNormGrad) { | |||
| auto cn = CompNode::load("xpux"); | |||
| using Param = opr::BatchNorm::Param; | |||
| size_t N=2, C=3, H=5, W=5; | |||
| LogicalTensorDesc inp{TensorLayout{{N, C, H, W}, dtype::Float32()}, cn}; | |||
| LogicalTensorDesc stat{TensorLayout{{C}, dtype::Float32()}, cn}; | |||
| { | |||
| auto op = OprAttr::make("BatchNorm"); | |||
| auto&& attr = op->cast_final_safe<OprAttr>(); | |||
| Param param; | |||
| param.fwd_mode = Param::FwdMode::TRAINING; | |||
| attr.param.write_pod(param); | |||
| OpDef::make_backward_graph(attr, {inp, stat, stat, stat, stat}, | |||
| {true, true ,true, false, false}, {false, false, false, false, true}); | |||
| } | |||
| { | |||
| auto op = OprAttr::make("BatchNorm"); | |||
| auto&& attr = op->cast_final_safe<OprAttr>(); | |||
| Param param; | |||
| param.fwd_mode = Param::FwdMode::TRAINING; | |||
| attr.param.write_pod(param); | |||
| OpDef::make_backward_graph(attr, {inp, stat, stat}, | |||
| {true, true ,true}, {false, false, true}); | |||
| } | |||
| auto cn = CompNode::load("xpux"); | |||
| using Param = opr::BatchNorm::Param; | |||
| size_t N=2, C=3, H=5, W=5; | |||
| LogicalTensorDesc inp{TensorLayout{{N, C, H, W}, dtype::Float32()}, cn}; | |||
| LogicalTensorDesc stat{TensorLayout{{C}, dtype::Float32()}, cn}; | |||
| { | |||
| auto op = OprAttr::make("BatchNorm"); | |||
| auto&& attr = op->cast_final_safe<OprAttr>(); | |||
| Param param; | |||
| param.fwd_mode = Param::FwdMode::TRAINING; | |||
| attr.param.write_pod(param); | |||
| OpDef::make_backward_graph(attr, {inp, stat, stat, stat, stat}, | |||
| {true, true ,true, false, false}, {false, false, false, false, true}); | |||
| } | |||
| { | |||
| auto op = OprAttr::make("BatchNorm"); | |||
| auto&& attr = op->cast_final_safe<OprAttr>(); | |||
| Param param; | |||
| param.fwd_mode = Param::FwdMode::TRAINING; | |||
| attr.param.write_pod(param); | |||
| OpDef::make_backward_graph(attr, {inp, stat, stat}, | |||
| {true, true ,true}, {false, false, true}); | |||
| } | |||
| } | |||
| TEST(TestImperative, OptimizedBackwardGraphBasic) { | |||
| auto cn = CompNode::load("xpux"); | |||
| LogicalTensorDesc desc = {TensorLayout(dtype::Float32()), cn}; | |||
| HostTensorGenerator<> gen; | |||
| auto op = std::shared_ptr<OpDef>(Elemwise::make(Elemwise::Mode::ADD)); | |||
| auto bg = OpDef::make_backward_graph(*op, {desc, desc}, {true, true}, {true}); | |||
| auto obg = OptimizedBackwardGraphResult(bg); | |||
| ASSERT_EQ(obg.save_for_backward.size(), 4); | |||
| ASSERT_FALSE(obg.save_for_backward[0]); | |||
| ASSERT_FALSE(obg.save_for_backward[1]); | |||
| ASSERT_FALSE(obg.save_for_backward[2]); | |||
| auto a_hv = gen({42}); | |||
| auto b_hv = gen({5, 42}); | |||
| auto dc_hv = gen({5, 42}); | |||
| auto a_tn = Tensor::make(*a_hv); | |||
| auto b_tn = Tensor::make(*b_hv); | |||
| auto dc_tn = Tensor::make(*dc_hv); | |||
| auto c_tn = OpDef::apply_on_physical_tensor(*op, {a_tn, b_tn})[0]; | |||
| auto backward_graph_inputs = prepare_backward_graph_inputs<SmallVector<TensorPtr>>(bg, {a_tn, b_tn}, {c_tn}, {dc_tn}); | |||
| auto grads = expand_grads(bg, OpDef::apply_on_physical_tensor(*bg.backward, backward_graph_inputs)); | |||
| auto precomp = OpDef::apply_on_physical_tensor(*obg.precomp, {a_tn, b_tn, c_tn}); | |||
| ASSERT_EQ(precomp.size(), 2); | |||
| ASSERT_EQ(precomp[0]->shape().ndim, 1); | |||
| ASSERT_LE(precomp[0]->shape()[0], 2); | |||
| ASSERT_EQ(precomp[1]->shape().ndim, 1); | |||
| ASSERT_LE(precomp[1]->shape()[0], 2); | |||
| auto backward_inputs = prepare_optimized_backward_inputs<SmallVector<TensorPtr>>(obg, precomp, {a_tn, b_tn}, {c_tn}, {dc_tn}); | |||
| auto grads2 = expand_grads(obg, OpDef::apply_on_physical_tensor(*obg.backward, backward_inputs)); | |||
| ASSERT_EQ(grads2.size(), 2); | |||
| MGB_ASSERT_TENSOR_EQ(grads[0]->get_value(), grads2[0]->get_value()); | |||
| MGB_ASSERT_TENSOR_EQ(grads[1]->get_value(), grads2[1]->get_value()); | |||
| } | |||