Browse Source

Merge branch 'dev' of git.nju.edu.cn:learnware/learnware-market into dev

tags/v0.3.2
chenzx 3 years ago
parent
commit
982da0ff6b
3 changed files with 43 additions and 44 deletions
  1. +8
    -3
      examples/example_m5/main.py
  2. +32
    -39
      learnware/market/easy.py
  3. +3
    -2
      learnware/model/base.py

+ 8
- 3
examples/example_m5/main.py View File

@@ -45,10 +45,10 @@ class M5DatasetWorkflow:

def _init_learnware_market(self):
"""initialize learnware market"""
database_ops.clear_learnware_table()
# database_ops.clear_learnware_table()
learnware.init()

easy_market = EasyMarket()
easy_market = EasyMarket(rebuild=True)
print("Total Item:", len(easy_market))

zip_path_list = []
@@ -130,7 +130,12 @@ class M5DatasetWorkflow:
user_info = BaseUserInfo(
id=f"user_{idx}", semantic_spec=user_senmantic, stat_info={"RKMEStatSpecification": user_spec}
)
sorted_score_list, single_learnware_list, mixture_learnware_list = easy_market.search_learnware(user_info)
(
sorted_score_list,
single_learnware_list,
mixture_score,
mixture_learnware_list,
) = easy_market.search_learnware(user_info)

print(f"search result of user{idx}:")
print(


+ 32
- 39
learnware/market/easy.py View File

@@ -19,9 +19,9 @@ logger = get_module_logger("market", "INFO")


class EasyMarket(BaseMarket):
INVALID_LEARNWARE = "INVALID"
NONUSABLE_LEARNWARE = "NONUSABLE"
USABLE_LEARWARE = "USABLE"
INVALID_LEARNWARE = -1
NONUSABLE_LEARNWARE = 0
USABLE_LEARWARE = 1

def __init__(self, market_id: str = "default", rebuild: bool = False):
"""Initialize Learnware Market.
@@ -82,15 +82,15 @@ class EasyMarket(BaseMarket):
return cls.NONUSABLE_LEARNWARE

try:
spec_data = learnware.specification.stat_spec["RKMEStatSpecification"].get_z()
except Exception:
logger.warning(f"The learnware [{learnware.id}] statistic specification is not avaliable!")
return cls.INVALID_LEARNWARE
learnware_model = learnware.get_model()
inputs = np.random.randn((10, *learnware_model.input_shape))
outputs = learnware.predict(inputs)
if outputs.shape[1:] != learnware_model.output_shape:
logger.warning(f"The learnware [{learnware.id}] input and output dimention is error")
return cls.NONUSABLE_LEARNWARE

try:
pred_spec = learnware.predict(spec_data)
except Exception:
logger.warning(f"The learnware [{learnware.id}] prediction is not avaliable")
except Exception as e:
logger.warning(f"The learnware [{learnware.id}] prediction is not avaliable! Due to {repr(e)}")
return cls.NONUSABLE_LEARNWARE

return cls.USABLE_LEARWARE
@@ -112,9 +112,9 @@ class EasyMarket(BaseMarket):

Returns
-------
Tuple[str, bool]
Tuple[str, int]
- str indicating model_id
- bool indicating whether the learnware is added successfully.
- int indicating what the flag of learnware is added.

"""
if not os.path.exists(zip_path):
@@ -160,42 +160,35 @@ class EasyMarket(BaseMarket):
with zipfile.ZipFile(target_zip_dir, "r") as z_file:
z_file.extractall(target_folder_dir)
logger.info("Learnware move to %s, and unzip to %s" % (target_zip_dir, target_folder_dir))

try:
new_learnware = get_learnware_from_dirpath(
id=id, semantic_spec=semantic_spec, learnware_dirpath=target_folder_dir
)
except:
new_learnware = None

if new_learnware is None:
try:
os.remove(target_zip_dir)
rmtree(target_folder_dir)
except:
pass
return None, False
else:
check_flag = self.check_learnware(new_learnware)
if not check_flag == self.INVALID_LEARNWARE:
try:
add_learnware_to_db(
market_id=self.market_id,
id=id,
semantic_spec=semantic_spec,
zip_path=target_zip_dir,
folder_path=target_folder_dir,
use_flag=check_flag,
)
self.learnware_list[id] = new_learnware
self.learnware_zip_list[id] = target_zip_dir
self.learnware_folder_list[id] = target_folder_dir
self.count += 1
return id, True
except Exception as e:
logger.warning(f"Add Learnware failed. Error msg: {e}")
return None, False
else:
return None, False
return None, self.INVALID_LEARNWARE

check_flag = self.check_learnware(new_learnware)

add_learnware_to_db(
market_id=self.market_id,
id=id,
semantic_spec=semantic_spec,
zip_path=target_zip_dir,
folder_path=target_folder_dir,
use_flag=check_flag,
)

self.learnware_list[id] = new_learnware
self.learnware_zip_list[id] = target_zip_dir
self.learnware_folder_list[id] = target_folder_dir
self.count += 1
return id, check_flag

def _convert_dist_to_score(
self, dist_list: List[float], dist_epsilon: float = 0.01, min_score: float = 0.92


+ 3
- 2
learnware/model/base.py View File

@@ -3,8 +3,9 @@ from abc import abstractmethod


class BaseModel:
def __init__(self):
pass
def __init__(self, input_shape: tuple, output_shape: tuple):
self.input_shape = input_shape
self.output_shape = output_shape

def fit(self, X: np.ndarray, y: np.ndarray):
pass


Loading…
Cancel
Save