You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_grad_accumulation.py 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import numpy as np
  16. import mindspore as ms
  17. import mindspore.common.dtype as mstype
  18. from mindspore import context, Tensor, Parameter
  19. from mindspore.nn import Cell, Momentum, Norm
  20. from mindspore.train import Model
  21. from mindspore.ops import operations as P
  22. from mindspore.ops import composite as C
  23. from mindspore.ops import functional as F
  24. from mindspore.common.initializer import initializer
  25. from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
  26. from mindspore.context import ParallelMode
  27. from tests.dataset_mock import MindData
  28. class Dataset(MindData):
  29. def __init__(self, predict, label, length=3):
  30. super(Dataset, self).__init__(size=length)
  31. self.predict = predict
  32. self.label = label
  33. self.index = 0
  34. self.length = length
  35. def __iter__(self):
  36. return self
  37. def __next__(self):
  38. if self.index >= self.length:
  39. raise StopIteration
  40. self.index += 1
  41. return self.predict, self.label
  42. def reset(self):
  43. self.index = 0
  44. get_square_sum = C.MultitypeFuncGraph("get_square_sum")
  45. @get_square_sum.register("Tensor")
  46. def _get_square_sum(grad):
  47. norm = P.ReduceSum(False)(F.square(grad), ())
  48. norm = F.expand_dims(F.cast(norm, mstype.float32), 0)
  49. return norm
  50. apply_global_norm = C.MultitypeFuncGraph("apply_global_norm")
  51. @apply_global_norm.register("Tensor", "Tensor", "Tensor")
  52. def _apply_global_norm(clip_norm, global_norm, grad):
  53. grad = grad * clip_norm / global_norm
  54. return grad
  55. class GlobalNorm(Cell):
  56. """
  57. Calculate the global norm value of given tensors
  58. """
  59. def __init__(self):
  60. super(GlobalNorm, self).__init__()
  61. self.norm = Norm()
  62. self.hyper_map = C.HyperMap()
  63. def construct(self, grads):
  64. square_sum = self.hyper_map(get_square_sum, grads)
  65. global_norms = F.sqrt(F.addn(square_sum) / F.scalar_to_array(len(square_sum)))
  66. return global_norms
  67. class ClipByGlobalNorm(Cell):
  68. """
  69. Clip grads by global norm
  70. """
  71. def __init__(self, clip_norm=1.0):
  72. super(ClipByGlobalNorm, self).__init__()
  73. self.global_norm = GlobalNorm()
  74. self.clip_norm = Tensor([clip_norm], mstype.float32)
  75. self.hyper_map = C.HyperMap()
  76. def construct(self, grads):
  77. global_norm = self.global_norm(grads)
  78. cond = P.GreaterEqual()(global_norm, self.clip_norm)
  79. global_norm = F.select(cond, global_norm, self.clip_norm)
  80. grads = self.hyper_map(F.partial(apply_global_norm, self.clip_norm, global_norm), grads)
  81. return grads
  82. cast = P.Cast()
  83. update_accu_grads = C.MultitypeFuncGraph("update_accu_grads")
  84. @update_accu_grads.register("Tensor", "Tensor")
  85. def _update_accu_grads(accu_grad, grad):
  86. succ = True
  87. return F.depend(succ, F.assign_add(accu_grad, cast(grad, mstype.float32)))
  88. zeroslike = P.ZerosLike()
  89. reset_accu_grads = C.MultitypeFuncGraph("reset_accu_grads")
  90. @reset_accu_grads.register("Tensor")
  91. def _reset_accu_grads(accu_grad):
  92. succ = True
  93. return F.depend(succ, F.assign(accu_grad, zeroslike(accu_grad)))
  94. grad_scale = C.MultitypeFuncGraph("grad_scale")
  95. reciprocal = P.Reciprocal()
  96. @grad_scale.register("Tensor", "Tensor")
  97. def tensor_grad_scale(scale, grad):
  98. return grad * reciprocal(scale)
  99. class TrainAccumulateStepsWithLossScaleCell(Cell):
  100. """
  101. Encapsulation class of bert network training.
  102. Append an optimizer to the training network after that the construct
  103. function can be called to create the backward graph. To mimic higher batch size, gradients are
  104. accumulated N times before weight update.
  105. Args:
  106. network (Cell): The training network. Note that loss function should have been added.
  107. optimizer (Optimizer): Optimizer for updating the weights.
  108. scale_update_cell (Cell): Cell to do the loss scale. Default: None.
  109. accumulation_steps (int): Number of accumulation steps before gradient update. The global batch size =
  110. batch_size * accumulation_steps. Default: 1.
  111. """
  112. def __init__(self, network, optimizer, scale_update_cell=None, accumulation_steps=4):
  113. super(TrainAccumulateStepsWithLossScaleCell, self).__init__(auto_prefix=False)
  114. self.network = network
  115. self.network.set_grad()
  116. self.weights = optimizer.parameters
  117. self.optimizer = optimizer
  118. self.accumulation_steps = accumulation_steps
  119. self.one = Tensor(np.array([1]).astype(np.int32))
  120. self.zero = Tensor(np.array([0]).astype(np.int32))
  121. self.local_step = Parameter(initializer(0, [1], mstype.int32), name="local_step")
  122. self.accu_grads = self.weights.clone(prefix="accu_grads", init='zeros')
  123. self.accu_overflow = Parameter(initializer(0, [1], mstype.int32))
  124. self.accu_loss = Parameter(initializer(0, [1], mstype.float32))
  125. self.grad = C.GradOperation(get_by_list=True, sens_param=True)
  126. self.reducer_flag = False
  127. self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
  128. if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
  129. self.reducer_flag = True
  130. self.grad_reducer = F.identity
  131. self.degree = 1
  132. if self.reducer_flag:
  133. self.degree = get_group_size()
  134. self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)
  135. self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
  136. self.overflow_reducer = F.identity
  137. if self.is_distributed:
  138. self.overflow_reducer = P.AllReduce()
  139. self.cast = P.Cast()
  140. self.alloc_status = P.NPUAllocFloatStatus()
  141. self.get_status = P.NPUGetFloatStatus()
  142. self.clear_before_grad = P.NPUClearFloatStatus()
  143. self.reduce_sum = P.ReduceSum(keep_dims=False)
  144. self.base = Tensor(1, mstype.float32)
  145. self.less_equal = P.LessEqual()
  146. self.logical_or = P.LogicalOr()
  147. self.not_equal = P.NotEqual()
  148. self.select = P.Select()
  149. self.reshape = P.Reshape()
  150. self.hyper_map = C.HyperMap()
  151. self.loss_scale = None
  152. self.loss_scaling_manager = scale_update_cell
  153. if scale_update_cell:
  154. self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32))
  155. @C.add_flags(has_effect=True)
  156. def construct(self, x, b, sens=None):
  157. """Defines the computation performed."""
  158. weights = self.weights
  159. loss = self.network(x, b)
  160. if sens is None:
  161. scaling_sens = self.loss_scale
  162. else:
  163. scaling_sens = sens
  164. # update accumulation parameters
  165. is_accu_step = self.not_equal(self.local_step, self.accumulation_steps)
  166. self.local_step = self.select(is_accu_step, self.local_step + self.one, self.one)
  167. self.accu_loss = self.select(is_accu_step, self.accu_loss + loss, loss)
  168. mean_loss = self.accu_loss / self.local_step
  169. is_accu_step = self.not_equal(self.local_step, self.accumulation_steps)
  170. # alloc status and clear should be right before gradoperation
  171. init = self.alloc_status()
  172. self.clear_before_grad(init)
  173. grads = self.grad(self.network, weights)(x, b, self.cast(scaling_sens, mstype.float32))
  174. accu_succ = self.hyper_map(update_accu_grads, self.accu_grads, grads)
  175. mean_loss = F.depend(mean_loss, accu_succ)
  176. self.get_status(init)
  177. flag_sum = self.reduce_sum(init, (0,))
  178. overflow = self.less_equal(self.base, flag_sum)
  179. overflow = self.logical_or(self.not_equal(self.accu_overflow, self.zero), overflow)
  180. accu_overflow = self.select(overflow, self.one, self.zero)
  181. self.accu_overflow = self.select(is_accu_step, accu_overflow, self.zero)
  182. is_accu_step = self.reshape(is_accu_step, (()))
  183. if is_accu_step:
  184. succ = False
  185. else:
  186. # apply grad reducer on grads
  187. grads = self.grad_reducer(self.accu_grads)
  188. scaling = scaling_sens * self.degree * self.accumulation_steps
  189. grads = self.hyper_map(F.partial(grad_scale, scaling), grads)
  190. grads = ClipByGlobalNorm()(grads)
  191. accu_overflow = self.overflow_reducer(accu_overflow)
  192. F.control_depend(grads, accu_overflow)
  193. overflow = self.less_equal(self.base, accu_overflow)
  194. accu_succ = self.hyper_map(reset_accu_grads, self.accu_grads)
  195. overflow = F.depend(overflow, accu_succ)
  196. overflow = self.reshape(overflow, (()))
  197. if sens is None:
  198. overflow = self.loss_scaling_manager(self.loss_scale, overflow)
  199. if overflow:
  200. succ = False
  201. else:
  202. succ = self.optimizer(grads)
  203. ret = (mean_loss, overflow, scaling_sens)
  204. return F.depend(ret, succ)
  205. class Net(Cell):
  206. def __init__(self, weight, strategy=None):
  207. super().__init__()
  208. self.mul = P.Mul().shard(strategy)
  209. self.weight = Parameter(weight, "w1")
  210. self.relu = P.ReLU()
  211. self.reduce_sum = P.ReduceSum(keep_dims=True)
  212. def construct(self, x, b):
  213. out = self.mul(x, self.weight)
  214. out = self.relu(out)
  215. out = self.reduce_sum(out)
  216. return out
  217. _x = Tensor(np.ones([2]), dtype=ms.float32)
  218. _b = Tensor(np.ones([16]), dtype=ms.float32)
  219. _w1 = Tensor(np.ones([16]), dtype=ms.float32)
  220. def compile_net(net, grad_accumulation_step):
  221. context.set_context(save_graphs=True)
  222. learning_rate = 0.1
  223. momentum = 0.9
  224. epoch_size = 2
  225. dataset = Dataset(_x, _b)
  226. opt = Momentum(net.trainable_params(), learning_rate, momentum)
  227. update_cell = DynamicLossScaleUpdateCell(loss_scale_value=65536, scale_factor=2, scale_window=1000)
  228. net_wrap = TrainAccumulateStepsWithLossScaleCell(net, opt, scale_update_cell=update_cell,
  229. accumulation_steps=grad_accumulation_step)
  230. model = Model(net_wrap)
  231. model.train(epoch_size, dataset, dataset_sink_mode=False)
  232. context.reset_auto_parallel_context()
  233. def test_grad_accumulation():
  234. grad_accumulation_step = 4
  235. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0,
  236. grad_accumulation_step=grad_accumulation_step)
  237. strategy = ((2,), (2,))
  238. net = Net(_w1, strategy)
  239. compile_net(net, grad_accumulation_step)