| @@ -0,0 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:8e9ab135da7eacabdeeeee11ba4b7bcdd1bfac128cf92a9de9c79f984060ae1e | |||
| size 259865 | |||
| @@ -11,6 +11,7 @@ class Models(object): | |||
| """ | |||
| # vision models | |||
| csrnet = 'csrnet' | |||
| cascade_mask_rcnn_swin = 'cascade_mask_rcnn_swin' | |||
| # nlp models | |||
| bert = 'bert' | |||
| @@ -67,6 +68,7 @@ class Pipelines(object): | |||
| image_super_resolution = 'rrdb-image-super-resolution' | |||
| face_image_generation = 'gan-face-image-generation' | |||
| style_transfer = 'AAMS-style-transfer' | |||
| image_instance_segmentation = 'cascade-mask-rcnn-swin-image-instance-segmentation' | |||
| # nlp tasks | |||
| sentence_similarity = 'sentence-similarity' | |||
| @@ -124,6 +126,7 @@ class Preprocessors(object): | |||
| # cv preprocessor | |||
| load_image = 'load-image' | |||
| image_color_enhance_preprocessor = 'image-color-enhance-preprocessor' | |||
| image_instance_segmentation_preprocessor = 'image-instance-segmentation-preprocessor' | |||
| # nlp preprocessor | |||
| sen_sim_tokenizer = 'sen-sim-tokenizer' | |||
| @@ -157,6 +160,8 @@ class Metrics(object): | |||
| # accuracy | |||
| accuracy = 'accuracy' | |||
| # metric for image instance segmentation task | |||
| image_ins_seg_coco_metric = 'image-ins-seg-coco-metric' | |||
| # metrics for sequence classification task | |||
| seq_cls_metric = 'seq_cls_metric' | |||
| # metrics for token-classification task | |||
| @@ -1,5 +1,7 @@ | |||
| from .base import Metric | |||
| from .builder import METRICS, build_metric, task_default_metrics | |||
| from .image_color_enhance_metric import ImageColorEnhanceMetric | |||
| from .image_instance_segmentation_metric import \ | |||
| ImageInstanceSegmentationCOCOMetric | |||
| from .sequence_classification_metric import SequenceClassificationMetric | |||
| from .text_generation_metric import TextGenerationMetric | |||
| @@ -18,6 +18,7 @@ class MetricKeys(object): | |||
| task_default_metrics = { | |||
| Tasks.image_segmentation: [Metrics.image_ins_seg_coco_metric], | |||
| Tasks.sentence_similarity: [Metrics.seq_cls_metric], | |||
| Tasks.sentiment_classification: [Metrics.seq_cls_metric], | |||
| Tasks.text_generation: [Metrics.text_gen_metric], | |||
| @@ -0,0 +1,312 @@ | |||
| import os.path as osp | |||
| import tempfile | |||
| from collections import OrderedDict | |||
| from typing import Any, Dict | |||
| import numpy as np | |||
| import pycocotools.mask as mask_util | |||
| from pycocotools.coco import COCO | |||
| from pycocotools.cocoeval import COCOeval | |||
| from modelscope.fileio import dump, load | |||
| from modelscope.metainfo import Metrics | |||
| from modelscope.metrics import METRICS, Metric | |||
| from modelscope.utils.registry import default_group | |||
| @METRICS.register_module( | |||
| group_key=default_group, module_name=Metrics.image_ins_seg_coco_metric) | |||
| class ImageInstanceSegmentationCOCOMetric(Metric): | |||
| """The metric computation class for COCO-style image instance segmentation. | |||
| """ | |||
| def __init__(self): | |||
| self.ann_file = None | |||
| self.classes = None | |||
| self.metrics = ['bbox', 'segm'] | |||
| self.proposal_nums = (100, 300, 1000) | |||
| self.iou_thrs = np.linspace( | |||
| .5, 0.95, int(np.round((0.95 - .5) / .05)) + 1, endpoint=True) | |||
| self.results = [] | |||
| def add(self, outputs: Dict[str, Any], inputs: Dict[str, Any]): | |||
| result = outputs['eval_result'] | |||
| # encode mask results | |||
| if isinstance(result[0], tuple): | |||
| result = [(bbox_results, encode_mask_results(mask_results)) | |||
| for bbox_results, mask_results in result] | |||
| self.results.extend(result) | |||
| if self.ann_file is None: | |||
| self.ann_file = outputs['img_metas'][0]['ann_file'] | |||
| self.classes = outputs['img_metas'][0]['classes'] | |||
| def evaluate(self): | |||
| cocoGt = COCO(self.ann_file) | |||
| self.cat_ids = cocoGt.getCatIds(catNms=self.classes) | |||
| self.img_ids = cocoGt.getImgIds() | |||
| result_files, tmp_dir = self.format_results(self.results, self.img_ids) | |||
| eval_results = OrderedDict() | |||
| for metric in self.metrics: | |||
| iou_type = metric | |||
| if metric not in result_files: | |||
| raise KeyError(f'{metric} is not in results') | |||
| try: | |||
| predictions = load(result_files[metric]) | |||
| if iou_type == 'segm': | |||
| # Refer to https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/coco.py#L331 # noqa | |||
| # When evaluating mask AP, if the results contain bbox, | |||
| # cocoapi will use the box area instead of the mask area | |||
| # for calculating the instance area. Though the overall AP | |||
| # is not affected, this leads to different | |||
| # small/medium/large mask AP results. | |||
| for x in predictions: | |||
| x.pop('bbox') | |||
| cocoDt = cocoGt.loadRes(predictions) | |||
| except IndexError: | |||
| print('The testing results of the whole dataset is empty.') | |||
| break | |||
| cocoEval = COCOeval(cocoGt, cocoDt, iou_type) | |||
| cocoEval.params.catIds = self.cat_ids | |||
| cocoEval.params.imgIds = self.img_ids | |||
| cocoEval.params.maxDets = list(self.proposal_nums) | |||
| cocoEval.params.iouThrs = self.iou_thrs | |||
| # mapping of cocoEval.stats | |||
| coco_metric_names = { | |||
| 'mAP': 0, | |||
| 'mAP_50': 1, | |||
| 'mAP_75': 2, | |||
| 'mAP_s': 3, | |||
| 'mAP_m': 4, | |||
| 'mAP_l': 5, | |||
| 'AR@100': 6, | |||
| 'AR@300': 7, | |||
| 'AR@1000': 8, | |||
| 'AR_s@1000': 9, | |||
| 'AR_m@1000': 10, | |||
| 'AR_l@1000': 11 | |||
| } | |||
| cocoEval.evaluate() | |||
| cocoEval.accumulate() | |||
| cocoEval.summarize() | |||
| metric_items = [ | |||
| 'mAP', 'mAP_50', 'mAP_75', 'mAP_s', 'mAP_m', 'mAP_l' | |||
| ] | |||
| for metric_item in metric_items: | |||
| key = f'{metric}_{metric_item}' | |||
| val = float( | |||
| f'{cocoEval.stats[coco_metric_names[metric_item]]:.3f}') | |||
| eval_results[key] = val | |||
| ap = cocoEval.stats[:6] | |||
| eval_results[f'{metric}_mAP_copypaste'] = ( | |||
| f'{ap[0]:.3f} {ap[1]:.3f} {ap[2]:.3f} {ap[3]:.3f} ' | |||
| f'{ap[4]:.3f} {ap[5]:.3f}') | |||
| if tmp_dir is not None: | |||
| tmp_dir.cleanup() | |||
| return eval_results | |||
| def format_results(self, results, img_ids, jsonfile_prefix=None, **kwargs): | |||
| """Format the results to json (standard format for COCO evaluation). | |||
| Args: | |||
| results (list[tuple | numpy.ndarray]): Testing results of the | |||
| dataset. | |||
| data_infos(list[tuple | numpy.ndarray]): data information | |||
| jsonfile_prefix (str | None): The prefix of json files. It includes | |||
| the file path and the prefix of filename, e.g., "a/b/prefix". | |||
| If not specified, a temp file will be created. Default: None. | |||
| Returns: | |||
| tuple: (result_files, tmp_dir), result_files is a dict containing \ | |||
| the json filepaths, tmp_dir is the temporal directory created \ | |||
| for saving json files when jsonfile_prefix is not specified. | |||
| """ | |||
| assert isinstance(results, list), 'results must be a list' | |||
| assert len(results) == len(img_ids), ( | |||
| 'The length of results is not equal to the dataset len: {} != {}'. | |||
| format(len(results), len(img_ids))) | |||
| if jsonfile_prefix is None: | |||
| tmp_dir = tempfile.TemporaryDirectory() | |||
| jsonfile_prefix = osp.join(tmp_dir.name, 'results') | |||
| else: | |||
| tmp_dir = None | |||
| result_files = self.results2json(results, jsonfile_prefix) | |||
| return result_files, tmp_dir | |||
| def xyxy2xywh(self, bbox): | |||
| """Convert ``xyxy`` style bounding boxes to ``xywh`` style for COCO | |||
| evaluation. | |||
| Args: | |||
| bbox (numpy.ndarray): The bounding boxes, shape (4, ), in | |||
| ``xyxy`` order. | |||
| Returns: | |||
| list[float]: The converted bounding boxes, in ``xywh`` order. | |||
| """ | |||
| _bbox = bbox.tolist() | |||
| return [ | |||
| _bbox[0], | |||
| _bbox[1], | |||
| _bbox[2] - _bbox[0], | |||
| _bbox[3] - _bbox[1], | |||
| ] | |||
| def _proposal2json(self, results): | |||
| """Convert proposal results to COCO json style.""" | |||
| json_results = [] | |||
| for idx in range(len(self.img_ids)): | |||
| img_id = self.img_ids[idx] | |||
| bboxes = results[idx] | |||
| for i in range(bboxes.shape[0]): | |||
| data = dict() | |||
| data['image_id'] = img_id | |||
| data['bbox'] = self.xyxy2xywh(bboxes[i]) | |||
| data['score'] = float(bboxes[i][4]) | |||
| data['category_id'] = 1 | |||
| json_results.append(data) | |||
| return json_results | |||
| def _det2json(self, results): | |||
| """Convert detection results to COCO json style.""" | |||
| json_results = [] | |||
| for idx in range(len(self.img_ids)): | |||
| img_id = self.img_ids[idx] | |||
| result = results[idx] | |||
| for label in range(len(result)): | |||
| # Here we skip invalid predicted labels, as we use the fixed num_classes of 80 (COCO) | |||
| # (assuming the num class of input dataset is no more than 80). | |||
| # Recommended manually set `num_classes=${your test dataset class num}` in the | |||
| # configuration.json in practice. | |||
| if label >= len(self.classes): | |||
| break | |||
| bboxes = result[label] | |||
| for i in range(bboxes.shape[0]): | |||
| data = dict() | |||
| data['image_id'] = img_id | |||
| data['bbox'] = self.xyxy2xywh(bboxes[i]) | |||
| data['score'] = float(bboxes[i][4]) | |||
| data['category_id'] = self.cat_ids[label] | |||
| json_results.append(data) | |||
| return json_results | |||
| def _segm2json(self, results): | |||
| """Convert instance segmentation results to COCO json style.""" | |||
| bbox_json_results = [] | |||
| segm_json_results = [] | |||
| for idx in range(len(self.img_ids)): | |||
| img_id = self.img_ids[idx] | |||
| det, seg = results[idx] | |||
| for label in range(len(det)): | |||
| # Here we skip invalid predicted labels, as we use the fixed num_classes of 80 (COCO) | |||
| # (assuming the num class of input dataset is no more than 80). | |||
| # Recommended manually set `num_classes=${your test dataset class num}` in the | |||
| # configuration.json in practice. | |||
| if label >= len(self.classes): | |||
| break | |||
| # bbox results | |||
| bboxes = det[label] | |||
| for i in range(bboxes.shape[0]): | |||
| data = dict() | |||
| data['image_id'] = img_id | |||
| data['bbox'] = self.xyxy2xywh(bboxes[i]) | |||
| data['score'] = float(bboxes[i][4]) | |||
| data['category_id'] = self.cat_ids[label] | |||
| bbox_json_results.append(data) | |||
| # segm results | |||
| # some detectors use different scores for bbox and mask | |||
| if isinstance(seg, tuple): | |||
| segms = seg[0][label] | |||
| mask_score = seg[1][label] | |||
| else: | |||
| segms = seg[label] | |||
| mask_score = [bbox[4] for bbox in bboxes] | |||
| for i in range(bboxes.shape[0]): | |||
| data = dict() | |||
| data['image_id'] = img_id | |||
| data['bbox'] = self.xyxy2xywh(bboxes[i]) | |||
| data['score'] = float(mask_score[i]) | |||
| data['category_id'] = self.cat_ids[label] | |||
| if isinstance(segms[i]['counts'], bytes): | |||
| segms[i]['counts'] = segms[i]['counts'].decode() | |||
| data['segmentation'] = segms[i] | |||
| segm_json_results.append(data) | |||
| return bbox_json_results, segm_json_results | |||
| def results2json(self, results, outfile_prefix): | |||
| """Dump the detection results to a COCO style json file. | |||
| There are 3 types of results: proposals, bbox predictions, mask | |||
| predictions, and they have different data types. This method will | |||
| automatically recognize the type, and dump them to json files. | |||
| Args: | |||
| results (list[list | tuple | ndarray]): Testing results of the | |||
| dataset. | |||
| outfile_prefix (str): The filename prefix of the json files. If the | |||
| prefix is "somepath/xxx", the json files will be named | |||
| "somepath/xxx.bbox.json", "somepath/xxx.segm.json", | |||
| "somepath/xxx.proposal.json". | |||
| Returns: | |||
| dict[str: str]: Possible keys are "bbox", "segm", "proposal", and \ | |||
| values are corresponding filenames. | |||
| """ | |||
| result_files = dict() | |||
| if isinstance(results[0], list): | |||
| json_results = self._det2json(results) | |||
| result_files['bbox'] = f'{outfile_prefix}.bbox.json' | |||
| result_files['proposal'] = f'{outfile_prefix}.bbox.json' | |||
| dump(json_results, result_files['bbox']) | |||
| elif isinstance(results[0], tuple): | |||
| json_results = self._segm2json(results) | |||
| result_files['bbox'] = f'{outfile_prefix}.bbox.json' | |||
| result_files['proposal'] = f'{outfile_prefix}.bbox.json' | |||
| result_files['segm'] = f'{outfile_prefix}.segm.json' | |||
| dump(json_results[0], result_files['bbox']) | |||
| dump(json_results[1], result_files['segm']) | |||
| elif isinstance(results[0], np.ndarray): | |||
| json_results = self._proposal2json(results) | |||
| result_files['proposal'] = f'{outfile_prefix}.proposal.json' | |||
| dump(json_results, result_files['proposal']) | |||
| else: | |||
| raise TypeError('invalid type of results') | |||
| return result_files | |||
| def encode_mask_results(mask_results): | |||
| """Encode bitmap mask to RLE code. | |||
| Args: | |||
| mask_results (list | tuple[list]): bitmap mask results. | |||
| In mask scoring rcnn, mask_results is a tuple of (segm_results, | |||
| segm_cls_score). | |||
| Returns: | |||
| list | tuple: RLE encoded mask. | |||
| """ | |||
| if isinstance(mask_results, tuple): # mask scoring | |||
| cls_segms, cls_mask_scores = mask_results | |||
| else: | |||
| cls_segms = mask_results | |||
| num_classes = len(cls_segms) | |||
| encoded_mask_results = [[] for _ in range(num_classes)] | |||
| for i in range(len(cls_segms)): | |||
| for cls_segm in cls_segms[i]: | |||
| encoded_mask_results[i].append( | |||
| mask_util.encode( | |||
| np.array( | |||
| cls_segm[:, :, np.newaxis], order='F', | |||
| dtype='uint8'))[0]) # encoded with RLE | |||
| if isinstance(mask_results, tuple): | |||
| return encoded_mask_results, cls_mask_scores | |||
| else: | |||
| return encoded_mask_results | |||
| @@ -0,0 +1,2 @@ | |||
| from .cascade_mask_rcnn_swin import CascadeMaskRCNNSwin | |||
| from .model import CascadeMaskRCNNSwinModel | |||
| @@ -0,0 +1 @@ | |||
| from .swin_transformer import SwinTransformer | |||
| @@ -0,0 +1,694 @@ | |||
| # Modified from: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py | |||
| import numpy as np | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| import torch.utils.checkpoint as checkpoint | |||
| from timm.models.layers import DropPath, to_2tuple, trunc_normal_ | |||
| class Mlp(nn.Module): | |||
| """ Multilayer perceptron.""" | |||
| def __init__(self, | |||
| in_features, | |||
| hidden_features=None, | |||
| out_features=None, | |||
| act_layer=nn.GELU, | |||
| drop=0.): | |||
| super().__init__() | |||
| out_features = out_features or in_features | |||
| hidden_features = hidden_features or in_features | |||
| self.fc1 = nn.Linear(in_features, hidden_features) | |||
| self.act = act_layer() | |||
| self.fc2 = nn.Linear(hidden_features, out_features) | |||
| self.drop = nn.Dropout(drop) | |||
| def forward(self, x): | |||
| x = self.fc1(x) | |||
| x = self.act(x) | |||
| x = self.drop(x) | |||
| x = self.fc2(x) | |||
| x = self.drop(x) | |||
| return x | |||
| def window_partition(x, window_size): | |||
| """ | |||
| Args: | |||
| x: (B, H, W, C) | |||
| window_size (int): window size | |||
| Returns: | |||
| windows: (num_windows*B, window_size, window_size, C) | |||
| """ | |||
| B, H, W, C = x.shape | |||
| x = x.view(B, H // window_size, window_size, W // window_size, window_size, | |||
| C) | |||
| windows = x.permute(0, 1, 3, 2, 4, | |||
| 5).contiguous().view(-1, window_size, window_size, C) | |||
| return windows | |||
| def window_reverse(windows, window_size, H, W): | |||
| """ | |||
| Args: | |||
| windows: (num_windows*B, window_size, window_size, C) | |||
| window_size (int): Window size | |||
| H (int): Height of image | |||
| W (int): Width of image | |||
| Returns: | |||
| x: (B, H, W, C) | |||
| """ | |||
| B = int(windows.shape[0] / (H * W / window_size / window_size)) | |||
| x = windows.view(B, H // window_size, W // window_size, window_size, | |||
| window_size, -1) | |||
| x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) | |||
| return x | |||
| class WindowAttention(nn.Module): | |||
| """ Window based multi-head self attention (W-MSA) module with relative position bias. | |||
| It supports both of shifted and non-shifted window. | |||
| Args: | |||
| dim (int): Number of input channels. | |||
| window_size (tuple[int]): The height and width of the window. | |||
| num_heads (int): Number of attention heads. | |||
| qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True | |||
| qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set | |||
| attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 | |||
| proj_drop (float, optional): Dropout ratio of output. Default: 0.0 | |||
| """ | |||
| def __init__(self, | |||
| dim, | |||
| window_size, | |||
| num_heads, | |||
| qkv_bias=True, | |||
| qk_scale=None, | |||
| attn_drop=0., | |||
| proj_drop=0.): | |||
| super().__init__() | |||
| self.dim = dim | |||
| self.window_size = window_size # Wh, Ww | |||
| self.num_heads = num_heads | |||
| head_dim = dim // num_heads | |||
| self.scale = qk_scale or head_dim**-0.5 | |||
| # define a parameter table of relative position bias | |||
| self.relative_position_bias_table = nn.Parameter( | |||
| torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), | |||
| num_heads)) # 2*Wh-1 * 2*Ww-1, nH | |||
| # get pair-wise relative position index for each token inside the window | |||
| coords_h = torch.arange(self.window_size[0]) | |||
| coords_w = torch.arange(self.window_size[1]) | |||
| coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww | |||
| coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww | |||
| relative_coords = coords_flatten[:, :, | |||
| None] - coords_flatten[:, | |||
| None, :] # 2, Wh*Ww, Wh*Ww | |||
| relative_coords = relative_coords.permute( | |||
| 1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 | |||
| relative_coords[:, :, | |||
| 0] += self.window_size[0] - 1 # shift to start from 0 | |||
| relative_coords[:, :, 1] += self.window_size[1] - 1 | |||
| relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 | |||
| relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww | |||
| self.register_buffer('relative_position_index', | |||
| relative_position_index) | |||
| self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) | |||
| self.attn_drop = nn.Dropout(attn_drop) | |||
| self.proj = nn.Linear(dim, dim) | |||
| self.proj_drop = nn.Dropout(proj_drop) | |||
| trunc_normal_(self.relative_position_bias_table, std=.02) | |||
| self.softmax = nn.Softmax(dim=-1) | |||
| def forward(self, x, mask=None): | |||
| """ Forward function. | |||
| Args: | |||
| x: input features with shape of (num_windows*B, N, C) | |||
| mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None | |||
| """ | |||
| B_, N, C = x.shape | |||
| qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, | |||
| C // self.num_heads).permute(2, 0, 3, 1, 4) | |||
| q, k, v = qkv[0], qkv[1], qkv[ | |||
| 2] # make torchscript happy (cannot use tensor as tuple) | |||
| q = q * self.scale | |||
| attn = (q @ k.transpose(-2, -1)) | |||
| relative_position_bias = self.relative_position_bias_table[ | |||
| self.relative_position_index.view(-1)].view( | |||
| self.window_size[0] * self.window_size[1], | |||
| self.window_size[0] * self.window_size[1], | |||
| -1) # Wh*Ww,Wh*Ww,nH | |||
| relative_position_bias = relative_position_bias.permute( | |||
| 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww | |||
| attn = attn + relative_position_bias.unsqueeze(0) | |||
| if mask is not None: | |||
| nW = mask.shape[0] | |||
| attn = attn.view(B_ // nW, nW, self.num_heads, N, | |||
| N) + mask.unsqueeze(1).unsqueeze(0) | |||
| attn = attn.view(-1, self.num_heads, N, N) | |||
| attn = self.softmax(attn) | |||
| else: | |||
| attn = self.softmax(attn) | |||
| attn = self.attn_drop(attn) | |||
| x = (attn @ v).transpose(1, 2).reshape(B_, N, C) | |||
| x = self.proj(x) | |||
| x = self.proj_drop(x) | |||
| return x | |||
| class SwinTransformerBlock(nn.Module): | |||
| """ Swin Transformer Block. | |||
| Args: | |||
| dim (int): Number of input channels. | |||
| num_heads (int): Number of attention heads. | |||
| window_size (int): Window size. | |||
| shift_size (int): Shift size for SW-MSA. | |||
| mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. | |||
| qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True | |||
| qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. | |||
| drop (float, optional): Dropout rate. Default: 0.0 | |||
| attn_drop (float, optional): Attention dropout rate. Default: 0.0 | |||
| drop_path (float, optional): Stochastic depth rate. Default: 0.0 | |||
| act_layer (nn.Module, optional): Activation layer. Default: nn.GELU | |||
| norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm | |||
| """ | |||
| def __init__(self, | |||
| dim, | |||
| num_heads, | |||
| window_size=7, | |||
| shift_size=0, | |||
| mlp_ratio=4., | |||
| qkv_bias=True, | |||
| qk_scale=None, | |||
| drop=0., | |||
| attn_drop=0., | |||
| drop_path=0., | |||
| act_layer=nn.GELU, | |||
| norm_layer=nn.LayerNorm): | |||
| super().__init__() | |||
| self.dim = dim | |||
| self.num_heads = num_heads | |||
| self.window_size = window_size | |||
| self.shift_size = shift_size | |||
| self.mlp_ratio = mlp_ratio | |||
| assert 0 <= self.shift_size < self.window_size, 'shift_size must in 0-window_size' | |||
| self.norm1 = norm_layer(dim) | |||
| self.attn = WindowAttention( | |||
| dim, | |||
| window_size=to_2tuple(self.window_size), | |||
| num_heads=num_heads, | |||
| qkv_bias=qkv_bias, | |||
| qk_scale=qk_scale, | |||
| attn_drop=attn_drop, | |||
| proj_drop=drop) | |||
| self.drop_path = DropPath( | |||
| drop_path) if drop_path > 0. else nn.Identity() | |||
| self.norm2 = norm_layer(dim) | |||
| mlp_hidden_dim = int(dim * mlp_ratio) | |||
| self.mlp = Mlp( | |||
| in_features=dim, | |||
| hidden_features=mlp_hidden_dim, | |||
| act_layer=act_layer, | |||
| drop=drop) | |||
| self.H = None | |||
| self.W = None | |||
| def forward(self, x, mask_matrix): | |||
| """ Forward function. | |||
| Args: | |||
| x: Input feature, tensor size (B, H*W, C). | |||
| H, W: Spatial resolution of the input feature. | |||
| mask_matrix: Attention mask for cyclic shift. | |||
| """ | |||
| B, L, C = x.shape | |||
| H, W = self.H, self.W | |||
| assert L == H * W, 'input feature has wrong size' | |||
| shortcut = x | |||
| x = self.norm1(x) | |||
| x = x.view(B, H, W, C) | |||
| # pad feature maps to multiples of window size | |||
| pad_l = pad_t = 0 | |||
| pad_r = (self.window_size - W % self.window_size) % self.window_size | |||
| pad_b = (self.window_size - H % self.window_size) % self.window_size | |||
| x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) | |||
| _, Hp, Wp, _ = x.shape | |||
| # cyclic shift | |||
| if self.shift_size > 0: | |||
| shifted_x = torch.roll( | |||
| x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) | |||
| attn_mask = mask_matrix | |||
| else: | |||
| shifted_x = x | |||
| attn_mask = None | |||
| # partition windows | |||
| x_windows = window_partition( | |||
| shifted_x, self.window_size) # nW*B, window_size, window_size, C | |||
| x_windows = x_windows.view(-1, self.window_size * self.window_size, | |||
| C) # nW*B, window_size*window_size, C | |||
| # W-MSA/SW-MSA | |||
| attn_windows = self.attn( | |||
| x_windows, mask=attn_mask) # nW*B, window_size*window_size, C | |||
| # merge windows | |||
| attn_windows = attn_windows.view(-1, self.window_size, | |||
| self.window_size, C) | |||
| shifted_x = window_reverse(attn_windows, self.window_size, Hp, | |||
| Wp) # B H' W' C | |||
| # reverse cyclic shift | |||
| if self.shift_size > 0: | |||
| x = torch.roll( | |||
| shifted_x, | |||
| shifts=(self.shift_size, self.shift_size), | |||
| dims=(1, 2)) | |||
| else: | |||
| x = shifted_x | |||
| if pad_r > 0 or pad_b > 0: | |||
| x = x[:, :H, :W, :].contiguous() | |||
| x = x.view(B, H * W, C) | |||
| # FFN | |||
| x = shortcut + self.drop_path(x) | |||
| x = x + self.drop_path(self.mlp(self.norm2(x))) | |||
| return x | |||
| class PatchMerging(nn.Module): | |||
| """ Patch Merging Layer | |||
| Args: | |||
| dim (int): Number of input channels. | |||
| norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm | |||
| """ | |||
| def __init__(self, dim, norm_layer=nn.LayerNorm): | |||
| super().__init__() | |||
| self.dim = dim | |||
| self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) | |||
| self.norm = norm_layer(4 * dim) | |||
| def forward(self, x, H, W): | |||
| """ Forward function. | |||
| Args: | |||
| x: Input feature, tensor size (B, H*W, C). | |||
| H, W: Spatial resolution of the input feature. | |||
| """ | |||
| B, L, C = x.shape | |||
| assert L == H * W, 'input feature has wrong size' | |||
| x = x.view(B, H, W, C) | |||
| # padding | |||
| pad_input = (H % 2 == 1) or (W % 2 == 1) | |||
| if pad_input: | |||
| x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) | |||
| x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C | |||
| x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C | |||
| x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C | |||
| x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C | |||
| x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C | |||
| x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C | |||
| x = self.norm(x) | |||
| x = self.reduction(x) | |||
| return x | |||
| class BasicLayer(nn.Module): | |||
| """ A basic Swin Transformer layer for one stage. | |||
| Args: | |||
| dim (int): Number of feature channels | |||
| depth (int): Depths of this stage. | |||
| num_heads (int): Number of attention head. | |||
| window_size (int): Local window size. Default: 7. | |||
| mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. | |||
| qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True | |||
| qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. | |||
| drop (float, optional): Dropout rate. Default: 0.0 | |||
| attn_drop (float, optional): Attention dropout rate. Default: 0.0 | |||
| drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 | |||
| norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm | |||
| downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None | |||
| use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. | |||
| """ | |||
| def __init__(self, | |||
| dim, | |||
| depth, | |||
| num_heads, | |||
| window_size=7, | |||
| mlp_ratio=4., | |||
| qkv_bias=True, | |||
| qk_scale=None, | |||
| drop=0., | |||
| attn_drop=0., | |||
| drop_path=0., | |||
| norm_layer=nn.LayerNorm, | |||
| downsample=None, | |||
| use_checkpoint=False): | |||
| super().__init__() | |||
| self.window_size = window_size | |||
| self.shift_size = window_size // 2 | |||
| self.depth = depth | |||
| self.use_checkpoint = use_checkpoint | |||
| # build blocks | |||
| self.blocks = nn.ModuleList([ | |||
| SwinTransformerBlock( | |||
| dim=dim, | |||
| num_heads=num_heads, | |||
| window_size=window_size, | |||
| shift_size=0 if (i % 2 == 0) else window_size // 2, | |||
| mlp_ratio=mlp_ratio, | |||
| qkv_bias=qkv_bias, | |||
| qk_scale=qk_scale, | |||
| drop=drop, | |||
| attn_drop=attn_drop, | |||
| drop_path=drop_path[i] | |||
| if isinstance(drop_path, list) else drop_path, | |||
| norm_layer=norm_layer) for i in range(depth) | |||
| ]) | |||
| # patch merging layer | |||
| if downsample is not None: | |||
| self.downsample = downsample(dim=dim, norm_layer=norm_layer) | |||
| else: | |||
| self.downsample = None | |||
| def forward(self, x, H, W): | |||
| """ Forward function. | |||
| Args: | |||
| x: Input feature, tensor size (B, H*W, C). | |||
| H, W: Spatial resolution of the input feature. | |||
| """ | |||
| # calculate attention mask for SW-MSA | |||
| Hp = int(np.ceil(H / self.window_size)) * self.window_size | |||
| Wp = int(np.ceil(W / self.window_size)) * self.window_size | |||
| img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 | |||
| h_slices = (slice(0, -self.window_size), | |||
| slice(-self.window_size, | |||
| -self.shift_size), slice(-self.shift_size, None)) | |||
| w_slices = (slice(0, -self.window_size), | |||
| slice(-self.window_size, | |||
| -self.shift_size), slice(-self.shift_size, None)) | |||
| cnt = 0 | |||
| for h in h_slices: | |||
| for w in w_slices: | |||
| img_mask[:, h, w, :] = cnt | |||
| cnt += 1 | |||
| mask_windows = window_partition( | |||
| img_mask, self.window_size) # nW, window_size, window_size, 1 | |||
| mask_windows = mask_windows.view(-1, | |||
| self.window_size * self.window_size) | |||
| attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) | |||
| attn_mask = attn_mask.masked_fill(attn_mask != 0, | |||
| float(-100.0)).masked_fill( | |||
| attn_mask == 0, float(0.0)) | |||
| for blk in self.blocks: | |||
| blk.H, blk.W = H, W | |||
| if self.use_checkpoint: | |||
| x = checkpoint.checkpoint(blk, x, attn_mask) | |||
| else: | |||
| x = blk(x, attn_mask) | |||
| if self.downsample is not None: | |||
| x_down = self.downsample(x, H, W) | |||
| Wh, Ww = (H + 1) // 2, (W + 1) // 2 | |||
| return x, H, W, x_down, Wh, Ww | |||
| else: | |||
| return x, H, W, x, H, W | |||
| class PatchEmbed(nn.Module): | |||
| """ Image to Patch Embedding | |||
| Args: | |||
| patch_size (int): Patch token size. Default: 4. | |||
| in_chans (int): Number of input image channels. Default: 3. | |||
| embed_dim (int): Number of linear projection output channels. Default: 96. | |||
| norm_layer (nn.Module, optional): Normalization layer. Default: None | |||
| """ | |||
| def __init__(self, | |||
| patch_size=4, | |||
| in_chans=3, | |||
| embed_dim=96, | |||
| norm_layer=None): | |||
| super().__init__() | |||
| patch_size = to_2tuple(patch_size) | |||
| self.patch_size = patch_size | |||
| self.in_chans = in_chans | |||
| self.embed_dim = embed_dim | |||
| self.proj = nn.Conv2d( | |||
| in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) | |||
| if norm_layer is not None: | |||
| self.norm = norm_layer(embed_dim) | |||
| else: | |||
| self.norm = None | |||
| def forward(self, x): | |||
| """Forward function.""" | |||
| # padding | |||
| _, _, H, W = x.size() | |||
| if W % self.patch_size[1] != 0: | |||
| x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) | |||
| if H % self.patch_size[0] != 0: | |||
| x = F.pad(x, | |||
| (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) | |||
| x = self.proj(x) # B C Wh Ww | |||
| if self.norm is not None: | |||
| Wh, Ww = x.size(2), x.size(3) | |||
| x = x.flatten(2).transpose(1, 2) | |||
| x = self.norm(x) | |||
| x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww) | |||
| return x | |||
| class SwinTransformer(nn.Module): | |||
| """ Swin Transformer backbone. | |||
| A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - | |||
| https://arxiv.org/pdf/2103.14030 | |||
| Inspiration from | |||
| https://github.com/SwinTransformer/Swin-Transformer-Object-Detection | |||
| Args: | |||
| pretrain_img_size (int): Input image size for training the pretrained model, | |||
| used in absolute postion embedding. Default 224. | |||
| patch_size (int | tuple(int)): Patch size. Default: 4. | |||
| in_chans (int): Number of input image channels. Default: 3. | |||
| embed_dim (int): Number of linear projection output channels. Default: 96. | |||
| depths (tuple[int]): Depths of each Swin Transformer stage. | |||
| num_heads (tuple[int]): Number of attention head of each stage. | |||
| window_size (int): Window size. Default: 7. | |||
| mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. | |||
| qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True | |||
| qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. | |||
| drop_rate (float): Dropout rate. | |||
| attn_drop_rate (float): Attention dropout rate. Default: 0. | |||
| drop_path_rate (float): Stochastic depth rate. Default: 0.2. | |||
| norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. | |||
| ape (bool): If True, add absolute position embedding to the patch embedding. Default: False. | |||
| patch_norm (bool): If True, add normalization after patch embedding. Default: True. | |||
| out_indices (Sequence[int]): Output from which stages. | |||
| frozen_stages (int): Stages to be frozen (stop grad and set eval mode). | |||
| -1 means not freezing any parameters. | |||
| use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. | |||
| """ | |||
| def __init__(self, | |||
| pretrain_img_size=224, | |||
| patch_size=4, | |||
| in_chans=3, | |||
| embed_dim=96, | |||
| depths=[2, 2, 6, 2], | |||
| num_heads=[3, 6, 12, 24], | |||
| window_size=7, | |||
| mlp_ratio=4., | |||
| qkv_bias=True, | |||
| qk_scale=None, | |||
| drop_rate=0., | |||
| attn_drop_rate=0., | |||
| drop_path_rate=0.2, | |||
| norm_layer=nn.LayerNorm, | |||
| ape=False, | |||
| patch_norm=True, | |||
| out_indices=(0, 1, 2, 3), | |||
| frozen_stages=-1, | |||
| use_checkpoint=False): | |||
| super().__init__() | |||
| self.pretrain_img_size = pretrain_img_size | |||
| self.num_layers = len(depths) | |||
| self.embed_dim = embed_dim | |||
| self.ape = ape | |||
| self.patch_norm = patch_norm | |||
| self.out_indices = out_indices | |||
| self.frozen_stages = frozen_stages | |||
| # split image into non-overlapping patches | |||
| self.patch_embed = PatchEmbed( | |||
| patch_size=patch_size, | |||
| in_chans=in_chans, | |||
| embed_dim=embed_dim, | |||
| norm_layer=norm_layer if self.patch_norm else None) | |||
| # absolute position embedding | |||
| if self.ape: | |||
| pretrain_img_size = to_2tuple(pretrain_img_size) | |||
| patch_size = to_2tuple(patch_size) | |||
| patches_resolution = [ | |||
| pretrain_img_size[0] // patch_size[0], | |||
| pretrain_img_size[1] // patch_size[1] | |||
| ] | |||
| self.absolute_pos_embed = nn.Parameter( | |||
| torch.zeros(1, embed_dim, patches_resolution[0], | |||
| patches_resolution[1])) | |||
| trunc_normal_(self.absolute_pos_embed, std=.02) | |||
| self.pos_drop = nn.Dropout(p=drop_rate) | |||
| # stochastic depth | |||
| dpr = [ | |||
| x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) | |||
| ] # stochastic depth decay rule | |||
| # build layers | |||
| self.layers = nn.ModuleList() | |||
| for i_layer in range(self.num_layers): | |||
| layer = BasicLayer( | |||
| dim=int(embed_dim * 2**i_layer), | |||
| depth=depths[i_layer], | |||
| num_heads=num_heads[i_layer], | |||
| window_size=window_size, | |||
| mlp_ratio=mlp_ratio, | |||
| qkv_bias=qkv_bias, | |||
| qk_scale=qk_scale, | |||
| drop=drop_rate, | |||
| attn_drop=attn_drop_rate, | |||
| drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], | |||
| norm_layer=norm_layer, | |||
| downsample=PatchMerging if | |||
| (i_layer < self.num_layers - 1) else None, | |||
| use_checkpoint=use_checkpoint) | |||
| self.layers.append(layer) | |||
| num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)] | |||
| self.num_features = num_features | |||
| # add a norm layer for each output | |||
| for i_layer in out_indices: | |||
| layer = norm_layer(num_features[i_layer]) | |||
| layer_name = f'norm{i_layer}' | |||
| self.add_module(layer_name, layer) | |||
| self._freeze_stages() | |||
| def _freeze_stages(self): | |||
| if self.frozen_stages >= 0: | |||
| self.patch_embed.eval() | |||
| for param in self.patch_embed.parameters(): | |||
| param.requires_grad = False | |||
| if self.frozen_stages >= 1 and self.ape: | |||
| self.absolute_pos_embed.requires_grad = False | |||
| if self.frozen_stages >= 2: | |||
| self.pos_drop.eval() | |||
| for i in range(0, self.frozen_stages - 1): | |||
| m = self.layers[i] | |||
| m.eval() | |||
| for param in m.parameters(): | |||
| param.requires_grad = False | |||
| def init_weights(self): | |||
| """Initialize the weights in backbone.""" | |||
| def _init_weights(m): | |||
| if isinstance(m, nn.Linear): | |||
| trunc_normal_(m.weight, std=.02) | |||
| if isinstance(m, nn.Linear) and m.bias is not None: | |||
| nn.init.constant_(m.bias, 0) | |||
| elif isinstance(m, nn.LayerNorm): | |||
| nn.init.constant_(m.bias, 0) | |||
| nn.init.constant_(m.weight, 1.0) | |||
| self.apply(_init_weights) | |||
| def forward(self, x): | |||
| """Forward function.""" | |||
| x = self.patch_embed(x) | |||
| Wh, Ww = x.size(2), x.size(3) | |||
| if self.ape: | |||
| # interpolate the position embedding to the corresponding size | |||
| absolute_pos_embed = F.interpolate( | |||
| self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic') | |||
| x = (x + absolute_pos_embed).flatten(2).transpose(1, | |||
| 2) # B Wh*Ww C | |||
| else: | |||
| x = x.flatten(2).transpose(1, 2) | |||
| x = self.pos_drop(x) | |||
| outs = [] | |||
| for i in range(self.num_layers): | |||
| layer = self.layers[i] | |||
| x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww) | |||
| if i in self.out_indices: | |||
| norm_layer = getattr(self, f'norm{i}') | |||
| x_out = norm_layer(x_out) | |||
| out = x_out.view(-1, H, W, | |||
| self.num_features[i]).permute(0, 3, 1, | |||
| 2).contiguous() | |||
| outs.append(out) | |||
| return tuple(outs) | |||
| def train(self, mode=True): | |||
| """Convert the model into training mode while keep layers freezed.""" | |||
| super(SwinTransformer, self).train(mode) | |||
| self._freeze_stages() | |||
| @@ -0,0 +1,266 @@ | |||
| import os | |||
| from collections import OrderedDict | |||
| import torch | |||
| import torch.distributed as dist | |||
| import torch.nn as nn | |||
| from modelscope.models.cv.image_instance_segmentation.backbones import \ | |||
| SwinTransformer | |||
| from modelscope.utils.constant import ModelFile | |||
| from modelscope.utils.logger import get_logger | |||
| logger = get_logger() | |||
| def build_backbone(cfg): | |||
| assert isinstance(cfg, dict) | |||
| cfg = cfg.copy() | |||
| type = cfg.pop('type') | |||
| if type == 'SwinTransformer': | |||
| return SwinTransformer(**cfg) | |||
| else: | |||
| raise ValueError(f'backbone \'{type}\' is not supported.') | |||
| def build_neck(cfg): | |||
| assert isinstance(cfg, dict) | |||
| cfg = cfg.copy() | |||
| type = cfg.pop('type') | |||
| if type == 'FPN': | |||
| from mmdet.models import FPN | |||
| return FPN(**cfg) | |||
| else: | |||
| raise ValueError(f'neck \'{type}\' is not supported.') | |||
| def build_rpn_head(cfg): | |||
| assert isinstance(cfg, dict) | |||
| cfg = cfg.copy() | |||
| type = cfg.pop('type') | |||
| if type == 'RPNHead': | |||
| from mmdet.models import RPNHead | |||
| return RPNHead(**cfg) | |||
| else: | |||
| raise ValueError(f'rpn head \'{type}\' is not supported.') | |||
| def build_roi_head(cfg): | |||
| assert isinstance(cfg, dict) | |||
| cfg = cfg.copy() | |||
| type = cfg.pop('type') | |||
| if type == 'CascadeRoIHead': | |||
| from mmdet.models import CascadeRoIHead | |||
| return CascadeRoIHead(**cfg) | |||
| else: | |||
| raise ValueError(f'roi head \'{type}\' is not supported.') | |||
| class CascadeMaskRCNNSwin(nn.Module): | |||
| def __init__(self, | |||
| backbone, | |||
| neck, | |||
| rpn_head, | |||
| roi_head, | |||
| pretrained=None, | |||
| **kwargs): | |||
| """ | |||
| Args: | |||
| backbone (dict): backbone config. | |||
| neck (dict): neck config. | |||
| rpn_head (dict): rpn_head config. | |||
| roi_head (dict): roi_head config. | |||
| pretrained (bool): whether to use pretrained model | |||
| """ | |||
| super(CascadeMaskRCNNSwin, self).__init__() | |||
| self.backbone = build_backbone(backbone) | |||
| self.neck = build_neck(neck) | |||
| self.rpn_head = build_rpn_head(rpn_head) | |||
| self.roi_head = build_roi_head(roi_head) | |||
| self.classes = kwargs.pop('classes', None) | |||
| if pretrained: | |||
| assert 'model_dir' in kwargs, 'pretrained model dir is missing.' | |||
| model_path = os.path.join(kwargs['model_dir'], | |||
| ModelFile.TORCH_MODEL_FILE) | |||
| logger.info(f'loading model from {model_path}') | |||
| weight = torch.load(model_path)['state_dict'] | |||
| tgt_weight = self.state_dict() | |||
| for name in list(weight.keys()): | |||
| if name in tgt_weight: | |||
| load_size = weight[name].size() | |||
| tgt_size = tgt_weight[name].size() | |||
| mis_match = False | |||
| if len(load_size) != len(tgt_size): | |||
| mis_match = True | |||
| else: | |||
| for n1, n2 in zip(load_size, tgt_size): | |||
| if n1 != n2: | |||
| mis_match = True | |||
| break | |||
| if mis_match: | |||
| logger.info(f'size mismatch for {name}, skip loading.') | |||
| del weight[name] | |||
| self.load_state_dict(weight, strict=False) | |||
| logger.info('load model done') | |||
| from mmcv.parallel import DataContainer, scatter | |||
| self.data_container = DataContainer | |||
| self.scatter = scatter | |||
| def extract_feat(self, img): | |||
| x = self.backbone(img) | |||
| x = self.neck(x) | |||
| return x | |||
| def forward_train(self, | |||
| img, | |||
| img_metas, | |||
| gt_bboxes, | |||
| gt_labels, | |||
| gt_bboxes_ignore=None, | |||
| gt_masks=None, | |||
| proposals=None, | |||
| **kwargs): | |||
| """ | |||
| Args: | |||
| img (Tensor): of shape (N, C, H, W) encoding input images. | |||
| Typically these should be mean centered and std scaled. | |||
| img_metas (list[dict]): list of image info dict where each dict | |||
| has: 'img_shape', 'scale_factor', 'flip', and may also contain | |||
| 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. | |||
| For details on the values of these keys see | |||
| `mmdet/datasets/pipelines/formatting.py:Collect`. | |||
| gt_bboxes (list[Tensor]): Ground truth bboxes for each image with | |||
| shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. | |||
| gt_labels (list[Tensor]): class indices corresponding to each box | |||
| gt_bboxes_ignore (None | list[Tensor]): specify which bounding | |||
| boxes can be ignored when computing the loss. | |||
| gt_masks (None | Tensor) : true segmentation masks for each box | |||
| used if the architecture supports a segmentation task. | |||
| proposals : override rpn proposals with custom proposals. Use when | |||
| `with_rpn` is False. | |||
| Returns: | |||
| dict[str, Tensor]: a dictionary of loss components | |||
| """ | |||
| x = self.extract_feat(img) | |||
| losses = dict() | |||
| # RPN forward and loss | |||
| proposal_cfg = self.rpn_head.train_cfg.get('rpn_proposal', | |||
| self.rpn_head.test_cfg) | |||
| rpn_losses, proposal_list = self.rpn_head.forward_train( | |||
| x, | |||
| img_metas, | |||
| gt_bboxes, | |||
| gt_labels=None, | |||
| gt_bboxes_ignore=gt_bboxes_ignore, | |||
| proposal_cfg=proposal_cfg, | |||
| **kwargs) | |||
| losses.update(rpn_losses) | |||
| roi_losses = self.roi_head.forward_train(x, img_metas, proposal_list, | |||
| gt_bboxes, gt_labels, | |||
| gt_bboxes_ignore, gt_masks, | |||
| **kwargs) | |||
| losses.update(roi_losses) | |||
| return losses | |||
| def forward_test(self, img, img_metas, proposals=None, rescale=True): | |||
| x = self.extract_feat(img) | |||
| if proposals is None: | |||
| proposal_list = self.rpn_head.simple_test_rpn(x, img_metas) | |||
| else: | |||
| proposal_list = proposals | |||
| result = self.roi_head.simple_test( | |||
| x, proposal_list, img_metas, rescale=rescale) | |||
| return dict(eval_result=result, img_metas=img_metas) | |||
| def forward(self, img, img_metas, **kwargs): | |||
| # currently only support cpu or single gpu | |||
| if isinstance(img, self.data_container): | |||
| img = img.data[0] | |||
| if isinstance(img_metas, self.data_container): | |||
| img_metas = img_metas.data[0] | |||
| for k, w in kwargs.items(): | |||
| if isinstance(w, self.data_container): | |||
| w = w.data[0] | |||
| kwargs[k] = w | |||
| if next(self.parameters()).is_cuda: | |||
| device = next(self.parameters()).device | |||
| img = self.scatter(img, [device])[0] | |||
| img_metas = self.scatter(img_metas, [device])[0] | |||
| for k, w in kwargs.items(): | |||
| kwargs[k] = self.scatter(w, [device])[0] | |||
| if self.training: | |||
| losses = self.forward_train(img, img_metas, **kwargs) | |||
| loss, log_vars = self._parse_losses(losses) | |||
| outputs = dict( | |||
| loss=loss, log_vars=log_vars, num_samples=len(img_metas)) | |||
| return outputs | |||
| else: | |||
| return self.forward_test(img, img_metas, **kwargs) | |||
| def _parse_losses(self, losses): | |||
| log_vars = OrderedDict() | |||
| for loss_name, loss_value in losses.items(): | |||
| if isinstance(loss_value, torch.Tensor): | |||
| log_vars[loss_name] = loss_value.mean() | |||
| elif isinstance(loss_value, list): | |||
| log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value) | |||
| else: | |||
| raise TypeError( | |||
| f'{loss_name} is not a tensor or list of tensors') | |||
| loss = sum(_value for _key, _value in log_vars.items() | |||
| if 'loss' in _key) | |||
| log_vars['loss'] = loss | |||
| for loss_name, loss_value in log_vars.items(): | |||
| # reduce loss when distributed training | |||
| if dist.is_available() and dist.is_initialized(): | |||
| loss_value = loss_value.data.clone() | |||
| dist.all_reduce(loss_value.div_(dist.get_world_size())) | |||
| log_vars[loss_name] = loss_value.item() | |||
| return loss, log_vars | |||
| def train_step(self, data, optimizer): | |||
| losses = self(**data) | |||
| loss, log_vars = self._parse_losses(losses) | |||
| outputs = dict( | |||
| loss=loss, log_vars=log_vars, num_samples=len(data['img_metas'])) | |||
| return outputs | |||
| def val_step(self, data, optimizer=None): | |||
| losses = self(**data) | |||
| loss, log_vars = self._parse_losses(losses) | |||
| outputs = dict( | |||
| loss=loss, log_vars=log_vars, num_samples=len(data['img_metas'])) | |||
| return outputs | |||
| @@ -0,0 +1,2 @@ | |||
| from .dataset import ImageInstanceSegmentationCocoDataset | |||
| from .transforms import build_preprocess_transform | |||
| @@ -0,0 +1,332 @@ | |||
| import os.path as osp | |||
| import numpy as np | |||
| from pycocotools.coco import COCO | |||
| from torch.utils.data import Dataset | |||
| class ImageInstanceSegmentationCocoDataset(Dataset): | |||
| """Coco-style dataset for image instance segmentation. | |||
| Args: | |||
| ann_file (str): Annotation file path. | |||
| classes (Sequence[str], optional): Specify classes to load. | |||
| If is None, ``cls.CLASSES`` will be used. Default: None. | |||
| data_root (str, optional): Data root for ``ann_file``, | |||
| ``img_prefix``, ``seg_prefix``, ``proposal_file`` if specified. | |||
| test_mode (bool, optional): If set True, annotation will not be loaded. | |||
| filter_empty_gt (bool, optional): If set true, images without bounding | |||
| boxes of the dataset's classes will be filtered out. This option | |||
| only works when `test_mode=False`, i.e., we never filter images | |||
| during tests. | |||
| """ | |||
| CLASSES = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', | |||
| 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', | |||
| 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', | |||
| 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', | |||
| 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', | |||
| 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', | |||
| 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', | |||
| 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', | |||
| 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', | |||
| 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', | |||
| 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', | |||
| 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', | |||
| 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', | |||
| 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush') | |||
| def __init__(self, | |||
| ann_file, | |||
| classes=None, | |||
| data_root=None, | |||
| img_prefix='', | |||
| seg_prefix=None, | |||
| test_mode=False, | |||
| filter_empty_gt=True): | |||
| self.ann_file = ann_file | |||
| self.data_root = data_root | |||
| self.img_prefix = img_prefix | |||
| self.seg_prefix = seg_prefix | |||
| self.test_mode = test_mode | |||
| self.filter_empty_gt = filter_empty_gt | |||
| self.CLASSES = self.get_classes(classes) | |||
| # join paths if data_root is specified | |||
| if self.data_root is not None: | |||
| if not osp.isabs(self.ann_file): | |||
| self.ann_file = osp.join(self.data_root, self.ann_file) | |||
| if not (self.img_prefix is None or osp.isabs(self.img_prefix)): | |||
| self.img_prefix = osp.join(self.data_root, self.img_prefix) | |||
| if not (self.seg_prefix is None or osp.isabs(self.seg_prefix)): | |||
| self.seg_prefix = osp.join(self.data_root, self.seg_prefix) | |||
| # load annotations | |||
| self.data_infos = self.load_annotations(self.ann_file) | |||
| # filter images too small and containing no annotations | |||
| if not test_mode: | |||
| valid_inds = self._filter_imgs() | |||
| self.data_infos = [self.data_infos[i] for i in valid_inds] | |||
| # set group flag for the sampler | |||
| self._set_group_flag() | |||
| self.preprocessor = None | |||
| def __len__(self): | |||
| """Total number of samples of data.""" | |||
| return len(self.data_infos) | |||
| def load_annotations(self, ann_file): | |||
| """Load annotation from COCO style annotation file. | |||
| Args: | |||
| ann_file (str): Path of annotation file. | |||
| Returns: | |||
| list[dict]: Annotation info from COCO api. | |||
| """ | |||
| self.coco = COCO(ann_file) | |||
| # The order of returned `cat_ids` will not | |||
| # change with the order of the CLASSES | |||
| self.cat_ids = self.coco.getCatIds(catNms=self.CLASSES) | |||
| self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)} | |||
| self.img_ids = self.coco.getImgIds() | |||
| data_infos = [] | |||
| total_ann_ids = [] | |||
| for i in self.img_ids: | |||
| info = self.coco.loadImgs([i])[0] | |||
| info['filename'] = info['file_name'] | |||
| info['ann_file'] = ann_file | |||
| info['classes'] = self.CLASSES | |||
| data_infos.append(info) | |||
| ann_ids = self.coco.getAnnIds(imgIds=[i]) | |||
| total_ann_ids.extend(ann_ids) | |||
| assert len(set(total_ann_ids)) == len( | |||
| total_ann_ids), f"Annotation ids in '{ann_file}' are not unique!" | |||
| return data_infos | |||
| def get_ann_info(self, idx): | |||
| """Get COCO annotation by index. | |||
| Args: | |||
| idx (int): Index of data. | |||
| Returns: | |||
| dict: Annotation info of specified index. | |||
| """ | |||
| img_id = self.data_infos[idx]['id'] | |||
| ann_ids = self.coco.getAnnIds(imgIds=[img_id]) | |||
| ann_info = self.coco.loadAnns(ann_ids) | |||
| return self._parse_ann_info(self.data_infos[idx], ann_info) | |||
| def get_cat_ids(self, idx): | |||
| """Get COCO category ids by index. | |||
| Args: | |||
| idx (int): Index of data. | |||
| Returns: | |||
| list[int]: All categories in the image of specified index. | |||
| """ | |||
| img_id = self.data_infos[idx]['id'] | |||
| ann_ids = self.coco.getAnnIds(imgIds=[img_id]) | |||
| ann_info = self.coco.loadAnns(ann_ids) | |||
| return [ann['category_id'] for ann in ann_info] | |||
| def pre_pipeline(self, results): | |||
| """Prepare results dict for pipeline.""" | |||
| results['img_prefix'] = self.img_prefix | |||
| results['seg_prefix'] = self.seg_prefix | |||
| results['bbox_fields'] = [] | |||
| results['mask_fields'] = [] | |||
| results['seg_fields'] = [] | |||
| def _filter_imgs(self, min_size=32): | |||
| """Filter images too small or without ground truths.""" | |||
| valid_inds = [] | |||
| # obtain images that contain annotation | |||
| ids_with_ann = set(_['image_id'] for _ in self.coco.anns.values()) | |||
| # obtain images that contain annotations of the required categories | |||
| ids_in_cat = set() | |||
| for i, class_id in enumerate(self.cat_ids): | |||
| ids_in_cat |= set(self.coco.catToImgs[class_id]) | |||
| # merge the image id sets of the two conditions and use the merged set | |||
| # to filter out images if self.filter_empty_gt=True | |||
| ids_in_cat &= ids_with_ann | |||
| valid_img_ids = [] | |||
| for i, img_info in enumerate(self.data_infos): | |||
| img_id = self.img_ids[i] | |||
| if self.filter_empty_gt and img_id not in ids_in_cat: | |||
| continue | |||
| if min(img_info['width'], img_info['height']) >= min_size: | |||
| valid_inds.append(i) | |||
| valid_img_ids.append(img_id) | |||
| self.img_ids = valid_img_ids | |||
| return valid_inds | |||
| def _parse_ann_info(self, img_info, ann_info): | |||
| """Parse bbox and mask annotation. | |||
| Args: | |||
| ann_info (list[dict]): Annotation info of an image. | |||
| Returns: | |||
| dict: A dict containing the following keys: bboxes, bboxes_ignore,\ | |||
| labels, masks, seg_map. "masks" are raw annotations and not \ | |||
| decoded into binary masks. | |||
| """ | |||
| gt_bboxes = [] | |||
| gt_labels = [] | |||
| gt_bboxes_ignore = [] | |||
| gt_masks_ann = [] | |||
| for i, ann in enumerate(ann_info): | |||
| if ann.get('ignore', False): | |||
| continue | |||
| x1, y1, w, h = ann['bbox'] | |||
| inter_w = max(0, min(x1 + w, img_info['width']) - max(x1, 0)) | |||
| inter_h = max(0, min(y1 + h, img_info['height']) - max(y1, 0)) | |||
| if inter_w * inter_h == 0: | |||
| continue | |||
| if ann['area'] <= 0 or w < 1 or h < 1: | |||
| continue | |||
| if ann['category_id'] not in self.cat_ids: | |||
| continue | |||
| bbox = [x1, y1, x1 + w, y1 + h] | |||
| if ann.get('iscrowd', False): | |||
| gt_bboxes_ignore.append(bbox) | |||
| else: | |||
| gt_bboxes.append(bbox) | |||
| gt_labels.append(self.cat2label[ann['category_id']]) | |||
| gt_masks_ann.append(ann.get('segmentation', None)) | |||
| if gt_bboxes: | |||
| gt_bboxes = np.array(gt_bboxes, dtype=np.float32) | |||
| gt_labels = np.array(gt_labels, dtype=np.int64) | |||
| else: | |||
| gt_bboxes = np.zeros((0, 4), dtype=np.float32) | |||
| gt_labels = np.array([], dtype=np.int64) | |||
| if gt_bboxes_ignore: | |||
| gt_bboxes_ignore = np.array(gt_bboxes_ignore, dtype=np.float32) | |||
| else: | |||
| gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32) | |||
| seg_map = img_info['filename'].replace('jpg', 'png') | |||
| ann = dict( | |||
| bboxes=gt_bboxes, | |||
| labels=gt_labels, | |||
| bboxes_ignore=gt_bboxes_ignore, | |||
| masks=gt_masks_ann, | |||
| seg_map=seg_map) | |||
| return ann | |||
| def _set_group_flag(self): | |||
| """Set flag according to image aspect ratio. | |||
| Images with aspect ratio greater than 1 will be set as group 1, | |||
| otherwise group 0. | |||
| """ | |||
| self.flag = np.zeros(len(self), dtype=np.uint8) | |||
| for i in range(len(self)): | |||
| img_info = self.data_infos[i] | |||
| if img_info['width'] / img_info['height'] > 1: | |||
| self.flag[i] = 1 | |||
| def _rand_another(self, idx): | |||
| """Get another random index from the same group as the given index.""" | |||
| pool = np.where(self.flag == self.flag[idx])[0] | |||
| return np.random.choice(pool) | |||
| def __getitem__(self, idx): | |||
| """Get training/test data after pipeline. | |||
| Args: | |||
| idx (int): Index of data. | |||
| Returns: | |||
| dict: Training/test data (with annotation if `test_mode` is set \ | |||
| True). | |||
| """ | |||
| if self.test_mode: | |||
| return self.prepare_test_img(idx) | |||
| while True: | |||
| data = self.prepare_train_img(idx) | |||
| if data is None: | |||
| idx = self._rand_another(idx) | |||
| continue | |||
| return data | |||
| def prepare_train_img(self, idx): | |||
| """Get training data and annotations after pipeline. | |||
| Args: | |||
| idx (int): Index of data. | |||
| Returns: | |||
| dict: Training data and annotation after pipeline with new keys \ | |||
| introduced by pipeline. | |||
| """ | |||
| img_info = self.data_infos[idx] | |||
| ann_info = self.get_ann_info(idx) | |||
| results = dict(img_info=img_info, ann_info=ann_info) | |||
| self.pre_pipeline(results) | |||
| if self.preprocessor is None: | |||
| return results | |||
| self.preprocessor.train() | |||
| return self.preprocessor(results) | |||
| def prepare_test_img(self, idx): | |||
| """Get testing data after pipeline. | |||
| Args: | |||
| idx (int): Index of data. | |||
| Returns: | |||
| dict: Testing data after pipeline with new keys introduced by \ | |||
| pipeline. | |||
| """ | |||
| img_info = self.data_infos[idx] | |||
| results = dict(img_info=img_info) | |||
| self.pre_pipeline(results) | |||
| if self.preprocessor is None: | |||
| return results | |||
| self.preprocessor.eval() | |||
| results = self.preprocessor(results) | |||
| return results | |||
| @classmethod | |||
| def get_classes(cls, classes=None): | |||
| """Get class names of current dataset. | |||
| Args: | |||
| classes (Sequence[str] | None): If classes is None, use | |||
| default CLASSES defined by builtin dataset. If classes is | |||
| a tuple or list, override the CLASSES defined by the dataset. | |||
| Returns: | |||
| tuple[str] or list[str]: Names of categories of the dataset. | |||
| """ | |||
| if classes is None: | |||
| return cls.CLASSES | |||
| if isinstance(classes, (tuple, list)): | |||
| class_names = classes | |||
| else: | |||
| raise ValueError(f'Unsupported type {type(classes)} of classes.') | |||
| return class_names | |||
| def to_torch_dataset(self, preprocessors=None): | |||
| self.preprocessor = preprocessors | |||
| return self | |||
| @@ -0,0 +1,109 @@ | |||
| import os.path as osp | |||
| import numpy as np | |||
| from modelscope.fileio import File | |||
| def build_preprocess_transform(cfg): | |||
| assert isinstance(cfg, dict) | |||
| cfg = cfg.copy() | |||
| type = cfg.pop('type') | |||
| if type == 'LoadImageFromFile': | |||
| return LoadImageFromFile(**cfg) | |||
| elif type == 'LoadAnnotations': | |||
| from mmdet.datasets.pipelines import LoadAnnotations | |||
| return LoadAnnotations(**cfg) | |||
| elif type == 'Resize': | |||
| if 'img_scale' in cfg: | |||
| if isinstance(cfg.img_scale[0], list): | |||
| elems = [] | |||
| for elem in cfg.img_scale: | |||
| elems.append(tuple(elem)) | |||
| cfg.img_scale = elems | |||
| else: | |||
| cfg.img_scale = tuple(cfg.img_scale) | |||
| from mmdet.datasets.pipelines import Resize | |||
| return Resize(**cfg) | |||
| elif type == 'RandomFlip': | |||
| from mmdet.datasets.pipelines import RandomFlip | |||
| return RandomFlip(**cfg) | |||
| elif type == 'Normalize': | |||
| from mmdet.datasets.pipelines import Normalize | |||
| return Normalize(**cfg) | |||
| elif type == 'Pad': | |||
| from mmdet.datasets.pipelines import Pad | |||
| return Pad(**cfg) | |||
| elif type == 'DefaultFormatBundle': | |||
| from mmdet.datasets.pipelines import DefaultFormatBundle | |||
| return DefaultFormatBundle(**cfg) | |||
| elif type == 'ImageToTensor': | |||
| from mmdet.datasets.pipelines import ImageToTensor | |||
| return ImageToTensor(**cfg) | |||
| elif type == 'Collect': | |||
| from mmdet.datasets.pipelines import Collect | |||
| return Collect(**cfg) | |||
| else: | |||
| raise ValueError(f'preprocess transform \'{type}\' is not supported.') | |||
| class LoadImageFromFile: | |||
| """Load an image from file. | |||
| Required keys are "img_prefix" and "img_info" (a dict that must contain the | |||
| key "filename"). Added or updated keys are "filename", "img", "img_shape", | |||
| "ori_shape" (same as `img_shape`), "pad_shape" (same as `img_shape`), | |||
| "scale_factor" (1.0) and "img_norm_cfg" (means=0 and stds=1). | |||
| Args: | |||
| to_float32 (bool): Whether to convert the loaded image to a float32 | |||
| numpy array. If set to False, the loaded image is an uint8 array. | |||
| Defaults to False. | |||
| """ | |||
| def __init__(self, to_float32=False, mode='rgb'): | |||
| self.to_float32 = to_float32 | |||
| self.mode = mode | |||
| from mmcv import imfrombytes | |||
| self.imfrombytes = imfrombytes | |||
| def __call__(self, results): | |||
| """Call functions to load image and get image meta information. | |||
| Args: | |||
| results (dict): Result dict from :obj:`ImageInstanceSegmentationDataset`. | |||
| Returns: | |||
| dict: The dict contains loaded image and meta information. | |||
| """ | |||
| if results['img_prefix'] is not None: | |||
| filename = osp.join(results['img_prefix'], | |||
| results['img_info']['filename']) | |||
| else: | |||
| filename = results['img_info']['filename'] | |||
| img_bytes = File.read(filename) | |||
| img = self.imfrombytes(img_bytes, 'color', 'bgr', backend='pillow') | |||
| if self.to_float32: | |||
| img = img.astype(np.float32) | |||
| results['filename'] = filename | |||
| results['ori_filename'] = results['img_info']['filename'] | |||
| results['img'] = img | |||
| results['img_shape'] = img.shape | |||
| results['ori_shape'] = img.shape | |||
| results['img_fields'] = ['img'] | |||
| results['ann_file'] = results['img_info']['ann_file'] | |||
| results['classes'] = results['img_info']['classes'] | |||
| return results | |||
| def __repr__(self): | |||
| repr_str = (f'{self.__class__.__name__}(' | |||
| f'to_float32={self.to_float32}, ' | |||
| f"mode='{self.mode}'") | |||
| return repr_str | |||
| @@ -0,0 +1,49 @@ | |||
| import os | |||
| from typing import Any, Dict | |||
| import torch | |||
| from modelscope.metainfo import Models | |||
| from modelscope.models.base import TorchModel | |||
| from modelscope.models.builder import MODELS | |||
| from modelscope.models.cv.image_instance_segmentation import \ | |||
| CascadeMaskRCNNSwin | |||
| from modelscope.utils.config import Config | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| @MODELS.register_module( | |||
| Tasks.image_segmentation, module_name=Models.cascade_mask_rcnn_swin) | |||
| class CascadeMaskRCNNSwinModel(TorchModel): | |||
| def __init__(self, model_dir=None, *args, **kwargs): | |||
| """ | |||
| Args: | |||
| model_dir (str): model directory. | |||
| """ | |||
| super(CascadeMaskRCNNSwinModel, self).__init__( | |||
| model_dir=model_dir, *args, **kwargs) | |||
| if 'backbone' not in kwargs: | |||
| config_path = os.path.join(model_dir, ModelFile.CONFIGURATION) | |||
| cfg = Config.from_file(config_path) | |||
| model_cfg = cfg.model | |||
| kwargs.update(model_cfg) | |||
| self.model = CascadeMaskRCNNSwin(model_dir=model_dir, **kwargs) | |||
| self.device = torch.device( | |||
| 'cuda' if torch.cuda.is_available() else 'cpu') | |||
| self.model.to(self.device) | |||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
| output = self.model(**input) | |||
| return output | |||
| def postprocess(self, input: Dict[str, Any], **kwargs) -> Dict[str, Any]: | |||
| return input | |||
| def compute_loss(self, outputs: Dict[str, Any], labels): | |||
| pass | |||
| @@ -0,0 +1,203 @@ | |||
| import itertools | |||
| import cv2 | |||
| import numpy as np | |||
| import pycocotools.mask as maskUtils | |||
| import torch | |||
| from modelscope.outputs import OutputKeys | |||
| def get_seg_bboxes(bboxes, labels, segms=None, class_names=None, score_thr=0.): | |||
| assert bboxes.ndim == 2, \ | |||
| f' bboxes ndim should be 2, but its ndim is {bboxes.ndim}.' | |||
| assert labels.ndim == 1, \ | |||
| f' labels ndim should be 1, but its ndim is {labels.ndim}.' | |||
| assert bboxes.shape[0] == labels.shape[0], \ | |||
| 'bboxes.shape[0] and labels.shape[0] should have the same length.' | |||
| assert bboxes.shape[1] == 4 or bboxes.shape[1] == 5, \ | |||
| f' bboxes.shape[1] should be 4 or 5, but its {bboxes.shape[1]}.' | |||
| if score_thr > 0: | |||
| assert bboxes.shape[1] == 5 | |||
| scores = bboxes[:, -1] | |||
| inds = scores > score_thr | |||
| bboxes = bboxes[inds, :] | |||
| labels = labels[inds] | |||
| if segms is not None: | |||
| segms = segms[inds, ...] | |||
| bboxes_names = [] | |||
| for i, (bbox, label) in enumerate(zip(bboxes, labels)): | |||
| label_name = class_names[ | |||
| label] if class_names is not None else f'class {label}' | |||
| bbox = [0 if b < 0 else b for b in list(bbox)] | |||
| bbox.append(label_name) | |||
| bbox.append(segms[i].astype(bool)) | |||
| bboxes_names.append(bbox) | |||
| return bboxes_names | |||
| def get_img_seg_results(det_rawdata=None, | |||
| class_names=None, | |||
| score_thr=0.3, | |||
| is_decode=True): | |||
| ''' | |||
| Get all boxes of one image. | |||
| score_thr: Classification probability threshold。 | |||
| output format: [ [x1,y1,x2,y2, prob, cls_name, mask], [x1,y1,x2,y2, prob, cls_name, mask], ... ] | |||
| ''' | |||
| assert det_rawdata is not None, 'det_rawdata should be not None.' | |||
| assert class_names is not None, 'class_names should be not None.' | |||
| if isinstance(det_rawdata, tuple): | |||
| bbox_result, segm_result = det_rawdata | |||
| if isinstance(segm_result, tuple): | |||
| segm_result = segm_result[0] # ms rcnn | |||
| else: | |||
| bbox_result, segm_result = det_rawdata, None | |||
| bboxes = np.vstack(bbox_result) | |||
| labels = [ | |||
| np.full(bbox.shape[0], i, dtype=np.int32) | |||
| for i, bbox in enumerate(bbox_result) | |||
| ] | |||
| labels = np.concatenate(labels) | |||
| segms = None | |||
| if segm_result is not None and len(labels) > 0: # non empty | |||
| segms = list(itertools.chain(*segm_result)) | |||
| if is_decode: | |||
| segms = maskUtils.decode(segms) | |||
| segms = segms.transpose(2, 0, 1) | |||
| if isinstance(segms[0], torch.Tensor): | |||
| segms = torch.stack(segms, dim=0).detach().cpu().numpy() | |||
| else: | |||
| segms = np.stack(segms, axis=0) | |||
| bboxes_names = get_seg_bboxes( | |||
| bboxes, | |||
| labels, | |||
| segms=segms, | |||
| class_names=class_names, | |||
| score_thr=score_thr) | |||
| return bboxes_names | |||
| def get_img_ins_seg_result(img_seg_result=None, | |||
| class_names=None, | |||
| score_thr=0.3): | |||
| assert img_seg_result is not None, 'img_seg_result should be not None.' | |||
| assert class_names is not None, 'class_names should be not None.' | |||
| img_seg_result = get_img_seg_results( | |||
| det_rawdata=(img_seg_result[0], img_seg_result[1]), | |||
| class_names=class_names, | |||
| score_thr=score_thr, | |||
| is_decode=False) | |||
| results_dict = { | |||
| OutputKeys.BOXES: [], | |||
| OutputKeys.MASKS: [], | |||
| OutputKeys.LABELS: [], | |||
| OutputKeys.SCORES: [] | |||
| } | |||
| for seg_result in img_seg_result: | |||
| box = { | |||
| 'x': np.int(seg_result[0]), | |||
| 'y': np.int(seg_result[1]), | |||
| 'w': np.int(seg_result[2] - seg_result[0]), | |||
| 'h': np.int(seg_result[3] - seg_result[1]) | |||
| } | |||
| score = np.float(seg_result[4]) | |||
| category = seg_result[5] | |||
| mask = np.array(seg_result[6], order='F', dtype='uint8') | |||
| mask = mask.astype(np.float) | |||
| results_dict[OutputKeys.BOXES].append(box) | |||
| results_dict[OutputKeys.MASKS].append(mask) | |||
| results_dict[OutputKeys.SCORES].append(score) | |||
| results_dict[OutputKeys.LABELS].append(category) | |||
| return results_dict | |||
| def show_result( | |||
| img, | |||
| result, | |||
| out_file='result.jpg', | |||
| show_box=True, | |||
| show_label=True, | |||
| show_score=True, | |||
| alpha=0.5, | |||
| fontScale=0.5, | |||
| fontFace=cv2.FONT_HERSHEY_COMPLEX_SMALL, | |||
| thickness=1, | |||
| ): | |||
| assert isinstance(img, (str, np.ndarray)), \ | |||
| f'img must be str or np.ndarray, but got {type(img)}.' | |||
| if isinstance(img, str): | |||
| img = cv2.imread(img) | |||
| if len(img.shape) == 2: | |||
| img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) | |||
| img = img.astype(np.float32) | |||
| labels = result[OutputKeys.LABELS] | |||
| scores = result[OutputKeys.SCORES] | |||
| boxes = result[OutputKeys.BOXES] | |||
| masks = result[OutputKeys.MASKS] | |||
| for label, score, box, mask in zip(labels, scores, boxes, masks): | |||
| random_color = np.array([ | |||
| np.random.random() * 255.0, | |||
| np.random.random() * 255.0, | |||
| np.random.random() * 255.0 | |||
| ]) | |||
| x1 = int(box['x']) | |||
| y1 = int(box['y']) | |||
| w = int(box['w']) | |||
| h = int(box['h']) | |||
| x2 = x1 + w | |||
| y2 = y1 + h | |||
| if show_box: | |||
| cv2.rectangle( | |||
| img, (x1, y1), (x2, y2), random_color, thickness=thickness) | |||
| if show_label or show_score: | |||
| if show_label and show_score: | |||
| text = '{}|{}'.format(label, round(float(score), 2)) | |||
| elif show_label: | |||
| text = '{}'.format(label) | |||
| else: | |||
| text = '{}'.format(round(float(score), 2)) | |||
| retval, baseLine = cv2.getTextSize( | |||
| text, | |||
| fontFace=fontFace, | |||
| fontScale=fontScale, | |||
| thickness=thickness) | |||
| cv2.rectangle( | |||
| img, (x1, y1 - retval[1] - baseLine), (x1 + retval[0], y1), | |||
| thickness=-1, | |||
| color=(0, 0, 0)) | |||
| cv2.putText( | |||
| img, | |||
| text, (x1, y1 - baseLine), | |||
| fontScale=fontScale, | |||
| fontFace=fontFace, | |||
| thickness=thickness, | |||
| color=random_color) | |||
| idx = np.nonzero(mask) | |||
| img[idx[0], idx[1], :] *= 1.0 - alpha | |||
| img[idx[0], idx[1], :] += alpha * random_color | |||
| cv2.imwrite(out_file, img) | |||
| @@ -13,6 +13,7 @@ class OutputKeys(object): | |||
| POSES = 'poses' | |||
| CAPTION = 'caption' | |||
| BOXES = 'boxes' | |||
| MASKS = 'masks' | |||
| TEXT = 'text' | |||
| POLYGONS = 'polygons' | |||
| OUTPUT = 'output' | |||
| @@ -76,6 +76,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||
| 'damo/cv_daflow_virtual-tryon_base'), | |||
| Tasks.image_colorization: (Pipelines.image_colorization, | |||
| 'damo/cv_unet_image-colorization'), | |||
| Tasks.image_segmentation: | |||
| (Pipelines.image_instance_segmentation, | |||
| 'damo/cv_swin-b_image-instance-segmentation_coco'), | |||
| Tasks.style_transfer: (Pipelines.style_transfer, | |||
| 'damo/cv_aams_style-transfer_damo'), | |||
| Tasks.face_image_generation: (Pipelines.face_image_generation, | |||
| @@ -11,6 +11,7 @@ try: | |||
| from .image_colorization_pipeline import ImageColorizationPipeline | |||
| from .image_super_resolution_pipeline import ImageSuperResolutionPipeline | |||
| from .face_image_generation_pipeline import FaceImageGenerationPipeline | |||
| from .image_instance_segmentation_pipeline import ImageInstanceSegmentationPipeline | |||
| except ModuleNotFoundError as e: | |||
| if str(e) == "No module named 'torch'": | |||
| pass | |||
| @@ -0,0 +1,105 @@ | |||
| import os | |||
| from typing import Any, Dict, Optional, Union | |||
| import cv2 | |||
| import numpy as np | |||
| import torch | |||
| from PIL import Image | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.models.cv.image_instance_segmentation.model import \ | |||
| CascadeMaskRCNNSwinModel | |||
| from modelscope.models.cv.image_instance_segmentation.postprocess_utils import \ | |||
| get_img_ins_seg_result | |||
| from modelscope.pipelines.base import Input, Pipeline | |||
| from modelscope.pipelines.builder import PIPELINES | |||
| from modelscope.preprocessors import (ImageInstanceSegmentationPreprocessor, | |||
| build_preprocessor, load_image) | |||
| from modelscope.utils.config import Config | |||
| from modelscope.utils.constant import Fields, ModelFile, Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| logger = get_logger() | |||
| @PIPELINES.register_module( | |||
| Tasks.image_segmentation, | |||
| module_name=Pipelines.image_instance_segmentation) | |||
| class ImageInstanceSegmentationPipeline(Pipeline): | |||
| def __init__(self, | |||
| model: Union[CascadeMaskRCNNSwinModel, str], | |||
| preprocessor: Optional[ | |||
| ImageInstanceSegmentationPreprocessor] = None, | |||
| **kwargs): | |||
| """use `model` and `preprocessor` to create a image instance segmentation pipeline for prediction | |||
| Args: | |||
| model (CascadeMaskRCNNSwinModel | str): a model instance | |||
| preprocessor (CascadeMaskRCNNSwinPreprocessor | None): a preprocessor instance | |||
| """ | |||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||
| if preprocessor is None: | |||
| config_path = os.path.join(self.model.model_dir, | |||
| ModelFile.CONFIGURATION) | |||
| cfg = Config.from_file(config_path) | |||
| self.preprocessor = build_preprocessor(cfg.preprocessor, Fields.cv) | |||
| else: | |||
| self.preprocessor = preprocessor | |||
| self.preprocessor.eval() | |||
| self.model.eval() | |||
| def _collate_fn(self, data): | |||
| # don't require collating | |||
| return data | |||
| def preprocess(self, input: Input, **preprocess_params) -> Dict[str, Any]: | |||
| filename = None | |||
| img = None | |||
| if isinstance(input, str): | |||
| filename = input | |||
| img = np.array(load_image(input)) | |||
| img = img[:, :, ::-1] # convert to bgr | |||
| elif isinstance(input, Image.Image): | |||
| img = np.array(input.convert('RGB')) | |||
| img = img[:, :, ::-1] # convert to bgr | |||
| elif isinstance(input, np.ndarray): | |||
| if len(input.shape) == 2: | |||
| img = cv2.cvtColor(input, cv2.COLOR_GRAY2BGR) | |||
| else: | |||
| raise TypeError(f'input should be either str, PIL.Image,' | |||
| f' np.array, but got {type(input)}') | |||
| result = { | |||
| 'img': img, | |||
| 'img_shape': img.shape, | |||
| 'ori_shape': img.shape, | |||
| 'img_fields': ['img'], | |||
| 'img_prefix': '', | |||
| 'img_info': { | |||
| 'filename': filename, | |||
| 'ann_file': None, | |||
| 'classes': None | |||
| }, | |||
| } | |||
| result = self.preprocessor(result) | |||
| # stacked as a batch | |||
| result['img'] = torch.stack([result['img']], dim=0) | |||
| result['img_metas'] = [result['img_metas'].data] | |||
| return result | |||
| def forward(self, input: Dict[str, Any], | |||
| **forward_params) -> Dict[str, Any]: | |||
| with torch.no_grad(): | |||
| output = self.model(input) | |||
| return output | |||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
| result = get_img_ins_seg_result( | |||
| img_seg_result=inputs['eval_result'][0], | |||
| class_names=self.model.model.classes) | |||
| return result | |||
| @@ -20,6 +20,7 @@ try: | |||
| from .space.dialog_modeling_preprocessor import * # noqa F403 | |||
| from .space.dialog_state_tracking_preprocessor import * # noqa F403 | |||
| from .image import ImageColorEnhanceFinetunePreprocessor | |||
| from .image import ImageInstanceSegmentationPreprocessor | |||
| except ModuleNotFoundError as e: | |||
| if str(e) == "No module named 'tensorflow'": | |||
| print(TENSORFLOW_IMPORT_ERROR.format('tts')) | |||
| @@ -136,3 +136,72 @@ class ImageColorEnhanceFinetunePreprocessor(Preprocessor): | |||
| """ | |||
| return data | |||
| @PREPROCESSORS.register_module( | |||
| Fields.cv, | |||
| module_name=Preprocessors.image_instance_segmentation_preprocessor) | |||
| class ImageInstanceSegmentationPreprocessor(Preprocessor): | |||
| def __init__(self, *args, **kwargs): | |||
| """image instance segmentation preprocessor in the fine-tune scenario | |||
| """ | |||
| super().__init__(*args, **kwargs) | |||
| self.training = kwargs.pop('training', True) | |||
| self.preprocessor_train_cfg = kwargs.pop('train', None) | |||
| self.preprocessor_test_cfg = kwargs.pop('val', None) | |||
| self.train_transforms = [] | |||
| self.test_transforms = [] | |||
| from modelscope.models.cv.image_instance_segmentation.datasets import \ | |||
| build_preprocess_transform | |||
| if self.preprocessor_train_cfg is not None: | |||
| if isinstance(self.preprocessor_train_cfg, dict): | |||
| self.preprocessor_train_cfg = [self.preprocessor_train_cfg] | |||
| for cfg in self.preprocessor_train_cfg: | |||
| transform = build_preprocess_transform(cfg) | |||
| self.train_transforms.append(transform) | |||
| if self.preprocessor_test_cfg is not None: | |||
| if isinstance(self.preprocessor_test_cfg, dict): | |||
| self.preprocessor_test_cfg = [self.preprocessor_test_cfg] | |||
| for cfg in self.preprocessor_test_cfg: | |||
| transform = build_preprocess_transform(cfg) | |||
| self.test_transforms.append(transform) | |||
| def train(self): | |||
| self.training = True | |||
| return | |||
| def eval(self): | |||
| self.training = False | |||
| return | |||
| @type_assert(object, object) | |||
| def __call__(self, results: Dict[str, Any]): | |||
| """process the raw input data | |||
| Args: | |||
| results (dict): Result dict from loading pipeline. | |||
| Returns: | |||
| Dict[str, Any] | None: the preprocessed data | |||
| """ | |||
| if self.training: | |||
| transforms = self.train_transforms | |||
| else: | |||
| transforms = self.test_transforms | |||
| for t in transforms: | |||
| results = t(results) | |||
| if results is None: | |||
| return None | |||
| return results | |||
| @@ -1,4 +1,5 @@ | |||
| from .base import DummyTrainer | |||
| from .builder import build_trainer | |||
| from .cv import ImageInstanceSegmentationTrainer | |||
| from .nlp import SequenceClassificationTrainer | |||
| from .trainer import EpochBasedTrainer | |||
| @@ -0,0 +1,2 @@ | |||
| from .image_instance_segmentation_trainer import \ | |||
| ImageInstanceSegmentationTrainer | |||
| @@ -0,0 +1,27 @@ | |||
| from modelscope.trainers.builder import TRAINERS | |||
| from modelscope.trainers.trainer import EpochBasedTrainer | |||
| @TRAINERS.register_module(module_name='image-instance-segmentation') | |||
| class ImageInstanceSegmentationTrainer(EpochBasedTrainer): | |||
| def __init__(self, *args, **kwargs): | |||
| super().__init__(*args, **kwargs) | |||
| def collate_fn(self, data): | |||
| # we skip this func due to some special data type, e.g., BitmapMasks | |||
| return data | |||
| def train(self, *args, **kwargs): | |||
| super().train(*args, **kwargs) | |||
| def evaluate(self, *args, **kwargs): | |||
| metric_values = super().evaluate(*args, **kwargs) | |||
| return metric_values | |||
| def prediction_step(self, model, inputs): | |||
| pass | |||
| def to_task_dataset(self, datasets, mode, preprocessor=None): | |||
| # wait for dataset interface to become stable... | |||
| return datasets.to_torch_dataset(preprocessor) | |||
| @@ -0,0 +1,60 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| import unittest | |||
| from modelscope.hub.snapshot_download import snapshot_download | |||
| from modelscope.models import Model | |||
| from modelscope.models.cv.image_instance_segmentation.model import \ | |||
| CascadeMaskRCNNSwinModel | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines import ImageInstanceSegmentationPipeline, pipeline | |||
| from modelscope.preprocessors import build_preprocessor | |||
| from modelscope.utils.config import Config | |||
| from modelscope.utils.constant import Fields, ModelFile, Tasks | |||
| from modelscope.utils.test_utils import test_level | |||
| class ImageInstanceSegmentationTest(unittest.TestCase): | |||
| model_id = 'damo/cv_swin-b_image-instance-segmentation_coco' | |||
| image = 'data/test/images/image_instance_segmentation.jpg' | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_model_from_modelhub(self): | |||
| model = Model.from_pretrained(self.model_id) | |||
| config_path = os.path.join(model.model_dir, ModelFile.CONFIGURATION) | |||
| cfg = Config.from_file(config_path) | |||
| preprocessor = build_preprocessor(cfg.preprocessor, Fields.cv) | |||
| pipeline_ins = pipeline( | |||
| task=Tasks.image_segmentation, | |||
| model=model, | |||
| preprocessor=preprocessor) | |||
| print(pipeline_ins(input=self.image)[OutputKeys.LABELS]) | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_run_with_model_name(self): | |||
| pipeline_ins = pipeline( | |||
| task=Tasks.image_segmentation, model=self.model_id) | |||
| print(pipeline_ins(input=self.image)[OutputKeys.LABELS]) | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_with_default_model(self): | |||
| pipeline_ins = pipeline(task=Tasks.image_segmentation) | |||
| print(pipeline_ins(input=self.image)[OutputKeys.LABELS]) | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_by_direct_model_download(self): | |||
| cache_path = snapshot_download(self.model_id) | |||
| config_path = os.path.join(cache_path, ModelFile.CONFIGURATION) | |||
| cfg = Config.from_file(config_path) | |||
| preprocessor = build_preprocessor(cfg.preprocessor, Fields.cv) | |||
| model = CascadeMaskRCNNSwinModel(cache_path) | |||
| pipeline1 = ImageInstanceSegmentationPipeline( | |||
| model, preprocessor=preprocessor) | |||
| pipeline2 = pipeline( | |||
| Tasks.image_segmentation, model=model, preprocessor=preprocessor) | |||
| print(f'pipeline1:{pipeline1(input=self.image)[OutputKeys.LABELS]}') | |||
| print(f'pipeline2: {pipeline2(input=self.image)[OutputKeys.LABELS]}') | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||
| @@ -0,0 +1,117 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| import shutil | |||
| import tempfile | |||
| import unittest | |||
| import zipfile | |||
| from functools import partial | |||
| from modelscope.hub.snapshot_download import snapshot_download | |||
| from modelscope.models.cv.image_instance_segmentation import \ | |||
| CascadeMaskRCNNSwinModel | |||
| from modelscope.models.cv.image_instance_segmentation.datasets import \ | |||
| ImageInstanceSegmentationCocoDataset | |||
| from modelscope.trainers import build_trainer | |||
| from modelscope.utils.config import Config | |||
| from modelscope.utils.constant import ModelFile | |||
| from modelscope.utils.test_utils import test_level | |||
| class TestImageInstanceSegmentationTrainer(unittest.TestCase): | |||
| model_id = 'damo/cv_swin-b_image-instance-segmentation_coco' | |||
| def setUp(self): | |||
| print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) | |||
| cache_path = snapshot_download(self.model_id) | |||
| config_path = os.path.join(cache_path, ModelFile.CONFIGURATION) | |||
| cfg = Config.from_file(config_path) | |||
| data_root = cfg.dataset.data_root | |||
| classes = tuple(cfg.dataset.classes) | |||
| max_epochs = cfg.train.max_epochs | |||
| samples_per_gpu = cfg.train.dataloader.batch_size_per_gpu | |||
| if data_root is None: | |||
| # use default toy data | |||
| dataset_path = os.path.join(cache_path, 'toydata.zip') | |||
| with zipfile.ZipFile(dataset_path, 'r') as zipf: | |||
| zipf.extractall(cache_path) | |||
| data_root = cache_path + '/toydata/' | |||
| classes = ('Cat', 'Dog') | |||
| self.train_dataset = ImageInstanceSegmentationCocoDataset( | |||
| data_root + 'annotations/instances_train.json', | |||
| classes=classes, | |||
| data_root=data_root, | |||
| img_prefix=data_root + 'images/train/', | |||
| seg_prefix=None, | |||
| test_mode=False) | |||
| self.eval_dataset = ImageInstanceSegmentationCocoDataset( | |||
| data_root + 'annotations/instances_val.json', | |||
| classes=classes, | |||
| data_root=data_root, | |||
| img_prefix=data_root + 'images/val/', | |||
| seg_prefix=None, | |||
| test_mode=True) | |||
| from mmcv.parallel import collate | |||
| self.collate_fn = partial(collate, samples_per_gpu=samples_per_gpu) | |||
| self.max_epochs = max_epochs | |||
| self.tmp_dir = tempfile.TemporaryDirectory().name | |||
| if not os.path.exists(self.tmp_dir): | |||
| os.makedirs(self.tmp_dir) | |||
| def tearDown(self): | |||
| shutil.rmtree(self.tmp_dir) | |||
| super().tearDown() | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_trainer(self): | |||
| kwargs = dict( | |||
| model=self.model_id, | |||
| data_collator=self.collate_fn, | |||
| train_dataset=self.train_dataset, | |||
| eval_dataset=self.eval_dataset, | |||
| work_dir=self.tmp_dir) | |||
| trainer = build_trainer( | |||
| name='image-instance-segmentation', default_args=kwargs) | |||
| trainer.train() | |||
| results_files = os.listdir(self.tmp_dir) | |||
| self.assertIn(f'{trainer.timestamp}.log.json', results_files) | |||
| for i in range(self.max_epochs): | |||
| self.assertIn(f'epoch_{i+1}.pth', results_files) | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_trainer_with_model_and_args(self): | |||
| tmp_dir = tempfile.TemporaryDirectory().name | |||
| if not os.path.exists(tmp_dir): | |||
| os.makedirs(tmp_dir) | |||
| cache_path = snapshot_download(self.model_id) | |||
| model = CascadeMaskRCNNSwinModel.from_pretrained(cache_path) | |||
| kwargs = dict( | |||
| cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION), | |||
| model=model, | |||
| data_collator=self.collate_fn, | |||
| train_dataset=self.train_dataset, | |||
| eval_dataset=self.eval_dataset, | |||
| work_dir=self.tmp_dir) | |||
| trainer = build_trainer( | |||
| name='image-instance-segmentation', default_args=kwargs) | |||
| trainer.train() | |||
| results_files = os.listdir(self.tmp_dir) | |||
| self.assertIn(f'{trainer.timestamp}.log.json', results_files) | |||
| for i in range(self.max_epochs): | |||
| self.assertIn(f'epoch_{i+1}.pth', results_files) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||