| @@ -49,7 +49,7 @@ public class AutoMlInsController extends BaseController { | |||||
| @PutMapping("{id}") | @PutMapping("{id}") | ||||
| @ApiOperation("终止实验实例") | @ApiOperation("终止实验实例") | ||||
| public GenericsAjaxResult<Boolean> terminateAutoMlIns(@PathVariable("id") Long id) { | |||||
| public GenericsAjaxResult<Boolean> terminateAutoMlIns(@PathVariable("id") Long id) throws Exception { | |||||
| return genericsSuccess(this.autoMLInsService.terminateAutoMlIns(id)); | return genericsSuccess(this.autoMLInsService.terminateAutoMlIns(id)); | ||||
| } | } | ||||
| @@ -0,0 +1,62 @@ | |||||
| package com.ruoyi.platform.controller.ray; | |||||
| import com.ruoyi.common.core.web.controller.BaseController; | |||||
| import com.ruoyi.common.core.web.domain.GenericsAjaxResult; | |||||
| import com.ruoyi.platform.domain.Ray; | |||||
| import com.ruoyi.platform.service.RayService; | |||||
| import com.ruoyi.platform.vo.RayVo; | |||||
| import io.swagger.annotations.Api; | |||||
| import io.swagger.annotations.ApiOperation; | |||||
| import org.springframework.data.domain.Page; | |||||
| import org.springframework.data.domain.PageRequest; | |||||
| import org.springframework.web.bind.annotation.*; | |||||
| import javax.annotation.Resource; | |||||
| import java.io.IOException; | |||||
| @RestController | |||||
| @RequestMapping("ray") | |||||
| @Api("自动超参数寻优") | |||||
| public class RayController extends BaseController { | |||||
| @Resource | |||||
| private RayService rayService; | |||||
| @GetMapping | |||||
| @ApiOperation("分页查询") | |||||
| public GenericsAjaxResult<Page<Ray>> queryByPage(@RequestParam("page") int page, | |||||
| @RequestParam("size") int size, | |||||
| @RequestParam(value = "name", required = false) String name) { | |||||
| PageRequest pageRequest = PageRequest.of(page, size); | |||||
| return genericsSuccess(this.rayService.queryByPage(name, pageRequest)); | |||||
| } | |||||
| @PostMapping | |||||
| @ApiOperation("新增自动超参数寻优") | |||||
| public GenericsAjaxResult<Ray> addRay(@RequestBody RayVo rayVo) throws Exception { | |||||
| return genericsSuccess(this.rayService.save(rayVo)); | |||||
| } | |||||
| @PutMapping | |||||
| @ApiOperation("编辑自动超参数寻优") | |||||
| public GenericsAjaxResult<String> editRay(@RequestBody RayVo rayVo) throws Exception{ | |||||
| return genericsSuccess(this.rayService.edit(rayVo)); | |||||
| } | |||||
| @GetMapping("/getRayDetail") | |||||
| @ApiOperation("获取自动超参数寻优详细信息") | |||||
| public GenericsAjaxResult<RayVo> getRayDetail(@RequestParam("id") Long id) throws IOException { | |||||
| return genericsSuccess(this.rayService.getRayDetail(id)); | |||||
| } | |||||
| @DeleteMapping("{id}") | |||||
| @ApiOperation("删除自动超参数寻优") | |||||
| public GenericsAjaxResult<String> deleteRay(@PathVariable("id") Long id) { | |||||
| return genericsSuccess(this.rayService.delete(id)); | |||||
| } | |||||
| @PostMapping("/run/{id}") | |||||
| @ApiOperation("运行自动超参数寻优实验") | |||||
| public GenericsAjaxResult<String> runRay(@PathVariable("id") Long id) throws Exception { | |||||
| return genericsSuccess(this.rayService.runRayIns(id)); | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,60 @@ | |||||
| package com.ruoyi.platform.controller.ray; | |||||
| import com.ruoyi.common.core.web.controller.BaseController; | |||||
| import com.ruoyi.common.core.web.domain.GenericsAjaxResult; | |||||
| import com.ruoyi.platform.domain.RayIns; | |||||
| import com.ruoyi.platform.service.RayInsService; | |||||
| import io.swagger.annotations.Api; | |||||
| import io.swagger.annotations.ApiOperation; | |||||
| import org.springframework.web.bind.annotation.*; | |||||
| import org.springframework.data.domain.Page; | |||||
| import org.springframework.data.domain.PageRequest; | |||||
| import javax.annotation.Resource; | |||||
| import java.io.IOException; | |||||
| import java.util.List; | |||||
| @RestController | |||||
| @RequestMapping("rayIns") | |||||
| @Api("自动超参数寻优实验实例") | |||||
| public class RayInsController extends BaseController { | |||||
| @Resource | |||||
| private RayInsService rayInsService; | |||||
| @GetMapping | |||||
| @ApiOperation("分页查询") | |||||
| public GenericsAjaxResult<Page<RayIns>> queryByPage(Long rayId, int page, int size) throws IOException { | |||||
| PageRequest pageRequest = PageRequest.of(page, size); | |||||
| return genericsSuccess(this.rayInsService.queryByPage(rayId, pageRequest)); | |||||
| } | |||||
| @PostMapping | |||||
| @ApiOperation("新增实验实例") | |||||
| public GenericsAjaxResult<RayIns> add(@RequestBody RayIns rayIns) { | |||||
| return genericsSuccess(this.rayInsService.insert(rayIns)); | |||||
| } | |||||
| @DeleteMapping("{id}") | |||||
| @ApiOperation("删除实验实例") | |||||
| public GenericsAjaxResult<String> deleteById(@PathVariable("id") Long id) { | |||||
| return genericsSuccess(this.rayInsService.deleteById(id)); | |||||
| } | |||||
| @DeleteMapping("batchDelete") | |||||
| @ApiOperation("批量删除实验实例") | |||||
| public GenericsAjaxResult<String> batchDelete(@RequestBody List<Long> ids) { | |||||
| return genericsSuccess(this.rayInsService.batchDelete(ids)); | |||||
| } | |||||
| @PutMapping("{id}") | |||||
| @ApiOperation("终止实验实例") | |||||
| public GenericsAjaxResult<Boolean> terminateRayIns(@PathVariable("id") Long id) throws Exception { | |||||
| return genericsSuccess(this.rayInsService.terminateRayIns(id)); | |||||
| } | |||||
| @GetMapping("{id}") | |||||
| @ApiOperation("查看实验实例详情") | |||||
| public GenericsAjaxResult<RayIns> getDetailById(@PathVariable("id") Long id) throws IOException { | |||||
| return genericsSuccess(this.rayInsService.getDetailById(id)); | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,82 @@ | |||||
| package com.ruoyi.platform.domain; | |||||
| import com.fasterxml.jackson.databind.PropertyNamingStrategy; | |||||
| import com.fasterxml.jackson.databind.annotation.JsonNaming; | |||||
| import io.swagger.annotations.ApiModel; | |||||
| import io.swagger.annotations.ApiModelProperty; | |||||
| import lombok.Data; | |||||
| import java.util.Date; | |||||
| @Data | |||||
| @JsonNaming(PropertyNamingStrategy.SnakeCaseStrategy.class) | |||||
| @ApiModel(description = "自动超参数寻优") | |||||
| public class Ray { | |||||
| private Long id; | |||||
| @ApiModelProperty(value = "实验名称") | |||||
| private String name; | |||||
| @ApiModelProperty(value = "实验描述") | |||||
| private String description; | |||||
| @ApiModelProperty(value = "数据集") | |||||
| private String dataset; | |||||
| @ApiModelProperty(value = "数据集挂载路径") | |||||
| private String model; | |||||
| @ApiModelProperty(value = "代码配置") | |||||
| private String codeConfig; | |||||
| @ApiModelProperty(value = "镜像") | |||||
| private String image; | |||||
| @ApiModelProperty(value = "主函数代码文件") | |||||
| private String mainPy; | |||||
| @ApiModelProperty(value = "总实验次数") | |||||
| private Integer numSamples; | |||||
| @ApiModelProperty(value = "参数") | |||||
| private String parameters; | |||||
| @ApiModelProperty(value = "手动指定需要运行的参数") | |||||
| private String pointsToEvaluate; | |||||
| @ApiModelProperty(value = "保存路径") | |||||
| private String storagePath; | |||||
| @ApiModelProperty(value = "搜索算法") | |||||
| private String searchAlg; | |||||
| @ApiModelProperty(value = "调度算法") | |||||
| private String scheduler; | |||||
| @ApiModelProperty(value = "指标") | |||||
| private String metric; | |||||
| @ApiModelProperty(value = "指标最大化或最小化,min or max") | |||||
| private String mode; | |||||
| @ApiModelProperty(value = "搜索算法为ASHA,HyperBand时传入,每次试验的最大时间单位。测试将在max_t时间单位后停止。") | |||||
| private Integer maxT; | |||||
| @ApiModelProperty(value = "搜索算法为MedianStopping时传入,计算中位数的最小试验数。") | |||||
| private Integer minSamplesRequired; | |||||
| private String resource; | |||||
| private Integer state; | |||||
| private String createBy; | |||||
| private String updateBy; | |||||
| private Date createTime; | |||||
| private Date updateTime; | |||||
| @ApiModelProperty(value = "状态列表") | |||||
| private String statusList; | |||||
| } | |||||
| @@ -0,0 +1,51 @@ | |||||
| package com.ruoyi.platform.domain; | |||||
| import com.baomidou.mybatisplus.annotation.TableField; | |||||
| import com.fasterxml.jackson.databind.PropertyNamingStrategy; | |||||
| import com.fasterxml.jackson.databind.annotation.JsonNaming; | |||||
| import io.swagger.annotations.ApiModel; | |||||
| import io.swagger.annotations.ApiModelProperty; | |||||
| import lombok.Data; | |||||
| import java.util.ArrayList; | |||||
| import java.util.Date; | |||||
| import java.util.Map; | |||||
| @Data | |||||
| @JsonNaming(PropertyNamingStrategy.SnakeCaseStrategy.class) | |||||
| @ApiModel(description = "自动超参数寻优实验实例") | |||||
| public class RayIns { | |||||
| private Long id; | |||||
| private Long rayId; | |||||
| private String resultPath; | |||||
| private Integer state; | |||||
| private String status; | |||||
| private String nodeStatus; | |||||
| private String nodeResult; | |||||
| private String param; | |||||
| private String source; | |||||
| @ApiModelProperty(value = "Argo实例名称") | |||||
| private String argoInsName; | |||||
| @ApiModelProperty(value = "Argo命名空间") | |||||
| private String argoInsNs; | |||||
| private Date createTime; | |||||
| private Date updateTime; | |||||
| private Date finishTime; | |||||
| @TableField(exist = false) | |||||
| private ArrayList<Map<String, Object>> trialList; | |||||
| } | |||||
| @@ -0,0 +1,21 @@ | |||||
| package com.ruoyi.platform.mapper; | |||||
| import com.ruoyi.platform.domain.Ray; | |||||
| import org.apache.ibatis.annotations.Param; | |||||
| import org.springframework.data.domain.PageRequest; | |||||
| import java.util.List; | |||||
| public interface RayDao { | |||||
| long count(@Param("name") String name); | |||||
| List<Ray> queryByPage(@Param("name") String name, @Param("pageable") PageRequest pageRequest); | |||||
| Ray getRayByName(@Param("name") String name); | |||||
| Ray getRayById(@Param("id") Long id); | |||||
| int save(@Param("ray") Ray ray); | |||||
| int edit(@Param("ray") Ray ray); | |||||
| } | |||||
| @@ -0,0 +1,23 @@ | |||||
| package com.ruoyi.platform.mapper; | |||||
| import com.ruoyi.platform.domain.RayIns; | |||||
| import org.apache.ibatis.annotations.Param; | |||||
| import org.springframework.data.domain.Pageable; | |||||
| import java.util.List; | |||||
| public interface RayInsDao { | |||||
| long count(@Param("rayId") Long rayId); | |||||
| List<RayIns> queryAllByLimit(@Param("rayId") Long rayId, @Param("pageable") Pageable pageable); | |||||
| RayIns queryById(@Param("id") Long id); | |||||
| List<RayIns> getByRayId(@Param("rayId") Long rayId); | |||||
| int insert(@Param("rayIns") RayIns rayIns); | |||||
| int update(@Param("rayIns") RayIns rayIns); | |||||
| List<RayIns> queryByRayInsIsNotTerminated(); | |||||
| } | |||||
| @@ -1,5 +1,6 @@ | |||||
| package com.ruoyi.platform.scheduling; | package com.ruoyi.platform.scheduling; | ||||
| import com.ruoyi.platform.constant.Constant; | |||||
| import com.ruoyi.platform.domain.AutoMl; | import com.ruoyi.platform.domain.AutoMl; | ||||
| import com.ruoyi.platform.domain.AutoMlIns; | import com.ruoyi.platform.domain.AutoMlIns; | ||||
| import com.ruoyi.platform.mapper.AutoMlDao; | import com.ruoyi.platform.mapper.AutoMlDao; | ||||
| @@ -41,7 +42,7 @@ public class AutoMlInsStatusTask { | |||||
| try { | try { | ||||
| autoMlIns = autoMlInsService.queryStatusFromArgo(autoMlIns); | autoMlIns = autoMlInsService.queryStatusFromArgo(autoMlIns); | ||||
| } catch (Exception e) { | } catch (Exception e) { | ||||
| autoMlIns.setStatus("Failed"); | |||||
| autoMlIns.setStatus(Constant.Failed); | |||||
| } | } | ||||
| // 线程安全的添加操作 | // 线程安全的添加操作 | ||||
| synchronized (autoMlIds) { | synchronized (autoMlIds) { | ||||
| @@ -0,0 +1,95 @@ | |||||
| package com.ruoyi.platform.scheduling; | |||||
| import com.ruoyi.platform.constant.Constant; | |||||
| import com.ruoyi.platform.domain.Ray; | |||||
| import com.ruoyi.platform.domain.RayIns; | |||||
| import com.ruoyi.platform.mapper.RayDao; | |||||
| import com.ruoyi.platform.mapper.RayInsDao; | |||||
| import com.ruoyi.platform.service.RayInsService; | |||||
| import org.apache.commons.lang3.StringUtils; | |||||
| import org.springframework.scheduling.annotation.Scheduled; | |||||
| import org.springframework.stereotype.Component; | |||||
| import javax.annotation.Resource; | |||||
| import java.util.ArrayList; | |||||
| import java.util.Iterator; | |||||
| import java.util.List; | |||||
| @Component() | |||||
| public class RayInsStatusTask { | |||||
| @Resource | |||||
| private RayInsService rayInsService; | |||||
| @Resource | |||||
| private RayInsDao rayInsDao; | |||||
| @Resource | |||||
| private RayDao rayDao; | |||||
| private List<Long> rayIds = new ArrayList<>(); | |||||
| @Scheduled(cron = "0/30 * * * * ?") // 每30S执行一次 | |||||
| public void executeRayInsStatus() { | |||||
| List<RayIns> rayInsList = rayInsService.queryByRayInsIsNotTerminated(); | |||||
| // 去argo查询状态 | |||||
| List<RayIns> updateList = new ArrayList<>(); | |||||
| if (rayInsList != null && rayInsList.size() > 0) { | |||||
| for (RayIns rayIns : rayInsList) { | |||||
| //当原本状态为null或非终止态时才调用argo接口 | |||||
| try { | |||||
| rayIns = rayInsService.queryStatusFromArgo(rayIns); | |||||
| } catch (Exception e) { | |||||
| rayIns.setStatus(Constant.Failed); | |||||
| } | |||||
| // 线程安全的添加操作 | |||||
| synchronized (rayIds) { | |||||
| rayIds.add(rayIns.getRayId()); | |||||
| } | |||||
| updateList.add(rayIns); | |||||
| } | |||||
| if (updateList.size() > 0) { | |||||
| for (RayIns rayIns : updateList) { | |||||
| rayInsDao.update(rayIns); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| @Scheduled(cron = "0/30 * * * * ?") // / 每30S执行一次 | |||||
| public void executeRayStatus() { | |||||
| if (rayIds.size() == 0) { | |||||
| return; | |||||
| } | |||||
| // 存储需要更新的实验对象列表 | |||||
| List<Ray> updateRays = new ArrayList<>(); | |||||
| for (Long rayId : rayIds) { | |||||
| // 获取当前实验的所有实例列表 | |||||
| List<RayIns> insList = rayInsDao.getByRayId(rayId); | |||||
| List<String> statusList = new ArrayList<>(); | |||||
| // 更新实验状态列表 | |||||
| for (int i = 0; i < insList.size(); i++) { | |||||
| statusList.add(insList.get(i).getStatus()); | |||||
| } | |||||
| String subStatus = statusList.toString().substring(1, statusList.toString().length() - 1); | |||||
| Ray ray = rayDao.getRayById(rayId); | |||||
| if (!StringUtils.equals(ray.getStatusList(), subStatus)) { | |||||
| ray.setStatusList(subStatus); | |||||
| updateRays.add(ray); | |||||
| rayDao.edit(ray); | |||||
| } | |||||
| } | |||||
| if (!updateRays.isEmpty()) { | |||||
| // 使用Iterator进行安全的删除操作 | |||||
| Iterator<Long> iterator = rayIds.iterator(); | |||||
| while (iterator.hasNext()) { | |||||
| Long rayId = iterator.next(); | |||||
| for (Ray ray : updateRays) { | |||||
| if (ray.getId().equals(rayId)) { | |||||
| iterator.remove(); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -21,7 +21,7 @@ public interface AutoMlInsService { | |||||
| AutoMlIns queryStatusFromArgo(AutoMlIns autoMlIns); | AutoMlIns queryStatusFromArgo(AutoMlIns autoMlIns); | ||||
| boolean terminateAutoMlIns(Long id); | |||||
| boolean terminateAutoMlIns(Long id) throws Exception; | |||||
| AutoMlIns getDetailById(Long id); | AutoMlIns getDetailById(Long id); | ||||
| @@ -0,0 +1,25 @@ | |||||
| package com.ruoyi.platform.service; | |||||
| import org.springframework.data.domain.Page; | |||||
| import org.springframework.data.domain.PageRequest; | |||||
| import com.ruoyi.platform.domain.RayIns; | |||||
| import java.io.IOException; | |||||
| import java.util.List; | |||||
| public interface RayInsService { | |||||
| Page<RayIns> queryByPage(Long rayId, PageRequest pageRequest) throws IOException; | |||||
| RayIns insert(RayIns rayIns); | |||||
| String deleteById(Long id); | |||||
| String batchDelete(List<Long> ids); | |||||
| boolean terminateRayIns(Long id) throws Exception; | |||||
| RayIns getDetailById(Long id) throws IOException; | |||||
| void updateRayStatus(Long rayId); | |||||
| RayIns queryStatusFromArgo(RayIns ins); | |||||
| List<RayIns> queryByRayInsIsNotTerminated(); | |||||
| } | |||||
| @@ -0,0 +1,22 @@ | |||||
| package com.ruoyi.platform.service; | |||||
| import com.ruoyi.platform.domain.Ray; | |||||
| import com.ruoyi.platform.vo.RayVo; | |||||
| import org.springframework.data.domain.Page; | |||||
| import org.springframework.data.domain.PageRequest; | |||||
| import java.io.IOException; | |||||
| public interface RayService { | |||||
| Page<Ray> queryByPage(String name, PageRequest pageRequest); | |||||
| Ray save(RayVo rayVo) throws Exception; | |||||
| String edit(RayVo rayVo) throws Exception; | |||||
| RayVo getRayDetail(Long id) throws IOException; | |||||
| String delete(Long id); | |||||
| String runRayIns(Long id) throws Exception; | |||||
| } | |||||
| @@ -57,7 +57,7 @@ public class AutoMlInsServiceImpl implements AutoMlInsService { | |||||
| if (StringUtils.isEmpty(autoMlIns.getStatus())) { | if (StringUtils.isEmpty(autoMlIns.getStatus())) { | ||||
| autoMlIns = queryStatusFromArgo(autoMlIns); | autoMlIns = queryStatusFromArgo(autoMlIns); | ||||
| } | } | ||||
| if (StringUtils.equals(autoMlIns.getStatus(), "Running")) { | |||||
| if (StringUtils.equals(autoMlIns.getStatus(), Constant.Running)) { | |||||
| return "实验实例正在运行,不可删除"; | return "实验实例正在运行,不可删除"; | ||||
| } | } | ||||
| @@ -156,7 +156,7 @@ public class AutoMlInsServiceImpl implements AutoMlInsService { | |||||
| } | } | ||||
| @Override | @Override | ||||
| public boolean terminateAutoMlIns(Long id) { | |||||
| public boolean terminateAutoMlIns(Long id) throws Exception { | |||||
| AutoMlIns autoMlIns = autoMlInsDao.queryById(id); | AutoMlIns autoMlIns = autoMlInsDao.queryById(id); | ||||
| if (autoMlIns == null) { | if (autoMlIns == null) { | ||||
| throw new IllegalStateException("实验实例未查询到,id: " + id); | throw new IllegalStateException("实验实例未查询到,id: " + id); | ||||
| @@ -172,7 +172,7 @@ public class AutoMlInsServiceImpl implements AutoMlInsService { | |||||
| } | } | ||||
| // 只有状态是"Running"时才能终止实例 | // 只有状态是"Running"时才能终止实例 | ||||
| if (!currentStatus.equalsIgnoreCase(Constant.Running)) { | if (!currentStatus.equalsIgnoreCase(Constant.Running)) { | ||||
| return false; // 如果不是"Running"状态,则不执行终止操作 | |||||
| throw new Exception("终止错误,只有运行状态的实例才能终止"); // 如果不是"Running"状态,则不执行终止操作 | |||||
| } | } | ||||
| // 创建请求数据map | // 创建请求数据map | ||||
| @@ -102,13 +102,13 @@ public class AutoMlServiceImpl implements AutoMlService { | |||||
| public String delete(Long id) { | public String delete(Long id) { | ||||
| AutoMl autoMl = autoMlDao.getAutoMlById(id); | AutoMl autoMl = autoMlDao.getAutoMlById(id); | ||||
| if (autoMl == null) { | if (autoMl == null) { | ||||
| throw new RuntimeException("服务不存在"); | |||||
| throw new RuntimeException("实验不存在"); | |||||
| } | } | ||||
| String username = SecurityUtils.getLoginUser().getUsername(); | String username = SecurityUtils.getLoginUser().getUsername(); | ||||
| String createBy = autoMl.getCreateBy(); | String createBy = autoMl.getCreateBy(); | ||||
| if (!(StringUtils.equals(username, "admin") || StringUtils.equals(username, createBy))) { | if (!(StringUtils.equals(username, "admin") || StringUtils.equals(username, createBy))) { | ||||
| throw new RuntimeException("无权限删除该服务"); | |||||
| throw new RuntimeException("无权限删除该实验"); | |||||
| } | } | ||||
| autoMl.setState(Constant.State_invalid); | autoMl.setState(Constant.State_invalid); | ||||
| @@ -0,0 +1,293 @@ | |||||
| package com.ruoyi.platform.service.impl; | |||||
| import com.ruoyi.platform.constant.Constant; | |||||
| import com.ruoyi.platform.domain.Ray; | |||||
| import com.ruoyi.platform.domain.RayIns; | |||||
| import com.ruoyi.platform.mapper.RayDao; | |||||
| import com.ruoyi.platform.mapper.RayInsDao; | |||||
| import com.ruoyi.platform.service.RayInsService; | |||||
| import com.ruoyi.platform.utils.DateUtils; | |||||
| import com.ruoyi.platform.utils.HttpUtils; | |||||
| import com.ruoyi.platform.utils.JsonUtils; | |||||
| import org.apache.commons.lang3.StringUtils; | |||||
| import org.springframework.beans.factory.annotation.Value; | |||||
| import org.springframework.data.domain.Page; | |||||
| import org.springframework.data.domain.PageImpl; | |||||
| import org.springframework.data.domain.PageRequest; | |||||
| import org.springframework.stereotype.Service; | |||||
| import javax.annotation.Resource; | |||||
| import java.io.IOException; | |||||
| import java.nio.file.Files; | |||||
| import java.nio.file.Path; | |||||
| import java.nio.file.Paths; | |||||
| import java.util.*; | |||||
| import java.util.stream.Collectors; | |||||
| @Service("rayInsService") | |||||
| public class RayInsServiceImpl implements RayInsService { | |||||
| @Value("${argo.url}") | |||||
| private String argoUrl; | |||||
| @Value("${argo.workflowStatus}") | |||||
| private String argoWorkflowStatus; | |||||
| @Value("${argo.workflowTermination}") | |||||
| private String argoWorkflowTermination; | |||||
| @Resource | |||||
| private RayInsDao rayInsDao; | |||||
| @Resource | |||||
| private RayDao rayDao; | |||||
| @Override | |||||
| public Page<RayIns> queryByPage(Long rayId, PageRequest pageRequest) throws IOException { | |||||
| long total = this.rayInsDao.count(rayId); | |||||
| List<RayIns> rayInsList = this.rayInsDao.queryAllByLimit(rayId, pageRequest); | |||||
| return new PageImpl<>(rayInsList, pageRequest, total); | |||||
| } | |||||
| @Override | |||||
| public RayIns insert(RayIns rayIns) { | |||||
| this.rayInsDao.insert(rayIns); | |||||
| return rayIns; | |||||
| } | |||||
| @Override | |||||
| public String deleteById(Long id) { | |||||
| RayIns rayIns = rayInsDao.queryById(id); | |||||
| if (rayIns == null) { | |||||
| return "实验实例不存在"; | |||||
| } | |||||
| if (StringUtils.isEmpty(rayIns.getStatus())) { | |||||
| rayIns = queryStatusFromArgo(rayIns); | |||||
| } | |||||
| if (StringUtils.equals(rayIns.getStatus(), Constant.Running)) { | |||||
| return "实验实例正在运行,不可删除"; | |||||
| } | |||||
| rayIns.setState(Constant.State_invalid); | |||||
| int update = rayInsDao.update(rayIns); | |||||
| if (update > 0) { | |||||
| updateRayStatus(rayIns.getRayId()); | |||||
| return "删除成功"; | |||||
| } else { | |||||
| return "删除失败"; | |||||
| } | |||||
| } | |||||
| @Override | |||||
| public String batchDelete(List<Long> ids) { | |||||
| for (Long id : ids) { | |||||
| String result = deleteById(id); | |||||
| if (!"删除成功".equals(result)) { | |||||
| return result; | |||||
| } | |||||
| } | |||||
| return "删除成功"; | |||||
| } | |||||
| @Override | |||||
| public boolean terminateRayIns(Long id) throws Exception { | |||||
| RayIns rayIns = rayInsDao.queryById(id); | |||||
| if (rayIns == null) { | |||||
| throw new IllegalStateException("实验实例未查询到,id: " + id); | |||||
| } | |||||
| String currentStatus = rayIns.getStatus(); | |||||
| String name = rayIns.getArgoInsName(); | |||||
| String namespace = rayIns.getArgoInsNs(); | |||||
| // 获取当前状态,如果为空,则从Argo查询 | |||||
| if (StringUtils.isEmpty(currentStatus)) { | |||||
| currentStatus = queryStatusFromArgo(rayIns).getStatus(); | |||||
| } | |||||
| // 只有状态是"Running"时才能终止实例 | |||||
| if (!currentStatus.equalsIgnoreCase(Constant.Running)) { | |||||
| throw new Exception("终止错误,只有运行状态的实例才能终止"); // 如果不是"Running"状态,则不执行终止操作 | |||||
| } | |||||
| // 创建请求数据map | |||||
| Map<String, Object> requestData = new HashMap<>(); | |||||
| requestData.put("namespace", namespace); | |||||
| requestData.put("name", name); | |||||
| // 创建发送数据map,将请求数据作为"data"键的值 | |||||
| Map<String, Object> res = new HashMap<>(); | |||||
| res.put("data", requestData); | |||||
| try { | |||||
| // 发送POST请求到Argo工作流状态查询接口,并将请求数据转换为JSON | |||||
| String req = HttpUtils.sendPost(argoUrl + argoWorkflowTermination, null, JsonUtils.mapToJson(res)); | |||||
| // 检查响应是否为空或无内容 | |||||
| if (StringUtils.isEmpty(req)) { | |||||
| throw new RuntimeException("终止响应内容为空。"); | |||||
| } | |||||
| // 将响应的JSON字符串转换为Map对象 | |||||
| Map<String, Object> runResMap = JsonUtils.jsonToMap(req); | |||||
| // 从响应Map中直接获取"errCode"的值 | |||||
| Integer errCode = (Integer) runResMap.get("errCode"); | |||||
| if (errCode != null && errCode == 0) { | |||||
| //更新autoMlIns,确保状态更新被保存到数据库 | |||||
| RayIns ins = queryStatusFromArgo(rayIns); | |||||
| String nodeStatus = ins.getNodeStatus(); | |||||
| Map<String, Object> nodeMap = JsonUtils.jsonToMap(nodeStatus); | |||||
| // 遍历 map | |||||
| for (Map.Entry<String, Object> entry : nodeMap.entrySet()) { | |||||
| // 获取每个 Map 中的值并强制转换为 Map | |||||
| Map<String, Object> innerMap = (Map<String, Object>) entry.getValue(); | |||||
| // 检查 phase 的值 | |||||
| if (innerMap.containsKey("phase")) { | |||||
| String phaseValue = (String) innerMap.get("phase"); | |||||
| // 如果值不等于 Succeeded,则赋值为 Failed | |||||
| if (!StringUtils.equals("Succeeded", phaseValue)) { | |||||
| innerMap.put("phase", "Failed"); | |||||
| } | |||||
| } | |||||
| } | |||||
| ins.setNodeStatus(JsonUtils.mapToJson(nodeMap)); | |||||
| ins.setStatus(Constant.Terminated); | |||||
| ins.setUpdateTime(new Date()); | |||||
| rayInsDao.update(ins); | |||||
| updateRayStatus(rayIns.getRayId()); | |||||
| return true; | |||||
| } else { | |||||
| return false; | |||||
| } | |||||
| } catch (Exception e) { | |||||
| throw new RuntimeException("终止实例错误: " + e.getMessage(), e); | |||||
| } | |||||
| } | |||||
| @Override | |||||
| public RayIns getDetailById(Long id) throws IOException { | |||||
| RayIns rayIns = rayInsDao.queryById(id); | |||||
| if (Constant.Running.equals(rayIns.getStatus()) || Constant.Pending.equals(rayIns.getStatus())) { | |||||
| rayIns = queryStatusFromArgo(rayIns); | |||||
| } | |||||
| rayIns.setTrialList(getTrialList(rayIns.getResultPath())); | |||||
| return rayIns; | |||||
| } | |||||
| @Override | |||||
| public void updateRayStatus(Long rayId) { | |||||
| List<RayIns> insList = rayInsDao.getByRayId(rayId); | |||||
| List<String> statusList = new ArrayList<>(); | |||||
| // 更新实验状态列表 | |||||
| for (int i = 0; i < insList.size(); i++) { | |||||
| statusList.add(insList.get(i).getStatus()); | |||||
| } | |||||
| String subStatus = statusList.toString().substring(1, statusList.toString().length() - 1); | |||||
| Ray ray = rayDao.getRayById(rayId); | |||||
| if (!StringUtils.equals(ray.getStatusList(), subStatus)) { | |||||
| ray.setStatusList(subStatus); | |||||
| rayDao.edit(ray); | |||||
| } | |||||
| } | |||||
| @Override | |||||
| public RayIns queryStatusFromArgo(RayIns ins) { | |||||
| String namespace = ins.getArgoInsNs(); | |||||
| String name = ins.getArgoInsName(); | |||||
| // 创建请求数据map | |||||
| Map<String, Object> requestData = new HashMap<>(); | |||||
| requestData.put("namespace", namespace); | |||||
| requestData.put("name", name); | |||||
| // 创建发送数据map,将请求数据作为"data"键的值 | |||||
| Map<String, Object> res = new HashMap<>(); | |||||
| res.put("data", requestData); | |||||
| try { | |||||
| // 发送POST请求到Argo工作流状态查询接口,并将请求数据转换为JSON | |||||
| String req = HttpUtils.sendPost(argoUrl + argoWorkflowStatus, null, JsonUtils.mapToJson(res)); | |||||
| // 检查响应是否为空或无内容 | |||||
| if (req == null || StringUtils.isEmpty(req)) { | |||||
| throw new RuntimeException("工作流状态响应为空。"); | |||||
| } | |||||
| // 将响应的JSON字符串转换为Map对象 | |||||
| Map<String, Object> runResMap = JsonUtils.jsonToMap(req); | |||||
| // 从响应Map中获取"data"部分 | |||||
| Map<String, Object> data = (Map<String, Object>) runResMap.get("data"); | |||||
| if (data == null || data.isEmpty()) { | |||||
| throw new RuntimeException("工作流数据为空."); | |||||
| } | |||||
| // 从"data"中获取"status"部分,并返回"phase"的值 | |||||
| Map<String, Object> status = (Map<String, Object>) data.get("status"); | |||||
| if (status == null || status.isEmpty()) { | |||||
| throw new RuntimeException("工作流状态为空。"); | |||||
| } | |||||
| //解析流水线结束时间 | |||||
| String finishedAtString = (String) status.get("finishedAt"); | |||||
| if (finishedAtString != null && !finishedAtString.isEmpty()) { | |||||
| Date finishTime = DateUtils.convertUTCtoShanghaiDate(finishedAtString); | |||||
| ins.setFinishTime(finishTime); | |||||
| } | |||||
| // 解析nodes字段,提取节点状态并转换为JSON字符串 | |||||
| Map<String, Object> nodes = (Map<String, Object>) status.get("nodes"); | |||||
| Map<String, Object> modifiedNodes = new LinkedHashMap<>(); | |||||
| if (nodes != null) { | |||||
| for (Map.Entry<String, Object> nodeEntry : nodes.entrySet()) { | |||||
| Map<String, Object> nodeDetails = (Map<String, Object>) nodeEntry.getValue(); | |||||
| String templateName = (String) nodeDetails.get("displayName"); | |||||
| modifiedNodes.put(templateName, nodeDetails); | |||||
| } | |||||
| } | |||||
| String nodeStatusJson = JsonUtils.mapToJson(modifiedNodes); | |||||
| ins.setNodeStatus(nodeStatusJson); | |||||
| //终止态为终止不改 | |||||
| if (!StringUtils.equals(ins.getStatus(), Constant.Terminated)) { | |||||
| ins.setStatus(StringUtils.isNotEmpty((String) status.get("phase")) ? (String) status.get("phase") : Constant.Pending); | |||||
| } | |||||
| if (StringUtils.equals(ins.getStatus(), "Error")) { | |||||
| ins.setStatus(Constant.Failed); | |||||
| } | |||||
| return ins; | |||||
| } catch (Exception e) { | |||||
| throw new RuntimeException("查询状态失败: " + e.getMessage(), e); | |||||
| } | |||||
| } | |||||
| @Override | |||||
| public List<RayIns> queryByRayInsIsNotTerminated() { | |||||
| return rayInsDao.queryByRayInsIsNotTerminated(); | |||||
| } | |||||
| public ArrayList<Map<String, Object>> getTrialList(String directoryPath) throws IOException { | |||||
| // 获取指定路径下的所有文件 | |||||
| Path dirPath = Paths.get(directoryPath); | |||||
| Path experimentState = Files.list(dirPath).filter(path -> Files.isRegularFile(path) && path.getFileName().toString().startsWith("experiment_state")).collect(Collectors.toList()).get(0); | |||||
| String content = new String(Files.readAllBytes(experimentState)); | |||||
| Map<String, Object> result = JsonUtils.jsonToMap(content); | |||||
| ArrayList<ArrayList> trial_data_list = (ArrayList<ArrayList>) result.get("trial_data"); | |||||
| ArrayList<Map<String, Object>> trialList = new ArrayList<>(); | |||||
| for (ArrayList trial_data : trial_data_list) { | |||||
| Map<String, Object> trial_data_0 = JsonUtils.jsonToMap((String) trial_data.get(0)); | |||||
| Map<String, Object> trial_data_1 = JsonUtils.jsonToMap((String) trial_data.get(1)); | |||||
| Map<String, Object> trial = new HashMap<>(); | |||||
| trial.put("trial_id", trial_data_0.get("trial_id")); | |||||
| trial.put("config", trial_data_0.get("config")); | |||||
| trial.put("status", trial_data_0.get("status")); | |||||
| Map<String, Object> last_result = (Map<String, Object>) trial_data_1.get("last_result"); | |||||
| Map<String, Object> metric_analysis = (Map<String, Object>) trial_data_1.get("metric_analysis"); | |||||
| Map<String, Object> time_total_s = (Map<String, Object>) metric_analysis.get("time_total_s"); | |||||
| trial.put("training_iteration", last_result.get("training_iteration")); | |||||
| trial.put("time", time_total_s.get("avg")); | |||||
| trialList.add(trial); | |||||
| } | |||||
| return trialList; | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,210 @@ | |||||
| package com.ruoyi.platform.service.impl; | |||||
| import com.google.gson.Gson; | |||||
| import com.google.gson.reflect.TypeToken; | |||||
| import com.ruoyi.common.security.utils.SecurityUtils; | |||||
| import com.ruoyi.platform.constant.Constant; | |||||
| import com.ruoyi.platform.domain.Ray; | |||||
| import com.ruoyi.platform.domain.RayIns; | |||||
| import com.ruoyi.platform.mapper.RayDao; | |||||
| import com.ruoyi.platform.mapper.RayInsDao; | |||||
| import com.ruoyi.platform.service.RayInsService; | |||||
| import com.ruoyi.platform.service.RayService; | |||||
| import com.ruoyi.platform.utils.HttpUtils; | |||||
| import com.ruoyi.platform.utils.JacksonUtil; | |||||
| import com.ruoyi.platform.utils.JsonUtils; | |||||
| import com.ruoyi.platform.vo.RayParamVo; | |||||
| import com.ruoyi.platform.vo.RayVo; | |||||
| import org.apache.commons.collections4.MapUtils; | |||||
| import org.apache.commons.lang3.StringUtils; | |||||
| import org.springframework.beans.BeanUtils; | |||||
| import org.springframework.beans.factory.annotation.Value; | |||||
| import org.springframework.data.domain.Page; | |||||
| import org.springframework.data.domain.PageImpl; | |||||
| import org.springframework.data.domain.PageRequest; | |||||
| import org.springframework.stereotype.Service; | |||||
| import javax.annotation.Resource; | |||||
| import java.io.IOException; | |||||
| import java.lang.reflect.Type; | |||||
| import java.util.ArrayList; | |||||
| import java.util.HashMap; | |||||
| import java.util.List; | |||||
| import java.util.Map; | |||||
| @Service("rayService") | |||||
| public class RayServiceImpl implements RayService { | |||||
| @Value("${argo.url}") | |||||
| private String argoUrl; | |||||
| @Value("${argo.convertRay}") | |||||
| String convertRay; | |||||
| @Value("${argo.workflowRun}") | |||||
| private String argoWorkflowRun; | |||||
| @Value("${minio.endpoint}") | |||||
| private String minioEndpoint; | |||||
| @Resource | |||||
| private RayDao rayDao; | |||||
| @Resource | |||||
| private RayInsDao rayInsDao; | |||||
| @Resource | |||||
| private RayInsService rayInsService; | |||||
| @Override | |||||
| public Page<Ray> queryByPage(String name, PageRequest pageRequest) { | |||||
| long total = rayDao.count(name); | |||||
| List<Ray> rays = rayDao.queryByPage(name, pageRequest); | |||||
| return new PageImpl<>(rays, pageRequest, total); | |||||
| } | |||||
| @Override | |||||
| public Ray save(RayVo rayVo) throws Exception { | |||||
| Ray rayByName = rayDao.getRayByName(rayVo.getName()); | |||||
| if (rayByName != null) { | |||||
| throw new RuntimeException("实验名称已存在"); | |||||
| } | |||||
| Ray ray = new Ray(); | |||||
| BeanUtils.copyProperties(rayVo, ray); | |||||
| String username = SecurityUtils.getLoginUser().getUsername(); | |||||
| ray.setCreateBy(username); | |||||
| ray.setUpdateBy(username); | |||||
| ray.setDataset(JacksonUtil.toJSONString(rayVo.getDataset())); | |||||
| ray.setCodeConfig(JacksonUtil.toJSONString(rayVo.getCodeConfig())); | |||||
| ray.setModel(JacksonUtil.toJSONString(rayVo.getModel())); | |||||
| ray.setImage(JacksonUtil.toJSONString(rayVo.getImage())); | |||||
| ray.setParameters(JacksonUtil.toJSONString(rayVo.getParameters())); | |||||
| ray.setPointsToEvaluate(JacksonUtil.toJSONString(rayVo.getPointsToEvaluate())); | |||||
| rayDao.save(ray); | |||||
| return ray; | |||||
| } | |||||
| @Override | |||||
| public String edit(RayVo rayVo) throws Exception { | |||||
| Ray oldRay = rayDao.getRayByName(rayVo.getName()); | |||||
| if (oldRay != null && !oldRay.getId().equals(rayVo.getId())) { | |||||
| throw new RuntimeException("实验名称已存在"); | |||||
| } | |||||
| Ray ray = new Ray(); | |||||
| BeanUtils.copyProperties(rayVo, ray); | |||||
| ray.setUpdateBy(SecurityUtils.getLoginUser().getUsername()); | |||||
| ray.setParameters(JacksonUtil.toJSONString(rayVo.getParameters())); | |||||
| ray.setPointsToEvaluate(JacksonUtil.toJSONString(rayVo.getPointsToEvaluate())); | |||||
| ray.setDataset(JacksonUtil.toJSONString(rayVo.getDataset())); | |||||
| ray.setCodeConfig(JacksonUtil.toJSONString(rayVo.getCodeConfig())); | |||||
| ray.setModel(JacksonUtil.toJSONString(rayVo.getModel())); | |||||
| ray.setImage(JacksonUtil.toJSONString(rayVo.getImage())); | |||||
| rayDao.edit(ray); | |||||
| return "修改成功"; | |||||
| } | |||||
| @Override | |||||
| public RayVo getRayDetail(Long id) throws IOException { | |||||
| Ray ray = rayDao.getRayById(id); | |||||
| RayVo rayVo = new RayVo(); | |||||
| BeanUtils.copyProperties(ray, rayVo); | |||||
| Gson gson = new Gson(); | |||||
| Type listType = new TypeToken<List<Map<String, Object>>>() { | |||||
| }.getType(); | |||||
| if (StringUtils.isNotEmpty(ray.getParameters())) { | |||||
| rayVo.setParameters(gson.fromJson(ray.getParameters(), listType)); | |||||
| } | |||||
| if (StringUtils.isNotEmpty(ray.getPointsToEvaluate())) { | |||||
| rayVo.setPointsToEvaluate(gson.fromJson(ray.getPointsToEvaluate(), listType)); | |||||
| } | |||||
| if (StringUtils.isNotEmpty(ray.getDataset())) { | |||||
| rayVo.setDataset(JsonUtils.jsonToMap(ray.getDataset())); | |||||
| } | |||||
| if (StringUtils.isNotEmpty(ray.getCodeConfig())) { | |||||
| rayVo.setCodeConfig(JsonUtils.jsonToMap(ray.getCodeConfig())); | |||||
| } | |||||
| if (StringUtils.isNotEmpty(ray.getModel())) { | |||||
| rayVo.setModel(JsonUtils.jsonToMap(ray.getModel())); | |||||
| } | |||||
| if (StringUtils.isNotEmpty(ray.getImage())) { | |||||
| rayVo.setImage(JsonUtils.jsonToMap(ray.getImage())); | |||||
| } | |||||
| return rayVo; | |||||
| } | |||||
| @Override | |||||
| public String delete(Long id) { | |||||
| Ray ray = rayDao.getRayById(id); | |||||
| if (ray == null) { | |||||
| throw new RuntimeException("实验不存在"); | |||||
| } | |||||
| String username = SecurityUtils.getLoginUser().getUsername(); | |||||
| String createBy = ray.getCreateBy(); | |||||
| if (!(StringUtils.equals(username, "admin") || StringUtils.equals(username, createBy))) { | |||||
| throw new RuntimeException("无权限删除该实验"); | |||||
| } | |||||
| ray.setState(Constant.State_invalid); | |||||
| return rayDao.edit(ray) > 0 ? "删除成功" : "删除失败"; | |||||
| } | |||||
| @Override | |||||
| public String runRayIns(Long id) throws Exception { | |||||
| Ray ray = rayDao.getRayById(id); | |||||
| if (ray == null) { | |||||
| throw new Exception("自动超参数寻优配置不存在"); | |||||
| } | |||||
| RayParamVo rayParamVo = new RayParamVo(); | |||||
| BeanUtils.copyProperties(ray, rayParamVo); | |||||
| rayParamVo.setCodeConfig(JsonUtils.jsonToMap(ray.getCodeConfig())); | |||||
| rayParamVo.setDataset(JsonUtils.jsonToMap(ray.getDataset())); | |||||
| rayParamVo.setModel(JsonUtils.jsonToMap(ray.getModel())); | |||||
| rayParamVo.setImage(JsonUtils.jsonToMap(ray.getImage())); | |||||
| String param = JsonUtils.objectToJson(rayParamVo); | |||||
| // 调argo转换接口 | |||||
| try { | |||||
| String convertRes = HttpUtils.sendPost(argoUrl + convertRay, param); | |||||
| if (convertRes == null || StringUtils.isEmpty(convertRes)) { | |||||
| throw new RuntimeException("转换流水线失败"); | |||||
| } | |||||
| Map<String, Object> converMap = JsonUtils.jsonToMap(convertRes); | |||||
| // 组装运行接口json | |||||
| Map<String, Object> output = (Map<String, Object>) converMap.get("output"); | |||||
| Map<String, Object> runReqMap = new HashMap<>(); | |||||
| runReqMap.put("data", converMap.get("data")); | |||||
| // 调argo运行接口 | |||||
| String runRes = HttpUtils.sendPost(argoUrl + argoWorkflowRun, JsonUtils.mapToJson(runReqMap)); | |||||
| if (runRes == null || StringUtils.isEmpty(runRes)) { | |||||
| throw new RuntimeException("Failed to run workflow."); | |||||
| } | |||||
| Map<String, Object> runResMap = JsonUtils.jsonToMap(runRes); | |||||
| Map<String, Object> data = (Map<String, Object>) runResMap.get("data"); | |||||
| //判断data为空 | |||||
| if (data == null || MapUtils.isEmpty(data)) { | |||||
| throw new RuntimeException("Failed to run workflow."); | |||||
| } | |||||
| Map<String, Object> metadata = (Map<String, Object>) data.get("metadata"); | |||||
| // 插入记录到实验实例表 | |||||
| RayIns rayIns = new RayIns(); | |||||
| rayIns.setRayId(ray.getId()); | |||||
| rayIns.setArgoInsNs((String) metadata.get("namespace")); | |||||
| rayIns.setArgoInsName((String) metadata.get("name")); | |||||
| rayIns.setParam(param); | |||||
| rayIns.setStatus(Constant.Pending); | |||||
| //替换argoInsName | |||||
| String outputString = JsonUtils.mapToJson(output); | |||||
| rayIns.setNodeResult(outputString.replace("{{workflow.name}}", (String) metadata.get("name"))); | |||||
| Map<String, Object> param_output = (Map<String, Object>) output.get("param_output"); | |||||
| List output1 = (ArrayList) param_output.values().toArray()[0]; | |||||
| Map<String, String> output2 = (Map<String, String>) output1.get(0); | |||||
| String outputPath = minioEndpoint + "/" + output2.get("path").replace("{{workflow.name}}", (String) metadata.get("name")) + "/" + ray.getName(); | |||||
| rayIns.setResultPath(outputPath); | |||||
| rayInsDao.insert(rayIns); | |||||
| rayInsService.updateRayStatus(id); | |||||
| } catch (Exception e) { | |||||
| throw new RuntimeException(e); | |||||
| } | |||||
| return "执行成功"; | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,50 @@ | |||||
| package com.ruoyi.platform.vo; | |||||
| import com.fasterxml.jackson.annotation.JsonInclude; | |||||
| import com.fasterxml.jackson.databind.PropertyNamingStrategy; | |||||
| import com.fasterxml.jackson.databind.annotation.JsonNaming; | |||||
| import io.swagger.annotations.ApiModel; | |||||
| import lombok.Data; | |||||
| import java.util.Map; | |||||
| @Data | |||||
| @JsonNaming(PropertyNamingStrategy.SnakeCaseStrategy.class) | |||||
| @JsonInclude(JsonInclude.Include.NON_NULL) | |||||
| @ApiModel(description = "超参数寻优参数") | |||||
| public class RayParamVo { | |||||
| private Map<String,Object> codeConfig; | |||||
| private Map<String,Object> dataset; | |||||
| private Map<String,Object> image; | |||||
| private Map<String,Object> model; | |||||
| private String mainPy; | |||||
| private String name; | |||||
| private Integer numSamples; | |||||
| private String parameters; | |||||
| private String pointsToEvaluate; | |||||
| private String storagePath; | |||||
| private String searchAlg; | |||||
| private String scheduler; | |||||
| private String metric; | |||||
| private String mode; | |||||
| private Integer maxT; | |||||
| private Integer minSamplesRequired; | |||||
| private String resource; | |||||
| } | |||||
| @@ -0,0 +1,82 @@ | |||||
| package com.ruoyi.platform.vo; | |||||
| import com.fasterxml.jackson.databind.PropertyNamingStrategy; | |||||
| import com.fasterxml.jackson.databind.annotation.JsonNaming; | |||||
| import io.swagger.annotations.ApiModel; | |||||
| import io.swagger.annotations.ApiModelProperty; | |||||
| import lombok.Data; | |||||
| import java.util.Date; | |||||
| import java.util.List; | |||||
| import java.util.Map; | |||||
| @Data | |||||
| @JsonNaming(PropertyNamingStrategy.SnakeCaseStrategy.class) | |||||
| @ApiModel(description = "自动超参数寻优") | |||||
| public class RayVo { | |||||
| private Long id; | |||||
| @ApiModelProperty(value = "实验名称") | |||||
| private String name; | |||||
| @ApiModelProperty(value = "实验描述") | |||||
| private String description; | |||||
| @ApiModelProperty(value = "主函数代码文件") | |||||
| private String mainPy; | |||||
| @ApiModelProperty(value = "总实验次数") | |||||
| private Integer numSamples; | |||||
| @ApiModelProperty(value = "参数") | |||||
| private List<Map<String, Object>> parameters; | |||||
| @ApiModelProperty(value = "手动指定需要运行的参数") | |||||
| private List<Map<String, Object>> pointsToEvaluate; | |||||
| @ApiModelProperty(value = "保存路径") | |||||
| private String storagePath; | |||||
| @ApiModelProperty(value = "搜索算法") | |||||
| private String searchAlg; | |||||
| @ApiModelProperty(value = "调度算法") | |||||
| private String scheduler; | |||||
| @ApiModelProperty(value = "指标") | |||||
| private String metric; | |||||
| @ApiModelProperty(value = "指标最大化或最小化,min or max") | |||||
| private String mode; | |||||
| @ApiModelProperty(value = "单次试验最大时间:单位秒,调度算法为ASHA,HyperBand时传入,每次试验的最大时间单位。测试将在max_t时间单位后停止。") | |||||
| private Integer maxT; | |||||
| @ApiModelProperty(value = "计算中位数的最小试验数:调度算法为MedianStopping时传入,计算中位数的最小试验数。") | |||||
| private Integer minSamplesRequired; | |||||
| private String resource; | |||||
| private String createBy; | |||||
| private Date createTime; | |||||
| private String updateBy; | |||||
| private Date updateTime; | |||||
| private Integer state; | |||||
| private String runState; | |||||
| @ApiModelProperty(value = "代码") | |||||
| private Map<String, Object> codeConfig; | |||||
| private Map<String, Object> dataset; | |||||
| @ApiModelProperty(value = "模型") | |||||
| private Map<String, Object> model; | |||||
| @ApiModelProperty(value = "镜像") | |||||
| private Map<String, Object> image; | |||||
| } | |||||
| @@ -0,0 +1,115 @@ | |||||
| <?xml version="1.0" encoding="UTF-8"?> | |||||
| <!DOCTYPE mapper PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN" "http://mybatis.org/dtd/mybatis-3-mapper.dtd"> | |||||
| <mapper namespace="com.ruoyi.platform.mapper.RayDao"> | |||||
| <insert id="save"> | |||||
| insert into ray(name, description, dataset, model, code_config, main_py, num_samples, parameters, points_to_evaluate, storage_path, | |||||
| search_alg, scheduler, metric, mode, max_t, | |||||
| min_samples_required, resource, image, create_by, update_by) | |||||
| values (#{ray.name}, #{ray.description}, #{ray.dataset}, #{ray.model}, #{ray.codeConfig}, #{ray.mainPy}, #{ray.numSamples}, #{ray.parameters}, | |||||
| #{ray.pointsToEvaluate}, #{ray.storagePath}, | |||||
| #{ray.searchAlg}, #{ray.scheduler}, #{ray.metric}, #{ray.mode}, #{ray.maxT}, #{ray.minSamplesRequired}, | |||||
| #{ray.resource}, #{ray.image}, #{ray.createBy}, #{ray.updateBy}) | |||||
| </insert> | |||||
| <update id="edit"> | |||||
| update ray | |||||
| <set> | |||||
| <if test="ray.name != null and ray.name !=''"> | |||||
| name = #{ray.name}, | |||||
| </if> | |||||
| <if test="ray.description != null and ray.description !=''"> | |||||
| description = #{ray.description}, | |||||
| </if> | |||||
| <if test="ray.dataset != null and ray.dataset !=''"> | |||||
| dataset = #{ray.dataset}, | |||||
| </if> | |||||
| <if test="ray.model != null and ray.model !=''"> | |||||
| model = #{ray.model}, | |||||
| </if> | |||||
| <if test="ray.codeConfig != null and ray.codeConfig !=''"> | |||||
| code_config = #{ray.codeConfig}, | |||||
| </if> | |||||
| <if test="ray.image != null and ray.image !=''"> | |||||
| image = #{ray.image}, | |||||
| </if> | |||||
| <if test="ray.mainPy != null and ray.mainPy !=''"> | |||||
| main_py = #{ray.mainPy}, | |||||
| </if> | |||||
| <if test="ray.numSamples != null"> | |||||
| num_samples = #{ray.numSamples}, | |||||
| </if> | |||||
| <if test="ray.parameters != null and ray.parameters !=''"> | |||||
| parameters = #{ray.parameters}, | |||||
| </if> | |||||
| <if test="ray.pointsToEvaluate != null and ray.pointsToEvaluate !=''"> | |||||
| points_to_evaluate = #{ray.pointsToEvaluate}, | |||||
| </if> | |||||
| <if test="ray.storagePath != null and ray.storagePath !=''"> | |||||
| storage_path = #{ray.storagePath}, | |||||
| </if> | |||||
| <if test="ray.searchAlg != null and ray.searchAlg !=''"> | |||||
| search_alg = #{ray.searchAlg}, | |||||
| </if> | |||||
| <if test="ray.scheduler != null and ray.scheduler !=''"> | |||||
| scheduler = #{ray.scheduler}, | |||||
| </if> | |||||
| <if test="ray.metric != null and ray.metric !=''"> | |||||
| metric = #{ray.metric}, | |||||
| </if> | |||||
| <if test="ray.mode != null and ray.mode !=''"> | |||||
| mode = #{ray.mode}, | |||||
| </if> | |||||
| <if test="ray.maxT != null"> | |||||
| max_t = #{ray.maxT}, | |||||
| </if> | |||||
| <if test="ray.minSamplesRequired != null"> | |||||
| min_samples_required = #{ray.minSamplesRequired}, | |||||
| </if> | |||||
| <if test="ray.resource != null"> | |||||
| resource = #{ray.resource}, | |||||
| </if> | |||||
| <if test="ray.updateBy != null and ray.updateBy !=''"> | |||||
| update_by = #{ray.updateBy}, | |||||
| </if> | |||||
| <if test="ray.statusList != null and ray.statusList !=''"> | |||||
| status_list = #{ray.statusList}, | |||||
| </if> | |||||
| <if test="ray.state != null"> | |||||
| state = #{ray.state}, | |||||
| </if> | |||||
| </set> | |||||
| where id = #{ray.id} | |||||
| </update> | |||||
| <select id="count" resultType="java.lang.Long"> | |||||
| select count(1) from ray | |||||
| <include refid="common_condition"></include> | |||||
| </select> | |||||
| <select id="queryByPage" resultType="com.ruoyi.platform.domain.Ray"> | |||||
| select * from ray | |||||
| <include refid="common_condition"></include> | |||||
| </select> | |||||
| <select id="getRayByName" resultType="com.ruoyi.platform.domain.Ray"> | |||||
| select * | |||||
| from ray _ | |||||
| where name = #{name} | |||||
| and state = 1 | |||||
| </select> | |||||
| <select id="getRayById" resultType="com.ruoyi.platform.domain.Ray"> | |||||
| select * | |||||
| from ray | |||||
| where id = #{id} | |||||
| </select> | |||||
| <sql id="common_condition"> | |||||
| <where> | |||||
| state = 1 | |||||
| <if test="name != null and name != ''"> | |||||
| and name like concat('%', #{name}, '%') | |||||
| </if> | |||||
| </where> | |||||
| </sql> | |||||
| </mapper> | |||||
| @@ -0,0 +1,77 @@ | |||||
| <?xml version="1.0" encoding="UTF-8"?> | |||||
| <!DOCTYPE mapper PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN" "http://mybatis.org/dtd/mybatis-3-mapper.dtd"> | |||||
| <mapper namespace="com.ruoyi.platform.mapper.RayInsDao"> | |||||
| <insert id="insert"> | |||||
| insert into ray_ins(ray_id, result_path, argo_ins_name, argo_ins_ns, node_status, node_result, param, source, | |||||
| status) | |||||
| values (#{rayIns.rayId}, #{rayIns.resultPath}, #{rayIns.argoInsName}, #{rayIns.argoInsNs}, | |||||
| #{rayIns.nodeStatus}, #{rayIns.nodeResult}, #{rayIns.param}, #{rayIns.source}, #{rayIns.status}) | |||||
| </insert> | |||||
| <update id="update"> | |||||
| update ray_ins | |||||
| <set> | |||||
| <if test="rayIns.resultPath != null and rayIns.resultPath != ''"> | |||||
| result_path = #{rayIns.resultPath}, | |||||
| </if> | |||||
| <if test="rayIns.status != null and rayIns.status != ''"> | |||||
| status = #{rayIns.status}, | |||||
| </if> | |||||
| <if test="rayIns.nodeStatus != null and rayIns.nodeStatus != ''"> | |||||
| node_status = #{rayIns.nodeStatus}, | |||||
| </if> | |||||
| <if test="rayIns.nodeResult != null and rayIns.nodeResult != ''"> | |||||
| node_result = #{rayIns.nodeResult}, | |||||
| </if> | |||||
| <if test="rayIns.state != null"> | |||||
| state = #{rayIns.state}, | |||||
| </if> | |||||
| <if test="rayIns.finishTime != null"> | |||||
| finish_time = #{rayIns.finishTime}, | |||||
| </if> | |||||
| </set> | |||||
| where id = #{rayIns.id} | |||||
| </update> | |||||
| <select id="count" resultType="java.lang.Long"> | |||||
| select count(1) | |||||
| from ray_ins | |||||
| <where> | |||||
| state = 1 | |||||
| and ray_id = #{rayId} | |||||
| </where> | |||||
| </select> | |||||
| <select id="queryAllByLimit" resultType="com.ruoyi.platform.domain.RayIns"> | |||||
| select * from ray_ins | |||||
| <where> | |||||
| state = 1 | |||||
| and ray_id = #{rayId} | |||||
| </where> | |||||
| order by update_time DESC | |||||
| limit #{pageable.offset}, #{pageable.pageSize} | |||||
| </select> | |||||
| <select id="queryById" resultType="com.ruoyi.platform.domain.RayIns"> | |||||
| select * from ray_ins | |||||
| <where> | |||||
| state = 1 and id = #{id} | |||||
| </where> | |||||
| </select> | |||||
| <select id="getByRayId" resultType="com.ruoyi.platform.domain.RayIns"> | |||||
| select * | |||||
| from ray_ins | |||||
| where ray_id = #{rayId} | |||||
| and state = 1 | |||||
| order by update_time DESC limit 5 | |||||
| </select> | |||||
| <select id="queryByRayInsIsNotTerminated" resultType="com.ruoyi.platform.domain.RayIns"> | |||||
| select * | |||||
| from ray_ins | |||||
| where (status NOT IN ('Terminated', 'Succeeded', 'Failed') | |||||
| OR status IS NULL) | |||||
| and state = 1 | |||||
| </select> | |||||
| </mapper> | |||||