diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/aim/AimController.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/aim/AimController.java index 28632dcb..9390e156 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/aim/AimController.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/aim/AimController.java @@ -1,6 +1,7 @@ package com.ruoyi.platform.controller.aim; import com.ruoyi.common.core.web.controller.BaseController; +import com.ruoyi.common.core.web.domain.AjaxResult; import com.ruoyi.common.core.web.domain.GenericsAjaxResult; import com.ruoyi.platform.service.AimService; import com.ruoyi.platform.vo.InsMetricInfoVo; @@ -24,19 +25,19 @@ public class AimController extends BaseController { @GetMapping("/getExpTrainInfos/{experiment_id}") @ApiOperation("获取当前实验的模型训练指标信息") @ApiResponse - public GenericsAjaxResult> getExpTrainInfos(@RequestParam(value = "page") int page, - @RequestParam(value = "size") int size, - @PathVariable("experiment_id") Integer experimentId) { - return genericsSuccess(aimService.getExpInfos(true, experimentId, page, size)); + public AjaxResult getExpTrainInfos(@RequestParam(value = "page") int page, + @RequestParam(value = "size") int size, + @PathVariable("experiment_id") Integer experimentId) { + return AjaxResult.success(aimService.getExpInfos(true, experimentId, page, size)); } @GetMapping("/getExpEvaluateInfos/{experiment_id}") @ApiOperation("获取当前实验的模型推理指标信息") @ApiResponse - public GenericsAjaxResult> getExpEvaluateInfos(@RequestParam(value = "page") int page, + public AjaxResult getExpEvaluateInfos(@RequestParam(value = "page") int page, @RequestParam(value = "size") int size, @PathVariable("experiment_id") Integer experimentId) { - return genericsSuccess(aimService.getExpInfos(false, experimentId, page, size)); + return AjaxResult.success(aimService.getExpInfos(false, experimentId, page, size)); } @PostMapping("/getExpMetrics") diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/AimService.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/AimService.java index abcbc645..c7e383e3 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/AimService.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/AimService.java @@ -1,6 +1,7 @@ package com.ruoyi.platform.service; import com.ruoyi.platform.vo.InsMetricInfoVo; +import org.springframework.data.domain.Page; import java.util.HashMap; import java.util.List; @@ -12,7 +13,7 @@ public interface AimService { List getExpEvaluateInfos(Integer experimentId, String offset, int limit) throws Exception; - List getExpInfos(boolean isTrain, Integer experimentId, int page, int size); + Page getExpInfos(boolean isTrain, Integer experimentId, int page, int size); List getExpInfos1(boolean isTrain, Integer experimentId, String runId) throws Exception; diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/AimServiceImpl.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/AimServiceImpl.java index 5c07567f..a2e912e4 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/AimServiceImpl.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/AimServiceImpl.java @@ -14,6 +14,8 @@ import com.ruoyi.platform.utils.JsonUtils; import com.ruoyi.platform.vo.InsMetricInfoVo; import org.apache.commons.lang3.StringUtils; import org.springframework.beans.factory.annotation.Value; +import org.springframework.data.domain.Page; +import org.springframework.data.domain.PageImpl; import org.springframework.data.domain.PageRequest; import org.springframework.stereotype.Service; @@ -151,14 +153,12 @@ public class AimServiceImpl implements AimService { return aimRunInfoList; } - public List getExpInfos(boolean isTrain, Integer experimentId, int page, int size) { + public Page getExpInfos(boolean isTrain, Integer experimentId, int page, int size) { PageRequest pageRequest = PageRequest.of(page, size); ExperimentIns query = new ExperimentIns(); query.setExperimentId(experimentId); + long count = experimentInsDao.count(query); List experimentInsList = experimentInsDao.queryAllByLimit(query, pageRequest); - if (experimentInsList == null || experimentInsList.size() == 0) { - return new ArrayList<>(); - } List aimRunInfoList = new ArrayList<>(); for (ExperimentIns experimentIns : experimentInsList) { InsMetricInfoVo aimRunInfo = new InsMetricInfoVo(); @@ -200,7 +200,7 @@ public class AimServiceImpl implements AimService { } aimRunInfoList.add(aimRunInfo); } - return aimRunInfoList; + return new PageImpl<>(aimRunInfoList, pageRequest, count); } @@ -314,15 +314,18 @@ public class AimServiceImpl implements AimService { HashMap metrics = new HashMap<>(); List metricsNames = new ArrayList<>(); - for (String key : metricValues.keySet()) { - Map valueMap = (Map) metricValues.get(key); - aimRunInfo.setRunId((String) valueMap.get("run_hash")); - - metrics.putAll(valueMap.entrySet().stream().filter(entry -> !entry.getKey().equals("run_hash")).collect(Collectors.toMap( - Map.Entry::getKey, - Map.Entry::getValue - ))); - metricsNames.addAll(valueMap.keySet().stream().filter(entry -> !entry.equals("run_hash")).collect(Collectors.toList())); + + if(metricValues != null){ + for (String key : metricValues.keySet()) { + Map valueMap = (Map) metricValues.get(key); + aimRunInfo.setRunId((String) valueMap.get("run_hash")); + + metrics.putAll(valueMap.entrySet().stream().filter(entry -> !entry.getKey().equals("run_hash")).collect(Collectors.toMap( + Map.Entry::getKey, + Map.Entry::getValue + ))); + metricsNames.addAll(valueMap.keySet().stream().filter(entry -> !entry.equals("run_hash")).collect(Collectors.toList())); + } } aimRunInfo.setMetrics(metrics); aimRunInfo.setMetricsNames(metricsNames);