| @@ -31,6 +31,7 @@ public class Constant { | |||
| public final static String Pending = "Pending"; | |||
| public final static String Init = "Init"; | |||
| public final static String Stopped = "Stopped"; | |||
| public final static String Succeeded = "Succeeded"; | |||
| public final static String Type_Train = "train"; | |||
| public final static String Type_Evaluate = "evaluate"; | |||
| @@ -6,6 +6,7 @@ import com.fasterxml.jackson.databind.PropertyNamingStrategy; | |||
| import com.fasterxml.jackson.databind.annotation.JsonNaming; | |||
| import io.swagger.annotations.ApiModel; | |||
| import io.swagger.annotations.ApiModelProperty; | |||
| import lombok.Data; | |||
| import java.io.Serializable; | |||
| import java.util.Date; | |||
| @@ -18,6 +19,7 @@ import java.util.Date; | |||
| */ | |||
| @JsonNaming(PropertyNamingStrategy.SnakeCaseStrategy.class) | |||
| @ApiModel("实验实例对象") | |||
| @Data | |||
| public class ExperimentIns implements Serializable { | |||
| private static final long serialVersionUID = 623464560240790680L; | |||
| @ApiModelProperty(name = "id") | |||
| @@ -53,6 +55,10 @@ public class ExperimentIns implements Serializable { | |||
| @JsonRawValue | |||
| private String metricRecord; | |||
| @ApiModelProperty(value = "指标数值", notes = "以JSON字符串格式存储") | |||
| @JsonRawValue | |||
| private String metricValue; | |||
| @ApiModelProperty(value = "开始时间") | |||
| private Date startTime; | |||
| @@ -81,162 +87,5 @@ public class ExperimentIns implements Serializable { | |||
| @TableField(exist = false) | |||
| private String experimentName; | |||
| public ExperimentIns() { | |||
| } | |||
| public Integer getId() { | |||
| return id; | |||
| } | |||
| public void setId(Integer id) { | |||
| this.id = id; | |||
| } | |||
| public Integer getExperimentId() { | |||
| return experimentId; | |||
| } | |||
| public void setExperimentId(Integer experimentId) { | |||
| this.experimentId = experimentId; | |||
| } | |||
| public String getArgoInsName() { | |||
| return argoInsName; | |||
| } | |||
| public void setArgoInsName(String argoInsName) { | |||
| this.argoInsName = argoInsName; | |||
| } | |||
| public String getArgoInsNs() { | |||
| return argoInsNs; | |||
| } | |||
| public void setArgoInsNs(String argoInsNs) { | |||
| this.argoInsNs = argoInsNs; | |||
| } | |||
| public String getStatus() { | |||
| return status; | |||
| } | |||
| public void setStatus(String status) { | |||
| this.status = status; | |||
| } | |||
| public String getNodesStatus() { | |||
| return nodesStatus; | |||
| } | |||
| public void setNodesStatus(String nodesStatus) { | |||
| this.nodesStatus = nodesStatus; | |||
| } | |||
| public String getNodesResult() { | |||
| return nodesResult; | |||
| } | |||
| public void setNodesResult(String nodesResult) { | |||
| this.nodesResult = nodesResult; | |||
| } | |||
| public String getNodesLogs() { | |||
| return nodesLogs; | |||
| } | |||
| public void setNodesLogs(String nodesLogs) { | |||
| this.nodesLogs = nodesLogs; | |||
| } | |||
| public String getGlobalParam() { | |||
| return globalParam; | |||
| } | |||
| public void setGlobalParam(String globalParam) { | |||
| this.globalParam = globalParam; | |||
| } | |||
| public void setStartTime(Date startTime) { | |||
| this.startTime = startTime; | |||
| } | |||
| public Date getStartTime() { | |||
| return startTime; | |||
| } | |||
| public void setFinishTime(Date finishTime) { | |||
| this.finishTime = finishTime; | |||
| } | |||
| public Date getFinishTime() { | |||
| return finishTime; | |||
| } | |||
| public String getCreateBy() { | |||
| return createBy; | |||
| } | |||
| public void setCreateBy(String createBy) { | |||
| this.createBy = createBy; | |||
| } | |||
| public Date getCreateTime() { | |||
| return createTime; | |||
| } | |||
| public void setCreateTime(Date createTime) { | |||
| this.createTime = createTime; | |||
| } | |||
| public String getUpdateBy() { | |||
| return updateBy; | |||
| } | |||
| public void setUpdateBy(String updateBy) { | |||
| this.updateBy = updateBy; | |||
| } | |||
| public Date getUpdateTime() { | |||
| return updateTime; | |||
| } | |||
| public void setUpdateTime(Date updateTime) { | |||
| this.updateTime = updateTime; | |||
| } | |||
| public Integer getState() { | |||
| return state; | |||
| } | |||
| public void setState(Integer state) { | |||
| this.state = state; | |||
| } | |||
| public Long getWorkflowId() { | |||
| return workflowId; | |||
| } | |||
| public void setWorkflowId(Long workflowId) { | |||
| this.workflowId = workflowId; | |||
| } | |||
| public String getMetricRecord() { | |||
| return metricRecord; | |||
| } | |||
| public void setMetricRecord(String metricRecord) { | |||
| this.metricRecord = metricRecord; | |||
| } | |||
| public String getExperimentName() { | |||
| return experimentName; | |||
| } | |||
| public void setExperimentName(String experimentName) { | |||
| this.experimentName = experimentName; | |||
| } | |||
| } | |||
| @@ -1,13 +1,15 @@ | |||
| package com.ruoyi.platform.scheduling; | |||
| import com.ruoyi.platform.constant.Constant; | |||
| import com.ruoyi.platform.domain.Experiment; | |||
| import com.ruoyi.platform.domain.ExperimentIns; | |||
| import com.ruoyi.platform.domain.ModelDependency; | |||
| import com.ruoyi.platform.mapper.ExperimentDao; | |||
| import com.ruoyi.platform.mapper.ExperimentInsDao; | |||
| import com.ruoyi.platform.mapper.ModelDependencyDao; | |||
| import com.ruoyi.platform.service.AimService; | |||
| import com.ruoyi.platform.service.ExperimentInsService; | |||
| import com.ruoyi.platform.utils.JacksonUtil; | |||
| import com.ruoyi.platform.utils.JsonUtils; | |||
| import com.ruoyi.platform.vo.InsMetricInfoVo; | |||
| import org.apache.commons.lang3.StringUtils; | |||
| import org.springframework.beans.factory.annotation.Autowired; | |||
| import org.springframework.scheduling.annotation.Scheduled; | |||
| @@ -26,11 +28,12 @@ public class ExperimentInstanceStatusTask { | |||
| @Resource | |||
| private ExperimentInsDao experimentInsDao; | |||
| @Resource | |||
| private ModelDependencyDao modelDependencyDao; | |||
| private AimService aimService; | |||
| private List<Integer> experimentIds = new ArrayList<>(); | |||
| @Scheduled(cron = "0/30 * * * * ?") // 每30S执行一次 | |||
| public void executeExperimentInsStatus() throws IOException { | |||
| public void executeExperimentInsStatus() throws Exception { | |||
| // 首先查到所有非终止态的实验实例 | |||
| List<ExperimentIns> experimentInsList = experimentInsService.queryByExperimentIsNotTerminated(); | |||
| // 去argo查询状态 | |||
| @@ -38,12 +41,42 @@ public class ExperimentInstanceStatusTask { | |||
| if (experimentInsList != null && experimentInsList.size() > 0) { | |||
| for (ExperimentIns experimentIns : experimentInsList) { | |||
| //当原本状态为null或非终止态时才调用argo接口 | |||
| String oldStatus = experimentIns.getStatus(); | |||
| try { | |||
| experimentIns = experimentInsService.queryStatusFromArgo(experimentIns); | |||
| } catch (Exception e) { | |||
| experimentIns.setStatus("Failed"); | |||
| } | |||
| //运行成功的实验实例记录指标数值 | |||
| if (Constant.Succeeded.equals(experimentIns.getStatus())) { | |||
| Map<String, Object> metricRecord = JacksonUtil.parseJSONStr2Map(experimentIns.getMetricRecord()); | |||
| List<Map<String, Object>> trainMetricRecord = (List<Map<String, Object>>) metricRecord.get("train"); | |||
| List<Map<String, Object>> evaluateMetricRecord = (List<Map<String, Object>>) metricRecord.get("evaluate"); | |||
| HashMap<String, Object> metricValue = new HashMap<>(); | |||
| HashMap<String, Object> trainMetricValue = new HashMap<>(); | |||
| HashMap<String, Object> evaluateMetricValue = new HashMap<>(); | |||
| if (trainMetricRecord != null && !trainMetricRecord.isEmpty()) { | |||
| String runId = (String) trainMetricRecord.get(0).get("run_id"); | |||
| List<InsMetricInfoVo> expTrainInfos = aimService.getExpTrainInfos1(true, experimentIns.getExperimentId(), runId); | |||
| for (InsMetricInfoVo expTrainInfo : expTrainInfos) { | |||
| Map metrics = expTrainInfo.getMetrics(); | |||
| trainMetricValue.putAll(metrics); | |||
| } | |||
| } | |||
| if (evaluateMetricRecord != null && !evaluateMetricRecord.isEmpty()) { | |||
| String runId = (String) evaluateMetricRecord.get(0).get("run_id"); | |||
| List<InsMetricInfoVo> expTrainInfos = aimService.getExpTrainInfos1(false, experimentIns.getExperimentId(), runId); | |||
| for (InsMetricInfoVo expTrainInfo : expTrainInfos) { | |||
| Map metrics = expTrainInfo.getMetrics(); | |||
| evaluateMetricValue.putAll(metrics); | |||
| } | |||
| } | |||
| metricValue.put("train", trainMetricValue); | |||
| metricValue.put("evaluate", evaluateMetricValue); | |||
| experimentIns.setMetricValue(JsonUtils.mapToJson(metricValue)); | |||
| } | |||
| experimentIns.setUpdateTime(new Date()); | |||
| // 线程安全的添加操作 | |||
| synchronized (experimentIds) { | |||