Browse Source

Merge branch 'Dev' of https://github.com/AbductiveLearning/ABL-Package into Dev

pull/3/head
Gao Enhao 3 years ago
parent
commit
71d2500b4f
1 changed files with 4 additions and 5 deletions
  1. +4
    -5
      abl/abducer/abducer_base.py

+ 4
- 5
abl/abducer/abducer_base.py View File

@@ -10,9 +10,10 @@
#
# ================================================================#

import os
import abc
import numpy as np
import multiprocessing
from multiprocessing import Pool
from zoopt import Dimension, Objective, Parameter, Opt
from ..utils.utils import confidence_dist, flatten, reform_idx, hamming_dist

@@ -110,10 +111,8 @@ class AbducerBase(abc.ABC):
return self.abduce((z, prob, y), max_address, require_more_address)

def batch_abduce(self, Z, Y, max_address=-1, require_more_address=0):
pool = multiprocessing.Pool(processes=len(Z))
results = pool.map(self._batch_abduce_helper, [(z, prob, y, max_address, require_more_address) for z, prob, y in zip(Z['cls'], Z['prob'], Y)])
pool.close()
pool.join()
with Pool(processes=os.cpu_count()) as pool:
results = pool.map(self._batch_abduce_helper, [(z, prob, y, max_address, require_more_address) for z, prob, y in zip(Z['cls'], Z['prob'], Y)])
return results
def __call__(self, Z, Y, max_address=-1, require_more_address=0):


Loading…
Cancel
Save