Browse Source

Merge remote-tracking branch 'origin/ofa/finetune_loss' into ofa/finetune

# Conflicts:
#	tests/trainers/test_ofa_trainer.py
master
行嗔 3 years ago
parent
commit
0c64d3fca5
4 changed files with 23 additions and 17 deletions
  1. +3
    -4
      modelscope/trainers/multi_modal/ofa/ofa_trainer.py
  2. +16
    -12
      modelscope/trainers/multi_modal/ofa/ofa_trainer_utils.py
  3. +3
    -1
      tests/trainers/test_ofa_trainer.py
  4. +1
    -0
      tests/trainers/workspace/ckpts/caption/configuration.json

+ 3
- 4
modelscope/trainers/multi_modal/ofa/ofa_trainer.py View File

@@ -129,10 +129,9 @@ class OFATrainer(EpochBasedTrainer):


def train_step(self, model, inputs): def train_step(self, model, inputs):
model.train() model.train()
model_outputs = model.forward(inputs)
loss, sample_size, logging_output = self.criterion(
model_outputs, inputs)
train_outputs = {'loss': loss}
# model_outputs = model.forward(inputs)
loss, sample_size, logging_output = self.criterion(model, inputs)
train_outputs = {'loss': loss / 100}
# add model output info to log # add model output info to log
if 'log_vars' not in train_outputs: if 'log_vars' not in train_outputs:
default_keys_pattern = ['loss'] default_keys_pattern = ['loss']


+ 16
- 12
modelscope/trainers/multi_modal/ofa/ofa_trainer_utils.py View File

@@ -123,7 +123,7 @@ class AdjustLabelSmoothedCrossEntropyCriterion(_Loss):
self.padding_idx = args.tokenizer.pad_token_id self.padding_idx = args.tokenizer.pad_token_id
self.args = args self.args = args


def forward(self, output, sample, update_num=0, reduce=True):
def forward(self, model, sample, update_num=0, reduce=True):
"""Compute the loss for the given sample. """Compute the loss for the given sample.


Returns a tuple with three elements: Returns a tuple with three elements:
@@ -131,15 +131,20 @@ class AdjustLabelSmoothedCrossEntropyCriterion(_Loss):
2) the sample size, which is used as the denominator for the gradient 2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training 3) logging outputs to display while training
""" """
if 'labels' in sample:
del sample['labels']
if 'samples' in sample:
del sample['samples']

if self.use_rdrop: if self.use_rdrop:
construct_rdrop_sample(sample) construct_rdrop_sample(sample)

output = model.model(**sample['net_input'])
loss, nll_loss, ntokens = self.compute_loss( loss, nll_loss, ntokens = self.compute_loss(
output, sample, update_num, reduce=reduce)
output.logits, sample, update_num, reduce=reduce)
sample_size = ( sample_size = (
sample['target'].size(0) if self.sentence_avg else ntokens) sample['target'].size(0) if self.sentence_avg else ntokens)
logging_output = { logging_output = {
'loss': loss.data,
'loss': loss.data / 100,
'nll_loss': nll_loss.data, 'nll_loss': nll_loss.data,
'ntokens': sample['ntokens'], 'ntokens': sample['ntokens'],
'nsentences': sample['nsentences'], 'nsentences': sample['nsentences'],
@@ -147,19 +152,18 @@ class AdjustLabelSmoothedCrossEntropyCriterion(_Loss):
} }
return loss, sample_size, logging_output return loss, sample_size, logging_output


def get_lprobs_and_target(self, net_output, sample):
def get_lprobs_and_target(self, logits, sample):
conf = sample['conf'][:, None, None] if 'conf' in sample and sample[ conf = sample['conf'][:, None, None] if 'conf' in sample and sample[
'conf'] is not None else 1 'conf'] is not None else 1
constraint_masks = None constraint_masks = None
if 'constraint_masks' in sample and sample[ if 'constraint_masks' in sample and sample[
'constraint_masks'] is not None: 'constraint_masks'] is not None:
constraint_masks = sample['constraint_masks'] constraint_masks = sample['constraint_masks']
net_output[0].masked_fill_(~constraint_masks, -math.inf)
logits.masked_fill_(~constraint_masks, -math.inf)
if self.constraint_start is not None and self.constraint_end is not None: if self.constraint_start is not None and self.constraint_end is not None:
net_output[0][:, :, 4:self.constraint_start] = -math.inf
net_output[0][:, :, self.constraint_end:] = -math.inf
lprobs = F.log_softmax(
net_output[0], dim=-1, dtype=torch.float32) * conf
logits[:, :, 4:self.constraint_start] = -math.inf
logits[:, :, self.constraint_end:] = -math.inf
lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32) * conf
target = sample['target'] target = sample['target']
if self.ignore_prefix_size > 0: if self.ignore_prefix_size > 0:
lprobs = lprobs[:, self.ignore_prefix_size:, :].contiguous() lprobs = lprobs[:, self.ignore_prefix_size:, :].contiguous()
@@ -180,9 +184,9 @@ class AdjustLabelSmoothedCrossEntropyCriterion(_Loss):
return lprobs.view(-1, return lprobs.view(-1,
lprobs.size(-1)), target.view(-1), constraint_masks lprobs.size(-1)), target.view(-1), constraint_masks


def compute_loss(self, net_output, sample, update_num, reduce=True):
def compute_loss(self, logits, sample, update_num, reduce=True):
lprobs, target, constraint_masks = self.get_lprobs_and_target( lprobs, target, constraint_masks = self.get_lprobs_and_target(
net_output, sample)
logits, sample)
if constraint_masks is not None: if constraint_masks is not None:
constraint_masks = constraint_masks[target != self.padding_idx] constraint_masks = constraint_masks[target != self.padding_idx]
lprobs = lprobs[target != self.padding_idx] lprobs = lprobs[target != self.padding_idx]


+ 3
- 1
tests/trainers/test_ofa_trainer.py View File

@@ -1,5 +1,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
import glob
import os import os
import os.path as osp
import shutil import shutil
import unittest import unittest


@@ -57,7 +59,7 @@ class TestOfaTrainer(unittest.TestCase):
'report_accuracy': False, 'report_accuracy': False,
'sample_patch_num': 196, 'sample_patch_num': 196,
'sentence_avg': False, 'sentence_avg': False,
'use_rdrop': False},
'use_rdrop': True},
'hooks': [{'type': 'BestCkptSaverHook', 'hooks': [{'type': 'BestCkptSaverHook',
'metric_key': 'bleu-4', 'metric_key': 'bleu-4',
'interval': 100}, 'interval': 100},


+ 1
- 0
tests/trainers/workspace/ckpts/caption/configuration.json View File

@@ -0,0 +1 @@
{"framework": "pytorch", "task": "image-captioning", "model": {"type": "ofa", "beam_search": {"beam_size": 5, "max_len_b": 16, "min_len": 1, "no_repeat_ngram_size": 0}, "seed": 7, "max_src_length": 256, "language": "en", "gen_type": "generation", "patch_image_size": 480, "max_image_size": 480, "imagenet_default_mean_and_std": false}, "pipeline": {"type": "image-captioning"}, "dataset": {"column_map": {"text": "caption"}}, "train": {"work_dir": "work/ckpts/caption", "max_epochs": 1, "use_fp16": true, "dataloader": {"batch_size_per_gpu": 4, "workers_per_gpu": 0}, "lr_scheduler": {"name": "polynomial_decay", "warmup_proportion": 0.01, "lr_end": 1e-07}, "lr_scheduler_hook": {"type": "LrSchedulerHook", "by_epoch": false}, "optimizer": {"type": "AdamW", "lr": 5e-05, "weight_decay": 0.01}, "optimizer_hook": {"type": "TorchAMPOptimizerHook", "cumulative_iters": 1, "grad_clip": {"max_norm": 1.0, "norm_type": 2}, "loss_keys": "loss"}, "criterion": {"name": "AdjustLabelSmoothedCrossEntropyCriterion", "constraint_range": null, "drop_worst_after": 0, "drop_worst_ratio": 0.0, "ignore_eos": false, "ignore_prefix_size": 0, "label_smoothing": 0.0, "reg_alpha": 1.0, "report_accuracy": false, "sample_patch_num": 196, "sentence_avg": false, "use_rdrop": true}, "hooks": [{"type": "BestCkptSaverHook", "metric_key": "bleu-4", "interval": 100}, {"type": "TextLoggerHook", "interval": 1}, {"type": "IterTimerHook"}, {"type": "EvaluationHook", "by_epoch": true, "interval": 1}]}, "evaluation": {"dataloader": {"batch_size_per_gpu": 4, "workers_per_gpu": 0}, "metrics": [{"type": "bleu", "eval_tokenized_bleu": false, "ref_name": "labels", "hyp_name": "caption"}]}, "preprocessor": []}

Loading…
Cancel
Save