Browse Source

修改实验对比分页查询

dev-czh
chenzhihang 1 year ago
parent
commit
1451f50c46
3 changed files with 26 additions and 21 deletions
  1. +7
    -6
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/aim/AimController.java
  2. +2
    -1
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/AimService.java
  3. +17
    -14
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/AimServiceImpl.java

+ 7
- 6
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/aim/AimController.java View File

@@ -1,6 +1,7 @@
package com.ruoyi.platform.controller.aim; package com.ruoyi.platform.controller.aim;


import com.ruoyi.common.core.web.controller.BaseController; import com.ruoyi.common.core.web.controller.BaseController;
import com.ruoyi.common.core.web.domain.AjaxResult;
import com.ruoyi.common.core.web.domain.GenericsAjaxResult; import com.ruoyi.common.core.web.domain.GenericsAjaxResult;
import com.ruoyi.platform.service.AimService; import com.ruoyi.platform.service.AimService;
import com.ruoyi.platform.vo.InsMetricInfoVo; import com.ruoyi.platform.vo.InsMetricInfoVo;
@@ -24,19 +25,19 @@ public class AimController extends BaseController {
@GetMapping("/getExpTrainInfos/{experiment_id}") @GetMapping("/getExpTrainInfos/{experiment_id}")
@ApiOperation("获取当前实验的模型训练指标信息") @ApiOperation("获取当前实验的模型训练指标信息")
@ApiResponse @ApiResponse
public GenericsAjaxResult<List<InsMetricInfoVo>> getExpTrainInfos(@RequestParam(value = "page") int page,
@RequestParam(value = "size") int size,
@PathVariable("experiment_id") Integer experimentId) {
return genericsSuccess(aimService.getExpInfos(true, experimentId, page, size));
public AjaxResult getExpTrainInfos(@RequestParam(value = "page") int page,
@RequestParam(value = "size") int size,
@PathVariable("experiment_id") Integer experimentId) {
return AjaxResult.success(aimService.getExpInfos(true, experimentId, page, size));
} }


@GetMapping("/getExpEvaluateInfos/{experiment_id}") @GetMapping("/getExpEvaluateInfos/{experiment_id}")
@ApiOperation("获取当前实验的模型推理指标信息") @ApiOperation("获取当前实验的模型推理指标信息")
@ApiResponse @ApiResponse
public GenericsAjaxResult<List<InsMetricInfoVo>> getExpEvaluateInfos(@RequestParam(value = "page") int page,
public AjaxResult getExpEvaluateInfos(@RequestParam(value = "page") int page,
@RequestParam(value = "size") int size, @RequestParam(value = "size") int size,
@PathVariable("experiment_id") Integer experimentId) { @PathVariable("experiment_id") Integer experimentId) {
return genericsSuccess(aimService.getExpInfos(false, experimentId, page, size));
return AjaxResult.success(aimService.getExpInfos(false, experimentId, page, size));
} }


@PostMapping("/getExpMetrics") @PostMapping("/getExpMetrics")


+ 2
- 1
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/AimService.java View File

@@ -1,6 +1,7 @@
package com.ruoyi.platform.service; package com.ruoyi.platform.service;


import com.ruoyi.platform.vo.InsMetricInfoVo; import com.ruoyi.platform.vo.InsMetricInfoVo;
import org.springframework.data.domain.Page;


import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
@@ -12,7 +13,7 @@ public interface AimService {


List<InsMetricInfoVo> getExpEvaluateInfos(Integer experimentId, String offset, int limit) throws Exception; List<InsMetricInfoVo> getExpEvaluateInfos(Integer experimentId, String offset, int limit) throws Exception;


List<InsMetricInfoVo> getExpInfos(boolean isTrain, Integer experimentId, int page, int size);
Page<InsMetricInfoVo> getExpInfos(boolean isTrain, Integer experimentId, int page, int size);


List<InsMetricInfoVo> getExpInfos1(boolean isTrain, Integer experimentId, String runId) throws Exception; List<InsMetricInfoVo> getExpInfos1(boolean isTrain, Integer experimentId, String runId) throws Exception;




+ 17
- 14
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/AimServiceImpl.java View File

@@ -14,6 +14,8 @@ import com.ruoyi.platform.utils.JsonUtils;
import com.ruoyi.platform.vo.InsMetricInfoVo; import com.ruoyi.platform.vo.InsMetricInfoVo;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Value; import org.springframework.beans.factory.annotation.Value;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.PageImpl;
import org.springframework.data.domain.PageRequest; import org.springframework.data.domain.PageRequest;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;


@@ -151,14 +153,12 @@ public class AimServiceImpl implements AimService {
return aimRunInfoList; return aimRunInfoList;
} }


public List<InsMetricInfoVo> getExpInfos(boolean isTrain, Integer experimentId, int page, int size) {
public Page<InsMetricInfoVo> getExpInfos(boolean isTrain, Integer experimentId, int page, int size) {
PageRequest pageRequest = PageRequest.of(page, size); PageRequest pageRequest = PageRequest.of(page, size);
ExperimentIns query = new ExperimentIns(); ExperimentIns query = new ExperimentIns();
query.setExperimentId(experimentId); query.setExperimentId(experimentId);
long count = experimentInsDao.count(query);
List<ExperimentIns> experimentInsList = experimentInsDao.queryAllByLimit(query, pageRequest); List<ExperimentIns> experimentInsList = experimentInsDao.queryAllByLimit(query, pageRequest);
if (experimentInsList == null || experimentInsList.size() == 0) {
return new ArrayList<>();
}
List<InsMetricInfoVo> aimRunInfoList = new ArrayList<>(); List<InsMetricInfoVo> aimRunInfoList = new ArrayList<>();
for (ExperimentIns experimentIns : experimentInsList) { for (ExperimentIns experimentIns : experimentInsList) {
InsMetricInfoVo aimRunInfo = new InsMetricInfoVo(); InsMetricInfoVo aimRunInfo = new InsMetricInfoVo();
@@ -200,7 +200,7 @@ public class AimServiceImpl implements AimService {
} }
aimRunInfoList.add(aimRunInfo); aimRunInfoList.add(aimRunInfo);
} }
return aimRunInfoList;
return new PageImpl<>(aimRunInfoList, pageRequest, count);
} }




@@ -314,15 +314,18 @@ public class AimServiceImpl implements AimService {


HashMap<String, Object> metrics = new HashMap<>(); HashMap<String, Object> metrics = new HashMap<>();
List<String> metricsNames = new ArrayList<>(); List<String> metricsNames = new ArrayList<>();
for (String key : metricValues.keySet()) {
Map<String, Object> valueMap = (Map<String, Object>) metricValues.get(key);
aimRunInfo.setRunId((String) valueMap.get("run_hash"));

metrics.putAll(valueMap.entrySet().stream().filter(entry -> !entry.getKey().equals("run_hash")).collect(Collectors.toMap(
Map.Entry::getKey,
Map.Entry::getValue
)));
metricsNames.addAll(valueMap.keySet().stream().filter(entry -> !entry.equals("run_hash")).collect(Collectors.toList()));

if(metricValues != null){
for (String key : metricValues.keySet()) {
Map<String, Object> valueMap = (Map<String, Object>) metricValues.get(key);
aimRunInfo.setRunId((String) valueMap.get("run_hash"));

metrics.putAll(valueMap.entrySet().stream().filter(entry -> !entry.getKey().equals("run_hash")).collect(Collectors.toMap(
Map.Entry::getKey,
Map.Entry::getValue
)));
metricsNames.addAll(valueMap.keySet().stream().filter(entry -> !entry.equals("run_hash")).collect(Collectors.toList()));
}
} }
aimRunInfo.setMetrics(metrics); aimRunInfo.setMetrics(metrics);
aimRunInfo.setMetricsNames(metricsNames); aimRunInfo.setMetricsNames(metricsNames);


Loading…
Cancel
Save