You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train_char_cnn.py 7.4 kB


  1. # 首先需要加入以下的路径到环境变量,因为当前只对内部测试开放,所以需要手动申明一下路径
  2. import os
  3. os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/'
  4. os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches'
  5. import sys
  6. sys.path.append('../..')
  7. from fastNLP.core.const import Const as C
  8. import torch.nn as nn
  9. from data.yelpLoader import yelpLoader
  10. from data.sstLoader import sst2Loader
  11. from data.IMDBLoader import IMDBLoader
  12. from model.char_cnn import CharacterLevelCNN
  13. from fastNLP.core.vocabulary import Vocabulary
  14. from fastNLP.models.cnn_text_classification import CNNText
  15. from fastNLP.modules.encoder.embedding import CNNCharEmbedding,StaticEmbedding,StackEmbedding,LSTMCharEmbedding
  16. from fastNLP import CrossEntropyLoss, AccuracyMetric
  17. from fastNLP.core.trainer import Trainer
  18. from torch.optim import SGD
  19. from torch.autograd import Variable
  20. import torch
  21. from fastNLP import BucketSampler
  22. ##hyper
  23. #todo 这里加入fastnlp的记录
  24. class Config():
  25. model_dir_or_name="en-base-uncased"
  26. embedding_grad= False,
  27. bert_embedding_larers= '4,-2,-1'
  28. train_epoch= 50
  29. num_classes=2
  30. task= "IMDB"
  31. #yelp_p
  32. datapath = {"train": "/remote-home/ygwang/yelp_polarity/train.csv",
  33. "test": "/remote-home/ygwang/yelp_polarity/test.csv"}
  34. #IMDB
  35. #datapath = {"train": "/remote-home/ygwang/IMDB_data/train.csv",
  36. # "test": "/remote-home/ygwang/IMDB_data/test.csv"}
  37. # sst
  38. # datapath = {"train": "/remote-home/ygwang/workspace/GLUE/SST-2/train.tsv",
  39. # "dev": "/remote-home/ygwang/workspace/GLUE/SST-2/dev.tsv"}
  40. lr=0.01
  41. batch_size=128
  42. model_size="large"
  43. number_of_characters=69
  44. extra_characters=''
  45. max_length=1014
  46. char_cnn_config={
  47. "alphabet": {
  48. "en": {
  49. "lower": {
  50. "alphabet": "abcdefghijklmnopqrstuvwxyz0123456789-,;.!?:'\"/\\|_@#$%^&*~`+-=<>()[]{}",
  51. "number_of_characters": 69
  52. },
  53. "both": {
  54. "alphabet": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-,;.!?:'\"/\\|_@#$%^&*~`+-=<>()[]{}",
  55. "number_of_characters": 95
  56. }
  57. }
  58. },
  59. "model_parameters": {
  60. "small": {
  61. "conv": [
  62. #依次是channel,kennnel_size,maxpooling_size
  63. [256,7,3],
  64. [256,7,3],
  65. [256,3,-1],
  66. [256,3,-1],
  67. [256,3,-1],
  68. [256,3,3]
  69. ],
  70. "fc": [1024,1024]
  71. },
  72. "large":{
  73. "conv":[
  74. [1024, 7, 3],
  75. [1024, 7, 3],
  76. [1024, 3, -1],
  77. [1024, 3, -1],
  78. [1024, 3, -1],
  79. [1024, 3, 3]
  80. ],
  81. "fc": [2048,2048]
  82. }
  83. },
  84. "data": {
  85. "text_column": "SentimentText",
  86. "label_column": "Sentiment",
  87. "max_length": 1014,
  88. "num_of_classes": 2,
  89. "encoding": None,
  90. "chunksize": 50000,
  91. "max_rows": 100000,
  92. "preprocessing_steps": ["lower", "remove_hashtags", "remove_urls", "remove_user_mentions"]
  93. },
  94. "training": {
  95. "batch_size": 128,
  96. "learning_rate": 0.01,
  97. "epochs": 10,
  98. "optimizer": "sgd"
  99. }
  100. }
  101. ops=Config
  102. ##1.task相关信息:利用dataloader载入dataInfo
  103. #dataloader=sst2Loader()
  104. #dataloader=IMDBLoader()
  105. dataloader=yelpLoader(fine_grained=True)
  106. datainfo=dataloader.process(ops.datapath,char_level_op=True)
  107. char_vocab=ops.char_cnn_config["alphabet"]["en"]["lower"]["alphabet"]
  108. ops.number_of_characters=len(char_vocab)
  109. ops.embedding_dim=ops.number_of_characters
  110. #chartoindex
  111. def chartoindex(chars):
  112. max_seq_len=ops.max_length
  113. zero_index=len(char_vocab)
  114. char_index_list=[]
  115. for char in chars:
  116. if char in char_vocab:
  117. char_index_list.append(char_vocab.index(char))
  118. else:
  119. #<unk>和<pad>均使用最后一个作为embbeding
  120. char_index_list.append(zero_index)
  121. if len(char_index_list) > max_seq_len:
  122. char_index_list = char_index_list[:max_seq_len]
  123. elif 0 < len(char_index_list) < max_seq_len:
  124. char_index_list = char_index_list+[zero_index]*(max_seq_len-len(char_index_list))
  125. elif len(char_index_list) == 0:
  126. char_index_list=[zero_index]*max_seq_len
  127. return char_index_list
  128. for dataset in datainfo.datasets.values():
  129. dataset.apply_field(chartoindex,field_name='chars',new_field_name='chars')
  130. datainfo.datasets['train'].set_input('chars')
  131. datainfo.datasets['test'].set_input('chars')
  132. datainfo.datasets['train'].set_target('target')
  133. datainfo.datasets['test'].set_target('target')
  134. ##2. 定义/组装模型,这里可以随意,就如果是fastNLP封装好的,类似CNNText就直接用初始化调用就好了,这里只是给出一个伪框架表示占位,在这里建立符合fastNLP输入输出规范的model
  135. class ModelFactory(nn.Module):
  136. """
  137. 用于拼装embedding,encoder,decoder 以及设计forward过程
  138. :param embedding: embbeding model
  139. :param encoder: encoder model
  140. :param decoder: decoder model
  141. """
  142. def __int__(self,embedding,encoder,decoder,**kwargs):
  143. super(ModelFactory,self).__init__()
  144. self.embedding=embedding
  145. self.encoder=encoder
  146. self.decoder=decoder
  147. def forward(self,x):
  148. return {C.OUTPUT:None}
  149. ## 2.或直接复用fastNLP的模型
  150. #vocab=datainfo.vocabs['words']
  151. vocab_label=datainfo.vocabs['target']
  152. '''
  153. # emded_char=CNNCharEmbedding(vocab)
  154. # embed_word = StaticEmbedding(vocab, model_dir_or_name='en-glove-6b-50', requires_grad=True)
  155. # embedding=StackEmbedding([emded_char, embed_word])
  156. # cnn_char_embed = CNNCharEmbedding(vocab)
  157. # lstm_char_embed = LSTMCharEmbedding(vocab)
  158. # embedding = StackEmbedding([cnn_char_embed, lstm_char_embed])
  159. '''
  160. #one-hot embedding
  161. embedding_weight= Variable(torch.zeros(len(char_vocab)+1, len(char_vocab)))
  162. for i in range(len(char_vocab)):
  163. embedding_weight[i][i]=1
  164. embedding=nn.Embedding(num_embeddings=len(char_vocab)+1,embedding_dim=len(char_vocab),padding_idx=len(char_vocab),_weight=embedding_weight)
  165. for para in embedding.parameters():
  166. para.requires_grad=False
  167. #CNNText太过于简单
  168. #model=CNNText(init_embed=embedding, num_classes=ops.num_classes)
  169. model=CharacterLevelCNN(ops,embedding)
  170. ## 3. 声明loss,metric,optimizer
  171. loss=CrossEntropyLoss
  172. metric=AccuracyMetric
  173. optimizer= SGD([param for param in model.parameters() if param.requires_grad==True], lr=ops.lr)
  174. ## 4.定义train方法
  175. def train(model,datainfo,loss,metrics,optimizer,num_epochs=100):
  176. trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss(target='target'),
  177. metrics=[metrics(target='target')], dev_data=datainfo.datasets['test'], device=0, check_code_level=-1,
  178. n_epochs=num_epochs)
  179. print(trainer.train())
  180. if __name__=="__main__":
  181. #print(vocab_label)
  182. #print(datainfo.datasets["train"])
  183. train(model,datainfo,loss,metric,optimizer,num_epochs=ops.train_epoch)