You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

readme.md 7.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. # TextRCNN
  2. ## Contents
  3. - [TextRCNN Description](#textrcnn-description)
  4. - [Model Architecture](#model-architecture)
  5. - [Dataset](#dataset)
  6. - [Environment Requirements](#environment-requirements)
  7. - [Quick Start](#quick-start)
  8. - [Script Description](#script-description)
  9. - [ModelZoo Homepage](#modelzoo-homepage)
  10. ## [TextRCNN Description](#contents)
  11. TextRCNN, a model for text classification, which is proposed by the Chinese Academy of Sciences in 2015.
  12. TextRCNN actually combines RNN and CNN, first uses bidirectional RNN to obtain upper semantic and grammatical information of the input text,
  13. and then uses maximum pooling to automatically filter out the most important feature.
  14. Then connect a fully connected layer for classification.
  15. The TextCNN network structure contains a convolutional layer and a pooling layer. In RCNN, the feature extraction function of the convolutional layer is replaced by RNN. The overall structure consists of RNN and pooling layer, so it is called RCNN.
  16. [Paper](https://www.aaai.org/ocs/index.php/AAAI/AAAI15/paper/download/9745/9552): Siwei Lai, Liheng Xu, Kang Liu, Jun Zhao: Recurrent Convolutional Neural Networks for Text Classification. AAAI 2015: 2267-2273
  17. ## [Model Architecture](#contents)
  18. Specifically, the TextRCNN is mainly composed of three parts: a recurrent structure layer, a max-pooling layer, and a fully connected layer. In the paper, the length of the word vector $|e|=50$, the length of the context vector $|c|=50$, the hidden layer size $ H=100$, the learning rate $\alpha=0.01$, the amount of words is $|V|$, the input is a sequence of words, and the output is a vector containing categories.
  19. ## [Dataset](#contents)
  20. Dataset used: [Sentence polarity dataset v1.0](http://www.cs.cornell.edu/people/pabo/movie-review-data/)
  21. - Dataset size:10662 movie comments in 2 classes, 9596 comments for train set, 1066 comments for test set.
  22. - Data format:text files. The processed data is in ```./data/```
  23. ## [Environment Requirements](#contents)
  24. - Hardware: Ascend
  25. - Framework: [MindSpore](https://www.mindspore.cn/install/en)
  26. - For more information, please check the resources below:[MindSpore tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html), [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html).
  27. ## [Quick Start](#contents)
  28. - Preparing environment
  29. ```python
  30. # download the pretrained GoogleNews-vectors-negative300.bin, put it into /tmp
  31. # you can download from https://code.google.com/archive/p/word2vec/,
  32. # or from https://pan.baidu.com/s/1NC2ekA_bJ0uSL7BF3SjhIg, code: yk9a
  33. mv /tmp/GoogleNews-vectors-negative300.bin ./word2vec/
  34. ```
  35. - Preparing data
  36. ```python
  37. # split the dataset by the following scripts.
  38. mkdir -p data/test && mkdir -p data/train
  39. python data_helpers.py --task dataset_split --data_dir dataset_dir
  40. ```
  41. - Running on Ascend
  42. ```python
  43. # run training
  44. DEVICE_ID=7 python train.py
  45. # or you can use the shell script to train in background
  46. bash scripts/run_train.sh
  47. # run evaluating
  48. DEVICE_ID=7 python eval.py --ckpt_path {checkpoint path}
  49. # or you can use the shell script to evaluate in background
  50. bash scripts/run_eval.sh
  51. ```
  52. ## [Script Description](#contents)
  53. ### [Script and Sample Code](#contents)
  54. ```python
  55. ├── model_zoo
  56. ├── README.md // descriptions about all the models
  57. ├── textrcnn
  58. ├── README.md // descriptions about TextRCNN
  59. ├── data_src
  60. │ ├──rt-polaritydata // directory to save the source data
  61. │ ├──rt-polaritydata.README.1.0.txt // readme file of dataset
  62. ├── scripts
  63. │ ├──run_train.sh // shell script for train on Ascend
  64. │ ├──run_eval.sh // shell script for evaluation on Ascend
  65. │ ├──sample.txt // example shell to run the above the two scripts
  66. ├── src
  67. │ ├──dataset.py // creating dataset
  68. │ ├──textrcnn.py // textrcnn architecture
  69. │ ├──config.py // parameter configuration
  70. ├── train.py // training script
  71. ├── export.py // export script
  72. ├── eval.py // evaluation script
  73. ├── data_helpers.py // dataset split script
  74. ├── sample.txt // the shell to train and eval the model without scripts
  75. ```
  76. ### [Script Parameters](#contents)
  77. Parameters for both training and evaluation can be set in config.py
  78. - config for Textrcnn, Sentence polarity dataset v1.0.
  79. ```python
  80. 'num_epochs': 10, # total training epochs
  81. 'lstm_num_epochs': 15, # total training epochs when using lstm
  82. 'batch_size': 64, # training batch size
  83. 'cell': 'gru', # the RNN architecture, can be 'vanilla', 'gru' and 'lstm'.
  84. 'ckpt_folder_path': './ckpt', # the path to save the checkpoints
  85. 'preprocess_path': './preprocess', # the directory to save the processed data
  86. 'preprocess' : 'false', # whethere to preprocess the data
  87. 'data_path': './data/', # the path to store the splited data
  88. 'lr': 1e-3, # the training learning rate
  89. 'lstm_lr_init': 2e-3, # learning rate initial value when using lstm
  90. 'lstm_lr_end': 5e-4, # learning rate end value when using lstm
  91. 'lstm_lr_max': 3e-3, # learning eate max value when using lstm
  92. 'lstm_lr_warm_up_epochs': 2 # warm up epoch num when using lstm
  93. 'lstm_lr_adjust_epochs': 9 # lr adjust in lr_adjust_epoch, after that, the lr is lr_end when using lstm
  94. 'emb_path': './word2vec', # the directory to save the embedding file
  95. 'embed_size': 300, # the dimension of the word embedding
  96. 'save_checkpoint_steps': 149, # per step to save the checkpoint
  97. 'keep_checkpoint_max': 10 # max checkpoints to save
  98. ```
  99. ### Performance
  100. | Model | MindSpore + Ascend | TensorFlow+GPU |
  101. | -------------------------- | ----------------------------- | ------------------------- |
  102. | Resource | Ascend 910; OS Euler2.8 | NV SMX2 V100-32G |
  103. | Version | 1.0.1 | 1.4.0 |
  104. | Dataset | Sentence polarity dataset v1.0 | Sentence polarity dataset v1.0 |
  105. | batch_size | 64 | 64 |
  106. | Accuracy | 0.78 | 0.78 |
  107. | Speed | 35ms/step | 77ms/step |
  108. ## [ModelZoo Homepage](#contents)
  109. Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).