You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cell_wrapper.py 27 kB

5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Cell_wrapper."""
  16. from types import FunctionType, MethodType
  17. from mindspore.parallel._utils import (_get_device_num, _get_gradients_mean,
  18. _get_parallel_mode, _get_enable_parallel_optimizer)
  19. from mindspore.context import ParallelMode
  20. from mindspore._checkparam import Validator as validator
  21. from mindspore import ops, nn
  22. from ...common import dtype as mstype
  23. from ...common.parameter import Parameter, ParameterTuple
  24. from ...ops import composite as C
  25. from ...ops import functional as F
  26. from ...ops import operations as P
  27. from ...ops.operations.comm_ops import _VirtualDataset
  28. from ..cell import Cell
  29. from .grad_reducer import DistributedGradReducer
  30. _get_datatype = C.MultitypeFuncGraph("_get_datatype")
  31. @_get_datatype.register("Tensor")
  32. def _tensors_get_datatype(param):
  33. """
  34. Acquire parameter datatype.
  35. Args:
  36. param (Tensor): The parameter before operation.
  37. Returns:
  38. mstype, the datatype of parameter.
  39. """
  40. return F.dtype(param)
  41. _cast_datatype = C.MultitypeFuncGraph("_cast_datatype")
  42. @_cast_datatype.register("TypeType", "Tensor")
  43. def _tensors_cast_datatype(datatype, param):
  44. """
  45. Cast gradient to datatype.
  46. Args:
  47. datatype (mstype): the destination datatype of parameter.
  48. param (Tensor): The parameter before operation.
  49. Returns:
  50. Tensor, the parameter after operation.
  51. """
  52. return F.cast(param, datatype)
  53. class WithLossCell(Cell):
  54. r"""
  55. Cell with loss function.
  56. Wraps the network with loss function. This Cell accepts data and label as inputs and
  57. the computed loss will be returned.
  58. Args:
  59. backbone (Cell): The target network to wrap.
  60. loss_fn (Cell): The loss function used to compute loss.
  61. Inputs:
  62. - **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
  63. - **label** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
  64. Outputs:
  65. Tensor, a tensor means the loss value, the shape of which is usually :math:`()`.
  66. Raises:
  67. TypeError: If dtype of `data` or `label` is neither float16 nor float32.
  68. Supported Platforms:
  69. ``Ascend`` ``GPU`` ``CPU``
  70. Examples:
  71. >>> net = Net()
  72. >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False)
  73. >>> net_with_criterion = nn.WithLossCell(net, loss_fn)
  74. >>>
  75. >>> batch_size = 2
  76. >>> data = Tensor(np.ones([batch_size, 1, 32, 32]).astype(np.float32) * 0.01)
  77. >>> label = Tensor(np.ones([batch_size, 10]).astype(np.float32))
  78. >>>
  79. >>> output_data = net_with_criterion(data, label)
  80. """
  81. def __init__(self, backbone, loss_fn):
  82. super(WithLossCell, self).__init__(auto_prefix=False)
  83. self._backbone = backbone
  84. self._loss_fn = loss_fn
  85. def construct(self, data, label):
  86. out = self._backbone(data)
  87. return self._loss_fn(out, label)
  88. @property
  89. def backbone_network(self):
  90. """
  91. Get the backbone network.
  92. Returns:
  93. Cell, the backbone network.
  94. """
  95. return self._backbone
  96. class WithGradCell(Cell):
  97. r"""
  98. Cell that returns the gradients.
  99. Wraps the network with backward cell to compute gradients. A network with a loss function is necessary
  100. as argument. If loss function in None, the network must be a wrapper of network and loss function. This
  101. Cell accepts '\*inputs' as inputs and returns gradients for each trainable parameter.
  102. Note:
  103. Run in PyNative mode.
  104. Args:
  105. network (Cell): The target network to wrap. The network only supports single output.
  106. loss_fn (Cell): Primitive loss function used to compute gradients. Default: None.
  107. sens (Union[None, Tensor, Scalar, Tuple ...]): The sensitive for backpropagation, the type and shape
  108. must be same as the `network` output. If None, we will fill one to a same type shape of
  109. output value. Default: None.
  110. Inputs:
  111. - **(\*inputs)** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`.
  112. Outputs:
  113. list, a list of Tensors with identical shapes as trainable weights.
  114. Raises:
  115. TypeError: If `sens` is not one of None, Tensor, Scalar or Tuple.
  116. Supported Platforms:
  117. ``Ascend`` ``GPU`` ``CPU``
  118. Examples:
  119. >>> # For a defined network Net without loss function
  120. >>> net = Net()
  121. >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits()
  122. >>> grad_net = nn.WithGradCell(net, loss_fn)
  123. >>>
  124. >>> # For a network wrapped with loss function
  125. >>> net = Net()
  126. >>> net_with_criterion = nn.WithLossCell(net, loss_fn)
  127. >>> grad_net = nn.WithGradCell(net_with_criterion)
  128. """
  129. def __init__(self, network, loss_fn=None, sens=None):
  130. super(WithGradCell, self).__init__(auto_prefix=False)
  131. self.network = network
  132. self.loss_fn = loss_fn
  133. self.weights = ParameterTuple(network.trainable_params())
  134. self.grad = C.GradOperation(get_by_list=True, sens_param=(sens is not None))
  135. self.sens = sens
  136. if loss_fn is None:
  137. self.network_with_loss = network
  138. else:
  139. self.network_with_loss = WithLossCell(self.network, self.loss_fn)
  140. self.network_with_loss.set_train()
  141. def construct(self, *inputs):
  142. weights = self.weights
  143. if self.sens is None:
  144. grads = self.grad(self.network_with_loss, weights)(*inputs)
  145. else:
  146. grads = self.grad(self.network_with_loss, weights)(*inputs, self.sens)
  147. return grads
  148. class ForwardValueAndGrad(Cell):
  149. r"""
  150. Network training package class.
  151. Including the network and a gradient function. The resulting Cell is trained with input '\*inputs'.
  152. The backward graph will be created in the gradient function to calculating gradient.
  153. Args:
  154. network (Cell): The training network.
  155. weights (ParameterTuple): The parameters of the training network that need to calculate the gradient.
  156. get_all (bool): If True, get all the gradients with respect to inputs. Default: False.
  157. get_by_list (bool): If True, get all the gradients with respect to Parameter variables.
  158. If get_all and get_by_list are both False, get the gradient with respect to first input.
  159. If get_all and get_by_list are both True, get the gradients with respect to inputs and Parameter variables
  160. at the same time in the form of ((gradients with respect to inputs),
  161. (gradients with respect to parameters)). Default: False.
  162. sens_param (bool): Whether to append sensitivity (gradient with respect to output) as input.
  163. If sens_param is False, a 'ones_like(outputs)' sensitivity will be attached automatically.
  164. Default: False.
  165. If the sens_param is True, a sensitivity (gradient with respect to output) needs to be transferred through
  166. the input parameter.
  167. Inputs:
  168. - **(\*inputs)** (Tuple(Tensor...)) - Tuple of inputs with shape :math:`(N, \ldots)`.
  169. - **(sens)** - A sensitivity (gradient with respect to output) as the input of backpropagation.
  170. If network has single output, the sens is a tensor.
  171. If network has multiple outputs, the sens is the tuple(tensor).
  172. Outputs:
  173. - **forward value** - The result of network forward running.
  174. - **gradients** (tuple(tensor)) - The gradients of network parameters and inputs.
  175. Supported Platforms:
  176. ``Ascend`` ``GPU`` ``CPU``
  177. Examples:
  178. >>> class Net(nn.Cell):
  179. ... def __init__(self):
  180. ... super(Net, self).__init__()
  181. ... self.weight = Parameter(Tensor(np.ones([2, 2]).astype(np.float32)), name="weight")
  182. ... self.matmul = P.MatMul()
  183. ...
  184. ... def construct(self, x):
  185. ... out = self.matmul(x, self.weight)
  186. ... return out
  187. ...
  188. >>> net = Net()
  189. >>> criterion = nn.SoftmaxCrossEntropyWithLogits()
  190. >>> net_with_criterion = nn.WithLossCell(net, criterion)
  191. >>> weight = ParameterTuple(net.trainable_params())
  192. >>> train_network = nn.ForwardValueAndGrad(net_with_criterion, weights=weight, get_all=True, get_by_list=True)
  193. >>> inputs = Tensor(np.ones([1, 2]).astype(np.float32))
  194. >>> labels = Tensor(np.zeros([1, 2]).astype(np.float32))
  195. >>> result = train_network(inputs, labels)
  196. >>> print(result)
  197. (Tensor(shape=[1], dtype=Float32, value=[0.00000000e+00]), ((Tensor(shape=[1, 2], dtype=Float32, value=
  198. [[1.00000000e+00, 1.00000000e+00]]), Tensor(shape=[1, 2], dtype=Float32, value=
  199. [[0.00000000e+00, 0.00000000e+00]])), (Tensor(shape=[2, 2], dtype=Float32, value=
  200. [[5.00000000e-01, 5.00000000e-01],
  201. [5.00000000e-01, 5.00000000e-01]]),)))
  202. """
  203. def __init__(self, network, weights=None, get_all=False, get_by_list=False, sens_param=False):
  204. super(ForwardValueAndGrad, self).__init__(auto_prefix=False)
  205. if not isinstance(network, (Cell, FunctionType, MethodType)):
  206. raise TypeError(f"For 'ForwardValueAndGrad', "
  207. f"the argument 'network' should be cell, function type or method type, "
  208. f"but got '{type(network)}'")
  209. if not isinstance(get_all, bool):
  210. raise TypeError(f"For 'ForwardValueAndGrad', "
  211. f"the type of 'get_all' should be bool, but got '{type(get_all)}'")
  212. if not isinstance(get_by_list, bool):
  213. raise TypeError(f"For 'ForwardValueAndGrad', "
  214. f"the type of 'get_by_list' should be bool, but got '{type(get_by_list)}'")
  215. if get_by_list and not isinstance(weights, ParameterTuple):
  216. raise TypeError(f"For 'ForwardValueAndGrad', "
  217. f"when 'get_by_list' is set to True, the argument 'network' should be "
  218. f"ParameterTuple type, but got '{type(weights)}'")
  219. self.network = network
  220. if isinstance(network, Cell):
  221. self.network.set_grad()
  222. self.weights = weights
  223. self.get_all = get_all
  224. self.get_by_list = get_by_list
  225. self.sens_param = sens_param
  226. self.grad = C.GradOperation(get_all=self.get_all, get_by_list=self.get_by_list, sens_param=self.sens_param)
  227. def construct(self, *inputs):
  228. grad_inputs = inputs
  229. if self.sens_param:
  230. inputs = inputs[:-1]
  231. loss = self.network(*inputs)
  232. if self.get_by_list:
  233. grads = self.grad(self.network, self.weights)(*grad_inputs)
  234. else:
  235. grads = self.grad(self.network)(*grad_inputs)
  236. return loss, grads
  237. class TrainOneStepCell(Cell):
  238. r"""
  239. Network training package class.
  240. Wraps the network with an optimizer. The resulting Cell is trained with input '\*inputs'.
  241. The backward graph will be created in the construct function to update the parameter. Different
  242. parallel modes are available for training.
  243. Args:
  244. network (Cell): The training network. The network only supports single output.
  245. optimizer (Union[Cell]): Optimizer for updating the weights.
  246. sens (numbers.Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0.
  247. Inputs:
  248. - **(\*inputs)** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`.
  249. Outputs:
  250. Tensor, a tensor means the loss value, the shape of which is usually :math:`()`.
  251. Raises:
  252. TypeError: If `sens` is not a number.
  253. Supported Platforms:
  254. ``Ascend`` ``GPU`` ``CPU``
  255. Examples:
  256. >>> net = Net()
  257. >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits()
  258. >>> optim = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
  259. >>> #1) Using the WithLossCell existing provide
  260. >>> loss_net = nn.WithLossCell(net, loss_fn)
  261. >>> train_net = nn.TrainOneStepCell(loss_net, optim)
  262. >>>
  263. >>> #2) Using user-defined WithLossCell
  264. >>> class MyWithLossCell(Cell):
  265. ... def __init__(self, backbone, loss_fn):
  266. ... super(MyWithLossCell, self).__init__(auto_prefix=False)
  267. ... self._backbone = backbone
  268. ... self._loss_fn = loss_fn
  269. ...
  270. ... def construct(self, x, y, label):
  271. ... out = self._backbone(x, y)
  272. ... return self._loss_fn(out, label)
  273. ...
  274. ... @property
  275. ... def backbone_network(self):
  276. ... return self._backbone
  277. ...
  278. >>> loss_net = MyWithLossCell(net, loss_fn)
  279. >>> train_net = nn.TrainOneStepCell(loss_net, optim)
  280. """
  281. def __init__(self, network, optimizer, sens=1.0):
  282. super(TrainOneStepCell, self).__init__(auto_prefix=False)
  283. self.network = network
  284. self.network.set_grad()
  285. self.optimizer = optimizer
  286. self.weights = self.optimizer.parameters
  287. self.grad = C.GradOperation(get_by_list=True, sens_param=True)
  288. self.sens = sens
  289. self.reducer_flag = False
  290. self.grad_reducer = F.identity
  291. self.parallel_mode = _get_parallel_mode()
  292. self.reducer_flag = self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL)
  293. if self.reducer_flag:
  294. self.mean = _get_gradients_mean()
  295. self.degree = _get_device_num()
  296. self.grad_reducer = DistributedGradReducer(self.weights, self.mean, self.degree)
  297. def construct(self, *inputs):
  298. loss = self.network(*inputs)
  299. sens = F.fill(loss.dtype, loss.shape, self.sens)
  300. grads = self.grad(self.network, self.weights)(*inputs, sens)
  301. grads = self.grad_reducer(grads)
  302. loss = F.depend(loss, self.optimizer(grads))
  303. return loss
  304. class GetNextSingleOp(Cell):
  305. """
  306. Cell to run for getting the next operation.
  307. For detailed information, refer to `ops.operations.GetNext`.
  308. Args:
  309. dataset_types (list[:class:`mindspore.dtype`]): The types of dataset.
  310. dataset_shapes (list[tuple[int]]): The shapes of dataset.
  311. queue_name (str): Queue name to fetch the data.
  312. Inputs:
  313. No inputs.
  314. Outputs:
  315. tuple[Tensor], the data get from Dataset.
  316. Supported Platforms:
  317. ``Ascend`` ``GPU``
  318. Examples:
  319. >>> train_dataset = create_custom_dataset()
  320. >>> dataset_helper = mindspore.DatasetHelper(train_dataset, dataset_sink_mode=True)
  321. >>> dataset = dataset_helper.iter.dataset
  322. >>> dataset_types, dataset_shapes = dataset_helper.types_shapes()
  323. >>> queue_name = dataset.__transfer_dataset__.queue_name
  324. >>> get_next_single_op_net = nn.GetNextSingleOp(dataset_types, dataset_shapes, queue_name)
  325. >>> data, label = get_next_single_op_net()
  326. >>> relu = P.ReLU()
  327. >>> result = relu(data).asnumpy()
  328. >>> print(result.shape)
  329. (32, 1, 32, 32)
  330. """
  331. def __init__(self, dataset_types, dataset_shapes, queue_name):
  332. super(GetNextSingleOp, self).__init__()
  333. self.get_next = P.GetNext(dataset_types, dataset_shapes, len(dataset_types), queue_name)
  334. def construct(self):
  335. return self.get_next()
  336. class _VirtualDatasetCell(Cell):
  337. """
  338. Wrap the network with virtual dataset to convert data parallel layout to model parallel layout.
  339. _VirtualDataset is a virtual Primitive, it does not exist in the final executing graph. Inputs and outputs
  340. of _VirtualDataset are distributed in data parallel pattern, tensor redistribution Primitives is inserted
  341. dynamically during the graph compile process.
  342. Note:
  343. Only used in semi auto parallel and auto parallel mode.
  344. Args:
  345. backbone (Cell): The target network to wrap.
  346. Examples:
  347. >>> net = Net()
  348. >>> net = _VirtualDatasetCell(net)
  349. """
  350. def __init__(self, backbone):
  351. super(_VirtualDatasetCell, self).__init__(auto_prefix=False)
  352. self._backbone = backbone
  353. self._virtual_dataset = _VirtualDataset()
  354. def construct(self, *inputs):
  355. output = self._virtual_dataset(*inputs)
  356. return self._backbone(*output)
  357. class _MicroBatch(Cell):
  358. """
  359. transform mini-batch to micro-batch in pipeline parallel.
  360. Args:
  361. params (micro_size): The number of micro-batch.
  362. """
  363. def __init__(self, micro_size):
  364. super(_MicroBatch, self).__init__()
  365. self.shape = P.Shape()
  366. self.micro_size = micro_size
  367. self.strided_slice = P.StridedSlice()
  368. def construct(self, i, *inputs):
  369. micro_inputs = ()
  370. for each_input in inputs:
  371. input_shape = self.shape(each_input)
  372. micro_batch_begin = i * input_shape[0] // self.micro_size
  373. micro_batch_end = (i + 1) * input_shape[0] // self.micro_size
  374. strided_slice_begin = (micro_batch_begin,)
  375. strided_slice_strides = (1,)
  376. for _ in range(len(input_shape) - 1):
  377. strided_slice_begin += (0,)
  378. strided_slice_strides += (1,)
  379. strided_slice_end = (micro_batch_end,)
  380. strided_slice_end += input_shape[1:]
  381. micro_input = self.strided_slice(each_input, strided_slice_begin, strided_slice_end, strided_slice_strides)
  382. micro_inputs += (micro_input,)
  383. return micro_inputs
  384. class MicroBatchInterleaved(Cell):
  385. """
  386. Wrap the network with Batch Size.
  387. Args:
  388. network (Cell): The target network to wrap.
  389. interleave_num (int): split num of batch size. Default: 2.
  390. Examples:
  391. >>> net = Net()
  392. >>> net = MicroBatchInterleaved(net, 4)
  393. """
  394. def __init__(self, network, interleave_num=2):
  395. super(MicroBatchInterleaved, self).__init__(auto_prefix=False)
  396. self.network = network
  397. self.interleave_num = interleave_num
  398. self.interleave_inputs = nn.CellList()
  399. for _ in range(interleave_num):
  400. interleave_data = _MicroBatch(interleave_num)
  401. interleave_data.strided_slice.add_prim_attr("strided_slice_flag", True)
  402. self.interleave_inputs.append(interleave_data)
  403. def construct(self, *inputs):
  404. output = 0.0
  405. for i in range(self.interleave_num):
  406. interleave_input = self.interleave_inputs[i](i, *inputs)
  407. output += self.network(*interleave_input)
  408. return output / self.interleave_num
  409. class PipelineCell(Cell):
  410. """
  411. Wrap the network with Micro Batch.
  412. Note:
  413. micro_size must be greater or equal to pipeline stages.
  414. Args:
  415. network (Cell): The target network to wrap.
  416. micro_size (int): MicroBatch size.
  417. Examples:
  418. >>> net = Net()
  419. >>> net = PipelineCell(net, 4)
  420. """
  421. def __init__(self, network, micro_size):
  422. super(PipelineCell, self).__init__(auto_prefix=False)
  423. self.network = network
  424. self.micro_inputs = nn.CellList()
  425. self.micro_size = micro_size
  426. self.add_list = []
  427. for i in range(micro_size):
  428. micro_input = _MicroBatch(micro_size)
  429. self.micro_inputs.append(micro_input)
  430. self.add = P.Add().add_prim_attr("pipeline_end", i)
  431. self.add_list.append(self.add)
  432. def construct(self, *inputs):
  433. ret = None
  434. for i in range(self.micro_size):
  435. micro_input = self.micro_inputs[i](i, *inputs)
  436. output = self.network(*micro_input)
  437. if ret is not None:
  438. ret = self.add_list[i](ret, output)
  439. else:
  440. ret = output
  441. return ret
  442. def _pipeline_clear_grad(accu_grad, grad):
  443. accu_grad = F.depend(accu_grad, grad)
  444. zeros = F.tensor_mul(accu_grad, 0.0)
  445. return F.assign(accu_grad, zeros)
  446. class _TrainPipelineAccuStepCell(TrainOneStepCell):
  447. """
  448. Wraps the network with an optimizer in pipeline mode.
  449. """
  450. def __init__(self, network, optimizer, sens=1.0):
  451. super(_TrainPipelineAccuStepCell, self).__init__(network, optimizer, sens)
  452. self.accu_grads = self.weights.clone(prefix="accu_grads", init="zeros")
  453. self.hyper_map = ops.HyperMap()
  454. self.opt_shard = _get_enable_parallel_optimizer()
  455. def construct(self, *inputs):
  456. weights = self.weights
  457. loss = self.network(*inputs)
  458. sens = ops.Fill()(ops.DType()(loss), ops.Shape()(loss), self.sens)
  459. grads = self.grad(self.network, weights)(*inputs, sens)
  460. accu_grads = ops.depend(self.accu_grads, grads)
  461. if self.opt_shard:
  462. succ = self.optimizer(grads)
  463. else:
  464. succ = self.optimizer(accu_grads)
  465. loss = ops.depend(loss, succ)
  466. clear = self.hyper_map(_pipeline_clear_grad, accu_grads, grads)
  467. loss = ops.depend(loss, clear)
  468. return loss
  469. class VirtualDatasetCellTriple(Cell):
  470. """
  471. Wrap the network with virtual dataset to convert data parallel layout to model parallel layout.
  472. VirtualDatasetCellTriple is a virtual Primitive, it does not exist in the final executing graph. Inputs and outputs
  473. of VirtualDatasetCellTriple are distributed in data parallel pattern, tensor redistribution Primitives is inserted
  474. dynamically during the graph compile process.
  475. Note:
  476. Only used in semi auto parallel and auto parallel mode. There are three inputs, as contrary to two inputs in
  477. _VirtualDatasetCell.
  478. Args:
  479. backbone (Cell): The target network to wrap.
  480. Examples:
  481. >>> net = Net()
  482. >>> net = VirtualDatasetCellTriple(net)
  483. """
  484. def __init__(self, backbone):
  485. super(VirtualDatasetCellTriple, self).__init__(auto_prefix=False)
  486. self._backbone = backbone
  487. self._virtual_dataset = _VirtualDataset()
  488. def construct(self, a, b, c):
  489. a_, b_, c_ = self._virtual_dataset(a, b, c)
  490. return self._backbone(a_, b_, c_)
  491. class WithEvalCell(Cell):
  492. r"""
  493. Cell that returns loss, output and label for evaluation.
  494. This Cell accepts a network and loss function as arguments and computes loss for model.
  495. It returns loss, output and label to calculate the metrics.
  496. Args:
  497. network (Cell): The network Cell.
  498. loss_fn (Cell): The loss Cell.
  499. add_cast_fp32 (bool): Adjust the data type to float32. Default: False.
  500. Inputs:
  501. - **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
  502. - **label** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
  503. Outputs:
  504. Tuple, containing a scalar loss Tensor, a network output Tensor of shape :math:`(N, \ldots)`
  505. and a label Tensor of shape :math:`(N, \ldots)`.
  506. Raises:
  507. TypeError: If `add_cast_fp32` is not a bool.
  508. Supported Platforms:
  509. ``Ascend`` ``GPU`` ``CPU``
  510. Examples:
  511. >>> # For a defined network Net without loss function
  512. >>> net = Net()
  513. >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits()
  514. >>> eval_net = nn.WithEvalCell(net, loss_fn)
  515. """
  516. def __init__(self, network, loss_fn, add_cast_fp32=False):
  517. super(WithEvalCell, self).__init__(auto_prefix=False)
  518. self._network = network
  519. self._loss_fn = loss_fn
  520. self.add_cast_fp32 = validator.check_value_type("add_cast_fp32", add_cast_fp32, [bool], self.cls_name)
  521. def construct(self, data, label):
  522. outputs = self._network(data)
  523. if self.add_cast_fp32:
  524. label = F.mixed_precision_cast(mstype.float32, label)
  525. outputs = F.cast(outputs, mstype.float32)
  526. loss = self._loss_fn(outputs, label)
  527. return loss, outputs, label
  528. class ParameterUpdate(Cell):
  529. """
  530. Cell that updates parameter.
  531. With this Cell, one can manually update `param` with the input `Tensor`.
  532. Args:
  533. param (Parameter): The parameter to be updated manually.
  534. Inputs:
  535. - **x** (Tensor) - A tensor whose shape and type are the same with `param`.
  536. Outputs:
  537. Tensor, the input `x`.
  538. Raises:
  539. KeyError: If parameter with the specified name does not exist.
  540. Supported Platforms:
  541. ``Ascend`` ``GPU`` ``CPU``
  542. Examples:
  543. >>> network = nn.Dense(3, 4)
  544. >>> param = network.parameters_dict()['weight']
  545. >>> update = nn.ParameterUpdate(param)
  546. >>> update.phase = "update_param"
  547. >>> weight = Tensor(np.arange(12).reshape((4, 3)), mindspore.float32)
  548. >>> output = update(weight)
  549. """
  550. def __init__(self, param):
  551. super(ParameterUpdate, self).__init__(auto_prefix=False)
  552. if not isinstance(param, Parameter):
  553. raise TypeError("For 'ParameterUpdate', 'param' must be 'Parameter', but got {}.".format(param))
  554. self._param = param
  555. def construct(self, x):
  556. F.assign(self._param, x)
  557. return x
  558. class _BroadCastCell(Cell):
  559. """
  560. Broadcast the parameters from device 0 to other devices.
  561. Args:
  562. params (list): The parameters of Net.
  563. """
  564. def __init__(self, params):
  565. super(_BroadCastCell, self).__init__()
  566. from mindspore.communication.management import get_group_size, create_group
  567. from mindspore import context
  568. self.map_ = C.Map()
  569. self.params = tuple(params)
  570. if context.get_context("device_target") == "Ascend" and context.get_context("mode") != context.PYNATIVE_MODE:
  571. rank_list = [id for id in range(0, get_group_size())]
  572. create_group("BroadcastWorldGroup", rank_list)
  573. self.broadcast = P.Broadcast(0, group="BroadcastWorldGroup")
  574. else:
  575. self.broadcast = P.Broadcast(0)
  576. def construct(self):
  577. datatypes = self.map_(F.partial(_get_datatype), self.params)
  578. params = self.map_(F.partial(_cast_datatype, mstype.float32), self.params)
  579. params = self.broadcast(params)
  580. new_params = self.map_(F.partial(_cast_datatype), datatypes, params)
  581. return new_params