Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9673243 * refactor inputs format of model forwardmaster
| @@ -1,6 +1,6 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| from abc import ABC, abstractmethod | from abc import ABC, abstractmethod | ||||
| from typing import Dict, Union | |||||
| from typing import Any, Dict, Union | |||||
| from modelscope.models.base.base_model import Model | from modelscope.models.base.base_model import Model | ||||
| from modelscope.utils.config import ConfigDict | from modelscope.utils.config import ConfigDict | ||||
| @@ -22,25 +22,20 @@ class Head(ABC): | |||||
| self.config = ConfigDict(kwargs) | self.config = ConfigDict(kwargs) | ||||
| @abstractmethod | @abstractmethod | ||||
| def forward(self, input: Input) -> Dict[str, Tensor]: | |||||
| def forward(self, *args, **kwargs) -> Dict[str, Any]: | |||||
| """ | """ | ||||
| This method will use the output from backbone model to do any | This method will use the output from backbone model to do any | ||||
| downstream tasks | |||||
| Args: | |||||
| input: The tensor output or a model from backbone model | |||||
| (text generation need a model as input) | |||||
| Returns: The output from downstream taks | |||||
| downstream tasks. Recieve The output from backbone model. | |||||
| Returns (Dict[str, Any]): The output from downstream task. | |||||
| """ | """ | ||||
| pass | pass | ||||
| @abstractmethod | @abstractmethod | ||||
| def compute_loss(self, outputs: Dict[str, Tensor], | |||||
| labels) -> Dict[str, Tensor]: | |||||
| def compute_loss(self, *args, **kwargs) -> Dict[str, Any]: | |||||
| """ | """ | ||||
| compute loss for head during the finetuning | |||||
| compute loss for head during the finetuning. | |||||
| Args: | |||||
| outputs (Dict[str, Tensor]): the output from the model forward | |||||
| Returns: the loss(Dict[str, Tensor]): | |||||
| Returns (Dict[str, Any]): The loss dict | |||||
| """ | """ | ||||
| pass | pass | ||||
| @@ -2,7 +2,7 @@ | |||||
| import os | import os | ||||
| import os.path as osp | import os.path as osp | ||||
| from abc import ABC, abstractmethod | from abc import ABC, abstractmethod | ||||
| from typing import Callable, Dict, List, Optional, Union | |||||
| from typing import Any, Callable, Dict, List, Optional, Union | |||||
| from modelscope.hub.snapshot_download import snapshot_download | from modelscope.hub.snapshot_download import snapshot_download | ||||
| from modelscope.models.builder import build_model | from modelscope.models.builder import build_model | ||||
| @@ -10,8 +10,6 @@ from modelscope.utils.checkpoint import save_pretrained | |||||
| from modelscope.utils.config import Config | from modelscope.utils.config import Config | ||||
| from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile | from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile | ||||
| from modelscope.utils.device import device_placement, verify_device | from modelscope.utils.device import device_placement, verify_device | ||||
| from modelscope.utils.file_utils import func_receive_dict_inputs | |||||
| from modelscope.utils.hub import parse_label_mapping | |||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| logger = get_logger() | logger = get_logger() | ||||
| @@ -27,35 +25,31 @@ class Model(ABC): | |||||
| verify_device(device_name) | verify_device(device_name) | ||||
| self._device_name = device_name | self._device_name = device_name | ||||
| def __call__(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||||
| return self.postprocess(self.forward(input)) | |||||
| def __call__(self, *args, **kwargs) -> Dict[str, Any]: | |||||
| return self.postprocess(self.forward(*args, **kwargs)) | |||||
| @abstractmethod | @abstractmethod | ||||
| def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||||
| def forward(self, *args, **kwargs) -> Dict[str, Any]: | |||||
| """ | """ | ||||
| Run the forward pass for a model. | Run the forward pass for a model. | ||||
| Args: | |||||
| input (Dict[str, Tensor]): the dict of the model inputs for the forward method | |||||
| Returns: | Returns: | ||||
| Dict[str, Tensor]: output from the model forward pass | |||||
| Dict[str, Any]: output from the model forward pass | |||||
| """ | """ | ||||
| pass | pass | ||||
| def postprocess(self, input: Dict[str, Tensor], | |||||
| **kwargs) -> Dict[str, Tensor]: | |||||
| def postprocess(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]: | |||||
| """ Model specific postprocess and convert model output to | """ Model specific postprocess and convert model output to | ||||
| standard model outputs. | standard model outputs. | ||||
| Args: | Args: | ||||
| input: input data | |||||
| inputs: input data | |||||
| Return: | Return: | ||||
| dict of results: a dict containing outputs of model, each | dict of results: a dict containing outputs of model, each | ||||
| output should have the standard output name. | output should have the standard output name. | ||||
| """ | """ | ||||
| return input | |||||
| return inputs | |||||
| @classmethod | @classmethod | ||||
| def _instantiate(cls, **kwargs): | def _instantiate(cls, **kwargs): | ||||
| @@ -1,5 +1,5 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| from typing import Dict | |||||
| from typing import Any, Dict | |||||
| import torch | import torch | ||||
| @@ -18,10 +18,8 @@ class TorchHead(Head, torch.nn.Module): | |||||
| super().__init__(**kwargs) | super().__init__(**kwargs) | ||||
| torch.nn.Module.__init__(self) | torch.nn.Module.__init__(self) | ||||
| def forward(self, inputs: Dict[str, | |||||
| torch.Tensor]) -> Dict[str, torch.Tensor]: | |||||
| def forward(self, *args, **kwargs) -> Dict[str, Any]: | |||||
| raise NotImplementedError | raise NotImplementedError | ||||
| def compute_loss(self, outputs: Dict[str, torch.Tensor], | |||||
| labels) -> Dict[str, torch.Tensor]: | |||||
| def compute_loss(self, *args, **kwargs) -> Dict[str, Any]: | |||||
| raise NotImplementedError | raise NotImplementedError | ||||
| @@ -1,6 +1,6 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| from typing import Any, Dict, Optional, Union | |||||
| from typing import Any, Dict | |||||
| import torch | import torch | ||||
| from torch import nn | from torch import nn | ||||
| @@ -21,15 +21,14 @@ class TorchModel(Model, torch.nn.Module): | |||||
| super().__init__(model_dir, *args, **kwargs) | super().__init__(model_dir, *args, **kwargs) | ||||
| torch.nn.Module.__init__(self) | torch.nn.Module.__init__(self) | ||||
| def __call__(self, input: Dict[str, | |||||
| torch.Tensor]) -> Dict[str, torch.Tensor]: | |||||
| def __call__(self, *args, **kwargs) -> Dict[str, Any]: | |||||
| # Adapting a model with only one dict arg, and the arg name must be input or inputs | |||||
| if func_receive_dict_inputs(self.forward): | if func_receive_dict_inputs(self.forward): | ||||
| return self.postprocess(self.forward(input)) | |||||
| return self.postprocess(self.forward(args[0], **kwargs)) | |||||
| else: | else: | ||||
| return self.postprocess(self.forward(**input)) | |||||
| return self.postprocess(self.forward(*args, **kwargs)) | |||||
| def forward(self, inputs: Dict[str, | |||||
| torch.Tensor]) -> Dict[str, torch.Tensor]: | |||||
| def forward(self, *args, **kwargs) -> Dict[str, Any]: | |||||
| raise NotImplementedError | raise NotImplementedError | ||||
| def post_init(self): | def post_init(self): | ||||