From: @yanglf1121 Reviewed-by: Signed-off-by:tags/v1.1.0
| @@ -61,6 +61,7 @@ In order to facilitate developers to enjoy the benefits of MindSpore framework, | |||||
| - [GhostNet_Quant](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/ghostnet_quant/README.md) | - [GhostNet_Quant](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/ghostnet_quant/README.md) | ||||
| - [ResNet50-0.65x](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/resnet50_adv_pruning/README.md) | - [ResNet50-0.65x](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/resnet50_adv_pruning/README.md) | ||||
| - [SSD_GhostNet](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/ssd_ghostnet/README.md) | - [SSD_GhostNet](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/ssd_ghostnet/README.md) | ||||
| - [TinyNet](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/tinynet/README.md) | |||||
| - [Natural Language Processing](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/nlp) | - [Natural Language Processing](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/nlp) | ||||
| - [DS-CNN](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/nlp/dscnn/README.md) | - [DS-CNN](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/nlp/dscnn/README.md) | ||||
| - [Audio](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/audio) | - [Audio](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/audio) | ||||
| @@ -0,0 +1,154 @@ | |||||
| # Contents | |||||
| - [TinyNet Description](#tinynet-description) | |||||
| - [Model Architecture](#model-architecture) | |||||
| - [Dataset](#dataset) | |||||
| - [Environment Requirements](#environment-requirements) | |||||
| - [Script Description](#script-description) | |||||
| - [Script and Sample Code](#script-and-sample-code) | |||||
| - [Training Process](#training-process) | |||||
| - [Evaluation Process](#evaluation-process) | |||||
| - [Evaluation](#evaluation) | |||||
| - [Model Description](#model-description) | |||||
| - [Performance](#performance) | |||||
| - [Training Performance](#evaluation-performance) | |||||
| - [Inference Performance](#evaluation-performance) | |||||
| - [Description of Random Situation](#description-of-random-situation) | |||||
| - [ModelZoo Homepage](#modelzoo-homepage) | |||||
| # [TinyNet Description](#contents) | |||||
| TinyNets are a series of lightweight models obtained by twisting resolution, depth and width with a data-driven tiny formula. TinyNet outperforms EfficientNet and MobileNetV3. | |||||
| [Paper](https://arxiv.org/abs/2010.14819): Kai Han, Yunhe Wang, Qiulin Zhang, Wei Zhang, Chunjing Xu, Tong Zhang. Model Rubik's Cube: Twisting Resolution, Depth and Width for TinyNets. In NeurIPS 2020. | |||||
| Note: We have only released TinyNet-C for now, and will release other TinyNets soon. | |||||
| # [Model architecture](#contents) | |||||
| The overall network architecture of TinyNet is show below: | |||||
| [Link](https://arxiv.org/abs/2010.14819) | |||||
| # [Dataset](#contents) | |||||
| Dataset used: [ImageNet 2012](http://image-net.org/challenges/LSVRC/2012/) | |||||
| - Dataset size: | |||||
| - Train: 1.2 million images in 1,000 classes | |||||
| - Test: 50,000 validation images in 1,000 classes | |||||
| - Data format: RGB images. | |||||
| - Note: Data will be processed in src/dataset/dataset.py | |||||
| # [Environment Requirements](#contents) | |||||
| - Hardware (GPU) | |||||
| - Framework | |||||
| - [MindSpore](https://www.mindspore.cn/install/en) | |||||
| - For more information, please check the resources below: | |||||
| - [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html) | |||||
| - [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html) | |||||
| # [Script description](#contents) | |||||
| ## [Script and sample code](#contents) | |||||
| ``` | |||||
| .tinynet | |||||
| ├── Readme.md # descriptions about tinynet | |||||
| ├── script | |||||
| │ ├── eval.sh # evaluation script | |||||
| │ ├── train_1p_gpu.sh # training script on single GPU | |||||
| │ └── train_distributed_gpu.sh # distributed training script on multiple GPUs | |||||
| ├── src | |||||
| │ ├── callback.py # loss and checkpoint callbacks | |||||
| │ ├── dataset.py # data processing | |||||
| │ ├── loss.py # label-smoothing cross-entropy loss function | |||||
| │ ├── tinynet.py # tinynet architecture | |||||
| │ └── utils.py # utility functions | |||||
| ├── eval.py # evaluation interface | |||||
| └── train.py # training interface | |||||
| ``` | |||||
| ## [Training process](#contents) | |||||
| ### Launch | |||||
| ``` | |||||
| # training on single GPU | |||||
| sh train_1p_gpu.sh | |||||
| # training on multiple GPUs, the number after -n indicates how many GPUs will be used for training | |||||
| sh train_distributed_gpu.sh -n 8 | |||||
| ``` | |||||
| Inside train.sh, there are hyperparameters that can be adjusted during training, for example: | |||||
| ``` | |||||
| --model tinynet_c model to be used for training | |||||
| --drop 0.2 dropout rate | |||||
| --drop-connect 0 drop connect rate | |||||
| --num-classes 1000 number of classes for training | |||||
| --opt-eps 0.001 optimizer's epsilon | |||||
| --lr 0.048 learning rate | |||||
| --batch-size 128 batch size | |||||
| --decay-epochs 2.4 learning rate decays every 2.4 epoch | |||||
| --warmup-lr 1e-6 warm up learning rate | |||||
| --warmup-epochs 3 learning rate warm up epoch | |||||
| --decay-rate 0.97 learning rate decay rate | |||||
| --ema-decay 0.9999 decay factor for model weights moving average | |||||
| --weight-decay 1e-5 optimizer's weight decay | |||||
| --epochs 450 number of epochs to be trained | |||||
| --ckpt_save_epoch 1 checkpoint saving interval | |||||
| --workers 8 number of processes for loading data | |||||
| --amp_level O0 training auto-mixed precision | |||||
| --opt rmsprop optimizers, currently we support SGD and RMSProp | |||||
| --data_path /path_to_ImageNet/ | |||||
| --GPU using GPU for training | |||||
| --dataset_sink using sink mode | |||||
| ``` | |||||
| The config above was used to train tinynets on ImageNet (change drop-connect to 0.2 for training tinynet-b) | |||||
| > checkpoints will be saved in the ./device_{rank_id} folder (single GPU) | |||||
| or ./device_parallel folder (multiple GPUs) | |||||
| ## [Eval process](#contents) | |||||
| ### Launch | |||||
| ``` | |||||
| # infer example | |||||
| sh eval.sh | |||||
| ``` | |||||
| Inside the eval.sh, there are configs that can be adjusted during inference, for example: | |||||
| ``` | |||||
| --num-classes 1000 | |||||
| --batch-size 128 | |||||
| --workers 8 | |||||
| --data_path /path_to_ImageNet/ | |||||
| --GPU | |||||
| --ckpt /path_to_EMA_checkpoint/ | |||||
| --dataset_sink > tinynet_c_eval.log 2>&1 & | |||||
| ``` | |||||
| > checkpoint can be produced in training process. | |||||
| # [Model Description](#contents) | |||||
| ## [Performance](#contents) | |||||
| #### Evaluation Performance | |||||
| | Model | FLOPs | Latency* | ImageNet Top-1 | | |||||
| | ------------------- | ----- | -------- | -------------- | | |||||
| | EfficientNet-B0 | 387M | 99.85 ms | 76.7% | | |||||
| | TinyNet-A | 339M | 81.30 ms | 76.8% | | |||||
| | EfficientNet-B^{-4} | 24M | 11.54 ms | 56.7% | | |||||
| | TinyNet-E | 24M | 9.18 ms | 59.9% | | |||||
| *Latency is measured using MS Lite on Huawei P40 smartphone. | |||||
| *More details in [Paper](https://arxiv.org/abs/2010.14819). | |||||
| # [Description of Random Situation](#contents) | |||||
| We set the seed inside dataset.py. We also use random seed in train.py. | |||||
| # [Model Zoo Homepage](#contents) | |||||
| Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo). | |||||
| @@ -0,0 +1,101 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Inference Interface""" | |||||
| import sys | |||||
| import os | |||||
| import argparse | |||||
| from mindspore.train.model import Model | |||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||||
| from mindspore.nn import Loss, Top1CategoricalAccuracy, Top5CategoricalAccuracy | |||||
| from mindspore import context | |||||
| from src.dataset import create_dataset_val | |||||
| from src.utils import count_params | |||||
| from src.loss import LabelSmoothingCrossEntropy | |||||
| from src.tinynet import tinynet | |||||
| parser = argparse.ArgumentParser(description='Evaluation') | |||||
| parser.add_argument('--data_path', type=str, default='/home/dataset/imagenet_jpeg/', | |||||
| metavar='DIR', help='path to dataset') | |||||
| parser.add_argument('--model', default='tinynet_c', type=str, metavar='MODEL', | |||||
| help='Name of model to train (default: "tinynet_c"') | |||||
| parser.add_argument('--num-classes', type=int, default=1000, metavar='N', | |||||
| help='number of label classes (default: 1000)') | |||||
| parser.add_argument('--smoothing', type=float, default=0.1, | |||||
| help='label smoothing (default: 0.1)') | |||||
| parser.add_argument('-b', '--batch-size', type=int, default=32, metavar='N', | |||||
| help='input batch size for training (default: 32)') | |||||
| parser.add_argument('-j', '--workers', type=int, default=4, metavar='N', | |||||
| help='how many training processes to use (default: 1)') | |||||
| parser.add_argument('--ckpt', type=str, default=None, | |||||
| help='model checkpoint to load') | |||||
| parser.add_argument('--GPU', action='store_true', default=True, | |||||
| help='Use GPU for training (default: True)') | |||||
| parser.add_argument('--dataset_sink', action='store_true', default=True) | |||||
| def main(): | |||||
| """Main entrance for training""" | |||||
| args = parser.parse_args() | |||||
| print(sys.argv) | |||||
| context.set_context(mode=context.GRAPH_MODE) | |||||
| if args.GPU: | |||||
| context.set_context(device_target='GPU') | |||||
| # parse model argument | |||||
| assert args.model.startswith( | |||||
| "tinynet"), "Only Tinynet models are supported." | |||||
| _, sub_name = args.model.split("_") | |||||
| net = tinynet(sub_model=sub_name, | |||||
| num_classes=args.num_classes, | |||||
| drop_rate=0.0, | |||||
| drop_connect_rate=0.0, | |||||
| global_pool="avg", | |||||
| bn_tf=False, | |||||
| bn_momentum=None, | |||||
| bn_eps=None) | |||||
| print("Total number of parameters:", count_params(net)) | |||||
| input_size = net.default_cfg['input_size'][1] | |||||
| val_data_url = os.path.join(args.data_path, 'val') | |||||
| val_dataset = create_dataset_val(args.batch_size, | |||||
| val_data_url, | |||||
| workers=args.workers, | |||||
| distributed=False, | |||||
| input_size=input_size) | |||||
| loss = LabelSmoothingCrossEntropy(smooth_factor=args.smoothing, | |||||
| num_classes=args.num_classes) | |||||
| loss.add_flags_recursive(fp32=True, fp16=False) | |||||
| eval_metrics = {'Validation-Loss': Loss(), | |||||
| 'Top1-Acc': Top1CategoricalAccuracy(), | |||||
| 'Top5-Acc': Top5CategoricalAccuracy()} | |||||
| ckpt = load_checkpoint(args.ckpt) | |||||
| load_param_into_net(net, ckpt) | |||||
| net.set_train(False) | |||||
| model = Model(net, loss, metrics=eval_metrics) | |||||
| metrics = model.eval(val_dataset, dataset_sink_mode=False) | |||||
| print(metrics) | |||||
| if __name__ == '__main__': | |||||
| main() | |||||
| @@ -0,0 +1,42 @@ | |||||
| #!/bin/bash | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| cd ../ || exit | |||||
| current_exec_path=$(pwd) | |||||
| echo ${current_exec_path} | |||||
| export RANK_SIZE=1 | |||||
| export start=0 | |||||
| export value=$((start + RANK_SIZE)) | |||||
| export curtime | |||||
| curtime=$(date '+%Y%m%d-%H%M%S') | |||||
| echo "$curtime" | |||||
| rm ${current_exec_path}/device${start}_$curtime/ -rf | |||||
| mkdir ${current_exec_path}/device${start}_$curtime | |||||
| cd ${current_exec_path}/device${start}_$curtime || exit | |||||
| export RANK_ID=start | |||||
| export DEVICE_ID=start | |||||
| time python3 ${current_exec_path}/eval.py \ | |||||
| --model tinynet_c \ | |||||
| --num-classes 1000 \ | |||||
| --batch-size 128 \ | |||||
| --workers 8 \ | |||||
| --data_path /path_to_ImageNet/\ | |||||
| --GPU \ | |||||
| --ckpt /path_to_ckpt/ \ | |||||
| --dataset_sink > tinynet_c_eval.log 2>&1 & | |||||
| @@ -0,0 +1,59 @@ | |||||
| #!/bin/bash | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| cd ../ || exit | |||||
| current_exec_path=$(pwd) | |||||
| echo ${current_exec_path} | |||||
| export RANK_SIZE=1 | |||||
| export start=0 | |||||
| export value=$(($start+$RANK_SIZE)) | |||||
| export curtime | |||||
| curtime=$(date '+%Y%m%d-%H%M%S') | |||||
| echo $curtime | |||||
| echo "rank_id = ${start}" | |||||
| rm ${current_exec_path}/device_$start/ -rf | |||||
| mkdir ${current_exec_path}/device_$start | |||||
| cd ${current_exec_path}/device_$start || exit | |||||
| export RANK_ID=$start | |||||
| export DEVICE_ID=$start | |||||
| time python3 ${current_exec_path}/train.py \ | |||||
| --model tinynet_c \ | |||||
| --drop 0.2 \ | |||||
| --drop-connect 0 \ | |||||
| --num-classes 1000 \ | |||||
| --opt-eps 0.001 \ | |||||
| --lr 0.048 \ | |||||
| --batch-size 128 \ | |||||
| --decay-epochs 2.4 \ | |||||
| --warmup-lr 1e-6 \ | |||||
| --warmup-epochs 3 \ | |||||
| --decay-rate 0.97 \ | |||||
| --ema-decay 0.9999 \ | |||||
| --weight-decay 1e-5 \ | |||||
| --epochs 100\ | |||||
| --ckpt_save_epoch 1 \ | |||||
| --workers 8 \ | |||||
| --amp_level O0 \ | |||||
| --opt rmsprop \ | |||||
| --data_path /path_to_ImageNet/ \ | |||||
| --GPU \ | |||||
| --dataset_sink > tinynet_c.log 2>&1 & | |||||
| cd ${current_exec_path} || exit | |||||
| @@ -0,0 +1,82 @@ | |||||
| #!/bin/bash | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| # below help function was adapted from | |||||
| # https://unix.stackexchange.com/questions/31414/how-can-i-pass-a-command-line-argument-into-a-shell-script | |||||
| helpFunction() | |||||
| { | |||||
| echo "" | |||||
| echo "Usage: $0 -n num_device" | |||||
| echo -e "\t-n how many gpus to use for training" | |||||
| exit 1 # Exit script after printing help | |||||
| } | |||||
| while getopts "n:" opt | |||||
| do | |||||
| case "$opt" in | |||||
| n ) num_device="$OPTARG" ;; | |||||
| ? ) helpFunction ;; # Print helpFunction in case parameter is non-existent | |||||
| esac | |||||
| done | |||||
| # Print helpFunction in case parameters are empty | |||||
| if [ -z "$num_device" ] | |||||
| then | |||||
| echo "Some or all of the parameters are empty"; | |||||
| helpFunction | |||||
| fi | |||||
| # Begin script in case all parameters are correct | |||||
| echo "$num_device" | |||||
| cd ../ || exit | |||||
| current_exec_path=$(pwd) | |||||
| echo ${current_exec_path} | |||||
| export SLOG_PRINT_TO_STDOUT=0 | |||||
| export RANK_SIZE=$num_device | |||||
| export curtime | |||||
| curtime=$(date '+%Y%m%d-%H%M%S') | |||||
| echo $curtime | |||||
| echo $curtime >> starttime | |||||
| rm ${current_exec_path}/device_parallel/ -rf | |||||
| mkdir ${current_exec_path}/device_parallel | |||||
| cd ${current_exec_path}/device_parallel || exit | |||||
| echo $curtime >> starttime | |||||
| time mpirun -n $RANK_SIZE --allow-run-as-root python3 ${current_exec_path}/train.py \ | |||||
| --model tinynet_c \ | |||||
| --drop 0.2 \ | |||||
| --drop-connect 0 \ | |||||
| --num-classes 1000 \ | |||||
| --opt-eps 0.001 \ | |||||
| --lr 0.048 \ | |||||
| --batch-size 128 \ | |||||
| --decay-epochs 2.4 \ | |||||
| --warmup-lr 1e-6 \ | |||||
| --warmup-epochs 3 \ | |||||
| --decay-rate 0.97 \ | |||||
| --ema-decay 0.9999 \ | |||||
| --weight-decay 1e-5 \ | |||||
| --per_print_times 100 \ | |||||
| --epochs 450 \ | |||||
| --ckpt_save_epoch 1 \ | |||||
| --workers 8 \ | |||||
| --amp_level O0 \ | |||||
| --opt rmsprop \ | |||||
| --distributed \ | |||||
| --data_path /path_to_ImageNet/ \ | |||||
| --GPU \ | |||||
| --dataset_sink > tinynet_c.log 2>&1 & | |||||
| @@ -0,0 +1,203 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """custom callbacks for ema and loss""" | |||||
| from copy import deepcopy | |||||
| import numpy as np | |||||
| from mindspore.train.callback import Callback | |||||
| from mindspore.common.parameter import Parameter | |||||
| from mindspore.train.serialization import save_checkpoint | |||||
| from mindspore.nn import Loss, Top1CategoricalAccuracy, Top5CategoricalAccuracy | |||||
| from mindspore.train.model import Model | |||||
| from mindspore import Tensor | |||||
| def load_nparray_into_net(net, array_dict): | |||||
| """ | |||||
| Loads dictionary of numpy arrays into network. | |||||
| Args: | |||||
| net (Cell): Cell network. | |||||
| array_dict (dict): dictionary of numpy array format model weights. | |||||
| """ | |||||
| param_not_load = [] | |||||
| for _, param in net.parameters_and_names(): | |||||
| if param.name in array_dict: | |||||
| new_param = array_dict[param.name] | |||||
| param.set_data(Parameter(new_param.copy(), name=param.name)) | |||||
| else: | |||||
| param_not_load.append(param.name) | |||||
| return param_not_load | |||||
| class EmaEvalCallBack(Callback): | |||||
| """ | |||||
| Call back that will evaluate the model and save model checkpoint at | |||||
| the end of training epoch. | |||||
| Args: | |||||
| model: Mindspore model instance. | |||||
| ema_network: step-wise exponential moving average for ema_network. | |||||
| eval_dataset: the evaluation daatset. | |||||
| decay (float): ema decay. | |||||
| save_epoch (int): defines how often to save checkpoint. | |||||
| dataset_sink_mode (bool): whether to use data sink mode. | |||||
| start_epoch (int): which epoch to start/resume training. | |||||
| """ | |||||
| def __init__(self, model, ema_network, eval_dataset, loss_fn, decay=0.999, | |||||
| save_epoch=1, dataset_sink_mode=True, start_epoch=0): | |||||
| self.model = model | |||||
| self.ema_network = ema_network | |||||
| self.eval_dataset = eval_dataset | |||||
| self.loss_fn = loss_fn | |||||
| self.decay = decay | |||||
| self.save_epoch = save_epoch | |||||
| self.shadow = {} | |||||
| self.ema_accuracy = {} | |||||
| self.best_ema_accuracy = 0 | |||||
| self.best_accuracy = 0 | |||||
| self.best_ema_epoch = 0 | |||||
| self.best_epoch = 0 | |||||
| self._start_epoch = start_epoch | |||||
| self.eval_metrics = {'Validation-Loss': Loss(), | |||||
| 'Top1-Acc': Top1CategoricalAccuracy(), | |||||
| 'Top5-Acc': Top5CategoricalAccuracy()} | |||||
| self.dataset_sink_mode = dataset_sink_mode | |||||
| def begin(self, run_context): | |||||
| """Initialize the EMA parameters """ | |||||
| cb_params = run_context.original_args() | |||||
| for _, param in cb_params.network.parameters_and_names(): | |||||
| self.shadow[param.name] = deepcopy(param.data.asnumpy()) | |||||
| def step_end(self, run_context): | |||||
| """Update the EMA parameters""" | |||||
| cb_params = run_context.original_args() | |||||
| for _, param in cb_params.network.parameters_and_names(): | |||||
| new_average = (1.0 - self.decay) * param.data.asnumpy().copy() + \ | |||||
| self.decay * self.shadow[param.name] | |||||
| self.shadow[param.name] = new_average | |||||
| def epoch_end(self, run_context): | |||||
| """evaluate the model and ema-model at the end of each epoch""" | |||||
| cb_params = run_context.original_args() | |||||
| cur_epoch = cb_params.cur_epoch_num + self._start_epoch - 1 | |||||
| save_ckpt = (cur_epoch % self.save_epoch == 0) | |||||
| acc = self.model.eval( | |||||
| self.eval_dataset, dataset_sink_mode=self.dataset_sink_mode) | |||||
| print("Model Accuracy:", acc) | |||||
| load_nparray_into_net(self.ema_network, self.shadow) | |||||
| self.ema_network.set_train(False) | |||||
| model_ema = Model(self.ema_network, loss_fn=self.loss_fn, | |||||
| metrics=self.eval_metrics) | |||||
| ema_acc = model_ema.eval( | |||||
| self.eval_dataset, dataset_sink_mode=self.dataset_sink_mode) | |||||
| print("EMA-Model Accuracy:", ema_acc) | |||||
| self.ema_accuracy[cur_epoch] = ema_acc["Top1-Acc"] | |||||
| output = [{"name": k, "data": Tensor(v)} | |||||
| for k, v in self.shadow.items()] | |||||
| if self.best_ema_accuracy < ema_acc["Top1-Acc"]: | |||||
| self.best_ema_accuracy = ema_acc["Top1-Acc"] | |||||
| self.best_ema_epoch = cur_epoch | |||||
| save_checkpoint(output, "ema_best.ckpt") | |||||
| if self.best_accuracy < acc["Top1-Acc"]: | |||||
| self.best_accuracy = acc["Top1-Acc"] | |||||
| self.best_epoch = cur_epoch | |||||
| print("Best Model Accuracy: %s, at epoch %s" % | |||||
| (self.best_accuracy, self.best_epoch)) | |||||
| print("Best EMA-Model Accuracy: %s, at epoch %s" % | |||||
| (self.best_ema_accuracy, self.best_ema_epoch)) | |||||
| if save_ckpt: | |||||
| # Save the ema_model checkpoints | |||||
| ckpt = "{}-{}.ckpt".format("ema", cur_epoch) | |||||
| save_checkpoint(output, ckpt) | |||||
| save_checkpoint(output, "ema_last.ckpt") | |||||
| # Save the model checkpoints | |||||
| save_checkpoint(cb_params.train_network, "last.ckpt") | |||||
| print("Top 10 EMA-Model Accuracies: ") | |||||
| count = 0 | |||||
| for epoch in sorted(self.ema_accuracy, key=self.ema_accuracy.get, | |||||
| reverse=True): | |||||
| if count == 10: | |||||
| break | |||||
| print("epoch: %s, Top-1: %s)" % (epoch, self.ema_accuracy[epoch])) | |||||
| count += 1 | |||||
| class LossMonitor(Callback): | |||||
| """ | |||||
| Monitor the loss in training. | |||||
| If the loss is NAN or INF, it will terminate training. | |||||
| Note: | |||||
| If per_print_times is 0, do not print loss. | |||||
| Args: | |||||
| lr_array (numpy.array): scheduled learning rate. | |||||
| total_epochs (int): Total number of epochs for training. | |||||
| per_print_times (int): Print the loss every time. Default: 1. | |||||
| start_epoch (int): which epoch to start, used when resume from a | |||||
| certain epoch. | |||||
| Raises: | |||||
| ValueError: If print_step is not an integer or less than zero. | |||||
| """ | |||||
| def __init__(self, lr_array, total_epochs, per_print_times=1, start_epoch=0): | |||||
| super(LossMonitor, self).__init__() | |||||
| if not isinstance(per_print_times, int) or per_print_times < 0: | |||||
| raise ValueError("print_step must be int and >= 0.") | |||||
| self._per_print_times = per_print_times | |||||
| self._lr_array = lr_array | |||||
| self._total_epochs = total_epochs | |||||
| self._start_epoch = start_epoch | |||||
| def step_end(self, run_context): | |||||
| """log epoch, step, loss and learning rate""" | |||||
| cb_params = run_context.original_args() | |||||
| loss = cb_params.net_outputs | |||||
| cur_epoch_num = cb_params.cur_epoch_num + self._start_epoch - 1 | |||||
| if isinstance(loss, (tuple, list)): | |||||
| if isinstance(loss[0], Tensor) and isinstance(loss[0].asnumpy(), np.ndarray): | |||||
| loss = loss[0] | |||||
| if isinstance(loss, Tensor) and isinstance(loss.asnumpy(), np.ndarray): | |||||
| loss = np.mean(loss.asnumpy()) | |||||
| global_step = cb_params.cur_step_num - 1 | |||||
| cur_step_in_epoch = global_step % cb_params.batch_num + 1 | |||||
| if isinstance(loss, float) and (np.isnan(loss) or np.isinf(loss)): | |||||
| raise ValueError("epoch: {} step: {}. Invalid loss, terminating training.".format( | |||||
| cur_epoch_num, cur_step_in_epoch)) | |||||
| if self._per_print_times != 0 and cur_step_in_epoch % self._per_print_times == 0: | |||||
| print("epoch: %s/%s, step: %s/%s, loss is %s, learning rate: %s" | |||||
| % (cur_epoch_num, self._total_epochs, cur_step_in_epoch, | |||||
| cb_params.batch_num, loss, self._lr_array[global_step]), | |||||
| flush=True) | |||||
| @@ -0,0 +1,143 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Data operations, will be used in train.py and eval.py""" | |||||
| import math | |||||
| import os | |||||
| import numpy as np | |||||
| import mindspore.dataset.vision.py_transforms as py_vision | |||||
| import mindspore.dataset.transforms.py_transforms as py_transforms | |||||
| import mindspore.dataset.transforms.c_transforms as c_transforms | |||||
| import mindspore.common.dtype as mstype | |||||
| import mindspore.dataset as ds | |||||
| from mindspore.communication.management import get_rank, get_group_size | |||||
| from mindspore.dataset.vision import Inter | |||||
| # values that should remain constant | |||||
| DEFAULT_CROP_PCT = 0.875 | |||||
| IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) | |||||
| IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) | |||||
| # data preprocess configs | |||||
| SCALE = (0.08, 1.0) | |||||
| RATIO = (3./4., 4./3.) | |||||
| ds.config.set_seed(1) | |||||
| def split_imgs_and_labels(imgs, labels, batchInfo): | |||||
| """split data into labels and images""" | |||||
| ret_imgs = [] | |||||
| ret_labels = [] | |||||
| for i, image in enumerate(imgs): | |||||
| ret_imgs.append(image) | |||||
| ret_labels.append(labels[i]) | |||||
| return np.array(ret_imgs), np.array(ret_labels) | |||||
| def create_dataset(batch_size, train_data_url='', workers=8, distributed=False, | |||||
| input_size=224, color_jitter=0.4): | |||||
| """Creat ImageNet training dataset""" | |||||
| if not os.path.exists(train_data_url): | |||||
| raise ValueError('Path not exists') | |||||
| decode_op = py_vision.Decode() | |||||
| type_cast_op = c_transforms.TypeCast(mstype.int32) | |||||
| random_resize_crop_bicubic = py_vision.RandomResizedCrop(size=(input_size, input_size), | |||||
| scale=SCALE, ratio=RATIO, | |||||
| interpolation=Inter.BICUBIC) | |||||
| random_horizontal_flip_op = py_vision.RandomHorizontalFlip(0.5) | |||||
| adjust_range = (max(0, 1 - color_jitter), 1 + color_jitter) | |||||
| random_color_jitter_op = py_vision.RandomColorAdjust(brightness=adjust_range, | |||||
| contrast=adjust_range, | |||||
| saturation=adjust_range) | |||||
| to_tensor = py_vision.ToTensor() | |||||
| nromlize_op = py_vision.Normalize( | |||||
| IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) | |||||
| # assemble all the transforms | |||||
| image_ops = py_transforms.Compose([decode_op, random_resize_crop_bicubic, | |||||
| random_horizontal_flip_op, random_color_jitter_op, to_tensor, nromlize_op]) | |||||
| rank_id = get_rank() if distributed else 0 | |||||
| rank_size = get_group_size() if distributed else 1 | |||||
| dataset_train = ds.ImageFolderDataset(train_data_url, | |||||
| num_parallel_workers=workers, | |||||
| shuffle=True, | |||||
| num_shards=rank_size, | |||||
| shard_id=rank_id) | |||||
| dataset_train = dataset_train.map(input_columns=["image"], | |||||
| operations=image_ops, | |||||
| num_parallel_workers=workers) | |||||
| dataset_train = dataset_train.map(input_columns=["label"], | |||||
| operations=type_cast_op, | |||||
| num_parallel_workers=workers) | |||||
| # batch dealing | |||||
| ds_train = dataset_train.batch(batch_size, | |||||
| per_batch_map=split_imgs_and_labels, | |||||
| input_columns=["image", "label"], | |||||
| num_parallel_workers=2, | |||||
| drop_remainder=True) | |||||
| ds_train = ds_train.repeat(1) | |||||
| return ds_train | |||||
| def create_dataset_val(batch_size=128, val_data_url='', workers=8, distributed=False, | |||||
| input_size=224): | |||||
| """Creat ImageNet validation dataset""" | |||||
| if not os.path.exists(val_data_url): | |||||
| raise ValueError('Path not exists') | |||||
| rank_id = get_rank() if distributed else 0 | |||||
| rank_size = get_group_size() if distributed else 1 | |||||
| dataset = ds.ImageFolderDataset(val_data_url, num_parallel_workers=workers, | |||||
| num_shards=rank_size, shard_id=rank_id) | |||||
| scale_size = None | |||||
| if isinstance(input_size, tuple): | |||||
| assert len(input_size) == 2 | |||||
| if input_size[-1] == input_size[-2]: | |||||
| scale_size = int(math.floor(input_size[0] / DEFAULT_CROP_PCT)) | |||||
| else: | |||||
| scale_size = tuple([int(x / DEFAULT_CROP_PCT) for x in input_size]) | |||||
| else: | |||||
| scale_size = int(math.floor(input_size / DEFAULT_CROP_PCT)) | |||||
| type_cast_op = c_transforms.TypeCast(mstype.int32) | |||||
| decode_op = py_vision.Decode() | |||||
| resize_op = py_vision.Resize(size=scale_size, interpolation=Inter.BICUBIC) | |||||
| center_crop = py_vision.CenterCrop(size=input_size) | |||||
| to_tensor = py_vision.ToTensor() | |||||
| nromlize_op = py_vision.Normalize( | |||||
| IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) | |||||
| image_ops = py_transforms.Compose([decode_op, resize_op, center_crop, | |||||
| to_tensor, nromlize_op]) | |||||
| dataset = dataset.map(input_columns=["label"], operations=type_cast_op, | |||||
| num_parallel_workers=workers) | |||||
| dataset = dataset.map(input_columns=["image"], operations=image_ops, | |||||
| num_parallel_workers=workers) | |||||
| dataset = dataset.batch(batch_size, per_batch_map=split_imgs_and_labels, | |||||
| input_columns=["image", "label"], | |||||
| num_parallel_workers=2, | |||||
| drop_remainder=True) | |||||
| dataset = dataset.repeat(1) | |||||
| return dataset | |||||
| @@ -0,0 +1,44 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """define loss function for network.""" | |||||
| from mindspore.nn.loss.loss import _Loss | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore.ops import functional as F | |||||
| from mindspore import Tensor | |||||
| from mindspore.common import dtype as mstype | |||||
| import mindspore.nn as nn | |||||
| class LabelSmoothingCrossEntropy(_Loss): | |||||
| """cross-entropy with label smoothing""" | |||||
| def __init__(self, smooth_factor=0.1, num_classes=1000): | |||||
| super(LabelSmoothingCrossEntropy, self).__init__() | |||||
| self.onehot = P.OneHot() | |||||
| self.on_value = Tensor(1.0 - smooth_factor, mstype.float32) | |||||
| self.off_value = Tensor(1.0 * smooth_factor / | |||||
| (num_classes - 1), mstype.float32) | |||||
| self.ce = nn.SoftmaxCrossEntropyWithLogits() | |||||
| self.mean = P.ReduceMean(False) | |||||
| self.cast = P.Cast() | |||||
| def construct(self, logits, label): | |||||
| label = self.cast(label, mstype.int32) | |||||
| one_hot_label = self.onehot(label, F.shape( | |||||
| logits)[1], self.on_value, self.off_value) | |||||
| loss_logit = self.ce(logits, one_hot_label) | |||||
| loss_logit = self.mean(loss_logit, 0) | |||||
| return loss_logit | |||||
| @@ -0,0 +1,808 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Tinynet model definition""" | |||||
| import math | |||||
| import re | |||||
| from copy import deepcopy | |||||
| import mindspore.nn as nn | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore.common.initializer import Normal, Zero, One, initializer, Uniform | |||||
| from mindspore import context, ms_function | |||||
| from mindspore.common.parameter import Parameter | |||||
| # Imagenet constant values | |||||
| IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) | |||||
| IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) | |||||
| # model structure configurations for TinyNets, values are | |||||
| # (resolution multiplier, channel multiplier, depth multiplier) | |||||
| # only tinynet-c is availiable for now, we will release other tinynet | |||||
| # models soon | |||||
| # codes are inspired and partially adapted from | |||||
| # https://github.com/rwightman/gen-efficientnet-pytorch | |||||
| TINYNET_CFG = {"c": (0.825, 0.54, 0.85)} | |||||
| relu = P.ReLU() | |||||
| sigmoid = P.Sigmoid() | |||||
| def _cfg(url='', **kwargs): | |||||
| return { | |||||
| 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), | |||||
| 'crop_pct': 0.875, 'interpolation': 'bicubic', | |||||
| 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, | |||||
| 'first_conv': 'conv_stem', 'classifier': 'classifier', | |||||
| **kwargs | |||||
| } | |||||
| default_cfgs = { | |||||
| 'efficientnet_b0': _cfg( | |||||
| url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b0-d6904d92.pth'), | |||||
| 'efficientnet_b1': _cfg( | |||||
| url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b1-533bc792.pth', | |||||
| input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), | |||||
| 'efficientnet_b2': _cfg( | |||||
| url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b2-cf78dc4d.pth', | |||||
| input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890), | |||||
| 'efficientnet_b3': _cfg( | |||||
| url='', input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904), | |||||
| 'efficientnet_b4': _cfg( | |||||
| url='', input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922), | |||||
| } | |||||
| _DEBUG = False | |||||
| # Default args for PyTorch BN impl | |||||
| _BN_MOMENTUM_PT_DEFAULT = 0.1 | |||||
| _BN_EPS_PT_DEFAULT = 1e-5 | |||||
| _BN_ARGS_PT = dict(momentum=_BN_MOMENTUM_PT_DEFAULT, eps=_BN_EPS_PT_DEFAULT) | |||||
| # Defaults used for Google/Tensorflow training of mobile networks /w | |||||
| # RMSprop as per papers and TF reference implementations. PT momentum | |||||
| # equiv for TF decay is (1 - TF decay) | |||||
| # NOTE: momentum varies btw .99 and .9997 depending on source | |||||
| # .99 in official TF TPU impl | |||||
| # .9997 (/w .999 in search space) for paper | |||||
| _BN_MOMENTUM_TF_DEFAULT = 1 - 0.99 | |||||
| _BN_EPS_TF_DEFAULT = 1e-3 | |||||
| _BN_ARGS_TF = dict(momentum=_BN_MOMENTUM_TF_DEFAULT, eps=_BN_EPS_TF_DEFAULT) | |||||
| def _initialize_weight_goog(shape=None, layer_type='conv', bias=False): | |||||
| """Google style weight initialization""" | |||||
| if layer_type not in ('conv', 'bn', 'fc'): | |||||
| raise ValueError( | |||||
| 'The layer type is not known, the supported are conv, bn and fc') | |||||
| if bias: | |||||
| return Zero() | |||||
| if layer_type == 'conv': | |||||
| assert isinstance(shape, (tuple, list)) and len( | |||||
| shape) == 3, 'The shape must be 3 scalars, and are in_chs, ks, out_chs respectively' | |||||
| n = shape[1] * shape[1] * shape[2] | |||||
| return Normal(math.sqrt(2.0 / n)) | |||||
| if layer_type == 'bn': | |||||
| return One() | |||||
| assert isinstance(shape, (tuple, list)) and len( | |||||
| shape) == 2, 'The shape must be 2 scalars, and are in_chs, out_chs respectively' | |||||
| n = shape[1] | |||||
| init_range = 1.0 / math.sqrt(n) | |||||
| return Uniform(init_range) | |||||
| def _conv(in_channels, out_channels, kernel_size=3, stride=1, padding=0, | |||||
| pad_mode='same', bias=False): | |||||
| """convolution wrapper""" | |||||
| weight_init_value = _initialize_weight_goog( | |||||
| shape=(in_channels, kernel_size, out_channels)) | |||||
| bias_init_value = _initialize_weight_goog(bias=True) if bias else None | |||||
| if bias: | |||||
| return nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, | |||||
| padding=padding, pad_mode=pad_mode, weight_init=weight_init_value, | |||||
| has_bias=bias, bias_init=bias_init_value) | |||||
| return nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, | |||||
| padding=padding, pad_mode=pad_mode, weight_init=weight_init_value, | |||||
| has_bias=bias) | |||||
| def _conv1x1(in_channels, out_channels, stride=1, padding=0, pad_mode='same', bias=False): | |||||
| """1x1 convolution wrapper""" | |||||
| weight_init_value = _initialize_weight_goog( | |||||
| shape=(in_channels, 1, out_channels)) | |||||
| bias_init_value = _initialize_weight_goog(bias=True) if bias else None | |||||
| if bias: | |||||
| return nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, | |||||
| padding=padding, pad_mode=pad_mode, weight_init=weight_init_value, | |||||
| has_bias=bias, bias_init=bias_init_value) | |||||
| return nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, | |||||
| padding=padding, pad_mode=pad_mode, weight_init=weight_init_value, | |||||
| has_bias=bias) | |||||
| def _conv_group(in_channels, out_channels, group, kernel_size=3, stride=1, padding=0, | |||||
| pad_mode='same', bias=False): | |||||
| """group convolution wrapper""" | |||||
| weight_init_value = _initialize_weight_goog( | |||||
| shape=(in_channels, kernel_size, out_channels)) | |||||
| bias_init_value = _initialize_weight_goog(bias=True) if bias else None | |||||
| if bias: | |||||
| return nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, | |||||
| padding=padding, pad_mode=pad_mode, weight_init=weight_init_value, | |||||
| group=group, has_bias=bias, bias_init=bias_init_value) | |||||
| return nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, | |||||
| padding=padding, pad_mode=pad_mode, weight_init=weight_init_value, | |||||
| group=group, has_bias=bias) | |||||
| def _fused_bn(channels, momentum=0.1, eps=1e-4, gamma_init=1, beta_init=0): | |||||
| return nn.BatchNorm2d(channels, eps=eps, momentum=1-momentum, gamma_init=gamma_init, | |||||
| beta_init=beta_init) | |||||
| def _dense(in_channels, output_channels, bias=True, activation=None): | |||||
| weight_init_value = _initialize_weight_goog(shape=(in_channels, output_channels), | |||||
| layer_type='fc') | |||||
| bias_init_value = _initialize_weight_goog(bias=True) if bias else None | |||||
| if bias: | |||||
| return nn.Dense(in_channels, output_channels, weight_init=weight_init_value, | |||||
| bias_init=bias_init_value, has_bias=bias, activation=activation) | |||||
| return nn.Dense(in_channels, output_channels, weight_init=weight_init_value, | |||||
| has_bias=bias, activation=activation) | |||||
| def _resolve_bn_args(kwargs): | |||||
| bn_args = _BN_ARGS_TF.copy() if kwargs.pop( | |||||
| 'bn_tf', False) else _BN_ARGS_PT.copy() | |||||
| bn_momentum = kwargs.pop('bn_momentum', None) | |||||
| if bn_momentum is not None: | |||||
| bn_args['momentum'] = bn_momentum | |||||
| bn_eps = kwargs.pop('bn_eps', None) | |||||
| if bn_eps is not None: | |||||
| bn_args['eps'] = bn_eps | |||||
| return bn_args | |||||
| def _round_channels(channels, multiplier=1.0, divisor=8, channel_min=None): | |||||
| """Round number of filters based on depth multiplier.""" | |||||
| if not multiplier: | |||||
| return channels | |||||
| channels *= multiplier | |||||
| channel_min = channel_min or divisor | |||||
| new_channels = max( | |||||
| int(channels + divisor / 2) // divisor * divisor, | |||||
| channel_min) | |||||
| # Make sure that round down does not go down by more than 10%. | |||||
| if new_channels < 0.9 * channels: | |||||
| new_channels += divisor | |||||
| return new_channels | |||||
| def _parse_ksize(ss): | |||||
| if ss.isdigit(): | |||||
| return int(ss) | |||||
| return [int(k) for k in ss.split('.')] | |||||
| def _decode_block_str(block_str, depth_multiplier=1.0): | |||||
| """ Decode block definition string | |||||
| Gets a list of block arg (dicts) through a string notation of arguments. | |||||
| E.g. ir_r2_k3_s2_e1_i32_o16_se0.25_noskip | |||||
| All args can exist in any order with the exception of the leading string which | |||||
| is assumed to indicate the block type. | |||||
| leading string - block type ( | |||||
| ir = InvertedResidual, ds = DepthwiseSep, dsa = DeptwhiseSep with pw act, cn = ConvBnAct) | |||||
| r - number of repeat blocks, | |||||
| k - kernel size, | |||||
| s - strides (1-9), | |||||
| e - expansion ratio, | |||||
| c - output channels, | |||||
| se - squeeze/excitation ratio | |||||
| n - activation fn ('re', 'r6', 'hs', or 'sw') | |||||
| Args: | |||||
| block_str: a string representation of block arguments. | |||||
| Returns: | |||||
| A list of block args (dicts) | |||||
| Raises: | |||||
| ValueError: if the string def not properly specified (TODO) | |||||
| """ | |||||
| assert isinstance(block_str, str) | |||||
| ops = block_str.split('_') | |||||
| block_type = ops[0] # take the block type off the front | |||||
| ops = ops[1:] | |||||
| options = {} | |||||
| noskip = False | |||||
| for op in ops: | |||||
| if op == 'noskip': | |||||
| noskip = True | |||||
| elif op.startswith('n'): | |||||
| # activation fn | |||||
| key = op[0] | |||||
| v = op[1:] | |||||
| if v == 're': | |||||
| print('not support') | |||||
| elif v == 'r6': | |||||
| print('not support') | |||||
| elif v == 'hs': | |||||
| print('not support') | |||||
| elif v == 'sw': | |||||
| print('not support') | |||||
| else: | |||||
| continue | |||||
| options[key] = value | |||||
| else: | |||||
| # all numeric options | |||||
| splits = re.split(r'(\d.*)', op) | |||||
| if len(splits) >= 2: | |||||
| key, value = splits[:2] | |||||
| options[key] = value | |||||
| act_fn = options['n'] if 'n' in options else None | |||||
| exp_kernel_size = _parse_ksize(options['a']) if 'a' in options else 1 | |||||
| pw_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1 | |||||
| fake_in_chs = int(options['fc']) if 'fc' in options else 0 | |||||
| num_repeat = int(options['r']) | |||||
| # each type of block has different valid arguments, fill accordingly | |||||
| if block_type == 'ir': | |||||
| block_args = dict( | |||||
| block_type=block_type, | |||||
| dw_kernel_size=_parse_ksize(options['k']), | |||||
| exp_kernel_size=exp_kernel_size, | |||||
| pw_kernel_size=pw_kernel_size, | |||||
| out_chs=int(options['c']), | |||||
| exp_ratio=float(options['e']), | |||||
| se_ratio=float(options['se']) if 'se' in options else None, | |||||
| stride=int(options['s']), | |||||
| act_fn=act_fn, | |||||
| noskip=noskip, | |||||
| ) | |||||
| elif block_type in ('ds', 'dsa'): | |||||
| block_args = dict( | |||||
| block_type=block_type, | |||||
| dw_kernel_size=_parse_ksize(options['k']), | |||||
| pw_kernel_size=pw_kernel_size, | |||||
| out_chs=int(options['c']), | |||||
| se_ratio=float(options['se']) if 'se' in options else None, | |||||
| stride=int(options['s']), | |||||
| act_fn=act_fn, | |||||
| pw_act=block_type == 'dsa', | |||||
| noskip=block_type == 'dsa' or noskip, | |||||
| ) | |||||
| elif block_type == 'er': | |||||
| block_args = dict( | |||||
| block_type=block_type, | |||||
| exp_kernel_size=_parse_ksize(options['k']), | |||||
| pw_kernel_size=pw_kernel_size, | |||||
| out_chs=int(options['c']), | |||||
| exp_ratio=float(options['e']), | |||||
| fake_in_chs=fake_in_chs, | |||||
| se_ratio=float(options['se']) if 'se' in options else None, | |||||
| stride=int(options['s']), | |||||
| act_fn=act_fn, | |||||
| noskip=noskip, | |||||
| ) | |||||
| elif block_type == 'cn': | |||||
| block_args = dict( | |||||
| block_type=block_type, | |||||
| kernel_size=int(options['k']), | |||||
| out_chs=int(options['c']), | |||||
| stride=int(options['s']), | |||||
| act_fn=act_fn, | |||||
| ) | |||||
| else: | |||||
| assert False, 'Unknown block type (%s)' % block_type | |||||
| return block_args, num_repeat | |||||
| def _scale_stage_depth(stack_args, repeats, depth_multiplier=1.0, depth_trunc='ceil'): | |||||
| """ Per-stage depth scaling | |||||
| Scales the block repeats in each stage. This depth scaling impl maintains | |||||
| compatibility with the EfficientNet scaling method, while allowing sensible | |||||
| scaling for other models that may have multiple block arg definitions in each stage. | |||||
| """ | |||||
| # We scale the total repeat count for each stage, there may be multiple | |||||
| # block arg defs per stage so we need to sum. | |||||
| num_repeat = sum(repeats) | |||||
| if depth_trunc == 'round': | |||||
| # Truncating to int by rounding allows stages with few repeats to remain | |||||
| # proportionally smaller for longer. This is a good choice when stage definitions | |||||
| # include single repeat stages that we'd prefer to keep that way as long as possible | |||||
| num_repeat_scaled = max(1, round(num_repeat * depth_multiplier)) | |||||
| else: | |||||
| # The default for EfficientNet truncates repeats to int via 'ceil'. | |||||
| # Any multiplier > 1.0 will result in an increased depth for every stage. | |||||
| num_repeat_scaled = int(math.ceil(num_repeat * depth_multiplier)) | |||||
| # Proportionally distribute repeat count scaling to each block definition in the stage. | |||||
| # Allocation is done in reverse as it results in the first block being less likely to be scaled. | |||||
| # The first block makes less sense to repeat in most of the arch definitions. | |||||
| repeats_scaled = [] | |||||
| for r in repeats[::-1]: | |||||
| rs = max(1, round((r / num_repeat * num_repeat_scaled))) | |||||
| repeats_scaled.append(rs) | |||||
| num_repeat -= r | |||||
| num_repeat_scaled -= rs | |||||
| repeats_scaled = repeats_scaled[::-1] | |||||
| # Apply the calculated scaling to each block arg in the stage | |||||
| sa_scaled = [] | |||||
| for ba, rep in zip(stack_args, repeats_scaled): | |||||
| sa_scaled.extend([deepcopy(ba) for _ in range(rep)]) | |||||
| return sa_scaled | |||||
| def _decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil'): | |||||
| """further decode the architecture definition into model-ready format""" | |||||
| arch_args = [] | |||||
| for _, block_strings in enumerate(arch_def): | |||||
| assert isinstance(block_strings, list) | |||||
| stack_args = [] | |||||
| repeats = [] | |||||
| for block_str in block_strings: | |||||
| assert isinstance(block_str, str) | |||||
| ba, rep = _decode_block_str(block_str) | |||||
| stack_args.append(ba) | |||||
| repeats.append(rep) | |||||
| arch_args.append(_scale_stage_depth( | |||||
| stack_args, repeats, depth_multiplier, depth_trunc)) | |||||
| return arch_args | |||||
| class Swish(nn.Cell): | |||||
| """swish activation function""" | |||||
| def __init__(self): | |||||
| super(Swish, self).__init__() | |||||
| self.sigmoid = P.Sigmoid() | |||||
| def construct(self, x): | |||||
| return x * self.sigmoid(x) | |||||
| @ms_function | |||||
| def swish(x): | |||||
| return x * nn.Sigmoid()(x) | |||||
| class BlockBuilder(nn.Cell): | |||||
| """build efficient-net convolution blocks""" | |||||
| def __init__(self, builder_in_channels, builder_block_args, channel_multiplier=1.0, | |||||
| channel_divisor=8, channel_min=None, pad_type='', act_fn=None, | |||||
| se_gate_fn=sigmoid, se_reduce_mid=False, bn_args=None, | |||||
| drop_connect_rate=0., verbose=False): | |||||
| super(BlockBuilder, self).__init__() | |||||
| self.channel_multiplier = channel_multiplier | |||||
| self.channel_divisor = channel_divisor | |||||
| self.channel_min = channel_min | |||||
| self.pad_type = pad_type | |||||
| self.act_fn = Swish() | |||||
| self.se_gate_fn = se_gate_fn | |||||
| self.se_reduce_mid = se_reduce_mid | |||||
| self.bn_args = bn_args | |||||
| self.drop_connect_rate = drop_connect_rate | |||||
| self.verbose = verbose | |||||
| # updated during build | |||||
| self.in_chs = None | |||||
| self.block_idx = 0 | |||||
| self.block_count = 0 | |||||
| self.layer = self._make_layer(builder_in_channels, builder_block_args) | |||||
| def _round_channels(self, chs): | |||||
| return _round_channels(chs, self.channel_multiplier, self.channel_divisor, self.channel_min) | |||||
| def _make_block(self, ba): | |||||
| """make the current block based on the block argument""" | |||||
| bt = ba.pop('block_type') | |||||
| ba['in_chs'] = self.in_chs | |||||
| ba['out_chs'] = self._round_channels(ba['out_chs']) | |||||
| if 'fake_in_chs' in ba and ba['fake_in_chs']: | |||||
| # this is a hack to work around mismatch in origin impl input filters | |||||
| ba['fake_in_chs'] = self._round_channels(ba['fake_in_chs']) | |||||
| ba['bn_args'] = self.bn_args | |||||
| ba['pad_type'] = self.pad_type | |||||
| # block act fn overrides the model default | |||||
| ba['act_fn'] = ba['act_fn'] if ba['act_fn'] is not None else self.act_fn | |||||
| assert ba['act_fn'] is not None | |||||
| if bt == 'ir': | |||||
| ba['drop_connect_rate'] = self.drop_connect_rate * \ | |||||
| self.block_idx / self.block_count | |||||
| ba['se_gate_fn'] = self.se_gate_fn | |||||
| ba['se_reduce_mid'] = self.se_reduce_mid | |||||
| block = InvertedResidual(**ba) | |||||
| elif bt in ('ds', 'dsa'): | |||||
| ba['drop_connect_rate'] = self.drop_connect_rate * \ | |||||
| self.block_idx / self.block_count | |||||
| block = DepthwiseSeparableConv(**ba) | |||||
| else: | |||||
| assert False, 'Uknkown block type (%s) while building model.' % bt | |||||
| self.in_chs = ba['out_chs'] | |||||
| return block | |||||
| def _make_stack(self, stack_args): | |||||
| """make a stack of blocks""" | |||||
| blocks = [] | |||||
| # each stack (stage) contains a list of block arguments | |||||
| for i, ba in enumerate(stack_args): | |||||
| if i >= 1: | |||||
| # only the first block in any stack can have a stride > 1 | |||||
| ba['stride'] = 1 | |||||
| block = self._make_block(ba) | |||||
| blocks.append(block) | |||||
| self.block_idx += 1 # incr global idx (across all stacks) | |||||
| return nn.SequentialCell(blocks) | |||||
| def _make_layer(self, in_chs, block_args): | |||||
| """ Build the entire layer | |||||
| Args: | |||||
| in_chs: Number of input-channels passed to first block | |||||
| block_args: A list of lists, outer list defines stages, inner | |||||
| list contains strings defining block configuration(s) | |||||
| Return: | |||||
| List of block stacks (each stack wrapped in nn.Sequential) | |||||
| """ | |||||
| self.in_chs = in_chs | |||||
| self.block_count = sum([len(x) for x in block_args]) | |||||
| self.block_idx = 0 | |||||
| blocks = [] | |||||
| # outer list of block_args defines the stacks ('stages' by some conventions) | |||||
| for _, stack in enumerate(block_args): | |||||
| assert isinstance(stack, list) | |||||
| stack = self._make_stack(stack) | |||||
| blocks.append(stack) | |||||
| return nn.SequentialCell(blocks) | |||||
| def construct(self, x): | |||||
| return self.layer(x) | |||||
| class DepthWiseConv(nn.Cell): | |||||
| """depth-wise convolution""" | |||||
| def __init__(self, in_planes, kernel_size, stride): | |||||
| super(DepthWiseConv, self).__init__() | |||||
| platform = context.get_context("device_target") | |||||
| weight_shape = [1, kernel_size, in_planes] | |||||
| weight_init = _initialize_weight_goog(shape=weight_shape) | |||||
| if platform == "GPU": | |||||
| self.depthwise_conv = P.Conv2D(out_channel=in_planes*1, | |||||
| kernel_size=kernel_size, | |||||
| stride=stride, | |||||
| pad=int(kernel_size/2), | |||||
| pad_mode="pad", | |||||
| group=in_planes) | |||||
| self.weight = Parameter(initializer(weight_init, | |||||
| [in_planes*1, 1, kernel_size, kernel_size]), name='depthwise_weight') | |||||
| else: | |||||
| self.depthwise_conv = P.DepthwiseConv2dNative(channel_multiplier=1, | |||||
| kernel_size=kernel_size, | |||||
| stride=stride, pad_mode='pad', | |||||
| pad=int(kernel_size/2)) | |||||
| self.weight = Parameter(initializer(weight_init, | |||||
| [1, in_planes, kernel_size, kernel_size]), name='depthwise_weight') | |||||
| def construct(self, x): | |||||
| x = self.depthwise_conv(x, self.weight) | |||||
| return x | |||||
| class DropConnect(nn.Cell): | |||||
| """drop connect implementation""" | |||||
| def __init__(self, drop_connect_rate=0., seed0=0, seed1=0): | |||||
| super(DropConnect, self).__init__() | |||||
| self.shape = P.Shape() | |||||
| self.dtype = P.DType() | |||||
| self.keep_prob = 1 - drop_connect_rate | |||||
| self.dropout = P.Dropout(keep_prob=self.keep_prob) | |||||
| def construct(self, x): | |||||
| shape = self.shape(x) | |||||
| dtype = self.dtype(x) | |||||
| ones_tensor = P.Fill()(dtype, (shape[0], 1, 1, 1), 1) | |||||
| _, mask_ = self.dropout(ones_tensor) | |||||
| x = x * mask_ | |||||
| return x | |||||
| def drop_connect(inputs, training=False, drop_connect_rate=0.): | |||||
| if not training: | |||||
| return inputs | |||||
| return DropConnect(drop_connect_rate)(inputs) | |||||
| class SqueezeExcite(nn.Cell): | |||||
| """squeeze-excite implementation""" | |||||
| def __init__(self, in_chs, reduce_chs=None, act_fn=relu, gate_fn=sigmoid): | |||||
| super(SqueezeExcite, self).__init__() | |||||
| self.act_fn = Swish() | |||||
| self.gate_fn = gate_fn | |||||
| reduce_chs = reduce_chs or in_chs | |||||
| self.conv_reduce = nn.Conv2d(in_channels=in_chs, out_channels=reduce_chs, | |||||
| kernel_size=1, has_bias=True, pad_mode='pad') | |||||
| self.conv_expand = nn.Conv2d(in_channels=reduce_chs, out_channels=in_chs, | |||||
| kernel_size=1, has_bias=True, pad_mode='pad') | |||||
| self.avg_global_pool = P.ReduceMean(keep_dims=True) | |||||
| def construct(self, x): | |||||
| x_se = self.avg_global_pool(x, (2, 3)) | |||||
| x_se = self.conv_reduce(x_se) | |||||
| x_se = self.act_fn(x_se) | |||||
| x_se = self.conv_expand(x_se) | |||||
| x_se = self.gate_fn(x_se) | |||||
| x = x * x_se | |||||
| return x | |||||
| class DepthwiseSeparableConv(nn.Cell): | |||||
| """depth-wise convolution -> (squeeze-excite) -> point-wise convolution""" | |||||
| def __init__(self, in_chs, out_chs, dw_kernel_size=3, | |||||
| stride=1, pad_type='', act_fn=relu, noskip=False, | |||||
| pw_kernel_size=1, pw_act=False, se_ratio=0., se_gate_fn=sigmoid, | |||||
| bn_args=None, drop_connect_rate=0.): | |||||
| super(DepthwiseSeparableConv, self).__init__() | |||||
| assert stride in [1, 2], 'stride must be 1 or 2' | |||||
| self.has_se = se_ratio is not None and se_ratio > 0. | |||||
| self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip | |||||
| self.has_pw_act = pw_act | |||||
| self.act_fn = Swish() | |||||
| self.drop_connect_rate = drop_connect_rate | |||||
| self.conv_dw = DepthWiseConv(in_chs, dw_kernel_size, stride) | |||||
| self.bn1 = _fused_bn(in_chs, **bn_args) | |||||
| if self.has_se: | |||||
| self.se = SqueezeExcite(in_chs, reduce_chs=max(1, int(in_chs * se_ratio)), | |||||
| act_fn=act_fn, gate_fn=se_gate_fn) | |||||
| self.conv_pw = _conv1x1(in_chs, out_chs) | |||||
| self.bn2 = _fused_bn(out_chs, **bn_args) | |||||
| def construct(self, x): | |||||
| """forward the depthwise separable conv""" | |||||
| identity = x | |||||
| x = self.conv_dw(x) | |||||
| x = self.bn1(x) | |||||
| x = self.act_fn(x) | |||||
| if self.has_se: | |||||
| x = self.se(x) | |||||
| x = self.conv_pw(x) | |||||
| x = self.bn2(x) | |||||
| if self.has_pw_act: | |||||
| x = self.act_fn(x) | |||||
| if self.has_residual: | |||||
| if self.drop_connect_rate > 0.: | |||||
| x = drop_connect(x, self.training, self.drop_connect_rate) | |||||
| x = x + identity | |||||
| return x | |||||
| class InvertedResidual(nn.Cell): | |||||
| """inverted-residual block implementation""" | |||||
| def __init__(self, in_chs, out_chs, dw_kernel_size=3, stride=1, | |||||
| pad_type='', act_fn=relu, pw_kernel_size=1, | |||||
| noskip=False, exp_ratio=1., exp_kernel_size=1, se_ratio=0., | |||||
| se_reduce_mid=False, se_gate_fn=sigmoid, shuffle_type=None, | |||||
| bn_args=None, drop_connect_rate=0.): | |||||
| super(InvertedResidual, self).__init__() | |||||
| mid_chs = int(in_chs * exp_ratio) | |||||
| self.has_se = se_ratio is not None and se_ratio > 0. | |||||
| self.has_residual = (in_chs == out_chs and stride == 1) and not noskip | |||||
| self.act_fn = Swish() | |||||
| self.drop_connect_rate = drop_connect_rate | |||||
| self.conv_pw = _conv(in_chs, mid_chs, exp_kernel_size) | |||||
| self.bn1 = _fused_bn(mid_chs, **bn_args) | |||||
| self.shuffle_type = shuffle_type | |||||
| if self.shuffle_type is not None and isinstance(exp_kernel_size, list): | |||||
| self.shuffle = None | |||||
| self.conv_dw = DepthWiseConv(mid_chs, dw_kernel_size, stride) | |||||
| self.bn2 = _fused_bn(mid_chs, **bn_args) | |||||
| if self.has_se: | |||||
| se_base_chs = mid_chs if se_reduce_mid else in_chs | |||||
| self.se = SqueezeExcite( | |||||
| mid_chs, reduce_chs=max(1, int(se_base_chs * se_ratio)), | |||||
| act_fn=act_fn, gate_fn=se_gate_fn | |||||
| ) | |||||
| self.conv_pwl = _conv(mid_chs, out_chs, pw_kernel_size) | |||||
| self.bn3 = _fused_bn(out_chs, **bn_args) | |||||
| def construct(self, x): | |||||
| """forward the inverted-residual block""" | |||||
| identity = x | |||||
| x = self.conv_pw(x) | |||||
| x = self.bn1(x) | |||||
| x = self.act_fn(x) | |||||
| x = self.conv_dw(x) | |||||
| x = self.bn2(x) | |||||
| x = self.act_fn(x) | |||||
| if self.has_se: | |||||
| x = self.se(x) | |||||
| x = self.conv_pwl(x) | |||||
| x = self.bn3(x) | |||||
| if self.has_residual: | |||||
| if self.drop_connect_rate > 0: | |||||
| x = drop_connect(x, self.training, self.drop_connect_rate) | |||||
| x = x + identity | |||||
| return x | |||||
| class GenEfficientNet(nn.Cell): | |||||
| """Generate EfficientNet architecture""" | |||||
| def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=32, num_features=1280, | |||||
| channel_multiplier=1.0, channel_divisor=8, channel_min=None, | |||||
| pad_type='', act_fn=relu, drop_rate=0., drop_connect_rate=0., | |||||
| se_gate_fn=sigmoid, se_reduce_mid=False, bn_args=None, | |||||
| global_pool='avg', head_conv='default', weight_init='goog'): | |||||
| super(GenEfficientNet, self).__init__() | |||||
| bn_args = _BN_ARGS_PT if bn_args is None else bn_args | |||||
| self.num_classes = num_classes | |||||
| self.drop_rate = drop_rate | |||||
| self.num_features = num_features | |||||
| self.conv_stem = _conv(in_chans, stem_size, 3, | |||||
| stride=2, padding=1, pad_mode='pad') | |||||
| self.bn1 = _fused_bn(stem_size, **bn_args) | |||||
| self.act_fn = Swish() | |||||
| in_chans = stem_size | |||||
| self.blocks = BlockBuilder(in_chans, block_args, channel_multiplier, | |||||
| channel_divisor, channel_min, | |||||
| pad_type, act_fn, se_gate_fn, se_reduce_mid, | |||||
| bn_args, drop_connect_rate, verbose=_DEBUG) | |||||
| in_chs = self.blocks.in_chs | |||||
| if not head_conv or head_conv == 'none': | |||||
| self.efficient_head = False | |||||
| self.conv_head = None | |||||
| assert in_chs == self.num_features | |||||
| else: | |||||
| self.efficient_head = head_conv == 'efficient' | |||||
| self.conv_head = _conv1x1(in_chs, self.num_features) | |||||
| self.bn2 = None if self.efficient_head else _fused_bn( | |||||
| self.num_features, **bn_args) | |||||
| self.global_pool = P.ReduceMean(keep_dims=True) | |||||
| self.classifier = _dense(self.num_features, self.num_classes) | |||||
| self.reshape = P.Reshape() | |||||
| self.shape = P.Shape() | |||||
| self.drop_out = nn.Dropout(keep_prob=1-self.drop_rate) | |||||
| def construct(self, x): | |||||
| """efficient net entry point""" | |||||
| x = self.conv_stem(x) | |||||
| x = self.bn1(x) | |||||
| x = self.act_fn(x) | |||||
| x = self.blocks(x) | |||||
| if self.efficient_head: | |||||
| x = self.global_pool(x, (2, 3)) | |||||
| x = self.conv_head(x) | |||||
| x = self.act_fn(x) | |||||
| x = self.reshape(self.shape(x)[0], -1) | |||||
| else: | |||||
| if self.conv_head is not None: | |||||
| x = self.conv_head(x) | |||||
| x = self.bn2(x) | |||||
| x = self.act_fn(x) | |||||
| x = self.global_pool(x, (2, 3)) | |||||
| x = self.reshape(x, (self.shape(x)[0], -1)) | |||||
| if self.training and self.drop_rate > 0.: | |||||
| x = self.drop_out(x) | |||||
| return self.classifier(x) | |||||
| def _gen_efficientnet(channel_multiplier=1.0, depth_multiplier=1.0, num_classes=1000, **kwargs): | |||||
| """Creates an EfficientNet model. | |||||
| Ref impl: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py | |||||
| Paper: https://arxiv.org/abs/1905.11946 | |||||
| EfficientNet params | |||||
| name: (channel_multiplier, depth_multiplier, resolution, dropout_rate) | |||||
| 'efficientnet-b0': (1.0, 1.0, 224, 0.2), | |||||
| 'efficientnet-b1': (1.0, 1.1, 240, 0.2), | |||||
| 'efficientnet-b2': (1.1, 1.2, 260, 0.3), | |||||
| 'efficientnet-b3': (1.2, 1.4, 300, 0.3), | |||||
| 'efficientnet-b4': (1.4, 1.8, 380, 0.4), | |||||
| 'efficientnet-b5': (1.6, 2.2, 456, 0.4), | |||||
| 'efficientnet-b6': (1.8, 2.6, 528, 0.5), | |||||
| 'efficientnet-b7': (2.0, 3.1, 600, 0.5), | |||||
| Args: | |||||
| channel_multiplier (int): multiplier to number of channels per layer | |||||
| depth_multiplier (int): multiplier to number of repeats per stage | |||||
| """ | |||||
| arch_def = [ | |||||
| ['ds_r1_k3_s1_e1_c16_se0.25'], | |||||
| ['ir_r2_k3_s2_e6_c24_se0.25'], | |||||
| ['ir_r2_k5_s2_e6_c40_se0.25'], | |||||
| ['ir_r3_k3_s2_e6_c80_se0.25'], | |||||
| ['ir_r3_k5_s1_e6_c112_se0.25'], | |||||
| ['ir_r4_k5_s2_e6_c192_se0.25'], | |||||
| ['ir_r1_k3_s1_e6_c320_se0.25'], | |||||
| ] | |||||
| num_features = max(1280, _round_channels( | |||||
| 1280, channel_multiplier, 8, None)) | |||||
| model = GenEfficientNet( | |||||
| _decode_arch_def(arch_def, depth_multiplier, depth_trunc='round'), | |||||
| num_classes=num_classes, | |||||
| stem_size=32, | |||||
| channel_multiplier=channel_multiplier, | |||||
| num_features=num_features, | |||||
| bn_args=_resolve_bn_args(kwargs), | |||||
| act_fn=Swish, | |||||
| **kwargs) | |||||
| return model | |||||
| def tinynet(sub_model="c", num_classes=1000, in_chans=3, **kwargs): | |||||
| """ TinyNet Models """ | |||||
| # choose a sub model | |||||
| r, w, d = TINYNET_CFG[sub_model] | |||||
| default_cfg = default_cfgs['efficientnet_b0'] | |||||
| assert default_cfg['input_size'] == (3, 224, 224), "All tinynet models are \ | |||||
| evolved from Efficient-B0, which has input dimension of 3*224*224" | |||||
| channel, height, width = default_cfg['input_size'] | |||||
| height = int(r * height) | |||||
| width = int(r * width) | |||||
| default_cfg['input_size'] = (channel, height, width) | |||||
| print("Data processing configuration for current model + dataset:") | |||||
| print("input_size:", default_cfg['input_size']) | |||||
| print("channel mutiplier:%s, depth multiplier:%s, resolution multiplier:%s" % (w, d, r)) | |||||
| model = _gen_efficientnet( | |||||
| channel_multiplier=w, depth_multiplier=d, | |||||
| num_classes=num_classes, in_chans=in_chans, **kwargs) | |||||
| model.default_cfg = default_cfg | |||||
| return model | |||||
| @@ -0,0 +1,89 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """model utils""" | |||||
| import math | |||||
| import argparse | |||||
| import numpy as np | |||||
| def str2bool(value): | |||||
| """Convert string arguments to bool type""" | |||||
| if value.lower() in ('yes', 'true', 't', 'y', '1'): | |||||
| return True | |||||
| if value.lower() in ('no', 'false', 'f', 'n', '0'): | |||||
| return False | |||||
| raise argparse.ArgumentTypeError('Boolean value expected.') | |||||
| def get_lr(base_lr, total_epochs, steps_per_epoch, decay_epochs=1, decay_rate=0.9, | |||||
| warmup_epochs=0., warmup_lr_init=0., global_epoch=0): | |||||
| """Get scheduled learning rate""" | |||||
| lr_each_step = [] | |||||
| total_steps = steps_per_epoch * total_epochs | |||||
| global_steps = steps_per_epoch * global_epoch | |||||
| self_warmup_delta = ((base_lr - warmup_lr_init) / \ | |||||
| warmup_epochs) if warmup_epochs > 0 else 0 | |||||
| self_decay_rate = decay_rate if decay_rate < 1 else 1/decay_rate | |||||
| for i in range(total_steps): | |||||
| epochs = math.floor(i/steps_per_epoch) | |||||
| cond = 1 if (epochs < warmup_epochs) else 0 | |||||
| warmup_lr = warmup_lr_init + epochs * self_warmup_delta | |||||
| decay_nums = math.floor(epochs / decay_epochs) | |||||
| decay_rate = math.pow(self_decay_rate, decay_nums) | |||||
| decay_lr = base_lr * decay_rate | |||||
| lr = cond * warmup_lr + (1 - cond) * decay_lr | |||||
| lr_each_step.append(lr) | |||||
| lr_each_step = lr_each_step[global_steps:] | |||||
| lr_each_step = np.array(lr_each_step).astype(np.float32) | |||||
| return lr_each_step | |||||
| def add_weight_decay(net, weight_decay=1e-5, skip_list=None): | |||||
| """Apply weight decay to only conv and dense layers (len(shape) > =2) | |||||
| Args: | |||||
| net (mindspore.nn.Cell): Mindspore network instance | |||||
| weight_decay (float): weight decay tobe used. | |||||
| skip_list (tuple): list of parameter names without weight decay | |||||
| Returns: | |||||
| A list of group of parameters, separated by different weight decay. | |||||
| """ | |||||
| decay = [] | |||||
| no_decay = [] | |||||
| if not skip_list: | |||||
| skip_list = () | |||||
| for param in net.trainable_params(): | |||||
| if len(param.shape) == 1 or \ | |||||
| param.name.endswith(".bias") or \ | |||||
| param.name in skip_list: | |||||
| no_decay.append(param) | |||||
| else: | |||||
| decay.append(param) | |||||
| return [ | |||||
| {'params': no_decay, 'weight_decay': 0.}, | |||||
| {'params': decay, 'weight_decay': weight_decay}] | |||||
| def count_params(net): | |||||
| """Count number of parameters in the network | |||||
| Args: | |||||
| net (mindspore.nn.Cell): Mindspore network instance | |||||
| Returns: | |||||
| total_params (int): Total number of trainable params | |||||
| """ | |||||
| total_params = 0 | |||||
| for param in net.trainable_params(): | |||||
| total_params += np.prod(param.shape) | |||||
| return total_params | |||||
| @@ -0,0 +1,250 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Training Interface""" | |||||
| import sys | |||||
| import os | |||||
| import argparse | |||||
| import copy | |||||
| from mindspore.communication.management import init, get_rank, get_group_size | |||||
| from mindspore.train.model import ParallelMode, Model | |||||
| from mindspore.train.callback import TimeMonitor | |||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||||
| from mindspore.train.loss_scale_manager import FixedLossScaleManager | |||||
| from mindspore.nn import SGD, RMSProp, Loss, Top1CategoricalAccuracy, \ | |||||
| Top5CategoricalAccuracy | |||||
| from mindspore import context, Tensor | |||||
| from src.dataset import create_dataset, create_dataset_val | |||||
| from src.utils import add_weight_decay, count_params, str2bool, get_lr | |||||
| from src.callback import EmaEvalCallBack, LossMonitor | |||||
| from src.loss import LabelSmoothingCrossEntropy | |||||
| from src.tinynet import tinynet | |||||
| parser = argparse.ArgumentParser(description='Training') | |||||
| # training parameters | |||||
| parser.add_argument('--data_path', type=str, default="", metavar="DIR", | |||||
| help='path to dataset') | |||||
| parser.add_argument('--model', default='tinynet_c', type=str, metavar='MODEL', | |||||
| help='Name of model to train (default: "tinynet_c"') | |||||
| parser.add_argument('--num-classes', type=int, default=1000, metavar='N', | |||||
| help='number of label classes (default: 1000)') | |||||
| parser.add_argument('-b', '--batch-size', type=int, default=32, metavar='N', | |||||
| help='input batch size for training (default: 32)') | |||||
| parser.add_argument('--drop', type=float, default=0.0, metavar='DROP', | |||||
| help='Dropout rate (default: 0.)') | |||||
| parser.add_argument('--drop-connect', type=float, default=0.0, metavar='DROP', | |||||
| help='Drop connect rate (default: 0.)') | |||||
| parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER', | |||||
| help='Optimizer (default: "sgd"') | |||||
| parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON', | |||||
| help='Optimizer Epsilon (default: 1e-8)') | |||||
| parser.add_argument('--momentum', type=float, default=0.9, metavar='M', | |||||
| help='SGD momentum (default: 0.9)') | |||||
| parser.add_argument('--weight-decay', type=float, default=0.0001, | |||||
| help='weight decay (default: 0.0001)') | |||||
| parser.add_argument('--lr', type=float, default=0.01, metavar='LR', | |||||
| help='learning rate (default: 0.01)') | |||||
| parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR', | |||||
| help='warmup learning rate (default: 0.0001)') | |||||
| parser.add_argument('--epochs', type=int, default=200, metavar='N', | |||||
| help='number of epochs to train (default: 2)') | |||||
| parser.add_argument('--decay-epochs', type=float, default=30, metavar='N', | |||||
| help='epoch interval to decay LR') | |||||
| parser.add_argument('--warmup-epochs', type=int, default=3, metavar='N', | |||||
| help='epochs to warmup LR, if scheduler supports') | |||||
| parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', | |||||
| help='LR decay rate (default: 0.1)') | |||||
| parser.add_argument('--smoothing', type=float, default=0.1, | |||||
| help='label smoothing (default: 0.1)') | |||||
| parser.add_argument('--ema-decay', type=float, default=0, | |||||
| help='decay factor for model weights moving average \ | |||||
| (default: 0.999)') | |||||
| parser.add_argument('--amp_level', type=str, default='O0') | |||||
| parser.add_argument('--per_print_times', type=int, default=100) | |||||
| # batch norm parameters | |||||
| parser.add_argument('--bn-tf', action='store_true', default=False, | |||||
| help='Use Tensorflow BatchNorm defaults for models that \ | |||||
| support it (default: False)') | |||||
| parser.add_argument('--bn-momentum', type=float, default=None, | |||||
| help='BatchNorm momentum override (if not None)') | |||||
| parser.add_argument('--bn-eps', type=float, default=None, | |||||
| help='BatchNorm epsilon override (if not None)') | |||||
| # parallel parameters | |||||
| parser.add_argument('-j', '--workers', type=int, default=4, metavar='N', | |||||
| help='how many training processes to use (default: 1)') | |||||
| parser.add_argument('--distributed', action='store_true', default=False) | |||||
| parser.add_argument('--dataset_sink', action='store_true', default=True) | |||||
| # checkpoint config | |||||
| parser.add_argument('--ckpt', type=str, default=None) | |||||
| parser.add_argument('--ckpt_save_epoch', type=int, default=1) | |||||
| parser.add_argument('--loss_scale', type=int, | |||||
| default=1024, help='static loss scale') | |||||
| parser.add_argument('--train', type=str2bool, default=1, help='train or eval') | |||||
| parser.add_argument('--GPU', action='store_true', default=False, | |||||
| help='Use GPU for training (default: False)') | |||||
| def main(): | |||||
| """Main entrance for training""" | |||||
| args = parser.parse_args() | |||||
| print(sys.argv) | |||||
| devid, args.rank_id, args.rank_size = 0, 0, 1 | |||||
| context.set_context(mode=context.GRAPH_MODE) | |||||
| if args.distributed: | |||||
| if args.GPU: | |||||
| init("nccl") | |||||
| context.set_context(device_target='GPU') | |||||
| else: | |||||
| init() | |||||
| devid = int(os.getenv('DEVICE_ID')) | |||||
| context.set_context(device_target='Ascend', | |||||
| device_id=devid, | |||||
| reserve_class_name_in_scope=False) | |||||
| context.reset_auto_parallel_context() | |||||
| args.rank_id = get_rank() | |||||
| args.rank_size = get_group_size() | |||||
| context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, | |||||
| gradients_mean=True, | |||||
| device_num=args.rank_size) | |||||
| else: | |||||
| if args.GPU: | |||||
| context.set_context(device_target='GPU') | |||||
| is_master = not args.distributed or (args.rank_id == 0) | |||||
| # parse model argument | |||||
| assert args.model.startswith( | |||||
| "tinynet"), "Only Tinynet models are supported." | |||||
| _, sub_name = args.model.split("_") | |||||
| net = tinynet(sub_model=sub_name, | |||||
| num_classes=args.num_classes, | |||||
| drop_rate=args.drop, | |||||
| drop_connect_rate=args.drop_connect, | |||||
| global_pool="avg", | |||||
| bn_tf=args.bn_tf, | |||||
| bn_momentum=args.bn_momentum, | |||||
| bn_eps=args.bn_eps) | |||||
| if is_master: | |||||
| print("Total number of parameters:", count_params(net)) | |||||
| # input image size of the network | |||||
| input_size = net.default_cfg['input_size'][1] | |||||
| train_dataset = val_dataset = None | |||||
| train_data_url = os.path.join(args.data_path, 'train') | |||||
| val_data_url = os.path.join(args.data_path, 'val') | |||||
| val_dataset = create_dataset_val(args.batch_size, | |||||
| val_data_url, | |||||
| workers=args.workers, | |||||
| distributed=False, | |||||
| input_size=input_size) | |||||
| if args.train: | |||||
| train_dataset = create_dataset(args.batch_size, | |||||
| train_data_url, | |||||
| workers=args.workers, | |||||
| distributed=args.distributed, | |||||
| input_size=input_size) | |||||
| batches_per_epoch = train_dataset.get_dataset_size() | |||||
| loss = LabelSmoothingCrossEntropy( | |||||
| smooth_factor=args.smoothing, num_classes=args.num_classes) | |||||
| time_cb = TimeMonitor(data_size=batches_per_epoch) | |||||
| loss_scale_manager = FixedLossScaleManager( | |||||
| args.loss_scale, drop_overflow_update=False) | |||||
| lr_array = get_lr(base_lr=args.lr, | |||||
| total_epochs=args.epochs, | |||||
| steps_per_epoch=batches_per_epoch, | |||||
| decay_epochs=args.decay_epochs, | |||||
| decay_rate=args.decay_rate, | |||||
| warmup_epochs=args.warmup_epochs, | |||||
| warmup_lr_init=args.warmup_lr, | |||||
| global_epoch=0) | |||||
| lr = Tensor(lr_array) | |||||
| loss_cb = LossMonitor(lr_array, | |||||
| args.epochs, | |||||
| per_print_times=args.per_print_times, | |||||
| start_epoch=0) | |||||
| param_group = add_weight_decay(net, weight_decay=args.weight_decay) | |||||
| if args.opt == 'sgd': | |||||
| if is_master: | |||||
| print('Using SGD optimizer') | |||||
| optimizer = SGD(param_group, | |||||
| learning_rate=lr, | |||||
| momentum=args.momentum, | |||||
| weight_decay=args.weight_decay, | |||||
| loss_scale=args.loss_scale) | |||||
| elif args.opt == 'rmsprop': | |||||
| if is_master: | |||||
| print('Using rmsprop optimizer') | |||||
| optimizer = RMSProp(param_group, | |||||
| learning_rate=lr, | |||||
| decay=0.9, | |||||
| weight_decay=args.weight_decay, | |||||
| momentum=args.momentum, | |||||
| epsilon=args.opt_eps, | |||||
| loss_scale=args.loss_scale) | |||||
| loss.add_flags_recursive(fp32=True, fp16=False) | |||||
| eval_metrics = {'Validation-Loss': Loss(), | |||||
| 'Top1-Acc': Top1CategoricalAccuracy(), | |||||
| 'Top5-Acc': Top5CategoricalAccuracy()} | |||||
| if args.ckpt: | |||||
| ckpt = load_checkpoint(args.ckpt) | |||||
| load_param_into_net(net, ckpt) | |||||
| net.set_train(False) | |||||
| model = Model(net, loss, optimizer, metrics=eval_metrics, | |||||
| loss_scale_manager=loss_scale_manager, | |||||
| amp_level=args.amp_level) | |||||
| net_ema = copy.deepcopy(net) | |||||
| net_ema.set_train(False) | |||||
| assert args.ema_decay > 0, "EMA should be used in tinynet training." | |||||
| ema_cb = EmaEvalCallBack(model=model, | |||||
| ema_network=net_ema, | |||||
| loss_fn=loss, | |||||
| eval_dataset=val_dataset, | |||||
| decay=args.ema_decay, | |||||
| save_epoch=args.ckpt_save_epoch, | |||||
| dataset_sink_mode=args.dataset_sink, | |||||
| start_epoch=0) | |||||
| callbacks = [loss_cb, ema_cb, time_cb] if is_master else [] | |||||
| if is_master: | |||||
| print("Training on " + args.model | |||||
| + " with " + str(args.num_classes) + " classes") | |||||
| model.train(args.epochs, train_dataset, callbacks=callbacks, | |||||
| dataset_sink_mode=args.dataset_sink) | |||||
| if __name__ == '__main__': | |||||
| main() | |||||