GitOrigin-RevId: 9ecf6f2c5b
tags/v1.8.0
| @@ -92,7 +92,6 @@ BUILTIN_TENSOR_WRAP_METHOD = [ | |||||
| "dtype", | "dtype", | ||||
| "grad", | "grad", | ||||
| "item", | "item", | ||||
| "name", | |||||
| "ndim", | "ndim", | ||||
| "numpy", | "numpy", | ||||
| "qparams", | "qparams", | ||||
| @@ -152,6 +151,11 @@ class module_tracer: | |||||
| return self._active_scopes[-1] | return self._active_scopes[-1] | ||||
| return None | return None | ||||
| def top_scope(self): | |||||
| if self._active_scopes: | |||||
| return self._active_scopes[0] | |||||
| return None | |||||
| class NotExist: | class NotExist: | ||||
| pass | pass | ||||
| @@ -180,6 +180,25 @@ def _tensor_to_node(tensors): | |||||
| return nodes | return nodes | ||||
| def _name_setter(node: Node, new_name: str): | |||||
| surgery_mode = _set_graph_surgery_mode(False) | |||||
| graph = active_module_tracer().current_scope() | |||||
| if node.top_graph is not None: | |||||
| top_graph = active_module_tracer().top_scope() | |||||
| if node is top_graph._namespace.used_names.get(node._name, None): | |||||
| graph = top_graph | |||||
| else: | |||||
| graph = node.top_graph | |||||
| assert ( | |||||
| graph._namespace.used_names.get(new_name, None) is None | |||||
| ), "The name(%s) is already in use. Please try a different one again." % (new_name) | |||||
| graph._namespace.unassociate_name_with_obj(node) | |||||
| node._name = graph._namespace.create_unique_name(new_name, node) | |||||
| _set_graph_surgery_mode(surgery_mode) | |||||
| def _wrap_method_to_tensor_node(): | def _wrap_method_to_tensor_node(): | ||||
| def _any_method(name, func): | def _any_method(name, func): | ||||
| def _any(*args, **kwargs): | def _any(*args, **kwargs): | ||||
| @@ -213,6 +232,10 @@ def _wrap_method_to_tensor_node(): | |||||
| else: | else: | ||||
| patch.set_func(_any_method(method, patch.origin_fn)) | patch.set_func(_any_method(method, patch.origin_fn)) | ||||
| tensor_method_patch.append(patch) | tensor_method_patch.append(patch) | ||||
| patch = PatchedFn(Node, "name") | |||||
| patch.set_func(property(patch.origin_fn.fget, _name_setter)) | |||||
| tensor_method_patch.append(patch) | |||||
| return tensor_method_patch | return tensor_method_patch | ||||
| @@ -377,6 +377,33 @@ def test_set_node_name(): | |||||
| rename("output") | rename("output") | ||||
| np.testing.assert_equal(str(graph.outputs[0]), "output") | np.testing.assert_equal(str(graph.outputs[0]), "output") | ||||
| def add_1(x): | |||||
| x = x + 1 | |||||
| x.name = "func_add_1" | |||||
| return x | |||||
| class ModuleAdd_3(M.Module): | |||||
| def forward(self, x): | |||||
| x = x + 1 | |||||
| x.name = "module_add_1" | |||||
| x = x + 2 | |||||
| return x | |||||
| setattr(traced_module, "add_3", ModuleAdd_3()) | |||||
| self = graph.inputs[0] | |||||
| with graph.insert_exprs(): | |||||
| x = output_node + 1 | |||||
| x.name = "_add_1" | |||||
| x = add_1(x) | |||||
| x = self.add_3(x) | |||||
| graph.replace_node({output_node: x}) | |||||
| graph.compile() | |||||
| assert "_add_1" in graph._namespace.used_names | |||||
| assert "func_add_1" in graph._namespace.used_names | |||||
| assert "module_add_1" in traced_module.add_3.graph._namespace.used_names | |||||
| def test_set_graph_name(): | def test_set_graph_name(): | ||||
| traced_module, x, expect = _init_module() | traced_module, x, expect = _init_module() | ||||