Browse Source

all ut passed

master
雨泓 3 years ago
parent
commit
1476e08b82
4 changed files with 19 additions and 1 deletions
  1. +6
    -0
      modelscope/models/nlp/masked_language_model.py
  2. +6
    -0
      modelscope/models/nlp/palm_for_text_generation.py
  3. +6
    -0
      modelscope/models/nlp/sbert_for_token_classification.py
  4. +1
    -1
      modelscope/preprocessors/nlp.py

+ 6
- 0
modelscope/models/nlp/masked_language_model.py View File

@@ -19,6 +19,12 @@ class MaskedLMModelBase(Model):
def build_model(self): def build_model(self):
raise NotImplementedError() raise NotImplementedError()


def train(self):
return self.model.train()

def eval(self):
return self.model.eval()

@property @property
def config(self): def config(self):
if hasattr(self.model, "config"): if hasattr(self.model, "config"):


+ 6
- 0
modelscope/models/nlp/palm_for_text_generation.py View File

@@ -26,6 +26,12 @@ class PalmForTextGeneration(Model):
self.tokenizer = model.tokenizer self.tokenizer = model.tokenizer
self.generator = Translator(model) self.generator = Translator(model)


def train(self):
return self.generator.train()

def eval(self):
return self.generator.eval()

def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
"""return the result by the model """return the result by the model




+ 6
- 0
modelscope/models/nlp/sbert_for_token_classification.py View File

@@ -29,6 +29,12 @@ class SbertForTokenClassification(Model):
self.model_dir) self.model_dir)
self.config = sofa.SbertConfig.from_pretrained(self.model_dir) self.config = sofa.SbertConfig.from_pretrained(self.model_dir)


def train(self):
return self.model.train()

def eval(self):
return self.model.eval()
def forward(self, input: Dict[str, def forward(self, input: Dict[str,
Any]) -> Dict[str, Union[str, np.ndarray]]: Any]) -> Dict[str, Union[str, np.ndarray]]:
"""return the result by the model """return the result by the model


+ 1
- 1
modelscope/preprocessors/nlp.py View File

@@ -314,7 +314,7 @@ class TextGenerationPreprocessor(Preprocessor):


rst['input_ids'].append(feature['input_ids']) rst['input_ids'].append(feature['input_ids'])
rst['attention_mask'].append(feature['attention_mask']) rst['attention_mask'].append(feature['attention_mask'])
rst['token_type_ids'].append(feature['token_type_ids'])
# rst['token_type_ids'].append(feature['token_type_ids'])
return {k: torch.tensor(v) for k, v in rst.items()} return {k: torch.tensor(v) for k, v in rst.items()}






Loading…
Cancel
Save