You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dataset.py 8.9 kB

5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. Create train or eval dataset.
  17. """
  18. import math
  19. import numpy as np
  20. import mindspore.dataset.engine as de
  21. import librosa
  22. import soundfile as sf
  23. TRAIN_INPUT_PAD_LENGTH = 1501
  24. TRAIN_LABEL_PAD_LENGTH = 350
  25. TEST_INPUT_PAD_LENGTH = 3500
  26. class LoadAudioAndTranscript():
  27. """
  28. parse audio and transcript
  29. """
  30. def __init__(self,
  31. audio_conf=None,
  32. normalize=False,
  33. labels=None):
  34. super(LoadAudioAndTranscript, self).__init__()
  35. self.window_stride = audio_conf.window_stride
  36. self.window_size = audio_conf.window_size
  37. self.sample_rate = audio_conf.sample_rate
  38. self.window = audio_conf.window
  39. self.is_normalization = normalize
  40. self.labels = labels
  41. def load_audio(self, path):
  42. """
  43. load audio
  44. """
  45. sound, _ = sf.read(path, dtype='int16')
  46. sound = sound.astype('float32') / 32767
  47. if len(sound.shape) > 1:
  48. if sound.shape[1] == 1:
  49. sound = sound.squeeze()
  50. else:
  51. sound = sound.mean(axis=1)
  52. return sound
  53. def parse_audio(self, audio_path):
  54. """
  55. parse audio
  56. """
  57. audio = self.load_audio(audio_path)
  58. n_fft = int(self.sample_rate * self.window_size)
  59. win_length = n_fft
  60. hop_length = int(self.sample_rate * self.window_stride)
  61. D = librosa.stft(y=audio, n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=self.window)
  62. mag, _ = librosa.magphase(D)
  63. mag = np.log1p(mag)
  64. if self.is_normalization:
  65. mean = mag.mean()
  66. std = mag.std()
  67. mag = (mag - mean) / std
  68. return mag
  69. def parse_transcript(self, transcript_path):
  70. with open(transcript_path, 'r', encoding='utf8') as transcript_file:
  71. transcript = transcript_file.read().replace('\n', '')
  72. transcript = list(filter(None, [self.labels.get(x) for x in list(transcript)]))
  73. return transcript
  74. class ASRDataset(LoadAudioAndTranscript):
  75. """
  76. create ASRDataset
  77. Args:
  78. audio_conf: Config containing the sample rate, window and the window length/stride in seconds
  79. manifest_filepath (str): manifest_file path.
  80. labels (list): List containing all the possible characters to map to
  81. normalize: Apply standard mean and deviation normalization to audio tensor
  82. batch_size (int): Dataset batch size (default=32)
  83. """
  84. def __init__(self, audio_conf=None,
  85. manifest_filepath='',
  86. labels=None,
  87. normalize=False,
  88. batch_size=32,
  89. is_training=True):
  90. with open(manifest_filepath) as f:
  91. ids = f.readlines()
  92. ids = [x.strip().split(',') for x in ids]
  93. self.is_training = is_training
  94. self.ids = ids
  95. self.blank_id = int(labels.index('_'))
  96. self.bins = [ids[i:i + batch_size] for i in range(0, len(ids), batch_size)]
  97. if len(self.ids) % batch_size != 0:
  98. self.bins = self.bins[:-1]
  99. self.bins.append(ids[-batch_size:])
  100. self.size = len(self.bins)
  101. self.batch_size = batch_size
  102. self.labels_map = {labels[i]: i for i in range(len(labels))}
  103. super(ASRDataset, self).__init__(audio_conf, normalize, self.labels_map)
  104. def __getitem__(self, index):
  105. batch_idx = self.bins[index]
  106. batch_size = len(batch_idx)
  107. batch_spect, batch_script, target_indices = [], [], []
  108. input_length = np.zeros(batch_size, np.float32)
  109. for data in batch_idx:
  110. audio_path, transcript_path = data[0], data[1]
  111. spect = self.parse_audio(audio_path)
  112. transcript = self.parse_transcript(transcript_path)
  113. batch_spect.append(spect)
  114. batch_script.append(transcript)
  115. freq_size = np.shape(batch_spect[-1])[0]
  116. if self.is_training:
  117. # 1501 is the max length in train dataset(LibriSpeech).
  118. # The length is fixed to this value because Mindspore does not support dynamic shape currently
  119. inputs = np.zeros((batch_size, 1, freq_size, TRAIN_INPUT_PAD_LENGTH), dtype=np.float32)
  120. # The target length is fixed to this value because Mindspore does not support dynamic shape currently
  121. # 350 may be greater than the max length of labels in train dataset(LibriSpeech).
  122. targets = np.ones((self.batch_size, TRAIN_LABEL_PAD_LENGTH), dtype=np.int32) * self.blank_id
  123. for k, spect_, scripts_ in zip(range(batch_size), batch_spect, batch_script):
  124. seq_length = np.shape(spect_)[1]
  125. input_length[k] = seq_length
  126. script_length = len(scripts_)
  127. targets[k, :script_length] = scripts_
  128. for m in range(350):
  129. target_indices.append([k, m])
  130. inputs[k, 0, :, 0:seq_length] = spect_
  131. targets = np.reshape(targets, (-1,))
  132. else:
  133. inputs = np.zeros((batch_size, 1, freq_size, TEST_INPUT_PAD_LENGTH), dtype=np.float32)
  134. targets = []
  135. for k, spect_, scripts_ in zip(range(batch_size), batch_spect, batch_script):
  136. seq_length = np.shape(spect_)[1]
  137. input_length[k] = seq_length
  138. targets.extend(scripts_)
  139. for m in range(len(scripts_)):
  140. target_indices.append([k, m])
  141. inputs[k, 0, :, 0:seq_length] = spect_
  142. return inputs, input_length, np.array(target_indices, dtype=np.int64), np.array(targets, dtype=np.int32)
  143. def __len__(self):
  144. return self.size
  145. class DistributedSampler():
  146. """
  147. function to distribute and shuffle sample
  148. """
  149. def __init__(self, dataset, rank, group_size, shuffle=True, seed=0):
  150. self.dataset = dataset
  151. self.rank = rank
  152. self.group_size = group_size
  153. self.dataset_len = len(self.dataset)
  154. self.num_samplers = int(math.ceil(self.dataset_len * 1.0 / self.group_size))
  155. self.total_size = self.num_samplers * self.group_size
  156. self.shuffle = shuffle
  157. self.seed = seed
  158. def __iter__(self):
  159. if self.shuffle:
  160. self.seed = (self.seed + 1) & 0xffffffff
  161. np.random.seed(self.seed)
  162. indices = np.random.permutation(self.dataset_len).tolist()
  163. else:
  164. indices = list(range(self.dataset_len))
  165. indices += indices[:(self.total_size - len(indices))]
  166. indices = indices[self.rank::self.group_size]
  167. return iter(indices)
  168. def __len__(self):
  169. return self.num_samplers
  170. def create_dataset(audio_conf, manifest_filepath, labels, normalize, batch_size, train_mode=True,
  171. rank=None, group_size=None):
  172. """
  173. create train dataset
  174. Args:
  175. audio_conf: Config containing the sample rate, window and the window length/stride in seconds
  176. manifest_filepath (str): manifest_file path.
  177. labels (list): list containing all the possible characters to map to
  178. normalize: Apply standard mean and deviation normalization to audio tensor
  179. train_mode (bool): Whether dataset is use for train or eval (default=True).
  180. batch_size (int): Dataset batch size
  181. rank (int): The shard ID within num_shards (default=None).
  182. group_size (int): Number of shards that the dataset should be divided into (default=None).
  183. Returns:
  184. Dataset.
  185. """
  186. dataset = ASRDataset(audio_conf=audio_conf, manifest_filepath=manifest_filepath, labels=labels, normalize=normalize,
  187. batch_size=batch_size, is_training=train_mode)
  188. sampler = DistributedSampler(dataset, rank, group_size, shuffle=True)
  189. ds = de.GeneratorDataset(dataset, ["inputs", "input_length", "target_indices", "label_values"], sampler=sampler)
  190. ds = ds.repeat(1)
  191. return ds