Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10539423master
| @@ -305,6 +305,7 @@ class Trainers(object): | |||||
| face_detection_scrfd = 'face-detection-scrfd' | face_detection_scrfd = 'face-detection-scrfd' | ||||
| card_detection_scrfd = 'card-detection-scrfd' | card_detection_scrfd = 'card-detection-scrfd' | ||||
| image_inpainting = 'image-inpainting' | image_inpainting = 'image-inpainting' | ||||
| referring_video_object_segmentation = 'referring-video-object-segmentation' | |||||
| image_classification_team = 'image-classification-team' | image_classification_team = 'image-classification-team' | ||||
| # nlp trainers | # nlp trainers | ||||
| @@ -423,6 +424,8 @@ class Metrics(object): | |||||
| image_inpainting_metric = 'image-inpainting-metric' | image_inpainting_metric = 'image-inpainting-metric' | ||||
| # metric for ocr | # metric for ocr | ||||
| NED = 'ned' | NED = 'ned' | ||||
| # metric for referring-video-object-segmentation task | |||||
| referring_video_object_segmentation_metric = 'referring-video-object-segmentation-metric' | |||||
| class Optimizers(object): | class Optimizers(object): | ||||
| @@ -20,6 +20,7 @@ if TYPE_CHECKING: | |||||
| from .accuracy_metric import AccuracyMetric | from .accuracy_metric import AccuracyMetric | ||||
| from .bleu_metric import BleuMetric | from .bleu_metric import BleuMetric | ||||
| from .image_inpainting_metric import ImageInpaintingMetric | from .image_inpainting_metric import ImageInpaintingMetric | ||||
| from .referring_video_object_segmentation_metric import ReferringVideoObjectSegmentationMetric | |||||
| else: | else: | ||||
| _import_structure = { | _import_structure = { | ||||
| @@ -40,6 +41,8 @@ else: | |||||
| 'image_inpainting_metric': ['ImageInpaintingMetric'], | 'image_inpainting_metric': ['ImageInpaintingMetric'], | ||||
| 'accuracy_metric': ['AccuracyMetric'], | 'accuracy_metric': ['AccuracyMetric'], | ||||
| 'bleu_metric': ['BleuMetric'], | 'bleu_metric': ['BleuMetric'], | ||||
| 'referring_video_object_segmentation_metric': | |||||
| ['ReferringVideoObjectSegmentationMetric'], | |||||
| } | } | ||||
| import sys | import sys | ||||
| @@ -43,6 +43,8 @@ task_default_metrics = { | |||||
| Tasks.visual_question_answering: [Metrics.text_gen_metric], | Tasks.visual_question_answering: [Metrics.text_gen_metric], | ||||
| Tasks.movie_scene_segmentation: [Metrics.movie_scene_segmentation_metric], | Tasks.movie_scene_segmentation: [Metrics.movie_scene_segmentation_metric], | ||||
| Tasks.image_inpainting: [Metrics.image_inpainting_metric], | Tasks.image_inpainting: [Metrics.image_inpainting_metric], | ||||
| Tasks.referring_video_object_segmentation: | |||||
| [Metrics.referring_video_object_segmentation_metric], | |||||
| } | } | ||||
| @@ -0,0 +1,108 @@ | |||||
| # Part of the implementation is borrowed and modified from MTTR, | |||||
| # publicly available at https://github.com/mttr2021/MTTR | |||||
| from typing import Dict | |||||
| import numpy as np | |||||
| import torch | |||||
| from pycocotools.coco import COCO | |||||
| from pycocotools.cocoeval import COCOeval | |||||
| from pycocotools.mask import decode | |||||
| from tqdm import tqdm | |||||
| from modelscope.metainfo import Metrics | |||||
| from modelscope.utils.registry import default_group | |||||
| from .base import Metric | |||||
| from .builder import METRICS, MetricKeys | |||||
| @METRICS.register_module( | |||||
| group_key=default_group, | |||||
| module_name=Metrics.referring_video_object_segmentation_metric) | |||||
| class ReferringVideoObjectSegmentationMetric(Metric): | |||||
| """The metric computation class for movie scene segmentation classes. | |||||
| """ | |||||
| def __init__(self, | |||||
| ann_file=None, | |||||
| calculate_precision_and_iou_metrics=True): | |||||
| self.ann_file = ann_file | |||||
| self.calculate_precision_and_iou_metrics = calculate_precision_and_iou_metrics | |||||
| self.preds = [] | |||||
| def add(self, outputs: Dict, inputs: Dict): | |||||
| preds_batch = outputs['pred'] | |||||
| self.preds.extend(preds_batch) | |||||
| def evaluate(self): | |||||
| coco_gt = COCO(self.ann_file) | |||||
| coco_pred = coco_gt.loadRes(self.preds) | |||||
| coco_eval = COCOeval(coco_gt, coco_pred, iouType='segm') | |||||
| coco_eval.params.useCats = 0 | |||||
| coco_eval.evaluate() | |||||
| coco_eval.accumulate() | |||||
| coco_eval.summarize() | |||||
| ap_labels = [ | |||||
| 'mAP 0.5:0.95', 'AP 0.5', 'AP 0.75', 'AP 0.5:0.95 S', | |||||
| 'AP 0.5:0.95 M', 'AP 0.5:0.95 L' | |||||
| ] | |||||
| ap_metrics = coco_eval.stats[:6] | |||||
| eval_metrics = {la: m for la, m in zip(ap_labels, ap_metrics)} | |||||
| if self.calculate_precision_and_iou_metrics: | |||||
| precision_at_k, overall_iou, mean_iou = calculate_precision_at_k_and_iou_metrics( | |||||
| coco_gt, coco_pred) | |||||
| eval_metrics.update({ | |||||
| f'P@{k}': m | |||||
| for k, m in zip([0.5, 0.6, 0.7, 0.8, 0.9], precision_at_k) | |||||
| }) | |||||
| eval_metrics.update({ | |||||
| 'overall_iou': overall_iou, | |||||
| 'mean_iou': mean_iou | |||||
| }) | |||||
| return eval_metrics | |||||
| def compute_iou(outputs: torch.Tensor, labels: torch.Tensor, EPS=1e-6): | |||||
| outputs = outputs.int() | |||||
| intersection = (outputs & labels).float().sum( | |||||
| (1, 2)) # Will be zero if Truth=0 or Prediction=0 | |||||
| union = (outputs | labels).float().sum( | |||||
| (1, 2)) # Will be zero if both are 0 | |||||
| iou = (intersection + EPS) / (union + EPS | |||||
| ) # EPS is used to avoid division by zero | |||||
| return iou, intersection, union | |||||
| def calculate_precision_at_k_and_iou_metrics(coco_gt: COCO, coco_pred: COCO): | |||||
| print('evaluating precision@k & iou metrics...') | |||||
| counters_by_iou = {iou: 0 for iou in [0.5, 0.6, 0.7, 0.8, 0.9]} | |||||
| total_intersection_area = 0 | |||||
| total_union_area = 0 | |||||
| ious_list = [] | |||||
| for instance in tqdm(coco_gt.imgs.keys() | |||||
| ): # each image_id contains exactly one instance | |||||
| gt_annot = coco_gt.imgToAnns[instance][0] | |||||
| gt_mask = decode(gt_annot['segmentation']) | |||||
| pred_annots = coco_pred.imgToAnns[instance] | |||||
| pred_annot = sorted( | |||||
| pred_annots, | |||||
| key=lambda a: a['score'])[-1] # choose pred with highest score | |||||
| pred_mask = decode(pred_annot['segmentation']) | |||||
| iou, intersection, union = compute_iou( | |||||
| torch.tensor(pred_mask).unsqueeze(0), | |||||
| torch.tensor(gt_mask).unsqueeze(0)) | |||||
| iou, intersection, union = iou.item(), intersection.item(), union.item( | |||||
| ) | |||||
| for iou_threshold in counters_by_iou.keys(): | |||||
| if iou > iou_threshold: | |||||
| counters_by_iou[iou_threshold] += 1 | |||||
| total_intersection_area += intersection | |||||
| total_union_area += union | |||||
| ious_list.append(iou) | |||||
| num_samples = len(ious_list) | |||||
| precision_at_k = np.array(list(counters_by_iou.values())) / num_samples | |||||
| overall_iou = total_intersection_area / total_union_area | |||||
| mean_iou = np.mean(ious_list) | |||||
| return precision_at_k, overall_iou, mean_iou | |||||
| @@ -5,11 +5,11 @@ from modelscope.utils.import_utils import LazyImportModule | |||||
| if TYPE_CHECKING: | if TYPE_CHECKING: | ||||
| from .model import MovieSceneSegmentation | |||||
| from .model import ReferringVideoObjectSegmentation | |||||
| else: | else: | ||||
| _import_structure = { | _import_structure = { | ||||
| 'model': ['MovieSceneSegmentation'], | |||||
| 'model': ['ReferringVideoObjectSegmentation'], | |||||
| } | } | ||||
| import sys | import sys | ||||
| @@ -1,4 +1,6 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| # Part of the implementation is borrowed and modified from MTTR, | |||||
| # publicly available at https://github.com/mttr2021/MTTR | |||||
| import os.path as osp | import os.path as osp | ||||
| from typing import Any, Dict | from typing import Any, Dict | ||||
| @@ -10,7 +12,9 @@ from modelscope.models.builder import MODELS | |||||
| from modelscope.utils.config import Config | from modelscope.utils.config import Config | ||||
| from modelscope.utils.constant import ModelFile, Tasks | from modelscope.utils.constant import ModelFile, Tasks | ||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| from .utils import (MTTR, A2DSentencesPostProcess, ReferYoutubeVOSPostProcess, | |||||
| from .utils import (MTTR, A2DSentencesPostProcess, HungarianMatcher, | |||||
| ReferYoutubeVOSPostProcess, SetCriterion, | |||||
| flatten_temporal_batch_dims, | |||||
| nested_tensor_from_videos_list) | nested_tensor_from_videos_list) | ||||
| logger = get_logger() | logger = get_logger() | ||||
| @@ -35,16 +39,66 @@ class ReferringVideoObjectSegmentation(TorchModel): | |||||
| params_dict = params_dict['model_state_dict'] | params_dict = params_dict['model_state_dict'] | ||||
| self.model.load_state_dict(params_dict, strict=True) | self.model.load_state_dict(params_dict, strict=True) | ||||
| dataset_name = self.cfg.pipeline.dataset_name | |||||
| if dataset_name == 'a2d_sentences' or dataset_name == 'jhmdb_sentences': | |||||
| self.postprocessor = A2DSentencesPostProcess() | |||||
| elif dataset_name == 'ref_youtube_vos': | |||||
| self.postprocessor = ReferYoutubeVOSPostProcess() | |||||
| self.set_postprocessor(self.cfg.pipeline.dataset_name) | |||||
| self.set_criterion() | |||||
| def set_device(self, device, name): | |||||
| self.device = device | |||||
| self._device_name = name | |||||
| def set_postprocessor(self, dataset_name): | |||||
| if 'a2d_sentences' in dataset_name or 'jhmdb_sentences' in dataset_name: | |||||
| self.postprocessor = A2DSentencesPostProcess() # fine-tune | |||||
| elif 'ref_youtube_vos' in dataset_name: | |||||
| self.postprocessor = ReferYoutubeVOSPostProcess() # inference | |||||
| else: | else: | ||||
| assert False, f'postprocessing for dataset: {dataset_name} is not supported' | assert False, f'postprocessing for dataset: {dataset_name} is not supported' | ||||
| def forward(self, inputs: Dict[str, Any]) -> Dict[str, torch.Tensor]: | |||||
| return inputs | |||||
| def forward(self, inputs: Dict[str, Any]): | |||||
| samples = inputs['samples'] | |||||
| targets = inputs['targets'] | |||||
| text_queries = inputs['text_queries'] | |||||
| valid_indices = torch.tensor( | |||||
| [i for i, t in enumerate(targets) if None not in t]) | |||||
| targets = [targets[i] for i in valid_indices.tolist()] | |||||
| if self._device_name == 'gpu': | |||||
| samples = samples.to(self.device) | |||||
| valid_indices = valid_indices.to(self.device) | |||||
| if isinstance(text_queries, tuple): | |||||
| text_queries = list(text_queries) | |||||
| outputs = self.model(samples, valid_indices, text_queries) | |||||
| losses = -1 | |||||
| if self.training: | |||||
| loss_dict = self.criterion(outputs, targets) | |||||
| weight_dict = self.criterion.weight_dict | |||||
| losses = sum(loss_dict[k] * weight_dict[k] | |||||
| for k in loss_dict.keys() if k in weight_dict) | |||||
| predictions = [] | |||||
| if not self.training: | |||||
| outputs.pop('aux_outputs', None) | |||||
| outputs, targets = flatten_temporal_batch_dims(outputs, targets) | |||||
| processed_outputs = self.postprocessor( | |||||
| outputs, | |||||
| resized_padded_sample_size=samples.tensors.shape[-2:], | |||||
| resized_sample_sizes=[t['size'] for t in targets], | |||||
| orig_sample_sizes=[t['orig_size'] for t in targets]) | |||||
| image_ids = [t['image_id'] for t in targets] | |||||
| predictions = [] | |||||
| for p, image_id in zip(processed_outputs, image_ids): | |||||
| for s, m in zip(p['scores'], p['rle_masks']): | |||||
| predictions.append({ | |||||
| 'image_id': image_id, | |||||
| 'category_id': | |||||
| 1, # dummy label, as categories are not predicted in ref-vos | |||||
| 'segmentation': m, | |||||
| 'score': s.item() | |||||
| }) | |||||
| re = dict(pred=predictions, loss=losses) | |||||
| return re | |||||
| def inference(self, **kwargs): | def inference(self, **kwargs): | ||||
| window = kwargs['window'] | window = kwargs['window'] | ||||
| @@ -63,3 +117,26 @@ class ReferringVideoObjectSegmentation(TorchModel): | |||||
| def postprocess(self, inputs: Dict[str, Any], **kwargs): | def postprocess(self, inputs: Dict[str, Any], **kwargs): | ||||
| return inputs | return inputs | ||||
| def set_criterion(self): | |||||
| matcher = HungarianMatcher( | |||||
| cost_is_referred=self.cfg.matcher.set_cost_is_referred, | |||||
| cost_dice=self.cfg.matcher.set_cost_dice) | |||||
| weight_dict = { | |||||
| 'loss_is_referred': self.cfg.loss.is_referred_loss_coef, | |||||
| 'loss_dice': self.cfg.loss.dice_loss_coef, | |||||
| 'loss_sigmoid_focal': self.cfg.loss.sigmoid_focal_loss_coef | |||||
| } | |||||
| if self.cfg.loss.aux_loss: | |||||
| aux_weight_dict = {} | |||||
| for i in range(self.cfg.model.num_decoder_layers - 1): | |||||
| aux_weight_dict.update( | |||||
| {k + f'_{i}': v | |||||
| for k, v in weight_dict.items()}) | |||||
| weight_dict.update(aux_weight_dict) | |||||
| self.criterion = SetCriterion( | |||||
| matcher=matcher, | |||||
| weight_dict=weight_dict, | |||||
| eos_coef=self.cfg.loss.eos_coef) | |||||
| @@ -1,4 +1,6 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| from .misc import nested_tensor_from_videos_list | |||||
| from .criterion import SetCriterion, flatten_temporal_batch_dims | |||||
| from .matcher import HungarianMatcher | |||||
| from .misc import interpolate, nested_tensor_from_videos_list | |||||
| from .mttr import MTTR | from .mttr import MTTR | ||||
| from .postprocessing import A2DSentencesPostProcess, ReferYoutubeVOSPostProcess | from .postprocessing import A2DSentencesPostProcess, ReferYoutubeVOSPostProcess | ||||
| @@ -0,0 +1,198 @@ | |||||
| # The implementation is adopted from MTTR, | |||||
| # made publicly available under the Apache 2.0 License at https://github.com/mttr2021/MTTR | |||||
| # Modified from DETR https://github.com/facebookresearch/detr | |||||
| import torch | |||||
| from torch import nn | |||||
| from .misc import (get_world_size, interpolate, is_dist_avail_and_initialized, | |||||
| nested_tensor_from_tensor_list) | |||||
| from .segmentation import dice_loss, sigmoid_focal_loss | |||||
| class SetCriterion(nn.Module): | |||||
| """ This class computes the loss for MTTR. | |||||
| The process happens in two steps: | |||||
| 1) we compute the hungarian assignment between the ground-truth and predicted sequences. | |||||
| 2) we supervise each pair of matched ground-truth / prediction sequences (mask + reference prediction) | |||||
| """ | |||||
| def __init__(self, matcher, weight_dict, eos_coef): | |||||
| """ Create the criterion. | |||||
| Parameters: | |||||
| matcher: module able to compute a matching between targets and proposals | |||||
| weight_dict: dict containing as key the names of the losses and as values their relative weight. | |||||
| eos_coef: relative classification weight applied to the un-referred category | |||||
| """ | |||||
| super().__init__() | |||||
| self.matcher = matcher | |||||
| self.weight_dict = weight_dict | |||||
| self.eos_coef = eos_coef | |||||
| # make sure that only loss functions with non-zero weights are computed: | |||||
| losses_to_compute = [] | |||||
| if weight_dict['loss_dice'] > 0 or weight_dict[ | |||||
| 'loss_sigmoid_focal'] > 0: | |||||
| losses_to_compute.append('masks') | |||||
| if weight_dict['loss_is_referred'] > 0: | |||||
| losses_to_compute.append('is_referred') | |||||
| self.losses = losses_to_compute | |||||
| def forward(self, outputs, targets): | |||||
| aux_outputs_list = outputs.pop('aux_outputs', None) | |||||
| # compute the losses for the output of the last decoder layer: | |||||
| losses = self.compute_criterion( | |||||
| outputs, targets, losses_to_compute=self.losses) | |||||
| # In case of auxiliary losses, we repeat this process with the output of each intermediate decoder layer. | |||||
| if aux_outputs_list is not None: | |||||
| aux_losses_to_compute = self.losses.copy() | |||||
| for i, aux_outputs in enumerate(aux_outputs_list): | |||||
| losses_dict = self.compute_criterion(aux_outputs, targets, | |||||
| aux_losses_to_compute) | |||||
| losses_dict = {k + f'_{i}': v for k, v in losses_dict.items()} | |||||
| losses.update(losses_dict) | |||||
| return losses | |||||
| def compute_criterion(self, outputs, targets, losses_to_compute): | |||||
| # Retrieve the matching between the outputs of the last layer and the targets | |||||
| indices = self.matcher(outputs, targets) | |||||
| # T & B dims are flattened so loss functions can be computed per frame (but with same indices per video). | |||||
| # also, indices are repeated so so the same indices can be used for frames of the same video. | |||||
| T = len(targets) | |||||
| outputs, targets = flatten_temporal_batch_dims(outputs, targets) | |||||
| # repeat the indices list T times so the same indices can be used for each video frame | |||||
| indices = T * indices | |||||
| # Compute the average number of target masks across all nodes, for normalization purposes | |||||
| num_masks = sum(len(t['masks']) for t in targets) | |||||
| num_masks = torch.as_tensor([num_masks], | |||||
| dtype=torch.float, | |||||
| device=indices[0][0].device) | |||||
| if is_dist_avail_and_initialized(): | |||||
| torch.distributed.all_reduce(num_masks) | |||||
| num_masks = torch.clamp(num_masks / get_world_size(), min=1).item() | |||||
| # Compute all the requested losses | |||||
| losses = {} | |||||
| for loss in losses_to_compute: | |||||
| losses.update( | |||||
| self.get_loss( | |||||
| loss, outputs, targets, indices, num_masks=num_masks)) | |||||
| return losses | |||||
| def loss_is_referred(self, outputs, targets, indices, **kwargs): | |||||
| device = outputs['pred_is_referred'].device | |||||
| bs = outputs['pred_is_referred'].shape[0] | |||||
| pred_is_referred = outputs['pred_is_referred'].log_softmax( | |||||
| dim=-1) # note that log-softmax is used here | |||||
| target_is_referred = torch.zeros_like(pred_is_referred) | |||||
| # extract indices of object queries that where matched with text-referred target objects | |||||
| query_referred_indices = self._get_query_referred_indices( | |||||
| indices, targets) | |||||
| # by default penalize compared to the no-object class (last token) | |||||
| target_is_referred[:, :, :] = torch.tensor([0.0, 1.0], device=device) | |||||
| if 'is_ref_inst_visible' in targets[ | |||||
| 0]: # visibility labels are available per-frame for the referred object: | |||||
| is_ref_inst_visible_per_frame = torch.stack( | |||||
| [t['is_ref_inst_visible'] for t in targets]) | |||||
| ref_inst_visible_frame_indices = is_ref_inst_visible_per_frame.nonzero( | |||||
| ).squeeze() | |||||
| # keep only the matched query indices of the frames in which the referred object is visible: | |||||
| visible_query_referred_indices = query_referred_indices[ | |||||
| ref_inst_visible_frame_indices] | |||||
| target_is_referred[ref_inst_visible_frame_indices, | |||||
| visible_query_referred_indices] = torch.tensor( | |||||
| [1.0, 0.0], device=device) | |||||
| else: # assume that the referred object is visible in every frame: | |||||
| target_is_referred[torch.arange(bs), | |||||
| query_referred_indices] = torch.tensor( | |||||
| [1.0, 0.0], device=device) | |||||
| loss = -(pred_is_referred * target_is_referred).sum(-1) | |||||
| # apply no-object class weights: | |||||
| eos_coef = torch.full(loss.shape, self.eos_coef, device=loss.device) | |||||
| eos_coef[torch.arange(bs), query_referred_indices] = 1.0 | |||||
| loss = loss * eos_coef | |||||
| bs = len(indices) | |||||
| loss = loss.sum() / bs # sum and normalize the loss by the batch size | |||||
| losses = {'loss_is_referred': loss} | |||||
| return losses | |||||
| def loss_masks(self, outputs, targets, indices, num_masks, **kwargs): | |||||
| assert 'pred_masks' in outputs | |||||
| src_idx = self._get_src_permutation_idx(indices) | |||||
| tgt_idx = self._get_tgt_permutation_idx(indices) | |||||
| src_masks = outputs['pred_masks'] | |||||
| src_masks = src_masks[src_idx] | |||||
| masks = [t['masks'] for t in targets] | |||||
| target_masks, valid = nested_tensor_from_tensor_list(masks).decompose() | |||||
| target_masks = target_masks.to(src_masks) | |||||
| target_masks = target_masks[tgt_idx] | |||||
| # upsample predictions to the target size | |||||
| src_masks = interpolate( | |||||
| src_masks[:, None], | |||||
| size=target_masks.shape[-2:], | |||||
| mode='bilinear', | |||||
| align_corners=False) | |||||
| src_masks = src_masks[:, 0].flatten(1) | |||||
| target_masks = target_masks.flatten(1) | |||||
| target_masks = target_masks.view(src_masks.shape) | |||||
| losses = { | |||||
| 'loss_sigmoid_focal': | |||||
| sigmoid_focal_loss(src_masks, target_masks, num_masks), | |||||
| 'loss_dice': | |||||
| dice_loss(src_masks, target_masks, num_masks), | |||||
| } | |||||
| return losses | |||||
| @staticmethod | |||||
| def _get_src_permutation_idx(indices): | |||||
| # permute predictions following indices | |||||
| batch_idx = torch.cat( | |||||
| [torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) | |||||
| src_idx = torch.cat([src for (src, _) in indices]) | |||||
| return batch_idx, src_idx | |||||
| @staticmethod | |||||
| def _get_tgt_permutation_idx(indices): | |||||
| # permute targets following indices | |||||
| batch_idx = torch.cat( | |||||
| [torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) | |||||
| tgt_idx = torch.cat([tgt for (_, tgt) in indices]) | |||||
| return batch_idx, tgt_idx | |||||
| @staticmethod | |||||
| def _get_query_referred_indices(indices, targets): | |||||
| """ | |||||
| extract indices of object queries that where matched with text-referred target objects | |||||
| """ | |||||
| query_referred_indices = [] | |||||
| for (query_idxs, target_idxs), target in zip(indices, targets): | |||||
| ref_query_idx = query_idxs[torch.where( | |||||
| target_idxs == target['referred_instance_idx'])[0]] | |||||
| query_referred_indices.append(ref_query_idx) | |||||
| query_referred_indices = torch.cat(query_referred_indices) | |||||
| return query_referred_indices | |||||
| def get_loss(self, loss, outputs, targets, indices, **kwargs): | |||||
| loss_map = { | |||||
| 'masks': self.loss_masks, | |||||
| 'is_referred': self.loss_is_referred, | |||||
| } | |||||
| assert loss in loss_map, f'do you really want to compute {loss} loss?' | |||||
| return loss_map[loss](outputs, targets, indices, **kwargs) | |||||
| def flatten_temporal_batch_dims(outputs, targets): | |||||
| for k in outputs.keys(): | |||||
| if isinstance(outputs[k], torch.Tensor): | |||||
| outputs[k] = outputs[k].flatten(0, 1) | |||||
| else: # list | |||||
| outputs[k] = [i for step_t in outputs[k] for i in step_t] | |||||
| targets = [ | |||||
| frame_t_target for step_t in targets for frame_t_target in step_t | |||||
| ] | |||||
| return outputs, targets | |||||
| @@ -0,0 +1,163 @@ | |||||
| # The implementation is adopted from MTTR, | |||||
| # made publicly available under the Apache 2.0 License at https://github.com/mttr2021/MTTR | |||||
| # Modified from DETR https://github.com/facebookresearch/detr | |||||
| # Module to compute the matching cost and solve the corresponding LSAP. | |||||
| import torch | |||||
| from scipy.optimize import linear_sum_assignment | |||||
| from torch import nn | |||||
| from .misc import interpolate, nested_tensor_from_tensor_list | |||||
| class HungarianMatcher(nn.Module): | |||||
| """This class computes an assignment between the targets and the predictions of the network | |||||
| For efficiency reasons, the targets don't include the no_object. Because of this, in general, | |||||
| there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, | |||||
| while the others are un-matched (and thus treated as non-objects). | |||||
| """ | |||||
| def __init__(self, cost_is_referred: float = 1, cost_dice: float = 1): | |||||
| """Creates the matcher | |||||
| Params: | |||||
| cost_is_referred: This is the relative weight of the reference cost in the total matching cost | |||||
| cost_dice: This is the relative weight of the dice cost in the total matching cost | |||||
| """ | |||||
| super().__init__() | |||||
| self.cost_is_referred = cost_is_referred | |||||
| self.cost_dice = cost_dice | |||||
| assert cost_is_referred != 0 or cost_dice != 0, 'all costs cant be 0' | |||||
| @torch.inference_mode() | |||||
| def forward(self, outputs, targets): | |||||
| """ Performs the matching | |||||
| Params: | |||||
| outputs: A dict that contains at least these entries: | |||||
| "pred_is_referred": Tensor of dim [time, batch_size, num_queries, 2] with the reference logits | |||||
| "pred_masks": Tensor of dim [time, batch_size, num_queries, H, W] with the predicted masks logits | |||||
| targets: A list of lists of targets (outer - time steps, inner - batch samples). each target is a dict | |||||
| which contain mask and reference ground truth information for a single frame. | |||||
| Returns: | |||||
| A list of size batch_size, containing tuples of (index_i, index_j) where: | |||||
| - index_i is the indices of the selected predictions (in order) | |||||
| - index_j is the indices of the corresponding selected targets (in order) | |||||
| For each batch element, it holds: | |||||
| len(index_i) = len(index_j) = min(num_queries, num_target_masks) | |||||
| """ | |||||
| t, bs, num_queries = outputs['pred_masks'].shape[:3] | |||||
| # We flatten to compute the cost matrices in a batch | |||||
| out_masks = outputs['pred_masks'].flatten( | |||||
| 1, 2) # [t, batch_size * num_queries, mask_h, mask_w] | |||||
| # preprocess and concat the target masks | |||||
| tgt_masks = [[ | |||||
| m for v in t_step_batch for m in v['masks'].unsqueeze(1) | |||||
| ] for t_step_batch in targets] | |||||
| # pad the target masks to a uniform shape | |||||
| tgt_masks, valid = list( | |||||
| zip(*[ | |||||
| nested_tensor_from_tensor_list(t).decompose() | |||||
| for t in tgt_masks | |||||
| ])) | |||||
| tgt_masks = torch.stack(tgt_masks).squeeze(2) | |||||
| # upsample predicted masks to target mask size | |||||
| out_masks = interpolate( | |||||
| out_masks, | |||||
| size=tgt_masks.shape[-2:], | |||||
| mode='bilinear', | |||||
| align_corners=False) | |||||
| # Compute the soft-tokens cost: | |||||
| if self.cost_is_referred > 0: | |||||
| cost_is_referred = compute_is_referred_cost(outputs, targets) | |||||
| else: | |||||
| cost_is_referred = 0 | |||||
| # Compute the DICE coefficient between the masks: | |||||
| if self.cost_dice > 0: | |||||
| cost_dice = -dice_coef(out_masks, tgt_masks) | |||||
| else: | |||||
| cost_dice = 0 | |||||
| # Final cost matrix | |||||
| C = self.cost_is_referred * cost_is_referred + self.cost_dice * cost_dice | |||||
| C = C.view(bs, num_queries, -1).cpu() | |||||
| num_traj_per_batch = [ | |||||
| len(v['masks']) for v in targets[0] | |||||
| ] # number of instance trajectories in each batch | |||||
| indices = [ | |||||
| linear_sum_assignment(c[i]) | |||||
| for i, c in enumerate(C.split(num_traj_per_batch, -1)) | |||||
| ] | |||||
| device = out_masks.device | |||||
| return [(torch.as_tensor(i, dtype=torch.int64, device=device), | |||||
| torch.as_tensor(j, dtype=torch.int64, device=device)) | |||||
| for i, j in indices] | |||||
| def dice_coef(inputs, targets, smooth=1.0): | |||||
| """ | |||||
| Compute the DICE coefficient, similar to generalized IOU for masks | |||||
| Args: | |||||
| inputs: A float tensor of arbitrary shape. | |||||
| The predictions for each example. | |||||
| targets: A float tensor with the same shape as inputs. Stores the binary | |||||
| classification label for each element in inputs | |||||
| (0 for the negative class and 1 for the positive class). | |||||
| """ | |||||
| inputs = inputs.sigmoid().flatten(2).unsqueeze(2) | |||||
| targets = targets.flatten(2).unsqueeze(1) | |||||
| numerator = 2 * (inputs * targets).sum(-1) | |||||
| denominator = inputs.sum(-1) + targets.sum(-1) | |||||
| coef = (numerator + smooth) / (denominator + smooth) | |||||
| coef = coef.mean( | |||||
| 0) # average on the temporal dim to get instance trajectory scores | |||||
| return coef | |||||
| def compute_is_referred_cost(outputs, targets): | |||||
| pred_is_referred = outputs['pred_is_referred'].flatten(1, 2).softmax( | |||||
| dim=-1) # [t, b*nq, 2] | |||||
| device = pred_is_referred.device | |||||
| t = pred_is_referred.shape[0] | |||||
| # number of instance trajectories in each batch | |||||
| num_traj_per_batch = torch.tensor([len(v['masks']) for v in targets[0]], | |||||
| device=device) | |||||
| total_trajectories = num_traj_per_batch.sum() | |||||
| # note that ref_indices are shared across time steps: | |||||
| ref_indices = torch.tensor( | |||||
| [v['referred_instance_idx'] for v in targets[0]], device=device) | |||||
| # convert ref_indices to fit flattened batch targets: | |||||
| ref_indices += torch.cat( | |||||
| (torch.zeros(1, dtype=torch.long, | |||||
| device=device), num_traj_per_batch.cumsum(0)[:-1])) | |||||
| # number of instance trajectories in each batch | |||||
| target_is_referred = torch.zeros((t, total_trajectories, 2), device=device) | |||||
| # 'no object' class by default (for un-referred objects) | |||||
| target_is_referred[:, :, :] = torch.tensor([0.0, 1.0], device=device) | |||||
| if 'is_ref_inst_visible' in targets[0][ | |||||
| 0]: # visibility labels are available per-frame for the referred object: | |||||
| is_ref_inst_visible = torch.stack([ | |||||
| torch.stack([t['is_ref_inst_visible'] for t in t_step]) | |||||
| for t_step in targets | |||||
| ]).permute(1, 0) | |||||
| for ref_idx, is_visible in zip(ref_indices, is_ref_inst_visible): | |||||
| is_visible = is_visible.nonzero().squeeze() | |||||
| target_is_referred[is_visible, | |||||
| ref_idx, :] = torch.tensor([1.0, 0.0], | |||||
| device=device) | |||||
| else: # assume that the referred object is visible in every frame: | |||||
| target_is_referred[:, ref_indices, :] = torch.tensor([1.0, 0.0], | |||||
| device=device) | |||||
| cost_is_referred = -(pred_is_referred.unsqueeze(2) | |||||
| * target_is_referred.unsqueeze(1)).sum(dim=-1).mean( | |||||
| dim=0) | |||||
| return cost_is_referred | |||||
| @@ -122,8 +122,8 @@ class MultimodalTransformer(nn.Module): | |||||
| with torch.inference_mode(mode=self.freeze_text_encoder): | with torch.inference_mode(mode=self.freeze_text_encoder): | ||||
| encoded_text = self.text_encoder(**tokenized_queries) | encoded_text = self.text_encoder(**tokenized_queries) | ||||
| # Transpose memory because pytorch's attention expects sequence first | # Transpose memory because pytorch's attention expects sequence first | ||||
| txt_memory = rearrange(encoded_text.last_hidden_state, | |||||
| 'b s c -> s b c') | |||||
| tmp_last_hidden_state = encoded_text.last_hidden_state.clone() | |||||
| txt_memory = rearrange(tmp_last_hidden_state, 'b s c -> s b c') | |||||
| txt_memory = self.txt_proj( | txt_memory = self.txt_proj( | ||||
| txt_memory) # change text embeddings dim to model dim | txt_memory) # change text embeddings dim to model dim | ||||
| # Invert attention mask that we get from huggingface because its the opposite in pytorch transformer | # Invert attention mask that we get from huggingface because its the opposite in pytorch transformer | ||||
| @@ -123,7 +123,8 @@ class WindowAttention3D(nn.Module): | |||||
| # define a parameter table of relative position bias | # define a parameter table of relative position bias | ||||
| wd, wh, ww = window_size | wd, wh, ww = window_size | ||||
| self.relative_position_bias_table = nn.Parameter( | self.relative_position_bias_table = nn.Parameter( | ||||
| torch.zeros((2 * wd - 1) * (2 * wh - 1) * (2 * ww - 1), num_heads)) | |||||
| torch.zeros((2 * wd - 1) * (2 * wh - 1) * (2 * ww - 1), | |||||
| num_heads)) # 2*Wd-1 * 2*Wh-1 * 2*Ww-1, nH | |||||
| # get pair-wise relative position index for each token inside the window | # get pair-wise relative position index for each token inside the window | ||||
| coords_d = torch.arange(self.window_size[0]) | coords_d = torch.arange(self.window_size[0]) | ||||
| @@ -13,6 +13,7 @@ if TYPE_CHECKING: | |||||
| from .video_summarization_dataset import VideoSummarizationDataset | from .video_summarization_dataset import VideoSummarizationDataset | ||||
| from .image_inpainting import ImageInpaintingDataset | from .image_inpainting import ImageInpaintingDataset | ||||
| from .text_ranking_dataset import TextRankingDataset | from .text_ranking_dataset import TextRankingDataset | ||||
| from .referring_video_object_segmentation import ReferringVideoObjectSegmentationDataset | |||||
| else: | else: | ||||
| _import_structure = { | _import_structure = { | ||||
| @@ -29,6 +30,8 @@ else: | |||||
| 'sidd_image_denoising_dataset': ['SiddImageDenoisingDataset'], | 'sidd_image_denoising_dataset': ['SiddImageDenoisingDataset'], | ||||
| 'image_portrait_enhancement_dataset': | 'image_portrait_enhancement_dataset': | ||||
| ['ImagePortraitEnhancementDataset'], | ['ImagePortraitEnhancementDataset'], | ||||
| 'referring_video_object_segmentation': | |||||
| ['ReferringVideoObjectSegmentationDataset'], | |||||
| } | } | ||||
| import sys | import sys | ||||
| @@ -0,0 +1,3 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from .referring_video_object_segmentation_dataset import \ | |||||
| ReferringVideoObjectSegmentationDataset | |||||
| @@ -0,0 +1,361 @@ | |||||
| # Part of the implementation is borrowed and modified from MTTR, | |||||
| # publicly available at https://github.com/mttr2021/MTTR | |||||
| from glob import glob | |||||
| from os import path as osp | |||||
| import h5py | |||||
| import json | |||||
| import numpy as np | |||||
| import pandas | |||||
| import torch | |||||
| import torch.distributed as dist | |||||
| import torchvision.transforms.functional as F | |||||
| from pycocotools.mask import area, encode | |||||
| from torchvision.io import read_video | |||||
| from tqdm import tqdm | |||||
| from modelscope.metainfo import Models | |||||
| from modelscope.models.cv.referring_video_object_segmentation.utils import \ | |||||
| nested_tensor_from_videos_list | |||||
| from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS | |||||
| from modelscope.msdatasets.task_datasets.torch_base_dataset import \ | |||||
| TorchTaskDataset | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.logger import get_logger | |||||
| from . import transformers as T | |||||
| LOGGER = get_logger() | |||||
| def get_image_id(video_id, frame_idx, ref_instance_a2d_id): | |||||
| image_id = f'v_{video_id}_f_{frame_idx}_i_{ref_instance_a2d_id}' | |||||
| return image_id | |||||
| @TASK_DATASETS.register_module( | |||||
| Tasks.referring_video_object_segmentation, | |||||
| module_name=Models.referring_video_object_segmentation) | |||||
| class ReferringVideoObjectSegmentationDataset(TorchTaskDataset): | |||||
| def __init__(self, **kwargs): | |||||
| split_config = kwargs['split_config'] | |||||
| LOGGER.info(kwargs) | |||||
| data_cfg = kwargs.get('cfg').data_kwargs | |||||
| trans_cfg = kwargs.get('cfg').transformers_kwargs | |||||
| distributed = data_cfg.get('distributed', False) | |||||
| self.data_root = next(iter(split_config.values())) | |||||
| if not osp.exists(self.data_root): | |||||
| self.data_root = osp.dirname(self.data_root) | |||||
| assert osp.exists(self.data_root) | |||||
| self.window_size = data_cfg.get('window_size', 8) | |||||
| self.mask_annotations_dir = osp.join( | |||||
| self.data_root, 'text_annotations/annotation_with_instances') | |||||
| self.videos_dir = osp.join(self.data_root, 'Release/CLIPS320') | |||||
| self.subset_type = next(iter(split_config.keys())) | |||||
| self.text_annotations = self.get_text_annotations( | |||||
| self.data_root, self.subset_type, distributed) | |||||
| self.transforms = A2dSentencesTransforms(self.subset_type, **trans_cfg) | |||||
| self.collator = Collator() | |||||
| self.ann_file = osp.join( | |||||
| self.data_root, | |||||
| data_cfg.get('ann_file', | |||||
| 'a2d_sentences_test_annotations_in_coco_format.json')) | |||||
| # create ground-truth test annotations for the evaluation process if necessary: | |||||
| if self.subset_type == 'test' and not osp.exists(self.ann_file): | |||||
| if (distributed and dist.get_rank() == 0) or not distributed: | |||||
| create_a2d_sentences_ground_truth_test_annotations( | |||||
| self.data_root, self.subset_type, | |||||
| self.mask_annotations_dir, self.ann_file) | |||||
| if distributed: | |||||
| dist.barrier() | |||||
| def __len__(self): | |||||
| return len(self.text_annotations) | |||||
| def __getitem__(self, idx): | |||||
| text_query, video_id, frame_idx, instance_id = self.text_annotations[ | |||||
| idx] | |||||
| text_query = ' '.join( | |||||
| text_query.lower().split()) # clean up the text query | |||||
| # read the source window frames: | |||||
| video_frames, _, _ = read_video( | |||||
| osp.join(self.videos_dir, f'{video_id}.mp4'), | |||||
| pts_unit='sec') # (T, H, W, C) | |||||
| # get a window of window_size frames with frame frame_idx in the middle. | |||||
| # note that the original a2d dataset is 1 indexed, so we have to subtract 1 from frame_idx | |||||
| start_idx, end_idx = frame_idx - 1 - self.window_size // 2, frame_idx - 1 + ( | |||||
| self.window_size + 1) // 2 | |||||
| # extract the window source frames: | |||||
| source_frames = [] | |||||
| for i in range(start_idx, end_idx): | |||||
| i = min(max(i, 0), | |||||
| len(video_frames) | |||||
| - 1) # pad out of range indices with edge frames | |||||
| source_frames.append( | |||||
| F.to_pil_image(video_frames[i].permute(2, 0, 1))) | |||||
| # read the instance mask: | |||||
| frame_annot_path = osp.join(self.mask_annotations_dir, video_id, | |||||
| f'{frame_idx:05d}.h5') | |||||
| f = h5py.File(frame_annot_path, 'r') | |||||
| instances = list(f['instance']) | |||||
| instance_idx = instances.index( | |||||
| instance_id) # existence was already validated during init | |||||
| instance_masks = np.array(f['reMask']) | |||||
| if len(instances) == 1: | |||||
| instance_masks = instance_masks[np.newaxis, ...] | |||||
| instance_masks = torch.tensor(instance_masks).transpose(1, 2) | |||||
| mask_rles = [encode(mask) for mask in instance_masks.numpy()] | |||||
| mask_areas = area(mask_rles).astype(np.float) | |||||
| f.close() | |||||
| # create the target dict for the center frame: | |||||
| target = { | |||||
| 'masks': instance_masks, | |||||
| 'orig_size': instance_masks. | |||||
| shape[-2:], # original frame shape without any augmentations | |||||
| # size with augmentations, will be changed inside transforms if necessary | |||||
| 'size': instance_masks.shape[-2:], | |||||
| 'referred_instance_idx': torch.tensor( | |||||
| instance_idx), # idx in 'masks' of the text referred instance | |||||
| 'area': torch.tensor(mask_areas), | |||||
| 'iscrowd': | |||||
| torch.zeros(len(instance_masks) | |||||
| ), # for compatibility with DETR COCO transforms | |||||
| 'image_id': get_image_id(video_id, frame_idx, instance_id) | |||||
| } | |||||
| # create dummy targets for adjacent frames: | |||||
| targets = self.window_size * [None] | |||||
| center_frame_idx = self.window_size // 2 | |||||
| targets[center_frame_idx] = target | |||||
| source_frames, targets, text_query = self.transforms( | |||||
| source_frames, targets, text_query) | |||||
| return source_frames, targets, text_query | |||||
| @staticmethod | |||||
| def get_text_annotations(root_path, subset, distributed): | |||||
| saved_annotations_file_path = osp.join( | |||||
| root_path, f'sentences_single_frame_{subset}_annotations.json') | |||||
| if osp.exists(saved_annotations_file_path): | |||||
| with open(saved_annotations_file_path, 'r') as f: | |||||
| text_annotations_by_frame = [tuple(a) for a in json.load(f)] | |||||
| return text_annotations_by_frame | |||||
| elif (distributed and dist.get_rank() == 0) or not distributed: | |||||
| print(f'building a2d sentences {subset} text annotations...') | |||||
| # without 'header == None' pandas will ignore the first sample... | |||||
| a2d_data_info = pandas.read_csv( | |||||
| osp.join(root_path, 'Release/videoset.csv'), header=None) | |||||
| # 'vid', 'label', 'start_time', 'end_time', 'height', 'width', 'total_frames', 'annotated_frames', 'subset' | |||||
| a2d_data_info.columns = [ | |||||
| 'vid', '', '', '', '', '', '', '', 'subset' | |||||
| ] | |||||
| with open( | |||||
| osp.join(root_path, 'text_annotations/missed_videos.txt'), | |||||
| 'r') as f: | |||||
| unused_videos = f.read().splitlines() | |||||
| subsets = {'train': 0, 'test': 1} | |||||
| # filter unused videos and videos which do not belong to our train/test subset: | |||||
| used_videos = a2d_data_info[ | |||||
| ~a2d_data_info.vid.isin(unused_videos) | |||||
| & (a2d_data_info.subset == subsets[subset])] | |||||
| used_videos_ids = list(used_videos['vid']) | |||||
| text_annotations = pandas.read_csv( | |||||
| osp.join(root_path, 'text_annotations/annotation.txt')) | |||||
| # filter the text annotations based on the used videos: | |||||
| used_text_annotations = text_annotations[ | |||||
| text_annotations.video_id.isin(used_videos_ids)] | |||||
| # remove a single dataset annotation mistake in video: T6bNPuKV-wY | |||||
| used_text_annotations = used_text_annotations[ | |||||
| used_text_annotations['instance_id'] != '1 (copy)'] | |||||
| # convert data-frame to list of tuples: | |||||
| used_text_annotations = list( | |||||
| used_text_annotations.to_records(index=False)) | |||||
| text_annotations_by_frame = [] | |||||
| mask_annotations_dir = osp.join( | |||||
| root_path, 'text_annotations/annotation_with_instances') | |||||
| for video_id, instance_id, text_query in tqdm( | |||||
| used_text_annotations): | |||||
| frame_annot_paths = sorted( | |||||
| glob(osp.join(mask_annotations_dir, video_id, '*.h5'))) | |||||
| instance_id = int(instance_id) | |||||
| for p in frame_annot_paths: | |||||
| f = h5py.File(p) | |||||
| instances = list(f['instance']) | |||||
| if instance_id in instances: | |||||
| # in case this instance does not appear in this frame it has no ground-truth mask, and thus this | |||||
| # frame-instance pair is ignored in evaluation, same as SOTA method: CMPC-V. check out: | |||||
| # https://github.com/spyflying/CMPC-Refseg/blob/094639b8bf00cc169ea7b49cdf9c87fdfc70d963/CMPC_video/build_A2D_batches.py#L98 | |||||
| frame_idx = int(p.split('/')[-1].split('.')[0]) | |||||
| text_query = text_query.lower( | |||||
| ) # lower the text query prior to augmentation & tokenization | |||||
| text_annotations_by_frame.append( | |||||
| (text_query, video_id, frame_idx, instance_id)) | |||||
| with open(saved_annotations_file_path, 'w') as f: | |||||
| json.dump(text_annotations_by_frame, f) | |||||
| if distributed: | |||||
| dist.barrier() | |||||
| with open(saved_annotations_file_path, 'r') as f: | |||||
| text_annotations_by_frame = [tuple(a) for a in json.load(f)] | |||||
| return text_annotations_by_frame | |||||
| class A2dSentencesTransforms: | |||||
| def __init__(self, subset_type, horizontal_flip_augmentations, | |||||
| resize_and_crop_augmentations, train_short_size, | |||||
| train_max_size, eval_short_size, eval_max_size, **kwargs): | |||||
| self.h_flip_augmentation = subset_type == 'train' and horizontal_flip_augmentations | |||||
| normalize = T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |||||
| scales = [ | |||||
| train_short_size | |||||
| ] # no more scales for now due to GPU memory constraints. might be changed later | |||||
| transforms = [] | |||||
| if resize_and_crop_augmentations: | |||||
| if subset_type == 'train': | |||||
| transforms.append( | |||||
| T.RandomResize(scales, max_size=train_max_size)) | |||||
| elif subset_type == 'test': | |||||
| transforms.append( | |||||
| T.RandomResize([eval_short_size], max_size=eval_max_size)), | |||||
| transforms.extend([T.ToTensor(), normalize]) | |||||
| self.size_transforms = T.Compose(transforms) | |||||
| def __call__(self, source_frames, targets, text_query): | |||||
| if self.h_flip_augmentation and torch.rand(1) > 0.5: | |||||
| source_frames = [F.hflip(f) for f in source_frames] | |||||
| targets[len(targets) // 2]['masks'] = F.hflip( | |||||
| targets[len(targets) // 2]['masks']) | |||||
| # Note - is it possible for both 'right' and 'left' to appear together in the same query. hence this fix: | |||||
| text_query = text_query.replace('left', '@').replace( | |||||
| 'right', 'left').replace('@', 'right') | |||||
| source_frames, targets = list( | |||||
| zip(*[ | |||||
| self.size_transforms(f, t) | |||||
| for f, t in zip(source_frames, targets) | |||||
| ])) | |||||
| source_frames = torch.stack(source_frames) # [T, 3, H, W] | |||||
| return source_frames, targets, text_query | |||||
| class Collator: | |||||
| def __call__(self, batch): | |||||
| samples, targets, text_queries = list(zip(*batch)) | |||||
| samples = nested_tensor_from_videos_list(samples) # [T, B, C, H, W] | |||||
| # convert targets to a list of tuples. outer list - time steps, inner tuples - time step batch | |||||
| targets = list(zip(*targets)) | |||||
| batch_dict = { | |||||
| 'samples': samples, | |||||
| 'targets': targets, | |||||
| 'text_queries': text_queries | |||||
| } | |||||
| return batch_dict | |||||
| def get_text_annotations_gt(root_path, subset): | |||||
| # without 'header == None' pandas will ignore the first sample... | |||||
| a2d_data_info = pandas.read_csv( | |||||
| osp.join(root_path, 'Release/videoset.csv'), header=None) | |||||
| # 'vid', 'label', 'start_time', 'end_time', 'height', 'width', 'total_frames', 'annotated_frames', 'subset' | |||||
| a2d_data_info.columns = ['vid', '', '', '', '', '', '', '', 'subset'] | |||||
| with open(osp.join(root_path, 'text_annotations/missed_videos.txt'), | |||||
| 'r') as f: | |||||
| unused_videos = f.read().splitlines() | |||||
| subsets = {'train': 0, 'test': 1} | |||||
| # filter unused videos and videos which do not belong to our train/test subset: | |||||
| used_videos = a2d_data_info[~a2d_data_info.vid.isin(unused_videos) | |||||
| & (a2d_data_info.subset == subsets[subset])] | |||||
| used_videos_ids = list(used_videos['vid']) | |||||
| text_annotations = pandas.read_csv( | |||||
| osp.join(root_path, 'text_annotations/annotation.txt')) | |||||
| # filter the text annotations based on the used videos: | |||||
| used_text_annotations = text_annotations[text_annotations.video_id.isin( | |||||
| used_videos_ids)] | |||||
| # convert data-frame to list of tuples: | |||||
| used_text_annotations = list(used_text_annotations.to_records(index=False)) | |||||
| return used_text_annotations | |||||
| def create_a2d_sentences_ground_truth_test_annotations(dataset_path, | |||||
| subset_type, | |||||
| mask_annotations_dir, | |||||
| output_path): | |||||
| text_annotations = get_text_annotations_gt(dataset_path, subset_type) | |||||
| # Note - it is very important to start counting the instance and category ids from 1 (not 0). This is implicitly | |||||
| # expected by pycocotools as it is the convention of the original coco dataset annotations. | |||||
| categories_dict = [{ | |||||
| 'id': 1, | |||||
| 'name': 'dummy_class' | |||||
| }] # dummy class, as categories are not used/predicted in RVOS | |||||
| images_dict = [] | |||||
| annotations_dict = [] | |||||
| images_set = set() | |||||
| instance_id_counter = 1 | |||||
| for annot in tqdm(text_annotations): | |||||
| video_id, instance_id, text_query = annot | |||||
| annot_paths = sorted( | |||||
| glob(osp.join(mask_annotations_dir, video_id, '*.h5'))) | |||||
| for p in annot_paths: | |||||
| f = h5py.File(p) | |||||
| instances = list(f['instance']) | |||||
| try: | |||||
| instance_idx = instances.index(int(instance_id)) | |||||
| # in case this instance does not appear in this frame it has no ground-truth mask, and thus this | |||||
| # frame-instance pair is ignored in evaluation, same as SOTA method: CMPC-V. check out: | |||||
| # https://github.com/spyflying/CMPC-Refseg/blob/094639b8bf00cc169ea7b49cdf9c87fdfc70d963/CMPC_video/build_A2D_batches.py#L98 | |||||
| except ValueError: | |||||
| continue # instance_id does not appear in current frame | |||||
| mask = f['reMask'][instance_idx] if len( | |||||
| instances) > 1 else np.array(f['reMask']) | |||||
| mask = mask.transpose() | |||||
| frame_idx = int(p.split('/')[-1].split('.')[0]) | |||||
| image_id = get_image_id(video_id, frame_idx, instance_id) | |||||
| assert image_id not in images_set, f'error: image id: {image_id} appeared twice' | |||||
| images_set.add(image_id) | |||||
| images_dict.append({ | |||||
| 'id': image_id, | |||||
| 'height': mask.shape[0], | |||||
| 'width': mask.shape[1] | |||||
| }) | |||||
| mask_rle = encode(mask) | |||||
| mask_rle['counts'] = mask_rle['counts'].decode('ascii') | |||||
| mask_area = float(area(mask_rle)) | |||||
| bbox = f['reBBox'][:, instance_idx] if len( | |||||
| instances) > 1 else np.array( | |||||
| f['reBBox']).squeeze() # x1y1x2y2 form | |||||
| bbox_xywh = [ | |||||
| bbox[0], bbox[1], bbox[2] - bbox[0], bbox[3] - bbox[1] | |||||
| ] | |||||
| instance_annot = { | |||||
| 'id': instance_id_counter, | |||||
| 'image_id': image_id, | |||||
| 'category_id': | |||||
| 1, # dummy class, as categories are not used/predicted in ref-vos | |||||
| 'segmentation': mask_rle, | |||||
| 'area': mask_area, | |||||
| 'bbox': bbox_xywh, | |||||
| 'iscrowd': 0, | |||||
| } | |||||
| annotations_dict.append(instance_annot) | |||||
| instance_id_counter += 1 | |||||
| dataset_dict = { | |||||
| 'categories': categories_dict, | |||||
| 'images': images_dict, | |||||
| 'annotations': annotations_dict | |||||
| } | |||||
| with open(output_path, 'w') as f: | |||||
| json.dump(dataset_dict, f) | |||||
| @@ -0,0 +1,294 @@ | |||||
| # The implementation is adopted from MTTR, | |||||
| # made publicly available under the Apache 2.0 License at https://github.com/mttr2021/MTTR | |||||
| # Modified from DETR https://github.com/facebookresearch/detr | |||||
| import random | |||||
| import PIL | |||||
| import torch | |||||
| import torchvision.transforms as T | |||||
| import torchvision.transforms.functional as F | |||||
| from modelscope.models.cv.referring_video_object_segmentation.utils import \ | |||||
| interpolate | |||||
| def crop(image, target, region): | |||||
| cropped_image = F.crop(image, *region) | |||||
| target = target.copy() | |||||
| i, j, h, w = region | |||||
| # should we do something wrt the original size? | |||||
| target['size'] = torch.tensor([h, w]) | |||||
| fields = ['labels', 'area', 'iscrowd'] | |||||
| if 'boxes' in target: | |||||
| boxes = target['boxes'] | |||||
| max_size = torch.as_tensor([w, h], dtype=torch.float32) | |||||
| cropped_boxes = boxes - torch.as_tensor([j, i, j, i]) | |||||
| cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size) | |||||
| cropped_boxes = cropped_boxes.clamp(min=0) | |||||
| area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1) | |||||
| target['boxes'] = cropped_boxes.reshape(-1, 4) | |||||
| target['area'] = area | |||||
| fields.append('boxes') | |||||
| if 'masks' in target: | |||||
| # FIXME should we update the area here if there are no boxes? | |||||
| target['masks'] = target['masks'][:, i:i + h, j:j + w] | |||||
| fields.append('masks') | |||||
| # remove elements for which the boxes or masks that have zero area | |||||
| if 'boxes' in target or 'masks' in target: | |||||
| # favor boxes selection when defining which elements to keep | |||||
| # this is compatible with previous implementation | |||||
| if 'boxes' in target: | |||||
| cropped_boxes = target['boxes'].reshape(-1, 2, 2) | |||||
| keep = torch.all( | |||||
| cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1) | |||||
| else: | |||||
| keep = target['masks'].flatten(1).any(1) | |||||
| for field in fields: | |||||
| target[field] = target[field][keep] | |||||
| return cropped_image, target | |||||
| def hflip(image, target): | |||||
| flipped_image = F.hflip(image) | |||||
| w, h = image.size | |||||
| target = target.copy() | |||||
| if 'boxes' in target: | |||||
| boxes = target['boxes'] | |||||
| boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor( | |||||
| [-1, 1, -1, 1]) + torch.as_tensor([w, 0, w, 0]) | |||||
| target['boxes'] = boxes | |||||
| if 'masks' in target: | |||||
| target['masks'] = target['masks'].flip(-1) | |||||
| return flipped_image, target | |||||
| def resize(image, target, size, max_size=None): | |||||
| # size can be min_size (scalar) or (w, h) tuple | |||||
| def get_size_with_aspect_ratio(image_size, size, max_size=None): | |||||
| w, h = image_size | |||||
| if max_size is not None: | |||||
| min_original_size = float(min((w, h))) | |||||
| max_original_size = float(max((w, h))) | |||||
| if max_original_size / min_original_size * size > max_size: | |||||
| size = int( | |||||
| round(max_size * min_original_size / max_original_size)) | |||||
| if (w <= h and w == size) or (h <= w and h == size): | |||||
| return (h, w) | |||||
| if w < h: | |||||
| ow = size | |||||
| oh = int(size * h / w) | |||||
| else: | |||||
| oh = size | |||||
| ow = int(size * w / h) | |||||
| return (oh, ow) | |||||
| def get_size(image_size, size, max_size=None): | |||||
| if isinstance(size, (list, tuple)): | |||||
| return size[::-1] | |||||
| else: | |||||
| return get_size_with_aspect_ratio(image_size, size, max_size) | |||||
| size = get_size(image.size, size, max_size) | |||||
| rescaled_image = F.resize(image, size) | |||||
| if target is None: | |||||
| return rescaled_image, None | |||||
| ratios = tuple( | |||||
| float(s) / float(s_orig) | |||||
| for s, s_orig in zip(rescaled_image.size, image.size)) | |||||
| ratio_width, ratio_height = ratios | |||||
| target = target.copy() | |||||
| if 'boxes' in target: | |||||
| boxes = target['boxes'] | |||||
| scaled_boxes = boxes * torch.as_tensor( | |||||
| [ratio_width, ratio_height, ratio_width, ratio_height]) | |||||
| target['boxes'] = scaled_boxes | |||||
| if 'area' in target: | |||||
| area = target['area'] | |||||
| scaled_area = area * (ratio_width * ratio_height) | |||||
| target['area'] = scaled_area | |||||
| h, w = size | |||||
| target['size'] = torch.tensor([h, w]) | |||||
| if 'masks' in target: | |||||
| target['masks'] = interpolate( | |||||
| target['masks'][:, None].float(), size, mode='nearest')[:, 0] > 0.5 | |||||
| return rescaled_image, target | |||||
| def pad(image, target, padding): | |||||
| # assumes that we only pad on the bottom right corners | |||||
| padded_image = F.pad(image, (0, 0, padding[0], padding[1])) | |||||
| if target is None: | |||||
| return padded_image, None | |||||
| target = target.copy() | |||||
| # should we do something wrt the original size? | |||||
| target['size'] = torch.tensor(padded_image.size[::-1]) | |||||
| if 'masks' in target: | |||||
| target['masks'] = torch.nn.functional.pad( | |||||
| target['masks'], (0, padding[0], 0, padding[1])) | |||||
| return padded_image, target | |||||
| class RandomCrop(object): | |||||
| def __init__(self, size): | |||||
| self.size = size | |||||
| def __call__(self, img, target): | |||||
| region = T.RandomCrop.get_params(img, self.size) | |||||
| return crop(img, target, region) | |||||
| class RandomSizeCrop(object): | |||||
| def __init__(self, min_size: int, max_size: int): | |||||
| self.min_size = min_size | |||||
| self.max_size = max_size | |||||
| def __call__(self, img: PIL.Image.Image, target: dict): | |||||
| w = random.randint(self.min_size, min(img.width, self.max_size)) | |||||
| h = random.randint(self.min_size, min(img.height, self.max_size)) | |||||
| region = T.RandomCrop.get_params(img, [h, w]) | |||||
| return crop(img, target, region) | |||||
| class CenterCrop(object): | |||||
| def __init__(self, size): | |||||
| self.size = size | |||||
| def __call__(self, img, target): | |||||
| image_width, image_height = img.size | |||||
| crop_height, crop_width = self.size | |||||
| crop_top = int(round((image_height - crop_height) / 2.)) | |||||
| crop_left = int(round((image_width - crop_width) / 2.)) | |||||
| return crop(img, target, | |||||
| (crop_top, crop_left, crop_height, crop_width)) | |||||
| class RandomHorizontalFlip(object): | |||||
| def __init__(self, p=0.5): | |||||
| self.p = p | |||||
| def __call__(self, img, target): | |||||
| if random.random() < self.p: | |||||
| return hflip(img, target) | |||||
| return img, target | |||||
| class RandomResize(object): | |||||
| def __init__(self, sizes, max_size=None): | |||||
| assert isinstance(sizes, (list, tuple)) | |||||
| self.sizes = sizes | |||||
| self.max_size = max_size | |||||
| def __call__(self, img, target=None): | |||||
| size = random.choice(self.sizes) | |||||
| return resize(img, target, size, self.max_size) | |||||
| class RandomPad(object): | |||||
| def __init__(self, max_pad): | |||||
| self.max_pad = max_pad | |||||
| def __call__(self, img, target): | |||||
| pad_x = random.randint(0, self.max_pad) | |||||
| pad_y = random.randint(0, self.max_pad) | |||||
| return pad(img, target, (pad_x, pad_y)) | |||||
| class RandomSelect(object): | |||||
| """ | |||||
| Randomly selects between transforms1 and transforms2, | |||||
| with probability p for transforms1 and (1 - p) for transforms2 | |||||
| """ | |||||
| def __init__(self, transforms1, transforms2, p=0.5): | |||||
| self.transforms1 = transforms1 | |||||
| self.transforms2 = transforms2 | |||||
| self.p = p | |||||
| def __call__(self, img, target): | |||||
| if random.random() < self.p: | |||||
| return self.transforms1(img, target) | |||||
| return self.transforms2(img, target) | |||||
| class ToTensor(object): | |||||
| def __call__(self, img, target): | |||||
| return F.to_tensor(img), target | |||||
| class RandomErasing(object): | |||||
| def __init__(self, *args, **kwargs): | |||||
| self.eraser = T.RandomErasing(*args, **kwargs) | |||||
| def __call__(self, img, target): | |||||
| return self.eraser(img), target | |||||
| class Normalize(object): | |||||
| def __init__(self, mean, std): | |||||
| self.mean = mean | |||||
| self.std = std | |||||
| def __call__(self, image, target=None): | |||||
| image = F.normalize(image, mean=self.mean, std=self.std) | |||||
| if target is None: | |||||
| return image, None | |||||
| target = target.copy() | |||||
| h, w = image.shape[-2:] | |||||
| if 'boxes' in target: | |||||
| boxes = target['boxes'] | |||||
| boxes = box_xyxy_to_cxcywh(boxes) | |||||
| boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32) | |||||
| target['boxes'] = boxes | |||||
| return image, target | |||||
| class Compose(object): | |||||
| def __init__(self, transforms): | |||||
| self.transforms = transforms | |||||
| def __call__(self, image, target): | |||||
| for t in self.transforms: | |||||
| image, target = t(image, target) | |||||
| return image, target | |||||
| def __repr__(self): | |||||
| format_string = self.__class__.__name__ + '(' | |||||
| for t in self.transforms: | |||||
| format_string += '\n' | |||||
| format_string += ' {0}'.format(t) | |||||
| format_string += '\n)' | |||||
| return format_string | |||||
| @@ -157,7 +157,13 @@ class ReferringVideoObjectSegmentationPipeline(Pipeline): | |||||
| * text_border_height_per_query, 0, 0)) | * text_border_height_per_query, 0, 0)) | ||||
| W, H = vid_frame.size | W, H = vid_frame.size | ||||
| draw = ImageDraw.Draw(vid_frame) | draw = ImageDraw.Draw(vid_frame) | ||||
| font = ImageFont.truetype(font='DejaVuSansMono.ttf', size=30) | |||||
| if self.model.cfg.pipeline.output_font: | |||||
| font = ImageFont.truetype( | |||||
| font=self.model.cfg.pipeline.output_font, | |||||
| size=self.model.cfg.pipeline.output_font_size) | |||||
| else: | |||||
| font = ImageFont.load_default() | |||||
| for i, (text_query, color) in enumerate( | for i, (text_query, color) in enumerate( | ||||
| zip(self.text_queries, colors), start=1): | zip(self.text_queries, colors), start=1): | ||||
| w, h = draw.textsize(text_query, font=font) | w, h = draw.textsize(text_query, font=font) | ||||
| @@ -9,7 +9,8 @@ if TYPE_CHECKING: | |||||
| from .builder import build_trainer | from .builder import build_trainer | ||||
| from .cv import (ImageInstanceSegmentationTrainer, | from .cv import (ImageInstanceSegmentationTrainer, | ||||
| ImagePortraitEnhancementTrainer, | ImagePortraitEnhancementTrainer, | ||||
| MovieSceneSegmentationTrainer, ImageInpaintingTrainer) | |||||
| MovieSceneSegmentationTrainer, ImageInpaintingTrainer, | |||||
| ReferringVideoObjectSegmentationTrainer) | |||||
| from .multi_modal import CLIPTrainer | from .multi_modal import CLIPTrainer | ||||
| from .nlp import SequenceClassificationTrainer, TextRankingTrainer | from .nlp import SequenceClassificationTrainer, TextRankingTrainer | ||||
| from .nlp_trainer import NlpEpochBasedTrainer, VecoTrainer, NlpTrainerArguments | from .nlp_trainer import NlpEpochBasedTrainer, VecoTrainer, NlpTrainerArguments | ||||
| @@ -9,6 +9,7 @@ if TYPE_CHECKING: | |||||
| from .image_portrait_enhancement_trainer import ImagePortraitEnhancementTrainer | from .image_portrait_enhancement_trainer import ImagePortraitEnhancementTrainer | ||||
| from .movie_scene_segmentation_trainer import MovieSceneSegmentationTrainer | from .movie_scene_segmentation_trainer import MovieSceneSegmentationTrainer | ||||
| from .image_inpainting_trainer import ImageInpaintingTrainer | from .image_inpainting_trainer import ImageInpaintingTrainer | ||||
| from .referring_video_object_segmentation_trainer import ReferringVideoObjectSegmentationTrainer | |||||
| else: | else: | ||||
| _import_structure = { | _import_structure = { | ||||
| @@ -17,7 +18,9 @@ else: | |||||
| 'image_portrait_enhancement_trainer': | 'image_portrait_enhancement_trainer': | ||||
| ['ImagePortraitEnhancementTrainer'], | ['ImagePortraitEnhancementTrainer'], | ||||
| 'movie_scene_segmentation_trainer': ['MovieSceneSegmentationTrainer'], | 'movie_scene_segmentation_trainer': ['MovieSceneSegmentationTrainer'], | ||||
| 'image_inpainting_trainer': ['ImageInpaintingTrainer'] | |||||
| 'image_inpainting_trainer': ['ImageInpaintingTrainer'], | |||||
| 'referring_video_object_segmentation_trainer': | |||||
| ['ReferringVideoObjectSegmentationTrainer'] | |||||
| } | } | ||||
| import sys | import sys | ||||
| @@ -0,0 +1,63 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os | |||||
| import torch | |||||
| from modelscope.metainfo import Trainers | |||||
| from modelscope.trainers.builder import TRAINERS | |||||
| from modelscope.trainers.trainer import EpochBasedTrainer | |||||
| from modelscope.utils.constant import ModeKeys | |||||
| @TRAINERS.register_module( | |||||
| module_name=Trainers.referring_video_object_segmentation) | |||||
| class ReferringVideoObjectSegmentationTrainer(EpochBasedTrainer): | |||||
| def __init__(self, *args, **kwargs): | |||||
| super().__init__(*args, **kwargs) | |||||
| self.model.set_postprocessor(self.cfg.dataset.name) | |||||
| self.train_data_collator = self.train_dataset.collator | |||||
| self.eval_data_collator = self.eval_dataset.collator | |||||
| device_name = kwargs.get('device', 'gpu') | |||||
| self.model.set_device(self.device, device_name) | |||||
| def train(self, *args, **kwargs): | |||||
| self.model.criterion.train() | |||||
| super().train(*args, **kwargs) | |||||
| def evaluate(self, checkpoint_path=None): | |||||
| if checkpoint_path is not None and os.path.isfile(checkpoint_path): | |||||
| from modelscope.trainers.hooks import CheckpointHook | |||||
| CheckpointHook.load_checkpoint(checkpoint_path, self) | |||||
| self.model.eval() | |||||
| self._mode = ModeKeys.EVAL | |||||
| if self.eval_dataset is None: | |||||
| self.eval_dataloader = self.get_eval_data_loader() | |||||
| else: | |||||
| self.eval_dataloader = self._build_dataloader_with_dataset( | |||||
| self.eval_dataset, | |||||
| dist=self._dist, | |||||
| seed=self._seed, | |||||
| collate_fn=self.eval_data_collator, | |||||
| **self.cfg.evaluation.get('dataloader', {})) | |||||
| self.data_loader = self.eval_dataloader | |||||
| from modelscope.metrics import build_metric | |||||
| ann_file = self.eval_dataset.ann_file | |||||
| metric_classes = [] | |||||
| for metric in self.metrics: | |||||
| metric.update({'ann_file': ann_file}) | |||||
| metric_classes.append(build_metric(metric)) | |||||
| for m in metric_classes: | |||||
| m.trainer = self | |||||
| metric_values = self.evaluation_loop(self.eval_dataloader, | |||||
| metric_classes) | |||||
| self._metric_values = metric_values | |||||
| return metric_values | |||||
| def prediction_step(self, model, inputs): | |||||
| pass | |||||
| @@ -62,7 +62,10 @@ def single_gpu_test(trainer, | |||||
| if 'nsentences' in data: | if 'nsentences' in data: | ||||
| batch_size = data['nsentences'] | batch_size = data['nsentences'] | ||||
| else: | else: | ||||
| batch_size = len(next(iter(data.values()))) | |||||
| try: | |||||
| batch_size = len(next(iter(data.values()))) | |||||
| except Exception: | |||||
| batch_size = data_loader.batch_size | |||||
| else: | else: | ||||
| batch_size = len(data) | batch_size = len(data) | ||||
| for _ in range(batch_size): | for _ in range(batch_size): | ||||
| @@ -0,0 +1,101 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os | |||||
| import shutil | |||||
| import tempfile | |||||
| import unittest | |||||
| import zipfile | |||||
| from modelscope.hub.snapshot_download import snapshot_download | |||||
| from modelscope.metainfo import Trainers | |||||
| from modelscope.models.cv.movie_scene_segmentation import \ | |||||
| MovieSceneSegmentationModel | |||||
| from modelscope.msdatasets import MsDataset | |||||
| from modelscope.trainers import build_trainer | |||||
| from modelscope.utils.config import Config, ConfigDict | |||||
| from modelscope.utils.constant import ModelFile | |||||
| from modelscope.utils.test_utils import test_level | |||||
| class TestImageInstanceSegmentationTrainer(unittest.TestCase): | |||||
| model_id = 'damo/cv_swin-t_referring_video-object-segmentation' | |||||
| dataset_name = 'referring_vos_toydata' | |||||
| def setUp(self): | |||||
| print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) | |||||
| cache_path = snapshot_download(self.model_id) | |||||
| config_path = os.path.join(cache_path, ModelFile.CONFIGURATION) | |||||
| cfg = Config.from_file(config_path) | |||||
| max_epochs = cfg.train.max_epochs | |||||
| train_data_cfg = ConfigDict( | |||||
| name=self.dataset_name, | |||||
| split='train', | |||||
| test_mode=False, | |||||
| cfg=cfg.dataset) | |||||
| test_data_cfg = ConfigDict( | |||||
| name=self.dataset_name, | |||||
| split='test', | |||||
| test_mode=True, | |||||
| cfg=cfg.dataset) | |||||
| self.train_dataset = MsDataset.load( | |||||
| dataset_name=train_data_cfg.name, | |||||
| split=train_data_cfg.split, | |||||
| cfg=train_data_cfg.cfg, | |||||
| namespace='damo', | |||||
| test_mode=train_data_cfg.test_mode) | |||||
| assert next( | |||||
| iter(self.train_dataset.config_kwargs['split_config'].values())) | |||||
| self.test_dataset = MsDataset.load( | |||||
| dataset_name=test_data_cfg.name, | |||||
| split=test_data_cfg.split, | |||||
| cfg=test_data_cfg.cfg, | |||||
| namespace='damo', | |||||
| test_mode=test_data_cfg.test_mode) | |||||
| assert next( | |||||
| iter(self.test_dataset.config_kwargs['split_config'].values())) | |||||
| self.max_epochs = max_epochs | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_trainer(self): | |||||
| kwargs = dict( | |||||
| model=self.model_id, | |||||
| train_dataset=self.train_dataset, | |||||
| eval_dataset=self.test_dataset, | |||||
| work_dir='./work_dir') | |||||
| trainer = build_trainer( | |||||
| name=Trainers.referring_video_object_segmentation, | |||||
| default_args=kwargs) | |||||
| trainer.train() | |||||
| results_files = os.listdir(trainer.work_dir) | |||||
| self.assertIn(f'{trainer.timestamp}.log.json', results_files) | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| def test_trainer_with_model_and_args(self): | |||||
| cache_path = snapshot_download(self.model_id) | |||||
| model = MovieSceneSegmentationModel.from_pretrained(cache_path) | |||||
| kwargs = dict( | |||||
| cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION), | |||||
| model=model, | |||||
| train_dataset=self.train_dataset, | |||||
| eval_dataset=self.test_dataset, | |||||
| work_dir='./work_dir') | |||||
| trainer = build_trainer( | |||||
| name=Trainers.referring_video_object_segmentation, | |||||
| default_args=kwargs) | |||||
| trainer.train() | |||||
| results_files = os.listdir(trainer.work_dir) | |||||
| self.assertIn(f'{trainer.timestamp}.log.json', results_files) | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||