| @@ -0,0 +1,111 @@ | |||||
| # EfficientNet-B0 Example | |||||
| ## Description | |||||
| This is an example of training EfficientNet-B0 in MindSpore. | |||||
| ## Requirements | |||||
| - Install [Mindspore](http://www.mindspore.cn/install/en). | |||||
| - Download the dataset. | |||||
| ## Structure | |||||
| ```shell | |||||
| . | |||||
| └─nasnet | |||||
| ├─README.md | |||||
| ├─scripts | |||||
| ├─run_standalone_train_for_gpu.sh # launch standalone training with gpu platform(1p) | |||||
| ├─run_distribute_train_for_gpu.sh # launch distributed training with gpu platform(8p) | |||||
| └─run_eval_for_gpu.sh # launch evaluating with gpu platform | |||||
| ├─src | |||||
| ├─config.py # parameter configuration | |||||
| ├─dataset.py # data preprocessing | |||||
| ├─efficientnet.py # network definition | |||||
| ├─loss.py # Customized loss function | |||||
| ├─transform_utils.py # random augment utils | |||||
| ├─transform.py # random augment class | |||||
| ├─eval.py # eval net | |||||
| └─train.py # train net | |||||
| ``` | |||||
| ## Parameter Configuration | |||||
| Parameters for both training and evaluating can be set in config.py | |||||
| ``` | |||||
| 'random_seed': 1, # fix random seed | |||||
| 'model': 'efficientnet_b0', # model name | |||||
| 'drop': 0.2, # dropout rate | |||||
| 'drop_connect': 0.2, # drop connect rate | |||||
| 'opt_eps': 0.001, # optimizer epsilon | |||||
| 'lr': 0.064, # learning rate LR | |||||
| 'batch_size': 128, # batch size | |||||
| 'decay_epochs': 2.4, # epoch interval to decay LR | |||||
| 'warmup_epochs': 5, # epochs to warmup LR | |||||
| 'decay_rate': 0.97, # LR decay rate | |||||
| 'weight_decay': 1e-5, # weight decay | |||||
| 'epochs': 600, # number of epochs to train | |||||
| 'workers': 8, # number of data processing processes | |||||
| 'amp_level': 'O0', # amp level | |||||
| 'opt': 'rmsprop', # optimizer | |||||
| 'num_classes': 1000, # number of classes | |||||
| 'gp': 'avg', # type of global pool, "avg", "max", "avgmax", "avgmaxc" | |||||
| 'momentum': 0.9, # optimizer momentum | |||||
| 'warmup_lr_init': 0.0001, # init warmup LR | |||||
| 'smoothing': 0.1, # label smoothing factor | |||||
| 'bn_tf': False, # use Tensorflow BatchNorm defaults | |||||
| 'keep_checkpoint_max': 10, # max number ckpts to keep | |||||
| 'loss_scale': 1024, # loss scale | |||||
| 'resume_start_epoch': 0, # resume start epoch | |||||
| ``` | |||||
| ## Running the example | |||||
| ### Train | |||||
| #### Usage | |||||
| ``` | |||||
| # distribute training example(8p) | |||||
| sh run_distribute_train_for_gpu.sh DATA_DIR | |||||
| # standalone training | |||||
| sh run_standalone_train_for_gpu.sh DATA_DIR DEVICE_ID | |||||
| ``` | |||||
| #### Launch | |||||
| ```bash | |||||
| # distributed training example(8p) for GPU | |||||
| sh scripts/run_distribute_train_for_gpu.sh /dataset | |||||
| # standalone training example for GPU | |||||
| sh scripts/run_standalone_train_for_gpu.sh /dataset 0 | |||||
| ``` | |||||
| #### Result | |||||
| You can find checkpoint file together with result in log. | |||||
| ### Evaluation | |||||
| #### Usage | |||||
| ``` | |||||
| # Evaluation | |||||
| sh run_eval_for_gpu.sh DATA_DIR DEVICE_ID PATH_CHECKPOINT | |||||
| ``` | |||||
| #### Launch | |||||
| ```bash | |||||
| # Evaluation with checkpoint | |||||
| sh scripts/run_eval_for_gpu.sh /dataset 0 ./checkpoint/efficientnet_b0-600_1251.ckpt | |||||
| ``` | |||||
| > checkpoint can be produced in training process. | |||||
| #### Result | |||||
| Evaluation result will be stored in the scripts path. Under this, you can find result like the followings in log. | |||||
| @@ -0,0 +1,61 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """evaluate imagenet""" | |||||
| import argparse | |||||
| import os | |||||
| import mindspore.nn as nn | |||||
| from mindspore import context | |||||
| from mindspore.train.model import Model | |||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||||
| from src.config import efficientnet_b0_config_gpu as cfg | |||||
| from src.dataset import create_dataset_val | |||||
| from src.efficientnet import efficientnet_b0 | |||||
| from src.loss import LabelSmoothingCrossEntropy | |||||
| if __name__ == '__main__': | |||||
| parser = argparse.ArgumentParser(description='image classification evaluation') | |||||
| parser.add_argument('--checkpoint', type=str, default='', help='checkpoint of efficientnet (Default: None)') | |||||
| parser.add_argument('--data_path', type=str, default='', help='Dataset path') | |||||
| parser.add_argument('--platform', type=str, default='GPU', choices=('Ascend', 'GPU'), help='run platform') | |||||
| args_opt = parser.parse_args() | |||||
| if args_opt.platform == 'Ascend': | |||||
| device_id = int(os.getenv('DEVICE_ID')) | |||||
| context.set_context(device_id=device_id) | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform) | |||||
| net = efficientnet_b0(num_classes=cfg.num_classes, | |||||
| drop_rate=cfg.drop, | |||||
| drop_connect_rate=cfg.drop_connect, | |||||
| global_pool=cfg.gp, | |||||
| bn_tf=cfg.bn_tf, | |||||
| ) | |||||
| ckpt = load_checkpoint(args_opt.checkpoint) | |||||
| load_param_into_net(net, ckpt) | |||||
| net.set_train(False) | |||||
| val_data_url = os.path.join(args_opt.data_path, 'val') | |||||
| dataset = create_dataset_val(cfg.batch_size, val_data_url, workers=cfg.workers, distributed=False) | |||||
| loss = LabelSmoothingCrossEntropy(smooth_factor=cfg.smoothing) | |||||
| eval_metrics = {'Loss': nn.Loss(), | |||||
| 'Top1-Acc': nn.Top1CategoricalAccuracy(), | |||||
| 'Top5-Acc': nn.Top5CategoricalAccuracy()} | |||||
| model = Model(net, loss, optimizer=None, metrics=eval_metrics) | |||||
| metrics = model.eval(dataset) | |||||
| print("metric: ", metrics) | |||||
| @@ -0,0 +1,32 @@ | |||||
| #!/bin/bash | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| DATA_DIR=$1 | |||||
| current_exec_path=$(pwd) | |||||
| echo ${current_exec_path} | |||||
| curtime=`date '+%Y%m%d-%H%M%S'` | |||||
| RANK_SIZE=8 | |||||
| rm ${current_exec_path}/device_parallel/ -rf | |||||
| mkdir ${current_exec_path}/device_parallel | |||||
| echo ${curtime} > ${current_exec_path}/device_parallel/starttime | |||||
| mpirun --allow-run-as-root -n $RANK_SIZE python ${current_exec_path}/train.py \ | |||||
| --GPU \ | |||||
| --distributed \ | |||||
| --data_path ${DATA_DIR} \ | |||||
| --cur_time ${curtime} > ${current_exec_path}/device_parallel/efficientnet_b0.log 2>&1 & | |||||
| @@ -0,0 +1,27 @@ | |||||
| #!/bin/bash | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| DATA_DIR=$1 | |||||
| DEVICE_ID=$2 | |||||
| PATH_CHECKPOINT=$3 | |||||
| current_exec_path=$(pwd) | |||||
| echo ${current_exec_path} | |||||
| curtime=`date '+%Y%m%d-%H%M%S'` | |||||
| echo ${curtime} > ${current_exec_path}/eval_starttime | |||||
| CUDA_VISIBLE_DEVICES=${DEVICE_ID} python ./eval.py --platform 'GPU' --data_path ${DATA_DIR} --checkpoint ${PATH_CHECKPOINT} > ${current_exec_path}/eval.log 2>&1 & | |||||
| @@ -0,0 +1,31 @@ | |||||
| #!/bin/bash | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| DATA_DIR=$1 | |||||
| DEVICE_ID=$2 | |||||
| current_exec_path=$(pwd) | |||||
| echo ${current_exec_path} | |||||
| curtime=`date '+%Y%m%d-%H%M%S'` | |||||
| rm ${current_exec_path}/device_${DEVICE_ID}/ -rf | |||||
| mkdir ${current_exec_path}/device_${DEVICE_ID} | |||||
| echo ${curtime} > ${current_exec_path}/device_${DEVICE_ID}/starttime | |||||
| CUDA_VISIBLE_DEVICES=${DEVICE_ID} python ${current_exec_path}/train.py \ | |||||
| --GPU \ | |||||
| --data_path ${DATA_DIR} \ | |||||
| --cur_time ${curtime} > ${current_exec_path}/device_${DEVICE_ID}/efficientnet_b0.log 2>&1 & | |||||
| @@ -0,0 +1,47 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| network config setting | |||||
| """ | |||||
| from easydict import EasyDict as edict | |||||
| efficientnet_b0_config_gpu = edict({ | |||||
| 'random_seed': 1, | |||||
| 'model': 'efficientnet_b0', | |||||
| 'drop': 0.2, | |||||
| 'drop_connect': 0.2, | |||||
| 'opt_eps': 0.001, | |||||
| 'lr': 0.064, | |||||
| 'batch_size': 128, | |||||
| 'decay_epochs': 2.4, | |||||
| 'warmup_epochs': 5, | |||||
| 'decay_rate': 0.97, | |||||
| 'weight_decay': 1e-5, | |||||
| 'epochs': 600, | |||||
| 'workers': 8, | |||||
| 'amp_level': 'O0', | |||||
| 'opt': 'rmsprop', | |||||
| 'num_classes': 1000, | |||||
| #'Type of global pool, "avg", "max", "avgmax", "avgmaxc" | |||||
| 'gp': 'avg', | |||||
| 'momentum': 0.9, | |||||
| 'warmup_lr_init': 0.0001, | |||||
| 'smoothing': 0.1, | |||||
| #Use Tensorflow BatchNorm defaults for models that support it | |||||
| 'bn_tf': False, | |||||
| 'keep_checkpoint_max': 10, | |||||
| 'loss_scale': 1024, | |||||
| 'resume_start_epoch': 0, | |||||
| }) | |||||
| @@ -0,0 +1,125 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| Data operations, will be used in train.py and eval.py | |||||
| """ | |||||
| import math | |||||
| import os | |||||
| import mindspore.common.dtype as mstype | |||||
| import mindspore.dataset as ds | |||||
| import mindspore.dataset.transforms.c_transforms as C2 | |||||
| import mindspore.dataset.vision.c_transforms as C | |||||
| from mindspore.communication.management import get_group_size, get_rank | |||||
| from mindspore.dataset.vision import Inter | |||||
| from src.config import efficientnet_b0_config_gpu as cfg | |||||
| from src.transform import RandAugment | |||||
| ds.config.set_seed(cfg.random_seed) | |||||
| IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) | |||||
| IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) | |||||
| img_size = (224, 224) | |||||
| crop_pct = 0.875 | |||||
| rescale = 1.0 / 255.0 | |||||
| shift = 0.0 | |||||
| inter_method = 'bilinear' | |||||
| resize_value = 224 # img_size | |||||
| scale = (0.08, 1.0) | |||||
| ratio = (3./4., 4./3.) | |||||
| inter_str = 'bicubic' | |||||
| def str2MsInter(method): | |||||
| if method == 'bicubic': | |||||
| return Inter.BICUBIC | |||||
| if method == 'nearest': | |||||
| return Inter.NEAREST | |||||
| return Inter.BILINEAR | |||||
| def create_dataset(batch_size, train_data_url='', workers=8, distributed=False): | |||||
| if not os.path.exists(train_data_url): | |||||
| raise ValueError('Path not exists') | |||||
| interpolation = str2MsInter(inter_str) | |||||
| c_decode_op = C.Decode() | |||||
| type_cast_op = C2.TypeCast(mstype.int32) | |||||
| random_resize_crop_op = C.RandomResizedCrop(size=(resize_value, resize_value), scale=scale, ratio=ratio, | |||||
| interpolation=interpolation) | |||||
| random_horizontal_flip_op = C.RandomHorizontalFlip(0.5) | |||||
| efficient_rand_augment = RandAugment() | |||||
| image_ops = [c_decode_op, random_resize_crop_op, random_horizontal_flip_op] | |||||
| rank_id = get_rank() if distributed else 0 | |||||
| rank_size = get_group_size() if distributed else 1 | |||||
| dataset_train = ds.ImageFolderDataset(train_data_url, | |||||
| num_parallel_workers=workers, | |||||
| shuffle=True, | |||||
| num_shards=rank_size, | |||||
| shard_id=rank_id) | |||||
| dataset_train = dataset_train.map(input_columns=["image"], | |||||
| operations=image_ops, | |||||
| num_parallel_workers=workers) | |||||
| dataset_train = dataset_train.map(input_columns=["label"], | |||||
| operations=type_cast_op, | |||||
| num_parallel_workers=workers) | |||||
| ds_train = dataset_train.batch(batch_size, | |||||
| per_batch_map=efficient_rand_augment, | |||||
| input_columns=["image", "label"], | |||||
| num_parallel_workers=2, | |||||
| drop_remainder=True) | |||||
| ds_train = ds_train.repeat(1) | |||||
| return ds_train | |||||
| def create_dataset_val(batch_size=128, val_data_url='', workers=8, distributed=False): | |||||
| if not os.path.exists(val_data_url): | |||||
| raise ValueError('Path not exists') | |||||
| rank_id = get_rank() if distributed else 0 | |||||
| rank_size = get_group_size() if distributed else 1 | |||||
| dataset = ds.ImageFolderDataset(val_data_url, num_parallel_workers=workers, | |||||
| num_shards=rank_size, shard_id=rank_id, shuffle=False) | |||||
| scale_size = None | |||||
| interpolation = str2MsInter(inter_method) | |||||
| if isinstance(img_size, tuple): | |||||
| assert len(img_size) == 2 | |||||
| if img_size[-1] == img_size[-2]: | |||||
| scale_size = int(math.floor(img_size[0] / crop_pct)) | |||||
| else: | |||||
| scale_size = tuple([int(x / crop_pct) for x in img_size]) | |||||
| else: | |||||
| scale_size = int(math.floor(img_size / crop_pct)) | |||||
| type_cast_op = C2.TypeCast(mstype.int32) | |||||
| decode_op = C.Decode() | |||||
| resize_op = C.Resize(size=scale_size, interpolation=interpolation) | |||||
| center_crop = C.CenterCrop(size=224) | |||||
| rescale_op = C.Rescale(rescale, shift) | |||||
| normalize_op = C.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) | |||||
| changeswap_op = C.HWC2CHW() | |||||
| ctrans = [decode_op, resize_op, center_crop, rescale_op, normalize_op, changeswap_op] | |||||
| dataset = dataset.map(input_columns=["label"], operations=type_cast_op, num_parallel_workers=workers) | |||||
| dataset = dataset.map(input_columns=["image"], operations=ctrans, num_parallel_workers=workers) | |||||
| dataset = dataset.batch(batch_size, drop_remainder=True, num_parallel_workers=workers) | |||||
| dataset = dataset.repeat(1) | |||||
| return dataset | |||||
| @@ -0,0 +1,746 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """EfficientNet model definition""" | |||||
| import logging | |||||
| import math | |||||
| import re | |||||
| from copy import deepcopy | |||||
| import mindspore as ms | |||||
| import mindspore.nn as nn | |||||
| from mindspore import context, ms_function | |||||
| from mindspore.common.initializer import (Normal, One, Uniform, Zero, | |||||
| initializer) | |||||
| from mindspore.common.parameter import Parameter | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore.ops.composite import clip_by_value | |||||
| relu = P.ReLU() | |||||
| sigmoid = P.Sigmoid() | |||||
| IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) | |||||
| IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) | |||||
| IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5) | |||||
| IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5) | |||||
| IMAGENET_DPN_MEAN = (124 / 255, 117 / 255, 104 / 255) | |||||
| IMAGENET_DPN_STD = tuple([1 / (.0167 * 255)] * 3) | |||||
| def _cfg(url='', **kwargs): | |||||
| return { | |||||
| 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), | |||||
| 'crop_pct': 0.875, 'interpolation': 'bicubic', | |||||
| 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, | |||||
| 'first_conv': 'conv_stem', 'classifier': 'classifier', | |||||
| **kwargs | |||||
| } | |||||
| default_cfgs = { | |||||
| 'efficientnet_b0': _cfg( | |||||
| url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b0-d6904d92.pth'), | |||||
| 'efficientnet_b1': _cfg( | |||||
| url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b1-533bc792.pth', | |||||
| input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), | |||||
| 'efficientnet_b2': _cfg( | |||||
| url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b2-cf78dc4d.pth', | |||||
| input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890), | |||||
| 'efficientnet_b3': _cfg( | |||||
| url='', input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904), | |||||
| 'efficientnet_b4': _cfg( | |||||
| url='', input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922), | |||||
| } | |||||
| _DEBUG = False | |||||
| _BN_MOMENTUM_PT_DEFAULT = 0.1 | |||||
| _BN_EPS_PT_DEFAULT = 1e-5 | |||||
| _BN_ARGS_PT = dict(momentum=_BN_MOMENTUM_PT_DEFAULT, eps=_BN_EPS_PT_DEFAULT) | |||||
| _BN_MOMENTUM_TF_DEFAULT = 1 - 0.99 | |||||
| _BN_EPS_TF_DEFAULT = 1e-3 | |||||
| _BN_ARGS_TF = dict(momentum=_BN_MOMENTUM_TF_DEFAULT, eps=_BN_EPS_TF_DEFAULT) | |||||
| def _initialize_weight_goog(shape=None, layer_type='conv', bias=False): | |||||
| if layer_type not in ('conv', 'bn', 'fc'): | |||||
| raise ValueError('The layer type is not known, the supported are conv, bn and fc') | |||||
| if bias: | |||||
| return Zero() | |||||
| if layer_type == 'conv': | |||||
| assert isinstance(shape, (tuple, list)) and len( | |||||
| shape) == 3, 'The shape must be 3 scalars, and are in_chs, ks, out_chs respectively' | |||||
| n = shape[1] * shape[1] * shape[2] | |||||
| return Normal(math.sqrt(2.0 / n)) | |||||
| if layer_type == 'bn': | |||||
| return One() | |||||
| assert isinstance(shape, (tuple, list)) and len( | |||||
| shape) == 2, 'The shape must be 2 scalars, and are in_chs, out_chs respectively' | |||||
| n = shape[1] | |||||
| init_range = 1.0 / math.sqrt(n) | |||||
| return Uniform(init_range) | |||||
| def _initialize_weight_default(shape=None, layer_type='conv', bias=False): | |||||
| if layer_type not in ('conv', 'bn', 'fc'): | |||||
| raise ValueError('The layer type is not known, the supported are conv, bn and fc') | |||||
| if bias and layer_type == 'bn': | |||||
| return Zero() | |||||
| if layer_type == 'conv': | |||||
| return One() | |||||
| if layer_type == 'bn': | |||||
| return One() | |||||
| return One() | |||||
| def _conv(in_channels, out_channels, kernel_size=3, stride=1, padding=0, pad_mode='same', bias=False): | |||||
| weight_init_value = _initialize_weight_goog(shape=(in_channels, kernel_size, out_channels)) | |||||
| bias_init_value = _initialize_weight_goog(bias=True) if bias else None | |||||
| if bias: | |||||
| return nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, | |||||
| padding=padding, pad_mode=pad_mode, weight_init=weight_init_value, | |||||
| has_bias=bias, bias_init=bias_init_value) | |||||
| return nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, | |||||
| padding=padding, pad_mode=pad_mode, weight_init=weight_init_value, | |||||
| has_bias=bias) | |||||
| def _conv1x1(in_channels, out_channels, stride=1, padding=0, pad_mode='same', bias=False): | |||||
| weight_init_value = _initialize_weight_goog(shape=(in_channels, 1, out_channels)) | |||||
| bias_init_value = _initialize_weight_goog(bias=True) if bias else None | |||||
| if bias: | |||||
| return nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, | |||||
| padding=padding, pad_mode=pad_mode, weight_init=weight_init_value, | |||||
| has_bias=bias, bias_init=bias_init_value) | |||||
| return nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, | |||||
| padding=padding, pad_mode=pad_mode, weight_init=weight_init_value, | |||||
| has_bias=bias) | |||||
| def _conv_group(in_channels, out_channels, group, kernel_size=3, stride=1, padding=0, pad_mode='same', bias=False): | |||||
| weight_init_value = _initialize_weight_goog(shape=(in_channels, kernel_size, out_channels)) | |||||
| bias_init_value = _initialize_weight_goog(bias=True) if bias else None | |||||
| if bias: | |||||
| return nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, | |||||
| padding=padding, pad_mode=pad_mode, weight_init=weight_init_value, | |||||
| group=group, has_bias=bias, bias_init=bias_init_value) | |||||
| return nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, | |||||
| padding=padding, pad_mode=pad_mode, weight_init=weight_init_value, | |||||
| group=group, has_bias=bias) | |||||
| def _fused_bn(channels, momentum=0.1, eps=1e-4, gamma_init=1, beta_init=0): | |||||
| return nn.BatchNorm2d(channels, eps=eps, momentum=1 - momentum, gamma_init=gamma_init, beta_init=beta_init) | |||||
| def _dense(in_channels, output_channels, bias=True, activation=None): | |||||
| weight_init_value = _initialize_weight_goog(shape=(in_channels, output_channels), layer_type='fc') | |||||
| bias_init_value = _initialize_weight_goog(bias=True) if bias else None | |||||
| if bias: | |||||
| return nn.Dense(in_channels, output_channels, weight_init=weight_init_value, bias_init=bias_init_value, | |||||
| has_bias=bias, activation=activation) | |||||
| return nn.Dense(in_channels, output_channels, weight_init=weight_init_value, has_bias=bias, | |||||
| activation=activation) | |||||
| def _resolve_bn_args(kwargs): | |||||
| bn_args = _BN_ARGS_TF.copy() if kwargs.pop('bn_tf', False) else _BN_ARGS_PT.copy() | |||||
| bn_momentum = kwargs.pop('bn_momentum', None) | |||||
| if bn_momentum is not None: | |||||
| bn_args['momentum'] = bn_momentum | |||||
| bn_eps = kwargs.pop('bn_eps', None) | |||||
| if bn_eps is not None: | |||||
| bn_args['eps'] = bn_eps | |||||
| return bn_args | |||||
| def _round_channels(channels, multiplier=1.0, divisor=8, channel_min=None): | |||||
| """Round number of filters based on depth multiplier.""" | |||||
| if not multiplier: | |||||
| return channels | |||||
| channels *= multiplier | |||||
| channel_min = channel_min or divisor | |||||
| new_channels = max( | |||||
| int(channels + divisor / 2) // divisor * divisor, | |||||
| channel_min) | |||||
| if new_channels < 0.9 * channels: | |||||
| new_channels += divisor | |||||
| return new_channels | |||||
| def _parse_ksize(ss): | |||||
| if ss.isdigit(): | |||||
| return int(ss) | |||||
| return [int(k) for k in ss.split('.')] | |||||
| def _decode_block_str(block_str, depth_multiplier=1.0): | |||||
| """ Decode block definition string | |||||
| Gets a list of block arg (dicts) through a string notation of arguments. | |||||
| E.g. ir_r2_k3_s2_e1_i32_o16_se0.25_noskip | |||||
| All args can exist in any order with the exception of the leading string which | |||||
| is assumed to indicate the block type. | |||||
| leading string - block type ( | |||||
| ir = InvertedResidual, ds = DepthwiseSep, dsa = DeptwhiseSep with pw act, cn = ConvBnAct) | |||||
| r - number of repeat blocks, | |||||
| k - kernel size, | |||||
| s - strides (1-9), | |||||
| e - expansion ratio, | |||||
| c - output channels, | |||||
| se - squeeze/excitation ratio | |||||
| n - activation fn ('re', 'r6', 'hs', or 'sw') | |||||
| Args: | |||||
| block_str: a string representation of block arguments. | |||||
| Returns: | |||||
| A list of block args (dicts) | |||||
| Raises: | |||||
| ValueError: if the string def not properly specified (TODO) | |||||
| """ | |||||
| assert isinstance(block_str, str) | |||||
| ops = block_str.split('_') | |||||
| block_type = ops[0] | |||||
| ops = ops[1:] | |||||
| options = {} | |||||
| noskip = False | |||||
| for op in ops: | |||||
| if op == 'noskip': | |||||
| noskip = True | |||||
| elif op.startswith('n'): | |||||
| # activation fn | |||||
| key = op[0] | |||||
| v = op[1:] | |||||
| if v == 're': | |||||
| print('not support') | |||||
| elif v == 'r6': | |||||
| print('not support') | |||||
| elif v == 'hs': | |||||
| print('not support') | |||||
| elif v == 'sw': | |||||
| print('not support') | |||||
| else: | |||||
| continue | |||||
| options[key] = value | |||||
| else: | |||||
| # all numeric options | |||||
| splits = re.split(r'(\d.*)', op) | |||||
| if len(splits) >= 2: | |||||
| key, value = splits[:2] | |||||
| options[key] = value | |||||
| act_fn = options['n'] if 'n' in options else None | |||||
| exp_kernel_size = _parse_ksize(options['a']) if 'a' in options else 1 | |||||
| pw_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1 | |||||
| fake_in_chs = int(options['fc']) if 'fc' in options else 0 | |||||
| num_repeat = int(options['r']) | |||||
| # each type of block has different valid arguments, fill accordingly | |||||
| if block_type == 'ir': | |||||
| block_args = dict( | |||||
| block_type=block_type, | |||||
| dw_kernel_size=_parse_ksize(options['k']), | |||||
| exp_kernel_size=exp_kernel_size, | |||||
| pw_kernel_size=pw_kernel_size, | |||||
| out_chs=int(options['c']), | |||||
| exp_ratio=float(options['e']), | |||||
| se_ratio=float(options['se']) if 'se' in options else None, | |||||
| stride=int(options['s']), | |||||
| act_fn=act_fn, | |||||
| noskip=noskip, | |||||
| ) | |||||
| elif block_type in ('ds', 'dsa'): | |||||
| block_args = dict( | |||||
| block_type=block_type, | |||||
| dw_kernel_size=_parse_ksize(options['k']), | |||||
| pw_kernel_size=pw_kernel_size, | |||||
| out_chs=int(options['c']), | |||||
| se_ratio=float(options['se']) if 'se' in options else None, | |||||
| stride=int(options['s']), | |||||
| act_fn=act_fn, | |||||
| pw_act=block_type == 'dsa', | |||||
| noskip=block_type == 'dsa' or noskip, | |||||
| ) | |||||
| elif block_type == 'er': | |||||
| block_args = dict( | |||||
| block_type=block_type, | |||||
| exp_kernel_size=_parse_ksize(options['k']), | |||||
| pw_kernel_size=pw_kernel_size, | |||||
| out_chs=int(options['c']), | |||||
| exp_ratio=float(options['e']), | |||||
| fake_in_chs=fake_in_chs, | |||||
| se_ratio=float(options['se']) if 'se' in options else None, | |||||
| stride=int(options['s']), | |||||
| act_fn=act_fn, | |||||
| noskip=noskip, | |||||
| ) | |||||
| elif block_type == 'cn': | |||||
| block_args = dict( | |||||
| block_type=block_type, | |||||
| kernel_size=int(options['k']), | |||||
| out_chs=int(options['c']), | |||||
| stride=int(options['s']), | |||||
| act_fn=act_fn, | |||||
| ) | |||||
| else: | |||||
| assert False, 'Unknown block type (%s)' % block_type | |||||
| return block_args, num_repeat | |||||
| def _scale_stage_depth(stack_args, repeats, depth_multiplier=1.0, depth_trunc='ceil'): | |||||
| """ Per-stage depth scaling | |||||
| Scales the block repeats in each stage. This depth scaling impl maintains | |||||
| compatibility with the EfficientNet scaling method, while allowing sensible | |||||
| scaling for other models that may have multiple block arg definitions in each stage. | |||||
| """ | |||||
| # We scale the total repeat count for each stage, there may be multiple | |||||
| # block arg defs per stage so we need to sum. | |||||
| num_repeat = sum(repeats) | |||||
| if depth_trunc == 'round': | |||||
| # Truncating to int by rounding allows stages with few repeats to remain | |||||
| # proportionally smaller for longer. This is a good choice when stage definitions | |||||
| # include single repeat stages that we'd prefer to keep that way as long as possible | |||||
| num_repeat_scaled = max(1, round(num_repeat * depth_multiplier)) | |||||
| else: | |||||
| # The default for EfficientNet truncates repeats to int via 'ceil'. | |||||
| # Any multiplier > 1.0 will result in an increased depth for every stage. | |||||
| num_repeat_scaled = int(math.ceil(num_repeat * depth_multiplier)) | |||||
| # Proportionally distribute repeat count scaling to each block definition in the stage. | |||||
| # Allocation is done in reverse as it results in the first block being less likely to be scaled. | |||||
| # The first block makes less sense to repeat in most of the arch definitions. | |||||
| repeats_scaled = [] | |||||
| for r in repeats[::-1]: | |||||
| rs = max(1, round((r / num_repeat * num_repeat_scaled))) | |||||
| repeats_scaled.append(rs) | |||||
| num_repeat -= r | |||||
| num_repeat_scaled -= rs | |||||
| repeats_scaled = repeats_scaled[::-1] | |||||
| # Apply the calculated scaling to each block arg in the stage | |||||
| sa_scaled = [] | |||||
| for ba, rep in zip(stack_args, repeats_scaled): | |||||
| sa_scaled.extend([deepcopy(ba) for _ in range(rep)]) | |||||
| return sa_scaled | |||||
| def _decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil'): | |||||
| arch_args = [] | |||||
| for _, block_strings in enumerate(arch_def): | |||||
| assert isinstance(block_strings, list) | |||||
| stack_args = [] | |||||
| repeats = [] | |||||
| for block_str in block_strings: | |||||
| assert isinstance(block_str, str) | |||||
| ba, rep = _decode_block_str(block_str) | |||||
| stack_args.append(ba) | |||||
| repeats.append(rep) | |||||
| arch_args.append(_scale_stage_depth(stack_args, repeats, depth_multiplier, depth_trunc)) | |||||
| return arch_args | |||||
| @ms_function | |||||
| def hard_swish(x): | |||||
| x = P.Cast()(x, ms.float32) | |||||
| y = x + 3.0 | |||||
| y = clip_by_value(y, 0.0, 6.0) | |||||
| y = y / 6.0 | |||||
| return x * y | |||||
| class BlockBuilder(nn.Cell): | |||||
| def __init__(self, builder_in_channels, builder_block_args, channel_multiplier=1.0, channel_divisor=8, | |||||
| channel_min=None, pad_type='', act_fn=None, se_gate_fn=sigmoid, se_reduce_mid=False, | |||||
| bn_args=None, drop_connect_rate=0., verbose=False): | |||||
| super(BlockBuilder, self).__init__() | |||||
| bn_args = _BN_ARGS_PT if bn_args is None else bn_args | |||||
| self.channel_multiplier = channel_multiplier | |||||
| self.channel_divisor = channel_divisor | |||||
| self.channel_min = channel_min | |||||
| self.pad_type = pad_type | |||||
| self.act_fn = act_fn | |||||
| self.se_gate_fn = se_gate_fn | |||||
| self.se_reduce_mid = se_reduce_mid | |||||
| self.bn_args = bn_args | |||||
| self.drop_connect_rate = drop_connect_rate | |||||
| self.verbose = verbose | |||||
| self.in_chs = None | |||||
| self.block_idx = 0 | |||||
| self.block_count = 0 | |||||
| self.layer = self._make_layer(builder_in_channels, builder_block_args) | |||||
| def _round_channels(self, chs): | |||||
| return _round_channels(chs, self.channel_multiplier, self.channel_divisor, self.channel_min) | |||||
| def _make_block(self, ba): | |||||
| bt = ba.pop('block_type') | |||||
| ba['in_chs'] = self.in_chs | |||||
| ba['out_chs'] = self._round_channels(ba['out_chs']) | |||||
| if 'fake_in_chs' in ba and ba['fake_in_chs']: | |||||
| ba['fake_in_chs'] = self._round_channels(ba['fake_in_chs']) | |||||
| ba['bn_args'] = self.bn_args | |||||
| ba['pad_type'] = self.pad_type | |||||
| ba['act_fn'] = ba['act_fn'] if ba['act_fn'] is not None else self.act_fn | |||||
| assert ba['act_fn'] is not None | |||||
| if bt == 'ir': | |||||
| ba['drop_connect_rate'] = self.drop_connect_rate * self.block_idx / self.block_count | |||||
| ba['se_gate_fn'] = self.se_gate_fn | |||||
| ba['se_reduce_mid'] = self.se_reduce_mid | |||||
| if self.verbose: | |||||
| logging.info(' InvertedResidual %d, Args: %s', self.block_idx, str(ba)) | |||||
| block = InvertedResidual(**ba) | |||||
| elif bt in ('ds', 'dsa'): | |||||
| ba['drop_connect_rate'] = self.drop_connect_rate * self.block_idx / self.block_count | |||||
| if self.verbose: | |||||
| logging.info(' DepthwiseSeparable %d, Args: %s', self.block_idx, str(ba)) | |||||
| block = DepthwiseSeparableConv(**ba) | |||||
| else: | |||||
| assert False, 'Uknkown block type (%s) while building model.' % bt | |||||
| self.in_chs = ba['out_chs'] | |||||
| return block | |||||
| def _make_stack(self, stack_args): | |||||
| blocks = [] | |||||
| # each stack (stage) contains a list of block arguments | |||||
| for i, ba in enumerate(stack_args): | |||||
| if self.verbose: | |||||
| logging.info(' Block: %d', i) | |||||
| if i >= 1: | |||||
| # only the first block in any stack can have a stride > 1 | |||||
| ba['stride'] = 1 | |||||
| block = self._make_block(ba) | |||||
| blocks.append(block) | |||||
| self.block_idx += 1 | |||||
| return nn.SequentialCell(blocks) | |||||
| def _make_layer(self, in_chs, block_args): | |||||
| """ Build the blocks | |||||
| Args: | |||||
| in_chs: Number of input-channels passed to first block | |||||
| block_args: A list of lists, outer list defines stages, inner | |||||
| list contains strings defining block configuration(s) | |||||
| Return: | |||||
| List of block stacks (each stack wrapped in nn.Sequential) | |||||
| """ | |||||
| if self.verbose: | |||||
| logging.info('Building model trunk with %d stages...', len(block_args)) | |||||
| self.in_chs = in_chs | |||||
| self.block_count = sum([len(x) for x in block_args]) | |||||
| self.block_idx = 0 | |||||
| blocks = [] | |||||
| for stack_idx, stack in enumerate(block_args): | |||||
| if self.verbose: | |||||
| logging.info('Stack: %d', stack_idx) | |||||
| assert isinstance(stack, list) | |||||
| stack = self._make_stack(stack) | |||||
| blocks.append(stack) | |||||
| return nn.SequentialCell(blocks) | |||||
| def construct(self, x): | |||||
| return self.layer(x) | |||||
| class DepthWiseConv(nn.Cell): | |||||
| def __init__(self, in_planes, kernel_size, stride): | |||||
| super(DepthWiseConv, self).__init__() | |||||
| platform = context.get_context("device_target") | |||||
| weight_shape = [1, kernel_size, in_planes] | |||||
| weight_init = _initialize_weight_goog(shape=weight_shape) | |||||
| if platform == "GPU": | |||||
| self.depthwise_conv = P.Conv2D(out_channel=in_planes * 1, kernel_size=kernel_size, | |||||
| stride=stride, pad_mode="same", group=in_planes) | |||||
| self.weight = Parameter(initializer( | |||||
| weight_init, [in_planes * 1, 1, kernel_size, kernel_size]), name='depthwise_weight') | |||||
| else: | |||||
| self.depthwise_conv = P.DepthwiseConv2dNative( | |||||
| channel_multiplier=1, kernel_size=kernel_size, stride=stride, pad_mode='same',) | |||||
| self.weight = Parameter(initializer( | |||||
| weight_init, [1, in_planes, kernel_size, kernel_size]), name='depthwise_weight') | |||||
| def construct(self, x): | |||||
| x = self.depthwise_conv(x, self.weight) | |||||
| return x | |||||
| class DropConnect(nn.Cell): | |||||
| def __init__(self, drop_connect_rate=0., seed0=0, seed1=0): | |||||
| super(DropConnect, self).__init__() | |||||
| self.shape = P.Shape() | |||||
| self.dtype = P.DType() | |||||
| self.keep_prob = 1 - drop_connect_rate | |||||
| self.dropout = P.Dropout(keep_prob=self.keep_prob) | |||||
| def construct(self, x): | |||||
| shape = self.shape(x) | |||||
| dtype = self.dtype(x) | |||||
| ones_tensor = P.Fill()(dtype, (shape[0], 1, 1, 1), 1) | |||||
| _, mask_ = self.dropout(ones_tensor) | |||||
| x = x * mask_ | |||||
| return x | |||||
| def drop_connect(inputs, training=False, drop_connect_rate=0.): | |||||
| if not training: | |||||
| return inputs | |||||
| return DropConnect(drop_connect_rate)(inputs) | |||||
| class SqueezeExcite(nn.Cell): | |||||
| def __init__(self, in_chs, reduce_chs=None, act_fn=relu, gate_fn=sigmoid): | |||||
| super(SqueezeExcite, self).__init__() | |||||
| self.act_fn = act_fn | |||||
| self.gate_fn = gate_fn | |||||
| reduce_chs = reduce_chs or in_chs | |||||
| self.conv_reduce = _dense(in_chs, reduce_chs, bias=True) | |||||
| self.conv_expand = _dense(reduce_chs, in_chs, bias=True) | |||||
| self.avg_global_pool = P.ReduceMean(keep_dims=False) | |||||
| def construct(self, x): | |||||
| x_se = self.avg_global_pool(x, (2, 3)) | |||||
| x_se = self.conv_reduce(x_se) | |||||
| x_se = self.act_fn(x_se) | |||||
| x_se = self.conv_expand(x_se) | |||||
| x_se = self.gate_fn(x_se) | |||||
| x_se = P.ExpandDims()(x_se, 2) | |||||
| x_se = P.ExpandDims()(x_se, 3) | |||||
| x = x * x_se | |||||
| return x | |||||
| class DepthwiseSeparableConv(nn.Cell): | |||||
| def __init__(self, in_chs, out_chs, dw_kernel_size=3, | |||||
| stride=1, pad_type='', act_fn=relu, noskip=False, | |||||
| pw_kernel_size=1, pw_act=False, se_ratio=0., se_gate_fn=sigmoid, | |||||
| bn_args=None, drop_connect_rate=0.): | |||||
| super(DepthwiseSeparableConv, self).__init__() | |||||
| bn_args = _BN_ARGS_PT if bn_args is None else bn_args | |||||
| assert stride in [1, 2], 'stride must be 1 or 2' | |||||
| self.has_se = se_ratio is not None and se_ratio > 0. | |||||
| self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip | |||||
| self.has_pw_act = pw_act | |||||
| self.act_fn = act_fn | |||||
| self.drop_connect_rate = drop_connect_rate | |||||
| self.conv_dw = DepthWiseConv(in_chs, dw_kernel_size, stride) | |||||
| self.bn1 = _fused_bn(in_chs, **bn_args) | |||||
| # | |||||
| if self.has_se: | |||||
| self.se = SqueezeExcite(in_chs, reduce_chs=max(1, int(in_chs * se_ratio)), | |||||
| act_fn=act_fn, gate_fn=se_gate_fn) | |||||
| self.conv_pw = _conv1x1(in_chs, out_chs) | |||||
| self.bn2 = _fused_bn(out_chs, **bn_args) | |||||
| def construct(self, x): | |||||
| identity = x | |||||
| x = self.conv_dw(x) | |||||
| x = self.bn1(x) | |||||
| x = self.act_fn(x) | |||||
| if self.has_se: | |||||
| x = self.se(x) | |||||
| x = self.conv_pw(x) | |||||
| x = self.bn2(x) | |||||
| if self.has_pw_act: | |||||
| x = self.act_fn(x) | |||||
| if self.has_residual: | |||||
| if self.drop_connect_rate > 0.: | |||||
| x = drop_connect(x, self.training, self.drop_connect_rate) | |||||
| x = x + identity | |||||
| return x | |||||
| class InvertedResidual(nn.Cell): | |||||
| def __init__(self, in_chs, out_chs, dw_kernel_size=3, stride=1, | |||||
| pad_type='', act_fn=relu, pw_kernel_size=1, | |||||
| noskip=False, exp_ratio=1., exp_kernel_size=1, se_ratio=0., | |||||
| se_reduce_mid=False, se_gate_fn=sigmoid, shuffle_type=None, | |||||
| bn_args=None, drop_connect_rate=0.): | |||||
| super(InvertedResidual, self).__init__() | |||||
| bn_args = _BN_ARGS_PT if bn_args is None else bn_args | |||||
| mid_chs = int(in_chs * exp_ratio) | |||||
| self.has_se = se_ratio is not None and se_ratio > 0. | |||||
| self.has_residual = (in_chs == out_chs and stride == 1) and not noskip | |||||
| self.act_fn = act_fn | |||||
| self.drop_connect_rate = drop_connect_rate | |||||
| self.conv_pw = _conv(in_chs, mid_chs, exp_kernel_size) | |||||
| self.bn1 = _fused_bn(mid_chs, **bn_args) | |||||
| self.shuffle_type = shuffle_type | |||||
| if self.shuffle_type is not None and isinstance(exp_kernel_size, list): | |||||
| self.shuffle = None | |||||
| self.conv_dw = DepthWiseConv(mid_chs, dw_kernel_size, stride) | |||||
| self.bn2 = _fused_bn(mid_chs, **bn_args) | |||||
| if self.has_se: | |||||
| se_base_chs = mid_chs if se_reduce_mid else in_chs | |||||
| self.se = SqueezeExcite(mid_chs, reduce_chs=max(1, int(se_base_chs * se_ratio)), | |||||
| act_fn=act_fn, gate_fn=se_gate_fn) | |||||
| self.conv_pwl = _conv(mid_chs, out_chs, pw_kernel_size) | |||||
| self.bn3 = _fused_bn(out_chs, **bn_args) | |||||
| def construct(self, x): | |||||
| identity = x | |||||
| x = self.conv_pw(x) | |||||
| x = self.bn1(x) | |||||
| x = self.act_fn(x) | |||||
| x = self.conv_dw(x) | |||||
| x = self.bn2(x) | |||||
| x = self.act_fn(x) | |||||
| if self.has_se: | |||||
| x = self.se(x) | |||||
| x = self.conv_pwl(x) | |||||
| x = self.bn3(x) | |||||
| if self.has_residual: | |||||
| if self.drop_connect_rate > 0: | |||||
| x = drop_connect(x, self.training, self.drop_connect_rate) | |||||
| x = x + identity | |||||
| return x | |||||
| class GenEfficientNet(nn.Cell): | |||||
| def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=32, num_features=1280, | |||||
| channel_multiplier=1.0, channel_divisor=8, channel_min=None, | |||||
| pad_type='', act_fn=relu, drop_rate=0., drop_connect_rate=0., | |||||
| se_gate_fn=sigmoid, se_reduce_mid=False, bn_args=None, | |||||
| global_pool='avg', head_conv='default', weight_init='goog'): | |||||
| super(GenEfficientNet, self).__init__() | |||||
| bn_args = _BN_ARGS_PT if bn_args is None else bn_args | |||||
| self.num_classes = num_classes | |||||
| self.drop_rate = drop_rate | |||||
| self.act_fn = act_fn | |||||
| self.num_features = num_features | |||||
| stem_size = _round_channels(stem_size, channel_multiplier, channel_divisor, channel_min) | |||||
| self.conv_stem = _conv(in_chans, stem_size, 3, stride=2) | |||||
| self.bn1 = _fused_bn(stem_size, **bn_args) | |||||
| in_chans = stem_size | |||||
| self.blocks = BlockBuilder(in_chans, block_args, channel_multiplier, channel_divisor, channel_min, | |||||
| pad_type, act_fn, se_gate_fn, se_reduce_mid, | |||||
| bn_args, drop_connect_rate, verbose=_DEBUG) | |||||
| in_chs = self.blocks.in_chs | |||||
| if not head_conv or head_conv == 'none': | |||||
| self.efficient_head = False | |||||
| self.conv_head = None | |||||
| assert in_chs == self.num_features | |||||
| else: | |||||
| self.efficient_head = head_conv == 'efficient' | |||||
| self.conv_head = _conv1x1(in_chs, self.num_features) | |||||
| self.bn2 = None if self.efficient_head else _fused_bn(self.num_features, **bn_args) | |||||
| self.global_pool = P.ReduceMean(keep_dims=True) | |||||
| self.classifier = _dense(self.num_features, self.num_classes) | |||||
| self.reshape = P.Reshape() | |||||
| self.shape = P.Shape() | |||||
| self.drop_out = nn.Dropout(keep_prob=1 - self.drop_rate) | |||||
| def construct(self, x): | |||||
| x = self.conv_stem(x) | |||||
| x = self.bn1(x) | |||||
| x = self.act_fn(x) | |||||
| x = self.blocks(x) | |||||
| if self.efficient_head: | |||||
| x = self.global_pool(x, (2, 3)) | |||||
| x = self.conv_head(x) | |||||
| x = self.act_fn(x) | |||||
| x = self.reshape(self.shape(x)[0], -1) | |||||
| else: | |||||
| if self.conv_head is not None: | |||||
| x = self.conv_head(x) | |||||
| x = self.bn2(x) | |||||
| x = self.act_fn(x) | |||||
| x = self.global_pool(x, (2, 3)) | |||||
| x = self.reshape(x, (self.shape(x)[0], -1)) | |||||
| if self.training and self.drop_rate > 0.: | |||||
| x = self.drop_out(x) | |||||
| return self.classifier(x) | |||||
| def _gen_efficientnet(channel_multiplier=1.0, depth_multiplier=1.0, num_classes=1000, **kwargs): | |||||
| """Creates an EfficientNet model. | |||||
| Ref impl: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py | |||||
| Paper: https://arxiv.org/abs/1905.11946 | |||||
| EfficientNet params | |||||
| name: (channel_multiplier, depth_multiplier, resolution, dropout_rate) | |||||
| 'efficientnet-b0': (1.0, 1.0, 224, 0.2), | |||||
| 'efficientnet-b1': (1.0, 1.1, 240, 0.2), | |||||
| 'efficientnet-b2': (1.1, 1.2, 260, 0.3), | |||||
| 'efficientnet-b3': (1.2, 1.4, 300, 0.3), | |||||
| 'efficientnet-b4': (1.4, 1.8, 380, 0.4), | |||||
| 'efficientnet-b5': (1.6, 2.2, 456, 0.4), | |||||
| 'efficientnet-b6': (1.8, 2.6, 528, 0.5), | |||||
| 'efficientnet-b7': (2.0, 3.1, 600, 0.5), | |||||
| Args: | |||||
| channel_multiplier: multiplier to number of channels per layer | |||||
| depth_multiplier: multiplier to number of repeats per stage | |||||
| """ | |||||
| arch_def = [ | |||||
| ['ds_r1_k3_s1_e1_c16_se0.25'], | |||||
| ['ir_r2_k3_s2_e6_c24_se0.25'], | |||||
| ['ir_r2_k5_s2_e6_c40_se0.25'], | |||||
| ['ir_r3_k3_s2_e6_c80_se0.25'], | |||||
| ['ir_r3_k5_s1_e6_c112_se0.25'], | |||||
| ['ir_r4_k5_s2_e6_c192_se0.25'], | |||||
| ['ir_r1_k3_s1_e6_c320_se0.25'], | |||||
| ] | |||||
| num_features = _round_channels(1280, channel_multiplier, 8, None) | |||||
| model = GenEfficientNet( | |||||
| _decode_arch_def(arch_def, depth_multiplier), | |||||
| num_classes=num_classes, | |||||
| stem_size=32, | |||||
| channel_multiplier=channel_multiplier, | |||||
| num_features=num_features, | |||||
| bn_args=_resolve_bn_args(kwargs), | |||||
| act_fn=hard_swish, | |||||
| **kwargs | |||||
| ) | |||||
| return model | |||||
| def efficientnet_b0(num_classes=1000, in_chans=3, **kwargs): | |||||
| """ EfficientNet-B0 """ | |||||
| default_cfg = default_cfgs['efficientnet_b0'] | |||||
| model = _gen_efficientnet( | |||||
| channel_multiplier=1.0, depth_multiplier=1.0, | |||||
| num_classes=num_classes, in_chans=in_chans, **kwargs) | |||||
| model.default_cfg = default_cfg | |||||
| return model | |||||
| @@ -0,0 +1,37 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """define loss function for network.""" | |||||
| from mindspore.nn.loss.loss import _Loss | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore.ops import functional as F | |||||
| from mindspore.common import dtype as mstype | |||||
| from mindspore import Tensor | |||||
| import mindspore.nn as nn | |||||
| class LabelSmoothingCrossEntropy(_Loss): | |||||
| def __init__(self, smooth_factor=0.1, num_classes=1000): | |||||
| super(LabelSmoothingCrossEntropy, self).__init__() | |||||
| self.onehot = P.OneHot() | |||||
| self.on_value = Tensor(1.0 - smooth_factor, mstype.float32) | |||||
| self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32) | |||||
| self.ce = nn.SoftmaxCrossEntropyWithLogits() | |||||
| self.mean = P.ReduceMean(False) | |||||
| def construct(self, logits, label): | |||||
| one_hot_label = self.onehot(label, F.shape(logits)[1], self.on_value, self.off_value) | |||||
| loss_logit = self.ce(logits, one_hot_label) | |||||
| loss_logit = self.mean(loss_logit, 0) | |||||
| return loss_logit | |||||
| @@ -0,0 +1,48 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| random augment class | |||||
| """ | |||||
| import numpy as np | |||||
| import mindspore.dataset.vision.py_transforms as P | |||||
| from src import transform_utils | |||||
| IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) | |||||
| IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) | |||||
| class RandAugment: | |||||
| # config_str belongs to str | |||||
| # hparams belongs to dict | |||||
| def __init__(self, config_str="rand-m9-mstd0.5", hparams=None): | |||||
| hparams = hparams if hparams is not None else {} | |||||
| self.config_str = config_str | |||||
| self.hparams = hparams | |||||
| def __call__(self, imgs, labels, batchInfo): | |||||
| # assert the imgs objetc are pil_images | |||||
| ret_imgs = [] | |||||
| ret_labels = [] | |||||
| py_to_pil_op = P.ToPIL() | |||||
| to_tensor = P.ToTensor() | |||||
| normalize_op = P.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) | |||||
| rand_augment_ops = transform_utils.rand_augment_transform(self.config_str, self.hparams) | |||||
| for i, image in enumerate(imgs): | |||||
| img_pil = py_to_pil_op(image) | |||||
| img_pil = rand_augment_ops(img_pil) | |||||
| img_array = to_tensor(img_pil) | |||||
| img_array = normalize_op(img_array) | |||||
| ret_imgs.append(img_array) | |||||
| ret_labels.append(labels[i]) | |||||
| return np.array(ret_imgs), np.array(ret_labels) | |||||
| @@ -0,0 +1,571 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| random augment utils | |||||
| """ | |||||
| import math | |||||
| import random | |||||
| import re | |||||
| import numpy as np | |||||
| import PIL | |||||
| from PIL import Image, ImageEnhance, ImageOps | |||||
| _PIL_VER = tuple([int(x) for x in PIL.__version__.split('.')[:2]]) | |||||
| _FILL = (128, 128, 128) | |||||
| _MAX_LEVEL = 10. | |||||
| _HPARAMS_DEFAULT = dict(translate_const=250, img_mean=_FILL) | |||||
| _RAND_TRANSFORMS = [ | |||||
| 'Distort', | |||||
| 'Zoom', | |||||
| 'Blur', | |||||
| 'Skew', | |||||
| 'AutoContrast', | |||||
| 'Equalize', | |||||
| 'Invert', | |||||
| 'Rotate', | |||||
| 'PosterizeTpu', | |||||
| 'Solarize', | |||||
| 'SolarizeAdd', | |||||
| 'Color', | |||||
| 'Contrast', | |||||
| 'Brightness', | |||||
| 'Sharpness', | |||||
| 'ShearX', | |||||
| 'ShearY', | |||||
| 'TranslateXRel', | |||||
| 'TranslateYRel', | |||||
| ] | |||||
| _RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) | |||||
| _RAND_CHOICE_WEIGHTS_0 = { | |||||
| 'Rotate': 0.3, | |||||
| 'ShearX': 0.2, | |||||
| 'ShearY': 0.2, | |||||
| 'TranslateXRel': 0.1, | |||||
| 'TranslateYRel': 0.1, | |||||
| 'Color': .025, | |||||
| 'Sharpness': 0.025, | |||||
| 'AutoContrast': 0.025, | |||||
| 'Solarize': .005, | |||||
| 'SolarizeAdd': .005, | |||||
| 'Contrast': .005, | |||||
| 'Brightness': .005, | |||||
| 'Equalize': .005, | |||||
| 'PosterizeTpu': 0, | |||||
| 'Invert': 0, | |||||
| 'Distort': 0, | |||||
| 'Zoom': 0, | |||||
| 'Blur': 0, | |||||
| 'Skew': 0, | |||||
| } | |||||
| def _interpolation(kwargs): | |||||
| interpolation = kwargs.pop('resample', Image.BILINEAR) | |||||
| if isinstance(interpolation, (list, tuple)): | |||||
| return random.choice(interpolation) | |||||
| return interpolation | |||||
| def _check_args_tf(kwargs): | |||||
| if 'fillcolor' in kwargs and _PIL_VER < (5, 0): | |||||
| kwargs.pop('fillcolor') | |||||
| kwargs['resample'] = _interpolation(kwargs) | |||||
| # define all kinds of functions | |||||
| def _randomly_negate(v): | |||||
| return -v if random.random() > 0.5 else v | |||||
| def shear_x(img, factor, **kwargs): | |||||
| _check_args_tf(kwargs) | |||||
| return img.transform(img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs) | |||||
| def shear_y(img, factor, **kwargs): | |||||
| _check_args_tf(kwargs) | |||||
| return img.transform(img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs) | |||||
| def translate_x_rel(img, pct, **kwargs): | |||||
| pixels = pct * img.size[0] | |||||
| _check_args_tf(kwargs) | |||||
| return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs) | |||||
| def translate_y_rel(img, pct, **kwargs): | |||||
| pixels = pct * img.size[1] | |||||
| _check_args_tf(kwargs) | |||||
| return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs) | |||||
| def translate_x_abs(img, pixels, **kwargs): | |||||
| _check_args_tf(kwargs) | |||||
| return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs) | |||||
| def translate_y_abs(img, pixels, **kwargs): | |||||
| _check_args_tf(kwargs) | |||||
| return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs) | |||||
| def rotate(img, degrees, **kwargs): | |||||
| kwargs_new = kwargs | |||||
| kwargs_new.pop('resample') | |||||
| kwargs_new['resample'] = Image.BICUBIC | |||||
| if _PIL_VER >= (5, 2): | |||||
| return img.rotate(degrees, **kwargs_new) | |||||
| if _PIL_VER >= (5, 0): | |||||
| w, h = img.size | |||||
| post_trans = (0, 0) | |||||
| rotn_center = (w / 2.0, h / 2.0) | |||||
| angle = -math.radians(degrees) | |||||
| matrix = [ | |||||
| round(math.cos(angle), 15), | |||||
| round(math.sin(angle), 15), | |||||
| 0.0, | |||||
| round(-math.sin(angle), 15), | |||||
| round(math.cos(angle), 15), | |||||
| 0.0, | |||||
| ] | |||||
| def transform(x, y, matrix): | |||||
| (a, b, c, d, e, f) = matrix | |||||
| return a * x + b * y + c, d * x + e * y + f | |||||
| matrix[2], matrix[5] = transform( | |||||
| -rotn_center[0] - post_trans[0], -rotn_center[1] - post_trans[1], matrix | |||||
| ) | |||||
| matrix[2] += rotn_center[0] | |||||
| matrix[5] += rotn_center[1] | |||||
| return img.transform(img.size, Image.AFFINE, matrix, **kwargs_new) | |||||
| return img.rotate(degrees, resample=kwargs['resample']) | |||||
| def auto_contrast(img, **__): | |||||
| return ImageOps.autocontrast(img) | |||||
| def invert(img, **__): | |||||
| return ImageOps.invert(img) | |||||
| def equalize(img, **__): | |||||
| return ImageOps.equalize(img) | |||||
| def solarize(img, thresh, **__): | |||||
| return ImageOps.solarize(img, thresh) | |||||
| def solarize_add(img, add, thresh=128, **__): | |||||
| lut = [] | |||||
| for i in range(256): | |||||
| if i < thresh: | |||||
| lut.append(min(255, i + add)) | |||||
| else: | |||||
| lut.append(i) | |||||
| if img.mode in ("L", "RGB"): | |||||
| if img.mode == "RGB" and len(lut) == 256: | |||||
| lut = lut + lut + lut | |||||
| return img.point(lut) | |||||
| return img | |||||
| def posterize(img, bits_to_keep, **__): | |||||
| if bits_to_keep >= 8: | |||||
| return img | |||||
| return ImageOps.posterize(img, bits_to_keep) | |||||
| def contrast(img, factor, **__): | |||||
| return ImageEnhance.Contrast(img).enhance(factor) | |||||
| def color(img, factor, **__): | |||||
| return ImageEnhance.Color(img).enhance(factor) | |||||
| def brightness(img, factor, **__): | |||||
| return ImageEnhance.Brightness(img).enhance(factor) | |||||
| def sharpness(img, factor, **__): | |||||
| return ImageEnhance.Sharpness(img).enhance(factor) | |||||
| def _rotate_level_to_arg(level, _hparams): | |||||
| # range [-30, 30] | |||||
| level = (level / _MAX_LEVEL) * 30. | |||||
| level = _randomly_negate(level) | |||||
| return (level,) | |||||
| def _enhance_level_to_arg(level, _hparams): | |||||
| # range [0.1, 1.9] | |||||
| return ((level / _MAX_LEVEL) * 1.8 + 0.1,) | |||||
| def _shear_level_to_arg(level, _hparams): | |||||
| # range [-0.3, 0.3] | |||||
| level = (level / _MAX_LEVEL) * 0.3 | |||||
| level = _randomly_negate(level) | |||||
| return (level,) | |||||
| def _translate_abs_level_to_arg(level, hparams): | |||||
| translate_const = hparams['translate_const'] | |||||
| level = (level / _MAX_LEVEL) * float(translate_const) | |||||
| level = _randomly_negate(level) | |||||
| return (level,) | |||||
| def _translate_rel_level_to_arg(level, _hparams): | |||||
| # range [-0.45, 0.45] | |||||
| level = (level / _MAX_LEVEL) * 0.45 | |||||
| level = _randomly_negate(level) | |||||
| return (level,) | |||||
| def _posterize_original_level_to_arg(level, _hparams): | |||||
| # As per original AutoAugment paper description | |||||
| # range [4, 8], 'keep 4 up to 8 MSB of image' | |||||
| return (int((level / _MAX_LEVEL) * 4) + 4,) | |||||
| def _posterize_research_level_to_arg(level, _hparams): | |||||
| # As per Tensorflow models research and UDA impl | |||||
| # range [4, 0], 'keep 4 down to 0 MSB of original image' | |||||
| return (4 - int((level / _MAX_LEVEL) * 4),) | |||||
| def _posterize_tpu_level_to_arg(level, _hparams): | |||||
| # As per Tensorflow TPU EfficientNet impl | |||||
| # range [0, 4], 'keep 0 up to 4 MSB of original image' | |||||
| return (int((level / _MAX_LEVEL) * 4),) | |||||
| def _solarize_level_to_arg(level, _hparams): | |||||
| # range [0, 256] | |||||
| return (int((level / _MAX_LEVEL) * 256),) | |||||
| def _solarize_add_level_to_arg(level, _hparams): | |||||
| # range [0, 110] | |||||
| return (int((level / _MAX_LEVEL) * 110),) | |||||
| def _distort_level_to_arg(level, _hparams): | |||||
| return (int((level / _MAX_LEVEL) * 10 + 10),) | |||||
| def _zoom_level_to_arg(level, _hparams): | |||||
| return ((level / _MAX_LEVEL) * 0.4,) | |||||
| def _blur_level_to_arg(level, _hparams): | |||||
| level = (level / _MAX_LEVEL) * 0.5 | |||||
| level = _randomly_negate(level) | |||||
| return (level,) | |||||
| def _skew_level_to_arg(level, _hparams): | |||||
| level = (level / _MAX_LEVEL) * 0.3 | |||||
| level = _randomly_negate(level) | |||||
| return (level,) | |||||
| def distort(img, v, **__): | |||||
| w, h = img.size | |||||
| horizontal_tiles = int(0.1 * v) | |||||
| vertical_tiles = int(0.1 * v) | |||||
| width_of_square = int(math.floor(w / float(horizontal_tiles))) | |||||
| height_of_square = int(math.floor(h / float(vertical_tiles))) | |||||
| width_of_last_square = w - (width_of_square * (horizontal_tiles - 1)) | |||||
| height_of_last_square = h - (height_of_square * (vertical_tiles - 1)) | |||||
| dimensions = [] | |||||
| for vertical_tile in range(vertical_tiles): | |||||
| for horizontal_tile in range(horizontal_tiles): | |||||
| if vertical_tile == (vertical_tiles - 1) and horizontal_tile == (horizontal_tiles - 1): | |||||
| dimensions.append([horizontal_tile * width_of_square, | |||||
| vertical_tile * height_of_square, | |||||
| width_of_last_square + (horizontal_tile * width_of_square), | |||||
| height_of_last_square + (height_of_square * vertical_tile)]) | |||||
| elif vertical_tile == (vertical_tiles - 1): | |||||
| dimensions.append([horizontal_tile * width_of_square, | |||||
| vertical_tile * height_of_square, | |||||
| width_of_square + (horizontal_tile * width_of_square), | |||||
| height_of_last_square + (height_of_square * vertical_tile)]) | |||||
| elif horizontal_tile == (horizontal_tiles - 1): | |||||
| dimensions.append([horizontal_tile * width_of_square, | |||||
| vertical_tile * height_of_square, | |||||
| width_of_last_square + (horizontal_tile * width_of_square), | |||||
| height_of_square + (height_of_square * vertical_tile)]) | |||||
| else: | |||||
| dimensions.append([horizontal_tile * width_of_square, | |||||
| vertical_tile * height_of_square, | |||||
| width_of_square + (horizontal_tile * width_of_square), | |||||
| height_of_square + (height_of_square * vertical_tile)]) | |||||
| last_column = [] | |||||
| for i in range(vertical_tiles): | |||||
| last_column.append((horizontal_tiles - 1) + horizontal_tiles * i) | |||||
| last_row = range((horizontal_tiles * vertical_tiles) - horizontal_tiles, horizontal_tiles * vertical_tiles) | |||||
| polygons = [] | |||||
| for x1, y1, x2, y2 in dimensions: | |||||
| polygons.append([x1, y1, x1, y2, x2, y2, x2, y1]) | |||||
| polygon_indices = [] | |||||
| for i in range((vertical_tiles * horizontal_tiles) - 1): | |||||
| if i not in last_row and i not in last_column: | |||||
| polygon_indices.append([i, i + 1, i + horizontal_tiles, i + 1 + horizontal_tiles]) | |||||
| for a, b, c, d in polygon_indices: | |||||
| dx = v | |||||
| dy = v | |||||
| x1, y1, x2, y2, x3, y3, x4, y4 = polygons[a] | |||||
| polygons[a] = [x1, y1, | |||||
| x2, y2, | |||||
| x3 + dx, y3 + dy, | |||||
| x4, y4] | |||||
| x1, y1, x2, y2, x3, y3, x4, y4 = polygons[b] | |||||
| polygons[b] = [x1, y1, | |||||
| x2 + dx, y2 + dy, | |||||
| x3, y3, | |||||
| x4, y4] | |||||
| x1, y1, x2, y2, x3, y3, x4, y4 = polygons[c] | |||||
| polygons[c] = [x1, y1, | |||||
| x2, y2, | |||||
| x3, y3, | |||||
| x4 + dx, y4 + dy] | |||||
| x1, y1, x2, y2, x3, y3, x4, y4 = polygons[d] | |||||
| polygons[d] = [x1 + dx, y1 + dy, | |||||
| x2, y2, | |||||
| x3, y3, | |||||
| x4, y4] | |||||
| generated_mesh = [] | |||||
| for idx, i in enumerate(dimensions): | |||||
| generated_mesh.append([dimensions[idx], polygons[idx]]) | |||||
| return img.transform(img.size, PIL.Image.MESH, generated_mesh, resample=PIL.Image.BICUBIC) | |||||
| def zoom(img, v, **__): | |||||
| #assert 0.1 <= v <= 2 | |||||
| w, h = img.size | |||||
| image_zoomed = img.resize((int(round(img.size[0] * v)), | |||||
| int(round(img.size[1] * v))), | |||||
| resample=PIL.Image.BICUBIC) | |||||
| w_zoomed, h_zoomed = image_zoomed.size | |||||
| return image_zoomed.crop((math.floor((float(w_zoomed) / 2) - (float(w) / 2)), | |||||
| math.floor((float(h_zoomed) / 2) - (float(h) / 2)), | |||||
| math.floor((float(w_zoomed) / 2) + (float(w) / 2)), | |||||
| math.floor((float(h_zoomed) / 2) + (float(h) / 2)))) | |||||
| def erase(img, v, **__): | |||||
| #assert 0.1<= v <= 1 | |||||
| w, h = img.size | |||||
| w_occlusion = int(w * v) | |||||
| h_occlusion = int(h * v) | |||||
| if len(img.getbands()) == 1: | |||||
| rectangle = PIL.Image.fromarray(np.uint8(np.random.rand(w_occlusion, h_occlusion) * 255)) | |||||
| else: | |||||
| rectangle = PIL.Image.fromarray(np.uint8(np.random.rand(w_occlusion, h_occlusion, len(img.getbands())) * 255)) | |||||
| random_position_x = random.randint(0, w - w_occlusion) | |||||
| random_position_y = random.randint(0, h - h_occlusion) | |||||
| img.paste(rectangle, (random_position_x, random_position_y)) | |||||
| return img | |||||
| def skew(img, v, **__): | |||||
| #assert -1 <= v <= 1 | |||||
| w, h = img.size | |||||
| x1 = 0 | |||||
| x2 = h | |||||
| y1 = 0 | |||||
| y2 = w | |||||
| original_plane = [(y1, x1), (y2, x1), (y2, x2), (y1, x2)] | |||||
| max_skew_amount = max(w, h) | |||||
| max_skew_amount = int(math.ceil(max_skew_amount * v)) | |||||
| skew_amount = max_skew_amount | |||||
| new_plane = [(y1 - skew_amount, x1), # Top Left | |||||
| (y2, x1 - skew_amount), # Top Right | |||||
| (y2 + skew_amount, x2), # Bottom Right | |||||
| (y1, x2 + skew_amount)] | |||||
| matrix = [] | |||||
| for p1, p2 in zip(new_plane, original_plane): | |||||
| matrix.append([p1[0], p1[1], 1, 0, 0, 0, -p2[0] * p1[0], -p2[0] * p1[1]]) | |||||
| matrix.append([0, 0, 0, p1[0], p1[1], 1, -p2[1] * p1[0], -p2[1] * p1[1]]) | |||||
| A = np.matrix(matrix, dtype=np.float) | |||||
| B = np.array(original_plane).reshape(8) | |||||
| perspective_skew_coefficients_matrix = np.dot(np.linalg.pinv(A), B) | |||||
| perspective_skew_coefficients_matrix = np.array(perspective_skew_coefficients_matrix).reshape(8) | |||||
| return img.transform(img.size, PIL.Image.PERSPECTIVE, perspective_skew_coefficients_matrix, | |||||
| resample=PIL.Image.BICUBIC) | |||||
| def blur(img, v, **__): | |||||
| #assert -3 <= v <= 3 | |||||
| return img.filter(PIL.ImageFilter.GaussianBlur(v)) | |||||
| def rand_augment_ops(magnitude=10, hparams=None, transforms=None): | |||||
| hparams = hparams or _HPARAMS_DEFAULT | |||||
| transforms = transforms or _RAND_TRANSFORMS | |||||
| return [AutoAugmentOp(name, prob=0.5, magnitude=magnitude, hparams=hparams) for name in transforms] | |||||
| def _select_rand_weights(weight_idx=0, transforms=None): | |||||
| transforms = transforms or _RAND_TRANSFORMS | |||||
| assert weight_idx == 0 # only one set of weights currently | |||||
| rand_weights = _RAND_CHOICE_WEIGHTS_0 | |||||
| probs = [rand_weights[k] for k in transforms] | |||||
| probs /= np.sum(probs) | |||||
| return probs | |||||
| def rand_augment_transform(config_str, hparams): | |||||
| magnitude = _MAX_LEVEL # default to _MAX_LEVEL for magnitude (currently 10) | |||||
| num_layers = 2 # default to 2 ops per image | |||||
| weight_idx = None # default to no probability weights for op choice | |||||
| config = config_str.split('-') | |||||
| assert config[0] == 'rand' | |||||
| config = config[1:] | |||||
| for c in config: | |||||
| cs = re.split(r'(\d.*)', c) | |||||
| if len(cs) < 2: | |||||
| continue | |||||
| key, val = cs[:2] | |||||
| if key == 'mstd': | |||||
| # noise param injected via hparams for now | |||||
| hparams.setdefault('magnitude_std', float(val)) | |||||
| elif key == 'm': | |||||
| magnitude = int(val) | |||||
| elif key == 'n': | |||||
| num_layers = int(val) | |||||
| elif key == 'w': | |||||
| weight_idx = int(val) | |||||
| else: | |||||
| assert False, 'Unknown RandAugment config section' | |||||
| ra_ops = rand_augment_ops(magnitude=magnitude, hparams=hparams) | |||||
| choice_weights = None if weight_idx is None else _select_rand_weights(weight_idx) | |||||
| final_result = RandAugment(ra_ops, num_layers, choice_weights=choice_weights) | |||||
| return final_result | |||||
| LEVEL_TO_ARG = { | |||||
| 'Distort': _distort_level_to_arg, | |||||
| 'Zoom': _zoom_level_to_arg, | |||||
| 'Blur': _blur_level_to_arg, | |||||
| 'Skew': _skew_level_to_arg, | |||||
| 'AutoContrast': None, | |||||
| 'Equalize': None, | |||||
| 'Invert': None, | |||||
| 'Rotate': _rotate_level_to_arg, | |||||
| 'PosterizeOriginal': _posterize_original_level_to_arg, | |||||
| 'PosterizeResearch': _posterize_research_level_to_arg, | |||||
| 'PosterizeTpu': _posterize_tpu_level_to_arg, | |||||
| 'Solarize': _solarize_level_to_arg, | |||||
| 'SolarizeAdd': _solarize_add_level_to_arg, | |||||
| 'Color': _enhance_level_to_arg, | |||||
| 'Contrast': _enhance_level_to_arg, | |||||
| 'Brightness': _enhance_level_to_arg, | |||||
| 'Sharpness': _enhance_level_to_arg, | |||||
| 'ShearX': _shear_level_to_arg, | |||||
| 'ShearY': _shear_level_to_arg, | |||||
| 'TranslateX': _translate_abs_level_to_arg, | |||||
| 'TranslateY': _translate_abs_level_to_arg, | |||||
| 'TranslateXRel': _translate_rel_level_to_arg, | |||||
| 'TranslateYRel': _translate_rel_level_to_arg, | |||||
| } | |||||
| NAME_TO_OP = { | |||||
| 'Distort': distort, | |||||
| 'Zoom': zoom, | |||||
| 'Blur': blur, | |||||
| 'Skew': skew, | |||||
| 'AutoContrast': auto_contrast, | |||||
| 'Equalize': equalize, | |||||
| 'Invert': invert, | |||||
| 'Rotate': rotate, | |||||
| 'PosterizeOriginal': posterize, | |||||
| 'PosterizeResearch': posterize, | |||||
| 'PosterizeTpu': posterize, | |||||
| 'Solarize': solarize, | |||||
| 'SolarizeAdd': solarize_add, | |||||
| 'Color': color, | |||||
| 'Contrast': contrast, | |||||
| 'Brightness': brightness, | |||||
| 'Sharpness': sharpness, | |||||
| 'ShearX': shear_x, | |||||
| 'ShearY': shear_y, | |||||
| 'TranslateX': translate_x_abs, | |||||
| 'TranslateY': translate_y_abs, | |||||
| 'TranslateXRel': translate_x_rel, | |||||
| 'TranslateYRel': translate_y_rel, | |||||
| } | |||||
| class AutoAugmentOp: | |||||
| def __init__(self, name, prob=0.5, magnitude=10, hparams=None): | |||||
| hparams = hparams or _HPARAMS_DEFAULT | |||||
| self.aug_fn = NAME_TO_OP[name] | |||||
| self.level_fn = LEVEL_TO_ARG[name] | |||||
| self.prob = prob | |||||
| self.magnitude = magnitude | |||||
| self.hparams = hparams.copy() | |||||
| self.kwargs = dict( | |||||
| fillcolor=hparams['img_mean'] if 'img_mean' in hparams else _FILL, | |||||
| resample=hparams['interpolation'] if 'interpolation' in hparams else _RANDOM_INTERPOLATION, | |||||
| ) | |||||
| self.magnitude_std = self.hparams.get('magnitude_std', 0) | |||||
| def __call__(self, img): | |||||
| if random.random() > self.prob: | |||||
| return img | |||||
| magnitude = self.magnitude | |||||
| if self.magnitude_std and self.magnitude_std > 0: | |||||
| magnitude = random.gauss(magnitude, self.magnitude_std) | |||||
| level_args = self.level_fn(magnitude, self.hparams) if self.level_fn is not None else tuple() | |||||
| return self.aug_fn(img, *level_args, **self.kwargs) | |||||
| class RandAugment: | |||||
| def __init__(self, ops, num_layers=2, choice_weights=None): | |||||
| self.ops = ops | |||||
| self.num_layers = num_layers | |||||
| self.choice_weights = choice_weights | |||||
| def __call__(self, img): | |||||
| ops = np.random.choice( | |||||
| self.ops, self.num_layers, replace=self.choice_weights is None, p=self.choice_weights) | |||||
| for op in ops: | |||||
| img = op(img) | |||||
| return img | |||||
| @@ -0,0 +1,191 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """train imagenet.""" | |||||
| import argparse | |||||
| import math | |||||
| import os | |||||
| import random | |||||
| import time | |||||
| import numpy as np | |||||
| import mindspore | |||||
| from mindspore import Tensor, context | |||||
| from mindspore.communication.management import get_group_size, get_rank, init | |||||
| from mindspore.nn import SGD, RMSProp | |||||
| from mindspore.train.callback import (CheckpointConfig, LossMonitor, | |||||
| ModelCheckpoint, TimeMonitor) | |||||
| from mindspore.train.loss_scale_manager import FixedLossScaleManager | |||||
| from mindspore.train.model import Model, ParallelMode | |||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||||
| from src.config import efficientnet_b0_config_gpu as cfg | |||||
| from src.dataset import create_dataset | |||||
| from src.efficientnet import efficientnet_b0 | |||||
| from src.loss import LabelSmoothingCrossEntropy | |||||
| mindspore.common.set_seed(cfg.random_seed) | |||||
| random.seed(cfg.random_seed) | |||||
| np.random.seed(cfg.random_seed) | |||||
| def get_lr(base_lr, total_epochs, steps_per_epoch, decay_steps=1, | |||||
| decay_rate=0.9, warmup_steps=0., warmup_lr_init=0., global_epoch=0): | |||||
| lr_each_step = [] | |||||
| total_steps = steps_per_epoch * total_epochs | |||||
| global_steps = steps_per_epoch * global_epoch | |||||
| self_warmup_delta = ((base_lr - warmup_lr_init) / | |||||
| warmup_steps) if warmup_steps > 0 else 0 | |||||
| self_decay_rate = decay_rate if decay_rate < 1 else 1 / decay_rate | |||||
| for i in range(total_steps): | |||||
| steps = math.floor(i / steps_per_epoch) | |||||
| cond = 1 if (steps < warmup_steps) else 0 | |||||
| warmup_lr = warmup_lr_init + steps * self_warmup_delta | |||||
| decay_nums = math.floor(steps / decay_steps) | |||||
| decay_rate = math.pow(self_decay_rate, decay_nums) | |||||
| decay_lr = base_lr * decay_rate | |||||
| lr = cond * warmup_lr + (1 - cond) * decay_lr | |||||
| lr_each_step.append(lr) | |||||
| lr_each_step = lr_each_step[global_steps:] | |||||
| lr_each_step = np.array(lr_each_step).astype(np.float32) | |||||
| return lr_each_step | |||||
| def get_outdir(path, *paths, inc=False): | |||||
| outdir = os.path.join(path, *paths) | |||||
| if not os.path.exists(outdir): | |||||
| os.makedirs(outdir) | |||||
| elif inc: | |||||
| count = 1 | |||||
| outdir_inc = outdir + '-' + str(count) | |||||
| while os.path.exists(outdir_inc): | |||||
| count = count + 1 | |||||
| outdir_inc = outdir + '-' + str(count) | |||||
| assert count < 100 | |||||
| outdir = outdir_inc | |||||
| os.makedirs(outdir) | |||||
| return outdir | |||||
| parser = argparse.ArgumentParser( | |||||
| description='Training configuration', add_help=False) | |||||
| parser.add_argument('--data_path', type=str, default='/home/dataset/imagenet_jpeg/', metavar='DIR', | |||||
| help='path to dataset') | |||||
| parser.add_argument('--distributed', action='store_true', default=False) | |||||
| parser.add_argument('--GPU', action='store_true', default=False, | |||||
| help='Use GPU for training (default: False)') | |||||
| parser.add_argument('--cur_time', type=str, | |||||
| default='19701010-000000', help='current time') | |||||
| parser.add_argument('--resume', default='', type=str, metavar='PATH', | |||||
| help='Resume full model and optimizer state from checkpoint (default: none)') | |||||
| def main(): | |||||
| args, _ = parser.parse_known_args() | |||||
| devid, rank_id, rank_size = 0, 0, 1 | |||||
| context.set_context(mode=context.GRAPH_MODE) | |||||
| if args.distributed: | |||||
| if args.GPU: | |||||
| init("nccl") | |||||
| context.set_context(device_target='GPU') | |||||
| else: | |||||
| init() | |||||
| devid = int(os.getenv('DEVICE_ID')) | |||||
| context.set_context( | |||||
| device_target='Ascend', device_id=devid, reserve_class_name_in_scope=False) | |||||
| context.reset_auto_parallel_context() | |||||
| rank_id = get_rank() | |||||
| rank_size = get_group_size() | |||||
| context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, | |||||
| gradients_mean=True, device_num=rank_size) | |||||
| else: | |||||
| if args.GPU: | |||||
| context.set_context(device_target='GPU') | |||||
| is_master = not args.distributed or (rank_id == 0) | |||||
| net = efficientnet_b0(num_classes=cfg.num_classes, | |||||
| drop_rate=cfg.drop, | |||||
| drop_connect_rate=cfg.drop_connect, | |||||
| global_pool=cfg.gp, | |||||
| bn_tf=cfg.bn_tf, | |||||
| ) | |||||
| cur_time = args.cur_time | |||||
| output_base = './output' | |||||
| exp_name = '-'.join([ | |||||
| cur_time, | |||||
| cfg.model, | |||||
| str(224) | |||||
| ]) | |||||
| time.sleep(rank_id) | |||||
| output_dir = get_outdir(output_base, exp_name) | |||||
| train_data_url = os.path.join(args.data_path, 'train') | |||||
| train_dataset = create_dataset( | |||||
| cfg.batch_size, train_data_url, workers=cfg.workers, distributed=args.distributed) | |||||
| batches_per_epoch = train_dataset.get_dataset_size() | |||||
| loss_cb = LossMonitor(per_print_times=batches_per_epoch) | |||||
| loss = LabelSmoothingCrossEntropy(smooth_factor=cfg.smoothing) | |||||
| time_cb = TimeMonitor(data_size=batches_per_epoch) | |||||
| loss_scale_manager = FixedLossScaleManager( | |||||
| cfg.loss_scale, drop_overflow_update=False) | |||||
| config_ck = CheckpointConfig( | |||||
| save_checkpoint_steps=batches_per_epoch, keep_checkpoint_max=cfg.keep_checkpoint_max) | |||||
| ckpoint_cb = ModelCheckpoint( | |||||
| prefix=cfg.model, directory=output_dir, config=config_ck) | |||||
| lr = Tensor(get_lr(base_lr=cfg.lr, total_epochs=cfg.epochs, steps_per_epoch=batches_per_epoch, | |||||
| decay_steps=cfg.decay_epochs, decay_rate=cfg.decay_rate, | |||||
| warmup_steps=cfg.warmup_epochs, warmup_lr_init=cfg.warmup_lr_init, | |||||
| global_epoch=cfg.resume_start_epoch)) | |||||
| if cfg.opt == 'sgd': | |||||
| optimizer = SGD(net.trainable_params(), learning_rate=lr, momentum=cfg.momentum, | |||||
| weight_decay=cfg.weight_decay, | |||||
| loss_scale=cfg.loss_scale | |||||
| ) | |||||
| elif cfg.opt == 'rmsprop': | |||||
| optimizer = RMSProp(net.trainable_params(), learning_rate=lr, decay=0.9, weight_decay=cfg.weight_decay, | |||||
| momentum=cfg.momentum, epsilon=cfg.opt_eps, loss_scale=cfg.loss_scale | |||||
| ) | |||||
| loss.add_flags_recursive(fp32=True, fp16=False) | |||||
| if args.resume: | |||||
| ckpt = load_checkpoint(args.resume) | |||||
| load_param_into_net(net, ckpt) | |||||
| model = Model(net, loss, optimizer, | |||||
| loss_scale_manager=loss_scale_manager, | |||||
| amp_level=cfg.amp_level | |||||
| ) | |||||
| callbacks = [loss_cb, ckpoint_cb, time_cb] if is_master else [] | |||||
| if args.resume: | |||||
| real_epoch = cfg.epochs - cfg.resume_start_epoch | |||||
| model.train(real_epoch, train_dataset, | |||||
| callbacks=callbacks, dataset_sink_mode=True) | |||||
| else: | |||||
| model.train(cfg.epochs, train_dataset, | |||||
| callbacks=callbacks, dataset_sink_mode=True) | |||||
| if __name__ == '__main__': | |||||
| main() | |||||