|
|
@@ -280,15 +280,23 @@ class Parameter(MetaTensor): |
|
|
Set `default_input` of current `Parameter`. |
|
|
Set `default_input` of current `Parameter`. |
|
|
|
|
|
|
|
|
Args: |
|
|
Args: |
|
|
data (Union[Tensor, Initializer]): new data. |
|
|
|
|
|
slice_shape (bool): If slice the Parameter. Default: False. |
|
|
|
|
|
|
|
|
data (Union[Tensor, Initializer, int, float]): new data. |
|
|
|
|
|
slice_shape (bool): If slice the Parameter, will not check if shape is match. Default: False. |
|
|
|
|
|
|
|
|
Retruns: |
|
|
Retruns: |
|
|
Parameter, the parameter after set data. |
|
|
Parameter, the parameter after set data. |
|
|
""" |
|
|
""" |
|
|
if not isinstance(data, (MetaTensor, Initializer)): |
|
|
|
|
|
raise ValueError(f"Parameter data must be `Initializer` or a kind of `MetaTensor` " |
|
|
|
|
|
f"(like `Tensor` or `MetaTensor`). But with type {type(data)}.") |
|
|
|
|
|
|
|
|
def raise_type_error(incoming): |
|
|
|
|
|
raise TypeError(f"Can not change the Parameter dtype. Current dtype is {self.set_dtype}" |
|
|
|
|
|
f", and incoming is {incoming}. Use .set_dtype(xxx) to change the dtype.") |
|
|
|
|
|
|
|
|
|
|
|
if not isinstance(data, (MetaTensor, Initializer, int, float)): |
|
|
|
|
|
raise TypeError(f"Parameter data must be [`Initializer`, `int`, `float`] or a kind of `MetaTensor` " |
|
|
|
|
|
f"(like `Tensor` or `MetaTensor`). But with type {type(data)}.") |
|
|
|
|
|
if isinstance(data, (int, float)): |
|
|
|
|
|
if self.dtype in mstype.int_type and isinstance(data, float): |
|
|
|
|
|
raise_type_error(mstype.float_) |
|
|
|
|
|
data = Tensor(data, self.dtype) |
|
|
# both not init. |
|
|
# both not init. |
|
|
is_incoming_tensor = isinstance(data, Tensor) |
|
|
is_incoming_tensor = isinstance(data, Tensor) |
|
|
is_current_tensor = isinstance(self, Tensor) |
|
|
is_current_tensor = isinstance(self, Tensor) |
|
|
@@ -300,25 +308,25 @@ class Parameter(MetaTensor): |
|
|
"network, then call this method.") |
|
|
"network, then call this method.") |
|
|
if tuple(self.shape) != tuple(data.shape): |
|
|
if tuple(self.shape) != tuple(data.shape): |
|
|
# If Slice create Parameter shape can be change. |
|
|
# If Slice create Parameter shape can be change. |
|
|
if slice_shape: |
|
|
|
|
|
self._update_tensor_data(data) |
|
|
|
|
|
self.sliced = True |
|
|
|
|
|
else: |
|
|
|
|
|
|
|
|
if not slice_shape: |
|
|
raise ValueError(f"Can not change the shape of Parameter which has been initialized." |
|
|
raise ValueError(f"Can not change the shape of Parameter which has been initialized." |
|
|
f" Current shape is {self.shape}, and incoming is {data.shape}.") |
|
|
f" Current shape is {self.shape}, and incoming is {data.shape}.") |
|
|
if self.dtype != data.dtype: |
|
|
if self.dtype != data.dtype: |
|
|
raise ValueError(f"Can not change the Parameter dtype. Current dtype is {self.set_dtype}" |
|
|
|
|
|
f", and incoming is {data.dtype}. Use .set_dtype(xxx) to change the dtype.") |
|
|
|
|
|
|
|
|
raise_type_error(data.dtype) |
|
|
if isinstance(data, Initializer): |
|
|
if isinstance(data, Initializer): |
|
|
# The parameter has been initializered, directly update by the data |
|
|
# The parameter has been initializered, directly update by the data |
|
|
if is_current_tensor: |
|
|
if is_current_tensor: |
|
|
self._update_tensor_data(data.to_tensor()) |
|
|
self._update_tensor_data(data.to_tensor()) |
|
|
else: |
|
|
else: |
|
|
|
|
|
# also update the related inited parameter data |
|
|
|
|
|
if self.inited_param is not None: |
|
|
|
|
|
self.inited_param.set_parameter_data(data) |
|
|
self.init_mode = data |
|
|
self.init_mode = data |
|
|
elif is_incoming_tensor or is_current_tensor: |
|
|
elif is_incoming_tensor or is_current_tensor: |
|
|
self._update_tensor_data(data) |
|
|
self._update_tensor_data(data) |
|
|
else: |
|
|
else: |
|
|
raise ValueError(f"Not support to update the Parameter by {data}") |
|
|
raise ValueError(f"Not support to update the Parameter by {data}") |
|
|
|
|
|
self.sliced = slice_shape |
|
|
return self |
|
|
return self |
|
|
|
|
|
|
|
|
def init_data(self, layout=None, set_sliced=False): |
|
|
def init_data(self, layout=None, set_sliced=False): |
|
|
@@ -340,8 +348,6 @@ class Parameter(MetaTensor): |
|
|
""" |
|
|
""" |
|
|
if self.init_mode is None: |
|
|
if self.init_mode is None: |
|
|
return self |
|
|
return self |
|
|
if self.inited_param is not None: |
|
|
|
|
|
return self.inited_param |
|
|
|
|
|
if layout is not None: |
|
|
if layout is not None: |
|
|
if not isinstance(layout, list): |
|
|
if not isinstance(layout, list): |
|
|
raise TypeError("The layout should be list! layout is {}.".format(layout)) |
|
|
raise TypeError("The layout should be list! layout is {}.".format(layout)) |
|
|
@@ -362,8 +368,7 @@ class Parameter(MetaTensor): |
|
|
if id(obj) != id(self): |
|
|
if id(obj) != id(self): |
|
|
self._inited_param = obj |
|
|
self._inited_param = obj |
|
|
obj.init_mode = None |
|
|
obj.init_mode = None |
|
|
if set_sliced: |
|
|
|
|
|
obj.sliced = True |
|
|
|
|
|
|
|
|
obj.sliced = set_sliced |
|
|
return obj |
|
|
return obj |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|