Browse Source

fix dataset sink size error in pynative

pull/15500/head
chujinjin 4 years ago
parent
commit
a6db82aeee
2 changed files with 12 additions and 7 deletions
  1. +10
    -0
      mindspore/train/dataset_helper.py
  2. +2
    -7
      mindspore/train/model.py

+ 10
- 0
mindspore/train/dataset_helper.py View File

@@ -154,6 +154,16 @@ def connect_network_with_dataset(network, dataset_helper):
dataset.__me_inited__ = True
dataset_types, dataset_shapes = dataset_helper.types_shapes()
network = _generate_dataset_sink_mode_net(network, dataset_shapes, dataset_types, queue_name)

if hasattr(dataset_iter, "sink_size") and \
dataset_iter.sink_size == 1 and \
dataset.get_dataset_size() != 1 and \
hasattr(dataset_iter, "sink_count") and \
dataset_iter.sink_count == 1 and \
context.get_context("device_target") == "Ascend" and \
context.get_context("mode") == context.PYNATIVE_MODE:
dataset_helper.get_data_info()

return network




+ 2
- 7
mindspore/train/model.py View File

@@ -442,13 +442,8 @@ class Model:
if sink_size == -1:
epoch_num = epoch
else:
if is_graph:
epoch_num = math.ceil(epoch * sink_size / train_dataset.get_dataset_size())
train_dataset.__total_batch__ = epoch * sink_size
else:
sink_size = -1
epoch_num = epoch
logger.warning("Loop sink is not supported in PyNative mode, so it will be performed with no loop sink")
epoch_num = math.ceil(epoch * sink_size / train_dataset.get_dataset_size())
train_dataset.__total_batch__ = epoch * sink_size

cb_params.cur_step_num = 0
cb_params.dataset_sink_mode = True


Loading…
Cancel
Save