Merge pull request !7909 from VectorSL/nhwc-scripttags/v1.1.0
| @@ -133,17 +133,20 @@ sh run_eval_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [C | |||
| ├── run_distribute_train_gpu.sh # launch gpu distributed training(8 pcs) | |||
| ├── run_parameter_server_train_gpu.sh # launch gpu parameter server training(8 pcs) | |||
| ├── run_eval_gpu.sh # launch gpu evaluation | |||
| └── run_standalone_train_gpu.sh # launch gpu standalone training(1 pcs) | |||
| ├── run_standalone_train_gpu.sh # launch gpu standalone training(1 pcs) | |||
| └── run_gpu_resnet_benchmark.sh # GPU benchmark for resnet50 with imagenet2012(1 pcs) | |||
| ├── src | |||
| ├── config.py # parameter configuration | |||
| ├── dataset.py # data preprocessing | |||
| ├── CrossEntropySmooth.py # loss definition for ImageNet2012 dataset | |||
| ├── lr_generator.py # generate learning rate for each step | |||
| └── resnet.py # resnet backbone, including resnet50 and resnet101 and se-resnet50 | |||
| ├── resnet.py # resnet backbone, including resnet50 and resnet101 and se-resnet50 | |||
| └── resnet_gpu_benchmark.py # resnet50 for GPU benchmark | |||
| ├── export.py # export model for inference | |||
| ├── mindspore_hub_conf.py # mindspore hub interface | |||
| ├── eval.py # eval net | |||
| └── train.py # train net | |||
| ├── train.py # train net | |||
| └── gpu_resent_benchmark.py # GPU benchmark for resnet50 | |||
| ``` | |||
| ## [Script Parameters](#contents) | |||
| @@ -272,6 +275,9 @@ sh run_standalone_train_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATA | |||
| # infer example | |||
| sh run_eval_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [CHECKPOINT_PATH] | |||
| # gpu benchmark example | |||
| sh run_gpu_resnet_benchmark.sh [IMAGENET_DATASET_PATH] [BATCH_SIZE](optional) | |||
| ``` | |||
| #### Running parameter server mode training | |||
| @@ -335,7 +341,22 @@ epoch: 4 step: 5004, loss is 3.5011306 | |||
| epoch: 5 step: 5004, loss is 3.3501816 | |||
| ... | |||
| ``` | |||
| - GPU Benchmark of ResNet50 with ImageNet2012 dataset | |||
| ``` | |||
| # ========START RESNET50 GPU BENCHMARK======== | |||
| step time: 22549.130 ms, fps: 11 img/sec. epoch: 1 step: 1, loss is 6.940182 | |||
| step time: 182.485 ms, fps: 1402 img/sec. epoch: 1 step: 2, loss is 7.078993 | |||
| step time: 175.263 ms, fps: 1460 img/sec. epoch: 1 step: 3, loss is 7.559594 | |||
| step time: 174.775 ms, fps: 1464 img/sec. epoch: 1 step: 4, loss is 8.020937 | |||
| step time: 175.564 ms, fps: 1458 img/sec. epoch: 1 step: 5, loss is 8.140132 | |||
| step time: 175.438 ms, fps: 1459 img/sec. epoch: 1 step: 6, loss is 8.021118 | |||
| step time: 175.760 ms, fps: 1456 img/sec. epoch: 1 step: 7, loss is 7.910158 | |||
| step time: 176.033 ms, fps: 1454 img/sec. epoch: 1 step: 8, loss is 7.940162 | |||
| step time: 175.995 ms, fps: 1454 img/sec. epoch: 1 step: 9, loss is 7.740654 | |||
| step time: 175.313 ms, fps: 1460 img/sec. epoch: 1 step: 10, loss is 7.956182 | |||
| ... | |||
| ``` | |||
| ## [Evaluation Process](#contents) | |||
| ### Usage | |||
| @@ -0,0 +1,160 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """train resnet.""" | |||
| import argparse | |||
| import time | |||
| import numpy as np | |||
| from mindspore import context | |||
| from mindspore import Tensor | |||
| from mindspore.nn.optim.momentum import Momentum | |||
| from mindspore.train.model import Model | |||
| from mindspore.train.callback import Callback, LossMonitor | |||
| from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits | |||
| from mindspore.train.loss_scale_manager import FixedLossScaleManager | |||
| from mindspore.common import set_seed | |||
| import mindspore.nn as nn | |||
| import mindspore.common.initializer as weight_init | |||
| import mindspore.common.dtype as mstype | |||
| import mindspore.dataset.engine as de | |||
| import mindspore.dataset.vision.c_transforms as C | |||
| import mindspore.dataset.transforms.c_transforms as C2 | |||
| from src.resnet_gpu_benchmark import resnet50 as resnet | |||
| parser = argparse.ArgumentParser(description='Image classification') | |||
| parser.add_argument('--batch_size', type=str, default="256", help='Batch_size: default 256.') | |||
| parser.add_argument('--epoch_size', type=str, default="2", help='Epoch_size: default 2') | |||
| parser.add_argument('--dataset_path', type=str, default=None, help='Imagenet dataset path') | |||
| args_opt = parser.parse_args() | |||
| set_seed(1) | |||
| class MyTimeMonitor(Callback): | |||
| def __init__(self, batch_size): | |||
| super(MyTimeMonitor, self).__init__() | |||
| self.batch_size = batch_size | |||
| def step_begin(self, run_context): | |||
| self.step_time = time.time() | |||
| def step_end(self, run_context): | |||
| step_mseconds = (time.time() - self.step_time) * 1000 | |||
| fps = self.batch_size / step_mseconds *1000 | |||
| print("step time: {:5.3f} ms, fps: {:d} img/sec.".format(step_mseconds, int(fps)), flush=True, end=" ") | |||
| def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, target="GPU"): | |||
| ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True) | |||
| image_size = 224 | |||
| mean = [0.485 * 255, 0.456 * 255, 0.406 * 255] | |||
| std = [0.229 * 255, 0.224 * 255, 0.225 * 255] | |||
| # define map operations | |||
| if do_train: | |||
| trans = [ | |||
| C.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)), | |||
| C.RandomHorizontalFlip(prob=0.5), | |||
| C.Normalize(mean=mean, std=std), | |||
| ] | |||
| else: | |||
| trans = [ | |||
| C.Decode(), | |||
| C.Resize(256), | |||
| C.CenterCrop(image_size), | |||
| C.Normalize(mean=mean, std=std), | |||
| ] | |||
| type_cast_op = C2.TypeCast(mstype.int32) | |||
| ds = ds.map(operations=trans, input_columns="image", num_parallel_workers=8) | |||
| ds = ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=8) | |||
| ds = ds.map(operations=C2.PadEnd(pad_shape=[224, 224, 4], pad_value=0), input_columns="image", | |||
| num_parallel_workers=8) | |||
| # apply batch operations | |||
| ds = ds.batch(batch_size, drop_remainder=True) | |||
| # apply dataset repeat operation | |||
| ds = ds.repeat(repeat_num) | |||
| return ds | |||
| def get_liner_lr(lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch): | |||
| lr_each_step = [] | |||
| total_steps = steps_per_epoch * total_epochs | |||
| warmup_steps = steps_per_epoch * warmup_epochs | |||
| for i in range(total_steps): | |||
| if i < warmup_steps: | |||
| lr_ = lr_init + (lr_max - lr_init) * i / warmup_steps | |||
| else: | |||
| lr_ = lr_max - (lr_max - lr_end) * (i - warmup_steps) / (total_steps - warmup_steps) | |||
| lr_each_step.append(lr_) | |||
| lr_each_step = np.array(lr_each_step).astype(np.float32) | |||
| return lr_each_step | |||
| if __name__ == '__main__': | |||
| dev = "GPU" | |||
| epoch_size = int(args_opt.epoch_size) | |||
| total_batch = int(args_opt.batch_size) | |||
| # init context | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=dev, save_graphs=False) | |||
| # create dataset | |||
| dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True, repeat_num=1, | |||
| batch_size=total_batch, target=dev) | |||
| step_size = dataset.get_dataset_size() | |||
| # define net | |||
| net = resnet(class_num=1001) | |||
| # init weight | |||
| for _, cell in net.cells_and_names(): | |||
| if isinstance(cell, nn.Conv2d): | |||
| cell.weight.set_data(weight_init.initializer(weight_init.XavierUniform(), | |||
| cell.weight.shape, | |||
| cell.weight.dtype)) | |||
| if isinstance(cell, nn.Dense): | |||
| cell.weight.set_data(weight_init.initializer(weight_init.TruncatedNormal(), | |||
| cell.weight.shape, | |||
| cell.weight.dtype)) | |||
| # init lr | |||
| lr = get_liner_lr(lr_init=0, lr_end=0, lr_max=0.8, warmup_epochs=0, total_epochs=epoch_size, | |||
| steps_per_epoch=step_size) | |||
| lr = Tensor(lr) | |||
| # define opt | |||
| decayed_params = [] | |||
| no_decayed_params = [] | |||
| for param in net.trainable_params(): | |||
| if 'beta' not in param.name and 'gamma' not in param.name and 'bias' not in param.name: | |||
| decayed_params.append(param) | |||
| else: | |||
| no_decayed_params.append(param) | |||
| group_params = [{'params': decayed_params, 'weight_decay': 1e-4}, | |||
| {'params': no_decayed_params}, | |||
| {'order_params': net.trainable_params()}] | |||
| # define loss, model | |||
| loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") | |||
| opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, 0.9, 1e-4, 1024) | |||
| loss_scale = FixedLossScaleManager(1024, drop_overflow_update=False) | |||
| # Mixed precision | |||
| model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'}, | |||
| amp_level="O2", keep_batchnorm_fp32=False) | |||
| # define callbacks | |||
| time_cb = MyTimeMonitor(total_batch) | |||
| loss_cb = LossMonitor() | |||
| cb = [time_cb, loss_cb] | |||
| # train model | |||
| print("========START RESNET50 GPU BENCHMARK========") | |||
| model.train(epoch_size, dataset, callbacks=cb, sink_size=dataset.get_dataset_size()) | |||
| @@ -0,0 +1,42 @@ | |||
| #!/bin/bash | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| if [ $# != 1 ] && [ $# != 2 ] | |||
| then | |||
| echo "Usage: sh run_gpu_resnet_benchmark.sh [DATASET_PATH] [BATCH_SIZE](optional)" | |||
| exit 1 | |||
| fi | |||
| get_real_path(){ | |||
| if [ "${1:0:1}" == "/" ]; then | |||
| echo "$1" | |||
| else | |||
| echo "$(realpath -m $PWD/$1)" | |||
| fi | |||
| } | |||
| DATAPATH=$(get_real_path $1) | |||
| script_self=$(readlink -f "$0") | |||
| self_path=$(dirname "${script_self}") | |||
| if [ $# == 1 ] | |||
| then | |||
| python ${self_path}/../gpu_resnet_benchmark.py --dataset_path=$DATAPATH | |||
| fi | |||
| if [ $# == 2 ] | |||
| then | |||
| python ${self_path}/../gpu_resnet_benchmark.py --dataset_path=$DATAPATH --batch_size=$2 | |||
| fi | |||
| @@ -0,0 +1,258 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ResNet.""" | |||
| import numpy as np | |||
| import mindspore.nn as nn | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore.ops import operations as P | |||
| from mindspore.common.tensor import Tensor | |||
| from scipy.stats import truncnorm | |||
| def _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size): | |||
| fan_in = in_channel * kernel_size * kernel_size | |||
| scale = 1.0 | |||
| scale /= max(1., fan_in) | |||
| stddev = (scale ** 0.5) / .87962566103423978 | |||
| mu, sigma = 0, stddev | |||
| weight = truncnorm(-2, 2, loc=mu, scale=sigma).rvs(out_channel * in_channel * kernel_size * kernel_size) | |||
| weight = np.reshape(weight, (out_channel, kernel_size, kernel_size, in_channel)) | |||
| return Tensor(weight, dtype=mstype.float32) | |||
| def _weight_variable(shape, factor=0.01): | |||
| init_value = np.random.randn(*shape).astype(np.float32) * factor | |||
| return Tensor(init_value) | |||
| def _conv3x3(in_channel, out_channel, stride=1): | |||
| weight_shape = (out_channel, 3, 3, in_channel) | |||
| weight = _weight_variable(weight_shape) | |||
| return nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride, | |||
| padding=1, pad_mode='pad', weight_init=weight, data_format="NHWC") | |||
| def _conv1x1(in_channel, out_channel, stride=1): | |||
| weight_shape = (out_channel, 1, 1, in_channel) | |||
| weight = _weight_variable(weight_shape) | |||
| return nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=stride, | |||
| padding=0, pad_mode='pad', weight_init=weight, data_format="NHWC") | |||
| def _conv7x7(in_channel, out_channel, stride=1): | |||
| weight_shape = (out_channel, 7, 7, in_channel) | |||
| weight = _weight_variable(weight_shape) | |||
| return nn.Conv2d(in_channel, out_channel, kernel_size=7, stride=stride, | |||
| padding=3, pad_mode='pad', weight_init=weight, data_format="NHWC") | |||
| def _bn(channel): | |||
| return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9, gamma_init=1, beta_init=0, | |||
| moving_mean_init=0, moving_var_init=1, data_format="NHWC") | |||
| def _bn_last(channel): | |||
| return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9, gamma_init=0, beta_init=0, | |||
| moving_mean_init=0, moving_var_init=1, data_format="NHWC") | |||
| def _fc(in_channel, out_channel): | |||
| weight_shape = (out_channel, in_channel) | |||
| weight = _weight_variable(weight_shape) | |||
| return nn.Dense(in_channel, out_channel, has_bias=True, weight_init=weight, bias_init=0) | |||
| class ResidualBlock(nn.Cell): | |||
| """ | |||
| ResNet V1 residual block definition. | |||
| Args: | |||
| in_channel (int): Input channel. | |||
| out_channel (int): Output channel. | |||
| stride (int): Stride size for the first convolutional layer. Default: 1. | |||
| Returns: | |||
| Tensor, output tensor. | |||
| Examples: | |||
| >>> ResidualBlock(3, 256, stride=2) | |||
| """ | |||
| expansion = 4 | |||
| def __init__(self, | |||
| in_channel, | |||
| out_channel, | |||
| stride=1): | |||
| super(ResidualBlock, self).__init__() | |||
| self.stride = stride | |||
| channel = out_channel // self.expansion | |||
| self.conv1 = _conv1x1(in_channel, channel, stride=1) | |||
| self.bn1 = _bn(channel) | |||
| self.conv2 = _conv3x3(channel, channel, stride=stride) | |||
| self.bn2 = _bn(channel) | |||
| self.conv3 = _conv1x1(channel, out_channel, stride=1) | |||
| self.bn3 = _bn_last(out_channel) | |||
| self.relu = nn.ReLU() | |||
| self.down_sample = False | |||
| if stride != 1 or in_channel != out_channel: | |||
| self.down_sample = True | |||
| self.down_sample_layer = None | |||
| if self.down_sample: | |||
| self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride), _bn(out_channel)]) | |||
| self.add = P.TensorAdd() | |||
| def construct(self, x): | |||
| identity = x | |||
| out = self.conv1(x) | |||
| out = self.bn1(out) | |||
| out = self.relu(out) | |||
| out = self.conv2(out) | |||
| out = self.bn2(out) | |||
| out = self.relu(out) | |||
| out = self.conv3(out) | |||
| out = self.bn3(out) | |||
| if self.down_sample: | |||
| identity = self.down_sample_layer(identity) | |||
| out = self.add(out, identity) | |||
| out = self.relu(out) | |||
| return out | |||
| class ResNet(nn.Cell): | |||
| """ | |||
| ResNet architecture. | |||
| Args: | |||
| block (Cell): Block for network. | |||
| layer_nums (list): Numbers of block in different layers. | |||
| in_channels (list): Input channel in each layer. | |||
| out_channels (list): Output channel in each layer. | |||
| strides (list): Stride size in each layer. | |||
| num_classes (int): The number of classes that the training images are belonging to. | |||
| Returns: | |||
| Tensor, output tensor. | |||
| Examples: | |||
| >>> ResNet(ResidualBlock, | |||
| >>> [3, 4, 6, 3], | |||
| >>> [64, 256, 512, 1024], | |||
| >>> [256, 512, 1024, 2048], | |||
| >>> [1, 2, 2, 2], | |||
| >>> 10) | |||
| """ | |||
| def __init__(self, | |||
| block, | |||
| layer_nums, | |||
| in_channels, | |||
| out_channels, | |||
| strides, | |||
| num_classes): | |||
| super(ResNet, self).__init__() | |||
| if not len(layer_nums) == len(in_channels) == len(out_channels) == 4: | |||
| raise ValueError("the length of layer_num, in_channels, out_channels list must be 4!") | |||
| self.conv1 = _conv7x7(4, 64, stride=2) | |||
| self.bn1 = _bn(64) | |||
| self.relu = P.ReLU() | |||
| self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same", data_format="NHWC") | |||
| self.layer1 = self._make_layer(block, | |||
| layer_nums[0], | |||
| in_channel=in_channels[0], | |||
| out_channel=out_channels[0], | |||
| stride=strides[0]) | |||
| self.layer2 = self._make_layer(block, | |||
| layer_nums[1], | |||
| in_channel=in_channels[1], | |||
| out_channel=out_channels[1], | |||
| stride=strides[1]) | |||
| self.layer3 = self._make_layer(block, | |||
| layer_nums[2], | |||
| in_channel=in_channels[2], | |||
| out_channel=out_channels[2], | |||
| stride=strides[2]) | |||
| self.layer4 = self._make_layer(block, | |||
| layer_nums[3], | |||
| in_channel=in_channels[3], | |||
| out_channel=out_channels[3], | |||
| stride=strides[3]) | |||
| self.avg_pool = P.AvgPool(7, 1, data_format="NHWC") | |||
| self.flatten = nn.Flatten() | |||
| self.end_point = _fc(out_channels[3], num_classes) | |||
| def _make_layer(self, block, layer_num, in_channel, out_channel, stride): | |||
| """ | |||
| Make stage network of ResNet. | |||
| Args: | |||
| block (Cell): Resnet block. | |||
| layer_num (int): Layer number. | |||
| in_channel (int): Input channel. | |||
| out_channel (int): Output channel. | |||
| stride (int): Stride size for the first convolutional layer. | |||
| Returns: | |||
| SequentialCell, the output layer. | |||
| Examples: | |||
| >>> _make_layer(ResidualBlock, 3, 128, 256, 2) | |||
| """ | |||
| layers = [] | |||
| resnet_block = block(in_channel, out_channel, stride=stride) | |||
| layers.append(resnet_block) | |||
| for _ in range(1, layer_num): | |||
| resnet_block = block(out_channel, out_channel, stride=1) | |||
| layers.append(resnet_block) | |||
| return nn.SequentialCell(layers) | |||
| def construct(self, x): | |||
| x = self.conv1(x) | |||
| x = self.bn1(x) | |||
| x = self.relu(x) | |||
| c1 = self.maxpool(x) | |||
| c2 = self.layer1(c1) | |||
| c3 = self.layer2(c2) | |||
| c4 = self.layer3(c3) | |||
| c5 = self.layer4(c4) | |||
| out = self.avg_pool(c5) | |||
| out = self.flatten(out) | |||
| out = self.end_point(out) | |||
| return out | |||
| def resnet50(class_num=1001): | |||
| """ | |||
| Get ResNet50 neural network. | |||
| Args: | |||
| class_num (int): Class number. | |||
| Returns: | |||
| Cell, cell instance of ResNet50 neural network. | |||
| Examples: | |||
| >>> net = resnet50(1001) | |||
| """ | |||
| return ResNet(ResidualBlock, | |||
| [3, 4, 6, 3], | |||
| [64, 256, 512, 1024], | |||
| [256, 512, 1024, 2048], | |||
| [1, 2, 2, 2], | |||
| class_num) | |||