Browse Source

all ut passed

master
雨泓 3 years ago
parent
commit
1476e08b82
4 changed files with 19 additions and 1 deletions
  1. +6
    -0
      modelscope/models/nlp/masked_language_model.py
  2. +6
    -0
      modelscope/models/nlp/palm_for_text_generation.py
  3. +6
    -0
      modelscope/models/nlp/sbert_for_token_classification.py
  4. +1
    -1
      modelscope/preprocessors/nlp.py

+ 6
- 0
modelscope/models/nlp/masked_language_model.py View File

@@ -19,6 +19,12 @@ class MaskedLMModelBase(Model):
def build_model(self):
raise NotImplementedError()

def train(self):
return self.model.train()

def eval(self):
return self.model.eval()

@property
def config(self):
if hasattr(self.model, "config"):


+ 6
- 0
modelscope/models/nlp/palm_for_text_generation.py View File

@@ -26,6 +26,12 @@ class PalmForTextGeneration(Model):
self.tokenizer = model.tokenizer
self.generator = Translator(model)

def train(self):
return self.generator.train()

def eval(self):
return self.generator.eval()

def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
"""return the result by the model



+ 6
- 0
modelscope/models/nlp/sbert_for_token_classification.py View File

@@ -29,6 +29,12 @@ class SbertForTokenClassification(Model):
self.model_dir)
self.config = sofa.SbertConfig.from_pretrained(self.model_dir)

def train(self):
return self.model.train()

def eval(self):
return self.model.eval()
def forward(self, input: Dict[str,
Any]) -> Dict[str, Union[str, np.ndarray]]:
"""return the result by the model


+ 1
- 1
modelscope/preprocessors/nlp.py View File

@@ -314,7 +314,7 @@ class TextGenerationPreprocessor(Preprocessor):

rst['input_ids'].append(feature['input_ids'])
rst['attention_mask'].append(feature['attention_mask'])
rst['token_type_ids'].append(feature['token_type_ids'])
# rst['token_type_ids'].append(feature['token_type_ids'])
return {k: torch.tensor(v) for k, v in rst.items()}




Loading…
Cancel
Save