diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/mapper/RayInsDao.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/mapper/RayInsDao.java index ceeb581d..86801d04 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/mapper/RayInsDao.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/mapper/RayInsDao.java @@ -18,4 +18,6 @@ public interface RayInsDao { int insert(@Param("rayIns") RayIns rayIns); int update(@Param("rayIns") RayIns rayIns); + + List queryByRayInsIsNotTerminated(); } diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/scheduling/AutoMlInsStatusTask.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/scheduling/AutoMlInsStatusTask.java index 5c94284f..fa07acf6 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/scheduling/AutoMlInsStatusTask.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/scheduling/AutoMlInsStatusTask.java @@ -1,5 +1,6 @@ package com.ruoyi.platform.scheduling; +import com.ruoyi.platform.constant.Constant; import com.ruoyi.platform.domain.AutoMl; import com.ruoyi.platform.domain.AutoMlIns; import com.ruoyi.platform.mapper.AutoMlDao; @@ -41,7 +42,7 @@ public class AutoMlInsStatusTask { try { autoMlIns = autoMlInsService.queryStatusFromArgo(autoMlIns); } catch (Exception e) { - autoMlIns.setStatus("Failed"); + autoMlIns.setStatus(Constant.Failed); } // 线程安全的添加操作 synchronized (autoMlIds) { diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/scheduling/RayInsStatusTask.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/scheduling/RayInsStatusTask.java new file mode 100644 index 00000000..23d03164 --- /dev/null +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/scheduling/RayInsStatusTask.java @@ -0,0 +1,95 @@ +package com.ruoyi.platform.scheduling; + +import com.ruoyi.platform.constant.Constant; +import com.ruoyi.platform.domain.Ray; +import com.ruoyi.platform.domain.RayIns; +import com.ruoyi.platform.mapper.RayDao; +import com.ruoyi.platform.mapper.RayInsDao; +import com.ruoyi.platform.service.RayInsService; +import org.apache.commons.lang3.StringUtils; +import org.springframework.scheduling.annotation.Scheduled; +import org.springframework.stereotype.Component; + +import javax.annotation.Resource; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; + +@Component() +public class RayInsStatusTask { + + @Resource + private RayInsService rayInsService; + @Resource + private RayInsDao rayInsDao; + @Resource + private RayDao rayDao; + + private List rayIds = new ArrayList<>(); + + @Scheduled(cron = "0/30 * * * * ?") // 每30S执行一次 + public void executeRayInsStatus() { + List rayInsList = rayInsService.queryByRayInsIsNotTerminated(); + + // 去argo查询状态 + List updateList = new ArrayList<>(); + if (rayInsList != null && rayInsList.size() > 0) { + for (RayIns rayIns : rayInsList) { + //当原本状态为null或非终止态时才调用argo接口 + try { + rayIns = rayInsService.queryStatusFromArgo(rayIns); + } catch (Exception e) { + rayIns.setStatus(Constant.Failed); + } + // 线程安全的添加操作 + synchronized (rayIds) { + rayIds.add(rayIns.getRayId()); + } + updateList.add(rayIns); + } + if (updateList.size() > 0) { + for (RayIns rayIns : updateList) { + rayInsDao.update(rayIns); + } + } + } + } + + @Scheduled(cron = "0/30 * * * * ?") // / 每30S执行一次 + public void executeRayStatus() { + if (rayIds.size() == 0) { + return; + } + // 存储需要更新的实验对象列表 + List updateRays = new ArrayList<>(); + for (Long rayId : rayIds) { + // 获取当前实验的所有实例列表 + List insList = rayInsDao.getByRayId(rayId); + List statusList = new ArrayList<>(); + // 更新实验状态列表 + for (int i = 0; i < insList.size(); i++) { + statusList.add(insList.get(i).getStatus()); + } + String subStatus = statusList.toString().substring(1, statusList.toString().length() - 1); + Ray ray = rayDao.getRayById(rayId); + if (!StringUtils.equals(ray.getStatusList(), subStatus)) { + ray.setStatusList(subStatus); + updateRays.add(ray); + rayDao.edit(ray); + } + } + + if (!updateRays.isEmpty()) { + // 使用Iterator进行安全的删除操作 + Iterator iterator = rayIds.iterator(); + while (iterator.hasNext()) { + Long rayId = iterator.next(); + for (Ray ray : updateRays) { + if (ray.getId().equals(rayId)) { + iterator.remove(); + } + } + } + } + } +} diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/RayInsService.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/RayInsService.java index 867dd0a9..63559ffe 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/RayInsService.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/RayInsService.java @@ -18,4 +18,8 @@ public interface RayInsService { RayIns getDetailById(Long id) throws IOException; void updateRayStatus(Long rayId); + + RayIns queryStatusFromArgo(RayIns ins); + + List queryByRayInsIsNotTerminated(); } diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/RayInsServiceImpl.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/RayInsServiceImpl.java index 976a43b5..f881faa4 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/RayInsServiceImpl.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/RayInsServiceImpl.java @@ -6,8 +6,11 @@ import com.ruoyi.platform.domain.RayIns; import com.ruoyi.platform.mapper.RayDao; import com.ruoyi.platform.mapper.RayInsDao; import com.ruoyi.platform.service.RayInsService; +import com.ruoyi.platform.utils.DateUtils; +import com.ruoyi.platform.utils.HttpUtils; import com.ruoyi.platform.utils.JsonUtils; import org.apache.commons.lang3.StringUtils; +import org.springframework.beans.factory.annotation.Value; import org.springframework.data.domain.Page; import org.springframework.data.domain.PageImpl; import org.springframework.data.domain.PageRequest; @@ -23,6 +26,13 @@ import java.util.stream.Collectors; @Service("rayInsService") public class RayInsServiceImpl implements RayInsService { + @Value("${argo.url}") + private String argoUrl; + @Value("${argo.workflowStatus}") + private String argoWorkflowStatus; + @Value("${argo.workflowTermination}") + private String argoWorkflowTermination; + @Resource private RayInsDao rayInsDao; @@ -49,7 +59,7 @@ public class RayInsServiceImpl implements RayInsService { return "实验实例不存在"; } if (StringUtils.isEmpty(rayIns.getStatus())) { - //todo queryStatusFromArgo + rayIns = queryStatusFromArgo(rayIns); } if (StringUtils.equals(rayIns.getStatus(), Constant.Running)) { return "实验实例正在运行,不可删除"; @@ -89,7 +99,7 @@ public class RayInsServiceImpl implements RayInsService { // 获取当前状态,如果为空,则从Argo查询 if (StringUtils.isEmpty(currentStatus)) { - // todo queryStatusFromArgo + currentStatus = queryStatusFromArgo(rayIns).getStatus(); } // 只有状态是"Running"时才能终止实例 @@ -97,20 +107,64 @@ public class RayInsServiceImpl implements RayInsService { throw new Exception("终止错误,只有运行状态的实例才能终止"); // 如果不是"Running"状态,则不执行终止操作 } - //todo terminateFromArgo + // 创建请求数据map + Map requestData = new HashMap<>(); + requestData.put("namespace", namespace); + requestData.put("name", name); + // 创建发送数据map,将请求数据作为"data"键的值 + Map res = new HashMap<>(); + res.put("data", requestData); + + try { + // 发送POST请求到Argo工作流状态查询接口,并将请求数据转换为JSON + String req = HttpUtils.sendPost(argoUrl + argoWorkflowTermination, null, JsonUtils.mapToJson(res)); + // 检查响应是否为空或无内容 + if (StringUtils.isEmpty(req)) { + throw new RuntimeException("终止响应内容为空。"); + + } + // 将响应的JSON字符串转换为Map对象 + Map runResMap = JsonUtils.jsonToMap(req); + // 从响应Map中直接获取"errCode"的值 + Integer errCode = (Integer) runResMap.get("errCode"); + if (errCode != null && errCode == 0) { + //更新autoMlIns,确保状态更新被保存到数据库 + RayIns ins = queryStatusFromArgo(rayIns); + String nodeStatus = ins.getNodeStatus(); + Map nodeMap = JsonUtils.jsonToMap(nodeStatus); - rayIns.setStatus(Constant.Terminated); - rayIns.setFinishTime(new Date()); - this.rayInsDao.update(rayIns); - updateRayStatus(rayIns.getRayId()); - return true; + // 遍历 map + for (Map.Entry entry : nodeMap.entrySet()) { + // 获取每个 Map 中的值并强制转换为 Map + Map innerMap = (Map) entry.getValue(); + // 检查 phase 的值 + if (innerMap.containsKey("phase")) { + String phaseValue = (String) innerMap.get("phase"); + // 如果值不等于 Succeeded,则赋值为 Failed + if (!StringUtils.equals("Succeeded", phaseValue)) { + innerMap.put("phase", "Failed"); + } + } + } + ins.setNodeStatus(JsonUtils.mapToJson(nodeMap)); + ins.setStatus(Constant.Terminated); + ins.setUpdateTime(new Date()); + rayInsDao.update(ins); + updateRayStatus(rayIns.getRayId()); + return true; + } else { + return false; + } + } catch (Exception e) { + throw new RuntimeException("终止实例错误: " + e.getMessage(), e); + } } @Override public RayIns getDetailById(Long id) throws IOException { RayIns rayIns = rayInsDao.queryById(id); if (Constant.Running.equals(rayIns.getStatus()) || Constant.Pending.equals(rayIns.getStatus())) { - //todo queryStatusFromArgo + rayIns = queryStatusFromArgo(rayIns); } rayIns.setTrialList(getTrialList(rayIns.getResultPath())); return rayIns; @@ -132,6 +186,80 @@ public class RayInsServiceImpl implements RayInsService { } } + @Override + public RayIns queryStatusFromArgo(RayIns ins) { + String namespace = ins.getArgoInsNs(); + String name = ins.getArgoInsName(); + + // 创建请求数据map + Map requestData = new HashMap<>(); + requestData.put("namespace", namespace); + requestData.put("name", name); + + // 创建发送数据map,将请求数据作为"data"键的值 + Map res = new HashMap<>(); + res.put("data", requestData); + + try { + // 发送POST请求到Argo工作流状态查询接口,并将请求数据转换为JSON + String req = HttpUtils.sendPost(argoUrl + argoWorkflowStatus, null, JsonUtils.mapToJson(res)); + // 检查响应是否为空或无内容 + if (req == null || StringUtils.isEmpty(req)) { + throw new RuntimeException("工作流状态响应为空。"); + } + // 将响应的JSON字符串转换为Map对象 + Map runResMap = JsonUtils.jsonToMap(req); + // 从响应Map中获取"data"部分 + Map data = (Map) runResMap.get("data"); + if (data == null || data.isEmpty()) { + throw new RuntimeException("工作流数据为空."); + } + // 从"data"中获取"status"部分,并返回"phase"的值 + Map status = (Map) data.get("status"); + if (status == null || status.isEmpty()) { + throw new RuntimeException("工作流状态为空。"); + } + + //解析流水线结束时间 + String finishedAtString = (String) status.get("finishedAt"); + if (finishedAtString != null && !finishedAtString.isEmpty()) { + Date finishTime = DateUtils.convertUTCtoShanghaiDate(finishedAtString); + ins.setFinishTime(finishTime); + } + + // 解析nodes字段,提取节点状态并转换为JSON字符串 + Map nodes = (Map) status.get("nodes"); + Map modifiedNodes = new LinkedHashMap<>(); + if (nodes != null) { + for (Map.Entry nodeEntry : nodes.entrySet()) { + Map nodeDetails = (Map) nodeEntry.getValue(); + String templateName = (String) nodeDetails.get("displayName"); + modifiedNodes.put(templateName, nodeDetails); + } + } + + + String nodeStatusJson = JsonUtils.mapToJson(modifiedNodes); + ins.setNodeStatus(nodeStatusJson); + + //终止态为终止不改 + if (!StringUtils.equals(ins.getStatus(), Constant.Terminated)) { + ins.setStatus(StringUtils.isNotEmpty((String) status.get("phase")) ? (String) status.get("phase") : Constant.Pending); + } + if (StringUtils.equals(ins.getStatus(), "Error")) { + ins.setStatus(Constant.Failed); + } + return ins; + } catch (Exception e) { + throw new RuntimeException("查询状态失败: " + e.getMessage(), e); + } + } + + @Override + public List queryByRayInsIsNotTerminated() { + return rayInsDao.queryByRayInsIsNotTerminated(); + } + public ArrayList> getTrialList(String directoryPath) throws IOException { // 获取指定路径下的所有文件 Path dirPath = Paths.get(directoryPath); diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/RayServiceImpl.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/RayServiceImpl.java index 32c19964..1d812659 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/RayServiceImpl.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/RayServiceImpl.java @@ -4,14 +4,22 @@ import com.google.gson.Gson; import com.google.gson.reflect.TypeToken; import com.ruoyi.common.security.utils.SecurityUtils; import com.ruoyi.platform.constant.Constant; +import com.ruoyi.platform.domain.AutoMlIns; import com.ruoyi.platform.domain.Ray; +import com.ruoyi.platform.domain.RayIns; import com.ruoyi.platform.mapper.RayDao; +import com.ruoyi.platform.mapper.RayInsDao; +import com.ruoyi.platform.service.RayInsService; import com.ruoyi.platform.service.RayService; +import com.ruoyi.platform.utils.HttpUtils; import com.ruoyi.platform.utils.JacksonUtil; import com.ruoyi.platform.utils.JsonUtils; +import com.ruoyi.platform.vo.RayParamVo; import com.ruoyi.platform.vo.RayVo; +import org.apache.commons.collections4.MapUtils; import org.apache.commons.lang3.StringUtils; import org.springframework.beans.BeanUtils; +import org.springframework.beans.factory.annotation.Value; import org.springframework.data.domain.Page; import org.springframework.data.domain.PageImpl; import org.springframework.data.domain.PageRequest; @@ -20,13 +28,30 @@ import org.springframework.stereotype.Service; import javax.annotation.Resource; import java.io.IOException; import java.lang.reflect.Type; +import java.util.ArrayList; +import java.util.HashMap; import java.util.List; import java.util.Map; @Service("rayService") public class RayServiceImpl implements RayService { + + @Value("${argo.url}") + private String argoUrl; + @Value("${argo.convertRay}") + String convertRay; + @Value("${argo.workflowRun}") + private String argoWorkflowRun; + + @Value("${minio.endpoint}") + private String minioEndpoint; + @Resource private RayDao rayDao; + @Resource + private RayInsDao rayInsDao; + @Resource + private RayInsService rayInsService; @Override public Page queryByPage(String name, PageRequest pageRequest) { @@ -126,7 +151,61 @@ public class RayServiceImpl implements RayService { if (ray == null) { throw new Exception("自动超参数寻优配置不存在"); } - //todo argo - return null; + + RayParamVo rayParamVo = new RayParamVo(); + BeanUtils.copyProperties(ray, rayParamVo); + rayParamVo.setCodeConfig(JsonUtils.jsonToMap(ray.getCodeConfig())); + rayParamVo.setDataset(JsonUtils.jsonToMap(ray.getDataset())); + rayParamVo.setModel(JsonUtils.jsonToMap(ray.getModel())); + rayParamVo.setImage(JsonUtils.jsonToMap(ray.getImage())); + String param = JsonUtils.objectToJson(rayParamVo); + + // 调argo转换接口 + try { + String convertRes = HttpUtils.sendPost(argoUrl + convertRay, param); + if (convertRes == null || StringUtils.isEmpty(convertRes)) { + throw new RuntimeException("转换流水线失败"); + } + Map converMap = JsonUtils.jsonToMap(convertRes); + // 组装运行接口json + Map output = (Map) converMap.get("output"); + Map runReqMap = new HashMap<>(); + runReqMap.put("data", converMap.get("data")); + // 调argo运行接口 + String runRes = HttpUtils.sendPost(argoUrl + argoWorkflowRun, JsonUtils.mapToJson(runReqMap)); + if (runRes == null || StringUtils.isEmpty(runRes)) { + throw new RuntimeException("Failed to run workflow."); + } + Map runResMap = JsonUtils.jsonToMap(runRes); + Map data = (Map) runResMap.get("data"); + //判断data为空 + if (data == null || MapUtils.isEmpty(data)) { + throw new RuntimeException("Failed to run workflow."); + } + Map metadata = (Map) data.get("metadata"); + + // 插入记录到实验实例表 + RayIns rayIns = new RayIns(); + rayIns.setRayId(ray.getId()); + rayIns.setArgoInsNs((String) metadata.get("namespace")); + rayIns.setArgoInsName((String) metadata.get("name")); + rayIns.setParam(param); + rayIns.setStatus(Constant.Pending); + //替换argoInsName + String outputString = JsonUtils.mapToJson(output); + rayIns.setNodeResult(outputString.replace("{{workflow.name}}", (String) metadata.get("name"))); + + Map param_output = (Map) output.get("param_output"); + List output1 = (ArrayList) param_output.values().toArray()[0]; + Map output2 = (Map) output1.get(0); + String outputPath = minioEndpoint + "/" + output2.get("path").replace("{{workflow.name}}", (String) metadata.get("name")) + "/"; + + rayIns.setResultPath(outputPath); + rayInsDao.insert(rayIns); + rayInsService.updateRayStatus(id); + } catch (Exception e) { + throw new RuntimeException(e); + } + return "执行成功"; } } diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/vo/RayParamVo.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/vo/RayParamVo.java new file mode 100644 index 00000000..4b6e1c5a --- /dev/null +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/vo/RayParamVo.java @@ -0,0 +1,50 @@ +package com.ruoyi.platform.vo; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import io.swagger.annotations.ApiModel; +import lombok.Data; + +import java.util.Map; + +@Data +@JsonNaming(PropertyNamingStrategy.SnakeCaseStrategy.class) +@JsonInclude(JsonInclude.Include.NON_NULL) +@ApiModel(description = "超参数寻优参数") +public class RayParamVo { + + private Map codeConfig; + + private Map dataset; + + private Map image; + + private Map model; + + private String mainPy; + + private String name; + + private Integer numSamples; + + private String parameters; + + private String pointsToEvaluate; + + private String storagePath; + + private String searchAlg; + + private String scheduler; + + private String metric; + + private String mode; + + private Integer maxT; + + private Integer minSamplesRequired; + + private String resource; +} diff --git a/ruoyi-modules/management-platform/src/main/resources/mapper/managementPlatform/RayInsDaoMapper.xml b/ruoyi-modules/management-platform/src/main/resources/mapper/managementPlatform/RayInsDaoMapper.xml index ca4b7231..2b9c12ac 100644 --- a/ruoyi-modules/management-platform/src/main/resources/mapper/managementPlatform/RayInsDaoMapper.xml +++ b/ruoyi-modules/management-platform/src/main/resources/mapper/managementPlatform/RayInsDaoMapper.xml @@ -66,4 +66,12 @@ and state = 1 order by update_time DESC limit 5 + + \ No newline at end of file