Merge master internal to github 1028master
| @@ -176,7 +176,10 @@ class HubApi: | |||
| """ | |||
| cookies = ModelScopeConfig.get_cookies() | |||
| owner_or_group, name = model_id_to_group_owner_name(model_id) | |||
| path = f'{self.endpoint}/api/v1/models/{owner_or_group}/{name}?Revision={revision}' | |||
| if revision: | |||
| path = f'{self.endpoint}/api/v1/models/{owner_or_group}/{name}?Revision={revision}' | |||
| else: | |||
| path = f'{self.endpoint}/api/v1/models/{owner_or_group}/{name}' | |||
| r = requests.get(path, cookies=cookies, headers=self.headers) | |||
| handle_http_response(r, logger, cookies, model_id) | |||
| @@ -447,8 +450,12 @@ class HubApi: | |||
| Returns: | |||
| List[dict]: Model file list. | |||
| """ | |||
| path = '%s/api/v1/models/%s/repo/files?Revision=%s&Recursive=%s' % ( | |||
| self.endpoint, model_id, revision, recursive) | |||
| if revision: | |||
| path = '%s/api/v1/models/%s/repo/files?Revision=%s&Recursive=%s' % ( | |||
| self.endpoint, model_id, revision, recursive) | |||
| else: | |||
| path = '%s/api/v1/models/%s/repo/files?Recursive=%s' % ( | |||
| self.endpoint, model_id, recursive) | |||
| cookies = self._check_cookie(use_cookies) | |||
| if root is not None: | |||
| path = path + f'&Root={root}' | |||
| @@ -499,13 +506,14 @@ class HubApi: | |||
| shutil.rmtree(cache_dir) | |||
| os.makedirs(cache_dir, exist_ok=True) | |||
| datahub_url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}' | |||
| r = requests.get(datahub_url) | |||
| cookies = ModelScopeConfig.get_cookies() | |||
| r = requests.get(datahub_url, cookies=cookies) | |||
| resp = r.json() | |||
| datahub_raise_on_error(datahub_url, resp) | |||
| dataset_id = resp['Data']['Id'] | |||
| dataset_type = resp['Data']['Type'] | |||
| datahub_url = f'{self.endpoint}/api/v1/datasets/{dataset_id}/repo/tree?Revision={revision}' | |||
| r = requests.get(datahub_url, headers=self.headers) | |||
| r = requests.get(datahub_url, cookies=cookies, headers=self.headers) | |||
| resp = r.json() | |||
| datahub_raise_on_error(datahub_url, resp) | |||
| file_list = resp['Data'] | |||
| @@ -524,7 +532,7 @@ class HubApi: | |||
| if extension in dataset_meta_format: | |||
| datahub_url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/repo?' \ | |||
| f'Revision={revision}&FilePath={file_path}' | |||
| r = requests.get(datahub_url) | |||
| r = requests.get(datahub_url, cookies=cookies) | |||
| raise_for_http_status(r) | |||
| local_path = os.path.join(cache_dir, file_path) | |||
| if os.path.exists(local_path): | |||
| @@ -569,9 +577,7 @@ class HubApi: | |||
| datahub_url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/' \ | |||
| f'ststoken?Revision={revision}' | |||
| cookies = requests.utils.dict_from_cookiejar(cookies) | |||
| r = requests.get( | |||
| url=datahub_url, cookies=cookies, headers=self.headers) | |||
| r = requests.get(url=datahub_url, cookies=cookies, headers=self.headers) | |||
| resp = r.json() | |||
| raise_on_error(resp) | |||
| return resp['Data'] | |||
| @@ -582,9 +588,6 @@ class HubApi: | |||
| f'MaxLimit={max_limit}&Revision={revision}&Recursive={is_recursive}&FilterDir={is_filter_dir}' | |||
| cookies = ModelScopeConfig.get_cookies() | |||
| if cookies: | |||
| cookies = requests.utils.dict_from_cookiejar(cookies) | |||
| resp = requests.get(url=url, cookies=cookies) | |||
| resp = resp.json() | |||
| raise_on_error(resp) | |||
| @@ -593,17 +596,48 @@ class HubApi: | |||
| def on_dataset_download(self, dataset_name: str, namespace: str) -> None: | |||
| url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/download/increase' | |||
| r = requests.post(url, headers=self.headers) | |||
| cookies = ModelScopeConfig.get_cookies() | |||
| r = requests.post(url, cookies=cookies, headers=self.headers) | |||
| raise_for_http_status(r) | |||
| def delete_oss_dataset_object(self, object_name: str, dataset_name: str, | |||
| namespace: str, revision: str) -> str: | |||
| if not object_name or not dataset_name or not namespace or not revision: | |||
| raise ValueError('Args cannot be empty!') | |||
| url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/oss?Path={object_name}&Revision={revision}' | |||
| cookies = self.check_local_cookies(use_cookies=True) | |||
| resp = requests.delete(url=url, cookies=cookies) | |||
| resp = resp.json() | |||
| raise_on_error(resp) | |||
| resp = resp['Message'] | |||
| return resp | |||
| def delete_oss_dataset_dir(self, object_name: str, dataset_name: str, | |||
| namespace: str, revision: str) -> str: | |||
| if not object_name or not dataset_name or not namespace or not revision: | |||
| raise ValueError('Args cannot be empty!') | |||
| url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/oss/prefix?Prefix={object_name}/' \ | |||
| f'&Revision={revision}' | |||
| cookies = self.check_local_cookies(use_cookies=True) | |||
| resp = requests.delete(url=url, cookies=cookies) | |||
| resp = resp.json() | |||
| raise_on_error(resp) | |||
| resp = resp['Message'] | |||
| return resp | |||
| @staticmethod | |||
| def datahub_remote_call(url): | |||
| r = requests.get(url, headers={'user-agent': ModelScopeConfig.get_user_agent()}) | |||
| cookies = ModelScopeConfig.get_cookies() | |||
| r = requests.get(url, cookies=cookies, headers={'user-agent': ModelScopeConfig.get_user_agent()}) | |||
| resp = r.json() | |||
| datahub_raise_on_error(url, resp) | |||
| return resp['Data'] | |||
| def check_cookies_upload_data(self, use_cookies) -> CookieJar: | |||
| def check_local_cookies(self, use_cookies) -> CookieJar: | |||
| return self._check_cookie(use_cookies=use_cookies) | |||
| @@ -63,6 +63,7 @@ def handle_http_post_error(response, url, request_body): | |||
| except HTTPError as error: | |||
| logger.error('Request %s with body: %s exception' % | |||
| (url, request_body)) | |||
| logger.error('Response details: %s' % response.content) | |||
| raise error | |||
| @@ -254,6 +254,7 @@ class Pipelines(object): | |||
| translation_en_to_de = 'translation_en_to_de' # keep it underscore | |||
| translation_en_to_ro = 'translation_en_to_ro' # keep it underscore | |||
| translation_en_to_fr = 'translation_en_to_fr' # keep it underscore | |||
| token_classification = 'token-classification' | |||
| # audio tasks | |||
| sambert_hifigan_tts = 'sambert-hifigan-tts' | |||
| @@ -305,6 +306,8 @@ class Trainers(object): | |||
| face_detection_scrfd = 'face-detection-scrfd' | |||
| card_detection_scrfd = 'card-detection-scrfd' | |||
| image_inpainting = 'image-inpainting' | |||
| referring_video_object_segmentation = 'referring-video-object-segmentation' | |||
| image_classification_team = 'image-classification-team' | |||
| # nlp trainers | |||
| bert_sentiment_analysis = 'bert-sentiment-analysis' | |||
| @@ -422,6 +425,8 @@ class Metrics(object): | |||
| image_inpainting_metric = 'image-inpainting-metric' | |||
| # metric for ocr | |||
| NED = 'ned' | |||
| # metric for referring-video-object-segmentation task | |||
| referring_video_object_segmentation_metric = 'referring-video-object-segmentation-metric' | |||
| class Optimizers(object): | |||
| @@ -20,6 +20,7 @@ if TYPE_CHECKING: | |||
| from .accuracy_metric import AccuracyMetric | |||
| from .bleu_metric import BleuMetric | |||
| from .image_inpainting_metric import ImageInpaintingMetric | |||
| from .referring_video_object_segmentation_metric import ReferringVideoObjectSegmentationMetric | |||
| else: | |||
| _import_structure = { | |||
| @@ -40,6 +41,8 @@ else: | |||
| 'image_inpainting_metric': ['ImageInpaintingMetric'], | |||
| 'accuracy_metric': ['AccuracyMetric'], | |||
| 'bleu_metric': ['BleuMetric'], | |||
| 'referring_video_object_segmentation_metric': | |||
| ['ReferringVideoObjectSegmentationMetric'], | |||
| } | |||
| import sys | |||
| @@ -43,6 +43,8 @@ task_default_metrics = { | |||
| Tasks.visual_question_answering: [Metrics.text_gen_metric], | |||
| Tasks.movie_scene_segmentation: [Metrics.movie_scene_segmentation_metric], | |||
| Tasks.image_inpainting: [Metrics.image_inpainting_metric], | |||
| Tasks.referring_video_object_segmentation: | |||
| [Metrics.referring_video_object_segmentation_metric], | |||
| } | |||
| @@ -0,0 +1,108 @@ | |||
| # Part of the implementation is borrowed and modified from MTTR, | |||
| # publicly available at https://github.com/mttr2021/MTTR | |||
| from typing import Dict | |||
| import numpy as np | |||
| import torch | |||
| from pycocotools.coco import COCO | |||
| from pycocotools.cocoeval import COCOeval | |||
| from pycocotools.mask import decode | |||
| from tqdm import tqdm | |||
| from modelscope.metainfo import Metrics | |||
| from modelscope.utils.registry import default_group | |||
| from .base import Metric | |||
| from .builder import METRICS, MetricKeys | |||
| @METRICS.register_module( | |||
| group_key=default_group, | |||
| module_name=Metrics.referring_video_object_segmentation_metric) | |||
| class ReferringVideoObjectSegmentationMetric(Metric): | |||
| """The metric computation class for movie scene segmentation classes. | |||
| """ | |||
| def __init__(self, | |||
| ann_file=None, | |||
| calculate_precision_and_iou_metrics=True): | |||
| self.ann_file = ann_file | |||
| self.calculate_precision_and_iou_metrics = calculate_precision_and_iou_metrics | |||
| self.preds = [] | |||
| def add(self, outputs: Dict, inputs: Dict): | |||
| preds_batch = outputs['pred'] | |||
| self.preds.extend(preds_batch) | |||
| def evaluate(self): | |||
| coco_gt = COCO(self.ann_file) | |||
| coco_pred = coco_gt.loadRes(self.preds) | |||
| coco_eval = COCOeval(coco_gt, coco_pred, iouType='segm') | |||
| coco_eval.params.useCats = 0 | |||
| coco_eval.evaluate() | |||
| coco_eval.accumulate() | |||
| coco_eval.summarize() | |||
| ap_labels = [ | |||
| 'mAP 0.5:0.95', 'AP 0.5', 'AP 0.75', 'AP 0.5:0.95 S', | |||
| 'AP 0.5:0.95 M', 'AP 0.5:0.95 L' | |||
| ] | |||
| ap_metrics = coco_eval.stats[:6] | |||
| eval_metrics = {la: m for la, m in zip(ap_labels, ap_metrics)} | |||
| if self.calculate_precision_and_iou_metrics: | |||
| precision_at_k, overall_iou, mean_iou = calculate_precision_at_k_and_iou_metrics( | |||
| coco_gt, coco_pred) | |||
| eval_metrics.update({ | |||
| f'P@{k}': m | |||
| for k, m in zip([0.5, 0.6, 0.7, 0.8, 0.9], precision_at_k) | |||
| }) | |||
| eval_metrics.update({ | |||
| 'overall_iou': overall_iou, | |||
| 'mean_iou': mean_iou | |||
| }) | |||
| return eval_metrics | |||
| def compute_iou(outputs: torch.Tensor, labels: torch.Tensor, EPS=1e-6): | |||
| outputs = outputs.int() | |||
| intersection = (outputs & labels).float().sum( | |||
| (1, 2)) # Will be zero if Truth=0 or Prediction=0 | |||
| union = (outputs | labels).float().sum( | |||
| (1, 2)) # Will be zero if both are 0 | |||
| iou = (intersection + EPS) / (union + EPS | |||
| ) # EPS is used to avoid division by zero | |||
| return iou, intersection, union | |||
| def calculate_precision_at_k_and_iou_metrics(coco_gt: COCO, coco_pred: COCO): | |||
| print('evaluating precision@k & iou metrics...') | |||
| counters_by_iou = {iou: 0 for iou in [0.5, 0.6, 0.7, 0.8, 0.9]} | |||
| total_intersection_area = 0 | |||
| total_union_area = 0 | |||
| ious_list = [] | |||
| for instance in tqdm(coco_gt.imgs.keys() | |||
| ): # each image_id contains exactly one instance | |||
| gt_annot = coco_gt.imgToAnns[instance][0] | |||
| gt_mask = decode(gt_annot['segmentation']) | |||
| pred_annots = coco_pred.imgToAnns[instance] | |||
| pred_annot = sorted( | |||
| pred_annots, | |||
| key=lambda a: a['score'])[-1] # choose pred with highest score | |||
| pred_mask = decode(pred_annot['segmentation']) | |||
| iou, intersection, union = compute_iou( | |||
| torch.tensor(pred_mask).unsqueeze(0), | |||
| torch.tensor(gt_mask).unsqueeze(0)) | |||
| iou, intersection, union = iou.item(), intersection.item(), union.item( | |||
| ) | |||
| for iou_threshold in counters_by_iou.keys(): | |||
| if iou > iou_threshold: | |||
| counters_by_iou[iou_threshold] += 1 | |||
| total_intersection_area += intersection | |||
| total_union_area += union | |||
| ious_list.append(iou) | |||
| num_samples = len(ious_list) | |||
| precision_at_k = np.array(list(counters_by_iou.values())) / num_samples | |||
| overall_iou = total_intersection_area / total_union_area | |||
| mean_iou = np.mean(ious_list) | |||
| return precision_at_k, overall_iou, mean_iou | |||
| @@ -3,6 +3,7 @@ | |||
| from typing import Dict | |||
| import numpy as np | |||
| from sklearn.metrics import accuracy_score, f1_score | |||
| from modelscope.metainfo import Metrics | |||
| from modelscope.outputs import OutputKeys | |||
| @@ -41,5 +42,11 @@ class SequenceClassificationMetric(Metric): | |||
| preds = np.argmax(preds, axis=1) | |||
| return { | |||
| MetricKeys.ACCURACY: | |||
| (preds == labels).astype(np.float32).mean().item() | |||
| accuracy_score(labels, preds), | |||
| MetricKeys.F1: | |||
| f1_score( | |||
| labels, | |||
| preds, | |||
| average='micro' if any([label > 1 | |||
| for label in labels]) else None), | |||
| } | |||
| @@ -2,7 +2,7 @@ | |||
| from typing import Dict, Iterable, List | |||
| from nltk.translate.bleu_score import sentence_bleu | |||
| from nltk.translate.bleu_score import SmoothingFunction, corpus_bleu | |||
| from rouge import Rouge | |||
| from modelscope.metainfo import Metrics | |||
| @@ -63,14 +63,18 @@ class TextGenerationMetric(Metric): | |||
| rouge_scores = self.rouge.get_scores(hyps=preds, refs=tgts) | |||
| rouge_1 = mean(map(lambda score: score['rouge-1']['f'], rouge_scores)) | |||
| rouge_l = mean(map(lambda score: score['rouge-l']['f'], rouge_scores)) | |||
| pred_split = tuple(pred.split(' ') for pred in self.preds) | |||
| tgt_split = tuple(tgt.split(' ') for tgt in self.tgts) | |||
| bleu_1 = mean( | |||
| sentence_bleu([tgt], pred, weights=(1, 0, 0, 0)) | |||
| for pred, tgt in zip(pred_split, tgt_split)) | |||
| bleu_4 = mean( | |||
| sentence_bleu([tgt], pred) | |||
| for pred, tgt in zip(pred_split, tgt_split)) | |||
| pred_list = [each.strip().split(' ') for each in self.preds] | |||
| tgt_list = [[each.strip().split(' ')] for each in self.tgts] | |||
| bleu_1 = corpus_bleu( | |||
| tgt_list, | |||
| pred_list, | |||
| weights=(1, 0, 0, 0), | |||
| smoothing_function=SmoothingFunction().method3) | |||
| bleu_4 = corpus_bleu( | |||
| tgt_list, | |||
| pred_list, | |||
| smoothing_function=SmoothingFunction().method3) | |||
| return { | |||
| MetricKeys.ROUGE_1: rouge_1, | |||
| MetricKeys.ROUGE_L: rouge_l, | |||
| @@ -67,8 +67,28 @@ class Model(ABC): | |||
| cfg_dict: Config = None, | |||
| device: str = None, | |||
| **kwargs): | |||
| """ Instantiate a model from local directory or remote model repo. Note | |||
| """Instantiate a model from local directory or remote model repo. Note | |||
| that when loading from remote, the model revision can be specified. | |||
| Args: | |||
| model_name_or_path(str): A model dir or a model id to be loaded | |||
| revision(str, `optional`): The revision used when the model_name_or_path is | |||
| a model id of the remote hub. default `master`. | |||
| cfg_dict(Config, `optional`): An optional model config. If provided, it will replace | |||
| the config read out of the `model_name_or_path` | |||
| device(str, `optional`): The device to load the model. | |||
| **kwargs: | |||
| task(str, `optional`): The `Tasks` enumeration value to replace the task value | |||
| read out of config in the `model_name_or_path`. This is useful when the model to be loaded is not | |||
| equal to the model saved. | |||
| For example, load a `backbone` into a `text-classification` model. | |||
| Other kwargs will be directly fed into the `model` key, to replace the default configs. | |||
| Returns: | |||
| A model instance. | |||
| Examples: | |||
| >>> from modelscope.models import Model | |||
| >>> Model.from_pretrained('damo/nlp_structbert_backbone_base_std', task='text-classification') | |||
| """ | |||
| prefetched = kwargs.get('model_prefetched') | |||
| if prefetched is not None: | |||
| @@ -5,11 +5,11 @@ from modelscope.utils.import_utils import LazyImportModule | |||
| if TYPE_CHECKING: | |||
| from .model import MovieSceneSegmentation | |||
| from .model import ReferringVideoObjectSegmentation | |||
| else: | |||
| _import_structure = { | |||
| 'model': ['MovieSceneSegmentation'], | |||
| 'model': ['ReferringVideoObjectSegmentation'], | |||
| } | |||
| import sys | |||
| @@ -1,4 +1,6 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| # Part of the implementation is borrowed and modified from MTTR, | |||
| # publicly available at https://github.com/mttr2021/MTTR | |||
| import os.path as osp | |||
| from typing import Any, Dict | |||
| @@ -10,7 +12,9 @@ from modelscope.models.builder import MODELS | |||
| from modelscope.utils.config import Config | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| from .utils import (MTTR, A2DSentencesPostProcess, ReferYoutubeVOSPostProcess, | |||
| from .utils import (MTTR, A2DSentencesPostProcess, HungarianMatcher, | |||
| ReferYoutubeVOSPostProcess, SetCriterion, | |||
| flatten_temporal_batch_dims, | |||
| nested_tensor_from_videos_list) | |||
| logger = get_logger() | |||
| @@ -35,16 +39,66 @@ class ReferringVideoObjectSegmentation(TorchModel): | |||
| params_dict = params_dict['model_state_dict'] | |||
| self.model.load_state_dict(params_dict, strict=True) | |||
| dataset_name = self.cfg.pipeline.dataset_name | |||
| if dataset_name == 'a2d_sentences' or dataset_name == 'jhmdb_sentences': | |||
| self.postprocessor = A2DSentencesPostProcess() | |||
| elif dataset_name == 'ref_youtube_vos': | |||
| self.postprocessor = ReferYoutubeVOSPostProcess() | |||
| self.set_postprocessor(self.cfg.pipeline.dataset_name) | |||
| self.set_criterion() | |||
| def set_device(self, device, name): | |||
| self.device = device | |||
| self._device_name = name | |||
| def set_postprocessor(self, dataset_name): | |||
| if 'a2d_sentences' in dataset_name or 'jhmdb_sentences' in dataset_name: | |||
| self.postprocessor = A2DSentencesPostProcess() # fine-tune | |||
| elif 'ref_youtube_vos' in dataset_name: | |||
| self.postprocessor = ReferYoutubeVOSPostProcess() # inference | |||
| else: | |||
| assert False, f'postprocessing for dataset: {dataset_name} is not supported' | |||
| def forward(self, inputs: Dict[str, Any]) -> Dict[str, torch.Tensor]: | |||
| return inputs | |||
| def forward(self, inputs: Dict[str, Any]): | |||
| samples = inputs['samples'] | |||
| targets = inputs['targets'] | |||
| text_queries = inputs['text_queries'] | |||
| valid_indices = torch.tensor( | |||
| [i for i, t in enumerate(targets) if None not in t]) | |||
| targets = [targets[i] for i in valid_indices.tolist()] | |||
| if self._device_name == 'gpu': | |||
| samples = samples.to(self.device) | |||
| valid_indices = valid_indices.to(self.device) | |||
| if isinstance(text_queries, tuple): | |||
| text_queries = list(text_queries) | |||
| outputs = self.model(samples, valid_indices, text_queries) | |||
| losses = -1 | |||
| if self.training: | |||
| loss_dict = self.criterion(outputs, targets) | |||
| weight_dict = self.criterion.weight_dict | |||
| losses = sum(loss_dict[k] * weight_dict[k] | |||
| for k in loss_dict.keys() if k in weight_dict) | |||
| predictions = [] | |||
| if not self.training: | |||
| outputs.pop('aux_outputs', None) | |||
| outputs, targets = flatten_temporal_batch_dims(outputs, targets) | |||
| processed_outputs = self.postprocessor( | |||
| outputs, | |||
| resized_padded_sample_size=samples.tensors.shape[-2:], | |||
| resized_sample_sizes=[t['size'] for t in targets], | |||
| orig_sample_sizes=[t['orig_size'] for t in targets]) | |||
| image_ids = [t['image_id'] for t in targets] | |||
| predictions = [] | |||
| for p, image_id in zip(processed_outputs, image_ids): | |||
| for s, m in zip(p['scores'], p['rle_masks']): | |||
| predictions.append({ | |||
| 'image_id': image_id, | |||
| 'category_id': | |||
| 1, # dummy label, as categories are not predicted in ref-vos | |||
| 'segmentation': m, | |||
| 'score': s.item() | |||
| }) | |||
| re = dict(pred=predictions, loss=losses) | |||
| return re | |||
| def inference(self, **kwargs): | |||
| window = kwargs['window'] | |||
| @@ -63,3 +117,26 @@ class ReferringVideoObjectSegmentation(TorchModel): | |||
| def postprocess(self, inputs: Dict[str, Any], **kwargs): | |||
| return inputs | |||
| def set_criterion(self): | |||
| matcher = HungarianMatcher( | |||
| cost_is_referred=self.cfg.matcher.set_cost_is_referred, | |||
| cost_dice=self.cfg.matcher.set_cost_dice) | |||
| weight_dict = { | |||
| 'loss_is_referred': self.cfg.loss.is_referred_loss_coef, | |||
| 'loss_dice': self.cfg.loss.dice_loss_coef, | |||
| 'loss_sigmoid_focal': self.cfg.loss.sigmoid_focal_loss_coef | |||
| } | |||
| if self.cfg.loss.aux_loss: | |||
| aux_weight_dict = {} | |||
| for i in range(self.cfg.model.num_decoder_layers - 1): | |||
| aux_weight_dict.update( | |||
| {k + f'_{i}': v | |||
| for k, v in weight_dict.items()}) | |||
| weight_dict.update(aux_weight_dict) | |||
| self.criterion = SetCriterion( | |||
| matcher=matcher, | |||
| weight_dict=weight_dict, | |||
| eos_coef=self.cfg.loss.eos_coef) | |||
| @@ -1,4 +1,6 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from .misc import nested_tensor_from_videos_list | |||
| from .criterion import SetCriterion, flatten_temporal_batch_dims | |||
| from .matcher import HungarianMatcher | |||
| from .misc import interpolate, nested_tensor_from_videos_list | |||
| from .mttr import MTTR | |||
| from .postprocessing import A2DSentencesPostProcess, ReferYoutubeVOSPostProcess | |||
| @@ -0,0 +1,198 @@ | |||
| # The implementation is adopted from MTTR, | |||
| # made publicly available under the Apache 2.0 License at https://github.com/mttr2021/MTTR | |||
| # Modified from DETR https://github.com/facebookresearch/detr | |||
| import torch | |||
| from torch import nn | |||
| from .misc import (get_world_size, interpolate, is_dist_avail_and_initialized, | |||
| nested_tensor_from_tensor_list) | |||
| from .segmentation import dice_loss, sigmoid_focal_loss | |||
| class SetCriterion(nn.Module): | |||
| """ This class computes the loss for MTTR. | |||
| The process happens in two steps: | |||
| 1) we compute the hungarian assignment between the ground-truth and predicted sequences. | |||
| 2) we supervise each pair of matched ground-truth / prediction sequences (mask + reference prediction) | |||
| """ | |||
| def __init__(self, matcher, weight_dict, eos_coef): | |||
| """ Create the criterion. | |||
| Parameters: | |||
| matcher: module able to compute a matching between targets and proposals | |||
| weight_dict: dict containing as key the names of the losses and as values their relative weight. | |||
| eos_coef: relative classification weight applied to the un-referred category | |||
| """ | |||
| super().__init__() | |||
| self.matcher = matcher | |||
| self.weight_dict = weight_dict | |||
| self.eos_coef = eos_coef | |||
| # make sure that only loss functions with non-zero weights are computed: | |||
| losses_to_compute = [] | |||
| if weight_dict['loss_dice'] > 0 or weight_dict[ | |||
| 'loss_sigmoid_focal'] > 0: | |||
| losses_to_compute.append('masks') | |||
| if weight_dict['loss_is_referred'] > 0: | |||
| losses_to_compute.append('is_referred') | |||
| self.losses = losses_to_compute | |||
| def forward(self, outputs, targets): | |||
| aux_outputs_list = outputs.pop('aux_outputs', None) | |||
| # compute the losses for the output of the last decoder layer: | |||
| losses = self.compute_criterion( | |||
| outputs, targets, losses_to_compute=self.losses) | |||
| # In case of auxiliary losses, we repeat this process with the output of each intermediate decoder layer. | |||
| if aux_outputs_list is not None: | |||
| aux_losses_to_compute = self.losses.copy() | |||
| for i, aux_outputs in enumerate(aux_outputs_list): | |||
| losses_dict = self.compute_criterion(aux_outputs, targets, | |||
| aux_losses_to_compute) | |||
| losses_dict = {k + f'_{i}': v for k, v in losses_dict.items()} | |||
| losses.update(losses_dict) | |||
| return losses | |||
| def compute_criterion(self, outputs, targets, losses_to_compute): | |||
| # Retrieve the matching between the outputs of the last layer and the targets | |||
| indices = self.matcher(outputs, targets) | |||
| # T & B dims are flattened so loss functions can be computed per frame (but with same indices per video). | |||
| # also, indices are repeated so so the same indices can be used for frames of the same video. | |||
| T = len(targets) | |||
| outputs, targets = flatten_temporal_batch_dims(outputs, targets) | |||
| # repeat the indices list T times so the same indices can be used for each video frame | |||
| indices = T * indices | |||
| # Compute the average number of target masks across all nodes, for normalization purposes | |||
| num_masks = sum(len(t['masks']) for t in targets) | |||
| num_masks = torch.as_tensor([num_masks], | |||
| dtype=torch.float, | |||
| device=indices[0][0].device) | |||
| if is_dist_avail_and_initialized(): | |||
| torch.distributed.all_reduce(num_masks) | |||
| num_masks = torch.clamp(num_masks / get_world_size(), min=1).item() | |||
| # Compute all the requested losses | |||
| losses = {} | |||
| for loss in losses_to_compute: | |||
| losses.update( | |||
| self.get_loss( | |||
| loss, outputs, targets, indices, num_masks=num_masks)) | |||
| return losses | |||
| def loss_is_referred(self, outputs, targets, indices, **kwargs): | |||
| device = outputs['pred_is_referred'].device | |||
| bs = outputs['pred_is_referred'].shape[0] | |||
| pred_is_referred = outputs['pred_is_referred'].log_softmax( | |||
| dim=-1) # note that log-softmax is used here | |||
| target_is_referred = torch.zeros_like(pred_is_referred) | |||
| # extract indices of object queries that where matched with text-referred target objects | |||
| query_referred_indices = self._get_query_referred_indices( | |||
| indices, targets) | |||
| # by default penalize compared to the no-object class (last token) | |||
| target_is_referred[:, :, :] = torch.tensor([0.0, 1.0], device=device) | |||
| if 'is_ref_inst_visible' in targets[ | |||
| 0]: # visibility labels are available per-frame for the referred object: | |||
| is_ref_inst_visible_per_frame = torch.stack( | |||
| [t['is_ref_inst_visible'] for t in targets]) | |||
| ref_inst_visible_frame_indices = is_ref_inst_visible_per_frame.nonzero( | |||
| ).squeeze() | |||
| # keep only the matched query indices of the frames in which the referred object is visible: | |||
| visible_query_referred_indices = query_referred_indices[ | |||
| ref_inst_visible_frame_indices] | |||
| target_is_referred[ref_inst_visible_frame_indices, | |||
| visible_query_referred_indices] = torch.tensor( | |||
| [1.0, 0.0], device=device) | |||
| else: # assume that the referred object is visible in every frame: | |||
| target_is_referred[torch.arange(bs), | |||
| query_referred_indices] = torch.tensor( | |||
| [1.0, 0.0], device=device) | |||
| loss = -(pred_is_referred * target_is_referred).sum(-1) | |||
| # apply no-object class weights: | |||
| eos_coef = torch.full(loss.shape, self.eos_coef, device=loss.device) | |||
| eos_coef[torch.arange(bs), query_referred_indices] = 1.0 | |||
| loss = loss * eos_coef | |||
| bs = len(indices) | |||
| loss = loss.sum() / bs # sum and normalize the loss by the batch size | |||
| losses = {'loss_is_referred': loss} | |||
| return losses | |||
| def loss_masks(self, outputs, targets, indices, num_masks, **kwargs): | |||
| assert 'pred_masks' in outputs | |||
| src_idx = self._get_src_permutation_idx(indices) | |||
| tgt_idx = self._get_tgt_permutation_idx(indices) | |||
| src_masks = outputs['pred_masks'] | |||
| src_masks = src_masks[src_idx] | |||
| masks = [t['masks'] for t in targets] | |||
| target_masks, valid = nested_tensor_from_tensor_list(masks).decompose() | |||
| target_masks = target_masks.to(src_masks) | |||
| target_masks = target_masks[tgt_idx] | |||
| # upsample predictions to the target size | |||
| src_masks = interpolate( | |||
| src_masks[:, None], | |||
| size=target_masks.shape[-2:], | |||
| mode='bilinear', | |||
| align_corners=False) | |||
| src_masks = src_masks[:, 0].flatten(1) | |||
| target_masks = target_masks.flatten(1) | |||
| target_masks = target_masks.view(src_masks.shape) | |||
| losses = { | |||
| 'loss_sigmoid_focal': | |||
| sigmoid_focal_loss(src_masks, target_masks, num_masks), | |||
| 'loss_dice': | |||
| dice_loss(src_masks, target_masks, num_masks), | |||
| } | |||
| return losses | |||
| @staticmethod | |||
| def _get_src_permutation_idx(indices): | |||
| # permute predictions following indices | |||
| batch_idx = torch.cat( | |||
| [torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) | |||
| src_idx = torch.cat([src for (src, _) in indices]) | |||
| return batch_idx, src_idx | |||
| @staticmethod | |||
| def _get_tgt_permutation_idx(indices): | |||
| # permute targets following indices | |||
| batch_idx = torch.cat( | |||
| [torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) | |||
| tgt_idx = torch.cat([tgt for (_, tgt) in indices]) | |||
| return batch_idx, tgt_idx | |||
| @staticmethod | |||
| def _get_query_referred_indices(indices, targets): | |||
| """ | |||
| extract indices of object queries that where matched with text-referred target objects | |||
| """ | |||
| query_referred_indices = [] | |||
| for (query_idxs, target_idxs), target in zip(indices, targets): | |||
| ref_query_idx = query_idxs[torch.where( | |||
| target_idxs == target['referred_instance_idx'])[0]] | |||
| query_referred_indices.append(ref_query_idx) | |||
| query_referred_indices = torch.cat(query_referred_indices) | |||
| return query_referred_indices | |||
| def get_loss(self, loss, outputs, targets, indices, **kwargs): | |||
| loss_map = { | |||
| 'masks': self.loss_masks, | |||
| 'is_referred': self.loss_is_referred, | |||
| } | |||
| assert loss in loss_map, f'do you really want to compute {loss} loss?' | |||
| return loss_map[loss](outputs, targets, indices, **kwargs) | |||
| def flatten_temporal_batch_dims(outputs, targets): | |||
| for k in outputs.keys(): | |||
| if isinstance(outputs[k], torch.Tensor): | |||
| outputs[k] = outputs[k].flatten(0, 1) | |||
| else: # list | |||
| outputs[k] = [i for step_t in outputs[k] for i in step_t] | |||
| targets = [ | |||
| frame_t_target for step_t in targets for frame_t_target in step_t | |||
| ] | |||
| return outputs, targets | |||
| @@ -0,0 +1,163 @@ | |||
| # The implementation is adopted from MTTR, | |||
| # made publicly available under the Apache 2.0 License at https://github.com/mttr2021/MTTR | |||
| # Modified from DETR https://github.com/facebookresearch/detr | |||
| # Module to compute the matching cost and solve the corresponding LSAP. | |||
| import torch | |||
| from scipy.optimize import linear_sum_assignment | |||
| from torch import nn | |||
| from .misc import interpolate, nested_tensor_from_tensor_list | |||
| class HungarianMatcher(nn.Module): | |||
| """This class computes an assignment between the targets and the predictions of the network | |||
| For efficiency reasons, the targets don't include the no_object. Because of this, in general, | |||
| there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, | |||
| while the others are un-matched (and thus treated as non-objects). | |||
| """ | |||
| def __init__(self, cost_is_referred: float = 1, cost_dice: float = 1): | |||
| """Creates the matcher | |||
| Params: | |||
| cost_is_referred: This is the relative weight of the reference cost in the total matching cost | |||
| cost_dice: This is the relative weight of the dice cost in the total matching cost | |||
| """ | |||
| super().__init__() | |||
| self.cost_is_referred = cost_is_referred | |||
| self.cost_dice = cost_dice | |||
| assert cost_is_referred != 0 or cost_dice != 0, 'all costs cant be 0' | |||
| @torch.inference_mode() | |||
| def forward(self, outputs, targets): | |||
| """ Performs the matching | |||
| Params: | |||
| outputs: A dict that contains at least these entries: | |||
| "pred_is_referred": Tensor of dim [time, batch_size, num_queries, 2] with the reference logits | |||
| "pred_masks": Tensor of dim [time, batch_size, num_queries, H, W] with the predicted masks logits | |||
| targets: A list of lists of targets (outer - time steps, inner - batch samples). each target is a dict | |||
| which contain mask and reference ground truth information for a single frame. | |||
| Returns: | |||
| A list of size batch_size, containing tuples of (index_i, index_j) where: | |||
| - index_i is the indices of the selected predictions (in order) | |||
| - index_j is the indices of the corresponding selected targets (in order) | |||
| For each batch element, it holds: | |||
| len(index_i) = len(index_j) = min(num_queries, num_target_masks) | |||
| """ | |||
| t, bs, num_queries = outputs['pred_masks'].shape[:3] | |||
| # We flatten to compute the cost matrices in a batch | |||
| out_masks = outputs['pred_masks'].flatten( | |||
| 1, 2) # [t, batch_size * num_queries, mask_h, mask_w] | |||
| # preprocess and concat the target masks | |||
| tgt_masks = [[ | |||
| m for v in t_step_batch for m in v['masks'].unsqueeze(1) | |||
| ] for t_step_batch in targets] | |||
| # pad the target masks to a uniform shape | |||
| tgt_masks, valid = list( | |||
| zip(*[ | |||
| nested_tensor_from_tensor_list(t).decompose() | |||
| for t in tgt_masks | |||
| ])) | |||
| tgt_masks = torch.stack(tgt_masks).squeeze(2) | |||
| # upsample predicted masks to target mask size | |||
| out_masks = interpolate( | |||
| out_masks, | |||
| size=tgt_masks.shape[-2:], | |||
| mode='bilinear', | |||
| align_corners=False) | |||
| # Compute the soft-tokens cost: | |||
| if self.cost_is_referred > 0: | |||
| cost_is_referred = compute_is_referred_cost(outputs, targets) | |||
| else: | |||
| cost_is_referred = 0 | |||
| # Compute the DICE coefficient between the masks: | |||
| if self.cost_dice > 0: | |||
| cost_dice = -dice_coef(out_masks, tgt_masks) | |||
| else: | |||
| cost_dice = 0 | |||
| # Final cost matrix | |||
| C = self.cost_is_referred * cost_is_referred + self.cost_dice * cost_dice | |||
| C = C.view(bs, num_queries, -1).cpu() | |||
| num_traj_per_batch = [ | |||
| len(v['masks']) for v in targets[0] | |||
| ] # number of instance trajectories in each batch | |||
| indices = [ | |||
| linear_sum_assignment(c[i]) | |||
| for i, c in enumerate(C.split(num_traj_per_batch, -1)) | |||
| ] | |||
| device = out_masks.device | |||
| return [(torch.as_tensor(i, dtype=torch.int64, device=device), | |||
| torch.as_tensor(j, dtype=torch.int64, device=device)) | |||
| for i, j in indices] | |||
| def dice_coef(inputs, targets, smooth=1.0): | |||
| """ | |||
| Compute the DICE coefficient, similar to generalized IOU for masks | |||
| Args: | |||
| inputs: A float tensor of arbitrary shape. | |||
| The predictions for each example. | |||
| targets: A float tensor with the same shape as inputs. Stores the binary | |||
| classification label for each element in inputs | |||
| (0 for the negative class and 1 for the positive class). | |||
| """ | |||
| inputs = inputs.sigmoid().flatten(2).unsqueeze(2) | |||
| targets = targets.flatten(2).unsqueeze(1) | |||
| numerator = 2 * (inputs * targets).sum(-1) | |||
| denominator = inputs.sum(-1) + targets.sum(-1) | |||
| coef = (numerator + smooth) / (denominator + smooth) | |||
| coef = coef.mean( | |||
| 0) # average on the temporal dim to get instance trajectory scores | |||
| return coef | |||
| def compute_is_referred_cost(outputs, targets): | |||
| pred_is_referred = outputs['pred_is_referred'].flatten(1, 2).softmax( | |||
| dim=-1) # [t, b*nq, 2] | |||
| device = pred_is_referred.device | |||
| t = pred_is_referred.shape[0] | |||
| # number of instance trajectories in each batch | |||
| num_traj_per_batch = torch.tensor([len(v['masks']) for v in targets[0]], | |||
| device=device) | |||
| total_trajectories = num_traj_per_batch.sum() | |||
| # note that ref_indices are shared across time steps: | |||
| ref_indices = torch.tensor( | |||
| [v['referred_instance_idx'] for v in targets[0]], device=device) | |||
| # convert ref_indices to fit flattened batch targets: | |||
| ref_indices += torch.cat( | |||
| (torch.zeros(1, dtype=torch.long, | |||
| device=device), num_traj_per_batch.cumsum(0)[:-1])) | |||
| # number of instance trajectories in each batch | |||
| target_is_referred = torch.zeros((t, total_trajectories, 2), device=device) | |||
| # 'no object' class by default (for un-referred objects) | |||
| target_is_referred[:, :, :] = torch.tensor([0.0, 1.0], device=device) | |||
| if 'is_ref_inst_visible' in targets[0][ | |||
| 0]: # visibility labels are available per-frame for the referred object: | |||
| is_ref_inst_visible = torch.stack([ | |||
| torch.stack([t['is_ref_inst_visible'] for t in t_step]) | |||
| for t_step in targets | |||
| ]).permute(1, 0) | |||
| for ref_idx, is_visible in zip(ref_indices, is_ref_inst_visible): | |||
| is_visible = is_visible.nonzero().squeeze() | |||
| target_is_referred[is_visible, | |||
| ref_idx, :] = torch.tensor([1.0, 0.0], | |||
| device=device) | |||
| else: # assume that the referred object is visible in every frame: | |||
| target_is_referred[:, ref_indices, :] = torch.tensor([1.0, 0.0], | |||
| device=device) | |||
| cost_is_referred = -(pred_is_referred.unsqueeze(2) | |||
| * target_is_referred.unsqueeze(1)).sum(dim=-1).mean( | |||
| dim=0) | |||
| return cost_is_referred | |||
| @@ -122,8 +122,8 @@ class MultimodalTransformer(nn.Module): | |||
| with torch.inference_mode(mode=self.freeze_text_encoder): | |||
| encoded_text = self.text_encoder(**tokenized_queries) | |||
| # Transpose memory because pytorch's attention expects sequence first | |||
| txt_memory = rearrange(encoded_text.last_hidden_state, | |||
| 'b s c -> s b c') | |||
| tmp_last_hidden_state = encoded_text.last_hidden_state.clone() | |||
| txt_memory = rearrange(tmp_last_hidden_state, 'b s c -> s b c') | |||
| txt_memory = self.txt_proj( | |||
| txt_memory) # change text embeddings dim to model dim | |||
| # Invert attention mask that we get from huggingface because its the opposite in pytorch transformer | |||
| @@ -123,7 +123,8 @@ class WindowAttention3D(nn.Module): | |||
| # define a parameter table of relative position bias | |||
| wd, wh, ww = window_size | |||
| self.relative_position_bias_table = nn.Parameter( | |||
| torch.zeros((2 * wd - 1) * (2 * wh - 1) * (2 * ww - 1), num_heads)) | |||
| torch.zeros((2 * wd - 1) * (2 * wh - 1) * (2 * ww - 1), | |||
| num_heads)) # 2*Wd-1 * 2*Wh-1 * 2*Ww-1, nH | |||
| # get pair-wise relative position index for each token inside the window | |||
| coords_d = torch.arange(self.window_size[0]) | |||
| @@ -269,8 +269,11 @@ class TinyNAS(nn.Module): | |||
| the_block_class = block_info['class'] | |||
| if the_block_class == 'ConvKXBNRELU': | |||
| if use_focus: | |||
| the_block = Focus(block_info['in'], block_info['out'], | |||
| block_info['k']) | |||
| the_block = Focus( | |||
| block_info['in'], | |||
| block_info['out'], | |||
| block_info['k'], | |||
| act=act) | |||
| else: | |||
| the_block = ConvKXBNRELU( | |||
| block_info['in'], | |||
| @@ -6,6 +6,7 @@ import pickle | |||
| import cv2 | |||
| import torch | |||
| import torch.nn as nn | |||
| import torchvision | |||
| from modelscope.metainfo import Models | |||
| @@ -47,6 +48,7 @@ class SingleStageDetector(TorchModel): | |||
| self.backbone = build_backbone(self.cfg.model.backbone) | |||
| self.neck = build_neck(self.cfg.model.neck) | |||
| self.head = build_head(self.cfg.model.head) | |||
| self.apply(self.init_bn) | |||
| self.load_pretrain_model(model_path) | |||
| @@ -59,6 +61,12 @@ class SingleStageDetector(TorchModel): | |||
| new_state_dict[k] = v | |||
| self.load_state_dict(new_state_dict, strict=True) | |||
| def init_bn(self, M): | |||
| for m in M.modules(): | |||
| if isinstance(m, nn.BatchNorm2d): | |||
| m.eps = 1e-3 | |||
| m.momentum = 0.03 | |||
| def inference(self, x): | |||
| if self.training: | |||
| @@ -1,6 +1,7 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| from os import path as osp | |||
| from typing import Any, Dict | |||
| import json | |||
| @@ -23,7 +24,8 @@ from modelscope.models.multi_modal.ofa import OFAModel, OFATokenizer | |||
| from modelscope.models.multi_modal.ofa.generate import sequence_generator as sg | |||
| from modelscope.models.multi_modal.ofa.generate.search import Sampling | |||
| from modelscope.models.multi_modal.ofa.generate.utils import move_to_device | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.config import Config | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| try: | |||
| from torchvision.transforms import InterpolationMode | |||
| @@ -133,6 +135,8 @@ class OfaForTextToImageSynthesis(Model): | |||
| super().__init__(model_dir=model_dir, *args, **kwargs) | |||
| # Initialize ofa | |||
| model = OFAModel.from_pretrained(model_dir) | |||
| self.cfg = Config.from_file( | |||
| osp.join(model_dir, ModelFile.CONFIGURATION)) | |||
| self.model = model.module if hasattr(model, 'module') else model | |||
| self.tokenizer = OFATokenizer.from_pretrained(model_dir) | |||
| self.tokenizer.add_tokens(['<code_{}>'.format(i) for i in range(8192)]) | |||
| @@ -171,6 +175,8 @@ class OfaForTextToImageSynthesis(Model): | |||
| 'gen_code': True, | |||
| 'constraint_range': '50265,58457' | |||
| } | |||
| if hasattr(self.cfg.model, 'beam_search'): | |||
| sg_args.update(self.cfg.model.beam_search) | |||
| self.generator = sg.SequenceGenerator(**sg_args) | |||
| def clip_tokenize(self, texts, context_length=77, truncate=False): | |||
| @@ -8,7 +8,6 @@ from torch import nn | |||
| from modelscope.metainfo import Heads | |||
| from modelscope.models.base import TorchHead | |||
| from modelscope.models.builder import HEADS | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.utils.constant import Tasks | |||
| @@ -27,9 +26,8 @@ class TextGenerationHead(TorchHead): | |||
| def forward(self, inputs=None): | |||
| logits = self.linear(inputs) | |||
| return {OutputKeys.LOGITS: logits} | |||
| return logits | |||
| def compute_loss(self, outputs: Dict[str, torch.Tensor], | |||
| def compute_loss(self, logits: torch.Tensor, | |||
| labels) -> Dict[str, torch.Tensor]: | |||
| logits = outputs[OutputKeys.LOGITS] | |||
| return {OutputKeys.LOSS: F.cross_entropy(logits, labels)} | |||
| return F.cross_entropy(logits, labels) | |||
| @@ -1,7 +1,6 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from typing import Any, Dict | |||
| import addict | |||
| import numpy as np | |||
| from transformers.modeling_utils import PreTrainedModel | |||
| @@ -9,7 +8,8 @@ from modelscope.metainfo import TaskModels | |||
| from modelscope.models.builder import MODELS | |||
| from modelscope.models.nlp.task_models.task_model import \ | |||
| SingleBackboneTaskModelBase | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.outputs import (OutputKeys, TextGenerationModelOutput, | |||
| TokenGeneratorOutput) | |||
| from modelscope.utils.constant import Tasks | |||
| __all__ = ['TaskModelForTextGeneration'] | |||
| @@ -43,12 +43,12 @@ class TaskModelForTextGeneration(SingleBackboneTaskModelBase, PreTrainedModel): | |||
| backbone_outputs = super().forward(input) | |||
| hidden_states = backbone_outputs[0] | |||
| outputs = self.head.forward(hidden_states) | |||
| logits = self.head.forward(hidden_states) | |||
| loss = None | |||
| if labels is not None: | |||
| input[OutputKeys.LABELS] = labels | |||
| loss = self.compute_loss(outputs, labels) | |||
| outputs.update(loss) | |||
| return addict.Dict(outputs) | |||
| loss = self.compute_loss(logits, labels) | |||
| return TextGenerationModelOutput(logits=logits, loss=loss) | |||
| def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): | |||
| # only last token for inputs_ids if past is defined in kwargs | |||
| @@ -76,4 +76,12 @@ class TaskModelForTextGeneration(SingleBackboneTaskModelBase, PreTrainedModel): | |||
| def generate(self, inputs, *args, **kwargs): | |||
| input_ids = inputs['input_ids'] if isinstance(inputs, Dict) else inputs | |||
| return super().generate(input_ids, *args, **kwargs) | |||
| generate_output = super().generate(input_ids, *args, **kwargs) | |||
| if isinstance(generate_output, Dict): | |||
| return TokenGeneratorOutput( | |||
| sequences=generate_output.sequences, | |||
| scores=generate_output.scores, | |||
| attentions=generate_output.attentions, | |||
| hidden_states=generate_output.hidden_states) | |||
| else: | |||
| return TokenGeneratorOutput(sequences=generate_output) | |||
| @@ -66,7 +66,6 @@ class TokenClassificationModel(SingleBackboneTaskModelBase): | |||
| attentions=outputs.attentions, | |||
| offset_mapping=input['offset_mapping'], | |||
| ) | |||
| return outputs | |||
| def extract_logits(self, outputs): | |||
| return outputs[OutputKeys.LOGITS].cpu().detach() | |||
| @@ -288,8 +288,8 @@ class InvariantPointAttention(nn.Module): | |||
| pt_att *= pt_att | |||
| pt_att = pt_att.sum(dim=-1) | |||
| head_weights = self.softplus(self.head_weights).view( | |||
| *((1, ) * len(pt_att.shape[:-2]) + (-1, 1))) | |||
| head_weights = self.softplus(self.head_weights).view( # noqa | |||
| *((1, ) * len(pt_att.shape[:-2]) + (-1, 1))) # noqa | |||
| head_weights = head_weights * math.sqrt( | |||
| 1.0 / (3 * (self.num_qk_points * 9.0 / 2))) | |||
| pt_att *= head_weights * (-0.5) | |||
| @@ -20,13 +20,15 @@ from modelscope.msdatasets.task_datasets.builder import build_task_dataset | |||
| from modelscope.msdatasets.utils.dataset_builder import ExternalDataset | |||
| from modelscope.msdatasets.utils.dataset_utils import ( | |||
| get_dataset_files, get_target_dataset_structure, load_dataset_builder) | |||
| from modelscope.msdatasets.utils.delete_utils import DatasetDeleteManager | |||
| from modelscope.msdatasets.utils.download_utils import DatasetDownloadManager | |||
| from modelscope.msdatasets.utils.upload_utils import DatasetUploadManager | |||
| from modelscope.utils.config import ConfigDict | |||
| from modelscope.utils.config_ds import MS_DATASETS_CACHE | |||
| from modelscope.utils.constant import (DEFAULT_DATASET_NAMESPACE, | |||
| DEFAULT_DATASET_REVISION, | |||
| DatasetFormations, DownloadMode, Hubs) | |||
| DatasetFormations, DownloadMode, Hubs, | |||
| UploadMode) | |||
| from modelscope.utils.logger import get_logger | |||
| logger = get_logger() | |||
| @@ -576,15 +578,17 @@ class MsDataset: | |||
| return self._hf_ds.rename_columns(column_mapping) | |||
| @staticmethod | |||
| def upload(object_name: str, | |||
| local_file_path: str, | |||
| dataset_name: str, | |||
| namespace: Optional[str] = DEFAULT_DATASET_NAMESPACE, | |||
| version: Optional[str] = DEFAULT_DATASET_REVISION, | |||
| num_processes: Optional[int] = None, | |||
| chunksize: Optional[int] = 1, | |||
| filter_hidden_files: Optional[bool] = True) -> None: | |||
| """Upload dataset file or directory to the ModelScope Hub. Please login to the ModelScope Hub first. | |||
| def upload( | |||
| object_name: str, | |||
| local_file_path: str, | |||
| dataset_name: str, | |||
| namespace: Optional[str] = DEFAULT_DATASET_NAMESPACE, | |||
| version: Optional[str] = DEFAULT_DATASET_REVISION, | |||
| num_processes: Optional[int] = None, | |||
| chunksize: Optional[int] = 1, | |||
| filter_hidden_files: Optional[bool] = True, | |||
| upload_mode: Optional[UploadMode] = UploadMode.OVERWRITE) -> None: | |||
| """Upload dataset file or directory to the ModelScope Hub. Please log in to the ModelScope Hub first. | |||
| Args: | |||
| object_name (str): The object name on ModelScope, in the form of your-dataset-name.zip or your-dataset-name | |||
| @@ -592,7 +596,7 @@ class MsDataset: | |||
| dataset_name (str): Name of the dataset | |||
| namespace(str, optional): Namespace of the dataset | |||
| version: Optional[str]: Version of the dataset | |||
| num_processes: Optional[int]: The number of processes used for multi-process uploading. | |||
| num_processes: Optional[int]: The number of processes used for multiprocess uploading. | |||
| This is only applicable when local_file_path is a directory, and we are uploading mutliple-files | |||
| insided the directory. When None provided, the number returned by os.cpu_count() is used as default. | |||
| chunksize: Optional[int]: The chunksize of objects to upload. | |||
| @@ -600,24 +604,34 @@ class MsDataset: | |||
| using the default value of 1. Available if local_file_path is a directory. | |||
| filter_hidden_files: Optional[bool]: Whether to filter hidden files. | |||
| Available if local_file_path is a directory. | |||
| upload_mode: Optional[UploadMode]: How to upload objects from local. Default: UploadMode.OVERWRITE, upload | |||
| all objects from local, existing remote objects may be overwritten. | |||
| Returns: | |||
| None | |||
| """ | |||
| if not object_name: | |||
| raise ValueError('object_name cannot be empty!') | |||
| _upload_manager = DatasetUploadManager( | |||
| dataset_name=dataset_name, namespace=namespace, version=version) | |||
| upload_mode = UploadMode(upload_mode or UploadMode.OVERWRITE) | |||
| if os.path.isfile(local_file_path): | |||
| _upload_manager.upload( | |||
| object_name=object_name, local_file_path=local_file_path) | |||
| object_name=object_name, | |||
| local_file_path=local_file_path, | |||
| upload_mode=upload_mode) | |||
| elif os.path.isdir(local_file_path): | |||
| _upload_manager.upload_dir( | |||
| object_dir_name=object_name, | |||
| local_dir_path=local_file_path, | |||
| num_processes=num_processes, | |||
| chunksize=chunksize, | |||
| filter_hidden_files=filter_hidden_files) | |||
| filter_hidden_files=filter_hidden_files, | |||
| upload_mode=upload_mode) | |||
| else: | |||
| raise ValueError( | |||
| f'{local_file_path} is not a valid file path or directory') | |||
| @@ -672,7 +686,7 @@ class MsDataset: | |||
| revision of the model you want to clone from. Can be any of a branch, tag or commit hash | |||
| auth_token(`Optional[str]`): | |||
| token obtained when calling `HubApi.login()`. Usually you can safely ignore the parameter | |||
| as the token is already saved when you login the first time, if None, we will use saved token. | |||
| as the token is already saved when you log in the first time, if None, we will use saved token. | |||
| git_path:(`Optional[str]`): | |||
| The git command line path, if None, we use 'git' | |||
| force (Optional[bool]): whether to use forced-push. | |||
| @@ -687,8 +701,29 @@ class MsDataset: | |||
| revision=revision, | |||
| auth_token=auth_token, | |||
| git_path=git_path) | |||
| _repo.push( | |||
| commit_message=commit_message, | |||
| local_branch=revision, | |||
| remote_branch=revision, | |||
| force=force) | |||
| _repo.push(commit_message=commit_message, branch=revision, force=force) | |||
| @staticmethod | |||
| def delete(object_name: str, | |||
| dataset_name: str, | |||
| namespace: Optional[str] = DEFAULT_DATASET_NAMESPACE, | |||
| version: Optional[str] = DEFAULT_DATASET_REVISION) -> str: | |||
| """ Delete object of dataset. Please log in first and make sure you have permission to manage the dataset. | |||
| Args: | |||
| object_name (str): The object name of dataset to be deleted. Could be a name of file or directory. If it's | |||
| directory, then ends with `/`. | |||
| For example: your-data-name.zip, train/001/img_001.png, train/, ... | |||
| dataset_name (str): Path or name of the dataset. | |||
| namespace(str, optional): Namespace of the dataset. | |||
| version (str, optional): Version of the dataset. | |||
| Returns: | |||
| res_msg (str): Response message. | |||
| """ | |||
| _delete_manager = DatasetDeleteManager( | |||
| dataset_name=dataset_name, namespace=namespace, version=version) | |||
| resp_msg = _delete_manager.delete(object_name=object_name) | |||
| logger.info(f'Object {object_name} successfully removed!') | |||
| return resp_msg | |||
| @@ -13,6 +13,7 @@ if TYPE_CHECKING: | |||
| from .video_summarization_dataset import VideoSummarizationDataset | |||
| from .image_inpainting import ImageInpaintingDataset | |||
| from .text_ranking_dataset import TextRankingDataset | |||
| from .referring_video_object_segmentation import ReferringVideoObjectSegmentationDataset | |||
| else: | |||
| _import_structure = { | |||
| @@ -29,6 +30,8 @@ else: | |||
| 'sidd_image_denoising_dataset': ['SiddImageDenoisingDataset'], | |||
| 'image_portrait_enhancement_dataset': | |||
| ['ImagePortraitEnhancementDataset'], | |||
| 'referring_video_object_segmentation': | |||
| ['ReferringVideoObjectSegmentationDataset'], | |||
| } | |||
| import sys | |||
| @@ -0,0 +1,3 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from .referring_video_object_segmentation_dataset import \ | |||
| ReferringVideoObjectSegmentationDataset | |||
| @@ -0,0 +1,361 @@ | |||
| # Part of the implementation is borrowed and modified from MTTR, | |||
| # publicly available at https://github.com/mttr2021/MTTR | |||
| from glob import glob | |||
| from os import path as osp | |||
| import h5py | |||
| import json | |||
| import numpy as np | |||
| import pandas | |||
| import torch | |||
| import torch.distributed as dist | |||
| import torchvision.transforms.functional as F | |||
| from pycocotools.mask import area, encode | |||
| from torchvision.io import read_video | |||
| from tqdm import tqdm | |||
| from modelscope.metainfo import Models | |||
| from modelscope.models.cv.referring_video_object_segmentation.utils import \ | |||
| nested_tensor_from_videos_list | |||
| from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS | |||
| from modelscope.msdatasets.task_datasets.torch_base_dataset import \ | |||
| TorchTaskDataset | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| from . import transformers as T | |||
| LOGGER = get_logger() | |||
| def get_image_id(video_id, frame_idx, ref_instance_a2d_id): | |||
| image_id = f'v_{video_id}_f_{frame_idx}_i_{ref_instance_a2d_id}' | |||
| return image_id | |||
| @TASK_DATASETS.register_module( | |||
| Tasks.referring_video_object_segmentation, | |||
| module_name=Models.referring_video_object_segmentation) | |||
| class ReferringVideoObjectSegmentationDataset(TorchTaskDataset): | |||
| def __init__(self, **kwargs): | |||
| split_config = kwargs['split_config'] | |||
| LOGGER.info(kwargs) | |||
| data_cfg = kwargs.get('cfg').data_kwargs | |||
| trans_cfg = kwargs.get('cfg').transformers_kwargs | |||
| distributed = data_cfg.get('distributed', False) | |||
| self.data_root = next(iter(split_config.values())) | |||
| if not osp.exists(self.data_root): | |||
| self.data_root = osp.dirname(self.data_root) | |||
| assert osp.exists(self.data_root) | |||
| self.window_size = data_cfg.get('window_size', 8) | |||
| self.mask_annotations_dir = osp.join( | |||
| self.data_root, 'text_annotations/annotation_with_instances') | |||
| self.videos_dir = osp.join(self.data_root, 'Release/CLIPS320') | |||
| self.subset_type = next(iter(split_config.keys())) | |||
| self.text_annotations = self.get_text_annotations( | |||
| self.data_root, self.subset_type, distributed) | |||
| self.transforms = A2dSentencesTransforms(self.subset_type, **trans_cfg) | |||
| self.collator = Collator() | |||
| self.ann_file = osp.join( | |||
| self.data_root, | |||
| data_cfg.get('ann_file', | |||
| 'a2d_sentences_test_annotations_in_coco_format.json')) | |||
| # create ground-truth test annotations for the evaluation process if necessary: | |||
| if self.subset_type == 'test' and not osp.exists(self.ann_file): | |||
| if (distributed and dist.get_rank() == 0) or not distributed: | |||
| create_a2d_sentences_ground_truth_test_annotations( | |||
| self.data_root, self.subset_type, | |||
| self.mask_annotations_dir, self.ann_file) | |||
| if distributed: | |||
| dist.barrier() | |||
| def __len__(self): | |||
| return len(self.text_annotations) | |||
| def __getitem__(self, idx): | |||
| text_query, video_id, frame_idx, instance_id = self.text_annotations[ | |||
| idx] | |||
| text_query = ' '.join( | |||
| text_query.lower().split()) # clean up the text query | |||
| # read the source window frames: | |||
| video_frames, _, _ = read_video( | |||
| osp.join(self.videos_dir, f'{video_id}.mp4'), | |||
| pts_unit='sec') # (T, H, W, C) | |||
| # get a window of window_size frames with frame frame_idx in the middle. | |||
| # note that the original a2d dataset is 1 indexed, so we have to subtract 1 from frame_idx | |||
| start_idx, end_idx = frame_idx - 1 - self.window_size // 2, frame_idx - 1 + ( | |||
| self.window_size + 1) // 2 | |||
| # extract the window source frames: | |||
| source_frames = [] | |||
| for i in range(start_idx, end_idx): | |||
| i = min(max(i, 0), | |||
| len(video_frames) | |||
| - 1) # pad out of range indices with edge frames | |||
| source_frames.append( | |||
| F.to_pil_image(video_frames[i].permute(2, 0, 1))) | |||
| # read the instance mask: | |||
| frame_annot_path = osp.join(self.mask_annotations_dir, video_id, | |||
| f'{frame_idx:05d}.h5') | |||
| f = h5py.File(frame_annot_path, 'r') | |||
| instances = list(f['instance']) | |||
| instance_idx = instances.index( | |||
| instance_id) # existence was already validated during init | |||
| instance_masks = np.array(f['reMask']) | |||
| if len(instances) == 1: | |||
| instance_masks = instance_masks[np.newaxis, ...] | |||
| instance_masks = torch.tensor(instance_masks).transpose(1, 2) | |||
| mask_rles = [encode(mask) for mask in instance_masks.numpy()] | |||
| mask_areas = area(mask_rles).astype(np.float) | |||
| f.close() | |||
| # create the target dict for the center frame: | |||
| target = { | |||
| 'masks': instance_masks, | |||
| 'orig_size': instance_masks. | |||
| shape[-2:], # original frame shape without any augmentations | |||
| # size with augmentations, will be changed inside transforms if necessary | |||
| 'size': instance_masks.shape[-2:], | |||
| 'referred_instance_idx': torch.tensor( | |||
| instance_idx), # idx in 'masks' of the text referred instance | |||
| 'area': torch.tensor(mask_areas), | |||
| 'iscrowd': | |||
| torch.zeros(len(instance_masks) | |||
| ), # for compatibility with DETR COCO transforms | |||
| 'image_id': get_image_id(video_id, frame_idx, instance_id) | |||
| } | |||
| # create dummy targets for adjacent frames: | |||
| targets = self.window_size * [None] | |||
| center_frame_idx = self.window_size // 2 | |||
| targets[center_frame_idx] = target | |||
| source_frames, targets, text_query = self.transforms( | |||
| source_frames, targets, text_query) | |||
| return source_frames, targets, text_query | |||
| @staticmethod | |||
| def get_text_annotations(root_path, subset, distributed): | |||
| saved_annotations_file_path = osp.join( | |||
| root_path, f'sentences_single_frame_{subset}_annotations.json') | |||
| if osp.exists(saved_annotations_file_path): | |||
| with open(saved_annotations_file_path, 'r') as f: | |||
| text_annotations_by_frame = [tuple(a) for a in json.load(f)] | |||
| return text_annotations_by_frame | |||
| elif (distributed and dist.get_rank() == 0) or not distributed: | |||
| print(f'building a2d sentences {subset} text annotations...') | |||
| # without 'header == None' pandas will ignore the first sample... | |||
| a2d_data_info = pandas.read_csv( | |||
| osp.join(root_path, 'Release/videoset.csv'), header=None) | |||
| # 'vid', 'label', 'start_time', 'end_time', 'height', 'width', 'total_frames', 'annotated_frames', 'subset' | |||
| a2d_data_info.columns = [ | |||
| 'vid', '', '', '', '', '', '', '', 'subset' | |||
| ] | |||
| with open( | |||
| osp.join(root_path, 'text_annotations/missed_videos.txt'), | |||
| 'r') as f: | |||
| unused_videos = f.read().splitlines() | |||
| subsets = {'train': 0, 'test': 1} | |||
| # filter unused videos and videos which do not belong to our train/test subset: | |||
| used_videos = a2d_data_info[ | |||
| ~a2d_data_info.vid.isin(unused_videos) | |||
| & (a2d_data_info.subset == subsets[subset])] | |||
| used_videos_ids = list(used_videos['vid']) | |||
| text_annotations = pandas.read_csv( | |||
| osp.join(root_path, 'text_annotations/annotation.txt')) | |||
| # filter the text annotations based on the used videos: | |||
| used_text_annotations = text_annotations[ | |||
| text_annotations.video_id.isin(used_videos_ids)] | |||
| # remove a single dataset annotation mistake in video: T6bNPuKV-wY | |||
| used_text_annotations = used_text_annotations[ | |||
| used_text_annotations['instance_id'] != '1 (copy)'] | |||
| # convert data-frame to list of tuples: | |||
| used_text_annotations = list( | |||
| used_text_annotations.to_records(index=False)) | |||
| text_annotations_by_frame = [] | |||
| mask_annotations_dir = osp.join( | |||
| root_path, 'text_annotations/annotation_with_instances') | |||
| for video_id, instance_id, text_query in tqdm( | |||
| used_text_annotations): | |||
| frame_annot_paths = sorted( | |||
| glob(osp.join(mask_annotations_dir, video_id, '*.h5'))) | |||
| instance_id = int(instance_id) | |||
| for p in frame_annot_paths: | |||
| f = h5py.File(p) | |||
| instances = list(f['instance']) | |||
| if instance_id in instances: | |||
| # in case this instance does not appear in this frame it has no ground-truth mask, and thus this | |||
| # frame-instance pair is ignored in evaluation, same as SOTA method: CMPC-V. check out: | |||
| # https://github.com/spyflying/CMPC-Refseg/blob/094639b8bf00cc169ea7b49cdf9c87fdfc70d963/CMPC_video/build_A2D_batches.py#L98 | |||
| frame_idx = int(p.split('/')[-1].split('.')[0]) | |||
| text_query = text_query.lower( | |||
| ) # lower the text query prior to augmentation & tokenization | |||
| text_annotations_by_frame.append( | |||
| (text_query, video_id, frame_idx, instance_id)) | |||
| with open(saved_annotations_file_path, 'w') as f: | |||
| json.dump(text_annotations_by_frame, f) | |||
| if distributed: | |||
| dist.barrier() | |||
| with open(saved_annotations_file_path, 'r') as f: | |||
| text_annotations_by_frame = [tuple(a) for a in json.load(f)] | |||
| return text_annotations_by_frame | |||
| class A2dSentencesTransforms: | |||
| def __init__(self, subset_type, horizontal_flip_augmentations, | |||
| resize_and_crop_augmentations, train_short_size, | |||
| train_max_size, eval_short_size, eval_max_size, **kwargs): | |||
| self.h_flip_augmentation = subset_type == 'train' and horizontal_flip_augmentations | |||
| normalize = T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |||
| scales = [ | |||
| train_short_size | |||
| ] # no more scales for now due to GPU memory constraints. might be changed later | |||
| transforms = [] | |||
| if resize_and_crop_augmentations: | |||
| if subset_type == 'train': | |||
| transforms.append( | |||
| T.RandomResize(scales, max_size=train_max_size)) | |||
| elif subset_type == 'test': | |||
| transforms.append( | |||
| T.RandomResize([eval_short_size], max_size=eval_max_size)), | |||
| transforms.extend([T.ToTensor(), normalize]) | |||
| self.size_transforms = T.Compose(transforms) | |||
| def __call__(self, source_frames, targets, text_query): | |||
| if self.h_flip_augmentation and torch.rand(1) > 0.5: | |||
| source_frames = [F.hflip(f) for f in source_frames] | |||
| targets[len(targets) // 2]['masks'] = F.hflip( | |||
| targets[len(targets) // 2]['masks']) | |||
| # Note - is it possible for both 'right' and 'left' to appear together in the same query. hence this fix: | |||
| text_query = text_query.replace('left', '@').replace( | |||
| 'right', 'left').replace('@', 'right') | |||
| source_frames, targets = list( | |||
| zip(*[ | |||
| self.size_transforms(f, t) | |||
| for f, t in zip(source_frames, targets) | |||
| ])) | |||
| source_frames = torch.stack(source_frames) # [T, 3, H, W] | |||
| return source_frames, targets, text_query | |||
| class Collator: | |||
| def __call__(self, batch): | |||
| samples, targets, text_queries = list(zip(*batch)) | |||
| samples = nested_tensor_from_videos_list(samples) # [T, B, C, H, W] | |||
| # convert targets to a list of tuples. outer list - time steps, inner tuples - time step batch | |||
| targets = list(zip(*targets)) | |||
| batch_dict = { | |||
| 'samples': samples, | |||
| 'targets': targets, | |||
| 'text_queries': text_queries | |||
| } | |||
| return batch_dict | |||
| def get_text_annotations_gt(root_path, subset): | |||
| # without 'header == None' pandas will ignore the first sample... | |||
| a2d_data_info = pandas.read_csv( | |||
| osp.join(root_path, 'Release/videoset.csv'), header=None) | |||
| # 'vid', 'label', 'start_time', 'end_time', 'height', 'width', 'total_frames', 'annotated_frames', 'subset' | |||
| a2d_data_info.columns = ['vid', '', '', '', '', '', '', '', 'subset'] | |||
| with open(osp.join(root_path, 'text_annotations/missed_videos.txt'), | |||
| 'r') as f: | |||
| unused_videos = f.read().splitlines() | |||
| subsets = {'train': 0, 'test': 1} | |||
| # filter unused videos and videos which do not belong to our train/test subset: | |||
| used_videos = a2d_data_info[~a2d_data_info.vid.isin(unused_videos) | |||
| & (a2d_data_info.subset == subsets[subset])] | |||
| used_videos_ids = list(used_videos['vid']) | |||
| text_annotations = pandas.read_csv( | |||
| osp.join(root_path, 'text_annotations/annotation.txt')) | |||
| # filter the text annotations based on the used videos: | |||
| used_text_annotations = text_annotations[text_annotations.video_id.isin( | |||
| used_videos_ids)] | |||
| # convert data-frame to list of tuples: | |||
| used_text_annotations = list(used_text_annotations.to_records(index=False)) | |||
| return used_text_annotations | |||
| def create_a2d_sentences_ground_truth_test_annotations(dataset_path, | |||
| subset_type, | |||
| mask_annotations_dir, | |||
| output_path): | |||
| text_annotations = get_text_annotations_gt(dataset_path, subset_type) | |||
| # Note - it is very important to start counting the instance and category ids from 1 (not 0). This is implicitly | |||
| # expected by pycocotools as it is the convention of the original coco dataset annotations. | |||
| categories_dict = [{ | |||
| 'id': 1, | |||
| 'name': 'dummy_class' | |||
| }] # dummy class, as categories are not used/predicted in RVOS | |||
| images_dict = [] | |||
| annotations_dict = [] | |||
| images_set = set() | |||
| instance_id_counter = 1 | |||
| for annot in tqdm(text_annotations): | |||
| video_id, instance_id, text_query = annot | |||
| annot_paths = sorted( | |||
| glob(osp.join(mask_annotations_dir, video_id, '*.h5'))) | |||
| for p in annot_paths: | |||
| f = h5py.File(p) | |||
| instances = list(f['instance']) | |||
| try: | |||
| instance_idx = instances.index(int(instance_id)) | |||
| # in case this instance does not appear in this frame it has no ground-truth mask, and thus this | |||
| # frame-instance pair is ignored in evaluation, same as SOTA method: CMPC-V. check out: | |||
| # https://github.com/spyflying/CMPC-Refseg/blob/094639b8bf00cc169ea7b49cdf9c87fdfc70d963/CMPC_video/build_A2D_batches.py#L98 | |||
| except ValueError: | |||
| continue # instance_id does not appear in current frame | |||
| mask = f['reMask'][instance_idx] if len( | |||
| instances) > 1 else np.array(f['reMask']) | |||
| mask = mask.transpose() | |||
| frame_idx = int(p.split('/')[-1].split('.')[0]) | |||
| image_id = get_image_id(video_id, frame_idx, instance_id) | |||
| assert image_id not in images_set, f'error: image id: {image_id} appeared twice' | |||
| images_set.add(image_id) | |||
| images_dict.append({ | |||
| 'id': image_id, | |||
| 'height': mask.shape[0], | |||
| 'width': mask.shape[1] | |||
| }) | |||
| mask_rle = encode(mask) | |||
| mask_rle['counts'] = mask_rle['counts'].decode('ascii') | |||
| mask_area = float(area(mask_rle)) | |||
| bbox = f['reBBox'][:, instance_idx] if len( | |||
| instances) > 1 else np.array( | |||
| f['reBBox']).squeeze() # x1y1x2y2 form | |||
| bbox_xywh = [ | |||
| bbox[0], bbox[1], bbox[2] - bbox[0], bbox[3] - bbox[1] | |||
| ] | |||
| instance_annot = { | |||
| 'id': instance_id_counter, | |||
| 'image_id': image_id, | |||
| 'category_id': | |||
| 1, # dummy class, as categories are not used/predicted in ref-vos | |||
| 'segmentation': mask_rle, | |||
| 'area': mask_area, | |||
| 'bbox': bbox_xywh, | |||
| 'iscrowd': 0, | |||
| } | |||
| annotations_dict.append(instance_annot) | |||
| instance_id_counter += 1 | |||
| dataset_dict = { | |||
| 'categories': categories_dict, | |||
| 'images': images_dict, | |||
| 'annotations': annotations_dict | |||
| } | |||
| with open(output_path, 'w') as f: | |||
| json.dump(dataset_dict, f) | |||
| @@ -0,0 +1,294 @@ | |||
| # The implementation is adopted from MTTR, | |||
| # made publicly available under the Apache 2.0 License at https://github.com/mttr2021/MTTR | |||
| # Modified from DETR https://github.com/facebookresearch/detr | |||
| import random | |||
| import PIL | |||
| import torch | |||
| import torchvision.transforms as T | |||
| import torchvision.transforms.functional as F | |||
| from modelscope.models.cv.referring_video_object_segmentation.utils import \ | |||
| interpolate | |||
| def crop(image, target, region): | |||
| cropped_image = F.crop(image, *region) | |||
| target = target.copy() | |||
| i, j, h, w = region | |||
| # should we do something wrt the original size? | |||
| target['size'] = torch.tensor([h, w]) | |||
| fields = ['labels', 'area', 'iscrowd'] | |||
| if 'boxes' in target: | |||
| boxes = target['boxes'] | |||
| max_size = torch.as_tensor([w, h], dtype=torch.float32) | |||
| cropped_boxes = boxes - torch.as_tensor([j, i, j, i]) | |||
| cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size) | |||
| cropped_boxes = cropped_boxes.clamp(min=0) | |||
| area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1) | |||
| target['boxes'] = cropped_boxes.reshape(-1, 4) | |||
| target['area'] = area | |||
| fields.append('boxes') | |||
| if 'masks' in target: | |||
| # FIXME should we update the area here if there are no boxes? | |||
| target['masks'] = target['masks'][:, i:i + h, j:j + w] | |||
| fields.append('masks') | |||
| # remove elements for which the boxes or masks that have zero area | |||
| if 'boxes' in target or 'masks' in target: | |||
| # favor boxes selection when defining which elements to keep | |||
| # this is compatible with previous implementation | |||
| if 'boxes' in target: | |||
| cropped_boxes = target['boxes'].reshape(-1, 2, 2) | |||
| keep = torch.all( | |||
| cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1) | |||
| else: | |||
| keep = target['masks'].flatten(1).any(1) | |||
| for field in fields: | |||
| target[field] = target[field][keep] | |||
| return cropped_image, target | |||
| def hflip(image, target): | |||
| flipped_image = F.hflip(image) | |||
| w, h = image.size | |||
| target = target.copy() | |||
| if 'boxes' in target: | |||
| boxes = target['boxes'] | |||
| boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor( | |||
| [-1, 1, -1, 1]) + torch.as_tensor([w, 0, w, 0]) | |||
| target['boxes'] = boxes | |||
| if 'masks' in target: | |||
| target['masks'] = target['masks'].flip(-1) | |||
| return flipped_image, target | |||
| def resize(image, target, size, max_size=None): | |||
| # size can be min_size (scalar) or (w, h) tuple | |||
| def get_size_with_aspect_ratio(image_size, size, max_size=None): | |||
| w, h = image_size | |||
| if max_size is not None: | |||
| min_original_size = float(min((w, h))) | |||
| max_original_size = float(max((w, h))) | |||
| if max_original_size / min_original_size * size > max_size: | |||
| size = int( | |||
| round(max_size * min_original_size / max_original_size)) | |||
| if (w <= h and w == size) or (h <= w and h == size): | |||
| return (h, w) | |||
| if w < h: | |||
| ow = size | |||
| oh = int(size * h / w) | |||
| else: | |||
| oh = size | |||
| ow = int(size * w / h) | |||
| return (oh, ow) | |||
| def get_size(image_size, size, max_size=None): | |||
| if isinstance(size, (list, tuple)): | |||
| return size[::-1] | |||
| else: | |||
| return get_size_with_aspect_ratio(image_size, size, max_size) | |||
| size = get_size(image.size, size, max_size) | |||
| rescaled_image = F.resize(image, size) | |||
| if target is None: | |||
| return rescaled_image, None | |||
| ratios = tuple( | |||
| float(s) / float(s_orig) | |||
| for s, s_orig in zip(rescaled_image.size, image.size)) | |||
| ratio_width, ratio_height = ratios | |||
| target = target.copy() | |||
| if 'boxes' in target: | |||
| boxes = target['boxes'] | |||
| scaled_boxes = boxes * torch.as_tensor( | |||
| [ratio_width, ratio_height, ratio_width, ratio_height]) | |||
| target['boxes'] = scaled_boxes | |||
| if 'area' in target: | |||
| area = target['area'] | |||
| scaled_area = area * (ratio_width * ratio_height) | |||
| target['area'] = scaled_area | |||
| h, w = size | |||
| target['size'] = torch.tensor([h, w]) | |||
| if 'masks' in target: | |||
| target['masks'] = interpolate( | |||
| target['masks'][:, None].float(), size, mode='nearest')[:, 0] > 0.5 | |||
| return rescaled_image, target | |||
| def pad(image, target, padding): | |||
| # assumes that we only pad on the bottom right corners | |||
| padded_image = F.pad(image, (0, 0, padding[0], padding[1])) | |||
| if target is None: | |||
| return padded_image, None | |||
| target = target.copy() | |||
| # should we do something wrt the original size? | |||
| target['size'] = torch.tensor(padded_image.size[::-1]) | |||
| if 'masks' in target: | |||
| target['masks'] = torch.nn.functional.pad( | |||
| target['masks'], (0, padding[0], 0, padding[1])) | |||
| return padded_image, target | |||
| class RandomCrop(object): | |||
| def __init__(self, size): | |||
| self.size = size | |||
| def __call__(self, img, target): | |||
| region = T.RandomCrop.get_params(img, self.size) | |||
| return crop(img, target, region) | |||
| class RandomSizeCrop(object): | |||
| def __init__(self, min_size: int, max_size: int): | |||
| self.min_size = min_size | |||
| self.max_size = max_size | |||
| def __call__(self, img: PIL.Image.Image, target: dict): | |||
| w = random.randint(self.min_size, min(img.width, self.max_size)) | |||
| h = random.randint(self.min_size, min(img.height, self.max_size)) | |||
| region = T.RandomCrop.get_params(img, [h, w]) | |||
| return crop(img, target, region) | |||
| class CenterCrop(object): | |||
| def __init__(self, size): | |||
| self.size = size | |||
| def __call__(self, img, target): | |||
| image_width, image_height = img.size | |||
| crop_height, crop_width = self.size | |||
| crop_top = int(round((image_height - crop_height) / 2.)) | |||
| crop_left = int(round((image_width - crop_width) / 2.)) | |||
| return crop(img, target, | |||
| (crop_top, crop_left, crop_height, crop_width)) | |||
| class RandomHorizontalFlip(object): | |||
| def __init__(self, p=0.5): | |||
| self.p = p | |||
| def __call__(self, img, target): | |||
| if random.random() < self.p: | |||
| return hflip(img, target) | |||
| return img, target | |||
| class RandomResize(object): | |||
| def __init__(self, sizes, max_size=None): | |||
| assert isinstance(sizes, (list, tuple)) | |||
| self.sizes = sizes | |||
| self.max_size = max_size | |||
| def __call__(self, img, target=None): | |||
| size = random.choice(self.sizes) | |||
| return resize(img, target, size, self.max_size) | |||
| class RandomPad(object): | |||
| def __init__(self, max_pad): | |||
| self.max_pad = max_pad | |||
| def __call__(self, img, target): | |||
| pad_x = random.randint(0, self.max_pad) | |||
| pad_y = random.randint(0, self.max_pad) | |||
| return pad(img, target, (pad_x, pad_y)) | |||
| class RandomSelect(object): | |||
| """ | |||
| Randomly selects between transforms1 and transforms2, | |||
| with probability p for transforms1 and (1 - p) for transforms2 | |||
| """ | |||
| def __init__(self, transforms1, transforms2, p=0.5): | |||
| self.transforms1 = transforms1 | |||
| self.transforms2 = transforms2 | |||
| self.p = p | |||
| def __call__(self, img, target): | |||
| if random.random() < self.p: | |||
| return self.transforms1(img, target) | |||
| return self.transforms2(img, target) | |||
| class ToTensor(object): | |||
| def __call__(self, img, target): | |||
| return F.to_tensor(img), target | |||
| class RandomErasing(object): | |||
| def __init__(self, *args, **kwargs): | |||
| self.eraser = T.RandomErasing(*args, **kwargs) | |||
| def __call__(self, img, target): | |||
| return self.eraser(img), target | |||
| class Normalize(object): | |||
| def __init__(self, mean, std): | |||
| self.mean = mean | |||
| self.std = std | |||
| def __call__(self, image, target=None): | |||
| image = F.normalize(image, mean=self.mean, std=self.std) | |||
| if target is None: | |||
| return image, None | |||
| target = target.copy() | |||
| h, w = image.shape[-2:] | |||
| if 'boxes' in target: | |||
| boxes = target['boxes'] | |||
| boxes = box_xyxy_to_cxcywh(boxes) | |||
| boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32) | |||
| target['boxes'] = boxes | |||
| return image, target | |||
| class Compose(object): | |||
| def __init__(self, transforms): | |||
| self.transforms = transforms | |||
| def __call__(self, image, target): | |||
| for t in self.transforms: | |||
| image, target = t(image, target) | |||
| return image, target | |||
| def __repr__(self): | |||
| format_string = self.__class__.__name__ + '(' | |||
| for t in self.transforms: | |||
| format_string += '\n' | |||
| format_string += ' {0}'.format(t) | |||
| format_string += '\n)' | |||
| return format_string | |||
| @@ -82,7 +82,7 @@ def list_dataset_objects(hub_api: HubApi, max_limit: int, is_recursive: bool, | |||
| dataset_name: str, namespace: str, | |||
| version: str) -> list: | |||
| """ | |||
| List all of objects for specific dataset. | |||
| List all objects for specific dataset. | |||
| Args: | |||
| hub_api (class HubApi): HubApi instance. | |||
| @@ -0,0 +1,32 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from modelscope.hub.api import HubApi | |||
| class DatasetDeleteManager(object): | |||
| def __init__(self, dataset_name: str, namespace: str, version: str): | |||
| self.api = HubApi() | |||
| self.dataset_name = dataset_name | |||
| self.namespace = namespace | |||
| self.version = version | |||
| def delete(self, object_name: str) -> str: | |||
| # single object | |||
| if not object_name.endswith('/'): | |||
| resp_msg = self.api.delete_oss_dataset_object( | |||
| object_name=object_name, | |||
| dataset_name=self.dataset_name, | |||
| namespace=self.namespace, | |||
| revision=self.version) | |||
| else: | |||
| # multiple objects | |||
| object_name = object_name.strip('/') | |||
| resp_msg = self.api.delete_oss_dataset_dir( | |||
| object_name=object_name, | |||
| dataset_name=self.dataset_name, | |||
| namespace=self.namespace, | |||
| revision=self.version) | |||
| return resp_msg | |||
| @@ -27,7 +27,11 @@ class DatasetDownloadManager(DownloadManager): | |||
| oss_config = api.get_dataset_access_config(self._dataset_name, | |||
| self._namespace, | |||
| self._version) | |||
| self.oss_utilities = OssUtilities(oss_config) | |||
| self.oss_utilities = OssUtilities( | |||
| oss_config=oss_config, | |||
| dataset_name=self._dataset_name, | |||
| namespace=self._namespace, | |||
| revision=self._version) | |||
| def _download(self, url_or_filename: str, | |||
| download_config: DownloadConfig) -> str: | |||
| @@ -6,19 +6,28 @@ import os | |||
| import oss2 | |||
| from datasets.utils.file_utils import hash_url_to_filename | |||
| from modelscope.hub.api import HubApi | |||
| from modelscope.utils.constant import UploadMode | |||
| from modelscope.utils.logger import get_logger | |||
| logger = get_logger() | |||
| ACCESS_ID = 'AccessId' | |||
| ACCESS_SECRET = 'AccessSecret' | |||
| SECURITY_TOKEN = 'SecurityToken' | |||
| BUCKET = 'Bucket' | |||
| BACK_DIR = 'BackupDir' | |||
| DIR = 'Dir' | |||
| class OssUtilities: | |||
| def __init__(self, oss_config): | |||
| self.key = oss_config['AccessId'] | |||
| self.secret = oss_config['AccessSecret'] | |||
| self.token = oss_config['SecurityToken'] | |||
| self.endpoint = f"https://{oss_config['Region']}.aliyuncs.com" | |||
| self.bucket_name = oss_config['Bucket'] | |||
| auth = oss2.StsAuth(self.key, self.secret, self.token) | |||
| self.bucket = oss2.Bucket(auth, self.endpoint, self.bucket_name) | |||
| self.oss_dir = oss_config['Dir'] | |||
| self.oss_backup_dir = oss_config['BackupDir'] | |||
| def __init__(self, oss_config, dataset_name, namespace, revision): | |||
| self._do_init(oss_config=oss_config) | |||
| self.dataset_name = dataset_name | |||
| self.namespace = namespace | |||
| self.revision = revision | |||
| self.upload_resumable_tmp_store = '/tmp/modelscope/tmp_dataset' | |||
| self.upload_multipart_threshold = 50 * 1024 * 1024 | |||
| @@ -26,6 +35,28 @@ class OssUtilities: | |||
| self.upload_num_threads = 4 | |||
| self.upload_max_retries = 3 | |||
| self.api = HubApi() | |||
| def _do_init(self, oss_config): | |||
| self.key = oss_config[ACCESS_ID] | |||
| self.secret = oss_config[ACCESS_SECRET] | |||
| self.token = oss_config[SECURITY_TOKEN] | |||
| self.endpoint = f"https://{oss_config['Region']}.aliyuncs.com" | |||
| self.bucket_name = oss_config[BUCKET] | |||
| auth = oss2.StsAuth(self.key, self.secret, self.token) | |||
| self.bucket = oss2.Bucket(auth, self.endpoint, self.bucket_name) | |||
| self.oss_dir = oss_config[DIR] | |||
| self.oss_backup_dir = oss_config[BACK_DIR] | |||
| def _reload_sts(self): | |||
| cookies = self.api.check_local_cookies(use_cookies=True) | |||
| oss_config_refresh = self.api.get_dataset_access_config_session( | |||
| cookies=cookies, | |||
| dataset_name=self.dataset_name, | |||
| namespace=self.namespace, | |||
| revision=self.revision) | |||
| self._do_init(oss_config_refresh) | |||
| @staticmethod | |||
| def _percentage(consumed_bytes, total_bytes): | |||
| if total_bytes: | |||
| @@ -51,7 +82,8 @@ class OssUtilities: | |||
| return local_path | |||
| def upload(self, oss_object_name: str, local_file_path: str, | |||
| indicate_individual_progress: bool) -> str: | |||
| indicate_individual_progress: bool, | |||
| upload_mode: UploadMode) -> str: | |||
| retry_count = 0 | |||
| object_key = os.path.join(self.oss_dir, oss_object_name) | |||
| resumable_store = oss2.ResumableStore( | |||
| @@ -64,6 +96,13 @@ class OssUtilities: | |||
| while True: | |||
| try: | |||
| retry_count += 1 | |||
| exist = self.bucket.object_exists(object_key) | |||
| if upload_mode == UploadMode.APPEND and exist: | |||
| logger.info( | |||
| f'Skip {oss_object_name} in case of {upload_mode.value} mode.' | |||
| ) | |||
| break | |||
| oss2.resumable_upload( | |||
| self.bucket, | |||
| object_key, | |||
| @@ -74,7 +113,9 @@ class OssUtilities: | |||
| progress_callback=progress_callback, | |||
| num_threads=self.upload_num_threads) | |||
| break | |||
| except Exception: | |||
| except Exception as e: | |||
| if e.__getattribute__('status') == 403: | |||
| self._reload_sts() | |||
| if retry_count >= self.upload_max_retries: | |||
| raise | |||
| @@ -5,6 +5,7 @@ from multiprocessing.dummy import Pool as ThreadPool | |||
| from tqdm import tqdm | |||
| from modelscope.utils.constant import UploadMode | |||
| from .oss_utils import OssUtilities | |||
| @@ -13,38 +14,45 @@ class DatasetUploadManager(object): | |||
| def __init__(self, dataset_name: str, namespace: str, version: str): | |||
| from modelscope.hub.api import HubApi | |||
| _hub_api = HubApi() | |||
| _cookies = _hub_api.check_cookies_upload_data(use_cookies=True) | |||
| _cookies = _hub_api.check_local_cookies(use_cookies=True) | |||
| _oss_config = _hub_api.get_dataset_access_config_session( | |||
| cookies=_cookies, | |||
| dataset_name=dataset_name, | |||
| namespace=namespace, | |||
| revision=version) | |||
| self.oss_utilities = OssUtilities(_oss_config) | |||
| self.oss_utilities = OssUtilities( | |||
| oss_config=_oss_config, | |||
| dataset_name=dataset_name, | |||
| namespace=namespace, | |||
| revision=version) | |||
| def upload(self, object_name: str, local_file_path: str) -> str: | |||
| def upload(self, object_name: str, local_file_path: str, | |||
| upload_mode: UploadMode) -> str: | |||
| object_key = self.oss_utilities.upload( | |||
| oss_object_name=object_name, | |||
| local_file_path=local_file_path, | |||
| indicate_individual_progress=True) | |||
| indicate_individual_progress=True, | |||
| upload_mode=upload_mode) | |||
| return object_key | |||
| def upload_dir(self, object_dir_name: str, local_dir_path: str, | |||
| num_processes: int, chunksize: int, | |||
| filter_hidden_files: bool) -> int: | |||
| filter_hidden_files: bool, upload_mode: UploadMode) -> int: | |||
| def run_upload(args): | |||
| self.oss_utilities.upload( | |||
| oss_object_name=args[0], | |||
| local_file_path=args[1], | |||
| indicate_individual_progress=False) | |||
| indicate_individual_progress=False, | |||
| upload_mode=upload_mode) | |||
| files_list = [] | |||
| for root, dirs, files in os.walk(local_dir_path): | |||
| for file_name in files: | |||
| if filter_hidden_files and file_name.startswith('.'): | |||
| continue | |||
| # Concatenate directory name and relative path into a oss object key. e.g., train/001/1_1230.png | |||
| # Concatenate directory name and relative path into oss object key. e.g., train/001/1_1230.png | |||
| object_name = os.path.join( | |||
| object_dir_name, | |||
| root.replace(local_dir_path, '', 1).strip('/'), file_name) | |||
| @@ -541,3 +541,50 @@ class Seq2SeqLMOutput(ModelOutputBase): | |||
| encoder_last_hidden_state: Optional[Tensor] = None | |||
| encoder_hidden_states: Optional[Tuple[Tensor]] = None | |||
| encoder_attentions: Optional[Tuple[Tensor]] = None | |||
| @dataclass | |||
| class TextGenerationModelOutput(ModelOutputBase): | |||
| """The output class for text generation models. | |||
| Args: | |||
| logits (`Tensor`): The logits output of the model. loss (`Tensor`, | |||
| *optional*) The loss of the model, available when training. | |||
| hidden_states (`Tensor`, *optional*) Hidden-states of the model at the | |||
| output of each layer plus the optional initial embedding outputs. | |||
| """ | |||
| logits: Tensor = None | |||
| loss: Tensor = None | |||
| @dataclass | |||
| class TokenGeneratorOutput(ModelOutputBase): | |||
| """ | |||
| The output class for generate method of text generation models. | |||
| Args: | |||
| sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`): | |||
| The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter | |||
| if all batches finished early due to the `eos_token_id`. | |||
| scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` | |||
| is passed or when `config.output_scores=True`): | |||
| Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) | |||
| at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for | |||
| each generated token), with each tensor of shape `(batch_size*num_return_sequences, config.vocab_size)`. | |||
| attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` | |||
| is passed or `config.output_attentions=True`): | |||
| Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of | |||
| `torch.FloatTensor` of shape `(num_return_sequences*batch_size, num_heads, generated_length, | |||
| sequence_length)`. | |||
| hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` | |||
| is passed or when `config.output_hidden_states=True`): | |||
| Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of | |||
| `torch.FloatTensor` of shape `(num_return_sequences*batch_size, generated_length, hidden_size)`. | |||
| """ | |||
| sequences: Tensor = None | |||
| scores: Optional[Tuple[Tensor]] = None | |||
| attentions: Optional[Tuple[Tuple[Tensor]]] = None | |||
| hidden_states: Optional[Tuple[Tuple[Tensor]]] = None | |||
| @@ -157,7 +157,13 @@ class ReferringVideoObjectSegmentationPipeline(Pipeline): | |||
| * text_border_height_per_query, 0, 0)) | |||
| W, H = vid_frame.size | |||
| draw = ImageDraw.Draw(vid_frame) | |||
| font = ImageFont.truetype(font='DejaVuSansMono.ttf', size=30) | |||
| if self.model.cfg.pipeline.output_font: | |||
| font = ImageFont.truetype( | |||
| font=self.model.cfg.pipeline.output_font, | |||
| size=self.model.cfg.pipeline.output_font_size) | |||
| else: | |||
| font = ImageFont.load_default() | |||
| for i, (text_query, color) in enumerate( | |||
| zip(self.text_queries, colors), start=1): | |||
| w, h = draw.textsize(text_query, font=font) | |||
| @@ -104,6 +104,10 @@ class TextGenerationPipeline(Pipeline): | |||
| tokenizer = self.preprocessor.tokenizer | |||
| return tokenizer.decode(inputs.tolist(), skip_special_tokens=True) | |||
| def sentence_piece(self, inputs) -> str: | |||
| tokenizer = self.preprocessor.tokenizer | |||
| return tokenizer.decode(inputs.tolist()) | |||
| def roberta(self, inputs) -> str: | |||
| tokenizer = self.preprocessor.tokenizer | |||
| decoded = tokenizer.decode(inputs.tolist()) | |||
| @@ -121,7 +125,7 @@ class TextGenerationPipeline(Pipeline): | |||
| Dict[str, str]: the prediction results | |||
| """ | |||
| inputs = inputs['sequences'] | |||
| if isinstance(inputs, list): | |||
| if isinstance(inputs, list) or len(inputs.shape) > 1: | |||
| inputs = inputs[0] | |||
| decoded = getattr(self, self.postprocessor)(inputs) | |||
| text = self._remove_space_between_chinese_chars(decoded) | |||
| @@ -17,6 +17,8 @@ from modelscope.utils.tensor_utils import (torch_nested_detach, | |||
| __all__ = ['TokenClassificationPipeline'] | |||
| @PIPELINES.register_module( | |||
| Tasks.token_classification, module_name=Pipelines.token_classification) | |||
| @PIPELINES.register_module( | |||
| Tasks.token_classification, module_name=Pipelines.part_of_speech) | |||
| @PIPELINES.register_module( | |||
| @@ -41,7 +43,7 @@ class TokenClassificationPipeline(Pipeline): | |||
| str) else model | |||
| if preprocessor is None: | |||
| preprocessor = Model.from_pretrained( | |||
| preprocessor = Preprocessor.from_pretrained( | |||
| model.model_dir, | |||
| sequence_length=kwargs.pop('sequence_length', 128)) | |||
| model.eval() | |||
| @@ -147,8 +147,50 @@ class Preprocessor(ABC): | |||
| cfg_dict: Config = None, | |||
| preprocessor_mode=ModeKeys.INFERENCE, | |||
| **kwargs): | |||
| """ Instantiate a model from local directory or remote model repo. Note | |||
| """Instantiate a preprocessor from local directory or remote model repo. Note | |||
| that when loading from remote, the model revision can be specified. | |||
| Args: | |||
| model_name_or_path(str): A model dir or a model id used to load the preprocessor out. | |||
| revision(str, `optional`): The revision used when the model_name_or_path is | |||
| a model id of the remote hub. default `master`. | |||
| cfg_dict(Config, `optional`): An optional config. If provided, it will replace | |||
| the config read out of the `model_name_or_path` | |||
| preprocessor_mode(str, `optional`): Specify the working mode of the preprocessor, can be `train`, `eval`, | |||
| or `inference`. Default value `inference`. | |||
| The preprocessor field in the config may contain two sub preprocessors: | |||
| >>> { | |||
| >>> "train": { | |||
| >>> "type": "some-train-preprocessor" | |||
| >>> }, | |||
| >>> "val": { | |||
| >>> "type": "some-eval-preprocessor" | |||
| >>> } | |||
| >>> } | |||
| In this scenario, the `train` preprocessor will be loaded in the `train` mode, the `val` preprocessor | |||
| will be loaded in the `eval` or `inference` mode. The `mode` field in the preprocessor class | |||
| will be assigned in all the modes. | |||
| Or just one: | |||
| >>> { | |||
| >>> "type": "some-train-preprocessor" | |||
| >>> } | |||
| In this scenario, the sole preprocessor will be loaded in all the modes, | |||
| and the `mode` field in the preprocessor class will be assigned. | |||
| **kwargs: | |||
| task(str, `optional`): The `Tasks` enumeration value to replace the task value | |||
| read out of config in the `model_name_or_path`. | |||
| This is useful when the preprocessor does not have a `type` field and the task to be used is not | |||
| equal to the task of which the model is saved. | |||
| Other kwargs will be directly fed into the preprocessor, to replace the default configs. | |||
| Returns: | |||
| The preprocessor instance. | |||
| Examples: | |||
| >>> from modelscope.preprocessors import Preprocessor | |||
| >>> Preprocessor.from_pretrained('damo/nlp_debertav2_fill-mask_chinese-base') | |||
| """ | |||
| if not os.path.exists(model_name_or_path): | |||
| model_dir = snapshot_download( | |||
| @@ -157,7 +157,7 @@ class MPlugPreprocessor(Preprocessor): | |||
| def image_open(self, path: str) -> Tuple[Image.Image, int]: | |||
| if path not in self._image_map: | |||
| index = len(self._image_map) | |||
| self._image_map[path] = (Image.open(path), index) | |||
| self._image_map[path] = (load_image(path), index) | |||
| return self._image_map[path] | |||
| def __call__( | |||
| @@ -9,7 +9,8 @@ if TYPE_CHECKING: | |||
| from .builder import build_trainer | |||
| from .cv import (ImageInstanceSegmentationTrainer, | |||
| ImagePortraitEnhancementTrainer, | |||
| MovieSceneSegmentationTrainer, ImageInpaintingTrainer) | |||
| MovieSceneSegmentationTrainer, ImageInpaintingTrainer, | |||
| ReferringVideoObjectSegmentationTrainer) | |||
| from .multi_modal import CLIPTrainer | |||
| from .nlp import SequenceClassificationTrainer, TextRankingTrainer | |||
| from .nlp_trainer import NlpEpochBasedTrainer, VecoTrainer, NlpTrainerArguments | |||
| @@ -9,6 +9,7 @@ if TYPE_CHECKING: | |||
| from .image_portrait_enhancement_trainer import ImagePortraitEnhancementTrainer | |||
| from .movie_scene_segmentation_trainer import MovieSceneSegmentationTrainer | |||
| from .image_inpainting_trainer import ImageInpaintingTrainer | |||
| from .referring_video_object_segmentation_trainer import ReferringVideoObjectSegmentationTrainer | |||
| else: | |||
| _import_structure = { | |||
| @@ -17,7 +18,9 @@ else: | |||
| 'image_portrait_enhancement_trainer': | |||
| ['ImagePortraitEnhancementTrainer'], | |||
| 'movie_scene_segmentation_trainer': ['MovieSceneSegmentationTrainer'], | |||
| 'image_inpainting_trainer': ['ImageInpaintingTrainer'] | |||
| 'image_inpainting_trainer': ['ImageInpaintingTrainer'], | |||
| 'referring_video_object_segmentation_trainer': | |||
| ['ReferringVideoObjectSegmentationTrainer'] | |||
| } | |||
| import sys | |||
| @@ -0,0 +1,63 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| import torch | |||
| from modelscope.metainfo import Trainers | |||
| from modelscope.trainers.builder import TRAINERS | |||
| from modelscope.trainers.trainer import EpochBasedTrainer | |||
| from modelscope.utils.constant import ModeKeys | |||
| @TRAINERS.register_module( | |||
| module_name=Trainers.referring_video_object_segmentation) | |||
| class ReferringVideoObjectSegmentationTrainer(EpochBasedTrainer): | |||
| def __init__(self, *args, **kwargs): | |||
| super().__init__(*args, **kwargs) | |||
| self.model.set_postprocessor(self.cfg.dataset.name) | |||
| self.train_data_collator = self.train_dataset.collator | |||
| self.eval_data_collator = self.eval_dataset.collator | |||
| device_name = kwargs.get('device', 'gpu') | |||
| self.model.set_device(self.device, device_name) | |||
| def train(self, *args, **kwargs): | |||
| self.model.criterion.train() | |||
| super().train(*args, **kwargs) | |||
| def evaluate(self, checkpoint_path=None): | |||
| if checkpoint_path is not None and os.path.isfile(checkpoint_path): | |||
| from modelscope.trainers.hooks import CheckpointHook | |||
| CheckpointHook.load_checkpoint(checkpoint_path, self) | |||
| self.model.eval() | |||
| self._mode = ModeKeys.EVAL | |||
| if self.eval_dataset is None: | |||
| self.eval_dataloader = self.get_eval_data_loader() | |||
| else: | |||
| self.eval_dataloader = self._build_dataloader_with_dataset( | |||
| self.eval_dataset, | |||
| dist=self._dist, | |||
| seed=self._seed, | |||
| collate_fn=self.eval_data_collator, | |||
| **self.cfg.evaluation.get('dataloader', {})) | |||
| self.data_loader = self.eval_dataloader | |||
| from modelscope.metrics import build_metric | |||
| ann_file = self.eval_dataset.ann_file | |||
| metric_classes = [] | |||
| for metric in self.metrics: | |||
| metric.update({'ann_file': ann_file}) | |||
| metric_classes.append(build_metric(metric)) | |||
| for m in metric_classes: | |||
| m.trainer = self | |||
| metric_values = self.evaluation_loop(self.eval_dataloader, | |||
| metric_classes) | |||
| self._metric_values = metric_values | |||
| return metric_values | |||
| def prediction_step(self, model, inputs): | |||
| pass | |||
| @@ -101,8 +101,9 @@ class CheckpointHook(Hook): | |||
| model = trainer.model.module | |||
| else: | |||
| model = trainer.model | |||
| meta = load_checkpoint(filename, model, trainer.optimizer, | |||
| trainer.lr_scheduler) | |||
| meta = load_checkpoint(filename, model, | |||
| getattr(trainer, 'optimizer', None), | |||
| getattr(trainer, 'lr_scheduler', None)) | |||
| trainer._epoch = meta.get('epoch', trainer._epoch) | |||
| trainer._iter = meta.get('iter', trainer._iter) | |||
| trainer._inner_iter = meta.get('inner_iter', trainer._inner_iter) | |||
| @@ -111,7 +112,7 @@ class CheckpointHook(Hook): | |||
| # hook: Hook | |||
| key = f'{hook.__class__}-{i}' | |||
| if key in meta and hasattr(hook, 'load_state_dict'): | |||
| hook.load_state_dict(meta[key]) | |||
| hook.load_state_dict(meta.get(key, {})) | |||
| else: | |||
| trainer.logger.warn( | |||
| f'The state_dict of hook {hook.__class__} at index {i} is not found in the checkpoint file.' | |||
| @@ -123,7 +124,7 @@ class CheckpointHook(Hook): | |||
| f'The modelscope version of loaded checkpoint does not match the runtime version. ' | |||
| f'The saved version: {version}, runtime version: {__version__}' | |||
| ) | |||
| trainer.logger.warn( | |||
| trainer.logger.info( | |||
| f'Checkpoint {filename} saving time: {meta.get("time")}') | |||
| return meta | |||
| @@ -171,12 +172,17 @@ class CheckpointHook(Hook): | |||
| else: | |||
| model = trainer.model | |||
| config = trainer.cfg.to_dict() | |||
| # override pipeline by tasks name after finetune done, | |||
| # avoid case like fill mask pipeline with a text cls task | |||
| config['pipeline'] = {'type': config['task']} | |||
| if hasattr(model, 'save_pretrained'): | |||
| model.save_pretrained( | |||
| output_dir, | |||
| ModelFile.TORCH_MODEL_BIN_FILE, | |||
| save_function=save_checkpoint, | |||
| config=trainer.cfg.to_dict(), | |||
| config=config, | |||
| with_meta=False) | |||
| def after_train_iter(self, trainer): | |||
| @@ -5,9 +5,13 @@ from modelscope.utils.import_utils import LazyImportModule | |||
| if TYPE_CHECKING: | |||
| from .clip import CLIPTrainer | |||
| from .team import TEAMImgClsTrainer | |||
| else: | |||
| _import_structure = {'clip': ['CLIPTrainer']} | |||
| _import_structure = { | |||
| 'clip': ['CLIPTrainer'], | |||
| 'team': ['TEAMImgClsTrainer'] | |||
| } | |||
| import sys | |||
| @@ -0,0 +1,3 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from .team_trainer import TEAMImgClsTrainer | |||
| @@ -0,0 +1,144 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| from collections import OrderedDict | |||
| from typing import Callable, Dict, Optional | |||
| import numpy as np | |||
| import torch | |||
| import torch.nn as nn | |||
| import torchvision.datasets as datasets | |||
| import torchvision.transforms as transforms | |||
| from sklearn.metrics import confusion_matrix | |||
| from torch.optim import AdamW | |||
| from torch.utils.data import DataLoader, Dataset | |||
| from modelscope.metainfo import Trainers | |||
| from modelscope.models.base import Model | |||
| from modelscope.msdatasets import MsDataset | |||
| from modelscope.trainers.base import BaseTrainer | |||
| from modelscope.trainers.builder import TRAINERS | |||
| from modelscope.trainers.multi_modal.team.team_trainer_utils import ( | |||
| get_optimizer, train_mapping, val_mapping) | |||
| from modelscope.utils.config import Config | |||
| from modelscope.utils.constant import DownloadMode, ModeKeys | |||
| from modelscope.utils.logger import get_logger | |||
| logger = get_logger() | |||
| @TRAINERS.register_module(module_name=Trainers.image_classification_team) | |||
| class TEAMImgClsTrainer(BaseTrainer): | |||
| def __init__(self, cfg_file: str, model: str, device_id: int, | |||
| data_collator: Callable, train_dataset: Dataset, | |||
| val_dataset: Dataset, *args, **kwargs): | |||
| super().__init__(cfg_file) | |||
| self.cfg = Config.from_file(cfg_file) | |||
| team_model = Model.from_pretrained(model) | |||
| image_model = team_model.model.image_model.vision_transformer | |||
| classification_model = nn.Sequential( | |||
| OrderedDict([('encoder', image_model), | |||
| ('classifier', | |||
| nn.Linear(768, self.cfg.dataset.class_num))])) | |||
| self.model = classification_model | |||
| for pname, param in self.model.named_parameters(): | |||
| if 'encoder' in pname: | |||
| param.requires_grad = False | |||
| self.device_id = device_id | |||
| self.total_epoch = self.cfg.train.epoch | |||
| self.train_batch_size = self.cfg.train.batch_size | |||
| self.val_batch_size = self.cfg.evaluation.batch_size | |||
| self.ckpt_dir = self.cfg.train.ckpt_dir | |||
| self.collate_fn = data_collator | |||
| self.train_dataset = train_dataset | |||
| self.val_dataset = val_dataset | |||
| self.criterion = nn.CrossEntropyLoss().to(self.device_id) | |||
| def train(self, *args, **kwargs): | |||
| self.model.train() | |||
| self.model.to(self.device_id) | |||
| optimizer = get_optimizer(self.model) | |||
| for epoch in range(self.total_epoch): | |||
| train_params = { | |||
| 'pin_memory': True, | |||
| 'collate_fn': self.collate_fn, | |||
| 'batch_size': self.train_batch_size, | |||
| 'shuffle': True, | |||
| 'drop_last': True, | |||
| 'num_workers': 8 | |||
| } | |||
| train_loader = DataLoader(self.train_dataset, **train_params) | |||
| for batch_idx, data in enumerate(train_loader): | |||
| img_tensor, label_tensor = data['pixel_values'], data['labels'] | |||
| img_tensor = img_tensor.to(self.device_id, non_blocking=True) | |||
| label_tensor = label_tensor.to( | |||
| self.device_id, non_blocking=True) | |||
| pred_logits = self.model(img_tensor) | |||
| loss = self.criterion(pred_logits, label_tensor) | |||
| optimizer.zero_grad() | |||
| loss.backward() | |||
| optimizer.step() | |||
| if batch_idx % 10 == 0: | |||
| logger.info( | |||
| 'epoch: {}, train batch {}/{}, loss={:.5f}'.format( | |||
| epoch, batch_idx, len(train_loader), loss.item())) | |||
| os.makedirs(self.ckpt_dir, exist_ok=True) | |||
| torch.save(self.model.state_dict(), | |||
| '{}/epoch{}.pth'.format(self.ckpt_dir, epoch)) | |||
| self.evaluate() | |||
| def evaluate(self, | |||
| checkpoint_path: Optional[str] = None, | |||
| *args, | |||
| **kwargs) -> Dict[str, float]: | |||
| if checkpoint_path is not None: | |||
| checkpoint_params = torch.load(checkpoint_path, 'cpu') | |||
| self.model.load_state_dict(checkpoint_params) | |||
| self.model.eval() | |||
| self.model.to(self.device_id) | |||
| val_params = { | |||
| 'collate_fn': self.collate_fn, | |||
| 'batch_size': self.val_batch_size, | |||
| 'shuffle': False, | |||
| 'drop_last': False, | |||
| 'num_workers': 8 | |||
| } | |||
| val_loader = DataLoader(self.val_dataset, **val_params) | |||
| tp_cnt, processed_cnt = 0, 0 | |||
| all_pred_labels, all_gt_labels = [], [] | |||
| with torch.no_grad(): | |||
| for batch_idx, data in enumerate(val_loader): | |||
| img_tensor, label_tensor = data['pixel_values'], data['labels'] | |||
| img_tensor = img_tensor.to(self.device_id, non_blocking=True) | |||
| label_tensor = label_tensor.to( | |||
| self.device_id, non_blocking=True) | |||
| pred_logits = self.model(img_tensor) | |||
| pred_labels = torch.max(pred_logits, dim=1)[1] | |||
| tp_cnt += torch.sum(pred_labels == label_tensor).item() | |||
| processed_cnt += img_tensor.shape[0] | |||
| logger.info('Accuracy: {:.3f}'.format(tp_cnt / processed_cnt)) | |||
| all_pred_labels.extend(pred_labels.tolist()) | |||
| all_gt_labels.extend(label_tensor.tolist()) | |||
| conf_mat = confusion_matrix(all_gt_labels, all_pred_labels) | |||
| acc_mean_per_class = np.mean(conf_mat.diagonal() | |||
| / conf_mat.sum(axis=1)) | |||
| logger.info( | |||
| 'Accuracy mean per class: {:.3f}'.format(acc_mean_per_class)) | |||
| @@ -0,0 +1,87 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import torch | |||
| import torchvision.transforms as transforms | |||
| from PIL import Image | |||
| from torch.optim import AdamW | |||
| from modelscope.utils.logger import get_logger | |||
| logger = get_logger() | |||
| train_transforms = transforms.Compose([ | |||
| transforms.RandomResizedCrop(224), | |||
| transforms.RandomHorizontalFlip(), | |||
| transforms.ToTensor(), | |||
| transforms.Normalize((0.48145466, 0.4578275, 0.40821073), | |||
| (0.26862954, 0.26130258, 0.27577711)), | |||
| ]) | |||
| val_transforms = transforms.Compose([ | |||
| transforms.Resize(256), | |||
| transforms.CenterCrop(224), | |||
| transforms.ToTensor(), | |||
| transforms.Normalize((0.48145466, 0.4578275, 0.40821073), | |||
| (0.26862954, 0.26130258, 0.27577711)), | |||
| ]) | |||
| def train_mapping(examples): | |||
| examples['pixel_values'] = [ | |||
| train_transforms(Image.open(image).convert('RGB')) | |||
| for image in examples['image:FILE'] | |||
| ] | |||
| examples['labels'] = [label for label in examples['label:LABEL']] | |||
| return examples | |||
| def val_mapping(examples): | |||
| examples['pixel_values'] = [ | |||
| val_transforms(Image.open(image).convert('RGB')) | |||
| for image in examples['image:FILE'] | |||
| ] | |||
| examples['labels'] = [label for label in examples['label:LABEL']] | |||
| return examples | |||
| def collate_fn(examples): | |||
| images = [] | |||
| labels = [] | |||
| for example in examples: | |||
| images.append((example['pixel_values'])) | |||
| labels.append(example['labels']) | |||
| pixel_values = torch.stack(images) | |||
| labels = torch.tensor(labels) | |||
| return {'pixel_values': pixel_values, 'labels': labels} | |||
| def get_params_groups(ddp_model, lr): | |||
| large_lr_params = [] | |||
| small_lr_params = [] | |||
| for name, param in ddp_model.named_parameters(): | |||
| if not param.requires_grad: | |||
| continue | |||
| if 'encoder' in name: | |||
| small_lr_params.append(param) | |||
| elif 'classifier' in name: | |||
| large_lr_params.append(param) | |||
| else: | |||
| logger.info('skip param: {}'.format(name)) | |||
| params_groups = [{ | |||
| 'params': small_lr_params, | |||
| 'lr': lr / 10.0 | |||
| }, { | |||
| 'params': large_lr_params, | |||
| 'lr': lr | |||
| }] | |||
| return params_groups | |||
| def get_optimizer(ddp_model): | |||
| lr_init = 1e-3 | |||
| betas = [0.9, 0.999] | |||
| weight_decay = 0.02 | |||
| params_groups = get_params_groups(ddp_model, lr=lr_init) | |||
| return AdamW( | |||
| params_groups, lr=lr_init, betas=betas, weight_decay=weight_decay) | |||
| @@ -646,7 +646,9 @@ class VecoTrainer(NlpEpochBasedTrainer): | |||
| break | |||
| for metric_name in self.metrics: | |||
| metric_values[metric_name] = np.average( | |||
| [m[metric_name] for m in metric_values.values()]) | |||
| all_metrics = [m[metric_name] for m in metric_values.values()] | |||
| for key in all_metrics[0].keys(): | |||
| metric_values[key] = np.average( | |||
| [metric[key] for metric in all_metrics]) | |||
| return metric_values | |||
| @@ -667,10 +667,25 @@ class EpochBasedTrainer(BaseTrainer): | |||
| return dataset | |||
| def build_optimizer(self, cfg: ConfigDict, default_args: dict = None): | |||
| return build_optimizer(self.model, cfg=cfg, default_args=default_args) | |||
| try: | |||
| return build_optimizer( | |||
| self.model, cfg=cfg, default_args=default_args) | |||
| except KeyError as e: | |||
| self.logger.error( | |||
| f'Build optimizer error, the optimizer {cfg} is native torch optimizer, ' | |||
| f'please check if your torch with version: {torch.__version__} matches the config.' | |||
| ) | |||
| raise e | |||
| def build_lr_scheduler(self, cfg: ConfigDict, default_args: dict = None): | |||
| return build_lr_scheduler(cfg=cfg, default_args=default_args) | |||
| try: | |||
| return build_lr_scheduler(cfg=cfg, default_args=default_args) | |||
| except KeyError as e: | |||
| self.logger.error( | |||
| f'Build lr_scheduler error, the lr_scheduler {cfg} is native torch lr_scheduler, ' | |||
| f'please check if your torch with version: {torch.__version__} matches the config.' | |||
| ) | |||
| raise e | |||
| def create_optimizer_and_scheduler(self): | |||
| """ Create optimizer and lr scheduler | |||
| @@ -62,7 +62,10 @@ def single_gpu_test(trainer, | |||
| if 'nsentences' in data: | |||
| batch_size = data['nsentences'] | |||
| else: | |||
| batch_size = len(next(iter(data.values()))) | |||
| try: | |||
| batch_size = len(next(iter(data.values()))) | |||
| except Exception: | |||
| batch_size = data_loader.batch_size | |||
| else: | |||
| batch_size = len(data) | |||
| for _ in range(batch_size): | |||
| @@ -134,9 +134,7 @@ def load_checkpoint(filename, | |||
| state_dict = checkpoint if 'state_dict' not in checkpoint else checkpoint[ | |||
| 'state_dict'] | |||
| model.load_state_dict(state_dict) | |||
| if 'meta' in checkpoint: | |||
| return checkpoint.get('meta', {}) | |||
| return checkpoint.get('meta', {}) | |||
| def save_pretrained(model, | |||
| @@ -238,6 +238,15 @@ class DownloadMode(enum.Enum): | |||
| FORCE_REDOWNLOAD = 'force_redownload' | |||
| class UploadMode(enum.Enum): | |||
| """ How to upload object to remote. | |||
| """ | |||
| # Upload all objects from local, existing remote objects may be overwritten. (Default) | |||
| OVERWRITE = 'overwrite' | |||
| # Upload local objects in append mode, skipping all existing remote objects. | |||
| APPEND = 'append' | |||
| class DatasetFormations(enum.Enum): | |||
| """ How a dataset is organized and interpreted | |||
| """ | |||
| @@ -87,21 +87,23 @@ class HubOperationTest(unittest.TestCase): | |||
| assert mdtime1 == mdtime2 | |||
| def test_download_public_without_login(self): | |||
| self.prepare_case() | |||
| rmtree(ModelScopeConfig.path_credential) | |||
| snapshot_path = snapshot_download( | |||
| model_id=self.model_id, revision=self.revision) | |||
| downloaded_file_path = os.path.join(snapshot_path, | |||
| download_model_file_name) | |||
| assert os.path.exists(downloaded_file_path) | |||
| temporary_dir = tempfile.mkdtemp() | |||
| downloaded_file = model_file_download( | |||
| model_id=self.model_id, | |||
| file_path=download_model_file_name, | |||
| revision=self.revision, | |||
| cache_dir=temporary_dir) | |||
| assert os.path.exists(downloaded_file) | |||
| self.api.login(TEST_ACCESS_TOKEN1) | |||
| try: | |||
| self.prepare_case() | |||
| rmtree(ModelScopeConfig.path_credential) | |||
| snapshot_path = snapshot_download( | |||
| model_id=self.model_id, revision=self.revision) | |||
| downloaded_file_path = os.path.join(snapshot_path, | |||
| download_model_file_name) | |||
| assert os.path.exists(downloaded_file_path) | |||
| temporary_dir = tempfile.mkdtemp() | |||
| downloaded_file = model_file_download( | |||
| model_id=self.model_id, | |||
| file_path=download_model_file_name, | |||
| revision=self.revision, | |||
| cache_dir=temporary_dir) | |||
| assert os.path.exists(downloaded_file) | |||
| finally: | |||
| self.api.login(TEST_ACCESS_TOKEN1) | |||
| def test_snapshot_delete_download_cache_file(self): | |||
| self.prepare_case() | |||
| @@ -0,0 +1,32 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import unittest | |||
| import numpy as np | |||
| from modelscope.metrics.sequence_classification_metric import \ | |||
| SequenceClassificationMetric | |||
| from modelscope.utils.test_utils import test_level | |||
| class TestTextClsMetrics(unittest.TestCase): | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_value(self): | |||
| metric = SequenceClassificationMetric() | |||
| outputs = { | |||
| 'logits': | |||
| np.array([[2.0, 1.0, 0.5], [1.0, 1.5, 1.0], [2.0, 1.0, 3.0], | |||
| [2.4, 1.5, 4.0], [2.0, 1.0, 3.0], [2.4, 1.5, 1.7], | |||
| [2.0, 1.0, 0.5], [2.4, 1.5, 0.5]]) | |||
| } | |||
| inputs = {'labels': np.array([0, 1, 2, 2, 0, 1, 2, 2])} | |||
| metric.add(outputs, inputs) | |||
| ret = metric.evaluate() | |||
| self.assertTrue(np.isclose(ret['f1'], 0.5)) | |||
| self.assertTrue(np.isclose(ret['accuracy'], 0.5)) | |||
| print(ret) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||
| @@ -0,0 +1,112 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| import shutil | |||
| import tempfile | |||
| import unittest | |||
| import zipfile | |||
| from modelscope.msdatasets import MsDataset | |||
| from modelscope.utils import logger as logging | |||
| from modelscope.utils.test_utils import test_level | |||
| logger = logging.get_logger(__name__) | |||
| KEY_EXTRACTED = 'extracted' | |||
| EXPECTED_MSG = 'success' | |||
| class DatasetDeleteTest(unittest.TestCase): | |||
| def setUp(self): | |||
| self.old_dir = os.getcwd() | |||
| self.dataset_name = 'small_coco_for_test' | |||
| self.dataset_file_name = self.dataset_name | |||
| self.prepared_dataset_name = 'pets_small' | |||
| self.token = os.getenv('TEST_UPLOAD_MS_TOKEN') | |||
| error_msg = 'The modelscope token can not be empty, please set env variable: TEST_UPLOAD_MS_TOKEN' | |||
| self.assertIsNotNone(self.token, msg=error_msg) | |||
| from modelscope.hub.api import HubApi | |||
| from modelscope.hub.api import ModelScopeConfig | |||
| self.api = HubApi() | |||
| self.api.login(self.token) | |||
| # get user info | |||
| self.namespace, _ = ModelScopeConfig.get_user_info() | |||
| self.temp_dir = tempfile.mkdtemp() | |||
| self.test_work_dir = os.path.join(self.temp_dir, self.dataset_name) | |||
| if not os.path.exists(self.test_work_dir): | |||
| os.makedirs(self.test_work_dir) | |||
| def tearDown(self): | |||
| os.chdir(self.old_dir) | |||
| shutil.rmtree(self.temp_dir, ignore_errors=True) | |||
| logger.info( | |||
| f'Temporary directory {self.temp_dir} successfully removed!') | |||
| @staticmethod | |||
| def get_raw_downloaded_file_path(extracted_path): | |||
| raw_downloaded_file_path = '' | |||
| raw_data_dir = os.path.abspath( | |||
| os.path.join(extracted_path, '../../..')) | |||
| for root, dirs, files in os.walk(raw_data_dir): | |||
| if KEY_EXTRACTED in dirs: | |||
| for file in files: | |||
| curr_file_path = os.path.join(root, file) | |||
| if zipfile.is_zipfile(curr_file_path): | |||
| raw_downloaded_file_path = curr_file_path | |||
| return raw_downloaded_file_path | |||
| def upload_test_file(self): | |||
| # Get the prepared data from hub, using default modelscope namespace | |||
| ms_ds_train = MsDataset.load(self.prepared_dataset_name, split='train') | |||
| config_res = ms_ds_train._hf_ds.config_kwargs | |||
| extracted_path = config_res.get('split_config').get('train') | |||
| raw_zipfile_path = self.get_raw_downloaded_file_path(extracted_path) | |||
| object_name = self.dataset_file_name + '_for_del.zip' | |||
| MsDataset.upload( | |||
| object_name=object_name, | |||
| local_file_path=raw_zipfile_path, | |||
| dataset_name=self.dataset_name, | |||
| namespace=self.namespace) | |||
| return object_name | |||
| def upload_test_dir(self): | |||
| ms_ds_train = MsDataset.load(self.prepared_dataset_name, split='train') | |||
| config_train = ms_ds_train._hf_ds.config_kwargs | |||
| extracted_path_train = config_train.get('split_config').get('train') | |||
| object_name = 'train_for_del' | |||
| MsDataset.upload( | |||
| object_name=object_name, | |||
| local_file_path=os.path.join(extracted_path_train, | |||
| 'Pets/images/train'), | |||
| dataset_name=self.dataset_name, | |||
| namespace=self.namespace) | |||
| return object_name + '/' | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_ds_delete_object(self): | |||
| # upload prepared data | |||
| file_name = self.upload_test_file() | |||
| dir_name = self.upload_test_dir() | |||
| # delete object | |||
| del_file_msg = MsDataset.delete( | |||
| object_name=file_name, | |||
| dataset_name=self.dataset_name, | |||
| namespace=self.namespace) | |||
| del_dir_msg = MsDataset.delete( | |||
| object_name=dir_name, | |||
| dataset_name=self.dataset_name, | |||
| namespace=self.namespace) | |||
| assert all([del_file_msg == EXPECTED_MSG, del_dir_msg == EXPECTED_MSG]) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||
| @@ -243,6 +243,7 @@ class OfaTasksTest(unittest.TestCase, DemoCompatibilityCheck): | |||
| def test_run_with_text_to_image_synthesis_with_name(self): | |||
| model = 'damo/ofa_text-to-image-synthesis_coco_large_en' | |||
| ofa_pipe = pipeline(Tasks.text_to_image_synthesis, model=model) | |||
| ofa_pipe.model.generator.beam_size = 2 | |||
| example = {'text': 'a bear in the water.'} | |||
| result = ofa_pipe(example) | |||
| result[OutputKeys.OUTPUT_IMG].save('result.png') | |||
| @@ -253,6 +254,7 @@ class OfaTasksTest(unittest.TestCase, DemoCompatibilityCheck): | |||
| model = Model.from_pretrained( | |||
| 'damo/ofa_text-to-image-synthesis_coco_large_en') | |||
| ofa_pipe = pipeline(Tasks.text_to_image_synthesis, model=model) | |||
| ofa_pipe.model.generator.beam_size = 2 | |||
| example = {'text': 'a bear in the water.'} | |||
| result = ofa_pipe(example) | |||
| result[OutputKeys.OUTPUT_IMG].save('result.png') | |||
| @@ -14,7 +14,7 @@ class ReferringVideoObjectSegmentationTest(unittest.TestCase, | |||
| self.task = Tasks.referring_video_object_segmentation | |||
| self.model_id = 'damo/cv_swin-t_referring_video-object-segmentation' | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| @unittest.skip('skip since the model is set to private for now') | |||
| def test_referring_video_object_segmentation(self): | |||
| input_location = 'data/test/videos/referring_video_object_segmentation_test_video.mp4' | |||
| text_queries = [ | |||
| @@ -31,7 +31,7 @@ class ReferringVideoObjectSegmentationTest(unittest.TestCase, | |||
| else: | |||
| raise ValueError('process error') | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| @unittest.skip('skip since the model is set to private for now') | |||
| def test_referring_video_object_segmentation_with_default_task(self): | |||
| input_location = 'data/test/videos/referring_video_object_segmentation_test_video.mp4' | |||
| text_queries = [ | |||
| @@ -183,7 +183,7 @@ class TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck): | |||
| task=Tasks.text_generation, model='langboat/bloom-1b4-zh') | |||
| print(pipe('中国的首都是')) | |||
| @unittest.skip("Langboat's checkpoint has not been uploaded to modelhub") | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_gpt_neo(self): | |||
| pipe = pipeline( | |||
| task=Tasks.text_generation, model='langboat/mengzi-gpt-neo-base') | |||
| @@ -20,16 +20,16 @@ class TinynasObjectDetectionTest(unittest.TestCase, DemoCompatibilityCheck): | |||
| Tasks.image_object_detection, model='damo/cv_tinynas_detection') | |||
| result = tinynas_object_detection( | |||
| 'data/test/images/image_detection.jpg') | |||
| print(result) | |||
| print('airdet', result) | |||
| @unittest.skip('will be enabled after damoyolo officially released') | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_damoyolo(self): | |||
| tinynas_object_detection = pipeline( | |||
| Tasks.image_object_detection, | |||
| model='damo/cv_tinynas_object-detection_damoyolo') | |||
| result = tinynas_object_detection( | |||
| 'data/test/images/image_detection.jpg') | |||
| print(result) | |||
| print('damoyolo', result) | |||
| @unittest.skip('demo compatibility test is only enabled on a needed-basis') | |||
| def test_demo_compatibility(self): | |||
| @@ -39,7 +39,8 @@ class TinynasObjectDetectionTest(unittest.TestCase, DemoCompatibilityCheck): | |||
| def test_image_object_detection_auto_pipeline(self): | |||
| test_image = 'data/test/images/image_detection.jpg' | |||
| tinynas_object_detection = pipeline( | |||
| Tasks.image_object_detection, model='damo/cv_tinynas_detection') | |||
| Tasks.image_object_detection, | |||
| model='damo/cv_tinynas_object-detection_damoyolo') | |||
| result = tinynas_object_detection(test_image) | |||
| tinynas_object_detection.show_result(test_image, result, | |||
| 'demo_ret.jpg') | |||
| @@ -346,7 +346,7 @@ class TestFinetuneSequenceClassification(unittest.TestCase): | |||
| train_datasets = [] | |||
| from datasets import DownloadConfig | |||
| dc = DownloadConfig() | |||
| dc.local_files_only = True | |||
| dc.local_files_only = False | |||
| for lang in langs: | |||
| train_datasets.append( | |||
| load_dataset('xnli', lang, split='train', download_config=dc)) | |||
| @@ -0,0 +1,101 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| import shutil | |||
| import tempfile | |||
| import unittest | |||
| import zipfile | |||
| from modelscope.hub.snapshot_download import snapshot_download | |||
| from modelscope.metainfo import Trainers | |||
| from modelscope.models.cv.movie_scene_segmentation import \ | |||
| MovieSceneSegmentationModel | |||
| from modelscope.msdatasets import MsDataset | |||
| from modelscope.trainers import build_trainer | |||
| from modelscope.utils.config import Config, ConfigDict | |||
| from modelscope.utils.constant import ModelFile | |||
| from modelscope.utils.test_utils import test_level | |||
| class TestImageInstanceSegmentationTrainer(unittest.TestCase): | |||
| model_id = 'damo/cv_swin-t_referring_video-object-segmentation' | |||
| dataset_name = 'referring_vos_toydata' | |||
| def setUp(self): | |||
| print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) | |||
| cache_path = snapshot_download(self.model_id) | |||
| config_path = os.path.join(cache_path, ModelFile.CONFIGURATION) | |||
| cfg = Config.from_file(config_path) | |||
| max_epochs = cfg.train.max_epochs | |||
| train_data_cfg = ConfigDict( | |||
| name=self.dataset_name, | |||
| split='train', | |||
| test_mode=False, | |||
| cfg=cfg.dataset) | |||
| test_data_cfg = ConfigDict( | |||
| name=self.dataset_name, | |||
| split='test', | |||
| test_mode=True, | |||
| cfg=cfg.dataset) | |||
| self.train_dataset = MsDataset.load( | |||
| dataset_name=train_data_cfg.name, | |||
| split=train_data_cfg.split, | |||
| cfg=train_data_cfg.cfg, | |||
| namespace='damo', | |||
| test_mode=train_data_cfg.test_mode) | |||
| assert next( | |||
| iter(self.train_dataset.config_kwargs['split_config'].values())) | |||
| self.test_dataset = MsDataset.load( | |||
| dataset_name=test_data_cfg.name, | |||
| split=test_data_cfg.split, | |||
| cfg=test_data_cfg.cfg, | |||
| namespace='damo', | |||
| test_mode=test_data_cfg.test_mode) | |||
| assert next( | |||
| iter(self.test_dataset.config_kwargs['split_config'].values())) | |||
| self.max_epochs = max_epochs | |||
| @unittest.skip('skip since the model is set to private for now') | |||
| def test_trainer(self): | |||
| kwargs = dict( | |||
| model=self.model_id, | |||
| train_dataset=self.train_dataset, | |||
| eval_dataset=self.test_dataset, | |||
| work_dir='./work_dir') | |||
| trainer = build_trainer( | |||
| name=Trainers.referring_video_object_segmentation, | |||
| default_args=kwargs) | |||
| trainer.train() | |||
| results_files = os.listdir(trainer.work_dir) | |||
| self.assertIn(f'{trainer.timestamp}.log.json', results_files) | |||
| @unittest.skip('skip since the model is set to private for now') | |||
| def test_trainer_with_model_and_args(self): | |||
| cache_path = snapshot_download(self.model_id) | |||
| model = MovieSceneSegmentationModel.from_pretrained(cache_path) | |||
| kwargs = dict( | |||
| cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION), | |||
| model=model, | |||
| train_dataset=self.train_dataset, | |||
| eval_dataset=self.test_dataset, | |||
| work_dir='./work_dir') | |||
| trainer = build_trainer( | |||
| name=Trainers.referring_video_object_segmentation, | |||
| default_args=kwargs) | |||
| trainer.train() | |||
| results_files = os.listdir(trainer.work_dir) | |||
| self.assertIn(f'{trainer.timestamp}.log.json', results_files) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||
| @@ -0,0 +1,94 @@ | |||
| import os | |||
| import unittest | |||
| import json | |||
| import requests | |||
| import torch | |||
| import torch.distributed as dist | |||
| import torch.multiprocessing as mp | |||
| from modelscope.hub.snapshot_download import snapshot_download | |||
| from modelscope.metainfo import Trainers | |||
| from modelscope.msdatasets import MsDataset | |||
| from modelscope.trainers import build_trainer | |||
| from modelscope.trainers.multi_modal.team.team_trainer_utils import ( | |||
| collate_fn, train_mapping, val_mapping) | |||
| from modelscope.utils.config import Config | |||
| from modelscope.utils.constant import DownloadMode, ModeKeys, ModelFile | |||
| from modelscope.utils.logger import get_logger | |||
| from modelscope.utils.test_utils import test_level | |||
| logger = get_logger() | |||
| def train_worker(device_id): | |||
| model_id = 'damo/multi-modal_team-vit-large-patch14_multi-modal-similarity' | |||
| ckpt_dir = './ckpt' | |||
| os.makedirs(ckpt_dir, exist_ok=True) | |||
| # Use epoch=1 for faster training here | |||
| cfg = Config({ | |||
| 'framework': 'pytorch', | |||
| 'task': 'multi-modal-similarity', | |||
| 'pipeline': { | |||
| 'type': 'multi-modal-similarity' | |||
| }, | |||
| 'model': { | |||
| 'type': 'team-multi-modal-similarity' | |||
| }, | |||
| 'dataset': { | |||
| 'name': 'Caltech101', | |||
| 'class_num': 101 | |||
| }, | |||
| 'preprocessor': {}, | |||
| 'train': { | |||
| 'epoch': 1, | |||
| 'batch_size': 32, | |||
| 'ckpt_dir': ckpt_dir | |||
| }, | |||
| 'evaluation': { | |||
| 'batch_size': 64 | |||
| } | |||
| }) | |||
| cfg_file = '{}/{}'.format(ckpt_dir, ModelFile.CONFIGURATION) | |||
| cfg.dump(cfg_file) | |||
| train_dataset = MsDataset.load( | |||
| cfg.dataset.name, | |||
| namespace='modelscope', | |||
| split='train', | |||
| download_mode=DownloadMode.FORCE_REDOWNLOAD).to_hf_dataset() | |||
| train_dataset = train_dataset.with_transform(train_mapping) | |||
| val_dataset = MsDataset.load( | |||
| cfg.dataset.name, | |||
| namespace='modelscope', | |||
| split='validation', | |||
| download_mode=DownloadMode.FORCE_REDOWNLOAD).to_hf_dataset() | |||
| val_dataset = val_dataset.with_transform(val_mapping) | |||
| default_args = dict( | |||
| cfg_file=cfg_file, | |||
| model=model_id, | |||
| device_id=device_id, | |||
| data_collator=collate_fn, | |||
| train_dataset=train_dataset, | |||
| val_dataset=val_dataset) | |||
| trainer = build_trainer( | |||
| name=Trainers.image_classification_team, default_args=default_args) | |||
| trainer.train() | |||
| trainer.evaluate() | |||
| class TEAMTransferTrainerTest(unittest.TestCase): | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_trainer(self): | |||
| if torch.cuda.device_count() > 0: | |||
| train_worker(device_id=0) | |||
| else: | |||
| train_worker(device_id=-1) | |||
| logger.info('Training done') | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||
| @@ -119,7 +119,7 @@ class TestTrainerWithNlp(unittest.TestCase): | |||
| checkpoint_path=os.path.join(self.tmp_dir, 'epoch_10.pth')) | |||
| self.assertTrue(Metrics.accuracy in eval_results) | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| @unittest.skip('skip for now before test is re-configured') | |||
| def test_trainer_with_configured_datasets(self): | |||
| model_id = 'damo/nlp_structbert_sentence-similarity_chinese-base' | |||
| cfg: Config = read_config(model_id) | |||
| @@ -223,13 +223,31 @@ class TestTrainerWithNlp(unittest.TestCase): | |||
| trainer, 'trainer_continue_train', level='strict'): | |||
| trainer.train(os.path.join(self.tmp_dir, 'iter_3.pth')) | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_trainer_with_evaluation(self): | |||
| tmp_dir = tempfile.TemporaryDirectory().name | |||
| if not os.path.exists(tmp_dir): | |||
| os.makedirs(tmp_dir) | |||
| model_id = 'damo/nlp_structbert_sentence-similarity_chinese-tiny' | |||
| cache_path = snapshot_download(model_id) | |||
| model = SbertForSequenceClassification.from_pretrained(cache_path) | |||
| kwargs = dict( | |||
| cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION), | |||
| model=model, | |||
| eval_dataset=self.dataset, | |||
| work_dir=self.tmp_dir) | |||
| trainer = build_trainer(default_args=kwargs) | |||
| print(trainer.evaluate(cache_path + '/pytorch_model.bin')) | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_trainer_with_model_and_args(self): | |||
| tmp_dir = tempfile.TemporaryDirectory().name | |||
| if not os.path.exists(tmp_dir): | |||
| os.makedirs(tmp_dir) | |||
| model_id = 'damo/nlp_structbert_sentence-similarity_chinese-base' | |||
| model_id = 'damo/nlp_structbert_sentence-similarity_chinese-tiny' | |||
| cache_path = snapshot_download(model_id) | |||
| model = SbertForSequenceClassification.from_pretrained(cache_path) | |||
| kwargs = dict( | |||