From e22a8a00f0e0c23f6795014b2dc7032f7559166e Mon Sep 17 00:00:00 2001 From: ly119399 Date: Fri, 22 Jul 2022 17:26:25 +0800 Subject: [PATCH] [to #42322933] dialog modeling use gpu default Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9443821 --- modelscope/models/nlp/space_for_dialog_modeling.py | 4 ++++ .../preprocessors/space/dialog_modeling_preprocessor.py | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/modelscope/models/nlp/space_for_dialog_modeling.py b/modelscope/models/nlp/space_for_dialog_modeling.py index 9ac6e099..35269e53 100644 --- a/modelscope/models/nlp/space_for_dialog_modeling.py +++ b/modelscope/models/nlp/space_for_dialog_modeling.py @@ -31,6 +31,10 @@ class SpaceForDialogModeling(Model): 'config', Config.from_file( os.path.join(self.model_dir, ModelFile.CONFIGURATION))) + + import torch + self.config.use_gpu = self.config.use_gpu and torch.cuda.is_available() + self.text_field = kwargs.pop( 'text_field', MultiWOZBPETextField(self.model_dir, config=self.config)) diff --git a/modelscope/preprocessors/space/dialog_modeling_preprocessor.py b/modelscope/preprocessors/space/dialog_modeling_preprocessor.py index db83d906..79059a9f 100644 --- a/modelscope/preprocessors/space/dialog_modeling_preprocessor.py +++ b/modelscope/preprocessors/space/dialog_modeling_preprocessor.py @@ -29,6 +29,10 @@ class DialogModelingPreprocessor(Preprocessor): self.model_dir: str = model_dir self.config = Config.from_file( os.path.join(self.model_dir, ModelFile.CONFIGURATION)) + + import torch + self.config.use_gpu = self.config.use_gpu and torch.cuda.is_available() + self.text_field = MultiWOZBPETextField( self.model_dir, config=self.config)