From 0e6c7ce5e86bcc1ea4a090287efe9b511af982a3 Mon Sep 17 00:00:00 2001 From: fanshuai <1141904845@qq.com> Date: Wed, 26 Jun 2024 09:14:41 +0800 Subject: [PATCH] =?UTF-8?q?=E6=96=B0=E5=A2=9E=E6=8C=87=E6=A0=87=E5=AF=B9?= =?UTF-8?q?=E6=AF=94?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../platform/mapper/ModelDependencyDao.java | 4 + .../ExperimentInstanceStatusTask.java | 70 +++++----- .../service/ModelDependencyService.java | 4 + .../platform/service/impl/AimServiceImpl.java | 132 ++++++++++++------ .../impl/ModelDependencyServiceImpl.java | 15 +- .../ruoyi/platform/vo/InsMetricInfoVo.java | 6 +- .../ModelDependencyDaoMapper.xml | 16 +++ 7 files changed, 162 insertions(+), 85 deletions(-) diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/mapper/ModelDependencyDao.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/mapper/ModelDependencyDao.java index 3a999886..ba1bc40b 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/mapper/ModelDependencyDao.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/mapper/ModelDependencyDao.java @@ -84,5 +84,9 @@ public interface ModelDependencyDao { List queryByModelDependency(@Param("modelDependency") ModelDependency modelDependency); List queryChildrenByVersionId(@Param("model_id")String modelId, @Param("version")String version); + + List queryByIns(@Param("expInsId")Integer expInsId); + + ModelDependency queryByInsAndTrainTaskId(@Param("expInsId")Integer expInsId,@Param("taskId") String taskId); } diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/scheduling/ExperimentInstanceStatusTask.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/scheduling/ExperimentInstanceStatusTask.java index 4680285e..131dca48 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/scheduling/ExperimentInstanceStatusTask.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/scheduling/ExperimentInstanceStatusTask.java @@ -7,20 +7,15 @@ import com.ruoyi.platform.mapper.ExperimentDao; import com.ruoyi.platform.mapper.ExperimentInsDao; import com.ruoyi.platform.mapper.ModelDependencyDao; import com.ruoyi.platform.service.ExperimentInsService; -import com.ruoyi.platform.service.ModelDependencyService; import com.ruoyi.platform.utils.JacksonUtil; import org.apache.commons.lang3.StringUtils; import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.data.domain.Page; import org.springframework.scheduling.annotation.Scheduled; import org.springframework.stereotype.Component; import javax.annotation.Resource; import java.io.IOException; -import java.util.ArrayList; -import java.util.Date; -import java.util.List; -import java.util.Map; +import java.util.*; @Component() public class ExperimentInstanceStatusTask { @@ -34,7 +29,7 @@ public class ExperimentInstanceStatusTask { private ModelDependencyDao modelDependencyDao; private List experimentIds = new ArrayList<>(); - @Scheduled(cron = "0/14 * * * * ?") // 每30S执行一次 + @Scheduled(cron = "0/30 * * * * ?") // 每30S执行一次 public void executeExperimentInsStatus() throws IOException { // 首先查到所有非终止态的实验实例 List experimentInsList = experimentInsService.queryByExperimentIsNotTerminated(); @@ -46,95 +41,94 @@ public class ExperimentInstanceStatusTask { String oldStatus = experimentIns.getStatus(); try { experimentIns = experimentInsService.queryStatusFromArgo(experimentIns); - }catch (Exception e){ + } catch (Exception e) { experimentIns.setStatus("Failed"); } -// if (!StringUtils.equals(oldStatus,experimentIns.getStatus())){ - experimentIns.setUpdateTime(new Date()); - // 线程安全的添加操作 - synchronized (experimentIds) { - experimentIds.add(experimentIns.getExperimentId()); - } - updateList.add(experimentIns); - -// } -// experimentInsDao.update(experimentIns); + experimentIns.setUpdateTime(new Date()); + // 线程安全的添加操作 + synchronized (experimentIds) { + experimentIds.add(experimentIns.getExperimentId()); + } + updateList.add(experimentIns); } - } - if (updateList.size() > 0){ + if (updateList.size() > 0) { experimentInsDao.insertOrUpdateBatch(updateList); //遍历模型关系表,找到 List modelDependencyList = new ArrayList(); - for (ExperimentIns experimentIns : updateList){ + for (ExperimentIns experimentIns : updateList) { ModelDependency modelDependencyquery = new ModelDependency(); modelDependencyquery.setExpInsId(experimentIns.getId()); modelDependencyquery.setState(2); List modelDependencyListquery = modelDependencyDao.queryByModelDependency(modelDependencyquery); - if (modelDependencyListquery==null||modelDependencyListquery.size()==0){ + if (modelDependencyListquery == null || modelDependencyListquery.size() == 0) { continue; } ModelDependency modelDependency = modelDependencyListquery.get(0); //查看状态, - if (StringUtils.equals("Failed",experimentIns.getStatus())){ + if (StringUtils.equals("Failed", experimentIns.getStatus())) { //取出节点状态 String trainTask = modelDependency.getTrainTask(); Map trainMap = JacksonUtil.parseJSONStr2Map(trainTask); String task_id = (String) trainMap.get("task_id"); - if (StringUtils.isEmpty(task_id)){ + if (StringUtils.isEmpty(task_id)) { continue; } String nodesStatus = experimentIns.getNodesStatus(); Map nodeMaps = JacksonUtil.parseJSONStr2Map(nodesStatus); Map nodeMap = JacksonUtil.parseJSONStr2Map(JacksonUtil.toJSONString(nodeMaps.get(task_id))); - if (nodeMap==null){ + if (nodeMap == null) { continue; } - if (!StringUtils.equals("Succeeded",(String)nodeMap.get("phase"))){ + if (!StringUtils.equals("Succeeded", (String) nodeMap.get("phase"))) { modelDependency.setState(0); modelDependencyList.add(modelDependency); } } } - if (modelDependencyList.size()>0) { + if (modelDependencyList.size() > 0) { modelDependencyDao.insertOrUpdateBatch(modelDependencyList); } } - } - @Scheduled(cron = "0/17 * * * * ?") // / 每30S执行一次 + + @Scheduled(cron = "0/30 * * * * ?") // / 每30S执行一次 public void executeExperimentStatus() throws IOException { - if (experimentIds.size()==0){ + if (experimentIds.size() == 0) { return; } // 存储需要更新的实验对象列表 List updateExperiments = new ArrayList<>(); - for (Integer experimentId : experimentIds){ + for (Integer experimentId : experimentIds) { // 获取当前实验的所有实例列表 List insList = experimentInsService.getByExperimentId(experimentId); List statusList = new ArrayList(); // 更新实验状态列表 - for (int i=0;i iterator = experimentIds.iterator(); + while (iterator.hasNext()) { + Integer experimentId = iterator.next(); + for (Experiment experiment : updateExperiments) { + if (experiment.getId().equals(experimentId)) { + iterator.remove(); + } } } } diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/ModelDependencyService.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/ModelDependencyService.java index 5c8b9d1d..049d87d1 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/ModelDependencyService.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/ModelDependencyService.java @@ -62,4 +62,8 @@ public interface ModelDependencyService { List queryByModelDependency(ModelDependency modelDependency) throws IOException; ModelDependcyTreeVo getModelDependencyTree(ModelDependency modelDependency) throws Exception; + + List queryByIns(Integer expInsId); + + ModelDependency queryByInsAndTrainTaskId(Integer expInsId, String taskId); } diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/AimServiceImpl.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/AimServiceImpl.java index c26e101e..6ec8f43c 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/AimServiceImpl.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/AimServiceImpl.java @@ -1,18 +1,17 @@ package com.ruoyi.platform.service.impl; -import com.alibaba.druid.util.StringUtils; import com.ruoyi.platform.domain.ExperimentIns; +import com.ruoyi.platform.domain.ModelDependency; import com.ruoyi.platform.service.AimService; import com.ruoyi.platform.service.ExperimentInsService; -import com.ruoyi.platform.service.ExperimentService; +import com.ruoyi.platform.service.ModelDependencyService; import com.ruoyi.platform.utils.AIM64EncoderUtil; import com.ruoyi.platform.utils.HttpUtils; import com.ruoyi.platform.utils.JacksonUtil; import com.ruoyi.platform.utils.JsonUtils; import com.ruoyi.platform.vo.InsMetricInfoVo; -import org.apache.dubbo.container.Main; -import org.json.JSONObject; -import org.json.JSONTokener; +import org.apache.commons.lang3.StringUtils; +import org.springframework.beans.factory.annotation.Value; import org.springframework.stereotype.Service; import javax.annotation.Resource; @@ -24,58 +23,66 @@ import java.util.stream.Collectors; public class AimServiceImpl implements AimService { @Resource private ExperimentInsService experimentInsService; + @Resource + private ModelDependencyService modelDependencyService; + + @Value("${aim.url}") + private String aimUrl; + @Value("${aim.proxyUrl}") + private String aimProxyUrl; @Override public List getExpTrainInfos(Integer experimentId) throws Exception { - String experimentName = "experiment-train-0"+experimentId; - return getAimRunInfos("",experimentId); + return getAimRunInfos(true,experimentId); } @Override public List getExpEvaluateInfos(Integer experimentId) throws Exception { - String experimentName = "experiment-evaluate-0"+experimentId; - return getAimRunInfos("",experimentId); + return getAimRunInfos(false,experimentId); } @Override public String getExpMetrics(List runIds) throws Exception { String decode = AIM64EncoderUtil.decode(runIds); - return "http://172.20.32.21:7123/api/runs/search/run?query="+decode; + return aimUrl+"/api/runs/search/run?query="+decode; } - private List getAimRunInfos(String experimentName,Integer experimentId) throws Exception { - String encodedUrlString = URLEncoder.encode("run.experiment==\"experiment-0000\"", "UTF-8"); - String url = "http://172.20.32.181:30123/api/runs/search/run?query="+encodedUrlString; + private List getAimRunInfos(boolean isTrain,Integer experimentId) throws Exception { + String experimentName = "experiment-"+experimentId+"-train"; + if (!isTrain){ + experimentName = "experiment-"+experimentId+"-evaluate"; + } + String encodedUrlString = URLEncoder.encode("run.experiment==\""+experimentName+"\"", "UTF-8"); + String url = aimProxyUrl+"/api/runs/search/run?query="+encodedUrlString; String s = HttpUtils.sendGetRequest(url); - System.out.println(s); List> response = JacksonUtil.parseJSONStr2MapList(s); - // TODO: parse aim response to InsMetricInfoVo list if (response == null || response.size() == 0){ return new ArrayList<>(); } //查询实例数据 List byExperimentId = experimentInsService.getByExperimentId(experimentId); -// if (byExperimentId == null || byExperimentId.size() == 0){ -// return new ArrayList<>(); -// } + if (byExperimentId == null || byExperimentId.size() == 0){ + return new ArrayList<>(); + } List aimRunInfoList = new ArrayList<>(); for (Map run : response) { InsMetricInfoVo aimRunInfo = new InsMetricInfoVo(); String runHash = (String) run.get("run_hash"); + aimRunInfo.setRunId(runHash); Map params= (Map) run.get("params"); Map paramMap = JsonUtils.flattenJson("", params); aimRunInfo.setParams(paramMap); - - Map tracesMap= (Map) run.get("params"); + String aimrunId = (String) paramMap.get("id"); + Map tracesMap= (Map) run.get("traces"); List> metricList = (List>) tracesMap.get("metric"); //过滤name为__system__开头的对象 aimRunInfo.setMetrics(new HashMap<>()); if (metricList != null && metricList.size() > 0){ List> metricRelList = metricList.stream() - .filter(map -> !StringUtils.equals("__system__", (String) map.get("name"))) + .filter(map -> !StringUtils.startsWith((String) map.get("name"),"__system__" )) .collect(Collectors.toList()); if (metricRelList!= null && metricRelList.size() > 0){ Map relMetricMap = new HashMap<>(); @@ -85,39 +92,84 @@ public class AimServiceImpl implements AimService { aimRunInfo.setMetrics(relMetricMap); } } - - - //找到ins - for (ExperimentIns ins : byExperimentId) { - String metricRecord = ins.getMetricRecord(); - if (metricRecord.contains(runHash)){ + String metricRecordString = ins.getMetricRecord(); + if (StringUtils.isEmpty(metricRecordString)){ + continue; + } + if (metricRecordString.contains(aimrunId)){ aimRunInfo.setExperimentInsId(ins.getId()); aimRunInfo.setStatus(ins.getStatus()); - aimRunInfo.setStartTime(ins.getStartTime()); + aimRunInfo.setStartTime(ins.getCreateTime()); + Map metricRecordMap = JacksonUtil.parseJSONStr2Map(metricRecordString); + + //metricRecord 格式为{"train":[{"task_id":"model-train-35303690","run_id":"5560d78f54314672b60304c8d6ba03b8","experiment_name":"experiment-30-train"}],"evaluate":[{"task_id":"model-train-35303690","run_id":"5560d78f54314672b60304c8d6ba03b8","experiment_name":"experiment-30-train"}]} + //遍历metricRecord,找到当前task_id对应的ModelDependency + + if (isTrain){ + List> trainList = (List>) metricRecordMap.get("train"); + List trainDateSet = getTrainDateSet(trainList, ins.getId(), isTrain); + aimRunInfo.setDataset(trainDateSet); + }else { + List> trainList = (List>) metricRecordMap.get("evaluate"); + List trainDateSet = getTrainDateSet(trainList, ins.getId(), isTrain); + aimRunInfo.setDataset(trainDateSet); + } + } } aimRunInfoList.add(aimRunInfo); } //判断哪个最长 - Optional maxMetricsVo = aimRunInfoList.stream() - .max((vo1, vo2) -> Integer.compare(vo1.getMetrics().size(), vo2.getMetrics().size())); + // 获取所有 metrics 的 key 的并集 + Set metricsKeys = (Set) aimRunInfoList.stream() + .map(InsMetricInfoVo::getMetrics) + .flatMap(metrics -> metrics.keySet().stream()) + .collect(Collectors.toSet()); + // 将并集赋值给每个 InsMetricInfoVo 的 metricsNames 属性 + aimRunInfoList.forEach(vo -> vo.setMetricsNames(new ArrayList<>(metricsKeys))); + + // 获取所有 params 的 key 的并集 + Set paramKeys = (Set) aimRunInfoList.stream() + .map(InsMetricInfoVo::getParams) + .flatMap(params -> params.keySet().stream()) + .collect(Collectors.toSet()); + // 将并集赋值给每个 InsMetricInfoVo 的 paramsNames 属性 + aimRunInfoList.forEach(vo -> vo.setParamsNames(new ArrayList<>(paramKeys))); + + return aimRunInfoList; + } - // 如果找到了,设置 metricsFlag 为 true - if (maxMetricsVo.isPresent()) { - maxMetricsVo.get().setMetricsFlag(true); - } - Optional maxParamsVo = aimRunInfoList.stream() - .max((vo1, vo2) -> Integer.compare(vo1.getParams().size(), vo2.getParams().size())); - // 如果找到了,设置 metricsFlag 为 true - if (maxParamsVo.isPresent()) { - maxParamsVo.get().setMetricsFlag(true); + private List getTrainDateSet(List> trainList,Integer expInsId,boolean isTrain){ + if (trainList == null || trainList.size() == 0){ + return new ArrayList<>(); } + List datasetList = new ArrayList<>(); + for (Map trainMap : trainList) { + String task_id = (String) trainMap.get("task_id"); + //modelDependency取到数据集文件 + ModelDependency modelDependency = modelDependencyService.queryByInsAndTrainTaskId(expInsId, task_id); + //把数据集文件组装成String后放进List + String datasetString = ""; + if (isTrain){ + datasetString = modelDependency.getTrainDataset(); + }else { + datasetString = modelDependency.getTestDataset(); + } + List> datasetListMap = JacksonUtil.parseJSONStr2MapList(datasetString); - return aimRunInfoList; + if (datasetListMap != null && datasetListMap.size() > 0){ + for (Map datasetMap : datasetListMap) { + //[{"dataset_id":20,"dataset_version":"v0.1.0","dataset_name":"手写体识别模型依赖测试训练数据集"}] + String datasetName = (String) datasetMap.get("dataset_name")+":"+(String) datasetMap.get("dataset_version"); + datasetList.add(datasetName); + } + } + } + return datasetList; } } diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ModelDependencyServiceImpl.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ModelDependencyServiceImpl.java index f3c48ebb..572a66a5 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ModelDependencyServiceImpl.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ModelDependencyServiceImpl.java @@ -17,10 +17,7 @@ import org.springframework.data.domain.PageRequest; import javax.annotation.Resource; import java.io.IOException; -import java.util.ArrayList; -import java.util.Date; -import java.util.List; -import java.util.Map; +import java.util.*; import java.util.stream.Collectors; /** @@ -97,6 +94,16 @@ public class ModelDependencyServiceImpl implements ModelDependencyService { return modelDependcyTreeVo; } + @Override + public List queryByIns(Integer expInsId) { + return modelDependencyDao.queryByIns(expInsId); + } + + @Override + public ModelDependency queryByInsAndTrainTaskId(Integer expInsId, String taskId) { + return modelDependencyDao.queryByInsAndTrainTaskId(expInsId,taskId); + } + /** * 递归父模型 * @param modelDependcyTreeVo diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/vo/InsMetricInfoVo.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/vo/InsMetricInfoVo.java index cd3943ed..6fe8caa4 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/vo/InsMetricInfoVo.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/vo/InsMetricInfoVo.java @@ -18,7 +18,7 @@ public class InsMetricInfoVo implements Serializable { @ApiModelProperty(value = "实例运行状态") private String status; @ApiModelProperty(value = "使用数据集") - private List> dataset; + private List dataset; @ApiModelProperty(value = "实例ID") private Integer experimentInsId; @ApiModelProperty(value = "训练指标") @@ -27,6 +27,6 @@ public class InsMetricInfoVo implements Serializable { private Map params; @ApiModelProperty(value = "训练记录ID") private String runId; - private Boolean metricsFlag = false; - private Boolean paramsFlag = false; + private List metricsNames; + private List paramsNames; } diff --git a/ruoyi-modules/management-platform/src/main/resources/mapper/managementPlatform/ModelDependencyDaoMapper.xml b/ruoyi-modules/management-platform/src/main/resources/mapper/managementPlatform/ModelDependencyDaoMapper.xml index 2cd5dd7a..ea592ee2 100644 --- a/ruoyi-modules/management-platform/src/main/resources/mapper/managementPlatform/ModelDependencyDaoMapper.xml +++ b/ruoyi-modules/management-platform/src/main/resources/mapper/managementPlatform/ModelDependencyDaoMapper.xml @@ -22,6 +22,22 @@ + +