diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 659eac23c3..1a00ee7d50 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -176,6 +176,15 @@ class Dataset: _init_device_info() return ir_tree, dataset + def close_pool(self): + """ + Close multiprocessing pool in dataset. + """ + if hasattr(self, 'process_pool') and self.process_pool is not None: + self.process_pool.close() + for child in self.children: + child.close_pool() + @staticmethod def _get_operator_id(dataset): """ @@ -1448,6 +1457,7 @@ class Dataset: if self._col_names is None: runtime_getter = self._init_tree_getters() self._col_names = runtime_getter[0].GetColumnNames() + self.close_pool() return self._col_names def output_shapes(self): @@ -1461,6 +1471,7 @@ class Dataset: runtime_getter = self._init_tree_getters() self.saved_output_shapes = runtime_getter[0].GetOutputShapes() self.saved_output_types = runtime_getter[0].GetOutputTypes() + self.close_pool() return self.saved_output_shapes def output_types(self): @@ -1474,6 +1485,7 @@ class Dataset: runtime_getter = self._init_tree_getters() self.saved_output_shapes = runtime_getter[0].GetOutputShapes() self.saved_output_types = runtime_getter[0].GetOutputTypes() + self.close_pool() return self.saved_output_types def get_dataset_size(self): @@ -1486,6 +1498,7 @@ class Dataset: if self.dataset_size is None: runtime_getter = self._init_size_getter() self.dataset_size = runtime_getter[0].GetDatasetSize(False) + self.close_pool() return self.dataset_size def num_classes(self): @@ -1498,6 +1511,7 @@ class Dataset: if self._num_classes is None: runtime_getter = self._init_tree_getters() self._num_classes = runtime_getter[0].GetNumClasses() + self.close_pool() if self._num_classes == -1: return None return self._num_classes