| @@ -16,6 +16,7 @@ class Models(object): | |||
| palm = 'palm-v2' | |||
| structbert = 'structbert' | |||
| veco = 'veco' | |||
| translation = 'csanmt-translation' | |||
| space = 'space' | |||
| # audio models | |||
| @@ -56,6 +57,7 @@ class Pipelines(object): | |||
| sentiment_analysis = 'sentiment-analysis' | |||
| sentiment_classification = 'sentiment-classification' | |||
| fill_mask = 'fill-mask' | |||
| csanmt_translation = 'csanmt-translation' | |||
| nli = 'nli' | |||
| dialog_intent_prediction = 'dialog-intent-prediction' | |||
| dialog_modeling = 'dialog-modeling' | |||
| @@ -16,7 +16,8 @@ try: | |||
| from .audio.kws import GenericKeyWordSpotting | |||
| from .multi_modal import OfaForImageCaptioning | |||
| from .nlp import (BertForMaskedLM, BertForSequenceClassification, | |||
| SbertForNLI, SbertForSentenceSimilarity, | |||
| CsanmtForTranslation, SbertForNLI, | |||
| SbertForSentenceSimilarity, | |||
| SbertForSentimentClassification, | |||
| SbertForTokenClassification, | |||
| SbertForZeroShotClassification, SpaceForDialogIntent, | |||
| @@ -1,4 +1,5 @@ | |||
| from .bert_for_sequence_classification import * # noqa F403 | |||
| from .csanmt_for_translation import * # noqa F403 | |||
| from .masked_language_model import * # noqa F403 | |||
| from .palm_for_text_generation import * # noqa F403 | |||
| from .sbert_for_nli import * # noqa F403 | |||
| @@ -21,12 +21,12 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||
| Tasks.sentence_similarity: | |||
| (Pipelines.sentence_similarity, | |||
| 'damo/nlp_structbert_sentence-similarity_chinese-base'), | |||
| Tasks.translation: (Pipelines.csanmt_translation, | |||
| 'damo/nlp_csanmt_translation'), | |||
| Tasks.nli: (Pipelines.nli, 'damo/nlp_structbert_nli_chinese-base'), | |||
| Tasks.sentiment_classification: | |||
| (Pipelines.sentiment_classification, | |||
| 'damo/nlp_structbert_sentiment-classification_chinese-base'), | |||
| Tasks.text_classification: ('bert-sentiment-analysis', | |||
| 'damo/bert-base-sst2'), | |||
| Tasks.image_matting: (Pipelines.image_matting, | |||
| 'damo/cv_unet_image-matting'), | |||
| Tasks.text_classification: (Pipelines.sentiment_analysis, | |||
| @@ -8,6 +8,7 @@ try: | |||
| from .sentiment_classification_pipeline import * # noqa F403 | |||
| from .sequence_classification_pipeline import * # noqa F403 | |||
| from .text_generation_pipeline import * # noqa F403 | |||
| from .translation_pipeline import * # noqa F403 | |||
| from .word_segmentation_pipeline import * # noqa F403 | |||
| from .zero_shot_classification_pipeline import * # noqa F403 | |||
| except ModuleNotFoundError as e: | |||
| @@ -0,0 +1,119 @@ | |||
| import os.path as osp | |||
| from typing import Any, Dict, Optional, Union | |||
| import numpy as np | |||
| import tensorflow as tf | |||
| from ...hub.snapshot_download import snapshot_download | |||
| from ...metainfo import Pipelines | |||
| from ...models import Model | |||
| from ...models.nlp import CsanmtForTranslation | |||
| from ...utils.constant import ModelFile, Tasks | |||
| from ...utils.logger import get_logger | |||
| from ..base import Pipeline, Tensor | |||
| from ..builder import PIPELINES | |||
| from ..outputs import OutputKeys | |||
| if tf.__version__ >= '2.0': | |||
| tf = tf.compat.v1 | |||
| tf.disable_eager_execution() | |||
| logger = get_logger() | |||
| __all__ = ['TranslationPipeline'] | |||
| # constant | |||
| PARAMS = { | |||
| 'hidden_size': 512, | |||
| 'filter_size': 2048, | |||
| 'num_heads': 8, | |||
| 'num_encoder_layers': 6, | |||
| 'num_decoder_layers': 6, | |||
| 'attention_dropout': 0.0, | |||
| 'residual_dropout': 0.0, | |||
| 'relu_dropout': 0.0, | |||
| 'layer_preproc': 'none', | |||
| 'layer_postproc': 'layer_norm', | |||
| 'shared_embedding_and_softmax_weights': True, | |||
| 'shared_source_target_embedding': True, | |||
| 'initializer_scale': 0.1, | |||
| 'train_max_len': 100, | |||
| 'confidence': 0.9, | |||
| 'position_info_type': 'absolute', | |||
| 'max_relative_dis': 16, | |||
| 'beam_size': 4, | |||
| 'lp_rate': 0.6, | |||
| 'num_semantic_encoder_layers': 4, | |||
| 'max_decoded_trg_len': 100, | |||
| 'src_vocab_size': 37006, | |||
| 'trg_vocab_size': 37006, | |||
| 'vocab_src': 'src_vocab.txt', | |||
| 'vocab_trg': 'trg_vocab.txt' | |||
| } | |||
| @PIPELINES.register_module( | |||
| Tasks.translation, module_name=Pipelines.csanmt_translation) | |||
| class TranslationPipeline(Pipeline): | |||
| def __init__(self, model: str, **kwargs): | |||
| if not osp.exists(model): | |||
| model = snapshot_download(model) | |||
| tf.reset_default_graph() | |||
| model_path = osp.join( | |||
| osp.join(model, ModelFile.TF_CHECKPOINT_FOLDER), 'ckpt-0') | |||
| self.params = PARAMS | |||
| self._src_vocab_path = osp.join(model, self.params['vocab_src']) | |||
| self._src_vocab = dict([ | |||
| (w.strip(), i) for i, w in enumerate(open(self._src_vocab_path)) | |||
| ]) | |||
| self._trg_vocab_path = osp.join(model, self.params['vocab_trg']) | |||
| self._trg_rvocab = dict([ | |||
| (i, w.strip()) for i, w in enumerate(open(self._trg_vocab_path)) | |||
| ]) | |||
| config = tf.ConfigProto(allow_soft_placement=True) | |||
| config.gpu_options.allow_growth = True | |||
| self._session = tf.Session(config=config) | |||
| self.input_wids = tf.placeholder( | |||
| dtype=tf.int64, shape=[None, None], name='input_wids') | |||
| self.output = {} | |||
| # model | |||
| csanmt_model = CsanmtForTranslation(model, params=self.params) | |||
| output = csanmt_model(self.input_wids) | |||
| self.output.update(output) | |||
| with self._session.as_default() as sess: | |||
| logger.info(f'loading model from {model_path}') | |||
| # load model | |||
| model_loader = tf.train.Saver(tf.global_variables()) | |||
| model_loader.restore(sess, model_path) | |||
| def preprocess(self, input: str) -> Dict[str, Any]: | |||
| input_ids = np.array([[ | |||
| self._src_vocab[w] | |||
| if w in self._src_vocab else self.params['src_vocab_size'] | |||
| for w in input.strip().split() | |||
| ]]) | |||
| result = {'input_ids': input_ids} | |||
| return result | |||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
| with self._session.as_default(): | |||
| feed_dict = {self.input_wids: input['input_ids']} | |||
| sess_outputs = self._session.run(self.output, feed_dict=feed_dict) | |||
| return sess_outputs | |||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
| output_seqs = inputs['output_seqs'][0] | |||
| wids = list(output_seqs[0]) + [0] | |||
| wids = wids[:wids.index(0)] | |||
| translation_out = ' '.join([ | |||
| self._trg_rvocab[wid] if wid in self._trg_rvocab else '<unk>' | |||
| for wid in wids | |||
| ]).replace('@@ ', '').replace('@@', '') | |||
| result = {OutputKeys.TRANSLATION: translation_out} | |||
| return result | |||
| @@ -18,6 +18,7 @@ class OutputKeys(object): | |||
| OUTPUT_PCM = 'output_pcm' | |||
| IMG_EMBEDDING = 'img_embedding' | |||
| TEXT_EMBEDDING = 'text_embedding' | |||
| TRANSLATION = 'translation' | |||
| RESPONSE = 'response' | |||
| PREDICTION = 'prediction' | |||
| DIALOG_STATES = 'dialog_states' | |||
| @@ -124,6 +125,12 @@ TASK_OUTPUTS = { | |||
| # } | |||
| Tasks.sentence_similarity: [OutputKeys.SCORES, OutputKeys.LABELS], | |||
| # translation result for a source sentence | |||
| # { | |||
| # "translation": “北京是中国的首都” | |||
| # } | |||
| Tasks.translation: [OutputKeys.TRANSLATION], | |||
| # sentiment classification result for single sample | |||
| # { | |||
| # "labels": ["happy", "sad", "calm", "angry"], | |||
| @@ -0,0 +1,22 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import shutil | |||
| import unittest | |||
| from modelscope.hub.snapshot_download import snapshot_download | |||
| from modelscope.pipelines import TranslationPipeline, pipeline | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.test_utils import test_level | |||
| class TranslationTest(unittest.TestCase): | |||
| model_id = 'damo/nlp_csanmt_translation' | |||
| inputs = 'Gut@@ ach : Incre@@ ased safety for pedestri@@ ans' | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_model_name(self): | |||
| pipeline_ins = pipeline(task=Tasks.translation, model=self.model_id) | |||
| print(pipeline_ins(input=self.inputs)) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||
| @@ -15,7 +15,7 @@ class TextToImageSynthesisTest(unittest.TestCase): | |||
| model_id = 'damo/cv_imagen_text-to-image-synthesis_tiny' | |||
| test_text = {'text': '宇航员'} | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_with_model_from_modelhub(self): | |||
| model = Model.from_pretrained(self.model_id) | |||
| pipe_line_text_to_image_synthesis = pipeline( | |||
| @@ -24,7 +24,7 @@ class TextToImageSynthesisTest(unittest.TestCase): | |||
| self.test_text)[OutputKeys.OUTPUT_IMG] | |||
| print(np.sum(np.abs(img))) | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_with_model_name(self): | |||
| pipe_line_text_to_image_synthesis = pipeline( | |||
| task=Tasks.text_to_image_synthesis, model=self.model_id) | |||