GitOrigin-RevId: 2c0c8f330d
tags/v1.9.0
| @@ -11,10 +11,13 @@ | |||||
| #include "./proxy_graph.h" | #include "./proxy_graph.h" | ||||
| #include "./blob_manager_impl.h" | #include "./blob_manager_impl.h" | ||||
| #include "megbrain/graph.h" | |||||
| #include "megbrain/graph/operator_node.h" | #include "megbrain/graph/operator_node.h" | ||||
| #include "megbrain/graph/static_infer.h" | #include "megbrain/graph/static_infer.h" | ||||
| #include "megbrain/imperative/ops/backward_graph.h" | #include "megbrain/imperative/ops/backward_graph.h" | ||||
| #include "megbrain/imperative/ops/opr_attr.h" | #include "megbrain/imperative/ops/opr_attr.h" | ||||
| #include "megbrain/imperative/subgraph_detail.h" | |||||
| #include "megbrain/opr/basic_arith.h" | |||||
| #include "megbrain/opr/internal/megdnn_opr_wrapper.h" | #include "megbrain/opr/internal/megdnn_opr_wrapper.h" | ||||
| #include "megbrain/opr/io.h" | #include "megbrain/opr/io.h" | ||||
| #include "megbrain/opr/tensor_manip.h" | #include "megbrain/opr/tensor_manip.h" | ||||
| @@ -486,139 +489,83 @@ EncodedSubgraph ProxyGraph::make_backward_graph( | |||||
| const OpDef& opdef, const SmallVector<LogicalTensorDesc>& input_descs, | const OpDef& opdef, const SmallVector<LogicalTensorDesc>& input_descs, | ||||
| const SmallVector<bool>& input_requires_grad, | const SmallVector<bool>& input_requires_grad, | ||||
| const SmallVector<bool>& output_has_grad) { | const SmallVector<bool>& output_has_grad) { | ||||
| ThinHashMap<VarNode*, size_t> var2idx; | |||||
| auto push = [&var2idx, | |||||
| cnt = 1](VarNode* var) mutable { // cnt is always greater non zero | |||||
| auto&& ret = var2idx.emplace(var, cnt++); | |||||
| mgb_assert(ret.second, "var %s has been already inserted", var->cname()); | |||||
| return ret.first->second; | |||||
| }; | |||||
| using op_t = OperatorNodeBase*; | |||||
| using var_t = VarNode*; | |||||
| using vars_t = VarNodeArray; | |||||
| auto inputs = make_input_place_holders(input_descs); | auto inputs = make_input_place_holders(input_descs); | ||||
| auto fwd = OpDef::apply_on_var_node(opdef, inputs)[0]->owner_opr(); | |||||
| auto&& outputs = fwd->usable_output(); | |||||
| auto outputs = OpDef::apply_on_var_node(opdef, inputs); | |||||
| SmallVector<LogicalTensorDesc> output_descs; | SmallVector<LogicalTensorDesc> output_descs; | ||||
| for (auto&& i : outputs) { | for (auto&& i : outputs) { | ||||
| output_descs.push_back({TensorLayout{i->dtype()}, i->comp_node()}); | output_descs.push_back({TensorLayout{i->dtype()}, i->comp_node()}); | ||||
| } | } | ||||
| GradContext<op_t, var_t> grad_context{[&](VarNode* lhs, VarNode* rhs) -> VarNode* { | |||||
| auto add = opr::Elemwise::Mode::ADD; | |||||
| return opr::Elemwise::make(VarNodeArray{lhs, rhs}, add).node(); | |||||
| }}; | |||||
| cg::DepOprIter iter{[&](OperatorNodeBase* op) { | |||||
| grad_context.record_expr(op, op->input(), op->output()); | |||||
| }}; | |||||
| for (size_t i = 0; i < inputs.size(); ++i) { | |||||
| auto& input = inputs[i]; | |||||
| iter.set_visited(input->owner_opr()); | |||||
| if (input_requires_grad[i]) { | |||||
| grad_context.mark_require_grad(input); | |||||
| } | |||||
| } | |||||
| for (auto&& output : outputs) { | |||||
| iter.add(output); | |||||
| } | |||||
| auto output_grads = make_input_place_holders(output_descs); | auto output_grads = make_input_place_holders(output_descs); | ||||
| mgb_assert( | |||||
| output_grads.size() == output_has_grad.size(), "%d vs %d", | |||||
| output_grads.size(), output_has_grad.size()); | |||||
| bool any_input_has_grad = false; | |||||
| for (size_t i = 0; i < output_grads.size(); ++i) { | |||||
| for (size_t i = 0; i < outputs.size(); ++i) { | |||||
| if (!output_has_grad[i]) { | if (!output_has_grad[i]) { | ||||
| output_grads[i] = nullptr; | output_grads[i] = nullptr; | ||||
| } else { | |||||
| any_input_has_grad = true; | |||||
| } | } | ||||
| } | } | ||||
| if (!any_input_has_grad) { | |||||
| return {}; | |||||
| } | |||||
| auto* gfunc = cg::lookup_grad_func(fwd->dyn_typeinfo()); | |||||
| EncodedSubgraph result; | |||||
| auto&& igraph = result.graph; | |||||
| size_t nr_backward_graph_inputs = 0; | |||||
| auto gen_expr = [this, &var2idx, &igraph, &push, &fwd, | |||||
| &nr_backward_graph_inputs](cg::OperatorNodeBase* op) { | |||||
| if (auto t = as_tensor(op)) { | |||||
| mgb_assert(op->output().size() == 1); | |||||
| igraph.constants.emplace_back(push(op->output(0)), std::move(t)); | |||||
| } else if (op->same_type<InputPlaceholder>()) { | |||||
| ++nr_backward_graph_inputs; | |||||
| push(op->output(0)); | |||||
| } else { | |||||
| SmallVector<size_t> inputs, outputs; | |||||
| for (auto&& i : op->input()) { | |||||
| if (i->owner_opr() == fwd) { | |||||
| if (var2idx.find(i) == var2idx.end()) { | |||||
| ++nr_backward_graph_inputs; | |||||
| push(i); | |||||
| } | |||||
| } | |||||
| inputs.push_back(var2idx.at(i)); | |||||
| } | |||||
| for (auto&& i : op->usable_output()) { | |||||
| outputs.push_back(push(i)); | |||||
| auto compute_input_grads = [&](op_t op, vars_t inputs, vars_t outputs, | |||||
| vars_t output_grads) { | |||||
| auto* gfunc = cg::lookup_grad_func(op->dyn_typeinfo()); | |||||
| vars_t input_grads(inputs.size(), nullptr); | |||||
| bool any_grad = false; | |||||
| for (auto&& output_grad : output_grads) { | |||||
| if (output_grad) { | |||||
| any_grad = true; | |||||
| } | } | ||||
| igraph.exprs.push_back({OpDef::make_from_op_node(op), inputs, outputs}); | |||||
| } | } | ||||
| }; | |||||
| // set backward graph outputs | |||||
| cg::DepOprIter iter{gen_expr}; | |||||
| iter.set_visited(fwd); | |||||
| result.output_mask.resize(inputs.size()); | |||||
| VarNodeArray output_grads_with_unused_var; | |||||
| { | |||||
| auto iter = output_grads.begin(); | |||||
| for (auto&& i : fwd->output()) { | |||||
| if (i->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) { | |||||
| // the var node with VOLATILE_CONTENT(e.g. workspace | |||||
| // or an empty var) would not be considered as a normal | |||||
| // output, so its grad is always NULL | |||||
| output_grads_with_unused_var.push_back(nullptr); | |||||
| } else { | |||||
| output_grads_with_unused_var.push_back(*iter); | |||||
| ++iter; | |||||
| } | |||||
| if (!gfunc || !any_grad) { | |||||
| return input_grads; | |||||
| } | } | ||||
| mgb_assert(iter == output_grads.end()); | |||||
| } | |||||
| Maybe<VarNodeArray> grad_results; | |||||
| for (size_t i = 0; i < inputs.size(); ++i) { | |||||
| VarNode* grad; | |||||
| if (grad_results.valid()) { | |||||
| grad = grad_results.val()[i]; | |||||
| } else { | |||||
| mgb_assert(gfunc, "could not find grad function"); | |||||
| auto res = (*gfunc)(fwd, i, output_grads_with_unused_var); | |||||
| if (res.from_single()) { | |||||
| grad = res.single(); | |||||
| } else { | |||||
| grad_results.emplace(res.all(fwd)); | |||||
| Maybe<VarNodeArray> grad_results; | |||||
| auto&& input_requires_grad = grad_context.get_require_grads(inputs); | |||||
| for (size_t i = 0; i < inputs.size(); ++i) { | |||||
| VarNode* grad; | |||||
| if (grad_results.valid()) { | |||||
| grad = grad_results.val()[i]; | grad = grad_results.val()[i]; | ||||
| } | |||||
| } | |||||
| if (grad && !grad->owner_opr()->same_type<opr::InvalidGrad>() && | |||||
| input_requires_grad[i]) { | |||||
| mgb_assert( | |||||
| !grad->owner_opr()->same_type<opr::InvalidGrad>(), | |||||
| "gradient of operator %s w.r.t. input #%lu is " | |||||
| "either not well defined or not implemented", | |||||
| fwd->dyn_typeinfo()->name, i); | |||||
| iter.add(grad); | |||||
| igraph.outputs.push_back(var2idx.at(grad)); | |||||
| result.output_mask[i] = true; | |||||
| } else { | |||||
| result.output_mask[i] = false; | |||||
| } | |||||
| } | |||||
| if (igraph.outputs.empty()) { | |||||
| return {}; | |||||
| } | |||||
| // set backward graph inputs | |||||
| auto write_inputs = [&igraph, &var2idx, &result](const VarNodeArray& vars) { | |||||
| for (auto&& i : vars) { | |||||
| auto&& iter = var2idx.find(i); | |||||
| if (iter != var2idx.end()) { | |||||
| igraph.inputs.push_back(iter->second); | |||||
| result.input_mask.push_back(true); | |||||
| } else { | } else { | ||||
| result.input_mask.push_back(false); | |||||
| mgb_assert(gfunc, "could not find grad function"); | |||||
| auto res = (*gfunc)(op, i, output_grads); | |||||
| if (res.from_single()) { | |||||
| grad = res.single(); | |||||
| } else { | |||||
| grad_results.emplace(res.all(op)); | |||||
| grad = grad_results.val()[i]; | |||||
| } | |||||
| } | |||||
| if (grad && !grad->owner_opr()->same_type<opr::InvalidGrad>()) { | |||||
| if (input_requires_grad[i]) { | |||||
| input_grads[i] = grad; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| return input_grads; | |||||
| }; | }; | ||||
| write_inputs(inputs); | |||||
| write_inputs(outputs); | |||||
| write_inputs(output_grads); | |||||
| mgb_assert(igraph.inputs.size() == nr_backward_graph_inputs); | |||||
| return result; | |||||
| grad_context.backward(outputs, output_grads, compute_input_grads); | |||||
| auto input_grads = grad_context.get_grads(inputs); | |||||
| VarNodeArray bgraph_inputs; | |||||
| bgraph_inputs.insert(bgraph_inputs.end(), inputs.begin(), inputs.end()); | |||||
| bgraph_inputs.insert(bgraph_inputs.end(), outputs.begin(), outputs.end()); | |||||
| bgraph_inputs.insert(bgraph_inputs.end(), output_grads.begin(), output_grads.end()); | |||||
| auto graph = subgraph_detail::make_from_computing_graph(bgraph_inputs, input_grads); | |||||
| return graph; | |||||
| } | } | ||||
| VarNodeArray ProxyGraph::make_input_place_holders( | VarNodeArray ProxyGraph::make_input_place_holders( | ||||
| @@ -107,13 +107,16 @@ EncodedSubgraph make_backward_graph_from_forward( | |||||
| Subgraph::Builder<LogicalTensorDesc> builder( | Subgraph::Builder<LogicalTensorDesc> builder( | ||||
| [](auto&& op, auto&& input_descs, size_t nr_outputs) { | [](auto&& op, auto&& input_descs, size_t nr_outputs) { | ||||
| auto [descs, _] = OpDef::infer_output_attrs_fallible(*op, input_descs); | auto [descs, _] = OpDef::infer_output_attrs_fallible(*op, input_descs); | ||||
| mgb_assert( | |||||
| descs.size() == nr_outputs, "nr_outputs mismatch for %s", | |||||
| op->to_string().c_str()); | |||||
| return descs; | return descs; | ||||
| }); | }); | ||||
| auto accum_grad = [&](var_t lhs, var_t rhs) { | auto accum_grad = [&](var_t lhs, var_t rhs) { | ||||
| return builder.write_expr( | return builder.write_expr( | ||||
| Elemwise::make(Elemwise::Mode::ADD), {lhs, rhs}, 1)[0]; | Elemwise::make(Elemwise::Mode::ADD), {lhs, rhs}, 1)[0]; | ||||
| }; | }; | ||||
| GradContext<var_t> grad_context{accum_grad}; | |||||
| GradContext<std::shared_ptr<OpDef>, var_t> grad_context{accum_grad}; | |||||
| auto input_vars = builder.write_inputs(inputs); | auto input_vars = builder.write_inputs(inputs); | ||||
| auto outputs = forward_graph.apply<var_t>( | auto outputs = forward_graph.apply<var_t>( | ||||
| input_vars, std::bind(&decltype(builder)::write_expr, &builder, _1, _2, _3), | input_vars, std::bind(&decltype(builder)::write_expr, &builder, _1, _2, _3), | ||||
| @@ -143,19 +146,17 @@ EncodedSubgraph make_backward_graph_from_forward( | |||||
| grad_context.backward( | grad_context.backward( | ||||
| apply_mask(outputs, output_has_grad), | apply_mask(outputs, output_has_grad), | ||||
| apply_mask(output_grads, output_has_grad), | apply_mask(output_grads, output_has_grad), | ||||
| [&](Subgraph::expr_t expr, vars_t output_grads) { | |||||
| [&](Subgraph::op_t op, vars_t inputs, vars_t outputs, vars_t output_grads) { | |||||
| auto bg = OpDef::make_backward_graph( | auto bg = OpDef::make_backward_graph( | ||||
| *expr.op, builder.get_descs(expr.inputs), | |||||
| grad_context.get_require_grads(expr.inputs), | |||||
| grad_context.get_has_grads(expr.outputs)); | |||||
| *op, builder.get_descs(inputs), | |||||
| grad_context.get_require_grads(inputs), | |||||
| grad_context.get_has_grads(outputs)); | |||||
| if (bg.graph.empty()) { | if (bg.graph.empty()) { | ||||
| return vars_t(expr.inputs.size(), 0); | |||||
| return vars_t(inputs.size(), 0); | |||||
| } | } | ||||
| vars_t grad_inputs; | vars_t grad_inputs; | ||||
| grad_inputs.insert( | |||||
| grad_inputs.end(), expr.inputs.begin(), expr.inputs.end()); | |||||
| grad_inputs.insert( | |||||
| grad_inputs.end(), expr.outputs.begin(), expr.outputs.end()); | |||||
| grad_inputs.insert(grad_inputs.end(), inputs.begin(), inputs.end()); | |||||
| grad_inputs.insert(grad_inputs.end(), outputs.begin(), outputs.end()); | |||||
| grad_inputs.insert( | grad_inputs.insert( | ||||
| grad_inputs.end(), output_grads.begin(), output_grads.end()); | grad_inputs.end(), output_grads.begin(), output_grads.end()); | ||||
| auto apply_functor = | auto apply_functor = | ||||
| @@ -183,6 +184,77 @@ EncodedSubgraph make_backward_graph( | |||||
| forward_graph, inputs, input_requires_grad, output_has_grad); | forward_graph, inputs, input_requires_grad, output_has_grad); | ||||
| } | } | ||||
| EncodedSubgraph make_from_computing_graph( | |||||
| const VarNodeArray& inputs, const VarNodeArray& outputs) { | |||||
| Subgraph subgraph; | |||||
| std::unordered_map<VarNode*, size_t> var2idx; | |||||
| size_t next_idx = 0; | |||||
| var2idx[nullptr] = next_idx++; | |||||
| for (auto&& input : inputs) { | |||||
| if (input) { | |||||
| var2idx[input] = next_idx++; | |||||
| } | |||||
| } | |||||
| auto is_tensor_holder = [](cg::OperatorNodeBase* op) { | |||||
| return op->input().empty(); | |||||
| }; | |||||
| auto as_tensor = [](VarNode* var) -> TensorPtr { | |||||
| auto* opr = var->owner_opr(); | |||||
| if (auto* imm_tensor = opr->try_cast_final<opr::ImmutableTensor>()) { | |||||
| auto&& dv = imm_tensor->value(); | |||||
| HostTensorND hv(dv.comp_node(), dv.shape(), dv.dtype()); | |||||
| // get host value | |||||
| auto&& cpu_value = imm_tensor->host_value(); | |||||
| mgb_assert(cpu_value.comp_node() == CompNode::default_cpu()); | |||||
| // default_cpu is synchronous with respect to caller | |||||
| hv.proxy_to_default_cpu().copy_from_fixlayout(cpu_value); | |||||
| return Tensor::make(dv, hv); | |||||
| } else if ( | |||||
| auto* shared_tensor = opr->try_cast_final<opr::SharedDeviceTensor>()) { | |||||
| return Tensor::make(shared_tensor->get_dev_tensor()); | |||||
| } else { | |||||
| mgb_assert( | |||||
| false, "unsupported tensor holder opr %s", | |||||
| opr->dyn_typeinfo()->name); | |||||
| } | |||||
| }; | |||||
| cg::DepOprIter iter{[&](cg::OperatorNodeBase* op) { | |||||
| // TODO: implement make_backward_graph for mm ops | |||||
| // mgb_assert(!op->node_prop().contain(cg::OperatorNodeProp::Flag::IMPURE_FUNC)); | |||||
| if (is_tensor_holder(op)) { | |||||
| for (auto&& output : op->usable_output()) { | |||||
| subgraph.constants.push_back( | |||||
| {var2idx[output] = next_idx++, as_tensor(output)}); | |||||
| } | |||||
| } else { | |||||
| Subgraph::vars_t inputs; | |||||
| Subgraph::vars_t outputs; | |||||
| for (auto&& input : op->input()) { | |||||
| inputs.push_back(var2idx.at(input)); | |||||
| } | |||||
| // NOTE: use usable_output | |||||
| for (auto&& output : op->usable_output()) { | |||||
| outputs.push_back(var2idx[output] = next_idx++); | |||||
| } | |||||
| auto opdef = OpDef::make_from_op_node(op); | |||||
| subgraph.exprs.push_back({opdef, inputs, outputs}); | |||||
| } | |||||
| }}; | |||||
| for (auto&& input : inputs) { | |||||
| if (input) { | |||||
| iter.set_visited(input->owner_opr()); | |||||
| } | |||||
| subgraph.inputs.push_back(var2idx.at(input)); | |||||
| } | |||||
| for (auto&& output : outputs) { | |||||
| if (output) { | |||||
| iter.add(output); | |||||
| } | |||||
| subgraph.outputs.push_back(var2idx.at(output)); | |||||
| } | |||||
| return EncodedSubgraph::make(subgraph); | |||||
| } | |||||
| } // namespace subgraph_detail | } // namespace subgraph_detail | ||||
| } // namespace imperative | } // namespace imperative | ||||
| } // namespace mgb | } // namespace mgb | ||||
| @@ -189,12 +189,17 @@ struct EncodedSubgraph { | |||||
| size_t hash() const; | size_t hash() const; | ||||
| }; | }; | ||||
| template <typename T> | |||||
| template <typename TOp, typename TVar> | |||||
| class GradContext { | class GradContext { | ||||
| public: | public: | ||||
| using var_t = T; | |||||
| using op_t = TOp; | |||||
| using var_t = TVar; | |||||
| using vars_t = SmallVector<var_t>; | using vars_t = SmallVector<var_t>; | ||||
| using expr_t = Expr<T>; | |||||
| struct expr_t { | |||||
| op_t op; | |||||
| vars_t inputs; | |||||
| vars_t outputs; | |||||
| }; | |||||
| private: | private: | ||||
| std::unordered_map<var_t, var_t> m_grads; | std::unordered_map<var_t, var_t> m_grads; | ||||
| @@ -219,6 +224,7 @@ public: | |||||
| } | } | ||||
| return mask; | return mask; | ||||
| } | } | ||||
| void mark_require_grad(var_t dest) { m_vars_require_grad.insert(dest); } | |||||
| void mark_require_grads(vars_t dests) { | void mark_require_grads(vars_t dests) { | ||||
| for (auto&& dest : dests) { | for (auto&& dest : dests) { | ||||
| m_vars_require_grad.insert(dest); | m_vars_require_grad.insert(dest); | ||||
| @@ -231,7 +237,7 @@ public: | |||||
| return m_grads[dest] = m_accumulator(m_grads[dest], grad); | return m_grads[dest] = m_accumulator(m_grads[dest], grad); | ||||
| } | } | ||||
| } | } | ||||
| void record_expr(std::shared_ptr<OpDef> op, vars_t inputs, vars_t outputs) { | |||||
| void record_expr(op_t op, vars_t inputs, vars_t outputs) { | |||||
| bool require_grad = false; | bool require_grad = false; | ||||
| for (auto&& input : inputs) { | for (auto&& input : inputs) { | ||||
| if (m_vars_require_grad.count(input)) { | if (m_vars_require_grad.count(input)) { | ||||
| @@ -254,7 +260,8 @@ public: | |||||
| std::reverse(exprs.begin(), exprs.end()); | std::reverse(exprs.begin(), exprs.end()); | ||||
| for (const expr_t& expr : exprs) { | for (const expr_t& expr : exprs) { | ||||
| size_t nr_inputs = expr.inputs.size(); | size_t nr_inputs = expr.inputs.size(); | ||||
| vars_t input_grads = functor(expr, get_grads(expr.outputs)); | |||||
| vars_t input_grads = functor( | |||||
| expr.op, expr.inputs, expr.outputs, get_grads(expr.outputs)); | |||||
| mgb_assert(input_grads.size() == nr_inputs, "input size mismatch"); | mgb_assert(input_grads.size() == nr_inputs, "input size mismatch"); | ||||
| for (size_t i = 0; i < nr_inputs; ++i) { | for (size_t i = 0; i < nr_inputs; ++i) { | ||||
| if (input_grads[i] && m_vars_require_grad.count(expr.inputs[i])) { | if (input_grads[i] && m_vars_require_grad.count(expr.inputs[i])) { | ||||
| @@ -43,6 +43,8 @@ EncodedSubgraph make_backward_graph_from_forward( | |||||
| const EncodedSubgraph& forward, const SmallVector<LogicalTensorDesc>& inputs, | const EncodedSubgraph& forward, const SmallVector<LogicalTensorDesc>& inputs, | ||||
| const SmallVector<bool>& input_requires_grad, | const SmallVector<bool>& input_requires_grad, | ||||
| const SmallVector<bool>& output_has_grad); | const SmallVector<bool>& output_has_grad); | ||||
| EncodedSubgraph make_from_computing_graph( | |||||
| const VarNodeArray& inputs, const VarNodeArray& outputs); | |||||
| } // namespace subgraph_detail | } // namespace subgraph_detail | ||||
| } // namespace imperative | } // namespace imperative | ||||