You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

main.py 2.8 kB

7 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. import os
  2. import torch.nn.functional as F
  3. from fastNLP.loader.dataset_loader import ClassDatasetLoader as Dataset_loader
  4. from fastNLP.loader.embed_loader import EmbedLoader as EmbedLoader
  5. from fastNLP.loader.config_loader import ConfigSection
  6. from fastNLP.loader.config_loader import ConfigLoader
  7. from fastNLP.models.base_model import BaseModel
  8. from fastNLP.core.preprocess import ClassPreprocess as Preprocess
  9. from fastNLP.core.trainer import ClassificationTrainer
  10. from fastNLP.modules.encoder.embedding import Embedding as Embedding
  11. from fastNLP.modules.encoder.lstm import Lstm
  12. from fastNLP.modules.aggregation.self_attention import SelfAttention
  13. from fastNLP.modules.decoder.MLP import MLP
  14. train_data_path = 'small_train_data.txt'
  15. dev_data_path = 'small_dev_data.txt'
  16. # emb_path = 'glove.txt'
  17. lstm_hidden_size = 300
  18. embeding_size = 300
  19. attention_unit = 350
  20. attention_hops = 10
  21. class_num = 5
  22. nfc = 3000
  23. ### data load ###
  24. train_dataset = Dataset_loader(train_data_path)
  25. train_data = train_dataset.load()
  26. dev_args = Dataset_loader(dev_data_path)
  27. dev_data = dev_args.load()
  28. ###### preprocess ####
  29. preprocess = Preprocess()
  30. word2index, label2index = preprocess.build_dict(train_data)
  31. train_data, dev_data = preprocess.run(train_data, dev_data)
  32. # emb = EmbedLoader(emb_path)
  33. # embedding = emb.load_embedding(emb_dim= embeding_size , emb_file= emb_path ,word_dict= word2index)
  34. ### construct vocab ###
  35. class SELF_ATTENTION_YELP_CLASSIFICATION(BaseModel):
  36. def __init__(self, args=None):
  37. super(SELF_ATTENTION_YELP_CLASSIFICATION,self).__init__()
  38. self.embedding = Embedding(len(word2index) ,embeding_size , init_emb= None )
  39. self.lstm = Lstm(input_size = embeding_size,hidden_size = lstm_hidden_size ,bidirectional = True)
  40. self.attention = SelfAttention(lstm_hidden_size * 2 ,dim =attention_unit ,num_vec=attention_hops)
  41. self.mlp = MLP(size_layer=[lstm_hidden_size * 2*attention_hops ,nfc ,class_num ])
  42. def forward(self,x):
  43. x_emb = self.embedding(x)
  44. output = self.lstm(x_emb)
  45. after_attention, penalty = self.attention(output,x)
  46. after_attention =after_attention.view(after_attention.size(0),-1)
  47. output = self.mlp(after_attention)
  48. return output
  49. def loss(self, predict, ground_truth):
  50. print("predict:%s; g:%s" % (str(predict.size()), str(ground_truth.size())))
  51. print(ground_truth)
  52. return F.cross_entropy(predict, ground_truth)
  53. train_args = ConfigSection()
  54. ConfigLoader("good path").load_config('config.cfg',{"train": train_args})
  55. train_args['vocab'] = len(word2index)
  56. trainer = ClassificationTrainer(**train_args.data)
  57. # for k in train_args.__dict__.keys():
  58. # print(k, train_args[k])
  59. model = SELF_ATTENTION_YELP_CLASSIFICATION(train_args)
  60. trainer.train(model,train_data , dev_data)