You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_gradient_accumulation.py 8.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. import os
  2. import pytest
  3. import mindspore.dataset as ds
  4. import mindspore.dataset.transforms.c_transforms as CT
  5. import mindspore.dataset.vision.c_transforms as CV
  6. import mindspore.nn as nn
  7. from mindspore import ParameterTuple
  8. from mindspore import context
  9. from mindspore.common import dtype as mstype
  10. from mindspore.common.initializer import Normal
  11. from mindspore.dataset.vision import Inter
  12. from mindspore.nn import Cell
  13. from mindspore.ops import composite as C
  14. from mindspore.ops import functional as F
  15. from mindspore.ops import operations as P
  16. from mindspore.train.dataset_helper import DatasetHelper
  17. from mindspore.train.serialization import save_checkpoint
  18. _sum_op = C.MultitypeFuncGraph("grad_sum_op")
  19. _clear_op = C.MultitypeFuncGraph("clear_op")
  20. @_sum_op.register("Tensor", "Tensor")
  21. def _cumulative_gard(grad_sum, grad):
  22. """Apply gard sum to cumulative gradient."""
  23. add = P.AssignAdd()
  24. return add(grad_sum, grad)
  25. @_clear_op.register("Tensor", "Tensor")
  26. def _clear_grad_sum(grad_sum, zero):
  27. """Apply zero to clear grad_sum."""
  28. success = True
  29. success = F.depend(success, F.assign(grad_sum, zero))
  30. return success
  31. class LeNet5(nn.Cell):
  32. """
  33. Lenet network
  34. Args:
  35. num_class (int): Num classes. Default: 10.
  36. num_channel (int): Num channels. Default: 1.
  37. Returns:
  38. Tensor, output tensor
  39. Examples:
  40. >>> LeNet(num_class=10)
  41. """
  42. def __init__(self, num_class=10, num_channel=1):
  43. super(LeNet5, self).__init__()
  44. self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
  45. self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
  46. self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
  47. self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
  48. self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))
  49. self.relu = nn.ReLU()
  50. self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
  51. self.flatten = nn.Flatten()
  52. def construct(self, x):
  53. x = self.max_pool2d(self.relu(self.conv1(x)))
  54. x = self.max_pool2d(self.relu(self.conv2(x)))
  55. x = self.flatten(x)
  56. x = self.relu(self.fc1(x))
  57. x = self.relu(self.fc2(x))
  58. x = self.fc3(x)
  59. return x
  60. class TrainForwardBackward(Cell):
  61. def __init__(self, network, optimizer, grad_sum, sens=1.0):
  62. super(TrainForwardBackward, self).__init__(auto_prefix=False)
  63. self.network = network
  64. self.network.set_grad()
  65. self.network.add_flags(defer_inline=True)
  66. self.weights = ParameterTuple(network.trainable_params())
  67. self.optimizer = optimizer
  68. self.grad_sum = grad_sum
  69. self.grad = C.GradOperation(get_by_list=True, sens_param=True)
  70. self.sens = sens
  71. self.hyper_map = C.HyperMap()
  72. def construct(self, *inputs):
  73. weights = self.weights
  74. loss = self.network(*inputs)
  75. sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
  76. grads = self.grad(self.network, weights)(*inputs, sens)
  77. return F.depend(loss, self.hyper_map(F.partial(_sum_op), self.grad_sum, grads))
  78. class TrainOptim(Cell):
  79. def __init__(self, optimizer, grad_sum):
  80. super(TrainOptim, self).__init__(auto_prefix=False)
  81. self.optimizer = optimizer
  82. self.grad_sum = grad_sum
  83. def construct(self):
  84. return self.optimizer(self.grad_sum)
  85. class TrainClear(Cell):
  86. def __init__(self, grad_sum, zeros):
  87. super(TrainClear, self).__init__(auto_prefix=False)
  88. self.grad_sum = grad_sum
  89. self.zeros = zeros
  90. self.hyper_map = C.HyperMap()
  91. def construct(self):
  92. seccess = self.hyper_map(F.partial(_clear_op), self.grad_sum, self.zeros)
  93. return seccess
  94. class GradientAccumulation:
  95. def __init__(self, network, loss_fn, optimizer):
  96. self._network = network
  97. self._loss_fn = loss_fn
  98. self._optimizer = optimizer
  99. params = self._optimizer.parameters
  100. self._grad_sum = params.clone(prefix="grad_sum", init='zeros')
  101. self._zeros = params.clone(prefix="zeros", init='zeros')
  102. self._train_forward_backward = self._build_train_forward_backward_network()
  103. self._train_optim = self._build_train_optim()
  104. self._train_clear = self._build_train_clear()
  105. def _build_train_forward_backward_network(self):
  106. """Build forward and backward network"""
  107. network = self._network
  108. network = nn.WithLossCell(network, self._loss_fn)
  109. loss_scale = 1.0
  110. network = TrainForwardBackward(network, self._optimizer, self._grad_sum, loss_scale).set_train()
  111. return network
  112. def _build_train_optim(self):
  113. """Build optimizer network"""
  114. network = TrainOptim(self._optimizer, self._grad_sum).set_train()
  115. return network
  116. def _build_train_clear(self):
  117. """Build clear network"""
  118. network = TrainClear(self._grad_sum, self._zeros).set_train()
  119. return network
  120. def train_process(self, epoch, train_dataset, mini_steps=None):
  121. """
  122. Training process. The data would be passed to network directly.
  123. """
  124. dataset_helper = DatasetHelper(train_dataset, dataset_sink_mode=False, epoch_num=epoch)
  125. for i in range(epoch):
  126. step = 0
  127. for k, next_element in enumerate(dataset_helper):
  128. loss = self._train_forward_backward(*next_element)
  129. if (k + 1) % mini_steps == 0:
  130. step += 1
  131. print("epoch:", i + 1, "step:", step, "loss is ", loss)
  132. self._train_optim()
  133. self._train_clear()
  134. train_dataset.reset()
  135. save_checkpoint(self._train_forward_backward, "gradient_accumulation.ckpt",)
  136. def create_dataset(data_path, batch_size=32, repeat_size=1,
  137. num_parallel_workers=1):
  138. """
  139. create dataset for train or test
  140. """
  141. # define dataset
  142. mnist_ds = ds.MnistDataset(data_path)
  143. resize_height, resize_width = 32, 32
  144. rescale = 1.0 / 255.0
  145. shift = 0.0
  146. rescale_nml = 1 / 0.3081
  147. shift_nml = -1 * 0.1307 / 0.3081
  148. # define map operations
  149. resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) # Bilinear mode
  150. rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)
  151. rescale_op = CV.Rescale(rescale, shift)
  152. hwc2chw_op = CV.HWC2CHW()
  153. type_cast_op = CT.TypeCast(mstype.int32)
  154. # apply map operations on images
  155. mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers)
  156. mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers)
  157. mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers)
  158. mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers)
  159. mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers)
  160. # apply DatasetOps
  161. buffer_size = 10000
  162. mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) # 10000 as in LeNet train script
  163. mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
  164. mnist_ds = mnist_ds.repeat(repeat_size)
  165. return mnist_ds
  166. @pytest.mark.level0
  167. @pytest.mark.platform_arm_ascend_training
  168. @pytest.mark.platform_x86_ascend_training
  169. @pytest.mark.env_onecard
  170. def test_gradient_accumulation():
  171. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  172. ds_train = create_dataset(os.path.join("/home/workspace/mindspore_dataset/mnist", "train"), 32)
  173. network = LeNet5(10)
  174. net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
  175. net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9)
  176. model = GradientAccumulation(network, net_loss, net_opt)
  177. print("============== Starting Training ==============")
  178. model.train_process(2, ds_train, mini_steps=4)