|
|
|
@@ -30,13 +30,13 @@ public class AimServiceImpl implements AimService { |
|
|
|
private NewHttpUtils httpUtils; |
|
|
|
|
|
|
|
@Override |
|
|
|
public List<InsMetricInfoVo> getExpTrainInfos(Integer experimentId, String runId) throws Exception { |
|
|
|
return getAimRunInfos(true, experimentId, runId); |
|
|
|
public List<InsMetricInfoVo> getExpTrainInfos(Integer experimentId) throws Exception { |
|
|
|
return getAimRunInfos(true, experimentId); |
|
|
|
} |
|
|
|
|
|
|
|
@Override |
|
|
|
public List<InsMetricInfoVo> getExpEvaluateInfos(Integer experimentId, String runId) throws Exception { |
|
|
|
return getAimRunInfos(false, experimentId, runId); |
|
|
|
public List<InsMetricInfoVo> getExpEvaluateInfos(Integer experimentId) throws Exception { |
|
|
|
return getAimRunInfos(false, experimentId); |
|
|
|
} |
|
|
|
|
|
|
|
@Override |
|
|
|
@@ -45,13 +45,12 @@ public class AimServiceImpl implements AimService { |
|
|
|
return aimUrl + "/metrics?select=" + decode; |
|
|
|
} |
|
|
|
|
|
|
|
private List<InsMetricInfoVo> getAimRunInfos(boolean isTrain, Integer experimentId, String runId) throws Exception { |
|
|
|
// String experimentName = "experiment-" + experimentId + "-train"; |
|
|
|
// if (!isTrain) { |
|
|
|
// experimentName = "experiment-" + experimentId + "-evaluate"; |
|
|
|
// } |
|
|
|
// String encodedUrlString = URLEncoder.encode("run.experiment==\"" + experimentName + "\"", "UTF-8"); |
|
|
|
String encodedUrlString = URLEncoder.encode("run.id==\"" + runId + "\"", "UTF-8"); |
|
|
|
private List<InsMetricInfoVo> getAimRunInfos(boolean isTrain, Integer experimentId) throws Exception { |
|
|
|
String experimentName = "experiment-" + experimentId + "-train"; |
|
|
|
if (!isTrain) { |
|
|
|
experimentName = "experiment-" + experimentId + "-evaluate"; |
|
|
|
} |
|
|
|
String encodedUrlString = URLEncoder.encode("run.experiment==\"" + experimentName + "\"", "UTF-8"); |
|
|
|
String url = aimProxyUrl + "/api/runs/search/run?query=" + encodedUrlString; |
|
|
|
String s = httpUtils.sendGet(url, null); |
|
|
|
List<Map<String, Object>> response = JacksonUtil.parseJSONStr2MapList(s); |
|
|
|
@@ -139,6 +138,96 @@ public class AimServiceImpl implements AimService { |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
@Override |
|
|
|
public List<InsMetricInfoVo> getExpTrainInfos1(boolean isTrain, Integer experimentId, String runId) throws Exception { |
|
|
|
String encodedUrlString = URLEncoder.encode("run.id==\"" + runId + "\"", "UTF-8"); |
|
|
|
String url = aimProxyUrl + "/api/runs/search/run?query=" + encodedUrlString; |
|
|
|
String s = httpUtils.sendGet(url, null); |
|
|
|
List<Map<String, Object>> response = JacksonUtil.parseJSONStr2MapList(s); |
|
|
|
System.out.println("response: " + JacksonUtil.toJSONString(response)); |
|
|
|
if (response == null || response.size() == 0) { |
|
|
|
return new ArrayList<>(); |
|
|
|
} |
|
|
|
//查询实例数据 |
|
|
|
List<ExperimentIns> byExperimentId = experimentInsService.queryByExperimentId(experimentId); |
|
|
|
|
|
|
|
if (byExperimentId == null || byExperimentId.size() == 0) { |
|
|
|
return new ArrayList<>(); |
|
|
|
} |
|
|
|
List<InsMetricInfoVo> aimRunInfoList = new ArrayList<>(); |
|
|
|
for (Map<String, Object> run : response) { |
|
|
|
InsMetricInfoVo aimRunInfo = new InsMetricInfoVo(); |
|
|
|
String runHash = (String) run.get("run_hash"); |
|
|
|
|
|
|
|
aimRunInfo.setRunId(runHash); |
|
|
|
|
|
|
|
Map params = (Map) run.get("params"); |
|
|
|
Map<String, Object> paramMap = JsonUtils.flattenJson("", params); |
|
|
|
aimRunInfo.setParams(paramMap); |
|
|
|
String aimrunId = (String) paramMap.get("id"); |
|
|
|
Map<String, Object> tracesMap = (Map<String, Object>) run.get("traces"); |
|
|
|
List<Map<String, Object>> metricList = (List<Map<String, Object>>) tracesMap.get("metric"); |
|
|
|
//过滤name为__system__开头的对象 |
|
|
|
aimRunInfo.setMetrics(new HashMap<>()); |
|
|
|
if (metricList != null && metricList.size() > 0) { |
|
|
|
List<Map<String, Object>> metricRelList = metricList.stream() |
|
|
|
.filter(map -> !StringUtils.startsWith((String) map.get("name"), "__system__")) |
|
|
|
.collect(Collectors.toList()); |
|
|
|
if (metricRelList != null && metricRelList.size() > 0) { |
|
|
|
Map<String, Object> relMetricMap = new HashMap<>(); |
|
|
|
for (Map<String, Object> metricMap : metricRelList) { |
|
|
|
relMetricMap.put((String) metricMap.get("name"), metricMap.get("last_value")); |
|
|
|
} |
|
|
|
aimRunInfo.setMetrics(relMetricMap); |
|
|
|
} |
|
|
|
} |
|
|
|
//找到ins |
|
|
|
for (ExperimentIns ins : byExperimentId) { |
|
|
|
String metricRecordString = ins.getMetricRecord(); |
|
|
|
if (StringUtils.isEmpty(metricRecordString)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (metricRecordString.contains(aimrunId)) { |
|
|
|
aimRunInfo.setExperimentInsId(ins.getId()); |
|
|
|
aimRunInfo.setStatus(ins.getStatus()); |
|
|
|
aimRunInfo.setStartTime(ins.getCreateTime()); |
|
|
|
Map<String, Object> metricRecordMap = JacksonUtil.parseJSONStr2Map(metricRecordString); |
|
|
|
if (isTrain) { |
|
|
|
List<Map<String, Object>> records = (List<Map<String, Object>>) metricRecordMap.get("train"); |
|
|
|
List<String> datasetList = getTrainDateSet(records, aimrunId); |
|
|
|
aimRunInfo.setDataset(datasetList); |
|
|
|
} else { |
|
|
|
List<Map<String, Object>> records = (List<Map<String, Object>>) metricRecordMap.get("evaluate"); |
|
|
|
List<String> datasetList = getTrainDateSet(records, aimrunId); |
|
|
|
aimRunInfo.setDataset(datasetList); |
|
|
|
} |
|
|
|
aimRunInfoList.add(aimRunInfo); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
//判断哪个最长 |
|
|
|
|
|
|
|
// 获取所有 metrics 的 key 的并集 |
|
|
|
Set<String> metricsKeys = (Set<String>) aimRunInfoList.stream() |
|
|
|
.map(InsMetricInfoVo::getMetrics) |
|
|
|
.flatMap(metrics -> metrics.keySet().stream()) |
|
|
|
.collect(Collectors.toSet()); |
|
|
|
// 将并集赋值给每个 InsMetricInfoVo 的 metricsNames 属性 |
|
|
|
aimRunInfoList.forEach(vo -> vo.setMetricsNames(new ArrayList<>(metricsKeys))); |
|
|
|
|
|
|
|
// 获取所有 params 的 key 的并集 |
|
|
|
Set<String> paramKeys = (Set<String>) aimRunInfoList.stream() |
|
|
|
.map(InsMetricInfoVo::getParams) |
|
|
|
.flatMap(params -> params.keySet().stream()) |
|
|
|
.collect(Collectors.toSet()); |
|
|
|
// 将并集赋值给每个 InsMetricInfoVo 的 paramsNames 属性 |
|
|
|
aimRunInfoList.forEach(vo -> vo.setParamsNames(new ArrayList<>(paramKeys))); |
|
|
|
|
|
|
|
return aimRunInfoList; |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
private List<String> getTrainDateSet(List<Map<String, Object>> records, String aimrunId) { |
|
|
|
List<String> datasetList = new ArrayList<>(); |
|
|
|
for (Map<String, Object> record : records) { |
|
|
|
|