You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

README.md 14 kB

4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289
  1. ![](https://www.mindspore.cn/static/img/logo_black.6a5c850d.png)
  2. <!-- TOC -->
  3. - [GNMT v2 For MindSpore](#gnmt-v2-for-mindspore)
  4. - [Model Structure](#model-structure)
  5. - [Dataset](#dataset)
  6. - [Environment Requirements](#environment-requirements)
  7. - [Platform](#platform)
  8. - [Software](#software)
  9. - [Quick Start](#quick-start)
  10. - [Script Description](#script-description)
  11. - [Dataset Preparation](#dataset-preparation)
  12. - [Configuration File](#configuration-file)
  13. - [Training Process](#training-process)
  14. - [Inference Process](#inference-process)
  15. - [Model Description](#model-description)
  16. - [Performance](#performance)
  17. - [Training Performance](#training-performance)
  18. - [Inference Performance](#inference-performance)
  19. - [Random Situation Description](#random-situation-description)
  20. - [Others](#others)
  21. - [ModelZoo HomePage](#modelzoo-homepage)
  22. <!-- /TOC -->
  23. # [GNMT v2 For MindSpore](#contents)
  24. The GNMT v2 model is similar to the model described in [Google's Neural Machine Translation System: Bridging the Gap between Human and Machine Translation](https://arxiv.org/abs/1609.08144), which is mainly used for corpus translation.
  25. # [Model Structure](#contents)
  26. The GNMTv2 model mainly consists of an encoder, a decoder, and an attention mechanism, where the encoder and the decoder use a shared word embedding vector.
  27. Encoder: consists of four long short-term memory (LSTM) layers. The first LSTM layer is bidirectional, while the other three layers are unidirectional.
  28. Decoder: consists of four unidirectional LSTM layers and a fully connected classifier. The output embedding dimension of LSTM is 1024.
  29. Attention mechanism: uses the standardized Bahdanau attention mechanism. First, the first layer output of the decoder is used as the input of the attention mechanism. Then, the computing result of the attention mechanism is connected to the input of the decoder LSTM, which is used as the input of the subsequent LSTM layer.
  30. # [Dataset](#contents)
  31. Note that you can run the scripts based on the dataset mentioned in original paper or widely used in relevant domain/network architecture. In the following sections, we will introduce how to run the scripts using the related dataset below.
  32. - WMT English-German for training.
  33. - WMT newstest2014 for evaluation.
  34. # [Environment Requirements](#contents)
  35. ## Platform
  36. - Hardware (Ascend)
  37. - Prepare hardware environment with Ascend processor.
  38. - Framework
  39. - Install [MindSpore](https://www.mindspore.cn/install/en).
  40. - For more information, please check the resources below:
  41. - [MindSpore tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
  42. - [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
  43. ## Software
  44. ```txt
  45. numpy
  46. sacrebleu==1.4.14
  47. sacremoses==0.0.35
  48. subword_nmt==0.3.7
  49. ```
  50. # [Quick Start](#contents)
  51. The process of GNMTv2 performing the text translation task is as follows:
  52. 1. Download the wmt16 data corpus and extract the dataset. For details, see the chapter "_Dataset_" above.
  53. 2. Dataset preparation and configuration.
  54. 3. Training.
  55. 4. Inference.
  56. After dataset preparation, you can start training and evaluation as follows:
  57. ```bash
  58. # run training example
  59. cd ./scripts
  60. sh run_standalone_train_ascend.sh PRE_TRAIN_DATASET
  61. # run distributed training example
  62. cd ./scripts
  63. sh run_distributed_train_ascend.sh RANK_TABLE_ADDR PRE_TRAIN_DATASET
  64. # run evaluation example
  65. cd ./scripts
  66. sh run_standalone_eval_ascend.sh TEST_DATASET EXISTED_CKPT_PATH \
  67. VOCAB_ADDR BPE_CODE_ADDR TEST_TARGET
  68. ```
  69. # [Script Description](#contents)
  70. The GNMT network script and code result are as follows:
  71. ```text
  72. ├── gnmt
  73. ├── README.md // Introduction of GNMTv2 model.
  74. ├── config
  75. │ ├──__init__.py // User interface.
  76. │ ├──config.py // Configuration instance definition.
  77. │ ├──config.json // Configuration file for pre-train or finetune.
  78. │ ├──config_test.json // Configuration file for test.
  79. ├── src
  80. │ ├──__init__.py // User interface.
  81. │ ├──dataset
  82. │ ├──__init__.py // User interface.
  83. │ ├──base.py // Base class of data loader.
  84. │ ├──bi_data_loader.py // Bilingual data loader.
  85. │ ├──load_dataset.py // Dataset loader to feed into model.
  86. │ ├──schema.py // Define schema of mindrecord.
  87. │ ├──tokenizer.py // Tokenizer class.
  88. │ ├──gnmt_model
  89. │ ├──__init__.py // User interface.
  90. │ ├──attention.py // Bahdanau attention mechanism.
  91. │ ├──beam_search.py // Beam search decoder for inferring.
  92. │ ├──bleu_calculate.py // Calculat the blue accuracy.
  93. │ ├──components.py // Components.
  94. │ ├──create_attention.py // Recurrent attention.
  95. │ ├──create_attn_padding.py // Create attention paddings from input paddings.
  96. │ ├──decoder.py // GNMT decoder component.
  97. │ ├──decoder_beam_infer.py // GNMT decoder component for beam search.
  98. │ ├──dynamic_rnn.py // DynamicRNN.
  99. │ ├──embedding.py // Embedding component.
  100. │ ├──encoder.py // GNMT encoder component.
  101. │ ├──gnmt.py // GNMT model architecture.
  102. │ ├──gnmt_for_infer.py // Use GNMT to infer.
  103. │ ├──gnmt_for_train.py // Use GNMT to train.
  104. │ ├──grad_clip.py // Gradient clip
  105. │ ├──utils
  106. │ ├──__init__.py // User interface.
  107. │ ├──initializer.py // Parameters initializer.
  108. │ ├──load_weights.py // Load weights from a checkpoint or NPZ file.
  109. │ ├──loss_moniter.py // Callback of monitering loss during training step.
  110. │ ├──lr_scheduler.py // Learning rate scheduler.
  111. │ ├──optimizer.py // Optimizer.
  112. ├── scripts
  113. │ ├──run_distributed_train_ascend.sh // Shell script for distributed train on ascend.
  114. │ ├──run_standalone_eval_ascend.sh // Shell script for standalone eval on ascend.
  115. │ ├──run_standalone_train_ascend.sh // Shell script for standalone eval on ascend.
  116. ├── create_dataset.py // Dataset preparation.
  117. ├── eval.py // Infer API entry.
  118. ├── export.py // Export checkpoint file into air models.
  119. ├── mindspore_hub_conf.py // Hub config.
  120. ├── requirements.txt // Requirements of third party package.
  121. ├── train.py // Train API entry.
  122. ```
  123. ## Dataset Preparation
  124. You may use this [shell script](https://github.com/NVIDIA/DeepLearningExamples/blob/master/TensorFlow/Translation/GNMT/scripts/wmt16_en_de.sh) to download and preprocess WMT English-German dataset. Assuming you get the following files:
  125. - train.tok.clean.bpe.32000.en
  126. - train.tok.clean.bpe.32000.de
  127. - vocab.bpe.32000
  128. - bpe.32000
  129. - newstest2014.en
  130. - newstest2014.de
  131. - Convert the original data to mindrecord for training and evaluation:
  132. ``` bash
  133. python create_dataset.py --src_folder /home/workspace/wmt16_de_en --output_folder /home/workspace/dataset_menu
  134. ```
  135. ## Configuration File
  136. The JSON file in the `config/` directory is the template configuration file.
  137. Almost all required options and parameters can be easily assigned, including the training platform, model configuration, and optimizer parameters.
  138. - config for GNMTv2
  139. ```python
  140. 'random_seed': 50 # global random seed
  141. 'epochs':6 # total training epochs
  142. 'batch_size': 128 # training batch size
  143. 'dataset_sink_mode': true # whether use dataset sink mode
  144. 'seq_length': 51 # max length of source sentences
  145. 'vocab_size': 32320 # vocabulary size
  146. 'hidden_size': 1024 # the output's last dimension of dynamicRNN
  147. 'initializer_range': 0.1 # initializer range
  148. 'max_decode_length': 50 # max length of decoder
  149. 'lr': 2e-3 # initial learning rate
  150. 'lr_scheduler': 'WarmupMultiStepLR' # learning rate scheduler
  151. 'existed_ckpt': "" # the absolute full path to save the checkpoint file
  152. ```
  153. For more configuration details, please refer the script `config/config.py` file.
  154. ## Training Process
  155. For a pre-trained model, configure the following options in the `config/config.json` file:
  156. - Select an optimizer ('momentum/adam/lamb' is available).
  157. - Specify `ckpt_prefix` and `ckpt_path` in `checkpoint_path` to save the model file.
  158. - Set other parameters, including dataset configuration and network configuration.
  159. - If a pre-trained model exists, assign `existed_ckpt` to the path of the existing model during fine-tuning.
  160. Start task training on a single device and run the shell script `scripts/run_standalone_train_ascend.sh`:
  161. ```bash
  162. cd ./scripts
  163. sh run_standalone_train_ascend.sh PRE_TRAIN_DATASET
  164. ```
  165. In this script, the `PRE_TRAIN_DATASET` is the dataset address.
  166. Run `scripts/run_distributed_train_ascend.sh` for distributed training of GNMTv2 model.
  167. Task training on multiple devices and run the following command in bash to be executed in `scripts/`.:
  168. ```bash
  169. cd ./scripts
  170. sh run_distributed_train_ascend.sh RANK_TABLE_ADDR PRE_TRAIN_DATASET
  171. ```
  172. Note: the `RANK_TABLE_ADDR` is the hccl_json file assigned when distributed training is running.
  173. Currently, inconsecutive device IDs are not supported in `scripts/run_distributed_train_ascend.sh`. The device ID must start from 0 in the `RANK_TABLE_ADDR` file.
  174. ## Inference Process
  175. For inference using a trained model on multiple hardware platforms, such as Ascend 910.
  176. Set options in `config/config_test.json`.
  177. Run the shell script `scripts/run_standalone_eval_ascend.sh` to process the output token ids to get the BLEU scores.
  178. ```bash
  179. cd ./scripts
  180. sh run_standalone_eval_ascend.sh
  181. sh run_standalone_eval_ascend.sh TEST_DATASET EXISTED_CKPT_PATH \
  182. VOCAB_ADDR BPE_CODE_ADDR TEST_TARGET
  183. ```
  184. The `TEST_DATASET` is the address of inference dataset, and `EXISTED_CKPT_PATH` is the path of the model file generated during training process.
  185. The `VOCAB_ADDR` is the vocabulary address, `BPE_CODE_ADDR` is the bpe code address and the `TEST_TARGET` are the path of answers.
  186. # [Model Description](#contents)
  187. ## Performance
  188. ### Training Performance
  189. | Parameters | Ascend |
  190. | -------------------------- | -------------------------------------------------------------- |
  191. | Resource | Ascend 910; OS Euler2.8 |
  192. | uploaded Date | 11/06/2020 (month/day/year) |
  193. | MindSpore Version | 1.0.0 |
  194. | Dataset | WMT English-German for training |
  195. | Training Parameters | epoch=6, batch_size=128 |
  196. | Optimizer | Adam |
  197. | Loss Function | Softmax Cross Entropy |
  198. | outputs | probability |
  199. | Speed | 344ms/step (8pcs) |
  200. | Total Time | 7800s (8pcs) |
  201. | Loss | 63.35 |
  202. | Params (M) | 613 |
  203. | Checkpoint for inference | 1.8G (.ckpt file) |
  204. | Scripts | [gnmt_v2](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/gnmt_v2) |
  205. ### Inference Performance
  206. | Parameters | Ascend |
  207. | ------------------- | --------------------------- |
  208. | Resource | Ascend 910; OS Euler2.8 |
  209. | Uploaded Date | 11/06/2020 (month/day/year) |
  210. | MindSpore Version | 1.0.0 |
  211. | Dataset | WMT newstest2014 |
  212. | batch_size | 128 |
  213. | Total Time | 1560s |
  214. | outputs | probability |
  215. | Accuracy | BLEU Score= 24.05 |
  216. | Model for inference | 1.8G (.ckpt file) |
  217. # [Random Situation Description](#contents)
  218. There are three random situations:
  219. - Shuffle of the dataset.
  220. - Initialization of some model weights.
  221. - Dropout operations.
  222. Some seeds have already been set in train.py to avoid the randomness of dataset shuffle and weight initialization. If you want to disable dropout, please set the corresponding dropout_prob parameter to 0 in config/config.json.
  223. # [Others](#contents)
  224. This model has been validated in the Ascend environment and is not validated on the CPU and GPU.
  225. # [ModelZoo HomePage](#contents)
  226. Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)