|
|
|
@@ -69,11 +69,14 @@ class KWSFarfieldTrainer(BaseTrainer): |
|
|
|
|
|
|
|
super().__init__(cfg_file, arg_parse_fn) |
|
|
|
|
|
|
|
self.model = self.build_model() |
|
|
|
self.work_dir = work_dir |
|
|
|
# the number of model output dimension |
|
|
|
# should update config outside the trainer, if user need more wake word |
|
|
|
num_syn = kwargs.get('num_syn', None) |
|
|
|
if num_syn: |
|
|
|
self.cfg.model.num_syn = num_syn |
|
|
|
self._num_classes = self.cfg.model.num_syn |
|
|
|
self.model = self.build_model() |
|
|
|
self.work_dir = work_dir |
|
|
|
|
|
|
|
if kwargs.get('launcher', None) is not None: |
|
|
|
init_dist(kwargs['launcher']) |
|
|
|
|