ly119399 yingda.chen 3 years ago
parent
commit
e22a8a00f0
2 changed files with 8 additions and 0 deletions
  1. +4
    -0
      modelscope/models/nlp/space_for_dialog_modeling.py
  2. +4
    -0
      modelscope/preprocessors/space/dialog_modeling_preprocessor.py

+ 4
- 0
modelscope/models/nlp/space_for_dialog_modeling.py View File

@@ -31,6 +31,10 @@ class SpaceForDialogModeling(Model):
'config', 'config',
Config.from_file( Config.from_file(
os.path.join(self.model_dir, ModelFile.CONFIGURATION))) os.path.join(self.model_dir, ModelFile.CONFIGURATION)))

import torch
self.config.use_gpu = self.config.use_gpu and torch.cuda.is_available()

self.text_field = kwargs.pop( self.text_field = kwargs.pop(
'text_field', 'text_field',
MultiWOZBPETextField(self.model_dir, config=self.config)) MultiWOZBPETextField(self.model_dir, config=self.config))


+ 4
- 0
modelscope/preprocessors/space/dialog_modeling_preprocessor.py View File

@@ -29,6 +29,10 @@ class DialogModelingPreprocessor(Preprocessor):
self.model_dir: str = model_dir self.model_dir: str = model_dir
self.config = Config.from_file( self.config = Config.from_file(
os.path.join(self.model_dir, ModelFile.CONFIGURATION)) os.path.join(self.model_dir, ModelFile.CONFIGURATION))

import torch
self.config.use_gpu = self.config.use_gpu and torch.cuda.is_available()

self.text_field = MultiWOZBPETextField( self.text_field = MultiWOZBPETextField(
self.model_dir, config=self.config) self.model_dir, config=self.config)




Loading…
Cancel
Save