Browse Source

实验模型导出元数据记录指标修改

dev-lhz
chenzhihang 1 year ago
parent
commit
3bc6e7fb51
4 changed files with 55 additions and 43 deletions
  1. +5
    -7
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/aim/AimController.java
  2. +2
    -2
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/AimService.java
  3. +32
    -31
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/AimServiceImpl.java
  4. +16
    -3
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ModelsServiceImpl.java

+ 5
- 7
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/aim/AimController.java View File

@@ -3,9 +3,7 @@ package com.ruoyi.platform.controller.aim;
import com.ruoyi.common.core.web.controller.BaseController; import com.ruoyi.common.core.web.controller.BaseController;
import com.ruoyi.common.core.web.domain.GenericsAjaxResult; import com.ruoyi.common.core.web.domain.GenericsAjaxResult;
import com.ruoyi.platform.service.AimService; import com.ruoyi.platform.service.AimService;
import com.ruoyi.platform.vo.FrameLogPathVo;
import com.ruoyi.platform.vo.InsMetricInfoVo; import com.ruoyi.platform.vo.InsMetricInfoVo;
import com.ruoyi.platform.vo.PodStatusVo;
import io.swagger.annotations.Api; import io.swagger.annotations.Api;
import io.swagger.annotations.ApiOperation; import io.swagger.annotations.ApiOperation;
import io.swagger.v3.oas.annotations.responses.ApiResponse; import io.swagger.v3.oas.annotations.responses.ApiResponse;
@@ -26,21 +24,21 @@ public class AimController extends BaseController {
@GetMapping("/getExpTrainInfos/{experiment_id}") @GetMapping("/getExpTrainInfos/{experiment_id}")
@ApiOperation("获取当前实验的模型训练指标信息") @ApiOperation("获取当前实验的模型训练指标信息")
@ApiResponse @ApiResponse
public GenericsAjaxResult<List<InsMetricInfoVo>> getExpTrainInfos(@PathVariable("experiment_id") Integer experimentId) throws Exception {
return genericsSuccess(aimService.getExpTrainInfos(experimentId));
public GenericsAjaxResult<List<InsMetricInfoVo>> getExpTrainInfos(@PathVariable("experiment_id") Integer experimentId, @RequestParam("run_id") String runId) throws Exception {
return genericsSuccess(aimService.getExpTrainInfos(experimentId, runId));
} }


@GetMapping("/getExpEvaluateInfos/{experiment_id}") @GetMapping("/getExpEvaluateInfos/{experiment_id}")
@ApiOperation("获取当前实验的模型推理指标信息") @ApiOperation("获取当前实验的模型推理指标信息")
@ApiResponse @ApiResponse
public GenericsAjaxResult<List<InsMetricInfoVo>> getExpEvaluateInfos(@PathVariable("experiment_id") Integer experimentId) throws Exception {
return genericsSuccess(aimService.getExpEvaluateInfos(experimentId));
public GenericsAjaxResult<List<InsMetricInfoVo>> getExpEvaluateInfos(@PathVariable("experiment_id") Integer experimentId, @RequestParam("run_id") String runId) throws Exception {
return genericsSuccess(aimService.getExpEvaluateInfos(experimentId, runId));
} }


@PostMapping("/getExpMetrics") @PostMapping("/getExpMetrics")
@ApiOperation("获取当前实验的指标对比地址") @ApiOperation("获取当前实验的指标对比地址")
@ApiResponse @ApiResponse
public GenericsAjaxResult<String> getExpMetrics(@RequestBody List<String> runIds) throws Exception { public GenericsAjaxResult<String> getExpMetrics(@RequestBody List<String> runIds) throws Exception {
return genericsSuccess(aimService.getExpMetrics(runIds));
return genericsSuccess(aimService.getExpMetrics(runIds));
} }
} }

+ 2
- 2
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/AimService.java View File

@@ -6,9 +6,9 @@ import java.util.List;


public interface AimService { public interface AimService {


List<InsMetricInfoVo> getExpTrainInfos(Integer experimentId) throws Exception;
List<InsMetricInfoVo> getExpTrainInfos(Integer experimentId, String runId) throws Exception;


List<InsMetricInfoVo> getExpEvaluateInfos(Integer experimentId) throws Exception;
List<InsMetricInfoVo> getExpEvaluateInfos(Integer experimentId, String runId) throws Exception;


String getExpMetrics(List<String> runIds) throws Exception; String getExpMetrics(List<String> runIds) throws Exception;
} }

+ 32
- 31
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/AimServiceImpl.java View File

@@ -30,38 +30,39 @@ public class AimServiceImpl implements AimService {
private NewHttpUtils httpUtils; private NewHttpUtils httpUtils;


@Override @Override
public List<InsMetricInfoVo> getExpTrainInfos(Integer experimentId) throws Exception {
return getAimRunInfos(true,experimentId);
public List<InsMetricInfoVo> getExpTrainInfos(Integer experimentId, String runId) throws Exception {
return getAimRunInfos(true, experimentId, runId);
} }


@Override @Override
public List<InsMetricInfoVo> getExpEvaluateInfos(Integer experimentId) throws Exception {
return getAimRunInfos(false,experimentId);
public List<InsMetricInfoVo> getExpEvaluateInfos(Integer experimentId, String runId) throws Exception {
return getAimRunInfos(false, experimentId, runId);
} }


@Override @Override
public String getExpMetrics(List<String> runIds) throws Exception { public String getExpMetrics(List<String> runIds) throws Exception {
String decode = AIM64EncoderUtil.decode(runIds); String decode = AIM64EncoderUtil.decode(runIds);
return aimUrl+"/metrics?select="+decode;
return aimUrl + "/metrics?select=" + decode;
} }


private List<InsMetricInfoVo> getAimRunInfos(boolean isTrain,Integer experimentId) throws Exception {
String experimentName = "experiment-"+experimentId+"-train";
if (!isTrain){
experimentName = "experiment-"+experimentId+"-evaluate";
}
String encodedUrlString = URLEncoder.encode("run.experiment==\""+experimentName+"\"", "UTF-8");
String url = aimProxyUrl+"/api/runs/search/run?query="+encodedUrlString;
String s = httpUtils.sendGet(url,null);
private List<InsMetricInfoVo> getAimRunInfos(boolean isTrain, Integer experimentId, String runId) throws Exception {
// String experimentName = "experiment-" + experimentId + "-train";
// if (!isTrain) {
// experimentName = "experiment-" + experimentId + "-evaluate";
// }
// String encodedUrlString = URLEncoder.encode("run.experiment==\"" + experimentName + "\"", "UTF-8");
String encodedUrlString = URLEncoder.encode("run.id==\"" + runId + "\"", "UTF-8");
String url = aimProxyUrl + "/api/runs/search/run?query=" + encodedUrlString;
String s = httpUtils.sendGet(url, null);
List<Map<String, Object>> response = JacksonUtil.parseJSONStr2MapList(s); List<Map<String, Object>> response = JacksonUtil.parseJSONStr2MapList(s);
System.out.println("response: "+JacksonUtil.toJSONString(response));
if (response == null || response.size() == 0){
System.out.println("response: " + JacksonUtil.toJSONString(response));
if (response == null || response.size() == 0) {
return new ArrayList<>(); return new ArrayList<>();
} }
//查询实例数据 //查询实例数据
List<ExperimentIns> byExperimentId = experimentInsService.queryByExperimentId(experimentId); List<ExperimentIns> byExperimentId = experimentInsService.queryByExperimentId(experimentId);


if (byExperimentId == null || byExperimentId.size() == 0){
if (byExperimentId == null || byExperimentId.size() == 0) {
return new ArrayList<>(); return new ArrayList<>();
} }
List<InsMetricInfoVo> aimRunInfoList = new ArrayList<>(); List<InsMetricInfoVo> aimRunInfoList = new ArrayList<>();
@@ -71,22 +72,22 @@ public class AimServiceImpl implements AimService {


aimRunInfo.setRunId(runHash); aimRunInfo.setRunId(runHash);


Map params= (Map) run.get("params");
Map params = (Map) run.get("params");
Map<String, Object> paramMap = JsonUtils.flattenJson("", params); Map<String, Object> paramMap = JsonUtils.flattenJson("", params);
aimRunInfo.setParams(paramMap); aimRunInfo.setParams(paramMap);
String aimrunId = (String) paramMap.get("id"); String aimrunId = (String) paramMap.get("id");
Map<String, Object> tracesMap= (Map<String, Object>) run.get("traces");
Map<String, Object> tracesMap = (Map<String, Object>) run.get("traces");
List<Map<String, Object>> metricList = (List<Map<String, Object>>) tracesMap.get("metric"); List<Map<String, Object>> metricList = (List<Map<String, Object>>) tracesMap.get("metric");
//过滤name为__system__开头的对象 //过滤name为__system__开头的对象
aimRunInfo.setMetrics(new HashMap<>()); aimRunInfo.setMetrics(new HashMap<>());
if (metricList != null && metricList.size() > 0){
if (metricList != null && metricList.size() > 0) {
List<Map<String, Object>> metricRelList = metricList.stream() List<Map<String, Object>> metricRelList = metricList.stream()
.filter(map -> !StringUtils.startsWith((String) map.get("name"),"__system__" ))
.filter(map -> !StringUtils.startsWith((String) map.get("name"), "__system__"))
.collect(Collectors.toList()); .collect(Collectors.toList());
if (metricRelList!= null && metricRelList.size() > 0){
if (metricRelList != null && metricRelList.size() > 0) {
Map<String, Object> relMetricMap = new HashMap<>(); Map<String, Object> relMetricMap = new HashMap<>();
for (Map<String, Object> metricMap : metricRelList) { for (Map<String, Object> metricMap : metricRelList) {
relMetricMap.put((String)metricMap.get("name"), metricMap.get("last_value"));
relMetricMap.put((String) metricMap.get("name"), metricMap.get("last_value"));
} }
aimRunInfo.setMetrics(relMetricMap); aimRunInfo.setMetrics(relMetricMap);
} }
@@ -94,19 +95,19 @@ public class AimServiceImpl implements AimService {
//找到ins //找到ins
for (ExperimentIns ins : byExperimentId) { for (ExperimentIns ins : byExperimentId) {
String metricRecordString = ins.getMetricRecord(); String metricRecordString = ins.getMetricRecord();
if (StringUtils.isEmpty(metricRecordString)){
if (StringUtils.isEmpty(metricRecordString)) {
continue; continue;
} }
if (metricRecordString.contains(aimrunId)){
if (metricRecordString.contains(aimrunId)) {
aimRunInfo.setExperimentInsId(ins.getId()); aimRunInfo.setExperimentInsId(ins.getId());
aimRunInfo.setStatus(ins.getStatus()); aimRunInfo.setStatus(ins.getStatus());
aimRunInfo.setStartTime(ins.getCreateTime()); aimRunInfo.setStartTime(ins.getCreateTime());
Map<String, Object> metricRecordMap = JacksonUtil.parseJSONStr2Map(metricRecordString); Map<String, Object> metricRecordMap = JacksonUtil.parseJSONStr2Map(metricRecordString);
if (isTrain){
if (isTrain) {
List<Map<String, Object>> records = (List<Map<String, Object>>) metricRecordMap.get("train"); List<Map<String, Object>> records = (List<Map<String, Object>>) metricRecordMap.get("train");
List<String> datasetList = getTrainDateSet(records, aimrunId); List<String> datasetList = getTrainDateSet(records, aimrunId);
aimRunInfo.setDataset(datasetList); aimRunInfo.setDataset(datasetList);
}else {
} else {
List<Map<String, Object>> records = (List<Map<String, Object>>) metricRecordMap.get("evaluate"); List<Map<String, Object>> records = (List<Map<String, Object>>) metricRecordMap.get("evaluate");
List<String> datasetList = getTrainDateSet(records, aimrunId); List<String> datasetList = getTrainDateSet(records, aimrunId);
aimRunInfo.setDataset(datasetList); aimRunInfo.setDataset(datasetList);
@@ -138,16 +139,16 @@ public class AimServiceImpl implements AimService {
} }




private List<String> getTrainDateSet(List<Map<String, Object>> records, String aimrunId){
private List<String> getTrainDateSet(List<Map<String, Object>> records, String aimrunId) {
List<String> datasetList = new ArrayList<>(); List<String> datasetList = new ArrayList<>();
for (Map<String, Object> record : records) { for (Map<String, Object> record : records) {
if (StringUtils.equals(aimrunId, (String)record.get("run_id"))) {
if (StringUtils.equals(aimrunId, (String) record.get("run_id"))) {
List<Map<String, Object>> datasets = (List<Map<String, Object>>) record.get("datasets"); List<Map<String, Object>> datasets = (List<Map<String, Object>>) record.get("datasets");
if (datasets == null || datasets.size() == 0){
if (datasets == null || datasets.size() == 0) {
continue; continue;
} }
for (Map<String, Object> dataset : datasets){
String datasetName = (String) dataset.get("dataset_name")+":"+(String) dataset.get("dataset_version");
for (Map<String, Object> dataset : datasets) {
String datasetName = (String) dataset.get("dataset_name") + ":" + (String) dataset.get("dataset_version");
datasetList.add(datasetName); datasetList.add(datasetName);
} }
break; break;


+ 16
- 3
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ModelsServiceImpl.java View File

@@ -1,6 +1,8 @@
package com.ruoyi.platform.service.impl; package com.ruoyi.platform.service.impl;


import com.alibaba.fastjson2.JSON; import com.alibaba.fastjson2.JSON;
import com.alibaba.fastjson2.JSONArray;
import com.alibaba.fastjson2.JSONObject;
import com.ruoyi.common.core.utils.DateUtils; import com.ruoyi.common.core.utils.DateUtils;
import com.ruoyi.common.security.utils.SecurityUtils; import com.ruoyi.common.security.utils.SecurityUtils;
import com.ruoyi.platform.annotations.CheckDuplicate; import com.ruoyi.platform.annotations.CheckDuplicate;
@@ -1134,9 +1136,20 @@ public class ModelsServiceImpl implements ModelsService {
} }


void getMetrics(ModelMetaVo modelMetaVo) throws Exception { void getMetrics(ModelMetaVo modelMetaVo) throws Exception {
List<InsMetricInfoVo> expTrainInfos = aimsService.getExpTrainInfos(modelMetaVo.getTrainTask().getExperimentId());
for (InsMetricInfoVo expTrainInfo : expTrainInfos) {
System.out.println(expTrainInfo.getMetrics());
HashMap<String, Object> metrics = modelMetaVo.getMetrics();
JSONArray trainMetrics = (JSONArray) metrics.get("train");
for (int i = 0; i < trainMetrics.size(); i++) {
JSONObject jsonObject = trainMetrics.getJSONObject(i);
String runId = jsonObject.getString("run_id");
List<InsMetricInfoVo> expTrainInfos = aimsService.getExpTrainInfos(modelMetaVo.getTrainTask().getExperimentId(), runId);
System.out.print(expTrainInfos);
}

JSONArray testMetrics = (JSONArray) metrics.get("test");
for (int i = 0; i < testMetrics.size(); i++) {
JSONObject jsonObject = testMetrics.getJSONObject(i);
String runId = jsonObject.getString("run_id");
List<InsMetricInfoVo> expTestInfos = aimsService.getExpEvaluateInfos(modelMetaVo.getTrainTask().getExperimentId(), runId);
} }
} }
} }

Loading…
Cancel
Save