| @@ -19,6 +19,12 @@ class MaskedLMModelBase(Model): | |||||
| def build_model(self): | def build_model(self): | ||||
| raise NotImplementedError() | raise NotImplementedError() | ||||
| def train(self): | |||||
| return self.model.train() | |||||
| def eval(self): | |||||
| return self.model.eval() | |||||
| @property | @property | ||||
| def config(self): | def config(self): | ||||
| if hasattr(self.model, "config"): | if hasattr(self.model, "config"): | ||||
| @@ -26,6 +26,12 @@ class PalmForTextGeneration(Model): | |||||
| self.tokenizer = model.tokenizer | self.tokenizer = model.tokenizer | ||||
| self.generator = Translator(model) | self.generator = Translator(model) | ||||
| def train(self): | |||||
| return self.generator.train() | |||||
| def eval(self): | |||||
| return self.generator.eval() | |||||
| def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | ||||
| """return the result by the model | """return the result by the model | ||||
| @@ -29,6 +29,12 @@ class SbertForTokenClassification(Model): | |||||
| self.model_dir) | self.model_dir) | ||||
| self.config = sofa.SbertConfig.from_pretrained(self.model_dir) | self.config = sofa.SbertConfig.from_pretrained(self.model_dir) | ||||
| def train(self): | |||||
| return self.model.train() | |||||
| def eval(self): | |||||
| return self.model.eval() | |||||
| def forward(self, input: Dict[str, | def forward(self, input: Dict[str, | ||||
| Any]) -> Dict[str, Union[str, np.ndarray]]: | Any]) -> Dict[str, Union[str, np.ndarray]]: | ||||
| """return the result by the model | """return the result by the model | ||||
| @@ -314,7 +314,7 @@ class TextGenerationPreprocessor(Preprocessor): | |||||
| rst['input_ids'].append(feature['input_ids']) | rst['input_ids'].append(feature['input_ids']) | ||||
| rst['attention_mask'].append(feature['attention_mask']) | rst['attention_mask'].append(feature['attention_mask']) | ||||
| rst['token_type_ids'].append(feature['token_type_ids']) | |||||
| # rst['token_type_ids'].append(feature['token_type_ids']) | |||||
| return {k: torch.tensor(v) for k, v in rst.items()} | return {k: torch.tensor(v) for k, v in rst.items()} | ||||