|
|
|
@@ -201,10 +201,11 @@ class Parameter(Tensor_): |
|
|
|
return (Tensor, data) |
|
|
|
|
|
|
|
def __str__(self): |
|
|
|
return f'Parameter (name={self._param_info.name})' |
|
|
|
return f'Parameter (name={self.name}, shape={self.shape}, dtype={self.dtype}, ' \ |
|
|
|
f'requires_grad={self.requires_grad})' |
|
|
|
|
|
|
|
def __repr__(self): |
|
|
|
return f'Parameter (name={self._param_info.name})' |
|
|
|
return self.__str__() |
|
|
|
|
|
|
|
def __parameter__(self): |
|
|
|
"""For parse check.""" |
|
|
|
@@ -242,7 +243,6 @@ class Parameter(Tensor_): |
|
|
|
""" |
|
|
|
return self._inited_param |
|
|
|
|
|
|
|
|
|
|
|
@property |
|
|
|
def name(self): |
|
|
|
"""Get the name of the parameter.""" |
|
|
|
@@ -501,10 +501,8 @@ class Parameter(Tensor_): |
|
|
|
Parameter, the `Parameter` after initializing data. If current `Parameter` was already initialized before, |
|
|
|
returns the same initialized `Parameter`. |
|
|
|
""" |
|
|
|
if self.is_default_input_init: |
|
|
|
is_current_in_parallel = _is_in_parallel_mode() |
|
|
|
if self.is_in_parallel != is_current_in_parallel: |
|
|
|
raise RuntimeError("Must set or change parallel mode before any Tensor created.") |
|
|
|
if self.is_default_input_init and self.is_in_parallel != _is_in_parallel_mode(): |
|
|
|
raise RuntimeError("Must set or change parallel mode before any Tensor created.") |
|
|
|
if self.init_mode is None: |
|
|
|
return self |
|
|
|
if self.inited_param is not None: |
|
|
|
@@ -512,29 +510,21 @@ class Parameter(Tensor_): |
|
|
|
if _is_role_worker() and self.cache_enable: |
|
|
|
global_seed, op_seed = _get_global_and_op_seed() |
|
|
|
_insert_weight_init_info(self.name, global_seed, op_seed) |
|
|
|
|
|
|
|
init_data_args = () |
|
|
|
if layout is not None: |
|
|
|
if not isinstance(layout, tuple): |
|
|
|
raise TypeError("The layout should be tuple! layout is {}.".format(layout)) |
|
|
|
raise TypeError("The layout should be tuple, but got layout is {}.".format(layout)) |
|
|
|
if len(layout) < 3: |
|
|
|
raise ValueError("The length of layout must be larger than 3! layout is {}.".format(layout)) |
|
|
|
raise ValueError("The length of layout must be larger than 2, but got layout is {}.".format(layout)) |
|
|
|
slice_index = int(_get_slice_index(layout[0], layout[1])) |
|
|
|
if (self.init_in_server and self.is_param_ps and isinstance(self.init_mode, Tensor) |
|
|
|
and self.init_mode.init is not None): |
|
|
|
if _is_role_worker() or _is_role_sched(): |
|
|
|
data = self.init_mode.init_data(0, [1]) |
|
|
|
else: |
|
|
|
data = self.init_mode.init_data(slice_index, layout[2], layout[5]) |
|
|
|
else: |
|
|
|
data = self.init_mode.init_data(slice_index, layout[2], layout[5]) |
|
|
|
init_data_args += (slice_index, layout[2], layout[5]) |
|
|
|
|
|
|
|
if self.init_in_server and self.is_param_ps and isinstance(self.init_mode, Tensor) and \ |
|
|
|
self.init_mode.init is not None and (_is_role_worker() or _is_role_sched()): |
|
|
|
data = self.init_mode.init_data(0, [1]) |
|
|
|
else: |
|
|
|
if (self.init_in_server and self.is_param_ps and isinstance(self.init_mode, Tensor) |
|
|
|
and self.init_mode.init is not None): |
|
|
|
if _is_role_worker() or _is_role_sched(): |
|
|
|
data = self.init_mode.init_data(0, [1]) |
|
|
|
else: |
|
|
|
data = self.init_mode.init_data() |
|
|
|
else: |
|
|
|
data = self.init_mode.init_data() |
|
|
|
data = self.init_mode.init_data(*init_data_args) |
|
|
|
|
|
|
|
obj = self._update_tensor_data(data) |
|
|
|
if id(obj) != id(self): |
|
|
|
|