diff --git a/modelscope/models/audio/ans/complex_nn.py b/modelscope/models/audio/ans/complex_nn.py index c61446c2..9768eff7 100644 --- a/modelscope/models/audio/ans/complex_nn.py +++ b/modelscope/models/audio/ans/complex_nn.py @@ -1,7 +1,7 @@ """ -class ComplexConv2d, ComplexConvTranspose2d and ComplexBatchNorm2d are the work of -Jongho Choi(sweetcocoa@snu.ac.kr / Seoul National Univ., ESTsoft ). -from https://github.com/sweetcocoa/DeepComplexUNetPyTorch +The implementation of class ComplexConv2d, ComplexConvTranspose2d and ComplexBatchNorm2d + here is modified based on Jongho Choi(sweetcocoa@snu.ac.kr / Seoul National Univ., ESTsoft ) +and publicly available at https://github.com/sweetcocoa/DeepComplexUNetPyTorch """ import torch diff --git a/modelscope/models/audio/ans/unet.py b/modelscope/models/audio/ans/unet.py index ae66eb69..3a9c5549 100644 --- a/modelscope/models/audio/ans/unet.py +++ b/modelscope/models/audio/ans/unet.py @@ -1,6 +1,7 @@ """ -Based on the work of Jongho Choi(sweetcocoa@snu.ac.kr / Seoul National Univ., ESTsoft ). -from https://github.com/sweetcocoa/DeepComplexUNetPyTorch +The implementation here is modified based on + Jongho Choi(sweetcocoa@snu.ac.kr / Seoul National Univ., ESTsoft ) +and publicly available at https://github.com/sweetcocoa/DeepComplexUNetPyTorch """ import torch import torch.nn as nn diff --git a/tests/trainers/audio/test_ans_trainer.py b/tests/trainers/audio/test_ans_trainer.py index 176c811f..ed8cd1fe 100644 --- a/tests/trainers/audio/test_ans_trainer.py +++ b/tests/trainers/audio/test_ans_trainer.py @@ -8,12 +8,14 @@ from modelscope.metainfo import Trainers from modelscope.msdatasets import MsDataset from modelscope.trainers import build_trainer from modelscope.utils.audio.audio_utils import to_segment +from modelscope.utils.hub import read_config from modelscope.utils.test_utils import test_level SEGMENT_LENGTH_TEST = 640 class TestANSTrainer(unittest.TestCase): + REVISION = 'beta' def setUp(self): self.tmp_dir = tempfile.TemporaryDirectory().name @@ -21,6 +23,11 @@ class TestANSTrainer(unittest.TestCase): os.makedirs(self.tmp_dir) self.model_id = 'damo/speech_frcrn_ans_cirm_16k' + cfg = read_config(self.model_id, revision=self.REVISION) + cfg.train.max_epochs = 2 + cfg.train.dataloader.batch_size_per_gpu = 1 + self.cfg_file = os.path.join(self.tmp_dir, 'train_config.json') + cfg.dump(self.cfg_file) hf_ds = MsDataset.load( 'ICASSP_2021_DNS_Challenge', split='test').to_hf_dataset() @@ -39,12 +46,13 @@ class TestANSTrainer(unittest.TestCase): def test_trainer(self): kwargs = dict( model=self.model_id, - model_revision='beta', + model_revision=self.REVISION, train_dataset=self.dataset, eval_dataset=self.dataset, max_epochs=2, train_iters_per_epoch=2, val_iters_per_epoch=1, + cfg_file=self.cfg_file, work_dir=self.tmp_dir) trainer = build_trainer(