Browse Source

fix getnext error in pynative

tags/v1.1.0
chujinjin 5 years ago
parent
commit
07b965dad9
1 changed files with 7 additions and 2 deletions
  1. +7
    -2
      mindspore/train/dataset_helper.py

+ 7
- 2
mindspore/train/dataset_helper.py View File

@@ -93,8 +93,9 @@ def connect_network_with_dataset(network, dataset_helper):
raise RuntimeError("Dataset should be connected with network only in sink mode.") raise RuntimeError("Dataset should be connected with network only in sink mode.")


if (hasattr(dataset_iter, "sink_size") and dataset_iter.sink_size == 1) \ if (hasattr(dataset_iter, "sink_size") and dataset_iter.sink_size == 1) \
and (hasattr(dataset_iter, "sink_count") and dataset_iter.sink_count == 1) \
and context.get_context("device_target") == "Ascend":
and (hasattr(dataset_iter, "sink_count") and dataset_iter.sink_count == 1) \
and context.get_context("device_target") == "Ascend" \
and context.get_context("mode") == context.GRAPH_MODE:


if not hasattr(dataset, '__network__'): if not hasattr(dataset, '__network__'):
dataset.__network__ = network dataset.__network__ = network
@@ -206,6 +207,7 @@ class DatasetHelper:
def get_data_info(self): def get_data_info(self):
return self.iter.get_data_info() return self.iter.get_data_info()



class _DatasetIter: class _DatasetIter:
"""Base iter for dataset helper""" """Base iter for dataset helper"""


@@ -286,6 +288,7 @@ class _DatasetIterGE(_DatasetIter):


self.op = op self.op = op



class _DatasetIterPyNative(_DatasetIter): class _DatasetIterPyNative(_DatasetIter):
"""Iter for MS(enable_loop_sink=False).""" """Iter for MS(enable_loop_sink=False)."""


@@ -301,6 +304,7 @@ class _DatasetIterPyNative(_DatasetIter):


self.op = op self.op = op



class _DatasetIterMSLoopSink(_DatasetIter): class _DatasetIterMSLoopSink(_DatasetIter):
"""Iter for context (device_target=Ascend)""" """Iter for context (device_target=Ascend)"""


@@ -354,6 +358,7 @@ class _DatasetIterPSLite(_DatasetIter):


class _DatasetIterNormal: class _DatasetIterNormal:
"""Iter for normal(non sink) mode, feed the data from host.""" """Iter for normal(non sink) mode, feed the data from host."""

def __init__(self, dataset, epoch_num=-1): def __init__(self, dataset, epoch_num=-1):
self.dataset = dataset self.dataset = dataset
self.device_num = _get_device_num() self.device_num = _get_device_num()


Loading…
Cancel
Save