Browse Source

新增指标对比

pull/95/head
fanshuai 1 year ago
parent
commit
884ad765bf
2 changed files with 63 additions and 44 deletions
  1. +18
    -36
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/AimServiceImpl.java
  2. +45
    -8
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ExperimentServiceImpl.java

+ 18
- 36
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/AimServiceImpl.java View File

@@ -23,8 +23,6 @@ import java.util.stream.Collectors;
public class AimServiceImpl implements AimService {
@Resource
private ExperimentInsService experimentInsService;
@Resource
private ModelDependencyService modelDependencyService;

@Value("${aim.url}")
private String aimUrl;
@@ -44,7 +42,7 @@ public class AimServiceImpl implements AimService {
@Override
public String getExpMetrics(List<String> runIds) throws Exception {
String decode = AIM64EncoderUtil.decode(runIds);
return aimUrl+"/api/runs/search/run?query="+decode;
return aimUrl+"/metrics?select="+decode;
}

private List<InsMetricInfoVo> getAimRunInfos(boolean isTrain,Integer experimentId) throws Exception {
@@ -56,6 +54,7 @@ public class AimServiceImpl implements AimService {
String url = aimProxyUrl+"/api/runs/search/run?query="+encodedUrlString;
String s = HttpUtils.sendGetRequest(url);
List<Map<String, Object>> response = JacksonUtil.parseJSONStr2MapList(s);
System.out.println("response: "+JacksonUtil.toJSONString(response));
if (response == null || response.size() == 0){
return new ArrayList<>();
}
@@ -103,20 +102,15 @@ public class AimServiceImpl implements AimService {
aimRunInfo.setStatus(ins.getStatus());
aimRunInfo.setStartTime(ins.getCreateTime());
Map<String, Object> metricRecordMap = JacksonUtil.parseJSONStr2Map(metricRecordString);

//metricRecord 格式为{"train":[{"task_id":"model-train-35303690","run_id":"5560d78f54314672b60304c8d6ba03b8","experiment_name":"experiment-30-train"}],"evaluate":[{"task_id":"model-train-35303690","run_id":"5560d78f54314672b60304c8d6ba03b8","experiment_name":"experiment-30-train"}]}
//遍历metricRecord,找到当前task_id对应的ModelDependency

if (isTrain){
List<Map<String, Object>> trainList = (List<Map<String, Object>>) metricRecordMap.get("train");
List<String> trainDateSet = getTrainDateSet(trainList, ins.getId(), isTrain);
aimRunInfo.setDataset(trainDateSet);
List<Map<String, Object>> records = (List<Map<String, Object>>) metricRecordMap.get("train");
List<String> datasetList = getTrainDateSet(records, aimrunId);
aimRunInfo.setDataset(datasetList);
}else {
List<Map<String, Object>> trainList = (List<Map<String, Object>>) metricRecordMap.get("evaluate");
List<String> trainDateSet = getTrainDateSet(trainList, ins.getId(), isTrain);
aimRunInfo.setDataset(trainDateSet);
List<Map<String, Object>> records = (List<Map<String, Object>>) metricRecordMap.get("evaluate");
List<String> datasetList = getTrainDateSet(records, aimrunId);
aimRunInfo.setDataset(datasetList);
}

}
}
aimRunInfoList.add(aimRunInfo);
@@ -143,33 +137,21 @@ public class AimServiceImpl implements AimService {
}


private List<String> getTrainDateSet(List<Map<String, Object>> trainList,Integer expInsId,boolean isTrain){
if (trainList == null || trainList.size() == 0){
return new ArrayList<>();
}
private List<String> getTrainDateSet(List<Map<String, Object>> records, String aimrunId){
List<String> datasetList = new ArrayList<>();
for (Map<String, Object> trainMap : trainList) {
String task_id = (String) trainMap.get("task_id");
//modelDependency取到数据集文件
ModelDependency modelDependency = modelDependencyService.queryByInsAndTrainTaskId(expInsId, task_id);
//把数据集文件组装成String后放进List
String datasetString = "";
if (isTrain){
datasetString = modelDependency.getTrainDataset();
}else {
datasetString = modelDependency.getTestDataset();
}
List<Map<String, Object>> datasetListMap = JacksonUtil.parseJSONStr2MapList(datasetString);

if (datasetListMap != null && datasetListMap.size() > 0){
for (Map<String, Object> datasetMap : datasetListMap) {
//[{"dataset_id":20,"dataset_version":"v0.1.0","dataset_name":"手写体识别模型依赖测试训练数据集"}]
String datasetName = (String) datasetMap.get("dataset_name")+":"+(String) datasetMap.get("dataset_version");
for (Map<String, Object> record : records) {
if (StringUtils.equals(aimrunId, (String)record.get("run_id"))) {
List<Map<String, Object>> datasets = (List<Map<String, Object>>) record.get("datasets");
if (datasets == null || datasets.size() == 0){
continue;
}
for (Map<String, Object> dataset : datasets){
String datasetName = (String) dataset.get("dataset_name")+":"+(String) dataset.get("dataset_version");
datasetList.add(datasetName);
}
break;
}
}
return datasetList;
}

}

+ 45
- 8
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ExperimentServiceImpl.java View File

@@ -251,16 +251,14 @@ public class ExperimentServiceImpl implements ExperimentService {
if (data == null || MapUtils.isEmpty(data)) {
throw new RuntimeException("Failed to run workflow.");
}
//获取训练参数
Map<String, Object> metricRecord = (Map<String, Object>) runResMap.get("metric_record");



Map<String, Object> metadata = (Map<String, Object>) data.get("metadata");
// 插入记录到实验实例表
ExperimentIns experimentIns = new ExperimentIns();
if (metricRecord != null){
experimentIns.setMetricRecord(JacksonUtil.toJSONString(metricRecord));
}
//获取训练参数

experimentIns.setExperimentId(experiment.getId());
experimentIns.setArgoInsNs((String) metadata.get("namespace"));
experimentIns.setArgoInsName((String) metadata.get("name"));
@@ -271,14 +269,22 @@ public class ExperimentServiceImpl implements ExperimentService {
//替换argoInsName
String outputString = JsonUtils.mapToJson(output);
experimentIns.setNodesResult(outputString.replace("{{workflow.name}}", (String) metadata.get("name")));
//插入ExperimentIns表中
ExperimentIns insert = experimentInsService.insert(experimentIns);

//插入到模型依赖关系表

//得到dependendcy
Map<String, Object> converMap2 = JsonUtils.jsonToMap(JacksonUtil.replaceInAarry(convertRes, params));
Map<String ,Object> dependendcy = (Map<String, Object>)converMap2.get("model_dependency");
Map<String ,Object> trainInfo = (Map<String, Object>)converMap2.get("component_info");

Map<String, Object> metricRecord = (Map<String, Object>) runResMap.get("metric_record");
if (metricRecord != null){
//把训练用的数据集也放进去
addDatesetToMetric(metricRecord, trainInfo);
experimentIns.setMetricRecord(JacksonUtil.toJSONString(metricRecord));
}
//插入ExperimentIns表中
ExperimentIns insert = experimentInsService.insert(experimentIns);
//插入到模型依赖关系表
if (dependendcy != null && trainInfo != null){
insertModelDependency(dependendcy,trainInfo,insert.getId(),experiment.getName());
}
@@ -289,6 +295,37 @@ public class ExperimentServiceImpl implements ExperimentService {
experiment.setExperimentInsList(updatedExperimentInsList);
return experiment;
}
private void addDatesetToMetric(Map<String, Object> metricRecord, Map<String, Object> trainInfo) {
processMetricPart(metricRecord, trainInfo, "train", "model_train");
processMetricPart(metricRecord, trainInfo, "evaluate", "model_evaluate");
}

private void processMetricPart(Map<String, Object> metricRecord, Map<String, Object> trainInfo, String metricKey, String trainInfoKey) {
List<Map<String, Object>> metricList = (List<Map<String, Object>>) metricRecord.get(metricKey);
if (metricList != null) {
for (Map<String, Object> metricRecordItem : metricList) {
String taskId = (String) metricRecordItem.get("task_id");
Map<String, Object> trainInfoPart = (Map<String, Object>) trainInfo.get(trainInfoKey);
if (trainInfoPart != null) {
Map<String, Object> trainInfoDetails = (Map<String, Object>) trainInfoPart.get(taskId);
if (trainInfoDetails != null) {
List<Map<String, Object>> datasets = (List<Map<String, Object>>) trainInfoDetails.get("datasets");
if (datasets != null) {
//查询名字再回填
for (int i = 0; i < datasets.size(); i++) {
Dataset dataset = datasetService.queryById((Integer) datasets.get(i).get("dataset_id"));
datasets.get(i).put("dataset_name", dataset.getName());
}
metricRecordItem.put("datasets", datasets);
}
}
}
}
}
}



private void insertModelDependency(Map<String ,Object> dependendcy,Map<String ,Object> trainInfo, Integer experimentInsId, String experimentName) throws Exception {
Iterator<Map.Entry<String, Object>> dependendcyIterator = dependendcy.entrySet().iterator();
Map<String, Object> modelTrain = (Map<String, Object>) trainInfo.get("model_train");


Loading…
Cancel
Save