You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_trace_dump.py 4.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. import contextlib
  2. import os
  3. import tempfile
  4. import numpy as np
  5. import pytest
  6. import megengine as mge
  7. import megengine.functional as F
  8. import megengine.module as M
  9. import megengine.optimizer as optim
  10. from megengine import tensor
  11. from megengine.autodiff import GradManager
  12. from megengine.jit import trace
  13. from megengine.optimizer import SGD
  14. @contextlib.contextmanager
  15. def mkstemp():
  16. fd, path = tempfile.mkstemp()
  17. try:
  18. os.close(fd)
  19. yield path
  20. finally:
  21. os.remove(path)
  22. def minibatch_generator(batch_size):
  23. while True:
  24. inp_data = np.zeros((batch_size, 2))
  25. label = np.zeros(batch_size, dtype=np.int32)
  26. for i in range(batch_size):
  27. inp_data[i, :] = np.random.rand(2) * 2 - 1
  28. label[i] = 1 if np.prod(inp_data[i]) < 0 else 0
  29. yield {"data": inp_data.astype(np.float32), "label": label.astype(np.int32)}
  30. class XORNet(M.Module):
  31. def __init__(self):
  32. self.mid_dim = 14
  33. self.num_class = 2
  34. super().__init__()
  35. self.fc0 = M.Linear(self.num_class, self.mid_dim, bias=True)
  36. self.bn0 = M.BatchNorm1d(self.mid_dim)
  37. self.fc1 = M.Linear(self.mid_dim, self.mid_dim, bias=True)
  38. self.bn1 = M.BatchNorm1d(self.mid_dim)
  39. self.fc2 = M.Linear(self.mid_dim, self.num_class, bias=True)
  40. def forward(self, x):
  41. x = self.fc0(x)
  42. x = self.bn0(x)
  43. x = F.tanh(x)
  44. x = self.fc1(x)
  45. x = self.bn1(x)
  46. x = F.tanh(x)
  47. x = self.fc2(x)
  48. return x
  49. def test_xornet_trace_dump():
  50. net = XORNet()
  51. opt = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
  52. gm = GradManager().attach(net.parameters())
  53. batch_size = 64
  54. train_dataset = minibatch_generator(batch_size)
  55. val_dataset = minibatch_generator(batch_size)
  56. @trace
  57. def train_fun(data, label):
  58. with gm:
  59. net.train()
  60. pred = net(data)
  61. loss = F.nn.cross_entropy(pred, label)
  62. gm.backward(loss)
  63. return pred, loss
  64. @trace
  65. def val_fun(data, label):
  66. net.eval()
  67. pred = net(data)
  68. loss = F.nn.cross_entropy(pred, label)
  69. return pred, loss
  70. @trace(symbolic=True, capture_as_const=True)
  71. def pred_fun(data):
  72. net.eval()
  73. pred = net(data)
  74. pred_normalized = F.softmax(pred)
  75. return pred_normalized
  76. train_loss = []
  77. val_loss = []
  78. for step, minibatch in enumerate(train_dataset):
  79. if step > 100:
  80. break
  81. data = tensor(minibatch["data"])
  82. label = tensor(minibatch["label"])
  83. opt.clear_grad()
  84. _, loss = train_fun(data, label)
  85. train_loss.append((step, loss.numpy()))
  86. if step % 50 == 0:
  87. minibatch = next(val_dataset)
  88. _, loss = val_fun(data, label)
  89. loss = loss.numpy()
  90. val_loss.append((step, loss))
  91. opt.step()
  92. test_data = np.array(
  93. [
  94. (0.5, 0.5),
  95. (0.3, 0.7),
  96. (0.1, 0.9),
  97. (-0.5, -0.5),
  98. (-0.3, -0.7),
  99. (-0.9, -0.1),
  100. (0.5, -0.5),
  101. (0.3, -0.7),
  102. (0.9, -0.1),
  103. (-0.5, 0.5),
  104. (-0.3, 0.7),
  105. (-0.1, 0.9),
  106. ]
  107. )
  108. data = tensor(test_data.astype(np.float32))
  109. out = pred_fun(data)
  110. with mkstemp() as out:
  111. pred_fun.dump(out, arg_names=["data"], output_names=["label"])
  112. def test_dump_bn_train_mode():
  113. @trace(symbolic=True, capture_as_const=True)
  114. def bn_train(data):
  115. pred = M.BatchNorm2d(10)(data).sum()
  116. return pred
  117. data = mge.tensor(np.random.random((10, 10, 10, 10)))
  118. bn_train(data)
  119. with pytest.raises(RuntimeError):
  120. bn_train.dump("test.mge")
  121. class ViTmode(M.Module):
  122. def __init__(self, patch_size=16, in_chans=3, embed_dim=384):
  123. super().__init__()
  124. self.proj = M.Conv2d(
  125. in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
  126. )
  127. self.head = M.Linear(embed_dim, 1000)
  128. def forward(self, x):
  129. x = self.proj(x)
  130. x = F.flatten(x, 2).transpose(0, 2, 1)
  131. x = self.head(x)
  132. return x
  133. def test_ViTmode_trace_train():
  134. model = ViTmode(embed_dim=384)
  135. data = mge.random.normal(size=(1, 3, 224, 224))
  136. optim = SGD(model.parameters(), lr=0.01)
  137. gm = GradManager()
  138. gm.attach(model.parameters())
  139. @trace(symbolic=True, capture_as_const=True)
  140. def train():
  141. for i in range(2):
  142. with gm:
  143. loss = model(data)
  144. gm.backward(loss)
  145. optim.step().clear_grad()
  146. train()