| @@ -24,19 +24,19 @@ public class AimController extends BaseController { | |||||
| @GetMapping("/getExpTrainInfos/{experiment_id}") | @GetMapping("/getExpTrainInfos/{experiment_id}") | ||||
| @ApiOperation("获取当前实验的模型训练指标信息") | @ApiOperation("获取当前实验的模型训练指标信息") | ||||
| @ApiResponse | @ApiResponse | ||||
| public GenericsAjaxResult<List<InsMetricInfoVo>> getExpTrainInfos(@RequestParam(value = "offset", required = false) String offset, | |||||
| @RequestParam(value = "limit") int limit, | |||||
| @PathVariable("experiment_id") Integer experimentId) throws Exception { | |||||
| return genericsSuccess(aimService.getExpTrainInfos(experimentId, offset, limit)); | |||||
| public GenericsAjaxResult<List<InsMetricInfoVo>> getExpTrainInfos(@RequestParam(value = "page") int page, | |||||
| @RequestParam(value = "size") int size, | |||||
| @PathVariable("experiment_id") Integer experimentId) { | |||||
| return genericsSuccess(aimService.getExpInfos(true, experimentId, page, size)); | |||||
| } | } | ||||
| @GetMapping("/getExpEvaluateInfos/{experiment_id}") | @GetMapping("/getExpEvaluateInfos/{experiment_id}") | ||||
| @ApiOperation("获取当前实验的模型推理指标信息") | @ApiOperation("获取当前实验的模型推理指标信息") | ||||
| @ApiResponse | @ApiResponse | ||||
| public GenericsAjaxResult<List<InsMetricInfoVo>> getExpEvaluateInfos(@RequestParam(value = "offset", required = false) String offset, | |||||
| @RequestParam(value = "limit") int limit, | |||||
| @PathVariable("experiment_id") Integer experimentId) throws Exception { | |||||
| return genericsSuccess(aimService.getExpEvaluateInfos(experimentId, offset, limit)); | |||||
| public GenericsAjaxResult<List<InsMetricInfoVo>> getExpEvaluateInfos(@RequestParam(value = "page") int page, | |||||
| @RequestParam(value = "size") int size, | |||||
| @PathVariable("experiment_id") Integer experimentId) { | |||||
| return genericsSuccess(aimService.getExpInfos(false, experimentId, page, size)); | |||||
| } | } | ||||
| @PostMapping("/getExpMetrics") | @PostMapping("/getExpMetrics") | ||||
| @@ -53,11 +53,12 @@ public class ExperimentInstanceStatusTask { | |||||
| List<Map<String, Object>> evaluateMetricRecords = (List<Map<String, Object>>) metricRecord.get("evaluate"); | List<Map<String, Object>> evaluateMetricRecords = (List<Map<String, Object>>) metricRecord.get("evaluate"); | ||||
| HashMap<String, Object> metricValue = new HashMap<>(); | HashMap<String, Object> metricValue = new HashMap<>(); | ||||
| HashMap<String, Object> trainMetricValue = new HashMap<>(); | |||||
| HashMap<String, Object> evaluateMetricValue = new HashMap<>(); | |||||
| HashMap<String, Object> trainMetricValues = new HashMap<>(); | |||||
| HashMap<String, Object> evaluateMetricValues = new HashMap<>(); | |||||
| if (trainMetricRecords != null && !trainMetricRecords.isEmpty()) { | if (trainMetricRecords != null && !trainMetricRecords.isEmpty()) { | ||||
| for (Map<String, Object> trainMetricRecord : trainMetricRecords) { | for (Map<String, Object> trainMetricRecord : trainMetricRecords) { | ||||
| HashMap<String, Object> trainMetricValue = new HashMap<>(); | |||||
| String taskId = (String) trainMetricRecord.get("task_id"); | String taskId = (String) trainMetricRecord.get("task_id"); | ||||
| if (taskId.startsWith("model-train")) { | if (taskId.startsWith("model-train")) { | ||||
| String runId = (String) trainMetricRecord.get("run_id"); | String runId = (String) trainMetricRecord.get("run_id"); | ||||
| @@ -65,14 +66,16 @@ public class ExperimentInstanceStatusTask { | |||||
| for (InsMetricInfoVo expTrainInfo : expTrainInfos) { | for (InsMetricInfoVo expTrainInfo : expTrainInfos) { | ||||
| Map metrics = expTrainInfo.getMetrics(); | Map metrics = expTrainInfo.getMetrics(); | ||||
| trainMetricValue.putAll(metrics); | trainMetricValue.putAll(metrics); | ||||
| trainMetricValue.put("run_hash",expTrainInfo.getRunId()); | |||||
| trainMetricValue.put("run_hash", expTrainInfo.getRunId()); | |||||
| } | } | ||||
| } | } | ||||
| trainMetricValues.put(taskId, trainMetricValue); | |||||
| } | } | ||||
| } | } | ||||
| if (evaluateMetricRecords != null && !evaluateMetricRecords.isEmpty()) { | if (evaluateMetricRecords != null && !evaluateMetricRecords.isEmpty()) { | ||||
| for (Map<String, Object> evaluateMetricRecord : evaluateMetricRecords) { | for (Map<String, Object> evaluateMetricRecord : evaluateMetricRecords) { | ||||
| HashMap<String, Object> evaluateMetricValue = new HashMap<>(); | |||||
| String taskId = (String) evaluateMetricRecord.get("task_id"); | String taskId = (String) evaluateMetricRecord.get("task_id"); | ||||
| if (taskId.startsWith("model-evaluate")) { | if (taskId.startsWith("model-evaluate")) { | ||||
| String runId = (String) evaluateMetricRecord.get("run_id"); | String runId = (String) evaluateMetricRecord.get("run_id"); | ||||
| @@ -80,13 +83,14 @@ public class ExperimentInstanceStatusTask { | |||||
| for (InsMetricInfoVo expTrainInfo : expTrainInfos) { | for (InsMetricInfoVo expTrainInfo : expTrainInfos) { | ||||
| Map metrics = expTrainInfo.getMetrics(); | Map metrics = expTrainInfo.getMetrics(); | ||||
| evaluateMetricValue.putAll(metrics); | evaluateMetricValue.putAll(metrics); | ||||
| evaluateMetricValue.put("run_hash",expTrainInfo.getRunId()); | |||||
| evaluateMetricValue.put("run_hash", expTrainInfo.getRunId()); | |||||
| } | } | ||||
| } | } | ||||
| evaluateMetricValues.put(taskId, evaluateMetricValue); | |||||
| } | } | ||||
| } | } | ||||
| metricValue.put("train", trainMetricValue); | |||||
| metricValue.put("evaluate", evaluateMetricValue); | |||||
| metricValue.put("train", trainMetricValues); | |||||
| metricValue.put("evaluate", evaluateMetricValues); | |||||
| experimentIns.setMetricValue(JsonUtils.mapToJson(metricValue)); | experimentIns.setMetricValue(JsonUtils.mapToJson(metricValue)); | ||||
| } | } | ||||
| experimentIns.setUpdateTime(new Date()); | experimentIns.setUpdateTime(new Date()); | ||||
| @@ -12,6 +12,8 @@ public interface AimService { | |||||
| List<InsMetricInfoVo> getExpEvaluateInfos(Integer experimentId, String offset, int limit) throws Exception; | List<InsMetricInfoVo> getExpEvaluateInfos(Integer experimentId, String offset, int limit) throws Exception; | ||||
| List<InsMetricInfoVo> getExpInfos(boolean isTrain, Integer experimentId, int page, int size); | |||||
| List<InsMetricInfoVo> getExpInfos1(boolean isTrain, Integer experimentId, String runId) throws Exception; | List<InsMetricInfoVo> getExpInfos1(boolean isTrain, Integer experimentId, String runId) throws Exception; | ||||
| String getExpMetrics(List<String> runIds) throws Exception; | String getExpMetrics(List<String> runIds) throws Exception; | ||||
| @@ -1,7 +1,10 @@ | |||||
| package com.ruoyi.platform.service.impl; | package com.ruoyi.platform.service.impl; | ||||
| import com.alibaba.fastjson2.JSON; | import com.alibaba.fastjson2.JSON; | ||||
| import com.alibaba.fastjson2.JSONArray; | |||||
| import com.alibaba.fastjson2.JSONObject; | |||||
| import com.ruoyi.platform.domain.ExperimentIns; | import com.ruoyi.platform.domain.ExperimentIns; | ||||
| import com.ruoyi.platform.mapper.ExperimentInsDao; | |||||
| import com.ruoyi.platform.service.AimService; | import com.ruoyi.platform.service.AimService; | ||||
| import com.ruoyi.platform.service.ExperimentInsService; | import com.ruoyi.platform.service.ExperimentInsService; | ||||
| import com.ruoyi.platform.utils.AIM64EncoderUtil; | import com.ruoyi.platform.utils.AIM64EncoderUtil; | ||||
| @@ -24,6 +27,8 @@ import java.util.stream.Collectors; | |||||
| public class AimServiceImpl implements AimService { | public class AimServiceImpl implements AimService { | ||||
| @Resource | @Resource | ||||
| private ExperimentInsService experimentInsService; | private ExperimentInsService experimentInsService; | ||||
| @Resource | |||||
| private ExperimentInsDao experimentInsDao; | |||||
| @Value("${aim.url}") | @Value("${aim.url}") | ||||
| private String aimUrl; | private String aimUrl; | ||||
| @@ -146,10 +151,56 @@ public class AimServiceImpl implements AimService { | |||||
| return aimRunInfoList; | return aimRunInfoList; | ||||
| } | } | ||||
| // private List<InsMetricInfoVo> getExpInfos(Integer experimentId, int page, int size){ | |||||
| // PageRequest pageRequest = PageRequest.of(page,size); | |||||
| // | |||||
| // } | |||||
| public List<InsMetricInfoVo> getExpInfos(boolean isTrain, Integer experimentId, int page, int size) { | |||||
| PageRequest pageRequest = PageRequest.of(page, size); | |||||
| ExperimentIns query = new ExperimentIns(); | |||||
| query.setExperimentId(experimentId); | |||||
| List<ExperimentIns> experimentInsList = experimentInsDao.queryAllByLimit(query, pageRequest); | |||||
| if (experimentInsList == null || experimentInsList.size() == 0) { | |||||
| return new ArrayList<>(); | |||||
| } | |||||
| List<InsMetricInfoVo> aimRunInfoList = new ArrayList<>(); | |||||
| for (ExperimentIns experimentIns : experimentInsList) { | |||||
| InsMetricInfoVo aimRunInfo = new InsMetricInfoVo(); | |||||
| aimRunInfo.setExperimentInsId(experimentIns.getId()); | |||||
| aimRunInfo.setStartTime(experimentIns.getCreateTime()); | |||||
| aimRunInfo.setStatus(experimentIns.getStatus()); | |||||
| //解析参数 | |||||
| JSONArray params = JSON.parseArray(experimentIns.getGlobalParam()); | |||||
| HashMap<String, Object> paramsMap = new HashMap<>(); | |||||
| List<String> paramsNames = new ArrayList<>(); | |||||
| for (int i = 0; i < params.size(); i++) { | |||||
| JSONObject jsonObject = params.getJSONObject(i); | |||||
| String paramName = jsonObject.getString("param_name"); | |||||
| String paramValue = jsonObject.getString("param_value"); | |||||
| paramsMap.put(paramName, paramValue); | |||||
| paramsNames.add(paramName); | |||||
| } | |||||
| aimRunInfo.setParams(paramsMap); | |||||
| aimRunInfo.setParamsNames(paramsNames); | |||||
| //解析数据集 | |||||
| Map<String, Object> metricRecord = JacksonUtil.parseJSONStr2Map(experimentIns.getMetricRecord()); | |||||
| if (isTrain) { | |||||
| aimRunInfo.setDataset(getDataset("train", metricRecord)); | |||||
| } else { | |||||
| aimRunInfo.setDataset(getDataset("evaluate", metricRecord)); | |||||
| } | |||||
| //解析指标 | |||||
| Map<String, Object> metricValue = JacksonUtil.parseJSONStr2Map(experimentIns.getMetricValue()); | |||||
| if (isTrain) { | |||||
| setMetricValue("train",metricValue,aimRunInfo); | |||||
| } else { | |||||
| setMetricValue("evaluate",metricValue,aimRunInfo); | |||||
| } | |||||
| aimRunInfoList.add(aimRunInfo); | |||||
| } | |||||
| return aimRunInfoList; | |||||
| } | |||||
| @Override | @Override | ||||
| public List<InsMetricInfoVo> getExpInfos1(boolean isTrain, Integer experimentId, String runId) throws Exception { | public List<InsMetricInfoVo> getExpInfos1(boolean isTrain, Integer experimentId, String runId) throws Exception { | ||||
| @@ -240,6 +291,40 @@ public class AimServiceImpl implements AimService { | |||||
| return aimRunInfoList; | return aimRunInfoList; | ||||
| } | } | ||||
| private List<String> getDataset(String isTrain, Map<String, Object> metricRecord) { | |||||
| List<String> datasetList = new ArrayList<>(); | |||||
| List<Map<String, Object>> trainMetricRecords = (List<Map<String, Object>>) metricRecord.get(isTrain); | |||||
| for (Map<String, Object> trainMetricRecord : trainMetricRecords) { | |||||
| String taskId = (String) trainMetricRecord.get("task_id"); | |||||
| if (taskId.startsWith("model-" + isTrain)) { | |||||
| List<Map<String, Object>> datasets = (List<Map<String, Object>>) trainMetricRecord.get("datasets"); | |||||
| for (Map<String, Object> dataset : datasets) { | |||||
| String datasetName = dataset.get("dataset_name") + ":" + dataset.get("dataset_version"); | |||||
| datasetList.add(datasetName); | |||||
| } | |||||
| } | |||||
| } | |||||
| return datasetList; | |||||
| } | |||||
| private void setMetricValue(String isTrain, Map<String, Object> metricValue, InsMetricInfoVo aimRunInfo){ | |||||
| Map<String, Object> metricValues = (Map<String, Object>) metricValue.get(isTrain); | |||||
| HashMap<String, Object> metrics = new HashMap<>(); | |||||
| List<String> metricsNames = new ArrayList<>(); | |||||
| for (String key : metricValues.keySet()) { | |||||
| Map<String, Object> valueMap = (Map<String, Object>) metricValues.get(key); | |||||
| aimRunInfo.setRunId((String) valueMap.get("run_hash")); | |||||
| metrics.putAll(valueMap.entrySet().stream().filter(entry -> !entry.getKey().equals("run_hash")).collect(Collectors.toMap( | |||||
| Map.Entry::getKey, | |||||
| Map.Entry::getValue | |||||
| ))); | |||||
| metricsNames.addAll(valueMap.keySet().stream().filter(entry -> !entry.equals("run_hash")).collect(Collectors.toList())); | |||||
| } | |||||
| aimRunInfo.setMetrics(metrics); | |||||
| aimRunInfo.setMetricsNames(metricsNames); | |||||
| } | |||||
| private List<String> getTrainDateSet(List<Map<String, Object>> records, String aimrunId) { | private List<String> getTrainDateSet(List<Map<String, Object>> records, String aimrunId) { | ||||
| List<String> datasetList = new ArrayList<>(); | List<String> datasetList = new ArrayList<>(); | ||||