Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9673243 * refactor inputs format of model forwardmaster
| @@ -1,6 +1,6 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from abc import ABC, abstractmethod | |||
| from typing import Dict, Union | |||
| from typing import Any, Dict, Union | |||
| from modelscope.models.base.base_model import Model | |||
| from modelscope.utils.config import ConfigDict | |||
| @@ -22,25 +22,20 @@ class Head(ABC): | |||
| self.config = ConfigDict(kwargs) | |||
| @abstractmethod | |||
| def forward(self, input: Input) -> Dict[str, Tensor]: | |||
| def forward(self, *args, **kwargs) -> Dict[str, Any]: | |||
| """ | |||
| This method will use the output from backbone model to do any | |||
| downstream tasks | |||
| Args: | |||
| input: The tensor output or a model from backbone model | |||
| (text generation need a model as input) | |||
| Returns: The output from downstream taks | |||
| downstream tasks. Recieve The output from backbone model. | |||
| Returns (Dict[str, Any]): The output from downstream task. | |||
| """ | |||
| pass | |||
| @abstractmethod | |||
| def compute_loss(self, outputs: Dict[str, Tensor], | |||
| labels) -> Dict[str, Tensor]: | |||
| def compute_loss(self, *args, **kwargs) -> Dict[str, Any]: | |||
| """ | |||
| compute loss for head during the finetuning | |||
| compute loss for head during the finetuning. | |||
| Args: | |||
| outputs (Dict[str, Tensor]): the output from the model forward | |||
| Returns: the loss(Dict[str, Tensor]): | |||
| Returns (Dict[str, Any]): The loss dict | |||
| """ | |||
| pass | |||
| @@ -2,7 +2,7 @@ | |||
| import os | |||
| import os.path as osp | |||
| from abc import ABC, abstractmethod | |||
| from typing import Callable, Dict, List, Optional, Union | |||
| from typing import Any, Callable, Dict, List, Optional, Union | |||
| from modelscope.hub.snapshot_download import snapshot_download | |||
| from modelscope.models.builder import build_model | |||
| @@ -10,8 +10,6 @@ from modelscope.utils.checkpoint import save_pretrained | |||
| from modelscope.utils.config import Config | |||
| from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile | |||
| from modelscope.utils.device import device_placement, verify_device | |||
| from modelscope.utils.file_utils import func_receive_dict_inputs | |||
| from modelscope.utils.hub import parse_label_mapping | |||
| from modelscope.utils.logger import get_logger | |||
| logger = get_logger() | |||
| @@ -27,35 +25,31 @@ class Model(ABC): | |||
| verify_device(device_name) | |||
| self._device_name = device_name | |||
| def __call__(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||
| return self.postprocess(self.forward(input)) | |||
| def __call__(self, *args, **kwargs) -> Dict[str, Any]: | |||
| return self.postprocess(self.forward(*args, **kwargs)) | |||
| @abstractmethod | |||
| def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||
| def forward(self, *args, **kwargs) -> Dict[str, Any]: | |||
| """ | |||
| Run the forward pass for a model. | |||
| Args: | |||
| input (Dict[str, Tensor]): the dict of the model inputs for the forward method | |||
| Returns: | |||
| Dict[str, Tensor]: output from the model forward pass | |||
| Dict[str, Any]: output from the model forward pass | |||
| """ | |||
| pass | |||
| def postprocess(self, input: Dict[str, Tensor], | |||
| **kwargs) -> Dict[str, Tensor]: | |||
| def postprocess(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]: | |||
| """ Model specific postprocess and convert model output to | |||
| standard model outputs. | |||
| Args: | |||
| input: input data | |||
| inputs: input data | |||
| Return: | |||
| dict of results: a dict containing outputs of model, each | |||
| output should have the standard output name. | |||
| """ | |||
| return input | |||
| return inputs | |||
| @classmethod | |||
| def _instantiate(cls, **kwargs): | |||
| @@ -1,5 +1,5 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from typing import Dict | |||
| from typing import Any, Dict | |||
| import torch | |||
| @@ -18,10 +18,8 @@ class TorchHead(Head, torch.nn.Module): | |||
| super().__init__(**kwargs) | |||
| torch.nn.Module.__init__(self) | |||
| def forward(self, inputs: Dict[str, | |||
| torch.Tensor]) -> Dict[str, torch.Tensor]: | |||
| def forward(self, *args, **kwargs) -> Dict[str, Any]: | |||
| raise NotImplementedError | |||
| def compute_loss(self, outputs: Dict[str, torch.Tensor], | |||
| labels) -> Dict[str, torch.Tensor]: | |||
| def compute_loss(self, *args, **kwargs) -> Dict[str, Any]: | |||
| raise NotImplementedError | |||
| @@ -1,6 +1,6 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from typing import Any, Dict, Optional, Union | |||
| from typing import Any, Dict | |||
| import torch | |||
| from torch import nn | |||
| @@ -21,15 +21,14 @@ class TorchModel(Model, torch.nn.Module): | |||
| super().__init__(model_dir, *args, **kwargs) | |||
| torch.nn.Module.__init__(self) | |||
| def __call__(self, input: Dict[str, | |||
| torch.Tensor]) -> Dict[str, torch.Tensor]: | |||
| def __call__(self, *args, **kwargs) -> Dict[str, Any]: | |||
| # Adapting a model with only one dict arg, and the arg name must be input or inputs | |||
| if func_receive_dict_inputs(self.forward): | |||
| return self.postprocess(self.forward(input)) | |||
| return self.postprocess(self.forward(args[0], **kwargs)) | |||
| else: | |||
| return self.postprocess(self.forward(**input)) | |||
| return self.postprocess(self.forward(*args, **kwargs)) | |||
| def forward(self, inputs: Dict[str, | |||
| torch.Tensor]) -> Dict[str, torch.Tensor]: | |||
| def forward(self, *args, **kwargs) -> Dict[str, Any]: | |||
| raise NotImplementedError | |||
| def post_init(self): | |||