Browse Source

fix bug of using pool

pull/13823/head
liyong 4 years ago
parent
commit
dc03f62f4b
1 changed files with 14 additions and 0 deletions
  1. +14
    -0
      mindspore/dataset/engine/datasets.py

+ 14
- 0
mindspore/dataset/engine/datasets.py View File

@@ -176,6 +176,15 @@ class Dataset:
_init_device_info()
return ir_tree, dataset

def close_pool(self):
"""
Close multiprocessing pool in dataset.
"""
if hasattr(self, 'process_pool') and self.process_pool is not None:
self.process_pool.close()
for child in self.children:
child.close_pool()

@staticmethod
def _get_operator_id(dataset):
"""
@@ -1448,6 +1457,7 @@ class Dataset:
if self._col_names is None:
runtime_getter = self._init_tree_getters()
self._col_names = runtime_getter[0].GetColumnNames()
self.close_pool()
return self._col_names

def output_shapes(self):
@@ -1461,6 +1471,7 @@ class Dataset:
runtime_getter = self._init_tree_getters()
self.saved_output_shapes = runtime_getter[0].GetOutputShapes()
self.saved_output_types = runtime_getter[0].GetOutputTypes()
self.close_pool()
return self.saved_output_shapes

def output_types(self):
@@ -1474,6 +1485,7 @@ class Dataset:
runtime_getter = self._init_tree_getters()
self.saved_output_shapes = runtime_getter[0].GetOutputShapes()
self.saved_output_types = runtime_getter[0].GetOutputTypes()
self.close_pool()
return self.saved_output_types

def get_dataset_size(self):
@@ -1486,6 +1498,7 @@ class Dataset:
if self.dataset_size is None:
runtime_getter = self._init_size_getter()
self.dataset_size = runtime_getter[0].GetDatasetSize(False)
self.close_pool()
return self.dataset_size

def num_classes(self):
@@ -1498,6 +1511,7 @@ class Dataset:
if self._num_classes is None:
runtime_getter = self._init_tree_getters()
self._num_classes = runtime_getter[0].GetNumClasses()
self.close_pool()
if self._num_classes == -1:
return None
return self._num_classes


Loading…
Cancel
Save