GitOrigin-RevId: 468a996bdd
tags/v1.11.0
| @@ -30,12 +30,22 @@ private: | |||||
| } | } | ||||
| public: | public: | ||||
| inline static WeakKeyMap<ValueWeakRef, py::object> module_trace_info_map; | |||||
| ModuleTraceTransformation(py::function hook_fn) : m_hook_fn(hook_fn) {} | ModuleTraceTransformation(py::function hook_fn) : m_hook_fn(hook_fn) {} | ||||
| ValueRefList apply_transformation( | ValueRefList apply_transformation( | ||||
| const Operator& op, Span<ValueRef> inputs) override { | const Operator& op, Span<ValueRef> inputs) override { | ||||
| if (op.is<ApplyOp>() && m_enabled > 0) { | if (op.is<ApplyOp>() && m_enabled > 0) { | ||||
| auto outputs = apply_module_trace_hook(op.cast<ApplyOp>().op(), inputs); | auto outputs = apply_module_trace_hook(op.cast<ApplyOp>().op(), inputs); | ||||
| return outputs; | return outputs; | ||||
| } else if (op.is<RenameValue>()) { | |||||
| auto outputs = imperative::apply(op, inputs); | |||||
| if (auto module_trace_info = module_trace_info_map.try_get(inputs[0])) { | |||||
| if (module_trace_info->ptr()) { | |||||
| auto node = module_trace_info.value(); | |||||
| module_trace_info_map[outputs[0]] = module_trace_info.value(); | |||||
| } | |||||
| } | |||||
| return outputs; | |||||
| } else { | } else { | ||||
| return imperative::apply(op, inputs); | return imperative::apply(op, inputs); | ||||
| } | } | ||||
| @@ -47,10 +47,6 @@ namespace views = ranges::views; | |||||
| namespace mgb::imperative::python { | namespace mgb::imperative::python { | ||||
| namespace { | |||||
| WeakKeyMap<ValueWeakRef, py::object> module_trace_info_map; | |||||
| } // namespace | |||||
| interpreter::Interpreter::Channel* interpreter_for_py = nullptr; | interpreter::Interpreter::Channel* interpreter_for_py = nullptr; | ||||
| PyTypeObject* py_tensor_type = nullptr; | PyTypeObject* py_tensor_type = nullptr; | ||||
| PyTypeObject* py_varnode_type = nullptr; | PyTypeObject* py_varnode_type = nullptr; | ||||
| @@ -594,7 +590,9 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { | |||||
| } | } | ||||
| PyObject* TensorWrapper::module_trace_info() { | PyObject* TensorWrapper::module_trace_info() { | ||||
| if (auto module_trace_info = module_trace_info_map.try_get(m_tensor->data())) { | |||||
| if (auto module_trace_info = | |||||
| ModuleTraceTransformation::module_trace_info_map.try_get( | |||||
| m_tensor->data())) { | |||||
| if (module_trace_info->ptr()) { | if (module_trace_info->ptr()) { | ||||
| return module_trace_info->inc_ref().ptr(); | return module_trace_info->inc_ref().ptr(); | ||||
| } | } | ||||
| @@ -608,7 +606,8 @@ PyObject* TensorWrapper::module_trace_info() { | |||||
| void TensorWrapper::set_module_trace_info(PyObject* obj) { | void TensorWrapper::set_module_trace_info(PyObject* obj) { | ||||
| // TODO: erase when obj == nullptr | // TODO: erase when obj == nullptr | ||||
| module_trace_info_map[m_tensor->data()] = py::reinterpret_borrow<py::object>(obj); | |||||
| ModuleTraceTransformation::module_trace_info_map[m_tensor->data()] = | |||||
| py::reinterpret_borrow<py::object>(obj); | |||||
| } | } | ||||
| void TensorWrapper::_set_format(PyObject* dest) { | void TensorWrapper::_set_format(PyObject* dest) { | ||||
| @@ -620,6 +619,7 @@ void TensorWrapper::_set_format(PyObject* dest) { | |||||
| void TensorWrapper::_set_name(PyObject* dest) { | void TensorWrapper::_set_name(PyObject* dest) { | ||||
| auto py_dest = py::reinterpret_borrow<py::object>(dest); | auto py_dest = py::reinterpret_borrow<py::object>(dest); | ||||
| auto name = py_dest.cast<std::string>(); | auto name = py_dest.cast<std::string>(); | ||||
| m_tensor->set_name(name); | m_tensor->set_name(name); | ||||
| } | } | ||||
| @@ -9,7 +9,7 @@ from megengine.core._imperative_rt.core2 import apply | |||||
| from megengine.core.ops import builtin | from megengine.core.ops import builtin | ||||
| from megengine.module import Module | from megengine.module import Module | ||||
| from megengine.traced_module import TracedModule, enable_expr_checker, trace_module | from megengine.traced_module import TracedModule, enable_expr_checker, trace_module | ||||
| from megengine.traced_module.expr import Apply, CallFunction, Constant | |||||
| from megengine.traced_module.expr import Apply, CallFunction, CallMethod, Constant | |||||
| class MyModule1(M.Module): | class MyModule1(M.Module): | ||||
| @@ -59,6 +59,14 @@ class MyModule4(M.Module): | |||||
| return self.add(x, y) | return self.add(x, y) | ||||
| class MyModule5(M.Module): | |||||
| def forward(self, x): | |||||
| a = x + x | |||||
| b = x * a | |||||
| b.name = "result" | |||||
| return b | |||||
| def test_trace_module(): | def test_trace_module(): | ||||
| enable_expr_checker() | enable_expr_checker() | ||||
| x = Tensor(1) | x = Tensor(1) | ||||
| @@ -157,3 +165,9 @@ def test_trace_module_2(): | |||||
| traced_model.graph._exprs[2].opdef, builtin.Elemwise | traced_model.graph._exprs[2].opdef, builtin.Elemwise | ||||
| ) | ) | ||||
| assert int(traced_model(Tensor([1, 2]))[0]) == 3 | assert int(traced_model(Tensor([1, 2]))[0]) == 3 | ||||
| def test_rename(): | |||||
| model = MyModule5() | |||||
| tm_model = trace_module(model, Tensor(1)) | |||||
| assert isinstance(tm_model.graph.outputs[0].expr, CallMethod) | |||||