Browse Source

Merge remote-tracking branch 'origin/dev-czh' into dev

dev-DXTZYK
chenzhihang 1 year ago
parent
commit
9e2d288fb8
1 changed files with 187 additions and 66 deletions
  1. +187
    -66
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ExperimentServiceImpl.java

+ 187
- 66
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ExperimentServiceImpl.java View File

@@ -1,5 +1,6 @@
package com.ruoyi.platform.service.impl; package com.ruoyi.platform.service.impl;


import com.alibaba.fastjson2.JSON;
import com.ruoyi.common.security.utils.SecurityUtils; import com.ruoyi.common.security.utils.SecurityUtils;
import com.ruoyi.platform.annotations.CheckDuplicate; import com.ruoyi.platform.annotations.CheckDuplicate;
import com.ruoyi.platform.domain.*; import com.ruoyi.platform.domain.*;
@@ -7,14 +8,18 @@ import com.ruoyi.platform.domain.dependencydomain.ProjectDepency;
import com.ruoyi.platform.domain.dependencydomain.TrainTaskDepency; import com.ruoyi.platform.domain.dependencydomain.TrainTaskDepency;
import com.ruoyi.platform.mapper.ExperimentDao; import com.ruoyi.platform.mapper.ExperimentDao;
import com.ruoyi.platform.mapper.ExperimentInsDao; import com.ruoyi.platform.mapper.ExperimentInsDao;
import com.ruoyi.platform.mapper.ModelDependency1Dao;
import com.ruoyi.platform.service.*; import com.ruoyi.platform.service.*;
import com.ruoyi.platform.utils.HttpUtils; import com.ruoyi.platform.utils.HttpUtils;
import com.ruoyi.platform.utils.JacksonUtil; import com.ruoyi.platform.utils.JacksonUtil;
import com.ruoyi.platform.utils.JsonUtils; import com.ruoyi.platform.utils.JsonUtils;
import com.ruoyi.platform.utils.NewHttpUtils;
import com.ruoyi.platform.vo.ModelMetaVo;
import com.ruoyi.platform.vo.ModelsVo;
import com.ruoyi.platform.vo.NewDatasetVo;
import com.ruoyi.system.api.model.LoginUser; import com.ruoyi.system.api.model.LoginUser;
import org.apache.commons.collections4.MapUtils; import org.apache.commons.collections4.MapUtils;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Value; import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Lazy; import org.springframework.context.annotation.Lazy;
import org.springframework.data.domain.Page; import org.springframework.data.domain.Page;
@@ -56,15 +61,20 @@ public class ExperimentServiceImpl implements ExperimentService {
@Resource @Resource
@Lazy @Lazy
private ExperimentInsService experimentInsService; private ExperimentInsService experimentInsService;

@Resource
private ModelDependency1Dao modelDependency1Dao;

@Value("${argo.url}") @Value("${argo.url}")
private String argoUrl;
private String argoUrl;
@Value("${argo.convert}") @Value("${argo.convert}")
private String argoConvert;
private String argoConvert;
@Value("${argo.workflowRun}") @Value("${argo.workflowRun}")
private String argoWorkflowRun;
private String argoWorkflowRun;
@Value("${argo.workflowStatus}") @Value("${argo.workflowStatus}")
private String argoWorkflowStatus;

private String argoWorkflowStatus;
@Value("${git.localPath}")
String localPath;


/** /**
* 通过ID查询单条数据 * 通过ID查询单条数据
@@ -95,22 +105,20 @@ public class ExperimentServiceImpl implements ExperimentService {
List<Experiment> experimentList = this.experimentDao.queryAllByLimit(experiment, pageRequest); List<Experiment> experimentList = this.experimentDao.queryAllByLimit(experiment, pageRequest);
long total = this.experimentDao.count(experiment); long total = this.experimentDao.count(experiment);
// 存储所有实验的ID列表,查询实验对应的流水线 // 存储所有实验的ID列表,查询实验对应的流水线
for(Experiment exp: experimentList){
for (Experiment exp : experimentList) {
Long workflowId = exp.getWorkflowId(); Long workflowId = exp.getWorkflowId();
Workflow correspondingWorkflow = this.workflowService.queryById(workflowId); Workflow correspondingWorkflow = this.workflowService.queryById(workflowId);
String workflowName = correspondingWorkflow.getName(); String workflowName = correspondingWorkflow.getName();
exp.setWorkflowName(workflowName); exp.setWorkflowName(workflowName);
} }


return new PageImpl<>(experimentList,pageRequest,total);
return new PageImpl<>(experimentList, pageRequest, total);
} }




/** /**
* 分页查询实验状态 * 分页查询实验状态
* *
*
*
* @param experiment 筛选条件 * @param experiment 筛选条件
* @param pageRequest 分页对象 * @param pageRequest 分页对象
* @return 查询结果 * @return 查询结果
@@ -120,7 +128,7 @@ public class ExperimentServiceImpl implements ExperimentService {
// 存储所有实验的ID列表 // 存储所有实验的ID列表
List<Integer> experimentIds = new ArrayList<>(); List<Integer> experimentIds = new ArrayList<>();
//对于每一个从Experiment表中查询出来的id,可能有多个实例,每个实例的实验id是不同的 //对于每一个从Experiment表中查询出来的id,可能有多个实例,每个实例的实验id是不同的
for (Experiment exp: experimentList) {
for (Experiment exp : experimentList) {
//返回所有实验ID相同的实例列表 //返回所有实验ID相同的实例列表
List<ExperimentIns> experimentInsList = this.experimentInsService.getByExperimentId(exp.getId()); List<ExperimentIns> experimentInsList = this.experimentInsService.getByExperimentId(exp.getId());
exp.setExperimentInsList(experimentInsList); exp.setExperimentInsList(experimentInsList);
@@ -178,7 +186,7 @@ public class ExperimentServiceImpl implements ExperimentService {
@Override @Override
public String removeById(Integer id) throws Exception { public String removeById(Integer id) throws Exception {
Experiment experiment = experimentDao.queryById(id); Experiment experiment = experimentDao.queryById(id);
if (experiment==null){
if (experiment == null) {
throw new Exception("实验不存在"); throw new Exception("实验不存在");
} }


@@ -186,17 +194,16 @@ public class ExperimentServiceImpl implements ExperimentService {
LoginUser loginUser = SecurityUtils.getLoginUser(); LoginUser loginUser = SecurityUtils.getLoginUser();
String username = loginUser.getUsername(); String username = loginUser.getUsername();
String createdBy = experiment.getCreateBy(); String createdBy = experiment.getCreateBy();
if (!(StringUtils.equals(username,"admin") || StringUtils.equals(username,createdBy))){
if (!(StringUtils.equals(username, "admin") || StringUtils.equals(username, createdBy))) {
throw new Exception("无权限删除该实验"); throw new Exception("无权限删除该实验");
} }


List<ExperimentIns> experimentInsList = experimentInsService.queryByExperimentId(experiment.getId()); List<ExperimentIns> experimentInsList = experimentInsService.queryByExperimentId(experiment.getId());
if (experimentInsList!=null&&experimentInsList.size()>0){
if (experimentInsList != null && experimentInsList.size() > 0) {
throw new Exception("该实验存在实例,无法删除"); throw new Exception("该实验存在实例,无法删除");
} }
experiment.setState(0); experiment.setState(0);
return this.experimentDao.update(experiment)>0?"删除成功":"删除失败";

return this.experimentDao.update(experiment) > 0 ? "删除成功" : "删除失败";




} }
@@ -215,7 +222,7 @@ public class ExperimentServiceImpl implements ExperimentService {
System.out.println("No experiment"); System.out.println("No experiment");
} }
Workflow workflow = workflowService.queryById(experiment.getWorkflowId()); Workflow workflow = workflowService.queryById(experiment.getWorkflowId());
if(workflow == null) {
if (workflow == null) {
throw new RuntimeException("流水线不存在,请先创建流水线"); throw new RuntimeException("流水线不存在,请先创建流水线");
} }


@@ -236,10 +243,10 @@ public class ExperimentServiceImpl implements ExperimentService {
runReqMap.put("params", params); runReqMap.put("params", params);
// 实验字段的Map,不要写成一行!否则会返回null // 实验字段的Map,不要写成一行!否则会返回null
Map<String, Object> experimentMap = new HashMap<>(); Map<String, Object> experimentMap = new HashMap<>();
experimentMap.put("name", "experiment-"+experiment.getId());
experimentMap.put("name", "experiment-" + experiment.getId());
runReqMap.put("experiment", experimentMap); runReqMap.put("experiment", experimentMap);


Map<String ,Object> output = (Map<String, Object>) converMap.get("output");
Map<String, Object> output = (Map<String, Object>) converMap.get("output");
// 调argo运行接口 // 调argo运行接口
String runRes = HttpUtils.sendPost(argoUrl + argoWorkflowRun, JsonUtils.mapToJson(runReqMap)); String runRes = HttpUtils.sendPost(argoUrl + argoWorkflowRun, JsonUtils.mapToJson(runReqMap));


@@ -254,7 +261,6 @@ public class ExperimentServiceImpl implements ExperimentService {
} }





Map<String, Object> metadata = (Map<String, Object>) data.get("metadata"); Map<String, Object> metadata = (Map<String, Object>) data.get("metadata");
// 插入记录到实验实例表 // 插入记录到实验实例表
ExperimentIns experimentIns = new ExperimentIns(); ExperimentIns experimentIns = new ExperimentIns();
@@ -274,11 +280,11 @@ public class ExperimentServiceImpl implements ExperimentService {


//得到dependendcy //得到dependendcy
Map<String, Object> converMap2 = JsonUtils.jsonToMap(JacksonUtil.replaceInAarry(convertRes, params)); Map<String, Object> converMap2 = JsonUtils.jsonToMap(JacksonUtil.replaceInAarry(convertRes, params));
Map<String ,Object> dependendcy = (Map<String, Object>)converMap2.get("model_dependency");
Map<String ,Object> trainInfo = (Map<String, Object>)converMap2.get("component_info");
Map<String, Object> dependendcy = (Map<String, Object>) converMap2.get("model_dependency");
Map<String, Object> trainInfo = (Map<String, Object>) converMap2.get("component_info");


Map<String, Object> metricRecord = (Map<String, Object>) runResMap.get("metric_record"); Map<String, Object> metricRecord = (Map<String, Object>) runResMap.get("metric_record");
if (metricRecord != null){
if (metricRecord != null) {
//把训练用的数据集也放进去 //把训练用的数据集也放进去
addDatesetToMetric(metricRecord, trainInfo); addDatesetToMetric(metricRecord, trainInfo);
experimentIns.setMetricRecord(JacksonUtil.toJSONString(metricRecord)); experimentIns.setMetricRecord(JacksonUtil.toJSONString(metricRecord));
@@ -286,18 +292,18 @@ public class ExperimentServiceImpl implements ExperimentService {
//插入ExperimentIns表中 //插入ExperimentIns表中
ExperimentIns insert = experimentInsService.insert(experimentIns); ExperimentIns insert = experimentInsService.insert(experimentIns);
//插入到模型依赖关系表 //插入到模型依赖关系表
if (dependendcy != null && trainInfo != null){
insertModelDependency(dependendcy,trainInfo,insert.getId(),experiment.getName());
if (dependendcy != null && trainInfo != null) {
insertModelDependencyNew(dependendcy, trainInfo, insert.getId(), experiment.getName());
} }


Map<String ,Object> datasetDependendcy = (Map<String, Object>)converMap2.get("dataset_dependency");
Map<String, Object> datasetDependendcy = (Map<String, Object>) converMap2.get("dataset_dependency");
//暂存数据集元数据{} //暂存数据集元数据{}
if (datasetDependendcy != null && trainInfo != null){
insertDatasetTempStorage(datasetDependendcy,trainInfo,experiment.getId(),insert.getId(),experiment.getName());
if (datasetDependendcy != null && trainInfo != null) {
insertDatasetTempStorage(datasetDependendcy, trainInfo, experiment.getId(), insert.getId(), experiment.getName());
} }




}catch (Exception e){
} catch (Exception e) {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
List<ExperimentIns> updatedExperimentInsList = experimentInsService.getByExperimentId(id); List<ExperimentIns> updatedExperimentInsList = experimentInsService.getByExperimentId(id);
@@ -305,12 +311,12 @@ public class ExperimentServiceImpl implements ExperimentService {
return experiment; return experiment;
} }


private void addDatesetToMetric(Map<String, Object> metricRecord, Map<String, Object> trainInfo) {
private void addDatesetToMetric(Map<String, Object> metricRecord, Map<String, Object> trainInfo) {
processMetricPart(metricRecord, trainInfo, "train", "model_train"); processMetricPart(metricRecord, trainInfo, "train", "model_train");
processMetricPart(metricRecord, trainInfo, "evaluate", "model_evaluate"); processMetricPart(metricRecord, trainInfo, "evaluate", "model_evaluate");
} }


private void processMetricPart(Map<String, Object> metricRecord, Map<String, Object> trainInfo, String metricKey, String trainInfoKey) {
private void processMetricPart(Map<String, Object> metricRecord, Map<String, Object> trainInfo, String metricKey, String trainInfoKey) {
List<Map<String, Object>> metricList = (List<Map<String, Object>>) metricRecord.get(metricKey); List<Map<String, Object>> metricList = (List<Map<String, Object>>) metricRecord.get(metricKey);
if (metricList != null) { if (metricList != null) {
for (Map<String, Object> metricRecordItem : metricList) { for (Map<String, Object> metricRecordItem : metricList) {
@@ -335,24 +341,23 @@ public class ExperimentServiceImpl implements ExperimentService {
} }





private void insertModelDependency(Map<String ,Object> dependendcy,Map<String ,Object> trainInfo, Integer experimentInsId, String experimentName) throws Exception {
private void insertModelDependency(Map<String, Object> dependendcy, Map<String, Object> trainInfo, Integer experimentInsId, String experimentName) throws Exception {
Iterator<Map.Entry<String, Object>> dependendcyIterator = dependendcy.entrySet().iterator(); Iterator<Map.Entry<String, Object>> dependendcyIterator = dependendcy.entrySet().iterator();
Map<String, Object> modelTrain = (Map<String, Object>) trainInfo.get("model_train");
Map<String, Object> modelEvaluate = (Map<String, Object>) trainInfo.get("model_evaluate");
Map<String, Object> modelExport = (Map<String, Object>) trainInfo.get("model_export");
Map<String, Object> modelTrain = (Map<String, Object>) trainInfo.get("model_train");
Map<String, Object> modelEvaluate = (Map<String, Object>) trainInfo.get("model_evaluate");
Map<String, Object> modelExport = (Map<String, Object>) trainInfo.get("model_export");
while (dependendcyIterator.hasNext()) { while (dependendcyIterator.hasNext()) {
ModelDependency modelDependency = new ModelDependency(); ModelDependency modelDependency = new ModelDependency();
Map.Entry<String, Object> entry = dependendcyIterator.next(); Map.Entry<String, Object> entry = dependendcyIterator.next();
Map<String, Object> modelDel = (Map<String, Object>) entry.getValue(); Map<String, Object> modelDel = (Map<String, Object>) entry.getValue();
Map<String, Object> source = (Map<String, Object>) modelDel.get("source");
Map<String, Object> source = (Map<String, Object>) modelDel.get("source");
List<Map<String, Object>> test = (List<Map<String, Object>>) modelDel.get("test"); List<Map<String, Object>> test = (List<Map<String, Object>>) modelDel.get("test");
List<Map<String, Object>> target = (List<Map<String, Object>>) modelDel.get("target"); List<Map<String, Object>> target = (List<Map<String, Object>>) modelDel.get("target");
String sourceTaskId = (String) source.get("task_id"); String sourceTaskId = (String) source.get("task_id");


Map<String, Object> modelTrainMap = (Map<String, Object>)modelTrain.get(sourceTaskId);
Map<String, Object> modelTrainMap = (Map<String, Object>) modelTrain.get(sourceTaskId);
//处理project数据 //处理project数据
Map<String, String> projectMap = (Map<String, String>) modelTrainMap.get("project");
Map<String, String> projectMap = (Map<String, String>) modelTrainMap.get("project");
ProjectDepency projectDepency = new ProjectDepency(); ProjectDepency projectDepency = new ProjectDepency();
projectDepency.setBranch(projectMap.get("branch")); projectDepency.setBranch(projectMap.get("branch"));
String projectUrl = projectMap.get("url"); String projectUrl = projectMap.get("url");
@@ -361,7 +366,7 @@ public class ExperimentServiceImpl implements ExperimentService {
//依赖项目 //依赖项目
modelDependency.setProjectDependency(JsonUtils.objectToJson(projectDepency)); modelDependency.setProjectDependency(JsonUtils.objectToJson(projectDepency));
//处理镜像 //处理镜像
Map<String, String> imagesMap = (Map<String, String>) modelTrainMap.get("image");
Map<String, String> imagesMap = (Map<String, String>) modelTrainMap.get("image");
modelDependency.setTrainImage(imagesMap.get("name")); modelDependency.setTrainImage(imagesMap.get("name"));
List<Map<String, Object>> trainParamList = (List<Map<String, Object>>) modelTrainMap.get("params"); List<Map<String, Object>> trainParamList = (List<Map<String, Object>>) modelTrainMap.get("params");
modelDependency.setTrainParams(JsonUtils.objectToJson(trainParamList)); modelDependency.setTrainParams(JsonUtils.objectToJson(trainParamList));
@@ -402,24 +407,24 @@ public class ExperimentServiceImpl implements ExperimentService {
List<Map<String, Object>> resultTestDatasets = new ArrayList<Map<String, Object>>(); List<Map<String, Object>> resultTestDatasets = new ArrayList<Map<String, Object>>();
//处理test数据 //处理test数据
if (test != null) { if (test != null) {
for(int i=0;i<test.size();i++){
for (int i = 0; i < test.size(); i++) {
Map<String, Object> testMap = test.get(i); Map<String, Object> testMap = test.get(i);
String testTaskId = (String) testMap.get("task_id"); String testTaskId = (String) testMap.get("task_id");
Map<String, Object> evaluateMap = (Map<String, Object>) modelEvaluate.get(testTaskId);
Map<String, Object> evaluateMap = (Map<String, Object>) modelEvaluate.get(testTaskId);
List<Map<String, Object>> realDataSetList = (List<Map<String, Object>>) evaluateMap.get("datasets"); List<Map<String, Object>> realDataSetList = (List<Map<String, Object>>) evaluateMap.get("datasets");
for(int j=0;j<realDataSetList.size();j++){
for (int j = 0; j < realDataSetList.size(); j++) {
Map<String, Object> realDataSet = realDataSetList.get(j); Map<String, Object> realDataSet = realDataSetList.get(j);
Dataset dataset = datasetService.queryById((Integer) realDataSet.get("dataset_id")); Dataset dataset = datasetService.queryById((Integer) realDataSet.get("dataset_id"));
if (dataset == null){
if (dataset == null) {
throw new Exception("源数据集不存在"); throw new Exception("源数据集不存在");
} }
realDataSet.put("dataset_name", dataset.getName()); realDataSet.put("dataset_name", dataset.getName());
resultTestDatasets.add(realDataSet); resultTestDatasets.add(realDataSet);
} }
}
}


//测试数据集
modelDependency.setTestDataset(JsonUtils.objectToJson(resultTestDatasets));
//测试数据集
modelDependency.setTestDataset(JsonUtils.objectToJson(resultTestDatasets));
} }
//处理target数据 //处理target数据
if (target != null) { if (target != null) {
@@ -437,7 +442,7 @@ public class ExperimentServiceImpl implements ExperimentService {
modelDependencyService.insert(modelDependency); modelDependencyService.insert(modelDependency);
} }
} }
}else {
} else {
modelDependency.setState(2); modelDependency.setState(2);
modelDependencyService.insert(modelDependency); modelDependencyService.insert(modelDependency);
} }
@@ -447,27 +452,25 @@ public class ExperimentServiceImpl implements ExperimentService {


/** /**
* 存储数据集元数据到临时表 * 存储数据集元数据到临时表
*
*
*/
private void insertDatasetTempStorage(Map<String, Object> datasetDependendcy, Map<String, Object> trainInfo, Integer experimentId,Integer experimentInsId, String experimentName) {
*/
private void insertDatasetTempStorage(Map<String, Object> datasetDependendcy, Map<String, Object> trainInfo, Integer experimentId, Integer experimentInsId, String experimentName) {
DatasetTempStorage datasetTempStorage = new DatasetTempStorage(); DatasetTempStorage datasetTempStorage = new DatasetTempStorage();


Iterator<Map.Entry<String, Object>> dependendcyIterator = datasetDependendcy.entrySet().iterator(); Iterator<Map.Entry<String, Object>> dependendcyIterator = datasetDependendcy.entrySet().iterator();
Map<String, Object> datasetExport = (Map<String, Object>) trainInfo.get("dataset_export");
Map<String, Object> datasetPreprocess = (Map<String, Object>) trainInfo.get("dataset_preprocess");
Map<String, Object> datasetExport = (Map<String, Object>) trainInfo.get("dataset_export");
Map<String, Object> datasetPreprocess = (Map<String, Object>) trainInfo.get("dataset_preprocess");
while (dependendcyIterator.hasNext()) { while (dependendcyIterator.hasNext()) {
Map.Entry<String, Object> entry = dependendcyIterator.next(); Map.Entry<String, Object> entry = dependendcyIterator.next();
Map<String, Object> modelDel = (Map<String, Object>) entry.getValue(); Map<String, Object> modelDel = (Map<String, Object>) entry.getValue();
Map<String, Object> source = (Map<String, Object>) modelDel.get("source"); //被处理数据集
Map<String, Object> source = (Map<String, Object>) modelDel.get("source"); //被处理数据集
// List<Map<String, Object>> test = (List<Map<String, Object>>) modelDel.get("test"); // List<Map<String, Object>> test = (List<Map<String, Object>>) modelDel.get("test");
List<Map<String, Object>> target = (List<Map<String, Object>>) modelDel.get("target"); //导出的数据集 List<Map<String, Object>> target = (List<Map<String, Object>>) modelDel.get("target"); //导出的数据集


String sourceTaskId = (String) source.get("task_id"); String sourceTaskId = (String) source.get("task_id");
Map<String, Object> datasetPreprocessMap = (Map<String, Object>)datasetPreprocess.get(sourceTaskId);
Map<String, Object> datasetPreprocessMap = (Map<String, Object>) datasetPreprocess.get(sourceTaskId);
//处理project数据 //处理project数据
Map<String, Object> projectMap = (Map<String, Object>) datasetPreprocessMap.get("project");
Map<String, Object> datasets = (Map<String, Object>) datasetPreprocessMap.get("datasets");
Map<String, Object> projectMap = (Map<String, Object>) datasetPreprocessMap.get("project");
Map<String, Object> datasets = (Map<String, Object>) datasetPreprocessMap.get("datasets");
datasetTempStorage.setName((String) datasets.get("dataset_identifier")); datasetTempStorage.setName((String) datasets.get("dataset_identifier"));
datasetTempStorage.setVersion((String) datasets.get("dataset_version")); datasetTempStorage.setVersion((String) datasets.get("dataset_version"));
// 拼接需要的参数 // 拼接需要的参数
@@ -475,8 +478,8 @@ public class ExperimentServiceImpl implements ExperimentService {
sourceParams.put("experiment_name", experimentName); sourceParams.put("experiment_name", experimentName);
sourceParams.put("experiment_ins_id", experimentInsId); sourceParams.put("experiment_ins_id", experimentInsId);
sourceParams.put("experiment_id", experimentId); sourceParams.put("experiment_id", experimentId);
sourceParams.put("train_name",sourceTaskId);
sourceParams.put("preprocess_code",projectMap);
sourceParams.put("train_name", sourceTaskId);
sourceParams.put("preprocess_code", projectMap);
datasetTempStorage.setSource(JacksonUtil.toJSONString(sourceParams)); datasetTempStorage.setSource(JacksonUtil.toJSONString(sourceParams));
datasetTempStorage.setState(1); datasetTempStorage.setState(1);
datasetTempStorageService.insert(datasetTempStorage); datasetTempStorageService.insert(datasetTempStorage);
@@ -484,9 +487,131 @@ public class ExperimentServiceImpl implements ExperimentService {


} }


private void insertModelDependencyNew(Map<String, Object> dependendcy, Map<String, Object> trainInfo, Integer experimentInsId, String experimentName) throws Exception {
Iterator<Map.Entry<String, Object>> dependendcyIterator = dependendcy.entrySet().iterator();
Map<String, Object> modelTrain = (Map<String, Object>) trainInfo.get("model_train");
Map<String, Object> modelEvaluate = (Map<String, Object>) trainInfo.get("model_evaluate");
Map<String, Object> modelExport = (Map<String, Object>) trainInfo.get("model_export");
while (dependendcyIterator.hasNext()) {
ModelsVo modelMetaVo = new ModelsVo();
ModelDependency1 modelDependency = new ModelDependency1();

Map.Entry<String, Object> entry = dependendcyIterator.next();
Map<String, Object> modelDel = (Map<String, Object>) entry.getValue();
Map<String, Object> source = (Map<String, Object>) modelDel.get("source");
List<Map<String, Object>> test = (List<Map<String, Object>>) modelDel.get("test");
List<Map<String, Object>> target = (List<Map<String, Object>>) modelDel.get("target");
String sourceTaskId = (String) source.get("task_id");

Map<String, Object> modelTrainMap = (Map<String, Object>) modelTrain.get(sourceTaskId);
//处理project数据
Map<String, String> projectMap = (Map<String, String>) modelTrainMap.get("project");
ProjectDepency projectDepency = new ProjectDepency();
projectDepency.setBranch(projectMap.get("branch"));
String projectUrl = projectMap.get("url");
projectDepency.setUrl(projectUrl);
projectDepency.setName(projectUrl.substring(projectUrl.lastIndexOf('/') + 1, projectUrl.length() - 4));
modelMetaVo.setProjectDepency(projectDepency);
//处理镜像
Map<String, String> imagesMap = (Map<String, String>) modelTrainMap.get("image");
modelMetaVo.setImage(imagesMap.get("name"));
//处理训练参数 todo
HashMap<String, Object> trainParam = (HashMap<String, Object>) modelTrainMap.get("params");

//处理source数据
List<Map<String, Object>> modelsList = (List<Map<String, Object>>) modelTrainMap.get("models");
if (modelsList != null) {
Map<String, Object> parentModelMap = modelsList.get(0);
String id = (String) parentModelMap.get("model_id");
String identifier = (String) parentModelMap.get("model_identifier");
String version = (String) parentModelMap.get("model_version");

HashMap<String, Object> map = new HashMap<>();
map.put("repoId", id);
map.put("identifier", identifier);
map.put("version", version);
String parentModel = JSON.toJSONString(map);
modelMetaVo.setParentModel(parentModel);
modelDependency.setParentModel(parentModel);
}

//训练数据集
List<Map<String, Object>> trainDatasetList = (List<Map<String, Object>>) modelTrainMap.get("datasets");
if (trainDatasetList != null) {
List<NewDatasetVo> trainDatasets = new ArrayList<>();
for (Map<String, Object> dataset : trainDatasetList) {
NewDatasetVo newDatasetVo = new NewDatasetVo();
newDatasetVo.setId((Integer) dataset.get("dataset_id"));
newDatasetVo.setName((String) dataset.get("dataset_name"));
newDatasetVo.setVersion((String) dataset.get("dataset_version"));
newDatasetVo.setIdentifier((String) dataset.get("dataset_identifier"));
//todo newDatasetVo.setowner
trainDatasets.add(newDatasetVo);
}
modelMetaVo.setTrainDatasets(trainDatasets);
}
//训练任务
TrainTaskDepency trainTask = new TrainTaskDepency();
trainTask.setTaskId(sourceTaskId);
trainTask.setInsId(experimentInsId);
trainTask.setName(experimentName);
modelMetaVo.setTrainTask(trainTask);

//处理test数据
if (test != null) {
for (Map<String, Object> testMap : test) {
String testTaskId = (String) testMap.get("task_id");
Map<String, Object> evaluateMap = (Map<String, Object>) modelEvaluate.get(testTaskId);
List<Map<String, Object>> testDatasetList = (List<Map<String, Object>>) evaluateMap.get("datasets");
List<NewDatasetVo> testDatasets = new ArrayList<>();
for (Map<String, Object> dataset : testDatasetList) {
NewDatasetVo newDatasetVo = new NewDatasetVo();
newDatasetVo.setId((Integer) dataset.get("dataset_id"));
newDatasetVo.setName((String) dataset.get("dataset_name"));
newDatasetVo.setVersion((String) dataset.get("dataset_version"));
newDatasetVo.setIdentifier((String) dataset.get("dataset_identifier"));
testDatasets.add(newDatasetVo);
}
modelMetaVo.setTestDatasets(testDatasets);
}
}

//处理target数据
LoginUser loginUser = SecurityUtils.getLoginUser();
String gitLinkUsername = loginUser.getSysUser().getGitLinkUsername();
modelMetaVo.setOwner(gitLinkUsername);
if (target != null) {
for (int i = 0; i < target.size(); i++) {
Map<String, Object> targetMap = target.get(i);
String targetTaskId = (String) targetMap.get("task_id");
Map<String, Object> exportMap = (Map<String, Object>) modelExport.get(targetTaskId);
List<Map<String, Object>> modelTargetList = (List<Map<String, Object>>) exportMap.get("models");
for (int j = 0; j < modelTargetList.size(); j++) {
Map<String, Object> targetModel = modelTargetList.get(i);
modelMetaVo.setId((Integer) targetModel.get("model_id"));
modelMetaVo.setIdentifier((String) targetModel.get("model_identifier"));
modelMetaVo.setName((String) targetModel.get("model_name"));
modelMetaVo.setVersionDesc((String) targetModel.get("model_version"));
modelMetaVo.setOwner(gitLinkUsername);

modelsService.newCreateVersion(modelMetaVo);
}
}
} else {
String meta = JSON.toJSONString(modelMetaVo);
modelDependency.setMeta(meta);
modelDependency.setOwner(gitLinkUsername);
modelDependency1Dao.insert(modelDependency);
}

}
}


/** /**
* 被废弃的旧JSON * 被废弃的旧JSON
* @param experiment
*
* @param experiment
* @return * @return
* @throws Exception * @throws Exception
*/ */
@@ -585,7 +710,6 @@ public class ExperimentServiceImpl implements ExperimentService {
// } // }
// } // }
// } // }

@Override @Override
public Experiment addAndRunExperiment(Experiment experiment) throws Exception { public Experiment addAndRunExperiment(Experiment experiment) throws Exception {
// 第一步: 调用add方法插入实验记录到数据库 // 第一步: 调用add方法插入实验记录到数据库
@@ -596,7 +720,7 @@ public class ExperimentServiceImpl implements ExperimentService {
} }
// 调用runExperiment方法运行实验 // 调用runExperiment方法运行实验
try { try {
newExperiment = this.runExperiment(newExperiment.getId());
newExperiment = this.runExperiment(newExperiment.getId());
} catch (Exception e) { } catch (Exception e) {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
@@ -605,7 +729,6 @@ public class ExperimentServiceImpl implements ExperimentService {
} }





/** /**
* 返回实验配置参数 * 返回实验配置参数
* *
@@ -648,6 +771,4 @@ public class ExperimentServiceImpl implements ExperimentService {
} }






} }

Loading…
Cancel
Save