|
|
|
@@ -91,7 +91,6 @@ class Model(ABC): |
|
|
|
osp.join(local_model_dir, ModelFile.CONFIGURATION)) |
|
|
|
task_name = cfg.task |
|
|
|
model_cfg = cfg.model |
|
|
|
framework = cfg.framework |
|
|
|
|
|
|
|
if hasattr(model_cfg, 'model_type') and not hasattr(model_cfg, 'type'): |
|
|
|
model_cfg.type = model_cfg.model_type |
|
|
|
@@ -101,9 +100,8 @@ class Model(ABC): |
|
|
|
model_cfg[k] = v |
|
|
|
if device is not None: |
|
|
|
model_cfg.device = device |
|
|
|
with device_placement(framework, device): |
|
|
|
model = build_model( |
|
|
|
model_cfg, task_name=task_name, default_args=kwargs) |
|
|
|
model = build_model( |
|
|
|
model_cfg, task_name=task_name, default_args=kwargs) |
|
|
|
else: |
|
|
|
model = build_model( |
|
|
|
model_cfg, task_name=task_name, default_args=kwargs) |
|
|
|
|