| @@ -437,7 +437,7 @@ def _unwrap(x): | |||
| return x | |||
| def apply_normal_op(op: OpDef, *args: VarNode): | |||
| def apply_normal_varnode(op: OpDef, *args: VarNode): | |||
| outputs = _imperative_rt.invoke_op(op, _unwrap(args)) | |||
| return _wrap(outputs) | |||
| @@ -447,7 +447,7 @@ def apply_backward_varnode(op: BackwardGraph, *args: VarNode): | |||
| graph = args[0].graph | |||
| outputs = op.interpret( | |||
| op, | |||
| lambda op, args: apply_normal_op(op, *args), | |||
| lambda op, args: apply_normal_varnode(op, *args), | |||
| graph._make_const_for_backward, | |||
| args, | |||
| ) | |||
| @@ -41,7 +41,7 @@ from ..core._imperative_rt.ops import ( | |||
| ) | |||
| from ..core._trace_option import set_symbolic_shape | |||
| from ..core._wrap import device as as_device | |||
| from ..core.ops.builtin import OpDef | |||
| from ..core.ops.builtin import BackwardGraph, OpDef | |||
| from ..core.ops.special import Const | |||
| from ..core.tensor import megbrain_graph as G | |||
| from .sublinear_memory_config import SublinearMemoryConfig | |||
| @@ -372,6 +372,7 @@ class trace: | |||
| lazy_eval_graph() | |||
| for r, x in zip(readers, lazy_eval_tensors): | |||
| x()._handle = RawTensor(r.op.get_value())._handle | |||
| x()._reset_varnode() | |||
| @contextlib.contextmanager | |||
| def _setup(self): | |||
| @@ -580,9 +581,11 @@ class trace: | |||
| ivars.append(info.varnode) | |||
| ivars = [RawTensor(ivar) for ivar in ivars] | |||
| ovars = apply(op, *ivars) | |||
| ovars = [x._varnode for x in ovars] | |||
| if isinstance(op, BackwardGraph): | |||
| ovars = G.apply_backward_varnode(op, *ivars) | |||
| else: | |||
| ovars = G.apply_normal_varnode(op, *ivars) | |||
| if require_links and len(ovars) > 0: | |||
| io_links = (ovars[0],) | |||
| assert len(ovars) == len(ohandles) | |||
| @@ -768,11 +771,10 @@ class trace: | |||
| info.bound_data.numpy(), dtype=info.dtype, device=dumped_device | |||
| ) | |||
| ivars.append(h2v[h]) | |||
| ivars = [RawTensor(ivar) for ivar in ivars] | |||
| ovars = apply(op, *ivars) | |||
| ovars = [x._varnode for x in ovars] | |||
| ovars = G.apply_normal_varnode(op, *ivars) | |||
| assert len(ovars) == len(ohandles) | |||
| h2v.update(zip(ohandles, ovars)) | |||
| unset_tracing() | |||
| dest_vars = [] | |||
| for i, h in enumerate(self._output_bindings): | |||
| @@ -781,7 +783,6 @@ class trace: | |||
| v.name = output_names[i] | |||
| dest_vars.append(v) | |||
| dest_vars = [G.VarNode(var) for var in dest_vars] | |||
| if optimize_for_inference: | |||
| dest_vars = G.optimize_for_inference(dest_vars, **kwargs) | |||
| @@ -1007,7 +1008,6 @@ def assign_raw_tensor(lhs, rhs): | |||
| lhs.__init__(rhs) | |||
| # this hook turns RawTensor into LazyEvalTensor(varnode) | |||
| def apply_symbolic_mode(op: OpDef, *args: RawTensor): | |||
| graph = active_trace._lazy_eval_graph | |||
| ivars = [] | |||
| @@ -1038,13 +1038,11 @@ def apply_symbolic_mode(op: OpDef, *args: RawTensor): | |||
| ivars[0] = opnode.outputs[0] | |||
| active_trace._lazy_eval_links = (ivars[0],) | |||
| ivars = [ | |||
| RawTensor(ivar._node) if hasattr(ivar, "_node") else RawTensor(ivar) | |||
| for ivar in ivars | |||
| ] | |||
| unset_symbolic() | |||
| outputs = apply(op, *ivars) | |||
| set_symbolic() | |||
| if isinstance(op, BackwardGraph): | |||
| ovars = G.apply_backward_varnode(op, *ivars) | |||
| else: | |||
| ovars = G.apply_normal_varnode(op, *ivars) | |||
| outputs = [RawTensor(o) for o in ovars] | |||
| if require_links: | |||
| active_trace._lazy_eval_links = (outputs[0]._varnode,) | |||
| @@ -392,6 +392,10 @@ void TensorWrapper::reset(PyObject* tensor) { | |||
| m_tensor = t->m_tensor; | |||
| } | |||
| void TensorWrapper::reset_varnode() { | |||
| m_tensor->m_var = nullptr; | |||
| } | |||
| PyObject* TensorWrapper::detach() { | |||
| PyObject* self = wrap_t::pycast(this); | |||
| PyTypeObject* pytype = self->ob_type; | |||
| @@ -687,6 +691,7 @@ void init_tensor(py::module m) { | |||
| .def<&TensorWrapper::_swap_out>("_swap_out") | |||
| .def<&TensorWrapper::_swap_in>("_swap_in") | |||
| .def<&TensorWrapper::_drop>("_drop") | |||
| .def<&TensorWrapper::reset_varnode>("_reset_varnode") | |||
| .def_getset<&TensorWrapper::varnode>("_varnode") | |||
| .def_getset<&TensorWrapper::data_read, &TensorWrapper::set_data_read>("data_read") | |||
| .def_getset<&TensorWrapper::value_read, &TensorWrapper::set_value_read>("value_read") | |||
| @@ -155,6 +155,7 @@ struct TensorWrapper { | |||
| void _swap_out(); | |||
| void _drop(); | |||
| PyObject* varnode(); | |||
| void reset_varnode(); | |||
| PyObject* handle(); | |||
| void set_handle(PyObject *); | |||
| @@ -17,30 +17,9 @@ namespace py = pybind11; | |||
| namespace mgb::imperative::python { | |||
| apply_result_t apply_tensor_on_var_node(ApplyContext& ctx) { | |||
| apply_result_t outputs; | |||
| cg::VarNodeArray vinputs(ctx.nargs); | |||
| for (size_t i = 0; i < ctx.nargs; i++) { | |||
| vinputs[i] = ctx.args[i]->m_var; | |||
| } | |||
| auto ovars = OpDef::apply_on_var_node(*ctx.op, vinputs); | |||
| for (size_t i = 0; i < ovars.size(); i++) { | |||
| outputs.emplace_back(std::make_shared<Tensor>(ovars[i])); | |||
| } | |||
| return outputs; | |||
| } | |||
| apply_result_t apply_trace(ApplyContext& ctx) { | |||
| apply_result_t outputs; | |||
| bool run_apply_on_var_node = false; | |||
| for (size_t i = 0; i < ctx.nargs; i++) { | |||
| run_apply_on_var_node |= ((ctx.args[i]->m_handle.get() == nullptr) & (ctx.args[i]->m_var != nullptr)); | |||
| } | |||
| if (ctx.backward) { | |||
| // reach here when symbolic=True or compiled=True | |||
| // call megbrain_graph.py apply(BackwardGraph, *args) | |||
| @@ -63,10 +42,6 @@ apply_result_t apply_trace(ApplyContext& ctx) { | |||
| return outputs; | |||
| } | |||
| if (run_apply_on_var_node && !is_symbolic) { | |||
| return apply_tensor_on_var_node(ctx); | |||
| } | |||
| py::object pyf; | |||
| if (is_compiled) { | |||
| // run apply in compiled mode, step 2, 3, etc | |||
| @@ -112,7 +112,7 @@ def test_quint8_typecvt(): | |||
| data = np.random.random(shape).astype(np.float32) * 5 - 1 | |||
| def typecvt(x, dt=None): | |||
| (y,) = G.apply_normal_op(ops.TypeCvt(dtype=dt), x) | |||
| (y,) = G.apply_normal_varnode(ops.TypeCvt(dtype=dt), x) | |||
| return y | |||
| # convert to quint8 | |||
| @@ -193,7 +193,7 @@ def test_quint4_typecvt(): | |||
| data = np.random.random(shape).astype(np.float32) * 5 - 1 | |||
| def typecvt(x, dt=None): | |||
| (y,) = G.apply_normal_op(ops.TypeCvt(dtype=dt), x) | |||
| (y,) = G.apply_normal_varnode(ops.TypeCvt(dtype=dt), x) | |||
| return y | |||
| # convert to quint4 | |||
| @@ -72,7 +72,7 @@ def test_op(): | |||
| lambda: x, device=x.comp_node, dtype=x.dtype, graph=g | |||
| ) | |||
| neg = Elemwise(Elemwise.Mode.NEGATE) | |||
| v = mgb_graph.apply_normal_op(neg, v)[0] | |||
| v = mgb_graph.apply_normal_varnode(neg, v)[0] | |||
| y = Future() | |||
| v = mgb_graph.output_callback(y.set_result, v) | |||
| f = g.compile(v) | |||
| @@ -90,7 +90,7 @@ def test_exception(): | |||
| g = mgb_graph.Graph() | |||
| x, _ = mgb_graph.input_callback(throw_exc, device="xpux", dtype="float32", graph=g) | |||
| neg = Elemwise(Elemwise.Mode.NEGATE) | |||
| y = mgb_graph.OutputNode(mgb_graph.apply_normal_op(neg, x)[0]) | |||
| y = mgb_graph.OutputNode(mgb_graph.apply_normal_varnode(neg, x)[0]) | |||
| f = g.compile(y.outputs[0]) | |||
| try: | |||
| f.execute() | |||
| @@ -16,7 +16,7 @@ import megengine.module as M | |||
| import megengine.utils.comp_graph_tools as cgtools | |||
| from megengine.core.ops.builtin import Elemwise | |||
| from megengine.core.tensor import megbrain_graph as mgb_graph | |||
| from megengine.core.tensor.megbrain_graph import apply_normal_op | |||
| from megengine.core.tensor.megbrain_graph import apply_normal_varnode | |||
| from megengine.core.tensor.utils import astensor1d | |||
| from megengine.jit import trace | |||
| @@ -34,9 +34,9 @@ def test_replace_vars(): | |||
| const = g.make_const(1.234, device=device) | |||
| add_op = Elemwise(Elemwise.Mode.ADD) | |||
| mul_op = Elemwise(Elemwise.Mode.MUL) | |||
| a_plus_a = apply_normal_op(add_op, a.outputs[0], a.outputs[0])[0] | |||
| a_plus_a_mul_const = apply_normal_op(mul_op, a_plus_a, const)[0] | |||
| rst = apply_normal_op(add_op, a_plus_a_mul_const, a.outputs[0])[0] | |||
| a_plus_a = apply_normal_varnode(add_op, a.outputs[0], a.outputs[0])[0] | |||
| a_plus_a_mul_const = apply_normal_varnode(mul_op, a_plus_a, const)[0] | |||
| rst = apply_normal_varnode(add_op, a_plus_a_mul_const, a.outputs[0])[0] | |||
| (new,) = cgtools.replace_vars([rst._node], {const._node: a_plus_a._node}) | |||
| out = mgb_graph.OutputNode(mgb_graph.VarNode(new)) | |||
| func = g.compile(out.outputs[0]) | |||
| @@ -56,10 +56,10 @@ def test_replace_oprs(): | |||
| const = g.make_const(1.25, device=device) | |||
| add_op = Elemwise(Elemwise.Mode.ADD) | |||
| mul_op = Elemwise(Elemwise.Mode.MUL) | |||
| a_plus_a = apply_normal_op(add_op, a.outputs[0], a.outputs[0])[0] | |||
| a_plus_a = apply_normal_varnode(add_op, a.outputs[0], a.outputs[0])[0] | |||
| old_opr = a_plus_a.op | |||
| a_plus_a_mul_const = apply_normal_op(mul_op, a_plus_a, const)[0] | |||
| a_mul_a = apply_normal_op(mul_op, a.outputs[0], a.outputs[0])[0] | |||
| a_plus_a_mul_const = apply_normal_varnode(mul_op, a_plus_a, const)[0] | |||
| a_mul_a = apply_normal_varnode(mul_op, a.outputs[0], a.outputs[0])[0] | |||
| new_opr = a_mul_a.op | |||
| (new,) = cgtools.replace_oprs( | |||
| [a_plus_a_mul_const._node], {old_opr._node: new_opr._node} | |||
| @@ -163,6 +163,7 @@ def test_trace_profiler(): | |||
| assert out.get("profiler") | |||
| @pytest.mark.skip(reason="force opt_level=0 when building graph") | |||
| def test_goptions(): | |||
| @trace(symbolic=True, opt_level=0, capture_as_const=True) | |||
| def f(x): | |||
| @@ -181,6 +182,7 @@ def test_goptions(): | |||
| np.testing.assert_equal(g(d).numpy().item(), 1.0) | |||
| @pytest.mark.skip(reason="force opt_level=0 when building graph") | |||
| def test_goptions_log_sum_exp(): | |||
| @trace(symbolic=True, opt_level=0, capture_as_const=True) | |||
| def f(x, y): | |||