Browse Source

新增指标对比

pull/95/head
fanshuai 1 year ago
parent
commit
0e6c7ce5e8
7 changed files with 162 additions and 85 deletions
  1. +4
    -0
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/mapper/ModelDependencyDao.java
  2. +32
    -38
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/scheduling/ExperimentInstanceStatusTask.java
  3. +4
    -0
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/ModelDependencyService.java
  4. +92
    -40
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/AimServiceImpl.java
  5. +11
    -4
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ModelDependencyServiceImpl.java
  6. +3
    -3
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/vo/InsMetricInfoVo.java
  7. +16
    -0
      ruoyi-modules/management-platform/src/main/resources/mapper/managementPlatform/ModelDependencyDaoMapper.xml

+ 4
- 0
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/mapper/ModelDependencyDao.java View File

@@ -84,5 +84,9 @@ public interface ModelDependencyDao {
List<ModelDependency> queryByModelDependency(@Param("modelDependency") ModelDependency modelDependency);

List<ModelDependency> queryChildrenByVersionId(@Param("model_id")String modelId, @Param("version")String version);

List<ModelDependency> queryByIns(@Param("expInsId")Integer expInsId);

ModelDependency queryByInsAndTrainTaskId(@Param("expInsId")Integer expInsId,@Param("taskId") String taskId);
}


+ 32
- 38
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/scheduling/ExperimentInstanceStatusTask.java View File

@@ -7,20 +7,15 @@ import com.ruoyi.platform.mapper.ExperimentDao;
import com.ruoyi.platform.mapper.ExperimentInsDao;
import com.ruoyi.platform.mapper.ModelDependencyDao;
import com.ruoyi.platform.service.ExperimentInsService;
import com.ruoyi.platform.service.ModelDependencyService;
import com.ruoyi.platform.utils.JacksonUtil;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.domain.Page;
import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.stereotype.Component;

import javax.annotation.Resource;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Date;
import java.util.List;
import java.util.Map;
import java.util.*;

@Component()
public class ExperimentInstanceStatusTask {
@@ -34,7 +29,7 @@ public class ExperimentInstanceStatusTask {
private ModelDependencyDao modelDependencyDao;
private List<Integer> experimentIds = new ArrayList<>();

@Scheduled(cron = "0/14 * * * * ?") // 每30S执行一次
@Scheduled(cron = "0/30 * * * * ?") // 每30S执行一次
public void executeExperimentInsStatus() throws IOException {
// 首先查到所有非终止态的实验实例
List<ExperimentIns> experimentInsList = experimentInsService.queryByExperimentIsNotTerminated();
@@ -46,95 +41,94 @@ public class ExperimentInstanceStatusTask {
String oldStatus = experimentIns.getStatus();
try {
experimentIns = experimentInsService.queryStatusFromArgo(experimentIns);
}catch (Exception e){
} catch (Exception e) {
experimentIns.setStatus("Failed");
}
// if (!StringUtils.equals(oldStatus,experimentIns.getStatus())){
experimentIns.setUpdateTime(new Date());
// 线程安全的添加操作
synchronized (experimentIds) {
experimentIds.add(experimentIns.getExperimentId());
}
updateList.add(experimentIns);

// }
// experimentInsDao.update(experimentIns);
experimentIns.setUpdateTime(new Date());
// 线程安全的添加操作
synchronized (experimentIds) {
experimentIds.add(experimentIns.getExperimentId());
}
updateList.add(experimentIns);
}

}
if (updateList.size() > 0){
if (updateList.size() > 0) {
experimentInsDao.insertOrUpdateBatch(updateList);

//遍历模型关系表,找到
List<ModelDependency> modelDependencyList = new ArrayList<ModelDependency>();
for (ExperimentIns experimentIns : updateList){
for (ExperimentIns experimentIns : updateList) {
ModelDependency modelDependencyquery = new ModelDependency();
modelDependencyquery.setExpInsId(experimentIns.getId());
modelDependencyquery.setState(2);

List<ModelDependency> modelDependencyListquery = modelDependencyDao.queryByModelDependency(modelDependencyquery);
if (modelDependencyListquery==null||modelDependencyListquery.size()==0){
if (modelDependencyListquery == null || modelDependencyListquery.size() == 0) {
continue;
}
ModelDependency modelDependency = modelDependencyListquery.get(0);
//查看状态,
if (StringUtils.equals("Failed",experimentIns.getStatus())){
if (StringUtils.equals("Failed", experimentIns.getStatus())) {
//取出节点状态
String trainTask = modelDependency.getTrainTask();
Map<String, Object> trainMap = JacksonUtil.parseJSONStr2Map(trainTask);
String task_id = (String) trainMap.get("task_id");
if (StringUtils.isEmpty(task_id)){
if (StringUtils.isEmpty(task_id)) {
continue;
}
String nodesStatus = experimentIns.getNodesStatus();
Map<String, Object> nodeMaps = JacksonUtil.parseJSONStr2Map(nodesStatus);
Map<String, Object> nodeMap = JacksonUtil.parseJSONStr2Map(JacksonUtil.toJSONString(nodeMaps.get(task_id)));

if (nodeMap==null){
if (nodeMap == null) {
continue;
}
if (!StringUtils.equals("Succeeded",(String)nodeMap.get("phase"))){
if (!StringUtils.equals("Succeeded", (String) nodeMap.get("phase"))) {
modelDependency.setState(0);
modelDependencyList.add(modelDependency);
}
}
}
if (modelDependencyList.size()>0) {
if (modelDependencyList.size() > 0) {
modelDependencyDao.insertOrUpdateBatch(modelDependencyList);
}
}

}
@Scheduled(cron = "0/17 * * * * ?") // / 每30S执行一次

@Scheduled(cron = "0/30 * * * * ?") // / 每30S执行一次
public void executeExperimentStatus() throws IOException {
if (experimentIds.size()==0){
if (experimentIds.size() == 0) {
return;
}
// 存储需要更新的实验对象列表
List<Experiment> updateExperiments = new ArrayList<>();
for (Integer experimentId : experimentIds){
for (Integer experimentId : experimentIds) {
// 获取当前实验的所有实例列表
List<ExperimentIns> insList = experimentInsService.getByExperimentId(experimentId);
List<String> statusList = new ArrayList<String>();
// 更新实验状态列表
for (int i=0;i<insList.size();i++){
for (int i = 0; i < insList.size(); i++) {
statusList.add(insList.get(i).getStatus());
}
String subStatus = statusList.toString().substring(1, statusList.toString().length() - 1);
Experiment experiment = experimentDao.queryById(experimentId);
// 如果实验状态列表发生变化,则更新实验对象,并加入到需要更新的列表中
if (!StringUtils.equals(subStatus,experiment.getStatusList())){
if (!StringUtils.equals(subStatus, experiment.getStatusList())) {
experiment.setStatusList(subStatus);
updateExperiments.add(experiment);
}
}
if (!updateExperiments.isEmpty()) {
experimentDao.insertOrUpdateBatch(updateExperiments);
for (int index = 0; index < updateExperiments.size(); index++) {
// 线程安全的删除操作
synchronized (experimentIds) {
experimentIds.remove(index);
// 使用Iterator进行安全的删除操作
Iterator<Integer> iterator = experimentIds.iterator();
while (iterator.hasNext()) {
Integer experimentId = iterator.next();
for (Experiment experiment : updateExperiments) {
if (experiment.getId().equals(experimentId)) {
iterator.remove();
}
}
}
}


+ 4
- 0
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/ModelDependencyService.java View File

@@ -62,4 +62,8 @@ public interface ModelDependencyService {
List<ModelDependency> queryByModelDependency(ModelDependency modelDependency) throws IOException;

ModelDependcyTreeVo getModelDependencyTree(ModelDependency modelDependency) throws Exception;

List<ModelDependency> queryByIns(Integer expInsId);

ModelDependency queryByInsAndTrainTaskId(Integer expInsId, String taskId);
}

+ 92
- 40
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/AimServiceImpl.java View File

@@ -1,18 +1,17 @@
package com.ruoyi.platform.service.impl;

import com.alibaba.druid.util.StringUtils;
import com.ruoyi.platform.domain.ExperimentIns;
import com.ruoyi.platform.domain.ModelDependency;
import com.ruoyi.platform.service.AimService;
import com.ruoyi.platform.service.ExperimentInsService;
import com.ruoyi.platform.service.ExperimentService;
import com.ruoyi.platform.service.ModelDependencyService;
import com.ruoyi.platform.utils.AIM64EncoderUtil;
import com.ruoyi.platform.utils.HttpUtils;
import com.ruoyi.platform.utils.JacksonUtil;
import com.ruoyi.platform.utils.JsonUtils;
import com.ruoyi.platform.vo.InsMetricInfoVo;
import org.apache.dubbo.container.Main;
import org.json.JSONObject;
import org.json.JSONTokener;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;

import javax.annotation.Resource;
@@ -24,58 +23,66 @@ import java.util.stream.Collectors;
public class AimServiceImpl implements AimService {
@Resource
private ExperimentInsService experimentInsService;
@Resource
private ModelDependencyService modelDependencyService;

@Value("${aim.url}")
private String aimUrl;
@Value("${aim.proxyUrl}")
private String aimProxyUrl;

@Override
public List<InsMetricInfoVo> getExpTrainInfos(Integer experimentId) throws Exception {
String experimentName = "experiment-train-0"+experimentId;
return getAimRunInfos("",experimentId);
return getAimRunInfos(true,experimentId);
}

@Override
public List<InsMetricInfoVo> getExpEvaluateInfos(Integer experimentId) throws Exception {
String experimentName = "experiment-evaluate-0"+experimentId;
return getAimRunInfos("",experimentId);
return getAimRunInfos(false,experimentId);
}

@Override
public String getExpMetrics(List<String> runIds) throws Exception {
String decode = AIM64EncoderUtil.decode(runIds);
return "http://172.20.32.21:7123/api/runs/search/run?query="+decode;
return aimUrl+"/api/runs/search/run?query="+decode;
}

private List<InsMetricInfoVo> getAimRunInfos(String experimentName,Integer experimentId) throws Exception {
String encodedUrlString = URLEncoder.encode("run.experiment==\"experiment-0000\"", "UTF-8");
String url = "http://172.20.32.181:30123/api/runs/search/run?query="+encodedUrlString;
private List<InsMetricInfoVo> getAimRunInfos(boolean isTrain,Integer experimentId) throws Exception {
String experimentName = "experiment-"+experimentId+"-train";
if (!isTrain){
experimentName = "experiment-"+experimentId+"-evaluate";
}
String encodedUrlString = URLEncoder.encode("run.experiment==\""+experimentName+"\"", "UTF-8");
String url = aimProxyUrl+"/api/runs/search/run?query="+encodedUrlString;
String s = HttpUtils.sendGetRequest(url);
System.out.println(s);
List<Map<String, Object>> response = JacksonUtil.parseJSONStr2MapList(s);
// TODO: parse aim response to InsMetricInfoVo list
if (response == null || response.size() == 0){
return new ArrayList<>();
}
//查询实例数据
List<ExperimentIns> byExperimentId = experimentInsService.getByExperimentId(experimentId);

// if (byExperimentId == null || byExperimentId.size() == 0){
// return new ArrayList<>();
// }
if (byExperimentId == null || byExperimentId.size() == 0){
return new ArrayList<>();
}
List<InsMetricInfoVo> aimRunInfoList = new ArrayList<>();
for (Map<String, Object> run : response) {
InsMetricInfoVo aimRunInfo = new InsMetricInfoVo();
String runHash = (String) run.get("run_hash");

aimRunInfo.setRunId(runHash);

Map params= (Map) run.get("params");
Map<String, Object> paramMap = JsonUtils.flattenJson("", params);
aimRunInfo.setParams(paramMap);
Map<String, Object> tracesMap= (Map<String, Object>) run.get("params");
String aimrunId = (String) paramMap.get("id");
Map<String, Object> tracesMap= (Map<String, Object>) run.get("traces");
List<Map<String, Object>> metricList = (List<Map<String, Object>>) tracesMap.get("metric");
//过滤name为__system__开头的对象
aimRunInfo.setMetrics(new HashMap<>());
if (metricList != null && metricList.size() > 0){
List<Map<String, Object>> metricRelList = metricList.stream()
.filter(map -> !StringUtils.equals("__system__", (String) map.get("name")))
.filter(map -> !StringUtils.startsWith((String) map.get("name"),"__system__" ))
.collect(Collectors.toList());
if (metricRelList!= null && metricRelList.size() > 0){
Map<String, Object> relMetricMap = new HashMap<>();
@@ -85,39 +92,84 @@ public class AimServiceImpl implements AimService {
aimRunInfo.setMetrics(relMetricMap);
}
}



//找到ins

for (ExperimentIns ins : byExperimentId) {
String metricRecord = ins.getMetricRecord();
if (metricRecord.contains(runHash)){
String metricRecordString = ins.getMetricRecord();
if (StringUtils.isEmpty(metricRecordString)){
continue;
}
if (metricRecordString.contains(aimrunId)){
aimRunInfo.setExperimentInsId(ins.getId());
aimRunInfo.setStatus(ins.getStatus());
aimRunInfo.setStartTime(ins.getStartTime());
aimRunInfo.setStartTime(ins.getCreateTime());
Map<String, Object> metricRecordMap = JacksonUtil.parseJSONStr2Map(metricRecordString);

//metricRecord 格式为{"train":[{"task_id":"model-train-35303690","run_id":"5560d78f54314672b60304c8d6ba03b8","experiment_name":"experiment-30-train"}],"evaluate":[{"task_id":"model-train-35303690","run_id":"5560d78f54314672b60304c8d6ba03b8","experiment_name":"experiment-30-train"}]}
//遍历metricRecord,找到当前task_id对应的ModelDependency

if (isTrain){
List<Map<String, Object>> trainList = (List<Map<String, Object>>) metricRecordMap.get("train");
List<String> trainDateSet = getTrainDateSet(trainList, ins.getId(), isTrain);
aimRunInfo.setDataset(trainDateSet);
}else {
List<Map<String, Object>> trainList = (List<Map<String, Object>>) metricRecordMap.get("evaluate");
List<String> trainDateSet = getTrainDateSet(trainList, ins.getId(), isTrain);
aimRunInfo.setDataset(trainDateSet);
}

}
}
aimRunInfoList.add(aimRunInfo);
}
//判断哪个最长

Optional<InsMetricInfoVo> maxMetricsVo = aimRunInfoList.stream()
.max((vo1, vo2) -> Integer.compare(vo1.getMetrics().size(), vo2.getMetrics().size()));
// 获取所有 metrics 的 key 的并集
Set<String> metricsKeys = (Set<String>) aimRunInfoList.stream()
.map(InsMetricInfoVo::getMetrics)
.flatMap(metrics -> metrics.keySet().stream())
.collect(Collectors.toSet());
// 将并集赋值给每个 InsMetricInfoVo 的 metricsNames 属性
aimRunInfoList.forEach(vo -> vo.setMetricsNames(new ArrayList<>(metricsKeys)));

// 获取所有 params 的 key 的并集
Set<String> paramKeys = (Set<String>) aimRunInfoList.stream()
.map(InsMetricInfoVo::getParams)
.flatMap(params -> params.keySet().stream())
.collect(Collectors.toSet());
// 将并集赋值给每个 InsMetricInfoVo 的 paramsNames 属性
aimRunInfoList.forEach(vo -> vo.setParamsNames(new ArrayList<>(paramKeys)));

return aimRunInfoList;
}

// 如果找到了,设置 metricsFlag 为 true
if (maxMetricsVo.isPresent()) {
maxMetricsVo.get().setMetricsFlag(true);
}
Optional<InsMetricInfoVo> maxParamsVo = aimRunInfoList.stream()
.max((vo1, vo2) -> Integer.compare(vo1.getParams().size(), vo2.getParams().size()));

// 如果找到了,设置 metricsFlag 为 true
if (maxParamsVo.isPresent()) {
maxParamsVo.get().setMetricsFlag(true);
private List<String> getTrainDateSet(List<Map<String, Object>> trainList,Integer expInsId,boolean isTrain){
if (trainList == null || trainList.size() == 0){
return new ArrayList<>();
}
List<String> datasetList = new ArrayList<>();
for (Map<String, Object> trainMap : trainList) {
String task_id = (String) trainMap.get("task_id");
//modelDependency取到数据集文件
ModelDependency modelDependency = modelDependencyService.queryByInsAndTrainTaskId(expInsId, task_id);
//把数据集文件组装成String后放进List
String datasetString = "";
if (isTrain){
datasetString = modelDependency.getTrainDataset();
}else {
datasetString = modelDependency.getTestDataset();
}
List<Map<String, Object>> datasetListMap = JacksonUtil.parseJSONStr2MapList(datasetString);

return aimRunInfoList;
if (datasetListMap != null && datasetListMap.size() > 0){
for (Map<String, Object> datasetMap : datasetListMap) {
//[{"dataset_id":20,"dataset_version":"v0.1.0","dataset_name":"手写体识别模型依赖测试训练数据集"}]
String datasetName = (String) datasetMap.get("dataset_name")+":"+(String) datasetMap.get("dataset_version");
datasetList.add(datasetName);
}
}
}
return datasetList;
}

}

+ 11
- 4
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ModelDependencyServiceImpl.java View File

@@ -17,10 +17,7 @@ import org.springframework.data.domain.PageRequest;

import javax.annotation.Resource;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Date;
import java.util.List;
import java.util.Map;
import java.util.*;
import java.util.stream.Collectors;

/**
@@ -97,6 +94,16 @@ public class ModelDependencyServiceImpl implements ModelDependencyService {
return modelDependcyTreeVo;
}

@Override
public List<ModelDependency> queryByIns(Integer expInsId) {
return modelDependencyDao.queryByIns(expInsId);
}

@Override
public ModelDependency queryByInsAndTrainTaskId(Integer expInsId, String taskId) {
return modelDependencyDao.queryByInsAndTrainTaskId(expInsId,taskId);
}

/**
* 递归父模型
* @param modelDependcyTreeVo


+ 3
- 3
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/vo/InsMetricInfoVo.java View File

@@ -18,7 +18,7 @@ public class InsMetricInfoVo implements Serializable {
@ApiModelProperty(value = "实例运行状态")
private String status;
@ApiModelProperty(value = "使用数据集")
private List<Map<String, Object>> dataset;
private List<String> dataset;
@ApiModelProperty(value = "实例ID")
private Integer experimentInsId;
@ApiModelProperty(value = "训练指标")
@@ -27,6 +27,6 @@ public class InsMetricInfoVo implements Serializable {
private Map params;
@ApiModelProperty(value = "训练记录ID")
private String runId;
private Boolean metricsFlag = false;
private Boolean paramsFlag = false;
private List<String> metricsNames;
private List<String> paramsNames;
}

+ 16
- 0
ruoyi-modules/management-platform/src/main/resources/mapper/managementPlatform/ModelDependencyDaoMapper.xml View File

@@ -22,6 +22,22 @@
<result property="state" column="state" jdbcType="INTEGER"/>
</resultMap>

<select id="queryByIns" resultMap="ModelDependencyMap">
select
id,current_model_id,exp_ins_id,parent_models,ref_item,train_task,train_dataset,train_params,train_image,test_dataset,project_dependency,version,create_by,create_time,update_by,update_time,state
from model_dependency
<where>
exp_ins_id = #{expInsId} and state = 1
</where>
</select>
<select id="queryByInsAndTrainTaskId" resultMap="ModelDependencyMap">
select
id,current_model_id,exp_ins_id,parent_models,ref_item,train_task,train_dataset,train_params,train_image,test_dataset,project_dependency,version,create_by,create_time,update_by,update_time,state
from model_dependency
<where>
exp_ins_id = #{expInsId} and train_task like concat('%', #{taskId}, '%') limit 1
</where>
</select>
<select id="queryChildrenByVersionId" resultMap="ModelDependencyMap">
select
id,current_model_id,exp_ins_id,parent_models,ref_item,train_task,train_dataset,train_params,train_image,test_dataset,project_dependency,version,create_by,create_time,update_by,update_time,state


Loading…
Cancel
Save