You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_converge_with_drop.py 3.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. # -*- coding: utf-8 -*-
  2. import itertools
  3. import numpy as np
  4. import megengine as mge
  5. import megengine.autodiff as ad
  6. import megengine.functional as F
  7. from megengine import Tensor
  8. from megengine.core import get_option, set_option
  9. from megengine.module import Linear, Module
  10. from megengine.optimizer import SGD
  11. batch_size = 64
  12. data_shape = (batch_size, 2)
  13. label_shape = (batch_size,)
  14. def minibatch_generator():
  15. while True:
  16. inp_data = np.zeros((batch_size, 2))
  17. label = np.zeros(batch_size, dtype=np.int32)
  18. for i in range(batch_size):
  19. # [x0, x1], sampled from U[-1, 1]
  20. inp_data[i, :] = np.random.rand(2) * 2 - 1
  21. label[i] = 0 if np.prod(inp_data[i]) < 0 else 1
  22. yield inp_data.astype(np.float32), label.astype(np.int32)
  23. def calculate_precision(data: np.ndarray, pred: np.ndarray) -> float:
  24. """ Calculate precision for given data and prediction.
  25. :type data: [[x, y], ...]
  26. :param data: Input data
  27. :type pred: [[x_pred, y_pred], ...]
  28. :param pred: Network output data
  29. """
  30. correct = 0
  31. assert len(data) == len(pred)
  32. for inp_data, pred_output in zip(data, pred):
  33. label = 0 if np.prod(inp_data) < 0 else 1
  34. pred_label = np.argmax(pred_output)
  35. if pred_label == label:
  36. correct += 1
  37. return float(correct) / len(data)
  38. class XORNet(Module):
  39. def __init__(self):
  40. self.mid_layers = 14
  41. self.num_class = 2
  42. super().__init__()
  43. self.fc0 = Linear(self.num_class, self.mid_layers, bias=True)
  44. self.fc1 = Linear(self.mid_layers, self.mid_layers, bias=True)
  45. self.fc2 = Linear(self.mid_layers, self.num_class, bias=True)
  46. def forward(self, x):
  47. y = self.fc0(x)
  48. x = F.tanh(y)
  49. y = self.fc1(x)
  50. x = F.tanh(y)
  51. x = self.fc2(x)
  52. y = (x + x) / 2 # in order to test drop()
  53. y._drop()
  54. return y
  55. def test_training_converge_with_drop():
  56. set_option("enable_drop", 1)
  57. net = XORNet()
  58. opt = SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
  59. gm = ad.GradManager().attach(net.parameters())
  60. def train(data, label):
  61. with gm:
  62. pred = net(data)
  63. loss = F.nn.cross_entropy(pred, label)
  64. gm.backward(loss)
  65. return loss
  66. def infer(data):
  67. return net(data)
  68. train_dataset = minibatch_generator()
  69. losses = []
  70. for data, label in itertools.islice(train_dataset, 2000):
  71. data = Tensor(data, dtype=np.float32)
  72. label = Tensor(label, dtype=np.int32)
  73. opt.clear_grad()
  74. loss = train(data, label)
  75. opt.step()
  76. losses.append(loss.numpy())
  77. assert np.mean(losses[-100:]) < 0.1, "Final training Loss must be low enough"
  78. ngrid = 10
  79. x = np.linspace(-1.0, 1.0, ngrid)
  80. xx, yy = np.meshgrid(x, x)
  81. xx = xx.reshape((ngrid * ngrid, 1))
  82. yy = yy.reshape((ngrid * ngrid, 1))
  83. data = mge.tensor(np.concatenate((xx, yy), axis=1).astype(np.float32))
  84. pred = infer(Tensor(data)).numpy()
  85. precision = calculate_precision(data.numpy(), pred)
  86. assert precision == 1.0, "Test precision must be high enough, get {}".format(
  87. precision
  88. )
  89. set_option("enable_drop", 0)