diff --git a/modelscope/models/nlp/space_for_dialog_modeling.py b/modelscope/models/nlp/space_for_dialog_modeling.py index 9ac6e099..35269e53 100644 --- a/modelscope/models/nlp/space_for_dialog_modeling.py +++ b/modelscope/models/nlp/space_for_dialog_modeling.py @@ -31,6 +31,10 @@ class SpaceForDialogModeling(Model): 'config', Config.from_file( os.path.join(self.model_dir, ModelFile.CONFIGURATION))) + + import torch + self.config.use_gpu = self.config.use_gpu and torch.cuda.is_available() + self.text_field = kwargs.pop( 'text_field', MultiWOZBPETextField(self.model_dir, config=self.config)) diff --git a/modelscope/preprocessors/space/dialog_modeling_preprocessor.py b/modelscope/preprocessors/space/dialog_modeling_preprocessor.py index db83d906..79059a9f 100644 --- a/modelscope/preprocessors/space/dialog_modeling_preprocessor.py +++ b/modelscope/preprocessors/space/dialog_modeling_preprocessor.py @@ -29,6 +29,10 @@ class DialogModelingPreprocessor(Preprocessor): self.model_dir: str = model_dir self.config = Config.from_file( os.path.join(self.model_dir, ModelFile.CONFIGURATION)) + + import torch + self.config.use_gpu = self.config.use_gpu and torch.cuda.is_available() + self.text_field = MultiWOZBPETextField( self.model_dir, config=self.config)