diff --git a/mindspore/common/parameter.py b/mindspore/common/parameter.py index d40a4b4abe..4456782378 100644 --- a/mindspore/common/parameter.py +++ b/mindspore/common/parameter.py @@ -128,6 +128,12 @@ class Parameter(MetaTensor_): self.init_in_server = False self._unique = False self.is_in_parallel = _is_in_parallel_mode() + if isinstance(default_input, (MetaTensor, Tensor)): + MetaTensor_.__init__(self, default_input.dtype, default_input.shape) + elif isinstance(default_input, int): + MetaTensor_.__init__(self, mstype.int64, ()) + elif isinstance(default_input, float): + MetaTensor_.__init__(self, mstype.float32, ()) @staticmethod def _get_base_class(input_class):