Browse Source

[to #42322933] Add S4: child-tuning

1. add child-tuning optimizer and ut
2. fix a training bug which can cause interruption after cross-evaluation
3. move model.params from cfg to default args in build_optimizer to prevent the saving of params in save_pretrained
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9891963
master
yuze.zyz yingda.chen 3 years ago
parent
commit
88d0804dcd
6 changed files with 331 additions and 5 deletions
  1. +2
    -1
      modelscope/trainers/optimizer/__init__.py
  2. +4
    -1
      modelscope/trainers/optimizer/builder.py
  3. +188
    -0
      modelscope/trainers/optimizer/child_tuning_adamw_optimizer.py
  4. +1
    -0
      modelscope/trainers/trainer.py
  5. +130
    -2
      tests/trainers/test_finetune_sequence_classification.py
  6. +6
    -1
      tests/trainers/test_finetune_token_classificatin.py

+ 2
- 1
modelscope/trainers/optimizer/__init__.py View File

@@ -1,4 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from .builder import OPTIMIZERS, build_optimizer
from .child_tuning_adamw_optimizer import ChildTuningAdamW

__all__ = ['OPTIMIZERS', 'build_optimizer']
__all__ = ['OPTIMIZERS', 'build_optimizer', 'ChildTuningAdamW']

+ 4
- 1
modelscope/trainers/optimizer/builder.py View File

@@ -20,7 +20,10 @@ def build_optimizer(model: torch.nn.Module,
"""
if hasattr(model, 'module'):
model = model.module
cfg.params = model.parameters()

if default_args is None:
default_args = {}
default_args['params'] = model.parameters()

return build_from_cfg(
cfg, OPTIMIZERS, group_key=default_group, default_args=default_args)


+ 188
- 0
modelscope/trainers/optimizer/child_tuning_adamw_optimizer.py View File

@@ -0,0 +1,188 @@
# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import types
from typing import Callable, Iterable, Tuple

import numpy as np
import torch
from torch.distributions.bernoulli import Bernoulli
from torch.optim import Optimizer

from modelscope.utils.logger import get_logger
from .builder import OPTIMIZERS, default_group

logger = get_logger(__name__)

__all__ = ['calculate_fisher', 'ChildTuningAdamW']


def calculate_fisher(model: torch.nn.Module,
data_loader,
forward_step,
reserve_p,
grad_clip=None):

gradient_mask = dict()
model.train()
for name, params in model.named_parameters():
if 'layer' in name:
gradient_mask[params] = params.new_zeros(params.size())

iters = len(data_loader)
for inputs in data_loader:
loss = forward_step(model, inputs)
loss.backward()
for name, params in model.named_parameters():
if 'layer' in name:
if grad_clip is not None:
torch.nn.utils.clip_grad_norm_(params, **grad_clip)
gradient_mask[params] += (params.grad**2) / iters
model.zero_grad()

logger.info('Calculate Fisher Information...')

# Numpy
r = None
for k, v in gradient_mask.items():
v = v.view(-1).cpu().numpy()
if r is None:
r = v
else:
r = np.append(r, v)
polar = np.percentile(r, (1 - reserve_p) * 100)
for k in gradient_mask:
gradient_mask[k] = gradient_mask[k] >= polar
print('Polar => {}'.format(polar))

# TODO: pytorch: torch.kthvalue

return gradient_mask


@OPTIMIZERS.register_module(
group_key=default_group, module_name='ChildTuningAdamW')
class ChildTuningAdamW(Optimizer):

def __init__(self,
params: Iterable[torch.nn.parameter.Parameter],
lr: float = 1e-3,
betas: Tuple[float, float] = (0.9, 0.999),
eps: float = 1e-6,
weight_decay: float = 0.0,
correct_bias: bool = True,
reserve_p=1.0,
mode=None):
if lr < 0.0:
raise ValueError(
'Invalid learning rate: {} - should be >= 0.0'.format(lr))
if not 0.0 <= betas[0] < 1.0:
raise ValueError(
'Invalid beta parameter: {} - should be in [0.0, 1.0['.format(
betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError(
'Invalid beta parameter: {} - should be in [0.0, 1.0['.format(
betas[1]))
if not 0.0 <= eps:
raise ValueError(
'Invalid epsilon value: {} - should be >= 0.0'.format(eps))
defaults = dict(
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
correct_bias=correct_bias)
super().__init__(params, defaults)

self.gradient_mask = None
self.reserve_p = reserve_p
self.mode = mode

def set_gradient_mask(self, gradient_mask):
self.gradient_mask = gradient_mask

def step(self, closure: Callable = None):
"""
Performs a single optimization step.
Arguments:
closure (:obj:`Callable`, `optional`): A closure that reevaluates the model and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError(
'Adam does not support sparse gradients, please consider SparseAdam instead'
)

# ChildTuning code
if self.mode is not None:
if self.mode == 'ChildTuning-D':
if p in self.gradient_mask:
grad *= self.gradient_mask[p]
else:
# ChildTuning-F
grad_mask = Bernoulli(
grad.new_full(
size=grad.size(), fill_value=self.reserve_p))
grad *= grad_mask.sample() / self.reserve_p

state = self.state[p]

# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data)

exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['betas']

state['step'] += 1

# Decay the first and second moment running average coefficient
# In-place operations to update the averages at the same time
exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
denom = exp_avg_sq.sqrt().add_(group['eps'])

step_size = group['lr']
if group['correct_bias']: # No bias correction for Bert
bias_correction1 = 1.0 - beta1**state['step']
bias_correction2 = 1.0 - beta2**state['step']
step_size = step_size * math.sqrt(
bias_correction2) / bias_correction1

p.data.addcdiv_(exp_avg, denom, value=-step_size)

# Just adding the square of the weights to the loss function is *not*
# the correct way of using L2 regularization/weight decay with Adam,
# since that will interact with the m and v parameters in strange ways.
#
# Instead we want to decay the weights in a manner that doesn't interact
# with the m/v parameters. This is equivalent to adding the square
# of the weights to the loss with plain (non-momentum) SGD.
# Add weight decay at the end (fixed version)
p.data.add_(p.data, alpha=-group['lr'] * group['weight_decay'])

return loss

+ 1
- 0
modelscope/trainers/trainer.py View File

@@ -800,6 +800,7 @@ class EpochBasedTrainer(BaseTrainer):
self.invoke_hook(TrainerStages.after_train_iter)
del self.data_batch
self._iter += 1
self._mode = ModeKeys.TRAIN

if i + 1 >= self.iters_per_epoch:
break


+ 130
- 2
tests/trainers/test_finetune_sequence_classification.py View File

@@ -6,9 +6,15 @@ import unittest

from modelscope.metainfo import Preprocessors, Trainers
from modelscope.models import Model
from modelscope.msdatasets import MsDataset
from modelscope.pipelines import pipeline
from modelscope.trainers import build_trainer
from modelscope.trainers.hooks import Hook
from modelscope.trainers.nlp_trainer import NlpEpochBasedTrainer
from modelscope.trainers.optimizer.child_tuning_adamw_optimizer import \
calculate_fisher
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.data_utils import to_device


class TestFinetuneSequenceClassification(unittest.TestCase):
@@ -69,6 +75,10 @@ class TestFinetuneSequenceClassification(unittest.TestCase):

@unittest.skip
def test_finetune_afqmc(self):
"""This unittest is used to reproduce the clue:afqmc dataset + structbert model training results.

User can train a custom dataset by modifying this piece of code and comment the @unittest.skip.
"""

def cfg_modify_fn(cfg):
cfg.task = Tasks.sentence_similarity
@@ -114,7 +124,7 @@ class TestFinetuneSequenceClassification(unittest.TestCase):
dc.local_files_only = True
dataset = load_dataset('clue', 'afqmc', download_config=dc)
self.finetune(
model_id='damo/nlp_structbert_backbone_tiny_std',
model_id='damo/nlp_structbert_backbone_base_std',
train_dataset=dataset['train'],
eval_dataset=dataset['validation'],
cfg_modify_fn=cfg_modify_fn)
@@ -124,6 +134,10 @@ class TestFinetuneSequenceClassification(unittest.TestCase):

@unittest.skip
def test_finetune_tnews(self):
"""This unittest is used to reproduce the clue:tnews dataset + structbert model training results.

User can train a custom dataset by modifying this piece of code and comment the @unittest.skip.
"""

def cfg_modify_fn(cfg):
# TODO no proper task for tnews
@@ -175,13 +189,21 @@ class TestFinetuneSequenceClassification(unittest.TestCase):
dataset = load_dataset('clue', 'tnews', download_config=dc)

self.finetune(
model_id='damo/nlp_structbert_backbone_tiny_std',
model_id='damo/nlp_structbert_backbone_base_std',
train_dataset=dataset['train'],
eval_dataset=dataset['validation'],
cfg_modify_fn=cfg_modify_fn)

@unittest.skip
def test_veco_xnli(self):
"""This unittest is used to reproduce the xnli dataset + veco model training results.

Here we follow the training scenario listed in the Alicemind open source project:
https://github.com/alibaba/AliceMind/tree/main/VECO
by training the english language subset.
User can train a custom dataset by modifying this piece of code and comment the @unittest.skip.
"""

from datasets import load_dataset
langs = ['en']
langs_eval = ['en']
@@ -267,6 +289,112 @@ class TestFinetuneSequenceClassification(unittest.TestCase):
name=Trainers.nlp_veco_trainer,
cfg_modify_fn=cfg_modify_fn)

@unittest.skip
def test_finetune_cluewsc(self):
"""This unittest is used to reproduce the clue:wsc dataset + structbert model training results.

A runnable sample of child-tuning is also showed here.

User can train a custom dataset by modifying this piece of code and comment the @unittest.skip.
"""

child_tuning_type = 'ChildTuning-F'
mode = {}
if child_tuning_type is not None:
mode = {'mode': child_tuning_type, 'reserve_p': 0.2}

def cfg_modify_fn(cfg):
cfg.task = 'nli'
cfg['preprocessor'] = {'type': 'nli-tokenizer'}
cfg['dataset'] = {
'train': {
'labels': ['0', '1'],
'first_sequence': 'text',
'second_sequence': 'text2',
'label': 'label',
}
}
cfg.train.dataloader.batch_size_per_gpu = 16
cfg.train.max_epochs = 30
cfg.train.optimizer = {
'type':
'AdamW' if child_tuning_type is None else 'ChildTuningAdamW',
'lr': 1e-5,
'options': {},
**mode,
}
cfg.train.lr_scheduler = {
'type':
'LinearLR',
'start_factor':
1.0,
'end_factor':
0.0,
'total_iters':
int(
len(dataset['train'])
/ cfg.train.dataloader.batch_size_per_gpu)
* cfg.train.max_epochs,
'options': {
'by_epoch': False
}
}
cfg.train.hooks = [{
'type': 'CheckpointHook',
'interval': 1
}, {
'type': 'TextLoggerHook',
'interval': 1
}, {
'type': 'IterTimerHook'
}, {
'type': 'EvaluationHook',
'by_epoch': False,
'interval': 30
}]
return cfg

def add_sentence2(features):
return {
'text2':
features['target']['span2_text'] + '指代'
+ features['target']['span1_text']
}

dataset = MsDataset.load('clue', subset_name='cluewsc2020')
dataset = {
k: v.to_hf_dataset().map(add_sentence2)
for k, v in dataset.items()
}

kwargs = dict(
model='damo/nlp_structbert_backbone_base_std',
train_dataset=dataset['train'],
eval_dataset=dataset['validation'],
work_dir=self.tmp_dir,
cfg_modify_fn=cfg_modify_fn)

os.environ['LOCAL_RANK'] = '0'
trainer: NlpEpochBasedTrainer = build_trainer(
name=Trainers.nlp_base_trainer, default_args=kwargs)

class CalculateFisherHook(Hook):

@staticmethod
def forward_step(model, inputs):
inputs = to_device(inputs, trainer.device)
trainer.train_step(model, inputs)
return trainer.train_outputs['loss']

def before_run(self, trainer: NlpEpochBasedTrainer):
v = calculate_fisher(trainer.model, trainer.train_dataloader,
self.forward_step, 0.2)
trainer.optimizer.set_gradient_mask(v)

if child_tuning_type == 'ChildTuning-D':
trainer.register_hook(CalculateFisherHook())
trainer.train()


if __name__ == '__main__':
unittest.main()

+ 6
- 1
tests/trainers/test_finetune_token_classificatin.py View File

@@ -47,6 +47,11 @@ class TestFinetuneTokenClassification(unittest.TestCase):

@unittest.skip
def test_word_segmentation(self):
"""This unittest is used to reproduce the icwb2:pku dataset + structbert model training results.

User can train a custom dataset by modifying this piece of code and comment the @unittest.skip.
"""

os.system(
f'curl http://sighan.cs.uchicago.edu/bakeoff2005/data/icwb2-data.zip > {self.tmp_dir}/icwb2-data.zip'
)
@@ -114,7 +119,7 @@ class TestFinetuneTokenClassification(unittest.TestCase):
return cfg

self.finetune(
'damo/nlp_structbert_backbone_tiny_std',
'damo/nlp_structbert_backbone_base_std',
train_dataset,
dev_dataset,
cfg_modify_fn=cfg_modify_fn)


Loading…
Cancel
Save