You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

loss_scale.py 14 kB

5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Loss scale cell for loss scale training."""
  16. import mindspore.context as context
  17. from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
  18. from mindspore.context import ParallelMode
  19. from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_gradients_mean
  20. from ..cell import Cell
  21. from ...common import Tensor, RowTensor
  22. from ...common.parameter import Parameter
  23. from ...ops import functional as F
  24. from ...ops import composite as C
  25. from ...ops import operations as P
  26. from ...ops.operations import NPUGetFloatStatus, NPUAllocFloatStatus, NPUClearFloatStatus, ReduceSum, LessEqual, \
  27. ControlDepend
  28. from ...common import dtype as mstype
  29. _grad_scale = C.MultitypeFuncGraph("grad_scale")
  30. reciprocal = P.Reciprocal()
  31. @_grad_scale.register("Tensor", "Tensor")
  32. def tensor_grad_scale(scale, grad):
  33. return grad * F.cast(reciprocal(scale), F.dtype(grad))
  34. @_grad_scale.register("Tensor", "RowTensor")
  35. def tensor_grad_scale_row_tensor(scale, grad):
  36. return RowTensor(grad.indices,
  37. grad.values * F.cast(reciprocal(scale), F.dtype(grad.values)),
  38. grad.dense_shape)
  39. _grad_overflow = C.MultitypeFuncGraph("_grad_overflow")
  40. grad_overflow = P.FloatStatus()
  41. @_grad_overflow.register("Tensor")
  42. def _tensor_grad_overflow(grad):
  43. return grad_overflow(grad)
  44. class DynamicLossScaleUpdateCell(Cell):
  45. r"""
  46. Dynamic Loss scale update cell.
  47. For loss scaling training, the initial loss scaling value will be set to be `loss_scale_value`.
  48. In each training step, the loss scaling value will be updated by loss scaling value/`scale_factor`
  49. when there is an overflow. And it will be increased by loss scaling value * `scale_factor` if there is no
  50. overflow for a continuous `scale_window` steps. This cell is used for Graph mode training in which all
  51. logic will be executed on device side(Another training mode is normal(non-sink) mode in which some logic will be
  52. executed on host).
  53. Args:
  54. loss_scale_value (float): Init loss scale.
  55. scale_factor (int): Coefficient of increase and decrease.
  56. scale_window (int): Maximum continuous training steps that do not have overflow.
  57. Inputs:
  58. - **inputs** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
  59. - **label** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
  60. Outputs:
  61. Tensor, a scalar Tensor with shape :math:`()`.
  62. Examples:
  63. >>> net_with_loss = Net()
  64. >>> optimizer = nn.Momentum(net_with_loss.trainable_params(), learning_rate=0.1, momentum=0.9)
  65. >>> manager = nn.DynamicLossScaleUpdateCell(loss_scale_value=2**12, scale_factor=2, scale_window=1000)
  66. >>> train_network = nn.TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_update_cell=manager)
  67. >>> train_network.set_train()
  68. >>>
  69. >>> inputs = Tensor(np.ones([16, 16]).astype(np.float32))
  70. >>> label = Tensor(np.zeros([16, 16]).astype(np.float32))
  71. >>> scaling_sens = Tensor(np.full((1), np.finfo(np.float32).max), dtype=mindspore.float32)
  72. >>> output = train_network(inputs, label, scaling_sens)
  73. """
  74. def __init__(self,
  75. loss_scale_value,
  76. scale_factor,
  77. scale_window):
  78. super(DynamicLossScaleUpdateCell, self).__init__()
  79. self.scale_window = Tensor(scale_window, dtype=mstype.int32)
  80. self.scale_factor = Tensor(scale_factor, dtype=mstype.float32)
  81. self.loss_scale_value = loss_scale_value
  82. self.cur_iter = Parameter(Tensor(1, dtype=mstype.int32), name="current_iterator_step")
  83. self.last_overflow_iter = Parameter(Tensor(0, dtype=mstype.int32), name="last_overflow_iterator_step")
  84. self.select = P.Select()
  85. self.max = P.Maximum()
  86. self.minimum_loss_scale = Tensor(1.0, dtype=mstype.float32)
  87. self.reciprocal = P.Reciprocal()
  88. self.less_equal = P.LessEqual()
  89. self.logic_and = P.LogicalAnd()
  90. self.logic_not = P.LogicalNot()
  91. self.logic_or = P.LogicalOr()
  92. self.const_true = Tensor(True, dtype=mstype.bool_)
  93. def get_loss_scale(self):
  94. return self.loss_scale_value
  95. def construct(self, loss_scale, overflow):
  96. overflow_cond = overflow
  97. loss_scale_on_overflow = self.select(overflow_cond, self.max(loss_scale * self.reciprocal(self.scale_factor),
  98. self.minimum_loss_scale), loss_scale)
  99. should_inc = self.less_equal(self.scale_window, self.cur_iter - self.last_overflow_iter)
  100. last_iter_cond = self.logic_or(overflow_cond, should_inc)
  101. last_overflow_iter = self.select(last_iter_cond, self.cur_iter, self.last_overflow_iter)
  102. assign_last_iter = F.assign(self.last_overflow_iter, last_overflow_iter)
  103. update_scale_cond = self.logic_and(should_inc, self.logic_not(overflow_cond))
  104. scale_mul_res = loss_scale_on_overflow * self.scale_factor
  105. scaled_loss_scale = self.select(update_scale_cond, scale_mul_res, loss_scale_on_overflow)
  106. assign_scaled_loss_scale = F.assign(loss_scale, scaled_loss_scale)
  107. inc_cur_iter = self.cur_iter + 1
  108. assing_cur_iter = F.assign(self.cur_iter, inc_cur_iter)
  109. t = (assign_last_iter, assign_scaled_loss_scale, assing_cur_iter)
  110. F.control_depend(assign_last_iter, assing_cur_iter)
  111. return F.depend(overflow, t)
  112. class FixedLossScaleUpdateCell(Cell):
  113. """
  114. Static scale update cell, the loss scaling value will not be updated.
  115. For usage, refer to `DynamicLossScaleUpdateCell`.
  116. Args:
  117. loss_scale_value (float): Init loss scale.
  118. Examples:
  119. >>> net_with_loss = Net()
  120. >>> optimizer = nn.Momentum(net_with_loss.trainable_params(), learning_rate=0.1, momentum=0.9)
  121. >>> manager = nn.FixedLossScaleUpdateCell(loss_scale_value=2**12)
  122. >>> train_network = nn.TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_update_cell=manager)
  123. >>> train_network.set_train()
  124. >>>
  125. >>> inputs = Tensor(np.ones([16, 16]).astype(np.float32))
  126. >>> label = Tensor(np.zeros([16, 16]).astype(np.float32))
  127. >>> scaling_sens = Tensor(np.full((1), np.finfo(np.float32).max), dtype=mindspore.float32)
  128. >>> output = train_network(inputs, label, scaling_sens)
  129. """
  130. def __init__(self, loss_scale_value):
  131. super(FixedLossScaleUpdateCell, self).__init__()
  132. self.loss_scale_value = loss_scale_value
  133. def get_loss_scale(self):
  134. return self.loss_scale_value
  135. def construct(self, _, overflow):
  136. return overflow
  137. class TrainOneStepWithLossScaleCell(Cell):
  138. r"""
  139. Network training with loss scaling.
  140. This is a training step with loss scaling. It takes a network, an optimizer and possibly a scale update
  141. Cell as args. The loss scale value can be updated in both host side or device side. The
  142. TrainOneStepWithLossScaleCell will be compiled to be graph which takes `*inputs` as input data.
  143. The Tensor type of `scale_sense` is acting as loss scaling value. If you want to update it on host side,
  144. the value should be provided. If the Tensor type of `scale_sense` is not given, the loss scale update logic
  145. should be provied by Cell type of `scale_sense`. If Cell type of `scale_sense` is not None and Tensor type
  146. of `scale_sense` is provided, the Cell type of `scale_sense` will be ignored.
  147. Args:
  148. network (Cell): The training network. The network only supports single output.
  149. optimizer (Cell): Optimizer for updating the weights.
  150. scale_sense (Union[Tensor, Cell]): If this value is Cell type, the loss scaling update logic cell.If this value
  151. is Tensor type, Tensor with shape :math:`()`. Default: None.
  152. Inputs:
  153. - **(*inputs)** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`.
  154. Outputs:
  155. Tuple of 3 Tensor, the loss, overflow flag and current loss scaling value.
  156. - **loss** (Tensor) - Tensor with shape :math:`()`.
  157. - **overflow** (Tensor) - Tensor with shape :math:`()`, type is bool.
  158. Examples:
  159. >>> net_with_loss = Net()
  160. >>> optimizer = nn.Momentum(net_with_loss.trainable_params(), learning_rate=0.1, momentum=0.9)
  161. >>> manager = nn.DynamicLossScaleUpdateCell(loss_scale_value=2**12, scale_factor=2, scale_window=1000)
  162. >>> train_network = nn.TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_update_cell=manager)
  163. >>> train_network.set_train()
  164. >>>
  165. >>> inputs = Tensor(np.ones([16, 16]).astype(np.float32))
  166. >>> label = Tensor(np.zeros([16, 16]).astype(np.float32))
  167. >>> scaling_sens = Tensor(np.full((1), np.finfo(np.float32).max), dtype=mindspore.float32)
  168. >>> output = train_network(inputs, label, scaling_sens)
  169. """
  170. def __init__(self, network, optimizer, scale_sense=None):
  171. super(TrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False)
  172. self.network = network
  173. self.network.set_grad()
  174. self.network.add_flags(defer_inline=True)
  175. self.weights = optimizer.parameters
  176. self.optimizer = optimizer
  177. self.grad = C.GradOperation(get_by_list=True, sens_param=True)
  178. self.hyper_map = C.HyperMap()
  179. if context.get_context("device_target") == "GPU":
  180. self.gpu_target = True
  181. self.float_status = P.FloatStatus()
  182. self.addn = P.AddN()
  183. self.reshape = P.Reshape()
  184. else:
  185. self.gpu_target = False
  186. self.alloc_status = NPUAllocFloatStatus()
  187. self.get_status = NPUGetFloatStatus()
  188. self.clear_status = NPUClearFloatStatus()
  189. self.reduce_sum = ReduceSum(keep_dims=False)
  190. self.base = Tensor(1, mstype.float32)
  191. self.less_equal = LessEqual()
  192. self.depend_parameter_use = ControlDepend(depend_mode=1)
  193. self.allreduce = P.AllReduce()
  194. self.parallel_mode = _get_parallel_mode()
  195. self.grad_reducer = F.identity
  196. self.reducer_flag = self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]
  197. if self.reducer_flag:
  198. mean = _get_gradients_mean()
  199. degree = _get_device_num()
  200. self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
  201. self.is_distributed = self.parallel_mode != ParallelMode.STAND_ALONE
  202. self.scale_sense = None
  203. self.loss_scaling_manager = None
  204. if isinstance(scale_sense, Cell):
  205. self.loss_scaling_manager = scale_sense
  206. self.scale_sense = Parameter(Tensor(scale_sense.get_loss_scale(), dtype=mstype.float32),
  207. name="scale_sense")
  208. if isinstance(scale_sense, Tensor):
  209. self.scale_sense = Parameter(scale_sense, name='scale_sense')
  210. @C.add_flags(has_effect=True)
  211. def construct(self, *inputs):
  212. weights = self.weights
  213. loss = self.network(*inputs)
  214. init = False
  215. if not self.gpu_target:
  216. # init overflow buffer
  217. init = self.alloc_status()
  218. # clear overflow buffer
  219. self.clear_status(init)
  220. scaling_sens = self.scale_sense
  221. scaling_sens_filled = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss))
  222. grads = self.grad(self.network, weights)(*inputs, scaling_sens_filled)
  223. grads = self.hyper_map(F.partial(_grad_scale, scaling_sens), grads)
  224. # apply grad reducer on grads
  225. grads = self.grad_reducer(grads)
  226. # get the overflow buffer
  227. if not self.gpu_target:
  228. self.get_status(init)
  229. # sum overflow buffer elements, 0:not overflow , >0:overflow
  230. flag_sum = self.reduce_sum(init, (0,))
  231. else:
  232. flag_sum = self.hyper_map(F.partial(_grad_overflow), grads)
  233. flag_sum = self.addn(flag_sum)
  234. # convert flag_sum to scalar
  235. flag_sum = self.reshape(flag_sum, (()))
  236. if self.is_distributed:
  237. # sum overflow flag over devices
  238. flag_reduce = self.allreduce(flag_sum)
  239. cond = self.less_equal(self.base, flag_reduce)
  240. else:
  241. cond = self.less_equal(self.base, flag_sum)
  242. overflow = cond
  243. if self.loss_scaling_manager is not None:
  244. overflow = self.loss_scaling_manager(self.scale_sense, cond)
  245. # if there is no overflow, do optimize
  246. if overflow:
  247. opt = False
  248. else:
  249. opt = self.optimizer(grads)
  250. ret = (loss, cond, scaling_sens)
  251. return F.depend(ret, opt)
  252. def set_sense_scale(self, sens):
  253. """If the user has set the sens in the training process and wants to reassign the value, he can call
  254. this function again to make modification, and sens needs to be of type Tensor."""
  255. if self.scale_sense and isinstance(sens, Tensor):
  256. self.self.scale_sense.set_data(sens)