|
|
@@ -157,6 +157,12 @@ class Model: |
|
|
if arg not in ['loss_scale_manager', 'keep_batchnorm_fp32']: |
|
|
if arg not in ['loss_scale_manager', 'keep_batchnorm_fp32']: |
|
|
raise ValueError(f"Unsupported arg '{arg}'") |
|
|
raise ValueError(f"Unsupported arg '{arg}'") |
|
|
|
|
|
|
|
|
|
|
|
def _check_reuse_dataset(self, dataset): |
|
|
|
|
|
if not hasattr(dataset, '__model_hash__'): |
|
|
|
|
|
dataset.__model_hash__ = hash(self) |
|
|
|
|
|
if hasattr(dataset, '__model_hash__') and dataset.__model_hash__ != hash(self): |
|
|
|
|
|
raise RuntimeError('The Dataset cannot be bound to different models, please create a new dataset.') |
|
|
|
|
|
|
|
|
def _build_train_network(self): |
|
|
def _build_train_network(self): |
|
|
"""Build train network""" |
|
|
"""Build train network""" |
|
|
network = self._network |
|
|
network = self._network |
|
|
@@ -388,6 +394,7 @@ class Model: |
|
|
"So the training process will be performed with dataset not sink.") |
|
|
"So the training process will be performed with dataset not sink.") |
|
|
self._train_process(epoch, train_dataset, list_callback, cb_params) |
|
|
self._train_process(epoch, train_dataset, list_callback, cb_params) |
|
|
else: |
|
|
else: |
|
|
|
|
|
self._check_reuse_dataset(train_dataset) |
|
|
self._train_dataset_sink_process(epoch, train_dataset, list_callback, cb_params, sink_size) |
|
|
self._train_dataset_sink_process(epoch, train_dataset, list_callback, cb_params, sink_size) |
|
|
|
|
|
|
|
|
@staticmethod |
|
|
@staticmethod |
|
|
|