You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

config.py 9.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ===========================================================================
  15. """Config setting, will be used in train.py and eval.py"""
  16. from src.utils import prepare_words_list
  17. def data_config(parser):
  18. '''config for data.'''
  19. parser.add_argument('--data_url', type=str,
  20. default='http://download.tensorflow.org/data/speech_commands_v0.01.tar.gz',
  21. help='Location of speech training data archive on the web.')
  22. parser.add_argument('--data_dir', type=str, default='data',
  23. help='Where to download the dataset.')
  24. parser.add_argument('--feat_dir', type=str, default='feat',
  25. help='Where to save the feature of audios')
  26. parser.add_argument('--background_volume', type=float, default=0.1,
  27. help='How loud the background noise should be, between 0 and 1.')
  28. parser.add_argument('--background_frequency', type=float, default=0.8,
  29. help='How many of the training samples have background noise mixed in.')
  30. parser.add_argument('--silence_percentage', type=float, default=10.0,
  31. help='How much of the training data should be silence.')
  32. parser.add_argument('--unknown_percentage', type=float, default=10.0,
  33. help='How much of the training data should be unknown words.')
  34. parser.add_argument('--time_shift_ms', type=float, default=100.0,
  35. help='Range to randomly shift the training audio by in time.')
  36. parser.add_argument('--testing_percentage', type=int, default=10,
  37. help='What percentage of wavs to use as a test set.')
  38. parser.add_argument('--validation_percentage', type=int, default=10,
  39. help='What percentage of wavs to use as a validation set.')
  40. parser.add_argument('--wanted_words', type=str, default='yes,no,up,down,left,right,on,off,stop,go',
  41. help='Words to use (others will be added to an unknown label)')
  42. parser.add_argument('--sample_rate', type=int, default=16000, help='Expected sample rate of the wavs')
  43. parser.add_argument('--clip_duration_ms', type=int, default=1000,
  44. help='Expected duration in milliseconds of the wavs')
  45. parser.add_argument('--window_size_ms', type=float, default=40.0, help='How long each spectrogram timeslice is')
  46. parser.add_argument('--window_stride_ms', type=float, default=20.0, help='How long each spectrogram timeslice is')
  47. parser.add_argument('--dct_coefficient_count', type=int, default=20,
  48. help='How many bins to use for the MFCC fingerprint')
  49. def train_config(parser):
  50. '''config for train.'''
  51. data_config(parser)
  52. # network related
  53. parser.add_argument('--model_size_info', type=int, nargs="+",
  54. default=[6, 276, 10, 4, 2, 1, 276, 3, 3, 2, 2, 276, 3, 3, 1, 1, 276, 3, 3, 1, 1, 276, 3, 3, 1,
  55. 1, 276, 3, 3, 1, 1],
  56. help='Model dimensions - different for various models')
  57. parser.add_argument('--drop', type=float, default=0.9, help='dropout')
  58. parser.add_argument('--pretrained', type=str, default='', help='model_path, local pretrained model to load')
  59. # training related
  60. parser.add_argument('--use_graph_mode', default=1, type=int, help='use graph mode or feed mode')
  61. parser.add_argument('--val_interval', type=int, default=1, help='validate interval')
  62. # dataset related
  63. parser.add_argument('--per_batch_size', default=100, type=int, help='batch size for per gpu')
  64. # optimizer and lr related
  65. parser.add_argument('--lr_scheduler', default='multistep', type=str,
  66. help='lr-scheduler, option type: multistep, cosine_annealing')
  67. parser.add_argument('--lr', default=0.1, type=float, help='learning rate of the training')
  68. parser.add_argument('--lr_epochs', type=str, default='20,40,60,80', help='epoch of lr changing')
  69. parser.add_argument('--lr_gamma', type=float, default=0.1,
  70. help='decrease lr by a factor of exponential lr_scheduler')
  71. parser.add_argument('--eta_min', type=float, default=0., help='eta_min in cosine_annealing scheduler')
  72. parser.add_argument('--T_max', type=int, default=80, help='T-max in cosine_annealing scheduler')
  73. parser.add_argument('--max_epoch', type=int, default=80, help='max epoch num to train the model')
  74. parser.add_argument('--warmup_epochs', default=0, type=float, help='warmup epoch')
  75. parser.add_argument('--weight_decay', type=float, default=0.001, help='weight decay')
  76. parser.add_argument('--momentum', type=float, default=0.98, help='momentum')
  77. # logging related
  78. parser.add_argument('--log_interval', type=int, default=100, help='logging interval')
  79. parser.add_argument('--ckpt_path', type=str, default='train_outputs/', help='checkpoint save location')
  80. parser.add_argument('--ckpt_interval', type=int, default=100, help='save ckpt_interval')
  81. flags, _ = parser.parse_known_args()
  82. flags.dataset_sink_mode = bool(flags.use_graph_mode)
  83. flags.lr_epochs = list(map(int, flags.lr_epochs.split(',')))
  84. model_settings = prepare_model_settings(
  85. len(prepare_words_list(flags.wanted_words.split(','))),
  86. flags.sample_rate, flags.clip_duration_ms, flags.window_size_ms,
  87. flags.window_stride_ms, flags.dct_coefficient_count)
  88. model_settings['dropout1'] = flags.drop
  89. return flags, model_settings
  90. def eval_config(parser):
  91. '''config for eval.'''
  92. parser.add_argument('--feat_dir', type=str, default='feat',
  93. help='Where to save the feature of audios')
  94. parser.add_argument('--model_dir', type=str,
  95. default='outputs',
  96. help='which folder the models are saved in or specific path of one model')
  97. parser.add_argument('--wanted_words', type=str, default='yes,no,up,down,left,right,on,off,stop,go',
  98. help='Words to use (others will be added to an unknown label)')
  99. parser.add_argument('--sample_rate', type=int, default=16000, help='Expected sample rate of the wavs')
  100. parser.add_argument('--clip_duration_ms', type=int, default=1000,
  101. help='Expected duration in milliseconds of the wavs')
  102. parser.add_argument('--window_size_ms', type=float, default=40.0, help='How long each spectrogram timeslice is')
  103. parser.add_argument('--window_stride_ms', type=float, default=20.0, help='How long each spectrogram timeslice is')
  104. parser.add_argument('--dct_coefficient_count', type=int, default=20,
  105. help='How many bins to use for the MFCC fingerprint')
  106. parser.add_argument('--model_size_info', type=int, nargs="+",
  107. default=[6, 276, 10, 4, 2, 1, 276, 3, 3, 2, 2, 276, 3, 3, 1, 1, 276, 3, 3, 1, 1, 276, 3, 3, 1,
  108. 1, 276, 3, 3, 1, 1],
  109. help='Model dimensions - different for various models')
  110. parser.add_argument('--per_batch_size', default=100, type=int, help='batch size for per gpu')
  111. parser.add_argument('--drop', type=float, default=0.9, help='dropout')
  112. # logging related
  113. parser.add_argument('--log_path', type=str, default='eval_outputs/', help='path to save eval log')
  114. flags, _ = parser.parse_known_args()
  115. model_settings = prepare_model_settings(
  116. len(prepare_words_list(flags.wanted_words.split(','))),
  117. flags.sample_rate, flags.clip_duration_ms, flags.window_size_ms,
  118. flags.window_stride_ms, flags.dct_coefficient_count)
  119. model_settings['dropout1'] = flags.drop
  120. return flags, model_settings
  121. def prepare_model_settings(label_count, sample_rate, clip_duration_ms,
  122. window_size_ms, window_stride_ms,
  123. dct_coefficient_count):
  124. '''Prepare model setting.'''
  125. desired_samples = int(sample_rate * clip_duration_ms / 1000)
  126. window_size_samples = int(sample_rate * window_size_ms / 1000)
  127. window_stride_samples = int(sample_rate * window_stride_ms / 1000)
  128. length_minus_window = (desired_samples - window_size_samples)
  129. if length_minus_window < 0:
  130. spectrogram_length = 0
  131. else:
  132. spectrogram_length = 1 + int(length_minus_window / window_stride_samples)
  133. fingerprint_size = dct_coefficient_count * spectrogram_length
  134. return {
  135. 'desired_samples': desired_samples,
  136. 'window_size_samples': window_size_samples,
  137. 'window_stride_samples': window_stride_samples,
  138. 'spectrogram_length': spectrogram_length,
  139. 'dct_coefficient_count': dct_coefficient_count,
  140. 'fingerprint_size': fingerprint_size,
  141. 'label_count': label_count,
  142. 'sample_rate': sample_rate,
  143. }