diff --git a/model_zoo/official/cv/efficientnet/README.md b/model_zoo/official/cv/efficientnet/README.md new file mode 100644 index 0000000000..24ffa9b6c8 --- /dev/null +++ b/model_zoo/official/cv/efficientnet/README.md @@ -0,0 +1,111 @@ +# EfficientNet-B0 Example + +## Description + +This is an example of training EfficientNet-B0 in MindSpore. + +## Requirements + +- Install [Mindspore](http://www.mindspore.cn/install/en). +- Download the dataset. + +## Structure + +```shell +. +└─nasnet + ├─README.md + ├─scripts + ├─run_standalone_train_for_gpu.sh # launch standalone training with gpu platform(1p) + ├─run_distribute_train_for_gpu.sh # launch distributed training with gpu platform(8p) + └─run_eval_for_gpu.sh # launch evaluating with gpu platform + ├─src + ├─config.py # parameter configuration + ├─dataset.py # data preprocessing + ├─efficientnet.py # network definition + ├─loss.py # Customized loss function + ├─transform_utils.py # random augment utils + ├─transform.py # random augment class + ├─eval.py # eval net + └─train.py # train net + +``` + +## Parameter Configuration + +Parameters for both training and evaluating can be set in config.py + +``` +'random_seed': 1, # fix random seed +'model': 'efficientnet_b0', # model name +'drop': 0.2, # dropout rate +'drop_connect': 0.2, # drop connect rate +'opt_eps': 0.001, # optimizer epsilon +'lr': 0.064, # learning rate LR +'batch_size': 128, # batch size +'decay_epochs': 2.4, # epoch interval to decay LR +'warmup_epochs': 5, # epochs to warmup LR +'decay_rate': 0.97, # LR decay rate +'weight_decay': 1e-5, # weight decay +'epochs': 600, # number of epochs to train +'workers': 8, # number of data processing processes +'amp_level': 'O0', # amp level +'opt': 'rmsprop', # optimizer +'num_classes': 1000, # number of classes +'gp': 'avg', # type of global pool, "avg", "max", "avgmax", "avgmaxc" +'momentum': 0.9, # optimizer momentum +'warmup_lr_init': 0.0001, # init warmup LR +'smoothing': 0.1, # label smoothing factor +'bn_tf': False, # use Tensorflow BatchNorm defaults +'keep_checkpoint_max': 10, # max number ckpts to keep +'loss_scale': 1024, # loss scale +'resume_start_epoch': 0, # resume start epoch +``` + +## Running the example + +### Train + +#### Usage + +``` +# distribute training example(8p) +sh run_distribute_train_for_gpu.sh DATA_DIR +# standalone training +sh run_standalone_train_for_gpu.sh DATA_DIR DEVICE_ID +``` + +#### Launch + +```bash +# distributed training example(8p) for GPU +sh scripts/run_distribute_train_for_gpu.sh /dataset +# standalone training example for GPU +sh scripts/run_standalone_train_for_gpu.sh /dataset 0 +``` + +#### Result + +You can find checkpoint file together with result in log. + +### Evaluation + +#### Usage + +``` +# Evaluation +sh run_eval_for_gpu.sh DATA_DIR DEVICE_ID PATH_CHECKPOINT +``` + +#### Launch + +```bash +# Evaluation with checkpoint +sh scripts/run_eval_for_gpu.sh /dataset 0 ./checkpoint/efficientnet_b0-600_1251.ckpt +``` + +> checkpoint can be produced in training process. + +#### Result + +Evaluation result will be stored in the scripts path. Under this, you can find result like the followings in log. diff --git a/model_zoo/official/cv/efficientnet/eval.py b/model_zoo/official/cv/efficientnet/eval.py new file mode 100644 index 0000000000..098db060ba --- /dev/null +++ b/model_zoo/official/cv/efficientnet/eval.py @@ -0,0 +1,61 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""evaluate imagenet""" +import argparse +import os + +import mindspore.nn as nn +from mindspore import context +from mindspore.train.model import Model +from mindspore.train.serialization import load_checkpoint, load_param_into_net + +from src.config import efficientnet_b0_config_gpu as cfg +from src.dataset import create_dataset_val +from src.efficientnet import efficientnet_b0 +from src.loss import LabelSmoothingCrossEntropy + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='image classification evaluation') + parser.add_argument('--checkpoint', type=str, default='', help='checkpoint of efficientnet (Default: None)') + parser.add_argument('--data_path', type=str, default='', help='Dataset path') + parser.add_argument('--platform', type=str, default='GPU', choices=('Ascend', 'GPU'), help='run platform') + args_opt = parser.parse_args() + + if args_opt.platform == 'Ascend': + device_id = int(os.getenv('DEVICE_ID')) + context.set_context(device_id=device_id) + + context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform) + + net = efficientnet_b0(num_classes=cfg.num_classes, + drop_rate=cfg.drop, + drop_connect_rate=cfg.drop_connect, + global_pool=cfg.gp, + bn_tf=cfg.bn_tf, + ) + + ckpt = load_checkpoint(args_opt.checkpoint) + load_param_into_net(net, ckpt) + net.set_train(False) + val_data_url = os.path.join(args_opt.data_path, 'val') + dataset = create_dataset_val(cfg.batch_size, val_data_url, workers=cfg.workers, distributed=False) + loss = LabelSmoothingCrossEntropy(smooth_factor=cfg.smoothing) + eval_metrics = {'Loss': nn.Loss(), + 'Top1-Acc': nn.Top1CategoricalAccuracy(), + 'Top5-Acc': nn.Top5CategoricalAccuracy()} + model = Model(net, loss, optimizer=None, metrics=eval_metrics) + + metrics = model.eval(dataset) + print("metric: ", metrics) diff --git a/model_zoo/official/cv/efficientnet/scripts/run_distribute_train_for_gpu.sh b/model_zoo/official/cv/efficientnet/scripts/run_distribute_train_for_gpu.sh new file mode 100644 index 0000000000..c9165841a8 --- /dev/null +++ b/model_zoo/official/cv/efficientnet/scripts/run_distribute_train_for_gpu.sh @@ -0,0 +1,32 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +DATA_DIR=$1 + +current_exec_path=$(pwd) +echo ${current_exec_path} + +curtime=`date '+%Y%m%d-%H%M%S'` +RANK_SIZE=8 + +rm ${current_exec_path}/device_parallel/ -rf +mkdir ${current_exec_path}/device_parallel +echo ${curtime} > ${current_exec_path}/device_parallel/starttime + +mpirun --allow-run-as-root -n $RANK_SIZE python ${current_exec_path}/train.py \ + --GPU \ + --distributed \ + --data_path ${DATA_DIR} \ + --cur_time ${curtime} > ${current_exec_path}/device_parallel/efficientnet_b0.log 2>&1 & diff --git a/model_zoo/official/cv/efficientnet/scripts/run_eval_for_gpu.sh b/model_zoo/official/cv/efficientnet/scripts/run_eval_for_gpu.sh new file mode 100644 index 0000000000..32ef1273bf --- /dev/null +++ b/model_zoo/official/cv/efficientnet/scripts/run_eval_for_gpu.sh @@ -0,0 +1,27 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +DATA_DIR=$1 +DEVICE_ID=$2 +PATH_CHECKPOINT=$3 + +current_exec_path=$(pwd) +echo ${current_exec_path} + +curtime=`date '+%Y%m%d-%H%M%S'` + +echo ${curtime} > ${current_exec_path}/eval_starttime + +CUDA_VISIBLE_DEVICES=${DEVICE_ID} python ./eval.py --platform 'GPU' --data_path ${DATA_DIR} --checkpoint ${PATH_CHECKPOINT} > ${current_exec_path}/eval.log 2>&1 & diff --git a/model_zoo/official/cv/efficientnet/scripts/run_standalone_train_for_gpu.sh b/model_zoo/official/cv/efficientnet/scripts/run_standalone_train_for_gpu.sh new file mode 100644 index 0000000000..ad3d6bdfa8 --- /dev/null +++ b/model_zoo/official/cv/efficientnet/scripts/run_standalone_train_for_gpu.sh @@ -0,0 +1,31 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +DATA_DIR=$1 +DEVICE_ID=$2 + +current_exec_path=$(pwd) +echo ${current_exec_path} + +curtime=`date '+%Y%m%d-%H%M%S'` + +rm ${current_exec_path}/device_${DEVICE_ID}/ -rf +mkdir ${current_exec_path}/device_${DEVICE_ID} +echo ${curtime} > ${current_exec_path}/device_${DEVICE_ID}/starttime + +CUDA_VISIBLE_DEVICES=${DEVICE_ID} python ${current_exec_path}/train.py \ + --GPU \ + --data_path ${DATA_DIR} \ + --cur_time ${curtime} > ${current_exec_path}/device_${DEVICE_ID}/efficientnet_b0.log 2>&1 & diff --git a/model_zoo/official/cv/efficientnet/src/config.py b/model_zoo/official/cv/efficientnet/src/config.py new file mode 100644 index 0000000000..09ea624716 --- /dev/null +++ b/model_zoo/official/cv/efficientnet/src/config.py @@ -0,0 +1,47 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +network config setting +""" +from easydict import EasyDict as edict + +efficientnet_b0_config_gpu = edict({ + 'random_seed': 1, + 'model': 'efficientnet_b0', + 'drop': 0.2, + 'drop_connect': 0.2, + 'opt_eps': 0.001, + 'lr': 0.064, + 'batch_size': 128, + 'decay_epochs': 2.4, + 'warmup_epochs': 5, + 'decay_rate': 0.97, + 'weight_decay': 1e-5, + 'epochs': 600, + 'workers': 8, + 'amp_level': 'O0', + 'opt': 'rmsprop', + 'num_classes': 1000, + #'Type of global pool, "avg", "max", "avgmax", "avgmaxc" + 'gp': 'avg', + 'momentum': 0.9, + 'warmup_lr_init': 0.0001, + 'smoothing': 0.1, + #Use Tensorflow BatchNorm defaults for models that support it + 'bn_tf': False, + 'keep_checkpoint_max': 10, + 'loss_scale': 1024, + 'resume_start_epoch': 0, +}) diff --git a/model_zoo/official/cv/efficientnet/src/dataset.py b/model_zoo/official/cv/efficientnet/src/dataset.py new file mode 100644 index 0000000000..e30dd87bc4 --- /dev/null +++ b/model_zoo/official/cv/efficientnet/src/dataset.py @@ -0,0 +1,125 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Data operations, will be used in train.py and eval.py +""" +import math +import os + +import mindspore.common.dtype as mstype +import mindspore.dataset as ds +import mindspore.dataset.transforms.c_transforms as C2 +import mindspore.dataset.vision.c_transforms as C +from mindspore.communication.management import get_group_size, get_rank +from mindspore.dataset.vision import Inter + +from src.config import efficientnet_b0_config_gpu as cfg +from src.transform import RandAugment + +ds.config.set_seed(cfg.random_seed) + + +IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) +IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) + +img_size = (224, 224) +crop_pct = 0.875 +rescale = 1.0 / 255.0 +shift = 0.0 +inter_method = 'bilinear' +resize_value = 224 # img_size +scale = (0.08, 1.0) +ratio = (3./4., 4./3.) +inter_str = 'bicubic' + +def str2MsInter(method): + if method == 'bicubic': + return Inter.BICUBIC + if method == 'nearest': + return Inter.NEAREST + return Inter.BILINEAR + +def create_dataset(batch_size, train_data_url='', workers=8, distributed=False): + if not os.path.exists(train_data_url): + raise ValueError('Path not exists') + interpolation = str2MsInter(inter_str) + + c_decode_op = C.Decode() + type_cast_op = C2.TypeCast(mstype.int32) + random_resize_crop_op = C.RandomResizedCrop(size=(resize_value, resize_value), scale=scale, ratio=ratio, + interpolation=interpolation) + random_horizontal_flip_op = C.RandomHorizontalFlip(0.5) + + efficient_rand_augment = RandAugment() + + image_ops = [c_decode_op, random_resize_crop_op, random_horizontal_flip_op] + + rank_id = get_rank() if distributed else 0 + rank_size = get_group_size() if distributed else 1 + + dataset_train = ds.ImageFolderDataset(train_data_url, + num_parallel_workers=workers, + shuffle=True, + num_shards=rank_size, + shard_id=rank_id) + dataset_train = dataset_train.map(input_columns=["image"], + operations=image_ops, + num_parallel_workers=workers) + dataset_train = dataset_train.map(input_columns=["label"], + operations=type_cast_op, + num_parallel_workers=workers) + ds_train = dataset_train.batch(batch_size, + per_batch_map=efficient_rand_augment, + input_columns=["image", "label"], + num_parallel_workers=2, + drop_remainder=True) + ds_train = ds_train.repeat(1) + return ds_train + + +def create_dataset_val(batch_size=128, val_data_url='', workers=8, distributed=False): + if not os.path.exists(val_data_url): + raise ValueError('Path not exists') + rank_id = get_rank() if distributed else 0 + rank_size = get_group_size() if distributed else 1 + dataset = ds.ImageFolderDataset(val_data_url, num_parallel_workers=workers, + num_shards=rank_size, shard_id=rank_id, shuffle=False) + scale_size = None + interpolation = str2MsInter(inter_method) + + if isinstance(img_size, tuple): + assert len(img_size) == 2 + if img_size[-1] == img_size[-2]: + scale_size = int(math.floor(img_size[0] / crop_pct)) + else: + scale_size = tuple([int(x / crop_pct) for x in img_size]) + else: + scale_size = int(math.floor(img_size / crop_pct)) + + type_cast_op = C2.TypeCast(mstype.int32) + decode_op = C.Decode() + resize_op = C.Resize(size=scale_size, interpolation=interpolation) + center_crop = C.CenterCrop(size=224) + rescale_op = C.Rescale(rescale, shift) + normalize_op = C.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) + changeswap_op = C.HWC2CHW() + + ctrans = [decode_op, resize_op, center_crop, rescale_op, normalize_op, changeswap_op] + + dataset = dataset.map(input_columns=["label"], operations=type_cast_op, num_parallel_workers=workers) + dataset = dataset.map(input_columns=["image"], operations=ctrans, num_parallel_workers=workers) + dataset = dataset.batch(batch_size, drop_remainder=True, num_parallel_workers=workers) + dataset = dataset.repeat(1) + return dataset diff --git a/model_zoo/official/cv/efficientnet/src/efficientnet.py b/model_zoo/official/cv/efficientnet/src/efficientnet.py new file mode 100644 index 0000000000..f8985a513a --- /dev/null +++ b/model_zoo/official/cv/efficientnet/src/efficientnet.py @@ -0,0 +1,746 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""EfficientNet model definition""" +import logging +import math +import re +from copy import deepcopy + +import mindspore as ms +import mindspore.nn as nn +from mindspore import context, ms_function +from mindspore.common.initializer import (Normal, One, Uniform, Zero, + initializer) +from mindspore.common.parameter import Parameter +from mindspore.ops import operations as P +from mindspore.ops.composite import clip_by_value + +relu = P.ReLU() +sigmoid = P.Sigmoid() + +IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) +IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) +IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5) +IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5) +IMAGENET_DPN_MEAN = (124 / 255, 117 / 255, 104 / 255) +IMAGENET_DPN_STD = tuple([1 / (.0167 * 255)] * 3) + + +def _cfg(url='', **kwargs): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'conv_stem', 'classifier': 'classifier', + **kwargs + } + + +default_cfgs = { + 'efficientnet_b0': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b0-d6904d92.pth'), + 'efficientnet_b1': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b1-533bc792.pth', + input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), + 'efficientnet_b2': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b2-cf78dc4d.pth', + input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890), + 'efficientnet_b3': _cfg( + url='', input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904), + 'efficientnet_b4': _cfg( + url='', input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922), +} + +_DEBUG = False + +_BN_MOMENTUM_PT_DEFAULT = 0.1 +_BN_EPS_PT_DEFAULT = 1e-5 +_BN_ARGS_PT = dict(momentum=_BN_MOMENTUM_PT_DEFAULT, eps=_BN_EPS_PT_DEFAULT) +_BN_MOMENTUM_TF_DEFAULT = 1 - 0.99 +_BN_EPS_TF_DEFAULT = 1e-3 +_BN_ARGS_TF = dict(momentum=_BN_MOMENTUM_TF_DEFAULT, eps=_BN_EPS_TF_DEFAULT) + + +def _initialize_weight_goog(shape=None, layer_type='conv', bias=False): + if layer_type not in ('conv', 'bn', 'fc'): + raise ValueError('The layer type is not known, the supported are conv, bn and fc') + if bias: + return Zero() + if layer_type == 'conv': + assert isinstance(shape, (tuple, list)) and len( + shape) == 3, 'The shape must be 3 scalars, and are in_chs, ks, out_chs respectively' + n = shape[1] * shape[1] * shape[2] + return Normal(math.sqrt(2.0 / n)) + if layer_type == 'bn': + return One() + assert isinstance(shape, (tuple, list)) and len( + shape) == 2, 'The shape must be 2 scalars, and are in_chs, out_chs respectively' + n = shape[1] + init_range = 1.0 / math.sqrt(n) + return Uniform(init_range) + + +def _initialize_weight_default(shape=None, layer_type='conv', bias=False): + if layer_type not in ('conv', 'bn', 'fc'): + raise ValueError('The layer type is not known, the supported are conv, bn and fc') + if bias and layer_type == 'bn': + return Zero() + if layer_type == 'conv': + return One() + if layer_type == 'bn': + return One() + return One() + + +def _conv(in_channels, out_channels, kernel_size=3, stride=1, padding=0, pad_mode='same', bias=False): + weight_init_value = _initialize_weight_goog(shape=(in_channels, kernel_size, out_channels)) + bias_init_value = _initialize_weight_goog(bias=True) if bias else None + if bias: + return nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, + padding=padding, pad_mode=pad_mode, weight_init=weight_init_value, + has_bias=bias, bias_init=bias_init_value) + return nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, + padding=padding, pad_mode=pad_mode, weight_init=weight_init_value, + has_bias=bias) + + +def _conv1x1(in_channels, out_channels, stride=1, padding=0, pad_mode='same', bias=False): + weight_init_value = _initialize_weight_goog(shape=(in_channels, 1, out_channels)) + bias_init_value = _initialize_weight_goog(bias=True) if bias else None + if bias: + return nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, + padding=padding, pad_mode=pad_mode, weight_init=weight_init_value, + has_bias=bias, bias_init=bias_init_value) + return nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, + padding=padding, pad_mode=pad_mode, weight_init=weight_init_value, + has_bias=bias) + + +def _conv_group(in_channels, out_channels, group, kernel_size=3, stride=1, padding=0, pad_mode='same', bias=False): + weight_init_value = _initialize_weight_goog(shape=(in_channels, kernel_size, out_channels)) + bias_init_value = _initialize_weight_goog(bias=True) if bias else None + if bias: + return nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, + padding=padding, pad_mode=pad_mode, weight_init=weight_init_value, + group=group, has_bias=bias, bias_init=bias_init_value) + return nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, + padding=padding, pad_mode=pad_mode, weight_init=weight_init_value, + group=group, has_bias=bias) + + +def _fused_bn(channels, momentum=0.1, eps=1e-4, gamma_init=1, beta_init=0): + return nn.BatchNorm2d(channels, eps=eps, momentum=1 - momentum, gamma_init=gamma_init, beta_init=beta_init) + + +def _dense(in_channels, output_channels, bias=True, activation=None): + weight_init_value = _initialize_weight_goog(shape=(in_channels, output_channels), layer_type='fc') + bias_init_value = _initialize_weight_goog(bias=True) if bias else None + if bias: + return nn.Dense(in_channels, output_channels, weight_init=weight_init_value, bias_init=bias_init_value, + has_bias=bias, activation=activation) + return nn.Dense(in_channels, output_channels, weight_init=weight_init_value, has_bias=bias, + activation=activation) + + +def _resolve_bn_args(kwargs): + bn_args = _BN_ARGS_TF.copy() if kwargs.pop('bn_tf', False) else _BN_ARGS_PT.copy() + bn_momentum = kwargs.pop('bn_momentum', None) + if bn_momentum is not None: + bn_args['momentum'] = bn_momentum + bn_eps = kwargs.pop('bn_eps', None) + if bn_eps is not None: + bn_args['eps'] = bn_eps + return bn_args + + +def _round_channels(channels, multiplier=1.0, divisor=8, channel_min=None): + """Round number of filters based on depth multiplier.""" + if not multiplier: + return channels + + channels *= multiplier + channel_min = channel_min or divisor + new_channels = max( + int(channels + divisor / 2) // divisor * divisor, + channel_min) + if new_channels < 0.9 * channels: + new_channels += divisor + return new_channels + + +def _parse_ksize(ss): + if ss.isdigit(): + return int(ss) + return [int(k) for k in ss.split('.')] + + +def _decode_block_str(block_str, depth_multiplier=1.0): + """ Decode block definition string + + Gets a list of block arg (dicts) through a string notation of arguments. + E.g. ir_r2_k3_s2_e1_i32_o16_se0.25_noskip + + All args can exist in any order with the exception of the leading string which + is assumed to indicate the block type. + + leading string - block type ( + ir = InvertedResidual, ds = DepthwiseSep, dsa = DeptwhiseSep with pw act, cn = ConvBnAct) + r - number of repeat blocks, + k - kernel size, + s - strides (1-9), + e - expansion ratio, + c - output channels, + se - squeeze/excitation ratio + n - activation fn ('re', 'r6', 'hs', or 'sw') + Args: + block_str: a string representation of block arguments. + Returns: + A list of block args (dicts) + Raises: + ValueError: if the string def not properly specified (TODO) + """ + assert isinstance(block_str, str) + ops = block_str.split('_') + block_type = ops[0] + ops = ops[1:] + options = {} + noskip = False + for op in ops: + if op == 'noskip': + noskip = True + elif op.startswith('n'): + # activation fn + key = op[0] + v = op[1:] + if v == 're': + print('not support') + elif v == 'r6': + print('not support') + elif v == 'hs': + print('not support') + elif v == 'sw': + print('not support') + else: + continue + options[key] = value + else: + # all numeric options + splits = re.split(r'(\d.*)', op) + if len(splits) >= 2: + key, value = splits[:2] + options[key] = value + + act_fn = options['n'] if 'n' in options else None + exp_kernel_size = _parse_ksize(options['a']) if 'a' in options else 1 + pw_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1 + fake_in_chs = int(options['fc']) if 'fc' in options else 0 + + num_repeat = int(options['r']) + # each type of block has different valid arguments, fill accordingly + if block_type == 'ir': + block_args = dict( + block_type=block_type, + dw_kernel_size=_parse_ksize(options['k']), + exp_kernel_size=exp_kernel_size, + pw_kernel_size=pw_kernel_size, + out_chs=int(options['c']), + exp_ratio=float(options['e']), + se_ratio=float(options['se']) if 'se' in options else None, + stride=int(options['s']), + act_fn=act_fn, + noskip=noskip, + ) + elif block_type in ('ds', 'dsa'): + block_args = dict( + block_type=block_type, + dw_kernel_size=_parse_ksize(options['k']), + pw_kernel_size=pw_kernel_size, + out_chs=int(options['c']), + se_ratio=float(options['se']) if 'se' in options else None, + stride=int(options['s']), + act_fn=act_fn, + pw_act=block_type == 'dsa', + noskip=block_type == 'dsa' or noskip, + ) + elif block_type == 'er': + block_args = dict( + block_type=block_type, + exp_kernel_size=_parse_ksize(options['k']), + pw_kernel_size=pw_kernel_size, + out_chs=int(options['c']), + exp_ratio=float(options['e']), + fake_in_chs=fake_in_chs, + se_ratio=float(options['se']) if 'se' in options else None, + stride=int(options['s']), + act_fn=act_fn, + noskip=noskip, + ) + elif block_type == 'cn': + block_args = dict( + block_type=block_type, + kernel_size=int(options['k']), + out_chs=int(options['c']), + stride=int(options['s']), + act_fn=act_fn, + ) + else: + assert False, 'Unknown block type (%s)' % block_type + + return block_args, num_repeat + + +def _scale_stage_depth(stack_args, repeats, depth_multiplier=1.0, depth_trunc='ceil'): + """ Per-stage depth scaling + Scales the block repeats in each stage. This depth scaling impl maintains + compatibility with the EfficientNet scaling method, while allowing sensible + scaling for other models that may have multiple block arg definitions in each stage. + """ + + # We scale the total repeat count for each stage, there may be multiple + # block arg defs per stage so we need to sum. + num_repeat = sum(repeats) + if depth_trunc == 'round': + # Truncating to int by rounding allows stages with few repeats to remain + # proportionally smaller for longer. This is a good choice when stage definitions + # include single repeat stages that we'd prefer to keep that way as long as possible + num_repeat_scaled = max(1, round(num_repeat * depth_multiplier)) + else: + # The default for EfficientNet truncates repeats to int via 'ceil'. + # Any multiplier > 1.0 will result in an increased depth for every stage. + num_repeat_scaled = int(math.ceil(num_repeat * depth_multiplier)) + + # Proportionally distribute repeat count scaling to each block definition in the stage. + # Allocation is done in reverse as it results in the first block being less likely to be scaled. + # The first block makes less sense to repeat in most of the arch definitions. + repeats_scaled = [] + for r in repeats[::-1]: + rs = max(1, round((r / num_repeat * num_repeat_scaled))) + repeats_scaled.append(rs) + num_repeat -= r + num_repeat_scaled -= rs + repeats_scaled = repeats_scaled[::-1] + + # Apply the calculated scaling to each block arg in the stage + sa_scaled = [] + for ba, rep in zip(stack_args, repeats_scaled): + sa_scaled.extend([deepcopy(ba) for _ in range(rep)]) + return sa_scaled + + +def _decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil'): + arch_args = [] + for _, block_strings in enumerate(arch_def): + assert isinstance(block_strings, list) + stack_args = [] + repeats = [] + for block_str in block_strings: + assert isinstance(block_str, str) + ba, rep = _decode_block_str(block_str) + stack_args.append(ba) + repeats.append(rep) + arch_args.append(_scale_stage_depth(stack_args, repeats, depth_multiplier, depth_trunc)) + return arch_args + + +@ms_function +def hard_swish(x): + x = P.Cast()(x, ms.float32) + y = x + 3.0 + y = clip_by_value(y, 0.0, 6.0) + y = y / 6.0 + return x * y + + +class BlockBuilder(nn.Cell): + def __init__(self, builder_in_channels, builder_block_args, channel_multiplier=1.0, channel_divisor=8, + channel_min=None, pad_type='', act_fn=None, se_gate_fn=sigmoid, se_reduce_mid=False, + bn_args=None, drop_connect_rate=0., verbose=False): + super(BlockBuilder, self).__init__() + + bn_args = _BN_ARGS_PT if bn_args is None else bn_args + self.channel_multiplier = channel_multiplier + self.channel_divisor = channel_divisor + self.channel_min = channel_min + self.pad_type = pad_type + self.act_fn = act_fn + self.se_gate_fn = se_gate_fn + self.se_reduce_mid = se_reduce_mid + self.bn_args = bn_args + self.drop_connect_rate = drop_connect_rate + self.verbose = verbose + + self.in_chs = None + self.block_idx = 0 + self.block_count = 0 + self.layer = self._make_layer(builder_in_channels, builder_block_args) + + def _round_channels(self, chs): + return _round_channels(chs, self.channel_multiplier, self.channel_divisor, self.channel_min) + + def _make_block(self, ba): + bt = ba.pop('block_type') + ba['in_chs'] = self.in_chs + ba['out_chs'] = self._round_channels(ba['out_chs']) + if 'fake_in_chs' in ba and ba['fake_in_chs']: + ba['fake_in_chs'] = self._round_channels(ba['fake_in_chs']) + ba['bn_args'] = self.bn_args + ba['pad_type'] = self.pad_type + ba['act_fn'] = ba['act_fn'] if ba['act_fn'] is not None else self.act_fn + assert ba['act_fn'] is not None + if bt == 'ir': + ba['drop_connect_rate'] = self.drop_connect_rate * self.block_idx / self.block_count + ba['se_gate_fn'] = self.se_gate_fn + ba['se_reduce_mid'] = self.se_reduce_mid + if self.verbose: + logging.info(' InvertedResidual %d, Args: %s', self.block_idx, str(ba)) + block = InvertedResidual(**ba) + elif bt in ('ds', 'dsa'): + ba['drop_connect_rate'] = self.drop_connect_rate * self.block_idx / self.block_count + if self.verbose: + logging.info(' DepthwiseSeparable %d, Args: %s', self.block_idx, str(ba)) + block = DepthwiseSeparableConv(**ba) + else: + assert False, 'Uknkown block type (%s) while building model.' % bt + self.in_chs = ba['out_chs'] + + return block + + def _make_stack(self, stack_args): + blocks = [] + # each stack (stage) contains a list of block arguments + for i, ba in enumerate(stack_args): + if self.verbose: + logging.info(' Block: %d', i) + if i >= 1: + # only the first block in any stack can have a stride > 1 + ba['stride'] = 1 + block = self._make_block(ba) + blocks.append(block) + self.block_idx += 1 + return nn.SequentialCell(blocks) + + def _make_layer(self, in_chs, block_args): + """ Build the blocks + Args: + in_chs: Number of input-channels passed to first block + block_args: A list of lists, outer list defines stages, inner + list contains strings defining block configuration(s) + Return: + List of block stacks (each stack wrapped in nn.Sequential) + """ + if self.verbose: + logging.info('Building model trunk with %d stages...', len(block_args)) + self.in_chs = in_chs + self.block_count = sum([len(x) for x in block_args]) + self.block_idx = 0 + blocks = [] + + for stack_idx, stack in enumerate(block_args): + if self.verbose: + logging.info('Stack: %d', stack_idx) + assert isinstance(stack, list) + stack = self._make_stack(stack) + blocks.append(stack) + return nn.SequentialCell(blocks) + + def construct(self, x): + return self.layer(x) + + +class DepthWiseConv(nn.Cell): + def __init__(self, in_planes, kernel_size, stride): + super(DepthWiseConv, self).__init__() + platform = context.get_context("device_target") + weight_shape = [1, kernel_size, in_planes] + weight_init = _initialize_weight_goog(shape=weight_shape) + if platform == "GPU": + self.depthwise_conv = P.Conv2D(out_channel=in_planes * 1, kernel_size=kernel_size, + stride=stride, pad_mode="same", group=in_planes) + self.weight = Parameter(initializer( + weight_init, [in_planes * 1, 1, kernel_size, kernel_size]), name='depthwise_weight') + else: + self.depthwise_conv = P.DepthwiseConv2dNative( + channel_multiplier=1, kernel_size=kernel_size, stride=stride, pad_mode='same',) + self.weight = Parameter(initializer( + weight_init, [1, in_planes, kernel_size, kernel_size]), name='depthwise_weight') + + def construct(self, x): + x = self.depthwise_conv(x, self.weight) + return x + + +class DropConnect(nn.Cell): + def __init__(self, drop_connect_rate=0., seed0=0, seed1=0): + super(DropConnect, self).__init__() + self.shape = P.Shape() + self.dtype = P.DType() + self.keep_prob = 1 - drop_connect_rate + self.dropout = P.Dropout(keep_prob=self.keep_prob) + + def construct(self, x): + shape = self.shape(x) + dtype = self.dtype(x) + ones_tensor = P.Fill()(dtype, (shape[0], 1, 1, 1), 1) + _, mask_ = self.dropout(ones_tensor) + x = x * mask_ + return x + + +def drop_connect(inputs, training=False, drop_connect_rate=0.): + if not training: + return inputs + return DropConnect(drop_connect_rate)(inputs) + + +class SqueezeExcite(nn.Cell): + def __init__(self, in_chs, reduce_chs=None, act_fn=relu, gate_fn=sigmoid): + super(SqueezeExcite, self).__init__() + self.act_fn = act_fn + self.gate_fn = gate_fn + reduce_chs = reduce_chs or in_chs + self.conv_reduce = _dense(in_chs, reduce_chs, bias=True) + self.conv_expand = _dense(reduce_chs, in_chs, bias=True) + self.avg_global_pool = P.ReduceMean(keep_dims=False) + + def construct(self, x): + x_se = self.avg_global_pool(x, (2, 3)) + x_se = self.conv_reduce(x_se) + x_se = self.act_fn(x_se) + x_se = self.conv_expand(x_se) + x_se = self.gate_fn(x_se) + x_se = P.ExpandDims()(x_se, 2) + x_se = P.ExpandDims()(x_se, 3) + x = x * x_se + return x + + +class DepthwiseSeparableConv(nn.Cell): + def __init__(self, in_chs, out_chs, dw_kernel_size=3, + stride=1, pad_type='', act_fn=relu, noskip=False, + pw_kernel_size=1, pw_act=False, se_ratio=0., se_gate_fn=sigmoid, + bn_args=None, drop_connect_rate=0.): + super(DepthwiseSeparableConv, self).__init__() + + bn_args = _BN_ARGS_PT if bn_args is None else bn_args + assert stride in [1, 2], 'stride must be 1 or 2' + self.has_se = se_ratio is not None and se_ratio > 0. + self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip + self.has_pw_act = pw_act + self.act_fn = act_fn + self.drop_connect_rate = drop_connect_rate + self.conv_dw = DepthWiseConv(in_chs, dw_kernel_size, stride) + self.bn1 = _fused_bn(in_chs, **bn_args) + + # + if self.has_se: + self.se = SqueezeExcite(in_chs, reduce_chs=max(1, int(in_chs * se_ratio)), + act_fn=act_fn, gate_fn=se_gate_fn) + self.conv_pw = _conv1x1(in_chs, out_chs) + self.bn2 = _fused_bn(out_chs, **bn_args) + + def construct(self, x): + identity = x + + x = self.conv_dw(x) + x = self.bn1(x) + x = self.act_fn(x) + + if self.has_se: + x = self.se(x) + + x = self.conv_pw(x) + x = self.bn2(x) + if self.has_pw_act: + x = self.act_fn(x) + + if self.has_residual: + if self.drop_connect_rate > 0.: + x = drop_connect(x, self.training, self.drop_connect_rate) + x = x + identity + + return x + + +class InvertedResidual(nn.Cell): + def __init__(self, in_chs, out_chs, dw_kernel_size=3, stride=1, + pad_type='', act_fn=relu, pw_kernel_size=1, + noskip=False, exp_ratio=1., exp_kernel_size=1, se_ratio=0., + se_reduce_mid=False, se_gate_fn=sigmoid, shuffle_type=None, + bn_args=None, drop_connect_rate=0.): + super(InvertedResidual, self).__init__() + + bn_args = _BN_ARGS_PT if bn_args is None else bn_args + mid_chs = int(in_chs * exp_ratio) + self.has_se = se_ratio is not None and se_ratio > 0. + self.has_residual = (in_chs == out_chs and stride == 1) and not noskip + self.act_fn = act_fn + self.drop_connect_rate = drop_connect_rate + + self.conv_pw = _conv(in_chs, mid_chs, exp_kernel_size) + self.bn1 = _fused_bn(mid_chs, **bn_args) + + self.shuffle_type = shuffle_type + if self.shuffle_type is not None and isinstance(exp_kernel_size, list): + self.shuffle = None + + self.conv_dw = DepthWiseConv(mid_chs, dw_kernel_size, stride) + self.bn2 = _fused_bn(mid_chs, **bn_args) + + if self.has_se: + se_base_chs = mid_chs if se_reduce_mid else in_chs + self.se = SqueezeExcite(mid_chs, reduce_chs=max(1, int(se_base_chs * se_ratio)), + act_fn=act_fn, gate_fn=se_gate_fn) + + self.conv_pwl = _conv(mid_chs, out_chs, pw_kernel_size) + self.bn3 = _fused_bn(out_chs, **bn_args) + + def construct(self, x): + identity = x + + x = self.conv_pw(x) + x = self.bn1(x) + x = self.act_fn(x) + + x = self.conv_dw(x) + x = self.bn2(x) + x = self.act_fn(x) + + if self.has_se: + x = self.se(x) + + x = self.conv_pwl(x) + x = self.bn3(x) + + if self.has_residual: + if self.drop_connect_rate > 0: + x = drop_connect(x, self.training, self.drop_connect_rate) + x = x + identity + return x + + +class GenEfficientNet(nn.Cell): + def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=32, num_features=1280, + channel_multiplier=1.0, channel_divisor=8, channel_min=None, + pad_type='', act_fn=relu, drop_rate=0., drop_connect_rate=0., + se_gate_fn=sigmoid, se_reduce_mid=False, bn_args=None, + global_pool='avg', head_conv='default', weight_init='goog'): + super(GenEfficientNet, self).__init__() + + bn_args = _BN_ARGS_PT if bn_args is None else bn_args + self.num_classes = num_classes + self.drop_rate = drop_rate + self.act_fn = act_fn + self.num_features = num_features + + stem_size = _round_channels(stem_size, channel_multiplier, channel_divisor, channel_min) + self.conv_stem = _conv(in_chans, stem_size, 3, stride=2) + self.bn1 = _fused_bn(stem_size, **bn_args) + in_chans = stem_size + self.blocks = BlockBuilder(in_chans, block_args, channel_multiplier, channel_divisor, channel_min, + pad_type, act_fn, se_gate_fn, se_reduce_mid, + bn_args, drop_connect_rate, verbose=_DEBUG) + in_chs = self.blocks.in_chs + + if not head_conv or head_conv == 'none': + self.efficient_head = False + self.conv_head = None + assert in_chs == self.num_features + else: + self.efficient_head = head_conv == 'efficient' + self.conv_head = _conv1x1(in_chs, self.num_features) + self.bn2 = None if self.efficient_head else _fused_bn(self.num_features, **bn_args) + self.global_pool = P.ReduceMean(keep_dims=True) + self.classifier = _dense(self.num_features, self.num_classes) + self.reshape = P.Reshape() + self.shape = P.Shape() + self.drop_out = nn.Dropout(keep_prob=1 - self.drop_rate) + + def construct(self, x): + x = self.conv_stem(x) + x = self.bn1(x) + x = self.act_fn(x) + x = self.blocks(x) + + if self.efficient_head: + x = self.global_pool(x, (2, 3)) + x = self.conv_head(x) + x = self.act_fn(x) + x = self.reshape(self.shape(x)[0], -1) + else: + if self.conv_head is not None: + x = self.conv_head(x) + x = self.bn2(x) + x = self.act_fn(x) + x = self.global_pool(x, (2, 3)) + x = self.reshape(x, (self.shape(x)[0], -1)) + + if self.training and self.drop_rate > 0.: + x = self.drop_out(x) + return self.classifier(x) + + +def _gen_efficientnet(channel_multiplier=1.0, depth_multiplier=1.0, num_classes=1000, **kwargs): + """Creates an EfficientNet model. + + Ref impl: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py + Paper: https://arxiv.org/abs/1905.11946 + + EfficientNet params + name: (channel_multiplier, depth_multiplier, resolution, dropout_rate) + 'efficientnet-b0': (1.0, 1.0, 224, 0.2), + 'efficientnet-b1': (1.0, 1.1, 240, 0.2), + 'efficientnet-b2': (1.1, 1.2, 260, 0.3), + 'efficientnet-b3': (1.2, 1.4, 300, 0.3), + 'efficientnet-b4': (1.4, 1.8, 380, 0.4), + 'efficientnet-b5': (1.6, 2.2, 456, 0.4), + 'efficientnet-b6': (1.8, 2.6, 528, 0.5), + 'efficientnet-b7': (2.0, 3.1, 600, 0.5), + + Args: + channel_multiplier: multiplier to number of channels per layer + depth_multiplier: multiplier to number of repeats per stage + + """ + arch_def = [ + ['ds_r1_k3_s1_e1_c16_se0.25'], + ['ir_r2_k3_s2_e6_c24_se0.25'], + ['ir_r2_k5_s2_e6_c40_se0.25'], + ['ir_r3_k3_s2_e6_c80_se0.25'], + ['ir_r3_k5_s1_e6_c112_se0.25'], + ['ir_r4_k5_s2_e6_c192_se0.25'], + ['ir_r1_k3_s1_e6_c320_se0.25'], + ] + num_features = _round_channels(1280, channel_multiplier, 8, None) + model = GenEfficientNet( + _decode_arch_def(arch_def, depth_multiplier), + num_classes=num_classes, + stem_size=32, + channel_multiplier=channel_multiplier, + num_features=num_features, + bn_args=_resolve_bn_args(kwargs), + act_fn=hard_swish, + **kwargs + ) + return model + + +def efficientnet_b0(num_classes=1000, in_chans=3, **kwargs): + """ EfficientNet-B0 """ + default_cfg = default_cfgs['efficientnet_b0'] + model = _gen_efficientnet( + channel_multiplier=1.0, depth_multiplier=1.0, + num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg + return model diff --git a/model_zoo/official/cv/efficientnet/src/loss.py b/model_zoo/official/cv/efficientnet/src/loss.py new file mode 100644 index 0000000000..c0077302d5 --- /dev/null +++ b/model_zoo/official/cv/efficientnet/src/loss.py @@ -0,0 +1,37 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""define loss function for network.""" +from mindspore.nn.loss.loss import _Loss +from mindspore.ops import operations as P +from mindspore.ops import functional as F +from mindspore.common import dtype as mstype +from mindspore import Tensor +import mindspore.nn as nn + +class LabelSmoothingCrossEntropy(_Loss): + + def __init__(self, smooth_factor=0.1, num_classes=1000): + super(LabelSmoothingCrossEntropy, self).__init__() + self.onehot = P.OneHot() + self.on_value = Tensor(1.0 - smooth_factor, mstype.float32) + self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32) + self.ce = nn.SoftmaxCrossEntropyWithLogits() + self.mean = P.ReduceMean(False) + + def construct(self, logits, label): + one_hot_label = self.onehot(label, F.shape(logits)[1], self.on_value, self.off_value) + loss_logit = self.ce(logits, one_hot_label) + loss_logit = self.mean(loss_logit, 0) + return loss_logit diff --git a/model_zoo/official/cv/efficientnet/src/transform.py b/model_zoo/official/cv/efficientnet/src/transform.py new file mode 100644 index 0000000000..2cdea209bb --- /dev/null +++ b/model_zoo/official/cv/efficientnet/src/transform.py @@ -0,0 +1,48 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +random augment class +""" +import numpy as np +import mindspore.dataset.vision.py_transforms as P +from src import transform_utils + +IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) +IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) + +class RandAugment: + # config_str belongs to str + # hparams belongs to dict + def __init__(self, config_str="rand-m9-mstd0.5", hparams=None): + hparams = hparams if hparams is not None else {} + self.config_str = config_str + self.hparams = hparams + + def __call__(self, imgs, labels, batchInfo): + # assert the imgs objetc are pil_images + ret_imgs = [] + ret_labels = [] + py_to_pil_op = P.ToPIL() + to_tensor = P.ToTensor() + normalize_op = P.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) + rand_augment_ops = transform_utils.rand_augment_transform(self.config_str, self.hparams) + for i, image in enumerate(imgs): + img_pil = py_to_pil_op(image) + img_pil = rand_augment_ops(img_pil) + img_array = to_tensor(img_pil) + img_array = normalize_op(img_array) + ret_imgs.append(img_array) + ret_labels.append(labels[i]) + return np.array(ret_imgs), np.array(ret_labels) diff --git a/model_zoo/official/cv/efficientnet/src/transform_utils.py b/model_zoo/official/cv/efficientnet/src/transform_utils.py new file mode 100644 index 0000000000..4e86cf9fbd --- /dev/null +++ b/model_zoo/official/cv/efficientnet/src/transform_utils.py @@ -0,0 +1,571 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +random augment utils +""" +import math +import random +import re + +import numpy as np +import PIL +from PIL import Image, ImageEnhance, ImageOps + +_PIL_VER = tuple([int(x) for x in PIL.__version__.split('.')[:2]]) +_FILL = (128, 128, 128) +_MAX_LEVEL = 10. +_HPARAMS_DEFAULT = dict(translate_const=250, img_mean=_FILL) +_RAND_TRANSFORMS = [ + 'Distort', + 'Zoom', + 'Blur', + 'Skew', + 'AutoContrast', + 'Equalize', + 'Invert', + 'Rotate', + 'PosterizeTpu', + 'Solarize', + 'SolarizeAdd', + 'Color', + 'Contrast', + 'Brightness', + 'Sharpness', + 'ShearX', + 'ShearY', + 'TranslateXRel', + 'TranslateYRel', +] +_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) +_RAND_CHOICE_WEIGHTS_0 = { + 'Rotate': 0.3, + 'ShearX': 0.2, + 'ShearY': 0.2, + 'TranslateXRel': 0.1, + 'TranslateYRel': 0.1, + 'Color': .025, + 'Sharpness': 0.025, + 'AutoContrast': 0.025, + 'Solarize': .005, + 'SolarizeAdd': .005, + 'Contrast': .005, + 'Brightness': .005, + 'Equalize': .005, + 'PosterizeTpu': 0, + 'Invert': 0, + 'Distort': 0, + 'Zoom': 0, + 'Blur': 0, + 'Skew': 0, +} + + +def _interpolation(kwargs): + interpolation = kwargs.pop('resample', Image.BILINEAR) + if isinstance(interpolation, (list, tuple)): + return random.choice(interpolation) + return interpolation + + +def _check_args_tf(kwargs): + if 'fillcolor' in kwargs and _PIL_VER < (5, 0): + kwargs.pop('fillcolor') + kwargs['resample'] = _interpolation(kwargs) + +# define all kinds of functions + + +def _randomly_negate(v): + return -v if random.random() > 0.5 else v + + +def shear_x(img, factor, **kwargs): + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs) + + +def shear_y(img, factor, **kwargs): + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs) + + +def translate_x_rel(img, pct, **kwargs): + pixels = pct * img.size[0] + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs) + + +def translate_y_rel(img, pct, **kwargs): + pixels = pct * img.size[1] + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs) + + +def translate_x_abs(img, pixels, **kwargs): + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs) + + +def translate_y_abs(img, pixels, **kwargs): + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs) + + +def rotate(img, degrees, **kwargs): + kwargs_new = kwargs + kwargs_new.pop('resample') + kwargs_new['resample'] = Image.BICUBIC + if _PIL_VER >= (5, 2): + return img.rotate(degrees, **kwargs_new) + if _PIL_VER >= (5, 0): + w, h = img.size + post_trans = (0, 0) + rotn_center = (w / 2.0, h / 2.0) + angle = -math.radians(degrees) + matrix = [ + round(math.cos(angle), 15), + round(math.sin(angle), 15), + 0.0, + round(-math.sin(angle), 15), + round(math.cos(angle), 15), + 0.0, + ] + + def transform(x, y, matrix): + (a, b, c, d, e, f) = matrix + return a * x + b * y + c, d * x + e * y + f + + matrix[2], matrix[5] = transform( + -rotn_center[0] - post_trans[0], -rotn_center[1] - post_trans[1], matrix + ) + matrix[2] += rotn_center[0] + matrix[5] += rotn_center[1] + return img.transform(img.size, Image.AFFINE, matrix, **kwargs_new) + return img.rotate(degrees, resample=kwargs['resample']) + + +def auto_contrast(img, **__): + return ImageOps.autocontrast(img) + + +def invert(img, **__): + return ImageOps.invert(img) + + +def equalize(img, **__): + return ImageOps.equalize(img) + + +def solarize(img, thresh, **__): + return ImageOps.solarize(img, thresh) + + +def solarize_add(img, add, thresh=128, **__): + lut = [] + for i in range(256): + if i < thresh: + lut.append(min(255, i + add)) + else: + lut.append(i) + if img.mode in ("L", "RGB"): + if img.mode == "RGB" and len(lut) == 256: + lut = lut + lut + lut + return img.point(lut) + return img + + +def posterize(img, bits_to_keep, **__): + if bits_to_keep >= 8: + return img + return ImageOps.posterize(img, bits_to_keep) + + +def contrast(img, factor, **__): + return ImageEnhance.Contrast(img).enhance(factor) + + +def color(img, factor, **__): + return ImageEnhance.Color(img).enhance(factor) + + +def brightness(img, factor, **__): + return ImageEnhance.Brightness(img).enhance(factor) + + +def sharpness(img, factor, **__): + return ImageEnhance.Sharpness(img).enhance(factor) + + +def _rotate_level_to_arg(level, _hparams): + # range [-30, 30] + level = (level / _MAX_LEVEL) * 30. + level = _randomly_negate(level) + return (level,) + + +def _enhance_level_to_arg(level, _hparams): + # range [0.1, 1.9] + return ((level / _MAX_LEVEL) * 1.8 + 0.1,) + + +def _shear_level_to_arg(level, _hparams): + # range [-0.3, 0.3] + level = (level / _MAX_LEVEL) * 0.3 + level = _randomly_negate(level) + return (level,) + + +def _translate_abs_level_to_arg(level, hparams): + translate_const = hparams['translate_const'] + level = (level / _MAX_LEVEL) * float(translate_const) + level = _randomly_negate(level) + return (level,) + + +def _translate_rel_level_to_arg(level, _hparams): + # range [-0.45, 0.45] + level = (level / _MAX_LEVEL) * 0.45 + level = _randomly_negate(level) + return (level,) + + +def _posterize_original_level_to_arg(level, _hparams): + # As per original AutoAugment paper description + # range [4, 8], 'keep 4 up to 8 MSB of image' + return (int((level / _MAX_LEVEL) * 4) + 4,) + + +def _posterize_research_level_to_arg(level, _hparams): + # As per Tensorflow models research and UDA impl + # range [4, 0], 'keep 4 down to 0 MSB of original image' + return (4 - int((level / _MAX_LEVEL) * 4),) + + +def _posterize_tpu_level_to_arg(level, _hparams): + # As per Tensorflow TPU EfficientNet impl + # range [0, 4], 'keep 0 up to 4 MSB of original image' + return (int((level / _MAX_LEVEL) * 4),) + + +def _solarize_level_to_arg(level, _hparams): + # range [0, 256] + return (int((level / _MAX_LEVEL) * 256),) + + +def _solarize_add_level_to_arg(level, _hparams): + # range [0, 110] + return (int((level / _MAX_LEVEL) * 110),) + + +def _distort_level_to_arg(level, _hparams): + return (int((level / _MAX_LEVEL) * 10 + 10),) + + +def _zoom_level_to_arg(level, _hparams): + return ((level / _MAX_LEVEL) * 0.4,) + + +def _blur_level_to_arg(level, _hparams): + level = (level / _MAX_LEVEL) * 0.5 + level = _randomly_negate(level) + return (level,) + + +def _skew_level_to_arg(level, _hparams): + level = (level / _MAX_LEVEL) * 0.3 + level = _randomly_negate(level) + return (level,) + + +def distort(img, v, **__): + w, h = img.size + horizontal_tiles = int(0.1 * v) + vertical_tiles = int(0.1 * v) + + width_of_square = int(math.floor(w / float(horizontal_tiles))) + height_of_square = int(math.floor(h / float(vertical_tiles))) + width_of_last_square = w - (width_of_square * (horizontal_tiles - 1)) + height_of_last_square = h - (height_of_square * (vertical_tiles - 1)) + dimensions = [] + + for vertical_tile in range(vertical_tiles): + for horizontal_tile in range(horizontal_tiles): + if vertical_tile == (vertical_tiles - 1) and horizontal_tile == (horizontal_tiles - 1): + dimensions.append([horizontal_tile * width_of_square, + vertical_tile * height_of_square, + width_of_last_square + (horizontal_tile * width_of_square), + height_of_last_square + (height_of_square * vertical_tile)]) + elif vertical_tile == (vertical_tiles - 1): + dimensions.append([horizontal_tile * width_of_square, + vertical_tile * height_of_square, + width_of_square + (horizontal_tile * width_of_square), + height_of_last_square + (height_of_square * vertical_tile)]) + elif horizontal_tile == (horizontal_tiles - 1): + dimensions.append([horizontal_tile * width_of_square, + vertical_tile * height_of_square, + width_of_last_square + (horizontal_tile * width_of_square), + height_of_square + (height_of_square * vertical_tile)]) + else: + dimensions.append([horizontal_tile * width_of_square, + vertical_tile * height_of_square, + width_of_square + (horizontal_tile * width_of_square), + height_of_square + (height_of_square * vertical_tile)]) + last_column = [] + for i in range(vertical_tiles): + last_column.append((horizontal_tiles - 1) + horizontal_tiles * i) + + last_row = range((horizontal_tiles * vertical_tiles) - horizontal_tiles, horizontal_tiles * vertical_tiles) + + polygons = [] + for x1, y1, x2, y2 in dimensions: + polygons.append([x1, y1, x1, y2, x2, y2, x2, y1]) + + polygon_indices = [] + for i in range((vertical_tiles * horizontal_tiles) - 1): + if i not in last_row and i not in last_column: + polygon_indices.append([i, i + 1, i + horizontal_tiles, i + 1 + horizontal_tiles]) + + for a, b, c, d in polygon_indices: + dx = v + dy = v + + x1, y1, x2, y2, x3, y3, x4, y4 = polygons[a] + polygons[a] = [x1, y1, + x2, y2, + x3 + dx, y3 + dy, + x4, y4] + + x1, y1, x2, y2, x3, y3, x4, y4 = polygons[b] + polygons[b] = [x1, y1, + x2 + dx, y2 + dy, + x3, y3, + x4, y4] + + x1, y1, x2, y2, x3, y3, x4, y4 = polygons[c] + polygons[c] = [x1, y1, + x2, y2, + x3, y3, + x4 + dx, y4 + dy] + + x1, y1, x2, y2, x3, y3, x4, y4 = polygons[d] + polygons[d] = [x1 + dx, y1 + dy, + x2, y2, + x3, y3, + x4, y4] + + generated_mesh = [] + for idx, i in enumerate(dimensions): + generated_mesh.append([dimensions[idx], polygons[idx]]) + return img.transform(img.size, PIL.Image.MESH, generated_mesh, resample=PIL.Image.BICUBIC) + + +def zoom(img, v, **__): + #assert 0.1 <= v <= 2 + w, h = img.size + image_zoomed = img.resize((int(round(img.size[0] * v)), + int(round(img.size[1] * v))), + resample=PIL.Image.BICUBIC) + w_zoomed, h_zoomed = image_zoomed.size + + return image_zoomed.crop((math.floor((float(w_zoomed) / 2) - (float(w) / 2)), + math.floor((float(h_zoomed) / 2) - (float(h) / 2)), + math.floor((float(w_zoomed) / 2) + (float(w) / 2)), + math.floor((float(h_zoomed) / 2) + (float(h) / 2)))) + + +def erase(img, v, **__): + #assert 0.1<= v <= 1 + w, h = img.size + w_occlusion = int(w * v) + h_occlusion = int(h * v) + if len(img.getbands()) == 1: + rectangle = PIL.Image.fromarray(np.uint8(np.random.rand(w_occlusion, h_occlusion) * 255)) + else: + rectangle = PIL.Image.fromarray(np.uint8(np.random.rand(w_occlusion, h_occlusion, len(img.getbands())) * 255)) + + random_position_x = random.randint(0, w - w_occlusion) + random_position_y = random.randint(0, h - h_occlusion) + img.paste(rectangle, (random_position_x, random_position_y)) + return img + + +def skew(img, v, **__): + #assert -1 <= v <= 1 + w, h = img.size + x1 = 0 + x2 = h + y1 = 0 + y2 = w + original_plane = [(y1, x1), (y2, x1), (y2, x2), (y1, x2)] + max_skew_amount = max(w, h) + max_skew_amount = int(math.ceil(max_skew_amount * v)) + skew_amount = max_skew_amount + new_plane = [(y1 - skew_amount, x1), # Top Left + (y2, x1 - skew_amount), # Top Right + (y2 + skew_amount, x2), # Bottom Right + (y1, x2 + skew_amount)] + matrix = [] + for p1, p2 in zip(new_plane, original_plane): + matrix.append([p1[0], p1[1], 1, 0, 0, 0, -p2[0] * p1[0], -p2[0] * p1[1]]) + matrix.append([0, 0, 0, p1[0], p1[1], 1, -p2[1] * p1[0], -p2[1] * p1[1]]) + + A = np.matrix(matrix, dtype=np.float) + B = np.array(original_plane).reshape(8) + perspective_skew_coefficients_matrix = np.dot(np.linalg.pinv(A), B) + perspective_skew_coefficients_matrix = np.array(perspective_skew_coefficients_matrix).reshape(8) + + return img.transform(img.size, PIL.Image.PERSPECTIVE, perspective_skew_coefficients_matrix, + resample=PIL.Image.BICUBIC) + + +def blur(img, v, **__): + #assert -3 <= v <= 3 + return img.filter(PIL.ImageFilter.GaussianBlur(v)) + + +def rand_augment_ops(magnitude=10, hparams=None, transforms=None): + hparams = hparams or _HPARAMS_DEFAULT + transforms = transforms or _RAND_TRANSFORMS + return [AutoAugmentOp(name, prob=0.5, magnitude=magnitude, hparams=hparams) for name in transforms] + + +def _select_rand_weights(weight_idx=0, transforms=None): + transforms = transforms or _RAND_TRANSFORMS + assert weight_idx == 0 # only one set of weights currently + rand_weights = _RAND_CHOICE_WEIGHTS_0 + probs = [rand_weights[k] for k in transforms] + probs /= np.sum(probs) + return probs + + +def rand_augment_transform(config_str, hparams): + magnitude = _MAX_LEVEL # default to _MAX_LEVEL for magnitude (currently 10) + num_layers = 2 # default to 2 ops per image + weight_idx = None # default to no probability weights for op choice + config = config_str.split('-') + assert config[0] == 'rand' + config = config[1:] + for c in config: + cs = re.split(r'(\d.*)', c) + if len(cs) < 2: + continue + key, val = cs[:2] + if key == 'mstd': + # noise param injected via hparams for now + hparams.setdefault('magnitude_std', float(val)) + elif key == 'm': + magnitude = int(val) + elif key == 'n': + num_layers = int(val) + elif key == 'w': + weight_idx = int(val) + else: + assert False, 'Unknown RandAugment config section' + ra_ops = rand_augment_ops(magnitude=magnitude, hparams=hparams) + choice_weights = None if weight_idx is None else _select_rand_weights(weight_idx) + + final_result = RandAugment(ra_ops, num_layers, choice_weights=choice_weights) + return final_result + + +LEVEL_TO_ARG = { + 'Distort': _distort_level_to_arg, + 'Zoom': _zoom_level_to_arg, + 'Blur': _blur_level_to_arg, + 'Skew': _skew_level_to_arg, + 'AutoContrast': None, + 'Equalize': None, + 'Invert': None, + 'Rotate': _rotate_level_to_arg, + 'PosterizeOriginal': _posterize_original_level_to_arg, + 'PosterizeResearch': _posterize_research_level_to_arg, + 'PosterizeTpu': _posterize_tpu_level_to_arg, + 'Solarize': _solarize_level_to_arg, + 'SolarizeAdd': _solarize_add_level_to_arg, + 'Color': _enhance_level_to_arg, + 'Contrast': _enhance_level_to_arg, + 'Brightness': _enhance_level_to_arg, + 'Sharpness': _enhance_level_to_arg, + 'ShearX': _shear_level_to_arg, + 'ShearY': _shear_level_to_arg, + 'TranslateX': _translate_abs_level_to_arg, + 'TranslateY': _translate_abs_level_to_arg, + 'TranslateXRel': _translate_rel_level_to_arg, + 'TranslateYRel': _translate_rel_level_to_arg, +} + +NAME_TO_OP = { + 'Distort': distort, + 'Zoom': zoom, + 'Blur': blur, + 'Skew': skew, + 'AutoContrast': auto_contrast, + 'Equalize': equalize, + 'Invert': invert, + 'Rotate': rotate, + 'PosterizeOriginal': posterize, + 'PosterizeResearch': posterize, + 'PosterizeTpu': posterize, + 'Solarize': solarize, + 'SolarizeAdd': solarize_add, + 'Color': color, + 'Contrast': contrast, + 'Brightness': brightness, + 'Sharpness': sharpness, + 'ShearX': shear_x, + 'ShearY': shear_y, + 'TranslateX': translate_x_abs, + 'TranslateY': translate_y_abs, + 'TranslateXRel': translate_x_rel, + 'TranslateYRel': translate_y_rel, +} + + +class AutoAugmentOp: + def __init__(self, name, prob=0.5, magnitude=10, hparams=None): + hparams = hparams or _HPARAMS_DEFAULT + self.aug_fn = NAME_TO_OP[name] + self.level_fn = LEVEL_TO_ARG[name] + self.prob = prob + self.magnitude = magnitude + self.hparams = hparams.copy() + self.kwargs = dict( + fillcolor=hparams['img_mean'] if 'img_mean' in hparams else _FILL, + resample=hparams['interpolation'] if 'interpolation' in hparams else _RANDOM_INTERPOLATION, + ) + self.magnitude_std = self.hparams.get('magnitude_std', 0) + + def __call__(self, img): + if random.random() > self.prob: + return img + magnitude = self.magnitude + if self.magnitude_std and self.magnitude_std > 0: + magnitude = random.gauss(magnitude, self.magnitude_std) + level_args = self.level_fn(magnitude, self.hparams) if self.level_fn is not None else tuple() + return self.aug_fn(img, *level_args, **self.kwargs) + + +class RandAugment: + def __init__(self, ops, num_layers=2, choice_weights=None): + self.ops = ops + self.num_layers = num_layers + self.choice_weights = choice_weights + + def __call__(self, img): + ops = np.random.choice( + self.ops, self.num_layers, replace=self.choice_weights is None, p=self.choice_weights) + for op in ops: + img = op(img) + return img diff --git a/model_zoo/official/cv/efficientnet/train.py b/model_zoo/official/cv/efficientnet/train.py new file mode 100644 index 0000000000..5c102648d4 --- /dev/null +++ b/model_zoo/official/cv/efficientnet/train.py @@ -0,0 +1,191 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""train imagenet.""" +import argparse +import math +import os +import random +import time + +import numpy as np +import mindspore +from mindspore import Tensor, context +from mindspore.communication.management import get_group_size, get_rank, init +from mindspore.nn import SGD, RMSProp +from mindspore.train.callback import (CheckpointConfig, LossMonitor, + ModelCheckpoint, TimeMonitor) +from mindspore.train.loss_scale_manager import FixedLossScaleManager +from mindspore.train.model import Model, ParallelMode +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from src.config import efficientnet_b0_config_gpu as cfg +from src.dataset import create_dataset +from src.efficientnet import efficientnet_b0 +from src.loss import LabelSmoothingCrossEntropy + +mindspore.common.set_seed(cfg.random_seed) +random.seed(cfg.random_seed) +np.random.seed(cfg.random_seed) + + +def get_lr(base_lr, total_epochs, steps_per_epoch, decay_steps=1, + decay_rate=0.9, warmup_steps=0., warmup_lr_init=0., global_epoch=0): + lr_each_step = [] + total_steps = steps_per_epoch * total_epochs + global_steps = steps_per_epoch * global_epoch + self_warmup_delta = ((base_lr - warmup_lr_init) / + warmup_steps) if warmup_steps > 0 else 0 + self_decay_rate = decay_rate if decay_rate < 1 else 1 / decay_rate + for i in range(total_steps): + steps = math.floor(i / steps_per_epoch) + cond = 1 if (steps < warmup_steps) else 0 + warmup_lr = warmup_lr_init + steps * self_warmup_delta + decay_nums = math.floor(steps / decay_steps) + decay_rate = math.pow(self_decay_rate, decay_nums) + decay_lr = base_lr * decay_rate + lr = cond * warmup_lr + (1 - cond) * decay_lr + lr_each_step.append(lr) + lr_each_step = lr_each_step[global_steps:] + lr_each_step = np.array(lr_each_step).astype(np.float32) + return lr_each_step + + +def get_outdir(path, *paths, inc=False): + outdir = os.path.join(path, *paths) + if not os.path.exists(outdir): + os.makedirs(outdir) + + elif inc: + count = 1 + outdir_inc = outdir + '-' + str(count) + while os.path.exists(outdir_inc): + count = count + 1 + outdir_inc = outdir + '-' + str(count) + assert count < 100 + outdir = outdir_inc + os.makedirs(outdir) + return outdir + + +parser = argparse.ArgumentParser( + description='Training configuration', add_help=False) +parser.add_argument('--data_path', type=str, default='/home/dataset/imagenet_jpeg/', metavar='DIR', + help='path to dataset') +parser.add_argument('--distributed', action='store_true', default=False) +parser.add_argument('--GPU', action='store_true', default=False, + help='Use GPU for training (default: False)') +parser.add_argument('--cur_time', type=str, + default='19701010-000000', help='current time') +parser.add_argument('--resume', default='', type=str, metavar='PATH', + help='Resume full model and optimizer state from checkpoint (default: none)') + + +def main(): + args, _ = parser.parse_known_args() + devid, rank_id, rank_size = 0, 0, 1 + + context.set_context(mode=context.GRAPH_MODE) + + if args.distributed: + if args.GPU: + init("nccl") + context.set_context(device_target='GPU') + else: + init() + devid = int(os.getenv('DEVICE_ID')) + context.set_context( + device_target='Ascend', device_id=devid, reserve_class_name_in_scope=False) + context.reset_auto_parallel_context() + rank_id = get_rank() + rank_size = get_group_size() + context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, + gradients_mean=True, device_num=rank_size) + else: + if args.GPU: + context.set_context(device_target='GPU') + + is_master = not args.distributed or (rank_id == 0) + + net = efficientnet_b0(num_classes=cfg.num_classes, + drop_rate=cfg.drop, + drop_connect_rate=cfg.drop_connect, + global_pool=cfg.gp, + bn_tf=cfg.bn_tf, + ) + + cur_time = args.cur_time + output_base = './output' + + exp_name = '-'.join([ + cur_time, + cfg.model, + str(224) + ]) + time.sleep(rank_id) + output_dir = get_outdir(output_base, exp_name) + + train_data_url = os.path.join(args.data_path, 'train') + train_dataset = create_dataset( + cfg.batch_size, train_data_url, workers=cfg.workers, distributed=args.distributed) + batches_per_epoch = train_dataset.get_dataset_size() + + loss_cb = LossMonitor(per_print_times=batches_per_epoch) + loss = LabelSmoothingCrossEntropy(smooth_factor=cfg.smoothing) + time_cb = TimeMonitor(data_size=batches_per_epoch) + loss_scale_manager = FixedLossScaleManager( + cfg.loss_scale, drop_overflow_update=False) + + config_ck = CheckpointConfig( + save_checkpoint_steps=batches_per_epoch, keep_checkpoint_max=cfg.keep_checkpoint_max) + ckpoint_cb = ModelCheckpoint( + prefix=cfg.model, directory=output_dir, config=config_ck) + + lr = Tensor(get_lr(base_lr=cfg.lr, total_epochs=cfg.epochs, steps_per_epoch=batches_per_epoch, + decay_steps=cfg.decay_epochs, decay_rate=cfg.decay_rate, + warmup_steps=cfg.warmup_epochs, warmup_lr_init=cfg.warmup_lr_init, + global_epoch=cfg.resume_start_epoch)) + if cfg.opt == 'sgd': + optimizer = SGD(net.trainable_params(), learning_rate=lr, momentum=cfg.momentum, + weight_decay=cfg.weight_decay, + loss_scale=cfg.loss_scale + ) + elif cfg.opt == 'rmsprop': + optimizer = RMSProp(net.trainable_params(), learning_rate=lr, decay=0.9, weight_decay=cfg.weight_decay, + momentum=cfg.momentum, epsilon=cfg.opt_eps, loss_scale=cfg.loss_scale + ) + + loss.add_flags_recursive(fp32=True, fp16=False) + + if args.resume: + ckpt = load_checkpoint(args.resume) + load_param_into_net(net, ckpt) + + model = Model(net, loss, optimizer, + loss_scale_manager=loss_scale_manager, + amp_level=cfg.amp_level + ) + + callbacks = [loss_cb, ckpoint_cb, time_cb] if is_master else [] + + if args.resume: + real_epoch = cfg.epochs - cfg.resume_start_epoch + model.train(real_epoch, train_dataset, + callbacks=callbacks, dataset_sink_mode=True) + else: + model.train(cfg.epochs, train_dataset, + callbacks=callbacks, dataset_sink_mode=True) + + +if __name__ == '__main__': + main()