You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

deepspeech2.py 13 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. DeepSpeech2 model
  17. """
  18. import math
  19. import numpy as np
  20. import mindspore.common.dtype as mstype
  21. from mindspore.ops import operations as P
  22. from mindspore import nn, Tensor, ParameterTuple, Parameter
  23. from mindspore.common.initializer import initializer
  24. class SequenceWise(nn.Cell):
  25. """
  26. SequenceWise FC Layers.
  27. """
  28. def __init__(self, module):
  29. super(SequenceWise, self).__init__()
  30. self.module = module
  31. self.reshape_op = P.Reshape()
  32. self.shape_op = P.Shape()
  33. self._initialize_weights()
  34. def construct(self, x):
  35. sizes = self.shape_op(x)
  36. t, n = sizes[0], sizes[1]
  37. x = self.reshape_op(x, (t * n, -1))
  38. x = self.module(x)
  39. x = self.reshape_op(x, (t, n, -1))
  40. return x
  41. def _initialize_weights(self):
  42. self.init_parameters_data()
  43. for _, m in self.cells_and_names():
  44. if isinstance(m, nn.Dense):
  45. m.weight.set_data(Tensor(
  46. np.random.uniform(-1. / m.in_channels, 1. / m.in_channels, m.weight.data.shape).astype("float32")))
  47. if m.bias is not None:
  48. m.bias.set_data(Tensor(
  49. np.random.uniform(-1. / m.in_channels, 1. / m.in_channels, m.bias.data.shape).astype(
  50. "float32")))
  51. class MaskConv(nn.Cell):
  52. """
  53. MaskConv architecture. MaskConv is actually not implemented in this part because some operation in MindSpore
  54. is not supported. lengths is kept for future use.
  55. """
  56. def __init__(self):
  57. super(MaskConv, self).__init__()
  58. self.zeros = P.ZerosLike()
  59. self.conv1 = nn.Conv2d(1, 32, kernel_size=(41, 11), stride=(2, 2), pad_mode='pad', padding=(20, 20, 5, 5))
  60. self.bn1 = nn.BatchNorm2d(num_features=32)
  61. self.conv2 = nn.Conv2d(32, 32, kernel_size=(21, 11), stride=(2, 1), pad_mode='pad', padding=(10, 10, 5, 5))
  62. self.bn2 = nn.BatchNorm2d(num_features=32)
  63. self.tanh = nn.Tanh()
  64. self._initialize_weights()
  65. self.module_list = nn.CellList([self.conv1, self.bn1, self.tanh, self.conv2, self.bn2, self.tanh])
  66. def construct(self, x, lengths):
  67. for module in self.module_list:
  68. x = module(x)
  69. return x
  70. def _initialize_weights(self):
  71. """
  72. parameter initialization
  73. """
  74. self.init_parameters_data()
  75. for _, m in self.cells_and_names():
  76. if isinstance(m, nn.Conv2d):
  77. n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
  78. m.weight.set_data(Tensor(np.random.normal(0, np.sqrt(2. / n),
  79. m.weight.data.shape).astype("float32")))
  80. if m.bias is not None:
  81. m.bias.set_data(
  82. Tensor(np.zeros(m.bias.data.shape, dtype="float32")))
  83. elif isinstance(m, nn.BatchNorm2d):
  84. m.gamma.set_data(
  85. Tensor(np.ones(m.gamma.data.shape, dtype="float32")))
  86. m.beta.set_data(
  87. Tensor(np.zeros(m.beta.data.shape, dtype="float32")))
  88. class BatchRNN(nn.Cell):
  89. """
  90. BatchRNN architecture.
  91. Args:
  92. batch_size(int): smaple_number of per step in training
  93. input_size (int): dimension of input tensor
  94. hidden_size(int): rnn hidden size
  95. num_layers(int): rnn layers
  96. bidirectional(bool): use bidirectional rnn (default=True). Currently, only bidirectional rnn is implemented.
  97. batch_norm(bool): whether to use batchnorm in RNN. Currently, GPU does not support batch_norm1D (default=False).
  98. rnn_type (str): rnn type to use (default='LSTM'). Currently, only LSTM is supported.
  99. """
  100. def __init__(self, batch_size, input_size, hidden_size, num_layers, bidirectional=False, batch_norm=False,
  101. rnn_type='LSTM', device_target="GPU"):
  102. super(BatchRNN, self).__init__()
  103. self.batch_size = batch_size
  104. self.input_size = input_size
  105. self.hidden_size = hidden_size
  106. self.num_layers = num_layers
  107. self.rnn_type = rnn_type
  108. self.bidirectional = bidirectional
  109. self.has_bias = True
  110. self.is_batch_norm = batch_norm
  111. self.num_directions = 2 if bidirectional else 1
  112. self.reshape_op = P.Reshape()
  113. self.shape_op = P.Shape()
  114. self.sum_op = P.ReduceSum()
  115. input_size_list = [input_size]
  116. for i in range(num_layers - 1):
  117. input_size_list.append(hidden_size)
  118. layers = []
  119. for i in range(num_layers):
  120. layers.append(
  121. nn.LSTMCell(input_size=input_size_list[i], hidden_size=hidden_size, bidirectional=bidirectional,
  122. has_bias=self.has_bias))
  123. weights = []
  124. for i in range(num_layers):
  125. weight_size = (input_size_list[i] + hidden_size) * hidden_size * self.num_directions * 4
  126. if self.has_bias:
  127. if device_target == "GPU":
  128. bias_size = self.num_directions * hidden_size * 4 * 2
  129. else:
  130. bias_size = self.num_directions * hidden_size * 4
  131. weight_size = weight_size + bias_size
  132. stdv = 1 / math.sqrt(hidden_size)
  133. w_np = np.random.uniform(-stdv, stdv, (weight_size, 1, 1)).astype(np.float32)
  134. weights.append(Parameter(initializer(Tensor(w_np), w_np.shape), name="weight" + str(i)))
  135. self.h, self.c = self.stack_lstm_default_state(batch_size, hidden_size, num_layers=num_layers,
  136. bidirectional=bidirectional)
  137. self.lstms = layers
  138. self.weight = ParameterTuple(tuple(weights))
  139. if batch_norm:
  140. batch_norm_layer = []
  141. for i in range(num_layers - 1):
  142. batch_norm_layer.append(nn.BatchNorm1d(hidden_size))
  143. self.batch_norm_list = batch_norm_layer
  144. def stack_lstm_default_state(self, batch_size, hidden_size, num_layers, bidirectional):
  145. """init default input."""
  146. num_directions = 2 if bidirectional else 1
  147. h_list = c_list = []
  148. for _ in range(num_layers):
  149. h_list.append(Tensor(np.zeros((num_directions, batch_size, hidden_size)).astype(np.float32)))
  150. c_list.append(Tensor(np.zeros((num_directions, batch_size, hidden_size)).astype(np.float32)))
  151. h, c = tuple(h_list), tuple(c_list)
  152. return h, c
  153. def construct(self, x):
  154. for i in range(self.num_layers):
  155. if self.is_batch_norm and i > 0:
  156. x = self.batch_norm_list[i - 1](x)
  157. x, _, _, _, _ = self.lstms[i](x, self.h[i], self.c[i], self.weight[i])
  158. if self.bidirectional:
  159. size = self.shape_op(x)
  160. x = self.reshape_op(x, (size[0], size[1], 2, -1))
  161. x = self.sum_op(x, 2)
  162. return x
  163. class DeepSpeechModel(nn.Cell):
  164. """
  165. ResNet architecture.
  166. Args:
  167. batch_size(int): smaple_number of per step in training (default=128)
  168. rnn_type (str): rnn type to use (default="LSTM")
  169. labels (list): list containing all the possible characters to map to
  170. rnn_hidden_size(int): rnn hidden size
  171. nb_layers(int): number of rnn layers
  172. audio_conf: Config containing the sample rate, window and the window length/stride in seconds
  173. bidirectional(bool): use bidirectional rnn (default=True)
  174. """
  175. def __init__(self, batch_size, labels, rnn_hidden_size, nb_layers, audio_conf, rnn_type='LSTM',
  176. bidirectional=True, device_target='GPU'):
  177. super(DeepSpeechModel, self).__init__()
  178. self.batch_size = batch_size
  179. self.hidden_size = rnn_hidden_size
  180. self.hidden_layers = nb_layers
  181. self.rnn_type = rnn_type
  182. self.audio_conf = audio_conf
  183. self.labels = labels
  184. self.bidirectional = bidirectional
  185. self.reshape_op = P.Reshape()
  186. self.shape_op = P.Shape()
  187. self.transpose_op = P.Transpose()
  188. self.add = P.Add()
  189. self.div = P.Div()
  190. sample_rate = self.audio_conf.sample_rate
  191. window_size = self.audio_conf.window_size
  192. num_classes = len(self.labels)
  193. self.conv = MaskConv()
  194. # This is to calculate
  195. self.pre, self.stride = self.get_conv_num()
  196. # Based on above convolutions and spectrogram size using conv formula (W - F + 2P)/ S+1
  197. rnn_input_size = int(math.floor((sample_rate * window_size) / 2) + 1)
  198. rnn_input_size = int(math.floor(rnn_input_size + 2 * 20 - 41) / 2 + 1)
  199. rnn_input_size = int(math.floor(rnn_input_size + 2 * 10 - 21) / 2 + 1)
  200. rnn_input_size *= 32
  201. self.RNN = BatchRNN(batch_size=self.batch_size, input_size=rnn_input_size, num_layers=nb_layers,
  202. hidden_size=rnn_hidden_size, bidirectional=bidirectional, batch_norm=False,
  203. rnn_type=self.rnn_type, device_target=device_target)
  204. fully_connected = nn.Dense(rnn_hidden_size, num_classes, has_bias=False)
  205. self.fc = SequenceWise(fully_connected)
  206. def construct(self, x, lengths):
  207. """
  208. lengths is actually not used in this part since Mindspore does not support dynamic shape.
  209. """
  210. output_lengths = self.get_seq_lens(lengths)
  211. x = self.conv(x, lengths)
  212. sizes = self.shape_op(x)
  213. x = self.reshape_op(x, (sizes[0], sizes[1] * sizes[2], sizes[3]))
  214. x = self.transpose_op(x, (2, 0, 1))
  215. x = self.RNN(x)
  216. x = self.fc(x)
  217. return x, output_lengths
  218. def get_seq_lens(self, seq_len):
  219. """
  220. Given a 1D Tensor or Variable containing integer sequence lengths, return a 1D tensor or variable
  221. containing the size sequences that will be output by the network.
  222. """
  223. for i in range(len(self.stride)):
  224. seq_len = self.add(self.div(self.add(seq_len, self.pre[i]), self.stride[i]), 1)
  225. return seq_len
  226. def get_conv_num(self):
  227. p, s = [], []
  228. for _, cell in self.conv.cells_and_names():
  229. if isinstance(cell, nn.Conv2d):
  230. kernel_size = cell.kernel_size
  231. padding_1 = int((kernel_size[1] - 1) / 2)
  232. temp = 2 * padding_1 - cell.dilation[1] * (cell.kernel_size[1] - 1) - 1
  233. p.append(temp)
  234. s.append(cell.stride[1])
  235. return p, s
  236. class NetWithLossClass(nn.Cell):
  237. """
  238. NetWithLossClass definition
  239. """
  240. def __init__(self, network):
  241. super(NetWithLossClass, self).__init__(auto_prefix=False)
  242. self.loss = P.CTCLoss(ctc_merge_repeated=True)
  243. self.network = network
  244. self.ReduceMean_false = P.ReduceMean(keep_dims=False)
  245. self.squeeze_op = P.Squeeze(0)
  246. self.cast_op = P.Cast()
  247. def construct(self, inputs, input_length, target_indices, label_values):
  248. predict, output_length = self.network(inputs, input_length)
  249. loss = self.loss(predict, target_indices, label_values, self.cast_op(output_length, mstype.int32))
  250. return self.ReduceMean_false(loss[0])
  251. class PredictWithSoftmax(nn.Cell):
  252. """
  253. PredictWithSoftmax
  254. """
  255. def __init__(self, network):
  256. super(PredictWithSoftmax, self).__init__(auto_prefix=False)
  257. self.network = network
  258. self.inference_softmax = P.Softmax(axis=-1)
  259. self.transpose_op = P.Transpose()
  260. self.cast_op = P.Cast()
  261. def construct(self, inputs, input_length):
  262. x, output_sizes = self.network(inputs, self.cast_op(input_length, mstype.int32))
  263. x = self.inference_softmax(x)
  264. x = self.transpose_op(x, (1, 0, 2))
  265. return x, output_sizes