Browse Source

[to #42322933] support finetune on cv/hand_2d_keypoints

添加2d手部关键点检测finetune功能
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10371710
master
hanyuan.chy yingda.chen 3 years ago
parent
commit
2d50c812df
6 changed files with 170 additions and 0 deletions
  1. +2
    -0
      modelscope/metainfo.py
  2. +20
    -0
      modelscope/models/cv/hand_2d_keypoints/__init__.py
  3. +16
    -0
      modelscope/models/cv/hand_2d_keypoints/hand_2d_keypoints.py
  4. +22
    -0
      modelscope/msdatasets/cv/hand_2d_keypoints/__init__.py
  5. +38
    -0
      modelscope/msdatasets/cv/hand_2d_keypoints/hand_2d_keypoints_dataset.py
  6. +72
    -0
      tests/trainers/easycv/test_easycv_trainer_hand_2d_keypoints.py

+ 2
- 0
modelscope/metainfo.py View File

@@ -50,6 +50,7 @@ class Models(object):
# EasyCV models # EasyCV models
yolox = 'YOLOX' yolox = 'YOLOX'
segformer = 'Segformer' segformer = 'Segformer'
hand_2d_keypoints = 'HRNet-Hand2D-Keypoints'
image_object_detection_auto = 'image-object-detection-auto' image_object_detection_auto = 'image-object-detection-auto'


# nlp models # nlp models
@@ -439,6 +440,7 @@ class Datasets(object):
""" """
ClsDataset = 'ClsDataset' ClsDataset = 'ClsDataset'
Face2dKeypointsDataset = 'Face2dKeypointsDataset' Face2dKeypointsDataset = 'Face2dKeypointsDataset'
HandCocoWholeBodyDataset = 'HandCocoWholeBodyDataset'
HumanWholeBodyKeypointDataset = 'HumanWholeBodyKeypointDataset' HumanWholeBodyKeypointDataset = 'HumanWholeBodyKeypointDataset'
SegDataset = 'SegDataset' SegDataset = 'SegDataset'
DetDataset = 'DetDataset' DetDataset = 'DetDataset'


+ 20
- 0
modelscope/models/cv/hand_2d_keypoints/__init__.py View File

@@ -0,0 +1,20 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import TYPE_CHECKING

from modelscope.utils.import_utils import LazyImportModule

if TYPE_CHECKING:
from .hand_2d_keypoints import Hand2dKeyPoints

else:
_import_structure = {'hand_2d_keypoints': ['Hand2dKeyPoints']}

import sys

sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

+ 16
- 0
modelscope/models/cv/hand_2d_keypoints/hand_2d_keypoints.py View File

@@ -0,0 +1,16 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from easycv.models.pose import TopDown

from modelscope.metainfo import Models
from modelscope.models.builder import MODELS
from modelscope.models.cv.easycv_base import EasyCVBaseModel
from modelscope.utils.constant import Tasks


@MODELS.register_module(
group_key=Tasks.hand_2d_keypoints, module_name=Models.hand_2d_keypoints)
class Hand2dKeyPoints(EasyCVBaseModel, TopDown):

def __init__(self, model_dir=None, *args, **kwargs):
EasyCVBaseModel.__init__(self, model_dir, args, kwargs)
TopDown.__init__(self, *args, **kwargs)

+ 22
- 0
modelscope/msdatasets/cv/hand_2d_keypoints/__init__.py View File

@@ -0,0 +1,22 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import TYPE_CHECKING

from modelscope.utils.import_utils import LazyImportModule

if TYPE_CHECKING:
from .hand_2d_keypoints_dataset import Hand2DKeypointDataset

else:
_import_structure = {
'hand_2d_keypoints_dataset': ['Hand2DKeypointDataset']
}

import sys

sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

+ 38
- 0
modelscope/msdatasets/cv/hand_2d_keypoints/hand_2d_keypoints_dataset.py View File

@@ -0,0 +1,38 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from easycv.datasets.pose import \
HandCocoWholeBodyDataset as _HandCocoWholeBodyDataset

from modelscope.metainfo import Datasets
from modelscope.msdatasets.cv.easycv_base import EasyCVBaseDataset
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS
from modelscope.utils.constant import Tasks


@TASK_DATASETS.register_module(
group_key=Tasks.hand_2d_keypoints,
module_name=Datasets.HandCocoWholeBodyDataset)
class HandCocoWholeBodyDataset(EasyCVBaseDataset, _HandCocoWholeBodyDataset):
"""EasyCV dataset for human hand 2d keypoints.

Args:
split_config (dict): Dataset root path from MSDataset, e.g.
{"train":"local cache path"} or {"evaluation":"local cache path"}.
preprocessor (Preprocessor): An optional preprocessor instance, please make sure the preprocessor fits for
the model if supplied. Not support yet.
mode: Training or Evaluation.
"""

def __init__(self,
split_config=None,
preprocessor=None,
mode=None,
*args,
**kwargs) -> None:
EasyCVBaseDataset.__init__(
self,
split_config=split_config,
preprocessor=preprocessor,
mode=mode,
args=args,
kwargs=kwargs)
_HandCocoWholeBodyDataset.__init__(self, *args, **kwargs)

+ 72
- 0
tests/trainers/easycv/test_easycv_trainer_hand_2d_keypoints.py View File

@@ -0,0 +1,72 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import glob
import os
import shutil
import tempfile
import unittest

import torch

from modelscope.metainfo import Trainers
from modelscope.msdatasets import MsDataset
from modelscope.trainers import build_trainer
from modelscope.utils.constant import DownloadMode, LogKeys, Tasks
from modelscope.utils.logger import get_logger
from modelscope.utils.test_utils import test_level


@unittest.skipIf(not torch.cuda.is_available(), 'cuda unittest')
class EasyCVTrainerTestHand2dKeypoints(unittest.TestCase):
model_id = 'damo/cv_hrnetw18_hand-pose-keypoints_coco-wholebody'

def setUp(self):
self.logger = get_logger()
self.logger.info(('Testing %s.%s' %
(type(self).__name__, self._testMethodName)))
self.tmp_dir = tempfile.TemporaryDirectory().name
if not os.path.exists(self.tmp_dir):
os.makedirs(self.tmp_dir)

def tearDown(self):
super().tearDown()
shutil.rmtree(self.tmp_dir, ignore_errors=True)

def _train(self):
cfg_options = {'train.max_epochs': 20}

trainer_name = Trainers.easycv

train_dataset = MsDataset.load(
dataset_name='cv_hand_2d_keypoints_coco_wholebody',
namespace='chenhyer',
split='subtrain',
download_mode=DownloadMode.FORCE_REDOWNLOAD)
eval_dataset = MsDataset.load(
dataset_name='cv_hand_2d_keypoints_coco_wholebody',
namespace='chenhyer',
split='subtrain',
download_mode=DownloadMode.FORCE_REDOWNLOAD)

kwargs = dict(
model=self.model_id,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
work_dir=self.tmp_dir,
cfg_options=cfg_options)

trainer = build_trainer(trainer_name, kwargs)
trainer.train()

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_trainer_single_gpu(self):
self._train()

results_files = os.listdir(self.tmp_dir)
json_files = glob.glob(os.path.join(self.tmp_dir, '*.log.json'))
self.assertEqual(len(json_files), 1)
self.assertIn(f'{LogKeys.EPOCH}_10.pth', results_files)
self.assertIn(f'{LogKeys.EPOCH}_20.pth', results_files)


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save