From 2caa5286e5afe4b99d14ede3036d8b2d20b61c59 Mon Sep 17 00:00:00 2001 From: chenzhihang <709011834@qq.com> Date: Thu, 17 Oct 2024 11:27:42 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9F=A5=E8=AF=A2=E6=A8=A1=E5=9E=8B=E7=89=88?= =?UTF-8?q?=E6=9C=AC=E6=8C=87=E6=A0=87=E5=AF=B9=E6=AF=94?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../com/ruoyi/platform/constant/Constant.java | 3 + .../model/NewModelFromGitController.java | 10 +-- .../ruoyi/platform/service/ModelsService.java | 5 +- .../service/impl/ModelsServiceImpl.java | 69 ++++++++++++++----- 4 files changed, 65 insertions(+), 22 deletions(-) diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/constant/Constant.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/constant/Constant.java index 12300bb1..b325cc5d 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/constant/Constant.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/constant/Constant.java @@ -31,4 +31,7 @@ public class Constant { public final static String Pending = "Pending"; public final static String Init = "Init"; public final static String Stopped = "Stopped"; + + public final static String Type_Train = "train"; + public final static String Type_Evaluate = "evaluate"; } diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/model/NewModelFromGitController.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/model/NewModelFromGitController.java index ab4c5925..958a78bf 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/model/NewModelFromGitController.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/model/NewModelFromGitController.java @@ -1,6 +1,7 @@ package com.ruoyi.platform.controller.model; import com.ruoyi.common.core.web.domain.AjaxResult; +import com.ruoyi.platform.domain.ModelDependency1; import com.ruoyi.platform.service.ModelsService; import com.ruoyi.platform.vo.ModelsVo; import io.swagger.annotations.Api; @@ -79,15 +80,16 @@ public class NewModelFromGitController { public AjaxResult queryVersions(@RequestParam(value = "page") int page, @RequestParam(value = "size") int size, @RequestParam("identifier") String identifier, - @RequestParam("owner") String owner) throws Exception { + @RequestParam("owner") String owner, + @RequestParam("type") String type) throws Exception { PageRequest pageRequest = PageRequest.of(page, size); - return AjaxResult.success(this.modelsService.queryVersions(pageRequest, identifier, owner)); + return AjaxResult.success(this.modelsService.queryVersions(pageRequest, identifier, owner, type)); } @GetMapping("/queryVersionsMetrics") @ApiOperation("查询版本指标") - public AjaxResult queryVersionsMetrics(@RequestParam("runIds") List runIds) throws Exception { - return AjaxResult.success(this.modelsService.queryVersionsMetrics(runIds)); + public AjaxResult queryVersionsMetrics(@RequestParam("params") List params, @RequestParam("type") String type) throws Exception { + return AjaxResult.success(this.modelsService.queryVersionsMetrics(params, type)); } diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/ModelsService.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/ModelsService.java index 6c5f6a58..72ad75dd 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/ModelsService.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/ModelsService.java @@ -1,6 +1,7 @@ package com.ruoyi.platform.service; +import com.ruoyi.platform.domain.ModelDependency1; import com.ruoyi.platform.domain.Models; import com.ruoyi.platform.domain.ModelsVersion; import com.ruoyi.platform.vo.ModelDependency1TreeVo; @@ -99,9 +100,9 @@ public interface ModelsService { Page newPersonalQueryByPage(ModelsVo modelsVo, PageRequest pageRequest) throws Exception; - Page> queryVersions(PageRequest pageRequest, String identifier, String owner) throws Exception; + Page> queryVersions(PageRequest pageRequest, String identifier, String owner, String type) throws Exception; - List>> queryVersionsMetrics(List runIds) throws Exception; + List queryVersionsMetrics(List params, String type) throws Exception; List> getVersionList(String identifier, String owner) throws Exception; diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ModelsServiceImpl.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ModelsServiceImpl.java index f4d06f2f..431077b4 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ModelsServiceImpl.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ModelsServiceImpl.java @@ -974,7 +974,7 @@ public class ModelsServiceImpl implements ModelsService { } @Override - public Page> queryVersions(PageRequest pageRequest, String identifier, String owner) throws Exception { + public Page> queryVersions(PageRequest pageRequest, String identifier, String owner, String type) throws Exception { String token = gitService.checkoutToken(); List> collect = gitService.getBrancheList(token, owner, identifier); List> result = collect.stream() @@ -984,27 +984,45 @@ public class ModelsServiceImpl implements ModelsService { for (Map branch : result) { String meta = modelDependency1Dao.getMeta(identifier, owner, (String) branch.get("name")); - ModelMetaVo modelMetaVo = JSON.parseObject(meta, ModelMetaVo.class); - if (modelMetaVo.getParams() != null) { - branch.putAll(modelMetaVo.getParams()); - } - if (modelMetaVo.getMetrics() != null) { - branch.putAll(modelMetaVo.getMetrics()); - } - if (modelMetaVo.getMetricsParams() != null) { - branch.putAll(modelMetaVo.getMetricsParams()); + if (StringUtils.isNotEmpty(meta)) { + ModelMetaVo modelMetaVo = JSON.parseObject(meta, ModelMetaVo.class); + if (modelMetaVo.getParams() != null) { + HashMap params = modelMetaVo.getParams(); + branch.putAll(params); + ArrayList params_names = new ArrayList<>(); + for (String key : params.keySet()) { + params_names.add(key); + } + branch.put("params_names", params_names); + } + if (modelMetaVo.getMetrics() != null) { + HashMap metrics = modelMetaVo.getMetrics(); + if (Constant.Type_Train.equals(type)) { + Map trainMetrics = (Map) metrics.get(Constant.Type_Train); + branch.putAll(trainMetrics); + } else { + Map evaluateMetrics = (Map) metrics.get(Constant.Type_Evaluate); + branch.putAll(evaluateMetrics); + } + } } } return new PageImpl<>(result, pageRequest, collect.size()); } @Override - public List>> queryVersionsMetrics(List runIds) throws Exception { - List>> batchMetrics = new ArrayList<>(); - for (String runId : runIds) { - HashMap map = aimsService.queryMetricsParams(runId); - List> batchMetric = aimsService.getBatchMetric((String) map.get("run_hash"), (String) map.get("params")); - batchMetrics.add(batchMetric); + public List queryVersionsMetrics(List params, String type) { + List batchMetrics = new ArrayList<>(); + for (ModelDependency1 model : params) { + + ModelDependency1 modelDependency1 = modelDependency1Dao.queryByRepoAndVersion(model.getRepoId(), model.getIdentifier(), model.getVersion()); + ModelMetaVo modelMetaVo = JSON.parseObject(modelDependency1.getMeta(), ModelMetaVo.class); + HashMap metrics = modelMetaVo.getMetrics(); + if (Constant.Type_Train.equals(type)) { + batchMetrics.add(metrics.get("tarinDetail")); + } else { + batchMetrics.add(metrics.get("evaluateDetail")); + } } return batchMetrics; } @@ -1220,6 +1238,9 @@ public class ModelsServiceImpl implements ModelsService { HashMap train = new HashMap<>(); HashMap evaluate = new HashMap<>(); + List>> trainBatchMetrics = new ArrayList<>(); + List>> evaluateBatchMetrics = new ArrayList<>(); + HashMap metrics = modelMetaVo.getMetricsParams(); JSONArray trainMetrics = (JSONArray) metrics.get("train"); if (trainMetrics != null) { @@ -1231,8 +1252,16 @@ public class ModelsServiceImpl implements ModelsService { Map metrics1 = expTrainInfo.getMetrics(); train.putAll(metrics1); } + + // 记录训练指标详情 + HashMap map = aimsService.queryMetricsParams(runId); + List> batchMetric = aimsService.getBatchMetric((String) map.get("run_hash"), (String) map.get("params")); + trainBatchMetrics.add(batchMetric); + } result.put("train", train); + result.put("tarinDetail", trainBatchMetrics); + } JSONArray testMetrics = (JSONArray) metrics.get("evaluate"); @@ -1245,8 +1274,16 @@ public class ModelsServiceImpl implements ModelsService { Map metrics1 = expTestInfo.getMetrics(); evaluate.putAll(metrics1); } + + // 记录验证指标详情 + HashMap map = aimsService.queryMetricsParams(runId); + List> batchMetric = aimsService.getBatchMetric((String) map.get("run_hash"), (String) map.get("params")); + evaluateBatchMetrics.add(batchMetric); + } result.put("evaluate", evaluate); + result.put("evaluateDetail", evaluateBatchMetrics); + } modelMetaVo.setMetrics(result); }