添加2d手部关键点检测finetune功能
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10371710
master
| @@ -50,6 +50,7 @@ class Models(object): | |||||
| # EasyCV models | # EasyCV models | ||||
| yolox = 'YOLOX' | yolox = 'YOLOX' | ||||
| segformer = 'Segformer' | segformer = 'Segformer' | ||||
| hand_2d_keypoints = 'HRNet-Hand2D-Keypoints' | |||||
| image_object_detection_auto = 'image-object-detection-auto' | image_object_detection_auto = 'image-object-detection-auto' | ||||
| # nlp models | # nlp models | ||||
| @@ -439,6 +440,7 @@ class Datasets(object): | |||||
| """ | """ | ||||
| ClsDataset = 'ClsDataset' | ClsDataset = 'ClsDataset' | ||||
| Face2dKeypointsDataset = 'Face2dKeypointsDataset' | Face2dKeypointsDataset = 'Face2dKeypointsDataset' | ||||
| HandCocoWholeBodyDataset = 'HandCocoWholeBodyDataset' | |||||
| HumanWholeBodyKeypointDataset = 'HumanWholeBodyKeypointDataset' | HumanWholeBodyKeypointDataset = 'HumanWholeBodyKeypointDataset' | ||||
| SegDataset = 'SegDataset' | SegDataset = 'SegDataset' | ||||
| DetDataset = 'DetDataset' | DetDataset = 'DetDataset' | ||||
| @@ -0,0 +1,20 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from typing import TYPE_CHECKING | |||||
| from modelscope.utils.import_utils import LazyImportModule | |||||
| if TYPE_CHECKING: | |||||
| from .hand_2d_keypoints import Hand2dKeyPoints | |||||
| else: | |||||
| _import_structure = {'hand_2d_keypoints': ['Hand2dKeyPoints']} | |||||
| import sys | |||||
| sys.modules[__name__] = LazyImportModule( | |||||
| __name__, | |||||
| globals()['__file__'], | |||||
| _import_structure, | |||||
| module_spec=__spec__, | |||||
| extra_objects={}, | |||||
| ) | |||||
| @@ -0,0 +1,16 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from easycv.models.pose import TopDown | |||||
| from modelscope.metainfo import Models | |||||
| from modelscope.models.builder import MODELS | |||||
| from modelscope.models.cv.easycv_base import EasyCVBaseModel | |||||
| from modelscope.utils.constant import Tasks | |||||
| @MODELS.register_module( | |||||
| group_key=Tasks.hand_2d_keypoints, module_name=Models.hand_2d_keypoints) | |||||
| class Hand2dKeyPoints(EasyCVBaseModel, TopDown): | |||||
| def __init__(self, model_dir=None, *args, **kwargs): | |||||
| EasyCVBaseModel.__init__(self, model_dir, args, kwargs) | |||||
| TopDown.__init__(self, *args, **kwargs) | |||||
| @@ -0,0 +1,22 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from typing import TYPE_CHECKING | |||||
| from modelscope.utils.import_utils import LazyImportModule | |||||
| if TYPE_CHECKING: | |||||
| from .hand_2d_keypoints_dataset import Hand2DKeypointDataset | |||||
| else: | |||||
| _import_structure = { | |||||
| 'hand_2d_keypoints_dataset': ['Hand2DKeypointDataset'] | |||||
| } | |||||
| import sys | |||||
| sys.modules[__name__] = LazyImportModule( | |||||
| __name__, | |||||
| globals()['__file__'], | |||||
| _import_structure, | |||||
| module_spec=__spec__, | |||||
| extra_objects={}, | |||||
| ) | |||||
| @@ -0,0 +1,38 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from easycv.datasets.pose import \ | |||||
| HandCocoWholeBodyDataset as _HandCocoWholeBodyDataset | |||||
| from modelscope.metainfo import Datasets | |||||
| from modelscope.msdatasets.cv.easycv_base import EasyCVBaseDataset | |||||
| from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS | |||||
| from modelscope.utils.constant import Tasks | |||||
| @TASK_DATASETS.register_module( | |||||
| group_key=Tasks.hand_2d_keypoints, | |||||
| module_name=Datasets.HandCocoWholeBodyDataset) | |||||
| class HandCocoWholeBodyDataset(EasyCVBaseDataset, _HandCocoWholeBodyDataset): | |||||
| """EasyCV dataset for human hand 2d keypoints. | |||||
| Args: | |||||
| split_config (dict): Dataset root path from MSDataset, e.g. | |||||
| {"train":"local cache path"} or {"evaluation":"local cache path"}. | |||||
| preprocessor (Preprocessor): An optional preprocessor instance, please make sure the preprocessor fits for | |||||
| the model if supplied. Not support yet. | |||||
| mode: Training or Evaluation. | |||||
| """ | |||||
| def __init__(self, | |||||
| split_config=None, | |||||
| preprocessor=None, | |||||
| mode=None, | |||||
| *args, | |||||
| **kwargs) -> None: | |||||
| EasyCVBaseDataset.__init__( | |||||
| self, | |||||
| split_config=split_config, | |||||
| preprocessor=preprocessor, | |||||
| mode=mode, | |||||
| args=args, | |||||
| kwargs=kwargs) | |||||
| _HandCocoWholeBodyDataset.__init__(self, *args, **kwargs) | |||||
| @@ -0,0 +1,72 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import glob | |||||
| import os | |||||
| import shutil | |||||
| import tempfile | |||||
| import unittest | |||||
| import torch | |||||
| from modelscope.metainfo import Trainers | |||||
| from modelscope.msdatasets import MsDataset | |||||
| from modelscope.trainers import build_trainer | |||||
| from modelscope.utils.constant import DownloadMode, LogKeys, Tasks | |||||
| from modelscope.utils.logger import get_logger | |||||
| from modelscope.utils.test_utils import test_level | |||||
| @unittest.skipIf(not torch.cuda.is_available(), 'cuda unittest') | |||||
| class EasyCVTrainerTestHand2dKeypoints(unittest.TestCase): | |||||
| model_id = 'damo/cv_hrnetw18_hand-pose-keypoints_coco-wholebody' | |||||
| def setUp(self): | |||||
| self.logger = get_logger() | |||||
| self.logger.info(('Testing %s.%s' % | |||||
| (type(self).__name__, self._testMethodName))) | |||||
| self.tmp_dir = tempfile.TemporaryDirectory().name | |||||
| if not os.path.exists(self.tmp_dir): | |||||
| os.makedirs(self.tmp_dir) | |||||
| def tearDown(self): | |||||
| super().tearDown() | |||||
| shutil.rmtree(self.tmp_dir, ignore_errors=True) | |||||
| def _train(self): | |||||
| cfg_options = {'train.max_epochs': 20} | |||||
| trainer_name = Trainers.easycv | |||||
| train_dataset = MsDataset.load( | |||||
| dataset_name='cv_hand_2d_keypoints_coco_wholebody', | |||||
| namespace='chenhyer', | |||||
| split='subtrain', | |||||
| download_mode=DownloadMode.FORCE_REDOWNLOAD) | |||||
| eval_dataset = MsDataset.load( | |||||
| dataset_name='cv_hand_2d_keypoints_coco_wholebody', | |||||
| namespace='chenhyer', | |||||
| split='subtrain', | |||||
| download_mode=DownloadMode.FORCE_REDOWNLOAD) | |||||
| kwargs = dict( | |||||
| model=self.model_id, | |||||
| train_dataset=train_dataset, | |||||
| eval_dataset=eval_dataset, | |||||
| work_dir=self.tmp_dir, | |||||
| cfg_options=cfg_options) | |||||
| trainer = build_trainer(trainer_name, kwargs) | |||||
| trainer.train() | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_trainer_single_gpu(self): | |||||
| self._train() | |||||
| results_files = os.listdir(self.tmp_dir) | |||||
| json_files = glob.glob(os.path.join(self.tmp_dir, '*.log.json')) | |||||
| self.assertEqual(len(json_files), 1) | |||||
| self.assertIn(f'{LogKeys.EPOCH}_10.pth', results_files) | |||||
| self.assertIn(f'{LogKeys.EPOCH}_20.pth', results_files) | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||