diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index e2d43a02..a963cdb2 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -22,6 +22,7 @@ class Models(object): sambert_hifi_16k = 'sambert-hifi-16k' generic_tts_frontend = 'generic-tts-frontend' hifigan16k = 'hifigan16k' + speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k' kws_kwsbp = 'kws-kwsbp' # multi-modal models @@ -44,6 +45,7 @@ class Pipelines(object): person_image_cartoon = 'unet-person-image-cartoon' ocr_detection = 'resnet18-ocr-detection' action_recognition = 'TAdaConv_action-recognition' + animal_recognation = 'resnet101-animal_recog' # nlp tasks sentence_similarity = 'sentence-similarity' @@ -59,6 +61,7 @@ class Pipelines(object): # audio tasks sambert_hifigan_16k_tts = 'sambert-hifigan-16k-tts' speech_dfsmn_aec_psm_16k = 'speech-dfsmn-aec-psm-16k' + speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k' kws_kwsbp = 'kws-kwsbp' # multi-modal tasks diff --git a/modelscope/models/__init__.py b/modelscope/models/__init__.py index 7e2535d3..e0331513 100644 --- a/modelscope/models/__init__.py +++ b/modelscope/models/__init__.py @@ -1,12 +1,13 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +from .audio.ans.frcrn import FRCRNModel from .audio.kws import GenericKeyWordSpotting from .audio.tts.am import SambertNetHifi16k from .audio.tts.vocoder import Hifigan16k from .base import Model from .builder import MODELS, build_model from .multi_modal import OfaForImageCaptioning -from .nlp import (BertForSequenceClassification, SbertForNLI, +from .nlp import (BertForMaskedLM, BertForSequenceClassification, SbertForNLI, SbertForSentenceSimilarity, SbertForSentimentClassification, SbertForTokenClassification, StructBertForMaskedLM, VecoForMaskedLM) diff --git a/modelscope/models/audio/layers/__init__.py b/modelscope/models/audio/aec/__init__.py similarity index 100% rename from modelscope/models/audio/layers/__init__.py rename to modelscope/models/audio/aec/__init__.py diff --git a/modelscope/models/audio/network/__init__.py b/modelscope/models/audio/aec/layers/__init__.py similarity index 100% rename from modelscope/models/audio/network/__init__.py rename to modelscope/models/audio/aec/layers/__init__.py diff --git a/modelscope/models/audio/layers/activations.py b/modelscope/models/audio/aec/layers/activations.py similarity index 100% rename from modelscope/models/audio/layers/activations.py rename to modelscope/models/audio/aec/layers/activations.py diff --git a/modelscope/models/audio/layers/affine_transform.py b/modelscope/models/audio/aec/layers/affine_transform.py similarity index 100% rename from modelscope/models/audio/layers/affine_transform.py rename to modelscope/models/audio/aec/layers/affine_transform.py diff --git a/modelscope/models/audio/layers/deep_fsmn.py b/modelscope/models/audio/aec/layers/deep_fsmn.py similarity index 100% rename from modelscope/models/audio/layers/deep_fsmn.py rename to modelscope/models/audio/aec/layers/deep_fsmn.py diff --git a/modelscope/models/audio/layers/layer_base.py b/modelscope/models/audio/aec/layers/layer_base.py similarity index 100% rename from modelscope/models/audio/layers/layer_base.py rename to modelscope/models/audio/aec/layers/layer_base.py diff --git a/modelscope/models/audio/layers/uni_deep_fsmn.py b/modelscope/models/audio/aec/layers/uni_deep_fsmn.py similarity index 100% rename from modelscope/models/audio/layers/uni_deep_fsmn.py rename to modelscope/models/audio/aec/layers/uni_deep_fsmn.py diff --git a/modelscope/models/audio/aec/network/__init__.py b/modelscope/models/audio/aec/network/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/audio/network/loss.py b/modelscope/models/audio/aec/network/loss.py similarity index 100% rename from modelscope/models/audio/network/loss.py rename to modelscope/models/audio/aec/network/loss.py diff --git a/modelscope/models/audio/network/modulation_loss.py b/modelscope/models/audio/aec/network/modulation_loss.py similarity index 100% rename from modelscope/models/audio/network/modulation_loss.py rename to modelscope/models/audio/aec/network/modulation_loss.py diff --git a/modelscope/models/audio/network/se_net.py b/modelscope/models/audio/aec/network/se_net.py similarity index 100% rename from modelscope/models/audio/network/se_net.py rename to modelscope/models/audio/aec/network/se_net.py diff --git a/modelscope/models/audio/ans/__init__.py b/modelscope/models/audio/ans/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/audio/ans/complex_nn.py b/modelscope/models/audio/ans/complex_nn.py new file mode 100644 index 00000000..69dec41e --- /dev/null +++ b/modelscope/models/audio/ans/complex_nn.py @@ -0,0 +1,248 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class UniDeepFsmn(nn.Module): + + def __init__(self, input_dim, output_dim, lorder=None, hidden_size=None): + super(UniDeepFsmn, self).__init__() + + self.input_dim = input_dim + self.output_dim = output_dim + + if lorder is None: + return + + self.lorder = lorder + self.hidden_size = hidden_size + + self.linear = nn.Linear(input_dim, hidden_size) + + self.project = nn.Linear(hidden_size, output_dim, bias=False) + + self.conv1 = nn.Conv2d( + output_dim, + output_dim, [lorder, 1], [1, 1], + groups=output_dim, + bias=False) + + def forward(self, input): + r""" + + Args: + input: torch with shape: batch (b) x sequence(T) x feature (h) + + Returns: + batch (b) x channel (c) x sequence(T) x feature (h) + """ + f1 = F.relu(self.linear(input)) + + p1 = self.project(f1) + + x = torch.unsqueeze(p1, 1) + # x: batch (b) x channel (c) x sequence(T) x feature (h) + x_per = x.permute(0, 3, 2, 1) + # x_per: batch (b) x feature (h) x sequence(T) x channel (c) + y = F.pad(x_per, [0, 0, self.lorder - 1, 0]) + + out = x_per + self.conv1(y) + + out1 = out.permute(0, 3, 2, 1) + # out1: batch (b) x channel (c) x sequence(T) x feature (h) + return input + out1.squeeze() + + +class ComplexUniDeepFsmn(nn.Module): + + def __init__(self, nIn, nHidden=128, nOut=128): + super(ComplexUniDeepFsmn, self).__init__() + + self.fsmn_re_L1 = UniDeepFsmn(nIn, nHidden, 20, nHidden) + self.fsmn_im_L1 = UniDeepFsmn(nIn, nHidden, 20, nHidden) + self.fsmn_re_L2 = UniDeepFsmn(nHidden, nOut, 20, nHidden) + self.fsmn_im_L2 = UniDeepFsmn(nHidden, nOut, 20, nHidden) + + def forward(self, x): + r""" + + Args: + x: torch with shape [batch, channel, feature, sequence, 2], eg: [6, 256, 1, 106, 2] + + Returns: + [batch, feature, sequence, 2], eg: [6, 99, 1024, 2] + """ + # + b, c, h, T, d = x.size() + x = torch.reshape(x, (b, c * h, T, d)) + # x: [b,h,T,2], [6, 256, 106, 2] + x = torch.transpose(x, 1, 2) + # x: [b,T,h,2], [6, 106, 256, 2] + + real_L1 = self.fsmn_re_L1(x[..., 0]) - self.fsmn_im_L1(x[..., 1]) + imaginary_L1 = self.fsmn_re_L1(x[..., 1]) + self.fsmn_im_L1(x[..., 0]) + # GRU output: [99, 6, 128] + real = self.fsmn_re_L2(real_L1) - self.fsmn_im_L2(imaginary_L1) + imaginary = self.fsmn_re_L2(imaginary_L1) + self.fsmn_im_L2(real_L1) + # output: [b,T,h,2], [99, 6, 1024, 2] + output = torch.stack((real, imaginary), dim=-1) + + # output: [b,h,T,2], [6, 99, 1024, 2] + output = torch.transpose(output, 1, 2) + output = torch.reshape(output, (b, c, h, T, d)) + + return output + + +class ComplexUniDeepFsmn_L1(nn.Module): + + def __init__(self, nIn, nHidden=128, nOut=128): + super(ComplexUniDeepFsmn_L1, self).__init__() + self.fsmn_re_L1 = UniDeepFsmn(nIn, nHidden, 20, nHidden) + self.fsmn_im_L1 = UniDeepFsmn(nIn, nHidden, 20, nHidden) + + def forward(self, x): + r""" + + Args: + x: torch with shape [batch, channel, feature, sequence, 2], eg: [6, 256, 1, 106, 2] + """ + b, c, h, T, d = x.size() + # x : [b,T,h,c,2] + x = torch.transpose(x, 1, 3) + x = torch.reshape(x, (b * T, h, c, d)) + + real = self.fsmn_re_L1(x[..., 0]) - self.fsmn_im_L1(x[..., 1]) + imaginary = self.fsmn_re_L1(x[..., 1]) + self.fsmn_im_L1(x[..., 0]) + # output: [b*T,h,c,2], [6*106, h, 256, 2] + output = torch.stack((real, imaginary), dim=-1) + + output = torch.reshape(output, (b, T, h, c, d)) + output = torch.transpose(output, 1, 3) + return output + + +class ComplexConv2d(nn.Module): + # https://github.com/litcoderr/ComplexCNN/blob/master/complexcnn/modules.py + def __init__(self, + in_channel, + out_channel, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + **kwargs): + super().__init__() + + # Model components + self.conv_re = nn.Conv2d( + in_channel, + out_channel, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + **kwargs) + self.conv_im = nn.Conv2d( + in_channel, + out_channel, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + **kwargs) + + def forward(self, x): + r""" + + Args: + x: torch with shape: [batch,channel,axis1,axis2,2] + """ + real = self.conv_re(x[..., 0]) - self.conv_im(x[..., 1]) + imaginary = self.conv_re(x[..., 1]) + self.conv_im(x[..., 0]) + output = torch.stack((real, imaginary), dim=-1) + return output + + +class ComplexConvTranspose2d(nn.Module): + + def __init__(self, + in_channel, + out_channel, + kernel_size, + stride=1, + padding=0, + output_padding=0, + dilation=1, + groups=1, + bias=True, + **kwargs): + super().__init__() + + # Model components + self.tconv_re = nn.ConvTranspose2d( + in_channel, + out_channel, + kernel_size=kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + bias=bias, + dilation=dilation, + **kwargs) + self.tconv_im = nn.ConvTranspose2d( + in_channel, + out_channel, + kernel_size=kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + bias=bias, + dilation=dilation, + **kwargs) + + def forward(self, x): # shpae of x : [batch,channel,axis1,axis2,2] + real = self.tconv_re(x[..., 0]) - self.tconv_im(x[..., 1]) + imaginary = self.tconv_re(x[..., 1]) + self.tconv_im(x[..., 0]) + output = torch.stack((real, imaginary), dim=-1) + return output + + +class ComplexBatchNorm2d(nn.Module): + + def __init__(self, + num_features, + eps=1e-5, + momentum=0.1, + affine=True, + track_running_stats=True, + **kwargs): + super().__init__() + self.bn_re = nn.BatchNorm2d( + num_features=num_features, + momentum=momentum, + affine=affine, + eps=eps, + track_running_stats=track_running_stats, + **kwargs) + self.bn_im = nn.BatchNorm2d( + num_features=num_features, + momentum=momentum, + affine=affine, + eps=eps, + track_running_stats=track_running_stats, + **kwargs) + + def forward(self, x): + real = self.bn_re(x[..., 0]) + imag = self.bn_im(x[..., 1]) + output = torch.stack((real, imag), dim=-1) + return output diff --git a/modelscope/models/audio/ans/conv_stft.py b/modelscope/models/audio/ans/conv_stft.py new file mode 100644 index 00000000..a47d7817 --- /dev/null +++ b/modelscope/models/audio/ans/conv_stft.py @@ -0,0 +1,112 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from scipy.signal import get_window + + +def init_kernels(win_len, win_inc, fft_len, win_type=None, invers=False): + if win_type == 'None' or win_type is None: + window = np.ones(win_len) + else: + window = get_window(win_type, win_len, fftbins=True)**0.5 + + N = fft_len + fourier_basis = np.fft.rfft(np.eye(N))[:win_len] + real_kernel = np.real(fourier_basis) + imag_kernel = np.imag(fourier_basis) + kernel = np.concatenate([real_kernel, imag_kernel], 1).T + + if invers: + kernel = np.linalg.pinv(kernel).T + + kernel = kernel * window + kernel = kernel[:, None, :] + return torch.from_numpy(kernel.astype(np.float32)), torch.from_numpy( + window[None, :, None].astype(np.float32)) + + +class ConvSTFT(nn.Module): + + def __init__(self, + win_len, + win_inc, + fft_len=None, + win_type='hamming', + feature_type='real', + fix=True): + super(ConvSTFT, self).__init__() + + if fft_len is None: + self.fft_len = np.int(2**np.ceil(np.log2(win_len))) + else: + self.fft_len = fft_len + + kernel, _ = init_kernels(win_len, win_inc, self.fft_len, win_type) + self.weight = nn.Parameter(kernel, requires_grad=(not fix)) + self.feature_type = feature_type + self.stride = win_inc + self.win_len = win_len + self.dim = self.fft_len + + def forward(self, inputs): + if inputs.dim() == 2: + inputs = torch.unsqueeze(inputs, 1) + + outputs = F.conv1d(inputs, self.weight, stride=self.stride) + + if self.feature_type == 'complex': + return outputs + else: + dim = self.dim // 2 + 1 + real = outputs[:, :dim, :] + imag = outputs[:, dim:, :] + mags = torch.sqrt(real**2 + imag**2) + phase = torch.atan2(imag, real) + return mags, phase + + +class ConviSTFT(nn.Module): + + def __init__(self, + win_len, + win_inc, + fft_len=None, + win_type='hamming', + feature_type='real', + fix=True): + super(ConviSTFT, self).__init__() + if fft_len is None: + self.fft_len = np.int(2**np.ceil(np.log2(win_len))) + else: + self.fft_len = fft_len + kernel, window = init_kernels( + win_len, win_inc, self.fft_len, win_type, invers=True) + self.weight = nn.Parameter(kernel, requires_grad=(not fix)) + self.feature_type = feature_type + self.win_type = win_type + self.win_len = win_len + self.win_inc = win_inc + self.stride = win_inc + self.dim = self.fft_len + self.register_buffer('window', window) + self.register_buffer('enframe', torch.eye(win_len)[:, None, :]) + + def forward(self, inputs, phase=None): + """ + Args: + inputs : [B, N+2, T] (complex spec) or [B, N//2+1, T] (mags) + phase: [B, N//2+1, T] (if not none) + """ + + if phase is not None: + real = inputs * torch.cos(phase) + imag = inputs * torch.sin(phase) + inputs = torch.cat([real, imag], 1) + outputs = F.conv_transpose1d(inputs, self.weight, stride=self.stride) + + # this is from torch-stft: https://github.com/pseeth/torch-stft + t = self.window.repeat(1, 1, inputs.size(-1))**2 + coff = F.conv_transpose1d(t, self.enframe, stride=self.stride) + outputs = outputs / (coff + 1e-8) + return outputs diff --git a/modelscope/models/audio/ans/frcrn.py b/modelscope/models/audio/ans/frcrn.py new file mode 100644 index 00000000..c56b8773 --- /dev/null +++ b/modelscope/models/audio/ans/frcrn.py @@ -0,0 +1,309 @@ +import os +from typing import Dict + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from modelscope.metainfo import Models +from modelscope.models.builder import MODELS +from modelscope.utils.constant import ModelFile, Tasks +from ...base import Model, Tensor +from .conv_stft import ConviSTFT, ConvSTFT +from .unet import UNet + + +class FTB(nn.Module): + + def __init__(self, input_dim=257, in_channel=9, r_channel=5): + + super(FTB, self).__init__() + self.in_channel = in_channel + self.conv1 = nn.Sequential( + nn.Conv2d(in_channel, r_channel, kernel_size=[1, 1]), + nn.BatchNorm2d(r_channel), nn.ReLU()) + + self.conv1d = nn.Sequential( + nn.Conv1d( + r_channel * input_dim, in_channel, kernel_size=9, padding=4), + nn.BatchNorm1d(in_channel), nn.ReLU()) + self.freq_fc = nn.Linear(input_dim, input_dim, bias=False) + + self.conv2 = nn.Sequential( + nn.Conv2d(in_channel * 2, in_channel, kernel_size=[1, 1]), + nn.BatchNorm2d(in_channel), nn.ReLU()) + + def forward(self, inputs): + ''' + inputs should be [Batch, Ca, Dim, Time] + ''' + # T-F attention + conv1_out = self.conv1(inputs) + B, C, D, T = conv1_out.size() + reshape1_out = torch.reshape(conv1_out, [B, C * D, T]) + conv1d_out = self.conv1d(reshape1_out) + conv1d_out = torch.reshape(conv1d_out, [B, self.in_channel, 1, T]) + + # now is also [B,C,D,T] + att_out = conv1d_out * inputs + + # tranpose to [B,C,T,D] + att_out = torch.transpose(att_out, 2, 3) + freqfc_out = self.freq_fc(att_out) + att_out = torch.transpose(freqfc_out, 2, 3) + + cat_out = torch.cat([att_out, inputs], 1) + outputs = self.conv2(cat_out) + return outputs + + +@MODELS.register_module( + Tasks.speech_signal_process, module_name=Models.speech_frcrn_ans_cirm_16k) +class FRCRNModel(Model): + r""" A decorator of FRCRN for integrating into modelscope framework """ + + def __init__(self, model_dir: str, *args, **kwargs): + """initialize the frcrn model from the `model_dir` path. + + Args: + model_dir (str): the model path. + """ + super().__init__(model_dir, *args, **kwargs) + self._model = FRCRN(*args, **kwargs) + model_bin_file = os.path.join(model_dir, + ModelFile.TORCH_MODEL_BIN_FILE) + if os.path.exists(model_bin_file): + checkpoint = torch.load(model_bin_file) + self._model.load_state_dict(checkpoint, strict=False) + + def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: + output = self._model.forward(input) + return { + 'spec_l1': output[0], + 'wav_l1': output[1], + 'mask_l1': output[2], + 'spec_l2': output[3], + 'wav_l2': output[4], + 'mask_l2': output[5] + } + + def to(self, *args, **kwargs): + self._model = self._model.to(*args, **kwargs) + return self + + def eval(self): + self._model = self._model.train(False) + return self + + +class FRCRN(nn.Module): + r""" Frequency Recurrent CRN """ + + def __init__(self, + complex, + model_complexity, + model_depth, + log_amp, + padding_mode, + win_len=400, + win_inc=100, + fft_len=512, + win_type='hanning'): + r""" + Args: + complex: Whether to use complex networks. + model_complexity: define the model complexity with the number of layers + model_depth: Only two options are available : 10, 20 + log_amp: Whether to use log amplitude to estimate signals + padding_mode: Encoder's convolution filter. 'zeros', 'reflect' + win_len: length of window used for defining one frame of sample points + win_inc: length of window shifting (equivalent to hop_size) + fft_len: number of Short Time Fourier Transform (STFT) points + win_type: windowing type used in STFT, eg. 'hanning', 'hamming' + """ + super().__init__() + self.feat_dim = fft_len // 2 + 1 + + self.win_len = win_len + self.win_inc = win_inc + self.fft_len = fft_len + self.win_type = win_type + + fix = True + self.stft = ConvSTFT( + self.win_len, + self.win_inc, + self.fft_len, + self.win_type, + feature_type='complex', + fix=fix) + self.istft = ConviSTFT( + self.win_len, + self.win_inc, + self.fft_len, + self.win_type, + feature_type='complex', + fix=fix) + self.unet = UNet( + 1, + complex=complex, + model_complexity=model_complexity, + model_depth=model_depth, + padding_mode=padding_mode) + self.unet2 = UNet( + 1, + complex=complex, + model_complexity=model_complexity, + model_depth=model_depth, + padding_mode=padding_mode) + + def forward(self, inputs): + out_list = [] + # [B, D*2, T] + cmp_spec = self.stft(inputs) + # [B, 1, D*2, T] + cmp_spec = torch.unsqueeze(cmp_spec, 1) + + # to [B, 2, D, T] real_part/imag_part + cmp_spec = torch.cat([ + cmp_spec[:, :, :self.feat_dim, :], + cmp_spec[:, :, self.feat_dim:, :], + ], 1) + + # [B, 2, D, T] + cmp_spec = torch.unsqueeze(cmp_spec, 4) + # [B, 1, D, T, 2] + cmp_spec = torch.transpose(cmp_spec, 1, 4) + unet1_out = self.unet(cmp_spec) + cmp_mask1 = torch.tanh(unet1_out) + unet2_out = self.unet2(unet1_out) + cmp_mask2 = torch.tanh(unet2_out) + est_spec, est_wav, est_mask = self.apply_mask(cmp_spec, cmp_mask1) + out_list.append(est_spec) + out_list.append(est_wav) + out_list.append(est_mask) + cmp_mask2 = cmp_mask2 + cmp_mask1 + est_spec, est_wav, est_mask = self.apply_mask(cmp_spec, cmp_mask2) + out_list.append(est_spec) + out_list.append(est_wav) + out_list.append(est_mask) + return out_list + + def apply_mask(self, cmp_spec, cmp_mask): + est_spec = torch.cat([ + cmp_spec[:, :, :, :, 0] * cmp_mask[:, :, :, :, 0] + - cmp_spec[:, :, :, :, 1] * cmp_mask[:, :, :, :, 1], + cmp_spec[:, :, :, :, 0] * cmp_mask[:, :, :, :, 1] + + cmp_spec[:, :, :, :, 1] * cmp_mask[:, :, :, :, 0] + ], 1) + est_spec = torch.cat([est_spec[:, 0, :, :], est_spec[:, 1, :, :]], 1) + cmp_mask = torch.squeeze(cmp_mask, 1) + cmp_mask = torch.cat([cmp_mask[:, :, :, 0], cmp_mask[:, :, :, 1]], 1) + + est_wav = self.istft(est_spec) + est_wav = torch.squeeze(est_wav, 1) + return est_spec, est_wav, cmp_mask + + def get_params(self, weight_decay=0.0): + # add L2 penalty + weights, biases = [], [] + for name, param in self.named_parameters(): + if 'bias' in name: + biases += [param] + else: + weights += [param] + params = [{ + 'params': weights, + 'weight_decay': weight_decay, + }, { + 'params': biases, + 'weight_decay': 0.0, + }] + return params + + def loss(self, noisy, labels, out_list, mode='Mix'): + if mode == 'SiSNR': + count = 0 + while count < len(out_list): + est_spec = out_list[count] + count = count + 1 + est_wav = out_list[count] + count = count + 1 + est_mask = out_list[count] + count = count + 1 + if count != 3: + loss = self.loss_1layer(noisy, est_spec, est_wav, labels, + est_mask, mode) + return loss + + elif mode == 'Mix': + count = 0 + while count < len(out_list): + est_spec = out_list[count] + count = count + 1 + est_wav = out_list[count] + count = count + 1 + est_mask = out_list[count] + count = count + 1 + if count != 3: + amp_loss, phase_loss, SiSNR_loss = self.loss_1layer( + noisy, est_spec, est_wav, labels, est_mask, mode) + loss = amp_loss + phase_loss + SiSNR_loss + return loss, amp_loss, phase_loss + + def loss_1layer(self, noisy, est, est_wav, labels, cmp_mask, mode='Mix'): + r""" Compute the loss by mode + mode == 'Mix' + est: [B, F*2, T] + labels: [B, F*2,T] + mode == 'SiSNR' + est: [B, T] + labels: [B, T] + """ + if mode == 'SiSNR': + if labels.dim() == 3: + labels = torch.squeeze(labels, 1) + if est_wav.dim() == 3: + est_wav = torch.squeeze(est_wav, 1) + return -si_snr(est_wav, labels) + elif mode == 'Mix': + + if labels.dim() == 3: + labels = torch.squeeze(labels, 1) + if est_wav.dim() == 3: + est_wav = torch.squeeze(est_wav, 1) + SiSNR_loss = -si_snr(est_wav, labels) + + b, d, t = est.size() + S = self.stft(labels) + Sr = S[:, :self.feat_dim, :] + Si = S[:, self.feat_dim:, :] + Y = self.stft(noisy) + Yr = Y[:, :self.feat_dim, :] + Yi = Y[:, self.feat_dim:, :] + Y_pow = Yr**2 + Yi**2 + gth_mask = torch.cat([(Sr * Yr + Si * Yi) / (Y_pow + 1e-8), + (Si * Yr - Sr * Yi) / (Y_pow + 1e-8)], 1) + gth_mask[gth_mask > 2] = 1 + gth_mask[gth_mask < -2] = -1 + amp_loss = F.mse_loss(gth_mask[:, :self.feat_dim, :], + cmp_mask[:, :self.feat_dim, :]) * d + phase_loss = F.mse_loss(gth_mask[:, self.feat_dim:, :], + cmp_mask[:, self.feat_dim:, :]) * d + return amp_loss, phase_loss, SiSNR_loss + + +def l2_norm(s1, s2): + norm = torch.sum(s1 * s2, -1, keepdim=True) + return norm + + +def si_snr(s1, s2, eps=1e-8): + s1_s2_norm = l2_norm(s1, s2) + s2_s2_norm = l2_norm(s2, s2) + s_target = s1_s2_norm / (s2_s2_norm + eps) * s2 + e_nosie = s1 - s_target + target_norm = l2_norm(s_target, s_target) + noise_norm = l2_norm(e_nosie, e_nosie) + snr = 10 * torch.log10((target_norm) / (noise_norm + eps) + eps) + return torch.mean(snr) diff --git a/modelscope/models/audio/ans/se_module_complex.py b/modelscope/models/audio/ans/se_module_complex.py new file mode 100644 index 00000000..f62fe523 --- /dev/null +++ b/modelscope/models/audio/ans/se_module_complex.py @@ -0,0 +1,26 @@ +import torch +from torch import nn + + +class SELayer(nn.Module): + + def __init__(self, channel, reduction=16): + super(SELayer, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc_r = nn.Sequential( + nn.Linear(channel, channel // reduction), nn.ReLU(inplace=True), + nn.Linear(channel // reduction, channel), nn.Sigmoid()) + self.fc_i = nn.Sequential( + nn.Linear(channel, channel // reduction), nn.ReLU(inplace=True), + nn.Linear(channel // reduction, channel), nn.Sigmoid()) + + def forward(self, x): + b, c, _, _, _ = x.size() + x_r = self.avg_pool(x[:, :, :, :, 0]).view(b, c) + x_i = self.avg_pool(x[:, :, :, :, 1]).view(b, c) + y_r = self.fc_r(x_r).view(b, c, 1, 1, 1) - self.fc_i(x_i).view( + b, c, 1, 1, 1) + y_i = self.fc_r(x_i).view(b, c, 1, 1, 1) + self.fc_i(x_r).view( + b, c, 1, 1, 1) + y = torch.cat([y_r, y_i], 4) + return x * y diff --git a/modelscope/models/audio/ans/unet.py b/modelscope/models/audio/ans/unet.py new file mode 100644 index 00000000..aa5a4254 --- /dev/null +++ b/modelscope/models/audio/ans/unet.py @@ -0,0 +1,269 @@ +import torch +import torch.nn as nn + +from . import complex_nn +from .se_module_complex import SELayer + + +class Encoder(nn.Module): + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride, + padding=None, + complex=False, + padding_mode='zeros'): + super().__init__() + if padding is None: + padding = [(i - 1) // 2 for i in kernel_size] # 'SAME' padding + + if complex: + conv = complex_nn.ComplexConv2d + bn = complex_nn.ComplexBatchNorm2d + else: + conv = nn.Conv2d + bn = nn.BatchNorm2d + + self.conv = conv( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + padding_mode=padding_mode) + self.bn = bn(out_channels) + self.relu = nn.LeakyReLU(inplace=True) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + x = self.relu(x) + return x + + +class Decoder(nn.Module): + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride, + padding=(0, 0), + complex=False): + super().__init__() + if complex: + tconv = complex_nn.ComplexConvTranspose2d + bn = complex_nn.ComplexBatchNorm2d + else: + tconv = nn.ConvTranspose2d + bn = nn.BatchNorm2d + + self.transconv = tconv( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding) + self.bn = bn(out_channels) + self.relu = nn.LeakyReLU(inplace=True) + + def forward(self, x): + x = self.transconv(x) + x = self.bn(x) + x = self.relu(x) + return x + + +class UNet(nn.Module): + + def __init__(self, + input_channels=1, + complex=False, + model_complexity=45, + model_depth=20, + padding_mode='zeros'): + super().__init__() + + if complex: + model_complexity = int(model_complexity // 1.414) + + self.set_size( + model_complexity=model_complexity, + input_channels=input_channels, + model_depth=model_depth) + self.encoders = [] + self.model_length = model_depth // 2 + self.fsmn = complex_nn.ComplexUniDeepFsmn(128, 128, 128) + self.se_layers_enc = [] + self.fsmn_enc = [] + for i in range(self.model_length): + fsmn_enc = complex_nn.ComplexUniDeepFsmn_L1(128, 128, 128) + self.add_module('fsmn_enc{}'.format(i), fsmn_enc) + self.fsmn_enc.append(fsmn_enc) + module = Encoder( + self.enc_channels[i], + self.enc_channels[i + 1], + kernel_size=self.enc_kernel_sizes[i], + stride=self.enc_strides[i], + padding=self.enc_paddings[i], + complex=complex, + padding_mode=padding_mode) + self.add_module('encoder{}'.format(i), module) + self.encoders.append(module) + se_layer_enc = SELayer(self.enc_channels[i + 1], 8) + self.add_module('se_layer_enc{}'.format(i), se_layer_enc) + self.se_layers_enc.append(se_layer_enc) + self.decoders = [] + self.fsmn_dec = [] + self.se_layers_dec = [] + for i in range(self.model_length): + fsmn_dec = complex_nn.ComplexUniDeepFsmn_L1(128, 128, 128) + self.add_module('fsmn_dec{}'.format(i), fsmn_dec) + self.fsmn_dec.append(fsmn_dec) + module = Decoder( + self.dec_channels[i] * 2, + self.dec_channels[i + 1], + kernel_size=self.dec_kernel_sizes[i], + stride=self.dec_strides[i], + padding=self.dec_paddings[i], + complex=complex) + self.add_module('decoder{}'.format(i), module) + self.decoders.append(module) + if i < self.model_length - 1: + se_layer_dec = SELayer(self.dec_channels[i + 1], 8) + self.add_module('se_layer_dec{}'.format(i), se_layer_dec) + self.se_layers_dec.append(se_layer_dec) + if complex: + conv = complex_nn.ComplexConv2d + else: + conv = nn.Conv2d + + linear = conv(self.dec_channels[-1], 1, 1) + + self.add_module('linear', linear) + self.complex = complex + self.padding_mode = padding_mode + + self.decoders = nn.ModuleList(self.decoders) + self.encoders = nn.ModuleList(self.encoders) + self.se_layers_enc = nn.ModuleList(self.se_layers_enc) + self.se_layers_dec = nn.ModuleList(self.se_layers_dec) + self.fsmn_enc = nn.ModuleList(self.fsmn_enc) + self.fsmn_dec = nn.ModuleList(self.fsmn_dec) + + def forward(self, inputs): + x = inputs + # go down + xs = [] + xs_se = [] + xs_se.append(x) + for i, encoder in enumerate(self.encoders): + xs.append(x) + if i > 0: + x = self.fsmn_enc[i](x) + x = encoder(x) + xs_se.append(self.se_layers_enc[i](x)) + # xs : x0=input x1 ... x9 + x = self.fsmn(x) + + p = x + for i, decoder in enumerate(self.decoders): + p = decoder(p) + if i < self.model_length - 1: + p = self.fsmn_dec[i](p) + if i == self.model_length - 1: + break + if i < self.model_length - 2: + p = self.se_layers_dec[i](p) + p = torch.cat([p, xs_se[self.model_length - 1 - i]], dim=1) + + # cmp_spec: [12, 1, 513, 64, 2] + cmp_spec = self.linear(p) + return cmp_spec + + def set_size(self, model_complexity, model_depth=20, input_channels=1): + + if model_depth == 14: + self.enc_channels = [ + input_channels, 128, 128, 128, 128, 128, 128, 128 + ] + self.enc_kernel_sizes = [(5, 2), (5, 2), (5, 2), (5, 2), (5, 2), + (5, 2), (2, 2)] + self.enc_strides = [(2, 1), (2, 1), (2, 1), (2, 1), (2, 1), (2, 1), + (2, 1)] + self.enc_paddings = [(0, 1), (0, 1), (0, 1), (0, 1), (0, 1), + (0, 1), (0, 1)] + self.dec_channels = [64, 128, 128, 128, 128, 128, 128, 1] + self.dec_kernel_sizes = [(2, 2), (5, 2), (5, 2), (5, 2), (6, 2), + (5, 2), (5, 2)] + self.dec_strides = [(2, 1), (2, 1), (2, 1), (2, 1), (2, 1), (2, 1), + (2, 1)] + self.dec_paddings = [(0, 1), (0, 1), (0, 1), (0, 1), (0, 1), + (0, 1), (0, 1)] + + elif model_depth == 10: + self.enc_channels = [ + input_channels, + 16, + 32, + 64, + 128, + 256, + ] + self.enc_kernel_sizes = [(3, 3), (3, 3), (3, 3), (3, 3), (3, 3)] + self.enc_strides = [(2, 1), (2, 1), (2, 1), (2, 1), (2, 1)] + self.enc_paddings = [(0, 1), (0, 1), (0, 1), (0, 1), (0, 1)] + self.dec_channels = [128, 128, 64, 32, 16, 1] + self.dec_kernel_sizes = [(3, 3), (3, 3), (3, 3), (4, 3), (3, 3)] + self.dec_strides = [(2, 1), (2, 1), (2, 1), (2, 1), (2, 1)] + self.dec_paddings = [(0, 1), (0, 1), (0, 1), (0, 1), (0, 1)] + + elif model_depth == 20: + self.enc_channels = [ + input_channels, model_complexity, model_complexity, + model_complexity * 2, model_complexity * 2, + model_complexity * 2, model_complexity * 2, + model_complexity * 2, model_complexity * 2, + model_complexity * 2, 128 + ] + + self.enc_kernel_sizes = [(7, 1), (1, 7), (6, 4), (7, 5), (5, 3), + (5, 3), (5, 3), (5, 3), (5, 3), (5, 3)] + + self.enc_strides = [(1, 1), (1, 1), (2, 2), (2, 1), (2, 2), (2, 1), + (2, 2), (2, 1), (2, 2), (2, 1)] + + self.enc_paddings = [ + (3, 0), + (0, 3), + None, # (0, 2), + None, + None, # (3,1), + None, # (3,1), + None, # (1,2), + None, + None, + None + ] + + self.dec_channels = [ + 0, model_complexity * 2, model_complexity * 2, + model_complexity * 2, model_complexity * 2, + model_complexity * 2, model_complexity * 2, + model_complexity * 2, model_complexity * 2, + model_complexity * 2, model_complexity * 2, + model_complexity * 2 + ] + + self.dec_kernel_sizes = [(4, 3), (4, 2), (4, 3), (4, 2), (4, 3), + (4, 2), (6, 3), (7, 4), (1, 7), (7, 1)] + + self.dec_strides = [(2, 1), (2, 2), (2, 1), (2, 2), (2, 1), (2, 2), + (2, 1), (2, 2), (1, 1), (1, 1)] + + self.dec_paddings = [(1, 1), (1, 0), (1, 1), (1, 0), (1, 1), + (1, 0), (2, 1), (2, 1), (0, 3), (3, 0)] + else: + raise ValueError('Unknown model depth : {}'.format(model_depth)) diff --git a/modelscope/models/cv/animal_recognition/__init__.py b/modelscope/models/cv/animal_recognition/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/cv/animal_recognition/resnet.py b/modelscope/models/cv/animal_recognition/resnet.py new file mode 100644 index 00000000..1fd4b93e --- /dev/null +++ b/modelscope/models/cv/animal_recognition/resnet.py @@ -0,0 +1,430 @@ +import math + +import torch +import torch.nn as nn + +from .splat import SplAtConv2d + +__all__ = ['ResNet', 'Bottleneck'] + + +class DropBlock2D(object): + + def __init__(self, *args, **kwargs): + raise NotImplementedError + + +class GlobalAvgPool2d(nn.Module): + + def __init__(self): + """Global average pooling over the input's spatial dimensions""" + super(GlobalAvgPool2d, self).__init__() + + def forward(self, inputs): + return nn.functional.adaptive_avg_pool2d(inputs, + 1).view(inputs.size(0), -1) + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, + inplanes, + planes, + stride=1, + downsample=None, + radix=1, + cardinality=1, + bottleneck_width=64, + avd=False, + avd_first=False, + dilation=1, + is_first=False, + rectified_conv=False, + rectify_avg=False, + norm_layer=None, + dropblock_prob=0.0, + last_gamma=False): + super(Bottleneck, self).__init__() + group_width = int(planes * (bottleneck_width / 64.)) * cardinality + self.conv1 = nn.Conv2d( + inplanes, group_width, kernel_size=1, bias=False) + self.bn1 = norm_layer(group_width) + self.dropblock_prob = dropblock_prob + self.radix = radix + self.avd = avd and (stride > 1 or is_first) + self.avd_first = avd_first + + if self.avd: + self.avd_layer = nn.AvgPool2d(3, stride, padding=1) + stride = 1 + + if dropblock_prob > 0.0: + self.dropblock1 = DropBlock2D(dropblock_prob, 3) + if radix == 1: + self.dropblock2 = DropBlock2D(dropblock_prob, 3) + self.dropblock3 = DropBlock2D(dropblock_prob, 3) + + if radix >= 1: + self.conv2 = SplAtConv2d( + group_width, + group_width, + kernel_size=3, + stride=stride, + padding=dilation, + dilation=dilation, + groups=cardinality, + bias=False, + radix=radix, + rectify=rectified_conv, + rectify_avg=rectify_avg, + norm_layer=norm_layer, + dropblock_prob=dropblock_prob) + elif rectified_conv: + from rfconv import RFConv2d + self.conv2 = RFConv2d( + group_width, + group_width, + kernel_size=3, + stride=stride, + padding=dilation, + dilation=dilation, + groups=cardinality, + bias=False, + average_mode=rectify_avg) + self.bn2 = norm_layer(group_width) + else: + self.conv2 = nn.Conv2d( + group_width, + group_width, + kernel_size=3, + stride=stride, + padding=dilation, + dilation=dilation, + groups=cardinality, + bias=False) + self.bn2 = norm_layer(group_width) + + self.conv3 = nn.Conv2d( + group_width, planes * 4, kernel_size=1, bias=False) + self.bn3 = norm_layer(planes * 4) + + if last_gamma: + from torch.nn.init import zeros_ + zeros_(self.bn3.weight) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.dilation = dilation + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + if self.dropblock_prob > 0.0: + out = self.dropblock1(out) + out = self.relu(out) + + if self.avd and self.avd_first: + out = self.avd_layer(out) + + out = self.conv2(out) + if self.radix == 0: + out = self.bn2(out) + if self.dropblock_prob > 0.0: + out = self.dropblock2(out) + out = self.relu(out) + + if self.avd and not self.avd_first: + out = self.avd_layer(out) + + out = self.conv3(out) + out = self.bn3(out) + if self.dropblock_prob > 0.0: + out = self.dropblock3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__(self, + block, + layers, + radix=1, + groups=1, + bottleneck_width=64, + num_classes=1000, + dilated=False, + dilation=1, + deep_stem=False, + stem_width=64, + avg_down=False, + rectified_conv=False, + rectify_avg=False, + avd=False, + avd_first=False, + final_drop=0.0, + dropblock_prob=0, + last_gamma=False, + norm_layer=nn.BatchNorm2d): + self.cardinality = groups + self.bottleneck_width = bottleneck_width + # ResNet-D params + self.inplanes = stem_width * 2 if deep_stem else 64 + self.avg_down = avg_down + self.last_gamma = last_gamma + # ResNeSt params + self.radix = radix + self.avd = avd + self.avd_first = avd_first + + super(ResNet, self).__init__() + self.rectified_conv = rectified_conv + self.rectify_avg = rectify_avg + if rectified_conv: + from rfconv import RFConv2d + conv_layer = RFConv2d + else: + conv_layer = nn.Conv2d + conv_kwargs = {'average_mode': rectify_avg} if rectified_conv else {} + if deep_stem: + self.conv1 = nn.Sequential( + conv_layer( + 3, + stem_width, + kernel_size=3, + stride=2, + padding=1, + bias=False, + **conv_kwargs), + norm_layer(stem_width), + nn.ReLU(inplace=True), + conv_layer( + stem_width, + stem_width, + kernel_size=3, + stride=1, + padding=1, + bias=False, + **conv_kwargs), + norm_layer(stem_width), + nn.ReLU(inplace=True), + conv_layer( + stem_width, + stem_width * 2, + kernel_size=3, + stride=1, + padding=1, + bias=False, + **conv_kwargs), + ) + else: + self.conv1 = conv_layer( + 3, + 64, + kernel_size=7, + stride=2, + padding=3, + bias=False, + **conv_kwargs) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer( + block, 64, layers[0], norm_layer=norm_layer, is_first=False) + self.layer2 = self._make_layer( + block, 128, layers[1], stride=2, norm_layer=norm_layer) + if dilated or dilation == 4: + self.layer3 = self._make_layer( + block, + 256, + layers[2], + stride=1, + dilation=2, + norm_layer=norm_layer, + dropblock_prob=dropblock_prob) + self.layer4 = self._make_layer( + block, + 512, + layers[3], + stride=1, + dilation=4, + norm_layer=norm_layer, + dropblock_prob=dropblock_prob) + elif dilation == 2: + self.layer3 = self._make_layer( + block, + 256, + layers[2], + stride=2, + dilation=1, + norm_layer=norm_layer, + dropblock_prob=dropblock_prob) + self.layer4 = self._make_layer( + block, + 512, + layers[3], + stride=1, + dilation=2, + norm_layer=norm_layer, + dropblock_prob=dropblock_prob) + else: + self.layer3 = self._make_layer( + block, + 256, + layers[2], + stride=2, + norm_layer=norm_layer, + dropblock_prob=dropblock_prob) + self.layer4 = self._make_layer( + block, + 512, + layers[3], + stride=2, + norm_layer=norm_layer, + dropblock_prob=dropblock_prob) + self.avgpool = GlobalAvgPool2d() + self.drop = nn.Dropout(final_drop) if final_drop > 0.0 else None + self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, norm_layer): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_layer(self, + block, + planes, + blocks, + stride=1, + dilation=1, + norm_layer=None, + dropblock_prob=0.0, + is_first=True): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + down_layers = [] + if self.avg_down: + if dilation == 1: + down_layers.append( + nn.AvgPool2d( + kernel_size=stride, + stride=stride, + ceil_mode=True, + count_include_pad=False)) + else: + down_layers.append( + nn.AvgPool2d( + kernel_size=1, + stride=1, + ceil_mode=True, + count_include_pad=False)) + down_layers.append( + nn.Conv2d( + self.inplanes, + planes * block.expansion, + kernel_size=1, + stride=1, + bias=False)) + else: + down_layers.append( + nn.Conv2d( + self.inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=False)) + down_layers.append(norm_layer(planes * block.expansion)) + downsample = nn.Sequential(*down_layers) + + layers = [] + if dilation == 1 or dilation == 2: + layers.append( + block( + self.inplanes, + planes, + stride, + downsample=downsample, + radix=self.radix, + cardinality=self.cardinality, + bottleneck_width=self.bottleneck_width, + avd=self.avd, + avd_first=self.avd_first, + dilation=1, + is_first=is_first, + rectified_conv=self.rectified_conv, + rectify_avg=self.rectify_avg, + norm_layer=norm_layer, + dropblock_prob=dropblock_prob, + last_gamma=self.last_gamma)) + elif dilation == 4: + layers.append( + block( + self.inplanes, + planes, + stride, + downsample=downsample, + radix=self.radix, + cardinality=self.cardinality, + bottleneck_width=self.bottleneck_width, + avd=self.avd, + avd_first=self.avd_first, + dilation=2, + is_first=is_first, + rectified_conv=self.rectified_conv, + rectify_avg=self.rectify_avg, + norm_layer=norm_layer, + dropblock_prob=dropblock_prob, + last_gamma=self.last_gamma)) + else: + raise RuntimeError('=> unknown dilation size: {}'.format(dilation)) + + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append( + block( + self.inplanes, + planes, + radix=self.radix, + cardinality=self.cardinality, + bottleneck_width=self.bottleneck_width, + avd=self.avd, + avd_first=self.avd_first, + dilation=dilation, + rectified_conv=self.rectified_conv, + rectify_avg=self.rectify_avg, + norm_layer=norm_layer, + dropblock_prob=dropblock_prob, + last_gamma=self.last_gamma)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = torch.flatten(x, 1) + if self.drop: + x = self.drop(x) + x = self.fc(x) + + return x diff --git a/modelscope/models/cv/animal_recognition/splat.py b/modelscope/models/cv/animal_recognition/splat.py new file mode 100644 index 00000000..b12bf154 --- /dev/null +++ b/modelscope/models/cv/animal_recognition/splat.py @@ -0,0 +1,125 @@ +"""Split-Attention""" + +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn import BatchNorm2d, Conv2d, Linear, Module, ReLU +from torch.nn.modules.utils import _pair + +__all__ = ['SplAtConv2d'] + + +class SplAtConv2d(Module): + """Split-Attention Conv2d + """ + + def __init__(self, + in_channels, + channels, + kernel_size, + stride=(1, 1), + padding=(0, 0), + dilation=(1, 1), + groups=1, + bias=True, + radix=2, + reduction_factor=4, + rectify=False, + rectify_avg=False, + norm_layer=None, + dropblock_prob=0.0, + **kwargs): + super(SplAtConv2d, self).__init__() + padding = _pair(padding) + self.rectify = rectify and (padding[0] > 0 or padding[1] > 0) + self.rectify_avg = rectify_avg + inter_channels = max(in_channels * radix // reduction_factor, 32) + self.radix = radix + self.cardinality = groups + self.channels = channels + self.dropblock_prob = dropblock_prob + if self.rectify: + from rfconv import RFConv2d + self.conv = RFConv2d( + in_channels, + channels * radix, + kernel_size, + stride, + padding, + dilation, + groups=groups * radix, + bias=bias, + average_mode=rectify_avg, + **kwargs) + else: + self.conv = Conv2d( + in_channels, + channels * radix, + kernel_size, + stride, + padding, + dilation, + groups=groups * radix, + bias=bias, + **kwargs) + self.use_bn = norm_layer is not None + if self.use_bn: + self.bn0 = norm_layer(channels * radix) + self.relu = ReLU(inplace=True) + self.fc1 = Conv2d(channels, inter_channels, 1, groups=self.cardinality) + if self.use_bn: + self.bn1 = norm_layer(inter_channels) + self.fc2 = Conv2d( + inter_channels, channels * radix, 1, groups=self.cardinality) + if dropblock_prob > 0.0: + self.dropblock = DropBlock2D(dropblock_prob, 3) + self.rsoftmax = rSoftMax(radix, groups) + + def forward(self, x): + x = self.conv(x) + if self.use_bn: + x = self.bn0(x) + if self.dropblock_prob > 0.0: + x = self.dropblock(x) + x = self.relu(x) + + batch, rchannel = x.shape[:2] + if self.radix > 1: + splited = torch.split(x, rchannel // self.radix, dim=1) + gap = sum(splited) + else: + gap = x + gap = F.adaptive_avg_pool2d(gap, 1) + gap = self.fc1(gap) + + if self.use_bn: + gap = self.bn1(gap) + gap = self.relu(gap) + + atten = self.fc2(gap) + atten = self.rsoftmax(atten).view(batch, -1, 1, 1) + + if self.radix > 1: + attens = torch.split(atten, rchannel // self.radix, dim=1) + out = sum([att * split for (att, split) in zip(attens, splited)]) + else: + out = atten * x + return out.contiguous() + + +class rSoftMax(nn.Module): + + def __init__(self, radix, cardinality): + super().__init__() + self.radix = radix + self.cardinality = cardinality + + def forward(self, x): + batch = x.size(0) + if self.radix > 1: + x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2) + x = F.softmax(x, dim=1) + x = x.reshape(batch, -1) + else: + x = torch.sigmoid(x) + return x diff --git a/modelscope/models/nlp/masked_language_model.py b/modelscope/models/nlp/masked_language_model.py index bb255c9c..0410f73c 100644 --- a/modelscope/models/nlp/masked_language_model.py +++ b/modelscope/models/nlp/masked_language_model.py @@ -7,7 +7,7 @@ from ...utils.constant import Tasks from ..base import Model, Tensor from ..builder import MODELS -__all__ = ['StructBertForMaskedLM', 'VecoForMaskedLM'] +__all__ = ['BertForMaskedLM', 'StructBertForMaskedLM', 'VecoForMaskedLM'] class MaskedLanguageModelBase(Model): @@ -61,3 +61,11 @@ class VecoForMaskedLM(MaskedLanguageModelBase): def build_model(self): from sofa import VecoForMaskedLM return VecoForMaskedLM.from_pretrained(self.model_dir) + + +@MODELS.register_module(Tasks.fill_mask, module_name=Models.bert) +class BertForMaskedLM(MaskedLanguageModelBase): + + def build_model(self): + from transformers import BertForMaskedLM + return BertForMaskedLM.from_pretrained(self.model_dir) diff --git a/modelscope/pipelines/__init__.py b/modelscope/pipelines/__init__.py index 14865872..74f5507f 100644 --- a/modelscope/pipelines/__init__.py +++ b/modelscope/pipelines/__init__.py @@ -1,4 +1,5 @@ from .audio import LinearAECPipeline +from .audio.ans_pipeline import ANSPipeline from .base import Pipeline from .builder import pipeline from .cv import * # noqa F403 diff --git a/modelscope/pipelines/audio/ans_pipeline.py b/modelscope/pipelines/audio/ans_pipeline.py new file mode 100644 index 00000000..d9a04a29 --- /dev/null +++ b/modelscope/pipelines/audio/ans_pipeline.py @@ -0,0 +1,117 @@ +import os.path +from typing import Any, Dict + +import librosa +import numpy as np +import soundfile as sf +import torch + +from modelscope.metainfo import Pipelines +from modelscope.utils.constant import Tasks +from ..base import Input, Pipeline +from ..builder import PIPELINES + + +def audio_norm(x): + rms = (x**2).mean()**0.5 + scalar = 10**(-25 / 20) / rms + x = x * scalar + pow_x = x**2 + avg_pow_x = pow_x.mean() + rmsx = pow_x[pow_x > avg_pow_x].mean()**0.5 + scalarx = 10**(-25 / 20) / rmsx + x = x * scalarx + return x + + +@PIPELINES.register_module( + Tasks.speech_signal_process, + module_name=Pipelines.speech_frcrn_ans_cirm_16k) +class ANSPipeline(Pipeline): + r"""ANS (Acoustic Noise Suppression) Inference Pipeline . + + When invoke the class with pipeline.__call__(), it accept only one parameter: + inputs(str): the path of wav file + """ + SAMPLE_RATE = 16000 + + def __init__(self, model): + r""" + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model) + self.device = torch.device( + 'cuda' if torch.cuda.is_available() else 'cpu') + self.model = self.model.to(self.device) + self.model.eval() + + def preprocess(self, inputs: Input) -> Dict[str, Any]: + assert isinstance(inputs, str) and os.path.exists(inputs) and os.path.isfile(inputs), \ + f'Input file do not exists: {inputs}' + data1, fs = sf.read(inputs) + data1 = audio_norm(data1) + if fs != self.SAMPLE_RATE: + data1 = librosa.resample(data1, fs, self.SAMPLE_RATE) + if len(data1.shape) > 1: + data1 = data1[:, 0] + data = data1.astype(np.float32) + inputs = np.reshape(data, [1, data.shape[0]]) + return {'ndarray': inputs, 'nsamples': data.shape[0]} + + def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + ndarray = inputs['ndarray'] + nsamples = inputs['nsamples'] + decode_do_segement = False + window = 16000 + stride = int(window * 0.75) + print('inputs:{}'.format(ndarray.shape)) + b, t = ndarray.shape # size() + if t > window * 120: + decode_do_segement = True + + if t < window: + ndarray = np.concatenate( + [ndarray, np.zeros((ndarray.shape[0], window - t))], 1) + elif t < window + stride: + padding = window + stride - t + print('padding: {}'.format(padding)) + ndarray = np.concatenate( + [ndarray, np.zeros((ndarray.shape[0], padding))], 1) + else: + if (t - window) % stride != 0: + padding = t - (t - window) // stride * stride + print('padding: {}'.format(padding)) + ndarray = np.concatenate( + [ndarray, np.zeros((ndarray.shape[0], padding))], 1) + print('inputs after padding:{}'.format(ndarray.shape)) + with torch.no_grad(): + ndarray = torch.from_numpy(np.float32(ndarray)).to(self.device) + b, t = ndarray.shape + if decode_do_segement: + outputs = np.zeros(t) + give_up_length = (window - stride) // 2 + current_idx = 0 + while current_idx + window <= t: + print('current_idx: {}'.format(current_idx)) + tmp_input = ndarray[:, current_idx:current_idx + window] + tmp_output = self.model( + tmp_input, )['wav_l2'][0].cpu().numpy() + end_index = current_idx + window - give_up_length + if current_idx == 0: + outputs[current_idx: + end_index] = tmp_output[:-give_up_length] + else: + outputs[current_idx + + give_up_length:end_index] = tmp_output[ + give_up_length:-give_up_length] + current_idx += stride + else: + outputs = self.model(ndarray)['wav_l2'][0].cpu().numpy() + return {'output_pcm': outputs[:nsamples]} + + def postprocess(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]: + if 'output_path' in kwargs.keys(): + sf.write(kwargs['output_path'], inputs['output_pcm'], + self.SAMPLE_RATE) + return inputs diff --git a/modelscope/pipelines/cv/__init__.py b/modelscope/pipelines/cv/__init__.py index 68d875ec..b046e076 100644 --- a/modelscope/pipelines/cv/__init__.py +++ b/modelscope/pipelines/cv/__init__.py @@ -1,4 +1,5 @@ from .action_recognition_pipeline import ActionRecognitionPipeline +from .animal_recog_pipeline import AnimalRecogPipeline from .image_cartoon_pipeline import ImageCartoonPipeline from .image_matting_pipeline import ImageMattingPipeline from .ocr_detection_pipeline import OCRDetectionPipeline diff --git a/modelscope/pipelines/cv/animal_recog_pipeline.py b/modelscope/pipelines/cv/animal_recog_pipeline.py new file mode 100644 index 00000000..eee9e844 --- /dev/null +++ b/modelscope/pipelines/cv/animal_recog_pipeline.py @@ -0,0 +1,127 @@ +import os.path as osp +import tempfile +from typing import Any, Dict + +import cv2 +import numpy as np +import torch +from PIL import Image +from torchvision import transforms + +from modelscope.fileio import File +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.metainfo import Pipelines +from modelscope.models.cv.animal_recognition import resnet +from modelscope.pipelines.base import Input +from modelscope.preprocessors import load_image +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger +from ..base import Pipeline +from ..builder import PIPELINES + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.image_classification, module_name=Pipelines.animal_recognation) +class AnimalRecogPipeline(Pipeline): + + def __init__(self, model: str): + super().__init__(model=model) + import torch + + def resnest101(**kwargs): + model = resnet.ResNet( + resnet.Bottleneck, [3, 4, 23, 3], + radix=2, + groups=1, + bottleneck_width=64, + deep_stem=True, + stem_width=64, + avg_down=True, + avd=True, + avd_first=False, + **kwargs) + return model + + def filter_param(src_params, own_state): + copied_keys = [] + for name, param in src_params.items(): + if 'module.' == name[0:7]: + name = name[7:] + if '.module.' not in list(own_state.keys())[0]: + name = name.replace('.module.', '.') + if (name in own_state) and (own_state[name].shape + == param.shape): + own_state[name].copy_(param) + copied_keys.append(name) + + def load_pretrained(model, src_params): + if 'state_dict' in src_params: + src_params = src_params['state_dict'] + own_state = model.state_dict() + filter_param(src_params, own_state) + model.load_state_dict(own_state) + + self.model = resnest101(num_classes=8288) + local_model_dir = model + if osp.exists(model): + local_model_dir = model + else: + local_model_dir = snapshot_download(model) + self.local_path = local_model_dir + src_params = torch.load( + osp.join(local_model_dir, 'pytorch_model.pt'), 'cpu') + load_pretrained(self.model, src_params) + logger.info('load model done') + + def preprocess(self, input: Input) -> Dict[str, Any]: + if isinstance(input, str): + img = load_image(input) + elif isinstance(input, PIL.Image.Image): + img = input.convert('RGB') + elif isinstance(input, np.ndarray): + if len(input.shape) == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + img = input[:, :, ::-1] + img = Image.fromarray(img.astype('uint8')).convert('RGB') + else: + raise TypeError(f'input should be either str, PIL.Image,' + f' np.array, but got {type(input)}') + + normalize = transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + test_transforms = transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), normalize + ]) + img = test_transforms(img) + result = {'img': img} + return result + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + + def set_phase(model, is_train): + if is_train: + model.train() + else: + model.eval() + + is_train = False + set_phase(self.model, is_train) + img = input['img'] + input_img = torch.unsqueeze(img, 0) + outputs = self.model(input_img) + return {'outputs': outputs} + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + label_mapping_path = osp.join(self.local_path, 'label_mapping.txt') + with open(label_mapping_path, 'r') as f: + label_mapping = f.readlines() + score = torch.max(inputs['outputs']) + inputs = { + 'scores': score.item(), + 'labels': label_mapping[inputs['outputs'].argmax()].split('\t')[1] + } + return inputs diff --git a/modelscope/pipelines/cv/ocr_detection_pipeline.py b/modelscope/pipelines/cv/ocr_detection_pipeline.py index 0502fe36..4856b06b 100644 --- a/modelscope/pipelines/cv/ocr_detection_pipeline.py +++ b/modelscope/pipelines/cv/ocr_detection_pipeline.py @@ -8,7 +8,6 @@ import cv2 import numpy as np import PIL import tensorflow as tf -import tf_slim as slim from modelscope.metainfo import Pipelines from modelscope.pipelines.base import Input @@ -19,6 +18,11 @@ from ..base import Pipeline from ..builder import PIPELINES from .ocr_utils import model_resnet_mutex_v4_linewithchar, ops, utils +if tf.__version__ >= '2.0': + import tf_slim as slim +else: + from tensorflow.contrib import slim + if tf.__version__ >= '2.0': tf = tf.compat.v1 tf.compat.v1.disable_eager_execution() @@ -44,6 +48,7 @@ class OCRDetectionPipeline(Pipeline): def __init__(self, model: str): super().__init__(model=model) + tf.reset_default_graph() model_path = osp.join( osp.join(self.model, ModelFile.TF_CHECKPOINT_FOLDER), 'checkpoint-80000') @@ -51,51 +56,56 @@ class OCRDetectionPipeline(Pipeline): config = tf.ConfigProto(allow_soft_placement=True) config.gpu_options.allow_growth = True self._session = tf.Session(config=config) - global_step = tf.get_variable( - 'global_step', [], - initializer=tf.constant_initializer(0), - dtype=tf.int64, - trainable=False) - variable_averages = tf.train.ExponentialMovingAverage( - 0.997, global_step) self.input_images = tf.placeholder( tf.float32, shape=[1, 1024, 1024, 3], name='input_images') self.output = {} - # detector - detector = model_resnet_mutex_v4_linewithchar.SegLinkDetector() - all_maps = detector.build_model(self.input_images, is_training=False) - - # decode local predictions - all_nodes, all_links, all_reg = [], [], [] - for i, maps in enumerate(all_maps): - cls_maps, lnk_maps, reg_maps = maps[0], maps[1], maps[2] - reg_maps = tf.multiply(reg_maps, OFFSET_VARIANCE) - - cls_prob = tf.nn.softmax(tf.reshape(cls_maps, [-1, 2])) - - lnk_prob_pos = tf.nn.softmax(tf.reshape(lnk_maps, [-1, 4])[:, :2]) - lnk_prob_mut = tf.nn.softmax(tf.reshape(lnk_maps, [-1, 4])[:, 2:]) - lnk_prob = tf.concat([lnk_prob_pos, lnk_prob_mut], axis=1) - - all_nodes.append(cls_prob) - all_links.append(lnk_prob) - all_reg.append(reg_maps) - - # decode segments and links - image_size = tf.shape(self.input_images)[1:3] - segments, group_indices, segment_counts, _ = ops.decode_segments_links_python( - image_size, - all_nodes, - all_links, - all_reg, - anchor_sizes=list(detector.anchor_sizes)) - - # combine segments - combined_rboxes, combined_counts = ops.combine_segments_python( - segments, group_indices, segment_counts) - self.output['combined_rboxes'] = combined_rboxes - self.output['combined_counts'] = combined_counts + with tf.variable_scope('', reuse=tf.AUTO_REUSE): + global_step = tf.get_variable( + 'global_step', [], + initializer=tf.constant_initializer(0), + dtype=tf.int64, + trainable=False) + variable_averages = tf.train.ExponentialMovingAverage( + 0.997, global_step) + + # detector + detector = model_resnet_mutex_v4_linewithchar.SegLinkDetector() + all_maps = detector.build_model( + self.input_images, is_training=False) + + # decode local predictions + all_nodes, all_links, all_reg = [], [], [] + for i, maps in enumerate(all_maps): + cls_maps, lnk_maps, reg_maps = maps[0], maps[1], maps[2] + reg_maps = tf.multiply(reg_maps, OFFSET_VARIANCE) + + cls_prob = tf.nn.softmax(tf.reshape(cls_maps, [-1, 2])) + + lnk_prob_pos = tf.nn.softmax( + tf.reshape(lnk_maps, [-1, 4])[:, :2]) + lnk_prob_mut = tf.nn.softmax( + tf.reshape(lnk_maps, [-1, 4])[:, 2:]) + lnk_prob = tf.concat([lnk_prob_pos, lnk_prob_mut], axis=1) + + all_nodes.append(cls_prob) + all_links.append(lnk_prob) + all_reg.append(reg_maps) + + # decode segments and links + image_size = tf.shape(self.input_images)[1:3] + segments, group_indices, segment_counts, _ = ops.decode_segments_links_python( + image_size, + all_nodes, + all_links, + all_reg, + anchor_sizes=list(detector.anchor_sizes)) + + # combine segments + combined_rboxes, combined_counts = ops.combine_segments_python( + segments, group_indices, segment_counts) + self.output['combined_rboxes'] = combined_rboxes + self.output['combined_counts'] = combined_counts with self._session.as_default() as sess: logger.info(f'loading model from {model_path}') diff --git a/modelscope/pipelines/cv/ocr_utils/model_resnet_mutex_v4_linewithchar.py b/modelscope/pipelines/cv/ocr_utils/model_resnet_mutex_v4_linewithchar.py index 50b8ba02..d03ff405 100644 --- a/modelscope/pipelines/cv/ocr_utils/model_resnet_mutex_v4_linewithchar.py +++ b/modelscope/pipelines/cv/ocr_utils/model_resnet_mutex_v4_linewithchar.py @@ -1,8 +1,12 @@ import tensorflow as tf -import tf_slim as slim from . import ops, resnet18_v1, resnet_utils +if tf.__version__ >= '2.0': + import tf_slim as slim +else: + from tensorflow.contrib import slim + if tf.__version__ >= '2.0': tf = tf.compat.v1 diff --git a/modelscope/pipelines/cv/ocr_utils/resnet18_v1.py b/modelscope/pipelines/cv/ocr_utils/resnet18_v1.py index 6371d4e5..7930c5a3 100644 --- a/modelscope/pipelines/cv/ocr_utils/resnet18_v1.py +++ b/modelscope/pipelines/cv/ocr_utils/resnet18_v1.py @@ -30,10 +30,14 @@ ResNet-101 for semantic segmentation into 21 classes: output_stride=16) """ import tensorflow as tf -import tf_slim as slim from . import resnet_utils +if tf.__version__ >= '2.0': + import tf_slim as slim +else: + from tensorflow.contrib import slim + if tf.__version__ >= '2.0': tf = tf.compat.v1 diff --git a/modelscope/pipelines/cv/ocr_utils/resnet_utils.py b/modelscope/pipelines/cv/ocr_utils/resnet_utils.py index e0e240c8..0a9af224 100644 --- a/modelscope/pipelines/cv/ocr_utils/resnet_utils.py +++ b/modelscope/pipelines/cv/ocr_utils/resnet_utils.py @@ -19,7 +19,11 @@ implementation is more memory efficient. import collections import tensorflow as tf -import tf_slim as slim + +if tf.__version__ >= '2.0': + import tf_slim as slim +else: + from tensorflow.contrib import slim if tf.__version__ >= '2.0': tf = tf.compat.v1 diff --git a/modelscope/pipelines/nlp/fill_mask_pipeline.py b/modelscope/pipelines/nlp/fill_mask_pipeline.py index 596d65f7..256f867a 100644 --- a/modelscope/pipelines/nlp/fill_mask_pipeline.py +++ b/modelscope/pipelines/nlp/fill_mask_pipeline.py @@ -1,3 +1,4 @@ +import os from typing import Any, Dict, Optional, Union import torch @@ -6,11 +7,13 @@ from ...metainfo import Pipelines from ...models import Model from ...models.nlp.masked_language_model import MaskedLanguageModelBase from ...preprocessors import FillMaskPreprocessor -from ...utils.constant import Tasks +from ...utils.config import Config +from ...utils.constant import ModelFile, Tasks from ..base import Pipeline, Tensor from ..builder import PIPELINES __all__ = ['FillMaskPipeline'] +_type_map = {'veco': 'roberta', 'sbert': 'bert'} @PIPELINES.register_module(Tasks.fill_mask, module_name=Pipelines.fill_mask) @@ -29,7 +32,6 @@ class FillMaskPipeline(Pipeline): """ fill_mask_model = model if isinstance( model, MaskedLanguageModelBase) else Model.from_pretrained(model) - assert fill_mask_model.config is not None if preprocessor is None: preprocessor = FillMaskPreprocessor( @@ -41,11 +43,13 @@ class FillMaskPipeline(Pipeline): model=fill_mask_model, preprocessor=preprocessor, **kwargs) self.preprocessor = preprocessor + self.config = Config.from_file( + os.path.join(fill_mask_model.model_dir, ModelFile.CONFIGURATION)) self.tokenizer = preprocessor.tokenizer - self.mask_id = {'veco': 250001, 'sbert': 103} + self.mask_id = {'roberta': 250001, 'bert': 103} self.rep_map = { - 'sbert': { + 'bert': { '[unused0]': '', '[PAD]': '', '[unused1]': '', @@ -55,7 +59,7 @@ class FillMaskPipeline(Pipeline): '[CLS]': '', '[UNK]': '' }, - 'veco': { + 'roberta': { r' +': ' ', '': '', '': '', @@ -84,7 +88,9 @@ class FillMaskPipeline(Pipeline): input_ids = inputs['input_ids'].detach().numpy() pred_ids = np.argmax(logits, axis=-1) model_type = self.model.config.model_type - rst_ids = np.where(input_ids == self.mask_id[model_type], pred_ids, + process_type = model_type if model_type in self.mask_id else _type_map[ + model_type] + rst_ids = np.where(input_ids == self.mask_id[process_type], pred_ids, input_ids) def rep_tokens(string, rep_map): @@ -94,14 +100,12 @@ class FillMaskPipeline(Pipeline): pred_strings = [] for ids in rst_ids: # batch - # TODO vocab size is not stable - - if self.model.config.vocab_size == 21128: # zh bert + if 'language' in self.config.model and self.config.model.language == 'zh': pred_string = self.tokenizer.convert_ids_to_tokens(ids) pred_string = ''.join(pred_string) else: pred_string = self.tokenizer.decode(ids) - pred_string = rep_tokens(pred_string, self.rep_map[model_type]) + pred_string = rep_tokens(pred_string, self.rep_map[process_type]) pred_strings.append(pred_string) return {'text': pred_strings} diff --git a/modelscope/preprocessors/nlp.py b/modelscope/preprocessors/nlp.py index f998da37..3bd1f110 100644 --- a/modelscope/preprocessors/nlp.py +++ b/modelscope/preprocessors/nlp.py @@ -326,14 +326,17 @@ class FillMaskPreprocessor(Preprocessor): model_dir (str): model path """ super().__init__(*args, **kwargs) - from sofa.utils.backend import AutoTokenizer self.model_dir = model_dir self.first_sequence: str = kwargs.pop('first_sequence', 'first_sequence') self.sequence_length = kwargs.pop('sequence_length', 128) - - self.tokenizer = AutoTokenizer.from_pretrained( - model_dir, use_fast=False) + try: + from transformers import AutoTokenizer + self.tokenizer = AutoTokenizer.from_pretrained(model_dir) + except KeyError: + from sofa.utils.backend import AutoTokenizer + self.tokenizer = AutoTokenizer.from_pretrained( + model_dir, use_fast=False) @type_assert(object, str) def __call__(self, data: str) -> Dict[str, Any]: diff --git a/requirements/audio.txt b/requirements/audio.txt index c7b2b239..1f5984ca 100644 --- a/requirements/audio.txt +++ b/requirements/audio.txt @@ -16,6 +16,7 @@ protobuf>3,<=3.20 ptflops PyWavelets>=1.0.0 scikit-learn +SoundFile>0.10 sox tensorboard tensorflow==1.15.* diff --git a/tests/pipelines/test_animal_recognation.py b/tests/pipelines/test_animal_recognation.py new file mode 100644 index 00000000..d0f42dc3 --- /dev/null +++ b/tests/pipelines/test_animal_recognation.py @@ -0,0 +1,20 @@ +import unittest + +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class MultiModalFeatureTest(unittest.TestCase): + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run(self): + animal_recog = pipeline( + Tasks.image_classification, + model='damo/cv_resnest101_animal_recognation') + result = animal_recog('data/test/images/image1.jpg') + print(result) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_fill_mask.py b/tests/pipelines/test_fill_mask.py index 49c5dc8a..d44ba4c8 100644 --- a/tests/pipelines/test_fill_mask.py +++ b/tests/pipelines/test_fill_mask.py @@ -3,7 +3,8 @@ import unittest from modelscope.hub.snapshot_download import snapshot_download from modelscope.models import Model -from modelscope.models.nlp import StructBertForMaskedLM, VecoForMaskedLM +from modelscope.models.nlp import (BertForMaskedLM, StructBertForMaskedLM, + VecoForMaskedLM) from modelscope.pipelines import FillMaskPipeline, pipeline from modelscope.preprocessors import FillMaskPreprocessor from modelscope.utils.constant import Tasks @@ -16,6 +17,7 @@ class FillMaskTest(unittest.TestCase): 'en': 'damo/nlp_structbert_fill-mask_english-large' } model_id_veco = 'damo/nlp_veco_fill-mask-large' + model_id_bert = 'damo/nlp_bert_fill-mask_chinese-base' ori_texts = { 'zh': @@ -69,6 +71,20 @@ class FillMaskTest(unittest.TestCase): f'{pipeline1(test_input)}\npipeline2: {pipeline2(test_input)}\n' ) + # zh bert + language = 'zh' + model_dir = snapshot_download(self.model_id_bert) + preprocessor = FillMaskPreprocessor( + model_dir, first_sequence='sentence', second_sequence=None) + model = BertForMaskedLM(model_dir) + pipeline1 = FillMaskPipeline(model, preprocessor) + pipeline2 = pipeline( + Tasks.fill_mask, model=model, preprocessor=preprocessor) + ori_text = self.ori_texts[language] + test_input = self.test_inputs[language] + print(f'\nori_text: {ori_text}\ninput: {test_input}\npipeline1: ' + f'{pipeline1(test_input)}\npipeline2: {pipeline2(test_input)}\n') + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_run_with_model_from_modelhub(self): # sbert @@ -97,6 +113,18 @@ class FillMaskTest(unittest.TestCase): print(f'\nori_text: {ori_text}\ninput: {test_input}\npipeline: ' f'{pipeline_ins(test_input)}\n') + # zh bert + model = Model.from_pretrained(self.model_id_bert) + preprocessor = FillMaskPreprocessor( + model.model_dir, first_sequence='sentence', second_sequence=None) + pipeline_ins = pipeline( + Tasks.fill_mask, model=model, preprocessor=preprocessor) + language = 'zh' + ori_text = self.ori_texts[language] + test_input = self.test_inputs[language] + print(f'\nori_text: {ori_text}\ninput: {test_input}\npipeline: ' + f'{pipeline_ins(test_input)}\n') + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_run_with_model_name(self): # veco @@ -115,6 +143,12 @@ class FillMaskTest(unittest.TestCase): f'\nori_text: {self.ori_texts[language]}\ninput: {self.test_inputs[language]}\npipeline: ' f'{pipeline_ins(self.test_inputs[language])}\n') + # bert + pipeline_ins = pipeline(task=Tasks.fill_mask, model=self.model_id_bert) + print( + f'\nori_text: {self.ori_texts[language]}\ninput: {self.test_inputs[language]}\npipeline: ' + f'{pipeline_ins(self.test_inputs[language])}\n') + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_with_default_model(self): pipeline_ins = pipeline(task=Tasks.fill_mask) diff --git a/tests/pipelines/test_ocr_detection.py b/tests/pipelines/test_ocr_detection.py index 986961b7..d1ecd4e4 100644 --- a/tests/pipelines/test_ocr_detection.py +++ b/tests/pipelines/test_ocr_detection.py @@ -27,6 +27,11 @@ class OCRDetectionTest(unittest.TestCase): print('ocr detection results: ') print(result) + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_model_from_modelhub(self): + ocr_detection = pipeline(Tasks.ocr_detection, model=self.model_id) + self.pipeline_inference(ocr_detection, self.test_image) + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_modelhub_default_model(self): ocr_detection = pipeline(Tasks.ocr_detection) diff --git a/tests/pipelines/test_speech_signal_process.py b/tests/pipelines/test_speech_signal_process.py index bc3a542e..f317bc07 100644 --- a/tests/pipelines/test_speech_signal_process.py +++ b/tests/pipelines/test_speech_signal_process.py @@ -17,6 +17,9 @@ AEC_LIB_URL = 'http://isv-data.oss-cn-hangzhou.aliyuncs.com/ics%2FMaaS%2FAEC%2Fl '?Expires=1664085465&OSSAccessKeyId=LTAIxjQyZNde90zh&Signature=Y7gelmGEsQAJRK4yyHSYMrdWizk%3D' AEC_LIB_FILE = 'libmitaec_pyio.so' +NOISE_SPEECH_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ANS/sample_audio/speech_with_noise.wav' +NOISE_SPEECH_FILE = 'speech_with_noise.wav' + def download(remote_path, local_path): local_dir = os.path.dirname(local_path) @@ -30,23 +33,40 @@ def download(remote_path, local_path): class SpeechSignalProcessTest(unittest.TestCase): def setUp(self) -> None: - self.model_id = 'damo/speech_dfsmn_aec_psm_16k' + pass + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_aec(self): # A temporary hack to provide c++ lib. Download it first. download(AEC_LIB_URL, AEC_LIB_FILE) - - @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') - def test_run(self): + # Download audio files download(NEAREND_MIC_URL, NEAREND_MIC_FILE) download(FAREND_SPEECH_URL, FAREND_SPEECH_FILE) + model_id = 'damo/speech_dfsmn_aec_psm_16k' input = { 'nearend_mic': NEAREND_MIC_FILE, 'farend_speech': FAREND_SPEECH_FILE } aec = pipeline( Tasks.speech_signal_process, - model=self.model_id, + model=model_id, pipeline_name=Pipelines.speech_dfsmn_aec_psm_16k) - aec(input, output_path='output.wav') + output_path = os.path.abspath('output.wav') + aec(input, output_path=output_path) + print(f'Processed audio saved to {output_path}') + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_ans(self): + # Download audio files + download(NOISE_SPEECH_URL, NOISE_SPEECH_FILE) + model_id = 'damo/speech_frcrn_ans_cirm_16k' + ans = pipeline( + Tasks.speech_signal_process, + model=model_id, + pipeline_name=Pipelines.speech_frcrn_ans_cirm_16k) + output_path = os.path.abspath('output.wav') + ans(NOISE_SPEECH_FILE, output_path=output_path) + print(f'Processed audio saved to {output_path}') if __name__ == '__main__':