diff --git a/mindspore/train/model.py b/mindspore/train/model.py index b99dffadce..b9014aece0 100755 --- a/mindspore/train/model.py +++ b/mindspore/train/model.py @@ -438,18 +438,23 @@ class Model: cb_params (_InternalCallbackParam): Callback parameters. Default: None. sink_size (int): Control the amount of data in each sink. Default: -1. """ + is_graph = (context.get_context("mode") == context.GRAPH_MODE) if sink_size == -1: epoch_num = epoch else: - epoch_num = math.ceil(epoch * sink_size / train_dataset.get_dataset_size()) - train_dataset.__total_batch__ = epoch * sink_size + if is_graph: + epoch_num = math.ceil(epoch * sink_size / train_dataset.get_dataset_size()) + train_dataset.__total_batch__ = epoch * sink_size + else: + sink_size = -1 + epoch_num = epoch + logger.warning("Loop sink is not supported in PyNative mode, so it will be performed with no loop sink") cb_params.cur_step_num = 0 cb_params.dataset_sink_mode = True run_context = RunContext(cb_params) list_callback.begin(run_context) - is_graph = (context.get_context("mode") == context.GRAPH_MODE) # used to stop training for early stop, such as stopAtTIme or stopATStep should_stop = False dataset_helper = None