Browse Source

新增指标对比

pull/95/head
fanshuai 1 year ago
parent
commit
01e1f7e951
9 changed files with 330 additions and 21 deletions
  1. +11
    -0
      ruoyi-modules/management-platform/pom.xml
  2. +22
    -14
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/aim/AimController.java
  3. +9
    -2
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/AimService.java
  4. +113
    -3
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/AimServiceImpl.java
  5. +7
    -2
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ExperimentServiceImpl.java
  6. +76
    -0
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/utils/AIM64EncoderUtil.java
  7. +35
    -0
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/utils/HttpUtils.java
  8. +25
    -0
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/utils/JsonUtils.java
  9. +32
    -0
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/vo/InsMetricInfoVo.java

+ 11
- 0
ruoyi-modules/management-platform/pom.xml View File

@@ -205,6 +205,17 @@
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-websocket</artifactId>
</dependency>
<dependency>
<groupId>org.json</groupId>
<artifactId>json</artifactId>
<version>20210307</version>
</dependency>
<dependency>
<groupId>org.apache.dubbo</groupId>
<artifactId>dubbo</artifactId>
<version>3.0.8</version>
<scope>compile</scope>
</dependency>


</dependencies>


+ 22
- 14
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/aim/AimController.java View File

@@ -4,16 +4,16 @@ import com.ruoyi.common.core.web.controller.BaseController;
import com.ruoyi.common.core.web.domain.GenericsAjaxResult;
import com.ruoyi.platform.service.AimService;
import com.ruoyi.platform.vo.FrameLogPathVo;
import com.ruoyi.platform.vo.InsMetricInfoVo;
import com.ruoyi.platform.vo.PodStatusVo;
import io.swagger.annotations.Api;
import io.swagger.annotations.ApiOperation;
import io.swagger.v3.oas.annotations.responses.ApiResponse;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.bind.annotation.*;

import javax.annotation.Resource;
import java.util.List;

@RestController
@RequestMapping("aim")
@Api("Aim管理")
@@ -22,17 +22,25 @@ public class AimController extends BaseController {
@Resource
private AimService aimService;

/**
* 启动tensorBoard接口
*
* @param frameLogPathVo 存储路径
* @return url
*/
@PostMapping("/run")
@ApiOperation("启动aim`")

@GetMapping("/getExpTrainInfos/{experiment_id}")
@ApiOperation("获取当前实验的模型训练指标信息")
@ApiResponse
public GenericsAjaxResult<String> runAim(@RequestBody FrameLogPathVo frameLogPathVo) throws Exception {
return genericsSuccess(aimService.runAim(frameLogPathVo));
public GenericsAjaxResult<List<InsMetricInfoVo>> getExpTrainInfos(@PathVariable("experiment_id") Integer experimentId) throws Exception {
return genericsSuccess(aimService.getExpTrainInfos(experimentId));
}

@GetMapping("/getExpEvaluateInfos/{experiment_id}")
@ApiOperation("获取当前实验的模型推理指标信息")
@ApiResponse
public GenericsAjaxResult<List<InsMetricInfoVo>> getExpEvaluateInfos(@PathVariable("experiment_id") Integer experimentId) throws Exception {
return genericsSuccess(aimService.getExpEvaluateInfos(experimentId));
}

@PostMapping("/getExpMetrics")
@ApiOperation("获取当前实验的指标对比地址")
@ApiResponse
public GenericsAjaxResult<String> getExpMetrics(@RequestBody List<String> runIds) throws Exception {
return genericsSuccess(aimService.getExpMetrics(runIds));
}
}

+ 9
- 2
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/AimService.java View File

@@ -1,7 +1,14 @@
package com.ruoyi.platform.service;

import com.ruoyi.platform.vo.FrameLogPathVo;
import com.ruoyi.platform.vo.InsMetricInfoVo;

import java.util.List;

public interface AimService {
String runAim(FrameLogPathVo frameLogPathVo);

List<InsMetricInfoVo> getExpTrainInfos(Integer experimentId) throws Exception;

List<InsMetricInfoVo> getExpEvaluateInfos(Integer experimentId) throws Exception;

String getExpMetrics(List<String> runIds) throws Exception;
}

+ 113
- 3
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/AimServiceImpl.java View File

@@ -1,13 +1,123 @@
package com.ruoyi.platform.service.impl;

import com.alibaba.druid.util.StringUtils;
import com.ruoyi.platform.domain.ExperimentIns;
import com.ruoyi.platform.service.AimService;
import com.ruoyi.platform.vo.FrameLogPathVo;
import com.ruoyi.platform.service.ExperimentInsService;
import com.ruoyi.platform.service.ExperimentService;
import com.ruoyi.platform.utils.AIM64EncoderUtil;
import com.ruoyi.platform.utils.HttpUtils;
import com.ruoyi.platform.utils.JacksonUtil;
import com.ruoyi.platform.utils.JsonUtils;
import com.ruoyi.platform.vo.InsMetricInfoVo;
import org.apache.dubbo.container.Main;
import org.json.JSONObject;
import org.json.JSONTokener;
import org.springframework.stereotype.Service;

import javax.annotation.Resource;
import java.net.URLEncoder;
import java.util.*;
import java.util.stream.Collectors;

@Service
public class AimServiceImpl implements AimService {
@Resource
private ExperimentInsService experimentInsService;

@Override
public List<InsMetricInfoVo> getExpTrainInfos(Integer experimentId) throws Exception {
String experimentName = "experiment-train-0"+experimentId;
return getAimRunInfos("",experimentId);
}

@Override
public String runAim(FrameLogPathVo frameLogPathVo) {
return null;
public List<InsMetricInfoVo> getExpEvaluateInfos(Integer experimentId) throws Exception {
String experimentName = "experiment-evaluate-0"+experimentId;
return getAimRunInfos("",experimentId);
}

@Override
public String getExpMetrics(List<String> runIds) throws Exception {
String decode = AIM64EncoderUtil.decode(runIds);
return "http://172.20.32.21:7123/api/runs/search/run?query="+decode;
}

private List<InsMetricInfoVo> getAimRunInfos(String experimentName,Integer experimentId) throws Exception {
String encodedUrlString = URLEncoder.encode("run.experiment==\"experiment-0000\"", "UTF-8");
String url = "http://172.20.32.181:30123/api/runs/search/run?query="+encodedUrlString;
String s = HttpUtils.sendGetRequest(url);
System.out.println(s);
List<Map<String, Object>> response = JacksonUtil.parseJSONStr2MapList(s);
// TODO: parse aim response to InsMetricInfoVo list
if (response == null || response.size() == 0){
return new ArrayList<>();
}
//查询实例数据
List<ExperimentIns> byExperimentId = experimentInsService.getByExperimentId(experimentId);

// if (byExperimentId == null || byExperimentId.size() == 0){
// return new ArrayList<>();
// }
List<InsMetricInfoVo> aimRunInfoList = new ArrayList<>();
for (Map<String, Object> run : response) {
InsMetricInfoVo aimRunInfo = new InsMetricInfoVo();
String runHash = (String) run.get("run_hash");
aimRunInfo.setRunId(runHash);

Map params= (Map) run.get("params");
Map<String, Object> paramMap = JsonUtils.flattenJson("", params);
aimRunInfo.setParams(paramMap);

Map<String, Object> tracesMap= (Map<String, Object>) run.get("params");
List<Map<String, Object>> metricList = (List<Map<String, Object>>) tracesMap.get("metric");
//过滤name为__system__开头的对象
aimRunInfo.setMetrics(new HashMap<>());
if (metricList != null && metricList.size() > 0){
List<Map<String, Object>> metricRelList = metricList.stream()
.filter(map -> !StringUtils.equals("__system__", (String) map.get("name")))
.collect(Collectors.toList());
if (metricRelList!= null && metricRelList.size() > 0){
Map<String, Object> relMetricMap = new HashMap<>();
for (Map<String, Object> metricMap : metricRelList) {
relMetricMap.put((String)metricMap.get("name"), metricMap.get("last_value"));
}
aimRunInfo.setMetrics(relMetricMap);
}
}



//找到ins

for (ExperimentIns ins : byExperimentId) {
String metricRecord = ins.getMetricRecord();
if (metricRecord.contains(runHash)){
aimRunInfo.setExperimentInsId(ins.getId());
aimRunInfo.setStatus(ins.getStatus());
aimRunInfo.setStartTime(ins.getStartTime());
}
}
aimRunInfoList.add(aimRunInfo);
}
//判断哪个最长

Optional<InsMetricInfoVo> maxMetricsVo = aimRunInfoList.stream()
.max((vo1, vo2) -> Integer.compare(vo1.getMetrics().size(), vo2.getMetrics().size()));

// 如果找到了,设置 metricsFlag 为 true
if (maxMetricsVo.isPresent()) {
maxMetricsVo.get().setMetricsFlag(true);
}
Optional<InsMetricInfoVo> maxParamsVo = aimRunInfoList.stream()
.max((vo1, vo2) -> Integer.compare(vo1.getParams().size(), vo2.getParams().size()));

// 如果找到了,设置 metricsFlag 为 true
if (maxParamsVo.isPresent()) {
maxParamsVo.get().setMetricsFlag(true);
}

return aimRunInfoList;
}

}

+ 7
- 2
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ExperimentServiceImpl.java View File

@@ -254,9 +254,13 @@ public class ExperimentServiceImpl implements ExperimentService {
//获取训练参数
Map<String, Object> metricRecord = (Map<String, Object>) runResMap.get("metric_record");


Map<String, Object> metadata = (Map<String, Object>) data.get("metadata");
// 插入记录到实验实例表
ExperimentIns experimentIns = new ExperimentIns();
if (metricRecord != null){
experimentIns.setMetricRecord(JacksonUtil.toJSONString(metricRecord));
}
experimentIns.setExperimentId(experiment.getId());
experimentIns.setArgoInsNs((String) metadata.get("namespace"));
experimentIns.setArgoInsName((String) metadata.get("name"));
@@ -275,8 +279,9 @@ public class ExperimentServiceImpl implements ExperimentService {
Map<String, Object> converMap2 = JsonUtils.jsonToMap(JacksonUtil.replaceInAarry(convertRes, params));
Map<String ,Object> dependendcy = (Map<String, Object>)converMap2.get("model_dependency");
Map<String ,Object> trainInfo = (Map<String, Object>)converMap2.get("component_info");
insertModelDependency(dependendcy,trainInfo,insert.getId(),experiment.getName());

if (dependendcy != null && trainInfo != null){
insertModelDependency(dependendcy,trainInfo,insert.getId(),experiment.getName());
}
}catch (Exception e){
throw new RuntimeException(e);
}


+ 76
- 0
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/utils/AIM64EncoderUtil.java View File

@@ -0,0 +1,76 @@
package com.ruoyi.platform.utils;

import com.alibaba.fastjson.JSON;

import java.util.Base64;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class AIM64EncoderUtil {

private static final String AIM64_ENCODING_PREFIX = "O-";

private static final Map<String, String> BS64_REPLACE_CHARACTERS_ENCODING = new HashMap<>();
static {
BS64_REPLACE_CHARACTERS_ENCODING.put("=", "");
BS64_REPLACE_CHARACTERS_ENCODING.put("+", "-");
BS64_REPLACE_CHARACTERS_ENCODING.put("/", "_");
}

public static String aim64encode(Map<String, Object> value) {
String jsonEncoded = JSON.toJSONString(value);
String base64Encoded = Base64.getEncoder().encodeToString(jsonEncoded.getBytes());
String aim64Encoded = base64Encoded;
for (Map.Entry<String, String> entry : BS64_REPLACE_CHARACTERS_ENCODING.entrySet()) {
aim64Encoded = aim64Encoded.replace(entry.getKey(), entry.getValue());
}
return AIM64_ENCODING_PREFIX + aim64Encoded;
}

public static String encode(Map<String, Object> value, boolean oneWayHashing) {
if (oneWayHashing) {
return md5(JSON.toJSONString(value));
}
return aim64encode(value);
}

private static String md5(String input) {
try {
java.security.MessageDigest md = java.security.MessageDigest.getInstance("MD5");
byte[] array = md.digest(input.getBytes());
StringBuilder sb = new StringBuilder();
for (byte b : array) {
sb.append(Integer.toHexString((b & 0xFF) | 0x100).substring(1, 3));
}
return sb.toString();
} catch (java.security.NoSuchAlgorithmException e) {
e.printStackTrace();
}
return null;
}

public static String decode(List<String> runIds) {
// 确保 runIds 列表的大小为 3
if (runIds == null || runIds.size() == 0) {
throw new IllegalArgumentException("runIds 不能为空");
}
// 构建查询字符串
StringBuilder queryBuilder = new StringBuilder("run.hash in [");
for (int i = 0; i < runIds.size(); i++) {
if (i > 0) {
queryBuilder.append(",");
}
queryBuilder.append("\"").append(runIds.get(i)).append("\"");
}
queryBuilder.append("]");
String query = queryBuilder.toString();
Map<String, Object> map = new HashMap<>();
map.put("query", query);
map.put("advancedMode", true);
map.put("advancedQuery", query);

String searchQuery = encode(map, false);
return searchQuery;
}
}

+ 35
- 0
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/utils/HttpUtils.java View File

@@ -25,6 +25,7 @@ import java.security.cert.X509Certificate;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.zip.GZIPInputStream;

/**
* HTTP请求工具类
@@ -447,4 +448,38 @@ public class HttpUtils {
return true;
}
}

public static String sendGetRequestgzip(String url) throws Exception {
String resultStr = null;
HttpGet httpGet = new HttpGet(url);
httpGet.setHeader("Content-Type", "application/json");
httpGet.setHeader("Accept-Encoding", "gzip, deflate");
try {
HttpResponse response = httpClient.execute(httpGet);
int responseCode = response.getStatusLine().getStatusCode();
if (responseCode != 200) {
throw new IOException("HTTP request failed with response code: " + responseCode);
}

// 获取响应内容
InputStream responseStream = response.getEntity().getContent();
// 检查响应是否被压缩
if ("gzip".equalsIgnoreCase(response.getEntity().getContentEncoding().getValue())) {
responseStream = new GZIPInputStream(responseStream);
}

// 读取解压缩后的内容
byte[] buffer = new byte[1024];
int len;
StringBuilder decompressedString = new StringBuilder();
while ((len = responseStream.read(buffer)) > 0) {
decompressedString.append(new String(buffer, 0, len));
}

resultStr = decompressedString.toString();
} catch (IOException e) {
e.printStackTrace();
}
return resultStr;
}
}

+ 25
- 0
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/utils/JsonUtils.java View File

@@ -1,8 +1,11 @@
package com.ruoyi.platform.utils;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.json.JSONObject;

import java.io.IOException;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;

public class JsonUtils {
@@ -28,4 +31,26 @@ public class JsonUtils {
public static <T> T jsonToObject(String json, Class<T> clazz) throws IOException {
return objectMapper.readValue(json, clazz);
}



// 将JSON字符串转换为扁平化的Map
public static Map<String, Object> flattenJson(String prefix, Map<String, Object> map) {
Map<String, Object> flatMap = new HashMap<>();
Iterator<Map.Entry<String, Object>> entries = map.entrySet().iterator();

while (entries.hasNext()) {
Map.Entry<String, Object> entry = entries.next();
String key = entry.getKey();
Object value = entry.getValue();

if (value instanceof Map) {
flatMap.putAll(flattenJson(prefix + key + ".", (Map<String, Object>) value));
} else {
flatMap.put(prefix + key, value);
}
}

return flatMap;
}
}

+ 32
- 0
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/vo/InsMetricInfoVo.java View File

@@ -0,0 +1,32 @@
package com.ruoyi.platform.vo;

import com.fasterxml.jackson.databind.PropertyNamingStrategy;
import com.fasterxml.jackson.databind.annotation.JsonNaming;
import io.swagger.annotations.ApiModelProperty;
import lombok.Data;

import java.io.Serializable;
import java.util.Date;
import java.util.List;
import java.util.Map;

@JsonNaming(PropertyNamingStrategy.SnakeCaseStrategy.class)
@Data
public class InsMetricInfoVo implements Serializable {
@ApiModelProperty(value = "开始时间")
private Date startTime;
@ApiModelProperty(value = "实例运行状态")
private String status;
@ApiModelProperty(value = "使用数据集")
private List<Map<String, Object>> dataset;
@ApiModelProperty(value = "实例ID")
private Integer experimentInsId;
@ApiModelProperty(value = "训练指标")
private Map metrics;
@ApiModelProperty(value = "训练参数")
private Map params;
@ApiModelProperty(value = "训练记录ID")
private String runId;
private Boolean metricsFlag = false;
private Boolean paramsFlag = false;
}

Loading…
Cancel
Save