|
|
|
@@ -2,19 +2,19 @@ import os |
|
|
|
from typing import Any, Dict, Optional |
|
|
|
|
|
|
|
from ....preprocessors.space.fields.gen_field import MultiWOZBPETextField |
|
|
|
from ....trainers.nlp.space.trainer.gen_trainer import MultiWOZTrainer |
|
|
|
from ....utils.config import Config |
|
|
|
from ....utils.constant import Tasks |
|
|
|
from ....utils.constant import ModelFile, Tasks |
|
|
|
from ...base import Model, Tensor |
|
|
|
from ...builder import MODELS |
|
|
|
from .application.gen_app import MultiWOZTrainer |
|
|
|
from .model.generator import Generator |
|
|
|
from .model.model_base import ModelBase |
|
|
|
from .model.model_base import SpaceModelBase |
|
|
|
|
|
|
|
__all__ = ['DialogModelingModel'] |
|
|
|
__all__ = ['SpaceForDialogModelingModel'] |
|
|
|
|
|
|
|
|
|
|
|
@MODELS.register_module(Tasks.dialog_modeling, module_name=r'space') |
|
|
|
class DialogModelingModel(Model): |
|
|
|
class SpaceForDialogModelingModel(Model): |
|
|
|
|
|
|
|
def __init__(self, model_dir: str, *args, **kwargs): |
|
|
|
"""initialize the test generation model from the `model_dir` path. |
|
|
|
@@ -30,12 +30,12 @@ class DialogModelingModel(Model): |
|
|
|
self.config = kwargs.pop( |
|
|
|
'config', |
|
|
|
Config.from_file( |
|
|
|
os.path.join(self.model_dir, 'configuration.json'))) |
|
|
|
os.path.join(self.model_dir, ModelFile.CONFIGURATION))) |
|
|
|
self.text_field = kwargs.pop( |
|
|
|
'text_field', |
|
|
|
MultiWOZBPETextField(self.model_dir, config=self.config)) |
|
|
|
self.generator = Generator.create(self.config, reader=self.text_field) |
|
|
|
self.model = ModelBase.create( |
|
|
|
self.model = SpaceModelBase.create( |
|
|
|
model_dir=model_dir, |
|
|
|
config=self.config, |
|
|
|
reader=self.text_field, |
|
|
|
|