precompute ops in forward to reduce saved tensor size
GitOrigin-RevId: d67043ba82
tags/v1.2.0
| @@ -11,6 +11,7 @@ | |||||
| #include "./grad.h" | #include "./grad.h" | ||||
| #include "megbrain/imperative/proxy_graph_detail.h" | #include "megbrain/imperative/proxy_graph_detail.h" | ||||
| #include "megbrain/imperative/backward_graph_opt.h" | |||||
| #include "megbrain/imperative/ops/autogen.h" | #include "megbrain/imperative/ops/autogen.h" | ||||
| #include "megbrain/imperative/ops/utility.h" | #include "megbrain/imperative/ops/utility.h" | ||||
| #include "megbrain/utils/mempool.h" | #include "megbrain/utils/mempool.h" | ||||
| @@ -32,14 +33,14 @@ struct GradSlotWeakPtr { | |||||
| size_t idx; | size_t idx; | ||||
| }; | }; | ||||
| struct BackwardGraphCache : std::unordered_map<uint64_t, std::shared_ptr<BackwardGraphResult>>, CompNodeDepedentObject { | |||||
| struct BackwardGraphCache : std::unordered_map<uint64_t, std::shared_ptr<OptimizedBackwardGraphResult>>, CompNodeDepedentObject { | |||||
| std::shared_ptr<void> on_comp_node_finalize() override { | std::shared_ptr<void> on_comp_node_finalize() override { | ||||
| clear(); | clear(); | ||||
| return {}; | return {}; | ||||
| } | } | ||||
| } backward_graph_cache; | } backward_graph_cache; | ||||
| std::shared_ptr<BackwardGraphResult> make_backward_graph( | |||||
| std::shared_ptr<OptimizedBackwardGraphResult> make_backward_graph( | |||||
| ApplyContext& ctx, const apply_result_t& outputs) { | ApplyContext& ctx, const apply_result_t& outputs) { | ||||
| // hash | // hash | ||||
| static_assert(alignof(size_t) % alignof(bool) == 0); | static_assert(alignof(size_t) % alignof(bool) == 0); | ||||
| @@ -72,23 +73,23 @@ std::shared_ptr<BackwardGraphResult> make_backward_graph( | |||||
| inputs[i].layout.dtype = ctx.args[i]->dtype(); | inputs[i].layout.dtype = ctx.args[i]->dtype(); | ||||
| input_requires_grad[i] = python::input_requires_grad(ctx, i); | input_requires_grad[i] = python::input_requires_grad(ctx, i); | ||||
| } | } | ||||
| auto result = std::make_shared<BackwardGraphResult>( | |||||
| proxy_graph_detail::make_backward_graph( | |||||
| *ctx.op, inputs, input_requires_grad, output_has_grad)); | |||||
| if (!result->backward) { | |||||
| result.reset(); | |||||
| std::shared_ptr<OptimizedBackwardGraphResult> ret; | |||||
| auto bg = proxy_graph_detail::make_backward_graph( | |||||
| *ctx.op, inputs, input_requires_grad, output_has_grad); | |||||
| if (bg.backward) { | |||||
| ret = std::make_shared<OptimizedBackwardGraphResult>(bg); | |||||
| } | } | ||||
| backward_graph_cache.emplace(key, result); | |||||
| return result; | |||||
| backward_graph_cache.emplace(key, ret); | |||||
| return ret; | |||||
| } | } | ||||
| struct BackwardGraphWithClosure { | struct BackwardGraphWithClosure { | ||||
| std::shared_ptr<BackwardGraphResult> backward_graph; | |||||
| std::shared_ptr<OptimizedBackwardGraphResult> backward_graph; | |||||
| SmallVector<std::shared_ptr<Tensor>> closure; | SmallVector<std::shared_ptr<Tensor>> closure; | ||||
| size_t output_mask_offset; | size_t output_mask_offset; | ||||
| size_t grad_mask_offset; | size_t grad_mask_offset; | ||||
| BackwardGraphWithClosure(std::shared_ptr<BackwardGraphResult> backward_graph_, | |||||
| BackwardGraphWithClosure(std::shared_ptr<OptimizedBackwardGraphResult> backward_graph_, | |||||
| ApplyContext& ctx, const apply_result_t& outputs) | ApplyContext& ctx, const apply_result_t& outputs) | ||||
| : backward_graph(backward_graph_), | : backward_graph(backward_graph_), | ||||
| output_mask_offset(ctx.nargs), | output_mask_offset(ctx.nargs), | ||||
| @@ -107,9 +108,18 @@ struct BackwardGraphWithClosure { | |||||
| // b.requires_grad == False, save_for_backward = [0, 1, 0, 1] | // b.requires_grad == False, save_for_backward = [0, 1, 0, 1] | ||||
| auto& save_for_backward = backward_graph->save_for_backward; | auto& save_for_backward = backward_graph->save_for_backward; | ||||
| mgb_assert(save_for_backward.size() == ctx.nargs + 2 * outputs.size()); | mgb_assert(save_for_backward.size() == ctx.nargs + 2 * outputs.size()); | ||||
| closure.reserve(std::count_if(save_for_backward.begin(), | |||||
| save_for_backward.end(), | |||||
| ranges::identity{})); | |||||
| size_t count = std::count_if(save_for_backward.begin(), | |||||
| save_for_backward.end(), | |||||
| ranges::identity{}); | |||||
| if (backward_graph->precomp) { | |||||
| auto&& irng = ranges::span(ctx.args, ctx.nargs); | |||||
| auto&& orng = views::transform(outputs, [](auto&& i){return i.get();}); | |||||
| auto precomp = apply(backward_graph->precomp, views::concat(irng, orng)); | |||||
| closure.reserve(precomp.size() + count); | |||||
| std::copy(precomp.begin(), precomp.end(), std::back_inserter(closure)); | |||||
| } else { | |||||
| closure.reserve(count); | |||||
| } | |||||
| for (size_t i = 0; i < ctx.nargs; ++i) { | for (size_t i = 0; i < ctx.nargs; ++i) { | ||||
| if (save_for_backward[i]) { | if (save_for_backward[i]) { | ||||
| closure.push_back(ctx.args[i]->shared_from_this()); | closure.push_back(ctx.args[i]->shared_from_this()); | ||||
| @@ -212,7 +212,7 @@ decltype(auto) resolve_arrow(T&& p) { | |||||
| if constexpr (std::is_invocable_v<decltype(probe), decltype(p)>) { | if constexpr (std::is_invocable_v<decltype(probe), decltype(p)>) { | ||||
| return resolve_arrow(p.operator->()); | return resolve_arrow(p.operator->()); | ||||
| } else { | } else { | ||||
| return p; | |||||
| return std::forward<T>(p); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,114 @@ | |||||
| /** | |||||
| * \file imperative/src/impl/backward_graph_opt.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| #include "megbrain/imperative/backward_graph_opt.h" | |||||
| #include "megbrain/imperative/ops/backward_graph.h" | |||||
| #include "megbrain/imperative/ops/autogen.h" | |||||
| using namespace mgb; | |||||
| using namespace imperative; | |||||
| OptimizedBackwardGraphResult::OptimizedBackwardGraphResult(const BackwardGraphResult& src) | |||||
| : input_has_grad(src.input_has_grad) { | |||||
| if (!src.backward->same_type<BackwardGraph>()) { | |||||
| // backward graph only contains a single op | |||||
| backward = src.backward; | |||||
| save_for_backward = src.save_for_backward; | |||||
| return; | |||||
| } | |||||
| save_for_backward.resize(src.save_for_backward.size(), false); | |||||
| precomp.reset(new BackwardGraph); | |||||
| backward.reset(new BackwardGraph); | |||||
| auto&& graph = src.backward->cast_final_safe<BackwardGraph>().graph(); | |||||
| auto&& mask = src.save_for_backward; | |||||
| size_t input_size = src.input_has_grad.size(); | |||||
| size_t output_size = (mask.size() - input_size) / 2; | |||||
| mgb_assert(input_size + output_size * 2 == mask.size()); | |||||
| auto& fgraph = precomp->cast_final<BackwardGraph>().graph(); | |||||
| auto& bgraph = backward->cast_final<BackwardGraph>().graph(); | |||||
| // optimization: move ops (e.g. GetVarShape) to forward to | |||||
| // reduce memory footprint | |||||
| struct VInfo { | |||||
| bool appears_in_backward = false; | |||||
| }; | |||||
| std::unordered_map<size_t, VInfo> vinfo; | |||||
| // step 1.1: ops not in whitelist must run in backward. | |||||
| // mark their inputs as always appears in backward | |||||
| for (auto&& [op, iv, ov] : graph.exprs) { | |||||
| if (!op->same_type<GetVarShape>()) { | |||||
| for (auto&& v : iv) { | |||||
| vinfo[v].appears_in_backward = true; | |||||
| } | |||||
| } | |||||
| } | |||||
| // step 1.2: inputs only available in backward (i.e. grads) | |||||
| // should be marked as always appears in backward | |||||
| for (size_t i = 0, j = 0; i < mask.size(); ++i) { | |||||
| if (!mask[i]) continue; | |||||
| if (i > input_size + output_size) { | |||||
| vinfo[graph.inputs[j]].appears_in_backward = true; | |||||
| } | |||||
| ++j; | |||||
| } | |||||
| // step 2: try to move ops to forward, if not all their inputs | |||||
| // are marked always appears in backward (otherwise no memory saving) | |||||
| for (auto&& expr : graph.exprs) { | |||||
| auto&& [op, iv, ov] = expr; | |||||
| if (std::all_of(iv.begin(), iv.end(), [&](auto&& v){return vinfo[v].appears_in_backward;})) { | |||||
| bgraph.exprs.push_back(expr); | |||||
| for (auto&& v : ov) { | |||||
| vinfo[v].appears_in_backward = true; | |||||
| } | |||||
| // logically should also mark all inputs as appears in backward | |||||
| // but clearly that's a no-op. | |||||
| } else { | |||||
| fgraph.exprs.push_back(expr); | |||||
| for (auto&& v : ov) { | |||||
| if (vinfo[v].appears_in_backward) { | |||||
| // appears_in_backward won't change after this point | |||||
| // so it is safe to set fgraph.outputs based on current value | |||||
| fgraph.outputs.push_back(v); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| // initialize remaining parts | |||||
| fgraph.constants = graph.constants; | |||||
| fgraph.inputs.reserve(input_size + output_size); | |||||
| for (size_t i = 0, j = 0; i < input_size + output_size; ++i) { | |||||
| if (!mask[i]) { | |||||
| fgraph.inputs.push_back(1000000000 + i); | |||||
| continue; | |||||
| } | |||||
| fgraph.inputs.push_back(graph.inputs[j++]); | |||||
| } | |||||
| bgraph.constants = graph.constants; | |||||
| bgraph.outputs = graph.outputs; | |||||
| bgraph.inputs = fgraph.outputs; | |||||
| for (size_t i = 0, j = 0; i < mask.size(); ++i) { | |||||
| if (mask[i]) { | |||||
| auto&& v = graph.inputs[j++]; | |||||
| if (vinfo[v].appears_in_backward) { | |||||
| save_for_backward[i] = true; | |||||
| bgraph.inputs.push_back(v); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,25 @@ | |||||
| /** | |||||
| * \file imperative/src/include/megbrain/imperative/backward_graph_opt.h | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| #include "./op_def.h" | |||||
| namespace mgb::imperative { | |||||
| struct OptimizedBackwardGraphResult { | |||||
| std::shared_ptr<OpDef> precomp; | |||||
| std::shared_ptr<OpDef> backward; | |||||
| std::vector<bool> save_for_backward; | |||||
| std::vector<bool> input_has_grad; | |||||
| OptimizedBackwardGraphResult(const BackwardGraphResult& bgraph); | |||||
| }; | |||||
| } // namespace mgb::imperative | |||||
| @@ -13,11 +13,68 @@ | |||||
| #include "megbrain/opr/basic_arith.h" | #include "megbrain/opr/basic_arith.h" | ||||
| #include "megbrain/opr/dnn/batch_norm.h" | #include "megbrain/opr/dnn/batch_norm.h" | ||||
| #include "megbrain/imperative/ops/opr_attr.h" | #include "megbrain/imperative/ops/opr_attr.h" | ||||
| #include "megbrain/imperative/ops/autogen.h" | |||||
| #include "megbrain/imperative/backward_graph_opt.h" | |||||
| using namespace mgb; | using namespace mgb; | ||||
| using namespace cg; | using namespace cg; | ||||
| using namespace imperative; | using namespace imperative; | ||||
| template <typename T> | |||||
| T prepare_backward_graph_inputs(const BackwardGraphResult& bg, const T& inputs, const T& outputs, const T& grads) { | |||||
| T ret; | |||||
| size_t i = 0; | |||||
| for (auto&& t : inputs) { | |||||
| if (bg.save_for_backward[i++]) { | |||||
| ret.push_back(t); | |||||
| } | |||||
| } | |||||
| for (auto&& t : outputs) { | |||||
| if (bg.save_for_backward[i++]) { | |||||
| ret.push_back(t); | |||||
| } | |||||
| } | |||||
| for (auto&& t : grads) { | |||||
| if (bg.save_for_backward[i++]) { | |||||
| ret.push_back(t); | |||||
| } | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| template <typename T, typename U> | |||||
| T expand_grads(const U& bg, const T& outputs) { | |||||
| T ret(bg.input_has_grad.size()); | |||||
| for (size_t i = 0, j = 0; i < bg.input_has_grad.size(); ++i) { | |||||
| if (bg.input_has_grad[i]) { | |||||
| ret[i] = outputs[j++]; | |||||
| } | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| template <typename T> | |||||
| T prepare_optimized_backward_inputs(const OptimizedBackwardGraphResult& bg, const T& precomp, const T& inputs, const T& outputs, const T& grads) { | |||||
| T ret = precomp; | |||||
| size_t i = 0; | |||||
| for (auto&& t : inputs) { | |||||
| if (bg.save_for_backward[i++]) { | |||||
| ret.push_back(t); | |||||
| } | |||||
| } | |||||
| for (auto&& t : outputs) { | |||||
| if (bg.save_for_backward[i++]) { | |||||
| ret.push_back(t); | |||||
| } | |||||
| } | |||||
| for (auto&& t : grads) { | |||||
| if (bg.save_for_backward[i++]) { | |||||
| ret.push_back(t); | |||||
| } | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| TEST(TestImperative, BackwardGraphBasic) { | TEST(TestImperative, BackwardGraphBasic) { | ||||
| HostTensorGenerator<> gen; | HostTensorGenerator<> gen; | ||||
| SmallVector<HostTensorND> hvs; | SmallVector<HostTensorND> hvs; | ||||
| @@ -121,27 +178,65 @@ TEST(TestImperative, BackwardGraphIdentity) { | |||||
| } | } | ||||
| TEST(TestImperative, BatchNormGrad) { | TEST(TestImperative, BatchNormGrad) { | ||||
| auto cn = CompNode::load("xpux"); | |||||
| using Param = opr::BatchNorm::Param; | |||||
| size_t N=2, C=3, H=5, W=5; | |||||
| LogicalTensorDesc inp{TensorLayout{{N, C, H, W}, dtype::Float32()}, cn}; | |||||
| LogicalTensorDesc stat{TensorLayout{{C}, dtype::Float32()}, cn}; | |||||
| { | |||||
| auto op = OprAttr::make("BatchNorm"); | |||||
| auto&& attr = op->cast_final_safe<OprAttr>(); | |||||
| Param param; | |||||
| param.fwd_mode = Param::FwdMode::TRAINING; | |||||
| attr.param.write_pod(param); | |||||
| OpDef::make_backward_graph(attr, {inp, stat, stat, stat, stat}, | |||||
| {true, true ,true, false, false}, {false, false, false, false, true}); | |||||
| } | |||||
| { | |||||
| auto op = OprAttr::make("BatchNorm"); | |||||
| auto&& attr = op->cast_final_safe<OprAttr>(); | |||||
| Param param; | |||||
| param.fwd_mode = Param::FwdMode::TRAINING; | |||||
| attr.param.write_pod(param); | |||||
| OpDef::make_backward_graph(attr, {inp, stat, stat}, | |||||
| {true, true ,true}, {false, false, true}); | |||||
| } | |||||
| auto cn = CompNode::load("xpux"); | |||||
| using Param = opr::BatchNorm::Param; | |||||
| size_t N=2, C=3, H=5, W=5; | |||||
| LogicalTensorDesc inp{TensorLayout{{N, C, H, W}, dtype::Float32()}, cn}; | |||||
| LogicalTensorDesc stat{TensorLayout{{C}, dtype::Float32()}, cn}; | |||||
| { | |||||
| auto op = OprAttr::make("BatchNorm"); | |||||
| auto&& attr = op->cast_final_safe<OprAttr>(); | |||||
| Param param; | |||||
| param.fwd_mode = Param::FwdMode::TRAINING; | |||||
| attr.param.write_pod(param); | |||||
| OpDef::make_backward_graph(attr, {inp, stat, stat, stat, stat}, | |||||
| {true, true ,true, false, false}, {false, false, false, false, true}); | |||||
| } | |||||
| { | |||||
| auto op = OprAttr::make("BatchNorm"); | |||||
| auto&& attr = op->cast_final_safe<OprAttr>(); | |||||
| Param param; | |||||
| param.fwd_mode = Param::FwdMode::TRAINING; | |||||
| attr.param.write_pod(param); | |||||
| OpDef::make_backward_graph(attr, {inp, stat, stat}, | |||||
| {true, true ,true}, {false, false, true}); | |||||
| } | |||||
| } | |||||
| TEST(TestImperative, OptimizedBackwardGraphBasic) { | |||||
| auto cn = CompNode::load("xpux"); | |||||
| LogicalTensorDesc desc = {TensorLayout(dtype::Float32()), cn}; | |||||
| HostTensorGenerator<> gen; | |||||
| auto op = std::shared_ptr<OpDef>(Elemwise::make(Elemwise::Mode::ADD)); | |||||
| auto bg = OpDef::make_backward_graph(*op, {desc, desc}, {true, true}, {true}); | |||||
| auto obg = OptimizedBackwardGraphResult(bg); | |||||
| ASSERT_EQ(obg.save_for_backward.size(), 4); | |||||
| ASSERT_FALSE(obg.save_for_backward[0]); | |||||
| ASSERT_FALSE(obg.save_for_backward[1]); | |||||
| ASSERT_FALSE(obg.save_for_backward[2]); | |||||
| auto a_hv = gen({42}); | |||||
| auto b_hv = gen({5, 42}); | |||||
| auto dc_hv = gen({5, 42}); | |||||
| auto a_tn = Tensor::make(*a_hv); | |||||
| auto b_tn = Tensor::make(*b_hv); | |||||
| auto dc_tn = Tensor::make(*dc_hv); | |||||
| auto c_tn = OpDef::apply_on_physical_tensor(*op, {a_tn, b_tn})[0]; | |||||
| auto backward_graph_inputs = prepare_backward_graph_inputs<SmallVector<TensorPtr>>(bg, {a_tn, b_tn}, {c_tn}, {dc_tn}); | |||||
| auto grads = expand_grads(bg, OpDef::apply_on_physical_tensor(*bg.backward, backward_graph_inputs)); | |||||
| auto precomp = OpDef::apply_on_physical_tensor(*obg.precomp, {a_tn, b_tn, c_tn}); | |||||
| ASSERT_EQ(precomp.size(), 2); | |||||
| ASSERT_EQ(precomp[0]->shape().ndim, 1); | |||||
| ASSERT_LE(precomp[0]->shape()[0], 2); | |||||
| ASSERT_EQ(precomp[1]->shape().ndim, 1); | |||||
| ASSERT_LE(precomp[1]->shape()[0], 2); | |||||
| auto backward_inputs = prepare_optimized_backward_inputs<SmallVector<TensorPtr>>(obg, precomp, {a_tn, b_tn}, {c_tn}, {dc_tn}); | |||||
| auto grads2 = expand_grads(obg, OpDef::apply_on_physical_tensor(*obg.backward, backward_inputs)); | |||||
| ASSERT_EQ(grads2.size(), 2); | |||||
| MGB_ASSERT_TENSOR_EQ(grads[0]->get_value(), grads2[0]->get_value()); | |||||
| MGB_ASSERT_TENSOR_EQ(grads[1]->get_value(), grads2[1]->get_value()); | |||||
| } | } | ||||