diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/aim/AimController.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/aim/AimController.java index bc800545..c61fb789 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/aim/AimController.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/aim/AimController.java @@ -24,15 +24,15 @@ public class AimController extends BaseController { @GetMapping("/getExpTrainInfos/{experiment_id}") @ApiOperation("获取当前实验的模型训练指标信息") @ApiResponse - public GenericsAjaxResult> getExpTrainInfos(@PathVariable("experiment_id") Integer experimentId, @RequestParam("run_id") String runId) throws Exception { - return genericsSuccess(aimService.getExpTrainInfos(experimentId, runId)); + public GenericsAjaxResult> getExpTrainInfos(@PathVariable("experiment_id") Integer experimentId) throws Exception { + return genericsSuccess(aimService.getExpTrainInfos(experimentId)); } @GetMapping("/getExpEvaluateInfos/{experiment_id}") @ApiOperation("获取当前实验的模型推理指标信息") @ApiResponse - public GenericsAjaxResult> getExpEvaluateInfos(@PathVariable("experiment_id") Integer experimentId, @RequestParam("run_id") String runId) throws Exception { - return genericsSuccess(aimService.getExpEvaluateInfos(experimentId, runId)); + public GenericsAjaxResult> getExpEvaluateInfos(@PathVariable("experiment_id") Integer experimentId) throws Exception { + return genericsSuccess(aimService.getExpEvaluateInfos(experimentId)); } @PostMapping("/getExpMetrics") diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/AimService.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/AimService.java index c7f91d8f..9f3868f5 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/AimService.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/AimService.java @@ -6,9 +6,11 @@ import java.util.List; public interface AimService { - List getExpTrainInfos(Integer experimentId, String runId) throws Exception; + List getExpTrainInfos(Integer experimentId) throws Exception; - List getExpEvaluateInfos(Integer experimentId, String runId) throws Exception; + List getExpTrainInfos1(boolean isTrain, Integer experimentId, String runId) throws Exception; + + List getExpEvaluateInfos(Integer experimentId) throws Exception; String getExpMetrics(List runIds) throws Exception; } diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/AimServiceImpl.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/AimServiceImpl.java index 3db73a01..1b754404 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/AimServiceImpl.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/AimServiceImpl.java @@ -30,13 +30,13 @@ public class AimServiceImpl implements AimService { private NewHttpUtils httpUtils; @Override - public List getExpTrainInfos(Integer experimentId, String runId) throws Exception { - return getAimRunInfos(true, experimentId, runId); + public List getExpTrainInfos(Integer experimentId) throws Exception { + return getAimRunInfos(true, experimentId); } @Override - public List getExpEvaluateInfos(Integer experimentId, String runId) throws Exception { - return getAimRunInfos(false, experimentId, runId); + public List getExpEvaluateInfos(Integer experimentId) throws Exception { + return getAimRunInfos(false, experimentId); } @Override @@ -45,13 +45,12 @@ public class AimServiceImpl implements AimService { return aimUrl + "/metrics?select=" + decode; } - private List getAimRunInfos(boolean isTrain, Integer experimentId, String runId) throws Exception { -// String experimentName = "experiment-" + experimentId + "-train"; -// if (!isTrain) { -// experimentName = "experiment-" + experimentId + "-evaluate"; -// } -// String encodedUrlString = URLEncoder.encode("run.experiment==\"" + experimentName + "\"", "UTF-8"); - String encodedUrlString = URLEncoder.encode("run.id==\"" + runId + "\"", "UTF-8"); + private List getAimRunInfos(boolean isTrain, Integer experimentId) throws Exception { + String experimentName = "experiment-" + experimentId + "-train"; + if (!isTrain) { + experimentName = "experiment-" + experimentId + "-evaluate"; + } + String encodedUrlString = URLEncoder.encode("run.experiment==\"" + experimentName + "\"", "UTF-8"); String url = aimProxyUrl + "/api/runs/search/run?query=" + encodedUrlString; String s = httpUtils.sendGet(url, null); List> response = JacksonUtil.parseJSONStr2MapList(s); @@ -139,6 +138,96 @@ public class AimServiceImpl implements AimService { } + @Override + public List getExpTrainInfos1(boolean isTrain, Integer experimentId, String runId) throws Exception { + String encodedUrlString = URLEncoder.encode("run.id==\"" + runId + "\"", "UTF-8"); + String url = aimProxyUrl + "/api/runs/search/run?query=" + encodedUrlString; + String s = httpUtils.sendGet(url, null); + List> response = JacksonUtil.parseJSONStr2MapList(s); + System.out.println("response: " + JacksonUtil.toJSONString(response)); + if (response == null || response.size() == 0) { + return new ArrayList<>(); + } + //查询实例数据 + List byExperimentId = experimentInsService.queryByExperimentId(experimentId); + + if (byExperimentId == null || byExperimentId.size() == 0) { + return new ArrayList<>(); + } + List aimRunInfoList = new ArrayList<>(); + for (Map run : response) { + InsMetricInfoVo aimRunInfo = new InsMetricInfoVo(); + String runHash = (String) run.get("run_hash"); + + aimRunInfo.setRunId(runHash); + + Map params = (Map) run.get("params"); + Map paramMap = JsonUtils.flattenJson("", params); + aimRunInfo.setParams(paramMap); + String aimrunId = (String) paramMap.get("id"); + Map tracesMap = (Map) run.get("traces"); + List> metricList = (List>) tracesMap.get("metric"); + //过滤name为__system__开头的对象 + aimRunInfo.setMetrics(new HashMap<>()); + if (metricList != null && metricList.size() > 0) { + List> metricRelList = metricList.stream() + .filter(map -> !StringUtils.startsWith((String) map.get("name"), "__system__")) + .collect(Collectors.toList()); + if (metricRelList != null && metricRelList.size() > 0) { + Map relMetricMap = new HashMap<>(); + for (Map metricMap : metricRelList) { + relMetricMap.put((String) metricMap.get("name"), metricMap.get("last_value")); + } + aimRunInfo.setMetrics(relMetricMap); + } + } + //找到ins + for (ExperimentIns ins : byExperimentId) { + String metricRecordString = ins.getMetricRecord(); + if (StringUtils.isEmpty(metricRecordString)) { + continue; + } + if (metricRecordString.contains(aimrunId)) { + aimRunInfo.setExperimentInsId(ins.getId()); + aimRunInfo.setStatus(ins.getStatus()); + aimRunInfo.setStartTime(ins.getCreateTime()); + Map metricRecordMap = JacksonUtil.parseJSONStr2Map(metricRecordString); + if (isTrain) { + List> records = (List>) metricRecordMap.get("train"); + List datasetList = getTrainDateSet(records, aimrunId); + aimRunInfo.setDataset(datasetList); + } else { + List> records = (List>) metricRecordMap.get("evaluate"); + List datasetList = getTrainDateSet(records, aimrunId); + aimRunInfo.setDataset(datasetList); + } + aimRunInfoList.add(aimRunInfo); + } + } + } + + //判断哪个最长 + + // 获取所有 metrics 的 key 的并集 + Set metricsKeys = (Set) aimRunInfoList.stream() + .map(InsMetricInfoVo::getMetrics) + .flatMap(metrics -> metrics.keySet().stream()) + .collect(Collectors.toSet()); + // 将并集赋值给每个 InsMetricInfoVo 的 metricsNames 属性 + aimRunInfoList.forEach(vo -> vo.setMetricsNames(new ArrayList<>(metricsKeys))); + + // 获取所有 params 的 key 的并集 + Set paramKeys = (Set) aimRunInfoList.stream() + .map(InsMetricInfoVo::getParams) + .flatMap(params -> params.keySet().stream()) + .collect(Collectors.toSet()); + // 将并集赋值给每个 InsMetricInfoVo 的 paramsNames 属性 + aimRunInfoList.forEach(vo -> vo.setParamsNames(new ArrayList<>(paramKeys))); + + return aimRunInfoList; + } + + private List getTrainDateSet(List> records, String aimrunId) { List datasetList = new ArrayList<>(); for (Map record : records) { diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ModelsServiceImpl.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ModelsServiceImpl.java index 7b51ee35..7230b362 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ModelsServiceImpl.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ModelsServiceImpl.java @@ -1165,7 +1165,7 @@ public class ModelsServiceImpl implements ModelsService { for (int i = 0; i < trainMetrics.size(); i++) { JSONObject jsonObject = trainMetrics.getJSONObject(i); String runId = jsonObject.getString("run_id"); - List expTrainInfos = aimsService.getExpTrainInfos(modelMetaVo.getTrainTask().getExperimentId(), runId); + List expTrainInfos = aimsService.getExpTrainInfos1(true, modelMetaVo.getTrainTask().getExperimentId(), runId); for (InsMetricInfoVo expTrainInfo : expTrainInfos) { Map metrics1 = expTrainInfo.getMetrics(); train.putAll(metrics1); @@ -1177,7 +1177,7 @@ public class ModelsServiceImpl implements ModelsService { for (int i = 0; i < testMetrics.size(); i++) { JSONObject jsonObject = testMetrics.getJSONObject(i); String runId = jsonObject.getString("run_id"); - List expTestInfos = aimsService.getExpEvaluateInfos(modelMetaVo.getTrainTask().getExperimentId(), runId); + List expTestInfos = aimsService.getExpTrainInfos1(false, modelMetaVo.getTrainTask().getExperimentId(), runId); for (InsMetricInfoVo expTestInfo : expTestInfos) { Map metrics1 = expTestInfo.getMetrics(); evaluate.putAll(metrics1);