You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

trainer.md 2.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. # Trainer使用教程
  2. Modelscope提供了众多预训练模型,你可以使用其中任意一个,利用公开数据集或者私有数据集针对特定任务进行模型训练,在本篇文章中将介绍如何使用Modelscope的`Trainer`模块进行Finetuning和评估。
  3. ## 环境准备
  4. 详细步骤可以参考 [快速开始](../quick_start.md)
  5. ### 准备数据集
  6. 在开始Finetuning前,需要准备一个数据集用以训练和评估,详细可以参考数据集使用教程。
  7. `临时写法`,我们通过数据集接口创建一个虚假的dataset
  8. ```python
  9. from datasets import Dataset
  10. dataset_dict = {
  11. 'sentence1': [
  12. 'This is test sentence1-1', 'This is test sentence2-1',
  13. 'This is test sentence3-1'
  14. ],
  15. 'sentence2': [
  16. 'This is test sentence1-2', 'This is test sentence2-2',
  17. 'This is test sentence3-2'
  18. ],
  19. 'label': [0, 1, 1]
  20. }
  21. train_dataset = MsDataset.from_hf_dataset(Dataset.from_dict(dataset_dict))
  22. eval_dataset = MsDataset.from_hf_dataset(Dataset.from_dict(dataset_dict))
  23. ```
  24. ### 训练
  25. ModelScope把所有训练相关的配置信息全部放到了模型仓库下的`configuration.json`中,因此我们只需要创建Trainer,加载配置文件,传入数据集即可完成训练。
  26. 首先,通过工厂方法创建Trainer, 需要传入模型仓库路径, 训练数据集对象,评估数据集对象,训练目录
  27. ```python
  28. kwargs = dict(
  29. model='damo/nlp_structbert_sentiment-classification_chinese-base',
  30. train_dataset=train_dataset,
  31. eval_dataset=eval_dataset,
  32. work_dir='work_dir')
  33. trainer = build_trainer(default_args=kwargs)
  34. ```
  35. 启动训练。
  36. ```python
  37. trainer.train()
  38. ```
  39. 如果需要调整训练参数,可以在模型仓库页面下载`configuration.json`文件到本地,修改参数后,指定配置文件路径,创建trainer
  40. ```python
  41. kwargs = dict(
  42. model='damo/nlp_structbert_sentiment-classification_chinese-base',
  43. train_dataset=train_dataset,
  44. eval_dataset=eval_dataset,
  45. cfg_file='你的配置文件路径'
  46. work_dir='work_dir')
  47. trainer = build_trainer(default_args=kwargs)
  48. trainer.train()
  49. ```
  50. ### 评估
  51. 训练过程中会定期使用验证集进行评估测试, Trainer模块也支持指定特定轮次保存的checkpoint路径,进行单次评估。
  52. ```python
  53. eval_results = trainer.evaluate('work_dir/epoch_10.pth')
  54. print(eval_results)
  55. ```