You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dialog_generation_preprocessor.py 1.5 kB

3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. import uuid
  4. from typing import Any, Dict, Union
  5. from maas_lib.data.nlp.space.fields.gen_field import MultiWOZBPETextField
  6. from maas_lib.utils.config import Config
  7. from maas_lib.utils.constant import Fields, InputFields
  8. from maas_lib.utils.type_assert import type_assert
  9. from ..base import Preprocessor
  10. from ..builder import PREPROCESSORS
  11. __all__ = ['DialogGenerationPreprocessor']
  12. @PREPROCESSORS.register_module(Fields.nlp, module_name=r'space-generation')
  13. class DialogGenerationPreprocessor(Preprocessor):
  14. def __init__(self, model_dir: str, *args, **kwargs):
  15. """preprocess the data via the vocab.txt from the `model_dir` path
  16. Args:
  17. model_dir (str): model path
  18. """
  19. super().__init__(*args, **kwargs)
  20. self.model_dir: str = model_dir
  21. self.config = Config.from_file(
  22. os.path.join(self.model_dir, 'configuration.json'))
  23. self.text_field = MultiWOZBPETextField(
  24. self.model_dir, config=self.config)
  25. @type_assert(object, str)
  26. def __call__(self, data: str) -> Dict[str, Any]:
  27. """process the raw input data
  28. Args:
  29. data (str): a sentence
  30. Example:
  31. 'you are so handsome.'
  32. Returns:
  33. Dict[str, Any]: the preprocessed data
  34. """
  35. idx = self.text_field.get_ids(data)
  36. return {'user_idx': idx}

致力于通过开放的社区合作,开源AI模型以及相关创新技术,推动基于模型即服务的生态繁荣发展