|
|
|
@@ -1,7 +1,10 @@ |
|
|
|
package com.ruoyi.platform.service.impl; |
|
|
|
|
|
|
|
import com.alibaba.fastjson2.JSON; |
|
|
|
import com.alibaba.fastjson2.JSONArray; |
|
|
|
import com.alibaba.fastjson2.JSONObject; |
|
|
|
import com.ruoyi.platform.domain.ExperimentIns; |
|
|
|
import com.ruoyi.platform.mapper.ExperimentInsDao; |
|
|
|
import com.ruoyi.platform.service.AimService; |
|
|
|
import com.ruoyi.platform.service.ExperimentInsService; |
|
|
|
import com.ruoyi.platform.utils.AIM64EncoderUtil; |
|
|
|
@@ -24,6 +27,8 @@ import java.util.stream.Collectors; |
|
|
|
public class AimServiceImpl implements AimService { |
|
|
|
@Resource |
|
|
|
private ExperimentInsService experimentInsService; |
|
|
|
@Resource |
|
|
|
private ExperimentInsDao experimentInsDao; |
|
|
|
|
|
|
|
@Value("${aim.url}") |
|
|
|
private String aimUrl; |
|
|
|
@@ -146,10 +151,56 @@ public class AimServiceImpl implements AimService { |
|
|
|
return aimRunInfoList; |
|
|
|
} |
|
|
|
|
|
|
|
// private List<InsMetricInfoVo> getExpInfos(Integer experimentId, int page, int size){ |
|
|
|
// PageRequest pageRequest = PageRequest.of(page,size); |
|
|
|
// |
|
|
|
// } |
|
|
|
public List<InsMetricInfoVo> getExpInfos(boolean isTrain, Integer experimentId, int page, int size) { |
|
|
|
PageRequest pageRequest = PageRequest.of(page, size); |
|
|
|
ExperimentIns query = new ExperimentIns(); |
|
|
|
query.setExperimentId(experimentId); |
|
|
|
List<ExperimentIns> experimentInsList = experimentInsDao.queryAllByLimit(query, pageRequest); |
|
|
|
if (experimentInsList == null || experimentInsList.size() == 0) { |
|
|
|
return new ArrayList<>(); |
|
|
|
} |
|
|
|
List<InsMetricInfoVo> aimRunInfoList = new ArrayList<>(); |
|
|
|
for (ExperimentIns experimentIns : experimentInsList) { |
|
|
|
InsMetricInfoVo aimRunInfo = new InsMetricInfoVo(); |
|
|
|
aimRunInfo.setExperimentInsId(experimentIns.getId()); |
|
|
|
aimRunInfo.setStartTime(experimentIns.getCreateTime()); |
|
|
|
aimRunInfo.setStatus(experimentIns.getStatus()); |
|
|
|
|
|
|
|
//解析参数 |
|
|
|
JSONArray params = JSON.parseArray(experimentIns.getGlobalParam()); |
|
|
|
HashMap<String, Object> paramsMap = new HashMap<>(); |
|
|
|
List<String> paramsNames = new ArrayList<>(); |
|
|
|
|
|
|
|
for (int i = 0; i < params.size(); i++) { |
|
|
|
JSONObject jsonObject = params.getJSONObject(i); |
|
|
|
String paramName = jsonObject.getString("param_name"); |
|
|
|
String paramValue = jsonObject.getString("param_value"); |
|
|
|
paramsMap.put(paramName, paramValue); |
|
|
|
paramsNames.add(paramName); |
|
|
|
} |
|
|
|
aimRunInfo.setParams(paramsMap); |
|
|
|
aimRunInfo.setParamsNames(paramsNames); |
|
|
|
|
|
|
|
//解析数据集 |
|
|
|
Map<String, Object> metricRecord = JacksonUtil.parseJSONStr2Map(experimentIns.getMetricRecord()); |
|
|
|
if (isTrain) { |
|
|
|
aimRunInfo.setDataset(getDataset("train", metricRecord)); |
|
|
|
} else { |
|
|
|
aimRunInfo.setDataset(getDataset("evaluate", metricRecord)); |
|
|
|
} |
|
|
|
|
|
|
|
//解析指标 |
|
|
|
Map<String, Object> metricValue = JacksonUtil.parseJSONStr2Map(experimentIns.getMetricValue()); |
|
|
|
if (isTrain) { |
|
|
|
setMetricValue("train",metricValue,aimRunInfo); |
|
|
|
} else { |
|
|
|
setMetricValue("evaluate",metricValue,aimRunInfo); |
|
|
|
} |
|
|
|
aimRunInfoList.add(aimRunInfo); |
|
|
|
} |
|
|
|
return aimRunInfoList; |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
@Override |
|
|
|
public List<InsMetricInfoVo> getExpInfos1(boolean isTrain, Integer experimentId, String runId) throws Exception { |
|
|
|
@@ -240,6 +291,40 @@ public class AimServiceImpl implements AimService { |
|
|
|
return aimRunInfoList; |
|
|
|
} |
|
|
|
|
|
|
|
private List<String> getDataset(String isTrain, Map<String, Object> metricRecord) { |
|
|
|
List<String> datasetList = new ArrayList<>(); |
|
|
|
List<Map<String, Object>> trainMetricRecords = (List<Map<String, Object>>) metricRecord.get(isTrain); |
|
|
|
for (Map<String, Object> trainMetricRecord : trainMetricRecords) { |
|
|
|
String taskId = (String) trainMetricRecord.get("task_id"); |
|
|
|
if (taskId.startsWith("model-" + isTrain)) { |
|
|
|
List<Map<String, Object>> datasets = (List<Map<String, Object>>) trainMetricRecord.get("datasets"); |
|
|
|
for (Map<String, Object> dataset : datasets) { |
|
|
|
String datasetName = dataset.get("dataset_name") + ":" + dataset.get("dataset_version"); |
|
|
|
datasetList.add(datasetName); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
return datasetList; |
|
|
|
} |
|
|
|
|
|
|
|
private void setMetricValue(String isTrain, Map<String, Object> metricValue, InsMetricInfoVo aimRunInfo){ |
|
|
|
Map<String, Object> metricValues = (Map<String, Object>) metricValue.get(isTrain); |
|
|
|
|
|
|
|
HashMap<String, Object> metrics = new HashMap<>(); |
|
|
|
List<String> metricsNames = new ArrayList<>(); |
|
|
|
for (String key : metricValues.keySet()) { |
|
|
|
Map<String, Object> valueMap = (Map<String, Object>) metricValues.get(key); |
|
|
|
aimRunInfo.setRunId((String) valueMap.get("run_hash")); |
|
|
|
|
|
|
|
metrics.putAll(valueMap.entrySet().stream().filter(entry -> !entry.getKey().equals("run_hash")).collect(Collectors.toMap( |
|
|
|
Map.Entry::getKey, |
|
|
|
Map.Entry::getValue |
|
|
|
))); |
|
|
|
metricsNames.addAll(valueMap.keySet().stream().filter(entry -> !entry.equals("run_hash")).collect(Collectors.toList())); |
|
|
|
} |
|
|
|
aimRunInfo.setMetrics(metrics); |
|
|
|
aimRunInfo.setMetricsNames(metricsNames); |
|
|
|
} |
|
|
|
|
|
|
|
private List<String> getTrainDateSet(List<Map<String, Object>> records, String aimrunId) { |
|
|
|
List<String> datasetList = new ArrayList<>(); |
|
|
|
|