You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_hook.py 4.8 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. import numpy as np
  2. import mindspore.nn as nn
  3. import mindspore.ops.operations as P
  4. from mindspore import context, Tensor, ParameterTuple
  5. from mindspore.common.initializer import TruncatedNormal
  6. from mindspore.nn import WithLossCell, Momentum
  7. from mindspore.ops import composite as C
  8. context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
  9. cell_hook_done = False
  10. var_hook_done = False
  11. cell_bprop_done = False
  12. def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
  13. """weight initial for conv layer"""
  14. weight = weight_variable()
  15. return nn.Conv2d(in_channels, out_channels,
  16. kernel_size=kernel_size, stride=stride, padding=padding,
  17. weight_init=weight, has_bias=False, pad_mode="valid")
  18. def fc_with_initialize(input_channels, out_channels):
  19. """weight initial for fc layer"""
  20. weight = weight_variable()
  21. bias = weight_variable()
  22. return nn.Dense(input_channels, out_channels, weight, bias)
  23. def weight_variable():
  24. """weight initial"""
  25. return TruncatedNormal(0.02)
  26. def cell_hook_function(cell_id, grad_input, grad_output):
  27. print(cell_id)
  28. global cell_hook_done
  29. cell_hook_done = True
  30. assert (grad_output[0].asnumpy().shape == (32, 6, 14, 14))
  31. assert (grad_input[0].asnumpy().shape == (32, 16, 10, 10))
  32. def var_hook_function(grad_out):
  33. print("grad:", grad_out)
  34. global var_hook_done
  35. var_hook_done = True
  36. assert (grad_out[0].asnumpy().shape == (32, 120))
  37. class Block(nn.Cell):
  38. def __init__(self):
  39. super(Block, self).__init__()
  40. self.relu = nn.ReLU()
  41. def construct(self, x):
  42. x = self.relu(x)
  43. return x
  44. def bprop(self, x, out, dout):
  45. global cell_bprop_done
  46. cell_bprop_done = True
  47. grad = out.asnumpy() * dout.asnumpy()
  48. grad = Tensor(grad)
  49. return (grad,)
  50. class LeNet5(nn.Cell):
  51. """
  52. Lenet network
  53. Args:
  54. num_class (int): Num classes. Default: 10.
  55. Returns:
  56. Tensor, output tensor
  57. Examples:
  58. >>> LeNet(num_class=10)
  59. """
  60. def __init__(self, num_class=10):
  61. super(LeNet5, self).__init__()
  62. self.num_class = num_class
  63. self.batch_size = 32
  64. self.conv1 = conv(1, 6, 5)
  65. self.conv2 = conv(6, 16, 5)
  66. self.conv2.register_backward_hook(cell_hook_function)
  67. self.block = Block()
  68. self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
  69. self.fc2 = fc_with_initialize(120, 84)
  70. self.fc3 = fc_with_initialize(84, self.num_class)
  71. self.relu = nn.ReLU()
  72. self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
  73. self.reshape = P.Reshape()
  74. self.hook = P.HookBackward(var_hook_function)
  75. def construct(self, x):
  76. x = self.conv1(x)
  77. x = self.relu(x)
  78. x = self.max_pool2d(x)
  79. x = self.conv2(x)
  80. x = self.block(x)
  81. x = self.max_pool2d(x)
  82. x = self.reshape(x, (self.batch_size, -1))
  83. x = self.fc1(x)
  84. x = self.hook(x)
  85. x = self.relu(x)
  86. x = self.fc2(x)
  87. x = self.relu(x)
  88. x = self.fc3(x)
  89. return x
  90. class GradWrap(nn.Cell):
  91. """ GradWrap definition """
  92. def __init__(self, network):
  93. super(GradWrap, self).__init__(auto_prefix=False)
  94. self.network = network
  95. self.weights = ParameterTuple(filter(lambda x: x.requires_grad, network.get_parameters()))
  96. def construct(self, x, label):
  97. weights = self.weights
  98. return C.GradOperation('get_by_list', get_by_list=True)(self.network, weights)(x, label)
  99. def test_hook():
  100. net = LeNet5()
  101. optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.1, 0.9)
  102. criterion = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=False)
  103. net_with_criterion = WithLossCell(net, criterion)
  104. train_network = GradWrap(net_with_criterion)
  105. train_network.set_train()
  106. input_data = Tensor(np.ones([net.batch_size, 1, 32, 32]).astype(np.float32) * 0.01)
  107. label = Tensor(np.ones([net.batch_size, net.num_class]).astype(np.float32))
  108. output = net(Tensor(input_data))
  109. loss_output = criterion(output, label)
  110. grads = train_network(input_data, label)
  111. success = optimizer(grads)
  112. assert cell_hook_done
  113. assert var_hook_done
  114. assert cell_bprop_done
  115. print(loss_output.asnumpy().shape)
  116. class MulAdd(nn.Cell):
  117. def __init__(self):
  118. super(MulAdd, self).__init__()
  119. def construct(self, x, y):
  120. return 2 * x + y
  121. def bprop(self, x, y, out, dout):
  122. assert (x == 1)
  123. assert (y == 2)
  124. assert (out == 4)
  125. assert (dout == 1)
  126. return 3 * dout, 2 * y
  127. def test_custom_bprop():
  128. mul_add = MulAdd()
  129. mul_add.bprop_debug = True
  130. assert C.grad_all(mul_add)(1, 2) == (3, 4)