Browse Source

[to #42322933] fix sample collate bug

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9615938

    * add init and make demo compatible

* make demo compatible

* fix comments

* add distilled ut

* Merge remote-tracking branch 'origin/master' into ofa/bug_fix

* fix sample collate bug

* Merge remote-tracking branch 'origin/master' into ofa/bug_fix

# Conflicts:
#	tests/pipelines/test_ofa_tasks.py

* Merge remote-tracking branch 'origin/master' into ofa/bug_fix

* fix sample collate bug
master
yichang.zyc 3 years ago
parent
commit
9db97366eb
2 changed files with 8 additions and 3 deletions
  1. +4
    -1
      modelscope/preprocessors/multi_modal.py
  2. +4
    -2
      tests/pipelines/test_ofa_tasks.py

+ 4
- 1
modelscope/preprocessors/multi_modal.py View File

@@ -76,7 +76,10 @@ class OfaPreprocessor(Preprocessor):
else: else:
data = self._build_dict(input) data = self._build_dict(input)
sample = self.preprocess(data) sample = self.preprocess(data)
sample['sample'] = data
str_data = dict()
for k, v in data.items():
str_data[k] = str(v)
sample['sample'] = str_data
return collate_fn([sample], return collate_fn([sample],
pad_idx=self.tokenizer.pad_token_id, pad_idx=self.tokenizer.pad_token_id,
eos_idx=self.tokenizer.eos_token_id) eos_idx=self.tokenizer.eos_token_id)


+ 4
- 2
tests/pipelines/test_ofa_tasks.py View File

@@ -1,6 +1,8 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
import unittest import unittest


from PIL import Image

from modelscope.models import Model from modelscope.models import Model
from modelscope.outputs import OutputKeys from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline from modelscope.pipelines import pipeline
@@ -181,8 +183,8 @@ class OfaTasksTest(unittest.TestCase):
task=Tasks.image_captioning, task=Tasks.image_captioning,
model=model, model=model,
) )
result = img_captioning(
{'image': 'data/test/images/image_captioning.png'})
image = Image.open('data/test/images/image_captioning.png')
result = img_captioning(image)
print(result[OutputKeys.CAPTION]) print(result[OutputKeys.CAPTION])


@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')


Loading…
Cancel
Save