GitOrigin-RevId: 258c03ee34
tags/v1.0.0-rc1
| @@ -235,14 +235,35 @@ class Tensor: | |||
| return self.__val.dtype | |||
| return self._symvar.dtype | |||
| def set_dtype(self, dtype: str = None): | |||
| @dtype.setter | |||
| def dtype(self, dtype: str = None): | |||
| r"""Set the data type of the tensor. | |||
| """ | |||
| if self.__val is not None: | |||
| self.__val = mgb.make_shared(self.device, value=self.astype(dtype).numpy()) | |||
| elif self.__sym_override is not None: | |||
| self.__sym_override = self.__sym_override.astype(dtype) | |||
| elif self.__sym is not None: | |||
| self.__sym = self.__sym.astype(dtype) | |||
| @property | |||
| def name(self): | |||
| r"""Get the tensor name, does not support Parameter and Buffer. | |||
| """ | |||
| return self._symvar.name | |||
| @name.setter | |||
| def name(self, name: str = None): | |||
| r"""Set the tensor name, does not support Parameter and Buffer. | |||
| """ | |||
| if self.__val is not None: | |||
| raise ValueError("name setting is not available for Parameter or Buffer.") | |||
| if self.__sym_override is not None: | |||
| self.__sym_override = self.__sym_override.rename(name) | |||
| if self.__sym is not None: | |||
| assert not self.__val | |||
| self.__sym = self.__sym.rename(name) | |||
| @property | |||
| def _comp_node(self): | |||
| if self.__val is not None: | |||
| @@ -436,6 +436,7 @@ class trace: | |||
| arg_names=None, | |||
| append=False, | |||
| optimize_for_inference=False, | |||
| output_names=None, | |||
| **kwargs | |||
| ): | |||
| """ | |||
| @@ -446,6 +447,8 @@ class trace: | |||
| :param append: whether output is appended to ``fpath``. | |||
| :param optimize_for_inference: whether to enable optimize_for_inference | |||
| pass before dump. | |||
| :param output_names: names of the output tensors in the traced function, | |||
| will use the default name if does not specify. | |||
| :param enable_io16xc32: whether to use float16 for I/O between oprs and use | |||
| float32 as internal computation precision. Note the output var would be | |||
| @@ -488,6 +491,17 @@ class trace: | |||
| len(self._args), len(arg_names) | |||
| ) | |||
| ) | |||
| if isinstance(output_names, str): | |||
| output_names = [output_names] | |||
| if output_names is None: | |||
| output_names = [var.name for var in self._sym_outputs] | |||
| elif len(output_names) != len(self._sym_outputs): | |||
| raise ValueError( | |||
| "len(output_names) should be {}, got {}".format( | |||
| len(self._sym_outputs), len(output_names) | |||
| ) | |||
| ) | |||
| optimize_for_inference_args_map = { | |||
| "enable_io16xc32": "f16_io_f32_comp", | |||
| "enable_ioc16": "f16_io_comp", | |||
| @@ -541,6 +555,8 @@ class trace: | |||
| sym_outputs = mgb.optimize_for_inference( | |||
| sym_outputs, **optimize_for_inference_kwargs | |||
| ) | |||
| for var, name in zip(sym_outputs, output_names): | |||
| var.rename(name) | |||
| mgb.serialize_comp_graph_to_file(fpath, sym_outputs, append=append) | |||
| def get_profile(self): | |||
| @@ -464,7 +464,7 @@ class Module(metaclass=ABCMeta): | |||
| # For quantized dtype, the initialized dtype | |||
| # scale/zero_points maybe invalid, use pretrained dtype instead. | |||
| if is_quantize(to_be_load.dtype) and is_quantize(var.dtype): | |||
| var.set_dtype(to_be_load.dtype) | |||
| var.dtype = to_be_load.dtype | |||
| var.set_value(to_be_load) | |||
| loaded.append(k) | |||
| @@ -46,29 +46,46 @@ def test_tensor_set_dtype(): | |||
| ) | |||
| t = mge.Parameter(np.ones((3, 4), dtype="float32")) | |||
| t.set_dtype(mgb.dtype.qint8(0.1)) | |||
| t.dtype = mgb.dtype.qint8(0.1) | |||
| check_dtype_value(t, 0.1, 10) | |||
| t = mge.Parameter(np.ones((3, 4), dtype=mgb.dtype.qint8(1))) | |||
| t.set_dtype(mgb.dtype.qint8(0.3)) | |||
| t.dtype = mgb.dtype.qint8(0.3) | |||
| check_dtype_value(t, 0.3, 3) | |||
| t = mge.Buffer(np.ones((3, 4), dtype="float32")) | |||
| t.set_dtype(mgb.dtype.qint8(0.1)) | |||
| t.dtype = mgb.dtype.qint8(0.1) | |||
| check_dtype_value(t, 0.1, 10) | |||
| t = mge.Buffer(np.ones((3, 4), dtype=mgb.dtype.qint8(1))) | |||
| t.set_dtype(mgb.dtype.qint8(0.3)) | |||
| t.dtype = mgb.dtype.qint8(0.3) | |||
| check_dtype_value(t, 0.3, 3) | |||
| t = mge.Buffer(np.ones((3, 4), dtype="float32")) | |||
| s = t + 1 | |||
| s.set_dtype(mgb.dtype.qint8(0.2)) | |||
| s.dtype = mgb.dtype.qint8(0.2) | |||
| check_dtype_value(s, 0.2, 10) | |||
| t.set_dtype(mgb.dtype.qint8(0.3)) | |||
| t.dtype = mgb.dtype.qint8(0.3) | |||
| s = t + 1 | |||
| s.set_dtype(mgb.dtype.qint8(0.1)) | |||
| s.dtype = mgb.dtype.qint8(0.1) | |||
| check_dtype_value(s, 0.1, 18) | |||
| s.set_dtype("float32") | |||
| s.dtype = "float32" | |||
| check_dtype_value(s, 0, 1.8) | |||
| def test_tensor_name(): | |||
| p = mge.Parameter(np.ones((3, 4), dtype="float32")) | |||
| assert "shared" in p.name | |||
| with pytest.raises(ValueError): | |||
| p.name = "Parameter0" | |||
| b = mge.Buffer(np.ones((3, 4), dtype="float32")) | |||
| assert "shared" in b.name | |||
| with pytest.raises(ValueError): | |||
| b.name = "Buffer0" | |||
| s = b + 1 | |||
| assert "ADD" in s.name | |||
| s.name = "WeightAdd1" | |||
| assert s.name == "WeightAdd1" | |||