|
|
|
@@ -128,6 +128,12 @@ class Parameter(MetaTensor_): |
|
|
|
self.init_in_server = False |
|
|
|
self._unique = False |
|
|
|
self.is_in_parallel = _is_in_parallel_mode() |
|
|
|
if isinstance(default_input, (MetaTensor, Tensor)): |
|
|
|
MetaTensor_.__init__(self, default_input.dtype, default_input.shape) |
|
|
|
elif isinstance(default_input, int): |
|
|
|
MetaTensor_.__init__(self, mstype.int64, ()) |
|
|
|
elif isinstance(default_input, float): |
|
|
|
MetaTensor_.__init__(self, mstype.float32, ()) |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def _get_base_class(input_class): |
|
|
|
|