| @@ -1,26 +1,32 @@ | |||
| awk -F: '/^[^#]/ { print $1 }' requirements/framework.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html | |||
| awk -F: '/^[^#]/ { print $1 }' requirements/audio.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html | |||
| awk -F: '/^[^#]/ { print $1 }' requirements/cv.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html | |||
| awk -F: '/^[^#]/ { print $1 }' requirements/multi-modal.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html | |||
| awk -F: '/^[^#]/ { print $1 }' requirements/nlp.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html | |||
| pip install -r requirements/tests.txt | |||
| echo "Testing envs" | |||
| printenv | |||
| echo "ENV END" | |||
| if [ "$MODELSCOPE_SDK_DEBUG" == "True" ]; then | |||
| awk -F: '/^[^#]/ { print $1 }' requirements/framework.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html | |||
| awk -F: '/^[^#]/ { print $1 }' requirements/audio.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html | |||
| awk -F: '/^[^#]/ { print $1 }' requirements/cv.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html | |||
| awk -F: '/^[^#]/ { print $1 }' requirements/multi-modal.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html | |||
| awk -F: '/^[^#]/ { print $1 }' requirements/nlp.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html | |||
| pip install -r requirements/tests.txt | |||
| git config --global --add safe.directory /Maas-lib | |||
| git config --global user.email tmp | |||
| git config --global user.name tmp.com | |||
| git config --global --add safe.directory /Maas-lib | |||
| git config --global user.email tmp | |||
| git config --global user.name tmp.com | |||
| # linter test | |||
| # use internal project for pre-commit due to the network problem | |||
| if [ `git remote -v | grep alibaba | wc -l` -gt 1 ]; then | |||
| pre-commit run -c .pre-commit-config_local.yaml --all-files | |||
| fi | |||
| if [ $? -ne 0 ]; then | |||
| echo "linter test failed, please run 'pre-commit run --all-files' to check" | |||
| exit -1 | |||
| # linter test | |||
| # use internal project for pre-commit due to the network problem | |||
| if [ `git remote -v | grep alibaba | wc -l` -gt 1 ]; then | |||
| pre-commit run -c .pre-commit-config_local.yaml --all-files | |||
| if [ $? -ne 0 ]; then | |||
| echo "linter test failed, please run 'pre-commit run --all-files' to check" | |||
| exit -1 | |||
| fi | |||
| fi | |||
| # test with install | |||
| python setup.py install | |||
| else | |||
| echo "Running case in release image, run case directly!" | |||
| fi | |||
| # test with install | |||
| python setup.py install | |||
| if [ $# -eq 0 ]; then | |||
| ci_command="python tests/run.py --subprocess" | |||
| else | |||
| @@ -20,28 +20,52 @@ do | |||
| # pull image if there are update | |||
| docker pull ${IMAGE_NAME}:${IMAGE_VERSION} | |||
| docker run --rm --name $CONTAINER_NAME --shm-size=16gb \ | |||
| --cpuset-cpus=${cpu_sets_arr[$gpu]} \ | |||
| --gpus="device=$gpu" \ | |||
| -v $CODE_DIR:$CODE_DIR_IN_CONTAINER \ | |||
| -v $MODELSCOPE_CACHE:$MODELSCOPE_CACHE_DIR_IN_CONTAINER \ | |||
| -v $MODELSCOPE_HOME_CACHE/$gpu:/root \ | |||
| -v /home/admin/pre-commit:/home/admin/pre-commit \ | |||
| -e CI_TEST=True \ | |||
| -e TEST_LEVEL=$TEST_LEVEL \ | |||
| -e MODELSCOPE_CACHE=$MODELSCOPE_CACHE_DIR_IN_CONTAINER \ | |||
| -e MODELSCOPE_DOMAIN=$MODELSCOPE_DOMAIN \ | |||
| -e HUB_DATASET_ENDPOINT=$HUB_DATASET_ENDPOINT \ | |||
| -e TEST_ACCESS_TOKEN_CITEST=$TEST_ACCESS_TOKEN_CITEST \ | |||
| -e TEST_ACCESS_TOKEN_SDKDEV=$TEST_ACCESS_TOKEN_SDKDEV \ | |||
| -e TEST_LEVEL=$TEST_LEVEL \ | |||
| -e TEST_UPLOAD_MS_TOKEN=$TEST_UPLOAD_MS_TOKEN \ | |||
| -e MODEL_TAG_URL=$MODEL_TAG_URL \ | |||
| --workdir=$CODE_DIR_IN_CONTAINER \ | |||
| --net host \ | |||
| ${IMAGE_NAME}:${IMAGE_VERSION} \ | |||
| $CI_COMMAND | |||
| if [ "$MODELSCOPE_SDK_DEBUG" == "True" ]; then | |||
| docker run --rm --name $CONTAINER_NAME --shm-size=16gb \ | |||
| --cpuset-cpus=${cpu_sets_arr[$gpu]} \ | |||
| --gpus="device=$gpu" \ | |||
| -v $CODE_DIR:$CODE_DIR_IN_CONTAINER \ | |||
| -v $MODELSCOPE_CACHE:$MODELSCOPE_CACHE_DIR_IN_CONTAINER \ | |||
| -v $MODELSCOPE_HOME_CACHE/$gpu:/root \ | |||
| -v /home/admin/pre-commit:/home/admin/pre-commit \ | |||
| -e CI_TEST=True \ | |||
| -e TEST_LEVEL=$TEST_LEVEL \ | |||
| -e MODELSCOPE_CACHE=$MODELSCOPE_CACHE_DIR_IN_CONTAINER \ | |||
| -e MODELSCOPE_DOMAIN=$MODELSCOPE_DOMAIN \ | |||
| -e MODELSCOPE_SDK_DEBUG=True \ | |||
| -e HUB_DATASET_ENDPOINT=$HUB_DATASET_ENDPOINT \ | |||
| -e TEST_ACCESS_TOKEN_CITEST=$TEST_ACCESS_TOKEN_CITEST \ | |||
| -e TEST_ACCESS_TOKEN_SDKDEV=$TEST_ACCESS_TOKEN_SDKDEV \ | |||
| -e TEST_LEVEL=$TEST_LEVEL \ | |||
| -e TEST_UPLOAD_MS_TOKEN=$TEST_UPLOAD_MS_TOKEN \ | |||
| -e MODEL_TAG_URL=$MODEL_TAG_URL \ | |||
| --workdir=$CODE_DIR_IN_CONTAINER \ | |||
| --net host \ | |||
| ${IMAGE_NAME}:${IMAGE_VERSION} \ | |||
| $CI_COMMAND | |||
| else | |||
| docker run --rm --name $CONTAINER_NAME --shm-size=16gb \ | |||
| --cpuset-cpus=${cpu_sets_arr[$gpu]} \ | |||
| --gpus="device=$gpu" \ | |||
| -v $CODE_DIR:$CODE_DIR_IN_CONTAINER \ | |||
| -v $MODELSCOPE_CACHE:$MODELSCOPE_CACHE_DIR_IN_CONTAINER \ | |||
| -v $MODELSCOPE_HOME_CACHE/$gpu:/root \ | |||
| -v /home/admin/pre-commit:/home/admin/pre-commit \ | |||
| -e CI_TEST=True \ | |||
| -e TEST_LEVEL=$TEST_LEVEL \ | |||
| -e MODELSCOPE_CACHE=$MODELSCOPE_CACHE_DIR_IN_CONTAINER \ | |||
| -e MODELSCOPE_DOMAIN=$MODELSCOPE_DOMAIN \ | |||
| -e HUB_DATASET_ENDPOINT=$HUB_DATASET_ENDPOINT \ | |||
| -e TEST_ACCESS_TOKEN_CITEST=$TEST_ACCESS_TOKEN_CITEST \ | |||
| -e TEST_ACCESS_TOKEN_SDKDEV=$TEST_ACCESS_TOKEN_SDKDEV \ | |||
| -e TEST_LEVEL=$TEST_LEVEL \ | |||
| -e TEST_UPLOAD_MS_TOKEN=$TEST_UPLOAD_MS_TOKEN \ | |||
| -e MODEL_TAG_URL=$MODEL_TAG_URL \ | |||
| --workdir=$CODE_DIR_IN_CONTAINER \ | |||
| --net host \ | |||
| ${IMAGE_NAME}:${IMAGE_VERSION} \ | |||
| $CI_COMMAND | |||
| fi | |||
| if [ $? -ne 0 ]; then | |||
| echo "Running test case failed, please check the log!" | |||
| exit -1 | |||
| @@ -1,6 +1,6 @@ | |||
| repos: | |||
| - repo: https://gitlab.com/pycqa/flake8.git | |||
| rev: 3.8.3 | |||
| rev: 4.0.0 | |||
| hooks: | |||
| - id: flake8 | |||
| exclude: thirdparty/|examples/ | |||
| @@ -1,6 +1,6 @@ | |||
| repos: | |||
| - repo: /home/admin/pre-commit/flake8 | |||
| rev: 3.8.3 | |||
| rev: 4.0.0 | |||
| hooks: | |||
| - id: flake8 | |||
| exclude: thirdparty/|examples/ | |||
| @@ -0,0 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:e999c247bfebb03d556a31722f0ce7145cac20a67fac9da813ad336e1f549f9f | |||
| size 38954 | |||
| @@ -0,0 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:32eb8d4d537941bf0edea69cd6723e8ba489fa3df64e13e29f96e4fae0b856f4 | |||
| size 93676 | |||
| @@ -0,0 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:f57aee13ade70be6b2c6e4f5e5c7404bdb03057b63828baefbaadcf23855a4cb | |||
| size 472012 | |||
| @@ -0,0 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:fee8e0460ca707f108782be0d93c555bf34fb6b1cb297e5fceed70192cc65f9b | |||
| size 71244 | |||
| @@ -0,0 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:450e31f9df8c5b48c617900625f01cb64c484f079a9843179fe9feaa7d163e61 | |||
| size 181964 | |||
| @@ -0,0 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:255494c41bc1dfb0c954d827ec6ce775900e4f7a55fb0a7881bdf9d66a03b425 | |||
| size 112078 | |||
| @@ -0,0 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:22a55277908bbc3ef60a0cf56b230eb507b9e837574e8f493e93644b1d21c281 | |||
| size 200556 | |||
| @@ -0,0 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:ee92191836c76412463d8b282a7ab4e1aa57386ba699ec011a3e2c4d64f32f4b | |||
| size 162636 | |||
| @@ -0,0 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:77d1537fc584c1505d8aa10ec8c86af57ab661199e4f28fd7ffee3c22d1e4e61 | |||
| size 160204 | |||
| @@ -0,0 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:e8d653a9a1ee49789c3df38e8da96af7118e0d8336d6ed12cd6458efa015071d | |||
| size 2327764 | |||
| @@ -0,0 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:c589d77404ea17d4d24daeb8624dce7e1ac919dc75e6bed44ea9d116f0514150 | |||
| size 68524 | |||
| @@ -0,0 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:76bf84536edbaf192a8a699efc62ba2b06056bac12c426ecfcc2e003d91fbd32 | |||
| size 53219 | |||
| @@ -0,0 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:ecbc9d0827cfb92e93e7d75868b1724142685dc20d3b32023c3c657a7b688a9c | |||
| size 254845 | |||
| @@ -0,0 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:d510ab26ddc58ffea882c8ef850c1f9bd4444772f2bce7ebea3e76944536c3ae | |||
| size 48909 | |||
| @@ -0,0 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:b2c1119e3d521cf2e583b1e85fc9c9afd1d44954b433135039a98050a730932d | |||
| size 1127557 | |||
| @@ -0,0 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:46db348eae61448f1668ce282caec21375e96c3268d53da44aa67ec32cbf4fa5 | |||
| size 2747938 | |||
| @@ -0,0 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:709c1828ed2d56badf2f19a40194da9a5e5e6db2fb73ef55d047407f49bc7a15 | |||
| size 27616 | |||
| @@ -0,0 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:772b19f76c98044e39330853928624f10e085106a4292b4dd19f865531080747 | |||
| size 959 | |||
| @@ -1,3 +0,0 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:379e11d7fc3734d3ec95afd0d86460b4653fbf4bb1f57f993610d6a6fd30fd3d | |||
| size 1702339 | |||
| @@ -0,0 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:dec0fbb931cb609bf481e56b89cd2fbbab79839f22832c3bbe69a8fae2769cdd | |||
| size 167407 | |||
| @@ -1,3 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:4fd6fa6b23c2fdaf876606a767d9b64b1924e1acddfc06ac42db73ba86083280 | |||
| size 119940 | |||
| oid sha256:4eae921001139d7e3c06331c9ef2213f8fc1c23512acd95751559866fb770e96 | |||
| size 121855 | |||
| @@ -1,3 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:4d37672a0e299a08d2daf5c7fc29bfce96bb15701fe5e5e68f068861ac2ee705 | |||
| size 119619 | |||
| oid sha256:f97d34d7450d17d0a93647129ab10d16b1f6e70c34a73b6f7687b79519ee4f71 | |||
| size 121563 | |||
| @@ -1,3 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:c692e0753cfe349e520511427727a8252f141fa10e85f9a61562845e8d731f9a | |||
| size 119619 | |||
| oid sha256:a8355f27a3235209f206b5e75f4400353e5989e94cf4d71270b42ded8821d536 | |||
| size 121563 | |||
| @@ -0,0 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:344ef971bdf310b76c6571d1f4994ab6abc5edc659654d71a4f75b14a30960c2 | |||
| size 152926 | |||
| @@ -1,3 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:44e3925c15d86d8596baeb6bd1d153d86f57b7489798b2cf988a1248e110fd62 | |||
| size 62231 | |||
| oid sha256:f0aeb07b6c9b40a0cfa7492e839431764e9bece93c906833a07c05e83520a399 | |||
| size 63161 | |||
| @@ -1,3 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:1ff17a0272752de4c88d4254b2e881f97f8ef022f03609d03ee1de0ae964368a | |||
| size 62235 | |||
| oid sha256:7aa5c7a2565ccf0d2eea4baf8adbd0e020dbe36a7159b31156c53141cc9b2df2 | |||
| size 63165 | |||
| @@ -1,3 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:9103ce2bc89212f67fb49ce70783b7667e376900d0f70fb8f5c4432eb74bc572 | |||
| size 60801 | |||
| oid sha256:cc6de82a8485fbfa008f6c2d5411cd07ba03e4a780bcb4e67efc6fba3c6ce92f | |||
| size 63597 | |||
| @@ -1,3 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:2d4dee34c7e83b77db04fb2f0d1200bfd37c7c24954c58e185da5cb96445975c | |||
| size 60801 | |||
| oid sha256:7d98ac11a4e9e2744a7402a5cc912da991a41938bbc5dd60f15ee5c6b3196030 | |||
| size 63349 | |||
| @@ -1,3 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:9e3ecc2c30d382641d561f84849b199c12bb1a9418e8099a191153f6f5275a85 | |||
| size 61589 | |||
| oid sha256:01f9b9bf6f8bbf9bb377d4cb6f399b2e5e065381f5b7332343e0db7b4fae72a5 | |||
| size 62519 | |||
| @@ -0,0 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:a49c9bc74a60860c360a4bf4509fe9db915279aaabd953f354f2c38e9be1e6cb | |||
| size 2924691 | |||
| @@ -0,0 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:f58df1d25590c158ae0a04b3999bd44b610cdaddb17d78afd84c34b3f00d4e87 | |||
| size 4068783 | |||
| @@ -76,7 +76,7 @@ RUN pip install --no-cache-dir --upgrade pip && \ | |||
| ENV SHELL=/bin/bash | |||
| # install special package | |||
| RUN pip install --no-cache-dir mmcls>=0.21.0 mmdet>=2.25.0 decord>=0.6.0 datasets==2.1.0 numpy==1.18.5 ipykernel fairseq | |||
| RUN pip install --no-cache-dir mmcls>=0.21.0 mmdet>=2.25.0 decord>=0.6.0 datasets==2.1.0 numpy==1.18.5 ipykernel fairseq fasttext https://modelscope.oss-cn-beijing.aliyuncs.com/releases/dependencies/xtcocotools-1.12-cp37-cp37m-linux_x86_64.whl | |||
| RUN if [ "$USE_GPU" = "True" ] ; then \ | |||
| pip install --no-cache-dir dgl-cu113 dglgo -f https://data.dgl.ai/wheels/repo.html; \ | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from .version import __version__ | |||
| from .version import __release_datetime__, __version__ | |||
| __all__ = ['__version__'] | |||
| __all__ = ['__version__', '__release_datetime__'] | |||
| @@ -19,10 +19,13 @@ class Exporter(ABC): | |||
| def from_model(cls, model: Model, **kwargs): | |||
| """Build the Exporter instance. | |||
| @param model: A model instance. it will be used to output the generated file, | |||
| Args: | |||
| model: A Model instance. it will be used to generate the intermediate format file, | |||
| and the configuration.json in its model_dir field will be used to create the exporter instance. | |||
| @param kwargs: Extra kwargs used to create the Exporter instance. | |||
| @return: The Exporter instance | |||
| kwargs: Extra kwargs used to create the Exporter instance. | |||
| Returns: | |||
| The Exporter instance | |||
| """ | |||
| cfg = Config.from_file( | |||
| os.path.join(model.model_dir, ModelFile.CONFIGURATION)) | |||
| @@ -44,10 +47,13 @@ class Exporter(ABC): | |||
| In some cases, several files may be generated, | |||
| So please return a dict which contains the generated name with the file path. | |||
| @param opset: The version of the ONNX operator set to use. | |||
| @param outputs: The output dir. | |||
| @param kwargs: In this default implementation, | |||
| kwargs will be carried to generate_dummy_inputs as extra arguments (like input shape). | |||
| @return: A dict contains the model name with the model file path. | |||
| Args: | |||
| opset: The version of the ONNX operator set to use. | |||
| outputs: The output dir. | |||
| kwargs: In this default implementation, | |||
| kwargs will be carried to generate_dummy_inputs as extra arguments (like input shape). | |||
| Returns: | |||
| A dict contains the model name with the model file path. | |||
| """ | |||
| pass | |||
| @@ -23,13 +23,18 @@ class SbertForSequenceClassificationExporter(TorchModelExporter): | |||
| def generate_dummy_inputs(self, | |||
| shape: Tuple = None, | |||
| pair: bool = False, | |||
| **kwargs) -> Dict[str, Any]: | |||
| """Generate dummy inputs for model exportation to onnx or other formats by tracing. | |||
| @param shape: A tuple of input shape which should have at most two dimensions. | |||
| shape = (1, ) batch_size=1, sequence_length will be taken from the preprocessor. | |||
| shape = (8, 128) batch_size=1, sequence_length=128, which will cover the config of the preprocessor. | |||
| @return: Dummy inputs. | |||
| Args: | |||
| shape: A tuple of input shape which should have at most two dimensions. | |||
| shape = (1, ) batch_size=1, sequence_length will be taken from the preprocessor. | |||
| shape = (8, 128) batch_size=1, sequence_length=128, which will cover the config of the preprocessor. | |||
| pair(bool, `optional`): Whether to generate sentence pairs or single sentences. | |||
| Returns: | |||
| Dummy inputs. | |||
| """ | |||
| cfg = Config.from_file( | |||
| @@ -55,7 +60,7 @@ class SbertForSequenceClassificationExporter(TorchModelExporter): | |||
| **sequence_length | |||
| }) | |||
| preprocessor: Preprocessor = build_preprocessor(cfg, field_name) | |||
| if preprocessor.pair: | |||
| if pair: | |||
| first_sequence = preprocessor.tokenizer.unk_token | |||
| second_sequence = preprocessor.tokenizer.unk_token | |||
| else: | |||
| @@ -13,8 +13,8 @@ from modelscope.models import TorchModel | |||
| from modelscope.pipelines.base import collate_fn | |||
| from modelscope.utils.constant import ModelFile | |||
| from modelscope.utils.logger import get_logger | |||
| from modelscope.utils.regress_test_utils import compare_arguments_nested | |||
| from modelscope.utils.tensor_utils import torch_nested_numpify | |||
| from modelscope.utils.regress_test_utils import (compare_arguments_nested, | |||
| numpify_tensor_nested) | |||
| from .base import Exporter | |||
| logger = get_logger(__name__) | |||
| @@ -28,49 +28,61 @@ class TorchModelExporter(Exporter): | |||
| and to provide implementations for generate_dummy_inputs/inputs/outputs methods. | |||
| """ | |||
| def export_onnx(self, outputs: str, opset=11, **kwargs): | |||
| def export_onnx(self, output_dir: str, opset=13, **kwargs): | |||
| """Export the model as onnx format files. | |||
| In some cases, several files may be generated, | |||
| So please return a dict which contains the generated name with the file path. | |||
| @param opset: The version of the ONNX operator set to use. | |||
| @param outputs: The output dir. | |||
| @param kwargs: In this default implementation, | |||
| you can pass the arguments needed by _torch_export_onnx, other unrecognized args | |||
| will be carried to generate_dummy_inputs as extra arguments (such as input shape). | |||
| @return: A dict containing the model key - model file path pairs. | |||
| Args: | |||
| opset: The version of the ONNX operator set to use. | |||
| output_dir: The output dir. | |||
| kwargs: | |||
| model: A model instance which will replace the exporting of self.model. | |||
| In this default implementation, | |||
| you can pass the arguments needed by _torch_export_onnx, other unrecognized args | |||
| will be carried to generate_dummy_inputs as extra arguments (such as input shape). | |||
| Returns: | |||
| A dict containing the model key - model file path pairs. | |||
| """ | |||
| model = self.model | |||
| model = self.model if 'model' not in kwargs else kwargs.pop('model') | |||
| if not isinstance(model, nn.Module) and hasattr(model, 'model'): | |||
| model = model.model | |||
| onnx_file = os.path.join(outputs, ModelFile.ONNX_MODEL_FILE) | |||
| onnx_file = os.path.join(output_dir, ModelFile.ONNX_MODEL_FILE) | |||
| self._torch_export_onnx(model, onnx_file, opset=opset, **kwargs) | |||
| return {'model': onnx_file} | |||
| def export_torch_script(self, outputs: str, **kwargs): | |||
| def export_torch_script(self, output_dir: str, **kwargs): | |||
| """Export the model as torch script files. | |||
| In some cases, several files may be generated, | |||
| So please return a dict which contains the generated name with the file path. | |||
| @param outputs: The output dir. | |||
| @param kwargs: In this default implementation, | |||
| Args: | |||
| output_dir: The output dir. | |||
| kwargs: | |||
| model: A model instance which will replace the exporting of self.model. | |||
| In this default implementation, | |||
| you can pass the arguments needed by _torch_export_torch_script, other unrecognized args | |||
| will be carried to generate_dummy_inputs as extra arguments (like input shape). | |||
| @return: A dict contains the model name with the model file path. | |||
| Returns: | |||
| A dict contains the model name with the model file path. | |||
| """ | |||
| model = self.model | |||
| model = self.model if 'model' not in kwargs else kwargs.pop('model') | |||
| if not isinstance(model, nn.Module) and hasattr(model, 'model'): | |||
| model = model.model | |||
| ts_file = os.path.join(outputs, ModelFile.TS_MODEL_FILE) | |||
| ts_file = os.path.join(output_dir, ModelFile.TS_MODEL_FILE) | |||
| # generate ts by tracing | |||
| self._torch_export_torch_script(model, ts_file, **kwargs) | |||
| return {'model': ts_file} | |||
| def generate_dummy_inputs(self, **kwargs) -> Dict[str, Any]: | |||
| """Generate dummy inputs for model exportation to onnx or other formats by tracing. | |||
| @return: Dummy inputs. | |||
| Returns: | |||
| Dummy inputs. | |||
| """ | |||
| return None | |||
| @@ -93,7 +105,7 @@ class TorchModelExporter(Exporter): | |||
| def _torch_export_onnx(self, | |||
| model: nn.Module, | |||
| output: str, | |||
| opset: int = 11, | |||
| opset: int = 13, | |||
| device: str = 'cpu', | |||
| validation: bool = True, | |||
| rtol: float = None, | |||
| @@ -101,18 +113,27 @@ class TorchModelExporter(Exporter): | |||
| **kwargs): | |||
| """Export the model to an onnx format file. | |||
| @param model: A torch.nn.Module instance to export. | |||
| @param output: The output file. | |||
| @param opset: The version of the ONNX operator set to use. | |||
| @param device: The device used to forward. | |||
| @param validation: Whether validate the export file. | |||
| @param rtol: The rtol used to regress the outputs. | |||
| @param atol: The atol used to regress the outputs. | |||
| Args: | |||
| model: A torch.nn.Module instance to export. | |||
| output: The output file. | |||
| opset: The version of the ONNX operator set to use. | |||
| device: The device used to forward. | |||
| validation: Whether validate the export file. | |||
| rtol: The rtol used to regress the outputs. | |||
| atol: The atol used to regress the outputs. | |||
| kwargs: | |||
| dummy_inputs: A dummy inputs which will replace the calling of self.generate_dummy_inputs(). | |||
| inputs: An inputs structure which will replace the calling of self.inputs. | |||
| outputs: An outputs structure which will replace the calling of self.outputs. | |||
| """ | |||
| dummy_inputs = self.generate_dummy_inputs(**kwargs) | |||
| inputs = self.inputs | |||
| outputs = self.outputs | |||
| dummy_inputs = self.generate_dummy_inputs( | |||
| **kwargs) if 'dummy_inputs' not in kwargs else kwargs.pop( | |||
| 'dummy_inputs') | |||
| inputs = self.inputs if 'inputs' not in kwargs else kwargs.pop( | |||
| 'inputs') | |||
| outputs = self.outputs if 'outputs' not in kwargs else kwargs.pop( | |||
| 'outputs') | |||
| if dummy_inputs is None or inputs is None or outputs is None: | |||
| raise NotImplementedError( | |||
| 'Model property dummy_inputs,inputs,outputs must be set.') | |||
| @@ -125,7 +146,7 @@ class TorchModelExporter(Exporter): | |||
| if isinstance(dummy_inputs, Mapping): | |||
| dummy_inputs = dict(dummy_inputs) | |||
| onnx_outputs = list(self.outputs.keys()) | |||
| onnx_outputs = list(outputs.keys()) | |||
| with replace_call(): | |||
| onnx_export( | |||
| @@ -160,11 +181,13 @@ class TorchModelExporter(Exporter): | |||
| outputs_origin = model.forward( | |||
| *_decide_input_format(model, dummy_inputs)) | |||
| if isinstance(outputs_origin, Mapping): | |||
| outputs_origin = torch_nested_numpify( | |||
| outputs_origin = numpify_tensor_nested( | |||
| list(outputs_origin.values())) | |||
| elif isinstance(outputs_origin, (tuple, list)): | |||
| outputs_origin = numpify_tensor_nested(outputs_origin) | |||
| outputs = ort_session.run( | |||
| onnx_outputs, | |||
| torch_nested_numpify(dummy_inputs), | |||
| numpify_tensor_nested(dummy_inputs), | |||
| ) | |||
| tols = {} | |||
| @@ -184,19 +207,26 @@ class TorchModelExporter(Exporter): | |||
| validation: bool = True, | |||
| rtol: float = None, | |||
| atol: float = None, | |||
| strict: bool = True, | |||
| **kwargs): | |||
| """Export the model to a torch script file. | |||
| @param model: A torch.nn.Module instance to export. | |||
| @param output: The output file. | |||
| @param device: The device used to forward. | |||
| @param validation: Whether validate the export file. | |||
| @param rtol: The rtol used to regress the outputs. | |||
| @param atol: The atol used to regress the outputs. | |||
| Args: | |||
| model: A torch.nn.Module instance to export. | |||
| output: The output file. | |||
| device: The device used to forward. | |||
| validation: Whether validate the export file. | |||
| rtol: The rtol used to regress the outputs. | |||
| atol: The atol used to regress the outputs. | |||
| strict: strict mode in torch script tracing. | |||
| kwargs: | |||
| dummy_inputs: A dummy inputs which will replace the calling of self.generate_dummy_inputs(). | |||
| """ | |||
| model.eval() | |||
| dummy_inputs = self.generate_dummy_inputs(**kwargs) | |||
| dummy_param = 'dummy_inputs' not in kwargs | |||
| dummy_inputs = self.generate_dummy_inputs( | |||
| **kwargs) if dummy_param else kwargs.pop('dummy_inputs') | |||
| if dummy_inputs is None: | |||
| raise NotImplementedError( | |||
| 'Model property dummy_inputs must be set.') | |||
| @@ -207,7 +237,7 @@ class TorchModelExporter(Exporter): | |||
| model.eval() | |||
| with replace_call(): | |||
| traced_model = torch.jit.trace( | |||
| model, dummy_inputs, strict=False) | |||
| model, dummy_inputs, strict=strict) | |||
| torch.jit.save(traced_model, output) | |||
| if validation: | |||
| @@ -216,9 +246,9 @@ class TorchModelExporter(Exporter): | |||
| model.eval() | |||
| ts_model.eval() | |||
| outputs = ts_model.forward(*dummy_inputs) | |||
| outputs = torch_nested_numpify(outputs) | |||
| outputs = numpify_tensor_nested(outputs) | |||
| outputs_origin = model.forward(*dummy_inputs) | |||
| outputs_origin = torch_nested_numpify(outputs_origin) | |||
| outputs_origin = numpify_tensor_nested(outputs_origin) | |||
| tols = {} | |||
| if rtol is not None: | |||
| tols['rtol'] = rtol | |||
| @@ -240,7 +270,6 @@ def replace_call(): | |||
| problems. Here we recover the call method to the default implementation of torch.nn.Module, and change it | |||
| back after the tracing was done. | |||
| """ | |||
| TorchModel.call_origin, TorchModel.__call__ = TorchModel.__call__, TorchModel._call_impl | |||
| yield | |||
| TorchModel.__call__ = TorchModel.call_origin | |||
| @@ -1,32 +1,47 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| # yapf: disable | |||
| import datetime | |||
| import os | |||
| import pickle | |||
| import platform | |||
| import shutil | |||
| import tempfile | |||
| import uuid | |||
| from collections import defaultdict | |||
| from http import HTTPStatus | |||
| from http.cookiejar import CookieJar | |||
| from os.path import expanduser | |||
| from typing import List, Optional, Tuple, Union | |||
| from typing import Dict, List, Optional, Tuple, Union | |||
| import requests | |||
| from modelscope import __version__ | |||
| from modelscope.hub.constants import (API_RESPONSE_FIELD_DATA, | |||
| API_RESPONSE_FIELD_EMAIL, | |||
| API_RESPONSE_FIELD_GIT_ACCESS_TOKEN, | |||
| API_RESPONSE_FIELD_MESSAGE, | |||
| API_RESPONSE_FIELD_USERNAME, | |||
| DEFAULT_CREDENTIALS_PATH) | |||
| DEFAULT_CREDENTIALS_PATH, | |||
| MODELSCOPE_ENVIRONMENT, ONE_YEAR_SECONDS, | |||
| Licenses, ModelVisibility) | |||
| from modelscope.hub.errors import (InvalidParameter, NotExistError, | |||
| NotLoginException, NoValidRevisionError, | |||
| RequestError, datahub_raise_on_error, | |||
| handle_http_post_error, | |||
| handle_http_response, is_ok, | |||
| raise_for_http_status, raise_on_error) | |||
| from modelscope.hub.git import GitCommandWrapper | |||
| from modelscope.hub.repository import Repository | |||
| from modelscope.utils.config_ds import DOWNLOADED_DATASETS_PATH | |||
| from modelscope.utils.constant import (DEFAULT_DATASET_REVISION, | |||
| DEFAULT_MODEL_REVISION, | |||
| DatasetFormations, DatasetMetaFormats, | |||
| DownloadMode) | |||
| DEFAULT_REPOSITORY_REVISION, | |||
| MASTER_MODEL_BRANCH, DatasetFormations, | |||
| DatasetMetaFormats, DownloadMode, | |||
| ModelFile) | |||
| from modelscope.utils.logger import get_logger | |||
| from .errors import (InvalidParameter, NotExistError, RequestError, | |||
| datahub_raise_on_error, handle_http_response, is_ok, | |||
| raise_on_error) | |||
| from .utils.utils import (get_dataset_hub_endpoint, get_endpoint, | |||
| from .utils.utils import (get_endpoint, get_release_datetime, | |||
| model_id_to_group_owner_name) | |||
| logger = get_logger() | |||
| @@ -34,10 +49,9 @@ logger = get_logger() | |||
| class HubApi: | |||
| def __init__(self, endpoint=None, dataset_endpoint=None): | |||
| def __init__(self, endpoint=None): | |||
| self.endpoint = endpoint if endpoint is not None else get_endpoint() | |||
| self.dataset_endpoint = dataset_endpoint if dataset_endpoint is not None else get_dataset_hub_endpoint( | |||
| ) | |||
| self.headers = {'user-agent': ModelScopeConfig.get_user_agent()} | |||
| def login( | |||
| self, | |||
| @@ -57,8 +71,9 @@ class HubApi: | |||
| </Tip> | |||
| """ | |||
| path = f'{self.endpoint}/api/v1/login' | |||
| r = requests.post(path, json={'AccessToken': access_token}) | |||
| r.raise_for_status() | |||
| r = requests.post( | |||
| path, json={'AccessToken': access_token}, headers=self.headers) | |||
| raise_for_http_status(r) | |||
| d = r.json() | |||
| raise_on_error(d) | |||
| @@ -105,17 +120,16 @@ class HubApi: | |||
| path = f'{self.endpoint}/api/v1/models' | |||
| owner_or_group, name = model_id_to_group_owner_name(model_id) | |||
| body = { | |||
| 'Path': owner_or_group, | |||
| 'Name': name, | |||
| 'ChineseName': chinese_name, | |||
| 'Visibility': visibility, # server check | |||
| 'License': license | |||
| } | |||
| r = requests.post( | |||
| path, | |||
| json={ | |||
| 'Path': owner_or_group, | |||
| 'Name': name, | |||
| 'ChineseName': chinese_name, | |||
| 'Visibility': visibility, # server check | |||
| 'License': license | |||
| }, | |||
| cookies=cookies) | |||
| r.raise_for_status() | |||
| path, json=body, cookies=cookies, headers=self.headers) | |||
| handle_http_post_error(r, path, body) | |||
| raise_on_error(r.json()) | |||
| model_repo_url = f'{get_endpoint()}/{model_id}' | |||
| return model_repo_url | |||
| @@ -134,8 +148,8 @@ class HubApi: | |||
| raise ValueError('Token does not exist, please login first.') | |||
| path = f'{self.endpoint}/api/v1/models/{model_id}' | |||
| r = requests.delete(path, cookies=cookies) | |||
| r.raise_for_status() | |||
| r = requests.delete(path, cookies=cookies, headers=self.headers) | |||
| raise_for_http_status(r) | |||
| raise_on_error(r.json()) | |||
| def get_model_url(self, model_id): | |||
| @@ -164,7 +178,7 @@ class HubApi: | |||
| owner_or_group, name = model_id_to_group_owner_name(model_id) | |||
| path = f'{self.endpoint}/api/v1/models/{owner_or_group}/{name}?Revision={revision}' | |||
| r = requests.get(path, cookies=cookies) | |||
| r = requests.get(path, cookies=cookies, headers=self.headers) | |||
| handle_http_response(r, logger, cookies, model_id) | |||
| if r.status_code == HTTPStatus.OK: | |||
| if is_ok(r.json()): | |||
| @@ -172,13 +186,108 @@ class HubApi: | |||
| else: | |||
| raise NotExistError(r.json()[API_RESPONSE_FIELD_MESSAGE]) | |||
| else: | |||
| r.raise_for_status() | |||
| raise_for_http_status(r) | |||
| def push_model(self, | |||
| model_id: str, | |||
| model_dir: str, | |||
| visibility: int = ModelVisibility.PUBLIC, | |||
| license: str = Licenses.APACHE_V2, | |||
| chinese_name: Optional[str] = None, | |||
| commit_message: Optional[str] = 'upload model', | |||
| revision: Optional[str] = DEFAULT_REPOSITORY_REVISION): | |||
| """ | |||
| Upload model from a given directory to given repository. A valid model directory | |||
| must contain a configuration.json file. | |||
| def list_model(self, | |||
| owner_or_group: str, | |||
| page_number=1, | |||
| page_size=10) -> dict: | |||
| """List model in owner or group. | |||
| This function upload the files in given directory to given repository. If the | |||
| given repository is not exists in remote, it will automatically create it with | |||
| given visibility, license and chinese_name parameters. If the revision is also | |||
| not exists in remote repository, it will create a new branch for it. | |||
| This function must be called before calling HubApi's login with a valid token | |||
| which can be obtained from ModelScope's website. | |||
| Args: | |||
| model_id (`str`): | |||
| The model id to be uploaded, caller must have write permission for it. | |||
| model_dir(`str`): | |||
| The Absolute Path of the finetune result. | |||
| visibility(`int`, defaults to `0`): | |||
| Visibility of the new created model(1-private, 5-public). If the model is | |||
| not exists in ModelScope, this function will create a new model with this | |||
| visibility and this parameter is required. You can ignore this parameter | |||
| if you make sure the model's existence. | |||
| license(`str`, defaults to `None`): | |||
| License of the new created model(see License). If the model is not exists | |||
| in ModelScope, this function will create a new model with this license | |||
| and this parameter is required. You can ignore this parameter if you | |||
| make sure the model's existence. | |||
| chinese_name(`str`, *optional*, defaults to `None`): | |||
| chinese name of the new created model. | |||
| commit_message(`str`, *optional*, defaults to `None`): | |||
| commit message of the push request. | |||
| revision (`str`, *optional*, default to DEFAULT_MODEL_REVISION): | |||
| which branch to push. If the branch is not exists, It will create a new | |||
| branch and push to it. | |||
| """ | |||
| if model_id is None: | |||
| raise InvalidParameter('model_id cannot be empty!') | |||
| if model_dir is None: | |||
| raise InvalidParameter('model_dir cannot be empty!') | |||
| if not os.path.exists(model_dir) or os.path.isfile(model_dir): | |||
| raise InvalidParameter('model_dir must be a valid directory.') | |||
| cfg_file = os.path.join(model_dir, ModelFile.CONFIGURATION) | |||
| if not os.path.exists(cfg_file): | |||
| raise ValueError(f'{model_dir} must contain a configuration.json.') | |||
| cookies = ModelScopeConfig.get_cookies() | |||
| if cookies is None: | |||
| raise NotLoginException('Must login before upload!') | |||
| files_to_save = os.listdir(model_dir) | |||
| try: | |||
| self.get_model(model_id=model_id) | |||
| except Exception: | |||
| if visibility is None or license is None: | |||
| raise InvalidParameter( | |||
| 'visibility and license cannot be empty if want to create new repo' | |||
| ) | |||
| logger.info('Create new model %s' % model_id) | |||
| self.create_model( | |||
| model_id=model_id, | |||
| visibility=visibility, | |||
| license=license, | |||
| chinese_name=chinese_name) | |||
| tmp_dir = tempfile.mkdtemp() | |||
| git_wrapper = GitCommandWrapper() | |||
| try: | |||
| repo = Repository(model_dir=tmp_dir, clone_from=model_id) | |||
| branches = git_wrapper.get_remote_branches(tmp_dir) | |||
| if revision not in branches: | |||
| logger.info('Create new branch %s' % revision) | |||
| git_wrapper.new_branch(tmp_dir, revision) | |||
| git_wrapper.checkout(tmp_dir, revision) | |||
| for f in files_to_save: | |||
| if f[0] != '.': | |||
| src = os.path.join(model_dir, f) | |||
| if os.path.isdir(src): | |||
| shutil.copytree(src, os.path.join(tmp_dir, f)) | |||
| else: | |||
| shutil.copy(src, tmp_dir) | |||
| if not commit_message: | |||
| date = datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S') | |||
| commit_message = '[automsg] push model %s to hub at %s' % ( | |||
| model_id, date) | |||
| repo.push(commit_message=commit_message, local_branch=revision, remote_branch=revision) | |||
| except Exception: | |||
| raise | |||
| finally: | |||
| shutil.rmtree(tmp_dir, ignore_errors=True) | |||
| def list_models(self, | |||
| owner_or_group: str, | |||
| page_number=1, | |||
| page_size=10) -> dict: | |||
| """List models in owner or group. | |||
| Args: | |||
| owner_or_group(`str`): owner or group. | |||
| @@ -193,7 +302,8 @@ class HubApi: | |||
| path, | |||
| data='{"Path":"%s", "PageNumber":%s, "PageSize": %s}' % | |||
| (owner_or_group, page_number, page_size), | |||
| cookies=cookies) | |||
| cookies=cookies, | |||
| headers=self.headers) | |||
| handle_http_response(r, logger, cookies, 'list_model') | |||
| if r.status_code == HTTPStatus.OK: | |||
| if is_ok(r.json()): | |||
| @@ -202,7 +312,7 @@ class HubApi: | |||
| else: | |||
| raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE]) | |||
| else: | |||
| r.raise_for_status() | |||
| raise_for_http_status(r) | |||
| return None | |||
| def _check_cookie(self, | |||
| @@ -217,10 +327,70 @@ class HubApi: | |||
| raise ValueError('Token does not exist, please login first.') | |||
| return cookies | |||
| def list_model_revisions( | |||
| self, | |||
| model_id: str, | |||
| cutoff_timestamp: int = None, | |||
| use_cookies: Union[bool, CookieJar] = False) -> List[str]: | |||
| """Get model branch and tags. | |||
| Args: | |||
| model_id (str): The model id | |||
| cutoff_timestamp (int): Tags created before the cutoff will be included. | |||
| The timestamp is represented by the seconds elasped from the epoch time. | |||
| use_cookies (Union[bool, CookieJar], optional): If is cookieJar, we will use this cookie, if True, will | |||
| will load cookie from local. Defaults to False. | |||
| Returns: | |||
| Tuple[List[str], List[str]]: Return list of branch name and tags | |||
| """ | |||
| cookies = self._check_cookie(use_cookies) | |||
| if cutoff_timestamp is None: | |||
| cutoff_timestamp = get_release_datetime() | |||
| path = f'{self.endpoint}/api/v1/models/{model_id}/revisions?EndTime=%s' % cutoff_timestamp | |||
| r = requests.get(path, cookies=cookies, headers=self.headers) | |||
| handle_http_response(r, logger, cookies, model_id) | |||
| d = r.json() | |||
| raise_on_error(d) | |||
| info = d[API_RESPONSE_FIELD_DATA] | |||
| # tags returned from backend are guaranteed to be ordered by create-time | |||
| tags = [x['Revision'] for x in info['RevisionMap']['Tags'] | |||
| ] if info['RevisionMap']['Tags'] else [] | |||
| return tags | |||
| def get_valid_revision(self, model_id: str, revision=None, cookies: Optional[CookieJar] = None): | |||
| release_timestamp = get_release_datetime() | |||
| current_timestamp = int(round(datetime.datetime.now().timestamp())) | |||
| # for active development in library codes (non-release-branches), release_timestamp | |||
| # is set to be a far-away-time-in-the-future, to ensure that we shall | |||
| # get the master-HEAD version from model repo by default (when no revision is provided) | |||
| if release_timestamp > current_timestamp + ONE_YEAR_SECONDS: | |||
| branches, tags = self.get_model_branches_and_tags( | |||
| model_id, use_cookies=False if cookies is None else cookies) | |||
| if revision is None: | |||
| revision = MASTER_MODEL_BRANCH | |||
| logger.info('Model revision not specified, use default: %s in development mode' % revision) | |||
| if revision not in branches and revision not in tags: | |||
| raise NotExistError('The model: %s has no branch or tag : %s .' % revision) | |||
| else: | |||
| revisions = self.list_model_revisions( | |||
| model_id, cutoff_timestamp=release_timestamp, use_cookies=False if cookies is None else cookies) | |||
| if revision is None: | |||
| if len(revisions) == 0: | |||
| raise NoValidRevisionError('The model: %s has no valid revision!' % model_id) | |||
| # tags (revisions) returned from backend are guaranteed to be ordered by create-time | |||
| # we shall obtain the latest revision created earlier than release version of this branch | |||
| revision = revisions[0] | |||
| logger.info('Model revision not specified, use the latest revision: %s' % revision) | |||
| else: | |||
| if revision not in revisions: | |||
| raise NotExistError( | |||
| 'The model: %s has no revision: %s !' % (model_id, revision)) | |||
| return revision | |||
| def get_model_branches_and_tags( | |||
| self, | |||
| model_id: str, | |||
| use_cookies: Union[bool, CookieJar] = False | |||
| use_cookies: Union[bool, CookieJar] = False, | |||
| ) -> Tuple[List[str], List[str]]: | |||
| """Get model branch and tags. | |||
| @@ -234,7 +404,7 @@ class HubApi: | |||
| cookies = self._check_cookie(use_cookies) | |||
| path = f'{self.endpoint}/api/v1/models/{model_id}/revisions' | |||
| r = requests.get(path, cookies=cookies) | |||
| r = requests.get(path, cookies=cookies, headers=self.headers) | |||
| handle_http_response(r, logger, cookies, model_id) | |||
| d = r.json() | |||
| raise_on_error(d) | |||
| @@ -275,7 +445,11 @@ class HubApi: | |||
| if root is not None: | |||
| path = path + f'&Root={root}' | |||
| r = requests.get(path, cookies=cookies, headers=headers) | |||
| r = requests.get( | |||
| path, cookies=cookies, headers={ | |||
| **headers, | |||
| **self.headers | |||
| }) | |||
| handle_http_response(r, logger, cookies, model_id) | |||
| d = r.json() | |||
| @@ -290,11 +464,10 @@ class HubApi: | |||
| return files | |||
| def list_datasets(self): | |||
| path = f'{self.dataset_endpoint}/api/v1/datasets' | |||
| headers = None | |||
| path = f'{self.endpoint}/api/v1/datasets' | |||
| params = {} | |||
| r = requests.get(path, params=params, headers=headers) | |||
| r.raise_for_status() | |||
| r = requests.get(path, params=params, headers=self.headers) | |||
| raise_for_http_status(r) | |||
| dataset_list = r.json()[API_RESPONSE_FIELD_DATA] | |||
| return [x['Name'] for x in dataset_list] | |||
| @@ -317,14 +490,14 @@ class HubApi: | |||
| cache_dir): | |||
| shutil.rmtree(cache_dir) | |||
| os.makedirs(cache_dir, exist_ok=True) | |||
| datahub_url = f'{self.dataset_endpoint}/api/v1/datasets/{namespace}/{dataset_name}' | |||
| datahub_url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}' | |||
| r = requests.get(datahub_url) | |||
| resp = r.json() | |||
| datahub_raise_on_error(datahub_url, resp) | |||
| dataset_id = resp['Data']['Id'] | |||
| dataset_type = resp['Data']['Type'] | |||
| datahub_url = f'{self.dataset_endpoint}/api/v1/datasets/{dataset_id}/repo/tree?Revision={revision}' | |||
| r = requests.get(datahub_url) | |||
| datahub_url = f'{self.endpoint}/api/v1/datasets/{dataset_id}/repo/tree?Revision={revision}' | |||
| r = requests.get(datahub_url, headers=self.headers) | |||
| resp = r.json() | |||
| datahub_raise_on_error(datahub_url, resp) | |||
| file_list = resp['Data'] | |||
| @@ -341,10 +514,10 @@ class HubApi: | |||
| file_path = file_info['Path'] | |||
| extension = os.path.splitext(file_path)[-1] | |||
| if extension in dataset_meta_format: | |||
| datahub_url = f'{self.dataset_endpoint}/api/v1/datasets/{namespace}/{dataset_name}/repo?' \ | |||
| datahub_url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/repo?' \ | |||
| f'Revision={revision}&FilePath={file_path}' | |||
| r = requests.get(datahub_url) | |||
| r.raise_for_status() | |||
| raise_for_http_status(r) | |||
| local_path = os.path.join(cache_dir, file_path) | |||
| if os.path.exists(local_path): | |||
| logger.warning( | |||
| @@ -365,7 +538,7 @@ class HubApi: | |||
| namespace: str, | |||
| revision: Optional[str] = DEFAULT_DATASET_REVISION): | |||
| if file_name.endswith('.csv'): | |||
| file_name = f'{self.dataset_endpoint}/api/v1/datasets/{namespace}/{dataset_name}/repo?' \ | |||
| file_name = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/repo?' \ | |||
| f'Revision={revision}&FilePath={file_name}' | |||
| return file_name | |||
| @@ -374,7 +547,7 @@ class HubApi: | |||
| dataset_name: str, | |||
| namespace: str, | |||
| revision: Optional[str] = DEFAULT_DATASET_REVISION): | |||
| datahub_url = f'{self.dataset_endpoint}/api/v1/datasets/{namespace}/{dataset_name}/' \ | |||
| datahub_url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/' \ | |||
| f'ststoken?Revision={revision}' | |||
| return self.datahub_remote_call(datahub_url) | |||
| @@ -385,23 +558,39 @@ class HubApi: | |||
| namespace: str, | |||
| revision: Optional[str] = DEFAULT_DATASET_REVISION): | |||
| datahub_url = f'{self.dataset_endpoint}/api/v1/datasets/{namespace}/{dataset_name}/' \ | |||
| datahub_url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/' \ | |||
| f'ststoken?Revision={revision}' | |||
| cookies = requests.utils.dict_from_cookiejar(cookies) | |||
| r = requests.get(url=datahub_url, cookies=cookies) | |||
| r = requests.get( | |||
| url=datahub_url, cookies=cookies, headers=self.headers) | |||
| resp = r.json() | |||
| raise_on_error(resp) | |||
| return resp['Data'] | |||
| def list_oss_dataset_objects(self, dataset_name, namespace, max_limit, | |||
| is_recursive, is_filter_dir, revision): | |||
| url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/oss/tree/?' \ | |||
| f'MaxLimit={max_limit}&Revision={revision}&Recursive={is_recursive}&FilterDir={is_filter_dir}' | |||
| cookies = ModelScopeConfig.get_cookies() | |||
| if cookies: | |||
| cookies = requests.utils.dict_from_cookiejar(cookies) | |||
| resp = requests.get(url=url, cookies=cookies) | |||
| resp = resp.json() | |||
| raise_on_error(resp) | |||
| resp = resp['Data'] | |||
| return resp | |||
| def on_dataset_download(self, dataset_name: str, namespace: str) -> None: | |||
| url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/download/increase' | |||
| r = requests.post(url) | |||
| r.raise_for_status() | |||
| r = requests.post(url, headers=self.headers) | |||
| raise_for_http_status(r) | |||
| @staticmethod | |||
| def datahub_remote_call(url): | |||
| r = requests.get(url) | |||
| r = requests.get(url, headers={'user-agent': ModelScopeConfig.get_user_agent()}) | |||
| resp = r.json() | |||
| datahub_raise_on_error(url, resp) | |||
| return resp['Data'] | |||
| @@ -415,6 +604,7 @@ class ModelScopeConfig: | |||
| COOKIES_FILE_NAME = 'cookies' | |||
| GIT_TOKEN_FILE_NAME = 'git_token' | |||
| USER_INFO_FILE_NAME = 'user' | |||
| USER_SESSION_ID_FILE_NAME = 'session' | |||
| @staticmethod | |||
| def make_sure_credential_path_exist(): | |||
| @@ -443,6 +633,23 @@ class ModelScopeConfig: | |||
| return cookies | |||
| return None | |||
| @staticmethod | |||
| def get_user_session_id(): | |||
| session_path = os.path.join(ModelScopeConfig.path_credential, | |||
| ModelScopeConfig.USER_SESSION_ID_FILE_NAME) | |||
| session_id = '' | |||
| if os.path.exists(session_path): | |||
| with open(session_path, 'rb') as f: | |||
| session_id = str(f.readline().strip(), encoding='utf-8') | |||
| return session_id | |||
| if session_id == '' or len(session_id) != 32: | |||
| session_id = str(uuid.uuid4().hex) | |||
| ModelScopeConfig.make_sure_credential_path_exist() | |||
| with open(session_path, 'w+') as wf: | |||
| wf.write(session_id) | |||
| return session_id | |||
| @staticmethod | |||
| def save_token(token: str): | |||
| ModelScopeConfig.make_sure_credential_path_exist() | |||
| @@ -491,3 +698,32 @@ class ModelScopeConfig: | |||
| except FileNotFoundError: | |||
| pass | |||
| return token | |||
| @staticmethod | |||
| def get_user_agent(user_agent: Union[Dict, str, None] = None, ) -> str: | |||
| """Formats a user-agent string with basic info about a request. | |||
| Args: | |||
| user_agent (`str`, `dict`, *optional*): | |||
| The user agent info in the form of a dictionary or a single string. | |||
| Returns: | |||
| The formatted user-agent string. | |||
| """ | |||
| env = 'custom' | |||
| if MODELSCOPE_ENVIRONMENT in os.environ: | |||
| env = os.environ[MODELSCOPE_ENVIRONMENT] | |||
| ua = 'modelscope/%s; python/%s; session_id/%s; platform/%s; processor/%s; env/%s' % ( | |||
| __version__, | |||
| platform.python_version(), | |||
| ModelScopeConfig.get_user_session_id(), | |||
| platform.platform(), | |||
| platform.processor(), | |||
| env, | |||
| ) | |||
| if isinstance(user_agent, dict): | |||
| ua = '; '.join(f'{k}/{v}' for k, v in user_agent.items()) | |||
| elif isinstance(user_agent, str): | |||
| ua += ';' + user_agent | |||
| return ua | |||
| @@ -16,6 +16,9 @@ API_RESPONSE_FIELD_GIT_ACCESS_TOKEN = 'AccessToken' | |||
| API_RESPONSE_FIELD_USERNAME = 'Username' | |||
| API_RESPONSE_FIELD_EMAIL = 'Email' | |||
| API_RESPONSE_FIELD_MESSAGE = 'Message' | |||
| MODELSCOPE_ENVIRONMENT = 'MODELSCOPE_ENVIRONMENT' | |||
| MODELSCOPE_SDK_DEBUG = 'MODELSCOPE_SDK_DEBUG' | |||
| ONE_YEAR_SECONDS = 24 * 365 * 60 * 60 | |||
| class Licenses(object): | |||
| @@ -0,0 +1,339 @@ | |||
| import urllib | |||
| from abc import ABC | |||
| from http import HTTPStatus | |||
| from typing import Optional | |||
| import json | |||
| import requests | |||
| from attrs import asdict, define, field, validators | |||
| from modelscope.hub.api import ModelScopeConfig | |||
| from modelscope.hub.constants import (API_RESPONSE_FIELD_DATA, | |||
| API_RESPONSE_FIELD_MESSAGE) | |||
| from modelscope.hub.errors import (NotLoginException, NotSupportError, | |||
| RequestError, handle_http_response, is_ok, | |||
| raise_for_http_status) | |||
| from modelscope.hub.utils.utils import get_endpoint | |||
| from modelscope.utils.logger import get_logger | |||
| # yapf: enable | |||
| logger = get_logger() | |||
| class Accelerator(object): | |||
| CPU = 'cpu' | |||
| GPU = 'gpu' | |||
| class Vendor(object): | |||
| EAS = 'eas' | |||
| class EASRegion(object): | |||
| beijing = 'cn-beijing' | |||
| hangzhou = 'cn-hangzhou' | |||
| class EASCpuInstanceType(object): | |||
| """EAS Cpu Instance TYpe, ref(https://help.aliyun.com/document_detail/144261.html) | |||
| """ | |||
| tiny = 'ecs.c6.2xlarge' | |||
| small = 'ecs.c6.4xlarge' | |||
| medium = 'ecs.c6.6xlarge' | |||
| large = 'ecs.c6.8xlarge' | |||
| class EASGpuInstanceType(object): | |||
| """EAS Cpu Instance TYpe, ref(https://help.aliyun.com/document_detail/144261.html) | |||
| """ | |||
| tiny = 'ecs.gn5-c28g1.7xlarge' | |||
| small = 'ecs.gn5-c8g1.4xlarge' | |||
| medium = 'ecs.gn6i-c24g1.12xlarge' | |||
| large = 'ecs.gn6e-c12g1.3xlarge' | |||
| def min_smaller_than_max(instance, attribute, value): | |||
| if value > instance.max_replica: | |||
| raise ValueError( | |||
| "'min_replica' value: %s has to be smaller than 'max_replica' value: %s!" | |||
| % (value, instance.max_replica)) | |||
| @define | |||
| class ServiceScalingConfig(object): | |||
| """Resource scaling config | |||
| Currently we ignore max_replica | |||
| Args: | |||
| max_replica: maximum replica | |||
| min_replica: minimum replica | |||
| """ | |||
| max_replica: int = field(default=1, validator=validators.ge(1)) | |||
| min_replica: int = field( | |||
| default=1, validator=[validators.ge(1), min_smaller_than_max]) | |||
| @define | |||
| class ServiceResourceConfig(object): | |||
| """Eas Resource request. | |||
| Args: | |||
| accelerator: the accelerator(cpu|gpu) | |||
| instance_type: the instance type. | |||
| scaling: The instance scaling config. | |||
| """ | |||
| instance_type: str | |||
| scaling: ServiceScalingConfig | |||
| accelerator: str = field( | |||
| default=Accelerator.CPU, | |||
| validator=validators.in_([Accelerator.CPU, Accelerator.GPU])) | |||
| @define | |||
| class ServiceProviderParameters(ABC): | |||
| pass | |||
| @define | |||
| class EASDeployParameters(ServiceProviderParameters): | |||
| """Parameters for EAS Deployment. | |||
| Args: | |||
| resource_group: the resource group to deploy, current default. | |||
| region: The eas instance region(eg: cn-hangzhou). | |||
| access_key_id: The eas account access key id. | |||
| access_key_secret: The eas account access key secret. | |||
| vendor: must be 'eas' | |||
| """ | |||
| region: str | |||
| access_key_id: str | |||
| access_key_secret: str | |||
| resource_group: Optional[str] = None | |||
| vendor: str = field( | |||
| default=Vendor.EAS, validator=validators.in_([Vendor.EAS])) | |||
| @define | |||
| class EASListParameters(ServiceProviderParameters): | |||
| """EAS instance list parameters. | |||
| Args: | |||
| resource_group: the resource group to deploy, current default. | |||
| region: The eas instance region(eg: cn-hangzhou). | |||
| access_key_id: The eas account access key id. | |||
| access_key_secret: The eas account access key secret. | |||
| vendor: must be 'eas' | |||
| """ | |||
| access_key_id: str | |||
| access_key_secret: str | |||
| region: str = None | |||
| resource_group: str = None | |||
| vendor: str = field( | |||
| default=Vendor.EAS, validator=validators.in_([Vendor.EAS])) | |||
| @define | |||
| class DeployServiceParameters(object): | |||
| """Deploy service parameters | |||
| Args: | |||
| instance_name: the name of the service. | |||
| model_id: the modelscope model_id | |||
| revision: the modelscope model revision | |||
| resource: the resource requirement. | |||
| provider: the cloud service provider. | |||
| """ | |||
| instance_name: str | |||
| model_id: str | |||
| revision: str | |||
| resource: ServiceResourceConfig | |||
| provider: ServiceProviderParameters | |||
| class AttrsToQueryString(ABC): | |||
| """Convert the attrs class to json string. | |||
| Args: | |||
| """ | |||
| def to_query_str(self): | |||
| self_dict = asdict( | |||
| self.provider, filter=lambda attr, value: value is not None) | |||
| json_str = json.dumps(self_dict) | |||
| print(json_str) | |||
| safe_str = urllib.parse.quote_plus(json_str) | |||
| print(safe_str) | |||
| query_param = 'provider=%s' % safe_str | |||
| return query_param | |||
| @define | |||
| class ListServiceParameters(AttrsToQueryString): | |||
| provider: ServiceProviderParameters | |||
| skip: int = 0 | |||
| limit: int = 100 | |||
| @define | |||
| class GetServiceParameters(AttrsToQueryString): | |||
| provider: ServiceProviderParameters | |||
| @define | |||
| class DeleteServiceParameters(AttrsToQueryString): | |||
| provider: ServiceProviderParameters | |||
| class ServiceDeployer(object): | |||
| def __init__(self, endpoint=None): | |||
| self.endpoint = endpoint if endpoint is not None else get_endpoint() | |||
| self.headers = {'user-agent': ModelScopeConfig.get_user_agent()} | |||
| self.cookies = ModelScopeConfig.get_cookies() | |||
| if self.cookies is None: | |||
| raise NotLoginException( | |||
| 'Token does not exist, please login with HubApi first.') | |||
| # deploy_model | |||
| def create(self, model_id: str, revision: str, instance_name: str, | |||
| resource: ServiceResourceConfig, | |||
| provider: ServiceProviderParameters): | |||
| """Deploy model to cloud, current we only support PAI EAS, this is an async API , | |||
| and the deployment could take a while to finish remotely. Please check deploy instance | |||
| status separately via checking the status. | |||
| Args: | |||
| model_id (str): The deployed model id | |||
| revision (str): The model revision | |||
| instance_name (str): The deployed model instance name. | |||
| resource (ServiceResourceConfig): The service resource information. | |||
| provider (ServiceProviderParameters): The service provider parameter | |||
| Raises: | |||
| NotLoginException: To use this api, you need login first. | |||
| NotSupportError: Not supported platform. | |||
| RequestError: The server return error. | |||
| Returns: | |||
| ServiceInstanceInfo: The information of the deployed service instance. | |||
| """ | |||
| if provider.vendor != Vendor.EAS: | |||
| raise NotSupportError( | |||
| 'Not support vendor: %s ,only support EAS current.' % | |||
| (provider.vendor)) | |||
| create_params = DeployServiceParameters( | |||
| instance_name=instance_name, | |||
| model_id=model_id, | |||
| revision=revision, | |||
| resource=resource, | |||
| provider=provider) | |||
| path = f'{self.endpoint}/api/v1/deployer/endpoint' | |||
| body = asdict(create_params) | |||
| r = requests.post( | |||
| path, json=body, cookies=self.cookies, headers=self.headers) | |||
| handle_http_response(r, logger, self.cookies, 'create_service') | |||
| if r.status_code >= HTTPStatus.OK and r.status_code < HTTPStatus.MULTIPLE_CHOICES: | |||
| if is_ok(r.json()): | |||
| data = r.json()[API_RESPONSE_FIELD_DATA] | |||
| return data | |||
| else: | |||
| raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE]) | |||
| else: | |||
| raise_for_http_status(r) | |||
| return None | |||
| def get(self, instance_name: str, provider: ServiceProviderParameters): | |||
| """Query the specified instance information. | |||
| Args: | |||
| instance_name (str): The deployed instance name. | |||
| provider (ServiceProviderParameters): The cloud provider information, for eas | |||
| need region(eg: ch-hangzhou), access_key_id and access_key_secret. | |||
| Raises: | |||
| NotLoginException: To use this api, you need login first. | |||
| RequestError: The request is failed from server. | |||
| Returns: | |||
| Dict: The information of the requested service instance. | |||
| """ | |||
| params = GetServiceParameters(provider=provider) | |||
| path = '%s/api/v1/deployer/endpoint/%s?%s' % ( | |||
| self.endpoint, instance_name, params.to_query_str()) | |||
| r = requests.get(path, cookies=self.cookies, headers=self.headers) | |||
| handle_http_response(r, logger, self.cookies, 'get_service') | |||
| if r.status_code == HTTPStatus.OK: | |||
| if is_ok(r.json()): | |||
| data = r.json()[API_RESPONSE_FIELD_DATA] | |||
| return data | |||
| else: | |||
| raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE]) | |||
| else: | |||
| raise_for_http_status(r) | |||
| return None | |||
| def delete(self, instance_name: str, provider: ServiceProviderParameters): | |||
| """Delete deployed model, this api send delete command and return, it will take | |||
| some to delete, please check through the cloud console. | |||
| Args: | |||
| instance_name (str): The instance name you want to delete. | |||
| provider (ServiceProviderParameters): The cloud provider information, for eas | |||
| need region(eg: ch-hangzhou), access_key_id and access_key_secret. | |||
| Raises: | |||
| NotLoginException: To call this api, you need login first. | |||
| RequestError: The request is failed. | |||
| Returns: | |||
| Dict: The deleted instance information. | |||
| """ | |||
| params = DeleteServiceParameters(provider=provider) | |||
| path = '%s/api/v1/deployer/endpoint/%s?%s' % ( | |||
| self.endpoint, instance_name, params.to_query_str()) | |||
| r = requests.delete(path, cookies=self.cookies, headers=self.headers) | |||
| handle_http_response(r, logger, self.cookies, 'delete_service') | |||
| if r.status_code == HTTPStatus.OK: | |||
| if is_ok(r.json()): | |||
| data = r.json()[API_RESPONSE_FIELD_DATA] | |||
| return data | |||
| else: | |||
| raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE]) | |||
| else: | |||
| raise_for_http_status(r) | |||
| return None | |||
| def list(self, | |||
| provider: ServiceProviderParameters, | |||
| skip: int = 0, | |||
| limit: int = 100): | |||
| """List deployed model instances. | |||
| Args: | |||
| provider (ServiceProviderParameters): The cloud service provider parameter, | |||
| for eas, need access_key_id and access_key_secret. | |||
| skip: start of the list, current not support. | |||
| limit: maximum number of instances return, current not support | |||
| Raises: | |||
| NotLoginException: To use this api, you need login first. | |||
| RequestError: The request is failed from server. | |||
| Returns: | |||
| List: List of instance information | |||
| """ | |||
| params = ListServiceParameters( | |||
| provider=provider, skip=skip, limit=limit) | |||
| path = '%s/api/v1/deployer/endpoint?%s' % (self.endpoint, | |||
| params.to_query_str()) | |||
| r = requests.get(path, cookies=self.cookies, headers=self.headers) | |||
| handle_http_response(r, logger, self.cookies, 'list_service_instances') | |||
| if r.status_code == HTTPStatus.OK: | |||
| if is_ok(r.json()): | |||
| data = r.json()[API_RESPONSE_FIELD_DATA] | |||
| return data | |||
| else: | |||
| raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE]) | |||
| else: | |||
| raise_for_http_status(r) | |||
| return None | |||
| @@ -4,6 +4,18 @@ from http import HTTPStatus | |||
| from requests.exceptions import HTTPError | |||
| from modelscope.utils.logger import get_logger | |||
| logger = get_logger() | |||
| class NotSupportError(Exception): | |||
| pass | |||
| class NoValidRevisionError(Exception): | |||
| pass | |||
| class NotExistError(Exception): | |||
| pass | |||
| @@ -45,15 +57,25 @@ def is_ok(rsp): | |||
| return rsp['Code'] == HTTPStatus.OK and rsp['Success'] | |||
| def handle_http_post_error(response, url, request_body): | |||
| try: | |||
| response.raise_for_status() | |||
| except HTTPError as error: | |||
| logger.error('Request %s with body: %s exception' % | |||
| (url, request_body)) | |||
| raise error | |||
| def handle_http_response(response, logger, cookies, model_id): | |||
| try: | |||
| response.raise_for_status() | |||
| except HTTPError: | |||
| except HTTPError as error: | |||
| if cookies is None: # code in [403] and | |||
| logger.error( | |||
| f'Authentication token does not exist, failed to access model {model_id} which may not exist or may be \ | |||
| private. Please login first.') | |||
| raise | |||
| logger.error('Response details: %s' % response.content) | |||
| raise error | |||
| def raise_on_error(rsp): | |||
| @@ -81,3 +103,33 @@ def datahub_raise_on_error(url, rsp): | |||
| raise RequestError( | |||
| f"Url = {url}, Status = {rsp.get('status')}, error = {rsp.get('error')}, message = {rsp.get('message')}" | |||
| ) | |||
| def raise_for_http_status(rsp): | |||
| """ | |||
| Attempt to decode utf-8 first since some servers | |||
| localize reason strings, for invalid utf-8, fall back | |||
| to decoding with iso-8859-1. | |||
| """ | |||
| http_error_msg = '' | |||
| if isinstance(rsp.reason, bytes): | |||
| try: | |||
| reason = rsp.reason.decode('utf-8') | |||
| except UnicodeDecodeError: | |||
| reason = rsp.reason.decode('iso-8859-1') | |||
| else: | |||
| reason = rsp.reason | |||
| if 400 <= rsp.status_code < 500: | |||
| http_error_msg = u'%s Client Error: %s for url: %s' % (rsp.status_code, | |||
| reason, rsp.url) | |||
| elif 500 <= rsp.status_code < 600: | |||
| http_error_msg = u'%s Server Error: %s for url: %s' % (rsp.status_code, | |||
| reason, rsp.url) | |||
| if http_error_msg: | |||
| req = rsp.request | |||
| if req.method == 'POST': | |||
| http_error_msg = u'%s, body: %s' % (http_error_msg, req.body) | |||
| raise HTTPError(http_error_msg, response=rsp) | |||
| @@ -2,29 +2,25 @@ | |||
| import copy | |||
| import os | |||
| import sys | |||
| import tempfile | |||
| from functools import partial | |||
| from http.cookiejar import CookieJar | |||
| from pathlib import Path | |||
| from typing import Dict, Optional, Union | |||
| from uuid import uuid4 | |||
| import requests | |||
| from filelock import FileLock | |||
| from tqdm import tqdm | |||
| from modelscope import __version__ | |||
| from modelscope.hub.api import HubApi, ModelScopeConfig | |||
| from modelscope.utils.constant import DEFAULT_MODEL_REVISION | |||
| from modelscope.utils.logger import get_logger | |||
| from .api import HubApi, ModelScopeConfig | |||
| from .constants import FILE_HASH | |||
| from .errors import FileDownloadError, NotExistError | |||
| from .utils.caching import ModelFileSystemCache | |||
| from .utils.utils import (file_integrity_validation, get_cache_dir, | |||
| get_endpoint, model_id_to_group_owner_name) | |||
| SESSION_ID = uuid4().hex | |||
| logger = get_logger() | |||
| @@ -35,6 +31,7 @@ def model_file_download( | |||
| cache_dir: Optional[str] = None, | |||
| user_agent: Union[Dict, str, None] = None, | |||
| local_files_only: Optional[bool] = False, | |||
| cookies: Optional[CookieJar] = None, | |||
| ) -> Optional[str]: # pragma: no cover | |||
| """ | |||
| Download from a given URL and cache it if it's not already present in the | |||
| @@ -105,54 +102,47 @@ def model_file_download( | |||
| " online, set 'local_files_only' to False.") | |||
| _api = HubApi() | |||
| headers = {'user-agent': http_user_agent(user_agent=user_agent, )} | |||
| cookies = ModelScopeConfig.get_cookies() | |||
| branches, tags = _api.get_model_branches_and_tags( | |||
| model_id, use_cookies=False if cookies is None else cookies) | |||
| headers = { | |||
| 'user-agent': ModelScopeConfig.get_user_agent(user_agent=user_agent, ) | |||
| } | |||
| if cookies is None: | |||
| cookies = ModelScopeConfig.get_cookies() | |||
| revision = _api.get_valid_revision( | |||
| model_id, revision=revision, cookies=cookies) | |||
| file_to_download_info = None | |||
| is_commit_id = False | |||
| if revision in branches or revision in tags: # The revision is version or tag, | |||
| # we need to confirm the version is up to date | |||
| # we need to get the file list to check if the lateast version is cached, if so return, otherwise download | |||
| model_files = _api.get_model_files( | |||
| model_id=model_id, | |||
| revision=revision, | |||
| recursive=True, | |||
| use_cookies=False if cookies is None else cookies) | |||
| for model_file in model_files: | |||
| if model_file['Type'] == 'tree': | |||
| continue | |||
| if model_file['Path'] == file_path: | |||
| if cache.exists(model_file): | |||
| return cache.get_file_by_info(model_file) | |||
| else: | |||
| file_to_download_info = model_file | |||
| break | |||
| if file_to_download_info is None: | |||
| raise NotExistError('The file path: %s not exist in: %s' % | |||
| (file_path, model_id)) | |||
| else: # the revision is commit id. | |||
| cached_file_path = cache.get_file_by_path_and_commit_id( | |||
| file_path, revision) | |||
| if cached_file_path is not None: | |||
| file_name = os.path.basename(cached_file_path) | |||
| logger.info( | |||
| f'File {file_name} already in cache, skip downloading!') | |||
| return cached_file_path # the file is in cache. | |||
| is_commit_id = True | |||
| # we need to confirm the version is up-to-date | |||
| # we need to get the file list to check if the latest version is cached, if so return, otherwise download | |||
| model_files = _api.get_model_files( | |||
| model_id=model_id, | |||
| revision=revision, | |||
| recursive=True, | |||
| use_cookies=False if cookies is None else cookies) | |||
| for model_file in model_files: | |||
| if model_file['Type'] == 'tree': | |||
| continue | |||
| if model_file['Path'] == file_path: | |||
| if cache.exists(model_file): | |||
| logger.info( | |||
| f'File {model_file["Name"]} already in cache, skip downloading!' | |||
| ) | |||
| return cache.get_file_by_info(model_file) | |||
| else: | |||
| file_to_download_info = model_file | |||
| break | |||
| if file_to_download_info is None: | |||
| raise NotExistError('The file path: %s not exist in: %s' % | |||
| (file_path, model_id)) | |||
| # we need to download again | |||
| url_to_download = get_file_download_url(model_id, file_path, revision) | |||
| file_to_download_info = { | |||
| 'Path': | |||
| file_path, | |||
| 'Revision': | |||
| revision if is_commit_id else file_to_download_info['Revision'], | |||
| FILE_HASH: | |||
| None if (is_commit_id or FILE_HASH not in file_to_download_info) else | |||
| file_to_download_info[FILE_HASH] | |||
| 'Path': file_path, | |||
| 'Revision': file_to_download_info['Revision'], | |||
| FILE_HASH: file_to_download_info[FILE_HASH] | |||
| } | |||
| temp_file_name = next(tempfile._get_candidate_names()) | |||
| @@ -171,25 +161,6 @@ def model_file_download( | |||
| os.path.join(temporary_cache_dir, temp_file_name)) | |||
| def http_user_agent(user_agent: Union[Dict, str, None] = None, ) -> str: | |||
| """Formats a user-agent string with basic info about a request. | |||
| Args: | |||
| user_agent (`str`, `dict`, *optional*): | |||
| The user agent info in the form of a dictionary or a single string. | |||
| Returns: | |||
| The formatted user-agent string. | |||
| """ | |||
| ua = f'modelscope/{__version__}; python/{sys.version.split()[0]}; session_id/{SESSION_ID}' | |||
| if isinstance(user_agent, dict): | |||
| ua = '; '.join(f'{k}/{v}' for k, v in user_agent.items()) | |||
| elif isinstance(user_agent, str): | |||
| ua = user_agent | |||
| return ua | |||
| def get_file_download_url(model_id: str, file_path: str, revision: str): | |||
| """ | |||
| Format file download url according to `model_id`, `revision` and `file_path`. | |||
| @@ -3,10 +3,9 @@ | |||
| import os | |||
| import subprocess | |||
| from typing import List | |||
| from xmlrpc.client import Boolean | |||
| from modelscope.utils.logger import get_logger | |||
| from .api import ModelScopeConfig | |||
| from ..utils.constant import MASTER_MODEL_BRANCH | |||
| from .errors import GitError | |||
| logger = get_logger() | |||
| @@ -131,6 +130,7 @@ class GitCommandWrapper(metaclass=Singleton): | |||
| return response | |||
| def add_user_info(self, repo_base_dir, repo_name): | |||
| from modelscope.hub.api import ModelScopeConfig | |||
| user_name, user_email = ModelScopeConfig.get_user_info() | |||
| if user_name and user_email: | |||
| # config user.name and user.email if exist | |||
| @@ -138,8 +138,8 @@ class GitCommandWrapper(metaclass=Singleton): | |||
| repo_base_dir, repo_name, user_name) | |||
| response = self._run_git_command(*config_user_name_args.split(' ')) | |||
| logger.debug(response.stdout.decode('utf8')) | |||
| config_user_email_args = '-C %s/%s config user.name %s' % ( | |||
| repo_base_dir, repo_name, user_name) | |||
| config_user_email_args = '-C %s/%s config user.email %s' % ( | |||
| repo_base_dir, repo_name, user_email) | |||
| response = self._run_git_command( | |||
| *config_user_email_args.split(' ')) | |||
| logger.debug(response.stdout.decode('utf8')) | |||
| @@ -177,6 +177,18 @@ class GitCommandWrapper(metaclass=Singleton): | |||
| cmds = ['-C', '%s' % repo_dir, 'checkout', '-b', revision] | |||
| return self._run_git_command(*cmds) | |||
| def get_remote_branches(self, repo_dir: str): | |||
| cmds = ['-C', '%s' % repo_dir, 'branch', '-r'] | |||
| rsp = self._run_git_command(*cmds) | |||
| info = [ | |||
| line.strip() | |||
| for line in rsp.stdout.decode('utf8').strip().split(os.linesep) | |||
| ] | |||
| if len(info) == 1: | |||
| return ['/'.join(info[0].split('/')[1:])] | |||
| else: | |||
| return ['/'.join(line.split('/')[1:]) for line in info[1:]] | |||
| def pull(self, repo_dir: str): | |||
| cmds = ['-C', repo_dir, 'pull'] | |||
| return self._run_git_command(*cmds) | |||
| @@ -216,3 +228,22 @@ class GitCommandWrapper(metaclass=Singleton): | |||
| files.append(line.split(' ')[-1]) | |||
| return files | |||
| def tag(self, | |||
| repo_dir: str, | |||
| tag_name: str, | |||
| message: str, | |||
| ref: str = MASTER_MODEL_BRANCH): | |||
| cmd_args = [ | |||
| '-C', repo_dir, 'tag', tag_name, '-m', | |||
| '"%s"' % message, ref | |||
| ] | |||
| rsp = self._run_git_command(*cmd_args) | |||
| logger.debug(rsp.stdout.decode('utf8')) | |||
| return rsp | |||
| def push_tag(self, repo_dir: str, tag_name): | |||
| cmd_args = ['-C', repo_dir, 'push', 'origin', tag_name] | |||
| rsp = self._run_git_command(*cmd_args) | |||
| logger.debug(rsp.stdout.decode('utf8')) | |||
| return rsp | |||
| @@ -5,9 +5,9 @@ from typing import Optional | |||
| from modelscope.hub.errors import GitError, InvalidParameter, NotLoginException | |||
| from modelscope.utils.constant import (DEFAULT_DATASET_REVISION, | |||
| DEFAULT_MODEL_REVISION) | |||
| DEFAULT_REPOSITORY_REVISION, | |||
| MASTER_MODEL_BRANCH) | |||
| from modelscope.utils.logger import get_logger | |||
| from .api import ModelScopeConfig | |||
| from .git import GitCommandWrapper | |||
| from .utils.utils import get_endpoint | |||
| @@ -21,7 +21,7 @@ class Repository: | |||
| def __init__(self, | |||
| model_dir: str, | |||
| clone_from: str, | |||
| revision: Optional[str] = DEFAULT_MODEL_REVISION, | |||
| revision: Optional[str] = DEFAULT_REPOSITORY_REVISION, | |||
| auth_token: Optional[str] = None, | |||
| git_path: Optional[str] = None): | |||
| """ | |||
| @@ -47,6 +47,7 @@ class Repository: | |||
| err_msg = 'a non-default value of revision cannot be empty.' | |||
| raise InvalidParameter(err_msg) | |||
| from modelscope.hub.api import ModelScopeConfig | |||
| if auth_token: | |||
| self.auth_token = auth_token | |||
| else: | |||
| @@ -89,7 +90,8 @@ class Repository: | |||
| def push(self, | |||
| commit_message: str, | |||
| branch: Optional[str] = DEFAULT_MODEL_REVISION, | |||
| local_branch: Optional[str] = DEFAULT_REPOSITORY_REVISION, | |||
| remote_branch: Optional[str] = DEFAULT_REPOSITORY_REVISION, | |||
| force: bool = False): | |||
| """Push local files to remote, this method will do. | |||
| git pull | |||
| @@ -116,14 +118,48 @@ class Repository: | |||
| url = self.git_wrapper.get_repo_remote_url(self.model_dir) | |||
| self.git_wrapper.pull(self.model_dir) | |||
| self.git_wrapper.add(self.model_dir, all_files=True) | |||
| self.git_wrapper.commit(self.model_dir, commit_message) | |||
| self.git_wrapper.push( | |||
| repo_dir=self.model_dir, | |||
| token=self.auth_token, | |||
| url=url, | |||
| local_branch=branch, | |||
| remote_branch=branch) | |||
| local_branch=local_branch, | |||
| remote_branch=remote_branch) | |||
| def tag(self, tag_name: str, message: str, ref: str = MASTER_MODEL_BRANCH): | |||
| """Create a new tag. | |||
| Args: | |||
| tag_name (str): The name of the tag | |||
| message (str): The tag message. | |||
| ref (str): The tag reference, can be commit id or branch. | |||
| """ | |||
| if tag_name is None or tag_name == '': | |||
| msg = 'We use tag-based revision, therefore tag_name cannot be None or empty.' | |||
| raise InvalidParameter(msg) | |||
| if message is None or message == '': | |||
| msg = 'We use annotated tag, therefore message cannot None or empty.' | |||
| self.git_wrapper.tag( | |||
| repo_dir=self.model_dir, | |||
| tag_name=tag_name, | |||
| message=message, | |||
| ref=ref) | |||
| def tag_and_push(self, | |||
| tag_name: str, | |||
| message: str, | |||
| ref: str = MASTER_MODEL_BRANCH): | |||
| """Create tag and push to remote | |||
| Args: | |||
| tag_name (str): The name of the tag | |||
| message (str): The tag message. | |||
| ref (str, optional): The tag ref, can be commit id or branch. Defaults to MASTER_MODEL_BRANCH. | |||
| """ | |||
| self.tag(tag_name, message, ref) | |||
| self.git_wrapper.push_tag(repo_dir=self.model_dir, tag_name=tag_name) | |||
| class DatasetRepository: | |||
| @@ -166,7 +202,7 @@ class DatasetRepository: | |||
| err_msg = 'a non-default value of revision cannot be empty.' | |||
| raise InvalidParameter(err_msg) | |||
| self.revision = revision | |||
| from modelscope.hub.api import ModelScopeConfig | |||
| if auth_token: | |||
| self.auth_token = auth_token | |||
| else: | |||
| @@ -2,16 +2,15 @@ | |||
| import os | |||
| import tempfile | |||
| from http.cookiejar import CookieJar | |||
| from pathlib import Path | |||
| from typing import Dict, Optional, Union | |||
| from modelscope.hub.api import HubApi, ModelScopeConfig | |||
| from modelscope.utils.constant import DEFAULT_MODEL_REVISION | |||
| from modelscope.utils.logger import get_logger | |||
| from .api import HubApi, ModelScopeConfig | |||
| from .constants import FILE_HASH | |||
| from .errors import NotExistError | |||
| from .file_download import (get_file_download_url, http_get_file, | |||
| http_user_agent) | |||
| from .file_download import get_file_download_url, http_get_file | |||
| from .utils.caching import ModelFileSystemCache | |||
| from .utils.utils import (file_integrity_validation, get_cache_dir, | |||
| model_id_to_group_owner_name) | |||
| @@ -23,7 +22,8 @@ def snapshot_download(model_id: str, | |||
| revision: Optional[str] = DEFAULT_MODEL_REVISION, | |||
| cache_dir: Union[str, Path, None] = None, | |||
| user_agent: Optional[Union[Dict, str]] = None, | |||
| local_files_only: Optional[bool] = False) -> str: | |||
| local_files_only: Optional[bool] = False, | |||
| cookies: Optional[CookieJar] = None) -> str: | |||
| """Download all files of a repo. | |||
| Downloads a whole snapshot of a repo's files at the specified revision. This | |||
| is useful when you want all files from a repo, because you don't know which | |||
| @@ -81,15 +81,15 @@ def snapshot_download(model_id: str, | |||
| ) # we can not confirm the cached file is for snapshot 'revision' | |||
| else: | |||
| # make headers | |||
| headers = {'user-agent': http_user_agent(user_agent=user_agent, )} | |||
| headers = { | |||
| 'user-agent': | |||
| ModelScopeConfig.get_user_agent(user_agent=user_agent, ) | |||
| } | |||
| _api = HubApi() | |||
| cookies = ModelScopeConfig.get_cookies() | |||
| # get file list from model repo | |||
| branches, tags = _api.get_model_branches_and_tags( | |||
| model_id, use_cookies=False if cookies is None else cookies) | |||
| if revision not in branches and revision not in tags: | |||
| raise NotExistError('The specified branch or tag : %s not exist!' | |||
| % revision) | |||
| if cookies is None: | |||
| cookies = ModelScopeConfig.get_cookies() | |||
| revision = _api.get_valid_revision( | |||
| model_id, revision=revision, cookies=cookies) | |||
| snapshot_header = headers if 'CI_TEST' in os.environ else { | |||
| **headers, | |||
| @@ -110,7 +110,7 @@ def snapshot_download(model_id: str, | |||
| for model_file in model_files: | |||
| if model_file['Type'] == 'tree': | |||
| continue | |||
| # check model_file is exist in cache, if exist, skip download, otherwise download | |||
| # check model_file is exist in cache, if existed, skip download, otherwise download | |||
| if cache.exists(model_file): | |||
| file_name = os.path.basename(model_file['Name']) | |||
| logger.info( | |||
| @@ -2,12 +2,12 @@ | |||
| import hashlib | |||
| import os | |||
| from datetime import datetime | |||
| from typing import Optional | |||
| from modelscope.hub.constants import (DEFAULT_MODELSCOPE_DATA_ENDPOINT, | |||
| DEFAULT_MODELSCOPE_DOMAIN, | |||
| from modelscope.hub.constants import (DEFAULT_MODELSCOPE_DOMAIN, | |||
| DEFAULT_MODELSCOPE_GROUP, | |||
| MODEL_ID_SEPARATOR, | |||
| MODEL_ID_SEPARATOR, MODELSCOPE_SDK_DEBUG, | |||
| MODELSCOPE_URL_SCHEME) | |||
| from modelscope.hub.errors import FileIntegrityError | |||
| from modelscope.utils.file_utils import get_default_cache_dir | |||
| @@ -38,17 +38,24 @@ def get_cache_dir(model_id: Optional[str] = None): | |||
| base_path, model_id + '/') | |||
| def get_release_datetime(): | |||
| if MODELSCOPE_SDK_DEBUG in os.environ: | |||
| rt = int(round(datetime.now().timestamp())) | |||
| else: | |||
| from modelscope import version | |||
| rt = int( | |||
| round( | |||
| datetime.strptime(version.__release_datetime__, | |||
| '%Y-%m-%d %H:%M:%S').timestamp())) | |||
| return rt | |||
| def get_endpoint(): | |||
| modelscope_domain = os.getenv('MODELSCOPE_DOMAIN', | |||
| DEFAULT_MODELSCOPE_DOMAIN) | |||
| return MODELSCOPE_URL_SCHEME + modelscope_domain | |||
| def get_dataset_hub_endpoint(): | |||
| return os.environ.get('HUB_DATASET_ENDPOINT', | |||
| DEFAULT_MODELSCOPE_DATA_ENDPOINT) | |||
| def compute_hash(file_path): | |||
| BUFFER_SIZE = 1024 * 64 # 64k buffer size | |||
| sha256_hash = hashlib.sha256() | |||
| @@ -9,11 +9,14 @@ class Models(object): | |||
| Model name should only contain model info but not task info. | |||
| """ | |||
| # tinynas models | |||
| tinynas_detection = 'tinynas-detection' | |||
| tinynas_damoyolo = 'tinynas-damoyolo' | |||
| # vision models | |||
| detection = 'detection' | |||
| realtime_object_detection = 'realtime-object-detection' | |||
| realtime_video_object_detection = 'realtime-video-object-detection' | |||
| scrfd = 'scrfd' | |||
| classification_model = 'ClassificationModel' | |||
| nafnet = 'nafnet' | |||
| @@ -27,11 +30,13 @@ class Models(object): | |||
| face_2d_keypoints = 'face-2d-keypoints' | |||
| panoptic_segmentation = 'swinL-panoptic-segmentation' | |||
| image_reid_person = 'passvitb' | |||
| image_inpainting = 'FFTInpainting' | |||
| video_summarization = 'pgl-video-summarization' | |||
| swinL_semantic_segmentation = 'swinL-semantic-segmentation' | |||
| vitadapter_semantic_segmentation = 'vitadapter-semantic-segmentation' | |||
| text_driven_segmentation = 'text-driven-segmentation' | |||
| resnet50_bert = 'resnet50-bert' | |||
| referring_video_object_segmentation = 'swinT-referring-video-object-segmentation' | |||
| fer = 'fer' | |||
| retinaface = 'retinaface' | |||
| shop_segmentation = 'shop-segmentation' | |||
| @@ -39,14 +44,18 @@ class Models(object): | |||
| mtcnn = 'mtcnn' | |||
| ulfd = 'ulfd' | |||
| video_inpainting = 'video-inpainting' | |||
| human_wholebody_keypoint = 'human-wholebody-keypoint' | |||
| hand_static = 'hand-static' | |||
| face_human_hand_detection = 'face-human-hand-detection' | |||
| face_emotion = 'face-emotion' | |||
| product_segmentation = 'product-segmentation' | |||
| image_body_reshaping = 'image-body-reshaping' | |||
| # EasyCV models | |||
| yolox = 'YOLOX' | |||
| segformer = 'Segformer' | |||
| hand_2d_keypoints = 'HRNet-Hand2D-Keypoints' | |||
| image_object_detection_auto = 'image-object-detection-auto' | |||
| # nlp models | |||
| bert = 'bert' | |||
| @@ -58,18 +67,22 @@ class Models(object): | |||
| space_dst = 'space-dst' | |||
| space_intent = 'space-intent' | |||
| space_modeling = 'space-modeling' | |||
| star = 'star' | |||
| star3 = 'star3' | |||
| space_T_en = 'space-T-en' | |||
| space_T_cn = 'space-T-cn' | |||
| tcrf = 'transformer-crf' | |||
| tcrf_wseg = 'transformer-crf-for-word-segmentation' | |||
| transformer_softmax = 'transformer-softmax' | |||
| lcrf = 'lstm-crf' | |||
| lcrf_wseg = 'lstm-crf-for-word-segmentation' | |||
| gcnncrf = 'gcnn-crf' | |||
| bart = 'bart' | |||
| gpt3 = 'gpt3' | |||
| gpt_neo = 'gpt-neo' | |||
| plug = 'plug' | |||
| bert_for_ds = 'bert-for-document-segmentation' | |||
| ponet = 'ponet' | |||
| T5 = 'T5' | |||
| bloom = 'bloom' | |||
| # audio models | |||
| sambert_hifigan = 'sambert-hifigan' | |||
| @@ -88,6 +101,10 @@ class Models(object): | |||
| team = 'team-multi-modal-similarity' | |||
| video_clip = 'video-clip-multi-modal-embedding' | |||
| # science models | |||
| unifold = 'unifold' | |||
| unifold_symmetry = 'unifold-symmetry' | |||
| class TaskModels(object): | |||
| # nlp task | |||
| @@ -96,6 +113,7 @@ class TaskModels(object): | |||
| information_extraction = 'information-extraction' | |||
| fill_mask = 'fill-mask' | |||
| feature_extraction = 'feature-extraction' | |||
| text_generation = 'text-generation' | |||
| class Heads(object): | |||
| @@ -111,6 +129,8 @@ class Heads(object): | |||
| token_classification = 'token-classification' | |||
| # extraction | |||
| information_extraction = 'information-extraction' | |||
| # text gen | |||
| text_generation = 'text-generation' | |||
| class Pipelines(object): | |||
| @@ -144,6 +164,7 @@ class Pipelines(object): | |||
| salient_detection = 'u2net-salient-detection' | |||
| image_classification = 'image-classification' | |||
| face_detection = 'resnet-face-detection-scrfd10gkps' | |||
| card_detection = 'resnet-card-detection-scrfd34gkps' | |||
| ulfd_face_detection = 'manual-face-detection-ulfd' | |||
| facial_expression_recognition = 'vgg19-facial-expression-recognition-fer' | |||
| retina_face_detection = 'resnet50-face-detection-retinaface' | |||
| @@ -160,6 +181,7 @@ class Pipelines(object): | |||
| face_image_generation = 'gan-face-image-generation' | |||
| product_retrieval_embedding = 'resnet50-product-retrieval-embedding' | |||
| realtime_object_detection = 'cspnet_realtime-object-detection_yolox' | |||
| realtime_video_object_detection = 'cspnet_realtime-video-object-detection_streamyolo' | |||
| face_recognition = 'ir101-face-recognition-cfglint' | |||
| image_instance_segmentation = 'cascade-mask-rcnn-swin-image-instance-segmentation' | |||
| image2image_translation = 'image-to-image-translation' | |||
| @@ -168,6 +190,7 @@ class Pipelines(object): | |||
| ocr_recognition = 'convnextTiny-ocr-recognition' | |||
| image_portrait_enhancement = 'gpen-image-portrait-enhancement' | |||
| image_to_image_generation = 'image-to-image-generation' | |||
| image_object_detection_auto = 'yolox_image-object-detection-auto' | |||
| skin_retouching = 'unet-skin-retouching' | |||
| tinynas_classification = 'tinynas-classification' | |||
| tinynas_detection = 'tinynas-detection' | |||
| @@ -178,21 +201,32 @@ class Pipelines(object): | |||
| video_summarization = 'googlenet_pgl_video_summarization' | |||
| image_semantic_segmentation = 'image-semantic-segmentation' | |||
| image_reid_person = 'passvitb-image-reid-person' | |||
| image_inpainting = 'fft-inpainting' | |||
| text_driven_segmentation = 'text-driven-segmentation' | |||
| movie_scene_segmentation = 'resnet50-bert-movie-scene-segmentation' | |||
| shop_segmentation = 'shop-segmentation' | |||
| video_inpainting = 'video-inpainting' | |||
| human_wholebody_keypoint = 'hrnetw48_human-wholebody-keypoint_image' | |||
| pst_action_recognition = 'patchshift-action-recognition' | |||
| hand_static = 'hand-static' | |||
| face_human_hand_detection = 'face-human-hand-detection' | |||
| face_emotion = 'face-emotion' | |||
| product_segmentation = 'product-segmentation' | |||
| image_body_reshaping = 'flow-based-body-reshaping' | |||
| referring_video_object_segmentation = 'referring-video-object-segmentation' | |||
| # nlp tasks | |||
| automatic_post_editing = 'automatic-post-editing' | |||
| translation_quality_estimation = 'translation-quality-estimation' | |||
| domain_classification = 'domain-classification' | |||
| sentence_similarity = 'sentence-similarity' | |||
| word_segmentation = 'word-segmentation' | |||
| multilingual_word_segmentation = 'multilingual-word-segmentation' | |||
| word_segmentation_thai = 'word-segmentation-thai' | |||
| part_of_speech = 'part-of-speech' | |||
| named_entity_recognition = 'named-entity-recognition' | |||
| named_entity_recognition_thai = 'named-entity-recognition-thai' | |||
| named_entity_recognition_viet = 'named-entity-recognition-viet' | |||
| text_generation = 'text-generation' | |||
| text2text_generation = 'text2text-generation' | |||
| sentiment_analysis = 'sentiment-analysis' | |||
| @@ -208,14 +242,18 @@ class Pipelines(object): | |||
| zero_shot_classification = 'zero-shot-classification' | |||
| text_error_correction = 'text-error-correction' | |||
| plug_generation = 'plug-generation' | |||
| gpt3_generation = 'gpt3-generation' | |||
| faq_question_answering = 'faq-question-answering' | |||
| conversational_text_to_sql = 'conversational-text-to-sql' | |||
| table_question_answering_pipeline = 'table-question-answering-pipeline' | |||
| sentence_embedding = 'sentence-embedding' | |||
| passage_ranking = 'passage-ranking' | |||
| text_ranking = 'text-ranking' | |||
| relation_extraction = 'relation-extraction' | |||
| document_segmentation = 'document-segmentation' | |||
| feature_extraction = 'feature-extraction' | |||
| translation_en_to_de = 'translation_en_to_de' # keep it underscore | |||
| translation_en_to_ro = 'translation_en_to_ro' # keep it underscore | |||
| translation_en_to_fr = 'translation_en_to_fr' # keep it underscore | |||
| # audio tasks | |||
| sambert_hifigan_tts = 'sambert-hifigan-tts' | |||
| @@ -236,6 +274,10 @@ class Pipelines(object): | |||
| text_to_image_synthesis = 'text-to-image-synthesis' | |||
| video_multi_modal_embedding = 'video-multi-modal-embedding' | |||
| image_text_retrieval = 'image-text-retrieval' | |||
| ofa_ocr_recognition = 'ofa-ocr-recognition' | |||
| # science tasks | |||
| protein_structure = 'unifold-protein-structure' | |||
| class Trainers(object): | |||
| @@ -253,12 +295,16 @@ class Trainers(object): | |||
| # multi-modal trainers | |||
| clip_multi_modal_embedding = 'clip-multi-modal-embedding' | |||
| ofa = 'ofa' | |||
| # cv trainers | |||
| image_instance_segmentation = 'image-instance-segmentation' | |||
| image_portrait_enhancement = 'image-portrait-enhancement' | |||
| video_summarization = 'video-summarization' | |||
| movie_scene_segmentation = 'movie-scene-segmentation' | |||
| face_detection_scrfd = 'face-detection-scrfd' | |||
| card_detection_scrfd = 'card-detection-scrfd' | |||
| image_inpainting = 'image-inpainting' | |||
| # nlp trainers | |||
| bert_sentiment_analysis = 'bert-sentiment-analysis' | |||
| @@ -266,10 +312,11 @@ class Trainers(object): | |||
| dialog_intent_trainer = 'dialog-intent-trainer' | |||
| nlp_base_trainer = 'nlp-base-trainer' | |||
| nlp_veco_trainer = 'nlp-veco-trainer' | |||
| nlp_passage_ranking_trainer = 'nlp-passage-ranking-trainer' | |||
| nlp_text_ranking_trainer = 'nlp-text-ranking-trainer' | |||
| # audio trainers | |||
| speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k' | |||
| speech_dfsmn_kws_char_farfield = 'speech_dfsmn_kws_char_farfield' | |||
| class Preprocessors(object): | |||
| @@ -298,8 +345,12 @@ class Preprocessors(object): | |||
| bert_seq_cls_tokenizer = 'bert-seq-cls-tokenizer' | |||
| text_gen_tokenizer = 'text-gen-tokenizer' | |||
| text2text_gen_preprocessor = 'text2text-gen-preprocessor' | |||
| text_gen_jieba_tokenizer = 'text-gen-jieba-tokenizer' | |||
| text2text_translate_preprocessor = 'text2text-translate-preprocessor' | |||
| token_cls_tokenizer = 'token-cls-tokenizer' | |||
| ner_tokenizer = 'ner-tokenizer' | |||
| thai_ner_tokenizer = 'thai-ner-tokenizer' | |||
| viet_ner_tokenizer = 'viet-ner-tokenizer' | |||
| nli_tokenizer = 'nli-tokenizer' | |||
| sen_cls_tokenizer = 'sen-cls-tokenizer' | |||
| dialog_intent_preprocessor = 'dialog-intent-preprocessor' | |||
| @@ -309,9 +360,10 @@ class Preprocessors(object): | |||
| zero_shot_cls_tokenizer = 'zero-shot-cls-tokenizer' | |||
| text_error_correction = 'text-error-correction' | |||
| sentence_embedding = 'sentence-embedding' | |||
| passage_ranking = 'passage-ranking' | |||
| text_ranking = 'text-ranking' | |||
| sequence_labeling_tokenizer = 'sequence-labeling-tokenizer' | |||
| word_segment_text_to_label_preprocessor = 'word-segment-text-to-label-preprocessor' | |||
| thai_wseg_tokenizer = 'thai-wseg-tokenizer' | |||
| fill_mask = 'fill-mask' | |||
| fill_mask_ponet = 'fill-mask-ponet' | |||
| faq_question_answering_preprocessor = 'faq-question-answering-preprocessor' | |||
| @@ -320,6 +372,7 @@ class Preprocessors(object): | |||
| re_tokenizer = 're-tokenizer' | |||
| document_segmentation = 'document-segmentation' | |||
| feature_extraction = 'feature-extraction' | |||
| sentence_piece = 'sentence-piece' | |||
| # audio preprocessor | |||
| linear_aec_fbank = 'linear-aec-fbank' | |||
| @@ -331,6 +384,9 @@ class Preprocessors(object): | |||
| ofa_tasks_preprocessor = 'ofa-tasks-preprocessor' | |||
| mplug_tasks_preprocessor = 'mplug-tasks-preprocessor' | |||
| # science preprocessor | |||
| unifold_preprocessor = 'unifold-preprocessor' | |||
| class Metrics(object): | |||
| """ Names for different metrics. | |||
| @@ -340,6 +396,9 @@ class Metrics(object): | |||
| accuracy = 'accuracy' | |||
| audio_noise_metric = 'audio-noise-metric' | |||
| # text gen | |||
| BLEU = 'bleu' | |||
| # metrics for image denoise task | |||
| image_denoise_metric = 'image-denoise-metric' | |||
| @@ -358,6 +417,10 @@ class Metrics(object): | |||
| video_summarization_metric = 'video-summarization-metric' | |||
| # metric for movie-scene-segmentation task | |||
| movie_scene_segmentation_metric = 'movie-scene-segmentation-metric' | |||
| # metric for inpainting task | |||
| image_inpainting_metric = 'image-inpainting-metric' | |||
| # metric for ocr | |||
| NED = 'ned' | |||
| class Optimizers(object): | |||
| @@ -399,6 +462,9 @@ class Hooks(object): | |||
| IterTimerHook = 'IterTimerHook' | |||
| EvaluationHook = 'EvaluationHook' | |||
| # Compression | |||
| SparsityHook = 'SparsityHook' | |||
| class LR_Schedulers(object): | |||
| """learning rate scheduler is defined here | |||
| @@ -413,7 +479,10 @@ class Datasets(object): | |||
| """ Names for different datasets. | |||
| """ | |||
| ClsDataset = 'ClsDataset' | |||
| Face2dKeypointsDataset = 'Face2dKeypointsDataset' | |||
| Face2dKeypointsDataset = 'FaceKeypointDataset' | |||
| HandCocoWholeBodyDataset = 'HandCocoWholeBodyDataset' | |||
| HumanWholeBodyKeypointDataset = 'WholeBodyCocoTopDownDataset' | |||
| SegDataset = 'SegDataset' | |||
| DetDataset = 'DetDataset' | |||
| DetImagesMixDataset = 'DetImagesMixDataset' | |||
| PairedDataset = 'PairedDataset' | |||
| @@ -17,6 +17,9 @@ if TYPE_CHECKING: | |||
| from .token_classification_metric import TokenClassificationMetric | |||
| from .video_summarization_metric import VideoSummarizationMetric | |||
| from .movie_scene_segmentation_metric import MovieSceneSegmentationMetric | |||
| from .accuracy_metric import AccuracyMetric | |||
| from .bleu_metric import BleuMetric | |||
| from .image_inpainting_metric import ImageInpaintingMetric | |||
| else: | |||
| _import_structure = { | |||
| @@ -34,6 +37,9 @@ else: | |||
| 'token_classification_metric': ['TokenClassificationMetric'], | |||
| 'video_summarization_metric': ['VideoSummarizationMetric'], | |||
| 'movie_scene_segmentation_metric': ['MovieSceneSegmentationMetric'], | |||
| 'image_inpainting_metric': ['ImageInpaintingMetric'], | |||
| 'accuracy_metric': ['AccuracyMetric'], | |||
| 'bleu_metric': ['BleuMetric'], | |||
| } | |||
| import sys | |||
| @@ -0,0 +1,46 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from typing import Dict | |||
| import numpy as np | |||
| from modelscope.metainfo import Metrics | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.utils.registry import default_group | |||
| from .base import Metric | |||
| from .builder import METRICS, MetricKeys | |||
| @METRICS.register_module(group_key=default_group, module_name=Metrics.accuracy) | |||
| class AccuracyMetric(Metric): | |||
| """The metric computation class for classification classes. | |||
| This metric class calculates accuracy for the whole input batches. | |||
| """ | |||
| def __init__(self, *args, **kwargs): | |||
| super().__init__(*args, **kwargs) | |||
| self.preds = [] | |||
| self.labels = [] | |||
| def add(self, outputs: Dict, inputs: Dict): | |||
| label_name = OutputKeys.LABEL if OutputKeys.LABEL in inputs else OutputKeys.LABELS | |||
| ground_truths = inputs[label_name] | |||
| eval_results = outputs[label_name] | |||
| assert type(ground_truths) == type(eval_results) | |||
| if isinstance(ground_truths, list): | |||
| self.preds.extend(eval_results) | |||
| self.labels.extend(ground_truths) | |||
| elif isinstance(ground_truths, np.ndarray): | |||
| self.preds.extend(eval_results.tolist()) | |||
| self.labels.extend(ground_truths.tolist()) | |||
| else: | |||
| raise 'only support list or np.ndarray' | |||
| def evaluate(self): | |||
| assert len(self.preds) == len(self.labels) | |||
| return { | |||
| MetricKeys.ACCURACY: (np.asarray([ | |||
| pred == ref for pred, ref in zip(self.preds, self.labels) | |||
| ])).mean().item() | |||
| } | |||
| @@ -35,6 +35,8 @@ class AudioNoiseMetric(Metric): | |||
| total_loss = avg_loss + avg_amp + avg_phase + avg_sisnr | |||
| return { | |||
| 'total_loss': total_loss.item(), | |||
| 'avg_sisnr': avg_sisnr.item(), | |||
| # model use opposite number of sisnr as a calculation shortcut. | |||
| # revert it in evaluation result | |||
| 'avg_sisnr': -avg_sisnr.item(), | |||
| MetricKeys.AVERAGE_LOSS: avg_loss.item() | |||
| } | |||
| @@ -10,8 +10,8 @@ class Metric(ABC): | |||
| complex metrics for a specific task with or without other Metric subclasses. | |||
| """ | |||
| def __init__(self, trainer=None, *args, **kwargs): | |||
| self.trainer = trainer | |||
| def __init__(self, *args, **kwargs): | |||
| pass | |||
| @abstractmethod | |||
| def add(self, outputs: Dict, inputs: Dict): | |||
| @@ -0,0 +1,42 @@ | |||
| from itertools import zip_longest | |||
| from typing import Dict | |||
| import sacrebleu | |||
| from modelscope.metainfo import Metrics | |||
| from modelscope.utils.registry import default_group | |||
| from .base import Metric | |||
| from .builder import METRICS, MetricKeys | |||
| EVAL_BLEU_ORDER = 4 | |||
| @METRICS.register_module(group_key=default_group, module_name=Metrics.BLEU) | |||
| class BleuMetric(Metric): | |||
| """The metric computation bleu for text generation classes. | |||
| This metric class calculates accuracy for the whole input batches. | |||
| """ | |||
| def __init__(self, *args, **kwargs): | |||
| super().__init__(*args, **kwargs) | |||
| self.eval_tokenized_bleu = kwargs.get('eval_tokenized_bleu', False) | |||
| self.hyp_name = kwargs.get('hyp_name', 'hyp') | |||
| self.ref_name = kwargs.get('ref_name', 'ref') | |||
| self.refs = list() | |||
| self.hyps = list() | |||
| def add(self, outputs: Dict, inputs: Dict): | |||
| self.refs.extend(inputs[self.ref_name]) | |||
| self.hyps.extend(outputs[self.hyp_name]) | |||
| def evaluate(self): | |||
| if self.eval_tokenized_bleu: | |||
| bleu = sacrebleu.corpus_bleu( | |||
| self.hyps, list(zip_longest(*self.refs)), tokenize='none') | |||
| else: | |||
| bleu = sacrebleu.corpus_bleu(self.hyps, | |||
| list(zip_longest(*self.refs))) | |||
| return { | |||
| MetricKeys.BLEU_4: bleu.score, | |||
| } | |||
| @@ -18,10 +18,12 @@ class MetricKeys(object): | |||
| SSIM = 'ssim' | |||
| AVERAGE_LOSS = 'avg_loss' | |||
| FScore = 'fscore' | |||
| FID = 'fid' | |||
| BLEU_1 = 'bleu-1' | |||
| BLEU_4 = 'bleu-4' | |||
| ROUGE_1 = 'rouge-1' | |||
| ROUGE_L = 'rouge-l' | |||
| NED = 'ned' # ocr metric | |||
| task_default_metrics = { | |||
| @@ -31,6 +33,7 @@ task_default_metrics = { | |||
| Tasks.sentiment_classification: [Metrics.seq_cls_metric], | |||
| Tasks.token_classification: [Metrics.token_cls_metric], | |||
| Tasks.text_generation: [Metrics.text_gen_metric], | |||
| Tasks.text_classification: [Metrics.seq_cls_metric], | |||
| Tasks.image_denoising: [Metrics.image_denoise_metric], | |||
| Tasks.image_color_enhancement: [Metrics.image_color_enhance_metric], | |||
| Tasks.image_portrait_enhancement: | |||
| @@ -39,6 +42,7 @@ task_default_metrics = { | |||
| Tasks.image_captioning: [Metrics.text_gen_metric], | |||
| Tasks.visual_question_answering: [Metrics.text_gen_metric], | |||
| Tasks.movie_scene_segmentation: [Metrics.movie_scene_segmentation_metric], | |||
| Tasks.image_inpainting: [Metrics.image_inpainting_metric], | |||
| } | |||
| @@ -0,0 +1 @@ | |||
| __author__ = 'tylin' | |||
| @@ -0,0 +1,57 @@ | |||
| # Filename: ciderD.py | |||
| # | |||
| # Description: Describes the class to compute the CIDEr-D (Consensus-Based Image Description Evaluation) Metric | |||
| # by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726) | |||
| # | |||
| # Creation Date: Sun Feb 8 14:16:54 2015 | |||
| # | |||
| # Authors: Ramakrishna Vedantam <vrama91@vt.edu> and Tsung-Yi Lin <tl483@cornell.edu> | |||
| from __future__ import absolute_import, division, print_function | |||
| from .ciderD_scorer import CiderScorer | |||
| class CiderD: | |||
| """ | |||
| Main Class to compute the CIDEr metric | |||
| """ | |||
| def __init__(self, n=4, sigma=6.0, df='corpus'): | |||
| # set cider to sum over 1 to 4-grams | |||
| self._n = n | |||
| # set the standard deviation parameter for gaussian penalty | |||
| self._sigma = sigma | |||
| # set which where to compute document frequencies from | |||
| self._df = df | |||
| self.cider_scorer = CiderScorer(n=self._n, df_mode=self._df) | |||
| def compute_score(self, gts, res): | |||
| """ | |||
| Main function to compute CIDEr score | |||
| :param hypo_for_image (dict) : dictionary with key <image> and value <tokenized hypothesis / candidate sentence> | |||
| ref_for_image (dict) : dictionary with key <image> and value <tokenized reference sentence> | |||
| :return: cider (float) : computed CIDEr score for the corpus | |||
| """ # noqa | |||
| # clear all the previous hypos and refs | |||
| tmp_cider_scorer = self.cider_scorer.copy_empty() | |||
| tmp_cider_scorer.clear() | |||
| for res_id in res: | |||
| hypo = res_id['caption'] | |||
| ref = gts[res_id['image_id']] | |||
| # Sanity check. | |||
| assert (type(hypo) is list) | |||
| assert (len(hypo) == 1) | |||
| assert (type(ref) is list) | |||
| assert (len(ref) > 0) | |||
| tmp_cider_scorer += (hypo[0], ref) | |||
| (score, scores) = tmp_cider_scorer.compute_score() | |||
| return score, scores | |||
| def method(self): | |||
| return 'CIDEr-D' | |||
| @@ -0,0 +1,233 @@ | |||
| #!/usr/bin/env python | |||
| # Tsung-Yi Lin <tl483@cornell.edu> | |||
| # Ramakrishna Vedantam <vrama91@vt.edu> | |||
| from __future__ import absolute_import, division, print_function | |||
| import copy | |||
| import math | |||
| import os | |||
| import pdb | |||
| from collections import defaultdict | |||
| import numpy as np | |||
| import six | |||
| from six.moves import cPickle | |||
| def precook(s, n=4, out=False): | |||
| """ | |||
| Takes a string as input and returns an object that can be given to | |||
| either cook_refs or cook_test. This is optional: cook_refs and cook_test | |||
| can take string arguments as well. | |||
| :param s: string : sentence to be converted into ngrams | |||
| :param n: int : number of ngrams for which representation is calculated | |||
| :return: term frequency vector for occuring ngrams | |||
| """ | |||
| words = s.split() | |||
| counts = defaultdict(int) | |||
| for k in range(1, n + 1): | |||
| for i in range(len(words) - k + 1): | |||
| ngram = tuple(words[i:i + k]) | |||
| counts[ngram] += 1 | |||
| return counts | |||
| def cook_refs(refs, n=4): # lhuang: oracle will call with "average" | |||
| '''Takes a list of reference sentences for a single segment | |||
| and returns an object that encapsulates everything that BLEU | |||
| needs to know about them. | |||
| :param refs: list of string : reference sentences for some image | |||
| :param n: int : number of ngrams for which (ngram) representation is calculated | |||
| :return: result (list of dict) | |||
| ''' | |||
| return [precook(ref, n) for ref in refs] | |||
| def cook_test(test, n=4): | |||
| '''Takes a test sentence and returns an object that | |||
| encapsulates everything that BLEU needs to know about it. | |||
| :param test: list of string : hypothesis sentence for some image | |||
| :param n: int : number of ngrams for which (ngram) representation is calculated | |||
| :return: result (dict) | |||
| ''' | |||
| return precook(test, n, True) | |||
| class CiderScorer(object): | |||
| """CIDEr scorer. | |||
| """ | |||
| def copy(self): | |||
| ''' copy the refs.''' | |||
| new = CiderScorer(n=self.n) | |||
| new.ctest = copy.copy(self.ctest) | |||
| new.crefs = copy.copy(self.crefs) | |||
| return new | |||
| def copy_empty(self): | |||
| new = CiderScorer(df_mode='corpus', n=self.n, sigma=self.sigma) | |||
| new.df_mode = self.df_mode | |||
| new.ref_len = self.ref_len | |||
| new.document_frequency = self.document_frequency | |||
| return new | |||
| def __init__(self, df_mode='corpus', test=None, refs=None, n=4, sigma=6.0): | |||
| ''' singular instance ''' | |||
| self.n = n | |||
| self.sigma = sigma | |||
| self.crefs = [] | |||
| self.ctest = [] | |||
| self.df_mode = df_mode | |||
| self.ref_len = None | |||
| if self.df_mode != 'corpus': | |||
| pkl_file = cPickle.load( | |||
| open(df_mode, 'rb'), | |||
| **(dict(encoding='latin1') if six.PY3 else {})) | |||
| self.ref_len = np.log(float(pkl_file['ref_len'])) | |||
| self.document_frequency = pkl_file['document_frequency'] | |||
| else: | |||
| self.document_frequency = None | |||
| self.cook_append(test, refs) | |||
| def clear(self): | |||
| self.crefs = [] | |||
| self.ctest = [] | |||
| def cook_append(self, test, refs): | |||
| '''called by constructor and __iadd__ to avoid creating new instances.''' | |||
| if refs is not None: | |||
| self.crefs.append(cook_refs(refs)) | |||
| if test is not None: | |||
| self.ctest.append(cook_test(test)) # N.B.: -1 | |||
| else: | |||
| self.ctest.append( | |||
| None) # lens of crefs and ctest have to match | |||
| def size(self): | |||
| assert len(self.crefs) == len( | |||
| self.ctest), 'refs/test mismatch! %d<>%d' % (len( | |||
| self.crefs), len(self.ctest)) | |||
| return len(self.crefs) | |||
| def __iadd__(self, other): | |||
| '''add an instance (e.g., from another sentence).''' | |||
| if type(other) is tuple: | |||
| # avoid creating new CiderScorer instances | |||
| self.cook_append(other[0], other[1]) | |||
| else: | |||
| self.ctest.extend(other.ctest) | |||
| self.crefs.extend(other.crefs) | |||
| return self | |||
| def compute_doc_freq(self): | |||
| """ | |||
| Compute term frequency for reference data. | |||
| This will be used to compute idf (inverse document frequency later) | |||
| The term frequency is stored in the object | |||
| :return: None | |||
| """ | |||
| for refs in self.crefs: | |||
| # refs, k ref captions of one image | |||
| for ngram in set([ | |||
| ngram for ref in refs for (ngram, count) in ref.items() | |||
| ]): # noqa | |||
| self.document_frequency[ngram] += 1 | |||
| def compute_cider(self): | |||
| def counts2vec(cnts): | |||
| """ | |||
| Function maps counts of ngram to vector of tfidf weights. | |||
| The function returns vec, an array of dictionary that store mapping of n-gram and tf-idf weights. | |||
| The n-th entry of array denotes length of n-grams. | |||
| :param cnts: | |||
| :return: vec (array of dict), norm (array of float), length (int) | |||
| """ | |||
| vec = [defaultdict(float) for _ in range(self.n)] | |||
| length = 0 | |||
| norm = [0.0 for _ in range(self.n)] | |||
| for (ngram, term_freq) in cnts.items(): | |||
| # give word count 1 if it doesn't appear in reference corpus | |||
| df = np.log(max(1.0, self.document_frequency[ngram])) | |||
| # ngram index | |||
| n = len(ngram) - 1 | |||
| # tf (term_freq) * idf (precomputed idf) for n-grams | |||
| vec[n][ngram] = float(term_freq) * (self.ref_len - df) | |||
| # compute norm for the vector. the norm will be used for computing similarity | |||
| norm[n] += pow(vec[n][ngram], 2) | |||
| if n == 1: | |||
| length += term_freq | |||
| norm = [np.sqrt(n) for n in norm] | |||
| return vec, norm, length | |||
| def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref): | |||
| ''' | |||
| Compute the cosine similarity of two vectors. | |||
| :param vec_hyp: array of dictionary for vector corresponding to hypothesis | |||
| :param vec_ref: array of dictionary for vector corresponding to reference | |||
| :param norm_hyp: array of float for vector corresponding to hypothesis | |||
| :param norm_ref: array of float for vector corresponding to reference | |||
| :param length_hyp: int containing length of hypothesis | |||
| :param length_ref: int containing length of reference | |||
| :return: array of score for each n-grams cosine similarity | |||
| ''' | |||
| delta = float(length_hyp - length_ref) | |||
| # measure consine similarity | |||
| val = np.array([0.0 for _ in range(self.n)]) | |||
| for n in range(self.n): | |||
| # ngram | |||
| for (ngram, count) in vec_hyp[n].items(): | |||
| # vrama91 : added clipping | |||
| val[n] += min(vec_hyp[n][ngram], | |||
| vec_ref[n][ngram]) * vec_ref[n][ngram] | |||
| if (norm_hyp[n] != 0) and (norm_ref[n] != 0): | |||
| val[n] /= (norm_hyp[n] * norm_ref[n]) | |||
| assert (not math.isnan(val[n])) | |||
| # vrama91: added a length based gaussian penalty | |||
| val[n] *= np.e**(-(delta**2) / (2 * self.sigma**2)) | |||
| return val | |||
| # compute log reference length | |||
| if self.df_mode == 'corpus': | |||
| self.ref_len = np.log(float(len(self.crefs))) | |||
| # elif self.df_mode == "coco-val-df": | |||
| # if coco option selected, use length of coco-val set | |||
| # self.ref_len = np.log(float(40504)) | |||
| scores = [] | |||
| for test, refs in zip(self.ctest, self.crefs): | |||
| # compute vector for test captions | |||
| vec, norm, length = counts2vec(test) | |||
| # compute vector for ref captions | |||
| score = np.array([0.0 for _ in range(self.n)]) | |||
| for ref in refs: | |||
| vec_ref, norm_ref, length_ref = counts2vec(ref) | |||
| score += sim(vec, vec_ref, norm, norm_ref, length, length_ref) | |||
| # change by vrama91 - mean of ngram scores, instead of sum | |||
| score_avg = np.mean(score) | |||
| # divide by number of references | |||
| score_avg /= len(refs) | |||
| # multiply score by 10 | |||
| score_avg *= 10.0 | |||
| # append score of an image to the score list | |||
| scores.append(score_avg) | |||
| return scores | |||
| def compute_score(self, option=None, verbose=0): | |||
| # compute idf | |||
| if self.df_mode == 'corpus': | |||
| self.document_frequency = defaultdict(float) | |||
| self.compute_doc_freq() | |||
| # assert to check document frequency | |||
| assert (len(self.ctest) >= max(self.document_frequency.values())) | |||
| # import json for now and write the corresponding files | |||
| # compute cider score | |||
| score = self.compute_cider() | |||
| # debug | |||
| # print score | |||
| return np.mean(np.array(score)), np.array(score) | |||
| @@ -1,12 +1,16 @@ | |||
| # ------------------------------------------------------------------------ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| # ------------------------------------------------------------------------ | |||
| # modified from https://github.com/megvii-research/NAFNet/blob/main/basicsr/metrics/psnr_ssim.py | |||
| # ------------------------------------------------------------------------ | |||
| from typing import Dict | |||
| import cv2 | |||
| import numpy as np | |||
| from skimage.metrics import peak_signal_noise_ratio, structural_similarity | |||
| import torch | |||
| from modelscope.metainfo import Metrics | |||
| from modelscope.utils.registry import default_group | |||
| from modelscope.utils.tensor_utils import (torch_nested_detach, | |||
| torch_nested_numpify) | |||
| from .base import Metric | |||
| from .builder import METRICS, MetricKeys | |||
| @@ -20,26 +24,249 @@ class ImageDenoiseMetric(Metric): | |||
| label_name = 'target' | |||
| def __init__(self): | |||
| super(ImageDenoiseMetric, self).__init__() | |||
| self.preds = [] | |||
| self.labels = [] | |||
| def add(self, outputs: Dict, inputs: Dict): | |||
| ground_truths = outputs[ImageDenoiseMetric.label_name] | |||
| eval_results = outputs[ImageDenoiseMetric.pred_name] | |||
| self.preds.append( | |||
| torch_nested_numpify(torch_nested_detach(eval_results))) | |||
| self.labels.append( | |||
| torch_nested_numpify(torch_nested_detach(ground_truths))) | |||
| self.preds.append(eval_results) | |||
| self.labels.append(ground_truths) | |||
| def evaluate(self): | |||
| psnr_list, ssim_list = [], [] | |||
| for (pred, label) in zip(self.preds, self.labels): | |||
| psnr_list.append( | |||
| peak_signal_noise_ratio(label[0], pred[0], data_range=255)) | |||
| ssim_list.append( | |||
| structural_similarity( | |||
| label[0], pred[0], multichannel=True, data_range=255)) | |||
| psnr_list.append(calculate_psnr(label[0], pred[0], crop_border=0)) | |||
| ssim_list.append(calculate_ssim(label[0], pred[0], crop_border=0)) | |||
| return { | |||
| MetricKeys.PSNR: np.mean(psnr_list), | |||
| MetricKeys.SSIM: np.mean(ssim_list) | |||
| } | |||
| def reorder_image(img, input_order='HWC'): | |||
| """Reorder images to 'HWC' order. | |||
| If the input_order is (h, w), return (h, w, 1); | |||
| If the input_order is (c, h, w), return (h, w, c); | |||
| If the input_order is (h, w, c), return as it is. | |||
| Args: | |||
| img (ndarray): Input image. | |||
| input_order (str): Whether the input order is 'HWC' or 'CHW'. | |||
| If the input image shape is (h, w), input_order will not have | |||
| effects. Default: 'HWC'. | |||
| Returns: | |||
| ndarray: reordered image. | |||
| """ | |||
| if input_order not in ['HWC', 'CHW']: | |||
| raise ValueError( | |||
| f"Wrong input_order {input_order}. Supported input_orders are 'HWC' and 'CHW'" | |||
| ) | |||
| if len(img.shape) == 2: | |||
| img = img[..., None] | |||
| if input_order == 'CHW': | |||
| img = img.transpose(1, 2, 0) | |||
| return img | |||
| def calculate_psnr(img1, img2, crop_border, input_order='HWC'): | |||
| """Calculate PSNR (Peak Signal-to-Noise Ratio). | |||
| Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio | |||
| Args: | |||
| img1 (ndarray/tensor): Images with range [0, 255]/[0, 1]. | |||
| img2 (ndarray/tensor): Images with range [0, 255]/[0, 1]. | |||
| crop_border (int): Cropped pixels in each edge of an image. These | |||
| pixels are not involved in the PSNR calculation. | |||
| input_order (str): Whether the input order is 'HWC' or 'CHW'. | |||
| Default: 'HWC'. | |||
| test_y_channel (bool): Test on Y channel of YCbCr. Default: False. | |||
| Returns: | |||
| float: psnr result. | |||
| """ | |||
| assert img1.shape == img2.shape, ( | |||
| f'Image shapes are differnet: {img1.shape}, {img2.shape}.') | |||
| if input_order not in ['HWC', 'CHW']: | |||
| raise ValueError( | |||
| f'Wrong input_order {input_order}. Supported input_orders are ' | |||
| '"HWC" and "CHW"') | |||
| if type(img1) == torch.Tensor: | |||
| if len(img1.shape) == 4: | |||
| img1 = img1.squeeze(0) | |||
| img1 = img1.detach().cpu().numpy().transpose(1, 2, 0) | |||
| if type(img2) == torch.Tensor: | |||
| if len(img2.shape) == 4: | |||
| img2 = img2.squeeze(0) | |||
| img2 = img2.detach().cpu().numpy().transpose(1, 2, 0) | |||
| img1 = reorder_image(img1, input_order=input_order) | |||
| img2 = reorder_image(img2, input_order=input_order) | |||
| img1 = img1.astype(np.float64) | |||
| img2 = img2.astype(np.float64) | |||
| if crop_border != 0: | |||
| img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...] | |||
| img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] | |||
| def _psnr(img1, img2): | |||
| mse = np.mean((img1 - img2)**2) | |||
| if mse == 0: | |||
| return float('inf') | |||
| max_value = 1. if img1.max() <= 1 else 255. | |||
| return 20. * np.log10(max_value / np.sqrt(mse)) | |||
| return _psnr(img1, img2) | |||
| def calculate_ssim(img1, img2, crop_border, input_order='HWC', ssim3d=True): | |||
| """Calculate SSIM (structural similarity). | |||
| Ref: | |||
| Image quality assessment: From error visibility to structural similarity | |||
| The results are the same as that of the official released MATLAB code in | |||
| https://ece.uwaterloo.ca/~z70wang/research/ssim/. | |||
| For three-channel images, SSIM is calculated for each channel and then | |||
| averaged. | |||
| Args: | |||
| img1 (ndarray): Images with range [0, 255]. | |||
| img2 (ndarray): Images with range [0, 255]. | |||
| crop_border (int): Cropped pixels in each edge of an image. These | |||
| pixels are not involved in the SSIM calculation. | |||
| input_order (str): Whether the input order is 'HWC' or 'CHW'. | |||
| Default: 'HWC'. | |||
| test_y_channel (bool): Test on Y channel of YCbCr. Default: False. | |||
| Returns: | |||
| float: ssim result. | |||
| """ | |||
| assert img1.shape == img2.shape, ( | |||
| f'Image shapes are differnet: {img1.shape}, {img2.shape}.') | |||
| if input_order not in ['HWC', 'CHW']: | |||
| raise ValueError( | |||
| f'Wrong input_order {input_order}. Supported input_orders are ' | |||
| '"HWC" and "CHW"') | |||
| if type(img1) == torch.Tensor: | |||
| if len(img1.shape) == 4: | |||
| img1 = img1.squeeze(0) | |||
| img1 = img1.detach().cpu().numpy().transpose(1, 2, 0) | |||
| if type(img2) == torch.Tensor: | |||
| if len(img2.shape) == 4: | |||
| img2 = img2.squeeze(0) | |||
| img2 = img2.detach().cpu().numpy().transpose(1, 2, 0) | |||
| img1 = reorder_image(img1, input_order=input_order) | |||
| img2 = reorder_image(img2, input_order=input_order) | |||
| img1 = img1.astype(np.float64) | |||
| img2 = img2.astype(np.float64) | |||
| if crop_border != 0: | |||
| img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...] | |||
| img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] | |||
| def _cal_ssim(img1, img2): | |||
| ssims = [] | |||
| max_value = 1 if img1.max() <= 1 else 255 | |||
| with torch.no_grad(): | |||
| final_ssim = _ssim_3d(img1, img2, max_value) if ssim3d else _ssim( | |||
| img1, img2, max_value) | |||
| ssims.append(final_ssim) | |||
| return np.array(ssims).mean() | |||
| return _cal_ssim(img1, img2) | |||
| def _ssim(img, img2, max_value): | |||
| """Calculate SSIM (structural similarity) for one channel images. | |||
| It is called by func:`calculate_ssim`. | |||
| Args: | |||
| img (ndarray): Images with range [0, 255] with order 'HWC'. | |||
| img2 (ndarray): Images with range [0, 255] with order 'HWC'. | |||
| Returns: | |||
| float: SSIM result. | |||
| """ | |||
| c1 = (0.01 * max_value)**2 | |||
| c2 = (0.03 * max_value)**2 | |||
| img = img.astype(np.float64) | |||
| img2 = img2.astype(np.float64) | |||
| kernel = cv2.getGaussianKernel(11, 1.5) | |||
| window = np.outer(kernel, kernel.transpose()) | |||
| mu1 = cv2.filter2D(img, -1, window)[5:-5, | |||
| 5:-5] # valid mode for window size 11 | |||
| mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] | |||
| mu1_sq = mu1**2 | |||
| mu2_sq = mu2**2 | |||
| mu1_mu2 = mu1 * mu2 | |||
| sigma1_sq = cv2.filter2D(img**2, -1, window)[5:-5, 5:-5] - mu1_sq | |||
| sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq | |||
| sigma12 = cv2.filter2D(img * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 | |||
| tmp1 = (2 * mu1_mu2 + c1) * (2 * sigma12 + c2) | |||
| tmp2 = (mu1_sq + mu2_sq + c1) * (sigma1_sq + sigma2_sq + c2) | |||
| ssim_map = tmp1 / tmp2 | |||
| return ssim_map.mean() | |||
| def _3d_gaussian_calculator(img, conv3d): | |||
| out = conv3d(img.unsqueeze(0).unsqueeze(0)).squeeze(0).squeeze(0) | |||
| return out | |||
| def _generate_3d_gaussian_kernel(): | |||
| kernel = cv2.getGaussianKernel(11, 1.5) | |||
| window = np.outer(kernel, kernel.transpose()) | |||
| kernel_3 = cv2.getGaussianKernel(11, 1.5) | |||
| kernel = torch.tensor(np.stack([window * k for k in kernel_3], axis=0)) | |||
| conv3d = torch.nn.Conv3d( | |||
| 1, | |||
| 1, (11, 11, 11), | |||
| stride=1, | |||
| padding=(5, 5, 5), | |||
| bias=False, | |||
| padding_mode='replicate') | |||
| conv3d.weight.requires_grad = False | |||
| conv3d.weight[0, 0, :, :, :] = kernel | |||
| return conv3d | |||
| def _ssim_3d(img1, img2, max_value): | |||
| assert len(img1.shape) == 3 and len(img2.shape) == 3 | |||
| """Calculate SSIM (structural similarity) for one channel images. | |||
| It is called by func:`calculate_ssim`. | |||
| Args: | |||
| img1 (ndarray): Images with range [0, 255]/[0, 1] with order 'HWC'. | |||
| img2 (ndarray): Images with range [0, 255]/[0, 1] with order 'HWC'. | |||
| Returns: | |||
| float: ssim result. | |||
| """ | |||
| C1 = (0.01 * max_value)**2 | |||
| C2 = (0.03 * max_value)**2 | |||
| img1 = img1.astype(np.float64) | |||
| img2 = img2.astype(np.float64) | |||
| kernel = _generate_3d_gaussian_kernel().cuda() | |||
| img1 = torch.tensor(img1).float().cuda() | |||
| img2 = torch.tensor(img2).float().cuda() | |||
| mu1 = _3d_gaussian_calculator(img1, kernel) | |||
| mu2 = _3d_gaussian_calculator(img2, kernel) | |||
| mu1_sq = mu1**2 | |||
| mu2_sq = mu2**2 | |||
| mu1_mu2 = mu1 * mu2 | |||
| sigma1_sq = _3d_gaussian_calculator(img1**2, kernel) - mu1_sq | |||
| sigma2_sq = _3d_gaussian_calculator(img2**2, kernel) - mu2_sq | |||
| sigma12 = _3d_gaussian_calculator(img1 * img2, kernel) - mu1_mu2 | |||
| tmp1 = (2 * mu1_mu2 + C1) * (2 * sigma12 + C2) | |||
| tmp2 = (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2) | |||
| ssim_map = tmp1 / tmp2 | |||
| return float(ssim_map.mean()) | |||
| @@ -0,0 +1,210 @@ | |||
| """ | |||
| Part of the implementation is borrowed and modified from LaMa, publicly available at | |||
| https://github.com/saic-mdal/lama | |||
| """ | |||
| from typing import Dict | |||
| import numpy as np | |||
| import torch | |||
| import torch.nn.functional as F | |||
| from scipy import linalg | |||
| from modelscope.metainfo import Metrics | |||
| from modelscope.models.cv.image_inpainting.modules.inception import InceptionV3 | |||
| from modelscope.utils.registry import default_group | |||
| from modelscope.utils.tensor_utils import (torch_nested_detach, | |||
| torch_nested_numpify) | |||
| from .base import Metric | |||
| from .builder import METRICS, MetricKeys | |||
| def fid_calculate_activation_statistics(act): | |||
| mu = np.mean(act, axis=0) | |||
| sigma = np.cov(act, rowvar=False) | |||
| return mu, sigma | |||
| def calculate_frechet_distance(activations_pred, activations_target, eps=1e-6): | |||
| mu1, sigma1 = fid_calculate_activation_statistics(activations_pred) | |||
| mu2, sigma2 = fid_calculate_activation_statistics(activations_target) | |||
| diff = mu1 - mu2 | |||
| # Product might be almost singular | |||
| covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) | |||
| if not np.isfinite(covmean).all(): | |||
| offset = np.eye(sigma1.shape[0]) * eps | |||
| covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) | |||
| # Numerical error might give slight imaginary component | |||
| if np.iscomplexobj(covmean): | |||
| # if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): | |||
| if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-2): | |||
| m = np.max(np.abs(covmean.imag)) | |||
| raise ValueError('Imaginary component {}'.format(m)) | |||
| covmean = covmean.real | |||
| tr_covmean = np.trace(covmean) | |||
| return (diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) | |||
| - 2 * tr_covmean) | |||
| class FIDScore(torch.nn.Module): | |||
| def __init__(self, dims=2048, eps=1e-6): | |||
| super().__init__() | |||
| if getattr(FIDScore, '_MODEL', None) is None: | |||
| block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] | |||
| FIDScore._MODEL = InceptionV3([block_idx]).eval() | |||
| self.model = FIDScore._MODEL | |||
| self.eps = eps | |||
| self.reset() | |||
| def forward(self, pred_batch, target_batch, mask=None): | |||
| activations_pred = self._get_activations(pred_batch) | |||
| activations_target = self._get_activations(target_batch) | |||
| self.activations_pred.append(activations_pred.detach().cpu()) | |||
| self.activations_target.append(activations_target.detach().cpu()) | |||
| def get_value(self): | |||
| activations_pred, activations_target = (self.activations_pred, | |||
| self.activations_target) | |||
| activations_pred = torch.cat(activations_pred).cpu().numpy() | |||
| activations_target = torch.cat(activations_target).cpu().numpy() | |||
| total_distance = calculate_frechet_distance( | |||
| activations_pred, activations_target, eps=self.eps) | |||
| self.reset() | |||
| return total_distance | |||
| def reset(self): | |||
| self.activations_pred = [] | |||
| self.activations_target = [] | |||
| def _get_activations(self, batch): | |||
| activations = self.model(batch)[0] | |||
| if activations.shape[2] != 1 or activations.shape[3] != 1: | |||
| assert False, \ | |||
| 'We should not have got here, because Inception always scales inputs to 299x299' | |||
| activations = activations.squeeze(-1).squeeze(-1) | |||
| return activations | |||
| class SSIM(torch.nn.Module): | |||
| """SSIM. Modified from: | |||
| https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/pytorch_ssim/__init__.py | |||
| """ | |||
| def __init__(self, window_size=11, size_average=True): | |||
| super().__init__() | |||
| self.window_size = window_size | |||
| self.size_average = size_average | |||
| self.channel = 1 | |||
| self.register_buffer('window', | |||
| self._create_window(window_size, self.channel)) | |||
| def forward(self, img1, img2): | |||
| assert len(img1.shape) == 4 | |||
| channel = img1.size()[1] | |||
| if channel == self.channel and self.window.data.type( | |||
| ) == img1.data.type(): | |||
| window = self.window | |||
| else: | |||
| window = self._create_window(self.window_size, channel) | |||
| window = window.type_as(img1) | |||
| self.window = window | |||
| self.channel = channel | |||
| return self._ssim(img1, img2, window, self.window_size, channel, | |||
| self.size_average) | |||
| def _gaussian(self, window_size, sigma): | |||
| gauss = torch.Tensor([ | |||
| np.exp(-(x - (window_size // 2))**2 / float(2 * sigma**2)) | |||
| for x in range(window_size) | |||
| ]) | |||
| return gauss / gauss.sum() | |||
| def _create_window(self, window_size, channel): | |||
| _1D_window = self._gaussian(window_size, 1.5).unsqueeze(1) | |||
| _2D_window = _1D_window.mm( | |||
| _1D_window.t()).float().unsqueeze(0).unsqueeze(0) | |||
| return _2D_window.expand(channel, 1, window_size, | |||
| window_size).contiguous() | |||
| def _ssim(self, | |||
| img1, | |||
| img2, | |||
| window, | |||
| window_size, | |||
| channel, | |||
| size_average=True): | |||
| mu1 = F.conv2d( | |||
| img1, window, padding=(window_size // 2), groups=channel) | |||
| mu2 = F.conv2d( | |||
| img2, window, padding=(window_size // 2), groups=channel) | |||
| mu1_sq = mu1.pow(2) | |||
| mu2_sq = mu2.pow(2) | |||
| mu1_mu2 = mu1 * mu2 | |||
| sigma1_sq = F.conv2d( | |||
| img1 * img1, window, padding=(window_size // 2), | |||
| groups=channel) - mu1_sq | |||
| sigma2_sq = F.conv2d( | |||
| img2 * img2, window, padding=(window_size // 2), | |||
| groups=channel) - mu2_sq | |||
| sigma12 = F.conv2d( | |||
| img1 * img2, window, padding=(window_size // 2), | |||
| groups=channel) - mu1_mu2 | |||
| C1 = 0.01**2 | |||
| C2 = 0.03**2 | |||
| ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / \ | |||
| ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) | |||
| if size_average: | |||
| return ssim_map.mean() | |||
| return ssim_map.mean(1).mean(1).mean(1) | |||
| def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, | |||
| missing_keys, unexpected_keys, error_msgs): | |||
| return | |||
| @METRICS.register_module( | |||
| group_key=default_group, module_name=Metrics.image_inpainting_metric) | |||
| class ImageInpaintingMetric(Metric): | |||
| """The metric computation class for image inpainting classes. | |||
| """ | |||
| def __init__(self): | |||
| self.preds = [] | |||
| self.targets = [] | |||
| self.SSIM = SSIM(window_size=11, size_average=False).eval() | |||
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |||
| self.FID = FIDScore().to(device) | |||
| def add(self, outputs: Dict, inputs: Dict): | |||
| pred = outputs['inpainted'] | |||
| target = inputs['image'] | |||
| self.preds.append(torch_nested_detach(pred)) | |||
| self.targets.append(torch_nested_detach(target)) | |||
| def evaluate(self): | |||
| ssim_list = [] | |||
| for (pred, target) in zip(self.preds, self.targets): | |||
| ssim_list.append(self.SSIM(pred, target)) | |||
| self.FID(pred, target) | |||
| ssim_list = torch_nested_numpify(ssim_list) | |||
| fid = self.FID.get_value() | |||
| return {MetricKeys.SSIM: np.mean(ssim_list), MetricKeys.FID: fid} | |||
| @@ -1,5 +1,8 @@ | |||
| # Part of the implementation is borrowed and modified from BasicSR, publicly available at | |||
| # https://github.com/XPixelGroup/BasicSR/blob/master/basicsr/metrics/psnr_ssim.py | |||
| from typing import Dict | |||
| import cv2 | |||
| import numpy as np | |||
| from modelscope.metainfo import Metrics | |||
| @@ -35,6 +38,7 @@ class ImagePortraitEnhancementMetric(Metric): | |||
| def add(self, outputs: Dict, inputs: Dict): | |||
| ground_truths = outputs['target'] | |||
| eval_results = outputs['pred'] | |||
| self.preds.extend(eval_results) | |||
| self.targets.extend(ground_truths) | |||
| @@ -34,17 +34,24 @@ class TokenClassificationMetric(Metric): | |||
| self.labels.append( | |||
| torch_nested_numpify(torch_nested_detach(ground_truths))) | |||
| def __init__(self, return_entity_level_metrics=False, *args, **kwargs): | |||
| def __init__(self, | |||
| return_entity_level_metrics=False, | |||
| label2id=None, | |||
| *args, | |||
| **kwargs): | |||
| super().__init__(*args, **kwargs) | |||
| self.return_entity_level_metrics = return_entity_level_metrics | |||
| self.preds = [] | |||
| self.labels = [] | |||
| self.label2id = label2id | |||
| def evaluate(self): | |||
| self.id2label = { | |||
| id: label | |||
| for label, id in self.trainer.label2id.items() | |||
| } | |||
| label2id = self.label2id | |||
| if label2id is None: | |||
| assert hasattr(self, 'trainer') | |||
| label2id = self.trainer.label2id | |||
| self.id2label = {id: label for label, id in label2id.items()} | |||
| self.preds = np.concatenate(self.preds, axis=0) | |||
| self.labels = np.concatenate(self.labels, axis=0) | |||
| predictions = np.argmax(self.preds, axis=-1) | |||
| @@ -1,3 +1,6 @@ | |||
| # Part of the implementation is borrowed and modified from PGL-SUM, | |||
| # publicly available at https://github.com/e-apostolidis/PGL-SUM | |||
| from typing import Dict | |||
| import numpy as np | |||
| @@ -1,3 +1,5 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| from typing import Any, Dict | |||
| @@ -1,15 +1,14 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| from typing import Dict | |||
| import torch | |||
| from typing import Dict, Optional | |||
| from modelscope.metainfo import Models | |||
| from modelscope.models import TorchModel | |||
| from modelscope.models.base import Tensor | |||
| from modelscope.models.builder import MODELS | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| from modelscope.utils.audio.audio_utils import update_conf | |||
| from modelscope.utils.constant import Tasks | |||
| from .fsmn_sele_v2 import FSMNSeleNetV2 | |||
| @@ -20,48 +19,38 @@ class FSMNSeleNetV2Decorator(TorchModel): | |||
| MODEL_TXT = 'model.txt' | |||
| SC_CONFIG = 'sound_connect.conf' | |||
| SC_CONF_ITEM_KWS_MODEL = '${kws_model}' | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| def __init__(self, | |||
| model_dir: str, | |||
| training: Optional[bool] = False, | |||
| *args, | |||
| **kwargs): | |||
| """initialize the dfsmn model from the `model_dir` path. | |||
| Args: | |||
| model_dir (str): the model path. | |||
| """ | |||
| super().__init__(model_dir, *args, **kwargs) | |||
| sc_config_file = os.path.join(model_dir, self.SC_CONFIG) | |||
| model_txt_file = os.path.join(model_dir, self.MODEL_TXT) | |||
| model_bin_file = os.path.join(model_dir, | |||
| ModelFile.TORCH_MODEL_BIN_FILE) | |||
| self._model = None | |||
| if os.path.exists(model_bin_file): | |||
| kwargs.pop('device') | |||
| self._model = FSMNSeleNetV2(*args, **kwargs) | |||
| checkpoint = torch.load(model_bin_file) | |||
| self._model.load_state_dict(checkpoint, strict=False) | |||
| self._sc = None | |||
| if os.path.exists(model_txt_file): | |||
| with open(sc_config_file) as f: | |||
| lines = f.readlines() | |||
| with open(sc_config_file, 'w') as f: | |||
| for line in lines: | |||
| if self.SC_CONF_ITEM_KWS_MODEL in line: | |||
| line = line.replace(self.SC_CONF_ITEM_KWS_MODEL, | |||
| model_txt_file) | |||
| f.write(line) | |||
| import py_sound_connect | |||
| self._sc = py_sound_connect.SoundConnect(sc_config_file) | |||
| self.size_in = self._sc.bytesPerBlockIn() | |||
| self.size_out = self._sc.bytesPerBlockOut() | |||
| if self._model is None and self._sc is None: | |||
| raise Exception( | |||
| f'Invalid model directory! Neither {model_txt_file} nor {model_bin_file} exists.' | |||
| ) | |||
| if training: | |||
| self.model = FSMNSeleNetV2(*args, **kwargs) | |||
| else: | |||
| sc_config_file = os.path.join(model_dir, self.SC_CONFIG) | |||
| model_txt_file = os.path.join(model_dir, self.MODEL_TXT) | |||
| self._sc = None | |||
| if os.path.exists(model_txt_file): | |||
| conf_dict = dict(mode=56542, kws_model=model_txt_file) | |||
| update_conf(sc_config_file, sc_config_file, conf_dict) | |||
| import py_sound_connect | |||
| self._sc = py_sound_connect.SoundConnect(sc_config_file) | |||
| self.size_in = self._sc.bytesPerBlockIn() | |||
| self.size_out = self._sc.bytesPerBlockOut() | |||
| else: | |||
| raise Exception( | |||
| f'Invalid model directory! Failed to load model file: {model_txt_file}.' | |||
| ) | |||
| def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||
| ... | |||
| return self.model.forward(input) | |||
| def forward_decode(self, data: bytes): | |||
| result = {'pcm': self._sc.process(data, self.size_out)} | |||
| @@ -1,3 +1,5 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| from typing import Any, Dict | |||
| @@ -2,6 +2,7 @@ | |||
| import os | |||
| import pickle as pkl | |||
| from threading import Lock | |||
| import json | |||
| import numpy as np | |||
| @@ -27,6 +28,7 @@ class Voice: | |||
| self.__am_config = AttrDict(**am_config) | |||
| self.__voc_config = AttrDict(**voc_config) | |||
| self.__model_loaded = False | |||
| self.__lock = Lock() | |||
| if 'am' not in self.__am_config: | |||
| raise TtsModelConfigurationException( | |||
| 'modelscope error: am configuration invalid') | |||
| @@ -71,34 +73,35 @@ class Voice: | |||
| self.__generator.remove_weight_norm() | |||
| def __am_forward(self, symbol_seq): | |||
| with torch.no_grad(): | |||
| inputs_feat_lst = self.__ling_unit.encode_symbol_sequence( | |||
| symbol_seq) | |||
| inputs_sy = torch.from_numpy(inputs_feat_lst[0]).long().to( | |||
| self.__device) | |||
| inputs_tone = torch.from_numpy(inputs_feat_lst[1]).long().to( | |||
| self.__device) | |||
| inputs_syllable = torch.from_numpy(inputs_feat_lst[2]).long().to( | |||
| self.__device) | |||
| inputs_ws = torch.from_numpy(inputs_feat_lst[3]).long().to( | |||
| self.__device) | |||
| inputs_ling = torch.stack( | |||
| [inputs_sy, inputs_tone, inputs_syllable, inputs_ws], | |||
| dim=-1).unsqueeze(0) | |||
| inputs_emo = torch.from_numpy(inputs_feat_lst[4]).long().to( | |||
| self.__device).unsqueeze(0) | |||
| inputs_spk = torch.from_numpy(inputs_feat_lst[5]).long().to( | |||
| self.__device).unsqueeze(0) | |||
| inputs_len = torch.zeros(1).to(self.__device).long( | |||
| ) + inputs_emo.size(1) - 1 # minus 1 for "~" | |||
| res = self.__am_net(inputs_ling[:, :-1, :], inputs_emo[:, :-1], | |||
| inputs_spk[:, :-1], inputs_len) | |||
| postnet_outputs = res['postnet_outputs'] | |||
| LR_length_rounded = res['LR_length_rounded'] | |||
| valid_length = int(LR_length_rounded[0].item()) | |||
| postnet_outputs = postnet_outputs[ | |||
| 0, :valid_length, :].cpu().numpy() | |||
| return postnet_outputs | |||
| with self.__lock: | |||
| with torch.no_grad(): | |||
| inputs_feat_lst = self.__ling_unit.encode_symbol_sequence( | |||
| symbol_seq) | |||
| inputs_sy = torch.from_numpy(inputs_feat_lst[0]).long().to( | |||
| self.__device) | |||
| inputs_tone = torch.from_numpy(inputs_feat_lst[1]).long().to( | |||
| self.__device) | |||
| inputs_syllable = torch.from_numpy( | |||
| inputs_feat_lst[2]).long().to(self.__device) | |||
| inputs_ws = torch.from_numpy(inputs_feat_lst[3]).long().to( | |||
| self.__device) | |||
| inputs_ling = torch.stack( | |||
| [inputs_sy, inputs_tone, inputs_syllable, inputs_ws], | |||
| dim=-1).unsqueeze(0) | |||
| inputs_emo = torch.from_numpy(inputs_feat_lst[4]).long().to( | |||
| self.__device).unsqueeze(0) | |||
| inputs_spk = torch.from_numpy(inputs_feat_lst[5]).long().to( | |||
| self.__device).unsqueeze(0) | |||
| inputs_len = torch.zeros(1).to(self.__device).long( | |||
| ) + inputs_emo.size(1) - 1 # minus 1 for "~" | |||
| res = self.__am_net(inputs_ling[:, :-1, :], inputs_emo[:, :-1], | |||
| inputs_spk[:, :-1], inputs_len) | |||
| postnet_outputs = res['postnet_outputs'] | |||
| LR_length_rounded = res['LR_length_rounded'] | |||
| valid_length = int(LR_length_rounded[0].item()) | |||
| postnet_outputs = postnet_outputs[ | |||
| 0, :valid_length, :].cpu().numpy() | |||
| return postnet_outputs | |||
| def __vocoder_forward(self, melspec): | |||
| dim0 = list(melspec.shape)[-1] | |||
| @@ -118,14 +121,15 @@ class Voice: | |||
| return audio | |||
| def forward(self, symbol_seq): | |||
| if not self.__model_loaded: | |||
| torch.manual_seed(self.__am_config.seed) | |||
| if torch.cuda.is_available(): | |||
| with self.__lock: | |||
| if not self.__model_loaded: | |||
| torch.manual_seed(self.__am_config.seed) | |||
| self.__device = torch.device('cuda') | |||
| else: | |||
| self.__device = torch.device('cpu') | |||
| self.__load_am() | |||
| self.__load_vocoder() | |||
| self.__model_loaded = True | |||
| if torch.cuda.is_available(): | |||
| torch.manual_seed(self.__am_config.seed) | |||
| self.__device = torch.device('cuda') | |||
| else: | |||
| self.__device = torch.device('cpu') | |||
| self.__load_am() | |||
| self.__load_vocoder() | |||
| self.__model_loaded = True | |||
| return self.__vocoder_forward(self.__am_forward(symbol_seq)) | |||
| @@ -5,11 +5,11 @@ from abc import ABC, abstractmethod | |||
| from typing import Any, Callable, Dict, List, Optional, Union | |||
| from modelscope.hub.snapshot_download import snapshot_download | |||
| from modelscope.models.builder import build_model | |||
| from modelscope.utils.checkpoint import save_pretrained | |||
| from modelscope.models.builder import MODELS, build_model | |||
| from modelscope.utils.checkpoint import save_checkpoint, save_pretrained | |||
| from modelscope.utils.config import Config | |||
| from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile | |||
| from modelscope.utils.device import device_placement, verify_device | |||
| from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile, Tasks | |||
| from modelscope.utils.device import verify_device | |||
| from modelscope.utils.logger import get_logger | |||
| logger = get_logger() | |||
| @@ -66,7 +66,6 @@ class Model(ABC): | |||
| revision: Optional[str] = DEFAULT_MODEL_REVISION, | |||
| cfg_dict: Config = None, | |||
| device: str = None, | |||
| *model_args, | |||
| **kwargs): | |||
| """ Instantiate a model from local directory or remote model repo. Note | |||
| that when loading from remote, the model revision can be specified. | |||
| @@ -90,11 +89,11 @@ class Model(ABC): | |||
| cfg = Config.from_file( | |||
| osp.join(local_model_dir, ModelFile.CONFIGURATION)) | |||
| task_name = cfg.task | |||
| if 'task' in kwargs: | |||
| task_name = kwargs.pop('task') | |||
| model_cfg = cfg.model | |||
| if hasattr(model_cfg, 'model_type') and not hasattr(model_cfg, 'type'): | |||
| model_cfg.type = model_cfg.model_type | |||
| model_cfg.model_dir = local_model_dir | |||
| for k, v in kwargs.items(): | |||
| model_cfg[k] = v | |||
| @@ -109,15 +108,19 @@ class Model(ABC): | |||
| # dynamically add pipeline info to model for pipeline inference | |||
| if hasattr(cfg, 'pipeline'): | |||
| model.pipeline = cfg.pipeline | |||
| if not hasattr(model, 'cfg'): | |||
| model.cfg = cfg | |||
| return model | |||
| def save_pretrained(self, | |||
| target_folder: Union[str, os.PathLike], | |||
| save_checkpoint_names: Union[str, List[str]] = None, | |||
| save_function: Callable = None, | |||
| save_function: Callable = save_checkpoint, | |||
| config: Optional[dict] = None, | |||
| **kwargs): | |||
| """save the pretrained model, its configuration and other related files to a directory, so that it can be re-loaded | |||
| """save the pretrained model, its configuration and other related files to a directory, | |||
| so that it can be re-loaded | |||
| Args: | |||
| target_folder (Union[str, os.PathLike]): | |||
| @@ -133,5 +136,10 @@ class Model(ABC): | |||
| The config for the configuration.json, might not be identical with model.config | |||
| """ | |||
| if config is None and hasattr(self, 'cfg'): | |||
| config = self.cfg | |||
| assert config is not None, 'Cannot save the model because the model config is empty.' | |||
| if isinstance(config, Config): | |||
| config = config.to_dict() | |||
| save_pretrained(self, target_folder, save_checkpoint_names, | |||
| save_function, config, **kwargs) | |||
| @@ -1,12 +1,20 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from modelscope.utils.config import ConfigDict | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.import_utils import INDEX_KEY, LazyImportModule | |||
| from modelscope.utils.registry import TYPE_NAME, Registry, build_from_cfg | |||
| MODELS = Registry('models') | |||
| BACKBONES = Registry('backbones') | |||
| BACKBONES = MODELS | |||
| HEADS = Registry('heads') | |||
| modules = LazyImportModule.AST_INDEX[INDEX_KEY] | |||
| for module_index in list(modules.keys()): | |||
| if module_index[1] == Tasks.backbone and module_index[0] == 'BACKBONES': | |||
| modules[(MODELS.name.upper(), module_index[1], | |||
| module_index[2])] = modules[module_index] | |||
| def build_model(cfg: ConfigDict, | |||
| task_name: str = None, | |||
| @@ -23,30 +31,27 @@ def build_model(cfg: ConfigDict, | |||
| cfg, MODELS, group_key=task_name, default_args=default_args) | |||
| def build_backbone(cfg: ConfigDict, | |||
| field: str = None, | |||
| default_args: dict = None): | |||
| def build_backbone(cfg: ConfigDict, default_args: dict = None): | |||
| """ build backbone given backbone config dict | |||
| Args: | |||
| cfg (:obj:`ConfigDict`): config dict for backbone object. | |||
| field (str, optional): field, such as CV, NLP's backbone | |||
| default_args (dict, optional): Default initialization arguments. | |||
| """ | |||
| return build_from_cfg( | |||
| cfg, BACKBONES, group_key=field, default_args=default_args) | |||
| cfg, BACKBONES, group_key=Tasks.backbone, default_args=default_args) | |||
| def build_head(cfg: ConfigDict, | |||
| group_key: str = None, | |||
| task_name: str = None, | |||
| default_args: dict = None): | |||
| """ build head given config dict | |||
| Args: | |||
| cfg (:obj:`ConfigDict`): config dict for head object. | |||
| task_name (str, optional): task name, refer to | |||
| :obj:`Tasks` for more details | |||
| default_args (dict, optional): Default initialization arguments. | |||
| """ | |||
| if group_key is None: | |||
| group_key = cfg[TYPE_NAME] | |||
| return build_from_cfg( | |||
| cfg, HEADS, group_key=group_key, default_args=default_args) | |||
| cfg, HEADS, group_key=task_name, default_args=default_args) | |||
| @@ -4,14 +4,16 @@ | |||
| from . import (action_recognition, animal_recognition, body_2d_keypoints, | |||
| body_3d_keypoints, cartoon, cmdssl_video_embedding, | |||
| crowd_counting, face_2d_keypoints, face_detection, | |||
| face_generation, image_classification, image_color_enhance, | |||
| image_colorization, image_denoise, image_instance_segmentation, | |||
| face_generation, human_wholebody_keypoint, image_classification, | |||
| image_color_enhance, image_colorization, image_denoise, | |||
| image_inpainting, image_instance_segmentation, | |||
| image_panoptic_segmentation, image_portrait_enhancement, | |||
| image_reid_person, image_semantic_segmentation, | |||
| image_to_image_generation, image_to_image_translation, | |||
| movie_scene_segmentation, object_detection, | |||
| product_retrieval_embedding, realtime_object_detection, | |||
| salient_detection, shop_segmentation, super_resolution, | |||
| referring_video_object_segmentation, salient_detection, | |||
| shop_segmentation, super_resolution, | |||
| video_single_object_tracking, video_summarization, virual_tryon) | |||
| # yapf: enable | |||
| @@ -4,6 +4,7 @@ import os | |||
| import os.path as osp | |||
| import shutil | |||
| import subprocess | |||
| import uuid | |||
| import cv2 | |||
| import numpy as np | |||
| @@ -84,7 +85,9 @@ class ActionDetONNX(Model): | |||
| def forward_video(self, video_name, scale): | |||
| min_size, max_size = self._get_sizes(scale) | |||
| tmp_dir = osp.join(self.tmp_dir, osp.basename(video_name)[:-4]) | |||
| tmp_dir = osp.join( | |||
| self.tmp_dir, | |||
| str(uuid.uuid1()) + '_' + osp.basename(video_name)[:-4]) | |||
| if osp.exists(tmp_dir): | |||
| shutil.rmtree(tmp_dir) | |||
| os.makedirs(tmp_dir) | |||
| @@ -110,6 +113,7 @@ class ActionDetONNX(Model): | |||
| len(frame_names) * self.temporal_stride, | |||
| self.temporal_stride)) | |||
| batch_imgs = [self.parse_frames(names) for names in frame_names] | |||
| shutil.rmtree(tmp_dir) | |||
| N, _, T, H, W = batch_imgs[0].shape | |||
| scale_min = min_size / min(H, W) | |||
| @@ -128,7 +132,6 @@ class ActionDetONNX(Model): | |||
| 'timestamp': t, | |||
| 'actions': res | |||
| } for t, res in zip(timestamp, results)] | |||
| shutil.rmtree(tmp_dir) | |||
| return results | |||
| def forward(self, video_name): | |||
| @@ -1,3 +1,5 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| from typing import Any, Dict, Optional, Union | |||
| @@ -1,10 +1,10 @@ | |||
| # ------------------------------------------------------------------------------ | |||
| # Copyright (c) Microsoft | |||
| # Licensed under the MIT License. | |||
| # Written by Bin Xiao (Bin.Xiao@microsoft.com) | |||
| # Modified by Ke Sun (sunk@mail.ustc.edu.cn) | |||
| # https://github.com/HRNet/HRNet-Image-Classification/blob/master/lib/models/cls_hrnet.py | |||
| # ------------------------------------------------------------------------------ | |||
| """ | |||
| Copyright (c) Microsoft | |||
| Licensed under the MIT License. | |||
| Written by Bin Xiao (Bin.Xiao@microsoft.com) | |||
| Modified by Ke Sun (sunk@mail.ustc.edu.cn) | |||
| https://github.com/HRNet/HRNet-Image-Classification/blob/master/lib/models/cls_hrnet.py | |||
| """ | |||
| import functools | |||
| import logging | |||
| @@ -8,12 +8,14 @@ if TYPE_CHECKING: | |||
| from .mtcnn import MtcnnFaceDetector | |||
| from .retinaface import RetinaFaceDetection | |||
| from .ulfd_slim import UlfdFaceDetector | |||
| from .scrfd import ScrfdDetect | |||
| else: | |||
| _import_structure = { | |||
| 'ulfd_slim': ['UlfdFaceDetector'], | |||
| 'retinaface': ['RetinaFaceDetection'], | |||
| 'mtcnn': ['MtcnnFaceDetector'], | |||
| 'mogface': ['MogFaceDetector'] | |||
| 'mogface': ['MogFaceDetector'], | |||
| 'scrfd': ['ScrfdDetect'] | |||
| } | |||
| import sys | |||
| @@ -1,189 +0,0 @@ | |||
| """ | |||
| The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at | |||
| https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/datasets/pipelines/transforms.py | |||
| """ | |||
| import numpy as np | |||
| from mmdet.datasets.builder import PIPELINES | |||
| from numpy import random | |||
| @PIPELINES.register_module() | |||
| class RandomSquareCrop(object): | |||
| """Random crop the image & bboxes, the cropped patches have minimum IoU | |||
| requirement with original image & bboxes, the IoU threshold is randomly | |||
| selected from min_ious. | |||
| Args: | |||
| min_ious (tuple): minimum IoU threshold for all intersections with | |||
| bounding boxes | |||
| min_crop_size (float): minimum crop's size (i.e. h,w := a*h, a*w, | |||
| where a >= min_crop_size). | |||
| Note: | |||
| The keys for bboxes, labels and masks should be paired. That is, \ | |||
| `gt_bboxes` corresponds to `gt_labels` and `gt_masks`, and \ | |||
| `gt_bboxes_ignore` to `gt_labels_ignore` and `gt_masks_ignore`. | |||
| """ | |||
| def __init__(self, | |||
| crop_ratio_range=None, | |||
| crop_choice=None, | |||
| bbox_clip_border=True): | |||
| self.crop_ratio_range = crop_ratio_range | |||
| self.crop_choice = crop_choice | |||
| self.bbox_clip_border = bbox_clip_border | |||
| assert (self.crop_ratio_range is None) ^ (self.crop_choice is None) | |||
| if self.crop_ratio_range is not None: | |||
| self.crop_ratio_min, self.crop_ratio_max = self.crop_ratio_range | |||
| self.bbox2label = { | |||
| 'gt_bboxes': 'gt_labels', | |||
| 'gt_bboxes_ignore': 'gt_labels_ignore' | |||
| } | |||
| self.bbox2mask = { | |||
| 'gt_bboxes': 'gt_masks', | |||
| 'gt_bboxes_ignore': 'gt_masks_ignore' | |||
| } | |||
| def __call__(self, results): | |||
| """Call function to crop images and bounding boxes with minimum IoU | |||
| constraint. | |||
| Args: | |||
| results (dict): Result dict from loading pipeline. | |||
| Returns: | |||
| dict: Result dict with images and bounding boxes cropped, \ | |||
| 'img_shape' key is updated. | |||
| """ | |||
| if 'img_fields' in results: | |||
| assert results['img_fields'] == ['img'], \ | |||
| 'Only single img_fields is allowed' | |||
| img = results['img'] | |||
| assert 'bbox_fields' in results | |||
| assert 'gt_bboxes' in results | |||
| boxes = results['gt_bboxes'] | |||
| h, w, c = img.shape | |||
| scale_retry = 0 | |||
| if self.crop_ratio_range is not None: | |||
| max_scale = self.crop_ratio_max | |||
| else: | |||
| max_scale = np.amax(self.crop_choice) | |||
| while True: | |||
| scale_retry += 1 | |||
| if scale_retry == 1 or max_scale > 1.0: | |||
| if self.crop_ratio_range is not None: | |||
| scale = np.random.uniform(self.crop_ratio_min, | |||
| self.crop_ratio_max) | |||
| elif self.crop_choice is not None: | |||
| scale = np.random.choice(self.crop_choice) | |||
| else: | |||
| scale = scale * 1.2 | |||
| for i in range(250): | |||
| short_side = min(w, h) | |||
| cw = int(scale * short_side) | |||
| ch = cw | |||
| # TODO +1 | |||
| if w == cw: | |||
| left = 0 | |||
| elif w > cw: | |||
| left = random.randint(0, w - cw) | |||
| else: | |||
| left = random.randint(w - cw, 0) | |||
| if h == ch: | |||
| top = 0 | |||
| elif h > ch: | |||
| top = random.randint(0, h - ch) | |||
| else: | |||
| top = random.randint(h - ch, 0) | |||
| patch = np.array( | |||
| (int(left), int(top), int(left + cw), int(top + ch)), | |||
| dtype=np.int) | |||
| # center of boxes should inside the crop img | |||
| # only adjust boxes and instance masks when the gt is not empty | |||
| # adjust boxes | |||
| def is_center_of_bboxes_in_patch(boxes, patch): | |||
| # TODO >= | |||
| center = (boxes[:, :2] + boxes[:, 2:]) / 2 | |||
| mask = \ | |||
| ((center[:, 0] > patch[0]) | |||
| * (center[:, 1] > patch[1]) | |||
| * (center[:, 0] < patch[2]) | |||
| * (center[:, 1] < patch[3])) | |||
| return mask | |||
| mask = is_center_of_bboxes_in_patch(boxes, patch) | |||
| if not mask.any(): | |||
| continue | |||
| for key in results.get('bbox_fields', []): | |||
| boxes = results[key].copy() | |||
| mask = is_center_of_bboxes_in_patch(boxes, patch) | |||
| boxes = boxes[mask] | |||
| if self.bbox_clip_border: | |||
| boxes[:, 2:] = boxes[:, 2:].clip(max=patch[2:]) | |||
| boxes[:, :2] = boxes[:, :2].clip(min=patch[:2]) | |||
| boxes -= np.tile(patch[:2], 2) | |||
| results[key] = boxes | |||
| # labels | |||
| label_key = self.bbox2label.get(key) | |||
| if label_key in results: | |||
| results[label_key] = results[label_key][mask] | |||
| # keypoints field | |||
| if key == 'gt_bboxes': | |||
| for kps_key in results.get('keypoints_fields', []): | |||
| keypointss = results[kps_key].copy() | |||
| keypointss = keypointss[mask, :, :] | |||
| if self.bbox_clip_border: | |||
| keypointss[:, :, : | |||
| 2] = keypointss[:, :, :2].clip( | |||
| max=patch[2:]) | |||
| keypointss[:, :, : | |||
| 2] = keypointss[:, :, :2].clip( | |||
| min=patch[:2]) | |||
| keypointss[:, :, 0] -= patch[0] | |||
| keypointss[:, :, 1] -= patch[1] | |||
| results[kps_key] = keypointss | |||
| # mask fields | |||
| mask_key = self.bbox2mask.get(key) | |||
| if mask_key in results: | |||
| results[mask_key] = results[mask_key][mask.nonzero() | |||
| [0]].crop(patch) | |||
| # adjust the img no matter whether the gt is empty before crop | |||
| rimg = np.ones((ch, cw, 3), dtype=img.dtype) * 128 | |||
| patch_from = patch.copy() | |||
| patch_from[0] = max(0, patch_from[0]) | |||
| patch_from[1] = max(0, patch_from[1]) | |||
| patch_from[2] = min(img.shape[1], patch_from[2]) | |||
| patch_from[3] = min(img.shape[0], patch_from[3]) | |||
| patch_to = patch.copy() | |||
| patch_to[0] = max(0, patch_to[0] * -1) | |||
| patch_to[1] = max(0, patch_to[1] * -1) | |||
| patch_to[2] = patch_to[0] + (patch_from[2] - patch_from[0]) | |||
| patch_to[3] = patch_to[1] + (patch_from[3] - patch_from[1]) | |||
| rimg[patch_to[1]:patch_to[3], | |||
| patch_to[0]:patch_to[2], :] = img[ | |||
| patch_from[1]:patch_from[3], | |||
| patch_from[0]:patch_from[2], :] | |||
| img = rimg | |||
| results['img'] = img | |||
| results['img_shape'] = img.shape | |||
| return results | |||
| def __repr__(self): | |||
| repr_str = self.__class__.__name__ | |||
| repr_str += f'(min_ious={self.min_iou}, ' | |||
| repr_str += f'crop_size={self.crop_size})' | |||
| return repr_str | |||
| @@ -1,3 +1,5 @@ | |||
| # The implementation is based on MogFace, available at | |||
| # https://github.com/damo-cv/MogFace | |||
| import os | |||
| import cv2 | |||
| @@ -0,0 +1,2 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from .scrfd_detect import ScrfdDetect | |||
| @@ -6,7 +6,7 @@ import numpy as np | |||
| import torch | |||
| def bbox2result(bboxes, labels, num_classes, kps=None): | |||
| def bbox2result(bboxes, labels, num_classes, kps=None, num_kps=5): | |||
| """Convert detection results to a list of numpy arrays. | |||
| Args: | |||
| @@ -17,7 +17,7 @@ def bbox2result(bboxes, labels, num_classes, kps=None): | |||
| Returns: | |||
| list(ndarray): bbox results of each class | |||
| """ | |||
| bbox_len = 5 if kps is None else 5 + 10 # if has kps, add 10 kps into bbox | |||
| bbox_len = 5 if kps is None else 5 + num_kps * 2 # if has kps, add num_kps*2 into bbox | |||
| if bboxes.shape[0] == 0: | |||
| return [ | |||
| np.zeros((0, bbox_len), dtype=np.float32) | |||
| @@ -17,6 +17,7 @@ def multiclass_nms(multi_bboxes, | |||
| Args: | |||
| multi_bboxes (Tensor): shape (n, #class*4) or (n, 4) | |||
| multi_kps (Tensor): shape (n, #class*num_kps*2) or (n, num_kps*2) | |||
| multi_scores (Tensor): shape (n, #class), where the last column | |||
| contains scores of the background class, but this will be ignored. | |||
| score_thr (float): bbox threshold, bboxes with scores lower than it | |||
| @@ -36,16 +37,18 @@ def multiclass_nms(multi_bboxes, | |||
| num_classes = multi_scores.size(1) - 1 | |||
| # exclude background category | |||
| kps = None | |||
| if multi_kps is not None: | |||
| num_kps = int((multi_kps.shape[1] / num_classes) / 2) | |||
| if multi_bboxes.shape[1] > 4: | |||
| bboxes = multi_bboxes.view(multi_scores.size(0), -1, 4) | |||
| if multi_kps is not None: | |||
| kps = multi_kps.view(multi_scores.size(0), -1, 10) | |||
| kps = multi_kps.view(multi_scores.size(0), -1, num_kps * 2) | |||
| else: | |||
| bboxes = multi_bboxes[:, None].expand( | |||
| multi_scores.size(0), num_classes, 4) | |||
| if multi_kps is not None: | |||
| kps = multi_kps[:, None].expand( | |||
| multi_scores.size(0), num_classes, 10) | |||
| multi_scores.size(0), num_classes, num_kps * 2) | |||
| scores = multi_scores[:, :-1] | |||
| if score_factors is not None: | |||
| @@ -56,7 +59,7 @@ def multiclass_nms(multi_bboxes, | |||
| bboxes = bboxes.reshape(-1, 4) | |||
| if kps is not None: | |||
| kps = kps.reshape(-1, 10) | |||
| kps = kps.reshape(-1, num_kps * 2) | |||
| scores = scores.reshape(-1) | |||
| labels = labels.reshape(-1) | |||
| @@ -2,6 +2,12 @@ | |||
| The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at | |||
| https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/datasets/pipelines | |||
| """ | |||
| from .auto_augment import RotateV2 | |||
| from .formating import DefaultFormatBundleV2 | |||
| from .loading import LoadAnnotationsV2 | |||
| from .transforms import RandomSquareCrop | |||
| __all__ = ['RandomSquareCrop'] | |||
| __all__ = [ | |||
| 'RandomSquareCrop', 'LoadAnnotationsV2', 'RotateV2', | |||
| 'DefaultFormatBundleV2' | |||
| ] | |||
| @@ -0,0 +1,271 @@ | |||
| """ | |||
| The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at | |||
| https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/datasets/pipelines/auto_augment.py | |||
| """ | |||
| import copy | |||
| import cv2 | |||
| import mmcv | |||
| import numpy as np | |||
| from mmdet.datasets.builder import PIPELINES | |||
| _MAX_LEVEL = 10 | |||
| def level_to_value(level, max_value): | |||
| """Map from level to values based on max_value.""" | |||
| return (level / _MAX_LEVEL) * max_value | |||
| def random_negative(value, random_negative_prob): | |||
| """Randomly negate value based on random_negative_prob.""" | |||
| return -value if np.random.rand() < random_negative_prob else value | |||
| def bbox2fields(): | |||
| """The key correspondence from bboxes to labels, masks and | |||
| segmentations.""" | |||
| bbox2label = { | |||
| 'gt_bboxes': 'gt_labels', | |||
| 'gt_bboxes_ignore': 'gt_labels_ignore' | |||
| } | |||
| bbox2mask = { | |||
| 'gt_bboxes': 'gt_masks', | |||
| 'gt_bboxes_ignore': 'gt_masks_ignore' | |||
| } | |||
| bbox2seg = { | |||
| 'gt_bboxes': 'gt_semantic_seg', | |||
| } | |||
| return bbox2label, bbox2mask, bbox2seg | |||
| @PIPELINES.register_module() | |||
| class RotateV2(object): | |||
| """Apply Rotate Transformation to image (and its corresponding bbox, mask, | |||
| segmentation). | |||
| Args: | |||
| level (int | float): The level should be in range (0,_MAX_LEVEL]. | |||
| scale (int | float): Isotropic scale factor. Same in | |||
| ``mmcv.imrotate``. | |||
| center (int | float | tuple[float]): Center point (w, h) of the | |||
| rotation in the source image. If None, the center of the | |||
| image will be used. Same in ``mmcv.imrotate``. | |||
| img_fill_val (int | float | tuple): The fill value for image border. | |||
| If float, the same value will be used for all the three | |||
| channels of image. If tuple, the should be 3 elements (e.g. | |||
| equals the number of channels for image). | |||
| seg_ignore_label (int): The fill value used for segmentation map. | |||
| Note this value must equals ``ignore_label`` in ``semantic_head`` | |||
| of the corresponding config. Default 255. | |||
| prob (float): The probability for perform transformation and | |||
| should be in range 0 to 1. | |||
| max_rotate_angle (int | float): The maximum angles for rotate | |||
| transformation. | |||
| random_negative_prob (float): The probability that turns the | |||
| offset negative. | |||
| """ | |||
| def __init__(self, | |||
| level, | |||
| scale=1, | |||
| center=None, | |||
| img_fill_val=128, | |||
| seg_ignore_label=255, | |||
| prob=0.5, | |||
| max_rotate_angle=30, | |||
| random_negative_prob=0.5): | |||
| assert isinstance(level, (int, float)), \ | |||
| f'The level must be type int or float. got {type(level)}.' | |||
| assert 0 <= level <= _MAX_LEVEL, \ | |||
| f'The level should be in range (0,{_MAX_LEVEL}]. got {level}.' | |||
| assert isinstance(scale, (int, float)), \ | |||
| f'The scale must be type int or float. got type {type(scale)}.' | |||
| if isinstance(center, (int, float)): | |||
| center = (center, center) | |||
| elif isinstance(center, tuple): | |||
| assert len(center) == 2, 'center with type tuple must have '\ | |||
| f'2 elements. got {len(center)} elements.' | |||
| else: | |||
| assert center is None, 'center must be None or type int, '\ | |||
| f'float or tuple, got type {type(center)}.' | |||
| if isinstance(img_fill_val, (float, int)): | |||
| img_fill_val = tuple([float(img_fill_val)] * 3) | |||
| elif isinstance(img_fill_val, tuple): | |||
| assert len(img_fill_val) == 3, 'img_fill_val as tuple must '\ | |||
| f'have 3 elements. got {len(img_fill_val)}.' | |||
| img_fill_val = tuple([float(val) for val in img_fill_val]) | |||
| else: | |||
| raise ValueError( | |||
| 'img_fill_val must be float or tuple with 3 elements.') | |||
| assert np.all([0 <= val <= 255 for val in img_fill_val]), \ | |||
| 'all elements of img_fill_val should between range [0,255]. '\ | |||
| f'got {img_fill_val}.' | |||
| assert 0 <= prob <= 1.0, 'The probability should be in range [0,1]. '\ | |||
| f'got {prob}.' | |||
| assert isinstance(max_rotate_angle, (int, float)), 'max_rotate_angle '\ | |||
| f'should be type int or float. got type {type(max_rotate_angle)}.' | |||
| self.level = level | |||
| self.scale = scale | |||
| # Rotation angle in degrees. Positive values mean | |||
| # clockwise rotation. | |||
| self.angle = level_to_value(level, max_rotate_angle) | |||
| self.center = center | |||
| self.img_fill_val = img_fill_val | |||
| self.seg_ignore_label = seg_ignore_label | |||
| self.prob = prob | |||
| self.max_rotate_angle = max_rotate_angle | |||
| self.random_negative_prob = random_negative_prob | |||
| def _rotate_img(self, results, angle, center=None, scale=1.0): | |||
| """Rotate the image. | |||
| Args: | |||
| results (dict): Result dict from loading pipeline. | |||
| angle (float): Rotation angle in degrees, positive values | |||
| mean clockwise rotation. Same in ``mmcv.imrotate``. | |||
| center (tuple[float], optional): Center point (w, h) of the | |||
| rotation. Same in ``mmcv.imrotate``. | |||
| scale (int | float): Isotropic scale factor. Same in | |||
| ``mmcv.imrotate``. | |||
| """ | |||
| for key in results.get('img_fields', ['img']): | |||
| img = results[key].copy() | |||
| img_rotated = mmcv.imrotate( | |||
| img, angle, center, scale, border_value=self.img_fill_val) | |||
| results[key] = img_rotated.astype(img.dtype) | |||
| results['img_shape'] = results[key].shape | |||
| def _rotate_bboxes(self, results, rotate_matrix): | |||
| """Rotate the bboxes.""" | |||
| h, w, c = results['img_shape'] | |||
| for key in results.get('bbox_fields', []): | |||
| min_x, min_y, max_x, max_y = np.split( | |||
| results[key], results[key].shape[-1], axis=-1) | |||
| coordinates = np.stack([[min_x, min_y], [max_x, min_y], | |||
| [min_x, max_y], | |||
| [max_x, max_y]]) # [4, 2, nb_bbox, 1] | |||
| # pad 1 to convert from format [x, y] to homogeneous | |||
| # coordinates format [x, y, 1] | |||
| coordinates = np.concatenate( | |||
| (coordinates, | |||
| np.ones((4, 1, coordinates.shape[2], 1), coordinates.dtype)), | |||
| axis=1) # [4, 3, nb_bbox, 1] | |||
| coordinates = coordinates.transpose( | |||
| (2, 0, 1, 3)) # [nb_bbox, 4, 3, 1] | |||
| rotated_coords = np.matmul(rotate_matrix, | |||
| coordinates) # [nb_bbox, 4, 2, 1] | |||
| rotated_coords = rotated_coords[..., 0] # [nb_bbox, 4, 2] | |||
| min_x, min_y = np.min( | |||
| rotated_coords[:, :, 0], axis=1), np.min( | |||
| rotated_coords[:, :, 1], axis=1) | |||
| max_x, max_y = np.max( | |||
| rotated_coords[:, :, 0], axis=1), np.max( | |||
| rotated_coords[:, :, 1], axis=1) | |||
| results[key] = np.stack([min_x, min_y, max_x, max_y], | |||
| axis=-1).astype(results[key].dtype) | |||
| def _rotate_keypoints90(self, results, angle): | |||
| """Rotate the keypoints, only valid when angle in [-90,90,-180,180]""" | |||
| if angle not in [-90, 90, 180, -180 | |||
| ] or self.scale != 1 or self.center is not None: | |||
| return | |||
| for key in results.get('keypoints_fields', []): | |||
| k = results[key] | |||
| if angle == 90: | |||
| w, h, c = results['img'].shape | |||
| new = np.stack([h - k[..., 1], k[..., 0], k[..., 2]], axis=-1) | |||
| elif angle == -90: | |||
| w, h, c = results['img'].shape | |||
| new = np.stack([k[..., 1], w - k[..., 0], k[..., 2]], axis=-1) | |||
| else: | |||
| h, w, c = results['img'].shape | |||
| new = np.stack([w - k[..., 0], h - k[..., 1], k[..., 2]], | |||
| axis=-1) | |||
| # a kps is invalid if thrid value is -1 | |||
| kps_invalid = new[..., -1][:, -1] == -1 | |||
| new[kps_invalid] = np.zeros(new.shape[1:]) - 1 | |||
| results[key] = new | |||
| def _rotate_masks(self, | |||
| results, | |||
| angle, | |||
| center=None, | |||
| scale=1.0, | |||
| fill_val=0): | |||
| """Rotate the masks.""" | |||
| h, w, c = results['img_shape'] | |||
| for key in results.get('mask_fields', []): | |||
| masks = results[key] | |||
| results[key] = masks.rotate((h, w), angle, center, scale, fill_val) | |||
| def _rotate_seg(self, | |||
| results, | |||
| angle, | |||
| center=None, | |||
| scale=1.0, | |||
| fill_val=255): | |||
| """Rotate the segmentation map.""" | |||
| for key in results.get('seg_fields', []): | |||
| seg = results[key].copy() | |||
| results[key] = mmcv.imrotate( | |||
| seg, angle, center, scale, | |||
| border_value=fill_val).astype(seg.dtype) | |||
| def _filter_invalid(self, results, min_bbox_size=0): | |||
| """Filter bboxes and corresponding masks too small after rotate | |||
| augmentation.""" | |||
| bbox2label, bbox2mask, _ = bbox2fields() | |||
| for key in results.get('bbox_fields', []): | |||
| bbox_w = results[key][:, 2] - results[key][:, 0] | |||
| bbox_h = results[key][:, 3] - results[key][:, 1] | |||
| valid_inds = (bbox_w > min_bbox_size) & (bbox_h > min_bbox_size) | |||
| valid_inds = np.nonzero(valid_inds)[0] | |||
| results[key] = results[key][valid_inds] | |||
| # label fields. e.g. gt_labels and gt_labels_ignore | |||
| label_key = bbox2label.get(key) | |||
| if label_key in results: | |||
| results[label_key] = results[label_key][valid_inds] | |||
| # mask fields, e.g. gt_masks and gt_masks_ignore | |||
| mask_key = bbox2mask.get(key) | |||
| if mask_key in results: | |||
| results[mask_key] = results[mask_key][valid_inds] | |||
| def __call__(self, results): | |||
| """Call function to rotate images, bounding boxes, masks and semantic | |||
| segmentation maps. | |||
| Args: | |||
| results (dict): Result dict from loading pipeline. | |||
| Returns: | |||
| dict: Rotated results. | |||
| """ | |||
| if np.random.rand() > self.prob: | |||
| return results | |||
| h, w = results['img'].shape[:2] | |||
| center = self.center | |||
| if center is None: | |||
| center = ((w - 1) * 0.5, (h - 1) * 0.5) | |||
| angle = random_negative(self.angle, self.random_negative_prob) | |||
| self._rotate_img(results, angle, center, self.scale) | |||
| rotate_matrix = cv2.getRotationMatrix2D(center, -angle, self.scale) | |||
| self._rotate_bboxes(results, rotate_matrix) | |||
| self._rotate_keypoints90(results, angle) | |||
| self._rotate_masks(results, angle, center, self.scale, fill_val=0) | |||
| self._rotate_seg( | |||
| results, angle, center, self.scale, fill_val=self.seg_ignore_label) | |||
| self._filter_invalid(results) | |||
| return results | |||
| def __repr__(self): | |||
| repr_str = self.__class__.__name__ | |||
| repr_str += f'(level={self.level}, ' | |||
| repr_str += f'scale={self.scale}, ' | |||
| repr_str += f'center={self.center}, ' | |||
| repr_str += f'img_fill_val={self.img_fill_val}, ' | |||
| repr_str += f'seg_ignore_label={self.seg_ignore_label}, ' | |||
| repr_str += f'prob={self.prob}, ' | |||
| repr_str += f'max_rotate_angle={self.max_rotate_angle}, ' | |||
| repr_str += f'random_negative_prob={self.random_negative_prob})' | |||
| return repr_str | |||
| @@ -0,0 +1,113 @@ | |||
| """ | |||
| The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at | |||
| https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/datasets/pipelines/formating.py | |||
| """ | |||
| import numpy as np | |||
| import torch | |||
| from mmcv.parallel import DataContainer as DC | |||
| from mmdet.datasets.builder import PIPELINES | |||
| def to_tensor(data): | |||
| """Convert objects of various python types to :obj:`torch.Tensor`. | |||
| Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`, | |||
| :class:`Sequence`, :class:`int` and :class:`float`. | |||
| Args: | |||
| data (torch.Tensor | numpy.ndarray | Sequence | int | float): Data to | |||
| be converted. | |||
| """ | |||
| if isinstance(data, torch.Tensor): | |||
| return data | |||
| elif isinstance(data, np.ndarray): | |||
| return torch.from_numpy(data) | |||
| elif isinstance(data, Sequence) and not mmcv.is_str(data): | |||
| return torch.tensor(data) | |||
| elif isinstance(data, int): | |||
| return torch.LongTensor([data]) | |||
| elif isinstance(data, float): | |||
| return torch.FloatTensor([data]) | |||
| else: | |||
| raise TypeError(f'type {type(data)} cannot be converted to tensor.') | |||
| @PIPELINES.register_module() | |||
| class DefaultFormatBundleV2(object): | |||
| """Default formatting bundle. | |||
| It simplifies the pipeline of formatting common fields, including "img", | |||
| "proposals", "gt_bboxes", "gt_labels", "gt_masks" and "gt_semantic_seg". | |||
| These fields are formatted as follows. | |||
| - img: (1)transpose, (2)to tensor, (3)to DataContainer (stack=True) | |||
| - proposals: (1)to tensor, (2)to DataContainer | |||
| - gt_bboxes: (1)to tensor, (2)to DataContainer | |||
| - gt_bboxes_ignore: (1)to tensor, (2)to DataContainer | |||
| - gt_labels: (1)to tensor, (2)to DataContainer | |||
| - gt_masks: (1)to tensor, (2)to DataContainer (cpu_only=True) | |||
| - gt_semantic_seg: (1)unsqueeze dim-0 (2)to tensor, \ | |||
| (3)to DataContainer (stack=True) | |||
| """ | |||
| def __call__(self, results): | |||
| """Call function to transform and format common fields in results. | |||
| Args: | |||
| results (dict): Result dict contains the data to convert. | |||
| Returns: | |||
| dict: The result dict contains the data that is formatted with \ | |||
| default bundle. | |||
| """ | |||
| if 'img' in results: | |||
| img = results['img'] | |||
| # add default meta keys | |||
| results = self._add_default_meta_keys(results) | |||
| if len(img.shape) < 3: | |||
| img = np.expand_dims(img, -1) | |||
| img = np.ascontiguousarray(img.transpose(2, 0, 1)) | |||
| results['img'] = DC(to_tensor(img), stack=True) | |||
| for key in [ | |||
| 'proposals', 'gt_bboxes', 'gt_bboxes_ignore', 'gt_keypointss', | |||
| 'gt_labels' | |||
| ]: | |||
| if key not in results: | |||
| continue | |||
| results[key] = DC(to_tensor(results[key])) | |||
| if 'gt_masks' in results: | |||
| results['gt_masks'] = DC(results['gt_masks'], cpu_only=True) | |||
| if 'gt_semantic_seg' in results: | |||
| results['gt_semantic_seg'] = DC( | |||
| to_tensor(results['gt_semantic_seg'][None, ...]), stack=True) | |||
| return results | |||
| def _add_default_meta_keys(self, results): | |||
| """Add default meta keys. | |||
| We set default meta keys including `pad_shape`, `scale_factor` and | |||
| `img_norm_cfg` to avoid the case where no `Resize`, `Normalize` and | |||
| `Pad` are implemented during the whole pipeline. | |||
| Args: | |||
| results (dict): Result dict contains the data to convert. | |||
| Returns: | |||
| results (dict): Updated result dict contains the data to convert. | |||
| """ | |||
| img = results['img'] | |||
| results.setdefault('pad_shape', img.shape) | |||
| results.setdefault('scale_factor', 1.0) | |||
| num_channels = 1 if len(img.shape) < 3 else img.shape[2] | |||
| results.setdefault( | |||
| 'img_norm_cfg', | |||
| dict( | |||
| mean=np.zeros(num_channels, dtype=np.float32), | |||
| std=np.ones(num_channels, dtype=np.float32), | |||
| to_rgb=False)) | |||
| return results | |||
| def __repr__(self): | |||
| return self.__class__.__name__ | |||
| @@ -0,0 +1,225 @@ | |||
| """ | |||
| The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at | |||
| https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/datasets/pipelines/loading.py | |||
| """ | |||
| import os.path as osp | |||
| import numpy as np | |||
| import pycocotools.mask as maskUtils | |||
| from mmdet.core import BitmapMasks, PolygonMasks | |||
| from mmdet.datasets.builder import PIPELINES | |||
| @PIPELINES.register_module() | |||
| class LoadAnnotationsV2(object): | |||
| """Load mutiple types of annotations. | |||
| Args: | |||
| with_bbox (bool): Whether to parse and load the bbox annotation. | |||
| Default: True. | |||
| with_label (bool): Whether to parse and load the label annotation. | |||
| Default: True. | |||
| with_keypoints (bool): Whether to parse and load the keypoints annotation. | |||
| Default: False. | |||
| with_mask (bool): Whether to parse and load the mask annotation. | |||
| Default: False. | |||
| with_seg (bool): Whether to parse and load the semantic segmentation | |||
| annotation. Default: False. | |||
| poly2mask (bool): Whether to convert the instance masks from polygons | |||
| to bitmaps. Default: True. | |||
| file_client_args (dict): Arguments to instantiate a FileClient. | |||
| See :class:`mmcv.fileio.FileClient` for details. | |||
| Defaults to ``dict(backend='disk')``. | |||
| """ | |||
| def __init__(self, | |||
| with_bbox=True, | |||
| with_label=True, | |||
| with_keypoints=False, | |||
| with_mask=False, | |||
| with_seg=False, | |||
| poly2mask=True, | |||
| file_client_args=dict(backend='disk')): | |||
| self.with_bbox = with_bbox | |||
| self.with_label = with_label | |||
| self.with_keypoints = with_keypoints | |||
| self.with_mask = with_mask | |||
| self.with_seg = with_seg | |||
| self.poly2mask = poly2mask | |||
| self.file_client_args = file_client_args.copy() | |||
| self.file_client = None | |||
| def _load_bboxes(self, results): | |||
| """Private function to load bounding box annotations. | |||
| Args: | |||
| results (dict): Result dict from :obj:`mmdet.CustomDataset`. | |||
| Returns: | |||
| dict: The dict contains loaded bounding box annotations. | |||
| """ | |||
| ann_info = results['ann_info'] | |||
| results['gt_bboxes'] = ann_info['bboxes'].copy() | |||
| gt_bboxes_ignore = ann_info.get('bboxes_ignore', None) | |||
| if gt_bboxes_ignore is not None: | |||
| results['gt_bboxes_ignore'] = gt_bboxes_ignore.copy() | |||
| results['bbox_fields'].append('gt_bboxes_ignore') | |||
| results['bbox_fields'].append('gt_bboxes') | |||
| return results | |||
| def _load_keypoints(self, results): | |||
| """Private function to load bounding box annotations. | |||
| Args: | |||
| results (dict): Result dict from :obj:`mmdet.CustomDataset`. | |||
| Returns: | |||
| dict: The dict contains loaded bounding box annotations. | |||
| """ | |||
| ann_info = results['ann_info'] | |||
| results['gt_keypointss'] = ann_info['keypointss'].copy() | |||
| results['keypoints_fields'] = ['gt_keypointss'] | |||
| return results | |||
| def _load_labels(self, results): | |||
| """Private function to load label annotations. | |||
| Args: | |||
| results (dict): Result dict from :obj:`mmdet.CustomDataset`. | |||
| Returns: | |||
| dict: The dict contains loaded label annotations. | |||
| """ | |||
| results['gt_labels'] = results['ann_info']['labels'].copy() | |||
| return results | |||
| def _poly2mask(self, mask_ann, img_h, img_w): | |||
| """Private function to convert masks represented with polygon to | |||
| bitmaps. | |||
| Args: | |||
| mask_ann (list | dict): Polygon mask annotation input. | |||
| img_h (int): The height of output mask. | |||
| img_w (int): The width of output mask. | |||
| Returns: | |||
| numpy.ndarray: The decode bitmap mask of shape (img_h, img_w). | |||
| """ | |||
| if isinstance(mask_ann, list): | |||
| # polygon -- a single object might consist of multiple parts | |||
| # we merge all parts into one mask rle code | |||
| rles = maskUtils.frPyObjects(mask_ann, img_h, img_w) | |||
| rle = maskUtils.merge(rles) | |||
| elif isinstance(mask_ann['counts'], list): | |||
| # uncompressed RLE | |||
| rle = maskUtils.frPyObjects(mask_ann, img_h, img_w) | |||
| else: | |||
| # rle | |||
| rle = mask_ann | |||
| mask = maskUtils.decode(rle) | |||
| return mask | |||
| def process_polygons(self, polygons): | |||
| """Convert polygons to list of ndarray and filter invalid polygons. | |||
| Args: | |||
| polygons (list[list]): Polygons of one instance. | |||
| Returns: | |||
| list[numpy.ndarray]: Processed polygons. | |||
| """ | |||
| polygons = [np.array(p) for p in polygons] | |||
| valid_polygons = [] | |||
| for polygon in polygons: | |||
| if len(polygon) % 2 == 0 and len(polygon) >= 6: | |||
| valid_polygons.append(polygon) | |||
| return valid_polygons | |||
| def _load_masks(self, results): | |||
| """Private function to load mask annotations. | |||
| Args: | |||
| results (dict): Result dict from :obj:`mmdet.CustomDataset`. | |||
| Returns: | |||
| dict: The dict contains loaded mask annotations. | |||
| If ``self.poly2mask`` is set ``True``, `gt_mask` will contain | |||
| :obj:`PolygonMasks`. Otherwise, :obj:`BitmapMasks` is used. | |||
| """ | |||
| h, w = results['img_info']['height'], results['img_info']['width'] | |||
| gt_masks = results['ann_info']['masks'] | |||
| if self.poly2mask: | |||
| gt_masks = BitmapMasks( | |||
| [self._poly2mask(mask, h, w) for mask in gt_masks], h, w) | |||
| else: | |||
| gt_masks = PolygonMasks( | |||
| [self.process_polygons(polygons) for polygons in gt_masks], h, | |||
| w) | |||
| results['gt_masks'] = gt_masks | |||
| results['mask_fields'].append('gt_masks') | |||
| return results | |||
| def _load_semantic_seg(self, results): | |||
| """Private function to load semantic segmentation annotations. | |||
| Args: | |||
| results (dict): Result dict from :obj:`dataset`. | |||
| Returns: | |||
| dict: The dict contains loaded semantic segmentation annotations. | |||
| """ | |||
| import mmcv | |||
| if self.file_client is None: | |||
| self.file_client = mmcv.FileClient(**self.file_client_args) | |||
| filename = osp.join(results['seg_prefix'], | |||
| results['ann_info']['seg_map']) | |||
| img_bytes = self.file_client.get(filename) | |||
| results['gt_semantic_seg'] = mmcv.imfrombytes( | |||
| img_bytes, flag='unchanged').squeeze() | |||
| results['seg_fields'].append('gt_semantic_seg') | |||
| return results | |||
| def __call__(self, results): | |||
| """Call function to load multiple types annotations. | |||
| Args: | |||
| results (dict): Result dict from :obj:`mmdet.CustomDataset`. | |||
| Returns: | |||
| dict: The dict contains loaded bounding box, label, mask and | |||
| semantic segmentation annotations. | |||
| """ | |||
| if self.with_bbox: | |||
| results = self._load_bboxes(results) | |||
| if results is None: | |||
| return None | |||
| if self.with_label: | |||
| results = self._load_labels(results) | |||
| if self.with_keypoints: | |||
| results = self._load_keypoints(results) | |||
| if self.with_mask: | |||
| results = self._load_masks(results) | |||
| if self.with_seg: | |||
| results = self._load_semantic_seg(results) | |||
| return results | |||
| def __repr__(self): | |||
| repr_str = self.__class__.__name__ | |||
| repr_str += f'(with_bbox={self.with_bbox}, ' | |||
| repr_str += f'with_label={self.with_label}, ' | |||
| repr_str += f'with_keypoints={self.with_keypoints}, ' | |||
| repr_str += f'with_mask={self.with_mask}, ' | |||
| repr_str += f'with_seg={self.with_seg})' | |||
| repr_str += f'poly2mask={self.poly2mask})' | |||
| repr_str += f'poly2mask={self.file_client_args})' | |||
| return repr_str | |||
| @@ -0,0 +1,737 @@ | |||
| """ | |||
| The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at | |||
| https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/datasets/pipelines/transforms.py | |||
| """ | |||
| import mmcv | |||
| import numpy as np | |||
| from mmdet.core.evaluation.bbox_overlaps import bbox_overlaps | |||
| from mmdet.datasets.builder import PIPELINES | |||
| from numpy import random | |||
| @PIPELINES.register_module() | |||
| class ResizeV2(object): | |||
| """Resize images & bbox & mask &kps. | |||
| This transform resizes the input image to some scale. Bboxes and masks are | |||
| then resized with the same scale factor. If the input dict contains the key | |||
| "scale", then the scale in the input dict is used, otherwise the specified | |||
| scale in the init method is used. If the input dict contains the key | |||
| "scale_factor" (if MultiScaleFlipAug does not give img_scale but | |||
| scale_factor), the actual scale will be computed by image shape and | |||
| scale_factor. | |||
| `img_scale` can either be a tuple (single-scale) or a list of tuple | |||
| (multi-scale). There are 3 multiscale modes: | |||
| - ``ratio_range is not None``: randomly sample a ratio from the ratio \ | |||
| range and multiply it with the image scale. | |||
| - ``ratio_range is None`` and ``multiscale_mode == "range"``: randomly \ | |||
| sample a scale from the multiscale range. | |||
| - ``ratio_range is None`` and ``multiscale_mode == "value"``: randomly \ | |||
| sample a scale from multiple scales. | |||
| Args: | |||
| img_scale (tuple or list[tuple]): Images scales for resizing. | |||
| multiscale_mode (str): Either "range" or "value". | |||
| ratio_range (tuple[float]): (min_ratio, max_ratio) | |||
| keep_ratio (bool): Whether to keep the aspect ratio when resizing the | |||
| image. | |||
| bbox_clip_border (bool, optional): Whether clip the objects outside | |||
| the border of the image. Defaults to True. | |||
| backend (str): Image resize backend, choices are 'cv2' and 'pillow'. | |||
| These two backends generates slightly different results. Defaults | |||
| to 'cv2'. | |||
| override (bool, optional): Whether to override `scale` and | |||
| `scale_factor` so as to call resize twice. Default False. If True, | |||
| after the first resizing, the existed `scale` and `scale_factor` | |||
| will be ignored so the second resizing can be allowed. | |||
| This option is a work-around for multiple times of resize in DETR. | |||
| Defaults to False. | |||
| """ | |||
| def __init__(self, | |||
| img_scale=None, | |||
| multiscale_mode='range', | |||
| ratio_range=None, | |||
| keep_ratio=True, | |||
| bbox_clip_border=True, | |||
| backend='cv2', | |||
| override=False): | |||
| if img_scale is None: | |||
| self.img_scale = None | |||
| else: | |||
| if isinstance(img_scale, list): | |||
| self.img_scale = img_scale | |||
| else: | |||
| self.img_scale = [img_scale] | |||
| assert mmcv.is_list_of(self.img_scale, tuple) | |||
| if ratio_range is not None: | |||
| # mode 1: given a scale and a range of image ratio | |||
| assert len(self.img_scale) == 1 | |||
| else: | |||
| # mode 2: given multiple scales or a range of scales | |||
| assert multiscale_mode in ['value', 'range'] | |||
| self.backend = backend | |||
| self.multiscale_mode = multiscale_mode | |||
| self.ratio_range = ratio_range | |||
| self.keep_ratio = keep_ratio | |||
| # TODO: refactor the override option in Resize | |||
| self.override = override | |||
| self.bbox_clip_border = bbox_clip_border | |||
| @staticmethod | |||
| def random_select(img_scales): | |||
| """Randomly select an img_scale from given candidates. | |||
| Args: | |||
| img_scales (list[tuple]): Images scales for selection. | |||
| Returns: | |||
| (tuple, int): Returns a tuple ``(img_scale, scale_dix)``, \ | |||
| where ``img_scale`` is the selected image scale and \ | |||
| ``scale_idx`` is the selected index in the given candidates. | |||
| """ | |||
| assert mmcv.is_list_of(img_scales, tuple) | |||
| scale_idx = np.random.randint(len(img_scales)) | |||
| img_scale = img_scales[scale_idx] | |||
| return img_scale, scale_idx | |||
| @staticmethod | |||
| def random_sample(img_scales): | |||
| """Randomly sample an img_scale when ``multiscale_mode=='range'``. | |||
| Args: | |||
| img_scales (list[tuple]): Images scale range for sampling. | |||
| There must be two tuples in img_scales, which specify the lower | |||
| and uper bound of image scales. | |||
| Returns: | |||
| (tuple, None): Returns a tuple ``(img_scale, None)``, where \ | |||
| ``img_scale`` is sampled scale and None is just a placeholder \ | |||
| to be consistent with :func:`random_select`. | |||
| """ | |||
| assert mmcv.is_list_of(img_scales, tuple) and len(img_scales) == 2 | |||
| img_scale_long = [max(s) for s in img_scales] | |||
| img_scale_short = [min(s) for s in img_scales] | |||
| long_edge = np.random.randint( | |||
| min(img_scale_long), | |||
| max(img_scale_long) + 1) | |||
| short_edge = np.random.randint( | |||
| min(img_scale_short), | |||
| max(img_scale_short) + 1) | |||
| img_scale = (long_edge, short_edge) | |||
| return img_scale, None | |||
| @staticmethod | |||
| def random_sample_ratio(img_scale, ratio_range): | |||
| """Randomly sample an img_scale when ``ratio_range`` is specified. | |||
| A ratio will be randomly sampled from the range specified by | |||
| ``ratio_range``. Then it would be multiplied with ``img_scale`` to | |||
| generate sampled scale. | |||
| Args: | |||
| img_scale (tuple): Images scale base to multiply with ratio. | |||
| ratio_range (tuple[float]): The minimum and maximum ratio to scale | |||
| the ``img_scale``. | |||
| Returns: | |||
| (tuple, None): Returns a tuple ``(scale, None)``, where \ | |||
| ``scale`` is sampled ratio multiplied with ``img_scale`` and \ | |||
| None is just a placeholder to be consistent with \ | |||
| :func:`random_select`. | |||
| """ | |||
| assert isinstance(img_scale, tuple) and len(img_scale) == 2 | |||
| min_ratio, max_ratio = ratio_range | |||
| assert min_ratio <= max_ratio | |||
| ratio = np.random.random_sample() * (max_ratio - min_ratio) + min_ratio | |||
| scale = int(img_scale[0] * ratio), int(img_scale[1] * ratio) | |||
| return scale, None | |||
| def _random_scale(self, results): | |||
| """Randomly sample an img_scale according to ``ratio_range`` and | |||
| ``multiscale_mode``. | |||
| If ``ratio_range`` is specified, a ratio will be sampled and be | |||
| multiplied with ``img_scale``. | |||
| If multiple scales are specified by ``img_scale``, a scale will be | |||
| sampled according to ``multiscale_mode``. | |||
| Otherwise, single scale will be used. | |||
| Args: | |||
| results (dict): Result dict from :obj:`dataset`. | |||
| Returns: | |||
| dict: Two new keys 'scale` and 'scale_idx` are added into \ | |||
| ``results``, which would be used by subsequent pipelines. | |||
| """ | |||
| if self.ratio_range is not None: | |||
| scale, scale_idx = self.random_sample_ratio( | |||
| self.img_scale[0], self.ratio_range) | |||
| elif len(self.img_scale) == 1: | |||
| scale, scale_idx = self.img_scale[0], 0 | |||
| elif self.multiscale_mode == 'range': | |||
| scale, scale_idx = self.random_sample(self.img_scale) | |||
| elif self.multiscale_mode == 'value': | |||
| scale, scale_idx = self.random_select(self.img_scale) | |||
| else: | |||
| raise NotImplementedError | |||
| results['scale'] = scale | |||
| results['scale_idx'] = scale_idx | |||
| def _resize_img(self, results): | |||
| """Resize images with ``results['scale']``.""" | |||
| for key in results.get('img_fields', ['img']): | |||
| if self.keep_ratio: | |||
| img, scale_factor = mmcv.imrescale( | |||
| results[key], | |||
| results['scale'], | |||
| return_scale=True, | |||
| backend=self.backend) | |||
| # the w_scale and h_scale has minor difference | |||
| # a real fix should be done in the mmcv.imrescale in the future | |||
| new_h, new_w = img.shape[:2] | |||
| h, w = results[key].shape[:2] | |||
| w_scale = new_w / w | |||
| h_scale = new_h / h | |||
| else: | |||
| img, w_scale, h_scale = mmcv.imresize( | |||
| results[key], | |||
| results['scale'], | |||
| return_scale=True, | |||
| backend=self.backend) | |||
| results[key] = img | |||
| scale_factor = np.array([w_scale, h_scale, w_scale, h_scale], | |||
| dtype=np.float32) | |||
| results['img_shape'] = img.shape | |||
| # in case that there is no padding | |||
| results['pad_shape'] = img.shape | |||
| results['scale_factor'] = scale_factor | |||
| results['keep_ratio'] = self.keep_ratio | |||
| def _resize_bboxes(self, results): | |||
| """Resize bounding boxes with ``results['scale_factor']``.""" | |||
| for key in results.get('bbox_fields', []): | |||
| bboxes = results[key] * results['scale_factor'] | |||
| if self.bbox_clip_border: | |||
| img_shape = results['img_shape'] | |||
| bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, img_shape[1]) | |||
| bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0, img_shape[0]) | |||
| results[key] = bboxes | |||
| def _resize_keypoints(self, results): | |||
| """Resize keypoints with ``results['scale_factor']``.""" | |||
| for key in results.get('keypoints_fields', []): | |||
| keypointss = results[key].copy() | |||
| factors = results['scale_factor'] | |||
| assert factors[0] == factors[2] | |||
| assert factors[1] == factors[3] | |||
| keypointss[:, :, 0] *= factors[0] | |||
| keypointss[:, :, 1] *= factors[1] | |||
| if self.bbox_clip_border: | |||
| img_shape = results['img_shape'] | |||
| keypointss[:, :, 0] = np.clip(keypointss[:, :, 0], 0, | |||
| img_shape[1]) | |||
| keypointss[:, :, 1] = np.clip(keypointss[:, :, 1], 0, | |||
| img_shape[0]) | |||
| results[key] = keypointss | |||
| def _resize_masks(self, results): | |||
| """Resize masks with ``results['scale']``""" | |||
| for key in results.get('mask_fields', []): | |||
| if results[key] is None: | |||
| continue | |||
| if self.keep_ratio: | |||
| results[key] = results[key].rescale(results['scale']) | |||
| else: | |||
| results[key] = results[key].resize(results['img_shape'][:2]) | |||
| def _resize_seg(self, results): | |||
| """Resize semantic segmentation map with ``results['scale']``.""" | |||
| for key in results.get('seg_fields', []): | |||
| if self.keep_ratio: | |||
| gt_seg = mmcv.imrescale( | |||
| results[key], | |||
| results['scale'], | |||
| interpolation='nearest', | |||
| backend=self.backend) | |||
| else: | |||
| gt_seg = mmcv.imresize( | |||
| results[key], | |||
| results['scale'], | |||
| interpolation='nearest', | |||
| backend=self.backend) | |||
| results['gt_semantic_seg'] = gt_seg | |||
| def __call__(self, results): | |||
| """Call function to resize images, bounding boxes, masks, semantic | |||
| segmentation map. | |||
| Args: | |||
| results (dict): Result dict from loading pipeline. | |||
| Returns: | |||
| dict: Resized results, 'img_shape', 'pad_shape', 'scale_factor', \ | |||
| 'keep_ratio' keys are added into result dict. | |||
| """ | |||
| if 'scale' not in results: | |||
| if 'scale_factor' in results: | |||
| img_shape = results['img'].shape[:2] | |||
| scale_factor = results['scale_factor'] | |||
| assert isinstance(scale_factor, float) | |||
| results['scale'] = tuple( | |||
| [int(x * scale_factor) for x in img_shape][::-1]) | |||
| else: | |||
| self._random_scale(results) | |||
| else: | |||
| if not self.override: | |||
| assert 'scale_factor' not in results, ( | |||
| 'scale and scale_factor cannot be both set.') | |||
| else: | |||
| results.pop('scale') | |||
| if 'scale_factor' in results: | |||
| results.pop('scale_factor') | |||
| self._random_scale(results) | |||
| self._resize_img(results) | |||
| self._resize_bboxes(results) | |||
| self._resize_keypoints(results) | |||
| self._resize_masks(results) | |||
| self._resize_seg(results) | |||
| return results | |||
| def __repr__(self): | |||
| repr_str = self.__class__.__name__ | |||
| repr_str += f'(img_scale={self.img_scale}, ' | |||
| repr_str += f'multiscale_mode={self.multiscale_mode}, ' | |||
| repr_str += f'ratio_range={self.ratio_range}, ' | |||
| repr_str += f'keep_ratio={self.keep_ratio})' | |||
| repr_str += f'bbox_clip_border={self.bbox_clip_border})' | |||
| return repr_str | |||
| @PIPELINES.register_module() | |||
| class RandomFlipV2(object): | |||
| """Flip the image & bbox & mask & kps. | |||
| If the input dict contains the key "flip", then the flag will be used, | |||
| otherwise it will be randomly decided by a ratio specified in the init | |||
| method. | |||
| When random flip is enabled, ``flip_ratio``/``direction`` can either be a | |||
| float/string or tuple of float/string. There are 3 flip modes: | |||
| - ``flip_ratio`` is float, ``direction`` is string: the image will be | |||
| ``direction``ly flipped with probability of ``flip_ratio`` . | |||
| E.g., ``flip_ratio=0.5``, ``direction='horizontal'``, | |||
| then image will be horizontally flipped with probability of 0.5. | |||
| - ``flip_ratio`` is float, ``direction`` is list of string: the image wil | |||
| be ``direction[i]``ly flipped with probability of | |||
| ``flip_ratio/len(direction)``. | |||
| E.g., ``flip_ratio=0.5``, ``direction=['horizontal', 'vertical']``, | |||
| then image will be horizontally flipped with probability of 0.25, | |||
| vertically with probability of 0.25. | |||
| - ``flip_ratio`` is list of float, ``direction`` is list of string: | |||
| given ``len(flip_ratio) == len(direction)``, the image wil | |||
| be ``direction[i]``ly flipped with probability of ``flip_ratio[i]``. | |||
| E.g., ``flip_ratio=[0.3, 0.5]``, ``direction=['horizontal', | |||
| 'vertical']``, then image will be horizontally flipped with probability | |||
| of 0.3, vertically with probability of 0.5 | |||
| Args: | |||
| flip_ratio (float | list[float], optional): The flipping probability. | |||
| Default: None. | |||
| direction(str | list[str], optional): The flipping direction. Options | |||
| are 'horizontal', 'vertical', 'diagonal'. Default: 'horizontal'. | |||
| If input is a list, the length must equal ``flip_ratio``. Each | |||
| element in ``flip_ratio`` indicates the flip probability of | |||
| corresponding direction. | |||
| """ | |||
| def __init__(self, flip_ratio=None, direction='horizontal'): | |||
| if isinstance(flip_ratio, list): | |||
| assert mmcv.is_list_of(flip_ratio, float) | |||
| assert 0 <= sum(flip_ratio) <= 1 | |||
| elif isinstance(flip_ratio, float): | |||
| assert 0 <= flip_ratio <= 1 | |||
| elif flip_ratio is None: | |||
| pass | |||
| else: | |||
| raise ValueError('flip_ratios must be None, float, ' | |||
| 'or list of float') | |||
| self.flip_ratio = flip_ratio | |||
| valid_directions = ['horizontal', 'vertical', 'diagonal'] | |||
| if isinstance(direction, str): | |||
| assert direction in valid_directions | |||
| elif isinstance(direction, list): | |||
| assert mmcv.is_list_of(direction, str) | |||
| assert set(direction).issubset(set(valid_directions)) | |||
| else: | |||
| raise ValueError('direction must be either str or list of str') | |||
| self.direction = direction | |||
| if isinstance(flip_ratio, list): | |||
| assert len(self.flip_ratio) == len(self.direction) | |||
| self.count = 0 | |||
| def bbox_flip(self, bboxes, img_shape, direction): | |||
| """Flip bboxes horizontally. | |||
| Args: | |||
| bboxes (numpy.ndarray): Bounding boxes, shape (..., 4*k) | |||
| img_shape (tuple[int]): Image shape (height, width) | |||
| direction (str): Flip direction. Options are 'horizontal', | |||
| 'vertical'. | |||
| Returns: | |||
| numpy.ndarray: Flipped bounding boxes. | |||
| """ | |||
| assert bboxes.shape[-1] % 4 == 0 | |||
| flipped = bboxes.copy() | |||
| if direction == 'horizontal': | |||
| w = img_shape[1] | |||
| flipped[..., 0::4] = w - bboxes[..., 2::4] | |||
| flipped[..., 2::4] = w - bboxes[..., 0::4] | |||
| elif direction == 'vertical': | |||
| h = img_shape[0] | |||
| flipped[..., 1::4] = h - bboxes[..., 3::4] | |||
| flipped[..., 3::4] = h - bboxes[..., 1::4] | |||
| elif direction == 'diagonal': | |||
| w = img_shape[1] | |||
| h = img_shape[0] | |||
| flipped[..., 0::4] = w - bboxes[..., 2::4] | |||
| flipped[..., 1::4] = h - bboxes[..., 3::4] | |||
| flipped[..., 2::4] = w - bboxes[..., 0::4] | |||
| flipped[..., 3::4] = h - bboxes[..., 1::4] | |||
| else: | |||
| raise ValueError(f"Invalid flipping direction '{direction}'") | |||
| return flipped | |||
| def keypoints_flip(self, keypointss, img_shape, direction): | |||
| """Flip keypoints horizontally.""" | |||
| assert direction == 'horizontal' | |||
| assert keypointss.shape[-1] == 3 | |||
| num_kps = keypointss.shape[1] | |||
| assert num_kps in [4, 5], f'Only Support num_kps=4 or 5, got:{num_kps}' | |||
| assert keypointss.ndim == 3 | |||
| flipped = keypointss.copy() | |||
| if num_kps == 5: | |||
| flip_order = [1, 0, 2, 4, 3] | |||
| elif num_kps == 4: | |||
| flip_order = [3, 2, 1, 0] | |||
| for idx, a in enumerate(flip_order): | |||
| flipped[:, idx, :] = keypointss[:, a, :] | |||
| w = img_shape[1] | |||
| flipped[..., 0] = w - flipped[..., 0] | |||
| return flipped | |||
| def __call__(self, results): | |||
| """Call function to flip bounding boxes, masks, semantic segmentation | |||
| maps. | |||
| Args: | |||
| results (dict): Result dict from loading pipeline. | |||
| Returns: | |||
| dict: Flipped results, 'flip', 'flip_direction' keys are added \ | |||
| into result dict. | |||
| """ | |||
| if 'flip' not in results: | |||
| if isinstance(self.direction, list): | |||
| # None means non-flip | |||
| direction_list = self.direction + [None] | |||
| else: | |||
| # None means non-flip | |||
| direction_list = [self.direction, None] | |||
| if isinstance(self.flip_ratio, list): | |||
| non_flip_ratio = 1 - sum(self.flip_ratio) | |||
| flip_ratio_list = self.flip_ratio + [non_flip_ratio] | |||
| else: | |||
| non_flip_ratio = 1 - self.flip_ratio | |||
| # exclude non-flip | |||
| single_ratio = self.flip_ratio / (len(direction_list) - 1) | |||
| flip_ratio_list = [single_ratio] * (len(direction_list) | |||
| - 1) + [non_flip_ratio] | |||
| cur_dir = np.random.choice(direction_list, p=flip_ratio_list) | |||
| results['flip'] = cur_dir is not None | |||
| if 'flip_direction' not in results: | |||
| results['flip_direction'] = cur_dir | |||
| if results['flip']: | |||
| # flip image | |||
| for key in results.get('img_fields', ['img']): | |||
| results[key] = mmcv.imflip( | |||
| results[key], direction=results['flip_direction']) | |||
| # flip bboxes | |||
| for key in results.get('bbox_fields', []): | |||
| results[key] = self.bbox_flip(results[key], | |||
| results['img_shape'], | |||
| results['flip_direction']) | |||
| # flip kps | |||
| for key in results.get('keypoints_fields', []): | |||
| results[key] = self.keypoints_flip(results[key], | |||
| results['img_shape'], | |||
| results['flip_direction']) | |||
| # flip masks | |||
| for key in results.get('mask_fields', []): | |||
| results[key] = results[key].flip(results['flip_direction']) | |||
| # flip segs | |||
| for key in results.get('seg_fields', []): | |||
| results[key] = mmcv.imflip( | |||
| results[key], direction=results['flip_direction']) | |||
| return results | |||
| def __repr__(self): | |||
| return self.__class__.__name__ + f'(flip_ratio={self.flip_ratio})' | |||
| @PIPELINES.register_module() | |||
| class RandomSquareCrop(object): | |||
| """Random crop the image & bboxes, the cropped patches have minimum IoU | |||
| requirement with original image & bboxes, the IoU threshold is randomly | |||
| selected from min_ious. | |||
| Args: | |||
| min_ious (tuple): minimum IoU threshold for all intersections with | |||
| bounding boxes | |||
| min_crop_size (float): minimum crop's size (i.e. h,w := a*h, a*w, | |||
| where a >= min_crop_size). | |||
| Note: | |||
| The keys for bboxes, labels and masks should be paired. That is, \ | |||
| `gt_bboxes` corresponds to `gt_labels` and `gt_masks`, and \ | |||
| `gt_bboxes_ignore` to `gt_labels_ignore` and `gt_masks_ignore`. | |||
| """ | |||
| def __init__(self, | |||
| crop_ratio_range=None, | |||
| crop_choice=None, | |||
| bbox_clip_border=True, | |||
| big_face_ratio=0, | |||
| big_face_crop_choice=None): | |||
| self.crop_ratio_range = crop_ratio_range | |||
| self.crop_choice = crop_choice | |||
| self.big_face_crop_choice = big_face_crop_choice | |||
| self.bbox_clip_border = bbox_clip_border | |||
| assert (self.crop_ratio_range is None) ^ (self.crop_choice is None) | |||
| if self.crop_ratio_range is not None: | |||
| self.crop_ratio_min, self.crop_ratio_max = self.crop_ratio_range | |||
| self.bbox2label = { | |||
| 'gt_bboxes': 'gt_labels', | |||
| 'gt_bboxes_ignore': 'gt_labels_ignore' | |||
| } | |||
| self.bbox2mask = { | |||
| 'gt_bboxes': 'gt_masks', | |||
| 'gt_bboxes_ignore': 'gt_masks_ignore' | |||
| } | |||
| assert big_face_ratio >= 0 and big_face_ratio <= 1.0 | |||
| self.big_face_ratio = big_face_ratio | |||
| def __call__(self, results): | |||
| """Call function to crop images and bounding boxes with minimum IoU | |||
| constraint. | |||
| Args: | |||
| results (dict): Result dict from loading pipeline. | |||
| Returns: | |||
| dict: Result dict with images and bounding boxes cropped, \ | |||
| 'img_shape' key is updated. | |||
| """ | |||
| if 'img_fields' in results: | |||
| assert results['img_fields'] == ['img'], \ | |||
| 'Only single img_fields is allowed' | |||
| img = results['img'] | |||
| assert 'bbox_fields' in results | |||
| assert 'gt_bboxes' in results | |||
| # try augment big face images | |||
| find_bigface = False | |||
| if np.random.random() < self.big_face_ratio: | |||
| min_size = 100 # h and w | |||
| expand_ratio = 0.3 # expand ratio of croped face alongwith both w and h | |||
| bbox = results['gt_bboxes'].copy() | |||
| lmks = results['gt_keypointss'].copy() | |||
| label = results['gt_labels'].copy() | |||
| # filter small faces | |||
| size_mask = ((bbox[:, 2] - bbox[:, 0]) > min_size) * ( | |||
| (bbox[:, 3] - bbox[:, 1]) > min_size) | |||
| bbox = bbox[size_mask] | |||
| lmks = lmks[size_mask] | |||
| label = label[size_mask] | |||
| # randomly choose a face that has no overlap with others | |||
| if len(bbox) > 0: | |||
| overlaps = bbox_overlaps(bbox, bbox) | |||
| overlaps -= np.eye(overlaps.shape[0]) | |||
| iou_mask = np.sum(overlaps, axis=1) == 0 | |||
| bbox = bbox[iou_mask] | |||
| lmks = lmks[iou_mask] | |||
| label = label[iou_mask] | |||
| if len(bbox) > 0: | |||
| choice = np.random.randint(len(bbox)) | |||
| bbox = bbox[choice] | |||
| lmks = lmks[choice] | |||
| label = [label[choice]] | |||
| w = bbox[2] - bbox[0] | |||
| h = bbox[3] - bbox[1] | |||
| x1 = bbox[0] - w * expand_ratio | |||
| x2 = bbox[2] + w * expand_ratio | |||
| y1 = bbox[1] - h * expand_ratio | |||
| y2 = bbox[3] + h * expand_ratio | |||
| x1, x2 = np.clip([x1, x2], 0, img.shape[1]) | |||
| y1, y2 = np.clip([y1, y2], 0, img.shape[0]) | |||
| bbox -= np.tile([x1, y1], 2) | |||
| lmks -= (x1, y1, 0) | |||
| find_bigface = True | |||
| img = img[int(y1):int(y2), int(x1):int(x2), :] | |||
| results['gt_bboxes'] = np.expand_dims(bbox, axis=0) | |||
| results['gt_keypointss'] = np.expand_dims(lmks, axis=0) | |||
| results['gt_labels'] = np.array(label) | |||
| results['img'] = img | |||
| boxes = results['gt_bboxes'] | |||
| h, w, c = img.shape | |||
| if self.crop_ratio_range is not None: | |||
| max_scale = self.crop_ratio_max | |||
| else: | |||
| max_scale = np.amax(self.crop_choice) | |||
| scale_retry = 0 | |||
| while True: | |||
| scale_retry += 1 | |||
| if scale_retry == 1 or max_scale > 1.0: | |||
| if self.crop_ratio_range is not None: | |||
| scale = np.random.uniform(self.crop_ratio_min, | |||
| self.crop_ratio_max) | |||
| elif self.crop_choice is not None: | |||
| scale = np.random.choice(self.crop_choice) | |||
| else: | |||
| scale = scale * 1.2 | |||
| if find_bigface: | |||
| # select a scale from big_face_crop_choice if in big_face mode | |||
| scale = np.random.choice(self.big_face_crop_choice) | |||
| for i in range(250): | |||
| long_side = max(w, h) | |||
| cw = int(scale * long_side) | |||
| ch = cw | |||
| # TODO +1 | |||
| if w == cw: | |||
| left = 0 | |||
| elif w > cw: | |||
| left = random.randint(0, w - cw) | |||
| else: | |||
| left = random.randint(w - cw, 0) | |||
| if h == ch: | |||
| top = 0 | |||
| elif h > ch: | |||
| top = random.randint(0, h - ch) | |||
| else: | |||
| top = random.randint(h - ch, 0) | |||
| patch = np.array( | |||
| (int(left), int(top), int(left + cw), int(top + ch)), | |||
| dtype=np.int32) | |||
| # center of boxes should inside the crop img | |||
| # only adjust boxes and instance masks when the gt is not empty | |||
| # adjust boxes | |||
| def is_center_of_bboxes_in_patch(boxes, patch): | |||
| # TODO >= | |||
| center = (boxes[:, :2] + boxes[:, 2:]) / 2 | |||
| mask = \ | |||
| ((center[:, 0] > patch[0]) | |||
| * (center[:, 1] > patch[1]) | |||
| * (center[:, 0] < patch[2]) | |||
| * (center[:, 1] < patch[3])) | |||
| return mask | |||
| mask = is_center_of_bboxes_in_patch(boxes, patch) | |||
| if not mask.any(): | |||
| continue | |||
| for key in results.get('bbox_fields', []): | |||
| boxes = results[key].copy() | |||
| mask = is_center_of_bboxes_in_patch(boxes, patch) | |||
| boxes = boxes[mask] | |||
| if self.bbox_clip_border: | |||
| boxes[:, 2:] = boxes[:, 2:].clip(max=patch[2:]) | |||
| boxes[:, :2] = boxes[:, :2].clip(min=patch[:2]) | |||
| boxes -= np.tile(patch[:2], 2) | |||
| results[key] = boxes | |||
| # labels | |||
| label_key = self.bbox2label.get(key) | |||
| if label_key in results: | |||
| results[label_key] = results[label_key][mask] | |||
| # keypoints field | |||
| if key == 'gt_bboxes': | |||
| for kps_key in results.get('keypoints_fields', []): | |||
| keypointss = results[kps_key].copy() | |||
| keypointss = keypointss[mask, :, :] | |||
| if self.bbox_clip_border: | |||
| keypointss[:, :, : | |||
| 2] = keypointss[:, :, :2].clip( | |||
| max=patch[2:]) | |||
| keypointss[:, :, : | |||
| 2] = keypointss[:, :, :2].clip( | |||
| min=patch[:2]) | |||
| keypointss[:, :, 0] -= patch[0] | |||
| keypointss[:, :, 1] -= patch[1] | |||
| results[kps_key] = keypointss | |||
| # mask fields | |||
| mask_key = self.bbox2mask.get(key) | |||
| if mask_key in results: | |||
| results[mask_key] = results[mask_key][mask.nonzero() | |||
| [0]].crop(patch) | |||
| # adjust the img no matter whether the gt is empty before crop | |||
| rimg = np.ones((ch, cw, 3), dtype=img.dtype) * 128 | |||
| patch_from = patch.copy() | |||
| patch_from[0] = max(0, patch_from[0]) | |||
| patch_from[1] = max(0, patch_from[1]) | |||
| patch_from[2] = min(img.shape[1], patch_from[2]) | |||
| patch_from[3] = min(img.shape[0], patch_from[3]) | |||
| patch_to = patch.copy() | |||
| patch_to[0] = max(0, patch_to[0] * -1) | |||
| patch_to[1] = max(0, patch_to[1] * -1) | |||
| patch_to[2] = patch_to[0] + (patch_from[2] - patch_from[0]) | |||
| patch_to[3] = patch_to[1] + (patch_from[3] - patch_from[1]) | |||
| rimg[patch_to[1]:patch_to[3], | |||
| patch_to[0]:patch_to[2], :] = img[ | |||
| patch_from[1]:patch_from[3], | |||
| patch_from[0]:patch_from[2], :] | |||
| img = rimg | |||
| results['img'] = img | |||
| results['img_shape'] = img.shape | |||
| return results | |||
| def __repr__(self): | |||
| repr_str = self.__class__.__name__ | |||
| repr_str += f'(min_ious={self.min_iou}, ' | |||
| repr_str += f'crop_size={self.crop_size})' | |||
| return repr_str | |||
| @@ -13,7 +13,7 @@ class RetinaFaceDataset(CustomDataset): | |||
| CLASSES = ('FG', ) | |||
| def __init__(self, min_size=None, **kwargs): | |||
| self.NK = 5 | |||
| self.NK = kwargs.pop('num_kps', 5) | |||
| self.cat2label = {cat: i for i, cat in enumerate(self.CLASSES)} | |||
| self.min_size = min_size | |||
| self.gt_path = kwargs.get('gt_path') | |||
| @@ -33,7 +33,8 @@ class RetinaFaceDataset(CustomDataset): | |||
| if len(values) > 4: | |||
| if len(values) > 5: | |||
| kps = np.array( | |||
| values[4:19], dtype=np.float32).reshape((self.NK, 3)) | |||
| values[4:4 + self.NK * 3], dtype=np.float32).reshape( | |||
| (self.NK, 3)) | |||
| for li in range(kps.shape[0]): | |||
| if (kps[li, :] == -1).all(): | |||
| kps[li][2] = 0.0 # weight = 0, ignore | |||
| @@ -103,6 +103,7 @@ class SCRFDHead(AnchorHead): | |||
| scale_mode=1, | |||
| dw_conv=False, | |||
| use_kps=False, | |||
| num_kps=5, | |||
| loss_kps=dict( | |||
| type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=0.1), | |||
| **kwargs): | |||
| @@ -116,7 +117,7 @@ class SCRFDHead(AnchorHead): | |||
| self.scale_mode = scale_mode | |||
| self.use_dfl = True | |||
| self.dw_conv = dw_conv | |||
| self.NK = 5 | |||
| self.NK = num_kps | |||
| self.extra_flops = 0.0 | |||
| if loss_dfl is None or not loss_dfl: | |||
| self.use_dfl = False | |||
| @@ -323,8 +324,8 @@ class SCRFDHead(AnchorHead): | |||
| batch_size, -1, self.cls_out_channels).sigmoid() | |||
| bbox_pred = bbox_pred.permute(0, 2, 3, | |||
| 1).reshape(batch_size, -1, 4) | |||
| kps_pred = kps_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, 10) | |||
| kps_pred = kps_pred.permute(0, 2, 3, | |||
| 1).reshape(batch_size, -1, self.NK * 2) | |||
| return cls_score, bbox_pred, kps_pred | |||
| def forward_train(self, | |||
| @@ -788,7 +789,7 @@ class SCRFDHead(AnchorHead): | |||
| if self.use_dfl: | |||
| kps_pred = self.integral(kps_pred) * stride[0] | |||
| else: | |||
| kps_pred = kps_pred.reshape((-1, 10)) * stride[0] | |||
| kps_pred = kps_pred.reshape((-1, self.NK * 2)) * stride[0] | |||
| nms_pre = cfg.get('nms_pre', -1) | |||
| if nms_pre > 0 and scores.shape[0] > nms_pre: | |||
| @@ -815,7 +816,7 @@ class SCRFDHead(AnchorHead): | |||
| mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor) | |||
| if mlvl_kps is not None: | |||
| scale_factor2 = torch.tensor( | |||
| [scale_factor[0], scale_factor[1]] * 5) | |||
| [scale_factor[0], scale_factor[1]] * self.NK) | |||
| mlvl_kps /= scale_factor2.to(mlvl_kps.device) | |||
| mlvl_scores = torch.cat(mlvl_scores) | |||
| @@ -54,7 +54,13 @@ class SCRFD(SingleStageDetector): | |||
| gt_bboxes_ignore) | |||
| return losses | |||
| def simple_test(self, img, img_metas, rescale=False): | |||
| def simple_test(self, | |||
| img, | |||
| img_metas, | |||
| rescale=False, | |||
| repeat_head=1, | |||
| output_kps_var=0, | |||
| output_results=1): | |||
| """Test function without test time augmentation. | |||
| Args: | |||
| @@ -62,6 +68,9 @@ class SCRFD(SingleStageDetector): | |||
| img_metas (list[dict]): List of image information. | |||
| rescale (bool, optional): Whether to rescale the results. | |||
| Defaults to False. | |||
| repeat_head (int): repeat inference times in head | |||
| output_kps_var (int): whether output kps var to calculate quality | |||
| output_results (int): 0: nothing 1: bbox 2: both bbox and kps | |||
| Returns: | |||
| list[list[np.ndarray]]: BBox results of each image and classes. | |||
| @@ -69,40 +78,71 @@ class SCRFD(SingleStageDetector): | |||
| corresponds to each class. | |||
| """ | |||
| x = self.extract_feat(img) | |||
| outs = self.bbox_head(x) | |||
| if torch.onnx.is_in_onnx_export(): | |||
| print('single_stage.py in-onnx-export') | |||
| print(outs.__class__) | |||
| cls_score, bbox_pred, kps_pred = outs | |||
| for c in cls_score: | |||
| print(c.shape) | |||
| for c in bbox_pred: | |||
| print(c.shape) | |||
| if self.bbox_head.use_kps: | |||
| for c in kps_pred: | |||
| assert repeat_head >= 1 | |||
| kps_out0 = [] | |||
| kps_out1 = [] | |||
| kps_out2 = [] | |||
| for i in range(repeat_head): | |||
| outs = self.bbox_head(x) | |||
| kps_out0 += [outs[2][0].detach().cpu().numpy()] | |||
| kps_out1 += [outs[2][1].detach().cpu().numpy()] | |||
| kps_out2 += [outs[2][2].detach().cpu().numpy()] | |||
| if output_kps_var: | |||
| var0 = np.var(np.vstack(kps_out0), axis=0).mean() | |||
| var1 = np.var(np.vstack(kps_out1), axis=0).mean() | |||
| var2 = np.var(np.vstack(kps_out2), axis=0).mean() | |||
| var = np.mean([var0, var1, var2]) | |||
| else: | |||
| var = None | |||
| if output_results > 0: | |||
| if torch.onnx.is_in_onnx_export(): | |||
| print('single_stage.py in-onnx-export') | |||
| print(outs.__class__) | |||
| cls_score, bbox_pred, kps_pred = outs | |||
| for c in cls_score: | |||
| print(c.shape) | |||
| for c in bbox_pred: | |||
| print(c.shape) | |||
| return (cls_score, bbox_pred, kps_pred) | |||
| else: | |||
| return (cls_score, bbox_pred) | |||
| bbox_list = self.bbox_head.get_bboxes( | |||
| *outs, img_metas, rescale=rescale) | |||
| if self.bbox_head.use_kps: | |||
| for c in kps_pred: | |||
| print(c.shape) | |||
| return (cls_score, bbox_pred, kps_pred) | |||
| else: | |||
| return (cls_score, bbox_pred) | |||
| bbox_list = self.bbox_head.get_bboxes( | |||
| *outs, img_metas, rescale=rescale) | |||
| # return kps if use_kps | |||
| if len(bbox_list[0]) == 2: | |||
| bbox_results = [ | |||
| bbox2result(det_bboxes, det_labels, self.bbox_head.num_classes) | |||
| for det_bboxes, det_labels in bbox_list | |||
| ] | |||
| elif len(bbox_list[0]) == 3: | |||
| bbox_results = [ | |||
| bbox2result( | |||
| det_bboxes, | |||
| det_labels, | |||
| self.bbox_head.num_classes, | |||
| kps=det_kps) | |||
| for det_bboxes, det_labels, det_kps in bbox_list | |||
| ] | |||
| return bbox_results | |||
| # return kps if use_kps | |||
| if len(bbox_list[0]) == 2: | |||
| bbox_results = [ | |||
| bbox2result(det_bboxes, det_labels, | |||
| self.bbox_head.num_classes) | |||
| for det_bboxes, det_labels in bbox_list | |||
| ] | |||
| elif len(bbox_list[0]) == 3: | |||
| if output_results == 2: | |||
| bbox_results = [ | |||
| bbox2result( | |||
| det_bboxes, | |||
| det_labels, | |||
| self.bbox_head.num_classes, | |||
| kps=det_kps, | |||
| num_kps=self.bbox_head.NK) | |||
| for det_bboxes, det_labels, det_kps in bbox_list | |||
| ] | |||
| elif output_results == 1: | |||
| bbox_results = [ | |||
| bbox2result(det_bboxes, det_labels, | |||
| self.bbox_head.num_classes) | |||
| for det_bboxes, det_labels, _ in bbox_list | |||
| ] | |||
| else: | |||
| bbox_results = None | |||
| if var is not None: | |||
| return bbox_results, var | |||
| else: | |||
| return bbox_results | |||
| def feature_test(self, img): | |||
| x = self.extract_feat(img) | |||
| @@ -0,0 +1,71 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os.path as osp | |||
| from copy import deepcopy | |||
| from typing import Any, Dict | |||
| import torch | |||
| from modelscope.metainfo import Models | |||
| from modelscope.models.base import TorchModel | |||
| from modelscope.models.builder import MODELS | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| logger = get_logger() | |||
| __all__ = ['ScrfdDetect'] | |||
| @MODELS.register_module(Tasks.face_detection, module_name=Models.scrfd) | |||
| class ScrfdDetect(TorchModel): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| """initialize the face detection model from the `model_dir` path. | |||
| Args: | |||
| model_dir (str): the model path. | |||
| """ | |||
| super().__init__(model_dir, *args, **kwargs) | |||
| from mmcv import Config | |||
| from mmcv.parallel import MMDataParallel | |||
| from mmcv.runner import load_checkpoint | |||
| from mmdet.models import build_detector | |||
| from modelscope.models.cv.face_detection.scrfd.mmdet_patch.datasets import RetinaFaceDataset | |||
| from modelscope.models.cv.face_detection.scrfd.mmdet_patch.datasets.pipelines import RandomSquareCrop | |||
| from modelscope.models.cv.face_detection.scrfd.mmdet_patch.models.backbones import ResNetV1e | |||
| from modelscope.models.cv.face_detection.scrfd.mmdet_patch.models.dense_heads import SCRFDHead | |||
| from modelscope.models.cv.face_detection.scrfd.mmdet_patch.models.detectors import SCRFD | |||
| cfg = Config.fromfile(osp.join(model_dir, 'mmcv_scrfd.py')) | |||
| ckpt_path = osp.join(model_dir, ModelFile.TORCH_MODEL_BIN_FILE) | |||
| cfg.model.test_cfg.score_thr = kwargs.get('score_thr', 0.3) | |||
| detector = build_detector(cfg.model) | |||
| logger.info(f'loading model from {ckpt_path}') | |||
| device = torch.device( | |||
| f'cuda:{0}' if torch.cuda.is_available() else 'cpu') | |||
| load_checkpoint(detector, ckpt_path, map_location=device) | |||
| detector = MMDataParallel(detector, device_ids=[0]) | |||
| detector.eval() | |||
| self.detector = detector | |||
| logger.info('load model done') | |||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
| result = self.detector( | |||
| return_loss=False, | |||
| rescale=True, | |||
| img=[input['img'][0].unsqueeze(0)], | |||
| img_metas=[[dict(input['img_metas'][0].data)]], | |||
| output_results=2) | |||
| assert result is not None | |||
| result = result[0][0] | |||
| bboxes = result[:, :4].tolist() | |||
| kpss = result[:, 5:].tolist() | |||
| scores = result[:, 4].tolist() | |||
| return { | |||
| OutputKeys.SCORES: scores, | |||
| OutputKeys.BOXES: bboxes, | |||
| OutputKeys.KEYPOINTS: kpss | |||
| } | |||
| def postprocess(self, input: Dict[str, Any], **kwargs) -> Dict[str, Any]: | |||
| return input | |||