You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

example.py 4.9 kB

7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. import time
  2. import aggregation
  3. import dataloader
  4. import embedding
  5. import encoder
  6. import predict
  7. import torch
  8. import torch.nn as nn
  9. import torch.optim as optim
  10. WORD_NUM = 357361
  11. WORD_SIZE = 100
  12. HIDDEN_SIZE = 300
  13. D_A = 350
  14. R = 10
  15. MLP_HIDDEN = 2000
  16. CLASSES_NUM = 5
  17. from fastNLP.models.base_model import BaseModel
  18. class MyNet(BaseModel):
  19. def __init__(self):
  20. super(MyNet, self).__init__()
  21. self.embedding = embedding.Lookuptable(WORD_NUM, WORD_SIZE)
  22. self.encoder = encoder.Lstm(WORD_SIZE, HIDDEN_SIZE, 1, 0.5, True)
  23. self.aggregation = aggregation.Selfattention(2 * HIDDEN_SIZE, D_A, R)
  24. self.predict = predict.MLP(R * HIDDEN_SIZE * 2, MLP_HIDDEN, CLASSES_NUM)
  25. self.penalty = None
  26. def encode(self, x):
  27. return self.encode(self.embedding(x))
  28. def aggregate(self, x):
  29. x, self.penalty = self.aggregate(x)
  30. return x
  31. def decode(self, x):
  32. return [self.predict(x), self.penalty]
  33. class Net(nn.Module):
  34. """
  35. A model for sentiment analysis using lstm and self-attention
  36. """
  37. def __init__(self):
  38. super(Net, self).__init__()
  39. self.embedding = embedding.Lookuptable(WORD_NUM, WORD_SIZE)
  40. self.encoder = encoder.Lstm(WORD_SIZE, HIDDEN_SIZE, 1, 0.5, True)
  41. self.aggregation = aggregation.Selfattention(2 * HIDDEN_SIZE, D_A, R)
  42. self.predict = predict.MLP(R * HIDDEN_SIZE * 2, MLP_HIDDEN, CLASSES_NUM)
  43. def forward(self, x):
  44. x = self.embedding(x)
  45. x = self.encoder(x)
  46. x, penalty = self.aggregation(x)
  47. x = self.predict(x)
  48. return x, penalty
  49. def train(model_dict=None, using_cuda=True, learning_rate=0.06,\
  50. momentum=0.3, batch_size=32, epochs=5, coef=1.0, interval=10):
  51. """
  52. training procedure
  53. Args:
  54. If model_dict is given (a file address), it will continue training on the given model.
  55. Otherwise, it would train a new model from scratch.
  56. If using_cuda is true, the training would be conducted on GPU.
  57. Learning_rate and momentum is for SGD optimizer.
  58. coef is the coefficent between the cross-entropy loss and the penalization term.
  59. interval is the frequncy of reporting.
  60. the result will be saved with a form "model_dict_+current time", which could be used for further training
  61. """
  62. if using_cuda:
  63. net = Net().cuda()
  64. else:
  65. net = Net()
  66. if model_dict != None:
  67. net.load_state_dict(torch.load(model_dict))
  68. optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=momentum)
  69. criterion = nn.CrossEntropyLoss()
  70. dataset = dataloader.DataLoader("train_set.pkl", batch_size, using_cuda=using_cuda)
  71. #statistics
  72. loss_count = 0
  73. prepare_time = 0
  74. run_time = 0
  75. count = 0
  76. for epoch in range(epochs):
  77. print("epoch: %d"%(epoch))
  78. for i, batch in enumerate(dataset):
  79. t1 = time.time()
  80. X = batch["feature"]
  81. y = batch["class"]
  82. t2 = time.time()
  83. y_pred, y_penl = net(X)
  84. loss = criterion(y_pred, y) + torch.sum(y_penl) / batch_size * coef
  85. optimizer.zero_grad()
  86. loss.backward()
  87. nn.utils.clip_grad_norm(net.parameters(), 0.5)
  88. optimizer.step()
  89. t3 = time.time()
  90. loss_count += torch.sum(y_penl).data[0]
  91. prepare_time += (t2 - t1)
  92. run_time += (t3 - t2)
  93. p, idx = torch.max(y_pred.data, dim=1)
  94. count += torch.sum(torch.eq(idx.cpu(), y.data.cpu()))
  95. if (i + 1) % interval == 0:
  96. print("epoch : %d, iters: %d"%(epoch, i + 1))
  97. print("loss count:" + str(loss_count / (interval * batch_size)))
  98. print("acuracy:" + str(count / (interval * batch_size)))
  99. print("penalty:" + str(torch.sum(y_penl).data[0] / batch_size))
  100. print("prepare time:" + str(prepare_time))
  101. print("run time:" + str(run_time))
  102. prepare_time = 0
  103. run_time = 0
  104. loss_count = 0
  105. count = 0
  106. string = time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime())
  107. torch.save(net.state_dict(), "model_dict_%s.dict"%(string))
  108. def test(model_dict, using_cuda=True):
  109. if using_cuda:
  110. net = Net().cuda()
  111. else:
  112. net = Net()
  113. net.load_state_dict(torch.load(model_dict))
  114. dataset = dataloader.DataLoader("test_set.pkl", batch_size=1, using_cuda=using_cuda)
  115. count = 0
  116. for i, batch in enumerate(dataset):
  117. X = batch["feature"]
  118. y = batch["class"]
  119. y_pred, _ = net(X)
  120. p, idx = torch.max(y_pred.data, dim=1)
  121. count += torch.sum(torch.eq(idx.cpu(), y.data.cpu()))
  122. print("accuracy: %f"%(count / dataset.num))
  123. if __name__ == "__main__":
  124. train(using_cuda=torch.cuda.is_available())