From 7c4b7203b0c87b0e509b94370dcbc8a7656a5f4a Mon Sep 17 00:00:00 2001 From: buxue Date: Tue, 26 Jan 2021 12:21:45 +0800 Subject: [PATCH] add dtype shape and value in __str__ and __repr__ of Parameter --- mindspore/common/parameter.py | 40 +++++++------------ ...st_outermost_net_pass_non_tensor_inputs.py | 3 +- 2 files changed, 17 insertions(+), 26 deletions(-) diff --git a/mindspore/common/parameter.py b/mindspore/common/parameter.py index 4779ed8107..cba3d9232b 100644 --- a/mindspore/common/parameter.py +++ b/mindspore/common/parameter.py @@ -201,10 +201,11 @@ class Parameter(Tensor_): return (Tensor, data) def __str__(self): - return f'Parameter (name={self._param_info.name})' + return f'Parameter (name={self.name}, shape={self.shape}, dtype={self.dtype}, ' \ + f'requires_grad={self.requires_grad})' def __repr__(self): - return f'Parameter (name={self._param_info.name})' + return self.__str__() def __parameter__(self): """For parse check.""" @@ -242,7 +243,6 @@ class Parameter(Tensor_): """ return self._inited_param - @property def name(self): """Get the name of the parameter.""" @@ -501,10 +501,8 @@ class Parameter(Tensor_): Parameter, the `Parameter` after initializing data. If current `Parameter` was already initialized before, returns the same initialized `Parameter`. """ - if self.is_default_input_init: - is_current_in_parallel = _is_in_parallel_mode() - if self.is_in_parallel != is_current_in_parallel: - raise RuntimeError("Must set or change parallel mode before any Tensor created.") + if self.is_default_input_init and self.is_in_parallel != _is_in_parallel_mode(): + raise RuntimeError("Must set or change parallel mode before any Tensor created.") if self.init_mode is None: return self if self.inited_param is not None: @@ -512,29 +510,21 @@ class Parameter(Tensor_): if _is_role_worker() and self.cache_enable: global_seed, op_seed = _get_global_and_op_seed() _insert_weight_init_info(self.name, global_seed, op_seed) + + init_data_args = () if layout is not None: if not isinstance(layout, tuple): - raise TypeError("The layout should be tuple! layout is {}.".format(layout)) + raise TypeError("The layout should be tuple, but got layout is {}.".format(layout)) if len(layout) < 3: - raise ValueError("The length of layout must be larger than 3! layout is {}.".format(layout)) + raise ValueError("The length of layout must be larger than 2, but got layout is {}.".format(layout)) slice_index = int(_get_slice_index(layout[0], layout[1])) - if (self.init_in_server and self.is_param_ps and isinstance(self.init_mode, Tensor) - and self.init_mode.init is not None): - if _is_role_worker() or _is_role_sched(): - data = self.init_mode.init_data(0, [1]) - else: - data = self.init_mode.init_data(slice_index, layout[2], layout[5]) - else: - data = self.init_mode.init_data(slice_index, layout[2], layout[5]) + init_data_args += (slice_index, layout[2], layout[5]) + + if self.init_in_server and self.is_param_ps and isinstance(self.init_mode, Tensor) and \ + self.init_mode.init is not None and (_is_role_worker() or _is_role_sched()): + data = self.init_mode.init_data(0, [1]) else: - if (self.init_in_server and self.is_param_ps and isinstance(self.init_mode, Tensor) - and self.init_mode.init is not None): - if _is_role_worker() or _is_role_sched(): - data = self.init_mode.init_data(0, [1]) - else: - data = self.init_mode.init_data() - else: - data = self.init_mode.init_data() + data = self.init_mode.init_data(*init_data_args) obj = self._update_tensor_data(data) if id(obj) != id(self): diff --git a/tests/ut/python/pipeline/parse/test_outermost_net_pass_non_tensor_inputs.py b/tests/ut/python/pipeline/parse/test_outermost_net_pass_non_tensor_inputs.py index aaede8a5a8..d1c2b8cebb 100644 --- a/tests/ut/python/pipeline/parse/test_outermost_net_pass_non_tensor_inputs.py +++ b/tests/ut/python/pipeline/parse/test_outermost_net_pass_non_tensor_inputs.py @@ -93,7 +93,8 @@ def test_outermost_net_pass_parameter(): assert "The inputs types of the outermost network support bool, int, float, tensor, " \ "mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), " \ "and tuple or list containing only these types, and dict whose values are these types, " \ - "but got 1th arg is Parameter (name=weight)" in str(err.value) + "but got 1th arg is Parameter (name=weight, shape=(2, 2), dtype=Float32, requires_grad=True)" \ + in str(err.value) def test_outermost_net_pass_tuple_including_parameter():