Browse Source

Merge remote-tracking branch 'origin/master' into ofa/finetune

master
行嗔 3 years ago
parent
commit
92a04b010c
4 changed files with 62 additions and 13 deletions
  1. +2
    -0
      modelscope/pipelines/base.py
  2. +16
    -5
      modelscope/pipelines/cv/body_3d_keypoints_pipeline.py
  3. +1
    -1
      modelscope/trainers/hooks/logger/text_logger_hook.py
  4. +43
    -7
      tests/pipelines/test_text_generation.py

+ 2
- 0
modelscope/pipelines/base.py View File

@@ -433,6 +433,8 @@ def collate_fn(data, device):
if isinstance(data, dict) or isinstance(data, Mapping): if isinstance(data, dict) or isinstance(data, Mapping):
return type(data)({k: collate_fn(v, device) for k, v in data.items()}) return type(data)({k: collate_fn(v, device) for k, v in data.items()})
elif isinstance(data, (tuple, list)): elif isinstance(data, (tuple, list)):
if 0 == len(data):
return torch.Tensor([])
if isinstance(data[0], (int, float)): if isinstance(data[0], (int, float)):
return default_collate(data).to(device) return default_collate(data).to(device)
else: else:


+ 16
- 5
modelscope/pipelines/cv/body_3d_keypoints_pipeline.py View File

@@ -143,6 +143,13 @@ class Body3DKeypointsPipeline(Pipeline):
max_frame = self.keypoint_model_3d.cfg.model.INPUT.MAX_FRAME # max video frame number to be predicted 3D joints max_frame = self.keypoint_model_3d.cfg.model.INPUT.MAX_FRAME # max video frame number to be predicted 3D joints
for i, frame in enumerate(video_frames): for i, frame in enumerate(video_frames):
kps_2d = self.human_body_2d_kps_detector(frame) kps_2d = self.human_body_2d_kps_detector(frame)
if [] == kps_2d.get('boxes'):
res = {
'success': False,
'msg': f'fail to detect person at image frame {i}'
}
return res

box = kps_2d['boxes'][ box = kps_2d['boxes'][
0] # box: [[[x1, y1], [x2, y2]]], N human boxes per frame, [0] represent using first detected bbox 0] # box: [[[x1, y1], [x2, y2]]], N human boxes per frame, [0] represent using first detected bbox
pose = kps_2d['keypoints'][0] # keypoints: [15, 2] pose = kps_2d['keypoints'][0] # keypoints: [15, 2]
@@ -180,7 +187,15 @@ class Body3DKeypointsPipeline(Pipeline):
return res return res


def postprocess(self, input: Dict[str, Any], **kwargs) -> Dict[str, Any]: def postprocess(self, input: Dict[str, Any], **kwargs) -> Dict[str, Any]:
res = {OutputKeys.KEYPOINTS: [], OutputKeys.TIMESTAMPS: []}
output_video_path = kwargs.get('output_video', None)
if output_video_path is None:
output_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name

res = {
OutputKeys.KEYPOINTS: [],
OutputKeys.TIMESTAMPS: [],
OutputKeys.OUTPUT_VIDEO: output_video_path
}


if not input['success']: if not input['success']:
pass pass
@@ -189,10 +204,6 @@ class Body3DKeypointsPipeline(Pipeline):
pred_3d_pose = poses.data.cpu().numpy()[ pred_3d_pose = poses.data.cpu().numpy()[
0] # [frame_num, joint_num, joint_dim] 0] # [frame_num, joint_num, joint_dim]


output_video_path = kwargs.get('output_video', None)
if output_video_path is None:
output_video_path = tempfile.NamedTemporaryFile(
suffix='.mp4').name
if 'render' in self.keypoint_model_3d.cfg.keys(): if 'render' in self.keypoint_model_3d.cfg.keys():
self.render_prediction(pred_3d_pose, output_video_path) self.render_prediction(pred_3d_pose, output_video_path)
res[OutputKeys.OUTPUT_VIDEO] = output_video_path res[OutputKeys.OUTPUT_VIDEO] = output_video_path


+ 1
- 1
modelscope/trainers/hooks/logger/text_logger_hook.py View File

@@ -51,7 +51,7 @@ class TextLoggerHook(LoggerHook):
if self.out_dir is None: if self.out_dir is None:
self.out_dir = trainer.work_dir self.out_dir = trainer.work_dir


if not osp.exists(self.out_dir):
if not osp.exists(self.out_dir) and is_master():
os.makedirs(self.out_dir) os.makedirs(self.out_dir)


trainer.logger.info('Text logs will be saved to {}'.format( trainer.logger.info('Text logs will be saved to {}'.format(


+ 43
- 7
tests/pipelines/test_text_generation.py View File

@@ -15,12 +15,17 @@ from modelscope.utils.test_utils import test_level
class TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck): class TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck):


def setUp(self) -> None: def setUp(self) -> None:
self.palm_model_id_zh = 'damo/nlp_palm2.0_text-generation_chinese-base'
self.palm_model_id_zh_base = 'damo/nlp_palm2.0_text-generation_chinese-base'
self.palm_model_id_zh_large = 'damo/nlp_palm2.0_text-generation_chinese-large'
self.palm_model_id_zh_commodity = 'damo/nlp_palm2.0_text-generation_commodity_chinese-base'
self.palm_model_id_zh_weather = 'damo/nlp_palm2.0_text-generation_weather_chinese-base'
self.palm_model_id_en = 'damo/nlp_palm2.0_text-generation_english-base' self.palm_model_id_en = 'damo/nlp_palm2.0_text-generation_english-base'
self.palm_input_zh = """ self.palm_input_zh = """
本文总结了十个可穿戴产品的设计原则,而这些原则,同样也是笔者认为是这个行业最吸引人的地方: 本文总结了十个可穿戴产品的设计原则,而这些原则,同样也是笔者认为是这个行业最吸引人的地方:
1.为人们解决重复性问题;2.从人开始,而不是从机器开始;3.要引起注意,但不要刻意;4.提升用户能力,而不是取代 1.为人们解决重复性问题;2.从人开始,而不是从机器开始;3.要引起注意,但不要刻意;4.提升用户能力,而不是取代
""" """
self.palm_input_commodity = '垃圾桶,双层,可拆卸,加高,加高双层,把手,垃圾桶,内附,万向轮'
self.palm_input_weather = "今日天气类型='浮尘'&空气质量等级='重度污染'&紫外线强度指数='中等'"
self.palm_input_en = """ self.palm_input_en = """
The Director of Public Prosecutions who let off Lord Janner over alleged child sex abuse started The Director of Public Prosecutions who let off Lord Janner over alleged child sex abuse started
her career at a legal chambers when the disgraced Labour peer was a top QC there . Alison Saunders , her career at a legal chambers when the disgraced Labour peer was a top QC there . Alison Saunders ,
@@ -51,8 +56,8 @@ class TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck):
print(pipeline_ins(input)) print(pipeline_ins(input))


@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_palm_zh_with_model_name(self):
self.run_pipeline_with_model_id(self.palm_model_id_zh,
def test_palm_zh_base_with_model_name(self):
self.run_pipeline_with_model_id(self.palm_model_id_zh_base,
self.palm_input_zh) self.palm_input_zh)


@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
@@ -71,10 +76,40 @@ class TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck):
self.gpt3_input) self.gpt3_input)


@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_palm_zh_with_model_instance(self):
self.run_pipeline_with_model_instance(self.palm_model_id_zh,
def test_palm_zh_large_with_model_name(self):
self.run_pipeline_with_model_id(self.palm_model_id_zh_large,
self.palm_input_zh)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_palm_zh_commodity_with_model_name(self):
self.run_pipeline_with_model_id(self.palm_model_id_zh_commodity,
self.palm_input_commodity)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_palm_zh_weather_with_model_name(self):
self.run_pipeline_with_model_id(self.palm_model_id_zh_weather,
self.palm_input_weather)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_palm_zh_base_with_model_instance(self):
self.run_pipeline_with_model_instance(self.palm_model_id_zh_base,
self.palm_input_zh) self.palm_input_zh)


@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_palm_zh_large_with_model_instance(self):
self.run_pipeline_with_model_instance(self.palm_model_id_zh_large,
self.palm_input_zh)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_palm_zh_commodity_with_model_instance(self):
self.run_pipeline_with_model_instance(self.palm_model_id_zh_commodity,
self.palm_input_commodity)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_palm_zh_weather_with_model_instance(self):
self.run_pipeline_with_model_instance(self.palm_model_id_zh_weather,
self.palm_input_weather)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_palm_en_with_model_instance(self): def test_palm_en_with_model_instance(self):
self.run_pipeline_with_model_instance(self.palm_model_id_en, self.run_pipeline_with_model_instance(self.palm_model_id_en,
@@ -92,8 +127,9 @@ class TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck):


@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_palm(self): def test_run_palm(self):
for model_id, input in ((self.palm_model_id_zh, self.palm_input_zh),
(self.palm_model_id_en, self.palm_input_en)):
for model_id, input in ((self.palm_model_id_zh_base,
self.palm_input_zh), (self.palm_model_id_en,
self.palm_input_en)):
cache_path = snapshot_download(model_id) cache_path = snapshot_download(model_id)
model = PalmForTextGeneration.from_pretrained(cache_path) model = PalmForTextGeneration.from_pretrained(cache_path)
preprocessor = TextGenerationPreprocessor( preprocessor = TextGenerationPreprocessor(


Loading…
Cancel
Save