diff --git a/learnware/client/learnware_client.py b/learnware/client/learnware_client.py index 040fd45..100af98 100644 --- a/learnware/client/learnware_client.py +++ b/learnware/client/learnware_client.py @@ -98,15 +98,15 @@ class LearnwareClient: pass @require_login - def upload_learnware(self, semantic_specification, learnware_file): + def upload_learnware(self, learnware_zip_path, semantic_specification): assert self._check_semantic_specification(semantic_specification) - file_hash = compute_file_hash(learnware_file) + file_hash = compute_file_hash(learnware_zip_path) url_upload = f"{self.host}/user/chunked_upload" - num_chunks = os.path.getsize(learnware_file) // CHUNK_SIZE + 1 + num_chunks = os.path.getsize(learnware_zip_path) // CHUNK_SIZE + 1 bar = tqdm(total=num_chunks, desc="Uploading", unit="MB") begin = 0 - for chunk in file_chunks(learnware_file): + for chunk in file_chunks(learnware_zip_path): response = requests.post( url_upload, files={ @@ -270,7 +270,7 @@ class LearnwareClient: data_type, task_type, library_type, - senarioes, + scenarios, input_description=None, output_description=None, ): @@ -278,7 +278,7 @@ class LearnwareClient: semantic_specification["Data"] = {"Type": "Class", "Values": [data_type]} semantic_specification["Task"] = {"Type": "Class", "Values": [task_type]} semantic_specification["Library"] = {"Type": "Class", "Values": [library_type]} - semantic_specification["Scenario"] = {"Type": "Tag", "Values": senarioes} + semantic_specification["Scenario"] = {"Type": "Tag", "Values": scenarios} semantic_specification["Name"] = {"Type": "String", "Values": name} semantic_specification["Description"] = {"Type": "String", "Values": description} semantic_specification["Input"] = input_description @@ -441,14 +441,14 @@ class LearnwareClient: return False @staticmethod - def check_learnware(zip_path, semantic_specification=None): + def check_learnware(learnware_zip_path, semantic_specification=None): if semantic_specification is None: semantic_specification = get_semantic_specification() else: - self._check_semantic_specification(semantic_specification) + LearnwareClient._check_semantic_specification(semantic_specification) with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir: - with zipfile.ZipFile(zip_path, mode="r") as z_file: + with zipfile.ZipFile(learnware_zip_path, mode="r") as z_file: z_file.extractall(tempdir) learnware = get_learnware_from_dirpath( @@ -458,7 +458,7 @@ class LearnwareClient: if learnware is None: raise Exception("The learnware is not valid.") - with LearnwaresContainer(learnware, zip_path) as env_container: + with LearnwaresContainer(learnware, learnware_zip_path) as env_container: learnware = env_container.get_learnwares_with_container()[0] if EasyMarket.check_learnware(learnware) == EasyMarket.USABLE_LEARWARE: logger.info("The learnware passed the local test.") diff --git a/tests/test_learnware_client/test_upload.py b/tests/test_learnware_client/test_upload.py index f111cf5..7bd3128 100644 --- a/tests/test_learnware_client/test_upload.py +++ b/tests/test_learnware_client/test_upload.py @@ -20,11 +20,9 @@ class TestAllLearnware(unittest.TestCase): "Description": {"0": "age", "1": "weight", "2": "body length", "3": "animal type", "4": "claw length"}, } output_description = { - "Dimension": 3, + "Dimension": 1, "Description": { "0": "the probability of being a cat", - "1": "the probability of being a dog", - "2": "the probability of being a bird", }, } semantic_spec = self.client.create_semantic_specification( @@ -33,7 +31,7 @@ class TestAllLearnware(unittest.TestCase): data_type="Table", task_type="Classification", library_type="Scikit-learn", - senarioes=["Business", "Financial"], + scenarios=["Business", "Financial"], input_description=input_description, output_description=output_description, ) @@ -43,7 +41,9 @@ class TestAllLearnware(unittest.TestCase): with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir: zip_path = os.path.join(tempdir, f"test.zip") self.client.download_learnware(download_learnware_id, zip_path) - learnware_id = self.client.upload_learnware(semantic_specification=semantic_spec, learnware_file=zip_path) + learnware_id = self.client.upload_learnware( + learnware_zip_path=zip_path, semantic_specification=semantic_spec + ) uploaded_ids = [learnware["learnware_id"] for learnware in self.client.list_learnware()] assert learnware_id in uploaded_ids