From 768bdf48a643b33f14423dadcf05cdb97a960ffa Mon Sep 17 00:00:00 2001 From: Gene Date: Fri, 13 Oct 2023 09:45:07 +0800 Subject: [PATCH] [MNT] replace __exit__ by atexit.register --- learnware/client/container.py | 22 +++++++++++----------- learnware/client/learnware_client.py | 4 +++- tests/test_client/test_download.py | 16 ++++++++-------- 3 files changed, 22 insertions(+), 20 deletions(-) diff --git a/learnware/client/container.py b/learnware/client/container.py index 8dcb195..c924722 100644 --- a/learnware/client/container.py +++ b/learnware/client/container.py @@ -1,5 +1,6 @@ import os import pickle +import atexit import tempfile import shortuuid from concurrent.futures import ProcessPoolExecutor @@ -126,6 +127,12 @@ class LearnwaresContainer: ) for _learnware, _zippath in zip(learnware_list, learnware_zippaths) ] + + model_list = [_learnware.get_model() for _learnware in self.learnware_list] + with ProcessPoolExecutor(max_workers=max(os.cpu_count() // 2, 1)) as executor: + executor.map(self._initialize_model_container, model_list) + + atexit.register(self.cleanup) @staticmethod def _initialize_model_container(model: ModelEnvContainer): @@ -135,16 +142,9 @@ class LearnwaresContainer: def _destroy_model_container(model: ModelEnvContainer): model.remove_env() - def __enter__(self): - model_list = [_learnware.get_model() for _learnware in self.learnware_list] - with ProcessPoolExecutor(max_workers=max(os.cpu_count() // 2, 1)) as executor: - executor.map(self._initialize_model_container, model_list) - return self - - def __exit__(self, type, value, trace): - model_list = [_learnware.get_model() for _learnware in self.learnware_list] - with ProcessPoolExecutor(max_workers=max(os.cpu_count() // 2, 1)) as executor: - executor.map(self._destroy_model_container, model_list) - def get_learnware_list_with_container(self): return self.learnware_list + + def cleanup(self): + for _learnware in self.learnware_list: + self._destroy_model_container(_learnware.get_model()) \ No newline at end of file diff --git a/learnware/client/learnware_client.py b/learnware/client/learnware_client.py index d1b1887..6760ae6 100644 --- a/learnware/client/learnware_client.py +++ b/learnware/client/learnware_client.py @@ -2,6 +2,7 @@ import os import numpy as np import yaml import json +import atexit import zipfile import hashlib import requests @@ -71,6 +72,7 @@ class LearnwareClient: self.chunk_size = 1024 * 1024 self.tempdir_list = [] + atexit.register(self.cleanup) def login(self, email, token): url = f"{self.host}/auth/login_by_token" @@ -439,6 +441,6 @@ class LearnwareClient: return result - def __del__(self): + def cleanup(self): for tempdir in self.tempdir_list: tempdir.cleanup() diff --git a/tests/test_client/test_download.py b/tests/test_client/test_download.py index 4a21286..d74749a 100644 --- a/tests/test_client/test_download.py +++ b/tests/test_client/test_download.py @@ -25,11 +25,11 @@ if __name__ == "__main__": learnware_list = [client.load_learnware(file, load_model=False) for file in zip_paths] - with LearnwaresContainer(learnware_list, zip_paths) as env_container: - learnware_list = env_container.get_learnware_list_with_container() - reuser = AveragingReuser(learnware_list, mode="vote_by_label") - input_array = np.random.random(size=(20, 13)) - print(reuser.predict(input_array)) - - for learnware in learnware_list: - print(learnware.id, learnware.predict(input_array)) + env_container = LearnwaresContainer(learnware_list, zip_paths) + learnware_list = env_container.get_learnware_list_with_container() + reuser = AveragingReuser(learnware_list, mode="vote_by_label") + input_array = np.random.random(size=(20, 13)) + print(reuser.predict(input_array)) + + for learnware in learnware_list: + print(learnware.id, learnware.predict(input_array))