| @@ -49,6 +49,7 @@ class Pipelines(object): | |||
| ocr_detection = 'resnet18-ocr-detection' | |||
| action_recognition = 'TAdaConv_action-recognition' | |||
| animal_recognation = 'resnet101-animal_recog' | |||
| cmdssl_video_embedding = 'cmdssl-r2p1d_video_embedding' | |||
| # nlp tasks | |||
| sentence_similarity = 'sentence-similarity' | |||
| @@ -0,0 +1,3 @@ | |||
| from .c3d import C3D | |||
| from .resnet2p1d import resnet26_2p1d | |||
| from .resnet3d import resnet26_3d | |||
| @@ -0,0 +1,121 @@ | |||
| import torch | |||
| import torch.nn as nn | |||
| def conv3x3x3(in_planes, out_planes, stride=1): | |||
| return nn.Conv3d( | |||
| in_planes, out_planes, kernel_size=3, stride=stride, padding=1) | |||
| class C3D(nn.Module): | |||
| def __init__(self, | |||
| num_classes=1000, | |||
| dropout=0.5, | |||
| inplanes=3, | |||
| norm_layer=None, | |||
| last_pool=True): | |||
| super(C3D, self).__init__() | |||
| if norm_layer is None: | |||
| norm_layer = nn.BatchNorm3d | |||
| if not last_pool and num_classes is not None: | |||
| raise ValueError('num_classes should be None when last_pool=False') | |||
| self.conv1 = conv3x3x3(inplanes, 64) | |||
| self.bn1 = norm_layer(64) | |||
| self.relu1 = nn.ReLU(inplace=True) | |||
| self.pool1 = nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2)) | |||
| self.conv2 = conv3x3x3(64, 128) | |||
| self.bn2 = norm_layer(128) | |||
| self.relu2 = nn.ReLU(inplace=True) | |||
| self.pool2 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2)) | |||
| self.conv3a = conv3x3x3(128, 256) | |||
| self.bn3a = norm_layer(256) | |||
| self.relu3a = nn.ReLU(inplace=True) | |||
| self.conv3b = conv3x3x3(256, 256) | |||
| self.bn3b = norm_layer(256) | |||
| self.relu3b = nn.ReLU(inplace=True) | |||
| self.pool3 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2)) | |||
| self.conv4a = conv3x3x3(256, 512) | |||
| self.bn4a = norm_layer(512) | |||
| self.relu4a = nn.ReLU(inplace=True) | |||
| self.conv4b = conv3x3x3(512, 512) | |||
| self.bn4b = norm_layer(512) | |||
| self.relu4b = nn.ReLU(inplace=True) | |||
| self.pool4 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2)) | |||
| self.conv5a = conv3x3x3(512, 512) | |||
| self.bn5a = norm_layer(512) | |||
| self.relu5a = nn.ReLU(inplace=True) | |||
| self.conv5b = conv3x3x3(512, 512) | |||
| self.bn5b = norm_layer(512) | |||
| self.relu5b = nn.ReLU(inplace=True) | |||
| self.pool5 = nn.AdaptiveAvgPool3d((1, 1, 1)) if last_pool else None | |||
| if num_classes is None: | |||
| self.dropout = None | |||
| self.fc = None | |||
| else: | |||
| self.dropout = nn.Dropout(dropout) | |||
| self.fc = nn.Linear(512, num_classes) | |||
| self.out_planes = 512 | |||
| for m in self.modules(): | |||
| if isinstance(m, nn.Conv3d): | |||
| nn.init.kaiming_normal_( | |||
| m.weight, mode='fan_out', nonlinearity='relu') | |||
| elif isinstance(m, (nn.BatchNorm3d, nn.GroupNorm)): | |||
| nn.init.constant_(m.weight, 1) | |||
| nn.init.constant_(m.bias, 0) | |||
| def forward(self, x): | |||
| x = self.conv1(x) | |||
| x = self.bn1(x) | |||
| x = self.relu1(x) | |||
| x = self.pool1(x) | |||
| x = self.conv2(x) | |||
| x = self.bn2(x) | |||
| x = self.relu2(x) | |||
| x = self.pool2(x) | |||
| x = self.conv3a(x) | |||
| x = self.bn3a(x) | |||
| x = self.relu3a(x) | |||
| x = self.conv3b(x) | |||
| x = self.bn3b(x) | |||
| x = self.relu3b(x) | |||
| x = self.pool3(x) | |||
| x = self.conv4a(x) | |||
| x = self.bn4a(x) | |||
| x = self.relu4a(x) | |||
| x = self.conv4b(x) | |||
| x = self.bn4b(x) | |||
| x = self.relu4b(x) | |||
| x = self.pool4(x) | |||
| x = self.conv5a(x) | |||
| x = self.bn5a(x) | |||
| x = self.relu5a(x) | |||
| x = self.conv5b(x) | |||
| x = self.bn5b(x) | |||
| x = self.relu5b(x) | |||
| if self.pool5: | |||
| x = self.pool5(x) | |||
| x = torch.flatten(x, 1) | |||
| if self.dropout and self.fc: | |||
| x = self.dropout(x) | |||
| x = self.fc(x) | |||
| return x | |||
| @@ -0,0 +1,339 @@ | |||
| import torch | |||
| import torch.nn as nn | |||
| def conv1x3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): | |||
| return nn.Conv3d( | |||
| in_planes, | |||
| out_planes, | |||
| kernel_size=(1, 3, 3), | |||
| stride=(1, stride, stride), | |||
| padding=(0, dilation, dilation), | |||
| groups=groups, | |||
| bias=False, | |||
| dilation=(1, dilation, dilation)) | |||
| def conv3x1x1(in_planes, out_planes, stride=1, groups=1, dilation=1): | |||
| return nn.Conv3d( | |||
| in_planes, | |||
| out_planes, | |||
| kernel_size=(3, 1, 1), | |||
| stride=(stride, 1, 1), | |||
| padding=(dilation, 0, 0), | |||
| groups=groups, | |||
| bias=False, | |||
| dilation=(dilation, 1, 1)) | |||
| def conv1x1x1(in_planes, out_planes, stride=1): | |||
| return nn.Conv3d( | |||
| in_planes, out_planes, kernel_size=1, stride=stride, bias=False) | |||
| class BasicBlock(nn.Module): | |||
| expansion = 1 | |||
| def __init__(self, | |||
| inplanes, | |||
| planes, | |||
| stride=1, | |||
| downsample=None, | |||
| groups=1, | |||
| base_width=64, | |||
| dilation=1, | |||
| norm_layer=None): | |||
| super(BasicBlock, self).__init__() | |||
| if norm_layer is None: | |||
| norm_layer = nn.BatchNorm3d | |||
| if groups != 1 or base_width != 64: | |||
| raise ValueError( | |||
| 'BasicBlock only supports groups=1 and base_width=64') | |||
| if dilation > 1: | |||
| raise NotImplementedError( | |||
| 'Dilation > 1 not supported in BasicBlock') | |||
| midplanes1 = (inplanes * planes * 3 * 3 * 3) // ( | |||
| inplanes * 3 * 3 + planes * 3) | |||
| self.conv1_s = conv1x3x3(inplanes, midplanes1, stride) | |||
| self.bn1_s = norm_layer(midplanes1) | |||
| self.conv1_t = conv3x1x1(midplanes1, planes, stride) | |||
| self.bn1_t = norm_layer(planes) | |||
| midplanes2 = (planes * planes * 3 * 3 * 3) // ( | |||
| planes * 3 * 3 + planes * 3) | |||
| self.conv2_s = conv1x3x3(planes, midplanes2) | |||
| self.bn2_s = norm_layer(midplanes2) | |||
| self.conv2_t = conv3x1x1(midplanes2, planes) | |||
| self.bn2_t = norm_layer(planes) | |||
| self.relu = nn.ReLU(inplace=True) | |||
| self.downsample = downsample | |||
| self.stride = stride | |||
| def forward(self, x): | |||
| identity = x | |||
| out = self.conv1_s(x) | |||
| out = self.bn1_s(out) | |||
| out = self.relu(out) | |||
| out = self.conv1_t(out) | |||
| out = self.bn1_t(out) | |||
| out = self.relu(out) | |||
| out = self.conv2_s(out) | |||
| out = self.bn2_s(out) | |||
| out = self.relu(out) | |||
| out = self.conv2_t(out) | |||
| out = self.bn2_t(out) | |||
| if self.downsample is not None: | |||
| identity = self.downsample(x) | |||
| out += identity | |||
| out = self.relu(out) | |||
| return out | |||
| class Bottleneck(nn.Module): | |||
| expansion = 4 | |||
| def __init__(self, | |||
| inplanes, | |||
| planes, | |||
| stride=1, | |||
| downsample=None, | |||
| groups=1, | |||
| base_width=64, | |||
| dilation=1, | |||
| norm_layer=None): | |||
| super(Bottleneck, self).__init__() | |||
| if norm_layer is None: | |||
| norm_layer = nn.BatchNorm3d | |||
| width = int(planes * (base_width / 64.)) * groups | |||
| self.conv1 = conv1x1x1(inplanes, width) | |||
| self.bn1 = norm_layer(width) | |||
| midplanes = (width * width * 3 * 3 * 3) // (width * 3 * 3 + width * 3) | |||
| self.conv2_s = conv1x3x3(width, midplanes, stride, groups, dilation) | |||
| self.bn2_s = norm_layer(midplanes) | |||
| self.conv2_t = conv3x1x1(midplanes, width, stride, groups, dilation) | |||
| self.bn2_t = norm_layer(width) | |||
| self.conv3 = conv1x1x1(width, planes * self.expansion) | |||
| self.bn3 = norm_layer(planes * self.expansion) | |||
| self.relu = nn.ReLU(inplace=True) | |||
| self.downsample = downsample | |||
| self.stride = stride | |||
| def forward(self, x): | |||
| identity = x | |||
| out = self.conv1(x) | |||
| out = self.bn1(out) | |||
| out = self.relu(out) | |||
| out = self.conv2_s(out) | |||
| out = self.bn2_s(out) | |||
| out = self.relu(out) | |||
| out = self.conv2_t(out) | |||
| out = self.bn2_t(out) | |||
| out = self.relu(out) | |||
| out = self.conv3(out) | |||
| out = self.bn3(out) | |||
| if self.downsample is not None: | |||
| identity = self.downsample(x) | |||
| out += identity | |||
| out = self.relu(out) | |||
| return out | |||
| class ResNet2p1d(nn.Module): | |||
| def __init__(self, | |||
| block, | |||
| layers, | |||
| num_classes=None, | |||
| zero_init_residual=True, | |||
| groups=1, | |||
| width_per_group=64, | |||
| replace_stride_with_dilation=None, | |||
| dropout=0.5, | |||
| inplanes=3, | |||
| first_stride=2, | |||
| norm_layer=None, | |||
| last_pool=True): | |||
| super(ResNet2p1d, self).__init__() | |||
| if norm_layer is None: | |||
| norm_layer = nn.BatchNorm3d | |||
| if not last_pool and num_classes is not None: | |||
| raise ValueError('num_classes should be None when last_pool=False') | |||
| self._norm_layer = norm_layer | |||
| self.first_stride = first_stride | |||
| self.inplanes = 64 | |||
| self.dilation = 1 | |||
| if replace_stride_with_dilation is None: | |||
| replace_stride_with_dilation = [False, False, False] | |||
| if len(replace_stride_with_dilation) != 3: | |||
| raise ValueError('replace_stride_with_dilation should be None ' | |||
| 'or a 3-element tuple, got {}'.format( | |||
| replace_stride_with_dilation)) | |||
| self.groups = groups | |||
| self.base_width = width_per_group | |||
| midplanes = (3 * self.inplanes * 3 * 7 * 7) // (3 * 7 * 7 | |||
| + self.inplanes * 3) | |||
| self.conv1_s = nn.Conv3d( | |||
| inplanes, | |||
| midplanes, | |||
| kernel_size=(1, 7, 7), | |||
| stride=(1, first_stride, first_stride), | |||
| padding=(0, 3, 3), | |||
| bias=False) | |||
| self.bn1_s = norm_layer(midplanes) | |||
| self.conv1_t = nn.Conv3d( | |||
| midplanes, | |||
| self.inplanes, | |||
| kernel_size=(3, 1, 1), | |||
| stride=(1, 1, 1), | |||
| padding=(1, 0, 0), | |||
| bias=False) | |||
| self.bn1_t = norm_layer(self.inplanes) | |||
| self.relu = nn.ReLU(inplace=True) | |||
| self.maxpool = nn.MaxPool3d( | |||
| kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)) | |||
| self.layer1 = self._make_layer(block, 64, layers[0]) | |||
| self.layer2 = self._make_layer( | |||
| block, | |||
| 128, | |||
| layers[1], | |||
| stride=2, | |||
| dilate=replace_stride_with_dilation[0]) | |||
| self.layer3 = self._make_layer( | |||
| block, | |||
| 256, | |||
| layers[2], | |||
| stride=2, | |||
| dilate=replace_stride_with_dilation[1]) | |||
| self.layer4 = self._make_layer( | |||
| block, | |||
| 512, | |||
| layers[3], | |||
| stride=2, | |||
| dilate=replace_stride_with_dilation[2]) | |||
| self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1)) if last_pool else None | |||
| if num_classes is None: | |||
| self.dropout = None | |||
| self.fc = None | |||
| else: | |||
| self.dropout = nn.Dropout(dropout) | |||
| self.fc = nn.Linear(512 * block.expansion, num_classes) | |||
| self.out_planes = 512 * block.expansion | |||
| for m in self.modules(): | |||
| if isinstance(m, nn.Conv3d): | |||
| nn.init.kaiming_normal_( | |||
| m.weight, mode='fan_out', nonlinearity='relu') | |||
| elif isinstance(m, (nn.BatchNorm3d, nn.GroupNorm)): | |||
| nn.init.constant_(m.weight, 1) | |||
| nn.init.constant_(m.bias, 0) | |||
| if zero_init_residual: | |||
| for m in self.modules(): | |||
| if isinstance(m, Bottleneck): | |||
| nn.init.constant_(m.bn3.weight, 0) | |||
| elif isinstance(m, BasicBlock): | |||
| nn.init.constant_(m.bn2_t.weight, 0) | |||
| def _make_layer(self, block, planes, blocks, stride=1, dilate=False): | |||
| norm_layer = self._norm_layer | |||
| downsample = None | |||
| previous_dilation = self.dilation | |||
| if dilate: | |||
| self.dilation *= stride | |||
| stride = 1 | |||
| if stride != 1 or self.inplanes != planes * block.expansion: | |||
| downsample = nn.Sequential( | |||
| conv1x1x1(self.inplanes, planes * block.expansion, stride), | |||
| norm_layer(planes * block.expansion)) | |||
| layers = [] | |||
| layers.append( | |||
| block(self.inplanes, planes, stride, downsample, self.groups, | |||
| self.base_width, previous_dilation, norm_layer)) | |||
| self.inplanes = planes * block.expansion | |||
| for _ in range(1, blocks): | |||
| layers.append( | |||
| block( | |||
| self.inplanes, | |||
| planes, | |||
| groups=self.groups, | |||
| base_width=self.base_width, | |||
| dilation=self.dilation, | |||
| norm_layer=norm_layer)) | |||
| return nn.Sequential(*layers) | |||
| def forward(self, x): | |||
| x = self.conv1_s(x) | |||
| x = self.bn1_s(x) | |||
| x = self.relu(x) | |||
| x = self.conv1_t(x) | |||
| x = self.bn1_t(x) | |||
| x = self.relu(x) | |||
| x = self.maxpool(x) | |||
| x = self.layer1(x) | |||
| x = self.layer2(x) | |||
| x = self.layer3(x) | |||
| x = self.layer4(x) | |||
| if self.avgpool: | |||
| x = self.avgpool(x) | |||
| x = torch.flatten(x, 1) | |||
| if self.dropout and self.fc: | |||
| x = self.dropout(x) | |||
| x = self.fc(x) | |||
| return x | |||
| def resnet10_2p1d(**kwargs): | |||
| return ResNet2p1d(BasicBlock, [1, 1, 1, 1], **kwargs) | |||
| def resnet18_2p1d(**kwargs): | |||
| return ResNet2p1d(BasicBlock, [2, 2, 2, 2], **kwargs) | |||
| def resnet26_2p1d(**kwargs): | |||
| return ResNet2p1d(Bottleneck, [2, 2, 2, 2], **kwargs) | |||
| def resnet34_2p1d(**kwargs): | |||
| return ResNet2p1d(BasicBlock, [3, 4, 6, 3], **kwargs) | |||
| def resnet50_2p1d(**kwargs): | |||
| return ResNet2p1d(Bottleneck, [3, 4, 6, 3], **kwargs) | |||
| def resnet101_2p1d(**kwargs): | |||
| return ResNet2p1d(Bottleneck, [3, 4, 23, 3], **kwargs) | |||
| def resnet152_2p1d(**kwargs): | |||
| return ResNet2p1d(Bottleneck, [3, 8, 36, 3], **kwargs) | |||
| def resnet200_2p1d(**kwargs): | |||
| return ResNet2p1d(Bottleneck, [3, 24, 36, 3], **kwargs) | |||
| @@ -0,0 +1,284 @@ | |||
| import torch | |||
| import torch.nn as nn | |||
| def conv3x3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): | |||
| return nn.Conv3d( | |||
| in_planes, | |||
| out_planes, | |||
| kernel_size=3, | |||
| stride=stride, | |||
| padding=dilation, | |||
| groups=groups, | |||
| bias=False, | |||
| dilation=dilation) | |||
| def conv1x1x1(in_planes, out_planes, stride=1): | |||
| return nn.Conv3d( | |||
| in_planes, out_planes, kernel_size=1, stride=stride, bias=False) | |||
| class BasicBlock(nn.Module): | |||
| expansion = 1 | |||
| def __init__(self, | |||
| inplanes, | |||
| planes, | |||
| stride=1, | |||
| downsample=None, | |||
| groups=1, | |||
| base_width=64, | |||
| dilation=1, | |||
| norm_layer=None): | |||
| super(BasicBlock, self).__init__() | |||
| if norm_layer is None: | |||
| norm_layer = nn.BatchNorm3d | |||
| if groups != 1 or base_width != 64: | |||
| raise ValueError( | |||
| 'BasicBlock only supports groups=1 and base_width=64') | |||
| if dilation > 1: | |||
| raise NotImplementedError( | |||
| 'Dilation > 1 not supported in BasicBlock') | |||
| self.conv1 = conv3x3x3(inplanes, planes, stride) | |||
| self.bn1 = norm_layer(planes) | |||
| self.relu = nn.ReLU(inplace=True) | |||
| self.conv2 = conv3x3x3(planes, planes) | |||
| self.bn2 = norm_layer(planes) | |||
| self.downsample = downsample | |||
| self.stride = stride | |||
| def forward(self, x): | |||
| identity = x | |||
| out = self.conv1(x) | |||
| out = self.bn1(out) | |||
| out = self.relu(out) | |||
| out = self.conv2(out) | |||
| out = self.bn2(out) | |||
| if self.downsample is not None: | |||
| identity = self.downsample(x) | |||
| out += identity | |||
| out = self.relu(out) | |||
| return out | |||
| class Bottleneck(nn.Module): | |||
| expansion = 4 | |||
| def __init__(self, | |||
| inplanes, | |||
| planes, | |||
| stride=1, | |||
| downsample=None, | |||
| groups=1, | |||
| base_width=64, | |||
| dilation=1, | |||
| norm_layer=None): | |||
| super(Bottleneck, self).__init__() | |||
| if norm_layer is None: | |||
| norm_layer = nn.BatchNorm3d | |||
| width = int(planes * (base_width / 64.)) * groups | |||
| self.conv1 = conv1x1x1(inplanes, width) | |||
| self.bn1 = norm_layer(width) | |||
| self.conv2 = conv3x3x3(width, width, stride, groups, dilation) | |||
| self.bn2 = norm_layer(width) | |||
| self.conv3 = conv1x1x1(width, planes * self.expansion) | |||
| self.bn3 = norm_layer(planes * self.expansion) | |||
| self.relu = nn.ReLU(inplace=True) | |||
| self.downsample = downsample | |||
| self.stride = stride | |||
| def forward(self, x): | |||
| identity = x | |||
| out = self.conv1(x) | |||
| out = self.bn1(out) | |||
| out = self.relu(out) | |||
| out = self.conv2(out) | |||
| out = self.bn2(out) | |||
| out = self.relu(out) | |||
| out = self.conv3(out) | |||
| out = self.bn3(out) | |||
| if self.downsample is not None: | |||
| identity = self.downsample(x) | |||
| out += identity | |||
| out = self.relu(out) | |||
| return out | |||
| class ResNet3d(nn.Module): | |||
| def __init__(self, | |||
| block, | |||
| layers, | |||
| num_classes=1000, | |||
| zero_init_residual=True, | |||
| groups=1, | |||
| width_per_group=64, | |||
| replace_stride_with_dilation=None, | |||
| dropout=0.5, | |||
| inplanes=3, | |||
| first_stride=2, | |||
| norm_layer=None, | |||
| last_pool=True): | |||
| super(ResNet3d, self).__init__() | |||
| if norm_layer is None: | |||
| norm_layer = nn.BatchNorm3d | |||
| if not last_pool and num_classes is not None: | |||
| raise ValueError('num_classes should be None when last_pool=False') | |||
| self._norm_layer = norm_layer | |||
| self.inplanes = 64 | |||
| self.dilation = 1 | |||
| if replace_stride_with_dilation is None: | |||
| replace_stride_with_dilation = [False, False, False] | |||
| if len(replace_stride_with_dilation) != 3: | |||
| raise ValueError('replace_stride_with_dilation should be None ' | |||
| 'or a 3-element tuple, got {}'.format( | |||
| replace_stride_with_dilation)) | |||
| self.groups = groups | |||
| self.base_width = width_per_group | |||
| self.conv1 = nn.Conv3d( | |||
| inplanes, | |||
| self.inplanes, | |||
| kernel_size=(3, 7, 7), | |||
| stride=(1, first_stride, first_stride), | |||
| padding=(1, 3, 3), | |||
| bias=False) | |||
| self.bn1 = norm_layer(self.inplanes) | |||
| self.relu = nn.ReLU(inplace=True) | |||
| self.maxpool = nn.MaxPool3d( | |||
| kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)) | |||
| self.layer1 = self._make_layer(block, 64, layers[0]) | |||
| self.layer2 = self._make_layer( | |||
| block, | |||
| 128, | |||
| layers[1], | |||
| stride=2, | |||
| dilate=replace_stride_with_dilation[0]) | |||
| self.layer3 = self._make_layer( | |||
| block, | |||
| 256, | |||
| layers[2], | |||
| stride=2, | |||
| dilate=replace_stride_with_dilation[1]) | |||
| self.layer4 = self._make_layer( | |||
| block, | |||
| 512, | |||
| layers[3], | |||
| stride=2, | |||
| dilate=replace_stride_with_dilation[2]) | |||
| self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1)) if last_pool else None | |||
| if num_classes is None: | |||
| self.dropout = None | |||
| self.fc = None | |||
| else: | |||
| self.dropout = nn.Dropout(dropout) | |||
| self.fc = nn.Linear(512 * block.expansion, num_classes) | |||
| self.out_planes = 512 * block.expansion | |||
| for m in self.modules(): | |||
| if isinstance(m, nn.Conv3d): | |||
| nn.init.kaiming_normal_( | |||
| m.weight, mode='fan_out', nonlinearity='relu') | |||
| elif isinstance(m, (nn.BatchNorm3d, nn.GroupNorm)): | |||
| nn.init.constant_(m.weight, 1) | |||
| nn.init.constant_(m.bias, 0) | |||
| if zero_init_residual: | |||
| for m in self.modules(): | |||
| if isinstance(m, Bottleneck): | |||
| nn.init.constant_(m.bn3.weight, 0) | |||
| elif isinstance(m, BasicBlock): | |||
| nn.init.constant_(m.bn2.weight, 0) | |||
| def _make_layer(self, block, planes, blocks, stride=1, dilate=False): | |||
| norm_layer = self._norm_layer | |||
| downsample = None | |||
| previous_dilation = self.dilation | |||
| if dilate: | |||
| self.dilation *= stride | |||
| stride = 1 | |||
| if stride != 1 or self.inplanes != planes * block.expansion: | |||
| downsample = nn.Sequential( | |||
| conv1x1x1(self.inplanes, planes * block.expansion, stride), | |||
| norm_layer(planes * block.expansion)) | |||
| layers = [] | |||
| layers.append( | |||
| block(self.inplanes, planes, stride, downsample, self.groups, | |||
| self.base_width, previous_dilation, norm_layer)) | |||
| self.inplanes = planes * block.expansion | |||
| for _ in range(1, blocks): | |||
| layers.append( | |||
| block( | |||
| self.inplanes, | |||
| planes, | |||
| groups=self.groups, | |||
| base_width=self.base_width, | |||
| dilation=self.dilation, | |||
| norm_layer=norm_layer)) | |||
| return nn.Sequential(*layers) | |||
| def forward(self, x): | |||
| x = self.conv1(x) | |||
| x = self.bn1(x) | |||
| x = self.relu(x) | |||
| x = self.maxpool(x) | |||
| x = self.layer1(x) | |||
| x = self.layer2(x) | |||
| x = self.layer3(x) | |||
| x = self.layer4(x) | |||
| if self.avgpool: | |||
| x = self.avgpool(x) | |||
| x = torch.flatten(x, 1) | |||
| if self.dropout and self.fc: | |||
| x = self.dropout(x) | |||
| x = self.fc(x) | |||
| return x | |||
| def resnet10_3d(**kwargs): | |||
| return ResNet3d(BasicBlock, [1, 1, 1, 1], **kwargs) | |||
| def resnet18_3d(**kwargs): | |||
| return ResNet3d(BasicBlock, [2, 2, 2, 2], **kwargs) | |||
| def resnet26_3d(**kwargs): | |||
| return ResNet3d(Bottleneck, [2, 2, 2, 2], **kwargs) | |||
| def resnet34_3d(**kwargs): | |||
| return ResNet3d(BasicBlock, [3, 4, 6, 3], **kwargs) | |||
| def resnet50_3d(**kwargs): | |||
| return ResNet3d(Bottleneck, [3, 4, 6, 3], **kwargs) | |||
| def resnet101_3d(**kwargs): | |||
| return ResNet3d(Bottleneck, [3, 4, 23, 3], **kwargs) | |||
| def resnet152_3d(**kwargs): | |||
| return ResNet3d(Bottleneck, [3, 8, 36, 3], **kwargs) | |||
| def resnet200_3d(**kwargs): | |||
| return ResNet3d(Bottleneck, [3, 24, 36, 3], **kwargs) | |||
| @@ -59,6 +59,11 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||
| Tasks.visual_question_answering: | |||
| (Pipelines.visual_question_answering, | |||
| 'damo/mplug_visual-question-answering_coco_large_en'), | |||
| Tasks.video_embedding: (Pipelines.cmdssl_video_embedding, | |||
| 'damo/cv_r2p1d_video_embedding'), | |||
| Tasks.text_to_image_synthesis: | |||
| (Pipelines.text_to_image_synthesis, | |||
| 'damo/cv_imagen_text-to-image-synthesis_tiny') | |||
| } | |||
| @@ -1,6 +1,7 @@ | |||
| try: | |||
| from .action_recognition_pipeline import ActionRecognitionPipeline | |||
| from .animal_recog_pipeline import AnimalRecogPipeline | |||
| from .cmdssl_video_embedding_pipleline import CMDSSLVideoEmbeddingPipeline | |||
| except ModuleNotFoundError as e: | |||
| if str(e) == "No module named 'torch'": | |||
| pass | |||
| @@ -2,9 +2,6 @@ import math | |||
| import os.path as osp | |||
| from typing import Any, Dict | |||
| import cv2 | |||
| import numpy as np | |||
| import PIL | |||
| import torch | |||
| from modelscope.metainfo import Pipelines | |||
| @@ -32,7 +29,9 @@ class ActionRecognitionPipeline(Pipeline): | |||
| config_path = osp.join(self.model, ModelFile.CONFIGURATION) | |||
| logger.info(f'loading config from {config_path}') | |||
| self.cfg = Config.from_file(config_path) | |||
| self.infer_model = BaseVideoModel(cfg=self.cfg).cuda() | |||
| self.device = torch.device( | |||
| 'cuda' if torch.cuda.is_available() else 'cpu') | |||
| self.infer_model = BaseVideoModel(cfg=self.cfg).to(self.device) | |||
| self.infer_model.eval() | |||
| self.infer_model.load_state_dict(torch.load(model_path)['model_state']) | |||
| self.label_mapping = self.cfg.label_mapping | |||
| @@ -40,7 +39,7 @@ class ActionRecognitionPipeline(Pipeline): | |||
| def preprocess(self, input: Input) -> Dict[str, Any]: | |||
| if isinstance(input, str): | |||
| video_input_data = ReadVideoData(self.cfg, input).cuda() | |||
| video_input_data = ReadVideoData(self.cfg, input).to(self.device) | |||
| else: | |||
| raise TypeError(f'input should be a str,' | |||
| f' but got {type(input)}') | |||
| @@ -0,0 +1,157 @@ | |||
| import math | |||
| import os.path as osp | |||
| from typing import Any, Dict | |||
| import cv2 | |||
| import decord | |||
| import numpy as np | |||
| import PIL | |||
| import torch | |||
| import torchvision.transforms.functional as TF | |||
| from decord import VideoReader, cpu | |||
| from PIL import Image | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.models.cv.cmdssl_video_embedding.resnet2p1d import \ | |||
| resnet26_2p1d | |||
| from modelscope.pipelines.base import Input | |||
| from modelscope.pipelines.outputs import OutputKeys | |||
| from modelscope.utils.config import Config | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| from ..base import Pipeline | |||
| from ..builder import PIPELINES | |||
| logger = get_logger() | |||
| @PIPELINES.register_module( | |||
| Tasks.video_embedding, module_name=Pipelines.cmdssl_video_embedding) | |||
| class CMDSSLVideoEmbeddingPipeline(Pipeline): | |||
| def __init__(self, model: str): | |||
| super().__init__(model=model) | |||
| model_path = osp.join(self.model, ModelFile.TORCH_MODEL_FILE) | |||
| logger.info(f'loading model from {model_path}') | |||
| config_path = osp.join(self.model, ModelFile.CONFIGURATION) | |||
| logger.info(f'loading config from {config_path}') | |||
| self.cfg = Config.from_file(config_path) | |||
| self.model = resnet26_2p1d(num_classes=None, last_pool=True) | |||
| if torch.cuda.is_available(): | |||
| self._device = torch.device('cuda') | |||
| else: | |||
| self._device = torch.device('cpu') | |||
| self.model = self.model.to(self._device).eval().requires_grad_(False) | |||
| self.model.load_state_dict(torch.load(model_path)) | |||
| logger.info('load model done') | |||
| def preprocess(self, input: Input) -> Dict[str, Any]: | |||
| decord.bridge.set_bridge('native') | |||
| transforms = VCompose([ | |||
| VRescale(size=self.cfg.DATA.scale_size), | |||
| VCenterCrop(size=self.cfg.DATA.crop_size), | |||
| VToTensor(), | |||
| VNormalize(mean=self.cfg.DATA.mean, std=self.cfg.DATA.std) | |||
| ]) | |||
| clip_len = (self.cfg.DATA.video_frames | |||
| - 1) * self.cfg.DATA.video_stride + 1 | |||
| vr = VideoReader(input, ctx=cpu(0)) | |||
| if len(vr) <= clip_len: | |||
| init_frames = np.zeros(self.cfg.DATA.multi_crop, dtype=int) | |||
| else: | |||
| init_frames = np.linspace(0, | |||
| len(vr) - clip_len, | |||
| self.cfg.DATA.multi_crop + 1) | |||
| init_frames = ((init_frames[1:] + init_frames[:-1]) | |||
| / 2.).astype(int) | |||
| indices = np.arange(0, clip_len, self.cfg.DATA.video_stride) | |||
| indices = (init_frames[:, None] + indices[None, :]).reshape(-1) | |||
| indices[indices >= len(vr)] = 0 | |||
| frames = torch.from_numpy(vr.get_batch(indices).asnumpy()).chunk( | |||
| self.cfg.DATA.multi_crop, dim=0) | |||
| frames = [ | |||
| transforms([Image.fromarray(f) for f in u.numpy()]) for u in frames | |||
| ] | |||
| frames = torch.stack(frames, dim=0) | |||
| result = {'video_data': frames} | |||
| return result | |||
| @torch.no_grad() | |||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
| frames = input['video_data'].to(self._device) | |||
| feature = self.model(frames) | |||
| feature = feature.mean(0) | |||
| return {OutputKeys.VIDEO_EMBEDDING: feature.data.cpu().numpy()} | |||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
| return inputs | |||
| class VCompose(object): | |||
| def __init__(self, transforms): | |||
| self.transforms = transforms | |||
| def __call__(self, item): | |||
| for t in self.transforms: | |||
| item = t(item) | |||
| return item | |||
| class VRescale(object): | |||
| def __init__(self, size=128): | |||
| self.size = size | |||
| def __call__(self, vclip): | |||
| w, h = vclip[0].size | |||
| scale = self.size / min(w, h) | |||
| out_w, out_h = int(round(w * scale)), int(round(h * scale)) | |||
| vclip = [u.resize((out_w, out_h), Image.BILINEAR) for u in vclip] | |||
| return vclip | |||
| class VCenterCrop(object): | |||
| def __init__(self, size=112): | |||
| self.size = size | |||
| def __call__(self, vclip): | |||
| w, h = vclip[0].size | |||
| assert min(w, h) >= self.size | |||
| x1 = (w - self.size) // 2 | |||
| y1 = (h - self.size) // 2 | |||
| vclip = [ | |||
| u.crop((x1, y1, x1 + self.size, y1 + self.size)) for u in vclip | |||
| ] | |||
| return vclip | |||
| class VToTensor(object): | |||
| def __call__(self, vclip): | |||
| vclip = torch.stack([TF.to_tensor(u) for u in vclip], dim=1) | |||
| return vclip | |||
| class VNormalize(object): | |||
| def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]): | |||
| self.mean = mean | |||
| self.std = std | |||
| def __call__(self, vclip): | |||
| assert vclip.min() > -0.1 and vclip.max() < 1.1, \ | |||
| 'vclip values should be in [0, 1]' | |||
| vclip = vclip.clone() | |||
| if not isinstance(self.mean, torch.Tensor): | |||
| self.mean = vclip.new_tensor(self.mean).view(-1, 1, 1, 1) | |||
| if not isinstance(self.std, torch.Tensor): | |||
| self.std = vclip.new_tensor(self.std).view(-1, 1, 1, 1) | |||
| vclip.sub_(self.mean).div_(self.std) | |||
| return vclip | |||
| @@ -1,4 +1,4 @@ | |||
| from typing import Any, Dict, Union | |||
| from typing import Any, Dict | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.pipelines.base import Input | |||
| @@ -23,7 +23,7 @@ class TextToImageSynthesisPipeline(Pipeline): | |||
| pipe_model = model | |||
| else: | |||
| raise NotImplementedError( | |||
| f'execpting a Model instance or str, but get {type(model)}.') | |||
| f'expecting a Model instance or str, but get {type(model)}.') | |||
| super().__init__(model=pipe_model) | |||
| @@ -22,6 +22,7 @@ class OutputKeys(object): | |||
| RESPONSE = 'response' | |||
| PREDICTION = 'prediction' | |||
| DIALOG_STATES = 'dialog_states' | |||
| VIDEO_EMBEDDING = 'video_embedding' | |||
| TASK_OUTPUTS = { | |||
| @@ -91,6 +92,12 @@ TASK_OUTPUTS = { | |||
| # } | |||
| Tasks.ocr_detection: [OutputKeys.POLYGONS], | |||
| # video embedding result for single video | |||
| # { | |||
| # "video_embedding": np.array with shape [D], | |||
| # } | |||
| Tasks.video_embedding: [OutputKeys.VIDEO_EMBEDDING], | |||
| # ============ nlp tasks =================== | |||
| # text classification result for single sample | |||
| @@ -31,6 +31,7 @@ class Tasks(object): | |||
| image_matting = 'image-matting' | |||
| ocr_detection = 'ocr-detection' | |||
| action_recognition = 'action-recognition' | |||
| video_embedding = 'video-embedding' | |||
| # nlp tasks | |||
| word_segmentation = 'word-segmentation' | |||
| @@ -1,14 +1,10 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| # !/usr/bin/env python | |||
| import os.path as osp | |||
| import shutil | |||
| import tempfile | |||
| import unittest | |||
| import cv2 | |||
| from modelscope.fileio import File | |||
| from modelscope.msdatasets import MsDataset | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| from modelscope.utils.test_utils import test_level | |||
| @@ -45,7 +41,7 @@ class ActionRecognitionTest(unittest.TestCase): | |||
| print(f'recognition output: {result}.') | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_modelhub_default_model(self): | |||
| recognition_pipeline = pipeline(Tasks.action_recognition) | |||
| result = recognition_pipeline( | |||
| @@ -0,0 +1,30 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| # !/usr/bin/env python | |||
| import os.path as osp | |||
| import shutil | |||
| import tempfile | |||
| import unittest | |||
| import cv2 | |||
| from modelscope.fileio import File | |||
| from modelscope.msdatasets import MsDataset | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| from modelscope.utils.test_utils import test_level | |||
| class CMDSSLVideoEmbeddingTest(unittest.TestCase): | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_modelhub(self): | |||
| videossl_pipeline = pipeline( | |||
| Tasks.video_embedding, model='damo/cv_r2p1d_video_embedding') | |||
| result = videossl_pipeline( | |||
| 'data/test/videos/action_recognition_test_video.mp4') | |||
| print(f'video embedding output: {result}.') | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||