1. Add F1 score to sequence classification metric
2. Fix a bug that the evaluate method in trainer does not support a pure pytorch_model.bin
3. Fix a bug in evaluation of veco trainer
4. Add some tips if lr_scheduler in the trainer needs a higher version torch
5. Add some comments
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10532230
master
| @@ -3,6 +3,7 @@ | |||
| from typing import Dict | |||
| import numpy as np | |||
| from sklearn.metrics import accuracy_score, f1_score | |||
| from modelscope.metainfo import Metrics | |||
| from modelscope.outputs import OutputKeys | |||
| @@ -41,5 +42,11 @@ class SequenceClassificationMetric(Metric): | |||
| preds = np.argmax(preds, axis=1) | |||
| return { | |||
| MetricKeys.ACCURACY: | |||
| (preds == labels).astype(np.float32).mean().item() | |||
| accuracy_score(labels, preds), | |||
| MetricKeys.F1: | |||
| f1_score( | |||
| labels, | |||
| preds, | |||
| average='micro' if any([label > 1 | |||
| for label in labels]) else None), | |||
| } | |||
| @@ -67,8 +67,28 @@ class Model(ABC): | |||
| cfg_dict: Config = None, | |||
| device: str = None, | |||
| **kwargs): | |||
| """ Instantiate a model from local directory or remote model repo. Note | |||
| """Instantiate a model from local directory or remote model repo. Note | |||
| that when loading from remote, the model revision can be specified. | |||
| Args: | |||
| model_name_or_path(str): A model dir or a model id to be loaded | |||
| revision(str, `optional`): The revision used when the model_name_or_path is | |||
| a model id of the remote hub. default `master`. | |||
| cfg_dict(Config, `optional`): An optional model config. If provided, it will replace | |||
| the config read out of the `model_name_or_path` | |||
| device(str, `optional`): The device to load the model. | |||
| **kwargs: | |||
| task(str, `optional`): The `Tasks` enumeration value to replace the task value | |||
| read out of config in the `model_name_or_path`. This is useful when the model to be loaded is not | |||
| equal to the model saved. | |||
| For example, load a `backbone` into a `text-classification` model. | |||
| Other kwargs will be directly fed into the `model` key, to replace the default configs. | |||
| Returns: | |||
| A model instance. | |||
| Examples: | |||
| >>> from modelscope.models import Model | |||
| >>> Model.from_pretrained('damo/nlp_structbert_backbone_base_std', task='text-classification') | |||
| """ | |||
| prefetched = kwargs.get('model_prefetched') | |||
| if prefetched is not None: | |||
| @@ -288,8 +288,8 @@ class InvariantPointAttention(nn.Module): | |||
| pt_att *= pt_att | |||
| pt_att = pt_att.sum(dim=-1) | |||
| head_weights = self.softplus(self.head_weights).view( | |||
| *((1, ) * len(pt_att.shape[:-2]) + (-1, 1))) | |||
| head_weights = self.softplus(self.head_weights).view( # noqa | |||
| *((1, ) * len(pt_att.shape[:-2]) + (-1, 1))) # noqa | |||
| head_weights = head_weights * math.sqrt( | |||
| 1.0 / (3 * (self.num_qk_points * 9.0 / 2))) | |||
| pt_att *= head_weights * (-0.5) | |||
| @@ -147,8 +147,50 @@ class Preprocessor(ABC): | |||
| cfg_dict: Config = None, | |||
| preprocessor_mode=ModeKeys.INFERENCE, | |||
| **kwargs): | |||
| """ Instantiate a model from local directory or remote model repo. Note | |||
| """Instantiate a preprocessor from local directory or remote model repo. Note | |||
| that when loading from remote, the model revision can be specified. | |||
| Args: | |||
| model_name_or_path(str): A model dir or a model id used to load the preprocessor out. | |||
| revision(str, `optional`): The revision used when the model_name_or_path is | |||
| a model id of the remote hub. default `master`. | |||
| cfg_dict(Config, `optional`): An optional config. If provided, it will replace | |||
| the config read out of the `model_name_or_path` | |||
| preprocessor_mode(str, `optional`): Specify the working mode of the preprocessor, can be `train`, `eval`, | |||
| or `inference`. Default value `inference`. | |||
| The preprocessor field in the config may contain two sub preprocessors: | |||
| >>> { | |||
| >>> "train": { | |||
| >>> "type": "some-train-preprocessor" | |||
| >>> }, | |||
| >>> "val": { | |||
| >>> "type": "some-eval-preprocessor" | |||
| >>> } | |||
| >>> } | |||
| In this scenario, the `train` preprocessor will be loaded in the `train` mode, the `val` preprocessor | |||
| will be loaded in the `eval` or `inference` mode. The `mode` field in the preprocessor class | |||
| will be assigned in all the modes. | |||
| Or just one: | |||
| >>> { | |||
| >>> "type": "some-train-preprocessor" | |||
| >>> } | |||
| In this scenario, the sole preprocessor will be loaded in all the modes, | |||
| and the `mode` field in the preprocessor class will be assigned. | |||
| **kwargs: | |||
| task(str, `optional`): The `Tasks` enumeration value to replace the task value | |||
| read out of config in the `model_name_or_path`. | |||
| This is useful when the preprocessor does not have a `type` field and the task to be used is not | |||
| equal to the task of which the model is saved. | |||
| Other kwargs will be directly fed into the preprocessor, to replace the default configs. | |||
| Returns: | |||
| The preprocessor instance. | |||
| Examples: | |||
| >>> from modelscope.preprocessors import Preprocessor | |||
| >>> Preprocessor.from_pretrained('damo/nlp_debertav2_fill-mask_chinese-base') | |||
| """ | |||
| if not os.path.exists(model_name_or_path): | |||
| model_dir = snapshot_download( | |||
| @@ -101,8 +101,9 @@ class CheckpointHook(Hook): | |||
| model = trainer.model.module | |||
| else: | |||
| model = trainer.model | |||
| meta = load_checkpoint(filename, model, trainer.optimizer, | |||
| trainer.lr_scheduler) | |||
| meta = load_checkpoint(filename, model, | |||
| getattr(trainer, 'optimizer', None), | |||
| getattr(trainer, 'lr_scheduler', None)) | |||
| trainer._epoch = meta.get('epoch', trainer._epoch) | |||
| trainer._iter = meta.get('iter', trainer._iter) | |||
| trainer._inner_iter = meta.get('inner_iter', trainer._inner_iter) | |||
| @@ -111,7 +112,7 @@ class CheckpointHook(Hook): | |||
| # hook: Hook | |||
| key = f'{hook.__class__}-{i}' | |||
| if key in meta and hasattr(hook, 'load_state_dict'): | |||
| hook.load_state_dict(meta[key]) | |||
| hook.load_state_dict(meta.get(key, {})) | |||
| else: | |||
| trainer.logger.warn( | |||
| f'The state_dict of hook {hook.__class__} at index {i} is not found in the checkpoint file.' | |||
| @@ -123,7 +124,7 @@ class CheckpointHook(Hook): | |||
| f'The modelscope version of loaded checkpoint does not match the runtime version. ' | |||
| f'The saved version: {version}, runtime version: {__version__}' | |||
| ) | |||
| trainer.logger.warn( | |||
| trainer.logger.info( | |||
| f'Checkpoint {filename} saving time: {meta.get("time")}') | |||
| return meta | |||
| @@ -646,7 +646,9 @@ class VecoTrainer(NlpEpochBasedTrainer): | |||
| break | |||
| for metric_name in self.metrics: | |||
| metric_values[metric_name] = np.average( | |||
| [m[metric_name] for m in metric_values.values()]) | |||
| all_metrics = [m[metric_name] for m in metric_values.values()] | |||
| for key in all_metrics[0].keys(): | |||
| metric_values[key] = np.average( | |||
| [metric[key] for metric in all_metrics]) | |||
| return metric_values | |||
| @@ -667,10 +667,25 @@ class EpochBasedTrainer(BaseTrainer): | |||
| return dataset | |||
| def build_optimizer(self, cfg: ConfigDict, default_args: dict = None): | |||
| return build_optimizer(self.model, cfg=cfg, default_args=default_args) | |||
| try: | |||
| return build_optimizer( | |||
| self.model, cfg=cfg, default_args=default_args) | |||
| except KeyError as e: | |||
| self.logger.error( | |||
| f'Build optimizer error, the optimizer {cfg} is native torch optimizer, ' | |||
| f'please check if your torch with version: {torch.__version__} matches the config.' | |||
| ) | |||
| raise e | |||
| def build_lr_scheduler(self, cfg: ConfigDict, default_args: dict = None): | |||
| return build_lr_scheduler(cfg=cfg, default_args=default_args) | |||
| try: | |||
| return build_lr_scheduler(cfg=cfg, default_args=default_args) | |||
| except KeyError as e: | |||
| self.logger.error( | |||
| f'Build lr_scheduler error, the lr_scheduler {cfg} is native torch lr_scheduler, ' | |||
| f'please check if your torch with version: {torch.__version__} matches the config.' | |||
| ) | |||
| raise e | |||
| def create_optimizer_and_scheduler(self): | |||
| """ Create optimizer and lr scheduler | |||
| @@ -134,9 +134,7 @@ def load_checkpoint(filename, | |||
| state_dict = checkpoint if 'state_dict' not in checkpoint else checkpoint[ | |||
| 'state_dict'] | |||
| model.load_state_dict(state_dict) | |||
| if 'meta' in checkpoint: | |||
| return checkpoint.get('meta', {}) | |||
| return checkpoint.get('meta', {}) | |||
| def save_pretrained(model, | |||
| @@ -0,0 +1,32 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import unittest | |||
| import numpy as np | |||
| from modelscope.metrics.sequence_classification_metric import \ | |||
| SequenceClassificationMetric | |||
| from modelscope.utils.test_utils import test_level | |||
| class TestTextClsMetrics(unittest.TestCase): | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_value(self): | |||
| metric = SequenceClassificationMetric() | |||
| outputs = { | |||
| 'logits': | |||
| np.array([[2.0, 1.0, 0.5], [1.0, 1.5, 1.0], [2.0, 1.0, 3.0], | |||
| [2.4, 1.5, 4.0], [2.0, 1.0, 3.0], [2.4, 1.5, 1.7], | |||
| [2.0, 1.0, 0.5], [2.4, 1.5, 0.5]]) | |||
| } | |||
| inputs = {'labels': np.array([0, 1, 2, 2, 0, 1, 2, 2])} | |||
| metric.add(outputs, inputs) | |||
| ret = metric.evaluate() | |||
| self.assertTrue(np.isclose(ret['f1'], 0.5)) | |||
| self.assertTrue(np.isclose(ret['accuracy'], 0.5)) | |||
| print(ret) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||
| @@ -346,7 +346,7 @@ class TestFinetuneSequenceClassification(unittest.TestCase): | |||
| train_datasets = [] | |||
| from datasets import DownloadConfig | |||
| dc = DownloadConfig() | |||
| dc.local_files_only = True | |||
| dc.local_files_only = False | |||
| for lang in langs: | |||
| train_datasets.append( | |||
| load_dataset('xnli', lang, split='train', download_config=dc)) | |||
| @@ -223,13 +223,31 @@ class TestTrainerWithNlp(unittest.TestCase): | |||
| trainer, 'trainer_continue_train', level='strict'): | |||
| trainer.train(os.path.join(self.tmp_dir, 'iter_3.pth')) | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_trainer_with_evaluation(self): | |||
| tmp_dir = tempfile.TemporaryDirectory().name | |||
| if not os.path.exists(tmp_dir): | |||
| os.makedirs(tmp_dir) | |||
| model_id = 'damo/nlp_structbert_sentence-similarity_chinese-tiny' | |||
| cache_path = snapshot_download(model_id) | |||
| model = SbertForSequenceClassification.from_pretrained(cache_path) | |||
| kwargs = dict( | |||
| cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION), | |||
| model=model, | |||
| eval_dataset=self.dataset, | |||
| work_dir=self.tmp_dir) | |||
| trainer = build_trainer(default_args=kwargs) | |||
| print(trainer.evaluate(cache_path + '/pytorch_model.bin')) | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_trainer_with_model_and_args(self): | |||
| tmp_dir = tempfile.TemporaryDirectory().name | |||
| if not os.path.exists(tmp_dir): | |||
| os.makedirs(tmp_dir) | |||
| model_id = 'damo/nlp_structbert_sentence-similarity_chinese-base' | |||
| model_id = 'damo/nlp_structbert_sentence-similarity_chinese-tiny' | |||
| cache_path = snapshot_download(model_id) | |||
| model = SbertForSequenceClassification.from_pretrained(cache_path) | |||
| kwargs = dict( | |||