You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

eval.py 5.5 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ===========================================================================
  15. """
  16. Eval DeepSpeech2
  17. """
  18. import argparse
  19. import json
  20. import pickle
  21. import numpy as np
  22. from src.config import eval_config
  23. from src.deepspeech2 import DeepSpeechModel, PredictWithSoftmax
  24. from src.dataset import create_dataset
  25. from src.greedydecoder import MSGreedyDecoder
  26. from mindspore import context
  27. from mindspore.train.serialization import load_checkpoint, load_param_into_net
  28. parser = argparse.ArgumentParser(description='DeepSpeech evaluation')
  29. parser.add_argument('--bidirectional', action="store_false", default=True, help='Use bidirectional RNN')
  30. parser.add_argument('--pretrain_ckpt', type=str, default='', help='Pretrained checkpoint path')
  31. parser.add_argument('--device_target', type=str, default="GPU", choices=("GPU", "CPU"),
  32. help='Device target, support GPU and CPU, Default: GPU')
  33. args = parser.parse_args()
  34. if __name__ == '__main__':
  35. context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, save_graphs=False)
  36. config = eval_config
  37. with open(config.DataConfig.labels_path) as label_file:
  38. labels = json.load(label_file)
  39. model = PredictWithSoftmax(DeepSpeechModel(batch_size=config.DataConfig.batch_size,
  40. rnn_hidden_size=config.ModelConfig.hidden_size,
  41. nb_layers=config.ModelConfig.hidden_layers,
  42. labels=labels,
  43. rnn_type=config.ModelConfig.rnn_type,
  44. audio_conf=config.DataConfig.SpectConfig,
  45. bidirectional=args.bidirectional))
  46. ds_eval = create_dataset(audio_conf=config.DataConfig.SpectConfig,
  47. manifest_filepath=config.DataConfig.test_manifest,
  48. labels=labels, normalize=True, train_mode=False,
  49. batch_size=config.DataConfig.batch_size, rank=0, group_size=1)
  50. param_dict = load_checkpoint(args.pretrain_ckpt)
  51. load_param_into_net(model, param_dict)
  52. print('Successfully loading the pre-trained model')
  53. if config.LMConfig.decoder_type == 'greedy':
  54. decoder = MSGreedyDecoder(labels=labels, blank_index=labels.index('_'))
  55. else:
  56. raise NotImplementedError("Only greedy decoder is supported now")
  57. target_decoder = MSGreedyDecoder(labels, blank_index=labels.index('_'))
  58. model.set_train(False)
  59. total_cer, total_wer, num_tokens, num_chars = 0, 0, 0, 0
  60. output_data = []
  61. for data in ds_eval.create_dict_iterator():
  62. inputs, input_length, target_indices, targets = data['inputs'], data['input_length'], data['target_indices'], \
  63. data['label_values']
  64. split_targets = []
  65. start, count, last_id = 0, 0, 0
  66. target_indices, targets = target_indices.asnumpy(), targets.asnumpy()
  67. for i in range(np.shape(targets)[0]):
  68. if target_indices[i, 0] == last_id:
  69. count += 1
  70. else:
  71. split_targets.append(list(targets[start:count]))
  72. last_id += 1
  73. start = count
  74. count += 1
  75. split_targets.append(list(targets[start:]))
  76. out, output_sizes = model(inputs, input_length)
  77. decoded_output, _ = decoder.decode(out, output_sizes)
  78. target_strings = target_decoder.convert_to_strings(split_targets)
  79. if config.save_output is not None:
  80. output_data.append((out.asnumpy(), output_sizes.asnumpy(), target_strings))
  81. for doutput, toutput in zip(decoded_output, target_strings):
  82. transcript, reference = doutput[0], toutput[0]
  83. wer_inst = decoder.wer(transcript, reference)
  84. cer_inst = decoder.cer(transcript, reference)
  85. total_wer += wer_inst
  86. total_cer += cer_inst
  87. num_tokens += len(reference.split())
  88. num_chars += len(reference.replace(' ', ''))
  89. if config.verbose:
  90. print("Ref:", reference.lower())
  91. print("Hyp:", transcript.lower())
  92. print("WER:", float(wer_inst) / len(reference.split()),
  93. "CER:", float(cer_inst) / len(reference.replace(' ', '')), "\n")
  94. wer = float(total_wer) / num_tokens
  95. cer = float(total_cer) / num_chars
  96. print('Test Summary \t'
  97. 'Average WER {wer:.3f}\t'
  98. 'Average CER {cer:.3f}\t'.format(wer=wer * 100, cer=cer * 100))
  99. if config.save_output is not None:
  100. with open(config.save_output + '.bin', 'wb') as output:
  101. pickle.dump(output_data, output)