From 3bc6e7fb511f5b0eed9575b28d99e277a0b5ea72 Mon Sep 17 00:00:00 2001 From: chenzhihang <709011834@qq.com> Date: Tue, 24 Sep 2024 11:00:55 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AE=9E=E9=AA=8C=E6=A8=A1=E5=9E=8B=E5=AF=BC?= =?UTF-8?q?=E5=87=BA=E5=85=83=E6=95=B0=E6=8D=AE=E8=AE=B0=E5=BD=95=E6=8C=87?= =?UTF-8?q?=E6=A0=87=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../controller/aim/AimController.java | 12 ++-- .../ruoyi/platform/service/AimService.java | 4 +- .../platform/service/impl/AimServiceImpl.java | 63 ++++++++++--------- .../service/impl/ModelsServiceImpl.java | 19 +++++- 4 files changed, 55 insertions(+), 43 deletions(-) diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/aim/AimController.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/aim/AimController.java index f6b0b863..bc800545 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/aim/AimController.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/aim/AimController.java @@ -3,9 +3,7 @@ package com.ruoyi.platform.controller.aim; import com.ruoyi.common.core.web.controller.BaseController; import com.ruoyi.common.core.web.domain.GenericsAjaxResult; import com.ruoyi.platform.service.AimService; -import com.ruoyi.platform.vo.FrameLogPathVo; import com.ruoyi.platform.vo.InsMetricInfoVo; -import com.ruoyi.platform.vo.PodStatusVo; import io.swagger.annotations.Api; import io.swagger.annotations.ApiOperation; import io.swagger.v3.oas.annotations.responses.ApiResponse; @@ -26,21 +24,21 @@ public class AimController extends BaseController { @GetMapping("/getExpTrainInfos/{experiment_id}") @ApiOperation("获取当前实验的模型训练指标信息") @ApiResponse - public GenericsAjaxResult> getExpTrainInfos(@PathVariable("experiment_id") Integer experimentId) throws Exception { - return genericsSuccess(aimService.getExpTrainInfos(experimentId)); + public GenericsAjaxResult> getExpTrainInfos(@PathVariable("experiment_id") Integer experimentId, @RequestParam("run_id") String runId) throws Exception { + return genericsSuccess(aimService.getExpTrainInfos(experimentId, runId)); } @GetMapping("/getExpEvaluateInfos/{experiment_id}") @ApiOperation("获取当前实验的模型推理指标信息") @ApiResponse - public GenericsAjaxResult> getExpEvaluateInfos(@PathVariable("experiment_id") Integer experimentId) throws Exception { - return genericsSuccess(aimService.getExpEvaluateInfos(experimentId)); + public GenericsAjaxResult> getExpEvaluateInfos(@PathVariable("experiment_id") Integer experimentId, @RequestParam("run_id") String runId) throws Exception { + return genericsSuccess(aimService.getExpEvaluateInfos(experimentId, runId)); } @PostMapping("/getExpMetrics") @ApiOperation("获取当前实验的指标对比地址") @ApiResponse public GenericsAjaxResult getExpMetrics(@RequestBody List runIds) throws Exception { - return genericsSuccess(aimService.getExpMetrics(runIds)); + return genericsSuccess(aimService.getExpMetrics(runIds)); } } \ No newline at end of file diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/AimService.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/AimService.java index c83a42af..c7f91d8f 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/AimService.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/AimService.java @@ -6,9 +6,9 @@ import java.util.List; public interface AimService { - List getExpTrainInfos(Integer experimentId) throws Exception; + List getExpTrainInfos(Integer experimentId, String runId) throws Exception; - List getExpEvaluateInfos(Integer experimentId) throws Exception; + List getExpEvaluateInfos(Integer experimentId, String runId) throws Exception; String getExpMetrics(List runIds) throws Exception; } diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/AimServiceImpl.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/AimServiceImpl.java index 943a1dc4..3db73a01 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/AimServiceImpl.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/AimServiceImpl.java @@ -30,38 +30,39 @@ public class AimServiceImpl implements AimService { private NewHttpUtils httpUtils; @Override - public List getExpTrainInfos(Integer experimentId) throws Exception { - return getAimRunInfos(true,experimentId); + public List getExpTrainInfos(Integer experimentId, String runId) throws Exception { + return getAimRunInfos(true, experimentId, runId); } @Override - public List getExpEvaluateInfos(Integer experimentId) throws Exception { - return getAimRunInfos(false,experimentId); + public List getExpEvaluateInfos(Integer experimentId, String runId) throws Exception { + return getAimRunInfos(false, experimentId, runId); } @Override public String getExpMetrics(List runIds) throws Exception { String decode = AIM64EncoderUtil.decode(runIds); - return aimUrl+"/metrics?select="+decode; + return aimUrl + "/metrics?select=" + decode; } - private List getAimRunInfos(boolean isTrain,Integer experimentId) throws Exception { - String experimentName = "experiment-"+experimentId+"-train"; - if (!isTrain){ - experimentName = "experiment-"+experimentId+"-evaluate"; - } - String encodedUrlString = URLEncoder.encode("run.experiment==\""+experimentName+"\"", "UTF-8"); - String url = aimProxyUrl+"/api/runs/search/run?query="+encodedUrlString; - String s = httpUtils.sendGet(url,null); + private List getAimRunInfos(boolean isTrain, Integer experimentId, String runId) throws Exception { +// String experimentName = "experiment-" + experimentId + "-train"; +// if (!isTrain) { +// experimentName = "experiment-" + experimentId + "-evaluate"; +// } +// String encodedUrlString = URLEncoder.encode("run.experiment==\"" + experimentName + "\"", "UTF-8"); + String encodedUrlString = URLEncoder.encode("run.id==\"" + runId + "\"", "UTF-8"); + String url = aimProxyUrl + "/api/runs/search/run?query=" + encodedUrlString; + String s = httpUtils.sendGet(url, null); List> response = JacksonUtil.parseJSONStr2MapList(s); - System.out.println("response: "+JacksonUtil.toJSONString(response)); - if (response == null || response.size() == 0){ + System.out.println("response: " + JacksonUtil.toJSONString(response)); + if (response == null || response.size() == 0) { return new ArrayList<>(); } //查询实例数据 List byExperimentId = experimentInsService.queryByExperimentId(experimentId); - if (byExperimentId == null || byExperimentId.size() == 0){ + if (byExperimentId == null || byExperimentId.size() == 0) { return new ArrayList<>(); } List aimRunInfoList = new ArrayList<>(); @@ -71,22 +72,22 @@ public class AimServiceImpl implements AimService { aimRunInfo.setRunId(runHash); - Map params= (Map) run.get("params"); + Map params = (Map) run.get("params"); Map paramMap = JsonUtils.flattenJson("", params); aimRunInfo.setParams(paramMap); String aimrunId = (String) paramMap.get("id"); - Map tracesMap= (Map) run.get("traces"); + Map tracesMap = (Map) run.get("traces"); List> metricList = (List>) tracesMap.get("metric"); //过滤name为__system__开头的对象 aimRunInfo.setMetrics(new HashMap<>()); - if (metricList != null && metricList.size() > 0){ + if (metricList != null && metricList.size() > 0) { List> metricRelList = metricList.stream() - .filter(map -> !StringUtils.startsWith((String) map.get("name"),"__system__" )) + .filter(map -> !StringUtils.startsWith((String) map.get("name"), "__system__")) .collect(Collectors.toList()); - if (metricRelList!= null && metricRelList.size() > 0){ + if (metricRelList != null && metricRelList.size() > 0) { Map relMetricMap = new HashMap<>(); for (Map metricMap : metricRelList) { - relMetricMap.put((String)metricMap.get("name"), metricMap.get("last_value")); + relMetricMap.put((String) metricMap.get("name"), metricMap.get("last_value")); } aimRunInfo.setMetrics(relMetricMap); } @@ -94,19 +95,19 @@ public class AimServiceImpl implements AimService { //找到ins for (ExperimentIns ins : byExperimentId) { String metricRecordString = ins.getMetricRecord(); - if (StringUtils.isEmpty(metricRecordString)){ + if (StringUtils.isEmpty(metricRecordString)) { continue; } - if (metricRecordString.contains(aimrunId)){ + if (metricRecordString.contains(aimrunId)) { aimRunInfo.setExperimentInsId(ins.getId()); aimRunInfo.setStatus(ins.getStatus()); aimRunInfo.setStartTime(ins.getCreateTime()); Map metricRecordMap = JacksonUtil.parseJSONStr2Map(metricRecordString); - if (isTrain){ + if (isTrain) { List> records = (List>) metricRecordMap.get("train"); List datasetList = getTrainDateSet(records, aimrunId); aimRunInfo.setDataset(datasetList); - }else { + } else { List> records = (List>) metricRecordMap.get("evaluate"); List datasetList = getTrainDateSet(records, aimrunId); aimRunInfo.setDataset(datasetList); @@ -138,16 +139,16 @@ public class AimServiceImpl implements AimService { } - private List getTrainDateSet(List> records, String aimrunId){ + private List getTrainDateSet(List> records, String aimrunId) { List datasetList = new ArrayList<>(); for (Map record : records) { - if (StringUtils.equals(aimrunId, (String)record.get("run_id"))) { + if (StringUtils.equals(aimrunId, (String) record.get("run_id"))) { List> datasets = (List>) record.get("datasets"); - if (datasets == null || datasets.size() == 0){ + if (datasets == null || datasets.size() == 0) { continue; } - for (Map dataset : datasets){ - String datasetName = (String) dataset.get("dataset_name")+":"+(String) dataset.get("dataset_version"); + for (Map dataset : datasets) { + String datasetName = (String) dataset.get("dataset_name") + ":" + (String) dataset.get("dataset_version"); datasetList.add(datasetName); } break; diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ModelsServiceImpl.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ModelsServiceImpl.java index b06b95d9..64aa42a8 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ModelsServiceImpl.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ModelsServiceImpl.java @@ -1,6 +1,8 @@ package com.ruoyi.platform.service.impl; import com.alibaba.fastjson2.JSON; +import com.alibaba.fastjson2.JSONArray; +import com.alibaba.fastjson2.JSONObject; import com.ruoyi.common.core.utils.DateUtils; import com.ruoyi.common.security.utils.SecurityUtils; import com.ruoyi.platform.annotations.CheckDuplicate; @@ -1134,9 +1136,20 @@ public class ModelsServiceImpl implements ModelsService { } void getMetrics(ModelMetaVo modelMetaVo) throws Exception { - List expTrainInfos = aimsService.getExpTrainInfos(modelMetaVo.getTrainTask().getExperimentId()); - for (InsMetricInfoVo expTrainInfo : expTrainInfos) { - System.out.println(expTrainInfo.getMetrics()); + HashMap metrics = modelMetaVo.getMetrics(); + JSONArray trainMetrics = (JSONArray) metrics.get("train"); + for (int i = 0; i < trainMetrics.size(); i++) { + JSONObject jsonObject = trainMetrics.getJSONObject(i); + String runId = jsonObject.getString("run_id"); + List expTrainInfos = aimsService.getExpTrainInfos(modelMetaVo.getTrainTask().getExperimentId(), runId); + System.out.print(expTrainInfos); + } + + JSONArray testMetrics = (JSONArray) metrics.get("test"); + for (int i = 0; i < testMetrics.size(); i++) { + JSONObject jsonObject = testMetrics.getJSONObject(i); + String runId = jsonObject.getString("run_id"); + List expTestInfos = aimsService.getExpEvaluateInfos(modelMetaVo.getTrainTask().getExperimentId(), runId); } } }