Browse Source

Merge remote-tracking branch 'origin/dev-czh' into dev

dev-DXTZYK
chenzhihang 1 year ago
parent
commit
9e2d288fb8
1 changed files with 187 additions and 66 deletions
  1. +187
    -66
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ExperimentServiceImpl.java

+ 187
- 66
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ExperimentServiceImpl.java View File

@@ -1,5 +1,6 @@
package com.ruoyi.platform.service.impl;

import com.alibaba.fastjson2.JSON;
import com.ruoyi.common.security.utils.SecurityUtils;
import com.ruoyi.platform.annotations.CheckDuplicate;
import com.ruoyi.platform.domain.*;
@@ -7,14 +8,18 @@ import com.ruoyi.platform.domain.dependencydomain.ProjectDepency;
import com.ruoyi.platform.domain.dependencydomain.TrainTaskDepency;
import com.ruoyi.platform.mapper.ExperimentDao;
import com.ruoyi.platform.mapper.ExperimentInsDao;
import com.ruoyi.platform.mapper.ModelDependency1Dao;
import com.ruoyi.platform.service.*;
import com.ruoyi.platform.utils.HttpUtils;
import com.ruoyi.platform.utils.JacksonUtil;
import com.ruoyi.platform.utils.JsonUtils;
import com.ruoyi.platform.utils.NewHttpUtils;
import com.ruoyi.platform.vo.ModelMetaVo;
import com.ruoyi.platform.vo.ModelsVo;
import com.ruoyi.platform.vo.NewDatasetVo;
import com.ruoyi.system.api.model.LoginUser;
import org.apache.commons.collections4.MapUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Lazy;
import org.springframework.data.domain.Page;
@@ -56,15 +61,20 @@ public class ExperimentServiceImpl implements ExperimentService {
@Resource
@Lazy
private ExperimentInsService experimentInsService;

@Resource
private ModelDependency1Dao modelDependency1Dao;

@Value("${argo.url}")
private String argoUrl;
private String argoUrl;
@Value("${argo.convert}")
private String argoConvert;
private String argoConvert;
@Value("${argo.workflowRun}")
private String argoWorkflowRun;
private String argoWorkflowRun;
@Value("${argo.workflowStatus}")
private String argoWorkflowStatus;

private String argoWorkflowStatus;
@Value("${git.localPath}")
String localPath;

/**
* 通过ID查询单条数据
@@ -95,22 +105,20 @@ public class ExperimentServiceImpl implements ExperimentService {
List<Experiment> experimentList = this.experimentDao.queryAllByLimit(experiment, pageRequest);
long total = this.experimentDao.count(experiment);
// 存储所有实验的ID列表,查询实验对应的流水线
for(Experiment exp: experimentList){
for (Experiment exp : experimentList) {
Long workflowId = exp.getWorkflowId();
Workflow correspondingWorkflow = this.workflowService.queryById(workflowId);
String workflowName = correspondingWorkflow.getName();
exp.setWorkflowName(workflowName);
}

return new PageImpl<>(experimentList,pageRequest,total);
return new PageImpl<>(experimentList, pageRequest, total);
}


/**
* 分页查询实验状态
*
*
*
* @param experiment 筛选条件
* @param pageRequest 分页对象
* @return 查询结果
@@ -120,7 +128,7 @@ public class ExperimentServiceImpl implements ExperimentService {
// 存储所有实验的ID列表
List<Integer> experimentIds = new ArrayList<>();
//对于每一个从Experiment表中查询出来的id,可能有多个实例,每个实例的实验id是不同的
for (Experiment exp: experimentList) {
for (Experiment exp : experimentList) {
//返回所有实验ID相同的实例列表
List<ExperimentIns> experimentInsList = this.experimentInsService.getByExperimentId(exp.getId());
exp.setExperimentInsList(experimentInsList);
@@ -178,7 +186,7 @@ public class ExperimentServiceImpl implements ExperimentService {
@Override
public String removeById(Integer id) throws Exception {
Experiment experiment = experimentDao.queryById(id);
if (experiment==null){
if (experiment == null) {
throw new Exception("实验不存在");
}

@@ -186,17 +194,16 @@ public class ExperimentServiceImpl implements ExperimentService {
LoginUser loginUser = SecurityUtils.getLoginUser();
String username = loginUser.getUsername();
String createdBy = experiment.getCreateBy();
if (!(StringUtils.equals(username,"admin") || StringUtils.equals(username,createdBy))){
if (!(StringUtils.equals(username, "admin") || StringUtils.equals(username, createdBy))) {
throw new Exception("无权限删除该实验");
}

List<ExperimentIns> experimentInsList = experimentInsService.queryByExperimentId(experiment.getId());
if (experimentInsList!=null&&experimentInsList.size()>0){
if (experimentInsList != null && experimentInsList.size() > 0) {
throw new Exception("该实验存在实例,无法删除");
}
experiment.setState(0);
return this.experimentDao.update(experiment)>0?"删除成功":"删除失败";

return this.experimentDao.update(experiment) > 0 ? "删除成功" : "删除失败";


}
@@ -215,7 +222,7 @@ public class ExperimentServiceImpl implements ExperimentService {
System.out.println("No experiment");
}
Workflow workflow = workflowService.queryById(experiment.getWorkflowId());
if(workflow == null) {
if (workflow == null) {
throw new RuntimeException("流水线不存在,请先创建流水线");
}

@@ -236,10 +243,10 @@ public class ExperimentServiceImpl implements ExperimentService {
runReqMap.put("params", params);
// 实验字段的Map,不要写成一行!否则会返回null
Map<String, Object> experimentMap = new HashMap<>();
experimentMap.put("name", "experiment-"+experiment.getId());
experimentMap.put("name", "experiment-" + experiment.getId());
runReqMap.put("experiment", experimentMap);

Map<String ,Object> output = (Map<String, Object>) converMap.get("output");
Map<String, Object> output = (Map<String, Object>) converMap.get("output");
// 调argo运行接口
String runRes = HttpUtils.sendPost(argoUrl + argoWorkflowRun, JsonUtils.mapToJson(runReqMap));

@@ -254,7 +261,6 @@ public class ExperimentServiceImpl implements ExperimentService {
}



Map<String, Object> metadata = (Map<String, Object>) data.get("metadata");
// 插入记录到实验实例表
ExperimentIns experimentIns = new ExperimentIns();
@@ -274,11 +280,11 @@ public class ExperimentServiceImpl implements ExperimentService {

//得到dependendcy
Map<String, Object> converMap2 = JsonUtils.jsonToMap(JacksonUtil.replaceInAarry(convertRes, params));
Map<String ,Object> dependendcy = (Map<String, Object>)converMap2.get("model_dependency");
Map<String ,Object> trainInfo = (Map<String, Object>)converMap2.get("component_info");
Map<String, Object> dependendcy = (Map<String, Object>) converMap2.get("model_dependency");
Map<String, Object> trainInfo = (Map<String, Object>) converMap2.get("component_info");

Map<String, Object> metricRecord = (Map<String, Object>) runResMap.get("metric_record");
if (metricRecord != null){
if (metricRecord != null) {
//把训练用的数据集也放进去
addDatesetToMetric(metricRecord, trainInfo);
experimentIns.setMetricRecord(JacksonUtil.toJSONString(metricRecord));
@@ -286,18 +292,18 @@ public class ExperimentServiceImpl implements ExperimentService {
//插入ExperimentIns表中
ExperimentIns insert = experimentInsService.insert(experimentIns);
//插入到模型依赖关系表
if (dependendcy != null && trainInfo != null){
insertModelDependency(dependendcy,trainInfo,insert.getId(),experiment.getName());
if (dependendcy != null && trainInfo != null) {
insertModelDependencyNew(dependendcy, trainInfo, insert.getId(), experiment.getName());
}

Map<String ,Object> datasetDependendcy = (Map<String, Object>)converMap2.get("dataset_dependency");
Map<String, Object> datasetDependendcy = (Map<String, Object>) converMap2.get("dataset_dependency");
//暂存数据集元数据{}
if (datasetDependendcy != null && trainInfo != null){
insertDatasetTempStorage(datasetDependendcy,trainInfo,experiment.getId(),insert.getId(),experiment.getName());
if (datasetDependendcy != null && trainInfo != null) {
insertDatasetTempStorage(datasetDependendcy, trainInfo, experiment.getId(), insert.getId(), experiment.getName());
}


}catch (Exception e){
} catch (Exception e) {
throw new RuntimeException(e);
}
List<ExperimentIns> updatedExperimentInsList = experimentInsService.getByExperimentId(id);
@@ -305,12 +311,12 @@ public class ExperimentServiceImpl implements ExperimentService {
return experiment;
}

private void addDatesetToMetric(Map<String, Object> metricRecord, Map<String, Object> trainInfo) {
private void addDatesetToMetric(Map<String, Object> metricRecord, Map<String, Object> trainInfo) {
processMetricPart(metricRecord, trainInfo, "train", "model_train");
processMetricPart(metricRecord, trainInfo, "evaluate", "model_evaluate");
}

private void processMetricPart(Map<String, Object> metricRecord, Map<String, Object> trainInfo, String metricKey, String trainInfoKey) {
private void processMetricPart(Map<String, Object> metricRecord, Map<String, Object> trainInfo, String metricKey, String trainInfoKey) {
List<Map<String, Object>> metricList = (List<Map<String, Object>>) metricRecord.get(metricKey);
if (metricList != null) {
for (Map<String, Object> metricRecordItem : metricList) {
@@ -335,24 +341,23 @@ public class ExperimentServiceImpl implements ExperimentService {
}



private void insertModelDependency(Map<String ,Object> dependendcy,Map<String ,Object> trainInfo, Integer experimentInsId, String experimentName) throws Exception {
private void insertModelDependency(Map<String, Object> dependendcy, Map<String, Object> trainInfo, Integer experimentInsId, String experimentName) throws Exception {
Iterator<Map.Entry<String, Object>> dependendcyIterator = dependendcy.entrySet().iterator();
Map<String, Object> modelTrain = (Map<String, Object>) trainInfo.get("model_train");
Map<String, Object> modelEvaluate = (Map<String, Object>) trainInfo.get("model_evaluate");
Map<String, Object> modelExport = (Map<String, Object>) trainInfo.get("model_export");
Map<String, Object> modelTrain = (Map<String, Object>) trainInfo.get("model_train");
Map<String, Object> modelEvaluate = (Map<String, Object>) trainInfo.get("model_evaluate");
Map<String, Object> modelExport = (Map<String, Object>) trainInfo.get("model_export");
while (dependendcyIterator.hasNext()) {
ModelDependency modelDependency = new ModelDependency();
Map.Entry<String, Object> entry = dependendcyIterator.next();
Map<String, Object> modelDel = (Map<String, Object>) entry.getValue();
Map<String, Object> source = (Map<String, Object>) modelDel.get("source");
Map<String, Object> source = (Map<String, Object>) modelDel.get("source");
List<Map<String, Object>> test = (List<Map<String, Object>>) modelDel.get("test");
List<Map<String, Object>> target = (List<Map<String, Object>>) modelDel.get("target");
String sourceTaskId = (String) source.get("task_id");

Map<String, Object> modelTrainMap = (Map<String, Object>)modelTrain.get(sourceTaskId);
Map<String, Object> modelTrainMap = (Map<String, Object>) modelTrain.get(sourceTaskId);
//处理project数据
Map<String, String> projectMap = (Map<String, String>) modelTrainMap.get("project");
Map<String, String> projectMap = (Map<String, String>) modelTrainMap.get("project");
ProjectDepency projectDepency = new ProjectDepency();
projectDepency.setBranch(projectMap.get("branch"));
String projectUrl = projectMap.get("url");
@@ -361,7 +366,7 @@ public class ExperimentServiceImpl implements ExperimentService {
//依赖项目
modelDependency.setProjectDependency(JsonUtils.objectToJson(projectDepency));
//处理镜像
Map<String, String> imagesMap = (Map<String, String>) modelTrainMap.get("image");
Map<String, String> imagesMap = (Map<String, String>) modelTrainMap.get("image");
modelDependency.setTrainImage(imagesMap.get("name"));
List<Map<String, Object>> trainParamList = (List<Map<String, Object>>) modelTrainMap.get("params");
modelDependency.setTrainParams(JsonUtils.objectToJson(trainParamList));
@@ -402,24 +407,24 @@ public class ExperimentServiceImpl implements ExperimentService {
List<Map<String, Object>> resultTestDatasets = new ArrayList<Map<String, Object>>();
//处理test数据
if (test != null) {
for(int i=0;i<test.size();i++){
for (int i = 0; i < test.size(); i++) {
Map<String, Object> testMap = test.get(i);
String testTaskId = (String) testMap.get("task_id");
Map<String, Object> evaluateMap = (Map<String, Object>) modelEvaluate.get(testTaskId);
Map<String, Object> evaluateMap = (Map<String, Object>) modelEvaluate.get(testTaskId);
List<Map<String, Object>> realDataSetList = (List<Map<String, Object>>) evaluateMap.get("datasets");
for(int j=0;j<realDataSetList.size();j++){
for (int j = 0; j < realDataSetList.size(); j++) {
Map<String, Object> realDataSet = realDataSetList.get(j);
Dataset dataset = datasetService.queryById((Integer) realDataSet.get("dataset_id"));
if (dataset == null){
if (dataset == null) {
throw new Exception("源数据集不存在");
}
realDataSet.put("dataset_name", dataset.getName());
resultTestDatasets.add(realDataSet);
}
}
}

//测试数据集
modelDependency.setTestDataset(JsonUtils.objectToJson(resultTestDatasets));
//测试数据集
modelDependency.setTestDataset(JsonUtils.objectToJson(resultTestDatasets));
}
//处理target数据
if (target != null) {
@@ -437,7 +442,7 @@ public class ExperimentServiceImpl implements ExperimentService {
modelDependencyService.insert(modelDependency);
}
}
}else {
} else {
modelDependency.setState(2);
modelDependencyService.insert(modelDependency);
}
@@ -447,27 +452,25 @@ public class ExperimentServiceImpl implements ExperimentService {

/**
* 存储数据集元数据到临时表
*
*
*/
private void insertDatasetTempStorage(Map<String, Object> datasetDependendcy, Map<String, Object> trainInfo, Integer experimentId,Integer experimentInsId, String experimentName) {
*/
private void insertDatasetTempStorage(Map<String, Object> datasetDependendcy, Map<String, Object> trainInfo, Integer experimentId, Integer experimentInsId, String experimentName) {
DatasetTempStorage datasetTempStorage = new DatasetTempStorage();

Iterator<Map.Entry<String, Object>> dependendcyIterator = datasetDependendcy.entrySet().iterator();
Map<String, Object> datasetExport = (Map<String, Object>) trainInfo.get("dataset_export");
Map<String, Object> datasetPreprocess = (Map<String, Object>) trainInfo.get("dataset_preprocess");
Map<String, Object> datasetExport = (Map<String, Object>) trainInfo.get("dataset_export");
Map<String, Object> datasetPreprocess = (Map<String, Object>) trainInfo.get("dataset_preprocess");
while (dependendcyIterator.hasNext()) {
Map.Entry<String, Object> entry = dependendcyIterator.next();
Map<String, Object> modelDel = (Map<String, Object>) entry.getValue();
Map<String, Object> source = (Map<String, Object>) modelDel.get("source"); //被处理数据集
Map<String, Object> source = (Map<String, Object>) modelDel.get("source"); //被处理数据集
// List<Map<String, Object>> test = (List<Map<String, Object>>) modelDel.get("test");
List<Map<String, Object>> target = (List<Map<String, Object>>) modelDel.get("target"); //导出的数据集

String sourceTaskId = (String) source.get("task_id");
Map<String, Object> datasetPreprocessMap = (Map<String, Object>)datasetPreprocess.get(sourceTaskId);
Map<String, Object> datasetPreprocessMap = (Map<String, Object>) datasetPreprocess.get(sourceTaskId);
//处理project数据
Map<String, Object> projectMap = (Map<String, Object>) datasetPreprocessMap.get("project");
Map<String, Object> datasets = (Map<String, Object>) datasetPreprocessMap.get("datasets");
Map<String, Object> projectMap = (Map<String, Object>) datasetPreprocessMap.get("project");
Map<String, Object> datasets = (Map<String, Object>) datasetPreprocessMap.get("datasets");
datasetTempStorage.setName((String) datasets.get("dataset_identifier"));
datasetTempStorage.setVersion((String) datasets.get("dataset_version"));
// 拼接需要的参数
@@ -475,8 +478,8 @@ public class ExperimentServiceImpl implements ExperimentService {
sourceParams.put("experiment_name", experimentName);
sourceParams.put("experiment_ins_id", experimentInsId);
sourceParams.put("experiment_id", experimentId);
sourceParams.put("train_name",sourceTaskId);
sourceParams.put("preprocess_code",projectMap);
sourceParams.put("train_name", sourceTaskId);
sourceParams.put("preprocess_code", projectMap);
datasetTempStorage.setSource(JacksonUtil.toJSONString(sourceParams));
datasetTempStorage.setState(1);
datasetTempStorageService.insert(datasetTempStorage);
@@ -484,9 +487,131 @@ public class ExperimentServiceImpl implements ExperimentService {

}

private void insertModelDependencyNew(Map<String, Object> dependendcy, Map<String, Object> trainInfo, Integer experimentInsId, String experimentName) throws Exception {
Iterator<Map.Entry<String, Object>> dependendcyIterator = dependendcy.entrySet().iterator();
Map<String, Object> modelTrain = (Map<String, Object>) trainInfo.get("model_train");
Map<String, Object> modelEvaluate = (Map<String, Object>) trainInfo.get("model_evaluate");
Map<String, Object> modelExport = (Map<String, Object>) trainInfo.get("model_export");
while (dependendcyIterator.hasNext()) {
ModelsVo modelMetaVo = new ModelsVo();
ModelDependency1 modelDependency = new ModelDependency1();

Map.Entry<String, Object> entry = dependendcyIterator.next();
Map<String, Object> modelDel = (Map<String, Object>) entry.getValue();
Map<String, Object> source = (Map<String, Object>) modelDel.get("source");
List<Map<String, Object>> test = (List<Map<String, Object>>) modelDel.get("test");
List<Map<String, Object>> target = (List<Map<String, Object>>) modelDel.get("target");
String sourceTaskId = (String) source.get("task_id");

Map<String, Object> modelTrainMap = (Map<String, Object>) modelTrain.get(sourceTaskId);
//处理project数据
Map<String, String> projectMap = (Map<String, String>) modelTrainMap.get("project");
ProjectDepency projectDepency = new ProjectDepency();
projectDepency.setBranch(projectMap.get("branch"));
String projectUrl = projectMap.get("url");
projectDepency.setUrl(projectUrl);
projectDepency.setName(projectUrl.substring(projectUrl.lastIndexOf('/') + 1, projectUrl.length() - 4));
modelMetaVo.setProjectDepency(projectDepency);
//处理镜像
Map<String, String> imagesMap = (Map<String, String>) modelTrainMap.get("image");
modelMetaVo.setImage(imagesMap.get("name"));
//处理训练参数 todo
HashMap<String, Object> trainParam = (HashMap<String, Object>) modelTrainMap.get("params");

//处理source数据
List<Map<String, Object>> modelsList = (List<Map<String, Object>>) modelTrainMap.get("models");
if (modelsList != null) {
Map<String, Object> parentModelMap = modelsList.get(0);
String id = (String) parentModelMap.get("model_id");
String identifier = (String) parentModelMap.get("model_identifier");
String version = (String) parentModelMap.get("model_version");

HashMap<String, Object> map = new HashMap<>();
map.put("repoId", id);
map.put("identifier", identifier);
map.put("version", version);
String parentModel = JSON.toJSONString(map);
modelMetaVo.setParentModel(parentModel);
modelDependency.setParentModel(parentModel);
}

//训练数据集
List<Map<String, Object>> trainDatasetList = (List<Map<String, Object>>) modelTrainMap.get("datasets");
if (trainDatasetList != null) {
List<NewDatasetVo> trainDatasets = new ArrayList<>();
for (Map<String, Object> dataset : trainDatasetList) {
NewDatasetVo newDatasetVo = new NewDatasetVo();
newDatasetVo.setId((Integer) dataset.get("dataset_id"));
newDatasetVo.setName((String) dataset.get("dataset_name"));
newDatasetVo.setVersion((String) dataset.get("dataset_version"));
newDatasetVo.setIdentifier((String) dataset.get("dataset_identifier"));
//todo newDatasetVo.setowner
trainDatasets.add(newDatasetVo);
}
modelMetaVo.setTrainDatasets(trainDatasets);
}
//训练任务
TrainTaskDepency trainTask = new TrainTaskDepency();
trainTask.setTaskId(sourceTaskId);
trainTask.setInsId(experimentInsId);
trainTask.setName(experimentName);
modelMetaVo.setTrainTask(trainTask);

//处理test数据
if (test != null) {
for (Map<String, Object> testMap : test) {
String testTaskId = (String) testMap.get("task_id");
Map<String, Object> evaluateMap = (Map<String, Object>) modelEvaluate.get(testTaskId);
List<Map<String, Object>> testDatasetList = (List<Map<String, Object>>) evaluateMap.get("datasets");
List<NewDatasetVo> testDatasets = new ArrayList<>();
for (Map<String, Object> dataset : testDatasetList) {
NewDatasetVo newDatasetVo = new NewDatasetVo();
newDatasetVo.setId((Integer) dataset.get("dataset_id"));
newDatasetVo.setName((String) dataset.get("dataset_name"));
newDatasetVo.setVersion((String) dataset.get("dataset_version"));
newDatasetVo.setIdentifier((String) dataset.get("dataset_identifier"));
testDatasets.add(newDatasetVo);
}
modelMetaVo.setTestDatasets(testDatasets);
}
}

//处理target数据
LoginUser loginUser = SecurityUtils.getLoginUser();
String gitLinkUsername = loginUser.getSysUser().getGitLinkUsername();
modelMetaVo.setOwner(gitLinkUsername);
if (target != null) {
for (int i = 0; i < target.size(); i++) {
Map<String, Object> targetMap = target.get(i);
String targetTaskId = (String) targetMap.get("task_id");
Map<String, Object> exportMap = (Map<String, Object>) modelExport.get(targetTaskId);
List<Map<String, Object>> modelTargetList = (List<Map<String, Object>>) exportMap.get("models");
for (int j = 0; j < modelTargetList.size(); j++) {
Map<String, Object> targetModel = modelTargetList.get(i);
modelMetaVo.setId((Integer) targetModel.get("model_id"));
modelMetaVo.setIdentifier((String) targetModel.get("model_identifier"));
modelMetaVo.setName((String) targetModel.get("model_name"));
modelMetaVo.setVersionDesc((String) targetModel.get("model_version"));
modelMetaVo.setOwner(gitLinkUsername);

modelsService.newCreateVersion(modelMetaVo);
}
}
} else {
String meta = JSON.toJSONString(modelMetaVo);
modelDependency.setMeta(meta);
modelDependency.setOwner(gitLinkUsername);
modelDependency1Dao.insert(modelDependency);
}

}
}


/**
* 被废弃的旧JSON
* @param experiment
*
* @param experiment
* @return
* @throws Exception
*/
@@ -585,7 +710,6 @@ public class ExperimentServiceImpl implements ExperimentService {
// }
// }
// }

@Override
public Experiment addAndRunExperiment(Experiment experiment) throws Exception {
// 第一步: 调用add方法插入实验记录到数据库
@@ -596,7 +720,7 @@ public class ExperimentServiceImpl implements ExperimentService {
}
// 调用runExperiment方法运行实验
try {
newExperiment = this.runExperiment(newExperiment.getId());
newExperiment = this.runExperiment(newExperiment.getId());
} catch (Exception e) {
throw new RuntimeException(e);
}
@@ -605,7 +729,6 @@ public class ExperimentServiceImpl implements ExperimentService {
}



/**
* 返回实验配置参数
*
@@ -648,6 +771,4 @@ public class ExperimentServiceImpl implements ExperimentService {
}




}

Loading…
Cancel
Save