| @@ -10,6 +10,8 @@ import sys | |||||
| import tarfile | import tarfile | ||||
| import tempfile | import tempfile | ||||
| import operator | import operator | ||||
| import types | |||||
| import functools | |||||
| from collections import OrderedDict, UserDict | from collections import OrderedDict, UserDict | ||||
| from contextlib import contextmanager | from contextlib import contextmanager | ||||
| from dataclasses import fields | from dataclasses import fields | ||||
| @@ -37,6 +39,8 @@ if _NEED_IMPORT_TORCH: | |||||
| import torch | import torch | ||||
| _torch_version = importlib_metadata.version("torch") | _torch_version = importlib_metadata.version("torch") | ||||
| ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"} | |||||
| hf_cache_home = os.path.expanduser( | hf_cache_home = os.path.expanduser( | ||||
| os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface")) | os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface")) | ||||
| ) | ) | ||||
| @@ -45,10 +49,9 @@ default_cache_path = os.path.join(hf_cache_home, "transformers") | |||||
| PYTORCH_PRETRAINED_BERT_CACHE = os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path) | PYTORCH_PRETRAINED_BERT_CACHE = os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path) | ||||
| PYTORCH_TRANSFORMERS_CACHE = os.getenv("PYTORCH_TRANSFORMERS_CACHE", PYTORCH_PRETRAINED_BERT_CACHE) | PYTORCH_TRANSFORMERS_CACHE = os.getenv("PYTORCH_TRANSFORMERS_CACHE", PYTORCH_PRETRAINED_BERT_CACHE) | ||||
| TRANSFORMERS_CACHE = os.getenv("TRANSFORMERS_CACHE", PYTORCH_TRANSFORMERS_CACHE) | TRANSFORMERS_CACHE = os.getenv("TRANSFORMERS_CACHE", PYTORCH_TRANSFORMERS_CACHE) | ||||
| HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules")) | |||||
| TRANSFORMERS_DYNAMIC_MODULE_NAME = "transformers_modules" | |||||
| SESSION_ID = uuid4().hex | SESSION_ID = uuid4().hex | ||||
| ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"} | |||||
| DISABLE_TELEMETRY = os.getenv("DISABLE_TELEMETRY", False) in ENV_VARS_TRUE_VALUES | DISABLE_TELEMETRY = os.getenv("DISABLE_TELEMETRY", False) in ENV_VARS_TRUE_VALUES | ||||
| WEIGHTS_NAME = "pytorch_model.bin" | WEIGHTS_NAME = "pytorch_model.bin" | ||||
| @@ -1043,3 +1046,11 @@ class TensorType(ExplicitEnum): | |||||
| PYTORCH = "pt" | PYTORCH = "pt" | ||||
| NUMPY = "np" | NUMPY = "np" | ||||
| def copy_func(f): | |||||
| """Returns a copy of a function f.""" | |||||
| # Based on http://stackoverflow.com/a/6528148/190597 (Glenn Maynard) | |||||
| g = types.FunctionType(f.__code__, f.__globals__, name=f.__name__, argdefs=f.__defaults__, closure=f.__closure__) | |||||
| g = functools.update_wrapper(g, f) | |||||
| g.__kwdefaults__ = f.__kwdefaults__ | |||||
| return g | |||||
| @@ -0,0 +1,83 @@ | |||||
| __all__ = [ | |||||
| "ALL_PRETRAINED_CONFIG_ARCHIVE_MAP", | |||||
| "CONFIG_MAPPING", | |||||
| "MODEL_NAMES_MAPPING", | |||||
| "AutoConfig", | |||||
| "TOKENIZER_MAPPING", | |||||
| "get_values", | |||||
| "MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING", | |||||
| "MODEL_FOR_CAUSAL_LM_MAPPING", | |||||
| "MODEL_FOR_CTC_MAPPING", | |||||
| "MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING", | |||||
| "MODEL_FOR_MASKED_LM_MAPPING", | |||||
| "MODEL_FOR_MULTIPLE_CHOICE_MAPPING", | |||||
| "MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING", | |||||
| "MODEL_FOR_OBJECT_DETECTION_MAPPING", | |||||
| "MODEL_FOR_PRETRAINING_MAPPING", | |||||
| "MODEL_FOR_QUESTION_ANSWERING_MAPPING", | |||||
| "MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING", | |||||
| "MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING", | |||||
| "MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING", | |||||
| "MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING", | |||||
| "MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING", | |||||
| "MODEL_MAPPING", | |||||
| "MODEL_WITH_LM_HEAD_MAPPING", | |||||
| "AutoModel", | |||||
| "AutoModelForAudioClassification", | |||||
| "AutoModelForCausalLM", | |||||
| "AutoModelForCTC", | |||||
| "AutoModelForImageClassification", | |||||
| "AutoModelForMaskedLM", | |||||
| "AutoModelForMultipleChoice", | |||||
| "AutoModelForNextSentencePrediction", | |||||
| "AutoModelForObjectDetection", | |||||
| "AutoModelForPreTraining", | |||||
| "AutoModelForQuestionAnswering", | |||||
| "AutoModelForSeq2SeqLM", | |||||
| "AutoModelForSequenceClassification", | |||||
| "AutoModelForSpeechSeq2Seq", | |||||
| "AutoModelForTableQuestionAnswering", | |||||
| "AutoModelForTokenClassification", | |||||
| "AutoModelWithLMHead", | |||||
| ] | |||||
| from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, CONFIG_MAPPING, MODEL_NAMES_MAPPING, \ | |||||
| AutoConfig | |||||
| from .tokenization_auto import TOKENIZER_MAPPING | |||||
| from .auto_factory import get_values | |||||
| from .modeling_auto import ( | |||||
| MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING, | |||||
| MODEL_FOR_CAUSAL_LM_MAPPING, | |||||
| MODEL_FOR_CTC_MAPPING, | |||||
| MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, | |||||
| MODEL_FOR_MASKED_LM_MAPPING, | |||||
| MODEL_FOR_MULTIPLE_CHOICE_MAPPING, | |||||
| MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, | |||||
| MODEL_FOR_OBJECT_DETECTION_MAPPING, | |||||
| MODEL_FOR_PRETRAINING_MAPPING, | |||||
| MODEL_FOR_QUESTION_ANSWERING_MAPPING, | |||||
| MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, | |||||
| MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, | |||||
| MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING, | |||||
| MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING, | |||||
| MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, | |||||
| MODEL_MAPPING, | |||||
| MODEL_WITH_LM_HEAD_MAPPING, | |||||
| AutoModel, | |||||
| AutoModelForAudioClassification, | |||||
| AutoModelForCausalLM, | |||||
| AutoModelForCTC, | |||||
| AutoModelForImageClassification, | |||||
| AutoModelForMaskedLM, | |||||
| AutoModelForMultipleChoice, | |||||
| AutoModelForNextSentencePrediction, | |||||
| AutoModelForObjectDetection, | |||||
| AutoModelForPreTraining, | |||||
| AutoModelForQuestionAnswering, | |||||
| AutoModelForSeq2SeqLM, | |||||
| AutoModelForSequenceClassification, | |||||
| AutoModelForSpeechSeq2Seq, | |||||
| AutoModelForTableQuestionAnswering, | |||||
| AutoModelForTokenClassification, | |||||
| AutoModelWithLMHead, | |||||
| ) | |||||
| @@ -0,0 +1,562 @@ | |||||
| # coding=utf-8 | |||||
| # Copyright 2021 The HuggingFace Inc. team. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| """Factory function to build auto-model classes.""" | |||||
| import importlib | |||||
| from collections import OrderedDict | |||||
| from .configuration_auto import AutoConfig, model_type_to_module_name, replace_list_option_in_docstrings | |||||
| from .dynamic import get_class_from_dynamic_module | |||||
| from fastNLP.transformers.torch.configuration_utils import PretrainedConfig | |||||
| from fastNLP.transformers.torch.file_utils import copy_func | |||||
| from fastNLP.core.log import logger | |||||
| CLASS_DOCSTRING = """ | |||||
| This is a generic model class that will be instantiated as one of the model classes of the library when created | |||||
| with the :meth:`~transformers.BaseAutoModelClass.from_pretrained` class method or the | |||||
| :meth:`~transformers.BaseAutoModelClass.from_config` class method. | |||||
| This class cannot be instantiated directly using ``__init__()`` (throws an error). | |||||
| """ | |||||
| FROM_CONFIG_DOCSTRING = """ | |||||
| Instantiates one of the model classes of the library from a configuration. | |||||
| Note: | |||||
| Loading a model from its configuration file does **not** load the model weights. It only affects the | |||||
| model's configuration. Use :meth:`~transformers.BaseAutoModelClass.from_pretrained` to load the model | |||||
| weights. | |||||
| Args: | |||||
| config (:class:`~transformers.PretrainedConfig`): | |||||
| The model class to instantiate is selected based on the configuration class: | |||||
| List options | |||||
| Examples:: | |||||
| >>> from transformers import AutoConfig, BaseAutoModelClass | |||||
| >>> # Download configuration from huggingface.co and cache. | |||||
| >>> config = AutoConfig.from_pretrained('checkpoint_placeholder') | |||||
| >>> model = BaseAutoModelClass.from_config(config) | |||||
| """ | |||||
| FROM_PRETRAINED_TORCH_DOCSTRING = """ | |||||
| Instantiate one of the model classes of the library from a pretrained model. | |||||
| The model class to instantiate is selected based on the :obj:`model_type` property of the config object (either | |||||
| passed as an argument or loaded from :obj:`pretrained_model_name_or_path` if possible), or when it's missing, | |||||
| by falling back to using pattern matching on :obj:`pretrained_model_name_or_path`: | |||||
| List options | |||||
| The model is set in evaluation mode by default using ``model.eval()`` (so for instance, dropout modules are | |||||
| deactivated). To train the model, you should first set it back in training mode with ``model.train()`` | |||||
| Args: | |||||
| pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`): | |||||
| Can be either: | |||||
| - A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co. | |||||
| Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under | |||||
| a user or organization name, like ``dbmdz/bert-base-german-cased``. | |||||
| - A path to a `directory` containing model weights saved using | |||||
| :func:`~transformers.PreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``. | |||||
| - A path or url to a `tensorflow index checkpoint file` (e.g, ``./tf_model/model.ckpt.index``). In | |||||
| this case, ``from_tf`` should be set to :obj:`True` and a configuration object should be provided | |||||
| as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in | |||||
| a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. | |||||
| model_args (additional positional arguments, `optional`): | |||||
| Will be passed along to the underlying model ``__init__()`` method. | |||||
| config (:class:`~transformers.PretrainedConfig`, `optional`): | |||||
| Configuration for the model to use instead of an automatically loaded configuration. Configuration can | |||||
| be automatically loaded when: | |||||
| - The model is a model provided by the library (loaded with the `model id` string of a pretrained | |||||
| model). | |||||
| - The model was saved using :meth:`~transformers.PreTrainedModel.save_pretrained` and is reloaded | |||||
| by supplying the save directory. | |||||
| - The model is loaded by supplying a local directory as ``pretrained_model_name_or_path`` and a | |||||
| configuration JSON file named `config.json` is found in the directory. | |||||
| state_dict (`Dict[str, torch.Tensor]`, `optional`): | |||||
| A state dictionary to use instead of a state dictionary loaded from saved weights file. | |||||
| This option can be used if you want to create a model from a pretrained configuration but load your own | |||||
| weights. In this case though, you should check if using | |||||
| :func:`~transformers.PreTrainedModel.save_pretrained` and | |||||
| :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option. | |||||
| cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`): | |||||
| Path to a directory in which a downloaded pretrained model configuration should be cached if the | |||||
| standard cache should not be used. | |||||
| from_tf (:obj:`bool`, `optional`, defaults to :obj:`False`): | |||||
| Load the model weights from a TensorFlow checkpoint save file (see docstring of | |||||
| ``pretrained_model_name_or_path`` argument). | |||||
| force_download (:obj:`bool`, `optional`, defaults to :obj:`False`): | |||||
| Whether or not to force the (re-)download of the model weights and configuration files, overriding the | |||||
| cached versions if they exist. | |||||
| resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`): | |||||
| Whether or not to delete incompletely received files. Will attempt to resume the download if such a | |||||
| file exists. | |||||
| proxies (:obj:`Dict[str, str], `optional`): | |||||
| A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128', | |||||
| 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. | |||||
| output_loading_info(:obj:`bool`, `optional`, defaults to :obj:`False`): | |||||
| Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. | |||||
| local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`): | |||||
| Whether or not to only look at local files (e.g., not try downloading the model). | |||||
| revision(:obj:`str`, `optional`, defaults to :obj:`"main"`): | |||||
| The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a | |||||
| git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any | |||||
| identifier allowed by git. | |||||
| trust_remote_code (:obj:`bool`, `optional`, defaults to :obj:`False`): | |||||
| Whether or not to allow for custom models defined on the Hub in their own modeling files. This option | |||||
| should only be set to :obj:`True` for repositories you trust and in which you have read the code, as it | |||||
| will execute code present on the Hub on your local machine. | |||||
| kwargs (additional keyword arguments, `optional`): | |||||
| Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., | |||||
| :obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or | |||||
| automatically loaded: | |||||
| - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the | |||||
| underlying model's ``__init__`` method (we assume all relevant updates to the configuration have | |||||
| already been done) | |||||
| - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class | |||||
| initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of | |||||
| ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute | |||||
| with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration | |||||
| attribute will be passed to the underlying model's ``__init__`` function. | |||||
| Examples:: | |||||
| >>> from transformers import AutoConfig, BaseAutoModelClass | |||||
| >>> # Download model and configuration from huggingface.co and cache. | |||||
| >>> model = BaseAutoModelClass.from_pretrained('checkpoint_placeholder') | |||||
| >>> # Update configuration during loading | |||||
| >>> model = BaseAutoModelClass.from_pretrained('checkpoint_placeholder', output_attentions=True) | |||||
| >>> model.config.output_attentions | |||||
| True | |||||
| >>> # Loading from a TF checkpoint file instead of a PyTorch model (slower) | |||||
| >>> config = AutoConfig.from_pretrained('./tf_model/shortcut_placeholder_tf_model_config.json') | |||||
| >>> model = BaseAutoModelClass.from_pretrained('./tf_model/shortcut_placeholder_tf_checkpoint.ckpt.index', from_tf=True, config=config) | |||||
| """ | |||||
| FROM_PRETRAINED_TF_DOCSTRING = """ | |||||
| Instantiate one of the model classes of the library from a pretrained model. | |||||
| The model class to instantiate is selected based on the :obj:`model_type` property of the config object (either | |||||
| passed as an argument or loaded from :obj:`pretrained_model_name_or_path` if possible), or when it's missing, | |||||
| by falling back to using pattern matching on :obj:`pretrained_model_name_or_path`: | |||||
| List options | |||||
| Args: | |||||
| pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`): | |||||
| Can be either: | |||||
| - A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co. | |||||
| Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under | |||||
| a user or organization name, like ``dbmdz/bert-base-german-cased``. | |||||
| - A path to a `directory` containing model weights saved using | |||||
| :func:`~transformers.PreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``. | |||||
| - A path or url to a `PyTorch state_dict save file` (e.g, ``./pt_model/pytorch_model.bin``). In | |||||
| this case, ``from_pt`` should be set to :obj:`True` and a configuration object should be provided | |||||
| as ``config`` argument. This loading path is slower than converting the PyTorch model in a | |||||
| TensorFlow model using the provided conversion scripts and loading the TensorFlow model | |||||
| afterwards. | |||||
| model_args (additional positional arguments, `optional`): | |||||
| Will be passed along to the underlying model ``__init__()`` method. | |||||
| config (:class:`~transformers.PretrainedConfig`, `optional`): | |||||
| Configuration for the model to use instead of an automatically loaded configuration. Configuration can | |||||
| be automatically loaded when: | |||||
| - The model is a model provided by the library (loaded with the `model id` string of a pretrained | |||||
| model). | |||||
| - The model was saved using :meth:`~transformers.PreTrainedModel.save_pretrained` and is reloaded | |||||
| by supplying the save directory. | |||||
| - The model is loaded by supplying a local directory as ``pretrained_model_name_or_path`` and a | |||||
| configuration JSON file named `config.json` is found in the directory. | |||||
| cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`): | |||||
| Path to a directory in which a downloaded pretrained model configuration should be cached if the | |||||
| standard cache should not be used. | |||||
| from_pt (:obj:`bool`, `optional`, defaults to :obj:`False`): | |||||
| Load the model weights from a PyTorch checkpoint save file (see docstring of | |||||
| ``pretrained_model_name_or_path`` argument). | |||||
| force_download (:obj:`bool`, `optional`, defaults to :obj:`False`): | |||||
| Whether or not to force the (re-)download of the model weights and configuration files, overriding the | |||||
| cached versions if they exist. | |||||
| resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`): | |||||
| Whether or not to delete incompletely received files. Will attempt to resume the download if such a | |||||
| file exists. | |||||
| proxies (:obj:`Dict[str, str], `optional`): | |||||
| A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128', | |||||
| 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. | |||||
| output_loading_info(:obj:`bool`, `optional`, defaults to :obj:`False`): | |||||
| Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. | |||||
| local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`): | |||||
| Whether or not to only look at local files (e.g., not try downloading the model). | |||||
| revision(:obj:`str`, `optional`, defaults to :obj:`"main"`): | |||||
| The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a | |||||
| git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any | |||||
| identifier allowed by git. | |||||
| trust_remote_code (:obj:`bool`, `optional`, defaults to :obj:`False`): | |||||
| Whether or not to allow for custom models defined on the Hub in their own modeling files. This option | |||||
| should only be set to :obj:`True` for repositories you trust and in which you have read the code, as it | |||||
| will execute code present on the Hub on your local machine. | |||||
| kwargs (additional keyword arguments, `optional`): | |||||
| Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., | |||||
| :obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or | |||||
| automatically loaded: | |||||
| - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the | |||||
| underlying model's ``__init__`` method (we assume all relevant updates to the configuration have | |||||
| already been done) | |||||
| - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class | |||||
| initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of | |||||
| ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute | |||||
| with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration | |||||
| attribute will be passed to the underlying model's ``__init__`` function. | |||||
| Examples:: | |||||
| >>> from transformers import AutoConfig, BaseAutoModelClass | |||||
| >>> # Download model and configuration from huggingface.co and cache. | |||||
| >>> model = BaseAutoModelClass.from_pretrained('checkpoint_placeholder') | |||||
| >>> # Update configuration during loading | |||||
| >>> model = BaseAutoModelClass.from_pretrained('checkpoint_placeholder', output_attentions=True) | |||||
| >>> model.config.output_attentions | |||||
| True | |||||
| >>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower) | |||||
| >>> config = AutoConfig.from_pretrained('./pt_model/shortcut_placeholder_pt_model_config.json') | |||||
| >>> model = BaseAutoModelClass.from_pretrained('./pt_model/shortcut_placeholder_pytorch_model.bin', from_pt=True, config=config) | |||||
| """ | |||||
| FROM_PRETRAINED_FLAX_DOCSTRING = """ | |||||
| Instantiate one of the model classes of the library from a pretrained model. | |||||
| The model class to instantiate is selected based on the :obj:`model_type` property of the config object (either | |||||
| passed as an argument or loaded from :obj:`pretrained_model_name_or_path` if possible), or when it's missing, | |||||
| by falling back to using pattern matching on :obj:`pretrained_model_name_or_path`: | |||||
| List options | |||||
| Args: | |||||
| pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`): | |||||
| Can be either: | |||||
| - A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co. | |||||
| Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under | |||||
| a user or organization name, like ``dbmdz/bert-base-german-cased``. | |||||
| - A path to a `directory` containing model weights saved using | |||||
| :func:`~transformers.PreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``. | |||||
| - A path or url to a `PyTorch state_dict save file` (e.g, ``./pt_model/pytorch_model.bin``). In | |||||
| this case, ``from_pt`` should be set to :obj:`True` and a configuration object should be provided | |||||
| as ``config`` argument. This loading path is slower than converting the PyTorch model in a | |||||
| TensorFlow model using the provided conversion scripts and loading the TensorFlow model | |||||
| afterwards. | |||||
| model_args (additional positional arguments, `optional`): | |||||
| Will be passed along to the underlying model ``__init__()`` method. | |||||
| config (:class:`~transformers.PretrainedConfig`, `optional`): | |||||
| Configuration for the model to use instead of an automatically loaded configuration. Configuration can | |||||
| be automatically loaded when: | |||||
| - The model is a model provided by the library (loaded with the `model id` string of a pretrained | |||||
| model). | |||||
| - The model was saved using :meth:`~transformers.PreTrainedModel.save_pretrained` and is reloaded | |||||
| by supplying the save directory. | |||||
| - The model is loaded by supplying a local directory as ``pretrained_model_name_or_path`` and a | |||||
| configuration JSON file named `config.json` is found in the directory. | |||||
| cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`): | |||||
| Path to a directory in which a downloaded pretrained model configuration should be cached if the | |||||
| standard cache should not be used. | |||||
| from_pt (:obj:`bool`, `optional`, defaults to :obj:`False`): | |||||
| Load the model weights from a PyTorch checkpoint save file (see docstring of | |||||
| ``pretrained_model_name_or_path`` argument). | |||||
| force_download (:obj:`bool`, `optional`, defaults to :obj:`False`): | |||||
| Whether or not to force the (re-)download of the model weights and configuration files, overriding the | |||||
| cached versions if they exist. | |||||
| resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`): | |||||
| Whether or not to delete incompletely received files. Will attempt to resume the download if such a | |||||
| file exists. | |||||
| proxies (:obj:`Dict[str, str], `optional`): | |||||
| A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128', | |||||
| 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. | |||||
| output_loading_info(:obj:`bool`, `optional`, defaults to :obj:`False`): | |||||
| Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. | |||||
| local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`): | |||||
| Whether or not to only look at local files (e.g., not try downloading the model). | |||||
| revision(:obj:`str`, `optional`, defaults to :obj:`"main"`): | |||||
| The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a | |||||
| git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any | |||||
| identifier allowed by git. | |||||
| trust_remote_code (:obj:`bool`, `optional`, defaults to :obj:`False`): | |||||
| Whether or not to allow for custom models defined on the Hub in their own modeling files. This option | |||||
| should only be set to :obj:`True` for repositories you trust and in which you have read the code, as it | |||||
| will execute code present on the Hub on your local machine. | |||||
| kwargs (additional keyword arguments, `optional`): | |||||
| Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., | |||||
| :obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or | |||||
| automatically loaded: | |||||
| - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the | |||||
| underlying model's ``__init__`` method (we assume all relevant updates to the configuration have | |||||
| already been done) | |||||
| - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class | |||||
| initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of | |||||
| ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute | |||||
| with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration | |||||
| attribute will be passed to the underlying model's ``__init__`` function. | |||||
| Examples:: | |||||
| >>> from transformers import AutoConfig, BaseAutoModelClass | |||||
| >>> # Download model and configuration from huggingface.co and cache. | |||||
| >>> model = BaseAutoModelClass.from_pretrained('checkpoint_placeholder') | |||||
| >>> # Update configuration during loading | |||||
| >>> model = BaseAutoModelClass.from_pretrained('checkpoint_placeholder', output_attentions=True) | |||||
| >>> model.config.output_attentions | |||||
| True | |||||
| >>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower) | |||||
| >>> config = AutoConfig.from_pretrained('./pt_model/shortcut_placeholder_pt_model_config.json') | |||||
| >>> model = BaseAutoModelClass.from_pretrained('./pt_model/shortcut_placeholder_pytorch_model.bin', from_pt=True, config=config) | |||||
| """ | |||||
| def _get_model_class(config, model_mapping): | |||||
| supported_models = model_mapping[type(config)] | |||||
| if not isinstance(supported_models, (list, tuple)): | |||||
| return supported_models | |||||
| name_to_model = {model.__name__: model for model in supported_models} | |||||
| architectures = getattr(config, "architectures", []) | |||||
| for arch in architectures: | |||||
| if arch in name_to_model: | |||||
| return name_to_model[arch] | |||||
| elif f"TF{arch}" in name_to_model: | |||||
| return name_to_model[f"TF{arch}"] | |||||
| elif f"Flax{arch}" in name_to_model: | |||||
| return name_to_model[f"Flax{arch}"] | |||||
| # If not architecture is set in the config or match the supported models, the first element of the tuple is the | |||||
| # defaults. | |||||
| return supported_models[0] | |||||
| class _BaseAutoModelClass: | |||||
| # Base class for auto models. | |||||
| _model_mapping = None | |||||
| def __init__(self, *args, **kwargs): | |||||
| raise EnvironmentError( | |||||
| f"{self.__class__.__name__} is designed to be instantiated " | |||||
| f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or " | |||||
| f"`{self.__class__.__name__}.from_config(config)` methods." | |||||
| ) | |||||
| @classmethod | |||||
| def from_config(cls, config, **kwargs): | |||||
| if type(config) in cls._model_mapping.keys(): | |||||
| model_class = _get_model_class(config, cls._model_mapping) | |||||
| return model_class._from_config(config, **kwargs) | |||||
| raise ValueError( | |||||
| f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n" | |||||
| f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}." | |||||
| ) | |||||
| @classmethod | |||||
| def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): | |||||
| config = kwargs.pop("config", None) | |||||
| trust_remote_code = kwargs.pop("trust_remote_code", False) | |||||
| kwargs["_from_auto"] = True | |||||
| if not isinstance(config, PretrainedConfig): | |||||
| config, kwargs = AutoConfig.from_pretrained( | |||||
| pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs | |||||
| ) | |||||
| if hasattr(config, "auto_map") and cls.__name__ in config.auto_map: | |||||
| if not trust_remote_code: | |||||
| raise ValueError( | |||||
| f"Loading {pretrained_model_name_or_path} requires you to execute the modeling file in that repo " | |||||
| "on your local machine. Make sure you have read the code there to avoid malicious use, then set " | |||||
| "the option `trust_remote_code=True` to remove this error." | |||||
| ) | |||||
| if kwargs.get("revision", None) is None: | |||||
| logger.warn( | |||||
| "Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure " | |||||
| "no malicious code has been contributed in a newer revision." | |||||
| ) | |||||
| class_ref = config.auto_map[cls.__name__] | |||||
| module_file, class_name = class_ref.split(".") | |||||
| model_class = get_class_from_dynamic_module( | |||||
| pretrained_model_name_or_path, module_file + ".py", class_name, **kwargs | |||||
| ) | |||||
| return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs) | |||||
| elif type(config) in cls._model_mapping.keys(): | |||||
| model_class = _get_model_class(config, cls._model_mapping) | |||||
| return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs) | |||||
| raise ValueError( | |||||
| f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n" | |||||
| f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}." | |||||
| ) | |||||
| def insert_head_doc(docstring, head_doc=""): | |||||
| if len(head_doc) > 0: | |||||
| return docstring.replace( | |||||
| "one of the model classes of the library ", | |||||
| f"one of the model classes of the library (with a {head_doc} head) ", | |||||
| ) | |||||
| return docstring.replace( | |||||
| "one of the model classes of the library ", "one of the base model classes of the library " | |||||
| ) | |||||
| def auto_class_update(cls, checkpoint_for_example="bert-base-cased", head_doc=""): | |||||
| # Create a new class with the right name from the base class | |||||
| model_mapping = cls._model_mapping | |||||
| name = cls.__name__ | |||||
| class_docstring = insert_head_doc(CLASS_DOCSTRING, head_doc=head_doc) | |||||
| cls.__doc__ = class_docstring.replace("BaseAutoModelClass", name) | |||||
| # Now we need to copy and re-register `from_config` and `from_pretrained` as class methods otherwise we can't | |||||
| # have a specific docstrings for them. | |||||
| from_config = copy_func(_BaseAutoModelClass.from_config) | |||||
| from_config_docstring = insert_head_doc(FROM_CONFIG_DOCSTRING, head_doc=head_doc) | |||||
| from_config_docstring = from_config_docstring.replace("BaseAutoModelClass", name) | |||||
| from_config_docstring = from_config_docstring.replace("checkpoint_placeholder", checkpoint_for_example) | |||||
| from_config.__doc__ = from_config_docstring | |||||
| from_config = replace_list_option_in_docstrings(model_mapping._model_mapping, use_model_types=False)(from_config) | |||||
| cls.from_config = classmethod(from_config) | |||||
| if name.startswith("TF"): | |||||
| from_pretrained_docstring = FROM_PRETRAINED_TF_DOCSTRING | |||||
| elif name.startswith("Flax"): | |||||
| from_pretrained_docstring = FROM_PRETRAINED_FLAX_DOCSTRING | |||||
| else: | |||||
| from_pretrained_docstring = FROM_PRETRAINED_TORCH_DOCSTRING | |||||
| from_pretrained = copy_func(_BaseAutoModelClass.from_pretrained) | |||||
| from_pretrained_docstring = insert_head_doc(from_pretrained_docstring, head_doc=head_doc) | |||||
| from_pretrained_docstring = from_pretrained_docstring.replace("BaseAutoModelClass", name) | |||||
| from_pretrained_docstring = from_pretrained_docstring.replace("checkpoint_placeholder", checkpoint_for_example) | |||||
| shortcut = checkpoint_for_example.split("/")[-1].split("-")[0] | |||||
| from_pretrained_docstring = from_pretrained_docstring.replace("shortcut_placeholder", shortcut) | |||||
| from_pretrained.__doc__ = from_pretrained_docstring | |||||
| from_pretrained = replace_list_option_in_docstrings(model_mapping._model_mapping)(from_pretrained) | |||||
| cls.from_pretrained = classmethod(from_pretrained) | |||||
| return cls | |||||
| def get_values(model_mapping): | |||||
| result = [] | |||||
| for model in model_mapping.values(): | |||||
| if isinstance(model, (list, tuple)): | |||||
| result += list(model) | |||||
| else: | |||||
| result.append(model) | |||||
| return result | |||||
| def getattribute_from_module(module, attr): | |||||
| if attr is None: | |||||
| return None | |||||
| if isinstance(attr, tuple): | |||||
| return tuple(getattribute_from_module(module, a) for a in attr) | |||||
| if hasattr(module, attr): | |||||
| return getattr(module, attr) | |||||
| # Some of the mappings have entries model_type -> object of another model type. In that case we try to grab the | |||||
| # object at the top level. | |||||
| transformers_module = importlib.import_module("transformers") | |||||
| return getattribute_from_module(transformers_module, attr) | |||||
| class _LazyAutoMapping(OrderedDict): | |||||
| """ | |||||
| " A mapping config to object (model or tokenizer for instance) that will load keys and values when it is accessed. | |||||
| Args: | |||||
| - config_mapping: The map model type to config class | |||||
| - model_mapping: The map model type to model (or tokenizer) class | |||||
| """ | |||||
| def __init__(self, config_mapping, model_mapping): | |||||
| self._config_mapping = config_mapping | |||||
| self._reverse_config_mapping = {v: k for k, v in config_mapping.items()} | |||||
| self._model_mapping = model_mapping | |||||
| self._modules = {} | |||||
| def __getitem__(self, key): | |||||
| model_type = self._reverse_config_mapping[key.__name__] | |||||
| if model_type not in self._model_mapping: | |||||
| raise KeyError(key) | |||||
| model_name = self._model_mapping[model_type] | |||||
| return self._load_attr_from_module(model_type, model_name) | |||||
| def _load_attr_from_module(self, model_type, attr): | |||||
| module_name = model_type_to_module_name(model_type) | |||||
| if module_name not in self._modules: | |||||
| self._modules[module_name] = importlib.import_module(f".{module_name}", "transformers.models") | |||||
| return getattribute_from_module(self._modules[module_name], attr) | |||||
| def keys(self): | |||||
| return [ | |||||
| self._load_attr_from_module(key, name) | |||||
| for key, name in self._config_mapping.items() | |||||
| if key in self._model_mapping.keys() | |||||
| ] | |||||
| def get(self, key, default): | |||||
| try: | |||||
| return self.__getitem__(key) | |||||
| except KeyError: | |||||
| return default | |||||
| def __bool__(self): | |||||
| return bool(self.keys()) | |||||
| def values(self): | |||||
| return [ | |||||
| self._load_attr_from_module(key, name) | |||||
| for key, name in self._model_mapping.items() | |||||
| if key in self._config_mapping.keys() | |||||
| ] | |||||
| def items(self): | |||||
| return [ | |||||
| ( | |||||
| self._load_attr_from_module(key, self._config_mapping[key]), | |||||
| self._load_attr_from_module(key, self._model_mapping[key]), | |||||
| ) | |||||
| for key in self._model_mapping.keys() | |||||
| if key in self._config_mapping.keys() | |||||
| ] | |||||
| def __iter__(self): | |||||
| return iter(self._mapping.keys()) | |||||
| def __contains__(self, item): | |||||
| if not hasattr(item, "__name__") or item.__name__ not in self._reverse_config_mapping: | |||||
| return False | |||||
| model_type = self._reverse_config_mapping[item.__name__] | |||||
| return model_type in self._model_mapping | |||||
| @@ -0,0 +1,208 @@ | |||||
| import importlib | |||||
| import os | |||||
| import re | |||||
| import shutil | |||||
| import sys | |||||
| from pathlib import Path | |||||
| from typing import Dict, Optional, Union | |||||
| from fastNLP.transformers.torch.file_utils import ( | |||||
| HF_MODULES_CACHE, | |||||
| TRANSFORMERS_DYNAMIC_MODULE_NAME, | |||||
| cached_path, | |||||
| hf_bucket_url, | |||||
| is_offline_mode, | |||||
| ) | |||||
| from fastNLP.core.log import logger | |||||
| def init_hf_modules(): | |||||
| """ | |||||
| Creates the cache directory for modules with an init, and adds it to the Python path. | |||||
| """ | |||||
| # This function has already been executed if HF_MODULES_CACHE already is in the Python path. | |||||
| if HF_MODULES_CACHE in sys.path: | |||||
| return | |||||
| sys.path.append(HF_MODULES_CACHE) | |||||
| os.makedirs(HF_MODULES_CACHE, exist_ok=True) | |||||
| init_path = Path(HF_MODULES_CACHE) / "__init__.py" | |||||
| if not init_path.exists(): | |||||
| init_path.touch() | |||||
| def create_dynamic_module(name: Union[str, os.PathLike]): | |||||
| """ | |||||
| Creates a dynamic module in the cache directory for modules. | |||||
| """ | |||||
| init_hf_modules() | |||||
| dynamic_module_path = Path(HF_MODULES_CACHE) / name | |||||
| # If the parent module does not exist yet, recursively create it. | |||||
| if not dynamic_module_path.parent.exists(): | |||||
| create_dynamic_module(dynamic_module_path.parent) | |||||
| os.makedirs(dynamic_module_path, exist_ok=True) | |||||
| init_path = dynamic_module_path / "__init__.py" | |||||
| if not init_path.exists(): | |||||
| init_path.touch() | |||||
| def check_imports(filename): | |||||
| """ | |||||
| Check if the current Python environment contains all the libraries that are imported in a file. | |||||
| """ | |||||
| with open(filename, "r", encoding="utf-8") as f: | |||||
| content = f.read() | |||||
| # Imports of the form `import xxx` | |||||
| imports = re.findall("^\s*import\s+(\S+)\s*$", content, flags=re.MULTILINE) | |||||
| # Imports of the form `from xxx import yyy` | |||||
| imports += re.findall("^\s*from\s+(\S+)\s+import", content, flags=re.MULTILINE) | |||||
| # Only keep the top-level module | |||||
| imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")] | |||||
| # Unique-ify and test we got them all | |||||
| imports = list(set(imports)) | |||||
| missing_packages = [] | |||||
| for imp in imports: | |||||
| try: | |||||
| importlib.import_module(imp) | |||||
| except ImportError: | |||||
| missing_packages.append(imp) | |||||
| if len(missing_packages) > 0: | |||||
| raise ImportError( | |||||
| "This modeling file requires the following packages that were not found in your environment: " | |||||
| f"{', '.join(missing_packages)}. Run `pip install {' '.join(missing_packages)}`" | |||||
| ) | |||||
| def get_class_in_module(class_name, module_path): | |||||
| """ | |||||
| Import a module on the cache directory for modules and extract a class from it. | |||||
| """ | |||||
| module_path = module_path.replace(os.path.sep, ".") | |||||
| module = importlib.import_module(module_path) | |||||
| return getattr(module, class_name) | |||||
| def get_class_from_dynamic_module( | |||||
| pretrained_model_name_or_path: Union[str, os.PathLike], | |||||
| module_file: str, | |||||
| class_name: str, | |||||
| cache_dir: Optional[Union[str, os.PathLike]] = None, | |||||
| force_download: bool = False, | |||||
| resume_download: bool = False, | |||||
| proxies: Optional[Dict[str, str]] = None, | |||||
| use_auth_token: Optional[Union[bool, str]] = None, | |||||
| revision: Optional[str] = None, | |||||
| local_files_only: bool = False, | |||||
| **kwargs, | |||||
| ): | |||||
| """ | |||||
| Extracts a class from a module file, present in the local folder or repository of a model. | |||||
| .. warning:: | |||||
| Calling this function will execute the code in the module file found locally or downloaded from the Hub. It | |||||
| should therefore only be called on trusted repos. | |||||
| Args: | |||||
| pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`): | |||||
| This can be either: | |||||
| - a string, the `model id` of a pretrained model configuration hosted inside a model repo on | |||||
| huggingface.co. Valid model ids can be located at the root-level, like ``bert-base-uncased``, or | |||||
| namespaced under a user or organization name, like ``dbmdz/bert-base-german-cased``. | |||||
| - a path to a `directory` containing a configuration file saved using the | |||||
| :func:`~transformers.PreTrainedTokenizer.save_pretrained` method, e.g., ``./my_model_directory/``. | |||||
| module_file (:obj:`str`): | |||||
| The name of the module file containing the class to look for. | |||||
| class_name (:obj:`str`): | |||||
| The name of the class to import in the module. | |||||
| cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`): | |||||
| Path to a directory in which a downloaded pretrained model configuration should be cached if the standard | |||||
| cache should not be used. | |||||
| force_download (:obj:`bool`, `optional`, defaults to :obj:`False`): | |||||
| Whether or not to force to (re-)download the configuration files and override the cached versions if they | |||||
| exist. | |||||
| resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`): | |||||
| Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists. | |||||
| proxies (:obj:`Dict[str, str]`, `optional`): | |||||
| A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128', | |||||
| 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. | |||||
| use_auth_token (:obj:`str` or `bool`, `optional`): | |||||
| The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token | |||||
| generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`). | |||||
| revision(:obj:`str`, `optional`, defaults to :obj:`"main"`): | |||||
| The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a | |||||
| git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any | |||||
| identifier allowed by git. | |||||
| local_files_only (:obj:`bool`, `optional`, defaults to :obj:`False`): | |||||
| If :obj:`True`, will only try to load the tokenizer configuration from local files. | |||||
| .. note:: | |||||
| Passing :obj:`use_auth_token=True` is required when you want to use a private model. | |||||
| Returns: | |||||
| :obj:`type`: The class, dynamically imported from the module. | |||||
| Examples:: | |||||
| # Download module `modeling.py` from huggingface.co and cache then extract the class `MyBertModel` from this | |||||
| # module. | |||||
| cls = get_class_from_dynamic_module("sgugger/my-bert-model", "modeling.py", "MyBertModel") | |||||
| """ | |||||
| if is_offline_mode() and not local_files_only: | |||||
| logger.info("Offline mode: forcing local_files_only=True") | |||||
| local_files_only = True | |||||
| # Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file. | |||||
| pretrained_model_name_or_path = str(pretrained_model_name_or_path) | |||||
| if os.path.isdir(pretrained_model_name_or_path): | |||||
| module_file_or_url = os.path.join(pretrained_model_name_or_path, module_file) | |||||
| submodule = "local" | |||||
| else: | |||||
| module_file_or_url = hf_bucket_url( | |||||
| pretrained_model_name_or_path, filename=module_file, revision=revision, mirror=None | |||||
| ) | |||||
| submodule = pretrained_model_name_or_path.replace("/", os.path.sep) | |||||
| try: | |||||
| # Load from URL or cache if already cached | |||||
| resolved_module_file = cached_path( | |||||
| module_file_or_url, | |||||
| cache_dir=cache_dir, | |||||
| force_download=force_download, | |||||
| proxies=proxies, | |||||
| resume_download=resume_download, | |||||
| local_files_only=local_files_only, | |||||
| use_auth_token=use_auth_token, | |||||
| ) | |||||
| except EnvironmentError: | |||||
| logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.") | |||||
| raise | |||||
| # Check we have all the requirements in our environment | |||||
| check_imports(resolved_module_file) | |||||
| # Now we move the module inside our cached dynamic modules. | |||||
| full_submodule = TRANSFORMERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule | |||||
| create_dynamic_module(full_submodule) | |||||
| submodule_path = Path(HF_MODULES_CACHE) / full_submodule | |||||
| if submodule == "local": | |||||
| # We always copy local files (we could hash the file to see if there was a change, and give them the name of | |||||
| # that hash, to only copy when there is a modification but it seems overkill for now). | |||||
| # The only reason we do the copy is to avoid putting too many folders in sys.path. | |||||
| module_name = module_file | |||||
| shutil.copy(resolved_module_file, submodule_path / module_file) | |||||
| else: | |||||
| # The module file will end up being named module_file + the etag. This way we get the benefit of versioning. | |||||
| resolved_module_file_name = Path(resolved_module_file).name | |||||
| module_name_parts = [module_file.replace(".py", "")] + resolved_module_file_name.split(".") | |||||
| module_name = "_".join(module_name_parts) + ".py" | |||||
| if not (submodule_path / module_name).exists(): | |||||
| shutil.copy(resolved_module_file, submodule_path / module_name) | |||||
| # And lastly we get the class inside our newly created module | |||||
| final_module = os.path.join(full_submodule, module_name.replace(".py", "")) | |||||
| return get_class_in_module(class_name, final_module) | |||||
| @@ -0,0 +1,663 @@ | |||||
| # coding=utf-8 | |||||
| # Copyright 2018 The HuggingFace Inc. team. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| """ Auto Model class. """ | |||||
| import warnings | |||||
| from collections import OrderedDict | |||||
| from .auto_factory import _BaseAutoModelClass, _LazyAutoMapping, auto_class_update | |||||
| from .configuration_auto import CONFIG_MAPPING_NAMES | |||||
| from fastNLP.core.log import logger | |||||
| MODEL_MAPPING_NAMES = OrderedDict( | |||||
| [ | |||||
| # Base model mapping | |||||
| ("fnet", "FNetModel"), | |||||
| ("gptj", "GPTJModel"), | |||||
| ("layoutlmv2", "LayoutLMv2Model"), | |||||
| ("beit", "BeitModel"), | |||||
| ("rembert", "RemBertModel"), | |||||
| ("visual_bert", "VisualBertModel"), | |||||
| ("canine", "CanineModel"), | |||||
| ("roformer", "RoFormerModel"), | |||||
| ("clip", "CLIPModel"), | |||||
| ("bigbird_pegasus", "BigBirdPegasusModel"), | |||||
| ("deit", "DeiTModel"), | |||||
| ("luke", "LukeModel"), | |||||
| ("detr", "DetrModel"), | |||||
| ("gpt_neo", "GPTNeoModel"), | |||||
| ("big_bird", "BigBirdModel"), | |||||
| ("speech_to_text", "Speech2TextModel"), | |||||
| ("vit", "ViTModel"), | |||||
| ("wav2vec2", "Wav2Vec2Model"), | |||||
| ("hubert", "HubertModel"), | |||||
| ("m2m_100", "M2M100Model"), | |||||
| ("convbert", "ConvBertModel"), | |||||
| ("led", "LEDModel"), | |||||
| ("blenderbot-small", "BlenderbotSmallModel"), | |||||
| ("retribert", "RetriBertModel"), | |||||
| ("mt5", "MT5Model"), | |||||
| ("t5", "T5Model"), | |||||
| ("pegasus", "PegasusModel"), | |||||
| ("marian", "MarianModel"), | |||||
| ("mbart", "MBartModel"), | |||||
| ("blenderbot", "BlenderbotModel"), | |||||
| ("distilbert", "DistilBertModel"), | |||||
| ("albert", "AlbertModel"), | |||||
| ("camembert", "CamembertModel"), | |||||
| ("xlm-roberta", "XLMRobertaModel"), | |||||
| ("bart", "BartModel"), | |||||
| ("longformer", "LongformerModel"), | |||||
| ("roberta", "RobertaModel"), | |||||
| ("layoutlm", "LayoutLMModel"), | |||||
| ("squeezebert", "SqueezeBertModel"), | |||||
| ("bert", "BertModel"), | |||||
| ("openai-gpt", "OpenAIGPTModel"), | |||||
| ("gpt2", "GPT2Model"), | |||||
| ("megatron-bert", "MegatronBertModel"), | |||||
| ("mobilebert", "MobileBertModel"), | |||||
| ("transfo-xl", "TransfoXLModel"), | |||||
| ("xlnet", "XLNetModel"), | |||||
| ("flaubert", "FlaubertModel"), | |||||
| ("fsmt", "FSMTModel"), | |||||
| ("xlm", "XLMModel"), | |||||
| ("ctrl", "CTRLModel"), | |||||
| ("electra", "ElectraModel"), | |||||
| ("reformer", "ReformerModel"), | |||||
| ("funnel", ("FunnelModel", "FunnelBaseModel")), | |||||
| ("lxmert", "LxmertModel"), | |||||
| ("bert-generation", "BertGenerationEncoder"), | |||||
| ("deberta", "DebertaModel"), | |||||
| ("deberta-v2", "DebertaV2Model"), | |||||
| ("dpr", "DPRQuestionEncoder"), | |||||
| ("xlm-prophetnet", "XLMProphetNetModel"), | |||||
| ("prophetnet", "ProphetNetModel"), | |||||
| ("mpnet", "MPNetModel"), | |||||
| ("tapas", "TapasModel"), | |||||
| ("ibert", "IBertModel"), | |||||
| ("splinter", "SplinterModel"), | |||||
| ] | |||||
| ) | |||||
| MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict( | |||||
| [ | |||||
| # Model for pre-training mapping | |||||
| ("fnet", "FNetForPreTraining"), | |||||
| ("visual_bert", "VisualBertForPreTraining"), | |||||
| ("layoutlm", "LayoutLMForMaskedLM"), | |||||
| ("retribert", "RetriBertModel"), | |||||
| ("t5", "T5ForConditionalGeneration"), | |||||
| ("distilbert", "DistilBertForMaskedLM"), | |||||
| ("albert", "AlbertForPreTraining"), | |||||
| ("camembert", "CamembertForMaskedLM"), | |||||
| ("xlm-roberta", "XLMRobertaForMaskedLM"), | |||||
| ("bart", "BartForConditionalGeneration"), | |||||
| ("fsmt", "FSMTForConditionalGeneration"), | |||||
| ("longformer", "LongformerForMaskedLM"), | |||||
| ("roberta", "RobertaForMaskedLM"), | |||||
| ("squeezebert", "SqueezeBertForMaskedLM"), | |||||
| ("bert", "BertForPreTraining"), | |||||
| ("big_bird", "BigBirdForPreTraining"), | |||||
| ("openai-gpt", "OpenAIGPTLMHeadModel"), | |||||
| ("gpt2", "GPT2LMHeadModel"), | |||||
| ("megatron-bert", "MegatronBertForPreTraining"), | |||||
| ("mobilebert", "MobileBertForPreTraining"), | |||||
| ("transfo-xl", "TransfoXLLMHeadModel"), | |||||
| ("xlnet", "XLNetLMHeadModel"), | |||||
| ("flaubert", "FlaubertWithLMHeadModel"), | |||||
| ("xlm", "XLMWithLMHeadModel"), | |||||
| ("ctrl", "CTRLLMHeadModel"), | |||||
| ("electra", "ElectraForPreTraining"), | |||||
| ("lxmert", "LxmertForPreTraining"), | |||||
| ("funnel", "FunnelForPreTraining"), | |||||
| ("mpnet", "MPNetForMaskedLM"), | |||||
| ("tapas", "TapasForMaskedLM"), | |||||
| ("ibert", "IBertForMaskedLM"), | |||||
| ("deberta", "DebertaForMaskedLM"), | |||||
| ("deberta-v2", "DebertaV2ForMaskedLM"), | |||||
| ("wav2vec2", "Wav2Vec2ForPreTraining"), | |||||
| ] | |||||
| ) | |||||
| MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict( | |||||
| [ | |||||
| # Model with LM heads mapping | |||||
| ("fnet", "FNetForMaskedLM"), | |||||
| ("gptj", "GPTJForCausalLM"), | |||||
| ("rembert", "RemBertForMaskedLM"), | |||||
| ("roformer", "RoFormerForMaskedLM"), | |||||
| ("bigbird_pegasus", "BigBirdPegasusForConditionalGeneration"), | |||||
| ("gpt_neo", "GPTNeoForCausalLM"), | |||||
| ("big_bird", "BigBirdForMaskedLM"), | |||||
| ("speech_to_text", "Speech2TextForConditionalGeneration"), | |||||
| ("wav2vec2", "Wav2Vec2ForMaskedLM"), | |||||
| ("m2m_100", "M2M100ForConditionalGeneration"), | |||||
| ("convbert", "ConvBertForMaskedLM"), | |||||
| ("led", "LEDForConditionalGeneration"), | |||||
| ("blenderbot-small", "BlenderbotSmallForConditionalGeneration"), | |||||
| ("layoutlm", "LayoutLMForMaskedLM"), | |||||
| ("t5", "T5ForConditionalGeneration"), | |||||
| ("distilbert", "DistilBertForMaskedLM"), | |||||
| ("albert", "AlbertForMaskedLM"), | |||||
| ("camembert", "CamembertForMaskedLM"), | |||||
| ("xlm-roberta", "XLMRobertaForMaskedLM"), | |||||
| ("marian", "MarianMTModel"), | |||||
| ("fsmt", "FSMTForConditionalGeneration"), | |||||
| ("bart", "BartForConditionalGeneration"), | |||||
| ("longformer", "LongformerForMaskedLM"), | |||||
| ("roberta", "RobertaForMaskedLM"), | |||||
| ("squeezebert", "SqueezeBertForMaskedLM"), | |||||
| ("bert", "BertForMaskedLM"), | |||||
| ("openai-gpt", "OpenAIGPTLMHeadModel"), | |||||
| ("gpt2", "GPT2LMHeadModel"), | |||||
| ("megatron-bert", "MegatronBertForCausalLM"), | |||||
| ("mobilebert", "MobileBertForMaskedLM"), | |||||
| ("transfo-xl", "TransfoXLLMHeadModel"), | |||||
| ("xlnet", "XLNetLMHeadModel"), | |||||
| ("flaubert", "FlaubertWithLMHeadModel"), | |||||
| ("xlm", "XLMWithLMHeadModel"), | |||||
| ("ctrl", "CTRLLMHeadModel"), | |||||
| ("electra", "ElectraForMaskedLM"), | |||||
| ("encoder-decoder", "EncoderDecoderModel"), | |||||
| ("reformer", "ReformerModelWithLMHead"), | |||||
| ("funnel", "FunnelForMaskedLM"), | |||||
| ("mpnet", "MPNetForMaskedLM"), | |||||
| ("tapas", "TapasForMaskedLM"), | |||||
| ("deberta", "DebertaForMaskedLM"), | |||||
| ("deberta-v2", "DebertaV2ForMaskedLM"), | |||||
| ("ibert", "IBertForMaskedLM"), | |||||
| ] | |||||
| ) | |||||
| MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( | |||||
| [ | |||||
| # Model for Causal LM mapping | |||||
| ("gptj", "GPTJForCausalLM"), | |||||
| ("rembert", "RemBertForCausalLM"), | |||||
| ("roformer", "RoFormerForCausalLM"), | |||||
| ("bigbird_pegasus", "BigBirdPegasusForCausalLM"), | |||||
| ("gpt_neo", "GPTNeoForCausalLM"), | |||||
| ("big_bird", "BigBirdForCausalLM"), | |||||
| ("camembert", "CamembertForCausalLM"), | |||||
| ("xlm-roberta", "XLMRobertaForCausalLM"), | |||||
| ("roberta", "RobertaForCausalLM"), | |||||
| ("bert", "BertLMHeadModel"), | |||||
| ("openai-gpt", "OpenAIGPTLMHeadModel"), | |||||
| ("gpt2", "GPT2LMHeadModel"), | |||||
| ("transfo-xl", "TransfoXLLMHeadModel"), | |||||
| ("xlnet", "XLNetLMHeadModel"), | |||||
| ("xlm", "XLMWithLMHeadModel"), | |||||
| ("ctrl", "CTRLLMHeadModel"), | |||||
| ("reformer", "ReformerModelWithLMHead"), | |||||
| ("bert-generation", "BertGenerationDecoder"), | |||||
| ("xlm-prophetnet", "XLMProphetNetForCausalLM"), | |||||
| ("prophetnet", "ProphetNetForCausalLM"), | |||||
| ("bart", "BartForCausalLM"), | |||||
| ("mbart", "MBartForCausalLM"), | |||||
| ("pegasus", "PegasusForCausalLM"), | |||||
| ("marian", "MarianForCausalLM"), | |||||
| ("blenderbot", "BlenderbotForCausalLM"), | |||||
| ("blenderbot-small", "BlenderbotSmallForCausalLM"), | |||||
| ("megatron-bert", "MegatronBertForCausalLM"), | |||||
| ("speech_to_text_2", "Speech2Text2ForCausalLM"), | |||||
| ] | |||||
| ) | |||||
| MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( | |||||
| [ | |||||
| # Model for Image Classification mapping | |||||
| ("vit", "ViTForImageClassification"), | |||||
| ("deit", ("DeiTForImageClassification", "DeiTForImageClassificationWithTeacher")), | |||||
| ("beit", "BeitForImageClassification"), | |||||
| ] | |||||
| ) | |||||
| MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict( | |||||
| [ | |||||
| # Model for Masked LM mapping | |||||
| ("fnet", "FNetForMaskedLM"), | |||||
| ("rembert", "RemBertForMaskedLM"), | |||||
| ("roformer", "RoFormerForMaskedLM"), | |||||
| ("big_bird", "BigBirdForMaskedLM"), | |||||
| ("wav2vec2", "Wav2Vec2ForMaskedLM"), | |||||
| ("convbert", "ConvBertForMaskedLM"), | |||||
| ("layoutlm", "LayoutLMForMaskedLM"), | |||||
| ("distilbert", "DistilBertForMaskedLM"), | |||||
| ("albert", "AlbertForMaskedLM"), | |||||
| ("bart", "BartForConditionalGeneration"), | |||||
| ("mbart", "MBartForConditionalGeneration"), | |||||
| ("camembert", "CamembertForMaskedLM"), | |||||
| ("xlm-roberta", "XLMRobertaForMaskedLM"), | |||||
| ("longformer", "LongformerForMaskedLM"), | |||||
| ("roberta", "RobertaForMaskedLM"), | |||||
| ("squeezebert", "SqueezeBertForMaskedLM"), | |||||
| ("bert", "BertForMaskedLM"), | |||||
| ("megatron-bert", "MegatronBertForMaskedLM"), | |||||
| ("mobilebert", "MobileBertForMaskedLM"), | |||||
| ("flaubert", "FlaubertWithLMHeadModel"), | |||||
| ("xlm", "XLMWithLMHeadModel"), | |||||
| ("electra", "ElectraForMaskedLM"), | |||||
| ("reformer", "ReformerForMaskedLM"), | |||||
| ("funnel", "FunnelForMaskedLM"), | |||||
| ("mpnet", "MPNetForMaskedLM"), | |||||
| ("tapas", "TapasForMaskedLM"), | |||||
| ("deberta", "DebertaForMaskedLM"), | |||||
| ("deberta-v2", "DebertaV2ForMaskedLM"), | |||||
| ("ibert", "IBertForMaskedLM"), | |||||
| ] | |||||
| ) | |||||
| MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict( | |||||
| [ | |||||
| # Model for Object Detection mapping | |||||
| ("detr", "DetrForObjectDetection"), | |||||
| ] | |||||
| ) | |||||
| MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict( | |||||
| [ | |||||
| # Model for Seq2Seq Causal LM mapping | |||||
| ("bigbird_pegasus", "BigBirdPegasusForConditionalGeneration"), | |||||
| ("m2m_100", "M2M100ForConditionalGeneration"), | |||||
| ("led", "LEDForConditionalGeneration"), | |||||
| ("blenderbot-small", "BlenderbotSmallForConditionalGeneration"), | |||||
| ("mt5", "MT5ForConditionalGeneration"), | |||||
| ("t5", "T5ForConditionalGeneration"), | |||||
| ("pegasus", "PegasusForConditionalGeneration"), | |||||
| ("marian", "MarianMTModel"), | |||||
| ("mbart", "MBartForConditionalGeneration"), | |||||
| ("blenderbot", "BlenderbotForConditionalGeneration"), | |||||
| ("bart", "BartForConditionalGeneration"), | |||||
| ("fsmt", "FSMTForConditionalGeneration"), | |||||
| ("encoder-decoder", "EncoderDecoderModel"), | |||||
| ("xlm-prophetnet", "XLMProphetNetForConditionalGeneration"), | |||||
| ("prophetnet", "ProphetNetForConditionalGeneration"), | |||||
| ] | |||||
| ) | |||||
| MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict( | |||||
| [ | |||||
| ("speech-encoder-decoder", "SpeechEncoderDecoderModel"), | |||||
| ("speech_to_text", "Speech2TextForConditionalGeneration"), | |||||
| ] | |||||
| ) | |||||
| MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( | |||||
| [ | |||||
| # Model for Sequence Classification mapping | |||||
| ("fnet", "FNetForSequenceClassification"), | |||||
| ("gptj", "GPTJForSequenceClassification"), | |||||
| ("layoutlmv2", "LayoutLMv2ForSequenceClassification"), | |||||
| ("rembert", "RemBertForSequenceClassification"), | |||||
| ("canine", "CanineForSequenceClassification"), | |||||
| ("roformer", "RoFormerForSequenceClassification"), | |||||
| ("bigbird_pegasus", "BigBirdPegasusForSequenceClassification"), | |||||
| ("big_bird", "BigBirdForSequenceClassification"), | |||||
| ("convbert", "ConvBertForSequenceClassification"), | |||||
| ("led", "LEDForSequenceClassification"), | |||||
| ("distilbert", "DistilBertForSequenceClassification"), | |||||
| ("albert", "AlbertForSequenceClassification"), | |||||
| ("camembert", "CamembertForSequenceClassification"), | |||||
| ("xlm-roberta", "XLMRobertaForSequenceClassification"), | |||||
| ("mbart", "MBartForSequenceClassification"), | |||||
| ("bart", "BartForSequenceClassification"), | |||||
| ("longformer", "LongformerForSequenceClassification"), | |||||
| ("roberta", "RobertaForSequenceClassification"), | |||||
| ("squeezebert", "SqueezeBertForSequenceClassification"), | |||||
| ("layoutlm", "LayoutLMForSequenceClassification"), | |||||
| ("bert", "BertForSequenceClassification"), | |||||
| ("xlnet", "XLNetForSequenceClassification"), | |||||
| ("megatron-bert", "MegatronBertForSequenceClassification"), | |||||
| ("mobilebert", "MobileBertForSequenceClassification"), | |||||
| ("flaubert", "FlaubertForSequenceClassification"), | |||||
| ("xlm", "XLMForSequenceClassification"), | |||||
| ("electra", "ElectraForSequenceClassification"), | |||||
| ("funnel", "FunnelForSequenceClassification"), | |||||
| ("deberta", "DebertaForSequenceClassification"), | |||||
| ("deberta-v2", "DebertaV2ForSequenceClassification"), | |||||
| ("gpt2", "GPT2ForSequenceClassification"), | |||||
| ("gpt_neo", "GPTNeoForSequenceClassification"), | |||||
| ("openai-gpt", "OpenAIGPTForSequenceClassification"), | |||||
| ("reformer", "ReformerForSequenceClassification"), | |||||
| ("ctrl", "CTRLForSequenceClassification"), | |||||
| ("transfo-xl", "TransfoXLForSequenceClassification"), | |||||
| ("mpnet", "MPNetForSequenceClassification"), | |||||
| ("tapas", "TapasForSequenceClassification"), | |||||
| ("ibert", "IBertForSequenceClassification"), | |||||
| ] | |||||
| ) | |||||
| MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( | |||||
| [ | |||||
| # Model for Question Answering mapping | |||||
| ("fnet", "FNetForQuestionAnswering"), | |||||
| ("layoutlmv2", "LayoutLMv2ForQuestionAnswering"), | |||||
| ("rembert", "RemBertForQuestionAnswering"), | |||||
| ("canine", "CanineForQuestionAnswering"), | |||||
| ("roformer", "RoFormerForQuestionAnswering"), | |||||
| ("bigbird_pegasus", "BigBirdPegasusForQuestionAnswering"), | |||||
| ("big_bird", "BigBirdForQuestionAnswering"), | |||||
| ("convbert", "ConvBertForQuestionAnswering"), | |||||
| ("led", "LEDForQuestionAnswering"), | |||||
| ("distilbert", "DistilBertForQuestionAnswering"), | |||||
| ("albert", "AlbertForQuestionAnswering"), | |||||
| ("camembert", "CamembertForQuestionAnswering"), | |||||
| ("bart", "BartForQuestionAnswering"), | |||||
| ("mbart", "MBartForQuestionAnswering"), | |||||
| ("longformer", "LongformerForQuestionAnswering"), | |||||
| ("xlm-roberta", "XLMRobertaForQuestionAnswering"), | |||||
| ("roberta", "RobertaForQuestionAnswering"), | |||||
| ("squeezebert", "SqueezeBertForQuestionAnswering"), | |||||
| ("bert", "BertForQuestionAnswering"), | |||||
| ("xlnet", "XLNetForQuestionAnsweringSimple"), | |||||
| ("flaubert", "FlaubertForQuestionAnsweringSimple"), | |||||
| ("megatron-bert", "MegatronBertForQuestionAnswering"), | |||||
| ("mobilebert", "MobileBertForQuestionAnswering"), | |||||
| ("xlm", "XLMForQuestionAnsweringSimple"), | |||||
| ("electra", "ElectraForQuestionAnswering"), | |||||
| ("reformer", "ReformerForQuestionAnswering"), | |||||
| ("funnel", "FunnelForQuestionAnswering"), | |||||
| ("lxmert", "LxmertForQuestionAnswering"), | |||||
| ("mpnet", "MPNetForQuestionAnswering"), | |||||
| ("deberta", "DebertaForQuestionAnswering"), | |||||
| ("deberta-v2", "DebertaV2ForQuestionAnswering"), | |||||
| ("ibert", "IBertForQuestionAnswering"), | |||||
| ("splinter", "SplinterForQuestionAnswering"), | |||||
| ] | |||||
| ) | |||||
| MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( | |||||
| [ | |||||
| # Model for Table Question Answering mapping | |||||
| ("tapas", "TapasForQuestionAnswering"), | |||||
| ] | |||||
| ) | |||||
| MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict( | |||||
| [ | |||||
| # Model for Token Classification mapping | |||||
| ("fnet", "FNetForTokenClassification"), | |||||
| ("layoutlmv2", "LayoutLMv2ForTokenClassification"), | |||||
| ("rembert", "RemBertForTokenClassification"), | |||||
| ("canine", "CanineForTokenClassification"), | |||||
| ("roformer", "RoFormerForTokenClassification"), | |||||
| ("big_bird", "BigBirdForTokenClassification"), | |||||
| ("convbert", "ConvBertForTokenClassification"), | |||||
| ("layoutlm", "LayoutLMForTokenClassification"), | |||||
| ("distilbert", "DistilBertForTokenClassification"), | |||||
| ("camembert", "CamembertForTokenClassification"), | |||||
| ("flaubert", "FlaubertForTokenClassification"), | |||||
| ("xlm", "XLMForTokenClassification"), | |||||
| ("xlm-roberta", "XLMRobertaForTokenClassification"), | |||||
| ("longformer", "LongformerForTokenClassification"), | |||||
| ("roberta", "RobertaForTokenClassification"), | |||||
| ("squeezebert", "SqueezeBertForTokenClassification"), | |||||
| ("bert", "BertForTokenClassification"), | |||||
| ("megatron-bert", "MegatronBertForTokenClassification"), | |||||
| ("mobilebert", "MobileBertForTokenClassification"), | |||||
| ("xlnet", "XLNetForTokenClassification"), | |||||
| ("albert", "AlbertForTokenClassification"), | |||||
| ("electra", "ElectraForTokenClassification"), | |||||
| ("funnel", "FunnelForTokenClassification"), | |||||
| ("mpnet", "MPNetForTokenClassification"), | |||||
| ("deberta", "DebertaForTokenClassification"), | |||||
| ("deberta-v2", "DebertaV2ForTokenClassification"), | |||||
| ("gpt2", "GPT2ForTokenClassification"), | |||||
| ("ibert", "IBertForTokenClassification"), | |||||
| ] | |||||
| ) | |||||
| MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict( | |||||
| [ | |||||
| # Model for Multiple Choice mapping | |||||
| ("fnet", "FNetForMultipleChoice"), | |||||
| ("rembert", "RemBertForMultipleChoice"), | |||||
| ("canine", "CanineForMultipleChoice"), | |||||
| ("roformer", "RoFormerForMultipleChoice"), | |||||
| ("big_bird", "BigBirdForMultipleChoice"), | |||||
| ("convbert", "ConvBertForMultipleChoice"), | |||||
| ("camembert", "CamembertForMultipleChoice"), | |||||
| ("electra", "ElectraForMultipleChoice"), | |||||
| ("xlm-roberta", "XLMRobertaForMultipleChoice"), | |||||
| ("longformer", "LongformerForMultipleChoice"), | |||||
| ("roberta", "RobertaForMultipleChoice"), | |||||
| ("squeezebert", "SqueezeBertForMultipleChoice"), | |||||
| ("bert", "BertForMultipleChoice"), | |||||
| ("distilbert", "DistilBertForMultipleChoice"), | |||||
| ("megatron-bert", "MegatronBertForMultipleChoice"), | |||||
| ("mobilebert", "MobileBertForMultipleChoice"), | |||||
| ("xlnet", "XLNetForMultipleChoice"), | |||||
| ("albert", "AlbertForMultipleChoice"), | |||||
| ("xlm", "XLMForMultipleChoice"), | |||||
| ("flaubert", "FlaubertForMultipleChoice"), | |||||
| ("funnel", "FunnelForMultipleChoice"), | |||||
| ("mpnet", "MPNetForMultipleChoice"), | |||||
| ("ibert", "IBertForMultipleChoice"), | |||||
| ] | |||||
| ) | |||||
| MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict( | |||||
| [ | |||||
| ("bert", "BertForNextSentencePrediction"), | |||||
| ("fnet", "FNetForNextSentencePrediction"), | |||||
| ("megatron-bert", "MegatronBertForNextSentencePrediction"), | |||||
| ("mobilebert", "MobileBertForNextSentencePrediction"), | |||||
| ] | |||||
| ) | |||||
| MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict( | |||||
| [ | |||||
| # Model for Audio Classification mapping | |||||
| ("wav2vec2", "Wav2Vec2ForSequenceClassification"), | |||||
| ("hubert", "HubertForSequenceClassification"), | |||||
| ] | |||||
| ) | |||||
| MODEL_FOR_CTC_MAPPING_NAMES = OrderedDict( | |||||
| [ | |||||
| # Model for Connectionist temporal classification (CTC) mapping | |||||
| ("wav2vec2", "Wav2Vec2ForCTC"), | |||||
| ("hubert", "HubertForCTC"), | |||||
| ] | |||||
| ) | |||||
| MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES) | |||||
| MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_PRETRAINING_MAPPING_NAMES) | |||||
| MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_WITH_LM_HEAD_MAPPING_NAMES) | |||||
| MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES) | |||||
| MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping( | |||||
| CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES | |||||
| ) | |||||
| MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_LM_MAPPING_NAMES) | |||||
| MODEL_FOR_OBJECT_DETECTION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES) | |||||
| MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping( | |||||
| CONFIG_MAPPING_NAMES, MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES | |||||
| ) | |||||
| MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = _LazyAutoMapping( | |||||
| CONFIG_MAPPING_NAMES, MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES | |||||
| ) | |||||
| MODEL_FOR_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping( | |||||
| CONFIG_MAPPING_NAMES, MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES | |||||
| ) | |||||
| MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping( | |||||
| CONFIG_MAPPING_NAMES, MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES | |||||
| ) | |||||
| MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = _LazyAutoMapping( | |||||
| CONFIG_MAPPING_NAMES, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES | |||||
| ) | |||||
| MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES) | |||||
| MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping( | |||||
| CONFIG_MAPPING_NAMES, MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES | |||||
| ) | |||||
| MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping( | |||||
| CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES | |||||
| ) | |||||
| MODEL_FOR_CTC_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_CTC_MAPPING_NAMES) | |||||
| MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES) | |||||
| class AutoModel(_BaseAutoModelClass): | |||||
| _model_mapping = MODEL_MAPPING | |||||
| AutoModel = auto_class_update(AutoModel) | |||||
| class AutoModelForPreTraining(_BaseAutoModelClass): | |||||
| _model_mapping = MODEL_FOR_PRETRAINING_MAPPING | |||||
| AutoModelForPreTraining = auto_class_update(AutoModelForPreTraining, head_doc="pretraining") | |||||
| # Private on purpose, the public class will add the deprecation warnings. | |||||
| class _AutoModelWithLMHead(_BaseAutoModelClass): | |||||
| _model_mapping = MODEL_WITH_LM_HEAD_MAPPING | |||||
| _AutoModelWithLMHead = auto_class_update(_AutoModelWithLMHead, head_doc="language modeling") | |||||
| class AutoModelForCausalLM(_BaseAutoModelClass): | |||||
| _model_mapping = MODEL_FOR_CAUSAL_LM_MAPPING | |||||
| AutoModelForCausalLM = auto_class_update(AutoModelForCausalLM, head_doc="causal language modeling") | |||||
| class AutoModelForMaskedLM(_BaseAutoModelClass): | |||||
| _model_mapping = MODEL_FOR_MASKED_LM_MAPPING | |||||
| AutoModelForMaskedLM = auto_class_update(AutoModelForMaskedLM, head_doc="masked language modeling") | |||||
| class AutoModelForSeq2SeqLM(_BaseAutoModelClass): | |||||
| _model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING | |||||
| AutoModelForSeq2SeqLM = auto_class_update( | |||||
| AutoModelForSeq2SeqLM, head_doc="sequence-to-sequence language modeling", checkpoint_for_example="t5-base" | |||||
| ) | |||||
| class AutoModelForSequenceClassification(_BaseAutoModelClass): | |||||
| _model_mapping = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING | |||||
| AutoModelForSequenceClassification = auto_class_update( | |||||
| AutoModelForSequenceClassification, head_doc="sequence classification" | |||||
| ) | |||||
| class AutoModelForQuestionAnswering(_BaseAutoModelClass): | |||||
| _model_mapping = MODEL_FOR_QUESTION_ANSWERING_MAPPING | |||||
| AutoModelForQuestionAnswering = auto_class_update(AutoModelForQuestionAnswering, head_doc="question answering") | |||||
| class AutoModelForTableQuestionAnswering(_BaseAutoModelClass): | |||||
| _model_mapping = MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING | |||||
| AutoModelForTableQuestionAnswering = auto_class_update( | |||||
| AutoModelForTableQuestionAnswering, | |||||
| head_doc="table question answering", | |||||
| checkpoint_for_example="google/tapas-base-finetuned-wtq", | |||||
| ) | |||||
| class AutoModelForTokenClassification(_BaseAutoModelClass): | |||||
| _model_mapping = MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING | |||||
| AutoModelForTokenClassification = auto_class_update(AutoModelForTokenClassification, head_doc="token classification") | |||||
| class AutoModelForMultipleChoice(_BaseAutoModelClass): | |||||
| _model_mapping = MODEL_FOR_MULTIPLE_CHOICE_MAPPING | |||||
| AutoModelForMultipleChoice = auto_class_update(AutoModelForMultipleChoice, head_doc="multiple choice") | |||||
| class AutoModelForNextSentencePrediction(_BaseAutoModelClass): | |||||
| _model_mapping = MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING | |||||
| AutoModelForNextSentencePrediction = auto_class_update( | |||||
| AutoModelForNextSentencePrediction, head_doc="next sentence prediction" | |||||
| ) | |||||
| class AutoModelForImageClassification(_BaseAutoModelClass): | |||||
| _model_mapping = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING | |||||
| AutoModelForImageClassification = auto_class_update(AutoModelForImageClassification, head_doc="image classification") | |||||
| class AutoModelForObjectDetection(_BaseAutoModelClass): | |||||
| _model_mapping = MODEL_FOR_OBJECT_DETECTION_MAPPING | |||||
| AutoModelForObjectDetection = auto_class_update(AutoModelForObjectDetection, head_doc="object detection") | |||||
| class AutoModelForAudioClassification(_BaseAutoModelClass): | |||||
| _model_mapping = MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING | |||||
| AutoModelForAudioClassification = auto_class_update(AutoModelForAudioClassification, head_doc="audio classification") | |||||
| class AutoModelForCTC(_BaseAutoModelClass): | |||||
| _model_mapping = MODEL_FOR_CTC_MAPPING | |||||
| AutoModelForCTC = auto_class_update(AutoModelForCTC, head_doc="connectionist temporal classification") | |||||
| class AutoModelForSpeechSeq2Seq(_BaseAutoModelClass): | |||||
| _model_mapping = MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING | |||||
| AutoModelForSpeechSeq2Seq = auto_class_update( | |||||
| AutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeing" | |||||
| ) | |||||
| class AutoModelWithLMHead(_AutoModelWithLMHead): | |||||
| @classmethod | |||||
| def from_config(cls, config): | |||||
| warnings.warn( | |||||
| "The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use " | |||||
| "`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and " | |||||
| "`AutoModelForSeq2SeqLM` for encoder-decoder models.", | |||||
| FutureWarning, | |||||
| ) | |||||
| return super().from_config(config) | |||||
| @classmethod | |||||
| def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): | |||||
| warnings.warn( | |||||
| "The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use " | |||||
| "`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and " | |||||
| "`AutoModelForSeq2SeqLM` for encoder-decoder models.", | |||||
| FutureWarning, | |||||
| ) | |||||
| return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) | |||||