Merge pull request !4319 from chengxb7532/mastertags/v0.7.0-beta
| @@ -29,7 +29,7 @@ from mindspore._checkparam import Rel | |||
| import mindspore.context as context | |||
| from .normalization import BatchNorm2d, BatchNorm1d | |||
| from .activation import get_activation, ReLU | |||
| from .activation import get_activation, ReLU, LeakyReLU | |||
| from ..cell import Cell | |||
| from . import conv, basic | |||
| from ..._checkparam import ParamValidator as validator | |||
| @@ -115,7 +115,11 @@ class Conv2dBnAct(Cell): | |||
| weight_init='normal', | |||
| bias_init='zeros', | |||
| has_bn=False, | |||
| activation=None): | |||
| momentum=0.9, | |||
| eps=1e-5, | |||
| activation=None, | |||
| alpha=0.2, | |||
| after_fake=True): | |||
| super(Conv2dBnAct, self).__init__() | |||
| if context.get_context('device_target') == "Ascend" and group > 1: | |||
| @@ -145,9 +149,13 @@ class Conv2dBnAct(Cell): | |||
| self.has_bn = validator.check_bool("has_bn", has_bn) | |||
| self.has_act = activation is not None | |||
| self.after_fake = after_fake | |||
| if has_bn: | |||
| self.batchnorm = BatchNorm2d(out_channels) | |||
| self.activation = get_activation(activation) | |||
| self.batchnorm = BatchNorm2d(out_channels, eps, momentum) | |||
| if activation == "leakyrelu": | |||
| self.activation = LeakyReLU(alpha) | |||
| else: | |||
| self.activation = get_activation(activation) | |||
| def construct(self, x): | |||
| x = self.conv(x) | |||
| @@ -244,7 +244,7 @@ class ConvertToQuantNetwork: | |||
| subcell.conv = conv_inner | |||
| if subcell.has_act and subcell.activation is not None: | |||
| subcell.activation = self._convert_activation(subcell.activation) | |||
| else: | |||
| elif subcell.after_fake: | |||
| subcell.has_act = True | |||
| subcell.activation = _AddFakeQuantAfterSubCell(F.identity, | |||
| num_bits=self.act_bits, | |||
| @@ -274,7 +274,7 @@ class ConvertToQuantNetwork: | |||
| subcell.dense = dense_inner | |||
| if subcell.has_act and subcell.activation is not None: | |||
| subcell.activation = self._convert_activation(subcell.activation) | |||
| else: | |||
| elif subcell.after_fake: | |||
| subcell.has_act = True | |||
| subcell.activation = _AddFakeQuantAfterSubCell(F.identity, | |||
| num_bits=self.act_bits, | |||
| @@ -0,0 +1,143 @@ | |||
| # YOLOV3-DarkNet53-Quant Example | |||
| ## Description | |||
| This is an example of training YOLOV3-DarkNet53-Quant with COCO2014 dataset in MindSpore. | |||
| ## Requirements | |||
| - Install [MindSpore](https://www.mindspore.cn/install/en). | |||
| - Download the dataset COCO2014. | |||
| > Unzip the COCO2014 dataset to any path you want, the folder should include train and eval dataset as follows: | |||
| ``` | |||
| . | |||
| └─dataset | |||
| ├─train2014 | |||
| ├─val2014 | |||
| └─annotations | |||
| ``` | |||
| ## Structure | |||
| ```shell | |||
| . | |||
| └─yolov3_darknet53_quant | |||
| ├─README.md | |||
| ├─scripts | |||
| ├─run_standalone_train.sh # launch standalone training(1p) | |||
| ├─run_distribute_train.sh # launch distributed training(8p) | |||
| └─run_eval.sh # launch evaluating | |||
| ├─src | |||
| ├─__init__.py # python init file | |||
| ├─config.py # parameter configuration | |||
| ├─darknet.py # backbone of network | |||
| ├─distributed_sampler.py # iterator of dataset | |||
| ├─initializer.py # initializer of parameters | |||
| ├─logger.py # log function | |||
| ├─loss.py # loss function | |||
| ├─lr_scheduler.py # generate learning rate | |||
| ├─transforms.py # Preprocess data | |||
| ├─util.py # util function | |||
| ├─yolo.py # yolov3 network | |||
| ├─yolo_dataset.py # create dataset for YOLOV3 | |||
| ├─eval.py # eval net | |||
| └─train.py # train net | |||
| ``` | |||
| ## Running the example | |||
| ### Train | |||
| #### Usage | |||
| ``` | |||
| # distributed training | |||
| sh run_distribute_train.sh [DATASET_PATH] [RESUME_YOLOV3] [MINDSPORE_HCCL_CONFIG_PATH] | |||
| # standalone training | |||
| sh run_standalone_train.sh [DATASET_PATH] [RESUME_YOLOV3] | |||
| ``` | |||
| #### Launch | |||
| ```bash | |||
| # distributed training example(8p) | |||
| sh run_distribute_train.sh dataset/coco2014 yolov3_darknet_noquant_ckpt/0-320_102400.ckpt rank_table_8p.json | |||
| # standalone training example(1p) | |||
| sh run_standalone_train.sh dataset/coco2014 yolov3_darknet_noquant_ckpt/0-320_102400.ckpt | |||
| ``` | |||
| > About rank_table.json, you can refer to the [distributed training tutorial](https://www.mindspore.cn/tutorial/en/master/advanced_use/distributed_training.html). | |||
| #### Result | |||
| Training result will be stored in the scripts path, whose folder name begins with "train" or "train_parallel". You can find checkpoint file together with result like the followings in log.txt. | |||
| ``` | |||
| # distribute training result(8p) | |||
| epoch[0], iter[0], loss:483.341675, 0.31 imgs/sec, lr:0.0 | |||
| epoch[0], iter[100], loss:55.690952, 3.46 imgs/sec, lr:0.0 | |||
| epoch[0], iter[200], loss:54.045728, 126.54 imgs/sec, lr:0.0 | |||
| epoch[0], iter[300], loss:48.771608, 133.04 imgs/sec, lr:0.0 | |||
| epoch[0], iter[400], loss:48.486769, 139.69 imgs/sec, lr:0.0 | |||
| epoch[0], iter[500], loss:48.649275, 143.29 imgs/sec, lr:0.0 | |||
| epoch[0], iter[600], loss:44.731309, 144.03 imgs/sec, lr:0.0 | |||
| epoch[1], iter[700], loss:43.037023, 136.08 imgs/sec, lr:0.0 | |||
| epoch[1], iter[800], loss:41.514788, 132.94 imgs/sec, lr:0.0 | |||
| … | |||
| epoch[133], iter[85700], loss:33.326716, 136.14 imgs/sec, lr:6.497331924038008e-06 | |||
| epoch[133], iter[85800], loss:34.968744, 136.76 imgs/sec, lr:6.497331924038008e-06 | |||
| epoch[134], iter[85900], loss:35.868543, 137.08 imgs/sec, lr:1.6245529650404933e-06 | |||
| epoch[134], iter[86000], loss:35.740817, 139.49 imgs/sec, lr:1.6245529650404933e-06 | |||
| epoch[134], iter[86100], loss:34.600463, 141.47 imgs/sec, lr:1.6245529650404933e-06 | |||
| epoch[134], iter[86200], loss:36.641916, 137.91 imgs/sec, lr:1.6245529650404933e-06 | |||
| epoch[134], iter[86300], loss:32.819769, 138.17 imgs/sec, lr:1.6245529650404933e-06 | |||
| epoch[134], iter[86400], loss:35.603033, 142.23 imgs/sec, lr:1.6245529650404933e-06 | |||
| epoch[134], iter[86500], loss:34.303755, 145.18 imgs/sec, lr:1.6245529650404933e-06 | |||
| ... | |||
| ``` | |||
| ### Infer | |||
| #### Usage | |||
| ``` | |||
| # infer | |||
| sh run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH] | |||
| ``` | |||
| #### Launch | |||
| ```bash | |||
| # infer with checkpoint | |||
| sh run_eval.sh dataset/coco2014/ checkpoint/0-135.ckpt | |||
| ``` | |||
| > checkpoint can be produced in training process. | |||
| #### Result | |||
| Inference result will be stored in the scripts path, whose folder name is "eval". Under this, you can find result like the followings in log.txt. | |||
| ``` | |||
| =============coco eval reulst========= | |||
| Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.310 | |||
| Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.531 | |||
| Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.322 | |||
| Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.130 | |||
| Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.326 | |||
| Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.425 | |||
| Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.260 | |||
| Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.402 | |||
| Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.429 | |||
| Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.232 | |||
| Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.450 | |||
| Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.558 | |||
| ``` | |||
| @@ -0,0 +1,336 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """YoloV3 eval.""" | |||
| import os | |||
| import argparse | |||
| import datetime | |||
| import time | |||
| import sys | |||
| from collections import defaultdict | |||
| import numpy as np | |||
| from pycocotools.coco import COCO | |||
| from pycocotools.cocoeval import COCOeval | |||
| from mindspore import Tensor | |||
| from mindspore.train import ParallelMode | |||
| from mindspore import context | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| import mindspore as ms | |||
| from mindspore.train.quant import quant | |||
| from src.yolo import YOLOV3DarkNet53 | |||
| from src.logger import get_logger | |||
| from src.yolo_dataset import create_yolo_dataset | |||
| from src.config import ConfigYOLOV3DarkNet53 | |||
| devid = int(os.getenv('DEVICE_ID')) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True, device_id=devid) | |||
| class Redirct: | |||
| def __init__(self): | |||
| self.content = "" | |||
| def write(self, content): | |||
| self.content += content | |||
| def flush(self): | |||
| self.content = "" | |||
| class DetectionEngine: | |||
| """Detection engine.""" | |||
| def __init__(self, args): | |||
| self.ignore_threshold = args.ignore_threshold | |||
| self.labels = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', | |||
| 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', | |||
| 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', | |||
| 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', | |||
| 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', | |||
| 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', | |||
| 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', | |||
| 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', | |||
| 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', | |||
| 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'] | |||
| self.num_classes = len(self.labels) | |||
| self.results = {} | |||
| self.file_path = '' | |||
| self.save_prefix = args.outputs_dir | |||
| self.annFile = args.annFile | |||
| self._coco = COCO(self.annFile) | |||
| self._img_ids = list(sorted(self._coco.imgs.keys())) | |||
| self.det_boxes = [] | |||
| self.nms_thresh = args.nms_thresh | |||
| self.coco_catIds = self._coco.getCatIds() | |||
| def do_nms_for_results(self): | |||
| """Get result boxes.""" | |||
| for img_id in self.results: | |||
| for clsi in self.results[img_id]: | |||
| dets = self.results[img_id][clsi] | |||
| dets = np.array(dets) | |||
| keep_index = self._nms(dets, self.nms_thresh) | |||
| keep_box = [{'image_id': int(img_id), | |||
| 'category_id': int(clsi), | |||
| 'bbox': list(dets[i][:4].astype(float)), | |||
| 'score': dets[i][4].astype(float)} | |||
| for i in keep_index] | |||
| self.det_boxes.extend(keep_box) | |||
| def _nms(self, dets, thresh): | |||
| """Calculate NMS.""" | |||
| # conver xywh -> xmin ymin xmax ymax | |||
| x1 = dets[:, 0] | |||
| y1 = dets[:, 1] | |||
| x2 = x1 + dets[:, 2] | |||
| y2 = y1 + dets[:, 3] | |||
| scores = dets[:, 4] | |||
| areas = (x2 - x1 + 1) * (y2 - y1 + 1) | |||
| order = scores.argsort()[::-1] | |||
| keep = [] | |||
| while order.size > 0: | |||
| i = order[0] | |||
| keep.append(i) | |||
| xx1 = np.maximum(x1[i], x1[order[1:]]) | |||
| yy1 = np.maximum(y1[i], y1[order[1:]]) | |||
| xx2 = np.minimum(x2[i], x2[order[1:]]) | |||
| yy2 = np.minimum(y2[i], y2[order[1:]]) | |||
| w = np.maximum(0.0, xx2 - xx1 + 1) | |||
| h = np.maximum(0.0, yy2 - yy1 + 1) | |||
| inter = w * h | |||
| ovr = inter / (areas[i] + areas[order[1:]] - inter) | |||
| inds = np.where(ovr <= thresh)[0] | |||
| order = order[inds + 1] | |||
| return keep | |||
| def write_result(self): | |||
| """Save result to file.""" | |||
| import json | |||
| t = datetime.datetime.now().strftime('_%Y_%m_%d_%H_%M_%S') | |||
| try: | |||
| self.file_path = self.save_prefix + '/predict' + t + '.json' | |||
| f = open(self.file_path, 'w') | |||
| json.dump(self.det_boxes, f) | |||
| except IOError as e: | |||
| raise RuntimeError("Unable to open json file to dump. What(): {}".format(str(e))) | |||
| else: | |||
| f.close() | |||
| return self.file_path | |||
| def get_eval_result(self): | |||
| """Get eval result.""" | |||
| cocoGt = COCO(self.annFile) | |||
| cocoDt = cocoGt.loadRes(self.file_path) | |||
| cocoEval = COCOeval(cocoGt, cocoDt, 'bbox') | |||
| cocoEval.evaluate() | |||
| cocoEval.accumulate() | |||
| rdct = Redirct() | |||
| stdout = sys.stdout | |||
| sys.stdout = rdct | |||
| cocoEval.summarize() | |||
| sys.stdout = stdout | |||
| return rdct.content | |||
| def detect(self, outputs, batch, image_shape, image_id): | |||
| """Detect boxes.""" | |||
| outputs_num = len(outputs) | |||
| # output [|32, 52, 52, 3, 85| ] | |||
| for batch_id in range(batch): | |||
| for out_id in range(outputs_num): | |||
| # 32, 52, 52, 3, 85 | |||
| out_item = outputs[out_id] | |||
| # 52, 52, 3, 85 | |||
| out_item_single = out_item[batch_id, :] | |||
| # get number of items in one head, [B, gx, gy, anchors, 5+80] | |||
| dimensions = out_item_single.shape[:-1] | |||
| out_num = 1 | |||
| for d in dimensions: | |||
| out_num *= d | |||
| ori_w, ori_h = image_shape[batch_id] | |||
| img_id = int(image_id[batch_id]) | |||
| x = out_item_single[..., 0] * ori_w | |||
| y = out_item_single[..., 1] * ori_h | |||
| w = out_item_single[..., 2] * ori_w | |||
| h = out_item_single[..., 3] * ori_h | |||
| conf = out_item_single[..., 4:5] | |||
| cls_emb = out_item_single[..., 5:] | |||
| cls_argmax = np.expand_dims(np.argmax(cls_emb, axis=-1), axis=-1) | |||
| x = x.reshape(-1) | |||
| y = y.reshape(-1) | |||
| w = w.reshape(-1) | |||
| h = h.reshape(-1) | |||
| cls_emb = cls_emb.reshape(-1, 80) | |||
| conf = conf.reshape(-1) | |||
| cls_argmax = cls_argmax.reshape(-1) | |||
| x_top_left = x - w / 2. | |||
| y_top_left = y - h / 2. | |||
| # creat all False | |||
| flag = np.random.random(cls_emb.shape) > sys.maxsize | |||
| for i in range(flag.shape[0]): | |||
| c = cls_argmax[i] | |||
| flag[i, c] = True | |||
| confidence = cls_emb[flag] * conf | |||
| for x_lefti, y_lefti, wi, hi, confi, clsi in zip(x_top_left, y_top_left, w, h, confidence, cls_argmax): | |||
| if confi < self.ignore_threshold: | |||
| continue | |||
| if img_id not in self.results: | |||
| self.results[img_id] = defaultdict(list) | |||
| x_lefti = max(0, x_lefti) | |||
| y_lefti = max(0, y_lefti) | |||
| wi = min(wi, ori_w) | |||
| hi = min(hi, ori_h) | |||
| # transform catId to match coco | |||
| coco_clsi = self.coco_catIds[clsi] | |||
| self.results[img_id][coco_clsi].append([x_lefti, y_lefti, wi, hi, confi]) | |||
| def parse_args(): | |||
| """Parse arguments.""" | |||
| parser = argparse.ArgumentParser('mindspore coco testing') | |||
| # dataset related | |||
| parser.add_argument('--data_dir', type=str, default='', help='train data dir') | |||
| parser.add_argument('--per_batch_size', default=1, type=int, help='batch size for per gpu') | |||
| # network related | |||
| parser.add_argument('--pretrained', default='', type=str, help='model_path, local pretrained model to load') | |||
| # logging related | |||
| parser.add_argument('--log_path', type=str, default='outputs/', help='checkpoint save location') | |||
| # detect_related | |||
| parser.add_argument('--nms_thresh', type=float, default=0.5, help='threshold for NMS') | |||
| parser.add_argument('--annFile', type=str, default='', help='path to annotation') | |||
| parser.add_argument('--testing_shape', type=str, default='', help='shape for test ') | |||
| parser.add_argument('--ignore_threshold', type=float, default=0.001, help='threshold to throw low quality boxes') | |||
| args, _ = parser.parse_known_args() | |||
| args.data_root = os.path.join(args.data_dir, 'val2014') | |||
| args.annFile = os.path.join(args.data_dir, 'annotations/instances_val2014.json') | |||
| return args | |||
| def conver_testing_shape(args): | |||
| """Convert testing shape to list.""" | |||
| testing_shape = [int(args.testing_shape), int(args.testing_shape)] | |||
| return testing_shape | |||
| def test(): | |||
| """The function of eval.""" | |||
| start_time = time.time() | |||
| args = parse_args() | |||
| # logger | |||
| args.outputs_dir = os.path.join(args.log_path, | |||
| datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S')) | |||
| rank_id = int(os.environ.get('RANK_ID')) | |||
| args.logger = get_logger(args.outputs_dir, rank_id) | |||
| context.reset_auto_parallel_context() | |||
| parallel_mode = ParallelMode.STAND_ALONE | |||
| context.set_auto_parallel_context(parallel_mode=parallel_mode, mirror_mean=True, device_num=1) | |||
| args.logger.info('Creating Network....') | |||
| network = YOLOV3DarkNet53(is_training=False) | |||
| config = ConfigYOLOV3DarkNet53() | |||
| if args.testing_shape: | |||
| config.test_img_shape = conver_testing_shape(args) | |||
| # convert fusion network to quantization aware network | |||
| if config.quantization_aware: | |||
| network = quant.convert_quant_network(network, | |||
| bn_fold=True, | |||
| per_channel=[True, False], | |||
| symmetric=[True, False]) | |||
| args.logger.info(args.pretrained) | |||
| if os.path.isfile(args.pretrained): | |||
| param_dict = load_checkpoint(args.pretrained) | |||
| param_dict_new = {} | |||
| for key, values in param_dict.items(): | |||
| if key.startswith('moments.'): | |||
| continue | |||
| elif key.startswith('yolo_network.'): | |||
| param_dict_new[key[13:]] = values | |||
| else: | |||
| param_dict_new[key] = values | |||
| load_param_into_net(network, param_dict_new) | |||
| args.logger.info('load_model {} success'.format(args.pretrained)) | |||
| else: | |||
| args.logger.info('{} not exists or not a pre-trained file'.format(args.pretrained)) | |||
| assert FileNotFoundError('{} not exists or not a pre-trained file'.format(args.pretrained)) | |||
| exit(1) | |||
| data_root = args.data_root | |||
| ann_file = args.annFile | |||
| ds, data_size = create_yolo_dataset(data_root, ann_file, is_training=False, batch_size=args.per_batch_size, | |||
| max_epoch=1, device_num=1, rank=rank_id, shuffle=False, | |||
| config=config) | |||
| args.logger.info('testing shape : {}'.format(config.test_img_shape)) | |||
| args.logger.info('totol {} images to eval'.format(data_size)) | |||
| network.set_train(False) | |||
| # init detection engine | |||
| detection = DetectionEngine(args) | |||
| input_shape = Tensor(tuple(config.test_img_shape), ms.float32) | |||
| args.logger.info('Start inference....') | |||
| for i, data in enumerate(ds.create_dict_iterator()): | |||
| image = Tensor(data["image"]) | |||
| image_shape = Tensor(data["image_shape"]) | |||
| image_id = Tensor(data["img_id"]) | |||
| prediction = network(image, input_shape) | |||
| output_big, output_me, output_small = prediction | |||
| output_big = output_big.asnumpy() | |||
| output_me = output_me.asnumpy() | |||
| output_small = output_small.asnumpy() | |||
| image_id = image_id.asnumpy() | |||
| image_shape = image_shape.asnumpy() | |||
| detection.detect([output_small, output_me, output_big], args.per_batch_size, image_shape, image_id) | |||
| if i % 1000 == 0: | |||
| args.logger.info('Processing... {:.2f}% '.format(i * args.per_batch_size / data_size * 100)) | |||
| args.logger.info('Calculating mAP...') | |||
| detection.do_nms_for_results() | |||
| result_file_path = detection.write_result() | |||
| args.logger.info('result file path: {}'.format(result_file_path)) | |||
| eval_result = detection.get_eval_result() | |||
| cost_time = time.time() - start_time | |||
| args.logger.info('\n=============coco eval reulst=========\n' + eval_result) | |||
| args.logger.info('testing cost time {:.2f}h'.format(cost_time / 3600.)) | |||
| if __name__ == "__main__": | |||
| test() | |||
| @@ -0,0 +1,83 @@ | |||
| #!/bin/bash | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| if [ $# != 3 ] | |||
| then | |||
| echo "Usage: sh run_distribute_train.sh [DATASET_PATH] [RESUME_YOLOV3] [MINDSPORE_HCCL_CONFIG_PATH]" | |||
| exit 1 | |||
| fi | |||
| get_real_path(){ | |||
| if [ "${1:0:1}" == "/" ]; then | |||
| echo "$1" | |||
| else | |||
| echo "$(realpath -m $PWD/$1)" | |||
| fi | |||
| } | |||
| DATASET_PATH=$(get_real_path $1) | |||
| RESUME_YOLOV3=$(get_real_path $2) | |||
| MINDSPORE_HCCL_CONFIG_PATH=$(get_real_path $3) | |||
| echo $DATASET_PATH | |||
| echo $RESUME_YOLOV3 | |||
| echo $MINDSPORE_HCCL_CONFIG_PATH | |||
| if [ ! -d $DATASET_PATH ] | |||
| then | |||
| echo "error: DATASET_PATH=$DATASET_PATH is not a directory" | |||
| exit 1 | |||
| fi | |||
| if [ ! -f $RESUME_YOLOV3 ] | |||
| then | |||
| echo "error: PRETRAINED_PATH=$RESUME_YOLOV3 is not a file" | |||
| exit 1 | |||
| fi | |||
| if [ ! -f $MINDSPORE_HCCL_CONFIG_PATH ] | |||
| then | |||
| echo "error: MINDSPORE_HCCL_CONFIG_PATH=$MINDSPORE_HCCL_CONFIG_PATH is not a file" | |||
| exit 1 | |||
| fi | |||
| export DEVICE_NUM=8 | |||
| export RANK_SIZE=8 | |||
| export MINDSPORE_HCCL_CONFIG_PATH=$MINDSPORE_HCCL_CONFIG_PATH | |||
| for((i=0; i<${DEVICE_NUM}; i++)) | |||
| do | |||
| export DEVICE_ID=$i | |||
| export RANK_ID=$i | |||
| rm -rf ./train_parallel$i | |||
| mkdir ./train_parallel$i | |||
| cp ../*.py ./train_parallel$i | |||
| cp -r ../src ./train_parallel$i | |||
| cd ./train_parallel$i || exit | |||
| echo "start training for rank $RANK_ID, device $DEVICE_ID" | |||
| env > env.log | |||
| python train.py \ | |||
| --data_dir=$DATASET_PATH \ | |||
| --resume_yolov3=$RESUME_YOLOV3 \ | |||
| --is_distributed=1 \ | |||
| --per_batch_size=16 \ | |||
| --lr=0.012 \ | |||
| --T_max=135 \ | |||
| --max_epoch=135 \ | |||
| --warmup_epochs=5 \ | |||
| --lr_scheduler=cosine_annealing > log.txt 2>&1 & | |||
| cd .. | |||
| done | |||
| @@ -0,0 +1,67 @@ | |||
| #!/bin/bash | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| if [ $# != 2 ] | |||
| then | |||
| echo "Usage: sh run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH]" | |||
| exit 1 | |||
| fi | |||
| get_real_path(){ | |||
| if [ "${1:0:1}" == "/" ]; then | |||
| echo "$1" | |||
| else | |||
| echo "$(realpath -m $PWD/$1)" | |||
| fi | |||
| } | |||
| DATASET_PATH=$(get_real_path $1) | |||
| CHECKPOINT_PATH=$(get_real_path $2) | |||
| echo $DATASET_PATH | |||
| echo $CHECKPOINT_PATH | |||
| if [ ! -d $DATASET_PATH ] | |||
| then | |||
| echo "error: DATASET_PATH=$PATH1 is not a directory" | |||
| exit 1 | |||
| fi | |||
| if [ ! -f $CHECKPOINT_PATH ] | |||
| then | |||
| echo "error: CHECKPOINT_PATH=$PATH2 is not a file" | |||
| exit 1 | |||
| fi | |||
| export DEVICE_NUM=1 | |||
| export DEVICE_ID=0 | |||
| export RANK_SIZE=$DEVICE_NUM | |||
| export RANK_ID=0 | |||
| if [ -d "eval" ]; | |||
| then | |||
| rm -rf ./eval | |||
| fi | |||
| mkdir ./eval | |||
| cp ../*.py ./eval | |||
| cp -r ../src ./eval | |||
| cd ./eval || exit | |||
| env > env.log | |||
| echo "start infering for device $DEVICE_ID" | |||
| python eval.py \ | |||
| --data_dir=$DATASET_PATH \ | |||
| --pretrained=$CHECKPOINT_PATH \ | |||
| --testing_shape=416 > log.txt 2>&1 & | |||
| cd .. | |||
| @@ -0,0 +1,74 @@ | |||
| #!/bin/bash | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| if [ $# != 2 ] | |||
| then | |||
| echo "Usage: sh run_standalone_train.sh [DATASET_PATH] [RESUME_YOLOV3]" | |||
| exit 1 | |||
| fi | |||
| get_real_path(){ | |||
| if [ "${1:0:1}" == "/" ]; then | |||
| echo "$1" | |||
| else | |||
| echo "$(realpath -m $PWD/$1)" | |||
| fi | |||
| } | |||
| DATASET_PATH=$(get_real_path $1) | |||
| echo $DATASET_PATH | |||
| RESUME_YOLOV3=$(get_real_path $2) | |||
| echo $RESUME_YOLOV3 | |||
| if [ ! -d $DATASET_PATH ] | |||
| then | |||
| echo "error: DATASET_PATH=$DATASET_PATH is not a directory" | |||
| exit 1 | |||
| fi | |||
| if [ ! -f $RESUME_YOLOV3 ] | |||
| then | |||
| echo "error: PRETRAINED_PATH=$RESUME_YOLOV3 is not a file" | |||
| exit 1 | |||
| fi | |||
| export DEVICE_NUM=1 | |||
| export DEVICE_ID=0 | |||
| export RANK_ID=0 | |||
| export RANK_SIZE=1 | |||
| if [ -d "train" ]; | |||
| then | |||
| rm -rf ./train | |||
| fi | |||
| mkdir ./train | |||
| cp ../*.py ./train | |||
| cp -r ../src ./train | |||
| cd ./train || exit | |||
| echo "start training for device $DEVICE_ID" | |||
| env > env.log | |||
| python train.py \ | |||
| --data_dir=$DATASET_PATH \ | |||
| --resume_yolov3=$RESUME_YOLOV3 \ | |||
| --is_distributed=0 \ | |||
| --per_batch_size=16 \ | |||
| --lr=0.004 \ | |||
| --T_max=135 \ | |||
| --max_epoch=135 \ | |||
| --warmup_epochs=5 \ | |||
| --lr_scheduler=cosine_annealing > log.txt 2>&1 & | |||
| cd .. | |||
| @@ -0,0 +1,69 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Config parameters for Darknet based yolov3_darknet53 models.""" | |||
| class ConfigYOLOV3DarkNet53: | |||
| """ | |||
| Config parameters for the yolov3_darknet53. | |||
| Examples: | |||
| ConfigYOLOV3DarkNet53() | |||
| """ | |||
| # train_param | |||
| # data augmentation related | |||
| hue = 0.1 | |||
| saturation = 1.5 | |||
| value = 1.5 | |||
| jitter = 0.3 | |||
| resize_rate = 1 | |||
| multi_scale = [[320, 320], | |||
| [352, 352], | |||
| [384, 384], | |||
| [416, 416], | |||
| [448, 448], | |||
| [480, 480], | |||
| [512, 512], | |||
| [544, 544], | |||
| [576, 576], | |||
| [608, 608] | |||
| ] | |||
| num_classes = 80 | |||
| max_box = 50 | |||
| backbone_input_shape = [32, 64, 128, 256, 512] | |||
| backbone_shape = [64, 128, 256, 512, 1024] | |||
| backbone_layers = [1, 2, 8, 8, 4] | |||
| # confidence under ignore_threshold means no object when training | |||
| ignore_threshold = 0.7 | |||
| # h->w | |||
| anchor_scales = [(10, 13), | |||
| (16, 30), | |||
| (33, 23), | |||
| (30, 61), | |||
| (62, 45), | |||
| (59, 119), | |||
| (116, 90), | |||
| (156, 198), | |||
| (373, 326)] | |||
| out_channel = 255 | |||
| quantization_aware = True | |||
| # test_param | |||
| test_img_shape = [416, 416] | |||
| @@ -0,0 +1,208 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """DarkNet model.""" | |||
| import mindspore.nn as nn | |||
| from mindspore.ops import operations as P | |||
| def conv_block(in_channels, | |||
| out_channels, | |||
| kernel_size, | |||
| stride, | |||
| dilation=1): | |||
| """Get a conv2d batchnorm and relu layer""" | |||
| pad_mode = 'same' | |||
| padding = 0 | |||
| return nn.Conv2dBnAct(in_channels, out_channels, kernel_size, | |||
| stride=stride, | |||
| pad_mode=pad_mode, | |||
| padding=padding, | |||
| dilation=dilation, | |||
| has_bn=True, | |||
| momentum=0.1, | |||
| activation='relu') | |||
| class ResidualBlock(nn.Cell): | |||
| """ | |||
| DarkNet V1 residual block definition. | |||
| Args: | |||
| in_channels: Integer. Input channel. | |||
| out_channels: Integer. Output channel. | |||
| Returns: | |||
| Tensor, output tensor. | |||
| Examples: | |||
| ResidualBlock(3, 208) | |||
| """ | |||
| expansion = 4 | |||
| def __init__(self, | |||
| in_channels, | |||
| out_channels): | |||
| super(ResidualBlock, self).__init__() | |||
| out_chls = out_channels//2 | |||
| self.conv1 = conv_block(in_channels, out_chls, kernel_size=1, stride=1) | |||
| self.conv2 = conv_block(out_chls, out_channels, kernel_size=3, stride=1) | |||
| self.add = P.TensorAdd() | |||
| def construct(self, x): | |||
| identity = x | |||
| out = self.conv1(x) | |||
| out = self.conv2(out) | |||
| out = self.add(out, identity) | |||
| return out | |||
| class DarkNet(nn.Cell): | |||
| """ | |||
| DarkNet V1 network. | |||
| Args: | |||
| block: Cell. Block for network. | |||
| layer_nums: List. Numbers of different layers. | |||
| in_channels: Integer. Input channel. | |||
| out_channels: Integer. Output channel. | |||
| detect: Bool. Whether detect or not. Default:False. | |||
| Returns: | |||
| Tuple, tuple of output tensor,(f1,f2,f3,f4,f5). | |||
| Examples: | |||
| DarkNet(ResidualBlock, | |||
| [1, 2, 8, 8, 4], | |||
| [32, 64, 128, 256, 512], | |||
| [64, 128, 256, 512, 1024], | |||
| 100) | |||
| """ | |||
| def __init__(self, | |||
| block, | |||
| layer_nums, | |||
| in_channels, | |||
| out_channels, | |||
| detect=False): | |||
| super(DarkNet, self).__init__() | |||
| self.outchannel = out_channels[-1] | |||
| self.detect = detect | |||
| if not len(layer_nums) == len(in_channels) == len(out_channels) == 5: | |||
| raise ValueError("the length of layer_num, inchannel, outchannel list must be 5!") | |||
| self.conv0 = conv_block(3, | |||
| in_channels[0], | |||
| kernel_size=3, | |||
| stride=1) | |||
| self.conv1 = conv_block(in_channels[0], | |||
| out_channels[0], | |||
| kernel_size=3, | |||
| stride=2) | |||
| self.conv2 = conv_block(in_channels[1], | |||
| out_channels[1], | |||
| kernel_size=3, | |||
| stride=2) | |||
| self.conv3 = conv_block(in_channels[2], | |||
| out_channels[2], | |||
| kernel_size=3, | |||
| stride=2) | |||
| self.conv4 = conv_block(in_channels[3], | |||
| out_channels[3], | |||
| kernel_size=3, | |||
| stride=2) | |||
| self.conv5 = conv_block(in_channels[4], | |||
| out_channels[4], | |||
| kernel_size=3, | |||
| stride=2) | |||
| self.layer1 = self._make_layer(block, | |||
| layer_nums[0], | |||
| in_channel=out_channels[0], | |||
| out_channel=out_channels[0]) | |||
| self.layer2 = self._make_layer(block, | |||
| layer_nums[1], | |||
| in_channel=out_channels[1], | |||
| out_channel=out_channels[1]) | |||
| self.layer3 = self._make_layer(block, | |||
| layer_nums[2], | |||
| in_channel=out_channels[2], | |||
| out_channel=out_channels[2]) | |||
| self.layer4 = self._make_layer(block, | |||
| layer_nums[3], | |||
| in_channel=out_channels[3], | |||
| out_channel=out_channels[3]) | |||
| self.layer5 = self._make_layer(block, | |||
| layer_nums[4], | |||
| in_channel=out_channels[4], | |||
| out_channel=out_channels[4]) | |||
| def _make_layer(self, block, layer_num, in_channel, out_channel): | |||
| """ | |||
| Make Layer for DarkNet. | |||
| :param block: Cell. DarkNet block. | |||
| :param layer_num: Integer. Layer number. | |||
| :param in_channel: Integer. Input channel. | |||
| :param out_channel: Integer. Output channel. | |||
| Examples: | |||
| _make_layer(ConvBlock, 1, 128, 256) | |||
| """ | |||
| layers = [] | |||
| darkblk = block(in_channel, out_channel) | |||
| layers.append(darkblk) | |||
| for _ in range(1, layer_num): | |||
| darkblk = block(out_channel, out_channel) | |||
| layers.append(darkblk) | |||
| return nn.SequentialCell(layers) | |||
| def construct(self, x): | |||
| c1 = self.conv0(x) | |||
| c2 = self.conv1(c1) | |||
| c3 = self.layer1(c2) | |||
| c4 = self.conv2(c3) | |||
| c5 = self.layer2(c4) | |||
| c6 = self.conv3(c5) | |||
| c7 = self.layer3(c6) | |||
| c8 = self.conv4(c7) | |||
| c9 = self.layer4(c8) | |||
| c10 = self.conv5(c9) | |||
| c11 = self.layer5(c10) | |||
| if self.detect: | |||
| return c7, c9, c11 | |||
| return c11 | |||
| def get_out_channels(self): | |||
| return self.outchannel | |||
| def darknet53(): | |||
| """ | |||
| Get DarkNet53 neural network. | |||
| Returns: | |||
| Cell, cell instance of DarkNet53 neural network. | |||
| Examples: | |||
| darknet53() | |||
| """ | |||
| return DarkNet(ResidualBlock, [1, 2, 8, 8, 4], | |||
| [32, 64, 128, 256, 512], | |||
| [64, 128, 256, 512, 1024]) | |||
| @@ -0,0 +1,60 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Yolo dataset distributed sampler.""" | |||
| from __future__ import division | |||
| import math | |||
| import numpy as np | |||
| class DistributedSampler: | |||
| """Distributed sampler.""" | |||
| def __init__(self, dataset_size, num_replicas=None, rank=None, shuffle=True): | |||
| if num_replicas is None: | |||
| print("***********Setting world_size to 1 since it is not passed in ******************") | |||
| num_replicas = 1 | |||
| if rank is None: | |||
| print("***********Setting rank to 0 since it is not passed in ******************") | |||
| rank = 0 | |||
| self.dataset_size = dataset_size | |||
| self.num_replicas = num_replicas | |||
| self.rank = rank | |||
| self.epoch = 0 | |||
| self.num_samples = int(math.ceil(dataset_size * 1.0 / self.num_replicas)) | |||
| self.total_size = self.num_samples * self.num_replicas | |||
| self.shuffle = shuffle | |||
| def __iter__(self): | |||
| # deterministically shuffle based on epoch | |||
| if self.shuffle: | |||
| indices = np.random.RandomState(seed=self.epoch).permutation(self.dataset_size) | |||
| # np.array type. number from 0 to len(dataset_size)-1, used as index of dataset | |||
| indices = indices.tolist() | |||
| self.epoch += 1 | |||
| # change to list type | |||
| else: | |||
| indices = list(range(self.dataset_size)) | |||
| # add extra samples to make it evenly divisible | |||
| indices += indices[:(self.total_size - len(indices))] | |||
| assert len(indices) == self.total_size | |||
| # subsample | |||
| indices = indices[self.rank:self.total_size:self.num_replicas] | |||
| assert len(indices) == self.num_samples | |||
| return iter(indices) | |||
| def __len__(self): | |||
| return self.num_samples | |||
| @@ -0,0 +1,179 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Parameter init.""" | |||
| import math | |||
| import numpy as np | |||
| from mindspore.common import initializer as init | |||
| from mindspore.common.initializer import Initializer as MeInitializer | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| np.random.seed(5) | |||
| def calculate_gain(nonlinearity, param=None): | |||
| r"""Return the recommended gain value for the given nonlinearity function. | |||
| The values are as follows: | |||
| ================= ==================================================== | |||
| nonlinearity gain | |||
| ================= ==================================================== | |||
| Linear / Identity :math:`1` | |||
| Conv{1,2,3}D :math:`1` | |||
| Sigmoid :math:`1` | |||
| Tanh :math:`\frac{5}{3}` | |||
| ReLU :math:`\sqrt{2}` | |||
| Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}` | |||
| ================= ==================================================== | |||
| Args: | |||
| nonlinearity: the non-linear function (`nn.functional` name) | |||
| param: optional parameter for the non-linear function | |||
| Examples: | |||
| >>> gain = nn.init.calculate_gain('leaky_relu', 0.2) # leaky_relu with negative_slope=0.2 | |||
| """ | |||
| linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d'] | |||
| if nonlinearity in linear_fns or nonlinearity == 'sigmoid': | |||
| return 1 | |||
| if nonlinearity == 'tanh': | |||
| return 5.0 / 3 | |||
| if nonlinearity == 'relu': | |||
| return math.sqrt(2.0) | |||
| if nonlinearity == 'leaky_relu': | |||
| if param is None: | |||
| negative_slope = 0.01 | |||
| elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float): | |||
| # True/False are instances of int, hence check above | |||
| negative_slope = param | |||
| else: | |||
| raise ValueError("negative_slope {} not a valid number".format(param)) | |||
| return math.sqrt(2.0 / (1 + negative_slope ** 2)) | |||
| raise ValueError("Unsupported nonlinearity {}".format(nonlinearity)) | |||
| def _assignment(arr, num): | |||
| """Assign the value of 'num' and 'arr'.""" | |||
| if arr.shape == (): | |||
| arr = arr.reshape((1)) | |||
| arr[:] = num | |||
| arr = arr.reshape(()) | |||
| else: | |||
| if isinstance(num, np.ndarray): | |||
| arr[:] = num[:] | |||
| else: | |||
| arr[:] = num | |||
| return arr | |||
| def _calculate_correct_fan(array, mode): | |||
| mode = mode.lower() | |||
| valid_modes = ['fan_in', 'fan_out'] | |||
| if mode not in valid_modes: | |||
| raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes)) | |||
| fan_in, fan_out = _calculate_fan_in_and_fan_out(array) | |||
| return fan_in if mode == 'fan_in' else fan_out | |||
| def kaiming_uniform_(arr, a=0, mode='fan_in', nonlinearity='leaky_relu'): | |||
| r"""Fills the input `Tensor` with values according to the method | |||
| described in `Delving deep into rectifiers: Surpassing human-level | |||
| performance on ImageNet classification` - He, K. et al. (2015), using a | |||
| uniform distribution. The resulting tensor will have values sampled from | |||
| :math:`\mathcal{U}(-\text{bound}, \text{bound})` where | |||
| .. math:: | |||
| \text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}} | |||
| Also known as He initialization. | |||
| Args: | |||
| tensor: an n-dimensional `Tensor` | |||
| a: the negative slope of the rectifier used after this layer (only | |||
| used with ``'leaky_relu'``) | |||
| mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'`` | |||
| preserves the magnitude of the variance of the weights in the | |||
| forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the | |||
| backwards pass. | |||
| nonlinearity: the non-linear function (`nn.functional` name), | |||
| recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default). | |||
| Examples: | |||
| >>> w = np.empty(3, 5) | |||
| >>> nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu') | |||
| """ | |||
| fan = _calculate_correct_fan(arr, mode) | |||
| gain = calculate_gain(nonlinearity, a) | |||
| std = gain / math.sqrt(fan) | |||
| bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation | |||
| return np.random.uniform(-bound, bound, arr.shape) | |||
| def _calculate_fan_in_and_fan_out(arr): | |||
| """Calculate fan in and fan out.""" | |||
| dimensions = len(arr.shape) | |||
| if dimensions < 2: | |||
| raise ValueError("Fan in and fan out can not be computed for array with fewer than 2 dimensions") | |||
| num_input_fmaps = arr.shape[1] | |||
| num_output_fmaps = arr.shape[0] | |||
| receptive_field_size = 1 | |||
| if dimensions > 2: | |||
| receptive_field_size = arr[0][0].size | |||
| fan_in = num_input_fmaps * receptive_field_size | |||
| fan_out = num_output_fmaps * receptive_field_size | |||
| return fan_in, fan_out | |||
| class KaimingUniform(MeInitializer): | |||
| """Kaiming uniform initializer.""" | |||
| def __init__(self, a=0, mode='fan_in', nonlinearity='leaky_relu'): | |||
| super(KaimingUniform, self).__init__() | |||
| self.a = a | |||
| self.mode = mode | |||
| self.nonlinearity = nonlinearity | |||
| def _initialize(self, arr): | |||
| tmp = kaiming_uniform_(arr, self.a, self.mode, self.nonlinearity) | |||
| _assignment(arr, tmp) | |||
| def default_recurisive_init(custom_cell): | |||
| """Initialize parameter.""" | |||
| for _, cell in custom_cell.cells_and_names(): | |||
| if isinstance(cell, nn.Conv2d): | |||
| cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)), | |||
| cell.weight.default_input.shape, | |||
| cell.weight.default_input.dtype).to_tensor() | |||
| if cell.bias is not None: | |||
| fan_in, _ = _calculate_fan_in_and_fan_out(cell.weight.default_input.asnumpy()) | |||
| bound = 1 / math.sqrt(fan_in) | |||
| cell.bias.default_input = Tensor(np.random.uniform(-bound, bound, cell.bias.default_input.shape), | |||
| cell.bias.default_input.dtype) | |||
| elif isinstance(cell, nn.Dense): | |||
| cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)), | |||
| cell.weight.default_input.shape, | |||
| cell.weight.default_input.dtype).to_tensor() | |||
| if cell.bias is not None: | |||
| fan_in, _ = _calculate_fan_in_and_fan_out(cell.weight.default_input.asnumpy()) | |||
| bound = 1 / math.sqrt(fan_in) | |||
| cell.bias.default_input = Tensor(np.random.uniform(-bound, bound, cell.bias.default_input.shape), | |||
| cell.bias.default_input.dtype) | |||
| elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)): | |||
| pass | |||
| @@ -0,0 +1,80 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Custom Logger.""" | |||
| import os | |||
| import sys | |||
| import logging | |||
| from datetime import datetime | |||
| class LOGGER(logging.Logger): | |||
| """ | |||
| Logger. | |||
| Args: | |||
| logger_name: String. Logger name. | |||
| rank: Integer. Rank id. | |||
| """ | |||
| def __init__(self, logger_name, rank=0): | |||
| super(LOGGER, self).__init__(logger_name) | |||
| self.rank = rank | |||
| if rank % 8 == 0: | |||
| console = logging.StreamHandler(sys.stdout) | |||
| console.setLevel(logging.INFO) | |||
| formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s') | |||
| console.setFormatter(formatter) | |||
| self.addHandler(console) | |||
| def setup_logging_file(self, log_dir, rank=0): | |||
| """Setup logging file.""" | |||
| self.rank = rank | |||
| if not os.path.exists(log_dir): | |||
| os.makedirs(log_dir, exist_ok=True) | |||
| log_name = datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S') + '_rank_{}.log'.format(rank) | |||
| self.log_fn = os.path.join(log_dir, log_name) | |||
| fh = logging.FileHandler(self.log_fn) | |||
| fh.setLevel(logging.INFO) | |||
| formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s') | |||
| fh.setFormatter(formatter) | |||
| self.addHandler(fh) | |||
| def info(self, msg, *args, **kwargs): | |||
| if self.isEnabledFor(logging.INFO): | |||
| self._log(logging.INFO, msg, args, **kwargs) | |||
| def save_args(self, args): | |||
| self.info('Args:') | |||
| args_dict = vars(args) | |||
| for key in args_dict.keys(): | |||
| self.info('--> %s: %s', key, args_dict[key]) | |||
| self.info('') | |||
| def important_info(self, msg, *args, **kwargs): | |||
| if self.isEnabledFor(logging.INFO) and self.rank == 0: | |||
| line_width = 2 | |||
| important_msg = '\n' | |||
| important_msg += ('*'*70 + '\n')*line_width | |||
| important_msg += ('*'*line_width + '\n')*2 | |||
| important_msg += '*'*line_width + ' '*8 + msg + '\n' | |||
| important_msg += ('*'*line_width + '\n')*2 | |||
| important_msg += ('*'*70 + '\n')*line_width | |||
| self.info(important_msg, *args, **kwargs) | |||
| def get_logger(path, rank): | |||
| """Get Logger.""" | |||
| logger = LOGGER('yolov3_darknet53', rank) | |||
| logger.setup_logging_file(path, rank) | |||
| return logger | |||
| @@ -0,0 +1,70 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """YOLOV3 loss.""" | |||
| from mindspore.ops import operations as P | |||
| import mindspore.nn as nn | |||
| class XYLoss(nn.Cell): | |||
| """Loss for x and y.""" | |||
| def __init__(self): | |||
| super(XYLoss, self).__init__() | |||
| self.cross_entropy = P.SigmoidCrossEntropyWithLogits() | |||
| self.reduce_sum = P.ReduceSum() | |||
| def construct(self, object_mask, box_loss_scale, predict_xy, true_xy): | |||
| xy_loss = object_mask * box_loss_scale * self.cross_entropy(predict_xy, true_xy) | |||
| xy_loss = self.reduce_sum(xy_loss, ()) | |||
| return xy_loss | |||
| class WHLoss(nn.Cell): | |||
| """Loss for w and h.""" | |||
| def __init__(self): | |||
| super(WHLoss, self).__init__() | |||
| self.square = P.Square() | |||
| self.reduce_sum = P.ReduceSum() | |||
| def construct(self, object_mask, box_loss_scale, predict_wh, true_wh): | |||
| wh_loss = object_mask * box_loss_scale * 0.5 * P.Square()(true_wh - predict_wh) | |||
| wh_loss = self.reduce_sum(wh_loss, ()) | |||
| return wh_loss | |||
| class ConfidenceLoss(nn.Cell): | |||
| """Loss for confidence.""" | |||
| def __init__(self): | |||
| super(ConfidenceLoss, self).__init__() | |||
| self.cross_entropy = P.SigmoidCrossEntropyWithLogits() | |||
| self.reduce_sum = P.ReduceSum() | |||
| def construct(self, object_mask, predict_confidence, ignore_mask): | |||
| confidence_loss = self.cross_entropy(predict_confidence, object_mask) | |||
| confidence_loss = object_mask * confidence_loss + (1 - object_mask) * confidence_loss * ignore_mask | |||
| confidence_loss = self.reduce_sum(confidence_loss, ()) | |||
| return confidence_loss | |||
| class ClassLoss(nn.Cell): | |||
| """Loss for classification.""" | |||
| def __init__(self): | |||
| super(ClassLoss, self).__init__() | |||
| self.cross_entropy = P.SigmoidCrossEntropyWithLogits() | |||
| self.reduce_sum = P.ReduceSum() | |||
| def construct(self, object_mask, predict_class, class_probs): | |||
| class_loss = object_mask * self.cross_entropy(predict_class, class_probs) | |||
| class_loss = self.reduce_sum(class_loss, ()) | |||
| return class_loss | |||
| @@ -0,0 +1,143 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Learning rate scheduler.""" | |||
| import math | |||
| from collections import Counter | |||
| import numpy as np | |||
| def linear_warmup_lr(current_step, warmup_steps, base_lr, init_lr): | |||
| """Linear learning rate.""" | |||
| lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps) | |||
| lr = float(init_lr) + lr_inc * current_step | |||
| return lr | |||
| def warmup_step_lr(lr, lr_epochs, steps_per_epoch, warmup_epochs, max_epoch, gamma=0.1): | |||
| """Warmup step learning rate.""" | |||
| base_lr = lr | |||
| warmup_init_lr = 0 | |||
| total_steps = int(max_epoch * steps_per_epoch) | |||
| warmup_steps = int(warmup_epochs * steps_per_epoch) | |||
| milestones = lr_epochs | |||
| milestones_steps = [] | |||
| for milestone in milestones: | |||
| milestones_step = milestone * steps_per_epoch | |||
| milestones_steps.append(milestones_step) | |||
| lr_each_step = [] | |||
| lr = base_lr | |||
| milestones_steps_counter = Counter(milestones_steps) | |||
| for i in range(total_steps): | |||
| if i < warmup_steps: | |||
| lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr) | |||
| else: | |||
| lr = lr * gamma**milestones_steps_counter[i] | |||
| lr_each_step.append(lr) | |||
| return np.array(lr_each_step).astype(np.float32) | |||
| def multi_step_lr(lr, milestones, steps_per_epoch, max_epoch, gamma=0.1): | |||
| return warmup_step_lr(lr, milestones, steps_per_epoch, 0, max_epoch, gamma=gamma) | |||
| def step_lr(lr, epoch_size, steps_per_epoch, max_epoch, gamma=0.1): | |||
| lr_epochs = [] | |||
| for i in range(1, max_epoch): | |||
| if i % epoch_size == 0: | |||
| lr_epochs.append(i) | |||
| return multi_step_lr(lr, lr_epochs, steps_per_epoch, max_epoch, gamma=gamma) | |||
| def warmup_cosine_annealing_lr(lr, steps_per_epoch, warmup_epochs, max_epoch, T_max, eta_min=0): | |||
| """Cosine annealing learning rate.""" | |||
| base_lr = lr | |||
| total_steps = int(max_epoch * steps_per_epoch) | |||
| warmup_steps = int(warmup_epochs * steps_per_epoch) | |||
| lr_each_step = [] | |||
| for i in range(total_steps): | |||
| last_epoch = i // steps_per_epoch | |||
| if i < warmup_steps: | |||
| lr = 0 | |||
| else: | |||
| lr = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi*last_epoch / T_max)) / 2 | |||
| lr_each_step.append(lr) | |||
| return np.array(lr_each_step).astype(np.float32) | |||
| def warmup_cosine_annealing_lr_V2(lr, steps_per_epoch, warmup_epochs, max_epoch, T_max, eta_min=0): | |||
| """Cosine annealing learning rate V2.""" | |||
| base_lr = lr | |||
| warmup_init_lr = 0 | |||
| total_steps = int(max_epoch * steps_per_epoch) | |||
| warmup_steps = int(warmup_epochs * steps_per_epoch) | |||
| last_lr = 0 | |||
| last_epoch_V1 = 0 | |||
| T_max_V2 = int(max_epoch*1/3) | |||
| lr_each_step = [] | |||
| for i in range(total_steps): | |||
| last_epoch = i // steps_per_epoch | |||
| if i < warmup_steps: | |||
| lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr) | |||
| else: | |||
| if i < total_steps*2/3: | |||
| lr = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi*last_epoch / T_max)) / 2 | |||
| last_lr = lr | |||
| last_epoch_V1 = last_epoch | |||
| else: | |||
| base_lr = last_lr | |||
| last_epoch = last_epoch-last_epoch_V1 | |||
| lr = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi * last_epoch / T_max_V2)) / 2 | |||
| lr_each_step.append(lr) | |||
| return np.array(lr_each_step).astype(np.float32) | |||
| def warmup_cosine_annealing_lr_sample(lr, steps_per_epoch, warmup_epochs, max_epoch, T_max, eta_min=0): | |||
| """Warmup cosine annealing learning rate.""" | |||
| start_sample_epoch = 60 | |||
| step_sample = 2 | |||
| tobe_sampled_epoch = 60 | |||
| end_sampled_epoch = start_sample_epoch + step_sample*tobe_sampled_epoch | |||
| max_sampled_epoch = max_epoch+tobe_sampled_epoch | |||
| T_max = max_sampled_epoch | |||
| base_lr = lr | |||
| warmup_init_lr = 0 | |||
| total_steps = int(max_epoch * steps_per_epoch) | |||
| total_sampled_steps = int(max_sampled_epoch * steps_per_epoch) | |||
| warmup_steps = int(warmup_epochs * steps_per_epoch) | |||
| lr_each_step = [] | |||
| for i in range(total_sampled_steps): | |||
| last_epoch = i // steps_per_epoch | |||
| if last_epoch in range(start_sample_epoch, end_sampled_epoch, step_sample): | |||
| continue | |||
| if i < warmup_steps: | |||
| lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr) | |||
| else: | |||
| lr = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi*last_epoch / T_max)) / 2 | |||
| lr_each_step.append(lr) | |||
| assert total_steps == len(lr_each_step) | |||
| return np.array(lr_each_step).astype(np.float32) | |||
| @@ -0,0 +1,577 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Preprocess dataset.""" | |||
| import random | |||
| import threading | |||
| import copy | |||
| import numpy as np | |||
| from PIL import Image | |||
| import cv2 | |||
| def _rand(a=0., b=1.): | |||
| return np.random.rand() * (b - a) + a | |||
| def bbox_iou(bbox_a, bbox_b, offset=0): | |||
| """Calculate Intersection-Over-Union(IOU) of two bounding boxes. | |||
| Parameters | |||
| ---------- | |||
| bbox_a : numpy.ndarray | |||
| An ndarray with shape :math:`(N, 4)`. | |||
| bbox_b : numpy.ndarray | |||
| An ndarray with shape :math:`(M, 4)`. | |||
| offset : float or int, default is 0 | |||
| The ``offset`` is used to control the whether the width(or height) is computed as | |||
| (right - left + ``offset``). | |||
| Note that the offset must be 0 for normalized bboxes, whose ranges are in ``[0, 1]``. | |||
| Returns | |||
| ------- | |||
| numpy.ndarray | |||
| An ndarray with shape :math:`(N, M)` indicates IOU between each pairs of | |||
| bounding boxes in `bbox_a` and `bbox_b`. | |||
| """ | |||
| if bbox_a.shape[1] < 4 or bbox_b.shape[1] < 4: | |||
| raise IndexError("Bounding boxes axis 1 must have at least length 4") | |||
| tl = np.maximum(bbox_a[:, None, :2], bbox_b[:, :2]) | |||
| br = np.minimum(bbox_a[:, None, 2:4], bbox_b[:, 2:4]) | |||
| area_i = np.prod(br - tl + offset, axis=2) * (tl < br).all(axis=2) | |||
| area_a = np.prod(bbox_a[:, 2:4] - bbox_a[:, :2] + offset, axis=1) | |||
| area_b = np.prod(bbox_b[:, 2:4] - bbox_b[:, :2] + offset, axis=1) | |||
| return area_i / (area_a[:, None] + area_b - area_i) | |||
| def statistic_normalize_img(img, statistic_norm): | |||
| """Statistic normalize images.""" | |||
| # img: RGB | |||
| if isinstance(img, Image.Image): | |||
| img = np.array(img) | |||
| img = img/255. | |||
| mean = np.array([0.485, 0.456, 0.406]) | |||
| std = np.array([0.229, 0.224, 0.225]) | |||
| if statistic_norm: | |||
| img = (img - mean) / std | |||
| return img | |||
| def get_interp_method(interp, sizes=()): | |||
| """Get the interpolation method for resize functions. | |||
| The major purpose of this function is to wrap a random interp method selection | |||
| and a auto-estimation method. | |||
| Parameters | |||
| ---------- | |||
| interp : int | |||
| interpolation method for all resizing operations | |||
| Possible values: | |||
| 0: Nearest Neighbors Interpolation. | |||
| 1: Bilinear interpolation. | |||
| 2: Bicubic interpolation over 4x4 pixel neighborhood. | |||
| 3: Nearest Neighbors. [Originally it should be Area-based, | |||
| as we cannot find Area-based, so we use NN instead. | |||
| Area-based (resampling using pixel area relation). It may be a | |||
| preferred method for image decimation, as it gives moire-free | |||
| results. But when the image is zoomed, it is similar to the Nearest | |||
| Neighbors method. (used by default). | |||
| 4: Lanczos interpolation over 8x8 pixel neighborhood. | |||
| 9: Cubic for enlarge, area for shrink, bilinear for others | |||
| 10: Random select from interpolation method metioned above. | |||
| Note: | |||
| When shrinking an image, it will generally look best with AREA-based | |||
| interpolation, whereas, when enlarging an image, it will generally look best | |||
| with Bicubic (slow) or Bilinear (faster but still looks OK). | |||
| More details can be found in the documentation of OpenCV, please refer to | |||
| http://docs.opencv.org/master/da/d54/group__imgproc__transform.html. | |||
| sizes : tuple of int | |||
| (old_height, old_width, new_height, new_width), if None provided, auto(9) | |||
| will return Area(2) anyway. | |||
| Returns | |||
| ------- | |||
| int | |||
| interp method from 0 to 4 | |||
| """ | |||
| if interp == 9: | |||
| if sizes: | |||
| assert len(sizes) == 4 | |||
| oh, ow, nh, nw = sizes | |||
| if nh > oh and nw > ow: | |||
| return 2 | |||
| if nh < oh and nw < ow: | |||
| return 0 | |||
| return 1 | |||
| return 2 | |||
| if interp == 10: | |||
| return random.randint(0, 4) | |||
| if interp not in (0, 1, 2, 3, 4): | |||
| raise ValueError('Unknown interp method %d' % interp) | |||
| return interp | |||
| def pil_image_reshape(interp): | |||
| """Reshape pil image.""" | |||
| reshape_type = { | |||
| 0: Image.NEAREST, | |||
| 1: Image.BILINEAR, | |||
| 2: Image.BICUBIC, | |||
| 3: Image.NEAREST, | |||
| 4: Image.LANCZOS, | |||
| } | |||
| return reshape_type[interp] | |||
| def _preprocess_true_boxes(true_boxes, anchors, in_shape, num_classes, | |||
| max_boxes, label_smooth, label_smooth_factor=0.1): | |||
| """Preprocess annotation boxes.""" | |||
| anchors = np.array(anchors) | |||
| num_layers = anchors.shape[0] // 3 | |||
| anchor_mask = [[6, 7, 8], [3, 4, 5], [0, 1, 2]] | |||
| true_boxes = np.array(true_boxes, dtype='float32') | |||
| input_shape = np.array(in_shape, dtype='int32') | |||
| boxes_xy = (true_boxes[..., 0:2] + true_boxes[..., 2:4]) // 2. | |||
| # trans to box center point | |||
| boxes_wh = true_boxes[..., 2:4] - true_boxes[..., 0:2] | |||
| # input_shape is [h, w] | |||
| true_boxes[..., 0:2] = boxes_xy / input_shape[::-1] | |||
| true_boxes[..., 2:4] = boxes_wh / input_shape[::-1] | |||
| # true_boxes = [xywh] | |||
| grid_shapes = [input_shape // 32, input_shape // 16, input_shape // 8] | |||
| # grid_shape [h, w] | |||
| y_true = [np.zeros((grid_shapes[l][0], grid_shapes[l][1], len(anchor_mask[l]), | |||
| 5 + num_classes), dtype='float32') for l in range(num_layers)] | |||
| # y_true [gridy, gridx] | |||
| anchors = np.expand_dims(anchors, 0) | |||
| anchors_max = anchors / 2. | |||
| anchors_min = -anchors_max | |||
| valid_mask = boxes_wh[..., 0] > 0 | |||
| wh = boxes_wh[valid_mask] | |||
| if wh.size > 0: | |||
| wh = np.expand_dims(wh, -2) | |||
| boxes_max = wh / 2. | |||
| boxes_min = -boxes_max | |||
| intersect_min = np.maximum(boxes_min, anchors_min) | |||
| intersect_max = np.minimum(boxes_max, anchors_max) | |||
| intersect_wh = np.maximum(intersect_max - intersect_min, 0.) | |||
| intersect_area = intersect_wh[..., 0] * intersect_wh[..., 1] | |||
| box_area = wh[..., 0] * wh[..., 1] | |||
| anchor_area = anchors[..., 0] * anchors[..., 1] | |||
| iou = intersect_area / (box_area + anchor_area - intersect_area) | |||
| best_anchor = np.argmax(iou, axis=-1) | |||
| for t, n in enumerate(best_anchor): | |||
| for l in range(num_layers): | |||
| if n in anchor_mask[l]: | |||
| i = np.floor(true_boxes[t, 0] * grid_shapes[l][1]).astype('int32') # grid_y | |||
| j = np.floor(true_boxes[t, 1] * grid_shapes[l][0]).astype('int32') # grid_x | |||
| k = anchor_mask[l].index(n) | |||
| c = true_boxes[t, 4].astype('int32') | |||
| y_true[l][j, i, k, 0:4] = true_boxes[t, 0:4] | |||
| y_true[l][j, i, k, 4] = 1. | |||
| # lable-smooth | |||
| if label_smooth: | |||
| sigma = label_smooth_factor/(num_classes-1) | |||
| y_true[l][j, i, k, 5:] = sigma | |||
| y_true[l][j, i, k, 5+c] = 1-label_smooth_factor | |||
| else: | |||
| y_true[l][j, i, k, 5 + c] = 1. | |||
| # pad_gt_boxes for avoiding dynamic shape | |||
| pad_gt_box0 = np.zeros(shape=[max_boxes, 4], dtype=np.float32) | |||
| pad_gt_box1 = np.zeros(shape=[max_boxes, 4], dtype=np.float32) | |||
| pad_gt_box2 = np.zeros(shape=[max_boxes, 4], dtype=np.float32) | |||
| mask0 = np.reshape(y_true[0][..., 4:5], [-1]) | |||
| gt_box0 = np.reshape(y_true[0][..., 0:4], [-1, 4]) | |||
| # gt_box [boxes, [x,y,w,h]] | |||
| gt_box0 = gt_box0[mask0 == 1] | |||
| # gt_box0: get all boxes which have object | |||
| pad_gt_box0[:gt_box0.shape[0]] = gt_box0 | |||
| # gt_box0.shape[0]: total number of boxes in gt_box0 | |||
| # top N of pad_gt_box0 is real box, and after are pad by zero | |||
| mask1 = np.reshape(y_true[1][..., 4:5], [-1]) | |||
| gt_box1 = np.reshape(y_true[1][..., 0:4], [-1, 4]) | |||
| gt_box1 = gt_box1[mask1 == 1] | |||
| pad_gt_box1[:gt_box1.shape[0]] = gt_box1 | |||
| mask2 = np.reshape(y_true[2][..., 4:5], [-1]) | |||
| gt_box2 = np.reshape(y_true[2][..., 0:4], [-1, 4]) | |||
| gt_box2 = gt_box2[mask2 == 1] | |||
| pad_gt_box2[:gt_box2.shape[0]] = gt_box2 | |||
| return y_true[0], y_true[1], y_true[2], pad_gt_box0, pad_gt_box1, pad_gt_box2 | |||
| def _reshape_data(image, image_size): | |||
| """Reshape image.""" | |||
| if not isinstance(image, Image.Image): | |||
| image = Image.fromarray(image) | |||
| ori_w, ori_h = image.size | |||
| ori_image_shape = np.array([ori_w, ori_h], np.int32) | |||
| # original image shape fir:H sec:W | |||
| h, w = image_size | |||
| interp = get_interp_method(interp=9, sizes=(ori_h, ori_w, h, w)) | |||
| image = image.resize((w, h), pil_image_reshape(interp)) | |||
| image_data = statistic_normalize_img(image, statistic_norm=True) | |||
| if len(image_data.shape) == 2: | |||
| image_data = np.expand_dims(image_data, axis=-1) | |||
| image_data = np.concatenate([image_data, image_data, image_data], axis=-1) | |||
| image_data = image_data.astype(np.float32) | |||
| return image_data, ori_image_shape | |||
| def color_distortion(img, hue, sat, val, device_num): | |||
| """Color distortion.""" | |||
| hue = _rand(-hue, hue) | |||
| sat = _rand(1, sat) if _rand() < .5 else 1 / _rand(1, sat) | |||
| val = _rand(1, val) if _rand() < .5 else 1 / _rand(1, val) | |||
| if device_num != 1: | |||
| cv2.setNumThreads(1) | |||
| x = cv2.cvtColor(img, cv2.COLOR_RGB2HSV_FULL) | |||
| x = x / 255. | |||
| x[..., 0] += hue | |||
| x[..., 0][x[..., 0] > 1] -= 1 | |||
| x[..., 0][x[..., 0] < 0] += 1 | |||
| x[..., 1] *= sat | |||
| x[..., 2] *= val | |||
| x[x > 1] = 1 | |||
| x[x < 0] = 0 | |||
| x = x * 255. | |||
| x = x.astype(np.uint8) | |||
| image_data = cv2.cvtColor(x, cv2.COLOR_HSV2RGB_FULL) | |||
| return image_data | |||
| def filp_pil_image(img): | |||
| return img.transpose(Image.FLIP_LEFT_RIGHT) | |||
| def convert_gray_to_color(img): | |||
| if len(img.shape) == 2: | |||
| img = np.expand_dims(img, axis=-1) | |||
| img = np.concatenate([img, img, img], axis=-1) | |||
| return img | |||
| def _is_iou_satisfied_constraint(min_iou, max_iou, box, crop_box): | |||
| iou = bbox_iou(box, crop_box) | |||
| return min_iou <= iou.min() and max_iou >= iou.max() | |||
| def _choose_candidate_by_constraints(max_trial, input_w, input_h, image_w, image_h, jitter, box, use_constraints): | |||
| """Choose candidate by constraints.""" | |||
| if use_constraints: | |||
| constraints = ( | |||
| (0.1, None), | |||
| (0.3, None), | |||
| (0.5, None), | |||
| (0.7, None), | |||
| (0.9, None), | |||
| (None, 1), | |||
| ) | |||
| else: | |||
| constraints = ( | |||
| (None, None), | |||
| ) | |||
| # add default candidate | |||
| candidates = [(0, 0, input_w, input_h)] | |||
| for constraint in constraints: | |||
| min_iou, max_iou = constraint | |||
| min_iou = -np.inf if min_iou is None else min_iou | |||
| max_iou = np.inf if max_iou is None else max_iou | |||
| for _ in range(max_trial): | |||
| # box_data should have at least one box | |||
| new_ar = float(input_w) / float(input_h) * _rand(1 - jitter, 1 + jitter) / _rand(1 - jitter, 1 + jitter) | |||
| scale = _rand(0.25, 2) | |||
| if new_ar < 1: | |||
| nh = int(scale * input_h) | |||
| nw = int(nh * new_ar) | |||
| else: | |||
| nw = int(scale * input_w) | |||
| nh = int(nw / new_ar) | |||
| dx = int(_rand(0, input_w - nw)) | |||
| dy = int(_rand(0, input_h - nh)) | |||
| if box.size > 0: | |||
| t_box = copy.deepcopy(box) | |||
| t_box[:, [0, 2]] = t_box[:, [0, 2]] * float(nw) / float(image_w) + dx | |||
| t_box[:, [1, 3]] = t_box[:, [1, 3]] * float(nh) / float(image_h) + dy | |||
| crop_box = np.array((0, 0, input_w, input_h)) | |||
| if not _is_iou_satisfied_constraint(min_iou, max_iou, t_box, crop_box[np.newaxis]): | |||
| continue | |||
| else: | |||
| candidates.append((dx, dy, nw, nh)) | |||
| else: | |||
| raise Exception("!!! annotation box is less than 1") | |||
| return candidates | |||
| def _correct_bbox_by_candidates(candidates, input_w, input_h, image_w, | |||
| image_h, flip, box, box_data, allow_outside_center): | |||
| """Calculate correct boxes.""" | |||
| while candidates: | |||
| if len(candidates) > 1: | |||
| # ignore default candidate which do not crop | |||
| candidate = candidates.pop(np.random.randint(1, len(candidates))) | |||
| else: | |||
| candidate = candidates.pop(np.random.randint(0, len(candidates))) | |||
| dx, dy, nw, nh = candidate | |||
| t_box = copy.deepcopy(box) | |||
| t_box[:, [0, 2]] = t_box[:, [0, 2]] * float(nw) / float(image_w) + dx | |||
| t_box[:, [1, 3]] = t_box[:, [1, 3]] * float(nh) / float(image_h) + dy | |||
| if flip: | |||
| t_box[:, [0, 2]] = input_w - t_box[:, [2, 0]] | |||
| if allow_outside_center: | |||
| pass | |||
| else: | |||
| t_box = t_box[np.logical_and((t_box[:, 0] + t_box[:, 2])/2. >= 0., (t_box[:, 1] + t_box[:, 3])/2. >= 0.)] | |||
| t_box = t_box[np.logical_and((t_box[:, 0] + t_box[:, 2]) / 2. <= input_w, | |||
| (t_box[:, 1] + t_box[:, 3]) / 2. <= input_h)] | |||
| # recorrect x, y for case x,y < 0 reset to zero, after dx and dy, some box can smaller than zero | |||
| t_box[:, 0:2][t_box[:, 0:2] < 0] = 0 | |||
| # recorrect w,h not higher than input size | |||
| t_box[:, 2][t_box[:, 2] > input_w] = input_w | |||
| t_box[:, 3][t_box[:, 3] > input_h] = input_h | |||
| box_w = t_box[:, 2] - t_box[:, 0] | |||
| box_h = t_box[:, 3] - t_box[:, 1] | |||
| # discard invalid box: w or h smaller than 1 pixel | |||
| t_box = t_box[np.logical_and(box_w > 1, box_h > 1)] | |||
| if t_box.shape[0] > 0: | |||
| # break if number of find t_box | |||
| box_data[: len(t_box)] = t_box | |||
| return box_data, candidate | |||
| raise Exception('all candidates can not satisfied re-correct bbox') | |||
| def _data_aug(image, box, jitter, hue, sat, val, image_input_size, max_boxes, | |||
| anchors, num_classes, max_trial=10, device_num=1): | |||
| """Crop an image randomly with bounding box constraints. | |||
| This data augmentation is used in training of | |||
| Single Shot Multibox Detector [#]_. More details can be found in | |||
| data augmentation section of the original paper. | |||
| .. [#] Wei Liu, Dragomir Anguelov, Dumitru Erhan, Christian Szegedy, | |||
| Scott Reed, Cheng-Yang Fu, Alexander C. Berg. | |||
| SSD: Single Shot MultiBox Detector. ECCV 2016.""" | |||
| if not isinstance(image, Image.Image): | |||
| image = Image.fromarray(image) | |||
| image_w, image_h = image.size | |||
| input_h, input_w = image_input_size | |||
| np.random.shuffle(box) | |||
| if len(box) > max_boxes: | |||
| box = box[:max_boxes] | |||
| flip = _rand() < .5 | |||
| box_data = np.zeros((max_boxes, 5)) | |||
| candidates = _choose_candidate_by_constraints(use_constraints=False, | |||
| max_trial=max_trial, | |||
| input_w=input_w, | |||
| input_h=input_h, | |||
| image_w=image_w, | |||
| image_h=image_h, | |||
| jitter=jitter, | |||
| box=box) | |||
| box_data, candidate = _correct_bbox_by_candidates(candidates=candidates, | |||
| input_w=input_w, | |||
| input_h=input_h, | |||
| image_w=image_w, | |||
| image_h=image_h, | |||
| flip=flip, | |||
| box=box, | |||
| box_data=box_data, | |||
| allow_outside_center=True) | |||
| dx, dy, nw, nh = candidate | |||
| interp = get_interp_method(interp=10) | |||
| image = image.resize((nw, nh), pil_image_reshape(interp)) | |||
| # place image, gray color as back graoud | |||
| new_image = Image.new('RGB', (input_w, input_h), (128, 128, 128)) | |||
| new_image.paste(image, (dx, dy)) | |||
| image = new_image | |||
| if flip: | |||
| image = filp_pil_image(image) | |||
| image = np.array(image) | |||
| image = convert_gray_to_color(image) | |||
| image_data = color_distortion(image, hue, sat, val, device_num) | |||
| image_data = statistic_normalize_img(image_data, statistic_norm=True) | |||
| image_data = image_data.astype(np.float32) | |||
| return image_data, box_data | |||
| def preprocess_fn(image, box, config, input_size, device_num): | |||
| """Preprocess data function.""" | |||
| config_anchors = config.anchor_scales | |||
| anchors = np.array([list(x) for x in config_anchors]) | |||
| max_boxes = config.max_box | |||
| num_classes = config.num_classes | |||
| jitter = config.jitter | |||
| hue = config.hue | |||
| sat = config.saturation | |||
| val = config.value | |||
| image, anno = _data_aug(image, box, jitter=jitter, hue=hue, sat=sat, val=val, | |||
| image_input_size=input_size, max_boxes=max_boxes, | |||
| num_classes=num_classes, anchors=anchors, device_num=device_num) | |||
| return image, anno | |||
| def reshape_fn(image, img_id, config): | |||
| input_size = config.test_img_shape | |||
| image, ori_image_shape = _reshape_data(image, image_size=input_size) | |||
| return image, ori_image_shape, img_id | |||
| class MultiScaleTrans: | |||
| """Multi scale transform.""" | |||
| def __init__(self, config, device_num): | |||
| self.config = config | |||
| self.seed = 0 | |||
| self.size_list = [] | |||
| self.resize_rate = config.resize_rate | |||
| self.dataset_size = config.dataset_size | |||
| self.size_dict = {} | |||
| self.seed_num = int(1e6) | |||
| self.seed_list = self.generate_seed_list(seed_num=self.seed_num) | |||
| self.resize_count_num = int(np.ceil(self.dataset_size / self.resize_rate)) | |||
| self.device_num = device_num | |||
| def generate_seed_list(self, init_seed=1234, seed_num=int(1e6), seed_range=(1, 1000)): | |||
| seed_list = [] | |||
| random.seed(init_seed) | |||
| for _ in range(seed_num): | |||
| seed = random.randint(seed_range[0], seed_range[1]) | |||
| seed_list.append(seed) | |||
| return seed_list | |||
| def __call__(self, imgs, annos, batchInfo): | |||
| epoch_num = batchInfo.get_epoch_num() | |||
| size_idx = int(batchInfo.get_batch_num() / self.resize_rate) | |||
| seed_key = self.seed_list[(epoch_num * self.resize_count_num + size_idx) % self.seed_num] | |||
| ret_imgs = [] | |||
| ret_annos = [] | |||
| if self.size_dict.get(seed_key, None) is None: | |||
| random.seed(seed_key) | |||
| new_size = random.choice(self.config.multi_scale) | |||
| self.size_dict[seed_key] = new_size | |||
| seed = seed_key | |||
| input_size = self.size_dict[seed] | |||
| for img, anno in zip(imgs, annos): | |||
| img, anno = preprocess_fn(img, anno, self.config, input_size, self.device_num) | |||
| ret_imgs.append(img.transpose(2, 0, 1).copy()) | |||
| ret_annos.append(anno) | |||
| return np.array(ret_imgs), np.array(ret_annos) | |||
| def thread_batch_preprocess_true_box(annos, config, input_shape, result_index, batch_bbox_true_1, batch_bbox_true_2, | |||
| batch_bbox_true_3, batch_gt_box1, batch_gt_box2, batch_gt_box3): | |||
| """Preprocess true box for multi-thread.""" | |||
| i = 0 | |||
| for anno in annos: | |||
| bbox_true_1, bbox_true_2, bbox_true_3, gt_box1, gt_box2, gt_box3 = \ | |||
| _preprocess_true_boxes(true_boxes=anno, anchors=config.anchor_scales, in_shape=input_shape, | |||
| num_classes=config.num_classes, max_boxes=config.max_box, | |||
| label_smooth=config.label_smooth, label_smooth_factor=config.label_smooth_factor) | |||
| batch_bbox_true_1[result_index + i] = bbox_true_1 | |||
| batch_bbox_true_2[result_index + i] = bbox_true_2 | |||
| batch_bbox_true_3[result_index + i] = bbox_true_3 | |||
| batch_gt_box1[result_index + i] = gt_box1 | |||
| batch_gt_box2[result_index + i] = gt_box2 | |||
| batch_gt_box3[result_index + i] = gt_box3 | |||
| i = i + 1 | |||
| def batch_preprocess_true_box(annos, config, input_shape): | |||
| """Preprocess true box with multi-thread.""" | |||
| batch_bbox_true_1 = [] | |||
| batch_bbox_true_2 = [] | |||
| batch_bbox_true_3 = [] | |||
| batch_gt_box1 = [] | |||
| batch_gt_box2 = [] | |||
| batch_gt_box3 = [] | |||
| threads = [] | |||
| step = 4 | |||
| for index in range(0, len(annos), step): | |||
| for _ in range(step): | |||
| batch_bbox_true_1.append(None) | |||
| batch_bbox_true_2.append(None) | |||
| batch_bbox_true_3.append(None) | |||
| batch_gt_box1.append(None) | |||
| batch_gt_box2.append(None) | |||
| batch_gt_box3.append(None) | |||
| step_anno = annos[index: index + step] | |||
| t = threading.Thread(target=thread_batch_preprocess_true_box, | |||
| args=(step_anno, config, input_shape, index, batch_bbox_true_1, batch_bbox_true_2, | |||
| batch_bbox_true_3, batch_gt_box1, batch_gt_box2, batch_gt_box3)) | |||
| t.start() | |||
| threads.append(t) | |||
| for t in threads: | |||
| t.join() | |||
| return np.array(batch_bbox_true_1), np.array(batch_bbox_true_2), np.array(batch_bbox_true_3), \ | |||
| np.array(batch_gt_box1), np.array(batch_gt_box2), np.array(batch_gt_box3) | |||
| def batch_preprocess_true_box_single(annos, config, input_shape): | |||
| """Preprocess true boxes.""" | |||
| batch_bbox_true_1 = [] | |||
| batch_bbox_true_2 = [] | |||
| batch_bbox_true_3 = [] | |||
| batch_gt_box1 = [] | |||
| batch_gt_box2 = [] | |||
| batch_gt_box3 = [] | |||
| for anno in annos: | |||
| bbox_true_1, bbox_true_2, bbox_true_3, gt_box1, gt_box2, gt_box3 = \ | |||
| _preprocess_true_boxes(true_boxes=anno, anchors=config.anchor_scales, in_shape=input_shape, | |||
| num_classes=config.num_classes, max_boxes=config.max_box, | |||
| label_smooth=config.label_smooth, label_smooth_factor=config.label_smooth_factor) | |||
| batch_bbox_true_1.append(bbox_true_1) | |||
| batch_bbox_true_2.append(bbox_true_2) | |||
| batch_bbox_true_3.append(bbox_true_3) | |||
| batch_gt_box1.append(gt_box1) | |||
| batch_gt_box2.append(gt_box2) | |||
| batch_gt_box3.append(gt_box3) | |||
| return np.array(batch_bbox_true_1), np.array(batch_bbox_true_2), np.array(batch_bbox_true_3), \ | |||
| np.array(batch_gt_box1), np.array(batch_gt_box2), np.array(batch_gt_box3) | |||
| @@ -0,0 +1,177 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Util class or function.""" | |||
| from mindspore.train.serialization import load_checkpoint | |||
| import mindspore.nn as nn | |||
| class AverageMeter: | |||
| """Computes and stores the average and current value""" | |||
| def __init__(self, name, fmt=':f', tb_writer=None): | |||
| self.name = name | |||
| self.fmt = fmt | |||
| self.reset() | |||
| self.tb_writer = tb_writer | |||
| self.cur_step = 1 | |||
| self.val = 0 | |||
| self.avg = 0 | |||
| self.sum = 0 | |||
| self.count = 0 | |||
| def reset(self): | |||
| self.val = 0 | |||
| self.avg = 0 | |||
| self.sum = 0 | |||
| self.count = 0 | |||
| def update(self, val, n=1): | |||
| self.val = val | |||
| self.sum += val * n | |||
| self.count += n | |||
| self.avg = self.sum / self.count | |||
| if self.tb_writer is not None: | |||
| self.tb_writer.add_scalar(self.name, self.val, self.cur_step) | |||
| self.cur_step += 1 | |||
| def __str__(self): | |||
| fmtstr = '{name}:{avg' + self.fmt + '}' | |||
| return fmtstr.format(**self.__dict__) | |||
| def load_backbone(net, ckpt_path, args): | |||
| """Load darknet53 backbone checkpoint.""" | |||
| param_dict = load_checkpoint(ckpt_path) | |||
| yolo_backbone_prefix = 'feature_map.backbone' | |||
| darknet_backbone_prefix = 'network.backbone' | |||
| find_param = [] | |||
| not_found_param = [] | |||
| for name, cell in net.cells_and_names(): | |||
| if name.startswith(yolo_backbone_prefix): | |||
| name = name.replace(yolo_backbone_prefix, darknet_backbone_prefix) | |||
| if isinstance(cell, (nn.Conv2d, nn.Dense)): | |||
| darknet_weight = '{}.weight'.format(name) | |||
| darknet_bias = '{}.bias'.format(name) | |||
| if darknet_weight in param_dict: | |||
| cell.weight.default_input = param_dict[darknet_weight].data | |||
| find_param.append(darknet_weight) | |||
| else: | |||
| not_found_param.append(darknet_weight) | |||
| if darknet_bias in param_dict: | |||
| cell.bias.default_input = param_dict[darknet_bias].data | |||
| find_param.append(darknet_bias) | |||
| else: | |||
| not_found_param.append(darknet_bias) | |||
| elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)): | |||
| darknet_moving_mean = '{}.moving_mean'.format(name) | |||
| darknet_moving_variance = '{}.moving_variance'.format(name) | |||
| darknet_gamma = '{}.gamma'.format(name) | |||
| darknet_beta = '{}.beta'.format(name) | |||
| if darknet_moving_mean in param_dict: | |||
| cell.moving_mean.default_input = param_dict[darknet_moving_mean].data | |||
| find_param.append(darknet_moving_mean) | |||
| else: | |||
| not_found_param.append(darknet_moving_mean) | |||
| if darknet_moving_variance in param_dict: | |||
| cell.moving_variance.default_input = param_dict[darknet_moving_variance].data | |||
| find_param.append(darknet_moving_variance) | |||
| else: | |||
| not_found_param.append(darknet_moving_variance) | |||
| if darknet_gamma in param_dict: | |||
| cell.gamma.default_input = param_dict[darknet_gamma].data | |||
| find_param.append(darknet_gamma) | |||
| else: | |||
| not_found_param.append(darknet_gamma) | |||
| if darknet_beta in param_dict: | |||
| cell.beta.default_input = param_dict[darknet_beta].data | |||
| find_param.append(darknet_beta) | |||
| else: | |||
| not_found_param.append(darknet_beta) | |||
| args.logger.info('================found_param {}========='.format(len(find_param))) | |||
| args.logger.info(find_param) | |||
| args.logger.info('================not_found_param {}========='.format(len(not_found_param))) | |||
| args.logger.info(not_found_param) | |||
| args.logger.info('=====load {} successfully ====='.format(ckpt_path)) | |||
| return net | |||
| def default_wd_filter(x): | |||
| """default weight decay filter.""" | |||
| parameter_name = x.name | |||
| if parameter_name.endswith('.bias'): | |||
| # all bias not using weight decay | |||
| return False | |||
| if parameter_name.endswith('.gamma'): | |||
| # bn weight bias not using weight decay, be carefully for now x not include BN | |||
| return False | |||
| if parameter_name.endswith('.beta'): | |||
| # bn weight bias not using weight decay, be carefully for now x not include BN | |||
| return False | |||
| return True | |||
| def get_param_groups(network): | |||
| """Param groups for optimizer.""" | |||
| decay_params = [] | |||
| no_decay_params = [] | |||
| for x in network.trainable_params(): | |||
| parameter_name = x.name | |||
| if parameter_name.endswith('.bias'): | |||
| # all bias not using weight decay | |||
| no_decay_params.append(x) | |||
| elif parameter_name.endswith('.gamma'): | |||
| # bn weight bias not using weight decay, be carefully for now x not include BN | |||
| no_decay_params.append(x) | |||
| elif parameter_name.endswith('.beta'): | |||
| # bn weight bias not using weight decay, be carefully for now x not include BN | |||
| no_decay_params.append(x) | |||
| else: | |||
| decay_params.append(x) | |||
| return [{'params': no_decay_params, 'weight_decay': 0.0}, {'params': decay_params}] | |||
| class ShapeRecord: | |||
| """Log image shape.""" | |||
| def __init__(self): | |||
| self.shape_record = { | |||
| 320: 0, | |||
| 352: 0, | |||
| 384: 0, | |||
| 416: 0, | |||
| 448: 0, | |||
| 480: 0, | |||
| 512: 0, | |||
| 544: 0, | |||
| 576: 0, | |||
| 608: 0, | |||
| 'total': 0 | |||
| } | |||
| def set(self, shape): | |||
| if len(shape) > 1: | |||
| shape = shape[0] | |||
| shape = int(shape) | |||
| self.shape_record[shape] += 1 | |||
| self.shape_record['total'] += 1 | |||
| def show(self, logger): | |||
| for key in self.shape_record: | |||
| rate = self.shape_record[key] / float(self.shape_record['total']) | |||
| logger.info('shape {}: {:.2f}%'.format(key, rate*100)) | |||
| @@ -0,0 +1,437 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """YOLOv3 based on DarkNet.""" | |||
| import mindspore as ms | |||
| import mindspore.nn as nn | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore import context | |||
| from mindspore.parallel._auto_parallel_context import auto_parallel_context | |||
| from mindspore.communication.management import get_group_size | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops import functional as F | |||
| from mindspore.ops import composite as C | |||
| from src.darknet import DarkNet, ResidualBlock | |||
| from src.config import ConfigYOLOV3DarkNet53 | |||
| from src.loss import XYLoss, WHLoss, ConfidenceLoss, ClassLoss | |||
| def _conv_bn_relu(in_channel, | |||
| out_channel, | |||
| ksize, | |||
| stride=1, | |||
| padding=0, | |||
| dilation=1, | |||
| alpha=0.1, | |||
| momentum=0.9, | |||
| eps=1e-5, | |||
| pad_mode="same"): | |||
| """Get a conv2d batchnorm and relu layer""" | |||
| return nn.Conv2dBnAct(in_channel, out_channel, ksize, | |||
| stride=stride, | |||
| pad_mode=pad_mode, | |||
| padding=padding, | |||
| dilation=dilation, | |||
| has_bn=True, | |||
| momentum=momentum, | |||
| eps=eps, | |||
| activation='leakyrelu', | |||
| alpha=alpha) | |||
| class YoloBlock(nn.Cell): | |||
| """ | |||
| YoloBlock for YOLOv3. | |||
| Args: | |||
| in_channels: Integer. Input channel. | |||
| out_chls: Interger. Middle channel. | |||
| out_channels: Integer. Output channel. | |||
| Returns: | |||
| Tuple, tuple of output tensor,(f1,f2,f3). | |||
| Examples: | |||
| YoloBlock(1024, 512, 255) | |||
| """ | |||
| def __init__(self, in_channels, out_chls, out_channels): | |||
| super(YoloBlock, self).__init__() | |||
| out_chls_2 = out_chls*2 | |||
| self.conv0 = _conv_bn_relu(in_channels, out_chls, ksize=1) | |||
| self.conv1 = _conv_bn_relu(out_chls, out_chls_2, ksize=3) | |||
| self.conv2 = _conv_bn_relu(out_chls_2, out_chls, ksize=1) | |||
| self.conv3 = _conv_bn_relu(out_chls, out_chls_2, ksize=3) | |||
| self.conv4 = _conv_bn_relu(out_chls_2, out_chls, ksize=1) | |||
| self.conv5 = _conv_bn_relu(out_chls, out_chls_2, ksize=3) | |||
| self.conv6 = nn.Conv2dBnAct(out_chls_2, out_channels, kernel_size=1, stride=1, | |||
| has_bias=True, has_bn=False, activation=None, after_fake=False) | |||
| def construct(self, x): | |||
| c1 = self.conv0(x) | |||
| c2 = self.conv1(c1) | |||
| c3 = self.conv2(c2) | |||
| c4 = self.conv3(c3) | |||
| c5 = self.conv4(c4) | |||
| c6 = self.conv5(c5) | |||
| out = self.conv6(c6) | |||
| return c5, out | |||
| class YOLOv3(nn.Cell): | |||
| """ | |||
| YOLOv3 Network. | |||
| Note: | |||
| backbone = darknet53 | |||
| Args: | |||
| backbone_shape: List. Darknet output channels shape. | |||
| backbone: Cell. Backbone Network. | |||
| out_channel: Interger. Output channel. | |||
| Returns: | |||
| Tensor, output tensor. | |||
| Examples: | |||
| YOLOv3(backbone_shape=[64, 128, 256, 512, 1024] | |||
| backbone=darknet53(), | |||
| out_channel=255) | |||
| """ | |||
| def __init__(self, backbone_shape, backbone, out_channel): | |||
| super(YOLOv3, self).__init__() | |||
| self.out_channel = out_channel | |||
| self.backbone = backbone | |||
| self.backblock0 = YoloBlock(backbone_shape[-1], out_chls=backbone_shape[-2], out_channels=out_channel) | |||
| self.conv1 = _conv_bn_relu(in_channel=backbone_shape[-2], out_channel=backbone_shape[-2]//2, ksize=1) | |||
| self.backblock1 = YoloBlock(in_channels=backbone_shape[-2]+backbone_shape[-3], | |||
| out_chls=backbone_shape[-3], | |||
| out_channels=out_channel) | |||
| self.conv2 = _conv_bn_relu(in_channel=backbone_shape[-3], out_channel=backbone_shape[-3]//2, ksize=1) | |||
| self.backblock2 = YoloBlock(in_channels=backbone_shape[-3]+backbone_shape[-4], | |||
| out_chls=backbone_shape[-4], | |||
| out_channels=out_channel) | |||
| self.concat = P.Concat(axis=1) | |||
| def construct(self, x): | |||
| # input_shape of x is (batch_size, 3, h, w) | |||
| # feature_map1 is (batch_size, backbone_shape[2], h/8, w/8) | |||
| # feature_map2 is (batch_size, backbone_shape[3], h/16, w/16) | |||
| # feature_map3 is (batch_size, backbone_shape[4], h/32, w/32) | |||
| img_hight = P.Shape()(x)[2] | |||
| img_width = P.Shape()(x)[3] | |||
| feature_map1, feature_map2, feature_map3 = self.backbone(x) | |||
| con1, big_object_output = self.backblock0(feature_map3) | |||
| con1 = self.conv1(con1) | |||
| ups1 = P.ResizeNearestNeighbor((img_hight / 16, img_width / 16))(con1) | |||
| con1 = self.concat((ups1, feature_map2)) | |||
| con2, medium_object_output = self.backblock1(con1) | |||
| con2 = self.conv2(con2) | |||
| ups2 = P.ResizeNearestNeighbor((img_hight / 8, img_width / 8))(con2) | |||
| con3 = self.concat((ups2, feature_map1)) | |||
| _, small_object_output = self.backblock2(con3) | |||
| return big_object_output, medium_object_output, small_object_output | |||
| class DetectionBlock(nn.Cell): | |||
| """ | |||
| YOLOv3 detection Network. It will finally output the detection result. | |||
| Args: | |||
| scale: Character. | |||
| config: ConfigYOLOV3DarkNet53, Configuration instance. | |||
| is_training: Bool, Whether train or not, default True. | |||
| Returns: | |||
| Tuple, tuple of output tensor,(f1,f2,f3). | |||
| Examples: | |||
| DetectionBlock(scale='l',stride=32) | |||
| """ | |||
| def __init__(self, scale, config=ConfigYOLOV3DarkNet53(), is_training=True): | |||
| super(DetectionBlock, self).__init__() | |||
| self.config = config | |||
| if scale == 's': | |||
| idx = (0, 1, 2) | |||
| elif scale == 'm': | |||
| idx = (3, 4, 5) | |||
| elif scale == 'l': | |||
| idx = (6, 7, 8) | |||
| else: | |||
| raise KeyError("Invalid scale value for DetectionBlock") | |||
| self.anchors = Tensor([self.config.anchor_scales[i] for i in idx], ms.float32) | |||
| self.num_anchors_per_scale = 3 | |||
| self.num_attrib = 4+1+self.config.num_classes | |||
| self.lambda_coord = 1 | |||
| self.sigmoid = nn.Sigmoid() | |||
| self.reshape = P.Reshape() | |||
| self.tile = P.Tile() | |||
| self.concat = P.Concat(axis=-1) | |||
| self.conf_training = is_training | |||
| def construct(self, x, input_shape): | |||
| num_batch = P.Shape()(x)[0] | |||
| grid_size = P.Shape()(x)[2:4] | |||
| # Reshape and transpose the feature to [n, grid_size[0], grid_size[1], 3, num_attrib] | |||
| prediction = P.Reshape()(x, (num_batch, | |||
| self.num_anchors_per_scale, | |||
| self.num_attrib, | |||
| grid_size[0], | |||
| grid_size[1])) | |||
| prediction = P.Transpose()(prediction, (0, 3, 4, 1, 2)) | |||
| range_x = range(grid_size[1]) | |||
| range_y = range(grid_size[0]) | |||
| grid_x = P.Cast()(F.tuple_to_array(range_x), ms.float32) | |||
| grid_y = P.Cast()(F.tuple_to_array(range_y), ms.float32) | |||
| # Tensor of shape [grid_size[0], grid_size[1], 1, 1] representing the coordinate of x/y axis for each grid | |||
| # [batch, gridx, gridy, 1, 1] | |||
| grid_x = self.tile(self.reshape(grid_x, (1, 1, -1, 1, 1)), (1, grid_size[0], 1, 1, 1)) | |||
| grid_y = self.tile(self.reshape(grid_y, (1, -1, 1, 1, 1)), (1, 1, grid_size[1], 1, 1)) | |||
| # Shape is [grid_size[0], grid_size[1], 1, 2] | |||
| grid = self.concat((grid_x, grid_y)) | |||
| box_xy = prediction[:, :, :, :, :2] | |||
| box_wh = prediction[:, :, :, :, 2:4] | |||
| box_confidence = prediction[:, :, :, :, 4:5] | |||
| box_probs = prediction[:, :, :, :, 5:] | |||
| # gridsize1 is x | |||
| # gridsize0 is y | |||
| box_xy = (self.sigmoid(box_xy) + grid) / P.Cast()(F.tuple_to_array((grid_size[1], grid_size[0])), ms.float32) | |||
| # box_wh is w->h | |||
| box_wh = P.Exp()(box_wh) * self.anchors / input_shape | |||
| box_confidence = self.sigmoid(box_confidence) | |||
| box_probs = self.sigmoid(box_probs) | |||
| if self.conf_training: | |||
| return grid, prediction, box_xy, box_wh | |||
| return self.concat((box_xy, box_wh, box_confidence, box_probs)) | |||
| class Iou(nn.Cell): | |||
| """Calculate the iou of boxes""" | |||
| def __init__(self): | |||
| super(Iou, self).__init__() | |||
| self.min = P.Minimum() | |||
| self.max = P.Maximum() | |||
| def construct(self, box1, box2): | |||
| # box1: pred_box [batch, gx, gy, anchors, 1, 4] ->4: [x_center, y_center, w, h] | |||
| # box2: gt_box [batch, 1, 1, 1, maxbox, 4] | |||
| # convert to topLeft and rightDown | |||
| box1_xy = box1[:, :, :, :, :, :2] | |||
| box1_wh = box1[:, :, :, :, :, 2:4] | |||
| box1_mins = box1_xy - box1_wh / F.scalar_to_array(2.0) # topLeft | |||
| box1_maxs = box1_xy + box1_wh / F.scalar_to_array(2.0) # rightDown | |||
| box2_xy = box2[:, :, :, :, :, :2] | |||
| box2_wh = box2[:, :, :, :, :, 2:4] | |||
| box2_mins = box2_xy - box2_wh / F.scalar_to_array(2.0) | |||
| box2_maxs = box2_xy + box2_wh / F.scalar_to_array(2.0) | |||
| intersect_mins = self.max(box1_mins, box2_mins) | |||
| intersect_maxs = self.min(box1_maxs, box2_maxs) | |||
| intersect_wh = self.max(intersect_maxs - intersect_mins, F.scalar_to_array(0.0)) | |||
| # P.squeeze: for effiecient slice | |||
| intersect_area = P.Squeeze(-1)(intersect_wh[:, :, :, :, :, 0:1]) * \ | |||
| P.Squeeze(-1)(intersect_wh[:, :, :, :, :, 1:2]) | |||
| box1_area = P.Squeeze(-1)(box1_wh[:, :, :, :, :, 0:1]) * P.Squeeze(-1)(box1_wh[:, :, :, :, :, 1:2]) | |||
| box2_area = P.Squeeze(-1)(box2_wh[:, :, :, :, :, 0:1]) * P.Squeeze(-1)(box2_wh[:, :, :, :, :, 1:2]) | |||
| iou = intersect_area / (box1_area + box2_area - intersect_area) | |||
| # iou : [batch, gx, gy, anchors, maxboxes] | |||
| return iou | |||
| class YoloLossBlock(nn.Cell): | |||
| """ | |||
| Loss block cell of YOLOV3 network. | |||
| """ | |||
| def __init__(self, scale, config=ConfigYOLOV3DarkNet53()): | |||
| super(YoloLossBlock, self).__init__() | |||
| self.config = config | |||
| if scale == 's': | |||
| # anchor mask | |||
| idx = (0, 1, 2) | |||
| elif scale == 'm': | |||
| idx = (3, 4, 5) | |||
| elif scale == 'l': | |||
| idx = (6, 7, 8) | |||
| else: | |||
| raise KeyError("Invalid scale value for DetectionBlock") | |||
| self.anchors = Tensor([self.config.anchor_scales[i] for i in idx], ms.float32) | |||
| self.ignore_threshold = Tensor(self.config.ignore_threshold, ms.float32) | |||
| self.concat = P.Concat(axis=-1) | |||
| self.iou = Iou() | |||
| self.reduce_max = P.ReduceMax(keep_dims=False) | |||
| self.xy_loss = XYLoss() | |||
| self.wh_loss = WHLoss() | |||
| self.confidenceLoss = ConfidenceLoss() | |||
| self.classLoss = ClassLoss() | |||
| def construct(self, grid, prediction, pred_xy, pred_wh, y_true, gt_box, input_shape): | |||
| # prediction : origin output from yolo | |||
| # pred_xy: (sigmoid(xy)+grid)/grid_size | |||
| # pred_wh: (exp(wh)*anchors)/input_shape | |||
| # y_true : after normalize | |||
| # gt_box: [batch, maxboxes, xyhw] after normalize | |||
| object_mask = y_true[:, :, :, :, 4:5] | |||
| class_probs = y_true[:, :, :, :, 5:] | |||
| grid_shape = P.Shape()(prediction)[1:3] | |||
| grid_shape = P.Cast()(F.tuple_to_array(grid_shape[::-1]), ms.float32) | |||
| pred_boxes = self.concat((pred_xy, pred_wh)) | |||
| true_xy = y_true[:, :, :, :, :2] * grid_shape - grid | |||
| true_wh = y_true[:, :, :, :, 2:4] | |||
| true_wh = P.Select()(P.Equal()(true_wh, 0.0), | |||
| P.Fill()(P.DType()(true_wh), | |||
| P.Shape()(true_wh), 1.0), | |||
| true_wh) | |||
| true_wh = P.Log()(true_wh / self.anchors * input_shape) | |||
| # 2-w*h for large picture, use small scale, since small obj need more precise | |||
| box_loss_scale = 2 - y_true[:, :, :, :, 2:3] * y_true[:, :, :, :, 3:4] | |||
| gt_shape = P.Shape()(gt_box) | |||
| gt_box = P.Reshape()(gt_box, (gt_shape[0], 1, 1, 1, gt_shape[1], gt_shape[2])) | |||
| # add one more dimension for broadcast | |||
| iou = self.iou(P.ExpandDims()(pred_boxes, -2), gt_box) | |||
| # gt_box is x,y,h,w after normalize | |||
| # [batch, grid[0], grid[1], num_anchor, num_gt] | |||
| best_iou = self.reduce_max(iou, -1) | |||
| # [batch, grid[0], grid[1], num_anchor] | |||
| # ignore_mask IOU too small | |||
| ignore_mask = best_iou < self.ignore_threshold | |||
| ignore_mask = P.Cast()(ignore_mask, ms.float32) | |||
| ignore_mask = P.ExpandDims()(ignore_mask, -1) | |||
| # ignore_mask backpro will cause a lot maximunGrad and minimumGrad time consume. | |||
| # so we turn off its gradient | |||
| ignore_mask = F.stop_gradient(ignore_mask) | |||
| xy_loss = self.xy_loss(object_mask, box_loss_scale, prediction[:, :, :, :, :2], true_xy) | |||
| wh_loss = self.wh_loss(object_mask, box_loss_scale, prediction[:, :, :, :, 2:4], true_wh) | |||
| confidence_loss = self.confidenceLoss(object_mask, prediction[:, :, :, :, 4:5], ignore_mask) | |||
| class_loss = self.classLoss(object_mask, prediction[:, :, :, :, 5:], class_probs) | |||
| loss = xy_loss + wh_loss + confidence_loss + class_loss | |||
| batch_size = P.Shape()(prediction)[0] | |||
| return loss / batch_size | |||
| class YOLOV3DarkNet53(nn.Cell): | |||
| """ | |||
| Darknet based YOLOV3 network. | |||
| Args: | |||
| is_training: Bool. Whether train or not. | |||
| Returns: | |||
| Cell, cell instance of Darknet based YOLOV3 neural network. | |||
| Examples: | |||
| YOLOV3DarkNet53(True) | |||
| """ | |||
| def __init__(self, is_training): | |||
| super(YOLOV3DarkNet53, self).__init__() | |||
| self.config = ConfigYOLOV3DarkNet53() | |||
| # YOLOv3 network | |||
| self.feature_map = YOLOv3(backbone=DarkNet(ResidualBlock, self.config.backbone_layers, | |||
| self.config.backbone_input_shape, | |||
| self.config.backbone_shape, | |||
| detect=True), | |||
| backbone_shape=self.config.backbone_shape, | |||
| out_channel=self.config.out_channel) | |||
| # prediction on the default anchor boxes | |||
| self.detect_1 = DetectionBlock('l', is_training=is_training) | |||
| self.detect_2 = DetectionBlock('m', is_training=is_training) | |||
| self.detect_3 = DetectionBlock('s', is_training=is_training) | |||
| def construct(self, x, input_shape): | |||
| big_object_output, medium_object_output, small_object_output = self.feature_map(x) | |||
| output_big = self.detect_1(big_object_output, input_shape) | |||
| output_me = self.detect_2(medium_object_output, input_shape) | |||
| output_small = self.detect_3(small_object_output, input_shape) | |||
| # big is the final output which has smallest feature map | |||
| return output_big, output_me, output_small | |||
| class YoloWithLossCell(nn.Cell): | |||
| """YOLOV3 loss.""" | |||
| def __init__(self, network): | |||
| super(YoloWithLossCell, self).__init__() | |||
| self.yolo_network = network | |||
| self.config = ConfigYOLOV3DarkNet53() | |||
| self.loss_big = YoloLossBlock('l', self.config) | |||
| self.loss_me = YoloLossBlock('m', self.config) | |||
| self.loss_small = YoloLossBlock('s', self.config) | |||
| def construct(self, x, y_true_0, y_true_1, y_true_2, gt_0, gt_1, gt_2, input_shape): | |||
| yolo_out = self.yolo_network(x, input_shape) | |||
| loss_l = self.loss_big(*yolo_out[0], y_true_0, gt_0, input_shape) | |||
| loss_m = self.loss_me(*yolo_out[1], y_true_1, gt_1, input_shape) | |||
| loss_s = self.loss_small(*yolo_out[2], y_true_2, gt_2, input_shape) | |||
| return loss_l + loss_m + loss_s | |||
| class TrainingWrapper(nn.Cell): | |||
| """Training wrapper.""" | |||
| def __init__(self, network, optimizer, sens=1.0): | |||
| super(TrainingWrapper, self).__init__(auto_prefix=False) | |||
| self.network = network | |||
| self.weights = optimizer.parameters | |||
| self.optimizer = optimizer | |||
| self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) | |||
| self.sens = sens | |||
| self.reducer_flag = False | |||
| self.grad_reducer = None | |||
| self.parallel_mode = context.get_auto_parallel_context("parallel_mode") | |||
| if self.parallel_mode in [ms.ParallelMode.DATA_PARALLEL, ms.ParallelMode.HYBRID_PARALLEL]: | |||
| self.reducer_flag = True | |||
| if self.reducer_flag: | |||
| mean = context.get_auto_parallel_context("mirror_mean") | |||
| if auto_parallel_context().get_device_num_is_set(): | |||
| degree = context.get_auto_parallel_context("device_num") | |||
| else: | |||
| degree = get_group_size() | |||
| self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree) | |||
| def construct(self, *args): | |||
| weights = self.weights | |||
| loss = self.network(*args) | |||
| sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) | |||
| grads = self.grad(self.network, weights)(*args, sens) | |||
| if self.reducer_flag: | |||
| grads = self.grad_reducer(grads) | |||
| return F.depend(loss, self.optimizer(grads)) | |||
| @@ -0,0 +1,184 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """YOLOV3 dataset.""" | |||
| import os | |||
| from PIL import Image | |||
| from pycocotools.coco import COCO | |||
| import mindspore.dataset as de | |||
| import mindspore.dataset.transforms.vision.c_transforms as CV | |||
| from src.distributed_sampler import DistributedSampler | |||
| from src.transforms import reshape_fn, MultiScaleTrans | |||
| min_keypoints_per_image = 10 | |||
| def _has_only_empty_bbox(anno): | |||
| return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno) | |||
| def _count_visible_keypoints(anno): | |||
| return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno) | |||
| def has_valid_annotation(anno): | |||
| """Check annotation file.""" | |||
| # if it's empty, there is no annotation | |||
| if not anno: | |||
| return False | |||
| # if all boxes have close to zero area, there is no annotation | |||
| if _has_only_empty_bbox(anno): | |||
| return False | |||
| # keypoints task have a slight different critera for considering | |||
| # if an annotation is valid | |||
| if "keypoints" not in anno[0]: | |||
| return True | |||
| # for keypoint detection tasks, only consider valid images those | |||
| # containing at least min_keypoints_per_image | |||
| if _count_visible_keypoints(anno) >= min_keypoints_per_image: | |||
| return True | |||
| return False | |||
| class COCOYoloDataset: | |||
| """YOLOV3 Dataset for COCO.""" | |||
| def __init__(self, root, ann_file, remove_images_without_annotations=True, | |||
| filter_crowd_anno=True, is_training=True): | |||
| self.coco = COCO(ann_file) | |||
| self.root = root | |||
| self.img_ids = list(sorted(self.coco.imgs.keys())) | |||
| self.filter_crowd_anno = filter_crowd_anno | |||
| self.is_training = is_training | |||
| # filter images without any annotations | |||
| if remove_images_without_annotations: | |||
| img_ids = [] | |||
| for img_id in self.img_ids: | |||
| ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=None) | |||
| anno = self.coco.loadAnns(ann_ids) | |||
| if has_valid_annotation(anno): | |||
| img_ids.append(img_id) | |||
| self.img_ids = img_ids | |||
| self.categories = {cat["id"]: cat["name"] for cat in self.coco.cats.values()} | |||
| self.cat_ids_to_continuous_ids = { | |||
| v: i for i, v in enumerate(self.coco.getCatIds()) | |||
| } | |||
| self.continuous_ids_cat_ids = { | |||
| v: k for k, v in self.cat_ids_to_continuous_ids.items() | |||
| } | |||
| def __getitem__(self, index): | |||
| """ | |||
| Args: | |||
| index (int): Index | |||
| Returns: | |||
| (img, target) (tuple): target is a dictionary contains "bbox", "segmentation" or "keypoints", | |||
| generated by the image's annotation. img is a PIL image. | |||
| """ | |||
| coco = self.coco | |||
| img_id = self.img_ids[index] | |||
| img_path = coco.loadImgs(img_id)[0]["file_name"] | |||
| img = Image.open(os.path.join(self.root, img_path)).convert("RGB") | |||
| if not self.is_training: | |||
| return img, img_id | |||
| ann_ids = coco.getAnnIds(imgIds=img_id) | |||
| target = coco.loadAnns(ann_ids) | |||
| # filter crowd annotations | |||
| if self.filter_crowd_anno: | |||
| annos = [anno for anno in target if anno["iscrowd"] == 0] | |||
| else: | |||
| annos = [anno for anno in target] | |||
| target = {} | |||
| boxes = [anno["bbox"] for anno in annos] | |||
| target["bboxes"] = boxes | |||
| classes = [anno["category_id"] for anno in annos] | |||
| classes = [self.cat_ids_to_continuous_ids[cl] for cl in classes] | |||
| target["labels"] = classes | |||
| bboxes = target['bboxes'] | |||
| labels = target['labels'] | |||
| out_target = [] | |||
| for bbox, label in zip(bboxes, labels): | |||
| tmp = [] | |||
| # convert to [x_min y_min x_max y_max] | |||
| bbox = self._convetTopDown(bbox) | |||
| tmp.extend(bbox) | |||
| tmp.append(int(label)) | |||
| # tmp [x_min y_min x_max y_max, label] | |||
| out_target.append(tmp) | |||
| return img, out_target | |||
| def __len__(self): | |||
| return len(self.img_ids) | |||
| def _convetTopDown(self, bbox): | |||
| x_min = bbox[0] | |||
| y_min = bbox[1] | |||
| w = bbox[2] | |||
| h = bbox[3] | |||
| return [x_min, y_min, x_min+w, y_min+h] | |||
| def create_yolo_dataset(image_dir, anno_path, batch_size, max_epoch, device_num, rank, | |||
| config=None, is_training=True, shuffle=True): | |||
| """Create dataset for YOLOV3.""" | |||
| if is_training: | |||
| filter_crowd = True | |||
| remove_empty_anno = True | |||
| else: | |||
| filter_crowd = False | |||
| remove_empty_anno = False | |||
| yolo_dataset = COCOYoloDataset(root=image_dir, ann_file=anno_path, filter_crowd_anno=filter_crowd, | |||
| remove_images_without_annotations=remove_empty_anno, is_training=is_training) | |||
| distributed_sampler = DistributedSampler(len(yolo_dataset), device_num, rank, shuffle=shuffle) | |||
| hwc_to_chw = CV.HWC2CHW() | |||
| config.dataset_size = len(yolo_dataset) | |||
| num_parallel_workers1 = int(64 / device_num) | |||
| num_parallel_workers2 = int(16 / device_num) | |||
| if is_training: | |||
| multi_scale_trans = MultiScaleTrans(config, device_num) | |||
| if device_num != 8: | |||
| ds = de.GeneratorDataset(yolo_dataset, column_names=["image", "annotation"], | |||
| num_parallel_workers=num_parallel_workers1, | |||
| sampler=distributed_sampler) | |||
| ds = ds.batch(batch_size, per_batch_map=multi_scale_trans, input_columns=['image', 'annotation'], | |||
| num_parallel_workers=num_parallel_workers2, drop_remainder=True) | |||
| else: | |||
| ds = de.GeneratorDataset(yolo_dataset, column_names=["image", "annotation"], sampler=distributed_sampler) | |||
| ds = ds.batch(batch_size, per_batch_map=multi_scale_trans, input_columns=['image', 'annotation'], | |||
| num_parallel_workers=8, drop_remainder=True) | |||
| else: | |||
| ds = de.GeneratorDataset(yolo_dataset, column_names=["image", "img_id"], | |||
| sampler=distributed_sampler) | |||
| compose_map_func = (lambda image, img_id: reshape_fn(image, img_id, config)) | |||
| ds = ds.map(input_columns=["image", "img_id"], | |||
| output_columns=["image", "image_shape", "img_id"], | |||
| columns_order=["image", "image_shape", "img_id"], | |||
| operations=compose_map_func, num_parallel_workers=8) | |||
| ds = ds.map(input_columns=["image"], operations=hwc_to_chw, num_parallel_workers=8) | |||
| ds = ds.batch(batch_size, drop_remainder=True) | |||
| ds = ds.repeat(max_epoch) | |||
| return ds, len(yolo_dataset) | |||
| @@ -0,0 +1,362 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """YoloV3 train.""" | |||
| import os | |||
| import time | |||
| import argparse | |||
| import datetime | |||
| from mindspore import ParallelMode | |||
| from mindspore.nn.optim.momentum import Momentum | |||
| from mindspore import Tensor | |||
| from mindspore import context | |||
| from mindspore.communication.management import init, get_rank, get_group_size | |||
| from mindspore.train.callback import ModelCheckpoint, RunContext | |||
| from mindspore.train.callback import _InternalCallbackParam, CheckpointConfig | |||
| import mindspore as ms | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from mindspore.train.quant import quant | |||
| from src.yolo import YOLOV3DarkNet53, YoloWithLossCell, TrainingWrapper | |||
| from src.logger import get_logger | |||
| from src.util import AverageMeter, load_backbone, get_param_groups | |||
| from src.lr_scheduler import warmup_step_lr, warmup_cosine_annealing_lr, \ | |||
| warmup_cosine_annealing_lr_V2, warmup_cosine_annealing_lr_sample | |||
| from src.yolo_dataset import create_yolo_dataset | |||
| from src.initializer import default_recurisive_init | |||
| from src.config import ConfigYOLOV3DarkNet53 | |||
| from src.transforms import batch_preprocess_true_box, batch_preprocess_true_box_single | |||
| from src.util import ShapeRecord | |||
| devid = int(os.getenv('DEVICE_ID')) | |||
| context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, | |||
| device_target="Ascend", save_graphs=True, device_id=devid) | |||
| def parse_args(): | |||
| """Parse train arguments.""" | |||
| parser = argparse.ArgumentParser('mindspore coco training') | |||
| # dataset related | |||
| parser.add_argument('--data_dir', type=str, default='', help='train data dir') | |||
| parser.add_argument('--per_batch_size', default=32, type=int, help='batch size for per gpu') | |||
| # network related | |||
| parser.add_argument('--pretrained_backbone', default='', type=str, help='model_path, local pretrained backbone' | |||
| ' model to load') | |||
| parser.add_argument('--resume_yolov3', default='', type=str, help='path of pretrained yolov3') | |||
| # optimizer and lr related | |||
| parser.add_argument('--lr_scheduler', default='exponential', type=str, | |||
| help='lr-scheduler, option type: exponential, cosine_annealing') | |||
| parser.add_argument('--lr', default=0.001, type=float, help='learning rate of the training') | |||
| parser.add_argument('--lr_epochs', type=str, default='220,250', help='epoch of lr changing') | |||
| parser.add_argument('--lr_gamma', type=float, default=0.1, | |||
| help='decrease lr by a factor of exponential lr_scheduler') | |||
| parser.add_argument('--eta_min', type=float, default=0., help='eta_min in cosine_annealing scheduler') | |||
| parser.add_argument('--T_max', type=int, default=320, help='T-max in cosine_annealing scheduler') | |||
| parser.add_argument('--max_epoch', type=int, default=320, help='max epoch num to train the model') | |||
| parser.add_argument('--warmup_epochs', default=0, type=float, help='warmup epoch') | |||
| parser.add_argument('--weight_decay', type=float, default=0.0005, help='weight decay') | |||
| parser.add_argument('--momentum', type=float, default=0.9, help='momentum') | |||
| # loss related | |||
| parser.add_argument('--loss_scale', type=int, default=1024, help='static loss scale') | |||
| parser.add_argument('--label_smooth', type=int, default=0, help='whether to use label smooth in CE') | |||
| parser.add_argument('--label_smooth_factor', type=float, default=0.1, help='smooth strength of original one-hot') | |||
| # logging related | |||
| parser.add_argument('--log_interval', type=int, default=100, help='logging interval') | |||
| parser.add_argument('--ckpt_path', type=str, default='outputs/', help='checkpoint save location') | |||
| parser.add_argument('--ckpt_interval', type=int, default=None, help='ckpt_interval') | |||
| parser.add_argument('--is_save_on_master', type=int, default=1, help='save ckpt on master or all rank') | |||
| # distributed related | |||
| parser.add_argument('--is_distributed', type=int, default=1, help='if multi device') | |||
| parser.add_argument('--rank', type=int, default=0, help='local rank of distributed') | |||
| parser.add_argument('--group_size', type=int, default=1, help='world size of distributed') | |||
| # roma obs | |||
| parser.add_argument('--train_url', type=str, default="", help='train url') | |||
| # profiler init | |||
| parser.add_argument('--need_profiler', type=int, default=0, help='whether use profiler') | |||
| # reset default config | |||
| parser.add_argument('--training_shape', type=str, default="", help='fix training shape') | |||
| parser.add_argument('--resize_rate', type=int, default=None, help='resize rate for multi-scale training') | |||
| args, _ = parser.parse_known_args() | |||
| if args.lr_scheduler == 'cosine_annealing' and args.max_epoch > args.T_max: | |||
| args.T_max = args.max_epoch | |||
| args.lr_epochs = list(map(int, args.lr_epochs.split(','))) | |||
| args.data_root = os.path.join(args.data_dir, 'train2014') | |||
| args.annFile = os.path.join(args.data_dir, 'annotations/instances_train2014.json') | |||
| return args | |||
| def conver_training_shape(args): | |||
| training_shape = [int(args.training_shape), int(args.training_shape)] | |||
| return training_shape | |||
| def train(): | |||
| """Train function.""" | |||
| args = parse_args() | |||
| # init distributed | |||
| if args.is_distributed: | |||
| init() | |||
| args.rank = get_rank() | |||
| args.group_size = get_group_size() | |||
| # select for master rank save ckpt or all rank save, compatiable for model parallel | |||
| args.rank_save_ckpt_flag = 0 | |||
| if args.is_save_on_master: | |||
| if args.rank == 0: | |||
| args.rank_save_ckpt_flag = 1 | |||
| else: | |||
| args.rank_save_ckpt_flag = 1 | |||
| # logger | |||
| args.outputs_dir = os.path.join(args.ckpt_path, | |||
| datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S')) | |||
| args.logger = get_logger(args.outputs_dir, args.rank) | |||
| args.logger.save_args(args) | |||
| if args.need_profiler: | |||
| from mindinsight.profiler.profiling import Profiler | |||
| profiler = Profiler(output_path=args.outputs_dir, is_detail=True, is_show_op_path=True) | |||
| loss_meter = AverageMeter('loss') | |||
| context.reset_auto_parallel_context() | |||
| if args.is_distributed: | |||
| parallel_mode = ParallelMode.DATA_PARALLEL | |||
| degree = get_group_size() | |||
| else: | |||
| parallel_mode = ParallelMode.STAND_ALONE | |||
| degree = 1 | |||
| context.set_auto_parallel_context(parallel_mode=parallel_mode, mirror_mean=True, device_num=degree) | |||
| network = YOLOV3DarkNet53(is_training=True) | |||
| # default is kaiming-normal | |||
| default_recurisive_init(network) | |||
| if args.pretrained_backbone: | |||
| network = load_backbone(network, args.pretrained_backbone, args) | |||
| args.logger.info('load pre-trained backbone {} into network'.format(args.pretrained_backbone)) | |||
| else: | |||
| args.logger.info('Not load pre-trained backbone, please be careful') | |||
| if args.resume_yolov3: | |||
| param_dict = load_checkpoint(args.resume_yolov3) | |||
| param_dict_new = {} | |||
| for key, values in param_dict.items(): | |||
| args.logger.info('ckpt param name = {}'.format(key)) | |||
| if key.startswith('moments.') or key.startswith('global_') or \ | |||
| key.startswith('learning_rate') or key.startswith('momentum'): | |||
| continue | |||
| elif key.startswith('yolo_network.'): | |||
| key_new = key[13:] | |||
| if key_new.endswith('1.beta'): | |||
| key_new = key_new.replace('1.beta', 'batchnorm.beta') | |||
| if key_new.endswith('1.gamma'): | |||
| key_new = key_new.replace('1.gamma', 'batchnorm.gamma') | |||
| if key_new.endswith('1.moving_mean'): | |||
| key_new = key_new.replace('1.moving_mean', 'batchnorm.moving_mean') | |||
| if key_new.endswith('1.moving_variance'): | |||
| key_new = key_new.replace('1.moving_variance', 'batchnorm.moving_variance') | |||
| if key_new.endswith('.weight'): | |||
| if key_new.endswith('0.weight'): | |||
| key_new = key_new.replace('0.weight', 'conv.weight') | |||
| else: | |||
| key_new = key_new.replace('.weight', '.conv.weight') | |||
| if key_new.endswith('.bias'): | |||
| key_new = key_new.replace('.bias', '.conv.bias') | |||
| param_dict_new[key_new] = values | |||
| args.logger.info('in resume {}'.format(key_new)) | |||
| else: | |||
| param_dict_new[key] = values | |||
| args.logger.info('in resume {}'.format(key)) | |||
| args.logger.info('resume finished') | |||
| for _, param in network.parameters_and_names(): | |||
| args.logger.info('network param name = {}'.format(param.name)) | |||
| if param.name not in param_dict_new: | |||
| args.logger.info('not match param name = {}'.format(param.name)) | |||
| load_param_into_net(network, param_dict_new) | |||
| args.logger.info('load_model {} success'.format(args.resume_yolov3)) | |||
| config = ConfigYOLOV3DarkNet53() | |||
| # convert fusion network to quantization aware network | |||
| if config.quantization_aware: | |||
| network = quant.convert_quant_network(network, | |||
| bn_fold=True, | |||
| per_channel=[True, False], | |||
| symmetric=[True, False]) | |||
| network = YoloWithLossCell(network) | |||
| args.logger.info('finish get network') | |||
| config.label_smooth = args.label_smooth | |||
| config.label_smooth_factor = args.label_smooth_factor | |||
| if args.training_shape: | |||
| config.multi_scale = [conver_training_shape(args)] | |||
| if args.resize_rate: | |||
| config.resize_rate = args.resize_rate | |||
| ds, data_size = create_yolo_dataset(image_dir=args.data_root, anno_path=args.annFile, is_training=True, | |||
| batch_size=args.per_batch_size, max_epoch=args.max_epoch, | |||
| device_num=args.group_size, rank=args.rank, config=config) | |||
| args.logger.info('Finish loading dataset') | |||
| args.steps_per_epoch = int(data_size / args.per_batch_size / args.group_size) | |||
| if not args.ckpt_interval: | |||
| args.ckpt_interval = args.steps_per_epoch | |||
| # lr scheduler | |||
| if args.lr_scheduler == 'exponential': | |||
| lr = warmup_step_lr(args.lr, | |||
| args.lr_epochs, | |||
| args.steps_per_epoch, | |||
| args.warmup_epochs, | |||
| args.max_epoch, | |||
| gamma=args.lr_gamma, | |||
| ) | |||
| elif args.lr_scheduler == 'cosine_annealing': | |||
| lr = warmup_cosine_annealing_lr(args.lr, | |||
| args.steps_per_epoch, | |||
| args.warmup_epochs, | |||
| args.max_epoch, | |||
| args.T_max, | |||
| args.eta_min) | |||
| elif args.lr_scheduler == 'cosine_annealing_V2': | |||
| lr = warmup_cosine_annealing_lr_V2(args.lr, | |||
| args.steps_per_epoch, | |||
| args.warmup_epochs, | |||
| args.max_epoch, | |||
| args.T_max, | |||
| args.eta_min) | |||
| elif args.lr_scheduler == 'cosine_annealing_sample': | |||
| lr = warmup_cosine_annealing_lr_sample(args.lr, | |||
| args.steps_per_epoch, | |||
| args.warmup_epochs, | |||
| args.max_epoch, | |||
| args.T_max, | |||
| args.eta_min) | |||
| else: | |||
| raise NotImplementedError(args.lr_scheduler) | |||
| opt = Momentum(params=get_param_groups(network), | |||
| learning_rate=Tensor(lr), | |||
| momentum=args.momentum, | |||
| weight_decay=args.weight_decay, | |||
| loss_scale=args.loss_scale) | |||
| network = TrainingWrapper(network, opt) | |||
| network.set_train() | |||
| if args.rank_save_ckpt_flag: | |||
| # checkpoint save | |||
| ckpt_max_num = args.max_epoch * args.steps_per_epoch // args.ckpt_interval | |||
| ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval, | |||
| keep_checkpoint_max=ckpt_max_num) | |||
| ckpt_cb = ModelCheckpoint(config=ckpt_config, | |||
| directory=args.outputs_dir, | |||
| prefix='{}'.format(args.rank)) | |||
| cb_params = _InternalCallbackParam() | |||
| cb_params.train_network = network | |||
| cb_params.epoch_num = ckpt_max_num | |||
| cb_params.cur_epoch_num = 1 | |||
| run_context = RunContext(cb_params) | |||
| ckpt_cb.begin(run_context) | |||
| old_progress = -1 | |||
| t_end = time.time() | |||
| data_loader = ds.create_dict_iterator() | |||
| shape_record = ShapeRecord() | |||
| for i, data in enumerate(data_loader): | |||
| images = data["image"] | |||
| input_shape = images.shape[2:4] | |||
| args.logger.info('iter[{}], shape{}'.format(i, input_shape[0])) | |||
| shape_record.set(input_shape) | |||
| images = Tensor(images) | |||
| annos = data["annotation"] | |||
| if args.group_size == 1: | |||
| batch_y_true_0, batch_y_true_1, batch_y_true_2, batch_gt_box0, batch_gt_box1, batch_gt_box2 = \ | |||
| batch_preprocess_true_box(annos, config, input_shape) | |||
| else: | |||
| batch_y_true_0, batch_y_true_1, batch_y_true_2, batch_gt_box0, batch_gt_box1, batch_gt_box2 = \ | |||
| batch_preprocess_true_box_single(annos, config, input_shape) | |||
| batch_y_true_0 = Tensor(batch_y_true_0) | |||
| batch_y_true_1 = Tensor(batch_y_true_1) | |||
| batch_y_true_2 = Tensor(batch_y_true_2) | |||
| batch_gt_box0 = Tensor(batch_gt_box0) | |||
| batch_gt_box1 = Tensor(batch_gt_box1) | |||
| batch_gt_box2 = Tensor(batch_gt_box2) | |||
| input_shape = Tensor(tuple(input_shape[::-1]), ms.float32) | |||
| loss = network(images, batch_y_true_0, batch_y_true_1, batch_y_true_2, batch_gt_box0, batch_gt_box1, | |||
| batch_gt_box2, input_shape) | |||
| loss_meter.update(loss.asnumpy()) | |||
| if args.rank_save_ckpt_flag: | |||
| # ckpt progress | |||
| cb_params.cur_step_num = i + 1 # current step number | |||
| cb_params.batch_num = i + 2 | |||
| ckpt_cb.step_end(run_context) | |||
| if i % args.log_interval == 0: | |||
| time_used = time.time() - t_end | |||
| epoch = int(i / args.steps_per_epoch) | |||
| fps = args.per_batch_size * (i - old_progress) * args.group_size / time_used | |||
| if args.rank == 0: | |||
| args.logger.info( | |||
| 'epoch[{}], iter[{}], {}, {:.2f} imgs/sec, lr:{}'.format(epoch, i, loss_meter, fps, lr[i])) | |||
| t_end = time.time() | |||
| loss_meter.reset() | |||
| old_progress = i | |||
| if (i + 1) % args.steps_per_epoch == 0 and args.rank_save_ckpt_flag: | |||
| cb_params.cur_epoch_num += 1 | |||
| if args.need_profiler: | |||
| if i == 10: | |||
| profiler.analyse() | |||
| break | |||
| args.logger.info('==========end training===============') | |||
| if __name__ == "__main__": | |||
| train() | |||