Browse Source

Merge remote-tracking branch 'origin/dev-automl' into dev

dev-complex-computation
chenzhihang 1 year ago
parent
commit
2b43452dbc
19 changed files with 1606 additions and 8 deletions
  1. +3
    -0
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/constant/Constant.java
  2. +71
    -0
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/autoML/AutoMlController.java
  3. +61
    -0
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/autoML/AutoMlInsController.java
  4. +0
    -1
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/codeConfig/CodeConfigController.java
  5. +184
    -0
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/domain/AutoMl.java
  6. +50
    -0
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/domain/AutoMlIns.java
  7. +21
    -0
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/mapper/AutoMlDao.java
  8. +23
    -0
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/mapper/AutoMlInsDao.java
  9. +97
    -0
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/scheduling/AutoMlInsStatusTask.java
  10. +29
    -0
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/AutoMlInsService.java
  11. +26
    -0
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/AutoMlService.java
  12. +250
    -0
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/AutoMlInsServiceImpl.java
  13. +216
    -0
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/AutoMlServiceImpl.java
  14. +0
    -2
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/JupyterServiceImpl.java
  15. +155
    -0
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/vo/AutoMlParamVo.java
  16. +183
    -0
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/vo/AutoMlVo.java
  17. +5
    -5
      ruoyi-modules/management-platform/src/main/resources/bootstrap.yml
  18. +146
    -0
      ruoyi-modules/management-platform/src/main/resources/mapper/managementPlatform/AutoMlDao.xml
  19. +86
    -0
      ruoyi-modules/management-platform/src/main/resources/mapper/managementPlatform/AutoMlInsDao.xml

+ 3
- 0
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/constant/Constant.java View File

@@ -29,10 +29,13 @@ public class Constant {
public final static String Running = "Running";
public final static String Failed = "Failed";
public final static String Pending = "Pending";
public final static String Terminated = "Terminated";
public final static String Init = "Init";
public final static String Stopped = "Stopped";
public final static String Succeeded = "Succeeded";

public final static String Type_Train = "train";
public final static String Type_Evaluate = "evaluate";

public final static String AutoMl_Classification = "classification";
}

+ 71
- 0
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/autoML/AutoMlController.java View File

@@ -0,0 +1,71 @@
package com.ruoyi.platform.controller.autoML;

import com.ruoyi.common.core.web.controller.BaseController;
import com.ruoyi.common.core.web.domain.AjaxResult;
import com.ruoyi.common.core.web.domain.GenericsAjaxResult;
import com.ruoyi.platform.domain.AutoMl;
import com.ruoyi.platform.service.AutoMlService;
import com.ruoyi.platform.vo.AutoMlVo;
import io.swagger.annotations.Api;
import io.swagger.annotations.ApiOperation;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.PageRequest;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.multipart.MultipartFile;

import javax.annotation.Resource;
import java.io.IOException;

@RestController
@RequestMapping("autoML")
@Api("自动机器学习")
public class AutoMlController extends BaseController {

@Resource
private AutoMlService autoMlService;

@GetMapping
@ApiOperation("分页查询")
public GenericsAjaxResult<Page<AutoMl>> queryByPage(@RequestParam("page") int page,
@RequestParam("size") int size,
@RequestParam(value = "ml_name", required = false) String mlName) {
PageRequest pageRequest = PageRequest.of(page, size);
return genericsSuccess(this.autoMlService.queryByPage(mlName, pageRequest));
}

@PostMapping
@ApiOperation("新增自动机器学习")
public GenericsAjaxResult<AutoMl> addAutoMl(@RequestBody AutoMlVo autoMlVo) throws Exception {
return genericsSuccess(this.autoMlService.save(autoMlVo));
}

@PutMapping
@ApiOperation("编辑自动机器学习")
public GenericsAjaxResult<String> editAutoMl(@RequestBody AutoMlVo autoMlVo) throws Exception {
return genericsSuccess(this.autoMlService.edit(autoMlVo));
}
@GetMapping("/getAutoMlDetail")
@ApiOperation("获取自动机器学习详细信息")
public GenericsAjaxResult<AutoMlVo> getAutoMlDetail(@RequestParam("id") Long id) throws IOException {
return genericsSuccess(this.autoMlService.getAutoMlDetail(id));
}

@DeleteMapping("{id}")
@ApiOperation("删除自动机器学习")
public GenericsAjaxResult<String> deleteAutoMl(@PathVariable("id") Long id) {
return genericsSuccess(this.autoMlService.delete(id));
}

@CrossOrigin(origins = "*", allowedHeaders = "*")
@PostMapping("/upload")
@ApiOperation(value = "上传数据文件csv", notes = "上传数据文件csv,并将信息存入数据库。")
public AjaxResult upload(@RequestParam("file") MultipartFile file, @RequestParam("uuid") String uuid) throws Exception {
return AjaxResult.success(this.autoMlService.upload(file, uuid));
}

@PostMapping("/run/{id}")
@ApiOperation("运行自动机器学习实验")
public GenericsAjaxResult<String> runAutoML(@PathVariable("id") Long id) throws Exception {
return genericsSuccess(this.autoMlService.runAutoMlIns(id));
}
}

+ 61
- 0
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/autoML/AutoMlInsController.java View File

@@ -0,0 +1,61 @@
package com.ruoyi.platform.controller.autoML;

import com.ruoyi.common.core.web.controller.BaseController;
import com.ruoyi.common.core.web.domain.GenericsAjaxResult;
import com.ruoyi.platform.domain.AutoMlIns;
import com.ruoyi.platform.service.AutoMlInsService;
import io.swagger.annotations.Api;
import io.swagger.annotations.ApiOperation;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.PageRequest;
import org.springframework.web.bind.annotation.*;

import javax.annotation.Resource;
import java.io.IOException;
import java.util.List;

@RestController
@RequestMapping("autoMLIns")
@Api("自动机器学习实验实例")
public class AutoMlInsController extends BaseController {

@Resource
private AutoMlInsService autoMLInsService;

@GetMapping
@ApiOperation("分页查询")
public GenericsAjaxResult<Page<AutoMlIns>> queryByPage(AutoMlIns autoMlIns, int page, int size) throws IOException {
PageRequest pageRequest = PageRequest.of(page, size);
return genericsSuccess(this.autoMLInsService.queryByPage(autoMlIns, pageRequest));
}

@PostMapping
@ApiOperation("新增实验实例")
public GenericsAjaxResult<AutoMlIns> add(@RequestBody AutoMlIns autoMlIns) {
return genericsSuccess(this.autoMLInsService.insert(autoMlIns));
}

@DeleteMapping("{id}")
@ApiOperation("删除实验实例")
public GenericsAjaxResult<String> deleteById(@PathVariable("id") Long id) {
return genericsSuccess(this.autoMLInsService.removeById(id));
}

@DeleteMapping("batchDelete")
@ApiOperation("批量删除实验实例")
public GenericsAjaxResult<String> batchDelete(@RequestBody List<Long> ids) {
return genericsSuccess(this.autoMLInsService.batchDelete(ids));
}

@PutMapping("{id}")
@ApiOperation("终止实验实例")
public GenericsAjaxResult<Boolean> terminateAutoMlIns(@PathVariable("id") Long id) {
return genericsSuccess(this.autoMLInsService.terminateAutoMlIns(id));
}

@GetMapping("{id}")
@ApiOperation("查看实验实例详情")
public GenericsAjaxResult<AutoMlIns> getDetailById(@PathVariable("id") Long id) {
return genericsSuccess(this.autoMLInsService.getDetailById(id));
}
}

+ 0
- 1
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/codeConfig/CodeConfigController.java View File

@@ -7,7 +7,6 @@ import com.ruoyi.platform.service.CodeConfigService;
import io.swagger.annotations.Api;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.PageRequest;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.*;

import javax.annotation.Resource;


+ 184
- 0
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/domain/AutoMl.java View File

@@ -0,0 +1,184 @@
package com.ruoyi.platform.domain;

import com.baomidou.mybatisplus.annotation.TableField;
import com.fasterxml.jackson.databind.PropertyNamingStrategy;
import com.fasterxml.jackson.databind.annotation.JsonNaming;
import com.ruoyi.platform.vo.VersionVo;
import io.swagger.annotations.ApiModel;
import io.swagger.annotations.ApiModelProperty;
import lombok.Data;

import java.util.Date;
import java.util.Map;

@Data
@JsonNaming(PropertyNamingStrategy.SnakeCaseStrategy.class)
@ApiModel(description = "自动机器学习")
public class AutoMl {
private Long id;

@ApiModelProperty(value = "实验名称")
private String mlName;

@ApiModelProperty(value = "实验描述")
private String mlDescription;

@ApiModelProperty(value = "任务类型:classification或regression")
private String taskType;

@ApiModelProperty(value = "搜索合适模型的时间限制(以秒为单位)。通过增加这个值,auto-sklearn有更高的机会找到更好的模型。默认3600,非必传。")
private Integer timeLeftForThisTask;

@ApiModelProperty(value = "单次调用机器学习模型的时间限制(以秒为单位)。如果机器学习算法运行超过时间限制,将终止模型拟合。将这个值设置得足够高,这样典型的机器学习算法就可以适用于训练数据。默认600,非必传。")
private Integer perRunTimeLimit;

@ApiModelProperty(value = "集成模型数量,如果设置为0,则没有集成。默认50,非必传。")
private Integer ensembleSize;

@ApiModelProperty(value = "设置为None将禁用集成构建,设置为SingleBest仅使用单个最佳模型而不是集成,设置为default,它将对单目标问题使用EnsembleSelection,对多目标问题使用MultiObjectiveDummyEnsemble。默认default,非必传。")
private String ensembleClass;

@ApiModelProperty(value = "在构建集成时只考虑ensemble_nbest模型。这是受到了“最大限度地利用集成选择”中引入的库修剪概念的启发。这是独立于ensemble_class参数的,并且这个修剪步骤是在构造集成之前完成的。默认50,非必传。")
private Integer ensembleNbest;

@ApiModelProperty(value = "定义在磁盘中保存的模型的最大数量。额外的模型数量将被永久删除。由于这个变量的性质,它设置了一个集成可以使用多少个模型的上限。必须是大于等于1的整数。如果设置为None,则所有模型都保留在磁盘上。默认50,非必传。")
private Integer maxModelsOnDisc;

@ApiModelProperty(value = "随机种子,将决定输出文件名。默认1,非必传。")
private Integer seed;

@ApiModelProperty(value = "机器学习算法的内存限制(MB)。如果auto-sklearn试图分配超过memory_limit MB,它将停止拟合机器学习算法。默认3072,非必传。")
private Integer memoryLimit;

@ApiModelProperty(value = "如果为None,则使用所有可能的分类算法。否则,指定搜索中包含的步骤和组件。有关可用组件,请参见/pipeline/components/<step>/*。与参数exclude不兼容。多选,逗号分隔。包含:adaboost\n" +
"bernoulli_nb\n" +
"decision_tree\n" +
"extra_trees\n" +
"gaussian_nb\n" +
"gradient_boosting\n" +
"k_nearest_neighbors\n" +
"lda\n" +
"liblinear_svc\n" +
"libsvm_svc\n" +
"mlp\n" +
"multinomial_nb\n" +
"passive_aggressive\n" +
"qda\n" +
"random_forest\n" +
"sgd")
private String includeClassifier;

@ApiModelProperty(value = "如果为None,则使用所有可能的特征预处理算法。否则,指定搜索中包含的步骤和组件。有关可用组件,请参见/pipeline/components/<step>/*。与参数exclude不兼容。多选,逗号分隔。包含:densifier\n" +
"extra_trees_preproc_for_classification\n" +
"extra_trees_preproc_for_regression\n" +
"fast_ica\n" +
"feature_agglomeration\n" +
"kernel_pca\n" +
"kitchen_sinks\n" +
"liblinear_svc_preprocessor\n" +
"no_preprocessing\n" +
"nystroem_sampler\n" +
"pca\n" +
"polynomial\n" +
"random_trees_embedding\n" +
"select_percentile_classification\n" +
"select_percentile_regression\n" +
"select_rates_classification\n" +
"select_rates_regression\n" +
"truncatedSVD")
private String includeFeaturePreprocessor;

@ApiModelProperty(value = "如果为None,则使用所有可能的回归算法。否则,指定搜索中包含的步骤和组件。有关可用组件,请参见/pipeline/components/<step>/*。与参数exclude不兼容。多选,逗号分隔。包含:adaboost,\n" +
"ard_regression,\n" +
"decision_tree,\n" +
"extra_trees,\n" +
"gaussian_process,\n" +
"gradient_boosting,\n" +
"k_nearest_neighbors,\n" +
"liblinear_svr,\n" +
"libsvm_svr,\n" +
"mlp,\n" +
"random_forest,\n" +
"sgd")
private String includeRegressor;

private String excludeClassifier;

private String excludeRegressor;

private String excludeFeaturePreprocessor;

@ApiModelProperty(value = "测试集的比率,0到1之间")
private Float testSize;

@ApiModelProperty(value = "如何处理过拟合,如果使用基于“cv”的方法或Splitter对象,可能需要使用resampling_strategy_arguments。holdout或crossValid")
private String resamplingStrategy;

@ApiModelProperty(value = "重采样划分训练集和验证集,训练集的比率,0到1之间")
private Float trainSize;

@ApiModelProperty(value = "拆分数据前是否进行shuffle")
private Boolean shuffle;

@ApiModelProperty(value = "交叉验证的折数,当resamplingStrategy为crossValid时,此项必填,为整数")
private Integer folds;

@ApiModelProperty(value = "文件夹存放配置输出和日志文件,默认/tmp/automl")
private String tmpFolder;

@ApiModelProperty(value = "数据集csv文件中哪几列是预测目标列,逗号分隔")
private String targetColumns;

@ApiModelProperty(value = "自定义指标名称")
private String metricName;

@ApiModelProperty(value = "模型优化目标指标及权重,json格式。分类的指标包含:accuracy\n" +
"balanced_accuracy\n" +
"roc_auc\n" +
"average_precision\n" +
"log_loss\n" +
"precision_macro\n" +
"precision_micro\n" +
"precision_samples\n" +
"precision_weighted\n" +
"recall_macro\n" +
"recall_micro\n" +
"recall_samples\n" +
"recall_weighted\n" +
"f1_macro\n" +
"f1_micro\n" +
"f1_samples\n" +
"f1_weighted\n" +
"回归的指标包含:mean_absolute_error\n" +
"mean_squared_error\n" +
"root_mean_squared_error\n" +
"mean_squared_log_error\n" +
"median_absolute_error\n" +
"r2")
private String metrics;

@ApiModelProperty(value = "指标优化方向,是越大越好还是越小越好")
private Boolean greaterIsBetter;

@ApiModelProperty(value = "模型计算并打印指标")
private String scoringFunctions;

private Integer state;

private String runState;

private Double progress;

private String createBy;

private Date createTime;

private String updateBy;

private Date updateTime;

private String dataset;

@ApiModelProperty(value = "状态列表")
private String statusList;
}

+ 50
- 0
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/domain/AutoMlIns.java View File

@@ -0,0 +1,50 @@
package com.ruoyi.platform.domain;

import com.fasterxml.jackson.databind.PropertyNamingStrategy;
import com.fasterxml.jackson.databind.annotation.JsonNaming;
import io.swagger.annotations.ApiModel;
import io.swagger.annotations.ApiModelProperty;
import lombok.Data;

import java.util.Date;

@Data
@JsonNaming(PropertyNamingStrategy.SnakeCaseStrategy.class)
@ApiModel(description = "自动机器学习实验实例")
public class AutoMlIns {
private Long id;

private Long autoMlId;

private String resultPath;

private String modelPath;

private String imgPath;

private String runHistoryPath;

private Integer state;

private String status;

private String nodeStatus;

private String nodeResult;

private String param;

private String source;

@ApiModelProperty(value = "Argo实例名称")
private String argoInsName;

@ApiModelProperty(value = "Argo命名空间")
private String argoInsNs;

private Date createTime;

private Date updateTime;

private Date finishTime;
}

+ 21
- 0
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/mapper/AutoMlDao.java View File

@@ -0,0 +1,21 @@
package com.ruoyi.platform.mapper;

import com.ruoyi.platform.domain.AutoMl;
import org.apache.ibatis.annotations.Param;
import org.springframework.data.domain.Pageable;

import java.util.List;

public interface AutoMlDao {
long count(@Param("mlName") String mlName);

List<AutoMl> queryByPage(@Param("mlName") String mlName, @Param("pageable") Pageable pageable);

AutoMl getAutoMlById(@Param("id") Long id);

AutoMl getAutoMlByName(@Param("mlName") String mlName);

int save(@Param("autoMl") AutoMl autoMl);

int edit(@Param("autoMl") AutoMl autoMl);
}

+ 23
- 0
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/mapper/AutoMlInsDao.java View File

@@ -0,0 +1,23 @@
package com.ruoyi.platform.mapper;

import com.ruoyi.platform.domain.AutoMlIns;
import org.apache.ibatis.annotations.Param;
import org.springframework.data.domain.Pageable;

import java.util.List;

public interface AutoMlInsDao {
long count(@Param("autoMlIns") AutoMlIns autoMlIns);

List<AutoMlIns> queryAllByLimit(@Param("autoMlIns") AutoMlIns autoMlIns, @Param("pageable") Pageable pageable);

List<AutoMlIns> getByAutoMlId(@Param("autoMlId") Long AutoMlId);

int insert(@Param("autoMlIns") AutoMlIns autoMlIns);

int update(@Param("autoMlIns") AutoMlIns autoMlIns);

AutoMlIns queryById(@Param("id") Long id);

List<AutoMlIns> queryByAutoMlInsIsNotTerminated();
}

+ 97
- 0
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/scheduling/AutoMlInsStatusTask.java View File

@@ -0,0 +1,97 @@
package com.ruoyi.platform.scheduling;

import com.ruoyi.platform.domain.AutoMl;
import com.ruoyi.platform.domain.AutoMlIns;
import com.ruoyi.platform.mapper.AutoMlDao;
import com.ruoyi.platform.mapper.AutoMlInsDao;
import com.ruoyi.platform.service.AutoMlInsService;
import org.apache.commons.lang3.StringUtils;
import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.stereotype.Component;

import javax.annotation.Resource;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;

@Component()
public class AutoMlInsStatusTask {

@Resource
private AutoMlInsService autoMlInsService;

@Resource
private AutoMlInsDao autoMlInsDao;

@Resource
private AutoMlDao autoMlDao;

private List<Long> autoMlIds = new ArrayList<>();

@Scheduled(cron = "0/30 * * * * ?") // 每30S执行一次
public void executeAutoMlInsStatus() throws Exception {
// 首先查到所有非终止态的实验实例
List<AutoMlIns> autoMlInsList = autoMlInsService.queryByAutoMlInsIsNotTerminated();

// 去argo查询状态
List<AutoMlIns> updateList = new ArrayList<>();
if (autoMlInsList != null && autoMlInsList.size() > 0) {
for (AutoMlIns autoMlIns : autoMlInsList) {
//当原本状态为null或非终止态时才调用argo接口
try {
autoMlIns = autoMlInsService.queryStatusFromArgo(autoMlIns);
} catch (Exception e) {
autoMlIns.setStatus("Failed");
}
// 线程安全的添加操作
synchronized (autoMlIds) {
autoMlIds.add(autoMlIns.getAutoMlId());
}
updateList.add(autoMlIns);
}
if (updateList.size() > 0) {
for (AutoMlIns autoMlIns : updateList) {
autoMlInsDao.update(autoMlIns);
}
}
}
}

@Scheduled(cron = "0/30 * * * * ?") // / 每30S执行一次
public void executeAutoMlStatus() throws Exception {
if (autoMlIds.size() == 0) {
return;
}
// 存储需要更新的实验对象列表
List<AutoMl> updateAutoMls = new ArrayList<>();
for (Long autoMlId : autoMlIds) {
// 获取当前实验的所有实例列表
List<AutoMlIns> insList = autoMlInsDao.getByAutoMlId(autoMlId);
List<String> statusList = new ArrayList<>();
// 更新实验状态列表
for (int i = 0; i < insList.size(); i++) {
statusList.add(insList.get(i).getStatus());
}
String subStatus = statusList.toString().substring(1, statusList.toString().length() - 1);
AutoMl autoMl = autoMlDao.getAutoMlById(autoMlId);
if (!StringUtils.equals(autoMl.getStatusList(), subStatus)) {
autoMl.setStatusList(subStatus);
updateAutoMls.add(autoMl);
autoMlDao.edit(autoMl);
}
}

if (!updateAutoMls.isEmpty()) {
// 使用Iterator进行安全的删除操作
Iterator<Long> iterator = autoMlIds.iterator();
while (iterator.hasNext()) {
Long autoMlId = iterator.next();
for (AutoMl autoMl : updateAutoMls) {
if (autoMl.getId().equals(autoMlId)) {
iterator.remove();
}
}
}
}
}
}

+ 29
- 0
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/AutoMlInsService.java View File

@@ -0,0 +1,29 @@
package com.ruoyi.platform.service;

import com.ruoyi.platform.domain.AutoMlIns;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.PageRequest;

import java.io.IOException;
import java.util.List;

public interface AutoMlInsService {

Page<AutoMlIns> queryByPage(AutoMlIns autoMlIns, PageRequest pageRequest) throws IOException;

AutoMlIns insert(AutoMlIns autoMlIns);

String removeById(Long id);

String batchDelete(List<Long> ids);

List<AutoMlIns> queryByAutoMlInsIsNotTerminated();

AutoMlIns queryStatusFromArgo(AutoMlIns autoMlIns);

boolean terminateAutoMlIns(Long id);

AutoMlIns getDetailById(Long id);

void updateAutoMlStatus(Long autoMlId);
}

+ 26
- 0
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/AutoMlService.java View File

@@ -0,0 +1,26 @@
package com.ruoyi.platform.service;

import com.ruoyi.platform.domain.AutoMl;
import com.ruoyi.platform.vo.AutoMlVo;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.PageRequest;
import org.springframework.web.multipart.MultipartFile;

import java.io.IOException;
import java.util.Map;

public interface AutoMlService {
Page<AutoMl> queryByPage(String mlName, PageRequest pageRequest);

AutoMl save(AutoMlVo autoMlVo) throws Exception;

String edit(AutoMlVo autoMlVo) throws Exception;

String delete(Long id);

AutoMlVo getAutoMlDetail(Long id) throws IOException;

Map<String, String> upload(MultipartFile file, String uuid) throws Exception;

String runAutoMlIns(Long id) throws Exception;
}

+ 250
- 0
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/AutoMlInsServiceImpl.java View File

@@ -0,0 +1,250 @@
package com.ruoyi.platform.service.impl;

import com.ruoyi.platform.constant.Constant;
import com.ruoyi.platform.domain.AutoMl;
import com.ruoyi.platform.domain.AutoMlIns;
import com.ruoyi.platform.mapper.AutoMlDao;
import com.ruoyi.platform.mapper.AutoMlInsDao;
import com.ruoyi.platform.service.AutoMlInsService;
import com.ruoyi.platform.utils.DateUtils;
import com.ruoyi.platform.utils.HttpUtils;
import com.ruoyi.platform.utils.JsonUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.PageImpl;
import org.springframework.data.domain.PageRequest;
import org.springframework.stereotype.Service;

import javax.annotation.Resource;
import java.io.IOException;
import java.util.*;

@Service
public class AutoMlInsServiceImpl implements AutoMlInsService {
@Value("${argo.url}")
private String argoUrl;
@Value("${argo.workflowStatus}")
private String argoWorkflowStatus;
@Value("${argo.workflowTermination}")
private String argoWorkflowTermination;

@Resource
private AutoMlInsDao autoMlInsDao;
@Resource
private AutoMlDao autoMlDao;

@Override
public Page<AutoMlIns> queryByPage(AutoMlIns autoMlIns, PageRequest pageRequest) throws IOException {
long total = this.autoMlInsDao.count(autoMlIns);
List<AutoMlIns> autoMlInsList = this.autoMlInsDao.queryAllByLimit(autoMlIns, pageRequest);
return new PageImpl<>(autoMlInsList, pageRequest, total);
}

@Override
public AutoMlIns insert(AutoMlIns autoMlIns) {
this.autoMlInsDao.insert(autoMlIns);
return autoMlIns;
}

@Override
public String removeById(Long id) {
AutoMlIns autoMlIns = autoMlInsDao.queryById(id);
if (autoMlIns == null) {
return "实验实例不存在";
}

if (StringUtils.isEmpty(autoMlIns.getStatus())) {
autoMlIns = queryStatusFromArgo(autoMlIns);
}
if (StringUtils.equals(autoMlIns.getStatus(), "Running")) {
return "实验实例正在运行,不可删除";
}

autoMlIns.setState(Constant.State_invalid);
int update = autoMlInsDao.update(autoMlIns);
if (update > 0) {
updateAutoMlStatus(autoMlIns.getAutoMlId());
return "删除成功";
} else {
return "删除失败";
}
}

@Override
public String batchDelete(List<Long> ids) {
for (Long id : ids) {
String result = removeById(id);
if (!"删除成功".equals(result)) {
return result;
}
}
return "删除成功";
}

@Override
public List<AutoMlIns> queryByAutoMlInsIsNotTerminated() {
return autoMlInsDao.queryByAutoMlInsIsNotTerminated();
}

@Override
public AutoMlIns queryStatusFromArgo(AutoMlIns ins) {
String namespace = ins.getArgoInsNs();
String name = ins.getArgoInsName();

// 创建请求数据map
Map<String, Object> requestData = new HashMap<>();
requestData.put("namespace", namespace);
requestData.put("name", name);

// 创建发送数据map,将请求数据作为"data"键的值
Map<String, Object> res = new HashMap<>();
res.put("data", requestData);

try {
// 发送POST请求到Argo工作流状态查询接口,并将请求数据转换为JSON
String req = HttpUtils.sendPost(argoUrl + argoWorkflowStatus, null, JsonUtils.mapToJson(res));
// 检查响应是否为空或无内容
if (req == null || StringUtils.isEmpty(req)) {
throw new RuntimeException("工作流状态响应为空。");
}
// 将响应的JSON字符串转换为Map对象
Map<String, Object> runResMap = JsonUtils.jsonToMap(req);
// 从响应Map中获取"data"部分
Map<String, Object> data = (Map<String, Object>) runResMap.get("data");
if (data == null || data.isEmpty()) {
throw new RuntimeException("工作流数据为空.");
}
// 从"data"中获取"status"部分,并返回"phase"的值
Map<String, Object> status = (Map<String, Object>) data.get("status");
if (status == null || status.isEmpty()) {
throw new RuntimeException("工作流状态为空。");
}

//解析流水线结束时间
String finishedAtString = (String) status.get("finishedAt");
if (finishedAtString != null && !finishedAtString.isEmpty()) {
Date finishTime = DateUtils.convertUTCtoShanghaiDate(finishedAtString);
ins.setFinishTime(finishTime);
}

// 解析nodes字段,提取节点状态并转换为JSON字符串
Map<String, Object> nodes = (Map<String, Object>) status.get("nodes");
Map<String, Object> modifiedNodes = new LinkedHashMap<>();
if (nodes != null) {
for (Map.Entry<String, Object> nodeEntry : nodes.entrySet()) {
Map<String, Object> nodeDetails = (Map<String, Object>) nodeEntry.getValue();
String templateName = (String) nodeDetails.get("displayName");
modifiedNodes.put(templateName, nodeDetails);
}
}

String nodeStatusJson = JsonUtils.mapToJson(modifiedNodes);
ins.setNodeStatus(nodeStatusJson);

//终止态为终止不改
if (!StringUtils.equals(ins.getStatus(), Constant.Terminated)) {
ins.setStatus(StringUtils.isNotEmpty((String) status.get("phase")) ? (String) status.get("phase") : Constant.Pending);
}
if (StringUtils.equals(ins.getStatus(), "Error")) {
ins.setStatus(Constant.Failed);
}
return ins;
} catch (Exception e) {
throw new RuntimeException("查询状态失败: " + e.getMessage(), e);
}
}

@Override
public boolean terminateAutoMlIns(Long id) {
AutoMlIns autoMlIns = autoMlInsDao.queryById(id);
if (autoMlIns == null) {
throw new IllegalStateException("实验实例未查询到,id: " + id);
}

String currentStatus = autoMlIns.getStatus();
String name = autoMlIns.getArgoInsName();
String namespace = autoMlIns.getArgoInsNs();

// 获取当前状态,如果为空,则从Argo查询
if (StringUtils.isEmpty(currentStatus)) {
currentStatus = queryStatusFromArgo(autoMlIns).getStatus();
}
// 只有状态是"Running"时才能终止实例
if (!currentStatus.equalsIgnoreCase(Constant.Running)) {
return false; // 如果不是"Running"状态,则不执行终止操作
}

// 创建请求数据map
Map<String, Object> requestData = new HashMap<>();
requestData.put("namespace", namespace);
requestData.put("name", name);
// 创建发送数据map,将请求数据作为"data"键的值
Map<String, Object> res = new HashMap<>();
res.put("data", requestData);

try {
// 发送POST请求到Argo工作流状态查询接口,并将请求数据转换为JSON
String req = HttpUtils.sendPost(argoUrl + argoWorkflowTermination, null, JsonUtils.mapToJson(res));
// 检查响应是否为空或无内容
if (StringUtils.isEmpty(req)) {
throw new RuntimeException("终止响应内容为空。");

}
// 将响应的JSON字符串转换为Map对象
Map<String, Object> runResMap = JsonUtils.jsonToMap(req);
// 从响应Map中直接获取"errCode"的值
Integer errCode = (Integer) runResMap.get("errCode");
if (errCode != null && errCode == 0) {
//更新autoMlIns,确保状态更新被保存到数据库
AutoMlIns ins = queryStatusFromArgo(autoMlIns);
String nodeStatus = ins.getNodeStatus();
Map<String, Object> nodeMap = JsonUtils.jsonToMap(nodeStatus);

// 遍历 map
for (Map.Entry<String, Object> entry : nodeMap.entrySet()) {
// 获取每个 Map 中的值并强制转换为 Map
Map<String, Object> innerMap = (Map<String, Object>) entry.getValue();
// 检查 phase 的值
if (innerMap.containsKey("phase")) {
String phaseValue = (String) innerMap.get("phase");
// 如果值不等于 Succeeded,则赋值为 Failed
if (!StringUtils.equals("Succeeded", phaseValue)) {
innerMap.put("phase", "Failed");
}
}
}
ins.setNodeStatus(JsonUtils.mapToJson(nodeMap));
ins.setStatus(Constant.Terminated);
ins.setUpdateTime(new Date());
this.autoMlInsDao.update(ins);
updateAutoMlStatus(autoMlIns.getAutoMlId());
return true;
} else {
return false;
}
} catch (Exception e) {
throw new RuntimeException("终止实例错误: " + e.getMessage(), e);
}
}

@Override
public AutoMlIns getDetailById(Long id) {
return this.autoMlInsDao.queryById(id);
}

public void updateAutoMlStatus(Long autoMlId) {
List<AutoMlIns> insList = autoMlInsDao.getByAutoMlId(autoMlId);
List<String> statusList = new ArrayList<>();
// 更新实验状态列表
for (int i = 0; i < insList.size(); i++) {
statusList.add(insList.get(i).getStatus());
}
String subStatus = statusList.toString().substring(1, statusList.toString().length() - 1);
AutoMl autoMl = autoMlDao.getAutoMlById(autoMlId);
if (!StringUtils.equals(autoMl.getStatusList(), subStatus)) {
autoMl.setStatusList(subStatus);
autoMlDao.edit(autoMl);
}
}
}

+ 216
- 0
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/AutoMlServiceImpl.java View File

@@ -0,0 +1,216 @@
package com.ruoyi.platform.service.impl;

import com.ruoyi.common.security.utils.SecurityUtils;
import com.ruoyi.platform.constant.Constant;
import com.ruoyi.platform.domain.AutoMl;
import com.ruoyi.platform.domain.AutoMlIns;
import com.ruoyi.platform.mapper.AutoMlDao;
import com.ruoyi.platform.mapper.AutoMlInsDao;
import com.ruoyi.platform.service.AutoMlInsService;
import com.ruoyi.platform.service.AutoMlService;
import com.ruoyi.platform.utils.FileUtil;
import com.ruoyi.platform.utils.HttpUtils;
import com.ruoyi.platform.utils.JacksonUtil;
import com.ruoyi.platform.utils.JsonUtils;
import com.ruoyi.platform.vo.AutoMlParamVo;
import com.ruoyi.platform.vo.AutoMlVo;
import org.apache.commons.collections4.MapUtils;
import org.apache.commons.io.FileUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.PageImpl;
import org.springframework.data.domain.PageRequest;
import org.springframework.stereotype.Service;
import org.springframework.web.multipart.MultipartFile;

import javax.annotation.Resource;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

@Service("autoMLService")
public class AutoMlServiceImpl implements AutoMlService {
@Value("${git.localPath}")
String localPath;

@Value("${argo.url}")
private String argoUrl;
@Value("${argo.convertAutoML}")
String convertAutoML;
@Value("${argo.workflowRun}")
private String argoWorkflowRun;

@Value("${minio.endpoint}")
private String minioEndpoint;

@Resource
private AutoMlDao autoMlDao;

@Resource
private AutoMlInsDao autoMlInsDao;

@Resource
private AutoMlInsService autoMlInsService;

@Override
public Page<AutoMl> queryByPage(String mlName, PageRequest pageRequest) {
long total = autoMlDao.count(mlName);
List<AutoMl> autoMls = autoMlDao.queryByPage(mlName, pageRequest);
return new PageImpl<>(autoMls, pageRequest, total);
}

@Override
public AutoMl save(AutoMlVo autoMlVo) throws Exception {
AutoMl autoMlByName = autoMlDao.getAutoMlByName(autoMlVo.getMlName());
if (autoMlByName != null) {
throw new RuntimeException("实验名称已存在");
}
AutoMl autoMl = new AutoMl();
BeanUtils.copyProperties(autoMlVo, autoMl);
String username = SecurityUtils.getLoginUser().getUsername();
autoMl.setCreateBy(username);
autoMl.setUpdateBy(username);
String datasetJson = JacksonUtil.toJSONString(autoMlVo.getDataset());
autoMl.setDataset(datasetJson);
autoMlDao.save(autoMl);
return autoMl;
}

@Override
public String edit(AutoMlVo autoMlVo) throws Exception {
AutoMl oldAutoMl = autoMlDao.getAutoMlByName(autoMlVo.getMlName());
if (oldAutoMl != null && !oldAutoMl.getId().equals(autoMlVo.getId())) {
throw new RuntimeException("实验名称已存在");
}
AutoMl autoMl = new AutoMl();
BeanUtils.copyProperties(autoMlVo, autoMl);
String username = SecurityUtils.getLoginUser().getUsername();
autoMl.setUpdateBy(username);
String datasetJson = JacksonUtil.toJSONString(autoMlVo.getDataset());
autoMl.setDataset(datasetJson);

autoMlDao.edit(autoMl);
return "修改成功";
}

@Override
public String delete(Long id) {
AutoMl autoMl = autoMlDao.getAutoMlById(id);
if (autoMl == null) {
throw new RuntimeException("服务不存在");
}

String username = SecurityUtils.getLoginUser().getUsername();
String createBy = autoMl.getCreateBy();
if (!(StringUtils.equals(username, "admin") || StringUtils.equals(username, createBy))) {
throw new RuntimeException("无权限删除该服务");
}

autoMl.setState(Constant.State_invalid);
return autoMlDao.edit(autoMl) > 0 ? "删除成功" : "删除失败";
}

@Override
public AutoMlVo getAutoMlDetail(Long id) throws IOException {
AutoMl autoMl = autoMlDao.getAutoMlById(id);
AutoMlVo autoMlVo = new AutoMlVo();
BeanUtils.copyProperties(autoMl, autoMlVo);
if (StringUtils.isNotEmpty(autoMl.getDataset())) {
autoMlVo.setDataset(JsonUtils.jsonToMap(autoMl.getDataset()));
}
return autoMlVo;
}

@Override
public Map<String, String> upload(MultipartFile file, String uuid) throws Exception {
Map<String, String> result = new HashMap<>();

String username = SecurityUtils.getLoginUser().getUsername();
String fileName = file.getOriginalFilename();
String path = localPath + "temp/" + username + "/automl_data/" + uuid;
long sizeInBytes = file.getSize();
String formattedSize = FileUtil.formatFileSize(sizeInBytes);
File targetFile = new File(path, file.getOriginalFilename());
// 确保目录存在
targetFile.getParentFile().mkdirs();
// 保存文件到目标路径
FileUtils.copyInputStreamToFile(file.getInputStream(), targetFile);
// 返回上传文件的路径
result.put("fileName", fileName);
result.put("url", path); // objectName根据实际情况定义
result.put("fileSize", formattedSize);
return result;
}

@Override
public String runAutoMlIns(Long id) throws Exception {
AutoMl autoMl = autoMlDao.getAutoMlById(id);
if (autoMl == null) {
throw new Exception("自动机器学习配置不存在");
}

AutoMlParamVo autoMlParam = new AutoMlParamVo();
BeanUtils.copyProperties(autoMl, autoMlParam);
autoMlParam.setDataset(JsonUtils.jsonToMap(autoMl.getDataset()));
String param = JsonUtils.objectToJson(autoMlParam);
// 调argo转换接口
try {
String convertRes = HttpUtils.sendPost(argoUrl + convertAutoML, param);
if (convertRes == null || StringUtils.isEmpty(convertRes)) {
throw new RuntimeException("转换流水线失败");
}
Map<String, Object> converMap = JsonUtils.jsonToMap(convertRes);
// 组装运行接口json
Map<String, Object> output = (Map<String, Object>) converMap.get("output");
Map<String, Object> runReqMap = new HashMap<>();
runReqMap.put("data", converMap.get("data"));
// 调argo运行接口
String runRes = HttpUtils.sendPost(argoUrl + argoWorkflowRun, JsonUtils.mapToJson(runReqMap));

if (runRes == null || StringUtils.isEmpty(runRes)) {
throw new RuntimeException("Failed to run workflow.");
}
Map<String, Object> runResMap = JsonUtils.jsonToMap(runRes);
Map<String, Object> data = (Map<String, Object>) runResMap.get("data");
//判断data为空
if (data == null || MapUtils.isEmpty(data)) {
throw new RuntimeException("Failed to run workflow.");
}
Map<String, Object> metadata = (Map<String, Object>) data.get("metadata");
// 插入记录到实验实例表
AutoMlIns autoMlIns = new AutoMlIns();
autoMlIns.setAutoMlId(autoMl.getId());
autoMlIns.setArgoInsNs((String) metadata.get("namespace"));
autoMlIns.setArgoInsName((String) metadata.get("name"));
autoMlIns.setParam(param);
autoMlIns.setStatus(Constant.Pending);
//替换argoInsName
String outputString = JsonUtils.mapToJson(output);
autoMlIns.setNodeResult(outputString.replace("{{workflow.name}}", (String) metadata.get("name")));

Map<String, Object> param_output = (Map<String, Object>) output.get("param_output");
List output1 = (ArrayList) param_output.values().toArray()[0];
Map<String, String> output2 = (Map<String, String>) output1.get(0);
String outputPath = minioEndpoint + "/" + output2.get("path").replace("{{workflow.name}}", (String) metadata.get("name")) + "/";
autoMlIns.setModelPath(outputPath + "save_model.joblib");
if (Constant.AutoMl_Classification.equals(autoMl.getTaskType())) {
autoMlIns.setImgPath(outputPath + "Auto-sklearn_accuracy_over_time.png" + "," + outputPath + "Train_Confusion_Matrix.png" + "," + outputPath + "Test_Confusion_Matrix.png");
} else {
autoMlIns.setImgPath(outputPath + "Auto-sklearn_accuracy_over_time.png" + "," + outputPath + "regression.png");
}
autoMlIns.setResultPath(outputPath + "result.txt");
String seed = autoMl.getSeed() != null ? String.valueOf(autoMl.getSeed()) : "1";
autoMlIns.setRunHistoryPath(outputPath + "smac3-output/run_" + seed + "/runhistory.json");
autoMlInsDao.insert(autoMlIns);
autoMlInsService.updateAutoMlStatus(id);
} catch (Exception e) {
throw new RuntimeException(e);
}
return "执行成功";
}
}

+ 0
- 2
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/JupyterServiceImpl.java View File

@@ -104,8 +104,6 @@ public class JupyterServiceImpl implements JupyterService {
// String pvcName = loginUser.getUsername().toLowerCase() + "-editor-pvc";
// V1PersistentVolumeClaim pvc = k8sClientUtil.createPvc(namespace, pvcName, storage, storageClassName);

//TODO 设置镜像可配置,这里先用默认镜像启动pod

// 调用修改后的 createPod 方法,传入额外的参数
// Integer podPort = k8sClientUtil.createConfiguredPod(podName, namespace, port, mountPath, pvc, devEnvironment, minioPvcName, datasetPath, modelPath);
Integer podPort = k8sClientUtil.createConfiguredPod(podName, namespace, port, mountPath, null, devEnvironment, minioPvcName, datasetPath, modelPath);


+ 155
- 0
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/vo/AutoMlParamVo.java View File

@@ -0,0 +1,155 @@
package com.ruoyi.platform.vo;

import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.databind.PropertyNamingStrategy;
import com.fasterxml.jackson.databind.annotation.JsonNaming;
import io.swagger.annotations.ApiModel;
import io.swagger.annotations.ApiModelProperty;
import lombok.Data;

import java.util.Map;

@Data
@JsonNaming(PropertyNamingStrategy.SnakeCaseStrategy.class)
@JsonInclude(JsonInclude.Include.NON_NULL)
@ApiModel(description = "自动机器学习参数")
public class AutoMlParamVo {
@ApiModelProperty(value = "任务类型:classification或regression")
private String taskType;

@ApiModelProperty(value = "搜索合适模型的时间限制(以秒为单位)。通过增加这个值,auto-sklearn有更高的机会找到更好的模型。默认3600,非必传。")
private Integer timeLeftForThisTask;

@ApiModelProperty(value = "单次调用机器学习模型的时间限制(以秒为单位)。如果机器学习算法运行超过时间限制,将终止模型拟合。将这个值设置得足够高,这样典型的机器学习算法就可以适用于训练数据。默认600,非必传。")
private Integer perRunTimeLimit;

@ApiModelProperty(value = "集成模型数量,如果设置为0,则没有集成。默认50,非必传。")
private Integer ensembleSize;

@ApiModelProperty(value = "设置为None将禁用集成构建,设置为SingleBest仅使用单个最佳模型而不是集成,设置为default,它将对单目标问题使用EnsembleSelection,对多目标问题使用MultiObjectiveDummyEnsemble。默认default,非必传。")
private String ensembleClass;

@ApiModelProperty(value = "在构建集成时只考虑ensemble_nbest模型。这是受到了“最大限度地利用集成选择”中引入的库修剪概念的启发。这是独立于ensemble_class参数的,并且这个修剪步骤是在构造集成之前完成的。默认50,非必传。")
private Integer ensembleNbest;

@ApiModelProperty(value = "定义在磁盘中保存的模型的最大数量。额外的模型数量将被永久删除。由于这个变量的性质,它设置了一个集成可以使用多少个模型的上限。必须是大于等于1的整数。如果设置为None,则所有模型都保留在磁盘上。默认50,非必传。")
private Integer maxModelsOnDisc;

@ApiModelProperty(value = "随机种子,将决定输出文件名。默认1,非必传。")
private Integer seed;

@ApiModelProperty(value = "机器学习算法的内存限制(MB)。如果auto-sklearn试图分配超过memory_limit MB,它将停止拟合机器学习算法。默认3072,非必传。")
private Integer memoryLimit;

@ApiModelProperty(value = "如果为None,则使用所有可能的分类算法。否则,指定搜索中包含的步骤和组件。有关可用组件,请参见/pipeline/components/<step>/*。与参数exclude不兼容。多选,逗号分隔。包含:adaboost\n" +
"bernoulli_nb\n" +
"decision_tree\n" +
"extra_trees\n" +
"gaussian_nb\n" +
"gradient_boosting\n" +
"k_nearest_neighbors\n" +
"lda\n" +
"liblinear_svc\n" +
"libsvm_svc\n" +
"mlp\n" +
"multinomial_nb\n" +
"passive_aggressive\n" +
"qda\n" +
"random_forest\n" +
"sgd")
private String includeClassifier;

@ApiModelProperty(value = "如果为None,则使用所有可能的特征预处理算法。否则,指定搜索中包含的步骤和组件。有关可用组件,请参见/pipeline/components/<step>/*。与参数exclude不兼容。多选,逗号分隔。包含:densifier\n" +
"extra_trees_preproc_for_classification\n" +
"extra_trees_preproc_for_regression\n" +
"fast_ica\n" +
"feature_agglomeration\n" +
"kernel_pca\n" +
"kitchen_sinks\n" +
"liblinear_svc_preprocessor\n" +
"no_preprocessing\n" +
"nystroem_sampler\n" +
"pca\n" +
"polynomial\n" +
"random_trees_embedding\n" +
"select_percentile_classification\n" +
"select_percentile_regression\n" +
"select_rates_classification\n" +
"select_rates_regression\n" +
"truncatedSVD")
private String includeFeaturePreprocessor;

@ApiModelProperty(value = "如果为None,则使用所有可能的回归算法。否则,指定搜索中包含的步骤和组件。有关可用组件,请参见/pipeline/components/<step>/*。与参数exclude不兼容。多选,逗号分隔。包含:adaboost,\n" +
"ard_regression,\n" +
"decision_tree,\n" +
"extra_trees,\n" +
"gaussian_process,\n" +
"gradient_boosting,\n" +
"k_nearest_neighbors,\n" +
"liblinear_svr,\n" +
"libsvm_svr,\n" +
"mlp,\n" +
"random_forest,\n" +
"sgd")
private String includeRegressor;

private String excludeClassifier;

private String excludeRegressor;

private String excludeFeaturePreprocessor;

@ApiModelProperty(value = "测试集的比率,0到1之间")
private Float testSize;

@ApiModelProperty(value = "如何处理过拟合,如果使用基于“cv”的方法或Splitter对象,可能需要使用resampling_strategy_arguments。holdout或crossValid")
private String resamplingStrategy;

@ApiModelProperty(value = "重采样划分训练集和验证集,训练集的比率,0到1之间")
private Float trainSize;

@ApiModelProperty(value = "拆分数据前是否进行shuffle")
private Boolean shuffle;

@ApiModelProperty(value = "交叉验证的折数,当resamplingStrategy为crossValid时,此项必填,为整数")
private Integer folds;

@ApiModelProperty(value = "数据集csv文件中哪几列是预测目标列,逗号分隔")
private String targetColumns;

@ApiModelProperty(value = "自定义指标名称")
private String metricName;

@ApiModelProperty(value = "模型优化目标指标及权重,json格式。分类的指标包含:accuracy\n" +
"balanced_accuracy\n" +
"roc_auc\n" +
"average_precision\n" +
"log_loss\n" +
"precision_macro\n" +
"precision_micro\n" +
"precision_samples\n" +
"precision_weighted\n" +
"recall_macro\n" +
"recall_micro\n" +
"recall_samples\n" +
"recall_weighted\n" +
"f1_macro\n" +
"f1_micro\n" +
"f1_samples\n" +
"f1_weighted\n" +
"回归的指标包含:mean_absolute_error\n" +
"mean_squared_error\n" +
"root_mean_squared_error\n" +
"mean_squared_log_error\n" +
"median_absolute_error\n" +
"r2")
private String metrics;

@ApiModelProperty(value = "指标优化方向,是越大越好还是越小越好")
private Boolean greaterIsBetter;

@ApiModelProperty(value = "模型计算并打印指标")
private String scoringFunctions;

private Map<String,Object> dataset;
}

+ 183
- 0
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/vo/AutoMlVo.java View File

@@ -0,0 +1,183 @@
package com.ruoyi.platform.vo;

import com.baomidou.mybatisplus.annotation.TableField;
import com.fasterxml.jackson.databind.PropertyNamingStrategy;
import com.fasterxml.jackson.databind.annotation.JsonNaming;
import io.swagger.annotations.ApiModel;
import io.swagger.annotations.ApiModelProperty;
import lombok.Data;

import java.util.Date;
import java.util.Map;

@Data
@JsonNaming(PropertyNamingStrategy.SnakeCaseStrategy.class)
@ApiModel(description = "自动机器学习")
public class AutoMlVo {
private Long id;

@ApiModelProperty(value = "实验名称")
private String mlName;

@ApiModelProperty(value = "实验描述")
private String mlDescription;

@ApiModelProperty(value = "任务类型:classification或regression")
private String taskType;

@ApiModelProperty(value = "搜索合适模型的时间限制(以秒为单位)。通过增加这个值,auto-sklearn有更高的机会找到更好的模型。默认3600,非必传。")
private Integer timeLeftForThisTask;

@ApiModelProperty(value = "单次调用机器学习模型的时间限制(以秒为单位)。如果机器学习算法运行超过时间限制,将终止模型拟合。将这个值设置得足够高,这样典型的机器学习算法就可以适用于训练数据。默认600,非必传。")
private Integer perRunTimeLimit;

@ApiModelProperty(value = "集成模型数量,如果设置为0,则没有集成。默认50,非必传。")
private Integer ensembleSize;

@ApiModelProperty(value = "设置为None将禁用集成构建,设置为SingleBest仅使用单个最佳模型而不是集成,设置为default,它将对单目标问题使用EnsembleSelection,对多目标问题使用MultiObjectiveDummyEnsemble。默认default,非必传。")
private String ensembleClass;

@ApiModelProperty(value = "在构建集成时只考虑ensemble_nbest模型。这是受到了“最大限度地利用集成选择”中引入的库修剪概念的启发。这是独立于ensemble_class参数的,并且这个修剪步骤是在构造集成之前完成的。默认50,非必传。")
private Integer ensembleNbest;

@ApiModelProperty(value = "定义在磁盘中保存的模型的最大数量。额外的模型数量将被永久删除。由于这个变量的性质,它设置了一个集成可以使用多少个模型的上限。必须是大于等于1的整数。如果设置为None,则所有模型都保留在磁盘上。默认50,非必传。")
private Integer maxModelsOnDisc;

@ApiModelProperty(value = "随机种子,将决定输出文件名。默认1,非必传。")
private Integer seed;

@ApiModelProperty(value = "机器学习算法的内存限制(MB)。如果auto-sklearn试图分配超过memory_limit MB,它将停止拟合机器学习算法。默认3072,非必传。")
private Integer memoryLimit;

@ApiModelProperty(value = "如果为None,则使用所有可能的分类算法。否则,指定搜索中包含的步骤和组件。有关可用组件,请参见/pipeline/components/<step>/*。与参数exclude不兼容。多选,逗号分隔。包含:adaboost\n" +
"bernoulli_nb\n" +
"decision_tree\n" +
"extra_trees\n" +
"gaussian_nb\n" +
"gradient_boosting\n" +
"k_nearest_neighbors\n" +
"lda\n" +
"liblinear_svc\n" +
"libsvm_svc\n" +
"mlp\n" +
"multinomial_nb\n" +
"passive_aggressive\n" +
"qda\n" +
"random_forest\n" +
"sgd")
private String includeClassifier;

@ApiModelProperty(value = "如果为None,则使用所有可能的特征预处理算法。否则,指定搜索中包含的步骤和组件。有关可用组件,请参见/pipeline/components/<step>/*。与参数exclude不兼容。多选,逗号分隔。包含:densifier\n" +
"extra_trees_preproc_for_classification\n" +
"extra_trees_preproc_for_regression\n" +
"fast_ica\n" +
"feature_agglomeration\n" +
"kernel_pca\n" +
"kitchen_sinks\n" +
"liblinear_svc_preprocessor\n" +
"no_preprocessing\n" +
"nystroem_sampler\n" +
"pca\n" +
"polynomial\n" +
"random_trees_embedding\n" +
"select_percentile_classification\n" +
"select_percentile_regression\n" +
"select_rates_classification\n" +
"select_rates_regression\n" +
"truncatedSVD")
private String includeFeaturePreprocessor;

@ApiModelProperty(value = "如果为None,则使用所有可能的回归算法。否则,指定搜索中包含的步骤和组件。有关可用组件,请参见/pipeline/components/<step>/*。与参数exclude不兼容。多选,逗号分隔。包含:adaboost,\n" +
"ard_regression,\n" +
"decision_tree,\n" +
"extra_trees,\n" +
"gaussian_process,\n" +
"gradient_boosting,\n" +
"k_nearest_neighbors,\n" +
"liblinear_svr,\n" +
"libsvm_svr,\n" +
"mlp,\n" +
"random_forest,\n" +
"sgd")
private String includeRegressor;

private String excludeClassifier;

private String excludeRegressor;

private String excludeFeaturePreprocessor;

@ApiModelProperty(value = "测试集的比率,0到1之间")
private Float testSize;

@ApiModelProperty(value = "如何处理过拟合,如果使用基于“cv”的方法或Splitter对象,可能需要使用resampling_strategy_arguments。holdout或crossValid")
private String resamplingStrategy;

@ApiModelProperty(value = "重采样划分训练集和验证集,训练集的比率,0到1之间")
private Float trainSize;

@ApiModelProperty(value = "拆分数据前是否进行shuffle")
private Boolean shuffle;

@ApiModelProperty(value = "交叉验证的折数,当resamplingStrategy为crossValid时,此项必填,为整数")
private Integer folds;

@ApiModelProperty(value = "文件夹存放配置输出和日志文件,默认/tmp/automl")
private String tmpFolder;

@ApiModelProperty(value = "数据集csv文件中哪几列是预测目标列,逗号分隔")
private String targetColumns;

@ApiModelProperty(value = "自定义指标名称")
private String metricName;

@ApiModelProperty(value = "模型优化目标指标及权重,json格式。分类的指标包含:accuracy\n" +
"balanced_accuracy\n" +
"roc_auc\n" +
"average_precision\n" +
"log_loss\n" +
"precision_macro\n" +
"precision_micro\n" +
"precision_samples\n" +
"precision_weighted\n" +
"recall_macro\n" +
"recall_micro\n" +
"recall_samples\n" +
"recall_weighted\n" +
"f1_macro\n" +
"f1_micro\n" +
"f1_samples\n" +
"f1_weighted\n" +
"回归的指标包含:mean_absolute_error\n" +
"mean_squared_error\n" +
"root_mean_squared_error\n" +
"mean_squared_log_error\n" +
"median_absolute_error\n" +
"r2")
private String metrics;

@ApiModelProperty(value = "指标优化方向,是越大越好还是越小越好")
private Boolean greaterIsBetter;

@ApiModelProperty(value = "模型计算并打印指标")
private String scoringFunctions;

private Integer state;

private String runState;

private Double progress;

private String createBy;

private Date createTime;

private String updateBy;

private Date updateTime;

/**
* 对应数据集
*/
private Map<String,Object> dataset;
}

+ 5
- 5
ruoyi-modules/management-platform/src/main/resources/bootstrap.yml View File

@@ -14,16 +14,16 @@ spring:
nacos:
discovery:
# 服务注册地址
server-addr: nacos-ci4s.argo.svc:8848
server-addr: 172.20.32.181:18848
username: nacos
password: h1n2x3j4y5@
password: nacos
retry:
enabled: true
# namespace: 6caf5d79-c4ce-4e3b-a357-141b74e52a01
config:
username: nacos
password: h1n2x3j4y5@
# namespace: 6caf5d79-c4ce-4e3b-a357-141b74e52a01
# 配置中心地址
server-addr: nacos-ci4s.argo.svc:8848
server-addr: 172.20.32.181:18848
# 配置文件格式
file-extension: yml
# 共享配置


+ 146
- 0
ruoyi-modules/management-platform/src/main/resources/mapper/managementPlatform/AutoMlDao.xml View File

@@ -0,0 +1,146 @@
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE mapper PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN" "http://mybatis.org/dtd/mybatis-3-mapper.dtd">
<mapper namespace="com.ruoyi.platform.mapper.AutoMlDao">
<insert id="save">
insert into auto_ml(ml_name, ml_description, task_type, dataset, time_left_for_this_task,
per_run_time_limit, ensemble_size, ensemble_class, ensemble_nbest, max_models_on_disc, seed,
memory_limit,
include_classifier, include_feature_preprocessor, include_regressor, exclude_classifier,
exclude_regressor, exclude_feature_preprocessor, test_size, resampling_strategy, train_size,
shuffle, folds, target_columns, metric_name, metrics, greater_is_better, scoring_functions,
tmp_folder,
create_by, update_by)
values (#{autoMl.mlName}, #{autoMl.mlDescription}, #{autoMl.taskType}, #{autoMl.dataset},
#{autoMl.timeLeftForThisTask}, #{autoMl.perRunTimeLimit},
#{autoMl.ensembleSize}, #{autoMl.ensembleClass}, #{autoMl.ensembleNbest},
#{autoMl.maxModelsOnDisc}, #{autoMl.seed},
#{autoMl.memoryLimit}, #{autoMl.includeClassifier}, #{autoMl.includeFeaturePreprocessor},
#{autoMl.includeRegressor}, #{autoMl.excludeClassifier},
#{autoMl.excludeRegressor}, #{autoMl.excludeFeaturePreprocessor}, #{autoMl.testSize},
#{autoMl.resamplingStrategy},
#{autoMl.trainSize}, #{autoMl.shuffle},
#{autoMl.folds},
#{autoMl.targetColumns}, #{autoMl.metricName}, #{autoMl.metrics}, #{autoMl.greaterIsBetter},
#{autoMl.scoringFunctions}, #{autoMl.tmpFolder},
#{autoMl.createBy}, #{autoMl.updateBy})
</insert>

<update id="edit">
update auto_ml
<set>
<if test="autoMl.mlName != null and autoMl.mlName !=''">
ml_name = #{autoMl.mlName},
</if>
<if test="autoMl.mlDescription != null and autoMl.mlDescription !=''">
ml_description = #{autoMl.mlDescription},
</if>
<if test="autoMl.statusList != null and autoMl.statusList !=''">
status_list = #{autoMl.statusList},
</if>
<!-- <if test="autoMl.progress != null">-->
<!-- progress = #{autoMl.progress},-->
<!-- </if>-->
<if test="autoMl.taskType != null and autoMl.taskType !=''">
task_type = #{autoMl.taskType},
</if>
<if test="autoMl.dataset != null and autoMl.dataset !=''">
dataset = #{autoMl.dataset},
</if>
<if test="autoMl.timeLeftForThisTask != null">
time_left_for_this_task = #{autoMl.timeLeftForThisTask},
</if>
<if test="autoMl.perRunTimeLimit != null">
per_run_time_limit = #{autoMl.perRunTimeLimit},
</if>
<if test="autoMl.ensembleSize != null">
ensemble_size = #{autoMl.ensembleSize},
</if>
<if test="autoMl.ensembleClass != null and autoMl.ensembleClass !=''">
ensemble_class = #{autoMl.ensembleClass},
</if>
<if test="autoMl.ensembleNbest != null">
ensemble_nbest = #{autoMl.ensembleNbest},
</if>
<if test="autoMl.maxModelsOnDisc != null">
max_models_on_disc = #{autoMl.maxModelsOnDisc},
</if>
<if test="autoMl.seed != null">
seed = #{autoMl.seed},
</if>
<if test="autoMl.memoryLimit != null">
memory_limit = #{autoMl.memoryLimit},
</if>
include_classifier = #{autoMl.includeClassifier},
include_feature_preprocessor = #{autoMl.includeFeaturePreprocessor},
include_regressor = #{autoMl.includeRegressor},
exclude_classifier = #{autoMl.excludeClassifier},
exclude_regressor = #{autoMl.excludeRegressor},
exclude_feature_preprocessor = #{autoMl.excludeFeaturePreprocessor},
scoring_functions = #{autoMl.scoringFunctions},
metrics = #{autoMl.metrics},
<if test="autoMl.testSize != null and autoMl.testSize !=''">
test_size = #{autoMl.testSize},
</if>
<if test="autoMl.resamplingStrategy != null and autoMl.resamplingStrategy !=''">
resampling_strategy = #{autoMl.resamplingStrategy},
</if>
<if test="autoMl.trainSize != null and autoMl.trainSize !=''">
train_size = #{autoMl.trainSize},
</if>
<if test="autoMl.shuffle != null">
shuffle = #{autoMl.shuffle},
</if>
<if test="autoMl.folds != null">
folds = #{autoMl.folds},
</if>
<if test="autoMl.tmpFolder != null and autoMl.tmpFolder !=''">
tmp_folder = #{autoMl.tmpFolder},
</if>
<if test="autoMl.metricName != null and autoMl.metricName !=''">
metric_name = #{autoMl.metricName},
</if>
<if test="autoMl.greaterIsBetter != null">
greater_is_better = #{autoMl.greaterIsBetter},
</if>
<if test="autoMl.targetColumns != null and autoMl.targetColumns !=''">
target_columns = #{autoMl.targetColumns},
</if>
<if test="autoMl.state != null">
state = #{autoMl.state},
</if>
</set>
where id = #{autoMl.id}
</update>

<select id="count" resultType="java.lang.Long">
select count(1) from auto_ml
<include refid="common_condition"></include>
</select>

<select id="queryByPage" resultType="com.ruoyi.platform.domain.AutoMl">
select * from auto_ml
<include refid="common_condition"></include>
</select>

<select id="getAutoMlById" resultType="com.ruoyi.platform.domain.AutoMl">
select *
from auto_ml
where id = #{id}
</select>

<select id="getAutoMlByName" resultType="com.ruoyi.platform.domain.AutoMl">
select *
from auto_ml
where ml_name = #{mlName}
and state = 1
</select>

<sql id="common_condition">
<where>
state = 1
<if test="mlName != null and mlName != ''">
and ml_name like concat('%', #{mlName}, '%')
</if>
</where>
</sql>
</mapper>

+ 86
- 0
ruoyi-modules/management-platform/src/main/resources/mapper/managementPlatform/AutoMlInsDao.xml View File

@@ -0,0 +1,86 @@
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE mapper PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN" "http://mybatis.org/dtd/mybatis-3-mapper.dtd">
<mapper namespace="com.ruoyi.platform.mapper.AutoMlInsDao">

<insert id="insert" keyProperty="id" useGeneratedKeys="true">
insert into auto_ml_ins(auto_ml_id, result_path, model_path, img_path, run_history_path, node_status,
node_result, param, source, argo_ins_name, argo_ins_ns, status)
values (#{autoMlIns.autoMlId}, #{autoMlIns.resultPath}, #{autoMlIns.modelPath}, #{autoMlIns.imgPath},
#{autoMlIns.runHistoryPath}, #{autoMlIns.nodeStatus},
#{autoMlIns.nodeResult}, #{autoMlIns.param}, #{autoMlIns.source}, #{autoMlIns.argoInsName},
#{autoMlIns.argoInsNs}, #{autoMlIns.status})
</insert>

<update id="update">
update auto_ml_ins
<set>
<if test="autoMlIns.modelPath != null and autoMlIns.modelPath != ''">
model_path = #{autoMlIns.modelPath},
</if>
<if test="autoMlIns.imgPath != null and autoMlIns.imgPath != ''">
img_path = #{autoMlIns.imgPath},
</if>
<if test="autoMlIns.status != null and autoMlIns.status != ''">
status = #{autoMlIns.status},
</if>
<if test="autoMlIns.nodeStatus != null and autoMlIns.nodeStatus != ''">
node_status = #{autoMlIns.nodeStatus},
</if>
<if test="autoMlIns.nodeResult != null and autoMlIns.nodeResult != ''">
node_result = #{autoMlIns.nodeResult},
</if>
<if test="autoMlIns.state != null">
state = #{autoMlIns.state},
</if>
<if test="autoMlIns.updateTime != null">
update_time = #{autoMlIns.updateTime},
</if>
<if test="autoMlIns.finishTime != null">
finish_time = #{autoMlIns.finishTime},
</if>
</set>
where id = #{autoMlIns.id}
</update>

<select id="count" resultType="java.lang.Long">
select count(1)
from auto_ml_ins
<where>
state = 1
and auto_ml_id = #{autoMlIns.autoMlId}
</where>
</select>

<select id="queryAllByLimit" resultType="com.ruoyi.platform.domain.AutoMlIns">
select * from auto_ml_ins
<where>
state = 1
and auto_ml_id = #{autoMlIns.autoMlId}
</where>
order by update_time DESC
limit #{pageable.offset}, #{pageable.pageSize}
</select>

<select id="queryById" resultType="com.ruoyi.platform.domain.AutoMlIns">
select * from auto_ml_ins
<where>
state = 1 and id = #{id}
</where>
</select>

<select id="queryByAutoMlInsIsNotTerminated" resultType="com.ruoyi.platform.domain.AutoMlIns">
select *
from auto_ml_ins
where (status NOT IN ('Terminated', 'Succeeded', 'Failed')
OR status IS NULL)
and state = 1
</select>

<select id="getByAutoMlId" resultType="com.ruoyi.platform.domain.AutoMlIns">
select *
from auto_ml_ins
where auto_ml_id = #{autoMlId}
and state = 1
order by update_time DESC limit 5
</select>
</mapper>

Loading…
Cancel
Save