From: @wangrao124 Reviewed-by: Signed-off-by:tags/v1.1.0
| @@ -5,14 +5,12 @@ | |||||
| - [Dataset](#dataset) | - [Dataset](#dataset) | ||||
| - [Environment Requirements](#environment-requirements) | - [Environment Requirements](#environment-requirements) | ||||
| - [Script Description](#script-description) | - [Script Description](#script-description) | ||||
| - [Script and Sample Code](#script-and-sample-code) | |||||
| - [Training Process](#training-process) | |||||
| - [Evaluation Process](#evaluation-process) | |||||
| - [Evaluation](#evaluation) | |||||
| - [Script and Sample Code](#script-and-sample-code) | |||||
| - [Training Process](#training-process) | |||||
| - [Evaluation Process](#evaluation-process) | |||||
| - [Model Description](#model-description) | - [Model Description](#model-description) | ||||
| - [Performance](#performance) | |||||
| - [Training Performance](#evaluation-performance) | |||||
| - [Inference Performance](#evaluation-performance) | |||||
| - [Performance](#performance) | |||||
| - [Evaluation Performance](#evaluation-performance) | |||||
| - [Description of Random Situation](#description-of-random-situation) | - [Description of Random Situation](#description-of-random-situation) | ||||
| - [ModelZoo Homepage](#modelzoo-homepage) | - [ModelZoo Homepage](#modelzoo-homepage) | ||||
| @@ -22,7 +20,6 @@ TinyNets are a series of lightweight models obtained by twisting resolution, dep | |||||
| [Paper](https://arxiv.org/abs/2010.14819): Kai Han, Yunhe Wang, Qiulin Zhang, Wei Zhang, Chunjing Xu, Tong Zhang. Model Rubik's Cube: Twisting Resolution, Depth and Width for TinyNets. In NeurIPS 2020. | [Paper](https://arxiv.org/abs/2010.14819): Kai Han, Yunhe Wang, Qiulin Zhang, Wei Zhang, Chunjing Xu, Tong Zhang. Model Rubik's Cube: Twisting Resolution, Depth and Width for TinyNets. In NeurIPS 2020. | ||||
| Note: We have only released TinyNet-C for now, and will release other TinyNets soon. | |||||
| # [Model architecture](#contents) | # [Model architecture](#contents) | ||||
| The overall network architecture of TinyNet is show below: | The overall network architecture of TinyNet is show below: | ||||
| @@ -33,53 +30,56 @@ The overall network architecture of TinyNet is show below: | |||||
| Dataset used: [ImageNet 2012](http://image-net.org/challenges/LSVRC/2012/) | Dataset used: [ImageNet 2012](http://image-net.org/challenges/LSVRC/2012/) | ||||
| - Dataset size: | |||||
| - Train: 1.2 million images in 1,000 classes | |||||
| - Test: 50,000 validation images in 1,000 classes | |||||
| - Dataset size: | |||||
| - Train: 1.2 million images in 1,000 classes | |||||
| - Test: 50,000 validation images in 1,000 classes | |||||
| - Data format: RGB images. | - Data format: RGB images. | ||||
| - Note: Data will be processed in src/dataset/dataset.py | |||||
| - Note: Data will be processed in src/dataset/dataset.py | |||||
| # [Environment Requirements](#contents) | # [Environment Requirements](#contents) | ||||
| - Hardware (GPU) | - Hardware (GPU) | ||||
| - Framework | - Framework | ||||
| - [MindSpore](https://www.mindspore.cn/install/en) | |||||
| - For more information, please check the resources below: | |||||
| - [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html) | |||||
| - [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html) | |||||
| - [MindSpore](https://www.mindspore.cn/install/en) | |||||
| - For more information, please check the resources below: | |||||
| - [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html) | |||||
| - [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html) | |||||
| # [Script description](#contents) | |||||
| # [Script Description](#contents) | |||||
| ## [Script and sample code](#contents) | |||||
| ## [Script and Sample Code](#contents) | |||||
| ``` | |||||
| ```markdown | |||||
| .tinynet | .tinynet | ||||
| ├── Readme.md # descriptions about tinynet | |||||
| ├── README.md # descriptions about tinynet | |||||
| ├── script | ├── script | ||||
| │ ├── eval.sh # evaluation script | │ ├── eval.sh # evaluation script | ||||
| │ ├── train_1p_gpu.sh # training script on single GPU | │ ├── train_1p_gpu.sh # training script on single GPU | ||||
| │ └── train_distributed_gpu.sh # distributed training script on multiple GPUs | │ └── train_distributed_gpu.sh # distributed training script on multiple GPUs | ||||
| ├── src | ├── src | ||||
| │ ├── callback.py # loss and checkpoint callbacks | |||||
| │ ├── dataset.py # data processing | |||||
| │ ├── callback.py # loss, ema, and checkpoint callbacks | |||||
| │ ├── dataset.py # data preprocessing | |||||
| │ ├── loss.py # label-smoothing cross-entropy loss function | │ ├── loss.py # label-smoothing cross-entropy loss function | ||||
| │ ├── tinynet.py # tinynet architecture | │ ├── tinynet.py # tinynet architecture | ||||
| │ └── utils.py # utility functions | |||||
| │ └── utils.py # utility functions | |||||
| ├── eval.py # evaluation interface | ├── eval.py # evaluation interface | ||||
| └── train.py # training interface | └── train.py # training interface | ||||
| ``` | ``` | ||||
| ## [Training process](#contents) | |||||
| ### Launch | |||||
| ### [Training process](#contents) | |||||
| ``` | |||||
| #### Launch | |||||
| ```bash | |||||
| # training on single GPU | # training on single GPU | ||||
| sh train_1p_gpu.sh | sh train_1p_gpu.sh | ||||
| # training on multiple GPUs, the number after -n indicates how many GPUs will be used for training | # training on multiple GPUs, the number after -n indicates how many GPUs will be used for training | ||||
| sh train_distributed_gpu.sh -n 8 | sh train_distributed_gpu.sh -n 8 | ||||
| ``` | ``` | ||||
| Inside train.sh, there are hyperparameters that can be adjusted during training, for example: | Inside train.sh, there are hyperparameters that can be adjusted during training, for example: | ||||
| ``` | |||||
| ```python | |||||
| --model tinynet_c model to be used for training | --model tinynet_c model to be used for training | ||||
| --drop 0.2 dropout rate | --drop 0.2 dropout rate | ||||
| --drop-connect 0 drop connect rate | --drop-connect 0 drop connect rate | ||||
| @@ -88,51 +88,55 @@ Inside train.sh, there are hyperparameters that can be adjusted during training, | |||||
| --lr 0.048 learning rate | --lr 0.048 learning rate | ||||
| --batch-size 128 batch size | --batch-size 128 batch size | ||||
| --decay-epochs 2.4 learning rate decays every 2.4 epoch | --decay-epochs 2.4 learning rate decays every 2.4 epoch | ||||
| --warmup-lr 1e-6 warm up learning rate | |||||
| --warmup-lr 1e-6 warm up learning rate | |||||
| --warmup-epochs 3 learning rate warm up epoch | --warmup-epochs 3 learning rate warm up epoch | ||||
| --decay-rate 0.97 learning rate decay rate | --decay-rate 0.97 learning rate decay rate | ||||
| --ema-decay 0.9999 decay factor for model weights moving average | --ema-decay 0.9999 decay factor for model weights moving average | ||||
| --weight-decay 1e-5 optimizer's weight decay | --weight-decay 1e-5 optimizer's weight decay | ||||
| --epochs 450 number of epochs to be trained | --epochs 450 number of epochs to be trained | ||||
| --ckpt_save_epoch 1 checkpoint saving interval | |||||
| --ckpt_save_epoch 1 checkpoint saving interval | |||||
| --workers 8 number of processes for loading data | --workers 8 number of processes for loading data | ||||
| --amp_level O0 training auto-mixed precision | --amp_level O0 training auto-mixed precision | ||||
| --opt rmsprop optimizers, currently we support SGD and RMSProp | --opt rmsprop optimizers, currently we support SGD and RMSProp | ||||
| --data_path /path_to_ImageNet/ | |||||
| --data_path /path_to_ImageNet/ | |||||
| --GPU using GPU for training | --GPU using GPU for training | ||||
| --dataset_sink using sink mode | --dataset_sink using sink mode | ||||
| ``` | ``` | ||||
| The config above was used to train tinynets on ImageNet (change drop-connect to 0.2 for training tinynet-b) | |||||
| The config above was used to train tinynets on ImageNet (change drop-connect to 0.1 for training tinynet_b) | |||||
| > checkpoints will be saved in the ./device_{rank_id} folder (single GPU) | > checkpoints will be saved in the ./device_{rank_id} folder (single GPU) | ||||
| or ./device_parallel folder (multiple GPUs) | or ./device_parallel folder (multiple GPUs) | ||||
| ## [Eval process](#contents) | |||||
| ### [Evaluation Process](#contents) | |||||
| ### Launch | |||||
| #### Launch | |||||
| ``` | |||||
| ```bash | |||||
| # infer example | # infer example | ||||
| sh eval.sh | sh eval.sh | ||||
| ``` | ``` | ||||
| Inside the eval.sh, there are configs that can be adjusted during inference, for example: | Inside the eval.sh, there are configs that can be adjusted during inference, for example: | ||||
| ``` | |||||
| --num-classes 1000 | |||||
| --batch-size 128 | |||||
| --workers 8 | |||||
| --data_path /path_to_ImageNet/ | |||||
| --GPU | |||||
| --ckpt /path_to_EMA_checkpoint/ | |||||
| ```python | |||||
| --num-classes 1000 | |||||
| --batch-size 128 | |||||
| --workers 8 | |||||
| --data_path /path_to_ImageNet/ | |||||
| --GPU | |||||
| --ckpt /path_to_EMA_checkpoint/ | |||||
| --dataset_sink > tinynet_c_eval.log 2>&1 & | --dataset_sink > tinynet_c_eval.log 2>&1 & | ||||
| ``` | ``` | ||||
| > checkpoint can be produced in training process. | > checkpoint can be produced in training process. | ||||
| # [Model Description](#contents) | # [Model Description](#contents) | ||||
| ## [Performance](#contents) | ## [Performance](#contents) | ||||
| #### Evaluation Performance | |||||
| ### Evaluation Performance | |||||
| | Model | FLOPs | Latency* | ImageNet Top-1 | | | Model | FLOPs | Latency* | ImageNet Top-1 | | ||||
| | ------------------- | ----- | -------- | -------------- | | | ------------------- | ----- | -------- | -------------- | | ||||
| @@ -149,6 +153,6 @@ Inside the eval.sh, there are configs that can be adjusted during inference, for | |||||
| We set the seed inside dataset.py. We also use random seed in train.py. | We set the seed inside dataset.py. We also use random seed in train.py. | ||||
| # [Model Zoo Homepage](#contents) | |||||
| # [ModelZoo Homepage](#contents) | |||||
| Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo). | Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo). | ||||
| @@ -36,7 +36,7 @@ def load_nparray_into_net(net, array_dict): | |||||
| for _, param in net.parameters_and_names(): | for _, param in net.parameters_and_names(): | ||||
| if param.name in array_dict: | if param.name in array_dict: | ||||
| new_param = array_dict[param.name] | new_param = array_dict[param.name] | ||||
| param.set_data(Parameter(new_param.copy(), name=param.name)) | |||||
| param.set_data(Parameter(Tensor(deepcopy(new_param)), name=param.name)) | |||||
| else: | else: | ||||
| param_not_load.append(param.name) | param_not_load.append(param.name) | ||||
| return param_not_load | return param_not_load | ||||
| @@ -48,8 +48,8 @@ class EmaEvalCallBack(Callback): | |||||
| the end of training epoch. | the end of training epoch. | ||||
| Args: | Args: | ||||
| model: Mindspore model instance. | |||||
| ema_network: step-wise exponential moving average for ema_network. | |||||
| network: tinynet network instance. | |||||
| ema_network: step-wise exponential moving average of network. | |||||
| eval_dataset: the evaluation daatset. | eval_dataset: the evaluation daatset. | ||||
| decay (float): ema decay. | decay (float): ema decay. | ||||
| save_epoch (int): defines how often to save checkpoint. | save_epoch (int): defines how often to save checkpoint. | ||||
| @@ -57,9 +57,9 @@ class EmaEvalCallBack(Callback): | |||||
| start_epoch (int): which epoch to start/resume training. | start_epoch (int): which epoch to start/resume training. | ||||
| """ | """ | ||||
| def __init__(self, model, ema_network, eval_dataset, loss_fn, decay=0.999, | |||||
| def __init__(self, network, ema_network, eval_dataset, loss_fn, decay=0.999, | |||||
| save_epoch=1, dataset_sink_mode=True, start_epoch=0): | save_epoch=1, dataset_sink_mode=True, start_epoch=0): | ||||
| self.model = model | |||||
| self.network = network | |||||
| self.ema_network = ema_network | self.ema_network = ema_network | ||||
| self.eval_dataset = eval_dataset | self.eval_dataset = eval_dataset | ||||
| self.loss_fn = loss_fn | self.loss_fn = loss_fn | ||||
| @@ -80,14 +80,12 @@ class EmaEvalCallBack(Callback): | |||||
| def begin(self, run_context): | def begin(self, run_context): | ||||
| """Initialize the EMA parameters """ | """Initialize the EMA parameters """ | ||||
| cb_params = run_context.original_args() | |||||
| for _, param in cb_params.network.parameters_and_names(): | |||||
| for _, param in self.network.parameters_and_names(): | |||||
| self.shadow[param.name] = deepcopy(param.data.asnumpy()) | self.shadow[param.name] = deepcopy(param.data.asnumpy()) | ||||
| def step_end(self, run_context): | def step_end(self, run_context): | ||||
| """Update the EMA parameters""" | """Update the EMA parameters""" | ||||
| cb_params = run_context.original_args() | |||||
| for _, param in cb_params.network.parameters_and_names(): | |||||
| for _, param in self.network.parameters_and_names(): | |||||
| new_average = (1.0 - self.decay) * param.data.asnumpy().copy() + \ | new_average = (1.0 - self.decay) * param.data.asnumpy().copy() + \ | ||||
| self.decay * self.shadow[param.name] | self.decay * self.shadow[param.name] | ||||
| self.shadow[param.name] = new_average | self.shadow[param.name] = new_average | ||||
| @@ -98,24 +96,20 @@ class EmaEvalCallBack(Callback): | |||||
| cur_epoch = cb_params.cur_epoch_num + self._start_epoch - 1 | cur_epoch = cb_params.cur_epoch_num + self._start_epoch - 1 | ||||
| save_ckpt = (cur_epoch % self.save_epoch == 0) | save_ckpt = (cur_epoch % self.save_epoch == 0) | ||||
| acc = self.model.eval( | |||||
| self.eval_dataset, dataset_sink_mode=self.dataset_sink_mode) | |||||
| print("Model Accuracy:", acc) | |||||
| load_nparray_into_net(self.ema_network, self.shadow) | load_nparray_into_net(self.ema_network, self.shadow) | ||||
| self.ema_network.set_train(False) | |||||
| model = Model(self.network, loss_fn=self.loss_fn, metrics=self.eval_metrics) | |||||
| model_ema = Model(self.ema_network, loss_fn=self.loss_fn, | model_ema = Model(self.ema_network, loss_fn=self.loss_fn, | ||||
| metrics=self.eval_metrics) | metrics=self.eval_metrics) | ||||
| acc = model.eval( | |||||
| self.eval_dataset, dataset_sink_mode=self.dataset_sink_mode) | |||||
| ema_acc = model_ema.eval( | ema_acc = model_ema.eval( | ||||
| self.eval_dataset, dataset_sink_mode=self.dataset_sink_mode) | self.eval_dataset, dataset_sink_mode=self.dataset_sink_mode) | ||||
| print("Model Accuracy:", acc) | |||||
| print("EMA-Model Accuracy:", ema_acc) | print("EMA-Model Accuracy:", ema_acc) | ||||
| self.ema_accuracy[cur_epoch] = ema_acc["Top1-Acc"] | |||||
| output = [{"name": k, "data": Tensor(v)} | output = [{"name": k, "data": Tensor(v)} | ||||
| for k, v in self.shadow.items()] | for k, v in self.shadow.items()] | ||||
| self.ema_accuracy[cur_epoch] = ema_acc["Top1-Acc"] | |||||
| if self.best_ema_accuracy < ema_acc["Top1-Acc"]: | if self.best_ema_accuracy < ema_acc["Top1-Acc"]: | ||||
| self.best_ema_accuracy = ema_acc["Top1-Acc"] | self.best_ema_accuracy = ema_acc["Top1-Acc"] | ||||
| self.best_ema_epoch = cur_epoch | self.best_ema_epoch = cur_epoch | ||||
| @@ -65,12 +65,12 @@ def create_dataset(batch_size, train_data_url='', workers=8, distributed=False, | |||||
| contrast=adjust_range, | contrast=adjust_range, | ||||
| saturation=adjust_range) | saturation=adjust_range) | ||||
| to_tensor = py_vision.ToTensor() | to_tensor = py_vision.ToTensor() | ||||
| nromlize_op = py_vision.Normalize( | |||||
| normalize_op = py_vision.Normalize( | |||||
| IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) | IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) | ||||
| # assemble all the transforms | # assemble all the transforms | ||||
| image_ops = py_transforms.Compose([decode_op, random_resize_crop_bicubic, | image_ops = py_transforms.Compose([decode_op, random_resize_crop_bicubic, | ||||
| random_horizontal_flip_op, random_color_jitter_op, to_tensor, nromlize_op]) | |||||
| random_horizontal_flip_op, random_color_jitter_op, to_tensor, normalize_op]) | |||||
| rank_id = get_rank() if distributed else 0 | rank_id = get_rank() if distributed else 0 | ||||
| rank_size = get_group_size() if distributed else 1 | rank_size = get_group_size() if distributed else 1 | ||||
| @@ -125,11 +125,11 @@ def create_dataset_val(batch_size=128, val_data_url='', workers=8, distributed=F | |||||
| resize_op = py_vision.Resize(size=scale_size, interpolation=Inter.BICUBIC) | resize_op = py_vision.Resize(size=scale_size, interpolation=Inter.BICUBIC) | ||||
| center_crop = py_vision.CenterCrop(size=input_size) | center_crop = py_vision.CenterCrop(size=input_size) | ||||
| to_tensor = py_vision.ToTensor() | to_tensor = py_vision.ToTensor() | ||||
| nromlize_op = py_vision.Normalize( | |||||
| normalize_op = py_vision.Normalize( | |||||
| IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) | IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) | ||||
| image_ops = py_transforms.Compose([decode_op, resize_op, center_crop, | image_ops = py_transforms.Compose([decode_op, resize_op, center_crop, | ||||
| to_tensor, nromlize_op]) | |||||
| to_tensor, normalize_op]) | |||||
| dataset = dataset.map(input_columns=["label"], operations=type_cast_op, | dataset = dataset.map(input_columns=["label"], operations=type_cast_op, | ||||
| num_parallel_workers=workers) | num_parallel_workers=workers) | ||||
| @@ -18,10 +18,12 @@ import re | |||||
| from copy import deepcopy | from copy import deepcopy | ||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| import mindspore.common.dtype as mstype | |||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore.common.initializer import Normal, Zero, One, initializer, Uniform | from mindspore.common.initializer import Normal, Zero, One, initializer, Uniform | ||||
| from mindspore import context, ms_function | from mindspore import context, ms_function | ||||
| from mindspore.common.parameter import Parameter | from mindspore.common.parameter import Parameter | ||||
| from mindspore import Tensor | |||||
| # Imagenet constant values | # Imagenet constant values | ||||
| IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) | IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) | ||||
| @@ -29,12 +31,14 @@ IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) | |||||
| # model structure configurations for TinyNets, values are | # model structure configurations for TinyNets, values are | ||||
| # (resolution multiplier, channel multiplier, depth multiplier) | # (resolution multiplier, channel multiplier, depth multiplier) | ||||
| # only tinynet-c is availiable for now, we will release other tinynet | |||||
| # models soon | |||||
| # codes are inspired and partially adapted from | # codes are inspired and partially adapted from | ||||
| # https://github.com/rwightman/gen-efficientnet-pytorch | # https://github.com/rwightman/gen-efficientnet-pytorch | ||||
| TINYNET_CFG = {"c": (0.825, 0.54, 0.85)} | |||||
| TINYNET_CFG = {"a": (0.86, 1.0, 1.2), | |||||
| "b": (0.84, 0.75, 1.1), | |||||
| "c": (0.825, 0.54, 0.85), | |||||
| "d": (0.68, 0.54, 0.695), | |||||
| "e": (0.475, 0.51, 0.60)} | |||||
| relu = P.ReLU() | relu = P.ReLU() | ||||
| sigmoid = P.Sigmoid() | sigmoid = P.Sigmoid() | ||||
| @@ -524,13 +528,15 @@ class DropConnect(nn.Cell): | |||||
| self.dtype = P.DType() | self.dtype = P.DType() | ||||
| self.keep_prob = 1 - drop_connect_rate | self.keep_prob = 1 - drop_connect_rate | ||||
| self.dropout = P.Dropout(keep_prob=self.keep_prob) | self.dropout = P.Dropout(keep_prob=self.keep_prob) | ||||
| self.keep_prob_tensor = Tensor(self.keep_prob, dtype=mstype.float32) | |||||
| def construct(self, x): | def construct(self, x): | ||||
| shape = self.shape(x) | shape = self.shape(x) | ||||
| dtype = self.dtype(x) | dtype = self.dtype(x) | ||||
| ones_tensor = P.Fill()(dtype, (shape[0], 1, 1, 1), 1) | ones_tensor = P.Fill()(dtype, (shape[0], 1, 1, 1), 1) | ||||
| _, mask_ = self.dropout(ones_tensor) | |||||
| x = x * mask_ | |||||
| _, mask = self.dropout(ones_tensor) | |||||
| x = x * mask | |||||
| x = x / self.keep_prob_tensor | |||||
| return x | return x | ||||
| @@ -227,7 +227,7 @@ def main(): | |||||
| net_ema.set_train(False) | net_ema.set_train(False) | ||||
| assert args.ema_decay > 0, "EMA should be used in tinynet training." | assert args.ema_decay > 0, "EMA should be used in tinynet training." | ||||
| ema_cb = EmaEvalCallBack(model=model, | |||||
| ema_cb = EmaEvalCallBack(network=net, | |||||
| ema_network=net_ema, | ema_network=net_ema, | ||||
| loss_fn=loss, | loss_fn=loss, | ||||
| eval_dataset=val_dataset, | eval_dataset=val_dataset, | ||||