diff --git a/model_zoo/official/nlp/textrcnn/data_helpers.py b/model_zoo/official/nlp/textrcnn/data_helpers.py new file mode 100644 index 0000000000..d0ac1599bf --- /dev/null +++ b/model_zoo/official/nlp/textrcnn/data_helpers.py @@ -0,0 +1,60 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""dataset helpers api""" +import argparse +import os +import numpy as np +parser = argparse.ArgumentParser(description='textrcnn') +parser.add_argument('--task', type=str, help='the data preprocess task, including dataset_split.') +parser.add_argument('--data_dir', type=str, help='the source dataset directory.', default='./data_src') +parser.add_argument('--out_dir', type=str, help='the target dataset directory.', default='./data') + +args = parser.parse_args() + + +def dataset_split(label): + """dataset_split api""" + # label can be 'pos' or 'neg' + pos_samples = [] + pos_file = os.path.join(args.data_dir, "rt-polaritydata", "rt-polarity."+label) + pfhand = open(pos_file, encoding='utf-8') + pos_samples += pfhand.readlines() + pfhand.close() + perm = np.random.permutation(len(pos_samples)) + # print(perm[0:int(len(pos_samples)*0.8)]) + perm_train = perm[0:int(len(pos_samples)*0.9)] + perm_test = perm[int(len(pos_samples)*0.9):] + pos_samples_train = [] + pos_samples_test = [] + for pt in perm_train: + pos_samples_train.append(pos_samples[pt]) + for pt in perm_test: + pos_samples_test.append(pos_samples[pt]) + f = open(os.path.join(args.out_dir, 'train', label), "w") + f.write(''.join(pos_samples_train)) + f.close() + + f = open(os.path.join(args.out_dir, 'test', label), "w") + f.write(''.join(pos_samples_test)) + f.close() + + + +if __name__ == '__main__': + if args.task == "dataset_split": + dataset_split('pos') + dataset_split('neg') + + # search(args.q) diff --git a/model_zoo/official/nlp/textrcnn/eval.py b/model_zoo/official/nlp/textrcnn/eval.py new file mode 100644 index 0000000000..fc2bbfd7f1 --- /dev/null +++ b/model_zoo/official/nlp/textrcnn/eval.py @@ -0,0 +1,61 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""model evaluation script""" +import os +import argparse +import numpy as np + +import mindspore.nn as nn +import mindspore.context as context +from mindspore import Tensor +from mindspore.train import Model +from mindspore.nn.metrics import Accuracy +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from mindspore.train.callback import LossMonitor +from mindspore.common import set_seed + +from src.config import textrcnn_cfg as cfg +from src.dataset import create_dataset +from src.textrcnn import textrcnn + +set_seed(1) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='textrcnn') + parser.add_argument('--ckpt_path', type=str) + args = parser.parse_args() + context.set_context( + mode=context.GRAPH_MODE, + save_graphs=False, + device_target="Ascend") + + device_id = int(os.getenv('DEVICE_ID')) + context.set_context(device_id=device_id) + + embedding_table = np.loadtxt(os.path.join(cfg.preprocess_path, "weight.txt")).astype(np.float32) + network = textrcnn(weight=Tensor(embedding_table), vocab_size=embedding_table.shape[0], \ + cell=cfg.cell, batch_size=cfg.batch_size) + loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True) + opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum) + loss_cb = LossMonitor() + print("============== Starting Testing ==============") + ds_eval = create_dataset(cfg.preprocess_path, cfg.batch_size, 1, False) + param_dict = load_checkpoint(args.ckpt_path) + load_param_into_net(network, param_dict) + network.set_train(False) + model = Model(network, loss, opt, metrics={'acc': Accuracy()}, amp_level='O3') + acc = model.eval(ds_eval, dataset_sink_mode=False) + print("============== Accuracy:{} ==============".format(acc)) diff --git a/model_zoo/official/nlp/textrcnn/readme.md b/model_zoo/official/nlp/textrcnn/readme.md new file mode 100644 index 0000000000..ba5d0b6c62 --- /dev/null +++ b/model_zoo/official/nlp/textrcnn/readme.md @@ -0,0 +1,144 @@ +# TextRCNN + +## Contents + +- [TextRCNN Description](#textrcnn-description) +- [Model Architecture](#model-architecture) +- [Dataset](#dataset) +- [Environment Requirements](#environment-requirements) +- [Quick Start](#quick-start) +- [Script Description](#script-description) +- [ModelZoo Homepage](#modelzoo-homepage) + +## [TextRCNN Description](#contents) + +TextRCNN, a model for text classification, which is proposed by the Chinese Academy of Sciences in 2015. +TextRCNN actually combines RNN and CNN, first uses bidirectional RNN to obtain upper semantic and grammatical information of the input text, +and then uses maximum pooling to automatically filter out the most important feature. +Then connect a fully connected layer for classification. + +The TextCNN network structure contains a convolutional layer and a pooling layer. In RCNN, the feature extraction function of the convolutional layer is replaced by RNN. The overall structure consists of RNN and pooling layer, so it is called RCNN. + +[Paper](https://www.aaai.org/ocs/index.php/AAAI/AAAI15/paper/download/9745/9552): Siwei Lai, Liheng Xu, Kang Liu, Jun Zhao: Recurrent Convolutional Neural Networks for Text Classification. AAAI 2015: 2267-2273 + +## [Model Architecture](#contents) + +Specifically, the TextRCNN is mainly composed of three parts: a recurrent structure layer, a max-pooling layer, and a fully connected layer. In the paper, the length of the word vector $|e|=50$, the length of the context vector $|c|=50$, the hidden layer size $ H=100$, the learning rate $\alpha=0.01$, the amount of words is $|V|$, the input is a sequence of words, and the output is a vector containing categories. + +## [Dataset](#contents) + +Dataset used: [Sentence polarity dataset v1.0]() + +- Dataset size:10662 movie comments in 2 classes, 9596 comments for train set, 1066 comments for test set. +- Data format:text files. The processed data is in ```./data/``` + +## [Environment Requirements](#contents) + +- Hardware: Ascend +- Framework: [MindSpore](https://www.mindspore.cn/install/en) +- For more information, please check the resources below:[MindSpore tutorials](https://www.mindspore.cn/tutorial/zh-CN/master/index.html), [MindSpore API](https://www.mindspore.cn/api/zh-CN/master/index.html). + +## [Quick Start](#contents) + +- Preparing enviroment + +```python + # download the pretrained GoogleNews-vectors-negative300.bin, put it into /tmp + # you can download from https://code.google.com/archive/p/word2vec/, + # or from https://pan.baidu.com/s/1NC2ekA_bJ0uSL7BF3SjhIg, code: yk9a + + mv /tmp/GoogleNews-vectors-negative300.bin ./word2vec/ +``` + +- Preparing data + +```python + # split the dataset by the following scripts. + mkdir -p data/test && mkdir -p data/train + python data_helpers.py --task dataset_split --data_dir dataset_dir + +``` + +- Modify the source code in ```mindspore/train/model.py```, line 173, add "O3". + +```python + self._eval_network = nn.WithEvalCell(self._network, self._loss_fn, self._amp_level in ["O2", "O3"]) +``` + +- Runing on Ascend + +```python +# run training +DEVICE_ID=7 python train.py +# or you can use the shell script to train in background +bash scripts/run_train.sh + +# run evaluating +DEVICE_ID=7 python eval.py --ckpt_path ./ckpt/lstm-10_149.ckpt +# or you can use the shell script to evaluate in background +bash scripts/run_eval.sh +``` + +## [Script Description](#contents) + +### [Script and Sample Code](#contents) + +```python +├── model_zoo + ├── README.md // descriptions about all the models + ├── textrcnn + ├── README.md // descriptions about TextRCNN + ├── data_src + │ ├──rt-polaritydata // directory to save the source data + │ ├──rt-polaritydata.README.1.0.txt // readme file of dataset + ├── scripts + │ ├──run_train.sh // shell script for train on Ascend + │ ├──run_eval.sh // shell script for evaluation on Ascend + │ ├──sample.txt // example shell to run the above the two scripts + ├── src + │ ├──dataset.py // creating dataset + │ ├──textrcnn.py // textrcnn architecture + │ ├──config.py // parameter configuration + ├── train.py // training script + ├── eval.py // evaluation script + ├── data_helpers.py // dataset split script + ├── sample.txt // the shell to train and eval the model without scripts +``` + +### [Script Parameters](#contents) + +Parameters for both training and evaluation can be set in config.py + +- config for Textrcnn, Sentence polarity dataset v1.0. + + ```python + 'num_epochs': 10, # total training epochs + 'batch_size': 64, # training batch size + 'cell': 'lstm', # the RNN architecture, can be 'vanilla', 'gru' and 'lstm'. + 'opt': 'adam', # the optimizer strategy, can be 'adam' or 'momentum' + 'ckpt_folder_path': './ckpt', # the path to save the checkpoints + 'preprocess_path': './preprocess', # the directory to save the processed data + 'preprocess' : 'false', # whethere to preprocess the data + 'data_path': './data/', # the path to store the splited data + 'lr': 1e-3, # the training learning rate + 'emb_path': './word2vec', # the directory to save the embedding file + 'embed_size': 300, # the dimension of the word embedding + 'save_checkpoint_steps': 149, # per step to save the checkpoint + 'keep_checkpoint_max': 10, # max checkpoints to save + 'momentum': 0.9 # the momentum rate + ``` + +### Performance + +| Model | MindSpore + Ascend | TensorFlow+GPU | +| -------------------------- | ----------------------------- | ------------------------- | +| Resource | Ascend 910 | NV SMX2 V100-32G | +| Version | 1.0.1 | 1.4.0 | +| Dataset | Sentence polarity dataset v1.0 | Sentence polarity dataset v1.0 | +| batch_size | 64 | 64 | +| Accuracy | 0.78 | 0.78 | +| Speed | 78ms/step | 89ms/step | + +## [ModelZoo Homepage](#contents) + + Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo). diff --git a/model_zoo/official/nlp/textrcnn/sample.txt b/model_zoo/official/nlp/textrcnn/sample.txt new file mode 100644 index 0000000000..1f742cc341 --- /dev/null +++ b/model_zoo/official/nlp/textrcnn/sample.txt @@ -0,0 +1,2 @@ +DEVICE_ID=7 python train.py +DEVICE_ID=7 python eval.py --ckpt_path ./ckpt/lstm-1_149.ckpt \ No newline at end of file diff --git a/model_zoo/official/nlp/textrcnn/scripts/run_eval.sh b/model_zoo/official/nlp/textrcnn/scripts/run_eval.sh new file mode 100644 index 0000000000..19c38957c4 --- /dev/null +++ b/model_zoo/official/nlp/textrcnn/scripts/run_eval.sh @@ -0,0 +1,20 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +ulimit -u unlimited + +BASEPATH=$(cd "`dirname $0`" || exit; pwd) +export PYTHONPATH=${BASEPATH}:$PYTHONPATH +python ${BASEPATH}/../eval.py > --ckpt_path $1 ./eval.log 2>&1 & diff --git a/model_zoo/official/nlp/textrcnn/scripts/run_train.sh b/model_zoo/official/nlp/textrcnn/scripts/run_train.sh new file mode 100644 index 0000000000..5e87829a5b --- /dev/null +++ b/model_zoo/official/nlp/textrcnn/scripts/run_train.sh @@ -0,0 +1,21 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +ulimit -u unlimited + +BASEPATH=$(cd "`dirname $0`" || exit; pwd) +export PYTHONPATH=${BASEPATH}:$PYTHONPATH +python ${BASEPATH}/../train.py > ./train.log 2>&1 & \ No newline at end of file diff --git a/model_zoo/official/nlp/textrcnn/src/config.py b/model_zoo/official/nlp/textrcnn/src/config.py new file mode 100644 index 0000000000..eacda81a19 --- /dev/null +++ b/model_zoo/official/nlp/textrcnn/src/config.py @@ -0,0 +1,38 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +network config +""" +from easydict import EasyDict as edict + +# LSTM CONFIG +textrcnn_cfg = edict({ + 'pos_dir': 'data/rt-polaritydata/rt-polarity.pos', + 'neg_dir': 'data/rt-polaritydata/rt-polarity.neg', + 'num_epochs': 10, + 'batch_size': 64, + 'cell': 'lstm', + 'opt': 'adam', + 'ckpt_folder_path': './ckpt', + 'preprocess_path': './preprocess', + 'preprocess': 'false', + 'data_path': './data/', + 'lr': 1e-3, + 'emb_path': './word2vec', + 'embed_size': 300, + 'save_checkpoint_steps': 149, + 'keep_checkpoint_max': 10, + 'momentum': 0.9 +}) diff --git a/model_zoo/official/nlp/textrcnn/src/dataset.py b/model_zoo/official/nlp/textrcnn/src/dataset.py new file mode 100644 index 0000000000..a793cf794b --- /dev/null +++ b/model_zoo/official/nlp/textrcnn/src/dataset.py @@ -0,0 +1,179 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""dataset api""" +import os +from itertools import chain +import gensim +import numpy as np + +from mindspore.mindrecord import FileWriter +import mindspore.dataset as ds + +# preprocess part +def encode_samples(tokenized_samples, word_to_idx): + """ encode word to index """ + features = [] + for sample in tokenized_samples: + feature = [] + for token in sample: + if token in word_to_idx: + feature.append(word_to_idx[token]) + else: + feature.append(0) + features.append(feature) + return features + + +def pad_samples(features, maxlen=50, pad=0): + """ pad all features to the same length """ + padded_features = [] + for feature in features: + if len(feature) >= maxlen: + padded_feature = feature[:maxlen] + else: + padded_feature = feature + while len(padded_feature) < maxlen: + padded_feature.append(pad) + padded_features.append(padded_feature) + return padded_features + + +def read_imdb(path, seg='train'): + """ read imdb dataset """ + pos_or_neg = ['pos', 'neg'] + data = [] + for label in pos_or_neg: + + f = os.path.join(path, seg, label) + rf = open(f, 'r') + for line in rf: + line = line.strip() + if label == 'pos': + data.append([line, 1]) + elif label == 'neg': + data.append([line, 0]) + + return data + + +def tokenizer(text): + return [tok.lower() for tok in text.split(' ')] + + +def collect_weight(glove_path, vocab, word_to_idx, embed_size): + """ collect weight """ + vocab_size = len(vocab) + # wvmodel = gensim.models.KeyedVectors.load_word2vec_format(os.path.join(glove_path, 'glove.6B.300d.txt'), + # binary=False, encoding='utf-8') + wvmodel = gensim.models.KeyedVectors.load_word2vec_format(os.path.join(glove_path, \ + 'GoogleNews-vectors-negative300.bin'), binary=True) + weight_np = np.zeros((vocab_size + 1, embed_size)).astype(np.float32) + + idx_to_word = {i + 1: word for i, word in enumerate(vocab)} + idx_to_word[0] = '' + + for i in range(len(wvmodel.index2word)): + try: + index = word_to_idx[wvmodel.index2word[i]] + except KeyError: + continue + weight_np[index, :] = wvmodel.get_vector( + idx_to_word[word_to_idx[wvmodel.index2word[i]]]) + return weight_np + + +def preprocess(data_path, glove_path, embed_size): + """ preprocess the train and test data """ + train_data = read_imdb(data_path, 'train') + test_data = read_imdb(data_path, 'test') + + train_tokenized = [] + test_tokenized = [] + for review, _ in train_data: + train_tokenized.append(tokenizer(review)) + for review, _ in test_data: + test_tokenized.append(tokenizer(review)) + + vocab = set(chain(*train_tokenized)) + vocab_size = len(vocab) + print("vocab_size: ", vocab_size) + + word_to_idx = {word: i + 1 for i, word in enumerate(vocab)} + word_to_idx[''] = 0 + + train_features = np.array(pad_samples(encode_samples(train_tokenized, word_to_idx))).astype(np.int32) + train_labels = np.array([score for _, score in train_data]).astype(np.int32) + test_features = np.array(pad_samples(encode_samples(test_tokenized, word_to_idx))).astype(np.int32) + test_labels = np.array([score for _, score in test_data]).astype(np.int32) + + weight_np = collect_weight(glove_path, vocab, word_to_idx, embed_size) + return train_features, train_labels, test_features, test_labels, weight_np, vocab_size + + +def get_imdb_data(labels_data, features_data): + data_list = [] + for i, (label, feature) in enumerate(zip(labels_data, features_data)): + data_json = {"id": i, + "label": int(label), + "feature": feature.reshape(-1)} + data_list.append(data_json) + return data_list + + +def convert_to_mindrecord(embed_size, data_path, proprocess_path, glove_path): + """ convert imdb dataset to mindrecord """ + + num_shard = 4 + train_features, train_labels, test_features, test_labels, weight_np, _ = \ + preprocess(data_path, glove_path, embed_size) + np.savetxt(os.path.join(proprocess_path, 'weight.txt'), weight_np) + + print("train_features.shape:", train_features.shape, "train_labels.shape:", train_labels.shape, "weight_np.shape:",\ + weight_np.shape, "type:", train_labels.dtype) + # write mindrecord + schema_json = {"id": {"type": "int32"}, + "label": {"type": "int32"}, + "feature": {"type": "int32", "shape": [-1]}} + + writer = FileWriter(os.path.join(proprocess_path, 'aclImdb_train.mindrecord'), num_shard) + data = get_imdb_data(train_labels, train_features) + writer.add_schema(schema_json, "nlp_schema") + writer.add_index(["id", "label"]) + writer.write_raw_data(data) + writer.commit() + + writer = FileWriter(os.path.join(proprocess_path, 'aclImdb_test.mindrecord'), num_shard) + data = get_imdb_data(test_labels, test_features) + writer.add_schema(schema_json, "nlp_schema") + writer.add_index(["id", "label"]) + writer.write_raw_data(data) + writer.commit() + + +def create_dataset(base_path, batch_size, num_epochs, is_train): + """Create dataset for training.""" + columns_list = ["feature", "label"] + num_consumer = 4 + + if is_train: + path = os.path.join(base_path, 'aclImdb_train.mindrecord0') + else: + path = os.path.join(base_path, 'aclImdb_test.mindrecord0') + + data_set = ds.MindDataset(path, columns_list, num_consumer) + ds.config.set_seed(1) + data_set = data_set.shuffle(buffer_size=data_set.get_dataset_size()) + data_set = data_set.batch(batch_size, drop_remainder=True) + return data_set diff --git a/model_zoo/official/nlp/textrcnn/src/textrcnn.py b/model_zoo/official/nlp/textrcnn/src/textrcnn.py new file mode 100644 index 0000000000..6b3fb00e92 --- /dev/null +++ b/model_zoo/official/nlp/textrcnn/src/textrcnn.py @@ -0,0 +1,196 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""model textrcnn""" +import numpy as np + +import mindspore.nn as nn +from mindspore.ops import operations as P +from mindspore.ops import functional as F +from mindspore.common.parameter import Parameter +from mindspore import Tensor +from mindspore.common import dtype as mstype + +class textrcnn(nn.Cell): + """class textrcnn""" + def __init__(self, weight, vocab_size, cell, batch_size): + super(textrcnn, self).__init__() + self.num_hiddens = 512 + self.embed_size = 300 + self.num_classes = 2 + self.batch_size = batch_size + k = (1 / self.num_hiddens) ** 0.5 + + self.embedding = nn.Embedding(vocab_size, self.embed_size, embedding_table=weight) + self.embedding.embedding_table.requires_grad = False + self.cell = cell + + self.cast = P.Cast() + + self.h1 = Tensor(np.zeros(shape=(self.batch_size, self.num_hiddens)).astype(np.float16)) + self.c1 = Tensor(np.zeros(shape=(self.batch_size, self.num_hiddens)).astype(np.float16)) + + if cell == "lstm": + self.lstm = P.DynamicRNN(forget_bias=0.0) + self.w1_fw = Parameter( + np.random.uniform(-k, k, (self.embed_size + self.num_hiddens, 4 * self.num_hiddens)).astype( + np.float16), name="w1_fw") + self.b1_fw = Parameter(np.random.uniform(-k, k, (4 * self.num_hiddens)).astype(np.float16), + name="b1_fw") + self.w1_bw = Parameter( + np.random.uniform(-k, k, (self.embed_size + self.num_hiddens, 4 * self.num_hiddens)).astype( + np.float16), name="w1_bw") + self.b1_bw = Parameter(np.random.uniform(-k, k, (4 * self.num_hiddens)).astype(np.float16), + name="b1_bw") + self.h1 = Tensor(np.zeros(shape=(1, self.batch_size, self.num_hiddens)).astype(np.float16)) + self.c1 = Tensor(np.zeros(shape=(1, self.batch_size, self.num_hiddens)).astype(np.float16)) + + if cell == "vanilla": + self.rnnW_fw = nn.Dense(self.num_hiddens, self.num_hiddens) + self.rnnU_fw = nn.Dense(self.embed_size, self.num_hiddens) + self.rnnW_bw = nn.Dense(self.num_hiddens, self.num_hiddens) + self.rnnU_bw = nn.Dense(self.embed_size, self.num_hiddens) + + if cell == "gru": + self.rnnWr_fw = nn.Dense(self.num_hiddens + self.embed_size, self.num_hiddens) + self.rnnWz_fw = nn.Dense(self.num_hiddens + self.embed_size, self.num_hiddens) + self.rnnWh_fw = nn.Dense(self.num_hiddens + self.embed_size, self.num_hiddens) + self.rnnWr_bw = nn.Dense(self.num_hiddens + self.embed_size, self.num_hiddens) + self.rnnWz_bw = nn.Dense(self.num_hiddens + self.embed_size, self.num_hiddens) + self.rnnWh_bw = nn.Dense(self.num_hiddens + self.embed_size, self.num_hiddens) + self.ones = Tensor(np.ones(shape=(self.batch_size, self.num_hiddens)).astype(np.float16)) + + self.transpose = P.Transpose() + self.reduce_max = P.ReduceMax() + self.expand_dims = P.ExpandDims() + self.concat = P.Concat() + + self.reshape = P.Reshape() + self.left_pad_tensor = Tensor(np.zeros((1, self.batch_size, self.num_hiddens)).astype(np.float16)) + self.right_pad_tensor = Tensor(np.zeros((1, self.batch_size, self.num_hiddens)).astype(np.float16)) + self.output_dense = nn.Dense(self.num_hiddens * 1, 2) + self.concat0 = P.Concat(0) + self.concat2 = P.Concat(2) + self.concat1 = P.Concat(1) + self.text_rep_dense = nn.Dense(2 * self.num_hiddens + self.embed_size, self.num_hiddens) + self.mydense = nn.Dense(self.num_hiddens, 2) + self.drop_out = nn.Dropout(keep_prob=0.7) + self.tanh = P.Tanh() + self.sigmoid = P.Sigmoid() + self.slice = P.Slice() + # self.lstm = nn.LSTM(input_size=input_size,hidden_size=hidden_size,num_layers=num_layers,has_bias=has_bias, batch_first=batch_first, bidirectional=bidirectional, dropout=0.0) + + def construct(self, x): + """class construction""" + # x: bs, sl + output_fw = x + output_bw = x + + if self.cell == "vanilla": + x = self.embedding(x) # bs, sl, emb_size + x = self.cast(x, mstype.float16) + x = self.transpose(x, (1, 0, 2)) # sl, bs, emb_size + x = self.drop_out(x) # sl,bs, emb_size + + h1_fw = self.cast(self.h1, mstype.float16) # bs, num_hidden + h1_fw = self.tanh(self.rnnW_fw(h1_fw) + self.rnnU_fw(x[0, :, :])) # bs, num_hidden + output_fw = self.expand_dims(h1_fw, 0) # 1, bs, num_hidden + + for i in range(1, F.shape(x)[0]): + h1_fw = self.tanh(self.rnnW_fw(h1_fw) + self.rnnU_fw(x[i, :, :])) # 1, bs, num_hidden + h1_after_expand_fw = self.expand_dims(h1_fw, 0) + output_fw = self.concat((output_fw, h1_after_expand_fw)) # 2/3/4.., bs, num_hidden + output_fw = self.cast(output_fw, mstype.float16) # sl, bs, num_hidden + + h1_bw = self.cast(self.h1, mstype.float16) # bs, num_hidden + h1_bw = self.tanh(self.rnnW_bw(h1_bw) + self.rnnU_bw(x[F.shape(x)[0] - 1, :, :])) # bs, num_hidden + output_bw = self.expand_dims(h1_bw, 0) # 1, bs, num_hidden + + for i in range(F.shape(x)[0] - 2, -1, -1): + h1_bw = self.tanh(self.rnnW_bw(h1_bw) + self.rnnU_bw(x[i, :, :])) # 1, bs, num_hidden + h1_after_expand_bw = self.expand_dims(h1_bw, 0) + output_bw = self.concat((h1_after_expand_bw, output_bw)) # 2/3/4.., bs, num_hidden + output_bw = self.cast(output_bw, mstype.float16) # sl, bs, num_hidden + + if self.cell == "gru": + x = self.embedding(x) # bs, sl, emb_size + x = self.cast(x, mstype.float16) + x = self.transpose(x, (1, 0, 2)) # sl, bs, emb_size + x = self.drop_out(x) # sl,bs, emb_size + + h_fw = self.cast(self.h1, mstype.float16) + + h_x_fw = self.concat1((h_fw, x[0, :, :])) + r_fw = self.sigmoid(self.rnnWr_fw(h_x_fw)) + z_fw = self.sigmoid(self.rnnWz_fw(h_x_fw)) + h_tilde_fw = self.tanh(self.rnnWh_fw(self.concat1((r_fw * h_fw, x[0, :, :])))) + h_fw = (self.ones - z_fw) * h_fw + z_fw * h_tilde_fw + output_fw = self.expand_dims(h_fw, 0) + + for i in range(1, F.shape(x)[0]): + h_x_fw = self.concat1((h_fw, x[i, :, :])) + r_fw = self.sigmoid(self.rnnWr_fw(h_x_fw)) + z_fw = self.sigmoid(self.rnnWz_fw(h_x_fw)) + h_tilde_fw = self.tanh(self.rnnWh_fw(self.concat1((r_fw * h_fw, x[i, :, :])))) + h_fw = (self.ones - z_fw) * h_fw + z_fw * h_tilde_fw + h_after_expand_fw = self.expand_dims(h_fw, 0) + output_fw = self.concat((output_fw, h_after_expand_fw)) + output_fw = self.cast(output_fw, mstype.float16) + + h_bw = self.cast(self.h1, mstype.float16) # bs, num_hidden + + h_x_bw = self.concat1((h_bw, x[F.shape(x)[0] - 1, :, :])) + r_bw = self.sigmoid(self.rnnWr_bw(h_x_bw)) + z_bw = self.sigmoid(self.rnnWz_bw(h_x_bw)) + h_tilde_bw = self.tanh(self.rnnWh_bw(self.concat1((r_bw * h_bw, x[F.shape(x)[0] - 1, :, :])))) + h_bw = (self.ones - z_bw) * h_bw + z_bw * h_tilde_bw + output_bw = self.expand_dims(h_bw, 0) + for i in range(F.shape(x)[0] - 2, -1, -1): + h_x_bw = self.concat1((h_bw, x[i, :, :])) + r_bw = self.sigmoid(self.rnnWr_bw(h_x_bw)) + z_bw = self.sigmoid(self.rnnWz_bw(h_x_bw)) + h_tilde_bw = self.tanh(self.rnnWh_bw(self.concat1((r_bw * h_bw, x[i, :, :])))) + h_bw = (self.ones - z_bw) * h_bw + z_bw * h_tilde_bw + h_after_expand_bw = self.expand_dims(h_bw, 0) + output_bw = self.concat((h_after_expand_bw, output_bw)) + output_bw = self.cast(output_bw, mstype.float16) + if self.cell == 'lstm': + x = self.embedding(x) # bs, sl, emb_size + x = self.cast(x, mstype.float16) + x = self.transpose(x, (1, 0, 2)) # sl, bs, emb_size + x = self.drop_out(x) # sl,bs, emb_size + + h1_fw_init = self.h1 # bs, num_hidden + c1_fw_init = self.c1 # bs, num_hidden + + _, output_fw, _, _, _, _, _, _ = self.lstm(x, self.w1_fw, self.b1_fw, None, h1_fw_init, c1_fw_init) + output_fw = self.cast(output_fw, mstype.float16) # sl, bs, num_hidden + + h1_bw_init = self.h1 # bs, num_hidden + c1_bw_init = self.c1 # bs, num_hidden + _, output_bw, _, _, _, _, _, _ = self.lstm(x, self.w1_bw, self.b1_bw, None, h1_bw_init, c1_bw_init) + output_bw = self.cast(output_bw, mstype.float16) # sl, bs, hidden + + c_left = self.concat0((self.left_pad_tensor, output_fw[:F.shape(x)[0] - 1])) # sl, bs, num_hidden + c_right = self.concat0((output_bw[1:], self.right_pad_tensor)) # sl, bs, num_hidden + output = self.concat2((c_left, self.cast(x, mstype.float16), c_right)) # sl, bs, 2*num_hidden+emb_size + output = self.cast(output, mstype.float16) + + output_flat = self.reshape(output, (F.shape(x)[0] * self.batch_size, 2 * self.num_hiddens + self.embed_size)) + output_dense = self.text_rep_dense(output_flat) # sl*bs, num_hidden + output_dense = self.tanh(output_dense) # sl*bs, num_hidden + output = self.reshape(output_dense, (F.shape(x)[0], self.batch_size, self.num_hiddens)) # sl, bs, num_hidden + output = self.reduce_max(output, 0) # bs, num_hidden + outputs = self.cast(self.mydense(output), mstype.float16) # bs, num_classes + return outputs diff --git a/model_zoo/official/nlp/textrcnn/train.py b/model_zoo/official/nlp/textrcnn/train.py new file mode 100644 index 0000000000..ec74ef4a9f --- /dev/null +++ b/model_zoo/official/nlp/textrcnn/train.py @@ -0,0 +1,74 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""model train script""" +import os +import shutil +import numpy as np + +import mindspore.nn as nn +import mindspore.context as context +from mindspore import Tensor +from mindspore.train import Model +from mindspore.nn.metrics import Accuracy +from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor +from mindspore.common import set_seed + +from src.config import textrcnn_cfg as cfg +from src.dataset import create_dataset +from src.dataset import convert_to_mindrecord +from src.textrcnn import textrcnn + + +set_seed(1) + +if __name__ == '__main__': + + context.set_context( + mode=context.GRAPH_MODE, + save_graphs=False, + device_target="Ascend") + + device_id = int(os.getenv('DEVICE_ID')) + context.set_context(device_id=device_id) + + if cfg.preprocess == 'true': + print("============== Starting Data Pre-processing ==============") + if os.path.exists(cfg.preprocess_path): + shutil.rmtree(cfg.preprocess_path) + os.mkdir(cfg.preprocess_path) + convert_to_mindrecord(cfg.embed_size, cfg.data_path, cfg.preprocess_path, cfg.emb_path) + + embedding_table = np.loadtxt(os.path.join(cfg.preprocess_path, "weight.txt")).astype(np.float32) + + network = textrcnn(weight=Tensor(embedding_table), vocab_size=embedding_table.shape[0], \ + cell=cfg.cell, batch_size=cfg.batch_size) + + loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True) + if cfg.opt == "adam": + opt = nn.Adam(params=network.trainable_params(), learning_rate=cfg.lr) + elif cfg.opt == "momentum": + opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum) + + loss_cb = LossMonitor() + model = Model(network, loss, opt, {'acc': Accuracy()}, amp_level="O3") + + print("============== Starting Training ==============") + ds_train = create_dataset(cfg.preprocess_path, cfg.batch_size, cfg.num_epochs, True) + config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps, \ + keep_checkpoint_max=cfg.keep_checkpoint_max) + ckpoint_cb = ModelCheckpoint(prefix=cfg.cell, directory=cfg.ckpt_folder_path, config=config_ck) + model.train(cfg.num_epochs, ds_train, callbacks=[ckpoint_cb, loss_cb]) + print("train success") + \ No newline at end of file