| @@ -23,8 +23,6 @@ import java.util.stream.Collectors; | |||||
| public class AimServiceImpl implements AimService { | public class AimServiceImpl implements AimService { | ||||
| @Resource | @Resource | ||||
| private ExperimentInsService experimentInsService; | private ExperimentInsService experimentInsService; | ||||
| @Resource | |||||
| private ModelDependencyService modelDependencyService; | |||||
| @Value("${aim.url}") | @Value("${aim.url}") | ||||
| private String aimUrl; | private String aimUrl; | ||||
| @@ -44,7 +42,7 @@ public class AimServiceImpl implements AimService { | |||||
| @Override | @Override | ||||
| public String getExpMetrics(List<String> runIds) throws Exception { | public String getExpMetrics(List<String> runIds) throws Exception { | ||||
| String decode = AIM64EncoderUtil.decode(runIds); | String decode = AIM64EncoderUtil.decode(runIds); | ||||
| return aimUrl+"/api/runs/search/run?query="+decode; | |||||
| return aimUrl+"/metrics?select="+decode; | |||||
| } | } | ||||
| private List<InsMetricInfoVo> getAimRunInfos(boolean isTrain,Integer experimentId) throws Exception { | private List<InsMetricInfoVo> getAimRunInfos(boolean isTrain,Integer experimentId) throws Exception { | ||||
| @@ -56,6 +54,7 @@ public class AimServiceImpl implements AimService { | |||||
| String url = aimProxyUrl+"/api/runs/search/run?query="+encodedUrlString; | String url = aimProxyUrl+"/api/runs/search/run?query="+encodedUrlString; | ||||
| String s = HttpUtils.sendGetRequest(url); | String s = HttpUtils.sendGetRequest(url); | ||||
| List<Map<String, Object>> response = JacksonUtil.parseJSONStr2MapList(s); | List<Map<String, Object>> response = JacksonUtil.parseJSONStr2MapList(s); | ||||
| System.out.println("response: "+JacksonUtil.toJSONString(response)); | |||||
| if (response == null || response.size() == 0){ | if (response == null || response.size() == 0){ | ||||
| return new ArrayList<>(); | return new ArrayList<>(); | ||||
| } | } | ||||
| @@ -103,20 +102,15 @@ public class AimServiceImpl implements AimService { | |||||
| aimRunInfo.setStatus(ins.getStatus()); | aimRunInfo.setStatus(ins.getStatus()); | ||||
| aimRunInfo.setStartTime(ins.getCreateTime()); | aimRunInfo.setStartTime(ins.getCreateTime()); | ||||
| Map<String, Object> metricRecordMap = JacksonUtil.parseJSONStr2Map(metricRecordString); | Map<String, Object> metricRecordMap = JacksonUtil.parseJSONStr2Map(metricRecordString); | ||||
| //metricRecord 格式为{"train":[{"task_id":"model-train-35303690","run_id":"5560d78f54314672b60304c8d6ba03b8","experiment_name":"experiment-30-train"}],"evaluate":[{"task_id":"model-train-35303690","run_id":"5560d78f54314672b60304c8d6ba03b8","experiment_name":"experiment-30-train"}]} | |||||
| //遍历metricRecord,找到当前task_id对应的ModelDependency | |||||
| if (isTrain){ | if (isTrain){ | ||||
| List<Map<String, Object>> trainList = (List<Map<String, Object>>) metricRecordMap.get("train"); | |||||
| List<String> trainDateSet = getTrainDateSet(trainList, ins.getId(), isTrain); | |||||
| aimRunInfo.setDataset(trainDateSet); | |||||
| List<Map<String, Object>> records = (List<Map<String, Object>>) metricRecordMap.get("train"); | |||||
| List<String> datasetList = getTrainDateSet(records, aimrunId); | |||||
| aimRunInfo.setDataset(datasetList); | |||||
| }else { | }else { | ||||
| List<Map<String, Object>> trainList = (List<Map<String, Object>>) metricRecordMap.get("evaluate"); | |||||
| List<String> trainDateSet = getTrainDateSet(trainList, ins.getId(), isTrain); | |||||
| aimRunInfo.setDataset(trainDateSet); | |||||
| List<Map<String, Object>> records = (List<Map<String, Object>>) metricRecordMap.get("evaluate"); | |||||
| List<String> datasetList = getTrainDateSet(records, aimrunId); | |||||
| aimRunInfo.setDataset(datasetList); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| aimRunInfoList.add(aimRunInfo); | aimRunInfoList.add(aimRunInfo); | ||||
| @@ -143,33 +137,21 @@ public class AimServiceImpl implements AimService { | |||||
| } | } | ||||
| private List<String> getTrainDateSet(List<Map<String, Object>> trainList,Integer expInsId,boolean isTrain){ | |||||
| if (trainList == null || trainList.size() == 0){ | |||||
| return new ArrayList<>(); | |||||
| } | |||||
| private List<String> getTrainDateSet(List<Map<String, Object>> records, String aimrunId){ | |||||
| List<String> datasetList = new ArrayList<>(); | List<String> datasetList = new ArrayList<>(); | ||||
| for (Map<String, Object> trainMap : trainList) { | |||||
| String task_id = (String) trainMap.get("task_id"); | |||||
| //modelDependency取到数据集文件 | |||||
| ModelDependency modelDependency = modelDependencyService.queryByInsAndTrainTaskId(expInsId, task_id); | |||||
| //把数据集文件组装成String后放进List | |||||
| String datasetString = ""; | |||||
| if (isTrain){ | |||||
| datasetString = modelDependency.getTrainDataset(); | |||||
| }else { | |||||
| datasetString = modelDependency.getTestDataset(); | |||||
| } | |||||
| List<Map<String, Object>> datasetListMap = JacksonUtil.parseJSONStr2MapList(datasetString); | |||||
| if (datasetListMap != null && datasetListMap.size() > 0){ | |||||
| for (Map<String, Object> datasetMap : datasetListMap) { | |||||
| //[{"dataset_id":20,"dataset_version":"v0.1.0","dataset_name":"手写体识别模型依赖测试训练数据集"}] | |||||
| String datasetName = (String) datasetMap.get("dataset_name")+":"+(String) datasetMap.get("dataset_version"); | |||||
| for (Map<String, Object> record : records) { | |||||
| if (StringUtils.equals(aimrunId, (String)record.get("run_id"))) { | |||||
| List<Map<String, Object>> datasets = (List<Map<String, Object>>) record.get("datasets"); | |||||
| if (datasets == null || datasets.size() == 0){ | |||||
| continue; | |||||
| } | |||||
| for (Map<String, Object> dataset : datasets){ | |||||
| String datasetName = (String) dataset.get("dataset_name")+":"+(String) dataset.get("dataset_version"); | |||||
| datasetList.add(datasetName); | datasetList.add(datasetName); | ||||
| } | } | ||||
| break; | |||||
| } | } | ||||
| } | } | ||||
| return datasetList; | return datasetList; | ||||
| } | } | ||||
| } | } | ||||
| @@ -251,16 +251,14 @@ public class ExperimentServiceImpl implements ExperimentService { | |||||
| if (data == null || MapUtils.isEmpty(data)) { | if (data == null || MapUtils.isEmpty(data)) { | ||||
| throw new RuntimeException("Failed to run workflow."); | throw new RuntimeException("Failed to run workflow."); | ||||
| } | } | ||||
| //获取训练参数 | |||||
| Map<String, Object> metricRecord = (Map<String, Object>) runResMap.get("metric_record"); | |||||
| Map<String, Object> metadata = (Map<String, Object>) data.get("metadata"); | Map<String, Object> metadata = (Map<String, Object>) data.get("metadata"); | ||||
| // 插入记录到实验实例表 | // 插入记录到实验实例表 | ||||
| ExperimentIns experimentIns = new ExperimentIns(); | ExperimentIns experimentIns = new ExperimentIns(); | ||||
| if (metricRecord != null){ | |||||
| experimentIns.setMetricRecord(JacksonUtil.toJSONString(metricRecord)); | |||||
| } | |||||
| //获取训练参数 | |||||
| experimentIns.setExperimentId(experiment.getId()); | experimentIns.setExperimentId(experiment.getId()); | ||||
| experimentIns.setArgoInsNs((String) metadata.get("namespace")); | experimentIns.setArgoInsNs((String) metadata.get("namespace")); | ||||
| experimentIns.setArgoInsName((String) metadata.get("name")); | experimentIns.setArgoInsName((String) metadata.get("name")); | ||||
| @@ -271,14 +269,22 @@ public class ExperimentServiceImpl implements ExperimentService { | |||||
| //替换argoInsName | //替换argoInsName | ||||
| String outputString = JsonUtils.mapToJson(output); | String outputString = JsonUtils.mapToJson(output); | ||||
| experimentIns.setNodesResult(outputString.replace("{{workflow.name}}", (String) metadata.get("name"))); | experimentIns.setNodesResult(outputString.replace("{{workflow.name}}", (String) metadata.get("name"))); | ||||
| //插入ExperimentIns表中 | |||||
| ExperimentIns insert = experimentInsService.insert(experimentIns); | |||||
| //插入到模型依赖关系表 | |||||
| //得到dependendcy | //得到dependendcy | ||||
| Map<String, Object> converMap2 = JsonUtils.jsonToMap(JacksonUtil.replaceInAarry(convertRes, params)); | Map<String, Object> converMap2 = JsonUtils.jsonToMap(JacksonUtil.replaceInAarry(convertRes, params)); | ||||
| Map<String ,Object> dependendcy = (Map<String, Object>)converMap2.get("model_dependency"); | Map<String ,Object> dependendcy = (Map<String, Object>)converMap2.get("model_dependency"); | ||||
| Map<String ,Object> trainInfo = (Map<String, Object>)converMap2.get("component_info"); | Map<String ,Object> trainInfo = (Map<String, Object>)converMap2.get("component_info"); | ||||
| Map<String, Object> metricRecord = (Map<String, Object>) runResMap.get("metric_record"); | |||||
| if (metricRecord != null){ | |||||
| //把训练用的数据集也放进去 | |||||
| addDatesetToMetric(metricRecord, trainInfo); | |||||
| experimentIns.setMetricRecord(JacksonUtil.toJSONString(metricRecord)); | |||||
| } | |||||
| //插入ExperimentIns表中 | |||||
| ExperimentIns insert = experimentInsService.insert(experimentIns); | |||||
| //插入到模型依赖关系表 | |||||
| if (dependendcy != null && trainInfo != null){ | if (dependendcy != null && trainInfo != null){ | ||||
| insertModelDependency(dependendcy,trainInfo,insert.getId(),experiment.getName()); | insertModelDependency(dependendcy,trainInfo,insert.getId(),experiment.getName()); | ||||
| } | } | ||||
| @@ -289,6 +295,37 @@ public class ExperimentServiceImpl implements ExperimentService { | |||||
| experiment.setExperimentInsList(updatedExperimentInsList); | experiment.setExperimentInsList(updatedExperimentInsList); | ||||
| return experiment; | return experiment; | ||||
| } | } | ||||
| private void addDatesetToMetric(Map<String, Object> metricRecord, Map<String, Object> trainInfo) { | |||||
| processMetricPart(metricRecord, trainInfo, "train", "model_train"); | |||||
| processMetricPart(metricRecord, trainInfo, "evaluate", "model_evaluate"); | |||||
| } | |||||
| private void processMetricPart(Map<String, Object> metricRecord, Map<String, Object> trainInfo, String metricKey, String trainInfoKey) { | |||||
| List<Map<String, Object>> metricList = (List<Map<String, Object>>) metricRecord.get(metricKey); | |||||
| if (metricList != null) { | |||||
| for (Map<String, Object> metricRecordItem : metricList) { | |||||
| String taskId = (String) metricRecordItem.get("task_id"); | |||||
| Map<String, Object> trainInfoPart = (Map<String, Object>) trainInfo.get(trainInfoKey); | |||||
| if (trainInfoPart != null) { | |||||
| Map<String, Object> trainInfoDetails = (Map<String, Object>) trainInfoPart.get(taskId); | |||||
| if (trainInfoDetails != null) { | |||||
| List<Map<String, Object>> datasets = (List<Map<String, Object>>) trainInfoDetails.get("datasets"); | |||||
| if (datasets != null) { | |||||
| //查询名字再回填 | |||||
| for (int i = 0; i < datasets.size(); i++) { | |||||
| Dataset dataset = datasetService.queryById((Integer) datasets.get(i).get("dataset_id")); | |||||
| datasets.get(i).put("dataset_name", dataset.getName()); | |||||
| } | |||||
| metricRecordItem.put("datasets", datasets); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| private void insertModelDependency(Map<String ,Object> dependendcy,Map<String ,Object> trainInfo, Integer experimentInsId, String experimentName) throws Exception { | private void insertModelDependency(Map<String ,Object> dependendcy,Map<String ,Object> trainInfo, Integer experimentInsId, String experimentName) throws Exception { | ||||
| Iterator<Map.Entry<String, Object>> dependendcyIterator = dependendcy.entrySet().iterator(); | Iterator<Map.Entry<String, Object>> dependendcyIterator = dependendcy.entrySet().iterator(); | ||||
| Map<String, Object> modelTrain = (Map<String, Object>) trainInfo.get("model_train"); | Map<String, Object> modelTrain = (Map<String, Object>) trainInfo.get("model_train"); | ||||