| @@ -278,7 +278,7 @@ class Trainers(object): | |||||
| # multi-modal trainers | # multi-modal trainers | ||||
| clip_multi_modal_embedding = 'clip-multi-modal-embedding' | clip_multi_modal_embedding = 'clip-multi-modal-embedding' | ||||
| ofa_tasks = 'ofa-tasks-trainer' | |||||
| ofa_tasks = 'ofa' | |||||
| # cv trainers | # cv trainers | ||||
| image_instance_segmentation = 'image-instance-segmentation' | image_instance_segmentation = 'image-instance-segmentation' | ||||
| @@ -83,6 +83,8 @@ class OfaPreprocessor(Preprocessor): | |||||
| return data | return data | ||||
| def _compatible_with_pretrain(self, data): | def _compatible_with_pretrain(self, data): | ||||
| # 预训练的时候使用的image都是经过pil转换的,PIL save的时候一般会进行有损压缩,为了保证和预训练一致 | |||||
| # 所以增加了这个逻辑 | |||||
| if 'image' in data and self.cfg.model.get('type', None) == 'ofa': | if 'image' in data and self.cfg.model.get('type', None) == 'ofa': | ||||
| if isinstance(data['image'], str): | if isinstance(data['image'], str): | ||||
| image = load_image(data['image']) | image = load_image(data['image']) | ||||
| @@ -1,17 +0,0 @@ | |||||
| import pdb | |||||
| import sys | |||||
| class ForkedPdb(pdb.Pdb): | |||||
| """A Pdb subclass that may be used | |||||
| from a forked multiprocessing child | |||||
| """ | |||||
| def interaction(self, *args, **kwargs): | |||||
| _stdin = sys.stdin | |||||
| try: | |||||
| sys.stdin = open('/dev/stdin') | |||||
| pdb.Pdb.interaction(self, *args, **kwargs) | |||||
| finally: | |||||
| sys.stdin = _stdin | |||||
| @@ -91,11 +91,8 @@ class OfaTasksTest(unittest.TestCase, DemoCompatibilityCheck): | |||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | ||||
| def test_run_with_text_classification_with_model(self): | def test_run_with_text_classification_with_model(self): | ||||
| # model = Model.from_pretrained( | |||||
| # 'damo/ofa_text-classification_mnli_large_en') | |||||
| model = Model.from_pretrained( | model = Model.from_pretrained( | ||||
| '/apsarapangu/disk2/yichang.zyc/ckpt/MaaS/ofa_text-classification_mnli_large_en' | |||||
| ) | |||||
| 'damo/ofa_text-classification_mnli_large_en') | |||||
| ofa_pipe = pipeline(Tasks.text_classification, model=model) | ofa_pipe = pipeline(Tasks.text_classification, model=model) | ||||
| text = 'One of our number will carry out your instructions minutely.' | text = 'One of our number will carry out your instructions minutely.' | ||||
| text2 = 'A member of my team will execute your orders with immense precision.' | text2 = 'A member of my team will execute your orders with immense precision.' | ||||
| @@ -1,9 +1,12 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| import glob | |||||
| import os | import os | ||||
| import os.path as osp | |||||
| import shutil | import shutil | ||||
| import unittest | import unittest | ||||
| from modelscope.trainers.multi_modal.ofa import OFATrainer | |||||
| from modelscope.metainfo import Trainers | |||||
| from modelscope.trainers import build_trainer | |||||
| from modelscope.utils.test_utils import test_level | from modelscope.utils.test_utils import test_level | ||||
| @@ -11,10 +14,16 @@ class TestOfaTrainer(unittest.TestCase): | |||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | ||||
| def test_trainer(self): | def test_trainer(self): | ||||
| model_id = 'damo/ofa_image-caption_coco_huge_en' | |||||
| self.trainer = OFATrainer(model_id) | |||||
| os.makedirs(self.trainer.work_dir, exist_ok=True) | |||||
| self.trainer.train() | |||||
| os.environ['LOCAL_RANK'] = '0' | |||||
| model_id = 'damo/ofa_text-classification_mnli_large_en' | |||||
| default_args = {'model': model_id} | |||||
| trainer = build_trainer( | |||||
| name=Trainers.ofa_tasks, default_args=default_args) | |||||
| os.makedirs(trainer.work_dir, exist_ok=True) | |||||
| trainer.train() | |||||
| assert len( | |||||
| glob.glob(osp.join(trainer.work_dir, | |||||
| 'best_epoch*_accuracy*.pth'))) == 2 | |||||
| if os.path.exists(self.trainer.work_dir): | if os.path.exists(self.trainer.work_dir): | ||||
| shutil.rmtree(self.trainer.work_dir) | shutil.rmtree(self.trainer.work_dir) | ||||