GitOrigin-RevId: 56d90be0e7
tags/v1.6.0-rc1
| @@ -77,7 +77,7 @@ std::shared_ptr<OptimizedBackwardGraphResult> make_backward_graph( | |||||
| std::shared_ptr<OptimizedBackwardGraphResult> ret; | std::shared_ptr<OptimizedBackwardGraphResult> ret; | ||||
| auto bg = OpDef::make_backward_graph( | auto bg = OpDef::make_backward_graph( | ||||
| *ctx.op, inputs, input_requires_grad, output_has_grad); | *ctx.op, inputs, input_requires_grad, output_has_grad); | ||||
| if (!bg.backward.empty()) { | |||||
| if (!bg.graph.empty()) { | |||||
| ret = std::make_shared<OptimizedBackwardGraphResult>(bg); | ret = std::make_shared<OptimizedBackwardGraphResult>(bg); | ||||
| } | } | ||||
| backward_graph_cache.emplace(key, ret); | backward_graph_cache.emplace(key, ret); | ||||
| @@ -37,7 +37,7 @@ void init_imperative_rt(py::module m) { | |||||
| const SmallVector<bool>& input_requires_grad, | const SmallVector<bool>& input_requires_grad, | ||||
| const SmallVector<bool>& output_has_grad){ | const SmallVector<bool>& output_has_grad){ | ||||
| auto result = OpDef::make_backward_graph(def, inputs, input_requires_grad, output_has_grad); | auto result = OpDef::make_backward_graph(def, inputs, input_requires_grad, output_has_grad); | ||||
| return std::make_tuple("backward_graph", result.save_for_backward, result.input_has_grad); | |||||
| return std::make_tuple("backward_graph", result.input_mask, result.output_mask); | |||||
| }; | }; | ||||
| m.def("make_backward_graph", make_backward_graph); | m.def("make_backward_graph", make_backward_graph); | ||||
| } | } | ||||
| @@ -16,19 +16,19 @@ | |||||
| using namespace mgb; | using namespace mgb; | ||||
| using namespace imperative; | using namespace imperative; | ||||
| OptimizedBackwardGraphResult::OptimizedBackwardGraphResult(const BackwardGraphResult& src) | |||||
| : input_has_grad(src.input_has_grad) { | |||||
| if (src.backward.exprs.size() <= 1) { | |||||
| OptimizedBackwardGraphResult::OptimizedBackwardGraphResult(const EncodedSubraph& src) | |||||
| : input_has_grad(src.output_mask) { | |||||
| if (src.graph.exprs.size() <= 1) { | |||||
| // backward graph only contains a single op | // backward graph only contains a single op | ||||
| backward = src.backward; | |||||
| save_for_backward = src.save_for_backward; | |||||
| backward = src.graph; | |||||
| save_for_backward = src.input_mask; | |||||
| return; | return; | ||||
| } | } | ||||
| save_for_backward.resize(src.save_for_backward.size(), false); | |||||
| save_for_backward.resize(src.input_mask.size(), false); | |||||
| auto&& graph = src.backward; | |||||
| auto&& mask = src.save_for_backward; | |||||
| size_t input_size = src.input_has_grad.size(); | |||||
| auto&& graph = src.graph; | |||||
| auto&& mask = src.input_mask; | |||||
| size_t input_size = src.output_mask.size(); | |||||
| size_t output_size = (mask.size() - input_size) / 2; | size_t output_size = (mask.size() - input_size) / 2; | ||||
| mgb_assert(input_size + output_size * 2 == mask.size()); | mgb_assert(input_size + output_size * 2 == mask.size()); | ||||
| @@ -80,7 +80,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> OpDef::infer_output_attrs_falli | |||||
| return def.trait()->infer_output_attrs_fallible(def, inputs); | return def.trait()->infer_output_attrs_fallible(def, inputs); | ||||
| } | } | ||||
| BackwardGraphResult OpDef::make_backward_graph( | |||||
| EncodedSubraph OpDef::make_backward_graph( | |||||
| const OpDef& def, | const OpDef& def, | ||||
| const SmallVector<LogicalTensorDesc>& inputs, | const SmallVector<LogicalTensorDesc>& inputs, | ||||
| const SmallVector<bool>& input_requires_grad, | const SmallVector<bool>& input_requires_grad, | ||||
| @@ -668,14 +668,14 @@ struct ProxyGraph::GradGraph { | |||||
| cg::VarNode* grad; | cg::VarNode* grad; | ||||
| }; | }; | ||||
| BackwardGraphResult | |||||
| EncodedSubraph | |||||
| ProxyGraph::make_backward_graph( | ProxyGraph::make_backward_graph( | ||||
| const OpDef& opdef, | const OpDef& opdef, | ||||
| const SmallVector<LogicalTensorDesc>& input_descs, | const SmallVector<LogicalTensorDesc>& input_descs, | ||||
| const SmallVector<bool>& input_requires_grad, | const SmallVector<bool>& input_requires_grad, | ||||
| const SmallVector<bool>& output_has_grad) { | const SmallVector<bool>& output_has_grad) { | ||||
| ThinHashMap<VarNode*, size_t> var2idx; | ThinHashMap<VarNode*, size_t> var2idx; | ||||
| auto push = [&var2idx, cnt=0](VarNode* var) mutable { | |||||
| auto push = [&var2idx, cnt=1](VarNode* var) mutable { //cnt is always greater non zero | |||||
| auto&& ret = var2idx.emplace(var, cnt ++); | auto&& ret = var2idx.emplace(var, cnt ++); | ||||
| mgb_assert(ret.second, "var %s has been already inserted", var->cname()); | mgb_assert(ret.second, "var %s has been already inserted", var->cname()); | ||||
| return ret.first->second; | return ret.first->second; | ||||
| @@ -702,8 +702,8 @@ ProxyGraph::make_backward_graph( | |||||
| } | } | ||||
| auto* gfunc = cg::lookup_grad_func(fwd->dyn_typeinfo()); | auto* gfunc = cg::lookup_grad_func(fwd->dyn_typeinfo()); | ||||
| BackwardGraphResult result; | |||||
| auto&& igraph = result.backward; | |||||
| EncodedSubraph result; | |||||
| auto&& igraph = result.graph; | |||||
| size_t nr_backward_graph_inputs = 0; | size_t nr_backward_graph_inputs = 0; | ||||
| auto gen_expr = [this, &var2idx, &igraph, &push, &fwd, | auto gen_expr = [this, &var2idx, &igraph, &push, &fwd, | ||||
| @@ -735,7 +735,7 @@ ProxyGraph::make_backward_graph( | |||||
| // set backward graph outputs | // set backward graph outputs | ||||
| cg::DepOprIter iter{gen_expr}; | cg::DepOprIter iter{gen_expr}; | ||||
| iter.set_visited(fwd); | iter.set_visited(fwd); | ||||
| result.input_has_grad.resize(inputs.size()); | |||||
| result.output_mask.resize(inputs.size()); | |||||
| VarNodeArray output_grads_with_unused_var; | VarNodeArray output_grads_with_unused_var; | ||||
| { | { | ||||
| @@ -760,6 +760,7 @@ ProxyGraph::make_backward_graph( | |||||
| if (grad_results.valid()) { | if (grad_results.valid()) { | ||||
| grad = grad_results.val()[i]; | grad = grad_results.val()[i]; | ||||
| } else { | } else { | ||||
| mgb_assert(gfunc, "could not find grad function"); | |||||
| auto res = (*gfunc)(fwd, i, output_grads_with_unused_var); | auto res = (*gfunc)(fwd, i, output_grads_with_unused_var); | ||||
| if (res.from_single()) { | if (res.from_single()) { | ||||
| grad = res.single(); | grad = res.single(); | ||||
| @@ -776,9 +777,9 @@ ProxyGraph::make_backward_graph( | |||||
| fwd->dyn_typeinfo()->name, i); | fwd->dyn_typeinfo()->name, i); | ||||
| iter.add(grad); | iter.add(grad); | ||||
| igraph.outputs.push_back(var2idx.at(grad)); | igraph.outputs.push_back(var2idx.at(grad)); | ||||
| result.input_has_grad[i] = true; | |||||
| result.output_mask[i] = true; | |||||
| } else { | } else { | ||||
| result.input_has_grad[i] = false; | |||||
| result.output_mask[i] = false; | |||||
| } | } | ||||
| } | } | ||||
| if (igraph.outputs.empty()) { | if (igraph.outputs.empty()) { | ||||
| @@ -787,15 +788,15 @@ ProxyGraph::make_backward_graph( | |||||
| // set backward graph inputs | // set backward graph inputs | ||||
| igraph.inputs.reserve(nr_backward_graph_inputs); | igraph.inputs.reserve(nr_backward_graph_inputs); | ||||
| result.save_for_backward.reserve(nr_backward_graph_inputs); | |||||
| result.input_mask.reserve(nr_backward_graph_inputs); | |||||
| auto write_inputs = [&igraph, &var2idx, &result](const VarNodeArray& vars) { | auto write_inputs = [&igraph, &var2idx, &result](const VarNodeArray& vars) { | ||||
| for (auto&& i: vars) { | for (auto&& i: vars) { | ||||
| auto&& iter = var2idx.find(i); | auto&& iter = var2idx.find(i); | ||||
| if (iter != var2idx.end()) { | if (iter != var2idx.end()) { | ||||
| igraph.inputs.push_back(iter->second); | igraph.inputs.push_back(iter->second); | ||||
| result.save_for_backward.push_back(true); | |||||
| result.input_mask.push_back(true); | |||||
| } else { | } else { | ||||
| result.save_for_backward.push_back(false); | |||||
| result.input_mask.push_back(false); | |||||
| } | } | ||||
| } | } | ||||
| }; | }; | ||||
| @@ -40,7 +40,7 @@ public: | |||||
| const SmallVector<Tensor*>& outputs, | const SmallVector<Tensor*>& outputs, | ||||
| const SmallVector<Tensor*>& workspace); | const SmallVector<Tensor*>& workspace); | ||||
| BackwardGraphResult make_backward_graph( | |||||
| EncodedSubraph make_backward_graph( | |||||
| const OpDef& opdef, | const OpDef& opdef, | ||||
| const SmallVector<LogicalTensorDesc>& input_descs, | const SmallVector<LogicalTensorDesc>& input_descs, | ||||
| const SmallVector<bool>& input_requires_grad, | const SmallVector<bool>& input_requires_grad, | ||||
| @@ -133,7 +133,7 @@ size_t get_backward_graph_hash_key(const OpDef& def, | |||||
| return state.digest(); | return state.digest(); | ||||
| } | } | ||||
| struct BackwardGraphCache : std::unordered_map<size_t, BackwardGraphResult>, CompNodeDepedentObject { | |||||
| struct BackwardGraphCache : std::unordered_map<size_t, EncodedSubraph>, CompNodeDepedentObject { | |||||
| std::shared_ptr<void> on_comp_node_finalize() override { | std::shared_ptr<void> on_comp_node_finalize() override { | ||||
| clear(); | clear(); | ||||
| return {}; | return {}; | ||||
| @@ -142,7 +142,7 @@ struct BackwardGraphCache : std::unordered_map<size_t, BackwardGraphResult>, Com | |||||
| } // anonymous namespace | } // anonymous namespace | ||||
| BackwardGraphResult | |||||
| EncodedSubraph | |||||
| make_backward_graph(const OpDef& def, | make_backward_graph(const OpDef& def, | ||||
| const SmallVector<LogicalTensorDesc>& inputs, | const SmallVector<LogicalTensorDesc>& inputs, | ||||
| const SmallVector<bool>& input_requires_grad, | const SmallVector<bool>& input_requires_grad, | ||||
| @@ -101,5 +101,26 @@ void Subgraph::replace_vars( | |||||
| } | } | ||||
| } | } | ||||
| std::string EncodedSubraph::repr() const { | |||||
| std::string buffer; | |||||
| buffer.push_back('|'); | |||||
| for (size_t i = 0; i < input_mask.size(); ++i) { | |||||
| buffer.push_back(input_mask[i] ? '#' : ' '); | |||||
| } | |||||
| buffer.push_back('|'); | |||||
| buffer.push_back('\n'); | |||||
| buffer.append(graph.repr()); | |||||
| buffer.push_back('|'); | |||||
| for (size_t i = 0; i < output_mask.size(); ++i) { | |||||
| buffer.push_back(output_mask[i] ? '#' : ' '); | |||||
| } | |||||
| buffer.push_back('|'); | |||||
| return buffer; | |||||
| } | |||||
| size_t EncodedSubraph::hash() const { | |||||
| return std::hash<std::string>{}(repr()); | |||||
| } | |||||
| } // namespace imperative | } // namespace imperative | ||||
| } // namespace mgb | } // namespace mgb | ||||
| @@ -19,7 +19,7 @@ struct OptimizedBackwardGraphResult { | |||||
| SmallVector<bool> save_for_backward; | SmallVector<bool> save_for_backward; | ||||
| SmallVector<bool> input_has_grad; | SmallVector<bool> input_has_grad; | ||||
| OptimizedBackwardGraphResult(const BackwardGraphResult& bgraph); | |||||
| OptimizedBackwardGraphResult(const EncodedSubraph& bgraph); | |||||
| }; | }; | ||||
| } // namespace mgb::imperative | } // namespace mgb::imperative | ||||
| @@ -29,12 +29,6 @@ enum DispatchMode { | |||||
| using SharedOp = std::shared_ptr<OpDef>; | using SharedOp = std::shared_ptr<OpDef>; | ||||
| struct BackwardGraphResult { | |||||
| Subgraph backward; | |||||
| SmallVector<bool> save_for_backward; | |||||
| SmallVector<bool> input_has_grad; | |||||
| }; | |||||
| class OpDef : public Hashable, | class OpDef : public Hashable, | ||||
| public NonCopyableObj, | public NonCopyableObj, | ||||
| public std::enable_shared_from_this<OpDef> { | public std::enable_shared_from_this<OpDef> { | ||||
| @@ -91,7 +85,7 @@ public: | |||||
| const SmallVector<TensorPtr>& inputs_tensors, | const SmallVector<TensorPtr>& inputs_tensors, | ||||
| const SmallVector<MemoryDesc>& inputs_mems); | const SmallVector<MemoryDesc>& inputs_mems); | ||||
| static BackwardGraphResult make_backward_graph( | |||||
| static EncodedSubraph make_backward_graph( | |||||
| const OpDef& def, | const OpDef& def, | ||||
| const SmallVector<LogicalTensorDesc>& inputs, | const SmallVector<LogicalTensorDesc>& inputs, | ||||
| const SmallVector<bool>& input_requires_grad, | const SmallVector<bool>& input_requires_grad, | ||||
| @@ -38,7 +38,7 @@ void exec(const OpDef& def, | |||||
| const SmallVector<TensorPtr>& inputs, | const SmallVector<TensorPtr>& inputs, | ||||
| const SmallVector<TensorPtr>& outputs); | const SmallVector<TensorPtr>& outputs); | ||||
| BackwardGraphResult | |||||
| EncodedSubraph | |||||
| make_backward_graph(const OpDef& def, | make_backward_graph(const OpDef& def, | ||||
| const SmallVector<LogicalTensorDesc>& inputs, | const SmallVector<LogicalTensorDesc>& inputs, | ||||
| const SmallVector<bool>& input_requires_grad, | const SmallVector<bool>& input_requires_grad, | ||||
| @@ -96,5 +96,185 @@ struct Subgraph { | |||||
| bool operator==(const Subgraph& rhs) const; | bool operator==(const Subgraph& rhs) const; | ||||
| }; | }; | ||||
| struct EncodedSubraph { | |||||
| Subgraph graph; | |||||
| SmallVector<bool> input_mask; | |||||
| SmallVector<bool> output_mask; | |||||
| template <typename TContainer> | |||||
| TContainer encode_inputs(TContainer inputs) const { | |||||
| TContainer encoded_inputs; | |||||
| size_t index = 0; | |||||
| for (auto&& input : inputs) { | |||||
| mgb_assert(index < input_mask.size(), "index out of range"); | |||||
| if (input_mask[index++]) { | |||||
| encoded_inputs.push_back(input); | |||||
| } | |||||
| } | |||||
| mgb_assert(index == input_mask.size(), "mask size mismatch"); | |||||
| return encoded_inputs; | |||||
| } | |||||
| template <typename TContainer> | |||||
| TContainer encode_outputs(TContainer outputs) const { | |||||
| TContainer encoded_outputs; | |||||
| size_t index = 0; | |||||
| for (auto&& output : outputs) { | |||||
| mgb_assert(index < output_mask.size(), "index out of range"); | |||||
| if (output_mask[index++]) { | |||||
| encoded_outputs.push_back(output); | |||||
| } | |||||
| } | |||||
| mgb_assert(index == output_mask.size(), "mask size mismatch"); | |||||
| return encoded_outputs; | |||||
| } | |||||
| template <typename TContainer> | |||||
| TContainer decode_outputs(TContainer outputs) const { | |||||
| TContainer decoded_outputs; | |||||
| size_t index = 0; | |||||
| for (size_t i = 0; i < output_mask.size(); i++) { | |||||
| mgb_assert(index < output_mask.size(), "index out of range"); | |||||
| if (output_mask[i]) { | |||||
| decoded_outputs.push_back(outputs[index++]); | |||||
| } else { | |||||
| decoded_outputs.emplace_back(); | |||||
| } | |||||
| } | |||||
| mgb_assert(decoded_outputs.size() == output_mask.size(), | |||||
| "mask size mismatch"); | |||||
| return decoded_outputs; | |||||
| } | |||||
| static EncodedSubraph make(Subgraph graph) { | |||||
| EncodedSubraph result; | |||||
| result.input_mask = graph.gen_input_mask(); | |||||
| result.output_mask = graph.gen_output_mask(); | |||||
| graph.inputs = result.encode_inputs(graph.inputs); | |||||
| graph.outputs = result.encode_outputs(graph.outputs); | |||||
| result.graph = graph; | |||||
| return result; | |||||
| } | |||||
| static EncodedSubraph make_single( | |||||
| std::shared_ptr<OpDef> op, | |||||
| SmallVector<bool> input_mask, | |||||
| SmallVector<bool> output_mask) { | |||||
| EncodedSubraph result; | |||||
| result.input_mask = input_mask; | |||||
| result.output_mask = output_mask; | |||||
| Subgraph::var_t last_var = 0; | |||||
| for (auto&& mask: input_mask) { | |||||
| if (mask) { | |||||
| result.graph.inputs.push_back(++last_var); | |||||
| } | |||||
| } | |||||
| for (auto&& mask: output_mask) { | |||||
| if (mask) { | |||||
| result.graph.outputs.push_back(++last_var); | |||||
| } | |||||
| } | |||||
| result.graph.exprs = {Subgraph::expr_t{op, result.graph.inputs, result.graph.outputs}}; | |||||
| return result; | |||||
| } | |||||
| template <typename T, typename F, typename C> | |||||
| SmallVector<T> apply(SmallVector<T> input_vars, F&& f, C&& c) const { | |||||
| auto encoded_inputs = encode_inputs(input_vars); | |||||
| auto encoded_outputs = graph.apply(encoded_inputs, std::forward<F>(f), | |||||
| std::forward<C>(c)); | |||||
| return decode_outputs(encoded_outputs); | |||||
| } | |||||
| std::string repr() const; | |||||
| size_t hash() const; | |||||
| }; | |||||
| template <typename T> | |||||
| class GradContext { | |||||
| public: | |||||
| using var_t = T; | |||||
| using vars_t = SmallVector<var_t>; | |||||
| using expr_t = Expr<T>; | |||||
| private: | |||||
| std::unordered_map<var_t, var_t> m_grads; | |||||
| std::unordered_set<var_t> m_vars_require_grad; | |||||
| std::function<var_t(var_t, var_t)> m_accumulator; | |||||
| std::vector<expr_t> m_exprs; | |||||
| public: | |||||
| GradContext(std::function<var_t(var_t, var_t)> accumulator): m_accumulator{std::move(accumulator)}{} | |||||
| SmallVector<bool> get_require_grads(vars_t dests) { | |||||
| SmallVector<bool> mask; | |||||
| for (auto&& dest: dests) { | |||||
| mask.push_back(bool(m_vars_require_grad.count(dest))); | |||||
| } | |||||
| return mask; | |||||
| } | |||||
| SmallVector<bool> get_has_grads(vars_t dests) { | |||||
| SmallVector<bool> mask; | |||||
| for (auto&& dest: dests) { | |||||
| mask.push_back(bool(m_grads.count(dest))); | |||||
| } | |||||
| return mask; | |||||
| } | |||||
| void mark_require_grads(vars_t dests) { | |||||
| for (auto&& dest: dests) { | |||||
| m_vars_require_grad.insert(dest); | |||||
| } | |||||
| } | |||||
| var_t accumulate_grad(var_t dest, var_t grad) { | |||||
| if (!m_grads.count(dest)) { | |||||
| return m_grads[dest] = grad; | |||||
| } else { | |||||
| return m_grads[dest] = m_accumulator(m_grads[dest], grad); | |||||
| } | |||||
| } | |||||
| void record_expr(std::shared_ptr<OpDef> op, vars_t inputs, vars_t outputs) { | |||||
| bool require_grad = false; | |||||
| for (auto&& input: inputs) { | |||||
| if (m_vars_require_grad.count(input)) { | |||||
| require_grad = true; | |||||
| break; | |||||
| } | |||||
| } | |||||
| if (require_grad) { | |||||
| m_exprs.push_back({op, inputs, outputs}); | |||||
| mark_require_grads(outputs); | |||||
| } | |||||
| } | |||||
| template <typename TFunctor> | |||||
| void backward(vars_t outputs, vars_t output_grads, TFunctor functor) { | |||||
| size_t nr_outputs = outputs.size(); | |||||
| for (size_t i = 0; i < nr_outputs; ++i) { | |||||
| m_grads[outputs[i]] = output_grads[i]; | |||||
| } | |||||
| auto exprs = m_exprs; | |||||
| std::reverse(exprs.begin(), exprs.end()); | |||||
| for (const expr_t& expr: exprs) { | |||||
| size_t nr_inputs = expr.inputs.size(); | |||||
| vars_t input_grads = functor(expr, get_grads(expr.outputs)); | |||||
| mgb_assert(input_grads.size() == nr_inputs, "input size mismatch"); | |||||
| for (size_t i = 0; i < nr_inputs; ++i) { | |||||
| if (input_grads[i] && m_vars_require_grad.count(expr.inputs[i])) { | |||||
| accumulate_grad(expr.inputs[i], input_grads[i]); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| var_t get_grad(var_t dest) { | |||||
| if (m_grads.count(dest)) { | |||||
| return m_grads.at(dest); | |||||
| } | |||||
| return 0; | |||||
| } | |||||
| vars_t get_grads(vars_t dests) { | |||||
| vars_t grads; | |||||
| for (auto&& dest: dests) { | |||||
| grads.push_back(get_grad(dest)); | |||||
| } | |||||
| return grads; | |||||
| } | |||||
| }; | |||||
| } // namespace imperative | } // namespace imperative | ||||
| } // namespace mgb | } // namespace mgb | ||||
| @@ -22,22 +22,22 @@ using namespace cg; | |||||
| using namespace imperative; | using namespace imperative; | ||||
| template <typename T> | template <typename T> | ||||
| T prepare_backward_graph_inputs(const BackwardGraphResult& bg, const T& inputs, | |||||
| T prepare_backward_graph_inputs(const EncodedSubraph& bg, const T& inputs, | |||||
| const T& outputs, const T& grads) { | const T& outputs, const T& grads) { | ||||
| T ret; | T ret; | ||||
| size_t i = 0; | size_t i = 0; | ||||
| for (auto&& t : inputs) { | for (auto&& t : inputs) { | ||||
| if (bg.save_for_backward[i++]) { | |||||
| if (bg.input_mask[i++]) { | |||||
| ret.push_back(t); | ret.push_back(t); | ||||
| } | } | ||||
| } | } | ||||
| for (auto&& t : outputs) { | for (auto&& t : outputs) { | ||||
| if (bg.save_for_backward[i++]) { | |||||
| if (bg.input_mask[i++]) { | |||||
| ret.push_back(t); | ret.push_back(t); | ||||
| } | } | ||||
| } | } | ||||
| for (auto&& t : grads) { | for (auto&& t : grads) { | ||||
| if (bg.save_for_backward[i++]) { | |||||
| if (bg.input_mask[i++]) { | |||||
| ret.push_back(t); | ret.push_back(t); | ||||
| } | } | ||||
| } | } | ||||
| @@ -45,10 +45,10 @@ T prepare_backward_graph_inputs(const BackwardGraphResult& bg, const T& inputs, | |||||
| } | } | ||||
| template <typename T, typename U> | template <typename T, typename U> | ||||
| T expand_grads(const U& bg, const T& outputs) { | |||||
| T ret(bg.input_has_grad.size()); | |||||
| for (size_t i = 0, j = 0; i < bg.input_has_grad.size(); ++i) { | |||||
| if (bg.input_has_grad[i]) { | |||||
| T expand_grads(const U& mask, const T& outputs) { | |||||
| T ret(mask.size()); | |||||
| for (size_t i = 0, j = 0; i < mask.size(); ++i) { | |||||
| if (mask[i]) { | |||||
| ret[i] = outputs[j++]; | ret[i] = outputs[j++]; | ||||
| } | } | ||||
| } | } | ||||
| @@ -80,7 +80,7 @@ T prepare_optimized_backward_inputs(const OptimizedBackwardGraphResult& bg, | |||||
| } | } | ||||
| SmallVector<TensorPtr> apply_shared_on_physical_tensor( | SmallVector<TensorPtr> apply_shared_on_physical_tensor( | ||||
| std::shared_ptr<OpDef> def, SmallVector<TensorPtr> inputs) { | |||||
| std::shared_ptr<OpDef> def, SmallVector<TensorPtr> inputs, size_t nr_outputs) { | |||||
| return OpDef::apply_on_physical_tensor(*def, inputs); | return OpDef::apply_on_physical_tensor(*def, inputs); | ||||
| } | } | ||||
| @@ -104,8 +104,8 @@ TEST(TestImperative, BackwardGraphBasic) { | |||||
| } | } | ||||
| auto result = OpDef::make_backward_graph(*attr, input_descs, {true, true}, | auto result = OpDef::make_backward_graph(*attr, input_descs, {true, true}, | ||||
| {true}); | {true}); | ||||
| auto&& save_for_backward = result.save_for_backward; | |||||
| auto&& input_has_grad = result.input_has_grad; | |||||
| auto&& save_for_backward = result.input_mask; | |||||
| auto&& input_has_grad = result.output_mask; | |||||
| auto outputs = OpDef::apply_on_physical_tensor(*attr, inputs); | auto outputs = OpDef::apply_on_physical_tensor(*attr, inputs); | ||||
| inputs.push_back(outputs[0]); | inputs.push_back(outputs[0]); | ||||
| @@ -124,7 +124,7 @@ TEST(TestImperative, BackwardGraphBasic) { | |||||
| } | } | ||||
| } | } | ||||
| inputs.clear(); | inputs.clear(); | ||||
| auto input_grads = result.backward.apply(backward_graph_inputs, | |||||
| auto input_grads = result.graph.apply(backward_graph_inputs, | |||||
| apply_shared_on_physical_tensor, | apply_shared_on_physical_tensor, | ||||
| [&](auto&& x) { return x; }); | [&](auto&& x) { return x; }); | ||||
| mgb_assert(input_grads.size() == input_has_grad.size()); | mgb_assert(input_grads.size() == input_has_grad.size()); | ||||
| @@ -159,8 +159,8 @@ TEST(TestImperative, BackwardGraphIdentity) { | |||||
| input_descs.push_back({a->layout(), a->comp_node()}); | input_descs.push_back({a->layout(), a->comp_node()}); | ||||
| auto result = | auto result = | ||||
| OpDef::make_backward_graph(*attr, input_descs, {true}, {true}); | OpDef::make_backward_graph(*attr, input_descs, {true}, {true}); | ||||
| auto&& save_for_backward = result.save_for_backward; | |||||
| auto&& input_has_grad = result.input_has_grad; | |||||
| auto&& save_for_backward = result.input_mask; | |||||
| auto&& input_has_grad = result.output_mask; | |||||
| auto outputs = OpDef::apply_on_physical_tensor(*attr, inputs); | auto outputs = OpDef::apply_on_physical_tensor(*attr, inputs); | ||||
| inputs.push_back(outputs[0]); | inputs.push_back(outputs[0]); | ||||
| @@ -178,7 +178,7 @@ TEST(TestImperative, BackwardGraphIdentity) { | |||||
| } | } | ||||
| } | } | ||||
| inputs.clear(); | inputs.clear(); | ||||
| auto input_grads = result.backward.apply(backward_graph_inputs, | |||||
| auto input_grads = result.graph.apply(backward_graph_inputs, | |||||
| apply_shared_on_physical_tensor, | apply_shared_on_physical_tensor, | ||||
| [&](auto&& x) { return x; }); | [&](auto&& x) { return x; }); | ||||
| mgb_assert(input_grads.size() == input_has_grad.size()); | mgb_assert(input_grads.size() == input_has_grad.size()); | ||||
| @@ -245,7 +245,7 @@ TEST(TestImperative, OptimizedBackwardGraphBasic) { | |||||
| prepare_backward_graph_inputs<SmallVector<TensorPtr>>( | prepare_backward_graph_inputs<SmallVector<TensorPtr>>( | ||||
| bg, {a_tn, b_tn}, {c_tn}, {dc_tn}); | bg, {a_tn, b_tn}, {c_tn}, {dc_tn}); | ||||
| auto grads = | auto grads = | ||||
| expand_grads(bg, bg.backward.apply(backward_graph_inputs, | |||||
| expand_grads(bg.output_mask, bg.graph.apply(backward_graph_inputs, | |||||
| apply_shared_on_physical_tensor, | apply_shared_on_physical_tensor, | ||||
| [&](auto&& x) { return x; })); | [&](auto&& x) { return x; })); | ||||
| @@ -262,7 +262,7 @@ TEST(TestImperative, OptimizedBackwardGraphBasic) { | |||||
| prepare_optimized_backward_inputs<SmallVector<TensorPtr>>( | prepare_optimized_backward_inputs<SmallVector<TensorPtr>>( | ||||
| obg, precomp, {a_tn, b_tn}, {c_tn}, {dc_tn}); | obg, precomp, {a_tn, b_tn}, {c_tn}, {dc_tn}); | ||||
| auto grads2 = expand_grads( | auto grads2 = expand_grads( | ||||
| obg, | |||||
| obg.input_has_grad, | |||||
| obg.backward.apply(backward_inputs, apply_shared_on_physical_tensor, | obg.backward.apply(backward_inputs, apply_shared_on_physical_tensor, | ||||
| [&](auto&& x) { return x; })); | [&](auto&& x) { return x; })); | ||||