Merge pull request !2144 from caojian05/ms_master_devtags/v0.5.0-beta
| @@ -21,8 +21,8 @@ import os | |||
| import numpy as np | |||
| from config import lstm_cfg as cfg | |||
| from dataset import create_dataset, convert_to_mindrecord | |||
| from src.config import lstm_cfg as cfg | |||
| from src.dataset import lstm_create_dataset, convert_to_mindrecord | |||
| from mindspore import Tensor, nn, Model, context | |||
| from mindspore.model_zoo.lstm import SentimentNet | |||
| from mindspore.nn import Accuracy | |||
| @@ -71,7 +71,7 @@ if __name__ == '__main__': | |||
| model = Model(network, loss, opt, {'acc': Accuracy()}) | |||
| print("============== Starting Testing ==============") | |||
| ds_eval = create_dataset(args.preprocess_path, cfg.batch_size, training=False) | |||
| ds_eval = lstm_create_dataset(args.preprocess_path, cfg.batch_size, training=False) | |||
| param_dict = load_checkpoint(args.ckpt_path) | |||
| load_param_into_net(network, param_dict) | |||
| if args.device_target == "CPU": | |||
| @@ -0,0 +1,14 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the License); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # httpwww.apache.orglicensesLICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an AS IS BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| @@ -19,12 +19,12 @@ import os | |||
| import numpy as np | |||
| from imdb import ImdbParser | |||
| import mindspore.dataset as ds | |||
| from mindspore.mindrecord import FileWriter | |||
| from .imdb import ImdbParser | |||
| def create_dataset(data_home, batch_size, repeat_num=1, training=True): | |||
| def lstm_create_dataset(data_home, batch_size, repeat_num=1, training=True): | |||
| """Data operations.""" | |||
| ds.config.set_seed(1) | |||
| data_dir = os.path.join(data_home, "aclImdb_train.mindrecord0") | |||
| @@ -0,0 +1,93 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """LSTM.""" | |||
| import numpy as np | |||
| from mindspore import Tensor, nn, context | |||
| from mindspore.ops import operations as P | |||
| # Initialize short-term memory (h) and long-term memory (c) to 0 | |||
| def lstm_default_state(batch_size, hidden_size, num_layers, bidirectional): | |||
| """init default input.""" | |||
| num_directions = 1 | |||
| if bidirectional: | |||
| num_directions = 2 | |||
| if context.get_context("device_target") == "CPU": | |||
| h_list = [] | |||
| c_list = [] | |||
| i = 0 | |||
| while i < num_layers: | |||
| hi = Tensor(np.zeros((num_directions, batch_size, hidden_size)).astype(np.float32)) | |||
| h_list.append(hi) | |||
| ci = Tensor(np.zeros((num_directions, batch_size, hidden_size)).astype(np.float32)) | |||
| c_list.append(ci) | |||
| i = i + 1 | |||
| h = tuple(h_list) | |||
| c = tuple(c_list) | |||
| return h, c | |||
| h = Tensor( | |||
| np.zeros((num_layers * num_directions, batch_size, hidden_size)).astype(np.float32)) | |||
| c = Tensor( | |||
| np.zeros((num_layers * num_directions, batch_size, hidden_size)).astype(np.float32)) | |||
| return h, c | |||
| class SentimentNet(nn.Cell): | |||
| """Sentiment network structure.""" | |||
| def __init__(self, | |||
| vocab_size, | |||
| embed_size, | |||
| num_hiddens, | |||
| num_layers, | |||
| bidirectional, | |||
| num_classes, | |||
| weight, | |||
| batch_size): | |||
| super(SentimentNet, self).__init__() | |||
| # Mapp words to vectors | |||
| self.embedding = nn.Embedding(vocab_size, | |||
| embed_size, | |||
| embedding_table=weight) | |||
| self.embedding.embedding_table.requires_grad = False | |||
| self.trans = P.Transpose() | |||
| self.perm = (1, 0, 2) | |||
| self.encoder = nn.LSTM(input_size=embed_size, | |||
| hidden_size=num_hiddens, | |||
| num_layers=num_layers, | |||
| has_bias=True, | |||
| bidirectional=bidirectional, | |||
| dropout=0.0) | |||
| self.h, self.c = lstm_default_state(batch_size, num_hiddens, num_layers, bidirectional) | |||
| self.concat = P.Concat(1) | |||
| if bidirectional: | |||
| self.decoder = nn.Dense(num_hiddens * 4, num_classes) | |||
| else: | |||
| self.decoder = nn.Dense(num_hiddens * 2, num_classes) | |||
| def construct(self, inputs): | |||
| # input:(64,500,300) | |||
| embeddings = self.embedding(inputs) | |||
| embeddings = self.trans(embeddings, self.perm) | |||
| output, _ = self.encoder(embeddings, (self.h, self.c)) | |||
| # states[i] size(64,200) -> encoding.size(64,400) | |||
| encoding = self.concat((output[0], output[-1])) | |||
| outputs = self.decoder(encoding) | |||
| return outputs | |||
| @@ -21,9 +21,9 @@ import os | |||
| import numpy as np | |||
| from config import lstm_cfg as cfg | |||
| from dataset import convert_to_mindrecord | |||
| from dataset import create_dataset | |||
| from src.config import lstm_cfg as cfg | |||
| from src.dataset import convert_to_mindrecord | |||
| from src.dataset import lstm_create_dataset | |||
| from mindspore import Tensor, nn, Model, context | |||
| from mindspore.model_zoo.lstm import SentimentNet | |||
| from mindspore.nn import Accuracy | |||
| @@ -71,7 +71,7 @@ if __name__ == '__main__': | |||
| model = Model(network, loss, opt, {'acc': Accuracy()}) | |||
| print("============== Starting Training ==============") | |||
| ds_train = create_dataset(args.preprocess_path, cfg.batch_size, cfg.num_epochs) | |||
| ds_train = lstm_create_dataset(args.preprocess_path, cfg.batch_size, cfg.num_epochs) | |||
| config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps, | |||
| keep_checkpoint_max=cfg.keep_checkpoint_max) | |||
| ckpoint_cb = ModelCheckpoint(prefix="lstm", directory=args.ckpt_path, config=config_ck) | |||
| @@ -98,7 +98,7 @@ parameters/options: | |||
| ### Distribute Training | |||
| ``` | |||
| Usage: sh run_distribute_train.sh [MINDSPORE_HCCL_CONFIG_PATH] [DATA_PATH] | |||
| Usage: sh script/run_distribute_train.sh [MINDSPORE_HCCL_CONFIG_PATH] [DATA_PATH] | |||
| parameters/options: | |||
| MINDSPORE_HCCL_CONFIG_PATH HCCL configuration file path. | |||
| @@ -17,14 +17,15 @@ | |||
| python eval.py --data_path=$DATA_HOME --device_id=$DEVICE_ID | |||
| """ | |||
| import argparse | |||
| import mindspore.nn as nn | |||
| from mindspore import context | |||
| from mindspore.nn.optim.momentum import Momentum | |||
| from mindspore.train.model import Model | |||
| from mindspore import context | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from mindspore.model_zoo.vgg import vgg16 | |||
| from config import cifar_cfg as cfg | |||
| import dataset | |||
| from src.config import cifar_cfg as cfg | |||
| from src.dataset import vgg_create_dataset | |||
| from src.vgg import vgg16 | |||
| if __name__ == '__main__': | |||
| parser = argparse.ArgumentParser(description='Cifar10 classification') | |||
| @@ -47,6 +48,6 @@ if __name__ == '__main__': | |||
| param_dict = load_checkpoint(args_opt.checkpoint_path) | |||
| load_param_into_net(net, param_dict) | |||
| net.set_train(False) | |||
| dataset = dataset.create_dataset(args_opt.data_path, 1, False) | |||
| dataset = vgg_create_dataset(args_opt.data_path, 1, False) | |||
| res = model.eval(dataset) | |||
| print("result: ", res) | |||
| @@ -15,39 +15,38 @@ | |||
| # ============================================================================ | |||
| if [ $# != 2 ] | |||
| then | |||
| then | |||
| echo "Usage: sh run_distribute_train.sh [MINDSPORE_HCCL_CONFIG_PATH] [DATA_PATH]" | |||
| exit 1 | |||
| fi | |||
| if [ ! -f $1 ] | |||
| then | |||
| then | |||
| echo "error: MINDSPORE_HCCL_CONFIG_PATH=$1 is not a file" | |||
| exit 1 | |||
| fi | |||
| fi | |||
| if [ ! -d $2 ] | |||
| then | |||
| then | |||
| echo "error: DATA_PATH=$2 is not a directory" | |||
| exit 1 | |||
| fi | |||
| fi | |||
| ulimit -u unlimited | |||
| export DEVICE_NUM=8 | |||
| export RANK_SIZE=8 | |||
| export MINDSPORE_HCCL_CONFIG_PATH=$1 | |||
| for((i=0; i<${DEVICE_NUM}; i++)) | |||
| for((i=0;i<RANK_SIZE;i++)) | |||
| do | |||
| export DEVICE_ID=$i | |||
| export RANK_ID=$i | |||
| rm -rf ./train_parallel$i | |||
| mkdir ./train_parallel$i | |||
| cp *.py ./train_parallel$i | |||
| cp *.sh ./train_parallel$i | |||
| cp -r src ./train_parallel$i | |||
| cd ./train_parallel$i || exit | |||
| echo "start training for rank $RANK_ID, device $DEVICE_ID" | |||
| env > env.log | |||
| python train.py --data_path=$2 --device_id=$i &> log & | |||
| cd .. | |||
| done | |||
| done | |||
| @@ -0,0 +1,14 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the License); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # httpwww.apache.orglicensesLICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an AS IS BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| @@ -16,13 +16,15 @@ | |||
| Data operations, will be used in train.py and eval.py | |||
| """ | |||
| import os | |||
| import mindspore.common.dtype as mstype | |||
| import mindspore.dataset as ds | |||
| import mindspore.dataset.transforms.c_transforms as C | |||
| import mindspore.dataset.transforms.vision.c_transforms as vision | |||
| import mindspore.common.dtype as mstype | |||
| from config import cifar_cfg as cfg | |||
| from .config import cifar_cfg as cfg | |||
| def create_dataset(data_home, repeat_num=1, training=True): | |||
| def vgg_create_dataset(data_home, repeat_num=1, training=True): | |||
| """Data operations.""" | |||
| ds.config.set_seed(1) | |||
| data_dir = os.path.join(data_home, "cifar-10-batches-bin") | |||
| @@ -0,0 +1,104 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """VGG.""" | |||
| import mindspore.nn as nn | |||
| from mindspore.common.initializer import initializer | |||
| import mindspore.common.dtype as mstype | |||
| def _make_layer(base, batch_norm): | |||
| """Make stage network of VGG.""" | |||
| layers = [] | |||
| in_channels = 3 | |||
| for v in base: | |||
| if v == 'M': | |||
| layers += [nn.MaxPool2d(kernel_size=2, stride=2)] | |||
| else: | |||
| weight_shape = (v, in_channels, 3, 3) | |||
| weight = initializer('XavierUniform', shape=weight_shape, dtype=mstype.float32).to_tensor() | |||
| conv2d = nn.Conv2d(in_channels=in_channels, | |||
| out_channels=v, | |||
| kernel_size=3, | |||
| padding=0, | |||
| pad_mode='same', | |||
| weight_init=weight) | |||
| if batch_norm: | |||
| layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU()] | |||
| else: | |||
| layers += [conv2d, nn.ReLU()] | |||
| in_channels = v | |||
| return nn.SequentialCell(layers) | |||
| class Vgg(nn.Cell): | |||
| """ | |||
| VGG network definition. | |||
| Args: | |||
| base (list): Configuration for different layers, mainly the channel number of Conv layer. | |||
| num_classes (int): Class numbers. Default: 1000. | |||
| batch_norm (bool): Whether to do the batchnorm. Default: False. | |||
| batch_size (int): Batch size. Default: 1. | |||
| Returns: | |||
| Tensor, infer output tensor. | |||
| Examples: | |||
| >>> Vgg([64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], | |||
| >>> num_classes=1000, batch_norm=False, batch_size=1) | |||
| """ | |||
| def __init__(self, base, num_classes=1000, batch_norm=False, batch_size=1): | |||
| super(Vgg, self).__init__() | |||
| _ = batch_size | |||
| self.layers = _make_layer(base, batch_norm=batch_norm) | |||
| self.flatten = nn.Flatten() | |||
| self.classifier = nn.SequentialCell([ | |||
| nn.Dense(512 * 7 * 7, 4096), | |||
| nn.ReLU(), | |||
| nn.Dense(4096, 4096), | |||
| nn.ReLU(), | |||
| nn.Dense(4096, num_classes)]) | |||
| def construct(self, x): | |||
| x = self.layers(x) | |||
| x = self.flatten(x) | |||
| x = self.classifier(x) | |||
| return x | |||
| cfg = { | |||
| '11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], | |||
| '13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], | |||
| '16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], | |||
| '19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], | |||
| } | |||
| def vgg16(num_classes=1000): | |||
| """ | |||
| Get Vgg16 neural network with batch normalization. | |||
| Args: | |||
| num_classes (int): Class numbers. Default: 1000. | |||
| Returns: | |||
| Cell, cell instance of Vgg16 neural network with batch normalization. | |||
| Examples: | |||
| >>> vgg16(num_classes=1000) | |||
| """ | |||
| net = Vgg(cfg['16'], num_classes=num_classes, batch_norm=True) | |||
| return net | |||
| @@ -19,20 +19,24 @@ python train.py --data_path=$DATA_HOME --device_id=$DEVICE_ID | |||
| import argparse | |||
| import os | |||
| import random | |||
| import numpy as np | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore import context | |||
| from mindspore.communication.management import init | |||
| from mindspore.nn.optim.momentum import Momentum | |||
| from mindspore.train.model import Model, ParallelMode | |||
| from mindspore import context | |||
| from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor | |||
| from mindspore.model_zoo.vgg import vgg16 | |||
| from dataset import create_dataset | |||
| from config import cifar_cfg as cfg | |||
| from mindspore.train.model import Model, ParallelMode | |||
| from src.config import cifar_cfg as cfg | |||
| from src.dataset import vgg_create_dataset | |||
| from src.vgg import vgg16 | |||
| random.seed(1) | |||
| np.random.seed(1) | |||
| def lr_steps(global_step, lr_max=None, total_epochs=None, steps_per_epoch=None): | |||
| """Set learning rate.""" | |||
| lr_each_step = [] | |||
| @@ -72,12 +76,13 @@ if __name__ == '__main__': | |||
| mirror_mean=True) | |||
| init() | |||
| dataset = create_dataset(args_opt.data_path, cfg.epoch_size) | |||
| dataset = vgg_create_dataset(args_opt.data_path, cfg.epoch_size) | |||
| batch_num = dataset.get_dataset_size() | |||
| net = vgg16(num_classes=cfg.num_classes) | |||
| lr = lr_steps(0, lr_max=cfg.lr_init, total_epochs=cfg.epoch_size, steps_per_epoch=batch_num) | |||
| opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), Tensor(lr), cfg.momentum, weight_decay=cfg.weight_decay) | |||
| opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), Tensor(lr), cfg.momentum, | |||
| weight_decay=cfg.weight_decay) | |||
| loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False) | |||
| model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}, | |||
| amp_level="O2", keep_batchnorm_fp32=False, loss_scale_manager=None) | |||