| @@ -63,7 +63,7 @@ public class AutoMlController extends BaseController { | |||
| return AjaxResult.success(this.autoMlService.upload(file, uuid)); | |||
| } | |||
| @PostMapping("{id}") | |||
| @PostMapping("/run/{id}") | |||
| @ApiOperation("运行自动机器学习实验") | |||
| public GenericsAjaxResult<String> runAutoML(@PathVariable("id") Long id) throws Exception { | |||
| return genericsSuccess(this.autoMlService.runAutoMlIns(id)); | |||
| @@ -46,4 +46,10 @@ public class AutoMlInsController extends BaseController { | |||
| public GenericsAjaxResult<String> batchDelete(@RequestBody List<Long> ids) { | |||
| return genericsSuccess(this.autoMLInsService.batchDelete(ids)); | |||
| } | |||
| @PutMapping("{id}") | |||
| @ApiOperation("终止实验实例") | |||
| public GenericsAjaxResult<Boolean> terminateAutoMlIns(@PathVariable("id") Long id) { | |||
| return genericsSuccess(this.autoMLInsService.terminateAutoMlIns(id)); | |||
| } | |||
| } | |||
| @@ -179,4 +179,6 @@ public class AutoMl { | |||
| private String dataset; | |||
| @ApiModelProperty(value = "状态列表") | |||
| private String statusList; | |||
| } | |||
| @@ -3,6 +3,7 @@ package com.ruoyi.platform.domain; | |||
| import com.fasterxml.jackson.databind.PropertyNamingStrategy; | |||
| import com.fasterxml.jackson.databind.annotation.JsonNaming; | |||
| import io.swagger.annotations.ApiModel; | |||
| import io.swagger.annotations.ApiModelProperty; | |||
| import lombok.Data; | |||
| import java.util.Date; | |||
| @@ -31,6 +32,12 @@ public class AutoMlIns { | |||
| private String source; | |||
| @ApiModelProperty(value = "Argo实例名称") | |||
| private String argoInsName; | |||
| @ApiModelProperty(value = "Argo命名空间") | |||
| private String argoInsNs; | |||
| private Date createTime; | |||
| private Date updateTime; | |||
| @@ -11,9 +11,13 @@ public interface AutoMlInsDao { | |||
| List<AutoMlIns> queryAllByLimit(@Param("autoMlIns") AutoMlIns autoMlIns, @Param("pageable") Pageable pageable); | |||
| List<AutoMlIns> getByAutoMlId(@Param("autoMlId") Long AutoMlId); | |||
| int insert(@Param("autoMlIns") AutoMlIns autoMlIns); | |||
| int update(@Param("autoMlIns") AutoMlIns autoMlIns); | |||
| AutoMlIns queryById(@Param("id") Long id); | |||
| List<AutoMlIns> queryByAutoMlInsIsNotTerminated(); | |||
| } | |||
| @@ -0,0 +1,97 @@ | |||
| package com.ruoyi.platform.scheduling; | |||
| import com.ruoyi.platform.domain.AutoMl; | |||
| import com.ruoyi.platform.domain.AutoMlIns; | |||
| import com.ruoyi.platform.mapper.AutoMlDao; | |||
| import com.ruoyi.platform.mapper.AutoMlInsDao; | |||
| import com.ruoyi.platform.service.AutoMlInsService; | |||
| import org.apache.commons.lang3.StringUtils; | |||
| import org.springframework.scheduling.annotation.Scheduled; | |||
| import org.springframework.stereotype.Component; | |||
| import javax.annotation.Resource; | |||
| import java.util.ArrayList; | |||
| import java.util.Iterator; | |||
| import java.util.List; | |||
| @Component() | |||
| public class AutoMlInsStatusTask { | |||
| @Resource | |||
| private AutoMlInsService autoMlInsService; | |||
| @Resource | |||
| private AutoMlInsDao autoMlInsDao; | |||
| @Resource | |||
| private AutoMlDao autoMlDao; | |||
| private List<Long> autoMlIds = new ArrayList<>(); | |||
| @Scheduled(cron = "0/30 * * * * ?") // 每30S执行一次 | |||
| public void executeAutoMlInsStatus() throws Exception { | |||
| // 首先查到所有非终止态的实验实例 | |||
| List<AutoMlIns> autoMlInsList = autoMlInsService.queryByAutoMlInsIsNotTerminated(); | |||
| // 去argo查询状态 | |||
| List<AutoMlIns> updateList = new ArrayList<>(); | |||
| if (autoMlInsList != null && autoMlInsList.size() > 0) { | |||
| for (AutoMlIns autoMlIns : autoMlInsList) { | |||
| //当原本状态为null或非终止态时才调用argo接口 | |||
| try { | |||
| autoMlIns = autoMlInsService.queryStatusFromArgo(autoMlIns); | |||
| } catch (Exception e) { | |||
| autoMlIns.setStatus("Failed"); | |||
| } | |||
| // 线程安全的添加操作 | |||
| synchronized (autoMlIds) { | |||
| autoMlIds.add(autoMlIns.getAutoMlId()); | |||
| } | |||
| updateList.add(autoMlIns); | |||
| } | |||
| if (updateList.size() > 0) { | |||
| for (AutoMlIns autoMlIns : updateList) { | |||
| autoMlInsDao.update(autoMlIns); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @Scheduled(cron = "0/30 * * * * ?") // / 每30S执行一次 | |||
| public void executeAutoMlStatus() throws Exception { | |||
| if (autoMlIds.size() == 0) { | |||
| return; | |||
| } | |||
| // 存储需要更新的实验对象列表 | |||
| List<AutoMl> updateAutoMls = new ArrayList<>(); | |||
| for (Long autoMlId : autoMlIds) { | |||
| // 获取当前实验的所有实例列表 | |||
| List<AutoMlIns> insList = autoMlInsDao.getByAutoMlId(autoMlId); | |||
| List<String> statusList = new ArrayList<>(); | |||
| // 更新实验状态列表 | |||
| for (int i = 0; i < insList.size(); i++) { | |||
| statusList.add(insList.get(i).getStatus()); | |||
| } | |||
| String subStatus = statusList.toString().substring(1, statusList.toString().length() - 1); | |||
| AutoMl autoMl = autoMlDao.getAutoMlById(autoMlId); | |||
| if (!StringUtils.equals(autoMl.getStatusList(), subStatus)) { | |||
| autoMl.setStatusList(subStatus); | |||
| updateAutoMls.add(autoMl); | |||
| autoMlDao.edit(autoMl); | |||
| } | |||
| } | |||
| if (!updateAutoMls.isEmpty()) { | |||
| // 使用Iterator进行安全的删除操作 | |||
| Iterator<Long> iterator = autoMlIds.iterator(); | |||
| while (iterator.hasNext()) { | |||
| Long autoMlId = iterator.next(); | |||
| for (AutoMl autoMl : updateAutoMls) { | |||
| if (autoMl.getId().equals(autoMlId)) { | |||
| iterator.remove(); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -15,4 +15,10 @@ public interface AutoMlInsService { | |||
| String removeById(Long id); | |||
| String batchDelete(List<Long> ids); | |||
| List<AutoMlIns> queryByAutoMlInsIsNotTerminated(); | |||
| AutoMlIns queryStatusFromArgo(AutoMlIns autoMlIns); | |||
| boolean terminateAutoMlIns(Long id); | |||
| } | |||
| @@ -4,6 +4,11 @@ import com.ruoyi.platform.constant.Constant; | |||
| import com.ruoyi.platform.domain.AutoMlIns; | |||
| import com.ruoyi.platform.mapper.AutoMlInsDao; | |||
| import com.ruoyi.platform.service.AutoMlInsService; | |||
| import com.ruoyi.platform.utils.DateUtils; | |||
| import com.ruoyi.platform.utils.HttpUtils; | |||
| import com.ruoyi.platform.utils.JsonUtils; | |||
| import org.apache.commons.lang3.StringUtils; | |||
| import org.springframework.beans.factory.annotation.Value; | |||
| import org.springframework.data.domain.Page; | |||
| import org.springframework.data.domain.PageImpl; | |||
| import org.springframework.data.domain.PageRequest; | |||
| @@ -11,10 +16,17 @@ import org.springframework.stereotype.Service; | |||
| import javax.annotation.Resource; | |||
| import java.io.IOException; | |||
| import java.util.List; | |||
| import java.util.*; | |||
| @Service | |||
| public class AutoMlInsServiceImpl implements AutoMlInsService { | |||
| @Value("${argo.url}") | |||
| private String argoUrl; | |||
| @Value("${argo.workflowStatus}") | |||
| private String argoWorkflowStatus; | |||
| @Value("${argo.workflowTermination}") | |||
| private String argoWorkflowTermination; | |||
| @Resource | |||
| private AutoMlInsDao autoMlInsDao; | |||
| @@ -38,6 +50,14 @@ public class AutoMlInsServiceImpl implements AutoMlInsService { | |||
| if (autoMlIns == null) { | |||
| return "实验实例不存在"; | |||
| } | |||
| if (StringUtils.isEmpty(autoMlIns.getStatus())) { | |||
| autoMlIns = queryStatusFromArgo(autoMlIns); | |||
| } | |||
| if (StringUtils.equals(autoMlIns.getStatus(), "Running")) { | |||
| return "实验实例正在运行,不可删除"; | |||
| } | |||
| autoMlIns.setState(Constant.State_invalid); | |||
| int update = autoMlInsDao.update(autoMlIns); | |||
| if (update > 0) { | |||
| @@ -57,4 +77,150 @@ public class AutoMlInsServiceImpl implements AutoMlInsService { | |||
| } | |||
| return "删除成功"; | |||
| } | |||
| @Override | |||
| public List<AutoMlIns> queryByAutoMlInsIsNotTerminated() { | |||
| return autoMlInsDao.queryByAutoMlInsIsNotTerminated(); | |||
| } | |||
| @Override | |||
| public AutoMlIns queryStatusFromArgo(AutoMlIns ins) { | |||
| String namespace = ins.getArgoInsNs(); | |||
| String name = ins.getArgoInsName(); | |||
| Long id = ins.getId(); | |||
| // 创建请求数据map | |||
| Map<String, Object> requestData = new HashMap<>(); | |||
| requestData.put("namespace", namespace); | |||
| requestData.put("name", name); | |||
| // 创建发送数据map,将请求数据作为"data"键的值 | |||
| Map<String, Object> res = new HashMap<>(); | |||
| res.put("data", requestData); | |||
| try { | |||
| // 发送POST请求到Argo工作流状态查询接口,并将请求数据转换为JSON | |||
| String req = HttpUtils.sendPost(argoUrl + argoWorkflowStatus, null, JsonUtils.mapToJson(res)); | |||
| // 检查响应是否为空或无内容 | |||
| if (req == null || StringUtils.isEmpty(req)) { | |||
| throw new RuntimeException("工作流状态响应为空。"); | |||
| } | |||
| // 将响应的JSON字符串转换为Map对象 | |||
| Map<String, Object> runResMap = JsonUtils.jsonToMap(req); | |||
| // 从响应Map中获取"data"部分 | |||
| Map<String, Object> data = (Map<String, Object>) runResMap.get("data"); | |||
| if (data == null || data.isEmpty()) { | |||
| throw new RuntimeException("工作流数据为空."); | |||
| } | |||
| // 从"data"中获取"status"部分,并返回"phase"的值 | |||
| Map<String, Object> status = (Map<String, Object>) data.get("status"); | |||
| if (status == null || status.isEmpty()) { | |||
| throw new RuntimeException("工作流状态为空。"); | |||
| } | |||
| //解析流水线结束时间 | |||
| String finishedAtString = (String) status.get("finishedAt"); | |||
| if (finishedAtString != null && !finishedAtString.isEmpty()) { | |||
| Date finishTime = DateUtils.convertUTCtoShanghaiDate(finishedAtString); | |||
| ins.setUpdateTime(finishTime); | |||
| } | |||
| // 解析nodes字段,提取节点状态并转换为JSON字符串 | |||
| Map<String, Object> nodes = (Map<String, Object>) status.get("nodes"); | |||
| Map<String, Object> modifiedNodes = new LinkedHashMap<>(); | |||
| if (nodes != null) { | |||
| for (Map.Entry<String, Object> nodeEntry : nodes.entrySet()) { | |||
| Map<String, Object> nodeDetails = (Map<String, Object>) nodeEntry.getValue(); | |||
| String templateName = (String) nodeDetails.get("displayName"); | |||
| modifiedNodes.put(templateName, nodeDetails); | |||
| } | |||
| } | |||
| String nodeStatusJson = JsonUtils.mapToJson(modifiedNodes); | |||
| ins.setNodeStatus(nodeStatusJson); | |||
| //终止态为终止不改 | |||
| if (!StringUtils.equals(ins.getStatus(), "Terminated")) { | |||
| ins.setStatus(StringUtils.isNotEmpty((String) status.get("phase")) ? (String) status.get("phase") : "Pending"); | |||
| } | |||
| if (StringUtils.equals(ins.getStatus(), "Error")) { | |||
| ins.setStatus("Failed"); | |||
| } | |||
| return ins; | |||
| } catch (Exception e) { | |||
| throw new RuntimeException("查询状态失败: " + e.getMessage(), e); | |||
| } | |||
| } | |||
| @Override | |||
| public boolean terminateAutoMlIns(Long id) { | |||
| AutoMlIns autoMlIns = autoMlInsDao.queryById(id); | |||
| if (autoMlIns == null) { | |||
| throw new IllegalStateException("实验实例未查询到,id: " + id); | |||
| } | |||
| String currentStatus = autoMlIns.getStatus(); | |||
| String name = autoMlIns.getArgoInsName(); | |||
| String namespace = autoMlIns.getArgoInsNs(); | |||
| // 获取当前状态,如果为空,则从Argo查询 | |||
| if (StringUtils.isEmpty(currentStatus)) { | |||
| currentStatus = queryStatusFromArgo(autoMlIns).getStatus(); | |||
| } | |||
| // 只有状态是"Running"时才能终止实例 | |||
| if (!currentStatus.equalsIgnoreCase(Constant.Running)) { | |||
| return false; // 如果不是"Running"状态,则不执行终止操作 | |||
| } | |||
| // 创建请求数据map | |||
| Map<String, Object> requestData = new HashMap<>(); | |||
| requestData.put("namespace", namespace); | |||
| requestData.put("name", name); | |||
| // 创建发送数据map,将请求数据作为"data"键的值 | |||
| Map<String, Object> res = new HashMap<>(); | |||
| res.put("data", requestData); | |||
| try { | |||
| // 发送POST请求到Argo工作流状态查询接口,并将请求数据转换为JSON | |||
| String req = HttpUtils.sendPost(argoUrl + argoWorkflowTermination, null, JsonUtils.mapToJson(res)); | |||
| // 检查响应是否为空或无内容 | |||
| if (StringUtils.isEmpty(req)) { | |||
| throw new RuntimeException("终止响应内容为空。"); | |||
| } | |||
| // 将响应的JSON字符串转换为Map对象 | |||
| Map<String, Object> runResMap = JsonUtils.jsonToMap(req); | |||
| // 从响应Map中直接获取"errCode"的值 | |||
| Integer errCode = (Integer) runResMap.get("errCode"); | |||
| if (errCode != null && errCode == 0) { | |||
| //更新autoMlIns,确保状态更新被保存到数据库 | |||
| AutoMlIns ins = queryStatusFromArgo(autoMlIns); | |||
| String nodeStatus = ins.getNodeStatus(); | |||
| Map<String, Object> nodeMap = JsonUtils.jsonToMap(nodeStatus); | |||
| // 遍历 map | |||
| for (Map.Entry<String, Object> entry : nodeMap.entrySet()) { | |||
| // 获取每个 Map 中的值并强制转换为 Map | |||
| Map<String, Object> innerMap = (Map<String, Object>) entry.getValue(); | |||
| // 检查 phase 的值 | |||
| if (innerMap.containsKey("phase")) { | |||
| String phaseValue = (String) innerMap.get("phase"); | |||
| // 如果值不等于 Succeeded,则赋值为 Failed | |||
| if (!StringUtils.equals("Succeeded", phaseValue)) { | |||
| innerMap.put("phase", "Failed"); | |||
| } | |||
| } | |||
| } | |||
| ins.setNodeStatus(JsonUtils.mapToJson(nodeMap)); | |||
| ins.setStatus(Constant.Terminated); | |||
| ins.setUpdateTime(new Date()); | |||
| this.autoMlInsDao.update(ins); | |||
| return true; | |||
| } else { | |||
| return false; | |||
| } | |||
| } catch (Exception e) { | |||
| throw new RuntimeException("终止实例错误: " + e.getMessage(), e); | |||
| } | |||
| } | |||
| } | |||
| @@ -3,15 +3,19 @@ package com.ruoyi.platform.service.impl; | |||
| import com.ruoyi.common.security.utils.SecurityUtils; | |||
| import com.ruoyi.platform.constant.Constant; | |||
| import com.ruoyi.platform.domain.AutoMl; | |||
| import com.ruoyi.platform.domain.AutoMlIns; | |||
| import com.ruoyi.platform.mapper.AutoMlDao; | |||
| import com.ruoyi.platform.mapper.AutoMlInsDao; | |||
| import com.ruoyi.platform.service.AutoMlService; | |||
| import com.ruoyi.platform.utils.*; | |||
| import com.ruoyi.platform.utils.FileUtil; | |||
| import com.ruoyi.platform.utils.HttpUtils; | |||
| import com.ruoyi.platform.utils.JacksonUtil; | |||
| import com.ruoyi.platform.utils.JsonUtils; | |||
| import com.ruoyi.platform.vo.AutoMlParamVo; | |||
| import com.ruoyi.platform.vo.AutoMlVo; | |||
| import io.kubernetes.client.openapi.models.V1Pod; | |||
| import org.apache.commons.collections4.MapUtils; | |||
| import org.apache.commons.io.FileUtils; | |||
| import org.apache.commons.lang3.StringUtils; | |||
| import org.slf4j.Logger; | |||
| import org.slf4j.LoggerFactory; | |||
| import org.springframework.beans.BeanUtils; | |||
| import org.springframework.beans.factory.annotation.Value; | |||
| import org.springframework.data.domain.Page; | |||
| @@ -26,29 +30,24 @@ import java.io.IOException; | |||
| import java.util.HashMap; | |||
| import java.util.List; | |||
| import java.util.Map; | |||
| import java.util.concurrent.CompletableFuture; | |||
| @Service("autoMLService") | |||
| public class AutoMlServiceImpl implements AutoMlService { | |||
| @Value("${harbor.serviceNS}") | |||
| private String serviceNS; | |||
| @Value("${dockerpush.proxyUrl}") | |||
| private String proxyUrl; | |||
| @Value("${dockerpush.mountPath}") | |||
| private String mountPath; | |||
| @Value("${minio.pvcName}") | |||
| private String pvcName; | |||
| @Value("${automl.image}") | |||
| private String image; | |||
| @Value("${git.localPath}") | |||
| String localPath; | |||
| private static final Logger logger = LoggerFactory.getLogger(ModelsServiceImpl.class); | |||
| @Value("${argo.url}") | |||
| private String argoUrl; | |||
| @Value("${argo.convertAutoML}") | |||
| String convertAutoML; | |||
| @Value("${argo.workflowRun}") | |||
| private String argoWorkflowRun; | |||
| @Resource | |||
| private AutoMlDao autoMlDao; | |||
| @Resource | |||
| private K8sClientUtil k8sClientUtil; | |||
| private AutoMlInsDao autoMlInsDao; | |||
| @Override | |||
| public Page<AutoMl> queryByPage(String mlName, PageRequest pageRequest) { | |||
| @@ -146,86 +145,52 @@ public class AutoMlServiceImpl implements AutoMlService { | |||
| public String runAutoMlIns(Long id) throws Exception { | |||
| AutoMl autoMl = autoMlDao.getAutoMlById(id); | |||
| if (autoMl == null) { | |||
| throw new Exception("开发环境配置不存在"); | |||
| } | |||
| StringBuffer command = new StringBuffer(); | |||
| command.append("nohup python /opt/automl.py --task_type " + autoMl.getTaskType()); | |||
| if (StringUtils.isNotEmpty(autoMl.getTargetColumns())) { | |||
| command.append(" --target_columns " + autoMl.getTargetColumns()); | |||
| } else { | |||
| throw new Exception("目标列为空"); | |||
| } | |||
| // String username = SecurityUtils.getLoginUser().getUsername().toLowerCase(); | |||
| String username = "admin"; | |||
| //构造pod名称 | |||
| String podName = username + "-autoMlIns-pod-" + id; | |||
| V1Pod pod = k8sClientUtil.createPodWithEnv(podName, serviceNS, proxyUrl, mountPath, pvcName, image); | |||
| if (autoMl.getTimeLeftForThisTask() != null) { | |||
| command.append(" --time_left_for_this_task " + autoMl.getTimeLeftForThisTask()); | |||
| } | |||
| if (autoMl.getPerRunTimeLimit() != null) { | |||
| command.append(" --per_run_time_limit " + autoMl.getPerRunTimeLimit()); | |||
| } | |||
| if (autoMl.getEnsembleSize() != null) { | |||
| command.append(" --ensemble_size " + autoMl.getEnsembleSize()); | |||
| } | |||
| if (StringUtils.isNotEmpty(autoMl.getEnsembleClass())) { | |||
| command.append(" --ensemble_class " + autoMl.getEnsembleClass()); | |||
| } | |||
| if (autoMl.getEnsembleNbest() != null) { | |||
| command.append(" --ensemble_nbest " + autoMl.getEnsembleNbest()); | |||
| } | |||
| if (autoMl.getMaxModelsOnDisc() != null) { | |||
| command.append(" --max_models_on_disc " + autoMl.getMaxModelsOnDisc()); | |||
| } | |||
| if (autoMl.getSeed() != null) { | |||
| command.append(" --seed " + autoMl.getSeed()); | |||
| } | |||
| if (autoMl.getMemoryLimit() != null) { | |||
| command.append(" --memory_limit " + autoMl.getMemoryLimit()); | |||
| } | |||
| if (StringUtils.isNotEmpty(autoMl.getIncludeClassifier())) { | |||
| command.append(" --include_classifier " + autoMl.getIncludeClassifier()); | |||
| } | |||
| if (StringUtils.isNotEmpty(autoMl.getIncludeRegressor())) { | |||
| command.append(" --include_regressor " + autoMl.getIncludeRegressor()); | |||
| } | |||
| if (StringUtils.isNotEmpty(autoMl.getIncludeFeaturePreprocessor())) { | |||
| command.append(" --include_feature_preprocessor " + autoMl.getIncludeFeaturePreprocessor()); | |||
| } | |||
| if (StringUtils.isNotEmpty(autoMl.getExcludeClassifier())) { | |||
| command.append(" --exclude_classifier " + autoMl.getExcludeClassifier()); | |||
| } | |||
| if (StringUtils.isNotEmpty(autoMl.getExcludeRegressor())) { | |||
| command.append(" --exclude_regressor " + autoMl.getExcludeRegressor()); | |||
| } | |||
| if (StringUtils.isNotEmpty(autoMl.getExcludeFeaturePreprocessor())) { | |||
| command.append(" --exclude_feature_preprocessor " + autoMl.getExcludeFeaturePreprocessor()); | |||
| } | |||
| if (StringUtils.isNotEmpty(autoMl.getResamplingStrategy())) { | |||
| command.append(" --resampling_strategy " + autoMl.getResamplingStrategy()); | |||
| } | |||
| if (autoMl.getTrainSize() != null) { | |||
| command.append(" --train_size " + autoMl.getTrainSize()); | |||
| } | |||
| if (autoMl.getShuffle() != null) { | |||
| command.append(" --shuffle " + autoMl.getShuffle()); | |||
| } | |||
| if (autoMl.getFolds() != null) { | |||
| command.append(" --folds " + autoMl.getFolds()); | |||
| } | |||
| command.append(" &"); | |||
| CompletableFuture.supplyAsync(() -> { | |||
| try { | |||
| String log = k8sClientUtil.executeCommand(pod, String.valueOf(command)); | |||
| } catch (Exception e) { | |||
| logger.error(e.getMessage(), e); | |||
| throw new Exception("自动机器学习配置不存在"); | |||
| } | |||
| AutoMlParamVo autoMlParam = new AutoMlParamVo(); | |||
| BeanUtils.copyProperties(autoMl, autoMlParam); | |||
| autoMlParam.setDataset(JsonUtils.jsonToMap(autoMl.getDataset())); | |||
| String param = JsonUtils.objectToJson(autoMlParam); | |||
| // 调argo转换接口 | |||
| try { | |||
| String convertRes = HttpUtils.sendPost(argoUrl + convertAutoML, param); | |||
| if (convertRes == null || StringUtils.isEmpty(convertRes)) { | |||
| throw new RuntimeException("转换流水线失败"); | |||
| } | |||
| return null; | |||
| }); | |||
| Map<String, Object> converMap = JsonUtils.jsonToMap(convertRes); | |||
| // 组装运行接口json | |||
| Map<String, Object> output = (Map<String, Object>) converMap.get("output"); | |||
| Map<String, Object> runReqMap = new HashMap<>(); | |||
| runReqMap.put("data", converMap.get("data")); | |||
| // 调argo运行接口 | |||
| String runRes = HttpUtils.sendPost(argoUrl + argoWorkflowRun, JsonUtils.mapToJson(runReqMap)); | |||
| if (runRes == null || StringUtils.isEmpty(runRes)) { | |||
| throw new RuntimeException("Failed to run workflow."); | |||
| } | |||
| Map<String, Object> runResMap = JsonUtils.jsonToMap(runRes); | |||
| Map<String, Object> data = (Map<String, Object>) runResMap.get("data"); | |||
| //判断data为空 | |||
| if (data == null || MapUtils.isEmpty(data)) { | |||
| throw new RuntimeException("Failed to run workflow."); | |||
| } | |||
| Map<String, Object> metadata = (Map<String, Object>) data.get("metadata"); | |||
| // 插入记录到实验实例表 | |||
| AutoMlIns autoMlIns = new AutoMlIns(); | |||
| autoMlIns.setAutoMlId(autoMl.getId()); | |||
| autoMlIns.setArgoInsNs((String) metadata.get("namespace")); | |||
| autoMlIns.setArgoInsName((String) metadata.get("name")); | |||
| autoMlIns.setParam(param); | |||
| autoMlIns.setStatus(Constant.Pending); | |||
| //替换argoInsName | |||
| String outputString = JsonUtils.mapToJson(output); | |||
| autoMlIns.setNodeResult(outputString.replace("{{workflow.name}}", (String) metadata.get("name"))); | |||
| autoMlInsDao.insert(autoMlIns); | |||
| } catch (Exception e) { | |||
| throw new RuntimeException(e); | |||
| } | |||
| return "执行成功"; | |||
| } | |||
| } | |||
| @@ -0,0 +1,155 @@ | |||
| package com.ruoyi.platform.vo; | |||
| import com.fasterxml.jackson.annotation.JsonInclude; | |||
| import com.fasterxml.jackson.databind.PropertyNamingStrategy; | |||
| import com.fasterxml.jackson.databind.annotation.JsonNaming; | |||
| import io.swagger.annotations.ApiModel; | |||
| import io.swagger.annotations.ApiModelProperty; | |||
| import lombok.Data; | |||
| import java.util.Map; | |||
| @Data | |||
| @JsonNaming(PropertyNamingStrategy.SnakeCaseStrategy.class) | |||
| @JsonInclude(JsonInclude.Include.NON_NULL) | |||
| @ApiModel(description = "自动机器学习参数") | |||
| public class AutoMlParamVo { | |||
| @ApiModelProperty(value = "任务类型:classification或regression") | |||
| private String taskType; | |||
| @ApiModelProperty(value = "搜索合适模型的时间限制(以秒为单位)。通过增加这个值,auto-sklearn有更高的机会找到更好的模型。默认3600,非必传。") | |||
| private Integer timeLeftForThisTask; | |||
| @ApiModelProperty(value = "单次调用机器学习模型的时间限制(以秒为单位)。如果机器学习算法运行超过时间限制,将终止模型拟合。将这个值设置得足够高,这样典型的机器学习算法就可以适用于训练数据。默认600,非必传。") | |||
| private Integer perRunTimeLimit; | |||
| @ApiModelProperty(value = "集成模型数量,如果设置为0,则没有集成。默认50,非必传。") | |||
| private Integer ensembleSize; | |||
| @ApiModelProperty(value = "设置为None将禁用集成构建,设置为SingleBest仅使用单个最佳模型而不是集成,设置为default,它将对单目标问题使用EnsembleSelection,对多目标问题使用MultiObjectiveDummyEnsemble。默认default,非必传。") | |||
| private String ensembleClass; | |||
| @ApiModelProperty(value = "在构建集成时只考虑ensemble_nbest模型。这是受到了“最大限度地利用集成选择”中引入的库修剪概念的启发。这是独立于ensemble_class参数的,并且这个修剪步骤是在构造集成之前完成的。默认50,非必传。") | |||
| private Integer ensembleNbest; | |||
| @ApiModelProperty(value = "定义在磁盘中保存的模型的最大数量。额外的模型数量将被永久删除。由于这个变量的性质,它设置了一个集成可以使用多少个模型的上限。必须是大于等于1的整数。如果设置为None,则所有模型都保留在磁盘上。默认50,非必传。") | |||
| private Integer maxModelsOnDisc; | |||
| @ApiModelProperty(value = "随机种子,将决定输出文件名。默认1,非必传。") | |||
| private Integer seed; | |||
| @ApiModelProperty(value = "机器学习算法的内存限制(MB)。如果auto-sklearn试图分配超过memory_limit MB,它将停止拟合机器学习算法。默认3072,非必传。") | |||
| private Integer memoryLimit; | |||
| @ApiModelProperty(value = "如果为None,则使用所有可能的分类算法。否则,指定搜索中包含的步骤和组件。有关可用组件,请参见/pipeline/components/<step>/*。与参数exclude不兼容。多选,逗号分隔。包含:adaboost\n" + | |||
| "bernoulli_nb\n" + | |||
| "decision_tree\n" + | |||
| "extra_trees\n" + | |||
| "gaussian_nb\n" + | |||
| "gradient_boosting\n" + | |||
| "k_nearest_neighbors\n" + | |||
| "lda\n" + | |||
| "liblinear_svc\n" + | |||
| "libsvm_svc\n" + | |||
| "mlp\n" + | |||
| "multinomial_nb\n" + | |||
| "passive_aggressive\n" + | |||
| "qda\n" + | |||
| "random_forest\n" + | |||
| "sgd") | |||
| private String includeClassifier; | |||
| @ApiModelProperty(value = "如果为None,则使用所有可能的特征预处理算法。否则,指定搜索中包含的步骤和组件。有关可用组件,请参见/pipeline/components/<step>/*。与参数exclude不兼容。多选,逗号分隔。包含:densifier\n" + | |||
| "extra_trees_preproc_for_classification\n" + | |||
| "extra_trees_preproc_for_regression\n" + | |||
| "fast_ica\n" + | |||
| "feature_agglomeration\n" + | |||
| "kernel_pca\n" + | |||
| "kitchen_sinks\n" + | |||
| "liblinear_svc_preprocessor\n" + | |||
| "no_preprocessing\n" + | |||
| "nystroem_sampler\n" + | |||
| "pca\n" + | |||
| "polynomial\n" + | |||
| "random_trees_embedding\n" + | |||
| "select_percentile_classification\n" + | |||
| "select_percentile_regression\n" + | |||
| "select_rates_classification\n" + | |||
| "select_rates_regression\n" + | |||
| "truncatedSVD") | |||
| private String includeFeaturePreprocessor; | |||
| @ApiModelProperty(value = "如果为None,则使用所有可能的回归算法。否则,指定搜索中包含的步骤和组件。有关可用组件,请参见/pipeline/components/<step>/*。与参数exclude不兼容。多选,逗号分隔。包含:adaboost,\n" + | |||
| "ard_regression,\n" + | |||
| "decision_tree,\n" + | |||
| "extra_trees,\n" + | |||
| "gaussian_process,\n" + | |||
| "gradient_boosting,\n" + | |||
| "k_nearest_neighbors,\n" + | |||
| "liblinear_svr,\n" + | |||
| "libsvm_svr,\n" + | |||
| "mlp,\n" + | |||
| "random_forest,\n" + | |||
| "sgd") | |||
| private String includeRegressor; | |||
| private String excludeClassifier; | |||
| private String excludeRegressor; | |||
| private String excludeFeaturePreprocessor; | |||
| @ApiModelProperty(value = "测试集的比率,0到1之间") | |||
| private Float testSize; | |||
| @ApiModelProperty(value = "如何处理过拟合,如果使用基于“cv”的方法或Splitter对象,可能需要使用resampling_strategy_arguments。holdout或crossValid") | |||
| private String resamplingStrategy; | |||
| @ApiModelProperty(value = "重采样划分训练集和验证集,训练集的比率,0到1之间") | |||
| private Float trainSize; | |||
| @ApiModelProperty(value = "拆分数据前是否进行shuffle") | |||
| private Boolean shuffle; | |||
| @ApiModelProperty(value = "交叉验证的折数,当resamplingStrategy为crossValid时,此项必填,为整数") | |||
| private Integer folds; | |||
| @ApiModelProperty(value = "数据集csv文件中哪几列是预测目标列,逗号分隔") | |||
| private String targetColumns; | |||
| @ApiModelProperty(value = "自定义指标名称") | |||
| private String metricName; | |||
| @ApiModelProperty(value = "模型优化目标指标及权重,json格式。分类的指标包含:accuracy\n" + | |||
| "balanced_accuracy\n" + | |||
| "roc_auc\n" + | |||
| "average_precision\n" + | |||
| "log_loss\n" + | |||
| "precision_macro\n" + | |||
| "precision_micro\n" + | |||
| "precision_samples\n" + | |||
| "precision_weighted\n" + | |||
| "recall_macro\n" + | |||
| "recall_micro\n" + | |||
| "recall_samples\n" + | |||
| "recall_weighted\n" + | |||
| "f1_macro\n" + | |||
| "f1_micro\n" + | |||
| "f1_samples\n" + | |||
| "f1_weighted\n" + | |||
| "回归的指标包含:mean_absolute_error\n" + | |||
| "mean_squared_error\n" + | |||
| "root_mean_squared_error\n" + | |||
| "mean_squared_log_error\n" + | |||
| "median_absolute_error\n" + | |||
| "r2") | |||
| private String metrics; | |||
| @ApiModelProperty(value = "指标优化方向,是越大越好还是越小越好") | |||
| private Boolean greaterIsBetter; | |||
| @ApiModelProperty(value = "模型计算并打印指标") | |||
| private String scoringFunctions; | |||
| private Map<String,Object> dataset; | |||
| } | |||
| @@ -34,9 +34,9 @@ | |||
| <if test="autoMl.mlDescription != null and autoMl.mlDescription !=''"> | |||
| ml_description = #{autoMl.mlDescription}, | |||
| </if> | |||
| <!-- <if test="autoMl.runState != null and autoMl.runState !=''">--> | |||
| <!-- run_state = #{autoMl.runState},--> | |||
| <!-- </if>--> | |||
| <if test="autoMl.statusList != null and autoMl.statusList !=''"> | |||
| status_list = #{autoMl.statusList}, | |||
| </if> | |||
| <!-- <if test="autoMl.progress != null">--> | |||
| <!-- progress = #{autoMl.progress},--> | |||
| <!-- </if>--> | |||
| @@ -3,9 +3,9 @@ | |||
| <mapper namespace="com.ruoyi.platform.mapper.AutoMlInsDao"> | |||
| <insert id="insert" keyProperty="id" useGeneratedKeys="true"> | |||
| insert into auto_ml_ins(auto_ml_id, model_path, img_path, node_status, node_result, param, source) | |||
| insert into auto_ml_ins(auto_ml_id, model_path, img_path, node_status, node_result, param, source, argo_ins_name, argo_ins_ns) | |||
| values (#{autoMlIns.autoMlId}, #{autoMlIns.modelPath}, #{autoMlIns.imgPath}, #{autoMlIns.nodeStatus}, | |||
| #{autoMlIns.nodeResult}, #{autoMlIns.param}, #{autoMlIns.source}) | |||
| #{autoMlIns.nodeResult}, #{autoMlIns.param}, #{autoMlIns.source}, #{autoMlIns.argoInsName}, #{autoMlIns.argoInsNs}) | |||
| </insert> | |||
| <update id="update"> | |||
| @@ -29,7 +29,11 @@ | |||
| <if test="autoMlIns.state != null"> | |||
| state = #{autoMlIns.state}, | |||
| </if> | |||
| <if test="autoMlIns.updateTime != null"> | |||
| update_time = #{autoMlIns.updateTime}, | |||
| </if> | |||
| </set> | |||
| where id = #{autoMlIns.id} | |||
| </update> | |||
| <select id="count" resultType="java.lang.Long"> | |||
| @@ -61,4 +65,20 @@ | |||
| state = 1 and id = #{id} | |||
| </where> | |||
| </select> | |||
| <select id="queryByAutoMlInsIsNotTerminated" resultType="com.ruoyi.platform.domain.AutoMlIns"> | |||
| select * | |||
| from auto_ml_ins | |||
| where (status NOT IN ('Terminated', 'Succeeded', 'Failed') | |||
| OR status IS NULL) | |||
| and state = 1 | |||
| </select> | |||
| <select id="getByAutoMlId" resultType="com.ruoyi.platform.domain.AutoMlIns"> | |||
| select * | |||
| from auto_ml_ins | |||
| where auto_ml_id = #{autoMlId} | |||
| and state = 1 | |||
| order by update_time DESC limit 5 | |||
| </select> | |||
| </mapper> | |||