GitOrigin-RevId: 9c69254866
tags/v1.3.1
| @@ -772,7 +772,8 @@ class trace: | |||||
| len(self._output_bindings) | len(self._output_bindings) | ||||
| ) | ) | ||||
| ) | ) | ||||
| if arg_names is None: | |||||
| without_arg_names = arg_names is None | |||||
| if without_arg_names: | |||||
| arg_names = ["arg_%d" % i for i in range(len(self._arg_bindings))] | arg_names = ["arg_%d" % i for i in range(len(self._arg_bindings))] | ||||
| if arg_names and not isinstance(arg_names, collections.abc.Sequence): | if arg_names and not isinstance(arg_names, collections.abc.Sequence): | ||||
| arg_names = (arg_names,) | arg_names = (arg_names,) | ||||
| @@ -802,7 +803,7 @@ class trace: | |||||
| dtype=info.dtype, | dtype=info.dtype, | ||||
| device=dumped_device(info), | device=dumped_device(info), | ||||
| shape=info.shape or (1,), | shape=info.shape or (1,), | ||||
| name=arg_names[i] if arg_names else None, | |||||
| name=info.name if without_arg_names and info.name else arg_names[i], | |||||
| ) | ) | ||||
| for k, h in self._kwarg_bindings.items(): | for k, h in self._kwarg_bindings.items(): | ||||
| info = self._tinfo[h] | info = self._tinfo[h] | ||||
| @@ -889,6 +890,7 @@ class trace: | |||||
| return | return | ||||
| h, info = self._new_handle() | h, info = self._new_handle() | ||||
| info.external = False | info.external = False | ||||
| info.name = x.c_name | |||||
| info.device = x.device | info.device = x.device | ||||
| info.dtype = x.dtype | info.dtype = x.dtype | ||||
| info.shape = x.numpy().shape | info.shape = x.numpy().shape | ||||
| @@ -203,14 +203,31 @@ def test_with_same_operators(symbolic): | |||||
| assert ops[-2].name == "simple.RELU[0]" | assert ops[-2].name == "simple.RELU[0]" | ||||
| def test_not_keep_opr_name(): | |||||
| @pytest.mark.parametrize("symbolic", [False, True]) | |||||
| def test_not_keep_opr_name(symbolic): | |||||
| def f(x): | def f(x): | ||||
| return 2 * x | return 2 * x | ||||
| op = _dump_and_load(f, True, False)[-1] | |||||
| op = _dump_and_load(f, symbolic, False)[-1] | |||||
| assert op.name == "MUL(x,const<2>[2])[4]" | assert op.name == "MUL(x,const<2>[2])[4]" | ||||
| @pytest.mark.parametrize("tensor_name, var_name", [("data", "data"), (None, "arg_0")]) | |||||
| def test_catch_input_name(tensor_name, var_name): | |||||
| def f(x): | |||||
| return 2 * x | |||||
| func = trace(f, symbolic=True, capture_as_const=True) | |||||
| x = Tensor(np.ones(shape=(2, 3)), name=tensor_name) | |||||
| func(x).numpy() | |||||
| file = io.BytesIO() | |||||
| func.dump(file, optimize_for_inference=False, keep_opr_name=True, keep_var_name=2) | |||||
| file.seek(0) | |||||
| *_, outputs = G.load_graph(file) | |||||
| op = cgtools.get_oprs_seq(outputs)[-1] | |||||
| assert op.inputs[0].name == var_name | |||||
| @pytest.mark.parametrize("symbolic", [False, True]) | @pytest.mark.parametrize("symbolic", [False, True]) | ||||
| def test_quantized_module_auto_naming(symbolic): | def test_quantized_module_auto_naming(symbolic): | ||||
| class Simple(M.Module): | class Simple(M.Module): | ||||