| @@ -90,10 +90,15 @@ sh run_standalone_train.sh DEVICE_ID DATA_PATH | |||||
| #### Launch | #### Launch | ||||
| ```bash | ```bash | ||||
| # distributed training example(8p) | |||||
| # distributed training example(8p) for Ascend | |||||
| sh scripts/run_distribute_train.sh MINDSPORE_HCCL_CONFIG_PATH /dataset/train | sh scripts/run_distribute_train.sh MINDSPORE_HCCL_CONFIG_PATH /dataset/train | ||||
| # standalone training example | |||||
| # standalone training example for Ascend | |||||
| sh scripts/run_standalone_train.sh 0 /dataset/train | sh scripts/run_standalone_train.sh 0 /dataset/train | ||||
| # distributed training example(8p) for GPU | |||||
| sh scripts/run_distribute_train_for_gpu.sh /dataset/train | |||||
| # standalone training example for GPU | |||||
| sh scripts/run_standalone_train_for_gpu.sh 0 /dataset/train | |||||
| ``` | ``` | ||||
| #### Result | #### Result | ||||
| @@ -106,14 +111,15 @@ You can find checkpoint file together with result in log. | |||||
| ``` | ``` | ||||
| # Evaluation | # Evaluation | ||||
| sh run_eval.sh DEVICE_ID DATA_PATH PRETRAINED_CKPT_PATH | |||||
| sh run_eval.sh DEVICE_ID DATA_PATH PRETRAINED_CKPT_PATH PLATFORM | |||||
| ``` | ``` | ||||
| PLATFORM is Ascend or GPU, default is Ascend. | |||||
| #### Launch | #### Launch | ||||
| ```bash | ```bash | ||||
| # Evaluation with checkpoint | # Evaluation with checkpoint | ||||
| sh scripts/run_eval.sh 0 /opt/npu/datasets/classification/val /resnext50_100.ckpt | |||||
| sh scripts/run_eval.sh 0 /opt/npu/datasets/classification/val /resnext50_100.ckpt Ascend | |||||
| ``` | ``` | ||||
| > checkpoint can be produced in training process. | > checkpoint can be produced in training process. | ||||
| @@ -29,15 +29,11 @@ from mindspore.ops import functional as F | |||||
| from mindspore.common import dtype as mstype | from mindspore.common import dtype as mstype | ||||
| from src.utils.logging import get_logger | from src.utils.logging import get_logger | ||||
| from src.utils.auto_mixed_precision import auto_mixed_precision | |||||
| from src.image_classification import get_network | from src.image_classification import get_network | ||||
| from src.dataset import classification_dataset | from src.dataset import classification_dataset | ||||
| from src.config import config | from src.config import config | ||||
| devid = int(os.getenv('DEVICE_ID')) | |||||
| context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, | |||||
| device_target="Ascend", save_graphs=False, device_id=devid) | |||||
| class ParameterReduce(nn.Cell): | class ParameterReduce(nn.Cell): | ||||
| """ParameterReduce""" | """ParameterReduce""" | ||||
| @@ -56,6 +52,7 @@ class ParameterReduce(nn.Cell): | |||||
| def parse_args(cloud_args=None): | def parse_args(cloud_args=None): | ||||
| """parse_args""" | """parse_args""" | ||||
| parser = argparse.ArgumentParser('mindspore classification test') | parser = argparse.ArgumentParser('mindspore classification test') | ||||
| parser.add_argument('--platform', type=str, default='Ascend', choices=('Ascend', 'GPU'), help='run platform') | |||||
| # dataset related | # dataset related | ||||
| parser.add_argument('--data_dir', type=str, default='/opt/npu/datasets/classification/val', help='eval data dir') | parser.add_argument('--data_dir', type=str, default='/opt/npu/datasets/classification/val', help='eval data dir') | ||||
| @@ -108,12 +105,25 @@ def merge_args(args, cloud_args): | |||||
| def test(cloud_args=None): | def test(cloud_args=None): | ||||
| """test""" | """test""" | ||||
| args = parse_args(cloud_args) | args = parse_args(cloud_args) | ||||
| context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, | |||||
| device_target=args.platform, save_graphs=False) | |||||
| if os.getenv('DEVICE_ID', "not_set").isdigit(): | |||||
| context.set_context(device_id=int(os.getenv('DEVICE_ID'))) | |||||
| # init distributed | # init distributed | ||||
| if args.is_distributed: | if args.is_distributed: | ||||
| init() | |||||
| if args.platform == "Ascend": | |||||
| init() | |||||
| elif args.platform == "GPU": | |||||
| init("nccl") | |||||
| args.rank = get_rank() | args.rank = get_rank() | ||||
| args.group_size = get_group_size() | args.group_size = get_group_size() | ||||
| parallel_mode = ParallelMode.DATA_PARALLEL | |||||
| context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=args.group_size, | |||||
| parameter_broadcast=True, mirror_mean=True) | |||||
| else: | |||||
| args.rank = 0 | |||||
| args.group_size = 1 | |||||
| args.outputs_dir = os.path.join(args.log_path, | args.outputs_dir = os.path.join(args.log_path, | ||||
| datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S')) | datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S')) | ||||
| @@ -140,7 +150,7 @@ def test(cloud_args=None): | |||||
| max_epoch=1, rank=args.rank, group_size=args.group_size, | max_epoch=1, rank=args.rank, group_size=args.group_size, | ||||
| mode='eval') | mode='eval') | ||||
| eval_dataloader = de_dataset.create_tuple_iterator() | eval_dataloader = de_dataset.create_tuple_iterator() | ||||
| network = get_network(args.backbone, args.num_classes) | |||||
| network = get_network(args.backbone, args.num_classes, platform=args.platform) | |||||
| if network is None: | if network is None: | ||||
| raise NotImplementedError('not implement {}'.format(args.backbone)) | raise NotImplementedError('not implement {}'.format(args.backbone)) | ||||
| @@ -157,12 +167,13 @@ def test(cloud_args=None): | |||||
| load_param_into_net(network, param_dict_new) | load_param_into_net(network, param_dict_new) | ||||
| args.logger.info('load model {} success'.format(model)) | args.logger.info('load model {} success'.format(model)) | ||||
| # must add | |||||
| network.add_flags_recursive(fp16=True) | |||||
| img_tot = 0 | img_tot = 0 | ||||
| top1_correct = 0 | top1_correct = 0 | ||||
| top5_correct = 0 | top5_correct = 0 | ||||
| if args.platform == "Ascend": | |||||
| network.to_float(mstype.float16) | |||||
| else: | |||||
| auto_mixed_precision(network) | |||||
| network.set_train(False) | network.set_train(False) | ||||
| t_end = time.time() | t_end = time.time() | ||||
| it = 0 | it = 0 | ||||
| @@ -0,0 +1,30 @@ | |||||
| #!/bin/bash | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| DATA_DIR=$1 | |||||
| export RANK_SIZE=8 | |||||
| PATH_CHECKPOINT="" | |||||
| if [ $# == 2 ] | |||||
| then | |||||
| PATH_CHECKPOINT=$2 | |||||
| fi | |||||
| mpirun --allow-run-as-root -n $RANK_SIZE \ | |||||
| python train.py \ | |||||
| --is_distribute=1 \ | |||||
| --platform="GPU" \ | |||||
| --pretrained=$PATH_CHECKPOINT \ | |||||
| --data_dir=$DATA_DIR > log.txt 2>&1 & | |||||
| @@ -14,11 +14,16 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| DEVICE_ID=$1 | |||||
| export DEVICE_ID=$1 | |||||
| DATA_DIR=$2 | DATA_DIR=$2 | ||||
| PATH_CHECKPOINT=$3 | PATH_CHECKPOINT=$3 | ||||
| PLATFORM=Ascend | |||||
| if [ $# == 4 ] | |||||
| then | |||||
| PLATFORM=$4 | |||||
| fi | |||||
| python eval.py \ | python eval.py \ | ||||
| --device_id=$DEVICE_ID \ | |||||
| --pretrained=$PATH_CHECKPOINT \ | --pretrained=$PATH_CHECKPOINT \ | ||||
| --platform=$PLATFORM \ | |||||
| --data_dir=$DATA_DIR > log.txt 2>&1 & | --data_dir=$DATA_DIR > log.txt 2>&1 & | ||||
| @@ -14,7 +14,7 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| DEVICE_ID=$1 | |||||
| export DEVICE_ID=$1 | |||||
| DATA_DIR=$2 | DATA_DIR=$2 | ||||
| PATH_CHECKPOINT="" | PATH_CHECKPOINT="" | ||||
| if [ $# == 3 ] | if [ $# == 3 ] | ||||
| @@ -0,0 +1,30 @@ | |||||
| #!/bin/bash | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| export DEVICE_ID=$1 | |||||
| DATA_DIR=$2 | |||||
| PATH_CHECKPOINT="" | |||||
| if [ $# == 3 ] | |||||
| then | |||||
| PATH_CHECKPOINT=$3 | |||||
| fi | |||||
| python train.py \ | |||||
| --is_distribute=0 \ | |||||
| --pretrained=$PATH_CHECKPOINT \ | |||||
| --platform="GPU" \ | |||||
| --data_dir=$DATA_DIR > log.txt 2>&1 & | |||||
| @@ -87,7 +87,8 @@ class BasicBlock(nn.Cell): | |||||
| """ | """ | ||||
| expansion = 1 | expansion = 1 | ||||
| def __init__(self, in_channels, out_channels, stride=1, down_sample=None, use_se=False, **kwargs): | |||||
| def __init__(self, in_channels, out_channels, stride=1, down_sample=None, use_se=False, | |||||
| platform="Ascend", **kwargs): | |||||
| super(BasicBlock, self).__init__() | super(BasicBlock, self).__init__() | ||||
| self.conv1 = conv3x3(in_channels, out_channels, stride=stride) | self.conv1 = conv3x3(in_channels, out_channels, stride=stride) | ||||
| self.bn1 = nn.BatchNorm2d(out_channels) | self.bn1 = nn.BatchNorm2d(out_channels) | ||||
| @@ -142,7 +143,7 @@ class Bottleneck(nn.Cell): | |||||
| expansion = 4 | expansion = 4 | ||||
| def __init__(self, in_channels, out_channels, stride=1, down_sample=None, | def __init__(self, in_channels, out_channels, stride=1, down_sample=None, | ||||
| base_width=64, groups=1, use_se=False, **kwargs): | |||||
| base_width=64, groups=1, use_se=False, platform="Ascend", **kwargs): | |||||
| super(Bottleneck, self).__init__() | super(Bottleneck, self).__init__() | ||||
| width = int(out_channels * (base_width / 64.0)) * groups | width = int(out_channels * (base_width / 64.0)) * groups | ||||
| @@ -153,7 +154,11 @@ class Bottleneck(nn.Cell): | |||||
| self.conv3x3s = nn.CellList() | self.conv3x3s = nn.CellList() | ||||
| self.conv2 = GroupConv(width, width, 3, stride, pad=1, groups=groups) | |||||
| if platform == "GPU": | |||||
| self.conv2 = nn.Conv2d(width, width, 3, stride, pad_mode='pad', padding=1, group=groups) | |||||
| else: | |||||
| self.conv2 = GroupConv(width, width, 3, stride, pad=1, groups=groups) | |||||
| self.op_split = Split(axis=1, output_num=self.groups) | self.op_split = Split(axis=1, output_num=self.groups) | ||||
| self.op_concat = Concat(axis=1) | self.op_concat = Concat(axis=1) | ||||
| @@ -211,7 +216,7 @@ class ResNet(nn.Cell): | |||||
| Examples: | Examples: | ||||
| >>>ResNet() | >>>ResNet() | ||||
| """ | """ | ||||
| def __init__(self, block, layers, width_per_group=64, groups=1, use_se=False): | |||||
| def __init__(self, block, layers, width_per_group=64, groups=1, use_se=False, platform="Ascend"): | |||||
| super(ResNet, self).__init__() | super(ResNet, self).__init__() | ||||
| self.in_channels = 64 | self.in_channels = 64 | ||||
| self.groups = groups | self.groups = groups | ||||
| @@ -222,10 +227,10 @@ class ResNet(nn.Cell): | |||||
| self.relu = P.ReLU() | self.relu = P.ReLU() | ||||
| self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same') | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same') | ||||
| self.layer1 = self._make_layer(block, 64, layers[0], use_se=use_se) | |||||
| self.layer2 = self._make_layer(block, 128, layers[1], stride=2, use_se=use_se) | |||||
| self.layer3 = self._make_layer(block, 256, layers[2], stride=2, use_se=use_se) | |||||
| self.layer4 = self._make_layer(block, 512, layers[3], stride=2, use_se=use_se) | |||||
| self.layer1 = self._make_layer(block, 64, layers[0], use_se=use_se, platform=platform) | |||||
| self.layer2 = self._make_layer(block, 128, layers[1], stride=2, use_se=use_se, platform=platform) | |||||
| self.layer3 = self._make_layer(block, 256, layers[2], stride=2, use_se=use_se, platform=platform) | |||||
| self.layer4 = self._make_layer(block, 512, layers[3], stride=2, use_se=use_se, platform=platform) | |||||
| self.out_channels = 512 * block.expansion | self.out_channels = 512 * block.expansion | ||||
| self.cast = P.Cast() | self.cast = P.Cast() | ||||
| @@ -242,7 +247,7 @@ class ResNet(nn.Cell): | |||||
| return x | return x | ||||
| def _make_layer(self, block, out_channels, blocks_num, stride=1, use_se=False): | |||||
| def _make_layer(self, block, out_channels, blocks_num, stride=1, use_se=False, platform="Ascend"): | |||||
| """_make_layer""" | """_make_layer""" | ||||
| down_sample = None | down_sample = None | ||||
| if stride != 1 or self.in_channels != out_channels * block.expansion: | if stride != 1 or self.in_channels != out_channels * block.expansion: | ||||
| @@ -257,11 +262,12 @@ class ResNet(nn.Cell): | |||||
| down_sample=down_sample, | down_sample=down_sample, | ||||
| base_width=self.base_width, | base_width=self.base_width, | ||||
| groups=self.groups, | groups=self.groups, | ||||
| use_se=use_se)) | |||||
| use_se=use_se, | |||||
| platform=platform)) | |||||
| self.in_channels = out_channels * block.expansion | self.in_channels = out_channels * block.expansion | ||||
| for _ in range(1, blocks_num): | for _ in range(1, blocks_num): | ||||
| layers.append(block(self.in_channels, out_channels, | |||||
| base_width=self.base_width, groups=self.groups, use_se=use_se)) | |||||
| layers.append(block(self.in_channels, out_channels, base_width=self.base_width, | |||||
| groups=self.groups, use_se=use_se, platform=platform)) | |||||
| return nn.SequentialCell(layers) | return nn.SequentialCell(layers) | ||||
| @@ -269,5 +275,5 @@ class ResNet(nn.Cell): | |||||
| return self.out_channels | return self.out_channels | ||||
| def resnext50(): | |||||
| return ResNet(Bottleneck, [3, 4, 6, 3], width_per_group=4, groups=32) | |||||
| def resnext50(platform="Ascend"): | |||||
| return ResNet(Bottleneck, [3, 4, 6, 3], width_per_group=4, groups=32, platform=platform) | |||||
| @@ -36,7 +36,8 @@ config = ed({ | |||||
| "label_smooth": 1, | "label_smooth": 1, | ||||
| "label_smooth_factor": 0.1, | "label_smooth_factor": 0.1, | ||||
| "ckpt_interval": 1250, | |||||
| "ckpt_interval": 5, | |||||
| "ckpt_save_max": 5, | |||||
| "ckpt_path": 'outputs/', | "ckpt_path": 'outputs/', | ||||
| "is_save_on_master": 1, | "is_save_on_master": 1, | ||||
| @@ -143,8 +143,10 @@ def classification_dataset(data_dir, image_size, per_batch_size, max_epoch, rank | |||||
| de_dataset = de.GeneratorDataset(dataset, ["image", "label"], sampler=sampler) | de_dataset = de.GeneratorDataset(dataset, ["image", "label"], sampler=sampler) | ||||
| de_dataset.set_dataset_size(len(sampler)) | de_dataset.set_dataset_size(len(sampler)) | ||||
| de_dataset = de_dataset.map(input_columns="image", num_parallel_workers=8, operations=transform_img) | |||||
| de_dataset = de_dataset.map(input_columns="label", num_parallel_workers=8, operations=transform_label) | |||||
| de_dataset = de_dataset.map(input_columns="image", num_parallel_workers=num_parallel_workers, | |||||
| operations=transform_img) | |||||
| de_dataset = de_dataset.map(input_columns="label", num_parallel_workers=num_parallel_workers, | |||||
| operations=transform_label) | |||||
| columns_to_project = ["image", "label"] | columns_to_project = ["image", "label"] | ||||
| de_dataset = de_dataset.project(columns=columns_to_project) | de_dataset = de_dataset.project(columns=columns_to_project) | ||||
| @@ -50,9 +50,9 @@ class Resnet(ImageClassificationNetwork): | |||||
| Returns: | Returns: | ||||
| Resnet. | Resnet. | ||||
| """ | """ | ||||
| def __init__(self, backbone_name, num_classes): | |||||
| def __init__(self, backbone_name, num_classes, platform="Ascend"): | |||||
| self.backbone_name = backbone_name | self.backbone_name = backbone_name | ||||
| backbone = backbones.__dict__[self.backbone_name]() | |||||
| backbone = backbones.__dict__[self.backbone_name](platform=platform) | |||||
| out_channels = backbone.get_out_channels() | out_channels = backbone.get_out_channels() | ||||
| head = heads.CommonHead(num_classes=num_classes, out_channels=out_channels) | head = heads.CommonHead(num_classes=num_classes, out_channels=out_channels) | ||||
| super(Resnet, self).__init__(backbone, head) | super(Resnet, self).__init__(backbone, head) | ||||
| @@ -79,7 +79,7 @@ class Resnet(ImageClassificationNetwork): | |||||
| def get_network(backbone_name, num_classes): | |||||
| def get_network(backbone_name, num_classes, platform="Ascend"): | |||||
| if backbone_name in ['resnext50']: | if backbone_name in ['resnext50']: | ||||
| return Resnet(backbone_name, num_classes) | |||||
| return Resnet(backbone_name, num_classes, platform) | |||||
| return None | return None | ||||
| @@ -0,0 +1,56 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Auto mixed precision.""" | |||||
| import mindspore.nn as nn | |||||
| from mindspore.ops import functional as F | |||||
| from mindspore._checkparam import Validator as validator | |||||
| from mindspore.common import dtype as mstype | |||||
| class OutputTo(nn.Cell): | |||||
| "Cast cell output back to float16 or float32" | |||||
| def __init__(self, op, to_type=mstype.float16): | |||||
| super(OutputTo, self).__init__(auto_prefix=False) | |||||
| self._op = op | |||||
| validator.check_type_name('to_type', to_type, [mstype.float16, mstype.float32], None) | |||||
| self.to_type = to_type | |||||
| def construct(self, x): | |||||
| return F.cast(self._op(x), self.to_type) | |||||
| def auto_mixed_precision(network): | |||||
| """Do keep batchnorm fp32.""" | |||||
| cells = network.name_cells() | |||||
| change = False | |||||
| network.to_float(mstype.float16) | |||||
| for name in cells: | |||||
| subcell = cells[name] | |||||
| if subcell == network: | |||||
| continue | |||||
| elif name == 'fc': | |||||
| network.insert_child_to_cell(name, OutputTo(subcell, mstype.float32)) | |||||
| change = True | |||||
| elif name == 'conv2': | |||||
| subcell.to_float(mstype.float32) | |||||
| change = True | |||||
| elif isinstance(subcell, (nn.BatchNorm2d, nn.BatchNorm1d)): | |||||
| network.insert_child_to_cell(name, OutputTo(subcell.to_float(mstype.float32), mstype.float16)) | |||||
| change = True | |||||
| else: | |||||
| auto_mixed_precision(subcell) | |||||
| if isinstance(network, nn.SequentialCell) and change: | |||||
| network.cell_list = list(network.cells()) | |||||
| @@ -29,14 +29,10 @@ class GlobalAvgPooling(nn.Cell): | |||||
| """ | """ | ||||
| def __init__(self): | def __init__(self): | ||||
| super(GlobalAvgPooling, self).__init__() | super(GlobalAvgPooling, self).__init__() | ||||
| self.mean = P.ReduceMean(True) | |||||
| self.shape = P.Shape() | |||||
| self.reshape = P.Reshape() | |||||
| self.mean = P.ReduceMean(False) | |||||
| def construct(self, x): | def construct(self, x): | ||||
| x = self.mean(x, (2, 3)) | x = self.mean(x, (2, 3)) | ||||
| b, c, _, _ = self.shape(x) | |||||
| x = self.reshape(x, (b, c)) | |||||
| return x | return x | ||||
| @@ -36,11 +36,9 @@ from src.warmup_cosine_annealing_lr import warmup_cosine_annealing_lr | |||||
| from src.utils.logging import get_logger | from src.utils.logging import get_logger | ||||
| from src.utils.optimizers__init__ import get_param_groups | from src.utils.optimizers__init__ import get_param_groups | ||||
| from src.image_classification import get_network | from src.image_classification import get_network | ||||
| from src.utils.auto_mixed_precision import auto_mixed_precision | |||||
| from src.config import config | from src.config import config | ||||
| devid = int(os.getenv('DEVICE_ID')) | |||||
| context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, | |||||
| device_target="Ascend", save_graphs=False, device_id=devid) | |||||
| class BuildTrainNetwork(nn.Cell): | class BuildTrainNetwork(nn.Cell): | ||||
| """build training network""" | """build training network""" | ||||
| @@ -109,6 +107,7 @@ class ProgressMonitor(Callback): | |||||
| def parse_args(cloud_args=None): | def parse_args(cloud_args=None): | ||||
| """parameters""" | """parameters""" | ||||
| parser = argparse.ArgumentParser('mindspore classification training') | parser = argparse.ArgumentParser('mindspore classification training') | ||||
| parser.add_argument('--platform', type=str, default='Ascend', choices=('Ascend', 'GPU'), help='run platform') | |||||
| # dataset related | # dataset related | ||||
| parser.add_argument('--data_dir', type=str, default='', help='train data dir') | parser.add_argument('--data_dir', type=str, default='', help='train data dir') | ||||
| @@ -141,6 +140,7 @@ def parse_args(cloud_args=None): | |||||
| args.label_smooth = config.label_smooth | args.label_smooth = config.label_smooth | ||||
| args.label_smooth_factor = config.label_smooth_factor | args.label_smooth_factor = config.label_smooth_factor | ||||
| args.ckpt_interval = config.ckpt_interval | args.ckpt_interval = config.ckpt_interval | ||||
| args.ckpt_save_max = config.ckpt_save_max | |||||
| args.ckpt_path = config.ckpt_path | args.ckpt_path = config.ckpt_path | ||||
| args.is_save_on_master = config.is_save_on_master | args.is_save_on_master = config.is_save_on_master | ||||
| args.rank = config.rank | args.rank = config.rank | ||||
| @@ -166,12 +166,25 @@ def merge_args(args, cloud_args): | |||||
| def train(cloud_args=None): | def train(cloud_args=None): | ||||
| """training process""" | """training process""" | ||||
| args = parse_args(cloud_args) | args = parse_args(cloud_args) | ||||
| context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, | |||||
| device_target=args.platform, save_graphs=False) | |||||
| if os.getenv('DEVICE_ID', "not_set").isdigit(): | |||||
| context.set_context(device_id=int(os.getenv('DEVICE_ID'))) | |||||
| # init distributed | # init distributed | ||||
| if args.is_distributed: | if args.is_distributed: | ||||
| init() | |||||
| if args.platform == "Ascend": | |||||
| init() | |||||
| else: | |||||
| init("nccl") | |||||
| args.rank = get_rank() | args.rank = get_rank() | ||||
| args.group_size = get_group_size() | args.group_size = get_group_size() | ||||
| parallel_mode = ParallelMode.DATA_PARALLEL | |||||
| context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=args.group_size, | |||||
| parameter_broadcast=True, mirror_mean=True) | |||||
| else: | |||||
| args.rank = 0 | |||||
| args.group_size = 1 | |||||
| if args.is_dynamic_loss_scale == 1: | if args.is_dynamic_loss_scale == 1: | ||||
| args.loss_scale = 1 # for dynamic loss scale can not set loss scale in momentum opt | args.loss_scale = 1 # for dynamic loss scale can not set loss scale in momentum opt | ||||
| @@ -192,7 +205,7 @@ def train(cloud_args=None): | |||||
| # dataloader | # dataloader | ||||
| de_dataset = classification_dataset(args.data_dir, args.image_size, | de_dataset = classification_dataset(args.data_dir, args.image_size, | ||||
| args.per_batch_size, 1, | args.per_batch_size, 1, | ||||
| args.rank, args.group_size) | |||||
| args.rank, args.group_size, num_parallel_workers=8) | |||||
| de_dataset.map_model = 4 # !!!important | de_dataset.map_model = 4 # !!!important | ||||
| args.steps_per_epoch = de_dataset.get_dataset_size() | args.steps_per_epoch = de_dataset.get_dataset_size() | ||||
| @@ -201,15 +214,9 @@ def train(cloud_args=None): | |||||
| # network | # network | ||||
| args.logger.important_info('start create network') | args.logger.important_info('start create network') | ||||
| # get network and init | # get network and init | ||||
| network = get_network(args.backbone, args.num_classes) | |||||
| network = get_network(args.backbone, args.num_classes, platform=args.platform) | |||||
| if network is None: | if network is None: | ||||
| raise NotImplementedError('not implement {}'.format(args.backbone)) | raise NotImplementedError('not implement {}'.format(args.backbone)) | ||||
| network.add_flags_recursive(fp16=True) | |||||
| # loss | |||||
| if not args.label_smooth: | |||||
| args.label_smooth_factor = 0.0 | |||||
| criterion = CrossEntropy(smooth_factor=args.label_smooth_factor, | |||||
| num_classes=args.num_classes) | |||||
| # load pretrain model | # load pretrain model | ||||
| if os.path.isfile(args.pretrained): | if os.path.isfile(args.pretrained): | ||||
| @@ -252,31 +259,29 @@ def train(cloud_args=None): | |||||
| loss_scale=args.loss_scale) | loss_scale=args.loss_scale) | ||||
| criterion.add_flags_recursive(fp32=True) | |||||
| # loss | |||||
| if not args.label_smooth: | |||||
| args.label_smooth_factor = 0.0 | |||||
| loss = CrossEntropy(smooth_factor=args.label_smooth_factor, num_classes=args.num_classes) | |||||
| # package training process, adjust lr + forward + backward + optimizer | |||||
| train_net = BuildTrainNetwork(network, criterion) | |||||
| if args.is_distributed: | |||||
| parallel_mode = ParallelMode.DATA_PARALLEL | |||||
| else: | |||||
| parallel_mode = ParallelMode.STAND_ALONE | |||||
| if args.is_dynamic_loss_scale == 1: | if args.is_dynamic_loss_scale == 1: | ||||
| loss_scale_manager = DynamicLossScaleManager(init_loss_scale=65536, scale_factor=2, scale_window=2000) | loss_scale_manager = DynamicLossScaleManager(init_loss_scale=65536, scale_factor=2, scale_window=2000) | ||||
| else: | else: | ||||
| loss_scale_manager = FixedLossScaleManager(args.loss_scale, drop_overflow_update=False) | loss_scale_manager = FixedLossScaleManager(args.loss_scale, drop_overflow_update=False) | ||||
| # Model api changed since TR5_branch 2020/03/09 | |||||
| context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=args.group_size, | |||||
| parameter_broadcast=True, mirror_mean=True) | |||||
| model = Model(train_net, optimizer=opt, metrics=None, loss_scale_manager=loss_scale_manager) | |||||
| if args.platform == "Ascend": | |||||
| model = Model(network, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale_manager, | |||||
| metrics={'acc'}, amp_level="O3") | |||||
| else: | |||||
| auto_mixed_precision(network) | |||||
| model = Model(network, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale_manager, metrics={'acc'}) | |||||
| # checkpoint save | # checkpoint save | ||||
| progress_cb = ProgressMonitor(args) | progress_cb = ProgressMonitor(args) | ||||
| callbacks = [progress_cb,] | callbacks = [progress_cb,] | ||||
| if args.rank_save_ckpt_flag: | if args.rank_save_ckpt_flag: | ||||
| ckpt_max_num = args.max_epoch * args.steps_per_epoch // args.ckpt_interval | |||||
| ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval, | |||||
| keep_checkpoint_max=ckpt_max_num) | |||||
| ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval * args.steps_per_epoch, | |||||
| keep_checkpoint_max=args.ckpt_save_max) | |||||
| ckpt_cb = ModelCheckpoint(config=ckpt_config, | ckpt_cb = ModelCheckpoint(config=ckpt_config, | ||||
| directory=args.outputs_dir, | directory=args.outputs_dir, | ||||
| prefix='{}'.format(args.rank)) | prefix='{}'.format(args.rank)) | ||||