Browse Source

remove ofa_file_dataset

master
行嗔 3 years ago
parent
commit
a799dd237d
2 changed files with 16 additions and 2 deletions
  1. +16
    -1
      modelscope/models/multi_modal/ofa_for_all_tasks.py
  2. +0
    -1
      modelscope/trainers/multi_modal/ofa/ofa_trainer_utils.py

+ 16
- 1
modelscope/models/multi_modal/ofa_for_all_tasks.py View File

@@ -1,8 +1,10 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import math
import os
import string
from functools import partial
from os import path as osp
from typing import Any, Dict
from typing import Any, Callable, Dict, List, Optional, Union

import json
import torch.cuda
@@ -295,3 +297,16 @@ class OfaForAllTasks(TorchModel):
self.cfg.model.answer2label)
with open(ans2label_file, 'r') as reader:
self.ans2label_dict = json.load(reader)

def save_pretrained(self,
target_folder: Union[str, os.PathLike],
save_checkpoint_names: Union[str, List[str]] = None,
save_function: Callable = None,
config: Optional[dict] = None,
**kwargs):
super(OfaForAllTasks, self). \
save_pretrained(target_folder=target_folder,
save_checkpoint_names=save_checkpoint_names,
save_function=partial(save_function, with_meta=False),
config=config,
**kwargs)

+ 0
- 1
modelscope/trainers/multi_modal/ofa/ofa_trainer_utils.py View File

@@ -12,7 +12,6 @@ from torch.nn.modules.loss import _Loss
from torch.utils.data import Dataset

from modelscope.preprocessors.multi_modal import OfaPreprocessor
from .ofa_file_dataset import OFAFileDataset


class OFADataset(Dataset):


Loading…
Cancel
Save