| @@ -31,20 +31,20 @@ class MaskedLMModelBase(Model): | |||||
| return self.model.config | return self.model.config | ||||
| return None | return None | ||||
| def forward(self, inputs: Dict[str, Tensor]) -> Dict[str, np.ndarray]: | |||||
| def forward(self, input: Dict[str, Tensor]) -> Dict[str, np.ndarray]: | |||||
| """return the result by the model | """return the result by the model | ||||
| Args: | Args: | ||||
| inputs (Dict[str, Any]): the preprocessed data | |||||
| input (Dict[str, Any]): the preprocessed data | |||||
| Returns: | Returns: | ||||
| Dict[str, np.ndarray]: results | Dict[str, np.ndarray]: results | ||||
| """ | """ | ||||
| rst = self.model( | rst = self.model( | ||||
| input_ids=inputs['input_ids'], | |||||
| attention_mask=inputs['attention_mask'], | |||||
| token_type_ids=inputs['token_type_ids']) | |||||
| return {'logits': rst['logits'], 'input_ids': inputs['input_ids']} | |||||
| input_ids=input['input_ids'], | |||||
| attention_mask=input['attention_mask'], | |||||
| token_type_ids=input['token_type_ids']) | |||||
| return {'logits': rst['logits'], 'input_ids': input['input_ids']} | |||||
| @MODELS.register_module(Tasks.fill_mask, module_name=Models.structbert) | @MODELS.register_module(Tasks.fill_mask, module_name=Models.structbert) | ||||