| @@ -205,6 +205,17 @@ | |||
| <groupId>org.springframework.boot</groupId> | |||
| <artifactId>spring-boot-starter-websocket</artifactId> | |||
| </dependency> | |||
| <dependency> | |||
| <groupId>org.json</groupId> | |||
| <artifactId>json</artifactId> | |||
| <version>20210307</version> | |||
| </dependency> | |||
| <dependency> | |||
| <groupId>org.apache.dubbo</groupId> | |||
| <artifactId>dubbo</artifactId> | |||
| <version>3.0.8</version> | |||
| <scope>compile</scope> | |||
| </dependency> | |||
| </dependencies> | |||
| @@ -4,16 +4,16 @@ import com.ruoyi.common.core.web.controller.BaseController; | |||
| import com.ruoyi.common.core.web.domain.GenericsAjaxResult; | |||
| import com.ruoyi.platform.service.AimService; | |||
| import com.ruoyi.platform.vo.FrameLogPathVo; | |||
| import com.ruoyi.platform.vo.InsMetricInfoVo; | |||
| import com.ruoyi.platform.vo.PodStatusVo; | |||
| import io.swagger.annotations.Api; | |||
| import io.swagger.annotations.ApiOperation; | |||
| import io.swagger.v3.oas.annotations.responses.ApiResponse; | |||
| import org.springframework.web.bind.annotation.PostMapping; | |||
| import org.springframework.web.bind.annotation.RequestBody; | |||
| import org.springframework.web.bind.annotation.RequestMapping; | |||
| import org.springframework.web.bind.annotation.RestController; | |||
| import org.springframework.web.bind.annotation.*; | |||
| import javax.annotation.Resource; | |||
| import java.util.List; | |||
| @RestController | |||
| @RequestMapping("aim") | |||
| @Api("Aim管理") | |||
| @@ -22,17 +22,25 @@ public class AimController extends BaseController { | |||
| @Resource | |||
| private AimService aimService; | |||
| /** | |||
| * 启动tensorBoard接口 | |||
| * | |||
| * @param frameLogPathVo 存储路径 | |||
| * @return url | |||
| */ | |||
| @PostMapping("/run") | |||
| @ApiOperation("启动aim`") | |||
| @GetMapping("/getExpTrainInfos/{experiment_id}") | |||
| @ApiOperation("获取当前实验的模型训练指标信息") | |||
| @ApiResponse | |||
| public GenericsAjaxResult<String> runAim(@RequestBody FrameLogPathVo frameLogPathVo) throws Exception { | |||
| return genericsSuccess(aimService.runAim(frameLogPathVo)); | |||
| public GenericsAjaxResult<List<InsMetricInfoVo>> getExpTrainInfos(@PathVariable("experiment_id") Integer experimentId) throws Exception { | |||
| return genericsSuccess(aimService.getExpTrainInfos(experimentId)); | |||
| } | |||
| @GetMapping("/getExpEvaluateInfos/{experiment_id}") | |||
| @ApiOperation("获取当前实验的模型推理指标信息") | |||
| @ApiResponse | |||
| public GenericsAjaxResult<List<InsMetricInfoVo>> getExpEvaluateInfos(@PathVariable("experiment_id") Integer experimentId) throws Exception { | |||
| return genericsSuccess(aimService.getExpEvaluateInfos(experimentId)); | |||
| } | |||
| @PostMapping("/getExpMetrics") | |||
| @ApiOperation("获取当前实验的指标对比地址") | |||
| @ApiResponse | |||
| public GenericsAjaxResult<String> getExpMetrics(@RequestBody List<String> runIds) throws Exception { | |||
| return genericsSuccess(aimService.getExpMetrics(runIds)); | |||
| } | |||
| } | |||
| @@ -1,7 +1,14 @@ | |||
| package com.ruoyi.platform.service; | |||
| import com.ruoyi.platform.vo.FrameLogPathVo; | |||
| import com.ruoyi.platform.vo.InsMetricInfoVo; | |||
| import java.util.List; | |||
| public interface AimService { | |||
| String runAim(FrameLogPathVo frameLogPathVo); | |||
| List<InsMetricInfoVo> getExpTrainInfos(Integer experimentId) throws Exception; | |||
| List<InsMetricInfoVo> getExpEvaluateInfos(Integer experimentId) throws Exception; | |||
| String getExpMetrics(List<String> runIds) throws Exception; | |||
| } | |||
| @@ -1,13 +1,123 @@ | |||
| package com.ruoyi.platform.service.impl; | |||
| import com.alibaba.druid.util.StringUtils; | |||
| import com.ruoyi.platform.domain.ExperimentIns; | |||
| import com.ruoyi.platform.service.AimService; | |||
| import com.ruoyi.platform.vo.FrameLogPathVo; | |||
| import com.ruoyi.platform.service.ExperimentInsService; | |||
| import com.ruoyi.platform.service.ExperimentService; | |||
| import com.ruoyi.platform.utils.AIM64EncoderUtil; | |||
| import com.ruoyi.platform.utils.HttpUtils; | |||
| import com.ruoyi.platform.utils.JacksonUtil; | |||
| import com.ruoyi.platform.utils.JsonUtils; | |||
| import com.ruoyi.platform.vo.InsMetricInfoVo; | |||
| import org.apache.dubbo.container.Main; | |||
| import org.json.JSONObject; | |||
| import org.json.JSONTokener; | |||
| import org.springframework.stereotype.Service; | |||
| import javax.annotation.Resource; | |||
| import java.net.URLEncoder; | |||
| import java.util.*; | |||
| import java.util.stream.Collectors; | |||
| @Service | |||
| public class AimServiceImpl implements AimService { | |||
| @Resource | |||
| private ExperimentInsService experimentInsService; | |||
| @Override | |||
| public List<InsMetricInfoVo> getExpTrainInfos(Integer experimentId) throws Exception { | |||
| String experimentName = "experiment-train-0"+experimentId; | |||
| return getAimRunInfos("",experimentId); | |||
| } | |||
| @Override | |||
| public String runAim(FrameLogPathVo frameLogPathVo) { | |||
| return null; | |||
| public List<InsMetricInfoVo> getExpEvaluateInfos(Integer experimentId) throws Exception { | |||
| String experimentName = "experiment-evaluate-0"+experimentId; | |||
| return getAimRunInfos("",experimentId); | |||
| } | |||
| @Override | |||
| public String getExpMetrics(List<String> runIds) throws Exception { | |||
| String decode = AIM64EncoderUtil.decode(runIds); | |||
| return "http://172.20.32.21:7123/api/runs/search/run?query="+decode; | |||
| } | |||
| private List<InsMetricInfoVo> getAimRunInfos(String experimentName,Integer experimentId) throws Exception { | |||
| String encodedUrlString = URLEncoder.encode("run.experiment==\"experiment-0000\"", "UTF-8"); | |||
| String url = "http://172.20.32.181:30123/api/runs/search/run?query="+encodedUrlString; | |||
| String s = HttpUtils.sendGetRequest(url); | |||
| System.out.println(s); | |||
| List<Map<String, Object>> response = JacksonUtil.parseJSONStr2MapList(s); | |||
| // TODO: parse aim response to InsMetricInfoVo list | |||
| if (response == null || response.size() == 0){ | |||
| return new ArrayList<>(); | |||
| } | |||
| //查询实例数据 | |||
| List<ExperimentIns> byExperimentId = experimentInsService.getByExperimentId(experimentId); | |||
| // if (byExperimentId == null || byExperimentId.size() == 0){ | |||
| // return new ArrayList<>(); | |||
| // } | |||
| List<InsMetricInfoVo> aimRunInfoList = new ArrayList<>(); | |||
| for (Map<String, Object> run : response) { | |||
| InsMetricInfoVo aimRunInfo = new InsMetricInfoVo(); | |||
| String runHash = (String) run.get("run_hash"); | |||
| aimRunInfo.setRunId(runHash); | |||
| Map params= (Map) run.get("params"); | |||
| Map<String, Object> paramMap = JsonUtils.flattenJson("", params); | |||
| aimRunInfo.setParams(paramMap); | |||
| Map<String, Object> tracesMap= (Map<String, Object>) run.get("params"); | |||
| List<Map<String, Object>> metricList = (List<Map<String, Object>>) tracesMap.get("metric"); | |||
| //过滤name为__system__开头的对象 | |||
| aimRunInfo.setMetrics(new HashMap<>()); | |||
| if (metricList != null && metricList.size() > 0){ | |||
| List<Map<String, Object>> metricRelList = metricList.stream() | |||
| .filter(map -> !StringUtils.equals("__system__", (String) map.get("name"))) | |||
| .collect(Collectors.toList()); | |||
| if (metricRelList!= null && metricRelList.size() > 0){ | |||
| Map<String, Object> relMetricMap = new HashMap<>(); | |||
| for (Map<String, Object> metricMap : metricRelList) { | |||
| relMetricMap.put((String)metricMap.get("name"), metricMap.get("last_value")); | |||
| } | |||
| aimRunInfo.setMetrics(relMetricMap); | |||
| } | |||
| } | |||
| //找到ins | |||
| for (ExperimentIns ins : byExperimentId) { | |||
| String metricRecord = ins.getMetricRecord(); | |||
| if (metricRecord.contains(runHash)){ | |||
| aimRunInfo.setExperimentInsId(ins.getId()); | |||
| aimRunInfo.setStatus(ins.getStatus()); | |||
| aimRunInfo.setStartTime(ins.getStartTime()); | |||
| } | |||
| } | |||
| aimRunInfoList.add(aimRunInfo); | |||
| } | |||
| //判断哪个最长 | |||
| Optional<InsMetricInfoVo> maxMetricsVo = aimRunInfoList.stream() | |||
| .max((vo1, vo2) -> Integer.compare(vo1.getMetrics().size(), vo2.getMetrics().size())); | |||
| // 如果找到了,设置 metricsFlag 为 true | |||
| if (maxMetricsVo.isPresent()) { | |||
| maxMetricsVo.get().setMetricsFlag(true); | |||
| } | |||
| Optional<InsMetricInfoVo> maxParamsVo = aimRunInfoList.stream() | |||
| .max((vo1, vo2) -> Integer.compare(vo1.getParams().size(), vo2.getParams().size())); | |||
| // 如果找到了,设置 metricsFlag 为 true | |||
| if (maxParamsVo.isPresent()) { | |||
| maxParamsVo.get().setMetricsFlag(true); | |||
| } | |||
| return aimRunInfoList; | |||
| } | |||
| } | |||
| @@ -254,9 +254,13 @@ public class ExperimentServiceImpl implements ExperimentService { | |||
| //获取训练参数 | |||
| Map<String, Object> metricRecord = (Map<String, Object>) runResMap.get("metric_record"); | |||
| Map<String, Object> metadata = (Map<String, Object>) data.get("metadata"); | |||
| // 插入记录到实验实例表 | |||
| ExperimentIns experimentIns = new ExperimentIns(); | |||
| if (metricRecord != null){ | |||
| experimentIns.setMetricRecord(JacksonUtil.toJSONString(metricRecord)); | |||
| } | |||
| experimentIns.setExperimentId(experiment.getId()); | |||
| experimentIns.setArgoInsNs((String) metadata.get("namespace")); | |||
| experimentIns.setArgoInsName((String) metadata.get("name")); | |||
| @@ -275,8 +279,9 @@ public class ExperimentServiceImpl implements ExperimentService { | |||
| Map<String, Object> converMap2 = JsonUtils.jsonToMap(JacksonUtil.replaceInAarry(convertRes, params)); | |||
| Map<String ,Object> dependendcy = (Map<String, Object>)converMap2.get("model_dependency"); | |||
| Map<String ,Object> trainInfo = (Map<String, Object>)converMap2.get("component_info"); | |||
| insertModelDependency(dependendcy,trainInfo,insert.getId(),experiment.getName()); | |||
| if (dependendcy != null && trainInfo != null){ | |||
| insertModelDependency(dependendcy,trainInfo,insert.getId(),experiment.getName()); | |||
| } | |||
| }catch (Exception e){ | |||
| throw new RuntimeException(e); | |||
| } | |||
| @@ -0,0 +1,76 @@ | |||
| package com.ruoyi.platform.utils; | |||
| import com.alibaba.fastjson.JSON; | |||
| import java.util.Base64; | |||
| import java.util.HashMap; | |||
| import java.util.List; | |||
| import java.util.Map; | |||
| public class AIM64EncoderUtil { | |||
| private static final String AIM64_ENCODING_PREFIX = "O-"; | |||
| private static final Map<String, String> BS64_REPLACE_CHARACTERS_ENCODING = new HashMap<>(); | |||
| static { | |||
| BS64_REPLACE_CHARACTERS_ENCODING.put("=", ""); | |||
| BS64_REPLACE_CHARACTERS_ENCODING.put("+", "-"); | |||
| BS64_REPLACE_CHARACTERS_ENCODING.put("/", "_"); | |||
| } | |||
| public static String aim64encode(Map<String, Object> value) { | |||
| String jsonEncoded = JSON.toJSONString(value); | |||
| String base64Encoded = Base64.getEncoder().encodeToString(jsonEncoded.getBytes()); | |||
| String aim64Encoded = base64Encoded; | |||
| for (Map.Entry<String, String> entry : BS64_REPLACE_CHARACTERS_ENCODING.entrySet()) { | |||
| aim64Encoded = aim64Encoded.replace(entry.getKey(), entry.getValue()); | |||
| } | |||
| return AIM64_ENCODING_PREFIX + aim64Encoded; | |||
| } | |||
| public static String encode(Map<String, Object> value, boolean oneWayHashing) { | |||
| if (oneWayHashing) { | |||
| return md5(JSON.toJSONString(value)); | |||
| } | |||
| return aim64encode(value); | |||
| } | |||
| private static String md5(String input) { | |||
| try { | |||
| java.security.MessageDigest md = java.security.MessageDigest.getInstance("MD5"); | |||
| byte[] array = md.digest(input.getBytes()); | |||
| StringBuilder sb = new StringBuilder(); | |||
| for (byte b : array) { | |||
| sb.append(Integer.toHexString((b & 0xFF) | 0x100).substring(1, 3)); | |||
| } | |||
| return sb.toString(); | |||
| } catch (java.security.NoSuchAlgorithmException e) { | |||
| e.printStackTrace(); | |||
| } | |||
| return null; | |||
| } | |||
| public static String decode(List<String> runIds) { | |||
| // 确保 runIds 列表的大小为 3 | |||
| if (runIds == null || runIds.size() == 0) { | |||
| throw new IllegalArgumentException("runIds 不能为空"); | |||
| } | |||
| // 构建查询字符串 | |||
| StringBuilder queryBuilder = new StringBuilder("run.hash in ["); | |||
| for (int i = 0; i < runIds.size(); i++) { | |||
| if (i > 0) { | |||
| queryBuilder.append(","); | |||
| } | |||
| queryBuilder.append("\"").append(runIds.get(i)).append("\""); | |||
| } | |||
| queryBuilder.append("]"); | |||
| String query = queryBuilder.toString(); | |||
| Map<String, Object> map = new HashMap<>(); | |||
| map.put("query", query); | |||
| map.put("advancedMode", true); | |||
| map.put("advancedQuery", query); | |||
| String searchQuery = encode(map, false); | |||
| return searchQuery; | |||
| } | |||
| } | |||
| @@ -25,6 +25,7 @@ import java.security.cert.X509Certificate; | |||
| import java.util.HashMap; | |||
| import java.util.List; | |||
| import java.util.Map; | |||
| import java.util.zip.GZIPInputStream; | |||
| /** | |||
| * HTTP请求工具类 | |||
| @@ -447,4 +448,38 @@ public class HttpUtils { | |||
| return true; | |||
| } | |||
| } | |||
| public static String sendGetRequestgzip(String url) throws Exception { | |||
| String resultStr = null; | |||
| HttpGet httpGet = new HttpGet(url); | |||
| httpGet.setHeader("Content-Type", "application/json"); | |||
| httpGet.setHeader("Accept-Encoding", "gzip, deflate"); | |||
| try { | |||
| HttpResponse response = httpClient.execute(httpGet); | |||
| int responseCode = response.getStatusLine().getStatusCode(); | |||
| if (responseCode != 200) { | |||
| throw new IOException("HTTP request failed with response code: " + responseCode); | |||
| } | |||
| // 获取响应内容 | |||
| InputStream responseStream = response.getEntity().getContent(); | |||
| // 检查响应是否被压缩 | |||
| if ("gzip".equalsIgnoreCase(response.getEntity().getContentEncoding().getValue())) { | |||
| responseStream = new GZIPInputStream(responseStream); | |||
| } | |||
| // 读取解压缩后的内容 | |||
| byte[] buffer = new byte[1024]; | |||
| int len; | |||
| StringBuilder decompressedString = new StringBuilder(); | |||
| while ((len = responseStream.read(buffer)) > 0) { | |||
| decompressedString.append(new String(buffer, 0, len)); | |||
| } | |||
| resultStr = decompressedString.toString(); | |||
| } catch (IOException e) { | |||
| e.printStackTrace(); | |||
| } | |||
| return resultStr; | |||
| } | |||
| } | |||
| @@ -1,8 +1,11 @@ | |||
| package com.ruoyi.platform.utils; | |||
| import com.fasterxml.jackson.core.JsonProcessingException; | |||
| import com.fasterxml.jackson.databind.ObjectMapper; | |||
| import org.json.JSONObject; | |||
| import java.io.IOException; | |||
| import java.util.HashMap; | |||
| import java.util.Iterator; | |||
| import java.util.Map; | |||
| public class JsonUtils { | |||
| @@ -28,4 +31,26 @@ public class JsonUtils { | |||
| public static <T> T jsonToObject(String json, Class<T> clazz) throws IOException { | |||
| return objectMapper.readValue(json, clazz); | |||
| } | |||
| // 将JSON字符串转换为扁平化的Map | |||
| public static Map<String, Object> flattenJson(String prefix, Map<String, Object> map) { | |||
| Map<String, Object> flatMap = new HashMap<>(); | |||
| Iterator<Map.Entry<String, Object>> entries = map.entrySet().iterator(); | |||
| while (entries.hasNext()) { | |||
| Map.Entry<String, Object> entry = entries.next(); | |||
| String key = entry.getKey(); | |||
| Object value = entry.getValue(); | |||
| if (value instanceof Map) { | |||
| flatMap.putAll(flattenJson(prefix + key + ".", (Map<String, Object>) value)); | |||
| } else { | |||
| flatMap.put(prefix + key, value); | |||
| } | |||
| } | |||
| return flatMap; | |||
| } | |||
| } | |||
| @@ -0,0 +1,32 @@ | |||
| package com.ruoyi.platform.vo; | |||
| import com.fasterxml.jackson.databind.PropertyNamingStrategy; | |||
| import com.fasterxml.jackson.databind.annotation.JsonNaming; | |||
| import io.swagger.annotations.ApiModelProperty; | |||
| import lombok.Data; | |||
| import java.io.Serializable; | |||
| import java.util.Date; | |||
| import java.util.List; | |||
| import java.util.Map; | |||
| @JsonNaming(PropertyNamingStrategy.SnakeCaseStrategy.class) | |||
| @Data | |||
| public class InsMetricInfoVo implements Serializable { | |||
| @ApiModelProperty(value = "开始时间") | |||
| private Date startTime; | |||
| @ApiModelProperty(value = "实例运行状态") | |||
| private String status; | |||
| @ApiModelProperty(value = "使用数据集") | |||
| private List<Map<String, Object>> dataset; | |||
| @ApiModelProperty(value = "实例ID") | |||
| private Integer experimentInsId; | |||
| @ApiModelProperty(value = "训练指标") | |||
| private Map metrics; | |||
| @ApiModelProperty(value = "训练参数") | |||
| private Map params; | |||
| @ApiModelProperty(value = "训练记录ID") | |||
| private String runId; | |||
| private Boolean metricsFlag = false; | |||
| private Boolean paramsFlag = false; | |||
| } | |||