|
|
|
@@ -36,7 +36,6 @@ from src.warmup_cosine_annealing_lr import warmup_cosine_annealing_lr |
|
|
|
from src.utils.logging import get_logger |
|
|
|
from src.utils.optimizers__init__ import get_param_groups |
|
|
|
from src.image_classification import get_network |
|
|
|
from src.utils.auto_mixed_precision import auto_mixed_precision |
|
|
|
from src.config import config |
|
|
|
|
|
|
|
|
|
|
|
@@ -273,8 +272,8 @@ def train(cloud_args=None): |
|
|
|
model = Model(network, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale_manager, |
|
|
|
metrics={'acc'}, amp_level="O3") |
|
|
|
else: |
|
|
|
auto_mixed_precision(network) |
|
|
|
model = Model(network, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale_manager, metrics={'acc'}) |
|
|
|
model = Model(network, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale_manager, |
|
|
|
metrics={'acc'}, amp_level="O2") |
|
|
|
|
|
|
|
# checkpoint save |
|
|
|
progress_cb = ProgressMonitor(args) |
|
|
|
|