| @@ -84,5 +84,9 @@ public interface ModelDependencyDao { | |||||
| List<ModelDependency> queryByModelDependency(@Param("modelDependency") ModelDependency modelDependency); | List<ModelDependency> queryByModelDependency(@Param("modelDependency") ModelDependency modelDependency); | ||||
| List<ModelDependency> queryChildrenByVersionId(@Param("model_id")String modelId, @Param("version")String version); | List<ModelDependency> queryChildrenByVersionId(@Param("model_id")String modelId, @Param("version")String version); | ||||
| List<ModelDependency> queryByIns(@Param("expInsId")Integer expInsId); | |||||
| ModelDependency queryByInsAndTrainTaskId(@Param("expInsId")Integer expInsId,@Param("taskId") String taskId); | |||||
| } | } | ||||
| @@ -7,20 +7,15 @@ import com.ruoyi.platform.mapper.ExperimentDao; | |||||
| import com.ruoyi.platform.mapper.ExperimentInsDao; | import com.ruoyi.platform.mapper.ExperimentInsDao; | ||||
| import com.ruoyi.platform.mapper.ModelDependencyDao; | import com.ruoyi.platform.mapper.ModelDependencyDao; | ||||
| import com.ruoyi.platform.service.ExperimentInsService; | import com.ruoyi.platform.service.ExperimentInsService; | ||||
| import com.ruoyi.platform.service.ModelDependencyService; | |||||
| import com.ruoyi.platform.utils.JacksonUtil; | import com.ruoyi.platform.utils.JacksonUtil; | ||||
| import org.apache.commons.lang3.StringUtils; | import org.apache.commons.lang3.StringUtils; | ||||
| import org.springframework.beans.factory.annotation.Autowired; | import org.springframework.beans.factory.annotation.Autowired; | ||||
| import org.springframework.data.domain.Page; | |||||
| import org.springframework.scheduling.annotation.Scheduled; | import org.springframework.scheduling.annotation.Scheduled; | ||||
| import org.springframework.stereotype.Component; | import org.springframework.stereotype.Component; | ||||
| import javax.annotation.Resource; | import javax.annotation.Resource; | ||||
| import java.io.IOException; | import java.io.IOException; | ||||
| import java.util.ArrayList; | |||||
| import java.util.Date; | |||||
| import java.util.List; | |||||
| import java.util.Map; | |||||
| import java.util.*; | |||||
| @Component() | @Component() | ||||
| public class ExperimentInstanceStatusTask { | public class ExperimentInstanceStatusTask { | ||||
| @@ -34,7 +29,7 @@ public class ExperimentInstanceStatusTask { | |||||
| private ModelDependencyDao modelDependencyDao; | private ModelDependencyDao modelDependencyDao; | ||||
| private List<Integer> experimentIds = new ArrayList<>(); | private List<Integer> experimentIds = new ArrayList<>(); | ||||
| @Scheduled(cron = "0/14 * * * * ?") // 每30S执行一次 | |||||
| @Scheduled(cron = "0/30 * * * * ?") // 每30S执行一次 | |||||
| public void executeExperimentInsStatus() throws IOException { | public void executeExperimentInsStatus() throws IOException { | ||||
| // 首先查到所有非终止态的实验实例 | // 首先查到所有非终止态的实验实例 | ||||
| List<ExperimentIns> experimentInsList = experimentInsService.queryByExperimentIsNotTerminated(); | List<ExperimentIns> experimentInsList = experimentInsService.queryByExperimentIsNotTerminated(); | ||||
| @@ -46,95 +41,94 @@ public class ExperimentInstanceStatusTask { | |||||
| String oldStatus = experimentIns.getStatus(); | String oldStatus = experimentIns.getStatus(); | ||||
| try { | try { | ||||
| experimentIns = experimentInsService.queryStatusFromArgo(experimentIns); | experimentIns = experimentInsService.queryStatusFromArgo(experimentIns); | ||||
| }catch (Exception e){ | |||||
| } catch (Exception e) { | |||||
| experimentIns.setStatus("Failed"); | experimentIns.setStatus("Failed"); | ||||
| } | } | ||||
| // if (!StringUtils.equals(oldStatus,experimentIns.getStatus())){ | |||||
| experimentIns.setUpdateTime(new Date()); | |||||
| // 线程安全的添加操作 | |||||
| synchronized (experimentIds) { | |||||
| experimentIds.add(experimentIns.getExperimentId()); | |||||
| } | |||||
| updateList.add(experimentIns); | |||||
| // } | |||||
| // experimentInsDao.update(experimentIns); | |||||
| experimentIns.setUpdateTime(new Date()); | |||||
| // 线程安全的添加操作 | |||||
| synchronized (experimentIds) { | |||||
| experimentIds.add(experimentIns.getExperimentId()); | |||||
| } | |||||
| updateList.add(experimentIns); | |||||
| } | } | ||||
| } | } | ||||
| if (updateList.size() > 0){ | |||||
| if (updateList.size() > 0) { | |||||
| experimentInsDao.insertOrUpdateBatch(updateList); | experimentInsDao.insertOrUpdateBatch(updateList); | ||||
| //遍历模型关系表,找到 | //遍历模型关系表,找到 | ||||
| List<ModelDependency> modelDependencyList = new ArrayList<ModelDependency>(); | List<ModelDependency> modelDependencyList = new ArrayList<ModelDependency>(); | ||||
| for (ExperimentIns experimentIns : updateList){ | |||||
| for (ExperimentIns experimentIns : updateList) { | |||||
| ModelDependency modelDependencyquery = new ModelDependency(); | ModelDependency modelDependencyquery = new ModelDependency(); | ||||
| modelDependencyquery.setExpInsId(experimentIns.getId()); | modelDependencyquery.setExpInsId(experimentIns.getId()); | ||||
| modelDependencyquery.setState(2); | modelDependencyquery.setState(2); | ||||
| List<ModelDependency> modelDependencyListquery = modelDependencyDao.queryByModelDependency(modelDependencyquery); | List<ModelDependency> modelDependencyListquery = modelDependencyDao.queryByModelDependency(modelDependencyquery); | ||||
| if (modelDependencyListquery==null||modelDependencyListquery.size()==0){ | |||||
| if (modelDependencyListquery == null || modelDependencyListquery.size() == 0) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| ModelDependency modelDependency = modelDependencyListquery.get(0); | ModelDependency modelDependency = modelDependencyListquery.get(0); | ||||
| //查看状态, | //查看状态, | ||||
| if (StringUtils.equals("Failed",experimentIns.getStatus())){ | |||||
| if (StringUtils.equals("Failed", experimentIns.getStatus())) { | |||||
| //取出节点状态 | //取出节点状态 | ||||
| String trainTask = modelDependency.getTrainTask(); | String trainTask = modelDependency.getTrainTask(); | ||||
| Map<String, Object> trainMap = JacksonUtil.parseJSONStr2Map(trainTask); | Map<String, Object> trainMap = JacksonUtil.parseJSONStr2Map(trainTask); | ||||
| String task_id = (String) trainMap.get("task_id"); | String task_id = (String) trainMap.get("task_id"); | ||||
| if (StringUtils.isEmpty(task_id)){ | |||||
| if (StringUtils.isEmpty(task_id)) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| String nodesStatus = experimentIns.getNodesStatus(); | String nodesStatus = experimentIns.getNodesStatus(); | ||||
| Map<String, Object> nodeMaps = JacksonUtil.parseJSONStr2Map(nodesStatus); | Map<String, Object> nodeMaps = JacksonUtil.parseJSONStr2Map(nodesStatus); | ||||
| Map<String, Object> nodeMap = JacksonUtil.parseJSONStr2Map(JacksonUtil.toJSONString(nodeMaps.get(task_id))); | Map<String, Object> nodeMap = JacksonUtil.parseJSONStr2Map(JacksonUtil.toJSONString(nodeMaps.get(task_id))); | ||||
| if (nodeMap==null){ | |||||
| if (nodeMap == null) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| if (!StringUtils.equals("Succeeded",(String)nodeMap.get("phase"))){ | |||||
| if (!StringUtils.equals("Succeeded", (String) nodeMap.get("phase"))) { | |||||
| modelDependency.setState(0); | modelDependency.setState(0); | ||||
| modelDependencyList.add(modelDependency); | modelDependencyList.add(modelDependency); | ||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| if (modelDependencyList.size()>0) { | |||||
| if (modelDependencyList.size() > 0) { | |||||
| modelDependencyDao.insertOrUpdateBatch(modelDependencyList); | modelDependencyDao.insertOrUpdateBatch(modelDependencyList); | ||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @Scheduled(cron = "0/17 * * * * ?") // / 每30S执行一次 | |||||
| @Scheduled(cron = "0/30 * * * * ?") // / 每30S执行一次 | |||||
| public void executeExperimentStatus() throws IOException { | public void executeExperimentStatus() throws IOException { | ||||
| if (experimentIds.size()==0){ | |||||
| if (experimentIds.size() == 0) { | |||||
| return; | return; | ||||
| } | } | ||||
| // 存储需要更新的实验对象列表 | // 存储需要更新的实验对象列表 | ||||
| List<Experiment> updateExperiments = new ArrayList<>(); | List<Experiment> updateExperiments = new ArrayList<>(); | ||||
| for (Integer experimentId : experimentIds){ | |||||
| for (Integer experimentId : experimentIds) { | |||||
| // 获取当前实验的所有实例列表 | // 获取当前实验的所有实例列表 | ||||
| List<ExperimentIns> insList = experimentInsService.getByExperimentId(experimentId); | List<ExperimentIns> insList = experimentInsService.getByExperimentId(experimentId); | ||||
| List<String> statusList = new ArrayList<String>(); | List<String> statusList = new ArrayList<String>(); | ||||
| // 更新实验状态列表 | // 更新实验状态列表 | ||||
| for (int i=0;i<insList.size();i++){ | |||||
| for (int i = 0; i < insList.size(); i++) { | |||||
| statusList.add(insList.get(i).getStatus()); | statusList.add(insList.get(i).getStatus()); | ||||
| } | } | ||||
| String subStatus = statusList.toString().substring(1, statusList.toString().length() - 1); | String subStatus = statusList.toString().substring(1, statusList.toString().length() - 1); | ||||
| Experiment experiment = experimentDao.queryById(experimentId); | Experiment experiment = experimentDao.queryById(experimentId); | ||||
| // 如果实验状态列表发生变化,则更新实验对象,并加入到需要更新的列表中 | // 如果实验状态列表发生变化,则更新实验对象,并加入到需要更新的列表中 | ||||
| if (!StringUtils.equals(subStatus,experiment.getStatusList())){ | |||||
| if (!StringUtils.equals(subStatus, experiment.getStatusList())) { | |||||
| experiment.setStatusList(subStatus); | experiment.setStatusList(subStatus); | ||||
| updateExperiments.add(experiment); | updateExperiments.add(experiment); | ||||
| } | } | ||||
| } | } | ||||
| if (!updateExperiments.isEmpty()) { | if (!updateExperiments.isEmpty()) { | ||||
| experimentDao.insertOrUpdateBatch(updateExperiments); | experimentDao.insertOrUpdateBatch(updateExperiments); | ||||
| for (int index = 0; index < updateExperiments.size(); index++) { | |||||
| // 线程安全的删除操作 | |||||
| synchronized (experimentIds) { | |||||
| experimentIds.remove(index); | |||||
| // 使用Iterator进行安全的删除操作 | |||||
| Iterator<Integer> iterator = experimentIds.iterator(); | |||||
| while (iterator.hasNext()) { | |||||
| Integer experimentId = iterator.next(); | |||||
| for (Experiment experiment : updateExperiments) { | |||||
| if (experiment.getId().equals(experimentId)) { | |||||
| iterator.remove(); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -62,4 +62,8 @@ public interface ModelDependencyService { | |||||
| List<ModelDependency> queryByModelDependency(ModelDependency modelDependency) throws IOException; | List<ModelDependency> queryByModelDependency(ModelDependency modelDependency) throws IOException; | ||||
| ModelDependcyTreeVo getModelDependencyTree(ModelDependency modelDependency) throws Exception; | ModelDependcyTreeVo getModelDependencyTree(ModelDependency modelDependency) throws Exception; | ||||
| List<ModelDependency> queryByIns(Integer expInsId); | |||||
| ModelDependency queryByInsAndTrainTaskId(Integer expInsId, String taskId); | |||||
| } | } | ||||
| @@ -1,18 +1,17 @@ | |||||
| package com.ruoyi.platform.service.impl; | package com.ruoyi.platform.service.impl; | ||||
| import com.alibaba.druid.util.StringUtils; | |||||
| import com.ruoyi.platform.domain.ExperimentIns; | import com.ruoyi.platform.domain.ExperimentIns; | ||||
| import com.ruoyi.platform.domain.ModelDependency; | |||||
| import com.ruoyi.platform.service.AimService; | import com.ruoyi.platform.service.AimService; | ||||
| import com.ruoyi.platform.service.ExperimentInsService; | import com.ruoyi.platform.service.ExperimentInsService; | ||||
| import com.ruoyi.platform.service.ExperimentService; | |||||
| import com.ruoyi.platform.service.ModelDependencyService; | |||||
| import com.ruoyi.platform.utils.AIM64EncoderUtil; | import com.ruoyi.platform.utils.AIM64EncoderUtil; | ||||
| import com.ruoyi.platform.utils.HttpUtils; | import com.ruoyi.platform.utils.HttpUtils; | ||||
| import com.ruoyi.platform.utils.JacksonUtil; | import com.ruoyi.platform.utils.JacksonUtil; | ||||
| import com.ruoyi.platform.utils.JsonUtils; | import com.ruoyi.platform.utils.JsonUtils; | ||||
| import com.ruoyi.platform.vo.InsMetricInfoVo; | import com.ruoyi.platform.vo.InsMetricInfoVo; | ||||
| import org.apache.dubbo.container.Main; | |||||
| import org.json.JSONObject; | |||||
| import org.json.JSONTokener; | |||||
| import org.apache.commons.lang3.StringUtils; | |||||
| import org.springframework.beans.factory.annotation.Value; | |||||
| import org.springframework.stereotype.Service; | import org.springframework.stereotype.Service; | ||||
| import javax.annotation.Resource; | import javax.annotation.Resource; | ||||
| @@ -24,58 +23,66 @@ import java.util.stream.Collectors; | |||||
| public class AimServiceImpl implements AimService { | public class AimServiceImpl implements AimService { | ||||
| @Resource | @Resource | ||||
| private ExperimentInsService experimentInsService; | private ExperimentInsService experimentInsService; | ||||
| @Resource | |||||
| private ModelDependencyService modelDependencyService; | |||||
| @Value("${aim.url}") | |||||
| private String aimUrl; | |||||
| @Value("${aim.proxyUrl}") | |||||
| private String aimProxyUrl; | |||||
| @Override | @Override | ||||
| public List<InsMetricInfoVo> getExpTrainInfos(Integer experimentId) throws Exception { | public List<InsMetricInfoVo> getExpTrainInfos(Integer experimentId) throws Exception { | ||||
| String experimentName = "experiment-train-0"+experimentId; | |||||
| return getAimRunInfos("",experimentId); | |||||
| return getAimRunInfos(true,experimentId); | |||||
| } | } | ||||
| @Override | @Override | ||||
| public List<InsMetricInfoVo> getExpEvaluateInfos(Integer experimentId) throws Exception { | public List<InsMetricInfoVo> getExpEvaluateInfos(Integer experimentId) throws Exception { | ||||
| String experimentName = "experiment-evaluate-0"+experimentId; | |||||
| return getAimRunInfos("",experimentId); | |||||
| return getAimRunInfos(false,experimentId); | |||||
| } | } | ||||
| @Override | @Override | ||||
| public String getExpMetrics(List<String> runIds) throws Exception { | public String getExpMetrics(List<String> runIds) throws Exception { | ||||
| String decode = AIM64EncoderUtil.decode(runIds); | String decode = AIM64EncoderUtil.decode(runIds); | ||||
| return "http://172.20.32.21:7123/api/runs/search/run?query="+decode; | |||||
| return aimUrl+"/api/runs/search/run?query="+decode; | |||||
| } | } | ||||
| private List<InsMetricInfoVo> getAimRunInfos(String experimentName,Integer experimentId) throws Exception { | |||||
| String encodedUrlString = URLEncoder.encode("run.experiment==\"experiment-0000\"", "UTF-8"); | |||||
| String url = "http://172.20.32.181:30123/api/runs/search/run?query="+encodedUrlString; | |||||
| private List<InsMetricInfoVo> getAimRunInfos(boolean isTrain,Integer experimentId) throws Exception { | |||||
| String experimentName = "experiment-"+experimentId+"-train"; | |||||
| if (!isTrain){ | |||||
| experimentName = "experiment-"+experimentId+"-evaluate"; | |||||
| } | |||||
| String encodedUrlString = URLEncoder.encode("run.experiment==\""+experimentName+"\"", "UTF-8"); | |||||
| String url = aimProxyUrl+"/api/runs/search/run?query="+encodedUrlString; | |||||
| String s = HttpUtils.sendGetRequest(url); | String s = HttpUtils.sendGetRequest(url); | ||||
| System.out.println(s); | |||||
| List<Map<String, Object>> response = JacksonUtil.parseJSONStr2MapList(s); | List<Map<String, Object>> response = JacksonUtil.parseJSONStr2MapList(s); | ||||
| // TODO: parse aim response to InsMetricInfoVo list | |||||
| if (response == null || response.size() == 0){ | if (response == null || response.size() == 0){ | ||||
| return new ArrayList<>(); | return new ArrayList<>(); | ||||
| } | } | ||||
| //查询实例数据 | //查询实例数据 | ||||
| List<ExperimentIns> byExperimentId = experimentInsService.getByExperimentId(experimentId); | List<ExperimentIns> byExperimentId = experimentInsService.getByExperimentId(experimentId); | ||||
| // if (byExperimentId == null || byExperimentId.size() == 0){ | |||||
| // return new ArrayList<>(); | |||||
| // } | |||||
| if (byExperimentId == null || byExperimentId.size() == 0){ | |||||
| return new ArrayList<>(); | |||||
| } | |||||
| List<InsMetricInfoVo> aimRunInfoList = new ArrayList<>(); | List<InsMetricInfoVo> aimRunInfoList = new ArrayList<>(); | ||||
| for (Map<String, Object> run : response) { | for (Map<String, Object> run : response) { | ||||
| InsMetricInfoVo aimRunInfo = new InsMetricInfoVo(); | InsMetricInfoVo aimRunInfo = new InsMetricInfoVo(); | ||||
| String runHash = (String) run.get("run_hash"); | String runHash = (String) run.get("run_hash"); | ||||
| aimRunInfo.setRunId(runHash); | aimRunInfo.setRunId(runHash); | ||||
| Map params= (Map) run.get("params"); | Map params= (Map) run.get("params"); | ||||
| Map<String, Object> paramMap = JsonUtils.flattenJson("", params); | Map<String, Object> paramMap = JsonUtils.flattenJson("", params); | ||||
| aimRunInfo.setParams(paramMap); | aimRunInfo.setParams(paramMap); | ||||
| Map<String, Object> tracesMap= (Map<String, Object>) run.get("params"); | |||||
| String aimrunId = (String) paramMap.get("id"); | |||||
| Map<String, Object> tracesMap= (Map<String, Object>) run.get("traces"); | |||||
| List<Map<String, Object>> metricList = (List<Map<String, Object>>) tracesMap.get("metric"); | List<Map<String, Object>> metricList = (List<Map<String, Object>>) tracesMap.get("metric"); | ||||
| //过滤name为__system__开头的对象 | //过滤name为__system__开头的对象 | ||||
| aimRunInfo.setMetrics(new HashMap<>()); | aimRunInfo.setMetrics(new HashMap<>()); | ||||
| if (metricList != null && metricList.size() > 0){ | if (metricList != null && metricList.size() > 0){ | ||||
| List<Map<String, Object>> metricRelList = metricList.stream() | List<Map<String, Object>> metricRelList = metricList.stream() | ||||
| .filter(map -> !StringUtils.equals("__system__", (String) map.get("name"))) | |||||
| .filter(map -> !StringUtils.startsWith((String) map.get("name"),"__system__" )) | |||||
| .collect(Collectors.toList()); | .collect(Collectors.toList()); | ||||
| if (metricRelList!= null && metricRelList.size() > 0){ | if (metricRelList!= null && metricRelList.size() > 0){ | ||||
| Map<String, Object> relMetricMap = new HashMap<>(); | Map<String, Object> relMetricMap = new HashMap<>(); | ||||
| @@ -85,39 +92,84 @@ public class AimServiceImpl implements AimService { | |||||
| aimRunInfo.setMetrics(relMetricMap); | aimRunInfo.setMetrics(relMetricMap); | ||||
| } | } | ||||
| } | } | ||||
| //找到ins | //找到ins | ||||
| for (ExperimentIns ins : byExperimentId) { | for (ExperimentIns ins : byExperimentId) { | ||||
| String metricRecord = ins.getMetricRecord(); | |||||
| if (metricRecord.contains(runHash)){ | |||||
| String metricRecordString = ins.getMetricRecord(); | |||||
| if (StringUtils.isEmpty(metricRecordString)){ | |||||
| continue; | |||||
| } | |||||
| if (metricRecordString.contains(aimrunId)){ | |||||
| aimRunInfo.setExperimentInsId(ins.getId()); | aimRunInfo.setExperimentInsId(ins.getId()); | ||||
| aimRunInfo.setStatus(ins.getStatus()); | aimRunInfo.setStatus(ins.getStatus()); | ||||
| aimRunInfo.setStartTime(ins.getStartTime()); | |||||
| aimRunInfo.setStartTime(ins.getCreateTime()); | |||||
| Map<String, Object> metricRecordMap = JacksonUtil.parseJSONStr2Map(metricRecordString); | |||||
| //metricRecord 格式为{"train":[{"task_id":"model-train-35303690","run_id":"5560d78f54314672b60304c8d6ba03b8","experiment_name":"experiment-30-train"}],"evaluate":[{"task_id":"model-train-35303690","run_id":"5560d78f54314672b60304c8d6ba03b8","experiment_name":"experiment-30-train"}]} | |||||
| //遍历metricRecord,找到当前task_id对应的ModelDependency | |||||
| if (isTrain){ | |||||
| List<Map<String, Object>> trainList = (List<Map<String, Object>>) metricRecordMap.get("train"); | |||||
| List<String> trainDateSet = getTrainDateSet(trainList, ins.getId(), isTrain); | |||||
| aimRunInfo.setDataset(trainDateSet); | |||||
| }else { | |||||
| List<Map<String, Object>> trainList = (List<Map<String, Object>>) metricRecordMap.get("evaluate"); | |||||
| List<String> trainDateSet = getTrainDateSet(trainList, ins.getId(), isTrain); | |||||
| aimRunInfo.setDataset(trainDateSet); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| aimRunInfoList.add(aimRunInfo); | aimRunInfoList.add(aimRunInfo); | ||||
| } | } | ||||
| //判断哪个最长 | //判断哪个最长 | ||||
| Optional<InsMetricInfoVo> maxMetricsVo = aimRunInfoList.stream() | |||||
| .max((vo1, vo2) -> Integer.compare(vo1.getMetrics().size(), vo2.getMetrics().size())); | |||||
| // 获取所有 metrics 的 key 的并集 | |||||
| Set<String> metricsKeys = (Set<String>) aimRunInfoList.stream() | |||||
| .map(InsMetricInfoVo::getMetrics) | |||||
| .flatMap(metrics -> metrics.keySet().stream()) | |||||
| .collect(Collectors.toSet()); | |||||
| // 将并集赋值给每个 InsMetricInfoVo 的 metricsNames 属性 | |||||
| aimRunInfoList.forEach(vo -> vo.setMetricsNames(new ArrayList<>(metricsKeys))); | |||||
| // 获取所有 params 的 key 的并集 | |||||
| Set<String> paramKeys = (Set<String>) aimRunInfoList.stream() | |||||
| .map(InsMetricInfoVo::getParams) | |||||
| .flatMap(params -> params.keySet().stream()) | |||||
| .collect(Collectors.toSet()); | |||||
| // 将并集赋值给每个 InsMetricInfoVo 的 paramsNames 属性 | |||||
| aimRunInfoList.forEach(vo -> vo.setParamsNames(new ArrayList<>(paramKeys))); | |||||
| return aimRunInfoList; | |||||
| } | |||||
| // 如果找到了,设置 metricsFlag 为 true | |||||
| if (maxMetricsVo.isPresent()) { | |||||
| maxMetricsVo.get().setMetricsFlag(true); | |||||
| } | |||||
| Optional<InsMetricInfoVo> maxParamsVo = aimRunInfoList.stream() | |||||
| .max((vo1, vo2) -> Integer.compare(vo1.getParams().size(), vo2.getParams().size())); | |||||
| // 如果找到了,设置 metricsFlag 为 true | |||||
| if (maxParamsVo.isPresent()) { | |||||
| maxParamsVo.get().setMetricsFlag(true); | |||||
| private List<String> getTrainDateSet(List<Map<String, Object>> trainList,Integer expInsId,boolean isTrain){ | |||||
| if (trainList == null || trainList.size() == 0){ | |||||
| return new ArrayList<>(); | |||||
| } | } | ||||
| List<String> datasetList = new ArrayList<>(); | |||||
| for (Map<String, Object> trainMap : trainList) { | |||||
| String task_id = (String) trainMap.get("task_id"); | |||||
| //modelDependency取到数据集文件 | |||||
| ModelDependency modelDependency = modelDependencyService.queryByInsAndTrainTaskId(expInsId, task_id); | |||||
| //把数据集文件组装成String后放进List | |||||
| String datasetString = ""; | |||||
| if (isTrain){ | |||||
| datasetString = modelDependency.getTrainDataset(); | |||||
| }else { | |||||
| datasetString = modelDependency.getTestDataset(); | |||||
| } | |||||
| List<Map<String, Object>> datasetListMap = JacksonUtil.parseJSONStr2MapList(datasetString); | |||||
| return aimRunInfoList; | |||||
| if (datasetListMap != null && datasetListMap.size() > 0){ | |||||
| for (Map<String, Object> datasetMap : datasetListMap) { | |||||
| //[{"dataset_id":20,"dataset_version":"v0.1.0","dataset_name":"手写体识别模型依赖测试训练数据集"}] | |||||
| String datasetName = (String) datasetMap.get("dataset_name")+":"+(String) datasetMap.get("dataset_version"); | |||||
| datasetList.add(datasetName); | |||||
| } | |||||
| } | |||||
| } | |||||
| return datasetList; | |||||
| } | } | ||||
| } | } | ||||
| @@ -17,10 +17,7 @@ import org.springframework.data.domain.PageRequest; | |||||
| import javax.annotation.Resource; | import javax.annotation.Resource; | ||||
| import java.io.IOException; | import java.io.IOException; | ||||
| import java.util.ArrayList; | |||||
| import java.util.Date; | |||||
| import java.util.List; | |||||
| import java.util.Map; | |||||
| import java.util.*; | |||||
| import java.util.stream.Collectors; | import java.util.stream.Collectors; | ||||
| /** | /** | ||||
| @@ -97,6 +94,16 @@ public class ModelDependencyServiceImpl implements ModelDependencyService { | |||||
| return modelDependcyTreeVo; | return modelDependcyTreeVo; | ||||
| } | } | ||||
| @Override | |||||
| public List<ModelDependency> queryByIns(Integer expInsId) { | |||||
| return modelDependencyDao.queryByIns(expInsId); | |||||
| } | |||||
| @Override | |||||
| public ModelDependency queryByInsAndTrainTaskId(Integer expInsId, String taskId) { | |||||
| return modelDependencyDao.queryByInsAndTrainTaskId(expInsId,taskId); | |||||
| } | |||||
| /** | /** | ||||
| * 递归父模型 | * 递归父模型 | ||||
| * @param modelDependcyTreeVo | * @param modelDependcyTreeVo | ||||
| @@ -18,7 +18,7 @@ public class InsMetricInfoVo implements Serializable { | |||||
| @ApiModelProperty(value = "实例运行状态") | @ApiModelProperty(value = "实例运行状态") | ||||
| private String status; | private String status; | ||||
| @ApiModelProperty(value = "使用数据集") | @ApiModelProperty(value = "使用数据集") | ||||
| private List<Map<String, Object>> dataset; | |||||
| private List<String> dataset; | |||||
| @ApiModelProperty(value = "实例ID") | @ApiModelProperty(value = "实例ID") | ||||
| private Integer experimentInsId; | private Integer experimentInsId; | ||||
| @ApiModelProperty(value = "训练指标") | @ApiModelProperty(value = "训练指标") | ||||
| @@ -27,6 +27,6 @@ public class InsMetricInfoVo implements Serializable { | |||||
| private Map params; | private Map params; | ||||
| @ApiModelProperty(value = "训练记录ID") | @ApiModelProperty(value = "训练记录ID") | ||||
| private String runId; | private String runId; | ||||
| private Boolean metricsFlag = false; | |||||
| private Boolean paramsFlag = false; | |||||
| private List<String> metricsNames; | |||||
| private List<String> paramsNames; | |||||
| } | } | ||||
| @@ -22,6 +22,22 @@ | |||||
| <result property="state" column="state" jdbcType="INTEGER"/> | <result property="state" column="state" jdbcType="INTEGER"/> | ||||
| </resultMap> | </resultMap> | ||||
| <select id="queryByIns" resultMap="ModelDependencyMap"> | |||||
| select | |||||
| id,current_model_id,exp_ins_id,parent_models,ref_item,train_task,train_dataset,train_params,train_image,test_dataset,project_dependency,version,create_by,create_time,update_by,update_time,state | |||||
| from model_dependency | |||||
| <where> | |||||
| exp_ins_id = #{expInsId} and state = 1 | |||||
| </where> | |||||
| </select> | |||||
| <select id="queryByInsAndTrainTaskId" resultMap="ModelDependencyMap"> | |||||
| select | |||||
| id,current_model_id,exp_ins_id,parent_models,ref_item,train_task,train_dataset,train_params,train_image,test_dataset,project_dependency,version,create_by,create_time,update_by,update_time,state | |||||
| from model_dependency | |||||
| <where> | |||||
| exp_ins_id = #{expInsId} and train_task like concat('%', #{taskId}, '%') limit 1 | |||||
| </where> | |||||
| </select> | |||||
| <select id="queryChildrenByVersionId" resultMap="ModelDependencyMap"> | <select id="queryChildrenByVersionId" resultMap="ModelDependencyMap"> | ||||
| select | select | ||||
| id,current_model_id,exp_ins_id,parent_models,ref_item,train_task,train_dataset,train_params,train_image,test_dataset,project_dependency,version,create_by,create_time,update_by,update_time,state | id,current_model_id,exp_ins_id,parent_models,ref_item,train_task,train_dataset,train_params,train_image,test_dataset,project_dependency,version,create_by,create_time,update_by,update_time,state | ||||