Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10233235 * add input definition for pipeline * support multiple input format for single pipeline * change param for body_3d_keypointsmaster
| @@ -0,0 +1,236 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import cv2 | |||||
| import numpy as np | |||||
| from PIL import Image | |||||
| from modelscope.models.base.base_head import Input | |||||
| from modelscope.utils.constant import Tasks | |||||
| class InputKeys(object): | |||||
| IMAGE = 'image' | |||||
| TEXT = 'text' | |||||
| VIDEO = 'video' | |||||
| class InputType(object): | |||||
| IMAGE = 'image' | |||||
| TEXT = 'text' | |||||
| AUDIO = 'audio' | |||||
| VIDEO = 'video' | |||||
| BOX = 'box' | |||||
| DICT = 'dict' | |||||
| LIST = 'list' | |||||
| INT = 'int' | |||||
| INPUT_TYPE = { | |||||
| InputType.IMAGE: (str, np.ndarray, Image.Image), | |||||
| InputType.TEXT: str, | |||||
| InputType.AUDIO: (str, np.ndarray), | |||||
| InputType.VIDEO: (str, np.ndarray, cv2.VideoCapture), | |||||
| InputType.BOX: (list, np.ndarray), | |||||
| InputType.DICT: (dict, type(None)), | |||||
| InputType.LIST: (list, type(None)), | |||||
| InputType.INT: int, | |||||
| } | |||||
| def check_input_type(input_type, input): | |||||
| expected_type = INPUT_TYPE[input_type] | |||||
| assert isinstance(input, expected_type), \ | |||||
| f'invalid input type for {input_type}, expected {expected_type} but got {type(input)}\n {input}' | |||||
| TASK_INPUTS = { | |||||
| # if task input is single var, value is InputType | |||||
| # if task input is a tuple, value is tuple of InputType | |||||
| # if task input is a dict, value is a dict of InputType, where key | |||||
| # equals the one needed in pipeline input dict | |||||
| # if task input is a list, value is a set of input format, in which | |||||
| # each elements corresponds to one input format as described above. | |||||
| # ============ vision tasks =================== | |||||
| Tasks.ocr_detection: | |||||
| InputType.IMAGE, | |||||
| Tasks.ocr_recognition: | |||||
| InputType.IMAGE, | |||||
| Tasks.face_2d_keypoints: | |||||
| InputType.IMAGE, | |||||
| Tasks.face_detection: | |||||
| InputType.IMAGE, | |||||
| Tasks.facial_expression_recognition: | |||||
| InputType.IMAGE, | |||||
| Tasks.face_recognition: | |||||
| InputType.IMAGE, | |||||
| Tasks.human_detection: | |||||
| InputType.IMAGE, | |||||
| Tasks.face_image_generation: | |||||
| InputType.INT, | |||||
| Tasks.image_classification: | |||||
| InputType.IMAGE, | |||||
| Tasks.image_object_detection: | |||||
| InputType.IMAGE, | |||||
| Tasks.image_segmentation: | |||||
| InputType.IMAGE, | |||||
| Tasks.portrait_matting: | |||||
| InputType.IMAGE, | |||||
| # image editing task result for a single image | |||||
| Tasks.skin_retouching: | |||||
| InputType.IMAGE, | |||||
| Tasks.image_super_resolution: | |||||
| InputType.IMAGE, | |||||
| Tasks.image_colorization: | |||||
| InputType.IMAGE, | |||||
| Tasks.image_color_enhancement: | |||||
| InputType.IMAGE, | |||||
| Tasks.image_denoising: | |||||
| InputType.IMAGE, | |||||
| Tasks.image_portrait_enhancement: | |||||
| InputType.IMAGE, | |||||
| Tasks.crowd_counting: | |||||
| InputType.IMAGE, | |||||
| # image generation task result for a single image | |||||
| Tasks.image_to_image_generation: | |||||
| InputType.IMAGE, | |||||
| Tasks.image_to_image_translation: | |||||
| InputType.IMAGE, | |||||
| Tasks.image_style_transfer: | |||||
| InputType.IMAGE, | |||||
| Tasks.image_portrait_stylization: | |||||
| InputType.IMAGE, | |||||
| Tasks.live_category: | |||||
| InputType.VIDEO, | |||||
| Tasks.action_recognition: | |||||
| InputType.VIDEO, | |||||
| Tasks.body_2d_keypoints: | |||||
| InputType.IMAGE, | |||||
| Tasks.body_3d_keypoints: | |||||
| InputType.VIDEO, | |||||
| Tasks.hand_2d_keypoints: | |||||
| InputType.IMAGE, | |||||
| Tasks.video_single_object_tracking: (InputType.VIDEO, InputType.BOX), | |||||
| Tasks.video_category: | |||||
| InputType.VIDEO, | |||||
| Tasks.product_retrieval_embedding: | |||||
| InputType.IMAGE, | |||||
| Tasks.video_embedding: | |||||
| InputType.VIDEO, | |||||
| Tasks.virtual_try_on: (InputType.IMAGE, InputType.IMAGE, InputType.IMAGE), | |||||
| Tasks.text_driven_segmentation: { | |||||
| InputKeys.IMAGE: InputType.IMAGE, | |||||
| InputKeys.TEXT: InputType.TEXT | |||||
| }, | |||||
| Tasks.shop_segmentation: | |||||
| InputType.IMAGE, | |||||
| Tasks.movie_scene_segmentation: | |||||
| InputType.VIDEO, | |||||
| # ============ nlp tasks =================== | |||||
| Tasks.text_classification: [ | |||||
| InputType.TEXT, | |||||
| (InputType.TEXT, InputType.TEXT), | |||||
| { | |||||
| 'text': InputType.TEXT, | |||||
| 'text2': InputType.TEXT | |||||
| }, | |||||
| ], | |||||
| Tasks.sentence_similarity: (InputType.TEXT, InputType.TEXT), | |||||
| Tasks.nli: (InputType.TEXT, InputType.TEXT), | |||||
| Tasks.sentiment_classification: | |||||
| InputType.TEXT, | |||||
| Tasks.zero_shot_classification: | |||||
| InputType.TEXT, | |||||
| Tasks.relation_extraction: | |||||
| InputType.TEXT, | |||||
| Tasks.translation: | |||||
| InputType.TEXT, | |||||
| Tasks.word_segmentation: | |||||
| InputType.TEXT, | |||||
| Tasks.part_of_speech: | |||||
| InputType.TEXT, | |||||
| Tasks.named_entity_recognition: | |||||
| InputType.TEXT, | |||||
| Tasks.text_error_correction: | |||||
| InputType.TEXT, | |||||
| Tasks.sentence_embedding: { | |||||
| 'source_sentence': InputType.LIST, | |||||
| 'sentences_to_compare': InputType.LIST, | |||||
| }, | |||||
| Tasks.passage_ranking: (InputType.TEXT, InputType.TEXT), | |||||
| Tasks.text_generation: | |||||
| InputType.TEXT, | |||||
| Tasks.fill_mask: | |||||
| InputType.TEXT, | |||||
| Tasks.task_oriented_conversation: { | |||||
| 'user_input': InputType.TEXT, | |||||
| 'history': InputType.DICT, | |||||
| }, | |||||
| Tasks.table_question_answering: { | |||||
| 'question': InputType.TEXT, | |||||
| 'history_sql': InputType.DICT, | |||||
| }, | |||||
| Tasks.faq_question_answering: { | |||||
| 'query_set': InputType.LIST, | |||||
| 'support_set': InputType.LIST, | |||||
| }, | |||||
| # ============ audio tasks =================== | |||||
| Tasks.auto_speech_recognition: | |||||
| InputType.AUDIO, | |||||
| Tasks.speech_signal_process: | |||||
| InputType.AUDIO, | |||||
| Tasks.acoustic_echo_cancellation: { | |||||
| 'nearend_mic': InputType.AUDIO, | |||||
| 'farend_speech': InputType.AUDIO | |||||
| }, | |||||
| Tasks.acoustic_noise_suppression: | |||||
| InputType.AUDIO, | |||||
| Tasks.text_to_speech: | |||||
| InputType.TEXT, | |||||
| Tasks.keyword_spotting: | |||||
| InputType.AUDIO, | |||||
| # ============ multi-modal tasks =================== | |||||
| Tasks.image_captioning: | |||||
| InputType.IMAGE, | |||||
| Tasks.visual_grounding: { | |||||
| 'image': InputType.IMAGE, | |||||
| 'text': InputType.TEXT | |||||
| }, | |||||
| Tasks.text_to_image_synthesis: { | |||||
| 'text': InputType.TEXT, | |||||
| }, | |||||
| Tasks.multi_modal_embedding: { | |||||
| 'img': InputType.IMAGE, | |||||
| 'text': InputType.TEXT | |||||
| }, | |||||
| Tasks.generative_multi_modal_embedding: { | |||||
| 'image': InputType.IMAGE, | |||||
| 'text': InputType.TEXT | |||||
| }, | |||||
| Tasks.multi_modal_similarity: { | |||||
| 'img': InputType.IMAGE, | |||||
| 'text': InputType.TEXT | |||||
| }, | |||||
| Tasks.visual_question_answering: { | |||||
| 'image': InputType.IMAGE, | |||||
| 'text': InputType.TEXT | |||||
| }, | |||||
| Tasks.visual_entailment: { | |||||
| 'image': InputType.IMAGE, | |||||
| 'text': InputType.TEXT, | |||||
| 'text2': InputType.TEXT, | |||||
| }, | |||||
| Tasks.action_detection: | |||||
| InputType.VIDEO, | |||||
| Tasks.image_reid_person: | |||||
| InputType.IMAGE, | |||||
| Tasks.video_inpainting: { | |||||
| 'video_input_path': InputType.TEXT, | |||||
| 'video_output_path': InputType.TEXT, | |||||
| 'mask_path': InputType.TEXT, | |||||
| } | |||||
| } | |||||
| @@ -13,6 +13,7 @@ import numpy as np | |||||
| from modelscope.models.base import Model | from modelscope.models.base import Model | ||||
| from modelscope.msdatasets import MsDataset | from modelscope.msdatasets import MsDataset | ||||
| from modelscope.outputs import TASK_OUTPUTS | from modelscope.outputs import TASK_OUTPUTS | ||||
| from modelscope.pipeline_inputs import TASK_INPUTS, check_input_type | |||||
| from modelscope.preprocessors import Preprocessor | from modelscope.preprocessors import Preprocessor | ||||
| from modelscope.utils.config import Config | from modelscope.utils.config import Config | ||||
| from modelscope.utils.constant import Frameworks, ModelFile | from modelscope.utils.constant import Frameworks, ModelFile | ||||
| @@ -210,7 +211,7 @@ class Pipeline(ABC): | |||||
| preprocess_params = kwargs.get('preprocess_params', {}) | preprocess_params = kwargs.get('preprocess_params', {}) | ||||
| forward_params = kwargs.get('forward_params', {}) | forward_params = kwargs.get('forward_params', {}) | ||||
| postprocess_params = kwargs.get('postprocess_params', {}) | postprocess_params = kwargs.get('postprocess_params', {}) | ||||
| self._check_input(input) | |||||
| out = self.preprocess(input, **preprocess_params) | out = self.preprocess(input, **preprocess_params) | ||||
| with device_placement(self.framework, self.device_name): | with device_placement(self.framework, self.device_name): | ||||
| if self.framework == Frameworks.torch: | if self.framework == Frameworks.torch: | ||||
| @@ -225,6 +226,42 @@ class Pipeline(ABC): | |||||
| self._check_output(out) | self._check_output(out) | ||||
| return out | return out | ||||
| def _check_input(self, input): | |||||
| task_name = self.group_key | |||||
| if task_name in TASK_INPUTS: | |||||
| input_type = TASK_INPUTS[task_name] | |||||
| # if multiple input formats are defined, we first | |||||
| # found the one that match input data and check | |||||
| if isinstance(input_type, list): | |||||
| matched_type = None | |||||
| for t in input_type: | |||||
| if type(t) == type(input): | |||||
| matched_type = t | |||||
| break | |||||
| if matched_type is None: | |||||
| err_msg = 'input data format for current pipeline should be one of following: \n' | |||||
| for t in input_type: | |||||
| err_msg += f'{t}\n' | |||||
| raise ValueError(err_msg) | |||||
| else: | |||||
| input_type = matched_type | |||||
| if isinstance(input_type, str): | |||||
| check_input_type(input_type, input) | |||||
| elif isinstance(input_type, tuple): | |||||
| for t, input_ele in zip(input_type, input): | |||||
| check_input_type(t, input_ele) | |||||
| elif isinstance(input_type, dict): | |||||
| for k in input_type.keys(): | |||||
| # allow single input for multi-modal models | |||||
| if k in input: | |||||
| check_input_type(input_type[k], input[k]) | |||||
| else: | |||||
| raise ValueError(f'invalid input_type definition {input_type}') | |||||
| else: | |||||
| logger.warning(f'task {task_name} input definition is missing') | |||||
| def _check_output(self, input): | def _check_output(self, input): | ||||
| # this attribute is dynamically attached by registry | # this attribute is dynamically attached by registry | ||||
| # when cls is registered in registry using task name | # when cls is registered in registry using task name | ||||
| @@ -132,12 +132,7 @@ class Body3DKeypointsPipeline(Pipeline): | |||||
| device='gpu' if torch.cuda.is_available() else 'cpu') | device='gpu' if torch.cuda.is_available() else 'cpu') | ||||
| def preprocess(self, input: Input) -> Dict[str, Any]: | def preprocess(self, input: Input) -> Dict[str, Any]: | ||||
| video_url = input.get('input_video') | |||||
| self.output_video_path = input.get('output_video_path') | |||||
| if self.output_video_path is None: | |||||
| self.output_video_path = tempfile.NamedTemporaryFile( | |||||
| suffix='.mp4').name | |||||
| video_url = input | |||||
| video_frames = self.read_video_frames(video_url) | video_frames = self.read_video_frames(video_url) | ||||
| if 0 == len(video_frames): | if 0 == len(video_frames): | ||||
| res = {'success': False, 'msg': 'get video frame failed.'} | res = {'success': False, 'msg': 'get video frame failed.'} | ||||
| @@ -194,9 +189,13 @@ class Body3DKeypointsPipeline(Pipeline): | |||||
| pred_3d_pose = poses.data.cpu().numpy()[ | pred_3d_pose = poses.data.cpu().numpy()[ | ||||
| 0] # [frame_num, joint_num, joint_dim] | 0] # [frame_num, joint_num, joint_dim] | ||||
| output_video_path = kwargs.get('output_video', None) | |||||
| if output_video_path is None: | |||||
| output_video_path = tempfile.NamedTemporaryFile( | |||||
| suffix='.mp4').name | |||||
| if 'render' in self.keypoint_model_3d.cfg.keys(): | if 'render' in self.keypoint_model_3d.cfg.keys(): | ||||
| self.render_prediction(pred_3d_pose) | |||||
| res[OutputKeys.OUTPUT_VIDEO] = self.output_video_path | |||||
| self.render_prediction(pred_3d_pose, output_video_path) | |||||
| res[OutputKeys.OUTPUT_VIDEO] = output_video_path | |||||
| res[OutputKeys.POSES] = pred_3d_pose | res[OutputKeys.POSES] = pred_3d_pose | ||||
| res[OutputKeys.TIMESTAMPS] = self.timestamps | res[OutputKeys.TIMESTAMPS] = self.timestamps | ||||
| @@ -252,12 +251,12 @@ class Body3DKeypointsPipeline(Pipeline): | |||||
| cap.release() | cap.release() | ||||
| return frames | return frames | ||||
| def render_prediction(self, pose3d_cam_rr): | |||||
| def render_prediction(self, pose3d_cam_rr, output_video_path): | |||||
| """render predict result 3d poses. | """render predict result 3d poses. | ||||
| Args: | Args: | ||||
| pose3d_cam_rr (nd.array): [frame_num, joint_num, joint_dim], 3d pose joints | pose3d_cam_rr (nd.array): [frame_num, joint_num, joint_dim], 3d pose joints | ||||
| output_video_path (str): output path for video | |||||
| Returns: | Returns: | ||||
| """ | """ | ||||
| frame_num = pose3d_cam_rr.shape[0] | frame_num = pose3d_cam_rr.shape[0] | ||||
| @@ -359,4 +358,4 @@ class Body3DKeypointsPipeline(Pipeline): | |||||
| # save mp4 | # save mp4 | ||||
| Writer = writers['ffmpeg'] | Writer = writers['ffmpeg'] | ||||
| writer = Writer(fps=self.fps, metadata={}, bitrate=4096) | writer = Writer(fps=self.fps, metadata={}, bitrate=4096) | ||||
| ani.save(self.output_video_path, writer=writer) | |||||
| ani.save(output_video_path, writer=writer) | |||||
| @@ -20,7 +20,7 @@ class Body3DKeypointsTest(unittest.TestCase, DemoCompatibilityCheck): | |||||
| self.task = Tasks.body_3d_keypoints | self.task = Tasks.body_3d_keypoints | ||||
| def pipeline_inference(self, pipeline: Pipeline, pipeline_input): | def pipeline_inference(self, pipeline: Pipeline, pipeline_input): | ||||
| output = pipeline(pipeline_input) | |||||
| output = pipeline(pipeline_input, output_video='./result.mp4') | |||||
| poses = np.array(output[OutputKeys.POSES]) | poses = np.array(output[OutputKeys.POSES]) | ||||
| print(f'result 3d points shape {poses.shape}') | print(f'result 3d points shape {poses.shape}') | ||||
| @@ -28,10 +28,7 @@ class Body3DKeypointsTest(unittest.TestCase, DemoCompatibilityCheck): | |||||
| def test_run_modelhub_with_video_file(self): | def test_run_modelhub_with_video_file(self): | ||||
| body_3d_keypoints = pipeline( | body_3d_keypoints = pipeline( | ||||
| Tasks.body_3d_keypoints, model=self.model_id) | Tasks.body_3d_keypoints, model=self.model_id) | ||||
| pipeline_input = { | |||||
| 'input_video': self.test_video, | |||||
| 'output_video_path': './result.mp4' | |||||
| } | |||||
| pipeline_input = self.test_video | |||||
| self.pipeline_inference( | self.pipeline_inference( | ||||
| body_3d_keypoints, pipeline_input=pipeline_input) | body_3d_keypoints, pipeline_input=pipeline_input) | ||||
| @@ -42,10 +39,7 @@ class Body3DKeypointsTest(unittest.TestCase, DemoCompatibilityCheck): | |||||
| if not cap.isOpened(): | if not cap.isOpened(): | ||||
| raise Exception('modelscope error: %s cannot be decoded by OpenCV.' | raise Exception('modelscope error: %s cannot be decoded by OpenCV.' | ||||
| % (self.test_video)) | % (self.test_video)) | ||||
| pipeline_input = { | |||||
| 'input_video': cap, | |||||
| 'output_video_path': './result.mp4' | |||||
| } | |||||
| pipeline_input = self.test_video | |||||
| self.pipeline_inference( | self.pipeline_inference( | ||||
| body_3d_keypoints, pipeline_input=pipeline_input) | body_3d_keypoints, pipeline_input=pipeline_input) | ||||
| @@ -26,7 +26,7 @@ class MplugTasksTest(unittest.TestCase, DemoCompatibilityCheck): | |||||
| model=model, | model=model, | ||||
| ) | ) | ||||
| image = Image.open('data/test/images/image_mplug_vqa.jpg') | image = Image.open('data/test/images/image_mplug_vqa.jpg') | ||||
| result = pipeline_caption({'image': image}) | |||||
| result = pipeline_caption(image) | |||||
| print(result[OutputKeys.CAPTION]) | print(result[OutputKeys.CAPTION]) | ||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | ||||
| @@ -35,7 +35,7 @@ class MplugTasksTest(unittest.TestCase, DemoCompatibilityCheck): | |||||
| Tasks.image_captioning, | Tasks.image_captioning, | ||||
| model='damo/mplug_image-captioning_coco_base_en') | model='damo/mplug_image-captioning_coco_base_en') | ||||
| image = Image.open('data/test/images/image_mplug_vqa.jpg') | image = Image.open('data/test/images/image_mplug_vqa.jpg') | ||||
| result = pipeline_caption({'image': image}) | |||||
| result = pipeline_caption(image) | |||||
| print(result[OutputKeys.CAPTION]) | print(result[OutputKeys.CAPTION]) | ||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | ||||
| @@ -34,7 +34,7 @@ class OfaTasksTest(unittest.TestCase, DemoCompatibilityCheck): | |||||
| model=model, | model=model, | ||||
| ) | ) | ||||
| image = 'data/test/images/image_captioning.png' | image = 'data/test/images/image_captioning.png' | ||||
| result = img_captioning({'image': image}) | |||||
| result = img_captioning(image) | |||||
| print(result[OutputKeys.CAPTION]) | print(result[OutputKeys.CAPTION]) | ||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | ||||
| @@ -42,8 +42,7 @@ class OfaTasksTest(unittest.TestCase, DemoCompatibilityCheck): | |||||
| img_captioning = pipeline( | img_captioning = pipeline( | ||||
| Tasks.image_captioning, | Tasks.image_captioning, | ||||
| model='damo/ofa_image-caption_coco_large_en') | model='damo/ofa_image-caption_coco_large_en') | ||||
| result = img_captioning( | |||||
| {'image': 'data/test/images/image_captioning.png'}) | |||||
| result = img_captioning('data/test/images/image_captioning.png') | |||||
| print(result[OutputKeys.CAPTION]) | print(result[OutputKeys.CAPTION]) | ||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | ||||
| @@ -52,8 +51,7 @@ class OfaTasksTest(unittest.TestCase, DemoCompatibilityCheck): | |||||
| 'damo/ofa_image-classification_imagenet_large_en') | 'damo/ofa_image-classification_imagenet_large_en') | ||||
| ofa_pipe = pipeline(Tasks.image_classification, model=model) | ofa_pipe = pipeline(Tasks.image_classification, model=model) | ||||
| image = 'data/test/images/image_classification.png' | image = 'data/test/images/image_classification.png' | ||||
| input = {'image': image} | |||||
| result = ofa_pipe(input) | |||||
| result = ofa_pipe(image) | |||||
| print(result) | print(result) | ||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | ||||
| @@ -62,8 +60,7 @@ class OfaTasksTest(unittest.TestCase, DemoCompatibilityCheck): | |||||
| Tasks.image_classification, | Tasks.image_classification, | ||||
| model='damo/ofa_image-classification_imagenet_large_en') | model='damo/ofa_image-classification_imagenet_large_en') | ||||
| image = 'data/test/images/image_classification.png' | image = 'data/test/images/image_classification.png' | ||||
| input = {'image': image} | |||||
| result = ofa_pipe(input) | |||||
| result = ofa_pipe(image) | |||||
| print(result) | print(result) | ||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | ||||
| @@ -99,8 +96,8 @@ class OfaTasksTest(unittest.TestCase, DemoCompatibilityCheck): | |||||
| ofa_pipe = pipeline(Tasks.text_classification, model=model) | ofa_pipe = pipeline(Tasks.text_classification, model=model) | ||||
| text = 'One of our number will carry out your instructions minutely.' | text = 'One of our number will carry out your instructions minutely.' | ||||
| text2 = 'A member of my team will execute your orders with immense precision.' | text2 = 'A member of my team will execute your orders with immense precision.' | ||||
| input = {'text': text, 'text2': text2} | |||||
| result = ofa_pipe(input) | |||||
| result = ofa_pipe((text, text2)) | |||||
| result = ofa_pipe({'text': text, 'text2': text2}) | |||||
| print(result) | print(result) | ||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | ||||
| @@ -110,8 +107,7 @@ class OfaTasksTest(unittest.TestCase, DemoCompatibilityCheck): | |||||
| model='damo/ofa_text-classification_mnli_large_en') | model='damo/ofa_text-classification_mnli_large_en') | ||||
| text = 'One of our number will carry out your instructions minutely.' | text = 'One of our number will carry out your instructions minutely.' | ||||
| text2 = 'A member of my team will execute your orders with immense precision.' | text2 = 'A member of my team will execute your orders with immense precision.' | ||||
| input = {'text': text, 'text2': text2} | |||||
| result = ofa_pipe(input) | |||||
| result = ofa_pipe((text, text2)) | |||||
| print(result) | print(result) | ||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | ||||