# Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # =========================================================================== """Config setting, will be used in train.py and eval.py""" from src.utils import prepare_words_list def data_config(parser): '''config for data.''' parser.add_argument('--data_url', type=str, default='http://download.tensorflow.org/data/speech_commands_v0.01.tar.gz', help='Location of speech training data archive on the web.') parser.add_argument('--data_dir', type=str, default='data', help='Where to download the dataset.') parser.add_argument('--feat_dir', type=str, default='feat', help='Where to save the feature of audios') parser.add_argument('--background_volume', type=float, default=0.1, help='How loud the background noise should be, between 0 and 1.') parser.add_argument('--background_frequency', type=float, default=0.8, help='How many of the training samples have background noise mixed in.') parser.add_argument('--silence_percentage', type=float, default=10.0, help='How much of the training data should be silence.') parser.add_argument('--unknown_percentage', type=float, default=10.0, help='How much of the training data should be unknown words.') parser.add_argument('--time_shift_ms', type=float, default=100.0, help='Range to randomly shift the training audio by in time.') parser.add_argument('--testing_percentage', type=int, default=10, help='What percentage of wavs to use as a test set.') parser.add_argument('--validation_percentage', type=int, default=10, help='What percentage of wavs to use as a validation set.') parser.add_argument('--wanted_words', type=str, default='yes,no,up,down,left,right,on,off,stop,go', help='Words to use (others will be added to an unknown label)') parser.add_argument('--sample_rate', type=int, default=16000, help='Expected sample rate of the wavs') parser.add_argument('--clip_duration_ms', type=int, default=1000, help='Expected duration in milliseconds of the wavs') parser.add_argument('--window_size_ms', type=float, default=40.0, help='How long each spectrogram timeslice is') parser.add_argument('--window_stride_ms', type=float, default=20.0, help='How long each spectrogram timeslice is') parser.add_argument('--dct_coefficient_count', type=int, default=20, help='How many bins to use for the MFCC fingerprint') def train_config(parser): '''config for train.''' data_config(parser) # network related parser.add_argument('--model_size_info', type=int, nargs="+", default=[6, 276, 10, 4, 2, 1, 276, 3, 3, 2, 2, 276, 3, 3, 1, 1, 276, 3, 3, 1, 1, 276, 3, 3, 1, 1, 276, 3, 3, 1, 1], help='Model dimensions - different for various models') parser.add_argument('--drop', type=float, default=0.9, help='dropout') parser.add_argument('--pretrained', type=str, default='', help='model_path, local pretrained model to load') # training related parser.add_argument('--use_graph_mode', default=1, type=int, help='use graph mode or feed mode') parser.add_argument('--val_interval', type=int, default=1, help='validate interval') # dataset related parser.add_argument('--per_batch_size', default=100, type=int, help='batch size for per gpu') # optimizer and lr related parser.add_argument('--lr_scheduler', default='multistep', type=str, help='lr-scheduler, option type: multistep, cosine_annealing') parser.add_argument('--lr', default=0.1, type=float, help='learning rate of the training') parser.add_argument('--lr_epochs', type=str, default='20,40,60,80', help='epoch of lr changing') parser.add_argument('--lr_gamma', type=float, default=0.1, help='decrease lr by a factor of exponential lr_scheduler') parser.add_argument('--eta_min', type=float, default=0., help='eta_min in cosine_annealing scheduler') parser.add_argument('--T_max', type=int, default=80, help='T-max in cosine_annealing scheduler') parser.add_argument('--max_epoch', type=int, default=80, help='max epoch num to train the model') parser.add_argument('--warmup_epochs', default=0, type=float, help='warmup epoch') parser.add_argument('--weight_decay', type=float, default=0.001, help='weight decay') parser.add_argument('--momentum', type=float, default=0.98, help='momentum') # logging related parser.add_argument('--log_interval', type=int, default=100, help='logging interval') parser.add_argument('--ckpt_path', type=str, default='train_outputs/', help='checkpoint save location') parser.add_argument('--ckpt_interval', type=int, default=100, help='save ckpt_interval') flags, _ = parser.parse_known_args() flags.dataset_sink_mode = bool(flags.use_graph_mode) flags.lr_epochs = list(map(int, flags.lr_epochs.split(','))) model_settings = prepare_model_settings( len(prepare_words_list(flags.wanted_words.split(','))), flags.sample_rate, flags.clip_duration_ms, flags.window_size_ms, flags.window_stride_ms, flags.dct_coefficient_count) model_settings['dropout1'] = flags.drop return flags, model_settings def eval_config(parser): '''config for eval.''' parser.add_argument('--feat_dir', type=str, default='feat', help='Where to save the feature of audios') parser.add_argument('--model_dir', type=str, default='outputs', help='which folder the models are saved in or specific path of one model') parser.add_argument('--wanted_words', type=str, default='yes,no,up,down,left,right,on,off,stop,go', help='Words to use (others will be added to an unknown label)') parser.add_argument('--sample_rate', type=int, default=16000, help='Expected sample rate of the wavs') parser.add_argument('--clip_duration_ms', type=int, default=1000, help='Expected duration in milliseconds of the wavs') parser.add_argument('--window_size_ms', type=float, default=40.0, help='How long each spectrogram timeslice is') parser.add_argument('--window_stride_ms', type=float, default=20.0, help='How long each spectrogram timeslice is') parser.add_argument('--dct_coefficient_count', type=int, default=20, help='How many bins to use for the MFCC fingerprint') parser.add_argument('--model_size_info', type=int, nargs="+", default=[6, 276, 10, 4, 2, 1, 276, 3, 3, 2, 2, 276, 3, 3, 1, 1, 276, 3, 3, 1, 1, 276, 3, 3, 1, 1, 276, 3, 3, 1, 1], help='Model dimensions - different for various models') parser.add_argument('--per_batch_size', default=100, type=int, help='batch size for per gpu') parser.add_argument('--drop', type=float, default=0.9, help='dropout') # logging related parser.add_argument('--log_path', type=str, default='eval_outputs/', help='path to save eval log') flags, _ = parser.parse_known_args() model_settings = prepare_model_settings( len(prepare_words_list(flags.wanted_words.split(','))), flags.sample_rate, flags.clip_duration_ms, flags.window_size_ms, flags.window_stride_ms, flags.dct_coefficient_count) model_settings['dropout1'] = flags.drop return flags, model_settings def prepare_model_settings(label_count, sample_rate, clip_duration_ms, window_size_ms, window_stride_ms, dct_coefficient_count): '''Prepare model setting.''' desired_samples = int(sample_rate * clip_duration_ms / 1000) window_size_samples = int(sample_rate * window_size_ms / 1000) window_stride_samples = int(sample_rate * window_stride_ms / 1000) length_minus_window = (desired_samples - window_size_samples) if length_minus_window < 0: spectrogram_length = 0 else: spectrogram_length = 1 + int(length_minus_window / window_stride_samples) fingerprint_size = dct_coefficient_count * spectrogram_length return { 'desired_samples': desired_samples, 'window_size_samples': window_size_samples, 'window_stride_samples': window_stride_samples, 'spectrogram_length': spectrogram_length, 'dct_coefficient_count': dct_coefficient_count, 'fingerprint_size': fingerprint_size, 'label_count': label_count, 'sample_rate': sample_rate, }