From e76f5a96a3e0a5130ed00b30d827bb92f325a35f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=A1=8C=E5=97=94?= Date: Wed, 19 Oct 2022 17:34:59 +0800 Subject: [PATCH] fix comments --- modelscope/metainfo.py | 2 +- modelscope/preprocessors/multi_modal.py | 2 ++ modelscope/utils/multi_modal/forked_pdb.py | 17 ----------------- tests/pipelines/test_ofa_tasks.py | 5 +---- tests/trainers/test_ofa_trainer.py | 19 ++++++++++++++----- 5 files changed, 18 insertions(+), 27 deletions(-) delete mode 100644 modelscope/utils/multi_modal/forked_pdb.py diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 0b4291f0..c3fe5594 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -278,7 +278,7 @@ class Trainers(object): # multi-modal trainers clip_multi_modal_embedding = 'clip-multi-modal-embedding' - ofa_tasks = 'ofa-tasks-trainer' + ofa_tasks = 'ofa' # cv trainers image_instance_segmentation = 'image-instance-segmentation' diff --git a/modelscope/preprocessors/multi_modal.py b/modelscope/preprocessors/multi_modal.py index 6d06bbb9..73742c47 100644 --- a/modelscope/preprocessors/multi_modal.py +++ b/modelscope/preprocessors/multi_modal.py @@ -83,6 +83,8 @@ class OfaPreprocessor(Preprocessor): return data def _compatible_with_pretrain(self, data): + # 预训练的时候使用的image都是经过pil转换的,PIL save的时候一般会进行有损压缩,为了保证和预训练一致 + # 所以增加了这个逻辑 if 'image' in data and self.cfg.model.get('type', None) == 'ofa': if isinstance(data['image'], str): image = load_image(data['image']) diff --git a/modelscope/utils/multi_modal/forked_pdb.py b/modelscope/utils/multi_modal/forked_pdb.py deleted file mode 100644 index 56107d1f..00000000 --- a/modelscope/utils/multi_modal/forked_pdb.py +++ /dev/null @@ -1,17 +0,0 @@ -import pdb -import sys - - -class ForkedPdb(pdb.Pdb): - """A Pdb subclass that may be used - from a forked multiprocessing child - - """ - - def interaction(self, *args, **kwargs): - _stdin = sys.stdin - try: - sys.stdin = open('/dev/stdin') - pdb.Pdb.interaction(self, *args, **kwargs) - finally: - sys.stdin = _stdin diff --git a/tests/pipelines/test_ofa_tasks.py b/tests/pipelines/test_ofa_tasks.py index 104c2869..f8366508 100644 --- a/tests/pipelines/test_ofa_tasks.py +++ b/tests/pipelines/test_ofa_tasks.py @@ -91,11 +91,8 @@ class OfaTasksTest(unittest.TestCase, DemoCompatibilityCheck): @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_run_with_text_classification_with_model(self): - # model = Model.from_pretrained( - # 'damo/ofa_text-classification_mnli_large_en') model = Model.from_pretrained( - '/apsarapangu/disk2/yichang.zyc/ckpt/MaaS/ofa_text-classification_mnli_large_en' - ) + 'damo/ofa_text-classification_mnli_large_en') ofa_pipe = pipeline(Tasks.text_classification, model=model) text = 'One of our number will carry out your instructions minutely.' text2 = 'A member of my team will execute your orders with immense precision.' diff --git a/tests/trainers/test_ofa_trainer.py b/tests/trainers/test_ofa_trainer.py index c0704061..8aab3544 100644 --- a/tests/trainers/test_ofa_trainer.py +++ b/tests/trainers/test_ofa_trainer.py @@ -1,9 +1,12 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import glob import os +import os.path as osp import shutil import unittest -from modelscope.trainers.multi_modal.ofa import OFATrainer +from modelscope.metainfo import Trainers +from modelscope.trainers import build_trainer from modelscope.utils.test_utils import test_level @@ -11,10 +14,16 @@ class TestOfaTrainer(unittest.TestCase): @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_trainer(self): - model_id = 'damo/ofa_image-caption_coco_huge_en' - self.trainer = OFATrainer(model_id) - os.makedirs(self.trainer.work_dir, exist_ok=True) - self.trainer.train() + os.environ['LOCAL_RANK'] = '0' + model_id = 'damo/ofa_text-classification_mnli_large_en' + default_args = {'model': model_id} + trainer = build_trainer( + name=Trainers.ofa_tasks, default_args=default_args) + os.makedirs(trainer.work_dir, exist_ok=True) + trainer.train() + assert len( + glob.glob(osp.join(trainer.work_dir, + 'best_epoch*_accuracy*.pth'))) == 2 if os.path.exists(self.trainer.work_dir): shutil.rmtree(self.trainer.work_dir)