Browse Source

!31211 Clean iterators before forking

Merge pull request !31211 from h.farahat/pytest_problem
r1.7
i-robot Gitee 4 years ago
parent
commit
d4fc47ef0c
No known key found for this signature in database GPG Key ID: 173E9B9CA92EEF8F
6 changed files with 74 additions and 43 deletions
  1. +6
    -6
      mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/ir/datasetops/bindings.cc
  2. +16
    -16
      mindspore/ccsrc/minddata/dataset/api/python/python_mp.h
  3. +4
    -4
      mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.cc
  4. +4
    -4
      mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/map_op.cc
  5. +16
    -13
      mindspore/python/mindspore/dataset/engine/datasets.py
  6. +28
    -0
      tests/ut/python/dataset/conftest.py

+ 6
- 6
mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/ir/datasetops/bindings.cc View File

@@ -210,12 +210,12 @@ PYBIND_REGISTER(PythonMultiprocessingRuntime, 1, ([](const py::module *m) {
std::shared_ptr<PythonMultiprocessingRuntime>>(
*m, "PythonMultiprocessingRuntime", "to create a PythonMultiprocessingRuntime")
.def(py::init<>())
.def("Launch", &PythonMultiprocessingRuntime::Launch)
.def("Terminate", &PythonMultiprocessingRuntime::Terminate)
.def("IsMPEnabled", &PythonMultiprocessingRuntime::IsMPEnabled)
.def("AddNewWorkers", &PythonMultiprocessingRuntime::AddNewWorkers)
.def("RemoveWorkers", &PythonMultiprocessingRuntime::RemoveWorkers)
.def("GetPIDs", &PythonMultiprocessingRuntime::GetPIDs);
.def("launch", &PythonMultiprocessingRuntime::launch)
.def("terminate", &PythonMultiprocessingRuntime::terminate)
.def("is_mp_enabled", &PythonMultiprocessingRuntime::is_mp_enabled)
.def("add_new_workers", &PythonMultiprocessingRuntime::add_new_workers)
.def("remove_workers", &PythonMultiprocessingRuntime::remove_workers)
.def("get_pids", &PythonMultiprocessingRuntime::get_pids);
}));

PYBIND_REGISTER(ProjectNode, 2, ([](const py::module *m) {


+ 16
- 16
mindspore/ccsrc/minddata/dataset/api/python/python_mp.h View File

@@ -34,12 +34,12 @@ namespace mindspore {
namespace dataset {
class PythonMultiprocessingRuntime {
public:
virtual void Launch(int32_t id) = 0;
virtual void Terminate() = 0;
virtual bool IsMPEnabled() = 0;
virtual void AddNewWorkers(int32_t num_new_workers) = 0;
virtual void RemoveWorkers(int32_t num_removed_workers) = 0;
virtual std::vector<int32_t> GetPIDs() = 0;
virtual void launch(int32_t id) = 0;
virtual void terminate() = 0;
virtual bool is_mp_enabled() = 0;
virtual void add_new_workers(int32_t num_new_workers) = 0;
virtual void remove_workers(int32_t num_removed_workers) = 0;
virtual std::vector<int32_t> get_pids() = 0;
virtual ~PythonMultiprocessingRuntime() {}
};

@@ -51,19 +51,19 @@ class PyPythonMultiprocessingRuntime : public PythonMultiprocessingRuntime {
// Trampoline (need one for each virtual function)
// PYBIND11_OVERLOAD_PURE(void, /* Return type */
// PythonMultiprocessingRuntime, /* Parent class */
// Launch /* Name of function in C++ (must match Python name) */
// launch /* Name of function in C++ (must match Python name) */

void Launch(int32_t id) override { PYBIND11_OVERLOAD_PURE(void, PythonMultiprocessingRuntime, Launch, id); }
void Terminate() override { PYBIND11_OVERLOAD_PURE(void, PythonMultiprocessingRuntime, Terminate); }
bool IsMPEnabled() override { PYBIND11_OVERLOAD_PURE(bool, PythonMultiprocessingRuntime, IsMPEnabled); }
void AddNewWorkers(int32_t num_workers) override {
PYBIND11_OVERLOAD_PURE(void, PythonMultiprocessingRuntime, AddNewWorkers, num_workers);
void launch(int32_t id) override { PYBIND11_OVERLOAD_PURE(void, PythonMultiprocessingRuntime, launch, id); }
void terminate() override { PYBIND11_OVERLOAD_PURE(void, PythonMultiprocessingRuntime, terminate); }
bool is_mp_enabled() override { PYBIND11_OVERLOAD_PURE(bool, PythonMultiprocessingRuntime, is_mp_enabled); }
void add_new_workers(int32_t num_workers) override {
PYBIND11_OVERLOAD_PURE(void, PythonMultiprocessingRuntime, add_new_workers, num_workers);
}
void RemoveWorkers(int32_t num_workers) override {
PYBIND11_OVERLOAD_PURE(void, PythonMultiprocessingRuntime, RemoveWorkers, num_workers);
void remove_workers(int32_t num_workers) override {
PYBIND11_OVERLOAD_PURE(void, PythonMultiprocessingRuntime, remove_workers, num_workers);
}
std::vector<int32_t> GetPIDs() override {
PYBIND11_OVERLOAD_PURE(std::vector<int32_t>, PythonMultiprocessingRuntime, GetPIDs);
std::vector<int32_t> get_pids() override {
PYBIND11_OVERLOAD_PURE(std::vector<int32_t>, PythonMultiprocessingRuntime, get_pids);
}
};
#endif


+ 4
- 4
mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.cc View File

@@ -630,7 +630,7 @@ Status BatchOp::AddNewWorkers(int32_t num_new_workers) {
RETURN_IF_NOT_OK(ParallelOp::AddNewWorkers(num_new_workers));
if (python_mp_ != nullptr) {
CHECK_FAIL_RETURN_UNEXPECTED(num_new_workers > 0, "Number of workers added should be greater than 0.");
python_mp_->AddNewWorkers(num_new_workers);
python_mp_->add_new_workers(num_new_workers);
}
return Status::OK();
}
@@ -639,7 +639,7 @@ Status BatchOp::RemoveWorkers(int32_t num_workers) {
RETURN_IF_NOT_OK(ParallelOp::RemoveWorkers(num_workers));
if (python_mp_ != nullptr) {
CHECK_FAIL_RETURN_UNEXPECTED(num_workers > 0, "Number of workers removed should be greater than 0.");
python_mp_->RemoveWorkers(num_workers);
python_mp_->remove_workers(num_workers);
}
return Status::OK();
}
@@ -652,14 +652,14 @@ Status BatchOp::Launch() {
// Launch Python multiprocessing. This will create the MP pool and shared memory if needed.
if (python_mp_) {
MS_LOG(DEBUG) << "Launch Python Multiprocessing for BatchOp:" << id();
python_mp_->Launch(id());
python_mp_->launch(id());
}
return DatasetOp::Launch();
}

std::vector<int32_t> BatchOp::GetMPWorkerPIDs() const {
if (python_mp_ != nullptr) {
return python_mp_->GetPIDs();
return python_mp_->get_pids();
}
return DatasetOp::GetMPWorkerPIDs();
}


+ 4
- 4
mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/map_op.cc View File

@@ -393,7 +393,7 @@ Status MapOp::AddNewWorkers(int32_t num_new_workers) {
RETURN_IF_NOT_OK(ParallelOp::AddNewWorkers(num_new_workers));
if (python_mp_ != nullptr) {
CHECK_FAIL_RETURN_UNEXPECTED(num_new_workers > 0, "Number of workers added should be greater than 0.");
python_mp_->AddNewWorkers(num_new_workers);
python_mp_->add_new_workers(num_new_workers);
}
return Status::OK();
}
@@ -402,7 +402,7 @@ Status MapOp::RemoveWorkers(int32_t num_workers) {
RETURN_IF_NOT_OK(ParallelOp::RemoveWorkers(num_workers));
if (python_mp_ != nullptr) {
CHECK_FAIL_RETURN_UNEXPECTED(num_workers > 0, "Number of workers removed should be greater than 0.");
python_mp_->RemoveWorkers(num_workers);
python_mp_->remove_workers(num_workers);
}
return Status::OK();
}
@@ -412,14 +412,14 @@ Status MapOp::Launch() {
// launch python multiprocessing. This will create the MP pool and shared memory if needed.
if (python_mp_) {
MS_LOG(DEBUG) << "Launch Python Multiprocessing for MapOp:" << id();
python_mp_->Launch(id());
python_mp_->launch(id());
}
return DatasetOp::Launch();
}

std::vector<int32_t> MapOp::GetMPWorkerPIDs() const {
if (python_mp_ != nullptr) {
return python_mp_->GetPIDs();
return python_mp_->get_pids();
}
return DatasetOp::GetMPWorkerPIDs();
}


+ 16
- 13
mindspore/python/mindspore/dataset/engine/datasets.py View File

@@ -31,6 +31,8 @@ import json
import os
import signal
import stat

import gc
import time
import uuid
import multiprocessing
@@ -1680,7 +1682,6 @@ class Dataset:

logger.warning("Calculating dynamic shape of input data, this will take a few minutes...")
# Assume data1 shape is dynamic, data2 shape is fix
# {"data1": [batch_size, None, feat_len], "data2": [batch_size, feat_len]}
dynamic_columns = self.dynamic_setting[1]
# ["data1", "data2"]
dataset_columns = self.get_col_names()
@@ -2774,7 +2775,7 @@ class _PythonMultiprocessing(cde.PythonMultiprocessingRuntime):
self.ppid = os.getpid()
self.hook = None

def Launch(self, op_id=-1):
def launch(self, op_id=-1):
self.op_id = op_id
logger.info("Launching new Python Multiprocessing pool for Op:" + str(self.op_id))
self.create_pool()
@@ -2791,6 +2792,8 @@ class _PythonMultiprocessing(cde.PythonMultiprocessingRuntime):
if self.process_pool is not None:
raise Exception("Pool was already created, close it first.")

# Let gc collect unrefrenced memory to avoid child processes in the pool to do it
gc.collect()
# Construct python multiprocessing pool.
# The _pyfunc_worker_init is used to pass lambda function to subprocesses.
self.process_pool = multiprocessing.Pool(processes=self.num_parallel_workers,
@@ -2810,36 +2813,36 @@ class _PythonMultiprocessing(cde.PythonMultiprocessingRuntime):
if sys.version_info >= (3, 8):
atexit.register(self.process_pool.close)

def Terminate(self):
def terminate(self):
logger.info("Terminating Python Multiprocessing pool for Op:" + str(self.op_id))
self.close_pool()
self.abort_watchdog()
self.delete_shared_memory()
self.process_pool = None

def GetPIDs(self):
def get_pids(self):
# obtain process IDs from multiprocessing.pool
return [w.pid for w in self.workers]

def AddNewWorkers(self, num_new_workers):
def add_new_workers(self, num_new_workers):
logger.info(
"Increasing num_parallel_workers of Python Multiprocessing pool for Op:" + str(self.op_id) +
", old num_workers=" + str(self.num_parallel_workers) + " new num_workers" + str(self.num_parallel_workers +
num_new_workers) + ".")
self.Terminate()
self.terminate()
self.num_parallel_workers += num_new_workers
self.Launch(self.op_id)
self.launch(self.op_id)

def RemoveWorkers(self, num_removed_workers):
def remove_workers(self, num_removed_workers):
logger.info(
"Decreasing num_parallel_workers of Python Multiprocessing pool for Op:" + str(self.op_id) +
", old num_workers=" + str(self.num_parallel_workers) + " new num_workers" + str(self.num_parallel_workers -
num_removed_workers) + ".")
self.Terminate()
self.terminate()
self.num_parallel_workers -= num_removed_workers
self.Launch(self.op_id)
self.launch(self.op_id)

def IsMPEnabled(self):
def is_mp_enabled(self):
return self.process_pool is not None

def create_shared_memory(self):
@@ -2878,7 +2881,7 @@ class _PythonMultiprocessing(cde.PythonMultiprocessingRuntime):
Collect the PIDs of the children processes.
"""
self.workers = [w for w in self.process_pool._pool] # pylint: disable=W0212
pids = self.GetPIDs()
pids = self.get_pids()
logger.info("Op: " + str(self.op_id) + " Python multiprocessing pool workers' PIDs: " + str(pids))

def execute(self, py_callable, idx, *args):
@@ -3135,7 +3138,7 @@ class _PythonMultiprocessing(cde.PythonMultiprocessingRuntime):

def __del__(self):
# Cleanup when the iter had been deleted from ITERATORS_LIST
self.Terminate()
self.terminate()


class MapDataset(UnionBaseDataset):


+ 28
- 0
tests/ut/python/dataset/conftest.py View File

@@ -0,0 +1,28 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
@File : conftest.py
@Desc : common fixtures for pytest
"""

import pytest
from mindspore.dataset.engine.iterators import _cleanup, _unset_iterator_cleanup


@pytest.fixture(autouse=True)
def close_iterators():
yield
_cleanup()
_unset_iterator_cleanup()

Loading…
Cancel
Save