Browse Source

revert a mis modification

master
雨泓 3 years ago
parent
commit
2eb633ec93
1 changed files with 6 additions and 6 deletions
  1. +6
    -6
      modelscope/models/nlp/masked_language_model.py

+ 6
- 6
modelscope/models/nlp/masked_language_model.py View File

@@ -31,20 +31,20 @@ class MaskedLMModelBase(Model):
return self.model.config return self.model.config
return None return None


def forward(self, inputs: Dict[str, Tensor]) -> Dict[str, np.ndarray]:
def forward(self, input: Dict[str, Tensor]) -> Dict[str, np.ndarray]:
"""return the result by the model """return the result by the model


Args: Args:
inputs (Dict[str, Any]): the preprocessed data
input (Dict[str, Any]): the preprocessed data


Returns: Returns:
Dict[str, np.ndarray]: results Dict[str, np.ndarray]: results
""" """
rst = self.model( rst = self.model(
input_ids=inputs['input_ids'],
attention_mask=inputs['attention_mask'],
token_type_ids=inputs['token_type_ids'])
return {'logits': rst['logits'], 'input_ids': inputs['input_ids']}
input_ids=input['input_ids'],
attention_mask=input['attention_mask'],
token_type_ids=input['token_type_ids'])
return {'logits': rst['logits'], 'input_ids': input['input_ids']}




@MODELS.register_module(Tasks.fill_mask, module_name=Models.structbert) @MODELS.register_module(Tasks.fill_mask, module_name=Models.structbert)


Loading…
Cancel
Save