Refactor tts task inputs
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9412937
master
| @@ -25,22 +25,19 @@ class TextToSpeechSambertHifiganPipeline(Pipeline): | |||||
| """ | """ | ||||
| super().__init__(model=model, **kwargs) | super().__init__(model=model, **kwargs) | ||||
| def forward(self, inputs: Dict[str, str]) -> Dict[str, np.ndarray]: | |||||
| def forward(self, input: str, **forward_params) -> Dict[str, np.ndarray]: | |||||
| """synthesis text from inputs with pipeline | """synthesis text from inputs with pipeline | ||||
| Args: | Args: | ||||
| inputs (Dict[str, str]): a dictionary that key is the name of | |||||
| certain testcase and value is the text to synthesis. | |||||
| input (str): text to synthesis | |||||
| forward_params: valid param is 'voice' used to setting speaker vocie | |||||
| Returns: | Returns: | ||||
| Dict[str, np.ndarray]: a dictionary with key and value. The key | |||||
| is the same as inputs' key which is the label of the testcase | |||||
| and the value is the pcm audio data. | |||||
| Dict[str, np.ndarray]: {OutputKeys.OUTPUT_PCM : np.ndarray(16bit pcm data)} | |||||
| """ | """ | ||||
| output_wav = {} | |||||
| for label, text in inputs.items(): | |||||
| output_wav[label] = self.model.forward(text, inputs.get('voice')) | |||||
| output_wav = self.model.forward(input, forward_params.get('voice')) | |||||
| return {OutputKeys.OUTPUT_PCM: output_wav} | return {OutputKeys.OUTPUT_PCM: output_wav} | ||||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
| def postprocess(self, inputs: Dict[str, Any], | |||||
| **postprocess_params) -> Dict[str, Any]: | |||||
| return inputs | return inputs | ||||
| def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]: | def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]: | ||||
| @@ -10,7 +10,7 @@ nara_wpe | |||||
| numpy<=1.18 | numpy<=1.18 | ||||
| protobuf>3,<=3.20 | protobuf>3,<=3.20 | ||||
| ptflops | ptflops | ||||
| pytorch_wavelets==1.3.0 | |||||
| pytorch_wavelets | |||||
| PyWavelets>=1.0.0 | PyWavelets>=1.0.0 | ||||
| scikit-learn | scikit-learn | ||||
| SoundFile>0.10 | SoundFile>0.10 | ||||
| @@ -24,7 +24,6 @@ class TextToSpeechSambertHifigan16kPipelineTest(unittest.TestCase): | |||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | ||||
| def test_pipeline(self): | def test_pipeline(self): | ||||
| single_test_case_label = 'test_case_label_0' | |||||
| text = '今天北京天气怎么样?' | text = '今天北京天气怎么样?' | ||||
| model_id = 'damo/speech_sambert-hifigan_tts_zhcn_16k' | model_id = 'damo/speech_sambert-hifigan_tts_zhcn_16k' | ||||
| voice = 'zhitian_emo' | voice = 'zhitian_emo' | ||||
| @@ -32,10 +31,9 @@ class TextToSpeechSambertHifigan16kPipelineTest(unittest.TestCase): | |||||
| sambert_hifigan_tts = pipeline( | sambert_hifigan_tts = pipeline( | ||||
| task=Tasks.text_to_speech, model=model_id) | task=Tasks.text_to_speech, model=model_id) | ||||
| self.assertTrue(sambert_hifigan_tts is not None) | self.assertTrue(sambert_hifigan_tts is not None) | ||||
| inputs = {single_test_case_label: text, 'voice': voice} | |||||
| output = sambert_hifigan_tts(inputs) | |||||
| output = sambert_hifigan_tts(input=text, voice=voice) | |||||
| self.assertIsNotNone(output[OutputKeys.OUTPUT_PCM]) | self.assertIsNotNone(output[OutputKeys.OUTPUT_PCM]) | ||||
| pcm = output[OutputKeys.OUTPUT_PCM][single_test_case_label] | |||||
| pcm = output[OutputKeys.OUTPUT_PCM] | |||||
| write('output.wav', 16000, pcm) | write('output.wav', 16000, pcm) | ||||