Browse Source

[MNT] replace __exit__ by atexit.register

tags/v0.3.2
Gene 2 years ago
parent
commit
768bdf48a6
3 changed files with 22 additions and 20 deletions
  1. +11
    -11
      learnware/client/container.py
  2. +3
    -1
      learnware/client/learnware_client.py
  3. +8
    -8
      tests/test_client/test_download.py

+ 11
- 11
learnware/client/container.py View File

@@ -1,5 +1,6 @@
import os
import pickle
import atexit
import tempfile
import shortuuid
from concurrent.futures import ProcessPoolExecutor
@@ -126,6 +127,12 @@ class LearnwaresContainer:
)
for _learnware, _zippath in zip(learnware_list, learnware_zippaths)
]
model_list = [_learnware.get_model() for _learnware in self.learnware_list]
with ProcessPoolExecutor(max_workers=max(os.cpu_count() // 2, 1)) as executor:
executor.map(self._initialize_model_container, model_list)
atexit.register(self.cleanup)

@staticmethod
def _initialize_model_container(model: ModelEnvContainer):
@@ -135,16 +142,9 @@ class LearnwaresContainer:
def _destroy_model_container(model: ModelEnvContainer):
model.remove_env()

def __enter__(self):
model_list = [_learnware.get_model() for _learnware in self.learnware_list]
with ProcessPoolExecutor(max_workers=max(os.cpu_count() // 2, 1)) as executor:
executor.map(self._initialize_model_container, model_list)
return self

def __exit__(self, type, value, trace):
model_list = [_learnware.get_model() for _learnware in self.learnware_list]
with ProcessPoolExecutor(max_workers=max(os.cpu_count() // 2, 1)) as executor:
executor.map(self._destroy_model_container, model_list)

def get_learnware_list_with_container(self):
return self.learnware_list
def cleanup(self):
for _learnware in self.learnware_list:
self._destroy_model_container(_learnware.get_model())

+ 3
- 1
learnware/client/learnware_client.py View File

@@ -2,6 +2,7 @@ import os
import numpy as np
import yaml
import json
import atexit
import zipfile
import hashlib
import requests
@@ -71,6 +72,7 @@ class LearnwareClient:

self.chunk_size = 1024 * 1024
self.tempdir_list = []
atexit.register(self.cleanup)

def login(self, email, token):
url = f"{self.host}/auth/login_by_token"
@@ -439,6 +441,6 @@ class LearnwareClient:

return result

def __del__(self):
def cleanup(self):
for tempdir in self.tempdir_list:
tempdir.cleanup()

+ 8
- 8
tests/test_client/test_download.py View File

@@ -25,11 +25,11 @@ if __name__ == "__main__":

learnware_list = [client.load_learnware(file, load_model=False) for file in zip_paths]

with LearnwaresContainer(learnware_list, zip_paths) as env_container:
learnware_list = env_container.get_learnware_list_with_container()
reuser = AveragingReuser(learnware_list, mode="vote_by_label")
input_array = np.random.random(size=(20, 13))
print(reuser.predict(input_array))
for learnware in learnware_list:
print(learnware.id, learnware.predict(input_array))
env_container = LearnwaresContainer(learnware_list, zip_paths)
learnware_list = env_container.get_learnware_list_with_container()
reuser = AveragingReuser(learnware_list, mode="vote_by_label")
input_array = np.random.random(size=(20, 13))
print(reuser.predict(input_array))
for learnware in learnware_list:
print(learnware.id, learnware.predict(input_array))

Loading…
Cancel
Save