GitOrigin-RevId: f7f6024034
tags/v1.8.0
| @@ -763,6 +763,7 @@ class Constant(Expr): | |||||
| current_graph = active_module_tracer().current_scope() | current_graph = active_module_tracer().current_scope() | ||||
| current_graph._namespace.auto_naming_for_outputs(expr) | current_graph._namespace.auto_naming_for_outputs(expr) | ||||
| current_graph._insert(expr) | current_graph._insert(expr) | ||||
| active_module_tracer().current_constant_cache().append(expr.value) | |||||
| return expr.outputs[0] | return expr.outputs[0] | ||||
| def interpret(self, *inputs): | def interpret(self, *inputs): | ||||
| @@ -131,6 +131,7 @@ class module_tracer: | |||||
| self._active_scopes = [] | self._active_scopes = [] | ||||
| self.checker = TracedModuleChecker(self) | self.checker = TracedModuleChecker(self) | ||||
| self.patcher = Patcher(wrap_fn) | self.patcher = Patcher(wrap_fn) | ||||
| self._activate_constant_cache = [] | |||||
| @classmethod | @classmethod | ||||
| def register_as_builtin(cls, mod): | def register_as_builtin(cls, mod): | ||||
| @@ -145,16 +146,28 @@ class module_tracer: | |||||
| def push_scope(self, scope): | def push_scope(self, scope): | ||||
| self._active_scopes.append(scope) | self._active_scopes.append(scope) | ||||
| self.checker.push_scope() | self.checker.push_scope() | ||||
| self._activate_constant_cache.append([]) | |||||
| def pop_scope(self): | def pop_scope(self): | ||||
| self._active_scopes.pop() | self._active_scopes.pop() | ||||
| self.checker.pop_scope() | self.checker.pop_scope() | ||||
| cache = self._activate_constant_cache.pop() | |||||
| for obj in cache: | |||||
| if hasattr(obj, "_NodeMixin__node"): | |||||
| delattr(obj, "_NodeMixin__node") | |||||
| def current_scope(self): | def current_scope(self): | ||||
| if self._active_scopes: | if self._active_scopes: | ||||
| return self._active_scopes[-1] | return self._active_scopes[-1] | ||||
| return None | return None | ||||
| def current_constant_cache(self): | |||||
| if self._activate_constant_cache: | |||||
| return self._activate_constant_cache[-1] | |||||
| return None | |||||
| def top_scope(self): | def top_scope(self): | ||||
| if self._active_scopes: | if self._active_scopes: | ||||
| return self._active_scopes[0] | return self._active_scopes[0] | ||||
| @@ -379,6 +379,11 @@ class NodeMixin(abc.ABC): | |||||
| if isinstance(value, NodeMixin): | if isinstance(value, NodeMixin): | ||||
| value._record_wrapped_nodes(node) | value._record_wrapped_nodes(node) | ||||
| @classmethod | |||||
| def clear_node(cls, value): | |||||
| if hasattr(value, "_NodeMixin__node"): | |||||
| delattr(value, "_NodeMixin__node") | |||||
| @classmethod | @classmethod | ||||
| def get(cls, value, *default): | def get(cls, value, *default): | ||||
| return getattr(value, "_NodeMixin__node", *default) | return getattr(value, "_NodeMixin__node", *default) | ||||
| @@ -1980,7 +1980,10 @@ class TracedModule(Module): | |||||
| assert ( | assert ( | ||||
| treedef in self.argdef_graph_map | treedef in self.argdef_graph_map | ||||
| ), "support input args kwargs format: \n{}, but get: \n{}".format( | ), "support input args kwargs format: \n{}, but get: \n{}".format( | ||||
| "\n ".join("forward({})".format(i._args_kwargs_repr()) for i in self.argdef_graph_map.keys()), | |||||
| "\n ".join( | |||||
| "forward({})".format(i._args_kwargs_repr()) | |||||
| for i in self.argdef_graph_map.keys() | |||||
| ), | |||||
| treedef._args_kwargs_repr(), | treedef._args_kwargs_repr(), | ||||
| ) | ) | ||||
| inputs = filter( | inputs = filter( | ||||
| @@ -2514,3 +2517,7 @@ def trace_module( | |||||
| set_symbolic_shape(use_sym_shape) | set_symbolic_shape(use_sym_shape) | ||||
| set_active_module_tracer(None) | set_active_module_tracer(None) | ||||
| unset_module_tracing() | unset_module_tracing() | ||||
| for t in mod.tensors(recursive=True): | |||||
| NodeMixin.clear_node(t) | |||||
| for t in inputs: | |||||
| NodeMixin.clear_node(t) | |||||