Browse Source

[to #42322933] test: use custom config to reduce test time

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10011826
master
bin.xue yingda.chen 3 years ago
parent
commit
b870e4eed5
3 changed files with 15 additions and 6 deletions
  1. +3
    -3
      modelscope/models/audio/ans/complex_nn.py
  2. +3
    -2
      modelscope/models/audio/ans/unet.py
  3. +9
    -1
      tests/trainers/audio/test_ans_trainer.py

+ 3
- 3
modelscope/models/audio/ans/complex_nn.py View File

@@ -1,7 +1,7 @@
"""
class ComplexConv2d, ComplexConvTranspose2d and ComplexBatchNorm2d are the work of
Jongho Choi(sweetcocoa@snu.ac.kr / Seoul National Univ., ESTsoft ).
from https://github.com/sweetcocoa/DeepComplexUNetPyTorch
The implementation of class ComplexConv2d, ComplexConvTranspose2d and ComplexBatchNorm2d
here is modified based on Jongho Choi(sweetcocoa@snu.ac.kr / Seoul National Univ., ESTsoft )
and publicly available at https://github.com/sweetcocoa/DeepComplexUNetPyTorch

"""
import torch


+ 3
- 2
modelscope/models/audio/ans/unet.py View File

@@ -1,6 +1,7 @@
"""
Based on the work of Jongho Choi(sweetcocoa@snu.ac.kr / Seoul National Univ., ESTsoft ).
from https://github.com/sweetcocoa/DeepComplexUNetPyTorch
The implementation here is modified based on
Jongho Choi(sweetcocoa@snu.ac.kr / Seoul National Univ., ESTsoft )
and publicly available at https://github.com/sweetcocoa/DeepComplexUNetPyTorch
"""
import torch
import torch.nn as nn


+ 9
- 1
tests/trainers/audio/test_ans_trainer.py View File

@@ -8,12 +8,14 @@ from modelscope.metainfo import Trainers
from modelscope.msdatasets import MsDataset
from modelscope.trainers import build_trainer
from modelscope.utils.audio.audio_utils import to_segment
from modelscope.utils.hub import read_config
from modelscope.utils.test_utils import test_level

SEGMENT_LENGTH_TEST = 640


class TestANSTrainer(unittest.TestCase):
REVISION = 'beta'

def setUp(self):
self.tmp_dir = tempfile.TemporaryDirectory().name
@@ -21,6 +23,11 @@ class TestANSTrainer(unittest.TestCase):
os.makedirs(self.tmp_dir)

self.model_id = 'damo/speech_frcrn_ans_cirm_16k'
cfg = read_config(self.model_id, revision=self.REVISION)
cfg.train.max_epochs = 2
cfg.train.dataloader.batch_size_per_gpu = 1
self.cfg_file = os.path.join(self.tmp_dir, 'train_config.json')
cfg.dump(self.cfg_file)

hf_ds = MsDataset.load(
'ICASSP_2021_DNS_Challenge', split='test').to_hf_dataset()
@@ -39,12 +46,13 @@ class TestANSTrainer(unittest.TestCase):
def test_trainer(self):
kwargs = dict(
model=self.model_id,
model_revision='beta',
model_revision=self.REVISION,
train_dataset=self.dataset,
eval_dataset=self.dataset,
max_epochs=2,
train_iters_per_epoch=2,
val_iters_per_epoch=1,
cfg_file=self.cfg_file,
work_dir=self.tmp_dir)

trainer = build_trainer(


Loading…
Cancel
Save