| @@ -0,0 +1,100 @@ | |||||
| # LSTM Example | |||||
| ## Description | |||||
| This example is for LSTM model training and evaluation. | |||||
| ## Requirements | |||||
| - Install [MindSpore](https://www.mindspore.cn/install/en). | |||||
| - Download the dataset aclImdb_v1. | |||||
| > Unzip the aclImdb_v1 dataset to any path you want and the folder structure should be as follows: | |||||
| > ``` | |||||
| > . | |||||
| > ├── train # train dataset | |||||
| > └── test # infer dataset | |||||
| > ``` | |||||
| - Download the GloVe file. | |||||
| > Unzip the glove.6B.zip to any path you want and the folder structure should be as follows: | |||||
| > ``` | |||||
| > . | |||||
| > ├── glove.6B.100d.txt | |||||
| > ├── glove.6B.200d.txt | |||||
| > ├── glove.6B.300d.txt # we will use this one later. | |||||
| > └── glove.6B.50d.txt | |||||
| > ``` | |||||
| > Adding a new line at the beginning of the file which named `glove.6B.300d.txt`. | |||||
| > It means reading a total of 400,000 words, each represented by a 300-latitude word vector. | |||||
| > ``` | |||||
| > 400000 300 | |||||
| > ``` | |||||
| ## Running the Example | |||||
| ### Training | |||||
| ``` | |||||
| python train.py --preprocess=true --aclimdb_path=your_imdb_path --glove_path=your_glove_path > out.train.log 2>&1 & | |||||
| ``` | |||||
| The python command above will run in the background, you can view the results through the file `out.train.log`. | |||||
| After training, you'll get some checkpoint files under the script folder by default. | |||||
| You will get the loss value as following: | |||||
| ``` | |||||
| # grep "loss is " out.train.log | |||||
| epoch: 1 step: 390, loss is 0.6003723 | |||||
| epcoh: 2 step: 390, loss is 0.35312173 | |||||
| ... | |||||
| ``` | |||||
| ### Evaluation | |||||
| ``` | |||||
| python eval.py --ckpt_path=./lstm-20-390.ckpt > out.eval.log 2>&1 & | |||||
| ``` | |||||
| The above python command will run in the background, you can view the results through the file `out.eval.log`. | |||||
| You will get the accuracy as following: | |||||
| ``` | |||||
| # grep "acc" out.eval.log | |||||
| result: {'acc': 0.83} | |||||
| ``` | |||||
| ## Usage: | |||||
| ### Training | |||||
| ``` | |||||
| usage: train.py [--preprocess {true,false}] [--aclimdb_path ACLIMDB_PATH] | |||||
| [--glove_path GLOVE_PATH] [--preprocess_path PREPROCESS_PATH] | |||||
| [--ckpt_path CKPT_PATH] [--device_target {GPU,CPU}] | |||||
| parameters/options: | |||||
| --preprocess whether to preprocess data. | |||||
| --aclimdb_path path where the dataset is stored. | |||||
| --glove_path path where the GloVe is stored. | |||||
| --preprocess_path path where the pre-process data is stored. | |||||
| --ckpt_path the path to save the checkpoint file. | |||||
| --device_target the target device to run, support "GPU", "CPU". | |||||
| ``` | |||||
| ### Evaluation | |||||
| ``` | |||||
| usage: eval.py [--preprocess {true,false}] [--aclimdb_path ACLIMDB_PATH] | |||||
| [--glove_path GLOVE_PATH] [--preprocess_path PREPROCESS_PATH] | |||||
| [--ckpt_path CKPT_PATH] [--device_target {GPU,CPU}] | |||||
| parameters/options: | |||||
| --preprocess whether to preprocess data. | |||||
| --aclimdb_path path where the dataset is stored. | |||||
| --glove_path path where the GloVe is stored. | |||||
| --preprocess_path path where the pre-process data is stored. | |||||
| --ckpt_path the checkpoint file path used to evaluate model. | |||||
| --device_target the target device to run, support "GPU", "CPU". | |||||
| ``` | |||||
| @@ -0,0 +1,33 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| network config setting | |||||
| """ | |||||
| from easydict import EasyDict as edict | |||||
| # LSTM CONFIG | |||||
| lstm_cfg = edict({ | |||||
| 'num_classes': 2, | |||||
| 'learning_rate': 0.1, | |||||
| 'momentum': 0.9, | |||||
| 'num_epochs': 20, | |||||
| 'batch_size': 64, | |||||
| 'embed_size': 300, | |||||
| 'num_hiddens': 100, | |||||
| 'num_layers': 2, | |||||
| 'bidirectional': True, | |||||
| 'save_checkpoint_steps': 390, | |||||
| 'keep_checkpoint_max': 10 | |||||
| }) | |||||
| @@ -0,0 +1,92 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| Data operations, will be used in train.py and eval.py | |||||
| """ | |||||
| import os | |||||
| import numpy as np | |||||
| from imdb import ImdbParser | |||||
| import mindspore.dataset as ds | |||||
| from mindspore.mindrecord import FileWriter | |||||
| def create_dataset(data_home, batch_size, repeat_num=1, training=True): | |||||
| """Data operations.""" | |||||
| ds.config.set_seed(1) | |||||
| data_dir = os.path.join(data_home, "aclImdb_train.mindrecord0") | |||||
| if not training: | |||||
| data_dir = os.path.join(data_home, "aclImdb_test.mindrecord0") | |||||
| data_set = ds.MindDataset(data_dir, columns_list=["feature", "label"], num_parallel_workers=4) | |||||
| # apply map operations on images | |||||
| data_set = data_set.shuffle(buffer_size=data_set.get_dataset_size()) | |||||
| data_set = data_set.batch(batch_size=batch_size, drop_remainder=True) | |||||
| data_set = data_set.repeat(count=repeat_num) | |||||
| return data_set | |||||
| def _convert_to_mindrecord(data_home, features, labels, weight_np=None, training=True): | |||||
| """ | |||||
| convert imdb dataset to mindrecoed dataset | |||||
| """ | |||||
| if weight_np is not None: | |||||
| np.savetxt(os.path.join(data_home, 'weight.txt'), weight_np) | |||||
| # write mindrecord | |||||
| schema_json = {"id": {"type": "int32"}, | |||||
| "label": {"type": "int32"}, | |||||
| "feature": {"type": "int32", "shape": [-1]}} | |||||
| data_dir = os.path.join(data_home, "aclImdb_train.mindrecord") | |||||
| if not training: | |||||
| data_dir = os.path.join(data_home, "aclImdb_test.mindrecord") | |||||
| def get_imdb_data(features, labels): | |||||
| data_list = [] | |||||
| for i, (label, feature) in enumerate(zip(labels, features)): | |||||
| data_json = {"id": i, | |||||
| "label": int(label), | |||||
| "feature": feature.reshape(-1)} | |||||
| data_list.append(data_json) | |||||
| return data_list | |||||
| writer = FileWriter(data_dir, shard_num=4) | |||||
| data = get_imdb_data(features, labels) | |||||
| writer.add_schema(schema_json, "nlp_schema") | |||||
| writer.add_index(["id", "label"]) | |||||
| writer.write_raw_data(data) | |||||
| writer.commit() | |||||
| def convert_to_mindrecord(embed_size, aclimdb_path, preprocess_path, glove_path): | |||||
| """ | |||||
| convert imdb dataset to mindrecoed dataset | |||||
| """ | |||||
| parser = ImdbParser(aclimdb_path, glove_path, embed_size) | |||||
| parser.parse() | |||||
| if not os.path.exists(preprocess_path): | |||||
| print(f"preprocess path {preprocess_path} is not exist") | |||||
| os.makedirs(preprocess_path) | |||||
| train_features, train_labels, train_weight_np = parser.get_datas('train') | |||||
| _convert_to_mindrecord(preprocess_path, train_features, train_labels, train_weight_np) | |||||
| test_features, test_labels, _ = parser.get_datas('test') | |||||
| _convert_to_mindrecord(preprocess_path, test_features, test_labels, training=False) | |||||
| @@ -0,0 +1,81 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| #################train lstm example on aclImdb######################## | |||||
| python eval.py --ckpt_path=./lstm-20-390.ckpt | |||||
| """ | |||||
| import argparse | |||||
| import os | |||||
| import numpy as np | |||||
| from config import lstm_cfg as cfg | |||||
| from dataset import create_dataset, convert_to_mindrecord | |||||
| from mindspore import Tensor, nn, Model, context | |||||
| from mindspore.model_zoo.lstm import SentimentNet | |||||
| from mindspore.nn import Accuracy | |||||
| from mindspore.train.callback import LossMonitor | |||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||||
| if __name__ == '__main__': | |||||
| parser = argparse.ArgumentParser(description='MindSpore LSTM Example') | |||||
| parser.add_argument('--preprocess', type=str, default='false', choices=['true', 'false'], | |||||
| help='whether to preprocess data.') | |||||
| parser.add_argument('--aclimdb_path', type=str, default="./aclImdb", | |||||
| help='path where the dataset is stored.') | |||||
| parser.add_argument('--glove_path', type=str, default="./glove", | |||||
| help='path where the GloVe is stored.') | |||||
| parser.add_argument('--preprocess_path', type=str, default="./preprocess", | |||||
| help='path where the pre-process data is stored.') | |||||
| parser.add_argument('--ckpt_path', type=str, default=None, | |||||
| help='the checkpoint file path used to evaluate model.') | |||||
| parser.add_argument('--device_target', type=str, default="GPU", choices=['GPU', 'CPU'], | |||||
| help='the target device to run, support "GPU", "CPU". Default: "GPU".') | |||||
| args = parser.parse_args() | |||||
| context.set_context( | |||||
| mode=context.GRAPH_MODE, | |||||
| save_graphs=False, | |||||
| device_target=args.device_target) | |||||
| if args.preprocess == "true": | |||||
| print("============== Starting Data Pre-processing ==============") | |||||
| convert_to_mindrecord(cfg.embed_size, args.aclimdb_path, args.preprocess_path, args.glove_path) | |||||
| embedding_table = np.loadtxt(os.path.join(args.preprocess_path, "weight.txt")).astype(np.float32) | |||||
| network = SentimentNet(vocab_size=embedding_table.shape[0], | |||||
| embed_size=cfg.embed_size, | |||||
| num_hiddens=cfg.num_hiddens, | |||||
| num_layers=cfg.num_layers, | |||||
| bidirectional=cfg.bidirectional, | |||||
| num_classes=cfg.num_classes, | |||||
| weight=Tensor(embedding_table), | |||||
| batch_size=cfg.batch_size) | |||||
| loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) | |||||
| opt = nn.Momentum(network.trainable_params(), cfg.learning_rate, cfg.momentum) | |||||
| loss_cb = LossMonitor() | |||||
| model = Model(network, loss, opt, {'acc': Accuracy()}) | |||||
| print("============== Starting Testing ==============") | |||||
| ds_eval = create_dataset(args.preprocess_path, cfg.batch_size, training=False) | |||||
| param_dict = load_checkpoint(args.ckpt_path) | |||||
| load_param_into_net(network, param_dict) | |||||
| if args.device_target == "CPU": | |||||
| acc = model.eval(ds_eval, dataset_sink_mode=False) | |||||
| else: | |||||
| acc = model.eval(ds_eval) | |||||
| print("============== Accuracy:{} ==============".format(acc)) | |||||
| @@ -0,0 +1,155 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| imdb dataset parser. | |||||
| """ | |||||
| import os | |||||
| from itertools import chain | |||||
| import gensim | |||||
| import numpy as np | |||||
| class ImdbParser(): | |||||
| """ | |||||
| parse aclImdb data to features and labels. | |||||
| sentence->tokenized->encoded->padding->features | |||||
| """ | |||||
| def __init__(self, imdb_path, glove_path, embed_size=300): | |||||
| self.__segs = ['train', 'test'] | |||||
| self.__label_dic = {'pos': 1, 'neg': 0} | |||||
| self.__imdb_path = imdb_path | |||||
| self.__glove_dim = embed_size | |||||
| self.__glove_file = os.path.join(glove_path, 'glove.6B.' + str(self.__glove_dim) + 'd.txt') | |||||
| # properties | |||||
| self.__imdb_datas = {} | |||||
| self.__features = {} | |||||
| self.__labels = {} | |||||
| self.__vacab = {} | |||||
| self.__word2idx = {} | |||||
| self.__weight_np = {} | |||||
| self.__wvmodel = None | |||||
| def parse(self): | |||||
| """ | |||||
| parse imdb data to memory | |||||
| """ | |||||
| self.__wvmodel = gensim.models.KeyedVectors.load_word2vec_format(self.__glove_file) | |||||
| for seg in self.__segs: | |||||
| self.__parse_imdb_datas(seg) | |||||
| self.__parse_features_and_labels(seg) | |||||
| self.__gen_weight_np(seg) | |||||
| def __parse_imdb_datas(self, seg): | |||||
| """ | |||||
| load data from txt | |||||
| """ | |||||
| data_lists = [] | |||||
| for label_name, label_id in self.__label_dic.items(): | |||||
| sentence_dir = os.path.join(self.__imdb_path, seg, label_name) | |||||
| for file in os.listdir(sentence_dir): | |||||
| with open(os.path.join(sentence_dir, file), mode='r', encoding='utf8') as f: | |||||
| sentence = f.read().replace('\n', '') | |||||
| data_lists.append([sentence, label_id]) | |||||
| self.__imdb_datas[seg] = data_lists | |||||
| def __parse_features_and_labels(self, seg): | |||||
| """ | |||||
| parse features and labels | |||||
| """ | |||||
| features = [] | |||||
| labels = [] | |||||
| for sentence, label in self.__imdb_datas[seg]: | |||||
| features.append(sentence) | |||||
| labels.append(label) | |||||
| self.__features[seg] = features | |||||
| self.__labels[seg] = labels | |||||
| # update feature to tokenized | |||||
| self.__updata_features_to_tokenized(seg) | |||||
| # parse vacab | |||||
| self.__parse_vacab(seg) | |||||
| # encode feature | |||||
| self.__encode_features(seg) | |||||
| # padding feature | |||||
| self.__padding_features(seg) | |||||
| def __updata_features_to_tokenized(self, seg): | |||||
| tokenized_features = [] | |||||
| for sentence in self.__features[seg]: | |||||
| tokenized_sentence = [word.lower() for word in sentence.split(" ")] | |||||
| tokenized_features.append(tokenized_sentence) | |||||
| self.__features[seg] = tokenized_features | |||||
| def __parse_vacab(self, seg): | |||||
| # vocab | |||||
| tokenized_features = self.__features[seg] | |||||
| vocab = set(chain(*tokenized_features)) | |||||
| self.__vacab[seg] = vocab | |||||
| # word_to_idx: {'hello': 1, 'world':111, ... '<unk>': 0} | |||||
| word_to_idx = {word: i + 1 for i, word in enumerate(vocab)} | |||||
| word_to_idx['<unk>'] = 0 | |||||
| self.__word2idx[seg] = word_to_idx | |||||
| def __encode_features(self, seg): | |||||
| """ encode word to index """ | |||||
| word_to_idx = self.__word2idx['train'] | |||||
| encoded_features = [] | |||||
| for tokenized_sentence in self.__features[seg]: | |||||
| encoded_sentence = [] | |||||
| for word in tokenized_sentence: | |||||
| encoded_sentence.append(word_to_idx.get(word, 0)) | |||||
| encoded_features.append(encoded_sentence) | |||||
| self.__features[seg] = encoded_features | |||||
| def __padding_features(self, seg, maxlen=500, pad=0): | |||||
| """ pad all features to the same length """ | |||||
| padded_features = [] | |||||
| for feature in self.__features[seg]: | |||||
| if len(feature) >= maxlen: | |||||
| padded_feature = feature[:maxlen] | |||||
| else: | |||||
| padded_feature = feature | |||||
| while len(padded_feature) < maxlen: | |||||
| padded_feature.append(pad) | |||||
| padded_features.append(padded_feature) | |||||
| self.__features[seg] = padded_features | |||||
| def __gen_weight_np(self, seg): | |||||
| """ | |||||
| generate weight by gensim | |||||
| """ | |||||
| weight_np = np.zeros((len(self.__word2idx[seg]), self.__glove_dim), dtype=np.float32) | |||||
| for word, idx in self.__word2idx[seg].items(): | |||||
| if word not in self.__wvmodel: | |||||
| continue | |||||
| word_vector = self.__wvmodel.get_vector(word) | |||||
| weight_np[idx, :] = word_vector | |||||
| self.__weight_np[seg] = weight_np | |||||
| def get_datas(self, seg): | |||||
| """ | |||||
| return features, labels, and weight | |||||
| """ | |||||
| features = np.array(self.__features[seg]).astype(np.int32) | |||||
| labels = np.array(self.__labels[seg]).astype(np.int32) | |||||
| weight = np.array(self.__weight_np[seg]) | |||||
| return features, labels, weight | |||||
| @@ -0,0 +1,83 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| #################train lstm example on aclImdb######################## | |||||
| python train.py --preprocess=true --aclimdb_path=your_imdb_path --glove_path=your_glove_path | |||||
| """ | |||||
| import argparse | |||||
| import os | |||||
| import numpy as np | |||||
| from config import lstm_cfg as cfg | |||||
| from dataset import convert_to_mindrecord | |||||
| from dataset import create_dataset | |||||
| from mindspore import Tensor, nn, Model, context | |||||
| from mindspore.model_zoo.lstm import SentimentNet | |||||
| from mindspore.nn import Accuracy | |||||
| from mindspore.train.callback import LossMonitor, CheckpointConfig, ModelCheckpoint, TimeMonitor | |||||
| if __name__ == '__main__': | |||||
| parser = argparse.ArgumentParser(description='MindSpore LSTM Example') | |||||
| parser.add_argument('--preprocess', type=str, default='false', choices=['true', 'false'], | |||||
| help='whether to preprocess data.') | |||||
| parser.add_argument('--aclimdb_path', type=str, default="./aclImdb", | |||||
| help='path where the dataset is stored.') | |||||
| parser.add_argument('--glove_path', type=str, default="./glove", | |||||
| help='path where the GloVe is stored.') | |||||
| parser.add_argument('--preprocess_path', type=str, default="./preprocess", | |||||
| help='path where the pre-process data is stored.') | |||||
| parser.add_argument('--ckpt_path', type=str, default="./", | |||||
| help='the path to save the checkpoint file.') | |||||
| parser.add_argument('--device_target', type=str, default="GPU", choices=['GPU', 'CPU'], | |||||
| help='the target device to run, support "GPU", "CPU". Default: "GPU".') | |||||
| args = parser.parse_args() | |||||
| context.set_context( | |||||
| mode=context.GRAPH_MODE, | |||||
| save_graphs=False, | |||||
| device_target=args.device_target) | |||||
| if args.preprocess == "true": | |||||
| print("============== Starting Data Pre-processing ==============") | |||||
| convert_to_mindrecord(cfg.embed_size, args.aclimdb_path, args.preprocess_path, args.glove_path) | |||||
| embedding_table = np.loadtxt(os.path.join(args.preprocess_path, "weight.txt")).astype(np.float32) | |||||
| network = SentimentNet(vocab_size=embedding_table.shape[0], | |||||
| embed_size=cfg.embed_size, | |||||
| num_hiddens=cfg.num_hiddens, | |||||
| num_layers=cfg.num_layers, | |||||
| bidirectional=cfg.bidirectional, | |||||
| num_classes=cfg.num_classes, | |||||
| weight=Tensor(embedding_table), | |||||
| batch_size=cfg.batch_size) | |||||
| loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) | |||||
| opt = nn.Momentum(network.trainable_params(), cfg.learning_rate, cfg.momentum) | |||||
| loss_cb = LossMonitor() | |||||
| model = Model(network, loss, opt, {'acc': Accuracy()}) | |||||
| print("============== Starting Training ==============") | |||||
| ds_train = create_dataset(args.preprocess_path, cfg.batch_size, repeat_num=cfg.num_epochs) | |||||
| config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps, | |||||
| keep_checkpoint_max=cfg.keep_checkpoint_max) | |||||
| ckpoint_cb = ModelCheckpoint(prefix="lstm", directory=args.ckpt_path, config=config_ck) | |||||
| time_cb = TimeMonitor(data_size=ds_train.get_dataset_size()) | |||||
| if args.device_target == "CPU": | |||||
| model.train(cfg.num_epochs, ds_train, callbacks=[time_cb, ckpoint_cb, loss_cb], dataset_sink_mode=False) | |||||
| else: | |||||
| model.train(cfg.num_epochs, ds_train, callbacks=[time_cb, ckpoint_cb, loss_cb]) | |||||
| print("============== Training Success ==============") | |||||
| @@ -0,0 +1,115 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """LSTM.""" | |||||
| import math | |||||
| import numpy as np | |||||
| from mindspore import Parameter, Tensor, nn | |||||
| from mindspore.common.initializer import initializer | |||||
| from mindspore.ops import operations as P | |||||
| def init_lstm_weight( | |||||
| input_size, | |||||
| hidden_size, | |||||
| num_layers, | |||||
| bidirectional, | |||||
| has_bias=True): | |||||
| """Initialize lstm weight.""" | |||||
| num_directions = 1 | |||||
| if bidirectional: | |||||
| num_directions = 2 | |||||
| weight_size = 0 | |||||
| gate_size = 4 * hidden_size | |||||
| for layer in range(num_layers): | |||||
| for _ in range(num_directions): | |||||
| input_layer_size = input_size if layer == 0 else hidden_size * num_directions | |||||
| weight_size += gate_size * input_layer_size | |||||
| weight_size += gate_size * hidden_size | |||||
| if has_bias: | |||||
| weight_size += 2 * gate_size | |||||
| stdv = 1 / math.sqrt(hidden_size) | |||||
| w_np = np.random.uniform(-stdv, stdv, (weight_size, 1, 1)).astype(np.float32) | |||||
| w = Parameter(initializer(Tensor(w_np), [weight_size, 1, 1]), name='weight') | |||||
| return w | |||||
| # Initialize short-term memory (h) and long-term memory (c) to 0 | |||||
| def lstm_default_state(batch_size, hidden_size, num_layers, bidirectional): | |||||
| """init default input.""" | |||||
| num_directions = 1 | |||||
| if bidirectional: | |||||
| num_directions = 2 | |||||
| h = Tensor( | |||||
| np.zeros((num_layers * num_directions, batch_size, hidden_size)).astype(np.float32)) | |||||
| c = Tensor( | |||||
| np.zeros((num_layers * num_directions, batch_size, hidden_size)).astype(np.float32)) | |||||
| return h, c | |||||
| class SentimentNet(nn.Cell): | |||||
| """Sentiment network structure.""" | |||||
| def __init__(self, | |||||
| vocab_size, | |||||
| embed_size, | |||||
| num_hiddens, | |||||
| num_layers, | |||||
| bidirectional, | |||||
| num_classes, | |||||
| weight, | |||||
| batch_size): | |||||
| super(SentimentNet, self).__init__() | |||||
| # Mapp words to vectors | |||||
| self.embedding = nn.Embedding(vocab_size, | |||||
| embed_size, | |||||
| embedding_table=weight) | |||||
| self.embedding.embedding_table.requires_grad = False | |||||
| self.trans = P.Transpose() | |||||
| self.perm = (1, 0, 2) | |||||
| self.encoder = nn.LSTM(input_size=embed_size, | |||||
| hidden_size=num_hiddens, | |||||
| num_layers=num_layers, | |||||
| has_bias=True, | |||||
| bidirectional=bidirectional, | |||||
| dropout=0.0) | |||||
| w_init = init_lstm_weight( | |||||
| embed_size, | |||||
| num_hiddens, | |||||
| num_layers, | |||||
| bidirectional) | |||||
| self.encoder.weight = w_init | |||||
| self.h, self.c = lstm_default_state(batch_size, num_hiddens, num_layers, bidirectional) | |||||
| self.concat = P.Concat(1) | |||||
| if bidirectional: | |||||
| self.decoder = nn.Dense(num_hiddens * 4, num_classes) | |||||
| else: | |||||
| self.decoder = nn.Dense(num_hiddens * 2, num_classes) | |||||
| def construct(self, inputs): | |||||
| # input:(64,500,300) | |||||
| embeddings = self.embedding(inputs) | |||||
| embeddings = self.trans(embeddings, self.perm) | |||||
| output, _ = self.encoder(embeddings, (self.h, self.c)) | |||||
| # states[i] size(64,200) -> encoding.size(64,400) | |||||
| encoding = self.concat((output[0], output[1])) | |||||
| outputs = self.decoder(encoding) | |||||
| return outputs | |||||