Browse Source

merge with master

master
智丞 3 years ago
parent
commit
c2afb63b1e
38 changed files with 1954 additions and 68 deletions
  1. +3
    -0
      modelscope/metainfo.py
  2. +2
    -1
      modelscope/models/__init__.py
  3. +0
    -0
      modelscope/models/audio/aec/__init__.py
  4. +0
    -0
      modelscope/models/audio/aec/layers/__init__.py
  5. +0
    -0
      modelscope/models/audio/aec/layers/activations.py
  6. +0
    -0
      modelscope/models/audio/aec/layers/affine_transform.py
  7. +0
    -0
      modelscope/models/audio/aec/layers/deep_fsmn.py
  8. +0
    -0
      modelscope/models/audio/aec/layers/layer_base.py
  9. +0
    -0
      modelscope/models/audio/aec/layers/uni_deep_fsmn.py
  10. +0
    -0
      modelscope/models/audio/aec/network/__init__.py
  11. +0
    -0
      modelscope/models/audio/aec/network/loss.py
  12. +0
    -0
      modelscope/models/audio/aec/network/modulation_loss.py
  13. +0
    -0
      modelscope/models/audio/aec/network/se_net.py
  14. +0
    -0
      modelscope/models/audio/ans/__init__.py
  15. +248
    -0
      modelscope/models/audio/ans/complex_nn.py
  16. +112
    -0
      modelscope/models/audio/ans/conv_stft.py
  17. +309
    -0
      modelscope/models/audio/ans/frcrn.py
  18. +26
    -0
      modelscope/models/audio/ans/se_module_complex.py
  19. +269
    -0
      modelscope/models/audio/ans/unet.py
  20. +0
    -0
      modelscope/models/cv/animal_recognition/__init__.py
  21. +430
    -0
      modelscope/models/cv/animal_recognition/resnet.py
  22. +125
    -0
      modelscope/models/cv/animal_recognition/splat.py
  23. +9
    -1
      modelscope/models/nlp/masked_language_model.py
  24. +1
    -0
      modelscope/pipelines/__init__.py
  25. +117
    -0
      modelscope/pipelines/audio/ans_pipeline.py
  26. +1
    -0
      modelscope/pipelines/cv/__init__.py
  27. +127
    -0
      modelscope/pipelines/cv/animal_recog_pipeline.py
  28. +52
    -42
      modelscope/pipelines/cv/ocr_detection_pipeline.py
  29. +5
    -1
      modelscope/pipelines/cv/ocr_utils/model_resnet_mutex_v4_linewithchar.py
  30. +5
    -1
      modelscope/pipelines/cv/ocr_utils/resnet18_v1.py
  31. +5
    -1
      modelscope/pipelines/cv/ocr_utils/resnet_utils.py
  32. +14
    -10
      modelscope/pipelines/nlp/fill_mask_pipeline.py
  33. +7
    -4
      modelscope/preprocessors/nlp.py
  34. +1
    -0
      requirements/audio.txt
  35. +20
    -0
      tests/pipelines/test_animal_recognation.py
  36. +35
    -1
      tests/pipelines/test_fill_mask.py
  37. +5
    -0
      tests/pipelines/test_ocr_detection.py
  38. +26
    -6
      tests/pipelines/test_speech_signal_process.py

+ 3
- 0
modelscope/metainfo.py View File

@@ -22,6 +22,7 @@ class Models(object):
sambert_hifi_16k = 'sambert-hifi-16k' sambert_hifi_16k = 'sambert-hifi-16k'
generic_tts_frontend = 'generic-tts-frontend' generic_tts_frontend = 'generic-tts-frontend'
hifigan16k = 'hifigan16k' hifigan16k = 'hifigan16k'
speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k'
kws_kwsbp = 'kws-kwsbp' kws_kwsbp = 'kws-kwsbp'


# multi-modal models # multi-modal models
@@ -44,6 +45,7 @@ class Pipelines(object):
person_image_cartoon = 'unet-person-image-cartoon' person_image_cartoon = 'unet-person-image-cartoon'
ocr_detection = 'resnet18-ocr-detection' ocr_detection = 'resnet18-ocr-detection'
action_recognition = 'TAdaConv_action-recognition' action_recognition = 'TAdaConv_action-recognition'
animal_recognation = 'resnet101-animal_recog'


# nlp tasks # nlp tasks
sentence_similarity = 'sentence-similarity' sentence_similarity = 'sentence-similarity'
@@ -59,6 +61,7 @@ class Pipelines(object):
# audio tasks # audio tasks
sambert_hifigan_16k_tts = 'sambert-hifigan-16k-tts' sambert_hifigan_16k_tts = 'sambert-hifigan-16k-tts'
speech_dfsmn_aec_psm_16k = 'speech-dfsmn-aec-psm-16k' speech_dfsmn_aec_psm_16k = 'speech-dfsmn-aec-psm-16k'
speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k'
kws_kwsbp = 'kws-kwsbp' kws_kwsbp = 'kws-kwsbp'


# multi-modal tasks # multi-modal tasks


+ 2
- 1
modelscope/models/__init__.py View File

@@ -1,12 +1,13 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.


from .audio.ans.frcrn import FRCRNModel
from .audio.kws import GenericKeyWordSpotting from .audio.kws import GenericKeyWordSpotting
from .audio.tts.am import SambertNetHifi16k from .audio.tts.am import SambertNetHifi16k
from .audio.tts.vocoder import Hifigan16k from .audio.tts.vocoder import Hifigan16k
from .base import Model from .base import Model
from .builder import MODELS, build_model from .builder import MODELS, build_model
from .multi_modal import OfaForImageCaptioning from .multi_modal import OfaForImageCaptioning
from .nlp import (BertForSequenceClassification, SbertForNLI,
from .nlp import (BertForMaskedLM, BertForSequenceClassification, SbertForNLI,
SbertForSentenceSimilarity, SbertForSentimentClassification, SbertForSentenceSimilarity, SbertForSentimentClassification,
SbertForTokenClassification, StructBertForMaskedLM, SbertForTokenClassification, StructBertForMaskedLM,
VecoForMaskedLM) VecoForMaskedLM)

modelscope/models/audio/layers/__init__.py → modelscope/models/audio/aec/__init__.py View File


modelscope/models/audio/network/__init__.py → modelscope/models/audio/aec/layers/__init__.py View File


modelscope/models/audio/layers/activations.py → modelscope/models/audio/aec/layers/activations.py View File


modelscope/models/audio/layers/affine_transform.py → modelscope/models/audio/aec/layers/affine_transform.py View File


modelscope/models/audio/layers/deep_fsmn.py → modelscope/models/audio/aec/layers/deep_fsmn.py View File


modelscope/models/audio/layers/layer_base.py → modelscope/models/audio/aec/layers/layer_base.py View File


modelscope/models/audio/layers/uni_deep_fsmn.py → modelscope/models/audio/aec/layers/uni_deep_fsmn.py View File


+ 0
- 0
modelscope/models/audio/aec/network/__init__.py View File


modelscope/models/audio/network/loss.py → modelscope/models/audio/aec/network/loss.py View File


modelscope/models/audio/network/modulation_loss.py → modelscope/models/audio/aec/network/modulation_loss.py View File


modelscope/models/audio/network/se_net.py → modelscope/models/audio/aec/network/se_net.py View File


+ 0
- 0
modelscope/models/audio/ans/__init__.py View File


+ 248
- 0
modelscope/models/audio/ans/complex_nn.py View File

@@ -0,0 +1,248 @@
import torch
import torch.nn as nn
import torch.nn.functional as F


class UniDeepFsmn(nn.Module):

def __init__(self, input_dim, output_dim, lorder=None, hidden_size=None):
super(UniDeepFsmn, self).__init__()

self.input_dim = input_dim
self.output_dim = output_dim

if lorder is None:
return

self.lorder = lorder
self.hidden_size = hidden_size

self.linear = nn.Linear(input_dim, hidden_size)

self.project = nn.Linear(hidden_size, output_dim, bias=False)

self.conv1 = nn.Conv2d(
output_dim,
output_dim, [lorder, 1], [1, 1],
groups=output_dim,
bias=False)

def forward(self, input):
r"""

Args:
input: torch with shape: batch (b) x sequence(T) x feature (h)

Returns:
batch (b) x channel (c) x sequence(T) x feature (h)
"""
f1 = F.relu(self.linear(input))

p1 = self.project(f1)

x = torch.unsqueeze(p1, 1)
# x: batch (b) x channel (c) x sequence(T) x feature (h)
x_per = x.permute(0, 3, 2, 1)
# x_per: batch (b) x feature (h) x sequence(T) x channel (c)
y = F.pad(x_per, [0, 0, self.lorder - 1, 0])

out = x_per + self.conv1(y)

out1 = out.permute(0, 3, 2, 1)
# out1: batch (b) x channel (c) x sequence(T) x feature (h)
return input + out1.squeeze()


class ComplexUniDeepFsmn(nn.Module):

def __init__(self, nIn, nHidden=128, nOut=128):
super(ComplexUniDeepFsmn, self).__init__()

self.fsmn_re_L1 = UniDeepFsmn(nIn, nHidden, 20, nHidden)
self.fsmn_im_L1 = UniDeepFsmn(nIn, nHidden, 20, nHidden)
self.fsmn_re_L2 = UniDeepFsmn(nHidden, nOut, 20, nHidden)
self.fsmn_im_L2 = UniDeepFsmn(nHidden, nOut, 20, nHidden)

def forward(self, x):
r"""

Args:
x: torch with shape [batch, channel, feature, sequence, 2], eg: [6, 256, 1, 106, 2]

Returns:
[batch, feature, sequence, 2], eg: [6, 99, 1024, 2]
"""
#
b, c, h, T, d = x.size()
x = torch.reshape(x, (b, c * h, T, d))
# x: [b,h,T,2], [6, 256, 106, 2]
x = torch.transpose(x, 1, 2)
# x: [b,T,h,2], [6, 106, 256, 2]

real_L1 = self.fsmn_re_L1(x[..., 0]) - self.fsmn_im_L1(x[..., 1])
imaginary_L1 = self.fsmn_re_L1(x[..., 1]) + self.fsmn_im_L1(x[..., 0])
# GRU output: [99, 6, 128]
real = self.fsmn_re_L2(real_L1) - self.fsmn_im_L2(imaginary_L1)
imaginary = self.fsmn_re_L2(imaginary_L1) + self.fsmn_im_L2(real_L1)
# output: [b,T,h,2], [99, 6, 1024, 2]
output = torch.stack((real, imaginary), dim=-1)

# output: [b,h,T,2], [6, 99, 1024, 2]
output = torch.transpose(output, 1, 2)
output = torch.reshape(output, (b, c, h, T, d))

return output


class ComplexUniDeepFsmn_L1(nn.Module):

def __init__(self, nIn, nHidden=128, nOut=128):
super(ComplexUniDeepFsmn_L1, self).__init__()
self.fsmn_re_L1 = UniDeepFsmn(nIn, nHidden, 20, nHidden)
self.fsmn_im_L1 = UniDeepFsmn(nIn, nHidden, 20, nHidden)

def forward(self, x):
r"""

Args:
x: torch with shape [batch, channel, feature, sequence, 2], eg: [6, 256, 1, 106, 2]
"""
b, c, h, T, d = x.size()
# x : [b,T,h,c,2]
x = torch.transpose(x, 1, 3)
x = torch.reshape(x, (b * T, h, c, d))

real = self.fsmn_re_L1(x[..., 0]) - self.fsmn_im_L1(x[..., 1])
imaginary = self.fsmn_re_L1(x[..., 1]) + self.fsmn_im_L1(x[..., 0])
# output: [b*T,h,c,2], [6*106, h, 256, 2]
output = torch.stack((real, imaginary), dim=-1)

output = torch.reshape(output, (b, T, h, c, d))
output = torch.transpose(output, 1, 3)
return output


class ComplexConv2d(nn.Module):
# https://github.com/litcoderr/ComplexCNN/blob/master/complexcnn/modules.py
def __init__(self,
in_channel,
out_channel,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
**kwargs):
super().__init__()

# Model components
self.conv_re = nn.Conv2d(
in_channel,
out_channel,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
**kwargs)
self.conv_im = nn.Conv2d(
in_channel,
out_channel,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
**kwargs)

def forward(self, x):
r"""

Args:
x: torch with shape: [batch,channel,axis1,axis2,2]
"""
real = self.conv_re(x[..., 0]) - self.conv_im(x[..., 1])
imaginary = self.conv_re(x[..., 1]) + self.conv_im(x[..., 0])
output = torch.stack((real, imaginary), dim=-1)
return output


class ComplexConvTranspose2d(nn.Module):

def __init__(self,
in_channel,
out_channel,
kernel_size,
stride=1,
padding=0,
output_padding=0,
dilation=1,
groups=1,
bias=True,
**kwargs):
super().__init__()

# Model components
self.tconv_re = nn.ConvTranspose2d(
in_channel,
out_channel,
kernel_size=kernel_size,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
bias=bias,
dilation=dilation,
**kwargs)
self.tconv_im = nn.ConvTranspose2d(
in_channel,
out_channel,
kernel_size=kernel_size,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
bias=bias,
dilation=dilation,
**kwargs)

def forward(self, x): # shpae of x : [batch,channel,axis1,axis2,2]
real = self.tconv_re(x[..., 0]) - self.tconv_im(x[..., 1])
imaginary = self.tconv_re(x[..., 1]) + self.tconv_im(x[..., 0])
output = torch.stack((real, imaginary), dim=-1)
return output


class ComplexBatchNorm2d(nn.Module):

def __init__(self,
num_features,
eps=1e-5,
momentum=0.1,
affine=True,
track_running_stats=True,
**kwargs):
super().__init__()
self.bn_re = nn.BatchNorm2d(
num_features=num_features,
momentum=momentum,
affine=affine,
eps=eps,
track_running_stats=track_running_stats,
**kwargs)
self.bn_im = nn.BatchNorm2d(
num_features=num_features,
momentum=momentum,
affine=affine,
eps=eps,
track_running_stats=track_running_stats,
**kwargs)

def forward(self, x):
real = self.bn_re(x[..., 0])
imag = self.bn_im(x[..., 1])
output = torch.stack((real, imag), dim=-1)
return output

+ 112
- 0
modelscope/models/audio/ans/conv_stft.py View File

@@ -0,0 +1,112 @@
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.signal import get_window


def init_kernels(win_len, win_inc, fft_len, win_type=None, invers=False):
if win_type == 'None' or win_type is None:
window = np.ones(win_len)
else:
window = get_window(win_type, win_len, fftbins=True)**0.5

N = fft_len
fourier_basis = np.fft.rfft(np.eye(N))[:win_len]
real_kernel = np.real(fourier_basis)
imag_kernel = np.imag(fourier_basis)
kernel = np.concatenate([real_kernel, imag_kernel], 1).T

if invers:
kernel = np.linalg.pinv(kernel).T

kernel = kernel * window
kernel = kernel[:, None, :]
return torch.from_numpy(kernel.astype(np.float32)), torch.from_numpy(
window[None, :, None].astype(np.float32))


class ConvSTFT(nn.Module):

def __init__(self,
win_len,
win_inc,
fft_len=None,
win_type='hamming',
feature_type='real',
fix=True):
super(ConvSTFT, self).__init__()

if fft_len is None:
self.fft_len = np.int(2**np.ceil(np.log2(win_len)))
else:
self.fft_len = fft_len

kernel, _ = init_kernels(win_len, win_inc, self.fft_len, win_type)
self.weight = nn.Parameter(kernel, requires_grad=(not fix))
self.feature_type = feature_type
self.stride = win_inc
self.win_len = win_len
self.dim = self.fft_len

def forward(self, inputs):
if inputs.dim() == 2:
inputs = torch.unsqueeze(inputs, 1)

outputs = F.conv1d(inputs, self.weight, stride=self.stride)

if self.feature_type == 'complex':
return outputs
else:
dim = self.dim // 2 + 1
real = outputs[:, :dim, :]
imag = outputs[:, dim:, :]
mags = torch.sqrt(real**2 + imag**2)
phase = torch.atan2(imag, real)
return mags, phase


class ConviSTFT(nn.Module):

def __init__(self,
win_len,
win_inc,
fft_len=None,
win_type='hamming',
feature_type='real',
fix=True):
super(ConviSTFT, self).__init__()
if fft_len is None:
self.fft_len = np.int(2**np.ceil(np.log2(win_len)))
else:
self.fft_len = fft_len
kernel, window = init_kernels(
win_len, win_inc, self.fft_len, win_type, invers=True)
self.weight = nn.Parameter(kernel, requires_grad=(not fix))
self.feature_type = feature_type
self.win_type = win_type
self.win_len = win_len
self.win_inc = win_inc
self.stride = win_inc
self.dim = self.fft_len
self.register_buffer('window', window)
self.register_buffer('enframe', torch.eye(win_len)[:, None, :])

def forward(self, inputs, phase=None):
"""
Args:
inputs : [B, N+2, T] (complex spec) or [B, N//2+1, T] (mags)
phase: [B, N//2+1, T] (if not none)
"""

if phase is not None:
real = inputs * torch.cos(phase)
imag = inputs * torch.sin(phase)
inputs = torch.cat([real, imag], 1)
outputs = F.conv_transpose1d(inputs, self.weight, stride=self.stride)

# this is from torch-stft: https://github.com/pseeth/torch-stft
t = self.window.repeat(1, 1, inputs.size(-1))**2
coff = F.conv_transpose1d(t, self.enframe, stride=self.stride)
outputs = outputs / (coff + 1e-8)
return outputs

+ 309
- 0
modelscope/models/audio/ans/frcrn.py View File

@@ -0,0 +1,309 @@
import os
from typing import Dict

import torch
import torch.nn as nn
import torch.nn.functional as F

from modelscope.metainfo import Models
from modelscope.models.builder import MODELS
from modelscope.utils.constant import ModelFile, Tasks
from ...base import Model, Tensor
from .conv_stft import ConviSTFT, ConvSTFT
from .unet import UNet


class FTB(nn.Module):

def __init__(self, input_dim=257, in_channel=9, r_channel=5):

super(FTB, self).__init__()
self.in_channel = in_channel
self.conv1 = nn.Sequential(
nn.Conv2d(in_channel, r_channel, kernel_size=[1, 1]),
nn.BatchNorm2d(r_channel), nn.ReLU())

self.conv1d = nn.Sequential(
nn.Conv1d(
r_channel * input_dim, in_channel, kernel_size=9, padding=4),
nn.BatchNorm1d(in_channel), nn.ReLU())
self.freq_fc = nn.Linear(input_dim, input_dim, bias=False)

self.conv2 = nn.Sequential(
nn.Conv2d(in_channel * 2, in_channel, kernel_size=[1, 1]),
nn.BatchNorm2d(in_channel), nn.ReLU())

def forward(self, inputs):
'''
inputs should be [Batch, Ca, Dim, Time]
'''
# T-F attention
conv1_out = self.conv1(inputs)
B, C, D, T = conv1_out.size()
reshape1_out = torch.reshape(conv1_out, [B, C * D, T])
conv1d_out = self.conv1d(reshape1_out)
conv1d_out = torch.reshape(conv1d_out, [B, self.in_channel, 1, T])

# now is also [B,C,D,T]
att_out = conv1d_out * inputs

# tranpose to [B,C,T,D]
att_out = torch.transpose(att_out, 2, 3)
freqfc_out = self.freq_fc(att_out)
att_out = torch.transpose(freqfc_out, 2, 3)

cat_out = torch.cat([att_out, inputs], 1)
outputs = self.conv2(cat_out)
return outputs


@MODELS.register_module(
Tasks.speech_signal_process, module_name=Models.speech_frcrn_ans_cirm_16k)
class FRCRNModel(Model):
r""" A decorator of FRCRN for integrating into modelscope framework """

def __init__(self, model_dir: str, *args, **kwargs):
"""initialize the frcrn model from the `model_dir` path.

Args:
model_dir (str): the model path.
"""
super().__init__(model_dir, *args, **kwargs)
self._model = FRCRN(*args, **kwargs)
model_bin_file = os.path.join(model_dir,
ModelFile.TORCH_MODEL_BIN_FILE)
if os.path.exists(model_bin_file):
checkpoint = torch.load(model_bin_file)
self._model.load_state_dict(checkpoint, strict=False)

def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
output = self._model.forward(input)
return {
'spec_l1': output[0],
'wav_l1': output[1],
'mask_l1': output[2],
'spec_l2': output[3],
'wav_l2': output[4],
'mask_l2': output[5]
}

def to(self, *args, **kwargs):
self._model = self._model.to(*args, **kwargs)
return self

def eval(self):
self._model = self._model.train(False)
return self


class FRCRN(nn.Module):
r""" Frequency Recurrent CRN """

def __init__(self,
complex,
model_complexity,
model_depth,
log_amp,
padding_mode,
win_len=400,
win_inc=100,
fft_len=512,
win_type='hanning'):
r"""
Args:
complex: Whether to use complex networks.
model_complexity: define the model complexity with the number of layers
model_depth: Only two options are available : 10, 20
log_amp: Whether to use log amplitude to estimate signals
padding_mode: Encoder's convolution filter. 'zeros', 'reflect'
win_len: length of window used for defining one frame of sample points
win_inc: length of window shifting (equivalent to hop_size)
fft_len: number of Short Time Fourier Transform (STFT) points
win_type: windowing type used in STFT, eg. 'hanning', 'hamming'
"""
super().__init__()
self.feat_dim = fft_len // 2 + 1

self.win_len = win_len
self.win_inc = win_inc
self.fft_len = fft_len
self.win_type = win_type

fix = True
self.stft = ConvSTFT(
self.win_len,
self.win_inc,
self.fft_len,
self.win_type,
feature_type='complex',
fix=fix)
self.istft = ConviSTFT(
self.win_len,
self.win_inc,
self.fft_len,
self.win_type,
feature_type='complex',
fix=fix)
self.unet = UNet(
1,
complex=complex,
model_complexity=model_complexity,
model_depth=model_depth,
padding_mode=padding_mode)
self.unet2 = UNet(
1,
complex=complex,
model_complexity=model_complexity,
model_depth=model_depth,
padding_mode=padding_mode)

def forward(self, inputs):
out_list = []
# [B, D*2, T]
cmp_spec = self.stft(inputs)
# [B, 1, D*2, T]
cmp_spec = torch.unsqueeze(cmp_spec, 1)

# to [B, 2, D, T] real_part/imag_part
cmp_spec = torch.cat([
cmp_spec[:, :, :self.feat_dim, :],
cmp_spec[:, :, self.feat_dim:, :],
], 1)

# [B, 2, D, T]
cmp_spec = torch.unsqueeze(cmp_spec, 4)
# [B, 1, D, T, 2]
cmp_spec = torch.transpose(cmp_spec, 1, 4)
unet1_out = self.unet(cmp_spec)
cmp_mask1 = torch.tanh(unet1_out)
unet2_out = self.unet2(unet1_out)
cmp_mask2 = torch.tanh(unet2_out)
est_spec, est_wav, est_mask = self.apply_mask(cmp_spec, cmp_mask1)
out_list.append(est_spec)
out_list.append(est_wav)
out_list.append(est_mask)
cmp_mask2 = cmp_mask2 + cmp_mask1
est_spec, est_wav, est_mask = self.apply_mask(cmp_spec, cmp_mask2)
out_list.append(est_spec)
out_list.append(est_wav)
out_list.append(est_mask)
return out_list

def apply_mask(self, cmp_spec, cmp_mask):
est_spec = torch.cat([
cmp_spec[:, :, :, :, 0] * cmp_mask[:, :, :, :, 0]
- cmp_spec[:, :, :, :, 1] * cmp_mask[:, :, :, :, 1],
cmp_spec[:, :, :, :, 0] * cmp_mask[:, :, :, :, 1]
+ cmp_spec[:, :, :, :, 1] * cmp_mask[:, :, :, :, 0]
], 1)
est_spec = torch.cat([est_spec[:, 0, :, :], est_spec[:, 1, :, :]], 1)
cmp_mask = torch.squeeze(cmp_mask, 1)
cmp_mask = torch.cat([cmp_mask[:, :, :, 0], cmp_mask[:, :, :, 1]], 1)

est_wav = self.istft(est_spec)
est_wav = torch.squeeze(est_wav, 1)
return est_spec, est_wav, cmp_mask

def get_params(self, weight_decay=0.0):
# add L2 penalty
weights, biases = [], []
for name, param in self.named_parameters():
if 'bias' in name:
biases += [param]
else:
weights += [param]
params = [{
'params': weights,
'weight_decay': weight_decay,
}, {
'params': biases,
'weight_decay': 0.0,
}]
return params

def loss(self, noisy, labels, out_list, mode='Mix'):
if mode == 'SiSNR':
count = 0
while count < len(out_list):
est_spec = out_list[count]
count = count + 1
est_wav = out_list[count]
count = count + 1
est_mask = out_list[count]
count = count + 1
if count != 3:
loss = self.loss_1layer(noisy, est_spec, est_wav, labels,
est_mask, mode)
return loss

elif mode == 'Mix':
count = 0
while count < len(out_list):
est_spec = out_list[count]
count = count + 1
est_wav = out_list[count]
count = count + 1
est_mask = out_list[count]
count = count + 1
if count != 3:
amp_loss, phase_loss, SiSNR_loss = self.loss_1layer(
noisy, est_spec, est_wav, labels, est_mask, mode)
loss = amp_loss + phase_loss + SiSNR_loss
return loss, amp_loss, phase_loss

def loss_1layer(self, noisy, est, est_wav, labels, cmp_mask, mode='Mix'):
r""" Compute the loss by mode
mode == 'Mix'
est: [B, F*2, T]
labels: [B, F*2,T]
mode == 'SiSNR'
est: [B, T]
labels: [B, T]
"""
if mode == 'SiSNR':
if labels.dim() == 3:
labels = torch.squeeze(labels, 1)
if est_wav.dim() == 3:
est_wav = torch.squeeze(est_wav, 1)
return -si_snr(est_wav, labels)
elif mode == 'Mix':

if labels.dim() == 3:
labels = torch.squeeze(labels, 1)
if est_wav.dim() == 3:
est_wav = torch.squeeze(est_wav, 1)
SiSNR_loss = -si_snr(est_wav, labels)

b, d, t = est.size()
S = self.stft(labels)
Sr = S[:, :self.feat_dim, :]
Si = S[:, self.feat_dim:, :]
Y = self.stft(noisy)
Yr = Y[:, :self.feat_dim, :]
Yi = Y[:, self.feat_dim:, :]
Y_pow = Yr**2 + Yi**2
gth_mask = torch.cat([(Sr * Yr + Si * Yi) / (Y_pow + 1e-8),
(Si * Yr - Sr * Yi) / (Y_pow + 1e-8)], 1)
gth_mask[gth_mask > 2] = 1
gth_mask[gth_mask < -2] = -1
amp_loss = F.mse_loss(gth_mask[:, :self.feat_dim, :],
cmp_mask[:, :self.feat_dim, :]) * d
phase_loss = F.mse_loss(gth_mask[:, self.feat_dim:, :],
cmp_mask[:, self.feat_dim:, :]) * d
return amp_loss, phase_loss, SiSNR_loss


def l2_norm(s1, s2):
norm = torch.sum(s1 * s2, -1, keepdim=True)
return norm


def si_snr(s1, s2, eps=1e-8):
s1_s2_norm = l2_norm(s1, s2)
s2_s2_norm = l2_norm(s2, s2)
s_target = s1_s2_norm / (s2_s2_norm + eps) * s2
e_nosie = s1 - s_target
target_norm = l2_norm(s_target, s_target)
noise_norm = l2_norm(e_nosie, e_nosie)
snr = 10 * torch.log10((target_norm) / (noise_norm + eps) + eps)
return torch.mean(snr)

+ 26
- 0
modelscope/models/audio/ans/se_module_complex.py View File

@@ -0,0 +1,26 @@
import torch
from torch import nn


class SELayer(nn.Module):

def __init__(self, channel, reduction=16):
super(SELayer, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc_r = nn.Sequential(
nn.Linear(channel, channel // reduction), nn.ReLU(inplace=True),
nn.Linear(channel // reduction, channel), nn.Sigmoid())
self.fc_i = nn.Sequential(
nn.Linear(channel, channel // reduction), nn.ReLU(inplace=True),
nn.Linear(channel // reduction, channel), nn.Sigmoid())

def forward(self, x):
b, c, _, _, _ = x.size()
x_r = self.avg_pool(x[:, :, :, :, 0]).view(b, c)
x_i = self.avg_pool(x[:, :, :, :, 1]).view(b, c)
y_r = self.fc_r(x_r).view(b, c, 1, 1, 1) - self.fc_i(x_i).view(
b, c, 1, 1, 1)
y_i = self.fc_r(x_i).view(b, c, 1, 1, 1) + self.fc_i(x_r).view(
b, c, 1, 1, 1)
y = torch.cat([y_r, y_i], 4)
return x * y

+ 269
- 0
modelscope/models/audio/ans/unet.py View File

@@ -0,0 +1,269 @@
import torch
import torch.nn as nn

from . import complex_nn
from .se_module_complex import SELayer


class Encoder(nn.Module):

def __init__(self,
in_channels,
out_channels,
kernel_size,
stride,
padding=None,
complex=False,
padding_mode='zeros'):
super().__init__()
if padding is None:
padding = [(i - 1) // 2 for i in kernel_size] # 'SAME' padding

if complex:
conv = complex_nn.ComplexConv2d
bn = complex_nn.ComplexBatchNorm2d
else:
conv = nn.Conv2d
bn = nn.BatchNorm2d

self.conv = conv(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
padding_mode=padding_mode)
self.bn = bn(out_channels)
self.relu = nn.LeakyReLU(inplace=True)

def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x


class Decoder(nn.Module):

def __init__(self,
in_channels,
out_channels,
kernel_size,
stride,
padding=(0, 0),
complex=False):
super().__init__()
if complex:
tconv = complex_nn.ComplexConvTranspose2d
bn = complex_nn.ComplexBatchNorm2d
else:
tconv = nn.ConvTranspose2d
bn = nn.BatchNorm2d

self.transconv = tconv(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding)
self.bn = bn(out_channels)
self.relu = nn.LeakyReLU(inplace=True)

def forward(self, x):
x = self.transconv(x)
x = self.bn(x)
x = self.relu(x)
return x


class UNet(nn.Module):

def __init__(self,
input_channels=1,
complex=False,
model_complexity=45,
model_depth=20,
padding_mode='zeros'):
super().__init__()

if complex:
model_complexity = int(model_complexity // 1.414)

self.set_size(
model_complexity=model_complexity,
input_channels=input_channels,
model_depth=model_depth)
self.encoders = []
self.model_length = model_depth // 2
self.fsmn = complex_nn.ComplexUniDeepFsmn(128, 128, 128)
self.se_layers_enc = []
self.fsmn_enc = []
for i in range(self.model_length):
fsmn_enc = complex_nn.ComplexUniDeepFsmn_L1(128, 128, 128)
self.add_module('fsmn_enc{}'.format(i), fsmn_enc)
self.fsmn_enc.append(fsmn_enc)
module = Encoder(
self.enc_channels[i],
self.enc_channels[i + 1],
kernel_size=self.enc_kernel_sizes[i],
stride=self.enc_strides[i],
padding=self.enc_paddings[i],
complex=complex,
padding_mode=padding_mode)
self.add_module('encoder{}'.format(i), module)
self.encoders.append(module)
se_layer_enc = SELayer(self.enc_channels[i + 1], 8)
self.add_module('se_layer_enc{}'.format(i), se_layer_enc)
self.se_layers_enc.append(se_layer_enc)
self.decoders = []
self.fsmn_dec = []
self.se_layers_dec = []
for i in range(self.model_length):
fsmn_dec = complex_nn.ComplexUniDeepFsmn_L1(128, 128, 128)
self.add_module('fsmn_dec{}'.format(i), fsmn_dec)
self.fsmn_dec.append(fsmn_dec)
module = Decoder(
self.dec_channels[i] * 2,
self.dec_channels[i + 1],
kernel_size=self.dec_kernel_sizes[i],
stride=self.dec_strides[i],
padding=self.dec_paddings[i],
complex=complex)
self.add_module('decoder{}'.format(i), module)
self.decoders.append(module)
if i < self.model_length - 1:
se_layer_dec = SELayer(self.dec_channels[i + 1], 8)
self.add_module('se_layer_dec{}'.format(i), se_layer_dec)
self.se_layers_dec.append(se_layer_dec)
if complex:
conv = complex_nn.ComplexConv2d
else:
conv = nn.Conv2d

linear = conv(self.dec_channels[-1], 1, 1)

self.add_module('linear', linear)
self.complex = complex
self.padding_mode = padding_mode

self.decoders = nn.ModuleList(self.decoders)
self.encoders = nn.ModuleList(self.encoders)
self.se_layers_enc = nn.ModuleList(self.se_layers_enc)
self.se_layers_dec = nn.ModuleList(self.se_layers_dec)
self.fsmn_enc = nn.ModuleList(self.fsmn_enc)
self.fsmn_dec = nn.ModuleList(self.fsmn_dec)

def forward(self, inputs):
x = inputs
# go down
xs = []
xs_se = []
xs_se.append(x)
for i, encoder in enumerate(self.encoders):
xs.append(x)
if i > 0:
x = self.fsmn_enc[i](x)
x = encoder(x)
xs_se.append(self.se_layers_enc[i](x))
# xs : x0=input x1 ... x9
x = self.fsmn(x)

p = x
for i, decoder in enumerate(self.decoders):
p = decoder(p)
if i < self.model_length - 1:
p = self.fsmn_dec[i](p)
if i == self.model_length - 1:
break
if i < self.model_length - 2:
p = self.se_layers_dec[i](p)
p = torch.cat([p, xs_se[self.model_length - 1 - i]], dim=1)

# cmp_spec: [12, 1, 513, 64, 2]
cmp_spec = self.linear(p)
return cmp_spec

def set_size(self, model_complexity, model_depth=20, input_channels=1):

if model_depth == 14:
self.enc_channels = [
input_channels, 128, 128, 128, 128, 128, 128, 128
]
self.enc_kernel_sizes = [(5, 2), (5, 2), (5, 2), (5, 2), (5, 2),
(5, 2), (2, 2)]
self.enc_strides = [(2, 1), (2, 1), (2, 1), (2, 1), (2, 1), (2, 1),
(2, 1)]
self.enc_paddings = [(0, 1), (0, 1), (0, 1), (0, 1), (0, 1),
(0, 1), (0, 1)]
self.dec_channels = [64, 128, 128, 128, 128, 128, 128, 1]
self.dec_kernel_sizes = [(2, 2), (5, 2), (5, 2), (5, 2), (6, 2),
(5, 2), (5, 2)]
self.dec_strides = [(2, 1), (2, 1), (2, 1), (2, 1), (2, 1), (2, 1),
(2, 1)]
self.dec_paddings = [(0, 1), (0, 1), (0, 1), (0, 1), (0, 1),
(0, 1), (0, 1)]

elif model_depth == 10:
self.enc_channels = [
input_channels,
16,
32,
64,
128,
256,
]
self.enc_kernel_sizes = [(3, 3), (3, 3), (3, 3), (3, 3), (3, 3)]
self.enc_strides = [(2, 1), (2, 1), (2, 1), (2, 1), (2, 1)]
self.enc_paddings = [(0, 1), (0, 1), (0, 1), (0, 1), (0, 1)]
self.dec_channels = [128, 128, 64, 32, 16, 1]
self.dec_kernel_sizes = [(3, 3), (3, 3), (3, 3), (4, 3), (3, 3)]
self.dec_strides = [(2, 1), (2, 1), (2, 1), (2, 1), (2, 1)]
self.dec_paddings = [(0, 1), (0, 1), (0, 1), (0, 1), (0, 1)]

elif model_depth == 20:
self.enc_channels = [
input_channels, model_complexity, model_complexity,
model_complexity * 2, model_complexity * 2,
model_complexity * 2, model_complexity * 2,
model_complexity * 2, model_complexity * 2,
model_complexity * 2, 128
]

self.enc_kernel_sizes = [(7, 1), (1, 7), (6, 4), (7, 5), (5, 3),
(5, 3), (5, 3), (5, 3), (5, 3), (5, 3)]

self.enc_strides = [(1, 1), (1, 1), (2, 2), (2, 1), (2, 2), (2, 1),
(2, 2), (2, 1), (2, 2), (2, 1)]

self.enc_paddings = [
(3, 0),
(0, 3),
None, # (0, 2),
None,
None, # (3,1),
None, # (3,1),
None, # (1,2),
None,
None,
None
]

self.dec_channels = [
0, model_complexity * 2, model_complexity * 2,
model_complexity * 2, model_complexity * 2,
model_complexity * 2, model_complexity * 2,
model_complexity * 2, model_complexity * 2,
model_complexity * 2, model_complexity * 2,
model_complexity * 2
]

self.dec_kernel_sizes = [(4, 3), (4, 2), (4, 3), (4, 2), (4, 3),
(4, 2), (6, 3), (7, 4), (1, 7), (7, 1)]

self.dec_strides = [(2, 1), (2, 2), (2, 1), (2, 2), (2, 1), (2, 2),
(2, 1), (2, 2), (1, 1), (1, 1)]

self.dec_paddings = [(1, 1), (1, 0), (1, 1), (1, 0), (1, 1),
(1, 0), (2, 1), (2, 1), (0, 3), (3, 0)]
else:
raise ValueError('Unknown model depth : {}'.format(model_depth))

+ 0
- 0
modelscope/models/cv/animal_recognition/__init__.py View File


+ 430
- 0
modelscope/models/cv/animal_recognition/resnet.py View File

@@ -0,0 +1,430 @@
import math

import torch
import torch.nn as nn

from .splat import SplAtConv2d

__all__ = ['ResNet', 'Bottleneck']


class DropBlock2D(object):

def __init__(self, *args, **kwargs):
raise NotImplementedError


class GlobalAvgPool2d(nn.Module):

def __init__(self):
"""Global average pooling over the input's spatial dimensions"""
super(GlobalAvgPool2d, self).__init__()

def forward(self, inputs):
return nn.functional.adaptive_avg_pool2d(inputs,
1).view(inputs.size(0), -1)


class Bottleneck(nn.Module):
expansion = 4

def __init__(self,
inplanes,
planes,
stride=1,
downsample=None,
radix=1,
cardinality=1,
bottleneck_width=64,
avd=False,
avd_first=False,
dilation=1,
is_first=False,
rectified_conv=False,
rectify_avg=False,
norm_layer=None,
dropblock_prob=0.0,
last_gamma=False):
super(Bottleneck, self).__init__()
group_width = int(planes * (bottleneck_width / 64.)) * cardinality
self.conv1 = nn.Conv2d(
inplanes, group_width, kernel_size=1, bias=False)
self.bn1 = norm_layer(group_width)
self.dropblock_prob = dropblock_prob
self.radix = radix
self.avd = avd and (stride > 1 or is_first)
self.avd_first = avd_first

if self.avd:
self.avd_layer = nn.AvgPool2d(3, stride, padding=1)
stride = 1

if dropblock_prob > 0.0:
self.dropblock1 = DropBlock2D(dropblock_prob, 3)
if radix == 1:
self.dropblock2 = DropBlock2D(dropblock_prob, 3)
self.dropblock3 = DropBlock2D(dropblock_prob, 3)

if radix >= 1:
self.conv2 = SplAtConv2d(
group_width,
group_width,
kernel_size=3,
stride=stride,
padding=dilation,
dilation=dilation,
groups=cardinality,
bias=False,
radix=radix,
rectify=rectified_conv,
rectify_avg=rectify_avg,
norm_layer=norm_layer,
dropblock_prob=dropblock_prob)
elif rectified_conv:
from rfconv import RFConv2d
self.conv2 = RFConv2d(
group_width,
group_width,
kernel_size=3,
stride=stride,
padding=dilation,
dilation=dilation,
groups=cardinality,
bias=False,
average_mode=rectify_avg)
self.bn2 = norm_layer(group_width)
else:
self.conv2 = nn.Conv2d(
group_width,
group_width,
kernel_size=3,
stride=stride,
padding=dilation,
dilation=dilation,
groups=cardinality,
bias=False)
self.bn2 = norm_layer(group_width)

self.conv3 = nn.Conv2d(
group_width, planes * 4, kernel_size=1, bias=False)
self.bn3 = norm_layer(planes * 4)

if last_gamma:
from torch.nn.init import zeros_
zeros_(self.bn3.weight)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.dilation = dilation
self.stride = stride

def forward(self, x):
residual = x

out = self.conv1(x)
out = self.bn1(out)
if self.dropblock_prob > 0.0:
out = self.dropblock1(out)
out = self.relu(out)

if self.avd and self.avd_first:
out = self.avd_layer(out)

out = self.conv2(out)
if self.radix == 0:
out = self.bn2(out)
if self.dropblock_prob > 0.0:
out = self.dropblock2(out)
out = self.relu(out)

if self.avd and not self.avd_first:
out = self.avd_layer(out)

out = self.conv3(out)
out = self.bn3(out)
if self.dropblock_prob > 0.0:
out = self.dropblock3(out)

if self.downsample is not None:
residual = self.downsample(x)

out += residual
out = self.relu(out)

return out


class ResNet(nn.Module):

def __init__(self,
block,
layers,
radix=1,
groups=1,
bottleneck_width=64,
num_classes=1000,
dilated=False,
dilation=1,
deep_stem=False,
stem_width=64,
avg_down=False,
rectified_conv=False,
rectify_avg=False,
avd=False,
avd_first=False,
final_drop=0.0,
dropblock_prob=0,
last_gamma=False,
norm_layer=nn.BatchNorm2d):
self.cardinality = groups
self.bottleneck_width = bottleneck_width
# ResNet-D params
self.inplanes = stem_width * 2 if deep_stem else 64
self.avg_down = avg_down
self.last_gamma = last_gamma
# ResNeSt params
self.radix = radix
self.avd = avd
self.avd_first = avd_first

super(ResNet, self).__init__()
self.rectified_conv = rectified_conv
self.rectify_avg = rectify_avg
if rectified_conv:
from rfconv import RFConv2d
conv_layer = RFConv2d
else:
conv_layer = nn.Conv2d
conv_kwargs = {'average_mode': rectify_avg} if rectified_conv else {}
if deep_stem:
self.conv1 = nn.Sequential(
conv_layer(
3,
stem_width,
kernel_size=3,
stride=2,
padding=1,
bias=False,
**conv_kwargs),
norm_layer(stem_width),
nn.ReLU(inplace=True),
conv_layer(
stem_width,
stem_width,
kernel_size=3,
stride=1,
padding=1,
bias=False,
**conv_kwargs),
norm_layer(stem_width),
nn.ReLU(inplace=True),
conv_layer(
stem_width,
stem_width * 2,
kernel_size=3,
stride=1,
padding=1,
bias=False,
**conv_kwargs),
)
else:
self.conv1 = conv_layer(
3,
64,
kernel_size=7,
stride=2,
padding=3,
bias=False,
**conv_kwargs)
self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(
block, 64, layers[0], norm_layer=norm_layer, is_first=False)
self.layer2 = self._make_layer(
block, 128, layers[1], stride=2, norm_layer=norm_layer)
if dilated or dilation == 4:
self.layer3 = self._make_layer(
block,
256,
layers[2],
stride=1,
dilation=2,
norm_layer=norm_layer,
dropblock_prob=dropblock_prob)
self.layer4 = self._make_layer(
block,
512,
layers[3],
stride=1,
dilation=4,
norm_layer=norm_layer,
dropblock_prob=dropblock_prob)
elif dilation == 2:
self.layer3 = self._make_layer(
block,
256,
layers[2],
stride=2,
dilation=1,
norm_layer=norm_layer,
dropblock_prob=dropblock_prob)
self.layer4 = self._make_layer(
block,
512,
layers[3],
stride=1,
dilation=2,
norm_layer=norm_layer,
dropblock_prob=dropblock_prob)
else:
self.layer3 = self._make_layer(
block,
256,
layers[2],
stride=2,
norm_layer=norm_layer,
dropblock_prob=dropblock_prob)
self.layer4 = self._make_layer(
block,
512,
layers[3],
stride=2,
norm_layer=norm_layer,
dropblock_prob=dropblock_prob)
self.avgpool = GlobalAvgPool2d()
self.drop = nn.Dropout(final_drop) if final_drop > 0.0 else None
self.fc = nn.Linear(512 * block.expansion, num_classes)

for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, norm_layer):
m.weight.data.fill_(1)
m.bias.data.zero_()

def _make_layer(self,
block,
planes,
blocks,
stride=1,
dilation=1,
norm_layer=None,
dropblock_prob=0.0,
is_first=True):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
down_layers = []
if self.avg_down:
if dilation == 1:
down_layers.append(
nn.AvgPool2d(
kernel_size=stride,
stride=stride,
ceil_mode=True,
count_include_pad=False))
else:
down_layers.append(
nn.AvgPool2d(
kernel_size=1,
stride=1,
ceil_mode=True,
count_include_pad=False))
down_layers.append(
nn.Conv2d(
self.inplanes,
planes * block.expansion,
kernel_size=1,
stride=1,
bias=False))
else:
down_layers.append(
nn.Conv2d(
self.inplanes,
planes * block.expansion,
kernel_size=1,
stride=stride,
bias=False))
down_layers.append(norm_layer(planes * block.expansion))
downsample = nn.Sequential(*down_layers)

layers = []
if dilation == 1 or dilation == 2:
layers.append(
block(
self.inplanes,
planes,
stride,
downsample=downsample,
radix=self.radix,
cardinality=self.cardinality,
bottleneck_width=self.bottleneck_width,
avd=self.avd,
avd_first=self.avd_first,
dilation=1,
is_first=is_first,
rectified_conv=self.rectified_conv,
rectify_avg=self.rectify_avg,
norm_layer=norm_layer,
dropblock_prob=dropblock_prob,
last_gamma=self.last_gamma))
elif dilation == 4:
layers.append(
block(
self.inplanes,
planes,
stride,
downsample=downsample,
radix=self.radix,
cardinality=self.cardinality,
bottleneck_width=self.bottleneck_width,
avd=self.avd,
avd_first=self.avd_first,
dilation=2,
is_first=is_first,
rectified_conv=self.rectified_conv,
rectify_avg=self.rectify_avg,
norm_layer=norm_layer,
dropblock_prob=dropblock_prob,
last_gamma=self.last_gamma))
else:
raise RuntimeError('=> unknown dilation size: {}'.format(dilation))

self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(
block(
self.inplanes,
planes,
radix=self.radix,
cardinality=self.cardinality,
bottleneck_width=self.bottleneck_width,
avd=self.avd,
avd_first=self.avd_first,
dilation=dilation,
rectified_conv=self.rectified_conv,
rectify_avg=self.rectify_avg,
norm_layer=norm_layer,
dropblock_prob=dropblock_prob,
last_gamma=self.last_gamma))

return nn.Sequential(*layers)

def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)

x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)

x = self.avgpool(x)
x = torch.flatten(x, 1)
if self.drop:
x = self.drop(x)
x = self.fc(x)

return x

+ 125
- 0
modelscope/models/cv/animal_recognition/splat.py View File

@@ -0,0 +1,125 @@
"""Split-Attention"""

import torch
import torch.nn.functional as F
from torch import nn
from torch.nn import BatchNorm2d, Conv2d, Linear, Module, ReLU
from torch.nn.modules.utils import _pair

__all__ = ['SplAtConv2d']


class SplAtConv2d(Module):
"""Split-Attention Conv2d
"""

def __init__(self,
in_channels,
channels,
kernel_size,
stride=(1, 1),
padding=(0, 0),
dilation=(1, 1),
groups=1,
bias=True,
radix=2,
reduction_factor=4,
rectify=False,
rectify_avg=False,
norm_layer=None,
dropblock_prob=0.0,
**kwargs):
super(SplAtConv2d, self).__init__()
padding = _pair(padding)
self.rectify = rectify and (padding[0] > 0 or padding[1] > 0)
self.rectify_avg = rectify_avg
inter_channels = max(in_channels * radix // reduction_factor, 32)
self.radix = radix
self.cardinality = groups
self.channels = channels
self.dropblock_prob = dropblock_prob
if self.rectify:
from rfconv import RFConv2d
self.conv = RFConv2d(
in_channels,
channels * radix,
kernel_size,
stride,
padding,
dilation,
groups=groups * radix,
bias=bias,
average_mode=rectify_avg,
**kwargs)
else:
self.conv = Conv2d(
in_channels,
channels * radix,
kernel_size,
stride,
padding,
dilation,
groups=groups * radix,
bias=bias,
**kwargs)
self.use_bn = norm_layer is not None
if self.use_bn:
self.bn0 = norm_layer(channels * radix)
self.relu = ReLU(inplace=True)
self.fc1 = Conv2d(channels, inter_channels, 1, groups=self.cardinality)
if self.use_bn:
self.bn1 = norm_layer(inter_channels)
self.fc2 = Conv2d(
inter_channels, channels * radix, 1, groups=self.cardinality)
if dropblock_prob > 0.0:
self.dropblock = DropBlock2D(dropblock_prob, 3)
self.rsoftmax = rSoftMax(radix, groups)

def forward(self, x):
x = self.conv(x)
if self.use_bn:
x = self.bn0(x)
if self.dropblock_prob > 0.0:
x = self.dropblock(x)
x = self.relu(x)

batch, rchannel = x.shape[:2]
if self.radix > 1:
splited = torch.split(x, rchannel // self.radix, dim=1)
gap = sum(splited)
else:
gap = x
gap = F.adaptive_avg_pool2d(gap, 1)
gap = self.fc1(gap)

if self.use_bn:
gap = self.bn1(gap)
gap = self.relu(gap)

atten = self.fc2(gap)
atten = self.rsoftmax(atten).view(batch, -1, 1, 1)

if self.radix > 1:
attens = torch.split(atten, rchannel // self.radix, dim=1)
out = sum([att * split for (att, split) in zip(attens, splited)])
else:
out = atten * x
return out.contiguous()


class rSoftMax(nn.Module):

def __init__(self, radix, cardinality):
super().__init__()
self.radix = radix
self.cardinality = cardinality

def forward(self, x):
batch = x.size(0)
if self.radix > 1:
x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2)
x = F.softmax(x, dim=1)
x = x.reshape(batch, -1)
else:
x = torch.sigmoid(x)
return x

+ 9
- 1
modelscope/models/nlp/masked_language_model.py View File

@@ -7,7 +7,7 @@ from ...utils.constant import Tasks
from ..base import Model, Tensor from ..base import Model, Tensor
from ..builder import MODELS from ..builder import MODELS


__all__ = ['StructBertForMaskedLM', 'VecoForMaskedLM']
__all__ = ['BertForMaskedLM', 'StructBertForMaskedLM', 'VecoForMaskedLM']




class MaskedLanguageModelBase(Model): class MaskedLanguageModelBase(Model):
@@ -61,3 +61,11 @@ class VecoForMaskedLM(MaskedLanguageModelBase):
def build_model(self): def build_model(self):
from sofa import VecoForMaskedLM from sofa import VecoForMaskedLM
return VecoForMaskedLM.from_pretrained(self.model_dir) return VecoForMaskedLM.from_pretrained(self.model_dir)


@MODELS.register_module(Tasks.fill_mask, module_name=Models.bert)
class BertForMaskedLM(MaskedLanguageModelBase):

def build_model(self):
from transformers import BertForMaskedLM
return BertForMaskedLM.from_pretrained(self.model_dir)

+ 1
- 0
modelscope/pipelines/__init__.py View File

@@ -1,4 +1,5 @@
from .audio import LinearAECPipeline from .audio import LinearAECPipeline
from .audio.ans_pipeline import ANSPipeline
from .base import Pipeline from .base import Pipeline
from .builder import pipeline from .builder import pipeline
from .cv import * # noqa F403 from .cv import * # noqa F403


+ 117
- 0
modelscope/pipelines/audio/ans_pipeline.py View File

@@ -0,0 +1,117 @@
import os.path
from typing import Any, Dict

import librosa
import numpy as np
import soundfile as sf
import torch

from modelscope.metainfo import Pipelines
from modelscope.utils.constant import Tasks
from ..base import Input, Pipeline
from ..builder import PIPELINES


def audio_norm(x):
rms = (x**2).mean()**0.5
scalar = 10**(-25 / 20) / rms
x = x * scalar
pow_x = x**2
avg_pow_x = pow_x.mean()
rmsx = pow_x[pow_x > avg_pow_x].mean()**0.5
scalarx = 10**(-25 / 20) / rmsx
x = x * scalarx
return x


@PIPELINES.register_module(
Tasks.speech_signal_process,
module_name=Pipelines.speech_frcrn_ans_cirm_16k)
class ANSPipeline(Pipeline):
r"""ANS (Acoustic Noise Suppression) Inference Pipeline .

When invoke the class with pipeline.__call__(), it accept only one parameter:
inputs(str): the path of wav file
"""
SAMPLE_RATE = 16000

def __init__(self, model):
r"""
Args:
model: model id on modelscope hub.
"""
super().__init__(model=model)
self.device = torch.device(
'cuda' if torch.cuda.is_available() else 'cpu')
self.model = self.model.to(self.device)
self.model.eval()

def preprocess(self, inputs: Input) -> Dict[str, Any]:
assert isinstance(inputs, str) and os.path.exists(inputs) and os.path.isfile(inputs), \
f'Input file do not exists: {inputs}'
data1, fs = sf.read(inputs)
data1 = audio_norm(data1)
if fs != self.SAMPLE_RATE:
data1 = librosa.resample(data1, fs, self.SAMPLE_RATE)
if len(data1.shape) > 1:
data1 = data1[:, 0]
data = data1.astype(np.float32)
inputs = np.reshape(data, [1, data.shape[0]])
return {'ndarray': inputs, 'nsamples': data.shape[0]}

def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
ndarray = inputs['ndarray']
nsamples = inputs['nsamples']
decode_do_segement = False
window = 16000
stride = int(window * 0.75)
print('inputs:{}'.format(ndarray.shape))
b, t = ndarray.shape # size()
if t > window * 120:
decode_do_segement = True

if t < window:
ndarray = np.concatenate(
[ndarray, np.zeros((ndarray.shape[0], window - t))], 1)
elif t < window + stride:
padding = window + stride - t
print('padding: {}'.format(padding))
ndarray = np.concatenate(
[ndarray, np.zeros((ndarray.shape[0], padding))], 1)
else:
if (t - window) % stride != 0:
padding = t - (t - window) // stride * stride
print('padding: {}'.format(padding))
ndarray = np.concatenate(
[ndarray, np.zeros((ndarray.shape[0], padding))], 1)
print('inputs after padding:{}'.format(ndarray.shape))
with torch.no_grad():
ndarray = torch.from_numpy(np.float32(ndarray)).to(self.device)
b, t = ndarray.shape
if decode_do_segement:
outputs = np.zeros(t)
give_up_length = (window - stride) // 2
current_idx = 0
while current_idx + window <= t:
print('current_idx: {}'.format(current_idx))
tmp_input = ndarray[:, current_idx:current_idx + window]
tmp_output = self.model(
tmp_input, )['wav_l2'][0].cpu().numpy()
end_index = current_idx + window - give_up_length
if current_idx == 0:
outputs[current_idx:
end_index] = tmp_output[:-give_up_length]
else:
outputs[current_idx
+ give_up_length:end_index] = tmp_output[
give_up_length:-give_up_length]
current_idx += stride
else:
outputs = self.model(ndarray)['wav_l2'][0].cpu().numpy()
return {'output_pcm': outputs[:nsamples]}

def postprocess(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]:
if 'output_path' in kwargs.keys():
sf.write(kwargs['output_path'], inputs['output_pcm'],
self.SAMPLE_RATE)
return inputs

+ 1
- 0
modelscope/pipelines/cv/__init__.py View File

@@ -1,4 +1,5 @@
from .action_recognition_pipeline import ActionRecognitionPipeline from .action_recognition_pipeline import ActionRecognitionPipeline
from .animal_recog_pipeline import AnimalRecogPipeline
from .image_cartoon_pipeline import ImageCartoonPipeline from .image_cartoon_pipeline import ImageCartoonPipeline
from .image_matting_pipeline import ImageMattingPipeline from .image_matting_pipeline import ImageMattingPipeline
from .ocr_detection_pipeline import OCRDetectionPipeline from .ocr_detection_pipeline import OCRDetectionPipeline

+ 127
- 0
modelscope/pipelines/cv/animal_recog_pipeline.py View File

@@ -0,0 +1,127 @@
import os.path as osp
import tempfile
from typing import Any, Dict

import cv2
import numpy as np
import torch
from PIL import Image
from torchvision import transforms

from modelscope.fileio import File
from modelscope.hub.snapshot_download import snapshot_download
from modelscope.metainfo import Pipelines
from modelscope.models.cv.animal_recognition import resnet
from modelscope.pipelines.base import Input
from modelscope.preprocessors import load_image
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger
from ..base import Pipeline
from ..builder import PIPELINES

logger = get_logger()


@PIPELINES.register_module(
Tasks.image_classification, module_name=Pipelines.animal_recognation)
class AnimalRecogPipeline(Pipeline):

def __init__(self, model: str):
super().__init__(model=model)
import torch

def resnest101(**kwargs):
model = resnet.ResNet(
resnet.Bottleneck, [3, 4, 23, 3],
radix=2,
groups=1,
bottleneck_width=64,
deep_stem=True,
stem_width=64,
avg_down=True,
avd=True,
avd_first=False,
**kwargs)
return model

def filter_param(src_params, own_state):
copied_keys = []
for name, param in src_params.items():
if 'module.' == name[0:7]:
name = name[7:]
if '.module.' not in list(own_state.keys())[0]:
name = name.replace('.module.', '.')
if (name in own_state) and (own_state[name].shape
== param.shape):
own_state[name].copy_(param)
copied_keys.append(name)

def load_pretrained(model, src_params):
if 'state_dict' in src_params:
src_params = src_params['state_dict']
own_state = model.state_dict()
filter_param(src_params, own_state)
model.load_state_dict(own_state)

self.model = resnest101(num_classes=8288)
local_model_dir = model
if osp.exists(model):
local_model_dir = model
else:
local_model_dir = snapshot_download(model)
self.local_path = local_model_dir
src_params = torch.load(
osp.join(local_model_dir, 'pytorch_model.pt'), 'cpu')
load_pretrained(self.model, src_params)
logger.info('load model done')

def preprocess(self, input: Input) -> Dict[str, Any]:
if isinstance(input, str):
img = load_image(input)
elif isinstance(input, PIL.Image.Image):
img = input.convert('RGB')
elif isinstance(input, np.ndarray):
if len(input.shape) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
img = input[:, :, ::-1]
img = Image.fromarray(img.astype('uint8')).convert('RGB')
else:
raise TypeError(f'input should be either str, PIL.Image,'
f' np.array, but got {type(input)}')

normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
test_transforms = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(), normalize
])
img = test_transforms(img)
result = {'img': img}
return result

def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:

def set_phase(model, is_train):
if is_train:
model.train()
else:
model.eval()

is_train = False
set_phase(self.model, is_train)
img = input['img']
input_img = torch.unsqueeze(img, 0)
outputs = self.model(input_img)
return {'outputs': outputs}

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
label_mapping_path = osp.join(self.local_path, 'label_mapping.txt')
with open(label_mapping_path, 'r') as f:
label_mapping = f.readlines()
score = torch.max(inputs['outputs'])
inputs = {
'scores': score.item(),
'labels': label_mapping[inputs['outputs'].argmax()].split('\t')[1]
}
return inputs

+ 52
- 42
modelscope/pipelines/cv/ocr_detection_pipeline.py View File

@@ -8,7 +8,6 @@ import cv2
import numpy as np import numpy as np
import PIL import PIL
import tensorflow as tf import tensorflow as tf
import tf_slim as slim


from modelscope.metainfo import Pipelines from modelscope.metainfo import Pipelines
from modelscope.pipelines.base import Input from modelscope.pipelines.base import Input
@@ -19,6 +18,11 @@ from ..base import Pipeline
from ..builder import PIPELINES from ..builder import PIPELINES
from .ocr_utils import model_resnet_mutex_v4_linewithchar, ops, utils from .ocr_utils import model_resnet_mutex_v4_linewithchar, ops, utils


if tf.__version__ >= '2.0':
import tf_slim as slim
else:
from tensorflow.contrib import slim

if tf.__version__ >= '2.0': if tf.__version__ >= '2.0':
tf = tf.compat.v1 tf = tf.compat.v1
tf.compat.v1.disable_eager_execution() tf.compat.v1.disable_eager_execution()
@@ -44,6 +48,7 @@ class OCRDetectionPipeline(Pipeline):


def __init__(self, model: str): def __init__(self, model: str):
super().__init__(model=model) super().__init__(model=model)
tf.reset_default_graph()
model_path = osp.join( model_path = osp.join(
osp.join(self.model, ModelFile.TF_CHECKPOINT_FOLDER), osp.join(self.model, ModelFile.TF_CHECKPOINT_FOLDER),
'checkpoint-80000') 'checkpoint-80000')
@@ -51,51 +56,56 @@ class OCRDetectionPipeline(Pipeline):
config = tf.ConfigProto(allow_soft_placement=True) config = tf.ConfigProto(allow_soft_placement=True)
config.gpu_options.allow_growth = True config.gpu_options.allow_growth = True
self._session = tf.Session(config=config) self._session = tf.Session(config=config)
global_step = tf.get_variable(
'global_step', [],
initializer=tf.constant_initializer(0),
dtype=tf.int64,
trainable=False)
variable_averages = tf.train.ExponentialMovingAverage(
0.997, global_step)
self.input_images = tf.placeholder( self.input_images = tf.placeholder(
tf.float32, shape=[1, 1024, 1024, 3], name='input_images') tf.float32, shape=[1, 1024, 1024, 3], name='input_images')
self.output = {} self.output = {}


# detector
detector = model_resnet_mutex_v4_linewithchar.SegLinkDetector()
all_maps = detector.build_model(self.input_images, is_training=False)

# decode local predictions
all_nodes, all_links, all_reg = [], [], []
for i, maps in enumerate(all_maps):
cls_maps, lnk_maps, reg_maps = maps[0], maps[1], maps[2]
reg_maps = tf.multiply(reg_maps, OFFSET_VARIANCE)

cls_prob = tf.nn.softmax(tf.reshape(cls_maps, [-1, 2]))

lnk_prob_pos = tf.nn.softmax(tf.reshape(lnk_maps, [-1, 4])[:, :2])
lnk_prob_mut = tf.nn.softmax(tf.reshape(lnk_maps, [-1, 4])[:, 2:])
lnk_prob = tf.concat([lnk_prob_pos, lnk_prob_mut], axis=1)

all_nodes.append(cls_prob)
all_links.append(lnk_prob)
all_reg.append(reg_maps)

# decode segments and links
image_size = tf.shape(self.input_images)[1:3]
segments, group_indices, segment_counts, _ = ops.decode_segments_links_python(
image_size,
all_nodes,
all_links,
all_reg,
anchor_sizes=list(detector.anchor_sizes))

# combine segments
combined_rboxes, combined_counts = ops.combine_segments_python(
segments, group_indices, segment_counts)
self.output['combined_rboxes'] = combined_rboxes
self.output['combined_counts'] = combined_counts
with tf.variable_scope('', reuse=tf.AUTO_REUSE):
global_step = tf.get_variable(
'global_step', [],
initializer=tf.constant_initializer(0),
dtype=tf.int64,
trainable=False)
variable_averages = tf.train.ExponentialMovingAverage(
0.997, global_step)

# detector
detector = model_resnet_mutex_v4_linewithchar.SegLinkDetector()
all_maps = detector.build_model(
self.input_images, is_training=False)

# decode local predictions
all_nodes, all_links, all_reg = [], [], []
for i, maps in enumerate(all_maps):
cls_maps, lnk_maps, reg_maps = maps[0], maps[1], maps[2]
reg_maps = tf.multiply(reg_maps, OFFSET_VARIANCE)

cls_prob = tf.nn.softmax(tf.reshape(cls_maps, [-1, 2]))

lnk_prob_pos = tf.nn.softmax(
tf.reshape(lnk_maps, [-1, 4])[:, :2])
lnk_prob_mut = tf.nn.softmax(
tf.reshape(lnk_maps, [-1, 4])[:, 2:])
lnk_prob = tf.concat([lnk_prob_pos, lnk_prob_mut], axis=1)

all_nodes.append(cls_prob)
all_links.append(lnk_prob)
all_reg.append(reg_maps)

# decode segments and links
image_size = tf.shape(self.input_images)[1:3]
segments, group_indices, segment_counts, _ = ops.decode_segments_links_python(
image_size,
all_nodes,
all_links,
all_reg,
anchor_sizes=list(detector.anchor_sizes))

# combine segments
combined_rboxes, combined_counts = ops.combine_segments_python(
segments, group_indices, segment_counts)
self.output['combined_rboxes'] = combined_rboxes
self.output['combined_counts'] = combined_counts


with self._session.as_default() as sess: with self._session.as_default() as sess:
logger.info(f'loading model from {model_path}') logger.info(f'loading model from {model_path}')


+ 5
- 1
modelscope/pipelines/cv/ocr_utils/model_resnet_mutex_v4_linewithchar.py View File

@@ -1,8 +1,12 @@
import tensorflow as tf import tensorflow as tf
import tf_slim as slim


from . import ops, resnet18_v1, resnet_utils from . import ops, resnet18_v1, resnet_utils


if tf.__version__ >= '2.0':
import tf_slim as slim
else:
from tensorflow.contrib import slim

if tf.__version__ >= '2.0': if tf.__version__ >= '2.0':
tf = tf.compat.v1 tf = tf.compat.v1




+ 5
- 1
modelscope/pipelines/cv/ocr_utils/resnet18_v1.py View File

@@ -30,10 +30,14 @@ ResNet-101 for semantic segmentation into 21 classes:
output_stride=16) output_stride=16)
""" """
import tensorflow as tf import tensorflow as tf
import tf_slim as slim


from . import resnet_utils from . import resnet_utils


if tf.__version__ >= '2.0':
import tf_slim as slim
else:
from tensorflow.contrib import slim

if tf.__version__ >= '2.0': if tf.__version__ >= '2.0':
tf = tf.compat.v1 tf = tf.compat.v1




+ 5
- 1
modelscope/pipelines/cv/ocr_utils/resnet_utils.py View File

@@ -19,7 +19,11 @@ implementation is more memory efficient.
import collections import collections


import tensorflow as tf import tensorflow as tf
import tf_slim as slim

if tf.__version__ >= '2.0':
import tf_slim as slim
else:
from tensorflow.contrib import slim


if tf.__version__ >= '2.0': if tf.__version__ >= '2.0':
tf = tf.compat.v1 tf = tf.compat.v1


+ 14
- 10
modelscope/pipelines/nlp/fill_mask_pipeline.py View File

@@ -1,3 +1,4 @@
import os
from typing import Any, Dict, Optional, Union from typing import Any, Dict, Optional, Union


import torch import torch
@@ -6,11 +7,13 @@ from ...metainfo import Pipelines
from ...models import Model from ...models import Model
from ...models.nlp.masked_language_model import MaskedLanguageModelBase from ...models.nlp.masked_language_model import MaskedLanguageModelBase
from ...preprocessors import FillMaskPreprocessor from ...preprocessors import FillMaskPreprocessor
from ...utils.constant import Tasks
from ...utils.config import Config
from ...utils.constant import ModelFile, Tasks
from ..base import Pipeline, Tensor from ..base import Pipeline, Tensor
from ..builder import PIPELINES from ..builder import PIPELINES


__all__ = ['FillMaskPipeline'] __all__ = ['FillMaskPipeline']
_type_map = {'veco': 'roberta', 'sbert': 'bert'}




@PIPELINES.register_module(Tasks.fill_mask, module_name=Pipelines.fill_mask) @PIPELINES.register_module(Tasks.fill_mask, module_name=Pipelines.fill_mask)
@@ -29,7 +32,6 @@ class FillMaskPipeline(Pipeline):
""" """
fill_mask_model = model if isinstance( fill_mask_model = model if isinstance(
model, MaskedLanguageModelBase) else Model.from_pretrained(model) model, MaskedLanguageModelBase) else Model.from_pretrained(model)
assert fill_mask_model.config is not None


if preprocessor is None: if preprocessor is None:
preprocessor = FillMaskPreprocessor( preprocessor = FillMaskPreprocessor(
@@ -41,11 +43,13 @@ class FillMaskPipeline(Pipeline):
model=fill_mask_model, preprocessor=preprocessor, **kwargs) model=fill_mask_model, preprocessor=preprocessor, **kwargs)


self.preprocessor = preprocessor self.preprocessor = preprocessor
self.config = Config.from_file(
os.path.join(fill_mask_model.model_dir, ModelFile.CONFIGURATION))
self.tokenizer = preprocessor.tokenizer self.tokenizer = preprocessor.tokenizer
self.mask_id = {'veco': 250001, 'sbert': 103}
self.mask_id = {'roberta': 250001, 'bert': 103}


self.rep_map = { self.rep_map = {
'sbert': {
'bert': {
'[unused0]': '', '[unused0]': '',
'[PAD]': '', '[PAD]': '',
'[unused1]': '', '[unused1]': '',
@@ -55,7 +59,7 @@ class FillMaskPipeline(Pipeline):
'[CLS]': '', '[CLS]': '',
'[UNK]': '' '[UNK]': ''
}, },
'veco': {
'roberta': {
r' +': ' ', r' +': ' ',
'<mask>': '<q>', '<mask>': '<q>',
'<pad>': '', '<pad>': '',
@@ -84,7 +88,9 @@ class FillMaskPipeline(Pipeline):
input_ids = inputs['input_ids'].detach().numpy() input_ids = inputs['input_ids'].detach().numpy()
pred_ids = np.argmax(logits, axis=-1) pred_ids = np.argmax(logits, axis=-1)
model_type = self.model.config.model_type model_type = self.model.config.model_type
rst_ids = np.where(input_ids == self.mask_id[model_type], pred_ids,
process_type = model_type if model_type in self.mask_id else _type_map[
model_type]
rst_ids = np.where(input_ids == self.mask_id[process_type], pred_ids,
input_ids) input_ids)


def rep_tokens(string, rep_map): def rep_tokens(string, rep_map):
@@ -94,14 +100,12 @@ class FillMaskPipeline(Pipeline):


pred_strings = [] pred_strings = []
for ids in rst_ids: # batch for ids in rst_ids: # batch
# TODO vocab size is not stable

if self.model.config.vocab_size == 21128: # zh bert
if 'language' in self.config.model and self.config.model.language == 'zh':
pred_string = self.tokenizer.convert_ids_to_tokens(ids) pred_string = self.tokenizer.convert_ids_to_tokens(ids)
pred_string = ''.join(pred_string) pred_string = ''.join(pred_string)
else: else:
pred_string = self.tokenizer.decode(ids) pred_string = self.tokenizer.decode(ids)
pred_string = rep_tokens(pred_string, self.rep_map[model_type])
pred_string = rep_tokens(pred_string, self.rep_map[process_type])
pred_strings.append(pred_string) pred_strings.append(pred_string)


return {'text': pred_strings} return {'text': pred_strings}

+ 7
- 4
modelscope/preprocessors/nlp.py View File

@@ -326,14 +326,17 @@ class FillMaskPreprocessor(Preprocessor):
model_dir (str): model path model_dir (str): model path
""" """
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
from sofa.utils.backend import AutoTokenizer
self.model_dir = model_dir self.model_dir = model_dir
self.first_sequence: str = kwargs.pop('first_sequence', self.first_sequence: str = kwargs.pop('first_sequence',
'first_sequence') 'first_sequence')
self.sequence_length = kwargs.pop('sequence_length', 128) self.sequence_length = kwargs.pop('sequence_length', 128)

self.tokenizer = AutoTokenizer.from_pretrained(
model_dir, use_fast=False)
try:
from transformers import AutoTokenizer
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
except KeyError:
from sofa.utils.backend import AutoTokenizer
self.tokenizer = AutoTokenizer.from_pretrained(
model_dir, use_fast=False)


@type_assert(object, str) @type_assert(object, str)
def __call__(self, data: str) -> Dict[str, Any]: def __call__(self, data: str) -> Dict[str, Any]:


+ 1
- 0
requirements/audio.txt View File

@@ -16,6 +16,7 @@ protobuf>3,<=3.20
ptflops ptflops
PyWavelets>=1.0.0 PyWavelets>=1.0.0
scikit-learn scikit-learn
SoundFile>0.10
sox sox
tensorboard tensorboard
tensorflow==1.15.* tensorflow==1.15.*


+ 20
- 0
tests/pipelines/test_animal_recognation.py View File

@@ -0,0 +1,20 @@
import unittest

from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level


class MultiModalFeatureTest(unittest.TestCase):

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run(self):
animal_recog = pipeline(
Tasks.image_classification,
model='damo/cv_resnest101_animal_recognation')
result = animal_recog('data/test/images/image1.jpg')
print(result)


if __name__ == '__main__':
unittest.main()

+ 35
- 1
tests/pipelines/test_fill_mask.py View File

@@ -3,7 +3,8 @@ import unittest


from modelscope.hub.snapshot_download import snapshot_download from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models import Model from modelscope.models import Model
from modelscope.models.nlp import StructBertForMaskedLM, VecoForMaskedLM
from modelscope.models.nlp import (BertForMaskedLM, StructBertForMaskedLM,
VecoForMaskedLM)
from modelscope.pipelines import FillMaskPipeline, pipeline from modelscope.pipelines import FillMaskPipeline, pipeline
from modelscope.preprocessors import FillMaskPreprocessor from modelscope.preprocessors import FillMaskPreprocessor
from modelscope.utils.constant import Tasks from modelscope.utils.constant import Tasks
@@ -16,6 +17,7 @@ class FillMaskTest(unittest.TestCase):
'en': 'damo/nlp_structbert_fill-mask_english-large' 'en': 'damo/nlp_structbert_fill-mask_english-large'
} }
model_id_veco = 'damo/nlp_veco_fill-mask-large' model_id_veco = 'damo/nlp_veco_fill-mask-large'
model_id_bert = 'damo/nlp_bert_fill-mask_chinese-base'


ori_texts = { ori_texts = {
'zh': 'zh':
@@ -69,6 +71,20 @@ class FillMaskTest(unittest.TestCase):
f'{pipeline1(test_input)}\npipeline2: {pipeline2(test_input)}\n' f'{pipeline1(test_input)}\npipeline2: {pipeline2(test_input)}\n'
) )


# zh bert
language = 'zh'
model_dir = snapshot_download(self.model_id_bert)
preprocessor = FillMaskPreprocessor(
model_dir, first_sequence='sentence', second_sequence=None)
model = BertForMaskedLM(model_dir)
pipeline1 = FillMaskPipeline(model, preprocessor)
pipeline2 = pipeline(
Tasks.fill_mask, model=model, preprocessor=preprocessor)
ori_text = self.ori_texts[language]
test_input = self.test_inputs[language]
print(f'\nori_text: {ori_text}\ninput: {test_input}\npipeline1: '
f'{pipeline1(test_input)}\npipeline2: {pipeline2(test_input)}\n')

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_from_modelhub(self): def test_run_with_model_from_modelhub(self):
# sbert # sbert
@@ -97,6 +113,18 @@ class FillMaskTest(unittest.TestCase):
print(f'\nori_text: {ori_text}\ninput: {test_input}\npipeline: ' print(f'\nori_text: {ori_text}\ninput: {test_input}\npipeline: '
f'{pipeline_ins(test_input)}\n') f'{pipeline_ins(test_input)}\n')


# zh bert
model = Model.from_pretrained(self.model_id_bert)
preprocessor = FillMaskPreprocessor(
model.model_dir, first_sequence='sentence', second_sequence=None)
pipeline_ins = pipeline(
Tasks.fill_mask, model=model, preprocessor=preprocessor)
language = 'zh'
ori_text = self.ori_texts[language]
test_input = self.test_inputs[language]
print(f'\nori_text: {ori_text}\ninput: {test_input}\npipeline: '
f'{pipeline_ins(test_input)}\n')

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_name(self): def test_run_with_model_name(self):
# veco # veco
@@ -115,6 +143,12 @@ class FillMaskTest(unittest.TestCase):
f'\nori_text: {self.ori_texts[language]}\ninput: {self.test_inputs[language]}\npipeline: ' f'\nori_text: {self.ori_texts[language]}\ninput: {self.test_inputs[language]}\npipeline: '
f'{pipeline_ins(self.test_inputs[language])}\n') f'{pipeline_ins(self.test_inputs[language])}\n')


# bert
pipeline_ins = pipeline(task=Tasks.fill_mask, model=self.model_id_bert)
print(
f'\nori_text: {self.ori_texts[language]}\ninput: {self.test_inputs[language]}\npipeline: '
f'{pipeline_ins(self.test_inputs[language])}\n')

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_default_model(self): def test_run_with_default_model(self):
pipeline_ins = pipeline(task=Tasks.fill_mask) pipeline_ins = pipeline(task=Tasks.fill_mask)


+ 5
- 0
tests/pipelines/test_ocr_detection.py View File

@@ -27,6 +27,11 @@ class OCRDetectionTest(unittest.TestCase):
print('ocr detection results: ') print('ocr detection results: ')
print(result) print(result)


@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_model_from_modelhub(self):
ocr_detection = pipeline(Tasks.ocr_detection, model=self.model_id)
self.pipeline_inference(ocr_detection, self.test_image)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_modelhub_default_model(self): def test_run_modelhub_default_model(self):
ocr_detection = pipeline(Tasks.ocr_detection) ocr_detection = pipeline(Tasks.ocr_detection)


+ 26
- 6
tests/pipelines/test_speech_signal_process.py View File

@@ -17,6 +17,9 @@ AEC_LIB_URL = 'http://isv-data.oss-cn-hangzhou.aliyuncs.com/ics%2FMaaS%2FAEC%2Fl
'?Expires=1664085465&OSSAccessKeyId=LTAIxjQyZNde90zh&Signature=Y7gelmGEsQAJRK4yyHSYMrdWizk%3D' '?Expires=1664085465&OSSAccessKeyId=LTAIxjQyZNde90zh&Signature=Y7gelmGEsQAJRK4yyHSYMrdWizk%3D'
AEC_LIB_FILE = 'libmitaec_pyio.so' AEC_LIB_FILE = 'libmitaec_pyio.so'


NOISE_SPEECH_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ANS/sample_audio/speech_with_noise.wav'
NOISE_SPEECH_FILE = 'speech_with_noise.wav'



def download(remote_path, local_path): def download(remote_path, local_path):
local_dir = os.path.dirname(local_path) local_dir = os.path.dirname(local_path)
@@ -30,23 +33,40 @@ def download(remote_path, local_path):
class SpeechSignalProcessTest(unittest.TestCase): class SpeechSignalProcessTest(unittest.TestCase):


def setUp(self) -> None: def setUp(self) -> None:
self.model_id = 'damo/speech_dfsmn_aec_psm_16k'
pass

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_aec(self):
# A temporary hack to provide c++ lib. Download it first. # A temporary hack to provide c++ lib. Download it first.
download(AEC_LIB_URL, AEC_LIB_FILE) download(AEC_LIB_URL, AEC_LIB_FILE)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run(self):
# Download audio files
download(NEAREND_MIC_URL, NEAREND_MIC_FILE) download(NEAREND_MIC_URL, NEAREND_MIC_FILE)
download(FAREND_SPEECH_URL, FAREND_SPEECH_FILE) download(FAREND_SPEECH_URL, FAREND_SPEECH_FILE)
model_id = 'damo/speech_dfsmn_aec_psm_16k'
input = { input = {
'nearend_mic': NEAREND_MIC_FILE, 'nearend_mic': NEAREND_MIC_FILE,
'farend_speech': FAREND_SPEECH_FILE 'farend_speech': FAREND_SPEECH_FILE
} }
aec = pipeline( aec = pipeline(
Tasks.speech_signal_process, Tasks.speech_signal_process,
model=self.model_id,
model=model_id,
pipeline_name=Pipelines.speech_dfsmn_aec_psm_16k) pipeline_name=Pipelines.speech_dfsmn_aec_psm_16k)
aec(input, output_path='output.wav')
output_path = os.path.abspath('output.wav')
aec(input, output_path=output_path)
print(f'Processed audio saved to {output_path}')

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_ans(self):
# Download audio files
download(NOISE_SPEECH_URL, NOISE_SPEECH_FILE)
model_id = 'damo/speech_frcrn_ans_cirm_16k'
ans = pipeline(
Tasks.speech_signal_process,
model=model_id,
pipeline_name=Pipelines.speech_frcrn_ans_cirm_16k)
output_path = os.path.abspath('output.wav')
ans(NOISE_SPEECH_FILE, output_path=output_path)
print(f'Processed audio saved to {output_path}')




if __name__ == '__main__': if __name__ == '__main__':


Loading…
Cancel
Save