| @@ -18,7 +18,6 @@ import numpy as np | |||||
| from .. import _imperative_rt | from .. import _imperative_rt | ||||
| from .._imperative_rt import GraphOptimizeOptions | from .._imperative_rt import GraphOptimizeOptions | ||||
| from .._imperative_rt.core2 import apply, set_cpp_apply_backward_varnode | from .._imperative_rt.core2 import apply, set_cpp_apply_backward_varnode | ||||
| from .._imperative_rt.ops import BackwardGraph | |||||
| from .._wrap import device as as_device | from .._wrap import device as as_device | ||||
| from ..ops.builtin import OpDef | from ..ops.builtin import OpDef | ||||
| from .core import TensorBase | from .core import TensorBase | ||||
| @@ -481,21 +480,6 @@ def apply_normal_varnode(op: OpDef, *args: VarNode): | |||||
| return _wrap(outputs) | return _wrap(outputs) | ||||
| def apply_backward_varnode(op: BackwardGraph, *args: VarNode): | |||||
| assert args | |||||
| graph = args[0].graph | |||||
| outputs = op.interpret( | |||||
| op, | |||||
| lambda op, args: apply_normal_varnode(op, *args), | |||||
| graph._make_const_for_backward, | |||||
| args, | |||||
| ) | |||||
| return outputs | |||||
| set_cpp_apply_backward_varnode(apply_backward_varnode) | |||||
| def input_callback(callback, *args, device=None, dtype=None, shape=None, graph=None): | def input_callback(callback, *args, device=None, dtype=None, shape=None, graph=None): | ||||
| outputs = _imperative_rt.input_callback( | outputs = _imperative_rt.input_callback( | ||||
| callback, as_device(device).to_c(), dtype, shape, _unwrap(args), graph=graph | callback, as_device(device).to_c(), dtype, shape, _unwrap(args), graph=graph | ||||
| @@ -32,7 +32,7 @@ from ..core._imperative_rt.ops import ( | |||||
| ) | ) | ||||
| from ..core._trace_option import set_symbolic_shape | from ..core._trace_option import set_symbolic_shape | ||||
| from ..core._wrap import device as as_device | from ..core._wrap import device as as_device | ||||
| from ..core.ops.builtin import BackwardGraph, BatchNorm, OpDef | |||||
| from ..core.ops.builtin import BatchNorm, OpDef | |||||
| from ..core.ops.special import Const | from ..core.ops.special import Const | ||||
| from ..core.tensor import megbrain_graph as G | from ..core.tensor import megbrain_graph as G | ||||
| from ..core.tensor.utils import setscalar | from ..core.tensor.utils import setscalar | ||||
| @@ -587,10 +587,7 @@ class trace: | |||||
| ivars.append(info.varnode) | ivars.append(info.varnode) | ||||
| if isinstance(op, BackwardGraph): | |||||
| ovars = G.apply_backward_varnode(op, *ivars) | |||||
| else: | |||||
| ovars = G.apply_normal_varnode(op, *ivars) | |||||
| ovars = G.apply_normal_varnode(op, *ivars) | |||||
| if require_links and len(ovars) > 0: | if require_links and len(ovars) > 0: | ||||
| io_links = (ovars[0],) | io_links = (ovars[0],) | ||||
| @@ -805,14 +802,11 @@ class trace: | |||||
| name=info.name, | name=info.name, | ||||
| ) | ) | ||||
| ivars.append(h2v[h]) | ivars.append(h2v[h]) | ||||
| if isinstance(op, BackwardGraph): | |||||
| ovars = G.apply_backward_varnode(op, *ivars) | |||||
| else: | |||||
| if isinstance(op, BatchNorm): | |||||
| assert ( | |||||
| op.fwd_mode == BatchNorm.FwdMode.INFERENCE | |||||
| ), "can not dump BatchNorm in training mode, maybe you forget to do model.eval()?" | |||||
| ovars = G.apply_normal_varnode(op, *ivars) | |||||
| if isinstance(op, BatchNorm): | |||||
| assert ( | |||||
| op.fwd_mode == BatchNorm.FwdMode.INFERENCE | |||||
| ), "can not dump BatchNorm in training mode, maybe you forget to do model.eval()?" | |||||
| ovars = G.apply_normal_varnode(op, *ivars) | |||||
| AutoNaming.record_opnode(ovars[0].op) | AutoNaming.record_opnode(ovars[0].op) | ||||
| @@ -1088,10 +1082,7 @@ def apply_symbolic_mode(op: OpDef, *args: RawTensor): | |||||
| ivars[0] = opnode.outputs[0] | ivars[0] = opnode.outputs[0] | ||||
| active_trace._lazy_eval_links = (ivars[0],) | active_trace._lazy_eval_links = (ivars[0],) | ||||
| if isinstance(op, BackwardGraph): | |||||
| ovars = G.apply_backward_varnode(op, *ivars) | |||||
| else: | |||||
| ovars = G.apply_normal_varnode(op, *ivars) | |||||
| ovars = G.apply_normal_varnode(op, *ivars) | |||||
| outputs = [RawTensor(o) for o in ovars] | outputs = [RawTensor(o) for o in ovars] | ||||
| if require_links: | if require_links: | ||||
| @@ -75,9 +75,9 @@ std::shared_ptr<OptimizedBackwardGraphResult> make_backward_graph( | |||||
| input_requires_grad[i] = python::input_requires_grad(ctx, i); | input_requires_grad[i] = python::input_requires_grad(ctx, i); | ||||
| } | } | ||||
| std::shared_ptr<OptimizedBackwardGraphResult> ret; | std::shared_ptr<OptimizedBackwardGraphResult> ret; | ||||
| auto bg = proxy_graph_detail::make_backward_graph( | |||||
| auto bg = OpDef::make_backward_graph( | |||||
| *ctx.op, inputs, input_requires_grad, output_has_grad); | *ctx.op, inputs, input_requires_grad, output_has_grad); | ||||
| if (bg.backward) { | |||||
| if (!bg.backward.empty()) { | |||||
| ret = std::make_shared<OptimizedBackwardGraphResult>(bg); | ret = std::make_shared<OptimizedBackwardGraphResult>(bg); | ||||
| } | } | ||||
| backward_graph_cache.emplace(key, ret); | backward_graph_cache.emplace(key, ret); | ||||
| @@ -112,7 +112,7 @@ struct BackwardGraphWithClosure { | |||||
| size_t count = std::count_if(save_for_backward.begin(), | size_t count = std::count_if(save_for_backward.begin(), | ||||
| save_for_backward.end(), | save_for_backward.end(), | ||||
| ranges::identity{}); | ranges::identity{}); | ||||
| if (backward_graph->precomp) { | |||||
| if (!backward_graph->precomp.empty()) { | |||||
| auto&& irng = ranges::span(ctx.args, ctx.nargs); | auto&& irng = ranges::span(ctx.args, ctx.nargs); | ||||
| auto&& orng = views::transform(outputs, [](auto&& i){return i.get();}); | auto&& orng = views::transform(outputs, [](auto&& i){return i.get();}); | ||||
| auto precomp = apply(backward_graph->precomp, views::concat(irng, orng)); | auto precomp = apply(backward_graph->precomp, views::concat(irng, orng)); | ||||
| @@ -30,26 +30,14 @@ using namespace imperative; | |||||
| using namespace interpreter; | using namespace interpreter; | ||||
| namespace { | |||||
| std::optional<std::tuple<std::shared_ptr<OpDef>, std::vector<bool>, std::vector<bool>>> | |||||
| make_backward_graph( | |||||
| const OpDef& opdef, std::vector<LogicalTensorDesc> inputs, | |||||
| std::vector<bool> input_requires_grad, | |||||
| std::vector<bool> output_has_grad) { | |||||
| auto res = OpDef::make_backward_graph(opdef, | |||||
| SmallVector<LogicalTensorDesc>(inputs.begin(), inputs.end()), | |||||
| SmallVector<bool>(input_requires_grad.begin(), input_requires_grad.end()), | |||||
| SmallVector<bool>(output_has_grad.begin(), output_has_grad.end())); | |||||
| if (res.backward) { | |||||
| return std::optional<std::tuple<std::shared_ptr<OpDef>, std::vector<bool>, std::vector<bool>>>{ | |||||
| std::in_place, res.backward, res.save_for_backward, res.input_has_grad}; | |||||
| } else { | |||||
| return {}; | |||||
| } | |||||
| } | |||||
| } // namespace | |||||
| void init_imperative_rt(py::module m) { | void init_imperative_rt(py::module m) { | ||||
| m.def("make_backward_graph", &make_backward_graph); | |||||
| auto make_backward_graph = []( | |||||
| const OpDef& def, | |||||
| const SmallVector<LogicalTensorDesc>& inputs, | |||||
| const SmallVector<bool>& input_requires_grad, | |||||
| const SmallVector<bool>& output_has_grad){ | |||||
| auto result = OpDef::make_backward_graph(def, inputs, input_requires_grad, output_has_grad); | |||||
| return std::make_tuple("backward_graph", result.save_for_backward, result.input_has_grad); | |||||
| }; | |||||
| m.def("make_backward_graph", make_backward_graph); | |||||
| } | } | ||||
| @@ -367,42 +367,6 @@ void _init_py_op_def(py::module m) { | |||||
| } | } | ||||
| /*********** begin of hand-write opdefs **************/ | /*********** begin of hand-write opdefs **************/ | ||||
| PyOpDefBegin(BackwardGraph) // {{ | |||||
| // }; | |||||
| PyOpDefEnd(BackwardGraph) | |||||
| void _init_py_backward_graph(py::module m) { | |||||
| using py_op = PyOp(BackwardGraph); | |||||
| auto& py_type = PyOpType(BackwardGraph); | |||||
| py_type = {PyVarObject_HEAD_INIT(NULL, 0)}; | |||||
| py_type.tp_name = "megengine.core._imperative_rt.ops.BackwardGraph"; | |||||
| py_type.tp_basicsize = sizeof(PyOp(BackwardGraph)); | |||||
| py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; | |||||
| py_type.tp_doc = "BackwardGraph"; | |||||
| py_type.tp_base = &PyOpType(OpDef); | |||||
| py_type.tp_dealloc = py_dealloc_generic<py_op>; | |||||
| py_type.tp_new = py_new_generic<py_op>; | |||||
| mgb_assert(PyType_Ready(&py_type) >= 0); | |||||
| // FIXME: rewrite interpret function in cpython instead wrap directly by pybind11::cppfunction | |||||
| auto interpret = py::cpp_function( | |||||
| [](OpDef& self, py::object pyf, py::object pyc, | |||||
| const mgb::SmallVector<py::object>& inputs) { | |||||
| auto f = [pyf](OpDef& op, const mgb::SmallVector<py::object>& inputs) { | |||||
| return py::cast<mgb::SmallVector<py::object>>(pyf(op.shared_from_this(), inputs)); | |||||
| }; | |||||
| auto c = [pyc](const TensorPtr& tensor) { | |||||
| return pyc(tensor->dev_tensor()); | |||||
| }; | |||||
| return self.cast_final_safe<BackwardGraph>().graph().interpret<py::object>(f, c, inputs); | |||||
| }); | |||||
| mgb_assert(PyDict_SetItemString( | |||||
| py_type.tp_dict, "interpret", interpret.release().ptr()) >= 0); | |||||
| PyType_Modified(&py_type); | |||||
| m.add_object("BackwardGraph", reinterpret_cast<PyObject*>(&py_type)); | |||||
| mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(BackwardGraph::typeinfo(), &py_type).second); | |||||
| } | |||||
| struct PyOpBase : PyOpDef { | struct PyOpBase : PyOpDef { | ||||
| static PyTypeObject py_type; | static PyTypeObject py_type; | ||||
| @@ -496,7 +460,6 @@ FOR_EACH_BIT_COMBINED_ENUM_PARAM(BIT_COMBINED_ENUM_CASTER_IMPL) | |||||
| void init_ops(py::module m) { | void init_ops(py::module m) { | ||||
| _init_py_op_def(m); | _init_py_op_def(m); | ||||
| _init_py_backward_graph(m); | |||||
| _init_py_op_base(m); | _init_py_op_base(m); | ||||
| INIT_ALL_OP(m) | INIT_ALL_OP(m) | ||||
| @@ -156,9 +156,6 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje | |||||
| ctx.args = &tensors[0]; | ctx.args = &tensors[0]; | ||||
| ctx.nargs = nargs; | ctx.nargs = nargs; | ||||
| ctx.pytype = pytype; | ctx.pytype = pytype; | ||||
| if (ctx.op->same_type<BackwardGraph>()) { | |||||
| ctx.backward = true; | |||||
| } | |||||
| if (py::isinstance<PySymbolVar>(py::handle(args[0]))){ | if (py::isinstance<PySymbolVar>(py::handle(args[0]))){ | ||||
| SmallVector<cg::VarNode*> vinputs(nargs); | SmallVector<cg::VarNode*> vinputs(nargs); | ||||
| @@ -248,31 +248,53 @@ apply_result_t apply(std::shared_ptr<OpDef> op, Args&&... args) { | |||||
| return apply(ctx); | return apply(ctx); | ||||
| } | } | ||||
| template <typename T> | |||||
| auto apply(std::shared_ptr<OpDef> op, T&& tensors) | |||||
| -> std::enable_if_t<std::is_same_v<decltype(resolve_arrow(tensors[0])), Tensor*>, | |||||
| apply_result_t> { | |||||
| inline auto apply(std::shared_ptr<OpDef> op, Tensor*const* args, size_t nargs) { | |||||
| ApplyContext ctx; | ApplyContext ctx; | ||||
| ctx.op = std::move(op); | ctx.op = std::move(op); | ||||
| ctx.nargs = tensors.size(); | |||||
| Tensor* args[ctx.nargs]; | |||||
| ctx.nargs = nargs; | |||||
| ctx.args = args; | ctx.args = args; | ||||
| for (size_t i = 0; i < ctx.nargs; ++i) { | |||||
| args[i] = resolve_arrow(tensors[i]); | |||||
| for (size_t i = 0; i < nargs; ++i) { | |||||
| ctx.flags |= args[i]->m_flags; | ctx.flags |= args[i]->m_flags; | ||||
| } | } | ||||
| return apply(ctx); | return apply(ctx); | ||||
| } | } | ||||
| inline auto apply(std::shared_ptr<OpDef> op, Tensor*const* args, size_t nargs) { | |||||
| ApplyContext ctx; | |||||
| ctx.op = std::move(op); | |||||
| ctx.nargs = nargs; | |||||
| ctx.args = args; | |||||
| template <typename T> | |||||
| auto apply(std::shared_ptr<OpDef> op, T&& tensors) | |||||
| -> std::enable_if_t<std::is_same_v<decltype(resolve_arrow(tensors[0])), Tensor*>, | |||||
| apply_result_t> { | |||||
| size_t nargs = tensors.size(); | |||||
| Tensor* args[nargs]; | |||||
| for (size_t i = 0; i < nargs; ++i) { | for (size_t i = 0; i < nargs; ++i) { | ||||
| ctx.flags |= args[i]->m_flags; | |||||
| args[i] = resolve_arrow(tensors[i]); | |||||
| } | } | ||||
| return apply(ctx); | |||||
| return apply(op, args, nargs); | |||||
| } | |||||
| inline auto apply(Subgraph graph, Tensor*const* args, size_t nargs) { | |||||
| SmallVector<std::shared_ptr<Tensor>> inputs; | |||||
| for (size_t i = 0; i < nargs; ++i) { | |||||
| inputs.push_back(args[i]->shared_from_this()); | |||||
| } | |||||
| auto apply_functor = [](std::shared_ptr<OpDef> op, SmallVector<std::shared_ptr<Tensor>> inputs) { | |||||
| return apply(op, inputs); | |||||
| }; | |||||
| auto const_functor = [](imperative::TensorPtr value) { | |||||
| return std::make_shared<Tensor>(interpreter_for_py->put(value->dev_tensor())); | |||||
| }; | |||||
| return graph.apply(inputs, apply_functor, const_functor); | |||||
| } | |||||
| template <typename T> | |||||
| auto apply(Subgraph graph, T&& tensors) | |||||
| -> std::enable_if_t<std::is_same_v<decltype(tensors[0]), Tensor*>, | |||||
| apply_result_t> { | |||||
| size_t nargs = tensors.size(); | |||||
| Tensor* args[nargs]; | |||||
| for (size_t i = 0; i < nargs; ++i) { | |||||
| args[i] = resolve_arrow(tensors[i]); | |||||
| } | |||||
| return apply(graph, args, nargs); | |||||
| } | } | ||||
| void init_tensor(pybind11::module); | void init_tensor(pybind11::module); | ||||
| @@ -22,7 +22,7 @@ apply_result_t apply_trace(ApplyContext& ctx) { | |||||
| apply_result_t outputs; | apply_result_t outputs; | ||||
| if (ctx.backward) { | if (ctx.backward) { | ||||
| // call megbrain_graph.py apply(BackwardGraph, *args) | |||||
| // reach here when compiled=True | |||||
| auto args = py::tuple(ctx.nargs + 1); | auto args = py::tuple(ctx.nargs + 1); | ||||
| args[0] = py::cast(ctx.op); | args[0] = py::cast(ctx.op); | ||||
| for (size_t i = 0; i < ctx.nargs; i++) { | for (size_t i = 0; i < ctx.nargs; i++) { | ||||
| @@ -18,24 +18,22 @@ using namespace imperative; | |||||
| OptimizedBackwardGraphResult::OptimizedBackwardGraphResult(const BackwardGraphResult& src) | OptimizedBackwardGraphResult::OptimizedBackwardGraphResult(const BackwardGraphResult& src) | ||||
| : input_has_grad(src.input_has_grad) { | : input_has_grad(src.input_has_grad) { | ||||
| if (!src.backward->same_type<BackwardGraph>()) { | |||||
| if (src.backward.exprs.size() <= 1) { | |||||
| // backward graph only contains a single op | // backward graph only contains a single op | ||||
| backward = src.backward; | backward = src.backward; | ||||
| save_for_backward = src.save_for_backward; | save_for_backward = src.save_for_backward; | ||||
| return; | return; | ||||
| } | } | ||||
| save_for_backward.resize(src.save_for_backward.size(), false); | save_for_backward.resize(src.save_for_backward.size(), false); | ||||
| precomp.reset(new BackwardGraph); | |||||
| backward.reset(new BackwardGraph); | |||||
| auto&& graph = src.backward->cast_final_safe<BackwardGraph>().graph(); | |||||
| auto&& graph = src.backward; | |||||
| auto&& mask = src.save_for_backward; | auto&& mask = src.save_for_backward; | ||||
| size_t input_size = src.input_has_grad.size(); | size_t input_size = src.input_has_grad.size(); | ||||
| size_t output_size = (mask.size() - input_size) / 2; | size_t output_size = (mask.size() - input_size) / 2; | ||||
| mgb_assert(input_size + output_size * 2 == mask.size()); | mgb_assert(input_size + output_size * 2 == mask.size()); | ||||
| auto& fgraph = precomp->cast_final<BackwardGraph>().graph(); | |||||
| auto& bgraph = backward->cast_final<BackwardGraph>().graph(); | |||||
| auto& fgraph = precomp; | |||||
| auto& bgraph = backward; | |||||
| // optimization: move ops (e.g. GetVarShape) to forward to | // optimization: move ops (e.g. GetVarShape) to forward to | ||||
| // reduce memory footprint | // reduce memory footprint | ||||
| @@ -113,6 +111,6 @@ OptimizedBackwardGraphResult::OptimizedBackwardGraphResult(const BackwardGraphRe | |||||
| } | } | ||||
| if (!fgraph.outputs.size()) { | if (!fgraph.outputs.size()) { | ||||
| precomp.reset(); | |||||
| precomp = {}; | |||||
| } | } | ||||
| } | } | ||||
| @@ -911,8 +911,7 @@ auto ChannelImpl::CommandBuffer::flush_pos_for(const Command& cmd) -> Handle { | |||||
| op_type == RemoteSend::typeinfo() || | op_type == RemoteSend::typeinfo() || | ||||
| op_type == CollectiveComm::typeinfo() || | op_type == CollectiveComm::typeinfo() || | ||||
| op_type == opr::InputCallback::typeinfo() || | op_type == opr::InputCallback::typeinfo() || | ||||
| op_type == opr::OutputCallback::typeinfo() || | |||||
| op_type == BackwardGraph::typeinfo()) { | |||||
| op_type == opr::OutputCallback::typeinfo()) { | |||||
| return m_commands.end(); | return m_commands.end(); | ||||
| } | } | ||||
| } else if constexpr (std::is_same_v<T, GetValue>) { | } else if constexpr (std::is_same_v<T, GetValue>) { | ||||
| @@ -10,6 +10,9 @@ | |||||
| */ | */ | ||||
| #include "megbrain/imperative/op_def.h" | #include "megbrain/imperative/op_def.h" | ||||
| #include <sstream> | |||||
| #include "megbrain/imperative/ops/opr_attr.h" | #include "megbrain/imperative/ops/opr_attr.h" | ||||
| #include "./op_trait.h" | #include "./op_trait.h" | ||||
| @@ -117,6 +120,67 @@ const std::string OpDef::make_name() const { | |||||
| return m_scope + "." + trait()->make_name(*this); | return m_scope + "." + trait()->make_name(*this); | ||||
| } | } | ||||
| std::string Subgraph::repr() const { | |||||
| std::ostringstream buf; | |||||
| buf << "("; | |||||
| for (size_t i = 0; i < inputs.size(); ++i) { | |||||
| if (i > 0) buf << ", "; | |||||
| buf << "%" << inputs[i]; | |||||
| } | |||||
| buf << ") => {\n"; | |||||
| auto fmt_const = [](size_t i, const TensorPtr& t) { | |||||
| if (t->shape().ndim == 1 && t->shape()[0] == 1) { | |||||
| auto&& v = t->get_value(); | |||||
| if (v.dtype() == dtype::Float32{}) { | |||||
| return std::to_string(*v.ptr<dt_float32>()); | |||||
| } else if (v.dtype() == dtype::Int32{}) { | |||||
| return std::to_string(*v.ptr<int32_t>()); | |||||
| } | |||||
| } | |||||
| return std::string("%c") + std::to_string(i); | |||||
| }; | |||||
| std::unordered_map<size_t, std::string> const_reps; | |||||
| for (auto&& [i, t] : constants) { | |||||
| const_reps.emplace(i, fmt_const(i, t)); | |||||
| } | |||||
| for (auto& [op, ins, outs] : exprs) { | |||||
| buf << " "; | |||||
| if (outs.size()) { | |||||
| for (size_t i = 0; i < outs.size(); ++i) { | |||||
| if (i > 0) buf << ", "; | |||||
| buf << "%" << outs[i]; | |||||
| } | |||||
| buf << " = "; | |||||
| } | |||||
| if (auto* p = op->try_cast_final<OprAttr>()) { | |||||
| buf << p->type; | |||||
| } else { | |||||
| buf << op->dyn_typeinfo()->name; | |||||
| } | |||||
| for (size_t i : ins) { | |||||
| buf << " "; | |||||
| auto&& it = const_reps.find(i); | |||||
| if (it != const_reps.end()) { | |||||
| buf << it->second; | |||||
| } else { | |||||
| buf << "%" << i; | |||||
| } | |||||
| } | |||||
| buf << "\n"; | |||||
| } | |||||
| buf << " "; | |||||
| if (outputs.size()) { | |||||
| for (size_t i = 0; i < outputs.size(); ++i) { | |||||
| if (i > 0) buf << ", "; | |||||
| buf << "%" << outputs[i]; | |||||
| } | |||||
| } else { | |||||
| buf << "()"; | |||||
| } | |||||
| buf << "\n}\n"; | |||||
| return buf.str(); | |||||
| } | |||||
| } // namespace imperative | } // namespace imperative | ||||
| } // namespace mgb | } // namespace mgb | ||||
| @@ -19,147 +19,6 @@ | |||||
| namespace mgb { | namespace mgb { | ||||
| namespace imperative { | namespace imperative { | ||||
| SmallVector<TensorPtr> | |||||
| BackwardGraph::InternalGraph::apply( | |||||
| const SmallVector<TensorPtr>& inputs) const { | |||||
| return interpret<TensorPtr>( | |||||
| &OpDef::apply_on_physical_tensor, | |||||
| [](const TensorPtr& x) {return x;}, | |||||
| inputs); | |||||
| } | |||||
| std::tuple<SmallVector<LogicalTensorDesc>, bool> BackwardGraph::InternalGraph::infer_attrs( | |||||
| const SmallVector<LogicalTensorDesc>& inputs) const { | |||||
| using TensorAttr = LogicalTensorDesc; | |||||
| ThinHashMap<size_t, TensorAttr> node2attr; | |||||
| auto&& input_nodes = this->inputs; | |||||
| auto&& output_nodes = this->outputs; | |||||
| mgb_assert(inputs.size() == input_nodes.size()); | |||||
| for (size_t i = 0; i < inputs.size(); ++ i) { | |||||
| node2attr[input_nodes[i]] = inputs[i]; | |||||
| } | |||||
| for (auto &&i : constants) { | |||||
| auto* value = i.second->try_get_value(); | |||||
| mgb_assert(value); | |||||
| node2attr[i.first] = TensorAttr{ | |||||
| i.second->layout(), i.second->comp_node(), | |||||
| value->proxy_to_default_cpu()}; | |||||
| } | |||||
| bool validated = true; | |||||
| for (size_t i = 0; i < exprs.size(); ++ i) { | |||||
| auto&& [expr_op, expr_inps, expr_oups] = exprs[i]; | |||||
| SmallVector<TensorAttr> expr_input_descs; | |||||
| for (auto &&inp : expr_inps) { | |||||
| expr_input_descs.push_back(node2attr.at(inp)); | |||||
| } | |||||
| auto [expr_output_descs, expr_validated] = OpDef::infer_output_attrs_fallible( | |||||
| *expr_op, expr_input_descs); | |||||
| validated = validated && expr_validated; | |||||
| mgb_assert(expr_output_descs.size() == expr_oups.size()); | |||||
| for (size_t i = 0; i < expr_output_descs.size(); ++ i) { | |||||
| node2attr[expr_oups[i]] = expr_output_descs[i]; | |||||
| } | |||||
| } | |||||
| SmallVector<TensorAttr> ret; | |||||
| for (auto &&i : output_nodes) { | |||||
| ret.push_back(node2attr.at(i)); | |||||
| } | |||||
| return {ret, validated}; | |||||
| } | |||||
| std::string BackwardGraph::InternalGraph::repr() { | |||||
| std::ostringstream buf; | |||||
| buf << "("; | |||||
| for (size_t i = 0; i < inputs.size(); ++i) { | |||||
| if (i > 0) buf << ", "; | |||||
| buf << "%" << inputs[i]; | |||||
| } | |||||
| buf << ") => {\n"; | |||||
| auto fmt_const = [](size_t i, TensorPtr& t) { | |||||
| if (t->shape().ndim == 1 && t->shape()[0] == 1) { | |||||
| auto&& v = t->get_value(); | |||||
| if (v.dtype() == dtype::Float32{}) { | |||||
| return std::to_string(*v.ptr<dt_float32>()); | |||||
| } else if (v.dtype() == dtype::Int32{}) { | |||||
| return std::to_string(*v.ptr<int32_t>()); | |||||
| } | |||||
| } | |||||
| return std::string("%c") + std::to_string(i); | |||||
| }; | |||||
| std::unordered_map<size_t, std::string> const_reps; | |||||
| for (auto&& [i, t] : constants) { | |||||
| const_reps.emplace(i, fmt_const(i, t)); | |||||
| } | |||||
| for (auto& [op, ins, outs] : exprs) { | |||||
| buf << " "; | |||||
| if (outs.size()) { | |||||
| for (size_t i = 0; i < outs.size(); ++i) { | |||||
| if (i > 0) buf << ", "; | |||||
| buf << "%" << outs[i]; | |||||
| } | |||||
| buf << " = "; | |||||
| } | |||||
| if (auto* p = op->try_cast_final<OprAttr>()) { | |||||
| buf << p->type; | |||||
| } else { | |||||
| buf << op->dyn_typeinfo()->name; | |||||
| } | |||||
| for (size_t i : ins) { | |||||
| buf << " "; | |||||
| auto&& it = const_reps.find(i); | |||||
| if (it != const_reps.end()) { | |||||
| buf << it->second; | |||||
| } else { | |||||
| buf << "%" << i; | |||||
| } | |||||
| } | |||||
| buf << "\n"; | |||||
| } | |||||
| buf << " "; | |||||
| if (outputs.size()) { | |||||
| for (size_t i = 0; i < outputs.size(); ++i) { | |||||
| if (i > 0) buf << ", "; | |||||
| buf << "%" << outputs[i]; | |||||
| } | |||||
| } else { | |||||
| buf << "()"; | |||||
| } | |||||
| buf << "\n}\n"; | |||||
| return buf.str(); | |||||
| } | |||||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(BackwardGraph); | |||||
| namespace { | |||||
| SmallVector<TensorPtr> backward_impl( | |||||
| const OpDef& backward_graph, | |||||
| const SmallVector<TensorPtr>& tensors) { | |||||
| return backward_graph.cast_final_safe<BackwardGraph>() | |||||
| .graph().apply(tensors); | |||||
| } | |||||
| std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_tensor_attrs( | |||||
| const OpDef& backward_graph, | |||||
| const SmallVector<LogicalTensorDesc> inputs) { | |||||
| return backward_graph.cast_final_safe<BackwardGraph>() | |||||
| .graph().infer_attrs(inputs); | |||||
| } | |||||
| std::vector<std::pair<const char*, std::string>> props( | |||||
| const OpDef& backward_graph) { | |||||
| return {}; | |||||
| } | |||||
| OP_TRAIT_REG(BackwardGraph, BackwardGraph) | |||||
| .apply_on_physical_tensor(backward_impl) | |||||
| .infer_output_attrs_fallible(infer_tensor_attrs) | |||||
| .props(props) | |||||
| .fallback(); | |||||
| } // anonymous namespace | |||||
| } // namespace imperative | } // namespace imperative | ||||
| } // namespace mgb | } // namespace mgb | ||||
| @@ -669,8 +669,7 @@ ProxyGraph::make_backward_graph( | |||||
| auto* gfunc = cg::lookup_grad_func(fwd->dyn_typeinfo()); | auto* gfunc = cg::lookup_grad_func(fwd->dyn_typeinfo()); | ||||
| BackwardGraphResult result; | BackwardGraphResult result; | ||||
| auto&& backward = BackwardGraph::make(); | |||||
| auto&& igraph = backward->cast_final_safe<BackwardGraph>().graph(); | |||||
| auto&& igraph = result.backward; | |||||
| size_t nr_backward_graph_inputs = 0; | size_t nr_backward_graph_inputs = 0; | ||||
| auto gen_expr = [this, &var2idx, &igraph, &push, &fwd, | auto gen_expr = [this, &var2idx, &igraph, &push, &fwd, | ||||
| @@ -682,7 +681,7 @@ ProxyGraph::make_backward_graph( | |||||
| ++ nr_backward_graph_inputs; | ++ nr_backward_graph_inputs; | ||||
| push(op->output(0)); | push(op->output(0)); | ||||
| } else { | } else { | ||||
| std::vector<size_t> inputs, outputs; | |||||
| SmallVector<size_t> inputs, outputs; | |||||
| for (auto &&i : op->input()) { | for (auto &&i : op->input()) { | ||||
| if (i->owner_opr() == fwd) { | if (i->owner_opr() == fwd) { | ||||
| if (var2idx.find(i) == var2idx.end()) { | if (var2idx.find(i) == var2idx.end()) { | ||||
| @@ -695,7 +694,7 @@ ProxyGraph::make_backward_graph( | |||||
| for (auto &&i : op->usable_output()) { | for (auto &&i : op->usable_output()) { | ||||
| outputs.push_back(push(i)); | outputs.push_back(push(i)); | ||||
| } | } | ||||
| igraph.exprs.emplace_back(OpDef::make_from_op_node(op), inputs, outputs); | |||||
| igraph.exprs.push_back({OpDef::make_from_op_node(op), inputs, outputs}); | |||||
| } | } | ||||
| }; | }; | ||||
| @@ -770,36 +769,6 @@ ProxyGraph::make_backward_graph( | |||||
| write_inputs(outputs); | write_inputs(outputs); | ||||
| write_inputs(output_grads); | write_inputs(output_grads); | ||||
| mgb_assert(igraph.inputs.size() == nr_backward_graph_inputs); | mgb_assert(igraph.inputs.size() == nr_backward_graph_inputs); | ||||
| auto treat_as_single = [](auto&& igraph) { | |||||
| if (igraph.exprs.size() != 1) | |||||
| return false; | |||||
| auto&& expr = igraph.exprs[0]; | |||||
| auto&& expr_inputs = std::get<1>(expr); | |||||
| if (expr_inputs.size() != igraph.inputs.size()) { | |||||
| return false; | |||||
| } | |||||
| for (size_t i = 0; i < expr_inputs.size(); ++ i) { | |||||
| if (igraph.inputs[i] != expr_inputs[i]) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| auto&& expr_outputs = std::get<2>(expr); | |||||
| if (expr_outputs.size() != igraph.outputs.size()) { | |||||
| return false; | |||||
| } | |||||
| for (size_t i = 0; i < expr_outputs.size(); ++ i) { | |||||
| if (igraph.outputs[i] != expr_outputs[i]) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| }; | |||||
| if (treat_as_single(igraph)) { | |||||
| result.backward = std::get<0>(igraph.exprs[0]); | |||||
| } else { | |||||
| result.backward = backward; | |||||
| } | |||||
| return result; | return result; | ||||
| } | } | ||||
| @@ -65,7 +65,7 @@ private: | |||||
| class InputPlaceholder; | class InputPlaceholder; | ||||
| struct ProxyGraphInst; | struct ProxyGraphInst; | ||||
| struct GradGraph; | struct GradGraph; | ||||
| struct CurOprGuard; | |||||
| class CurOprGuard; | |||||
| void reset(); | void reset(); | ||||
| @@ -15,7 +15,7 @@ namespace mgb::imperative::proxy_graph { | |||||
| // e.g. friend class mgb::imperative::proxy_graph::ProxyGraph | // e.g. friend class mgb::imperative::proxy_graph::ProxyGraph | ||||
| struct ProxyGraph { | struct ProxyGraph { | ||||
| struct InputPlaceholder; | struct InputPlaceholder; | ||||
| struct MiniGraph; | |||||
| class MiniGraph; | |||||
| }; | }; | ||||
| } // namespace mgb::imperative::proxy_graph | } // namespace mgb::imperative::proxy_graph | ||||
| @@ -75,30 +75,7 @@ apply_on_physical_tensor(const OpDef& def, | |||||
| auto output_descs = infer_output_attrs(def, inputs); | auto output_descs = infer_output_attrs(def, inputs); | ||||
| SmallVector<TensorPtr> outputs(output_descs.size(), {}); | SmallVector<TensorPtr> outputs(output_descs.size(), {}); | ||||
| for (size_t i = 0; i < outputs.size(); i++) { | for (size_t i = 0; i < outputs.size(); i++) { | ||||
| auto& output = outputs[i]; | |||||
| auto& output_desc = output_descs[i]; | |||||
| if (def.same_type<Elemwise>()) { | |||||
| for (size_t j = 0; j < inputs.size(); j++) { | |||||
| // TODO: reindex inputs to support inplace exprs like 'y = x op x'. | |||||
| auto& input = inputs[j]; | |||||
| // Because we pass inputs by value, if input and input->blob() are all unique, | |||||
| // their ownerships are on the stack, thus we can reuse them safely. | |||||
| // @see: interpreter::intl::ChannelImpl::process_one_task | |||||
| if (input.unique() && input->blob().unique() && input->blob()->storage().unique() && | |||||
| input->layout().dtype == output_desc.layout.dtype && | |||||
| input->layout().eq_layout(output_desc.layout) && | |||||
| input->comp_node() == output_desc.comp_node) { | |||||
| static std::atomic_llong inplace_count = 0; | |||||
| mgb_log_debug("do inplace for elemwise, layout: %s, count: %lld", | |||||
| output_desc.layout.to_string().c_str(), ++inplace_count); | |||||
| output = Tensor::make(input->blob(), input->layout(), input->offset()); | |||||
| break; | |||||
| } | |||||
| } | |||||
| } | |||||
| if (!output) { | |||||
| output = Tensor::make(output_desc.layout, output_desc.comp_node); | |||||
| } | |||||
| outputs[i] = Tensor::make(output_descs[i].layout, output_descs[i].comp_node); | |||||
| } | } | ||||
| exec(def, inputs, outputs); | exec(def, inputs, outputs); | ||||
| auto async_error = ProxyGraph::get_async_error(); | auto async_error = ProxyGraph::get_async_error(); | ||||
| @@ -14,10 +14,10 @@ | |||||
| namespace mgb::imperative { | namespace mgb::imperative { | ||||
| struct OptimizedBackwardGraphResult { | struct OptimizedBackwardGraphResult { | ||||
| std::shared_ptr<OpDef> precomp; | |||||
| std::shared_ptr<OpDef> backward; | |||||
| std::vector<bool> save_for_backward; | |||||
| std::vector<bool> input_has_grad; | |||||
| Subgraph precomp; | |||||
| Subgraph backward; | |||||
| SmallVector<bool> save_for_backward; | |||||
| SmallVector<bool> input_has_grad; | |||||
| OptimizedBackwardGraphResult(const BackwardGraphResult& bgraph); | OptimizedBackwardGraphResult(const BackwardGraphResult& bgraph); | ||||
| }; | }; | ||||
| @@ -26,10 +26,60 @@ enum DispatchMode { | |||||
| KERNEL = 1 | KERNEL = 1 | ||||
| }; | }; | ||||
| using SharedOp = std::shared_ptr<OpDef>; | |||||
| template <typename T> | |||||
| struct Expr { | |||||
| std::shared_ptr<OpDef> op; | |||||
| SmallVector<T> inputs; | |||||
| SmallVector<T> outputs; | |||||
| }; | |||||
| struct Subgraph { | |||||
| SmallVector<size_t> inputs; | |||||
| SmallVector<std::pair<size_t, TensorPtr>> constants; | |||||
| SmallVector<size_t> outputs; | |||||
| SmallVector<Expr<size_t>> exprs; | |||||
| template <typename T, typename F, typename C> | |||||
| SmallVector<T> apply(SmallVector<T> input_vars, F&& f, C&& c) const { | |||||
| std::unordered_map<size_t, T> idx2var; | |||||
| mgb_assert(inputs.size() == input_vars.size(), "input size mismatch"); | |||||
| for (size_t i = 0; i < inputs.size(); ++i) { | |||||
| idx2var[inputs[i]] = input_vars[i]; | |||||
| } | |||||
| for (auto&& [idx, val]: constants) { | |||||
| idx2var[idx] = c(val); | |||||
| } | |||||
| for (auto& expr: exprs) { | |||||
| SmallVector<T> expr_inputs; | |||||
| for (auto idx: expr.inputs) { | |||||
| expr_inputs.push_back(idx2var[idx]); | |||||
| } | |||||
| SmallVector<T> expr_outputs = f(expr.op, std::move(expr_inputs)); | |||||
| mgb_assert(expr_outputs.size() == expr.outputs.size(), "output size mismatch"); | |||||
| for (size_t i = 0; i < expr_outputs.size(); ++i) { | |||||
| idx2var[expr.outputs[i]] = expr_outputs[i]; | |||||
| } | |||||
| } | |||||
| SmallVector<T> output_vars; | |||||
| for (auto idx: outputs) { | |||||
| output_vars.push_back(idx2var[idx]); | |||||
| } | |||||
| return output_vars; | |||||
| } | |||||
| bool empty() const { | |||||
| return outputs.size() == 0; | |||||
| } | |||||
| std::string repr() const; | |||||
| }; | |||||
| struct BackwardGraphResult { | struct BackwardGraphResult { | ||||
| std::shared_ptr<OpDef> backward; | |||||
| std::vector<bool> save_for_backward; | |||||
| std::vector<bool> input_has_grad; | |||||
| Subgraph backward; | |||||
| SmallVector<bool> save_for_backward; | |||||
| SmallVector<bool> input_has_grad; | |||||
| }; | }; | ||||
| class OpDef : public Hashable, | class OpDef : public Hashable, | ||||
| @@ -15,92 +15,6 @@ | |||||
| namespace mgb { | namespace mgb { | ||||
| namespace imperative { | namespace imperative { | ||||
| // a special OpDef used for taking gradient on physical tensor | |||||
| struct BackwardGraph final : public OpDefImplBase<BackwardGraph> { | |||||
| MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||||
| public: | |||||
| struct InternalGraph { | |||||
| // op, inputs, outputs | |||||
| using Expr = std::tuple<std::shared_ptr<OpDef>, | |||||
| std::vector<size_t>, std::vector<size_t>>; | |||||
| std::vector<Expr> exprs; | |||||
| // index array of input nodes | |||||
| std::vector<size_t> inputs; | |||||
| // index array of output nodes | |||||
| std::vector<size_t> outputs; | |||||
| // pair of (node index, correspending constant) | |||||
| std::vector<std::pair<size_t, TensorPtr>> constants; | |||||
| SmallVector<TensorPtr> | |||||
| apply(const SmallVector<TensorPtr>& inputs) const; | |||||
| std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_attrs( | |||||
| const SmallVector<LogicalTensorDesc>& inputs) const; | |||||
| template <typename T, typename F, typename C> | |||||
| SmallVector<T> interpret(F&& f, C&& c, const SmallVector<T>& inputs) const { | |||||
| ThinHashMap<size_t, T> node2tensor; | |||||
| auto&& input_nodes = this->inputs; | |||||
| mgb_assert(inputs.size() == input_nodes.size()); | |||||
| for (size_t i = 0; i < inputs.size(); ++ i) { | |||||
| node2tensor[input_nodes[i]] = inputs[i]; | |||||
| } | |||||
| for (auto &&i : constants) { | |||||
| node2tensor[i.first] = c(i.second); | |||||
| } | |||||
| for (size_t i = 0; i < exprs.size(); ++ i) { | |||||
| auto&& expr = exprs[i]; | |||||
| SmallVector<T> inputs; | |||||
| for (auto &&in : std::get<1>(expr)) { | |||||
| inputs.push_back(node2tensor.at(in)); | |||||
| } | |||||
| auto&& outputs = f(*std::get<0>(expr), std::move(inputs)); | |||||
| auto&& output_nodes = std::get<2>(expr); | |||||
| mgb_assert(outputs.size() == output_nodes.size()); | |||||
| for (size_t i = 0; i < outputs.size(); ++ i) { | |||||
| node2tensor[output_nodes[i]] = std::move(outputs[i]); | |||||
| } | |||||
| } | |||||
| SmallVector<T> ret; | |||||
| for (auto &&i : outputs) { | |||||
| ret.push_back(node2tensor.at(i)); | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| std::string repr(); | |||||
| }; | |||||
| const InternalGraph& graph() const { | |||||
| return m_graph; | |||||
| } | |||||
| InternalGraph& graph() { | |||||
| return m_graph; | |||||
| } | |||||
| bool is_same_st(const Hashable& rhs) const override { | |||||
| if (!rhs.same_type<BackwardGraph>()) { | |||||
| return false; | |||||
| } | |||||
| auto& other = rhs.cast_final_safe<BackwardGraph>(); | |||||
| if (this == &other) { | |||||
| return true; | |||||
| } | |||||
| // FIXME | |||||
| return false; | |||||
| } | |||||
| std::string repr() {return m_graph.repr();} | |||||
| private: | |||||
| InternalGraph m_graph; | |||||
| }; | |||||
| } // namespace imperative | } // namespace imperative | ||||
| } // namespace mgb | } // namespace mgb | ||||
| @@ -29,7 +29,7 @@ struct GenericPyOp final : OpDefImplBase<GenericPyOp> { | |||||
| } | } | ||||
| bool is_same_st(const Hashable& rhs) const override { | bool is_same_st(const Hashable& rhs) const override { | ||||
| return obj.equal(static_cast<const GenericPyOp&>(rhs).obj); | |||||
| return obj.equal(rhs.cast_final<GenericPyOp>().obj); | |||||
| } | } | ||||
| MGB_DYN_TYPE_OBJ_FINAL_DECL; | MGB_DYN_TYPE_OBJ_FINAL_DECL; | ||||
| @@ -75,6 +75,10 @@ T prepare_optimized_backward_inputs(const OptimizedBackwardGraphResult& bg, cons | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| SmallVector<TensorPtr> apply_shared_on_physical_tensor(std::shared_ptr<OpDef> def, SmallVector<TensorPtr> inputs) { | |||||
| return OpDef::apply_on_physical_tensor(*def, inputs); | |||||
| } | |||||
| TEST(TestImperative, BackwardGraphBasic) { | TEST(TestImperative, BackwardGraphBasic) { | ||||
| HostTensorGenerator<> gen; | HostTensorGenerator<> gen; | ||||
| SmallVector<HostTensorND> hvs; | SmallVector<HostTensorND> hvs; | ||||
| @@ -114,7 +118,11 @@ TEST(TestImperative, BackwardGraphBasic) { | |||||
| } | } | ||||
| } | } | ||||
| inputs.clear(); | inputs.clear(); | ||||
| auto input_grads = OpDef::apply_on_physical_tensor(*(result.backward), backward_graph_inputs); | |||||
| auto input_grads = result.backward.apply( | |||||
| backward_graph_inputs, | |||||
| apply_shared_on_physical_tensor, | |||||
| [&](auto&& x){ return x; } | |||||
| ); | |||||
| mgb_assert(input_grads.size() == input_has_grad.size()); | mgb_assert(input_grads.size() == input_has_grad.size()); | ||||
| for (size_t i = 0; i < input_has_grad.size(); ++ i) { | for (size_t i = 0; i < input_has_grad.size(); ++ i) { | ||||
| mgb_assert(input_has_grad[i] == static_cast<bool>(input_grads[i])); | mgb_assert(input_has_grad[i] == static_cast<bool>(input_grads[i])); | ||||
| @@ -164,7 +172,11 @@ TEST(TestImperative, BackwardGraphIdentity) { | |||||
| } | } | ||||
| } | } | ||||
| inputs.clear(); | inputs.clear(); | ||||
| auto input_grads = OpDef::apply_on_physical_tensor(*(result.backward), backward_graph_inputs); | |||||
| auto input_grads = result.backward.apply( | |||||
| backward_graph_inputs, | |||||
| apply_shared_on_physical_tensor, | |||||
| [&](auto&& x){ return x; } | |||||
| ); | |||||
| mgb_assert(input_grads.size() == input_has_grad.size()); | mgb_assert(input_grads.size() == input_has_grad.size()); | ||||
| for (size_t i = 0; i < input_has_grad.size(); ++ i) { | for (size_t i = 0; i < input_has_grad.size(); ++ i) { | ||||
| mgb_assert(input_has_grad[i] == static_cast<bool>(input_grads[i])); | mgb_assert(input_has_grad[i] == static_cast<bool>(input_grads[i])); | ||||
| @@ -224,9 +236,17 @@ TEST(TestImperative, OptimizedBackwardGraphBasic) { | |||||
| auto c_tn = OpDef::apply_on_physical_tensor(*op, {a_tn, b_tn})[0]; | auto c_tn = OpDef::apply_on_physical_tensor(*op, {a_tn, b_tn})[0]; | ||||
| auto backward_graph_inputs = prepare_backward_graph_inputs<SmallVector<TensorPtr>>(bg, {a_tn, b_tn}, {c_tn}, {dc_tn}); | auto backward_graph_inputs = prepare_backward_graph_inputs<SmallVector<TensorPtr>>(bg, {a_tn, b_tn}, {c_tn}, {dc_tn}); | ||||
| auto grads = expand_grads(bg, OpDef::apply_on_physical_tensor(*bg.backward, backward_graph_inputs)); | |||||
| auto grads = expand_grads(bg, bg.backward.apply( | |||||
| backward_graph_inputs, | |||||
| apply_shared_on_physical_tensor, | |||||
| [&](auto&& x){ return x; } | |||||
| )); | |||||
| auto precomp = OpDef::apply_on_physical_tensor(*obg.precomp, {a_tn, b_tn, c_tn}); | |||||
| auto precomp = obg.precomp.apply( | |||||
| SmallVector<TensorPtr>{a_tn, b_tn, c_tn}, | |||||
| apply_shared_on_physical_tensor, | |||||
| [&](auto&& x){ return x; } | |||||
| ); | |||||
| ASSERT_EQ(precomp.size(), 2); | ASSERT_EQ(precomp.size(), 2); | ||||
| ASSERT_EQ(precomp[0]->shape().ndim, 1); | ASSERT_EQ(precomp[0]->shape().ndim, 1); | ||||
| ASSERT_LE(precomp[0]->shape()[0], 2); | ASSERT_LE(precomp[0]->shape()[0], 2); | ||||
| @@ -234,7 +254,11 @@ TEST(TestImperative, OptimizedBackwardGraphBasic) { | |||||
| ASSERT_LE(precomp[1]->shape()[0], 2); | ASSERT_LE(precomp[1]->shape()[0], 2); | ||||
| auto backward_inputs = prepare_optimized_backward_inputs<SmallVector<TensorPtr>>(obg, precomp, {a_tn, b_tn}, {c_tn}, {dc_tn}); | auto backward_inputs = prepare_optimized_backward_inputs<SmallVector<TensorPtr>>(obg, precomp, {a_tn, b_tn}, {c_tn}, {dc_tn}); | ||||
| auto grads2 = expand_grads(obg, OpDef::apply_on_physical_tensor(*obg.backward, backward_inputs)); | |||||
| auto grads2 = expand_grads(obg, obg.backward.apply( | |||||
| backward_inputs, | |||||
| apply_shared_on_physical_tensor, | |||||
| [&](auto&& x){ return x; } | |||||
| )); | |||||
| ASSERT_EQ(grads2.size(), 2); | ASSERT_EQ(grads2.size(), 2); | ||||
| MGB_ASSERT_TENSOR_EQ(grads[0]->get_value(), grads2[0]->get_value()); | MGB_ASSERT_TENSOR_EQ(grads[0]->get_value(), grads2[0]->get_value()); | ||||