Browse Source

[MNT] modify check for output dimension

tags/v0.3.2
Gene 2 years ago
parent
commit
71ebe0f1e6
1 changed files with 22 additions and 7 deletions
  1. +22
    -7
      learnware/market/easy/checker.py

+ 22
- 7
learnware/market/easy/checker.py View File

@@ -164,19 +164,34 @@ class EasyStatChecker(BaseChecker):


# Check output shape # Check output shape
if outputs[0].shape != learnware_model.output_shape: if outputs[0].shape != learnware_model.output_shape:
message = f"The learnware [{learnware.id}] output dimension mismatch!, where pred_shape={outputs[0].shape}, model_shape={learnware_model.output_shape}"
message = f"The learnware [{learnware.id}] output dimension mismatch, where pred_shape={outputs[0].shape}, model_shape={learnware_model.output_shape}"
logger.warning(message) logger.warning(message)
return self.INVALID_LEARNWARE, message return self.INVALID_LEARNWARE, message


# Check output dimension
if semantic_spec["Task"]["Values"][0] in [
"Classification",
"Regression",
] and learnware_model.output_shape[0] != int(semantic_spec["Output"]["Dimension"]):
message = f"The learnware [{learnware.id}] output dimension mismatch!, where model_shape={learnware_model.output_shape}, semantic_shape={(int(semantic_spec['Output']['Dimension']), )}"
# Check output dimension for regression
if semantic_spec["Task"]["Values"][0] == "Regression" and learnware_model.output_shape[0] != int(
semantic_spec["Output"]["Dimension"]
):
message = f"The learnware [{learnware.id}] output dimension mismatch, where model_shape={learnware_model.output_shape}, semantic_shape={(int(semantic_spec['Output']['Dimension']), )}"
logger.warning(message) logger.warning(message)
return self.INVALID_LEARNWARE, message return self.INVALID_LEARNWARE, message


# Check output dimension for classification
if semantic_spec["Task"]["Values"][0] == "Classification":
model_output_shape = learnware_model.output_shape[0]
semantic_output_shape = int(semantic_spec["Output"]["Dimension"])

if model_output_shape == 1:
if not all(int(item) >= 0 and int(item) < semantic_output_shape for item in outputs):
message = f"The learnware [{learnware.id}] output label mismatch, where outputs of model is {outputs}, semantic_shape={(semantic_output_shape, )}"
logger.warning(message)
return self.INVALID_LEARNWARE, message
else:
if model_output_shape != semantic_output_shape:
message = f"The learnware [{learnware.id}] output dimension mismatch, where model_shape={learnware_model.output_shape}, semantic_shape={(semantic_output_shape, )}"
logger.warning(message)
return self.INVALID_LEARNWARE, message

except Exception as e: except Exception as e:
message = f"The learnware [{learnware.id}] is not valid! Due to {repr(e)}." message = f"The learnware [{learnware.id}] is not valid! Due to {repr(e)}."
logger.warning(message) logger.warning(message)


Loading…
Cancel
Save