You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_loss_scale.py 9.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test_loss_scale """
  16. import numpy as np
  17. import pytest
  18. import mindspore.nn as nn
  19. from mindspore import context
  20. from mindspore import Tensor, Parameter
  21. from mindspore.nn.wrap.cell_wrapper import WithLossCell
  22. from mindspore.nn.wrap.loss_scale import TrainOneStepWithLossScaleCell
  23. from mindspore.ops import operations as P
  24. from mindspore.nn.optim import Momentum, RMSProp
  25. from mindspore.ops import functional as F
  26. from mindspore.common import dtype as mstype
  27. from mindspore.train import Model
  28. from mindspore.nn.optim import Lamb
  29. from mindspore.train.loss_scale_manager import DynamicLossScaleManager
  30. def setup_module():
  31. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  32. class MindData:
  33. """ Stub for MindData """
  34. def __init__(self, size=None, batch_size=None, repeat_count=1,
  35. np_types=None, output_shapes=None, input_indexes=(), func_name=''):
  36. self._size = size
  37. self._batch_size = batch_size
  38. self._repeat_count = repeat_count
  39. self._np_types = np_types
  40. self._output_shapes = output_shapes
  41. self._input_indexes = input_indexes
  42. self._func_name = func_name
  43. self._iter_num = 0
  44. def get_dataset_size(self):
  45. return self._size
  46. def get_repeat_count(self):
  47. return self._repeat_count
  48. def get_batch_size(self):
  49. return self._batch_size
  50. def output_types(self):
  51. return self._np_types
  52. def output_shapes(self):
  53. return self._output_shapes
  54. def create_tuple_iterator(self, num_epochs=-1):
  55. return self
  56. @property
  57. def input_indexes(self):
  58. return self._input_indexes
  59. @property
  60. def func_name(self):
  61. return self._func_name
  62. def send(self):
  63. pass
  64. def __len__(self):
  65. return self._size
  66. def __iter__(self):
  67. return self
  68. def __next__(self):
  69. if self._size < self._iter_num:
  70. raise StopIteration
  71. self._iter_num += 1
  72. next_value = []
  73. for shape, typ in zip(self._output_shapes, self._np_types):
  74. next_value.append(Tensor(np.ndarray(shape, typ)))
  75. return tuple(next_value)
  76. def next(self):
  77. return self.__next__()
  78. def reset(self):
  79. self._iter_num = 0
  80. class MindDataSet(MindData):
  81. def __init__(self, dataset_types, dataset_shapes):
  82. super(MindDataSet, self).__init__(size=2, batch_size=32,
  83. np_types=dataset_types,
  84. output_shapes=dataset_shapes,
  85. input_indexes=(0, 1), func_name='')
  86. def __next__(self):
  87. if self._size < self._iter_num:
  88. raise StopIteration
  89. self._iter_num += 1
  90. res = []
  91. for shape, t in zip(self._output_shapes, self._np_types):
  92. res.append(Tensor(np.ones(shape).astype(t)))
  93. return tuple(res)
  94. class NetFP16(nn.Cell):
  95. def __init__(self, in_features, out_features):
  96. super(NetFP16, self).__init__()
  97. self.weight = Parameter(Tensor(np.ones([out_features, in_features]).astype(np.float32)), name="weight")
  98. self.bias = Parameter(Tensor(np.ones([out_features]).astype(np.float32)), name="bias")
  99. self.matmul = P.MatMul()
  100. self.add = P.TensorAdd()
  101. self.cast = P.Cast()
  102. def construct(self, x):
  103. output = self.cast(self.add(self.matmul(self.cast(x, mstype.float16),
  104. self.cast(self.weight, mstype.float16)),
  105. self.cast(self.bias, mstype.float16)), mstype.float32)
  106. return output
  107. def get_axis(x):
  108. shape_op = P.Shape()
  109. shape = shape_op(x)
  110. length = F.tuple_len(shape)
  111. perm = F.make_range(0, length)
  112. return perm
  113. class MSELoss(nn.Cell):
  114. def __init__(self):
  115. super(MSELoss, self).__init__()
  116. self.sum = P.ReduceSum()
  117. self.square = P.Square()
  118. self.reduce_mean = P.ReduceMean()
  119. def construct(self, data, label):
  120. diff = data - label
  121. return self.reduce_mean(self.square(diff), get_axis(diff))
  122. @pytest.mark.level0
  123. @pytest.mark.platform_arm_ascend_training
  124. @pytest.mark.platform_x86_ascend_training
  125. @pytest.mark.env_onecard
  126. def test_loss_scale_fp16_lr_overflow():
  127. inputs = Tensor(np.ones([16, 16]).astype(np.float32))
  128. label = Tensor(np.zeros([16, 16]).astype(np.float32))
  129. lr = Tensor(np.ones([1], np.float32) * 0.1)
  130. net = NetFP16(16, 16)
  131. net.set_train()
  132. loss = MSELoss()
  133. optimizer = Momentum(net.trainable_params(), learning_rate=lr, momentum=0.9)
  134. net_with_loss = WithLossCell(net, loss)
  135. train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer,
  136. scale_sense=Tensor(np.full((1), np.finfo(np.float32).max),
  137. dtype=mstype.float32))
  138. output_1 = train_network(inputs, label)
  139. output_2 = train_network(inputs, label)
  140. assert output_1[0].asnumpy() == output_2[0].asnumpy()
  141. assert output_1[1].asnumpy() == output_2[1].asnumpy() == True
  142. @pytest.mark.level0
  143. @pytest.mark.platform_arm_ascend_training
  144. @pytest.mark.platform_x86_ascend_training
  145. @pytest.mark.env_onecard
  146. def test_loss_scale_fp16_lr_overflow_set_sense_scale():
  147. inputs = Tensor(np.ones([16, 16]).astype(np.float32))
  148. label = Tensor(np.zeros([16, 16]).astype(np.float32))
  149. lr = Tensor(np.ones([1], np.float32) * 0.1)
  150. net = NetFP16(16, 16)
  151. net.set_train()
  152. loss = MSELoss()
  153. optimizer = Momentum(net.trainable_params(), learning_rate=lr, momentum=0.9)
  154. net_with_loss = WithLossCell(net, loss)
  155. train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer,
  156. scale_sense=Tensor(np.full((1), np.finfo(np.float32).max),
  157. dtype=mstype.float32))
  158. output_1 = train_network(inputs, label)
  159. train_network.set_sense_scale(Tensor(np.full((1), np.finfo(np.float32).max), dtype=mstype.float32))
  160. output_2 = train_network(inputs, label)
  161. assert output_1[0].asnumpy() == output_2[0].asnumpy()
  162. assert output_1[1].asnumpy() == output_2[1].asnumpy() == True
  163. @pytest.mark.level0
  164. @pytest.mark.platform_arm_ascend_training
  165. @pytest.mark.platform_x86_ascend_training
  166. @pytest.mark.env_onecard
  167. def test_loss_scale_fp16_model_train_overflow():
  168. dataset_types = (np.float32, np.float32)
  169. dataset_shapes = ((16, 16), (16, 16))
  170. dataset = MindDataSet(dataset_types, dataset_shapes)
  171. net = NetFP16(16, 16)
  172. net.set_train()
  173. loss = MSELoss()
  174. optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
  175. scale_manager = DynamicLossScaleManager(init_loss_scale=16, scale_factor=2, scale_window=2)
  176. model = Model(net, loss_fn=loss, optimizer=optimizer, metrics=None, loss_scale_manager=scale_manager)
  177. model.train(2, dataset, dataset_sink_mode=False)
  178. @pytest.mark.level0
  179. @pytest.mark.platform_arm_ascend_training
  180. @pytest.mark.platform_x86_ascend_training
  181. @pytest.mark.env_onecard
  182. def test_loss_scale_fp16_opt_rmsprop_overflow():
  183. inputs = Tensor(np.ones([16, 16]).astype(np.float32))
  184. label = Tensor(np.zeros([16, 16]).astype(np.float32))
  185. net = NetFP16(16, 16)
  186. net.set_train()
  187. loss = MSELoss()
  188. optimizer = RMSProp(net.trainable_params(), learning_rate=0.1)
  189. net_with_loss = WithLossCell(net, loss)
  190. train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer,
  191. scale_sense=Tensor(np.full(1, np.finfo(np.float32).max),
  192. dtype=mstype.float32))
  193. output_1 = train_network(inputs, label)
  194. output_2 = train_network(inputs, label)
  195. assert output_1[0].asnumpy() == output_2[0].asnumpy()
  196. assert output_1[1].asnumpy() == output_2[1].asnumpy() == True
  197. @pytest.mark.level0
  198. @pytest.mark.platform_arm_ascend_training
  199. @pytest.mark.platform_x86_ascend_training
  200. @pytest.mark.env_onecard
  201. def test_loss_scale_fp16_overflow():
  202. inputs = Tensor(np.ones([16, 16]).astype(np.float32))
  203. label = Tensor(np.zeros([16, 16]).astype(np.float32))
  204. net = NetFP16(16, 16)
  205. net.set_train()
  206. loss = MSELoss()
  207. optimizer = Lamb(net.trainable_params(), learning_rate=0.01)
  208. net_with_loss = WithLossCell(net, loss)
  209. net_with_loss.set_grad()
  210. train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer,
  211. scale_sense=Tensor(np.full((1), np.finfo(np.float32).max),
  212. dtype=mstype.float32))
  213. output_1 = train_network(inputs, label)
  214. output_2 = train_network(inputs, label)
  215. assert output_1[0].asnumpy() == output_2[0].asnumpy()
  216. assert output_1[1].asnumpy() == output_2[1].asnumpy() == True