Browse Source

自动机器学习开发

dev-automl
chenzhihang 1 year ago
parent
commit
2f5c3d4476
12 changed files with 531 additions and 103 deletions
  1. +1
    -1
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/autoML/AutoMlController.java
  2. +6
    -0
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/autoML/AutoMlInsController.java
  3. +2
    -0
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/domain/AutoMl.java
  4. +7
    -0
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/domain/AutoMlIns.java
  5. +4
    -0
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/mapper/AutoMlInsDao.java
  6. +97
    -0
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/scheduling/AutoMlInsStatusTask.java
  7. +6
    -0
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/AutoMlInsService.java
  8. +167
    -1
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/AutoMlInsServiceImpl.java
  9. +61
    -96
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/AutoMlServiceImpl.java
  10. +155
    -0
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/vo/AutoMlParamVo.java
  11. +3
    -3
      ruoyi-modules/management-platform/src/main/resources/mapper/managementPlatform/AutoMlDao.xml
  12. +22
    -2
      ruoyi-modules/management-platform/src/main/resources/mapper/managementPlatform/AutoMlInsDao.xml

+ 1
- 1
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/autoML/AutoMlController.java View File

@@ -63,7 +63,7 @@ public class AutoMlController extends BaseController {
return AjaxResult.success(this.autoMlService.upload(file, uuid)); return AjaxResult.success(this.autoMlService.upload(file, uuid));
} }


@PostMapping("{id}")
@PostMapping("/run/{id}")
@ApiOperation("运行自动机器学习实验") @ApiOperation("运行自动机器学习实验")
public GenericsAjaxResult<String> runAutoML(@PathVariable("id") Long id) throws Exception { public GenericsAjaxResult<String> runAutoML(@PathVariable("id") Long id) throws Exception {
return genericsSuccess(this.autoMlService.runAutoMlIns(id)); return genericsSuccess(this.autoMlService.runAutoMlIns(id));


+ 6
- 0
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/autoML/AutoMlInsController.java View File

@@ -46,4 +46,10 @@ public class AutoMlInsController extends BaseController {
public GenericsAjaxResult<String> batchDelete(@RequestBody List<Long> ids) { public GenericsAjaxResult<String> batchDelete(@RequestBody List<Long> ids) {
return genericsSuccess(this.autoMLInsService.batchDelete(ids)); return genericsSuccess(this.autoMLInsService.batchDelete(ids));
} }

@PutMapping("{id}")
@ApiOperation("终止实验实例")
public GenericsAjaxResult<Boolean> terminateAutoMlIns(@PathVariable("id") Long id) {
return genericsSuccess(this.autoMLInsService.terminateAutoMlIns(id));
}
} }

+ 2
- 0
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/domain/AutoMl.java View File

@@ -179,4 +179,6 @@ public class AutoMl {


private String dataset; private String dataset;


@ApiModelProperty(value = "状态列表")
private String statusList;
} }

+ 7
- 0
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/domain/AutoMlIns.java View File

@@ -3,6 +3,7 @@ package com.ruoyi.platform.domain;
import com.fasterxml.jackson.databind.PropertyNamingStrategy; import com.fasterxml.jackson.databind.PropertyNamingStrategy;
import com.fasterxml.jackson.databind.annotation.JsonNaming; import com.fasterxml.jackson.databind.annotation.JsonNaming;
import io.swagger.annotations.ApiModel; import io.swagger.annotations.ApiModel;
import io.swagger.annotations.ApiModelProperty;
import lombok.Data; import lombok.Data;


import java.util.Date; import java.util.Date;
@@ -31,6 +32,12 @@ public class AutoMlIns {


private String source; private String source;


@ApiModelProperty(value = "Argo实例名称")
private String argoInsName;

@ApiModelProperty(value = "Argo命名空间")
private String argoInsNs;

private Date createTime; private Date createTime;


private Date updateTime; private Date updateTime;


+ 4
- 0
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/mapper/AutoMlInsDao.java View File

@@ -11,9 +11,13 @@ public interface AutoMlInsDao {


List<AutoMlIns> queryAllByLimit(@Param("autoMlIns") AutoMlIns autoMlIns, @Param("pageable") Pageable pageable); List<AutoMlIns> queryAllByLimit(@Param("autoMlIns") AutoMlIns autoMlIns, @Param("pageable") Pageable pageable);


List<AutoMlIns> getByAutoMlId(@Param("autoMlId") Long AutoMlId);

int insert(@Param("autoMlIns") AutoMlIns autoMlIns); int insert(@Param("autoMlIns") AutoMlIns autoMlIns);


int update(@Param("autoMlIns") AutoMlIns autoMlIns); int update(@Param("autoMlIns") AutoMlIns autoMlIns);


AutoMlIns queryById(@Param("id") Long id); AutoMlIns queryById(@Param("id") Long id);

List<AutoMlIns> queryByAutoMlInsIsNotTerminated();
} }

+ 97
- 0
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/scheduling/AutoMlInsStatusTask.java View File

@@ -0,0 +1,97 @@
package com.ruoyi.platform.scheduling;

import com.ruoyi.platform.domain.AutoMl;
import com.ruoyi.platform.domain.AutoMlIns;
import com.ruoyi.platform.mapper.AutoMlDao;
import com.ruoyi.platform.mapper.AutoMlInsDao;
import com.ruoyi.platform.service.AutoMlInsService;
import org.apache.commons.lang3.StringUtils;
import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.stereotype.Component;

import javax.annotation.Resource;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;

@Component()
public class AutoMlInsStatusTask {

@Resource
private AutoMlInsService autoMlInsService;

@Resource
private AutoMlInsDao autoMlInsDao;

@Resource
private AutoMlDao autoMlDao;

private List<Long> autoMlIds = new ArrayList<>();

@Scheduled(cron = "0/30 * * * * ?") // 每30S执行一次
public void executeAutoMlInsStatus() throws Exception {
// 首先查到所有非终止态的实验实例
List<AutoMlIns> autoMlInsList = autoMlInsService.queryByAutoMlInsIsNotTerminated();

// 去argo查询状态
List<AutoMlIns> updateList = new ArrayList<>();
if (autoMlInsList != null && autoMlInsList.size() > 0) {
for (AutoMlIns autoMlIns : autoMlInsList) {
//当原本状态为null或非终止态时才调用argo接口
try {
autoMlIns = autoMlInsService.queryStatusFromArgo(autoMlIns);
} catch (Exception e) {
autoMlIns.setStatus("Failed");
}
// 线程安全的添加操作
synchronized (autoMlIds) {
autoMlIds.add(autoMlIns.getAutoMlId());
}
updateList.add(autoMlIns);
}
if (updateList.size() > 0) {
for (AutoMlIns autoMlIns : updateList) {
autoMlInsDao.update(autoMlIns);
}
}
}
}

@Scheduled(cron = "0/30 * * * * ?") // / 每30S执行一次
public void executeAutoMlStatus() throws Exception {
if (autoMlIds.size() == 0) {
return;
}
// 存储需要更新的实验对象列表
List<AutoMl> updateAutoMls = new ArrayList<>();
for (Long autoMlId : autoMlIds) {
// 获取当前实验的所有实例列表
List<AutoMlIns> insList = autoMlInsDao.getByAutoMlId(autoMlId);
List<String> statusList = new ArrayList<>();
// 更新实验状态列表
for (int i = 0; i < insList.size(); i++) {
statusList.add(insList.get(i).getStatus());
}
String subStatus = statusList.toString().substring(1, statusList.toString().length() - 1);
AutoMl autoMl = autoMlDao.getAutoMlById(autoMlId);
if (!StringUtils.equals(autoMl.getStatusList(), subStatus)) {
autoMl.setStatusList(subStatus);
updateAutoMls.add(autoMl);
autoMlDao.edit(autoMl);
}
}

if (!updateAutoMls.isEmpty()) {
// 使用Iterator进行安全的删除操作
Iterator<Long> iterator = autoMlIds.iterator();
while (iterator.hasNext()) {
Long autoMlId = iterator.next();
for (AutoMl autoMl : updateAutoMls) {
if (autoMl.getId().equals(autoMlId)) {
iterator.remove();
}
}
}
}
}
}

+ 6
- 0
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/AutoMlInsService.java View File

@@ -15,4 +15,10 @@ public interface AutoMlInsService {
String removeById(Long id); String removeById(Long id);


String batchDelete(List<Long> ids); String batchDelete(List<Long> ids);

List<AutoMlIns> queryByAutoMlInsIsNotTerminated();

AutoMlIns queryStatusFromArgo(AutoMlIns autoMlIns);

boolean terminateAutoMlIns(Long id);
} }

+ 167
- 1
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/AutoMlInsServiceImpl.java View File

@@ -4,6 +4,11 @@ import com.ruoyi.platform.constant.Constant;
import com.ruoyi.platform.domain.AutoMlIns; import com.ruoyi.platform.domain.AutoMlIns;
import com.ruoyi.platform.mapper.AutoMlInsDao; import com.ruoyi.platform.mapper.AutoMlInsDao;
import com.ruoyi.platform.service.AutoMlInsService; import com.ruoyi.platform.service.AutoMlInsService;
import com.ruoyi.platform.utils.DateUtils;
import com.ruoyi.platform.utils.HttpUtils;
import com.ruoyi.platform.utils.JsonUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.data.domain.Page; import org.springframework.data.domain.Page;
import org.springframework.data.domain.PageImpl; import org.springframework.data.domain.PageImpl;
import org.springframework.data.domain.PageRequest; import org.springframework.data.domain.PageRequest;
@@ -11,10 +16,17 @@ import org.springframework.stereotype.Service;


import javax.annotation.Resource; import javax.annotation.Resource;
import java.io.IOException; import java.io.IOException;
import java.util.List;
import java.util.*;


@Service @Service
public class AutoMlInsServiceImpl implements AutoMlInsService { public class AutoMlInsServiceImpl implements AutoMlInsService {
@Value("${argo.url}")
private String argoUrl;
@Value("${argo.workflowStatus}")
private String argoWorkflowStatus;
@Value("${argo.workflowTermination}")
private String argoWorkflowTermination;

@Resource @Resource
private AutoMlInsDao autoMlInsDao; private AutoMlInsDao autoMlInsDao;


@@ -38,6 +50,14 @@ public class AutoMlInsServiceImpl implements AutoMlInsService {
if (autoMlIns == null) { if (autoMlIns == null) {
return "实验实例不存在"; return "实验实例不存在";
} }

if (StringUtils.isEmpty(autoMlIns.getStatus())) {
autoMlIns = queryStatusFromArgo(autoMlIns);
}
if (StringUtils.equals(autoMlIns.getStatus(), "Running")) {
return "实验实例正在运行,不可删除";
}

autoMlIns.setState(Constant.State_invalid); autoMlIns.setState(Constant.State_invalid);
int update = autoMlInsDao.update(autoMlIns); int update = autoMlInsDao.update(autoMlIns);
if (update > 0) { if (update > 0) {
@@ -57,4 +77,150 @@ public class AutoMlInsServiceImpl implements AutoMlInsService {
} }
return "删除成功"; return "删除成功";
} }

@Override
public List<AutoMlIns> queryByAutoMlInsIsNotTerminated() {
return autoMlInsDao.queryByAutoMlInsIsNotTerminated();
}

@Override
public AutoMlIns queryStatusFromArgo(AutoMlIns ins) {
String namespace = ins.getArgoInsNs();
String name = ins.getArgoInsName();
Long id = ins.getId();

// 创建请求数据map
Map<String, Object> requestData = new HashMap<>();
requestData.put("namespace", namespace);
requestData.put("name", name);

// 创建发送数据map,将请求数据作为"data"键的值
Map<String, Object> res = new HashMap<>();
res.put("data", requestData);

try {
// 发送POST请求到Argo工作流状态查询接口,并将请求数据转换为JSON
String req = HttpUtils.sendPost(argoUrl + argoWorkflowStatus, null, JsonUtils.mapToJson(res));
// 检查响应是否为空或无内容
if (req == null || StringUtils.isEmpty(req)) {
throw new RuntimeException("工作流状态响应为空。");
}
// 将响应的JSON字符串转换为Map对象
Map<String, Object> runResMap = JsonUtils.jsonToMap(req);
// 从响应Map中获取"data"部分
Map<String, Object> data = (Map<String, Object>) runResMap.get("data");
if (data == null || data.isEmpty()) {
throw new RuntimeException("工作流数据为空.");
}
// 从"data"中获取"status"部分,并返回"phase"的值
Map<String, Object> status = (Map<String, Object>) data.get("status");
if (status == null || status.isEmpty()) {
throw new RuntimeException("工作流状态为空。");
}

//解析流水线结束时间
String finishedAtString = (String) status.get("finishedAt");
if (finishedAtString != null && !finishedAtString.isEmpty()) {
Date finishTime = DateUtils.convertUTCtoShanghaiDate(finishedAtString);
ins.setUpdateTime(finishTime);
}

// 解析nodes字段,提取节点状态并转换为JSON字符串
Map<String, Object> nodes = (Map<String, Object>) status.get("nodes");
Map<String, Object> modifiedNodes = new LinkedHashMap<>();
if (nodes != null) {
for (Map.Entry<String, Object> nodeEntry : nodes.entrySet()) {
Map<String, Object> nodeDetails = (Map<String, Object>) nodeEntry.getValue();
String templateName = (String) nodeDetails.get("displayName");
modifiedNodes.put(templateName, nodeDetails);
}
}

String nodeStatusJson = JsonUtils.mapToJson(modifiedNodes);
ins.setNodeStatus(nodeStatusJson);

//终止态为终止不改
if (!StringUtils.equals(ins.getStatus(), "Terminated")) {
ins.setStatus(StringUtils.isNotEmpty((String) status.get("phase")) ? (String) status.get("phase") : "Pending");
}
if (StringUtils.equals(ins.getStatus(), "Error")) {
ins.setStatus("Failed");
}
return ins;
} catch (Exception e) {
throw new RuntimeException("查询状态失败: " + e.getMessage(), e);
}
}

@Override
public boolean terminateAutoMlIns(Long id) {
AutoMlIns autoMlIns = autoMlInsDao.queryById(id);
if (autoMlIns == null) {
throw new IllegalStateException("实验实例未查询到,id: " + id);
}

String currentStatus = autoMlIns.getStatus();
String name = autoMlIns.getArgoInsName();
String namespace = autoMlIns.getArgoInsNs();

// 获取当前状态,如果为空,则从Argo查询
if (StringUtils.isEmpty(currentStatus)) {
currentStatus = queryStatusFromArgo(autoMlIns).getStatus();
}
// 只有状态是"Running"时才能终止实例
if (!currentStatus.equalsIgnoreCase(Constant.Running)) {
return false; // 如果不是"Running"状态,则不执行终止操作
}

// 创建请求数据map
Map<String, Object> requestData = new HashMap<>();
requestData.put("namespace", namespace);
requestData.put("name", name);
// 创建发送数据map,将请求数据作为"data"键的值
Map<String, Object> res = new HashMap<>();
res.put("data", requestData);

try {
// 发送POST请求到Argo工作流状态查询接口,并将请求数据转换为JSON
String req = HttpUtils.sendPost(argoUrl + argoWorkflowTermination, null, JsonUtils.mapToJson(res));
// 检查响应是否为空或无内容
if (StringUtils.isEmpty(req)) {
throw new RuntimeException("终止响应内容为空。");

}
// 将响应的JSON字符串转换为Map对象
Map<String, Object> runResMap = JsonUtils.jsonToMap(req);
// 从响应Map中直接获取"errCode"的值
Integer errCode = (Integer) runResMap.get("errCode");
if (errCode != null && errCode == 0) {
//更新autoMlIns,确保状态更新被保存到数据库
AutoMlIns ins = queryStatusFromArgo(autoMlIns);
String nodeStatus = ins.getNodeStatus();
Map<String, Object> nodeMap = JsonUtils.jsonToMap(nodeStatus);

// 遍历 map
for (Map.Entry<String, Object> entry : nodeMap.entrySet()) {
// 获取每个 Map 中的值并强制转换为 Map
Map<String, Object> innerMap = (Map<String, Object>) entry.getValue();
// 检查 phase 的值
if (innerMap.containsKey("phase")) {
String phaseValue = (String) innerMap.get("phase");
// 如果值不等于 Succeeded,则赋值为 Failed
if (!StringUtils.equals("Succeeded", phaseValue)) {
innerMap.put("phase", "Failed");
}
}
}
ins.setNodeStatus(JsonUtils.mapToJson(nodeMap));
ins.setStatus(Constant.Terminated);
ins.setUpdateTime(new Date());
this.autoMlInsDao.update(ins);
return true;
} else {
return false;
}
} catch (Exception e) {
throw new RuntimeException("终止实例错误: " + e.getMessage(), e);
}
}
} }

+ 61
- 96
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/AutoMlServiceImpl.java View File

@@ -3,15 +3,19 @@ package com.ruoyi.platform.service.impl;
import com.ruoyi.common.security.utils.SecurityUtils; import com.ruoyi.common.security.utils.SecurityUtils;
import com.ruoyi.platform.constant.Constant; import com.ruoyi.platform.constant.Constant;
import com.ruoyi.platform.domain.AutoMl; import com.ruoyi.platform.domain.AutoMl;
import com.ruoyi.platform.domain.AutoMlIns;
import com.ruoyi.platform.mapper.AutoMlDao; import com.ruoyi.platform.mapper.AutoMlDao;
import com.ruoyi.platform.mapper.AutoMlInsDao;
import com.ruoyi.platform.service.AutoMlService; import com.ruoyi.platform.service.AutoMlService;
import com.ruoyi.platform.utils.*;
import com.ruoyi.platform.utils.FileUtil;
import com.ruoyi.platform.utils.HttpUtils;
import com.ruoyi.platform.utils.JacksonUtil;
import com.ruoyi.platform.utils.JsonUtils;
import com.ruoyi.platform.vo.AutoMlParamVo;
import com.ruoyi.platform.vo.AutoMlVo; import com.ruoyi.platform.vo.AutoMlVo;
import io.kubernetes.client.openapi.models.V1Pod;
import org.apache.commons.collections4.MapUtils;
import org.apache.commons.io.FileUtils; import org.apache.commons.io.FileUtils;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.BeanUtils; import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Value; import org.springframework.beans.factory.annotation.Value;
import org.springframework.data.domain.Page; import org.springframework.data.domain.Page;
@@ -26,29 +30,24 @@ import java.io.IOException;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.concurrent.CompletableFuture;


@Service("autoMLService") @Service("autoMLService")
public class AutoMlServiceImpl implements AutoMlService { public class AutoMlServiceImpl implements AutoMlService {
@Value("${harbor.serviceNS}")
private String serviceNS;
@Value("${dockerpush.proxyUrl}")
private String proxyUrl;
@Value("${dockerpush.mountPath}")
private String mountPath;
@Value("${minio.pvcName}")
private String pvcName;
@Value("${automl.image}")
private String image;
@Value("${git.localPath}") @Value("${git.localPath}")
String localPath; String localPath;


private static final Logger logger = LoggerFactory.getLogger(ModelsServiceImpl.class);
@Value("${argo.url}")
private String argoUrl;
@Value("${argo.convertAutoML}")
String convertAutoML;
@Value("${argo.workflowRun}")
private String argoWorkflowRun;


@Resource @Resource
private AutoMlDao autoMlDao; private AutoMlDao autoMlDao;

@Resource @Resource
private K8sClientUtil k8sClientUtil;
private AutoMlInsDao autoMlInsDao;


@Override @Override
public Page<AutoMl> queryByPage(String mlName, PageRequest pageRequest) { public Page<AutoMl> queryByPage(String mlName, PageRequest pageRequest) {
@@ -146,86 +145,52 @@ public class AutoMlServiceImpl implements AutoMlService {
public String runAutoMlIns(Long id) throws Exception { public String runAutoMlIns(Long id) throws Exception {
AutoMl autoMl = autoMlDao.getAutoMlById(id); AutoMl autoMl = autoMlDao.getAutoMlById(id);
if (autoMl == null) { if (autoMl == null) {
throw new Exception("开发环境配置不存在");
}

StringBuffer command = new StringBuffer();
command.append("nohup python /opt/automl.py --task_type " + autoMl.getTaskType());
if (StringUtils.isNotEmpty(autoMl.getTargetColumns())) {
command.append(" --target_columns " + autoMl.getTargetColumns());
} else {
throw new Exception("目标列为空");
}

// String username = SecurityUtils.getLoginUser().getUsername().toLowerCase();
String username = "admin";
//构造pod名称
String podName = username + "-autoMlIns-pod-" + id;
V1Pod pod = k8sClientUtil.createPodWithEnv(podName, serviceNS, proxyUrl, mountPath, pvcName, image);

if (autoMl.getTimeLeftForThisTask() != null) {
command.append(" --time_left_for_this_task " + autoMl.getTimeLeftForThisTask());
}
if (autoMl.getPerRunTimeLimit() != null) {
command.append(" --per_run_time_limit " + autoMl.getPerRunTimeLimit());
}
if (autoMl.getEnsembleSize() != null) {
command.append(" --ensemble_size " + autoMl.getEnsembleSize());
}
if (StringUtils.isNotEmpty(autoMl.getEnsembleClass())) {
command.append(" --ensemble_class " + autoMl.getEnsembleClass());
}
if (autoMl.getEnsembleNbest() != null) {
command.append(" --ensemble_nbest " + autoMl.getEnsembleNbest());
}
if (autoMl.getMaxModelsOnDisc() != null) {
command.append(" --max_models_on_disc " + autoMl.getMaxModelsOnDisc());
}
if (autoMl.getSeed() != null) {
command.append(" --seed " + autoMl.getSeed());
}
if (autoMl.getMemoryLimit() != null) {
command.append(" --memory_limit " + autoMl.getMemoryLimit());
}
if (StringUtils.isNotEmpty(autoMl.getIncludeClassifier())) {
command.append(" --include_classifier " + autoMl.getIncludeClassifier());
}
if (StringUtils.isNotEmpty(autoMl.getIncludeRegressor())) {
command.append(" --include_regressor " + autoMl.getIncludeRegressor());
}
if (StringUtils.isNotEmpty(autoMl.getIncludeFeaturePreprocessor())) {
command.append(" --include_feature_preprocessor " + autoMl.getIncludeFeaturePreprocessor());
}
if (StringUtils.isNotEmpty(autoMl.getExcludeClassifier())) {
command.append(" --exclude_classifier " + autoMl.getExcludeClassifier());
}
if (StringUtils.isNotEmpty(autoMl.getExcludeRegressor())) {
command.append(" --exclude_regressor " + autoMl.getExcludeRegressor());
}
if (StringUtils.isNotEmpty(autoMl.getExcludeFeaturePreprocessor())) {
command.append(" --exclude_feature_preprocessor " + autoMl.getExcludeFeaturePreprocessor());
}
if (StringUtils.isNotEmpty(autoMl.getResamplingStrategy())) {
command.append(" --resampling_strategy " + autoMl.getResamplingStrategy());
}
if (autoMl.getTrainSize() != null) {
command.append(" --train_size " + autoMl.getTrainSize());
}
if (autoMl.getShuffle() != null) {
command.append(" --shuffle " + autoMl.getShuffle());
}
if (autoMl.getFolds() != null) {
command.append(" --folds " + autoMl.getFolds());
}
command.append(" &");
CompletableFuture.supplyAsync(() -> {
try {
String log = k8sClientUtil.executeCommand(pod, String.valueOf(command));
} catch (Exception e) {
logger.error(e.getMessage(), e);
throw new Exception("自动机器学习配置不存在");
}

AutoMlParamVo autoMlParam = new AutoMlParamVo();
BeanUtils.copyProperties(autoMl, autoMlParam);
autoMlParam.setDataset(JsonUtils.jsonToMap(autoMl.getDataset()));
String param = JsonUtils.objectToJson(autoMlParam);
// 调argo转换接口
try {
String convertRes = HttpUtils.sendPost(argoUrl + convertAutoML, param);
if (convertRes == null || StringUtils.isEmpty(convertRes)) {
throw new RuntimeException("转换流水线失败");
} }
return null;
});
Map<String, Object> converMap = JsonUtils.jsonToMap(convertRes);
// 组装运行接口json
Map<String, Object> output = (Map<String, Object>) converMap.get("output");
Map<String, Object> runReqMap = new HashMap<>();
runReqMap.put("data", converMap.get("data"));
// 调argo运行接口
String runRes = HttpUtils.sendPost(argoUrl + argoWorkflowRun, JsonUtils.mapToJson(runReqMap));

if (runRes == null || StringUtils.isEmpty(runRes)) {
throw new RuntimeException("Failed to run workflow.");
}
Map<String, Object> runResMap = JsonUtils.jsonToMap(runRes);
Map<String, Object> data = (Map<String, Object>) runResMap.get("data");
//判断data为空
if (data == null || MapUtils.isEmpty(data)) {
throw new RuntimeException("Failed to run workflow.");
}
Map<String, Object> metadata = (Map<String, Object>) data.get("metadata");
// 插入记录到实验实例表
AutoMlIns autoMlIns = new AutoMlIns();
autoMlIns.setAutoMlId(autoMl.getId());
autoMlIns.setArgoInsNs((String) metadata.get("namespace"));
autoMlIns.setArgoInsName((String) metadata.get("name"));
autoMlIns.setParam(param);
autoMlIns.setStatus(Constant.Pending);
//替换argoInsName
String outputString = JsonUtils.mapToJson(output);
autoMlIns.setNodeResult(outputString.replace("{{workflow.name}}", (String) metadata.get("name")));
autoMlInsDao.insert(autoMlIns);

} catch (Exception e) {
throw new RuntimeException(e);
}
return "执行成功"; return "执行成功";
} }
} }

+ 155
- 0
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/vo/AutoMlParamVo.java View File

@@ -0,0 +1,155 @@
package com.ruoyi.platform.vo;

import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.databind.PropertyNamingStrategy;
import com.fasterxml.jackson.databind.annotation.JsonNaming;
import io.swagger.annotations.ApiModel;
import io.swagger.annotations.ApiModelProperty;
import lombok.Data;

import java.util.Map;

@Data
@JsonNaming(PropertyNamingStrategy.SnakeCaseStrategy.class)
@JsonInclude(JsonInclude.Include.NON_NULL)
@ApiModel(description = "自动机器学习参数")
public class AutoMlParamVo {
@ApiModelProperty(value = "任务类型:classification或regression")
private String taskType;

@ApiModelProperty(value = "搜索合适模型的时间限制(以秒为单位)。通过增加这个值,auto-sklearn有更高的机会找到更好的模型。默认3600,非必传。")
private Integer timeLeftForThisTask;

@ApiModelProperty(value = "单次调用机器学习模型的时间限制(以秒为单位)。如果机器学习算法运行超过时间限制,将终止模型拟合。将这个值设置得足够高,这样典型的机器学习算法就可以适用于训练数据。默认600,非必传。")
private Integer perRunTimeLimit;

@ApiModelProperty(value = "集成模型数量,如果设置为0,则没有集成。默认50,非必传。")
private Integer ensembleSize;

@ApiModelProperty(value = "设置为None将禁用集成构建,设置为SingleBest仅使用单个最佳模型而不是集成,设置为default,它将对单目标问题使用EnsembleSelection,对多目标问题使用MultiObjectiveDummyEnsemble。默认default,非必传。")
private String ensembleClass;

@ApiModelProperty(value = "在构建集成时只考虑ensemble_nbest模型。这是受到了“最大限度地利用集成选择”中引入的库修剪概念的启发。这是独立于ensemble_class参数的,并且这个修剪步骤是在构造集成之前完成的。默认50,非必传。")
private Integer ensembleNbest;

@ApiModelProperty(value = "定义在磁盘中保存的模型的最大数量。额外的模型数量将被永久删除。由于这个变量的性质,它设置了一个集成可以使用多少个模型的上限。必须是大于等于1的整数。如果设置为None,则所有模型都保留在磁盘上。默认50,非必传。")
private Integer maxModelsOnDisc;

@ApiModelProperty(value = "随机种子,将决定输出文件名。默认1,非必传。")
private Integer seed;

@ApiModelProperty(value = "机器学习算法的内存限制(MB)。如果auto-sklearn试图分配超过memory_limit MB,它将停止拟合机器学习算法。默认3072,非必传。")
private Integer memoryLimit;

@ApiModelProperty(value = "如果为None,则使用所有可能的分类算法。否则,指定搜索中包含的步骤和组件。有关可用组件,请参见/pipeline/components/<step>/*。与参数exclude不兼容。多选,逗号分隔。包含:adaboost\n" +
"bernoulli_nb\n" +
"decision_tree\n" +
"extra_trees\n" +
"gaussian_nb\n" +
"gradient_boosting\n" +
"k_nearest_neighbors\n" +
"lda\n" +
"liblinear_svc\n" +
"libsvm_svc\n" +
"mlp\n" +
"multinomial_nb\n" +
"passive_aggressive\n" +
"qda\n" +
"random_forest\n" +
"sgd")
private String includeClassifier;

@ApiModelProperty(value = "如果为None,则使用所有可能的特征预处理算法。否则,指定搜索中包含的步骤和组件。有关可用组件,请参见/pipeline/components/<step>/*。与参数exclude不兼容。多选,逗号分隔。包含:densifier\n" +
"extra_trees_preproc_for_classification\n" +
"extra_trees_preproc_for_regression\n" +
"fast_ica\n" +
"feature_agglomeration\n" +
"kernel_pca\n" +
"kitchen_sinks\n" +
"liblinear_svc_preprocessor\n" +
"no_preprocessing\n" +
"nystroem_sampler\n" +
"pca\n" +
"polynomial\n" +
"random_trees_embedding\n" +
"select_percentile_classification\n" +
"select_percentile_regression\n" +
"select_rates_classification\n" +
"select_rates_regression\n" +
"truncatedSVD")
private String includeFeaturePreprocessor;

@ApiModelProperty(value = "如果为None,则使用所有可能的回归算法。否则,指定搜索中包含的步骤和组件。有关可用组件,请参见/pipeline/components/<step>/*。与参数exclude不兼容。多选,逗号分隔。包含:adaboost,\n" +
"ard_regression,\n" +
"decision_tree,\n" +
"extra_trees,\n" +
"gaussian_process,\n" +
"gradient_boosting,\n" +
"k_nearest_neighbors,\n" +
"liblinear_svr,\n" +
"libsvm_svr,\n" +
"mlp,\n" +
"random_forest,\n" +
"sgd")
private String includeRegressor;

private String excludeClassifier;

private String excludeRegressor;

private String excludeFeaturePreprocessor;

@ApiModelProperty(value = "测试集的比率,0到1之间")
private Float testSize;

@ApiModelProperty(value = "如何处理过拟合,如果使用基于“cv”的方法或Splitter对象,可能需要使用resampling_strategy_arguments。holdout或crossValid")
private String resamplingStrategy;

@ApiModelProperty(value = "重采样划分训练集和验证集,训练集的比率,0到1之间")
private Float trainSize;

@ApiModelProperty(value = "拆分数据前是否进行shuffle")
private Boolean shuffle;

@ApiModelProperty(value = "交叉验证的折数,当resamplingStrategy为crossValid时,此项必填,为整数")
private Integer folds;

@ApiModelProperty(value = "数据集csv文件中哪几列是预测目标列,逗号分隔")
private String targetColumns;

@ApiModelProperty(value = "自定义指标名称")
private String metricName;

@ApiModelProperty(value = "模型优化目标指标及权重,json格式。分类的指标包含:accuracy\n" +
"balanced_accuracy\n" +
"roc_auc\n" +
"average_precision\n" +
"log_loss\n" +
"precision_macro\n" +
"precision_micro\n" +
"precision_samples\n" +
"precision_weighted\n" +
"recall_macro\n" +
"recall_micro\n" +
"recall_samples\n" +
"recall_weighted\n" +
"f1_macro\n" +
"f1_micro\n" +
"f1_samples\n" +
"f1_weighted\n" +
"回归的指标包含:mean_absolute_error\n" +
"mean_squared_error\n" +
"root_mean_squared_error\n" +
"mean_squared_log_error\n" +
"median_absolute_error\n" +
"r2")
private String metrics;

@ApiModelProperty(value = "指标优化方向,是越大越好还是越小越好")
private Boolean greaterIsBetter;

@ApiModelProperty(value = "模型计算并打印指标")
private String scoringFunctions;

private Map<String,Object> dataset;
}

+ 3
- 3
ruoyi-modules/management-platform/src/main/resources/mapper/managementPlatform/AutoMlDao.xml View File

@@ -34,9 +34,9 @@
<if test="autoMl.mlDescription != null and autoMl.mlDescription !=''"> <if test="autoMl.mlDescription != null and autoMl.mlDescription !=''">
ml_description = #{autoMl.mlDescription}, ml_description = #{autoMl.mlDescription},
</if> </if>
<!-- <if test="autoMl.runState != null and autoMl.runState !=''">-->
<!-- run_state = #{autoMl.runState},-->
<!-- </if>-->
<if test="autoMl.statusList != null and autoMl.statusList !=''">
status_list = #{autoMl.statusList},
</if>
<!-- <if test="autoMl.progress != null">--> <!-- <if test="autoMl.progress != null">-->
<!-- progress = #{autoMl.progress},--> <!-- progress = #{autoMl.progress},-->
<!-- </if>--> <!-- </if>-->


+ 22
- 2
ruoyi-modules/management-platform/src/main/resources/mapper/managementPlatform/AutoMlInsDao.xml View File

@@ -3,9 +3,9 @@
<mapper namespace="com.ruoyi.platform.mapper.AutoMlInsDao"> <mapper namespace="com.ruoyi.platform.mapper.AutoMlInsDao">


<insert id="insert" keyProperty="id" useGeneratedKeys="true"> <insert id="insert" keyProperty="id" useGeneratedKeys="true">
insert into auto_ml_ins(auto_ml_id, model_path, img_path, node_status, node_result, param, source)
insert into auto_ml_ins(auto_ml_id, model_path, img_path, node_status, node_result, param, source, argo_ins_name, argo_ins_ns)
values (#{autoMlIns.autoMlId}, #{autoMlIns.modelPath}, #{autoMlIns.imgPath}, #{autoMlIns.nodeStatus}, values (#{autoMlIns.autoMlId}, #{autoMlIns.modelPath}, #{autoMlIns.imgPath}, #{autoMlIns.nodeStatus},
#{autoMlIns.nodeResult}, #{autoMlIns.param}, #{autoMlIns.source})
#{autoMlIns.nodeResult}, #{autoMlIns.param}, #{autoMlIns.source}, #{autoMlIns.argoInsName}, #{autoMlIns.argoInsNs})
</insert> </insert>


<update id="update"> <update id="update">
@@ -29,7 +29,11 @@
<if test="autoMlIns.state != null"> <if test="autoMlIns.state != null">
state = #{autoMlIns.state}, state = #{autoMlIns.state},
</if> </if>
<if test="autoMlIns.updateTime != null">
update_time = #{autoMlIns.updateTime},
</if>
</set> </set>
where id = #{autoMlIns.id}
</update> </update>


<select id="count" resultType="java.lang.Long"> <select id="count" resultType="java.lang.Long">
@@ -61,4 +65,20 @@
state = 1 and id = #{id} state = 1 and id = #{id}
</where> </where>
</select> </select>

<select id="queryByAutoMlInsIsNotTerminated" resultType="com.ruoyi.platform.domain.AutoMlIns">
select *
from auto_ml_ins
where (status NOT IN ('Terminated', 'Succeeded', 'Failed')
OR status IS NULL)
and state = 1
</select>

<select id="getByAutoMlId" resultType="com.ruoyi.platform.domain.AutoMlIns">
select *
from auto_ml_ins
where auto_ml_id = #{autoMlId}
and state = 1
order by update_time DESC limit 5
</select>
</mapper> </mapper>

Loading…
Cancel
Save