Browse Source

修改实验对比分页查询

dev-czh
cp3hnu 1 year ago
parent
commit
c0b74fafd2
4 changed files with 59 additions and 49 deletions
  1. +2
    -0
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/mapper/ExperimentInsDao.java
  2. +39
    -36
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/scheduling/ExperimentInstanceStatusTask.java
  3. +2
    -13
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/AimServiceImpl.java
  4. +16
    -0
      ruoyi-modules/management-platform/src/main/resources/mapper/managementPlatform/ExperimentInsDaoMapper.xml

+ 2
- 0
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/mapper/ExperimentInsDao.java View File

@@ -41,6 +41,8 @@ public interface ExperimentInsDao {

long countTorE(@Param("experimentId") Integer experimentId, @Param("isTrain") Boolean isTrain);

List<ExperimentIns> queryTorE(@Param("experimentId") Integer experimentId, @Param("isTrain") Boolean isTrain, @Param("pageable") Pageable pageable);

/*
统计实验实例总数



+ 39
- 36
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/scheduling/ExperimentInstanceStatusTask.java View File

@@ -1,6 +1,5 @@
package com.ruoyi.platform.scheduling;

import com.ruoyi.platform.constant.Constant;
import com.ruoyi.platform.domain.Experiment;
import com.ruoyi.platform.domain.ExperimentIns;
import com.ruoyi.platform.mapper.ExperimentDao;
@@ -48,50 +47,54 @@ public class ExperimentInstanceStatusTask {
}
//运行成功的实验实例记录指标数值
// if (Constant.Succeeded.equals(experimentIns.getStatus())) {
Map<String, Object> metricRecord = JacksonUtil.parseJSONStr2Map(experimentIns.getMetricRecord());
List<Map<String, Object>> trainMetricRecords = (List<Map<String, Object>>) metricRecord.get("train");
List<Map<String, Object>> evaluateMetricRecords = (List<Map<String, Object>>) metricRecord.get("evaluate");
Map<String, Object> metricRecord = JacksonUtil.parseJSONStr2Map(experimentIns.getMetricRecord());
List<Map<String, Object>> trainMetricRecords = (List<Map<String, Object>>) metricRecord.get("train");
List<Map<String, Object>> evaluateMetricRecords = (List<Map<String, Object>>) metricRecord.get("evaluate");

HashMap<String, Object> metricValue = new HashMap<>();
HashMap<String, Object> trainMetricValues = new HashMap<>();
HashMap<String, Object> evaluateMetricValues = new HashMap<>();
HashMap<String, Object> metricValue = new HashMap<>();
HashMap<String, Object> trainMetricValues = new HashMap<>();
HashMap<String, Object> evaluateMetricValues = new HashMap<>();

if (trainMetricRecords != null && !trainMetricRecords.isEmpty()) {
for (Map<String, Object> trainMetricRecord : trainMetricRecords) {
HashMap<String, Object> trainMetricValue = new HashMap<>();
String taskId = (String) trainMetricRecord.get("task_id");
if (taskId.startsWith("model-train")) {
String runId = (String) trainMetricRecord.get("run_id");
List<InsMetricInfoVo> expTrainInfos = aimService.getExpInfos1(true, experimentIns.getExperimentId(), runId);
for (InsMetricInfoVo expTrainInfo : expTrainInfos) {
Map metrics = expTrainInfo.getMetrics();
trainMetricValue.putAll(metrics);
trainMetricValue.put("run_hash", expTrainInfo.getRunId());
}
if (trainMetricRecords != null && !trainMetricRecords.isEmpty()) {
for (Map<String, Object> trainMetricRecord : trainMetricRecords) {
HashMap<String, Object> trainMetricValue = new HashMap<>();
String taskId = (String) trainMetricRecord.get("task_id");
if (taskId.startsWith("model-train")) {
String runId = (String) trainMetricRecord.get("run_id");
List<InsMetricInfoVo> expTrainInfos = aimService.getExpInfos1(true, experimentIns.getExperimentId(), runId);
for (InsMetricInfoVo expTrainInfo : expTrainInfos) {
Map metrics = expTrainInfo.getMetrics();
trainMetricValue.putAll(metrics);
trainMetricValue.put("run_hash", expTrainInfo.getRunId());
}
if (trainMetricValue.size() > 0) {
trainMetricValues.put(taskId, trainMetricValue);
}
trainMetricValues.put(taskId, trainMetricValue);
}
}
}

if (evaluateMetricRecords != null && !evaluateMetricRecords.isEmpty()) {
for (Map<String, Object> evaluateMetricRecord : evaluateMetricRecords) {
HashMap<String, Object> evaluateMetricValue = new HashMap<>();
String taskId = (String) evaluateMetricRecord.get("task_id");
if (taskId.startsWith("model-evaluate")) {
String runId = (String) evaluateMetricRecord.get("run_id");
List<InsMetricInfoVo> expTrainInfos = aimService.getExpInfos1(false, experimentIns.getExperimentId(), runId);
for (InsMetricInfoVo expTrainInfo : expTrainInfos) {
Map metrics = expTrainInfo.getMetrics();
evaluateMetricValue.putAll(metrics);
evaluateMetricValue.put("run_hash", expTrainInfo.getRunId());
}
if (evaluateMetricRecords != null && !evaluateMetricRecords.isEmpty()) {
for (Map<String, Object> evaluateMetricRecord : evaluateMetricRecords) {
HashMap<String, Object> evaluateMetricValue = new HashMap<>();
String taskId = (String) evaluateMetricRecord.get("task_id");
if (taskId.startsWith("model-evaluate")) {
String runId = (String) evaluateMetricRecord.get("run_id");
List<InsMetricInfoVo> expTrainInfos = aimService.getExpInfos1(false, experimentIns.getExperimentId(), runId);
for (InsMetricInfoVo expTrainInfo : expTrainInfos) {
Map metrics = expTrainInfo.getMetrics();
evaluateMetricValue.putAll(metrics);
evaluateMetricValue.put("run_hash", expTrainInfo.getRunId());
}
if (evaluateMetricValue.size() > 0) {
evaluateMetricValues.put(taskId, evaluateMetricValue);
}
evaluateMetricValues.put(taskId, evaluateMetricValue);
}
}
metricValue.put("train", trainMetricValues);
metricValue.put("evaluate", evaluateMetricValues);
experimentIns.setMetricValue(JsonUtils.mapToJson(metricValue));
}
metricValue.put("train", trainMetricValues);
metricValue.put("evaluate", evaluateMetricValues);
experimentIns.setMetricValue(JsonUtils.mapToJson(metricValue));
// }
experimentIns.setUpdateTime(new Date());
// 线程安全的添加操作


+ 2
- 13
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/AimServiceImpl.java View File

@@ -155,22 +155,11 @@ public class AimServiceImpl implements AimService {

public Page<InsMetricInfoVo> getExpInfos(boolean isTrain, Integer experimentId, int page, int size) {
PageRequest pageRequest = PageRequest.of(page, size);
ExperimentIns query = new ExperimentIns();
query.setExperimentId(experimentId);
long count = experimentInsDao.countTorE(experimentId, isTrain);
List<ExperimentIns> experimentInsList = experimentInsDao.queryAllByLimit(query, pageRequest);

List<ExperimentIns> collect = experimentInsList.stream().filter(ins -> {
Map<String, Object> metricRecord = JacksonUtil.parseJSONStr2Map(ins.getMetricRecord());
if (isTrain) {
return metricRecord.get("train") != null;
} else {
return metricRecord.get("evaluate") != null;
}
}).collect(Collectors.toList());
List<ExperimentIns> experimentInsList = experimentInsDao.queryTorE(experimentId, isTrain, pageRequest);

List<InsMetricInfoVo> aimRunInfoList = new ArrayList<>();
for (ExperimentIns experimentIns : collect) {
for (ExperimentIns experimentIns : experimentInsList) {
InsMetricInfoVo aimRunInfo = new InsMetricInfoVo();
aimRunInfo.setExperimentInsId(experimentIns.getId());
aimRunInfo.setStartTime(experimentIns.getCreateTime());


+ 16
- 0
ruoyi-modules/management-platform/src/main/resources/mapper/managementPlatform/ExperimentInsDaoMapper.xml View File

@@ -263,6 +263,22 @@
</if>
</select>

<select id="queryTorE" resultType="com.ruoyi.platform.domain.ExperimentIns">
select
id, experiment_id, argo_ins_name, argo_ins_ns, status, nodes_status,nodes_result,
nodes_logs,global_param,metric_record, metric_value, start_time, finish_time, create_by, create_time, update_by,
update_time, state
from experiment_ins
where state = 1
and experiment_id = #{experimentId}
<if test="isTrain">
and not JSON_CONTAINS(metric_record, 'null', '$.train')
</if>
<if test="! isTrain">
and not JSON_CONTAINS(metric_record, 'null', '$.evaluate')
</if>
</select>

<!--新增所有列-->
<insert id="insert" keyProperty="id" useGeneratedKeys="true">
insert into experiment_ins(experiment_id,argo_ins_name,argo_ins_ns,status,nodes_status,nodes_result,nodes_logs,global_param,metric_record,metric_value,start_time,finish_time,create_by,create_time,update_by,update_time,state)


Loading…
Cancel
Save