You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_mix_precision.py 8.2 kB

4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """test_mix_precision"""
  16. import numpy as np
  17. import mindspore.common.dtype as mstype
  18. import mindspore.nn as nn
  19. from mindspore import Tensor, context
  20. from mindspore.common import ParameterTuple
  21. from mindspore.common.api import _cell_graph_executor
  22. from mindspore.common.parameter import Parameter
  23. from mindspore.nn import Momentum
  24. from mindspore.nn import TrainOneStepCell, WithLossCell
  25. from mindspore.ops import composite as C
  26. from mindspore.ops import operations as P
  27. from mindspore.ops import functional as F
  28. from mindspore.context import ParallelMode
  29. from tests.ops_common import convert
  30. from ....train_step_wrap import train_step_with_loss_warp
  31. class LeNet5(nn.Cell):
  32. """LeNet5"""
  33. def __init__(self):
  34. super(LeNet5, self).__init__()
  35. self.conv1 = nn.Conv2d(1, 6, 5, pad_mode='valid')
  36. self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
  37. self.fc1 = nn.Dense(16 * 5 * 5, 120)
  38. self.fc2 = nn.Dense(120, 84)
  39. self.fc3 = nn.Dense(84, 10)
  40. self.relu = nn.ReLU()
  41. self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
  42. self.flatten = P.Flatten()
  43. def construct(self, x):
  44. x = self.max_pool2d(self.relu(self.conv1(x)))
  45. x = self.max_pool2d(self.relu(self.conv2(x)))
  46. x = self.flatten(x)
  47. x = self.relu(self.fc1(x))
  48. x = self.relu(self.fc2(x))
  49. x = self.fc3(x)
  50. return x
  51. class NetForConcat(nn.Cell):
  52. def __init__(self):
  53. super(NetForConcat, self).__init__()
  54. self.concat = P.Concat()
  55. self.x1 = Tensor(np.zeros([1, 10]).astype(np.float32))
  56. self.x2 = Parameter(Tensor(np.zeros([1, 10]).astype(np.float32)), name='x2')
  57. def construct(self, x0):
  58. return self.concat((x0, self.x1, self.x2))
  59. def test_add_cast_flag():
  60. predict = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32) * 0.01)
  61. label = Tensor(np.zeros([1, 10]).astype(np.float32))
  62. net = LeNet5()
  63. net.to_float(mstype.float16)
  64. net.fc3.to_float(mstype.float32)
  65. net = train_step_with_loss_warp(net)
  66. net.set_train()
  67. net(predict, label)
  68. def test_add_cast_flag_tensor():
  69. x1 = Tensor(np.zeros([1, 10]).astype(np.float32))
  70. net = NetForConcat()
  71. net.add_flags_recursive(fp16=True)
  72. net.set_train()
  73. net(x1)
  74. def test_on_momentum():
  75. predict = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32) * 0.01)
  76. label = Tensor(np.zeros([1, 10]).astype(np.float32))
  77. net = LeNet5()
  78. net = train_step_with_loss_warp(net).to_float(mstype.float16)
  79. net.set_train()
  80. net(predict, label)
  81. def data_parallel_with_cast():
  82. """test_data_parallel_with_cast"""
  83. context.set_context(device_target='Ascend')
  84. context.reset_auto_parallel_context()
  85. context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, device_num=8)
  86. predict = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32) * 0.01)
  87. label = Tensor(np.zeros([1, 10]).astype(np.float32))
  88. net = LeNet5()
  89. net.to_float(mstype.float16)
  90. net.fc3.to_float(mstype.float32)
  91. loss_fn = nn.SoftmaxCrossEntropyWithLogits()
  92. optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()),
  93. learning_rate=0.1,
  94. momentum=0.9)
  95. net = WithLossCell(net, loss_fn)
  96. net = TrainOneStepCell(net, optimizer)
  97. _cell_graph_executor.compile(net, predict, label)
  98. context.reset_auto_parallel_context()
  99. class NetForPReLU(nn.Cell):
  100. def __init__(self):
  101. super(NetForPReLU, self).__init__()
  102. self.prelu = nn.PReLU()
  103. def construct(self, x):
  104. return self.prelu(x)
  105. def test_nn_prelu():
  106. x = Tensor(np.ones([1, 16, 10, 10]).astype(np.float32) * 0.01)
  107. net = NetForPReLU().set_train()
  108. net.add_flags_recursive(fp16=True)
  109. _cell_graph_executor.compile(net, x)
  110. class NetForCast(nn.Cell):
  111. def __init__(self):
  112. super(NetForCast, self).__init__()
  113. self.x1 = Tensor(1.0, mstype.float32)
  114. self.x2 = Parameter(Tensor(np.zeros([1, 10]).astype(np.float32)), name='x2')
  115. def construct(self, x0):
  116. x = self.x1 * x0 * self.x2
  117. return x
  118. def test_cast():
  119. x = Tensor(np.ones([1, 16, 10, 10]).astype(np.float32) * 0.01)
  120. net = NetForCast()
  121. net.add_flags_recursive(fp16=True)
  122. net(x)
  123. class IRBlockZ(nn.Cell):
  124. def __init__(self, inplanes, planes):
  125. super(IRBlockZ, self).__init__()
  126. self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=1, pad_mode="same", group=1, has_bias=False,
  127. dilation=1)
  128. self.act_layer = nn.PReLU(planes)
  129. def construct(self, x):
  130. out = self.conv1(x)
  131. return self.act_layer(out)
  132. class GetParamGrad(nn.Cell):
  133. def __init__(self, network):
  134. super(GetParamGrad, self).__init__(auto_prefix=False)
  135. self.network = network
  136. self.weights = ParameterTuple(network.trainable_params())
  137. self.grad = C.GradOperation(get_by_list=True,
  138. sens_param=True)
  139. def construct(self, data, sens):
  140. weights = self.weights
  141. return self.grad(self.network, weights)(data, sens)
  142. def test_grad_conv_prelu():
  143. shapes = [[64, 64, 112, 112]]
  144. outshape = [[64, 64, 112, 112]]
  145. net = IRBlockZ(inplanes=64, planes=64).add_flags_recursive(fp16=True)
  146. inputs = [convert(shp, dtype=np.float16) for shp in shapes]
  147. sens_shape = outshape[0]
  148. sens = convert(sens_shape, dtype=np.float16)
  149. all_inputs = inputs + [sens]
  150. net = GetParamGrad(net)
  151. net.set_train()
  152. net(*all_inputs)
  153. def test_dict_cast():
  154. class FirstNet(nn.Cell):
  155. def __init__(self):
  156. super(FirstNet, self).__init__()
  157. self.net = SecondNet()
  158. self.sub = P.Sub()
  159. def construct(self, tensor_a, tensor_b):
  160. a = F.mixed_precision_cast(mstype.float16, tensor_a)
  161. b = F.mixed_precision_cast(mstype.float16, tensor_b)
  162. c = self.sub(a, b)
  163. dictionary = {"key": a}
  164. result = self.net(c, key1=a, key2=dictionary)
  165. return result
  166. class SecondNet(nn.Cell):
  167. def __init__(self):
  168. super(SecondNet, self).__init__()
  169. self.add = P.Add()
  170. def construct(self, tensor_c, **kwargs):
  171. d = F.mixed_precision_cast(mstype.float16, tensor_c)
  172. dict_cast = F.mixed_precision_cast(mstype.float16, kwargs)
  173. e = self.add(d, dict_cast["key1"])
  174. f = self.add(e, dict_cast["key2"]["key"])
  175. return f
  176. x = Tensor(np.array([1, 2.5, 3.5]), mstype.float32)
  177. y = Tensor(np.array([4, 5.5, 6.5]), mstype.float32)
  178. net = FirstNet()
  179. net(x, y)
  180. def test_kwarg_cast():
  181. class FirstNet(nn.Cell):
  182. def __init__(self):
  183. super(FirstNet, self).__init__()
  184. self.net = SecondNet().add_flags_recursive(fp16=True)
  185. self.add = P.Add()
  186. def construct(self, tensor_a, tensor_b):
  187. tensor_c = self.add(tensor_a, tensor_b)
  188. dictionary = {"key": tensor_a}
  189. result = self.net(key1=tensor_c, key2=dictionary)
  190. return result
  191. class SecondNet(nn.Cell):
  192. def __init__(self):
  193. super(SecondNet, self).__init__()
  194. self.add = P.Add()
  195. def construct(self, key1=1, key2=2):
  196. tensor_d = self.add(key1, key2["key"])
  197. return tensor_d
  198. x = Tensor(np.array([1, 2.5, 3.5]), mstype.float32)
  199. y = Tensor(np.array([4, 5.5, 6.5]), mstype.float32)
  200. net = FirstNet()
  201. net(x, y)