| @@ -1,6 +1,6 @@ | |||
| repos: | |||
| - repo: https://gitlab.com/pycqa/flake8.git | |||
| rev: 3.8.3 | |||
| rev: 4.0.0 | |||
| hooks: | |||
| - id: flake8 | |||
| exclude: thirdparty/|examples/ | |||
| @@ -1,6 +1,6 @@ | |||
| repos: | |||
| - repo: /home/admin/pre-commit/flake8 | |||
| rev: 3.8.3 | |||
| rev: 4.0.0 | |||
| hooks: | |||
| - id: flake8 | |||
| exclude: thirdparty/|examples/ | |||
| @@ -390,11 +390,13 @@ class HubApi: | |||
| return resp['Data'] | |||
| def list_oss_dataset_objects(self, dataset_name, namespace, max_limit, | |||
| is_recursive, is_filter_dir, revision, | |||
| cookies): | |||
| is_recursive, is_filter_dir, revision): | |||
| url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/oss/tree/?' \ | |||
| f'MaxLimit={max_limit}&Revision={revision}&Recursive={is_recursive}&FilterDir={is_filter_dir}' | |||
| cookies = requests.utils.dict_from_cookiejar(cookies) | |||
| cookies = ModelScopeConfig.get_cookies() | |||
| if cookies: | |||
| cookies = requests.utils.dict_from_cookiejar(cookies) | |||
| resp = requests.get(url=url, cookies=cookies) | |||
| resp = resp.json() | |||
| @@ -9,7 +9,9 @@ class Models(object): | |||
| Model name should only contain model info but not task info. | |||
| """ | |||
| # tinynas models | |||
| tinynas_detection = 'tinynas-detection' | |||
| tinynas_damoyolo = 'tinynas-damoyolo' | |||
| # vision models | |||
| detection = 'detection' | |||
| @@ -7,10 +7,12 @@ from modelscope.utils.import_utils import LazyImportModule | |||
| if TYPE_CHECKING: | |||
| from .tinynas_detector import Tinynas_detector | |||
| from .tinynas_damoyolo import DamoYolo | |||
| else: | |||
| _import_structure = { | |||
| 'tinynas_detector': ['TinynasDetector'], | |||
| 'tinynas_damoyolo': ['DamoYolo'], | |||
| } | |||
| import sys | |||
| @@ -4,6 +4,7 @@ | |||
| import torch | |||
| import torch.nn as nn | |||
| from modelscope.utils.file_utils import read_file | |||
| from ..core.base_ops import Focus, SPPBottleneck, get_activation | |||
| from ..core.repvgg_block import RepVggBlock | |||
| @@ -49,12 +50,16 @@ class ResConvK1KX(nn.Module): | |||
| kernel_size, | |||
| stride, | |||
| force_resproj=False, | |||
| act='silu'): | |||
| act='silu', | |||
| reparam=False): | |||
| super(ResConvK1KX, self).__init__() | |||
| self.stride = stride | |||
| self.conv1 = ConvKXBN(in_c, btn_c, 1, 1) | |||
| self.conv2 = RepVggBlock( | |||
| btn_c, out_c, kernel_size, stride, act='identity') | |||
| if not reparam: | |||
| self.conv2 = ConvKXBN(btn_c, out_c, 3, stride) | |||
| else: | |||
| self.conv2 = RepVggBlock( | |||
| btn_c, out_c, kernel_size, stride, act='identity') | |||
| if act is None: | |||
| self.activation_function = torch.relu | |||
| @@ -97,7 +102,8 @@ class SuperResConvK1KX(nn.Module): | |||
| stride, | |||
| num_blocks, | |||
| with_spp=False, | |||
| act='silu'): | |||
| act='silu', | |||
| reparam=False): | |||
| super(SuperResConvK1KX, self).__init__() | |||
| if act is None: | |||
| self.act = torch.relu | |||
| @@ -124,7 +130,8 @@ class SuperResConvK1KX(nn.Module): | |||
| this_kernel_size, | |||
| this_stride, | |||
| force_resproj, | |||
| act=act) | |||
| act=act, | |||
| reparam=reparam) | |||
| self.block_list.append(the_block) | |||
| if block_id == 0 and with_spp: | |||
| self.block_list.append( | |||
| @@ -248,7 +255,8 @@ class TinyNAS(nn.Module): | |||
| with_spp=False, | |||
| use_focus=False, | |||
| need_conv1=True, | |||
| act='silu'): | |||
| act='silu', | |||
| reparam=False): | |||
| super(TinyNAS, self).__init__() | |||
| assert len(out_indices) == len(out_channels) | |||
| self.out_indices = out_indices | |||
| @@ -281,7 +289,8 @@ class TinyNAS(nn.Module): | |||
| block_info['s'], | |||
| block_info['L'], | |||
| spp, | |||
| act=act) | |||
| act=act, | |||
| reparam=reparam) | |||
| self.block_list.append(the_block) | |||
| elif the_block_class == 'SuperResConvKXKX': | |||
| spp = with_spp if idx == len(structure_info) - 1 else False | |||
| @@ -325,8 +334,8 @@ class TinyNAS(nn.Module): | |||
| def load_tinynas_net(backbone_cfg): | |||
| # load masternet model to path | |||
| import ast | |||
| struct_str = ''.join([x.strip() for x in backbone_cfg.net_structure_str]) | |||
| net_structure_str = read_file(backbone_cfg.structure_file) | |||
| struct_str = ''.join([x.strip() for x in net_structure_str]) | |||
| struct_info = ast.literal_eval(struct_str) | |||
| for layer in struct_info: | |||
| if 'nbitsA' in layer: | |||
| @@ -342,6 +351,6 @@ def load_tinynas_net(backbone_cfg): | |||
| use_focus=backbone_cfg.use_focus, | |||
| act=backbone_cfg.act, | |||
| need_conv1=backbone_cfg.need_conv1, | |||
| ) | |||
| reparam=backbone_cfg.reparam) | |||
| return model | |||
| @@ -30,7 +30,7 @@ class SingleStageDetector(TorchModel): | |||
| """ | |||
| super().__init__(model_dir, *args, **kwargs) | |||
| config_path = osp.join(model_dir, 'airdet_s.py') | |||
| config_path = osp.join(model_dir, self.config_name) | |||
| config = parse_config(config_path) | |||
| self.cfg = config | |||
| model_path = osp.join(model_dir, config.model.name) | |||
| @@ -41,6 +41,9 @@ class SingleStageDetector(TorchModel): | |||
| self.conf_thre = config.model.head.nms_conf_thre | |||
| self.nms_thre = config.model.head.nms_iou_thre | |||
| if self.cfg.model.backbone.name == 'TinyNAS': | |||
| self.cfg.model.backbone.structure_file = osp.join( | |||
| model_dir, self.cfg.model.backbone.structure_file) | |||
| self.backbone = build_backbone(self.cfg.model.backbone) | |||
| self.neck = build_neck(self.cfg.model.neck) | |||
| self.head = build_head(self.cfg.model.head) | |||
| @@ -124,11 +124,13 @@ class GFocalHead_Tiny(nn.Module): | |||
| simOTA_iou_weight=3.0, | |||
| octbase=8, | |||
| simlqe=False, | |||
| use_lqe=True, | |||
| **kwargs): | |||
| self.simlqe = simlqe | |||
| self.num_classes = num_classes | |||
| self.in_channels = in_channels | |||
| self.strides = strides | |||
| self.use_lqe = use_lqe | |||
| self.feat_channels = feat_channels if isinstance(feat_channels, list) \ | |||
| else [feat_channels] * len(self.strides) | |||
| @@ -181,15 +183,20 @@ class GFocalHead_Tiny(nn.Module): | |||
| groups=self.conv_groups, | |||
| norm=self.norm, | |||
| act=self.act)) | |||
| if not self.simlqe: | |||
| conf_vector = [nn.Conv2d(4 * self.total_dim, self.reg_channels, 1)] | |||
| if self.use_lqe: | |||
| if not self.simlqe: | |||
| conf_vector = [ | |||
| nn.Conv2d(4 * self.total_dim, self.reg_channels, 1) | |||
| ] | |||
| else: | |||
| conf_vector = [ | |||
| nn.Conv2d(4 * (self.reg_max + 1), self.reg_channels, 1) | |||
| ] | |||
| conf_vector += [self.relu] | |||
| conf_vector += [nn.Conv2d(self.reg_channels, 1, 1), nn.Sigmoid()] | |||
| reg_conf = nn.Sequential(*conf_vector) | |||
| else: | |||
| conf_vector = [ | |||
| nn.Conv2d(4 * (self.reg_max + 1), self.reg_channels, 1) | |||
| ] | |||
| conf_vector += [self.relu] | |||
| conf_vector += [nn.Conv2d(self.reg_channels, 1, 1), nn.Sigmoid()] | |||
| reg_conf = nn.Sequential(*conf_vector) | |||
| reg_conf = None | |||
| return cls_convs, reg_convs, reg_conf | |||
| @@ -290,21 +297,27 @@ class GFocalHead_Tiny(nn.Module): | |||
| N, C, H, W = bbox_pred.size() | |||
| prob = F.softmax( | |||
| bbox_pred.reshape(N, 4, self.reg_max + 1, H, W), dim=2) | |||
| if not self.simlqe: | |||
| prob_topk, _ = prob.topk(self.reg_topk, dim=2) | |||
| if self.add_mean: | |||
| stat = torch.cat( | |||
| [prob_topk, prob_topk.mean(dim=2, keepdim=True)], dim=2) | |||
| if self.use_lqe: | |||
| if not self.simlqe: | |||
| prob_topk, _ = prob.topk(self.reg_topk, dim=2) | |||
| if self.add_mean: | |||
| stat = torch.cat( | |||
| [prob_topk, | |||
| prob_topk.mean(dim=2, keepdim=True)], | |||
| dim=2) | |||
| else: | |||
| stat = prob_topk | |||
| quality_score = reg_conf( | |||
| stat.reshape(N, 4 * self.total_dim, H, W)) | |||
| else: | |||
| stat = prob_topk | |||
| quality_score = reg_conf( | |||
| bbox_pred.reshape(N, 4 * (self.reg_max + 1), H, W)) | |||
| quality_score = reg_conf(stat.reshape(N, 4 * self.total_dim, H, W)) | |||
| cls_score = gfl_cls(cls_feat).sigmoid() * quality_score | |||
| else: | |||
| quality_score = reg_conf( | |||
| bbox_pred.reshape(N, 4 * (self.reg_max + 1), H, W)) | |||
| cls_score = gfl_cls(cls_feat).sigmoid() * quality_score | |||
| cls_score = gfl_cls(cls_feat).sigmoid() | |||
| flatten_cls_score = cls_score.flatten(start_dim=2).transpose(1, 2) | |||
| flatten_bbox_pred = bbox_pred.flatten(start_dim=2).transpose(1, 2) | |||
| @@ -14,7 +14,6 @@ class GiraffeNeckV2(nn.Module): | |||
| self, | |||
| depth=1.0, | |||
| width=1.0, | |||
| in_features=[2, 3, 4], | |||
| in_channels=[256, 512, 1024], | |||
| out_channels=[256, 512, 1024], | |||
| depthwise=False, | |||
| @@ -24,7 +23,6 @@ class GiraffeNeckV2(nn.Module): | |||
| block_name='BasicBlock', | |||
| ): | |||
| super().__init__() | |||
| self.in_features = in_features | |||
| self.in_channels = in_channels | |||
| Conv = DWConv if depthwise else BaseConv | |||
| @@ -169,8 +167,7 @@ class GiraffeNeckV2(nn.Module): | |||
| """ | |||
| # backbone | |||
| features = [out_features[f] for f in self.in_features] | |||
| [x2, x1, x0] = features | |||
| [x2, x1, x0] = out_features | |||
| # node x3 | |||
| x13 = self.bu_conv13(x1) | |||
| @@ -0,0 +1,15 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from modelscope.metainfo import Models | |||
| from modelscope.models.builder import MODELS | |||
| from modelscope.utils.constant import Tasks | |||
| from .detector import SingleStageDetector | |||
| @MODELS.register_module( | |||
| Tasks.image_object_detection, module_name=Models.tinynas_damoyolo) | |||
| class DamoYolo(SingleStageDetector): | |||
| def __init__(self, model_dir, *args, **kwargs): | |||
| self.config_name = 'damoyolo_s.py' | |||
| super(DamoYolo, self).__init__(model_dir, *args, **kwargs) | |||
| @@ -12,5 +12,5 @@ from .detector import SingleStageDetector | |||
| class TinynasDetector(SingleStageDetector): | |||
| def __init__(self, model_dir, *args, **kwargs): | |||
| self.config_name = 'airdet_s.py' | |||
| super(TinynasDetector, self).__init__(model_dir, *args, **kwargs) | |||
| @@ -7,7 +7,7 @@ from typing import Any, Mapping, Optional, Sequence, Union | |||
| from datasets.builder import DatasetBuilder | |||
| from modelscope.hub.api import HubApi | |||
| from modelscope.utils.constant import DEFAULT_DATASET_REVISION, DownloadParams | |||
| from modelscope.utils.constant import DEFAULT_DATASET_REVISION | |||
| from modelscope.utils.logger import get_logger | |||
| from .dataset_builder import MsCsvDatasetBuilder, TaskSpecificDatasetBuilder | |||
| @@ -95,15 +95,13 @@ def list_dataset_objects(hub_api: HubApi, max_limit: int, is_recursive: bool, | |||
| res (list): List of objects, i.e., ['train/images/001.png', 'train/images/002.png', 'val/images/001.png', ...] | |||
| """ | |||
| res = [] | |||
| cookies = hub_api.check_cookies_upload_data(use_cookies=True) | |||
| objects = hub_api.list_oss_dataset_objects( | |||
| dataset_name=dataset_name, | |||
| namespace=namespace, | |||
| max_limit=max_limit, | |||
| is_recursive=is_recursive, | |||
| is_filter_dir=True, | |||
| revision=version, | |||
| cookies=cookies) | |||
| revision=version) | |||
| for item in objects: | |||
| object_key = item.get('Key') | |||
| @@ -174,7 +172,7 @@ def get_dataset_files(subset_split_into: dict, | |||
| modelscope_api = HubApi() | |||
| objects = list_dataset_objects( | |||
| hub_api=modelscope_api, | |||
| max_limit=DownloadParams.MAX_LIST_OBJECTS_NUM.value, | |||
| max_limit=-1, | |||
| is_recursive=True, | |||
| dataset_name=dataset_name, | |||
| namespace=namespace, | |||
| @@ -12,6 +12,8 @@ from modelscope.pipelines.base import Input, Pipeline | |||
| from modelscope.pipelines.builder import PIPELINES | |||
| from modelscope.preprocessors import LoadImage | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.cv.image_utils import \ | |||
| show_image_object_detection_auto_result | |||
| from modelscope.utils.logger import get_logger | |||
| logger = get_logger() | |||
| @@ -52,10 +54,18 @@ class TinynasDetectionPipeline(Pipeline): | |||
| bboxes, scores, labels = self.model.postprocess(inputs['data']) | |||
| if bboxes is None: | |||
| return None | |||
| outputs = { | |||
| OutputKeys.SCORES: scores, | |||
| OutputKeys.LABELS: labels, | |||
| OutputKeys.BOXES: bboxes | |||
| } | |||
| outputs = { | |||
| OutputKeys.SCORES: [], | |||
| OutputKeys.LABELS: [], | |||
| OutputKeys.BOXES: [] | |||
| } | |||
| else: | |||
| outputs = { | |||
| OutputKeys.SCORES: scores, | |||
| OutputKeys.LABELS: labels, | |||
| OutputKeys.BOXES: bboxes | |||
| } | |||
| return outputs | |||
| def show_result(self, img_path, result, save_path=None): | |||
| show_image_object_detection_auto_result(img_path, result, save_path) | |||
| @@ -1,5 +1,10 @@ | |||
| import math | |||
| import os | |||
| import random | |||
| import uuid | |||
| from os.path import exists | |||
| from tempfile import TemporaryDirectory | |||
| from urllib.parse import urlparse | |||
| import numpy as np | |||
| import torch | |||
| @@ -9,6 +14,7 @@ import torchvision.transforms._transforms_video as transforms | |||
| from decord import VideoReader | |||
| from torchvision.transforms import Compose | |||
| from modelscope.hub.file_download import http_get_file | |||
| from modelscope.metainfo import Preprocessors | |||
| from modelscope.utils.constant import Fields, ModeKeys | |||
| from modelscope.utils.type_assert import type_assert | |||
| @@ -30,7 +36,22 @@ def ReadVideoData(cfg, | |||
| Returns: | |||
| data (Tensor): the normalized video clips for model inputs | |||
| """ | |||
| data = _decode_video(cfg, video_path, num_temporal_views_override) | |||
| url_parsed = urlparse(video_path) | |||
| if url_parsed.scheme in ('file', '') and exists( | |||
| url_parsed.path): # Possibly a local file | |||
| data = _decode_video(cfg, video_path, num_temporal_views_override) | |||
| else: | |||
| with TemporaryDirectory() as temporary_cache_dir: | |||
| random_str = uuid.uuid4().hex | |||
| http_get_file( | |||
| url=video_path, | |||
| local_dir=temporary_cache_dir, | |||
| file_name=random_str, | |||
| cookies=None) | |||
| temp_file_path = os.path.join(temporary_cache_dir, random_str) | |||
| data = _decode_video(cfg, temp_file_path, | |||
| num_temporal_views_override) | |||
| if num_spatial_crops_override is not None: | |||
| num_spatial_crops = num_spatial_crops_override | |||
| transform = kinetics400_tranform(cfg, num_spatial_crops_override) | |||
| @@ -231,13 +231,6 @@ class DownloadMode(enum.Enum): | |||
| FORCE_REDOWNLOAD = 'force_redownload' | |||
| class DownloadParams(enum.Enum): | |||
| """ | |||
| Parameters for downloading dataset. | |||
| """ | |||
| MAX_LIST_OBJECTS_NUM = 50000 | |||
| class DatasetFormations(enum.Enum): | |||
| """ How a dataset is organized and interpreted | |||
| """ | |||
| @@ -61,8 +61,8 @@ def device_placement(framework, device_name='gpu:0'): | |||
| if framework == Frameworks.tf: | |||
| import tensorflow as tf | |||
| if device_type == Devices.gpu and not tf.test.is_gpu_available(): | |||
| logger.warning( | |||
| 'tensorflow cuda is not available, using cpu instead.') | |||
| logger.debug( | |||
| 'tensorflow: cuda is not available, using cpu instead.') | |||
| device_type = Devices.cpu | |||
| if device_type == Devices.cpu: | |||
| with tf.device('/CPU:0'): | |||
| @@ -78,7 +78,8 @@ def device_placement(framework, device_name='gpu:0'): | |||
| if torch.cuda.is_available(): | |||
| torch.cuda.set_device(f'cuda:{device_id}') | |||
| else: | |||
| logger.warning('cuda is not available, using cpu instead.') | |||
| logger.debug( | |||
| 'pytorch: cuda is not available, using cpu instead.') | |||
| yield | |||
| else: | |||
| yield | |||
| @@ -96,9 +97,7 @@ def create_device(device_name): | |||
| if device_type == Devices.gpu: | |||
| use_cuda = True | |||
| if not torch.cuda.is_available(): | |||
| logger.warning( | |||
| 'cuda is not available, create gpu device failed, using cpu instead.' | |||
| ) | |||
| logger.info('cuda is not available, using cpu instead.') | |||
| use_cuda = False | |||
| if use_cuda: | |||
| @@ -1,6 +1,7 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import inspect | |||
| import os | |||
| from pathlib import Path | |||
| @@ -35,3 +36,10 @@ def get_default_cache_dir(): | |||
| """ | |||
| default_cache_dir = Path.home().joinpath('.cache', 'modelscope') | |||
| return default_cache_dir | |||
| def read_file(path): | |||
| with open(path, 'r') as f: | |||
| text = f.read() | |||
| return text | |||
| @@ -176,7 +176,7 @@ def build_from_cfg(cfg, | |||
| raise TypeError('default_args must be a dict or None, ' | |||
| f'but got {type(default_args)}') | |||
| # dynamic load installation reqruiements for this module | |||
| # dynamic load installation requirements for this module | |||
| from modelscope.utils.import_utils import LazyImportModule | |||
| sig = (registry.name.upper(), group_key, cfg['type']) | |||
| LazyImportModule.import_module(sig) | |||
| @@ -193,8 +193,11 @@ def build_from_cfg(cfg, | |||
| if isinstance(obj_type, str): | |||
| obj_cls = registry.get(obj_type, group_key=group_key) | |||
| if obj_cls is None: | |||
| raise KeyError(f'{obj_type} is not in the {registry.name}' | |||
| f' registry group {group_key}') | |||
| raise KeyError( | |||
| f'{obj_type} is not in the {registry.name}' | |||
| f' registry group {group_key}. Please make' | |||
| f' sure the correct version of 1qqQModelScope library is used.' | |||
| ) | |||
| obj_cls.group_key = group_key | |||
| elif inspect.isclass(obj_type) or inspect.isfunction(obj_type): | |||
| obj_cls = obj_type | |||
| @@ -4,22 +4,45 @@ import unittest | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.demo_utils import DemoCompatibilityCheck | |||
| from modelscope.utils.test_utils import test_level | |||
| class TinynasObjectDetectionTest(unittest.TestCase): | |||
| class TinynasObjectDetectionTest(unittest.TestCase, DemoCompatibilityCheck): | |||
| def setUp(self) -> None: | |||
| self.task = Tasks.image_object_detection | |||
| self.model_id = 'damo/cv_tinynas_object-detection_damoyolo' | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run(self): | |||
| def test_run_airdet(self): | |||
| tinynas_object_detection = pipeline( | |||
| Tasks.image_object_detection, model='damo/cv_tinynas_detection') | |||
| result = tinynas_object_detection( | |||
| 'data/test/images/image_detection.jpg') | |||
| print(result) | |||
| @unittest.skip('will be enabled after damoyolo officially released') | |||
| def test_run_damoyolo(self): | |||
| tinynas_object_detection = pipeline( | |||
| Tasks.image_object_detection, | |||
| model='damo/cv_tinynas_object-detection_damoyolo') | |||
| result = tinynas_object_detection( | |||
| 'data/test/images/image_detection.jpg') | |||
| print(result) | |||
| @unittest.skip('demo compatibility test is only enabled on a needed-basis') | |||
| def test_demo_compatibility(self): | |||
| self.test_demo() | |||
| self.compatibility_check() | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_image_object_detection_auto_pipeline(self): | |||
| test_image = 'data/test/images/image_detection.jpg' | |||
| tinynas_object_detection = pipeline( | |||
| Tasks.image_object_detection, model='damo/cv_tinynas_detection') | |||
| result = tinynas_object_detection(test_image) | |||
| tinynas_object_detection.show_result(test_image, result, | |||
| 'demo_ret.jpg') | |||
| if __name__ == '__main__': | |||