GitOrigin-RevId: 56d90be0e7
tags/v1.6.0-rc1
| @@ -77,7 +77,7 @@ std::shared_ptr<OptimizedBackwardGraphResult> make_backward_graph( | |||
| std::shared_ptr<OptimizedBackwardGraphResult> ret; | |||
| auto bg = OpDef::make_backward_graph( | |||
| *ctx.op, inputs, input_requires_grad, output_has_grad); | |||
| if (!bg.backward.empty()) { | |||
| if (!bg.graph.empty()) { | |||
| ret = std::make_shared<OptimizedBackwardGraphResult>(bg); | |||
| } | |||
| backward_graph_cache.emplace(key, ret); | |||
| @@ -37,7 +37,7 @@ void init_imperative_rt(py::module m) { | |||
| const SmallVector<bool>& input_requires_grad, | |||
| const SmallVector<bool>& output_has_grad){ | |||
| auto result = OpDef::make_backward_graph(def, inputs, input_requires_grad, output_has_grad); | |||
| return std::make_tuple("backward_graph", result.save_for_backward, result.input_has_grad); | |||
| return std::make_tuple("backward_graph", result.input_mask, result.output_mask); | |||
| }; | |||
| m.def("make_backward_graph", make_backward_graph); | |||
| } | |||
| @@ -16,19 +16,19 @@ | |||
| using namespace mgb; | |||
| using namespace imperative; | |||
| OptimizedBackwardGraphResult::OptimizedBackwardGraphResult(const BackwardGraphResult& src) | |||
| : input_has_grad(src.input_has_grad) { | |||
| if (src.backward.exprs.size() <= 1) { | |||
| OptimizedBackwardGraphResult::OptimizedBackwardGraphResult(const EncodedSubraph& src) | |||
| : input_has_grad(src.output_mask) { | |||
| if (src.graph.exprs.size() <= 1) { | |||
| // backward graph only contains a single op | |||
| backward = src.backward; | |||
| save_for_backward = src.save_for_backward; | |||
| backward = src.graph; | |||
| save_for_backward = src.input_mask; | |||
| return; | |||
| } | |||
| save_for_backward.resize(src.save_for_backward.size(), false); | |||
| save_for_backward.resize(src.input_mask.size(), false); | |||
| auto&& graph = src.backward; | |||
| auto&& mask = src.save_for_backward; | |||
| size_t input_size = src.input_has_grad.size(); | |||
| auto&& graph = src.graph; | |||
| auto&& mask = src.input_mask; | |||
| size_t input_size = src.output_mask.size(); | |||
| size_t output_size = (mask.size() - input_size) / 2; | |||
| mgb_assert(input_size + output_size * 2 == mask.size()); | |||
| @@ -80,7 +80,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> OpDef::infer_output_attrs_falli | |||
| return def.trait()->infer_output_attrs_fallible(def, inputs); | |||
| } | |||
| BackwardGraphResult OpDef::make_backward_graph( | |||
| EncodedSubraph OpDef::make_backward_graph( | |||
| const OpDef& def, | |||
| const SmallVector<LogicalTensorDesc>& inputs, | |||
| const SmallVector<bool>& input_requires_grad, | |||
| @@ -668,14 +668,14 @@ struct ProxyGraph::GradGraph { | |||
| cg::VarNode* grad; | |||
| }; | |||
| BackwardGraphResult | |||
| EncodedSubraph | |||
| ProxyGraph::make_backward_graph( | |||
| const OpDef& opdef, | |||
| const SmallVector<LogicalTensorDesc>& input_descs, | |||
| const SmallVector<bool>& input_requires_grad, | |||
| const SmallVector<bool>& output_has_grad) { | |||
| ThinHashMap<VarNode*, size_t> var2idx; | |||
| auto push = [&var2idx, cnt=0](VarNode* var) mutable { | |||
| auto push = [&var2idx, cnt=1](VarNode* var) mutable { //cnt is always greater non zero | |||
| auto&& ret = var2idx.emplace(var, cnt ++); | |||
| mgb_assert(ret.second, "var %s has been already inserted", var->cname()); | |||
| return ret.first->second; | |||
| @@ -702,8 +702,8 @@ ProxyGraph::make_backward_graph( | |||
| } | |||
| auto* gfunc = cg::lookup_grad_func(fwd->dyn_typeinfo()); | |||
| BackwardGraphResult result; | |||
| auto&& igraph = result.backward; | |||
| EncodedSubraph result; | |||
| auto&& igraph = result.graph; | |||
| size_t nr_backward_graph_inputs = 0; | |||
| auto gen_expr = [this, &var2idx, &igraph, &push, &fwd, | |||
| @@ -735,7 +735,7 @@ ProxyGraph::make_backward_graph( | |||
| // set backward graph outputs | |||
| cg::DepOprIter iter{gen_expr}; | |||
| iter.set_visited(fwd); | |||
| result.input_has_grad.resize(inputs.size()); | |||
| result.output_mask.resize(inputs.size()); | |||
| VarNodeArray output_grads_with_unused_var; | |||
| { | |||
| @@ -760,6 +760,7 @@ ProxyGraph::make_backward_graph( | |||
| if (grad_results.valid()) { | |||
| grad = grad_results.val()[i]; | |||
| } else { | |||
| mgb_assert(gfunc, "could not find grad function"); | |||
| auto res = (*gfunc)(fwd, i, output_grads_with_unused_var); | |||
| if (res.from_single()) { | |||
| grad = res.single(); | |||
| @@ -776,9 +777,9 @@ ProxyGraph::make_backward_graph( | |||
| fwd->dyn_typeinfo()->name, i); | |||
| iter.add(grad); | |||
| igraph.outputs.push_back(var2idx.at(grad)); | |||
| result.input_has_grad[i] = true; | |||
| result.output_mask[i] = true; | |||
| } else { | |||
| result.input_has_grad[i] = false; | |||
| result.output_mask[i] = false; | |||
| } | |||
| } | |||
| if (igraph.outputs.empty()) { | |||
| @@ -787,15 +788,15 @@ ProxyGraph::make_backward_graph( | |||
| // set backward graph inputs | |||
| igraph.inputs.reserve(nr_backward_graph_inputs); | |||
| result.save_for_backward.reserve(nr_backward_graph_inputs); | |||
| result.input_mask.reserve(nr_backward_graph_inputs); | |||
| auto write_inputs = [&igraph, &var2idx, &result](const VarNodeArray& vars) { | |||
| for (auto&& i: vars) { | |||
| auto&& iter = var2idx.find(i); | |||
| if (iter != var2idx.end()) { | |||
| igraph.inputs.push_back(iter->second); | |||
| result.save_for_backward.push_back(true); | |||
| result.input_mask.push_back(true); | |||
| } else { | |||
| result.save_for_backward.push_back(false); | |||
| result.input_mask.push_back(false); | |||
| } | |||
| } | |||
| }; | |||
| @@ -40,7 +40,7 @@ public: | |||
| const SmallVector<Tensor*>& outputs, | |||
| const SmallVector<Tensor*>& workspace); | |||
| BackwardGraphResult make_backward_graph( | |||
| EncodedSubraph make_backward_graph( | |||
| const OpDef& opdef, | |||
| const SmallVector<LogicalTensorDesc>& input_descs, | |||
| const SmallVector<bool>& input_requires_grad, | |||
| @@ -133,7 +133,7 @@ size_t get_backward_graph_hash_key(const OpDef& def, | |||
| return state.digest(); | |||
| } | |||
| struct BackwardGraphCache : std::unordered_map<size_t, BackwardGraphResult>, CompNodeDepedentObject { | |||
| struct BackwardGraphCache : std::unordered_map<size_t, EncodedSubraph>, CompNodeDepedentObject { | |||
| std::shared_ptr<void> on_comp_node_finalize() override { | |||
| clear(); | |||
| return {}; | |||
| @@ -142,7 +142,7 @@ struct BackwardGraphCache : std::unordered_map<size_t, BackwardGraphResult>, Com | |||
| } // anonymous namespace | |||
| BackwardGraphResult | |||
| EncodedSubraph | |||
| make_backward_graph(const OpDef& def, | |||
| const SmallVector<LogicalTensorDesc>& inputs, | |||
| const SmallVector<bool>& input_requires_grad, | |||
| @@ -101,5 +101,26 @@ void Subgraph::replace_vars( | |||
| } | |||
| } | |||
| std::string EncodedSubraph::repr() const { | |||
| std::string buffer; | |||
| buffer.push_back('|'); | |||
| for (size_t i = 0; i < input_mask.size(); ++i) { | |||
| buffer.push_back(input_mask[i] ? '#' : ' '); | |||
| } | |||
| buffer.push_back('|'); | |||
| buffer.push_back('\n'); | |||
| buffer.append(graph.repr()); | |||
| buffer.push_back('|'); | |||
| for (size_t i = 0; i < output_mask.size(); ++i) { | |||
| buffer.push_back(output_mask[i] ? '#' : ' '); | |||
| } | |||
| buffer.push_back('|'); | |||
| return buffer; | |||
| } | |||
| size_t EncodedSubraph::hash() const { | |||
| return std::hash<std::string>{}(repr()); | |||
| } | |||
| } // namespace imperative | |||
| } // namespace mgb | |||
| @@ -19,7 +19,7 @@ struct OptimizedBackwardGraphResult { | |||
| SmallVector<bool> save_for_backward; | |||
| SmallVector<bool> input_has_grad; | |||
| OptimizedBackwardGraphResult(const BackwardGraphResult& bgraph); | |||
| OptimizedBackwardGraphResult(const EncodedSubraph& bgraph); | |||
| }; | |||
| } // namespace mgb::imperative | |||
| @@ -29,12 +29,6 @@ enum DispatchMode { | |||
| using SharedOp = std::shared_ptr<OpDef>; | |||
| struct BackwardGraphResult { | |||
| Subgraph backward; | |||
| SmallVector<bool> save_for_backward; | |||
| SmallVector<bool> input_has_grad; | |||
| }; | |||
| class OpDef : public Hashable, | |||
| public NonCopyableObj, | |||
| public std::enable_shared_from_this<OpDef> { | |||
| @@ -91,7 +85,7 @@ public: | |||
| const SmallVector<TensorPtr>& inputs_tensors, | |||
| const SmallVector<MemoryDesc>& inputs_mems); | |||
| static BackwardGraphResult make_backward_graph( | |||
| static EncodedSubraph make_backward_graph( | |||
| const OpDef& def, | |||
| const SmallVector<LogicalTensorDesc>& inputs, | |||
| const SmallVector<bool>& input_requires_grad, | |||
| @@ -38,7 +38,7 @@ void exec(const OpDef& def, | |||
| const SmallVector<TensorPtr>& inputs, | |||
| const SmallVector<TensorPtr>& outputs); | |||
| BackwardGraphResult | |||
| EncodedSubraph | |||
| make_backward_graph(const OpDef& def, | |||
| const SmallVector<LogicalTensorDesc>& inputs, | |||
| const SmallVector<bool>& input_requires_grad, | |||
| @@ -96,5 +96,185 @@ struct Subgraph { | |||
| bool operator==(const Subgraph& rhs) const; | |||
| }; | |||
| struct EncodedSubraph { | |||
| Subgraph graph; | |||
| SmallVector<bool> input_mask; | |||
| SmallVector<bool> output_mask; | |||
| template <typename TContainer> | |||
| TContainer encode_inputs(TContainer inputs) const { | |||
| TContainer encoded_inputs; | |||
| size_t index = 0; | |||
| for (auto&& input : inputs) { | |||
| mgb_assert(index < input_mask.size(), "index out of range"); | |||
| if (input_mask[index++]) { | |||
| encoded_inputs.push_back(input); | |||
| } | |||
| } | |||
| mgb_assert(index == input_mask.size(), "mask size mismatch"); | |||
| return encoded_inputs; | |||
| } | |||
| template <typename TContainer> | |||
| TContainer encode_outputs(TContainer outputs) const { | |||
| TContainer encoded_outputs; | |||
| size_t index = 0; | |||
| for (auto&& output : outputs) { | |||
| mgb_assert(index < output_mask.size(), "index out of range"); | |||
| if (output_mask[index++]) { | |||
| encoded_outputs.push_back(output); | |||
| } | |||
| } | |||
| mgb_assert(index == output_mask.size(), "mask size mismatch"); | |||
| return encoded_outputs; | |||
| } | |||
| template <typename TContainer> | |||
| TContainer decode_outputs(TContainer outputs) const { | |||
| TContainer decoded_outputs; | |||
| size_t index = 0; | |||
| for (size_t i = 0; i < output_mask.size(); i++) { | |||
| mgb_assert(index < output_mask.size(), "index out of range"); | |||
| if (output_mask[i]) { | |||
| decoded_outputs.push_back(outputs[index++]); | |||
| } else { | |||
| decoded_outputs.emplace_back(); | |||
| } | |||
| } | |||
| mgb_assert(decoded_outputs.size() == output_mask.size(), | |||
| "mask size mismatch"); | |||
| return decoded_outputs; | |||
| } | |||
| static EncodedSubraph make(Subgraph graph) { | |||
| EncodedSubraph result; | |||
| result.input_mask = graph.gen_input_mask(); | |||
| result.output_mask = graph.gen_output_mask(); | |||
| graph.inputs = result.encode_inputs(graph.inputs); | |||
| graph.outputs = result.encode_outputs(graph.outputs); | |||
| result.graph = graph; | |||
| return result; | |||
| } | |||
| static EncodedSubraph make_single( | |||
| std::shared_ptr<OpDef> op, | |||
| SmallVector<bool> input_mask, | |||
| SmallVector<bool> output_mask) { | |||
| EncodedSubraph result; | |||
| result.input_mask = input_mask; | |||
| result.output_mask = output_mask; | |||
| Subgraph::var_t last_var = 0; | |||
| for (auto&& mask: input_mask) { | |||
| if (mask) { | |||
| result.graph.inputs.push_back(++last_var); | |||
| } | |||
| } | |||
| for (auto&& mask: output_mask) { | |||
| if (mask) { | |||
| result.graph.outputs.push_back(++last_var); | |||
| } | |||
| } | |||
| result.graph.exprs = {Subgraph::expr_t{op, result.graph.inputs, result.graph.outputs}}; | |||
| return result; | |||
| } | |||
| template <typename T, typename F, typename C> | |||
| SmallVector<T> apply(SmallVector<T> input_vars, F&& f, C&& c) const { | |||
| auto encoded_inputs = encode_inputs(input_vars); | |||
| auto encoded_outputs = graph.apply(encoded_inputs, std::forward<F>(f), | |||
| std::forward<C>(c)); | |||
| return decode_outputs(encoded_outputs); | |||
| } | |||
| std::string repr() const; | |||
| size_t hash() const; | |||
| }; | |||
| template <typename T> | |||
| class GradContext { | |||
| public: | |||
| using var_t = T; | |||
| using vars_t = SmallVector<var_t>; | |||
| using expr_t = Expr<T>; | |||
| private: | |||
| std::unordered_map<var_t, var_t> m_grads; | |||
| std::unordered_set<var_t> m_vars_require_grad; | |||
| std::function<var_t(var_t, var_t)> m_accumulator; | |||
| std::vector<expr_t> m_exprs; | |||
| public: | |||
| GradContext(std::function<var_t(var_t, var_t)> accumulator): m_accumulator{std::move(accumulator)}{} | |||
| SmallVector<bool> get_require_grads(vars_t dests) { | |||
| SmallVector<bool> mask; | |||
| for (auto&& dest: dests) { | |||
| mask.push_back(bool(m_vars_require_grad.count(dest))); | |||
| } | |||
| return mask; | |||
| } | |||
| SmallVector<bool> get_has_grads(vars_t dests) { | |||
| SmallVector<bool> mask; | |||
| for (auto&& dest: dests) { | |||
| mask.push_back(bool(m_grads.count(dest))); | |||
| } | |||
| return mask; | |||
| } | |||
| void mark_require_grads(vars_t dests) { | |||
| for (auto&& dest: dests) { | |||
| m_vars_require_grad.insert(dest); | |||
| } | |||
| } | |||
| var_t accumulate_grad(var_t dest, var_t grad) { | |||
| if (!m_grads.count(dest)) { | |||
| return m_grads[dest] = grad; | |||
| } else { | |||
| return m_grads[dest] = m_accumulator(m_grads[dest], grad); | |||
| } | |||
| } | |||
| void record_expr(std::shared_ptr<OpDef> op, vars_t inputs, vars_t outputs) { | |||
| bool require_grad = false; | |||
| for (auto&& input: inputs) { | |||
| if (m_vars_require_grad.count(input)) { | |||
| require_grad = true; | |||
| break; | |||
| } | |||
| } | |||
| if (require_grad) { | |||
| m_exprs.push_back({op, inputs, outputs}); | |||
| mark_require_grads(outputs); | |||
| } | |||
| } | |||
| template <typename TFunctor> | |||
| void backward(vars_t outputs, vars_t output_grads, TFunctor functor) { | |||
| size_t nr_outputs = outputs.size(); | |||
| for (size_t i = 0; i < nr_outputs; ++i) { | |||
| m_grads[outputs[i]] = output_grads[i]; | |||
| } | |||
| auto exprs = m_exprs; | |||
| std::reverse(exprs.begin(), exprs.end()); | |||
| for (const expr_t& expr: exprs) { | |||
| size_t nr_inputs = expr.inputs.size(); | |||
| vars_t input_grads = functor(expr, get_grads(expr.outputs)); | |||
| mgb_assert(input_grads.size() == nr_inputs, "input size mismatch"); | |||
| for (size_t i = 0; i < nr_inputs; ++i) { | |||
| if (input_grads[i] && m_vars_require_grad.count(expr.inputs[i])) { | |||
| accumulate_grad(expr.inputs[i], input_grads[i]); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| var_t get_grad(var_t dest) { | |||
| if (m_grads.count(dest)) { | |||
| return m_grads.at(dest); | |||
| } | |||
| return 0; | |||
| } | |||
| vars_t get_grads(vars_t dests) { | |||
| vars_t grads; | |||
| for (auto&& dest: dests) { | |||
| grads.push_back(get_grad(dest)); | |||
| } | |||
| return grads; | |||
| } | |||
| }; | |||
| } // namespace imperative | |||
| } // namespace mgb | |||
| @@ -22,22 +22,22 @@ using namespace cg; | |||
| using namespace imperative; | |||
| template <typename T> | |||
| T prepare_backward_graph_inputs(const BackwardGraphResult& bg, const T& inputs, | |||
| T prepare_backward_graph_inputs(const EncodedSubraph& bg, const T& inputs, | |||
| const T& outputs, const T& grads) { | |||
| T ret; | |||
| size_t i = 0; | |||
| for (auto&& t : inputs) { | |||
| if (bg.save_for_backward[i++]) { | |||
| if (bg.input_mask[i++]) { | |||
| ret.push_back(t); | |||
| } | |||
| } | |||
| for (auto&& t : outputs) { | |||
| if (bg.save_for_backward[i++]) { | |||
| if (bg.input_mask[i++]) { | |||
| ret.push_back(t); | |||
| } | |||
| } | |||
| for (auto&& t : grads) { | |||
| if (bg.save_for_backward[i++]) { | |||
| if (bg.input_mask[i++]) { | |||
| ret.push_back(t); | |||
| } | |||
| } | |||
| @@ -45,10 +45,10 @@ T prepare_backward_graph_inputs(const BackwardGraphResult& bg, const T& inputs, | |||
| } | |||
| template <typename T, typename U> | |||
| T expand_grads(const U& bg, const T& outputs) { | |||
| T ret(bg.input_has_grad.size()); | |||
| for (size_t i = 0, j = 0; i < bg.input_has_grad.size(); ++i) { | |||
| if (bg.input_has_grad[i]) { | |||
| T expand_grads(const U& mask, const T& outputs) { | |||
| T ret(mask.size()); | |||
| for (size_t i = 0, j = 0; i < mask.size(); ++i) { | |||
| if (mask[i]) { | |||
| ret[i] = outputs[j++]; | |||
| } | |||
| } | |||
| @@ -80,7 +80,7 @@ T prepare_optimized_backward_inputs(const OptimizedBackwardGraphResult& bg, | |||
| } | |||
| SmallVector<TensorPtr> apply_shared_on_physical_tensor( | |||
| std::shared_ptr<OpDef> def, SmallVector<TensorPtr> inputs) { | |||
| std::shared_ptr<OpDef> def, SmallVector<TensorPtr> inputs, size_t nr_outputs) { | |||
| return OpDef::apply_on_physical_tensor(*def, inputs); | |||
| } | |||
| @@ -104,8 +104,8 @@ TEST(TestImperative, BackwardGraphBasic) { | |||
| } | |||
| auto result = OpDef::make_backward_graph(*attr, input_descs, {true, true}, | |||
| {true}); | |||
| auto&& save_for_backward = result.save_for_backward; | |||
| auto&& input_has_grad = result.input_has_grad; | |||
| auto&& save_for_backward = result.input_mask; | |||
| auto&& input_has_grad = result.output_mask; | |||
| auto outputs = OpDef::apply_on_physical_tensor(*attr, inputs); | |||
| inputs.push_back(outputs[0]); | |||
| @@ -124,7 +124,7 @@ TEST(TestImperative, BackwardGraphBasic) { | |||
| } | |||
| } | |||
| inputs.clear(); | |||
| auto input_grads = result.backward.apply(backward_graph_inputs, | |||
| auto input_grads = result.graph.apply(backward_graph_inputs, | |||
| apply_shared_on_physical_tensor, | |||
| [&](auto&& x) { return x; }); | |||
| mgb_assert(input_grads.size() == input_has_grad.size()); | |||
| @@ -159,8 +159,8 @@ TEST(TestImperative, BackwardGraphIdentity) { | |||
| input_descs.push_back({a->layout(), a->comp_node()}); | |||
| auto result = | |||
| OpDef::make_backward_graph(*attr, input_descs, {true}, {true}); | |||
| auto&& save_for_backward = result.save_for_backward; | |||
| auto&& input_has_grad = result.input_has_grad; | |||
| auto&& save_for_backward = result.input_mask; | |||
| auto&& input_has_grad = result.output_mask; | |||
| auto outputs = OpDef::apply_on_physical_tensor(*attr, inputs); | |||
| inputs.push_back(outputs[0]); | |||
| @@ -178,7 +178,7 @@ TEST(TestImperative, BackwardGraphIdentity) { | |||
| } | |||
| } | |||
| inputs.clear(); | |||
| auto input_grads = result.backward.apply(backward_graph_inputs, | |||
| auto input_grads = result.graph.apply(backward_graph_inputs, | |||
| apply_shared_on_physical_tensor, | |||
| [&](auto&& x) { return x; }); | |||
| mgb_assert(input_grads.size() == input_has_grad.size()); | |||
| @@ -245,7 +245,7 @@ TEST(TestImperative, OptimizedBackwardGraphBasic) { | |||
| prepare_backward_graph_inputs<SmallVector<TensorPtr>>( | |||
| bg, {a_tn, b_tn}, {c_tn}, {dc_tn}); | |||
| auto grads = | |||
| expand_grads(bg, bg.backward.apply(backward_graph_inputs, | |||
| expand_grads(bg.output_mask, bg.graph.apply(backward_graph_inputs, | |||
| apply_shared_on_physical_tensor, | |||
| [&](auto&& x) { return x; })); | |||
| @@ -262,7 +262,7 @@ TEST(TestImperative, OptimizedBackwardGraphBasic) { | |||
| prepare_optimized_backward_inputs<SmallVector<TensorPtr>>( | |||
| obg, precomp, {a_tn, b_tn}, {c_tn}, {dc_tn}); | |||
| auto grads2 = expand_grads( | |||
| obg, | |||
| obg.input_has_grad, | |||
| obg.backward.apply(backward_inputs, apply_shared_on_physical_tensor, | |||
| [&](auto&& x) { return x; })); | |||