diff --git a/ruoyi-modules/management-platform/pom.xml b/ruoyi-modules/management-platform/pom.xml index e7c559fd..ba689d0e 100644 --- a/ruoyi-modules/management-platform/pom.xml +++ b/ruoyi-modules/management-platform/pom.xml @@ -242,7 +242,6 @@ jedis 3.6.0 - diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/aim/AimController.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/aim/AimController.java index bc800545..c61fb789 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/aim/AimController.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/aim/AimController.java @@ -24,15 +24,15 @@ public class AimController extends BaseController { @GetMapping("/getExpTrainInfos/{experiment_id}") @ApiOperation("获取当前实验的模型训练指标信息") @ApiResponse - public GenericsAjaxResult> getExpTrainInfos(@PathVariable("experiment_id") Integer experimentId, @RequestParam("run_id") String runId) throws Exception { - return genericsSuccess(aimService.getExpTrainInfos(experimentId, runId)); + public GenericsAjaxResult> getExpTrainInfos(@PathVariable("experiment_id") Integer experimentId) throws Exception { + return genericsSuccess(aimService.getExpTrainInfos(experimentId)); } @GetMapping("/getExpEvaluateInfos/{experiment_id}") @ApiOperation("获取当前实验的模型推理指标信息") @ApiResponse - public GenericsAjaxResult> getExpEvaluateInfos(@PathVariable("experiment_id") Integer experimentId, @RequestParam("run_id") String runId) throws Exception { - return genericsSuccess(aimService.getExpEvaluateInfos(experimentId, runId)); + public GenericsAjaxResult> getExpEvaluateInfos(@PathVariable("experiment_id") Integer experimentId) throws Exception { + return genericsSuccess(aimService.getExpEvaluateInfos(experimentId)); } @PostMapping("/getExpMetrics") diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/AimService.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/AimService.java index c7f91d8f..9f3868f5 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/AimService.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/AimService.java @@ -6,9 +6,11 @@ import java.util.List; public interface AimService { - List getExpTrainInfos(Integer experimentId, String runId) throws Exception; + List getExpTrainInfos(Integer experimentId) throws Exception; - List getExpEvaluateInfos(Integer experimentId, String runId) throws Exception; + List getExpTrainInfos1(boolean isTrain, Integer experimentId, String runId) throws Exception; + + List getExpEvaluateInfos(Integer experimentId) throws Exception; String getExpMetrics(List runIds) throws Exception; } diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/AimServiceImpl.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/AimServiceImpl.java index 3db73a01..1b754404 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/AimServiceImpl.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/AimServiceImpl.java @@ -30,13 +30,13 @@ public class AimServiceImpl implements AimService { private NewHttpUtils httpUtils; @Override - public List getExpTrainInfos(Integer experimentId, String runId) throws Exception { - return getAimRunInfos(true, experimentId, runId); + public List getExpTrainInfos(Integer experimentId) throws Exception { + return getAimRunInfos(true, experimentId); } @Override - public List getExpEvaluateInfos(Integer experimentId, String runId) throws Exception { - return getAimRunInfos(false, experimentId, runId); + public List getExpEvaluateInfos(Integer experimentId) throws Exception { + return getAimRunInfos(false, experimentId); } @Override @@ -45,13 +45,12 @@ public class AimServiceImpl implements AimService { return aimUrl + "/metrics?select=" + decode; } - private List getAimRunInfos(boolean isTrain, Integer experimentId, String runId) throws Exception { -// String experimentName = "experiment-" + experimentId + "-train"; -// if (!isTrain) { -// experimentName = "experiment-" + experimentId + "-evaluate"; -// } -// String encodedUrlString = URLEncoder.encode("run.experiment==\"" + experimentName + "\"", "UTF-8"); - String encodedUrlString = URLEncoder.encode("run.id==\"" + runId + "\"", "UTF-8"); + private List getAimRunInfos(boolean isTrain, Integer experimentId) throws Exception { + String experimentName = "experiment-" + experimentId + "-train"; + if (!isTrain) { + experimentName = "experiment-" + experimentId + "-evaluate"; + } + String encodedUrlString = URLEncoder.encode("run.experiment==\"" + experimentName + "\"", "UTF-8"); String url = aimProxyUrl + "/api/runs/search/run?query=" + encodedUrlString; String s = httpUtils.sendGet(url, null); List> response = JacksonUtil.parseJSONStr2MapList(s); @@ -139,6 +138,96 @@ public class AimServiceImpl implements AimService { } + @Override + public List getExpTrainInfos1(boolean isTrain, Integer experimentId, String runId) throws Exception { + String encodedUrlString = URLEncoder.encode("run.id==\"" + runId + "\"", "UTF-8"); + String url = aimProxyUrl + "/api/runs/search/run?query=" + encodedUrlString; + String s = httpUtils.sendGet(url, null); + List> response = JacksonUtil.parseJSONStr2MapList(s); + System.out.println("response: " + JacksonUtil.toJSONString(response)); + if (response == null || response.size() == 0) { + return new ArrayList<>(); + } + //查询实例数据 + List byExperimentId = experimentInsService.queryByExperimentId(experimentId); + + if (byExperimentId == null || byExperimentId.size() == 0) { + return new ArrayList<>(); + } + List aimRunInfoList = new ArrayList<>(); + for (Map run : response) { + InsMetricInfoVo aimRunInfo = new InsMetricInfoVo(); + String runHash = (String) run.get("run_hash"); + + aimRunInfo.setRunId(runHash); + + Map params = (Map) run.get("params"); + Map paramMap = JsonUtils.flattenJson("", params); + aimRunInfo.setParams(paramMap); + String aimrunId = (String) paramMap.get("id"); + Map tracesMap = (Map) run.get("traces"); + List> metricList = (List>) tracesMap.get("metric"); + //过滤name为__system__开头的对象 + aimRunInfo.setMetrics(new HashMap<>()); + if (metricList != null && metricList.size() > 0) { + List> metricRelList = metricList.stream() + .filter(map -> !StringUtils.startsWith((String) map.get("name"), "__system__")) + .collect(Collectors.toList()); + if (metricRelList != null && metricRelList.size() > 0) { + Map relMetricMap = new HashMap<>(); + for (Map metricMap : metricRelList) { + relMetricMap.put((String) metricMap.get("name"), metricMap.get("last_value")); + } + aimRunInfo.setMetrics(relMetricMap); + } + } + //找到ins + for (ExperimentIns ins : byExperimentId) { + String metricRecordString = ins.getMetricRecord(); + if (StringUtils.isEmpty(metricRecordString)) { + continue; + } + if (metricRecordString.contains(aimrunId)) { + aimRunInfo.setExperimentInsId(ins.getId()); + aimRunInfo.setStatus(ins.getStatus()); + aimRunInfo.setStartTime(ins.getCreateTime()); + Map metricRecordMap = JacksonUtil.parseJSONStr2Map(metricRecordString); + if (isTrain) { + List> records = (List>) metricRecordMap.get("train"); + List datasetList = getTrainDateSet(records, aimrunId); + aimRunInfo.setDataset(datasetList); + } else { + List> records = (List>) metricRecordMap.get("evaluate"); + List datasetList = getTrainDateSet(records, aimrunId); + aimRunInfo.setDataset(datasetList); + } + aimRunInfoList.add(aimRunInfo); + } + } + } + + //判断哪个最长 + + // 获取所有 metrics 的 key 的并集 + Set metricsKeys = (Set) aimRunInfoList.stream() + .map(InsMetricInfoVo::getMetrics) + .flatMap(metrics -> metrics.keySet().stream()) + .collect(Collectors.toSet()); + // 将并集赋值给每个 InsMetricInfoVo 的 metricsNames 属性 + aimRunInfoList.forEach(vo -> vo.setMetricsNames(new ArrayList<>(metricsKeys))); + + // 获取所有 params 的 key 的并集 + Set paramKeys = (Set) aimRunInfoList.stream() + .map(InsMetricInfoVo::getParams) + .flatMap(params -> params.keySet().stream()) + .collect(Collectors.toSet()); + // 将并集赋值给每个 InsMetricInfoVo 的 paramsNames 属性 + aimRunInfoList.forEach(vo -> vo.setParamsNames(new ArrayList<>(paramKeys))); + + return aimRunInfoList; + } + + private List getTrainDateSet(List> records, String aimrunId) { List datasetList = new ArrayList<>(); for (Map record : records) { diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ExperimentInsServiceImpl.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ExperimentInsServiceImpl.java index e9457bcb..7b9c57f7 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ExperimentInsServiceImpl.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ExperimentInsServiceImpl.java @@ -372,6 +372,25 @@ public class ExperimentInsServiceImpl implements ExperimentInsService { if (errCode != null && errCode == 0) { //更新experimentIns,确保状态更新被保存到数据库 ExperimentIns ins = queryStatusFromArgo(experimentIns); + String nodesStatus = ins.getNodesStatus(); + Map nodeMap = JsonUtils.jsonToMap(nodesStatus); + + // 遍历 map + for (Map.Entry entry : nodeMap.entrySet()) { + // 获取每个 Map 中的值并强制转换为 Map + Map innerMap = (Map) entry.getValue(); + + // 检查 phase 的值 + if (innerMap.containsKey("phase")) { + String phaseValue = (String) innerMap.get("phase"); + + // 如果值不等于 Succeeded,则赋值为 Failed + if (!StringUtils.equals("Succeeded", phaseValue)) { + innerMap.put("phase", "Failed"); + } + } + } + ins.setNodesStatus(JsonUtils.mapToJson(nodeMap)); ins.setStatus("Terminated"); ins.setFinishTime(new Date()); this.experimentInsDao.update(ins); diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ExperimentServiceImpl.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ExperimentServiceImpl.java index 6738a01c..f91e1a2f 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ExperimentServiceImpl.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ExperimentServiceImpl.java @@ -13,13 +13,12 @@ import com.ruoyi.platform.mapper.ExperimentDao; import com.ruoyi.platform.mapper.ExperimentInsDao; import com.ruoyi.platform.mapper.ModelDependency1Dao; import com.ruoyi.platform.service.*; -import com.ruoyi.platform.utils.FileUtil; import com.ruoyi.platform.utils.HttpUtils; import com.ruoyi.platform.utils.JacksonUtil; import com.ruoyi.platform.utils.JsonUtils; +import com.ruoyi.platform.utils.YamlUtils; import com.ruoyi.platform.vo.ModelsVo; import com.ruoyi.platform.vo.NewDatasetVo; -import com.ruoyi.platform.vo.VersionVo; import com.ruoyi.system.api.model.LoginUser; import org.apache.commons.collections4.MapUtils; import org.apache.commons.lang3.StringUtils; @@ -34,6 +33,7 @@ import org.springframework.stereotype.Service; import javax.annotation.Resource; import java.io.IOException; import java.lang.reflect.Field; +import java.math.BigDecimal; import java.util.*; /** @@ -457,8 +457,6 @@ public class ExperimentServiceImpl implements ExperimentService { * 存储数据集元数据到临时表 */ private void insertDatasetTempStorage(Map datasetDependendcy, Map trainInfo, Integer experimentId, Integer experimentInsId, String experimentName) { - DatasetTempStorage datasetTempStorage = new DatasetTempStorage(); - Iterator> dependendcyIterator = datasetDependendcy.entrySet().iterator(); Map datasetExport = (Map) trainInfo.get("dataset_export"); Map datasetPreprocess = (Map) trainInfo.get("general-data-process"); @@ -472,20 +470,33 @@ public class ExperimentServiceImpl implements ExperimentService { String sourceTaskId = (String) source.get("task_id"); Map datasetPreprocessMap = (Map) datasetPreprocess.get(sourceTaskId); //处理project数据 - Map projectMap = (Map) datasetPreprocessMap.get("project"); - Map datasets = (Map) datasetPreprocessMap.get("datasets"); - datasetTempStorage.setName((String) datasets.get("dataset_identifier")); - datasetTempStorage.setVersion((String) datasets.get("dataset_version")); // 拼接需要的参数 + Map projectMap = (Map) datasetPreprocessMap.get("project"); Map sourceParams = new HashMap<>(); sourceParams.put("experiment_name", experimentName); sourceParams.put("experiment_ins_id", experimentInsId); sourceParams.put("experiment_id", experimentId); sourceParams.put("train_name", sourceTaskId); sourceParams.put("preprocess_code", projectMap); - datasetTempStorage.setSource(JacksonUtil.toJSONString(sourceParams)); - datasetTempStorage.setState(1); - datasetTempStorageService.insert(datasetTempStorage); + + if (target != null && target.size() > 0) { + for (Map targetMap : target) { + String targetTaskId = (String) targetMap.get("task_id"); + Map datasetExportMap = (Map) datasetExport.get(targetTaskId); + List datasets = (List) datasetExportMap.get("datasets"); + if (datasets != null) { + for (Map dataset : datasets) { + DatasetTempStorage datasetTempStorage = new DatasetTempStorage(); + datasetTempStorage.setName((String) dataset.get("dataset_identifier")); + datasetTempStorage.setVersion((String) dataset.get("dataset_version")); + datasetTempStorage.setSource(JacksonUtil.toJSONString(sourceParams)); + datasetTempStorage.setState(1); + datasetTempStorageService.insert(datasetTempStorage); + } + } + + } + } } } @@ -534,8 +545,13 @@ public class ExperimentServiceImpl implements ExperimentService { for (int i = 0; i < jsonArray.size(); i++) { JSONObject jsonObject = jsonArray.getJSONObject(i); String paramName = jsonObject.getString("param_name"); - Double paramValue = jsonObject.getDouble("param_value"); - trainParam.put(paramName, paramValue); + String paramValue = jsonObject.getString("param_value"); + if (YamlUtils.isNumeric(paramValue)) { + BigDecimal bigDecimal = new BigDecimal(paramValue); + trainParam.put(paramName, bigDecimal); + } else { + trainParam.put(paramName, paramValue); + } } modelMetaVo.setParams(trainParam); diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ModelsServiceImpl.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ModelsServiceImpl.java index 7b51ee35..514a0d36 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ModelsServiceImpl.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ModelsServiceImpl.java @@ -985,8 +985,8 @@ public class ModelsServiceImpl implements ModelsService { String jsonString = JSON.toJSONString(stringObjectMap); ModelsVo modelsVo = JSON.parseObject(jsonString, ModelsVo.class); - List versionVos = new ArrayList<>(); if (!fileDetailsAfterGitPull.isEmpty()) { + List versionVos = new ArrayList<>(); for (Map fileDetail : fileDetailsAfterGitPull) { VersionVo versionVo = new VersionVo(); versionVo.setUrl((String) fileDetail.get("filePath")); @@ -995,8 +995,8 @@ public class ModelsServiceImpl implements ModelsService { versionVo.setFileSize(FileUtil.formatFileSize(size)); versionVos.add(versionVo); } + modelsVo.setModelVersionVos(versionVos); } - modelsVo.setModelVersionVos(versionVos); return modelsVo; } @@ -1162,28 +1162,32 @@ public class ModelsServiceImpl implements ModelsService { HashMap metrics = modelMetaVo.getMetrics(); JSONArray trainMetrics = (JSONArray) metrics.get("train"); - for (int i = 0; i < trainMetrics.size(); i++) { - JSONObject jsonObject = trainMetrics.getJSONObject(i); - String runId = jsonObject.getString("run_id"); - List expTrainInfos = aimsService.getExpTrainInfos(modelMetaVo.getTrainTask().getExperimentId(), runId); - for (InsMetricInfoVo expTrainInfo : expTrainInfos) { - Map metrics1 = expTrainInfo.getMetrics(); - train.putAll(metrics1); + if (trainMetrics != null) { + for (int i = 0; i < trainMetrics.size(); i++) { + JSONObject jsonObject = trainMetrics.getJSONObject(i); + String runId = jsonObject.getString("run_id"); + List expTrainInfos = aimsService.getExpTrainInfos1(true, modelMetaVo.getTrainTask().getExperimentId(), runId); + for (InsMetricInfoVo expTrainInfo : expTrainInfos) { + Map metrics1 = expTrainInfo.getMetrics(); + train.putAll(metrics1); + } } + result.put("train", train); } - result.put("train", train); JSONArray testMetrics = (JSONArray) metrics.get("evaluate"); - for (int i = 0; i < testMetrics.size(); i++) { - JSONObject jsonObject = testMetrics.getJSONObject(i); - String runId = jsonObject.getString("run_id"); - List expTestInfos = aimsService.getExpEvaluateInfos(modelMetaVo.getTrainTask().getExperimentId(), runId); - for (InsMetricInfoVo expTestInfo : expTestInfos) { - Map metrics1 = expTestInfo.getMetrics(); - evaluate.putAll(metrics1); + if (testMetrics != null) { + for (int i = 0; i < testMetrics.size(); i++) { + JSONObject jsonObject = testMetrics.getJSONObject(i); + String runId = jsonObject.getString("run_id"); + List expTestInfos = aimsService.getExpTrainInfos1(false, modelMetaVo.getTrainTask().getExperimentId(), runId); + for (InsMetricInfoVo expTestInfo : expTestInfos) { + Map metrics1 = expTestInfo.getMetrics(); + evaluate.putAll(metrics1); + } } + result.put("evaluate", evaluate); } - result.put("evaluate", evaluate); modelMetaVo.setMetrics(result); } } diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/utils/YamlUtils.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/utils/YamlUtils.java index 3f235af1..142406cf 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/utils/YamlUtils.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/utils/YamlUtils.java @@ -1,29 +1,37 @@ package com.ruoyi.platform.utils; +import com.alibaba.fastjson2.JSON; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.PropertyNamingStrategies; +import com.fasterxml.jackson.dataformat.yaml.YAMLFactory; +import com.fasterxml.jackson.dataformat.yaml.YAMLGenerator; +import org.springframework.stereotype.Component; import org.yaml.snakeyaml.DumperOptions; import org.yaml.snakeyaml.Yaml; -import org.yaml.snakeyaml.nodes.Node; -import org.yaml.snakeyaml.representer.Represent; -import org.yaml.snakeyaml.representer.Representer; import java.io.*; -import java.util.Iterator; import java.util.Map; -import org.yaml.snakeyaml.nodes.Tag; + + +//import org.ho.yaml.Yaml; + +@Component public class YamlUtils { /** * 将Map对象转换为YAML格式并写入指定路径的文件中 * - * @param data Map对象 - * @param path 文件路径 - * @param fileName 文件名 + * @param data Map对象 + * @param path 文件路径 + * @param fileName 文件名 */ public static void generateYamlFile(Map data, String path, String fileName) { DumperOptions options = new DumperOptions(); options.setDefaultFlowStyle(DumperOptions.FlowStyle.BLOCK); options.setDefaultScalarStyle(DumperOptions.ScalarStyle.PLAIN); + // 创建Yaml实例 Yaml yaml = new Yaml(options); @@ -38,12 +46,58 @@ public class YamlUtils { String fullPath = path + "/" + fileName + ".yaml"; try (FileWriter writer = new FileWriter(fullPath)) { + String dump = yaml.dump(data); + yaml.dump(data, writer); } catch (IOException e) { e.printStackTrace(); } } + +// public static void generateYamlFile1(Object data, String path, String fileName) { +// try { +// YAMLFactory yamlFactory = new YAMLFactory(); +// yamlFactory.enable(YAMLGenerator.Feature.MINIMIZE_QUOTES); +// yamlFactory.disable(YAMLGenerator.Feature.WRITE_DOC_START_MARKER); +//// +// ObjectMapper objectMapper = new ObjectMapper(yamlFactory); +//// ObjectMapper objectMapper = new ObjectMapper(); +// objectMapper.setPropertyNamingStrategy(PropertyNamingStrategies.SNAKE_CASE); +// String s = objectMapper.writeValueAsString(data); +// +// File directory = new File(path); +// if (!directory.exists()) { +// boolean isCreated = directory.mkdirs(); +// if (!isCreated) { +// throw new RuntimeException("创建路径失败: " + path); +// } +// } +// +// try { +// DumperOptions options = new DumperOptions(); +// options.setDefaultFlowStyle(DumperOptions.FlowStyle.BLOCK); +// // 创建Yaml实例 +// Yaml yaml = new Yaml(options); +// +// String fullPath = path + "/" + fileName + ".yaml"; +// FileWriter writer = new FileWriter(fullPath); +// +// yaml.dump(s, writer); +// } catch (FileNotFoundException e) { +// e.printStackTrace(); +// +// } catch (IOException e) { +// throw new RuntimeException(e); +// } +// +// +// } catch (JsonProcessingException e) { +// throw new RuntimeException(e); +// } +// } + + /** * 读取YAML文件并将其内容转换为Map * @@ -59,4 +113,13 @@ public class YamlUtils { return null; } } + + public static boolean isNumeric(String str) { + try { + double num = Double.parseDouble(str); + return true; + } catch (NumberFormatException e) { + return false; + } + } } diff --git a/ruoyi-modules/management-platform/src/main/resources/mapper/managementPlatform/ModelDependency1DaoMapper.xml b/ruoyi-modules/management-platform/src/main/resources/mapper/managementPlatform/ModelDependency1DaoMapper.xml index c92d842c..596c931a 100644 --- a/ruoyi-modules/management-platform/src/main/resources/mapper/managementPlatform/ModelDependency1DaoMapper.xml +++ b/ruoyi-modules/management-platform/src/main/resources/mapper/managementPlatform/ModelDependency1DaoMapper.xml @@ -47,6 +47,7 @@ and identifier = #{identifier} and version = #{version} and state = 2 + order by create_time desc limit 1