Browse Source

!10271 Fix copy _batch_size of dataset

From: @luoyang42
Reviewed-by: @heleiwang,@liucunwei
Signed-off-by: @liucunwei
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
12b3a4325c
1 changed files with 4 additions and 0 deletions
  1. +4
    -0
      mindspore/dataset/engine/datasets.py

+ 4
- 0
mindspore/dataset/engine/datasets.py View File

@@ -1359,6 +1359,9 @@ class Dataset:
def input_indexs(self, value): def input_indexs(self, value):
self._input_indexs = value self._input_indexs = value


def copy_batch_size(self, value):
self._batch_size = value

def _init_tree_getters(self): def _init_tree_getters(self):
""" """
Get pipeline information. Get pipeline information.
@@ -1931,6 +1934,7 @@ class BatchDataset(Dataset):
new_op.saved_output_types = self.saved_output_types new_op.saved_output_types = self.saved_output_types
new_op.saved_output_shapes = self.saved_output_shapes new_op.saved_output_shapes = self.saved_output_shapes
new_op.input_indexs = copy.deepcopy(self._input_indexs, memodict) new_op.input_indexs = copy.deepcopy(self._input_indexs, memodict)
new_op.copy_batch_size(copy.deepcopy(self._batch_size, memodict))
new_op.dataset_size = self.dataset_size new_op.dataset_size = self.dataset_size
new_op.pad = self.pad new_op.pad = self.pad
new_op.python_multiprocessing = copy.deepcopy(self.python_multiprocessing, memodict) new_op.python_multiprocessing = copy.deepcopy(self.python_multiprocessing, memodict)


Loading…
Cancel
Save