Add pipelines for the following new models: - [Multilingual Quality Estimation](https://modelscope.cn/models/damo/nlp_translation_quality_estimation_multilingual/summary) - [Automatic Post-Editing (En-De)](https://modelscope.cn/models/damo/nlp_automatic_post_editing_for_translation_en2de/summary) - [Domain classification (Zh)](https://modelscope.cn/models/damo/nlp_domain_classification_chinese/summary) - [Style classification (Zh)](https://modelscope.cn/models/damo/nlp_style_classification_chinese/summary) - [Style classification (En)](https://modelscope.cn/models/damo/nlp_style_classification_english/summary) Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10315370master
| @@ -189,6 +189,9 @@ class Pipelines(object): | |||
| product_segmentation = 'product-segmentation' | |||
| # nlp tasks | |||
| automatic_post_editing = 'automatic-post-editing' | |||
| translation_quality_estimation = 'translation-quality-estimation' | |||
| domain_classification = 'domain-classification' | |||
| sentence_similarity = 'sentence-similarity' | |||
| word_segmentation = 'word-segmentation' | |||
| part_of_speech = 'part-of-speech' | |||
| @@ -4,12 +4,13 @@ from typing import TYPE_CHECKING | |||
| from modelscope.utils.import_utils import LazyImportModule | |||
| if TYPE_CHECKING: | |||
| from .automatic_post_editing_pipeline import AutomaticPostEditingPipeline | |||
| from .conversational_text_to_sql_pipeline import ConversationalTextToSqlPipeline | |||
| from .table_question_answering_pipeline import TableQuestionAnsweringPipeline | |||
| from .dialog_intent_prediction_pipeline import DialogIntentPredictionPipeline | |||
| from .dialog_modeling_pipeline import DialogModelingPipeline | |||
| from .dialog_state_tracking_pipeline import DialogStateTrackingPipeline | |||
| from .document_segmentation_pipeline import DocumentSegmentationPipeline | |||
| from .fasttext_sequence_classification_pipeline import FasttextSequenceClassificationPipeline | |||
| from .faq_question_answering_pipeline import FaqQuestionAnsweringPipeline | |||
| from .feature_extraction_pipeline import FeatureExtractionPipeline | |||
| from .fill_mask_pipeline import FillMaskPipeline | |||
| @@ -20,6 +21,8 @@ if TYPE_CHECKING: | |||
| from .sentence_embedding_pipeline import SentenceEmbeddingPipeline | |||
| from .sequence_classification_pipeline import SequenceClassificationPipeline | |||
| from .summarization_pipeline import SummarizationPipeline | |||
| from .table_question_answering_pipeline import TableQuestionAnsweringPipeline | |||
| from .translation_quality_estimation_pipeline import TranslationQualityEstimationPipeline | |||
| from .text_classification_pipeline import TextClassificationPipeline | |||
| from .text_error_correction_pipeline import TextErrorCorrectionPipeline | |||
| from .text_generation_pipeline import TextGenerationPipeline | |||
| @@ -31,14 +34,15 @@ if TYPE_CHECKING: | |||
| else: | |||
| _import_structure = { | |||
| 'automatic_post_editing_pipeline': ['AutomaticPostEditingPipeline'], | |||
| 'conversational_text_to_sql_pipeline': | |||
| ['ConversationalTextToSqlPipeline'], | |||
| 'table_question_answering_pipeline': | |||
| ['TableQuestionAnsweringPipeline'], | |||
| 'dialog_intent_prediction_pipeline': | |||
| ['DialogIntentPredictionPipeline'], | |||
| 'dialog_modeling_pipeline': ['DialogModelingPipeline'], | |||
| 'dialog_state_tracking_pipeline': ['DialogStateTrackingPipeline'], | |||
| 'domain_classification_pipeline': | |||
| ['FasttextSequenceClassificationPipeline'], | |||
| 'document_segmentation_pipeline': ['DocumentSegmentationPipeline'], | |||
| 'faq_question_answering_pipeline': ['FaqQuestionAnsweringPipeline'], | |||
| 'feature_extraction_pipeline': ['FeatureExtractionPipeline'], | |||
| @@ -51,12 +55,16 @@ else: | |||
| 'sentence_embedding_pipeline': ['SentenceEmbeddingPipeline'], | |||
| 'sequence_classification_pipeline': ['SequenceClassificationPipeline'], | |||
| 'summarization_pipeline': ['SummarizationPipeline'], | |||
| 'table_question_answering_pipeline': | |||
| ['TableQuestionAnsweringPipeline'], | |||
| 'text_classification_pipeline': ['TextClassificationPipeline'], | |||
| 'text_error_correction_pipeline': ['TextErrorCorrectionPipeline'], | |||
| 'text_generation_pipeline': ['TextGenerationPipeline'], | |||
| 'text2text_generation_pipeline': ['Text2TextGenerationPipeline'], | |||
| 'token_classification_pipeline': ['TokenClassificationPipeline'], | |||
| 'translation_pipeline': ['TranslationPipeline'], | |||
| 'translation_quality_estimation_pipeline': | |||
| ['TranslationQualityEstimationPipeline'], | |||
| 'word_segmentation_pipeline': ['WordSegmentationPipeline'], | |||
| 'zero_shot_classification_pipeline': | |||
| ['ZeroShotClassificationPipeline'], | |||
| @@ -0,0 +1,158 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| from html import unescape | |||
| from typing import Any, Dict | |||
| import jieba | |||
| import numpy as np | |||
| import tensorflow as tf | |||
| from sacremoses import (MosesDetokenizer, MosesDetruecaser, | |||
| MosesPunctNormalizer, MosesTokenizer, MosesTruecaser) | |||
| from sentencepiece import SentencePieceProcessor | |||
| from tensorflow.contrib.seq2seq.python.ops import beam_search_ops | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.models.base import Model | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines.base import Pipeline | |||
| from modelscope.pipelines.builder import PIPELINES | |||
| from modelscope.utils.config import Config, ConfigFields | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| if tf.__version__ >= '2.0': | |||
| tf = tf.compat.v1 | |||
| tf.disable_eager_execution() | |||
| logger = get_logger() | |||
| __all__ = ['AutomaticPostEditingPipeline'] | |||
| @PIPELINES.register_module( | |||
| Tasks.translation, module_name=Pipelines.automatic_post_editing) | |||
| class AutomaticPostEditingPipeline(Pipeline): | |||
| def __init__(self, model: str, **kwargs): | |||
| """Build an automatic post editing pipeline with a model dir. | |||
| @param model: Model path for saved pb file | |||
| """ | |||
| super().__init__(model=model, **kwargs) | |||
| export_dir = model | |||
| self.cfg = Config.from_file( | |||
| os.path.join(export_dir, ModelFile.CONFIGURATION)) | |||
| joint_vocab_file = os.path.join( | |||
| export_dir, self.cfg[ConfigFields.preprocessor]['vocab']) | |||
| self.vocab = dict([(w.strip(), i) for i, w in enumerate( | |||
| open(joint_vocab_file, 'r', encoding='utf8'))]) | |||
| self.vocab_reverse = dict([(i, w.strip()) for i, w in enumerate( | |||
| open(joint_vocab_file, 'r', encoding='utf8'))]) | |||
| self.unk_id = self.cfg[ConfigFields.preprocessor].get('unk_id', -1) | |||
| strip_unk = self.cfg.get(ConfigFields.postprocessor, | |||
| {}).get('strip_unk', True) | |||
| self.unk_token = '' if strip_unk else self.cfg.get( | |||
| ConfigFields.postprocessor, {}).get('unk_token', '<unk>') | |||
| if self.unk_id == -1: | |||
| self.unk_id = len(self.vocab) - 1 | |||
| tf.reset_default_graph() | |||
| tf_config = tf.ConfigProto(allow_soft_placement=True) | |||
| tf_config.gpu_options.allow_growth = True | |||
| self._session = tf.Session(config=tf_config) | |||
| tf.saved_model.loader.load( | |||
| self._session, [tf.python.saved_model.tag_constants.SERVING], | |||
| export_dir) | |||
| default_graph = tf.get_default_graph() | |||
| self.input_src_id_placeholder = default_graph.get_tensor_by_name( | |||
| 'Placeholder:0') | |||
| self.input_src_len_placeholder = default_graph.get_tensor_by_name( | |||
| 'Placeholder_1:0') | |||
| self.input_mt_id_placeholder = default_graph.get_tensor_by_name( | |||
| 'Placeholder_2:0') | |||
| self.input_mt_len_placeholder = default_graph.get_tensor_by_name( | |||
| 'Placeholder_3:0') | |||
| output_id_beam = default_graph.get_tensor_by_name( | |||
| 'enc2enc/decoder/transpose:0') | |||
| output_len_beam = default_graph.get_tensor_by_name( | |||
| 'enc2enc/decoder/Minimum:0') | |||
| output_id = tf.cast( | |||
| tf.map_fn(lambda x: x[0], output_id_beam), dtype=tf.int64) | |||
| output_len = tf.map_fn(lambda x: x[0], output_len_beam) | |||
| self.output = {'output_ids': output_id, 'output_lens': output_len} | |||
| init = tf.global_variables_initializer() | |||
| local_init = tf.local_variables_initializer() | |||
| self._session.run([init, local_init]) | |||
| tf.saved_model.loader.load( | |||
| self._session, [tf.python.saved_model.tag_constants.SERVING], | |||
| export_dir) | |||
| # preprocess | |||
| self._src_lang = self.cfg[ConfigFields.preprocessor]['src_lang'] | |||
| self._tgt_lang = self.cfg[ConfigFields.preprocessor]['tgt_lang'] | |||
| tok_escape = self.cfg[ConfigFields.preprocessor].get( | |||
| 'tokenize_escape', False) | |||
| src_tokenizer = MosesTokenizer(lang=self._src_lang) | |||
| mt_tokenizer = MosesTokenizer(lang=self._tgt_lang) | |||
| truecase_model = os.path.join( | |||
| export_dir, self.cfg[ConfigFields.preprocessor]['truecaser']) | |||
| truecaser = MosesTruecaser(load_from=truecase_model) | |||
| sp_model = os.path.join( | |||
| export_dir, self.cfg[ConfigFields.preprocessor]['sentencepiece']) | |||
| sp = SentencePieceProcessor() | |||
| sp.load(sp_model) | |||
| self.src_preprocess = lambda x: ' '.join( | |||
| sp.encode_as_pieces( | |||
| truecaser.truecase( | |||
| src_tokenizer.tokenize( | |||
| x, return_str=True, escape=tok_escape), | |||
| return_str=True))) | |||
| self.mt_preprocess = lambda x: ' '.join( | |||
| sp.encode_as_pieces( | |||
| truecaser.truecase( | |||
| mt_tokenizer.tokenize( | |||
| x, return_str=True, escape=tok_escape), | |||
| return_str=True))) | |||
| # post process, de-bpe, de-truecase, detok | |||
| detruecaser = MosesDetruecaser() | |||
| detokenizer = MosesDetokenizer(lang=self._tgt_lang) | |||
| self.postprocess_fun = lambda x: detokenizer.detokenize( | |||
| detruecaser.detruecase( | |||
| x.replace(' ▁', '@@').replace(' ', '').replace('@@', ' '). | |||
| strip()[1:], | |||
| return_str=True).split()) | |||
| def preprocess(self, input: str) -> Dict[str, Any]: | |||
| src, mt = input.split('\005', 1) | |||
| src_sp, mt_sp = self.src_preprocess(src), self.mt_preprocess(mt) | |||
| input_src_ids = np.array( | |||
| [[self.vocab.get(w, self.unk_id) for w in src_sp.strip().split()]]) | |||
| input_mt_ids = np.array( | |||
| [[self.vocab.get(w, self.unk_id) for w in mt_sp.strip().split()]]) | |||
| input_src_lens = [len(x) for x in input_src_ids] | |||
| input_mt_lens = [len(x) for x in input_mt_ids] | |||
| feed_dict = { | |||
| self.input_src_id_placeholder: input_src_ids, | |||
| self.input_mt_id_placeholder: input_mt_ids, | |||
| self.input_src_len_placeholder: input_src_lens, | |||
| self.input_mt_len_placeholder: input_mt_lens | |||
| } | |||
| return feed_dict | |||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
| with self._session.as_default(): | |||
| sess_outputs = self._session.run(self.output, feed_dict=input) | |||
| return sess_outputs | |||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
| output_ids, output_len = inputs['output_ids'][0], inputs[ | |||
| 'output_lens'][0] | |||
| output_ids = output_ids[:output_len - 1] # -1 for </s> | |||
| output_tokens = ' '.join([ | |||
| self.vocab_reverse.get(wid, self.unk_token) for wid in output_ids | |||
| ]) | |||
| post_editing_output = self.postprocess_fun(output_tokens) | |||
| result = {OutputKeys.TRANSLATION: post_editing_output} | |||
| return result | |||
| @@ -0,0 +1,69 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| from typing import Any, Dict, Union | |||
| import numpy as np | |||
| import sentencepiece | |||
| from fasttext import load_model | |||
| from fasttext.FastText import _FastText | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.models import Model | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines.base import Input, Pipeline | |||
| from modelscope.pipelines.builder import PIPELINES | |||
| from modelscope.preprocessors import SequenceClassificationPreprocessor | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| __all__ = ['FasttextSequenceClassificationPipeline'] | |||
| def sentencepiece_tokenize(sp_model, sent): | |||
| tokens = [] | |||
| for t in sp_model.EncodeAsPieces(sent): | |||
| s = t.strip() | |||
| if s: | |||
| tokens.append(s) | |||
| return ' '.join(tokens) | |||
| @PIPELINES.register_module( | |||
| Tasks.text_classification, module_name=Pipelines.domain_classification) | |||
| class FasttextSequenceClassificationPipeline(Pipeline): | |||
| def __init__(self, model: Union[str, _FastText], **kwargs): | |||
| """use `model` and `preprocessor` to create a nlp text classification pipeline for prediction | |||
| Args: | |||
| model: a model directory including model.bin and spm.model | |||
| preprocessor (SequenceClassificationPreprocessor): a preprocessor instance | |||
| """ | |||
| super().__init__(model=model) | |||
| model_file = os.path.join(model, ModelFile.TORCH_MODEL_BIN_FILE) | |||
| spm_file = os.path.join(model, 'sentencepiece.model') | |||
| assert os.path.isdir(model) and os.path.exists(model_file) and os.path.exists(spm_file), \ | |||
| '`model` should be a directory contains `model.bin` and `sentencepiece.model`' | |||
| self.model = load_model(model_file) | |||
| self.spm = sentencepiece.SentencePieceProcessor() | |||
| self.spm.Load(spm_file) | |||
| def preprocess(self, inputs: str) -> Dict[str, Any]: | |||
| text = inputs.strip() | |||
| text_sp = sentencepiece_tokenize(self.spm, text) | |||
| return {'text_sp': text_sp, 'text': text} | |||
| def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
| topk = inputs.get('topk', -1) | |||
| label, probs = self.model.predict(inputs['text_sp'], k=topk) | |||
| label = [x.replace('__label__', '') for x in label] | |||
| result = { | |||
| OutputKeys.LABEL: label[0], | |||
| OutputKeys.SCORE: probs[0], | |||
| OutputKeys.LABELS: label, | |||
| OutputKeys.SCORES: probs | |||
| } | |||
| return result | |||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
| return inputs | |||
| @@ -0,0 +1,72 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import io | |||
| import os | |||
| from typing import Any, Dict, Union | |||
| import numpy as np | |||
| import torch | |||
| from transformers import XLMRobertaTokenizer | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.models import Model | |||
| from modelscope.models.nlp import BertForSequenceClassification | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines.base import Input, Pipeline | |||
| from modelscope.pipelines.builder import PIPELINES | |||
| from modelscope.preprocessors import SequenceClassificationPreprocessor | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| __all__ = ['TranslationQualityEstimationPipeline'] | |||
| @PIPELINES.register_module( | |||
| Tasks.sentence_similarity, | |||
| module_name=Pipelines.translation_quality_estimation) | |||
| class TranslationQualityEstimationPipeline(Pipeline): | |||
| def __init__(self, model: str, device: str = 'gpu', **kwargs): | |||
| super().__init__(model=model, device=device) | |||
| model_file = os.path.join(model, ModelFile.TORCH_MODEL_FILE) | |||
| with open(model_file, 'rb') as f: | |||
| buffer = io.BytesIO(f.read()) | |||
| self.tokenizer = XLMRobertaTokenizer.from_pretrained(model) | |||
| self.model = torch.jit.load( | |||
| buffer, map_location=self.device).to(self.device) | |||
| def preprocess(self, inputs: Dict[str, Any]): | |||
| src_text = inputs['source_text'].strip() | |||
| tgt_text = inputs['target_text'].strip() | |||
| encoded_inputs = self.tokenizer.batch_encode_plus( | |||
| [[src_text, tgt_text]], | |||
| return_tensors='pt', | |||
| padding=True, | |||
| truncation=True) | |||
| input_ids = encoded_inputs['input_ids'].to(self.device) | |||
| attention_mask = encoded_inputs['attention_mask'].to(self.device) | |||
| inputs.update({ | |||
| 'input_ids': input_ids, | |||
| 'attention_mask': attention_mask | |||
| }) | |||
| return inputs | |||
| def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
| if 'input_ids' not in inputs: | |||
| inputs = self.preprocess(inputs) | |||
| res = self.model(inputs['input_ids'], inputs['attention_mask']) | |||
| result = { | |||
| OutputKeys.LABELS: '-1', | |||
| OutputKeys.SCORES: res[0].detach().squeeze().tolist() | |||
| } | |||
| return result | |||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]: | |||
| """process the prediction results | |||
| Args: | |||
| inputs (Dict[str, Any]): input data dict | |||
| Returns: | |||
| Dict[str, str]: the prediction results | |||
| """ | |||
| return inputs | |||
| @@ -267,6 +267,7 @@ class ConfigFields(object): | |||
| preprocessor = 'preprocessor' | |||
| train = 'train' | |||
| evaluation = 'evaluation' | |||
| postprocessor = 'postprocessor' | |||
| class ConfigKeys(object): | |||
| @@ -1,7 +1,10 @@ | |||
| en_core_web_sm>=2.3.5 | |||
| fasttext | |||
| jieba>=0.42.1 | |||
| megatron_util | |||
| pai-easynlp | |||
| # “protobuf version beyond 3.20.0 is not compatible with TensorFlow 1.x, therefore is discouraged.” | |||
| protobuf>=3.19.0,<3.21.0 | |||
| # rough-score was just recently updated from 0.0.4 to 0.0.7 | |||
| # which introduced compatability issues that are being investigated | |||
| rouge_score<=0.0.4 | |||
| @@ -0,0 +1,30 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import unittest | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.demo_utils import DemoCompatibilityCheck | |||
| from modelscope.utils.test_utils import test_level | |||
| class AutomaticPostEditingTest(unittest.TestCase, DemoCompatibilityCheck): | |||
| def setUp(self) -> None: | |||
| self.task = Tasks.translation | |||
| self.model_id = 'damo/nlp_automatic_post_editing_for_translation_en2de' | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_model_name_for_en2de(self): | |||
| inputs = 'Simultaneously, the Legion took part to the pacification of Algeria, plagued by various tribal ' \ | |||
| 'rebellions and razzias.\005Gleichzeitig nahm die Legion an der Befriedung Algeriens teil, die von ' \ | |||
| 'verschiedenen Stammesaufständen und Rasias heimgesucht wurde.' | |||
| pipeline_ins = pipeline(self.task, model=self.model_id) | |||
| print(pipeline_ins(input=inputs)) | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_demo_compatibility(self): | |||
| self.compatibility_check() | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||
| @@ -0,0 +1,45 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import unittest | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.demo_utils import DemoCompatibilityCheck | |||
| from modelscope.utils.test_utils import test_level | |||
| class DomainClassificationTest(unittest.TestCase, DemoCompatibilityCheck): | |||
| def setUp(self) -> None: | |||
| self.task = Tasks.text_classification | |||
| self.model_id = 'damo/nlp_domain_classification_chinese' | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_model_name_for_zh_domain(self): | |||
| inputs = '通过这种方式产生的离子吸收大地水分之后,可以通过潮解作用,将活性电解离子有效释放到周围土壤中,使接地极成为一个离子发生装置,' \ | |||
| '从而改善周边土质使之达到接地要求。' | |||
| pipeline_ins = pipeline(self.task, model=self.model_id) | |||
| print(pipeline_ins(input=inputs)) | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_model_name_for_zh_style(self): | |||
| model_id = 'damo/nlp_style_classification_chinese' | |||
| inputs = '通过这种方式产生的离子吸收大地水分之后,可以通过潮解作用,将活性电解离子有效释放到周围土壤中,使接地极成为一个离子发生装置,' \ | |||
| '从而改善周边土质使之达到接地要求。' | |||
| pipeline_ins = pipeline(self.task, model=model_id) | |||
| print(pipeline_ins(input=inputs)) | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_model_name_for_en_style(self): | |||
| model_id = 'damo/nlp_style_classification_english' | |||
| inputs = 'High Power 11.1V 5200mAh Lipo Battery For RC Car Robot Airplanes ' \ | |||
| 'Helicopter RC Drone Parts 3s Lithium battery 11.1v Battery' | |||
| pipeline_ins = pipeline(self.task, model=model_id) | |||
| print(pipeline_ins(input=inputs)) | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_demo_compatibility(self): | |||
| self.compatibility_check() | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||
| @@ -0,0 +1,32 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import unittest | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.demo_utils import DemoCompatibilityCheck | |||
| from modelscope.utils.test_utils import test_level | |||
| class TranslationQualityEstimationTest(unittest.TestCase, | |||
| DemoCompatibilityCheck): | |||
| def setUp(self) -> None: | |||
| self.task = Tasks.sentence_similarity | |||
| self.model_id = 'damo/nlp_translation_quality_estimation_multilingual' | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_model_name_for_en2zh(self): | |||
| inputs = { | |||
| 'source_text': 'Love is a losing game', | |||
| 'target_text': '宝贝,人和人一场游戏' | |||
| } | |||
| pipeline_ins = pipeline(self.task, model=self.model_id) | |||
| print(pipeline_ins(input=inputs)) | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_demo_compatibility(self): | |||
| self.compatibility_check() | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||