You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

parameter.py 25 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Parameter for cell."""
  16. from copy import copy
  17. import numbers
  18. import numpy as np
  19. from .._c_expression import ParamInfo
  20. from . import dtype as mstype
  21. from .. import context
  22. from ..parallel._utils import _get_parallel_mode
  23. from .initializer import initializer
  24. from .tensor import Tensor
  25. from .._checkparam import Validator
  26. from .._c_expression import Tensor as Tensor_
  27. from ..parallel._tensor import _get_slice_index
  28. from ..parallel._auto_parallel_context import auto_parallel_context
  29. from ..parallel._ps_context import _is_role_worker, _is_role_pserver, _is_role_sched, _clone_hash_table
  30. from ..parallel._ps_context import _reinsert_hash_table_size
  31. from ..parallel._ps_context import _insert_weight_init_info, _insert_accumu_init_info
  32. from .seed import _get_global_and_op_seed
  33. __all__ = ['Parameter', 'ParameterTuple']
  34. PARAMETER_NAME_DEFAULT = "Parameter"
  35. PARAMETER_NAME_PREFIX_MAX_LEN = 1024
  36. def _is_in_parallel_mode():
  37. """Get parallel mode."""
  38. return auto_parallel_context().get_parallel_mode() in ["semi_auto_parallel", "auto_parallel"]
  39. def init_to_value(init):
  40. """Get value of initializer."""
  41. if isinstance(init, str):
  42. if init == 'zeros':
  43. return 0.0
  44. if init == 'ones':
  45. return 1.0
  46. raise ValueError("init should be one of values in 'zeros', 'ones'.")
  47. if isinstance(init, numbers.Number):
  48. return float(init)
  49. raise ValueError("init should be number or string")
  50. class Parameter(Tensor_):
  51. """
  52. Parameter types of cell models.
  53. After initialized `Parameter` is a subtype of `Tensor`.
  54. In auto_parallel mode of "semi_auto_parallel" and "auto_parallel", if init `Parameter` by
  55. an `Tensor`, the type of Parameter will be `Tensor`. `Tensor`
  56. will save the shape and type info of a tensor with no memory usage. The shape can be changed while
  57. compiling for auto-parallel. Call `init_data` will return a Tensor Parameter with initialized data.
  58. Note:
  59. Each parameter of Cell is represented by Parameter class.
  60. A Parameter has to belong to a Cell.
  61. If there is an operator in the network that requires part of the inputs to be Parameter,
  62. then the Parameters as this part of the inputs are not allowed to be cast.
  63. It is recommended to use the default value of `name` when initialize a parameter as one attribute of a cell,
  64. otherwise, the parameter name may be different than expected.
  65. Args:
  66. default_input (Union[Tensor, int, float, numpy.ndarray, list]): Parameter data, to be set initialized.
  67. name (str): Name of the child parameter. Default: None.
  68. requires_grad (bool): True if the parameter requires gradient. Default: True.
  69. layerwise_parallel (bool): When layerwise_parallel is true in data parallel mode,
  70. broadcast and gradients communication would not be applied to parameters. Default: False.
  71. parallel_optimizer (bool): It is used to filter the weight shard operation in semi auto or auto parallel
  72. mode. It works only when enable parallel optimizer in `mindspore.context.set_auto_parallel_context()`.
  73. Default: True.
  74. Examples:
  75. >>> from mindspore import Parameter, Tensor
  76. >>> from mindspore.common import initializer as init
  77. >>> from mindspore.ops import operations as P
  78. >>> from mindspore.nn import Cell
  79. >>> import mindspore
  80. >>> import numpy as np
  81. >>> from mindspore import context
  82. >>>
  83. >>> class Net(Cell):
  84. ... def __init__(self):
  85. ... super(Net, self).__init__()
  86. ... self.matmul = P.MatMul()
  87. ... self.weight = Parameter(Tensor(np.ones((1, 2)), mindspore.float32), name="w", requires_grad=True)
  88. ...
  89. ... def construct(self, x):
  90. ... out = self.matmul(self.weight, x)
  91. ... return out
  92. >>> net = Net()
  93. >>> x = Tensor(np.ones((2, 1)), mindspore.float32)
  94. >>> print(net(x))
  95. [[2.]]
  96. >>> _ = net.weight.set_data(Tensor(np.zeros((1, 2)), mindspore.float32))
  97. >>> print(net(x))
  98. [[0.]]
  99. """
  100. __base_type__ = {}
  101. def __new__(cls, default_input, *args, **kwargs):
  102. init_data_flag = bool(isinstance(default_input, Tensor) and default_input.has_init)
  103. input_class, *class_init_args = Parameter._get_parameter_new_args(default_input)
  104. new_type = Parameter._get_base_class(input_class)
  105. obj = input_class.__new__(new_type)
  106. input_class.__init__(obj, *class_init_args)
  107. # it's better to make the Initializer a kind of tensor.
  108. obj.init_mode = None
  109. obj.is_default_input_init = init_data_flag
  110. if obj.has_init:
  111. obj.init_mode = default_input
  112. return obj
  113. def __reduce_ex__(self, _):
  114. data = self
  115. if self.init_mode is not None:
  116. data = self.init_mode
  117. else:
  118. # cast to break deep infinite loop while deepcopy
  119. data = Tensor(self)
  120. return (
  121. Parameter, (data, self.name, self.requires_grad, self.layerwise_parallel))
  122. def __init__(self, default_input, name=None, requires_grad=True, layerwise_parallel=False, parallel_optimizer=True):
  123. self.param_info = ParamInfo()
  124. self.init_in_server = False
  125. self.cache_enable = False
  126. self.name = name
  127. self.requires_grad = requires_grad
  128. self.layerwise_parallel = layerwise_parallel
  129. self.parallel_optimizer = parallel_optimizer
  130. # this flag for tensor copy data.
  131. self.init_flag = False
  132. # this flag is for ge variable copy data.
  133. self._is_init = False
  134. self._inited_param = None
  135. self._sliced = False
  136. self.is_param_ps = False
  137. self._cast_type = None
  138. self._unique = False
  139. self.is_in_parallel = _is_in_parallel_mode()
  140. if isinstance(default_input, (Tensor_, Tensor)):
  141. Tensor_.__init__(self, default_input.dtype, default_input.shape)
  142. elif isinstance(default_input, int):
  143. Tensor_.__init__(self, mstype.int64, ())
  144. elif isinstance(default_input, float):
  145. Tensor_.__init__(self, mstype.float32, ())
  146. elif isinstance(default_input, (np.ndarray, list)):
  147. Tensor_.__init__(self, default_input)
  148. else:
  149. raise TypeError(f"Parameter input must be [`Tensor`, `int`, `float`, `numpy.ndarray`, `list`]."
  150. f"But with type {type(default_input)}.")
  151. def __deepcopy__(self, memodict):
  152. new_obj = Parameter(self)
  153. new_obj.name = self.name
  154. new_obj._inited_param = self._inited_param # pylint: disable=W0212
  155. return new_obj
  156. @staticmethod
  157. def _get_base_class(input_class):
  158. input_class_name = f'Parameter{input_class.__name__}'
  159. if input_class_name in Parameter.__base_type__:
  160. new_type = Parameter.__base_type__[input_class_name]
  161. else:
  162. new_type = type(input_class_name, (Parameter, input_class), {})
  163. Parameter.__base_type__[input_class_name] = new_type
  164. return new_type
  165. @staticmethod
  166. def _get_parameter_new_args(data):
  167. """Set `set_data` of current `Parameter`."""
  168. if isinstance(data, bool):
  169. raise ValueError('Parameter data can not be `bool`')
  170. if isinstance(data, Tensor) and data.has_init:
  171. if _is_in_parallel_mode() or _is_role_worker() or _is_role_sched():
  172. # do not init data while in auto parallel.
  173. return (Tensor, None, data.dtype, data.shape, data.init)
  174. data = data.init_data().asnumpy()
  175. elif isinstance(data, Tensor):
  176. # make a copy of Tensor to init the parameter
  177. return (Tensor, data.asnumpy(),)
  178. if isinstance(data, int):
  179. return (Tensor, data, mstype.int32)
  180. if isinstance(data, float):
  181. return (Tensor, data, mstype.float32)
  182. return (Tensor, data)
  183. def __str__(self):
  184. return f'Parameter (name={self.name}, shape={self.shape}, dtype={self.dtype}, ' \
  185. f'requires_grad={self.requires_grad})'
  186. def __repr__(self):
  187. return self.__str__()
  188. def __parameter__(self):
  189. """For parse check."""
  190. def set_param_ps(self, init_in_server=False):
  191. """
  192. Set whether the trainable parameter is updated by parameter server and whether the
  193. trainable parameter is initialized on server.
  194. Note:
  195. It only works when a running task is in the parameter server mode.
  196. Args:
  197. init_in_server (bool): Whether trainable parameter updated by parameter server is
  198. initialized on server. Default: False.
  199. """
  200. if not(_is_role_worker() or _is_role_pserver() or _is_role_sched()):
  201. raise RuntimeError("Must complete following two steps before calling set_param_ps: \
  202. 1. set_ps_context(enable_ps=True) \
  203. 2. export MS_ROLE environment variable.")
  204. if init_in_server and (not self.name.endswith("embedding_table")):
  205. raise RuntimeError("Can not initialize parameter '{}' in server, only parameters of "
  206. "sparse operator support initialization in server.".format(self.name))
  207. self.is_param_ps = True
  208. self.init_in_server = init_in_server
  209. self.param_info.init_in_server = init_in_server
  210. @property
  211. def inited_param(self):
  212. """
  213. Get the new parameter after call the init_data.
  214. Default is a None, If `self` is a Parameter with out data, after call the
  215. `init_data` the initialized Parameter with data will be recorded here.
  216. """
  217. return self._inited_param
  218. @property
  219. def name(self):
  220. """Get the name of the parameter."""
  221. return self.param_info.name
  222. @name.setter
  223. def name(self, name_):
  224. """
  225. Define a name for the parameter.
  226. Args:
  227. name_ (`str` or `None`): The name of the parameter. When the parameter is None or an empty string,
  228. the default value `PARAMETER_NAME_DEFAULT` is used.
  229. """
  230. if name_ is None:
  231. name_ = PARAMETER_NAME_DEFAULT
  232. elif isinstance(name_, str):
  233. name_ = name_.strip()
  234. if name_ == '':
  235. name_ = PARAMETER_NAME_DEFAULT
  236. if len(name_) > PARAMETER_NAME_PREFIX_MAX_LEN:
  237. raise ValueError("The length of the '{}' name should be less than {}.".
  238. format(name_, PARAMETER_NAME_PREFIX_MAX_LEN))
  239. else:
  240. raise ValueError("The type of the name should be `str` or `None`.")
  241. if _is_role_worker() and self.cache_enable:
  242. if len(self.shape) != 2:
  243. raise RuntimeError("The dims of parameter '{}' must be 2, but got {}."
  244. .format(self.name, len(self.shape)))
  245. _reinsert_hash_table_size(name_, self.param_info.name, self.shape[0], self.shape[1])
  246. self.param_info.name = name_
  247. @property
  248. def sliced(self):
  249. """Get slice status of the parameter."""
  250. return self._sliced
  251. @sliced.setter
  252. def sliced(self, sliced_):
  253. self._sliced = sliced_
  254. @property
  255. def comm_fusion(self):
  256. """
  257. Get and Set the fusion type (int) for communication operators corresponding to this parameter.
  258. In `AUTO_PARALLEL` and `SEMI_AUTO_PARALLEL` mode, some communication operators used for parameters or
  259. gradients aggregation are inserted automatically. Set the fusion type for communication operators generated
  260. for this parameter. The value of fusion must be greater than or equal to 0. When the value of fusion is 0,
  261. operators will not be fused together.
  262. Only `Ascend` and `Graph` mode is supported.
  263. """
  264. return self.param_info.comm_fusion
  265. @comm_fusion.setter
  266. def comm_fusion(self, comm_fusion_):
  267. if context.get_context("mode") == context.PYNATIVE_MODE and "auto_parallel" in _get_parallel_mode():
  268. raise RuntimeError("`comm_fusion` does not support PYNATIVE_MODE")
  269. Validator.check_non_negative_int(comm_fusion_)
  270. self.param_info.comm_fusion = comm_fusion_
  271. @property
  272. def unique(self):
  273. """whether the parameter is already unique or not."""
  274. return self._unique
  275. @unique.setter
  276. def unique(self, unique_):
  277. self._unique = unique_
  278. @property
  279. def is_init(self):
  280. """
  281. Get the initialization status of the parameter.
  282. In GE backend, the Parameter need a "init graph" to sync the data from host to device.
  283. This flag indicates whether the data as been sync to the device.
  284. This flag only work in GE, and it will be set to False in other backend.
  285. """
  286. return self._is_init
  287. @is_init.setter
  288. def is_init(self, is_init_):
  289. """
  290. Set init status of the parameter.
  291. Args:
  292. is_init_ (bool): The init status of the parameter.
  293. """
  294. self._is_init = is_init_
  295. def clone(self, init='same'):
  296. """
  297. Clone the parameter.
  298. Args:
  299. init (Union[Tensor, str, numbers.Number]): Initialize the shape of the parameter.
  300. Default: 'same'.
  301. Returns:
  302. Parameter, a new parameter.
  303. """
  304. x = copy(self)
  305. x.param_info = self.param_info.clone()
  306. x.is_init = False
  307. x.init = self.init
  308. x.is_param_ps = self.is_param_ps
  309. x.init_in_server = self.init_in_server
  310. x.cache_enable = self.cache_enable
  311. if self.cache_shape:
  312. x.cache_shape = self.cache_shape
  313. if init != 'same':
  314. shape = self.shape
  315. dtype = self.dtype
  316. x.set_data(initializer(init, shape=shape, dtype=dtype))
  317. return x
  318. @property
  319. def layerwise_parallel(self):
  320. return self.param_info.layerwise_parallel
  321. @layerwise_parallel.setter
  322. def layerwise_parallel(self, value=True):
  323. if not isinstance(value, bool):
  324. raise TypeError("`layerwise_parallel` parameter must be bool type")
  325. self.param_info.layerwise_parallel = value
  326. @property
  327. def parallel_optimizer(self):
  328. """Return whether the parameter requires weight shard for parallel optimizer."""
  329. return self.param_info.parallel_optimizer
  330. @parallel_optimizer.setter
  331. def parallel_optimizer(self, value=True):
  332. if not isinstance(value, bool):
  333. raise TypeError("`parallel_optimizer` parameter must be bool type")
  334. self.param_info.parallel_optimizer = value
  335. @property
  336. def cache_enable(self):
  337. """Return whether the parameter is cache enable."""
  338. return self.param_info.cache_enable
  339. @cache_enable.setter
  340. def cache_enable(self, value=True):
  341. if not isinstance(value, bool):
  342. raise TypeError("`cache_enable` parameter must be bool type")
  343. self.param_info.cache_enable = value
  344. @property
  345. def cache_shape(self):
  346. """Return the cache shape corresponding to the parameter if use cache."""
  347. return self.param_info.cache_shape
  348. @cache_shape.setter
  349. def cache_shape(self, value):
  350. if not isinstance(value, (tuple, list)):
  351. raise TypeError("`cache_shape` parameter must be tuple or list type")
  352. self.param_info.cache_shape = value
  353. @property
  354. def requires_grad(self):
  355. """Return whether the parameter requires gradient."""
  356. return self.param_info.requires_grad
  357. @requires_grad.setter
  358. def requires_grad(self, value=True):
  359. if not isinstance(value, bool):
  360. raise TypeError("`requires_grad` parameter must be bool type")
  361. self.param_info.requires_grad = value
  362. @property
  363. def data(self):
  364. return self
  365. def _update_tensor_data(self, data):
  366. "Update the parameter by a Tensor."
  367. if isinstance(self, Tensor):
  368. self.init_flag = False
  369. self.init = None
  370. return self.assign_value(data)
  371. new_param = Parameter(data, self.name, self.requires_grad)
  372. new_param.param_info = self.param_info
  373. return new_param
  374. def set_data(self, data, slice_shape=False):
  375. """
  376. Set `set_data` of current `Parameter`.
  377. Args:
  378. data (Union[Tensor, int, float]): new data.
  379. slice_shape (bool): If slice the parameter is set to true, the shape is not checked for consistency.
  380. Default: False.
  381. Returns:
  382. Parameter, the parameter after set data.
  383. """
  384. def raise_type_error(incoming):
  385. raise TypeError(f"Incoming Parameter dtype can not be converted to current dtype implicitly. "
  386. f"Current dtype is {self.dtype}, and incoming is {incoming}. "
  387. f"Use .set_dtype(xxx) to change the dtype.")
  388. if not isinstance(data, (Tensor, int, float)):
  389. raise TypeError(f"Parameter data must be [`Tensor`, `int`, `float`] or a kind of `Tensor` "
  390. f"(like `Tensor`). But with type {type(data)}.")
  391. if isinstance(data, (int, float)):
  392. if self.dtype in mstype.int_type and isinstance(data, float):
  393. raise_type_error(mstype.float_)
  394. data = Tensor(data, self.dtype)
  395. # both not init.
  396. incoming_tensor_is_init = isinstance(data, Tensor) and not data.has_init
  397. current_tensor_is_init = isinstance(self, Tensor) and not self.has_init
  398. if incoming_tensor_is_init and not current_tensor_is_init:
  399. raise TypeError("Parameter is a `Tensor` and not initializered, `data` for `set_data`"
  400. "should be a Tensor. If you want to update it by Tensor, call method"
  401. "`init_parameters_data` of `Cell` to init and replace all the Parameter of"
  402. "network, then call this method.")
  403. if tuple(self.shape) != tuple(data.shape):
  404. # If Slice create Parameter shape can be change.
  405. if not slice_shape:
  406. raise ValueError(f"Can not change the shape of Parameter which has been initialized."
  407. f" Current shape is {self.shape}, and incoming is {data.shape}.")
  408. if self.dtype != data.dtype:
  409. if mstype.implicit_conversion_seq[self.dtype] < mstype.implicit_conversion_seq[data.dtype]:
  410. raise_type_error(data.dtype)
  411. else:
  412. from mindspore.ops import functional as F
  413. data = F.cast(data, self.dtype)
  414. if isinstance(data, Tensor) and data.has_init:
  415. # The parameter has been initializered, directly update by the data
  416. if current_tensor_is_init:
  417. self._update_tensor_data(data.init_data())
  418. else:
  419. # also update the related inited parameter data
  420. if self.inited_param is not None:
  421. self.inited_param.set_data(data)
  422. self.init_mode = data
  423. elif incoming_tensor_is_init or current_tensor_is_init:
  424. self._update_tensor_data(data)
  425. else:
  426. raise ValueError(f"Not support to update the Parameter by {data}")
  427. self.sliced = slice_shape
  428. return self
  429. def init_data(self, layout=None, set_sliced=False):
  430. """
  431. Initialize the parameter data.
  432. Args:
  433. layout (Union[None, list(list(int))]): Parameter slice
  434. layout [dev_mat, tensor_map, slice_shape]. Default: None.
  435. - dev_mat (list(int)): Device matrix.
  436. - tensor_map (list(int)): Tensor map.
  437. - slice_shape (list(int)): Shape of slice.
  438. set_sliced (bool): True if the parameter is set sliced after initializing the data.
  439. Default: False.
  440. Raises:
  441. RuntimeError: If it is from Initializer, and parallel mode has changed after the Initializer created.
  442. Returns:
  443. Parameter, the `Parameter` after initializing data. If current `Parameter` was already initialized before,
  444. returns the same initialized `Parameter`.
  445. """
  446. if self.is_default_input_init and self.is_in_parallel != _is_in_parallel_mode():
  447. raise RuntimeError("Must set or change parallel mode before any Tensor created.")
  448. if self.init_mode is None:
  449. return self
  450. if self.inited_param is not None:
  451. return self.inited_param
  452. if _is_role_worker() and self.cache_enable:
  453. global_seed, op_seed = _get_global_and_op_seed()
  454. _insert_weight_init_info(self.name, global_seed, op_seed)
  455. init_data_args = ()
  456. if layout is not None:
  457. if not isinstance(layout, tuple):
  458. raise TypeError("The layout should be tuple, but got layout is {}.".format(layout))
  459. if len(layout) < 3:
  460. raise ValueError("The length of layout must be larger than 2, but got layout is {}.".format(layout))
  461. slice_index = int(_get_slice_index(layout[0], layout[1]))
  462. init_data_args += (slice_index, layout[2], layout[5])
  463. if self.init_in_server and self.is_param_ps and isinstance(self.init_mode, Tensor) and \
  464. self.init_mode.init is not None and (_is_role_worker() or _is_role_sched()):
  465. data = self.init_mode.init_data(0, [1])
  466. else:
  467. data = self.init_mode.init_data(*init_data_args)
  468. obj = self._update_tensor_data(data)
  469. if id(obj) != id(self):
  470. self._inited_param = obj
  471. obj.init_mode = None
  472. obj.sliced = set_sliced
  473. return obj
  474. class ParameterTuple(tuple):
  475. """
  476. Class for storing tuple of parameters.
  477. Note:
  478. It is used to store the parameters of the network into the parameter tuple collection.
  479. """
  480. def __new__(cls, iterable):
  481. """Create instance object of ParameterTuple."""
  482. data = tuple(iterable)
  483. ids = set()
  484. orders = {}
  485. for x in data:
  486. if not isinstance(x, Parameter):
  487. raise TypeError(f"ParameterTuple input should be `Parameter` collection."
  488. f"But got a {type(iterable)}, {iterable}")
  489. if id(x) not in ids:
  490. ids.add(id(x))
  491. if x.name not in orders.keys():
  492. orders[x.name] = [0, x]
  493. else:
  494. if isinstance(orders[x.name], list):
  495. name = x.name
  496. orders[name][1].name = name + "_" + str(0)
  497. x.name = x.name + "_" + str(1)
  498. orders[name] = 1
  499. else:
  500. orders[x.name] += 1
  501. x.name = x.name + "_" + str(orders[x.name])
  502. return tuple.__new__(ParameterTuple, tuple(data))
  503. def clone(self, prefix, init='same'):
  504. """
  505. Clone the parameter.
  506. Args:
  507. prefix (str): Namespace of parameter.
  508. init (str): Initialize the shape of the parameter. Default: 'same'.
  509. Returns:
  510. Tuple, the new Parameter tuple.
  511. """
  512. Validator.check_str_by_regular(prefix)
  513. new = []
  514. for x in self:
  515. x1 = x.clone(init)
  516. x1.name = prefix + "." + x1.name
  517. new.append(x1)
  518. if not x1.cache_enable:
  519. continue
  520. if not x1.name.endswith("embedding_table"):
  521. raise RuntimeError("Can not enable cache for parameter '{}', Only parameters of "
  522. "sparse operator support enable cache.".format(x1.name))
  523. if _is_role_worker():
  524. _clone_hash_table(x.name, x1.name)
  525. _insert_accumu_init_info(x1.name, init_to_value(init))
  526. return ParameterTuple(new)
  527. def __parameter_tuple__(self):
  528. """For parse check."""
  529. def __del__(self):
  530. self.param_info = None