GitOrigin-RevId: ce11fe5e09
tags/v1.8.0
| @@ -606,7 +606,8 @@ class Apply(Expr): | |||||
| def apply_module_trace_hook(cls, opdef, *inputs): | def apply_module_trace_hook(cls, opdef, *inputs): | ||||
| for i in inputs: | for i in inputs: | ||||
| node = NodeMixin.get(i, None) | node = NodeMixin.get(i, None) | ||||
| assert node is not None | |||||
| if node is None: # capture as constant | |||||
| NodeMixin.wrap_safe(i, Constant.make(i)) | |||||
| if isinstance(opdef, FakeQuant): | if isinstance(opdef, FakeQuant): | ||||
| inp_nodes = [NodeMixin.get(inputs[0])] | inp_nodes = [NodeMixin.get(inputs[0])] | ||||
| @@ -627,7 +628,6 @@ class Apply(Expr): | |||||
| unset_module_tracing() | unset_module_tracing() | ||||
| outputs = apply(opdef, *inputs) | outputs = apply(opdef, *inputs) | ||||
| outputs = list(map(Tensor, outputs)) | |||||
| set_module_tracing() | set_module_tracing() | ||||
| apply_node.add_outputs(outputs) | apply_node.add_outputs(outputs) | ||||
| @@ -741,12 +741,8 @@ class Constant(Expr): | |||||
| assert isinstance(c, (RawTensor, Module)) | assert isinstance(c, (RawTensor, Module)) | ||||
| if isinstance(c, Module): | if isinstance(c, Module): | ||||
| assert module_tracer.is_builtin(c) or c.is_qat | assert module_tracer.is_builtin(c) or c.is_qat | ||||
| if isinstance(c, RawTensor): | |||||
| if is_tracing_module(): | |||||
| unset_module_tracing() | |||||
| c = Tensor(c) | |||||
| set_module_tracing() | |||||
| else: | |||||
| if type(c) is RawTensor: | |||||
| with _exclude_from_trace(): | |||||
| c = Tensor(c) | c = Tensor(c) | ||||
| self.value = c | self.value = c | ||||
| self.name = name | self.name = name | ||||
| @@ -52,6 +52,12 @@ public: | |||||
| } | } | ||||
| } | } | ||||
| void enable() { m_enabled = 1; } | |||||
| void disable() { m_enabled = 0; } | |||||
| bool enabled() const { return m_enabled; } | |||||
| ValueRef unwrap(ValueRef value) override { return value; } | ValueRef unwrap(ValueRef value) override { return value; } | ||||
| std::string name() const override { return "ModuleTraceTransformation"; } | std::string name() const override { return "ModuleTraceTransformation"; } | ||||
| @@ -219,17 +219,19 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { | |||||
| PyObject* TensorWrapper::module_trace_info() { | PyObject* TensorWrapper::module_trace_info() { | ||||
| if (auto module_trace_info = module_trace_info_map.try_get(m_tensor->data())) { | if (auto module_trace_info = module_trace_info_map.try_get(m_tensor->data())) { | ||||
| return module_trace_info->inc_ref().ptr(); | |||||
| } else { | |||||
| PyErr_SetString( | |||||
| PyExc_AttributeError, | |||||
| "Has no attribute named \'_NodeMixin__node\', please " | |||||
| "set it first"); | |||||
| return nullptr; | |||||
| if (module_trace_info->ptr()) { | |||||
| return module_trace_info->inc_ref().ptr(); | |||||
| } | |||||
| } | } | ||||
| PyErr_SetString( | |||||
| PyExc_AttributeError, | |||||
| "Has no attribute named \'_NodeMixin__node\', please " | |||||
| "set it first"); | |||||
| return nullptr; | |||||
| } | } | ||||
| void TensorWrapper::set_module_trace_info(PyObject* obj) { | void TensorWrapper::set_module_trace_info(PyObject* obj) { | ||||
| // TODO: erase when obj == nullptr | |||||
| module_trace_info_map[m_tensor->data()] = py::reinterpret_borrow<py::object>(obj); | module_trace_info_map[m_tensor->data()] = py::reinterpret_borrow<py::object>(obj); | ||||
| } | } | ||||
| @@ -1031,29 +1033,23 @@ void init_tensor(py::module m) { | |||||
| static py::function module_trace_hook; | static py::function module_trace_hook; | ||||
| static std::shared_ptr<ModuleTraceTransformation> module_trace_transformation; | |||||
| static int module_tracing = 0; | |||||
| m.def("set_module_tracing", [=] { | |||||
| static auto get_module_trace = [] { | |||||
| static std::shared_ptr<ModuleTraceTransformation> module_trace_transformation; | |||||
| if (!module_trace_transformation) { | if (!module_trace_transformation) { | ||||
| mgb_assert(module_trace_hook); | mgb_assert(module_trace_hook); | ||||
| module_trace_transformation = | module_trace_transformation = | ||||
| std::make_shared<ModuleTraceTransformation>(module_trace_hook); | std::make_shared<ModuleTraceTransformation>(module_trace_hook); | ||||
| } | |||||
| if (++module_tracing == 1) { | |||||
| transformations.register_at<TransformationManager::ModuleTrace>( | |||||
| transformations.register_at<Segment::ModuleTrace>( | |||||
| module_trace_transformation); | module_trace_transformation); | ||||
| } | } | ||||
| }); | |||||
| return module_trace_transformation; | |||||
| }; | |||||
| m.def("unset_module_tracing", [=] { | |||||
| if (--module_tracing == 0) { | |||||
| transformations.unregister<TransformationManager::ModuleTrace>( | |||||
| module_trace_transformation); | |||||
| } | |||||
| }); | |||||
| m.def("set_module_tracing", [=] { get_module_trace()->enable(); }); | |||||
| m.def("unset_module_tracing", [=] { get_module_trace()->disable(); }); | |||||
| m.def("is_tracing_module", [=] { return module_tracing > 0; }); | |||||
| m.def("is_tracing_module", [=] { return get_module_trace()->enabled(); }); | |||||
| m.def("set_module_trace_hook", | m.def("set_module_trace_hook", | ||||
| [](py::function function) { module_trace_hook = function; }); | [](py::function function) { module_trace_hook = function; }); | ||||
| @@ -5,9 +5,11 @@ import numpy as np | |||||
| import megengine.functional as F | import megengine.functional as F | ||||
| import megengine.module as M | import megengine.module as M | ||||
| from megengine import Tensor | from megengine import Tensor | ||||
| from megengine.module.module import Module | |||||
| from megengine.core._imperative_rt.core2 import apply | |||||
| from megengine.core.ops import builtin | |||||
| from megengine.module import Module | |||||
| from megengine.traced_module import TracedModule, enable_expr_checker, trace_module | from megengine.traced_module import TracedModule, enable_expr_checker, trace_module | ||||
| from megengine.traced_module.expr import CallFunction | |||||
| from megengine.traced_module.expr import Apply, CallFunction, Constant | |||||
| class MyModule1(M.Module): | class MyModule1(M.Module): | ||||
| @@ -133,3 +135,25 @@ def test_trace_module(): | |||||
| tm6 = trace_module(MyModule5(), a, b) | tm6 = trace_module(MyModule5(), a, b) | ||||
| assert tm6.m1.argspec is None | assert tm6.m1.argspec is None | ||||
| assert tm6.m1._is_top is False | assert tm6.m1._is_top is False | ||||
| def test_trace_module_2(): | |||||
| class Model(M.Module): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| def forward(self, x): | |||||
| out = x.shape | |||||
| out = apply(builtin.Elemwise(mode="ADD"), out, Tensor(1)) | |||||
| return out | |||||
| traced_model = trace_module(Model(), Tensor(([1,]))) | |||||
| assert isinstance(traced_model.graph._exprs[0], Apply) and isinstance( | |||||
| traced_model.graph._exprs[0].opdef, builtin.GetVarShape | |||||
| ) | |||||
| assert isinstance(traced_model.graph._exprs[1], Constant) | |||||
| assert isinstance(traced_model.graph._exprs[2], Apply) and isinstance( | |||||
| traced_model.graph._exprs[2].opdef, builtin.Elemwise | |||||
| ) | |||||
| assert int(traced_model(Tensor([1, 2]))[0]) == 3 | |||||