diff --git a/ruoyi-api/ruoyi-api-system/src/main/java/com/ruoyi/system/api/constant/Constant.java b/ruoyi-api/ruoyi-api-system/src/main/java/com/ruoyi/system/api/constant/Constant.java index 8af1124b..8303ebf7 100644 --- a/ruoyi-api/ruoyi-api-system/src/main/java/com/ruoyi/system/api/constant/Constant.java +++ b/ruoyi-api/ruoyi-api-system/src/main/java/com/ruoyi/system/api/constant/Constant.java @@ -67,6 +67,10 @@ public class Constant { public final static String ML_TextClassification = "text_classification"; public final static String ML_VideoClassification = "video_classification"; + public final static String AL_PYTORCH = "pytorch"; + public final static String AL_SKLEARN = "sklearn"; + public final static String AL_KERAS = "keras"; + public final static String DelFlag = "2"; public final static String Code = "123123"; diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ActiveLearnInsServiceImpl.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ActiveLearnInsServiceImpl.java index 197cc9db..cb51341e 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ActiveLearnInsServiceImpl.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ActiveLearnInsServiceImpl.java @@ -291,9 +291,19 @@ public class ActiveLearnInsServiceImpl implements ActiveLearnInsService { return aimUrl + "/metrics?select=" + decode; } - public void getTrialList(ActiveLearnIns ins) { + public void getTrialList(ActiveLearnIns ins) throws IOException { String directoryPath = ins.getResultPath(); - ins.setResultPath(endpoint + "/" + directoryPath + "/final_checkpoint/final_model_weights.pth"); + switch ((String) JsonUtils.jsonToMap(ins.getParam()).get("framework_type")) { + case Constant.AL_PYTORCH: { + ins.setResultPath(endpoint + "/" + directoryPath + "/final_checkpoint/final_model_weights.pth"); + } + case Constant.AL_SKLEARN: { + ins.setResultPath(endpoint + "/" + directoryPath + "/final_checkpoint/final_model.joblib"); + } + case Constant.AL_KERAS: { + ins.setResultPath(endpoint + "/" + directoryPath + "/final_checkpoint/model.h5"); + } + } try { String bucketName = directoryPath.substring(0, directoryPath.indexOf("/"));