|
|
|
@@ -310,9 +310,6 @@ def pipeline(task: str = None, |
|
|
|
model[0], revision=model_revision) |
|
|
|
check_config(cfg) |
|
|
|
pipeline_name = cfg.pipeline.type |
|
|
|
else: |
|
|
|
# used for test case, when model is str and is not hub path |
|
|
|
pipeline_name = get_pipeline_by_model_name(task, model) |
|
|
|
elif model is not None: |
|
|
|
# get pipeline info from Model object |
|
|
|
first_model = model[0] if isinstance(model, list) else model |
|
|
|
@@ -375,19 +372,3 @@ def get_default_pipeline_info(task): |
|
|
|
else: |
|
|
|
pipeline_name, default_model = DEFAULT_MODEL_FOR_PIPELINE[task] |
|
|
|
return pipeline_name, default_model |
|
|
|
|
|
|
|
|
|
|
|
def get_pipeline_by_model_name(task: str, model: Union[str, List[str]]): |
|
|
|
""" Get pipeline name by task name and model name |
|
|
|
|
|
|
|
Args: |
|
|
|
task (str): task name. |
|
|
|
model (str| list[str]): model names |
|
|
|
""" |
|
|
|
if isinstance(model, str): |
|
|
|
model_key = model |
|
|
|
else: |
|
|
|
model_key = '_'.join(model) |
|
|
|
assert model_key in PIPELINES.modules[task], \ |
|
|
|
f'pipeline for task {task} model {model_key} not found.' |
|
|
|
return model_key |