|
|
|
@@ -42,14 +42,13 @@ from transformers.utils import logging |
|
|
|
|
|
|
|
from modelscope.models.multi_modal.mplug.configuration_mplug import MPlugConfig |
|
|
|
from modelscope.models.multi_modal.mplug.predictor import TextGenerator |
|
|
|
from modelscope.utils.constant import ModelFile |
|
|
|
|
|
|
|
transformers.logging.set_verbosity_error() |
|
|
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
|
|
CONFIG_NAME = 'config.yaml' |
|
|
|
WEIGHTS_NAME = 'pytorch_model.bin' |
|
|
|
VOCAB_NAME = 'vocab.txt' |
|
|
|
|
|
|
|
_CONFIG_FOR_DOC = 'BertConfig' |
|
|
|
_TOKENIZER_FOR_DOC = 'BertTokenizer' |
|
|
|
@@ -1733,7 +1732,7 @@ class MPlugForVisualQuestionAnswering(PreTrainedModel): |
|
|
|
super().__init__(config) |
|
|
|
self.config = config |
|
|
|
self.tokenizer = BertTokenizer.from_pretrained( |
|
|
|
os.path.join(config.model_dir, VOCAB_NAME)) |
|
|
|
os.path.join(config.model_dir, ModelFile.VOCAB_FILE)) |
|
|
|
self.module_setting(config) |
|
|
|
self.visual_encoder = self._initialize_clip(config) |
|
|
|
self.text_encoder = BertModel( |
|
|
|
@@ -1751,7 +1750,8 @@ class MPlugForVisualQuestionAnswering(PreTrainedModel): |
|
|
|
config.model_dir = model_dir |
|
|
|
model = cls(config) |
|
|
|
if load_checkpoint: |
|
|
|
checkpoint_path = os.path.join(model_dir, WEIGHTS_NAME) |
|
|
|
checkpoint_path = os.path.join(model_dir, |
|
|
|
ModelFile.TORCH_MODEL_BIN_FILE) |
|
|
|
checkpoint = torch.load(checkpoint_path, map_location='cpu') |
|
|
|
if 'model' in checkpoint: |
|
|
|
state_dict = checkpoint['model'] |
|
|
|
|