You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

example.py 5.3 kB

7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. import time
  2. import aggregation
  3. import dataloader
  4. import embedding
  5. import encoder
  6. import predict
  7. import torch
  8. import torch.nn as nn
  9. import torch.optim as optim
  10. WORD_NUM = 357361
  11. WORD_SIZE = 100
  12. HIDDEN_SIZE = 300
  13. D_A = 350
  14. R = 10
  15. MLP_HIDDEN = 2000
  16. CLASSES_NUM = 5
  17. from fastNLP.models.base_model import BaseModel
  18. from fastNLP.action.trainer import BaseTrainer
  19. class MyNet(BaseModel):
  20. def __init__(self):
  21. super(MyNet, self).__init__()
  22. self.embedding = embedding.Lookuptable(WORD_NUM, WORD_SIZE)
  23. self.encoder = encoder.Lstm(WORD_SIZE, HIDDEN_SIZE, 1, 0.5, True)
  24. self.aggregation = aggregation.Selfattention(2 * HIDDEN_SIZE, D_A, R)
  25. self.predict = predict.MLP(R * HIDDEN_SIZE * 2, MLP_HIDDEN, CLASSES_NUM)
  26. self.penalty = None
  27. def encode(self, x):
  28. return self.encode(self.embedding(x))
  29. def aggregate(self, x):
  30. x, self.penalty = self.aggregate(x)
  31. return x
  32. def decode(self, x):
  33. return [self.predict(x), self.penalty]
  34. class Net(nn.Module):
  35. """
  36. A model for sentiment analysis using lstm and self-attention
  37. """
  38. def __init__(self):
  39. super(Net, self).__init__()
  40. self.embedding = embedding.Lookuptable(WORD_NUM, WORD_SIZE)
  41. self.encoder = encoder.Lstm(WORD_SIZE, HIDDEN_SIZE, 1, 0.5, True)
  42. self.aggregation = aggregation.Selfattention(2 * HIDDEN_SIZE, D_A, R)
  43. self.predict = predict.MLP(R * HIDDEN_SIZE * 2, MLP_HIDDEN, CLASSES_NUM)
  44. def forward(self, x):
  45. x = self.embedding(x)
  46. x = self.encoder(x)
  47. x, penalty = self.aggregation(x)
  48. x = self.predict(x)
  49. return x, penalty
  50. class MyTrainer(BaseTrainer):
  51. def __init__(self, args):
  52. super(MyTrainer, self).__init__(args)
  53. self.optimizer = None
  54. def define_optimizer(self):
  55. self.optimizer = optim.SGD(self.model.parameters(), lr=0.01, momentum=0.9)
  56. def define_loss(self):
  57. self.loss_func = nn.CrossEntropyLoss()
  58. def train(model_dict=None, using_cuda=True, learning_rate=0.06,\
  59. momentum=0.3, batch_size=32, epochs=5, coef=1.0, interval=10):
  60. """
  61. training procedure
  62. Args:
  63. If model_dict is given (a file address), it will continue training on the given model.
  64. Otherwise, it would train a new model from scratch.
  65. If using_cuda is true, the training would be conducted on GPU.
  66. Learning_rate and momentum is for SGD optimizer.
  67. coef is the coefficent between the cross-entropy loss and the penalization term.
  68. interval is the frequncy of reporting.
  69. the result will be saved with a form "model_dict_+current time", which could be used for further training
  70. """
  71. if using_cuda:
  72. net = Net().cuda()
  73. else:
  74. net = Net()
  75. if model_dict != None:
  76. net.load_state_dict(torch.load(model_dict))
  77. optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=momentum)
  78. criterion = nn.CrossEntropyLoss()
  79. dataset = dataloader.DataLoader("train_set.pkl", batch_size, using_cuda=using_cuda)
  80. #statistics
  81. loss_count = 0
  82. prepare_time = 0
  83. run_time = 0
  84. count = 0
  85. for epoch in range(epochs):
  86. print("epoch: %d"%(epoch))
  87. for i, batch in enumerate(dataset):
  88. t1 = time.time()
  89. X = batch["feature"]
  90. y = batch["class"]
  91. t2 = time.time()
  92. y_pred, y_penl = net(X)
  93. loss = criterion(y_pred, y) + torch.sum(y_penl) / batch_size * coef
  94. optimizer.zero_grad()
  95. loss.backward()
  96. nn.utils.clip_grad_norm(net.parameters(), 0.5)
  97. optimizer.step()
  98. t3 = time.time()
  99. loss_count += torch.sum(y_penl).data[0]
  100. prepare_time += (t2 - t1)
  101. run_time += (t3 - t2)
  102. p, idx = torch.max(y_pred.data, dim=1)
  103. count += torch.sum(torch.eq(idx.cpu(), y.data.cpu()))
  104. if (i + 1) % interval == 0:
  105. print("epoch : %d, iters: %d"%(epoch, i + 1))
  106. print("loss count:" + str(loss_count / (interval * batch_size)))
  107. print("acuracy:" + str(count / (interval * batch_size)))
  108. print("penalty:" + str(torch.sum(y_penl).data[0] / batch_size))
  109. print("prepare time:" + str(prepare_time))
  110. print("run time:" + str(run_time))
  111. prepare_time = 0
  112. run_time = 0
  113. loss_count = 0
  114. count = 0
  115. string = time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime())
  116. torch.save(net.state_dict(), "model_dict_%s.dict"%(string))
  117. def test(model_dict, using_cuda=True):
  118. if using_cuda:
  119. net = Net().cuda()
  120. else:
  121. net = Net()
  122. net.load_state_dict(torch.load(model_dict))
  123. dataset = dataloader.DataLoader("test_set.pkl", batch_size=1, using_cuda=using_cuda)
  124. count = 0
  125. for i, batch in enumerate(dataset):
  126. X = batch["feature"]
  127. y = batch["class"]
  128. y_pred, _ = net(X)
  129. p, idx = torch.max(y_pred.data, dim=1)
  130. count += torch.sum(torch.eq(idx.cpu(), y.data.cpu()))
  131. print("accuracy: %f"%(count / dataset.num))
  132. if __name__ == "__main__":
  133. train(using_cuda=torch.cuda.is_available())

一款轻量级的自然语言处理(NLP)工具包,目标是减少用户项目中的工程型代码,例如数据处理循环、训练循环、多卡运行等