You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train_HAN.py 3.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. # 首先需要加入以下的路径到环境变量,因为当前只对内部测试开放,所以需要手动申明一下路径
  2. import os
  3. import sys
  4. sys.path.append('../../')
  5. os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/'
  6. os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches'
  7. os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
  8. from fastNLP.core.const import Const as C
  9. from fastNLP.core import LRScheduler
  10. from fastNLP.io.data_loader import YelpLoader
  11. from reproduction.text_classification.model.HAN import HANCLS
  12. from fastNLP.embeddings import StaticEmbedding
  13. from fastNLP import CrossEntropyLoss, AccuracyMetric
  14. from fastNLP.core.trainer import Trainer
  15. from torch.optim import SGD
  16. import torch.cuda
  17. from torch.optim.lr_scheduler import CosineAnnealingLR
  18. ##hyper
  19. class Config():
  20. model_dir_or_name = "en-base-uncased"
  21. embedding_grad = False,
  22. train_epoch = 30
  23. batch_size = 100
  24. num_classes = 5
  25. task = "yelp"
  26. #datadir = '/remote-home/lyli/fastNLP/yelp_polarity/'
  27. datadir = '/remote-home/ygwang/yelp_polarity/'
  28. datafile = {"train": "train.csv", "test": "test.csv"}
  29. lr = 1e-3
  30. def __init__(self):
  31. self.datapath = {k: os.path.join(self.datadir, v)
  32. for k, v in self.datafile.items()}
  33. ops = Config()
  34. ##1.task相关信息:利用dataloader载入dataInfo
  35. datainfo = YelpLoader(fine_grained=True).process(paths=ops.datapath, train_ds=['train'])
  36. print(len(datainfo.datasets['train']))
  37. print(len(datainfo.datasets['test']))
  38. # post process
  39. def make_sents(words):
  40. sents = [words]
  41. return sents
  42. for dataset in datainfo.datasets.values():
  43. dataset.apply_field(make_sents, field_name='words', new_field_name='input_sents')
  44. datainfo = datainfo
  45. datainfo.datasets['train'].set_input('input_sents')
  46. datainfo.datasets['test'].set_input('input_sents')
  47. datainfo.datasets['train'].set_target('target')
  48. datainfo.datasets['test'].set_target('target')
  49. ## 2.或直接复用fastNLP的模型
  50. vocab = datainfo.vocabs['words']
  51. # embedding = StackEmbedding([StaticEmbedding(vocab), CNNCharEmbedding(vocab, 100)])
  52. embedding = StaticEmbedding(vocab)
  53. print(len(vocab))
  54. print(len(datainfo.vocabs['target']))
  55. # model = DPCNN(init_embed=embedding, num_cls=ops.num_classes)
  56. model = HANCLS(init_embed=embedding, num_cls=ops.num_classes)
  57. ## 3. 声明loss,metric,optimizer
  58. loss = CrossEntropyLoss(pred=C.OUTPUT, target=C.TARGET)
  59. metric = AccuracyMetric(pred=C.OUTPUT, target=C.TARGET)
  60. optimizer = SGD([param for param in model.parameters() if param.requires_grad == True],
  61. lr=ops.lr, momentum=0.9, weight_decay=0)
  62. callbacks = []
  63. callbacks.append(LRScheduler(CosineAnnealingLR(optimizer, 5)))
  64. device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
  65. print(device)
  66. for ds in datainfo.datasets.values():
  67. ds.apply_field(len, C.INPUT, C.INPUT_LEN)
  68. ds.set_input(C.INPUT, C.INPUT_LEN)
  69. ds.set_target(C.TARGET)
  70. ## 4.定义train方法
  71. def train(model, datainfo, loss, metrics, optimizer, num_epochs=ops.train_epoch):
  72. trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss,
  73. metrics=[metrics], dev_data=datainfo.datasets['test'], device=device,
  74. check_code_level=-1, batch_size=ops.batch_size, callbacks=callbacks,
  75. n_epochs=num_epochs)
  76. print(trainer.train())
  77. if __name__ == "__main__":
  78. train(model, datainfo, loss, metric, optimizer)