diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 01b08699..f77ff299 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -281,6 +281,7 @@ class Trainers(object): # multi-modal trainers clip_multi_modal_embedding = 'clip-multi-modal-embedding' + ofa = 'ofa' # cv trainers image_instance_segmentation = 'image-instance-segmentation' @@ -375,6 +376,9 @@ class Metrics(object): accuracy = 'accuracy' audio_noise_metric = 'audio-noise-metric' + # text gen + BLEU = 'bleu' + # metrics for image denoise task image_denoise_metric = 'image-denoise-metric' @@ -395,6 +399,8 @@ class Metrics(object): movie_scene_segmentation_metric = 'movie-scene-segmentation-metric' # metric for inpainting task image_inpainting_metric = 'image-inpainting-metric' + # metric for ocr + NED = 'ned' class Optimizers(object): diff --git a/modelscope/metrics/__init__.py b/modelscope/metrics/__init__.py index e6a03a22..c022eaf4 100644 --- a/modelscope/metrics/__init__.py +++ b/modelscope/metrics/__init__.py @@ -17,6 +17,8 @@ if TYPE_CHECKING: from .token_classification_metric import TokenClassificationMetric from .video_summarization_metric import VideoSummarizationMetric from .movie_scene_segmentation_metric import MovieSceneSegmentationMetric + from .accuracy_metric import AccuracyMetric + from .bleu_metric import BleuMetric from .image_inpainting_metric import ImageInpaintingMetric else: @@ -36,6 +38,8 @@ else: 'video_summarization_metric': ['VideoSummarizationMetric'], 'movie_scene_segmentation_metric': ['MovieSceneSegmentationMetric'], 'image_inpainting_metric': ['ImageInpaintingMetric'], + 'accuracy_metric': ['AccuracyMetric'], + 'bleu_metric': ['BleuMetric'], } import sys diff --git a/modelscope/metrics/accuracy_metric.py b/modelscope/metrics/accuracy_metric.py new file mode 100644 index 00000000..1761786e --- /dev/null +++ b/modelscope/metrics/accuracy_metric.py @@ -0,0 +1,46 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Dict + +import numpy as np + +from modelscope.metainfo import Metrics +from modelscope.outputs import OutputKeys +from modelscope.utils.registry import default_group +from .base import Metric +from .builder import METRICS, MetricKeys + + +@METRICS.register_module(group_key=default_group, module_name=Metrics.accuracy) +class AccuracyMetric(Metric): + """The metric computation class for classification classes. + + This metric class calculates accuracy for the whole input batches. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.preds = [] + self.labels = [] + + def add(self, outputs: Dict, inputs: Dict): + label_name = OutputKeys.LABEL if OutputKeys.LABEL in inputs else OutputKeys.LABELS + ground_truths = inputs[label_name] + eval_results = outputs[label_name] + assert type(ground_truths) == type(eval_results) + if isinstance(ground_truths, list): + self.preds.extend(eval_results) + self.labels.extend(ground_truths) + elif isinstance(ground_truths, np.ndarray): + self.preds.extend(eval_results.tolist()) + self.labels.extend(ground_truths.tolist()) + else: + raise 'only support list or np.ndarray' + + def evaluate(self): + assert len(self.preds) == len(self.labels) + return { + MetricKeys.ACCURACY: (np.asarray([ + pred == ref for pred, ref in zip(self.preds, self.labels) + ])).mean().item() + } diff --git a/modelscope/metrics/bleu_metric.py b/modelscope/metrics/bleu_metric.py new file mode 100644 index 00000000..7c134b6a --- /dev/null +++ b/modelscope/metrics/bleu_metric.py @@ -0,0 +1,42 @@ +from itertools import zip_longest +from typing import Dict + +import sacrebleu + +from modelscope.metainfo import Metrics +from modelscope.utils.registry import default_group +from .base import Metric +from .builder import METRICS, MetricKeys + +EVAL_BLEU_ORDER = 4 + + +@METRICS.register_module(group_key=default_group, module_name=Metrics.BLEU) +class BleuMetric(Metric): + """The metric computation bleu for text generation classes. + + This metric class calculates accuracy for the whole input batches. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.eval_tokenized_bleu = kwargs.get('eval_tokenized_bleu', False) + self.hyp_name = kwargs.get('hyp_name', 'hyp') + self.ref_name = kwargs.get('ref_name', 'ref') + self.refs = list() + self.hyps = list() + + def add(self, outputs: Dict, inputs: Dict): + self.refs.extend(inputs[self.ref_name]) + self.hyps.extend(outputs[self.hyp_name]) + + def evaluate(self): + if self.eval_tokenized_bleu: + bleu = sacrebleu.corpus_bleu( + self.hyps, list(zip_longest(*self.refs)), tokenize='none') + else: + bleu = sacrebleu.corpus_bleu(self.hyps, + list(zip_longest(*self.refs))) + return { + MetricKeys.BLEU_4: bleu.score, + } diff --git a/modelscope/metrics/builder.py b/modelscope/metrics/builder.py index 1c8e16d7..da3b64c7 100644 --- a/modelscope/metrics/builder.py +++ b/modelscope/metrics/builder.py @@ -23,6 +23,7 @@ class MetricKeys(object): BLEU_4 = 'bleu-4' ROUGE_1 = 'rouge-1' ROUGE_L = 'rouge-l' + NED = 'ned' # ocr metric task_default_metrics = { diff --git a/modelscope/metrics/ciderD/__init__.py b/modelscope/metrics/ciderD/__init__.py new file mode 100755 index 00000000..3f7d85bb --- /dev/null +++ b/modelscope/metrics/ciderD/__init__.py @@ -0,0 +1 @@ +__author__ = 'tylin' diff --git a/modelscope/metrics/ciderD/ciderD.py b/modelscope/metrics/ciderD/ciderD.py new file mode 100755 index 00000000..05c7eb23 --- /dev/null +++ b/modelscope/metrics/ciderD/ciderD.py @@ -0,0 +1,57 @@ +# Filename: ciderD.py +# +# Description: Describes the class to compute the CIDEr-D (Consensus-Based Image Description Evaluation) Metric +# by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726) +# +# Creation Date: Sun Feb 8 14:16:54 2015 +# +# Authors: Ramakrishna Vedantam and Tsung-Yi Lin +from __future__ import absolute_import, division, print_function + +from .ciderD_scorer import CiderScorer + + +class CiderD: + """ + Main Class to compute the CIDEr metric + + """ + + def __init__(self, n=4, sigma=6.0, df='corpus'): + # set cider to sum over 1 to 4-grams + self._n = n + # set the standard deviation parameter for gaussian penalty + self._sigma = sigma + # set which where to compute document frequencies from + self._df = df + self.cider_scorer = CiderScorer(n=self._n, df_mode=self._df) + + def compute_score(self, gts, res): + """ + Main function to compute CIDEr score + :param hypo_for_image (dict) : dictionary with key and value + ref_for_image (dict) : dictionary with key and value + :return: cider (float) : computed CIDEr score for the corpus + """ # noqa + + # clear all the previous hypos and refs + tmp_cider_scorer = self.cider_scorer.copy_empty() + tmp_cider_scorer.clear() + for res_id in res: + + hypo = res_id['caption'] + ref = gts[res_id['image_id']] + + # Sanity check. + assert (type(hypo) is list) + assert (len(hypo) == 1) + assert (type(ref) is list) + assert (len(ref) > 0) + tmp_cider_scorer += (hypo[0], ref) + + (score, scores) = tmp_cider_scorer.compute_score() + + return score, scores + + def method(self): + return 'CIDEr-D' diff --git a/modelscope/metrics/ciderD/ciderD_scorer.py b/modelscope/metrics/ciderD/ciderD_scorer.py new file mode 100755 index 00000000..4157ec11 --- /dev/null +++ b/modelscope/metrics/ciderD/ciderD_scorer.py @@ -0,0 +1,233 @@ +#!/usr/bin/env python +# Tsung-Yi Lin +# Ramakrishna Vedantam +from __future__ import absolute_import, division, print_function +import copy +import math +import os +import pdb +from collections import defaultdict + +import numpy as np +import six +from six.moves import cPickle + + +def precook(s, n=4, out=False): + """ + Takes a string as input and returns an object that can be given to + either cook_refs or cook_test. This is optional: cook_refs and cook_test + can take string arguments as well. + :param s: string : sentence to be converted into ngrams + :param n: int : number of ngrams for which representation is calculated + :return: term frequency vector for occuring ngrams + """ + words = s.split() + counts = defaultdict(int) + for k in range(1, n + 1): + for i in range(len(words) - k + 1): + ngram = tuple(words[i:i + k]) + counts[ngram] += 1 + return counts + + +def cook_refs(refs, n=4): # lhuang: oracle will call with "average" + '''Takes a list of reference sentences for a single segment + and returns an object that encapsulates everything that BLEU + needs to know about them. + :param refs: list of string : reference sentences for some image + :param n: int : number of ngrams for which (ngram) representation is calculated + :return: result (list of dict) + ''' + return [precook(ref, n) for ref in refs] + + +def cook_test(test, n=4): + '''Takes a test sentence and returns an object that + encapsulates everything that BLEU needs to know about it. + :param test: list of string : hypothesis sentence for some image + :param n: int : number of ngrams for which (ngram) representation is calculated + :return: result (dict) + ''' + return precook(test, n, True) + + +class CiderScorer(object): + """CIDEr scorer. + """ + + def copy(self): + ''' copy the refs.''' + new = CiderScorer(n=self.n) + new.ctest = copy.copy(self.ctest) + new.crefs = copy.copy(self.crefs) + return new + + def copy_empty(self): + new = CiderScorer(df_mode='corpus', n=self.n, sigma=self.sigma) + new.df_mode = self.df_mode + new.ref_len = self.ref_len + new.document_frequency = self.document_frequency + return new + + def __init__(self, df_mode='corpus', test=None, refs=None, n=4, sigma=6.0): + ''' singular instance ''' + self.n = n + self.sigma = sigma + self.crefs = [] + self.ctest = [] + self.df_mode = df_mode + self.ref_len = None + if self.df_mode != 'corpus': + pkl_file = cPickle.load( + open(df_mode, 'rb'), + **(dict(encoding='latin1') if six.PY3 else {})) + self.ref_len = np.log(float(pkl_file['ref_len'])) + self.document_frequency = pkl_file['document_frequency'] + else: + self.document_frequency = None + self.cook_append(test, refs) + + def clear(self): + self.crefs = [] + self.ctest = [] + + def cook_append(self, test, refs): + '''called by constructor and __iadd__ to avoid creating new instances.''' + + if refs is not None: + self.crefs.append(cook_refs(refs)) + if test is not None: + self.ctest.append(cook_test(test)) # N.B.: -1 + else: + self.ctest.append( + None) # lens of crefs and ctest have to match + + def size(self): + assert len(self.crefs) == len( + self.ctest), 'refs/test mismatch! %d<>%d' % (len( + self.crefs), len(self.ctest)) + return len(self.crefs) + + def __iadd__(self, other): + '''add an instance (e.g., from another sentence).''' + + if type(other) is tuple: + # avoid creating new CiderScorer instances + self.cook_append(other[0], other[1]) + else: + self.ctest.extend(other.ctest) + self.crefs.extend(other.crefs) + + return self + + def compute_doc_freq(self): + """ + Compute term frequency for reference data. + This will be used to compute idf (inverse document frequency later) + The term frequency is stored in the object + :return: None + """ + for refs in self.crefs: + # refs, k ref captions of one image + for ngram in set([ + ngram for ref in refs for (ngram, count) in ref.items() + ]): # noqa + self.document_frequency[ngram] += 1 + + def compute_cider(self): + + def counts2vec(cnts): + """ + Function maps counts of ngram to vector of tfidf weights. + The function returns vec, an array of dictionary that store mapping of n-gram and tf-idf weights. + The n-th entry of array denotes length of n-grams. + :param cnts: + :return: vec (array of dict), norm (array of float), length (int) + """ + vec = [defaultdict(float) for _ in range(self.n)] + length = 0 + norm = [0.0 for _ in range(self.n)] + for (ngram, term_freq) in cnts.items(): + # give word count 1 if it doesn't appear in reference corpus + df = np.log(max(1.0, self.document_frequency[ngram])) + # ngram index + n = len(ngram) - 1 + # tf (term_freq) * idf (precomputed idf) for n-grams + vec[n][ngram] = float(term_freq) * (self.ref_len - df) + # compute norm for the vector. the norm will be used for computing similarity + norm[n] += pow(vec[n][ngram], 2) + + if n == 1: + length += term_freq + norm = [np.sqrt(n) for n in norm] + return vec, norm, length + + def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref): + ''' + Compute the cosine similarity of two vectors. + :param vec_hyp: array of dictionary for vector corresponding to hypothesis + :param vec_ref: array of dictionary for vector corresponding to reference + :param norm_hyp: array of float for vector corresponding to hypothesis + :param norm_ref: array of float for vector corresponding to reference + :param length_hyp: int containing length of hypothesis + :param length_ref: int containing length of reference + :return: array of score for each n-grams cosine similarity + ''' + delta = float(length_hyp - length_ref) + # measure consine similarity + val = np.array([0.0 for _ in range(self.n)]) + for n in range(self.n): + # ngram + for (ngram, count) in vec_hyp[n].items(): + # vrama91 : added clipping + val[n] += min(vec_hyp[n][ngram], + vec_ref[n][ngram]) * vec_ref[n][ngram] + + if (norm_hyp[n] != 0) and (norm_ref[n] != 0): + val[n] /= (norm_hyp[n] * norm_ref[n]) + + assert (not math.isnan(val[n])) + # vrama91: added a length based gaussian penalty + val[n] *= np.e**(-(delta**2) / (2 * self.sigma**2)) + return val + + # compute log reference length + if self.df_mode == 'corpus': + self.ref_len = np.log(float(len(self.crefs))) + # elif self.df_mode == "coco-val-df": + # if coco option selected, use length of coco-val set + # self.ref_len = np.log(float(40504)) + + scores = [] + for test, refs in zip(self.ctest, self.crefs): + # compute vector for test captions + vec, norm, length = counts2vec(test) + # compute vector for ref captions + score = np.array([0.0 for _ in range(self.n)]) + for ref in refs: + vec_ref, norm_ref, length_ref = counts2vec(ref) + score += sim(vec, vec_ref, norm, norm_ref, length, length_ref) + # change by vrama91 - mean of ngram scores, instead of sum + score_avg = np.mean(score) + # divide by number of references + score_avg /= len(refs) + # multiply score by 10 + score_avg *= 10.0 + # append score of an image to the score list + scores.append(score_avg) + return scores + + def compute_score(self, option=None, verbose=0): + # compute idf + if self.df_mode == 'corpus': + self.document_frequency = defaultdict(float) + self.compute_doc_freq() + # assert to check document frequency + assert (len(self.ctest) >= max(self.document_frequency.values())) + # import json for now and write the corresponding files + # compute cider score + score = self.compute_cider() + # debug + # print score + return np.mean(np.array(score)), np.array(score) diff --git a/modelscope/models/multi_modal/ofa/adaptor/__init__.py b/modelscope/models/multi_modal/ofa/adaptor/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/multi_modal/ofa/generate/search.py b/modelscope/models/multi_modal/ofa/generate/search.py index 63ecb0a9..0dcaf6b3 100644 --- a/modelscope/models/multi_modal/ofa/generate/search.py +++ b/modelscope/models/multi_modal/ofa/generate/search.py @@ -148,7 +148,7 @@ class BeamSearch(Search): scores_buf = top_prediction[0] indices_buf = top_prediction[1] # Project back into relative indices and beams - beams_buf = indices_buf // vocab_size + beams_buf = torch.div(indices_buf, vocab_size, rounding_mode='floor') indices_buf = indices_buf.fmod(vocab_size) # At this point, beams_buf and indices_buf are single-dim and contain relative indices diff --git a/modelscope/models/multi_modal/ofa/generate/sequence_generator.py b/modelscope/models/multi_modal/ofa/generate/sequence_generator.py index 590fb67b..e42d3c8e 100644 --- a/modelscope/models/multi_modal/ofa/generate/sequence_generator.py +++ b/modelscope/models/multi_modal/ofa/generate/sequence_generator.py @@ -385,12 +385,7 @@ class SequenceGenerator(nn.Module): attn = torch.empty(bsz * beam_size, avg_attn_scores.size(1), max_len + 2).to(scores) - # print("+++++++ debug attention shape +++++++") - # print("attn", attn.shape) - # print("avg_attn_scores", avg_attn_scores.shape) attn[:, :, step + 1].copy_(avg_attn_scores) - # print("attn[:, :, step + 1]", attn[:, :, step + 1].shape) - # print("attn", attn.shape) scores = scores.type_as(lprobs) eos_bbsz_idx = torch.empty(0).to( @@ -404,8 +399,28 @@ class SequenceGenerator(nn.Module): self.search.set_src_lengths(src_lengths) if self.repeat_ngram_blocker is not None: - lprobs = self.repeat_ngram_blocker(tokens, lprobs, bsz, - beam_size, step) + # process prefix_tokens + p_toks_len = prefix_tokens.ne(self.pad).sum( + dim=1) if prefix_tokens is not None else None + if p_toks_len is not None: + p_toks_len_beam = p_toks_len.unsqueeze(-1).repeat( + 1, beam_size).view(-1) + no_repeat_ngram_size = self.repeat_ngram_blocker.no_repeat_ngram_size + out_prefix = p_toks_len_beam < ( + step + no_repeat_ngram_size - 1) + else: + out_prefix = torch.ones(bsz * beam_size).bool() + ngram_blocker_tokens = tokens[out_prefix] + ngram_blocker_lprobs = lprobs[out_prefix] + ngram_blocker_bsz = torch.div( + out_prefix.sum(), beam_size, rounding_mode='trunc') + + lprobs[out_prefix] = self.repeat_ngram_blocker( + tokens=ngram_blocker_tokens, + lprobs=ngram_blocker_lprobs, + bsz=ngram_blocker_bsz, + beam_size=beam_size, + step=step) # Shape: (batch, cand_size) cand_scores, cand_indices, cand_beams = self.search.step( @@ -415,7 +430,6 @@ class SequenceGenerator(nn.Module): tokens[:, :step + 1], original_batch_idxs, ) - # cand_bbsz_idx contains beam indices for the top candidate # hypotheses, with a range of values: [0, bsz*beam_size), # and dimensions: [bsz, cand_size] @@ -671,7 +685,7 @@ class SequenceGenerator(nn.Module): cum_unfin.append(prev) cum_fin_tensor = torch.tensor(cum_unfin, dtype=torch.int).to(bbsz_idx) - unfin_idx = bbsz_idx // beam_size + unfin_idx = torch.div(bbsz_idx, beam_size, rounding_mode='floor') sent = unfin_idx + torch.index_select(cum_fin_tensor, 0, unfin_idx) # Create a set of "{sent}{unfin_idx}", where diff --git a/modelscope/models/multi_modal/ofa/modeling_ofa.py b/modelscope/models/multi_modal/ofa/modeling_ofa.py index 01cc02f9..0a7a2ce6 100755 --- a/modelscope/models/multi_modal/ofa/modeling_ofa.py +++ b/modelscope/models/multi_modal/ofa/modeling_ofa.py @@ -19,6 +19,7 @@ from dataclasses import dataclass from typing import Dict, List, Optional, Tuple import torch +from packaging import version from torch import Tensor, nn from torch.nn import functional as F from transformers.activations import ACT2FN @@ -40,6 +41,8 @@ logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = 'ofa-base' _CONFIG_FOR_DOC = 'OFAConfig' _TOKENIZER_FOR_DOC = 'OFATokenizer' +TORCH_VERSION = version.parse(torch.__version__) +TORCH_MESH_GRID_WARNING_VERSION = version.parse('1.9.1') DEFAULT_MAX_SOURCE_POSITIONS = 1024 DEFAULT_MAX_TARGET_POSITIONS = 1024 @@ -51,6 +54,7 @@ OFA_PRETRAINED_MODEL_ARCHIVE_LIST = [ 'ofa-medium', 'ofa-base', 'ofa-large', + 'ofa-huge', ] try: @@ -114,7 +118,11 @@ def make_image_bucket_position(bucket_size, num_relative_distance): """ coords_h = torch.arange(bucket_size) coords_w = torch.arange(bucket_size) - coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + if TORCH_VERSION > TORCH_MESH_GRID_WARNING_VERSION: + coords = torch.stack( + torch.meshgrid([coords_h, coords_w], indexing='ij')) # 2, Wh, Ww + else: + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = coords_flatten[:, :, None] - \ coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww diff --git a/modelscope/models/multi_modal/ofa/utils/constant.py b/modelscope/models/multi_modal/ofa/utils/constant.py index d3257383..b3776f8f 100644 --- a/modelscope/models/multi_modal/ofa/utils/constant.py +++ b/modelscope/models/multi_modal/ofa/utils/constant.py @@ -8,7 +8,7 @@ OFA_TASK_KEY_MAPPING = { Tasks.text_summarization: OutputKeys.TEXT, Tasks.visual_question_answering: OutputKeys.TEXT, Tasks.visual_grounding: OutputKeys.BOXES, - Tasks.text_classification: (OutputKeys.SCORES, OutputKeys.LABELS), + Tasks.text_classification: OutputKeys.LABELS, Tasks.image_classification: OutputKeys.LABELS, - Tasks.visual_entailment: (OutputKeys.SCORES, OutputKeys.LABELS), + Tasks.visual_entailment: OutputKeys.LABELS, } diff --git a/modelscope/models/multi_modal/ofa_for_all_tasks.py b/modelscope/models/multi_modal/ofa_for_all_tasks.py index 6e331228..56d19ad8 100644 --- a/modelscope/models/multi_modal/ofa_for_all_tasks.py +++ b/modelscope/models/multi_modal/ofa_for_all_tasks.py @@ -1,8 +1,10 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import math +import os import string +from functools import partial from os import path as osp -from typing import Any, Dict +from typing import Any, Callable, Dict, List, Optional, Union import json import torch.cuda @@ -10,7 +12,6 @@ import torch.nn.functional as F from modelscope.metainfo import Models from modelscope.models import TorchModel -from modelscope.models.base import Tensor from modelscope.models.builder import MODELS from modelscope.outputs import OutputKeys from modelscope.preprocessors.ofa.utils.collate import collate_tokens @@ -66,10 +67,9 @@ class OfaForAllTasks(TorchModel): self.gen_type = self.cfg.model.get('gen_type', 'generation') assert self.gen_type in ['generation', 'traverse'], \ 'model.gen_type must be in ["generation", "traverse"]' - self._device = torch.device('cuda') if torch.cuda.is_available() \ - else torch.device('cpu') - self.eos_item = torch.LongTensor([self.tokenizer.eos_token_id - ]).to(self._device) + self.bos_item = torch.LongTensor([self.tokenizer.bos_token_id]) + self.pad_item = torch.LongTensor([self.tokenizer.pad_token_id]) + self.eos_item = torch.LongTensor([self.tokenizer.eos_token_id]) self.index2ans = {} self.ans2label_dict = {} self.load_ans2label() @@ -90,7 +90,8 @@ class OfaForAllTasks(TorchModel): self.val_masks_l = [] self.build_trie() sg_args['constraint_trie'] = self.constraint_trie - self.model.to(self._device) + else: + self.constraint_trie = None self.generator = sg.SequenceGenerator(**sg_args) inference_d = { 'generation': self._text_gen_inference, @@ -108,8 +109,16 @@ class OfaForAllTasks(TorchModel): } def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + input = move_to_device(input, self.model.device) + if self.model.training: + return self.model(**input['net_input']) + else: + return self.inference(input) + + def inference(self, input: Dict[str, Any]) -> Dict[str, Any]: ret = self.task_inference_mapping[self.cfg.task](input) - ret['samples'] = input['samples'] + if 'samples' in input: + ret['samples'] = input['samples'] for key in [ OutputKeys.CAPTION, OutputKeys.TEXT, OutputKeys.BOXES, OutputKeys.LABELS, OutputKeys.SCORES @@ -118,21 +127,33 @@ class OfaForAllTasks(TorchModel): ret[key] = None return ret - def postprocess(self, input: Dict[str, Tensor], - **kwargs) -> Dict[str, Tensor]: - if self.cfg.task == Tasks.image_captioning: - caption = [ - cap.translate(self.transtab).strip() - for cap in input[OutputKeys.CAPTION] - ] - input[OutputKeys.CAPTION] = caption + def postprocess(self, input: Dict[str, Any], **kwargs) -> Dict[str, Any]: + if not self.model.training and self.cfg.task == Tasks.image_captioning: + caption = input[OutputKeys.CAPTION] + result_l = list() + for cap in caption: + result_l.append(cap.translate(self.transtab).strip()) + input[OutputKeys.CAPTION] = result_l return input def _text_gen_inference(self, input): - input = move_to_device(input, self._device) - gen_output = self.generator.generate([self.model], input) - gen = [gen_output[i][0]['tokens'] for i in range(len(gen_output))] - result = self.tokenizer.batch_decode(gen, skip_special_tokens=True) + gen_outputs = self.generator.generate([self.model], + input, + prefix_tokens=input.get( + 'prefix_tokens', None)) + gen_l = list() + for idx, gen_out in enumerate(gen_outputs): + if len(gen_out) > 0: + decode_tokens = gen_out[0]['tokens'] + if 'prefix_tokens' in input: + prefix_len = input['prefix_tokens'][idx].ne( + self.pad_item.to(self.model.device)).sum() + decode_tokens = decode_tokens[prefix_len:] + gen_l.append(decode_tokens) + else: + gen_l.append('') + result = self.tokenizer.batch_decode(gen_l, skip_special_tokens=True) + result = [item.strip() for item in result] # text generation tasks have no score ret = {OFA_TASK_KEY_MAPPING[self.cfg.task]: result} if self.cfg.task.endswith('classification'): @@ -140,7 +161,6 @@ class OfaForAllTasks(TorchModel): return ret def _visual_grounding_inference(self, input): - input = move_to_device(input, self._device) gen_output = self.generator.generate([self.model], input) tokens = [gen_output[i][0]['tokens'] for i in range(len(gen_output))] region_coord_l = list() @@ -160,7 +180,6 @@ class OfaForAllTasks(TorchModel): } def _traverse_inference(self, input): - input = move_to_device(input, self._device) encoder_input = dict() for key in input['net_input'].keys(): encoder_input[key] = input['net_input'][key] @@ -170,13 +189,14 @@ class OfaForAllTasks(TorchModel): valid_size = len(val_ans) valid_tgt_items = [ torch.cat([ - torch.tensor(decoder_prompt[1:]), valid_answer, + torch.tensor(decoder_prompt[1:]).to('cpu'), valid_answer, self.eos_item ]) for decoder_prompt in input['decoder_prompts'] for valid_answer in val_ans ] valid_prev_items = [ - torch.cat([torch.tensor(decoder_prompt), valid_answer]) + torch.cat( + [torch.tensor(decoder_prompt).to('cpu'), valid_answer]) for decoder_prompt in input['decoder_prompts'] for valid_answer in val_ans ] @@ -184,19 +204,19 @@ class OfaForAllTasks(TorchModel): torch.cat([ torch.zeros( len(decoder_prompt) - 1, - valid_constraint_mask.size(1)).bool().to(self._device), + valid_constraint_mask.size(1)).bool(), valid_constraint_mask], dim=0) # yapf: disable for decoder_prompt in input['decoder_prompts'] # yapf: disable for valid_constraint_mask in val_masks] # yapf: disable valid_tgt = collate_tokens( valid_tgt_items, - pad_idx=self.tokenizer.pad_token_id).to(self._device) + pad_idx=self.tokenizer.pad_token_id).to(self.model.device) valid_prev_output = collate_tokens( valid_prev_items, - pad_idx=self.tokenizer.pad_token_id).to(self._device) + pad_idx=self.tokenizer.pad_token_id).to(self.model.device) val_masks = collate_tokens( valid_constraint_mask_items, - pad_idx=self.tokenizer.pad_token_id).to(self._device) + pad_idx=self.tokenizer.pad_token_id).to(self.model.device) new_encoder_out = { 'last_hidden_state': encoder_out['last_hidden_state'].repeat_interleave( @@ -271,10 +291,23 @@ class OfaForAllTasks(TorchModel): self.val_masks_l += [ constraint_mask_list[i:i + self.val_batch_size] ] - self.val_ans_l = move_to_device(self.val_ans_l, self._device) - self.val_masks_l = move_to_device(self.val_masks_l, self._device) def load_ans2label(self): if self.cfg.model.get('answer2label', None): - filename = osp.join(self.model_dir, self.cfg.model.answer2label) - self.ans2label_dict = json.load(open(filename)) + ans2label_file = osp.join(self.model_dir, + self.cfg.model.answer2label) + with open(ans2label_file, 'r') as reader: + self.ans2label_dict = json.load(reader) + + def save_pretrained(self, + target_folder: Union[str, os.PathLike], + save_checkpoint_names: Union[str, List[str]] = None, + save_function: Callable = None, + config: Optional[dict] = None, + **kwargs): + super(OfaForAllTasks, self). \ + save_pretrained(target_folder=target_folder, + save_checkpoint_names=save_checkpoint_names, + save_function=partial(save_function, with_meta=False), + config=config, + **kwargs) diff --git a/modelscope/pipelines/cv/image_classification_pipeline.py b/modelscope/pipelines/cv/image_classification_pipeline.py index 49467eab..69dbd1fb 100644 --- a/modelscope/pipelines/cv/image_classification_pipeline.py +++ b/modelscope/pipelines/cv/image_classification_pipeline.py @@ -13,6 +13,7 @@ from modelscope.pipelines.base import Input, Model, Pipeline from modelscope.pipelines.builder import PIPELINES from modelscope.preprocessors import OfaPreprocessor, Preprocessor, load_image from modelscope.utils.constant import Tasks +from modelscope.utils.device import get_device from modelscope.utils.logger import get_logger logger = get_logger() @@ -36,6 +37,7 @@ class ImageClassificationPipeline(Pipeline): else: raise NotImplementedError pipe_model.model.eval() + pipe_model.to(get_device()) if preprocessor is None and isinstance(pipe_model, OfaForAllTasks): preprocessor = OfaPreprocessor(model_dir=pipe_model.model_dir) super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs) diff --git a/modelscope/preprocessors/multi_modal.py b/modelscope/preprocessors/multi_modal.py index 4427c096..256c5243 100644 --- a/modelscope/preprocessors/multi_modal.py +++ b/modelscope/preprocessors/multi_modal.py @@ -1,5 +1,6 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import os.path as osp +from io import BytesIO from typing import Any, Dict, List, Tuple, Union import torch @@ -15,6 +16,7 @@ from .base import Preprocessor from .builder import PREPROCESSORS from .ofa import * # noqa from .ofa.utils.collate import collate_fn +from .ofa.utils.constant import OFA_TASK_KEY_MAPPING __all__ = [ 'OfaPreprocessor', @@ -26,11 +28,16 @@ __all__ = [ Fields.multi_modal, module_name=Preprocessors.ofa_tasks_preprocessor) class OfaPreprocessor(Preprocessor): - def __init__(self, model_dir: str, *args, **kwargs): + def __init__(self, + model_dir: str, + mode=ModeKeys.INFERENCE, + *args, + **kwargs): """preprocess the data Args: model_dir (str): model path + mode: preprocessor mode (model mode) """ super().__init__(*args, **kwargs) preprocess_mapping = { @@ -45,25 +52,18 @@ class OfaPreprocessor(Preprocessor): Tasks.text_summarization: OfaSummarizationPreprocessor, Tasks.text_to_image_synthesis: OfaTextToImageSynthesisPreprocessor } - input_key_mapping = { - Tasks.ocr_recognition: ['image'], - Tasks.image_captioning: ['image'], - Tasks.image_classification: ['image'], - Tasks.text_summarization: ['text'], - Tasks.text_classification: ['text', 'text2'], - Tasks.visual_grounding: ['image', 'text'], - Tasks.visual_question_answering: ['image', 'text'], - Tasks.visual_entailment: ['image', 'text', 'text2'], - Tasks.text_to_image_synthesis: ['text'] - } model_dir = model_dir if osp.exists(model_dir) else snapshot_download( model_dir) self.cfg = Config.from_file( osp.join(model_dir, ModelFile.CONFIGURATION)) - self.preprocess = preprocess_mapping[self.cfg.task](self.cfg, - model_dir) - self.keys = input_key_mapping[self.cfg.task] + self.preprocess = preprocess_mapping[self.cfg.task]( + cfg=self.cfg, model_dir=model_dir, mode=mode) + self.keys = OFA_TASK_KEY_MAPPING[self.cfg.task] self.tokenizer = self.preprocess.tokenizer + if kwargs.get('no_collate', None): + self.no_collate = True + else: + self.no_collate = False # just for modelscope demo def _build_dict(self, input: Union[Input, List[Input]]) -> Dict[str, Any]: @@ -74,20 +74,37 @@ class OfaPreprocessor(Preprocessor): data[key] = item return data + def _ofa_input_compatibility_conversion(self, data): + if 'image' in data and self.cfg.model.get('type', None) == 'ofa': + if isinstance(data['image'], str): + image = load_image(data['image']) + else: + image = data['image'] + if image.mode != 'RGB': + image = image.convert('RGB') + img_buffer = BytesIO() + image.save(img_buffer, format='JPEG') + data['image'] = Image.open(img_buffer) + return data + def __call__(self, input: Union[str, tuple, Dict[str, Any]], *args, **kwargs) -> Dict[str, Any]: if isinstance(input, dict): data = input else: data = self._build_dict(input) + data = self._ofa_input_compatibility_conversion(data) sample = self.preprocess(data) str_data = dict() for k, v in data.items(): str_data[k] = str(v) sample['sample'] = str_data - return collate_fn([sample], - pad_idx=self.tokenizer.pad_token_id, - eos_idx=self.tokenizer.eos_token_id) + if self.no_collate: + return sample + else: + return collate_fn([sample], + pad_idx=self.tokenizer.pad_token_id, + eos_idx=self.tokenizer.eos_token_id) @PREPROCESSORS.register_module( @@ -140,7 +157,7 @@ class MPlugPreprocessor(Preprocessor): def image_open(self, path: str) -> Tuple[Image.Image, int]: if path not in self._image_map: index = len(self._image_map) - self._image_map[path] = (load_image(path), index) + self._image_map[path] = (Image.open(path), index) return self._image_map[path] def __call__( diff --git a/modelscope/preprocessors/ofa/base.py b/modelscope/preprocessors/ofa/base.py index 691f8b36..55b3895d 100644 --- a/modelscope/preprocessors/ofa/base.py +++ b/modelscope/preprocessors/ofa/base.py @@ -1,26 +1,31 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import re +import string from os import path as osp import json import numpy as np import torch +from PIL import Image from modelscope.models.multi_modal.ofa import OFATokenizer, OFATokenizerZH +from modelscope.preprocessors.image import load_image from modelscope.utils.trie import Trie +from .utils.constant import OFA_TASK_KEY_MAPPING from .utils.random_help import set_torch_seed class OfaBasePreprocessor: - def __init__(self, cfg, model_dir): - """preprocess the data + def __init__(self, cfg, model_dir, mode, *args, **kwargs): + """preprocess the data via the vocab.txt from the `model_dir` path Args: cfg(modelscope.utils.config.ConfigDict) : model config model_dir (str): model path """ self.cfg = cfg + self.mode = mode self.language = self.cfg.model.get('language', 'en') if self.language == 'en': tokenizer = OFATokenizer.from_pretrained(model_dir) @@ -41,6 +46,7 @@ class OfaBasePreprocessor: for key, value in tokenizer.get_vocab().items() } self.max_src_length = cfg.model.get('max_src_length', 256) + self.max_tgt_length = cfg.model.get('max_tgt_length', 256) self.max_image_size = cfg.model.get('max_image_size', 512) self.language = self.cfg.model.get('language', 'en') self.prompt_type = self.cfg.model.get('prompt_type', 'none') @@ -56,26 +62,40 @@ class OfaBasePreprocessor: self.mean = [0.5, 0.5, 0.5] self.std = [0.5, 0.5, 0.5] self.patch_image_size = self.cfg.model.get('patch_image_size', 480) + self.column_map = { + key: key + for key in OFA_TASK_KEY_MAPPING[self.cfg.task] + } + if hasattr(self.cfg, + 'dataset') and self.cfg.dataset.column_map is not None: + for k, v in self.cfg.dataset.column_map.items(): + self.column_map[k] = v + self.transtab = str.maketrans( + {key: None + for key in string.punctuation}) self.constraint_trie = None - self.index2ans = {} - if self.cfg.model.get('answer2label', False): + if self.cfg.model.get('answer2label', None): ans2label_file = osp.join(model_dir, self.cfg.model.answer2label) - ans2label_dict = json.load(open(ans2label_file, 'r')) + with open(ans2label_file, 'r') as reader: + ans2label_dict = json.load(reader) + self.ans2label = ans2label_dict + self.label2ans = {v: k for k, v in self.ans2label.items()} self.constraint_trie = Trie(tokenizer.eos_token_id) for i, answer in enumerate(ans2label_dict.keys()): - answer_item = tokenizer( - ' ' + answer, - return_tensors='pt', - add_special_tokens=False).input_ids.squeeze(0) + answer_item = self.tokenize_text( + ' ' + answer, add_bos=False, add_eos=False) self.constraint_trie.insert([tokenizer.bos_token_id] + answer_item.tolist() + [tokenizer.eos_token_id]) - def get_inputs(self, text, add_bos=True, add_eos=True): + def tokenize_text(self, text, add_bos=True, add_eos=True): + if text is None: + return None inputs = self.tokenizer( text, max_length=self.max_src_length, add_special_tokens=False, + truncation=True, return_tensors='pt')['input_ids'].squeeze(0) if add_bos: inputs = torch.cat([self.bos_item, inputs]) @@ -85,7 +105,7 @@ class OfaBasePreprocessor: @staticmethod def pre_caption(caption, max_words=None): - caption = caption.lower().lstrip(',.!?*#:;~').replace('-', ' ')\ + caption = caption.lower().lstrip(',.!?*#:;~').replace('-', ' ') \ .replace('/', ' ').replace('', 'person') caption = re.sub( @@ -123,3 +143,23 @@ class OfaBasePreprocessor: question = ' '.join(question_words[:max_ques_words]) return question + + def add_constraint_mask(self, sample): + target_itm = sample['target'] + len_label_itm = target_itm.ne(self.pad_item).sum(dim=0).item() + if self.constraint_trie: + constraint_mask = torch.zeros( + (len(target_itm), len(self.tgt_dict))).bool() + start_idx = len(target_itm) - len_label_itm + for i in range(start_idx, len(target_itm)): + constraint_prefix_token = self.bos_item.tolist( + ) + target_itm[start_idx:i].tolist() + constraint_nodes = self.constraint_trie.get_next_layer( + constraint_prefix_token) + constraint_mask[i][constraint_nodes] = True + sample['constraint_mask'] = constraint_mask + + def get_img_pil(self, path_or_url_or_pil): + image = path_or_url_or_pil if isinstance(path_or_url_or_pil, Image.Image) \ + else load_image(path_or_url_or_pil) + return image diff --git a/modelscope/preprocessors/ofa/image_captioning.py b/modelscope/preprocessors/ofa/image_captioning.py index 318a8a6d..af623297 100644 --- a/modelscope/preprocessors/ofa/image_captioning.py +++ b/modelscope/preprocessors/ofa/image_captioning.py @@ -1,42 +1,67 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from typing import Any, Dict, Union +from typing import Any, Dict import torch -from PIL import Image from torchvision import transforms -from modelscope.preprocessors.image import load_image +from modelscope.utils.constant import ModeKeys from .base import OfaBasePreprocessor class OfaImageCaptioningPreprocessor(OfaBasePreprocessor): - def __init__(self, cfg, model_dir): + def __init__(self, + cfg, + model_dir, + mode=ModeKeys.INFERENCE, + *args, + **kwargs): """preprocess the data Args: cfg(modelscope.utils.config.ConfigDict) : model config - model_dir (str): model path + model_dir (str): model path, + mode: preprocessor mode (model mode) """ - super(OfaImageCaptioningPreprocessor, self).__init__(cfg, model_dir) + super(OfaImageCaptioningPreprocessor, + self).__init__(cfg, model_dir, mode, *args, **kwargs) # Initialize transform self.patch_resize_transform = transforms.Compose([ lambda image: image.convert('RGB'), - transforms.Resize((self.patch_image_size, self.patch_image_size), - interpolation=Image.BICUBIC), + transforms.Resize( + (self.patch_image_size, self.patch_image_size), + interpolation=transforms.InterpolationMode.BICUBIC), transforms.ToTensor(), transforms.Normalize(mean=self.mean, std=self.std), ]) def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: - image = data['image'] if isinstance( - data['image'], Image.Image) else load_image(data['image']) + if self.mode == ModeKeys.TRAIN: + return self._build_train_sample(data) + else: + return self._build_infer_sample(data) + + def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: + sample = self._build_infer_sample(data) + target = data[self.column_map['text']] + target = target.translate(self.transtab).strip() + target_token_list = target.strip().split() + target = ' '.join(target_token_list[:self.max_tgt_length]) + sample['target'] = self.tokenize_text(target, add_bos=False) + sample['prev_output_tokens'] = torch.cat( + [self.bos_item, sample['target'][:-1]]) + return sample + + def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: + image = self.get_img_pil(data[self.column_map['image']]) patch_image = self.patch_resize_transform(image) prompt = self.cfg.model.get('prompt', ' what does the image describe?') - inputs = self.get_inputs(prompt) + inputs = self.tokenize_text(prompt) sample = { 'source': inputs, 'patch_image': patch_image, 'patch_mask': torch.tensor([True]) } + if 'text' in self.column_map and self.column_map['text'] in data: + sample['label'] = data[self.column_map['text']] return sample diff --git a/modelscope/preprocessors/ofa/image_classification.py b/modelscope/preprocessors/ofa/image_classification.py index dd2de634..49968823 100644 --- a/modelscope/preprocessors/ofa/image_classification.py +++ b/modelscope/preprocessors/ofa/image_classification.py @@ -6,25 +6,33 @@ from PIL import Image from torchvision import transforms from modelscope.preprocessors.image import load_image +from modelscope.utils.constant import ModeKeys from .base import OfaBasePreprocessor class OfaImageClassificationPreprocessor(OfaBasePreprocessor): - def __init__(self, cfg, model_dir): + def __init__(self, + cfg, + model_dir, + mode=ModeKeys.INFERENCE, + *args, + **kwargs): """preprocess the data Args: cfg(modelscope.utils.config.ConfigDict) : model config - model_dir (str): model path + model_dir (str): model path, + mode: preprocessor mode (model mode) """ super(OfaImageClassificationPreprocessor, - self).__init__(cfg, model_dir) + self).__init__(cfg, model_dir, mode, *args, **kwargs) # Initialize transform self.patch_resize_transform = transforms.Compose([ lambda image: image.convert('RGB'), - transforms.Resize((self.patch_image_size, self.patch_image_size), - interpolation=Image.BICUBIC), + transforms.Resize( + (self.patch_image_size, self.patch_image_size), + interpolation=transforms.InterpolationMode.BICUBIC), transforms.ToTensor(), transforms.Normalize(mean=self.mean, std=self.std), ]) @@ -34,7 +42,7 @@ class OfaImageClassificationPreprocessor(OfaBasePreprocessor): data['image'], Image.Image) else load_image(data['image']) patch_image = self.patch_resize_transform(image) prompt = self.cfg.model.get('prompt', ' what does the image describe?') - inputs = self.get_inputs(prompt) + inputs = self.tokenize_text(prompt) sample = { 'source': inputs, 'patch_image': patch_image, diff --git a/modelscope/preprocessors/ofa/ocr_recognition.py b/modelscope/preprocessors/ofa/ocr_recognition.py index 1d30e572..1761dbd4 100644 --- a/modelscope/preprocessors/ofa/ocr_recognition.py +++ b/modelscope/preprocessors/ofa/ocr_recognition.py @@ -1,7 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -import random -import unicodedata -from typing import Any, Dict, Union +from typing import Any, Dict import torch from PIL import Image @@ -10,6 +8,7 @@ from torchvision.transforms import InterpolationMode from torchvision.transforms import functional as F from modelscope.preprocessors.image import load_image +from modelscope.utils.constant import ModeKeys from .base import OfaBasePreprocessor IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) @@ -59,14 +58,21 @@ def ocr_resize(img, patch_image_size, is_document=False): class OfaOcrRecognitionPreprocessor(OfaBasePreprocessor): - def __init__(self, cfg, model_dir): + def __init__(self, + cfg, + model_dir, + mode=ModeKeys.INFERENCE, + *args, + **kwargs): """preprocess the data Args: cfg(modelscope.utils.config.ConfigDict) : model config - model_dir (str): model path + model_dir (str): model path, + mode: preprocessor mode (model mode) """ - super(OfaOcrRecognitionPreprocessor, self).__init__(cfg, model_dir) + super(OfaOcrRecognitionPreprocessor, + self).__init__(cfg, model_dir, mode, *args, **kwargs) # Initialize transform if self.cfg.model.imagenet_default_mean_and_std: mean = IMAGENET_DEFAULT_MEAN @@ -89,7 +95,7 @@ class OfaOcrRecognitionPreprocessor(OfaBasePreprocessor): data['image'], Image.Image) else load_image(data['image']) patch_image = self.patch_resize_transform(image) prompt = self.cfg.model.get('prompt', '图片上的文字是什么?') - inputs = self.get_inputs(prompt) + inputs = self.tokenize_text(prompt) sample = { 'source': inputs, diff --git a/modelscope/preprocessors/ofa/summarization.py b/modelscope/preprocessors/ofa/summarization.py index 99028e61..cfd3c23d 100644 --- a/modelscope/preprocessors/ofa/summarization.py +++ b/modelscope/preprocessors/ofa/summarization.py @@ -1,19 +1,27 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from typing import Any, Dict +from modelscope.utils.constant import ModeKeys from .base import OfaBasePreprocessor class OfaSummarizationPreprocessor(OfaBasePreprocessor): - def __init__(self, cfg, model_dir): + def __init__(self, + cfg, + model_dir, + mode=ModeKeys.INFERENCE, + *args, + **kwargs): """preprocess the data Args: cfg(modelscope.utils.config.ConfigDict) : model config - model_dir (str): model path + model_dir (str): model path, + mode: preprocessor mode (model mode) """ - super(OfaSummarizationPreprocessor, self).__init__(cfg, model_dir) + super(OfaSummarizationPreprocessor, + self).__init__(cfg, model_dir, mode, *args, **kwargs) def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: source = super().pre_caption( @@ -23,7 +31,7 @@ class OfaSummarizationPreprocessor(OfaBasePreprocessor): prompt = self.cfg.model.get( 'prompt', ' " {} " Summarize the article with a title: ') text = prompt.format(source) - inputs = self.get_inputs(text) + inputs = self.tokenize_text(text) if self.prompt_type == 'none': decoder_prompt = self.bos_item elif self.prompt_type == 'prev_output': diff --git a/modelscope/preprocessors/ofa/text_classification.py b/modelscope/preprocessors/ofa/text_classification.py index 5673a07f..24c4f67e 100644 --- a/modelscope/preprocessors/ofa/text_classification.py +++ b/modelscope/preprocessors/ofa/text_classification.py @@ -1,38 +1,81 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from typing import Any, Dict +import torch + +from modelscope.utils.constant import ModeKeys from .base import OfaBasePreprocessor class OfaTextClassificationPreprocessor(OfaBasePreprocessor): - def __init__(self, cfg, model_dir): + def __init__(self, + cfg, + model_dir, + mode=ModeKeys.INFERENCE, + *args, + **kwargs): """preprocess the data Args: cfg(modelscope.utils.config.ConfigDict) : model config - model_dir (str): model path + model_dir (str): model path, + mode: preprocessor mode (model mode) """ - super(OfaTextClassificationPreprocessor, self).__init__(cfg, model_dir) + super(OfaTextClassificationPreprocessor, + self).__init__(cfg, model_dir, mode, *args, **kwargs) def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: + if self.mode == ModeKeys.TRAIN: + return self._build_train_sample(data) + else: + return self._build_infer_sample(data) + + def _build_instruction(self, data): text1 = ' '.join( data['text'].lower().strip().split()[:self.max_src_length]) text2 = ' '.join( data['text2'].lower().strip().split()[:self.max_src_length]) prompt = ' can text1 " {} " imply text2 " {} "?' text = prompt.format(text1, text2) - inputs = self.get_inputs(text) + instruction_itm = self.tokenize_text(text) + return instruction_itm + + def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: + instruction_itm = self._build_instruction(data) + assert 'label' in data, 'there must has `label` column in train phase ' + label = data['label'] + if self.label2ans: + label = self.label2ans[label] # ans + label_itm = self.tokenize_text(f' {label}', add_bos=False) + if self.prompt_type == 'none': + target_itm = label_itm + elif self.prompt_type == 'prev_output': + target_itm = torch.cat([instruction_itm[1:-1], label_itm]) + else: + raise NotImplementedError + prev_output_itm = torch.cat([self.bos_item, target_itm[:-1]]) + target_itm[:-len(label_itm)] = self.pad_item + sample = { + 'source': instruction_itm, + 'target': target_itm, + 'prev_output_tokens': prev_output_itm, + } + self.add_constraint_mask(sample) + return sample + + def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: + instruction_itm = self._build_instruction(data) if self.prompt_type == 'none': - decoder_prompt = self.bos_item - elif self.prompt_type == 'src': - decoder_prompt = inputs + prefix_token = [] elif self.prompt_type == 'prev_output': - decoder_prompt = inputs[:-1] + prefix_token = instruction_itm[:-1] # remove eos else: raise NotImplementedError sample = { - 'source': inputs, - 'decoder_prompt': decoder_prompt, + 'source': instruction_itm, + 'prefix_token': prefix_token, } + if 'label' in data: + sample['label'] = self.label2ans[data['label']] return sample diff --git a/modelscope/preprocessors/ofa/text_to_image_synthesis.py b/modelscope/preprocessors/ofa/text_to_image_synthesis.py index e10de82c..2f6000eb 100644 --- a/modelscope/preprocessors/ofa/text_to_image_synthesis.py +++ b/modelscope/preprocessors/ofa/text_to_image_synthesis.py @@ -3,26 +3,34 @@ from typing import Any, Dict import torch +from modelscope.utils.constant import ModeKeys from .base import OfaBasePreprocessor class OfaTextToImageSynthesisPreprocessor(OfaBasePreprocessor): - def __init__(self, cfg, model_dir): + def __init__(self, + cfg, + model_dir, + mode=ModeKeys.INFERENCE, + *args, + **kwargs): """preprocess the data Args: - model_dir (str): model path + cfg(modelscope.utils.config.ConfigDict) : model config + model_dir (str): model path, + mode: preprocessor mode (model mode) """ super(OfaTextToImageSynthesisPreprocessor, - self).__init__(cfg, model_dir) + self).__init__(cfg, model_dir, mode, *args, **kwargs) self.max_src_length = 64 def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: source = ' '.join( data['text'].lower().strip().split()[:self.max_src_length]) source = 'what is the complete image? caption: {}'.format(source) - inputs = self.get_inputs(source) + inputs = self.tokenize_text(source) sample = { 'source': inputs, 'patch_images': None, diff --git a/modelscope/preprocessors/ofa/utils/collate.py b/modelscope/preprocessors/ofa/utils/collate.py index 4bfa8eca..f7775680 100644 --- a/modelscope/preprocessors/ofa/utils/collate.py +++ b/modelscope/preprocessors/ofa/utils/collate.py @@ -49,11 +49,15 @@ def collate_fn(samples, pad_idx, eos_idx): batch['conf'] = torch.cat([s['conf'] for s in samples], dim=0) if samples[0].get('ref_dict', None) is not None: batch['ref_dict'] = np.array([s['ref_dict'] for s in samples]) + if samples[0].get('label', None) is not None: + batch['labels'] = np.array([s['label'] for s in samples]).tolist() if samples[0].get('constraint_mask', None) is not None: batch['constraint_masks'] = merge('constraint_mask') if samples[0].get('decoder_prompt', None) is not None: batch['decoder_prompts'] = np.array( [s['decoder_prompt'].tolist() for s in samples]) + if samples[0].get('prefix_token', None) is not None: + batch['prefix_tokens'] = merge('prefix_token') # For detection and visual grounding if samples[0].get('w_resize_ratio', None) is not None: batch['w_resize_ratios'] = torch.stack( diff --git a/modelscope/preprocessors/ofa/utils/constant.py b/modelscope/preprocessors/ofa/utils/constant.py new file mode 100644 index 00000000..102d27c0 --- /dev/null +++ b/modelscope/preprocessors/ofa/utils/constant.py @@ -0,0 +1,13 @@ +from modelscope.utils.constant import Tasks + +OFA_TASK_KEY_MAPPING = { + Tasks.ocr_recognition: ['image'], + Tasks.image_captioning: ['image'], + Tasks.image_classification: ['image'], + Tasks.text_summarization: ['text'], + Tasks.text_classification: ['text', 'text2'], + Tasks.visual_grounding: ['image', 'text'], + Tasks.visual_question_answering: ['image', 'text'], + Tasks.visual_entailment: ['image', 'text', 'text2'], + Tasks.text_to_image_synthesis: ['text'] +} diff --git a/modelscope/preprocessors/ofa/visual_entailment.py b/modelscope/preprocessors/ofa/visual_entailment.py index 6002c4a6..61c3cc6a 100644 --- a/modelscope/preprocessors/ofa/visual_entailment.py +++ b/modelscope/preprocessors/ofa/visual_entailment.py @@ -6,24 +6,33 @@ from PIL import Image from torchvision import transforms from modelscope.preprocessors.image import load_image +from modelscope.utils.constant import ModeKeys from .base import OfaBasePreprocessor class OfaVisualEntailmentPreprocessor(OfaBasePreprocessor): - def __init__(self, cfg, model_dir): + def __init__(self, + cfg, + model_dir, + mode=ModeKeys.INFERENCE, + *args, + **kwargs): """preprocess the data Args: cfg(modelscope.utils.config.ConfigDict) : model config - model_dir (str): model path + model_dir (str): model path, + mode: preprocessor mode (model mode) """ - super(OfaVisualEntailmentPreprocessor, self).__init__(cfg, model_dir) + super(OfaVisualEntailmentPreprocessor, + self).__init__(cfg, model_dir, mode, *args, **kwargs) # Initialize transform self.patch_resize_transform = transforms.Compose([ lambda image: image.convert('RGB'), - transforms.Resize((self.patch_image_size, self.patch_image_size), - interpolation=Image.BICUBIC), + transforms.Resize( + (self.patch_image_size, self.patch_image_size), + interpolation=transforms.InterpolationMode.BICUBIC), transforms.ToTensor(), transforms.Normalize(mean=self.mean, std=self.std), ]) @@ -44,7 +53,7 @@ class OfaVisualEntailmentPreprocessor(OfaBasePreprocessor): prompt = self.cfg.model.get( 'prompt', ' can image and text1 " {} " imply text2 " {} "?') text = prompt.format(caption, hypothesis) - inputs = self.get_inputs(text) + inputs = self.tokenize_text(text) if self.prompt_type == 'none': decoder_prompt = self.bos_item elif self.prompt_type == 'src': diff --git a/modelscope/preprocessors/ofa/visual_grounding.py b/modelscope/preprocessors/ofa/visual_grounding.py index 022e5788..8b116463 100644 --- a/modelscope/preprocessors/ofa/visual_grounding.py +++ b/modelscope/preprocessors/ofa/visual_grounding.py @@ -6,24 +6,33 @@ from PIL import Image from torchvision import transforms from modelscope.preprocessors.image import load_image +from modelscope.utils.constant import ModeKeys from .base import OfaBasePreprocessor class OfaVisualGroundingPreprocessor(OfaBasePreprocessor): - def __init__(self, cfg, model_dir): + def __init__(self, + cfg, + model_dir, + mode=ModeKeys.INFERENCE, + *args, + **kwargs): """preprocess the data Args: cfg(modelscope.utils.config.ConfigDict) : model config - model_dir (str): model path + model_dir (str): model path, + mode: preprocessor mode (model mode) """ - super(OfaVisualGroundingPreprocessor, self).__init__(cfg, model_dir) + super(OfaVisualGroundingPreprocessor, + self).__init__(cfg, model_dir, mode, *args, **kwargs) # Initialize transform self.patch_resize_transform = transforms.Compose([ lambda image: image.convert('RGB'), - transforms.Resize((self.patch_image_size, self.patch_image_size), - interpolation=Image.BICUBIC), + transforms.Resize( + (self.patch_image_size, self.patch_image_size), + interpolation=transforms.InterpolationMode.BICUBIC), transforms.ToTensor(), transforms.Normalize(mean=self.mean, std=self.std), ]) @@ -39,7 +48,7 @@ class OfaVisualGroundingPreprocessor(OfaBasePreprocessor): prompt = self.cfg.model.get( 'prompt', ' which region does the text " {} " describe?') text = prompt.format(src_caption) - src_item = self.get_inputs(text) + src_item = self.tokenize_text(text) sample = { 'source': src_item, 'patch_image': patch_image, diff --git a/modelscope/preprocessors/ofa/visual_question_answering.py b/modelscope/preprocessors/ofa/visual_question_answering.py index d34d1db0..11104e7e 100644 --- a/modelscope/preprocessors/ofa/visual_question_answering.py +++ b/modelscope/preprocessors/ofa/visual_question_answering.py @@ -6,25 +6,33 @@ from PIL import Image from torchvision import transforms from modelscope.preprocessors.image import load_image +from modelscope.utils.constant import ModeKeys from .base import OfaBasePreprocessor class OfaVisualQuestionAnsweringPreprocessor(OfaBasePreprocessor): - def __init__(self, cfg, model_dir): + def __init__(self, + cfg, + model_dir, + mode=ModeKeys.INFERENCE, + *args, + **kwargs): """preprocess the data Args: cfg(modelscope.utils.config.ConfigDict) : model config - model_dir (str): model path + model_dir (str): model path, + mode: preprocessor mode (model mode) """ super(OfaVisualQuestionAnsweringPreprocessor, - self).__init__(cfg, model_dir) + self).__init__(cfg, model_dir, mode, *args, **kwargs) # Initialize transform self.patch_resize_transform = transforms.Compose([ lambda image: image.convert('RGB'), - transforms.Resize((self.patch_image_size, self.patch_image_size), - interpolation=Image.BICUBIC), + transforms.Resize( + (self.patch_image_size, self.patch_image_size), + interpolation=transforms.InterpolationMode.BICUBIC), transforms.ToTensor(), transforms.Normalize(mean=self.mean, std=self.std), ]) @@ -34,7 +42,7 @@ class OfaVisualQuestionAnsweringPreprocessor(OfaBasePreprocessor): data['image'], Image.Image) else load_image(data['image']) patch_image = self.patch_resize_transform(image) text = ' {}'.format(data['text']) - inputs = self.get_inputs(text) + inputs = self.tokenize_text(text) if self.prompt_type == 'none': decoder_prompt = self.bos_item elif self.prompt_type == 'src': diff --git a/modelscope/trainers/multi_modal/ofa/__init__.py b/modelscope/trainers/multi_modal/ofa/__init__.py new file mode 100644 index 00000000..34e4ec7a --- /dev/null +++ b/modelscope/trainers/multi_modal/ofa/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from .ofa_trainer import OFATrainer diff --git a/modelscope/trainers/multi_modal/ofa/ofa_trainer.py b/modelscope/trainers/multi_modal/ofa/ofa_trainer.py new file mode 100644 index 00000000..02853925 --- /dev/null +++ b/modelscope/trainers/multi_modal/ofa/ofa_trainer.py @@ -0,0 +1,154 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import math +import os +import shutil +from functools import partial +from typing import Callable, Dict, Optional, Tuple, Union + +import torch +from torch import distributed as dist +from torch import nn +from torch.utils.data import Dataset + +from modelscope.metainfo import Trainers +from modelscope.models.base import Model, TorchModel +from modelscope.msdatasets.ms_dataset import MsDataset +from modelscope.preprocessors.base import Preprocessor +from modelscope.preprocessors.multi_modal import OfaPreprocessor +from modelscope.preprocessors.ofa.utils.collate import collate_fn +from modelscope.trainers import EpochBasedTrainer +from modelscope.trainers.builder import TRAINERS +from modelscope.trainers.optimizer.builder import build_optimizer +from modelscope.utils.config import Config +from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, ConfigKeys, + ModeKeys) +from .ofa_trainer_utils import (AdjustLabelSmoothedCrossEntropyCriterion, + get_schedule) + + +@TRAINERS.register_module(module_name=Trainers.ofa) +class OFATrainer(EpochBasedTrainer): + + def __init__( + self, + model: Optional[Union[TorchModel, nn.Module, str]] = None, + cfg_file: Optional[str] = None, + arg_parse_fn: Optional[Callable] = None, + data_collator: Optional[Union[Callable, Dict[str, + Callable]]] = None, + train_dataset: Optional[Union[MsDataset, Dataset]] = None, + eval_dataset: Optional[Union[MsDataset, Dataset]] = None, + preprocessor: Optional[Union[Preprocessor, + Dict[str, Preprocessor]]] = None, + optimizers: Tuple[torch.optim.Optimizer, + torch.optim.lr_scheduler._LRScheduler] = (None, + None), + model_revision: Optional[str] = DEFAULT_MODEL_REVISION, + seed: int = 42, + **kwargs): + model = Model.from_pretrained(model, revision=model_revision) + model_dir = model.model_dir + cfg = Config.from_file(cfg_file) + if 'work_dir' not in kwargs or len(kwargs['work_dir']) == 0: + work_dir = cfg.train.work_dir + else: + work_dir = kwargs['work_dir'] + tokenizer_files = { + 'zh': [ + 'tokenizer.json', 'tokenizer_config.json', 'vocab.txt', + 'config.json' + ], + 'en': + ['tokenizer.json', 'vocab.json', 'merges.txt', 'config.json'], + } + for filename in tokenizer_files[cfg.model.get('language', 'en')]: + finetune_file = os.path.join(work_dir, filename) + pretrain_file = os.path.join(model_dir, filename) + if os.path.exists(finetune_file): + continue + if os.path.exists(pretrain_file): + shutil.copy(pretrain_file, finetune_file) + + if preprocessor is None: + preprocessor = { + ConfigKeys.train: + OfaPreprocessor( + model_dir=work_dir, mode=ModeKeys.TRAIN, no_collate=True), + ConfigKeys.val: + OfaPreprocessor( + model_dir=work_dir, mode=ModeKeys.EVAL, no_collate=True), + } + # use torchrun launch + world_size = int(os.environ.get('WORLD_SIZE', 1)) + epoch_steps = math.ceil( + len(train_dataset) / # noqa + (cfg.train.dataloader.batch_size_per_gpu * world_size)) # noqa + cfg.train.lr_scheduler.num_train_steps = epoch_steps * cfg.train.max_epochs + cfg.train.criterion.tokenizer = model.tokenizer + self.criterion = AdjustLabelSmoothedCrossEntropyCriterion( + cfg.train.criterion) + if optimizers[0] is None: + optimizer = build_optimizer(model, cfg=cfg.train.optimizer) + else: + optimizer = optimizers[0] + if optimizers[1] is None: + scheduler_class, scheduler_args = get_schedule( + cfg.train.lr_scheduler) + if scheduler_class is not None: + lr_scheduler = scheduler_class(**{'optimizer': optimizer}, + **scheduler_args) + else: + lr_scheduler = None + else: + lr_scheduler = optimizers[1] + optimizers = (optimizer, lr_scheduler) + if data_collator is None: + data_collator = partial( + collate_fn, + pad_idx=model.tokenizer.pad_token_id, + eos_idx=model.tokenizer.eos_token_id, + ) + if 'launcher' not in kwargs and cfg.train.get('launcher', None): + kwargs['launcher'] = cfg.train.launcher + if 'use_fp16' not in kwargs and cfg.train.get('use_fp16', False): + kwargs['use_fp16'] = cfg.train.use_fp16 + kwargs['to_tensor'] = False + super().__init__( + model=model, + cfg_file=cfg_file, + arg_parse_fn=arg_parse_fn, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + preprocessor=preprocessor, + optimizers=optimizers, + seed=seed, + **kwargs, + ) + + def train_step(self, model, inputs): + model.train() + model_outputs = model.forward(inputs) + loss, sample_size, logging_output = self.criterion( + model_outputs, inputs) + train_outputs = {'loss': loss} + # add model output info to log + if 'log_vars' not in train_outputs: + default_keys_pattern = ['loss'] + match_keys = set([]) + for key_p in default_keys_pattern: + match_keys.update( + [key for key in train_outputs.keys() if key_p in key]) + log_vars = {} + for key in match_keys: + value = train_outputs.get(key, None) + if value is not None: + if dist.is_available() and dist.is_initialized(): + value = value.data.clone() + dist.all_reduce(value.div_(dist.get_world_size())) + log_vars.update({key: value.item()}) + self.log_buffer.update(log_vars) + else: + self.log_buffer.update(train_outputs['log_vars']) + self.train_outputs = train_outputs diff --git a/modelscope/trainers/multi_modal/ofa/ofa_trainer_utils.py b/modelscope/trainers/multi_modal/ofa/ofa_trainer_utils.py new file mode 100644 index 00000000..2189a5db --- /dev/null +++ b/modelscope/trainers/multi_modal/ofa/ofa_trainer_utils.py @@ -0,0 +1,243 @@ +# Copyright 2022 The OFA-Sys Team. +# All rights reserved. +# This source code is licensed under the Apache 2.0 license +# found in the LICENSE file in the root directory. +import math + +import numpy as np +import torch +import torch.nn.functional as F +import transformers +from torch.nn.modules.loss import _Loss + + +def construct_rdrop_sample(x): + if isinstance(x, dict): + for key in x: + x[key] = construct_rdrop_sample(x[key]) + return x + elif isinstance(x, torch.Tensor): + return x.repeat(2, *([1] * (x.dim() - 1))) + elif isinstance(x, int): + return x * 2 + elif isinstance(x, np.ndarray): + return x.repeat(2) + else: + raise NotImplementedError + + +def kl_loss(p, q): + p_loss = F.kl_div(p, torch.exp(q), reduction='sum') + q_loss = F.kl_div(q, torch.exp(p), reduction='sum') + loss = (p_loss + q_loss) / 2 + return loss + + +def label_smoothed_nll_loss(lprobs, + target, + epsilon, + update_num, + reduce=True, + drop_worst_ratio=0.0, + drop_worst_after=0, + use_rdrop=False, + reg_alpha=1.0, + constraint_masks=None, + constraint_start=None, + constraint_end=None): + if target.dim() == lprobs.dim() - 1: + target = target.unsqueeze(-1) + nll_loss = -lprobs.gather(dim=-1, index=target).squeeze(-1) + if constraint_masks is not None: + smooth_loss = -lprobs.masked_fill(~constraint_masks, 0).sum( + dim=-1, keepdim=True).squeeze(-1) + eps_i = epsilon / (constraint_masks.sum(1) - 1 + 1e-6) + elif constraint_start is not None and constraint_end is not None: + constraint_range = [0, 1, 2, 3] + list( + range(constraint_start, constraint_end)) + smooth_loss = -lprobs[:, constraint_range].sum( + dim=-1, keepdim=True).squeeze(-1) + eps_i = epsilon / (len(constraint_range) - 1 + 1e-6) + else: + smooth_loss = -lprobs.sum(dim=-1, keepdim=True).squeeze(-1) + eps_i = epsilon / (lprobs.size(-1) - 1) + loss = (1.0 - epsilon - eps_i) * nll_loss + eps_i * smooth_loss + if drop_worst_ratio > 0 and update_num > drop_worst_after: + if use_rdrop: + true_batch_size = loss.size(0) // 2 + _, indices = torch.topk( + loss[:true_batch_size], + k=int(true_batch_size * (1 - drop_worst_ratio)), + largest=False) + loss = torch.cat([loss[indices], loss[indices + true_batch_size]]) + nll_loss = torch.cat( + [nll_loss[indices], nll_loss[indices + true_batch_size]]) + lprobs = torch.cat( + [lprobs[indices], lprobs[indices + true_batch_size]]) + else: + loss, indices = torch.topk( + loss, + k=int(loss.shape[0] * (1 - drop_worst_ratio)), + largest=False) + nll_loss = nll_loss[indices] + lprobs = lprobs[indices] + + ntokens = loss.numel() + nll_loss = nll_loss.sum() / ntokens # 后面在grads里面处理 + loss = loss.sum() / ntokens # 后面在grads里面处理 + if use_rdrop: + true_batch_size = lprobs.size(0) // 2 + p = lprobs[:true_batch_size] + q = lprobs[true_batch_size:] + if constraint_start is not None and constraint_end is not None: + constraint_range = [0, 1, 2, 3] + list( + range(constraint_start, constraint_end)) + p = p[:, constraint_range] + q = q[:, constraint_range] + loss += kl_loss(p, q) * reg_alpha + + return loss, nll_loss, ntokens + + +class AdjustLabelSmoothedCrossEntropyCriterion(_Loss): + + def __init__(self, args): + super().__init__() + self.sentence_avg = args.sentence_avg + self.eps = args.label_smoothing + self.ignore_prefix_size = args.ignore_prefix_size + self.ignore_eos = args.ignore_eos + self.report_accuracy = args.report_accuracy + self.drop_worst_ratio = args.drop_worst_ratio + self.drop_worst_after = args.drop_worst_after + self.use_rdrop = args.use_rdrop + self.reg_alpha = args.reg_alpha + self.sample_patch_num = args.sample_patch_num + + self.constraint_start = None + self.constraint_end = None + if args.constraint_range: + constraint_start, constraint_end = args.constraint_range.split(',') + self.constraint_start = int(constraint_start) + self.constraint_end = int(constraint_end) + self.padding_idx = args.tokenizer.pad_token_id + self.args = args + + def forward(self, output, sample, update_num=0, reduce=True): + """Compute the loss for the given sample. + + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + if self.use_rdrop: + construct_rdrop_sample(sample) + + loss, nll_loss, ntokens = self.compute_loss( + output, sample, update_num, reduce=reduce) + sample_size = ( + sample['target'].size(0) if self.sentence_avg else ntokens) + logging_output = { + 'loss': loss.data, + 'nll_loss': nll_loss.data, + 'ntokens': sample['ntokens'], + 'nsentences': sample['nsentences'], + 'sample_size': sample_size, + } + return loss, sample_size, logging_output + + def get_lprobs_and_target(self, net_output, sample): + conf = sample['conf'][:, None, None] if 'conf' in sample and sample[ + 'conf'] is not None else 1 + constraint_masks = None + if 'constraint_masks' in sample and sample[ + 'constraint_masks'] is not None: + constraint_masks = sample['constraint_masks'] + net_output[0].masked_fill_(~constraint_masks, -math.inf) + if self.constraint_start is not None and self.constraint_end is not None: + net_output[0][:, :, 4:self.constraint_start] = -math.inf + net_output[0][:, :, self.constraint_end:] = -math.inf + lprobs = F.log_softmax( + net_output[0], dim=-1, dtype=torch.float32) * conf + target = sample['target'] + if self.ignore_prefix_size > 0: + lprobs = lprobs[:, self.ignore_prefix_size:, :].contiguous() + target = target[:, self.ignore_prefix_size:].contiguous() + if constraint_masks is not None: + constraint_masks = constraint_masks[:, self.ignore_prefix_size:, :].contiguous() # yapf: disable + if self.ignore_eos: + bsz, seq_len, embed_dim = lprobs.size() + eos_indices = target.eq(self.task.tgt_dict.eos()) + lprobs = lprobs[~eos_indices].reshape(bsz, seq_len - 1, embed_dim) + target = target[~eos_indices].reshape(bsz, seq_len - 1) + if constraint_masks is not None: + constraint_masks = constraint_masks[~eos_indices].reshape( + bsz, seq_len - 1, embed_dim) + if constraint_masks is not None: + constraint_masks = constraint_masks.view(-1, + constraint_masks.size(-1)) + return lprobs.view(-1, + lprobs.size(-1)), target.view(-1), constraint_masks + + def compute_loss(self, net_output, sample, update_num, reduce=True): + lprobs, target, constraint_masks = self.get_lprobs_and_target( + net_output, sample) + if constraint_masks is not None: + constraint_masks = constraint_masks[target != self.padding_idx] + lprobs = lprobs[target != self.padding_idx] + target = target[target != self.padding_idx] + loss, nll_loss, ntokens = label_smoothed_nll_loss( + lprobs, + target, + self.eps, + update_num, + reduce=reduce, + drop_worst_ratio=self.drop_worst_ratio, + drop_worst_after=self.drop_worst_after, + use_rdrop=self.use_rdrop, + reg_alpha=self.reg_alpha, + constraint_masks=constraint_masks, + constraint_start=self.constraint_start, + constraint_end=self.constraint_end) + return loss, nll_loss, ntokens + + +def get_schedule(scheduler): + + if scheduler.name == 'const': + scheduler_class = transformers.get_constant_schedule_with_warmup + scheduler_args = { + 'num_warmup_steps': + int(scheduler.warmup_proportion * scheduler.num_train_steps) + } + elif scheduler.name == 'linear': + scheduler_class = transformers.get_linear_schedule_with_warmup + scheduler_args = { + 'num_warmup_steps': + int(scheduler.warmup_proportion * scheduler.num_train_steps), + 'num_training_steps': + scheduler.num_train_steps + } + elif scheduler.name == 'cosine': + scheduler_class = transformers.get_cosine_schedule_with_warmup + scheduler_args = { + 'num_warmup_steps': + int(scheduler.warmup_proportion * scheduler.num_train_steps), + 'num_training_steps': + scheduler.num_train_steps + } + elif scheduler.name == 'polynomial_decay': + scheduler_class = transformers.get_polynomial_decay_schedule_with_warmup + scheduler_args = { + 'num_warmup_steps': + int(scheduler.warmup_proportion * scheduler.num_train_steps), + 'num_training_steps': + scheduler.num_train_steps, + 'lr_end': + scheduler.lr_end + } + else: + raise NotImplementedError + + return scheduler_class, scheduler_args diff --git a/modelscope/trainers/trainer.py b/modelscope/trainers/trainer.py index 0dc6ece4..f47bff10 100644 --- a/modelscope/trainers/trainer.py +++ b/modelscope/trainers/trainer.py @@ -168,19 +168,20 @@ class EpochBasedTrainer(BaseTrainer): device_name = f'cuda:{local_rank}' self.device = create_device(device_name) - self.train_dataset = self.to_task_dataset( train_dataset, mode=ModeKeys.TRAIN, task_data_config=self.cfg.dataset.get('train', None) if hasattr( self.cfg, 'dataset') else None, - preprocessor=self.train_preprocessor) + preprocessor=self.train_preprocessor, + **kwargs) self.eval_dataset = self.to_task_dataset( eval_dataset, mode=ModeKeys.EVAL, task_data_config=self.cfg.dataset.get('val', None) if hasattr( self.cfg, 'dataset') else None, - preprocessor=self.eval_preprocessor) + preprocessor=self.eval_preprocessor, + **kwargs) self.train_data_collator, self.eval_default_collate = None, None if isinstance(data_collator, Mapping): @@ -216,7 +217,6 @@ class EpochBasedTrainer(BaseTrainer): self._max_epochs = self.cfg.train.max_epochs else: self._max_epochs = kwargs['max_epochs'] - self._train_iters_per_epoch = kwargs.get('train_iters_per_epoch', None) self._eval_iters_per_epoch = kwargs.get('val_iters_per_epoch', None) if self._train_iters_per_epoch is None and hasattr( @@ -306,13 +306,15 @@ class EpochBasedTrainer(BaseTrainer): datasets: Union[Dataset, List[Dataset]], mode: str, task_data_config: Config = None, - preprocessor: Optional[Preprocessor] = None): + preprocessor: Optional[Preprocessor] = None, + **kwargs): """Build the task specific dataset processor for this trainer. Returns: The task dataset processor for the task. If no result for the very model-type and task, the default TaskDataset will be returned. """ try: + to_tensor = kwargs.get('to_tensor', True) if not datasets: return datasets if isinstance(datasets, TorchTaskDataset): @@ -328,7 +330,8 @@ class EpochBasedTrainer(BaseTrainer): return datasets.to_torch_dataset( task_data_config=task_data_config, task_name=self.cfg.task, - preprocessors=preprocessor) + preprocessors=preprocessor, + to_tensor=to_tensor) elif isinstance(datasets, List) and isinstance( datasets[0], MsDataset): if task_data_config is None: @@ -342,7 +345,8 @@ class EpochBasedTrainer(BaseTrainer): d.to_torch_dataset( task_data_config=task_data_config, task_name=self.cfg.task, - preprocessors=preprocessor) for d in datasets + preprocessors=preprocessor, + to_tensor=to_tensor) for d in datasets ] cfg = ConfigDict( type=self.cfg.model.type, mode=mode, datasets=datasets) @@ -497,6 +501,7 @@ class EpochBasedTrainer(BaseTrainer): dp_cfg = dict( type='DistributedDataParallel', module=model, + find_unused_parameters=True, device_ids=[torch.cuda.current_device()]) return build_parallel(dp_cfg) @@ -779,7 +784,7 @@ class EpochBasedTrainer(BaseTrainer): batch_size = batch_size_per_gpu num_workers = workers_per_gpu - if dist: + if dist and not isinstance(dataset, torch.utils.data.IterableDataset): sampler = DistributedSampler( dataset, num_replicas=world_size, rank=rank, shuffle=shuffle) else: diff --git a/modelscope/trainers/utils/inference.py b/modelscope/trainers/utils/inference.py index 7f5d4ec3..1f8f8ed0 100644 --- a/modelscope/trainers/utils/inference.py +++ b/modelscope/trainers/utils/inference.py @@ -69,7 +69,10 @@ def single_gpu_test(model, batch_size = 1 # iteration count else: if isinstance(data, dict): - batch_size = len(next(iter(data.values()))) + if 'nsentences' in data: + batch_size = data['nsentences'] + else: + batch_size = len(next(iter(data.values()))) else: batch_size = len(data) for _ in range(batch_size): @@ -152,21 +155,29 @@ def multi_gpu_test(model, result = model.forward(data) results.append(result) - if rank == 0: - if isinstance(data, dict): - batch_size = len(next(iter(data.values()))) + if isinstance(data, dict): + if 'nsentences' in data: + batch_size = data['nsentences'] else: - batch_size = len(data) - - if progress_with_iters: - total_samples += batch_size * world_size - batch_size = 1 # iteration count + batch_size = len(next(iter(data.values()))) + else: + batch_size = len(data) + if i >= (data_len // world_size) - 1: + total_samples = torch.LongTensor([batch_size]).to(model.device) + dist.all_reduce(total_samples, op=dist.reduce_op.SUM) + total_samples = total_samples.item() + else: + total_samples = batch_size * world_size + if progress_with_iters: + iter_cnt_all = world_size + else: + iter_cnt_all = total_samples + count += iter_cnt_all - batch_size_all = batch_size * world_size - count += batch_size_all + if rank == 0: if count > data_len: - batch_size_all = data_len - (count - batch_size_all) - for _ in range(batch_size_all): + iter_cnt_all = data_len - (count - iter_cnt_all) + for _ in range(iter_cnt_all): pbar.update() if progress_with_iters and (i + 1) >= data_len: diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index 6a9d6fd5..86b7bb7d 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -280,6 +280,7 @@ class ConfigKeys(object): """Fixed keywords in configuration file""" train = 'train' val = 'val' + test = 'test' class Requirements(object): diff --git a/modelscope/utils/device.py b/modelscope/utils/device.py index 6fc59e37..83faa261 100644 --- a/modelscope/utils/device.py +++ b/modelscope/utils/device.py @@ -1,5 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. - +import os from contextlib import contextmanager from modelscope.utils.constant import Devices, Frameworks @@ -106,3 +106,17 @@ def create_device(device_name): device = torch.device('cpu') return device + + +def get_device(): + import torch + from torch import distributed as dist + if torch.cuda.is_available(): + if dist.is_available() and dist.is_initialized( + ) and 'LOCAL_RANK' in os.environ: + device_id = f"cuda:{os.environ['LOCAL_RANK']}" + else: + device_id = 'cuda:0' + else: + device_id = 'cpu' + return torch.device(device_id) diff --git a/modelscope/utils/multi_modal/fp16/__init__.py b/modelscope/utils/multi_modal/fp16/__init__.py new file mode 100644 index 00000000..81250858 --- /dev/null +++ b/modelscope/utils/multi_modal/fp16/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .fp16 import FP16_Module, FP16_Optimizer diff --git a/modelscope/utils/multi_modal/fp16/fp16.py b/modelscope/utils/multi_modal/fp16/fp16.py new file mode 100755 index 00000000..37a80e65 --- /dev/null +++ b/modelscope/utils/multi_modal/fp16/fp16.py @@ -0,0 +1,655 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Stable version of apex FP16 Optimizer""" +import torch +from torch import nn +from torch.autograd import Variable +from torch.nn.parameter import Parameter + +from .fp16util import (master_params_to_model_params, + model_grads_to_master_grads) +from .loss_scaler import DynamicLossScaler, LossScaler + +FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor) +HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor) + + +def conversion_helper(val, conversion): + """Apply conversion to val. Recursively apply conversion if `val` is a nested tuple/list structure.""" + if not isinstance(val, (tuple, list)): + return conversion(val) + rtn = [conversion_helper(v, conversion) for v in val] + if isinstance(val, tuple): + rtn = tuple(rtn) + return rtn + + +def fp32_to_fp16(val): + """Convert fp32 `val` to fp16""" + + def half_conversion(val): + val_typecheck = val + if isinstance(val_typecheck, (Parameter, Variable)): + val_typecheck = val.data + if isinstance(val_typecheck, FLOAT_TYPES): + val = val.half() + return val + + return conversion_helper(val, half_conversion) + + +def fp16_to_fp32(val): + """Convert fp16 `val` to fp32""" + + def float_conversion(val): + val_typecheck = val + if isinstance(val_typecheck, (Parameter, Variable)): + val_typecheck = val.data + if isinstance(val_typecheck, HALF_TYPES): + val = val.float() + return val + + return conversion_helper(val, float_conversion) + + +class FP16_Module(nn.Module): + + def __init__(self, module): + super(FP16_Module, self).__init__() + self.add_module('module', module.half()) + + def forward(self, *inputs, **kwargs): + return fp16_to_fp32(self.module(*(fp32_to_fp16(inputs)), **kwargs)) + + def state_dict(self, destination=None, prefix='', keep_vars=False): + return self.module.state_dict(destination, prefix, keep_vars) + + def load_state_dict(self, state_dict, strict=True): + self.module.load_state_dict(state_dict, strict=strict) + + +class FP16_Optimizer(object): + """ + :class:`FP16_Optimizer` is designed to wrap an existing PyTorch optimizer, + and manage static or dynamic loss scaling and master weights in a manner transparent to the user. + For standard use, only two lines must be changed: creating the :class:`FP16_Optimizer` instance, + and changing the call to ``backward``. + + Example:: + + model = torch.nn.Linear(D_in, D_out).cuda().half() + optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) + # Name the FP16_Optimizer instance to replace the existing optimizer + # (recommended but not required): + optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0) + ... + # loss.backward() becomes: + optimizer.backward(loss) + ... + + Example with dynamic loss scaling:: + + ... + optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True) + # optional arg to control dynamic loss scaling behavior + # dynamic_loss_args={'scale_window' : 500}) + # Usually, dynamic_loss_args is not necessary. + + Args: + init_optimizer (torch.optim.optimizer): Existing optimizer created with the parameters to optimize. Internally, :class:`FP16_Optimizer` replaces the passed optimizer's fp16 parameters, if any, with fp32 master parameters copied from the original ones. :class:`FP16_Optimizer` also stores references to the original fp16 parameters, and updates these fp16 parameters from the master fp32 copy at the end of each :attr:`step`. # noqa + static_loss_scale (float, optional, default=1.0): Loss scale used internally to scale gradients computed by the model. Any fp16 gradients will be copied to fp32, then downscaled before being applied to the fp32 master params, so ``static_loss_scale`` should not affect learning rate. # noqa + dynamic_loss_scale (bool, optional, default=False): Use dynamic loss scaling. If True, this will override any ``static_loss_scale`` option. # noqa + dynamic_loss_args (dict, optional, default=None): Dict of kwargs that will be forwarded to the internal :class:`DynamicLossScaler` instance's constructor. Keys of this dict must match kwargs accepted by :class:`DynamicLossScaler`'s constructor. If ``dynamic_loss_args`` is unspecified, :class:`DynamicLossScaler`'s defaults will be used. # noqa + verbose (bool, optional, default=True): By default, FP16_Optimizer's constructor prints out the parameters and parameter groups it is ingesting, as a sanity check. If this becomes annoying (e.g. for large models), it can be disabled by passing ``verbose=False``. ``verbose=False`` will not disable printing when the loss scale is readjusted during dynamic loss scaling. # noqa + + ``init_optimizer`` is expected to have been constructed in the ordinary way. + It is recommended (although not required) that the newly constructed :class:`FP16_Optimizer` instance be + named to replace ``init_optimizer``, for two reasons: + First, it means that references to the same name + later in the file will not have to change. + Second, :class:`FP16_Optimizer` reserves the right (as an implementation detail) to + modify ``init_optimizer``. If you do choose a unique name for the new + :class:`FP16_Optimizer` instance, you should only work with this new instance, + because the preexisting optimizer might no longer behave as expected. + + ``init_optimizer`` may be any Pytorch optimizer. + It may contain a mixture of fp16 and fp32 parameters organized into any number of + ``param_groups`` with different hyperparameters. The :class:`FP16_Optimizer` constructor will + ingest these ``param_groups`` and remember them. + + Calls to :: + + loss.backward() + + must be replaced with :: + + optimizer.backward(loss) + + because :class:`FP16_Optimizer` requires ownership of the backward pass to implement + loss scaling and copies to master gradients. + + .. note:: + Loss scaling, either static or dynamic, is orthogonal to learning rate, because gradients + are downscaled before being applied. This means that adjusting the loss scale, or using + dynamic loss scaling, should not require retuning the learning rate or any other + hyperparameters. + + + **Advanced options** + + **Closures**: :class:`FP16_Optimizer` can wrap a Pytorch optimizer that receives a closure. + See docstring for :attr:`step`. + + **Gradient clipping**: Use :attr:`clip_master_grads`. + + **Multiple losses**: If your model accumulates gradients from multiple losses, + this can be made more efficient by supplying ``update_master_grads=False`` + to :attr:`backward`. See docstring for :attr:`backward`. + + **Manually adjusting loss scale**: The current loss scale can be retrieved or set via :: + + print(optimizer.loss_scale) + optimizer.loss_scale = new_loss_scale + + For static loss scaling, manually adjusting the loss scale over time is a reasonable + thing to do. During later epochs, gradients may become smaller, and a + higher loss scale may be required, analogous to scheduling the learning rate. Dynamic loss + scaling is more subtle (see :class:`DynamicLossScaler`) and in this case, manually adjusting + the loss scale is not recommended. + + **Multi_GPU training**: If the wrapped ``init_optimizer`` was created from a model wrapped in + Pytorch DistributedDataParallel or Apex DistributedDataParallel, :class:`FP16_Optimizer` + should still work as intended. + """ + + def __init__(self, + init_optimizer, + static_loss_scale=1.0, + dynamic_loss_scale=False, + dynamic_loss_args=None, + verbose=False): + if not torch.cuda.is_available: + raise SystemError('Cannot use fp16 without CUDA.') + + self.verbose = verbose + + self.optimizer = init_optimizer + # init_state_dict sets up an alternative way to cast per-param state tensors. + # Stashing here in case https://github.com/pytorch/pytorch/issues/7733 makes it necessary. + # init_state_dict = init_optimizer.state_dict() + + self.fp16_groups = [] + self.fp32_from_fp16_groups = [] + self.fp32_from_fp32_groups = [] + for i, param_group in enumerate(self.optimizer.param_groups): + self.maybe_print( + 'FP16_Optimizer processing param group {}:'.format(i)) + fp16_params_this_group = [] + fp32_params_this_group = [] + fp32_from_fp16_params_this_group = [] + for i, param in enumerate(param_group['params']): + if param.requires_grad: + if param.type() == 'torch.cuda.HalfTensor': + self.maybe_print( + 'FP16_Optimizer received torch.cuda.HalfTensor with {}' + .format(param.size())) + fp16_params_this_group.append(param) + master_param = param.detach().clone().float() + master_param.requires_grad = True + # Copythe model parallel flag. + master_param.model_parallel = param.model_parallel + param_group['params'][i] = master_param + fp32_from_fp16_params_this_group.append(master_param) + # Reset existing state dict key to the new master param. + # We still need to recast per-param state tensors, if any, to FP32. + if param in self.optimizer.state: + self.optimizer.state[ + master_param] = self.optimizer.state.pop(param) + elif param.type() == 'torch.cuda.FloatTensor': + self.maybe_print( + 'FP16_Optimizer received torch.cuda.FloatTensor with {}' + .format(param.size())) + fp32_params_this_group.append(param) + param_group['params'][i] = param + else: + raise TypeError( + 'Wrapped parameters must be either ' + 'torch.cuda.FloatTensor or torch.cuda.HalfTensor. ' + 'Received {}'.format(param.type())) + + self.fp16_groups.append(fp16_params_this_group) + self.fp32_from_fp16_groups.append(fp32_from_fp16_params_this_group) + self.fp32_from_fp32_groups.append(fp32_params_this_group) + + # Leverage state_dict() and load_state_dict() to recast preexisting per-param state tensors + self.optimizer.load_state_dict(self.optimizer.state_dict()) + # alternative way to cast per-param state tensors: + # self.optimizer.load_state_dict(init_state_dict) + + if dynamic_loss_scale: + self.dynamic_loss_scale = True + if dynamic_loss_args is not None: + self.loss_scaler = DynamicLossScaler(**dynamic_loss_args) + else: + self.loss_scaler = DynamicLossScaler() + else: + self.dynamic_loss_scale = False + self.loss_scaler = LossScaler(static_loss_scale) + + self.overflow = False + self.first_closure_call_this_step = True + + self.clip_grad_norm = nn.utils.clip_grad.clip_grad_norm_ + + def maybe_print(self, msg): + if self.verbose: + print(msg) + + def __getstate__(self): + raise RuntimeError( + 'FP16_Optimizer should be serialized using state_dict().') + + def __setstate__(self, state): + raise RuntimeError( + 'FP16_Optimizer should be deserialized using load_state_dict().') + + def zero_grad(self, set_grads_to_None=False): + """ + Zero fp32 and fp16 parameter grads. + """ + # In principle, only the .grad attributes of the model params need to be zeroed, + # because gradients are copied into the FP32 master params. However, we zero + # all gradients owned by the optimizer, just to be safe: + for group in self.optimizer.param_groups: + for p in group['params']: + if set_grads_to_None: + p.grad = None + else: + if p.grad is not None: + p.grad.detach_() + p.grad.zero_() + + # Zero fp16 gradients owned by the model: + for fp16_group in self.fp16_groups: + for param in fp16_group: + if set_grads_to_None: + param.grad = None + else: + if param.grad is not None: + param.grad.detach_( + ) # as in torch.optim.optimizer.zero_grad() + param.grad.zero_() + + def _check_overflow(self): + params = [] + for group in self.fp16_groups: + for param in group: + params.append(param) + for group in self.fp32_from_fp32_groups: + for param in group: + params.append(param) + self.overflow = self.loss_scaler.has_overflow(params) + + def _update_scale(self, has_overflow=False): + self.loss_scaler.update_scale(has_overflow) + + def _master_params_to_model_params(self): + for fp16_group, fp32_from_fp16_group in zip( + self.fp16_groups, self.fp32_from_fp16_groups): + master_params_to_model_params(fp16_group, fp32_from_fp16_group) + + def _model_params_to_master_params(self): + for fp16_group, fp32_from_fp16_group in zip( + self.fp16_groups, self.fp32_from_fp16_groups): + master_params_to_model_params(fp32_from_fp16_group, fp16_group) + + # To consider: Integrate distributed with this wrapper by registering a hook on each variable + # that does the overflow check, gradient copy + downscale, and fp32 allreduce in a different stream. + def _model_grads_to_master_grads(self): + for fp16_group, fp32_from_fp16_group in zip( + self.fp16_groups, self.fp32_from_fp16_groups): + model_grads_to_master_grads(fp16_group, fp32_from_fp16_group) + + def _downscale_master(self): + if self.loss_scale != 1.0: + for group in self.optimizer.param_groups: + for param in group['params']: + if param.grad is not None: + param.grad.data.mul_(1. / self.loss_scale) + + def clip_master_grads(self, max_norm, norm_type=2): + """ + Clips fp32 master gradients via ``torch.nn.utils.clip_grad_norm``. + + Args: + max_norm (float or int): max norm of the gradients + norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for + infinity norm. + + Returns: + Total norm of the current fp32 gradients (viewed as a single vector). + + .. warning:: + Returns -1 if the most recently computed fp16 gradients overflowed (that is, if ``self.overflow`` is ``True``). # noqa + """ + if not self.overflow: + fp32_params = [] + for param_group in self.optimizer.param_groups: + for param in param_group['params']: + fp32_params.append(param) + return self.clip_grad_norm(fp32_params, max_norm, norm_type) + else: + return -1 + + def state_dict(self): + """ + Returns a dict containing the current state of this :class:`FP16_Optimizer` instance. + This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict + of the contained Pytorch optimizer. + Example:: + + checkpoint = {} + checkpoint['model'] = model.state_dict() + checkpoint['optimizer'] = optimizer.state_dict() + torch.save(checkpoint, "saved.pth") + """ + state_dict = {} + state_dict['loss_scaler'] = self.loss_scaler + state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale + state_dict['overflow'] = self.overflow + state_dict[ + 'first_closure_call_this_step'] = self.first_closure_call_this_step + state_dict['optimizer_state_dict'] = self.optimizer.state_dict() + state_dict['fp32_from_fp16'] = self.fp32_from_fp16_groups + return state_dict + + def load_state_dict(self, state_dict): + """ + Loads a state_dict created by an earlier call to state_dict(). + If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``, + whose parameters in turn came from ``model``, it is expected that the user + will call ``model.load_state_dict()`` before + ``fp16_optimizer_instance.load_state_dict()`` is called. + + Example:: + + model = torch.nn.Linear(D_in, D_out).cuda().half() + optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) + optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0) + ... + checkpoint = torch.load("saved.pth") + model.load_state_dict(checkpoint['model']) + optimizer.load_state_dict(checkpoint['optimizer']) + """ + # I think it should actually be ok to reload the optimizer before the model. + self.loss_scaler = state_dict['loss_scaler'] + self.dynamic_loss_scale = state_dict['dynamic_loss_scale'] + self.overflow = state_dict['overflow'] + self.first_closure_call_this_step = state_dict[ + 'first_closure_call_this_step'] + self.optimizer.load_state_dict(state_dict['optimizer_state_dict']) + # At this point, the optimizer's references to the model's fp32 parameters are up to date. + # The optimizer's hyperparameters and internal buffers are also up to date. + # However, the fp32 master copies of the model's fp16 params stored by the optimizer are still + # out of date. There are two options. + # 1: Refresh the master params from the model's fp16 params. + # This requires less storage but incurs precision loss. + # 2: Save and restore the fp32 master copies separately. + # We choose option 2. + # + # Pytorch Optimizer.load_state_dict casts saved buffers (e.g. momentum) to the type and device + # of their associated parameters, because it's possible those buffers might not exist yet in + # the current optimizer instance. In our case, as long as the current FP16_Optimizer has been + # constructed in the same way as the one whose state_dict we are loading, the same master params + # are guaranteed to exist, so we can just copy_() from the saved master params. + for current_group, saved_group in zip(self.fp32_from_fp16_groups, + state_dict['fp32_from_fp16']): + for current, saved in zip(current_group, saved_group): + current.data.copy_(saved.data) + + def step(self, closure=None): # could add clip option. + """ + If no closure is supplied, :attr:`step` should be called after + ``fp16_optimizer_obj.backward(loss)``. + :attr:`step` updates the fp32 master copy of parameters using the optimizer supplied to + :class:`FP16_Optimizer`'s constructor, then copies the updated fp32 params into the fp16 params + originally referenced by :class:`FP16_Optimizer`'s constructor, so the user may immediately run + another forward pass using their model. + + If a closure is supplied, :attr:`step` may be called without a prior call to + :attr:`backward(loss)`. + This control flow is identical to `ordinary Pytorch optimizer use`_ with closures. + However, the user should take care that any ``loss.backward()`` call within the closure + has been replaced by ``fp16_optimizer_obj.backward(loss)``. + + Args: + closure (optional): Closure that will be supplied to the underlying optimizer originally passed to :class:`FP16_Optimizer`'s constructor. closure should call :attr:`zero_grad()` on the :class:`FP16_Optimizer` object, compute the loss, call :attr:`backward(loss)`, and return the loss. # noqa + + Example with closure:: + + # optimizer is assumed to be an FP16_Optimizer object, previously constructed from an + # existing pytorch optimizer. + for input, target in dataset: + def closure(): + optimizer.zero_grad() + output = model(input) + loss = loss_fn(output, target) + # loss.backward() becomes: + optimizer.backward(loss) + return loss + optimizer.step(closure) + + .. warning:: + Currently, calling :attr:`step` with a closure is not compatible with dynamic loss scaling. + + .. _`ordinary Pytorch optimizer use`: + http://pytorch.org/docs/master/optim.html#optimizer-step-closure + """ + + scale = self.loss_scaler.loss_scale + self._update_scale(self.overflow) + + if self.overflow: + self.maybe_print( + 'OVERFLOW! Skipping step. Attempted loss scale: {}, reducing to {}' + .format(scale, self.loss_scale)) + return + + if closure is not None: + retval = self._step_with_closure(closure) + else: + retval = self.optimizer.step() + + self._master_params_to_model_params() + + return retval + + def _step_with_closure(self, closure): + + def wrapped_closure(): + # helpful for debugging + # print("Calling wrapped_closure, first_closure_call_this_step = {}" + # .format(self.first_closure_call_this_step)) + if self.first_closure_call_this_step: + # We expect that the fp16 params are initially fresh on entering self.step(), + # so _master_params_to_model_params() is unnecessary the first time wrapped_closure() + # is called within self.optimizer.step(). + self.first_closure_call_this_step = False + else: + # If self.optimizer.step() internally calls wrapped_closure more than once, + # it may update the fp32 params after each call. However, self.optimizer + # doesn't know about the fp16 params at all. If the fp32 params get updated, + # we can't rely on self.optimizer to refresh the fp16 params. We need + # to handle that manually: + self._master_params_to_model_params() + # Our API expects the user to give us ownership of the backward() call by + # replacing all calls to loss.backward() with optimizer.backward(loss). + # This requirement holds whether or not the call to backward() is made within a closure. + # If the user is properly calling optimizer.backward(loss) within "closure," + # calling closure() here will give the fp32 master params fresh gradients + # for the optimizer to play with, so all wrapped_closure needs to do is call + # closure() and return the loss. + temp_loss = closure() + while (self.overflow): + scale = self.loss_scaler.loss_scale + self._update_scale(self.overflow) + self.maybe_print( + 'OVERFLOW within closure! Skipping step. Attempted loss scale: {}, ' + 'reducing to {}'.format(scale, self.loss_scale)) + temp_loss = closure() + return temp_loss + + retval = self.optimizer.step(wrapped_closure) + + self.first_closure_call_this_step = True + + return retval + + def backward(self, loss, update_master_grads=True, retain_graph=False): + """ + :attr:`backward` performs the following conceptual steps: + + 1. fp32_loss = loss.float() (see first Note below) + 2. scaled_loss = fp32_loss*loss_scale + 3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's leaves (which may be fp16, fp32, or a mixture, depending how your model was defined). # noqa + 4. fp16 grads are then copied to the master params' ``.grad`` attributes (see second Note), which are guaranteed to be fp32. # noqa + 5. Finally, master grads are divided by loss_scale. + + In this way, after :attr:`backward`, the master params have fresh gradients, + and :attr:`step` may be called. + + .. note:: + :attr:`backward` internally converts the loss to fp32 before applying the loss scale. + This provides some additional safety against overflow if the user has supplied an + fp16 loss value. + However, for maximum overflow safety, the user should + compute the loss criterion (MSE, cross entropy, etc) in fp32 before supplying it to + :attr:`backward`. + + .. warning:: + The gradients found in a model's leaves after the call to + :attr:`backward` should not be regarded as valid in general, + because it's possible + they have been scaled (and in the case of dynamic loss scaling, + the scale factor may change over time). + If the user wants to inspect gradients after a call to :attr:`backward`, + only the master gradients should be regarded as valid. These can be retrieved via + :attr:`inspect_master_grad_data()`. + + Args: + loss: The loss output by the user's model. loss may be either float or half (but see first Note above). + update_master_grads (bool, optional, default=True): Option to copy fp16 grads to fp32 grads on this call. By setting this to False, the user can delay the copy, which is useful to eliminate redundant fp16->fp32 grad copies if :attr:`backward` is being called on multiple losses in one iteration. If set to False, the user becomes responsible for calling :attr:`update_master_grads` before calling :attr:`step`. # noqa + retain_graph (bool, optional, default=False): Forwards the usual ``retain_graph=True`` option to the internal call to ``loss.backward``. If ``retain_graph`` is being used to accumulate gradient values from multiple backward passes before calling ``optimizer.step``, passing ``update_master_grads=False`` is also recommended (see Example below). # noqa + + Example:: + + # Ordinary operation: + optimizer.backward(loss) + + # Naive operation with multiple losses (technically valid, but less efficient): + # fp32 grads will be correct after the second call, but + # the first call incurs an unnecessary fp16->fp32 grad copy. + optimizer.backward(loss1) + optimizer.backward(loss2) + + # More efficient way to handle multiple losses: + # The fp16->fp32 grad copy is delayed until fp16 grads from all + # losses have been accumulated. + optimizer.backward(loss1, update_master_grads=False) + optimizer.backward(loss2, update_master_grads=False) + optimizer.update_master_grads() + """ + # To consider: try multiple backward passes using retain_grad=True to find + # a loss scale that works. After you find a loss scale that works, do a final dummy + # backward pass with retain_graph=False to tear down the graph. Doing this would avoid + # discarding the iteration, but probably wouldn't improve overall efficiency. + self.loss_scaler.backward(loss.float(), retain_graph=retain_graph) + if update_master_grads: + self.update_master_grads() + + def update_master_grads(self): + """ + Copy the ``.grad`` attribute from stored references to fp16 parameters to + the ``.grad`` attribute of the fp32 master parameters that are directly + updated by the optimizer. :attr:`update_master_grads` only needs to be called if + ``fp16_optimizer_obj.backward`` was called with ``update_master_grads=False``. + """ + if self.dynamic_loss_scale: + self._check_overflow() + if self.overflow: return # noqa + self._model_grads_to_master_grads() + self._downscale_master() + + def inspect_master_grad_data(self): + """ + When running with :class:`FP16_Optimizer`, + ``.grad`` attributes of a model's fp16 leaves should not be + regarded as truthful, because they might be scaled. + After a call to :attr:`fp16_optimizer_obj.backward(loss)`, if no overflow was encountered, + the fp32 master params' ``.grad`` + attributes will contain valid gradients properly divided by the loss scale. However, + because :class:`FP16_Optimizer` flattens some parameters, accessing them may be + nonintuitive. :attr:`inspect_master_grad_data` + allows those gradients to be viewed with shapes corresponding to their associated model leaves. + + Returns: + List of lists (one list for each parameter group). The list for each parameter group + is a list of the ``.grad.data`` attributes of the fp32 master params belonging to that group. + """ + if self.overflow: + print( + 'Warning: calling FP16_Optimizer.inspect_master_grad_data while in an overflow state. ' + 'Gradients are currently invalid (may be inf, nan, or stale). Returning None.' + ) + return None + else: + # The optimizer owns only references to master params. + master_grads_data = [] + for param_group in self.optimizer.param_groups: + master_grads_this_group = [] + for param in param_group['params']: + if param.grad is not None: + master_grads_this_group.append(param.grad.data) + else: + master_grads_this_group.append(None) + master_grads_data.append(master_grads_this_group) + return master_grads_data + + # Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale" + def _get_loss_scale(self): + return self.loss_scaler.loss_scale + + def _set_loss_scale(self, value): + self.loss_scaler.cur_scale = value + + loss_scale = property(_get_loss_scale, _set_loss_scale) + + # Promote state so it can be retrieved or set via "fp16_optimizer_instance.state" + def _get_state(self): + return self.optimizer.state + + def _set_state(self, value): + self.optimizer.state = value + + state = property(_get_state, _set_state) + + # Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups" + # (for example, to adjust the learning rate) + def _get_param_groups(self): + return self.optimizer.param_groups + + def _set_param_groups(self, value): + self.optimizer.param_groups = value + + param_groups = property(_get_param_groups, _set_param_groups) diff --git a/modelscope/utils/multi_modal/fp16/fp16util.py b/modelscope/utils/multi_modal/fp16/fp16util.py new file mode 100644 index 00000000..29595a6c --- /dev/null +++ b/modelscope/utils/multi_modal/fp16/fp16util.py @@ -0,0 +1,216 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors +from torch.autograd import Variable + + +class tofp16(nn.Module): + """ + Utility module that implements:: + + def forward(self, input): + return input.half() + """ + + def __init__(self): + super(tofp16, self).__init__() + + def forward(self, input): + return input.half() + + +def BN_convert_float(module): + """ + Utility function for network_to_half(). + + Retained for legacy purposes. + """ + if isinstance( + module, + torch.nn.modules.batchnorm._BatchNorm) and module.affine is True: + module.float() + for child in module.children(): + BN_convert_float(child) + return module + + +def network_to_half(network): + """ + Convert model to half precision in a batchnorm-safe way. + + Retained for legacy purposes. It is recommended to use FP16Model. + """ + return nn.Sequential(tofp16(), BN_convert_float(network.half())) + + +def convert_module(module, dtype): + """ + Converts a module's immediate parameters and buffers to dtype. + """ + for param in module.parameters(recurse=False): + if param is not None: + if param.data.dtype.is_floating_point: + param.data = param.data.to(dtype=dtype) + if param._grad is not None and param._grad.data.dtype.is_floating_point: + param._grad.data = param._grad.data.to(dtype=dtype) + + for buf in module.buffers(recurse=False): + if buf is not None and buf.data.dtype.is_floating_point: + buf.data = buf.data.to(dtype=dtype) + + +def convert_network(network, dtype): + """ + Converts a network's parameters and buffers to dtype. + """ + for module in network.modules(): + if isinstance(module, torch.nn.modules.batchnorm._BatchNorm + ) and module.affine is True: + continue + convert_module(module, dtype) + return network + + +class FP16Model(nn.Module): + """ + Convert model to half precision in a batchnorm-safe way. + """ + + def __init__(self, network): + super(FP16Model, self).__init__() + self.network = convert_network(network, dtype=torch.half) + + def forward(self, *inputs): + inputs = tuple(t.half() for t in inputs) + return self.network(*inputs) + + +def backwards_debug_hook(grad): + raise RuntimeError( + 'master_params recieved a gradient in the backward pass!') + + +def prep_param_lists(model, flat_master=False): + """ + Creates a list of FP32 master parameters for a given model, as in + `Training Neural Networks with Mixed Precision: Real Examples`_. + + Args: + model (torch.nn.Module): Existing Pytorch model + flat_master (bool, optional, default=False): Flatten the master parameters into a single tensor, as a performance optimization. # noqa + Returns: + A tuple (``model_params``, ``master_params``). ``model_params`` is a list of the model's parameters for later use with :func:`model_grads_to_master_grads` and :func:`master_params_to_model_params`. ``master_params`` is a list of FP32 master gradients. If ``flat_master=True``, ``master_params`` will be a list with one element. # noqa + + Example:: + + model_params, master_params = prep_param_lists(model) + + .. warning:: + Currently, if ``flat_master=True``, all the model's parameters must be the same type. If the model has parameters of different types, use ``flat_master=False``, or use :class:`FP16_Optimizer`. # noqa + + .. _`Training Neural Networks with Mixed Precision: Real Examples`: + http://on-demand.gputechconf.com/gtc/2018/video/S81012/ + """ + model_params = [ + param for param in model.parameters() if param.requires_grad + ] + + if flat_master: + # Give the user some more useful error messages + try: + # flatten_dense_tensors returns a contiguous flat array. + # http://pytorch.org/docs/master/_modules/torch/_utils.html + master_params = _flatten_dense_tensors( + [param.data for param in model_params]).float() + except: # noqa + print( + 'Error in prep_param_lists: model may contain a mixture of parameters ' + 'of different types. Use flat_master=False, or use F16_Optimizer.' + ) + raise + master_params = torch.nn.Parameter(master_params) + master_params.requires_grad = True + # master_params.register_hook(backwards_debug_hook) + if master_params.grad is None: + master_params.grad = master_params.new(*master_params.size()) + return model_params, [master_params] + else: + master_params = [ + param.clone().float().detach() for param in model_params + ] + for param in master_params: + param.requires_grad = True + return model_params, master_params + + +def model_grads_to_master_grads(model_params, + master_params, + flat_master=False): + """ + Copy model gradients to master gradients. + + Args: + model_params: List of model parameters created by :func:`prep_param_lists`. + master_params: List of FP32 master parameters created by :func:`prep_param_lists`. If ``master_params`` was created with ``flat_master=True``, ``flat_master=True`` should also be supplied to :func:`model_grads_to_master_grads`. # noqa + """ + if flat_master: + # The flattening may incur one more deep copy than is necessary. + master_params[0].grad.data.copy_( + _flatten_dense_tensors([p.grad.data for p in model_params])) + else: + for model, master in zip(model_params, master_params): + if model.grad is not None: + if master.grad is None: + master.grad = Variable( + master.data.new(*master.data.size())) + master.grad.data.copy_(model.grad.data) + else: + master.grad = None + + +def master_params_to_model_params(model_params, + master_params, + flat_master=False): + """ + Copy master parameters to model parameters. + + Args: + model_params: List of model parameters created by :func:`prep_param_lists`. + master_params: List of FP32 master parameters created by :func:`prep_param_lists`. If ``master_params`` was created with ``flat_master=True``, ``flat_master=True`` should also be supplied to :func:`master_params_to_model_params`. # noqa + """ + if flat_master: + for model, master in zip( + model_params, + _unflatten_dense_tensors(master_params[0].data, model_params)): + model.data.copy_(master) + else: + for model, master in zip(model_params, master_params): + model.data.copy_(master.data) + + +# Backward compatibility fixes + + +def to_python_float(t): + if hasattr(t, 'item'): + return t.item() + else: + return t[0] + + +TORCH_MAJOR = int(torch.__version__.split('.')[0]) +TORCH_MINOR = int(torch.__version__.split('.')[1]) diff --git a/modelscope/utils/multi_modal/fp16/loss_scaler.py b/modelscope/utils/multi_modal/fp16/loss_scaler.py new file mode 100755 index 00000000..fc55a4ed --- /dev/null +++ b/modelscope/utils/multi_modal/fp16/loss_scaler.py @@ -0,0 +1,237 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + + +# item() is a recent addition, so this helps with backward compatibility. +def to_python_float(t): + if hasattr(t, 'item'): + return t.item() + else: + return t[0] + + +class LossScaler: + """ + Class that manages a static loss scale. This class is intended to interact with + :class:`FP16_Optimizer`, and should not be directly manipulated by the user. + + Use of :class:`LossScaler` is enabled via the ``static_loss_scale`` argument to + :class:`FP16_Optimizer`'s constructor. + + Args: + scale (float, optional, default=1.0): The loss scale. + """ + + def __init__(self, scale=1): + self.cur_scale = scale + + # `params` is a list / generator of torch.Variable + def has_overflow(self, params): + return False + + # `x` is a torch.Tensor + def _has_inf_or_nan(x): + return False + + def update_scale(self, overflow): + pass + + @property + def loss_scale(self): + return self.cur_scale + + def scale_gradient(self, module, grad_in, grad_out): + return tuple(self.loss_scale * g for g in grad_in) + + def backward(self, loss, retain_graph=False): + scaled_loss = loss * self.loss_scale + scaled_loss.backward(retain_graph=retain_graph) + + +class DynamicLossScaler: + """ + Class that manages dynamic loss scaling. It is recommended to use :class:`DynamicLossScaler` + indirectly, by supplying ``dynamic_loss_scale=True`` to the constructor of + :class:`FP16_Optimizer`. However, it's important to understand how :class:`DynamicLossScaler` + operates, because the default options can be changed using the + the ``dynamic_loss_args`` argument to :class:`FP16_Optimizer`'s constructor. + + Loss scaling is designed to combat the problem of underflowing gradients encountered at long + times when training fp16 networks. Dynamic loss scaling begins by attempting a very high loss + scale. Ironically, this may result in OVERflowing gradients. If overflowing gradients are + encountered, :class:`DynamicLossScaler` informs :class:`FP16_Optimizer` that an overflow has + occurred. + :class:`FP16_Optimizer` then skips the update step for this particular iteration/minibatch, + and :class:`DynamicLossScaler` adjusts the loss scale to a lower value. + If a certain number of iterations occur without overflowing gradients detected, + :class:`DynamicLossScaler` increases the loss scale once more. + In this way :class:`DynamicLossScaler` attempts to "ride the edge" of + always using the highest loss scale possible without incurring overflow. + + Args: + init_scale (float, optional, default=2**32): Initial loss scale attempted by :class:`DynamicLossScaler.` + scale_factor (float, optional, default=2.0): Factor used when adjusting the loss scale. If an overflow is encountered, the loss scale is readjusted to loss scale/``scale_factor``. If ``scale_window`` consecutive iterations take place without an overflow, the loss scale is readjusted to loss_scale*``scale_factor``. # noqa + scale_window (int, optional, default=1000): Number of consecutive iterations without an overflow to wait before increasing the loss scale. # noqa + """ + + def __init__(self, + init_scale=2**32, + scale_factor=2., + scale_window=1000, + min_scale=1, + delayed_shift=1, + consecutive_hysteresis=False): + self.cur_scale = init_scale + self.cur_iter = 0 + self.last_overflow_iter = -1 + self.scale_factor = scale_factor + self.scale_window = scale_window + self.min_scale = min_scale + self.delayed_shift = delayed_shift + self.cur_hysteresis = delayed_shift + self.consecutive_hysteresis = consecutive_hysteresis + + # `params` is a list / generator of torch.Variable + def has_overflow_serial(self, params): + for p in params: + if p.grad is not None and DynamicLossScaler._has_inf_or_nan( + p.grad.data): + return True + + return False + + def has_overflow(self, params): + overflow = self.has_overflow_serial(params) + overflow_gpu = torch.cuda.ByteTensor([overflow]) + overflow = overflow_gpu[0].item() + return bool(overflow) + + # `x` is a torch.Tensor + def _has_inf_or_nan(x): + try: + # if x is half, the .float() incurs an additional deep copy, but it's necessary if + # Pytorch's .sum() creates a one-element tensor of the same type as x + # (which is true for some recent version of pytorch). + cpu_sum = float(x.float().sum()) + # More efficient version that can be used if .sum() returns a Python scalar + # cpu_sum = float(x.sum()) + except RuntimeError as instance: + # We want to check if inst is actually an overflow exception. + # RuntimeError could come from a different error. + # If so, we still want the exception to propagate. + if 'value cannot be converted' not in instance.args[0]: + raise + return True + else: + if cpu_sum == float( + 'inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum: + return True + return False + + # `overflow` is boolean indicating whether the gradient overflowed + def update_scale(self, overflow): + + if not hasattr(self, 'min_scale'): + self.min_scale = 1 + if not hasattr(self, 'delayed_shift'): + self.delayed_shift = 1 + if not hasattr(self, 'cur_hysteresis'): + self.cur_hysteresis = 1 + if not hasattr(self, 'consecutive_hysteresis'): + self.consecutive_hysteresis = True + if overflow: + # self.cur_scale /= self.scale_factor + if self.delayed_shift == 1 or self.cur_hysteresis == 1: + self.cur_scale = max(self.cur_scale / self.scale_factor, + self.min_scale) + else: + self.cur_hysteresis -= 1 + self.last_overflow_iter = self.cur_iter + else: + if self.consecutive_hysteresis: + self.cur_hysteresis = self.delayed_shift + if (self.cur_iter + - self.last_overflow_iter) % self.scale_window == 0: + if not self.consecutive_hysteresis: + self.cur_hysteresis = self.delayed_shift + self.cur_scale *= self.scale_factor + self.cur_iter += 1 + + @property + def loss_scale(self): + return self.cur_scale + + def scale_gradient(self, module, grad_in, grad_out): + return tuple(self.loss_scale * g for g in grad_in) + + def backward(self, loss, retain_graph=False): + scaled_loss = loss * self.loss_scale + scaled_loss.backward(retain_graph=retain_graph) + + +############################################################## +# Example usage below here -- assuming it's in a separate file +############################################################## +""" +TO-DO separate out into an example. +if __name__ == "__main__": + import torch + from torch.autograd import Variable + from dynamic_loss_scaler import DynamicLossScaler + + # N is batch size; D_in is input dimension; + # H is hidden dimension; D_out is output dimension. + N, D_in, H, D_out = 64, 1000, 100, 10 + + # Create random Tensors to hold inputs and outputs, and wrap them in Variables. + x = Variable(torch.randn(N, D_in), requires_grad=False) + y = Variable(torch.randn(N, D_out), requires_grad=False) + + w1 = Variable(torch.randn(D_in, H), requires_grad=True) + w2 = Variable(torch.randn(H, D_out), requires_grad=True) + parameters = [w1, w2] + + learning_rate = 1e-6 + optimizer = torch.optim.SGD(parameters, lr=learning_rate) + loss_scaler = DynamicLossScaler() + + for t in range(500): + y_pred = x.mm(w1).clamp(min=0).mm(w2) + loss = (y_pred - y).pow(2).sum() * loss_scaler.loss_scale + print('Iter {} loss scale: {}'.format(t, loss_scaler.loss_scale)) + print('Iter {} scaled loss: {}'.format(t, loss.data[0])) + print('Iter {} unscaled loss: {}'.format(t, loss.data[0] / loss_scaler.loss_scale)) + + # Run backprop + optimizer.zero_grad() + loss.backward() + + # Check for overflow + has_overflow = DynamicLossScaler.has_overflow(parameters) + + # If no overflow, unscale grad and update as usual + if not has_overflow: + for param in parameters: + param.grad.data.mul_(1. / loss_scaler.loss_scale) + optimizer.step() + # Otherwise, don't do anything -- ie, skip iteration + else: + print('OVERFLOW!') + + # Update loss scale for next iteration + loss_scaler.update_scale(has_overflow) + +""" diff --git a/requirements/multi-modal.txt b/requirements/multi-modal.txt index 02e87baa..255f6155 100644 --- a/requirements/multi-modal.txt +++ b/requirements/multi-modal.txt @@ -5,6 +5,7 @@ pycocotools>=2.0.4 # rough-score was just recently updated from 0.0.4 to 0.0.7 # which introduced compatability issues that are being investigated rouge_score<=0.0.4 +sacrebleu taming-transformers-rom1504 timm tokenizers diff --git a/tests/trainers/test_ofa_trainer.py b/tests/trainers/test_ofa_trainer.py new file mode 100644 index 00000000..06003625 --- /dev/null +++ b/tests/trainers/test_ofa_trainer.py @@ -0,0 +1,105 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +import unittest + +import json + +from modelscope.metainfo import Metrics, Trainers +from modelscope.msdatasets import MsDataset +from modelscope.trainers import build_trainer +from modelscope.utils.constant import ModelFile +from modelscope.utils.test_utils import test_level + + +class TestOfaTrainer(unittest.TestCase): + + def setUp(self) -> None: + self.finetune_cfg = \ + {'framework': 'pytorch', + 'task': 'image-captioning', + 'model': {'type': 'ofa', + 'beam_search': {'beam_size': 5, + 'max_len_b': 16, + 'min_len': 1, + 'no_repeat_ngram_size': 0}, + 'seed': 7, + 'max_src_length': 256, + 'language': 'en', + 'gen_type': 'generation', + 'patch_image_size': 480, + 'max_image_size': 480, + 'imagenet_default_mean_and_std': False}, + 'pipeline': {'type': 'image-captioning'}, + 'dataset': {'column_map': {'text': 'caption'}}, + 'train': {'work_dir': 'work/ckpts/caption', + # 'launcher': 'pytorch', + 'max_epochs': 1, + 'use_fp16': True, + 'dataloader': {'batch_size_per_gpu': 1, 'workers_per_gpu': 0}, + 'lr_scheduler': {'name': 'polynomial_decay', + 'warmup_proportion': 0.01, + 'lr_end': 1e-07}, + 'lr_scheduler_hook': {'type': 'LrSchedulerHook', 'by_epoch': False}, + 'optimizer': {'type': 'AdamW', 'lr': 5e-05, 'weight_decay': 0.01}, + 'optimizer_hook': {'type': 'TorchAMPOptimizerHook', + 'cumulative_iters': 1, + 'grad_clip': {'max_norm': 1.0, 'norm_type': 2}, + 'loss_keys': 'loss'}, + 'criterion': {'name': 'AdjustLabelSmoothedCrossEntropyCriterion', + 'constraint_range': None, + 'drop_worst_after': 0, + 'drop_worst_ratio': 0.0, + 'ignore_eos': False, + 'ignore_prefix_size': 0, + 'label_smoothing': 0.1, + 'reg_alpha': 1.0, + 'report_accuracy': False, + 'sample_patch_num': 196, + 'sentence_avg': False, + 'use_rdrop': False}, + 'hooks': [{'type': 'BestCkptSaverHook', + 'metric_key': 'bleu-4', + 'interval': 100}, + {'type': 'TextLoggerHook', 'interval': 1}, + {'type': 'IterTimerHook'}, + {'type': 'EvaluationHook', 'by_epoch': True, 'interval': 1}]}, + 'evaluation': {'dataloader': {'batch_size_per_gpu': 4, 'workers_per_gpu': 0}, + 'metrics': [{'type': 'bleu', + 'eval_tokenized_bleu': False, + 'ref_name': 'labels', + 'hyp_name': 'caption'}]}, + 'preprocessor': []} + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_trainer_std(self): + WORKSPACE = './workspace/ckpts/caption' + os.makedirs(WORKSPACE, exist_ok=True) + config_file = os.path.join(WORKSPACE, ModelFile.CONFIGURATION) + with open(config_file, 'w') as writer: + json.dump(self.finetune_cfg, writer) + + pretrained_model = 'damo/ofa_image-caption_coco_distilled_en' + args = dict( + model=pretrained_model, + work_dir=WORKSPACE, + train_dataset=MsDataset.load( + 'coco_2014_caption', + namespace='modelscope', + split='train[:20]'), + eval_dataset=MsDataset.load( + 'coco_2014_caption', + namespace='modelscope', + split='validation[:10]'), + metrics=[Metrics.BLEU], + cfg_file=config_file) + trainer = build_trainer(name=Trainers.ofa, default_args=args) + trainer.train() + + self.assertIn(ModelFile.TORCH_MODEL_BIN_FILE, + os.listdir(os.path.join(WORKSPACE, 'output'))) + shutil.rmtree(WORKSPACE) + + +if __name__ == '__main__': + unittest.main()