Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9198181master
| @@ -16,6 +16,7 @@ class Models(object): | |||||
| palm = 'palm-v2' | palm = 'palm-v2' | ||||
| structbert = 'structbert' | structbert = 'structbert' | ||||
| veco = 'veco' | veco = 'veco' | ||||
| translation = 'csanmt-translation' | |||||
| space = 'space' | space = 'space' | ||||
| # audio models | # audio models | ||||
| @@ -56,6 +57,7 @@ class Pipelines(object): | |||||
| sentiment_analysis = 'sentiment-analysis' | sentiment_analysis = 'sentiment-analysis' | ||||
| sentiment_classification = 'sentiment-classification' | sentiment_classification = 'sentiment-classification' | ||||
| fill_mask = 'fill-mask' | fill_mask = 'fill-mask' | ||||
| csanmt_translation = 'csanmt-translation' | |||||
| nli = 'nli' | nli = 'nli' | ||||
| dialog_intent_prediction = 'dialog-intent-prediction' | dialog_intent_prediction = 'dialog-intent-prediction' | ||||
| dialog_modeling = 'dialog-modeling' | dialog_modeling = 'dialog-modeling' | ||||
| @@ -15,13 +15,12 @@ except ModuleNotFoundError as e: | |||||
| try: | try: | ||||
| from .audio.kws import GenericKeyWordSpotting | from .audio.kws import GenericKeyWordSpotting | ||||
| from .multi_modal import OfaForImageCaptioning | from .multi_modal import OfaForImageCaptioning | ||||
| from .nlp import (BertForMaskedLM, BertForSequenceClassification, | |||||
| SbertForNLI, SbertForSentenceSimilarity, | |||||
| SbertForSentimentClassification, | |||||
| SbertForTokenClassification, | |||||
| SbertForZeroShotClassification, SpaceForDialogIntent, | |||||
| SpaceForDialogModeling, StructBertForMaskedLM, | |||||
| VecoForMaskedLM) | |||||
| from .nlp import ( | |||||
| BertForMaskedLM, BertForSequenceClassification, CsanmtForTranslation, | |||||
| SbertForNLI, SbertForSentenceSimilarity, | |||||
| SbertForSentimentClassification, SbertForTokenClassification, | |||||
| SbertForZeroShotClassification, SpaceForDialogIntent, | |||||
| SpaceForDialogModeling, StructBertForMaskedLM, VecoForMaskedLM) | |||||
| from .audio.ans.frcrn import FRCRNModel | from .audio.ans.frcrn import FRCRNModel | ||||
| except ModuleNotFoundError as e: | except ModuleNotFoundError as e: | ||||
| if str(e) == "No module named 'pytorch'": | if str(e) == "No module named 'pytorch'": | ||||
| @@ -1,4 +1,5 @@ | |||||
| from .bert_for_sequence_classification import * # noqa F403 | from .bert_for_sequence_classification import * # noqa F403 | ||||
| from .csanmt_for_translation import * # noqa F403 | |||||
| from .masked_language_model import * # noqa F403 | from .masked_language_model import * # noqa F403 | ||||
| from .palm_for_text_generation import * # noqa F403 | from .palm_for_text_generation import * # noqa F403 | ||||
| from .sbert_for_nli import * # noqa F403 | from .sbert_for_nli import * # noqa F403 | ||||
| @@ -21,12 +21,12 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||||
| Tasks.sentence_similarity: | Tasks.sentence_similarity: | ||||
| (Pipelines.sentence_similarity, | (Pipelines.sentence_similarity, | ||||
| 'damo/nlp_structbert_sentence-similarity_chinese-base'), | 'damo/nlp_structbert_sentence-similarity_chinese-base'), | ||||
| Tasks.translation: (Pipelines.csanmt_translation, | |||||
| 'damo/nlp_csanmt_translation'), | |||||
| Tasks.nli: (Pipelines.nli, 'damo/nlp_structbert_nli_chinese-base'), | Tasks.nli: (Pipelines.nli, 'damo/nlp_structbert_nli_chinese-base'), | ||||
| Tasks.sentiment_classification: | Tasks.sentiment_classification: | ||||
| (Pipelines.sentiment_classification, | (Pipelines.sentiment_classification, | ||||
| 'damo/nlp_structbert_sentiment-classification_chinese-base'), | 'damo/nlp_structbert_sentiment-classification_chinese-base'), | ||||
| Tasks.text_classification: ('bert-sentiment-analysis', | |||||
| 'damo/bert-base-sst2'), | |||||
| Tasks.image_matting: (Pipelines.image_matting, | Tasks.image_matting: (Pipelines.image_matting, | ||||
| 'damo/cv_unet_image-matting'), | 'damo/cv_unet_image-matting'), | ||||
| Tasks.text_classification: (Pipelines.sentiment_analysis, | Tasks.text_classification: (Pipelines.sentiment_analysis, | ||||
| @@ -7,6 +7,7 @@ try: | |||||
| from .sentiment_classification_pipeline import * # noqa F403 | from .sentiment_classification_pipeline import * # noqa F403 | ||||
| from .sequence_classification_pipeline import * # noqa F403 | from .sequence_classification_pipeline import * # noqa F403 | ||||
| from .text_generation_pipeline import * # noqa F403 | from .text_generation_pipeline import * # noqa F403 | ||||
| from .translation_pipeline import * # noqa F403 | |||||
| from .word_segmentation_pipeline import * # noqa F403 | from .word_segmentation_pipeline import * # noqa F403 | ||||
| from .zero_shot_classification_pipeline import * # noqa F403 | from .zero_shot_classification_pipeline import * # noqa F403 | ||||
| except ModuleNotFoundError as e: | except ModuleNotFoundError as e: | ||||
| @@ -0,0 +1,119 @@ | |||||
| import os.path as osp | |||||
| from typing import Any, Dict, Optional, Union | |||||
| import numpy as np | |||||
| import tensorflow as tf | |||||
| from ...hub.snapshot_download import snapshot_download | |||||
| from ...metainfo import Pipelines | |||||
| from ...models import Model | |||||
| from ...models.nlp import CsanmtForTranslation | |||||
| from ...utils.constant import ModelFile, Tasks | |||||
| from ...utils.logger import get_logger | |||||
| from ..base import Pipeline, Tensor | |||||
| from ..builder import PIPELINES | |||||
| from ..outputs import OutputKeys | |||||
| if tf.__version__ >= '2.0': | |||||
| tf = tf.compat.v1 | |||||
| tf.disable_eager_execution() | |||||
| logger = get_logger() | |||||
| __all__ = ['TranslationPipeline'] | |||||
| # constant | |||||
| PARAMS = { | |||||
| 'hidden_size': 512, | |||||
| 'filter_size': 2048, | |||||
| 'num_heads': 8, | |||||
| 'num_encoder_layers': 6, | |||||
| 'num_decoder_layers': 6, | |||||
| 'attention_dropout': 0.0, | |||||
| 'residual_dropout': 0.0, | |||||
| 'relu_dropout': 0.0, | |||||
| 'layer_preproc': 'none', | |||||
| 'layer_postproc': 'layer_norm', | |||||
| 'shared_embedding_and_softmax_weights': True, | |||||
| 'shared_source_target_embedding': True, | |||||
| 'initializer_scale': 0.1, | |||||
| 'train_max_len': 100, | |||||
| 'confidence': 0.9, | |||||
| 'position_info_type': 'absolute', | |||||
| 'max_relative_dis': 16, | |||||
| 'beam_size': 4, | |||||
| 'lp_rate': 0.6, | |||||
| 'num_semantic_encoder_layers': 4, | |||||
| 'max_decoded_trg_len': 100, | |||||
| 'src_vocab_size': 37006, | |||||
| 'trg_vocab_size': 37006, | |||||
| 'vocab_src': 'src_vocab.txt', | |||||
| 'vocab_trg': 'trg_vocab.txt' | |||||
| } | |||||
| @PIPELINES.register_module( | |||||
| Tasks.translation, module_name=Pipelines.csanmt_translation) | |||||
| class TranslationPipeline(Pipeline): | |||||
| def __init__(self, model: str, **kwargs): | |||||
| if not osp.exists(model): | |||||
| model = snapshot_download(model) | |||||
| tf.reset_default_graph() | |||||
| model_path = osp.join( | |||||
| osp.join(model, ModelFile.TF_CHECKPOINT_FOLDER), 'ckpt-0') | |||||
| self.params = PARAMS | |||||
| self._src_vocab_path = osp.join(model, self.params['vocab_src']) | |||||
| self._src_vocab = dict([ | |||||
| (w.strip(), i) for i, w in enumerate(open(self._src_vocab_path)) | |||||
| ]) | |||||
| self._trg_vocab_path = osp.join(model, self.params['vocab_trg']) | |||||
| self._trg_rvocab = dict([ | |||||
| (i, w.strip()) for i, w in enumerate(open(self._trg_vocab_path)) | |||||
| ]) | |||||
| config = tf.ConfigProto(allow_soft_placement=True) | |||||
| config.gpu_options.allow_growth = True | |||||
| self._session = tf.Session(config=config) | |||||
| self.input_wids = tf.placeholder( | |||||
| dtype=tf.int64, shape=[None, None], name='input_wids') | |||||
| self.output = {} | |||||
| # model | |||||
| csanmt_model = CsanmtForTranslation(model, params=self.params) | |||||
| output = csanmt_model(self.input_wids) | |||||
| self.output.update(output) | |||||
| with self._session.as_default() as sess: | |||||
| logger.info(f'loading model from {model_path}') | |||||
| # load model | |||||
| model_loader = tf.train.Saver(tf.global_variables()) | |||||
| model_loader.restore(sess, model_path) | |||||
| def preprocess(self, input: str) -> Dict[str, Any]: | |||||
| input_ids = np.array([[ | |||||
| self._src_vocab[w] | |||||
| if w in self._src_vocab else self.params['src_vocab_size'] | |||||
| for w in input.strip().split() | |||||
| ]]) | |||||
| result = {'input_ids': input_ids} | |||||
| return result | |||||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||||
| with self._session.as_default(): | |||||
| feed_dict = {self.input_wids: input['input_ids']} | |||||
| sess_outputs = self._session.run(self.output, feed_dict=feed_dict) | |||||
| return sess_outputs | |||||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
| output_seqs = inputs['output_seqs'][0] | |||||
| wids = list(output_seqs[0]) + [0] | |||||
| wids = wids[:wids.index(0)] | |||||
| translation_out = ' '.join([ | |||||
| self._trg_rvocab[wid] if wid in self._trg_rvocab else '<unk>' | |||||
| for wid in wids | |||||
| ]).replace('@@ ', '').replace('@@', '') | |||||
| result = {OutputKeys.TRANSLATION: translation_out} | |||||
| return result | |||||
| @@ -18,6 +18,7 @@ class OutputKeys(object): | |||||
| OUTPUT_PCM = 'output_pcm' | OUTPUT_PCM = 'output_pcm' | ||||
| IMG_EMBEDDING = 'img_embedding' | IMG_EMBEDDING = 'img_embedding' | ||||
| TEXT_EMBEDDING = 'text_embedding' | TEXT_EMBEDDING = 'text_embedding' | ||||
| TRANSLATION = 'translation' | |||||
| RESPONSE = 'response' | RESPONSE = 'response' | ||||
| PREDICTION = 'prediction' | PREDICTION = 'prediction' | ||||
| @@ -123,6 +124,12 @@ TASK_OUTPUTS = { | |||||
| # } | # } | ||||
| Tasks.sentence_similarity: [OutputKeys.SCORES, OutputKeys.LABELS], | Tasks.sentence_similarity: [OutputKeys.SCORES, OutputKeys.LABELS], | ||||
| # translation result for a source sentence | |||||
| # { | |||||
| # "translation": “北京是中国的首都” | |||||
| # } | |||||
| Tasks.translation: [OutputKeys.TRANSLATION], | |||||
| # sentiment classification result for single sample | # sentiment classification result for single sample | ||||
| # { | # { | ||||
| # "labels": ["happy", "sad", "calm", "angry"], | # "labels": ["happy", "sad", "calm", "angry"], | ||||
| @@ -0,0 +1,22 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import shutil | |||||
| import unittest | |||||
| from modelscope.hub.snapshot_download import snapshot_download | |||||
| from modelscope.pipelines import TranslationPipeline, pipeline | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.test_utils import test_level | |||||
| class TranslationTest(unittest.TestCase): | |||||
| model_id = 'damo/nlp_csanmt_translation' | |||||
| inputs = 'Gut@@ ach : Incre@@ ased safety for pedestri@@ ans' | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_run_with_model_name(self): | |||||
| pipeline_ins = pipeline(task=Tasks.translation, model=self.model_id) | |||||
| print(pipeline_ins(input=self.inputs)) | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||