|
|
|
@@ -8,12 +8,14 @@ from modelscope.metainfo import Trainers |
|
|
|
from modelscope.msdatasets import MsDataset |
|
|
|
from modelscope.trainers import build_trainer |
|
|
|
from modelscope.utils.audio.audio_utils import to_segment |
|
|
|
from modelscope.utils.hub import read_config |
|
|
|
from modelscope.utils.test_utils import test_level |
|
|
|
|
|
|
|
SEGMENT_LENGTH_TEST = 640 |
|
|
|
|
|
|
|
|
|
|
|
class TestANSTrainer(unittest.TestCase): |
|
|
|
REVISION = 'beta' |
|
|
|
|
|
|
|
def setUp(self): |
|
|
|
self.tmp_dir = tempfile.TemporaryDirectory().name |
|
|
|
@@ -21,6 +23,11 @@ class TestANSTrainer(unittest.TestCase): |
|
|
|
os.makedirs(self.tmp_dir) |
|
|
|
|
|
|
|
self.model_id = 'damo/speech_frcrn_ans_cirm_16k' |
|
|
|
cfg = read_config(self.model_id, revision=self.REVISION) |
|
|
|
cfg.train.max_epochs = 2 |
|
|
|
cfg.train.dataloader.batch_size_per_gpu = 1 |
|
|
|
self.cfg_file = os.path.join(self.tmp_dir, 'train_config.json') |
|
|
|
cfg.dump(self.cfg_file) |
|
|
|
|
|
|
|
hf_ds = MsDataset.load( |
|
|
|
'ICASSP_2021_DNS_Challenge', split='test').to_hf_dataset() |
|
|
|
@@ -39,12 +46,13 @@ class TestANSTrainer(unittest.TestCase): |
|
|
|
def test_trainer(self): |
|
|
|
kwargs = dict( |
|
|
|
model=self.model_id, |
|
|
|
model_revision='beta', |
|
|
|
model_revision=self.REVISION, |
|
|
|
train_dataset=self.dataset, |
|
|
|
eval_dataset=self.dataset, |
|
|
|
max_epochs=2, |
|
|
|
train_iters_per_epoch=2, |
|
|
|
val_iters_per_epoch=1, |
|
|
|
cfg_file=self.cfg_file, |
|
|
|
work_dir=self.tmp_dir) |
|
|
|
|
|
|
|
trainer = build_trainer( |
|
|
|
|