| @@ -26,6 +26,8 @@ public class TextClassification { | |||
| @ApiModelProperty(value = "数据集") | |||
| private String dataset; | |||
| private Integer computingResourceId; | |||
| @ApiModelProperty(value = "epochs") | |||
| private Integer epochs; | |||
| @@ -30,7 +30,7 @@ public class AutoMlInsStatusTask { | |||
| private List<Long> autoMlIds = new ArrayList<>(); | |||
| @Scheduled(cron = "0/30 * * * * ?") // 每30S执行一次 | |||
| public void executeAutoMlInsStatus() throws Exception { | |||
| public void executeAutoMlInsStatus() { | |||
| // 首先查到所有非终止态的实验实例 | |||
| List<AutoMlIns> autoMlInsList = autoMlInsService.queryByAutoMlInsIsNotTerminated(); | |||
| @@ -59,7 +59,7 @@ public class AutoMlInsStatusTask { | |||
| } | |||
| @Scheduled(cron = "0/30 * * * * ?") // / 每30S执行一次 | |||
| public void executeAutoMlStatus() throws Exception { | |||
| public void executeAutoMlStatus() { | |||
| if (autoMlIds.size() == 0) { | |||
| return; | |||
| } | |||
| @@ -0,0 +1,116 @@ | |||
| package com.ruoyi.platform.scheduling; | |||
| import com.ruoyi.platform.domain.TextClassification; | |||
| import com.ruoyi.platform.domain.TextClassificationIns; | |||
| import com.ruoyi.platform.mapper.ResourceOccupyDao; | |||
| import com.ruoyi.platform.mapper.TextClassificationDao; | |||
| import com.ruoyi.platform.mapper.TextClassificationInsDao; | |||
| import com.ruoyi.platform.service.ResourceOccupyService; | |||
| import com.ruoyi.platform.service.TextClassificationInsService; | |||
| import com.ruoyi.system.api.constant.Constant; | |||
| import org.apache.commons.lang3.StringUtils; | |||
| import org.springframework.scheduling.annotation.Scheduled; | |||
| import org.springframework.stereotype.Component; | |||
| import javax.annotation.Resource; | |||
| import java.util.ArrayList; | |||
| import java.util.Iterator; | |||
| import java.util.List; | |||
| @Component() | |||
| public class TextClassificationInsTask { | |||
| @Resource | |||
| private TextClassificationInsService textClassificationInsService; | |||
| @Resource | |||
| private TextClassificationInsDao textClassificationInsDao; | |||
| @Resource | |||
| private TextClassificationDao textClassificationDao; | |||
| @Resource | |||
| private ResourceOccupyDao resourceOccupyDao; | |||
| @Resource | |||
| private ResourceOccupyService resourceOccupyService; | |||
| private List<Long> textClassificationIds = new ArrayList<>(); | |||
| @Scheduled(cron = "0/30 * * * * ?") // 每30S执行一次 | |||
| public void executeTextClassificationInsStatus() { | |||
| // 首先查到所有非终止态的实验实例 | |||
| List<TextClassificationIns> insList = textClassificationInsService.queryByNotTerminated(); | |||
| // 去argo查询状态 | |||
| List<TextClassificationIns> updateList = new ArrayList<>(); | |||
| if (insList != null && insList.size() > 0) { | |||
| for (TextClassificationIns ins : insList) { | |||
| //当原本状态为null或非终止态时才调用argo接口 | |||
| try { | |||
| Long userId = resourceOccupyDao.getResourceOccupyByTask(Constant.TaskType_TextClassification, ins.getTextClassificationId(), ins.getId(), null).get(0).getUserId(); | |||
| if (resourceOccupyDao.getUserCredit(userId) <= 0) { | |||
| ins.setStatus(Constant.Failed); | |||
| textClassificationInsService.terminateTextClassificationIns(ins.getId()); | |||
| } else { | |||
| ins = textClassificationInsService.queryStatusFromArgo(ins); | |||
| // 扣除积分 | |||
| if (Constant.Running.equals(ins.getStatus())) { | |||
| resourceOccupyService.deducing(Constant.TaskType_TextClassification, null, ins.getId(), null, null); | |||
| } else if (Constant.Failed.equals(ins.getStatus()) || Constant.Terminated.equals(ins.getStatus()) | |||
| || Constant.Succeeded.equals(ins.getStatus())) { | |||
| resourceOccupyService.endDeduce(Constant.TaskType_TextClassification, null, ins.getId(), null, null); | |||
| } | |||
| } | |||
| } catch (Exception e) { | |||
| ins.setStatus(Constant.Failed); | |||
| // 结束扣除积分 | |||
| resourceOccupyService.endDeduce(Constant.TaskType_TextClassification, null, ins.getId(), null, null); | |||
| } | |||
| // 线程安全的添加操作 | |||
| synchronized (textClassificationIds) { | |||
| textClassificationIds.add(ins.getTextClassificationId()); | |||
| } | |||
| updateList.add(ins); | |||
| } | |||
| if (updateList.size() > 0) { | |||
| for (TextClassificationIns ins : updateList) { | |||
| textClassificationInsDao.update(ins); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @Scheduled(cron = "0/30 * * * * ?") // / 每30S执行一次 | |||
| public void executeTextClassificationStatus() { | |||
| if (textClassificationIds.size() == 0) { | |||
| return; | |||
| } | |||
| // 存储需要更新的实验对象列表 | |||
| List<TextClassification> updateTextClassifications = new ArrayList<>(); | |||
| for (Long textClassificationId : textClassificationIds) { | |||
| // 获取当前实验的所有实例列表 | |||
| List<TextClassificationIns> insList = textClassificationInsDao.getByTextClassificationId(textClassificationId); | |||
| List<String> statusList = new ArrayList<>(); | |||
| // 更新实验状态列表 | |||
| for (int i = 0; i < insList.size(); i++) { | |||
| statusList.add(insList.get(i).getStatus()); | |||
| } | |||
| String subStatus = statusList.toString().substring(1, statusList.toString().length() - 1); | |||
| TextClassification textClassification = textClassificationDao.getById(textClassificationId); | |||
| if (!StringUtils.equals(textClassification.getStatusList(), subStatus)) { | |||
| textClassification.setStatusList(subStatus); | |||
| updateTextClassifications.add(textClassification); | |||
| textClassificationDao.edit(textClassification); | |||
| } | |||
| } | |||
| if (!updateTextClassifications.isEmpty()) { | |||
| // 使用Iterator进行安全的删除操作 | |||
| Iterator<Long> iterator = textClassificationIds.iterator(); | |||
| while (iterator.hasNext()) { | |||
| Long textClassificationId = iterator.next(); | |||
| for (TextClassification textClassification : updateTextClassifications) { | |||
| if (textClassification.getId().equals(textClassificationId)) { | |||
| iterator.remove(); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -53,6 +53,8 @@ public class NewDatasetServiceImpl implements NewDatasetService { | |||
| @Resource | |||
| private AutoMlDao autoMlDao; | |||
| @Resource | |||
| private TextClassificationDao textClassificationDao; | |||
| @Resource | |||
| private RayDao rayDao; | |||
| @Resource | |||
| private ActiveLearnDao activeLearnDao; | |||
| @@ -426,6 +428,12 @@ public class NewDatasetServiceImpl implements NewDatasetService { | |||
| throw new Exception("该数据集被自动机器学习:" + autoMls + "使用,不能删除,请先删除自动机器学习"); | |||
| } | |||
| List<TextClassification> textClassificationList = textClassificationDao.queryByDatasetId(JSON.toJSONString(map)); | |||
| if (textClassificationList != null && !textClassificationList.isEmpty()) { | |||
| String textClassifications = String.join(",", textClassificationList.stream().map(TextClassification::getName).collect(Collectors.toSet())); | |||
| throw new Exception("该数据集被自动机器学习文本分类:" + textClassifications + "使用,不能删除,请先删除自动机器学习文本分类"); | |||
| } | |||
| List<Ray> rayList = rayDao.queryByDatasetId(JSON.toJSONString(map)); | |||
| if (rayList != null && !rayList.isEmpty()) { | |||
| String rays = String.join(",", rayList.stream().map(Ray::getName).collect(Collectors.toSet())); | |||
| @@ -469,6 +477,12 @@ public class NewDatasetServiceImpl implements NewDatasetService { | |||
| throw new Exception("该数据集版本被自动机器学习:" + autoMls + "使用,不能删除,请先删除自动机器学习"); | |||
| } | |||
| List<TextClassification> textClassificationList = textClassificationDao.queryByDatasetId(JSON.toJSONString(map)); | |||
| if (textClassificationList != null && !textClassificationList.isEmpty()) { | |||
| String textClassifications = String.join(",", textClassificationList.stream().map(TextClassification::getName).collect(Collectors.toSet())); | |||
| throw new Exception("该数据集版本被自动机器学习文本分类:" + textClassifications + "使用,不能删除,请先删除自动机器学习文本分类"); | |||
| } | |||
| List<Ray> rayList = rayDao.queryByDatasetId(JSON.toJSONString(map)); | |||
| if (rayList != null && !rayList.isEmpty()) { | |||
| String rays = String.join(",", rayList.stream().map(Ray::getName).collect(Collectors.toSet())); | |||
| @@ -4,6 +4,7 @@ import com.ruoyi.platform.domain.TextClassification; | |||
| import com.ruoyi.platform.domain.TextClassificationIns; | |||
| import com.ruoyi.platform.mapper.TextClassificationDao; | |||
| import com.ruoyi.platform.mapper.TextClassificationInsDao; | |||
| import com.ruoyi.platform.service.ResourceOccupyService; | |||
| import com.ruoyi.platform.service.TextClassificationInsService; | |||
| import com.ruoyi.platform.utils.DateUtils; | |||
| import com.ruoyi.platform.utils.HttpUtils; | |||
| @@ -15,6 +16,7 @@ import org.springframework.data.domain.Page; | |||
| import org.springframework.data.domain.PageImpl; | |||
| import org.springframework.data.domain.PageRequest; | |||
| import org.springframework.stereotype.Service; | |||
| import org.springframework.transaction.annotation.Transactional; | |||
| import javax.annotation.Resource; | |||
| import java.util.*; | |||
| @@ -32,6 +34,8 @@ public class TextClassificationInsServiceImpl implements TextClassificationInsSe | |||
| private TextClassificationDao textClassificationDao; | |||
| @Resource | |||
| private TextClassificationInsDao textClassificationInsDao; | |||
| @Resource | |||
| private ResourceOccupyService resourceOccupyService; | |||
| @Override | |||
| public Page<TextClassificationIns> queryByPage(TextClassificationIns textClassificationIns, PageRequest pageRequest) { | |||
| @@ -47,6 +51,7 @@ public class TextClassificationInsServiceImpl implements TextClassificationInsSe | |||
| } | |||
| @Override | |||
| @Transactional | |||
| public String removeById(Long id) { | |||
| TextClassificationIns textClassificationIns = textClassificationInsDao.queryById(id); | |||
| if (textClassificationIns == null) { | |||
| @@ -61,6 +66,7 @@ public class TextClassificationInsServiceImpl implements TextClassificationInsSe | |||
| textClassificationIns.setState(Constant.State_invalid); | |||
| int update = textClassificationInsDao.update(textClassificationIns); | |||
| if (update > 0) { | |||
| resourceOccupyService.deleteTaskState(Constant.TaskType_TextClassification, textClassificationIns.getTextClassificationId(), id); | |||
| updateTextClassificationStatus(textClassificationIns.getTextClassificationId()); | |||
| return "删除成功"; | |||
| } else { | |||
| @@ -143,6 +149,8 @@ public class TextClassificationInsServiceImpl implements TextClassificationInsSe | |||
| ins.setFinishTime(new Date()); | |||
| this.textClassificationInsDao.update(ins); | |||
| updateTextClassificationStatus(textClassificationIns.getTextClassificationId()); | |||
| // 结束扣积分 | |||
| resourceOccupyService.endDeduce(Constant.TaskType_TextClassification, null, id, null, null); | |||
| return true; | |||
| } else { | |||
| return false; | |||
| @@ -5,6 +5,7 @@ import com.ruoyi.platform.domain.TextClassification; | |||
| import com.ruoyi.platform.domain.TextClassificationIns; | |||
| import com.ruoyi.platform.mapper.TextClassificationDao; | |||
| import com.ruoyi.platform.mapper.TextClassificationInsDao; | |||
| import com.ruoyi.platform.service.ResourceOccupyService; | |||
| import com.ruoyi.platform.service.TextClassificationInsService; | |||
| import com.ruoyi.platform.service.TextClassificationService; | |||
| import com.ruoyi.platform.utils.HttpUtils; | |||
| @@ -21,6 +22,7 @@ import org.springframework.data.domain.Page; | |||
| import org.springframework.data.domain.PageImpl; | |||
| import org.springframework.data.domain.PageRequest; | |||
| import org.springframework.stereotype.Service; | |||
| import org.springframework.transaction.annotation.Transactional; | |||
| import javax.annotation.Resource; | |||
| import java.io.IOException; | |||
| @@ -48,6 +50,8 @@ public class TextClassificationServiceImpl implements TextClassificationService | |||
| private TextClassificationInsDao textClassificationInsDao; | |||
| @Resource | |||
| private TextClassificationInsService textClassificationInsService; | |||
| @Resource | |||
| private ResourceOccupyService resourceOccupyService; | |||
| @Override | |||
| public Page<TextClassification> queryByPage(String name, PageRequest pageRequest) { | |||
| @@ -94,6 +98,7 @@ public class TextClassificationServiceImpl implements TextClassificationService | |||
| } | |||
| @Override | |||
| @Transactional | |||
| public String delete(Long id) { | |||
| TextClassification textClassification = textClassificationDao.getById(id); | |||
| if (textClassification == null) { | |||
| @@ -105,6 +110,7 @@ public class TextClassificationServiceImpl implements TextClassificationService | |||
| throw new RuntimeException("无权限删除该实验"); | |||
| } | |||
| textClassification.setState(Constant.State_invalid); | |||
| resourceOccupyService.deleteTaskState(Constant.TaskType_TextClassification, id, null); | |||
| return textClassificationDao.edit(textClassification) > 0 ? "删除成功" : "删除失败"; | |||
| } | |||
| @@ -125,54 +131,60 @@ public class TextClassificationServiceImpl implements TextClassificationService | |||
| if (textClassification == null) { | |||
| throw new Exception("文本分类配置不存在"); | |||
| } | |||
| TextClassificationParamVo paramVo = new TextClassificationParamVo(); | |||
| BeanUtils.copyProperties(textClassification, paramVo); | |||
| paramVo.setDataset(JsonUtils.jsonToMap(textClassification.getDataset())); | |||
| String param = JsonUtils.objectToJson(paramVo); | |||
| // 调argo转换接口 | |||
| try { | |||
| String convertRes = HttpUtils.sendPost(argoUrl + convertTextClassification, param); | |||
| if (convertRes == null || StringUtils.isEmpty(convertRes)) { | |||
| throw new RuntimeException("转换流水线失败"); | |||
| } | |||
| Map<String, Object> converMap = JsonUtils.jsonToMap(convertRes); | |||
| // 组装运行接口json | |||
| Map<String, Object> output = (Map<String, Object>) converMap.get("output"); | |||
| Map<String, Object> runReqMap = new HashMap<>(); | |||
| runReqMap.put("data", converMap.get("data")); | |||
| // 调argo运行接口 | |||
| String runRes = HttpUtils.sendPost(argoUrl + argoWorkflowRun, JsonUtils.mapToJson(runReqMap)); | |||
| if (runRes == null || StringUtils.isEmpty(runRes)) { | |||
| throw new RuntimeException("运行流水线失败"); | |||
| } | |||
| Map<String, Object> runResMap = JsonUtils.jsonToMap(runRes); | |||
| Map<String, Object> data = (Map<String, Object>) runResMap.get("data"); | |||
| //判断data为空 | |||
| if (data == null || MapUtils.isEmpty(data)) { | |||
| throw new RuntimeException("运行流水线失败"); | |||
| } | |||
| Map<String, Object> metadata = (Map<String, Object>) data.get("metadata"); | |||
| // 插入记录到实验实例表 | |||
| TextClassificationIns textClassificationIns = new TextClassificationIns(); | |||
| textClassificationIns.setTextClassificationId(id); | |||
| textClassificationIns.setArgoInsNs((String) metadata.get("namespace")); | |||
| textClassificationIns.setArgoInsName((String) metadata.get("name")); | |||
| textClassificationIns.setParam(param); | |||
| textClassificationIns.setStatus(Constant.Pending); | |||
| //替换argoInsName | |||
| String outputString = JsonUtils.mapToJson(output); | |||
| textClassificationIns.setNodeResult(outputString.replace("{{workflow.name}}", (String) metadata.get("name"))); | |||
| // 记录开始扣积分 | |||
| if (resourceOccupyService.haveResource(textClassification.getComputingResourceId(), 1)) { | |||
| TextClassificationParamVo paramVo = new TextClassificationParamVo(); | |||
| BeanUtils.copyProperties(textClassification, paramVo); | |||
| paramVo.setDataset(JsonUtils.jsonToMap(textClassification.getDataset())); | |||
| String param = JsonUtils.objectToJson(paramVo); | |||
| // 调argo转换接口 | |||
| try { | |||
| String convertRes = HttpUtils.sendPost(argoUrl + convertTextClassification, param); | |||
| if (convertRes == null || StringUtils.isEmpty(convertRes)) { | |||
| throw new RuntimeException("转换流水线失败"); | |||
| } | |||
| Map<String, Object> converMap = JsonUtils.jsonToMap(convertRes); | |||
| // 组装运行接口json | |||
| Map<String, Object> output = (Map<String, Object>) converMap.get("output"); | |||
| Map<String, Object> runReqMap = new HashMap<>(); | |||
| runReqMap.put("data", converMap.get("data")); | |||
| // 调argo运行接口 | |||
| String runRes = HttpUtils.sendPost(argoUrl + argoWorkflowRun, JsonUtils.mapToJson(runReqMap)); | |||
| if (runRes == null || StringUtils.isEmpty(runRes)) { | |||
| throw new RuntimeException("运行流水线失败"); | |||
| } | |||
| Map<String, Object> runResMap = JsonUtils.jsonToMap(runRes); | |||
| Map<String, Object> data = (Map<String, Object>) runResMap.get("data"); | |||
| //判断data为空 | |||
| if (data == null || MapUtils.isEmpty(data)) { | |||
| throw new RuntimeException("运行流水线失败"); | |||
| } | |||
| Map<String, Object> metadata = (Map<String, Object>) data.get("metadata"); | |||
| // 插入记录到实验实例表 | |||
| TextClassificationIns textClassificationIns = new TextClassificationIns(); | |||
| textClassificationIns.setTextClassificationId(id); | |||
| textClassificationIns.setArgoInsNs((String) metadata.get("namespace")); | |||
| textClassificationIns.setArgoInsName((String) metadata.get("name")); | |||
| textClassificationIns.setParam(param); | |||
| textClassificationIns.setStatus(Constant.Pending); | |||
| //替换argoInsName | |||
| String outputString = JsonUtils.mapToJson(output); | |||
| textClassificationIns.setNodeResult(outputString.replace("{{workflow.name}}", (String) metadata.get("name"))); | |||
| Map<String, Object> param_output = (Map<String, Object>) output.get("param_output"); | |||
| List output1 = (ArrayList) param_output.values().toArray()[0]; | |||
| Map<String, String> output2 = (Map<String, String>) output1.get(0); | |||
| String outputPath = minioEndpoint + "/" + output2.get("path").replace("{{workflow.name}}", (String) metadata.get("name")) + "/"; | |||
| textClassificationIns.setModelPath(outputPath + "/saved_dict/" + textClassification.getModel() + ".ckpt"); | |||
| textClassificationInsDao.insert(textClassificationIns); | |||
| textClassificationInsService.updateTextClassificationStatus(id); | |||
| } catch (Exception e) { | |||
| throw new RuntimeException(e); | |||
| Map<String, Object> param_output = (Map<String, Object>) output.get("param_output"); | |||
| List output1 = (ArrayList) param_output.values().toArray()[0]; | |||
| Map<String, String> output2 = (Map<String, String>) output1.get(0); | |||
| String outputPath = minioEndpoint + "/" + output2.get("path").replace("{{workflow.name}}", (String) metadata.get("name")) + "/"; | |||
| textClassificationIns.setModelPath(outputPath + "/saved_dict/" + textClassification.getModel() + ".ckpt"); | |||
| textClassificationInsDao.insert(textClassificationIns); | |||
| textClassificationInsService.updateTextClassificationStatus(id); | |||
| // 记录开始扣除积分 | |||
| resourceOccupyService.startDeduce(textClassification.getComputingResourceId(), 1, Constant.TaskType_TextClassification, id, textClassificationIns.getId(), null, textClassification.getName(), null, null); | |||
| } catch (Exception e) { | |||
| throw new RuntimeException(e); | |||
| } | |||
| } | |||
| return "执行成功"; | |||
| } | |||
| @@ -26,10 +26,10 @@ | |||
| </select> | |||
| <insert id="save" keyProperty="id" useGeneratedKeys="true"> | |||
| insert into text_classification(name, description, model, dataset, epochs, batch_size, lr, create_by, update_by) | |||
| insert into text_classification(name, description, model, dataset, epochs, batch_size, lr, create_by, update_by, computing_resource_id) | |||
| values (#{textClassification.name}, #{textClassification.description}, #{textClassification.model}, | |||
| #{textClassification.dataset}, #{textClassification.epochs}, #{textClassification.batchSize}, | |||
| #{textClassification.lr}, #{textClassification.createBy}, #{textClassification.updateBy},) | |||
| #{textClassification.lr}, #{textClassification.createBy}, #{textClassification.updateBy}, #{textClassification.computingResourceId}) | |||
| </insert> | |||
| <update id="edit"> | |||
| @@ -47,6 +47,9 @@ | |||
| <if test="textClassification.dataset != null and textClassification.dataset !=''"> | |||
| dataset = #{textClassification.dataset}, | |||
| </if> | |||
| <if test="textClassification.computingResourceId != null"> | |||
| computing_resource_id = #{textClassification.computingResourceId}, | |||
| </if> | |||
| <if test="textClassification.epochs != null"> | |||
| epochs = #{textClassification.epochs}, | |||
| </if> | |||
| @@ -75,7 +75,7 @@ | |||
| where id = #{textClassificationIns.id} | |||
| </update> | |||
| <select id="queryByAutoMlInsIsNotTerminated" resultType="com.ruoyi.platform.domain.TextClassificationIns"> | |||
| <select id="queryByNotTerminated" resultType="com.ruoyi.platform.domain.TextClassificationIns"> | |||
| select * | |||
| from text_classification_ins | |||
| where (status NOT IN ('Terminated', 'Succeeded', 'Failed') | |||