diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/AimServiceImpl.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/AimServiceImpl.java index 6ec8f43c..fdd7b5c9 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/AimServiceImpl.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/AimServiceImpl.java @@ -23,8 +23,6 @@ import java.util.stream.Collectors; public class AimServiceImpl implements AimService { @Resource private ExperimentInsService experimentInsService; - @Resource - private ModelDependencyService modelDependencyService; @Value("${aim.url}") private String aimUrl; @@ -44,7 +42,7 @@ public class AimServiceImpl implements AimService { @Override public String getExpMetrics(List runIds) throws Exception { String decode = AIM64EncoderUtil.decode(runIds); - return aimUrl+"/api/runs/search/run?query="+decode; + return aimUrl+"/metrics?select="+decode; } private List getAimRunInfos(boolean isTrain,Integer experimentId) throws Exception { @@ -56,6 +54,7 @@ public class AimServiceImpl implements AimService { String url = aimProxyUrl+"/api/runs/search/run?query="+encodedUrlString; String s = HttpUtils.sendGetRequest(url); List> response = JacksonUtil.parseJSONStr2MapList(s); + System.out.println("response: "+JacksonUtil.toJSONString(response)); if (response == null || response.size() == 0){ return new ArrayList<>(); } @@ -103,20 +102,15 @@ public class AimServiceImpl implements AimService { aimRunInfo.setStatus(ins.getStatus()); aimRunInfo.setStartTime(ins.getCreateTime()); Map metricRecordMap = JacksonUtil.parseJSONStr2Map(metricRecordString); - - //metricRecord 格式为{"train":[{"task_id":"model-train-35303690","run_id":"5560d78f54314672b60304c8d6ba03b8","experiment_name":"experiment-30-train"}],"evaluate":[{"task_id":"model-train-35303690","run_id":"5560d78f54314672b60304c8d6ba03b8","experiment_name":"experiment-30-train"}]} - //遍历metricRecord,找到当前task_id对应的ModelDependency - if (isTrain){ - List> trainList = (List>) metricRecordMap.get("train"); - List trainDateSet = getTrainDateSet(trainList, ins.getId(), isTrain); - aimRunInfo.setDataset(trainDateSet); + List> records = (List>) metricRecordMap.get("train"); + List datasetList = getTrainDateSet(records, aimrunId); + aimRunInfo.setDataset(datasetList); }else { - List> trainList = (List>) metricRecordMap.get("evaluate"); - List trainDateSet = getTrainDateSet(trainList, ins.getId(), isTrain); - aimRunInfo.setDataset(trainDateSet); + List> records = (List>) metricRecordMap.get("evaluate"); + List datasetList = getTrainDateSet(records, aimrunId); + aimRunInfo.setDataset(datasetList); } - } } aimRunInfoList.add(aimRunInfo); @@ -143,33 +137,21 @@ public class AimServiceImpl implements AimService { } - private List getTrainDateSet(List> trainList,Integer expInsId,boolean isTrain){ - if (trainList == null || trainList.size() == 0){ - return new ArrayList<>(); - } + private List getTrainDateSet(List> records, String aimrunId){ List datasetList = new ArrayList<>(); - for (Map trainMap : trainList) { - String task_id = (String) trainMap.get("task_id"); - //modelDependency取到数据集文件 - ModelDependency modelDependency = modelDependencyService.queryByInsAndTrainTaskId(expInsId, task_id); - //把数据集文件组装成String后放进List - String datasetString = ""; - if (isTrain){ - datasetString = modelDependency.getTrainDataset(); - }else { - datasetString = modelDependency.getTestDataset(); - } - List> datasetListMap = JacksonUtil.parseJSONStr2MapList(datasetString); - - if (datasetListMap != null && datasetListMap.size() > 0){ - for (Map datasetMap : datasetListMap) { - //[{"dataset_id":20,"dataset_version":"v0.1.0","dataset_name":"手写体识别模型依赖测试训练数据集"}] - String datasetName = (String) datasetMap.get("dataset_name")+":"+(String) datasetMap.get("dataset_version"); + for (Map record : records) { + if (StringUtils.equals(aimrunId, (String)record.get("run_id"))) { + List> datasets = (List>) record.get("datasets"); + if (datasets == null || datasets.size() == 0){ + continue; + } + for (Map dataset : datasets){ + String datasetName = (String) dataset.get("dataset_name")+":"+(String) dataset.get("dataset_version"); datasetList.add(datasetName); } + break; } } return datasetList; } - } diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ExperimentServiceImpl.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ExperimentServiceImpl.java index d1ae81c4..388155d6 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ExperimentServiceImpl.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ExperimentServiceImpl.java @@ -251,16 +251,14 @@ public class ExperimentServiceImpl implements ExperimentService { if (data == null || MapUtils.isEmpty(data)) { throw new RuntimeException("Failed to run workflow."); } - //获取训练参数 - Map metricRecord = (Map) runResMap.get("metric_record"); + Map metadata = (Map) data.get("metadata"); // 插入记录到实验实例表 ExperimentIns experimentIns = new ExperimentIns(); - if (metricRecord != null){ - experimentIns.setMetricRecord(JacksonUtil.toJSONString(metricRecord)); - } + //获取训练参数 + experimentIns.setExperimentId(experiment.getId()); experimentIns.setArgoInsNs((String) metadata.get("namespace")); experimentIns.setArgoInsName((String) metadata.get("name")); @@ -271,14 +269,22 @@ public class ExperimentServiceImpl implements ExperimentService { //替换argoInsName String outputString = JsonUtils.mapToJson(output); experimentIns.setNodesResult(outputString.replace("{{workflow.name}}", (String) metadata.get("name"))); - //插入ExperimentIns表中 - ExperimentIns insert = experimentInsService.insert(experimentIns); - //插入到模型依赖关系表 + //得到dependendcy Map converMap2 = JsonUtils.jsonToMap(JacksonUtil.replaceInAarry(convertRes, params)); Map dependendcy = (Map)converMap2.get("model_dependency"); Map trainInfo = (Map)converMap2.get("component_info"); + + Map metricRecord = (Map) runResMap.get("metric_record"); + if (metricRecord != null){ + //把训练用的数据集也放进去 + addDatesetToMetric(metricRecord, trainInfo); + experimentIns.setMetricRecord(JacksonUtil.toJSONString(metricRecord)); + } + //插入ExperimentIns表中 + ExperimentIns insert = experimentInsService.insert(experimentIns); + //插入到模型依赖关系表 if (dependendcy != null && trainInfo != null){ insertModelDependency(dependendcy,trainInfo,insert.getId(),experiment.getName()); } @@ -289,6 +295,37 @@ public class ExperimentServiceImpl implements ExperimentService { experiment.setExperimentInsList(updatedExperimentInsList); return experiment; } + private void addDatesetToMetric(Map metricRecord, Map trainInfo) { + processMetricPart(metricRecord, trainInfo, "train", "model_train"); + processMetricPart(metricRecord, trainInfo, "evaluate", "model_evaluate"); + } + + private void processMetricPart(Map metricRecord, Map trainInfo, String metricKey, String trainInfoKey) { + List> metricList = (List>) metricRecord.get(metricKey); + if (metricList != null) { + for (Map metricRecordItem : metricList) { + String taskId = (String) metricRecordItem.get("task_id"); + Map trainInfoPart = (Map) trainInfo.get(trainInfoKey); + if (trainInfoPart != null) { + Map trainInfoDetails = (Map) trainInfoPart.get(taskId); + if (trainInfoDetails != null) { + List> datasets = (List>) trainInfoDetails.get("datasets"); + if (datasets != null) { + //查询名字再回填 + for (int i = 0; i < datasets.size(); i++) { + Dataset dataset = datasetService.queryById((Integer) datasets.get(i).get("dataset_id")); + datasets.get(i).put("dataset_name", dataset.getName()); + } + metricRecordItem.put("datasets", datasets); + } + } + } + } + } + } + + + private void insertModelDependency(Map dependendcy,Map trainInfo, Integer experimentInsId, String experimentName) throws Exception { Iterator> dependendcyIterator = dependendcy.entrySet().iterator(); Map modelTrain = (Map) trainInfo.get("model_train");