Browse Source

Merge remote-tracking branch 'origin/master' into ofa/finetune

master
行嗔 3 years ago
parent
commit
92a04b010c
4 changed files with 62 additions and 13 deletions
  1. +2
    -0
      modelscope/pipelines/base.py
  2. +16
    -5
      modelscope/pipelines/cv/body_3d_keypoints_pipeline.py
  3. +1
    -1
      modelscope/trainers/hooks/logger/text_logger_hook.py
  4. +43
    -7
      tests/pipelines/test_text_generation.py

+ 2
- 0
modelscope/pipelines/base.py View File

@@ -433,6 +433,8 @@ def collate_fn(data, device):
if isinstance(data, dict) or isinstance(data, Mapping):
return type(data)({k: collate_fn(v, device) for k, v in data.items()})
elif isinstance(data, (tuple, list)):
if 0 == len(data):
return torch.Tensor([])
if isinstance(data[0], (int, float)):
return default_collate(data).to(device)
else:


+ 16
- 5
modelscope/pipelines/cv/body_3d_keypoints_pipeline.py View File

@@ -143,6 +143,13 @@ class Body3DKeypointsPipeline(Pipeline):
max_frame = self.keypoint_model_3d.cfg.model.INPUT.MAX_FRAME # max video frame number to be predicted 3D joints
for i, frame in enumerate(video_frames):
kps_2d = self.human_body_2d_kps_detector(frame)
if [] == kps_2d.get('boxes'):
res = {
'success': False,
'msg': f'fail to detect person at image frame {i}'
}
return res

box = kps_2d['boxes'][
0] # box: [[[x1, y1], [x2, y2]]], N human boxes per frame, [0] represent using first detected bbox
pose = kps_2d['keypoints'][0] # keypoints: [15, 2]
@@ -180,7 +187,15 @@ class Body3DKeypointsPipeline(Pipeline):
return res

def postprocess(self, input: Dict[str, Any], **kwargs) -> Dict[str, Any]:
res = {OutputKeys.KEYPOINTS: [], OutputKeys.TIMESTAMPS: []}
output_video_path = kwargs.get('output_video', None)
if output_video_path is None:
output_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name

res = {
OutputKeys.KEYPOINTS: [],
OutputKeys.TIMESTAMPS: [],
OutputKeys.OUTPUT_VIDEO: output_video_path
}

if not input['success']:
pass
@@ -189,10 +204,6 @@ class Body3DKeypointsPipeline(Pipeline):
pred_3d_pose = poses.data.cpu().numpy()[
0] # [frame_num, joint_num, joint_dim]

output_video_path = kwargs.get('output_video', None)
if output_video_path is None:
output_video_path = tempfile.NamedTemporaryFile(
suffix='.mp4').name
if 'render' in self.keypoint_model_3d.cfg.keys():
self.render_prediction(pred_3d_pose, output_video_path)
res[OutputKeys.OUTPUT_VIDEO] = output_video_path


+ 1
- 1
modelscope/trainers/hooks/logger/text_logger_hook.py View File

@@ -51,7 +51,7 @@ class TextLoggerHook(LoggerHook):
if self.out_dir is None:
self.out_dir = trainer.work_dir

if not osp.exists(self.out_dir):
if not osp.exists(self.out_dir) and is_master():
os.makedirs(self.out_dir)

trainer.logger.info('Text logs will be saved to {}'.format(


+ 43
- 7
tests/pipelines/test_text_generation.py View File

@@ -15,12 +15,17 @@ from modelscope.utils.test_utils import test_level
class TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck):

def setUp(self) -> None:
self.palm_model_id_zh = 'damo/nlp_palm2.0_text-generation_chinese-base'
self.palm_model_id_zh_base = 'damo/nlp_palm2.0_text-generation_chinese-base'
self.palm_model_id_zh_large = 'damo/nlp_palm2.0_text-generation_chinese-large'
self.palm_model_id_zh_commodity = 'damo/nlp_palm2.0_text-generation_commodity_chinese-base'
self.palm_model_id_zh_weather = 'damo/nlp_palm2.0_text-generation_weather_chinese-base'
self.palm_model_id_en = 'damo/nlp_palm2.0_text-generation_english-base'
self.palm_input_zh = """
本文总结了十个可穿戴产品的设计原则,而这些原则,同样也是笔者认为是这个行业最吸引人的地方:
1.为人们解决重复性问题;2.从人开始,而不是从机器开始;3.要引起注意,但不要刻意;4.提升用户能力,而不是取代
"""
self.palm_input_commodity = '垃圾桶,双层,可拆卸,加高,加高双层,把手,垃圾桶,内附,万向轮'
self.palm_input_weather = "今日天气类型='浮尘'&空气质量等级='重度污染'&紫外线强度指数='中等'"
self.palm_input_en = """
The Director of Public Prosecutions who let off Lord Janner over alleged child sex abuse started
her career at a legal chambers when the disgraced Labour peer was a top QC there . Alison Saunders ,
@@ -51,8 +56,8 @@ class TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck):
print(pipeline_ins(input))

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_palm_zh_with_model_name(self):
self.run_pipeline_with_model_id(self.palm_model_id_zh,
def test_palm_zh_base_with_model_name(self):
self.run_pipeline_with_model_id(self.palm_model_id_zh_base,
self.palm_input_zh)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
@@ -71,10 +76,40 @@ class TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck):
self.gpt3_input)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_palm_zh_with_model_instance(self):
self.run_pipeline_with_model_instance(self.palm_model_id_zh,
def test_palm_zh_large_with_model_name(self):
self.run_pipeline_with_model_id(self.palm_model_id_zh_large,
self.palm_input_zh)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_palm_zh_commodity_with_model_name(self):
self.run_pipeline_with_model_id(self.palm_model_id_zh_commodity,
self.palm_input_commodity)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_palm_zh_weather_with_model_name(self):
self.run_pipeline_with_model_id(self.palm_model_id_zh_weather,
self.palm_input_weather)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_palm_zh_base_with_model_instance(self):
self.run_pipeline_with_model_instance(self.palm_model_id_zh_base,
self.palm_input_zh)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_palm_zh_large_with_model_instance(self):
self.run_pipeline_with_model_instance(self.palm_model_id_zh_large,
self.palm_input_zh)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_palm_zh_commodity_with_model_instance(self):
self.run_pipeline_with_model_instance(self.palm_model_id_zh_commodity,
self.palm_input_commodity)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_palm_zh_weather_with_model_instance(self):
self.run_pipeline_with_model_instance(self.palm_model_id_zh_weather,
self.palm_input_weather)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_palm_en_with_model_instance(self):
self.run_pipeline_with_model_instance(self.palm_model_id_en,
@@ -92,8 +127,9 @@ class TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck):

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_palm(self):
for model_id, input in ((self.palm_model_id_zh, self.palm_input_zh),
(self.palm_model_id_en, self.palm_input_en)):
for model_id, input in ((self.palm_model_id_zh_base,
self.palm_input_zh), (self.palm_model_id_en,
self.palm_input_en)):
cache_path = snapshot_download(model_id)
model = PalmForTextGeneration.from_pretrained(cache_path)
preprocessor = TextGenerationPreprocessor(


Loading…
Cancel
Save