You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

parameter.py 19 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Parameter for cell."""
  16. from copy import copy
  17. from .._c_expression import ParamInfo
  18. from .._c_expression import MetaTensor as MetaTensor_
  19. from . import dtype as mstype
  20. from .initializer import initializer
  21. from .tensor import Tensor, MetaTensor
  22. from .._checkparam import Validator
  23. from ..parallel._tensor import _get_slice_index
  24. from ..parallel._auto_parallel_context import auto_parallel_context
  25. from ..parallel._ps_context import _is_role_worker, _is_role_pserver, _is_role_sched
  26. __all__ = ['Parameter', 'ParameterTuple']
  27. PARAMETER_NAME_DEFAULT = "Parameter"
  28. PARAMETER_NAME_PREFIX_MAX_LEN = 1024
  29. def _is_in_parallel_mode():
  30. """Get parallel mode."""
  31. return auto_parallel_context().get_parallel_mode() in ["semi_auto_parallel", "auto_parallel"]
  32. class Parameter(MetaTensor_):
  33. """
  34. Parameter types of cell models.
  35. After initialized `Parameter` is a subtype of `Tensor`.
  36. In auto_parallel mode of "semi_auto_parallel" and "auto_parallel", if init `Parameter` by
  37. an `MetaTensor`, the type of Parameter will be `MetaTensor` not `Tensor`. `MetaTensor_`
  38. only saves the shape and type info of a tensor with no memory usage. The shape can be changed while
  39. compiling for auto-parallel. Call `init_data` will return a Tensor Parameter with initialized data.
  40. Note:
  41. Each parameter of Cell is represented by Parameter class.
  42. A Parameter has to belong to a Cell.
  43. If there is an operator in the network that requires part of the inputs to be Parameter,
  44. then the Parameters as this part of the inputs are not allowed to be cast.
  45. Args:
  46. default_input (Union[Tensor, MetaTensor, Number]): Parameter data, to be set initialized.
  47. name (str): Name of the child parameter.
  48. requires_grad (bool): True if the parameter requires gradient. Default: True.
  49. layerwise_parallel (bool): A kind of model parallel mode. When layerwise_parallel is true in parallel mode,
  50. broadcast and gradients communication would not be applied to parameters. Default: False.
  51. Example:
  52. >>> from mindspore import Parameter, Tensor
  53. >>> from mindspore.common import initializer as init
  54. >>> from mindspore.ops import operations as P
  55. >>> from mindspore.nn import Cell
  56. >>> import mindspore
  57. >>> import numpy as np
  58. >>> from mindspore import context
  59. >>>
  60. >>> class Net(Cell):
  61. >>> def __init__(self):
  62. >>> super(Net, self).__init__()
  63. >>> self.matmul = P.MatMul()
  64. >>> self.weight = Parameter(Tensor(np.ones((1,2))), name="w", requires_grad=True)
  65. >>>
  66. >>> def construct(self, x):
  67. >>> out = self.matmul(self.weight, x)
  68. >>> return out
  69. >>> context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
  70. >>> net = Net()
  71. >>> x = Tensor(np.ones((2,1)))
  72. >>> net(x)
  73. [[2.]]
  74. >>> net.weight.set_data(Tensor(np.zeros((1,2))))
  75. >>> net(x)
  76. [[0.]]
  77. """
  78. __base_type__ = {}
  79. def __new__(cls, default_input, name, *args, **kwargs):
  80. input_class, *class_init_args = Parameter._get_parameter_new_args(default_input)
  81. new_type = Parameter._get_base_class(input_class)
  82. obj = input_class.__new__(new_type)
  83. input_class.__init__(obj, *class_init_args)
  84. # it's better to make the Initializer a kind of metatensor.
  85. obj.init_mode = None
  86. obj.is_default_input_meta = False
  87. if isinstance(default_input, MetaTensor):
  88. obj.is_default_input_meta = True
  89. if not isinstance(obj, Tensor):
  90. obj.init_mode = default_input
  91. return obj
  92. def __reduce_ex__(self, _):
  93. data = self
  94. if self.init_mode is not None:
  95. data = self.init_mode
  96. else:
  97. # cast to break deep infinit loop while deepcopy
  98. data = Tensor(self)
  99. return (
  100. Parameter, (data, self.name, self.requires_grad, self.layerwise_parallel))
  101. def __init__(self, default_input, name, requires_grad=True, layerwise_parallel=False):
  102. self._param_info = ParamInfo()
  103. self.name = name
  104. self.requires_grad = requires_grad
  105. self.layerwise_parallel = layerwise_parallel
  106. # this flag for tensor copy data.
  107. self.init_flag = False
  108. # this flag is for ge variable copy data.
  109. self._is_init = False
  110. self._inited_param = None
  111. self._sliced = False
  112. self.is_param_ps = False
  113. self._cast_type = None
  114. self.init_in_server = False
  115. self.is_in_parallel = _is_in_parallel_mode()
  116. @staticmethod
  117. def _get_base_class(input_class):
  118. input_class_name = f'Parameter{input_class.__name__}'
  119. if input_class_name in Parameter.__base_type__:
  120. new_type = Parameter.__base_type__[input_class_name]
  121. else:
  122. new_type = type(input_class_name, (Parameter, input_class), {})
  123. Parameter.__base_type__[input_class_name] = new_type
  124. return new_type
  125. @staticmethod
  126. def _get_parameter_new_args(data):
  127. """Set `set_data` of current `Parameter`."""
  128. if isinstance(data, bool):
  129. raise ValueError('Parameter data can not be `bool`')
  130. if isinstance(data, MetaTensor):
  131. if _is_in_parallel_mode() or _is_role_worker():
  132. # do not init data while in auto parallel.
  133. return (MetaTensor_, data.dtype, data.shape)
  134. data = data.to_tensor()
  135. if isinstance(data, Tensor):
  136. # make a copy of Tensor to init the parameter
  137. return (Tensor, data.asnumpy(),)
  138. if isinstance(data, int):
  139. return (Tensor, data, mstype.int32)
  140. if isinstance(data, float):
  141. return (Tensor, data, mstype.float32)
  142. return (Tensor, data)
  143. def __str__(self):
  144. value_str = MetaTensor.__str__(self)
  145. if isinstance(self, Tensor):
  146. value_str = Tensor.__str__(self)
  147. return f'Parameter (name={self._param_info.name}, value={value_str})'
  148. def __repr__(self):
  149. value_str = MetaTensor.__repr__(self)
  150. if isinstance(self, Tensor):
  151. value_str = Tensor.__repr__(self)
  152. return f'Parameter (name={self._param_info.name}, value={value_str})'
  153. def __parameter__(self):
  154. """For parse check."""
  155. def set_param_ps(self, init_in_server=False):
  156. if _is_role_worker() or _is_role_pserver() or _is_role_sched():
  157. if init_in_server and (not self.name.endswith("embedding_table")):
  158. raise RuntimeError("Can not initialize parameter '{}' in server, only parameters of \
  159. sparse operator support initialization in server.".format(self.name))
  160. self.is_param_ps = True
  161. self.init_in_server = init_in_server
  162. self._param_info.init_in_server = init_in_server
  163. else:
  164. raise RuntimeError("Must complete following two steps before calling set_param_ps: \
  165. 1. set_ps_context(enable_ps=True) \
  166. 2. export MS_ROLE environment variable.")
  167. @property
  168. def inited_param(self):
  169. """
  170. Get the new parameter after call the init_data.
  171. Default is a None, If `self` is a Parameter with out data, after call the
  172. `init_data` the initialized Parameter with data will be recorded here.
  173. """
  174. return self._inited_param
  175. @property
  176. def name(self):
  177. """Get the name of the parameter."""
  178. return self._param_info.name
  179. @name.setter
  180. def name(self, name_):
  181. """
  182. Define a name for the parameter.
  183. Args:
  184. name_ (`str` or `None`): The name of the parameter. When the parameter is None or an empty string,
  185. the default value `PARAMETER_NAME_DEFAULT` is used.
  186. """
  187. if name_ is None:
  188. name_ = PARAMETER_NAME_DEFAULT
  189. elif isinstance(name_, str):
  190. name_ = name_.strip()
  191. if name_ == '':
  192. name_ = PARAMETER_NAME_DEFAULT
  193. if len(name_) > PARAMETER_NAME_PREFIX_MAX_LEN:
  194. raise ValueError("The length of the '{}' name should be less than {}.".
  195. format(name_, PARAMETER_NAME_PREFIX_MAX_LEN))
  196. else:
  197. raise ValueError("The type of the name should be `str` or `None`.")
  198. self._param_info.name = name_
  199. @property
  200. def sliced(self):
  201. """Get slice status of the parameter."""
  202. return self._sliced
  203. @sliced.setter
  204. def sliced(self, sliced_):
  205. self._sliced = sliced_
  206. @property
  207. def is_init(self):
  208. """
  209. Get the initialization status of the parameter.
  210. In GE backend, the Parameter need a "init graph" to sync the data from host to device.
  211. This flag indicates whether the data as been sync to the device.
  212. This flag only work in GE, and it will be set to False in other backend.
  213. """
  214. return self._is_init
  215. @is_init.setter
  216. def is_init(self, is_init_):
  217. """
  218. Set init status of the parameter.
  219. Args:
  220. is_init_ (bool): The init status of the parameter.
  221. """
  222. self._is_init = is_init_
  223. def clone(self, prefix, init='same'):
  224. """
  225. Clone the parameter.
  226. Args:
  227. prefix (str): Namespace of parameter. The cloned Parameter name is
  228. combined of prefix and current name: `f"{perfix}.{self.name}"`.
  229. init (Union[Tensor, str, MetaTensor, numbers.Number]): Initialize the shape of the parameter.
  230. Default: 'same'.
  231. Returns:
  232. Parameter, a new parameter.
  233. """
  234. Validator.check_str_by_regular(prefix)
  235. x = copy(self)
  236. # pylint: disable=protected-access
  237. x._param_info = self._param_info.clone()
  238. x._param_info.name = prefix + '.' + self._param_info.name
  239. x.is_init = False
  240. x.is_param_ps = self.is_param_ps
  241. x.init_in_server = self.init_in_server
  242. if init != 'same':
  243. shape = self.shape
  244. dtype = self.dtype
  245. x.set_data(initializer(init, shape=shape, dtype=dtype))
  246. return x
  247. @property
  248. def layerwise_parallel(self):
  249. return self._param_info.layerwise_parallel
  250. @layerwise_parallel.setter
  251. def layerwise_parallel(self, value=True):
  252. if not isinstance(value, bool):
  253. raise TypeError("`layerwise_parallel` parameter must be bool type")
  254. self._param_info.layerwise_parallel = value
  255. @property
  256. def requires_grad(self):
  257. """Return whether the parameter requires gradient."""
  258. return self._param_info.requires_grad
  259. @requires_grad.setter
  260. def requires_grad(self, value=True):
  261. if not isinstance(value, bool):
  262. raise TypeError("`requires_grad` parameter must be bool type")
  263. self._param_info.requires_grad = value
  264. @property
  265. def data(self):
  266. return self
  267. def _update_tensor_data(self, data):
  268. "Update the parameter by a Tensor."
  269. if isinstance(self, Tensor):
  270. # for Tensor same shape:
  271. self.init_flag = False
  272. return self.assign_value(data)
  273. # create a new tensor
  274. return Parameter(data, self.name, self.requires_grad)
  275. def set_data(self, data, slice_shape=False):
  276. """
  277. Set `set_data` of current `Parameter`.
  278. Args:
  279. data (Union[Tensor, MetaTensor, int, float]): new data.
  280. slice_shape (bool): If slice the Parameter, will not check if shape is match. Default: False.
  281. Retruns:
  282. Parameter, the parameter after set data.
  283. """
  284. def raise_type_error(incoming):
  285. raise TypeError(f"Incoming Parameter dtype can not be converted to current dtype implicitly. "
  286. f"Current dtype is {self.dtype}, and incoming is {incoming}. "
  287. f"Use .set_dtype(xxx) to change the dtype.")
  288. if not isinstance(data, (MetaTensor_, int, float)):
  289. raise TypeError(f"Parameter data must be [`MetaTensor`, `int`, `float`] or a kind of `MetaTensor_` "
  290. f"(like `Tensor` or `MetaTensor_`). But with type {type(data)}.")
  291. if isinstance(data, (int, float)):
  292. if self.dtype in mstype.int_type and isinstance(data, float):
  293. raise_type_error(mstype.float_)
  294. data = Tensor(data, self.dtype)
  295. # both not init.
  296. is_incoming_tensor = isinstance(data, Tensor)
  297. is_current_tensor = isinstance(self, Tensor)
  298. if is_incoming_tensor and not is_current_tensor:
  299. raise TypeError("Parameter is a `MetaTensor_` and not initializered, `data` for `set_data`"
  300. "should be a MetaTensor. If you want to update it by Tensor, call method"
  301. "`init_parameters_data` of `Cell` to init and replace all the Parameter of"
  302. "network, then call this method.")
  303. if tuple(self.shape) != tuple(data.shape):
  304. # If Slice create Parameter shape can be change.
  305. if not slice_shape:
  306. raise ValueError(f"Can not change the shape of Parameter which has been initialized."
  307. f" Current shape is {self.shape}, and incoming is {data.shape}.")
  308. if self.dtype != data.dtype:
  309. if mstype.implicit_conversion_seq[self.dtype] < mstype.implicit_conversion_seq[data.dtype]:
  310. raise_type_error(data.dtype)
  311. else:
  312. data = Tensor(data, self.dtype)
  313. if isinstance(data, MetaTensor):
  314. # The parameter has been initializered, directly update by the data
  315. if is_current_tensor:
  316. self._update_tensor_data(data.to_tensor())
  317. else:
  318. # also update the related inited parameter data
  319. if self.inited_param is not None:
  320. self.inited_param.set_data(data)
  321. self.init_mode = data
  322. elif is_incoming_tensor or is_current_tensor:
  323. self._update_tensor_data(data)
  324. else:
  325. raise ValueError(f"Not support to update the Parameter by {data}")
  326. self.sliced = slice_shape
  327. return self
  328. def init_data(self, layout=None, set_sliced=False):
  329. """
  330. Initialize the parameter data.
  331. Args:
  332. layout (list[list[int]]): Parameter slice layout [dev_mat, tensor_map, slice_shape].
  333. - dev_mat (list[int]): Device matrix.
  334. - tensor_map (list[int]): Tensor map.
  335. - slice_shape (list[int]): Shape of slice.
  336. set_sliced (bool): True if the parameter is set sliced after initializing the data.
  337. Default: False.
  338. Raises:
  339. RuntimeError: If it is from Initializer, and parallel mode has changed after the Initializer created.
  340. Returns:
  341. Parameter, the `Parameter` after initializing data. If current `Parameter` was already initialized before,
  342. returns the same initialized `Parameter`.
  343. """
  344. if self.is_default_input_meta:
  345. is_current_in_parallel = _is_in_parallel_mode()
  346. if self.is_in_parallel != is_current_in_parallel:
  347. raise RuntimeError("Must set or change parallel mode before any MetaTensor created.")
  348. if self.init_mode is None:
  349. return self
  350. if self.inited_param is not None:
  351. return self.inited_param
  352. if layout is not None:
  353. if not isinstance(layout, tuple):
  354. raise TypeError("The layout should be tuple! layout is {}.".format(layout))
  355. if len(layout) < 3:
  356. raise ValueError("The length of layout must be larger than 3! layout is {}.".format(layout))
  357. slice_index = int(_get_slice_index(layout[0], layout[1]))
  358. if (self.init_in_server and self.is_param_ps and isinstance(self.init_mode, MetaTensor)):
  359. if _is_role_worker():
  360. data = self.init_mode.to_tensor(0, [1])
  361. else:
  362. data = self.init_mode.to_tensor(slice_index, layout[2])
  363. else:
  364. data = self.init_mode.to_tensor(slice_index, layout[2])
  365. else:
  366. if (self.init_in_server and self.is_param_ps and isinstance(self.init_mode, MetaTensor)):
  367. if _is_role_worker():
  368. data = self.init_mode.to_tensor(0, [1])
  369. else:
  370. data = self.init_mode.to_tensor()
  371. else:
  372. data = self.init_mode.to_tensor()
  373. obj = self._update_tensor_data(data)
  374. if id(obj) != id(self):
  375. self._inited_param = obj
  376. obj.init_mode = None
  377. obj.sliced = set_sliced
  378. return obj
  379. class ParameterTuple(tuple):
  380. """
  381. Class for storing tuple of parameters.
  382. Note:
  383. It is used to store the parameters of the network into the parameter tuple collection.
  384. """
  385. def __new__(cls, iterable):
  386. """Create instance object of ParameterTuple."""
  387. data = tuple(iterable)
  388. for x in data:
  389. if not isinstance(x, Parameter):
  390. raise TypeError(f"ParameterTuple input should be `Parameter` collection."
  391. f"But got a {type(iterable)}, {iterable}")
  392. return tuple.__new__(ParameterTuple, tuple(data))
  393. def clone(self, prefix, init='same'):
  394. """
  395. Clone the parameter.
  396. Args:
  397. prefix (str): Namespace of parameter.
  398. init (str): Initialize the shape of the parameter. Default: 'same'.
  399. Returns:
  400. Tuple, the new Parameter tuple.
  401. """
  402. Validator.check_str_by_regular(prefix)
  403. new = []
  404. for x in self:
  405. x1 = x.clone(prefix, init)
  406. new.append(x1)
  407. return ParameterTuple(new)
  408. def __parameter_tuple__(self):
  409. """For parse check."""