Browse Source

[MNT] replace __exit__ by atexit.register

tags/v0.3.2
Gene 2 years ago
parent
commit
768bdf48a6
3 changed files with 22 additions and 20 deletions
  1. +11
    -11
      learnware/client/container.py
  2. +3
    -1
      learnware/client/learnware_client.py
  3. +8
    -8
      tests/test_client/test_download.py

+ 11
- 11
learnware/client/container.py View File

@@ -1,5 +1,6 @@
import os import os
import pickle import pickle
import atexit
import tempfile import tempfile
import shortuuid import shortuuid
from concurrent.futures import ProcessPoolExecutor from concurrent.futures import ProcessPoolExecutor
@@ -126,6 +127,12 @@ class LearnwaresContainer:
) )
for _learnware, _zippath in zip(learnware_list, learnware_zippaths) for _learnware, _zippath in zip(learnware_list, learnware_zippaths)
] ]
model_list = [_learnware.get_model() for _learnware in self.learnware_list]
with ProcessPoolExecutor(max_workers=max(os.cpu_count() // 2, 1)) as executor:
executor.map(self._initialize_model_container, model_list)
atexit.register(self.cleanup)


@staticmethod @staticmethod
def _initialize_model_container(model: ModelEnvContainer): def _initialize_model_container(model: ModelEnvContainer):
@@ -135,16 +142,9 @@ class LearnwaresContainer:
def _destroy_model_container(model: ModelEnvContainer): def _destroy_model_container(model: ModelEnvContainer):
model.remove_env() model.remove_env()


def __enter__(self):
model_list = [_learnware.get_model() for _learnware in self.learnware_list]
with ProcessPoolExecutor(max_workers=max(os.cpu_count() // 2, 1)) as executor:
executor.map(self._initialize_model_container, model_list)
return self

def __exit__(self, type, value, trace):
model_list = [_learnware.get_model() for _learnware in self.learnware_list]
with ProcessPoolExecutor(max_workers=max(os.cpu_count() // 2, 1)) as executor:
executor.map(self._destroy_model_container, model_list)

def get_learnware_list_with_container(self): def get_learnware_list_with_container(self):
return self.learnware_list return self.learnware_list
def cleanup(self):
for _learnware in self.learnware_list:
self._destroy_model_container(_learnware.get_model())

+ 3
- 1
learnware/client/learnware_client.py View File

@@ -2,6 +2,7 @@ import os
import numpy as np import numpy as np
import yaml import yaml
import json import json
import atexit
import zipfile import zipfile
import hashlib import hashlib
import requests import requests
@@ -71,6 +72,7 @@ class LearnwareClient:


self.chunk_size = 1024 * 1024 self.chunk_size = 1024 * 1024
self.tempdir_list = [] self.tempdir_list = []
atexit.register(self.cleanup)


def login(self, email, token): def login(self, email, token):
url = f"{self.host}/auth/login_by_token" url = f"{self.host}/auth/login_by_token"
@@ -439,6 +441,6 @@ class LearnwareClient:


return result return result


def __del__(self):
def cleanup(self):
for tempdir in self.tempdir_list: for tempdir in self.tempdir_list:
tempdir.cleanup() tempdir.cleanup()

+ 8
- 8
tests/test_client/test_download.py View File

@@ -25,11 +25,11 @@ if __name__ == "__main__":


learnware_list = [client.load_learnware(file, load_model=False) for file in zip_paths] learnware_list = [client.load_learnware(file, load_model=False) for file in zip_paths]


with LearnwaresContainer(learnware_list, zip_paths) as env_container:
learnware_list = env_container.get_learnware_list_with_container()
reuser = AveragingReuser(learnware_list, mode="vote_by_label")
input_array = np.random.random(size=(20, 13))
print(reuser.predict(input_array))
for learnware in learnware_list:
print(learnware.id, learnware.predict(input_array))
env_container = LearnwaresContainer(learnware_list, zip_paths)
learnware_list = env_container.get_learnware_list_with_container()
reuser = AveragingReuser(learnware_list, mode="vote_by_label")
input_array = np.random.random(size=(20, 13))
print(reuser.predict(input_array))
for learnware in learnware_list:
print(learnware.id, learnware.predict(input_array))

Loading…
Cancel
Save