You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train_char_cnn.py 7.9 kB

6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. import sys
  2. sys.path.append('../..')
  3. from fastNLP.core.const import Const as C
  4. import torch.nn as nn
  5. from fastNLP.io.pipe.classification import YelpFullPipe,YelpPolarityPipe,SST2Pipe,IMDBPipe
  6. from model.char_cnn import CharacterLevelCNN
  7. from fastNLP import CrossEntropyLoss, AccuracyMetric
  8. from fastNLP.core.trainer import Trainer
  9. from torch.optim import SGD
  10. from torch.autograd import Variable
  11. import torch
  12. from torch.optim.lr_scheduler import LambdaLR
  13. from fastNLP.core import LRScheduler
  14. ##hyper
  15. #todo 这里加入fastnlp的记录
  16. class Config():
  17. #seed=7777
  18. model_dir_or_name="en-base-uncased"
  19. embedding_grad= False,
  20. bert_embedding_larers= '4,-2,-1'
  21. train_epoch= 100
  22. num_classes=2
  23. task= "yelp_p"
  24. lr=0.01
  25. batch_size=128
  26. model_size="large"
  27. number_of_characters=69
  28. extra_characters=''
  29. max_length=1014
  30. weight_decay = 1e-5
  31. to_lower=True
  32. tokenizer = 'spacy' # 使用spacy进行分词
  33. char_cnn_config={
  34. "alphabet": {
  35. "en": {
  36. "lower": {
  37. "alphabet": "abcdefghijklmnopqrstuvwxyz0123456789-,;.!?:'\"/\\|_@#$%^&*~`+-=<>()[]{}",
  38. "number_of_characters": 69
  39. },
  40. "both": {
  41. "alphabet": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-,;.!?:'\"/\\|_@#$%^&*~`+-=<>()[]{}",
  42. "number_of_characters": 95
  43. }
  44. }
  45. },
  46. "model_parameters": {
  47. "small": {
  48. "conv": [
  49. #依次是channel,kennnel_size,maxpooling_size
  50. [256,7,3],
  51. [256,7,3],
  52. [256,3,-1],
  53. [256,3,-1],
  54. [256,3,-1],
  55. [256,3,3]
  56. ],
  57. "fc": [1024,1024]
  58. },
  59. "large":{
  60. "conv":[
  61. [1024, 7, 3],
  62. [1024, 7, 3],
  63. [1024, 3, -1],
  64. [1024, 3, -1],
  65. [1024, 3, -1],
  66. [1024, 3, 3]
  67. ],
  68. "fc": [2048,2048]
  69. }
  70. },
  71. "data": {
  72. "text_column": "SentimentText",
  73. "label_column": "Sentiment",
  74. "max_length": 1014,
  75. "num_of_classes": 2,
  76. "encoding": None,
  77. "chunksize": 50000,
  78. "max_rows": 100000,
  79. "preprocessing_steps": ["lower", "remove_hashtags", "remove_urls", "remove_user_mentions"]
  80. },
  81. "training": {
  82. "batch_size": 128,
  83. "learning_rate": 0.01,
  84. "epochs": 10,
  85. "optimizer": "sgd"
  86. }
  87. }
  88. ops=Config
  89. # set_rng_seeds(ops.seed)
  90. # print('RNG SEED: {}'.format(ops.seed))
  91. ##1.task相关信息:利用dataloader载入dataInfo
  92. #dataloader=SST2Loader()
  93. #dataloader=IMDBLoader()
  94. # dataloader=YelpLoader(fine_grained=True)
  95. # datainfo=dataloader.process(ops.datapath,char_level_op=True,split_dev_op=False)
  96. char_vocab=ops.char_cnn_config["alphabet"]["en"]["lower"]["alphabet"]
  97. ops.number_of_characters=len(char_vocab)
  98. ops.embedding_dim=ops.number_of_characters
  99. # load data set
  100. if ops.task == 'yelp_p':
  101. data_bundle = YelpPolarityPipe(lower=ops.to_lower, tokenizer=ops.tokenizer).process_from_file()
  102. elif ops.task == 'yelp_f':
  103. data_bundle = YelpFullPipe(lower=ops.to_lower, tokenizer=ops.tokenizer).process_from_file()
  104. elif ops.task == 'imdb':
  105. data_bundle = IMDBPipe(lower=ops.to_lower, tokenizer=ops.tokenizer).process_from_file()
  106. elif ops.task == 'sst-2':
  107. data_bundle = SST2Pipe(lower=ops.to_lower, tokenizer=ops.tokenizer).process_from_file()
  108. else:
  109. raise RuntimeError(f'NOT support {ops.task} task yet!')
  110. print(data_bundle)
  111. def wordtochar(words):
  112. chars = []
  113. #for word in words:
  114. #word = word.lower()
  115. for char in words:
  116. chars.append(char)
  117. #chars.append('')
  118. #chars.pop()
  119. return chars
  120. #chartoindex
  121. def chartoindex(chars):
  122. max_seq_len=ops.max_length
  123. zero_index=len(char_vocab)
  124. char_index_list=[]
  125. for char in chars:
  126. if char in char_vocab:
  127. char_index_list.append(char_vocab.index(char))
  128. else:
  129. #<unk>和<pad>均使用最后一个作为embbeding
  130. char_index_list.append(zero_index)
  131. if len(char_index_list) > max_seq_len:
  132. char_index_list = char_index_list[:max_seq_len]
  133. elif 0 < len(char_index_list) < max_seq_len:
  134. char_index_list = char_index_list+[zero_index]*(max_seq_len-len(char_index_list))
  135. elif len(char_index_list) == 0:
  136. char_index_list=[zero_index]*max_seq_len
  137. return char_index_list
  138. for dataset in data_bundle.datasets.values():
  139. dataset.apply_field(wordtochar, field_name="raw_words", new_field_name='chars')
  140. dataset.apply_field(chartoindex,field_name='chars',new_field_name='chars')
  141. # print(data_bundle.datasets['train'][0]['chars'])
  142. # print(data_bundle.datasets['train'][0]['raw_words'])
  143. data_bundle.datasets['train'].set_input('chars')
  144. data_bundle.datasets['test'].set_input('chars')
  145. data_bundle.datasets['train'].set_target('target')
  146. data_bundle.datasets['test'].set_target('target')
  147. ##2. 定义/组装模型,这里可以随意,就如果是fastNLP封装好的,类似CNNText就直接用初始化调用就好了,这里只是给出一个伪框架表示占位,在这里建立符合fastNLP输入输出规范的model
  148. class ModelFactory(nn.Module):
  149. """
  150. 用于拼装embedding,encoder,decoder 以及设计forward过程
  151. :param embedding: embbeding model
  152. :param encoder: encoder model
  153. :param decoder: decoder model
  154. """
  155. def __int__(self,embedding,encoder,decoder,**kwargs):
  156. super(ModelFactory,self).__init__()
  157. self.embedding=embedding
  158. self.encoder=encoder
  159. self.decoder=decoder
  160. def forward(self,x):
  161. return {C.OUTPUT:None}
  162. ## 2.或直接复用fastNLP的模型
  163. #vocab=datainfo.vocabs['words']
  164. vocab_label=data_bundle.vocabs['target']
  165. '''
  166. # emded_char=CNNCharEmbedding(vocab)
  167. # embed_word = StaticEmbedding(vocab, model_dir_or_name='en-glove-6b-50', requires_grad=True)
  168. # embedding=StackEmbedding([emded_char, embed_word])
  169. # cnn_char_embed = CNNCharEmbedding(vocab)
  170. # lstm_char_embed = LSTMCharEmbedding(vocab)
  171. # embedding = StackEmbedding([cnn_char_embed, lstm_char_embed])
  172. '''
  173. #one-hot embedding
  174. embedding_weight= Variable(torch.zeros(len(char_vocab)+1, len(char_vocab)))
  175. for i in range(len(char_vocab)):
  176. embedding_weight[i][i]=1
  177. embedding=nn.Embedding(num_embeddings=len(char_vocab)+1,embedding_dim=len(char_vocab),padding_idx=len(char_vocab),_weight=embedding_weight)
  178. for para in embedding.parameters():
  179. para.requires_grad=False
  180. #CNNText太过于简单
  181. #model=CNNText(init_embed=embedding, num_classes=ops.num_classes)
  182. model=CharacterLevelCNN(ops,embedding)
  183. ## 3. 声明loss,metric,optimizer
  184. loss=CrossEntropyLoss
  185. metric=AccuracyMetric
  186. optimizer = SGD([param for param in model.parameters() if param.requires_grad == True],
  187. lr=ops.lr, momentum=0.9, weight_decay=ops.weight_decay)
  188. callbacks = []
  189. # callbacks.append(LRScheduler(CosineAnnealingLR(optimizer, 5)))
  190. callbacks.append(
  191. LRScheduler(LambdaLR(optimizer, lambda epoch: ops.lr if epoch <
  192. ops.train_epoch * 0.8 else ops.lr * 0.1))
  193. )
  194. ## 4.定义train方法
  195. def train(model,datainfo,loss,metrics,optimizer,num_epochs=100):
  196. trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss(target='target'),batch_size=ops.batch_size,
  197. metrics=[metrics(target='target')], dev_data=datainfo.datasets['test'], device=[0,1,2], check_code_level=-1,
  198. n_epochs=num_epochs,callbacks=callbacks)
  199. print(trainer.train())
  200. if __name__=="__main__":
  201. train(model,data_bundle,loss,metric,optimizer,num_epochs=ops.train_epoch)