You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

README.md 19 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395
  1. # Contents
  2. - [Contents](#contents)
  3. - [TernaryBERT Description](#ternarybert-description)
  4. - [Model Architecture](#model-architecture)
  5. - [Dataset](#dataset)
  6. - [Environment Requirements](#environment-requirements)
  7. - [Quick Start](#quick-start)
  8. - [Script Description](#script-description)
  9. - [Script and Sample Code](#script-and-sample-code)
  10. - [Script Parameters](#script-parameters)
  11. - [Train](#train)
  12. - [Eval](#eval)
  13. - [Options and Parameters](#options-and-parameters)
  14. - [Parameters](#parameters)
  15. - [Training Process](#training-process)
  16. - [Training](#training)
  17. - [Evaluation Process](#evaluation-process)
  18. - [Evaluation](#evaluation)
  19. - [evaluation on STS-B dataset](#evaluation-on-STS-B-dataset)
  20. - [evaluation on QNLI dataset](#evaluation-on-qnli-dataset)
  21. - [evaluation on MNLI dataset](#evaluation-on-mnli-dataset)
  22. - [Model Description](#model-description)
  23. - [Performance](#performance)
  24. - [training Performance](#training-performance)
  25. - [Inference Performance](#inference-performance)
  26. - [Description of Random Situation](#description-of-random-situation)
  27. - [ModelZoo Homepage](#modelzoo-homepage)
  28. # [TernaryBERT Description](#contents)
  29. [TernaryBERT](https://arxiv.org/abs/2009.12812) ternarizes the weights in a fine-tuned [BERT](https://arxiv.org/abs/1810.04805) or [TinyBERT](https://arxiv.org/abs/1909.10351) model and achieves competitive performances in natural language processing tasks. TernaryBERT outperforms the other BERT quantization methods, and even achieves comparable performance as the full-precision model while being 14.9x smaller
  30. [Paper](https://arxiv.org/abs/2009.12812): Wei Zhang, Lu Hou, Yichun Yin, Lifeng Shang, Xiao Chen, Xin Jiang and Qun Liu. [TernaryBERT: Distillation-aware Ultra-low Bit BERT](https://arxiv.org/abs/2009.12812). arXiv preprint arXiv:2009.12812.
  31. # [Model Architecture](#contents)
  32. The backbone structure of TernaryBERT is transformer, the transformer contains six encoder modules, one encoder contains one self-attention module and one self-attention module contains one attention module.
  33. # [Dataset](#contents)
  34. - Download glue dataset for task distillation. Convert dataset files from json format to tfrecord format, please refer to run_classifier.py which in [BERT](https://github.com/google-research/bert) repository.
  35. # [Environment Requirements](#contents)
  36. - Hardware(GPU)
  37. - Prepare hardware environment with GPU processor.
  38. - Framework
  39. - [MindSpore](https://gitee.com/mindspore/mindspore)
  40. - For more information, please check the resources below:
  41. - [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
  42. - [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
  43. - Software:
  44. - sklearn
  45. # [Quick Start](#contents)
  46. After installing MindSpore via the official website, you can start training and evaluation as follows:
  47. ```bash
  48. # run training example
  49. sh scripts/run_train.sh
  50. Before running the shell script, please set the `task_name`, `teacher_model_dir`, `student_model_dir` and `data_dir` in the run_train.sh file first.
  51. # run evaluation example
  52. sh scripts/run_eval.sh
  53. Before running the shell script, please set the `task_name`, `model_dir` and `data_dir` in the run_eval.sh file first.
  54. ```
  55. # [Script Description](#contents)
  56. ## [Script and Sample Code](#contents)
  57. ```text
  58. .
  59. └─bert
  60. ├─README.md
  61. ├─scripts
  62. ├─run_train.sh # shell script for training phase
  63. ├─run_eval.sh # shell script for evaluation phase
  64. ├─src
  65. ├─__init__.py
  66. ├─assessment_method.py # assessment method for evaluation
  67. ├─cell_wrapper.py # cell for training
  68. ├─config.py # parameter configuration for training and evaluation phase
  69. ├─dataset.py # data processing
  70. ├─quant.py # function for quantization
  71. ├─tinybert_model.py # backbone code of network
  72. ├─utils.py # util function
  73. ├─__init__.py
  74. ├─train.py # train net for task distillation
  75. ├─eval.py # evaluate net after task distillation
  76. ```
  77. ## [Script Parameters](#contents)
  78. ### Train
  79. ```text
  80. usage: train.py [--h]
  81. [--device_target {GPU,Ascend}]
  82. [--do_eval {true,false}]
  83. [--epoch_size EPOCH_SIZE]
  84. [--device_id DEVICE_ID]
  85. [--do_shuffle {true,false}]
  86. [--enable_data_sink {true,false}]
  87. [--save_ckpt_step SAVE_CKPT_STEP]
  88. [--eval_ckpt_step EVAL_CKPT_STEP]
  89. [--max_ckpt_num MAX_CKPT_NUM]
  90. [--data_sink_steps DATA_SINK_STEPS]
  91. [--teacher_model_dir TEACHER_MODEL_DIR]
  92. [--student_model_dir STUDENT_MODEL_DIR]
  93. [--data_dir DATA_DIR]
  94. [--output_dir OUTPUT_DIR]
  95. [--task_name {sts-b,qnli,mnli}]
  96. [--dataset_type DATASET_TYPE]
  97. [--seed SEED]
  98. [--train_batch_size TRAIN_BATCH_SIZE]
  99. [--eval_batch_size EVAL_BATCH_SIZE]
  100. options:
  101. --device_target Device where the code will be implemented: "GPU" | "Ascend", default is "GPU"
  102. --do_eval Do eval task during training or not: "true" | "false", default is "true"
  103. --epoch_size Epoch size for train phase: N, default is 3
  104. --device_id Device id: N, default is 0
  105. --do_shuffle Enable shuffle for train dataset: "true" | "false", default is "true"
  106. --enable_data_sink Enable data sink: "true" | "false", default is "true"
  107. --save_ckpt_step If do_eval is false, the checkpoint will be saved every save_ckpt_step: N, default is 50
  108. --eval_ckpt_step If do_eval is true, the evaluation will be ran every eval_ckpt_step: N, default is 50
  109. --max_ckpt_num The number of checkpoints will not be larger than max_ckpt_num: N, default is 50
  110. --data_sink_steps Sink steps for each epoch: N, default is 1
  111. --teacher_model_dir The checkpoint directory of teacher model: PATH, default is ""
  112. --student_model_dir The checkpoint directory of student model: PATH, default is ""
  113. --data_dir Data directory: PATH, default is ""
  114. --output_dir The output checkpoint directory: PATH, default is "./"
  115. --task_name The name of the task to train: "sts-b" | "qnli" | "mnli", default is "sts-b"
  116. --dataset_type The name of the task to train: "tfrecord" | "mindrecord", default is "tfrecord"
  117. --seed The random seed: N, default is 1
  118. --train_batch_size Batch size for training: N, default is 16
  119. --eval_batch_size Eval Batch size in callback: N, default is 32
  120. ```
  121. ### Eval
  122. ```text
  123. usage: eval.py [--h]
  124. [--device_target {GPU,Ascend}]
  125. [--device_id DEVICE_ID]
  126. [--model_dir MODEL_DIR]
  127. [--data_dir DATA_DIR]
  128. [--task_name {sts-b,qnli,mnli}]
  129. [--dataset_type DATASET_TYPE]
  130. [--batch_size BATCH_SIZE]
  131. options:
  132. --device_target Device where the code will be implemented: "GPU" | "Ascend", default is "GPU"
  133. --device_id Device id: N, default is 0
  134. --model_dir The checkpoint directory of model: PATH, default is ""
  135. --data_dir Data directory: PATH, default is ""
  136. --task_name The name of the task to train: "sts-b" | "qnli" | "mnli", default is "sts-b"
  137. --dataset_type The name of the task to train: "tfrecord" | "mindrecord", default is "tfrecord"
  138. --batch_size Batch size for evaluating: N, default is 32
  139. ```
  140. ## Parameters
  141. `config.py`contains parameters of glue tasks, train, optimizer, eval, teacher BERT model and student BERT model.
  142. ```text
  143. Parameters for glue task:
  144. num_labels the numbers of labels: N.
  145. seq_length length of input sequence: N
  146. task_type the type of task: "classification" | "regression"
  147. metrics the eval metric for task: Accuracy | F1 | Pearsonr | Matthews
  148. Parameters for train:
  149. batch_size batch size of input dataset: N, default is 16
  150. loss_scale_value initial value of loss scale: N, default is 2^16
  151. scale_factor factor used to update loss scale: N, default is 2
  152. scale_window steps for once updatation of loss scale: N, default is 50
  153. Parameters for optimizer:
  154. learning_rate value of learning rate: Q, default is 5e-5
  155. end_learning_rate value of end learning rate: Q, must be positive, default is 1e-14
  156. power power: Q, default is 1.0
  157. weight_decay weight decay: Q, default is 1e-4
  158. eps term added to the denominator to improve numerical stability: Q, default is 1e-6
  159. warmup_ratio the ratio of warmup steps to total steps: Q, default is 0.1
  160. Parameters for eval:
  161. batch_size batch size of input dataset: N, default is 32
  162. Parameters for teacher bert network:
  163. seq_length length of input sequence: N, default is 128
  164. vocab_size size of each embedding vector: N, must be consistant with the dataset you use. Default is 30522
  165. hidden_size size of bert encoder layers: N
  166. num_hidden_layers number of hidden layers: N
  167. num_attention_heads number of attention heads: N, default is 12
  168. intermediate_size size of intermediate layer: N
  169. hidden_act activation function used: ACTIVATION, default is "gelu"
  170. hidden_dropout_prob dropout probability for BertOutput: Q
  171. attention_probs_dropout_prob dropout probability for BertAttention: Q
  172. max_position_embeddings maximum length of sequences: N, default is 512
  173. save_ckpt_step number for saving checkponit: N, default is 100
  174. max_ckpt_num maximum number for saving checkpoint: N, default is 1
  175. type_vocab_size size of token type vocab: N, default is 2
  176. initializer_range initialization value of TruncatedNormal: Q, default is 0.02
  177. use_relative_positions use relative positions or not: True | False, default is False
  178. dtype data type of input: mstype.float16 | mstype.float32, default is mstype.float32
  179. compute_type compute type in BertTransformer: mstype.float16 | mstype.float32, default is mstype.float32
  180. Parameters for student bert network:
  181. seq_length length of input sequence: N, default is 128
  182. vocab_size size of each embedding vector: N, must be consistant with the dataset you use. Default is 30522
  183. hidden_size size of bert encoder layers: N
  184. num_hidden_layers number of hidden layers: N
  185. num_attention_heads number of attention heads: N, default is 12
  186. intermediate_size size of intermediate layer: N
  187. hidden_act activation function used: ACTIVATION, default is "gelu"
  188. hidden_dropout_prob dropout probability for BertOutput: Q
  189. attention_probs_dropout_prob dropout probability for BertAttention: Q
  190. max_position_embeddings maximum length of sequences: N, default is 512
  191. save_ckpt_step number for saving checkponit: N, default is 100
  192. max_ckpt_num maximum number for saving checkpoint: N, default is 1
  193. type_vocab_size size of token type vocab: N, default is 2
  194. initializer_range initialization value of TruncatedNormal: Q, default is 0.02
  195. use_relative_positions use relative positions or not: True | False, default is False
  196. dtype data type of input: mstype.float16 | mstype.float32, default is mstype.float32
  197. compute_type compute type in BertTransformer: mstype.float16 | mstype.float32, default is mstype.float32
  198. do_quant do activation quantilization or not: True | False, default is True
  199. embedding_bits the quant bits of embedding: N, default is 2
  200. weight_bits the quant bits of weight: N, default is 2
  201. cls_dropout_prob dropout probability for BertModelCLS: Q
  202. activation_init initialization value of activation quantilization: Q, default is 2.5
  203. is_lgt_fit use label ground truth loss or not: True | False, default is False
  204. ```
  205. ## [Training Process](#contents)
  206. ### Training
  207. Before running the command below, please check `teacher_model_dir`, `student_model_dir` and `data_dir` has been set. Please set the path to be the absolute full path, e.g:"/home/xxx/model_dir/".
  208. ```text
  209. python
  210. python train.py --task_name='sts-b' --teacher_model_dir='/home/xxx/model_dir/' --student_model_dir='/home/xxx/model_dir/' --data_dir='/home/xxx/data_dir/'
  211. shell
  212. sh scripts/run_train.sh
  213. ```
  214. The shell command above will run in the background, you can view the results the file log.txt. The python command will run in the console, you can view the results on the interface. After training, you will get some checkpoint files under the script folder by default. The eval metric value will be achieved as follows:
  215. ```text
  216. step: 50, Pearsonr 72.50008506516072, best_Pearsonr 72.50008506516072
  217. step 100, Pearsonr 81.3580301181608, best_Pearsonr 81.3580301181608
  218. step 150, Pearsonr 83.60461724688754, best_Pearsonr 83.60461724688754
  219. step 200, Pearsonr 82.23210161651377, best_Pearsonr 83.60461724688754
  220. ...
  221. step 1050, Pearsonr 87.5606067964618332, best_Pearsonr 87.58388835685436
  222. ```
  223. ## [Evaluation Process](#contents)
  224. ### Evaluation
  225. If you want to after running and continue to eval.
  226. #### evaluation on STS-B dataset
  227. ```text
  228. python
  229. python eval.py --task_name='sts-b' --model_dir='/home/xxx/model_dir/' --data_dir='/home/xxx/data_dir/'
  230. shell
  231. sh scripts/run_eval.sh
  232. ```
  233. The shell command above will run in the background, you can view the results the file log.txt. The python command will run in the console, you can view the results on the interface. The metric value of the test dataset will be as follows:
  234. ```text
  235. eval step: 0, Pearsonr: 96.91109003302263
  236. eval step: 1, Pearsonr: 95.6800637493701
  237. eval step: 2, Pearsonr: 94.23823082886167
  238. ...
  239. The best Pearsonr: 87.58388835685437
  240. ```
  241. #### evaluation on QNLI dataset
  242. ```text
  243. python
  244. python eval.py --task_name='qnli' --model_dir='/home/xxx/model_dir/' --data_dir='/home/xxx/data_dir/'
  245. shell
  246. sh scripts/run_eval.sh
  247. ```
  248. The shell command above will run in the background, you can view the results the file log.txt. The python command will run in the console, you can view the results on the interface. The metric value of the test dataset will be as follows:
  249. ```text
  250. eval step: 0, Accuracy: 96.875
  251. eval step: 1, Accuracy: 89.0625
  252. eval step: 2, Accuracy: 89.58333333333334
  253. ...
  254. The best Accuracy: 90.426505583013
  255. ```
  256. #### evaluation on MNLI dataset
  257. ```text
  258. python
  259. python eval.py --task_name='mnli' --model_dir='/home/xxx/model_dir/' --data_dir='/home/xxx/data_dir/'
  260. shell
  261. sh scripts/run_eval.sh
  262. ```
  263. The shell command above will run in the background, you can view the results the file log.txt. The python command will run in the console, you can view the results on the interface. The metric value of the test dataset will be as follows:
  264. ```text
  265. eval step: 0, Accuracy: 90.625
  266. eval step: 1, Accuracy: 81.25
  267. eval step: 2, Accuracy: 79.16666666666666
  268. ...
  269. The best Accuracy: 83.70860927152319
  270. ```
  271. ## [Model Description](#contents)
  272. ## [Performance](#contents)
  273. ### training Performance
  274. | Parameters | GPU |
  275. | ----------------- | :---------------------------------------------------- |
  276. | Model Version | TernaryBERT |
  277. | Resource | NV SMX2 V100-32G |
  278. | uploaded Date | 08/20/2020 |
  279. | MindSpore Version | 1.1.0 |
  280. | Dataset | STS-B, QNLI, MNLI |
  281. | batch_size | 16, 16, 16 |
  282. | Metric value | 87.58388835685437, 90.426505583013, 83.70860927152319 |
  283. | Speed | |
  284. | Total time | |
  285. ### Inference Performance
  286. | Parameters | GPU |
  287. | ----------------- | :---------------------------------------------------- |
  288. | Model Version | TernaryBERT |
  289. | Resource | NV SMX2 V100-32G |
  290. | uploaded Date | 08/20/2020 |
  291. | MindSpore Version | 1.1.0 |
  292. | Dataset | STS-B, QNLI, MNLI |
  293. | batch_size | 32, 32, 32 |
  294. | Accuracy | 87.58388835685437, 90.426505583013, 83.70860927152319 |
  295. | Speed | |
  296. | Total time | |
  297. # [Description of Random Situation](#contents)
  298. In train.py, we set do_shuffle to shuffle the dataset.
  299. In config.py, we set the hidden_dropout_prob, attention_pros_dropout_prob and cls_dropout_prob to dropout some network node.
  300. # [ModelZoo Homepage](#contents)
  301. Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).