Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10750369master
| @@ -31,7 +31,10 @@ class ReferringVideoObjectSegmentation(TorchModel): | |||||
| config_path = osp.join(model_dir, ModelFile.CONFIGURATION) | config_path = osp.join(model_dir, ModelFile.CONFIGURATION) | ||||
| self.cfg = Config.from_file(config_path) | self.cfg = Config.from_file(config_path) | ||||
| self.model = MTTR(**self.cfg.model) | |||||
| transformer_cfg_dir = osp.join(model_dir, 'transformer_cfg_dir') | |||||
| self.model = MTTR( | |||||
| transformer_cfg_dir=transformer_cfg_dir, **self.cfg.model) | |||||
| model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE) | model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE) | ||||
| params_dict = torch.load(model_path, map_location='cpu') | params_dict = torch.load(model_path, map_location='cpu') | ||||
| @@ -19,6 +19,7 @@ class MTTR(nn.Module): | |||||
| num_queries, | num_queries, | ||||
| mask_kernels_dim=8, | mask_kernels_dim=8, | ||||
| aux_loss=False, | aux_loss=False, | ||||
| transformer_cfg_dir=None, | |||||
| **kwargs): | **kwargs): | ||||
| """ | """ | ||||
| Parameters: | Parameters: | ||||
| @@ -29,7 +30,9 @@ class MTTR(nn.Module): | |||||
| """ | """ | ||||
| super().__init__() | super().__init__() | ||||
| self.backbone = init_backbone(**kwargs) | self.backbone = init_backbone(**kwargs) | ||||
| self.transformer = MultimodalTransformer(**kwargs) | |||||
| assert transformer_cfg_dir is not None | |||||
| self.transformer = MultimodalTransformer( | |||||
| transformer_cfg_dir=transformer_cfg_dir, **kwargs) | |||||
| d_model = self.transformer.d_model | d_model = self.transformer.d_model | ||||
| self.is_referred_head = nn.Linear( | self.is_referred_head = nn.Linear( | ||||
| d_model, | d_model, | ||||
| @@ -26,6 +26,7 @@ class MultimodalTransformer(nn.Module): | |||||
| num_decoder_layers=3, | num_decoder_layers=3, | ||||
| text_encoder_type='roberta-base', | text_encoder_type='roberta-base', | ||||
| freeze_text_encoder=True, | freeze_text_encoder=True, | ||||
| transformer_cfg_dir=None, | |||||
| **kwargs): | **kwargs): | ||||
| super().__init__() | super().__init__() | ||||
| self.d_model = kwargs['d_model'] | self.d_model = kwargs['d_model'] | ||||
| @@ -40,10 +41,12 @@ class MultimodalTransformer(nn.Module): | |||||
| self.pos_encoder_2d = PositionEmbeddingSine2D() | self.pos_encoder_2d = PositionEmbeddingSine2D() | ||||
| self._reset_parameters() | self._reset_parameters() | ||||
| self.text_encoder = RobertaModel.from_pretrained(text_encoder_type) | |||||
| if text_encoder_type != 'roberta-base': | |||||
| transformer_cfg_dir = text_encoder_type | |||||
| self.text_encoder = RobertaModel.from_pretrained(transformer_cfg_dir) | |||||
| self.text_encoder.pooler = None # this pooler is never used, this is a hack to avoid DDP problems... | self.text_encoder.pooler = None # this pooler is never used, this is a hack to avoid DDP problems... | ||||
| self.tokenizer = RobertaTokenizerFast.from_pretrained( | self.tokenizer = RobertaTokenizerFast.from_pretrained( | ||||
| text_encoder_type) | |||||
| transformer_cfg_dir) | |||||
| self.freeze_text_encoder = freeze_text_encoder | self.freeze_text_encoder = freeze_text_encoder | ||||
| if freeze_text_encoder: | if freeze_text_encoder: | ||||
| for p in self.text_encoder.parameters(): | for p in self.text_encoder.parameters(): | ||||