Browse Source

修复初始化过程参数未生效问题

此前文生图模型没有加载configuration.json中的参数 影响默认配置
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10558026
master
menrui.mr yingda.chen 3 years ago
parent
commit
c7b0787049
2 changed files with 9 additions and 1 deletions
  1. +7
    -1
      modelscope/models/multi_modal/ofa_for_text_to_image_synthesis_model.py
  2. +2
    -0
      tests/pipelines/test_ofa_tasks.py

+ 7
- 1
modelscope/models/multi_modal/ofa_for_text_to_image_synthesis_model.py View File

@@ -1,6 +1,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.


import os import os
from os import path as osp
from typing import Any, Dict from typing import Any, Dict


import json import json
@@ -23,7 +24,8 @@ from modelscope.models.multi_modal.ofa import OFAModel, OFATokenizer
from modelscope.models.multi_modal.ofa.generate import sequence_generator as sg from modelscope.models.multi_modal.ofa.generate import sequence_generator as sg
from modelscope.models.multi_modal.ofa.generate.search import Sampling from modelscope.models.multi_modal.ofa.generate.search import Sampling
from modelscope.models.multi_modal.ofa.generate.utils import move_to_device from modelscope.models.multi_modal.ofa.generate.utils import move_to_device
from modelscope.utils.constant import Tasks
from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile, Tasks


try: try:
from torchvision.transforms import InterpolationMode from torchvision.transforms import InterpolationMode
@@ -133,6 +135,8 @@ class OfaForTextToImageSynthesis(Model):
super().__init__(model_dir=model_dir, *args, **kwargs) super().__init__(model_dir=model_dir, *args, **kwargs)
# Initialize ofa # Initialize ofa
model = OFAModel.from_pretrained(model_dir) model = OFAModel.from_pretrained(model_dir)
self.cfg = Config.from_file(
osp.join(model_dir, ModelFile.CONFIGURATION))
self.model = model.module if hasattr(model, 'module') else model self.model = model.module if hasattr(model, 'module') else model
self.tokenizer = OFATokenizer.from_pretrained(model_dir) self.tokenizer = OFATokenizer.from_pretrained(model_dir)
self.tokenizer.add_tokens(['<code_{}>'.format(i) for i in range(8192)]) self.tokenizer.add_tokens(['<code_{}>'.format(i) for i in range(8192)])
@@ -171,6 +175,8 @@ class OfaForTextToImageSynthesis(Model):
'gen_code': True, 'gen_code': True,
'constraint_range': '50265,58457' 'constraint_range': '50265,58457'
} }
if hasattr(self.cfg.model, 'beam_search'):
sg_args.update(self.cfg.model.beam_search)
self.generator = sg.SequenceGenerator(**sg_args) self.generator = sg.SequenceGenerator(**sg_args)


def clip_tokenize(self, texts, context_length=77, truncate=False): def clip_tokenize(self, texts, context_length=77, truncate=False):


+ 2
- 0
tests/pipelines/test_ofa_tasks.py View File

@@ -243,6 +243,7 @@ class OfaTasksTest(unittest.TestCase, DemoCompatibilityCheck):
def test_run_with_text_to_image_synthesis_with_name(self): def test_run_with_text_to_image_synthesis_with_name(self):
model = 'damo/ofa_text-to-image-synthesis_coco_large_en' model = 'damo/ofa_text-to-image-synthesis_coco_large_en'
ofa_pipe = pipeline(Tasks.text_to_image_synthesis, model=model) ofa_pipe = pipeline(Tasks.text_to_image_synthesis, model=model)
ofa_pipe.model.generator.beam_size = 2
example = {'text': 'a bear in the water.'} example = {'text': 'a bear in the water.'}
result = ofa_pipe(example) result = ofa_pipe(example)
result[OutputKeys.OUTPUT_IMG].save('result.png') result[OutputKeys.OUTPUT_IMG].save('result.png')
@@ -253,6 +254,7 @@ class OfaTasksTest(unittest.TestCase, DemoCompatibilityCheck):
model = Model.from_pretrained( model = Model.from_pretrained(
'damo/ofa_text-to-image-synthesis_coco_large_en') 'damo/ofa_text-to-image-synthesis_coco_large_en')
ofa_pipe = pipeline(Tasks.text_to_image_synthesis, model=model) ofa_pipe = pipeline(Tasks.text_to_image_synthesis, model=model)
ofa_pipe.model.generator.beam_size = 2
example = {'text': 'a bear in the water.'} example = {'text': 'a bear in the water.'}
result = ofa_pipe(example) result = ofa_pipe(example)
result[OutputKeys.OUTPUT_IMG].save('result.png') result[OutputKeys.OUTPUT_IMG].save('result.png')


Loading…
Cancel
Save