Browse Source

实验模型导出元数据记录指标修改

dev-lhz
chenzhihang 1 year ago
parent
commit
bd409e85b0
4 changed files with 110 additions and 19 deletions
  1. +4
    -4
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/aim/AimController.java
  2. +4
    -2
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/AimService.java
  3. +100
    -11
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/AimServiceImpl.java
  4. +2
    -2
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ModelsServiceImpl.java

+ 4
- 4
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/aim/AimController.java View File

@@ -24,15 +24,15 @@ public class AimController extends BaseController {
@GetMapping("/getExpTrainInfos/{experiment_id}") @GetMapping("/getExpTrainInfos/{experiment_id}")
@ApiOperation("获取当前实验的模型训练指标信息") @ApiOperation("获取当前实验的模型训练指标信息")
@ApiResponse @ApiResponse
public GenericsAjaxResult<List<InsMetricInfoVo>> getExpTrainInfos(@PathVariable("experiment_id") Integer experimentId, @RequestParam("run_id") String runId) throws Exception {
return genericsSuccess(aimService.getExpTrainInfos(experimentId, runId));
public GenericsAjaxResult<List<InsMetricInfoVo>> getExpTrainInfos(@PathVariable("experiment_id") Integer experimentId) throws Exception {
return genericsSuccess(aimService.getExpTrainInfos(experimentId));
} }


@GetMapping("/getExpEvaluateInfos/{experiment_id}") @GetMapping("/getExpEvaluateInfos/{experiment_id}")
@ApiOperation("获取当前实验的模型推理指标信息") @ApiOperation("获取当前实验的模型推理指标信息")
@ApiResponse @ApiResponse
public GenericsAjaxResult<List<InsMetricInfoVo>> getExpEvaluateInfos(@PathVariable("experiment_id") Integer experimentId, @RequestParam("run_id") String runId) throws Exception {
return genericsSuccess(aimService.getExpEvaluateInfos(experimentId, runId));
public GenericsAjaxResult<List<InsMetricInfoVo>> getExpEvaluateInfos(@PathVariable("experiment_id") Integer experimentId) throws Exception {
return genericsSuccess(aimService.getExpEvaluateInfos(experimentId));
} }


@PostMapping("/getExpMetrics") @PostMapping("/getExpMetrics")


+ 4
- 2
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/AimService.java View File

@@ -6,9 +6,11 @@ import java.util.List;


public interface AimService { public interface AimService {


List<InsMetricInfoVo> getExpTrainInfos(Integer experimentId, String runId) throws Exception;
List<InsMetricInfoVo> getExpTrainInfos(Integer experimentId) throws Exception;


List<InsMetricInfoVo> getExpEvaluateInfos(Integer experimentId, String runId) throws Exception;
List<InsMetricInfoVo> getExpTrainInfos1(boolean isTrain, Integer experimentId, String runId) throws Exception;

List<InsMetricInfoVo> getExpEvaluateInfos(Integer experimentId) throws Exception;


String getExpMetrics(List<String> runIds) throws Exception; String getExpMetrics(List<String> runIds) throws Exception;
} }

+ 100
- 11
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/AimServiceImpl.java View File

@@ -30,13 +30,13 @@ public class AimServiceImpl implements AimService {
private NewHttpUtils httpUtils; private NewHttpUtils httpUtils;


@Override @Override
public List<InsMetricInfoVo> getExpTrainInfos(Integer experimentId, String runId) throws Exception {
return getAimRunInfos(true, experimentId, runId);
public List<InsMetricInfoVo> getExpTrainInfos(Integer experimentId) throws Exception {
return getAimRunInfos(true, experimentId);
} }


@Override @Override
public List<InsMetricInfoVo> getExpEvaluateInfos(Integer experimentId, String runId) throws Exception {
return getAimRunInfos(false, experimentId, runId);
public List<InsMetricInfoVo> getExpEvaluateInfos(Integer experimentId) throws Exception {
return getAimRunInfos(false, experimentId);
} }


@Override @Override
@@ -45,13 +45,12 @@ public class AimServiceImpl implements AimService {
return aimUrl + "/metrics?select=" + decode; return aimUrl + "/metrics?select=" + decode;
} }


private List<InsMetricInfoVo> getAimRunInfos(boolean isTrain, Integer experimentId, String runId) throws Exception {
// String experimentName = "experiment-" + experimentId + "-train";
// if (!isTrain) {
// experimentName = "experiment-" + experimentId + "-evaluate";
// }
// String encodedUrlString = URLEncoder.encode("run.experiment==\"" + experimentName + "\"", "UTF-8");
String encodedUrlString = URLEncoder.encode("run.id==\"" + runId + "\"", "UTF-8");
private List<InsMetricInfoVo> getAimRunInfos(boolean isTrain, Integer experimentId) throws Exception {
String experimentName = "experiment-" + experimentId + "-train";
if (!isTrain) {
experimentName = "experiment-" + experimentId + "-evaluate";
}
String encodedUrlString = URLEncoder.encode("run.experiment==\"" + experimentName + "\"", "UTF-8");
String url = aimProxyUrl + "/api/runs/search/run?query=" + encodedUrlString; String url = aimProxyUrl + "/api/runs/search/run?query=" + encodedUrlString;
String s = httpUtils.sendGet(url, null); String s = httpUtils.sendGet(url, null);
List<Map<String, Object>> response = JacksonUtil.parseJSONStr2MapList(s); List<Map<String, Object>> response = JacksonUtil.parseJSONStr2MapList(s);
@@ -139,6 +138,96 @@ public class AimServiceImpl implements AimService {
} }




@Override
public List<InsMetricInfoVo> getExpTrainInfos1(boolean isTrain, Integer experimentId, String runId) throws Exception {
String encodedUrlString = URLEncoder.encode("run.id==\"" + runId + "\"", "UTF-8");
String url = aimProxyUrl + "/api/runs/search/run?query=" + encodedUrlString;
String s = httpUtils.sendGet(url, null);
List<Map<String, Object>> response = JacksonUtil.parseJSONStr2MapList(s);
System.out.println("response: " + JacksonUtil.toJSONString(response));
if (response == null || response.size() == 0) {
return new ArrayList<>();
}
//查询实例数据
List<ExperimentIns> byExperimentId = experimentInsService.queryByExperimentId(experimentId);

if (byExperimentId == null || byExperimentId.size() == 0) {
return new ArrayList<>();
}
List<InsMetricInfoVo> aimRunInfoList = new ArrayList<>();
for (Map<String, Object> run : response) {
InsMetricInfoVo aimRunInfo = new InsMetricInfoVo();
String runHash = (String) run.get("run_hash");

aimRunInfo.setRunId(runHash);

Map params = (Map) run.get("params");
Map<String, Object> paramMap = JsonUtils.flattenJson("", params);
aimRunInfo.setParams(paramMap);
String aimrunId = (String) paramMap.get("id");
Map<String, Object> tracesMap = (Map<String, Object>) run.get("traces");
List<Map<String, Object>> metricList = (List<Map<String, Object>>) tracesMap.get("metric");
//过滤name为__system__开头的对象
aimRunInfo.setMetrics(new HashMap<>());
if (metricList != null && metricList.size() > 0) {
List<Map<String, Object>> metricRelList = metricList.stream()
.filter(map -> !StringUtils.startsWith((String) map.get("name"), "__system__"))
.collect(Collectors.toList());
if (metricRelList != null && metricRelList.size() > 0) {
Map<String, Object> relMetricMap = new HashMap<>();
for (Map<String, Object> metricMap : metricRelList) {
relMetricMap.put((String) metricMap.get("name"), metricMap.get("last_value"));
}
aimRunInfo.setMetrics(relMetricMap);
}
}
//找到ins
for (ExperimentIns ins : byExperimentId) {
String metricRecordString = ins.getMetricRecord();
if (StringUtils.isEmpty(metricRecordString)) {
continue;
}
if (metricRecordString.contains(aimrunId)) {
aimRunInfo.setExperimentInsId(ins.getId());
aimRunInfo.setStatus(ins.getStatus());
aimRunInfo.setStartTime(ins.getCreateTime());
Map<String, Object> metricRecordMap = JacksonUtil.parseJSONStr2Map(metricRecordString);
if (isTrain) {
List<Map<String, Object>> records = (List<Map<String, Object>>) metricRecordMap.get("train");
List<String> datasetList = getTrainDateSet(records, aimrunId);
aimRunInfo.setDataset(datasetList);
} else {
List<Map<String, Object>> records = (List<Map<String, Object>>) metricRecordMap.get("evaluate");
List<String> datasetList = getTrainDateSet(records, aimrunId);
aimRunInfo.setDataset(datasetList);
}
aimRunInfoList.add(aimRunInfo);
}
}
}

//判断哪个最长

// 获取所有 metrics 的 key 的并集
Set<String> metricsKeys = (Set<String>) aimRunInfoList.stream()
.map(InsMetricInfoVo::getMetrics)
.flatMap(metrics -> metrics.keySet().stream())
.collect(Collectors.toSet());
// 将并集赋值给每个 InsMetricInfoVo 的 metricsNames 属性
aimRunInfoList.forEach(vo -> vo.setMetricsNames(new ArrayList<>(metricsKeys)));

// 获取所有 params 的 key 的并集
Set<String> paramKeys = (Set<String>) aimRunInfoList.stream()
.map(InsMetricInfoVo::getParams)
.flatMap(params -> params.keySet().stream())
.collect(Collectors.toSet());
// 将并集赋值给每个 InsMetricInfoVo 的 paramsNames 属性
aimRunInfoList.forEach(vo -> vo.setParamsNames(new ArrayList<>(paramKeys)));

return aimRunInfoList;
}


private List<String> getTrainDateSet(List<Map<String, Object>> records, String aimrunId) { private List<String> getTrainDateSet(List<Map<String, Object>> records, String aimrunId) {
List<String> datasetList = new ArrayList<>(); List<String> datasetList = new ArrayList<>();
for (Map<String, Object> record : records) { for (Map<String, Object> record : records) {


+ 2
- 2
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ModelsServiceImpl.java View File

@@ -1165,7 +1165,7 @@ public class ModelsServiceImpl implements ModelsService {
for (int i = 0; i < trainMetrics.size(); i++) { for (int i = 0; i < trainMetrics.size(); i++) {
JSONObject jsonObject = trainMetrics.getJSONObject(i); JSONObject jsonObject = trainMetrics.getJSONObject(i);
String runId = jsonObject.getString("run_id"); String runId = jsonObject.getString("run_id");
List<InsMetricInfoVo> expTrainInfos = aimsService.getExpTrainInfos(modelMetaVo.getTrainTask().getExperimentId(), runId);
List<InsMetricInfoVo> expTrainInfos = aimsService.getExpTrainInfos1(true, modelMetaVo.getTrainTask().getExperimentId(), runId);
for (InsMetricInfoVo expTrainInfo : expTrainInfos) { for (InsMetricInfoVo expTrainInfo : expTrainInfos) {
Map metrics1 = expTrainInfo.getMetrics(); Map metrics1 = expTrainInfo.getMetrics();
train.putAll(metrics1); train.putAll(metrics1);
@@ -1177,7 +1177,7 @@ public class ModelsServiceImpl implements ModelsService {
for (int i = 0; i < testMetrics.size(); i++) { for (int i = 0; i < testMetrics.size(); i++) {
JSONObject jsonObject = testMetrics.getJSONObject(i); JSONObject jsonObject = testMetrics.getJSONObject(i);
String runId = jsonObject.getString("run_id"); String runId = jsonObject.getString("run_id");
List<InsMetricInfoVo> expTestInfos = aimsService.getExpEvaluateInfos(modelMetaVo.getTrainTask().getExperimentId(), runId);
List<InsMetricInfoVo> expTestInfos = aimsService.getExpTrainInfos1(false, modelMetaVo.getTrainTask().getExperimentId(), runId);
for (InsMetricInfoVo expTestInfo : expTestInfos) { for (InsMetricInfoVo expTestInfo : expTestInfos) {
Map metrics1 = expTestInfo.getMetrics(); Map metrics1 = expTestInfo.getMetrics();
evaluate.putAll(metrics1); evaluate.putAll(metrics1);


Loading…
Cancel
Save