Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10421808 * add face 2d keypoints & human wholebody keypoint finrtune test casemaster
| @@ -452,9 +452,9 @@ class Datasets(object): | |||||
| """ Names for different datasets. | """ Names for different datasets. | ||||
| """ | """ | ||||
| ClsDataset = 'ClsDataset' | ClsDataset = 'ClsDataset' | ||||
| Face2dKeypointsDataset = 'Face2dKeypointsDataset' | |||||
| Face2dKeypointsDataset = 'FaceKeypointDataset' | |||||
| HandCocoWholeBodyDataset = 'HandCocoWholeBodyDataset' | HandCocoWholeBodyDataset = 'HandCocoWholeBodyDataset' | ||||
| HumanWholeBodyKeypointDataset = 'HumanWholeBodyKeypointDataset' | |||||
| HumanWholeBodyKeypointDataset = 'WholeBodyCocoTopDownDataset' | |||||
| SegDataset = 'SegDataset' | SegDataset = 'SegDataset' | ||||
| DetDataset = 'DetDataset' | DetDataset = 'DetDataset' | ||||
| DetImagesMixDataset = 'DetImagesMixDataset' | DetImagesMixDataset = 'DetImagesMixDataset' | ||||
| @@ -26,11 +26,16 @@ class EasyCVBaseDataset(object): | |||||
| if self.split_config is not None: | if self.split_config is not None: | ||||
| self._update_data_source(kwargs['data_source']) | self._update_data_source(kwargs['data_source']) | ||||
| def _update_data_root(self, input_dict, data_root): | |||||
| for k, v in input_dict.items(): | |||||
| if isinstance(v, str) and self.DATA_ROOT_PATTERN in v: | |||||
| input_dict.update( | |||||
| {k: v.replace(self.DATA_ROOT_PATTERN, data_root)}) | |||||
| elif isinstance(v, dict): | |||||
| self._update_data_root(v, data_root) | |||||
| def _update_data_source(self, data_source): | def _update_data_source(self, data_source): | ||||
| data_root = next(iter(self.split_config.values())) | data_root = next(iter(self.split_config.values())) | ||||
| data_root = data_root.rstrip(osp.sep) | data_root = data_root.rstrip(osp.sep) | ||||
| for k, v in data_source.items(): | |||||
| if isinstance(v, str) and self.DATA_ROOT_PATTERN in v: | |||||
| data_source.update( | |||||
| {k: v.replace(self.DATA_ROOT_PATTERN, data_root)}) | |||||
| self._update_data_root(data_source, data_root) | |||||
| @@ -19,7 +19,7 @@ moviepy>=1.0.3 | |||||
| networkx>=2.5 | networkx>=2.5 | ||||
| numba | numba | ||||
| onnxruntime>=1.10 | onnxruntime>=1.10 | ||||
| pai-easycv>=0.6.3.7 | |||||
| pai-easycv>=0.6.3.9 | |||||
| pandas | pandas | ||||
| psutil | psutil | ||||
| regex | regex | ||||
| @@ -0,0 +1,71 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import glob | |||||
| import os | |||||
| import shutil | |||||
| import tempfile | |||||
| import unittest | |||||
| import torch | |||||
| from modelscope.metainfo import Trainers | |||||
| from modelscope.msdatasets import MsDataset | |||||
| from modelscope.trainers import build_trainer | |||||
| from modelscope.utils.constant import DownloadMode, LogKeys, Tasks | |||||
| from modelscope.utils.logger import get_logger | |||||
| from modelscope.utils.test_utils import test_level | |||||
| @unittest.skipIf(not torch.cuda.is_available(), 'cuda unittest') | |||||
| class EasyCVTrainerTestFace2DKeypoints(unittest.TestCase): | |||||
| model_id = 'damo/cv_mobilenet_face-2d-keypoints_alignment' | |||||
| def setUp(self): | |||||
| self.logger = get_logger() | |||||
| self.logger.info(('Testing %s.%s' % | |||||
| (type(self).__name__, self._testMethodName))) | |||||
| def _train(self, tmp_dir): | |||||
| cfg_options = {'train.max_epochs': 2} | |||||
| trainer_name = Trainers.easycv | |||||
| train_dataset = MsDataset.load( | |||||
| dataset_name='face_2d_keypoints_dataset', | |||||
| namespace='modelscope', | |||||
| split='train', | |||||
| download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS) | |||||
| eval_dataset = MsDataset.load( | |||||
| dataset_name='face_2d_keypoints_dataset', | |||||
| namespace='modelscope', | |||||
| split='train', | |||||
| download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS) | |||||
| kwargs = dict( | |||||
| model=self.model_id, | |||||
| train_dataset=train_dataset, | |||||
| eval_dataset=eval_dataset, | |||||
| work_dir=tmp_dir, | |||||
| cfg_options=cfg_options) | |||||
| trainer = build_trainer(trainer_name, kwargs) | |||||
| trainer.train() | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_trainer_single_gpu(self): | |||||
| temp_file_dir = tempfile.TemporaryDirectory() | |||||
| tmp_dir = temp_file_dir.name | |||||
| if not os.path.exists(tmp_dir): | |||||
| os.makedirs(tmp_dir) | |||||
| self._train(tmp_dir) | |||||
| results_files = os.listdir(tmp_dir) | |||||
| json_files = glob.glob(os.path.join(tmp_dir, '*.log.json')) | |||||
| self.assertEqual(len(json_files), 1) | |||||
| self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files) | |||||
| temp_file_dir.cleanup() | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||