diff --git a/learnware/market/base.py b/learnware/market/base.py index 01957e3..e4ca3c4 100644 --- a/learnware/market/base.py +++ b/learnware/market/base.py @@ -58,7 +58,28 @@ class BaseMarket: """Initializing an empty market""" self.learnware_list = {} # id: Learnware self.count = 0 - self.property_list = None + self.property_list = { + 'Data': { + 'Values': ['Tabular', 'Image', 'Video', 'Text', 'Audio'], + 'Type' : 'Class', # Choose only one class + }, + 'Task': { + 'Values': ['Classification','Regression','Clustering','Feature Extraction','Generation','Segmentation','Object Detection'], + 'Type': 'Class', # Choose only one class + }, + 'Device': { + 'Values': ['CPU', 'GPU'], + 'Type': 'Tag', # Choose one or more tags + }, + 'Scenario': { + 'Values': ['Business', 'Financial', 'Health', 'Politics', 'Computer', 'Internet', 'Traffic', 'Nature', 'Fashion', 'Industry', 'Agriculture', 'Education', 'Entertainment', 'Architecture'], + 'Type': 'Tag', # Choose one or more tags + }, + 'Description': { + 'Values': str, + 'Type': 'Description', + }, + } def reload_market(self, market_path: str, property_list_path: str, load_mode: str = "database") -> bool: """Reload the market when server restared. @@ -165,6 +186,29 @@ class BaseMarket: - first is recommended combination, None when no recommended combination is calculated or statistical specification is not provided. - second is a list of matched learnwares """ + def search_by_property(): + def match_property(property1, property2): + if property1.keys() != property2.keys(): + raise Exception("property key error".format(property1.keys(), property2.keys())) + for key in property1.keys(): + if property1[key]['Type'] == 'Class': + if property1[key]['Values'] != property2[key]['Values']: + return False + elif property1[key]['Type'] == 'Tag': + if not (set(property1[key]['Values']) & set(property2[key]['Values'])): + return False + return True + + match_learnwares = [] + for learnware in self.learnware_list: + learnware_property = learnware.get_specification().get_property() + user_property = user_info.get_property() + if match_property(learnware_property, user_property): + match_learnwares.append(learnware) + return match_learnwares + + match_learnwares = search_by_property() + pass def get_learnware_by_ids(self, id: Union[str, List[str]]) -> Union[Learnware, List[Learnware]]: diff --git a/learnware/specification/base.py b/learnware/specification/base.py index c9f9e60..8982b31 100644 --- a/learnware/specification/base.py +++ b/learnware/specification/base.py @@ -16,8 +16,8 @@ class BaseStatSpecification: class Specification: - def __init__(self): - self.property = None + def __init__(self, property=None): + self.property = property self.stat_spec = {} # stat_spec should be dict def get_stat_spec(self):