Browse Source

modify ofatrainer

master
翎航 3 years ago
parent
commit
9b8cfc4ece
2 changed files with 40 additions and 10 deletions
  1. +8
    -7
      modelscope/trainers/multi_modal/ofa/ofa_trainer.py
  2. +32
    -3
      tests/trainers/test_ofa_trainer.py

+ 8
- 7
modelscope/trainers/multi_modal/ofa/ofa_trainer.py View File

@@ -24,12 +24,13 @@ from .ofa_trainer_utils import (AdjustLabelSmoothedCrossEntropyCriterion,
@TRAINERS.register_module(module_name=Trainers.ofa_tasks) @TRAINERS.register_module(module_name=Trainers.ofa_tasks)
class OFATrainer(EpochBasedTrainer): class OFATrainer(EpochBasedTrainer):


def __init__(self, model: str, *args, **kwargs):
def __init__(self, model: str, cfg_file, work_dir, train_dataset,
eval_dataset, *args, **kwargs):
model = Model.from_pretrained(model) model = Model.from_pretrained(model)
model_dir = model.model_dir model_dir = model.model_dir
cfg_file = os.path.join(model_dir, ModelFile.CONFIGURATION)
# cfg_file = os.path.join(model_dir, ModelFile.CONFIGURATION)
cfg = Config.from_file(cfg_file) cfg = Config.from_file(cfg_file)
dataset = self._build_dataset_with_config(cfg)
# dataset = self._build_dataset_with_config(cfg)
preprocessor = { preprocessor = {
ConfigKeys.train: ConfigKeys.train:
OfaPreprocessor( OfaPreprocessor(
@@ -41,7 +42,7 @@ class OFATrainer(EpochBasedTrainer):
# use torchrun launch # use torchrun launch
world_size = int(os.environ.get('WORLD_SIZE', 1)) world_size = int(os.environ.get('WORLD_SIZE', 1))
epoch_steps = math.ceil( epoch_steps = math.ceil(
len(dataset['train']) / # noqa
len(train_dataset) / # noqa
(cfg.train.dataloader.batch_size_per_gpu * world_size)) # noqa (cfg.train.dataloader.batch_size_per_gpu * world_size)) # noqa
cfg.train.lr_scheduler.num_train_steps = epoch_steps * cfg.train.max_epochs cfg.train.lr_scheduler.num_train_steps = epoch_steps * cfg.train.max_epochs
cfg.train.criterion.tokenizer = model.tokenizer cfg.train.criterion.tokenizer = model.tokenizer
@@ -68,11 +69,11 @@ class OFATrainer(EpochBasedTrainer):
cfg_file=cfg_file, cfg_file=cfg_file,
model=model, model=model,
data_collator=collator, data_collator=collator,
train_dataset=dataset['train'],
eval_dataset=dataset['valid'],
train_dataset=train_dataset,
eval_dataset=eval_dataset,
preprocessor=preprocessor, preprocessor=preprocessor,
optimizers=(optimizer, lr_scheduler), optimizers=(optimizer, lr_scheduler),
work_dir=cfg.train.work_dir,
work_dir=work_dir,
*args, *args,
**kwargs, **kwargs,
) )


+ 32
- 3
tests/trainers/test_ofa_trainer.py View File

@@ -3,22 +3,51 @@ import glob
import os import os
import os.path as osp import os.path as osp
import shutil import shutil
import tempfile
import unittest import unittest


from modelscope.metainfo import Trainers from modelscope.metainfo import Trainers
from modelscope.msdatasets import MsDataset
from modelscope.trainers import build_trainer from modelscope.trainers import build_trainer
from modelscope.utils.constant import DownloadMode
from modelscope.utils.test_utils import test_level from modelscope.utils.test_utils import test_level




class TestOfaTrainer(unittest.TestCase): class TestOfaTrainer(unittest.TestCase):


def setUp(self):
column_map = {'premise': 'text', 'hypothesis': 'text2'}
data_train = MsDataset.load(
dataset_name='glue',
subset_name='mnli',
namespace='modelscope',
split='train[:100]',
download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS)
self.train_dataset = MsDataset.from_hf_dataset(
data_train._hf_ds.rename_columns(column_map))
data_eval = MsDataset.load(
dataset_name='glue',
subset_name='mnli',
namespace='modelscope',
split='validation_matched[:8]',
download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS)
self.test_dataset = MsDataset.from_hf_dataset(
data_eval._hf_ds.rename_columns(column_map))

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_trainer(self): def test_trainer(self):
os.environ['LOCAL_RANK'] = '0' os.environ['LOCAL_RANK'] = '0'
model_id = 'damo/ofa_text-classification_mnli_large_en' model_id = 'damo/ofa_text-classification_mnli_large_en'
default_args = {'model': model_id}
trainer = build_trainer(
name=Trainers.ofa_tasks, default_args=default_args)

kwargs = dict(
model=model_id,
cfg_file=
'/Users/running_you/.cache/modelscope/hub/damo/ofa_text-classification_mnli_large_en//configuration.json',
train_dataset=self.train_dataset,
eval_dataset=self.test_dataset,
work_dir='/Users/running_you/.cache/modelscope/hub/work/mnli')

trainer = build_trainer(name=Trainers.ofa_tasks, default_args=kwargs)
os.makedirs(trainer.work_dir, exist_ok=True) os.makedirs(trainer.work_dir, exist_ok=True)
trainer.train() trainer.train()
assert len( assert len(


Loading…
Cancel
Save