Merge pull request !5214 from hanhuifeng/yolov3_gputags/v1.0.0
| @@ -53,8 +53,8 @@ Dataset used: [COCO2014](https://cocodataset.org/#download) | |||
| # [Environment Requirements](#contents) | |||
| - Hardware(Ascend) | |||
| - Prepare hardware environment with Ascend processor. If you want to try Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources. | |||
| - Hardware(Ascend/GPU) | |||
| - Prepare hardware environment with Ascend or GPU processor. If you want to try Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources. | |||
| - Framework | |||
| - [MindSpore](http://10.90.67.50/mindspore/archive/20200506/OpenSource/me_vm_x86/) | |||
| - For more information, please check the resources below: | |||
| @@ -65,7 +65,7 @@ Dataset used: [COCO2014](https://cocodataset.org/#download) | |||
| # [Quick Start](#contents) | |||
| After installing MindSpore via the official website, you can start training and evaluation in Ascend as follows: | |||
| After installing MindSpore via the official website, you can start training and evaluation in as follows. If running on GPU, please add `--device_target=GPU` in the python command or use the "_gpu" shell script ("xxx_gpu.sh"). | |||
| ``` | |||
| # The darknet53_backbone.ckpt in the follow script is got from darknet53 training like paper. | |||
| @@ -87,9 +87,12 @@ python train.py \ | |||
| # standalone training example(1p) by shell script | |||
| sh run_standalone_train.sh dataset/coco2014 darknet53_backbone.ckpt | |||
| # distributed training example(8p) by shell script | |||
| # For Ascend device, distributed training example(8p) by shell script | |||
| sh run_distribute_train.sh dataset/coco2014 darknet53_backbone.ckpt rank_table_8p.json | |||
| # For GPU device, distributed training example(8p) by shell script | |||
| sh run_distribute_train_gpu.sh dataset/coco2014 darknet53_backbone.ckpt | |||
| # run evaluation by python command | |||
| python eval.py \ | |||
| --data_dir=./dataset/coco2014 \ | |||
| @@ -113,6 +116,9 @@ sh run_eval.sh dataset/coco2014/ checkpoint/0-319_102400.ckpt | |||
| ├─run_standalone_train.sh # launch standalone training(1p) in ascend | |||
| ├─run_distribute_train.sh # launch distributed training(8p) in ascend | |||
| └─run_eval.sh # launch evaluating in ascend | |||
| ├─run_standalone_train_gpu.sh # launch standalone training(1p) in gpu | |||
| ├─run_distribute_train_gpu.sh # launch distributed training(8p) in gpu | |||
| └─run_eval_gpu.sh # launch evaluating in gpu | |||
| ├─src | |||
| ├─__init__.py # python init file | |||
| ├─config.py # parameter configuration | |||
| @@ -138,6 +144,7 @@ Major parameters in train.py as follow. | |||
| optional arguments: | |||
| -h, --help show this help message and exit | |||
| --device_target device where the code will be implemented: "Ascend" | "GPU", default is "Ascend" | |||
| --data_dir DATA_DIR Train dataset directory. | |||
| --per_batch_size PER_BATCH_SIZE | |||
| Batch size for Training. Default: 32. | |||
| @@ -212,7 +219,7 @@ python train.py \ | |||
| --lr_scheduler=cosine_annealing > log.txt 2>&1 & | |||
| ``` | |||
| The python command above will run in the background, you can view the results through the file `log.txt`. | |||
| The python command above will run in the background, you can view the results through the file `log.txt`. If running on GPU, please add `--device_target=GPU` in the python command. | |||
| After training, you'll get some checkpoint files under the outputs folder by default. The loss value will be achieved as follows: | |||
| @@ -228,9 +235,14 @@ The model checkpoint will be saved in outputs directory. | |||
| ### Distributed Training | |||
| For Ascend device, distributed training example(8p) by shell script | |||
| ``` | |||
| sh run_distribute_train.sh dataset/coco2014 darknet53_backbone.ckpt rank_table_8p.json | |||
| ``` | |||
| For GPU device, distributed training example(8p) by shell script | |||
| ``` | |||
| sh run_distribute_train_gpu.sh dataset/coco2014 darknet53_backbone.ckpt | |||
| ``` | |||
| The above shell script will run distribute training in the background. You can view the results through the file `train_parallel[X]/log.txt`. The loss value will be achieved as follows: | |||
| @@ -254,7 +266,7 @@ epoch[319], iter[102300], loss:35.430038, 423.49 imgs/sec, lr:2.409552052995423e | |||
| ### Evaluation | |||
| Before running the command below. | |||
| Before running the command below. If running on GPU, please add `--device_target=GPU` in the python command or use the "_gpu" shell script ("xxx_gpu.sh"). | |||
| ``` | |||
| python eval.py \ | |||
| @@ -35,9 +35,6 @@ from src.logger import get_logger | |||
| from src.yolo_dataset import create_yolo_dataset | |||
| from src.config import ConfigYOLOV3DarkNet53 | |||
| devid = int(os.getenv('DEVICE_ID')) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True, device_id=devid) | |||
| class Redirct: | |||
| def __init__(self): | |||
| @@ -208,6 +205,10 @@ def parse_args(): | |||
| """Parse arguments.""" | |||
| parser = argparse.ArgumentParser('mindspore coco testing') | |||
| # device related | |||
| parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'], | |||
| help='device where the code will be implemented. (Default: Ascend)') | |||
| # dataset related | |||
| parser.add_argument('--data_dir', type=str, default='', help='train data dir') | |||
| parser.add_argument('--per_batch_size', default=1, type=int, help='batch size for per gpu') | |||
| @@ -243,10 +244,13 @@ def test(): | |||
| start_time = time.time() | |||
| args = parse_args() | |||
| devid = int(os.getenv('DEVICE_ID')) if os.getenv('DEVICE_ID') else 0 | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, save_graphs=True, device_id=devid) | |||
| # logger | |||
| args.outputs_dir = os.path.join(args.log_path, | |||
| datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S')) | |||
| rank_id = int(os.environ.get('RANK_ID')) | |||
| rank_id = int(os.environ.get('RANK_ID')) if os.environ.get('RANK_ID') else 0 | |||
| args.logger = get_logger(args.outputs_dir, rank_id) | |||
| context.reset_auto_parallel_context() | |||
| @@ -0,0 +1,67 @@ | |||
| #!/bin/bash | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| if [ $# != 2 ] | |||
| then | |||
| echo "Usage: sh run_distribute_train_gpu.sh [DATASET_PATH] [PRETRAINED_BACKBONE]" | |||
| exit 1 | |||
| fi | |||
| get_real_path(){ | |||
| if [ "${1:0:1}" == "/" ]; then | |||
| echo "$1" | |||
| else | |||
| echo "$(realpath -m $PWD/$1)" | |||
| fi | |||
| } | |||
| DATASET_PATH=$(get_real_path $1) | |||
| PRETRAINED_BACKBONE=$(get_real_path $2) | |||
| echo $DATASET_PATH | |||
| echo $PRETRAINED_BACKBONE | |||
| if [ ! -d $DATASET_PATH ] | |||
| then | |||
| echo "error: DATASET_PATH=$DATASET_PATH is not a directory" | |||
| exit 1 | |||
| fi | |||
| if [ ! -f $PRETRAINED_BACKBONE ] | |||
| then | |||
| echo "error: PRETRAINED_PATH=$PRETRAINED_BACKBONE is not a file" | |||
| exit 1 | |||
| fi | |||
| export DEVICE_NUM=8 | |||
| rm -rf ./train_parallel | |||
| mkdir ./train_parallel | |||
| cp ../*.py ./train_parallel | |||
| cp -r ../src ./train_parallel | |||
| cd ./train_parallel || exit | |||
| env > env.log | |||
| mpirun --allow-run-as-root -n ${DEVICE_NUM} python train.py \ | |||
| --data_dir=$DATASET_PATH \ | |||
| --pretrained_backbone=$PRETRAINED_BACKBONE \ | |||
| --device_target=GPU \ | |||
| --is_distributed=1 \ | |||
| --lr=0.1 \ | |||
| --T_max=320 \ | |||
| --max_epoch=320 \ | |||
| --warmup_epochs=4 \ | |||
| --training_shape=416 \ | |||
| --lr_scheduler=cosine_annealing > log.txt 2>&1 & | |||
| cd .. | |||
| @@ -0,0 +1,67 @@ | |||
| #!/bin/bash | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| if [ $# != 2 ] | |||
| then | |||
| echo "Usage: sh run_eval_gpu.sh [DATASET_PATH] [CHECKPOINT_PATH]" | |||
| exit 1 | |||
| fi | |||
| get_real_path(){ | |||
| if [ "${1:0:1}" == "/" ]; then | |||
| echo "$1" | |||
| else | |||
| echo "$(realpath -m $PWD/$1)" | |||
| fi | |||
| } | |||
| DATASET_PATH=$(get_real_path $1) | |||
| CHECKPOINT_PATH=$(get_real_path $2) | |||
| echo $DATASET_PATH | |||
| echo $CHECKPOINT_PATH | |||
| if [ ! -d $DATASET_PATH ] | |||
| then | |||
| echo "error: DATASET_PATH=$PATH1 is not a directory" | |||
| exit 1 | |||
| fi | |||
| if [ ! -f $CHECKPOINT_PATH ] | |||
| then | |||
| echo "error: CHECKPOINT_PATH=$PATH2 is not a file" | |||
| exit 1 | |||
| fi | |||
| export DEVICE_NUM=1 | |||
| export DEVICE_ID=0 | |||
| export RANK_SIZE=$DEVICE_NUM | |||
| export RANK_ID=0 | |||
| if [ -d "eval" ]; | |||
| then | |||
| rm -rf ./eval | |||
| fi | |||
| mkdir ./eval | |||
| cp ../*.py ./eval | |||
| cp -r ../src ./eval | |||
| cd ./eval || exit | |||
| env > env.log | |||
| echo "start infering for device $DEVICE_ID" | |||
| python eval.py \ | |||
| --device_target="GPU" \ | |||
| --data_dir=$DATASET_PATH \ | |||
| --pretrained=$CHECKPOINT_PATH \ | |||
| --testing_shape=416 > log.txt 2>&1 & | |||
| cd .. | |||
| @@ -0,0 +1,75 @@ | |||
| #!/bin/bash | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| if [ $# != 2 ] | |||
| then | |||
| echo "Usage: sh run_standalone_train_gpu.sh [DATASET_PATH] [PRETRAINED_BACKBONE]" | |||
| exit 1 | |||
| fi | |||
| get_real_path(){ | |||
| if [ "${1:0:1}" == "/" ]; then | |||
| echo "$1" | |||
| else | |||
| echo "$(realpath -m $PWD/$1)" | |||
| fi | |||
| } | |||
| DATASET_PATH=$(get_real_path $1) | |||
| echo $DATASET_PATH | |||
| PRETRAINED_BACKBONE=$(get_real_path $2) | |||
| echo $PRETRAINED_BACKBONE | |||
| if [ ! -d $DATASET_PATH ] | |||
| then | |||
| echo "error: DATASET_PATH=$DATASET_PATH is not a directory" | |||
| exit 1 | |||
| fi | |||
| if [ ! -f $PRETRAINED_BACKBONE ] | |||
| then | |||
| echo "error: PRETRAINED_PATH=$PRETRAINED_BACKBONE is not a file" | |||
| exit 1 | |||
| fi | |||
| export DEVICE_NUM=1 | |||
| export DEVICE_ID=0 | |||
| export RANK_ID=0 | |||
| export RANK_SIZE=1 | |||
| if [ -d "train" ]; | |||
| then | |||
| rm -rf ./train | |||
| fi | |||
| mkdir ./train | |||
| cp ../*.py ./train | |||
| cp -r ../src ./train | |||
| cd ./train || exit | |||
| echo "start training for device $DEVICE_ID" | |||
| env > env.log | |||
| python train.py \ | |||
| --device_targe="GPU" \ | |||
| --data_dir=$DATASET_PATH \ | |||
| --pretrained_backbone=$PRETRAINED_BACKBONE \ | |||
| --is_distributed=0 \ | |||
| --lr=0.1 \ | |||
| --T_max=320 \ | |||
| --max_epoch=320 \ | |||
| --warmup_epochs=4 \ | |||
| --training_shape=416 \ | |||
| --lr_scheduler=cosine_annealing > log.txt 2>&1 & | |||
| cd .. | |||
| @@ -465,6 +465,11 @@ class MultiScaleTrans: | |||
| self.seed_list = self.generate_seed_list(seed_num=self.seed_num) | |||
| self.resize_count_num = int(np.ceil(self.dataset_size / self.resize_rate)) | |||
| self.device_num = device_num | |||
| self.anchor_scales = config.anchor_scales | |||
| self.num_classes = config.num_classes | |||
| self.max_box = config.max_box | |||
| self.label_smooth = config.label_smooth | |||
| self.label_smooth_factor = config.label_smooth_factor | |||
| def generate_seed_list(self, init_seed=1234, seed_num=int(1e6), seed_range=(1, 1000)): | |||
| seed_list = [] | |||
| @@ -474,13 +479,20 @@ class MultiScaleTrans: | |||
| seed_list.append(seed) | |||
| return seed_list | |||
| def __call__(self, imgs, annos, batchInfo): | |||
| def __call__(self, imgs, annos, x1, x2, x3, x4, x5, x6, batchInfo): | |||
| epoch_num = batchInfo.get_epoch_num() | |||
| size_idx = int(batchInfo.get_batch_num() / self.resize_rate) | |||
| seed_key = self.seed_list[(epoch_num * self.resize_count_num + size_idx) % self.seed_num] | |||
| ret_imgs = [] | |||
| ret_annos = [] | |||
| bbox1 = [] | |||
| bbox2 = [] | |||
| bbox3 = [] | |||
| gt1 = [] | |||
| gt2 = [] | |||
| gt3 = [] | |||
| if self.size_dict.get(seed_key, None) is None: | |||
| random.seed(seed_key) | |||
| new_size = random.choice(self.config.multi_scale) | |||
| @@ -491,8 +503,19 @@ class MultiScaleTrans: | |||
| for img, anno in zip(imgs, annos): | |||
| img, anno = preprocess_fn(img, anno, self.config, input_size, self.device_num) | |||
| ret_imgs.append(img.transpose(2, 0, 1).copy()) | |||
| ret_annos.append(anno) | |||
| return np.array(ret_imgs), np.array(ret_annos) | |||
| bbox_true_1, bbox_true_2, bbox_true_3, gt_box1, gt_box2, gt_box3 = \ | |||
| _preprocess_true_boxes(true_boxes=anno, anchors=self.anchor_scales, in_shape=img.shape[0:2], | |||
| num_classes=self.num_classes, max_boxes=self.max_box, | |||
| label_smooth=self.label_smooth, label_smooth_factor=self.label_smooth_factor) | |||
| bbox1.append(bbox_true_1) | |||
| bbox2.append(bbox_true_2) | |||
| bbox3.append(bbox_true_3) | |||
| gt1.append(gt_box1) | |||
| gt2.append(gt_box2) | |||
| gt3.append(gt_box3) | |||
| ret_annos.append(0) | |||
| return np.array(ret_imgs), np.array(ret_annos), np.array(bbox1), np.array(bbox2), np.array(bbox3), \ | |||
| np.array(gt1), np.array(gt2), np.array(gt3) | |||
| def thread_batch_preprocess_true_box(annos, config, input_shape, result_index, batch_bbox_true_1, batch_bbox_true_2, | |||
| @@ -15,6 +15,9 @@ | |||
| """Util class or function.""" | |||
| from mindspore.train.serialization import load_checkpoint | |||
| import mindspore.nn as nn | |||
| import mindspore.common.dtype as mstype | |||
| from .yolo import YoloLossBlock | |||
| class AverageMeter: | |||
| @@ -175,3 +178,10 @@ class ShapeRecord: | |||
| for key in self.shape_record: | |||
| rate = self.shape_record[key] / float(self.shape_record['total']) | |||
| logger.info('shape {}: {:.2f}%'.format(key, rate*100)) | |||
| def keep_loss_fp32(network): | |||
| """Keep loss of network with float32""" | |||
| for _, cell in network.cells_and_names(): | |||
| if isinstance(cell, (YoloLossBlock,)): | |||
| cell.to_float(mstype.float32) | |||
| @@ -15,6 +15,7 @@ | |||
| """YOLOV3 dataset.""" | |||
| import os | |||
| import multiprocessing | |||
| from PIL import Image | |||
| from pycocotools.coco import COCO | |||
| import mindspore.dataset as de | |||
| @@ -126,7 +127,7 @@ class COCOYoloDataset: | |||
| tmp.append(int(label)) | |||
| # tmp [x_min y_min x_max y_max, label] | |||
| out_target.append(tmp) | |||
| return img, out_target | |||
| return img, out_target, [], [], [], [], [], [] | |||
| def __len__(self): | |||
| return len(self.img_ids) | |||
| @@ -155,20 +156,22 @@ def create_yolo_dataset(image_dir, anno_path, batch_size, max_epoch, device_num, | |||
| hwc_to_chw = CV.HWC2CHW() | |||
| config.dataset_size = len(yolo_dataset) | |||
| num_parallel_workers1 = int(64 / device_num) | |||
| num_parallel_workers2 = int(16 / device_num) | |||
| cores = multiprocessing.cpu_count() | |||
| num_parallel_workers = int(cores / device_num) | |||
| if is_training: | |||
| multi_scale_trans = MultiScaleTrans(config, device_num) | |||
| dataset_column_names = ["image", "annotation", "bbox1", "bbox2", "bbox3", | |||
| "gt_box1", "gt_box2", "gt_box3"] | |||
| if device_num != 8: | |||
| ds = de.GeneratorDataset(yolo_dataset, column_names=["image", "annotation"], | |||
| num_parallel_workers=num_parallel_workers1, | |||
| ds = de.GeneratorDataset(yolo_dataset, column_names=dataset_column_names, | |||
| num_parallel_workers=min(32, num_parallel_workers), | |||
| sampler=distributed_sampler) | |||
| ds = ds.batch(batch_size, per_batch_map=multi_scale_trans, input_columns=['image', 'annotation'], | |||
| num_parallel_workers=num_parallel_workers2, drop_remainder=True) | |||
| ds = ds.batch(batch_size, per_batch_map=multi_scale_trans, input_columns=dataset_column_names, | |||
| num_parallel_workers=min(32, num_parallel_workers), drop_remainder=True) | |||
| else: | |||
| ds = de.GeneratorDataset(yolo_dataset, column_names=["image", "annotation"], sampler=distributed_sampler) | |||
| ds = ds.batch(batch_size, per_batch_map=multi_scale_trans, input_columns=['image', 'annotation'], | |||
| num_parallel_workers=8, drop_remainder=True) | |||
| ds = de.GeneratorDataset(yolo_dataset, column_names=dataset_column_names, sampler=distributed_sampler) | |||
| ds = ds.batch(batch_size, per_batch_map=multi_scale_trans, input_columns=dataset_column_names, | |||
| num_parallel_workers=min(8, num_parallel_workers), drop_remainder=True) | |||
| else: | |||
| ds = de.GeneratorDataset(yolo_dataset, column_names=["image", "img_id"], | |||
| sampler=distributed_sampler) | |||
| @@ -28,6 +28,8 @@ from mindspore.train.callback import ModelCheckpoint, RunContext | |||
| from mindspore.train.callback import _InternalCallbackParam, CheckpointConfig | |||
| import mindspore as ms | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from mindspore import amp | |||
| from mindspore.train.loss_scale_manager import FixedLossScaleManager | |||
| from src.yolo import YOLOV3DarkNet53, YoloWithLossCell, TrainingWrapper | |||
| from src.logger import get_logger | |||
| @@ -37,13 +39,7 @@ from src.lr_scheduler import warmup_step_lr, warmup_cosine_annealing_lr, \ | |||
| from src.yolo_dataset import create_yolo_dataset | |||
| from src.initializer import default_recurisive_init | |||
| from src.config import ConfigYOLOV3DarkNet53 | |||
| from src.transforms import batch_preprocess_true_box, batch_preprocess_true_box_single | |||
| from src.util import ShapeRecord | |||
| devid = int(os.getenv('DEVICE_ID')) | |||
| context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, | |||
| device_target="Ascend", save_graphs=True, device_id=devid) | |||
| from src.util import keep_loss_fp32 | |||
| class BuildTrainNetwork(nn.Cell): | |||
| @@ -62,6 +58,10 @@ def parse_args(): | |||
| """Parse train arguments.""" | |||
| parser = argparse.ArgumentParser('mindspore coco training') | |||
| # device related | |||
| parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'], | |||
| help='device where the code will be implemented. (Default: Ascend)') | |||
| # dataset related | |||
| parser.add_argument('--data_dir', type=str, help='Train dataset directory.') | |||
| parser.add_argument('--per_batch_size', default=32, type=int, help='Batch size for Training. Default: 32.') | |||
| @@ -136,9 +136,16 @@ def train(): | |||
| """Train function.""" | |||
| args = parse_args() | |||
| devid = int(os.getenv('DEVICE_ID')) if os.getenv('DEVICE_ID') else 0 | |||
| context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, | |||
| device_target=args.device_target, save_graphs=True, device_id=devid) | |||
| # init distributed | |||
| if args.is_distributed: | |||
| init() | |||
| if args.device_target == "Ascend": | |||
| init() | |||
| else: | |||
| init("nccl") | |||
| args.rank = get_rank() | |||
| args.group_size = get_group_size() | |||
| @@ -259,9 +266,19 @@ def train(): | |||
| momentum=args.momentum, | |||
| weight_decay=args.weight_decay, | |||
| loss_scale=args.loss_scale) | |||
| network = TrainingWrapper(network, opt) | |||
| network.set_train() | |||
| enable_amp = False | |||
| is_gpu = context.get_context("device_target") == "GPU" | |||
| if is_gpu: | |||
| enable_amp = True | |||
| if enable_amp: | |||
| loss_scale_value = 1.0 | |||
| loss_scale = FixedLossScaleManager(loss_scale_value, drop_overflow_update=False) | |||
| network = amp.build_train_network(network, optimizer=opt, loss_scale_manager=loss_scale, | |||
| level="O2", keep_batchnorm_fp32=True) | |||
| keep_loss_fp32(network) | |||
| else: | |||
| network = TrainingWrapper(network, opt) | |||
| network.set_train() | |||
| if args.rank_save_ckpt_flag: | |||
| # checkpoint save | |||
| @@ -282,28 +299,19 @@ def train(): | |||
| t_end = time.time() | |||
| data_loader = ds.create_dict_iterator() | |||
| shape_record = ShapeRecord() | |||
| for i, data in enumerate(data_loader): | |||
| images = data["image"] | |||
| input_shape = images.shape[2:4] | |||
| args.logger.info('iter[{}], shape{}'.format(i, input_shape[0])) | |||
| shape_record.set(input_shape) | |||
| images = Tensor(images) | |||
| annos = data["annotation"] | |||
| if args.group_size == 1: | |||
| batch_y_true_0, batch_y_true_1, batch_y_true_2, batch_gt_box0, batch_gt_box1, batch_gt_box2 = \ | |||
| batch_preprocess_true_box(annos, config, input_shape) | |||
| else: | |||
| batch_y_true_0, batch_y_true_1, batch_y_true_2, batch_gt_box0, batch_gt_box1, batch_gt_box2 = \ | |||
| batch_preprocess_true_box_single(annos, config, input_shape) | |||
| batch_y_true_0 = Tensor(batch_y_true_0) | |||
| batch_y_true_1 = Tensor(batch_y_true_1) | |||
| batch_y_true_2 = Tensor(batch_y_true_2) | |||
| batch_gt_box0 = Tensor(batch_gt_box0) | |||
| batch_gt_box1 = Tensor(batch_gt_box1) | |||
| batch_gt_box2 = Tensor(batch_gt_box2) | |||
| batch_y_true_0 = Tensor(data['bbox1']) | |||
| batch_y_true_1 = Tensor(data['bbox2']) | |||
| batch_y_true_2 = Tensor(data['bbox3']) | |||
| batch_gt_box0 = Tensor(data['gt_box1']) | |||
| batch_gt_box1 = Tensor(data['gt_box2']) | |||
| batch_gt_box2 = Tensor(data['gt_box3']) | |||
| input_shape = Tensor(tuple(input_shape[::-1]), ms.float32) | |||
| loss = network(images, batch_y_true_0, batch_y_true_1, batch_y_true_2, batch_gt_box0, batch_gt_box1, | |||