You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

parameter.py 30 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
code check for master # Conflicts: # mindspore/common/initializer.py # mindspore/nn/cell.py # # 似乎您正在做一个拣选提交。如果不对,请删除文件 # .git/CHERRY_PICK_HEAD # 然后重试。 # 请为您的变更输入提交说明。以 '#' 开始的行将被忽略,而一个空的提交 # 说明将会终止提交。 # # 日期: Fri Aug 13 18:40:19 2021 +0800 # # 位于分支 code_review_r1.3 # 您的分支与上游分支 'ma/r1.3' 一致。 # # 您在执行拣选提交 ffda6be35c 的操作。 # # 要提交的变更: # 修改: mindspore/common/__init__.py # 修改: mindspore/common/_register_for_tensor.py # 修改: mindspore/common/api.py # 修改: mindspore/common/dtype.py # 修改: mindspore/common/initializer.py # 修改: mindspore/common/monad.py # 修改: mindspore/common/parameter.py # 修改: mindspore/common/seed.py # 修改: mindspore/common/tensor.py # 修改: mindspore/nn/cell.py # 修改: mindspore/nn/metrics/__init__.py # 修改: mindspore/nn/metrics/confusion_matrix.py # 修改: mindspore/nn/metrics/error.py # 修改: mindspore/nn/metrics/fbeta.py # 修改: mindspore/nn/metrics/loss.py # 修改: mindspore/nn/metrics/metric.py # 修改: mindspore/nn/metrics/precision.py # 修改: mindspore/nn/metrics/recall.py # 修改: mindspore/nn/metrics/topk.py # 修改: mindspore/train/callback/_checkpoint.py # 修改: mindspore/train/model.py # 修改: mindspore/train/serialization.py # # Conflicts: # mindspore/common/api.py # mindspore/common/initializer.py # mindspore/nn/metrics/confusion_matrix.py # # 似乎您正在做一个拣选提交。如果不对,请删除文件 # .git/CHERRY_PICK_HEAD # 然后重试。 # 请为您的变更输入提交说明。以 '#' 开始的行将被忽略,而一个空的提交 # 说明将会终止提交。 # # 日期: Fri Aug 13 18:40:19 2021 +0800 # # 位于分支 code_review_master # 您的分支与上游分支 'ma/master' 一致。 # # 您在执行拣选提交 743f9fbff3 的操作。 # # 要提交的变更: # 修改: mindspore/common/__init__.py # 修改: mindspore/common/_monad.py # 修改: mindspore/common/_register_for_tensor.py # 修改: mindspore/common/api.py # 修改: mindspore/common/dtype.py # 修改: mindspore/common/initializer.py # 修改: mindspore/common/parameter.py # 修改: mindspore/common/seed.py # 修改: mindspore/common/tensor.py # 修改: mindspore/nn/cell.py # 修改: mindspore/nn/metrics/__init__.py # 修改: mindspore/nn/metrics/confusion_matrix.py # 修改: mindspore/nn/metrics/error.py # 修改: mindspore/nn/metrics/fbeta.py # 修改: mindspore/nn/metrics/loss.py # 修改: mindspore/nn/metrics/metric.py # 修改: mindspore/nn/metrics/precision.py # 修改: mindspore/nn/metrics/recall.py # 修改: mindspore/nn/metrics/topk.py # 修改: mindspore/train/callback/_checkpoint.py # 修改: mindspore/train/model.py # 修改: mindspore/train/serialization.py #
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Parameter for cell."""
  16. from copy import copy
  17. import numbers
  18. import numpy as np
  19. from mindspore import log as logger
  20. from .._c_expression import ParamInfo
  21. from . import dtype as mstype
  22. from .. import context
  23. from ..parallel._utils import _get_parallel_mode
  24. from .initializer import initializer
  25. from .tensor import Tensor
  26. from .._checkparam import Validator
  27. from .._c_expression import Tensor as Tensor_
  28. from ..parallel._tensor import _get_slice_index
  29. from ..parallel._auto_parallel_context import auto_parallel_context
  30. from ..parallel._ps_context import _is_role_worker, _is_role_pserver, _is_role_sched, _clone_hash_table
  31. from ..parallel._ps_context import _reinsert_hash_table_size
  32. from ..parallel._ps_context import _insert_weight_init_info, _insert_accumu_init_info
  33. from .seed import _get_global_and_op_seed
  34. __all__ = ['Parameter', 'ParameterTuple']
  35. PARAMETER_NAME_DEFAULT = "Parameter"
  36. PARAMETER_NAME_PREFIX_MAX_LEN = 1024
  37. def _is_in_parallel_mode():
  38. """Get parallel mode."""
  39. return auto_parallel_context().get_parallel_mode() in ["semi_auto_parallel", "auto_parallel"]
  40. def init_to_value(init):
  41. """
  42. Get value of initializer.
  43. Returns:
  44. Value of the initializer.
  45. Raises:
  46. ValueError: The value of the argument 'init' is not correct.
  47. """
  48. if isinstance(init, str):
  49. if init == 'zeros':
  50. return 0.0
  51. if init == 'ones':
  52. return 1.0
  53. raise ValueError("The argument 'init' should be one of values in ['zeros', 'ones'].")
  54. if isinstance(init, numbers.Number):
  55. return float(init)
  56. raise ValueError("The argument 'init' should be number or string, but got {}.".format(type(init)))
  57. class Parameter(Tensor_):
  58. """
  59. An object holding weights of cells, after initialized `Parameter` is a subtype of `Tensor`.
  60. Note:
  61. In auto_parallel mode of "semi_auto_parallel" and "auto_parallel", if init `Parameter` by
  62. a `Tensor`, the type of Parameter will be `Tensor`. `Tensor`
  63. will save the shape and type info of a tensor with no memory usage. The shape can be changed while
  64. compiling for auto-parallel. Call `init_data` will return a Tensor Parameter with initialized data.
  65. If there is an operator in the network that requires part of the inputs to be Parameter,
  66. then the Parameters as this part of the inputs are not allowed to be cast.
  67. It is recommended to use the default value of `name` when initialize a parameter as one attribute of a cell,
  68. otherwise, the parameter name may be different from expected.
  69. Args:
  70. default_input (Union[Tensor, int, float, numpy.ndarray, list]): Parameter data,
  71. to initialize the parameter data.
  72. name (str): Name of the parameter. Default: None.
  73. 1) If the parameter is not given a name, the default name is its variable name. For example, the name of
  74. param_a below is name_a, and the name of param_b is the variable name param_b.
  75. .. code-block::
  76. self.param_a = Parameter(Tensor([1], ms.float32), name="name_a")
  77. self.param_b = Parameter(Tensor([2], ms.float32))
  78. 2) If parameter in list or tuple is not given a name, will give it a unique name. For example, the names of
  79. parameters below are Parameter$1 and Parameter$2.
  80. .. code-block::
  81. self.param_list = [Parameter(Tensor([3], ms.float32)),
  82. Parameter(Tensor([4], ms.float32))]
  83. 3) If the parameter is given a name, and the same name exists between different parameters, an exception
  84. will be thrown. For example, "its name 'name_a' already exists." will be thrown.
  85. .. code-block::
  86. self.param_a = Parameter(Tensor([1], ms.float32), name="name_a")
  87. self.param_tuple = (Parameter(Tensor([5], ms.float32), name="name_a"),
  88. Parameter(Tensor([6], ms.float32)))
  89. 4) If a parameter appear multiple times in list or tuple, check the name of the object only once. For
  90. example, the following example will not throw an exception.
  91. .. code-block::
  92. self.param_a = Parameter(Tensor([1], ms.float32), name="name_a")
  93. self.param_tuple = (self.param_a, self.param_a)
  94. requires_grad (bool): True if the parameter requires gradient. Default: True.
  95. layerwise_parallel (bool): When layerwise_parallel is true in data/hybrid parallel mode,
  96. broadcast and gradients communication would not be applied to parameters. Default: False.
  97. parallel_optimizer (bool): It is used to filter the weight shard operation in semi auto or auto parallel
  98. mode. It works only when enable parallel optimizer in `mindspore.context.set_auto_parallel_context()`.
  99. Default: True.
  100. Examples:
  101. >>> import numpy as np
  102. >>> from mindspore import Parameter, Tensor
  103. >>> import mindspore.ops as ops
  104. >>> import mindspore.nn as nn
  105. >>> import mindspore
  106. >>>
  107. >>> class Net(nn.Cell):
  108. ... def __init__(self):
  109. ... super(Net, self).__init__()
  110. ... self.matmul = ops.MatMul()
  111. ... self.weight = Parameter(Tensor(np.ones((1, 2)), mindspore.float32), name="w", requires_grad=True)
  112. ...
  113. ... def construct(self, x):
  114. ... out = self.matmul(self.weight, x)
  115. ... return out
  116. >>> net = Net()
  117. >>> x = Tensor(np.ones((2, 1)), mindspore.float32)
  118. >>> print(net(x))
  119. [[2.]]
  120. >>> net.weight.set_data(Tensor(np.zeros((1, 2)), mindspore.float32))
  121. >>> print(net(x))
  122. [[0.]]
  123. """
  124. __base_type__ = {}
  125. def __new__(cls, default_input, *args, **kwargs):
  126. init_data_flag = bool(isinstance(default_input, Tensor) and default_input.has_init)
  127. input_class, *class_init_args = Parameter._get_parameter_new_args(default_input)
  128. new_type = Parameter._get_base_class(input_class)
  129. obj = input_class.__new__(new_type)
  130. input_class.__init__(obj, *class_init_args)
  131. # it's better to make the Initializer a kind of tensor.
  132. obj.init_mode = None
  133. obj.is_default_input_init = init_data_flag
  134. if obj.has_init:
  135. obj.init_mode = default_input
  136. return obj
  137. def __reduce_ex__(self, _):
  138. data = self
  139. if self.init_mode is not None:
  140. data = self.init_mode
  141. else:
  142. # cast to break deep infinite loop while deepcopy
  143. data = Tensor(self)
  144. return (
  145. Parameter, (data, self.name, self.requires_grad, self.layerwise_parallel))
  146. def __init__(self, default_input, name=None, requires_grad=True, layerwise_parallel=False, parallel_optimizer=True):
  147. self.param_info = ParamInfo()
  148. self.init_in_server = False
  149. self.cache_enable = False
  150. self.name = name
  151. self.requires_grad = requires_grad
  152. self.layerwise_parallel = layerwise_parallel
  153. self.parallel_optimizer = parallel_optimizer
  154. # this flag for tensor copy data.
  155. self.init_flag = False
  156. # this flag is for ge variable copy data.
  157. self.is_init = False
  158. self._inited_param = None
  159. self._sliced = False
  160. self.is_param_ps = False
  161. self.push_weight_to_server = False
  162. self.pull_weight_from_server = False
  163. self.requires_aggr = True
  164. self._cast_type = None
  165. self._unique = False
  166. self.is_in_parallel = _is_in_parallel_mode()
  167. self._pipeline_stage_list = []
  168. if isinstance(default_input, (Tensor_, Tensor)):
  169. Tensor_.__init__(self, default_input.dtype, default_input.shape)
  170. elif isinstance(default_input, int):
  171. Tensor_.__init__(self, mstype.int64, ())
  172. elif isinstance(default_input, float):
  173. Tensor_.__init__(self, mstype.float32, ())
  174. elif isinstance(default_input, (np.ndarray, list)):
  175. Tensor_.__init__(self, default_input)
  176. else:
  177. raise TypeError(f"The type of the argument 'default_input' must be in ['Tensor', 'int', 'float',"
  178. f" 'numpy.ndarray', 'list']. But got type {type(default_input)}.")
  179. def __deepcopy__(self, memodict):
  180. new_obj = Parameter(self)
  181. new_obj.name = self.name
  182. new_obj._inited_param = self._inited_param # pylint: disable=W0212
  183. return new_obj
  184. @staticmethod
  185. def _get_base_class(input_class):
  186. input_class_name = f'Parameter{input_class.__name__}'
  187. if input_class_name in Parameter.__base_type__:
  188. new_type = Parameter.__base_type__[input_class_name]
  189. else:
  190. new_type = type(input_class_name, (Parameter, input_class), {})
  191. Parameter.__base_type__[input_class_name] = new_type
  192. return new_type
  193. @staticmethod
  194. def _get_parameter_new_args(data):
  195. """Set `set_data` of current `Parameter`."""
  196. if isinstance(data, bool):
  197. raise ValueError('Parameter data can not be `bool`')
  198. if isinstance(data, Tensor) and data.has_init:
  199. if context.get_fl_context('server_mode') not in ('FEDERATED_LEARNING', 'HYBRID_TRAINING'):
  200. if _is_in_parallel_mode() or _is_role_worker() or _is_role_sched() or _is_role_pserver():
  201. # do not init data while in auto parallel.
  202. return (Tensor, None, data.dtype, data.shape, data.init)
  203. data = data.init_data().asnumpy()
  204. elif isinstance(data, Tensor):
  205. # make a copy of Tensor to init the parameter
  206. return (Tensor, data.asnumpy(),)
  207. if isinstance(data, int):
  208. return (Tensor, data, mstype.int32)
  209. if isinstance(data, float):
  210. return (Tensor, data, mstype.float32)
  211. return (Tensor, data)
  212. def __str__(self):
  213. return f'Parameter (name={self.name}, shape={self.shape}, dtype={self.dtype}, ' \
  214. f'requires_grad={self.requires_grad})'
  215. def __repr__(self):
  216. return self.__str__()
  217. def __parameter__(self):
  218. """For parse check."""
  219. def set_param_ps(self, init_in_server=False):
  220. """
  221. Set whether the trainable parameter is updated by parameter server and whether the
  222. trainable parameter is initialized on server.
  223. Note:
  224. It only works when a running task is in the parameter server mode.
  225. Args:
  226. init_in_server (bool): Whether trainable parameter updated by parameter server is
  227. initialized on server. Default: False.
  228. """
  229. if not(_is_role_worker() or _is_role_pserver() or _is_role_sched()):
  230. raise RuntimeError("Must complete following two steps before calling set_param_ps: \n"
  231. "1. context.set_ps_context(enable_ps=True) \n"
  232. "2. export MS_ROLE environment variable \n"
  233. "Please refer to the official website for detailed usage.")
  234. self.is_param_ps = True
  235. self.init_in_server = init_in_server
  236. self.param_info.init_in_server = init_in_server
  237. def set_param_fl(self, push_to_server=False, pull_from_server=False, requires_aggr=True):
  238. """
  239. Set the way of parameter and server interaction.
  240. Args:
  241. push_to_server (bool): Whether the parameter should be pushed to server. Default: False.
  242. pull_from_server (bool): Whether the parameter should be pulled from server. Default: False.
  243. requires_aggr (bool): Whether the parameter should be aggregated in the server. Default: True.
  244. """
  245. if push_to_server:
  246. self.push_weight_to_server = True
  247. if pull_from_server:
  248. self.pull_weight_from_server = True
  249. if not requires_aggr:
  250. self.requires_aggr = False
  251. self.param_info.requires_aggr = False
  252. @property
  253. def inited_param(self):
  254. """
  255. Get the new parameter after call the init_data.
  256. Default is a None, If `self` is a Parameter without data, after call the
  257. `init_data` the initialized Parameter with data will be recorded here.
  258. """
  259. return self._inited_param
  260. @property
  261. def name(self):
  262. """Get the name of the parameter."""
  263. return self.param_info.name
  264. @name.setter
  265. def name(self, name_):
  266. """
  267. Define a name for the parameter.
  268. Args:
  269. name_ (`str` or `None`): The name of the parameter. When the parameter is None or an empty string,
  270. the default value `PARAMETER_NAME_DEFAULT` is used.
  271. """
  272. if name_ is None:
  273. name_ = PARAMETER_NAME_DEFAULT
  274. elif isinstance(name_, str):
  275. name_ = name_.strip()
  276. if name_ == '':
  277. name_ = PARAMETER_NAME_DEFAULT
  278. if len(name_) > PARAMETER_NAME_PREFIX_MAX_LEN:
  279. raise ValueError("The length of the '{}' name should be less than {}.".
  280. format(name_, PARAMETER_NAME_PREFIX_MAX_LEN))
  281. else:
  282. raise ValueError("The type of the Parameter's name should be 'string' or 'None', "
  283. "but got {}.".format(type(name_)))
  284. if _is_role_worker() and self.cache_enable:
  285. if len(self.shape) != 2:
  286. raise RuntimeError("The dims of parameter '{}' must be 2, but got {}."
  287. .format(self.name, len(self.shape)))
  288. _reinsert_hash_table_size(name_, self.param_info.name, self.shape[0], self.shape[1])
  289. if name_ == PARAMETER_NAME_DEFAULT:
  290. logger.warning("The parameter definition is deprecated.\n"
  291. "Please set a unique name for the parameter '{}'.". format(self))
  292. self.param_info.name = name_
  293. @property
  294. def sliced(self):
  295. """Get slice status of the parameter."""
  296. return self._sliced
  297. @sliced.setter
  298. def sliced(self, sliced_):
  299. self._sliced = sliced_
  300. @property
  301. def comm_fusion(self):
  302. """
  303. Get the fusion type (int) for communication operators corresponding to this parameter.
  304. In `AUTO_PARALLEL` and `SEMI_AUTO_PARALLEL` mode, some communication operators used for parameters or
  305. gradients aggregation are inserted automatically. The value of fusion must be greater than or equal to 0.
  306. When the value of fusion is 0, operators will not be fused together.
  307. """
  308. return self.param_info.comm_fusion
  309. @comm_fusion.setter
  310. def comm_fusion(self, comm_fusion_):
  311. if context.get_context("mode") == context.PYNATIVE_MODE and "auto_parallel" in _get_parallel_mode():
  312. raise RuntimeError(
  313. "`comm_fusion` does not support PYNATIVE_MODE in AUTO_PARALLEL and SEMI_AUTO_PARALLEL mode.")
  314. Validator.check_non_negative_int(comm_fusion_)
  315. self.param_info.comm_fusion = comm_fusion_
  316. @property
  317. def parallel_optimizer_comm_recompute(self):
  318. """
  319. Get the communication recompute status(bool) of optimizer parallel for the parameter.
  320. In `AUTO_PARALLEL` and `SEMI_AUTO_PARALLEL` mode, when applying parallel optimizer, some AllGather operators
  321. used for parameters gathering are inserted automatically. It is used to control the recompute attr for those
  322. AllGather operators.
  323. Note:
  324. - Only `Graph` mode is supported.
  325. - It is recommended to use cell.recompute(parallel_optimizer_comm_recompute=True/False) to configure
  326. the AllGather operators introducing by parallel optimizer rather than using this interface directly.
  327. """
  328. return self.param_info.parallel_optimizer_comm_recompute
  329. @parallel_optimizer_comm_recompute.setter
  330. def parallel_optimizer_comm_recompute(self, parallel_optimizer_comm_recompute_):
  331. Validator.check_bool(parallel_optimizer_comm_recompute_)
  332. self.param_info.parallel_optimizer_comm_recompute = parallel_optimizer_comm_recompute_
  333. @property
  334. def unique(self):
  335. """whether the parameter is already unique or not."""
  336. return self._unique
  337. @unique.setter
  338. def unique(self, unique_):
  339. self._unique = unique_
  340. def clone(self, init='same'):
  341. """
  342. Clone the parameter.
  343. Args:
  344. init (Union[Tensor, str, numbers.Number]): Initialize the shape and dtype of the parameter.
  345. If `init` is a `Tensor` or `numbers.Number`, clone a new parameter with the same shape
  346. and dtype, and the data of the new parameter will be set according to `init`. If `init`
  347. is a `str`, the `init` should be the alias of the class inheriting from `Initializer`.
  348. For example, if `init` is 'same', clone a new parameter with the same data, shape, and
  349. dtype. Default: 'same'.
  350. Returns:
  351. Parameter, a new parameter.
  352. """
  353. x = copy(self)
  354. x.param_info = self.param_info.clone()
  355. x.is_init = False
  356. x.init = self.init
  357. x.is_param_ps = self.is_param_ps
  358. x.init_in_server = self.init_in_server
  359. x.cache_enable = self.cache_enable
  360. x.requires_aggr = self.requires_aggr
  361. if self.cache_shape:
  362. x.cache_shape = self.cache_shape
  363. if init != 'same':
  364. shape = self.shape
  365. dtype = self.dtype
  366. x.set_data(initializer(init, shape=shape, dtype=dtype))
  367. return x
  368. @property
  369. def layerwise_parallel(self):
  370. """
  371. Get the layerwise parallel status(bool) of the parameter.
  372. When layerwise_parallel is true in `DATA_PARALLEL` and `HYBRID_PARALLEL` parallel mode, broadcast and gradients
  373. communication would not be applied to parameters.
  374. """
  375. return self.param_info.layerwise_parallel
  376. @layerwise_parallel.setter
  377. def layerwise_parallel(self, value=True):
  378. if not isinstance(value, bool):
  379. raise TypeError("The argument `layerwise_parallel` must be bool type.")
  380. self.param_info.layerwise_parallel = value
  381. @property
  382. def parallel_optimizer(self):
  383. """
  384. Get the optimizer parallel status(bool) of the parameter.
  385. It is used to filter the weight shard operation in `AUTO_PARALLEL` and `SEMI_AUTO_PARALLEL` mode. It works only
  386. when enable parallel optimizer in `mindspore.context.set_auto_parallel_context()`.
  387. """
  388. return self.param_info.parallel_optimizer
  389. @parallel_optimizer.setter
  390. def parallel_optimizer(self, value=True):
  391. if not isinstance(value, bool):
  392. raise TypeError("The argument `parallel_optimizer` must be bool type.")
  393. self.param_info.parallel_optimizer = value
  394. @property
  395. def cache_enable(self):
  396. """Return whether the parameter is cache enable."""
  397. return self.param_info.cache_enable
  398. @cache_enable.setter
  399. def cache_enable(self, value=True):
  400. if not isinstance(value, bool):
  401. raise TypeError("The argument `cache_enable` must be bool type.")
  402. self.param_info.cache_enable = value
  403. @property
  404. def cache_shape(self):
  405. """Return the cache shape corresponding to the parameter if use cache."""
  406. return self.param_info.cache_shape
  407. @cache_shape.setter
  408. def cache_shape(self, value):
  409. if not isinstance(value, (tuple, list)):
  410. raise TypeError("The argument `cache_shape` must be tuple or list type.")
  411. self.param_info.cache_shape = value
  412. @property
  413. def requires_grad(self):
  414. """
  415. Return whether the parameter requires gradient.
  416. The main function of requires_grad is to tell auto grad to start recording operations on a Tensor.
  417. If a Tensor has requires_grad=False, then Tensor requires_grad will make auto grad start recording
  418. operations on the tensor.
  419. """
  420. return self.param_info.requires_grad
  421. @requires_grad.setter
  422. def requires_grad(self, value=True):
  423. if not isinstance(value, bool):
  424. raise TypeError("The argument `requires_grad` must be bool type")
  425. self.param_info.requires_grad = value
  426. @property
  427. def data(self):
  428. """Return the parameter object."""
  429. return self
  430. def _update_tensor_data(self, data):
  431. """Update the parameter by a Tensor."""
  432. if isinstance(self, Tensor):
  433. self.init_flag = False
  434. self.init = None
  435. return self.assign_value(data)
  436. new_param = Parameter(data, self.name, self.requires_grad)
  437. new_param.param_info = self.param_info
  438. return new_param
  439. def add_pipeline_stage(self, stage):
  440. if not isinstance(stage, int) or stage < 0:
  441. raise TypeError("`stage` must be a positive number of int type")
  442. self._pipeline_stage_list.append(stage)
  443. def set_data(self, data, slice_shape=False):
  444. """
  445. Set Parameter's data.
  446. Args:
  447. data (Union[Tensor, int, float]): new data.
  448. slice_shape (bool): If slice the parameter is set to true, the shape is not checked for consistency.
  449. Default: False.
  450. Returns:
  451. Parameter, the parameter after set data.
  452. """
  453. def raise_type_error(incoming):
  454. raise TypeError(f"Incoming Parameter dtype can not be converted to current dtype implicitly. "
  455. f"Current dtype is {self.dtype}, and incoming is {incoming}. "
  456. f"Use .set_dtype(xxx) to change the dtype.")
  457. if not isinstance(data, (Tensor, int, float)):
  458. raise TypeError(f"Parameter data must be [`Tensor`, `int`, `float`] or a kind of `Tensor` "
  459. f"(like `Tensor`). But with type {type(data)}.")
  460. if isinstance(data, (int, float)):
  461. if self.dtype in mstype.int_type and isinstance(data, float):
  462. raise_type_error(mstype.float_)
  463. data = Tensor(data, self.dtype)
  464. # both not init.
  465. incoming_tensor_is_init = isinstance(data, Tensor) and not data.has_init
  466. current_tensor_is_init = isinstance(self, Tensor) and not self.has_init
  467. if incoming_tensor_is_init and not current_tensor_is_init:
  468. raise TypeError("The original tensor data is initialized, but the argument 'data' is not initialized."
  469. "Please initialize 'data' before call this method.")
  470. if tuple(self.shape) != tuple(data.shape):
  471. # If Slice create Parameter shape can be change.
  472. if not slice_shape:
  473. raise ValueError(f"Can not change the shape of Parameter which has been initialized."
  474. f" Current shape is {self.shape}, and incoming is {data.shape}.")
  475. if self.dtype != data.dtype:
  476. if mstype.implicit_conversion_seq[self.dtype] < mstype.implicit_conversion_seq[data.dtype]:
  477. raise_type_error(data.dtype)
  478. else:
  479. from mindspore.ops import functional as F
  480. data = F.cast(data, self.dtype)
  481. if isinstance(data, Tensor) and data.has_init:
  482. # The parameter has been initialized, directly update by the data
  483. if current_tensor_is_init:
  484. self._update_tensor_data(data.init_data())
  485. else:
  486. # also update the related inited parameter data
  487. if self.inited_param is not None:
  488. self.inited_param.set_data(data)
  489. self.init_mode = data
  490. elif incoming_tensor_is_init or current_tensor_is_init:
  491. self._update_tensor_data(data)
  492. self.sliced = slice_shape
  493. return self
  494. def init_data(self, layout=None, set_sliced=False):
  495. """
  496. Initialize the parameter's data.
  497. Args:
  498. layout (Union[None, tuple]): The parameter's layout info.
  499. layout [dev_mat, tensor_map, slice_shape, filed_size, uniform_split, opt_shard_group]. Default: None.
  500. It's not None only in 'SEMI_AUTO_PARALLEL' or 'AUTO_PARALLEL' mode.
  501. - dev_mat (list(int)): The parameter's device matrix.
  502. - tensor_map (list(int)): The parameter's tensor map.
  503. - slice_shape (list(int)): The parameter's slice shape.
  504. - filed_size (int): The parameter's filed size.
  505. - uniform_split (bool): Whether the parameter is split evenly.
  506. - opt_shard_group (str): The group of the parameter while running optimizer parallel.
  507. set_sliced (bool): True if the parameter is set sliced after initializing the data.
  508. Default: False.
  509. Raises:
  510. RuntimeError: If it is from Initializer, and parallel mode has changed after the Initializer created.
  511. ValueError: If the length of the layout is less than 6.
  512. TypeError: If `layout` is not tuple.
  513. Returns:
  514. Parameter, the `Parameter` after initializing data. If current `Parameter` was already initialized before,
  515. returns the same initialized `Parameter`.
  516. """
  517. if self.is_default_input_init and self.is_in_parallel != _is_in_parallel_mode():
  518. raise RuntimeError("Must set or change parallel mode before any Tensor created.")
  519. if self.init_mode is None:
  520. return self
  521. if self.inited_param is not None:
  522. return self.inited_param
  523. if _is_role_worker() and self.cache_enable:
  524. global_seed, op_seed = _get_global_and_op_seed()
  525. _insert_weight_init_info(self.name, global_seed, op_seed)
  526. init_data_args = ()
  527. if layout is not None:
  528. if not isinstance(layout, tuple):
  529. raise TypeError("The argument 'layout' should be tuple, but got {}.".format(type(layout)))
  530. if len(layout) < 6:
  531. raise ValueError("The length of 'layout' must be larger than 5, but got {}.".format(len(layout)))
  532. slice_index = int(_get_slice_index(layout[0], layout[1]))
  533. init_data_args += (slice_index, layout[2], layout[5])
  534. if _is_role_pserver():
  535. return self
  536. if self.init_in_server and self.is_param_ps and isinstance(self.init_mode, Tensor) and \
  537. self.init_mode.init is not None and (_is_role_worker() or _is_role_sched()):
  538. if self.cache_enable:
  539. data = self.init_mode.init_data(*init_data_args)
  540. else:
  541. data = self.init_mode.init_data(0, [1])
  542. else:
  543. data = self.init_mode.init_data(*init_data_args)
  544. obj = self._update_tensor_data(data)
  545. if id(obj) != id(self):
  546. self._inited_param = obj
  547. obj.init_mode = None
  548. obj.sliced = set_sliced
  549. return obj
  550. class ParameterTuple(tuple):
  551. """
  552. Class for storing tuple of parameters.
  553. Note:
  554. It is used to store the parameters of the network into the parameter tuple collection.
  555. """
  556. def __new__(cls, iterable):
  557. """Create instance object of ParameterTuple."""
  558. data = tuple(iterable)
  559. ids = set()
  560. orders = {}
  561. for x in data:
  562. if not isinstance(x, Parameter):
  563. raise TypeError(f"ParameterTuple input should be `Parameter` collection."
  564. f"But got a {type(iterable)}, {iterable}")
  565. if id(x) not in ids:
  566. ids.add(id(x))
  567. if x.name not in orders.keys():
  568. orders[x.name] = [0, x]
  569. else:
  570. if isinstance(orders[x.name], list):
  571. name = x.name
  572. orders[name][1].name = name + "_" + str(0)
  573. x.name = x.name + "_" + str(1)
  574. orders[name] = 1
  575. else:
  576. orders[x.name] += 1
  577. x.name = x.name + "_" + str(orders[x.name])
  578. return tuple.__new__(ParameterTuple, tuple(data))
  579. def clone(self, prefix, init='same'):
  580. """
  581. Clone the parameters in ParameterTuple element-wisely to generate a new ParameterTuple.
  582. Args:
  583. prefix (str): Namespace of parameter.
  584. init (Union[Tensor, str, numbers.Number]): Initialize the shape and dtype of the parameters.
  585. The definition of `init` is the same as in `Parameter` API. If `init` is 'same', the
  586. parameters in the new parameter tuple are the same as those in the original parameter tuple.
  587. Default: 'same'.
  588. Returns:
  589. Tuple, the new Parameter tuple.
  590. """
  591. Validator.check_str_by_regular(prefix)
  592. new = []
  593. for x in self:
  594. x1 = x.clone(init)
  595. x1.name = prefix + "." + x1.name
  596. new.append(x1)
  597. if not x1.cache_enable:
  598. continue
  599. if _is_role_worker():
  600. _clone_hash_table(x.name, x1.name)
  601. _insert_accumu_init_info(x1.name, init_to_value(init))
  602. return ParameterTuple(new)
  603. def __parameter_tuple__(self):
  604. """For parse check."""