diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ExperimentServiceImpl.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ExperimentServiceImpl.java index e14e2009..85c2f54f 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ExperimentServiceImpl.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ExperimentServiceImpl.java @@ -1,5 +1,6 @@ package com.ruoyi.platform.service.impl; +import com.alibaba.fastjson2.JSON; import com.ruoyi.common.security.utils.SecurityUtils; import com.ruoyi.platform.annotations.CheckDuplicate; import com.ruoyi.platform.domain.*; @@ -7,14 +8,18 @@ import com.ruoyi.platform.domain.dependencydomain.ProjectDepency; import com.ruoyi.platform.domain.dependencydomain.TrainTaskDepency; import com.ruoyi.platform.mapper.ExperimentDao; import com.ruoyi.platform.mapper.ExperimentInsDao; +import com.ruoyi.platform.mapper.ModelDependency1Dao; import com.ruoyi.platform.service.*; import com.ruoyi.platform.utils.HttpUtils; import com.ruoyi.platform.utils.JacksonUtil; import com.ruoyi.platform.utils.JsonUtils; -import com.ruoyi.platform.utils.NewHttpUtils; +import com.ruoyi.platform.vo.ModelMetaVo; +import com.ruoyi.platform.vo.ModelsVo; +import com.ruoyi.platform.vo.NewDatasetVo; import com.ruoyi.system.api.model.LoginUser; import org.apache.commons.collections4.MapUtils; import org.apache.commons.lang3.StringUtils; +import org.springframework.beans.BeanUtils; import org.springframework.beans.factory.annotation.Value; import org.springframework.context.annotation.Lazy; import org.springframework.data.domain.Page; @@ -56,15 +61,20 @@ public class ExperimentServiceImpl implements ExperimentService { @Resource @Lazy private ExperimentInsService experimentInsService; + + @Resource + private ModelDependency1Dao modelDependency1Dao; + @Value("${argo.url}") - private String argoUrl; + private String argoUrl; @Value("${argo.convert}") - private String argoConvert; + private String argoConvert; @Value("${argo.workflowRun}") - private String argoWorkflowRun; + private String argoWorkflowRun; @Value("${argo.workflowStatus}") - private String argoWorkflowStatus; - + private String argoWorkflowStatus; + @Value("${git.localPath}") + String localPath; /** * 通过ID查询单条数据 @@ -95,22 +105,20 @@ public class ExperimentServiceImpl implements ExperimentService { List experimentList = this.experimentDao.queryAllByLimit(experiment, pageRequest); long total = this.experimentDao.count(experiment); // 存储所有实验的ID列表,查询实验对应的流水线 - for(Experiment exp: experimentList){ + for (Experiment exp : experimentList) { Long workflowId = exp.getWorkflowId(); Workflow correspondingWorkflow = this.workflowService.queryById(workflowId); String workflowName = correspondingWorkflow.getName(); exp.setWorkflowName(workflowName); } - return new PageImpl<>(experimentList,pageRequest,total); + return new PageImpl<>(experimentList, pageRequest, total); } /** * 分页查询实验状态 * - * - * * @param experiment 筛选条件 * @param pageRequest 分页对象 * @return 查询结果 @@ -120,7 +128,7 @@ public class ExperimentServiceImpl implements ExperimentService { // 存储所有实验的ID列表 List experimentIds = new ArrayList<>(); //对于每一个从Experiment表中查询出来的id,可能有多个实例,每个实例的实验id是不同的 - for (Experiment exp: experimentList) { + for (Experiment exp : experimentList) { //返回所有实验ID相同的实例列表 List experimentInsList = this.experimentInsService.getByExperimentId(exp.getId()); exp.setExperimentInsList(experimentInsList); @@ -178,7 +186,7 @@ public class ExperimentServiceImpl implements ExperimentService { @Override public String removeById(Integer id) throws Exception { Experiment experiment = experimentDao.queryById(id); - if (experiment==null){ + if (experiment == null) { throw new Exception("实验不存在"); } @@ -186,17 +194,16 @@ public class ExperimentServiceImpl implements ExperimentService { LoginUser loginUser = SecurityUtils.getLoginUser(); String username = loginUser.getUsername(); String createdBy = experiment.getCreateBy(); - if (!(StringUtils.equals(username,"admin") || StringUtils.equals(username,createdBy))){ + if (!(StringUtils.equals(username, "admin") || StringUtils.equals(username, createdBy))) { throw new Exception("无权限删除该实验"); } List experimentInsList = experimentInsService.queryByExperimentId(experiment.getId()); - if (experimentInsList!=null&&experimentInsList.size()>0){ + if (experimentInsList != null && experimentInsList.size() > 0) { throw new Exception("该实验存在实例,无法删除"); } experiment.setState(0); - return this.experimentDao.update(experiment)>0?"删除成功":"删除失败"; - + return this.experimentDao.update(experiment) > 0 ? "删除成功" : "删除失败"; } @@ -215,7 +222,7 @@ public class ExperimentServiceImpl implements ExperimentService { System.out.println("No experiment"); } Workflow workflow = workflowService.queryById(experiment.getWorkflowId()); - if(workflow == null) { + if (workflow == null) { throw new RuntimeException("流水线不存在,请先创建流水线"); } @@ -236,10 +243,10 @@ public class ExperimentServiceImpl implements ExperimentService { runReqMap.put("params", params); // 实验字段的Map,不要写成一行!否则会返回null Map experimentMap = new HashMap<>(); - experimentMap.put("name", "experiment-"+experiment.getId()); + experimentMap.put("name", "experiment-" + experiment.getId()); runReqMap.put("experiment", experimentMap); - Map output = (Map) converMap.get("output"); + Map output = (Map) converMap.get("output"); // 调argo运行接口 String runRes = HttpUtils.sendPost(argoUrl + argoWorkflowRun, JsonUtils.mapToJson(runReqMap)); @@ -254,7 +261,6 @@ public class ExperimentServiceImpl implements ExperimentService { } - Map metadata = (Map) data.get("metadata"); // 插入记录到实验实例表 ExperimentIns experimentIns = new ExperimentIns(); @@ -274,11 +280,11 @@ public class ExperimentServiceImpl implements ExperimentService { //得到dependendcy Map converMap2 = JsonUtils.jsonToMap(JacksonUtil.replaceInAarry(convertRes, params)); - Map dependendcy = (Map)converMap2.get("model_dependency"); - Map trainInfo = (Map)converMap2.get("component_info"); + Map dependendcy = (Map) converMap2.get("model_dependency"); + Map trainInfo = (Map) converMap2.get("component_info"); Map metricRecord = (Map) runResMap.get("metric_record"); - if (metricRecord != null){ + if (metricRecord != null) { //把训练用的数据集也放进去 addDatesetToMetric(metricRecord, trainInfo); experimentIns.setMetricRecord(JacksonUtil.toJSONString(metricRecord)); @@ -286,18 +292,18 @@ public class ExperimentServiceImpl implements ExperimentService { //插入ExperimentIns表中 ExperimentIns insert = experimentInsService.insert(experimentIns); //插入到模型依赖关系表 - if (dependendcy != null && trainInfo != null){ - insertModelDependency(dependendcy,trainInfo,insert.getId(),experiment.getName()); + if (dependendcy != null && trainInfo != null) { + insertModelDependencyNew(dependendcy, trainInfo, insert.getId(), experiment.getName()); } - Map datasetDependendcy = (Map)converMap2.get("dataset_dependency"); + Map datasetDependendcy = (Map) converMap2.get("dataset_dependency"); //暂存数据集元数据{} - if (datasetDependendcy != null && trainInfo != null){ - insertDatasetTempStorage(datasetDependendcy,trainInfo,experiment.getId(),insert.getId(),experiment.getName()); + if (datasetDependendcy != null && trainInfo != null) { + insertDatasetTempStorage(datasetDependendcy, trainInfo, experiment.getId(), insert.getId(), experiment.getName()); } - }catch (Exception e){ + } catch (Exception e) { throw new RuntimeException(e); } List updatedExperimentInsList = experimentInsService.getByExperimentId(id); @@ -305,12 +311,12 @@ public class ExperimentServiceImpl implements ExperimentService { return experiment; } - private void addDatesetToMetric(Map metricRecord, Map trainInfo) { + private void addDatesetToMetric(Map metricRecord, Map trainInfo) { processMetricPart(metricRecord, trainInfo, "train", "model_train"); processMetricPart(metricRecord, trainInfo, "evaluate", "model_evaluate"); } - private void processMetricPart(Map metricRecord, Map trainInfo, String metricKey, String trainInfoKey) { + private void processMetricPart(Map metricRecord, Map trainInfo, String metricKey, String trainInfoKey) { List> metricList = (List>) metricRecord.get(metricKey); if (metricList != null) { for (Map metricRecordItem : metricList) { @@ -335,24 +341,23 @@ public class ExperimentServiceImpl implements ExperimentService { } - - private void insertModelDependency(Map dependendcy,Map trainInfo, Integer experimentInsId, String experimentName) throws Exception { + private void insertModelDependency(Map dependendcy, Map trainInfo, Integer experimentInsId, String experimentName) throws Exception { Iterator> dependendcyIterator = dependendcy.entrySet().iterator(); - Map modelTrain = (Map) trainInfo.get("model_train"); - Map modelEvaluate = (Map) trainInfo.get("model_evaluate"); - Map modelExport = (Map) trainInfo.get("model_export"); + Map modelTrain = (Map) trainInfo.get("model_train"); + Map modelEvaluate = (Map) trainInfo.get("model_evaluate"); + Map modelExport = (Map) trainInfo.get("model_export"); while (dependendcyIterator.hasNext()) { ModelDependency modelDependency = new ModelDependency(); Map.Entry entry = dependendcyIterator.next(); Map modelDel = (Map) entry.getValue(); - Map source = (Map) modelDel.get("source"); + Map source = (Map) modelDel.get("source"); List> test = (List>) modelDel.get("test"); List> target = (List>) modelDel.get("target"); String sourceTaskId = (String) source.get("task_id"); - Map modelTrainMap = (Map)modelTrain.get(sourceTaskId); + Map modelTrainMap = (Map) modelTrain.get(sourceTaskId); //处理project数据 - Map projectMap = (Map) modelTrainMap.get("project"); + Map projectMap = (Map) modelTrainMap.get("project"); ProjectDepency projectDepency = new ProjectDepency(); projectDepency.setBranch(projectMap.get("branch")); String projectUrl = projectMap.get("url"); @@ -361,7 +366,7 @@ public class ExperimentServiceImpl implements ExperimentService { //依赖项目 modelDependency.setProjectDependency(JsonUtils.objectToJson(projectDepency)); //处理镜像 - Map imagesMap = (Map) modelTrainMap.get("image"); + Map imagesMap = (Map) modelTrainMap.get("image"); modelDependency.setTrainImage(imagesMap.get("name")); List> trainParamList = (List>) modelTrainMap.get("params"); modelDependency.setTrainParams(JsonUtils.objectToJson(trainParamList)); @@ -402,24 +407,24 @@ public class ExperimentServiceImpl implements ExperimentService { List> resultTestDatasets = new ArrayList>(); //处理test数据 if (test != null) { - for(int i=0;i testMap = test.get(i); String testTaskId = (String) testMap.get("task_id"); - Map evaluateMap = (Map) modelEvaluate.get(testTaskId); + Map evaluateMap = (Map) modelEvaluate.get(testTaskId); List> realDataSetList = (List>) evaluateMap.get("datasets"); - for(int j=0;j realDataSet = realDataSetList.get(j); Dataset dataset = datasetService.queryById((Integer) realDataSet.get("dataset_id")); - if (dataset == null){ + if (dataset == null) { throw new Exception("源数据集不存在"); } realDataSet.put("dataset_name", dataset.getName()); resultTestDatasets.add(realDataSet); } - } + } - //测试数据集 - modelDependency.setTestDataset(JsonUtils.objectToJson(resultTestDatasets)); + //测试数据集 + modelDependency.setTestDataset(JsonUtils.objectToJson(resultTestDatasets)); } //处理target数据 if (target != null) { @@ -437,7 +442,7 @@ public class ExperimentServiceImpl implements ExperimentService { modelDependencyService.insert(modelDependency); } } - }else { + } else { modelDependency.setState(2); modelDependencyService.insert(modelDependency); } @@ -447,27 +452,25 @@ public class ExperimentServiceImpl implements ExperimentService { /** * 存储数据集元数据到临时表 - * - * - */ - private void insertDatasetTempStorage(Map datasetDependendcy, Map trainInfo, Integer experimentId,Integer experimentInsId, String experimentName) { + */ + private void insertDatasetTempStorage(Map datasetDependendcy, Map trainInfo, Integer experimentId, Integer experimentInsId, String experimentName) { DatasetTempStorage datasetTempStorage = new DatasetTempStorage(); Iterator> dependendcyIterator = datasetDependendcy.entrySet().iterator(); - Map datasetExport = (Map) trainInfo.get("dataset_export"); - Map datasetPreprocess = (Map) trainInfo.get("dataset_preprocess"); + Map datasetExport = (Map) trainInfo.get("dataset_export"); + Map datasetPreprocess = (Map) trainInfo.get("dataset_preprocess"); while (dependendcyIterator.hasNext()) { Map.Entry entry = dependendcyIterator.next(); Map modelDel = (Map) entry.getValue(); - Map source = (Map) modelDel.get("source"); //被处理数据集 + Map source = (Map) modelDel.get("source"); //被处理数据集 // List> test = (List>) modelDel.get("test"); List> target = (List>) modelDel.get("target"); //导出的数据集 String sourceTaskId = (String) source.get("task_id"); - Map datasetPreprocessMap = (Map)datasetPreprocess.get(sourceTaskId); + Map datasetPreprocessMap = (Map) datasetPreprocess.get(sourceTaskId); //处理project数据 - Map projectMap = (Map) datasetPreprocessMap.get("project"); - Map datasets = (Map) datasetPreprocessMap.get("datasets"); + Map projectMap = (Map) datasetPreprocessMap.get("project"); + Map datasets = (Map) datasetPreprocessMap.get("datasets"); datasetTempStorage.setName((String) datasets.get("dataset_identifier")); datasetTempStorage.setVersion((String) datasets.get("dataset_version")); // 拼接需要的参数 @@ -475,8 +478,8 @@ public class ExperimentServiceImpl implements ExperimentService { sourceParams.put("experiment_name", experimentName); sourceParams.put("experiment_ins_id", experimentInsId); sourceParams.put("experiment_id", experimentId); - sourceParams.put("train_name",sourceTaskId); - sourceParams.put("preprocess_code",projectMap); + sourceParams.put("train_name", sourceTaskId); + sourceParams.put("preprocess_code", projectMap); datasetTempStorage.setSource(JacksonUtil.toJSONString(sourceParams)); datasetTempStorage.setState(1); datasetTempStorageService.insert(datasetTempStorage); @@ -484,9 +487,131 @@ public class ExperimentServiceImpl implements ExperimentService { } + private void insertModelDependencyNew(Map dependendcy, Map trainInfo, Integer experimentInsId, String experimentName) throws Exception { + Iterator> dependendcyIterator = dependendcy.entrySet().iterator(); + Map modelTrain = (Map) trainInfo.get("model_train"); + Map modelEvaluate = (Map) trainInfo.get("model_evaluate"); + Map modelExport = (Map) trainInfo.get("model_export"); + while (dependendcyIterator.hasNext()) { + ModelsVo modelMetaVo = new ModelsVo(); + ModelDependency1 modelDependency = new ModelDependency1(); + + Map.Entry entry = dependendcyIterator.next(); + Map modelDel = (Map) entry.getValue(); + Map source = (Map) modelDel.get("source"); + List> test = (List>) modelDel.get("test"); + List> target = (List>) modelDel.get("target"); + String sourceTaskId = (String) source.get("task_id"); + + Map modelTrainMap = (Map) modelTrain.get(sourceTaskId); + //处理project数据 + Map projectMap = (Map) modelTrainMap.get("project"); + ProjectDepency projectDepency = new ProjectDepency(); + projectDepency.setBranch(projectMap.get("branch")); + String projectUrl = projectMap.get("url"); + projectDepency.setUrl(projectUrl); + projectDepency.setName(projectUrl.substring(projectUrl.lastIndexOf('/') + 1, projectUrl.length() - 4)); + modelMetaVo.setProjectDepency(projectDepency); + //处理镜像 + Map imagesMap = (Map) modelTrainMap.get("image"); + modelMetaVo.setImage(imagesMap.get("name")); + //处理训练参数 todo + HashMap trainParam = (HashMap) modelTrainMap.get("params"); + + //处理source数据 + List> modelsList = (List>) modelTrainMap.get("models"); + if (modelsList != null) { + Map parentModelMap = modelsList.get(0); + String id = (String) parentModelMap.get("model_id"); + String identifier = (String) parentModelMap.get("model_identifier"); + String version = (String) parentModelMap.get("model_version"); + + HashMap map = new HashMap<>(); + map.put("repoId", id); + map.put("identifier", identifier); + map.put("version", version); + String parentModel = JSON.toJSONString(map); + modelMetaVo.setParentModel(parentModel); + modelDependency.setParentModel(parentModel); + } + + //训练数据集 + List> trainDatasetList = (List>) modelTrainMap.get("datasets"); + if (trainDatasetList != null) { + List trainDatasets = new ArrayList<>(); + for (Map dataset : trainDatasetList) { + NewDatasetVo newDatasetVo = new NewDatasetVo(); + newDatasetVo.setId((Integer) dataset.get("dataset_id")); + newDatasetVo.setName((String) dataset.get("dataset_name")); + newDatasetVo.setVersion((String) dataset.get("dataset_version")); + newDatasetVo.setIdentifier((String) dataset.get("dataset_identifier")); + //todo newDatasetVo.setowner + trainDatasets.add(newDatasetVo); + } + modelMetaVo.setTrainDatasets(trainDatasets); + } + //训练任务 + TrainTaskDepency trainTask = new TrainTaskDepency(); + trainTask.setTaskId(sourceTaskId); + trainTask.setInsId(experimentInsId); + trainTask.setName(experimentName); + modelMetaVo.setTrainTask(trainTask); + + //处理test数据 + if (test != null) { + for (Map testMap : test) { + String testTaskId = (String) testMap.get("task_id"); + Map evaluateMap = (Map) modelEvaluate.get(testTaskId); + List> testDatasetList = (List>) evaluateMap.get("datasets"); + List testDatasets = new ArrayList<>(); + for (Map dataset : testDatasetList) { + NewDatasetVo newDatasetVo = new NewDatasetVo(); + newDatasetVo.setId((Integer) dataset.get("dataset_id")); + newDatasetVo.setName((String) dataset.get("dataset_name")); + newDatasetVo.setVersion((String) dataset.get("dataset_version")); + newDatasetVo.setIdentifier((String) dataset.get("dataset_identifier")); + testDatasets.add(newDatasetVo); + } + modelMetaVo.setTestDatasets(testDatasets); + } + } + + //处理target数据 + LoginUser loginUser = SecurityUtils.getLoginUser(); + String gitLinkUsername = loginUser.getSysUser().getGitLinkUsername(); + modelMetaVo.setOwner(gitLinkUsername); + if (target != null) { + for (int i = 0; i < target.size(); i++) { + Map targetMap = target.get(i); + String targetTaskId = (String) targetMap.get("task_id"); + Map exportMap = (Map) modelExport.get(targetTaskId); + List> modelTargetList = (List>) exportMap.get("models"); + for (int j = 0; j < modelTargetList.size(); j++) { + Map targetModel = modelTargetList.get(i); + modelMetaVo.setId((Integer) targetModel.get("model_id")); + modelMetaVo.setIdentifier((String) targetModel.get("model_identifier")); + modelMetaVo.setName((String) targetModel.get("model_name")); + modelMetaVo.setVersionDesc((String) targetModel.get("model_version")); + modelMetaVo.setOwner(gitLinkUsername); + + modelsService.newCreateVersion(modelMetaVo); + } + } + } else { + String meta = JSON.toJSONString(modelMetaVo); + modelDependency.setMeta(meta); + modelDependency.setOwner(gitLinkUsername); + modelDependency1Dao.insert(modelDependency); + } + + } + } + + /** * 被废弃的旧JSON - * @param experiment + * + * @param experiment * @return * @throws Exception */ @@ -585,7 +710,6 @@ public class ExperimentServiceImpl implements ExperimentService { // } // } // } - @Override public Experiment addAndRunExperiment(Experiment experiment) throws Exception { // 第一步: 调用add方法插入实验记录到数据库 @@ -596,7 +720,7 @@ public class ExperimentServiceImpl implements ExperimentService { } // 调用runExperiment方法运行实验 try { - newExperiment = this.runExperiment(newExperiment.getId()); + newExperiment = this.runExperiment(newExperiment.getId()); } catch (Exception e) { throw new RuntimeException(e); } @@ -605,7 +729,6 @@ public class ExperimentServiceImpl implements ExperimentService { } - /** * 返回实验配置参数 * @@ -648,6 +771,4 @@ public class ExperimentServiceImpl implements ExperimentService { } - - }