Browse Source

Fix copy _batch_size of dataset

tags/v1.1.0
luoyang 5 years ago
parent
commit
47f20f6547
1 changed files with 4 additions and 0 deletions
  1. +4
    -0
      mindspore/dataset/engine/datasets.py

+ 4
- 0
mindspore/dataset/engine/datasets.py View File

@@ -1359,6 +1359,9 @@ class Dataset:
def input_indexs(self, value):
self._input_indexs = value

def copy_batch_size(self, value):
self._batch_size = value

def _init_tree_getters(self):
"""
Get pipeline information.
@@ -1931,6 +1934,7 @@ class BatchDataset(Dataset):
new_op.saved_output_types = self.saved_output_types
new_op.saved_output_shapes = self.saved_output_shapes
new_op.input_indexs = copy.deepcopy(self._input_indexs, memodict)
new_op.copy_batch_size(copy.deepcopy(self._batch_size, memodict))
new_op.dataset_size = self.dataset_size
new_op.pad = self.pad
new_op.python_multiprocessing = copy.deepcopy(self.python_multiprocessing, memodict)


Loading…
Cancel
Save