Browse Source

refactor inputs format of model forward

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9673243

    * refactor inputs format of model forward
master
jiangnana.jnn 3 years ago
parent
commit
652ec697b7
4 changed files with 25 additions and 39 deletions
  1. +8
    -13
      modelscope/models/base/base_head.py
  2. +8
    -14
      modelscope/models/base/base_model.py
  3. +3
    -5
      modelscope/models/base/base_torch_head.py
  4. +6
    -7
      modelscope/models/base/base_torch_model.py

+ 8
- 13
modelscope/models/base/base_head.py View File

@@ -1,6 +1,6 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict, Union
from typing import Any, Dict, Union


from modelscope.models.base.base_model import Model from modelscope.models.base.base_model import Model
from modelscope.utils.config import ConfigDict from modelscope.utils.config import ConfigDict
@@ -22,25 +22,20 @@ class Head(ABC):
self.config = ConfigDict(kwargs) self.config = ConfigDict(kwargs)


@abstractmethod @abstractmethod
def forward(self, input: Input) -> Dict[str, Tensor]:
def forward(self, *args, **kwargs) -> Dict[str, Any]:
""" """
This method will use the output from backbone model to do any This method will use the output from backbone model to do any
downstream tasks
Args:
input: The tensor output or a model from backbone model
(text generation need a model as input)
Returns: The output from downstream taks
downstream tasks. Recieve The output from backbone model.

Returns (Dict[str, Any]): The output from downstream task.
""" """
pass pass


@abstractmethod @abstractmethod
def compute_loss(self, outputs: Dict[str, Tensor],
labels) -> Dict[str, Tensor]:
def compute_loss(self, *args, **kwargs) -> Dict[str, Any]:
""" """
compute loss for head during the finetuning
compute loss for head during the finetuning.


Args:
outputs (Dict[str, Tensor]): the output from the model forward
Returns: the loss(Dict[str, Tensor]):
Returns (Dict[str, Any]): The loss dict
""" """
pass pass

+ 8
- 14
modelscope/models/base/base_model.py View File

@@ -2,7 +2,7 @@
import os import os
import os.path as osp import os.path as osp
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union


from modelscope.hub.snapshot_download import snapshot_download from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models.builder import build_model from modelscope.models.builder import build_model
@@ -10,8 +10,6 @@ from modelscope.utils.checkpoint import save_pretrained
from modelscope.utils.config import Config from modelscope.utils.config import Config
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile
from modelscope.utils.device import device_placement, verify_device from modelscope.utils.device import device_placement, verify_device
from modelscope.utils.file_utils import func_receive_dict_inputs
from modelscope.utils.hub import parse_label_mapping
from modelscope.utils.logger import get_logger from modelscope.utils.logger import get_logger


logger = get_logger() logger = get_logger()
@@ -27,35 +25,31 @@ class Model(ABC):
verify_device(device_name) verify_device(device_name)
self._device_name = device_name self._device_name = device_name


def __call__(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
return self.postprocess(self.forward(input))
def __call__(self, *args, **kwargs) -> Dict[str, Any]:
return self.postprocess(self.forward(*args, **kwargs))


@abstractmethod @abstractmethod
def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
def forward(self, *args, **kwargs) -> Dict[str, Any]:
""" """
Run the forward pass for a model. Run the forward pass for a model.


Args:
input (Dict[str, Tensor]): the dict of the model inputs for the forward method

Returns: Returns:
Dict[str, Tensor]: output from the model forward pass
Dict[str, Any]: output from the model forward pass
""" """
pass pass


def postprocess(self, input: Dict[str, Tensor],
**kwargs) -> Dict[str, Tensor]:
def postprocess(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]:
""" Model specific postprocess and convert model output to """ Model specific postprocess and convert model output to
standard model outputs. standard model outputs.


Args: Args:
input: input data
inputs: input data


Return: Return:
dict of results: a dict containing outputs of model, each dict of results: a dict containing outputs of model, each
output should have the standard output name. output should have the standard output name.
""" """
return input
return inputs


@classmethod @classmethod
def _instantiate(cls, **kwargs): def _instantiate(cls, **kwargs):


+ 3
- 5
modelscope/models/base/base_torch_head.py View File

@@ -1,5 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Dict
from typing import Any, Dict


import torch import torch


@@ -18,10 +18,8 @@ class TorchHead(Head, torch.nn.Module):
super().__init__(**kwargs) super().__init__(**kwargs)
torch.nn.Module.__init__(self) torch.nn.Module.__init__(self)


def forward(self, inputs: Dict[str,
torch.Tensor]) -> Dict[str, torch.Tensor]:
def forward(self, *args, **kwargs) -> Dict[str, Any]:
raise NotImplementedError raise NotImplementedError


def compute_loss(self, outputs: Dict[str, torch.Tensor],
labels) -> Dict[str, torch.Tensor]:
def compute_loss(self, *args, **kwargs) -> Dict[str, Any]:
raise NotImplementedError raise NotImplementedError

+ 6
- 7
modelscope/models/base/base_torch_model.py View File

@@ -1,6 +1,6 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.


from typing import Any, Dict, Optional, Union
from typing import Any, Dict


import torch import torch
from torch import nn from torch import nn
@@ -21,15 +21,14 @@ class TorchModel(Model, torch.nn.Module):
super().__init__(model_dir, *args, **kwargs) super().__init__(model_dir, *args, **kwargs)
torch.nn.Module.__init__(self) torch.nn.Module.__init__(self)


def __call__(self, input: Dict[str,
torch.Tensor]) -> Dict[str, torch.Tensor]:
def __call__(self, *args, **kwargs) -> Dict[str, Any]:
# Adapting a model with only one dict arg, and the arg name must be input or inputs
if func_receive_dict_inputs(self.forward): if func_receive_dict_inputs(self.forward):
return self.postprocess(self.forward(input))
return self.postprocess(self.forward(args[0], **kwargs))
else: else:
return self.postprocess(self.forward(**input))
return self.postprocess(self.forward(*args, **kwargs))


def forward(self, inputs: Dict[str,
torch.Tensor]) -> Dict[str, torch.Tensor]:
def forward(self, *args, **kwargs) -> Dict[str, Any]:
raise NotImplementedError raise NotImplementedError


def post_init(self): def post_init(self):


Loading…
Cancel
Save