|
|
|
@@ -11,10 +11,13 @@ |
|
|
|
|
|
|
|
#include "./proxy_graph.h" |
|
|
|
#include "./blob_manager_impl.h" |
|
|
|
#include "megbrain/graph.h" |
|
|
|
#include "megbrain/graph/operator_node.h" |
|
|
|
#include "megbrain/graph/static_infer.h" |
|
|
|
#include "megbrain/imperative/ops/backward_graph.h" |
|
|
|
#include "megbrain/imperative/ops/opr_attr.h" |
|
|
|
#include "megbrain/imperative/subgraph_detail.h" |
|
|
|
#include "megbrain/opr/basic_arith.h" |
|
|
|
#include "megbrain/opr/internal/megdnn_opr_wrapper.h" |
|
|
|
#include "megbrain/opr/io.h" |
|
|
|
#include "megbrain/opr/tensor_manip.h" |
|
|
|
@@ -486,139 +489,83 @@ EncodedSubgraph ProxyGraph::make_backward_graph( |
|
|
|
const OpDef& opdef, const SmallVector<LogicalTensorDesc>& input_descs, |
|
|
|
const SmallVector<bool>& input_requires_grad, |
|
|
|
const SmallVector<bool>& output_has_grad) { |
|
|
|
ThinHashMap<VarNode*, size_t> var2idx; |
|
|
|
auto push = [&var2idx, |
|
|
|
cnt = 1](VarNode* var) mutable { // cnt is always greater non zero |
|
|
|
auto&& ret = var2idx.emplace(var, cnt++); |
|
|
|
mgb_assert(ret.second, "var %s has been already inserted", var->cname()); |
|
|
|
return ret.first->second; |
|
|
|
}; |
|
|
|
using op_t = OperatorNodeBase*; |
|
|
|
using var_t = VarNode*; |
|
|
|
using vars_t = VarNodeArray; |
|
|
|
auto inputs = make_input_place_holders(input_descs); |
|
|
|
auto fwd = OpDef::apply_on_var_node(opdef, inputs)[0]->owner_opr(); |
|
|
|
auto&& outputs = fwd->usable_output(); |
|
|
|
auto outputs = OpDef::apply_on_var_node(opdef, inputs); |
|
|
|
SmallVector<LogicalTensorDesc> output_descs; |
|
|
|
for (auto&& i : outputs) { |
|
|
|
output_descs.push_back({TensorLayout{i->dtype()}, i->comp_node()}); |
|
|
|
} |
|
|
|
GradContext<op_t, var_t> grad_context{[&](VarNode* lhs, VarNode* rhs) -> VarNode* { |
|
|
|
auto add = opr::Elemwise::Mode::ADD; |
|
|
|
return opr::Elemwise::make(VarNodeArray{lhs, rhs}, add).node(); |
|
|
|
}}; |
|
|
|
cg::DepOprIter iter{[&](OperatorNodeBase* op) { |
|
|
|
grad_context.record_expr(op, op->input(), op->output()); |
|
|
|
}}; |
|
|
|
for (size_t i = 0; i < inputs.size(); ++i) { |
|
|
|
auto& input = inputs[i]; |
|
|
|
iter.set_visited(input->owner_opr()); |
|
|
|
if (input_requires_grad[i]) { |
|
|
|
grad_context.mark_require_grad(input); |
|
|
|
} |
|
|
|
} |
|
|
|
for (auto&& output : outputs) { |
|
|
|
iter.add(output); |
|
|
|
} |
|
|
|
auto output_grads = make_input_place_holders(output_descs); |
|
|
|
mgb_assert( |
|
|
|
output_grads.size() == output_has_grad.size(), "%d vs %d", |
|
|
|
output_grads.size(), output_has_grad.size()); |
|
|
|
bool any_input_has_grad = false; |
|
|
|
for (size_t i = 0; i < output_grads.size(); ++i) { |
|
|
|
for (size_t i = 0; i < outputs.size(); ++i) { |
|
|
|
if (!output_has_grad[i]) { |
|
|
|
output_grads[i] = nullptr; |
|
|
|
} else { |
|
|
|
any_input_has_grad = true; |
|
|
|
} |
|
|
|
} |
|
|
|
if (!any_input_has_grad) { |
|
|
|
return {}; |
|
|
|
} |
|
|
|
auto* gfunc = cg::lookup_grad_func(fwd->dyn_typeinfo()); |
|
|
|
|
|
|
|
EncodedSubgraph result; |
|
|
|
auto&& igraph = result.graph; |
|
|
|
|
|
|
|
size_t nr_backward_graph_inputs = 0; |
|
|
|
auto gen_expr = [this, &var2idx, &igraph, &push, &fwd, |
|
|
|
&nr_backward_graph_inputs](cg::OperatorNodeBase* op) { |
|
|
|
if (auto t = as_tensor(op)) { |
|
|
|
mgb_assert(op->output().size() == 1); |
|
|
|
igraph.constants.emplace_back(push(op->output(0)), std::move(t)); |
|
|
|
} else if (op->same_type<InputPlaceholder>()) { |
|
|
|
++nr_backward_graph_inputs; |
|
|
|
push(op->output(0)); |
|
|
|
} else { |
|
|
|
SmallVector<size_t> inputs, outputs; |
|
|
|
for (auto&& i : op->input()) { |
|
|
|
if (i->owner_opr() == fwd) { |
|
|
|
if (var2idx.find(i) == var2idx.end()) { |
|
|
|
++nr_backward_graph_inputs; |
|
|
|
push(i); |
|
|
|
} |
|
|
|
} |
|
|
|
inputs.push_back(var2idx.at(i)); |
|
|
|
} |
|
|
|
for (auto&& i : op->usable_output()) { |
|
|
|
outputs.push_back(push(i)); |
|
|
|
auto compute_input_grads = [&](op_t op, vars_t inputs, vars_t outputs, |
|
|
|
vars_t output_grads) { |
|
|
|
auto* gfunc = cg::lookup_grad_func(op->dyn_typeinfo()); |
|
|
|
vars_t input_grads(inputs.size(), nullptr); |
|
|
|
bool any_grad = false; |
|
|
|
for (auto&& output_grad : output_grads) { |
|
|
|
if (output_grad) { |
|
|
|
any_grad = true; |
|
|
|
} |
|
|
|
igraph.exprs.push_back({OpDef::make_from_op_node(op), inputs, outputs}); |
|
|
|
} |
|
|
|
}; |
|
|
|
|
|
|
|
// set backward graph outputs |
|
|
|
cg::DepOprIter iter{gen_expr}; |
|
|
|
iter.set_visited(fwd); |
|
|
|
result.output_mask.resize(inputs.size()); |
|
|
|
|
|
|
|
VarNodeArray output_grads_with_unused_var; |
|
|
|
{ |
|
|
|
auto iter = output_grads.begin(); |
|
|
|
for (auto&& i : fwd->output()) { |
|
|
|
if (i->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) { |
|
|
|
// the var node with VOLATILE_CONTENT(e.g. workspace |
|
|
|
// or an empty var) would not be considered as a normal |
|
|
|
// output, so its grad is always NULL |
|
|
|
output_grads_with_unused_var.push_back(nullptr); |
|
|
|
} else { |
|
|
|
output_grads_with_unused_var.push_back(*iter); |
|
|
|
++iter; |
|
|
|
} |
|
|
|
if (!gfunc || !any_grad) { |
|
|
|
return input_grads; |
|
|
|
} |
|
|
|
mgb_assert(iter == output_grads.end()); |
|
|
|
} |
|
|
|
|
|
|
|
Maybe<VarNodeArray> grad_results; |
|
|
|
for (size_t i = 0; i < inputs.size(); ++i) { |
|
|
|
VarNode* grad; |
|
|
|
if (grad_results.valid()) { |
|
|
|
grad = grad_results.val()[i]; |
|
|
|
} else { |
|
|
|
mgb_assert(gfunc, "could not find grad function"); |
|
|
|
auto res = (*gfunc)(fwd, i, output_grads_with_unused_var); |
|
|
|
if (res.from_single()) { |
|
|
|
grad = res.single(); |
|
|
|
} else { |
|
|
|
grad_results.emplace(res.all(fwd)); |
|
|
|
Maybe<VarNodeArray> grad_results; |
|
|
|
auto&& input_requires_grad = grad_context.get_require_grads(inputs); |
|
|
|
for (size_t i = 0; i < inputs.size(); ++i) { |
|
|
|
VarNode* grad; |
|
|
|
if (grad_results.valid()) { |
|
|
|
grad = grad_results.val()[i]; |
|
|
|
} |
|
|
|
} |
|
|
|
if (grad && !grad->owner_opr()->same_type<opr::InvalidGrad>() && |
|
|
|
input_requires_grad[i]) { |
|
|
|
mgb_assert( |
|
|
|
!grad->owner_opr()->same_type<opr::InvalidGrad>(), |
|
|
|
"gradient of operator %s w.r.t. input #%lu is " |
|
|
|
"either not well defined or not implemented", |
|
|
|
fwd->dyn_typeinfo()->name, i); |
|
|
|
iter.add(grad); |
|
|
|
igraph.outputs.push_back(var2idx.at(grad)); |
|
|
|
result.output_mask[i] = true; |
|
|
|
} else { |
|
|
|
result.output_mask[i] = false; |
|
|
|
} |
|
|
|
} |
|
|
|
if (igraph.outputs.empty()) { |
|
|
|
return {}; |
|
|
|
} |
|
|
|
|
|
|
|
// set backward graph inputs |
|
|
|
auto write_inputs = [&igraph, &var2idx, &result](const VarNodeArray& vars) { |
|
|
|
for (auto&& i : vars) { |
|
|
|
auto&& iter = var2idx.find(i); |
|
|
|
if (iter != var2idx.end()) { |
|
|
|
igraph.inputs.push_back(iter->second); |
|
|
|
result.input_mask.push_back(true); |
|
|
|
} else { |
|
|
|
result.input_mask.push_back(false); |
|
|
|
mgb_assert(gfunc, "could not find grad function"); |
|
|
|
auto res = (*gfunc)(op, i, output_grads); |
|
|
|
if (res.from_single()) { |
|
|
|
grad = res.single(); |
|
|
|
} else { |
|
|
|
grad_results.emplace(res.all(op)); |
|
|
|
grad = grad_results.val()[i]; |
|
|
|
} |
|
|
|
} |
|
|
|
if (grad && !grad->owner_opr()->same_type<opr::InvalidGrad>()) { |
|
|
|
if (input_requires_grad[i]) { |
|
|
|
input_grads[i] = grad; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
return input_grads; |
|
|
|
}; |
|
|
|
write_inputs(inputs); |
|
|
|
write_inputs(outputs); |
|
|
|
write_inputs(output_grads); |
|
|
|
mgb_assert(igraph.inputs.size() == nr_backward_graph_inputs); |
|
|
|
return result; |
|
|
|
grad_context.backward(outputs, output_grads, compute_input_grads); |
|
|
|
auto input_grads = grad_context.get_grads(inputs); |
|
|
|
VarNodeArray bgraph_inputs; |
|
|
|
bgraph_inputs.insert(bgraph_inputs.end(), inputs.begin(), inputs.end()); |
|
|
|
bgraph_inputs.insert(bgraph_inputs.end(), outputs.begin(), outputs.end()); |
|
|
|
bgraph_inputs.insert(bgraph_inputs.end(), output_grads.begin(), output_grads.end()); |
|
|
|
auto graph = subgraph_detail::make_from_computing_graph(bgraph_inputs, input_grads); |
|
|
|
return graph; |
|
|
|
} |
|
|
|
|
|
|
|
VarNodeArray ProxyGraph::make_input_place_holders( |
|
|
|
|