diff --git a/modelscope/models/multi_modal/mmr/models/clip_for_mm_video_embedding.py b/modelscope/models/multi_modal/mmr/models/clip_for_mm_video_embedding.py index 2a72985f..f1b1a6c7 100644 --- a/modelscope/models/multi_modal/mmr/models/clip_for_mm_video_embedding.py +++ b/modelscope/models/multi_modal/mmr/models/clip_for_mm_video_embedding.py @@ -1,9 +1,13 @@ # The implementation is adopted from the CLIP4Clip implementation, # made pubicly available under Apache License, Version 2.0 at https://github.com/ArrowLuo/CLIP4Clip +import os import random +import uuid from os.path import exists +from tempfile import TemporaryDirectory from typing import Any, Dict +from urllib.parse import urlparse import json import numpy as np @@ -11,6 +15,7 @@ import torch from decord import VideoReader, cpu from PIL import Image +from modelscope.hub.file_download import http_get_file from modelscope.metainfo import Models from modelscope.models import TorchModel from modelscope.models.builder import MODELS @@ -68,12 +73,16 @@ class VideoCLIPForMultiModalEmbedding(TorchModel): self.model.to(self.device) def _get_text(self, caption, tokenizer, enable_zh=False): - if len(caption) == 3: - _caption_text, s, e = caption - elif len(caption) == 4: - _caption_text, s, e, pos = caption - else: - NotImplementedError + + if type(caption) is str: + _caption_text, s, e = caption, None, None + elif type(caption) is tuple: + if len(caption) == 3: + _caption_text, s, e = caption + elif len(caption) == 4: + _caption_text, s, e, pos = caption + else: + NotImplementedError if isinstance(_caption_text, list): caption_text = random.choice(_caption_text) @@ -137,11 +146,25 @@ class VideoCLIPForMultiModalEmbedding(TorchModel): elif start_time == end_time: end_time = end_time + 1 - if exists(video_path): + url_parsed = urlparse(video_path) + if url_parsed.scheme in ('file', '') and exists( + url_parsed.path): # Possibly a local file vreader = VideoReader(video_path, ctx=cpu(0)) else: - logger.error('non video input, output is wrong!!!') - return video, video_mask + try: + with TemporaryDirectory() as temporary_cache_dir: + random_str = uuid.uuid4().hex + http_get_file( + url=video_path, + local_dir=temporary_cache_dir, + file_name=random_str, + cookies=None) + temp_file_path = os.path.join(temporary_cache_dir, + random_str) + vreader = VideoReader(temp_file_path, ctx=cpu(0)) + except Exception as ex: + logger.error('non video input, output is {}!!!'.format(ex)) + return video, video_mask fps = vreader.get_avg_fps() f_start = 0 if start_time is None else int(start_time * fps) diff --git a/tests/pipelines/test_video_multi_modal_embedding.py b/tests/pipelines/test_video_multi_modal_embedding.py index f4aa4d24..afe5940d 100644 --- a/tests/pipelines/test_video_multi_modal_embedding.py +++ b/tests/pipelines/test_video_multi_modal_embedding.py @@ -17,8 +17,8 @@ class VideoMultiModalEmbeddingTest(unittest.TestCase, DemoCompatibilityCheck): self.task = Tasks.video_multi_modal_embedding self.model_id = 'damo/multi_modal_clip_vtretrival_msrvtt_53' - video_path = 'data/test/videos/multi_modal_test_video_9770.mp4' - caption = ('a person is connecting something to system', None, None) + video_path = 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/videos/multi_modal_test_video_9770.mp4' + caption = 'a person is connecting something to system' _input = {'video': video_path, 'text': caption} @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')