yingda.chen 3 years ago
parent
commit
0260b89431
21 changed files with 114 additions and 73 deletions
  1. +1
    -1
      docs/source/quick_start.md
  2. +7
    -2
      modelscope/models/multi_modal/clip/clip_model.py
  3. +2
    -1
      modelscope/models/multi_modal/image_captioning_model.py
  4. +3
    -2
      modelscope/pipelines/audio/ans_pipeline.py
  5. +4
    -3
      modelscope/pipelines/audio/linear_aec_pipeline.py
  6. +2
    -1
      modelscope/pipelines/cv/action_recognition_pipeline.py
  7. +5
    -2
      modelscope/pipelines/cv/animal_recog_pipeline.py
  8. +3
    -2
      modelscope/pipelines/cv/image_cartoon_pipeline.py
  9. +4
    -3
      modelscope/pipelines/cv/image_matting_pipeline.py
  10. +2
    -1
      modelscope/pipelines/cv/ocr_detection_pipeline.py
  11. +2
    -1
      modelscope/pipelines/nlp/fill_mask_pipeline.py
  12. +2
    -1
      modelscope/pipelines/nlp/sentence_similarity_pipeline.py
  13. +2
    -1
      modelscope/pipelines/nlp/sequence_classification_pipeline.py
  14. +2
    -1
      modelscope/pipelines/nlp/text_generation_pipeline.py
  15. +2
    -4
      modelscope/pipelines/nlp/word_segmentation_pipeline.py
  16. +3
    -2
      modelscope/pipelines/nlp/zero_shot_classification_pipeline.py
  17. +54
    -35
      modelscope/pipelines/outputs.py
  18. +4
    -3
      tests/pipelines/test_base.py
  19. +2
    -1
      tests/pipelines/test_image_captioning.py
  20. +6
    -5
      tests/pipelines/test_image_matting.py
  21. +2
    -1
      tests/pipelines/test_person_image_cartoon.py

+ 1
- 1
docs/source/quick_start.md View File

@@ -29,7 +29,7 @@ pip install model_scope[all] -f https://pai-vision-data-hz.oss-cn-zhangjiakou.al
```
### 使用源码安装
适合本地开发调试使用,修改源码后可以直接执行
下载源码前首先联系(临在,谦言,颖达,一耘)申请代码库权限,clone代码到本地
下载源码可以直接clone代码到本地
```shell
git clone git@gitlab.alibaba-inc.com:Ali-MaaS/MaaS-lib.git modelscope
git fetch origin master


+ 7
- 2
modelscope/models/multi_modal/clip/clip_model.py View File

@@ -108,7 +108,11 @@ class CLIPForMultiModalEmbedding(Model):
return text_ids_tensor, text_mask_tensor

def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
output = {'img_embedding': None, 'text_embedding': None}
from modelscope.pipelines.outputs import OutputKeys
output = {
OutputKeys.IMG_EMBEDDING: None,
OutputKeys.TEXT_EMBEDDING: None
}
if 'img' in input and input['img'] is not None:
input_img = input['img']
if isinstance(input_img, Image.Image):
@@ -130,7 +134,8 @@ class CLIPForMultiModalEmbedding(Model):

img_embedding = self.clip_model(
input_data=img_tensor, input_type='img')
output['img_embedding'] = img_embedding.data.cpu().numpy()
from modelscope.pipelines.outputs import OutputKeys
output[OutputKeys.IMG_EMBEDDING] = img_embedding.data.cpu().numpy()

if 'text' in input and input['text'] is not None:
text_str = input['text']


+ 2
- 1
modelscope/models/multi_modal/image_captioning_model.py View File

@@ -76,9 +76,10 @@ class OfaForImageCaptioning(Model):
input = fairseq.utils.move_to_cuda(input, device=self._device)
results, _ = self.eval_caption(self.task, self.generator, self.models,
input)
from ...pipelines.outputs import OutputKeys
return {
'image_id': results[0]['image_id'],
'caption': results[0]['caption']
OutputKeys.CAPTION: results[0][OutputKeys.CAPTION]
}

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:


+ 3
- 2
modelscope/pipelines/audio/ans_pipeline.py View File

@@ -10,6 +10,7 @@ from modelscope.metainfo import Pipelines
from modelscope.utils.constant import Tasks
from ..base import Input, Pipeline
from ..builder import PIPELINES
from ..outputs import OutputKeys


def audio_norm(x):
@@ -108,10 +109,10 @@ class ANSPipeline(Pipeline):
current_idx += stride
else:
outputs = self.model(ndarray)['wav_l2'][0].cpu().numpy()
return {'output_pcm': outputs[:nsamples]}
return {OutputKeys.OUTPUT_PCM: outputs[:nsamples]}

def postprocess(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]:
if 'output_path' in kwargs.keys():
sf.write(kwargs['output_path'], inputs['output_pcm'],
sf.write(kwargs['output_path'], inputs[OutputKeys.OUTPUT_PCM],
self.SAMPLE_RATE)
return inputs

+ 4
- 3
modelscope/pipelines/audio/linear_aec_pipeline.py View File

@@ -12,6 +12,7 @@ from modelscope.preprocessors.audio import LinearAECAndFbank
from modelscope.utils.constant import ModelFile, Tasks
from ..base import Pipeline
from ..builder import PIPELINES
from ..outputs import OutputKeys

FEATURE_MVN = 'feature.DEY.mvn.txt'

@@ -120,7 +121,7 @@ class LinearAECPipeline(Pipeline):
}
"""
output_data = self._process(inputs['feature'], inputs['base'])
return {'output_pcm': output_data}
return {OutputKeys.OUTPUT_PCM: output_data}

def postprocess(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]:
r"""The post process. Will save audio to file, if the output_path is given.
@@ -140,8 +141,8 @@ class LinearAECPipeline(Pipeline):
"""
if 'output_path' in kwargs.keys():
wav.write(kwargs['output_path'], self.preprocessor.SAMPLE_RATE,
inputs['output_pcm'].astype(np.int16))
inputs['output_pcm'] = inputs['output_pcm'] / 32768.0
inputs[OutputKeys.OUTPUT_PCM].astype(np.int16))
inputs[OutputKeys.OUTPUT_PCM] = inputs[OutputKeys.OUTPUT_PCM] / 32768.0
return inputs

def _process(self, fbanks, mixture):


+ 2
- 1
modelscope/pipelines/cv/action_recognition_pipeline.py View File

@@ -16,6 +16,7 @@ from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger
from ..base import Pipeline
from ..builder import PIPELINES
from ..outputs import OutputKeys

logger = get_logger()

@@ -49,7 +50,7 @@ class ActionRecognitionPipeline(Pipeline):
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
pred = self.perform_inference(input['video_data'])
output_label = self.label_mapping[str(pred)]
return {'output_label': output_label}
return {OutputKeys.LABELS: output_label}

@torch.no_grad()
def perform_inference(self, data, max_bsz=4):


+ 5
- 2
modelscope/pipelines/cv/animal_recog_pipeline.py View File

@@ -18,6 +18,7 @@ from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger
from ..base import Pipeline
from ..builder import PIPELINES
from ..outputs import OutputKeys

logger = get_logger()

@@ -121,7 +122,9 @@ class AnimalRecogPipeline(Pipeline):
label_mapping = f.readlines()
score = torch.max(inputs['outputs'])
inputs = {
'scores': score.item(),
'labels': label_mapping[inputs['outputs'].argmax()].split('\t')[1]
OutputKeys.SCORES:
score.item(),
OutputKeys.LABELS:
label_mapping[inputs['outputs'].argmax()].split('\t')[1]
}
return inputs

+ 3
- 2
modelscope/pipelines/cv/image_cartoon_pipeline.py View File

@@ -17,6 +17,7 @@ from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger
from ..base import Pipeline
from ..builder import PIPELINES
from ..outputs import OutputKeys

if tf.__version__ >= '2.0':
tf = tf.compat.v1
@@ -94,7 +95,7 @@ class ImageCartoonPipeline(Pipeline):
landmarks = self.detect_face(img)
if landmarks is None:
print('No face detected!')
return {'output_png': None}
return {OutputKeys.OUTPUT_IMG: None}

# background process
pad_bg, pad_h, pad_w = padTo16x(img_brg)
@@ -143,7 +144,7 @@ class ImageCartoonPipeline(Pipeline):

res = cv2.resize(res, (ori_w, ori_h), interpolation=cv2.INTER_AREA)

return {'output_png': res}
return {OutputKeys.OUTPUT_IMG: res}

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
return inputs

+ 4
- 3
modelscope/pipelines/cv/image_matting_pipeline.py View File

@@ -12,6 +12,7 @@ from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger
from ..base import Pipeline
from ..builder import PIPELINES
from ..outputs import OutputKeys

logger = get_logger()

@@ -60,9 +61,9 @@ class ImageMattingPipeline(Pipeline):
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
with self._session.as_default():
feed_dict = {self.input_name: input['img']}
output_png = self._session.run(self.output, feed_dict=feed_dict)
output_png = cv2.cvtColor(output_png, cv2.COLOR_RGBA2BGRA)
return {'output_png': output_png}
output_img = self._session.run(self.output, feed_dict=feed_dict)
output_img = cv2.cvtColor(output_img, cv2.COLOR_RGBA2BGRA)
return {OutputKeys.OUTPUT_IMG: output_img}

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
return inputs

+ 2
- 1
modelscope/pipelines/cv/ocr_detection_pipeline.py View File

@@ -16,6 +16,7 @@ from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger
from ..base import Pipeline
from ..builder import PIPELINES
from ..outputs import OutputKeys
from .ocr_utils import model_resnet_mutex_v4_linewithchar, ops, utils

if tf.__version__ >= '2.0':
@@ -174,5 +175,5 @@ class OCRDetectionPipeline(Pipeline):
dt_nms = utils.nms_python(dt_n9)
dt_polygons = np.array([o[:8] for o in dt_nms])

result = {'det_polygons': dt_polygons}
result = {OutputKeys.POLYGONS: dt_polygons}
return result

+ 2
- 1
modelscope/pipelines/nlp/fill_mask_pipeline.py View File

@@ -9,6 +9,7 @@ from ...utils.config import Config
from ...utils.constant import ModelFile, Tasks
from ..base import Pipeline, Tensor
from ..builder import PIPELINES
from ..outputs import OutputKeys

__all__ = ['FillMaskPipeline']
_type_map = {'veco': 'roberta', 'sbert': 'bert'}
@@ -96,4 +97,4 @@ class FillMaskPipeline(Pipeline):
pred_string = rep_tokens(pred_string, self.rep_map[process_type])
pred_strings.append(pred_string)

return {'text': pred_strings}
return {OutputKeys.TEXT: pred_strings}

+ 2
- 1
modelscope/pipelines/nlp/sentence_similarity_pipeline.py View File

@@ -9,6 +9,7 @@ from modelscope.utils.constant import Tasks
from ...models import Model
from ..base import Input, Pipeline
from ..builder import PIPELINES
from ..outputs import OutputKeys

__all__ = ['SentenceSimilarityPipeline']

@@ -59,4 +60,4 @@ class SentenceSimilarityPipeline(Pipeline):
probs = probs[cls_ids].tolist()
cls_names = [self.model.id2label[cid] for cid in cls_ids]
b = 0
return {'scores': probs[b], 'labels': cls_names[b]}
return {OutputKeys.SCORES: probs[b], OutputKeys.LABELS: cls_names[b]}

+ 2
- 1
modelscope/pipelines/nlp/sequence_classification_pipeline.py View File

@@ -9,6 +9,7 @@ from modelscope.utils.constant import Tasks
from ...models import Model
from ..base import Input, Pipeline
from ..builder import PIPELINES
from ..outputs import OutputKeys

__all__ = ['SequenceClassificationPipeline']

@@ -64,4 +65,4 @@ class SequenceClassificationPipeline(Pipeline):

cls_names = [self.model.id2label[cid] for cid in cls_ids]

return {'scores': probs, 'labels': cls_names}
return {OutputKeys.SCORES: probs, OutputKeys.LABELS: cls_names}

+ 2
- 1
modelscope/pipelines/nlp/text_generation_pipeline.py View File

@@ -7,6 +7,7 @@ from modelscope.preprocessors import TextGenerationPreprocessor
from modelscope.utils.constant import Tasks
from ..base import Pipeline, Tensor
from ..builder import PIPELINES
from ..outputs import OutputKeys

__all__ = ['TextGenerationPipeline']

@@ -61,4 +62,4 @@ class TextGenerationPipeline(Pipeline):
for _old, _new in replace_tokens_roberta:
pred_string = pred_string.replace(_old, _new)
pred_string.strip()
return {'text': pred_string}
return {OutputKeys.TEXT: pred_string}

+ 2
- 4
modelscope/pipelines/nlp/word_segmentation_pipeline.py View File

@@ -7,6 +7,7 @@ from modelscope.preprocessors import TokenClassifcationPreprocessor
from modelscope.utils.constant import Tasks
from ..base import Pipeline, Tensor
from ..builder import PIPELINES
from ..outputs import OutputKeys

__all__ = ['WordSegmentationPipeline']

@@ -63,7 +64,4 @@ class WordSegmentationPipeline(Pipeline):
if chunk:
chunks.append(chunk)
seg_result = ' '.join(chunks)
rst = {
'output': seg_result,
}
return rst
return {OutputKeys.OUTPUT: seg_result}

+ 3
- 2
modelscope/pipelines/nlp/zero_shot_classification_pipeline.py View File

@@ -14,6 +14,7 @@ from ...preprocessors import ZeroShotClassificationPreprocessor
from ...utils.constant import Tasks
from ..base import Input, Pipeline
from ..builder import PIPELINES
from ..outputs import OutputKeys

__all__ = ['ZeroShotClassificationPipeline']

@@ -91,7 +92,7 @@ class ZeroShotClassificationPipeline(Pipeline):

reversed_index = list(reversed(scores.argsort()))
result = {
'labels': [candidate_labels[i] for i in reversed_index],
'scores': [scores[i].item() for i in reversed_index],
OutputKeys.LABELS: [candidate_labels[i] for i in reversed_index],
OutputKeys.SCORES: [scores[i].item() for i in reversed_index],
}
return result

+ 54
- 35
modelscope/pipelines/outputs.py View File

@@ -2,54 +2,72 @@

from modelscope.utils.constant import Tasks


class OutputKeys(object):
SCORES = 'scores'
LABELS = 'labels'
POSES = 'poses'
CAPTION = 'caption'
BOXES = 'boxes'
TEXT = 'text'
POLYGONS = 'polygons'
OUTPUT = 'output'
OUTPUT_IMG = 'output_img'
OUTPUT_PCM = 'output_pcm'
IMG_EMBEDDING = 'img_embedding'
TEXT_EMBEDDING = 'text_embedding'


TASK_OUTPUTS = {

# ============ vision tasks ===================

# image classification result for single sample
# {
# "labels": ["dog", "horse", "cow", "cat"],
# "scores": [0.9, 0.1, 0.05, 0.05]
# "labels": ["dog", "horse", "cow", "cat"],
# }
Tasks.image_classification: ['scores', 'labels'],
Tasks.image_tagging: ['scores', 'labels'],
Tasks.image_classification: [OutputKeys.SCORES, OutputKeys.LABELS],
Tasks.image_tagging: [OutputKeys.SCORES, OutputKeys.LABELS],

# object detection result for single sample
# {
# "scores": [0.9, 0.1, 0.05, 0.05]
# "labels": ["dog", "horse", "cow", "cat"],
# "boxes": [
# [x1, y1, x2, y2],
# [x1, y1, x2, y2],
# [x1, y1, x2, y2],
# ],
# "labels": ["dog", "horse", "cow", "cat"],
# "scores": [0.9, 0.1, 0.05, 0.05]
# }
Tasks.object_detection: ['scores', 'labels', 'boxes'],
Tasks.object_detection:
[OutputKeys.SCORES, OutputKeys.LABELS, OutputKeys.BOXES],

# instance segmentation result for single sample
# {
# "masks": [
# np.array in bgr channel order
# ],
# "scores": [0.9, 0.1, 0.05, 0.05],
# "labels": ["dog", "horse", "cow", "cat"],
# "scores": [0.9, 0.1, 0.05, 0.05]
# "boxes": [
# np.array in bgr channel order
# ]
# }
Tasks.image_segmentation: ['scores', 'labels', 'boxes'],
Tasks.image_segmentation:
[OutputKeys.SCORES, OutputKeys.LABELS, OutputKeys.BOXES],

# image generation/editing/matting result for single sample
# {
# "output_png": np.array with shape(h, w, 4)
# "output_img": np.array with shape(h, w, 4)
# for matting or (h, w, 3) for general purpose
# }
Tasks.image_editing: ['output_png'],
Tasks.image_matting: ['output_png'],
Tasks.image_generation: ['output_png'],
Tasks.image_editing: [OutputKeys.OUTPUT_IMG],
Tasks.image_matting: [OutputKeys.OUTPUT_IMG],
Tasks.image_generation: [OutputKeys.OUTPUT_IMG],

# action recognition result for single video
# {
# "output_label": "abseiling"
# }
Tasks.action_recognition: ['output_label'],
Tasks.action_recognition: [OutputKeys.LABELS],

# pose estimation result for single sample
# {
@@ -58,55 +76,55 @@ TASK_OUTPUTS = {
# "boxes": np.array with shape [num_pose, 4], each box is
# [x1, y1, x2, y2]
# }
Tasks.pose_estimation: ['poses', 'boxes'],
Tasks.pose_estimation: [OutputKeys.POSES, OutputKeys.BOXES],

# ocr detection result for single sample
# {
# "det_polygons": np.array with shape [num_text, 8], each box is
# "polygons": np.array with shape [num_text, 8], each polygon is
# [x1, y1, x2, y2, x3, y3, x4, y4]
# }
Tasks.ocr_detection: ['det_polygons'],
Tasks.ocr_detection: [OutputKeys.POLYGONS],

# ============ nlp tasks ===================

# text classification result for single sample
# {
# "labels": ["happy", "sad", "calm", "angry"],
# "scores": [0.9, 0.1, 0.05, 0.05]
# "labels": ["happy", "sad", "calm", "angry"],
# }
Tasks.text_classification: ['scores', 'labels'],
Tasks.text_classification: [OutputKeys.SCORES, OutputKeys.LABELS],

# text generation result for single sample
# {
# "text": "this is text generated by a model."
# "text": "this is the text generated by a model."
# }
Tasks.text_generation: ['text'],
Tasks.text_generation: [OutputKeys.TEXT],

# fill mask result for single sample
# {
# "text": "this is the text which masks filled by model."
# }
Tasks.fill_mask: ['text'],
Tasks.fill_mask: [OutputKeys.TEXT],

# word segmentation result for single sample
# {
# "output": "今天 天气 不错 , 适合 出去 游玩"
# }
Tasks.word_segmentation: ['output'],
Tasks.word_segmentation: [OutputKeys.OUTPUT],

# sentence similarity result for single sample
# {
# "labels": "1",
# "scores": 0.9
# "labels": "1",
# }
Tasks.sentence_similarity: ['scores', 'labels'],
Tasks.sentence_similarity: [OutputKeys.SCORES, OutputKeys.LABELS],

# zero-shot classification result for single sample
# {
# "labels": ["happy", "sad", "calm", "angry"],
# "scores": [0.9, 0.1, 0.05, 0.05]
# "labels": ["happy", "sad", "calm", "angry"],
# }
Tasks.zero_shot_classification: ['scores', 'labels'],
Tasks.zero_shot_classification: [OutputKeys.SCORES, OutputKeys.LABELS],

# ============ audio tasks ===================

@@ -114,7 +132,7 @@ TASK_OUTPUTS = {
# {
# "output_pcm": np.array with shape(samples,) and dtype float32
# }
Tasks.speech_signal_process: ['output_pcm'],
Tasks.speech_signal_process: [OutputKeys.OUTPUT_PCM],

# ============ multi-modal tasks ===================

@@ -122,14 +140,15 @@ TASK_OUTPUTS = {
# {
# "caption": "this is an image caption text."
# }
Tasks.image_captioning: ['caption'],
Tasks.image_captioning: [OutputKeys.CAPTION],

# multi-modal embedding result for single sample
# {
# "img_embedding": np.array with shape [1, D],
# "text_embedding": np.array with shape [1, D]
# }
Tasks.multi_modal_embedding: ['img_embedding', 'text_embedding'],
Tasks.multi_modal_embedding:
[OutputKeys.IMG_EMBEDDING, OutputKeys.TEXT_EMBEDDING],

# visual grounding result for single sample
# {
@@ -140,11 +159,11 @@ TASK_OUTPUTS = {
# ],
# "scores": [0.9, 0.1, 0.05, 0.05]
# }
Tasks.visual_grounding: ['boxes', 'scores'],
Tasks.visual_grounding: [OutputKeys.BOXES, OutputKeys.SCORES],

# text_to_image result for a single sample
# {
# "image": np.ndarray with shape [height, width, 3]
# "output_img": np.ndarray with shape [height, width, 3]
# }
Tasks.text_to_image_synthesis: ['image']
Tasks.text_to_image_synthesis: [OutputKeys.OUTPUT_IMG]
}

+ 4
- 3
tests/pipelines/test_base.py View File

@@ -8,6 +8,7 @@ import PIL

from modelscope.pipelines import Pipeline, pipeline
from modelscope.pipelines.builder import PIPELINES, add_default_pipeline_info
from modelscope.pipelines.outputs import OutputKeys
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger
from modelscope.utils.registry import default_group
@@ -68,7 +69,7 @@ class CustomPipelineTest(unittest.TestCase):
outputs['filename'] = inputs['url']
img = inputs['img']
new_image = img.resize((img.width // 2, img.height // 2))
outputs['output_png'] = np.array(new_image)
outputs[OutputKeys.OUTPUT_IMG] = np.array(new_image)
return outputs

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
@@ -83,13 +84,13 @@ class CustomPipelineTest(unittest.TestCase):
img_url = 'data/test/images/image1.jpg'
output = pipe(img_url)
self.assertEqual(output['filename'], img_url)
self.assertEqual(output['output_png'].shape, (318, 512, 3))
self.assertEqual(output[OutputKeys.OUTPUT_IMG].shape, (318, 512, 3))

outputs = pipe([img_url for i in range(4)])
self.assertEqual(len(outputs), 4)
for out in outputs:
self.assertEqual(out['filename'], img_url)
self.assertEqual(out['output_png'].shape, (318, 512, 3))
self.assertEqual(out[OutputKeys.OUTPUT_IMG].shape, (318, 512, 3))


if __name__ == '__main__':


+ 2
- 1
tests/pipelines/test_image_captioning.py View File

@@ -3,6 +3,7 @@
import unittest

from modelscope.pipelines import pipeline
from modelscope.pipelines.outputs import OutputKeys
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level

@@ -15,7 +16,7 @@ class ImageCaptionTest(unittest.TestCase):
Tasks.image_captioning,
model='damo/ofa_image-caption_coco_large_en')
result = img_captioning('data/test/images/image_captioning.png')
print(result['caption'])
print(result[OutputKeys.CAPTION])


if __name__ == '__main__':


+ 6
- 5
tests/pipelines/test_image_matting.py View File

@@ -9,6 +9,7 @@ import cv2
from modelscope.fileio import File
from modelscope.msdatasets import MsDataset
from modelscope.pipelines import pipeline
from modelscope.pipelines.outputs import OutputKeys
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.test_utils import test_level

@@ -29,7 +30,7 @@ class ImageMattingTest(unittest.TestCase):
img_matting = pipeline(Tasks.image_matting, model=tmp_dir)

result = img_matting('data/test/images/image_matting.png')
cv2.imwrite('result.png', result['output_png'])
cv2.imwrite('result.png', result[OutputKeys.OUTPUT_IMG])

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_dataset(self):
@@ -41,7 +42,7 @@ class ImageMattingTest(unittest.TestCase):
img_matting = pipeline(Tasks.image_matting, model=self.model_id)
# note that for dataset output, the inference-output is a Generator that can be iterated.
result = img_matting(dataset)
cv2.imwrite('result.png', next(result)['output_png'])
cv2.imwrite('result.png', next(result)[OutputKeys.OUTPUT_IMG])
print(f'Output written to {osp.abspath("result.png")}')

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
@@ -49,7 +50,7 @@ class ImageMattingTest(unittest.TestCase):
img_matting = pipeline(Tasks.image_matting, model=self.model_id)

result = img_matting('data/test/images/image_matting.png')
cv2.imwrite('result.png', result['output_png'])
cv2.imwrite('result.png', result[OutputKeys.OUTPUT_IMG])
print(f'Output written to {osp.abspath("result.png")}')

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
@@ -57,7 +58,7 @@ class ImageMattingTest(unittest.TestCase):
img_matting = pipeline(Tasks.image_matting)

result = img_matting('data/test/images/image_matting.png')
cv2.imwrite('result.png', result['output_png'])
cv2.imwrite('result.png', result[OutputKeys.OUTPUT_IMG])
print(f'Output written to {osp.abspath("result.png")}')

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
@@ -67,7 +68,7 @@ class ImageMattingTest(unittest.TestCase):
img_matting = pipeline(Tasks.image_matting, model=self.model_id)
result = img_matting(dataset)
for i in range(10):
cv2.imwrite(f'result_{i}.png', next(result)['output_png'])
cv2.imwrite(f'result_{i}.png', next(result)[OutputKeys.OUTPUT_IMG])
print(
f'Output written to dir: {osp.dirname(osp.abspath("result_0.png"))}'
)


+ 2
- 1
tests/pipelines/test_person_image_cartoon.py View File

@@ -7,6 +7,7 @@ import cv2

from modelscope.pipelines import pipeline
from modelscope.pipelines.base import Pipeline
from modelscope.pipelines.outputs import OutputKeys
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level

@@ -22,7 +23,7 @@ class ImageCartoonTest(unittest.TestCase):
def pipeline_inference(self, pipeline: Pipeline, input_location: str):
result = pipeline(input_location)
if result is not None:
cv2.imwrite('result.png', result['output_png'])
cv2.imwrite('result.png', result[OutputKeys.OUTPUT_IMG])
print(f'Output written to {osp.abspath("result.png")}')

@unittest.skip('deprecated, download model from model hub instead')


Loading…
Cancel
Save