diff --git a/mindspore/train/model.py b/mindspore/train/model.py index 46e4f421f7..8ce51562fa 100755 --- a/mindspore/train/model.py +++ b/mindspore/train/model.py @@ -254,7 +254,8 @@ class Model: """ # remove later to deal with loop sink need_wrap = False - if not hasattr(train_dataset, '__ME_INITED__') and context.get_context("enable_loop_sink"): + if not hasattr(train_dataset, '__ME_INITED__') and context.get_context("enable_loop_sink") \ + and not context.get_context("enable_ge"): need_wrap = True dataset_helper = DatasetHelper(train_dataset) @@ -418,7 +419,8 @@ class Model: # remove later to deal with loop sink need_wrap = False - if not hasattr(valid_dataset, '__ME_INITED__') and context.get_context("enable_loop_sink"): + if not hasattr(valid_dataset, '__ME_INITED__') and context.get_context("enable_loop_sink") \ + and not context.get_context("enable_ge"): need_wrap = True valid_dataset.__loop_size__ = 1