Browse Source

[to #42322933] fix sample collate bug

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9615938

    * add init and make demo compatible

* make demo compatible

* fix comments

* add distilled ut

* Merge remote-tracking branch 'origin/master' into ofa/bug_fix

* fix sample collate bug

* Merge remote-tracking branch 'origin/master' into ofa/bug_fix

# Conflicts:
#	tests/pipelines/test_ofa_tasks.py

* Merge remote-tracking branch 'origin/master' into ofa/bug_fix

* fix sample collate bug
master
yichang.zyc 3 years ago
parent
commit
9db97366eb
2 changed files with 8 additions and 3 deletions
  1. +4
    -1
      modelscope/preprocessors/multi_modal.py
  2. +4
    -2
      tests/pipelines/test_ofa_tasks.py

+ 4
- 1
modelscope/preprocessors/multi_modal.py View File

@@ -76,7 +76,10 @@ class OfaPreprocessor(Preprocessor):
else:
data = self._build_dict(input)
sample = self.preprocess(data)
sample['sample'] = data
str_data = dict()
for k, v in data.items():
str_data[k] = str(v)
sample['sample'] = str_data
return collate_fn([sample],
pad_idx=self.tokenizer.pad_token_id,
eos_idx=self.tokenizer.eos_token_id)


+ 4
- 2
tests/pipelines/test_ofa_tasks.py View File

@@ -1,6 +1,8 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest

from PIL import Image

from modelscope.models import Model
from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
@@ -181,8 +183,8 @@ class OfaTasksTest(unittest.TestCase):
task=Tasks.image_captioning,
model=model,
)
result = img_captioning(
{'image': 'data/test/images/image_captioning.png'})
image = Image.open('data/test/images/image_captioning.png')
result = img_captioning(image)
print(result[OutputKeys.CAPTION])

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')


Loading…
Cancel
Save