1. 调整face_detection的文件层级(scrfd与其余新增face_detection方法平级);
2. 增加极大脸/旋转脸的检测方法,更新了新模型;
3. 支持读入数据集并finetune和eval;
4. 新增card_detection模型,支持读入datasethub数据集并finetune
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10244540
master
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:ecbc9d0827cfb92e93e7d75868b1724142685dc20d3b32023c3c657a7b688a9c | |||||
| size 254845 | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:d510ab26ddc58ffea882c8ef850c1f9bd4444772f2bce7ebea3e76944536c3ae | |||||
| size 48909 | |||||
| @@ -148,6 +148,7 @@ class Pipelines(object): | |||||
| salient_detection = 'u2net-salient-detection' | salient_detection = 'u2net-salient-detection' | ||||
| image_classification = 'image-classification' | image_classification = 'image-classification' | ||||
| face_detection = 'resnet-face-detection-scrfd10gkps' | face_detection = 'resnet-face-detection-scrfd10gkps' | ||||
| card_detection = 'resnet-card-detection-scrfd34gkps' | |||||
| ulfd_face_detection = 'manual-face-detection-ulfd' | ulfd_face_detection = 'manual-face-detection-ulfd' | ||||
| facial_expression_recognition = 'vgg19-facial-expression-recognition-fer' | facial_expression_recognition = 'vgg19-facial-expression-recognition-fer' | ||||
| retina_face_detection = 'resnet50-face-detection-retinaface' | retina_face_detection = 'resnet50-face-detection-retinaface' | ||||
| @@ -270,6 +271,8 @@ class Trainers(object): | |||||
| image_portrait_enhancement = 'image-portrait-enhancement' | image_portrait_enhancement = 'image-portrait-enhancement' | ||||
| video_summarization = 'video-summarization' | video_summarization = 'video-summarization' | ||||
| movie_scene_segmentation = 'movie-scene-segmentation' | movie_scene_segmentation = 'movie-scene-segmentation' | ||||
| face_detection_scrfd = 'face-detection-scrfd' | |||||
| card_detection_scrfd = 'card-detection-scrfd' | |||||
| image_inpainting = 'image-inpainting' | image_inpainting = 'image-inpainting' | ||||
| # nlp trainers | # nlp trainers | ||||
| @@ -8,12 +8,14 @@ if TYPE_CHECKING: | |||||
| from .mtcnn import MtcnnFaceDetector | from .mtcnn import MtcnnFaceDetector | ||||
| from .retinaface import RetinaFaceDetection | from .retinaface import RetinaFaceDetection | ||||
| from .ulfd_slim import UlfdFaceDetector | from .ulfd_slim import UlfdFaceDetector | ||||
| from .scrfd import ScrfdDetect | |||||
| else: | else: | ||||
| _import_structure = { | _import_structure = { | ||||
| 'ulfd_slim': ['UlfdFaceDetector'], | 'ulfd_slim': ['UlfdFaceDetector'], | ||||
| 'retinaface': ['RetinaFaceDetection'], | 'retinaface': ['RetinaFaceDetection'], | ||||
| 'mtcnn': ['MtcnnFaceDetector'], | 'mtcnn': ['MtcnnFaceDetector'], | ||||
| 'mogface': ['MogFaceDetector'] | |||||
| 'mogface': ['MogFaceDetector'], | |||||
| 'scrfd': ['ScrfdDetect'] | |||||
| } | } | ||||
| import sys | import sys | ||||
| @@ -1,189 +0,0 @@ | |||||
| """ | |||||
| The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at | |||||
| https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/datasets/pipelines/transforms.py | |||||
| """ | |||||
| import numpy as np | |||||
| from mmdet.datasets.builder import PIPELINES | |||||
| from numpy import random | |||||
| @PIPELINES.register_module() | |||||
| class RandomSquareCrop(object): | |||||
| """Random crop the image & bboxes, the cropped patches have minimum IoU | |||||
| requirement with original image & bboxes, the IoU threshold is randomly | |||||
| selected from min_ious. | |||||
| Args: | |||||
| min_ious (tuple): minimum IoU threshold for all intersections with | |||||
| bounding boxes | |||||
| min_crop_size (float): minimum crop's size (i.e. h,w := a*h, a*w, | |||||
| where a >= min_crop_size). | |||||
| Note: | |||||
| The keys for bboxes, labels and masks should be paired. That is, \ | |||||
| `gt_bboxes` corresponds to `gt_labels` and `gt_masks`, and \ | |||||
| `gt_bboxes_ignore` to `gt_labels_ignore` and `gt_masks_ignore`. | |||||
| """ | |||||
| def __init__(self, | |||||
| crop_ratio_range=None, | |||||
| crop_choice=None, | |||||
| bbox_clip_border=True): | |||||
| self.crop_ratio_range = crop_ratio_range | |||||
| self.crop_choice = crop_choice | |||||
| self.bbox_clip_border = bbox_clip_border | |||||
| assert (self.crop_ratio_range is None) ^ (self.crop_choice is None) | |||||
| if self.crop_ratio_range is not None: | |||||
| self.crop_ratio_min, self.crop_ratio_max = self.crop_ratio_range | |||||
| self.bbox2label = { | |||||
| 'gt_bboxes': 'gt_labels', | |||||
| 'gt_bboxes_ignore': 'gt_labels_ignore' | |||||
| } | |||||
| self.bbox2mask = { | |||||
| 'gt_bboxes': 'gt_masks', | |||||
| 'gt_bboxes_ignore': 'gt_masks_ignore' | |||||
| } | |||||
| def __call__(self, results): | |||||
| """Call function to crop images and bounding boxes with minimum IoU | |||||
| constraint. | |||||
| Args: | |||||
| results (dict): Result dict from loading pipeline. | |||||
| Returns: | |||||
| dict: Result dict with images and bounding boxes cropped, \ | |||||
| 'img_shape' key is updated. | |||||
| """ | |||||
| if 'img_fields' in results: | |||||
| assert results['img_fields'] == ['img'], \ | |||||
| 'Only single img_fields is allowed' | |||||
| img = results['img'] | |||||
| assert 'bbox_fields' in results | |||||
| assert 'gt_bboxes' in results | |||||
| boxes = results['gt_bboxes'] | |||||
| h, w, c = img.shape | |||||
| scale_retry = 0 | |||||
| if self.crop_ratio_range is not None: | |||||
| max_scale = self.crop_ratio_max | |||||
| else: | |||||
| max_scale = np.amax(self.crop_choice) | |||||
| while True: | |||||
| scale_retry += 1 | |||||
| if scale_retry == 1 or max_scale > 1.0: | |||||
| if self.crop_ratio_range is not None: | |||||
| scale = np.random.uniform(self.crop_ratio_min, | |||||
| self.crop_ratio_max) | |||||
| elif self.crop_choice is not None: | |||||
| scale = np.random.choice(self.crop_choice) | |||||
| else: | |||||
| scale = scale * 1.2 | |||||
| for i in range(250): | |||||
| short_side = min(w, h) | |||||
| cw = int(scale * short_side) | |||||
| ch = cw | |||||
| # TODO +1 | |||||
| if w == cw: | |||||
| left = 0 | |||||
| elif w > cw: | |||||
| left = random.randint(0, w - cw) | |||||
| else: | |||||
| left = random.randint(w - cw, 0) | |||||
| if h == ch: | |||||
| top = 0 | |||||
| elif h > ch: | |||||
| top = random.randint(0, h - ch) | |||||
| else: | |||||
| top = random.randint(h - ch, 0) | |||||
| patch = np.array( | |||||
| (int(left), int(top), int(left + cw), int(top + ch)), | |||||
| dtype=np.int) | |||||
| # center of boxes should inside the crop img | |||||
| # only adjust boxes and instance masks when the gt is not empty | |||||
| # adjust boxes | |||||
| def is_center_of_bboxes_in_patch(boxes, patch): | |||||
| # TODO >= | |||||
| center = (boxes[:, :2] + boxes[:, 2:]) / 2 | |||||
| mask = \ | |||||
| ((center[:, 0] > patch[0]) | |||||
| * (center[:, 1] > patch[1]) | |||||
| * (center[:, 0] < patch[2]) | |||||
| * (center[:, 1] < patch[3])) | |||||
| return mask | |||||
| mask = is_center_of_bboxes_in_patch(boxes, patch) | |||||
| if not mask.any(): | |||||
| continue | |||||
| for key in results.get('bbox_fields', []): | |||||
| boxes = results[key].copy() | |||||
| mask = is_center_of_bboxes_in_patch(boxes, patch) | |||||
| boxes = boxes[mask] | |||||
| if self.bbox_clip_border: | |||||
| boxes[:, 2:] = boxes[:, 2:].clip(max=patch[2:]) | |||||
| boxes[:, :2] = boxes[:, :2].clip(min=patch[:2]) | |||||
| boxes -= np.tile(patch[:2], 2) | |||||
| results[key] = boxes | |||||
| # labels | |||||
| label_key = self.bbox2label.get(key) | |||||
| if label_key in results: | |||||
| results[label_key] = results[label_key][mask] | |||||
| # keypoints field | |||||
| if key == 'gt_bboxes': | |||||
| for kps_key in results.get('keypoints_fields', []): | |||||
| keypointss = results[kps_key].copy() | |||||
| keypointss = keypointss[mask, :, :] | |||||
| if self.bbox_clip_border: | |||||
| keypointss[:, :, : | |||||
| 2] = keypointss[:, :, :2].clip( | |||||
| max=patch[2:]) | |||||
| keypointss[:, :, : | |||||
| 2] = keypointss[:, :, :2].clip( | |||||
| min=patch[:2]) | |||||
| keypointss[:, :, 0] -= patch[0] | |||||
| keypointss[:, :, 1] -= patch[1] | |||||
| results[kps_key] = keypointss | |||||
| # mask fields | |||||
| mask_key = self.bbox2mask.get(key) | |||||
| if mask_key in results: | |||||
| results[mask_key] = results[mask_key][mask.nonzero() | |||||
| [0]].crop(patch) | |||||
| # adjust the img no matter whether the gt is empty before crop | |||||
| rimg = np.ones((ch, cw, 3), dtype=img.dtype) * 128 | |||||
| patch_from = patch.copy() | |||||
| patch_from[0] = max(0, patch_from[0]) | |||||
| patch_from[1] = max(0, patch_from[1]) | |||||
| patch_from[2] = min(img.shape[1], patch_from[2]) | |||||
| patch_from[3] = min(img.shape[0], patch_from[3]) | |||||
| patch_to = patch.copy() | |||||
| patch_to[0] = max(0, patch_to[0] * -1) | |||||
| patch_to[1] = max(0, patch_to[1] * -1) | |||||
| patch_to[2] = patch_to[0] + (patch_from[2] - patch_from[0]) | |||||
| patch_to[3] = patch_to[1] + (patch_from[3] - patch_from[1]) | |||||
| rimg[patch_to[1]:patch_to[3], | |||||
| patch_to[0]:patch_to[2], :] = img[ | |||||
| patch_from[1]:patch_from[3], | |||||
| patch_from[0]:patch_from[2], :] | |||||
| img = rimg | |||||
| results['img'] = img | |||||
| results['img_shape'] = img.shape | |||||
| return results | |||||
| def __repr__(self): | |||||
| repr_str = self.__class__.__name__ | |||||
| repr_str += f'(min_ious={self.min_iou}, ' | |||||
| repr_str += f'crop_size={self.crop_size})' | |||||
| return repr_str | |||||
| @@ -0,0 +1,2 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from .scrfd_detect import ScrfdDetect | |||||
| @@ -6,7 +6,7 @@ import numpy as np | |||||
| import torch | import torch | ||||
| def bbox2result(bboxes, labels, num_classes, kps=None): | |||||
| def bbox2result(bboxes, labels, num_classes, kps=None, num_kps=5): | |||||
| """Convert detection results to a list of numpy arrays. | """Convert detection results to a list of numpy arrays. | ||||
| Args: | Args: | ||||
| @@ -17,7 +17,7 @@ def bbox2result(bboxes, labels, num_classes, kps=None): | |||||
| Returns: | Returns: | ||||
| list(ndarray): bbox results of each class | list(ndarray): bbox results of each class | ||||
| """ | """ | ||||
| bbox_len = 5 if kps is None else 5 + 10 # if has kps, add 10 kps into bbox | |||||
| bbox_len = 5 if kps is None else 5 + num_kps * 2 # if has kps, add num_kps*2 into bbox | |||||
| if bboxes.shape[0] == 0: | if bboxes.shape[0] == 0: | ||||
| return [ | return [ | ||||
| np.zeros((0, bbox_len), dtype=np.float32) | np.zeros((0, bbox_len), dtype=np.float32) | ||||
| @@ -17,6 +17,7 @@ def multiclass_nms(multi_bboxes, | |||||
| Args: | Args: | ||||
| multi_bboxes (Tensor): shape (n, #class*4) or (n, 4) | multi_bboxes (Tensor): shape (n, #class*4) or (n, 4) | ||||
| multi_kps (Tensor): shape (n, #class*num_kps*2) or (n, num_kps*2) | |||||
| multi_scores (Tensor): shape (n, #class), where the last column | multi_scores (Tensor): shape (n, #class), where the last column | ||||
| contains scores of the background class, but this will be ignored. | contains scores of the background class, but this will be ignored. | ||||
| score_thr (float): bbox threshold, bboxes with scores lower than it | score_thr (float): bbox threshold, bboxes with scores lower than it | ||||
| @@ -36,16 +37,18 @@ def multiclass_nms(multi_bboxes, | |||||
| num_classes = multi_scores.size(1) - 1 | num_classes = multi_scores.size(1) - 1 | ||||
| # exclude background category | # exclude background category | ||||
| kps = None | kps = None | ||||
| if multi_kps is not None: | |||||
| num_kps = int((multi_kps.shape[1] / num_classes) / 2) | |||||
| if multi_bboxes.shape[1] > 4: | if multi_bboxes.shape[1] > 4: | ||||
| bboxes = multi_bboxes.view(multi_scores.size(0), -1, 4) | bboxes = multi_bboxes.view(multi_scores.size(0), -1, 4) | ||||
| if multi_kps is not None: | if multi_kps is not None: | ||||
| kps = multi_kps.view(multi_scores.size(0), -1, 10) | |||||
| kps = multi_kps.view(multi_scores.size(0), -1, num_kps * 2) | |||||
| else: | else: | ||||
| bboxes = multi_bboxes[:, None].expand( | bboxes = multi_bboxes[:, None].expand( | ||||
| multi_scores.size(0), num_classes, 4) | multi_scores.size(0), num_classes, 4) | ||||
| if multi_kps is not None: | if multi_kps is not None: | ||||
| kps = multi_kps[:, None].expand( | kps = multi_kps[:, None].expand( | ||||
| multi_scores.size(0), num_classes, 10) | |||||
| multi_scores.size(0), num_classes, num_kps * 2) | |||||
| scores = multi_scores[:, :-1] | scores = multi_scores[:, :-1] | ||||
| if score_factors is not None: | if score_factors is not None: | ||||
| @@ -56,7 +59,7 @@ def multiclass_nms(multi_bboxes, | |||||
| bboxes = bboxes.reshape(-1, 4) | bboxes = bboxes.reshape(-1, 4) | ||||
| if kps is not None: | if kps is not None: | ||||
| kps = kps.reshape(-1, 10) | |||||
| kps = kps.reshape(-1, num_kps * 2) | |||||
| scores = scores.reshape(-1) | scores = scores.reshape(-1) | ||||
| labels = labels.reshape(-1) | labels = labels.reshape(-1) | ||||
| @@ -2,6 +2,12 @@ | |||||
| The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at | The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at | ||||
| https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/datasets/pipelines | https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/datasets/pipelines | ||||
| """ | """ | ||||
| from .auto_augment import RotateV2 | |||||
| from .formating import DefaultFormatBundleV2 | |||||
| from .loading import LoadAnnotationsV2 | |||||
| from .transforms import RandomSquareCrop | from .transforms import RandomSquareCrop | ||||
| __all__ = ['RandomSquareCrop'] | |||||
| __all__ = [ | |||||
| 'RandomSquareCrop', 'LoadAnnotationsV2', 'RotateV2', | |||||
| 'DefaultFormatBundleV2' | |||||
| ] | |||||
| @@ -0,0 +1,271 @@ | |||||
| """ | |||||
| The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at | |||||
| https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/datasets/pipelines/auto_augment.py | |||||
| """ | |||||
| import copy | |||||
| import cv2 | |||||
| import mmcv | |||||
| import numpy as np | |||||
| from mmdet.datasets.builder import PIPELINES | |||||
| _MAX_LEVEL = 10 | |||||
| def level_to_value(level, max_value): | |||||
| """Map from level to values based on max_value.""" | |||||
| return (level / _MAX_LEVEL) * max_value | |||||
| def random_negative(value, random_negative_prob): | |||||
| """Randomly negate value based on random_negative_prob.""" | |||||
| return -value if np.random.rand() < random_negative_prob else value | |||||
| def bbox2fields(): | |||||
| """The key correspondence from bboxes to labels, masks and | |||||
| segmentations.""" | |||||
| bbox2label = { | |||||
| 'gt_bboxes': 'gt_labels', | |||||
| 'gt_bboxes_ignore': 'gt_labels_ignore' | |||||
| } | |||||
| bbox2mask = { | |||||
| 'gt_bboxes': 'gt_masks', | |||||
| 'gt_bboxes_ignore': 'gt_masks_ignore' | |||||
| } | |||||
| bbox2seg = { | |||||
| 'gt_bboxes': 'gt_semantic_seg', | |||||
| } | |||||
| return bbox2label, bbox2mask, bbox2seg | |||||
| @PIPELINES.register_module() | |||||
| class RotateV2(object): | |||||
| """Apply Rotate Transformation to image (and its corresponding bbox, mask, | |||||
| segmentation). | |||||
| Args: | |||||
| level (int | float): The level should be in range (0,_MAX_LEVEL]. | |||||
| scale (int | float): Isotropic scale factor. Same in | |||||
| ``mmcv.imrotate``. | |||||
| center (int | float | tuple[float]): Center point (w, h) of the | |||||
| rotation in the source image. If None, the center of the | |||||
| image will be used. Same in ``mmcv.imrotate``. | |||||
| img_fill_val (int | float | tuple): The fill value for image border. | |||||
| If float, the same value will be used for all the three | |||||
| channels of image. If tuple, the should be 3 elements (e.g. | |||||
| equals the number of channels for image). | |||||
| seg_ignore_label (int): The fill value used for segmentation map. | |||||
| Note this value must equals ``ignore_label`` in ``semantic_head`` | |||||
| of the corresponding config. Default 255. | |||||
| prob (float): The probability for perform transformation and | |||||
| should be in range 0 to 1. | |||||
| max_rotate_angle (int | float): The maximum angles for rotate | |||||
| transformation. | |||||
| random_negative_prob (float): The probability that turns the | |||||
| offset negative. | |||||
| """ | |||||
| def __init__(self, | |||||
| level, | |||||
| scale=1, | |||||
| center=None, | |||||
| img_fill_val=128, | |||||
| seg_ignore_label=255, | |||||
| prob=0.5, | |||||
| max_rotate_angle=30, | |||||
| random_negative_prob=0.5): | |||||
| assert isinstance(level, (int, float)), \ | |||||
| f'The level must be type int or float. got {type(level)}.' | |||||
| assert 0 <= level <= _MAX_LEVEL, \ | |||||
| f'The level should be in range (0,{_MAX_LEVEL}]. got {level}.' | |||||
| assert isinstance(scale, (int, float)), \ | |||||
| f'The scale must be type int or float. got type {type(scale)}.' | |||||
| if isinstance(center, (int, float)): | |||||
| center = (center, center) | |||||
| elif isinstance(center, tuple): | |||||
| assert len(center) == 2, 'center with type tuple must have '\ | |||||
| f'2 elements. got {len(center)} elements.' | |||||
| else: | |||||
| assert center is None, 'center must be None or type int, '\ | |||||
| f'float or tuple, got type {type(center)}.' | |||||
| if isinstance(img_fill_val, (float, int)): | |||||
| img_fill_val = tuple([float(img_fill_val)] * 3) | |||||
| elif isinstance(img_fill_val, tuple): | |||||
| assert len(img_fill_val) == 3, 'img_fill_val as tuple must '\ | |||||
| f'have 3 elements. got {len(img_fill_val)}.' | |||||
| img_fill_val = tuple([float(val) for val in img_fill_val]) | |||||
| else: | |||||
| raise ValueError( | |||||
| 'img_fill_val must be float or tuple with 3 elements.') | |||||
| assert np.all([0 <= val <= 255 for val in img_fill_val]), \ | |||||
| 'all elements of img_fill_val should between range [0,255]. '\ | |||||
| f'got {img_fill_val}.' | |||||
| assert 0 <= prob <= 1.0, 'The probability should be in range [0,1]. '\ | |||||
| f'got {prob}.' | |||||
| assert isinstance(max_rotate_angle, (int, float)), 'max_rotate_angle '\ | |||||
| f'should be type int or float. got type {type(max_rotate_angle)}.' | |||||
| self.level = level | |||||
| self.scale = scale | |||||
| # Rotation angle in degrees. Positive values mean | |||||
| # clockwise rotation. | |||||
| self.angle = level_to_value(level, max_rotate_angle) | |||||
| self.center = center | |||||
| self.img_fill_val = img_fill_val | |||||
| self.seg_ignore_label = seg_ignore_label | |||||
| self.prob = prob | |||||
| self.max_rotate_angle = max_rotate_angle | |||||
| self.random_negative_prob = random_negative_prob | |||||
| def _rotate_img(self, results, angle, center=None, scale=1.0): | |||||
| """Rotate the image. | |||||
| Args: | |||||
| results (dict): Result dict from loading pipeline. | |||||
| angle (float): Rotation angle in degrees, positive values | |||||
| mean clockwise rotation. Same in ``mmcv.imrotate``. | |||||
| center (tuple[float], optional): Center point (w, h) of the | |||||
| rotation. Same in ``mmcv.imrotate``. | |||||
| scale (int | float): Isotropic scale factor. Same in | |||||
| ``mmcv.imrotate``. | |||||
| """ | |||||
| for key in results.get('img_fields', ['img']): | |||||
| img = results[key].copy() | |||||
| img_rotated = mmcv.imrotate( | |||||
| img, angle, center, scale, border_value=self.img_fill_val) | |||||
| results[key] = img_rotated.astype(img.dtype) | |||||
| results['img_shape'] = results[key].shape | |||||
| def _rotate_bboxes(self, results, rotate_matrix): | |||||
| """Rotate the bboxes.""" | |||||
| h, w, c = results['img_shape'] | |||||
| for key in results.get('bbox_fields', []): | |||||
| min_x, min_y, max_x, max_y = np.split( | |||||
| results[key], results[key].shape[-1], axis=-1) | |||||
| coordinates = np.stack([[min_x, min_y], [max_x, min_y], | |||||
| [min_x, max_y], | |||||
| [max_x, max_y]]) # [4, 2, nb_bbox, 1] | |||||
| # pad 1 to convert from format [x, y] to homogeneous | |||||
| # coordinates format [x, y, 1] | |||||
| coordinates = np.concatenate( | |||||
| (coordinates, | |||||
| np.ones((4, 1, coordinates.shape[2], 1), coordinates.dtype)), | |||||
| axis=1) # [4, 3, nb_bbox, 1] | |||||
| coordinates = coordinates.transpose( | |||||
| (2, 0, 1, 3)) # [nb_bbox, 4, 3, 1] | |||||
| rotated_coords = np.matmul(rotate_matrix, | |||||
| coordinates) # [nb_bbox, 4, 2, 1] | |||||
| rotated_coords = rotated_coords[..., 0] # [nb_bbox, 4, 2] | |||||
| min_x, min_y = np.min( | |||||
| rotated_coords[:, :, 0], axis=1), np.min( | |||||
| rotated_coords[:, :, 1], axis=1) | |||||
| max_x, max_y = np.max( | |||||
| rotated_coords[:, :, 0], axis=1), np.max( | |||||
| rotated_coords[:, :, 1], axis=1) | |||||
| results[key] = np.stack([min_x, min_y, max_x, max_y], | |||||
| axis=-1).astype(results[key].dtype) | |||||
| def _rotate_keypoints90(self, results, angle): | |||||
| """Rotate the keypoints, only valid when angle in [-90,90,-180,180]""" | |||||
| if angle not in [-90, 90, 180, -180 | |||||
| ] or self.scale != 1 or self.center is not None: | |||||
| return | |||||
| for key in results.get('keypoints_fields', []): | |||||
| k = results[key] | |||||
| if angle == 90: | |||||
| w, h, c = results['img'].shape | |||||
| new = np.stack([h - k[..., 1], k[..., 0], k[..., 2]], axis=-1) | |||||
| elif angle == -90: | |||||
| w, h, c = results['img'].shape | |||||
| new = np.stack([k[..., 1], w - k[..., 0], k[..., 2]], axis=-1) | |||||
| else: | |||||
| h, w, c = results['img'].shape | |||||
| new = np.stack([w - k[..., 0], h - k[..., 1], k[..., 2]], | |||||
| axis=-1) | |||||
| # a kps is invalid if thrid value is -1 | |||||
| kps_invalid = new[..., -1][:, -1] == -1 | |||||
| new[kps_invalid] = np.zeros(new.shape[1:]) - 1 | |||||
| results[key] = new | |||||
| def _rotate_masks(self, | |||||
| results, | |||||
| angle, | |||||
| center=None, | |||||
| scale=1.0, | |||||
| fill_val=0): | |||||
| """Rotate the masks.""" | |||||
| h, w, c = results['img_shape'] | |||||
| for key in results.get('mask_fields', []): | |||||
| masks = results[key] | |||||
| results[key] = masks.rotate((h, w), angle, center, scale, fill_val) | |||||
| def _rotate_seg(self, | |||||
| results, | |||||
| angle, | |||||
| center=None, | |||||
| scale=1.0, | |||||
| fill_val=255): | |||||
| """Rotate the segmentation map.""" | |||||
| for key in results.get('seg_fields', []): | |||||
| seg = results[key].copy() | |||||
| results[key] = mmcv.imrotate( | |||||
| seg, angle, center, scale, | |||||
| border_value=fill_val).astype(seg.dtype) | |||||
| def _filter_invalid(self, results, min_bbox_size=0): | |||||
| """Filter bboxes and corresponding masks too small after rotate | |||||
| augmentation.""" | |||||
| bbox2label, bbox2mask, _ = bbox2fields() | |||||
| for key in results.get('bbox_fields', []): | |||||
| bbox_w = results[key][:, 2] - results[key][:, 0] | |||||
| bbox_h = results[key][:, 3] - results[key][:, 1] | |||||
| valid_inds = (bbox_w > min_bbox_size) & (bbox_h > min_bbox_size) | |||||
| valid_inds = np.nonzero(valid_inds)[0] | |||||
| results[key] = results[key][valid_inds] | |||||
| # label fields. e.g. gt_labels and gt_labels_ignore | |||||
| label_key = bbox2label.get(key) | |||||
| if label_key in results: | |||||
| results[label_key] = results[label_key][valid_inds] | |||||
| # mask fields, e.g. gt_masks and gt_masks_ignore | |||||
| mask_key = bbox2mask.get(key) | |||||
| if mask_key in results: | |||||
| results[mask_key] = results[mask_key][valid_inds] | |||||
| def __call__(self, results): | |||||
| """Call function to rotate images, bounding boxes, masks and semantic | |||||
| segmentation maps. | |||||
| Args: | |||||
| results (dict): Result dict from loading pipeline. | |||||
| Returns: | |||||
| dict: Rotated results. | |||||
| """ | |||||
| if np.random.rand() > self.prob: | |||||
| return results | |||||
| h, w = results['img'].shape[:2] | |||||
| center = self.center | |||||
| if center is None: | |||||
| center = ((w - 1) * 0.5, (h - 1) * 0.5) | |||||
| angle = random_negative(self.angle, self.random_negative_prob) | |||||
| self._rotate_img(results, angle, center, self.scale) | |||||
| rotate_matrix = cv2.getRotationMatrix2D(center, -angle, self.scale) | |||||
| self._rotate_bboxes(results, rotate_matrix) | |||||
| self._rotate_keypoints90(results, angle) | |||||
| self._rotate_masks(results, angle, center, self.scale, fill_val=0) | |||||
| self._rotate_seg( | |||||
| results, angle, center, self.scale, fill_val=self.seg_ignore_label) | |||||
| self._filter_invalid(results) | |||||
| return results | |||||
| def __repr__(self): | |||||
| repr_str = self.__class__.__name__ | |||||
| repr_str += f'(level={self.level}, ' | |||||
| repr_str += f'scale={self.scale}, ' | |||||
| repr_str += f'center={self.center}, ' | |||||
| repr_str += f'img_fill_val={self.img_fill_val}, ' | |||||
| repr_str += f'seg_ignore_label={self.seg_ignore_label}, ' | |||||
| repr_str += f'prob={self.prob}, ' | |||||
| repr_str += f'max_rotate_angle={self.max_rotate_angle}, ' | |||||
| repr_str += f'random_negative_prob={self.random_negative_prob})' | |||||
| return repr_str | |||||
| @@ -0,0 +1,113 @@ | |||||
| """ | |||||
| The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at | |||||
| https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/datasets/pipelines/formating.py | |||||
| """ | |||||
| import numpy as np | |||||
| import torch | |||||
| from mmcv.parallel import DataContainer as DC | |||||
| from mmdet.datasets.builder import PIPELINES | |||||
| def to_tensor(data): | |||||
| """Convert objects of various python types to :obj:`torch.Tensor`. | |||||
| Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`, | |||||
| :class:`Sequence`, :class:`int` and :class:`float`. | |||||
| Args: | |||||
| data (torch.Tensor | numpy.ndarray | Sequence | int | float): Data to | |||||
| be converted. | |||||
| """ | |||||
| if isinstance(data, torch.Tensor): | |||||
| return data | |||||
| elif isinstance(data, np.ndarray): | |||||
| return torch.from_numpy(data) | |||||
| elif isinstance(data, Sequence) and not mmcv.is_str(data): | |||||
| return torch.tensor(data) | |||||
| elif isinstance(data, int): | |||||
| return torch.LongTensor([data]) | |||||
| elif isinstance(data, float): | |||||
| return torch.FloatTensor([data]) | |||||
| else: | |||||
| raise TypeError(f'type {type(data)} cannot be converted to tensor.') | |||||
| @PIPELINES.register_module() | |||||
| class DefaultFormatBundleV2(object): | |||||
| """Default formatting bundle. | |||||
| It simplifies the pipeline of formatting common fields, including "img", | |||||
| "proposals", "gt_bboxes", "gt_labels", "gt_masks" and "gt_semantic_seg". | |||||
| These fields are formatted as follows. | |||||
| - img: (1)transpose, (2)to tensor, (3)to DataContainer (stack=True) | |||||
| - proposals: (1)to tensor, (2)to DataContainer | |||||
| - gt_bboxes: (1)to tensor, (2)to DataContainer | |||||
| - gt_bboxes_ignore: (1)to tensor, (2)to DataContainer | |||||
| - gt_labels: (1)to tensor, (2)to DataContainer | |||||
| - gt_masks: (1)to tensor, (2)to DataContainer (cpu_only=True) | |||||
| - gt_semantic_seg: (1)unsqueeze dim-0 (2)to tensor, \ | |||||
| (3)to DataContainer (stack=True) | |||||
| """ | |||||
| def __call__(self, results): | |||||
| """Call function to transform and format common fields in results. | |||||
| Args: | |||||
| results (dict): Result dict contains the data to convert. | |||||
| Returns: | |||||
| dict: The result dict contains the data that is formatted with \ | |||||
| default bundle. | |||||
| """ | |||||
| if 'img' in results: | |||||
| img = results['img'] | |||||
| # add default meta keys | |||||
| results = self._add_default_meta_keys(results) | |||||
| if len(img.shape) < 3: | |||||
| img = np.expand_dims(img, -1) | |||||
| img = np.ascontiguousarray(img.transpose(2, 0, 1)) | |||||
| results['img'] = DC(to_tensor(img), stack=True) | |||||
| for key in [ | |||||
| 'proposals', 'gt_bboxes', 'gt_bboxes_ignore', 'gt_keypointss', | |||||
| 'gt_labels' | |||||
| ]: | |||||
| if key not in results: | |||||
| continue | |||||
| results[key] = DC(to_tensor(results[key])) | |||||
| if 'gt_masks' in results: | |||||
| results['gt_masks'] = DC(results['gt_masks'], cpu_only=True) | |||||
| if 'gt_semantic_seg' in results: | |||||
| results['gt_semantic_seg'] = DC( | |||||
| to_tensor(results['gt_semantic_seg'][None, ...]), stack=True) | |||||
| return results | |||||
| def _add_default_meta_keys(self, results): | |||||
| """Add default meta keys. | |||||
| We set default meta keys including `pad_shape`, `scale_factor` and | |||||
| `img_norm_cfg` to avoid the case where no `Resize`, `Normalize` and | |||||
| `Pad` are implemented during the whole pipeline. | |||||
| Args: | |||||
| results (dict): Result dict contains the data to convert. | |||||
| Returns: | |||||
| results (dict): Updated result dict contains the data to convert. | |||||
| """ | |||||
| img = results['img'] | |||||
| results.setdefault('pad_shape', img.shape) | |||||
| results.setdefault('scale_factor', 1.0) | |||||
| num_channels = 1 if len(img.shape) < 3 else img.shape[2] | |||||
| results.setdefault( | |||||
| 'img_norm_cfg', | |||||
| dict( | |||||
| mean=np.zeros(num_channels, dtype=np.float32), | |||||
| std=np.ones(num_channels, dtype=np.float32), | |||||
| to_rgb=False)) | |||||
| return results | |||||
| def __repr__(self): | |||||
| return self.__class__.__name__ | |||||
| @@ -0,0 +1,225 @@ | |||||
| """ | |||||
| The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at | |||||
| https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/datasets/pipelines/loading.py | |||||
| """ | |||||
| import os.path as osp | |||||
| import numpy as np | |||||
| import pycocotools.mask as maskUtils | |||||
| from mmdet.core import BitmapMasks, PolygonMasks | |||||
| from mmdet.datasets.builder import PIPELINES | |||||
| @PIPELINES.register_module() | |||||
| class LoadAnnotationsV2(object): | |||||
| """Load mutiple types of annotations. | |||||
| Args: | |||||
| with_bbox (bool): Whether to parse and load the bbox annotation. | |||||
| Default: True. | |||||
| with_label (bool): Whether to parse and load the label annotation. | |||||
| Default: True. | |||||
| with_keypoints (bool): Whether to parse and load the keypoints annotation. | |||||
| Default: False. | |||||
| with_mask (bool): Whether to parse and load the mask annotation. | |||||
| Default: False. | |||||
| with_seg (bool): Whether to parse and load the semantic segmentation | |||||
| annotation. Default: False. | |||||
| poly2mask (bool): Whether to convert the instance masks from polygons | |||||
| to bitmaps. Default: True. | |||||
| file_client_args (dict): Arguments to instantiate a FileClient. | |||||
| See :class:`mmcv.fileio.FileClient` for details. | |||||
| Defaults to ``dict(backend='disk')``. | |||||
| """ | |||||
| def __init__(self, | |||||
| with_bbox=True, | |||||
| with_label=True, | |||||
| with_keypoints=False, | |||||
| with_mask=False, | |||||
| with_seg=False, | |||||
| poly2mask=True, | |||||
| file_client_args=dict(backend='disk')): | |||||
| self.with_bbox = with_bbox | |||||
| self.with_label = with_label | |||||
| self.with_keypoints = with_keypoints | |||||
| self.with_mask = with_mask | |||||
| self.with_seg = with_seg | |||||
| self.poly2mask = poly2mask | |||||
| self.file_client_args = file_client_args.copy() | |||||
| self.file_client = None | |||||
| def _load_bboxes(self, results): | |||||
| """Private function to load bounding box annotations. | |||||
| Args: | |||||
| results (dict): Result dict from :obj:`mmdet.CustomDataset`. | |||||
| Returns: | |||||
| dict: The dict contains loaded bounding box annotations. | |||||
| """ | |||||
| ann_info = results['ann_info'] | |||||
| results['gt_bboxes'] = ann_info['bboxes'].copy() | |||||
| gt_bboxes_ignore = ann_info.get('bboxes_ignore', None) | |||||
| if gt_bboxes_ignore is not None: | |||||
| results['gt_bboxes_ignore'] = gt_bboxes_ignore.copy() | |||||
| results['bbox_fields'].append('gt_bboxes_ignore') | |||||
| results['bbox_fields'].append('gt_bboxes') | |||||
| return results | |||||
| def _load_keypoints(self, results): | |||||
| """Private function to load bounding box annotations. | |||||
| Args: | |||||
| results (dict): Result dict from :obj:`mmdet.CustomDataset`. | |||||
| Returns: | |||||
| dict: The dict contains loaded bounding box annotations. | |||||
| """ | |||||
| ann_info = results['ann_info'] | |||||
| results['gt_keypointss'] = ann_info['keypointss'].copy() | |||||
| results['keypoints_fields'] = ['gt_keypointss'] | |||||
| return results | |||||
| def _load_labels(self, results): | |||||
| """Private function to load label annotations. | |||||
| Args: | |||||
| results (dict): Result dict from :obj:`mmdet.CustomDataset`. | |||||
| Returns: | |||||
| dict: The dict contains loaded label annotations. | |||||
| """ | |||||
| results['gt_labels'] = results['ann_info']['labels'].copy() | |||||
| return results | |||||
| def _poly2mask(self, mask_ann, img_h, img_w): | |||||
| """Private function to convert masks represented with polygon to | |||||
| bitmaps. | |||||
| Args: | |||||
| mask_ann (list | dict): Polygon mask annotation input. | |||||
| img_h (int): The height of output mask. | |||||
| img_w (int): The width of output mask. | |||||
| Returns: | |||||
| numpy.ndarray: The decode bitmap mask of shape (img_h, img_w). | |||||
| """ | |||||
| if isinstance(mask_ann, list): | |||||
| # polygon -- a single object might consist of multiple parts | |||||
| # we merge all parts into one mask rle code | |||||
| rles = maskUtils.frPyObjects(mask_ann, img_h, img_w) | |||||
| rle = maskUtils.merge(rles) | |||||
| elif isinstance(mask_ann['counts'], list): | |||||
| # uncompressed RLE | |||||
| rle = maskUtils.frPyObjects(mask_ann, img_h, img_w) | |||||
| else: | |||||
| # rle | |||||
| rle = mask_ann | |||||
| mask = maskUtils.decode(rle) | |||||
| return mask | |||||
| def process_polygons(self, polygons): | |||||
| """Convert polygons to list of ndarray and filter invalid polygons. | |||||
| Args: | |||||
| polygons (list[list]): Polygons of one instance. | |||||
| Returns: | |||||
| list[numpy.ndarray]: Processed polygons. | |||||
| """ | |||||
| polygons = [np.array(p) for p in polygons] | |||||
| valid_polygons = [] | |||||
| for polygon in polygons: | |||||
| if len(polygon) % 2 == 0 and len(polygon) >= 6: | |||||
| valid_polygons.append(polygon) | |||||
| return valid_polygons | |||||
| def _load_masks(self, results): | |||||
| """Private function to load mask annotations. | |||||
| Args: | |||||
| results (dict): Result dict from :obj:`mmdet.CustomDataset`. | |||||
| Returns: | |||||
| dict: The dict contains loaded mask annotations. | |||||
| If ``self.poly2mask`` is set ``True``, `gt_mask` will contain | |||||
| :obj:`PolygonMasks`. Otherwise, :obj:`BitmapMasks` is used. | |||||
| """ | |||||
| h, w = results['img_info']['height'], results['img_info']['width'] | |||||
| gt_masks = results['ann_info']['masks'] | |||||
| if self.poly2mask: | |||||
| gt_masks = BitmapMasks( | |||||
| [self._poly2mask(mask, h, w) for mask in gt_masks], h, w) | |||||
| else: | |||||
| gt_masks = PolygonMasks( | |||||
| [self.process_polygons(polygons) for polygons in gt_masks], h, | |||||
| w) | |||||
| results['gt_masks'] = gt_masks | |||||
| results['mask_fields'].append('gt_masks') | |||||
| return results | |||||
| def _load_semantic_seg(self, results): | |||||
| """Private function to load semantic segmentation annotations. | |||||
| Args: | |||||
| results (dict): Result dict from :obj:`dataset`. | |||||
| Returns: | |||||
| dict: The dict contains loaded semantic segmentation annotations. | |||||
| """ | |||||
| import mmcv | |||||
| if self.file_client is None: | |||||
| self.file_client = mmcv.FileClient(**self.file_client_args) | |||||
| filename = osp.join(results['seg_prefix'], | |||||
| results['ann_info']['seg_map']) | |||||
| img_bytes = self.file_client.get(filename) | |||||
| results['gt_semantic_seg'] = mmcv.imfrombytes( | |||||
| img_bytes, flag='unchanged').squeeze() | |||||
| results['seg_fields'].append('gt_semantic_seg') | |||||
| return results | |||||
| def __call__(self, results): | |||||
| """Call function to load multiple types annotations. | |||||
| Args: | |||||
| results (dict): Result dict from :obj:`mmdet.CustomDataset`. | |||||
| Returns: | |||||
| dict: The dict contains loaded bounding box, label, mask and | |||||
| semantic segmentation annotations. | |||||
| """ | |||||
| if self.with_bbox: | |||||
| results = self._load_bboxes(results) | |||||
| if results is None: | |||||
| return None | |||||
| if self.with_label: | |||||
| results = self._load_labels(results) | |||||
| if self.with_keypoints: | |||||
| results = self._load_keypoints(results) | |||||
| if self.with_mask: | |||||
| results = self._load_masks(results) | |||||
| if self.with_seg: | |||||
| results = self._load_semantic_seg(results) | |||||
| return results | |||||
| def __repr__(self): | |||||
| repr_str = self.__class__.__name__ | |||||
| repr_str += f'(with_bbox={self.with_bbox}, ' | |||||
| repr_str += f'with_label={self.with_label}, ' | |||||
| repr_str += f'with_keypoints={self.with_keypoints}, ' | |||||
| repr_str += f'with_mask={self.with_mask}, ' | |||||
| repr_str += f'with_seg={self.with_seg})' | |||||
| repr_str += f'poly2mask={self.poly2mask})' | |||||
| repr_str += f'poly2mask={self.file_client_args})' | |||||
| return repr_str | |||||
| @@ -0,0 +1,737 @@ | |||||
| """ | |||||
| The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at | |||||
| https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/datasets/pipelines/transforms.py | |||||
| """ | |||||
| import mmcv | |||||
| import numpy as np | |||||
| from mmdet.core.evaluation.bbox_overlaps import bbox_overlaps | |||||
| from mmdet.datasets.builder import PIPELINES | |||||
| from numpy import random | |||||
| @PIPELINES.register_module() | |||||
| class ResizeV2(object): | |||||
| """Resize images & bbox & mask &kps. | |||||
| This transform resizes the input image to some scale. Bboxes and masks are | |||||
| then resized with the same scale factor. If the input dict contains the key | |||||
| "scale", then the scale in the input dict is used, otherwise the specified | |||||
| scale in the init method is used. If the input dict contains the key | |||||
| "scale_factor" (if MultiScaleFlipAug does not give img_scale but | |||||
| scale_factor), the actual scale will be computed by image shape and | |||||
| scale_factor. | |||||
| `img_scale` can either be a tuple (single-scale) or a list of tuple | |||||
| (multi-scale). There are 3 multiscale modes: | |||||
| - ``ratio_range is not None``: randomly sample a ratio from the ratio \ | |||||
| range and multiply it with the image scale. | |||||
| - ``ratio_range is None`` and ``multiscale_mode == "range"``: randomly \ | |||||
| sample a scale from the multiscale range. | |||||
| - ``ratio_range is None`` and ``multiscale_mode == "value"``: randomly \ | |||||
| sample a scale from multiple scales. | |||||
| Args: | |||||
| img_scale (tuple or list[tuple]): Images scales for resizing. | |||||
| multiscale_mode (str): Either "range" or "value". | |||||
| ratio_range (tuple[float]): (min_ratio, max_ratio) | |||||
| keep_ratio (bool): Whether to keep the aspect ratio when resizing the | |||||
| image. | |||||
| bbox_clip_border (bool, optional): Whether clip the objects outside | |||||
| the border of the image. Defaults to True. | |||||
| backend (str): Image resize backend, choices are 'cv2' and 'pillow'. | |||||
| These two backends generates slightly different results. Defaults | |||||
| to 'cv2'. | |||||
| override (bool, optional): Whether to override `scale` and | |||||
| `scale_factor` so as to call resize twice. Default False. If True, | |||||
| after the first resizing, the existed `scale` and `scale_factor` | |||||
| will be ignored so the second resizing can be allowed. | |||||
| This option is a work-around for multiple times of resize in DETR. | |||||
| Defaults to False. | |||||
| """ | |||||
| def __init__(self, | |||||
| img_scale=None, | |||||
| multiscale_mode='range', | |||||
| ratio_range=None, | |||||
| keep_ratio=True, | |||||
| bbox_clip_border=True, | |||||
| backend='cv2', | |||||
| override=False): | |||||
| if img_scale is None: | |||||
| self.img_scale = None | |||||
| else: | |||||
| if isinstance(img_scale, list): | |||||
| self.img_scale = img_scale | |||||
| else: | |||||
| self.img_scale = [img_scale] | |||||
| assert mmcv.is_list_of(self.img_scale, tuple) | |||||
| if ratio_range is not None: | |||||
| # mode 1: given a scale and a range of image ratio | |||||
| assert len(self.img_scale) == 1 | |||||
| else: | |||||
| # mode 2: given multiple scales or a range of scales | |||||
| assert multiscale_mode in ['value', 'range'] | |||||
| self.backend = backend | |||||
| self.multiscale_mode = multiscale_mode | |||||
| self.ratio_range = ratio_range | |||||
| self.keep_ratio = keep_ratio | |||||
| # TODO: refactor the override option in Resize | |||||
| self.override = override | |||||
| self.bbox_clip_border = bbox_clip_border | |||||
| @staticmethod | |||||
| def random_select(img_scales): | |||||
| """Randomly select an img_scale from given candidates. | |||||
| Args: | |||||
| img_scales (list[tuple]): Images scales for selection. | |||||
| Returns: | |||||
| (tuple, int): Returns a tuple ``(img_scale, scale_dix)``, \ | |||||
| where ``img_scale`` is the selected image scale and \ | |||||
| ``scale_idx`` is the selected index in the given candidates. | |||||
| """ | |||||
| assert mmcv.is_list_of(img_scales, tuple) | |||||
| scale_idx = np.random.randint(len(img_scales)) | |||||
| img_scale = img_scales[scale_idx] | |||||
| return img_scale, scale_idx | |||||
| @staticmethod | |||||
| def random_sample(img_scales): | |||||
| """Randomly sample an img_scale when ``multiscale_mode=='range'``. | |||||
| Args: | |||||
| img_scales (list[tuple]): Images scale range for sampling. | |||||
| There must be two tuples in img_scales, which specify the lower | |||||
| and uper bound of image scales. | |||||
| Returns: | |||||
| (tuple, None): Returns a tuple ``(img_scale, None)``, where \ | |||||
| ``img_scale`` is sampled scale and None is just a placeholder \ | |||||
| to be consistent with :func:`random_select`. | |||||
| """ | |||||
| assert mmcv.is_list_of(img_scales, tuple) and len(img_scales) == 2 | |||||
| img_scale_long = [max(s) for s in img_scales] | |||||
| img_scale_short = [min(s) for s in img_scales] | |||||
| long_edge = np.random.randint( | |||||
| min(img_scale_long), | |||||
| max(img_scale_long) + 1) | |||||
| short_edge = np.random.randint( | |||||
| min(img_scale_short), | |||||
| max(img_scale_short) + 1) | |||||
| img_scale = (long_edge, short_edge) | |||||
| return img_scale, None | |||||
| @staticmethod | |||||
| def random_sample_ratio(img_scale, ratio_range): | |||||
| """Randomly sample an img_scale when ``ratio_range`` is specified. | |||||
| A ratio will be randomly sampled from the range specified by | |||||
| ``ratio_range``. Then it would be multiplied with ``img_scale`` to | |||||
| generate sampled scale. | |||||
| Args: | |||||
| img_scale (tuple): Images scale base to multiply with ratio. | |||||
| ratio_range (tuple[float]): The minimum and maximum ratio to scale | |||||
| the ``img_scale``. | |||||
| Returns: | |||||
| (tuple, None): Returns a tuple ``(scale, None)``, where \ | |||||
| ``scale`` is sampled ratio multiplied with ``img_scale`` and \ | |||||
| None is just a placeholder to be consistent with \ | |||||
| :func:`random_select`. | |||||
| """ | |||||
| assert isinstance(img_scale, tuple) and len(img_scale) == 2 | |||||
| min_ratio, max_ratio = ratio_range | |||||
| assert min_ratio <= max_ratio | |||||
| ratio = np.random.random_sample() * (max_ratio - min_ratio) + min_ratio | |||||
| scale = int(img_scale[0] * ratio), int(img_scale[1] * ratio) | |||||
| return scale, None | |||||
| def _random_scale(self, results): | |||||
| """Randomly sample an img_scale according to ``ratio_range`` and | |||||
| ``multiscale_mode``. | |||||
| If ``ratio_range`` is specified, a ratio will be sampled and be | |||||
| multiplied with ``img_scale``. | |||||
| If multiple scales are specified by ``img_scale``, a scale will be | |||||
| sampled according to ``multiscale_mode``. | |||||
| Otherwise, single scale will be used. | |||||
| Args: | |||||
| results (dict): Result dict from :obj:`dataset`. | |||||
| Returns: | |||||
| dict: Two new keys 'scale` and 'scale_idx` are added into \ | |||||
| ``results``, which would be used by subsequent pipelines. | |||||
| """ | |||||
| if self.ratio_range is not None: | |||||
| scale, scale_idx = self.random_sample_ratio( | |||||
| self.img_scale[0], self.ratio_range) | |||||
| elif len(self.img_scale) == 1: | |||||
| scale, scale_idx = self.img_scale[0], 0 | |||||
| elif self.multiscale_mode == 'range': | |||||
| scale, scale_idx = self.random_sample(self.img_scale) | |||||
| elif self.multiscale_mode == 'value': | |||||
| scale, scale_idx = self.random_select(self.img_scale) | |||||
| else: | |||||
| raise NotImplementedError | |||||
| results['scale'] = scale | |||||
| results['scale_idx'] = scale_idx | |||||
| def _resize_img(self, results): | |||||
| """Resize images with ``results['scale']``.""" | |||||
| for key in results.get('img_fields', ['img']): | |||||
| if self.keep_ratio: | |||||
| img, scale_factor = mmcv.imrescale( | |||||
| results[key], | |||||
| results['scale'], | |||||
| return_scale=True, | |||||
| backend=self.backend) | |||||
| # the w_scale and h_scale has minor difference | |||||
| # a real fix should be done in the mmcv.imrescale in the future | |||||
| new_h, new_w = img.shape[:2] | |||||
| h, w = results[key].shape[:2] | |||||
| w_scale = new_w / w | |||||
| h_scale = new_h / h | |||||
| else: | |||||
| img, w_scale, h_scale = mmcv.imresize( | |||||
| results[key], | |||||
| results['scale'], | |||||
| return_scale=True, | |||||
| backend=self.backend) | |||||
| results[key] = img | |||||
| scale_factor = np.array([w_scale, h_scale, w_scale, h_scale], | |||||
| dtype=np.float32) | |||||
| results['img_shape'] = img.shape | |||||
| # in case that there is no padding | |||||
| results['pad_shape'] = img.shape | |||||
| results['scale_factor'] = scale_factor | |||||
| results['keep_ratio'] = self.keep_ratio | |||||
| def _resize_bboxes(self, results): | |||||
| """Resize bounding boxes with ``results['scale_factor']``.""" | |||||
| for key in results.get('bbox_fields', []): | |||||
| bboxes = results[key] * results['scale_factor'] | |||||
| if self.bbox_clip_border: | |||||
| img_shape = results['img_shape'] | |||||
| bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, img_shape[1]) | |||||
| bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0, img_shape[0]) | |||||
| results[key] = bboxes | |||||
| def _resize_keypoints(self, results): | |||||
| """Resize keypoints with ``results['scale_factor']``.""" | |||||
| for key in results.get('keypoints_fields', []): | |||||
| keypointss = results[key].copy() | |||||
| factors = results['scale_factor'] | |||||
| assert factors[0] == factors[2] | |||||
| assert factors[1] == factors[3] | |||||
| keypointss[:, :, 0] *= factors[0] | |||||
| keypointss[:, :, 1] *= factors[1] | |||||
| if self.bbox_clip_border: | |||||
| img_shape = results['img_shape'] | |||||
| keypointss[:, :, 0] = np.clip(keypointss[:, :, 0], 0, | |||||
| img_shape[1]) | |||||
| keypointss[:, :, 1] = np.clip(keypointss[:, :, 1], 0, | |||||
| img_shape[0]) | |||||
| results[key] = keypointss | |||||
| def _resize_masks(self, results): | |||||
| """Resize masks with ``results['scale']``""" | |||||
| for key in results.get('mask_fields', []): | |||||
| if results[key] is None: | |||||
| continue | |||||
| if self.keep_ratio: | |||||
| results[key] = results[key].rescale(results['scale']) | |||||
| else: | |||||
| results[key] = results[key].resize(results['img_shape'][:2]) | |||||
| def _resize_seg(self, results): | |||||
| """Resize semantic segmentation map with ``results['scale']``.""" | |||||
| for key in results.get('seg_fields', []): | |||||
| if self.keep_ratio: | |||||
| gt_seg = mmcv.imrescale( | |||||
| results[key], | |||||
| results['scale'], | |||||
| interpolation='nearest', | |||||
| backend=self.backend) | |||||
| else: | |||||
| gt_seg = mmcv.imresize( | |||||
| results[key], | |||||
| results['scale'], | |||||
| interpolation='nearest', | |||||
| backend=self.backend) | |||||
| results['gt_semantic_seg'] = gt_seg | |||||
| def __call__(self, results): | |||||
| """Call function to resize images, bounding boxes, masks, semantic | |||||
| segmentation map. | |||||
| Args: | |||||
| results (dict): Result dict from loading pipeline. | |||||
| Returns: | |||||
| dict: Resized results, 'img_shape', 'pad_shape', 'scale_factor', \ | |||||
| 'keep_ratio' keys are added into result dict. | |||||
| """ | |||||
| if 'scale' not in results: | |||||
| if 'scale_factor' in results: | |||||
| img_shape = results['img'].shape[:2] | |||||
| scale_factor = results['scale_factor'] | |||||
| assert isinstance(scale_factor, float) | |||||
| results['scale'] = tuple( | |||||
| [int(x * scale_factor) for x in img_shape][::-1]) | |||||
| else: | |||||
| self._random_scale(results) | |||||
| else: | |||||
| if not self.override: | |||||
| assert 'scale_factor' not in results, ( | |||||
| 'scale and scale_factor cannot be both set.') | |||||
| else: | |||||
| results.pop('scale') | |||||
| if 'scale_factor' in results: | |||||
| results.pop('scale_factor') | |||||
| self._random_scale(results) | |||||
| self._resize_img(results) | |||||
| self._resize_bboxes(results) | |||||
| self._resize_keypoints(results) | |||||
| self._resize_masks(results) | |||||
| self._resize_seg(results) | |||||
| return results | |||||
| def __repr__(self): | |||||
| repr_str = self.__class__.__name__ | |||||
| repr_str += f'(img_scale={self.img_scale}, ' | |||||
| repr_str += f'multiscale_mode={self.multiscale_mode}, ' | |||||
| repr_str += f'ratio_range={self.ratio_range}, ' | |||||
| repr_str += f'keep_ratio={self.keep_ratio})' | |||||
| repr_str += f'bbox_clip_border={self.bbox_clip_border})' | |||||
| return repr_str | |||||
| @PIPELINES.register_module() | |||||
| class RandomFlipV2(object): | |||||
| """Flip the image & bbox & mask & kps. | |||||
| If the input dict contains the key "flip", then the flag will be used, | |||||
| otherwise it will be randomly decided by a ratio specified in the init | |||||
| method. | |||||
| When random flip is enabled, ``flip_ratio``/``direction`` can either be a | |||||
| float/string or tuple of float/string. There are 3 flip modes: | |||||
| - ``flip_ratio`` is float, ``direction`` is string: the image will be | |||||
| ``direction``ly flipped with probability of ``flip_ratio`` . | |||||
| E.g., ``flip_ratio=0.5``, ``direction='horizontal'``, | |||||
| then image will be horizontally flipped with probability of 0.5. | |||||
| - ``flip_ratio`` is float, ``direction`` is list of string: the image wil | |||||
| be ``direction[i]``ly flipped with probability of | |||||
| ``flip_ratio/len(direction)``. | |||||
| E.g., ``flip_ratio=0.5``, ``direction=['horizontal', 'vertical']``, | |||||
| then image will be horizontally flipped with probability of 0.25, | |||||
| vertically with probability of 0.25. | |||||
| - ``flip_ratio`` is list of float, ``direction`` is list of string: | |||||
| given ``len(flip_ratio) == len(direction)``, the image wil | |||||
| be ``direction[i]``ly flipped with probability of ``flip_ratio[i]``. | |||||
| E.g., ``flip_ratio=[0.3, 0.5]``, ``direction=['horizontal', | |||||
| 'vertical']``, then image will be horizontally flipped with probability | |||||
| of 0.3, vertically with probability of 0.5 | |||||
| Args: | |||||
| flip_ratio (float | list[float], optional): The flipping probability. | |||||
| Default: None. | |||||
| direction(str | list[str], optional): The flipping direction. Options | |||||
| are 'horizontal', 'vertical', 'diagonal'. Default: 'horizontal'. | |||||
| If input is a list, the length must equal ``flip_ratio``. Each | |||||
| element in ``flip_ratio`` indicates the flip probability of | |||||
| corresponding direction. | |||||
| """ | |||||
| def __init__(self, flip_ratio=None, direction='horizontal'): | |||||
| if isinstance(flip_ratio, list): | |||||
| assert mmcv.is_list_of(flip_ratio, float) | |||||
| assert 0 <= sum(flip_ratio) <= 1 | |||||
| elif isinstance(flip_ratio, float): | |||||
| assert 0 <= flip_ratio <= 1 | |||||
| elif flip_ratio is None: | |||||
| pass | |||||
| else: | |||||
| raise ValueError('flip_ratios must be None, float, ' | |||||
| 'or list of float') | |||||
| self.flip_ratio = flip_ratio | |||||
| valid_directions = ['horizontal', 'vertical', 'diagonal'] | |||||
| if isinstance(direction, str): | |||||
| assert direction in valid_directions | |||||
| elif isinstance(direction, list): | |||||
| assert mmcv.is_list_of(direction, str) | |||||
| assert set(direction).issubset(set(valid_directions)) | |||||
| else: | |||||
| raise ValueError('direction must be either str or list of str') | |||||
| self.direction = direction | |||||
| if isinstance(flip_ratio, list): | |||||
| assert len(self.flip_ratio) == len(self.direction) | |||||
| self.count = 0 | |||||
| def bbox_flip(self, bboxes, img_shape, direction): | |||||
| """Flip bboxes horizontally. | |||||
| Args: | |||||
| bboxes (numpy.ndarray): Bounding boxes, shape (..., 4*k) | |||||
| img_shape (tuple[int]): Image shape (height, width) | |||||
| direction (str): Flip direction. Options are 'horizontal', | |||||
| 'vertical'. | |||||
| Returns: | |||||
| numpy.ndarray: Flipped bounding boxes. | |||||
| """ | |||||
| assert bboxes.shape[-1] % 4 == 0 | |||||
| flipped = bboxes.copy() | |||||
| if direction == 'horizontal': | |||||
| w = img_shape[1] | |||||
| flipped[..., 0::4] = w - bboxes[..., 2::4] | |||||
| flipped[..., 2::4] = w - bboxes[..., 0::4] | |||||
| elif direction == 'vertical': | |||||
| h = img_shape[0] | |||||
| flipped[..., 1::4] = h - bboxes[..., 3::4] | |||||
| flipped[..., 3::4] = h - bboxes[..., 1::4] | |||||
| elif direction == 'diagonal': | |||||
| w = img_shape[1] | |||||
| h = img_shape[0] | |||||
| flipped[..., 0::4] = w - bboxes[..., 2::4] | |||||
| flipped[..., 1::4] = h - bboxes[..., 3::4] | |||||
| flipped[..., 2::4] = w - bboxes[..., 0::4] | |||||
| flipped[..., 3::4] = h - bboxes[..., 1::4] | |||||
| else: | |||||
| raise ValueError(f"Invalid flipping direction '{direction}'") | |||||
| return flipped | |||||
| def keypoints_flip(self, keypointss, img_shape, direction): | |||||
| """Flip keypoints horizontally.""" | |||||
| assert direction == 'horizontal' | |||||
| assert keypointss.shape[-1] == 3 | |||||
| num_kps = keypointss.shape[1] | |||||
| assert num_kps in [4, 5], f'Only Support num_kps=4 or 5, got:{num_kps}' | |||||
| assert keypointss.ndim == 3 | |||||
| flipped = keypointss.copy() | |||||
| if num_kps == 5: | |||||
| flip_order = [1, 0, 2, 4, 3] | |||||
| elif num_kps == 4: | |||||
| flip_order = [3, 2, 1, 0] | |||||
| for idx, a in enumerate(flip_order): | |||||
| flipped[:, idx, :] = keypointss[:, a, :] | |||||
| w = img_shape[1] | |||||
| flipped[..., 0] = w - flipped[..., 0] | |||||
| return flipped | |||||
| def __call__(self, results): | |||||
| """Call function to flip bounding boxes, masks, semantic segmentation | |||||
| maps. | |||||
| Args: | |||||
| results (dict): Result dict from loading pipeline. | |||||
| Returns: | |||||
| dict: Flipped results, 'flip', 'flip_direction' keys are added \ | |||||
| into result dict. | |||||
| """ | |||||
| if 'flip' not in results: | |||||
| if isinstance(self.direction, list): | |||||
| # None means non-flip | |||||
| direction_list = self.direction + [None] | |||||
| else: | |||||
| # None means non-flip | |||||
| direction_list = [self.direction, None] | |||||
| if isinstance(self.flip_ratio, list): | |||||
| non_flip_ratio = 1 - sum(self.flip_ratio) | |||||
| flip_ratio_list = self.flip_ratio + [non_flip_ratio] | |||||
| else: | |||||
| non_flip_ratio = 1 - self.flip_ratio | |||||
| # exclude non-flip | |||||
| single_ratio = self.flip_ratio / (len(direction_list) - 1) | |||||
| flip_ratio_list = [single_ratio] * (len(direction_list) | |||||
| - 1) + [non_flip_ratio] | |||||
| cur_dir = np.random.choice(direction_list, p=flip_ratio_list) | |||||
| results['flip'] = cur_dir is not None | |||||
| if 'flip_direction' not in results: | |||||
| results['flip_direction'] = cur_dir | |||||
| if results['flip']: | |||||
| # flip image | |||||
| for key in results.get('img_fields', ['img']): | |||||
| results[key] = mmcv.imflip( | |||||
| results[key], direction=results['flip_direction']) | |||||
| # flip bboxes | |||||
| for key in results.get('bbox_fields', []): | |||||
| results[key] = self.bbox_flip(results[key], | |||||
| results['img_shape'], | |||||
| results['flip_direction']) | |||||
| # flip kps | |||||
| for key in results.get('keypoints_fields', []): | |||||
| results[key] = self.keypoints_flip(results[key], | |||||
| results['img_shape'], | |||||
| results['flip_direction']) | |||||
| # flip masks | |||||
| for key in results.get('mask_fields', []): | |||||
| results[key] = results[key].flip(results['flip_direction']) | |||||
| # flip segs | |||||
| for key in results.get('seg_fields', []): | |||||
| results[key] = mmcv.imflip( | |||||
| results[key], direction=results['flip_direction']) | |||||
| return results | |||||
| def __repr__(self): | |||||
| return self.__class__.__name__ + f'(flip_ratio={self.flip_ratio})' | |||||
| @PIPELINES.register_module() | |||||
| class RandomSquareCrop(object): | |||||
| """Random crop the image & bboxes, the cropped patches have minimum IoU | |||||
| requirement with original image & bboxes, the IoU threshold is randomly | |||||
| selected from min_ious. | |||||
| Args: | |||||
| min_ious (tuple): minimum IoU threshold for all intersections with | |||||
| bounding boxes | |||||
| min_crop_size (float): minimum crop's size (i.e. h,w := a*h, a*w, | |||||
| where a >= min_crop_size). | |||||
| Note: | |||||
| The keys for bboxes, labels and masks should be paired. That is, \ | |||||
| `gt_bboxes` corresponds to `gt_labels` and `gt_masks`, and \ | |||||
| `gt_bboxes_ignore` to `gt_labels_ignore` and `gt_masks_ignore`. | |||||
| """ | |||||
| def __init__(self, | |||||
| crop_ratio_range=None, | |||||
| crop_choice=None, | |||||
| bbox_clip_border=True, | |||||
| big_face_ratio=0, | |||||
| big_face_crop_choice=None): | |||||
| self.crop_ratio_range = crop_ratio_range | |||||
| self.crop_choice = crop_choice | |||||
| self.big_face_crop_choice = big_face_crop_choice | |||||
| self.bbox_clip_border = bbox_clip_border | |||||
| assert (self.crop_ratio_range is None) ^ (self.crop_choice is None) | |||||
| if self.crop_ratio_range is not None: | |||||
| self.crop_ratio_min, self.crop_ratio_max = self.crop_ratio_range | |||||
| self.bbox2label = { | |||||
| 'gt_bboxes': 'gt_labels', | |||||
| 'gt_bboxes_ignore': 'gt_labels_ignore' | |||||
| } | |||||
| self.bbox2mask = { | |||||
| 'gt_bboxes': 'gt_masks', | |||||
| 'gt_bboxes_ignore': 'gt_masks_ignore' | |||||
| } | |||||
| assert big_face_ratio >= 0 and big_face_ratio <= 1.0 | |||||
| self.big_face_ratio = big_face_ratio | |||||
| def __call__(self, results): | |||||
| """Call function to crop images and bounding boxes with minimum IoU | |||||
| constraint. | |||||
| Args: | |||||
| results (dict): Result dict from loading pipeline. | |||||
| Returns: | |||||
| dict: Result dict with images and bounding boxes cropped, \ | |||||
| 'img_shape' key is updated. | |||||
| """ | |||||
| if 'img_fields' in results: | |||||
| assert results['img_fields'] == ['img'], \ | |||||
| 'Only single img_fields is allowed' | |||||
| img = results['img'] | |||||
| assert 'bbox_fields' in results | |||||
| assert 'gt_bboxes' in results | |||||
| # try augment big face images | |||||
| find_bigface = False | |||||
| if np.random.random() < self.big_face_ratio: | |||||
| min_size = 100 # h and w | |||||
| expand_ratio = 0.3 # expand ratio of croped face alongwith both w and h | |||||
| bbox = results['gt_bboxes'].copy() | |||||
| lmks = results['gt_keypointss'].copy() | |||||
| label = results['gt_labels'].copy() | |||||
| # filter small faces | |||||
| size_mask = ((bbox[:, 2] - bbox[:, 0]) > min_size) * ( | |||||
| (bbox[:, 3] - bbox[:, 1]) > min_size) | |||||
| bbox = bbox[size_mask] | |||||
| lmks = lmks[size_mask] | |||||
| label = label[size_mask] | |||||
| # randomly choose a face that has no overlap with others | |||||
| if len(bbox) > 0: | |||||
| overlaps = bbox_overlaps(bbox, bbox) | |||||
| overlaps -= np.eye(overlaps.shape[0]) | |||||
| iou_mask = np.sum(overlaps, axis=1) == 0 | |||||
| bbox = bbox[iou_mask] | |||||
| lmks = lmks[iou_mask] | |||||
| label = label[iou_mask] | |||||
| if len(bbox) > 0: | |||||
| choice = np.random.randint(len(bbox)) | |||||
| bbox = bbox[choice] | |||||
| lmks = lmks[choice] | |||||
| label = [label[choice]] | |||||
| w = bbox[2] - bbox[0] | |||||
| h = bbox[3] - bbox[1] | |||||
| x1 = bbox[0] - w * expand_ratio | |||||
| x2 = bbox[2] + w * expand_ratio | |||||
| y1 = bbox[1] - h * expand_ratio | |||||
| y2 = bbox[3] + h * expand_ratio | |||||
| x1, x2 = np.clip([x1, x2], 0, img.shape[1]) | |||||
| y1, y2 = np.clip([y1, y2], 0, img.shape[0]) | |||||
| bbox -= np.tile([x1, y1], 2) | |||||
| lmks -= (x1, y1, 0) | |||||
| find_bigface = True | |||||
| img = img[int(y1):int(y2), int(x1):int(x2), :] | |||||
| results['gt_bboxes'] = np.expand_dims(bbox, axis=0) | |||||
| results['gt_keypointss'] = np.expand_dims(lmks, axis=0) | |||||
| results['gt_labels'] = np.array(label) | |||||
| results['img'] = img | |||||
| boxes = results['gt_bboxes'] | |||||
| h, w, c = img.shape | |||||
| if self.crop_ratio_range is not None: | |||||
| max_scale = self.crop_ratio_max | |||||
| else: | |||||
| max_scale = np.amax(self.crop_choice) | |||||
| scale_retry = 0 | |||||
| while True: | |||||
| scale_retry += 1 | |||||
| if scale_retry == 1 or max_scale > 1.0: | |||||
| if self.crop_ratio_range is not None: | |||||
| scale = np.random.uniform(self.crop_ratio_min, | |||||
| self.crop_ratio_max) | |||||
| elif self.crop_choice is not None: | |||||
| scale = np.random.choice(self.crop_choice) | |||||
| else: | |||||
| scale = scale * 1.2 | |||||
| if find_bigface: | |||||
| # select a scale from big_face_crop_choice if in big_face mode | |||||
| scale = np.random.choice(self.big_face_crop_choice) | |||||
| for i in range(250): | |||||
| long_side = max(w, h) | |||||
| cw = int(scale * long_side) | |||||
| ch = cw | |||||
| # TODO +1 | |||||
| if w == cw: | |||||
| left = 0 | |||||
| elif w > cw: | |||||
| left = random.randint(0, w - cw) | |||||
| else: | |||||
| left = random.randint(w - cw, 0) | |||||
| if h == ch: | |||||
| top = 0 | |||||
| elif h > ch: | |||||
| top = random.randint(0, h - ch) | |||||
| else: | |||||
| top = random.randint(h - ch, 0) | |||||
| patch = np.array( | |||||
| (int(left), int(top), int(left + cw), int(top + ch)), | |||||
| dtype=np.int32) | |||||
| # center of boxes should inside the crop img | |||||
| # only adjust boxes and instance masks when the gt is not empty | |||||
| # adjust boxes | |||||
| def is_center_of_bboxes_in_patch(boxes, patch): | |||||
| # TODO >= | |||||
| center = (boxes[:, :2] + boxes[:, 2:]) / 2 | |||||
| mask = \ | |||||
| ((center[:, 0] > patch[0]) | |||||
| * (center[:, 1] > patch[1]) | |||||
| * (center[:, 0] < patch[2]) | |||||
| * (center[:, 1] < patch[3])) | |||||
| return mask | |||||
| mask = is_center_of_bboxes_in_patch(boxes, patch) | |||||
| if not mask.any(): | |||||
| continue | |||||
| for key in results.get('bbox_fields', []): | |||||
| boxes = results[key].copy() | |||||
| mask = is_center_of_bboxes_in_patch(boxes, patch) | |||||
| boxes = boxes[mask] | |||||
| if self.bbox_clip_border: | |||||
| boxes[:, 2:] = boxes[:, 2:].clip(max=patch[2:]) | |||||
| boxes[:, :2] = boxes[:, :2].clip(min=patch[:2]) | |||||
| boxes -= np.tile(patch[:2], 2) | |||||
| results[key] = boxes | |||||
| # labels | |||||
| label_key = self.bbox2label.get(key) | |||||
| if label_key in results: | |||||
| results[label_key] = results[label_key][mask] | |||||
| # keypoints field | |||||
| if key == 'gt_bboxes': | |||||
| for kps_key in results.get('keypoints_fields', []): | |||||
| keypointss = results[kps_key].copy() | |||||
| keypointss = keypointss[mask, :, :] | |||||
| if self.bbox_clip_border: | |||||
| keypointss[:, :, : | |||||
| 2] = keypointss[:, :, :2].clip( | |||||
| max=patch[2:]) | |||||
| keypointss[:, :, : | |||||
| 2] = keypointss[:, :, :2].clip( | |||||
| min=patch[:2]) | |||||
| keypointss[:, :, 0] -= patch[0] | |||||
| keypointss[:, :, 1] -= patch[1] | |||||
| results[kps_key] = keypointss | |||||
| # mask fields | |||||
| mask_key = self.bbox2mask.get(key) | |||||
| if mask_key in results: | |||||
| results[mask_key] = results[mask_key][mask.nonzero() | |||||
| [0]].crop(patch) | |||||
| # adjust the img no matter whether the gt is empty before crop | |||||
| rimg = np.ones((ch, cw, 3), dtype=img.dtype) * 128 | |||||
| patch_from = patch.copy() | |||||
| patch_from[0] = max(0, patch_from[0]) | |||||
| patch_from[1] = max(0, patch_from[1]) | |||||
| patch_from[2] = min(img.shape[1], patch_from[2]) | |||||
| patch_from[3] = min(img.shape[0], patch_from[3]) | |||||
| patch_to = patch.copy() | |||||
| patch_to[0] = max(0, patch_to[0] * -1) | |||||
| patch_to[1] = max(0, patch_to[1] * -1) | |||||
| patch_to[2] = patch_to[0] + (patch_from[2] - patch_from[0]) | |||||
| patch_to[3] = patch_to[1] + (patch_from[3] - patch_from[1]) | |||||
| rimg[patch_to[1]:patch_to[3], | |||||
| patch_to[0]:patch_to[2], :] = img[ | |||||
| patch_from[1]:patch_from[3], | |||||
| patch_from[0]:patch_from[2], :] | |||||
| img = rimg | |||||
| results['img'] = img | |||||
| results['img_shape'] = img.shape | |||||
| return results | |||||
| def __repr__(self): | |||||
| repr_str = self.__class__.__name__ | |||||
| repr_str += f'(min_ious={self.min_iou}, ' | |||||
| repr_str += f'crop_size={self.crop_size})' | |||||
| return repr_str | |||||
| @@ -13,7 +13,7 @@ class RetinaFaceDataset(CustomDataset): | |||||
| CLASSES = ('FG', ) | CLASSES = ('FG', ) | ||||
| def __init__(self, min_size=None, **kwargs): | def __init__(self, min_size=None, **kwargs): | ||||
| self.NK = 5 | |||||
| self.NK = kwargs.pop('num_kps', 5) | |||||
| self.cat2label = {cat: i for i, cat in enumerate(self.CLASSES)} | self.cat2label = {cat: i for i, cat in enumerate(self.CLASSES)} | ||||
| self.min_size = min_size | self.min_size = min_size | ||||
| self.gt_path = kwargs.get('gt_path') | self.gt_path = kwargs.get('gt_path') | ||||
| @@ -33,7 +33,8 @@ class RetinaFaceDataset(CustomDataset): | |||||
| if len(values) > 4: | if len(values) > 4: | ||||
| if len(values) > 5: | if len(values) > 5: | ||||
| kps = np.array( | kps = np.array( | ||||
| values[4:19], dtype=np.float32).reshape((self.NK, 3)) | |||||
| values[4:4 + self.NK * 3], dtype=np.float32).reshape( | |||||
| (self.NK, 3)) | |||||
| for li in range(kps.shape[0]): | for li in range(kps.shape[0]): | ||||
| if (kps[li, :] == -1).all(): | if (kps[li, :] == -1).all(): | ||||
| kps[li][2] = 0.0 # weight = 0, ignore | kps[li][2] = 0.0 # weight = 0, ignore | ||||
| @@ -103,6 +103,7 @@ class SCRFDHead(AnchorHead): | |||||
| scale_mode=1, | scale_mode=1, | ||||
| dw_conv=False, | dw_conv=False, | ||||
| use_kps=False, | use_kps=False, | ||||
| num_kps=5, | |||||
| loss_kps=dict( | loss_kps=dict( | ||||
| type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=0.1), | type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=0.1), | ||||
| **kwargs): | **kwargs): | ||||
| @@ -116,7 +117,7 @@ class SCRFDHead(AnchorHead): | |||||
| self.scale_mode = scale_mode | self.scale_mode = scale_mode | ||||
| self.use_dfl = True | self.use_dfl = True | ||||
| self.dw_conv = dw_conv | self.dw_conv = dw_conv | ||||
| self.NK = 5 | |||||
| self.NK = num_kps | |||||
| self.extra_flops = 0.0 | self.extra_flops = 0.0 | ||||
| if loss_dfl is None or not loss_dfl: | if loss_dfl is None or not loss_dfl: | ||||
| self.use_dfl = False | self.use_dfl = False | ||||
| @@ -323,8 +324,8 @@ class SCRFDHead(AnchorHead): | |||||
| batch_size, -1, self.cls_out_channels).sigmoid() | batch_size, -1, self.cls_out_channels).sigmoid() | ||||
| bbox_pred = bbox_pred.permute(0, 2, 3, | bbox_pred = bbox_pred.permute(0, 2, 3, | ||||
| 1).reshape(batch_size, -1, 4) | 1).reshape(batch_size, -1, 4) | ||||
| kps_pred = kps_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, 10) | |||||
| kps_pred = kps_pred.permute(0, 2, 3, | |||||
| 1).reshape(batch_size, -1, self.NK * 2) | |||||
| return cls_score, bbox_pred, kps_pred | return cls_score, bbox_pred, kps_pred | ||||
| def forward_train(self, | def forward_train(self, | ||||
| @@ -788,7 +789,7 @@ class SCRFDHead(AnchorHead): | |||||
| if self.use_dfl: | if self.use_dfl: | ||||
| kps_pred = self.integral(kps_pred) * stride[0] | kps_pred = self.integral(kps_pred) * stride[0] | ||||
| else: | else: | ||||
| kps_pred = kps_pred.reshape((-1, 10)) * stride[0] | |||||
| kps_pred = kps_pred.reshape((-1, self.NK * 2)) * stride[0] | |||||
| nms_pre = cfg.get('nms_pre', -1) | nms_pre = cfg.get('nms_pre', -1) | ||||
| if nms_pre > 0 and scores.shape[0] > nms_pre: | if nms_pre > 0 and scores.shape[0] > nms_pre: | ||||
| @@ -815,7 +816,7 @@ class SCRFDHead(AnchorHead): | |||||
| mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor) | mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor) | ||||
| if mlvl_kps is not None: | if mlvl_kps is not None: | ||||
| scale_factor2 = torch.tensor( | scale_factor2 = torch.tensor( | ||||
| [scale_factor[0], scale_factor[1]] * 5) | |||||
| [scale_factor[0], scale_factor[1]] * self.NK) | |||||
| mlvl_kps /= scale_factor2.to(mlvl_kps.device) | mlvl_kps /= scale_factor2.to(mlvl_kps.device) | ||||
| mlvl_scores = torch.cat(mlvl_scores) | mlvl_scores = torch.cat(mlvl_scores) | ||||
| @@ -54,7 +54,13 @@ class SCRFD(SingleStageDetector): | |||||
| gt_bboxes_ignore) | gt_bboxes_ignore) | ||||
| return losses | return losses | ||||
| def simple_test(self, img, img_metas, rescale=False): | |||||
| def simple_test(self, | |||||
| img, | |||||
| img_metas, | |||||
| rescale=False, | |||||
| repeat_head=1, | |||||
| output_kps_var=0, | |||||
| output_results=1): | |||||
| """Test function without test time augmentation. | """Test function without test time augmentation. | ||||
| Args: | Args: | ||||
| @@ -62,6 +68,9 @@ class SCRFD(SingleStageDetector): | |||||
| img_metas (list[dict]): List of image information. | img_metas (list[dict]): List of image information. | ||||
| rescale (bool, optional): Whether to rescale the results. | rescale (bool, optional): Whether to rescale the results. | ||||
| Defaults to False. | Defaults to False. | ||||
| repeat_head (int): repeat inference times in head | |||||
| output_kps_var (int): whether output kps var to calculate quality | |||||
| output_results (int): 0: nothing 1: bbox 2: both bbox and kps | |||||
| Returns: | Returns: | ||||
| list[list[np.ndarray]]: BBox results of each image and classes. | list[list[np.ndarray]]: BBox results of each image and classes. | ||||
| @@ -69,40 +78,71 @@ class SCRFD(SingleStageDetector): | |||||
| corresponds to each class. | corresponds to each class. | ||||
| """ | """ | ||||
| x = self.extract_feat(img) | x = self.extract_feat(img) | ||||
| outs = self.bbox_head(x) | |||||
| if torch.onnx.is_in_onnx_export(): | |||||
| print('single_stage.py in-onnx-export') | |||||
| print(outs.__class__) | |||||
| cls_score, bbox_pred, kps_pred = outs | |||||
| for c in cls_score: | |||||
| print(c.shape) | |||||
| for c in bbox_pred: | |||||
| print(c.shape) | |||||
| if self.bbox_head.use_kps: | |||||
| for c in kps_pred: | |||||
| assert repeat_head >= 1 | |||||
| kps_out0 = [] | |||||
| kps_out1 = [] | |||||
| kps_out2 = [] | |||||
| for i in range(repeat_head): | |||||
| outs = self.bbox_head(x) | |||||
| kps_out0 += [outs[2][0].detach().cpu().numpy()] | |||||
| kps_out1 += [outs[2][1].detach().cpu().numpy()] | |||||
| kps_out2 += [outs[2][2].detach().cpu().numpy()] | |||||
| if output_kps_var: | |||||
| var0 = np.var(np.vstack(kps_out0), axis=0).mean() | |||||
| var1 = np.var(np.vstack(kps_out1), axis=0).mean() | |||||
| var2 = np.var(np.vstack(kps_out2), axis=0).mean() | |||||
| var = np.mean([var0, var1, var2]) | |||||
| else: | |||||
| var = None | |||||
| if output_results > 0: | |||||
| if torch.onnx.is_in_onnx_export(): | |||||
| print('single_stage.py in-onnx-export') | |||||
| print(outs.__class__) | |||||
| cls_score, bbox_pred, kps_pred = outs | |||||
| for c in cls_score: | |||||
| print(c.shape) | |||||
| for c in bbox_pred: | |||||
| print(c.shape) | print(c.shape) | ||||
| return (cls_score, bbox_pred, kps_pred) | |||||
| else: | |||||
| return (cls_score, bbox_pred) | |||||
| bbox_list = self.bbox_head.get_bboxes( | |||||
| *outs, img_metas, rescale=rescale) | |||||
| if self.bbox_head.use_kps: | |||||
| for c in kps_pred: | |||||
| print(c.shape) | |||||
| return (cls_score, bbox_pred, kps_pred) | |||||
| else: | |||||
| return (cls_score, bbox_pred) | |||||
| bbox_list = self.bbox_head.get_bboxes( | |||||
| *outs, img_metas, rescale=rescale) | |||||
| # return kps if use_kps | |||||
| if len(bbox_list[0]) == 2: | |||||
| bbox_results = [ | |||||
| bbox2result(det_bboxes, det_labels, self.bbox_head.num_classes) | |||||
| for det_bboxes, det_labels in bbox_list | |||||
| ] | |||||
| elif len(bbox_list[0]) == 3: | |||||
| bbox_results = [ | |||||
| bbox2result( | |||||
| det_bboxes, | |||||
| det_labels, | |||||
| self.bbox_head.num_classes, | |||||
| kps=det_kps) | |||||
| for det_bboxes, det_labels, det_kps in bbox_list | |||||
| ] | |||||
| return bbox_results | |||||
| # return kps if use_kps | |||||
| if len(bbox_list[0]) == 2: | |||||
| bbox_results = [ | |||||
| bbox2result(det_bboxes, det_labels, | |||||
| self.bbox_head.num_classes) | |||||
| for det_bboxes, det_labels in bbox_list | |||||
| ] | |||||
| elif len(bbox_list[0]) == 3: | |||||
| if output_results == 2: | |||||
| bbox_results = [ | |||||
| bbox2result( | |||||
| det_bboxes, | |||||
| det_labels, | |||||
| self.bbox_head.num_classes, | |||||
| kps=det_kps, | |||||
| num_kps=self.bbox_head.NK) | |||||
| for det_bboxes, det_labels, det_kps in bbox_list | |||||
| ] | |||||
| elif output_results == 1: | |||||
| bbox_results = [ | |||||
| bbox2result(det_bboxes, det_labels, | |||||
| self.bbox_head.num_classes) | |||||
| for det_bboxes, det_labels, _ in bbox_list | |||||
| ] | |||||
| else: | |||||
| bbox_results = None | |||||
| if var is not None: | |||||
| return bbox_results, var | |||||
| else: | |||||
| return bbox_results | |||||
| def feature_test(self, img): | def feature_test(self, img): | ||||
| x = self.extract_feat(img) | x = self.extract_feat(img) | ||||
| @@ -0,0 +1,71 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os.path as osp | |||||
| from copy import deepcopy | |||||
| from typing import Any, Dict | |||||
| import torch | |||||
| from modelscope.metainfo import Models | |||||
| from modelscope.models.base import TorchModel | |||||
| from modelscope.models.builder import MODELS | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.utils.constant import ModelFile, Tasks | |||||
| from modelscope.utils.logger import get_logger | |||||
| logger = get_logger() | |||||
| __all__ = ['ScrfdDetect'] | |||||
| @MODELS.register_module(Tasks.face_detection, module_name=Models.scrfd) | |||||
| class ScrfdDetect(TorchModel): | |||||
| def __init__(self, model_dir: str, *args, **kwargs): | |||||
| """initialize the face detection model from the `model_dir` path. | |||||
| Args: | |||||
| model_dir (str): the model path. | |||||
| """ | |||||
| super().__init__(model_dir, *args, **kwargs) | |||||
| from mmcv import Config | |||||
| from mmcv.parallel import MMDataParallel | |||||
| from mmcv.runner import load_checkpoint | |||||
| from mmdet.models import build_detector | |||||
| from modelscope.models.cv.face_detection.scrfd.mmdet_patch.datasets import RetinaFaceDataset | |||||
| from modelscope.models.cv.face_detection.scrfd.mmdet_patch.datasets.pipelines import RandomSquareCrop | |||||
| from modelscope.models.cv.face_detection.scrfd.mmdet_patch.models.backbones import ResNetV1e | |||||
| from modelscope.models.cv.face_detection.scrfd.mmdet_patch.models.dense_heads import SCRFDHead | |||||
| from modelscope.models.cv.face_detection.scrfd.mmdet_patch.models.detectors import SCRFD | |||||
| cfg = Config.fromfile(osp.join(model_dir, 'mmcv_scrfd.py')) | |||||
| ckpt_path = osp.join(model_dir, ModelFile.TORCH_MODEL_BIN_FILE) | |||||
| cfg.model.test_cfg.score_thr = kwargs.get('score_thr', 0.3) | |||||
| detector = build_detector(cfg.model) | |||||
| logger.info(f'loading model from {ckpt_path}') | |||||
| device = torch.device( | |||||
| f'cuda:{0}' if torch.cuda.is_available() else 'cpu') | |||||
| load_checkpoint(detector, ckpt_path, map_location=device) | |||||
| detector = MMDataParallel(detector, device_ids=[0]) | |||||
| detector.eval() | |||||
| self.detector = detector | |||||
| logger.info('load model done') | |||||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||||
| result = self.detector( | |||||
| return_loss=False, | |||||
| rescale=True, | |||||
| img=[input['img'][0].unsqueeze(0)], | |||||
| img_metas=[[dict(input['img_metas'][0].data)]], | |||||
| output_results=2) | |||||
| assert result is not None | |||||
| result = result[0][0] | |||||
| bboxes = result[:, :4].tolist() | |||||
| kpss = result[:, 5:].tolist() | |||||
| scores = result[:, 4].tolist() | |||||
| return { | |||||
| OutputKeys.SCORES: scores, | |||||
| OutputKeys.BOXES: bboxes, | |||||
| OutputKeys.KEYPOINTS: kpss | |||||
| } | |||||
| def postprocess(self, input: Dict[str, Any], **kwargs) -> Dict[str, Any]: | |||||
| return input | |||||
| @@ -90,6 +90,25 @@ TASK_OUTPUTS = { | |||||
| Tasks.face_detection: | Tasks.face_detection: | ||||
| [OutputKeys.SCORES, OutputKeys.BOXES, OutputKeys.KEYPOINTS], | [OutputKeys.SCORES, OutputKeys.BOXES, OutputKeys.KEYPOINTS], | ||||
| # card detection result for single sample | |||||
| # { | |||||
| # "scores": [0.9, 0.1, 0.05, 0.05] | |||||
| # "boxes": [ | |||||
| # [x1, y1, x2, y2], | |||||
| # [x1, y1, x2, y2], | |||||
| # [x1, y1, x2, y2], | |||||
| # [x1, y1, x2, y2], | |||||
| # ], | |||||
| # "keypoints": [ | |||||
| # [x1, y1, x2, y2, x3, y3, x4, y4], | |||||
| # [x1, y1, x2, y2, x3, y3, x4, y4], | |||||
| # [x1, y1, x2, y2, x3, y3, x4, y4], | |||||
| # [x1, y1, x2, y2, x3, y3, x4, y4], | |||||
| # ], | |||||
| # } | |||||
| Tasks.card_detection: | |||||
| [OutputKeys.SCORES, OutputKeys.BOXES, OutputKeys.KEYPOINTS], | |||||
| # facial expression recognition result for single sample | # facial expression recognition result for single sample | ||||
| # { | # { | ||||
| # "scores": [0.9, 0.1, 0.02, 0.02, 0.02, 0.02, 0.02], | # "scores": [0.9, 0.1, 0.02, 0.02, 0.02, 0.02, 0.02], | ||||
| @@ -116,6 +116,10 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||||
| Tasks.hand_2d_keypoints: | Tasks.hand_2d_keypoints: | ||||
| (Pipelines.hand_2d_keypoints, | (Pipelines.hand_2d_keypoints, | ||||
| 'damo/cv_hrnetw18_hand-pose-keypoints_coco-wholebody'), | 'damo/cv_hrnetw18_hand-pose-keypoints_coco-wholebody'), | ||||
| Tasks.face_detection: (Pipelines.face_detection, | |||||
| 'damo/cv_resnet_facedetection_scrfd10gkps'), | |||||
| Tasks.card_detection: (Pipelines.card_detection, | |||||
| 'damo/cv_resnet_carddetection_scrfd34gkps'), | |||||
| Tasks.face_detection: | Tasks.face_detection: | ||||
| (Pipelines.face_detection, | (Pipelines.face_detection, | ||||
| 'damo/cv_resnet101_face-detection_cvpr22papermogface'), | 'damo/cv_resnet101_face-detection_cvpr22papermogface'), | ||||
| @@ -0,0 +1,23 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from modelscope.metainfo import Pipelines | |||||
| from modelscope.pipelines.builder import PIPELINES | |||||
| from modelscope.pipelines.cv.face_detection_pipeline import \ | |||||
| FaceDetectionPipeline | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.logger import get_logger | |||||
| logger = get_logger() | |||||
| @PIPELINES.register_module( | |||||
| Tasks.card_detection, module_name=Pipelines.card_detection) | |||||
| class CardDetectionPipeline(FaceDetectionPipeline): | |||||
| def __init__(self, model: str, **kwargs): | |||||
| """ | |||||
| use `model` to create a card detection pipeline for prediction | |||||
| Args: | |||||
| model: model id on modelscope hub. | |||||
| """ | |||||
| thr = 0.45 # card/face detect use different threshold | |||||
| super().__init__(model=model, score_thr=thr, **kwargs) | |||||
| @@ -8,6 +8,7 @@ import PIL | |||||
| import torch | import torch | ||||
| from modelscope.metainfo import Pipelines | from modelscope.metainfo import Pipelines | ||||
| from modelscope.models.cv.face_detection import ScrfdDetect | |||||
| from modelscope.outputs import OutputKeys | from modelscope.outputs import OutputKeys | ||||
| from modelscope.pipelines.base import Input, Pipeline | from modelscope.pipelines.base import Input, Pipeline | ||||
| from modelscope.pipelines.builder import PIPELINES | from modelscope.pipelines.builder import PIPELINES | ||||
| @@ -29,27 +30,8 @@ class FaceDetectionPipeline(Pipeline): | |||||
| model: model id on modelscope hub. | model: model id on modelscope hub. | ||||
| """ | """ | ||||
| super().__init__(model=model, **kwargs) | super().__init__(model=model, **kwargs) | ||||
| from mmcv import Config | |||||
| from mmcv.parallel import MMDataParallel | |||||
| from mmcv.runner import load_checkpoint | |||||
| from mmdet.models import build_detector | |||||
| from modelscope.models.cv.face_detection.mmdet_patch.datasets import RetinaFaceDataset | |||||
| from modelscope.models.cv.face_detection.mmdet_patch.datasets.pipelines import RandomSquareCrop | |||||
| from modelscope.models.cv.face_detection.mmdet_patch.models.backbones import ResNetV1e | |||||
| from modelscope.models.cv.face_detection.mmdet_patch.models.dense_heads import SCRFDHead | |||||
| from modelscope.models.cv.face_detection.mmdet_patch.models.detectors import SCRFD | |||||
| cfg = Config.fromfile(osp.join(model, 'mmcv_scrfd_10g_bnkps.py')) | |||||
| detector = build_detector( | |||||
| cfg.model, train_cfg=None, test_cfg=cfg.test_cfg) | |||||
| ckpt_path = osp.join(model, ModelFile.TORCH_MODEL_BIN_FILE) | |||||
| logger.info(f'loading model from {ckpt_path}') | |||||
| device = torch.device( | |||||
| f'cuda:{0}' if torch.cuda.is_available() else 'cpu') | |||||
| load_checkpoint(detector, ckpt_path, map_location=device) | |||||
| detector = MMDataParallel(detector, device_ids=[0]) | |||||
| detector.eval() | |||||
| detector = ScrfdDetect(model_dir=model, **kwargs) | |||||
| self.detector = detector | self.detector = detector | ||||
| logger.info('load model done') | |||||
| def preprocess(self, input: Input) -> Dict[str, Any]: | def preprocess(self, input: Input) -> Dict[str, Any]: | ||||
| img = LoadImage.convert_to_ndarray(input) | img = LoadImage.convert_to_ndarray(input) | ||||
| @@ -85,22 +67,7 @@ class FaceDetectionPipeline(Pipeline): | |||||
| return result | return result | ||||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | ||||
| result = self.detector( | |||||
| return_loss=False, | |||||
| rescale=True, | |||||
| img=[input['img'][0].unsqueeze(0)], | |||||
| img_metas=[[dict(input['img_metas'][0].data)]]) | |||||
| assert result is not None | |||||
| result = result[0][0] | |||||
| bboxes = result[:, :4].tolist() | |||||
| kpss = result[:, 5:].tolist() | |||||
| scores = result[:, 4].tolist() | |||||
| return { | |||||
| OutputKeys.SCORES: scores, | |||||
| OutputKeys.BOXES: bboxes, | |||||
| OutputKeys.KEYPOINTS: kpss | |||||
| } | |||||
| return self.detector(input) | |||||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | ||||
| return inputs | return inputs | ||||
| @@ -49,7 +49,7 @@ class FaceRecognitionPipeline(Pipeline): | |||||
| # face detect pipeline | # face detect pipeline | ||||
| det_model_id = 'damo/cv_resnet_facedetection_scrfd10gkps' | det_model_id = 'damo/cv_resnet_facedetection_scrfd10gkps' | ||||
| self.face_detection = pipeline( | self.face_detection = pipeline( | ||||
| Tasks.face_detection, model=det_model_id) | |||||
| Tasks.face_detection, model=det_model_id, model_revision='v2') | |||||
| def _choose_face(self, | def _choose_face(self, | ||||
| det_result, | det_result, | ||||
| @@ -0,0 +1,18 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from modelscope.metainfo import Trainers | |||||
| from modelscope.trainers.builder import TRAINERS | |||||
| from modelscope.trainers.cv.face_detection_scrfd_trainer import \ | |||||
| FaceDetectionScrfdTrainer | |||||
| @TRAINERS.register_module(module_name=Trainers.card_detection_scrfd) | |||||
| class CardDetectionScrfdTrainer(FaceDetectionScrfdTrainer): | |||||
| def __init__(self, cfg_file: str, *args, **kwargs): | |||||
| """ High-level finetune api for SCRFD. | |||||
| Args: | |||||
| cfg_file: Path to configuration file. | |||||
| """ | |||||
| # card/face dataset use different img folder names | |||||
| super().__init__(cfg_file, imgdir_name='', **kwargs) | |||||
| @@ -0,0 +1,154 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import copy | |||||
| import os | |||||
| import os.path as osp | |||||
| import time | |||||
| from typing import Callable, Dict, Optional | |||||
| from modelscope.metainfo import Trainers | |||||
| from modelscope.trainers.base import BaseTrainer | |||||
| from modelscope.trainers.builder import TRAINERS | |||||
| @TRAINERS.register_module(module_name=Trainers.face_detection_scrfd) | |||||
| class FaceDetectionScrfdTrainer(BaseTrainer): | |||||
| def __init__(self, | |||||
| cfg_file: str, | |||||
| cfg_modify_fn: Optional[Callable] = None, | |||||
| *args, | |||||
| **kwargs): | |||||
| """ High-level finetune api for SCRFD. | |||||
| Args: | |||||
| cfg_file: Path to configuration file. | |||||
| cfg_modify_fn: An input fn which is used to modify the cfg read out of the file. | |||||
| """ | |||||
| import mmcv | |||||
| from mmcv.runner import get_dist_info, init_dist | |||||
| from mmcv.utils import get_git_hash | |||||
| from mmdet.utils import collect_env, get_root_logger | |||||
| from mmdet.apis import set_random_seed | |||||
| from mmdet.models import build_detector | |||||
| from mmdet.datasets import build_dataset | |||||
| from mmdet import __version__ | |||||
| from modelscope.models.cv.face_detection.scrfd.mmdet_patch.datasets import RetinaFaceDataset | |||||
| from modelscope.models.cv.face_detection.scrfd.mmdet_patch.datasets.pipelines import DefaultFormatBundleV2 | |||||
| from modelscope.models.cv.face_detection.scrfd.mmdet_patch.datasets.pipelines import LoadAnnotationsV2 | |||||
| from modelscope.models.cv.face_detection.scrfd.mmdet_patch.datasets.pipelines import RotateV2 | |||||
| from modelscope.models.cv.face_detection.scrfd.mmdet_patch.datasets.pipelines import RandomSquareCrop | |||||
| from modelscope.models.cv.face_detection.scrfd.mmdet_patch.models.backbones import ResNetV1e | |||||
| from modelscope.models.cv.face_detection.scrfd.mmdet_patch.models.dense_heads import SCRFDHead | |||||
| from modelscope.models.cv.face_detection.scrfd.mmdet_patch.models.detectors import SCRFD | |||||
| super().__init__(cfg_file) | |||||
| cfg = self.cfg | |||||
| if 'work_dir' in kwargs: | |||||
| cfg.work_dir = kwargs['work_dir'] | |||||
| else: | |||||
| # use config filename as default work_dir if work_dir is None | |||||
| cfg.work_dir = osp.join('./work_dirs', | |||||
| osp.splitext(osp.basename(cfg_file))[0]) | |||||
| mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir)) | |||||
| if 'resume_from' in kwargs: # pretrain model for finetune | |||||
| cfg.resume_from = kwargs['resume_from'] | |||||
| cfg.device = 'cuda' | |||||
| if 'gpu_ids' in kwargs: | |||||
| cfg.gpu_ids = kwargs['gpu_ids'] | |||||
| else: | |||||
| cfg.gpu_ids = range(1) | |||||
| labelfile_name = kwargs.pop('labelfile_name', 'labelv2.txt') | |||||
| imgdir_name = kwargs.pop('imgdir_name', 'images/') | |||||
| if 'train_root' in kwargs: | |||||
| cfg.data.train.ann_file = kwargs['train_root'] + labelfile_name | |||||
| cfg.data.train.img_prefix = kwargs['train_root'] + imgdir_name | |||||
| if 'val_root' in kwargs: | |||||
| cfg.data.val.ann_file = kwargs['val_root'] + labelfile_name | |||||
| cfg.data.val.img_prefix = kwargs['val_root'] + imgdir_name | |||||
| if 'total_epochs' in kwargs: | |||||
| cfg.total_epochs = kwargs['total_epochs'] | |||||
| if cfg_modify_fn is not None: | |||||
| cfg = cfg_modify_fn(cfg) | |||||
| if 'launcher' in kwargs: | |||||
| distributed = True | |||||
| init_dist(kwargs['launcher'], **cfg.dist_params) | |||||
| # re-set gpu_ids with distributed training mode | |||||
| _, world_size = get_dist_info() | |||||
| cfg.gpu_ids = range(world_size) | |||||
| else: | |||||
| distributed = False | |||||
| # no_validate=True will not evaluate checkpoint during training | |||||
| cfg.no_validate = kwargs.get('no_validate', False) | |||||
| # init the logger before other steps | |||||
| timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) | |||||
| log_file = osp.join(cfg.work_dir, f'{timestamp}.log') | |||||
| logger = get_root_logger(log_file=log_file, log_level=cfg.log_level) | |||||
| # init the meta dict to record some important information such as | |||||
| # environment info and seed, which will be logged | |||||
| meta = dict() | |||||
| # log env info | |||||
| env_info_dict = collect_env() | |||||
| env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()]) | |||||
| dash_line = '-' * 60 + '\n' | |||||
| logger.info('Environment info:\n' + dash_line + env_info + '\n' | |||||
| + dash_line) | |||||
| meta['env_info'] = env_info | |||||
| meta['config'] = cfg.pretty_text | |||||
| # log some basic info | |||||
| logger.info(f'Distributed training: {distributed}') | |||||
| logger.info(f'Config:\n{cfg.pretty_text}') | |||||
| # set random seeds | |||||
| if 'seed' in kwargs: | |||||
| cfg.seed = kwargs['seed'] | |||||
| _deterministic = kwargs.get('deterministic', False) | |||||
| logger.info(f'Set random seed to {kwargs["seed"]}, ' | |||||
| f'deterministic: {_deterministic}') | |||||
| set_random_seed(kwargs['seed'], deterministic=_deterministic) | |||||
| else: | |||||
| cfg.seed = None | |||||
| meta['seed'] = cfg.seed | |||||
| meta['exp_name'] = osp.basename(cfg_file) | |||||
| model = build_detector(cfg.model) | |||||
| model.init_weights() | |||||
| datasets = [build_dataset(cfg.data.train)] | |||||
| if len(cfg.workflow) == 2: | |||||
| val_dataset = copy.deepcopy(cfg.data.val) | |||||
| val_dataset.pipeline = cfg.data.train.pipeline | |||||
| datasets.append(build_dataset(val_dataset)) | |||||
| if cfg.checkpoint_config is not None: | |||||
| # save mmdet version, config file content and class names in | |||||
| # checkpoints as meta data | |||||
| cfg.checkpoint_config.meta = dict( | |||||
| mmdet_version=__version__ + get_git_hash()[:7], | |||||
| CLASSES=datasets[0].CLASSES) | |||||
| # add an attribute for visualization convenience | |||||
| model.CLASSES = datasets[0].CLASSES | |||||
| self.cfg = cfg | |||||
| self.datasets = datasets | |||||
| self.model = model | |||||
| self.distributed = distributed | |||||
| self.timestamp = timestamp | |||||
| self.meta = meta | |||||
| self.logger = logger | |||||
| def train(self, *args, **kwargs): | |||||
| from mmdet.apis import train_detector | |||||
| train_detector( | |||||
| self.model, | |||||
| self.datasets, | |||||
| self.cfg, | |||||
| distributed=self.distributed, | |||||
| validate=(not self.cfg.no_validate), | |||||
| timestamp=self.timestamp, | |||||
| meta=self.meta) | |||||
| def evaluate(self, | |||||
| checkpoint_path: str = None, | |||||
| *args, | |||||
| **kwargs) -> Dict[str, float]: | |||||
| cfg = self.cfg.evaluation | |||||
| logger.info(f'eval cfg {cfg}') | |||||
| logger.info(f'checkpoint_path {checkpoint_path}') | |||||
| @@ -19,6 +19,7 @@ class CVTasks(object): | |||||
| # human face body related | # human face body related | ||||
| animal_recognition = 'animal-recognition' | animal_recognition = 'animal-recognition' | ||||
| face_detection = 'face-detection' | face_detection = 'face-detection' | ||||
| card_detection = 'card-detection' | |||||
| face_recognition = 'face-recognition' | face_recognition = 'face-recognition' | ||||
| facial_expression_recognition = 'facial-expression-recognition' | facial_expression_recognition = 'facial-expression-recognition' | ||||
| face_2d_keypoints = 'face-2d-keypoints' | face_2d_keypoints = 'face-2d-keypoints' | ||||
| @@ -154,6 +154,54 @@ def draw_face_detection_result(img_path, detection_result): | |||||
| return img | return img | ||||
| def draw_card_detection_result(img_path, detection_result): | |||||
| def warp_img(src_img, kps, ratio): | |||||
| short_size = 500 | |||||
| if ratio > 1: | |||||
| obj_h = short_size | |||||
| obj_w = int(obj_h * ratio) | |||||
| else: | |||||
| obj_w = short_size | |||||
| obj_h = int(obj_w / ratio) | |||||
| input_pts = np.float32([kps[0], kps[1], kps[2], kps[3]]) | |||||
| output_pts = np.float32([[0, obj_h - 1], [0, 0], [obj_w - 1, 0], | |||||
| [obj_w - 1, obj_h - 1]]) | |||||
| M = cv2.getPerspectiveTransform(input_pts, output_pts) | |||||
| obj_img = cv2.warpPerspective(src_img, M, (obj_w, obj_h)) | |||||
| return obj_img | |||||
| bboxes = np.array(detection_result[OutputKeys.BOXES]) | |||||
| kpss = np.array(detection_result[OutputKeys.KEYPOINTS]) | |||||
| scores = np.array(detection_result[OutputKeys.SCORES]) | |||||
| img_list = [] | |||||
| ver_col = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (0, 255, 255)] | |||||
| img = cv2.imread(img_path) | |||||
| img_list += [img] | |||||
| assert img is not None, f"Can't read img: {img_path}" | |||||
| for i in range(len(scores)): | |||||
| bbox = bboxes[i].astype(np.int32) | |||||
| kps = kpss[i].reshape(-1, 2).astype(np.int32) | |||||
| _w = (kps[0][0] - kps[3][0])**2 + (kps[0][1] - kps[3][1])**2 | |||||
| _h = (kps[0][0] - kps[1][0])**2 + (kps[0][1] - kps[1][1])**2 | |||||
| ratio = 1.59 if _w >= _h else 1 / 1.59 | |||||
| card_img = warp_img(img, kps, ratio) | |||||
| img_list += [card_img] | |||||
| score = scores[i] | |||||
| x1, y1, x2, y2 = bbox | |||||
| cv2.rectangle(img, (x1, y1), (x2, y2), (255, 0, 0), 4) | |||||
| for k, kp in enumerate(kps): | |||||
| cv2.circle(img, tuple(kp), 1, color=ver_col[k], thickness=10) | |||||
| cv2.putText( | |||||
| img, | |||||
| f'{score:.2f}', (x1, y2), | |||||
| 1, | |||||
| 1.0, (0, 255, 0), | |||||
| thickness=1, | |||||
| lineType=8) | |||||
| return img_list | |||||
| def created_boxed_image(image_in, box): | def created_boxed_image(image_in, box): | ||||
| image = load_image(image_in) | image = load_image(image_in) | ||||
| img = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR) | img = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR) | ||||
| @@ -0,0 +1,66 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os.path as osp | |||||
| import unittest | |||||
| import cv2 | |||||
| from modelscope.msdatasets import MsDataset | |||||
| from modelscope.pipelines import pipeline | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.cv.image_utils import draw_card_detection_result | |||||
| from modelscope.utils.demo_utils import DemoCompatibilityCheck | |||||
| from modelscope.utils.test_utils import test_level | |||||
| class CardDetectionTest(unittest.TestCase, DemoCompatibilityCheck): | |||||
| def setUp(self) -> None: | |||||
| self.task = Tasks.card_detection | |||||
| self.model_id = 'damo/cv_resnet_carddetection_scrfd34gkps' | |||||
| def show_result(self, img_path, detection_result): | |||||
| img_list = draw_card_detection_result(img_path, detection_result) | |||||
| for i, img in enumerate(img_list): | |||||
| if i == 0: | |||||
| cv2.imwrite('result.jpg', img_list[0]) | |||||
| print( | |||||
| f'Found {len(img_list)-1} cards, output written to {osp.abspath("result.jpg")}' | |||||
| ) | |||||
| else: | |||||
| cv2.imwrite(f'card_{i}.jpg', img_list[i]) | |||||
| save_path = osp.abspath(f'card_{i}.jpg') | |||||
| print(f'detect card_{i}: {save_path}') | |||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||||
| def test_run_with_dataset(self): | |||||
| input_location = ['data/test/images/card_detection.jpg'] | |||||
| dataset = MsDataset.load(input_location, target='image') | |||||
| card_detection = pipeline(Tasks.card_detection, model=self.model_id) | |||||
| # note that for dataset output, the inference-output is a Generator that can be iterated. | |||||
| result = card_detection(dataset) | |||||
| result = next(result) | |||||
| self.show_result(input_location[0], result) | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_run_modelhub(self): | |||||
| card_detection = pipeline(Tasks.card_detection, model=self.model_id) | |||||
| img_path = 'data/test/images/card_detection.jpg' | |||||
| result = card_detection(img_path) | |||||
| self.show_result(img_path, result) | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| def test_run_modelhub_default_model(self): | |||||
| card_detection = pipeline(Tasks.card_detection) | |||||
| img_path = 'data/test/images/card_detection.jpg' | |||||
| result = card_detection(img_path) | |||||
| self.show_result(img_path, result) | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_demo_compatibility(self): | |||||
| self.compatibility_check() | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||
| @@ -25,10 +25,11 @@ class FaceDetectionTest(unittest.TestCase, DemoCompatibilityCheck): | |||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | ||||
| def test_run_with_dataset(self): | def test_run_with_dataset(self): | ||||
| input_location = ['data/test/images/face_detection.png'] | |||||
| input_location = ['data/test/images/face_detection2.jpeg'] | |||||
| dataset = MsDataset.load(input_location, target='image') | dataset = MsDataset.load(input_location, target='image') | ||||
| face_detection = pipeline(Tasks.face_detection, model=self.model_id) | |||||
| face_detection = pipeline( | |||||
| Tasks.face_detection, model=self.model_id, model_revision='v2') | |||||
| # note that for dataset output, the inference-output is a Generator that can be iterated. | # note that for dataset output, the inference-output is a Generator that can be iterated. | ||||
| result = face_detection(dataset) | result = face_detection(dataset) | ||||
| result = next(result) | result = next(result) | ||||
| @@ -36,8 +37,9 @@ class FaceDetectionTest(unittest.TestCase, DemoCompatibilityCheck): | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | ||||
| def test_run_modelhub(self): | def test_run_modelhub(self): | ||||
| face_detection = pipeline(Tasks.face_detection, model=self.model_id) | |||||
| img_path = 'data/test/images/face_detection.png' | |||||
| face_detection = pipeline( | |||||
| Tasks.face_detection, model=self.model_id, model_revision='v2') | |||||
| img_path = 'data/test/images/face_detection2.jpeg' | |||||
| result = face_detection(img_path) | result = face_detection(img_path) | ||||
| self.show_result(img_path, result) | self.show_result(img_path, result) | ||||
| @@ -45,7 +47,7 @@ class FaceDetectionTest(unittest.TestCase, DemoCompatibilityCheck): | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | ||||
| def test_run_modelhub_default_model(self): | def test_run_modelhub_default_model(self): | ||||
| face_detection = pipeline(Tasks.face_detection) | face_detection = pipeline(Tasks.face_detection) | ||||
| img_path = 'data/test/images/face_detection.png' | |||||
| img_path = 'data/test/images/face_detection2.jpeg' | |||||
| result = face_detection(img_path) | result = face_detection(img_path) | ||||
| self.show_result(img_path, result) | self.show_result(img_path, result) | ||||
| @@ -0,0 +1,151 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import glob | |||||
| import os | |||||
| import shutil | |||||
| import tempfile | |||||
| import unittest | |||||
| import torch | |||||
| from modelscope.hub.snapshot_download import snapshot_download | |||||
| from modelscope.metainfo import Trainers | |||||
| from modelscope.msdatasets import MsDataset | |||||
| from modelscope.trainers import build_trainer | |||||
| from modelscope.utils.config import Config | |||||
| from modelscope.utils.constant import ModelFile | |||||
| from modelscope.utils.test_utils import DistributedTestCase, test_level | |||||
| def _setup(): | |||||
| model_id = 'damo/cv_resnet_carddetection_scrfd34gkps' | |||||
| # mini dataset only for unit test, remove '_mini' for full dataset. | |||||
| ms_ds_syncards = MsDataset.load( | |||||
| 'SyntheticCards_mini', namespace='shaoxuan') | |||||
| data_path = ms_ds_syncards.config_kwargs['split_config'] | |||||
| train_dir = data_path['train'] | |||||
| val_dir = data_path['validation'] | |||||
| train_root = train_dir + '/' + os.listdir(train_dir)[0] + '/' | |||||
| val_root = val_dir + '/' + os.listdir(val_dir)[0] + '/' | |||||
| max_epochs = 1 # run epochs in unit test | |||||
| cache_path = snapshot_download(model_id) | |||||
| tmp_dir = tempfile.TemporaryDirectory().name | |||||
| if not os.path.exists(tmp_dir): | |||||
| os.makedirs(tmp_dir) | |||||
| return train_root, val_root, max_epochs, cache_path, tmp_dir | |||||
| def train_func(**kwargs): | |||||
| trainer = build_trainer( | |||||
| name=Trainers.card_detection_scrfd, default_args=kwargs) | |||||
| trainer.train() | |||||
| class TestCardDetectionScrfdTrainerSingleGPU(unittest.TestCase): | |||||
| def setUp(self): | |||||
| print(('SingleGPU Testing %s.%s' % | |||||
| (type(self).__name__, self._testMethodName))) | |||||
| self.train_root, self.val_root, self.max_epochs, self.cache_path, self.tmp_dir = _setup( | |||||
| ) | |||||
| def tearDown(self): | |||||
| shutil.rmtree(self.tmp_dir) | |||||
| super().tearDown() | |||||
| def _cfg_modify_fn(self, cfg): | |||||
| cfg.checkpoint_config.interval = 1 | |||||
| cfg.log_config.interval = 10 | |||||
| cfg.evaluation.interval = 1 | |||||
| cfg.data.workers_per_gpu = 3 | |||||
| cfg.data.samples_per_gpu = 4 # batch size | |||||
| return cfg | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_trainer_from_scratch(self): | |||||
| kwargs = dict( | |||||
| cfg_file=os.path.join(self.cache_path, 'mmcv_scrfd.py'), | |||||
| work_dir=self.tmp_dir, | |||||
| train_root=self.train_root, | |||||
| val_root=self.val_root, | |||||
| total_epochs=self.max_epochs, | |||||
| cfg_modify_fn=self._cfg_modify_fn) | |||||
| trainer = build_trainer( | |||||
| name=Trainers.card_detection_scrfd, default_args=kwargs) | |||||
| trainer.train() | |||||
| results_files = os.listdir(self.tmp_dir) | |||||
| self.assertIn(f'{trainer.timestamp}.log.json', results_files) | |||||
| for i in range(self.max_epochs): | |||||
| self.assertIn(f'epoch_{i+1}.pth', results_files) | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| def test_trainer_finetune(self): | |||||
| pretrain_epoch = 640 | |||||
| self.max_epochs += pretrain_epoch | |||||
| kwargs = dict( | |||||
| cfg_file=os.path.join(self.cache_path, 'mmcv_scrfd.py'), | |||||
| work_dir=self.tmp_dir, | |||||
| train_root=self.train_root, | |||||
| val_root=self.val_root, | |||||
| total_epochs=self.max_epochs, | |||||
| resume_from=os.path.join(self.cache_path, | |||||
| ModelFile.TORCH_MODEL_BIN_FILE), | |||||
| cfg_modify_fn=self._cfg_modify_fn) | |||||
| trainer = build_trainer( | |||||
| name=Trainers.card_detection_scrfd, default_args=kwargs) | |||||
| trainer.train() | |||||
| results_files = os.listdir(self.tmp_dir) | |||||
| self.assertIn(f'{trainer.timestamp}.log.json', results_files) | |||||
| for i in range(pretrain_epoch, self.max_epochs): | |||||
| self.assertIn(f'epoch_{i+1}.pth', results_files) | |||||
| @unittest.skipIf(not torch.cuda.is_available() | |||||
| or torch.cuda.device_count() <= 1, 'distributed unittest') | |||||
| class TestCardDetectionScrfdTrainerMultiGpus(DistributedTestCase): | |||||
| def setUp(self): | |||||
| print(('MultiGPUs Testing %s.%s' % | |||||
| (type(self).__name__, self._testMethodName))) | |||||
| self.train_root, self.val_root, self.max_epochs, self.cache_path, self.tmp_dir = _setup( | |||||
| ) | |||||
| cfg_file_path = os.path.join(self.cache_path, 'mmcv_scrfd.py') | |||||
| cfg = Config.from_file(cfg_file_path) | |||||
| cfg.checkpoint_config.interval = 1 | |||||
| cfg.log_config.interval = 10 | |||||
| cfg.evaluation.interval = 1 | |||||
| cfg.data.workers_per_gpu = 3 | |||||
| cfg.data.samples_per_gpu = 4 | |||||
| cfg.dump(cfg_file_path) | |||||
| def tearDown(self): | |||||
| shutil.rmtree(self.tmp_dir) | |||||
| super().tearDown() | |||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||||
| def test_multi_gpus_finetune(self): | |||||
| pretrain_epoch = 640 | |||||
| self.max_epochs += pretrain_epoch | |||||
| kwargs = dict( | |||||
| cfg_file=os.path.join(self.cache_path, 'mmcv_scrfd.py'), | |||||
| work_dir=self.tmp_dir, | |||||
| train_root=self.train_root, | |||||
| val_root=self.val_root, | |||||
| total_epochs=self.max_epochs, | |||||
| resume_from=os.path.join(self.cache_path, | |||||
| ModelFile.TORCH_MODEL_BIN_FILE), | |||||
| launcher='pytorch') | |||||
| self.start(train_func, num_gpus=2, **kwargs) | |||||
| results_files = os.listdir(self.tmp_dir) | |||||
| json_files = glob.glob(os.path.join(self.tmp_dir, '*.log.json')) | |||||
| self.assertEqual(len(json_files), 1) | |||||
| for i in range(pretrain_epoch, self.max_epochs): | |||||
| self.assertIn(f'epoch_{i+1}.pth', results_files) | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||
| @@ -0,0 +1,150 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import glob | |||||
| import os | |||||
| import shutil | |||||
| import tempfile | |||||
| import unittest | |||||
| import torch | |||||
| from modelscope.hub.snapshot_download import snapshot_download | |||||
| from modelscope.metainfo import Trainers | |||||
| from modelscope.msdatasets import MsDataset | |||||
| from modelscope.trainers import build_trainer | |||||
| from modelscope.utils.config import Config | |||||
| from modelscope.utils.constant import ModelFile | |||||
| from modelscope.utils.test_utils import DistributedTestCase, test_level | |||||
| def _setup(): | |||||
| model_id = 'damo/cv_resnet_facedetection_scrfd10gkps' | |||||
| # mini dataset only for unit test, remove '_mini' for full dataset. | |||||
| ms_ds_widerface = MsDataset.load('WIDER_FACE_mini', namespace='shaoxuan') | |||||
| data_path = ms_ds_widerface.config_kwargs['split_config'] | |||||
| train_dir = data_path['train'] | |||||
| val_dir = data_path['validation'] | |||||
| train_root = train_dir + '/' + os.listdir(train_dir)[0] + '/' | |||||
| val_root = val_dir + '/' + os.listdir(val_dir)[0] + '/' | |||||
| max_epochs = 1 # run epochs in unit test | |||||
| cache_path = snapshot_download(model_id, revision='v2') | |||||
| tmp_dir = tempfile.TemporaryDirectory().name | |||||
| if not os.path.exists(tmp_dir): | |||||
| os.makedirs(tmp_dir) | |||||
| return train_root, val_root, max_epochs, cache_path, tmp_dir | |||||
| def train_func(**kwargs): | |||||
| trainer = build_trainer( | |||||
| name=Trainers.face_detection_scrfd, default_args=kwargs) | |||||
| trainer.train() | |||||
| class TestFaceDetectionScrfdTrainerSingleGPU(unittest.TestCase): | |||||
| def setUp(self): | |||||
| print(('SingleGPU Testing %s.%s' % | |||||
| (type(self).__name__, self._testMethodName))) | |||||
| self.train_root, self.val_root, self.max_epochs, self.cache_path, self.tmp_dir = _setup( | |||||
| ) | |||||
| def tearDown(self): | |||||
| shutil.rmtree(self.tmp_dir) | |||||
| super().tearDown() | |||||
| def _cfg_modify_fn(self, cfg): | |||||
| cfg.checkpoint_config.interval = 1 | |||||
| cfg.log_config.interval = 10 | |||||
| cfg.evaluation.interval = 1 | |||||
| cfg.data.workers_per_gpu = 3 | |||||
| cfg.data.samples_per_gpu = 4 # batch size | |||||
| return cfg | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_trainer_from_scratch(self): | |||||
| kwargs = dict( | |||||
| cfg_file=os.path.join(self.cache_path, 'mmcv_scrfd.py'), | |||||
| work_dir=self.tmp_dir, | |||||
| train_root=self.train_root, | |||||
| val_root=self.val_root, | |||||
| total_epochs=self.max_epochs, | |||||
| cfg_modify_fn=self._cfg_modify_fn) | |||||
| trainer = build_trainer( | |||||
| name=Trainers.face_detection_scrfd, default_args=kwargs) | |||||
| trainer.train() | |||||
| results_files = os.listdir(self.tmp_dir) | |||||
| self.assertIn(f'{trainer.timestamp}.log.json', results_files) | |||||
| for i in range(self.max_epochs): | |||||
| self.assertIn(f'epoch_{i+1}.pth', results_files) | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| def test_trainer_finetune(self): | |||||
| pretrain_epoch = 640 | |||||
| self.max_epochs += pretrain_epoch | |||||
| kwargs = dict( | |||||
| cfg_file=os.path.join(self.cache_path, 'mmcv_scrfd.py'), | |||||
| work_dir=self.tmp_dir, | |||||
| train_root=self.train_root, | |||||
| val_root=self.val_root, | |||||
| total_epochs=self.max_epochs, | |||||
| resume_from=os.path.join(self.cache_path, | |||||
| ModelFile.TORCH_MODEL_BIN_FILE), | |||||
| cfg_modify_fn=self._cfg_modify_fn) | |||||
| trainer = build_trainer( | |||||
| name=Trainers.face_detection_scrfd, default_args=kwargs) | |||||
| trainer.train() | |||||
| results_files = os.listdir(self.tmp_dir) | |||||
| self.assertIn(f'{trainer.timestamp}.log.json', results_files) | |||||
| for i in range(pretrain_epoch, self.max_epochs): | |||||
| self.assertIn(f'epoch_{i+1}.pth', results_files) | |||||
| @unittest.skipIf(not torch.cuda.is_available() | |||||
| or torch.cuda.device_count() <= 1, 'distributed unittest') | |||||
| class TestFaceDetectionScrfdTrainerMultiGpus(DistributedTestCase): | |||||
| def setUp(self): | |||||
| print(('MultiGPUs Testing %s.%s' % | |||||
| (type(self).__name__, self._testMethodName))) | |||||
| self.train_root, self.val_root, self.max_epochs, self.cache_path, self.tmp_dir = _setup( | |||||
| ) | |||||
| cfg_file_path = os.path.join(self.cache_path, 'mmcv_scrfd.py') | |||||
| cfg = Config.from_file(cfg_file_path) | |||||
| cfg.checkpoint_config.interval = 1 | |||||
| cfg.log_config.interval = 10 | |||||
| cfg.evaluation.interval = 1 | |||||
| cfg.data.workers_per_gpu = 3 | |||||
| cfg.data.samples_per_gpu = 4 | |||||
| cfg.dump(cfg_file_path) | |||||
| def tearDown(self): | |||||
| shutil.rmtree(self.tmp_dir) | |||||
| super().tearDown() | |||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||||
| def test_multi_gpus_finetune(self): | |||||
| pretrain_epoch = 640 | |||||
| self.max_epochs += pretrain_epoch | |||||
| kwargs = dict( | |||||
| cfg_file=os.path.join(self.cache_path, 'mmcv_scrfd.py'), | |||||
| work_dir=self.tmp_dir, | |||||
| train_root=self.train_root, | |||||
| val_root=self.val_root, | |||||
| total_epochs=self.max_epochs, | |||||
| resume_from=os.path.join(self.cache_path, | |||||
| ModelFile.TORCH_MODEL_BIN_FILE), | |||||
| launcher='pytorch') | |||||
| self.start(train_func, num_gpus=2, **kwargs) | |||||
| results_files = os.listdir(self.tmp_dir) | |||||
| json_files = glob.glob(os.path.join(self.tmp_dir, '*.log.json')) | |||||
| self.assertEqual(len(json_files), 1) | |||||
| for i in range(pretrain_epoch, self.max_epochs): | |||||
| self.assertIn(f'epoch_{i+1}.pth', results_files) | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||