Browse Source

修改实验对比分页查询

dev-czh
chenzhihang 1 year ago
parent
commit
3002e7b0c5
4 changed files with 109 additions and 18 deletions
  1. +8
    -8
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/aim/AimController.java
  2. +10
    -6
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/scheduling/ExperimentInstanceStatusTask.java
  3. +2
    -0
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/AimService.java
  4. +89
    -4
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/AimServiceImpl.java

+ 8
- 8
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/aim/AimController.java View File

@@ -24,19 +24,19 @@ public class AimController extends BaseController {
@GetMapping("/getExpTrainInfos/{experiment_id}")
@ApiOperation("获取当前实验的模型训练指标信息")
@ApiResponse
public GenericsAjaxResult<List<InsMetricInfoVo>> getExpTrainInfos(@RequestParam(value = "offset", required = false) String offset,
@RequestParam(value = "limit") int limit,
@PathVariable("experiment_id") Integer experimentId) throws Exception {
return genericsSuccess(aimService.getExpTrainInfos(experimentId, offset, limit));
public GenericsAjaxResult<List<InsMetricInfoVo>> getExpTrainInfos(@RequestParam(value = "page") int page,
@RequestParam(value = "size") int size,
@PathVariable("experiment_id") Integer experimentId) {
return genericsSuccess(aimService.getExpInfos(true, experimentId, page, size));
}

@GetMapping("/getExpEvaluateInfos/{experiment_id}")
@ApiOperation("获取当前实验的模型推理指标信息")
@ApiResponse
public GenericsAjaxResult<List<InsMetricInfoVo>> getExpEvaluateInfos(@RequestParam(value = "offset", required = false) String offset,
@RequestParam(value = "limit") int limit,
@PathVariable("experiment_id") Integer experimentId) throws Exception {
return genericsSuccess(aimService.getExpEvaluateInfos(experimentId, offset, limit));
public GenericsAjaxResult<List<InsMetricInfoVo>> getExpEvaluateInfos(@RequestParam(value = "page") int page,
@RequestParam(value = "size") int size,
@PathVariable("experiment_id") Integer experimentId) {
return genericsSuccess(aimService.getExpInfos(false, experimentId, page, size));
}

@PostMapping("/getExpMetrics")


+ 10
- 6
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/scheduling/ExperimentInstanceStatusTask.java View File

@@ -53,11 +53,12 @@ public class ExperimentInstanceStatusTask {
List<Map<String, Object>> evaluateMetricRecords = (List<Map<String, Object>>) metricRecord.get("evaluate");

HashMap<String, Object> metricValue = new HashMap<>();
HashMap<String, Object> trainMetricValue = new HashMap<>();
HashMap<String, Object> evaluateMetricValue = new HashMap<>();
HashMap<String, Object> trainMetricValues = new HashMap<>();
HashMap<String, Object> evaluateMetricValues = new HashMap<>();

if (trainMetricRecords != null && !trainMetricRecords.isEmpty()) {
for (Map<String, Object> trainMetricRecord : trainMetricRecords) {
HashMap<String, Object> trainMetricValue = new HashMap<>();
String taskId = (String) trainMetricRecord.get("task_id");
if (taskId.startsWith("model-train")) {
String runId = (String) trainMetricRecord.get("run_id");
@@ -65,14 +66,16 @@ public class ExperimentInstanceStatusTask {
for (InsMetricInfoVo expTrainInfo : expTrainInfos) {
Map metrics = expTrainInfo.getMetrics();
trainMetricValue.putAll(metrics);
trainMetricValue.put("run_hash",expTrainInfo.getRunId());
trainMetricValue.put("run_hash", expTrainInfo.getRunId());
}
}
trainMetricValues.put(taskId, trainMetricValue);
}
}

if (evaluateMetricRecords != null && !evaluateMetricRecords.isEmpty()) {
for (Map<String, Object> evaluateMetricRecord : evaluateMetricRecords) {
HashMap<String, Object> evaluateMetricValue = new HashMap<>();
String taskId = (String) evaluateMetricRecord.get("task_id");
if (taskId.startsWith("model-evaluate")) {
String runId = (String) evaluateMetricRecord.get("run_id");
@@ -80,13 +83,14 @@ public class ExperimentInstanceStatusTask {
for (InsMetricInfoVo expTrainInfo : expTrainInfos) {
Map metrics = expTrainInfo.getMetrics();
evaluateMetricValue.putAll(metrics);
evaluateMetricValue.put("run_hash",expTrainInfo.getRunId());
evaluateMetricValue.put("run_hash", expTrainInfo.getRunId());
}
}
evaluateMetricValues.put(taskId, evaluateMetricValue);
}
}
metricValue.put("train", trainMetricValue);
metricValue.put("evaluate", evaluateMetricValue);
metricValue.put("train", trainMetricValues);
metricValue.put("evaluate", evaluateMetricValues);
experimentIns.setMetricValue(JsonUtils.mapToJson(metricValue));
}
experimentIns.setUpdateTime(new Date());


+ 2
- 0
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/AimService.java View File

@@ -12,6 +12,8 @@ public interface AimService {

List<InsMetricInfoVo> getExpEvaluateInfos(Integer experimentId, String offset, int limit) throws Exception;

List<InsMetricInfoVo> getExpInfos(boolean isTrain, Integer experimentId, int page, int size);

List<InsMetricInfoVo> getExpInfos1(boolean isTrain, Integer experimentId, String runId) throws Exception;

String getExpMetrics(List<String> runIds) throws Exception;


+ 89
- 4
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/AimServiceImpl.java View File

@@ -1,7 +1,10 @@
package com.ruoyi.platform.service.impl;

import com.alibaba.fastjson2.JSON;
import com.alibaba.fastjson2.JSONArray;
import com.alibaba.fastjson2.JSONObject;
import com.ruoyi.platform.domain.ExperimentIns;
import com.ruoyi.platform.mapper.ExperimentInsDao;
import com.ruoyi.platform.service.AimService;
import com.ruoyi.platform.service.ExperimentInsService;
import com.ruoyi.platform.utils.AIM64EncoderUtil;
@@ -24,6 +27,8 @@ import java.util.stream.Collectors;
public class AimServiceImpl implements AimService {
@Resource
private ExperimentInsService experimentInsService;
@Resource
private ExperimentInsDao experimentInsDao;

@Value("${aim.url}")
private String aimUrl;
@@ -146,10 +151,56 @@ public class AimServiceImpl implements AimService {
return aimRunInfoList;
}

// private List<InsMetricInfoVo> getExpInfos(Integer experimentId, int page, int size){
// PageRequest pageRequest = PageRequest.of(page,size);
//
// }
public List<InsMetricInfoVo> getExpInfos(boolean isTrain, Integer experimentId, int page, int size) {
PageRequest pageRequest = PageRequest.of(page, size);
ExperimentIns query = new ExperimentIns();
query.setExperimentId(experimentId);
List<ExperimentIns> experimentInsList = experimentInsDao.queryAllByLimit(query, pageRequest);
if (experimentInsList == null || experimentInsList.size() == 0) {
return new ArrayList<>();
}
List<InsMetricInfoVo> aimRunInfoList = new ArrayList<>();
for (ExperimentIns experimentIns : experimentInsList) {
InsMetricInfoVo aimRunInfo = new InsMetricInfoVo();
aimRunInfo.setExperimentInsId(experimentIns.getId());
aimRunInfo.setStartTime(experimentIns.getCreateTime());
aimRunInfo.setStatus(experimentIns.getStatus());

//解析参数
JSONArray params = JSON.parseArray(experimentIns.getGlobalParam());
HashMap<String, Object> paramsMap = new HashMap<>();
List<String> paramsNames = new ArrayList<>();

for (int i = 0; i < params.size(); i++) {
JSONObject jsonObject = params.getJSONObject(i);
String paramName = jsonObject.getString("param_name");
String paramValue = jsonObject.getString("param_value");
paramsMap.put(paramName, paramValue);
paramsNames.add(paramName);
}
aimRunInfo.setParams(paramsMap);
aimRunInfo.setParamsNames(paramsNames);

//解析数据集
Map<String, Object> metricRecord = JacksonUtil.parseJSONStr2Map(experimentIns.getMetricRecord());
if (isTrain) {
aimRunInfo.setDataset(getDataset("train", metricRecord));
} else {
aimRunInfo.setDataset(getDataset("evaluate", metricRecord));
}

//解析指标
Map<String, Object> metricValue = JacksonUtil.parseJSONStr2Map(experimentIns.getMetricValue());
if (isTrain) {
setMetricValue("train",metricValue,aimRunInfo);
} else {
setMetricValue("evaluate",metricValue,aimRunInfo);
}
aimRunInfoList.add(aimRunInfo);
}
return aimRunInfoList;
}


@Override
public List<InsMetricInfoVo> getExpInfos1(boolean isTrain, Integer experimentId, String runId) throws Exception {
@@ -240,6 +291,40 @@ public class AimServiceImpl implements AimService {
return aimRunInfoList;
}

private List<String> getDataset(String isTrain, Map<String, Object> metricRecord) {
List<String> datasetList = new ArrayList<>();
List<Map<String, Object>> trainMetricRecords = (List<Map<String, Object>>) metricRecord.get(isTrain);
for (Map<String, Object> trainMetricRecord : trainMetricRecords) {
String taskId = (String) trainMetricRecord.get("task_id");
if (taskId.startsWith("model-" + isTrain)) {
List<Map<String, Object>> datasets = (List<Map<String, Object>>) trainMetricRecord.get("datasets");
for (Map<String, Object> dataset : datasets) {
String datasetName = dataset.get("dataset_name") + ":" + dataset.get("dataset_version");
datasetList.add(datasetName);
}
}
}
return datasetList;
}

private void setMetricValue(String isTrain, Map<String, Object> metricValue, InsMetricInfoVo aimRunInfo){
Map<String, Object> metricValues = (Map<String, Object>) metricValue.get(isTrain);

HashMap<String, Object> metrics = new HashMap<>();
List<String> metricsNames = new ArrayList<>();
for (String key : metricValues.keySet()) {
Map<String, Object> valueMap = (Map<String, Object>) metricValues.get(key);
aimRunInfo.setRunId((String) valueMap.get("run_hash"));

metrics.putAll(valueMap.entrySet().stream().filter(entry -> !entry.getKey().equals("run_hash")).collect(Collectors.toMap(
Map.Entry::getKey,
Map.Entry::getValue
)));
metricsNames.addAll(valueMap.keySet().stream().filter(entry -> !entry.equals("run_hash")).collect(Collectors.toList()));
}
aimRunInfo.setMetrics(metrics);
aimRunInfo.setMetricsNames(metricsNames);
}

private List<String> getTrainDateSet(List<Map<String, Object>> records, String aimrunId) {
List<String> datasetList = new ArrayList<>();


Loading…
Cancel
Save