Browse Source

merge with master

master
智丞 3 years ago
parent
commit
1bcdc05fb8
14 changed files with 956 additions and 12 deletions
  1. +1
    -0
      modelscope/metainfo.py
  2. +3
    -0
      modelscope/models/cv/cmdssl_video_embedding/__init__.py
  3. +121
    -0
      modelscope/models/cv/cmdssl_video_embedding/c3d.py
  4. +339
    -0
      modelscope/models/cv/cmdssl_video_embedding/resnet2p1d.py
  5. +284
    -0
      modelscope/models/cv/cmdssl_video_embedding/resnet3d.py
  6. +5
    -0
      modelscope/pipelines/builder.py
  7. +1
    -0
      modelscope/pipelines/cv/__init__.py
  8. +4
    -5
      modelscope/pipelines/cv/action_recognition_pipeline.py
  9. +157
    -0
      modelscope/pipelines/cv/cmdssl_video_embedding_pipleline.py
  10. +2
    -2
      modelscope/pipelines/multi_modal/text_to_image_synthesis_pipeline.py
  11. +7
    -0
      modelscope/pipelines/outputs.py
  12. +1
    -0
      modelscope/utils/constant.py
  13. +1
    -5
      tests/pipelines/test_action_recognition.py
  14. +30
    -0
      tests/pipelines/test_cmdssl_video_embedding.py

+ 1
- 0
modelscope/metainfo.py View File

@@ -49,6 +49,7 @@ class Pipelines(object):
ocr_detection = 'resnet18-ocr-detection'
action_recognition = 'TAdaConv_action-recognition'
animal_recognation = 'resnet101-animal_recog'
cmdssl_video_embedding = 'cmdssl-r2p1d_video_embedding'

# nlp tasks
sentence_similarity = 'sentence-similarity'


+ 3
- 0
modelscope/models/cv/cmdssl_video_embedding/__init__.py View File

@@ -0,0 +1,3 @@
from .c3d import C3D
from .resnet2p1d import resnet26_2p1d
from .resnet3d import resnet26_3d

+ 121
- 0
modelscope/models/cv/cmdssl_video_embedding/c3d.py View File

@@ -0,0 +1,121 @@
import torch
import torch.nn as nn


def conv3x3x3(in_planes, out_planes, stride=1):
return nn.Conv3d(
in_planes, out_planes, kernel_size=3, stride=stride, padding=1)


class C3D(nn.Module):

def __init__(self,
num_classes=1000,
dropout=0.5,
inplanes=3,
norm_layer=None,
last_pool=True):
super(C3D, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm3d
if not last_pool and num_classes is not None:
raise ValueError('num_classes should be None when last_pool=False')

self.conv1 = conv3x3x3(inplanes, 64)
self.bn1 = norm_layer(64)
self.relu1 = nn.ReLU(inplace=True)
self.pool1 = nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2))

self.conv2 = conv3x3x3(64, 128)
self.bn2 = norm_layer(128)
self.relu2 = nn.ReLU(inplace=True)
self.pool2 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2))

self.conv3a = conv3x3x3(128, 256)
self.bn3a = norm_layer(256)
self.relu3a = nn.ReLU(inplace=True)

self.conv3b = conv3x3x3(256, 256)
self.bn3b = norm_layer(256)
self.relu3b = nn.ReLU(inplace=True)
self.pool3 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2))

self.conv4a = conv3x3x3(256, 512)
self.bn4a = norm_layer(512)
self.relu4a = nn.ReLU(inplace=True)

self.conv4b = conv3x3x3(512, 512)
self.bn4b = norm_layer(512)
self.relu4b = nn.ReLU(inplace=True)
self.pool4 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2))

self.conv5a = conv3x3x3(512, 512)
self.bn5a = norm_layer(512)
self.relu5a = nn.ReLU(inplace=True)

self.conv5b = conv3x3x3(512, 512)
self.bn5b = norm_layer(512)
self.relu5b = nn.ReLU(inplace=True)
self.pool5 = nn.AdaptiveAvgPool3d((1, 1, 1)) if last_pool else None

if num_classes is None:
self.dropout = None
self.fc = None
else:
self.dropout = nn.Dropout(dropout)
self.fc = nn.Linear(512, num_classes)
self.out_planes = 512

for m in self.modules():
if isinstance(m, nn.Conv3d):
nn.init.kaiming_normal_(
m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm3d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)

def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu1(x)
x = self.pool1(x)

x = self.conv2(x)
x = self.bn2(x)
x = self.relu2(x)
x = self.pool2(x)

x = self.conv3a(x)
x = self.bn3a(x)
x = self.relu3a(x)

x = self.conv3b(x)
x = self.bn3b(x)
x = self.relu3b(x)
x = self.pool3(x)

x = self.conv4a(x)
x = self.bn4a(x)
x = self.relu4a(x)

x = self.conv4b(x)
x = self.bn4b(x)
x = self.relu4b(x)
x = self.pool4(x)

x = self.conv5a(x)
x = self.bn5a(x)
x = self.relu5a(x)

x = self.conv5b(x)
x = self.bn5b(x)
x = self.relu5b(x)

if self.pool5:
x = self.pool5(x)
x = torch.flatten(x, 1)
if self.dropout and self.fc:
x = self.dropout(x)
x = self.fc(x)

return x

+ 339
- 0
modelscope/models/cv/cmdssl_video_embedding/resnet2p1d.py View File

@@ -0,0 +1,339 @@
import torch
import torch.nn as nn


def conv1x3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
return nn.Conv3d(
in_planes,
out_planes,
kernel_size=(1, 3, 3),
stride=(1, stride, stride),
padding=(0, dilation, dilation),
groups=groups,
bias=False,
dilation=(1, dilation, dilation))


def conv3x1x1(in_planes, out_planes, stride=1, groups=1, dilation=1):
return nn.Conv3d(
in_planes,
out_planes,
kernel_size=(3, 1, 1),
stride=(stride, 1, 1),
padding=(dilation, 0, 0),
groups=groups,
bias=False,
dilation=(dilation, 1, 1))


def conv1x1x1(in_planes, out_planes, stride=1):
return nn.Conv3d(
in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class BasicBlock(nn.Module):
expansion = 1

def __init__(self,
inplanes,
planes,
stride=1,
downsample=None,
groups=1,
base_width=64,
dilation=1,
norm_layer=None):
super(BasicBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm3d
if groups != 1 or base_width != 64:
raise ValueError(
'BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError(
'Dilation > 1 not supported in BasicBlock')

midplanes1 = (inplanes * planes * 3 * 3 * 3) // (
inplanes * 3 * 3 + planes * 3)
self.conv1_s = conv1x3x3(inplanes, midplanes1, stride)
self.bn1_s = norm_layer(midplanes1)
self.conv1_t = conv3x1x1(midplanes1, planes, stride)
self.bn1_t = norm_layer(planes)

midplanes2 = (planes * planes * 3 * 3 * 3) // (
planes * 3 * 3 + planes * 3)
self.conv2_s = conv1x3x3(planes, midplanes2)
self.bn2_s = norm_layer(midplanes2)
self.conv2_t = conv3x1x1(midplanes2, planes)
self.bn2_t = norm_layer(planes)

self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride

def forward(self, x):
identity = x

out = self.conv1_s(x)
out = self.bn1_s(out)
out = self.relu(out)
out = self.conv1_t(out)
out = self.bn1_t(out)
out = self.relu(out)

out = self.conv2_s(out)
out = self.bn2_s(out)
out = self.relu(out)
out = self.conv2_t(out)
out = self.bn2_t(out)

if self.downsample is not None:
identity = self.downsample(x)

out += identity
out = self.relu(out)

return out


class Bottleneck(nn.Module):
expansion = 4

def __init__(self,
inplanes,
planes,
stride=1,
downsample=None,
groups=1,
base_width=64,
dilation=1,
norm_layer=None):
super(Bottleneck, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm3d
width = int(planes * (base_width / 64.)) * groups

self.conv1 = conv1x1x1(inplanes, width)
self.bn1 = norm_layer(width)

midplanes = (width * width * 3 * 3 * 3) // (width * 3 * 3 + width * 3)
self.conv2_s = conv1x3x3(width, midplanes, stride, groups, dilation)
self.bn2_s = norm_layer(midplanes)
self.conv2_t = conv3x1x1(midplanes, width, stride, groups, dilation)
self.bn2_t = norm_layer(width)

self.conv3 = conv1x1x1(width, planes * self.expansion)
self.bn3 = norm_layer(planes * self.expansion)

self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride

def forward(self, x):
identity = x

out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)

out = self.conv2_s(out)
out = self.bn2_s(out)
out = self.relu(out)
out = self.conv2_t(out)
out = self.bn2_t(out)
out = self.relu(out)

out = self.conv3(out)
out = self.bn3(out)

if self.downsample is not None:
identity = self.downsample(x)

out += identity
out = self.relu(out)

return out


class ResNet2p1d(nn.Module):

def __init__(self,
block,
layers,
num_classes=None,
zero_init_residual=True,
groups=1,
width_per_group=64,
replace_stride_with_dilation=None,
dropout=0.5,
inplanes=3,
first_stride=2,
norm_layer=None,
last_pool=True):
super(ResNet2p1d, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm3d
if not last_pool and num_classes is not None:
raise ValueError('num_classes should be None when last_pool=False')
self._norm_layer = norm_layer
self.first_stride = first_stride

self.inplanes = 64
self.dilation = 1
if replace_stride_with_dilation is None:
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError('replace_stride_with_dilation should be None '
'or a 3-element tuple, got {}'.format(
replace_stride_with_dilation))
self.groups = groups
self.base_width = width_per_group

midplanes = (3 * self.inplanes * 3 * 7 * 7) // (3 * 7 * 7
+ self.inplanes * 3)
self.conv1_s = nn.Conv3d(
inplanes,
midplanes,
kernel_size=(1, 7, 7),
stride=(1, first_stride, first_stride),
padding=(0, 3, 3),
bias=False)
self.bn1_s = norm_layer(midplanes)
self.conv1_t = nn.Conv3d(
midplanes,
self.inplanes,
kernel_size=(3, 1, 1),
stride=(1, 1, 1),
padding=(1, 0, 0),
bias=False)
self.bn1_t = norm_layer(self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool3d(
kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1))

self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(
block,
128,
layers[1],
stride=2,
dilate=replace_stride_with_dilation[0])
self.layer3 = self._make_layer(
block,
256,
layers[2],
stride=2,
dilate=replace_stride_with_dilation[1])
self.layer4 = self._make_layer(
block,
512,
layers[3],
stride=2,
dilate=replace_stride_with_dilation[2])
self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1)) if last_pool else None
if num_classes is None:
self.dropout = None
self.fc = None
else:
self.dropout = nn.Dropout(dropout)
self.fc = nn.Linear(512 * block.expansion, num_classes)
self.out_planes = 512 * block.expansion

for m in self.modules():
if isinstance(m, nn.Conv3d):
nn.init.kaiming_normal_(
m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm3d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)

if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0)
elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2_t.weight, 0)

def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1x1(self.inplanes, planes * block.expansion, stride),
norm_layer(planes * block.expansion))

layers = []
layers.append(
block(self.inplanes, planes, stride, downsample, self.groups,
self.base_width, previous_dilation, norm_layer))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(
block(
self.inplanes,
planes,
groups=self.groups,
base_width=self.base_width,
dilation=self.dilation,
norm_layer=norm_layer))

return nn.Sequential(*layers)

def forward(self, x):
x = self.conv1_s(x)
x = self.bn1_s(x)
x = self.relu(x)
x = self.conv1_t(x)
x = self.bn1_t(x)
x = self.relu(x)
x = self.maxpool(x)

x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)

if self.avgpool:
x = self.avgpool(x)
x = torch.flatten(x, 1)
if self.dropout and self.fc:
x = self.dropout(x)
x = self.fc(x)

return x


def resnet10_2p1d(**kwargs):
return ResNet2p1d(BasicBlock, [1, 1, 1, 1], **kwargs)


def resnet18_2p1d(**kwargs):
return ResNet2p1d(BasicBlock, [2, 2, 2, 2], **kwargs)


def resnet26_2p1d(**kwargs):
return ResNet2p1d(Bottleneck, [2, 2, 2, 2], **kwargs)


def resnet34_2p1d(**kwargs):
return ResNet2p1d(BasicBlock, [3, 4, 6, 3], **kwargs)


def resnet50_2p1d(**kwargs):
return ResNet2p1d(Bottleneck, [3, 4, 6, 3], **kwargs)


def resnet101_2p1d(**kwargs):
return ResNet2p1d(Bottleneck, [3, 4, 23, 3], **kwargs)


def resnet152_2p1d(**kwargs):
return ResNet2p1d(Bottleneck, [3, 8, 36, 3], **kwargs)


def resnet200_2p1d(**kwargs):
return ResNet2p1d(Bottleneck, [3, 24, 36, 3], **kwargs)

+ 284
- 0
modelscope/models/cv/cmdssl_video_embedding/resnet3d.py View File

@@ -0,0 +1,284 @@
import torch
import torch.nn as nn


def conv3x3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
return nn.Conv3d(
in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=dilation,
groups=groups,
bias=False,
dilation=dilation)


def conv1x1x1(in_planes, out_planes, stride=1):
return nn.Conv3d(
in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class BasicBlock(nn.Module):
expansion = 1

def __init__(self,
inplanes,
planes,
stride=1,
downsample=None,
groups=1,
base_width=64,
dilation=1,
norm_layer=None):
super(BasicBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm3d
if groups != 1 or base_width != 64:
raise ValueError(
'BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError(
'Dilation > 1 not supported in BasicBlock')
self.conv1 = conv3x3x3(inplanes, planes, stride)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3x3(planes, planes)
self.bn2 = norm_layer(planes)
self.downsample = downsample
self.stride = stride

def forward(self, x):
identity = x

out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)

out = self.conv2(out)
out = self.bn2(out)

if self.downsample is not None:
identity = self.downsample(x)

out += identity
out = self.relu(out)

return out


class Bottleneck(nn.Module):
expansion = 4

def __init__(self,
inplanes,
planes,
stride=1,
downsample=None,
groups=1,
base_width=64,
dilation=1,
norm_layer=None):
super(Bottleneck, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm3d
width = int(planes * (base_width / 64.)) * groups
self.conv1 = conv1x1x1(inplanes, width)
self.bn1 = norm_layer(width)
self.conv2 = conv3x3x3(width, width, stride, groups, dilation)
self.bn2 = norm_layer(width)
self.conv3 = conv1x1x1(width, planes * self.expansion)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride

def forward(self, x):
identity = x

out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)

out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)

out = self.conv3(out)
out = self.bn3(out)

if self.downsample is not None:
identity = self.downsample(x)

out += identity
out = self.relu(out)

return out


class ResNet3d(nn.Module):

def __init__(self,
block,
layers,
num_classes=1000,
zero_init_residual=True,
groups=1,
width_per_group=64,
replace_stride_with_dilation=None,
dropout=0.5,
inplanes=3,
first_stride=2,
norm_layer=None,
last_pool=True):
super(ResNet3d, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm3d
if not last_pool and num_classes is not None:
raise ValueError('num_classes should be None when last_pool=False')
self._norm_layer = norm_layer

self.inplanes = 64
self.dilation = 1
if replace_stride_with_dilation is None:
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError('replace_stride_with_dilation should be None '
'or a 3-element tuple, got {}'.format(
replace_stride_with_dilation))
self.groups = groups
self.base_width = width_per_group
self.conv1 = nn.Conv3d(
inplanes,
self.inplanes,
kernel_size=(3, 7, 7),
stride=(1, first_stride, first_stride),
padding=(1, 3, 3),
bias=False)
self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool3d(
kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1))
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(
block,
128,
layers[1],
stride=2,
dilate=replace_stride_with_dilation[0])
self.layer3 = self._make_layer(
block,
256,
layers[2],
stride=2,
dilate=replace_stride_with_dilation[1])
self.layer4 = self._make_layer(
block,
512,
layers[3],
stride=2,
dilate=replace_stride_with_dilation[2])
self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1)) if last_pool else None
if num_classes is None:
self.dropout = None
self.fc = None
else:
self.dropout = nn.Dropout(dropout)
self.fc = nn.Linear(512 * block.expansion, num_classes)
self.out_planes = 512 * block.expansion

for m in self.modules():
if isinstance(m, nn.Conv3d):
nn.init.kaiming_normal_(
m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm3d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)

if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0)
elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0)

def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1x1(self.inplanes, planes * block.expansion, stride),
norm_layer(planes * block.expansion))

layers = []
layers.append(
block(self.inplanes, planes, stride, downsample, self.groups,
self.base_width, previous_dilation, norm_layer))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(
block(
self.inplanes,
planes,
groups=self.groups,
base_width=self.base_width,
dilation=self.dilation,
norm_layer=norm_layer))

return nn.Sequential(*layers)

def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)

x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)

if self.avgpool:
x = self.avgpool(x)
x = torch.flatten(x, 1)
if self.dropout and self.fc:
x = self.dropout(x)
x = self.fc(x)

return x


def resnet10_3d(**kwargs):
return ResNet3d(BasicBlock, [1, 1, 1, 1], **kwargs)


def resnet18_3d(**kwargs):
return ResNet3d(BasicBlock, [2, 2, 2, 2], **kwargs)


def resnet26_3d(**kwargs):
return ResNet3d(Bottleneck, [2, 2, 2, 2], **kwargs)


def resnet34_3d(**kwargs):
return ResNet3d(BasicBlock, [3, 4, 6, 3], **kwargs)


def resnet50_3d(**kwargs):
return ResNet3d(Bottleneck, [3, 4, 6, 3], **kwargs)


def resnet101_3d(**kwargs):
return ResNet3d(Bottleneck, [3, 4, 23, 3], **kwargs)


def resnet152_3d(**kwargs):
return ResNet3d(Bottleneck, [3, 8, 36, 3], **kwargs)


def resnet200_3d(**kwargs):
return ResNet3d(Bottleneck, [3, 24, 36, 3], **kwargs)

+ 5
- 0
modelscope/pipelines/builder.py View File

@@ -59,6 +59,11 @@ DEFAULT_MODEL_FOR_PIPELINE = {
Tasks.visual_question_answering:
(Pipelines.visual_question_answering,
'damo/mplug_visual-question-answering_coco_large_en'),
Tasks.video_embedding: (Pipelines.cmdssl_video_embedding,
'damo/cv_r2p1d_video_embedding'),
Tasks.text_to_image_synthesis:
(Pipelines.text_to_image_synthesis,
'damo/cv_imagen_text-to-image-synthesis_tiny')
}




+ 1
- 0
modelscope/pipelines/cv/__init__.py View File

@@ -1,6 +1,7 @@
try:
from .action_recognition_pipeline import ActionRecognitionPipeline
from .animal_recog_pipeline import AnimalRecogPipeline
from .cmdssl_video_embedding_pipleline import CMDSSLVideoEmbeddingPipeline
except ModuleNotFoundError as e:
if str(e) == "No module named 'torch'":
pass


+ 4
- 5
modelscope/pipelines/cv/action_recognition_pipeline.py View File

@@ -2,9 +2,6 @@ import math
import os.path as osp
from typing import Any, Dict

import cv2
import numpy as np
import PIL
import torch

from modelscope.metainfo import Pipelines
@@ -32,7 +29,9 @@ class ActionRecognitionPipeline(Pipeline):
config_path = osp.join(self.model, ModelFile.CONFIGURATION)
logger.info(f'loading config from {config_path}')
self.cfg = Config.from_file(config_path)
self.infer_model = BaseVideoModel(cfg=self.cfg).cuda()
self.device = torch.device(
'cuda' if torch.cuda.is_available() else 'cpu')
self.infer_model = BaseVideoModel(cfg=self.cfg).to(self.device)
self.infer_model.eval()
self.infer_model.load_state_dict(torch.load(model_path)['model_state'])
self.label_mapping = self.cfg.label_mapping
@@ -40,7 +39,7 @@ class ActionRecognitionPipeline(Pipeline):

def preprocess(self, input: Input) -> Dict[str, Any]:
if isinstance(input, str):
video_input_data = ReadVideoData(self.cfg, input).cuda()
video_input_data = ReadVideoData(self.cfg, input).to(self.device)
else:
raise TypeError(f'input should be a str,'
f' but got {type(input)}')


+ 157
- 0
modelscope/pipelines/cv/cmdssl_video_embedding_pipleline.py View File

@@ -0,0 +1,157 @@
import math
import os.path as osp
from typing import Any, Dict

import cv2
import decord
import numpy as np
import PIL
import torch
import torchvision.transforms.functional as TF
from decord import VideoReader, cpu
from PIL import Image

from modelscope.metainfo import Pipelines
from modelscope.models.cv.cmdssl_video_embedding.resnet2p1d import \
resnet26_2p1d
from modelscope.pipelines.base import Input
from modelscope.pipelines.outputs import OutputKeys
from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger
from ..base import Pipeline
from ..builder import PIPELINES

logger = get_logger()


@PIPELINES.register_module(
Tasks.video_embedding, module_name=Pipelines.cmdssl_video_embedding)
class CMDSSLVideoEmbeddingPipeline(Pipeline):

def __init__(self, model: str):
super().__init__(model=model)
model_path = osp.join(self.model, ModelFile.TORCH_MODEL_FILE)
logger.info(f'loading model from {model_path}')
config_path = osp.join(self.model, ModelFile.CONFIGURATION)
logger.info(f'loading config from {config_path}')
self.cfg = Config.from_file(config_path)
self.model = resnet26_2p1d(num_classes=None, last_pool=True)

if torch.cuda.is_available():
self._device = torch.device('cuda')
else:
self._device = torch.device('cpu')
self.model = self.model.to(self._device).eval().requires_grad_(False)
self.model.load_state_dict(torch.load(model_path))
logger.info('load model done')

def preprocess(self, input: Input) -> Dict[str, Any]:
decord.bridge.set_bridge('native')

transforms = VCompose([
VRescale(size=self.cfg.DATA.scale_size),
VCenterCrop(size=self.cfg.DATA.crop_size),
VToTensor(),
VNormalize(mean=self.cfg.DATA.mean, std=self.cfg.DATA.std)
])

clip_len = (self.cfg.DATA.video_frames
- 1) * self.cfg.DATA.video_stride + 1
vr = VideoReader(input, ctx=cpu(0))
if len(vr) <= clip_len:
init_frames = np.zeros(self.cfg.DATA.multi_crop, dtype=int)
else:
init_frames = np.linspace(0,
len(vr) - clip_len,
self.cfg.DATA.multi_crop + 1)
init_frames = ((init_frames[1:] + init_frames[:-1])
/ 2.).astype(int)

indices = np.arange(0, clip_len, self.cfg.DATA.video_stride)
indices = (init_frames[:, None] + indices[None, :]).reshape(-1)
indices[indices >= len(vr)] = 0

frames = torch.from_numpy(vr.get_batch(indices).asnumpy()).chunk(
self.cfg.DATA.multi_crop, dim=0)
frames = [
transforms([Image.fromarray(f) for f in u.numpy()]) for u in frames
]
frames = torch.stack(frames, dim=0)
result = {'video_data': frames}
return result

@torch.no_grad()
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
frames = input['video_data'].to(self._device)
feature = self.model(frames)
feature = feature.mean(0)
return {OutputKeys.VIDEO_EMBEDDING: feature.data.cpu().numpy()}

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
return inputs


class VCompose(object):

def __init__(self, transforms):
self.transforms = transforms

def __call__(self, item):
for t in self.transforms:
item = t(item)
return item


class VRescale(object):

def __init__(self, size=128):
self.size = size

def __call__(self, vclip):
w, h = vclip[0].size
scale = self.size / min(w, h)
out_w, out_h = int(round(w * scale)), int(round(h * scale))
vclip = [u.resize((out_w, out_h), Image.BILINEAR) for u in vclip]
return vclip


class VCenterCrop(object):

def __init__(self, size=112):
self.size = size

def __call__(self, vclip):
w, h = vclip[0].size
assert min(w, h) >= self.size
x1 = (w - self.size) // 2
y1 = (h - self.size) // 2
vclip = [
u.crop((x1, y1, x1 + self.size, y1 + self.size)) for u in vclip
]
return vclip


class VToTensor(object):

def __call__(self, vclip):
vclip = torch.stack([TF.to_tensor(u) for u in vclip], dim=1)
return vclip


class VNormalize(object):

def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
self.mean = mean
self.std = std

def __call__(self, vclip):
assert vclip.min() > -0.1 and vclip.max() < 1.1, \
'vclip values should be in [0, 1]'
vclip = vclip.clone()
if not isinstance(self.mean, torch.Tensor):
self.mean = vclip.new_tensor(self.mean).view(-1, 1, 1, 1)
if not isinstance(self.std, torch.Tensor):
self.std = vclip.new_tensor(self.std).view(-1, 1, 1, 1)
vclip.sub_(self.mean).div_(self.std)
return vclip

+ 2
- 2
modelscope/pipelines/multi_modal/text_to_image_synthesis_pipeline.py View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, Union
from typing import Any, Dict

from modelscope.metainfo import Pipelines
from modelscope.pipelines.base import Input
@@ -23,7 +23,7 @@ class TextToImageSynthesisPipeline(Pipeline):
pipe_model = model
else:
raise NotImplementedError(
f'execpting a Model instance or str, but get {type(model)}.')
f'expecting a Model instance or str, but get {type(model)}.')

super().__init__(model=pipe_model)



+ 7
- 0
modelscope/pipelines/outputs.py View File

@@ -22,6 +22,7 @@ class OutputKeys(object):
RESPONSE = 'response'
PREDICTION = 'prediction'
DIALOG_STATES = 'dialog_states'
VIDEO_EMBEDDING = 'video_embedding'


TASK_OUTPUTS = {
@@ -91,6 +92,12 @@ TASK_OUTPUTS = {
# }
Tasks.ocr_detection: [OutputKeys.POLYGONS],

# video embedding result for single video
# {
# "video_embedding": np.array with shape [D],
# }
Tasks.video_embedding: [OutputKeys.VIDEO_EMBEDDING],

# ============ nlp tasks ===================

# text classification result for single sample


+ 1
- 0
modelscope/utils/constant.py View File

@@ -31,6 +31,7 @@ class Tasks(object):
image_matting = 'image-matting'
ocr_detection = 'ocr-detection'
action_recognition = 'action-recognition'
video_embedding = 'video-embedding'

# nlp tasks
word_segmentation = 'word-segmentation'


+ 1
- 5
tests/pipelines/test_action_recognition.py View File

@@ -1,14 +1,10 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# !/usr/bin/env python
import os.path as osp
import shutil
import tempfile
import unittest

import cv2

from modelscope.fileio import File
from modelscope.msdatasets import MsDataset
from modelscope.pipelines import pipeline
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.test_utils import test_level
@@ -45,7 +41,7 @@ class ActionRecognitionTest(unittest.TestCase):

print(f'recognition output: {result}.')

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_modelhub_default_model(self):
recognition_pipeline = pipeline(Tasks.action_recognition)
result = recognition_pipeline(


+ 30
- 0
tests/pipelines/test_cmdssl_video_embedding.py View File

@@ -0,0 +1,30 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# !/usr/bin/env python
import os.path as osp
import shutil
import tempfile
import unittest

import cv2

from modelscope.fileio import File
from modelscope.msdatasets import MsDataset
from modelscope.pipelines import pipeline
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.test_utils import test_level


class CMDSSLVideoEmbeddingTest(unittest.TestCase):

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_modelhub(self):
videossl_pipeline = pipeline(
Tasks.video_embedding, model='damo/cv_r2p1d_video_embedding')
result = videossl_pipeline(
'data/test/videos/action_recognition_test_video.mp4')

print(f'video embedding output: {result}.')


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save