You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train_dpcnn.py 5.0 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. # 首先需要加入以下的路径到环境变量,因为当前只对内部测试开放,所以需要手动申明一下路径
  2. import torch.cuda
  3. from fastNLP.core.utils import cache_results
  4. from torch.optim import SGD
  5. from torch.optim.lr_scheduler import CosineAnnealingLR
  6. from fastNLP.core.trainer import Trainer
  7. from fastNLP import CrossEntropyLoss, AccuracyMetric
  8. from fastNLP.embeddings import StaticEmbedding
  9. from reproduction.text_classification.model.dpcnn import DPCNN
  10. from fastNLP.io.data_loader import YelpLoader
  11. from fastNLP.core.sampler import BucketSampler
  12. from fastNLP.core import LRScheduler
  13. from fastNLP.core.const import Const as C
  14. from fastNLP.core.vocabulary import VocabularyOption
  15. from fastNLP.core.dist_trainer import DistTrainer
  16. from utils.util_init import set_rng_seeds
  17. from fastNLP import logger
  18. import os
  19. # os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/'
  20. # os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches'
  21. os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
  22. # hyper
  23. logger.add_file('log', 'INFO')
  24. print(logger.handlers)
  25. class Config():
  26. seed = 12345
  27. model_dir_or_name = "dpcnn-yelp-f"
  28. embedding_grad = True
  29. train_epoch = 30
  30. batch_size = 100
  31. task = "yelp_f"
  32. #datadir = 'workdir/datasets/SST'
  33. # datadir = 'workdir/datasets/yelp_polarity'
  34. datadir = 'workdir/datasets/yelp_full'
  35. #datafile = {"train": "train.txt", "dev": "dev.txt", "test": "test.txt"}
  36. datafile = {"train": "train.csv", "test": "test.csv"}
  37. lr = 1e-3
  38. src_vocab_op = VocabularyOption(max_size=100000)
  39. embed_dropout = 0.3
  40. cls_dropout = 0.1
  41. weight_decay = 1e-5
  42. def __init__(self):
  43. self.datadir = os.path.join(os.environ['HOME'], self.datadir)
  44. self.datapath = {k: os.path.join(self.datadir, v)
  45. for k, v in self.datafile.items()}
  46. ops = Config()
  47. set_rng_seeds(ops.seed)
  48. # print('RNG SEED: {}'.format(ops.seed))
  49. logger.info('RNG SEED %d'%ops.seed)
  50. # 1.task相关信息:利用dataloader载入dataInfo
  51. #datainfo=SSTLoader(fine_grained=True).process(paths=ops.datapath, train_ds=['train'])
  52. @cache_results(ops.model_dir_or_name+'-data-cache')
  53. def load_data():
  54. datainfo = YelpLoader(fine_grained=True, lower=True).process(
  55. paths=ops.datapath, train_ds=['train'], src_vocab_op=ops.src_vocab_op)
  56. for ds in datainfo.datasets.values():
  57. ds.apply_field(len, C.INPUT, C.INPUT_LEN)
  58. ds.set_input(C.INPUT, C.INPUT_LEN)
  59. ds.set_target(C.TARGET)
  60. return datainfo
  61. datainfo = load_data()
  62. embedding = StaticEmbedding(
  63. datainfo.vocabs['words'], model_dir_or_name='en-glove-6b-100d', requires_grad=ops.embedding_grad,
  64. normalize=False)
  65. embedding.embedding.weight.data /= embedding.embedding.weight.data.std()
  66. print(embedding.embedding.weight.data.mean(), embedding.embedding.weight.data.std())
  67. # 2.或直接复用fastNLP的模型
  68. # embedding = StackEmbedding([StaticEmbedding(vocab), CNNCharEmbedding(vocab, 100)])
  69. datainfo.datasets['train'] = datainfo.datasets['train'][:1000]
  70. datainfo.datasets['test'] = datainfo.datasets['test'][:1000]
  71. # print(datainfo)
  72. # print(datainfo.datasets['train'][0])
  73. logger.info(datainfo)
  74. model = DPCNN(init_embed=embedding, num_cls=len(datainfo.vocabs[C.TARGET]),
  75. embed_dropout=ops.embed_dropout, cls_dropout=ops.cls_dropout)
  76. # print(model)
  77. # 3. 声明loss,metric,optimizer
  78. loss = CrossEntropyLoss(pred=C.OUTPUT, target=C.TARGET)
  79. metric = AccuracyMetric(pred=C.OUTPUT, target=C.TARGET)
  80. optimizer = SGD([param for param in model.parameters() if param.requires_grad == True],
  81. lr=ops.lr, momentum=0.9, weight_decay=ops.weight_decay)
  82. callbacks = []
  83. callbacks.append(LRScheduler(CosineAnnealingLR(optimizer, 5)))
  84. # callbacks.append(
  85. # LRScheduler(LambdaLR(optimizer, lambda epoch: ops.lr if epoch <
  86. # ops.train_epoch * 0.8 else ops.lr * 0.1))
  87. # )
  88. # callbacks.append(
  89. # FitlogCallback(data=datainfo.datasets, verbose=1)
  90. # )
  91. device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
  92. # print(device)
  93. logger.info(device)
  94. # 4.定义train方法
  95. trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss,
  96. sampler=BucketSampler(num_buckets=50, batch_size=ops.batch_size),
  97. metrics=[metric], use_tqdm=False, save_path='save',
  98. dev_data=datainfo.datasets['test'], device=device,
  99. check_code_level=-1, batch_size=ops.batch_size, callbacks=callbacks,
  100. n_epochs=ops.train_epoch, num_workers=4)
  101. # trainer = DistTrainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss,
  102. # metrics=[metric],
  103. # dev_data=datainfo.datasets['test'], device='cuda',
  104. # batch_size_per_gpu=ops.batch_size, callbacks_all=callbacks,
  105. # n_epochs=ops.train_epoch, num_workers=4)
  106. if __name__ == "__main__":
  107. print(trainer.train())