You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_hook.py 4.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. import numpy as np
  2. import mindspore.nn as nn
  3. import mindspore.ops.operations as P
  4. from mindspore import context
  5. from mindspore.ops import composite as C
  6. from mindspore.common import dtype as mstype
  7. from mindspore import context, Tensor, ParameterTuple
  8. from mindspore.common.initializer import TruncatedNormal
  9. from mindspore.nn import Dense, WithLossCell, SoftmaxCrossEntropyWithLogits, Momentum
  10. context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
  11. def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
  12. """weight initial for conv layer"""
  13. weight = weight_variable()
  14. return nn.Conv2d(in_channels, out_channels,
  15. kernel_size=kernel_size, stride=stride, padding=padding,
  16. weight_init=weight, has_bias=False, pad_mode="valid")
  17. def fc_with_initialize(input_channels, out_channels):
  18. """weight initial for fc layer"""
  19. weight = weight_variable()
  20. bias = weight_variable()
  21. return nn.Dense(input_channels, out_channels, weight, bias)
  22. def weight_variable():
  23. """weight initial"""
  24. return TruncatedNormal(0.02)
  25. def cell_hook_function(cell_id, grad_input, grad_output):
  26. print(cell_id)
  27. assert(grad_output.asnumpy().shape == (32, 6, 14, 14))
  28. assert(grad_input.asnumpy().shape == (32, 16, 10, 10))
  29. def var_hook_function(grad_out):
  30. print("grad:", grad_out)
  31. assert(grad_out.asnumpy().shape == (32, 120))
  32. class LeNet5(nn.Cell):
  33. """
  34. Lenet network
  35. Args:
  36. num_class (int): Num classes. Default: 10.
  37. Returns:
  38. Tensor, output tensor
  39. Examples:
  40. >>> LeNet(num_class=10)
  41. """
  42. def __init__(self, num_class=10):
  43. super(LeNet5, self).__init__()
  44. self.num_class = num_class
  45. self.batch_size = 32
  46. self.conv1 = conv(1, 6, 5)
  47. self.conv2 = conv(6, 16, 5)
  48. self.conv2.register_backward_hook(cell_hook_function)
  49. self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
  50. self.fc2 = fc_with_initialize(120, 84)
  51. self.fc3 = fc_with_initialize(84, self.num_class)
  52. self.relu = nn.ReLU()
  53. self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
  54. self.reshape = P.Reshape()
  55. self.hook = P.HookBackward(var_hook_function)
  56. def construct(self, x):
  57. x = self.conv1(x)
  58. x = self.relu(x)
  59. x = self.max_pool2d(x)
  60. x = self.conv2(x)
  61. x = self.relu(x)
  62. x = self.max_pool2d(x)
  63. x = self.reshape(x, (self.batch_size, -1))
  64. x = self.fc1(x)
  65. x = self.hook(x)
  66. x = self.relu(x)
  67. x = self.fc2(x)
  68. x = self.relu(x)
  69. x = self.fc3(x)
  70. return x
  71. class GradWrap(nn.Cell):
  72. """ GradWrap definition """
  73. def __init__(self, network):
  74. super(GradWrap, self).__init__(auto_prefix=False)
  75. self.network = network
  76. self.weights = ParameterTuple(filter(lambda x: x.requires_grad, network.get_parameters()))
  77. def construct(self, x, label):
  78. weights = self.weights
  79. return C.GradOperation('get_by_list', get_by_list=True)(self.network, weights)(x, label)
  80. def test_hook():
  81. net = LeNet5()
  82. optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.1, 0.9)
  83. criterion = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=False)
  84. net_with_criterion = WithLossCell(net, criterion)
  85. train_network = GradWrap(net_with_criterion)
  86. train_network.set_train()
  87. input_data = Tensor(np.ones([net.batch_size, 1, 32, 32]).astype(np.float32) * 0.01)
  88. label = Tensor(np.ones([net.batch_size, net.num_class]).astype(np.float32))
  89. output = net(Tensor(input_data))
  90. loss_output = criterion(output, label)
  91. grads = train_network(input_data, label)
  92. success = optimizer(grads)
  93. print(loss_output.asnumpy().shape)
  94. class MulAdd(nn.Cell):
  95. def __init__(self):
  96. super(MulAdd, self).__init__()
  97. def construct(self, x, y):
  98. return 2 * x + y
  99. def bprop(self, x, y, out, dout):
  100. assert(x == 1)
  101. assert(y == 2)
  102. assert(out == 4)
  103. assert(dout == 1)
  104. return 3 * dout, 2 * y
  105. def test_custom_bprop():
  106. mul_add = MulAdd()
  107. mul_add.bprop_debug = True
  108. assert C.grad_all(mul_add)(1, 2) == (3, 4)