|
|
|
@@ -18,13 +18,13 @@ __all__ = ['CsanmtForTranslation'] |
|
|
|
@MODELS.register_module(Tasks.translation, module_name=Models.translation) |
|
|
|
class CsanmtForTranslation(Model): |
|
|
|
|
|
|
|
def __init__(self, model_dir, params, *args, **kwargs): |
|
|
|
def __init__(self, model_dir, *args, **kwargs): |
|
|
|
""" |
|
|
|
Args: |
|
|
|
params (dict): the model configuration. |
|
|
|
""" |
|
|
|
super().__init__(model_dir, *args, **kwargs) |
|
|
|
self.params = params |
|
|
|
self.params = kwargs |
|
|
|
|
|
|
|
def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: |
|
|
|
"""return the result by the model |
|
|
|
|