| @@ -282,7 +282,7 @@ class Trainers(object): | |||||
| # multi-modal trainers | # multi-modal trainers | ||||
| clip_multi_modal_embedding = 'clip-multi-modal-embedding' | clip_multi_modal_embedding = 'clip-multi-modal-embedding' | ||||
| ofa_tasks = 'ofa' | |||||
| ofa = 'ofa' | |||||
| # cv trainers | # cv trainers | ||||
| image_instance_segmentation = 'image-instance-segmentation' | image_instance_segmentation = 'image-instance-segmentation' | ||||
| @@ -74,9 +74,7 @@ class OfaPreprocessor(Preprocessor): | |||||
| data[key] = item | data[key] = item | ||||
| return data | return data | ||||
| def _compatible_with_pretrain(self, data): | |||||
| # 预训练的时候使用的image都是经过pil转换的,PIL save的时候一般会进行有损压缩,为了保证和预训练一致 | |||||
| # 所以增加了这个逻辑 | |||||
| def _ofa_input_compatibility_conversion(self, data): | |||||
| if 'image' in data and self.cfg.model.get('type', None) == 'ofa': | if 'image' in data and self.cfg.model.get('type', None) == 'ofa': | ||||
| if isinstance(data['image'], str): | if isinstance(data['image'], str): | ||||
| image = load_image(data['image']) | image = load_image(data['image']) | ||||
| @@ -95,7 +93,7 @@ class OfaPreprocessor(Preprocessor): | |||||
| data = input | data = input | ||||
| else: | else: | ||||
| data = self._build_dict(input) | data = self._build_dict(input) | ||||
| data = self._compatible_with_pretrain(data) | |||||
| data = self._ofa_input_compatibility_conversion(data) | |||||
| sample = self.preprocess(data) | sample = self.preprocess(data) | ||||
| str_data = dict() | str_data = dict() | ||||
| for k, v in data.items(): | for k, v in data.items(): | ||||
| @@ -27,7 +27,7 @@ from .ofa_trainer_utils import (AdjustLabelSmoothedCrossEntropyCriterion, | |||||
| get_schedule) | get_schedule) | ||||
| @TRAINERS.register_module(module_name=Trainers.ofa_tasks) | |||||
| @TRAINERS.register_module(module_name=Trainers.ofa) | |||||
| class OFATrainer(EpochBasedTrainer): | class OFATrainer(EpochBasedTrainer): | ||||
| def __init__( | def __init__( | ||||
| @@ -93,7 +93,7 @@ class TestOfaTrainer(unittest.TestCase): | |||||
| split='validation[:10]'), | split='validation[:10]'), | ||||
| metrics=[Metrics.BLEU], | metrics=[Metrics.BLEU], | ||||
| cfg_file=config_file) | cfg_file=config_file) | ||||
| trainer = build_trainer(name=Trainers.ofa_tasks, default_args=args) | |||||
| trainer = build_trainer(name=Trainers.ofa, default_args=args) | |||||
| trainer.train() | trainer.train() | ||||
| self.assertIn(ModelFile.TORCH_MODEL_BIN_FILE, | self.assertIn(ModelFile.TORCH_MODEL_BIN_FILE, | ||||