|
|
|
@@ -1,8 +1,10 @@ |
|
|
|
# Copyright (c) Alibaba, Inc. and its affiliates. |
|
|
|
import math |
|
|
|
import os |
|
|
|
import string |
|
|
|
from functools import partial |
|
|
|
from os import path as osp |
|
|
|
from typing import Any, Dict |
|
|
|
from typing import Any, Callable, Dict, List, Optional, Union |
|
|
|
|
|
|
|
import json |
|
|
|
import torch.cuda |
|
|
|
@@ -295,3 +297,16 @@ class OfaForAllTasks(TorchModel): |
|
|
|
self.cfg.model.answer2label) |
|
|
|
with open(ans2label_file, 'r') as reader: |
|
|
|
self.ans2label_dict = json.load(reader) |
|
|
|
|
|
|
|
def save_pretrained(self, |
|
|
|
target_folder: Union[str, os.PathLike], |
|
|
|
save_checkpoint_names: Union[str, List[str]] = None, |
|
|
|
save_function: Callable = None, |
|
|
|
config: Optional[dict] = None, |
|
|
|
**kwargs): |
|
|
|
super(OfaForAllTasks, self). \ |
|
|
|
save_pretrained(target_folder=target_folder, |
|
|
|
save_checkpoint_names=save_checkpoint_names, |
|
|
|
save_function=partial(save_function, with_meta=False), |
|
|
|
config=config, |
|
|
|
**kwargs) |