You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cell_wrapper.py 13 kB

5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Cell_wrapper."""
  16. from mindspore.parallel._utils import (_get_device_num, _get_gradients_mean,
  17. _get_parallel_mode)
  18. from mindspore.context import ParallelMode
  19. from ...common import dtype as mstype
  20. from ...common.parameter import Parameter, ParameterTuple
  21. from ...ops import composite as C
  22. from ...ops import functional as F
  23. from ...ops import operations as P
  24. from ...ops.operations.comm_ops import _VirtualDataset
  25. from ..cell import Cell
  26. from .grad_reducer import DistributedGradReducer
  27. class WithLossCell(Cell):
  28. r"""
  29. Cell with loss function.
  30. Wraps the network with loss function. This Cell accepts data and label as inputs and
  31. the computed loss will be returned.
  32. Args:
  33. backbone (Cell): The target network to wrap.
  34. loss_fn (Cell): The loss function used to compute loss.
  35. Inputs:
  36. - **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
  37. - **label** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
  38. Outputs:
  39. Tensor, a scalar tensor with shape :math:`()`.
  40. Examples:
  41. >>> net = Net()
  42. >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
  43. >>> net_with_criterion = nn.WithLossCell(net, loss_fn)
  44. >>>
  45. >>> batch_size = 2
  46. >>> data = Tensor(np.ones([batch_size, 3, 64, 64]).astype(np.float32) * 0.01)
  47. >>> label = Tensor(np.ones([batch_size, 1, 1, 1]).astype(np.int32))
  48. >>>
  49. >>> net_with_criterion(data, label)
  50. """
  51. def __init__(self, backbone, loss_fn):
  52. super(WithLossCell, self).__init__(auto_prefix=False)
  53. self._backbone = backbone
  54. self._loss_fn = loss_fn
  55. def construct(self, data, label):
  56. out = self._backbone(data)
  57. return self._loss_fn(out, label)
  58. @property
  59. def backbone_network(self):
  60. """
  61. Returns the backbone network.
  62. Returns:
  63. Cell, the backbone network.
  64. """
  65. return self._backbone
  66. class WithGradCell(Cell):
  67. r"""
  68. Cell that returns the gradients.
  69. Wraps the network with backward cell to compute gradients. A network with a loss function is necessary
  70. as argument. If loss function in None, the network must be a wrapper of network and loss function. This
  71. Cell accepts '*inputs' as inputs and returns gradients for each trainable parameter.
  72. Note:
  73. Run in PyNative mode.
  74. Args:
  75. network (Cell): The target network to wrap. The network only supports single output.
  76. loss_fn (Cell): Primitive loss function used to compute gradients. Default: None.
  77. sens (Union[None, Tensor, Scalar, Tuple ...]): The sensitive for backpropagation, the type and shape
  78. must be same as the `network` output. If None, we will fill one to a same type shape of
  79. output value. Default: None.
  80. Inputs:
  81. - **(*inputs)** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`.
  82. Outputs:
  83. list, a list of Tensors with identical shapes as trainable weights.
  84. Examples:
  85. >>> # For a defined network Net without loss function
  86. >>> net = Net()
  87. >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits()
  88. >>> grad_net = nn.WithGradCell(net, loss_fn)
  89. >>>
  90. >>> # For a network wrapped with loss function
  91. >>> net = Net()
  92. >>> net_with_criterion = nn.WithLossCell(net, loss_fn)
  93. >>> grad_net = nn.WithGradCell(net_with_criterion)
  94. """
  95. def __init__(self, network, loss_fn=None, sens=None):
  96. super(WithGradCell, self).__init__(auto_prefix=False)
  97. self.network = network
  98. self.loss_fn = loss_fn
  99. self.weights = ParameterTuple(network.trainable_params())
  100. self.grad = C.GradOperation(get_by_list=True, sens_param=(sens is not None))
  101. self.sens = sens
  102. if loss_fn is None:
  103. self.network_with_loss = network
  104. else:
  105. self.network_with_loss = WithLossCell(self.network, self.loss_fn)
  106. self.network_with_loss.set_train()
  107. def construct(self, *inputs):
  108. weights = self.weights
  109. if self.sens is None:
  110. grads = self.grad(self.network_with_loss, weights)(*inputs)
  111. else:
  112. grads = self.grad(self.network_with_loss, weights)(*inputs, self.sens)
  113. return grads
  114. class TrainOneStepCell(Cell):
  115. r"""
  116. Network training package class.
  117. Wraps the network with an optimizer. The resulting Cell is trained with input *inputs.
  118. The backward graph will be created in the construct function to update the parameter. Different
  119. parallel modes are available for training.
  120. Args:
  121. network (Cell): The training network. The network only supports single output.
  122. optimizer (Cell): Optimizer for updating the weights.
  123. sens (Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0.
  124. Inputs:
  125. - **(*inputs)** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`.
  126. Outputs:
  127. Tensor, a scalar Tensor with shape :math:`()`.
  128. Examples:
  129. >>> net = Net()
  130. >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits()
  131. >>> optim = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
  132. >>> #1) Using the WithLossCell existing provide
  133. >>> loss_net = nn.WithLossCell(net, loss_fn)
  134. >>> train_net = nn.TrainOneStepCell(loss_net, optim)
  135. >>>
  136. >>> #2) Using user-defined WithLossCell
  137. >>>class MyWithLossCell(nn.cell):
  138. >>> def __init__(self, backbone, loss_fn):
  139. >>> super(WithLossCell, self).__init__(auto_prefix=False)
  140. >>> self._backbone = backbone
  141. >>> self._loss_fn = loss_fn
  142. >>>
  143. >>> def construct(self, x, y, label):
  144. >>> out = self._backbone(x, y)
  145. >>> return self._loss_fn(out, label)
  146. >>>
  147. >>> loss_net = MyWithLossCell(net, loss_fn)
  148. >>> train_net = nn.TrainOneStepCell(loss_net, optim)
  149. """
  150. def __init__(self, network, optimizer, sens=1.0):
  151. super(TrainOneStepCell, self).__init__(auto_prefix=False)
  152. self.network = network
  153. self.network.set_grad()
  154. self.network.add_flags(defer_inline=True)
  155. self.weights = optimizer.parameters
  156. self.optimizer = optimizer
  157. self.grad = C.GradOperation(get_by_list=True, sens_param=True)
  158. self.sens = sens
  159. self.reducer_flag = False
  160. self.grad_reducer = F.identity
  161. self.parallel_mode = _get_parallel_mode()
  162. if self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
  163. self.reducer_flag = True
  164. if self.reducer_flag:
  165. mean = _get_gradients_mean()
  166. degree = _get_device_num()
  167. self.grad_reducer = DistributedGradReducer(self.weights, mean, degree)
  168. def construct(self, *inputs):
  169. weights = self.weights
  170. loss = self.network(*inputs)
  171. sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
  172. grads = self.grad(self.network, weights)(*inputs, sens)
  173. grads = self.grad_reducer(grads)
  174. return F.depend(loss, self.optimizer(grads))
  175. class GetNextSingleOp(Cell):
  176. """
  177. Cell to run for getting the next operation.
  178. Args:
  179. dataset_types (list[:class:`mindspore.dtype`]): The types of dataset.
  180. dataset_shapes (list[tuple[int]]): The shapes of dataset.
  181. queue_name (str): Queue name to fetch the data.
  182. For detailed information, refer to `ops.operations.GetNext`.
  183. """
  184. def __init__(self, dataset_types, dataset_shapes, queue_name):
  185. super(GetNextSingleOp, self).__init__()
  186. self.get_next = P.GetNext(dataset_types, dataset_shapes, len(dataset_types), queue_name)
  187. def construct(self):
  188. return self.get_next()
  189. class _VirtualDatasetCell(Cell):
  190. """
  191. Wrap the network with virtual dataset to convert data parallel layout to model parallel layout.
  192. _VirtualDataset is a virtual Primitive, it does not exist in the final executing graph. Inputs and outpus
  193. of _VirtualDataset are distributed in data parallel pattern, tensor redistribution Primitives is inserted
  194. dynamically during the graph compile process.
  195. Note:
  196. Only used in semi auto parallel and auto parallel mode.
  197. Args:
  198. backbone (Cell): The target network to wrap.
  199. Examples:
  200. >>> net = Net()
  201. >>> net = _VirtualDatasetCell(net)
  202. """
  203. def __init__(self, backbone):
  204. super(_VirtualDatasetCell, self).__init__(auto_prefix=False)
  205. self._backbone = backbone
  206. self._virtual_dataset = _VirtualDataset()
  207. def construct(self, data, label):
  208. data_, label_ = self._virtual_dataset(data, label)
  209. return self._backbone(data_, label_)
  210. class VirtualDatasetCellTriple(Cell):
  211. """
  212. Wrap the network with virtual dataset to convert data parallel layout to model parallel layout.
  213. VirtualDatasetCellTriple is a virtual Primitive, it does not exist in the final executing graph. Inputs and outputs
  214. of VirtualDatasetCellTriple are distributed in data parallel pattern, tensor redistribution Primitives is inserted
  215. dynamically during the graph compile process.
  216. Note:
  217. Only used in semi auto parallel and auto parallel mode. There are three inputs, as contrary to two inputs in
  218. _VirtualDatasetCell.
  219. Args:
  220. backbone (Cell): The target network to wrap.
  221. Examples:
  222. >>> net = Net()
  223. >>> net = VirtualDatasetCellTriple(net)
  224. """
  225. def __init__(self, backbone):
  226. super(VirtualDatasetCellTriple, self).__init__(auto_prefix=False)
  227. self._backbone = backbone
  228. self._virtual_dataset = _VirtualDataset()
  229. def construct(self, a, b, c):
  230. a_, b_, c_ = self._virtual_dataset(a, b, c)
  231. return self._backbone(a_, b_, c_)
  232. class WithEvalCell(Cell):
  233. r"""
  234. Cell that returns loss, output and label for evaluation.
  235. This Cell accepts a network and loss function as arguments and computes loss for model.
  236. It returns loss, output and label to calculate the metrics.
  237. Args:
  238. network (Cell): The network Cell.
  239. loss_fn (Cell): The loss Cell.
  240. Inputs:
  241. - **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
  242. - **label** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
  243. Outputs:
  244. Tuple, containing a scalar loss Tensor, a network output Tensor of shape :math:`(N, \ldots)`
  245. and a label Tensor of shape :math:`(N, \ldots)`.
  246. Examples:
  247. >>> # For a defined network Net without loss function
  248. >>> net = Net()
  249. >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits()
  250. >>> eval_net = nn.WithEvalCell(net, loss_fn)
  251. """
  252. def __init__(self, network, loss_fn, add_cast_fp32=False):
  253. super(WithEvalCell, self).__init__(auto_prefix=False)
  254. self._network = network
  255. self._loss_fn = loss_fn
  256. self.add_cast_fp32 = add_cast_fp32
  257. def construct(self, data, label):
  258. outputs = self._network(data)
  259. if self.add_cast_fp32:
  260. label = F.mixed_precision_cast(mstype.float32, label)
  261. outputs = F.cast(outputs, mstype.float32)
  262. loss = self._loss_fn(outputs, label)
  263. return loss, outputs, label
  264. class ParameterUpdate(Cell):
  265. """
  266. Cell that updates parameters.
  267. With this Cell, one can manually update `param` with the input `Tensor`.
  268. Args:
  269. param (Parameter): The parameter to be updated manually.
  270. Raises:
  271. KeyError: If parameter with the specified name does not exist.
  272. Examples:
  273. >>> network = Net()
  274. >>> param = network.parameters_dict()['learning_rate']
  275. >>> update = nn.ParameterUpdate(param)
  276. >>> update.phase = "update_param"
  277. >>> lr = Tensor(0.001, mindspore.float32)
  278. >>> update(lr)
  279. """
  280. def __init__(self, param):
  281. super(ParameterUpdate, self).__init__(auto_prefix=False)
  282. if not isinstance(param, Parameter):
  283. raise TypeError("`param` must be `Parameter`, but got {}".format(param))
  284. self._param = param
  285. def construct(self, x):
  286. F.assign(self._param, x)
  287. return x