Browse Source

Merge remote-tracking branch 'origin/dev' into dev-service-czh

dev-lhz
chenzhihang 1 year ago
parent
commit
9c9291c25a
9 changed files with 251 additions and 57 deletions
  1. +0
    -1
      ruoyi-modules/management-platform/pom.xml
  2. +4
    -4
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/aim/AimController.java
  3. +4
    -2
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/AimService.java
  4. +100
    -11
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/AimServiceImpl.java
  5. +19
    -0
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ExperimentInsServiceImpl.java
  6. +29
    -13
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ExperimentServiceImpl.java
  7. +22
    -18
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ModelsServiceImpl.java
  8. +71
    -8
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/utils/YamlUtils.java
  9. +2
    -0
      ruoyi-modules/management-platform/src/main/resources/mapper/managementPlatform/ModelDependency1DaoMapper.xml

+ 0
- 1
ruoyi-modules/management-platform/pom.xml View File

@@ -242,7 +242,6 @@
<artifactId>jedis</artifactId>
<version>3.6.0</version>
</dependency>

</dependencies>

<build>


+ 4
- 4
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/aim/AimController.java View File

@@ -24,15 +24,15 @@ public class AimController extends BaseController {
@GetMapping("/getExpTrainInfos/{experiment_id}")
@ApiOperation("获取当前实验的模型训练指标信息")
@ApiResponse
public GenericsAjaxResult<List<InsMetricInfoVo>> getExpTrainInfos(@PathVariable("experiment_id") Integer experimentId, @RequestParam("run_id") String runId) throws Exception {
return genericsSuccess(aimService.getExpTrainInfos(experimentId, runId));
public GenericsAjaxResult<List<InsMetricInfoVo>> getExpTrainInfos(@PathVariable("experiment_id") Integer experimentId) throws Exception {
return genericsSuccess(aimService.getExpTrainInfos(experimentId));
}

@GetMapping("/getExpEvaluateInfos/{experiment_id}")
@ApiOperation("获取当前实验的模型推理指标信息")
@ApiResponse
public GenericsAjaxResult<List<InsMetricInfoVo>> getExpEvaluateInfos(@PathVariable("experiment_id") Integer experimentId, @RequestParam("run_id") String runId) throws Exception {
return genericsSuccess(aimService.getExpEvaluateInfos(experimentId, runId));
public GenericsAjaxResult<List<InsMetricInfoVo>> getExpEvaluateInfos(@PathVariable("experiment_id") Integer experimentId) throws Exception {
return genericsSuccess(aimService.getExpEvaluateInfos(experimentId));
}

@PostMapping("/getExpMetrics")


+ 4
- 2
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/AimService.java View File

@@ -6,9 +6,11 @@ import java.util.List;

public interface AimService {

List<InsMetricInfoVo> getExpTrainInfos(Integer experimentId, String runId) throws Exception;
List<InsMetricInfoVo> getExpTrainInfos(Integer experimentId) throws Exception;

List<InsMetricInfoVo> getExpEvaluateInfos(Integer experimentId, String runId) throws Exception;
List<InsMetricInfoVo> getExpTrainInfos1(boolean isTrain, Integer experimentId, String runId) throws Exception;

List<InsMetricInfoVo> getExpEvaluateInfos(Integer experimentId) throws Exception;

String getExpMetrics(List<String> runIds) throws Exception;
}

+ 100
- 11
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/AimServiceImpl.java View File

@@ -30,13 +30,13 @@ public class AimServiceImpl implements AimService {
private NewHttpUtils httpUtils;

@Override
public List<InsMetricInfoVo> getExpTrainInfos(Integer experimentId, String runId) throws Exception {
return getAimRunInfos(true, experimentId, runId);
public List<InsMetricInfoVo> getExpTrainInfos(Integer experimentId) throws Exception {
return getAimRunInfos(true, experimentId);
}

@Override
public List<InsMetricInfoVo> getExpEvaluateInfos(Integer experimentId, String runId) throws Exception {
return getAimRunInfos(false, experimentId, runId);
public List<InsMetricInfoVo> getExpEvaluateInfos(Integer experimentId) throws Exception {
return getAimRunInfos(false, experimentId);
}

@Override
@@ -45,13 +45,12 @@ public class AimServiceImpl implements AimService {
return aimUrl + "/metrics?select=" + decode;
}

private List<InsMetricInfoVo> getAimRunInfos(boolean isTrain, Integer experimentId, String runId) throws Exception {
// String experimentName = "experiment-" + experimentId + "-train";
// if (!isTrain) {
// experimentName = "experiment-" + experimentId + "-evaluate";
// }
// String encodedUrlString = URLEncoder.encode("run.experiment==\"" + experimentName + "\"", "UTF-8");
String encodedUrlString = URLEncoder.encode("run.id==\"" + runId + "\"", "UTF-8");
private List<InsMetricInfoVo> getAimRunInfos(boolean isTrain, Integer experimentId) throws Exception {
String experimentName = "experiment-" + experimentId + "-train";
if (!isTrain) {
experimentName = "experiment-" + experimentId + "-evaluate";
}
String encodedUrlString = URLEncoder.encode("run.experiment==\"" + experimentName + "\"", "UTF-8");
String url = aimProxyUrl + "/api/runs/search/run?query=" + encodedUrlString;
String s = httpUtils.sendGet(url, null);
List<Map<String, Object>> response = JacksonUtil.parseJSONStr2MapList(s);
@@ -139,6 +138,96 @@ public class AimServiceImpl implements AimService {
}


@Override
public List<InsMetricInfoVo> getExpTrainInfos1(boolean isTrain, Integer experimentId, String runId) throws Exception {
String encodedUrlString = URLEncoder.encode("run.id==\"" + runId + "\"", "UTF-8");
String url = aimProxyUrl + "/api/runs/search/run?query=" + encodedUrlString;
String s = httpUtils.sendGet(url, null);
List<Map<String, Object>> response = JacksonUtil.parseJSONStr2MapList(s);
System.out.println("response: " + JacksonUtil.toJSONString(response));
if (response == null || response.size() == 0) {
return new ArrayList<>();
}
//查询实例数据
List<ExperimentIns> byExperimentId = experimentInsService.queryByExperimentId(experimentId);

if (byExperimentId == null || byExperimentId.size() == 0) {
return new ArrayList<>();
}
List<InsMetricInfoVo> aimRunInfoList = new ArrayList<>();
for (Map<String, Object> run : response) {
InsMetricInfoVo aimRunInfo = new InsMetricInfoVo();
String runHash = (String) run.get("run_hash");

aimRunInfo.setRunId(runHash);

Map params = (Map) run.get("params");
Map<String, Object> paramMap = JsonUtils.flattenJson("", params);
aimRunInfo.setParams(paramMap);
String aimrunId = (String) paramMap.get("id");
Map<String, Object> tracesMap = (Map<String, Object>) run.get("traces");
List<Map<String, Object>> metricList = (List<Map<String, Object>>) tracesMap.get("metric");
//过滤name为__system__开头的对象
aimRunInfo.setMetrics(new HashMap<>());
if (metricList != null && metricList.size() > 0) {
List<Map<String, Object>> metricRelList = metricList.stream()
.filter(map -> !StringUtils.startsWith((String) map.get("name"), "__system__"))
.collect(Collectors.toList());
if (metricRelList != null && metricRelList.size() > 0) {
Map<String, Object> relMetricMap = new HashMap<>();
for (Map<String, Object> metricMap : metricRelList) {
relMetricMap.put((String) metricMap.get("name"), metricMap.get("last_value"));
}
aimRunInfo.setMetrics(relMetricMap);
}
}
//找到ins
for (ExperimentIns ins : byExperimentId) {
String metricRecordString = ins.getMetricRecord();
if (StringUtils.isEmpty(metricRecordString)) {
continue;
}
if (metricRecordString.contains(aimrunId)) {
aimRunInfo.setExperimentInsId(ins.getId());
aimRunInfo.setStatus(ins.getStatus());
aimRunInfo.setStartTime(ins.getCreateTime());
Map<String, Object> metricRecordMap = JacksonUtil.parseJSONStr2Map(metricRecordString);
if (isTrain) {
List<Map<String, Object>> records = (List<Map<String, Object>>) metricRecordMap.get("train");
List<String> datasetList = getTrainDateSet(records, aimrunId);
aimRunInfo.setDataset(datasetList);
} else {
List<Map<String, Object>> records = (List<Map<String, Object>>) metricRecordMap.get("evaluate");
List<String> datasetList = getTrainDateSet(records, aimrunId);
aimRunInfo.setDataset(datasetList);
}
aimRunInfoList.add(aimRunInfo);
}
}
}

//判断哪个最长

// 获取所有 metrics 的 key 的并集
Set<String> metricsKeys = (Set<String>) aimRunInfoList.stream()
.map(InsMetricInfoVo::getMetrics)
.flatMap(metrics -> metrics.keySet().stream())
.collect(Collectors.toSet());
// 将并集赋值给每个 InsMetricInfoVo 的 metricsNames 属性
aimRunInfoList.forEach(vo -> vo.setMetricsNames(new ArrayList<>(metricsKeys)));

// 获取所有 params 的 key 的并集
Set<String> paramKeys = (Set<String>) aimRunInfoList.stream()
.map(InsMetricInfoVo::getParams)
.flatMap(params -> params.keySet().stream())
.collect(Collectors.toSet());
// 将并集赋值给每个 InsMetricInfoVo 的 paramsNames 属性
aimRunInfoList.forEach(vo -> vo.setParamsNames(new ArrayList<>(paramKeys)));

return aimRunInfoList;
}


private List<String> getTrainDateSet(List<Map<String, Object>> records, String aimrunId) {
List<String> datasetList = new ArrayList<>();
for (Map<String, Object> record : records) {


+ 19
- 0
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ExperimentInsServiceImpl.java View File

@@ -372,6 +372,25 @@ public class ExperimentInsServiceImpl implements ExperimentInsService {
if (errCode != null && errCode == 0) {
//更新experimentIns,确保状态更新被保存到数据库
ExperimentIns ins = queryStatusFromArgo(experimentIns);
String nodesStatus = ins.getNodesStatus();
Map<String, Object> nodeMap = JsonUtils.jsonToMap(nodesStatus);

// 遍历 map
for (Map.Entry<String, Object> entry : nodeMap.entrySet()) {
// 获取每个 Map 中的值并强制转换为 Map
Map<String, Object> innerMap = (Map<String, Object>) entry.getValue();

// 检查 phase 的值
if (innerMap.containsKey("phase")) {
String phaseValue = (String) innerMap.get("phase");

// 如果值不等于 Succeeded,则赋值为 Failed
if (!StringUtils.equals("Succeeded", phaseValue)) {
innerMap.put("phase", "Failed");
}
}
}
ins.setNodesStatus(JsonUtils.mapToJson(nodeMap));
ins.setStatus("Terminated");
ins.setFinishTime(new Date());
this.experimentInsDao.update(ins);


+ 29
- 13
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ExperimentServiceImpl.java View File

@@ -13,13 +13,12 @@ import com.ruoyi.platform.mapper.ExperimentDao;
import com.ruoyi.platform.mapper.ExperimentInsDao;
import com.ruoyi.platform.mapper.ModelDependency1Dao;
import com.ruoyi.platform.service.*;
import com.ruoyi.platform.utils.FileUtil;
import com.ruoyi.platform.utils.HttpUtils;
import com.ruoyi.platform.utils.JacksonUtil;
import com.ruoyi.platform.utils.JsonUtils;
import com.ruoyi.platform.utils.YamlUtils;
import com.ruoyi.platform.vo.ModelsVo;
import com.ruoyi.platform.vo.NewDatasetVo;
import com.ruoyi.platform.vo.VersionVo;
import com.ruoyi.system.api.model.LoginUser;
import org.apache.commons.collections4.MapUtils;
import org.apache.commons.lang3.StringUtils;
@@ -34,6 +33,7 @@ import org.springframework.stereotype.Service;
import javax.annotation.Resource;
import java.io.IOException;
import java.lang.reflect.Field;
import java.math.BigDecimal;
import java.util.*;

/**
@@ -457,8 +457,6 @@ public class ExperimentServiceImpl implements ExperimentService {
* 存储数据集元数据到临时表
*/
private void insertDatasetTempStorage(Map<String, Object> datasetDependendcy, Map<String, Object> trainInfo, Integer experimentId, Integer experimentInsId, String experimentName) {
DatasetTempStorage datasetTempStorage = new DatasetTempStorage();

Iterator<Map.Entry<String, Object>> dependendcyIterator = datasetDependendcy.entrySet().iterator();
Map<String, Object> datasetExport = (Map<String, Object>) trainInfo.get("dataset_export");
Map<String, Object> datasetPreprocess = (Map<String, Object>) trainInfo.get("general-data-process");
@@ -472,20 +470,33 @@ public class ExperimentServiceImpl implements ExperimentService {
String sourceTaskId = (String) source.get("task_id");
Map<String, Object> datasetPreprocessMap = (Map<String, Object>) datasetPreprocess.get(sourceTaskId);
//处理project数据
Map<String, Object> projectMap = (Map<String, Object>) datasetPreprocessMap.get("project");
Map<String, Object> datasets = (Map<String, Object>) datasetPreprocessMap.get("datasets");
datasetTempStorage.setName((String) datasets.get("dataset_identifier"));
datasetTempStorage.setVersion((String) datasets.get("dataset_version"));
// 拼接需要的参数
Map<String, Object> projectMap = (Map<String, Object>) datasetPreprocessMap.get("project");
Map<String, Object> sourceParams = new HashMap<>();
sourceParams.put("experiment_name", experimentName);
sourceParams.put("experiment_ins_id", experimentInsId);
sourceParams.put("experiment_id", experimentId);
sourceParams.put("train_name", sourceTaskId);
sourceParams.put("preprocess_code", projectMap);
datasetTempStorage.setSource(JacksonUtil.toJSONString(sourceParams));
datasetTempStorage.setState(1);
datasetTempStorageService.insert(datasetTempStorage);

if (target != null && target.size() > 0) {
for (Map<String, Object> targetMap : target) {
String targetTaskId = (String) targetMap.get("task_id");
Map<String, Object> datasetExportMap = (Map<String, Object>) datasetExport.get(targetTaskId);
List<Map> datasets = (List<Map>) datasetExportMap.get("datasets");
if (datasets != null) {
for (Map<String, Object> dataset : datasets) {
DatasetTempStorage datasetTempStorage = new DatasetTempStorage();
datasetTempStorage.setName((String) dataset.get("dataset_identifier"));
datasetTempStorage.setVersion((String) dataset.get("dataset_version"));
datasetTempStorage.setSource(JacksonUtil.toJSONString(sourceParams));
datasetTempStorage.setState(1);
datasetTempStorageService.insert(datasetTempStorage);
}
}

}
}
}

}
@@ -534,8 +545,13 @@ public class ExperimentServiceImpl implements ExperimentService {
for (int i = 0; i < jsonArray.size(); i++) {
JSONObject jsonObject = jsonArray.getJSONObject(i);
String paramName = jsonObject.getString("param_name");
Double paramValue = jsonObject.getDouble("param_value");
trainParam.put(paramName, paramValue);
String paramValue = jsonObject.getString("param_value");
if (YamlUtils.isNumeric(paramValue)) {
BigDecimal bigDecimal = new BigDecimal(paramValue);
trainParam.put(paramName, bigDecimal);
} else {
trainParam.put(paramName, paramValue);
}
}
modelMetaVo.setParams(trainParam);



+ 22
- 18
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ModelsServiceImpl.java View File

@@ -985,8 +985,8 @@ public class ModelsServiceImpl implements ModelsService {
String jsonString = JSON.toJSONString(stringObjectMap);
ModelsVo modelsVo = JSON.parseObject(jsonString, ModelsVo.class);

List<VersionVo> versionVos = new ArrayList<>();
if (!fileDetailsAfterGitPull.isEmpty()) {
List<VersionVo> versionVos = new ArrayList<>();
for (Map<String, Object> fileDetail : fileDetailsAfterGitPull) {
VersionVo versionVo = new VersionVo();
versionVo.setUrl((String) fileDetail.get("filePath"));
@@ -995,8 +995,8 @@ public class ModelsServiceImpl implements ModelsService {
versionVo.setFileSize(FileUtil.formatFileSize(size));
versionVos.add(versionVo);
}
modelsVo.setModelVersionVos(versionVos);
}
modelsVo.setModelVersionVos(versionVos);

return modelsVo;
}
@@ -1162,28 +1162,32 @@ public class ModelsServiceImpl implements ModelsService {

HashMap<String, Object> metrics = modelMetaVo.getMetrics();
JSONArray trainMetrics = (JSONArray) metrics.get("train");
for (int i = 0; i < trainMetrics.size(); i++) {
JSONObject jsonObject = trainMetrics.getJSONObject(i);
String runId = jsonObject.getString("run_id");
List<InsMetricInfoVo> expTrainInfos = aimsService.getExpTrainInfos(modelMetaVo.getTrainTask().getExperimentId(), runId);
for (InsMetricInfoVo expTrainInfo : expTrainInfos) {
Map metrics1 = expTrainInfo.getMetrics();
train.putAll(metrics1);
if (trainMetrics != null) {
for (int i = 0; i < trainMetrics.size(); i++) {
JSONObject jsonObject = trainMetrics.getJSONObject(i);
String runId = jsonObject.getString("run_id");
List<InsMetricInfoVo> expTrainInfos = aimsService.getExpTrainInfos1(true, modelMetaVo.getTrainTask().getExperimentId(), runId);
for (InsMetricInfoVo expTrainInfo : expTrainInfos) {
Map metrics1 = expTrainInfo.getMetrics();
train.putAll(metrics1);
}
}
result.put("train", train);
}
result.put("train", train);

JSONArray testMetrics = (JSONArray) metrics.get("evaluate");
for (int i = 0; i < testMetrics.size(); i++) {
JSONObject jsonObject = testMetrics.getJSONObject(i);
String runId = jsonObject.getString("run_id");
List<InsMetricInfoVo> expTestInfos = aimsService.getExpEvaluateInfos(modelMetaVo.getTrainTask().getExperimentId(), runId);
for (InsMetricInfoVo expTestInfo : expTestInfos) {
Map metrics1 = expTestInfo.getMetrics();
evaluate.putAll(metrics1);
if (testMetrics != null) {
for (int i = 0; i < testMetrics.size(); i++) {
JSONObject jsonObject = testMetrics.getJSONObject(i);
String runId = jsonObject.getString("run_id");
List<InsMetricInfoVo> expTestInfos = aimsService.getExpTrainInfos1(false, modelMetaVo.getTrainTask().getExperimentId(), runId);
for (InsMetricInfoVo expTestInfo : expTestInfos) {
Map metrics1 = expTestInfo.getMetrics();
evaluate.putAll(metrics1);
}
}
result.put("evaluate", evaluate);
}
result.put("evaluate", evaluate);
modelMetaVo.setMetrics(result);
}
}

+ 71
- 8
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/utils/YamlUtils.java View File

@@ -1,29 +1,37 @@
package com.ruoyi.platform.utils;

import com.alibaba.fastjson2.JSON;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.PropertyNamingStrategies;
import com.fasterxml.jackson.dataformat.yaml.YAMLFactory;
import com.fasterxml.jackson.dataformat.yaml.YAMLGenerator;
import org.springframework.stereotype.Component;
import org.yaml.snakeyaml.DumperOptions;
import org.yaml.snakeyaml.Yaml;
import org.yaml.snakeyaml.nodes.Node;
import org.yaml.snakeyaml.representer.Represent;
import org.yaml.snakeyaml.representer.Representer;

import java.io.*;
import java.util.Iterator;
import java.util.Map;
import org.yaml.snakeyaml.nodes.Tag;


//import org.ho.yaml.Yaml;

@Component
public class YamlUtils {

/**
* 将Map对象转换为YAML格式并写入指定路径的文件中
*
* @param data Map对象
* @param path 文件路径
* @param fileName 文件名
* @param data Map对象
* @param path 文件路径
* @param fileName 文件名
*/
public static void generateYamlFile(Map<String, Object> data, String path, String fileName) {
DumperOptions options = new DumperOptions();
options.setDefaultFlowStyle(DumperOptions.FlowStyle.BLOCK);
options.setDefaultScalarStyle(DumperOptions.ScalarStyle.PLAIN);


// 创建Yaml实例
Yaml yaml = new Yaml(options);

@@ -38,12 +46,58 @@ public class YamlUtils {
String fullPath = path + "/" + fileName + ".yaml";

try (FileWriter writer = new FileWriter(fullPath)) {
String dump = yaml.dump(data);

yaml.dump(data, writer);
} catch (IOException e) {
e.printStackTrace();
}
}


// public static void generateYamlFile1(Object data, String path, String fileName) {
// try {
// YAMLFactory yamlFactory = new YAMLFactory();
// yamlFactory.enable(YAMLGenerator.Feature.MINIMIZE_QUOTES);
// yamlFactory.disable(YAMLGenerator.Feature.WRITE_DOC_START_MARKER);
////
// ObjectMapper objectMapper = new ObjectMapper(yamlFactory);
//// ObjectMapper objectMapper = new ObjectMapper();
// objectMapper.setPropertyNamingStrategy(PropertyNamingStrategies.SNAKE_CASE);
// String s = objectMapper.writeValueAsString(data);
//
// File directory = new File(path);
// if (!directory.exists()) {
// boolean isCreated = directory.mkdirs();
// if (!isCreated) {
// throw new RuntimeException("创建路径失败: " + path);
// }
// }
//
// try {
// DumperOptions options = new DumperOptions();
// options.setDefaultFlowStyle(DumperOptions.FlowStyle.BLOCK);
// // 创建Yaml实例
// Yaml yaml = new Yaml(options);
//
// String fullPath = path + "/" + fileName + ".yaml";
// FileWriter writer = new FileWriter(fullPath);
//
// yaml.dump(s, writer);
// } catch (FileNotFoundException e) {
// e.printStackTrace();
//
// } catch (IOException e) {
// throw new RuntimeException(e);
// }
//
//
// } catch (JsonProcessingException e) {
// throw new RuntimeException(e);
// }
// }


/**
* 读取YAML文件并将其内容转换为Map<String, Object>
*
@@ -59,4 +113,13 @@ public class YamlUtils {
return null;
}
}

public static boolean isNumeric(String str) {
try {
double num = Double.parseDouble(str);
return true;
} catch (NumberFormatException e) {
return false;
}
}
}

+ 2
- 0
ruoyi-modules/management-platform/src/main/resources/mapper/managementPlatform/ModelDependency1DaoMapper.xml View File

@@ -47,6 +47,7 @@
and identifier = #{identifier}
and version = #{version}
and state = 2
order by create_time desc limit 1
</select>

<select id="queryByTrainTask" resultType="com.ruoyi.platform.domain.ModelDependency1">
@@ -54,6 +55,7 @@
from model_dependency1
where JSON_CONTAINS(meta, #{trainTask})
and state = 2
order by create_time desc limit 1
</select>

<update id="deleteModel">


Loading…
Cancel
Save