Browse Source

[FIX] fix bug in list_semantic_specification_values

tags/v0.3.2
Gene 2 years ago
parent
commit
a267fcefed
1 changed files with 5 additions and 21 deletions
  1. +5
    -21
      learnware/client/learnware_client.py

+ 5
- 21
learnware/client/learnware_client.py View File

@@ -40,16 +40,12 @@ def file_chunks(file_path):
if not chunk:
break
yield chunk
pass
pass
pass


def compute_file_hash(file_path):
file_hash = hashlib.md5()
for chunk in file_chunks(file_path):
file_hash.update(chunk)
pass
return file_hash.hexdigest()


@@ -58,7 +54,6 @@ class SemanticSpecificationKey(Enum):
TASK_TYPE = "Task"
LIBRARY_TYPE = "Library"
SENARIOES = "Scenario"
pass


class LearnwareClient:
@@ -85,7 +80,6 @@ class LearnwareClient:

token = result["data"]["token"]
self.headers = {"Authorization": f"Bearer {token}"}
pass

@require_login
def logout(self):
@@ -95,7 +89,6 @@ class LearnwareClient:
if result["code"] != 0:
raise Exception("logout failed: " + json.dumps(result))
self.headers = None
pass

@require_login
def upload_learnware(self, learnware_zip_path, semantic_specification):
@@ -126,7 +119,6 @@ class LearnwareClient:

begin += len(chunk)
bar.update(1)
pass
bar.close()

url_add = f"{self.host}/user/add_learnware_uploaded"
@@ -169,9 +161,6 @@ class LearnwareClient:
for chunk in response.iter_content(chunk_size=CHUNK_SIZE):
f.write(chunk)
bar.update(1)
pass
pass
pass

@require_login
def list_learnware(self):
@@ -199,19 +188,16 @@ class LearnwareClient:
stat_spec = list(stat_spec.values())[0]
else:
stat_spec = None
pass

returns = []
with tempfile.NamedTemporaryFile(prefix="learnware_stat_", suffix=".json") as ftemp:
if stat_spec is not None:
stat_spec.save(ftemp.name)
pass

with open(ftemp.name, "r") as fin:
semantic_specification = specification.get_semantic_spec()
if semantic_specification is None:
semantic_specification = {}
pass

semantic_specification.pop("Input", None)
semantic_specification.pop("Output", None)
@@ -220,7 +206,6 @@ class LearnwareClient:
files = None
else:
files = {"statistical_specification": fin}
pass

response = requests.post(
url,
@@ -246,9 +231,6 @@ class LearnwareClient:
"matching": learnware["matching"],
}
)
pass
pass
pass

return returns

@@ -261,7 +243,6 @@ class LearnwareClient:

if result["code"] != 0:
raise Exception("delete failed: " + json.dumps(result))
pass

def create_semantic_specification(
self,
@@ -295,8 +276,9 @@ class LearnwareClient:
response = requests.get(url, headers=self.headers)
result = response.json()
semantic_conf = result["data"]["semantic_specification"]
print("!" * 100, semantic_conf)

return semantic_conf[key]["Values"]
return semantic_conf[key.value]["Values"]

def load_learnware(
self,
@@ -412,7 +394,9 @@ class LearnwareClient:
semantic_specification = (
get_semantic_specification() if semantic_specification is None else semantic_specification
)
assert LearnwareClient._check_semantic_specification(semantic_specification), "Semantic specification check failed!"
assert LearnwareClient._check_semantic_specification(
semantic_specification
), "Semantic specification check failed!"

with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir:
with zipfile.ZipFile(learnware_zip_path, mode="r") as z_file:


Loading…
Cancel
Save