1. Add F1 score to sequence classification metric
2. Fix a bug that the evaluate method in trainer does not support a pure pytorch_model.bin
3. Fix a bug in evaluation of veco trainer
4. Add some tips if lr_scheduler in the trainer needs a higher version torch
5. Add some comments
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10532230
master
| @@ -3,6 +3,7 @@ | |||||
| from typing import Dict | from typing import Dict | ||||
| import numpy as np | import numpy as np | ||||
| from sklearn.metrics import accuracy_score, f1_score | |||||
| from modelscope.metainfo import Metrics | from modelscope.metainfo import Metrics | ||||
| from modelscope.outputs import OutputKeys | from modelscope.outputs import OutputKeys | ||||
| @@ -41,5 +42,11 @@ class SequenceClassificationMetric(Metric): | |||||
| preds = np.argmax(preds, axis=1) | preds = np.argmax(preds, axis=1) | ||||
| return { | return { | ||||
| MetricKeys.ACCURACY: | MetricKeys.ACCURACY: | ||||
| (preds == labels).astype(np.float32).mean().item() | |||||
| accuracy_score(labels, preds), | |||||
| MetricKeys.F1: | |||||
| f1_score( | |||||
| labels, | |||||
| preds, | |||||
| average='micro' if any([label > 1 | |||||
| for label in labels]) else None), | |||||
| } | } | ||||
| @@ -67,8 +67,28 @@ class Model(ABC): | |||||
| cfg_dict: Config = None, | cfg_dict: Config = None, | ||||
| device: str = None, | device: str = None, | ||||
| **kwargs): | **kwargs): | ||||
| """ Instantiate a model from local directory or remote model repo. Note | |||||
| """Instantiate a model from local directory or remote model repo. Note | |||||
| that when loading from remote, the model revision can be specified. | that when loading from remote, the model revision can be specified. | ||||
| Args: | |||||
| model_name_or_path(str): A model dir or a model id to be loaded | |||||
| revision(str, `optional`): The revision used when the model_name_or_path is | |||||
| a model id of the remote hub. default `master`. | |||||
| cfg_dict(Config, `optional`): An optional model config. If provided, it will replace | |||||
| the config read out of the `model_name_or_path` | |||||
| device(str, `optional`): The device to load the model. | |||||
| **kwargs: | |||||
| task(str, `optional`): The `Tasks` enumeration value to replace the task value | |||||
| read out of config in the `model_name_or_path`. This is useful when the model to be loaded is not | |||||
| equal to the model saved. | |||||
| For example, load a `backbone` into a `text-classification` model. | |||||
| Other kwargs will be directly fed into the `model` key, to replace the default configs. | |||||
| Returns: | |||||
| A model instance. | |||||
| Examples: | |||||
| >>> from modelscope.models import Model | |||||
| >>> Model.from_pretrained('damo/nlp_structbert_backbone_base_std', task='text-classification') | |||||
| """ | """ | ||||
| prefetched = kwargs.get('model_prefetched') | prefetched = kwargs.get('model_prefetched') | ||||
| if prefetched is not None: | if prefetched is not None: | ||||
| @@ -288,8 +288,8 @@ class InvariantPointAttention(nn.Module): | |||||
| pt_att *= pt_att | pt_att *= pt_att | ||||
| pt_att = pt_att.sum(dim=-1) | pt_att = pt_att.sum(dim=-1) | ||||
| head_weights = self.softplus(self.head_weights).view( | |||||
| *((1, ) * len(pt_att.shape[:-2]) + (-1, 1))) | |||||
| head_weights = self.softplus(self.head_weights).view( # noqa | |||||
| *((1, ) * len(pt_att.shape[:-2]) + (-1, 1))) # noqa | |||||
| head_weights = head_weights * math.sqrt( | head_weights = head_weights * math.sqrt( | ||||
| 1.0 / (3 * (self.num_qk_points * 9.0 / 2))) | 1.0 / (3 * (self.num_qk_points * 9.0 / 2))) | ||||
| pt_att *= head_weights * (-0.5) | pt_att *= head_weights * (-0.5) | ||||
| @@ -147,8 +147,50 @@ class Preprocessor(ABC): | |||||
| cfg_dict: Config = None, | cfg_dict: Config = None, | ||||
| preprocessor_mode=ModeKeys.INFERENCE, | preprocessor_mode=ModeKeys.INFERENCE, | ||||
| **kwargs): | **kwargs): | ||||
| """ Instantiate a model from local directory or remote model repo. Note | |||||
| """Instantiate a preprocessor from local directory or remote model repo. Note | |||||
| that when loading from remote, the model revision can be specified. | that when loading from remote, the model revision can be specified. | ||||
| Args: | |||||
| model_name_or_path(str): A model dir or a model id used to load the preprocessor out. | |||||
| revision(str, `optional`): The revision used when the model_name_or_path is | |||||
| a model id of the remote hub. default `master`. | |||||
| cfg_dict(Config, `optional`): An optional config. If provided, it will replace | |||||
| the config read out of the `model_name_or_path` | |||||
| preprocessor_mode(str, `optional`): Specify the working mode of the preprocessor, can be `train`, `eval`, | |||||
| or `inference`. Default value `inference`. | |||||
| The preprocessor field in the config may contain two sub preprocessors: | |||||
| >>> { | |||||
| >>> "train": { | |||||
| >>> "type": "some-train-preprocessor" | |||||
| >>> }, | |||||
| >>> "val": { | |||||
| >>> "type": "some-eval-preprocessor" | |||||
| >>> } | |||||
| >>> } | |||||
| In this scenario, the `train` preprocessor will be loaded in the `train` mode, the `val` preprocessor | |||||
| will be loaded in the `eval` or `inference` mode. The `mode` field in the preprocessor class | |||||
| will be assigned in all the modes. | |||||
| Or just one: | |||||
| >>> { | |||||
| >>> "type": "some-train-preprocessor" | |||||
| >>> } | |||||
| In this scenario, the sole preprocessor will be loaded in all the modes, | |||||
| and the `mode` field in the preprocessor class will be assigned. | |||||
| **kwargs: | |||||
| task(str, `optional`): The `Tasks` enumeration value to replace the task value | |||||
| read out of config in the `model_name_or_path`. | |||||
| This is useful when the preprocessor does not have a `type` field and the task to be used is not | |||||
| equal to the task of which the model is saved. | |||||
| Other kwargs will be directly fed into the preprocessor, to replace the default configs. | |||||
| Returns: | |||||
| The preprocessor instance. | |||||
| Examples: | |||||
| >>> from modelscope.preprocessors import Preprocessor | |||||
| >>> Preprocessor.from_pretrained('damo/nlp_debertav2_fill-mask_chinese-base') | |||||
| """ | """ | ||||
| if not os.path.exists(model_name_or_path): | if not os.path.exists(model_name_or_path): | ||||
| model_dir = snapshot_download( | model_dir = snapshot_download( | ||||
| @@ -101,8 +101,9 @@ class CheckpointHook(Hook): | |||||
| model = trainer.model.module | model = trainer.model.module | ||||
| else: | else: | ||||
| model = trainer.model | model = trainer.model | ||||
| meta = load_checkpoint(filename, model, trainer.optimizer, | |||||
| trainer.lr_scheduler) | |||||
| meta = load_checkpoint(filename, model, | |||||
| getattr(trainer, 'optimizer', None), | |||||
| getattr(trainer, 'lr_scheduler', None)) | |||||
| trainer._epoch = meta.get('epoch', trainer._epoch) | trainer._epoch = meta.get('epoch', trainer._epoch) | ||||
| trainer._iter = meta.get('iter', trainer._iter) | trainer._iter = meta.get('iter', trainer._iter) | ||||
| trainer._inner_iter = meta.get('inner_iter', trainer._inner_iter) | trainer._inner_iter = meta.get('inner_iter', trainer._inner_iter) | ||||
| @@ -111,7 +112,7 @@ class CheckpointHook(Hook): | |||||
| # hook: Hook | # hook: Hook | ||||
| key = f'{hook.__class__}-{i}' | key = f'{hook.__class__}-{i}' | ||||
| if key in meta and hasattr(hook, 'load_state_dict'): | if key in meta and hasattr(hook, 'load_state_dict'): | ||||
| hook.load_state_dict(meta[key]) | |||||
| hook.load_state_dict(meta.get(key, {})) | |||||
| else: | else: | ||||
| trainer.logger.warn( | trainer.logger.warn( | ||||
| f'The state_dict of hook {hook.__class__} at index {i} is not found in the checkpoint file.' | f'The state_dict of hook {hook.__class__} at index {i} is not found in the checkpoint file.' | ||||
| @@ -123,7 +124,7 @@ class CheckpointHook(Hook): | |||||
| f'The modelscope version of loaded checkpoint does not match the runtime version. ' | f'The modelscope version of loaded checkpoint does not match the runtime version. ' | ||||
| f'The saved version: {version}, runtime version: {__version__}' | f'The saved version: {version}, runtime version: {__version__}' | ||||
| ) | ) | ||||
| trainer.logger.warn( | |||||
| trainer.logger.info( | |||||
| f'Checkpoint {filename} saving time: {meta.get("time")}') | f'Checkpoint {filename} saving time: {meta.get("time")}') | ||||
| return meta | return meta | ||||
| @@ -646,7 +646,9 @@ class VecoTrainer(NlpEpochBasedTrainer): | |||||
| break | break | ||||
| for metric_name in self.metrics: | for metric_name in self.metrics: | ||||
| metric_values[metric_name] = np.average( | |||||
| [m[metric_name] for m in metric_values.values()]) | |||||
| all_metrics = [m[metric_name] for m in metric_values.values()] | |||||
| for key in all_metrics[0].keys(): | |||||
| metric_values[key] = np.average( | |||||
| [metric[key] for metric in all_metrics]) | |||||
| return metric_values | return metric_values | ||||
| @@ -667,10 +667,25 @@ class EpochBasedTrainer(BaseTrainer): | |||||
| return dataset | return dataset | ||||
| def build_optimizer(self, cfg: ConfigDict, default_args: dict = None): | def build_optimizer(self, cfg: ConfigDict, default_args: dict = None): | ||||
| return build_optimizer(self.model, cfg=cfg, default_args=default_args) | |||||
| try: | |||||
| return build_optimizer( | |||||
| self.model, cfg=cfg, default_args=default_args) | |||||
| except KeyError as e: | |||||
| self.logger.error( | |||||
| f'Build optimizer error, the optimizer {cfg} is native torch optimizer, ' | |||||
| f'please check if your torch with version: {torch.__version__} matches the config.' | |||||
| ) | |||||
| raise e | |||||
| def build_lr_scheduler(self, cfg: ConfigDict, default_args: dict = None): | def build_lr_scheduler(self, cfg: ConfigDict, default_args: dict = None): | ||||
| return build_lr_scheduler(cfg=cfg, default_args=default_args) | |||||
| try: | |||||
| return build_lr_scheduler(cfg=cfg, default_args=default_args) | |||||
| except KeyError as e: | |||||
| self.logger.error( | |||||
| f'Build lr_scheduler error, the lr_scheduler {cfg} is native torch lr_scheduler, ' | |||||
| f'please check if your torch with version: {torch.__version__} matches the config.' | |||||
| ) | |||||
| raise e | |||||
| def create_optimizer_and_scheduler(self): | def create_optimizer_and_scheduler(self): | ||||
| """ Create optimizer and lr scheduler | """ Create optimizer and lr scheduler | ||||
| @@ -134,9 +134,7 @@ def load_checkpoint(filename, | |||||
| state_dict = checkpoint if 'state_dict' not in checkpoint else checkpoint[ | state_dict = checkpoint if 'state_dict' not in checkpoint else checkpoint[ | ||||
| 'state_dict'] | 'state_dict'] | ||||
| model.load_state_dict(state_dict) | model.load_state_dict(state_dict) | ||||
| if 'meta' in checkpoint: | |||||
| return checkpoint.get('meta', {}) | |||||
| return checkpoint.get('meta', {}) | |||||
| def save_pretrained(model, | def save_pretrained(model, | ||||
| @@ -0,0 +1,32 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import unittest | |||||
| import numpy as np | |||||
| from modelscope.metrics.sequence_classification_metric import \ | |||||
| SequenceClassificationMetric | |||||
| from modelscope.utils.test_utils import test_level | |||||
| class TestTextClsMetrics(unittest.TestCase): | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_value(self): | |||||
| metric = SequenceClassificationMetric() | |||||
| outputs = { | |||||
| 'logits': | |||||
| np.array([[2.0, 1.0, 0.5], [1.0, 1.5, 1.0], [2.0, 1.0, 3.0], | |||||
| [2.4, 1.5, 4.0], [2.0, 1.0, 3.0], [2.4, 1.5, 1.7], | |||||
| [2.0, 1.0, 0.5], [2.4, 1.5, 0.5]]) | |||||
| } | |||||
| inputs = {'labels': np.array([0, 1, 2, 2, 0, 1, 2, 2])} | |||||
| metric.add(outputs, inputs) | |||||
| ret = metric.evaluate() | |||||
| self.assertTrue(np.isclose(ret['f1'], 0.5)) | |||||
| self.assertTrue(np.isclose(ret['accuracy'], 0.5)) | |||||
| print(ret) | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||
| @@ -346,7 +346,7 @@ class TestFinetuneSequenceClassification(unittest.TestCase): | |||||
| train_datasets = [] | train_datasets = [] | ||||
| from datasets import DownloadConfig | from datasets import DownloadConfig | ||||
| dc = DownloadConfig() | dc = DownloadConfig() | ||||
| dc.local_files_only = True | |||||
| dc.local_files_only = False | |||||
| for lang in langs: | for lang in langs: | ||||
| train_datasets.append( | train_datasets.append( | ||||
| load_dataset('xnli', lang, split='train', download_config=dc)) | load_dataset('xnli', lang, split='train', download_config=dc)) | ||||
| @@ -223,13 +223,31 @@ class TestTrainerWithNlp(unittest.TestCase): | |||||
| trainer, 'trainer_continue_train', level='strict'): | trainer, 'trainer_continue_train', level='strict'): | ||||
| trainer.train(os.path.join(self.tmp_dir, 'iter_3.pth')) | trainer.train(os.path.join(self.tmp_dir, 'iter_3.pth')) | ||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||||
| def test_trainer_with_evaluation(self): | |||||
| tmp_dir = tempfile.TemporaryDirectory().name | |||||
| if not os.path.exists(tmp_dir): | |||||
| os.makedirs(tmp_dir) | |||||
| model_id = 'damo/nlp_structbert_sentence-similarity_chinese-tiny' | |||||
| cache_path = snapshot_download(model_id) | |||||
| model = SbertForSequenceClassification.from_pretrained(cache_path) | |||||
| kwargs = dict( | |||||
| cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION), | |||||
| model=model, | |||||
| eval_dataset=self.dataset, | |||||
| work_dir=self.tmp_dir) | |||||
| trainer = build_trainer(default_args=kwargs) | |||||
| print(trainer.evaluate(cache_path + '/pytorch_model.bin')) | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | ||||
| def test_trainer_with_model_and_args(self): | def test_trainer_with_model_and_args(self): | ||||
| tmp_dir = tempfile.TemporaryDirectory().name | tmp_dir = tempfile.TemporaryDirectory().name | ||||
| if not os.path.exists(tmp_dir): | if not os.path.exists(tmp_dir): | ||||
| os.makedirs(tmp_dir) | os.makedirs(tmp_dir) | ||||
| model_id = 'damo/nlp_structbert_sentence-similarity_chinese-base' | |||||
| model_id = 'damo/nlp_structbert_sentence-similarity_chinese-tiny' | |||||
| cache_path = snapshot_download(model_id) | cache_path = snapshot_download(model_id) | ||||
| model = SbertForSequenceClassification.from_pretrained(cache_path) | model = SbertForSequenceClassification.from_pretrained(cache_path) | ||||
| kwargs = dict( | kwargs = dict( | ||||