Browse Source

Merge pull request #54 from Learnware-LAMDA/rm_tags

[MNT] remove unnecessary tags
tags/v0.3.2
Gene GitHub 2 years ago
parent
commit
ea310a9b66
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 2 additions and 8 deletions
  1. +0
    -6
      learnware/config.py
  2. +2
    -2
      learnware/market/easy2/checker.py

+ 0
- 6
learnware/config.py View File

@@ -80,19 +80,13 @@ semantic_config = {
"Values": [
"Classification",
"Regression",
"Clustering",
"Feature Extraction",
# "Generation",
"Segmentation",
"Object Detection",
"Others",
],
"Type": "Class", # Choose only one class
},
# "Device": {
# "Values": ["CPU", "GPU"],
# "Type": "Tag",
# }, # Choose one or more tags
"Library": {
"Values": ["Scikit-learn", "PyTorch", "TensorFlow", "Others"],
"Type": "Class",


+ 2
- 2
learnware/market/easy2/checker.py View File

@@ -43,7 +43,7 @@ class EasySemanticChecker(BaseChecker):
assert int(k) >= 0 and int(k) < dim, f"Dimension number in [0, {dim})"
assert isinstance(v, str), "Description must be string"

if semantic_spec["Task"]["Values"][0] in ["Classification", "Regression", "Feature Extraction"]:
if semantic_spec["Task"]["Values"][0] in ["Classification", "Regression"]:
assert semantic_spec["Output"] is not None, "Lack of output semantics"
dim = semantic_spec["Output"]["Dimension"]
for k, v in semantic_spec["Output"]["Description"].items():
@@ -126,7 +126,7 @@ class EasyStatChecker(BaseChecker):
logger.warning(f"learnware {learnware} prediction method is not valid!")
return self.INVALID_LEARNWARE

if semantic_spec["Task"]["Values"][0] in ("Classification", "Regression", "Feature Extraction"):
if semantic_spec["Task"]["Values"][0] in ("Classification", "Regression"):
# Check output type
if isinstance(outputs, torch.Tensor):
outputs = outputs.detach().cpu().numpy()


Loading…
Cancel
Save