|
|
|
@@ -16,6 +16,7 @@ |
|
|
|
from collections.abc import Iterable |
|
|
|
|
|
|
|
import os |
|
|
|
import math |
|
|
|
import numpy as np |
|
|
|
|
|
|
|
from mindspore import log as logger |
|
|
|
@@ -402,7 +403,7 @@ class Model: |
|
|
|
if sink_size == -1: |
|
|
|
epoch_num = epoch |
|
|
|
else: |
|
|
|
epoch_num = epoch * sink_size // train_dataset.get_dataset_size() |
|
|
|
epoch_num = math.ceil(epoch * sink_size / train_dataset.get_dataset_size()) |
|
|
|
|
|
|
|
dataset_helper, train_network = self._exec_preprocess(self._train_network, |
|
|
|
is_train=True, |
|
|
|
|