|
|
|
@@ -43,7 +43,7 @@ class EasySemanticChecker(BaseChecker): |
|
|
|
assert int(k) >= 0 and int(k) < dim, f"Dimension number in [0, {dim})" |
|
|
|
assert isinstance(v, str), "Description must be string" |
|
|
|
|
|
|
|
if semantic_spec["Task"]["Values"][0] in ["Classification", "Regression", "Feature Extraction"]: |
|
|
|
if semantic_spec["Task"]["Values"][0] in ["Classification", "Regression"]: |
|
|
|
assert semantic_spec["Output"] is not None, "Lack of output semantics" |
|
|
|
dim = semantic_spec["Output"]["Dimension"] |
|
|
|
for k, v in semantic_spec["Output"]["Description"].items(): |
|
|
|
@@ -126,7 +126,7 @@ class EasyStatChecker(BaseChecker): |
|
|
|
logger.warning(f"learnware {learnware} prediction method is not valid!") |
|
|
|
return self.INVALID_LEARNWARE |
|
|
|
|
|
|
|
if semantic_spec["Task"]["Values"][0] in ("Classification", "Regression", "Feature Extraction"): |
|
|
|
if semantic_spec["Task"]["Values"][0] in ("Classification", "Regression"): |
|
|
|
# Check output type |
|
|
|
if isinstance(outputs, torch.Tensor): |
|
|
|
outputs = outputs.detach().cpu().numpy() |
|
|
|
|