You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_hook.py 6.1 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import numpy as np
  16. import pytest
  17. import mindspore.nn as nn
  18. import mindspore.ops.operations as P
  19. from mindspore import context, Tensor, ParameterTuple
  20. from mindspore.common.initializer import TruncatedNormal
  21. from mindspore.nn import WithLossCell, Momentum
  22. from mindspore.ops import composite as C
  23. context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
  24. cell_hook_done = False
  25. var_hook_done = False
  26. cell_bprop_done = False
  27. grad_all = C.GradOperation(get_all=True)
  28. def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
  29. """weight initial for conv layer"""
  30. weight = weight_variable()
  31. return nn.Conv2d(in_channels, out_channels,
  32. kernel_size=kernel_size, stride=stride, padding=padding,
  33. weight_init=weight, has_bias=False, pad_mode="valid")
  34. def fc_with_initialize(input_channels, out_channels):
  35. """weight initial for fc layer"""
  36. weight = weight_variable()
  37. bias = weight_variable()
  38. return nn.Dense(input_channels, out_channels, weight, bias)
  39. def weight_variable():
  40. """weight initial"""
  41. return TruncatedNormal(0.02)
  42. def cell_hook_function(cell_id, grad_input, grad_output):
  43. print(cell_id)
  44. global cell_hook_done
  45. cell_hook_done = True
  46. assert (grad_output[0].asnumpy().shape == (32, 6, 14, 14))
  47. assert (grad_input[0].asnumpy().shape == (32, 16, 10, 10))
  48. def var_hook_function(grad_out):
  49. print("grad:", grad_out)
  50. global var_hook_done
  51. var_hook_done = True
  52. assert (grad_out[0].asnumpy().shape == (32, 120))
  53. class Block(nn.Cell):
  54. def __init__(self):
  55. super(Block, self).__init__()
  56. self.relu = nn.ReLU()
  57. def construct(self, x):
  58. x = self.relu(x)
  59. return x
  60. def bprop(self, x, out, dout):
  61. global cell_bprop_done
  62. cell_bprop_done = True
  63. grad = out.asnumpy() * dout.asnumpy()
  64. grad = Tensor(grad)
  65. return (grad,)
  66. class LeNet5(nn.Cell):
  67. """
  68. Lenet network
  69. Args:
  70. num_class (int): Num classes. Default: 10.
  71. Returns:
  72. Tensor, output tensor
  73. Examples:
  74. >>> LeNet(num_class=10)
  75. """
  76. def __init__(self, num_class=10):
  77. super(LeNet5, self).__init__()
  78. self.num_class = num_class
  79. self.batch_size = 32
  80. self.conv1 = conv(1, 6, 5)
  81. self.conv2 = conv(6, 16, 5)
  82. self.conv2.register_backward_hook(cell_hook_function)
  83. self.block = Block()
  84. self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
  85. self.fc2 = fc_with_initialize(120, 84)
  86. self.fc3 = fc_with_initialize(84, self.num_class)
  87. self.relu = nn.ReLU()
  88. self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
  89. self.reshape = P.Reshape()
  90. self.hook = P.HookBackward(var_hook_function)
  91. def construct(self, x):
  92. x = self.conv1(x)
  93. x = self.relu(x)
  94. x = self.max_pool2d(x)
  95. x = self.conv2(x)
  96. x = self.block(x)
  97. x = self.max_pool2d(x)
  98. x = self.reshape(x, (self.batch_size, -1))
  99. x = self.fc1(x)
  100. x = self.hook(x)
  101. x = self.relu(x)
  102. x = self.fc2(x)
  103. x = self.relu(x)
  104. x = self.fc3(x)
  105. return x
  106. class GradWrap(nn.Cell):
  107. """ GradWrap definition """
  108. def __init__(self, network):
  109. super(GradWrap, self).__init__(auto_prefix=False)
  110. self.network = network
  111. self.weights = ParameterTuple(filter(lambda x: x.requires_grad, network.get_parameters()))
  112. def construct(self, x, label):
  113. weights = self.weights
  114. return C.GradOperation(get_by_list=True)(self.network, weights)(x, label)
  115. def test_hook():
  116. net = LeNet5()
  117. optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.1, 0.9)
  118. criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=False)
  119. net_with_criterion = WithLossCell(net, criterion)
  120. train_network = GradWrap(net_with_criterion)
  121. train_network.set_train()
  122. input_data = Tensor(np.ones([net.batch_size, 1, 32, 32]).astype(np.float32) * 0.01)
  123. label = Tensor(np.ones([net.batch_size, net.num_class]).astype(np.float32))
  124. output = net(Tensor(input_data))
  125. loss_output = criterion(output, label)
  126. grads = train_network(input_data, label)
  127. success = optimizer(grads)
  128. assert cell_hook_done
  129. assert var_hook_done
  130. assert cell_bprop_done
  131. print(loss_output.asnumpy())
  132. bprop_debug = False
  133. class MulAdd(nn.Cell):
  134. def __init__(self):
  135. super(MulAdd, self).__init__()
  136. def construct(self, x, y):
  137. return 2 * x * x + y * y
  138. def bprop(self, x, y, out, dout):
  139. global bprop_debug
  140. bprop_debug = True
  141. return dout, 2 * y
  142. def test_custom_bprop():
  143. mul_add = MulAdd()
  144. mul_add.bprop_debug = True
  145. x = Tensor(np.array([1, 2, 3]).astype(np.int32))
  146. y = Tensor(np.array([2, 3, 4]).astype(np.int32))
  147. grad_all(mul_add)(x, y)
  148. assert bprop_debug
  149. class Net(nn.Cell):
  150. def __init__(self):
  151. super(Net, self).__init__()
  152. def construct(self, x, y):
  153. return 2 * x * x + y * y
  154. def test_grad_all():
  155. net = Net()
  156. x = Tensor(np.array([1, 2, 3]).astype(np.int32))
  157. y = Tensor(np.array([2, 3, 4]).astype(np.int32))
  158. res = grad_all(net)(x, y)
  159. print(res)
  160. def test_check_input():
  161. net = Net()
  162. x = np.array([1, 2, 3])
  163. y = np.array([2, 3, 4])
  164. with pytest.raises(TypeError):
  165. net(x, y)