You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train_HAN.py 3.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. # 首先需要加入以下的路径到环境变量,因为当前只对内部测试开放,所以需要手动申明一下路径
  2. import os
  3. import sys
  4. sys.path.append('../../')
  5. os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/'
  6. os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches'
  7. os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
  8. from fastNLP.core.const import Const as C
  9. from fastNLP.core import LRScheduler
  10. import torch.nn as nn
  11. from fastNLP.io.dataset_loader import SSTLoader
  12. from reproduction.text_classification.data.yelpLoader import yelpLoader
  13. from reproduction.text_classification.model.HAN import HANCLS
  14. from fastNLP.modules.encoder.embedding import StaticEmbedding, CNNCharEmbedding, StackEmbedding
  15. from fastNLP import CrossEntropyLoss, AccuracyMetric
  16. from fastNLP.core.trainer import Trainer
  17. from torch.optim import SGD
  18. import torch.cuda
  19. from torch.optim.lr_scheduler import CosineAnnealingLR
  20. ##hyper
  21. class Config():
  22. model_dir_or_name = "en-base-uncased"
  23. embedding_grad = False,
  24. train_epoch = 30
  25. batch_size = 100
  26. num_classes = 5
  27. task = "yelp"
  28. #datadir = '/remote-home/lyli/fastNLP/yelp_polarity/'
  29. datadir = '/remote-home/ygwang/yelp_polarity/'
  30. datafile = {"train": "train.csv", "test": "test.csv"}
  31. lr = 1e-3
  32. def __init__(self):
  33. self.datapath = {k: os.path.join(self.datadir, v)
  34. for k, v in self.datafile.items()}
  35. ops = Config()
  36. ##1.task相关信息:利用dataloader载入dataInfo
  37. datainfo = yelpLoader(fine_grained=True).process(paths=ops.datapath, train_ds=['train'])
  38. print(len(datainfo.datasets['train']))
  39. print(len(datainfo.datasets['test']))
  40. # post process
  41. def make_sents(words):
  42. sents = [words]
  43. return sents
  44. for dataset in datainfo.datasets.values():
  45. dataset.apply_field(make_sents, field_name='words', new_field_name='input_sents')
  46. datainfo = datainfo
  47. datainfo.datasets['train'].set_input('input_sents')
  48. datainfo.datasets['test'].set_input('input_sents')
  49. datainfo.datasets['train'].set_target('target')
  50. datainfo.datasets['test'].set_target('target')
  51. ## 2.或直接复用fastNLP的模型
  52. vocab = datainfo.vocabs['words']
  53. # embedding = StackEmbedding([StaticEmbedding(vocab), CNNCharEmbedding(vocab, 100)])
  54. embedding = StaticEmbedding(vocab)
  55. print(len(vocab))
  56. print(len(datainfo.vocabs['target']))
  57. # model = DPCNN(init_embed=embedding, num_cls=ops.num_classes)
  58. model = HANCLS(init_embed=embedding, num_cls=ops.num_classes)
  59. ## 3. 声明loss,metric,optimizer
  60. loss = CrossEntropyLoss(pred=C.OUTPUT, target=C.TARGET)
  61. metric = AccuracyMetric(pred=C.OUTPUT, target=C.TARGET)
  62. optimizer = SGD([param for param in model.parameters() if param.requires_grad == True],
  63. lr=ops.lr, momentum=0.9, weight_decay=0)
  64. callbacks = []
  65. callbacks.append(LRScheduler(CosineAnnealingLR(optimizer, 5)))
  66. device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
  67. print(device)
  68. for ds in datainfo.datasets.values():
  69. ds.apply_field(len, C.INPUT, C.INPUT_LEN)
  70. ds.set_input(C.INPUT, C.INPUT_LEN)
  71. ds.set_target(C.TARGET)
  72. ## 4.定义train方法
  73. def train(model, datainfo, loss, metrics, optimizer, num_epochs=ops.train_epoch):
  74. trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss,
  75. metrics=[metrics], dev_data=datainfo.datasets['test'], device=device,
  76. check_code_level=-1, batch_size=ops.batch_size, callbacks=callbacks,
  77. n_epochs=num_epochs)
  78. print(trainer.train())
  79. if __name__ == "__main__":
  80. train(model, datainfo, loss, metric, optimizer)