You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train_char_cnn.py 7.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. # 首先需要加入以下的路径到环境变量,因为当前只对内部测试开放,所以需要手动申明一下路径
  2. import os
  3. os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/'
  4. os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches'
  5. import sys
  6. sys.path.append('../..')
  7. from fastNLP.core.const import Const as C
  8. import torch.nn as nn
  9. from fastNLP.io.data_loader import YelpLoader
  10. #from data.sstLoader import sst2Loader
  11. from model.char_cnn import CharacterLevelCNN
  12. from fastNLP import CrossEntropyLoss, AccuracyMetric
  13. from fastNLP.core.trainer import Trainer
  14. from torch.optim import SGD
  15. from torch.autograd import Variable
  16. import torch
  17. from torch.optim.lr_scheduler import LambdaLR
  18. from fastNLP.core import LRScheduler
  19. ##hyper
  20. #todo 这里加入fastnlp的记录
  21. class Config():
  22. #seed=7777
  23. model_dir_or_name="en-base-uncased"
  24. embedding_grad= False,
  25. bert_embedding_larers= '4,-2,-1'
  26. train_epoch= 50
  27. num_classes=2
  28. task= "yelp_p"
  29. #yelp_p
  30. datapath = {"train": "/remote-home/ygwang/yelp_polarity/train.csv",
  31. "test": "/remote-home/ygwang/yelp_polarity/test.csv"}
  32. #IMDB
  33. #datapath = {"train": "/remote-home/ygwang/IMDB_data/train.csv",
  34. # "test": "/remote-home/ygwang/IMDB_data/test.csv"}
  35. # sst
  36. # datapath = {"train": "/remote-home/ygwang/workspace/GLUE/SST-2/train.tsv",
  37. # "dev": "/remote-home/ygwang/workspace/GLUE/SST-2/dev.tsv"}
  38. lr=0.01
  39. batch_size=128
  40. model_size="large"
  41. number_of_characters=69
  42. extra_characters=''
  43. max_length=1014
  44. weight_decay = 1e-5
  45. char_cnn_config={
  46. "alphabet": {
  47. "en": {
  48. "lower": {
  49. "alphabet": "abcdefghijklmnopqrstuvwxyz0123456789-,;.!?:'\"/\\|_@#$%^&*~`+-=<>()[]{}",
  50. "number_of_characters": 69
  51. },
  52. "both": {
  53. "alphabet": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-,;.!?:'\"/\\|_@#$%^&*~`+-=<>()[]{}",
  54. "number_of_characters": 95
  55. }
  56. }
  57. },
  58. "model_parameters": {
  59. "small": {
  60. "conv": [
  61. #依次是channel,kennnel_size,maxpooling_size
  62. [256,7,3],
  63. [256,7,3],
  64. [256,3,-1],
  65. [256,3,-1],
  66. [256,3,-1],
  67. [256,3,3]
  68. ],
  69. "fc": [1024,1024]
  70. },
  71. "large":{
  72. "conv":[
  73. [1024, 7, 3],
  74. [1024, 7, 3],
  75. [1024, 3, -1],
  76. [1024, 3, -1],
  77. [1024, 3, -1],
  78. [1024, 3, 3]
  79. ],
  80. "fc": [2048,2048]
  81. }
  82. },
  83. "data": {
  84. "text_column": "SentimentText",
  85. "label_column": "Sentiment",
  86. "max_length": 1014,
  87. "num_of_classes": 2,
  88. "encoding": None,
  89. "chunksize": 50000,
  90. "max_rows": 100000,
  91. "preprocessing_steps": ["lower", "remove_hashtags", "remove_urls", "remove_user_mentions"]
  92. },
  93. "training": {
  94. "batch_size": 128,
  95. "learning_rate": 0.01,
  96. "epochs": 10,
  97. "optimizer": "sgd"
  98. }
  99. }
  100. ops=Config
  101. # set_rng_seeds(ops.seed)
  102. # print('RNG SEED: {}'.format(ops.seed))
  103. ##1.task相关信息:利用dataloader载入dataInfo
  104. #dataloader=SST2Loader()
  105. #dataloader=IMDBLoader()
  106. dataloader=YelpLoader(fine_grained=True)
  107. datainfo=dataloader.process(ops.datapath,char_level_op=True,split_dev_op=False)
  108. char_vocab=ops.char_cnn_config["alphabet"]["en"]["lower"]["alphabet"]
  109. ops.number_of_characters=len(char_vocab)
  110. ops.embedding_dim=ops.number_of_characters
  111. #chartoindex
  112. def chartoindex(chars):
  113. max_seq_len=ops.max_length
  114. zero_index=len(char_vocab)
  115. char_index_list=[]
  116. for char in chars:
  117. if char in char_vocab:
  118. char_index_list.append(char_vocab.index(char))
  119. else:
  120. #<unk>和<pad>均使用最后一个作为embbeding
  121. char_index_list.append(zero_index)
  122. if len(char_index_list) > max_seq_len:
  123. char_index_list = char_index_list[:max_seq_len]
  124. elif 0 < len(char_index_list) < max_seq_len:
  125. char_index_list = char_index_list+[zero_index]*(max_seq_len-len(char_index_list))
  126. elif len(char_index_list) == 0:
  127. char_index_list=[zero_index]*max_seq_len
  128. return char_index_list
  129. for dataset in datainfo.datasets.values():
  130. dataset.apply_field(chartoindex,field_name='chars',new_field_name='chars')
  131. datainfo.datasets['train'].set_input('chars')
  132. datainfo.datasets['test'].set_input('chars')
  133. datainfo.datasets['train'].set_target('target')
  134. datainfo.datasets['test'].set_target('target')
  135. ##2. 定义/组装模型,这里可以随意,就如果是fastNLP封装好的,类似CNNText就直接用初始化调用就好了,这里只是给出一个伪框架表示占位,在这里建立符合fastNLP输入输出规范的model
  136. class ModelFactory(nn.Module):
  137. """
  138. 用于拼装embedding,encoder,decoder 以及设计forward过程
  139. :param embedding: embbeding model
  140. :param encoder: encoder model
  141. :param decoder: decoder model
  142. """
  143. def __int__(self,embedding,encoder,decoder,**kwargs):
  144. super(ModelFactory,self).__init__()
  145. self.embedding=embedding
  146. self.encoder=encoder
  147. self.decoder=decoder
  148. def forward(self,x):
  149. return {C.OUTPUT:None}
  150. ## 2.或直接复用fastNLP的模型
  151. #vocab=datainfo.vocabs['words']
  152. vocab_label=datainfo.vocabs['target']
  153. '''
  154. # emded_char=CNNCharEmbedding(vocab)
  155. # embed_word = StaticEmbedding(vocab, model_dir_or_name='en-glove-6b-50', requires_grad=True)
  156. # embedding=StackEmbedding([emded_char, embed_word])
  157. # cnn_char_embed = CNNCharEmbedding(vocab)
  158. # lstm_char_embed = LSTMCharEmbedding(vocab)
  159. # embedding = StackEmbedding([cnn_char_embed, lstm_char_embed])
  160. '''
  161. #one-hot embedding
  162. embedding_weight= Variable(torch.zeros(len(char_vocab)+1, len(char_vocab)))
  163. for i in range(len(char_vocab)):
  164. embedding_weight[i][i]=1
  165. embedding=nn.Embedding(num_embeddings=len(char_vocab)+1,embedding_dim=len(char_vocab),padding_idx=len(char_vocab),_weight=embedding_weight)
  166. for para in embedding.parameters():
  167. para.requires_grad=False
  168. #CNNText太过于简单
  169. #model=CNNText(init_embed=embedding, num_classes=ops.num_classes)
  170. model=CharacterLevelCNN(ops,embedding)
  171. ## 3. 声明loss,metric,optimizer
  172. loss=CrossEntropyLoss
  173. metric=AccuracyMetric
  174. #optimizer= SGD([param for param in model.parameters() if param.requires_grad==True], lr=ops.lr)
  175. optimizer = SGD([param for param in model.parameters() if param.requires_grad == True],
  176. lr=ops.lr, momentum=0.9, weight_decay=ops.weight_decay)
  177. callbacks = []
  178. # callbacks.append(LRScheduler(CosineAnnealingLR(optimizer, 5)))
  179. callbacks.append(
  180. LRScheduler(LambdaLR(optimizer, lambda epoch: ops.lr if epoch <
  181. ops.train_epoch * 0.8 else ops.lr * 0.1))
  182. )
  183. ## 4.定义train方法
  184. def train(model,datainfo,loss,metrics,optimizer,num_epochs=100):
  185. trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss(target='target'),batch_size=ops.batch_size,
  186. metrics=[metrics(target='target')], dev_data=datainfo.datasets['test'], device=[0,1,2], check_code_level=-1,
  187. n_epochs=num_epochs)
  188. print(trainer.train())
  189. if __name__=="__main__":
  190. #print(vocab_label)
  191. #print(datainfo.datasets["train"])
  192. train(model,datainfo,loss,metric,optimizer,num_epochs=ops.train_epoch)