|
|
|
@@ -24,6 +24,7 @@ from mindspore.nn import WithLossCell, TrainOneStepCell |
|
|
|
from mindspore.nn.optim.momentum import Momentum |
|
|
|
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits |
|
|
|
from mindspore.common import dtype as mstype |
|
|
|
from mindspore.communication.management import get_rank |
|
|
|
from mindspore.train.model import Model |
|
|
|
from mindspore.train.loss_scale_manager import FixedLossScaleManager |
|
|
|
from mindspore.train.serialization import save_checkpoint |
|
|
|
@@ -94,10 +95,6 @@ if __name__ == '__main__': |
|
|
|
features_path = args_opt.dataset_path + '_features' |
|
|
|
idx_list = list(range(step_size)) |
|
|
|
|
|
|
|
if os.path.isdir(config.save_checkpoint_path): |
|
|
|
os.rename(config.save_checkpoint_path, "{}_{}".format(config.save_checkpoint_path, time.time())) |
|
|
|
os.mkdir(config.save_checkpoint_path) |
|
|
|
|
|
|
|
for epoch in range(epoch_size): |
|
|
|
random.shuffle(idx_list) |
|
|
|
epoch_start = time.time() |
|
|
|
@@ -112,7 +109,11 @@ if __name__ == '__main__': |
|
|
|
.format(epoch + 1, step_size, epoch_mseconds, per_step_mseconds, np.mean(np.array(losses))), \ |
|
|
|
end="") |
|
|
|
if (epoch + 1) % config.save_checkpoint_epochs == 0: |
|
|
|
save_checkpoint(network, os.path.join(config.save_checkpoint_path, \ |
|
|
|
rank = 0 |
|
|
|
if config.run_distribute: |
|
|
|
rank = get_rank() |
|
|
|
save_ckpt_path = os.path.join(config.save_checkpoint_path, 'ckpt_' + str(rank) + '/') |
|
|
|
save_checkpoint(network, os.path.join(save_ckpt_path, \ |
|
|
|
f"mobilenetv2_head_{epoch+1}.ckpt")) |
|
|
|
print("total cost {:5.4f} s".format(time.time() - start)) |
|
|
|
|
|
|
|
|