Browse Source

refactor inputs format of model forward

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9673243

    * refactor inputs format of model forward
master
jiangnana.jnn 3 years ago
parent
commit
652ec697b7
4 changed files with 25 additions and 39 deletions
  1. +8
    -13
      modelscope/models/base/base_head.py
  2. +8
    -14
      modelscope/models/base/base_model.py
  3. +3
    -5
      modelscope/models/base/base_torch_head.py
  4. +6
    -7
      modelscope/models/base/base_torch_model.py

+ 8
- 13
modelscope/models/base/base_head.py View File

@@ -1,6 +1,6 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from abc import ABC, abstractmethod
from typing import Dict, Union
from typing import Any, Dict, Union

from modelscope.models.base.base_model import Model
from modelscope.utils.config import ConfigDict
@@ -22,25 +22,20 @@ class Head(ABC):
self.config = ConfigDict(kwargs)

@abstractmethod
def forward(self, input: Input) -> Dict[str, Tensor]:
def forward(self, *args, **kwargs) -> Dict[str, Any]:
"""
This method will use the output from backbone model to do any
downstream tasks
Args:
input: The tensor output or a model from backbone model
(text generation need a model as input)
Returns: The output from downstream taks
downstream tasks. Recieve The output from backbone model.

Returns (Dict[str, Any]): The output from downstream task.
"""
pass

@abstractmethod
def compute_loss(self, outputs: Dict[str, Tensor],
labels) -> Dict[str, Tensor]:
def compute_loss(self, *args, **kwargs) -> Dict[str, Any]:
"""
compute loss for head during the finetuning
compute loss for head during the finetuning.

Args:
outputs (Dict[str, Tensor]): the output from the model forward
Returns: the loss(Dict[str, Tensor]):
Returns (Dict[str, Any]): The loss dict
"""
pass

+ 8
- 14
modelscope/models/base/base_model.py View File

@@ -2,7 +2,7 @@
import os
import os.path as osp
from abc import ABC, abstractmethod
from typing import Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union

from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models.builder import build_model
@@ -10,8 +10,6 @@ from modelscope.utils.checkpoint import save_pretrained
from modelscope.utils.config import Config
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile
from modelscope.utils.device import device_placement, verify_device
from modelscope.utils.file_utils import func_receive_dict_inputs
from modelscope.utils.hub import parse_label_mapping
from modelscope.utils.logger import get_logger

logger = get_logger()
@@ -27,35 +25,31 @@ class Model(ABC):
verify_device(device_name)
self._device_name = device_name

def __call__(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
return self.postprocess(self.forward(input))
def __call__(self, *args, **kwargs) -> Dict[str, Any]:
return self.postprocess(self.forward(*args, **kwargs))

@abstractmethod
def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
def forward(self, *args, **kwargs) -> Dict[str, Any]:
"""
Run the forward pass for a model.

Args:
input (Dict[str, Tensor]): the dict of the model inputs for the forward method

Returns:
Dict[str, Tensor]: output from the model forward pass
Dict[str, Any]: output from the model forward pass
"""
pass

def postprocess(self, input: Dict[str, Tensor],
**kwargs) -> Dict[str, Tensor]:
def postprocess(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]:
""" Model specific postprocess and convert model output to
standard model outputs.

Args:
input: input data
inputs: input data

Return:
dict of results: a dict containing outputs of model, each
output should have the standard output name.
"""
return input
return inputs

@classmethod
def _instantiate(cls, **kwargs):


+ 3
- 5
modelscope/models/base/base_torch_head.py View File

@@ -1,5 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Dict
from typing import Any, Dict

import torch

@@ -18,10 +18,8 @@ class TorchHead(Head, torch.nn.Module):
super().__init__(**kwargs)
torch.nn.Module.__init__(self)

def forward(self, inputs: Dict[str,
torch.Tensor]) -> Dict[str, torch.Tensor]:
def forward(self, *args, **kwargs) -> Dict[str, Any]:
raise NotImplementedError

def compute_loss(self, outputs: Dict[str, torch.Tensor],
labels) -> Dict[str, torch.Tensor]:
def compute_loss(self, *args, **kwargs) -> Dict[str, Any]:
raise NotImplementedError

+ 6
- 7
modelscope/models/base/base_torch_model.py View File

@@ -1,6 +1,6 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

from typing import Any, Dict, Optional, Union
from typing import Any, Dict

import torch
from torch import nn
@@ -21,15 +21,14 @@ class TorchModel(Model, torch.nn.Module):
super().__init__(model_dir, *args, **kwargs)
torch.nn.Module.__init__(self)

def __call__(self, input: Dict[str,
torch.Tensor]) -> Dict[str, torch.Tensor]:
def __call__(self, *args, **kwargs) -> Dict[str, Any]:
# Adapting a model with only one dict arg, and the arg name must be input or inputs
if func_receive_dict_inputs(self.forward):
return self.postprocess(self.forward(input))
return self.postprocess(self.forward(args[0], **kwargs))
else:
return self.postprocess(self.forward(**input))
return self.postprocess(self.forward(*args, **kwargs))

def forward(self, inputs: Dict[str,
torch.Tensor]) -> Dict[str, torch.Tensor]:
def forward(self, *args, **kwargs) -> Dict[str, Any]:
raise NotImplementedError

def post_init(self):


Loading…
Cancel
Save