diff --git a/modelscope/models/nlp/csanmt_for_translation.py b/modelscope/models/nlp/csanmt_for_translation.py index 6906f41c..83b58060 100644 --- a/modelscope/models/nlp/csanmt_for_translation.py +++ b/modelscope/models/nlp/csanmt_for_translation.py @@ -21,7 +21,7 @@ class CsanmtForTranslation(Model): params (dict): the model configuration. """ super().__init__(model_dir, *args, **kwargs) - self.params = kwargs['params'] + self.params = kwargs def __call__(self, input: Dict[str, Tensor], diff --git a/modelscope/pipelines/nlp/translation_pipeline.py b/modelscope/pipelines/nlp/translation_pipeline.py index 67ff3927..ebd51f02 100644 --- a/modelscope/pipelines/nlp/translation_pipeline.py +++ b/modelscope/pipelines/nlp/translation_pipeline.py @@ -1,5 +1,4 @@ import os.path as osp -from threading import Lock from typing import Any, Dict import numpy as np @@ -10,7 +9,7 @@ from modelscope.outputs import OutputKeys from modelscope.pipelines.base import Pipeline from modelscope.pipelines.builder import PIPELINES from modelscope.utils.config import Config -from modelscope.utils.constant import Frameworks, ModelFile, Tasks +from modelscope.utils.constant import ModelFile, Tasks from modelscope.utils.logger import get_logger if tf.__version__ >= '2.0': @@ -27,25 +26,22 @@ __all__ = ['TranslationPipeline'] class TranslationPipeline(Pipeline): def __init__(self, model: str, **kwargs): - tf.reset_default_graph() - self.framework = Frameworks.tf - self.device_name = 'cpu' - super().__init__(model=model) + model = self.model.model_dir + tf.reset_default_graph() model_path = osp.join( osp.join(model, ModelFile.TF_CHECKPOINT_FOLDER), 'ckpt-0') self.cfg = Config.from_file(osp.join(model, ModelFile.CONFIGURATION)) - self.params = {} - self._override_params_from_file() - - self._src_vocab_path = osp.join(model, self.params['vocab_src']) + self._src_vocab_path = osp.join( + model, self.cfg['dataset']['src_vocab']['file']) self._src_vocab = dict([ (w.strip(), i) for i, w in enumerate(open(self._src_vocab_path)) ]) - self._trg_vocab_path = osp.join(model, self.params['vocab_trg']) + self._trg_vocab_path = osp.join( + model, self.cfg['dataset']['trg_vocab']['file']) self._trg_rvocab = dict([ (i, w.strip()) for i, w in enumerate(open(self._trg_vocab_path)) ]) @@ -59,7 +55,6 @@ class TranslationPipeline(Pipeline): self.output = {} # model - self.model = CsanmtForTranslation(model_path, params=self.params) output = self.model(self.input_wids) self.output.update(output) @@ -69,53 +64,10 @@ class TranslationPipeline(Pipeline): model_loader = tf.train.Saver(tf.global_variables()) model_loader.restore(sess, model_path) - def _override_params_from_file(self): - - # model - self.params['hidden_size'] = self.cfg['model']['hidden_size'] - self.params['filter_size'] = self.cfg['model']['filter_size'] - self.params['num_heads'] = self.cfg['model']['num_heads'] - self.params['num_encoder_layers'] = self.cfg['model'][ - 'num_encoder_layers'] - self.params['num_decoder_layers'] = self.cfg['model'][ - 'num_decoder_layers'] - self.params['layer_preproc'] = self.cfg['model']['layer_preproc'] - self.params['layer_postproc'] = self.cfg['model']['layer_postproc'] - self.params['shared_embedding_and_softmax_weights'] = self.cfg[ - 'model']['shared_embedding_and_softmax_weights'] - self.params['shared_source_target_embedding'] = self.cfg['model'][ - 'shared_source_target_embedding'] - self.params['initializer_scale'] = self.cfg['model'][ - 'initializer_scale'] - self.params['position_info_type'] = self.cfg['model'][ - 'position_info_type'] - self.params['max_relative_dis'] = self.cfg['model']['max_relative_dis'] - self.params['num_semantic_encoder_layers'] = self.cfg['model'][ - 'num_semantic_encoder_layers'] - self.params['src_vocab_size'] = self.cfg['model']['src_vocab_size'] - self.params['trg_vocab_size'] = self.cfg['model']['trg_vocab_size'] - self.params['attention_dropout'] = 0.0 - self.params['residual_dropout'] = 0.0 - self.params['relu_dropout'] = 0.0 - - # dataset - self.params['vocab_src'] = self.cfg['dataset']['src_vocab']['file'] - self.params['vocab_trg'] = self.cfg['dataset']['trg_vocab']['file'] - - # train - self.params['train_max_len'] = self.cfg['train']['train_max_len'] - self.params['confidence'] = self.cfg['train']['confidence'] - - # evaluation - self.params['beam_size'] = self.cfg['evaluation']['beam_size'] - self.params['lp_rate'] = self.cfg['evaluation']['lp_rate'] - self.params['max_decoded_trg_len'] = self.cfg['evaluation'][ - 'max_decoded_trg_len'] - def preprocess(self, input: str) -> Dict[str, Any]: input_ids = np.array([[ self._src_vocab[w] - if w in self._src_vocab else self.params['src_vocab_size'] + if w in self._src_vocab else self.cfg['model']['src_vocab_size'] for w in input.strip().split() ]]) result = {'input_ids': input_ids} diff --git a/modelscope/trainers/nlp/csanmt_translation_trainer.py b/modelscope/trainers/nlp/csanmt_translation_trainer.py index 219c5ff1..067c1d83 100644 --- a/modelscope/trainers/nlp/csanmt_translation_trainer.py +++ b/modelscope/trainers/nlp/csanmt_translation_trainer.py @@ -47,7 +47,7 @@ class CsanmtTranslationTrainer(BaseTrainer): self.global_step = tf.train.create_global_step() - self.model = CsanmtForTranslation(self.model_path, params=self.params) + self.model = CsanmtForTranslation(self.model_path, **self.params) output = self.model(input=self.source_wids, label=self.target_wids) self.output.update(output) @@ -319,6 +319,4 @@ def get_pretrained_variables_map(checkpoint_file_path, ignore_scope=None): if var_shape == saved_shapes[saved_var_name]: restore_vars.append(curr_var) restore_map[saved_var_name] = curr_var - tf.logging.info('Restore paramter %s from %s ...' % - (saved_var_name, checkpoint_file_path)) return restore_map diff --git a/tests/pipelines/test_csanmt_translation.py b/tests/pipelines/test_csanmt_translation.py index 699270d6..c852b1ff 100644 --- a/tests/pipelines/test_csanmt_translation.py +++ b/tests/pipelines/test_csanmt_translation.py @@ -10,7 +10,7 @@ class TranslationTest(unittest.TestCase): model_id = 'damo/nlp_csanmt_translation_zh2en' inputs = '声明 补充 说 , 沃伦 的 同事 都 深感 震惊 , 并且 希望 他 能够 投@@ 案@@ 自@@ 首 。' - @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_run_with_model_name(self): pipeline_ins = pipeline(task=Tasks.translation, model=self.model_id) print(pipeline_ins(input=self.inputs))