| @@ -29,10 +29,13 @@ public class Constant { | |||
| public final static String Running = "Running"; | |||
| public final static String Failed = "Failed"; | |||
| public final static String Pending = "Pending"; | |||
| public final static String Terminated = "Terminated"; | |||
| public final static String Init = "Init"; | |||
| public final static String Stopped = "Stopped"; | |||
| public final static String Succeeded = "Succeeded"; | |||
| public final static String Type_Train = "train"; | |||
| public final static String Type_Evaluate = "evaluate"; | |||
| public final static String AutoMl_Classification = "classification"; | |||
| } | |||
| @@ -0,0 +1,71 @@ | |||
| package com.ruoyi.platform.controller.autoML; | |||
| import com.ruoyi.common.core.web.controller.BaseController; | |||
| import com.ruoyi.common.core.web.domain.AjaxResult; | |||
| import com.ruoyi.common.core.web.domain.GenericsAjaxResult; | |||
| import com.ruoyi.platform.domain.AutoMl; | |||
| import com.ruoyi.platform.service.AutoMlService; | |||
| import com.ruoyi.platform.vo.AutoMlVo; | |||
| import io.swagger.annotations.Api; | |||
| import io.swagger.annotations.ApiOperation; | |||
| import org.springframework.data.domain.Page; | |||
| import org.springframework.data.domain.PageRequest; | |||
| import org.springframework.web.bind.annotation.*; | |||
| import org.springframework.web.multipart.MultipartFile; | |||
| import javax.annotation.Resource; | |||
| import java.io.IOException; | |||
| @RestController | |||
| @RequestMapping("autoML") | |||
| @Api("自动机器学习") | |||
| public class AutoMlController extends BaseController { | |||
| @Resource | |||
| private AutoMlService autoMlService; | |||
| @GetMapping | |||
| @ApiOperation("分页查询") | |||
| public GenericsAjaxResult<Page<AutoMl>> queryByPage(@RequestParam("page") int page, | |||
| @RequestParam("size") int size, | |||
| @RequestParam(value = "ml_name", required = false) String mlName) { | |||
| PageRequest pageRequest = PageRequest.of(page, size); | |||
| return genericsSuccess(this.autoMlService.queryByPage(mlName, pageRequest)); | |||
| } | |||
| @PostMapping | |||
| @ApiOperation("新增自动机器学习") | |||
| public GenericsAjaxResult<AutoMl> addAutoMl(@RequestBody AutoMlVo autoMlVo) throws Exception { | |||
| return genericsSuccess(this.autoMlService.save(autoMlVo)); | |||
| } | |||
| @PutMapping | |||
| @ApiOperation("编辑自动机器学习") | |||
| public GenericsAjaxResult<String> editAutoMl(@RequestBody AutoMlVo autoMlVo) throws Exception { | |||
| return genericsSuccess(this.autoMlService.edit(autoMlVo)); | |||
| } | |||
| @GetMapping("/getAutoMlDetail") | |||
| @ApiOperation("获取自动机器学习详细信息") | |||
| public GenericsAjaxResult<AutoMlVo> getAutoMlDetail(@RequestParam("id") Long id) throws IOException { | |||
| return genericsSuccess(this.autoMlService.getAutoMlDetail(id)); | |||
| } | |||
| @DeleteMapping("{id}") | |||
| @ApiOperation("删除自动机器学习") | |||
| public GenericsAjaxResult<String> deleteAutoMl(@PathVariable("id") Long id) { | |||
| return genericsSuccess(this.autoMlService.delete(id)); | |||
| } | |||
| @CrossOrigin(origins = "*", allowedHeaders = "*") | |||
| @PostMapping("/upload") | |||
| @ApiOperation(value = "上传数据文件csv", notes = "上传数据文件csv,并将信息存入数据库。") | |||
| public AjaxResult upload(@RequestParam("file") MultipartFile file, @RequestParam("uuid") String uuid) throws Exception { | |||
| return AjaxResult.success(this.autoMlService.upload(file, uuid)); | |||
| } | |||
| @PostMapping("/run/{id}") | |||
| @ApiOperation("运行自动机器学习实验") | |||
| public GenericsAjaxResult<String> runAutoML(@PathVariable("id") Long id) throws Exception { | |||
| return genericsSuccess(this.autoMlService.runAutoMlIns(id)); | |||
| } | |||
| } | |||
| @@ -0,0 +1,61 @@ | |||
| package com.ruoyi.platform.controller.autoML; | |||
| import com.ruoyi.common.core.web.controller.BaseController; | |||
| import com.ruoyi.common.core.web.domain.GenericsAjaxResult; | |||
| import com.ruoyi.platform.domain.AutoMlIns; | |||
| import com.ruoyi.platform.service.AutoMlInsService; | |||
| import io.swagger.annotations.Api; | |||
| import io.swagger.annotations.ApiOperation; | |||
| import org.springframework.data.domain.Page; | |||
| import org.springframework.data.domain.PageRequest; | |||
| import org.springframework.web.bind.annotation.*; | |||
| import javax.annotation.Resource; | |||
| import java.io.IOException; | |||
| import java.util.List; | |||
| @RestController | |||
| @RequestMapping("autoMLIns") | |||
| @Api("自动机器学习实验实例") | |||
| public class AutoMlInsController extends BaseController { | |||
| @Resource | |||
| private AutoMlInsService autoMLInsService; | |||
| @GetMapping | |||
| @ApiOperation("分页查询") | |||
| public GenericsAjaxResult<Page<AutoMlIns>> queryByPage(AutoMlIns autoMlIns, int page, int size) throws IOException { | |||
| PageRequest pageRequest = PageRequest.of(page, size); | |||
| return genericsSuccess(this.autoMLInsService.queryByPage(autoMlIns, pageRequest)); | |||
| } | |||
| @PostMapping | |||
| @ApiOperation("新增实验实例") | |||
| public GenericsAjaxResult<AutoMlIns> add(@RequestBody AutoMlIns autoMlIns) { | |||
| return genericsSuccess(this.autoMLInsService.insert(autoMlIns)); | |||
| } | |||
| @DeleteMapping("{id}") | |||
| @ApiOperation("删除实验实例") | |||
| public GenericsAjaxResult<String> deleteById(@PathVariable("id") Long id) { | |||
| return genericsSuccess(this.autoMLInsService.removeById(id)); | |||
| } | |||
| @DeleteMapping("batchDelete") | |||
| @ApiOperation("批量删除实验实例") | |||
| public GenericsAjaxResult<String> batchDelete(@RequestBody List<Long> ids) { | |||
| return genericsSuccess(this.autoMLInsService.batchDelete(ids)); | |||
| } | |||
| @PutMapping("{id}") | |||
| @ApiOperation("终止实验实例") | |||
| public GenericsAjaxResult<Boolean> terminateAutoMlIns(@PathVariable("id") Long id) { | |||
| return genericsSuccess(this.autoMLInsService.terminateAutoMlIns(id)); | |||
| } | |||
| @GetMapping("{id}") | |||
| @ApiOperation("查看实验实例详情") | |||
| public GenericsAjaxResult<AutoMlIns> getDetailById(@PathVariable("id") Long id) { | |||
| return genericsSuccess(this.autoMLInsService.getDetailById(id)); | |||
| } | |||
| } | |||
| @@ -7,7 +7,6 @@ import com.ruoyi.platform.service.CodeConfigService; | |||
| import io.swagger.annotations.Api; | |||
| import org.springframework.data.domain.Page; | |||
| import org.springframework.data.domain.PageRequest; | |||
| import org.springframework.http.ResponseEntity; | |||
| import org.springframework.web.bind.annotation.*; | |||
| import javax.annotation.Resource; | |||
| @@ -0,0 +1,184 @@ | |||
| package com.ruoyi.platform.domain; | |||
| import com.baomidou.mybatisplus.annotation.TableField; | |||
| import com.fasterxml.jackson.databind.PropertyNamingStrategy; | |||
| import com.fasterxml.jackson.databind.annotation.JsonNaming; | |||
| import com.ruoyi.platform.vo.VersionVo; | |||
| import io.swagger.annotations.ApiModel; | |||
| import io.swagger.annotations.ApiModelProperty; | |||
| import lombok.Data; | |||
| import java.util.Date; | |||
| import java.util.Map; | |||
| @Data | |||
| @JsonNaming(PropertyNamingStrategy.SnakeCaseStrategy.class) | |||
| @ApiModel(description = "自动机器学习") | |||
| public class AutoMl { | |||
| private Long id; | |||
| @ApiModelProperty(value = "实验名称") | |||
| private String mlName; | |||
| @ApiModelProperty(value = "实验描述") | |||
| private String mlDescription; | |||
| @ApiModelProperty(value = "任务类型:classification或regression") | |||
| private String taskType; | |||
| @ApiModelProperty(value = "搜索合适模型的时间限制(以秒为单位)。通过增加这个值,auto-sklearn有更高的机会找到更好的模型。默认3600,非必传。") | |||
| private Integer timeLeftForThisTask; | |||
| @ApiModelProperty(value = "单次调用机器学习模型的时间限制(以秒为单位)。如果机器学习算法运行超过时间限制,将终止模型拟合。将这个值设置得足够高,这样典型的机器学习算法就可以适用于训练数据。默认600,非必传。") | |||
| private Integer perRunTimeLimit; | |||
| @ApiModelProperty(value = "集成模型数量,如果设置为0,则没有集成。默认50,非必传。") | |||
| private Integer ensembleSize; | |||
| @ApiModelProperty(value = "设置为None将禁用集成构建,设置为SingleBest仅使用单个最佳模型而不是集成,设置为default,它将对单目标问题使用EnsembleSelection,对多目标问题使用MultiObjectiveDummyEnsemble。默认default,非必传。") | |||
| private String ensembleClass; | |||
| @ApiModelProperty(value = "在构建集成时只考虑ensemble_nbest模型。这是受到了“最大限度地利用集成选择”中引入的库修剪概念的启发。这是独立于ensemble_class参数的,并且这个修剪步骤是在构造集成之前完成的。默认50,非必传。") | |||
| private Integer ensembleNbest; | |||
| @ApiModelProperty(value = "定义在磁盘中保存的模型的最大数量。额外的模型数量将被永久删除。由于这个变量的性质,它设置了一个集成可以使用多少个模型的上限。必须是大于等于1的整数。如果设置为None,则所有模型都保留在磁盘上。默认50,非必传。") | |||
| private Integer maxModelsOnDisc; | |||
| @ApiModelProperty(value = "随机种子,将决定输出文件名。默认1,非必传。") | |||
| private Integer seed; | |||
| @ApiModelProperty(value = "机器学习算法的内存限制(MB)。如果auto-sklearn试图分配超过memory_limit MB,它将停止拟合机器学习算法。默认3072,非必传。") | |||
| private Integer memoryLimit; | |||
| @ApiModelProperty(value = "如果为None,则使用所有可能的分类算法。否则,指定搜索中包含的步骤和组件。有关可用组件,请参见/pipeline/components/<step>/*。与参数exclude不兼容。多选,逗号分隔。包含:adaboost\n" + | |||
| "bernoulli_nb\n" + | |||
| "decision_tree\n" + | |||
| "extra_trees\n" + | |||
| "gaussian_nb\n" + | |||
| "gradient_boosting\n" + | |||
| "k_nearest_neighbors\n" + | |||
| "lda\n" + | |||
| "liblinear_svc\n" + | |||
| "libsvm_svc\n" + | |||
| "mlp\n" + | |||
| "multinomial_nb\n" + | |||
| "passive_aggressive\n" + | |||
| "qda\n" + | |||
| "random_forest\n" + | |||
| "sgd") | |||
| private String includeClassifier; | |||
| @ApiModelProperty(value = "如果为None,则使用所有可能的特征预处理算法。否则,指定搜索中包含的步骤和组件。有关可用组件,请参见/pipeline/components/<step>/*。与参数exclude不兼容。多选,逗号分隔。包含:densifier\n" + | |||
| "extra_trees_preproc_for_classification\n" + | |||
| "extra_trees_preproc_for_regression\n" + | |||
| "fast_ica\n" + | |||
| "feature_agglomeration\n" + | |||
| "kernel_pca\n" + | |||
| "kitchen_sinks\n" + | |||
| "liblinear_svc_preprocessor\n" + | |||
| "no_preprocessing\n" + | |||
| "nystroem_sampler\n" + | |||
| "pca\n" + | |||
| "polynomial\n" + | |||
| "random_trees_embedding\n" + | |||
| "select_percentile_classification\n" + | |||
| "select_percentile_regression\n" + | |||
| "select_rates_classification\n" + | |||
| "select_rates_regression\n" + | |||
| "truncatedSVD") | |||
| private String includeFeaturePreprocessor; | |||
| @ApiModelProperty(value = "如果为None,则使用所有可能的回归算法。否则,指定搜索中包含的步骤和组件。有关可用组件,请参见/pipeline/components/<step>/*。与参数exclude不兼容。多选,逗号分隔。包含:adaboost,\n" + | |||
| "ard_regression,\n" + | |||
| "decision_tree,\n" + | |||
| "extra_trees,\n" + | |||
| "gaussian_process,\n" + | |||
| "gradient_boosting,\n" + | |||
| "k_nearest_neighbors,\n" + | |||
| "liblinear_svr,\n" + | |||
| "libsvm_svr,\n" + | |||
| "mlp,\n" + | |||
| "random_forest,\n" + | |||
| "sgd") | |||
| private String includeRegressor; | |||
| private String excludeClassifier; | |||
| private String excludeRegressor; | |||
| private String excludeFeaturePreprocessor; | |||
| @ApiModelProperty(value = "测试集的比率,0到1之间") | |||
| private Float testSize; | |||
| @ApiModelProperty(value = "如何处理过拟合,如果使用基于“cv”的方法或Splitter对象,可能需要使用resampling_strategy_arguments。holdout或crossValid") | |||
| private String resamplingStrategy; | |||
| @ApiModelProperty(value = "重采样划分训练集和验证集,训练集的比率,0到1之间") | |||
| private Float trainSize; | |||
| @ApiModelProperty(value = "拆分数据前是否进行shuffle") | |||
| private Boolean shuffle; | |||
| @ApiModelProperty(value = "交叉验证的折数,当resamplingStrategy为crossValid时,此项必填,为整数") | |||
| private Integer folds; | |||
| @ApiModelProperty(value = "文件夹存放配置输出和日志文件,默认/tmp/automl") | |||
| private String tmpFolder; | |||
| @ApiModelProperty(value = "数据集csv文件中哪几列是预测目标列,逗号分隔") | |||
| private String targetColumns; | |||
| @ApiModelProperty(value = "自定义指标名称") | |||
| private String metricName; | |||
| @ApiModelProperty(value = "模型优化目标指标及权重,json格式。分类的指标包含:accuracy\n" + | |||
| "balanced_accuracy\n" + | |||
| "roc_auc\n" + | |||
| "average_precision\n" + | |||
| "log_loss\n" + | |||
| "precision_macro\n" + | |||
| "precision_micro\n" + | |||
| "precision_samples\n" + | |||
| "precision_weighted\n" + | |||
| "recall_macro\n" + | |||
| "recall_micro\n" + | |||
| "recall_samples\n" + | |||
| "recall_weighted\n" + | |||
| "f1_macro\n" + | |||
| "f1_micro\n" + | |||
| "f1_samples\n" + | |||
| "f1_weighted\n" + | |||
| "回归的指标包含:mean_absolute_error\n" + | |||
| "mean_squared_error\n" + | |||
| "root_mean_squared_error\n" + | |||
| "mean_squared_log_error\n" + | |||
| "median_absolute_error\n" + | |||
| "r2") | |||
| private String metrics; | |||
| @ApiModelProperty(value = "指标优化方向,是越大越好还是越小越好") | |||
| private Boolean greaterIsBetter; | |||
| @ApiModelProperty(value = "模型计算并打印指标") | |||
| private String scoringFunctions; | |||
| private Integer state; | |||
| private String runState; | |||
| private Double progress; | |||
| private String createBy; | |||
| private Date createTime; | |||
| private String updateBy; | |||
| private Date updateTime; | |||
| private String dataset; | |||
| @ApiModelProperty(value = "状态列表") | |||
| private String statusList; | |||
| } | |||
| @@ -0,0 +1,50 @@ | |||
| package com.ruoyi.platform.domain; | |||
| import com.fasterxml.jackson.databind.PropertyNamingStrategy; | |||
| import com.fasterxml.jackson.databind.annotation.JsonNaming; | |||
| import io.swagger.annotations.ApiModel; | |||
| import io.swagger.annotations.ApiModelProperty; | |||
| import lombok.Data; | |||
| import java.util.Date; | |||
| @Data | |||
| @JsonNaming(PropertyNamingStrategy.SnakeCaseStrategy.class) | |||
| @ApiModel(description = "自动机器学习实验实例") | |||
| public class AutoMlIns { | |||
| private Long id; | |||
| private Long autoMlId; | |||
| private String resultPath; | |||
| private String modelPath; | |||
| private String imgPath; | |||
| private String runHistoryPath; | |||
| private Integer state; | |||
| private String status; | |||
| private String nodeStatus; | |||
| private String nodeResult; | |||
| private String param; | |||
| private String source; | |||
| @ApiModelProperty(value = "Argo实例名称") | |||
| private String argoInsName; | |||
| @ApiModelProperty(value = "Argo命名空间") | |||
| private String argoInsNs; | |||
| private Date createTime; | |||
| private Date updateTime; | |||
| private Date finishTime; | |||
| } | |||
| @@ -0,0 +1,21 @@ | |||
| package com.ruoyi.platform.mapper; | |||
| import com.ruoyi.platform.domain.AutoMl; | |||
| import org.apache.ibatis.annotations.Param; | |||
| import org.springframework.data.domain.Pageable; | |||
| import java.util.List; | |||
| public interface AutoMlDao { | |||
| long count(@Param("mlName") String mlName); | |||
| List<AutoMl> queryByPage(@Param("mlName") String mlName, @Param("pageable") Pageable pageable); | |||
| AutoMl getAutoMlById(@Param("id") Long id); | |||
| AutoMl getAutoMlByName(@Param("mlName") String mlName); | |||
| int save(@Param("autoMl") AutoMl autoMl); | |||
| int edit(@Param("autoMl") AutoMl autoMl); | |||
| } | |||
| @@ -0,0 +1,23 @@ | |||
| package com.ruoyi.platform.mapper; | |||
| import com.ruoyi.platform.domain.AutoMlIns; | |||
| import org.apache.ibatis.annotations.Param; | |||
| import org.springframework.data.domain.Pageable; | |||
| import java.util.List; | |||
| public interface AutoMlInsDao { | |||
| long count(@Param("autoMlIns") AutoMlIns autoMlIns); | |||
| List<AutoMlIns> queryAllByLimit(@Param("autoMlIns") AutoMlIns autoMlIns, @Param("pageable") Pageable pageable); | |||
| List<AutoMlIns> getByAutoMlId(@Param("autoMlId") Long AutoMlId); | |||
| int insert(@Param("autoMlIns") AutoMlIns autoMlIns); | |||
| int update(@Param("autoMlIns") AutoMlIns autoMlIns); | |||
| AutoMlIns queryById(@Param("id") Long id); | |||
| List<AutoMlIns> queryByAutoMlInsIsNotTerminated(); | |||
| } | |||
| @@ -0,0 +1,97 @@ | |||
| package com.ruoyi.platform.scheduling; | |||
| import com.ruoyi.platform.domain.AutoMl; | |||
| import com.ruoyi.platform.domain.AutoMlIns; | |||
| import com.ruoyi.platform.mapper.AutoMlDao; | |||
| import com.ruoyi.platform.mapper.AutoMlInsDao; | |||
| import com.ruoyi.platform.service.AutoMlInsService; | |||
| import org.apache.commons.lang3.StringUtils; | |||
| import org.springframework.scheduling.annotation.Scheduled; | |||
| import org.springframework.stereotype.Component; | |||
| import javax.annotation.Resource; | |||
| import java.util.ArrayList; | |||
| import java.util.Iterator; | |||
| import java.util.List; | |||
| @Component() | |||
| public class AutoMlInsStatusTask { | |||
| @Resource | |||
| private AutoMlInsService autoMlInsService; | |||
| @Resource | |||
| private AutoMlInsDao autoMlInsDao; | |||
| @Resource | |||
| private AutoMlDao autoMlDao; | |||
| private List<Long> autoMlIds = new ArrayList<>(); | |||
| @Scheduled(cron = "0/30 * * * * ?") // 每30S执行一次 | |||
| public void executeAutoMlInsStatus() throws Exception { | |||
| // 首先查到所有非终止态的实验实例 | |||
| List<AutoMlIns> autoMlInsList = autoMlInsService.queryByAutoMlInsIsNotTerminated(); | |||
| // 去argo查询状态 | |||
| List<AutoMlIns> updateList = new ArrayList<>(); | |||
| if (autoMlInsList != null && autoMlInsList.size() > 0) { | |||
| for (AutoMlIns autoMlIns : autoMlInsList) { | |||
| //当原本状态为null或非终止态时才调用argo接口 | |||
| try { | |||
| autoMlIns = autoMlInsService.queryStatusFromArgo(autoMlIns); | |||
| } catch (Exception e) { | |||
| autoMlIns.setStatus("Failed"); | |||
| } | |||
| // 线程安全的添加操作 | |||
| synchronized (autoMlIds) { | |||
| autoMlIds.add(autoMlIns.getAutoMlId()); | |||
| } | |||
| updateList.add(autoMlIns); | |||
| } | |||
| if (updateList.size() > 0) { | |||
| for (AutoMlIns autoMlIns : updateList) { | |||
| autoMlInsDao.update(autoMlIns); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @Scheduled(cron = "0/30 * * * * ?") // / 每30S执行一次 | |||
| public void executeAutoMlStatus() throws Exception { | |||
| if (autoMlIds.size() == 0) { | |||
| return; | |||
| } | |||
| // 存储需要更新的实验对象列表 | |||
| List<AutoMl> updateAutoMls = new ArrayList<>(); | |||
| for (Long autoMlId : autoMlIds) { | |||
| // 获取当前实验的所有实例列表 | |||
| List<AutoMlIns> insList = autoMlInsDao.getByAutoMlId(autoMlId); | |||
| List<String> statusList = new ArrayList<>(); | |||
| // 更新实验状态列表 | |||
| for (int i = 0; i < insList.size(); i++) { | |||
| statusList.add(insList.get(i).getStatus()); | |||
| } | |||
| String subStatus = statusList.toString().substring(1, statusList.toString().length() - 1); | |||
| AutoMl autoMl = autoMlDao.getAutoMlById(autoMlId); | |||
| if (!StringUtils.equals(autoMl.getStatusList(), subStatus)) { | |||
| autoMl.setStatusList(subStatus); | |||
| updateAutoMls.add(autoMl); | |||
| autoMlDao.edit(autoMl); | |||
| } | |||
| } | |||
| if (!updateAutoMls.isEmpty()) { | |||
| // 使用Iterator进行安全的删除操作 | |||
| Iterator<Long> iterator = autoMlIds.iterator(); | |||
| while (iterator.hasNext()) { | |||
| Long autoMlId = iterator.next(); | |||
| for (AutoMl autoMl : updateAutoMls) { | |||
| if (autoMl.getId().equals(autoMlId)) { | |||
| iterator.remove(); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,29 @@ | |||
| package com.ruoyi.platform.service; | |||
| import com.ruoyi.platform.domain.AutoMlIns; | |||
| import org.springframework.data.domain.Page; | |||
| import org.springframework.data.domain.PageRequest; | |||
| import java.io.IOException; | |||
| import java.util.List; | |||
| public interface AutoMlInsService { | |||
| Page<AutoMlIns> queryByPage(AutoMlIns autoMlIns, PageRequest pageRequest) throws IOException; | |||
| AutoMlIns insert(AutoMlIns autoMlIns); | |||
| String removeById(Long id); | |||
| String batchDelete(List<Long> ids); | |||
| List<AutoMlIns> queryByAutoMlInsIsNotTerminated(); | |||
| AutoMlIns queryStatusFromArgo(AutoMlIns autoMlIns); | |||
| boolean terminateAutoMlIns(Long id); | |||
| AutoMlIns getDetailById(Long id); | |||
| void updateAutoMlStatus(Long autoMlId); | |||
| } | |||
| @@ -0,0 +1,26 @@ | |||
| package com.ruoyi.platform.service; | |||
| import com.ruoyi.platform.domain.AutoMl; | |||
| import com.ruoyi.platform.vo.AutoMlVo; | |||
| import org.springframework.data.domain.Page; | |||
| import org.springframework.data.domain.PageRequest; | |||
| import org.springframework.web.multipart.MultipartFile; | |||
| import java.io.IOException; | |||
| import java.util.Map; | |||
| public interface AutoMlService { | |||
| Page<AutoMl> queryByPage(String mlName, PageRequest pageRequest); | |||
| AutoMl save(AutoMlVo autoMlVo) throws Exception; | |||
| String edit(AutoMlVo autoMlVo) throws Exception; | |||
| String delete(Long id); | |||
| AutoMlVo getAutoMlDetail(Long id) throws IOException; | |||
| Map<String, String> upload(MultipartFile file, String uuid) throws Exception; | |||
| String runAutoMlIns(Long id) throws Exception; | |||
| } | |||
| @@ -0,0 +1,250 @@ | |||
| package com.ruoyi.platform.service.impl; | |||
| import com.ruoyi.platform.constant.Constant; | |||
| import com.ruoyi.platform.domain.AutoMl; | |||
| import com.ruoyi.platform.domain.AutoMlIns; | |||
| import com.ruoyi.platform.mapper.AutoMlDao; | |||
| import com.ruoyi.platform.mapper.AutoMlInsDao; | |||
| import com.ruoyi.platform.service.AutoMlInsService; | |||
| import com.ruoyi.platform.utils.DateUtils; | |||
| import com.ruoyi.platform.utils.HttpUtils; | |||
| import com.ruoyi.platform.utils.JsonUtils; | |||
| import org.apache.commons.lang3.StringUtils; | |||
| import org.springframework.beans.factory.annotation.Value; | |||
| import org.springframework.data.domain.Page; | |||
| import org.springframework.data.domain.PageImpl; | |||
| import org.springframework.data.domain.PageRequest; | |||
| import org.springframework.stereotype.Service; | |||
| import javax.annotation.Resource; | |||
| import java.io.IOException; | |||
| import java.util.*; | |||
| @Service | |||
| public class AutoMlInsServiceImpl implements AutoMlInsService { | |||
| @Value("${argo.url}") | |||
| private String argoUrl; | |||
| @Value("${argo.workflowStatus}") | |||
| private String argoWorkflowStatus; | |||
| @Value("${argo.workflowTermination}") | |||
| private String argoWorkflowTermination; | |||
| @Resource | |||
| private AutoMlInsDao autoMlInsDao; | |||
| @Resource | |||
| private AutoMlDao autoMlDao; | |||
| @Override | |||
| public Page<AutoMlIns> queryByPage(AutoMlIns autoMlIns, PageRequest pageRequest) throws IOException { | |||
| long total = this.autoMlInsDao.count(autoMlIns); | |||
| List<AutoMlIns> autoMlInsList = this.autoMlInsDao.queryAllByLimit(autoMlIns, pageRequest); | |||
| return new PageImpl<>(autoMlInsList, pageRequest, total); | |||
| } | |||
| @Override | |||
| public AutoMlIns insert(AutoMlIns autoMlIns) { | |||
| this.autoMlInsDao.insert(autoMlIns); | |||
| return autoMlIns; | |||
| } | |||
| @Override | |||
| public String removeById(Long id) { | |||
| AutoMlIns autoMlIns = autoMlInsDao.queryById(id); | |||
| if (autoMlIns == null) { | |||
| return "实验实例不存在"; | |||
| } | |||
| if (StringUtils.isEmpty(autoMlIns.getStatus())) { | |||
| autoMlIns = queryStatusFromArgo(autoMlIns); | |||
| } | |||
| if (StringUtils.equals(autoMlIns.getStatus(), "Running")) { | |||
| return "实验实例正在运行,不可删除"; | |||
| } | |||
| autoMlIns.setState(Constant.State_invalid); | |||
| int update = autoMlInsDao.update(autoMlIns); | |||
| if (update > 0) { | |||
| updateAutoMlStatus(autoMlIns.getAutoMlId()); | |||
| return "删除成功"; | |||
| } else { | |||
| return "删除失败"; | |||
| } | |||
| } | |||
| @Override | |||
| public String batchDelete(List<Long> ids) { | |||
| for (Long id : ids) { | |||
| String result = removeById(id); | |||
| if (!"删除成功".equals(result)) { | |||
| return result; | |||
| } | |||
| } | |||
| return "删除成功"; | |||
| } | |||
| @Override | |||
| public List<AutoMlIns> queryByAutoMlInsIsNotTerminated() { | |||
| return autoMlInsDao.queryByAutoMlInsIsNotTerminated(); | |||
| } | |||
| @Override | |||
| public AutoMlIns queryStatusFromArgo(AutoMlIns ins) { | |||
| String namespace = ins.getArgoInsNs(); | |||
| String name = ins.getArgoInsName(); | |||
| // 创建请求数据map | |||
| Map<String, Object> requestData = new HashMap<>(); | |||
| requestData.put("namespace", namespace); | |||
| requestData.put("name", name); | |||
| // 创建发送数据map,将请求数据作为"data"键的值 | |||
| Map<String, Object> res = new HashMap<>(); | |||
| res.put("data", requestData); | |||
| try { | |||
| // 发送POST请求到Argo工作流状态查询接口,并将请求数据转换为JSON | |||
| String req = HttpUtils.sendPost(argoUrl + argoWorkflowStatus, null, JsonUtils.mapToJson(res)); | |||
| // 检查响应是否为空或无内容 | |||
| if (req == null || StringUtils.isEmpty(req)) { | |||
| throw new RuntimeException("工作流状态响应为空。"); | |||
| } | |||
| // 将响应的JSON字符串转换为Map对象 | |||
| Map<String, Object> runResMap = JsonUtils.jsonToMap(req); | |||
| // 从响应Map中获取"data"部分 | |||
| Map<String, Object> data = (Map<String, Object>) runResMap.get("data"); | |||
| if (data == null || data.isEmpty()) { | |||
| throw new RuntimeException("工作流数据为空."); | |||
| } | |||
| // 从"data"中获取"status"部分,并返回"phase"的值 | |||
| Map<String, Object> status = (Map<String, Object>) data.get("status"); | |||
| if (status == null || status.isEmpty()) { | |||
| throw new RuntimeException("工作流状态为空。"); | |||
| } | |||
| //解析流水线结束时间 | |||
| String finishedAtString = (String) status.get("finishedAt"); | |||
| if (finishedAtString != null && !finishedAtString.isEmpty()) { | |||
| Date finishTime = DateUtils.convertUTCtoShanghaiDate(finishedAtString); | |||
| ins.setFinishTime(finishTime); | |||
| } | |||
| // 解析nodes字段,提取节点状态并转换为JSON字符串 | |||
| Map<String, Object> nodes = (Map<String, Object>) status.get("nodes"); | |||
| Map<String, Object> modifiedNodes = new LinkedHashMap<>(); | |||
| if (nodes != null) { | |||
| for (Map.Entry<String, Object> nodeEntry : nodes.entrySet()) { | |||
| Map<String, Object> nodeDetails = (Map<String, Object>) nodeEntry.getValue(); | |||
| String templateName = (String) nodeDetails.get("displayName"); | |||
| modifiedNodes.put(templateName, nodeDetails); | |||
| } | |||
| } | |||
| String nodeStatusJson = JsonUtils.mapToJson(modifiedNodes); | |||
| ins.setNodeStatus(nodeStatusJson); | |||
| //终止态为终止不改 | |||
| if (!StringUtils.equals(ins.getStatus(), Constant.Terminated)) { | |||
| ins.setStatus(StringUtils.isNotEmpty((String) status.get("phase")) ? (String) status.get("phase") : Constant.Pending); | |||
| } | |||
| if (StringUtils.equals(ins.getStatus(), "Error")) { | |||
| ins.setStatus(Constant.Failed); | |||
| } | |||
| return ins; | |||
| } catch (Exception e) { | |||
| throw new RuntimeException("查询状态失败: " + e.getMessage(), e); | |||
| } | |||
| } | |||
| @Override | |||
| public boolean terminateAutoMlIns(Long id) { | |||
| AutoMlIns autoMlIns = autoMlInsDao.queryById(id); | |||
| if (autoMlIns == null) { | |||
| throw new IllegalStateException("实验实例未查询到,id: " + id); | |||
| } | |||
| String currentStatus = autoMlIns.getStatus(); | |||
| String name = autoMlIns.getArgoInsName(); | |||
| String namespace = autoMlIns.getArgoInsNs(); | |||
| // 获取当前状态,如果为空,则从Argo查询 | |||
| if (StringUtils.isEmpty(currentStatus)) { | |||
| currentStatus = queryStatusFromArgo(autoMlIns).getStatus(); | |||
| } | |||
| // 只有状态是"Running"时才能终止实例 | |||
| if (!currentStatus.equalsIgnoreCase(Constant.Running)) { | |||
| return false; // 如果不是"Running"状态,则不执行终止操作 | |||
| } | |||
| // 创建请求数据map | |||
| Map<String, Object> requestData = new HashMap<>(); | |||
| requestData.put("namespace", namespace); | |||
| requestData.put("name", name); | |||
| // 创建发送数据map,将请求数据作为"data"键的值 | |||
| Map<String, Object> res = new HashMap<>(); | |||
| res.put("data", requestData); | |||
| try { | |||
| // 发送POST请求到Argo工作流状态查询接口,并将请求数据转换为JSON | |||
| String req = HttpUtils.sendPost(argoUrl + argoWorkflowTermination, null, JsonUtils.mapToJson(res)); | |||
| // 检查响应是否为空或无内容 | |||
| if (StringUtils.isEmpty(req)) { | |||
| throw new RuntimeException("终止响应内容为空。"); | |||
| } | |||
| // 将响应的JSON字符串转换为Map对象 | |||
| Map<String, Object> runResMap = JsonUtils.jsonToMap(req); | |||
| // 从响应Map中直接获取"errCode"的值 | |||
| Integer errCode = (Integer) runResMap.get("errCode"); | |||
| if (errCode != null && errCode == 0) { | |||
| //更新autoMlIns,确保状态更新被保存到数据库 | |||
| AutoMlIns ins = queryStatusFromArgo(autoMlIns); | |||
| String nodeStatus = ins.getNodeStatus(); | |||
| Map<String, Object> nodeMap = JsonUtils.jsonToMap(nodeStatus); | |||
| // 遍历 map | |||
| for (Map.Entry<String, Object> entry : nodeMap.entrySet()) { | |||
| // 获取每个 Map 中的值并强制转换为 Map | |||
| Map<String, Object> innerMap = (Map<String, Object>) entry.getValue(); | |||
| // 检查 phase 的值 | |||
| if (innerMap.containsKey("phase")) { | |||
| String phaseValue = (String) innerMap.get("phase"); | |||
| // 如果值不等于 Succeeded,则赋值为 Failed | |||
| if (!StringUtils.equals("Succeeded", phaseValue)) { | |||
| innerMap.put("phase", "Failed"); | |||
| } | |||
| } | |||
| } | |||
| ins.setNodeStatus(JsonUtils.mapToJson(nodeMap)); | |||
| ins.setStatus(Constant.Terminated); | |||
| ins.setUpdateTime(new Date()); | |||
| this.autoMlInsDao.update(ins); | |||
| updateAutoMlStatus(autoMlIns.getAutoMlId()); | |||
| return true; | |||
| } else { | |||
| return false; | |||
| } | |||
| } catch (Exception e) { | |||
| throw new RuntimeException("终止实例错误: " + e.getMessage(), e); | |||
| } | |||
| } | |||
| @Override | |||
| public AutoMlIns getDetailById(Long id) { | |||
| return this.autoMlInsDao.queryById(id); | |||
| } | |||
| public void updateAutoMlStatus(Long autoMlId) { | |||
| List<AutoMlIns> insList = autoMlInsDao.getByAutoMlId(autoMlId); | |||
| List<String> statusList = new ArrayList<>(); | |||
| // 更新实验状态列表 | |||
| for (int i = 0; i < insList.size(); i++) { | |||
| statusList.add(insList.get(i).getStatus()); | |||
| } | |||
| String subStatus = statusList.toString().substring(1, statusList.toString().length() - 1); | |||
| AutoMl autoMl = autoMlDao.getAutoMlById(autoMlId); | |||
| if (!StringUtils.equals(autoMl.getStatusList(), subStatus)) { | |||
| autoMl.setStatusList(subStatus); | |||
| autoMlDao.edit(autoMl); | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,216 @@ | |||
| package com.ruoyi.platform.service.impl; | |||
| import com.ruoyi.common.security.utils.SecurityUtils; | |||
| import com.ruoyi.platform.constant.Constant; | |||
| import com.ruoyi.platform.domain.AutoMl; | |||
| import com.ruoyi.platform.domain.AutoMlIns; | |||
| import com.ruoyi.platform.mapper.AutoMlDao; | |||
| import com.ruoyi.platform.mapper.AutoMlInsDao; | |||
| import com.ruoyi.platform.service.AutoMlInsService; | |||
| import com.ruoyi.platform.service.AutoMlService; | |||
| import com.ruoyi.platform.utils.FileUtil; | |||
| import com.ruoyi.platform.utils.HttpUtils; | |||
| import com.ruoyi.platform.utils.JacksonUtil; | |||
| import com.ruoyi.platform.utils.JsonUtils; | |||
| import com.ruoyi.platform.vo.AutoMlParamVo; | |||
| import com.ruoyi.platform.vo.AutoMlVo; | |||
| import org.apache.commons.collections4.MapUtils; | |||
| import org.apache.commons.io.FileUtils; | |||
| import org.apache.commons.lang3.StringUtils; | |||
| import org.springframework.beans.BeanUtils; | |||
| import org.springframework.beans.factory.annotation.Value; | |||
| import org.springframework.data.domain.Page; | |||
| import org.springframework.data.domain.PageImpl; | |||
| import org.springframework.data.domain.PageRequest; | |||
| import org.springframework.stereotype.Service; | |||
| import org.springframework.web.multipart.MultipartFile; | |||
| import javax.annotation.Resource; | |||
| import java.io.File; | |||
| import java.io.IOException; | |||
| import java.util.ArrayList; | |||
| import java.util.HashMap; | |||
| import java.util.List; | |||
| import java.util.Map; | |||
| @Service("autoMLService") | |||
| public class AutoMlServiceImpl implements AutoMlService { | |||
| @Value("${git.localPath}") | |||
| String localPath; | |||
| @Value("${argo.url}") | |||
| private String argoUrl; | |||
| @Value("${argo.convertAutoML}") | |||
| String convertAutoML; | |||
| @Value("${argo.workflowRun}") | |||
| private String argoWorkflowRun; | |||
| @Value("${minio.endpoint}") | |||
| private String minioEndpoint; | |||
| @Resource | |||
| private AutoMlDao autoMlDao; | |||
| @Resource | |||
| private AutoMlInsDao autoMlInsDao; | |||
| @Resource | |||
| private AutoMlInsService autoMlInsService; | |||
| @Override | |||
| public Page<AutoMl> queryByPage(String mlName, PageRequest pageRequest) { | |||
| long total = autoMlDao.count(mlName); | |||
| List<AutoMl> autoMls = autoMlDao.queryByPage(mlName, pageRequest); | |||
| return new PageImpl<>(autoMls, pageRequest, total); | |||
| } | |||
| @Override | |||
| public AutoMl save(AutoMlVo autoMlVo) throws Exception { | |||
| AutoMl autoMlByName = autoMlDao.getAutoMlByName(autoMlVo.getMlName()); | |||
| if (autoMlByName != null) { | |||
| throw new RuntimeException("实验名称已存在"); | |||
| } | |||
| AutoMl autoMl = new AutoMl(); | |||
| BeanUtils.copyProperties(autoMlVo, autoMl); | |||
| String username = SecurityUtils.getLoginUser().getUsername(); | |||
| autoMl.setCreateBy(username); | |||
| autoMl.setUpdateBy(username); | |||
| String datasetJson = JacksonUtil.toJSONString(autoMlVo.getDataset()); | |||
| autoMl.setDataset(datasetJson); | |||
| autoMlDao.save(autoMl); | |||
| return autoMl; | |||
| } | |||
| @Override | |||
| public String edit(AutoMlVo autoMlVo) throws Exception { | |||
| AutoMl oldAutoMl = autoMlDao.getAutoMlByName(autoMlVo.getMlName()); | |||
| if (oldAutoMl != null && !oldAutoMl.getId().equals(autoMlVo.getId())) { | |||
| throw new RuntimeException("实验名称已存在"); | |||
| } | |||
| AutoMl autoMl = new AutoMl(); | |||
| BeanUtils.copyProperties(autoMlVo, autoMl); | |||
| String username = SecurityUtils.getLoginUser().getUsername(); | |||
| autoMl.setUpdateBy(username); | |||
| String datasetJson = JacksonUtil.toJSONString(autoMlVo.getDataset()); | |||
| autoMl.setDataset(datasetJson); | |||
| autoMlDao.edit(autoMl); | |||
| return "修改成功"; | |||
| } | |||
| @Override | |||
| public String delete(Long id) { | |||
| AutoMl autoMl = autoMlDao.getAutoMlById(id); | |||
| if (autoMl == null) { | |||
| throw new RuntimeException("服务不存在"); | |||
| } | |||
| String username = SecurityUtils.getLoginUser().getUsername(); | |||
| String createBy = autoMl.getCreateBy(); | |||
| if (!(StringUtils.equals(username, "admin") || StringUtils.equals(username, createBy))) { | |||
| throw new RuntimeException("无权限删除该服务"); | |||
| } | |||
| autoMl.setState(Constant.State_invalid); | |||
| return autoMlDao.edit(autoMl) > 0 ? "删除成功" : "删除失败"; | |||
| } | |||
| @Override | |||
| public AutoMlVo getAutoMlDetail(Long id) throws IOException { | |||
| AutoMl autoMl = autoMlDao.getAutoMlById(id); | |||
| AutoMlVo autoMlVo = new AutoMlVo(); | |||
| BeanUtils.copyProperties(autoMl, autoMlVo); | |||
| if (StringUtils.isNotEmpty(autoMl.getDataset())) { | |||
| autoMlVo.setDataset(JsonUtils.jsonToMap(autoMl.getDataset())); | |||
| } | |||
| return autoMlVo; | |||
| } | |||
| @Override | |||
| public Map<String, String> upload(MultipartFile file, String uuid) throws Exception { | |||
| Map<String, String> result = new HashMap<>(); | |||
| String username = SecurityUtils.getLoginUser().getUsername(); | |||
| String fileName = file.getOriginalFilename(); | |||
| String path = localPath + "temp/" + username + "/automl_data/" + uuid; | |||
| long sizeInBytes = file.getSize(); | |||
| String formattedSize = FileUtil.formatFileSize(sizeInBytes); | |||
| File targetFile = new File(path, file.getOriginalFilename()); | |||
| // 确保目录存在 | |||
| targetFile.getParentFile().mkdirs(); | |||
| // 保存文件到目标路径 | |||
| FileUtils.copyInputStreamToFile(file.getInputStream(), targetFile); | |||
| // 返回上传文件的路径 | |||
| result.put("fileName", fileName); | |||
| result.put("url", path); // objectName根据实际情况定义 | |||
| result.put("fileSize", formattedSize); | |||
| return result; | |||
| } | |||
| @Override | |||
| public String runAutoMlIns(Long id) throws Exception { | |||
| AutoMl autoMl = autoMlDao.getAutoMlById(id); | |||
| if (autoMl == null) { | |||
| throw new Exception("自动机器学习配置不存在"); | |||
| } | |||
| AutoMlParamVo autoMlParam = new AutoMlParamVo(); | |||
| BeanUtils.copyProperties(autoMl, autoMlParam); | |||
| autoMlParam.setDataset(JsonUtils.jsonToMap(autoMl.getDataset())); | |||
| String param = JsonUtils.objectToJson(autoMlParam); | |||
| // 调argo转换接口 | |||
| try { | |||
| String convertRes = HttpUtils.sendPost(argoUrl + convertAutoML, param); | |||
| if (convertRes == null || StringUtils.isEmpty(convertRes)) { | |||
| throw new RuntimeException("转换流水线失败"); | |||
| } | |||
| Map<String, Object> converMap = JsonUtils.jsonToMap(convertRes); | |||
| // 组装运行接口json | |||
| Map<String, Object> output = (Map<String, Object>) converMap.get("output"); | |||
| Map<String, Object> runReqMap = new HashMap<>(); | |||
| runReqMap.put("data", converMap.get("data")); | |||
| // 调argo运行接口 | |||
| String runRes = HttpUtils.sendPost(argoUrl + argoWorkflowRun, JsonUtils.mapToJson(runReqMap)); | |||
| if (runRes == null || StringUtils.isEmpty(runRes)) { | |||
| throw new RuntimeException("Failed to run workflow."); | |||
| } | |||
| Map<String, Object> runResMap = JsonUtils.jsonToMap(runRes); | |||
| Map<String, Object> data = (Map<String, Object>) runResMap.get("data"); | |||
| //判断data为空 | |||
| if (data == null || MapUtils.isEmpty(data)) { | |||
| throw new RuntimeException("Failed to run workflow."); | |||
| } | |||
| Map<String, Object> metadata = (Map<String, Object>) data.get("metadata"); | |||
| // 插入记录到实验实例表 | |||
| AutoMlIns autoMlIns = new AutoMlIns(); | |||
| autoMlIns.setAutoMlId(autoMl.getId()); | |||
| autoMlIns.setArgoInsNs((String) metadata.get("namespace")); | |||
| autoMlIns.setArgoInsName((String) metadata.get("name")); | |||
| autoMlIns.setParam(param); | |||
| autoMlIns.setStatus(Constant.Pending); | |||
| //替换argoInsName | |||
| String outputString = JsonUtils.mapToJson(output); | |||
| autoMlIns.setNodeResult(outputString.replace("{{workflow.name}}", (String) metadata.get("name"))); | |||
| Map<String, Object> param_output = (Map<String, Object>) output.get("param_output"); | |||
| List output1 = (ArrayList) param_output.values().toArray()[0]; | |||
| Map<String, String> output2 = (Map<String, String>) output1.get(0); | |||
| String outputPath = minioEndpoint + "/" + output2.get("path").replace("{{workflow.name}}", (String) metadata.get("name")) + "/"; | |||
| autoMlIns.setModelPath(outputPath + "save_model.joblib"); | |||
| if (Constant.AutoMl_Classification.equals(autoMl.getTaskType())) { | |||
| autoMlIns.setImgPath(outputPath + "Auto-sklearn_accuracy_over_time.png" + "," + outputPath + "Train_Confusion_Matrix.png" + "," + outputPath + "Test_Confusion_Matrix.png"); | |||
| } else { | |||
| autoMlIns.setImgPath(outputPath + "Auto-sklearn_accuracy_over_time.png" + "," + outputPath + "regression.png"); | |||
| } | |||
| autoMlIns.setResultPath(outputPath + "result.txt"); | |||
| String seed = autoMl.getSeed() != null ? String.valueOf(autoMl.getSeed()) : "1"; | |||
| autoMlIns.setRunHistoryPath(outputPath + "smac3-output/run_" + seed + "/runhistory.json"); | |||
| autoMlInsDao.insert(autoMlIns); | |||
| autoMlInsService.updateAutoMlStatus(id); | |||
| } catch (Exception e) { | |||
| throw new RuntimeException(e); | |||
| } | |||
| return "执行成功"; | |||
| } | |||
| } | |||
| @@ -104,8 +104,6 @@ public class JupyterServiceImpl implements JupyterService { | |||
| // String pvcName = loginUser.getUsername().toLowerCase() + "-editor-pvc"; | |||
| // V1PersistentVolumeClaim pvc = k8sClientUtil.createPvc(namespace, pvcName, storage, storageClassName); | |||
| //TODO 设置镜像可配置,这里先用默认镜像启动pod | |||
| // 调用修改后的 createPod 方法,传入额外的参数 | |||
| // Integer podPort = k8sClientUtil.createConfiguredPod(podName, namespace, port, mountPath, pvc, devEnvironment, minioPvcName, datasetPath, modelPath); | |||
| Integer podPort = k8sClientUtil.createConfiguredPod(podName, namespace, port, mountPath, null, devEnvironment, minioPvcName, datasetPath, modelPath); | |||
| @@ -0,0 +1,155 @@ | |||
| package com.ruoyi.platform.vo; | |||
| import com.fasterxml.jackson.annotation.JsonInclude; | |||
| import com.fasterxml.jackson.databind.PropertyNamingStrategy; | |||
| import com.fasterxml.jackson.databind.annotation.JsonNaming; | |||
| import io.swagger.annotations.ApiModel; | |||
| import io.swagger.annotations.ApiModelProperty; | |||
| import lombok.Data; | |||
| import java.util.Map; | |||
| @Data | |||
| @JsonNaming(PropertyNamingStrategy.SnakeCaseStrategy.class) | |||
| @JsonInclude(JsonInclude.Include.NON_NULL) | |||
| @ApiModel(description = "自动机器学习参数") | |||
| public class AutoMlParamVo { | |||
| @ApiModelProperty(value = "任务类型:classification或regression") | |||
| private String taskType; | |||
| @ApiModelProperty(value = "搜索合适模型的时间限制(以秒为单位)。通过增加这个值,auto-sklearn有更高的机会找到更好的模型。默认3600,非必传。") | |||
| private Integer timeLeftForThisTask; | |||
| @ApiModelProperty(value = "单次调用机器学习模型的时间限制(以秒为单位)。如果机器学习算法运行超过时间限制,将终止模型拟合。将这个值设置得足够高,这样典型的机器学习算法就可以适用于训练数据。默认600,非必传。") | |||
| private Integer perRunTimeLimit; | |||
| @ApiModelProperty(value = "集成模型数量,如果设置为0,则没有集成。默认50,非必传。") | |||
| private Integer ensembleSize; | |||
| @ApiModelProperty(value = "设置为None将禁用集成构建,设置为SingleBest仅使用单个最佳模型而不是集成,设置为default,它将对单目标问题使用EnsembleSelection,对多目标问题使用MultiObjectiveDummyEnsemble。默认default,非必传。") | |||
| private String ensembleClass; | |||
| @ApiModelProperty(value = "在构建集成时只考虑ensemble_nbest模型。这是受到了“最大限度地利用集成选择”中引入的库修剪概念的启发。这是独立于ensemble_class参数的,并且这个修剪步骤是在构造集成之前完成的。默认50,非必传。") | |||
| private Integer ensembleNbest; | |||
| @ApiModelProperty(value = "定义在磁盘中保存的模型的最大数量。额外的模型数量将被永久删除。由于这个变量的性质,它设置了一个集成可以使用多少个模型的上限。必须是大于等于1的整数。如果设置为None,则所有模型都保留在磁盘上。默认50,非必传。") | |||
| private Integer maxModelsOnDisc; | |||
| @ApiModelProperty(value = "随机种子,将决定输出文件名。默认1,非必传。") | |||
| private Integer seed; | |||
| @ApiModelProperty(value = "机器学习算法的内存限制(MB)。如果auto-sklearn试图分配超过memory_limit MB,它将停止拟合机器学习算法。默认3072,非必传。") | |||
| private Integer memoryLimit; | |||
| @ApiModelProperty(value = "如果为None,则使用所有可能的分类算法。否则,指定搜索中包含的步骤和组件。有关可用组件,请参见/pipeline/components/<step>/*。与参数exclude不兼容。多选,逗号分隔。包含:adaboost\n" + | |||
| "bernoulli_nb\n" + | |||
| "decision_tree\n" + | |||
| "extra_trees\n" + | |||
| "gaussian_nb\n" + | |||
| "gradient_boosting\n" + | |||
| "k_nearest_neighbors\n" + | |||
| "lda\n" + | |||
| "liblinear_svc\n" + | |||
| "libsvm_svc\n" + | |||
| "mlp\n" + | |||
| "multinomial_nb\n" + | |||
| "passive_aggressive\n" + | |||
| "qda\n" + | |||
| "random_forest\n" + | |||
| "sgd") | |||
| private String includeClassifier; | |||
| @ApiModelProperty(value = "如果为None,则使用所有可能的特征预处理算法。否则,指定搜索中包含的步骤和组件。有关可用组件,请参见/pipeline/components/<step>/*。与参数exclude不兼容。多选,逗号分隔。包含:densifier\n" + | |||
| "extra_trees_preproc_for_classification\n" + | |||
| "extra_trees_preproc_for_regression\n" + | |||
| "fast_ica\n" + | |||
| "feature_agglomeration\n" + | |||
| "kernel_pca\n" + | |||
| "kitchen_sinks\n" + | |||
| "liblinear_svc_preprocessor\n" + | |||
| "no_preprocessing\n" + | |||
| "nystroem_sampler\n" + | |||
| "pca\n" + | |||
| "polynomial\n" + | |||
| "random_trees_embedding\n" + | |||
| "select_percentile_classification\n" + | |||
| "select_percentile_regression\n" + | |||
| "select_rates_classification\n" + | |||
| "select_rates_regression\n" + | |||
| "truncatedSVD") | |||
| private String includeFeaturePreprocessor; | |||
| @ApiModelProperty(value = "如果为None,则使用所有可能的回归算法。否则,指定搜索中包含的步骤和组件。有关可用组件,请参见/pipeline/components/<step>/*。与参数exclude不兼容。多选,逗号分隔。包含:adaboost,\n" + | |||
| "ard_regression,\n" + | |||
| "decision_tree,\n" + | |||
| "extra_trees,\n" + | |||
| "gaussian_process,\n" + | |||
| "gradient_boosting,\n" + | |||
| "k_nearest_neighbors,\n" + | |||
| "liblinear_svr,\n" + | |||
| "libsvm_svr,\n" + | |||
| "mlp,\n" + | |||
| "random_forest,\n" + | |||
| "sgd") | |||
| private String includeRegressor; | |||
| private String excludeClassifier; | |||
| private String excludeRegressor; | |||
| private String excludeFeaturePreprocessor; | |||
| @ApiModelProperty(value = "测试集的比率,0到1之间") | |||
| private Float testSize; | |||
| @ApiModelProperty(value = "如何处理过拟合,如果使用基于“cv”的方法或Splitter对象,可能需要使用resampling_strategy_arguments。holdout或crossValid") | |||
| private String resamplingStrategy; | |||
| @ApiModelProperty(value = "重采样划分训练集和验证集,训练集的比率,0到1之间") | |||
| private Float trainSize; | |||
| @ApiModelProperty(value = "拆分数据前是否进行shuffle") | |||
| private Boolean shuffle; | |||
| @ApiModelProperty(value = "交叉验证的折数,当resamplingStrategy为crossValid时,此项必填,为整数") | |||
| private Integer folds; | |||
| @ApiModelProperty(value = "数据集csv文件中哪几列是预测目标列,逗号分隔") | |||
| private String targetColumns; | |||
| @ApiModelProperty(value = "自定义指标名称") | |||
| private String metricName; | |||
| @ApiModelProperty(value = "模型优化目标指标及权重,json格式。分类的指标包含:accuracy\n" + | |||
| "balanced_accuracy\n" + | |||
| "roc_auc\n" + | |||
| "average_precision\n" + | |||
| "log_loss\n" + | |||
| "precision_macro\n" + | |||
| "precision_micro\n" + | |||
| "precision_samples\n" + | |||
| "precision_weighted\n" + | |||
| "recall_macro\n" + | |||
| "recall_micro\n" + | |||
| "recall_samples\n" + | |||
| "recall_weighted\n" + | |||
| "f1_macro\n" + | |||
| "f1_micro\n" + | |||
| "f1_samples\n" + | |||
| "f1_weighted\n" + | |||
| "回归的指标包含:mean_absolute_error\n" + | |||
| "mean_squared_error\n" + | |||
| "root_mean_squared_error\n" + | |||
| "mean_squared_log_error\n" + | |||
| "median_absolute_error\n" + | |||
| "r2") | |||
| private String metrics; | |||
| @ApiModelProperty(value = "指标优化方向,是越大越好还是越小越好") | |||
| private Boolean greaterIsBetter; | |||
| @ApiModelProperty(value = "模型计算并打印指标") | |||
| private String scoringFunctions; | |||
| private Map<String,Object> dataset; | |||
| } | |||
| @@ -0,0 +1,183 @@ | |||
| package com.ruoyi.platform.vo; | |||
| import com.baomidou.mybatisplus.annotation.TableField; | |||
| import com.fasterxml.jackson.databind.PropertyNamingStrategy; | |||
| import com.fasterxml.jackson.databind.annotation.JsonNaming; | |||
| import io.swagger.annotations.ApiModel; | |||
| import io.swagger.annotations.ApiModelProperty; | |||
| import lombok.Data; | |||
| import java.util.Date; | |||
| import java.util.Map; | |||
| @Data | |||
| @JsonNaming(PropertyNamingStrategy.SnakeCaseStrategy.class) | |||
| @ApiModel(description = "自动机器学习") | |||
| public class AutoMlVo { | |||
| private Long id; | |||
| @ApiModelProperty(value = "实验名称") | |||
| private String mlName; | |||
| @ApiModelProperty(value = "实验描述") | |||
| private String mlDescription; | |||
| @ApiModelProperty(value = "任务类型:classification或regression") | |||
| private String taskType; | |||
| @ApiModelProperty(value = "搜索合适模型的时间限制(以秒为单位)。通过增加这个值,auto-sklearn有更高的机会找到更好的模型。默认3600,非必传。") | |||
| private Integer timeLeftForThisTask; | |||
| @ApiModelProperty(value = "单次调用机器学习模型的时间限制(以秒为单位)。如果机器学习算法运行超过时间限制,将终止模型拟合。将这个值设置得足够高,这样典型的机器学习算法就可以适用于训练数据。默认600,非必传。") | |||
| private Integer perRunTimeLimit; | |||
| @ApiModelProperty(value = "集成模型数量,如果设置为0,则没有集成。默认50,非必传。") | |||
| private Integer ensembleSize; | |||
| @ApiModelProperty(value = "设置为None将禁用集成构建,设置为SingleBest仅使用单个最佳模型而不是集成,设置为default,它将对单目标问题使用EnsembleSelection,对多目标问题使用MultiObjectiveDummyEnsemble。默认default,非必传。") | |||
| private String ensembleClass; | |||
| @ApiModelProperty(value = "在构建集成时只考虑ensemble_nbest模型。这是受到了“最大限度地利用集成选择”中引入的库修剪概念的启发。这是独立于ensemble_class参数的,并且这个修剪步骤是在构造集成之前完成的。默认50,非必传。") | |||
| private Integer ensembleNbest; | |||
| @ApiModelProperty(value = "定义在磁盘中保存的模型的最大数量。额外的模型数量将被永久删除。由于这个变量的性质,它设置了一个集成可以使用多少个模型的上限。必须是大于等于1的整数。如果设置为None,则所有模型都保留在磁盘上。默认50,非必传。") | |||
| private Integer maxModelsOnDisc; | |||
| @ApiModelProperty(value = "随机种子,将决定输出文件名。默认1,非必传。") | |||
| private Integer seed; | |||
| @ApiModelProperty(value = "机器学习算法的内存限制(MB)。如果auto-sklearn试图分配超过memory_limit MB,它将停止拟合机器学习算法。默认3072,非必传。") | |||
| private Integer memoryLimit; | |||
| @ApiModelProperty(value = "如果为None,则使用所有可能的分类算法。否则,指定搜索中包含的步骤和组件。有关可用组件,请参见/pipeline/components/<step>/*。与参数exclude不兼容。多选,逗号分隔。包含:adaboost\n" + | |||
| "bernoulli_nb\n" + | |||
| "decision_tree\n" + | |||
| "extra_trees\n" + | |||
| "gaussian_nb\n" + | |||
| "gradient_boosting\n" + | |||
| "k_nearest_neighbors\n" + | |||
| "lda\n" + | |||
| "liblinear_svc\n" + | |||
| "libsvm_svc\n" + | |||
| "mlp\n" + | |||
| "multinomial_nb\n" + | |||
| "passive_aggressive\n" + | |||
| "qda\n" + | |||
| "random_forest\n" + | |||
| "sgd") | |||
| private String includeClassifier; | |||
| @ApiModelProperty(value = "如果为None,则使用所有可能的特征预处理算法。否则,指定搜索中包含的步骤和组件。有关可用组件,请参见/pipeline/components/<step>/*。与参数exclude不兼容。多选,逗号分隔。包含:densifier\n" + | |||
| "extra_trees_preproc_for_classification\n" + | |||
| "extra_trees_preproc_for_regression\n" + | |||
| "fast_ica\n" + | |||
| "feature_agglomeration\n" + | |||
| "kernel_pca\n" + | |||
| "kitchen_sinks\n" + | |||
| "liblinear_svc_preprocessor\n" + | |||
| "no_preprocessing\n" + | |||
| "nystroem_sampler\n" + | |||
| "pca\n" + | |||
| "polynomial\n" + | |||
| "random_trees_embedding\n" + | |||
| "select_percentile_classification\n" + | |||
| "select_percentile_regression\n" + | |||
| "select_rates_classification\n" + | |||
| "select_rates_regression\n" + | |||
| "truncatedSVD") | |||
| private String includeFeaturePreprocessor; | |||
| @ApiModelProperty(value = "如果为None,则使用所有可能的回归算法。否则,指定搜索中包含的步骤和组件。有关可用组件,请参见/pipeline/components/<step>/*。与参数exclude不兼容。多选,逗号分隔。包含:adaboost,\n" + | |||
| "ard_regression,\n" + | |||
| "decision_tree,\n" + | |||
| "extra_trees,\n" + | |||
| "gaussian_process,\n" + | |||
| "gradient_boosting,\n" + | |||
| "k_nearest_neighbors,\n" + | |||
| "liblinear_svr,\n" + | |||
| "libsvm_svr,\n" + | |||
| "mlp,\n" + | |||
| "random_forest,\n" + | |||
| "sgd") | |||
| private String includeRegressor; | |||
| private String excludeClassifier; | |||
| private String excludeRegressor; | |||
| private String excludeFeaturePreprocessor; | |||
| @ApiModelProperty(value = "测试集的比率,0到1之间") | |||
| private Float testSize; | |||
| @ApiModelProperty(value = "如何处理过拟合,如果使用基于“cv”的方法或Splitter对象,可能需要使用resampling_strategy_arguments。holdout或crossValid") | |||
| private String resamplingStrategy; | |||
| @ApiModelProperty(value = "重采样划分训练集和验证集,训练集的比率,0到1之间") | |||
| private Float trainSize; | |||
| @ApiModelProperty(value = "拆分数据前是否进行shuffle") | |||
| private Boolean shuffle; | |||
| @ApiModelProperty(value = "交叉验证的折数,当resamplingStrategy为crossValid时,此项必填,为整数") | |||
| private Integer folds; | |||
| @ApiModelProperty(value = "文件夹存放配置输出和日志文件,默认/tmp/automl") | |||
| private String tmpFolder; | |||
| @ApiModelProperty(value = "数据集csv文件中哪几列是预测目标列,逗号分隔") | |||
| private String targetColumns; | |||
| @ApiModelProperty(value = "自定义指标名称") | |||
| private String metricName; | |||
| @ApiModelProperty(value = "模型优化目标指标及权重,json格式。分类的指标包含:accuracy\n" + | |||
| "balanced_accuracy\n" + | |||
| "roc_auc\n" + | |||
| "average_precision\n" + | |||
| "log_loss\n" + | |||
| "precision_macro\n" + | |||
| "precision_micro\n" + | |||
| "precision_samples\n" + | |||
| "precision_weighted\n" + | |||
| "recall_macro\n" + | |||
| "recall_micro\n" + | |||
| "recall_samples\n" + | |||
| "recall_weighted\n" + | |||
| "f1_macro\n" + | |||
| "f1_micro\n" + | |||
| "f1_samples\n" + | |||
| "f1_weighted\n" + | |||
| "回归的指标包含:mean_absolute_error\n" + | |||
| "mean_squared_error\n" + | |||
| "root_mean_squared_error\n" + | |||
| "mean_squared_log_error\n" + | |||
| "median_absolute_error\n" + | |||
| "r2") | |||
| private String metrics; | |||
| @ApiModelProperty(value = "指标优化方向,是越大越好还是越小越好") | |||
| private Boolean greaterIsBetter; | |||
| @ApiModelProperty(value = "模型计算并打印指标") | |||
| private String scoringFunctions; | |||
| private Integer state; | |||
| private String runState; | |||
| private Double progress; | |||
| private String createBy; | |||
| private Date createTime; | |||
| private String updateBy; | |||
| private Date updateTime; | |||
| /** | |||
| * 对应数据集 | |||
| */ | |||
| private Map<String,Object> dataset; | |||
| } | |||
| @@ -14,16 +14,16 @@ spring: | |||
| nacos: | |||
| discovery: | |||
| # 服务注册地址 | |||
| server-addr: nacos-ci4s.argo.svc:8848 | |||
| server-addr: 172.20.32.181:18848 | |||
| username: nacos | |||
| password: h1n2x3j4y5@ | |||
| password: nacos | |||
| retry: | |||
| enabled: true | |||
| # namespace: 6caf5d79-c4ce-4e3b-a357-141b74e52a01 | |||
| config: | |||
| username: nacos | |||
| password: h1n2x3j4y5@ | |||
| # namespace: 6caf5d79-c4ce-4e3b-a357-141b74e52a01 | |||
| # 配置中心地址 | |||
| server-addr: nacos-ci4s.argo.svc:8848 | |||
| server-addr: 172.20.32.181:18848 | |||
| # 配置文件格式 | |||
| file-extension: yml | |||
| # 共享配置 | |||
| @@ -0,0 +1,146 @@ | |||
| <?xml version="1.0" encoding="UTF-8"?> | |||
| <!DOCTYPE mapper PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN" "http://mybatis.org/dtd/mybatis-3-mapper.dtd"> | |||
| <mapper namespace="com.ruoyi.platform.mapper.AutoMlDao"> | |||
| <insert id="save"> | |||
| insert into auto_ml(ml_name, ml_description, task_type, dataset, time_left_for_this_task, | |||
| per_run_time_limit, ensemble_size, ensemble_class, ensemble_nbest, max_models_on_disc, seed, | |||
| memory_limit, | |||
| include_classifier, include_feature_preprocessor, include_regressor, exclude_classifier, | |||
| exclude_regressor, exclude_feature_preprocessor, test_size, resampling_strategy, train_size, | |||
| shuffle, folds, target_columns, metric_name, metrics, greater_is_better, scoring_functions, | |||
| tmp_folder, | |||
| create_by, update_by) | |||
| values (#{autoMl.mlName}, #{autoMl.mlDescription}, #{autoMl.taskType}, #{autoMl.dataset}, | |||
| #{autoMl.timeLeftForThisTask}, #{autoMl.perRunTimeLimit}, | |||
| #{autoMl.ensembleSize}, #{autoMl.ensembleClass}, #{autoMl.ensembleNbest}, | |||
| #{autoMl.maxModelsOnDisc}, #{autoMl.seed}, | |||
| #{autoMl.memoryLimit}, #{autoMl.includeClassifier}, #{autoMl.includeFeaturePreprocessor}, | |||
| #{autoMl.includeRegressor}, #{autoMl.excludeClassifier}, | |||
| #{autoMl.excludeRegressor}, #{autoMl.excludeFeaturePreprocessor}, #{autoMl.testSize}, | |||
| #{autoMl.resamplingStrategy}, | |||
| #{autoMl.trainSize}, #{autoMl.shuffle}, | |||
| #{autoMl.folds}, | |||
| #{autoMl.targetColumns}, #{autoMl.metricName}, #{autoMl.metrics}, #{autoMl.greaterIsBetter}, | |||
| #{autoMl.scoringFunctions}, #{autoMl.tmpFolder}, | |||
| #{autoMl.createBy}, #{autoMl.updateBy}) | |||
| </insert> | |||
| <update id="edit"> | |||
| update auto_ml | |||
| <set> | |||
| <if test="autoMl.mlName != null and autoMl.mlName !=''"> | |||
| ml_name = #{autoMl.mlName}, | |||
| </if> | |||
| <if test="autoMl.mlDescription != null and autoMl.mlDescription !=''"> | |||
| ml_description = #{autoMl.mlDescription}, | |||
| </if> | |||
| <if test="autoMl.statusList != null and autoMl.statusList !=''"> | |||
| status_list = #{autoMl.statusList}, | |||
| </if> | |||
| <!-- <if test="autoMl.progress != null">--> | |||
| <!-- progress = #{autoMl.progress},--> | |||
| <!-- </if>--> | |||
| <if test="autoMl.taskType != null and autoMl.taskType !=''"> | |||
| task_type = #{autoMl.taskType}, | |||
| </if> | |||
| <if test="autoMl.dataset != null and autoMl.dataset !=''"> | |||
| dataset = #{autoMl.dataset}, | |||
| </if> | |||
| <if test="autoMl.timeLeftForThisTask != null"> | |||
| time_left_for_this_task = #{autoMl.timeLeftForThisTask}, | |||
| </if> | |||
| <if test="autoMl.perRunTimeLimit != null"> | |||
| per_run_time_limit = #{autoMl.perRunTimeLimit}, | |||
| </if> | |||
| <if test="autoMl.ensembleSize != null"> | |||
| ensemble_size = #{autoMl.ensembleSize}, | |||
| </if> | |||
| <if test="autoMl.ensembleClass != null and autoMl.ensembleClass !=''"> | |||
| ensemble_class = #{autoMl.ensembleClass}, | |||
| </if> | |||
| <if test="autoMl.ensembleNbest != null"> | |||
| ensemble_nbest = #{autoMl.ensembleNbest}, | |||
| </if> | |||
| <if test="autoMl.maxModelsOnDisc != null"> | |||
| max_models_on_disc = #{autoMl.maxModelsOnDisc}, | |||
| </if> | |||
| <if test="autoMl.seed != null"> | |||
| seed = #{autoMl.seed}, | |||
| </if> | |||
| <if test="autoMl.memoryLimit != null"> | |||
| memory_limit = #{autoMl.memoryLimit}, | |||
| </if> | |||
| include_classifier = #{autoMl.includeClassifier}, | |||
| include_feature_preprocessor = #{autoMl.includeFeaturePreprocessor}, | |||
| include_regressor = #{autoMl.includeRegressor}, | |||
| exclude_classifier = #{autoMl.excludeClassifier}, | |||
| exclude_regressor = #{autoMl.excludeRegressor}, | |||
| exclude_feature_preprocessor = #{autoMl.excludeFeaturePreprocessor}, | |||
| scoring_functions = #{autoMl.scoringFunctions}, | |||
| metrics = #{autoMl.metrics}, | |||
| <if test="autoMl.testSize != null and autoMl.testSize !=''"> | |||
| test_size = #{autoMl.testSize}, | |||
| </if> | |||
| <if test="autoMl.resamplingStrategy != null and autoMl.resamplingStrategy !=''"> | |||
| resampling_strategy = #{autoMl.resamplingStrategy}, | |||
| </if> | |||
| <if test="autoMl.trainSize != null and autoMl.trainSize !=''"> | |||
| train_size = #{autoMl.trainSize}, | |||
| </if> | |||
| <if test="autoMl.shuffle != null"> | |||
| shuffle = #{autoMl.shuffle}, | |||
| </if> | |||
| <if test="autoMl.folds != null"> | |||
| folds = #{autoMl.folds}, | |||
| </if> | |||
| <if test="autoMl.tmpFolder != null and autoMl.tmpFolder !=''"> | |||
| tmp_folder = #{autoMl.tmpFolder}, | |||
| </if> | |||
| <if test="autoMl.metricName != null and autoMl.metricName !=''"> | |||
| metric_name = #{autoMl.metricName}, | |||
| </if> | |||
| <if test="autoMl.greaterIsBetter != null"> | |||
| greater_is_better = #{autoMl.greaterIsBetter}, | |||
| </if> | |||
| <if test="autoMl.targetColumns != null and autoMl.targetColumns !=''"> | |||
| target_columns = #{autoMl.targetColumns}, | |||
| </if> | |||
| <if test="autoMl.state != null"> | |||
| state = #{autoMl.state}, | |||
| </if> | |||
| </set> | |||
| where id = #{autoMl.id} | |||
| </update> | |||
| <select id="count" resultType="java.lang.Long"> | |||
| select count(1) from auto_ml | |||
| <include refid="common_condition"></include> | |||
| </select> | |||
| <select id="queryByPage" resultType="com.ruoyi.platform.domain.AutoMl"> | |||
| select * from auto_ml | |||
| <include refid="common_condition"></include> | |||
| </select> | |||
| <select id="getAutoMlById" resultType="com.ruoyi.platform.domain.AutoMl"> | |||
| select * | |||
| from auto_ml | |||
| where id = #{id} | |||
| </select> | |||
| <select id="getAutoMlByName" resultType="com.ruoyi.platform.domain.AutoMl"> | |||
| select * | |||
| from auto_ml | |||
| where ml_name = #{mlName} | |||
| and state = 1 | |||
| </select> | |||
| <sql id="common_condition"> | |||
| <where> | |||
| state = 1 | |||
| <if test="mlName != null and mlName != ''"> | |||
| and ml_name like concat('%', #{mlName}, '%') | |||
| </if> | |||
| </where> | |||
| </sql> | |||
| </mapper> | |||
| @@ -0,0 +1,86 @@ | |||
| <?xml version="1.0" encoding="UTF-8"?> | |||
| <!DOCTYPE mapper PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN" "http://mybatis.org/dtd/mybatis-3-mapper.dtd"> | |||
| <mapper namespace="com.ruoyi.platform.mapper.AutoMlInsDao"> | |||
| <insert id="insert" keyProperty="id" useGeneratedKeys="true"> | |||
| insert into auto_ml_ins(auto_ml_id, result_path, model_path, img_path, run_history_path, node_status, | |||
| node_result, param, source, argo_ins_name, argo_ins_ns, status) | |||
| values (#{autoMlIns.autoMlId}, #{autoMlIns.resultPath}, #{autoMlIns.modelPath}, #{autoMlIns.imgPath}, | |||
| #{autoMlIns.runHistoryPath}, #{autoMlIns.nodeStatus}, | |||
| #{autoMlIns.nodeResult}, #{autoMlIns.param}, #{autoMlIns.source}, #{autoMlIns.argoInsName}, | |||
| #{autoMlIns.argoInsNs}, #{autoMlIns.status}) | |||
| </insert> | |||
| <update id="update"> | |||
| update auto_ml_ins | |||
| <set> | |||
| <if test="autoMlIns.modelPath != null and autoMlIns.modelPath != ''"> | |||
| model_path = #{autoMlIns.modelPath}, | |||
| </if> | |||
| <if test="autoMlIns.imgPath != null and autoMlIns.imgPath != ''"> | |||
| img_path = #{autoMlIns.imgPath}, | |||
| </if> | |||
| <if test="autoMlIns.status != null and autoMlIns.status != ''"> | |||
| status = #{autoMlIns.status}, | |||
| </if> | |||
| <if test="autoMlIns.nodeStatus != null and autoMlIns.nodeStatus != ''"> | |||
| node_status = #{autoMlIns.nodeStatus}, | |||
| </if> | |||
| <if test="autoMlIns.nodeResult != null and autoMlIns.nodeResult != ''"> | |||
| node_result = #{autoMlIns.nodeResult}, | |||
| </if> | |||
| <if test="autoMlIns.state != null"> | |||
| state = #{autoMlIns.state}, | |||
| </if> | |||
| <if test="autoMlIns.updateTime != null"> | |||
| update_time = #{autoMlIns.updateTime}, | |||
| </if> | |||
| <if test="autoMlIns.finishTime != null"> | |||
| finish_time = #{autoMlIns.finishTime}, | |||
| </if> | |||
| </set> | |||
| where id = #{autoMlIns.id} | |||
| </update> | |||
| <select id="count" resultType="java.lang.Long"> | |||
| select count(1) | |||
| from auto_ml_ins | |||
| <where> | |||
| state = 1 | |||
| and auto_ml_id = #{autoMlIns.autoMlId} | |||
| </where> | |||
| </select> | |||
| <select id="queryAllByLimit" resultType="com.ruoyi.platform.domain.AutoMlIns"> | |||
| select * from auto_ml_ins | |||
| <where> | |||
| state = 1 | |||
| and auto_ml_id = #{autoMlIns.autoMlId} | |||
| </where> | |||
| order by update_time DESC | |||
| limit #{pageable.offset}, #{pageable.pageSize} | |||
| </select> | |||
| <select id="queryById" resultType="com.ruoyi.platform.domain.AutoMlIns"> | |||
| select * from auto_ml_ins | |||
| <where> | |||
| state = 1 and id = #{id} | |||
| </where> | |||
| </select> | |||
| <select id="queryByAutoMlInsIsNotTerminated" resultType="com.ruoyi.platform.domain.AutoMlIns"> | |||
| select * | |||
| from auto_ml_ins | |||
| where (status NOT IN ('Terminated', 'Succeeded', 'Failed') | |||
| OR status IS NULL) | |||
| and state = 1 | |||
| </select> | |||
| <select id="getByAutoMlId" resultType="com.ruoyi.platform.domain.AutoMlIns"> | |||
| select * | |||
| from auto_ml_ins | |||
| where auto_ml_id = #{autoMlId} | |||
| and state = 1 | |||
| order by update_time DESC limit 5 | |||
| </select> | |||
| </mapper> | |||