|
|
|
@@ -8,7 +8,8 @@ import torch.cuda |
|
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
from modelscope.metainfo import Models |
|
|
|
from modelscope.models.base import Model, Tensor |
|
|
|
from modelscope.models import TorchModel |
|
|
|
from modelscope.models.base import Tensor |
|
|
|
from modelscope.models.builder import MODELS |
|
|
|
from modelscope.outputs import OutputKeys |
|
|
|
from modelscope.preprocessors.ofa.utils.collate import collate_tokens |
|
|
|
@@ -32,7 +33,7 @@ __all__ = ['OfaForAllTasks'] |
|
|
|
@MODELS.register_module(Tasks.image_classification, module_name=Models.ofa) |
|
|
|
@MODELS.register_module(Tasks.summarization, module_name=Models.ofa) |
|
|
|
@MODELS.register_module(Tasks.text_classification, module_name=Models.ofa) |
|
|
|
class OfaForAllTasks(Model): |
|
|
|
class OfaForAllTasks(TorchModel): |
|
|
|
|
|
|
|
def __init__(self, model_dir, *args, **kwargs): |
|
|
|
super().__init__(model_dir=model_dir, *args, **kwargs) |
|
|
|
|