Browse Source

use from_numpy and add do_copy option

tags/v1.1.0
ms_yan 5 years ago
parent
commit
deb1e6e965
13 changed files with 23 additions and 18 deletions
  1. +5
    -3
      mindspore/dataset/engine/datasets.py
  2. +7
    -4
      mindspore/dataset/engine/iterators.py
  3. +1
    -1
      mindspore/dataset/engine/validators.py
  4. +1
    -1
      mindspore/train/dataset_helper.py
  5. +1
    -1
      tests/dataset_mock.py
  6. +1
    -1
      tests/st/auto_parallel/optimizer_parallel.py
  7. +1
    -1
      tests/st/auto_parallel/parallel_strategy_search.py
  8. +1
    -1
      tests/st/pynative/loss_scale/test_loss_scale.py
  9. +1
    -1
      tests/ut/python/parallel/test_auto_parallel_double_subgraphs.py
  10. +1
    -1
      tests/ut/python/parallel/test_auto_parallel_resnet.py
  11. +1
    -1
      tests/ut/python/parallel/test_bias_add.py
  12. +1
    -1
      tests/ut/python/parallel/test_gather_v2_primitive.py
  13. +1
    -1
      tests/ut/python/parallel/test_pipeline_split.py

+ 5
- 3
mindspore/dataset/engine/datasets.py View File

@@ -1255,7 +1255,7 @@ class Dataset:
del api_tree

@check_tuple_iterator
def create_tuple_iterator(self, columns=None, num_epochs=-1, output_numpy=False):
def create_tuple_iterator(self, columns=None, num_epochs=-1, output_numpy=False, do_copy=True):
"""
Create an iterator over the dataset. The data retrieved will be a list of ndarrays of data.

@@ -1269,6 +1269,8 @@ class Dataset:
(default=-1, iterator can be iterated infinite number of epochs)
output_numpy (bool, optional): Whether or not to output NumPy datatype.
If output_numpy=False, iterator will output MSTensor (default=False).
do_copy (bool, optional): when output data type is mindspore.Tensor,
use this param to select the conversion method, only take False for better performance (default=True).

Returns:
Iterator, list of ndarrays.
@@ -1290,7 +1292,7 @@ class Dataset:

if Dataset._noop_mode():
return DummyIterator(self, 'tuple')
return TupleIterator(self, columns, num_epochs, output_numpy)
return TupleIterator(self, columns, num_epochs, output_numpy, do_copy)

@check_dict_iterator
def create_dict_iterator(self, num_epochs=-1, output_numpy=False):
@@ -2788,7 +2790,7 @@ class TransferDataset(Dataset):
def create_dict_iterator(self, num_epochs=-1, output_numpy=False):
raise RuntimeError("TransferDataset is not iterable.")

def create_tuple_iterator(self, columns=None, num_epochs=-1, output_numpy=False):
def create_tuple_iterator(self, columns=None, num_epochs=-1, output_numpy=False, do_copy=True):
raise RuntimeError("TransferDataset is not iterable.")

def __iter__(self):


+ 7
- 4
mindspore/dataset/engine/iterators.py View File

@@ -63,7 +63,7 @@ class Iterator:
dataset: Dataset to be iterated over
"""

def __init__(self, dataset, num_epochs=-1, output_numpy=False):
def __init__(self, dataset, num_epochs=-1, output_numpy=False, do_copy=True):
self._col_names = None

# create a copy of tree and work on it.
@@ -80,7 +80,10 @@ class Iterator:

self._transform_tensor = lambda t: t.as_array()
if not output_numpy:
self._transform_tensor = lambda t: Tensor(t.as_array())
if do_copy:
self._transform_tensor = lambda t: Tensor(t.as_array())
else:
self._transform_tensor = lambda t: Tensor.from_numpy(t.as_array())
self._index = 0

# todo remove next when ContextManager is done
@@ -179,13 +182,13 @@ class TupleIterator(Iterator):
The derived class of Iterator with list type.
"""

def __init__(self, dataset, columns=None, num_epochs=-1, output_numpy=False):
def __init__(self, dataset, columns=None, num_epochs=-1, output_numpy=False, do_copy=True):
if columns is not None:
if not isinstance(columns, list):
columns = [columns]
# todo: move next to IR
dataset = dataset.project(columns)
super().__init__(dataset, num_epochs, output_numpy)
super().__init__(dataset, num_epochs, output_numpy, do_copy)

def _get_next(self):
"""


+ 1
- 1
mindspore/dataset/engine/validators.py View File

@@ -298,7 +298,7 @@ def check_tuple_iterator(method):

@wraps(method)
def new_method(self, *args, **kwargs):
[columns, num_epochs, _], param_dict = parse_user_args(method, *args, **kwargs)
[columns, num_epochs, _, _], param_dict = parse_user_args(method, *args, **kwargs)
nreq_param_bool = ['output_numpy']
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
if num_epochs is not None:


+ 1
- 1
mindspore/train/dataset_helper.py View File

@@ -394,7 +394,7 @@ class _DatasetIterNormal:
self.dataset = dataset
self.device_num = _get_device_num()
self.global_rank = _get_global_rank()
self.iter = self.dataset.create_tuple_iterator(num_epochs=epoch_num)
self.iter = self.dataset.create_tuple_iterator(num_epochs=epoch_num, do_copy=False)

def __iter__(self):
return self


+ 1
- 1
tests/dataset_mock.py View File

@@ -55,7 +55,7 @@ class MindData:
self.send_epoch_end = send_epoch_end
return self

def create_tuple_iterator(self, num_epochs=-1):
def create_tuple_iterator(self, num_epochs=-1, do_copy=True):
return self.__iter__()

def send(self, num_epochs=-1):


+ 1
- 1
tests/st/auto_parallel/optimizer_parallel.py View File

@@ -125,7 +125,7 @@ class FakeData:
def set_label_onehot(self, is_onehot=True):
self.is_onehot = is_onehot

def create_tuple_iterator(self, num_epochs=-1):
def create_tuple_iterator(self, num_epochs=-1, do_copy=True):
_ = num_epochs
return self



+ 1
- 1
tests/st/auto_parallel/parallel_strategy_search.py View File

@@ -128,7 +128,7 @@ class FakeData:
def set_label_onehot(self, is_onehot=True):
self.is_onehot = is_onehot

def create_tuple_iterator(self, num_epochs=-1):
def create_tuple_iterator(self, num_epochs=-1, do_copy=True):
_ = num_epochs
return self



+ 1
- 1
tests/st/pynative/loss_scale/test_loss_scale.py View File

@@ -60,7 +60,7 @@ class MindData:
def output_shapes(self):
return self._output_shapes

def create_tuple_iterator(self, num_epochs=-1):
def create_tuple_iterator(self, num_epochs=-1, do_copy=True):
return self

@property


+ 1
- 1
tests/ut/python/parallel/test_auto_parallel_double_subgraphs.py View File

@@ -152,7 +152,7 @@ class DatasetLenet():
def get_repeat_count(self):
return 1

def create_tuple_iterator(self, num_epochs=-1):
def create_tuple_iterator(self, num_epochs=-1, do_copy=True):
return self

def test_double_subgraphs_train():


+ 1
- 1
tests/ut/python/parallel/test_auto_parallel_resnet.py View File

@@ -275,7 +275,7 @@ class DatasetLenet():
def get_repeat_count(self):
return 1

def create_tuple_iterator(self, num_epochs=-1):
def create_tuple_iterator(self, num_epochs=-1, do_copy=True):
return self




+ 1
- 1
tests/ut/python/parallel/test_bias_add.py View File

@@ -61,7 +61,7 @@ class DatasetLenet():
def get_repeat_count(self):
return 1

def create_tuple_iterator(self, num_epochs=-1):
def create_tuple_iterator(self, num_epochs=-1, do_copy=True):
return self




+ 1
- 1
tests/ut/python/parallel/test_gather_v2_primitive.py View File

@@ -59,7 +59,7 @@ class Dataset():
def get_repeat_count(self):
return 1

def create_tuple_iterator(self, num_epochs=-1):
def create_tuple_iterator(self, num_epochs=-1, do_copy=True):
return self




+ 1
- 1
tests/ut/python/parallel/test_pipeline_split.py View File

@@ -51,7 +51,7 @@ class DatasetLenet():
def get_batch_size(self):
return 32

def create_tuple_iterator(self, num_epochs=1):
def create_tuple_iterator(self, num_epochs=1, do_copy=True):
return self




Loading…
Cancel
Save