# Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """ DeepSpeech2 model """ import math import numpy as np import mindspore.common.dtype as mstype from mindspore.ops import operations as P from mindspore import nn, Tensor, ParameterTuple, Parameter from mindspore.common.initializer import initializer class SequenceWise(nn.Cell): """ SequenceWise FC Layers. """ def __init__(self, module): super(SequenceWise, self).__init__() self.module = module self.reshape_op = P.Reshape() self.shape_op = P.Shape() self._initialize_weights() def construct(self, x): sizes = self.shape_op(x) t, n = sizes[0], sizes[1] x = self.reshape_op(x, (t * n, -1)) x = self.module(x) x = self.reshape_op(x, (t, n, -1)) return x def _initialize_weights(self): self.init_parameters_data() for _, m in self.cells_and_names(): if isinstance(m, nn.Dense): m.weight.set_data(Tensor( np.random.uniform(-1. / m.in_channels, 1. / m.in_channels, m.weight.data.shape).astype("float32"))) if m.bias is not None: m.bias.set_data(Tensor( np.random.uniform(-1. / m.in_channels, 1. / m.in_channels, m.bias.data.shape).astype( "float32"))) class MaskConv(nn.Cell): """ MaskConv architecture. MaskConv is actually not implemented in this part because some operation in MindSpore is not supported. lengths is kept for future use. """ def __init__(self): super(MaskConv, self).__init__() self.zeros = P.ZerosLike() self.conv1 = nn.Conv2d(1, 32, kernel_size=(41, 11), stride=(2, 2), pad_mode='pad', padding=(20, 20, 5, 5)) self.bn1 = nn.BatchNorm2d(num_features=32) self.conv2 = nn.Conv2d(32, 32, kernel_size=(21, 11), stride=(2, 1), pad_mode='pad', padding=(10, 10, 5, 5)) self.bn2 = nn.BatchNorm2d(num_features=32) self.tanh = nn.Tanh() self._initialize_weights() self.module_list = nn.CellList([self.conv1, self.bn1, self.tanh, self.conv2, self.bn2, self.tanh]) def construct(self, x, lengths): for module in self.module_list: x = module(x) return x def _initialize_weights(self): """ parameter initialization """ self.init_parameters_data() for _, m in self.cells_and_names(): if isinstance(m, nn.Conv2d): n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels m.weight.set_data(Tensor(np.random.normal(0, np.sqrt(2. / n), m.weight.data.shape).astype("float32"))) if m.bias is not None: m.bias.set_data( Tensor(np.zeros(m.bias.data.shape, dtype="float32"))) elif isinstance(m, nn.BatchNorm2d): m.gamma.set_data( Tensor(np.ones(m.gamma.data.shape, dtype="float32"))) m.beta.set_data( Tensor(np.zeros(m.beta.data.shape, dtype="float32"))) class BatchRNN(nn.Cell): """ BatchRNN architecture. Args: batch_size(int): smaple_number of per step in training input_size (int): dimension of input tensor hidden_size(int): rnn hidden size num_layers(int): rnn layers bidirectional(bool): use bidirectional rnn (default=True). Currently, only bidirectional rnn is implemented. batch_norm(bool): whether to use batchnorm in RNN. Currently, GPU does not support batch_norm1D (default=False). rnn_type (str): rnn type to use (default='LSTM'). Currently, only LSTM is supported. """ def __init__(self, batch_size, input_size, hidden_size, num_layers, bidirectional=False, batch_norm=False, rnn_type='LSTM', device_target="GPU"): super(BatchRNN, self).__init__() self.batch_size = batch_size self.input_size = input_size self.hidden_size = hidden_size self.num_layers = num_layers self.rnn_type = rnn_type self.bidirectional = bidirectional self.has_bias = True self.is_batch_norm = batch_norm self.num_directions = 2 if bidirectional else 1 self.reshape_op = P.Reshape() self.shape_op = P.Shape() self.sum_op = P.ReduceSum() input_size_list = [input_size] for i in range(num_layers - 1): input_size_list.append(hidden_size) layers = [] for i in range(num_layers): layers.append( nn.LSTMCell(input_size=input_size_list[i], hidden_size=hidden_size, bidirectional=bidirectional, has_bias=self.has_bias)) weights = [] for i in range(num_layers): weight_size = (input_size_list[i] + hidden_size) * hidden_size * self.num_directions * 4 if self.has_bias: if device_target == "GPU": bias_size = self.num_directions * hidden_size * 4 * 2 else: bias_size = self.num_directions * hidden_size * 4 weight_size = weight_size + bias_size stdv = 1 / math.sqrt(hidden_size) w_np = np.random.uniform(-stdv, stdv, (weight_size, 1, 1)).astype(np.float32) weights.append(Parameter(initializer(Tensor(w_np), w_np.shape), name="weight" + str(i))) self.h, self.c = self.stack_lstm_default_state(batch_size, hidden_size, num_layers=num_layers, bidirectional=bidirectional) self.lstms = layers self.weight = ParameterTuple(tuple(weights)) if batch_norm: batch_norm_layer = [] for i in range(num_layers - 1): batch_norm_layer.append(nn.BatchNorm1d(hidden_size)) self.batch_norm_list = batch_norm_layer def stack_lstm_default_state(self, batch_size, hidden_size, num_layers, bidirectional): """init default input.""" num_directions = 2 if bidirectional else 1 h_list = c_list = [] for _ in range(num_layers): h_list.append(Tensor(np.zeros((num_directions, batch_size, hidden_size)).astype(np.float32))) c_list.append(Tensor(np.zeros((num_directions, batch_size, hidden_size)).astype(np.float32))) h, c = tuple(h_list), tuple(c_list) return h, c def construct(self, x): for i in range(self.num_layers): if self.is_batch_norm and i > 0: x = self.batch_norm_list[i - 1](x) x, _, _, _, _ = self.lstms[i](x, self.h[i], self.c[i], self.weight[i]) if self.bidirectional: size = self.shape_op(x) x = self.reshape_op(x, (size[0], size[1], 2, -1)) x = self.sum_op(x, 2) return x class DeepSpeechModel(nn.Cell): """ ResNet architecture. Args: batch_size(int): smaple_number of per step in training (default=128) rnn_type (str): rnn type to use (default="LSTM") labels (list): list containing all the possible characters to map to rnn_hidden_size(int): rnn hidden size nb_layers(int): number of rnn layers audio_conf: Config containing the sample rate, window and the window length/stride in seconds bidirectional(bool): use bidirectional rnn (default=True) """ def __init__(self, batch_size, labels, rnn_hidden_size, nb_layers, audio_conf, rnn_type='LSTM', bidirectional=True, device_target='GPU'): super(DeepSpeechModel, self).__init__() self.batch_size = batch_size self.hidden_size = rnn_hidden_size self.hidden_layers = nb_layers self.rnn_type = rnn_type self.audio_conf = audio_conf self.labels = labels self.bidirectional = bidirectional self.reshape_op = P.Reshape() self.shape_op = P.Shape() self.transpose_op = P.Transpose() self.add = P.Add() self.div = P.Div() sample_rate = self.audio_conf.sample_rate window_size = self.audio_conf.window_size num_classes = len(self.labels) self.conv = MaskConv() # This is to calculate self.pre, self.stride = self.get_conv_num() # Based on above convolutions and spectrogram size using conv formula (W - F + 2P)/ S+1 rnn_input_size = int(math.floor((sample_rate * window_size) / 2) + 1) rnn_input_size = int(math.floor(rnn_input_size + 2 * 20 - 41) / 2 + 1) rnn_input_size = int(math.floor(rnn_input_size + 2 * 10 - 21) / 2 + 1) rnn_input_size *= 32 self.RNN = BatchRNN(batch_size=self.batch_size, input_size=rnn_input_size, num_layers=nb_layers, hidden_size=rnn_hidden_size, bidirectional=bidirectional, batch_norm=False, rnn_type=self.rnn_type, device_target=device_target) fully_connected = nn.Dense(rnn_hidden_size, num_classes, has_bias=False) self.fc = SequenceWise(fully_connected) def construct(self, x, lengths): """ lengths is actually not used in this part since Mindspore does not support dynamic shape. """ output_lengths = self.get_seq_lens(lengths) x = self.conv(x, lengths) sizes = self.shape_op(x) x = self.reshape_op(x, (sizes[0], sizes[1] * sizes[2], sizes[3])) x = self.transpose_op(x, (2, 0, 1)) x = self.RNN(x) x = self.fc(x) return x, output_lengths def get_seq_lens(self, seq_len): """ Given a 1D Tensor or Variable containing integer sequence lengths, return a 1D tensor or variable containing the size sequences that will be output by the network. """ for i in range(len(self.stride)): seq_len = self.add(self.div(self.add(seq_len, self.pre[i]), self.stride[i]), 1) return seq_len def get_conv_num(self): p, s = [], [] for _, cell in self.conv.cells_and_names(): if isinstance(cell, nn.Conv2d): kernel_size = cell.kernel_size padding_1 = int((kernel_size[1] - 1) / 2) temp = 2 * padding_1 - cell.dilation[1] * (cell.kernel_size[1] - 1) - 1 p.append(temp) s.append(cell.stride[1]) return p, s class NetWithLossClass(nn.Cell): """ NetWithLossClass definition """ def __init__(self, network): super(NetWithLossClass, self).__init__(auto_prefix=False) self.loss = P.CTCLoss(ctc_merge_repeated=True) self.network = network self.ReduceMean_false = P.ReduceMean(keep_dims=False) self.squeeze_op = P.Squeeze(0) self.cast_op = P.Cast() def construct(self, inputs, input_length, target_indices, label_values): predict, output_length = self.network(inputs, input_length) loss = self.loss(predict, target_indices, label_values, self.cast_op(output_length, mstype.int32)) return self.ReduceMean_false(loss[0]) class PredictWithSoftmax(nn.Cell): """ PredictWithSoftmax """ def __init__(self, network): super(PredictWithSoftmax, self).__init__(auto_prefix=False) self.network = network self.inference_softmax = P.Softmax(axis=-1) self.transpose_op = P.Transpose() self.cast_op = P.Cast() def construct(self, inputs, input_length): x, output_sizes = self.network(inputs, self.cast_op(input_length, mstype.int32)) x = self.inference_softmax(x) x = self.transpose_op(x, (1, 0, 2)) return x, output_sizes