From 3a965b80138df25728e5519d97f5903aefb3f125 Mon Sep 17 00:00:00 2001 From: troyyyyy Date: Tue, 21 Mar 2023 10:11:50 +0800 Subject: [PATCH] Modify pool --- abl/abducer/abducer_base.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/abl/abducer/abducer_base.py b/abl/abducer/abducer_base.py index 96161a4..441223f 100644 --- a/abl/abducer/abducer_base.py +++ b/abl/abducer/abducer_base.py @@ -10,9 +10,10 @@ # # ================================================================# +import os import abc import numpy as np -import multiprocessing +from multiprocessing import Pool from zoopt import Dimension, Objective, Parameter, Opt from ..utils.utils import confidence_dist, flatten, reform_idx, hamming_dist @@ -110,10 +111,8 @@ class AbducerBase(abc.ABC): return self.abduce((z, prob, y), max_address, require_more_address) def batch_abduce(self, Z, Y, max_address=-1, require_more_address=0): - pool = multiprocessing.Pool(processes=len(Z)) - results = pool.map(self._batch_abduce_helper, [(z, prob, y, max_address, require_more_address) for z, prob, y in zip(Z['cls'], Z['prob'], Y)]) - pool.close() - pool.join() + with Pool(processes=os.cpu_count()) as pool: + results = pool.map(self._batch_abduce_helper, [(z, prob, y, max_address, require_more_address) for z, prob, y in zip(Z['cls'], Z['prob'], Y)]) return results def __call__(self, Z, Y, max_address=-1, require_more_address=0):