You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

main.py 2.7 kB

7 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. import torch.nn.functional as F
  2. from fastNLP.core.preprocess import ClassPreprocess as Preprocess
  3. from fastNLP.core.trainer import ClassificationTrainer
  4. from fastNLP.loader.config_loader import ConfigLoader
  5. from fastNLP.loader.config_loader import ConfigSection
  6. from fastNLP.loader.dataset_loader import ClassDataSetLoader as Dataset_loader
  7. from fastNLP.models.base_model import BaseModel
  8. from fastNLP.modules.aggregator.self_attention import SelfAttention
  9. from fastNLP.modules.decoder.MLP import MLP
  10. from fastNLP.modules.encoder.embedding import Embedding as Embedding
  11. from fastNLP.modules.encoder.lstm import LSTM
  12. train_data_path = 'small_train_data.txt'
  13. dev_data_path = 'small_dev_data.txt'
  14. # emb_path = 'glove.txt'
  15. lstm_hidden_size = 300
  16. embeding_size = 300
  17. attention_unit = 350
  18. attention_hops = 10
  19. class_num = 5
  20. nfc = 3000
  21. ### data load ###
  22. train_dataset = Dataset_loader(train_data_path)
  23. train_data = train_dataset.load()
  24. dev_args = Dataset_loader(dev_data_path)
  25. dev_data = dev_args.load()
  26. ###### preprocess ####
  27. preprocess = Preprocess()
  28. word2index, label2index = preprocess.build_dict(train_data)
  29. train_data, dev_data = preprocess.run(train_data, dev_data)
  30. # emb = EmbedLoader(emb_path)
  31. # embedding = emb.load_embedding(emb_dim= embeding_size , emb_file= emb_path ,word_dict= word2index)
  32. ### construct vocab ###
  33. class SELF_ATTENTION_YELP_CLASSIFICATION(BaseModel):
  34. def __init__(self, args=None):
  35. super(SELF_ATTENTION_YELP_CLASSIFICATION,self).__init__()
  36. self.embedding = Embedding(len(word2index) ,embeding_size , init_emb= None )
  37. self.lstm = LSTM(input_size=embeding_size, hidden_size=lstm_hidden_size, bidirectional=True)
  38. self.attention = SelfAttention(lstm_hidden_size * 2 ,dim =attention_unit ,num_vec=attention_hops)
  39. self.mlp = MLP(size_layer=[lstm_hidden_size * 2*attention_hops ,nfc ,class_num ])
  40. def forward(self,x):
  41. x_emb = self.embedding(x)
  42. output = self.lstm(x_emb)
  43. after_attention, penalty = self.attention(output,x)
  44. after_attention =after_attention.view(after_attention.size(0),-1)
  45. output = self.mlp(after_attention)
  46. return output
  47. def loss(self, predict, ground_truth):
  48. print("predict:%s; g:%s" % (str(predict.size()), str(ground_truth.size())))
  49. print(ground_truth)
  50. return F.cross_entropy(predict, ground_truth)
  51. train_args = ConfigSection()
  52. ConfigLoader("good path").load_config('config.cfg',{"train": train_args})
  53. train_args['vocab'] = len(word2index)
  54. trainer = ClassificationTrainer(**train_args.data)
  55. # for k in train_args.__dict__.keys():
  56. # print(k, train_args[k])
  57. model = SELF_ATTENTION_YELP_CLASSIFICATION(train_args)
  58. trainer.train(model,train_data , dev_data)