You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_hook.py 4.2 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. import numpy as np
  2. import mindspore.nn as nn
  3. import mindspore.ops.operations as P
  4. from mindspore import context, Tensor, ParameterTuple
  5. from mindspore.common.initializer import TruncatedNormal
  6. from mindspore.nn import WithLossCell, Momentum
  7. from mindspore.ops import composite as C
  8. context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
  9. def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
  10. """weight initial for conv layer"""
  11. weight = weight_variable()
  12. return nn.Conv2d(in_channels, out_channels,
  13. kernel_size=kernel_size, stride=stride, padding=padding,
  14. weight_init=weight, has_bias=False, pad_mode="valid")
  15. def fc_with_initialize(input_channels, out_channels):
  16. """weight initial for fc layer"""
  17. weight = weight_variable()
  18. bias = weight_variable()
  19. return nn.Dense(input_channels, out_channels, weight, bias)
  20. def weight_variable():
  21. """weight initial"""
  22. return TruncatedNormal(0.02)
  23. def cell_hook_function(cell_id, grad_input, grad_output):
  24. print(cell_id)
  25. assert (grad_output[0].asnumpy().shape == (32, 6, 14, 14))
  26. assert (grad_input[0].asnumpy().shape == (32, 16, 10, 10))
  27. def var_hook_function(grad_out):
  28. print("grad:", grad_out)
  29. assert (grad_out[0].asnumpy().shape == (32, 120))
  30. class LeNet5(nn.Cell):
  31. """
  32. Lenet network
  33. Args:
  34. num_class (int): Num classes. Default: 10.
  35. Returns:
  36. Tensor, output tensor
  37. Examples:
  38. >>> LeNet(num_class=10)
  39. """
  40. def __init__(self, num_class=10):
  41. super(LeNet5, self).__init__()
  42. self.num_class = num_class
  43. self.batch_size = 32
  44. self.conv1 = conv(1, 6, 5)
  45. self.conv2 = conv(6, 16, 5)
  46. self.conv2.register_backward_hook(cell_hook_function)
  47. self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
  48. self.fc2 = fc_with_initialize(120, 84)
  49. self.fc3 = fc_with_initialize(84, self.num_class)
  50. self.relu = nn.ReLU()
  51. self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
  52. self.reshape = P.Reshape()
  53. self.hook = P.HookBackward(var_hook_function)
  54. def construct(self, x):
  55. x = self.conv1(x)
  56. x = self.relu(x)
  57. x = self.max_pool2d(x)
  58. x = self.conv2(x)
  59. x = self.relu(x)
  60. x = self.max_pool2d(x)
  61. x = self.reshape(x, (self.batch_size, -1))
  62. x = self.fc1(x)
  63. x = self.hook(x)
  64. x = self.relu(x)
  65. x = self.fc2(x)
  66. x = self.relu(x)
  67. x = self.fc3(x)
  68. return x
  69. class GradWrap(nn.Cell):
  70. """ GradWrap definition """
  71. def __init__(self, network):
  72. super(GradWrap, self).__init__(auto_prefix=False)
  73. self.network = network
  74. self.weights = ParameterTuple(filter(lambda x: x.requires_grad, network.get_parameters()))
  75. def construct(self, x, label):
  76. weights = self.weights
  77. return C.GradOperation('get_by_list', get_by_list=True)(self.network, weights)(x, label)
  78. def test_hook():
  79. net = LeNet5()
  80. optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.1, 0.9)
  81. criterion = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=False)
  82. net_with_criterion = WithLossCell(net, criterion)
  83. train_network = GradWrap(net_with_criterion)
  84. train_network.set_train()
  85. input_data = Tensor(np.ones([net.batch_size, 1, 32, 32]).astype(np.float32) * 0.01)
  86. label = Tensor(np.ones([net.batch_size, net.num_class]).astype(np.float32))
  87. output = net(Tensor(input_data))
  88. loss_output = criterion(output, label)
  89. grads = train_network(input_data, label)
  90. success = optimizer(grads)
  91. print(loss_output.asnumpy().shape)
  92. class MulAdd(nn.Cell):
  93. def __init__(self):
  94. super(MulAdd, self).__init__()
  95. def construct(self, x, y):
  96. return 2 * x + y
  97. def bprop(self, x, y, out, dout):
  98. assert (x == 1)
  99. assert (y == 2)
  100. assert (out == 4)
  101. assert (dout == 1)
  102. return 3 * dout, 2 * y
  103. def test_custom_bprop():
  104. mul_add = MulAdd()
  105. mul_add.bprop_debug = True
  106. assert C.grad_all(mul_add)(1, 2) == (3, 4)