| @@ -205,6 +205,17 @@ | |||||
| <groupId>org.springframework.boot</groupId> | <groupId>org.springframework.boot</groupId> | ||||
| <artifactId>spring-boot-starter-websocket</artifactId> | <artifactId>spring-boot-starter-websocket</artifactId> | ||||
| </dependency> | </dependency> | ||||
| <dependency> | |||||
| <groupId>org.json</groupId> | |||||
| <artifactId>json</artifactId> | |||||
| <version>20210307</version> | |||||
| </dependency> | |||||
| <dependency> | |||||
| <groupId>org.apache.dubbo</groupId> | |||||
| <artifactId>dubbo</artifactId> | |||||
| <version>3.0.8</version> | |||||
| <scope>compile</scope> | |||||
| </dependency> | |||||
| </dependencies> | </dependencies> | ||||
| @@ -4,16 +4,16 @@ import com.ruoyi.common.core.web.controller.BaseController; | |||||
| import com.ruoyi.common.core.web.domain.GenericsAjaxResult; | import com.ruoyi.common.core.web.domain.GenericsAjaxResult; | ||||
| import com.ruoyi.platform.service.AimService; | import com.ruoyi.platform.service.AimService; | ||||
| import com.ruoyi.platform.vo.FrameLogPathVo; | import com.ruoyi.platform.vo.FrameLogPathVo; | ||||
| import com.ruoyi.platform.vo.InsMetricInfoVo; | |||||
| import com.ruoyi.platform.vo.PodStatusVo; | import com.ruoyi.platform.vo.PodStatusVo; | ||||
| import io.swagger.annotations.Api; | import io.swagger.annotations.Api; | ||||
| import io.swagger.annotations.ApiOperation; | import io.swagger.annotations.ApiOperation; | ||||
| import io.swagger.v3.oas.annotations.responses.ApiResponse; | import io.swagger.v3.oas.annotations.responses.ApiResponse; | ||||
| import org.springframework.web.bind.annotation.PostMapping; | |||||
| import org.springframework.web.bind.annotation.RequestBody; | |||||
| import org.springframework.web.bind.annotation.RequestMapping; | |||||
| import org.springframework.web.bind.annotation.RestController; | |||||
| import org.springframework.web.bind.annotation.*; | |||||
| import javax.annotation.Resource; | import javax.annotation.Resource; | ||||
| import java.util.List; | |||||
| @RestController | @RestController | ||||
| @RequestMapping("aim") | @RequestMapping("aim") | ||||
| @Api("Aim管理") | @Api("Aim管理") | ||||
| @@ -22,17 +22,25 @@ public class AimController extends BaseController { | |||||
| @Resource | @Resource | ||||
| private AimService aimService; | private AimService aimService; | ||||
| /** | |||||
| * 启动tensorBoard接口 | |||||
| * | |||||
| * @param frameLogPathVo 存储路径 | |||||
| * @return url | |||||
| */ | |||||
| @PostMapping("/run") | |||||
| @ApiOperation("启动aim`") | |||||
| @GetMapping("/getExpTrainInfos/{experiment_id}") | |||||
| @ApiOperation("获取当前实验的模型训练指标信息") | |||||
| @ApiResponse | @ApiResponse | ||||
| public GenericsAjaxResult<String> runAim(@RequestBody FrameLogPathVo frameLogPathVo) throws Exception { | |||||
| return genericsSuccess(aimService.runAim(frameLogPathVo)); | |||||
| public GenericsAjaxResult<List<InsMetricInfoVo>> getExpTrainInfos(@PathVariable("experiment_id") Integer experimentId) throws Exception { | |||||
| return genericsSuccess(aimService.getExpTrainInfos(experimentId)); | |||||
| } | } | ||||
| @GetMapping("/getExpEvaluateInfos/{experiment_id}") | |||||
| @ApiOperation("获取当前实验的模型推理指标信息") | |||||
| @ApiResponse | |||||
| public GenericsAjaxResult<List<InsMetricInfoVo>> getExpEvaluateInfos(@PathVariable("experiment_id") Integer experimentId) throws Exception { | |||||
| return genericsSuccess(aimService.getExpEvaluateInfos(experimentId)); | |||||
| } | |||||
| @PostMapping("/getExpMetrics") | |||||
| @ApiOperation("获取当前实验的指标对比地址") | |||||
| @ApiResponse | |||||
| public GenericsAjaxResult<String> getExpMetrics(@RequestBody List<String> runIds) throws Exception { | |||||
| return genericsSuccess(aimService.getExpMetrics(runIds)); | |||||
| } | |||||
| } | } | ||||
| @@ -1,7 +1,14 @@ | |||||
| package com.ruoyi.platform.service; | package com.ruoyi.platform.service; | ||||
| import com.ruoyi.platform.vo.FrameLogPathVo; | |||||
| import com.ruoyi.platform.vo.InsMetricInfoVo; | |||||
| import java.util.List; | |||||
| public interface AimService { | public interface AimService { | ||||
| String runAim(FrameLogPathVo frameLogPathVo); | |||||
| List<InsMetricInfoVo> getExpTrainInfos(Integer experimentId) throws Exception; | |||||
| List<InsMetricInfoVo> getExpEvaluateInfos(Integer experimentId) throws Exception; | |||||
| String getExpMetrics(List<String> runIds) throws Exception; | |||||
| } | } | ||||
| @@ -1,13 +1,123 @@ | |||||
| package com.ruoyi.platform.service.impl; | package com.ruoyi.platform.service.impl; | ||||
| import com.alibaba.druid.util.StringUtils; | |||||
| import com.ruoyi.platform.domain.ExperimentIns; | |||||
| import com.ruoyi.platform.service.AimService; | import com.ruoyi.platform.service.AimService; | ||||
| import com.ruoyi.platform.vo.FrameLogPathVo; | |||||
| import com.ruoyi.platform.service.ExperimentInsService; | |||||
| import com.ruoyi.platform.service.ExperimentService; | |||||
| import com.ruoyi.platform.utils.AIM64EncoderUtil; | |||||
| import com.ruoyi.platform.utils.HttpUtils; | |||||
| import com.ruoyi.platform.utils.JacksonUtil; | |||||
| import com.ruoyi.platform.utils.JsonUtils; | |||||
| import com.ruoyi.platform.vo.InsMetricInfoVo; | |||||
| import org.apache.dubbo.container.Main; | |||||
| import org.json.JSONObject; | |||||
| import org.json.JSONTokener; | |||||
| import org.springframework.stereotype.Service; | import org.springframework.stereotype.Service; | ||||
| import javax.annotation.Resource; | |||||
| import java.net.URLEncoder; | |||||
| import java.util.*; | |||||
| import java.util.stream.Collectors; | |||||
| @Service | @Service | ||||
| public class AimServiceImpl implements AimService { | public class AimServiceImpl implements AimService { | ||||
| @Resource | |||||
| private ExperimentInsService experimentInsService; | |||||
| @Override | |||||
| public List<InsMetricInfoVo> getExpTrainInfos(Integer experimentId) throws Exception { | |||||
| String experimentName = "experiment-train-0"+experimentId; | |||||
| return getAimRunInfos("",experimentId); | |||||
| } | |||||
| @Override | @Override | ||||
| public String runAim(FrameLogPathVo frameLogPathVo) { | |||||
| return null; | |||||
| public List<InsMetricInfoVo> getExpEvaluateInfos(Integer experimentId) throws Exception { | |||||
| String experimentName = "experiment-evaluate-0"+experimentId; | |||||
| return getAimRunInfos("",experimentId); | |||||
| } | |||||
| @Override | |||||
| public String getExpMetrics(List<String> runIds) throws Exception { | |||||
| String decode = AIM64EncoderUtil.decode(runIds); | |||||
| return "http://172.20.32.21:7123/api/runs/search/run?query="+decode; | |||||
| } | |||||
| private List<InsMetricInfoVo> getAimRunInfos(String experimentName,Integer experimentId) throws Exception { | |||||
| String encodedUrlString = URLEncoder.encode("run.experiment==\"experiment-0000\"", "UTF-8"); | |||||
| String url = "http://172.20.32.181:30123/api/runs/search/run?query="+encodedUrlString; | |||||
| String s = HttpUtils.sendGetRequest(url); | |||||
| System.out.println(s); | |||||
| List<Map<String, Object>> response = JacksonUtil.parseJSONStr2MapList(s); | |||||
| // TODO: parse aim response to InsMetricInfoVo list | |||||
| if (response == null || response.size() == 0){ | |||||
| return new ArrayList<>(); | |||||
| } | |||||
| //查询实例数据 | |||||
| List<ExperimentIns> byExperimentId = experimentInsService.getByExperimentId(experimentId); | |||||
| // if (byExperimentId == null || byExperimentId.size() == 0){ | |||||
| // return new ArrayList<>(); | |||||
| // } | |||||
| List<InsMetricInfoVo> aimRunInfoList = new ArrayList<>(); | |||||
| for (Map<String, Object> run : response) { | |||||
| InsMetricInfoVo aimRunInfo = new InsMetricInfoVo(); | |||||
| String runHash = (String) run.get("run_hash"); | |||||
| aimRunInfo.setRunId(runHash); | |||||
| Map params= (Map) run.get("params"); | |||||
| Map<String, Object> paramMap = JsonUtils.flattenJson("", params); | |||||
| aimRunInfo.setParams(paramMap); | |||||
| Map<String, Object> tracesMap= (Map<String, Object>) run.get("params"); | |||||
| List<Map<String, Object>> metricList = (List<Map<String, Object>>) tracesMap.get("metric"); | |||||
| //过滤name为__system__开头的对象 | |||||
| aimRunInfo.setMetrics(new HashMap<>()); | |||||
| if (metricList != null && metricList.size() > 0){ | |||||
| List<Map<String, Object>> metricRelList = metricList.stream() | |||||
| .filter(map -> !StringUtils.equals("__system__", (String) map.get("name"))) | |||||
| .collect(Collectors.toList()); | |||||
| if (metricRelList!= null && metricRelList.size() > 0){ | |||||
| Map<String, Object> relMetricMap = new HashMap<>(); | |||||
| for (Map<String, Object> metricMap : metricRelList) { | |||||
| relMetricMap.put((String)metricMap.get("name"), metricMap.get("last_value")); | |||||
| } | |||||
| aimRunInfo.setMetrics(relMetricMap); | |||||
| } | |||||
| } | |||||
| //找到ins | |||||
| for (ExperimentIns ins : byExperimentId) { | |||||
| String metricRecord = ins.getMetricRecord(); | |||||
| if (metricRecord.contains(runHash)){ | |||||
| aimRunInfo.setExperimentInsId(ins.getId()); | |||||
| aimRunInfo.setStatus(ins.getStatus()); | |||||
| aimRunInfo.setStartTime(ins.getStartTime()); | |||||
| } | |||||
| } | |||||
| aimRunInfoList.add(aimRunInfo); | |||||
| } | |||||
| //判断哪个最长 | |||||
| Optional<InsMetricInfoVo> maxMetricsVo = aimRunInfoList.stream() | |||||
| .max((vo1, vo2) -> Integer.compare(vo1.getMetrics().size(), vo2.getMetrics().size())); | |||||
| // 如果找到了,设置 metricsFlag 为 true | |||||
| if (maxMetricsVo.isPresent()) { | |||||
| maxMetricsVo.get().setMetricsFlag(true); | |||||
| } | |||||
| Optional<InsMetricInfoVo> maxParamsVo = aimRunInfoList.stream() | |||||
| .max((vo1, vo2) -> Integer.compare(vo1.getParams().size(), vo2.getParams().size())); | |||||
| // 如果找到了,设置 metricsFlag 为 true | |||||
| if (maxParamsVo.isPresent()) { | |||||
| maxParamsVo.get().setMetricsFlag(true); | |||||
| } | |||||
| return aimRunInfoList; | |||||
| } | } | ||||
| } | } | ||||
| @@ -254,9 +254,13 @@ public class ExperimentServiceImpl implements ExperimentService { | |||||
| //获取训练参数 | //获取训练参数 | ||||
| Map<String, Object> metricRecord = (Map<String, Object>) runResMap.get("metric_record"); | Map<String, Object> metricRecord = (Map<String, Object>) runResMap.get("metric_record"); | ||||
| Map<String, Object> metadata = (Map<String, Object>) data.get("metadata"); | Map<String, Object> metadata = (Map<String, Object>) data.get("metadata"); | ||||
| // 插入记录到实验实例表 | // 插入记录到实验实例表 | ||||
| ExperimentIns experimentIns = new ExperimentIns(); | ExperimentIns experimentIns = new ExperimentIns(); | ||||
| if (metricRecord != null){ | |||||
| experimentIns.setMetricRecord(JacksonUtil.toJSONString(metricRecord)); | |||||
| } | |||||
| experimentIns.setExperimentId(experiment.getId()); | experimentIns.setExperimentId(experiment.getId()); | ||||
| experimentIns.setArgoInsNs((String) metadata.get("namespace")); | experimentIns.setArgoInsNs((String) metadata.get("namespace")); | ||||
| experimentIns.setArgoInsName((String) metadata.get("name")); | experimentIns.setArgoInsName((String) metadata.get("name")); | ||||
| @@ -275,8 +279,9 @@ public class ExperimentServiceImpl implements ExperimentService { | |||||
| Map<String, Object> converMap2 = JsonUtils.jsonToMap(JacksonUtil.replaceInAarry(convertRes, params)); | Map<String, Object> converMap2 = JsonUtils.jsonToMap(JacksonUtil.replaceInAarry(convertRes, params)); | ||||
| Map<String ,Object> dependendcy = (Map<String, Object>)converMap2.get("model_dependency"); | Map<String ,Object> dependendcy = (Map<String, Object>)converMap2.get("model_dependency"); | ||||
| Map<String ,Object> trainInfo = (Map<String, Object>)converMap2.get("component_info"); | Map<String ,Object> trainInfo = (Map<String, Object>)converMap2.get("component_info"); | ||||
| insertModelDependency(dependendcy,trainInfo,insert.getId(),experiment.getName()); | |||||
| if (dependendcy != null && trainInfo != null){ | |||||
| insertModelDependency(dependendcy,trainInfo,insert.getId(),experiment.getName()); | |||||
| } | |||||
| }catch (Exception e){ | }catch (Exception e){ | ||||
| throw new RuntimeException(e); | throw new RuntimeException(e); | ||||
| } | } | ||||
| @@ -0,0 +1,76 @@ | |||||
| package com.ruoyi.platform.utils; | |||||
| import com.alibaba.fastjson.JSON; | |||||
| import java.util.Base64; | |||||
| import java.util.HashMap; | |||||
| import java.util.List; | |||||
| import java.util.Map; | |||||
| public class AIM64EncoderUtil { | |||||
| private static final String AIM64_ENCODING_PREFIX = "O-"; | |||||
| private static final Map<String, String> BS64_REPLACE_CHARACTERS_ENCODING = new HashMap<>(); | |||||
| static { | |||||
| BS64_REPLACE_CHARACTERS_ENCODING.put("=", ""); | |||||
| BS64_REPLACE_CHARACTERS_ENCODING.put("+", "-"); | |||||
| BS64_REPLACE_CHARACTERS_ENCODING.put("/", "_"); | |||||
| } | |||||
| public static String aim64encode(Map<String, Object> value) { | |||||
| String jsonEncoded = JSON.toJSONString(value); | |||||
| String base64Encoded = Base64.getEncoder().encodeToString(jsonEncoded.getBytes()); | |||||
| String aim64Encoded = base64Encoded; | |||||
| for (Map.Entry<String, String> entry : BS64_REPLACE_CHARACTERS_ENCODING.entrySet()) { | |||||
| aim64Encoded = aim64Encoded.replace(entry.getKey(), entry.getValue()); | |||||
| } | |||||
| return AIM64_ENCODING_PREFIX + aim64Encoded; | |||||
| } | |||||
| public static String encode(Map<String, Object> value, boolean oneWayHashing) { | |||||
| if (oneWayHashing) { | |||||
| return md5(JSON.toJSONString(value)); | |||||
| } | |||||
| return aim64encode(value); | |||||
| } | |||||
| private static String md5(String input) { | |||||
| try { | |||||
| java.security.MessageDigest md = java.security.MessageDigest.getInstance("MD5"); | |||||
| byte[] array = md.digest(input.getBytes()); | |||||
| StringBuilder sb = new StringBuilder(); | |||||
| for (byte b : array) { | |||||
| sb.append(Integer.toHexString((b & 0xFF) | 0x100).substring(1, 3)); | |||||
| } | |||||
| return sb.toString(); | |||||
| } catch (java.security.NoSuchAlgorithmException e) { | |||||
| e.printStackTrace(); | |||||
| } | |||||
| return null; | |||||
| } | |||||
| public static String decode(List<String> runIds) { | |||||
| // 确保 runIds 列表的大小为 3 | |||||
| if (runIds == null || runIds.size() == 0) { | |||||
| throw new IllegalArgumentException("runIds 不能为空"); | |||||
| } | |||||
| // 构建查询字符串 | |||||
| StringBuilder queryBuilder = new StringBuilder("run.hash in ["); | |||||
| for (int i = 0; i < runIds.size(); i++) { | |||||
| if (i > 0) { | |||||
| queryBuilder.append(","); | |||||
| } | |||||
| queryBuilder.append("\"").append(runIds.get(i)).append("\""); | |||||
| } | |||||
| queryBuilder.append("]"); | |||||
| String query = queryBuilder.toString(); | |||||
| Map<String, Object> map = new HashMap<>(); | |||||
| map.put("query", query); | |||||
| map.put("advancedMode", true); | |||||
| map.put("advancedQuery", query); | |||||
| String searchQuery = encode(map, false); | |||||
| return searchQuery; | |||||
| } | |||||
| } | |||||
| @@ -25,6 +25,7 @@ import java.security.cert.X509Certificate; | |||||
| import java.util.HashMap; | import java.util.HashMap; | ||||
| import java.util.List; | import java.util.List; | ||||
| import java.util.Map; | import java.util.Map; | ||||
| import java.util.zip.GZIPInputStream; | |||||
| /** | /** | ||||
| * HTTP请求工具类 | * HTTP请求工具类 | ||||
| @@ -447,4 +448,38 @@ public class HttpUtils { | |||||
| return true; | return true; | ||||
| } | } | ||||
| } | } | ||||
| public static String sendGetRequestgzip(String url) throws Exception { | |||||
| String resultStr = null; | |||||
| HttpGet httpGet = new HttpGet(url); | |||||
| httpGet.setHeader("Content-Type", "application/json"); | |||||
| httpGet.setHeader("Accept-Encoding", "gzip, deflate"); | |||||
| try { | |||||
| HttpResponse response = httpClient.execute(httpGet); | |||||
| int responseCode = response.getStatusLine().getStatusCode(); | |||||
| if (responseCode != 200) { | |||||
| throw new IOException("HTTP request failed with response code: " + responseCode); | |||||
| } | |||||
| // 获取响应内容 | |||||
| InputStream responseStream = response.getEntity().getContent(); | |||||
| // 检查响应是否被压缩 | |||||
| if ("gzip".equalsIgnoreCase(response.getEntity().getContentEncoding().getValue())) { | |||||
| responseStream = new GZIPInputStream(responseStream); | |||||
| } | |||||
| // 读取解压缩后的内容 | |||||
| byte[] buffer = new byte[1024]; | |||||
| int len; | |||||
| StringBuilder decompressedString = new StringBuilder(); | |||||
| while ((len = responseStream.read(buffer)) > 0) { | |||||
| decompressedString.append(new String(buffer, 0, len)); | |||||
| } | |||||
| resultStr = decompressedString.toString(); | |||||
| } catch (IOException e) { | |||||
| e.printStackTrace(); | |||||
| } | |||||
| return resultStr; | |||||
| } | |||||
| } | } | ||||
| @@ -1,8 +1,11 @@ | |||||
| package com.ruoyi.platform.utils; | package com.ruoyi.platform.utils; | ||||
| import com.fasterxml.jackson.core.JsonProcessingException; | import com.fasterxml.jackson.core.JsonProcessingException; | ||||
| import com.fasterxml.jackson.databind.ObjectMapper; | import com.fasterxml.jackson.databind.ObjectMapper; | ||||
| import org.json.JSONObject; | |||||
| import java.io.IOException; | import java.io.IOException; | ||||
| import java.util.HashMap; | |||||
| import java.util.Iterator; | |||||
| import java.util.Map; | import java.util.Map; | ||||
| public class JsonUtils { | public class JsonUtils { | ||||
| @@ -28,4 +31,26 @@ public class JsonUtils { | |||||
| public static <T> T jsonToObject(String json, Class<T> clazz) throws IOException { | public static <T> T jsonToObject(String json, Class<T> clazz) throws IOException { | ||||
| return objectMapper.readValue(json, clazz); | return objectMapper.readValue(json, clazz); | ||||
| } | } | ||||
| // 将JSON字符串转换为扁平化的Map | |||||
| public static Map<String, Object> flattenJson(String prefix, Map<String, Object> map) { | |||||
| Map<String, Object> flatMap = new HashMap<>(); | |||||
| Iterator<Map.Entry<String, Object>> entries = map.entrySet().iterator(); | |||||
| while (entries.hasNext()) { | |||||
| Map.Entry<String, Object> entry = entries.next(); | |||||
| String key = entry.getKey(); | |||||
| Object value = entry.getValue(); | |||||
| if (value instanceof Map) { | |||||
| flatMap.putAll(flattenJson(prefix + key + ".", (Map<String, Object>) value)); | |||||
| } else { | |||||
| flatMap.put(prefix + key, value); | |||||
| } | |||||
| } | |||||
| return flatMap; | |||||
| } | |||||
| } | } | ||||
| @@ -0,0 +1,32 @@ | |||||
| package com.ruoyi.platform.vo; | |||||
| import com.fasterxml.jackson.databind.PropertyNamingStrategy; | |||||
| import com.fasterxml.jackson.databind.annotation.JsonNaming; | |||||
| import io.swagger.annotations.ApiModelProperty; | |||||
| import lombok.Data; | |||||
| import java.io.Serializable; | |||||
| import java.util.Date; | |||||
| import java.util.List; | |||||
| import java.util.Map; | |||||
| @JsonNaming(PropertyNamingStrategy.SnakeCaseStrategy.class) | |||||
| @Data | |||||
| public class InsMetricInfoVo implements Serializable { | |||||
| @ApiModelProperty(value = "开始时间") | |||||
| private Date startTime; | |||||
| @ApiModelProperty(value = "实例运行状态") | |||||
| private String status; | |||||
| @ApiModelProperty(value = "使用数据集") | |||||
| private List<Map<String, Object>> dataset; | |||||
| @ApiModelProperty(value = "实例ID") | |||||
| private Integer experimentInsId; | |||||
| @ApiModelProperty(value = "训练指标") | |||||
| private Map metrics; | |||||
| @ApiModelProperty(value = "训练参数") | |||||
| private Map params; | |||||
| @ApiModelProperty(value = "训练记录ID") | |||||
| private String runId; | |||||
| private Boolean metricsFlag = false; | |||||
| private Boolean paramsFlag = false; | |||||
| } | |||||