diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/autoML/AutoMlController.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/autoML/AutoMlController.java index 981b9fb1..9c1028bb 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/autoML/AutoMlController.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/autoML/AutoMlController.java @@ -63,7 +63,7 @@ public class AutoMlController extends BaseController { return AjaxResult.success(this.autoMlService.upload(file, uuid)); } - @PostMapping("{id}") + @PostMapping("/run/{id}") @ApiOperation("运行自动机器学习实验") public GenericsAjaxResult runAutoML(@PathVariable("id") Long id) throws Exception { return genericsSuccess(this.autoMlService.runAutoMlIns(id)); diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/autoML/AutoMlInsController.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/autoML/AutoMlInsController.java index 6d7b68d3..125fd668 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/autoML/AutoMlInsController.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/autoML/AutoMlInsController.java @@ -46,4 +46,10 @@ public class AutoMlInsController extends BaseController { public GenericsAjaxResult batchDelete(@RequestBody List ids) { return genericsSuccess(this.autoMLInsService.batchDelete(ids)); } + + @PutMapping("{id}") + @ApiOperation("终止实验实例") + public GenericsAjaxResult terminateAutoMlIns(@PathVariable("id") Long id) { + return genericsSuccess(this.autoMLInsService.terminateAutoMlIns(id)); + } } diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/domain/AutoMl.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/domain/AutoMl.java index 811168da..a0f130cc 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/domain/AutoMl.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/domain/AutoMl.java @@ -179,4 +179,6 @@ public class AutoMl { private String dataset; + @ApiModelProperty(value = "状态列表") + private String statusList; } diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/domain/AutoMlIns.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/domain/AutoMlIns.java index d31711bd..ba425dcf 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/domain/AutoMlIns.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/domain/AutoMlIns.java @@ -3,6 +3,7 @@ package com.ruoyi.platform.domain; import com.fasterxml.jackson.databind.PropertyNamingStrategy; import com.fasterxml.jackson.databind.annotation.JsonNaming; import io.swagger.annotations.ApiModel; +import io.swagger.annotations.ApiModelProperty; import lombok.Data; import java.util.Date; @@ -31,6 +32,12 @@ public class AutoMlIns { private String source; + @ApiModelProperty(value = "Argo实例名称") + private String argoInsName; + + @ApiModelProperty(value = "Argo命名空间") + private String argoInsNs; + private Date createTime; private Date updateTime; diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/mapper/AutoMlInsDao.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/mapper/AutoMlInsDao.java index 7468c8f3..7fcc3ca9 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/mapper/AutoMlInsDao.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/mapper/AutoMlInsDao.java @@ -11,9 +11,13 @@ public interface AutoMlInsDao { List queryAllByLimit(@Param("autoMlIns") AutoMlIns autoMlIns, @Param("pageable") Pageable pageable); + List getByAutoMlId(@Param("autoMlId") Long AutoMlId); + int insert(@Param("autoMlIns") AutoMlIns autoMlIns); int update(@Param("autoMlIns") AutoMlIns autoMlIns); AutoMlIns queryById(@Param("id") Long id); + + List queryByAutoMlInsIsNotTerminated(); } diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/scheduling/AutoMlInsStatusTask.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/scheduling/AutoMlInsStatusTask.java new file mode 100644 index 00000000..5c94284f --- /dev/null +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/scheduling/AutoMlInsStatusTask.java @@ -0,0 +1,97 @@ +package com.ruoyi.platform.scheduling; + +import com.ruoyi.platform.domain.AutoMl; +import com.ruoyi.platform.domain.AutoMlIns; +import com.ruoyi.platform.mapper.AutoMlDao; +import com.ruoyi.platform.mapper.AutoMlInsDao; +import com.ruoyi.platform.service.AutoMlInsService; +import org.apache.commons.lang3.StringUtils; +import org.springframework.scheduling.annotation.Scheduled; +import org.springframework.stereotype.Component; + +import javax.annotation.Resource; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; + +@Component() +public class AutoMlInsStatusTask { + + @Resource + private AutoMlInsService autoMlInsService; + + @Resource + private AutoMlInsDao autoMlInsDao; + + @Resource + private AutoMlDao autoMlDao; + + private List autoMlIds = new ArrayList<>(); + + @Scheduled(cron = "0/30 * * * * ?") // 每30S执行一次 + public void executeAutoMlInsStatus() throws Exception { + // 首先查到所有非终止态的实验实例 + List autoMlInsList = autoMlInsService.queryByAutoMlInsIsNotTerminated(); + + // 去argo查询状态 + List updateList = new ArrayList<>(); + if (autoMlInsList != null && autoMlInsList.size() > 0) { + for (AutoMlIns autoMlIns : autoMlInsList) { + //当原本状态为null或非终止态时才调用argo接口 + try { + autoMlIns = autoMlInsService.queryStatusFromArgo(autoMlIns); + } catch (Exception e) { + autoMlIns.setStatus("Failed"); + } + // 线程安全的添加操作 + synchronized (autoMlIds) { + autoMlIds.add(autoMlIns.getAutoMlId()); + } + updateList.add(autoMlIns); + } + if (updateList.size() > 0) { + for (AutoMlIns autoMlIns : updateList) { + autoMlInsDao.update(autoMlIns); + } + } + } + } + + @Scheduled(cron = "0/30 * * * * ?") // / 每30S执行一次 + public void executeAutoMlStatus() throws Exception { + if (autoMlIds.size() == 0) { + return; + } + // 存储需要更新的实验对象列表 + List updateAutoMls = new ArrayList<>(); + for (Long autoMlId : autoMlIds) { + // 获取当前实验的所有实例列表 + List insList = autoMlInsDao.getByAutoMlId(autoMlId); + List statusList = new ArrayList<>(); + // 更新实验状态列表 + for (int i = 0; i < insList.size(); i++) { + statusList.add(insList.get(i).getStatus()); + } + String subStatus = statusList.toString().substring(1, statusList.toString().length() - 1); + AutoMl autoMl = autoMlDao.getAutoMlById(autoMlId); + if (!StringUtils.equals(autoMl.getStatusList(), subStatus)) { + autoMl.setStatusList(subStatus); + updateAutoMls.add(autoMl); + autoMlDao.edit(autoMl); + } + } + + if (!updateAutoMls.isEmpty()) { + // 使用Iterator进行安全的删除操作 + Iterator iterator = autoMlIds.iterator(); + while (iterator.hasNext()) { + Long autoMlId = iterator.next(); + for (AutoMl autoMl : updateAutoMls) { + if (autoMl.getId().equals(autoMlId)) { + iterator.remove(); + } + } + } + } + } +} diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/AutoMlInsService.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/AutoMlInsService.java index 4b7a8601..d4956a22 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/AutoMlInsService.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/AutoMlInsService.java @@ -15,4 +15,10 @@ public interface AutoMlInsService { String removeById(Long id); String batchDelete(List ids); + + List queryByAutoMlInsIsNotTerminated(); + + AutoMlIns queryStatusFromArgo(AutoMlIns autoMlIns); + + boolean terminateAutoMlIns(Long id); } diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/AutoMlInsServiceImpl.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/AutoMlInsServiceImpl.java index 5b190226..3d11123e 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/AutoMlInsServiceImpl.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/AutoMlInsServiceImpl.java @@ -4,6 +4,11 @@ import com.ruoyi.platform.constant.Constant; import com.ruoyi.platform.domain.AutoMlIns; import com.ruoyi.platform.mapper.AutoMlInsDao; import com.ruoyi.platform.service.AutoMlInsService; +import com.ruoyi.platform.utils.DateUtils; +import com.ruoyi.platform.utils.HttpUtils; +import com.ruoyi.platform.utils.JsonUtils; +import org.apache.commons.lang3.StringUtils; +import org.springframework.beans.factory.annotation.Value; import org.springframework.data.domain.Page; import org.springframework.data.domain.PageImpl; import org.springframework.data.domain.PageRequest; @@ -11,10 +16,17 @@ import org.springframework.stereotype.Service; import javax.annotation.Resource; import java.io.IOException; -import java.util.List; +import java.util.*; @Service public class AutoMlInsServiceImpl implements AutoMlInsService { + @Value("${argo.url}") + private String argoUrl; + @Value("${argo.workflowStatus}") + private String argoWorkflowStatus; + @Value("${argo.workflowTermination}") + private String argoWorkflowTermination; + @Resource private AutoMlInsDao autoMlInsDao; @@ -38,6 +50,14 @@ public class AutoMlInsServiceImpl implements AutoMlInsService { if (autoMlIns == null) { return "实验实例不存在"; } + + if (StringUtils.isEmpty(autoMlIns.getStatus())) { + autoMlIns = queryStatusFromArgo(autoMlIns); + } + if (StringUtils.equals(autoMlIns.getStatus(), "Running")) { + return "实验实例正在运行,不可删除"; + } + autoMlIns.setState(Constant.State_invalid); int update = autoMlInsDao.update(autoMlIns); if (update > 0) { @@ -57,4 +77,150 @@ public class AutoMlInsServiceImpl implements AutoMlInsService { } return "删除成功"; } + + @Override + public List queryByAutoMlInsIsNotTerminated() { + return autoMlInsDao.queryByAutoMlInsIsNotTerminated(); + } + + @Override + public AutoMlIns queryStatusFromArgo(AutoMlIns ins) { + String namespace = ins.getArgoInsNs(); + String name = ins.getArgoInsName(); + Long id = ins.getId(); + + // 创建请求数据map + Map requestData = new HashMap<>(); + requestData.put("namespace", namespace); + requestData.put("name", name); + + // 创建发送数据map,将请求数据作为"data"键的值 + Map res = new HashMap<>(); + res.put("data", requestData); + + try { + // 发送POST请求到Argo工作流状态查询接口,并将请求数据转换为JSON + String req = HttpUtils.sendPost(argoUrl + argoWorkflowStatus, null, JsonUtils.mapToJson(res)); + // 检查响应是否为空或无内容 + if (req == null || StringUtils.isEmpty(req)) { + throw new RuntimeException("工作流状态响应为空。"); + } + // 将响应的JSON字符串转换为Map对象 + Map runResMap = JsonUtils.jsonToMap(req); + // 从响应Map中获取"data"部分 + Map data = (Map) runResMap.get("data"); + if (data == null || data.isEmpty()) { + throw new RuntimeException("工作流数据为空."); + } + // 从"data"中获取"status"部分,并返回"phase"的值 + Map status = (Map) data.get("status"); + if (status == null || status.isEmpty()) { + throw new RuntimeException("工作流状态为空。"); + } + + //解析流水线结束时间 + String finishedAtString = (String) status.get("finishedAt"); + if (finishedAtString != null && !finishedAtString.isEmpty()) { + Date finishTime = DateUtils.convertUTCtoShanghaiDate(finishedAtString); + ins.setUpdateTime(finishTime); + } + + // 解析nodes字段,提取节点状态并转换为JSON字符串 + Map nodes = (Map) status.get("nodes"); + Map modifiedNodes = new LinkedHashMap<>(); + if (nodes != null) { + for (Map.Entry nodeEntry : nodes.entrySet()) { + Map nodeDetails = (Map) nodeEntry.getValue(); + String templateName = (String) nodeDetails.get("displayName"); + modifiedNodes.put(templateName, nodeDetails); + } + } + + String nodeStatusJson = JsonUtils.mapToJson(modifiedNodes); + ins.setNodeStatus(nodeStatusJson); + + //终止态为终止不改 + if (!StringUtils.equals(ins.getStatus(), "Terminated")) { + ins.setStatus(StringUtils.isNotEmpty((String) status.get("phase")) ? (String) status.get("phase") : "Pending"); + } + if (StringUtils.equals(ins.getStatus(), "Error")) { + ins.setStatus("Failed"); + } + return ins; + } catch (Exception e) { + throw new RuntimeException("查询状态失败: " + e.getMessage(), e); + } + } + + @Override + public boolean terminateAutoMlIns(Long id) { + AutoMlIns autoMlIns = autoMlInsDao.queryById(id); + if (autoMlIns == null) { + throw new IllegalStateException("实验实例未查询到,id: " + id); + } + + String currentStatus = autoMlIns.getStatus(); + String name = autoMlIns.getArgoInsName(); + String namespace = autoMlIns.getArgoInsNs(); + + // 获取当前状态,如果为空,则从Argo查询 + if (StringUtils.isEmpty(currentStatus)) { + currentStatus = queryStatusFromArgo(autoMlIns).getStatus(); + } + // 只有状态是"Running"时才能终止实例 + if (!currentStatus.equalsIgnoreCase(Constant.Running)) { + return false; // 如果不是"Running"状态,则不执行终止操作 + } + + // 创建请求数据map + Map requestData = new HashMap<>(); + requestData.put("namespace", namespace); + requestData.put("name", name); + // 创建发送数据map,将请求数据作为"data"键的值 + Map res = new HashMap<>(); + res.put("data", requestData); + + try { + // 发送POST请求到Argo工作流状态查询接口,并将请求数据转换为JSON + String req = HttpUtils.sendPost(argoUrl + argoWorkflowTermination, null, JsonUtils.mapToJson(res)); + // 检查响应是否为空或无内容 + if (StringUtils.isEmpty(req)) { + throw new RuntimeException("终止响应内容为空。"); + + } + // 将响应的JSON字符串转换为Map对象 + Map runResMap = JsonUtils.jsonToMap(req); + // 从响应Map中直接获取"errCode"的值 + Integer errCode = (Integer) runResMap.get("errCode"); + if (errCode != null && errCode == 0) { + //更新autoMlIns,确保状态更新被保存到数据库 + AutoMlIns ins = queryStatusFromArgo(autoMlIns); + String nodeStatus = ins.getNodeStatus(); + Map nodeMap = JsonUtils.jsonToMap(nodeStatus); + + // 遍历 map + for (Map.Entry entry : nodeMap.entrySet()) { + // 获取每个 Map 中的值并强制转换为 Map + Map innerMap = (Map) entry.getValue(); + // 检查 phase 的值 + if (innerMap.containsKey("phase")) { + String phaseValue = (String) innerMap.get("phase"); + // 如果值不等于 Succeeded,则赋值为 Failed + if (!StringUtils.equals("Succeeded", phaseValue)) { + innerMap.put("phase", "Failed"); + } + } + } + ins.setNodeStatus(JsonUtils.mapToJson(nodeMap)); + ins.setStatus(Constant.Terminated); + ins.setUpdateTime(new Date()); + this.autoMlInsDao.update(ins); + return true; + } else { + return false; + } + } catch (Exception e) { + throw new RuntimeException("终止实例错误: " + e.getMessage(), e); + } + } } diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/AutoMlServiceImpl.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/AutoMlServiceImpl.java index 7669232a..4b7e58e9 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/AutoMlServiceImpl.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/AutoMlServiceImpl.java @@ -3,15 +3,19 @@ package com.ruoyi.platform.service.impl; import com.ruoyi.common.security.utils.SecurityUtils; import com.ruoyi.platform.constant.Constant; import com.ruoyi.platform.domain.AutoMl; +import com.ruoyi.platform.domain.AutoMlIns; import com.ruoyi.platform.mapper.AutoMlDao; +import com.ruoyi.platform.mapper.AutoMlInsDao; import com.ruoyi.platform.service.AutoMlService; -import com.ruoyi.platform.utils.*; +import com.ruoyi.platform.utils.FileUtil; +import com.ruoyi.platform.utils.HttpUtils; +import com.ruoyi.platform.utils.JacksonUtil; +import com.ruoyi.platform.utils.JsonUtils; +import com.ruoyi.platform.vo.AutoMlParamVo; import com.ruoyi.platform.vo.AutoMlVo; -import io.kubernetes.client.openapi.models.V1Pod; +import org.apache.commons.collections4.MapUtils; import org.apache.commons.io.FileUtils; import org.apache.commons.lang3.StringUtils; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; import org.springframework.beans.BeanUtils; import org.springframework.beans.factory.annotation.Value; import org.springframework.data.domain.Page; @@ -26,29 +30,24 @@ import java.io.IOException; import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.concurrent.CompletableFuture; @Service("autoMLService") public class AutoMlServiceImpl implements AutoMlService { - @Value("${harbor.serviceNS}") - private String serviceNS; - @Value("${dockerpush.proxyUrl}") - private String proxyUrl; - @Value("${dockerpush.mountPath}") - private String mountPath; - @Value("${minio.pvcName}") - private String pvcName; - @Value("${automl.image}") - private String image; @Value("${git.localPath}") String localPath; - private static final Logger logger = LoggerFactory.getLogger(ModelsServiceImpl.class); + @Value("${argo.url}") + private String argoUrl; + @Value("${argo.convertAutoML}") + String convertAutoML; + @Value("${argo.workflowRun}") + private String argoWorkflowRun; @Resource private AutoMlDao autoMlDao; + @Resource - private K8sClientUtil k8sClientUtil; + private AutoMlInsDao autoMlInsDao; @Override public Page queryByPage(String mlName, PageRequest pageRequest) { @@ -146,86 +145,52 @@ public class AutoMlServiceImpl implements AutoMlService { public String runAutoMlIns(Long id) throws Exception { AutoMl autoMl = autoMlDao.getAutoMlById(id); if (autoMl == null) { - throw new Exception("开发环境配置不存在"); - } - - StringBuffer command = new StringBuffer(); - command.append("nohup python /opt/automl.py --task_type " + autoMl.getTaskType()); - if (StringUtils.isNotEmpty(autoMl.getTargetColumns())) { - command.append(" --target_columns " + autoMl.getTargetColumns()); - } else { - throw new Exception("目标列为空"); - } - -// String username = SecurityUtils.getLoginUser().getUsername().toLowerCase(); - String username = "admin"; - //构造pod名称 - String podName = username + "-autoMlIns-pod-" + id; - V1Pod pod = k8sClientUtil.createPodWithEnv(podName, serviceNS, proxyUrl, mountPath, pvcName, image); - - if (autoMl.getTimeLeftForThisTask() != null) { - command.append(" --time_left_for_this_task " + autoMl.getTimeLeftForThisTask()); - } - if (autoMl.getPerRunTimeLimit() != null) { - command.append(" --per_run_time_limit " + autoMl.getPerRunTimeLimit()); - } - if (autoMl.getEnsembleSize() != null) { - command.append(" --ensemble_size " + autoMl.getEnsembleSize()); - } - if (StringUtils.isNotEmpty(autoMl.getEnsembleClass())) { - command.append(" --ensemble_class " + autoMl.getEnsembleClass()); - } - if (autoMl.getEnsembleNbest() != null) { - command.append(" --ensemble_nbest " + autoMl.getEnsembleNbest()); - } - if (autoMl.getMaxModelsOnDisc() != null) { - command.append(" --max_models_on_disc " + autoMl.getMaxModelsOnDisc()); - } - if (autoMl.getSeed() != null) { - command.append(" --seed " + autoMl.getSeed()); - } - if (autoMl.getMemoryLimit() != null) { - command.append(" --memory_limit " + autoMl.getMemoryLimit()); - } - if (StringUtils.isNotEmpty(autoMl.getIncludeClassifier())) { - command.append(" --include_classifier " + autoMl.getIncludeClassifier()); - } - if (StringUtils.isNotEmpty(autoMl.getIncludeRegressor())) { - command.append(" --include_regressor " + autoMl.getIncludeRegressor()); - } - if (StringUtils.isNotEmpty(autoMl.getIncludeFeaturePreprocessor())) { - command.append(" --include_feature_preprocessor " + autoMl.getIncludeFeaturePreprocessor()); - } - if (StringUtils.isNotEmpty(autoMl.getExcludeClassifier())) { - command.append(" --exclude_classifier " + autoMl.getExcludeClassifier()); - } - if (StringUtils.isNotEmpty(autoMl.getExcludeRegressor())) { - command.append(" --exclude_regressor " + autoMl.getExcludeRegressor()); - } - if (StringUtils.isNotEmpty(autoMl.getExcludeFeaturePreprocessor())) { - command.append(" --exclude_feature_preprocessor " + autoMl.getExcludeFeaturePreprocessor()); - } - if (StringUtils.isNotEmpty(autoMl.getResamplingStrategy())) { - command.append(" --resampling_strategy " + autoMl.getResamplingStrategy()); - } - if (autoMl.getTrainSize() != null) { - command.append(" --train_size " + autoMl.getTrainSize()); - } - if (autoMl.getShuffle() != null) { - command.append(" --shuffle " + autoMl.getShuffle()); - } - if (autoMl.getFolds() != null) { - command.append(" --folds " + autoMl.getFolds()); - } - command.append(" &"); - CompletableFuture.supplyAsync(() -> { - try { - String log = k8sClientUtil.executeCommand(pod, String.valueOf(command)); - } catch (Exception e) { - logger.error(e.getMessage(), e); + throw new Exception("自动机器学习配置不存在"); + } + + AutoMlParamVo autoMlParam = new AutoMlParamVo(); + BeanUtils.copyProperties(autoMl, autoMlParam); + autoMlParam.setDataset(JsonUtils.jsonToMap(autoMl.getDataset())); + String param = JsonUtils.objectToJson(autoMlParam); + // 调argo转换接口 + try { + String convertRes = HttpUtils.sendPost(argoUrl + convertAutoML, param); + if (convertRes == null || StringUtils.isEmpty(convertRes)) { + throw new RuntimeException("转换流水线失败"); } - return null; - }); + Map converMap = JsonUtils.jsonToMap(convertRes); + // 组装运行接口json + Map output = (Map) converMap.get("output"); + Map runReqMap = new HashMap<>(); + runReqMap.put("data", converMap.get("data")); + // 调argo运行接口 + String runRes = HttpUtils.sendPost(argoUrl + argoWorkflowRun, JsonUtils.mapToJson(runReqMap)); + + if (runRes == null || StringUtils.isEmpty(runRes)) { + throw new RuntimeException("Failed to run workflow."); + } + Map runResMap = JsonUtils.jsonToMap(runRes); + Map data = (Map) runResMap.get("data"); + //判断data为空 + if (data == null || MapUtils.isEmpty(data)) { + throw new RuntimeException("Failed to run workflow."); + } + Map metadata = (Map) data.get("metadata"); + // 插入记录到实验实例表 + AutoMlIns autoMlIns = new AutoMlIns(); + autoMlIns.setAutoMlId(autoMl.getId()); + autoMlIns.setArgoInsNs((String) metadata.get("namespace")); + autoMlIns.setArgoInsName((String) metadata.get("name")); + autoMlIns.setParam(param); + autoMlIns.setStatus(Constant.Pending); + //替换argoInsName + String outputString = JsonUtils.mapToJson(output); + autoMlIns.setNodeResult(outputString.replace("{{workflow.name}}", (String) metadata.get("name"))); + autoMlInsDao.insert(autoMlIns); + + } catch (Exception e) { + throw new RuntimeException(e); + } return "执行成功"; } } diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/vo/AutoMlParamVo.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/vo/AutoMlParamVo.java new file mode 100644 index 00000000..540631af --- /dev/null +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/vo/AutoMlParamVo.java @@ -0,0 +1,155 @@ +package com.ruoyi.platform.vo; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import io.swagger.annotations.ApiModel; +import io.swagger.annotations.ApiModelProperty; +import lombok.Data; + +import java.util.Map; + +@Data +@JsonNaming(PropertyNamingStrategy.SnakeCaseStrategy.class) +@JsonInclude(JsonInclude.Include.NON_NULL) +@ApiModel(description = "自动机器学习参数") +public class AutoMlParamVo { + @ApiModelProperty(value = "任务类型:classification或regression") + private String taskType; + + @ApiModelProperty(value = "搜索合适模型的时间限制(以秒为单位)。通过增加这个值,auto-sklearn有更高的机会找到更好的模型。默认3600,非必传。") + private Integer timeLeftForThisTask; + + @ApiModelProperty(value = "单次调用机器学习模型的时间限制(以秒为单位)。如果机器学习算法运行超过时间限制,将终止模型拟合。将这个值设置得足够高,这样典型的机器学习算法就可以适用于训练数据。默认600,非必传。") + private Integer perRunTimeLimit; + + @ApiModelProperty(value = "集成模型数量,如果设置为0,则没有集成。默认50,非必传。") + private Integer ensembleSize; + + @ApiModelProperty(value = "设置为None将禁用集成构建,设置为SingleBest仅使用单个最佳模型而不是集成,设置为default,它将对单目标问题使用EnsembleSelection,对多目标问题使用MultiObjectiveDummyEnsemble。默认default,非必传。") + private String ensembleClass; + + @ApiModelProperty(value = "在构建集成时只考虑ensemble_nbest模型。这是受到了“最大限度地利用集成选择”中引入的库修剪概念的启发。这是独立于ensemble_class参数的,并且这个修剪步骤是在构造集成之前完成的。默认50,非必传。") + private Integer ensembleNbest; + + @ApiModelProperty(value = "定义在磁盘中保存的模型的最大数量。额外的模型数量将被永久删除。由于这个变量的性质,它设置了一个集成可以使用多少个模型的上限。必须是大于等于1的整数。如果设置为None,则所有模型都保留在磁盘上。默认50,非必传。") + private Integer maxModelsOnDisc; + + @ApiModelProperty(value = "随机种子,将决定输出文件名。默认1,非必传。") + private Integer seed; + + @ApiModelProperty(value = "机器学习算法的内存限制(MB)。如果auto-sklearn试图分配超过memory_limit MB,它将停止拟合机器学习算法。默认3072,非必传。") + private Integer memoryLimit; + + @ApiModelProperty(value = "如果为None,则使用所有可能的分类算法。否则,指定搜索中包含的步骤和组件。有关可用组件,请参见/pipeline/components//*。与参数exclude不兼容。多选,逗号分隔。包含:adaboost\n" + + "bernoulli_nb\n" + + "decision_tree\n" + + "extra_trees\n" + + "gaussian_nb\n" + + "gradient_boosting\n" + + "k_nearest_neighbors\n" + + "lda\n" + + "liblinear_svc\n" + + "libsvm_svc\n" + + "mlp\n" + + "multinomial_nb\n" + + "passive_aggressive\n" + + "qda\n" + + "random_forest\n" + + "sgd") + private String includeClassifier; + + @ApiModelProperty(value = "如果为None,则使用所有可能的特征预处理算法。否则,指定搜索中包含的步骤和组件。有关可用组件,请参见/pipeline/components//*。与参数exclude不兼容。多选,逗号分隔。包含:densifier\n" + + "extra_trees_preproc_for_classification\n" + + "extra_trees_preproc_for_regression\n" + + "fast_ica\n" + + "feature_agglomeration\n" + + "kernel_pca\n" + + "kitchen_sinks\n" + + "liblinear_svc_preprocessor\n" + + "no_preprocessing\n" + + "nystroem_sampler\n" + + "pca\n" + + "polynomial\n" + + "random_trees_embedding\n" + + "select_percentile_classification\n" + + "select_percentile_regression\n" + + "select_rates_classification\n" + + "select_rates_regression\n" + + "truncatedSVD") + private String includeFeaturePreprocessor; + + @ApiModelProperty(value = "如果为None,则使用所有可能的回归算法。否则,指定搜索中包含的步骤和组件。有关可用组件,请参见/pipeline/components//*。与参数exclude不兼容。多选,逗号分隔。包含:adaboost,\n" + + "ard_regression,\n" + + "decision_tree,\n" + + "extra_trees,\n" + + "gaussian_process,\n" + + "gradient_boosting,\n" + + "k_nearest_neighbors,\n" + + "liblinear_svr,\n" + + "libsvm_svr,\n" + + "mlp,\n" + + "random_forest,\n" + + "sgd") + private String includeRegressor; + + private String excludeClassifier; + + private String excludeRegressor; + + private String excludeFeaturePreprocessor; + + @ApiModelProperty(value = "测试集的比率,0到1之间") + private Float testSize; + + @ApiModelProperty(value = "如何处理过拟合,如果使用基于“cv”的方法或Splitter对象,可能需要使用resampling_strategy_arguments。holdout或crossValid") + private String resamplingStrategy; + + @ApiModelProperty(value = "重采样划分训练集和验证集,训练集的比率,0到1之间") + private Float trainSize; + + @ApiModelProperty(value = "拆分数据前是否进行shuffle") + private Boolean shuffle; + + @ApiModelProperty(value = "交叉验证的折数,当resamplingStrategy为crossValid时,此项必填,为整数") + private Integer folds; + + @ApiModelProperty(value = "数据集csv文件中哪几列是预测目标列,逗号分隔") + private String targetColumns; + + @ApiModelProperty(value = "自定义指标名称") + private String metricName; + + @ApiModelProperty(value = "模型优化目标指标及权重,json格式。分类的指标包含:accuracy\n" + + "balanced_accuracy\n" + + "roc_auc\n" + + "average_precision\n" + + "log_loss\n" + + "precision_macro\n" + + "precision_micro\n" + + "precision_samples\n" + + "precision_weighted\n" + + "recall_macro\n" + + "recall_micro\n" + + "recall_samples\n" + + "recall_weighted\n" + + "f1_macro\n" + + "f1_micro\n" + + "f1_samples\n" + + "f1_weighted\n" + + "回归的指标包含:mean_absolute_error\n" + + "mean_squared_error\n" + + "root_mean_squared_error\n" + + "mean_squared_log_error\n" + + "median_absolute_error\n" + + "r2") + private String metrics; + + @ApiModelProperty(value = "指标优化方向,是越大越好还是越小越好") + private Boolean greaterIsBetter; + + @ApiModelProperty(value = "模型计算并打印指标") + private String scoringFunctions; + + private Map dataset; +} diff --git a/ruoyi-modules/management-platform/src/main/resources/mapper/managementPlatform/AutoMlDao.xml b/ruoyi-modules/management-platform/src/main/resources/mapper/managementPlatform/AutoMlDao.xml index b6a3c4e7..610fe143 100644 --- a/ruoyi-modules/management-platform/src/main/resources/mapper/managementPlatform/AutoMlDao.xml +++ b/ruoyi-modules/management-platform/src/main/resources/mapper/managementPlatform/AutoMlDao.xml @@ -34,9 +34,9 @@ ml_description = #{autoMl.mlDescription}, - - - + + status_list = #{autoMl.statusList}, + diff --git a/ruoyi-modules/management-platform/src/main/resources/mapper/managementPlatform/AutoMlInsDao.xml b/ruoyi-modules/management-platform/src/main/resources/mapper/managementPlatform/AutoMlInsDao.xml index 05168691..ca1300d3 100644 --- a/ruoyi-modules/management-platform/src/main/resources/mapper/managementPlatform/AutoMlInsDao.xml +++ b/ruoyi-modules/management-platform/src/main/resources/mapper/managementPlatform/AutoMlInsDao.xml @@ -3,9 +3,9 @@ - insert into auto_ml_ins(auto_ml_id, model_path, img_path, node_status, node_result, param, source) + insert into auto_ml_ins(auto_ml_id, model_path, img_path, node_status, node_result, param, source, argo_ins_name, argo_ins_ns) values (#{autoMlIns.autoMlId}, #{autoMlIns.modelPath}, #{autoMlIns.imgPath}, #{autoMlIns.nodeStatus}, - #{autoMlIns.nodeResult}, #{autoMlIns.param}, #{autoMlIns.source}) + #{autoMlIns.nodeResult}, #{autoMlIns.param}, #{autoMlIns.source}, #{autoMlIns.argoInsName}, #{autoMlIns.argoInsNs}) @@ -29,7 +29,11 @@ state = #{autoMlIns.state}, + + update_time = #{autoMlIns.updateTime}, + + where id = #{autoMlIns.id} + + + + \ No newline at end of file