merge master internal to github mastermaster
| @@ -1,26 +1,32 @@ | |||||
| awk -F: '/^[^#]/ { print $1 }' requirements/framework.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html | |||||
| awk -F: '/^[^#]/ { print $1 }' requirements/audio.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html | |||||
| awk -F: '/^[^#]/ { print $1 }' requirements/cv.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html | |||||
| awk -F: '/^[^#]/ { print $1 }' requirements/multi-modal.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html | |||||
| awk -F: '/^[^#]/ { print $1 }' requirements/nlp.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html | |||||
| pip install -r requirements/tests.txt | |||||
| echo "Testing envs" | |||||
| printenv | |||||
| echo "ENV END" | |||||
| if [ "$MODELSCOPE_SDK_DEBUG" == "True" ]; then | |||||
| awk -F: '/^[^#]/ { print $1 }' requirements/framework.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html | |||||
| awk -F: '/^[^#]/ { print $1 }' requirements/audio.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html | |||||
| awk -F: '/^[^#]/ { print $1 }' requirements/cv.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html | |||||
| awk -F: '/^[^#]/ { print $1 }' requirements/multi-modal.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html | |||||
| awk -F: '/^[^#]/ { print $1 }' requirements/nlp.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html | |||||
| pip install -r requirements/tests.txt | |||||
| git config --global --add safe.directory /Maas-lib | |||||
| git config --global user.email tmp | |||||
| git config --global user.name tmp.com | |||||
| git config --global --add safe.directory /Maas-lib | |||||
| git config --global user.email tmp | |||||
| git config --global user.name tmp.com | |||||
| # linter test | |||||
| # use internal project for pre-commit due to the network problem | |||||
| if [ `git remote -v | grep alibaba | wc -l` -gt 1 ]; then | |||||
| pre-commit run -c .pre-commit-config_local.yaml --all-files | |||||
| fi | |||||
| if [ $? -ne 0 ]; then | |||||
| echo "linter test failed, please run 'pre-commit run --all-files' to check" | |||||
| exit -1 | |||||
| # linter test | |||||
| # use internal project for pre-commit due to the network problem | |||||
| if [ `git remote -v | grep alibaba | wc -l` -gt 1 ]; then | |||||
| pre-commit run -c .pre-commit-config_local.yaml --all-files | |||||
| if [ $? -ne 0 ]; then | |||||
| echo "linter test failed, please run 'pre-commit run --all-files' to check" | |||||
| exit -1 | |||||
| fi | |||||
| fi | |||||
| # test with install | |||||
| python setup.py install | |||||
| else | |||||
| echo "Running case in release image, run case directly!" | |||||
| fi | fi | ||||
| # test with install | |||||
| python setup.py install | |||||
| if [ $# -eq 0 ]; then | if [ $# -eq 0 ]; then | ||||
| ci_command="python tests/run.py --subprocess" | ci_command="python tests/run.py --subprocess" | ||||
| else | else | ||||
| @@ -20,28 +20,52 @@ do | |||||
| # pull image if there are update | # pull image if there are update | ||||
| docker pull ${IMAGE_NAME}:${IMAGE_VERSION} | docker pull ${IMAGE_NAME}:${IMAGE_VERSION} | ||||
| docker run --rm --name $CONTAINER_NAME --shm-size=16gb \ | |||||
| --cpuset-cpus=${cpu_sets_arr[$gpu]} \ | |||||
| --gpus="device=$gpu" \ | |||||
| -v $CODE_DIR:$CODE_DIR_IN_CONTAINER \ | |||||
| -v $MODELSCOPE_CACHE:$MODELSCOPE_CACHE_DIR_IN_CONTAINER \ | |||||
| -v $MODELSCOPE_HOME_CACHE/$gpu:/root \ | |||||
| -v /home/admin/pre-commit:/home/admin/pre-commit \ | |||||
| -e CI_TEST=True \ | |||||
| -e TEST_LEVEL=$TEST_LEVEL \ | |||||
| -e MODELSCOPE_CACHE=$MODELSCOPE_CACHE_DIR_IN_CONTAINER \ | |||||
| -e MODELSCOPE_DOMAIN=$MODELSCOPE_DOMAIN \ | |||||
| -e HUB_DATASET_ENDPOINT=$HUB_DATASET_ENDPOINT \ | |||||
| -e TEST_ACCESS_TOKEN_CITEST=$TEST_ACCESS_TOKEN_CITEST \ | |||||
| -e TEST_ACCESS_TOKEN_SDKDEV=$TEST_ACCESS_TOKEN_SDKDEV \ | |||||
| -e TEST_LEVEL=$TEST_LEVEL \ | |||||
| -e TEST_UPLOAD_MS_TOKEN=$TEST_UPLOAD_MS_TOKEN \ | |||||
| -e MODEL_TAG_URL=$MODEL_TAG_URL \ | |||||
| --workdir=$CODE_DIR_IN_CONTAINER \ | |||||
| --net host \ | |||||
| ${IMAGE_NAME}:${IMAGE_VERSION} \ | |||||
| $CI_COMMAND | |||||
| if [ "$MODELSCOPE_SDK_DEBUG" == "True" ]; then | |||||
| docker run --rm --name $CONTAINER_NAME --shm-size=16gb \ | |||||
| --cpuset-cpus=${cpu_sets_arr[$gpu]} \ | |||||
| --gpus="device=$gpu" \ | |||||
| -v $CODE_DIR:$CODE_DIR_IN_CONTAINER \ | |||||
| -v $MODELSCOPE_CACHE:$MODELSCOPE_CACHE_DIR_IN_CONTAINER \ | |||||
| -v $MODELSCOPE_HOME_CACHE/$gpu:/root \ | |||||
| -v /home/admin/pre-commit:/home/admin/pre-commit \ | |||||
| -e CI_TEST=True \ | |||||
| -e TEST_LEVEL=$TEST_LEVEL \ | |||||
| -e MODELSCOPE_CACHE=$MODELSCOPE_CACHE_DIR_IN_CONTAINER \ | |||||
| -e MODELSCOPE_DOMAIN=$MODELSCOPE_DOMAIN \ | |||||
| -e MODELSCOPE_SDK_DEBUG=True \ | |||||
| -e HUB_DATASET_ENDPOINT=$HUB_DATASET_ENDPOINT \ | |||||
| -e TEST_ACCESS_TOKEN_CITEST=$TEST_ACCESS_TOKEN_CITEST \ | |||||
| -e TEST_ACCESS_TOKEN_SDKDEV=$TEST_ACCESS_TOKEN_SDKDEV \ | |||||
| -e TEST_LEVEL=$TEST_LEVEL \ | |||||
| -e TEST_UPLOAD_MS_TOKEN=$TEST_UPLOAD_MS_TOKEN \ | |||||
| -e MODEL_TAG_URL=$MODEL_TAG_URL \ | |||||
| --workdir=$CODE_DIR_IN_CONTAINER \ | |||||
| --net host \ | |||||
| ${IMAGE_NAME}:${IMAGE_VERSION} \ | |||||
| $CI_COMMAND | |||||
| else | |||||
| docker run --rm --name $CONTAINER_NAME --shm-size=16gb \ | |||||
| --cpuset-cpus=${cpu_sets_arr[$gpu]} \ | |||||
| --gpus="device=$gpu" \ | |||||
| -v $CODE_DIR:$CODE_DIR_IN_CONTAINER \ | |||||
| -v $MODELSCOPE_CACHE:$MODELSCOPE_CACHE_DIR_IN_CONTAINER \ | |||||
| -v $MODELSCOPE_HOME_CACHE/$gpu:/root \ | |||||
| -v /home/admin/pre-commit:/home/admin/pre-commit \ | |||||
| -e CI_TEST=True \ | |||||
| -e TEST_LEVEL=$TEST_LEVEL \ | |||||
| -e MODELSCOPE_CACHE=$MODELSCOPE_CACHE_DIR_IN_CONTAINER \ | |||||
| -e MODELSCOPE_DOMAIN=$MODELSCOPE_DOMAIN \ | |||||
| -e HUB_DATASET_ENDPOINT=$HUB_DATASET_ENDPOINT \ | |||||
| -e TEST_ACCESS_TOKEN_CITEST=$TEST_ACCESS_TOKEN_CITEST \ | |||||
| -e TEST_ACCESS_TOKEN_SDKDEV=$TEST_ACCESS_TOKEN_SDKDEV \ | |||||
| -e TEST_LEVEL=$TEST_LEVEL \ | |||||
| -e TEST_UPLOAD_MS_TOKEN=$TEST_UPLOAD_MS_TOKEN \ | |||||
| -e MODEL_TAG_URL=$MODEL_TAG_URL \ | |||||
| --workdir=$CODE_DIR_IN_CONTAINER \ | |||||
| --net host \ | |||||
| ${IMAGE_NAME}:${IMAGE_VERSION} \ | |||||
| $CI_COMMAND | |||||
| fi | |||||
| if [ $? -ne 0 ]; then | if [ $? -ne 0 ]; then | ||||
| echo "Running test case failed, please check the log!" | echo "Running test case failed, please check the log!" | ||||
| exit -1 | exit -1 | ||||
| @@ -11,10 +11,10 @@ jobs: | |||||
| runs-on: ubuntu-latest | runs-on: ubuntu-latest | ||||
| steps: | steps: | ||||
| - uses: actions/checkout@v2 | - uses: actions/checkout@v2 | ||||
| - name: Set up Python 3.6 | |||||
| - name: Set up Python 3.7 | |||||
| uses: actions/setup-python@v2 | uses: actions/setup-python@v2 | ||||
| with: | with: | ||||
| python-version: 3.6 | |||||
| python-version: 3.7 | |||||
| - name: Install pre-commit hook | - name: Install pre-commit hook | ||||
| run: | | run: | | ||||
| pip install pre-commit | pip install pre-commit | ||||
| @@ -1,6 +1,6 @@ | |||||
| repos: | repos: | ||||
| - repo: https://gitlab.com/pycqa/flake8.git | - repo: https://gitlab.com/pycqa/flake8.git | ||||
| rev: 3.8.3 | |||||
| rev: 4.0.0 | |||||
| hooks: | hooks: | ||||
| - id: flake8 | - id: flake8 | ||||
| exclude: thirdparty/|examples/ | exclude: thirdparty/|examples/ | ||||
| @@ -1,6 +1,6 @@ | |||||
| repos: | repos: | ||||
| - repo: /home/admin/pre-commit/flake8 | - repo: /home/admin/pre-commit/flake8 | ||||
| rev: 3.8.3 | |||||
| rev: 4.0.0 | |||||
| hooks: | hooks: | ||||
| - id: flake8 | - id: flake8 | ||||
| exclude: thirdparty/|examples/ | exclude: thirdparty/|examples/ | ||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:e999c247bfebb03d556a31722f0ce7145cac20a67fac9da813ad336e1f549f9f | |||||
| size 38954 | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:32eb8d4d537941bf0edea69cd6723e8ba489fa3df64e13e29f96e4fae0b856f4 | |||||
| size 93676 | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:f57aee13ade70be6b2c6e4f5e5c7404bdb03057b63828baefbaadcf23855a4cb | |||||
| size 472012 | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:fee8e0460ca707f108782be0d93c555bf34fb6b1cb297e5fceed70192cc65f9b | |||||
| size 71244 | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:450e31f9df8c5b48c617900625f01cb64c484f079a9843179fe9feaa7d163e61 | |||||
| size 181964 | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:255494c41bc1dfb0c954d827ec6ce775900e4f7a55fb0a7881bdf9d66a03b425 | |||||
| size 112078 | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:22a55277908bbc3ef60a0cf56b230eb507b9e837574e8f493e93644b1d21c281 | |||||
| size 200556 | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:ee92191836c76412463d8b282a7ab4e1aa57386ba699ec011a3e2c4d64f32f4b | |||||
| size 162636 | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:77d1537fc584c1505d8aa10ec8c86af57ab661199e4f28fd7ffee3c22d1e4e61 | |||||
| size 160204 | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:e8d653a9a1ee49789c3df38e8da96af7118e0d8336d6ed12cd6458efa015071d | |||||
| size 2327764 | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:c589d77404ea17d4d24daeb8624dce7e1ac919dc75e6bed44ea9d116f0514150 | |||||
| size 68524 | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:76bf84536edbaf192a8a699efc62ba2b06056bac12c426ecfcc2e003d91fbd32 | |||||
| size 53219 | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:ecbc9d0827cfb92e93e7d75868b1724142685dc20d3b32023c3c657a7b688a9c | |||||
| size 254845 | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:d510ab26ddc58ffea882c8ef850c1f9bd4444772f2bce7ebea3e76944536c3ae | |||||
| size 48909 | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:b2c1119e3d521cf2e583b1e85fc9c9afd1d44954b433135039a98050a730932d | |||||
| size 1127557 | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:46db348eae61448f1668ce282caec21375e96c3268d53da44aa67ec32cbf4fa5 | |||||
| size 2747938 | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:709c1828ed2d56badf2f19a40194da9a5e5e6db2fb73ef55d047407f49bc7a15 | |||||
| size 27616 | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:772b19f76c98044e39330853928624f10e085106a4292b4dd19f865531080747 | |||||
| size 959 | |||||
| @@ -1,3 +0,0 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:379e11d7fc3734d3ec95afd0d86460b4653fbf4bb1f57f993610d6a6fd30fd3d | |||||
| size 1702339 | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:dec0fbb931cb609bf481e56b89cd2fbbab79839f22832c3bbe69a8fae2769cdd | |||||
| size 167407 | |||||
| @@ -1,3 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | version https://git-lfs.github.com/spec/v1 | ||||
| oid sha256:4fd6fa6b23c2fdaf876606a767d9b64b1924e1acddfc06ac42db73ba86083280 | |||||
| size 119940 | |||||
| oid sha256:4eae921001139d7e3c06331c9ef2213f8fc1c23512acd95751559866fb770e96 | |||||
| size 121855 | |||||
| @@ -1,3 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | version https://git-lfs.github.com/spec/v1 | ||||
| oid sha256:4d37672a0e299a08d2daf5c7fc29bfce96bb15701fe5e5e68f068861ac2ee705 | |||||
| size 119619 | |||||
| oid sha256:f97d34d7450d17d0a93647129ab10d16b1f6e70c34a73b6f7687b79519ee4f71 | |||||
| size 121563 | |||||
| @@ -1,3 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | version https://git-lfs.github.com/spec/v1 | ||||
| oid sha256:c692e0753cfe349e520511427727a8252f141fa10e85f9a61562845e8d731f9a | |||||
| size 119619 | |||||
| oid sha256:a8355f27a3235209f206b5e75f4400353e5989e94cf4d71270b42ded8821d536 | |||||
| size 121563 | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:344ef971bdf310b76c6571d1f4994ab6abc5edc659654d71a4f75b14a30960c2 | |||||
| size 152926 | |||||
| @@ -1,3 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | version https://git-lfs.github.com/spec/v1 | ||||
| oid sha256:44e3925c15d86d8596baeb6bd1d153d86f57b7489798b2cf988a1248e110fd62 | |||||
| size 62231 | |||||
| oid sha256:f0aeb07b6c9b40a0cfa7492e839431764e9bece93c906833a07c05e83520a399 | |||||
| size 63161 | |||||
| @@ -1,3 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | version https://git-lfs.github.com/spec/v1 | ||||
| oid sha256:1ff17a0272752de4c88d4254b2e881f97f8ef022f03609d03ee1de0ae964368a | |||||
| size 62235 | |||||
| oid sha256:7aa5c7a2565ccf0d2eea4baf8adbd0e020dbe36a7159b31156c53141cc9b2df2 | |||||
| size 63165 | |||||
| @@ -1,3 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | version https://git-lfs.github.com/spec/v1 | ||||
| oid sha256:9103ce2bc89212f67fb49ce70783b7667e376900d0f70fb8f5c4432eb74bc572 | |||||
| size 60801 | |||||
| oid sha256:cc6de82a8485fbfa008f6c2d5411cd07ba03e4a780bcb4e67efc6fba3c6ce92f | |||||
| size 63597 | |||||
| @@ -1,3 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | version https://git-lfs.github.com/spec/v1 | ||||
| oid sha256:2d4dee34c7e83b77db04fb2f0d1200bfd37c7c24954c58e185da5cb96445975c | |||||
| size 60801 | |||||
| oid sha256:7d98ac11a4e9e2744a7402a5cc912da991a41938bbc5dd60f15ee5c6b3196030 | |||||
| size 63349 | |||||
| @@ -1,3 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | version https://git-lfs.github.com/spec/v1 | ||||
| oid sha256:9e3ecc2c30d382641d561f84849b199c12bb1a9418e8099a191153f6f5275a85 | |||||
| size 61589 | |||||
| oid sha256:01f9b9bf6f8bbf9bb377d4cb6f399b2e5e065381f5b7332343e0db7b4fae72a5 | |||||
| size 62519 | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:a49c9bc74a60860c360a4bf4509fe9db915279aaabd953f354f2c38e9be1e6cb | |||||
| size 2924691 | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:f58df1d25590c158ae0a04b3999bd44b610cdaddb17d78afd84c34b3f00d4e87 | |||||
| size 4068783 | |||||
| @@ -76,7 +76,7 @@ RUN pip install --no-cache-dir --upgrade pip && \ | |||||
| ENV SHELL=/bin/bash | ENV SHELL=/bin/bash | ||||
| # install special package | # install special package | ||||
| RUN pip install --no-cache-dir mmcls>=0.21.0 mmdet>=2.25.0 decord>=0.6.0 datasets==2.1.0 numpy==1.18.5 ipykernel fairseq | |||||
| RUN pip install --no-cache-dir mmcls>=0.21.0 mmdet>=2.25.0 decord>=0.6.0 datasets==2.1.0 numpy==1.18.5 ipykernel fairseq fasttext https://modelscope.oss-cn-beijing.aliyuncs.com/releases/dependencies/xtcocotools-1.12-cp37-cp37m-linux_x86_64.whl | |||||
| RUN if [ "$USE_GPU" = "True" ] ; then \ | RUN if [ "$USE_GPU" = "True" ] ; then \ | ||||
| pip install --no-cache-dir dgl-cu113 dglgo -f https://data.dgl.ai/wheels/repo.html; \ | pip install --no-cache-dir dgl-cu113 dglgo -f https://data.dgl.ai/wheels/repo.html; \ | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| from .version import __version__ | |||||
| from .version import __release_datetime__, __version__ | |||||
| __all__ = ['__version__'] | |||||
| __all__ = ['__version__', '__release_datetime__'] | |||||
| @@ -19,10 +19,13 @@ class Exporter(ABC): | |||||
| def from_model(cls, model: Model, **kwargs): | def from_model(cls, model: Model, **kwargs): | ||||
| """Build the Exporter instance. | """Build the Exporter instance. | ||||
| @param model: A model instance. it will be used to output the generated file, | |||||
| Args: | |||||
| model: A Model instance. it will be used to generate the intermediate format file, | |||||
| and the configuration.json in its model_dir field will be used to create the exporter instance. | and the configuration.json in its model_dir field will be used to create the exporter instance. | ||||
| @param kwargs: Extra kwargs used to create the Exporter instance. | |||||
| @return: The Exporter instance | |||||
| kwargs: Extra kwargs used to create the Exporter instance. | |||||
| Returns: | |||||
| The Exporter instance | |||||
| """ | """ | ||||
| cfg = Config.from_file( | cfg = Config.from_file( | ||||
| os.path.join(model.model_dir, ModelFile.CONFIGURATION)) | os.path.join(model.model_dir, ModelFile.CONFIGURATION)) | ||||
| @@ -44,10 +47,13 @@ class Exporter(ABC): | |||||
| In some cases, several files may be generated, | In some cases, several files may be generated, | ||||
| So please return a dict which contains the generated name with the file path. | So please return a dict which contains the generated name with the file path. | ||||
| @param opset: The version of the ONNX operator set to use. | |||||
| @param outputs: The output dir. | |||||
| @param kwargs: In this default implementation, | |||||
| kwargs will be carried to generate_dummy_inputs as extra arguments (like input shape). | |||||
| @return: A dict contains the model name with the model file path. | |||||
| Args: | |||||
| opset: The version of the ONNX operator set to use. | |||||
| outputs: The output dir. | |||||
| kwargs: In this default implementation, | |||||
| kwargs will be carried to generate_dummy_inputs as extra arguments (like input shape). | |||||
| Returns: | |||||
| A dict contains the model name with the model file path. | |||||
| """ | """ | ||||
| pass | pass | ||||
| @@ -23,13 +23,18 @@ class SbertForSequenceClassificationExporter(TorchModelExporter): | |||||
| def generate_dummy_inputs(self, | def generate_dummy_inputs(self, | ||||
| shape: Tuple = None, | shape: Tuple = None, | ||||
| pair: bool = False, | |||||
| **kwargs) -> Dict[str, Any]: | **kwargs) -> Dict[str, Any]: | ||||
| """Generate dummy inputs for model exportation to onnx or other formats by tracing. | """Generate dummy inputs for model exportation to onnx or other formats by tracing. | ||||
| @param shape: A tuple of input shape which should have at most two dimensions. | |||||
| shape = (1, ) batch_size=1, sequence_length will be taken from the preprocessor. | |||||
| shape = (8, 128) batch_size=1, sequence_length=128, which will cover the config of the preprocessor. | |||||
| @return: Dummy inputs. | |||||
| Args: | |||||
| shape: A tuple of input shape which should have at most two dimensions. | |||||
| shape = (1, ) batch_size=1, sequence_length will be taken from the preprocessor. | |||||
| shape = (8, 128) batch_size=1, sequence_length=128, which will cover the config of the preprocessor. | |||||
| pair(bool, `optional`): Whether to generate sentence pairs or single sentences. | |||||
| Returns: | |||||
| Dummy inputs. | |||||
| """ | """ | ||||
| cfg = Config.from_file( | cfg = Config.from_file( | ||||
| @@ -55,7 +60,7 @@ class SbertForSequenceClassificationExporter(TorchModelExporter): | |||||
| **sequence_length | **sequence_length | ||||
| }) | }) | ||||
| preprocessor: Preprocessor = build_preprocessor(cfg, field_name) | preprocessor: Preprocessor = build_preprocessor(cfg, field_name) | ||||
| if preprocessor.pair: | |||||
| if pair: | |||||
| first_sequence = preprocessor.tokenizer.unk_token | first_sequence = preprocessor.tokenizer.unk_token | ||||
| second_sequence = preprocessor.tokenizer.unk_token | second_sequence = preprocessor.tokenizer.unk_token | ||||
| else: | else: | ||||
| @@ -13,8 +13,8 @@ from modelscope.models import TorchModel | |||||
| from modelscope.pipelines.base import collate_fn | from modelscope.pipelines.base import collate_fn | ||||
| from modelscope.utils.constant import ModelFile | from modelscope.utils.constant import ModelFile | ||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| from modelscope.utils.regress_test_utils import compare_arguments_nested | |||||
| from modelscope.utils.tensor_utils import torch_nested_numpify | |||||
| from modelscope.utils.regress_test_utils import (compare_arguments_nested, | |||||
| numpify_tensor_nested) | |||||
| from .base import Exporter | from .base import Exporter | ||||
| logger = get_logger(__name__) | logger = get_logger(__name__) | ||||
| @@ -28,49 +28,61 @@ class TorchModelExporter(Exporter): | |||||
| and to provide implementations for generate_dummy_inputs/inputs/outputs methods. | and to provide implementations for generate_dummy_inputs/inputs/outputs methods. | ||||
| """ | """ | ||||
| def export_onnx(self, outputs: str, opset=11, **kwargs): | |||||
| def export_onnx(self, output_dir: str, opset=13, **kwargs): | |||||
| """Export the model as onnx format files. | """Export the model as onnx format files. | ||||
| In some cases, several files may be generated, | In some cases, several files may be generated, | ||||
| So please return a dict which contains the generated name with the file path. | So please return a dict which contains the generated name with the file path. | ||||
| @param opset: The version of the ONNX operator set to use. | |||||
| @param outputs: The output dir. | |||||
| @param kwargs: In this default implementation, | |||||
| you can pass the arguments needed by _torch_export_onnx, other unrecognized args | |||||
| will be carried to generate_dummy_inputs as extra arguments (such as input shape). | |||||
| @return: A dict containing the model key - model file path pairs. | |||||
| Args: | |||||
| opset: The version of the ONNX operator set to use. | |||||
| output_dir: The output dir. | |||||
| kwargs: | |||||
| model: A model instance which will replace the exporting of self.model. | |||||
| In this default implementation, | |||||
| you can pass the arguments needed by _torch_export_onnx, other unrecognized args | |||||
| will be carried to generate_dummy_inputs as extra arguments (such as input shape). | |||||
| Returns: | |||||
| A dict containing the model key - model file path pairs. | |||||
| """ | """ | ||||
| model = self.model | |||||
| model = self.model if 'model' not in kwargs else kwargs.pop('model') | |||||
| if not isinstance(model, nn.Module) and hasattr(model, 'model'): | if not isinstance(model, nn.Module) and hasattr(model, 'model'): | ||||
| model = model.model | model = model.model | ||||
| onnx_file = os.path.join(outputs, ModelFile.ONNX_MODEL_FILE) | |||||
| onnx_file = os.path.join(output_dir, ModelFile.ONNX_MODEL_FILE) | |||||
| self._torch_export_onnx(model, onnx_file, opset=opset, **kwargs) | self._torch_export_onnx(model, onnx_file, opset=opset, **kwargs) | ||||
| return {'model': onnx_file} | return {'model': onnx_file} | ||||
| def export_torch_script(self, outputs: str, **kwargs): | |||||
| def export_torch_script(self, output_dir: str, **kwargs): | |||||
| """Export the model as torch script files. | """Export the model as torch script files. | ||||
| In some cases, several files may be generated, | In some cases, several files may be generated, | ||||
| So please return a dict which contains the generated name with the file path. | So please return a dict which contains the generated name with the file path. | ||||
| @param outputs: The output dir. | |||||
| @param kwargs: In this default implementation, | |||||
| Args: | |||||
| output_dir: The output dir. | |||||
| kwargs: | |||||
| model: A model instance which will replace the exporting of self.model. | |||||
| In this default implementation, | |||||
| you can pass the arguments needed by _torch_export_torch_script, other unrecognized args | you can pass the arguments needed by _torch_export_torch_script, other unrecognized args | ||||
| will be carried to generate_dummy_inputs as extra arguments (like input shape). | will be carried to generate_dummy_inputs as extra arguments (like input shape). | ||||
| @return: A dict contains the model name with the model file path. | |||||
| Returns: | |||||
| A dict contains the model name with the model file path. | |||||
| """ | """ | ||||
| model = self.model | |||||
| model = self.model if 'model' not in kwargs else kwargs.pop('model') | |||||
| if not isinstance(model, nn.Module) and hasattr(model, 'model'): | if not isinstance(model, nn.Module) and hasattr(model, 'model'): | ||||
| model = model.model | model = model.model | ||||
| ts_file = os.path.join(outputs, ModelFile.TS_MODEL_FILE) | |||||
| ts_file = os.path.join(output_dir, ModelFile.TS_MODEL_FILE) | |||||
| # generate ts by tracing | # generate ts by tracing | ||||
| self._torch_export_torch_script(model, ts_file, **kwargs) | self._torch_export_torch_script(model, ts_file, **kwargs) | ||||
| return {'model': ts_file} | return {'model': ts_file} | ||||
| def generate_dummy_inputs(self, **kwargs) -> Dict[str, Any]: | def generate_dummy_inputs(self, **kwargs) -> Dict[str, Any]: | ||||
| """Generate dummy inputs for model exportation to onnx or other formats by tracing. | """Generate dummy inputs for model exportation to onnx or other formats by tracing. | ||||
| @return: Dummy inputs. | |||||
| Returns: | |||||
| Dummy inputs. | |||||
| """ | """ | ||||
| return None | return None | ||||
| @@ -93,7 +105,7 @@ class TorchModelExporter(Exporter): | |||||
| def _torch_export_onnx(self, | def _torch_export_onnx(self, | ||||
| model: nn.Module, | model: nn.Module, | ||||
| output: str, | output: str, | ||||
| opset: int = 11, | |||||
| opset: int = 13, | |||||
| device: str = 'cpu', | device: str = 'cpu', | ||||
| validation: bool = True, | validation: bool = True, | ||||
| rtol: float = None, | rtol: float = None, | ||||
| @@ -101,18 +113,27 @@ class TorchModelExporter(Exporter): | |||||
| **kwargs): | **kwargs): | ||||
| """Export the model to an onnx format file. | """Export the model to an onnx format file. | ||||
| @param model: A torch.nn.Module instance to export. | |||||
| @param output: The output file. | |||||
| @param opset: The version of the ONNX operator set to use. | |||||
| @param device: The device used to forward. | |||||
| @param validation: Whether validate the export file. | |||||
| @param rtol: The rtol used to regress the outputs. | |||||
| @param atol: The atol used to regress the outputs. | |||||
| Args: | |||||
| model: A torch.nn.Module instance to export. | |||||
| output: The output file. | |||||
| opset: The version of the ONNX operator set to use. | |||||
| device: The device used to forward. | |||||
| validation: Whether validate the export file. | |||||
| rtol: The rtol used to regress the outputs. | |||||
| atol: The atol used to regress the outputs. | |||||
| kwargs: | |||||
| dummy_inputs: A dummy inputs which will replace the calling of self.generate_dummy_inputs(). | |||||
| inputs: An inputs structure which will replace the calling of self.inputs. | |||||
| outputs: An outputs structure which will replace the calling of self.outputs. | |||||
| """ | """ | ||||
| dummy_inputs = self.generate_dummy_inputs(**kwargs) | |||||
| inputs = self.inputs | |||||
| outputs = self.outputs | |||||
| dummy_inputs = self.generate_dummy_inputs( | |||||
| **kwargs) if 'dummy_inputs' not in kwargs else kwargs.pop( | |||||
| 'dummy_inputs') | |||||
| inputs = self.inputs if 'inputs' not in kwargs else kwargs.pop( | |||||
| 'inputs') | |||||
| outputs = self.outputs if 'outputs' not in kwargs else kwargs.pop( | |||||
| 'outputs') | |||||
| if dummy_inputs is None or inputs is None or outputs is None: | if dummy_inputs is None or inputs is None or outputs is None: | ||||
| raise NotImplementedError( | raise NotImplementedError( | ||||
| 'Model property dummy_inputs,inputs,outputs must be set.') | 'Model property dummy_inputs,inputs,outputs must be set.') | ||||
| @@ -125,7 +146,7 @@ class TorchModelExporter(Exporter): | |||||
| if isinstance(dummy_inputs, Mapping): | if isinstance(dummy_inputs, Mapping): | ||||
| dummy_inputs = dict(dummy_inputs) | dummy_inputs = dict(dummy_inputs) | ||||
| onnx_outputs = list(self.outputs.keys()) | |||||
| onnx_outputs = list(outputs.keys()) | |||||
| with replace_call(): | with replace_call(): | ||||
| onnx_export( | onnx_export( | ||||
| @@ -160,11 +181,13 @@ class TorchModelExporter(Exporter): | |||||
| outputs_origin = model.forward( | outputs_origin = model.forward( | ||||
| *_decide_input_format(model, dummy_inputs)) | *_decide_input_format(model, dummy_inputs)) | ||||
| if isinstance(outputs_origin, Mapping): | if isinstance(outputs_origin, Mapping): | ||||
| outputs_origin = torch_nested_numpify( | |||||
| outputs_origin = numpify_tensor_nested( | |||||
| list(outputs_origin.values())) | list(outputs_origin.values())) | ||||
| elif isinstance(outputs_origin, (tuple, list)): | |||||
| outputs_origin = numpify_tensor_nested(outputs_origin) | |||||
| outputs = ort_session.run( | outputs = ort_session.run( | ||||
| onnx_outputs, | onnx_outputs, | ||||
| torch_nested_numpify(dummy_inputs), | |||||
| numpify_tensor_nested(dummy_inputs), | |||||
| ) | ) | ||||
| tols = {} | tols = {} | ||||
| @@ -184,19 +207,26 @@ class TorchModelExporter(Exporter): | |||||
| validation: bool = True, | validation: bool = True, | ||||
| rtol: float = None, | rtol: float = None, | ||||
| atol: float = None, | atol: float = None, | ||||
| strict: bool = True, | |||||
| **kwargs): | **kwargs): | ||||
| """Export the model to a torch script file. | """Export the model to a torch script file. | ||||
| @param model: A torch.nn.Module instance to export. | |||||
| @param output: The output file. | |||||
| @param device: The device used to forward. | |||||
| @param validation: Whether validate the export file. | |||||
| @param rtol: The rtol used to regress the outputs. | |||||
| @param atol: The atol used to regress the outputs. | |||||
| Args: | |||||
| model: A torch.nn.Module instance to export. | |||||
| output: The output file. | |||||
| device: The device used to forward. | |||||
| validation: Whether validate the export file. | |||||
| rtol: The rtol used to regress the outputs. | |||||
| atol: The atol used to regress the outputs. | |||||
| strict: strict mode in torch script tracing. | |||||
| kwargs: | |||||
| dummy_inputs: A dummy inputs which will replace the calling of self.generate_dummy_inputs(). | |||||
| """ | """ | ||||
| model.eval() | model.eval() | ||||
| dummy_inputs = self.generate_dummy_inputs(**kwargs) | |||||
| dummy_param = 'dummy_inputs' not in kwargs | |||||
| dummy_inputs = self.generate_dummy_inputs( | |||||
| **kwargs) if dummy_param else kwargs.pop('dummy_inputs') | |||||
| if dummy_inputs is None: | if dummy_inputs is None: | ||||
| raise NotImplementedError( | raise NotImplementedError( | ||||
| 'Model property dummy_inputs must be set.') | 'Model property dummy_inputs must be set.') | ||||
| @@ -207,7 +237,7 @@ class TorchModelExporter(Exporter): | |||||
| model.eval() | model.eval() | ||||
| with replace_call(): | with replace_call(): | ||||
| traced_model = torch.jit.trace( | traced_model = torch.jit.trace( | ||||
| model, dummy_inputs, strict=False) | |||||
| model, dummy_inputs, strict=strict) | |||||
| torch.jit.save(traced_model, output) | torch.jit.save(traced_model, output) | ||||
| if validation: | if validation: | ||||
| @@ -216,9 +246,9 @@ class TorchModelExporter(Exporter): | |||||
| model.eval() | model.eval() | ||||
| ts_model.eval() | ts_model.eval() | ||||
| outputs = ts_model.forward(*dummy_inputs) | outputs = ts_model.forward(*dummy_inputs) | ||||
| outputs = torch_nested_numpify(outputs) | |||||
| outputs = numpify_tensor_nested(outputs) | |||||
| outputs_origin = model.forward(*dummy_inputs) | outputs_origin = model.forward(*dummy_inputs) | ||||
| outputs_origin = torch_nested_numpify(outputs_origin) | |||||
| outputs_origin = numpify_tensor_nested(outputs_origin) | |||||
| tols = {} | tols = {} | ||||
| if rtol is not None: | if rtol is not None: | ||||
| tols['rtol'] = rtol | tols['rtol'] = rtol | ||||
| @@ -240,7 +270,6 @@ def replace_call(): | |||||
| problems. Here we recover the call method to the default implementation of torch.nn.Module, and change it | problems. Here we recover the call method to the default implementation of torch.nn.Module, and change it | ||||
| back after the tracing was done. | back after the tracing was done. | ||||
| """ | """ | ||||
| TorchModel.call_origin, TorchModel.__call__ = TorchModel.__call__, TorchModel._call_impl | TorchModel.call_origin, TorchModel.__call__ = TorchModel.__call__, TorchModel._call_impl | ||||
| yield | yield | ||||
| TorchModel.__call__ = TorchModel.call_origin | TorchModel.__call__ = TorchModel.call_origin | ||||
| @@ -1,32 +1,47 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| # yapf: disable | |||||
| import datetime | |||||
| import os | import os | ||||
| import pickle | import pickle | ||||
| import platform | |||||
| import shutil | import shutil | ||||
| import tempfile | |||||
| import uuid | |||||
| from collections import defaultdict | from collections import defaultdict | ||||
| from http import HTTPStatus | from http import HTTPStatus | ||||
| from http.cookiejar import CookieJar | from http.cookiejar import CookieJar | ||||
| from os.path import expanduser | from os.path import expanduser | ||||
| from typing import List, Optional, Tuple, Union | |||||
| from typing import Dict, List, Optional, Tuple, Union | |||||
| import requests | import requests | ||||
| from modelscope import __version__ | |||||
| from modelscope.hub.constants import (API_RESPONSE_FIELD_DATA, | from modelscope.hub.constants import (API_RESPONSE_FIELD_DATA, | ||||
| API_RESPONSE_FIELD_EMAIL, | API_RESPONSE_FIELD_EMAIL, | ||||
| API_RESPONSE_FIELD_GIT_ACCESS_TOKEN, | API_RESPONSE_FIELD_GIT_ACCESS_TOKEN, | ||||
| API_RESPONSE_FIELD_MESSAGE, | API_RESPONSE_FIELD_MESSAGE, | ||||
| API_RESPONSE_FIELD_USERNAME, | API_RESPONSE_FIELD_USERNAME, | ||||
| DEFAULT_CREDENTIALS_PATH) | |||||
| DEFAULT_CREDENTIALS_PATH, | |||||
| MODELSCOPE_ENVIRONMENT, ONE_YEAR_SECONDS, | |||||
| Licenses, ModelVisibility) | |||||
| from modelscope.hub.errors import (InvalidParameter, NotExistError, | |||||
| NotLoginException, NoValidRevisionError, | |||||
| RequestError, datahub_raise_on_error, | |||||
| handle_http_post_error, | |||||
| handle_http_response, is_ok, | |||||
| raise_for_http_status, raise_on_error) | |||||
| from modelscope.hub.git import GitCommandWrapper | |||||
| from modelscope.hub.repository import Repository | |||||
| from modelscope.utils.config_ds import DOWNLOADED_DATASETS_PATH | from modelscope.utils.config_ds import DOWNLOADED_DATASETS_PATH | ||||
| from modelscope.utils.constant import (DEFAULT_DATASET_REVISION, | from modelscope.utils.constant import (DEFAULT_DATASET_REVISION, | ||||
| DEFAULT_MODEL_REVISION, | DEFAULT_MODEL_REVISION, | ||||
| DatasetFormations, DatasetMetaFormats, | |||||
| DownloadMode) | |||||
| DEFAULT_REPOSITORY_REVISION, | |||||
| MASTER_MODEL_BRANCH, DatasetFormations, | |||||
| DatasetMetaFormats, DownloadMode, | |||||
| ModelFile) | |||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| from .errors import (InvalidParameter, NotExistError, RequestError, | |||||
| datahub_raise_on_error, handle_http_response, is_ok, | |||||
| raise_on_error) | |||||
| from .utils.utils import (get_dataset_hub_endpoint, get_endpoint, | |||||
| from .utils.utils import (get_endpoint, get_release_datetime, | |||||
| model_id_to_group_owner_name) | model_id_to_group_owner_name) | ||||
| logger = get_logger() | logger = get_logger() | ||||
| @@ -34,10 +49,9 @@ logger = get_logger() | |||||
| class HubApi: | class HubApi: | ||||
| def __init__(self, endpoint=None, dataset_endpoint=None): | |||||
| def __init__(self, endpoint=None): | |||||
| self.endpoint = endpoint if endpoint is not None else get_endpoint() | self.endpoint = endpoint if endpoint is not None else get_endpoint() | ||||
| self.dataset_endpoint = dataset_endpoint if dataset_endpoint is not None else get_dataset_hub_endpoint( | |||||
| ) | |||||
| self.headers = {'user-agent': ModelScopeConfig.get_user_agent()} | |||||
| def login( | def login( | ||||
| self, | self, | ||||
| @@ -57,8 +71,9 @@ class HubApi: | |||||
| </Tip> | </Tip> | ||||
| """ | """ | ||||
| path = f'{self.endpoint}/api/v1/login' | path = f'{self.endpoint}/api/v1/login' | ||||
| r = requests.post(path, json={'AccessToken': access_token}) | |||||
| r.raise_for_status() | |||||
| r = requests.post( | |||||
| path, json={'AccessToken': access_token}, headers=self.headers) | |||||
| raise_for_http_status(r) | |||||
| d = r.json() | d = r.json() | ||||
| raise_on_error(d) | raise_on_error(d) | ||||
| @@ -105,17 +120,16 @@ class HubApi: | |||||
| path = f'{self.endpoint}/api/v1/models' | path = f'{self.endpoint}/api/v1/models' | ||||
| owner_or_group, name = model_id_to_group_owner_name(model_id) | owner_or_group, name = model_id_to_group_owner_name(model_id) | ||||
| body = { | |||||
| 'Path': owner_or_group, | |||||
| 'Name': name, | |||||
| 'ChineseName': chinese_name, | |||||
| 'Visibility': visibility, # server check | |||||
| 'License': license | |||||
| } | |||||
| r = requests.post( | r = requests.post( | ||||
| path, | |||||
| json={ | |||||
| 'Path': owner_or_group, | |||||
| 'Name': name, | |||||
| 'ChineseName': chinese_name, | |||||
| 'Visibility': visibility, # server check | |||||
| 'License': license | |||||
| }, | |||||
| cookies=cookies) | |||||
| r.raise_for_status() | |||||
| path, json=body, cookies=cookies, headers=self.headers) | |||||
| handle_http_post_error(r, path, body) | |||||
| raise_on_error(r.json()) | raise_on_error(r.json()) | ||||
| model_repo_url = f'{get_endpoint()}/{model_id}' | model_repo_url = f'{get_endpoint()}/{model_id}' | ||||
| return model_repo_url | return model_repo_url | ||||
| @@ -134,8 +148,8 @@ class HubApi: | |||||
| raise ValueError('Token does not exist, please login first.') | raise ValueError('Token does not exist, please login first.') | ||||
| path = f'{self.endpoint}/api/v1/models/{model_id}' | path = f'{self.endpoint}/api/v1/models/{model_id}' | ||||
| r = requests.delete(path, cookies=cookies) | |||||
| r.raise_for_status() | |||||
| r = requests.delete(path, cookies=cookies, headers=self.headers) | |||||
| raise_for_http_status(r) | |||||
| raise_on_error(r.json()) | raise_on_error(r.json()) | ||||
| def get_model_url(self, model_id): | def get_model_url(self, model_id): | ||||
| @@ -164,7 +178,7 @@ class HubApi: | |||||
| owner_or_group, name = model_id_to_group_owner_name(model_id) | owner_or_group, name = model_id_to_group_owner_name(model_id) | ||||
| path = f'{self.endpoint}/api/v1/models/{owner_or_group}/{name}?Revision={revision}' | path = f'{self.endpoint}/api/v1/models/{owner_or_group}/{name}?Revision={revision}' | ||||
| r = requests.get(path, cookies=cookies) | |||||
| r = requests.get(path, cookies=cookies, headers=self.headers) | |||||
| handle_http_response(r, logger, cookies, model_id) | handle_http_response(r, logger, cookies, model_id) | ||||
| if r.status_code == HTTPStatus.OK: | if r.status_code == HTTPStatus.OK: | ||||
| if is_ok(r.json()): | if is_ok(r.json()): | ||||
| @@ -172,13 +186,116 @@ class HubApi: | |||||
| else: | else: | ||||
| raise NotExistError(r.json()[API_RESPONSE_FIELD_MESSAGE]) | raise NotExistError(r.json()[API_RESPONSE_FIELD_MESSAGE]) | ||||
| else: | else: | ||||
| r.raise_for_status() | |||||
| raise_for_http_status(r) | |||||
| def push_model(self, | |||||
| model_id: str, | |||||
| model_dir: str, | |||||
| visibility: int = ModelVisibility.PUBLIC, | |||||
| license: str = Licenses.APACHE_V2, | |||||
| chinese_name: Optional[str] = None, | |||||
| commit_message: Optional[str] = 'upload model', | |||||
| revision: Optional[str] = DEFAULT_REPOSITORY_REVISION): | |||||
| """ | |||||
| Upload model from a given directory to given repository. A valid model directory | |||||
| must contain a configuration.json file. | |||||
| def list_model(self, | |||||
| owner_or_group: str, | |||||
| page_number=1, | |||||
| page_size=10) -> dict: | |||||
| """List model in owner or group. | |||||
| This function upload the files in given directory to given repository. If the | |||||
| given repository is not exists in remote, it will automatically create it with | |||||
| given visibility, license and chinese_name parameters. If the revision is also | |||||
| not exists in remote repository, it will create a new branch for it. | |||||
| This function must be called before calling HubApi's login with a valid token | |||||
| which can be obtained from ModelScope's website. | |||||
| Args: | |||||
| model_id (`str`): | |||||
| The model id to be uploaded, caller must have write permission for it. | |||||
| model_dir(`str`): | |||||
| The Absolute Path of the finetune result. | |||||
| visibility(`int`, defaults to `0`): | |||||
| Visibility of the new created model(1-private, 5-public). If the model is | |||||
| not exists in ModelScope, this function will create a new model with this | |||||
| visibility and this parameter is required. You can ignore this parameter | |||||
| if you make sure the model's existence. | |||||
| license(`str`, defaults to `None`): | |||||
| License of the new created model(see License). If the model is not exists | |||||
| in ModelScope, this function will create a new model with this license | |||||
| and this parameter is required. You can ignore this parameter if you | |||||
| make sure the model's existence. | |||||
| chinese_name(`str`, *optional*, defaults to `None`): | |||||
| chinese name of the new created model. | |||||
| commit_message(`str`, *optional*, defaults to `None`): | |||||
| commit message of the push request. | |||||
| revision (`str`, *optional*, default to DEFAULT_MODEL_REVISION): | |||||
| which branch to push. If the branch is not exists, It will create a new | |||||
| branch and push to it. | |||||
| """ | |||||
| if model_id is None: | |||||
| raise InvalidParameter('model_id cannot be empty!') | |||||
| if model_dir is None: | |||||
| raise InvalidParameter('model_dir cannot be empty!') | |||||
| if not os.path.exists(model_dir) or os.path.isfile(model_dir): | |||||
| raise InvalidParameter('model_dir must be a valid directory.') | |||||
| cfg_file = os.path.join(model_dir, ModelFile.CONFIGURATION) | |||||
| if not os.path.exists(cfg_file): | |||||
| raise ValueError(f'{model_dir} must contain a configuration.json.') | |||||
| cookies = ModelScopeConfig.get_cookies() | |||||
| if cookies is None: | |||||
| raise NotLoginException('Must login before upload!') | |||||
| files_to_save = os.listdir(model_dir) | |||||
| try: | |||||
| self.get_model(model_id=model_id) | |||||
| except Exception: | |||||
| if visibility is None or license is None: | |||||
| raise InvalidParameter( | |||||
| 'visibility and license cannot be empty if want to create new repo' | |||||
| ) | |||||
| logger.info('Create new model %s' % model_id) | |||||
| self.create_model( | |||||
| model_id=model_id, | |||||
| visibility=visibility, | |||||
| license=license, | |||||
| chinese_name=chinese_name) | |||||
| tmp_dir = tempfile.mkdtemp() | |||||
| git_wrapper = GitCommandWrapper() | |||||
| try: | |||||
| repo = Repository(model_dir=tmp_dir, clone_from=model_id) | |||||
| branches = git_wrapper.get_remote_branches(tmp_dir) | |||||
| if revision not in branches: | |||||
| logger.info('Create new branch %s' % revision) | |||||
| git_wrapper.new_branch(tmp_dir, revision) | |||||
| git_wrapper.checkout(tmp_dir, revision) | |||||
| files_in_repo = os.listdir(tmp_dir) | |||||
| for f in files_in_repo: | |||||
| if f[0] != '.': | |||||
| src = os.path.join(tmp_dir, f) | |||||
| if os.path.isfile(src): | |||||
| os.remove(src) | |||||
| else: | |||||
| shutil.rmtree(src, ignore_errors=True) | |||||
| for f in files_to_save: | |||||
| if f[0] != '.': | |||||
| src = os.path.join(model_dir, f) | |||||
| if os.path.isdir(src): | |||||
| shutil.copytree(src, os.path.join(tmp_dir, f)) | |||||
| else: | |||||
| shutil.copy(src, tmp_dir) | |||||
| if not commit_message: | |||||
| date = datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S') | |||||
| commit_message = '[automsg] push model %s to hub at %s' % ( | |||||
| model_id, date) | |||||
| repo.push(commit_message=commit_message, local_branch=revision, remote_branch=revision) | |||||
| except Exception: | |||||
| raise | |||||
| finally: | |||||
| shutil.rmtree(tmp_dir, ignore_errors=True) | |||||
| def list_models(self, | |||||
| owner_or_group: str, | |||||
| page_number=1, | |||||
| page_size=10) -> dict: | |||||
| """List models in owner or group. | |||||
| Args: | Args: | ||||
| owner_or_group(`str`): owner or group. | owner_or_group(`str`): owner or group. | ||||
| @@ -193,7 +310,8 @@ class HubApi: | |||||
| path, | path, | ||||
| data='{"Path":"%s", "PageNumber":%s, "PageSize": %s}' % | data='{"Path":"%s", "PageNumber":%s, "PageSize": %s}' % | ||||
| (owner_or_group, page_number, page_size), | (owner_or_group, page_number, page_size), | ||||
| cookies=cookies) | |||||
| cookies=cookies, | |||||
| headers=self.headers) | |||||
| handle_http_response(r, logger, cookies, 'list_model') | handle_http_response(r, logger, cookies, 'list_model') | ||||
| if r.status_code == HTTPStatus.OK: | if r.status_code == HTTPStatus.OK: | ||||
| if is_ok(r.json()): | if is_ok(r.json()): | ||||
| @@ -202,7 +320,7 @@ class HubApi: | |||||
| else: | else: | ||||
| raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE]) | raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE]) | ||||
| else: | else: | ||||
| r.raise_for_status() | |||||
| raise_for_http_status(r) | |||||
| return None | return None | ||||
| def _check_cookie(self, | def _check_cookie(self, | ||||
| @@ -217,10 +335,70 @@ class HubApi: | |||||
| raise ValueError('Token does not exist, please login first.') | raise ValueError('Token does not exist, please login first.') | ||||
| return cookies | return cookies | ||||
| def list_model_revisions( | |||||
| self, | |||||
| model_id: str, | |||||
| cutoff_timestamp: int = None, | |||||
| use_cookies: Union[bool, CookieJar] = False) -> List[str]: | |||||
| """Get model branch and tags. | |||||
| Args: | |||||
| model_id (str): The model id | |||||
| cutoff_timestamp (int): Tags created before the cutoff will be included. | |||||
| The timestamp is represented by the seconds elasped from the epoch time. | |||||
| use_cookies (Union[bool, CookieJar], optional): If is cookieJar, we will use this cookie, if True, will | |||||
| will load cookie from local. Defaults to False. | |||||
| Returns: | |||||
| Tuple[List[str], List[str]]: Return list of branch name and tags | |||||
| """ | |||||
| cookies = self._check_cookie(use_cookies) | |||||
| if cutoff_timestamp is None: | |||||
| cutoff_timestamp = get_release_datetime() | |||||
| path = f'{self.endpoint}/api/v1/models/{model_id}/revisions?EndTime=%s' % cutoff_timestamp | |||||
| r = requests.get(path, cookies=cookies, headers=self.headers) | |||||
| handle_http_response(r, logger, cookies, model_id) | |||||
| d = r.json() | |||||
| raise_on_error(d) | |||||
| info = d[API_RESPONSE_FIELD_DATA] | |||||
| # tags returned from backend are guaranteed to be ordered by create-time | |||||
| tags = [x['Revision'] for x in info['RevisionMap']['Tags'] | |||||
| ] if info['RevisionMap']['Tags'] else [] | |||||
| return tags | |||||
| def get_valid_revision(self, model_id: str, revision=None, cookies: Optional[CookieJar] = None): | |||||
| release_timestamp = get_release_datetime() | |||||
| current_timestamp = int(round(datetime.datetime.now().timestamp())) | |||||
| # for active development in library codes (non-release-branches), release_timestamp | |||||
| # is set to be a far-away-time-in-the-future, to ensure that we shall | |||||
| # get the master-HEAD version from model repo by default (when no revision is provided) | |||||
| if release_timestamp > current_timestamp + ONE_YEAR_SECONDS: | |||||
| branches, tags = self.get_model_branches_and_tags( | |||||
| model_id, use_cookies=False if cookies is None else cookies) | |||||
| if revision is None: | |||||
| revision = MASTER_MODEL_BRANCH | |||||
| logger.info('Model revision not specified, use default: %s in development mode' % revision) | |||||
| if revision not in branches and revision not in tags: | |||||
| raise NotExistError('The model: %s has no branch or tag : %s .' % revision) | |||||
| else: | |||||
| revisions = self.list_model_revisions( | |||||
| model_id, cutoff_timestamp=release_timestamp, use_cookies=False if cookies is None else cookies) | |||||
| if revision is None: | |||||
| if len(revisions) == 0: | |||||
| raise NoValidRevisionError('The model: %s has no valid revision!' % model_id) | |||||
| # tags (revisions) returned from backend are guaranteed to be ordered by create-time | |||||
| # we shall obtain the latest revision created earlier than release version of this branch | |||||
| revision = revisions[0] | |||||
| logger.info('Model revision not specified, use the latest revision: %s' % revision) | |||||
| else: | |||||
| if revision not in revisions: | |||||
| raise NotExistError( | |||||
| 'The model: %s has no revision: %s !' % (model_id, revision)) | |||||
| return revision | |||||
| def get_model_branches_and_tags( | def get_model_branches_and_tags( | ||||
| self, | self, | ||||
| model_id: str, | model_id: str, | ||||
| use_cookies: Union[bool, CookieJar] = False | |||||
| use_cookies: Union[bool, CookieJar] = False, | |||||
| ) -> Tuple[List[str], List[str]]: | ) -> Tuple[List[str], List[str]]: | ||||
| """Get model branch and tags. | """Get model branch and tags. | ||||
| @@ -234,7 +412,7 @@ class HubApi: | |||||
| cookies = self._check_cookie(use_cookies) | cookies = self._check_cookie(use_cookies) | ||||
| path = f'{self.endpoint}/api/v1/models/{model_id}/revisions' | path = f'{self.endpoint}/api/v1/models/{model_id}/revisions' | ||||
| r = requests.get(path, cookies=cookies) | |||||
| r = requests.get(path, cookies=cookies, headers=self.headers) | |||||
| handle_http_response(r, logger, cookies, model_id) | handle_http_response(r, logger, cookies, model_id) | ||||
| d = r.json() | d = r.json() | ||||
| raise_on_error(d) | raise_on_error(d) | ||||
| @@ -275,7 +453,11 @@ class HubApi: | |||||
| if root is not None: | if root is not None: | ||||
| path = path + f'&Root={root}' | path = path + f'&Root={root}' | ||||
| r = requests.get(path, cookies=cookies, headers=headers) | |||||
| r = requests.get( | |||||
| path, cookies=cookies, headers={ | |||||
| **headers, | |||||
| **self.headers | |||||
| }) | |||||
| handle_http_response(r, logger, cookies, model_id) | handle_http_response(r, logger, cookies, model_id) | ||||
| d = r.json() | d = r.json() | ||||
| @@ -290,11 +472,10 @@ class HubApi: | |||||
| return files | return files | ||||
| def list_datasets(self): | def list_datasets(self): | ||||
| path = f'{self.dataset_endpoint}/api/v1/datasets' | |||||
| headers = None | |||||
| path = f'{self.endpoint}/api/v1/datasets' | |||||
| params = {} | params = {} | ||||
| r = requests.get(path, params=params, headers=headers) | |||||
| r.raise_for_status() | |||||
| r = requests.get(path, params=params, headers=self.headers) | |||||
| raise_for_http_status(r) | |||||
| dataset_list = r.json()[API_RESPONSE_FIELD_DATA] | dataset_list = r.json()[API_RESPONSE_FIELD_DATA] | ||||
| return [x['Name'] for x in dataset_list] | return [x['Name'] for x in dataset_list] | ||||
| @@ -317,14 +498,14 @@ class HubApi: | |||||
| cache_dir): | cache_dir): | ||||
| shutil.rmtree(cache_dir) | shutil.rmtree(cache_dir) | ||||
| os.makedirs(cache_dir, exist_ok=True) | os.makedirs(cache_dir, exist_ok=True) | ||||
| datahub_url = f'{self.dataset_endpoint}/api/v1/datasets/{namespace}/{dataset_name}' | |||||
| datahub_url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}' | |||||
| r = requests.get(datahub_url) | r = requests.get(datahub_url) | ||||
| resp = r.json() | resp = r.json() | ||||
| datahub_raise_on_error(datahub_url, resp) | datahub_raise_on_error(datahub_url, resp) | ||||
| dataset_id = resp['Data']['Id'] | dataset_id = resp['Data']['Id'] | ||||
| dataset_type = resp['Data']['Type'] | dataset_type = resp['Data']['Type'] | ||||
| datahub_url = f'{self.dataset_endpoint}/api/v1/datasets/{dataset_id}/repo/tree?Revision={revision}' | |||||
| r = requests.get(datahub_url) | |||||
| datahub_url = f'{self.endpoint}/api/v1/datasets/{dataset_id}/repo/tree?Revision={revision}' | |||||
| r = requests.get(datahub_url, headers=self.headers) | |||||
| resp = r.json() | resp = r.json() | ||||
| datahub_raise_on_error(datahub_url, resp) | datahub_raise_on_error(datahub_url, resp) | ||||
| file_list = resp['Data'] | file_list = resp['Data'] | ||||
| @@ -341,10 +522,10 @@ class HubApi: | |||||
| file_path = file_info['Path'] | file_path = file_info['Path'] | ||||
| extension = os.path.splitext(file_path)[-1] | extension = os.path.splitext(file_path)[-1] | ||||
| if extension in dataset_meta_format: | if extension in dataset_meta_format: | ||||
| datahub_url = f'{self.dataset_endpoint}/api/v1/datasets/{namespace}/{dataset_name}/repo?' \ | |||||
| datahub_url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/repo?' \ | |||||
| f'Revision={revision}&FilePath={file_path}' | f'Revision={revision}&FilePath={file_path}' | ||||
| r = requests.get(datahub_url) | r = requests.get(datahub_url) | ||||
| r.raise_for_status() | |||||
| raise_for_http_status(r) | |||||
| local_path = os.path.join(cache_dir, file_path) | local_path = os.path.join(cache_dir, file_path) | ||||
| if os.path.exists(local_path): | if os.path.exists(local_path): | ||||
| logger.warning( | logger.warning( | ||||
| @@ -365,7 +546,7 @@ class HubApi: | |||||
| namespace: str, | namespace: str, | ||||
| revision: Optional[str] = DEFAULT_DATASET_REVISION): | revision: Optional[str] = DEFAULT_DATASET_REVISION): | ||||
| if file_name.endswith('.csv'): | if file_name.endswith('.csv'): | ||||
| file_name = f'{self.dataset_endpoint}/api/v1/datasets/{namespace}/{dataset_name}/repo?' \ | |||||
| file_name = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/repo?' \ | |||||
| f'Revision={revision}&FilePath={file_name}' | f'Revision={revision}&FilePath={file_name}' | ||||
| return file_name | return file_name | ||||
| @@ -374,7 +555,7 @@ class HubApi: | |||||
| dataset_name: str, | dataset_name: str, | ||||
| namespace: str, | namespace: str, | ||||
| revision: Optional[str] = DEFAULT_DATASET_REVISION): | revision: Optional[str] = DEFAULT_DATASET_REVISION): | ||||
| datahub_url = f'{self.dataset_endpoint}/api/v1/datasets/{namespace}/{dataset_name}/' \ | |||||
| datahub_url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/' \ | |||||
| f'ststoken?Revision={revision}' | f'ststoken?Revision={revision}' | ||||
| return self.datahub_remote_call(datahub_url) | return self.datahub_remote_call(datahub_url) | ||||
| @@ -385,23 +566,39 @@ class HubApi: | |||||
| namespace: str, | namespace: str, | ||||
| revision: Optional[str] = DEFAULT_DATASET_REVISION): | revision: Optional[str] = DEFAULT_DATASET_REVISION): | ||||
| datahub_url = f'{self.dataset_endpoint}/api/v1/datasets/{namespace}/{dataset_name}/' \ | |||||
| datahub_url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/' \ | |||||
| f'ststoken?Revision={revision}' | f'ststoken?Revision={revision}' | ||||
| cookies = requests.utils.dict_from_cookiejar(cookies) | cookies = requests.utils.dict_from_cookiejar(cookies) | ||||
| r = requests.get(url=datahub_url, cookies=cookies) | |||||
| r = requests.get( | |||||
| url=datahub_url, cookies=cookies, headers=self.headers) | |||||
| resp = r.json() | resp = r.json() | ||||
| raise_on_error(resp) | raise_on_error(resp) | ||||
| return resp['Data'] | return resp['Data'] | ||||
| def list_oss_dataset_objects(self, dataset_name, namespace, max_limit, | |||||
| is_recursive, is_filter_dir, revision): | |||||
| url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/oss/tree/?' \ | |||||
| f'MaxLimit={max_limit}&Revision={revision}&Recursive={is_recursive}&FilterDir={is_filter_dir}' | |||||
| cookies = ModelScopeConfig.get_cookies() | |||||
| if cookies: | |||||
| cookies = requests.utils.dict_from_cookiejar(cookies) | |||||
| resp = requests.get(url=url, cookies=cookies) | |||||
| resp = resp.json() | |||||
| raise_on_error(resp) | |||||
| resp = resp['Data'] | |||||
| return resp | |||||
| def on_dataset_download(self, dataset_name: str, namespace: str) -> None: | def on_dataset_download(self, dataset_name: str, namespace: str) -> None: | ||||
| url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/download/increase' | url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/download/increase' | ||||
| r = requests.post(url) | |||||
| r.raise_for_status() | |||||
| r = requests.post(url, headers=self.headers) | |||||
| raise_for_http_status(r) | |||||
| @staticmethod | @staticmethod | ||||
| def datahub_remote_call(url): | def datahub_remote_call(url): | ||||
| r = requests.get(url) | |||||
| r = requests.get(url, headers={'user-agent': ModelScopeConfig.get_user_agent()}) | |||||
| resp = r.json() | resp = r.json() | ||||
| datahub_raise_on_error(url, resp) | datahub_raise_on_error(url, resp) | ||||
| return resp['Data'] | return resp['Data'] | ||||
| @@ -415,6 +612,7 @@ class ModelScopeConfig: | |||||
| COOKIES_FILE_NAME = 'cookies' | COOKIES_FILE_NAME = 'cookies' | ||||
| GIT_TOKEN_FILE_NAME = 'git_token' | GIT_TOKEN_FILE_NAME = 'git_token' | ||||
| USER_INFO_FILE_NAME = 'user' | USER_INFO_FILE_NAME = 'user' | ||||
| USER_SESSION_ID_FILE_NAME = 'session' | |||||
| @staticmethod | @staticmethod | ||||
| def make_sure_credential_path_exist(): | def make_sure_credential_path_exist(): | ||||
| @@ -443,6 +641,23 @@ class ModelScopeConfig: | |||||
| return cookies | return cookies | ||||
| return None | return None | ||||
| @staticmethod | |||||
| def get_user_session_id(): | |||||
| session_path = os.path.join(ModelScopeConfig.path_credential, | |||||
| ModelScopeConfig.USER_SESSION_ID_FILE_NAME) | |||||
| session_id = '' | |||||
| if os.path.exists(session_path): | |||||
| with open(session_path, 'rb') as f: | |||||
| session_id = str(f.readline().strip(), encoding='utf-8') | |||||
| return session_id | |||||
| if session_id == '' or len(session_id) != 32: | |||||
| session_id = str(uuid.uuid4().hex) | |||||
| ModelScopeConfig.make_sure_credential_path_exist() | |||||
| with open(session_path, 'w+') as wf: | |||||
| wf.write(session_id) | |||||
| return session_id | |||||
| @staticmethod | @staticmethod | ||||
| def save_token(token: str): | def save_token(token: str): | ||||
| ModelScopeConfig.make_sure_credential_path_exist() | ModelScopeConfig.make_sure_credential_path_exist() | ||||
| @@ -491,3 +706,32 @@ class ModelScopeConfig: | |||||
| except FileNotFoundError: | except FileNotFoundError: | ||||
| pass | pass | ||||
| return token | return token | ||||
| @staticmethod | |||||
| def get_user_agent(user_agent: Union[Dict, str, None] = None, ) -> str: | |||||
| """Formats a user-agent string with basic info about a request. | |||||
| Args: | |||||
| user_agent (`str`, `dict`, *optional*): | |||||
| The user agent info in the form of a dictionary or a single string. | |||||
| Returns: | |||||
| The formatted user-agent string. | |||||
| """ | |||||
| env = 'custom' | |||||
| if MODELSCOPE_ENVIRONMENT in os.environ: | |||||
| env = os.environ[MODELSCOPE_ENVIRONMENT] | |||||
| ua = 'modelscope/%s; python/%s; session_id/%s; platform/%s; processor/%s; env/%s' % ( | |||||
| __version__, | |||||
| platform.python_version(), | |||||
| ModelScopeConfig.get_user_session_id(), | |||||
| platform.platform(), | |||||
| platform.processor(), | |||||
| env, | |||||
| ) | |||||
| if isinstance(user_agent, dict): | |||||
| ua = '; '.join(f'{k}/{v}' for k, v in user_agent.items()) | |||||
| elif isinstance(user_agent, str): | |||||
| ua += ';' + user_agent | |||||
| return ua | |||||
| @@ -16,6 +16,9 @@ API_RESPONSE_FIELD_GIT_ACCESS_TOKEN = 'AccessToken' | |||||
| API_RESPONSE_FIELD_USERNAME = 'Username' | API_RESPONSE_FIELD_USERNAME = 'Username' | ||||
| API_RESPONSE_FIELD_EMAIL = 'Email' | API_RESPONSE_FIELD_EMAIL = 'Email' | ||||
| API_RESPONSE_FIELD_MESSAGE = 'Message' | API_RESPONSE_FIELD_MESSAGE = 'Message' | ||||
| MODELSCOPE_ENVIRONMENT = 'MODELSCOPE_ENVIRONMENT' | |||||
| MODELSCOPE_SDK_DEBUG = 'MODELSCOPE_SDK_DEBUG' | |||||
| ONE_YEAR_SECONDS = 24 * 365 * 60 * 60 | |||||
| class Licenses(object): | class Licenses(object): | ||||
| @@ -0,0 +1,339 @@ | |||||
| import urllib | |||||
| from abc import ABC | |||||
| from http import HTTPStatus | |||||
| from typing import Optional | |||||
| import json | |||||
| import requests | |||||
| from attrs import asdict, define, field, validators | |||||
| from modelscope.hub.api import ModelScopeConfig | |||||
| from modelscope.hub.constants import (API_RESPONSE_FIELD_DATA, | |||||
| API_RESPONSE_FIELD_MESSAGE) | |||||
| from modelscope.hub.errors import (NotLoginException, NotSupportError, | |||||
| RequestError, handle_http_response, is_ok, | |||||
| raise_for_http_status) | |||||
| from modelscope.hub.utils.utils import get_endpoint | |||||
| from modelscope.utils.logger import get_logger | |||||
| # yapf: enable | |||||
| logger = get_logger() | |||||
| class Accelerator(object): | |||||
| CPU = 'cpu' | |||||
| GPU = 'gpu' | |||||
| class Vendor(object): | |||||
| EAS = 'eas' | |||||
| class EASRegion(object): | |||||
| beijing = 'cn-beijing' | |||||
| hangzhou = 'cn-hangzhou' | |||||
| class EASCpuInstanceType(object): | |||||
| """EAS Cpu Instance TYpe, ref(https://help.aliyun.com/document_detail/144261.html) | |||||
| """ | |||||
| tiny = 'ecs.c6.2xlarge' | |||||
| small = 'ecs.c6.4xlarge' | |||||
| medium = 'ecs.c6.6xlarge' | |||||
| large = 'ecs.c6.8xlarge' | |||||
| class EASGpuInstanceType(object): | |||||
| """EAS Cpu Instance TYpe, ref(https://help.aliyun.com/document_detail/144261.html) | |||||
| """ | |||||
| tiny = 'ecs.gn5-c28g1.7xlarge' | |||||
| small = 'ecs.gn5-c8g1.4xlarge' | |||||
| medium = 'ecs.gn6i-c24g1.12xlarge' | |||||
| large = 'ecs.gn6e-c12g1.3xlarge' | |||||
| def min_smaller_than_max(instance, attribute, value): | |||||
| if value > instance.max_replica: | |||||
| raise ValueError( | |||||
| "'min_replica' value: %s has to be smaller than 'max_replica' value: %s!" | |||||
| % (value, instance.max_replica)) | |||||
| @define | |||||
| class ServiceScalingConfig(object): | |||||
| """Resource scaling config | |||||
| Currently we ignore max_replica | |||||
| Args: | |||||
| max_replica: maximum replica | |||||
| min_replica: minimum replica | |||||
| """ | |||||
| max_replica: int = field(default=1, validator=validators.ge(1)) | |||||
| min_replica: int = field( | |||||
| default=1, validator=[validators.ge(1), min_smaller_than_max]) | |||||
| @define | |||||
| class ServiceResourceConfig(object): | |||||
| """Eas Resource request. | |||||
| Args: | |||||
| accelerator: the accelerator(cpu|gpu) | |||||
| instance_type: the instance type. | |||||
| scaling: The instance scaling config. | |||||
| """ | |||||
| instance_type: str | |||||
| scaling: ServiceScalingConfig | |||||
| accelerator: str = field( | |||||
| default=Accelerator.CPU, | |||||
| validator=validators.in_([Accelerator.CPU, Accelerator.GPU])) | |||||
| @define | |||||
| class ServiceProviderParameters(ABC): | |||||
| pass | |||||
| @define | |||||
| class EASDeployParameters(ServiceProviderParameters): | |||||
| """Parameters for EAS Deployment. | |||||
| Args: | |||||
| resource_group: the resource group to deploy, current default. | |||||
| region: The eas instance region(eg: cn-hangzhou). | |||||
| access_key_id: The eas account access key id. | |||||
| access_key_secret: The eas account access key secret. | |||||
| vendor: must be 'eas' | |||||
| """ | |||||
| region: str | |||||
| access_key_id: str | |||||
| access_key_secret: str | |||||
| resource_group: Optional[str] = None | |||||
| vendor: str = field( | |||||
| default=Vendor.EAS, validator=validators.in_([Vendor.EAS])) | |||||
| @define | |||||
| class EASListParameters(ServiceProviderParameters): | |||||
| """EAS instance list parameters. | |||||
| Args: | |||||
| resource_group: the resource group to deploy, current default. | |||||
| region: The eas instance region(eg: cn-hangzhou). | |||||
| access_key_id: The eas account access key id. | |||||
| access_key_secret: The eas account access key secret. | |||||
| vendor: must be 'eas' | |||||
| """ | |||||
| access_key_id: str | |||||
| access_key_secret: str | |||||
| region: str = None | |||||
| resource_group: str = None | |||||
| vendor: str = field( | |||||
| default=Vendor.EAS, validator=validators.in_([Vendor.EAS])) | |||||
| @define | |||||
| class DeployServiceParameters(object): | |||||
| """Deploy service parameters | |||||
| Args: | |||||
| instance_name: the name of the service. | |||||
| model_id: the modelscope model_id | |||||
| revision: the modelscope model revision | |||||
| resource: the resource requirement. | |||||
| provider: the cloud service provider. | |||||
| """ | |||||
| instance_name: str | |||||
| model_id: str | |||||
| revision: str | |||||
| resource: ServiceResourceConfig | |||||
| provider: ServiceProviderParameters | |||||
| class AttrsToQueryString(ABC): | |||||
| """Convert the attrs class to json string. | |||||
| Args: | |||||
| """ | |||||
| def to_query_str(self): | |||||
| self_dict = asdict( | |||||
| self.provider, filter=lambda attr, value: value is not None) | |||||
| json_str = json.dumps(self_dict) | |||||
| print(json_str) | |||||
| safe_str = urllib.parse.quote_plus(json_str) | |||||
| print(safe_str) | |||||
| query_param = 'provider=%s' % safe_str | |||||
| return query_param | |||||
| @define | |||||
| class ListServiceParameters(AttrsToQueryString): | |||||
| provider: ServiceProviderParameters | |||||
| skip: int = 0 | |||||
| limit: int = 100 | |||||
| @define | |||||
| class GetServiceParameters(AttrsToQueryString): | |||||
| provider: ServiceProviderParameters | |||||
| @define | |||||
| class DeleteServiceParameters(AttrsToQueryString): | |||||
| provider: ServiceProviderParameters | |||||
| class ServiceDeployer(object): | |||||
| def __init__(self, endpoint=None): | |||||
| self.endpoint = endpoint if endpoint is not None else get_endpoint() | |||||
| self.headers = {'user-agent': ModelScopeConfig.get_user_agent()} | |||||
| self.cookies = ModelScopeConfig.get_cookies() | |||||
| if self.cookies is None: | |||||
| raise NotLoginException( | |||||
| 'Token does not exist, please login with HubApi first.') | |||||
| # deploy_model | |||||
| def create(self, model_id: str, revision: str, instance_name: str, | |||||
| resource: ServiceResourceConfig, | |||||
| provider: ServiceProviderParameters): | |||||
| """Deploy model to cloud, current we only support PAI EAS, this is an async API , | |||||
| and the deployment could take a while to finish remotely. Please check deploy instance | |||||
| status separately via checking the status. | |||||
| Args: | |||||
| model_id (str): The deployed model id | |||||
| revision (str): The model revision | |||||
| instance_name (str): The deployed model instance name. | |||||
| resource (ServiceResourceConfig): The service resource information. | |||||
| provider (ServiceProviderParameters): The service provider parameter | |||||
| Raises: | |||||
| NotLoginException: To use this api, you need login first. | |||||
| NotSupportError: Not supported platform. | |||||
| RequestError: The server return error. | |||||
| Returns: | |||||
| ServiceInstanceInfo: The information of the deployed service instance. | |||||
| """ | |||||
| if provider.vendor != Vendor.EAS: | |||||
| raise NotSupportError( | |||||
| 'Not support vendor: %s ,only support EAS current.' % | |||||
| (provider.vendor)) | |||||
| create_params = DeployServiceParameters( | |||||
| instance_name=instance_name, | |||||
| model_id=model_id, | |||||
| revision=revision, | |||||
| resource=resource, | |||||
| provider=provider) | |||||
| path = f'{self.endpoint}/api/v1/deployer/endpoint' | |||||
| body = asdict(create_params) | |||||
| r = requests.post( | |||||
| path, json=body, cookies=self.cookies, headers=self.headers) | |||||
| handle_http_response(r, logger, self.cookies, 'create_service') | |||||
| if r.status_code >= HTTPStatus.OK and r.status_code < HTTPStatus.MULTIPLE_CHOICES: | |||||
| if is_ok(r.json()): | |||||
| data = r.json()[API_RESPONSE_FIELD_DATA] | |||||
| return data | |||||
| else: | |||||
| raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE]) | |||||
| else: | |||||
| raise_for_http_status(r) | |||||
| return None | |||||
| def get(self, instance_name: str, provider: ServiceProviderParameters): | |||||
| """Query the specified instance information. | |||||
| Args: | |||||
| instance_name (str): The deployed instance name. | |||||
| provider (ServiceProviderParameters): The cloud provider information, for eas | |||||
| need region(eg: ch-hangzhou), access_key_id and access_key_secret. | |||||
| Raises: | |||||
| NotLoginException: To use this api, you need login first. | |||||
| RequestError: The request is failed from server. | |||||
| Returns: | |||||
| Dict: The information of the requested service instance. | |||||
| """ | |||||
| params = GetServiceParameters(provider=provider) | |||||
| path = '%s/api/v1/deployer/endpoint/%s?%s' % ( | |||||
| self.endpoint, instance_name, params.to_query_str()) | |||||
| r = requests.get(path, cookies=self.cookies, headers=self.headers) | |||||
| handle_http_response(r, logger, self.cookies, 'get_service') | |||||
| if r.status_code == HTTPStatus.OK: | |||||
| if is_ok(r.json()): | |||||
| data = r.json()[API_RESPONSE_FIELD_DATA] | |||||
| return data | |||||
| else: | |||||
| raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE]) | |||||
| else: | |||||
| raise_for_http_status(r) | |||||
| return None | |||||
| def delete(self, instance_name: str, provider: ServiceProviderParameters): | |||||
| """Delete deployed model, this api send delete command and return, it will take | |||||
| some to delete, please check through the cloud console. | |||||
| Args: | |||||
| instance_name (str): The instance name you want to delete. | |||||
| provider (ServiceProviderParameters): The cloud provider information, for eas | |||||
| need region(eg: ch-hangzhou), access_key_id and access_key_secret. | |||||
| Raises: | |||||
| NotLoginException: To call this api, you need login first. | |||||
| RequestError: The request is failed. | |||||
| Returns: | |||||
| Dict: The deleted instance information. | |||||
| """ | |||||
| params = DeleteServiceParameters(provider=provider) | |||||
| path = '%s/api/v1/deployer/endpoint/%s?%s' % ( | |||||
| self.endpoint, instance_name, params.to_query_str()) | |||||
| r = requests.delete(path, cookies=self.cookies, headers=self.headers) | |||||
| handle_http_response(r, logger, self.cookies, 'delete_service') | |||||
| if r.status_code == HTTPStatus.OK: | |||||
| if is_ok(r.json()): | |||||
| data = r.json()[API_RESPONSE_FIELD_DATA] | |||||
| return data | |||||
| else: | |||||
| raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE]) | |||||
| else: | |||||
| raise_for_http_status(r) | |||||
| return None | |||||
| def list(self, | |||||
| provider: ServiceProviderParameters, | |||||
| skip: int = 0, | |||||
| limit: int = 100): | |||||
| """List deployed model instances. | |||||
| Args: | |||||
| provider (ServiceProviderParameters): The cloud service provider parameter, | |||||
| for eas, need access_key_id and access_key_secret. | |||||
| skip: start of the list, current not support. | |||||
| limit: maximum number of instances return, current not support | |||||
| Raises: | |||||
| NotLoginException: To use this api, you need login first. | |||||
| RequestError: The request is failed from server. | |||||
| Returns: | |||||
| List: List of instance information | |||||
| """ | |||||
| params = ListServiceParameters( | |||||
| provider=provider, skip=skip, limit=limit) | |||||
| path = '%s/api/v1/deployer/endpoint?%s' % (self.endpoint, | |||||
| params.to_query_str()) | |||||
| r = requests.get(path, cookies=self.cookies, headers=self.headers) | |||||
| handle_http_response(r, logger, self.cookies, 'list_service_instances') | |||||
| if r.status_code == HTTPStatus.OK: | |||||
| if is_ok(r.json()): | |||||
| data = r.json()[API_RESPONSE_FIELD_DATA] | |||||
| return data | |||||
| else: | |||||
| raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE]) | |||||
| else: | |||||
| raise_for_http_status(r) | |||||
| return None | |||||
| @@ -4,6 +4,18 @@ from http import HTTPStatus | |||||
| from requests.exceptions import HTTPError | from requests.exceptions import HTTPError | ||||
| from modelscope.utils.logger import get_logger | |||||
| logger = get_logger() | |||||
| class NotSupportError(Exception): | |||||
| pass | |||||
| class NoValidRevisionError(Exception): | |||||
| pass | |||||
| class NotExistError(Exception): | class NotExistError(Exception): | ||||
| pass | pass | ||||
| @@ -45,15 +57,25 @@ def is_ok(rsp): | |||||
| return rsp['Code'] == HTTPStatus.OK and rsp['Success'] | return rsp['Code'] == HTTPStatus.OK and rsp['Success'] | ||||
| def handle_http_post_error(response, url, request_body): | |||||
| try: | |||||
| response.raise_for_status() | |||||
| except HTTPError as error: | |||||
| logger.error('Request %s with body: %s exception' % | |||||
| (url, request_body)) | |||||
| raise error | |||||
| def handle_http_response(response, logger, cookies, model_id): | def handle_http_response(response, logger, cookies, model_id): | ||||
| try: | try: | ||||
| response.raise_for_status() | response.raise_for_status() | ||||
| except HTTPError: | |||||
| except HTTPError as error: | |||||
| if cookies is None: # code in [403] and | if cookies is None: # code in [403] and | ||||
| logger.error( | logger.error( | ||||
| f'Authentication token does not exist, failed to access model {model_id} which may not exist or may be \ | f'Authentication token does not exist, failed to access model {model_id} which may not exist or may be \ | ||||
| private. Please login first.') | private. Please login first.') | ||||
| raise | |||||
| logger.error('Response details: %s' % response.content) | |||||
| raise error | |||||
| def raise_on_error(rsp): | def raise_on_error(rsp): | ||||
| @@ -81,3 +103,33 @@ def datahub_raise_on_error(url, rsp): | |||||
| raise RequestError( | raise RequestError( | ||||
| f"Url = {url}, Status = {rsp.get('status')}, error = {rsp.get('error')}, message = {rsp.get('message')}" | f"Url = {url}, Status = {rsp.get('status')}, error = {rsp.get('error')}, message = {rsp.get('message')}" | ||||
| ) | ) | ||||
| def raise_for_http_status(rsp): | |||||
| """ | |||||
| Attempt to decode utf-8 first since some servers | |||||
| localize reason strings, for invalid utf-8, fall back | |||||
| to decoding with iso-8859-1. | |||||
| """ | |||||
| http_error_msg = '' | |||||
| if isinstance(rsp.reason, bytes): | |||||
| try: | |||||
| reason = rsp.reason.decode('utf-8') | |||||
| except UnicodeDecodeError: | |||||
| reason = rsp.reason.decode('iso-8859-1') | |||||
| else: | |||||
| reason = rsp.reason | |||||
| if 400 <= rsp.status_code < 500: | |||||
| http_error_msg = u'%s Client Error: %s for url: %s' % (rsp.status_code, | |||||
| reason, rsp.url) | |||||
| elif 500 <= rsp.status_code < 600: | |||||
| http_error_msg = u'%s Server Error: %s for url: %s' % (rsp.status_code, | |||||
| reason, rsp.url) | |||||
| if http_error_msg: | |||||
| req = rsp.request | |||||
| if req.method == 'POST': | |||||
| http_error_msg = u'%s, body: %s' % (http_error_msg, req.body) | |||||
| raise HTTPError(http_error_msg, response=rsp) | |||||
| @@ -2,29 +2,25 @@ | |||||
| import copy | import copy | ||||
| import os | import os | ||||
| import sys | |||||
| import tempfile | import tempfile | ||||
| from functools import partial | from functools import partial | ||||
| from http.cookiejar import CookieJar | from http.cookiejar import CookieJar | ||||
| from pathlib import Path | from pathlib import Path | ||||
| from typing import Dict, Optional, Union | from typing import Dict, Optional, Union | ||||
| from uuid import uuid4 | |||||
| import requests | import requests | ||||
| from filelock import FileLock | |||||
| from tqdm import tqdm | from tqdm import tqdm | ||||
| from modelscope import __version__ | from modelscope import __version__ | ||||
| from modelscope.hub.api import HubApi, ModelScopeConfig | |||||
| from modelscope.utils.constant import DEFAULT_MODEL_REVISION | from modelscope.utils.constant import DEFAULT_MODEL_REVISION | ||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| from .api import HubApi, ModelScopeConfig | |||||
| from .constants import FILE_HASH | from .constants import FILE_HASH | ||||
| from .errors import FileDownloadError, NotExistError | from .errors import FileDownloadError, NotExistError | ||||
| from .utils.caching import ModelFileSystemCache | from .utils.caching import ModelFileSystemCache | ||||
| from .utils.utils import (file_integrity_validation, get_cache_dir, | from .utils.utils import (file_integrity_validation, get_cache_dir, | ||||
| get_endpoint, model_id_to_group_owner_name) | get_endpoint, model_id_to_group_owner_name) | ||||
| SESSION_ID = uuid4().hex | |||||
| logger = get_logger() | logger = get_logger() | ||||
| @@ -35,6 +31,7 @@ def model_file_download( | |||||
| cache_dir: Optional[str] = None, | cache_dir: Optional[str] = None, | ||||
| user_agent: Union[Dict, str, None] = None, | user_agent: Union[Dict, str, None] = None, | ||||
| local_files_only: Optional[bool] = False, | local_files_only: Optional[bool] = False, | ||||
| cookies: Optional[CookieJar] = None, | |||||
| ) -> Optional[str]: # pragma: no cover | ) -> Optional[str]: # pragma: no cover | ||||
| """ | """ | ||||
| Download from a given URL and cache it if it's not already present in the | Download from a given URL and cache it if it's not already present in the | ||||
| @@ -105,54 +102,47 @@ def model_file_download( | |||||
| " online, set 'local_files_only' to False.") | " online, set 'local_files_only' to False.") | ||||
| _api = HubApi() | _api = HubApi() | ||||
| headers = {'user-agent': http_user_agent(user_agent=user_agent, )} | |||||
| cookies = ModelScopeConfig.get_cookies() | |||||
| branches, tags = _api.get_model_branches_and_tags( | |||||
| model_id, use_cookies=False if cookies is None else cookies) | |||||
| headers = { | |||||
| 'user-agent': ModelScopeConfig.get_user_agent(user_agent=user_agent, ) | |||||
| } | |||||
| if cookies is None: | |||||
| cookies = ModelScopeConfig.get_cookies() | |||||
| revision = _api.get_valid_revision( | |||||
| model_id, revision=revision, cookies=cookies) | |||||
| file_to_download_info = None | file_to_download_info = None | ||||
| is_commit_id = False | |||||
| if revision in branches or revision in tags: # The revision is version or tag, | |||||
| # we need to confirm the version is up to date | |||||
| # we need to get the file list to check if the lateast version is cached, if so return, otherwise download | |||||
| model_files = _api.get_model_files( | |||||
| model_id=model_id, | |||||
| revision=revision, | |||||
| recursive=True, | |||||
| use_cookies=False if cookies is None else cookies) | |||||
| for model_file in model_files: | |||||
| if model_file['Type'] == 'tree': | |||||
| continue | |||||
| if model_file['Path'] == file_path: | |||||
| if cache.exists(model_file): | |||||
| return cache.get_file_by_info(model_file) | |||||
| else: | |||||
| file_to_download_info = model_file | |||||
| break | |||||
| if file_to_download_info is None: | |||||
| raise NotExistError('The file path: %s not exist in: %s' % | |||||
| (file_path, model_id)) | |||||
| else: # the revision is commit id. | |||||
| cached_file_path = cache.get_file_by_path_and_commit_id( | |||||
| file_path, revision) | |||||
| if cached_file_path is not None: | |||||
| file_name = os.path.basename(cached_file_path) | |||||
| logger.info( | |||||
| f'File {file_name} already in cache, skip downloading!') | |||||
| return cached_file_path # the file is in cache. | |||||
| is_commit_id = True | |||||
| # we need to confirm the version is up-to-date | |||||
| # we need to get the file list to check if the latest version is cached, if so return, otherwise download | |||||
| model_files = _api.get_model_files( | |||||
| model_id=model_id, | |||||
| revision=revision, | |||||
| recursive=True, | |||||
| use_cookies=False if cookies is None else cookies) | |||||
| for model_file in model_files: | |||||
| if model_file['Type'] == 'tree': | |||||
| continue | |||||
| if model_file['Path'] == file_path: | |||||
| if cache.exists(model_file): | |||||
| logger.info( | |||||
| f'File {model_file["Name"]} already in cache, skip downloading!' | |||||
| ) | |||||
| return cache.get_file_by_info(model_file) | |||||
| else: | |||||
| file_to_download_info = model_file | |||||
| break | |||||
| if file_to_download_info is None: | |||||
| raise NotExistError('The file path: %s not exist in: %s' % | |||||
| (file_path, model_id)) | |||||
| # we need to download again | # we need to download again | ||||
| url_to_download = get_file_download_url(model_id, file_path, revision) | url_to_download = get_file_download_url(model_id, file_path, revision) | ||||
| file_to_download_info = { | file_to_download_info = { | ||||
| 'Path': | |||||
| file_path, | |||||
| 'Revision': | |||||
| revision if is_commit_id else file_to_download_info['Revision'], | |||||
| FILE_HASH: | |||||
| None if (is_commit_id or FILE_HASH not in file_to_download_info) else | |||||
| file_to_download_info[FILE_HASH] | |||||
| 'Path': file_path, | |||||
| 'Revision': file_to_download_info['Revision'], | |||||
| FILE_HASH: file_to_download_info[FILE_HASH] | |||||
| } | } | ||||
| temp_file_name = next(tempfile._get_candidate_names()) | temp_file_name = next(tempfile._get_candidate_names()) | ||||
| @@ -171,25 +161,6 @@ def model_file_download( | |||||
| os.path.join(temporary_cache_dir, temp_file_name)) | os.path.join(temporary_cache_dir, temp_file_name)) | ||||
| def http_user_agent(user_agent: Union[Dict, str, None] = None, ) -> str: | |||||
| """Formats a user-agent string with basic info about a request. | |||||
| Args: | |||||
| user_agent (`str`, `dict`, *optional*): | |||||
| The user agent info in the form of a dictionary or a single string. | |||||
| Returns: | |||||
| The formatted user-agent string. | |||||
| """ | |||||
| ua = f'modelscope/{__version__}; python/{sys.version.split()[0]}; session_id/{SESSION_ID}' | |||||
| if isinstance(user_agent, dict): | |||||
| ua = '; '.join(f'{k}/{v}' for k, v in user_agent.items()) | |||||
| elif isinstance(user_agent, str): | |||||
| ua = user_agent | |||||
| return ua | |||||
| def get_file_download_url(model_id: str, file_path: str, revision: str): | def get_file_download_url(model_id: str, file_path: str, revision: str): | ||||
| """ | """ | ||||
| Format file download url according to `model_id`, `revision` and `file_path`. | Format file download url according to `model_id`, `revision` and `file_path`. | ||||
| @@ -3,10 +3,9 @@ | |||||
| import os | import os | ||||
| import subprocess | import subprocess | ||||
| from typing import List | from typing import List | ||||
| from xmlrpc.client import Boolean | |||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| from .api import ModelScopeConfig | |||||
| from ..utils.constant import MASTER_MODEL_BRANCH | |||||
| from .errors import GitError | from .errors import GitError | ||||
| logger = get_logger() | logger = get_logger() | ||||
| @@ -131,6 +130,7 @@ class GitCommandWrapper(metaclass=Singleton): | |||||
| return response | return response | ||||
| def add_user_info(self, repo_base_dir, repo_name): | def add_user_info(self, repo_base_dir, repo_name): | ||||
| from modelscope.hub.api import ModelScopeConfig | |||||
| user_name, user_email = ModelScopeConfig.get_user_info() | user_name, user_email = ModelScopeConfig.get_user_info() | ||||
| if user_name and user_email: | if user_name and user_email: | ||||
| # config user.name and user.email if exist | # config user.name and user.email if exist | ||||
| @@ -138,8 +138,8 @@ class GitCommandWrapper(metaclass=Singleton): | |||||
| repo_base_dir, repo_name, user_name) | repo_base_dir, repo_name, user_name) | ||||
| response = self._run_git_command(*config_user_name_args.split(' ')) | response = self._run_git_command(*config_user_name_args.split(' ')) | ||||
| logger.debug(response.stdout.decode('utf8')) | logger.debug(response.stdout.decode('utf8')) | ||||
| config_user_email_args = '-C %s/%s config user.name %s' % ( | |||||
| repo_base_dir, repo_name, user_name) | |||||
| config_user_email_args = '-C %s/%s config user.email %s' % ( | |||||
| repo_base_dir, repo_name, user_email) | |||||
| response = self._run_git_command( | response = self._run_git_command( | ||||
| *config_user_email_args.split(' ')) | *config_user_email_args.split(' ')) | ||||
| logger.debug(response.stdout.decode('utf8')) | logger.debug(response.stdout.decode('utf8')) | ||||
| @@ -177,6 +177,18 @@ class GitCommandWrapper(metaclass=Singleton): | |||||
| cmds = ['-C', '%s' % repo_dir, 'checkout', '-b', revision] | cmds = ['-C', '%s' % repo_dir, 'checkout', '-b', revision] | ||||
| return self._run_git_command(*cmds) | return self._run_git_command(*cmds) | ||||
| def get_remote_branches(self, repo_dir: str): | |||||
| cmds = ['-C', '%s' % repo_dir, 'branch', '-r'] | |||||
| rsp = self._run_git_command(*cmds) | |||||
| info = [ | |||||
| line.strip() | |||||
| for line in rsp.stdout.decode('utf8').strip().split(os.linesep) | |||||
| ] | |||||
| if len(info) == 1: | |||||
| return ['/'.join(info[0].split('/')[1:])] | |||||
| else: | |||||
| return ['/'.join(line.split('/')[1:]) for line in info[1:]] | |||||
| def pull(self, repo_dir: str): | def pull(self, repo_dir: str): | ||||
| cmds = ['-C', repo_dir, 'pull'] | cmds = ['-C', repo_dir, 'pull'] | ||||
| return self._run_git_command(*cmds) | return self._run_git_command(*cmds) | ||||
| @@ -216,3 +228,22 @@ class GitCommandWrapper(metaclass=Singleton): | |||||
| files.append(line.split(' ')[-1]) | files.append(line.split(' ')[-1]) | ||||
| return files | return files | ||||
| def tag(self, | |||||
| repo_dir: str, | |||||
| tag_name: str, | |||||
| message: str, | |||||
| ref: str = MASTER_MODEL_BRANCH): | |||||
| cmd_args = [ | |||||
| '-C', repo_dir, 'tag', tag_name, '-m', | |||||
| '"%s"' % message, ref | |||||
| ] | |||||
| rsp = self._run_git_command(*cmd_args) | |||||
| logger.debug(rsp.stdout.decode('utf8')) | |||||
| return rsp | |||||
| def push_tag(self, repo_dir: str, tag_name): | |||||
| cmd_args = ['-C', repo_dir, 'push', 'origin', tag_name] | |||||
| rsp = self._run_git_command(*cmd_args) | |||||
| logger.debug(rsp.stdout.decode('utf8')) | |||||
| return rsp | |||||
| @@ -5,9 +5,9 @@ from typing import Optional | |||||
| from modelscope.hub.errors import GitError, InvalidParameter, NotLoginException | from modelscope.hub.errors import GitError, InvalidParameter, NotLoginException | ||||
| from modelscope.utils.constant import (DEFAULT_DATASET_REVISION, | from modelscope.utils.constant import (DEFAULT_DATASET_REVISION, | ||||
| DEFAULT_MODEL_REVISION) | |||||
| DEFAULT_REPOSITORY_REVISION, | |||||
| MASTER_MODEL_BRANCH) | |||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| from .api import ModelScopeConfig | |||||
| from .git import GitCommandWrapper | from .git import GitCommandWrapper | ||||
| from .utils.utils import get_endpoint | from .utils.utils import get_endpoint | ||||
| @@ -21,7 +21,7 @@ class Repository: | |||||
| def __init__(self, | def __init__(self, | ||||
| model_dir: str, | model_dir: str, | ||||
| clone_from: str, | clone_from: str, | ||||
| revision: Optional[str] = DEFAULT_MODEL_REVISION, | |||||
| revision: Optional[str] = DEFAULT_REPOSITORY_REVISION, | |||||
| auth_token: Optional[str] = None, | auth_token: Optional[str] = None, | ||||
| git_path: Optional[str] = None): | git_path: Optional[str] = None): | ||||
| """ | """ | ||||
| @@ -47,6 +47,7 @@ class Repository: | |||||
| err_msg = 'a non-default value of revision cannot be empty.' | err_msg = 'a non-default value of revision cannot be empty.' | ||||
| raise InvalidParameter(err_msg) | raise InvalidParameter(err_msg) | ||||
| from modelscope.hub.api import ModelScopeConfig | |||||
| if auth_token: | if auth_token: | ||||
| self.auth_token = auth_token | self.auth_token = auth_token | ||||
| else: | else: | ||||
| @@ -89,7 +90,8 @@ class Repository: | |||||
| def push(self, | def push(self, | ||||
| commit_message: str, | commit_message: str, | ||||
| branch: Optional[str] = DEFAULT_MODEL_REVISION, | |||||
| local_branch: Optional[str] = DEFAULT_REPOSITORY_REVISION, | |||||
| remote_branch: Optional[str] = DEFAULT_REPOSITORY_REVISION, | |||||
| force: bool = False): | force: bool = False): | ||||
| """Push local files to remote, this method will do. | """Push local files to remote, this method will do. | ||||
| git pull | git pull | ||||
| @@ -116,14 +118,48 @@ class Repository: | |||||
| url = self.git_wrapper.get_repo_remote_url(self.model_dir) | url = self.git_wrapper.get_repo_remote_url(self.model_dir) | ||||
| self.git_wrapper.pull(self.model_dir) | self.git_wrapper.pull(self.model_dir) | ||||
| self.git_wrapper.add(self.model_dir, all_files=True) | self.git_wrapper.add(self.model_dir, all_files=True) | ||||
| self.git_wrapper.commit(self.model_dir, commit_message) | self.git_wrapper.commit(self.model_dir, commit_message) | ||||
| self.git_wrapper.push( | self.git_wrapper.push( | ||||
| repo_dir=self.model_dir, | repo_dir=self.model_dir, | ||||
| token=self.auth_token, | token=self.auth_token, | ||||
| url=url, | url=url, | ||||
| local_branch=branch, | |||||
| remote_branch=branch) | |||||
| local_branch=local_branch, | |||||
| remote_branch=remote_branch) | |||||
| def tag(self, tag_name: str, message: str, ref: str = MASTER_MODEL_BRANCH): | |||||
| """Create a new tag. | |||||
| Args: | |||||
| tag_name (str): The name of the tag | |||||
| message (str): The tag message. | |||||
| ref (str): The tag reference, can be commit id or branch. | |||||
| """ | |||||
| if tag_name is None or tag_name == '': | |||||
| msg = 'We use tag-based revision, therefore tag_name cannot be None or empty.' | |||||
| raise InvalidParameter(msg) | |||||
| if message is None or message == '': | |||||
| msg = 'We use annotated tag, therefore message cannot None or empty.' | |||||
| self.git_wrapper.tag( | |||||
| repo_dir=self.model_dir, | |||||
| tag_name=tag_name, | |||||
| message=message, | |||||
| ref=ref) | |||||
| def tag_and_push(self, | |||||
| tag_name: str, | |||||
| message: str, | |||||
| ref: str = MASTER_MODEL_BRANCH): | |||||
| """Create tag and push to remote | |||||
| Args: | |||||
| tag_name (str): The name of the tag | |||||
| message (str): The tag message. | |||||
| ref (str, optional): The tag ref, can be commit id or branch. Defaults to MASTER_MODEL_BRANCH. | |||||
| """ | |||||
| self.tag(tag_name, message, ref) | |||||
| self.git_wrapper.push_tag(repo_dir=self.model_dir, tag_name=tag_name) | |||||
| class DatasetRepository: | class DatasetRepository: | ||||
| @@ -166,7 +202,7 @@ class DatasetRepository: | |||||
| err_msg = 'a non-default value of revision cannot be empty.' | err_msg = 'a non-default value of revision cannot be empty.' | ||||
| raise InvalidParameter(err_msg) | raise InvalidParameter(err_msg) | ||||
| self.revision = revision | self.revision = revision | ||||
| from modelscope.hub.api import ModelScopeConfig | |||||
| if auth_token: | if auth_token: | ||||
| self.auth_token = auth_token | self.auth_token = auth_token | ||||
| else: | else: | ||||
| @@ -2,16 +2,15 @@ | |||||
| import os | import os | ||||
| import tempfile | import tempfile | ||||
| from http.cookiejar import CookieJar | |||||
| from pathlib import Path | from pathlib import Path | ||||
| from typing import Dict, Optional, Union | from typing import Dict, Optional, Union | ||||
| from modelscope.hub.api import HubApi, ModelScopeConfig | |||||
| from modelscope.utils.constant import DEFAULT_MODEL_REVISION | from modelscope.utils.constant import DEFAULT_MODEL_REVISION | ||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| from .api import HubApi, ModelScopeConfig | |||||
| from .constants import FILE_HASH | from .constants import FILE_HASH | ||||
| from .errors import NotExistError | |||||
| from .file_download import (get_file_download_url, http_get_file, | |||||
| http_user_agent) | |||||
| from .file_download import get_file_download_url, http_get_file | |||||
| from .utils.caching import ModelFileSystemCache | from .utils.caching import ModelFileSystemCache | ||||
| from .utils.utils import (file_integrity_validation, get_cache_dir, | from .utils.utils import (file_integrity_validation, get_cache_dir, | ||||
| model_id_to_group_owner_name) | model_id_to_group_owner_name) | ||||
| @@ -23,7 +22,8 @@ def snapshot_download(model_id: str, | |||||
| revision: Optional[str] = DEFAULT_MODEL_REVISION, | revision: Optional[str] = DEFAULT_MODEL_REVISION, | ||||
| cache_dir: Union[str, Path, None] = None, | cache_dir: Union[str, Path, None] = None, | ||||
| user_agent: Optional[Union[Dict, str]] = None, | user_agent: Optional[Union[Dict, str]] = None, | ||||
| local_files_only: Optional[bool] = False) -> str: | |||||
| local_files_only: Optional[bool] = False, | |||||
| cookies: Optional[CookieJar] = None) -> str: | |||||
| """Download all files of a repo. | """Download all files of a repo. | ||||
| Downloads a whole snapshot of a repo's files at the specified revision. This | Downloads a whole snapshot of a repo's files at the specified revision. This | ||||
| is useful when you want all files from a repo, because you don't know which | is useful when you want all files from a repo, because you don't know which | ||||
| @@ -81,15 +81,15 @@ def snapshot_download(model_id: str, | |||||
| ) # we can not confirm the cached file is for snapshot 'revision' | ) # we can not confirm the cached file is for snapshot 'revision' | ||||
| else: | else: | ||||
| # make headers | # make headers | ||||
| headers = {'user-agent': http_user_agent(user_agent=user_agent, )} | |||||
| headers = { | |||||
| 'user-agent': | |||||
| ModelScopeConfig.get_user_agent(user_agent=user_agent, ) | |||||
| } | |||||
| _api = HubApi() | _api = HubApi() | ||||
| cookies = ModelScopeConfig.get_cookies() | |||||
| # get file list from model repo | |||||
| branches, tags = _api.get_model_branches_and_tags( | |||||
| model_id, use_cookies=False if cookies is None else cookies) | |||||
| if revision not in branches and revision not in tags: | |||||
| raise NotExistError('The specified branch or tag : %s not exist!' | |||||
| % revision) | |||||
| if cookies is None: | |||||
| cookies = ModelScopeConfig.get_cookies() | |||||
| revision = _api.get_valid_revision( | |||||
| model_id, revision=revision, cookies=cookies) | |||||
| snapshot_header = headers if 'CI_TEST' in os.environ else { | snapshot_header = headers if 'CI_TEST' in os.environ else { | ||||
| **headers, | **headers, | ||||
| @@ -110,7 +110,7 @@ def snapshot_download(model_id: str, | |||||
| for model_file in model_files: | for model_file in model_files: | ||||
| if model_file['Type'] == 'tree': | if model_file['Type'] == 'tree': | ||||
| continue | continue | ||||
| # check model_file is exist in cache, if exist, skip download, otherwise download | |||||
| # check model_file is exist in cache, if existed, skip download, otherwise download | |||||
| if cache.exists(model_file): | if cache.exists(model_file): | ||||
| file_name = os.path.basename(model_file['Name']) | file_name = os.path.basename(model_file['Name']) | ||||
| logger.info( | logger.info( | ||||
| @@ -2,12 +2,12 @@ | |||||
| import hashlib | import hashlib | ||||
| import os | import os | ||||
| from datetime import datetime | |||||
| from typing import Optional | from typing import Optional | ||||
| from modelscope.hub.constants import (DEFAULT_MODELSCOPE_DATA_ENDPOINT, | |||||
| DEFAULT_MODELSCOPE_DOMAIN, | |||||
| from modelscope.hub.constants import (DEFAULT_MODELSCOPE_DOMAIN, | |||||
| DEFAULT_MODELSCOPE_GROUP, | DEFAULT_MODELSCOPE_GROUP, | ||||
| MODEL_ID_SEPARATOR, | |||||
| MODEL_ID_SEPARATOR, MODELSCOPE_SDK_DEBUG, | |||||
| MODELSCOPE_URL_SCHEME) | MODELSCOPE_URL_SCHEME) | ||||
| from modelscope.hub.errors import FileIntegrityError | from modelscope.hub.errors import FileIntegrityError | ||||
| from modelscope.utils.file_utils import get_default_cache_dir | from modelscope.utils.file_utils import get_default_cache_dir | ||||
| @@ -38,17 +38,24 @@ def get_cache_dir(model_id: Optional[str] = None): | |||||
| base_path, model_id + '/') | base_path, model_id + '/') | ||||
| def get_release_datetime(): | |||||
| if MODELSCOPE_SDK_DEBUG in os.environ: | |||||
| rt = int(round(datetime.now().timestamp())) | |||||
| else: | |||||
| from modelscope import version | |||||
| rt = int( | |||||
| round( | |||||
| datetime.strptime(version.__release_datetime__, | |||||
| '%Y-%m-%d %H:%M:%S').timestamp())) | |||||
| return rt | |||||
| def get_endpoint(): | def get_endpoint(): | ||||
| modelscope_domain = os.getenv('MODELSCOPE_DOMAIN', | modelscope_domain = os.getenv('MODELSCOPE_DOMAIN', | ||||
| DEFAULT_MODELSCOPE_DOMAIN) | DEFAULT_MODELSCOPE_DOMAIN) | ||||
| return MODELSCOPE_URL_SCHEME + modelscope_domain | return MODELSCOPE_URL_SCHEME + modelscope_domain | ||||
| def get_dataset_hub_endpoint(): | |||||
| return os.environ.get('HUB_DATASET_ENDPOINT', | |||||
| DEFAULT_MODELSCOPE_DATA_ENDPOINT) | |||||
| def compute_hash(file_path): | def compute_hash(file_path): | ||||
| BUFFER_SIZE = 1024 * 64 # 64k buffer size | BUFFER_SIZE = 1024 * 64 # 64k buffer size | ||||
| sha256_hash = hashlib.sha256() | sha256_hash = hashlib.sha256() | ||||
| @@ -9,11 +9,14 @@ class Models(object): | |||||
| Model name should only contain model info but not task info. | Model name should only contain model info but not task info. | ||||
| """ | """ | ||||
| # tinynas models | |||||
| tinynas_detection = 'tinynas-detection' | tinynas_detection = 'tinynas-detection' | ||||
| tinynas_damoyolo = 'tinynas-damoyolo' | |||||
| # vision models | # vision models | ||||
| detection = 'detection' | detection = 'detection' | ||||
| realtime_object_detection = 'realtime-object-detection' | realtime_object_detection = 'realtime-object-detection' | ||||
| realtime_video_object_detection = 'realtime-video-object-detection' | |||||
| scrfd = 'scrfd' | scrfd = 'scrfd' | ||||
| classification_model = 'ClassificationModel' | classification_model = 'ClassificationModel' | ||||
| nafnet = 'nafnet' | nafnet = 'nafnet' | ||||
| @@ -27,11 +30,13 @@ class Models(object): | |||||
| face_2d_keypoints = 'face-2d-keypoints' | face_2d_keypoints = 'face-2d-keypoints' | ||||
| panoptic_segmentation = 'swinL-panoptic-segmentation' | panoptic_segmentation = 'swinL-panoptic-segmentation' | ||||
| image_reid_person = 'passvitb' | image_reid_person = 'passvitb' | ||||
| image_inpainting = 'FFTInpainting' | |||||
| video_summarization = 'pgl-video-summarization' | video_summarization = 'pgl-video-summarization' | ||||
| swinL_semantic_segmentation = 'swinL-semantic-segmentation' | swinL_semantic_segmentation = 'swinL-semantic-segmentation' | ||||
| vitadapter_semantic_segmentation = 'vitadapter-semantic-segmentation' | vitadapter_semantic_segmentation = 'vitadapter-semantic-segmentation' | ||||
| text_driven_segmentation = 'text-driven-segmentation' | text_driven_segmentation = 'text-driven-segmentation' | ||||
| resnet50_bert = 'resnet50-bert' | resnet50_bert = 'resnet50-bert' | ||||
| referring_video_object_segmentation = 'swinT-referring-video-object-segmentation' | |||||
| fer = 'fer' | fer = 'fer' | ||||
| retinaface = 'retinaface' | retinaface = 'retinaface' | ||||
| shop_segmentation = 'shop-segmentation' | shop_segmentation = 'shop-segmentation' | ||||
| @@ -39,14 +44,18 @@ class Models(object): | |||||
| mtcnn = 'mtcnn' | mtcnn = 'mtcnn' | ||||
| ulfd = 'ulfd' | ulfd = 'ulfd' | ||||
| video_inpainting = 'video-inpainting' | video_inpainting = 'video-inpainting' | ||||
| human_wholebody_keypoint = 'human-wholebody-keypoint' | |||||
| hand_static = 'hand-static' | hand_static = 'hand-static' | ||||
| face_human_hand_detection = 'face-human-hand-detection' | face_human_hand_detection = 'face-human-hand-detection' | ||||
| face_emotion = 'face-emotion' | face_emotion = 'face-emotion' | ||||
| product_segmentation = 'product-segmentation' | product_segmentation = 'product-segmentation' | ||||
| image_body_reshaping = 'image-body-reshaping' | |||||
| # EasyCV models | # EasyCV models | ||||
| yolox = 'YOLOX' | yolox = 'YOLOX' | ||||
| segformer = 'Segformer' | segformer = 'Segformer' | ||||
| hand_2d_keypoints = 'HRNet-Hand2D-Keypoints' | |||||
| image_object_detection_auto = 'image-object-detection-auto' | |||||
| # nlp models | # nlp models | ||||
| bert = 'bert' | bert = 'bert' | ||||
| @@ -58,18 +67,22 @@ class Models(object): | |||||
| space_dst = 'space-dst' | space_dst = 'space-dst' | ||||
| space_intent = 'space-intent' | space_intent = 'space-intent' | ||||
| space_modeling = 'space-modeling' | space_modeling = 'space-modeling' | ||||
| star = 'star' | |||||
| star3 = 'star3' | |||||
| space_T_en = 'space-T-en' | |||||
| space_T_cn = 'space-T-cn' | |||||
| tcrf = 'transformer-crf' | tcrf = 'transformer-crf' | ||||
| tcrf_wseg = 'transformer-crf-for-word-segmentation' | |||||
| transformer_softmax = 'transformer-softmax' | transformer_softmax = 'transformer-softmax' | ||||
| lcrf = 'lstm-crf' | lcrf = 'lstm-crf' | ||||
| lcrf_wseg = 'lstm-crf-for-word-segmentation' | |||||
| gcnncrf = 'gcnn-crf' | gcnncrf = 'gcnn-crf' | ||||
| bart = 'bart' | bart = 'bart' | ||||
| gpt3 = 'gpt3' | gpt3 = 'gpt3' | ||||
| gpt_neo = 'gpt-neo' | |||||
| plug = 'plug' | plug = 'plug' | ||||
| bert_for_ds = 'bert-for-document-segmentation' | bert_for_ds = 'bert-for-document-segmentation' | ||||
| ponet = 'ponet' | ponet = 'ponet' | ||||
| T5 = 'T5' | T5 = 'T5' | ||||
| bloom = 'bloom' | |||||
| # audio models | # audio models | ||||
| sambert_hifigan = 'sambert-hifigan' | sambert_hifigan = 'sambert-hifigan' | ||||
| @@ -88,6 +101,10 @@ class Models(object): | |||||
| team = 'team-multi-modal-similarity' | team = 'team-multi-modal-similarity' | ||||
| video_clip = 'video-clip-multi-modal-embedding' | video_clip = 'video-clip-multi-modal-embedding' | ||||
| # science models | |||||
| unifold = 'unifold' | |||||
| unifold_symmetry = 'unifold-symmetry' | |||||
| class TaskModels(object): | class TaskModels(object): | ||||
| # nlp task | # nlp task | ||||
| @@ -96,6 +113,7 @@ class TaskModels(object): | |||||
| information_extraction = 'information-extraction' | information_extraction = 'information-extraction' | ||||
| fill_mask = 'fill-mask' | fill_mask = 'fill-mask' | ||||
| feature_extraction = 'feature-extraction' | feature_extraction = 'feature-extraction' | ||||
| text_generation = 'text-generation' | |||||
| class Heads(object): | class Heads(object): | ||||
| @@ -111,6 +129,8 @@ class Heads(object): | |||||
| token_classification = 'token-classification' | token_classification = 'token-classification' | ||||
| # extraction | # extraction | ||||
| information_extraction = 'information-extraction' | information_extraction = 'information-extraction' | ||||
| # text gen | |||||
| text_generation = 'text-generation' | |||||
| class Pipelines(object): | class Pipelines(object): | ||||
| @@ -144,6 +164,7 @@ class Pipelines(object): | |||||
| salient_detection = 'u2net-salient-detection' | salient_detection = 'u2net-salient-detection' | ||||
| image_classification = 'image-classification' | image_classification = 'image-classification' | ||||
| face_detection = 'resnet-face-detection-scrfd10gkps' | face_detection = 'resnet-face-detection-scrfd10gkps' | ||||
| card_detection = 'resnet-card-detection-scrfd34gkps' | |||||
| ulfd_face_detection = 'manual-face-detection-ulfd' | ulfd_face_detection = 'manual-face-detection-ulfd' | ||||
| facial_expression_recognition = 'vgg19-facial-expression-recognition-fer' | facial_expression_recognition = 'vgg19-facial-expression-recognition-fer' | ||||
| retina_face_detection = 'resnet50-face-detection-retinaface' | retina_face_detection = 'resnet50-face-detection-retinaface' | ||||
| @@ -160,6 +181,7 @@ class Pipelines(object): | |||||
| face_image_generation = 'gan-face-image-generation' | face_image_generation = 'gan-face-image-generation' | ||||
| product_retrieval_embedding = 'resnet50-product-retrieval-embedding' | product_retrieval_embedding = 'resnet50-product-retrieval-embedding' | ||||
| realtime_object_detection = 'cspnet_realtime-object-detection_yolox' | realtime_object_detection = 'cspnet_realtime-object-detection_yolox' | ||||
| realtime_video_object_detection = 'cspnet_realtime-video-object-detection_streamyolo' | |||||
| face_recognition = 'ir101-face-recognition-cfglint' | face_recognition = 'ir101-face-recognition-cfglint' | ||||
| image_instance_segmentation = 'cascade-mask-rcnn-swin-image-instance-segmentation' | image_instance_segmentation = 'cascade-mask-rcnn-swin-image-instance-segmentation' | ||||
| image2image_translation = 'image-to-image-translation' | image2image_translation = 'image-to-image-translation' | ||||
| @@ -168,6 +190,7 @@ class Pipelines(object): | |||||
| ocr_recognition = 'convnextTiny-ocr-recognition' | ocr_recognition = 'convnextTiny-ocr-recognition' | ||||
| image_portrait_enhancement = 'gpen-image-portrait-enhancement' | image_portrait_enhancement = 'gpen-image-portrait-enhancement' | ||||
| image_to_image_generation = 'image-to-image-generation' | image_to_image_generation = 'image-to-image-generation' | ||||
| image_object_detection_auto = 'yolox_image-object-detection-auto' | |||||
| skin_retouching = 'unet-skin-retouching' | skin_retouching = 'unet-skin-retouching' | ||||
| tinynas_classification = 'tinynas-classification' | tinynas_classification = 'tinynas-classification' | ||||
| tinynas_detection = 'tinynas-detection' | tinynas_detection = 'tinynas-detection' | ||||
| @@ -178,21 +201,32 @@ class Pipelines(object): | |||||
| video_summarization = 'googlenet_pgl_video_summarization' | video_summarization = 'googlenet_pgl_video_summarization' | ||||
| image_semantic_segmentation = 'image-semantic-segmentation' | image_semantic_segmentation = 'image-semantic-segmentation' | ||||
| image_reid_person = 'passvitb-image-reid-person' | image_reid_person = 'passvitb-image-reid-person' | ||||
| image_inpainting = 'fft-inpainting' | |||||
| text_driven_segmentation = 'text-driven-segmentation' | text_driven_segmentation = 'text-driven-segmentation' | ||||
| movie_scene_segmentation = 'resnet50-bert-movie-scene-segmentation' | movie_scene_segmentation = 'resnet50-bert-movie-scene-segmentation' | ||||
| shop_segmentation = 'shop-segmentation' | shop_segmentation = 'shop-segmentation' | ||||
| video_inpainting = 'video-inpainting' | video_inpainting = 'video-inpainting' | ||||
| human_wholebody_keypoint = 'hrnetw48_human-wholebody-keypoint_image' | |||||
| pst_action_recognition = 'patchshift-action-recognition' | pst_action_recognition = 'patchshift-action-recognition' | ||||
| hand_static = 'hand-static' | hand_static = 'hand-static' | ||||
| face_human_hand_detection = 'face-human-hand-detection' | face_human_hand_detection = 'face-human-hand-detection' | ||||
| face_emotion = 'face-emotion' | face_emotion = 'face-emotion' | ||||
| product_segmentation = 'product-segmentation' | product_segmentation = 'product-segmentation' | ||||
| image_body_reshaping = 'flow-based-body-reshaping' | |||||
| referring_video_object_segmentation = 'referring-video-object-segmentation' | |||||
| # nlp tasks | # nlp tasks | ||||
| automatic_post_editing = 'automatic-post-editing' | |||||
| translation_quality_estimation = 'translation-quality-estimation' | |||||
| domain_classification = 'domain-classification' | |||||
| sentence_similarity = 'sentence-similarity' | sentence_similarity = 'sentence-similarity' | ||||
| word_segmentation = 'word-segmentation' | word_segmentation = 'word-segmentation' | ||||
| multilingual_word_segmentation = 'multilingual-word-segmentation' | |||||
| word_segmentation_thai = 'word-segmentation-thai' | |||||
| part_of_speech = 'part-of-speech' | part_of_speech = 'part-of-speech' | ||||
| named_entity_recognition = 'named-entity-recognition' | named_entity_recognition = 'named-entity-recognition' | ||||
| named_entity_recognition_thai = 'named-entity-recognition-thai' | |||||
| named_entity_recognition_viet = 'named-entity-recognition-viet' | |||||
| text_generation = 'text-generation' | text_generation = 'text-generation' | ||||
| text2text_generation = 'text2text-generation' | text2text_generation = 'text2text-generation' | ||||
| sentiment_analysis = 'sentiment-analysis' | sentiment_analysis = 'sentiment-analysis' | ||||
| @@ -208,14 +242,18 @@ class Pipelines(object): | |||||
| zero_shot_classification = 'zero-shot-classification' | zero_shot_classification = 'zero-shot-classification' | ||||
| text_error_correction = 'text-error-correction' | text_error_correction = 'text-error-correction' | ||||
| plug_generation = 'plug-generation' | plug_generation = 'plug-generation' | ||||
| gpt3_generation = 'gpt3-generation' | |||||
| faq_question_answering = 'faq-question-answering' | faq_question_answering = 'faq-question-answering' | ||||
| conversational_text_to_sql = 'conversational-text-to-sql' | conversational_text_to_sql = 'conversational-text-to-sql' | ||||
| table_question_answering_pipeline = 'table-question-answering-pipeline' | table_question_answering_pipeline = 'table-question-answering-pipeline' | ||||
| sentence_embedding = 'sentence-embedding' | sentence_embedding = 'sentence-embedding' | ||||
| passage_ranking = 'passage-ranking' | |||||
| text_ranking = 'text-ranking' | |||||
| relation_extraction = 'relation-extraction' | relation_extraction = 'relation-extraction' | ||||
| document_segmentation = 'document-segmentation' | document_segmentation = 'document-segmentation' | ||||
| feature_extraction = 'feature-extraction' | feature_extraction = 'feature-extraction' | ||||
| translation_en_to_de = 'translation_en_to_de' # keep it underscore | |||||
| translation_en_to_ro = 'translation_en_to_ro' # keep it underscore | |||||
| translation_en_to_fr = 'translation_en_to_fr' # keep it underscore | |||||
| # audio tasks | # audio tasks | ||||
| sambert_hifigan_tts = 'sambert-hifigan-tts' | sambert_hifigan_tts = 'sambert-hifigan-tts' | ||||
| @@ -236,6 +274,10 @@ class Pipelines(object): | |||||
| text_to_image_synthesis = 'text-to-image-synthesis' | text_to_image_synthesis = 'text-to-image-synthesis' | ||||
| video_multi_modal_embedding = 'video-multi-modal-embedding' | video_multi_modal_embedding = 'video-multi-modal-embedding' | ||||
| image_text_retrieval = 'image-text-retrieval' | image_text_retrieval = 'image-text-retrieval' | ||||
| ofa_ocr_recognition = 'ofa-ocr-recognition' | |||||
| # science tasks | |||||
| protein_structure = 'unifold-protein-structure' | |||||
| class Trainers(object): | class Trainers(object): | ||||
| @@ -253,12 +295,16 @@ class Trainers(object): | |||||
| # multi-modal trainers | # multi-modal trainers | ||||
| clip_multi_modal_embedding = 'clip-multi-modal-embedding' | clip_multi_modal_embedding = 'clip-multi-modal-embedding' | ||||
| ofa = 'ofa' | |||||
| # cv trainers | # cv trainers | ||||
| image_instance_segmentation = 'image-instance-segmentation' | image_instance_segmentation = 'image-instance-segmentation' | ||||
| image_portrait_enhancement = 'image-portrait-enhancement' | image_portrait_enhancement = 'image-portrait-enhancement' | ||||
| video_summarization = 'video-summarization' | video_summarization = 'video-summarization' | ||||
| movie_scene_segmentation = 'movie-scene-segmentation' | movie_scene_segmentation = 'movie-scene-segmentation' | ||||
| face_detection_scrfd = 'face-detection-scrfd' | |||||
| card_detection_scrfd = 'card-detection-scrfd' | |||||
| image_inpainting = 'image-inpainting' | |||||
| # nlp trainers | # nlp trainers | ||||
| bert_sentiment_analysis = 'bert-sentiment-analysis' | bert_sentiment_analysis = 'bert-sentiment-analysis' | ||||
| @@ -266,10 +312,12 @@ class Trainers(object): | |||||
| dialog_intent_trainer = 'dialog-intent-trainer' | dialog_intent_trainer = 'dialog-intent-trainer' | ||||
| nlp_base_trainer = 'nlp-base-trainer' | nlp_base_trainer = 'nlp-base-trainer' | ||||
| nlp_veco_trainer = 'nlp-veco-trainer' | nlp_veco_trainer = 'nlp-veco-trainer' | ||||
| nlp_passage_ranking_trainer = 'nlp-passage-ranking-trainer' | |||||
| nlp_text_ranking_trainer = 'nlp-text-ranking-trainer' | |||||
| text_generation_trainer = 'text-generation-trainer' | |||||
| # audio trainers | # audio trainers | ||||
| speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k' | speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k' | ||||
| speech_dfsmn_kws_char_farfield = 'speech_dfsmn_kws_char_farfield' | |||||
| class Preprocessors(object): | class Preprocessors(object): | ||||
| @@ -298,8 +346,12 @@ class Preprocessors(object): | |||||
| bert_seq_cls_tokenizer = 'bert-seq-cls-tokenizer' | bert_seq_cls_tokenizer = 'bert-seq-cls-tokenizer' | ||||
| text_gen_tokenizer = 'text-gen-tokenizer' | text_gen_tokenizer = 'text-gen-tokenizer' | ||||
| text2text_gen_preprocessor = 'text2text-gen-preprocessor' | text2text_gen_preprocessor = 'text2text-gen-preprocessor' | ||||
| text_gen_jieba_tokenizer = 'text-gen-jieba-tokenizer' | |||||
| text2text_translate_preprocessor = 'text2text-translate-preprocessor' | |||||
| token_cls_tokenizer = 'token-cls-tokenizer' | token_cls_tokenizer = 'token-cls-tokenizer' | ||||
| ner_tokenizer = 'ner-tokenizer' | ner_tokenizer = 'ner-tokenizer' | ||||
| thai_ner_tokenizer = 'thai-ner-tokenizer' | |||||
| viet_ner_tokenizer = 'viet-ner-tokenizer' | |||||
| nli_tokenizer = 'nli-tokenizer' | nli_tokenizer = 'nli-tokenizer' | ||||
| sen_cls_tokenizer = 'sen-cls-tokenizer' | sen_cls_tokenizer = 'sen-cls-tokenizer' | ||||
| dialog_intent_preprocessor = 'dialog-intent-preprocessor' | dialog_intent_preprocessor = 'dialog-intent-preprocessor' | ||||
| @@ -309,9 +361,10 @@ class Preprocessors(object): | |||||
| zero_shot_cls_tokenizer = 'zero-shot-cls-tokenizer' | zero_shot_cls_tokenizer = 'zero-shot-cls-tokenizer' | ||||
| text_error_correction = 'text-error-correction' | text_error_correction = 'text-error-correction' | ||||
| sentence_embedding = 'sentence-embedding' | sentence_embedding = 'sentence-embedding' | ||||
| passage_ranking = 'passage-ranking' | |||||
| text_ranking = 'text-ranking' | |||||
| sequence_labeling_tokenizer = 'sequence-labeling-tokenizer' | sequence_labeling_tokenizer = 'sequence-labeling-tokenizer' | ||||
| word_segment_text_to_label_preprocessor = 'word-segment-text-to-label-preprocessor' | word_segment_text_to_label_preprocessor = 'word-segment-text-to-label-preprocessor' | ||||
| thai_wseg_tokenizer = 'thai-wseg-tokenizer' | |||||
| fill_mask = 'fill-mask' | fill_mask = 'fill-mask' | ||||
| fill_mask_ponet = 'fill-mask-ponet' | fill_mask_ponet = 'fill-mask-ponet' | ||||
| faq_question_answering_preprocessor = 'faq-question-answering-preprocessor' | faq_question_answering_preprocessor = 'faq-question-answering-preprocessor' | ||||
| @@ -320,6 +373,7 @@ class Preprocessors(object): | |||||
| re_tokenizer = 're-tokenizer' | re_tokenizer = 're-tokenizer' | ||||
| document_segmentation = 'document-segmentation' | document_segmentation = 'document-segmentation' | ||||
| feature_extraction = 'feature-extraction' | feature_extraction = 'feature-extraction' | ||||
| sentence_piece = 'sentence-piece' | |||||
| # audio preprocessor | # audio preprocessor | ||||
| linear_aec_fbank = 'linear-aec-fbank' | linear_aec_fbank = 'linear-aec-fbank' | ||||
| @@ -331,6 +385,9 @@ class Preprocessors(object): | |||||
| ofa_tasks_preprocessor = 'ofa-tasks-preprocessor' | ofa_tasks_preprocessor = 'ofa-tasks-preprocessor' | ||||
| mplug_tasks_preprocessor = 'mplug-tasks-preprocessor' | mplug_tasks_preprocessor = 'mplug-tasks-preprocessor' | ||||
| # science preprocessor | |||||
| unifold_preprocessor = 'unifold-preprocessor' | |||||
| class Metrics(object): | class Metrics(object): | ||||
| """ Names for different metrics. | """ Names for different metrics. | ||||
| @@ -340,6 +397,9 @@ class Metrics(object): | |||||
| accuracy = 'accuracy' | accuracy = 'accuracy' | ||||
| audio_noise_metric = 'audio-noise-metric' | audio_noise_metric = 'audio-noise-metric' | ||||
| # text gen | |||||
| BLEU = 'bleu' | |||||
| # metrics for image denoise task | # metrics for image denoise task | ||||
| image_denoise_metric = 'image-denoise-metric' | image_denoise_metric = 'image-denoise-metric' | ||||
| @@ -358,6 +418,10 @@ class Metrics(object): | |||||
| video_summarization_metric = 'video-summarization-metric' | video_summarization_metric = 'video-summarization-metric' | ||||
| # metric for movie-scene-segmentation task | # metric for movie-scene-segmentation task | ||||
| movie_scene_segmentation_metric = 'movie-scene-segmentation-metric' | movie_scene_segmentation_metric = 'movie-scene-segmentation-metric' | ||||
| # metric for inpainting task | |||||
| image_inpainting_metric = 'image-inpainting-metric' | |||||
| # metric for ocr | |||||
| NED = 'ned' | |||||
| class Optimizers(object): | class Optimizers(object): | ||||
| @@ -399,6 +463,9 @@ class Hooks(object): | |||||
| IterTimerHook = 'IterTimerHook' | IterTimerHook = 'IterTimerHook' | ||||
| EvaluationHook = 'EvaluationHook' | EvaluationHook = 'EvaluationHook' | ||||
| # Compression | |||||
| SparsityHook = 'SparsityHook' | |||||
| class LR_Schedulers(object): | class LR_Schedulers(object): | ||||
| """learning rate scheduler is defined here | """learning rate scheduler is defined here | ||||
| @@ -413,7 +480,10 @@ class Datasets(object): | |||||
| """ Names for different datasets. | """ Names for different datasets. | ||||
| """ | """ | ||||
| ClsDataset = 'ClsDataset' | ClsDataset = 'ClsDataset' | ||||
| Face2dKeypointsDataset = 'Face2dKeypointsDataset' | |||||
| Face2dKeypointsDataset = 'FaceKeypointDataset' | |||||
| HandCocoWholeBodyDataset = 'HandCocoWholeBodyDataset' | |||||
| HumanWholeBodyKeypointDataset = 'WholeBodyCocoTopDownDataset' | |||||
| SegDataset = 'SegDataset' | SegDataset = 'SegDataset' | ||||
| DetDataset = 'DetDataset' | DetDataset = 'DetDataset' | ||||
| DetImagesMixDataset = 'DetImagesMixDataset' | DetImagesMixDataset = 'DetImagesMixDataset' | ||||
| PairedDataset = 'PairedDataset' | |||||
| @@ -17,6 +17,9 @@ if TYPE_CHECKING: | |||||
| from .token_classification_metric import TokenClassificationMetric | from .token_classification_metric import TokenClassificationMetric | ||||
| from .video_summarization_metric import VideoSummarizationMetric | from .video_summarization_metric import VideoSummarizationMetric | ||||
| from .movie_scene_segmentation_metric import MovieSceneSegmentationMetric | from .movie_scene_segmentation_metric import MovieSceneSegmentationMetric | ||||
| from .accuracy_metric import AccuracyMetric | |||||
| from .bleu_metric import BleuMetric | |||||
| from .image_inpainting_metric import ImageInpaintingMetric | |||||
| else: | else: | ||||
| _import_structure = { | _import_structure = { | ||||
| @@ -34,6 +37,9 @@ else: | |||||
| 'token_classification_metric': ['TokenClassificationMetric'], | 'token_classification_metric': ['TokenClassificationMetric'], | ||||
| 'video_summarization_metric': ['VideoSummarizationMetric'], | 'video_summarization_metric': ['VideoSummarizationMetric'], | ||||
| 'movie_scene_segmentation_metric': ['MovieSceneSegmentationMetric'], | 'movie_scene_segmentation_metric': ['MovieSceneSegmentationMetric'], | ||||
| 'image_inpainting_metric': ['ImageInpaintingMetric'], | |||||
| 'accuracy_metric': ['AccuracyMetric'], | |||||
| 'bleu_metric': ['BleuMetric'], | |||||
| } | } | ||||
| import sys | import sys | ||||
| @@ -0,0 +1,52 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from typing import Dict | |||||
| import numpy as np | |||||
| from modelscope.metainfo import Metrics | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.utils.registry import default_group | |||||
| from .base import Metric | |||||
| from .builder import METRICS, MetricKeys | |||||
| @METRICS.register_module(group_key=default_group, module_name=Metrics.accuracy) | |||||
| class AccuracyMetric(Metric): | |||||
| """The metric computation class for classification classes. | |||||
| This metric class calculates accuracy for the whole input batches. | |||||
| """ | |||||
| def __init__(self, *args, **kwargs): | |||||
| super().__init__(*args, **kwargs) | |||||
| self.preds = [] | |||||
| self.labels = [] | |||||
| def add(self, outputs: Dict, inputs: Dict): | |||||
| label_name = OutputKeys.LABEL if OutputKeys.LABEL in inputs else OutputKeys.LABELS | |||||
| ground_truths = inputs[label_name] | |||||
| eval_results = outputs[label_name] | |||||
| for key in [ | |||||
| OutputKeys.CAPTION, OutputKeys.TEXT, OutputKeys.BOXES, | |||||
| OutputKeys.LABELS, OutputKeys.SCORES | |||||
| ]: | |||||
| if key in outputs and outputs[key] is not None: | |||||
| eval_results = outputs[key] | |||||
| break | |||||
| assert type(ground_truths) == type(eval_results) | |||||
| for truth in ground_truths: | |||||
| self.labels.append(truth) | |||||
| for result in eval_results: | |||||
| if isinstance(truth, str): | |||||
| self.preds.append(result.strip().replace(' ', '')) | |||||
| else: | |||||
| self.preds.append(result) | |||||
| def evaluate(self): | |||||
| assert len(self.preds) == len(self.labels) | |||||
| return { | |||||
| MetricKeys.ACCURACY: (np.asarray([ | |||||
| pred == ref for pred, ref in zip(self.preds, self.labels) | |||||
| ])).mean().item() | |||||
| } | |||||
| @@ -35,6 +35,8 @@ class AudioNoiseMetric(Metric): | |||||
| total_loss = avg_loss + avg_amp + avg_phase + avg_sisnr | total_loss = avg_loss + avg_amp + avg_phase + avg_sisnr | ||||
| return { | return { | ||||
| 'total_loss': total_loss.item(), | 'total_loss': total_loss.item(), | ||||
| 'avg_sisnr': avg_sisnr.item(), | |||||
| # model use opposite number of sisnr as a calculation shortcut. | |||||
| # revert it in evaluation result | |||||
| 'avg_sisnr': -avg_sisnr.item(), | |||||
| MetricKeys.AVERAGE_LOSS: avg_loss.item() | MetricKeys.AVERAGE_LOSS: avg_loss.item() | ||||
| } | } | ||||
| @@ -10,8 +10,8 @@ class Metric(ABC): | |||||
| complex metrics for a specific task with or without other Metric subclasses. | complex metrics for a specific task with or without other Metric subclasses. | ||||
| """ | """ | ||||
| def __init__(self, trainer=None, *args, **kwargs): | |||||
| self.trainer = trainer | |||||
| def __init__(self, *args, **kwargs): | |||||
| pass | |||||
| @abstractmethod | @abstractmethod | ||||
| def add(self, outputs: Dict, inputs: Dict): | def add(self, outputs: Dict, inputs: Dict): | ||||
| @@ -0,0 +1,42 @@ | |||||
| from itertools import zip_longest | |||||
| from typing import Dict | |||||
| import sacrebleu | |||||
| from modelscope.metainfo import Metrics | |||||
| from modelscope.utils.registry import default_group | |||||
| from .base import Metric | |||||
| from .builder import METRICS, MetricKeys | |||||
| EVAL_BLEU_ORDER = 4 | |||||
| @METRICS.register_module(group_key=default_group, module_name=Metrics.BLEU) | |||||
| class BleuMetric(Metric): | |||||
| """The metric computation bleu for text generation classes. | |||||
| This metric class calculates accuracy for the whole input batches. | |||||
| """ | |||||
| def __init__(self, *args, **kwargs): | |||||
| super().__init__(*args, **kwargs) | |||||
| self.eval_tokenized_bleu = kwargs.get('eval_tokenized_bleu', False) | |||||
| self.hyp_name = kwargs.get('hyp_name', 'hyp') | |||||
| self.ref_name = kwargs.get('ref_name', 'ref') | |||||
| self.refs = list() | |||||
| self.hyps = list() | |||||
| def add(self, outputs: Dict, inputs: Dict): | |||||
| self.refs.extend(inputs[self.ref_name]) | |||||
| self.hyps.extend(outputs[self.hyp_name]) | |||||
| def evaluate(self): | |||||
| if self.eval_tokenized_bleu: | |||||
| bleu = sacrebleu.corpus_bleu( | |||||
| self.hyps, list(zip_longest(*self.refs)), tokenize='none') | |||||
| else: | |||||
| bleu = sacrebleu.corpus_bleu(self.hyps, | |||||
| list(zip_longest(*self.refs))) | |||||
| return { | |||||
| MetricKeys.BLEU_4: bleu.score, | |||||
| } | |||||
| @@ -18,10 +18,12 @@ class MetricKeys(object): | |||||
| SSIM = 'ssim' | SSIM = 'ssim' | ||||
| AVERAGE_LOSS = 'avg_loss' | AVERAGE_LOSS = 'avg_loss' | ||||
| FScore = 'fscore' | FScore = 'fscore' | ||||
| FID = 'fid' | |||||
| BLEU_1 = 'bleu-1' | BLEU_1 = 'bleu-1' | ||||
| BLEU_4 = 'bleu-4' | BLEU_4 = 'bleu-4' | ||||
| ROUGE_1 = 'rouge-1' | ROUGE_1 = 'rouge-1' | ||||
| ROUGE_L = 'rouge-l' | ROUGE_L = 'rouge-l' | ||||
| NED = 'ned' # ocr metric | |||||
| task_default_metrics = { | task_default_metrics = { | ||||
| @@ -31,6 +33,7 @@ task_default_metrics = { | |||||
| Tasks.sentiment_classification: [Metrics.seq_cls_metric], | Tasks.sentiment_classification: [Metrics.seq_cls_metric], | ||||
| Tasks.token_classification: [Metrics.token_cls_metric], | Tasks.token_classification: [Metrics.token_cls_metric], | ||||
| Tasks.text_generation: [Metrics.text_gen_metric], | Tasks.text_generation: [Metrics.text_gen_metric], | ||||
| Tasks.text_classification: [Metrics.seq_cls_metric], | |||||
| Tasks.image_denoising: [Metrics.image_denoise_metric], | Tasks.image_denoising: [Metrics.image_denoise_metric], | ||||
| Tasks.image_color_enhancement: [Metrics.image_color_enhance_metric], | Tasks.image_color_enhancement: [Metrics.image_color_enhance_metric], | ||||
| Tasks.image_portrait_enhancement: | Tasks.image_portrait_enhancement: | ||||
| @@ -39,6 +42,7 @@ task_default_metrics = { | |||||
| Tasks.image_captioning: [Metrics.text_gen_metric], | Tasks.image_captioning: [Metrics.text_gen_metric], | ||||
| Tasks.visual_question_answering: [Metrics.text_gen_metric], | Tasks.visual_question_answering: [Metrics.text_gen_metric], | ||||
| Tasks.movie_scene_segmentation: [Metrics.movie_scene_segmentation_metric], | Tasks.movie_scene_segmentation: [Metrics.movie_scene_segmentation_metric], | ||||
| Tasks.image_inpainting: [Metrics.image_inpainting_metric], | |||||
| } | } | ||||
| @@ -0,0 +1 @@ | |||||
| __author__ = 'tylin' | |||||
| @@ -0,0 +1,57 @@ | |||||
| # Filename: ciderD.py | |||||
| # | |||||
| # Description: Describes the class to compute the CIDEr-D (Consensus-Based Image Description Evaluation) Metric | |||||
| # by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726) | |||||
| # | |||||
| # Creation Date: Sun Feb 8 14:16:54 2015 | |||||
| # | |||||
| # Authors: Ramakrishna Vedantam <vrama91@vt.edu> and Tsung-Yi Lin <tl483@cornell.edu> | |||||
| from __future__ import absolute_import, division, print_function | |||||
| from .ciderD_scorer import CiderScorer | |||||
| class CiderD: | |||||
| """ | |||||
| Main Class to compute the CIDEr metric | |||||
| """ | |||||
| def __init__(self, n=4, sigma=6.0, df='corpus'): | |||||
| # set cider to sum over 1 to 4-grams | |||||
| self._n = n | |||||
| # set the standard deviation parameter for gaussian penalty | |||||
| self._sigma = sigma | |||||
| # set which where to compute document frequencies from | |||||
| self._df = df | |||||
| self.cider_scorer = CiderScorer(n=self._n, df_mode=self._df) | |||||
| def compute_score(self, gts, res): | |||||
| """ | |||||
| Main function to compute CIDEr score | |||||
| :param hypo_for_image (dict) : dictionary with key <image> and value <tokenized hypothesis / candidate sentence> | |||||
| ref_for_image (dict) : dictionary with key <image> and value <tokenized reference sentence> | |||||
| :return: cider (float) : computed CIDEr score for the corpus | |||||
| """ # noqa | |||||
| # clear all the previous hypos and refs | |||||
| tmp_cider_scorer = self.cider_scorer.copy_empty() | |||||
| tmp_cider_scorer.clear() | |||||
| for res_id in res: | |||||
| hypo = res_id['caption'] | |||||
| ref = gts[res_id['image_id']] | |||||
| # Sanity check. | |||||
| assert (type(hypo) is list) | |||||
| assert (len(hypo) == 1) | |||||
| assert (type(ref) is list) | |||||
| assert (len(ref) > 0) | |||||
| tmp_cider_scorer += (hypo[0], ref) | |||||
| (score, scores) = tmp_cider_scorer.compute_score() | |||||
| return score, scores | |||||
| def method(self): | |||||
| return 'CIDEr-D' | |||||
| @@ -0,0 +1,233 @@ | |||||
| #!/usr/bin/env python | |||||
| # Tsung-Yi Lin <tl483@cornell.edu> | |||||
| # Ramakrishna Vedantam <vrama91@vt.edu> | |||||
| from __future__ import absolute_import, division, print_function | |||||
| import copy | |||||
| import math | |||||
| import os | |||||
| import pdb | |||||
| from collections import defaultdict | |||||
| import numpy as np | |||||
| import six | |||||
| from six.moves import cPickle | |||||
| def precook(s, n=4, out=False): | |||||
| """ | |||||
| Takes a string as input and returns an object that can be given to | |||||
| either cook_refs or cook_test. This is optional: cook_refs and cook_test | |||||
| can take string arguments as well. | |||||
| :param s: string : sentence to be converted into ngrams | |||||
| :param n: int : number of ngrams for which representation is calculated | |||||
| :return: term frequency vector for occuring ngrams | |||||
| """ | |||||
| words = s.split() | |||||
| counts = defaultdict(int) | |||||
| for k in range(1, n + 1): | |||||
| for i in range(len(words) - k + 1): | |||||
| ngram = tuple(words[i:i + k]) | |||||
| counts[ngram] += 1 | |||||
| return counts | |||||
| def cook_refs(refs, n=4): # lhuang: oracle will call with "average" | |||||
| '''Takes a list of reference sentences for a single segment | |||||
| and returns an object that encapsulates everything that BLEU | |||||
| needs to know about them. | |||||
| :param refs: list of string : reference sentences for some image | |||||
| :param n: int : number of ngrams for which (ngram) representation is calculated | |||||
| :return: result (list of dict) | |||||
| ''' | |||||
| return [precook(ref, n) for ref in refs] | |||||
| def cook_test(test, n=4): | |||||
| '''Takes a test sentence and returns an object that | |||||
| encapsulates everything that BLEU needs to know about it. | |||||
| :param test: list of string : hypothesis sentence for some image | |||||
| :param n: int : number of ngrams for which (ngram) representation is calculated | |||||
| :return: result (dict) | |||||
| ''' | |||||
| return precook(test, n, True) | |||||
| class CiderScorer(object): | |||||
| """CIDEr scorer. | |||||
| """ | |||||
| def copy(self): | |||||
| ''' copy the refs.''' | |||||
| new = CiderScorer(n=self.n) | |||||
| new.ctest = copy.copy(self.ctest) | |||||
| new.crefs = copy.copy(self.crefs) | |||||
| return new | |||||
| def copy_empty(self): | |||||
| new = CiderScorer(df_mode='corpus', n=self.n, sigma=self.sigma) | |||||
| new.df_mode = self.df_mode | |||||
| new.ref_len = self.ref_len | |||||
| new.document_frequency = self.document_frequency | |||||
| return new | |||||
| def __init__(self, df_mode='corpus', test=None, refs=None, n=4, sigma=6.0): | |||||
| ''' singular instance ''' | |||||
| self.n = n | |||||
| self.sigma = sigma | |||||
| self.crefs = [] | |||||
| self.ctest = [] | |||||
| self.df_mode = df_mode | |||||
| self.ref_len = None | |||||
| if self.df_mode != 'corpus': | |||||
| pkl_file = cPickle.load( | |||||
| open(df_mode, 'rb'), | |||||
| **(dict(encoding='latin1') if six.PY3 else {})) | |||||
| self.ref_len = np.log(float(pkl_file['ref_len'])) | |||||
| self.document_frequency = pkl_file['document_frequency'] | |||||
| else: | |||||
| self.document_frequency = None | |||||
| self.cook_append(test, refs) | |||||
| def clear(self): | |||||
| self.crefs = [] | |||||
| self.ctest = [] | |||||
| def cook_append(self, test, refs): | |||||
| '''called by constructor and __iadd__ to avoid creating new instances.''' | |||||
| if refs is not None: | |||||
| self.crefs.append(cook_refs(refs)) | |||||
| if test is not None: | |||||
| self.ctest.append(cook_test(test)) # N.B.: -1 | |||||
| else: | |||||
| self.ctest.append( | |||||
| None) # lens of crefs and ctest have to match | |||||
| def size(self): | |||||
| assert len(self.crefs) == len( | |||||
| self.ctest), 'refs/test mismatch! %d<>%d' % (len( | |||||
| self.crefs), len(self.ctest)) | |||||
| return len(self.crefs) | |||||
| def __iadd__(self, other): | |||||
| '''add an instance (e.g., from another sentence).''' | |||||
| if type(other) is tuple: | |||||
| # avoid creating new CiderScorer instances | |||||
| self.cook_append(other[0], other[1]) | |||||
| else: | |||||
| self.ctest.extend(other.ctest) | |||||
| self.crefs.extend(other.crefs) | |||||
| return self | |||||
| def compute_doc_freq(self): | |||||
| """ | |||||
| Compute term frequency for reference data. | |||||
| This will be used to compute idf (inverse document frequency later) | |||||
| The term frequency is stored in the object | |||||
| :return: None | |||||
| """ | |||||
| for refs in self.crefs: | |||||
| # refs, k ref captions of one image | |||||
| for ngram in set([ | |||||
| ngram for ref in refs for (ngram, count) in ref.items() | |||||
| ]): # noqa | |||||
| self.document_frequency[ngram] += 1 | |||||
| def compute_cider(self): | |||||
| def counts2vec(cnts): | |||||
| """ | |||||
| Function maps counts of ngram to vector of tfidf weights. | |||||
| The function returns vec, an array of dictionary that store mapping of n-gram and tf-idf weights. | |||||
| The n-th entry of array denotes length of n-grams. | |||||
| :param cnts: | |||||
| :return: vec (array of dict), norm (array of float), length (int) | |||||
| """ | |||||
| vec = [defaultdict(float) for _ in range(self.n)] | |||||
| length = 0 | |||||
| norm = [0.0 for _ in range(self.n)] | |||||
| for (ngram, term_freq) in cnts.items(): | |||||
| # give word count 1 if it doesn't appear in reference corpus | |||||
| df = np.log(max(1.0, self.document_frequency[ngram])) | |||||
| # ngram index | |||||
| n = len(ngram) - 1 | |||||
| # tf (term_freq) * idf (precomputed idf) for n-grams | |||||
| vec[n][ngram] = float(term_freq) * (self.ref_len - df) | |||||
| # compute norm for the vector. the norm will be used for computing similarity | |||||
| norm[n] += pow(vec[n][ngram], 2) | |||||
| if n == 1: | |||||
| length += term_freq | |||||
| norm = [np.sqrt(n) for n in norm] | |||||
| return vec, norm, length | |||||
| def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref): | |||||
| ''' | |||||
| Compute the cosine similarity of two vectors. | |||||
| :param vec_hyp: array of dictionary for vector corresponding to hypothesis | |||||
| :param vec_ref: array of dictionary for vector corresponding to reference | |||||
| :param norm_hyp: array of float for vector corresponding to hypothesis | |||||
| :param norm_ref: array of float for vector corresponding to reference | |||||
| :param length_hyp: int containing length of hypothesis | |||||
| :param length_ref: int containing length of reference | |||||
| :return: array of score for each n-grams cosine similarity | |||||
| ''' | |||||
| delta = float(length_hyp - length_ref) | |||||
| # measure consine similarity | |||||
| val = np.array([0.0 for _ in range(self.n)]) | |||||
| for n in range(self.n): | |||||
| # ngram | |||||
| for (ngram, count) in vec_hyp[n].items(): | |||||
| # vrama91 : added clipping | |||||
| val[n] += min(vec_hyp[n][ngram], | |||||
| vec_ref[n][ngram]) * vec_ref[n][ngram] | |||||
| if (norm_hyp[n] != 0) and (norm_ref[n] != 0): | |||||
| val[n] /= (norm_hyp[n] * norm_ref[n]) | |||||
| assert (not math.isnan(val[n])) | |||||
| # vrama91: added a length based gaussian penalty | |||||
| val[n] *= np.e**(-(delta**2) / (2 * self.sigma**2)) | |||||
| return val | |||||
| # compute log reference length | |||||
| if self.df_mode == 'corpus': | |||||
| self.ref_len = np.log(float(len(self.crefs))) | |||||
| # elif self.df_mode == "coco-val-df": | |||||
| # if coco option selected, use length of coco-val set | |||||
| # self.ref_len = np.log(float(40504)) | |||||
| scores = [] | |||||
| for test, refs in zip(self.ctest, self.crefs): | |||||
| # compute vector for test captions | |||||
| vec, norm, length = counts2vec(test) | |||||
| # compute vector for ref captions | |||||
| score = np.array([0.0 for _ in range(self.n)]) | |||||
| for ref in refs: | |||||
| vec_ref, norm_ref, length_ref = counts2vec(ref) | |||||
| score += sim(vec, vec_ref, norm, norm_ref, length, length_ref) | |||||
| # change by vrama91 - mean of ngram scores, instead of sum | |||||
| score_avg = np.mean(score) | |||||
| # divide by number of references | |||||
| score_avg /= len(refs) | |||||
| # multiply score by 10 | |||||
| score_avg *= 10.0 | |||||
| # append score of an image to the score list | |||||
| scores.append(score_avg) | |||||
| return scores | |||||
| def compute_score(self, option=None, verbose=0): | |||||
| # compute idf | |||||
| if self.df_mode == 'corpus': | |||||
| self.document_frequency = defaultdict(float) | |||||
| self.compute_doc_freq() | |||||
| # assert to check document frequency | |||||
| assert (len(self.ctest) >= max(self.document_frequency.values())) | |||||
| # import json for now and write the corresponding files | |||||
| # compute cider score | |||||
| score = self.compute_cider() | |||||
| # debug | |||||
| # print score | |||||
| return np.mean(np.array(score)), np.array(score) | |||||
| @@ -1,12 +1,16 @@ | |||||
| # ------------------------------------------------------------------------ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| # ------------------------------------------------------------------------ | |||||
| # modified from https://github.com/megvii-research/NAFNet/blob/main/basicsr/metrics/psnr_ssim.py | |||||
| # ------------------------------------------------------------------------ | |||||
| from typing import Dict | from typing import Dict | ||||
| import cv2 | |||||
| import numpy as np | import numpy as np | ||||
| from skimage.metrics import peak_signal_noise_ratio, structural_similarity | |||||
| import torch | |||||
| from modelscope.metainfo import Metrics | from modelscope.metainfo import Metrics | ||||
| from modelscope.utils.registry import default_group | from modelscope.utils.registry import default_group | ||||
| from modelscope.utils.tensor_utils import (torch_nested_detach, | |||||
| torch_nested_numpify) | |||||
| from .base import Metric | from .base import Metric | ||||
| from .builder import METRICS, MetricKeys | from .builder import METRICS, MetricKeys | ||||
| @@ -20,26 +24,249 @@ class ImageDenoiseMetric(Metric): | |||||
| label_name = 'target' | label_name = 'target' | ||||
| def __init__(self): | def __init__(self): | ||||
| super(ImageDenoiseMetric, self).__init__() | |||||
| self.preds = [] | self.preds = [] | ||||
| self.labels = [] | self.labels = [] | ||||
| def add(self, outputs: Dict, inputs: Dict): | def add(self, outputs: Dict, inputs: Dict): | ||||
| ground_truths = outputs[ImageDenoiseMetric.label_name] | ground_truths = outputs[ImageDenoiseMetric.label_name] | ||||
| eval_results = outputs[ImageDenoiseMetric.pred_name] | eval_results = outputs[ImageDenoiseMetric.pred_name] | ||||
| self.preds.append( | |||||
| torch_nested_numpify(torch_nested_detach(eval_results))) | |||||
| self.labels.append( | |||||
| torch_nested_numpify(torch_nested_detach(ground_truths))) | |||||
| self.preds.append(eval_results) | |||||
| self.labels.append(ground_truths) | |||||
| def evaluate(self): | def evaluate(self): | ||||
| psnr_list, ssim_list = [], [] | psnr_list, ssim_list = [], [] | ||||
| for (pred, label) in zip(self.preds, self.labels): | for (pred, label) in zip(self.preds, self.labels): | ||||
| psnr_list.append( | |||||
| peak_signal_noise_ratio(label[0], pred[0], data_range=255)) | |||||
| ssim_list.append( | |||||
| structural_similarity( | |||||
| label[0], pred[0], multichannel=True, data_range=255)) | |||||
| psnr_list.append(calculate_psnr(label[0], pred[0], crop_border=0)) | |||||
| ssim_list.append(calculate_ssim(label[0], pred[0], crop_border=0)) | |||||
| return { | return { | ||||
| MetricKeys.PSNR: np.mean(psnr_list), | MetricKeys.PSNR: np.mean(psnr_list), | ||||
| MetricKeys.SSIM: np.mean(ssim_list) | MetricKeys.SSIM: np.mean(ssim_list) | ||||
| } | } | ||||
| def reorder_image(img, input_order='HWC'): | |||||
| """Reorder images to 'HWC' order. | |||||
| If the input_order is (h, w), return (h, w, 1); | |||||
| If the input_order is (c, h, w), return (h, w, c); | |||||
| If the input_order is (h, w, c), return as it is. | |||||
| Args: | |||||
| img (ndarray): Input image. | |||||
| input_order (str): Whether the input order is 'HWC' or 'CHW'. | |||||
| If the input image shape is (h, w), input_order will not have | |||||
| effects. Default: 'HWC'. | |||||
| Returns: | |||||
| ndarray: reordered image. | |||||
| """ | |||||
| if input_order not in ['HWC', 'CHW']: | |||||
| raise ValueError( | |||||
| f"Wrong input_order {input_order}. Supported input_orders are 'HWC' and 'CHW'" | |||||
| ) | |||||
| if len(img.shape) == 2: | |||||
| img = img[..., None] | |||||
| if input_order == 'CHW': | |||||
| img = img.transpose(1, 2, 0) | |||||
| return img | |||||
| def calculate_psnr(img1, img2, crop_border, input_order='HWC'): | |||||
| """Calculate PSNR (Peak Signal-to-Noise Ratio). | |||||
| Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio | |||||
| Args: | |||||
| img1 (ndarray/tensor): Images with range [0, 255]/[0, 1]. | |||||
| img2 (ndarray/tensor): Images with range [0, 255]/[0, 1]. | |||||
| crop_border (int): Cropped pixels in each edge of an image. These | |||||
| pixels are not involved in the PSNR calculation. | |||||
| input_order (str): Whether the input order is 'HWC' or 'CHW'. | |||||
| Default: 'HWC'. | |||||
| test_y_channel (bool): Test on Y channel of YCbCr. Default: False. | |||||
| Returns: | |||||
| float: psnr result. | |||||
| """ | |||||
| assert img1.shape == img2.shape, ( | |||||
| f'Image shapes are differnet: {img1.shape}, {img2.shape}.') | |||||
| if input_order not in ['HWC', 'CHW']: | |||||
| raise ValueError( | |||||
| f'Wrong input_order {input_order}. Supported input_orders are ' | |||||
| '"HWC" and "CHW"') | |||||
| if type(img1) == torch.Tensor: | |||||
| if len(img1.shape) == 4: | |||||
| img1 = img1.squeeze(0) | |||||
| img1 = img1.detach().cpu().numpy().transpose(1, 2, 0) | |||||
| if type(img2) == torch.Tensor: | |||||
| if len(img2.shape) == 4: | |||||
| img2 = img2.squeeze(0) | |||||
| img2 = img2.detach().cpu().numpy().transpose(1, 2, 0) | |||||
| img1 = reorder_image(img1, input_order=input_order) | |||||
| img2 = reorder_image(img2, input_order=input_order) | |||||
| img1 = img1.astype(np.float64) | |||||
| img2 = img2.astype(np.float64) | |||||
| if crop_border != 0: | |||||
| img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...] | |||||
| img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] | |||||
| def _psnr(img1, img2): | |||||
| mse = np.mean((img1 - img2)**2) | |||||
| if mse == 0: | |||||
| return float('inf') | |||||
| max_value = 1. if img1.max() <= 1 else 255. | |||||
| return 20. * np.log10(max_value / np.sqrt(mse)) | |||||
| return _psnr(img1, img2) | |||||
| def calculate_ssim(img1, img2, crop_border, input_order='HWC', ssim3d=True): | |||||
| """Calculate SSIM (structural similarity). | |||||
| Ref: | |||||
| Image quality assessment: From error visibility to structural similarity | |||||
| The results are the same as that of the official released MATLAB code in | |||||
| https://ece.uwaterloo.ca/~z70wang/research/ssim/. | |||||
| For three-channel images, SSIM is calculated for each channel and then | |||||
| averaged. | |||||
| Args: | |||||
| img1 (ndarray): Images with range [0, 255]. | |||||
| img2 (ndarray): Images with range [0, 255]. | |||||
| crop_border (int): Cropped pixels in each edge of an image. These | |||||
| pixels are not involved in the SSIM calculation. | |||||
| input_order (str): Whether the input order is 'HWC' or 'CHW'. | |||||
| Default: 'HWC'. | |||||
| test_y_channel (bool): Test on Y channel of YCbCr. Default: False. | |||||
| Returns: | |||||
| float: ssim result. | |||||
| """ | |||||
| assert img1.shape == img2.shape, ( | |||||
| f'Image shapes are differnet: {img1.shape}, {img2.shape}.') | |||||
| if input_order not in ['HWC', 'CHW']: | |||||
| raise ValueError( | |||||
| f'Wrong input_order {input_order}. Supported input_orders are ' | |||||
| '"HWC" and "CHW"') | |||||
| if type(img1) == torch.Tensor: | |||||
| if len(img1.shape) == 4: | |||||
| img1 = img1.squeeze(0) | |||||
| img1 = img1.detach().cpu().numpy().transpose(1, 2, 0) | |||||
| if type(img2) == torch.Tensor: | |||||
| if len(img2.shape) == 4: | |||||
| img2 = img2.squeeze(0) | |||||
| img2 = img2.detach().cpu().numpy().transpose(1, 2, 0) | |||||
| img1 = reorder_image(img1, input_order=input_order) | |||||
| img2 = reorder_image(img2, input_order=input_order) | |||||
| img1 = img1.astype(np.float64) | |||||
| img2 = img2.astype(np.float64) | |||||
| if crop_border != 0: | |||||
| img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...] | |||||
| img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] | |||||
| def _cal_ssim(img1, img2): | |||||
| ssims = [] | |||||
| max_value = 1 if img1.max() <= 1 else 255 | |||||
| with torch.no_grad(): | |||||
| final_ssim = _ssim_3d(img1, img2, max_value) if ssim3d else _ssim( | |||||
| img1, img2, max_value) | |||||
| ssims.append(final_ssim) | |||||
| return np.array(ssims).mean() | |||||
| return _cal_ssim(img1, img2) | |||||
| def _ssim(img, img2, max_value): | |||||
| """Calculate SSIM (structural similarity) for one channel images. | |||||
| It is called by func:`calculate_ssim`. | |||||
| Args: | |||||
| img (ndarray): Images with range [0, 255] with order 'HWC'. | |||||
| img2 (ndarray): Images with range [0, 255] with order 'HWC'. | |||||
| Returns: | |||||
| float: SSIM result. | |||||
| """ | |||||
| c1 = (0.01 * max_value)**2 | |||||
| c2 = (0.03 * max_value)**2 | |||||
| img = img.astype(np.float64) | |||||
| img2 = img2.astype(np.float64) | |||||
| kernel = cv2.getGaussianKernel(11, 1.5) | |||||
| window = np.outer(kernel, kernel.transpose()) | |||||
| mu1 = cv2.filter2D(img, -1, window)[5:-5, | |||||
| 5:-5] # valid mode for window size 11 | |||||
| mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] | |||||
| mu1_sq = mu1**2 | |||||
| mu2_sq = mu2**2 | |||||
| mu1_mu2 = mu1 * mu2 | |||||
| sigma1_sq = cv2.filter2D(img**2, -1, window)[5:-5, 5:-5] - mu1_sq | |||||
| sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq | |||||
| sigma12 = cv2.filter2D(img * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 | |||||
| tmp1 = (2 * mu1_mu2 + c1) * (2 * sigma12 + c2) | |||||
| tmp2 = (mu1_sq + mu2_sq + c1) * (sigma1_sq + sigma2_sq + c2) | |||||
| ssim_map = tmp1 / tmp2 | |||||
| return ssim_map.mean() | |||||
| def _3d_gaussian_calculator(img, conv3d): | |||||
| out = conv3d(img.unsqueeze(0).unsqueeze(0)).squeeze(0).squeeze(0) | |||||
| return out | |||||
| def _generate_3d_gaussian_kernel(): | |||||
| kernel = cv2.getGaussianKernel(11, 1.5) | |||||
| window = np.outer(kernel, kernel.transpose()) | |||||
| kernel_3 = cv2.getGaussianKernel(11, 1.5) | |||||
| kernel = torch.tensor(np.stack([window * k for k in kernel_3], axis=0)) | |||||
| conv3d = torch.nn.Conv3d( | |||||
| 1, | |||||
| 1, (11, 11, 11), | |||||
| stride=1, | |||||
| padding=(5, 5, 5), | |||||
| bias=False, | |||||
| padding_mode='replicate') | |||||
| conv3d.weight.requires_grad = False | |||||
| conv3d.weight[0, 0, :, :, :] = kernel | |||||
| return conv3d | |||||
| def _ssim_3d(img1, img2, max_value): | |||||
| assert len(img1.shape) == 3 and len(img2.shape) == 3 | |||||
| """Calculate SSIM (structural similarity) for one channel images. | |||||
| It is called by func:`calculate_ssim`. | |||||
| Args: | |||||
| img1 (ndarray): Images with range [0, 255]/[0, 1] with order 'HWC'. | |||||
| img2 (ndarray): Images with range [0, 255]/[0, 1] with order 'HWC'. | |||||
| Returns: | |||||
| float: ssim result. | |||||
| """ | |||||
| C1 = (0.01 * max_value)**2 | |||||
| C2 = (0.03 * max_value)**2 | |||||
| img1 = img1.astype(np.float64) | |||||
| img2 = img2.astype(np.float64) | |||||
| kernel = _generate_3d_gaussian_kernel().cuda() | |||||
| img1 = torch.tensor(img1).float().cuda() | |||||
| img2 = torch.tensor(img2).float().cuda() | |||||
| mu1 = _3d_gaussian_calculator(img1, kernel) | |||||
| mu2 = _3d_gaussian_calculator(img2, kernel) | |||||
| mu1_sq = mu1**2 | |||||
| mu2_sq = mu2**2 | |||||
| mu1_mu2 = mu1 * mu2 | |||||
| sigma1_sq = _3d_gaussian_calculator(img1**2, kernel) - mu1_sq | |||||
| sigma2_sq = _3d_gaussian_calculator(img2**2, kernel) - mu2_sq | |||||
| sigma12 = _3d_gaussian_calculator(img1 * img2, kernel) - mu1_mu2 | |||||
| tmp1 = (2 * mu1_mu2 + C1) * (2 * sigma12 + C2) | |||||
| tmp2 = (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2) | |||||
| ssim_map = tmp1 / tmp2 | |||||
| return float(ssim_map.mean()) | |||||
| @@ -0,0 +1,210 @@ | |||||
| """ | |||||
| Part of the implementation is borrowed and modified from LaMa, publicly available at | |||||
| https://github.com/saic-mdal/lama | |||||
| """ | |||||
| from typing import Dict | |||||
| import numpy as np | |||||
| import torch | |||||
| import torch.nn.functional as F | |||||
| from scipy import linalg | |||||
| from modelscope.metainfo import Metrics | |||||
| from modelscope.models.cv.image_inpainting.modules.inception import InceptionV3 | |||||
| from modelscope.utils.registry import default_group | |||||
| from modelscope.utils.tensor_utils import (torch_nested_detach, | |||||
| torch_nested_numpify) | |||||
| from .base import Metric | |||||
| from .builder import METRICS, MetricKeys | |||||
| def fid_calculate_activation_statistics(act): | |||||
| mu = np.mean(act, axis=0) | |||||
| sigma = np.cov(act, rowvar=False) | |||||
| return mu, sigma | |||||
| def calculate_frechet_distance(activations_pred, activations_target, eps=1e-6): | |||||
| mu1, sigma1 = fid_calculate_activation_statistics(activations_pred) | |||||
| mu2, sigma2 = fid_calculate_activation_statistics(activations_target) | |||||
| diff = mu1 - mu2 | |||||
| # Product might be almost singular | |||||
| covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) | |||||
| if not np.isfinite(covmean).all(): | |||||
| offset = np.eye(sigma1.shape[0]) * eps | |||||
| covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) | |||||
| # Numerical error might give slight imaginary component | |||||
| if np.iscomplexobj(covmean): | |||||
| # if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): | |||||
| if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-2): | |||||
| m = np.max(np.abs(covmean.imag)) | |||||
| raise ValueError('Imaginary component {}'.format(m)) | |||||
| covmean = covmean.real | |||||
| tr_covmean = np.trace(covmean) | |||||
| return (diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) | |||||
| - 2 * tr_covmean) | |||||
| class FIDScore(torch.nn.Module): | |||||
| def __init__(self, dims=2048, eps=1e-6): | |||||
| super().__init__() | |||||
| if getattr(FIDScore, '_MODEL', None) is None: | |||||
| block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] | |||||
| FIDScore._MODEL = InceptionV3([block_idx]).eval() | |||||
| self.model = FIDScore._MODEL | |||||
| self.eps = eps | |||||
| self.reset() | |||||
| def forward(self, pred_batch, target_batch, mask=None): | |||||
| activations_pred = self._get_activations(pred_batch) | |||||
| activations_target = self._get_activations(target_batch) | |||||
| self.activations_pred.append(activations_pred.detach().cpu()) | |||||
| self.activations_target.append(activations_target.detach().cpu()) | |||||
| def get_value(self): | |||||
| activations_pred, activations_target = (self.activations_pred, | |||||
| self.activations_target) | |||||
| activations_pred = torch.cat(activations_pred).cpu().numpy() | |||||
| activations_target = torch.cat(activations_target).cpu().numpy() | |||||
| total_distance = calculate_frechet_distance( | |||||
| activations_pred, activations_target, eps=self.eps) | |||||
| self.reset() | |||||
| return total_distance | |||||
| def reset(self): | |||||
| self.activations_pred = [] | |||||
| self.activations_target = [] | |||||
| def _get_activations(self, batch): | |||||
| activations = self.model(batch)[0] | |||||
| if activations.shape[2] != 1 or activations.shape[3] != 1: | |||||
| assert False, \ | |||||
| 'We should not have got here, because Inception always scales inputs to 299x299' | |||||
| activations = activations.squeeze(-1).squeeze(-1) | |||||
| return activations | |||||
| class SSIM(torch.nn.Module): | |||||
| """SSIM. Modified from: | |||||
| https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/pytorch_ssim/__init__.py | |||||
| """ | |||||
| def __init__(self, window_size=11, size_average=True): | |||||
| super().__init__() | |||||
| self.window_size = window_size | |||||
| self.size_average = size_average | |||||
| self.channel = 1 | |||||
| self.register_buffer('window', | |||||
| self._create_window(window_size, self.channel)) | |||||
| def forward(self, img1, img2): | |||||
| assert len(img1.shape) == 4 | |||||
| channel = img1.size()[1] | |||||
| if channel == self.channel and self.window.data.type( | |||||
| ) == img1.data.type(): | |||||
| window = self.window | |||||
| else: | |||||
| window = self._create_window(self.window_size, channel) | |||||
| window = window.type_as(img1) | |||||
| self.window = window | |||||
| self.channel = channel | |||||
| return self._ssim(img1, img2, window, self.window_size, channel, | |||||
| self.size_average) | |||||
| def _gaussian(self, window_size, sigma): | |||||
| gauss = torch.Tensor([ | |||||
| np.exp(-(x - (window_size // 2))**2 / float(2 * sigma**2)) | |||||
| for x in range(window_size) | |||||
| ]) | |||||
| return gauss / gauss.sum() | |||||
| def _create_window(self, window_size, channel): | |||||
| _1D_window = self._gaussian(window_size, 1.5).unsqueeze(1) | |||||
| _2D_window = _1D_window.mm( | |||||
| _1D_window.t()).float().unsqueeze(0).unsqueeze(0) | |||||
| return _2D_window.expand(channel, 1, window_size, | |||||
| window_size).contiguous() | |||||
| def _ssim(self, | |||||
| img1, | |||||
| img2, | |||||
| window, | |||||
| window_size, | |||||
| channel, | |||||
| size_average=True): | |||||
| mu1 = F.conv2d( | |||||
| img1, window, padding=(window_size // 2), groups=channel) | |||||
| mu2 = F.conv2d( | |||||
| img2, window, padding=(window_size // 2), groups=channel) | |||||
| mu1_sq = mu1.pow(2) | |||||
| mu2_sq = mu2.pow(2) | |||||
| mu1_mu2 = mu1 * mu2 | |||||
| sigma1_sq = F.conv2d( | |||||
| img1 * img1, window, padding=(window_size // 2), | |||||
| groups=channel) - mu1_sq | |||||
| sigma2_sq = F.conv2d( | |||||
| img2 * img2, window, padding=(window_size // 2), | |||||
| groups=channel) - mu2_sq | |||||
| sigma12 = F.conv2d( | |||||
| img1 * img2, window, padding=(window_size // 2), | |||||
| groups=channel) - mu1_mu2 | |||||
| C1 = 0.01**2 | |||||
| C2 = 0.03**2 | |||||
| ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / \ | |||||
| ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) | |||||
| if size_average: | |||||
| return ssim_map.mean() | |||||
| return ssim_map.mean(1).mean(1).mean(1) | |||||
| def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, | |||||
| missing_keys, unexpected_keys, error_msgs): | |||||
| return | |||||
| @METRICS.register_module( | |||||
| group_key=default_group, module_name=Metrics.image_inpainting_metric) | |||||
| class ImageInpaintingMetric(Metric): | |||||
| """The metric computation class for image inpainting classes. | |||||
| """ | |||||
| def __init__(self): | |||||
| self.preds = [] | |||||
| self.targets = [] | |||||
| self.SSIM = SSIM(window_size=11, size_average=False).eval() | |||||
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |||||
| self.FID = FIDScore().to(device) | |||||
| def add(self, outputs: Dict, inputs: Dict): | |||||
| pred = outputs['inpainted'] | |||||
| target = inputs['image'] | |||||
| self.preds.append(torch_nested_detach(pred)) | |||||
| self.targets.append(torch_nested_detach(target)) | |||||
| def evaluate(self): | |||||
| ssim_list = [] | |||||
| for (pred, target) in zip(self.preds, self.targets): | |||||
| ssim_list.append(self.SSIM(pred, target)) | |||||
| self.FID(pred, target) | |||||
| ssim_list = torch_nested_numpify(ssim_list) | |||||
| fid = self.FID.get_value() | |||||
| return {MetricKeys.SSIM: np.mean(ssim_list), MetricKeys.FID: fid} | |||||
| @@ -1,5 +1,8 @@ | |||||
| # Part of the implementation is borrowed and modified from BasicSR, publicly available at | |||||
| # https://github.com/XPixelGroup/BasicSR/blob/master/basicsr/metrics/psnr_ssim.py | |||||
| from typing import Dict | from typing import Dict | ||||
| import cv2 | |||||
| import numpy as np | import numpy as np | ||||
| from modelscope.metainfo import Metrics | from modelscope.metainfo import Metrics | ||||
| @@ -35,6 +38,7 @@ class ImagePortraitEnhancementMetric(Metric): | |||||
| def add(self, outputs: Dict, inputs: Dict): | def add(self, outputs: Dict, inputs: Dict): | ||||
| ground_truths = outputs['target'] | ground_truths = outputs['target'] | ||||
| eval_results = outputs['pred'] | eval_results = outputs['pred'] | ||||
| self.preds.extend(eval_results) | self.preds.extend(eval_results) | ||||
| self.targets.extend(ground_truths) | self.targets.extend(ground_truths) | ||||
| @@ -0,0 +1,87 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from typing import Dict | |||||
| import numpy as np | |||||
| from modelscope.metainfo import Metrics | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.utils.registry import default_group | |||||
| from .base import Metric | |||||
| from .builder import METRICS, MetricKeys | |||||
| @METRICS.register_module(group_key=default_group, module_name=Metrics.NED) | |||||
| class NedMetric(Metric): | |||||
| """The ned metric computation class for classification classes. | |||||
| This metric class calculates the levenshtein distance between sentences for the whole input batches. | |||||
| """ | |||||
| def __init__(self, *args, **kwargs): | |||||
| super().__init__(*args, **kwargs) | |||||
| self.preds = [] | |||||
| self.labels = [] | |||||
| def add(self, outputs: Dict, inputs: Dict): | |||||
| label_name = OutputKeys.LABEL if OutputKeys.LABEL in inputs else OutputKeys.LABELS | |||||
| ground_truths = inputs[label_name] | |||||
| eval_results = outputs[label_name] | |||||
| for key in [ | |||||
| OutputKeys.CAPTION, OutputKeys.TEXT, OutputKeys.BOXES, | |||||
| OutputKeys.LABELS, OutputKeys.SCORES | |||||
| ]: | |||||
| if key in outputs and outputs[key] is not None: | |||||
| eval_results = outputs[key] | |||||
| break | |||||
| assert type(ground_truths) == type(eval_results) | |||||
| if isinstance(ground_truths, list): | |||||
| self.preds.extend(eval_results) | |||||
| self.labels.extend(ground_truths) | |||||
| elif isinstance(ground_truths, np.ndarray): | |||||
| self.preds.extend(eval_results.tolist()) | |||||
| self.labels.extend(ground_truths.tolist()) | |||||
| else: | |||||
| raise Exception('only support list or np.ndarray') | |||||
| def evaluate(self): | |||||
| assert len(self.preds) == len(self.labels) | |||||
| return { | |||||
| MetricKeys.NED: (np.asarray([ | |||||
| 1.0 - NedMetric._distance(pred, ref) | |||||
| for pred, ref in zip(self.preds, self.labels) | |||||
| ])).mean().item() | |||||
| } | |||||
| @staticmethod | |||||
| def _distance(pred, ref): | |||||
| if pred is None or ref is None: | |||||
| raise TypeError('Argument (pred or ref) is NoneType.') | |||||
| if pred == ref: | |||||
| return 0.0 | |||||
| if len(pred) == 0: | |||||
| return len(ref) | |||||
| if len(ref) == 0: | |||||
| return len(pred) | |||||
| m_len = max(len(pred), len(ref)) | |||||
| if m_len == 0: | |||||
| return 0.0 | |||||
| def levenshtein(s0, s1): | |||||
| v0 = [0] * (len(s1) + 1) | |||||
| v1 = [0] * (len(s1) + 1) | |||||
| for i in range(len(v0)): | |||||
| v0[i] = i | |||||
| for i in range(len(s0)): | |||||
| v1[0] = i + 1 | |||||
| for j in range(len(s1)): | |||||
| cost = 1 | |||||
| if s0[i] == s1[j]: | |||||
| cost = 0 | |||||
| v1[j + 1] = min(v1[j] + 1, v0[j + 1] + 1, v0[j] + cost) | |||||
| v0, v1 = v1, v0 | |||||
| return v0[len(s1)] | |||||
| return levenshtein(pred, ref) / m_len | |||||
| @@ -36,20 +36,31 @@ class TextGenerationMetric(Metric): | |||||
| for char in string | for char in string | ||||
| ]).split()) | ]).split()) | ||||
| def add(self, outputs: Dict[str, List[str]], inputs: Dict = None): | |||||
| ground_truths = outputs['tgts'] | |||||
| def add(self, outputs: Dict[str, List[str]], inputs: Dict[str, List[str]]): | |||||
| ground_truths = inputs['tgts'] | |||||
| eval_results = outputs['preds'] | eval_results = outputs['preds'] | ||||
| for truth in ground_truths: | for truth in ground_truths: | ||||
| self.tgts.append(self.rebuild_str(truth)) | self.tgts.append(self.rebuild_str(truth)) | ||||
| for result in eval_results: | for result in eval_results: | ||||
| self.preds.append(self.rebuild_str(result)) | self.preds.append(self.rebuild_str(result)) | ||||
| def _check(self, pred: str, tgt: str) -> bool: | |||||
| def remove_useless(string: str) -> str: | |||||
| return string.replace(' ', '').replace('.', '') | |||||
| return remove_useless(pred) and remove_useless(tgt) | |||||
| def evaluate(self): | def evaluate(self): | ||||
| assert self.preds, 'preds in TextGenerationMetric must not be empty!' | |||||
| tmp = [(pred, tgt) for pred, tgt in zip(self.preds, self.tgts) | |||||
| if self._check(pred, tgt)] | |||||
| preds, tgts = zip(*tmp) | |||||
| def mean(iter: Iterable) -> float: | def mean(iter: Iterable) -> float: | ||||
| return sum(iter) / len(self.preds) | return sum(iter) / len(self.preds) | ||||
| rouge_scores = self.rouge.get_scores(hyps=self.preds, refs=self.tgts) | |||||
| rouge_scores = self.rouge.get_scores(hyps=preds, refs=tgts) | |||||
| rouge_1 = mean(map(lambda score: score['rouge-1']['f'], rouge_scores)) | rouge_1 = mean(map(lambda score: score['rouge-1']['f'], rouge_scores)) | ||||
| rouge_l = mean(map(lambda score: score['rouge-l']['f'], rouge_scores)) | rouge_l = mean(map(lambda score: score['rouge-l']['f'], rouge_scores)) | ||||
| pred_split = tuple(pred.split(' ') for pred in self.preds) | pred_split = tuple(pred.split(' ') for pred in self.preds) | ||||
| @@ -34,17 +34,24 @@ class TokenClassificationMetric(Metric): | |||||
| self.labels.append( | self.labels.append( | ||||
| torch_nested_numpify(torch_nested_detach(ground_truths))) | torch_nested_numpify(torch_nested_detach(ground_truths))) | ||||
| def __init__(self, return_entity_level_metrics=False, *args, **kwargs): | |||||
| def __init__(self, | |||||
| return_entity_level_metrics=False, | |||||
| label2id=None, | |||||
| *args, | |||||
| **kwargs): | |||||
| super().__init__(*args, **kwargs) | super().__init__(*args, **kwargs) | ||||
| self.return_entity_level_metrics = return_entity_level_metrics | self.return_entity_level_metrics = return_entity_level_metrics | ||||
| self.preds = [] | self.preds = [] | ||||
| self.labels = [] | self.labels = [] | ||||
| self.label2id = label2id | |||||
| def evaluate(self): | def evaluate(self): | ||||
| self.id2label = { | |||||
| id: label | |||||
| for label, id in self.trainer.label2id.items() | |||||
| } | |||||
| label2id = self.label2id | |||||
| if label2id is None: | |||||
| assert hasattr(self, 'trainer') | |||||
| label2id = self.trainer.label2id | |||||
| self.id2label = {id: label for label, id in label2id.items()} | |||||
| self.preds = np.concatenate(self.preds, axis=0) | self.preds = np.concatenate(self.preds, axis=0) | ||||
| self.labels = np.concatenate(self.labels, axis=0) | self.labels = np.concatenate(self.labels, axis=0) | ||||
| predictions = np.argmax(self.preds, axis=-1) | predictions = np.argmax(self.preds, axis=-1) | ||||
| @@ -1,3 +1,6 @@ | |||||
| # Part of the implementation is borrowed and modified from PGL-SUM, | |||||
| # publicly available at https://github.com/e-apostolidis/PGL-SUM | |||||
| from typing import Dict | from typing import Dict | ||||
| import numpy as np | import numpy as np | ||||
| @@ -1,3 +1,5 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os | import os | ||||
| from typing import Any, Dict | from typing import Any, Dict | ||||
| @@ -1,15 +1,14 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| import os | import os | ||||
| from typing import Dict | |||||
| import torch | |||||
| from typing import Dict, Optional | |||||
| from modelscope.metainfo import Models | from modelscope.metainfo import Models | ||||
| from modelscope.models import TorchModel | from modelscope.models import TorchModel | ||||
| from modelscope.models.base import Tensor | from modelscope.models.base import Tensor | ||||
| from modelscope.models.builder import MODELS | from modelscope.models.builder import MODELS | ||||
| from modelscope.utils.constant import ModelFile, Tasks | |||||
| from modelscope.utils.audio.audio_utils import update_conf | |||||
| from modelscope.utils.constant import Tasks | |||||
| from .fsmn_sele_v2 import FSMNSeleNetV2 | from .fsmn_sele_v2 import FSMNSeleNetV2 | ||||
| @@ -20,48 +19,38 @@ class FSMNSeleNetV2Decorator(TorchModel): | |||||
| MODEL_TXT = 'model.txt' | MODEL_TXT = 'model.txt' | ||||
| SC_CONFIG = 'sound_connect.conf' | SC_CONFIG = 'sound_connect.conf' | ||||
| SC_CONF_ITEM_KWS_MODEL = '${kws_model}' | |||||
| def __init__(self, model_dir: str, *args, **kwargs): | |||||
| def __init__(self, | |||||
| model_dir: str, | |||||
| training: Optional[bool] = False, | |||||
| *args, | |||||
| **kwargs): | |||||
| """initialize the dfsmn model from the `model_dir` path. | """initialize the dfsmn model from the `model_dir` path. | ||||
| Args: | Args: | ||||
| model_dir (str): the model path. | model_dir (str): the model path. | ||||
| """ | """ | ||||
| super().__init__(model_dir, *args, **kwargs) | super().__init__(model_dir, *args, **kwargs) | ||||
| sc_config_file = os.path.join(model_dir, self.SC_CONFIG) | |||||
| model_txt_file = os.path.join(model_dir, self.MODEL_TXT) | |||||
| model_bin_file = os.path.join(model_dir, | |||||
| ModelFile.TORCH_MODEL_BIN_FILE) | |||||
| self._model = None | |||||
| if os.path.exists(model_bin_file): | |||||
| kwargs.pop('device') | |||||
| self._model = FSMNSeleNetV2(*args, **kwargs) | |||||
| checkpoint = torch.load(model_bin_file) | |||||
| self._model.load_state_dict(checkpoint, strict=False) | |||||
| self._sc = None | |||||
| if os.path.exists(model_txt_file): | |||||
| with open(sc_config_file) as f: | |||||
| lines = f.readlines() | |||||
| with open(sc_config_file, 'w') as f: | |||||
| for line in lines: | |||||
| if self.SC_CONF_ITEM_KWS_MODEL in line: | |||||
| line = line.replace(self.SC_CONF_ITEM_KWS_MODEL, | |||||
| model_txt_file) | |||||
| f.write(line) | |||||
| import py_sound_connect | |||||
| self._sc = py_sound_connect.SoundConnect(sc_config_file) | |||||
| self.size_in = self._sc.bytesPerBlockIn() | |||||
| self.size_out = self._sc.bytesPerBlockOut() | |||||
| if self._model is None and self._sc is None: | |||||
| raise Exception( | |||||
| f'Invalid model directory! Neither {model_txt_file} nor {model_bin_file} exists.' | |||||
| ) | |||||
| if training: | |||||
| self.model = FSMNSeleNetV2(*args, **kwargs) | |||||
| else: | |||||
| sc_config_file = os.path.join(model_dir, self.SC_CONFIG) | |||||
| model_txt_file = os.path.join(model_dir, self.MODEL_TXT) | |||||
| self._sc = None | |||||
| if os.path.exists(model_txt_file): | |||||
| conf_dict = dict(mode=56542, kws_model=model_txt_file) | |||||
| update_conf(sc_config_file, sc_config_file, conf_dict) | |||||
| import py_sound_connect | |||||
| self._sc = py_sound_connect.SoundConnect(sc_config_file) | |||||
| self.size_in = self._sc.bytesPerBlockIn() | |||||
| self.size_out = self._sc.bytesPerBlockOut() | |||||
| else: | |||||
| raise Exception( | |||||
| f'Invalid model directory! Failed to load model file: {model_txt_file}.' | |||||
| ) | |||||
| def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | ||||
| ... | |||||
| return self.model.forward(input) | |||||
| def forward_decode(self, data: bytes): | def forward_decode(self, data: bytes): | ||||
| result = {'pcm': self._sc.process(data, self.size_out)} | result = {'pcm': self._sc.process(data, self.size_out)} | ||||
| @@ -1,3 +1,5 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os | import os | ||||
| from typing import Any, Dict | from typing import Any, Dict | ||||
| @@ -2,6 +2,7 @@ | |||||
| import os | import os | ||||
| import pickle as pkl | import pickle as pkl | ||||
| from threading import Lock | |||||
| import json | import json | ||||
| import numpy as np | import numpy as np | ||||
| @@ -27,6 +28,7 @@ class Voice: | |||||
| self.__am_config = AttrDict(**am_config) | self.__am_config = AttrDict(**am_config) | ||||
| self.__voc_config = AttrDict(**voc_config) | self.__voc_config = AttrDict(**voc_config) | ||||
| self.__model_loaded = False | self.__model_loaded = False | ||||
| self.__lock = Lock() | |||||
| if 'am' not in self.__am_config: | if 'am' not in self.__am_config: | ||||
| raise TtsModelConfigurationException( | raise TtsModelConfigurationException( | ||||
| 'modelscope error: am configuration invalid') | 'modelscope error: am configuration invalid') | ||||
| @@ -71,34 +73,35 @@ class Voice: | |||||
| self.__generator.remove_weight_norm() | self.__generator.remove_weight_norm() | ||||
| def __am_forward(self, symbol_seq): | def __am_forward(self, symbol_seq): | ||||
| with torch.no_grad(): | |||||
| inputs_feat_lst = self.__ling_unit.encode_symbol_sequence( | |||||
| symbol_seq) | |||||
| inputs_sy = torch.from_numpy(inputs_feat_lst[0]).long().to( | |||||
| self.__device) | |||||
| inputs_tone = torch.from_numpy(inputs_feat_lst[1]).long().to( | |||||
| self.__device) | |||||
| inputs_syllable = torch.from_numpy(inputs_feat_lst[2]).long().to( | |||||
| self.__device) | |||||
| inputs_ws = torch.from_numpy(inputs_feat_lst[3]).long().to( | |||||
| self.__device) | |||||
| inputs_ling = torch.stack( | |||||
| [inputs_sy, inputs_tone, inputs_syllable, inputs_ws], | |||||
| dim=-1).unsqueeze(0) | |||||
| inputs_emo = torch.from_numpy(inputs_feat_lst[4]).long().to( | |||||
| self.__device).unsqueeze(0) | |||||
| inputs_spk = torch.from_numpy(inputs_feat_lst[5]).long().to( | |||||
| self.__device).unsqueeze(0) | |||||
| inputs_len = torch.zeros(1).to(self.__device).long( | |||||
| ) + inputs_emo.size(1) - 1 # minus 1 for "~" | |||||
| res = self.__am_net(inputs_ling[:, :-1, :], inputs_emo[:, :-1], | |||||
| inputs_spk[:, :-1], inputs_len) | |||||
| postnet_outputs = res['postnet_outputs'] | |||||
| LR_length_rounded = res['LR_length_rounded'] | |||||
| valid_length = int(LR_length_rounded[0].item()) | |||||
| postnet_outputs = postnet_outputs[ | |||||
| 0, :valid_length, :].cpu().numpy() | |||||
| return postnet_outputs | |||||
| with self.__lock: | |||||
| with torch.no_grad(): | |||||
| inputs_feat_lst = self.__ling_unit.encode_symbol_sequence( | |||||
| symbol_seq) | |||||
| inputs_sy = torch.from_numpy(inputs_feat_lst[0]).long().to( | |||||
| self.__device) | |||||
| inputs_tone = torch.from_numpy(inputs_feat_lst[1]).long().to( | |||||
| self.__device) | |||||
| inputs_syllable = torch.from_numpy( | |||||
| inputs_feat_lst[2]).long().to(self.__device) | |||||
| inputs_ws = torch.from_numpy(inputs_feat_lst[3]).long().to( | |||||
| self.__device) | |||||
| inputs_ling = torch.stack( | |||||
| [inputs_sy, inputs_tone, inputs_syllable, inputs_ws], | |||||
| dim=-1).unsqueeze(0) | |||||
| inputs_emo = torch.from_numpy(inputs_feat_lst[4]).long().to( | |||||
| self.__device).unsqueeze(0) | |||||
| inputs_spk = torch.from_numpy(inputs_feat_lst[5]).long().to( | |||||
| self.__device).unsqueeze(0) | |||||
| inputs_len = torch.zeros(1).to(self.__device).long( | |||||
| ) + inputs_emo.size(1) - 1 # minus 1 for "~" | |||||
| res = self.__am_net(inputs_ling[:, :-1, :], inputs_emo[:, :-1], | |||||
| inputs_spk[:, :-1], inputs_len) | |||||
| postnet_outputs = res['postnet_outputs'] | |||||
| LR_length_rounded = res['LR_length_rounded'] | |||||
| valid_length = int(LR_length_rounded[0].item()) | |||||
| postnet_outputs = postnet_outputs[ | |||||
| 0, :valid_length, :].cpu().numpy() | |||||
| return postnet_outputs | |||||
| def __vocoder_forward(self, melspec): | def __vocoder_forward(self, melspec): | ||||
| dim0 = list(melspec.shape)[-1] | dim0 = list(melspec.shape)[-1] | ||||
| @@ -118,14 +121,15 @@ class Voice: | |||||
| return audio | return audio | ||||
| def forward(self, symbol_seq): | def forward(self, symbol_seq): | ||||
| if not self.__model_loaded: | |||||
| torch.manual_seed(self.__am_config.seed) | |||||
| if torch.cuda.is_available(): | |||||
| with self.__lock: | |||||
| if not self.__model_loaded: | |||||
| torch.manual_seed(self.__am_config.seed) | torch.manual_seed(self.__am_config.seed) | ||||
| self.__device = torch.device('cuda') | |||||
| else: | |||||
| self.__device = torch.device('cpu') | |||||
| self.__load_am() | |||||
| self.__load_vocoder() | |||||
| self.__model_loaded = True | |||||
| if torch.cuda.is_available(): | |||||
| torch.manual_seed(self.__am_config.seed) | |||||
| self.__device = torch.device('cuda') | |||||
| else: | |||||
| self.__device = torch.device('cpu') | |||||
| self.__load_am() | |||||
| self.__load_vocoder() | |||||
| self.__model_loaded = True | |||||
| return self.__vocoder_forward(self.__am_forward(symbol_seq)) | return self.__vocoder_forward(self.__am_forward(symbol_seq)) | ||||
| @@ -5,11 +5,11 @@ from abc import ABC, abstractmethod | |||||
| from typing import Any, Callable, Dict, List, Optional, Union | from typing import Any, Callable, Dict, List, Optional, Union | ||||
| from modelscope.hub.snapshot_download import snapshot_download | from modelscope.hub.snapshot_download import snapshot_download | ||||
| from modelscope.models.builder import build_model | |||||
| from modelscope.utils.checkpoint import save_pretrained | |||||
| from modelscope.models.builder import MODELS, build_model | |||||
| from modelscope.utils.checkpoint import save_checkpoint, save_pretrained | |||||
| from modelscope.utils.config import Config | from modelscope.utils.config import Config | ||||
| from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile | |||||
| from modelscope.utils.device import device_placement, verify_device | |||||
| from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile, Tasks | |||||
| from modelscope.utils.device import verify_device | |||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| logger = get_logger() | logger = get_logger() | ||||
| @@ -66,7 +66,6 @@ class Model(ABC): | |||||
| revision: Optional[str] = DEFAULT_MODEL_REVISION, | revision: Optional[str] = DEFAULT_MODEL_REVISION, | ||||
| cfg_dict: Config = None, | cfg_dict: Config = None, | ||||
| device: str = None, | device: str = None, | ||||
| *model_args, | |||||
| **kwargs): | **kwargs): | ||||
| """ Instantiate a model from local directory or remote model repo. Note | """ Instantiate a model from local directory or remote model repo. Note | ||||
| that when loading from remote, the model revision can be specified. | that when loading from remote, the model revision can be specified. | ||||
| @@ -90,11 +89,11 @@ class Model(ABC): | |||||
| cfg = Config.from_file( | cfg = Config.from_file( | ||||
| osp.join(local_model_dir, ModelFile.CONFIGURATION)) | osp.join(local_model_dir, ModelFile.CONFIGURATION)) | ||||
| task_name = cfg.task | task_name = cfg.task | ||||
| if 'task' in kwargs: | |||||
| task_name = kwargs.pop('task') | |||||
| model_cfg = cfg.model | model_cfg = cfg.model | ||||
| if hasattr(model_cfg, 'model_type') and not hasattr(model_cfg, 'type'): | if hasattr(model_cfg, 'model_type') and not hasattr(model_cfg, 'type'): | ||||
| model_cfg.type = model_cfg.model_type | model_cfg.type = model_cfg.model_type | ||||
| model_cfg.model_dir = local_model_dir | model_cfg.model_dir = local_model_dir | ||||
| for k, v in kwargs.items(): | for k, v in kwargs.items(): | ||||
| model_cfg[k] = v | model_cfg[k] = v | ||||
| @@ -109,15 +108,19 @@ class Model(ABC): | |||||
| # dynamically add pipeline info to model for pipeline inference | # dynamically add pipeline info to model for pipeline inference | ||||
| if hasattr(cfg, 'pipeline'): | if hasattr(cfg, 'pipeline'): | ||||
| model.pipeline = cfg.pipeline | model.pipeline = cfg.pipeline | ||||
| if not hasattr(model, 'cfg'): | |||||
| model.cfg = cfg | |||||
| return model | return model | ||||
| def save_pretrained(self, | def save_pretrained(self, | ||||
| target_folder: Union[str, os.PathLike], | target_folder: Union[str, os.PathLike], | ||||
| save_checkpoint_names: Union[str, List[str]] = None, | save_checkpoint_names: Union[str, List[str]] = None, | ||||
| save_function: Callable = None, | |||||
| save_function: Callable = save_checkpoint, | |||||
| config: Optional[dict] = None, | config: Optional[dict] = None, | ||||
| **kwargs): | **kwargs): | ||||
| """save the pretrained model, its configuration and other related files to a directory, so that it can be re-loaded | |||||
| """save the pretrained model, its configuration and other related files to a directory, | |||||
| so that it can be re-loaded | |||||
| Args: | Args: | ||||
| target_folder (Union[str, os.PathLike]): | target_folder (Union[str, os.PathLike]): | ||||
| @@ -133,5 +136,10 @@ class Model(ABC): | |||||
| The config for the configuration.json, might not be identical with model.config | The config for the configuration.json, might not be identical with model.config | ||||
| """ | """ | ||||
| if config is None and hasattr(self, 'cfg'): | |||||
| config = self.cfg | |||||
| assert config is not None, 'Cannot save the model because the model config is empty.' | |||||
| if isinstance(config, Config): | |||||
| config = config.to_dict() | |||||
| save_pretrained(self, target_folder, save_checkpoint_names, | save_pretrained(self, target_folder, save_checkpoint_names, | ||||
| save_function, config, **kwargs) | save_function, config, **kwargs) | ||||
| @@ -1,12 +1,20 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| from modelscope.utils.config import ConfigDict | from modelscope.utils.config import ConfigDict | ||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.import_utils import INDEX_KEY, LazyImportModule | |||||
| from modelscope.utils.registry import TYPE_NAME, Registry, build_from_cfg | from modelscope.utils.registry import TYPE_NAME, Registry, build_from_cfg | ||||
| MODELS = Registry('models') | MODELS = Registry('models') | ||||
| BACKBONES = Registry('backbones') | |||||
| BACKBONES = MODELS | |||||
| HEADS = Registry('heads') | HEADS = Registry('heads') | ||||
| modules = LazyImportModule.AST_INDEX[INDEX_KEY] | |||||
| for module_index in list(modules.keys()): | |||||
| if module_index[1] == Tasks.backbone and module_index[0] == 'BACKBONES': | |||||
| modules[(MODELS.name.upper(), module_index[1], | |||||
| module_index[2])] = modules[module_index] | |||||
| def build_model(cfg: ConfigDict, | def build_model(cfg: ConfigDict, | ||||
| task_name: str = None, | task_name: str = None, | ||||
| @@ -23,30 +31,27 @@ def build_model(cfg: ConfigDict, | |||||
| cfg, MODELS, group_key=task_name, default_args=default_args) | cfg, MODELS, group_key=task_name, default_args=default_args) | ||||
| def build_backbone(cfg: ConfigDict, | |||||
| field: str = None, | |||||
| default_args: dict = None): | |||||
| def build_backbone(cfg: ConfigDict, default_args: dict = None): | |||||
| """ build backbone given backbone config dict | """ build backbone given backbone config dict | ||||
| Args: | Args: | ||||
| cfg (:obj:`ConfigDict`): config dict for backbone object. | cfg (:obj:`ConfigDict`): config dict for backbone object. | ||||
| field (str, optional): field, such as CV, NLP's backbone | |||||
| default_args (dict, optional): Default initialization arguments. | default_args (dict, optional): Default initialization arguments. | ||||
| """ | """ | ||||
| return build_from_cfg( | return build_from_cfg( | ||||
| cfg, BACKBONES, group_key=field, default_args=default_args) | |||||
| cfg, BACKBONES, group_key=Tasks.backbone, default_args=default_args) | |||||
| def build_head(cfg: ConfigDict, | def build_head(cfg: ConfigDict, | ||||
| group_key: str = None, | |||||
| task_name: str = None, | |||||
| default_args: dict = None): | default_args: dict = None): | ||||
| """ build head given config dict | """ build head given config dict | ||||
| Args: | Args: | ||||
| cfg (:obj:`ConfigDict`): config dict for head object. | cfg (:obj:`ConfigDict`): config dict for head object. | ||||
| task_name (str, optional): task name, refer to | |||||
| :obj:`Tasks` for more details | |||||
| default_args (dict, optional): Default initialization arguments. | default_args (dict, optional): Default initialization arguments. | ||||
| """ | """ | ||||
| if group_key is None: | |||||
| group_key = cfg[TYPE_NAME] | |||||
| return build_from_cfg( | return build_from_cfg( | ||||
| cfg, HEADS, group_key=group_key, default_args=default_args) | |||||
| cfg, HEADS, group_key=task_name, default_args=default_args) | |||||
| @@ -4,14 +4,16 @@ | |||||
| from . import (action_recognition, animal_recognition, body_2d_keypoints, | from . import (action_recognition, animal_recognition, body_2d_keypoints, | ||||
| body_3d_keypoints, cartoon, cmdssl_video_embedding, | body_3d_keypoints, cartoon, cmdssl_video_embedding, | ||||
| crowd_counting, face_2d_keypoints, face_detection, | crowd_counting, face_2d_keypoints, face_detection, | ||||
| face_generation, image_classification, image_color_enhance, | |||||
| image_colorization, image_denoise, image_instance_segmentation, | |||||
| face_generation, human_wholebody_keypoint, image_classification, | |||||
| image_color_enhance, image_colorization, image_denoise, | |||||
| image_inpainting, image_instance_segmentation, | |||||
| image_panoptic_segmentation, image_portrait_enhancement, | image_panoptic_segmentation, image_portrait_enhancement, | ||||
| image_reid_person, image_semantic_segmentation, | image_reid_person, image_semantic_segmentation, | ||||
| image_to_image_generation, image_to_image_translation, | image_to_image_generation, image_to_image_translation, | ||||
| movie_scene_segmentation, object_detection, | movie_scene_segmentation, object_detection, | ||||
| product_retrieval_embedding, realtime_object_detection, | product_retrieval_embedding, realtime_object_detection, | ||||
| salient_detection, shop_segmentation, super_resolution, | |||||
| referring_video_object_segmentation, salient_detection, | |||||
| shop_segmentation, super_resolution, | |||||
| video_single_object_tracking, video_summarization, virual_tryon) | video_single_object_tracking, video_summarization, virual_tryon) | ||||
| # yapf: enable | # yapf: enable | ||||
| @@ -4,6 +4,7 @@ import os | |||||
| import os.path as osp | import os.path as osp | ||||
| import shutil | import shutil | ||||
| import subprocess | import subprocess | ||||
| import uuid | |||||
| import cv2 | import cv2 | ||||
| import numpy as np | import numpy as np | ||||
| @@ -84,7 +85,9 @@ class ActionDetONNX(Model): | |||||
| def forward_video(self, video_name, scale): | def forward_video(self, video_name, scale): | ||||
| min_size, max_size = self._get_sizes(scale) | min_size, max_size = self._get_sizes(scale) | ||||
| tmp_dir = osp.join(self.tmp_dir, osp.basename(video_name)[:-4]) | |||||
| tmp_dir = osp.join( | |||||
| self.tmp_dir, | |||||
| str(uuid.uuid1()) + '_' + osp.basename(video_name)[:-4]) | |||||
| if osp.exists(tmp_dir): | if osp.exists(tmp_dir): | ||||
| shutil.rmtree(tmp_dir) | shutil.rmtree(tmp_dir) | ||||
| os.makedirs(tmp_dir) | os.makedirs(tmp_dir) | ||||
| @@ -110,6 +113,7 @@ class ActionDetONNX(Model): | |||||
| len(frame_names) * self.temporal_stride, | len(frame_names) * self.temporal_stride, | ||||
| self.temporal_stride)) | self.temporal_stride)) | ||||
| batch_imgs = [self.parse_frames(names) for names in frame_names] | batch_imgs = [self.parse_frames(names) for names in frame_names] | ||||
| shutil.rmtree(tmp_dir) | |||||
| N, _, T, H, W = batch_imgs[0].shape | N, _, T, H, W = batch_imgs[0].shape | ||||
| scale_min = min_size / min(H, W) | scale_min = min_size / min(H, W) | ||||
| @@ -128,7 +132,6 @@ class ActionDetONNX(Model): | |||||
| 'timestamp': t, | 'timestamp': t, | ||||
| 'actions': res | 'actions': res | ||||
| } for t, res in zip(timestamp, results)] | } for t, res in zip(timestamp, results)] | ||||
| shutil.rmtree(tmp_dir) | |||||
| return results | return results | ||||
| def forward(self, video_name): | def forward(self, video_name): | ||||
| @@ -1,3 +1,5 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os | import os | ||||
| from typing import Any, Dict, Optional, Union | from typing import Any, Dict, Optional, Union | ||||
| @@ -1,10 +1,10 @@ | |||||
| # ------------------------------------------------------------------------------ | |||||
| # Copyright (c) Microsoft | |||||
| # Licensed under the MIT License. | |||||
| # Written by Bin Xiao (Bin.Xiao@microsoft.com) | |||||
| # Modified by Ke Sun (sunk@mail.ustc.edu.cn) | |||||
| # https://github.com/HRNet/HRNet-Image-Classification/blob/master/lib/models/cls_hrnet.py | |||||
| # ------------------------------------------------------------------------------ | |||||
| """ | |||||
| Copyright (c) Microsoft | |||||
| Licensed under the MIT License. | |||||
| Written by Bin Xiao (Bin.Xiao@microsoft.com) | |||||
| Modified by Ke Sun (sunk@mail.ustc.edu.cn) | |||||
| https://github.com/HRNet/HRNet-Image-Classification/blob/master/lib/models/cls_hrnet.py | |||||
| """ | |||||
| import functools | import functools | ||||
| import logging | import logging | ||||
| @@ -8,12 +8,14 @@ if TYPE_CHECKING: | |||||
| from .mtcnn import MtcnnFaceDetector | from .mtcnn import MtcnnFaceDetector | ||||
| from .retinaface import RetinaFaceDetection | from .retinaface import RetinaFaceDetection | ||||
| from .ulfd_slim import UlfdFaceDetector | from .ulfd_slim import UlfdFaceDetector | ||||
| from .scrfd import ScrfdDetect | |||||
| else: | else: | ||||
| _import_structure = { | _import_structure = { | ||||
| 'ulfd_slim': ['UlfdFaceDetector'], | 'ulfd_slim': ['UlfdFaceDetector'], | ||||
| 'retinaface': ['RetinaFaceDetection'], | 'retinaface': ['RetinaFaceDetection'], | ||||
| 'mtcnn': ['MtcnnFaceDetector'], | 'mtcnn': ['MtcnnFaceDetector'], | ||||
| 'mogface': ['MogFaceDetector'] | |||||
| 'mogface': ['MogFaceDetector'], | |||||
| 'scrfd': ['ScrfdDetect'] | |||||
| } | } | ||||
| import sys | import sys | ||||
| @@ -1,189 +0,0 @@ | |||||
| """ | |||||
| The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at | |||||
| https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/datasets/pipelines/transforms.py | |||||
| """ | |||||
| import numpy as np | |||||
| from mmdet.datasets.builder import PIPELINES | |||||
| from numpy import random | |||||
| @PIPELINES.register_module() | |||||
| class RandomSquareCrop(object): | |||||
| """Random crop the image & bboxes, the cropped patches have minimum IoU | |||||
| requirement with original image & bboxes, the IoU threshold is randomly | |||||
| selected from min_ious. | |||||
| Args: | |||||
| min_ious (tuple): minimum IoU threshold for all intersections with | |||||
| bounding boxes | |||||
| min_crop_size (float): minimum crop's size (i.e. h,w := a*h, a*w, | |||||
| where a >= min_crop_size). | |||||
| Note: | |||||
| The keys for bboxes, labels and masks should be paired. That is, \ | |||||
| `gt_bboxes` corresponds to `gt_labels` and `gt_masks`, and \ | |||||
| `gt_bboxes_ignore` to `gt_labels_ignore` and `gt_masks_ignore`. | |||||
| """ | |||||
| def __init__(self, | |||||
| crop_ratio_range=None, | |||||
| crop_choice=None, | |||||
| bbox_clip_border=True): | |||||
| self.crop_ratio_range = crop_ratio_range | |||||
| self.crop_choice = crop_choice | |||||
| self.bbox_clip_border = bbox_clip_border | |||||
| assert (self.crop_ratio_range is None) ^ (self.crop_choice is None) | |||||
| if self.crop_ratio_range is not None: | |||||
| self.crop_ratio_min, self.crop_ratio_max = self.crop_ratio_range | |||||
| self.bbox2label = { | |||||
| 'gt_bboxes': 'gt_labels', | |||||
| 'gt_bboxes_ignore': 'gt_labels_ignore' | |||||
| } | |||||
| self.bbox2mask = { | |||||
| 'gt_bboxes': 'gt_masks', | |||||
| 'gt_bboxes_ignore': 'gt_masks_ignore' | |||||
| } | |||||
| def __call__(self, results): | |||||
| """Call function to crop images and bounding boxes with minimum IoU | |||||
| constraint. | |||||
| Args: | |||||
| results (dict): Result dict from loading pipeline. | |||||
| Returns: | |||||
| dict: Result dict with images and bounding boxes cropped, \ | |||||
| 'img_shape' key is updated. | |||||
| """ | |||||
| if 'img_fields' in results: | |||||
| assert results['img_fields'] == ['img'], \ | |||||
| 'Only single img_fields is allowed' | |||||
| img = results['img'] | |||||
| assert 'bbox_fields' in results | |||||
| assert 'gt_bboxes' in results | |||||
| boxes = results['gt_bboxes'] | |||||
| h, w, c = img.shape | |||||
| scale_retry = 0 | |||||
| if self.crop_ratio_range is not None: | |||||
| max_scale = self.crop_ratio_max | |||||
| else: | |||||
| max_scale = np.amax(self.crop_choice) | |||||
| while True: | |||||
| scale_retry += 1 | |||||
| if scale_retry == 1 or max_scale > 1.0: | |||||
| if self.crop_ratio_range is not None: | |||||
| scale = np.random.uniform(self.crop_ratio_min, | |||||
| self.crop_ratio_max) | |||||
| elif self.crop_choice is not None: | |||||
| scale = np.random.choice(self.crop_choice) | |||||
| else: | |||||
| scale = scale * 1.2 | |||||
| for i in range(250): | |||||
| short_side = min(w, h) | |||||
| cw = int(scale * short_side) | |||||
| ch = cw | |||||
| # TODO +1 | |||||
| if w == cw: | |||||
| left = 0 | |||||
| elif w > cw: | |||||
| left = random.randint(0, w - cw) | |||||
| else: | |||||
| left = random.randint(w - cw, 0) | |||||
| if h == ch: | |||||
| top = 0 | |||||
| elif h > ch: | |||||
| top = random.randint(0, h - ch) | |||||
| else: | |||||
| top = random.randint(h - ch, 0) | |||||
| patch = np.array( | |||||
| (int(left), int(top), int(left + cw), int(top + ch)), | |||||
| dtype=np.int) | |||||
| # center of boxes should inside the crop img | |||||
| # only adjust boxes and instance masks when the gt is not empty | |||||
| # adjust boxes | |||||
| def is_center_of_bboxes_in_patch(boxes, patch): | |||||
| # TODO >= | |||||
| center = (boxes[:, :2] + boxes[:, 2:]) / 2 | |||||
| mask = \ | |||||
| ((center[:, 0] > patch[0]) | |||||
| * (center[:, 1] > patch[1]) | |||||
| * (center[:, 0] < patch[2]) | |||||
| * (center[:, 1] < patch[3])) | |||||
| return mask | |||||
| mask = is_center_of_bboxes_in_patch(boxes, patch) | |||||
| if not mask.any(): | |||||
| continue | |||||
| for key in results.get('bbox_fields', []): | |||||
| boxes = results[key].copy() | |||||
| mask = is_center_of_bboxes_in_patch(boxes, patch) | |||||
| boxes = boxes[mask] | |||||
| if self.bbox_clip_border: | |||||
| boxes[:, 2:] = boxes[:, 2:].clip(max=patch[2:]) | |||||
| boxes[:, :2] = boxes[:, :2].clip(min=patch[:2]) | |||||
| boxes -= np.tile(patch[:2], 2) | |||||
| results[key] = boxes | |||||
| # labels | |||||
| label_key = self.bbox2label.get(key) | |||||
| if label_key in results: | |||||
| results[label_key] = results[label_key][mask] | |||||
| # keypoints field | |||||
| if key == 'gt_bboxes': | |||||
| for kps_key in results.get('keypoints_fields', []): | |||||
| keypointss = results[kps_key].copy() | |||||
| keypointss = keypointss[mask, :, :] | |||||
| if self.bbox_clip_border: | |||||
| keypointss[:, :, : | |||||
| 2] = keypointss[:, :, :2].clip( | |||||
| max=patch[2:]) | |||||
| keypointss[:, :, : | |||||
| 2] = keypointss[:, :, :2].clip( | |||||
| min=patch[:2]) | |||||
| keypointss[:, :, 0] -= patch[0] | |||||
| keypointss[:, :, 1] -= patch[1] | |||||
| results[kps_key] = keypointss | |||||
| # mask fields | |||||
| mask_key = self.bbox2mask.get(key) | |||||
| if mask_key in results: | |||||
| results[mask_key] = results[mask_key][mask.nonzero() | |||||
| [0]].crop(patch) | |||||
| # adjust the img no matter whether the gt is empty before crop | |||||
| rimg = np.ones((ch, cw, 3), dtype=img.dtype) * 128 | |||||
| patch_from = patch.copy() | |||||
| patch_from[0] = max(0, patch_from[0]) | |||||
| patch_from[1] = max(0, patch_from[1]) | |||||
| patch_from[2] = min(img.shape[1], patch_from[2]) | |||||
| patch_from[3] = min(img.shape[0], patch_from[3]) | |||||
| patch_to = patch.copy() | |||||
| patch_to[0] = max(0, patch_to[0] * -1) | |||||
| patch_to[1] = max(0, patch_to[1] * -1) | |||||
| patch_to[2] = patch_to[0] + (patch_from[2] - patch_from[0]) | |||||
| patch_to[3] = patch_to[1] + (patch_from[3] - patch_from[1]) | |||||
| rimg[patch_to[1]:patch_to[3], | |||||
| patch_to[0]:patch_to[2], :] = img[ | |||||
| patch_from[1]:patch_from[3], | |||||
| patch_from[0]:patch_from[2], :] | |||||
| img = rimg | |||||
| results['img'] = img | |||||
| results['img_shape'] = img.shape | |||||
| return results | |||||
| def __repr__(self): | |||||
| repr_str = self.__class__.__name__ | |||||
| repr_str += f'(min_ious={self.min_iou}, ' | |||||
| repr_str += f'crop_size={self.crop_size})' | |||||
| return repr_str | |||||
| @@ -1,3 +1,5 @@ | |||||
| # The implementation is based on MogFace, available at | |||||
| # https://github.com/damo-cv/MogFace | |||||
| import os | import os | ||||
| import cv2 | import cv2 | ||||
| @@ -0,0 +1,2 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from .scrfd_detect import ScrfdDetect | |||||
| @@ -6,7 +6,7 @@ import numpy as np | |||||
| import torch | import torch | ||||
| def bbox2result(bboxes, labels, num_classes, kps=None): | |||||
| def bbox2result(bboxes, labels, num_classes, kps=None, num_kps=5): | |||||
| """Convert detection results to a list of numpy arrays. | """Convert detection results to a list of numpy arrays. | ||||
| Args: | Args: | ||||
| @@ -17,7 +17,7 @@ def bbox2result(bboxes, labels, num_classes, kps=None): | |||||
| Returns: | Returns: | ||||
| list(ndarray): bbox results of each class | list(ndarray): bbox results of each class | ||||
| """ | """ | ||||
| bbox_len = 5 if kps is None else 5 + 10 # if has kps, add 10 kps into bbox | |||||
| bbox_len = 5 if kps is None else 5 + num_kps * 2 # if has kps, add num_kps*2 into bbox | |||||
| if bboxes.shape[0] == 0: | if bboxes.shape[0] == 0: | ||||
| return [ | return [ | ||||
| np.zeros((0, bbox_len), dtype=np.float32) | np.zeros((0, bbox_len), dtype=np.float32) | ||||
| @@ -17,6 +17,7 @@ def multiclass_nms(multi_bboxes, | |||||
| Args: | Args: | ||||
| multi_bboxes (Tensor): shape (n, #class*4) or (n, 4) | multi_bboxes (Tensor): shape (n, #class*4) or (n, 4) | ||||
| multi_kps (Tensor): shape (n, #class*num_kps*2) or (n, num_kps*2) | |||||
| multi_scores (Tensor): shape (n, #class), where the last column | multi_scores (Tensor): shape (n, #class), where the last column | ||||
| contains scores of the background class, but this will be ignored. | contains scores of the background class, but this will be ignored. | ||||
| score_thr (float): bbox threshold, bboxes with scores lower than it | score_thr (float): bbox threshold, bboxes with scores lower than it | ||||
| @@ -36,16 +37,18 @@ def multiclass_nms(multi_bboxes, | |||||
| num_classes = multi_scores.size(1) - 1 | num_classes = multi_scores.size(1) - 1 | ||||
| # exclude background category | # exclude background category | ||||
| kps = None | kps = None | ||||
| if multi_kps is not None: | |||||
| num_kps = int((multi_kps.shape[1] / num_classes) / 2) | |||||
| if multi_bboxes.shape[1] > 4: | if multi_bboxes.shape[1] > 4: | ||||
| bboxes = multi_bboxes.view(multi_scores.size(0), -1, 4) | bboxes = multi_bboxes.view(multi_scores.size(0), -1, 4) | ||||
| if multi_kps is not None: | if multi_kps is not None: | ||||
| kps = multi_kps.view(multi_scores.size(0), -1, 10) | |||||
| kps = multi_kps.view(multi_scores.size(0), -1, num_kps * 2) | |||||
| else: | else: | ||||
| bboxes = multi_bboxes[:, None].expand( | bboxes = multi_bboxes[:, None].expand( | ||||
| multi_scores.size(0), num_classes, 4) | multi_scores.size(0), num_classes, 4) | ||||
| if multi_kps is not None: | if multi_kps is not None: | ||||
| kps = multi_kps[:, None].expand( | kps = multi_kps[:, None].expand( | ||||
| multi_scores.size(0), num_classes, 10) | |||||
| multi_scores.size(0), num_classes, num_kps * 2) | |||||
| scores = multi_scores[:, :-1] | scores = multi_scores[:, :-1] | ||||
| if score_factors is not None: | if score_factors is not None: | ||||
| @@ -56,7 +59,7 @@ def multiclass_nms(multi_bboxes, | |||||
| bboxes = bboxes.reshape(-1, 4) | bboxes = bboxes.reshape(-1, 4) | ||||
| if kps is not None: | if kps is not None: | ||||
| kps = kps.reshape(-1, 10) | |||||
| kps = kps.reshape(-1, num_kps * 2) | |||||
| scores = scores.reshape(-1) | scores = scores.reshape(-1) | ||||
| labels = labels.reshape(-1) | labels = labels.reshape(-1) | ||||
| @@ -2,6 +2,12 @@ | |||||
| The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at | The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at | ||||
| https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/datasets/pipelines | https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/datasets/pipelines | ||||
| """ | """ | ||||
| from .auto_augment import RotateV2 | |||||
| from .formating import DefaultFormatBundleV2 | |||||
| from .loading import LoadAnnotationsV2 | |||||
| from .transforms import RandomSquareCrop | from .transforms import RandomSquareCrop | ||||
| __all__ = ['RandomSquareCrop'] | |||||
| __all__ = [ | |||||
| 'RandomSquareCrop', 'LoadAnnotationsV2', 'RotateV2', | |||||
| 'DefaultFormatBundleV2' | |||||
| ] | |||||
| @@ -0,0 +1,271 @@ | |||||
| """ | |||||
| The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at | |||||
| https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/datasets/pipelines/auto_augment.py | |||||
| """ | |||||
| import copy | |||||
| import cv2 | |||||
| import mmcv | |||||
| import numpy as np | |||||
| from mmdet.datasets.builder import PIPELINES | |||||
| _MAX_LEVEL = 10 | |||||
| def level_to_value(level, max_value): | |||||
| """Map from level to values based on max_value.""" | |||||
| return (level / _MAX_LEVEL) * max_value | |||||
| def random_negative(value, random_negative_prob): | |||||
| """Randomly negate value based on random_negative_prob.""" | |||||
| return -value if np.random.rand() < random_negative_prob else value | |||||
| def bbox2fields(): | |||||
| """The key correspondence from bboxes to labels, masks and | |||||
| segmentations.""" | |||||
| bbox2label = { | |||||
| 'gt_bboxes': 'gt_labels', | |||||
| 'gt_bboxes_ignore': 'gt_labels_ignore' | |||||
| } | |||||
| bbox2mask = { | |||||
| 'gt_bboxes': 'gt_masks', | |||||
| 'gt_bboxes_ignore': 'gt_masks_ignore' | |||||
| } | |||||
| bbox2seg = { | |||||
| 'gt_bboxes': 'gt_semantic_seg', | |||||
| } | |||||
| return bbox2label, bbox2mask, bbox2seg | |||||
| @PIPELINES.register_module() | |||||
| class RotateV2(object): | |||||
| """Apply Rotate Transformation to image (and its corresponding bbox, mask, | |||||
| segmentation). | |||||
| Args: | |||||
| level (int | float): The level should be in range (0,_MAX_LEVEL]. | |||||
| scale (int | float): Isotropic scale factor. Same in | |||||
| ``mmcv.imrotate``. | |||||
| center (int | float | tuple[float]): Center point (w, h) of the | |||||
| rotation in the source image. If None, the center of the | |||||
| image will be used. Same in ``mmcv.imrotate``. | |||||
| img_fill_val (int | float | tuple): The fill value for image border. | |||||
| If float, the same value will be used for all the three | |||||
| channels of image. If tuple, the should be 3 elements (e.g. | |||||
| equals the number of channels for image). | |||||
| seg_ignore_label (int): The fill value used for segmentation map. | |||||
| Note this value must equals ``ignore_label`` in ``semantic_head`` | |||||
| of the corresponding config. Default 255. | |||||
| prob (float): The probability for perform transformation and | |||||
| should be in range 0 to 1. | |||||
| max_rotate_angle (int | float): The maximum angles for rotate | |||||
| transformation. | |||||
| random_negative_prob (float): The probability that turns the | |||||
| offset negative. | |||||
| """ | |||||
| def __init__(self, | |||||
| level, | |||||
| scale=1, | |||||
| center=None, | |||||
| img_fill_val=128, | |||||
| seg_ignore_label=255, | |||||
| prob=0.5, | |||||
| max_rotate_angle=30, | |||||
| random_negative_prob=0.5): | |||||
| assert isinstance(level, (int, float)), \ | |||||
| f'The level must be type int or float. got {type(level)}.' | |||||
| assert 0 <= level <= _MAX_LEVEL, \ | |||||
| f'The level should be in range (0,{_MAX_LEVEL}]. got {level}.' | |||||
| assert isinstance(scale, (int, float)), \ | |||||
| f'The scale must be type int or float. got type {type(scale)}.' | |||||
| if isinstance(center, (int, float)): | |||||
| center = (center, center) | |||||
| elif isinstance(center, tuple): | |||||
| assert len(center) == 2, 'center with type tuple must have '\ | |||||
| f'2 elements. got {len(center)} elements.' | |||||
| else: | |||||
| assert center is None, 'center must be None or type int, '\ | |||||
| f'float or tuple, got type {type(center)}.' | |||||
| if isinstance(img_fill_val, (float, int)): | |||||
| img_fill_val = tuple([float(img_fill_val)] * 3) | |||||
| elif isinstance(img_fill_val, tuple): | |||||
| assert len(img_fill_val) == 3, 'img_fill_val as tuple must '\ | |||||
| f'have 3 elements. got {len(img_fill_val)}.' | |||||
| img_fill_val = tuple([float(val) for val in img_fill_val]) | |||||
| else: | |||||
| raise ValueError( | |||||
| 'img_fill_val must be float or tuple with 3 elements.') | |||||
| assert np.all([0 <= val <= 255 for val in img_fill_val]), \ | |||||
| 'all elements of img_fill_val should between range [0,255]. '\ | |||||
| f'got {img_fill_val}.' | |||||
| assert 0 <= prob <= 1.0, 'The probability should be in range [0,1]. '\ | |||||
| f'got {prob}.' | |||||
| assert isinstance(max_rotate_angle, (int, float)), 'max_rotate_angle '\ | |||||
| f'should be type int or float. got type {type(max_rotate_angle)}.' | |||||
| self.level = level | |||||
| self.scale = scale | |||||
| # Rotation angle in degrees. Positive values mean | |||||
| # clockwise rotation. | |||||
| self.angle = level_to_value(level, max_rotate_angle) | |||||
| self.center = center | |||||
| self.img_fill_val = img_fill_val | |||||
| self.seg_ignore_label = seg_ignore_label | |||||
| self.prob = prob | |||||
| self.max_rotate_angle = max_rotate_angle | |||||
| self.random_negative_prob = random_negative_prob | |||||
| def _rotate_img(self, results, angle, center=None, scale=1.0): | |||||
| """Rotate the image. | |||||
| Args: | |||||
| results (dict): Result dict from loading pipeline. | |||||
| angle (float): Rotation angle in degrees, positive values | |||||
| mean clockwise rotation. Same in ``mmcv.imrotate``. | |||||
| center (tuple[float], optional): Center point (w, h) of the | |||||
| rotation. Same in ``mmcv.imrotate``. | |||||
| scale (int | float): Isotropic scale factor. Same in | |||||
| ``mmcv.imrotate``. | |||||
| """ | |||||
| for key in results.get('img_fields', ['img']): | |||||
| img = results[key].copy() | |||||
| img_rotated = mmcv.imrotate( | |||||
| img, angle, center, scale, border_value=self.img_fill_val) | |||||
| results[key] = img_rotated.astype(img.dtype) | |||||
| results['img_shape'] = results[key].shape | |||||
| def _rotate_bboxes(self, results, rotate_matrix): | |||||
| """Rotate the bboxes.""" | |||||
| h, w, c = results['img_shape'] | |||||
| for key in results.get('bbox_fields', []): | |||||
| min_x, min_y, max_x, max_y = np.split( | |||||
| results[key], results[key].shape[-1], axis=-1) | |||||
| coordinates = np.stack([[min_x, min_y], [max_x, min_y], | |||||
| [min_x, max_y], | |||||
| [max_x, max_y]]) # [4, 2, nb_bbox, 1] | |||||
| # pad 1 to convert from format [x, y] to homogeneous | |||||
| # coordinates format [x, y, 1] | |||||
| coordinates = np.concatenate( | |||||
| (coordinates, | |||||
| np.ones((4, 1, coordinates.shape[2], 1), coordinates.dtype)), | |||||
| axis=1) # [4, 3, nb_bbox, 1] | |||||
| coordinates = coordinates.transpose( | |||||
| (2, 0, 1, 3)) # [nb_bbox, 4, 3, 1] | |||||
| rotated_coords = np.matmul(rotate_matrix, | |||||
| coordinates) # [nb_bbox, 4, 2, 1] | |||||
| rotated_coords = rotated_coords[..., 0] # [nb_bbox, 4, 2] | |||||
| min_x, min_y = np.min( | |||||
| rotated_coords[:, :, 0], axis=1), np.min( | |||||
| rotated_coords[:, :, 1], axis=1) | |||||
| max_x, max_y = np.max( | |||||
| rotated_coords[:, :, 0], axis=1), np.max( | |||||
| rotated_coords[:, :, 1], axis=1) | |||||
| results[key] = np.stack([min_x, min_y, max_x, max_y], | |||||
| axis=-1).astype(results[key].dtype) | |||||
| def _rotate_keypoints90(self, results, angle): | |||||
| """Rotate the keypoints, only valid when angle in [-90,90,-180,180]""" | |||||
| if angle not in [-90, 90, 180, -180 | |||||
| ] or self.scale != 1 or self.center is not None: | |||||
| return | |||||
| for key in results.get('keypoints_fields', []): | |||||
| k = results[key] | |||||
| if angle == 90: | |||||
| w, h, c = results['img'].shape | |||||
| new = np.stack([h - k[..., 1], k[..., 0], k[..., 2]], axis=-1) | |||||
| elif angle == -90: | |||||
| w, h, c = results['img'].shape | |||||
| new = np.stack([k[..., 1], w - k[..., 0], k[..., 2]], axis=-1) | |||||
| else: | |||||
| h, w, c = results['img'].shape | |||||
| new = np.stack([w - k[..., 0], h - k[..., 1], k[..., 2]], | |||||
| axis=-1) | |||||
| # a kps is invalid if thrid value is -1 | |||||
| kps_invalid = new[..., -1][:, -1] == -1 | |||||
| new[kps_invalid] = np.zeros(new.shape[1:]) - 1 | |||||
| results[key] = new | |||||
| def _rotate_masks(self, | |||||
| results, | |||||
| angle, | |||||
| center=None, | |||||
| scale=1.0, | |||||
| fill_val=0): | |||||
| """Rotate the masks.""" | |||||
| h, w, c = results['img_shape'] | |||||
| for key in results.get('mask_fields', []): | |||||
| masks = results[key] | |||||
| results[key] = masks.rotate((h, w), angle, center, scale, fill_val) | |||||
| def _rotate_seg(self, | |||||
| results, | |||||
| angle, | |||||
| center=None, | |||||
| scale=1.0, | |||||
| fill_val=255): | |||||
| """Rotate the segmentation map.""" | |||||
| for key in results.get('seg_fields', []): | |||||
| seg = results[key].copy() | |||||
| results[key] = mmcv.imrotate( | |||||
| seg, angle, center, scale, | |||||
| border_value=fill_val).astype(seg.dtype) | |||||
| def _filter_invalid(self, results, min_bbox_size=0): | |||||
| """Filter bboxes and corresponding masks too small after rotate | |||||
| augmentation.""" | |||||
| bbox2label, bbox2mask, _ = bbox2fields() | |||||
| for key in results.get('bbox_fields', []): | |||||
| bbox_w = results[key][:, 2] - results[key][:, 0] | |||||
| bbox_h = results[key][:, 3] - results[key][:, 1] | |||||
| valid_inds = (bbox_w > min_bbox_size) & (bbox_h > min_bbox_size) | |||||
| valid_inds = np.nonzero(valid_inds)[0] | |||||
| results[key] = results[key][valid_inds] | |||||
| # label fields. e.g. gt_labels and gt_labels_ignore | |||||
| label_key = bbox2label.get(key) | |||||
| if label_key in results: | |||||
| results[label_key] = results[label_key][valid_inds] | |||||
| # mask fields, e.g. gt_masks and gt_masks_ignore | |||||
| mask_key = bbox2mask.get(key) | |||||
| if mask_key in results: | |||||
| results[mask_key] = results[mask_key][valid_inds] | |||||
| def __call__(self, results): | |||||
| """Call function to rotate images, bounding boxes, masks and semantic | |||||
| segmentation maps. | |||||
| Args: | |||||
| results (dict): Result dict from loading pipeline. | |||||
| Returns: | |||||
| dict: Rotated results. | |||||
| """ | |||||
| if np.random.rand() > self.prob: | |||||
| return results | |||||
| h, w = results['img'].shape[:2] | |||||
| center = self.center | |||||
| if center is None: | |||||
| center = ((w - 1) * 0.5, (h - 1) * 0.5) | |||||
| angle = random_negative(self.angle, self.random_negative_prob) | |||||
| self._rotate_img(results, angle, center, self.scale) | |||||
| rotate_matrix = cv2.getRotationMatrix2D(center, -angle, self.scale) | |||||
| self._rotate_bboxes(results, rotate_matrix) | |||||
| self._rotate_keypoints90(results, angle) | |||||
| self._rotate_masks(results, angle, center, self.scale, fill_val=0) | |||||
| self._rotate_seg( | |||||
| results, angle, center, self.scale, fill_val=self.seg_ignore_label) | |||||
| self._filter_invalid(results) | |||||
| return results | |||||
| def __repr__(self): | |||||
| repr_str = self.__class__.__name__ | |||||
| repr_str += f'(level={self.level}, ' | |||||
| repr_str += f'scale={self.scale}, ' | |||||
| repr_str += f'center={self.center}, ' | |||||
| repr_str += f'img_fill_val={self.img_fill_val}, ' | |||||
| repr_str += f'seg_ignore_label={self.seg_ignore_label}, ' | |||||
| repr_str += f'prob={self.prob}, ' | |||||
| repr_str += f'max_rotate_angle={self.max_rotate_angle}, ' | |||||
| repr_str += f'random_negative_prob={self.random_negative_prob})' | |||||
| return repr_str | |||||
| @@ -0,0 +1,113 @@ | |||||
| """ | |||||
| The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at | |||||
| https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/datasets/pipelines/formating.py | |||||
| """ | |||||
| import numpy as np | |||||
| import torch | |||||
| from mmcv.parallel import DataContainer as DC | |||||
| from mmdet.datasets.builder import PIPELINES | |||||
| def to_tensor(data): | |||||
| """Convert objects of various python types to :obj:`torch.Tensor`. | |||||
| Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`, | |||||
| :class:`Sequence`, :class:`int` and :class:`float`. | |||||
| Args: | |||||
| data (torch.Tensor | numpy.ndarray | Sequence | int | float): Data to | |||||
| be converted. | |||||
| """ | |||||
| if isinstance(data, torch.Tensor): | |||||
| return data | |||||
| elif isinstance(data, np.ndarray): | |||||
| return torch.from_numpy(data) | |||||
| elif isinstance(data, Sequence) and not mmcv.is_str(data): | |||||
| return torch.tensor(data) | |||||
| elif isinstance(data, int): | |||||
| return torch.LongTensor([data]) | |||||
| elif isinstance(data, float): | |||||
| return torch.FloatTensor([data]) | |||||
| else: | |||||
| raise TypeError(f'type {type(data)} cannot be converted to tensor.') | |||||
| @PIPELINES.register_module() | |||||
| class DefaultFormatBundleV2(object): | |||||
| """Default formatting bundle. | |||||
| It simplifies the pipeline of formatting common fields, including "img", | |||||
| "proposals", "gt_bboxes", "gt_labels", "gt_masks" and "gt_semantic_seg". | |||||
| These fields are formatted as follows. | |||||
| - img: (1)transpose, (2)to tensor, (3)to DataContainer (stack=True) | |||||
| - proposals: (1)to tensor, (2)to DataContainer | |||||
| - gt_bboxes: (1)to tensor, (2)to DataContainer | |||||
| - gt_bboxes_ignore: (1)to tensor, (2)to DataContainer | |||||
| - gt_labels: (1)to tensor, (2)to DataContainer | |||||
| - gt_masks: (1)to tensor, (2)to DataContainer (cpu_only=True) | |||||
| - gt_semantic_seg: (1)unsqueeze dim-0 (2)to tensor, \ | |||||
| (3)to DataContainer (stack=True) | |||||
| """ | |||||
| def __call__(self, results): | |||||
| """Call function to transform and format common fields in results. | |||||
| Args: | |||||
| results (dict): Result dict contains the data to convert. | |||||
| Returns: | |||||
| dict: The result dict contains the data that is formatted with \ | |||||
| default bundle. | |||||
| """ | |||||
| if 'img' in results: | |||||
| img = results['img'] | |||||
| # add default meta keys | |||||
| results = self._add_default_meta_keys(results) | |||||
| if len(img.shape) < 3: | |||||
| img = np.expand_dims(img, -1) | |||||
| img = np.ascontiguousarray(img.transpose(2, 0, 1)) | |||||
| results['img'] = DC(to_tensor(img), stack=True) | |||||
| for key in [ | |||||
| 'proposals', 'gt_bboxes', 'gt_bboxes_ignore', 'gt_keypointss', | |||||
| 'gt_labels' | |||||
| ]: | |||||
| if key not in results: | |||||
| continue | |||||
| results[key] = DC(to_tensor(results[key])) | |||||
| if 'gt_masks' in results: | |||||
| results['gt_masks'] = DC(results['gt_masks'], cpu_only=True) | |||||
| if 'gt_semantic_seg' in results: | |||||
| results['gt_semantic_seg'] = DC( | |||||
| to_tensor(results['gt_semantic_seg'][None, ...]), stack=True) | |||||
| return results | |||||
| def _add_default_meta_keys(self, results): | |||||
| """Add default meta keys. | |||||
| We set default meta keys including `pad_shape`, `scale_factor` and | |||||
| `img_norm_cfg` to avoid the case where no `Resize`, `Normalize` and | |||||
| `Pad` are implemented during the whole pipeline. | |||||
| Args: | |||||
| results (dict): Result dict contains the data to convert. | |||||
| Returns: | |||||
| results (dict): Updated result dict contains the data to convert. | |||||
| """ | |||||
| img = results['img'] | |||||
| results.setdefault('pad_shape', img.shape) | |||||
| results.setdefault('scale_factor', 1.0) | |||||
| num_channels = 1 if len(img.shape) < 3 else img.shape[2] | |||||
| results.setdefault( | |||||
| 'img_norm_cfg', | |||||
| dict( | |||||
| mean=np.zeros(num_channels, dtype=np.float32), | |||||
| std=np.ones(num_channels, dtype=np.float32), | |||||
| to_rgb=False)) | |||||
| return results | |||||
| def __repr__(self): | |||||
| return self.__class__.__name__ | |||||
| @@ -0,0 +1,225 @@ | |||||
| """ | |||||
| The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at | |||||
| https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/datasets/pipelines/loading.py | |||||
| """ | |||||
| import os.path as osp | |||||
| import numpy as np | |||||
| import pycocotools.mask as maskUtils | |||||
| from mmdet.core import BitmapMasks, PolygonMasks | |||||
| from mmdet.datasets.builder import PIPELINES | |||||
| @PIPELINES.register_module() | |||||
| class LoadAnnotationsV2(object): | |||||
| """Load mutiple types of annotations. | |||||
| Args: | |||||
| with_bbox (bool): Whether to parse and load the bbox annotation. | |||||
| Default: True. | |||||
| with_label (bool): Whether to parse and load the label annotation. | |||||
| Default: True. | |||||
| with_keypoints (bool): Whether to parse and load the keypoints annotation. | |||||
| Default: False. | |||||
| with_mask (bool): Whether to parse and load the mask annotation. | |||||
| Default: False. | |||||
| with_seg (bool): Whether to parse and load the semantic segmentation | |||||
| annotation. Default: False. | |||||
| poly2mask (bool): Whether to convert the instance masks from polygons | |||||
| to bitmaps. Default: True. | |||||
| file_client_args (dict): Arguments to instantiate a FileClient. | |||||
| See :class:`mmcv.fileio.FileClient` for details. | |||||
| Defaults to ``dict(backend='disk')``. | |||||
| """ | |||||
| def __init__(self, | |||||
| with_bbox=True, | |||||
| with_label=True, | |||||
| with_keypoints=False, | |||||
| with_mask=False, | |||||
| with_seg=False, | |||||
| poly2mask=True, | |||||
| file_client_args=dict(backend='disk')): | |||||
| self.with_bbox = with_bbox | |||||
| self.with_label = with_label | |||||
| self.with_keypoints = with_keypoints | |||||
| self.with_mask = with_mask | |||||
| self.with_seg = with_seg | |||||
| self.poly2mask = poly2mask | |||||
| self.file_client_args = file_client_args.copy() | |||||
| self.file_client = None | |||||
| def _load_bboxes(self, results): | |||||
| """Private function to load bounding box annotations. | |||||
| Args: | |||||
| results (dict): Result dict from :obj:`mmdet.CustomDataset`. | |||||
| Returns: | |||||
| dict: The dict contains loaded bounding box annotations. | |||||
| """ | |||||
| ann_info = results['ann_info'] | |||||
| results['gt_bboxes'] = ann_info['bboxes'].copy() | |||||
| gt_bboxes_ignore = ann_info.get('bboxes_ignore', None) | |||||
| if gt_bboxes_ignore is not None: | |||||
| results['gt_bboxes_ignore'] = gt_bboxes_ignore.copy() | |||||
| results['bbox_fields'].append('gt_bboxes_ignore') | |||||
| results['bbox_fields'].append('gt_bboxes') | |||||
| return results | |||||
| def _load_keypoints(self, results): | |||||
| """Private function to load bounding box annotations. | |||||
| Args: | |||||
| results (dict): Result dict from :obj:`mmdet.CustomDataset`. | |||||
| Returns: | |||||
| dict: The dict contains loaded bounding box annotations. | |||||
| """ | |||||
| ann_info = results['ann_info'] | |||||
| results['gt_keypointss'] = ann_info['keypointss'].copy() | |||||
| results['keypoints_fields'] = ['gt_keypointss'] | |||||
| return results | |||||
| def _load_labels(self, results): | |||||
| """Private function to load label annotations. | |||||
| Args: | |||||
| results (dict): Result dict from :obj:`mmdet.CustomDataset`. | |||||
| Returns: | |||||
| dict: The dict contains loaded label annotations. | |||||
| """ | |||||
| results['gt_labels'] = results['ann_info']['labels'].copy() | |||||
| return results | |||||
| def _poly2mask(self, mask_ann, img_h, img_w): | |||||
| """Private function to convert masks represented with polygon to | |||||
| bitmaps. | |||||
| Args: | |||||
| mask_ann (list | dict): Polygon mask annotation input. | |||||
| img_h (int): The height of output mask. | |||||
| img_w (int): The width of output mask. | |||||
| Returns: | |||||
| numpy.ndarray: The decode bitmap mask of shape (img_h, img_w). | |||||
| """ | |||||
| if isinstance(mask_ann, list): | |||||
| # polygon -- a single object might consist of multiple parts | |||||
| # we merge all parts into one mask rle code | |||||
| rles = maskUtils.frPyObjects(mask_ann, img_h, img_w) | |||||
| rle = maskUtils.merge(rles) | |||||
| elif isinstance(mask_ann['counts'], list): | |||||
| # uncompressed RLE | |||||
| rle = maskUtils.frPyObjects(mask_ann, img_h, img_w) | |||||
| else: | |||||
| # rle | |||||
| rle = mask_ann | |||||
| mask = maskUtils.decode(rle) | |||||
| return mask | |||||
| def process_polygons(self, polygons): | |||||
| """Convert polygons to list of ndarray and filter invalid polygons. | |||||
| Args: | |||||
| polygons (list[list]): Polygons of one instance. | |||||
| Returns: | |||||
| list[numpy.ndarray]: Processed polygons. | |||||
| """ | |||||
| polygons = [np.array(p) for p in polygons] | |||||
| valid_polygons = [] | |||||
| for polygon in polygons: | |||||
| if len(polygon) % 2 == 0 and len(polygon) >= 6: | |||||
| valid_polygons.append(polygon) | |||||
| return valid_polygons | |||||
| def _load_masks(self, results): | |||||
| """Private function to load mask annotations. | |||||
| Args: | |||||
| results (dict): Result dict from :obj:`mmdet.CustomDataset`. | |||||
| Returns: | |||||
| dict: The dict contains loaded mask annotations. | |||||
| If ``self.poly2mask`` is set ``True``, `gt_mask` will contain | |||||
| :obj:`PolygonMasks`. Otherwise, :obj:`BitmapMasks` is used. | |||||
| """ | |||||
| h, w = results['img_info']['height'], results['img_info']['width'] | |||||
| gt_masks = results['ann_info']['masks'] | |||||
| if self.poly2mask: | |||||
| gt_masks = BitmapMasks( | |||||
| [self._poly2mask(mask, h, w) for mask in gt_masks], h, w) | |||||
| else: | |||||
| gt_masks = PolygonMasks( | |||||
| [self.process_polygons(polygons) for polygons in gt_masks], h, | |||||
| w) | |||||
| results['gt_masks'] = gt_masks | |||||
| results['mask_fields'].append('gt_masks') | |||||
| return results | |||||
| def _load_semantic_seg(self, results): | |||||
| """Private function to load semantic segmentation annotations. | |||||
| Args: | |||||
| results (dict): Result dict from :obj:`dataset`. | |||||
| Returns: | |||||
| dict: The dict contains loaded semantic segmentation annotations. | |||||
| """ | |||||
| import mmcv | |||||
| if self.file_client is None: | |||||
| self.file_client = mmcv.FileClient(**self.file_client_args) | |||||
| filename = osp.join(results['seg_prefix'], | |||||
| results['ann_info']['seg_map']) | |||||
| img_bytes = self.file_client.get(filename) | |||||
| results['gt_semantic_seg'] = mmcv.imfrombytes( | |||||
| img_bytes, flag='unchanged').squeeze() | |||||
| results['seg_fields'].append('gt_semantic_seg') | |||||
| return results | |||||
| def __call__(self, results): | |||||
| """Call function to load multiple types annotations. | |||||
| Args: | |||||
| results (dict): Result dict from :obj:`mmdet.CustomDataset`. | |||||
| Returns: | |||||
| dict: The dict contains loaded bounding box, label, mask and | |||||
| semantic segmentation annotations. | |||||
| """ | |||||
| if self.with_bbox: | |||||
| results = self._load_bboxes(results) | |||||
| if results is None: | |||||
| return None | |||||
| if self.with_label: | |||||
| results = self._load_labels(results) | |||||
| if self.with_keypoints: | |||||
| results = self._load_keypoints(results) | |||||
| if self.with_mask: | |||||
| results = self._load_masks(results) | |||||
| if self.with_seg: | |||||
| results = self._load_semantic_seg(results) | |||||
| return results | |||||
| def __repr__(self): | |||||
| repr_str = self.__class__.__name__ | |||||
| repr_str += f'(with_bbox={self.with_bbox}, ' | |||||
| repr_str += f'with_label={self.with_label}, ' | |||||
| repr_str += f'with_keypoints={self.with_keypoints}, ' | |||||
| repr_str += f'with_mask={self.with_mask}, ' | |||||
| repr_str += f'with_seg={self.with_seg})' | |||||
| repr_str += f'poly2mask={self.poly2mask})' | |||||
| repr_str += f'poly2mask={self.file_client_args})' | |||||
| return repr_str | |||||
| @@ -0,0 +1,737 @@ | |||||
| """ | |||||
| The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at | |||||
| https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/datasets/pipelines/transforms.py | |||||
| """ | |||||
| import mmcv | |||||
| import numpy as np | |||||
| from mmdet.core.evaluation.bbox_overlaps import bbox_overlaps | |||||
| from mmdet.datasets.builder import PIPELINES | |||||
| from numpy import random | |||||
| @PIPELINES.register_module() | |||||
| class ResizeV2(object): | |||||
| """Resize images & bbox & mask &kps. | |||||
| This transform resizes the input image to some scale. Bboxes and masks are | |||||
| then resized with the same scale factor. If the input dict contains the key | |||||
| "scale", then the scale in the input dict is used, otherwise the specified | |||||
| scale in the init method is used. If the input dict contains the key | |||||
| "scale_factor" (if MultiScaleFlipAug does not give img_scale but | |||||
| scale_factor), the actual scale will be computed by image shape and | |||||
| scale_factor. | |||||
| `img_scale` can either be a tuple (single-scale) or a list of tuple | |||||
| (multi-scale). There are 3 multiscale modes: | |||||
| - ``ratio_range is not None``: randomly sample a ratio from the ratio \ | |||||
| range and multiply it with the image scale. | |||||
| - ``ratio_range is None`` and ``multiscale_mode == "range"``: randomly \ | |||||
| sample a scale from the multiscale range. | |||||
| - ``ratio_range is None`` and ``multiscale_mode == "value"``: randomly \ | |||||
| sample a scale from multiple scales. | |||||
| Args: | |||||
| img_scale (tuple or list[tuple]): Images scales for resizing. | |||||
| multiscale_mode (str): Either "range" or "value". | |||||
| ratio_range (tuple[float]): (min_ratio, max_ratio) | |||||
| keep_ratio (bool): Whether to keep the aspect ratio when resizing the | |||||
| image. | |||||
| bbox_clip_border (bool, optional): Whether clip the objects outside | |||||
| the border of the image. Defaults to True. | |||||
| backend (str): Image resize backend, choices are 'cv2' and 'pillow'. | |||||
| These two backends generates slightly different results. Defaults | |||||
| to 'cv2'. | |||||
| override (bool, optional): Whether to override `scale` and | |||||
| `scale_factor` so as to call resize twice. Default False. If True, | |||||
| after the first resizing, the existed `scale` and `scale_factor` | |||||
| will be ignored so the second resizing can be allowed. | |||||
| This option is a work-around for multiple times of resize in DETR. | |||||
| Defaults to False. | |||||
| """ | |||||
| def __init__(self, | |||||
| img_scale=None, | |||||
| multiscale_mode='range', | |||||
| ratio_range=None, | |||||
| keep_ratio=True, | |||||
| bbox_clip_border=True, | |||||
| backend='cv2', | |||||
| override=False): | |||||
| if img_scale is None: | |||||
| self.img_scale = None | |||||
| else: | |||||
| if isinstance(img_scale, list): | |||||
| self.img_scale = img_scale | |||||
| else: | |||||
| self.img_scale = [img_scale] | |||||
| assert mmcv.is_list_of(self.img_scale, tuple) | |||||
| if ratio_range is not None: | |||||
| # mode 1: given a scale and a range of image ratio | |||||
| assert len(self.img_scale) == 1 | |||||
| else: | |||||
| # mode 2: given multiple scales or a range of scales | |||||
| assert multiscale_mode in ['value', 'range'] | |||||
| self.backend = backend | |||||
| self.multiscale_mode = multiscale_mode | |||||
| self.ratio_range = ratio_range | |||||
| self.keep_ratio = keep_ratio | |||||
| # TODO: refactor the override option in Resize | |||||
| self.override = override | |||||
| self.bbox_clip_border = bbox_clip_border | |||||
| @staticmethod | |||||
| def random_select(img_scales): | |||||
| """Randomly select an img_scale from given candidates. | |||||
| Args: | |||||
| img_scales (list[tuple]): Images scales for selection. | |||||
| Returns: | |||||
| (tuple, int): Returns a tuple ``(img_scale, scale_dix)``, \ | |||||
| where ``img_scale`` is the selected image scale and \ | |||||
| ``scale_idx`` is the selected index in the given candidates. | |||||
| """ | |||||
| assert mmcv.is_list_of(img_scales, tuple) | |||||
| scale_idx = np.random.randint(len(img_scales)) | |||||
| img_scale = img_scales[scale_idx] | |||||
| return img_scale, scale_idx | |||||
| @staticmethod | |||||
| def random_sample(img_scales): | |||||
| """Randomly sample an img_scale when ``multiscale_mode=='range'``. | |||||
| Args: | |||||
| img_scales (list[tuple]): Images scale range for sampling. | |||||
| There must be two tuples in img_scales, which specify the lower | |||||
| and uper bound of image scales. | |||||
| Returns: | |||||
| (tuple, None): Returns a tuple ``(img_scale, None)``, where \ | |||||
| ``img_scale`` is sampled scale and None is just a placeholder \ | |||||
| to be consistent with :func:`random_select`. | |||||
| """ | |||||
| assert mmcv.is_list_of(img_scales, tuple) and len(img_scales) == 2 | |||||
| img_scale_long = [max(s) for s in img_scales] | |||||
| img_scale_short = [min(s) for s in img_scales] | |||||
| long_edge = np.random.randint( | |||||
| min(img_scale_long), | |||||
| max(img_scale_long) + 1) | |||||
| short_edge = np.random.randint( | |||||
| min(img_scale_short), | |||||
| max(img_scale_short) + 1) | |||||
| img_scale = (long_edge, short_edge) | |||||
| return img_scale, None | |||||
| @staticmethod | |||||
| def random_sample_ratio(img_scale, ratio_range): | |||||
| """Randomly sample an img_scale when ``ratio_range`` is specified. | |||||
| A ratio will be randomly sampled from the range specified by | |||||
| ``ratio_range``. Then it would be multiplied with ``img_scale`` to | |||||
| generate sampled scale. | |||||
| Args: | |||||
| img_scale (tuple): Images scale base to multiply with ratio. | |||||
| ratio_range (tuple[float]): The minimum and maximum ratio to scale | |||||
| the ``img_scale``. | |||||
| Returns: | |||||
| (tuple, None): Returns a tuple ``(scale, None)``, where \ | |||||
| ``scale`` is sampled ratio multiplied with ``img_scale`` and \ | |||||
| None is just a placeholder to be consistent with \ | |||||
| :func:`random_select`. | |||||
| """ | |||||
| assert isinstance(img_scale, tuple) and len(img_scale) == 2 | |||||
| min_ratio, max_ratio = ratio_range | |||||
| assert min_ratio <= max_ratio | |||||
| ratio = np.random.random_sample() * (max_ratio - min_ratio) + min_ratio | |||||
| scale = int(img_scale[0] * ratio), int(img_scale[1] * ratio) | |||||
| return scale, None | |||||
| def _random_scale(self, results): | |||||
| """Randomly sample an img_scale according to ``ratio_range`` and | |||||
| ``multiscale_mode``. | |||||
| If ``ratio_range`` is specified, a ratio will be sampled and be | |||||
| multiplied with ``img_scale``. | |||||
| If multiple scales are specified by ``img_scale``, a scale will be | |||||
| sampled according to ``multiscale_mode``. | |||||
| Otherwise, single scale will be used. | |||||
| Args: | |||||
| results (dict): Result dict from :obj:`dataset`. | |||||
| Returns: | |||||
| dict: Two new keys 'scale` and 'scale_idx` are added into \ | |||||
| ``results``, which would be used by subsequent pipelines. | |||||
| """ | |||||
| if self.ratio_range is not None: | |||||
| scale, scale_idx = self.random_sample_ratio( | |||||
| self.img_scale[0], self.ratio_range) | |||||
| elif len(self.img_scale) == 1: | |||||
| scale, scale_idx = self.img_scale[0], 0 | |||||
| elif self.multiscale_mode == 'range': | |||||
| scale, scale_idx = self.random_sample(self.img_scale) | |||||
| elif self.multiscale_mode == 'value': | |||||
| scale, scale_idx = self.random_select(self.img_scale) | |||||
| else: | |||||
| raise NotImplementedError | |||||
| results['scale'] = scale | |||||
| results['scale_idx'] = scale_idx | |||||
| def _resize_img(self, results): | |||||
| """Resize images with ``results['scale']``.""" | |||||
| for key in results.get('img_fields', ['img']): | |||||
| if self.keep_ratio: | |||||
| img, scale_factor = mmcv.imrescale( | |||||
| results[key], | |||||
| results['scale'], | |||||
| return_scale=True, | |||||
| backend=self.backend) | |||||
| # the w_scale and h_scale has minor difference | |||||
| # a real fix should be done in the mmcv.imrescale in the future | |||||
| new_h, new_w = img.shape[:2] | |||||
| h, w = results[key].shape[:2] | |||||
| w_scale = new_w / w | |||||
| h_scale = new_h / h | |||||
| else: | |||||
| img, w_scale, h_scale = mmcv.imresize( | |||||
| results[key], | |||||
| results['scale'], | |||||
| return_scale=True, | |||||
| backend=self.backend) | |||||
| results[key] = img | |||||
| scale_factor = np.array([w_scale, h_scale, w_scale, h_scale], | |||||
| dtype=np.float32) | |||||
| results['img_shape'] = img.shape | |||||
| # in case that there is no padding | |||||
| results['pad_shape'] = img.shape | |||||
| results['scale_factor'] = scale_factor | |||||
| results['keep_ratio'] = self.keep_ratio | |||||
| def _resize_bboxes(self, results): | |||||
| """Resize bounding boxes with ``results['scale_factor']``.""" | |||||
| for key in results.get('bbox_fields', []): | |||||
| bboxes = results[key] * results['scale_factor'] | |||||
| if self.bbox_clip_border: | |||||
| img_shape = results['img_shape'] | |||||
| bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, img_shape[1]) | |||||
| bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0, img_shape[0]) | |||||
| results[key] = bboxes | |||||
| def _resize_keypoints(self, results): | |||||
| """Resize keypoints with ``results['scale_factor']``.""" | |||||
| for key in results.get('keypoints_fields', []): | |||||
| keypointss = results[key].copy() | |||||
| factors = results['scale_factor'] | |||||
| assert factors[0] == factors[2] | |||||
| assert factors[1] == factors[3] | |||||
| keypointss[:, :, 0] *= factors[0] | |||||
| keypointss[:, :, 1] *= factors[1] | |||||
| if self.bbox_clip_border: | |||||
| img_shape = results['img_shape'] | |||||
| keypointss[:, :, 0] = np.clip(keypointss[:, :, 0], 0, | |||||
| img_shape[1]) | |||||
| keypointss[:, :, 1] = np.clip(keypointss[:, :, 1], 0, | |||||
| img_shape[0]) | |||||
| results[key] = keypointss | |||||
| def _resize_masks(self, results): | |||||
| """Resize masks with ``results['scale']``""" | |||||
| for key in results.get('mask_fields', []): | |||||
| if results[key] is None: | |||||
| continue | |||||
| if self.keep_ratio: | |||||
| results[key] = results[key].rescale(results['scale']) | |||||
| else: | |||||
| results[key] = results[key].resize(results['img_shape'][:2]) | |||||
| def _resize_seg(self, results): | |||||
| """Resize semantic segmentation map with ``results['scale']``.""" | |||||
| for key in results.get('seg_fields', []): | |||||
| if self.keep_ratio: | |||||
| gt_seg = mmcv.imrescale( | |||||
| results[key], | |||||
| results['scale'], | |||||
| interpolation='nearest', | |||||
| backend=self.backend) | |||||
| else: | |||||
| gt_seg = mmcv.imresize( | |||||
| results[key], | |||||
| results['scale'], | |||||
| interpolation='nearest', | |||||
| backend=self.backend) | |||||
| results['gt_semantic_seg'] = gt_seg | |||||
| def __call__(self, results): | |||||
| """Call function to resize images, bounding boxes, masks, semantic | |||||
| segmentation map. | |||||
| Args: | |||||
| results (dict): Result dict from loading pipeline. | |||||
| Returns: | |||||
| dict: Resized results, 'img_shape', 'pad_shape', 'scale_factor', \ | |||||
| 'keep_ratio' keys are added into result dict. | |||||
| """ | |||||
| if 'scale' not in results: | |||||
| if 'scale_factor' in results: | |||||
| img_shape = results['img'].shape[:2] | |||||
| scale_factor = results['scale_factor'] | |||||
| assert isinstance(scale_factor, float) | |||||
| results['scale'] = tuple( | |||||
| [int(x * scale_factor) for x in img_shape][::-1]) | |||||
| else: | |||||
| self._random_scale(results) | |||||
| else: | |||||
| if not self.override: | |||||
| assert 'scale_factor' not in results, ( | |||||
| 'scale and scale_factor cannot be both set.') | |||||
| else: | |||||
| results.pop('scale') | |||||
| if 'scale_factor' in results: | |||||
| results.pop('scale_factor') | |||||
| self._random_scale(results) | |||||
| self._resize_img(results) | |||||
| self._resize_bboxes(results) | |||||
| self._resize_keypoints(results) | |||||
| self._resize_masks(results) | |||||
| self._resize_seg(results) | |||||
| return results | |||||
| def __repr__(self): | |||||
| repr_str = self.__class__.__name__ | |||||
| repr_str += f'(img_scale={self.img_scale}, ' | |||||
| repr_str += f'multiscale_mode={self.multiscale_mode}, ' | |||||
| repr_str += f'ratio_range={self.ratio_range}, ' | |||||
| repr_str += f'keep_ratio={self.keep_ratio})' | |||||
| repr_str += f'bbox_clip_border={self.bbox_clip_border})' | |||||
| return repr_str | |||||
| @PIPELINES.register_module() | |||||
| class RandomFlipV2(object): | |||||
| """Flip the image & bbox & mask & kps. | |||||
| If the input dict contains the key "flip", then the flag will be used, | |||||
| otherwise it will be randomly decided by a ratio specified in the init | |||||
| method. | |||||
| When random flip is enabled, ``flip_ratio``/``direction`` can either be a | |||||
| float/string or tuple of float/string. There are 3 flip modes: | |||||
| - ``flip_ratio`` is float, ``direction`` is string: the image will be | |||||
| ``direction``ly flipped with probability of ``flip_ratio`` . | |||||
| E.g., ``flip_ratio=0.5``, ``direction='horizontal'``, | |||||
| then image will be horizontally flipped with probability of 0.5. | |||||
| - ``flip_ratio`` is float, ``direction`` is list of string: the image wil | |||||
| be ``direction[i]``ly flipped with probability of | |||||
| ``flip_ratio/len(direction)``. | |||||
| E.g., ``flip_ratio=0.5``, ``direction=['horizontal', 'vertical']``, | |||||
| then image will be horizontally flipped with probability of 0.25, | |||||
| vertically with probability of 0.25. | |||||
| - ``flip_ratio`` is list of float, ``direction`` is list of string: | |||||
| given ``len(flip_ratio) == len(direction)``, the image wil | |||||
| be ``direction[i]``ly flipped with probability of ``flip_ratio[i]``. | |||||
| E.g., ``flip_ratio=[0.3, 0.5]``, ``direction=['horizontal', | |||||
| 'vertical']``, then image will be horizontally flipped with probability | |||||
| of 0.3, vertically with probability of 0.5 | |||||
| Args: | |||||
| flip_ratio (float | list[float], optional): The flipping probability. | |||||
| Default: None. | |||||
| direction(str | list[str], optional): The flipping direction. Options | |||||
| are 'horizontal', 'vertical', 'diagonal'. Default: 'horizontal'. | |||||
| If input is a list, the length must equal ``flip_ratio``. Each | |||||
| element in ``flip_ratio`` indicates the flip probability of | |||||
| corresponding direction. | |||||
| """ | |||||
| def __init__(self, flip_ratio=None, direction='horizontal'): | |||||
| if isinstance(flip_ratio, list): | |||||
| assert mmcv.is_list_of(flip_ratio, float) | |||||
| assert 0 <= sum(flip_ratio) <= 1 | |||||
| elif isinstance(flip_ratio, float): | |||||
| assert 0 <= flip_ratio <= 1 | |||||
| elif flip_ratio is None: | |||||
| pass | |||||
| else: | |||||
| raise ValueError('flip_ratios must be None, float, ' | |||||
| 'or list of float') | |||||
| self.flip_ratio = flip_ratio | |||||
| valid_directions = ['horizontal', 'vertical', 'diagonal'] | |||||
| if isinstance(direction, str): | |||||
| assert direction in valid_directions | |||||
| elif isinstance(direction, list): | |||||
| assert mmcv.is_list_of(direction, str) | |||||
| assert set(direction).issubset(set(valid_directions)) | |||||
| else: | |||||
| raise ValueError('direction must be either str or list of str') | |||||
| self.direction = direction | |||||
| if isinstance(flip_ratio, list): | |||||
| assert len(self.flip_ratio) == len(self.direction) | |||||
| self.count = 0 | |||||
| def bbox_flip(self, bboxes, img_shape, direction): | |||||
| """Flip bboxes horizontally. | |||||
| Args: | |||||
| bboxes (numpy.ndarray): Bounding boxes, shape (..., 4*k) | |||||
| img_shape (tuple[int]): Image shape (height, width) | |||||
| direction (str): Flip direction. Options are 'horizontal', | |||||
| 'vertical'. | |||||
| Returns: | |||||
| numpy.ndarray: Flipped bounding boxes. | |||||
| """ | |||||
| assert bboxes.shape[-1] % 4 == 0 | |||||
| flipped = bboxes.copy() | |||||
| if direction == 'horizontal': | |||||
| w = img_shape[1] | |||||
| flipped[..., 0::4] = w - bboxes[..., 2::4] | |||||
| flipped[..., 2::4] = w - bboxes[..., 0::4] | |||||
| elif direction == 'vertical': | |||||
| h = img_shape[0] | |||||
| flipped[..., 1::4] = h - bboxes[..., 3::4] | |||||
| flipped[..., 3::4] = h - bboxes[..., 1::4] | |||||
| elif direction == 'diagonal': | |||||
| w = img_shape[1] | |||||
| h = img_shape[0] | |||||
| flipped[..., 0::4] = w - bboxes[..., 2::4] | |||||
| flipped[..., 1::4] = h - bboxes[..., 3::4] | |||||
| flipped[..., 2::4] = w - bboxes[..., 0::4] | |||||
| flipped[..., 3::4] = h - bboxes[..., 1::4] | |||||
| else: | |||||
| raise ValueError(f"Invalid flipping direction '{direction}'") | |||||
| return flipped | |||||
| def keypoints_flip(self, keypointss, img_shape, direction): | |||||
| """Flip keypoints horizontally.""" | |||||
| assert direction == 'horizontal' | |||||
| assert keypointss.shape[-1] == 3 | |||||
| num_kps = keypointss.shape[1] | |||||
| assert num_kps in [4, 5], f'Only Support num_kps=4 or 5, got:{num_kps}' | |||||
| assert keypointss.ndim == 3 | |||||
| flipped = keypointss.copy() | |||||
| if num_kps == 5: | |||||
| flip_order = [1, 0, 2, 4, 3] | |||||
| elif num_kps == 4: | |||||
| flip_order = [3, 2, 1, 0] | |||||
| for idx, a in enumerate(flip_order): | |||||
| flipped[:, idx, :] = keypointss[:, a, :] | |||||
| w = img_shape[1] | |||||
| flipped[..., 0] = w - flipped[..., 0] | |||||
| return flipped | |||||
| def __call__(self, results): | |||||
| """Call function to flip bounding boxes, masks, semantic segmentation | |||||
| maps. | |||||
| Args: | |||||
| results (dict): Result dict from loading pipeline. | |||||
| Returns: | |||||
| dict: Flipped results, 'flip', 'flip_direction' keys are added \ | |||||
| into result dict. | |||||
| """ | |||||
| if 'flip' not in results: | |||||
| if isinstance(self.direction, list): | |||||
| # None means non-flip | |||||
| direction_list = self.direction + [None] | |||||
| else: | |||||
| # None means non-flip | |||||
| direction_list = [self.direction, None] | |||||
| if isinstance(self.flip_ratio, list): | |||||
| non_flip_ratio = 1 - sum(self.flip_ratio) | |||||
| flip_ratio_list = self.flip_ratio + [non_flip_ratio] | |||||
| else: | |||||
| non_flip_ratio = 1 - self.flip_ratio | |||||
| # exclude non-flip | |||||
| single_ratio = self.flip_ratio / (len(direction_list) - 1) | |||||
| flip_ratio_list = [single_ratio] * (len(direction_list) | |||||
| - 1) + [non_flip_ratio] | |||||
| cur_dir = np.random.choice(direction_list, p=flip_ratio_list) | |||||
| results['flip'] = cur_dir is not None | |||||
| if 'flip_direction' not in results: | |||||
| results['flip_direction'] = cur_dir | |||||
| if results['flip']: | |||||
| # flip image | |||||
| for key in results.get('img_fields', ['img']): | |||||
| results[key] = mmcv.imflip( | |||||
| results[key], direction=results['flip_direction']) | |||||
| # flip bboxes | |||||
| for key in results.get('bbox_fields', []): | |||||
| results[key] = self.bbox_flip(results[key], | |||||
| results['img_shape'], | |||||
| results['flip_direction']) | |||||
| # flip kps | |||||
| for key in results.get('keypoints_fields', []): | |||||
| results[key] = self.keypoints_flip(results[key], | |||||
| results['img_shape'], | |||||
| results['flip_direction']) | |||||
| # flip masks | |||||
| for key in results.get('mask_fields', []): | |||||
| results[key] = results[key].flip(results['flip_direction']) | |||||
| # flip segs | |||||
| for key in results.get('seg_fields', []): | |||||
| results[key] = mmcv.imflip( | |||||
| results[key], direction=results['flip_direction']) | |||||
| return results | |||||
| def __repr__(self): | |||||
| return self.__class__.__name__ + f'(flip_ratio={self.flip_ratio})' | |||||
| @PIPELINES.register_module() | |||||
| class RandomSquareCrop(object): | |||||
| """Random crop the image & bboxes, the cropped patches have minimum IoU | |||||
| requirement with original image & bboxes, the IoU threshold is randomly | |||||
| selected from min_ious. | |||||
| Args: | |||||
| min_ious (tuple): minimum IoU threshold for all intersections with | |||||
| bounding boxes | |||||
| min_crop_size (float): minimum crop's size (i.e. h,w := a*h, a*w, | |||||
| where a >= min_crop_size). | |||||
| Note: | |||||
| The keys for bboxes, labels and masks should be paired. That is, \ | |||||
| `gt_bboxes` corresponds to `gt_labels` and `gt_masks`, and \ | |||||
| `gt_bboxes_ignore` to `gt_labels_ignore` and `gt_masks_ignore`. | |||||
| """ | |||||
| def __init__(self, | |||||
| crop_ratio_range=None, | |||||
| crop_choice=None, | |||||
| bbox_clip_border=True, | |||||
| big_face_ratio=0, | |||||
| big_face_crop_choice=None): | |||||
| self.crop_ratio_range = crop_ratio_range | |||||
| self.crop_choice = crop_choice | |||||
| self.big_face_crop_choice = big_face_crop_choice | |||||
| self.bbox_clip_border = bbox_clip_border | |||||
| assert (self.crop_ratio_range is None) ^ (self.crop_choice is None) | |||||
| if self.crop_ratio_range is not None: | |||||
| self.crop_ratio_min, self.crop_ratio_max = self.crop_ratio_range | |||||
| self.bbox2label = { | |||||
| 'gt_bboxes': 'gt_labels', | |||||
| 'gt_bboxes_ignore': 'gt_labels_ignore' | |||||
| } | |||||
| self.bbox2mask = { | |||||
| 'gt_bboxes': 'gt_masks', | |||||
| 'gt_bboxes_ignore': 'gt_masks_ignore' | |||||
| } | |||||
| assert big_face_ratio >= 0 and big_face_ratio <= 1.0 | |||||
| self.big_face_ratio = big_face_ratio | |||||
| def __call__(self, results): | |||||
| """Call function to crop images and bounding boxes with minimum IoU | |||||
| constraint. | |||||
| Args: | |||||
| results (dict): Result dict from loading pipeline. | |||||
| Returns: | |||||
| dict: Result dict with images and bounding boxes cropped, \ | |||||
| 'img_shape' key is updated. | |||||
| """ | |||||
| if 'img_fields' in results: | |||||
| assert results['img_fields'] == ['img'], \ | |||||
| 'Only single img_fields is allowed' | |||||
| img = results['img'] | |||||
| assert 'bbox_fields' in results | |||||
| assert 'gt_bboxes' in results | |||||
| # try augment big face images | |||||
| find_bigface = False | |||||
| if np.random.random() < self.big_face_ratio: | |||||
| min_size = 100 # h and w | |||||
| expand_ratio = 0.3 # expand ratio of croped face alongwith both w and h | |||||
| bbox = results['gt_bboxes'].copy() | |||||
| lmks = results['gt_keypointss'].copy() | |||||
| label = results['gt_labels'].copy() | |||||
| # filter small faces | |||||
| size_mask = ((bbox[:, 2] - bbox[:, 0]) > min_size) * ( | |||||
| (bbox[:, 3] - bbox[:, 1]) > min_size) | |||||
| bbox = bbox[size_mask] | |||||
| lmks = lmks[size_mask] | |||||
| label = label[size_mask] | |||||
| # randomly choose a face that has no overlap with others | |||||
| if len(bbox) > 0: | |||||
| overlaps = bbox_overlaps(bbox, bbox) | |||||
| overlaps -= np.eye(overlaps.shape[0]) | |||||
| iou_mask = np.sum(overlaps, axis=1) == 0 | |||||
| bbox = bbox[iou_mask] | |||||
| lmks = lmks[iou_mask] | |||||
| label = label[iou_mask] | |||||
| if len(bbox) > 0: | |||||
| choice = np.random.randint(len(bbox)) | |||||
| bbox = bbox[choice] | |||||
| lmks = lmks[choice] | |||||
| label = [label[choice]] | |||||
| w = bbox[2] - bbox[0] | |||||
| h = bbox[3] - bbox[1] | |||||
| x1 = bbox[0] - w * expand_ratio | |||||
| x2 = bbox[2] + w * expand_ratio | |||||
| y1 = bbox[1] - h * expand_ratio | |||||
| y2 = bbox[3] + h * expand_ratio | |||||
| x1, x2 = np.clip([x1, x2], 0, img.shape[1]) | |||||
| y1, y2 = np.clip([y1, y2], 0, img.shape[0]) | |||||
| bbox -= np.tile([x1, y1], 2) | |||||
| lmks -= (x1, y1, 0) | |||||
| find_bigface = True | |||||
| img = img[int(y1):int(y2), int(x1):int(x2), :] | |||||
| results['gt_bboxes'] = np.expand_dims(bbox, axis=0) | |||||
| results['gt_keypointss'] = np.expand_dims(lmks, axis=0) | |||||
| results['gt_labels'] = np.array(label) | |||||
| results['img'] = img | |||||
| boxes = results['gt_bboxes'] | |||||
| h, w, c = img.shape | |||||
| if self.crop_ratio_range is not None: | |||||
| max_scale = self.crop_ratio_max | |||||
| else: | |||||
| max_scale = np.amax(self.crop_choice) | |||||
| scale_retry = 0 | |||||
| while True: | |||||
| scale_retry += 1 | |||||
| if scale_retry == 1 or max_scale > 1.0: | |||||
| if self.crop_ratio_range is not None: | |||||
| scale = np.random.uniform(self.crop_ratio_min, | |||||
| self.crop_ratio_max) | |||||
| elif self.crop_choice is not None: | |||||
| scale = np.random.choice(self.crop_choice) | |||||
| else: | |||||
| scale = scale * 1.2 | |||||
| if find_bigface: | |||||
| # select a scale from big_face_crop_choice if in big_face mode | |||||
| scale = np.random.choice(self.big_face_crop_choice) | |||||
| for i in range(250): | |||||
| long_side = max(w, h) | |||||
| cw = int(scale * long_side) | |||||
| ch = cw | |||||
| # TODO +1 | |||||
| if w == cw: | |||||
| left = 0 | |||||
| elif w > cw: | |||||
| left = random.randint(0, w - cw) | |||||
| else: | |||||
| left = random.randint(w - cw, 0) | |||||
| if h == ch: | |||||
| top = 0 | |||||
| elif h > ch: | |||||
| top = random.randint(0, h - ch) | |||||
| else: | |||||
| top = random.randint(h - ch, 0) | |||||
| patch = np.array( | |||||
| (int(left), int(top), int(left + cw), int(top + ch)), | |||||
| dtype=np.int32) | |||||
| # center of boxes should inside the crop img | |||||
| # only adjust boxes and instance masks when the gt is not empty | |||||
| # adjust boxes | |||||
| def is_center_of_bboxes_in_patch(boxes, patch): | |||||
| # TODO >= | |||||
| center = (boxes[:, :2] + boxes[:, 2:]) / 2 | |||||
| mask = \ | |||||
| ((center[:, 0] > patch[0]) | |||||
| * (center[:, 1] > patch[1]) | |||||
| * (center[:, 0] < patch[2]) | |||||
| * (center[:, 1] < patch[3])) | |||||
| return mask | |||||
| mask = is_center_of_bboxes_in_patch(boxes, patch) | |||||
| if not mask.any(): | |||||
| continue | |||||
| for key in results.get('bbox_fields', []): | |||||
| boxes = results[key].copy() | |||||
| mask = is_center_of_bboxes_in_patch(boxes, patch) | |||||
| boxes = boxes[mask] | |||||
| if self.bbox_clip_border: | |||||
| boxes[:, 2:] = boxes[:, 2:].clip(max=patch[2:]) | |||||
| boxes[:, :2] = boxes[:, :2].clip(min=patch[:2]) | |||||
| boxes -= np.tile(patch[:2], 2) | |||||
| results[key] = boxes | |||||
| # labels | |||||
| label_key = self.bbox2label.get(key) | |||||
| if label_key in results: | |||||
| results[label_key] = results[label_key][mask] | |||||
| # keypoints field | |||||
| if key == 'gt_bboxes': | |||||
| for kps_key in results.get('keypoints_fields', []): | |||||
| keypointss = results[kps_key].copy() | |||||
| keypointss = keypointss[mask, :, :] | |||||
| if self.bbox_clip_border: | |||||
| keypointss[:, :, : | |||||
| 2] = keypointss[:, :, :2].clip( | |||||
| max=patch[2:]) | |||||
| keypointss[:, :, : | |||||
| 2] = keypointss[:, :, :2].clip( | |||||
| min=patch[:2]) | |||||
| keypointss[:, :, 0] -= patch[0] | |||||
| keypointss[:, :, 1] -= patch[1] | |||||
| results[kps_key] = keypointss | |||||
| # mask fields | |||||
| mask_key = self.bbox2mask.get(key) | |||||
| if mask_key in results: | |||||
| results[mask_key] = results[mask_key][mask.nonzero() | |||||
| [0]].crop(patch) | |||||
| # adjust the img no matter whether the gt is empty before crop | |||||
| rimg = np.ones((ch, cw, 3), dtype=img.dtype) * 128 | |||||
| patch_from = patch.copy() | |||||
| patch_from[0] = max(0, patch_from[0]) | |||||
| patch_from[1] = max(0, patch_from[1]) | |||||
| patch_from[2] = min(img.shape[1], patch_from[2]) | |||||
| patch_from[3] = min(img.shape[0], patch_from[3]) | |||||
| patch_to = patch.copy() | |||||
| patch_to[0] = max(0, patch_to[0] * -1) | |||||
| patch_to[1] = max(0, patch_to[1] * -1) | |||||
| patch_to[2] = patch_to[0] + (patch_from[2] - patch_from[0]) | |||||
| patch_to[3] = patch_to[1] + (patch_from[3] - patch_from[1]) | |||||
| rimg[patch_to[1]:patch_to[3], | |||||
| patch_to[0]:patch_to[2], :] = img[ | |||||
| patch_from[1]:patch_from[3], | |||||
| patch_from[0]:patch_from[2], :] | |||||
| img = rimg | |||||
| results['img'] = img | |||||
| results['img_shape'] = img.shape | |||||
| return results | |||||
| def __repr__(self): | |||||
| repr_str = self.__class__.__name__ | |||||
| repr_str += f'(min_ious={self.min_iou}, ' | |||||
| repr_str += f'crop_size={self.crop_size})' | |||||
| return repr_str | |||||
| @@ -13,7 +13,7 @@ class RetinaFaceDataset(CustomDataset): | |||||
| CLASSES = ('FG', ) | CLASSES = ('FG', ) | ||||
| def __init__(self, min_size=None, **kwargs): | def __init__(self, min_size=None, **kwargs): | ||||
| self.NK = 5 | |||||
| self.NK = kwargs.pop('num_kps', 5) | |||||
| self.cat2label = {cat: i for i, cat in enumerate(self.CLASSES)} | self.cat2label = {cat: i for i, cat in enumerate(self.CLASSES)} | ||||
| self.min_size = min_size | self.min_size = min_size | ||||
| self.gt_path = kwargs.get('gt_path') | self.gt_path = kwargs.get('gt_path') | ||||
| @@ -33,7 +33,8 @@ class RetinaFaceDataset(CustomDataset): | |||||
| if len(values) > 4: | if len(values) > 4: | ||||
| if len(values) > 5: | if len(values) > 5: | ||||
| kps = np.array( | kps = np.array( | ||||
| values[4:19], dtype=np.float32).reshape((self.NK, 3)) | |||||
| values[4:4 + self.NK * 3], dtype=np.float32).reshape( | |||||
| (self.NK, 3)) | |||||
| for li in range(kps.shape[0]): | for li in range(kps.shape[0]): | ||||
| if (kps[li, :] == -1).all(): | if (kps[li, :] == -1).all(): | ||||
| kps[li][2] = 0.0 # weight = 0, ignore | kps[li][2] = 0.0 # weight = 0, ignore | ||||
| @@ -103,6 +103,7 @@ class SCRFDHead(AnchorHead): | |||||
| scale_mode=1, | scale_mode=1, | ||||
| dw_conv=False, | dw_conv=False, | ||||
| use_kps=False, | use_kps=False, | ||||
| num_kps=5, | |||||
| loss_kps=dict( | loss_kps=dict( | ||||
| type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=0.1), | type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=0.1), | ||||
| **kwargs): | **kwargs): | ||||
| @@ -116,7 +117,7 @@ class SCRFDHead(AnchorHead): | |||||
| self.scale_mode = scale_mode | self.scale_mode = scale_mode | ||||
| self.use_dfl = True | self.use_dfl = True | ||||
| self.dw_conv = dw_conv | self.dw_conv = dw_conv | ||||
| self.NK = 5 | |||||
| self.NK = num_kps | |||||
| self.extra_flops = 0.0 | self.extra_flops = 0.0 | ||||
| if loss_dfl is None or not loss_dfl: | if loss_dfl is None or not loss_dfl: | ||||
| self.use_dfl = False | self.use_dfl = False | ||||
| @@ -323,8 +324,8 @@ class SCRFDHead(AnchorHead): | |||||
| batch_size, -1, self.cls_out_channels).sigmoid() | batch_size, -1, self.cls_out_channels).sigmoid() | ||||
| bbox_pred = bbox_pred.permute(0, 2, 3, | bbox_pred = bbox_pred.permute(0, 2, 3, | ||||
| 1).reshape(batch_size, -1, 4) | 1).reshape(batch_size, -1, 4) | ||||
| kps_pred = kps_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, 10) | |||||
| kps_pred = kps_pred.permute(0, 2, 3, | |||||
| 1).reshape(batch_size, -1, self.NK * 2) | |||||
| return cls_score, bbox_pred, kps_pred | return cls_score, bbox_pred, kps_pred | ||||
| def forward_train(self, | def forward_train(self, | ||||
| @@ -788,7 +789,7 @@ class SCRFDHead(AnchorHead): | |||||
| if self.use_dfl: | if self.use_dfl: | ||||
| kps_pred = self.integral(kps_pred) * stride[0] | kps_pred = self.integral(kps_pred) * stride[0] | ||||
| else: | else: | ||||
| kps_pred = kps_pred.reshape((-1, 10)) * stride[0] | |||||
| kps_pred = kps_pred.reshape((-1, self.NK * 2)) * stride[0] | |||||
| nms_pre = cfg.get('nms_pre', -1) | nms_pre = cfg.get('nms_pre', -1) | ||||
| if nms_pre > 0 and scores.shape[0] > nms_pre: | if nms_pre > 0 and scores.shape[0] > nms_pre: | ||||
| @@ -815,7 +816,7 @@ class SCRFDHead(AnchorHead): | |||||
| mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor) | mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor) | ||||
| if mlvl_kps is not None: | if mlvl_kps is not None: | ||||
| scale_factor2 = torch.tensor( | scale_factor2 = torch.tensor( | ||||
| [scale_factor[0], scale_factor[1]] * 5) | |||||
| [scale_factor[0], scale_factor[1]] * self.NK) | |||||
| mlvl_kps /= scale_factor2.to(mlvl_kps.device) | mlvl_kps /= scale_factor2.to(mlvl_kps.device) | ||||
| mlvl_scores = torch.cat(mlvl_scores) | mlvl_scores = torch.cat(mlvl_scores) | ||||