| @@ -16,6 +16,7 @@ | |||||
| from collections.abc import Iterable | from collections.abc import Iterable | ||||
| import os | import os | ||||
| import math | |||||
| import numpy as np | import numpy as np | ||||
| from mindspore import log as logger | from mindspore import log as logger | ||||
| @@ -402,7 +403,7 @@ class Model: | |||||
| if sink_size == -1: | if sink_size == -1: | ||||
| epoch_num = epoch | epoch_num = epoch | ||||
| else: | else: | ||||
| epoch_num = epoch * sink_size // train_dataset.get_dataset_size() | |||||
| epoch_num = math.ceil(epoch * sink_size / train_dataset.get_dataset_size()) | |||||
| dataset_helper, train_network = self._exec_preprocess(self._train_network, | dataset_helper, train_network = self._exec_preprocess(self._train_network, | ||||
| is_train=True, | is_train=True, | ||||