Browse Source

Merge pull request #26 from Learnware-LAMDA/test_upload_learnware

[FIX] fix details for learnware upload in client
tags/v0.3.2
Gene GitHub 2 years ago
parent
commit
23c92bb69e
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 15 additions and 15 deletions
  1. +10
    -10
      learnware/client/learnware_client.py
  2. +5
    -5
      tests/test_learnware_client/test_upload.py

+ 10
- 10
learnware/client/learnware_client.py View File

@@ -98,15 +98,15 @@ class LearnwareClient:
pass

@require_login
def upload_learnware(self, semantic_specification, learnware_file):
def upload_learnware(self, learnware_zip_path, semantic_specification):
assert self._check_semantic_specification(semantic_specification)
file_hash = compute_file_hash(learnware_file)
file_hash = compute_file_hash(learnware_zip_path)
url_upload = f"{self.host}/user/chunked_upload"

num_chunks = os.path.getsize(learnware_file) // CHUNK_SIZE + 1
num_chunks = os.path.getsize(learnware_zip_path) // CHUNK_SIZE + 1
bar = tqdm(total=num_chunks, desc="Uploading", unit="MB")
begin = 0
for chunk in file_chunks(learnware_file):
for chunk in file_chunks(learnware_zip_path):
response = requests.post(
url_upload,
files={
@@ -270,7 +270,7 @@ class LearnwareClient:
data_type,
task_type,
library_type,
senarioes,
scenarios,
input_description=None,
output_description=None,
):
@@ -278,7 +278,7 @@ class LearnwareClient:
semantic_specification["Data"] = {"Type": "Class", "Values": [data_type]}
semantic_specification["Task"] = {"Type": "Class", "Values": [task_type]}
semantic_specification["Library"] = {"Type": "Class", "Values": [library_type]}
semantic_specification["Scenario"] = {"Type": "Tag", "Values": senarioes}
semantic_specification["Scenario"] = {"Type": "Tag", "Values": scenarios}
semantic_specification["Name"] = {"Type": "String", "Values": name}
semantic_specification["Description"] = {"Type": "String", "Values": description}
semantic_specification["Input"] = input_description
@@ -441,14 +441,14 @@ class LearnwareClient:
return False

@staticmethod
def check_learnware(zip_path, semantic_specification=None):
def check_learnware(learnware_zip_path, semantic_specification=None):
if semantic_specification is None:
semantic_specification = get_semantic_specification()
else:
self._check_semantic_specification(semantic_specification)
LearnwareClient._check_semantic_specification(semantic_specification)

with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir:
with zipfile.ZipFile(zip_path, mode="r") as z_file:
with zipfile.ZipFile(learnware_zip_path, mode="r") as z_file:
z_file.extractall(tempdir)

learnware = get_learnware_from_dirpath(
@@ -458,7 +458,7 @@ class LearnwareClient:
if learnware is None:
raise Exception("The learnware is not valid.")

with LearnwaresContainer(learnware, zip_path) as env_container:
with LearnwaresContainer(learnware, learnware_zip_path) as env_container:
learnware = env_container.get_learnwares_with_container()[0]
if EasyMarket.check_learnware(learnware) == EasyMarket.USABLE_LEARWARE:
logger.info("The learnware passed the local test.")


+ 5
- 5
tests/test_learnware_client/test_upload.py View File

@@ -20,11 +20,9 @@ class TestAllLearnware(unittest.TestCase):
"Description": {"0": "age", "1": "weight", "2": "body length", "3": "animal type", "4": "claw length"},
}
output_description = {
"Dimension": 3,
"Dimension": 1,
"Description": {
"0": "the probability of being a cat",
"1": "the probability of being a dog",
"2": "the probability of being a bird",
},
}
semantic_spec = self.client.create_semantic_specification(
@@ -33,7 +31,7 @@ class TestAllLearnware(unittest.TestCase):
data_type="Table",
task_type="Classification",
library_type="Scikit-learn",
senarioes=["Business", "Financial"],
scenarios=["Business", "Financial"],
input_description=input_description,
output_description=output_description,
)
@@ -43,7 +41,9 @@ class TestAllLearnware(unittest.TestCase):
with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir:
zip_path = os.path.join(tempdir, f"test.zip")
self.client.download_learnware(download_learnware_id, zip_path)
learnware_id = self.client.upload_learnware(semantic_specification=semantic_spec, learnware_file=zip_path)
learnware_id = self.client.upload_learnware(
learnware_zip_path=zip_path, semantic_specification=semantic_spec
)

uploaded_ids = [learnware["learnware_id"] for learnware in self.client.list_learnware()]
assert learnware_id in uploaded_ids


Loading…
Cancel
Save