You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_mix_precision.py 6.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """test_mix_precision"""
  16. import numpy as np
  17. import mindspore.nn as nn
  18. import mindspore.common.dtype as mstype
  19. from mindspore.common.api import _executor
  20. from mindspore.common.parameter import Parameter
  21. from mindspore.common import ParameterTuple
  22. from mindspore import Tensor, context
  23. from mindspore.ops import operations as P
  24. from mindspore.ops import composite as C
  25. from mindspore.nn import TrainOneStepCell, WithLossCell
  26. from mindspore.nn import Momentum
  27. from ....train_step_wrap import train_step_with_loss_warp
  28. from tests.ops_common import convert
  29. from mindspore.train.parallel_utils import ParallelMode
  30. class LeNet5(nn.Cell):
  31. """LeNet5"""
  32. def __init__(self):
  33. super(LeNet5, self).__init__()
  34. self.conv1 = nn.Conv2d(1, 6, 5, pad_mode='valid')
  35. self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
  36. self.fc1 = nn.Dense(16 * 5 * 5, 120)
  37. self.fc2 = nn.Dense(120, 84)
  38. self.fc3 = nn.Dense(84, 10)
  39. self.relu = nn.ReLU()
  40. self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
  41. self.flatten = P.Flatten()
  42. def construct(self, x):
  43. x = self.max_pool2d(self.relu(self.conv1(x)))
  44. x = self.max_pool2d(self.relu(self.conv2(x)))
  45. x = self.flatten(x)
  46. x = self.relu(self.fc1(x))
  47. x = self.relu(self.fc2(x))
  48. x = self.fc3(x)
  49. return x
  50. class NetForConcat(nn.Cell):
  51. def __init__(self):
  52. super(NetForConcat, self).__init__()
  53. self.concat = P.Concat()
  54. self.x1 = Tensor(np.zeros([1, 10]).astype(np.float32))
  55. self.x2 = Parameter(Tensor(np.zeros([1, 10]).astype(np.float32)), name='x2')
  56. def construct(self, x0):
  57. return self.concat((x0, self.x1, self.x2))
  58. def test_add_cast_flag():
  59. predict = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32) * 0.01)
  60. label = Tensor(np.zeros([1, 10]).astype(np.float32))
  61. net = LeNet5()
  62. net.to_float(mstype.float16)
  63. net.fc3.to_float(mstype.float32)
  64. net = train_step_with_loss_warp(net)
  65. net.set_train()
  66. _executor.compile(net, predict, label)
  67. def test_add_cast_flag_tensor():
  68. x1 = Tensor(np.zeros([1, 10]).astype(np.float32))
  69. net = NetForConcat()
  70. net.add_flags_recursive(fp16=True)
  71. net.set_train()
  72. _executor.compile(net, x1)
  73. def test_on_momentum():
  74. predict = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32) * 0.01)
  75. label = Tensor(np.zeros([1, 10]).astype(np.float32))
  76. net = LeNet5()
  77. net = train_step_with_loss_warp(net).to_float(mstype.float16)
  78. net.set_train()
  79. _executor.compile(net, predict, label)
  80. def test_data_parallel_with_cast():
  81. """test_data_parallel_with_cast"""
  82. predict = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32) * 0.01)
  83. label = Tensor(np.zeros([1, 10]).astype(np.float32))
  84. net = LeNet5()
  85. net.to_float(mstype.float16)
  86. net.fc3.to_float(mstype.float32)
  87. loss_fn = nn.SoftmaxCrossEntropyWithLogits()
  88. optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()),
  89. learning_rate=0.1,
  90. momentum=0.9)
  91. net = WithLossCell(net, loss_fn)
  92. context.reset_auto_parallel_context()
  93. context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True, device_num=8)
  94. net = TrainOneStepCell(net, optimizer)
  95. _executor.compile(net, predict, label)
  96. context.reset_auto_parallel_context()
  97. class NetForPReLU(nn.Cell):
  98. def __init__(self):
  99. super(NetForPReLU, self).__init__()
  100. self.prelu = nn.PReLU()
  101. def construct(self, x):
  102. return self.prelu(x)
  103. def test_nn_prelu():
  104. x = Tensor(np.ones([1, 16, 10, 10]).astype(np.float32) * 0.01)
  105. net = NetForPReLU().set_train()
  106. net.add_flags_recursive(fp16=True)
  107. _executor.compile(net, x)
  108. class NetForCast(nn.Cell):
  109. def __init__(self):
  110. super(NetForCast, self).__init__()
  111. self.concat = P.Concat()
  112. self.x1 = Tensor(1.0, mstype.float32)
  113. def construct(self, x0):
  114. x = self.x1 * x0
  115. return x
  116. def test_cast():
  117. x = Tensor(np.ones([1, 16, 10, 10]).astype(np.float32) * 0.01)
  118. net = NetForCast()
  119. net.add_flags_recursive(fp16=True)
  120. _executor.compile(net, x)
  121. """test grad of PReLU, which cause AddN(generated by grad) fail"""
  122. class IRBlockZ(nn.Cell):
  123. def __init__(self, inplanes, planes):
  124. super(IRBlockZ, self).__init__()
  125. self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=1, pad_mode="same", group=1, has_bias=False, dilation=1)
  126. self.act_layer = nn.PReLU(planes)
  127. def construct(self, x):
  128. out = self.conv1(x)
  129. return self.act_layer(out)
  130. class GetParamGrad(nn.Cell):
  131. def __init__(self, network):
  132. super(GetParamGrad, self).__init__(auto_prefix=False)
  133. self.network = network
  134. self.weights = ParameterTuple(network.trainable_params())
  135. self.grad = C.GradOperation('grad',
  136. get_by_list=True,
  137. sens_param=True)
  138. def construct(self, data, sens):
  139. weights = self.weights
  140. return self.grad(self.network, weights)(data, sens)
  141. def test_grad_conv_prelu():
  142. shapes = [[64, 64, 112, 112]]
  143. outshape = [[64, 64, 56, 56]]
  144. net = IRBlockZ(inplanes=64, planes=64).add_flags_recursive(fp16=True)
  145. inputs = [convert(shp, dtype=np.float16) for shp in shapes]
  146. sens_shape = outshape[0]
  147. sens = convert(sens_shape, dtype=np.float16)
  148. all_inputs = inputs + [sens]
  149. net = GetParamGrad(net)
  150. net.set_train()
  151. net(*all_inputs)