diff --git a/model_zoo/official/nlp/textrcnn/data_helpers.py b/model_zoo/official/nlp/textrcnn/data_helpers.py
new file mode 100644
index 0000000000..d0ac1599bf
--- /dev/null
+++ b/model_zoo/official/nlp/textrcnn/data_helpers.py
@@ -0,0 +1,60 @@
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""dataset helpers api"""
+import argparse
+import os
+import numpy as np
+parser = argparse.ArgumentParser(description='textrcnn')
+parser.add_argument('--task', type=str, help='the data preprocess task, including dataset_split.')
+parser.add_argument('--data_dir', type=str, help='the source dataset directory.', default='./data_src')
+parser.add_argument('--out_dir', type=str, help='the target dataset directory.', default='./data')
+
+args = parser.parse_args()
+
+
+def dataset_split(label):
+ """dataset_split api"""
+ # label can be 'pos' or 'neg'
+ pos_samples = []
+ pos_file = os.path.join(args.data_dir, "rt-polaritydata", "rt-polarity."+label)
+ pfhand = open(pos_file, encoding='utf-8')
+ pos_samples += pfhand.readlines()
+ pfhand.close()
+ perm = np.random.permutation(len(pos_samples))
+ # print(perm[0:int(len(pos_samples)*0.8)])
+ perm_train = perm[0:int(len(pos_samples)*0.9)]
+ perm_test = perm[int(len(pos_samples)*0.9):]
+ pos_samples_train = []
+ pos_samples_test = []
+ for pt in perm_train:
+ pos_samples_train.append(pos_samples[pt])
+ for pt in perm_test:
+ pos_samples_test.append(pos_samples[pt])
+ f = open(os.path.join(args.out_dir, 'train', label), "w")
+ f.write(''.join(pos_samples_train))
+ f.close()
+
+ f = open(os.path.join(args.out_dir, 'test', label), "w")
+ f.write(''.join(pos_samples_test))
+ f.close()
+
+
+
+if __name__ == '__main__':
+ if args.task == "dataset_split":
+ dataset_split('pos')
+ dataset_split('neg')
+
+ # search(args.q)
diff --git a/model_zoo/official/nlp/textrcnn/eval.py b/model_zoo/official/nlp/textrcnn/eval.py
new file mode 100644
index 0000000000..fc2bbfd7f1
--- /dev/null
+++ b/model_zoo/official/nlp/textrcnn/eval.py
@@ -0,0 +1,61 @@
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""model evaluation script"""
+import os
+import argparse
+import numpy as np
+
+import mindspore.nn as nn
+import mindspore.context as context
+from mindspore import Tensor
+from mindspore.train import Model
+from mindspore.nn.metrics import Accuracy
+from mindspore.train.serialization import load_checkpoint, load_param_into_net
+from mindspore.train.callback import LossMonitor
+from mindspore.common import set_seed
+
+from src.config import textrcnn_cfg as cfg
+from src.dataset import create_dataset
+from src.textrcnn import textrcnn
+
+set_seed(1)
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(description='textrcnn')
+ parser.add_argument('--ckpt_path', type=str)
+ args = parser.parse_args()
+ context.set_context(
+ mode=context.GRAPH_MODE,
+ save_graphs=False,
+ device_target="Ascend")
+
+ device_id = int(os.getenv('DEVICE_ID'))
+ context.set_context(device_id=device_id)
+
+ embedding_table = np.loadtxt(os.path.join(cfg.preprocess_path, "weight.txt")).astype(np.float32)
+ network = textrcnn(weight=Tensor(embedding_table), vocab_size=embedding_table.shape[0], \
+ cell=cfg.cell, batch_size=cfg.batch_size)
+ loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
+ opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
+ loss_cb = LossMonitor()
+ print("============== Starting Testing ==============")
+ ds_eval = create_dataset(cfg.preprocess_path, cfg.batch_size, 1, False)
+ param_dict = load_checkpoint(args.ckpt_path)
+ load_param_into_net(network, param_dict)
+ network.set_train(False)
+ model = Model(network, loss, opt, metrics={'acc': Accuracy()}, amp_level='O3')
+ acc = model.eval(ds_eval, dataset_sink_mode=False)
+ print("============== Accuracy:{} ==============".format(acc))
diff --git a/model_zoo/official/nlp/textrcnn/readme.md b/model_zoo/official/nlp/textrcnn/readme.md
new file mode 100644
index 0000000000..ba5d0b6c62
--- /dev/null
+++ b/model_zoo/official/nlp/textrcnn/readme.md
@@ -0,0 +1,144 @@
+# TextRCNN
+
+## Contents
+
+- [TextRCNN Description](#textrcnn-description)
+- [Model Architecture](#model-architecture)
+- [Dataset](#dataset)
+- [Environment Requirements](#environment-requirements)
+- [Quick Start](#quick-start)
+- [Script Description](#script-description)
+- [ModelZoo Homepage](#modelzoo-homepage)
+
+## [TextRCNN Description](#contents)
+
+TextRCNN, a model for text classification, which is proposed by the Chinese Academy of Sciences in 2015.
+TextRCNN actually combines RNN and CNN, first uses bidirectional RNN to obtain upper semantic and grammatical information of the input text,
+and then uses maximum pooling to automatically filter out the most important feature.
+Then connect a fully connected layer for classification.
+
+The TextCNN network structure contains a convolutional layer and a pooling layer. In RCNN, the feature extraction function of the convolutional layer is replaced by RNN. The overall structure consists of RNN and pooling layer, so it is called RCNN.
+
+[Paper](https://www.aaai.org/ocs/index.php/AAAI/AAAI15/paper/download/9745/9552): Siwei Lai, Liheng Xu, Kang Liu, Jun Zhao: Recurrent Convolutional Neural Networks for Text Classification. AAAI 2015: 2267-2273
+
+## [Model Architecture](#contents)
+
+Specifically, the TextRCNN is mainly composed of three parts: a recurrent structure layer, a max-pooling layer, and a fully connected layer. In the paper, the length of the word vector $|e|=50$, the length of the context vector $|c|=50$, the hidden layer size $ H=100$, the learning rate $\alpha=0.01$, the amount of words is $|V|$, the input is a sequence of words, and the output is a vector containing categories.
+
+## [Dataset](#contents)
+
+Dataset used: [Sentence polarity dataset v1.0]()
+
+- Dataset size:10662 movie comments in 2 classes, 9596 comments for train set, 1066 comments for test set.
+- Data format:text files. The processed data is in ```./data/```
+
+## [Environment Requirements](#contents)
+
+- Hardware: Ascend
+- Framework: [MindSpore](https://www.mindspore.cn/install/en)
+- For more information, please check the resources below:[MindSpore tutorials](https://www.mindspore.cn/tutorial/zh-CN/master/index.html), [MindSpore API](https://www.mindspore.cn/api/zh-CN/master/index.html).
+
+## [Quick Start](#contents)
+
+- Preparing enviroment
+
+```python
+ # download the pretrained GoogleNews-vectors-negative300.bin, put it into /tmp
+ # you can download from https://code.google.com/archive/p/word2vec/,
+ # or from https://pan.baidu.com/s/1NC2ekA_bJ0uSL7BF3SjhIg, code: yk9a
+
+ mv /tmp/GoogleNews-vectors-negative300.bin ./word2vec/
+```
+
+- Preparing data
+
+```python
+ # split the dataset by the following scripts.
+ mkdir -p data/test && mkdir -p data/train
+ python data_helpers.py --task dataset_split --data_dir dataset_dir
+
+```
+
+- Modify the source code in ```mindspore/train/model.py```, line 173, add "O3".
+
+```python
+ self._eval_network = nn.WithEvalCell(self._network, self._loss_fn, self._amp_level in ["O2", "O3"])
+```
+
+- Runing on Ascend
+
+```python
+# run training
+DEVICE_ID=7 python train.py
+# or you can use the shell script to train in background
+bash scripts/run_train.sh
+
+# run evaluating
+DEVICE_ID=7 python eval.py --ckpt_path ./ckpt/lstm-10_149.ckpt
+# or you can use the shell script to evaluate in background
+bash scripts/run_eval.sh
+```
+
+## [Script Description](#contents)
+
+### [Script and Sample Code](#contents)
+
+```python
+├── model_zoo
+ ├── README.md // descriptions about all the models
+ ├── textrcnn
+ ├── README.md // descriptions about TextRCNN
+ ├── data_src
+ │ ├──rt-polaritydata // directory to save the source data
+ │ ├──rt-polaritydata.README.1.0.txt // readme file of dataset
+ ├── scripts
+ │ ├──run_train.sh // shell script for train on Ascend
+ │ ├──run_eval.sh // shell script for evaluation on Ascend
+ │ ├──sample.txt // example shell to run the above the two scripts
+ ├── src
+ │ ├──dataset.py // creating dataset
+ │ ├──textrcnn.py // textrcnn architecture
+ │ ├──config.py // parameter configuration
+ ├── train.py // training script
+ ├── eval.py // evaluation script
+ ├── data_helpers.py // dataset split script
+ ├── sample.txt // the shell to train and eval the model without scripts
+```
+
+### [Script Parameters](#contents)
+
+Parameters for both training and evaluation can be set in config.py
+
+- config for Textrcnn, Sentence polarity dataset v1.0.
+
+ ```python
+ 'num_epochs': 10, # total training epochs
+ 'batch_size': 64, # training batch size
+ 'cell': 'lstm', # the RNN architecture, can be 'vanilla', 'gru' and 'lstm'.
+ 'opt': 'adam', # the optimizer strategy, can be 'adam' or 'momentum'
+ 'ckpt_folder_path': './ckpt', # the path to save the checkpoints
+ 'preprocess_path': './preprocess', # the directory to save the processed data
+ 'preprocess' : 'false', # whethere to preprocess the data
+ 'data_path': './data/', # the path to store the splited data
+ 'lr': 1e-3, # the training learning rate
+ 'emb_path': './word2vec', # the directory to save the embedding file
+ 'embed_size': 300, # the dimension of the word embedding
+ 'save_checkpoint_steps': 149, # per step to save the checkpoint
+ 'keep_checkpoint_max': 10, # max checkpoints to save
+ 'momentum': 0.9 # the momentum rate
+ ```
+
+### Performance
+
+| Model | MindSpore + Ascend | TensorFlow+GPU |
+| -------------------------- | ----------------------------- | ------------------------- |
+| Resource | Ascend 910 | NV SMX2 V100-32G |
+| Version | 1.0.1 | 1.4.0 |
+| Dataset | Sentence polarity dataset v1.0 | Sentence polarity dataset v1.0 |
+| batch_size | 64 | 64 |
+| Accuracy | 0.78 | 0.78 |
+| Speed | 78ms/step | 89ms/step |
+
+## [ModelZoo Homepage](#contents)
+
+ Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
diff --git a/model_zoo/official/nlp/textrcnn/sample.txt b/model_zoo/official/nlp/textrcnn/sample.txt
new file mode 100644
index 0000000000..1f742cc341
--- /dev/null
+++ b/model_zoo/official/nlp/textrcnn/sample.txt
@@ -0,0 +1,2 @@
+DEVICE_ID=7 python train.py
+DEVICE_ID=7 python eval.py --ckpt_path ./ckpt/lstm-1_149.ckpt
\ No newline at end of file
diff --git a/model_zoo/official/nlp/textrcnn/scripts/run_eval.sh b/model_zoo/official/nlp/textrcnn/scripts/run_eval.sh
new file mode 100644
index 0000000000..19c38957c4
--- /dev/null
+++ b/model_zoo/official/nlp/textrcnn/scripts/run_eval.sh
@@ -0,0 +1,20 @@
+#!/bin/bash
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+ulimit -u unlimited
+
+BASEPATH=$(cd "`dirname $0`" || exit; pwd)
+export PYTHONPATH=${BASEPATH}:$PYTHONPATH
+python ${BASEPATH}/../eval.py > --ckpt_path $1 ./eval.log 2>&1 &
diff --git a/model_zoo/official/nlp/textrcnn/scripts/run_train.sh b/model_zoo/official/nlp/textrcnn/scripts/run_train.sh
new file mode 100644
index 0000000000..5e87829a5b
--- /dev/null
+++ b/model_zoo/official/nlp/textrcnn/scripts/run_train.sh
@@ -0,0 +1,21 @@
+#!/bin/bash
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+ulimit -u unlimited
+
+BASEPATH=$(cd "`dirname $0`" || exit; pwd)
+export PYTHONPATH=${BASEPATH}:$PYTHONPATH
+python ${BASEPATH}/../train.py > ./train.log 2>&1 &
\ No newline at end of file
diff --git a/model_zoo/official/nlp/textrcnn/src/config.py b/model_zoo/official/nlp/textrcnn/src/config.py
new file mode 100644
index 0000000000..eacda81a19
--- /dev/null
+++ b/model_zoo/official/nlp/textrcnn/src/config.py
@@ -0,0 +1,38 @@
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""
+network config
+"""
+from easydict import EasyDict as edict
+
+# LSTM CONFIG
+textrcnn_cfg = edict({
+ 'pos_dir': 'data/rt-polaritydata/rt-polarity.pos',
+ 'neg_dir': 'data/rt-polaritydata/rt-polarity.neg',
+ 'num_epochs': 10,
+ 'batch_size': 64,
+ 'cell': 'lstm',
+ 'opt': 'adam',
+ 'ckpt_folder_path': './ckpt',
+ 'preprocess_path': './preprocess',
+ 'preprocess': 'false',
+ 'data_path': './data/',
+ 'lr': 1e-3,
+ 'emb_path': './word2vec',
+ 'embed_size': 300,
+ 'save_checkpoint_steps': 149,
+ 'keep_checkpoint_max': 10,
+ 'momentum': 0.9
+})
diff --git a/model_zoo/official/nlp/textrcnn/src/dataset.py b/model_zoo/official/nlp/textrcnn/src/dataset.py
new file mode 100644
index 0000000000..a793cf794b
--- /dev/null
+++ b/model_zoo/official/nlp/textrcnn/src/dataset.py
@@ -0,0 +1,179 @@
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""dataset api"""
+import os
+from itertools import chain
+import gensim
+import numpy as np
+
+from mindspore.mindrecord import FileWriter
+import mindspore.dataset as ds
+
+# preprocess part
+def encode_samples(tokenized_samples, word_to_idx):
+ """ encode word to index """
+ features = []
+ for sample in tokenized_samples:
+ feature = []
+ for token in sample:
+ if token in word_to_idx:
+ feature.append(word_to_idx[token])
+ else:
+ feature.append(0)
+ features.append(feature)
+ return features
+
+
+def pad_samples(features, maxlen=50, pad=0):
+ """ pad all features to the same length """
+ padded_features = []
+ for feature in features:
+ if len(feature) >= maxlen:
+ padded_feature = feature[:maxlen]
+ else:
+ padded_feature = feature
+ while len(padded_feature) < maxlen:
+ padded_feature.append(pad)
+ padded_features.append(padded_feature)
+ return padded_features
+
+
+def read_imdb(path, seg='train'):
+ """ read imdb dataset """
+ pos_or_neg = ['pos', 'neg']
+ data = []
+ for label in pos_or_neg:
+
+ f = os.path.join(path, seg, label)
+ rf = open(f, 'r')
+ for line in rf:
+ line = line.strip()
+ if label == 'pos':
+ data.append([line, 1])
+ elif label == 'neg':
+ data.append([line, 0])
+
+ return data
+
+
+def tokenizer(text):
+ return [tok.lower() for tok in text.split(' ')]
+
+
+def collect_weight(glove_path, vocab, word_to_idx, embed_size):
+ """ collect weight """
+ vocab_size = len(vocab)
+ # wvmodel = gensim.models.KeyedVectors.load_word2vec_format(os.path.join(glove_path, 'glove.6B.300d.txt'),
+ # binary=False, encoding='utf-8')
+ wvmodel = gensim.models.KeyedVectors.load_word2vec_format(os.path.join(glove_path, \
+ 'GoogleNews-vectors-negative300.bin'), binary=True)
+ weight_np = np.zeros((vocab_size + 1, embed_size)).astype(np.float32)
+
+ idx_to_word = {i + 1: word for i, word in enumerate(vocab)}
+ idx_to_word[0] = ''
+
+ for i in range(len(wvmodel.index2word)):
+ try:
+ index = word_to_idx[wvmodel.index2word[i]]
+ except KeyError:
+ continue
+ weight_np[index, :] = wvmodel.get_vector(
+ idx_to_word[word_to_idx[wvmodel.index2word[i]]])
+ return weight_np
+
+
+def preprocess(data_path, glove_path, embed_size):
+ """ preprocess the train and test data """
+ train_data = read_imdb(data_path, 'train')
+ test_data = read_imdb(data_path, 'test')
+
+ train_tokenized = []
+ test_tokenized = []
+ for review, _ in train_data:
+ train_tokenized.append(tokenizer(review))
+ for review, _ in test_data:
+ test_tokenized.append(tokenizer(review))
+
+ vocab = set(chain(*train_tokenized))
+ vocab_size = len(vocab)
+ print("vocab_size: ", vocab_size)
+
+ word_to_idx = {word: i + 1 for i, word in enumerate(vocab)}
+ word_to_idx[''] = 0
+
+ train_features = np.array(pad_samples(encode_samples(train_tokenized, word_to_idx))).astype(np.int32)
+ train_labels = np.array([score for _, score in train_data]).astype(np.int32)
+ test_features = np.array(pad_samples(encode_samples(test_tokenized, word_to_idx))).astype(np.int32)
+ test_labels = np.array([score for _, score in test_data]).astype(np.int32)
+
+ weight_np = collect_weight(glove_path, vocab, word_to_idx, embed_size)
+ return train_features, train_labels, test_features, test_labels, weight_np, vocab_size
+
+
+def get_imdb_data(labels_data, features_data):
+ data_list = []
+ for i, (label, feature) in enumerate(zip(labels_data, features_data)):
+ data_json = {"id": i,
+ "label": int(label),
+ "feature": feature.reshape(-1)}
+ data_list.append(data_json)
+ return data_list
+
+
+def convert_to_mindrecord(embed_size, data_path, proprocess_path, glove_path):
+ """ convert imdb dataset to mindrecord """
+
+ num_shard = 4
+ train_features, train_labels, test_features, test_labels, weight_np, _ = \
+ preprocess(data_path, glove_path, embed_size)
+ np.savetxt(os.path.join(proprocess_path, 'weight.txt'), weight_np)
+
+ print("train_features.shape:", train_features.shape, "train_labels.shape:", train_labels.shape, "weight_np.shape:",\
+ weight_np.shape, "type:", train_labels.dtype)
+ # write mindrecord
+ schema_json = {"id": {"type": "int32"},
+ "label": {"type": "int32"},
+ "feature": {"type": "int32", "shape": [-1]}}
+
+ writer = FileWriter(os.path.join(proprocess_path, 'aclImdb_train.mindrecord'), num_shard)
+ data = get_imdb_data(train_labels, train_features)
+ writer.add_schema(schema_json, "nlp_schema")
+ writer.add_index(["id", "label"])
+ writer.write_raw_data(data)
+ writer.commit()
+
+ writer = FileWriter(os.path.join(proprocess_path, 'aclImdb_test.mindrecord'), num_shard)
+ data = get_imdb_data(test_labels, test_features)
+ writer.add_schema(schema_json, "nlp_schema")
+ writer.add_index(["id", "label"])
+ writer.write_raw_data(data)
+ writer.commit()
+
+
+def create_dataset(base_path, batch_size, num_epochs, is_train):
+ """Create dataset for training."""
+ columns_list = ["feature", "label"]
+ num_consumer = 4
+
+ if is_train:
+ path = os.path.join(base_path, 'aclImdb_train.mindrecord0')
+ else:
+ path = os.path.join(base_path, 'aclImdb_test.mindrecord0')
+
+ data_set = ds.MindDataset(path, columns_list, num_consumer)
+ ds.config.set_seed(1)
+ data_set = data_set.shuffle(buffer_size=data_set.get_dataset_size())
+ data_set = data_set.batch(batch_size, drop_remainder=True)
+ return data_set
diff --git a/model_zoo/official/nlp/textrcnn/src/textrcnn.py b/model_zoo/official/nlp/textrcnn/src/textrcnn.py
new file mode 100644
index 0000000000..6b3fb00e92
--- /dev/null
+++ b/model_zoo/official/nlp/textrcnn/src/textrcnn.py
@@ -0,0 +1,196 @@
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""model textrcnn"""
+import numpy as np
+
+import mindspore.nn as nn
+from mindspore.ops import operations as P
+from mindspore.ops import functional as F
+from mindspore.common.parameter import Parameter
+from mindspore import Tensor
+from mindspore.common import dtype as mstype
+
+class textrcnn(nn.Cell):
+ """class textrcnn"""
+ def __init__(self, weight, vocab_size, cell, batch_size):
+ super(textrcnn, self).__init__()
+ self.num_hiddens = 512
+ self.embed_size = 300
+ self.num_classes = 2
+ self.batch_size = batch_size
+ k = (1 / self.num_hiddens) ** 0.5
+
+ self.embedding = nn.Embedding(vocab_size, self.embed_size, embedding_table=weight)
+ self.embedding.embedding_table.requires_grad = False
+ self.cell = cell
+
+ self.cast = P.Cast()
+
+ self.h1 = Tensor(np.zeros(shape=(self.batch_size, self.num_hiddens)).astype(np.float16))
+ self.c1 = Tensor(np.zeros(shape=(self.batch_size, self.num_hiddens)).astype(np.float16))
+
+ if cell == "lstm":
+ self.lstm = P.DynamicRNN(forget_bias=0.0)
+ self.w1_fw = Parameter(
+ np.random.uniform(-k, k, (self.embed_size + self.num_hiddens, 4 * self.num_hiddens)).astype(
+ np.float16), name="w1_fw")
+ self.b1_fw = Parameter(np.random.uniform(-k, k, (4 * self.num_hiddens)).astype(np.float16),
+ name="b1_fw")
+ self.w1_bw = Parameter(
+ np.random.uniform(-k, k, (self.embed_size + self.num_hiddens, 4 * self.num_hiddens)).astype(
+ np.float16), name="w1_bw")
+ self.b1_bw = Parameter(np.random.uniform(-k, k, (4 * self.num_hiddens)).astype(np.float16),
+ name="b1_bw")
+ self.h1 = Tensor(np.zeros(shape=(1, self.batch_size, self.num_hiddens)).astype(np.float16))
+ self.c1 = Tensor(np.zeros(shape=(1, self.batch_size, self.num_hiddens)).astype(np.float16))
+
+ if cell == "vanilla":
+ self.rnnW_fw = nn.Dense(self.num_hiddens, self.num_hiddens)
+ self.rnnU_fw = nn.Dense(self.embed_size, self.num_hiddens)
+ self.rnnW_bw = nn.Dense(self.num_hiddens, self.num_hiddens)
+ self.rnnU_bw = nn.Dense(self.embed_size, self.num_hiddens)
+
+ if cell == "gru":
+ self.rnnWr_fw = nn.Dense(self.num_hiddens + self.embed_size, self.num_hiddens)
+ self.rnnWz_fw = nn.Dense(self.num_hiddens + self.embed_size, self.num_hiddens)
+ self.rnnWh_fw = nn.Dense(self.num_hiddens + self.embed_size, self.num_hiddens)
+ self.rnnWr_bw = nn.Dense(self.num_hiddens + self.embed_size, self.num_hiddens)
+ self.rnnWz_bw = nn.Dense(self.num_hiddens + self.embed_size, self.num_hiddens)
+ self.rnnWh_bw = nn.Dense(self.num_hiddens + self.embed_size, self.num_hiddens)
+ self.ones = Tensor(np.ones(shape=(self.batch_size, self.num_hiddens)).astype(np.float16))
+
+ self.transpose = P.Transpose()
+ self.reduce_max = P.ReduceMax()
+ self.expand_dims = P.ExpandDims()
+ self.concat = P.Concat()
+
+ self.reshape = P.Reshape()
+ self.left_pad_tensor = Tensor(np.zeros((1, self.batch_size, self.num_hiddens)).astype(np.float16))
+ self.right_pad_tensor = Tensor(np.zeros((1, self.batch_size, self.num_hiddens)).astype(np.float16))
+ self.output_dense = nn.Dense(self.num_hiddens * 1, 2)
+ self.concat0 = P.Concat(0)
+ self.concat2 = P.Concat(2)
+ self.concat1 = P.Concat(1)
+ self.text_rep_dense = nn.Dense(2 * self.num_hiddens + self.embed_size, self.num_hiddens)
+ self.mydense = nn.Dense(self.num_hiddens, 2)
+ self.drop_out = nn.Dropout(keep_prob=0.7)
+ self.tanh = P.Tanh()
+ self.sigmoid = P.Sigmoid()
+ self.slice = P.Slice()
+ # self.lstm = nn.LSTM(input_size=input_size,hidden_size=hidden_size,num_layers=num_layers,has_bias=has_bias, batch_first=batch_first, bidirectional=bidirectional, dropout=0.0)
+
+ def construct(self, x):
+ """class construction"""
+ # x: bs, sl
+ output_fw = x
+ output_bw = x
+
+ if self.cell == "vanilla":
+ x = self.embedding(x) # bs, sl, emb_size
+ x = self.cast(x, mstype.float16)
+ x = self.transpose(x, (1, 0, 2)) # sl, bs, emb_size
+ x = self.drop_out(x) # sl,bs, emb_size
+
+ h1_fw = self.cast(self.h1, mstype.float16) # bs, num_hidden
+ h1_fw = self.tanh(self.rnnW_fw(h1_fw) + self.rnnU_fw(x[0, :, :])) # bs, num_hidden
+ output_fw = self.expand_dims(h1_fw, 0) # 1, bs, num_hidden
+
+ for i in range(1, F.shape(x)[0]):
+ h1_fw = self.tanh(self.rnnW_fw(h1_fw) + self.rnnU_fw(x[i, :, :])) # 1, bs, num_hidden
+ h1_after_expand_fw = self.expand_dims(h1_fw, 0)
+ output_fw = self.concat((output_fw, h1_after_expand_fw)) # 2/3/4.., bs, num_hidden
+ output_fw = self.cast(output_fw, mstype.float16) # sl, bs, num_hidden
+
+ h1_bw = self.cast(self.h1, mstype.float16) # bs, num_hidden
+ h1_bw = self.tanh(self.rnnW_bw(h1_bw) + self.rnnU_bw(x[F.shape(x)[0] - 1, :, :])) # bs, num_hidden
+ output_bw = self.expand_dims(h1_bw, 0) # 1, bs, num_hidden
+
+ for i in range(F.shape(x)[0] - 2, -1, -1):
+ h1_bw = self.tanh(self.rnnW_bw(h1_bw) + self.rnnU_bw(x[i, :, :])) # 1, bs, num_hidden
+ h1_after_expand_bw = self.expand_dims(h1_bw, 0)
+ output_bw = self.concat((h1_after_expand_bw, output_bw)) # 2/3/4.., bs, num_hidden
+ output_bw = self.cast(output_bw, mstype.float16) # sl, bs, num_hidden
+
+ if self.cell == "gru":
+ x = self.embedding(x) # bs, sl, emb_size
+ x = self.cast(x, mstype.float16)
+ x = self.transpose(x, (1, 0, 2)) # sl, bs, emb_size
+ x = self.drop_out(x) # sl,bs, emb_size
+
+ h_fw = self.cast(self.h1, mstype.float16)
+
+ h_x_fw = self.concat1((h_fw, x[0, :, :]))
+ r_fw = self.sigmoid(self.rnnWr_fw(h_x_fw))
+ z_fw = self.sigmoid(self.rnnWz_fw(h_x_fw))
+ h_tilde_fw = self.tanh(self.rnnWh_fw(self.concat1((r_fw * h_fw, x[0, :, :]))))
+ h_fw = (self.ones - z_fw) * h_fw + z_fw * h_tilde_fw
+ output_fw = self.expand_dims(h_fw, 0)
+
+ for i in range(1, F.shape(x)[0]):
+ h_x_fw = self.concat1((h_fw, x[i, :, :]))
+ r_fw = self.sigmoid(self.rnnWr_fw(h_x_fw))
+ z_fw = self.sigmoid(self.rnnWz_fw(h_x_fw))
+ h_tilde_fw = self.tanh(self.rnnWh_fw(self.concat1((r_fw * h_fw, x[i, :, :]))))
+ h_fw = (self.ones - z_fw) * h_fw + z_fw * h_tilde_fw
+ h_after_expand_fw = self.expand_dims(h_fw, 0)
+ output_fw = self.concat((output_fw, h_after_expand_fw))
+ output_fw = self.cast(output_fw, mstype.float16)
+
+ h_bw = self.cast(self.h1, mstype.float16) # bs, num_hidden
+
+ h_x_bw = self.concat1((h_bw, x[F.shape(x)[0] - 1, :, :]))
+ r_bw = self.sigmoid(self.rnnWr_bw(h_x_bw))
+ z_bw = self.sigmoid(self.rnnWz_bw(h_x_bw))
+ h_tilde_bw = self.tanh(self.rnnWh_bw(self.concat1((r_bw * h_bw, x[F.shape(x)[0] - 1, :, :]))))
+ h_bw = (self.ones - z_bw) * h_bw + z_bw * h_tilde_bw
+ output_bw = self.expand_dims(h_bw, 0)
+ for i in range(F.shape(x)[0] - 2, -1, -1):
+ h_x_bw = self.concat1((h_bw, x[i, :, :]))
+ r_bw = self.sigmoid(self.rnnWr_bw(h_x_bw))
+ z_bw = self.sigmoid(self.rnnWz_bw(h_x_bw))
+ h_tilde_bw = self.tanh(self.rnnWh_bw(self.concat1((r_bw * h_bw, x[i, :, :]))))
+ h_bw = (self.ones - z_bw) * h_bw + z_bw * h_tilde_bw
+ h_after_expand_bw = self.expand_dims(h_bw, 0)
+ output_bw = self.concat((h_after_expand_bw, output_bw))
+ output_bw = self.cast(output_bw, mstype.float16)
+ if self.cell == 'lstm':
+ x = self.embedding(x) # bs, sl, emb_size
+ x = self.cast(x, mstype.float16)
+ x = self.transpose(x, (1, 0, 2)) # sl, bs, emb_size
+ x = self.drop_out(x) # sl,bs, emb_size
+
+ h1_fw_init = self.h1 # bs, num_hidden
+ c1_fw_init = self.c1 # bs, num_hidden
+
+ _, output_fw, _, _, _, _, _, _ = self.lstm(x, self.w1_fw, self.b1_fw, None, h1_fw_init, c1_fw_init)
+ output_fw = self.cast(output_fw, mstype.float16) # sl, bs, num_hidden
+
+ h1_bw_init = self.h1 # bs, num_hidden
+ c1_bw_init = self.c1 # bs, num_hidden
+ _, output_bw, _, _, _, _, _, _ = self.lstm(x, self.w1_bw, self.b1_bw, None, h1_bw_init, c1_bw_init)
+ output_bw = self.cast(output_bw, mstype.float16) # sl, bs, hidden
+
+ c_left = self.concat0((self.left_pad_tensor, output_fw[:F.shape(x)[0] - 1])) # sl, bs, num_hidden
+ c_right = self.concat0((output_bw[1:], self.right_pad_tensor)) # sl, bs, num_hidden
+ output = self.concat2((c_left, self.cast(x, mstype.float16), c_right)) # sl, bs, 2*num_hidden+emb_size
+ output = self.cast(output, mstype.float16)
+
+ output_flat = self.reshape(output, (F.shape(x)[0] * self.batch_size, 2 * self.num_hiddens + self.embed_size))
+ output_dense = self.text_rep_dense(output_flat) # sl*bs, num_hidden
+ output_dense = self.tanh(output_dense) # sl*bs, num_hidden
+ output = self.reshape(output_dense, (F.shape(x)[0], self.batch_size, self.num_hiddens)) # sl, bs, num_hidden
+ output = self.reduce_max(output, 0) # bs, num_hidden
+ outputs = self.cast(self.mydense(output), mstype.float16) # bs, num_classes
+ return outputs
diff --git a/model_zoo/official/nlp/textrcnn/train.py b/model_zoo/official/nlp/textrcnn/train.py
new file mode 100644
index 0000000000..ec74ef4a9f
--- /dev/null
+++ b/model_zoo/official/nlp/textrcnn/train.py
@@ -0,0 +1,74 @@
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""model train script"""
+import os
+import shutil
+import numpy as np
+
+import mindspore.nn as nn
+import mindspore.context as context
+from mindspore import Tensor
+from mindspore.train import Model
+from mindspore.nn.metrics import Accuracy
+from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
+from mindspore.common import set_seed
+
+from src.config import textrcnn_cfg as cfg
+from src.dataset import create_dataset
+from src.dataset import convert_to_mindrecord
+from src.textrcnn import textrcnn
+
+
+set_seed(1)
+
+if __name__ == '__main__':
+
+ context.set_context(
+ mode=context.GRAPH_MODE,
+ save_graphs=False,
+ device_target="Ascend")
+
+ device_id = int(os.getenv('DEVICE_ID'))
+ context.set_context(device_id=device_id)
+
+ if cfg.preprocess == 'true':
+ print("============== Starting Data Pre-processing ==============")
+ if os.path.exists(cfg.preprocess_path):
+ shutil.rmtree(cfg.preprocess_path)
+ os.mkdir(cfg.preprocess_path)
+ convert_to_mindrecord(cfg.embed_size, cfg.data_path, cfg.preprocess_path, cfg.emb_path)
+
+ embedding_table = np.loadtxt(os.path.join(cfg.preprocess_path, "weight.txt")).astype(np.float32)
+
+ network = textrcnn(weight=Tensor(embedding_table), vocab_size=embedding_table.shape[0], \
+ cell=cfg.cell, batch_size=cfg.batch_size)
+
+ loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
+ if cfg.opt == "adam":
+ opt = nn.Adam(params=network.trainable_params(), learning_rate=cfg.lr)
+ elif cfg.opt == "momentum":
+ opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
+
+ loss_cb = LossMonitor()
+ model = Model(network, loss, opt, {'acc': Accuracy()}, amp_level="O3")
+
+ print("============== Starting Training ==============")
+ ds_train = create_dataset(cfg.preprocess_path, cfg.batch_size, cfg.num_epochs, True)
+ config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps, \
+ keep_checkpoint_max=cfg.keep_checkpoint_max)
+ ckpoint_cb = ModelCheckpoint(prefix=cfg.cell, directory=cfg.ckpt_folder_path, config=config_ck)
+ model.train(cfg.num_epochs, ds_train, callbacks=[ckpoint_cb, loss_cb])
+ print("train success")
+
\ No newline at end of file