From 2f9a18bf1c833dedef3942dde7326c5b10ba9115 Mon Sep 17 00:00:00 2001 From: chenzhihang <709011834@qq.com> Date: Wed, 23 Apr 2025 14:26:04 +0800 Subject: [PATCH] =?UTF-8?q?=E6=96=87=E6=9C=AC=E5=88=86=E7=B1=BB=E5=8A=9F?= =?UTF-8?q?=E8=83=BD=E5=BC=80=E5=8F=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../platform/domain/TextClassification.java | 2 + .../scheduling/AutoMlInsStatusTask.java | 4 +- .../scheduling/TextClassificationInsTask.java | 116 ++++++++++++++++++ .../service/impl/NewDatasetServiceImpl.java | 14 +++ .../TextClassificationInsServiceImpl.java | 8 ++ .../impl/TextClassificationServiceImpl.java | 104 +++++++++------- .../TextClassificationDaoMapper.xml | 7 +- .../TextClassificationInsDaoMapper.xml | 2 +- 8 files changed, 206 insertions(+), 51 deletions(-) create mode 100644 ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/scheduling/TextClassificationInsTask.java diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/domain/TextClassification.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/domain/TextClassification.java index f6a239c5..fd11eedc 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/domain/TextClassification.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/domain/TextClassification.java @@ -26,6 +26,8 @@ public class TextClassification { @ApiModelProperty(value = "数据集") private String dataset; + private Integer computingResourceId; + @ApiModelProperty(value = "epochs") private Integer epochs; diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/scheduling/AutoMlInsStatusTask.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/scheduling/AutoMlInsStatusTask.java index 0c71cb7d..c59d67fc 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/scheduling/AutoMlInsStatusTask.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/scheduling/AutoMlInsStatusTask.java @@ -30,7 +30,7 @@ public class AutoMlInsStatusTask { private List autoMlIds = new ArrayList<>(); @Scheduled(cron = "0/30 * * * * ?") // 每30S执行一次 - public void executeAutoMlInsStatus() throws Exception { + public void executeAutoMlInsStatus() { // 首先查到所有非终止态的实验实例 List autoMlInsList = autoMlInsService.queryByAutoMlInsIsNotTerminated(); @@ -59,7 +59,7 @@ public class AutoMlInsStatusTask { } @Scheduled(cron = "0/30 * * * * ?") // / 每30S执行一次 - public void executeAutoMlStatus() throws Exception { + public void executeAutoMlStatus() { if (autoMlIds.size() == 0) { return; } diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/scheduling/TextClassificationInsTask.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/scheduling/TextClassificationInsTask.java new file mode 100644 index 00000000..5499c868 --- /dev/null +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/scheduling/TextClassificationInsTask.java @@ -0,0 +1,116 @@ +package com.ruoyi.platform.scheduling; + +import com.ruoyi.platform.domain.TextClassification; +import com.ruoyi.platform.domain.TextClassificationIns; +import com.ruoyi.platform.mapper.ResourceOccupyDao; +import com.ruoyi.platform.mapper.TextClassificationDao; +import com.ruoyi.platform.mapper.TextClassificationInsDao; +import com.ruoyi.platform.service.ResourceOccupyService; +import com.ruoyi.platform.service.TextClassificationInsService; +import com.ruoyi.system.api.constant.Constant; +import org.apache.commons.lang3.StringUtils; +import org.springframework.scheduling.annotation.Scheduled; +import org.springframework.stereotype.Component; + +import javax.annotation.Resource; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; + +@Component() +public class TextClassificationInsTask { + @Resource + private TextClassificationInsService textClassificationInsService; + @Resource + private TextClassificationInsDao textClassificationInsDao; + @Resource + private TextClassificationDao textClassificationDao; + @Resource + private ResourceOccupyDao resourceOccupyDao; + @Resource + private ResourceOccupyService resourceOccupyService; + + private List textClassificationIds = new ArrayList<>(); + + @Scheduled(cron = "0/30 * * * * ?") // 每30S执行一次 + public void executeTextClassificationInsStatus() { + // 首先查到所有非终止态的实验实例 + List insList = textClassificationInsService.queryByNotTerminated(); + + // 去argo查询状态 + List updateList = new ArrayList<>(); + if (insList != null && insList.size() > 0) { + for (TextClassificationIns ins : insList) { + //当原本状态为null或非终止态时才调用argo接口 + try { + Long userId = resourceOccupyDao.getResourceOccupyByTask(Constant.TaskType_TextClassification, ins.getTextClassificationId(), ins.getId(), null).get(0).getUserId(); + if (resourceOccupyDao.getUserCredit(userId) <= 0) { + ins.setStatus(Constant.Failed); + textClassificationInsService.terminateTextClassificationIns(ins.getId()); + } else { + ins = textClassificationInsService.queryStatusFromArgo(ins); + // 扣除积分 + if (Constant.Running.equals(ins.getStatus())) { + resourceOccupyService.deducing(Constant.TaskType_TextClassification, null, ins.getId(), null, null); + } else if (Constant.Failed.equals(ins.getStatus()) || Constant.Terminated.equals(ins.getStatus()) + || Constant.Succeeded.equals(ins.getStatus())) { + resourceOccupyService.endDeduce(Constant.TaskType_TextClassification, null, ins.getId(), null, null); + } + } + } catch (Exception e) { + ins.setStatus(Constant.Failed); + // 结束扣除积分 + resourceOccupyService.endDeduce(Constant.TaskType_TextClassification, null, ins.getId(), null, null); + } + // 线程安全的添加操作 + synchronized (textClassificationIds) { + textClassificationIds.add(ins.getTextClassificationId()); + } + updateList.add(ins); + } + if (updateList.size() > 0) { + for (TextClassificationIns ins : updateList) { + textClassificationInsDao.update(ins); + } + } + } + } + + @Scheduled(cron = "0/30 * * * * ?") // / 每30S执行一次 + public void executeTextClassificationStatus() { + if (textClassificationIds.size() == 0) { + return; + } + // 存储需要更新的实验对象列表 + List updateTextClassifications = new ArrayList<>(); + for (Long textClassificationId : textClassificationIds) { + // 获取当前实验的所有实例列表 + List insList = textClassificationInsDao.getByTextClassificationId(textClassificationId); + List statusList = new ArrayList<>(); + // 更新实验状态列表 + for (int i = 0; i < insList.size(); i++) { + statusList.add(insList.get(i).getStatus()); + } + String subStatus = statusList.toString().substring(1, statusList.toString().length() - 1); + TextClassification textClassification = textClassificationDao.getById(textClassificationId); + if (!StringUtils.equals(textClassification.getStatusList(), subStatus)) { + textClassification.setStatusList(subStatus); + updateTextClassifications.add(textClassification); + textClassificationDao.edit(textClassification); + } + } + + if (!updateTextClassifications.isEmpty()) { + // 使用Iterator进行安全的删除操作 + Iterator iterator = textClassificationIds.iterator(); + while (iterator.hasNext()) { + Long textClassificationId = iterator.next(); + for (TextClassification textClassification : updateTextClassifications) { + if (textClassification.getId().equals(textClassificationId)) { + iterator.remove(); + } + } + } + } + } +} diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/NewDatasetServiceImpl.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/NewDatasetServiceImpl.java index caeffbbb..5432f17b 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/NewDatasetServiceImpl.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/NewDatasetServiceImpl.java @@ -53,6 +53,8 @@ public class NewDatasetServiceImpl implements NewDatasetService { @Resource private AutoMlDao autoMlDao; @Resource + private TextClassificationDao textClassificationDao; + @Resource private RayDao rayDao; @Resource private ActiveLearnDao activeLearnDao; @@ -426,6 +428,12 @@ public class NewDatasetServiceImpl implements NewDatasetService { throw new Exception("该数据集被自动机器学习:" + autoMls + "使用,不能删除,请先删除自动机器学习"); } + List textClassificationList = textClassificationDao.queryByDatasetId(JSON.toJSONString(map)); + if (textClassificationList != null && !textClassificationList.isEmpty()) { + String textClassifications = String.join(",", textClassificationList.stream().map(TextClassification::getName).collect(Collectors.toSet())); + throw new Exception("该数据集被自动机器学习文本分类:" + textClassifications + "使用,不能删除,请先删除自动机器学习文本分类"); + } + List rayList = rayDao.queryByDatasetId(JSON.toJSONString(map)); if (rayList != null && !rayList.isEmpty()) { String rays = String.join(",", rayList.stream().map(Ray::getName).collect(Collectors.toSet())); @@ -469,6 +477,12 @@ public class NewDatasetServiceImpl implements NewDatasetService { throw new Exception("该数据集版本被自动机器学习:" + autoMls + "使用,不能删除,请先删除自动机器学习"); } + List textClassificationList = textClassificationDao.queryByDatasetId(JSON.toJSONString(map)); + if (textClassificationList != null && !textClassificationList.isEmpty()) { + String textClassifications = String.join(",", textClassificationList.stream().map(TextClassification::getName).collect(Collectors.toSet())); + throw new Exception("该数据集版本被自动机器学习文本分类:" + textClassifications + "使用,不能删除,请先删除自动机器学习文本分类"); + } + List rayList = rayDao.queryByDatasetId(JSON.toJSONString(map)); if (rayList != null && !rayList.isEmpty()) { String rays = String.join(",", rayList.stream().map(Ray::getName).collect(Collectors.toSet())); diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/TextClassificationInsServiceImpl.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/TextClassificationInsServiceImpl.java index 6dd52394..afdd779f 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/TextClassificationInsServiceImpl.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/TextClassificationInsServiceImpl.java @@ -4,6 +4,7 @@ import com.ruoyi.platform.domain.TextClassification; import com.ruoyi.platform.domain.TextClassificationIns; import com.ruoyi.platform.mapper.TextClassificationDao; import com.ruoyi.platform.mapper.TextClassificationInsDao; +import com.ruoyi.platform.service.ResourceOccupyService; import com.ruoyi.platform.service.TextClassificationInsService; import com.ruoyi.platform.utils.DateUtils; import com.ruoyi.platform.utils.HttpUtils; @@ -15,6 +16,7 @@ import org.springframework.data.domain.Page; import org.springframework.data.domain.PageImpl; import org.springframework.data.domain.PageRequest; import org.springframework.stereotype.Service; +import org.springframework.transaction.annotation.Transactional; import javax.annotation.Resource; import java.util.*; @@ -32,6 +34,8 @@ public class TextClassificationInsServiceImpl implements TextClassificationInsSe private TextClassificationDao textClassificationDao; @Resource private TextClassificationInsDao textClassificationInsDao; + @Resource + private ResourceOccupyService resourceOccupyService; @Override public Page queryByPage(TextClassificationIns textClassificationIns, PageRequest pageRequest) { @@ -47,6 +51,7 @@ public class TextClassificationInsServiceImpl implements TextClassificationInsSe } @Override + @Transactional public String removeById(Long id) { TextClassificationIns textClassificationIns = textClassificationInsDao.queryById(id); if (textClassificationIns == null) { @@ -61,6 +66,7 @@ public class TextClassificationInsServiceImpl implements TextClassificationInsSe textClassificationIns.setState(Constant.State_invalid); int update = textClassificationInsDao.update(textClassificationIns); if (update > 0) { + resourceOccupyService.deleteTaskState(Constant.TaskType_TextClassification, textClassificationIns.getTextClassificationId(), id); updateTextClassificationStatus(textClassificationIns.getTextClassificationId()); return "删除成功"; } else { @@ -143,6 +149,8 @@ public class TextClassificationInsServiceImpl implements TextClassificationInsSe ins.setFinishTime(new Date()); this.textClassificationInsDao.update(ins); updateTextClassificationStatus(textClassificationIns.getTextClassificationId()); + // 结束扣积分 + resourceOccupyService.endDeduce(Constant.TaskType_TextClassification, null, id, null, null); return true; } else { return false; diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/TextClassificationServiceImpl.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/TextClassificationServiceImpl.java index 3de94587..3c4f34e5 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/TextClassificationServiceImpl.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/TextClassificationServiceImpl.java @@ -5,6 +5,7 @@ import com.ruoyi.platform.domain.TextClassification; import com.ruoyi.platform.domain.TextClassificationIns; import com.ruoyi.platform.mapper.TextClassificationDao; import com.ruoyi.platform.mapper.TextClassificationInsDao; +import com.ruoyi.platform.service.ResourceOccupyService; import com.ruoyi.platform.service.TextClassificationInsService; import com.ruoyi.platform.service.TextClassificationService; import com.ruoyi.platform.utils.HttpUtils; @@ -21,6 +22,7 @@ import org.springframework.data.domain.Page; import org.springframework.data.domain.PageImpl; import org.springframework.data.domain.PageRequest; import org.springframework.stereotype.Service; +import org.springframework.transaction.annotation.Transactional; import javax.annotation.Resource; import java.io.IOException; @@ -48,6 +50,8 @@ public class TextClassificationServiceImpl implements TextClassificationService private TextClassificationInsDao textClassificationInsDao; @Resource private TextClassificationInsService textClassificationInsService; + @Resource + private ResourceOccupyService resourceOccupyService; @Override public Page queryByPage(String name, PageRequest pageRequest) { @@ -94,6 +98,7 @@ public class TextClassificationServiceImpl implements TextClassificationService } @Override + @Transactional public String delete(Long id) { TextClassification textClassification = textClassificationDao.getById(id); if (textClassification == null) { @@ -105,6 +110,7 @@ public class TextClassificationServiceImpl implements TextClassificationService throw new RuntimeException("无权限删除该实验"); } textClassification.setState(Constant.State_invalid); + resourceOccupyService.deleteTaskState(Constant.TaskType_TextClassification, id, null); return textClassificationDao.edit(textClassification) > 0 ? "删除成功" : "删除失败"; } @@ -125,54 +131,60 @@ public class TextClassificationServiceImpl implements TextClassificationService if (textClassification == null) { throw new Exception("文本分类配置不存在"); } - TextClassificationParamVo paramVo = new TextClassificationParamVo(); - BeanUtils.copyProperties(textClassification, paramVo); - paramVo.setDataset(JsonUtils.jsonToMap(textClassification.getDataset())); - String param = JsonUtils.objectToJson(paramVo); - // 调argo转换接口 - try { - String convertRes = HttpUtils.sendPost(argoUrl + convertTextClassification, param); - if (convertRes == null || StringUtils.isEmpty(convertRes)) { - throw new RuntimeException("转换流水线失败"); - } - Map converMap = JsonUtils.jsonToMap(convertRes); - // 组装运行接口json - Map output = (Map) converMap.get("output"); - Map runReqMap = new HashMap<>(); - runReqMap.put("data", converMap.get("data")); - // 调argo运行接口 - String runRes = HttpUtils.sendPost(argoUrl + argoWorkflowRun, JsonUtils.mapToJson(runReqMap)); - if (runRes == null || StringUtils.isEmpty(runRes)) { - throw new RuntimeException("运行流水线失败"); - } - Map runResMap = JsonUtils.jsonToMap(runRes); - Map data = (Map) runResMap.get("data"); - //判断data为空 - if (data == null || MapUtils.isEmpty(data)) { - throw new RuntimeException("运行流水线失败"); - } - Map metadata = (Map) data.get("metadata"); - // 插入记录到实验实例表 - TextClassificationIns textClassificationIns = new TextClassificationIns(); - textClassificationIns.setTextClassificationId(id); - textClassificationIns.setArgoInsNs((String) metadata.get("namespace")); - textClassificationIns.setArgoInsName((String) metadata.get("name")); - textClassificationIns.setParam(param); - textClassificationIns.setStatus(Constant.Pending); - //替换argoInsName - String outputString = JsonUtils.mapToJson(output); - textClassificationIns.setNodeResult(outputString.replace("{{workflow.name}}", (String) metadata.get("name"))); + // 记录开始扣积分 + if (resourceOccupyService.haveResource(textClassification.getComputingResourceId(), 1)) { + TextClassificationParamVo paramVo = new TextClassificationParamVo(); + BeanUtils.copyProperties(textClassification, paramVo); + paramVo.setDataset(JsonUtils.jsonToMap(textClassification.getDataset())); + String param = JsonUtils.objectToJson(paramVo); + // 调argo转换接口 + try { + String convertRes = HttpUtils.sendPost(argoUrl + convertTextClassification, param); + if (convertRes == null || StringUtils.isEmpty(convertRes)) { + throw new RuntimeException("转换流水线失败"); + } + Map converMap = JsonUtils.jsonToMap(convertRes); + // 组装运行接口json + Map output = (Map) converMap.get("output"); + Map runReqMap = new HashMap<>(); + runReqMap.put("data", converMap.get("data")); + // 调argo运行接口 + String runRes = HttpUtils.sendPost(argoUrl + argoWorkflowRun, JsonUtils.mapToJson(runReqMap)); + + if (runRes == null || StringUtils.isEmpty(runRes)) { + throw new RuntimeException("运行流水线失败"); + } + Map runResMap = JsonUtils.jsonToMap(runRes); + Map data = (Map) runResMap.get("data"); + //判断data为空 + if (data == null || MapUtils.isEmpty(data)) { + throw new RuntimeException("运行流水线失败"); + } + Map metadata = (Map) data.get("metadata"); + // 插入记录到实验实例表 + TextClassificationIns textClassificationIns = new TextClassificationIns(); + textClassificationIns.setTextClassificationId(id); + textClassificationIns.setArgoInsNs((String) metadata.get("namespace")); + textClassificationIns.setArgoInsName((String) metadata.get("name")); + textClassificationIns.setParam(param); + textClassificationIns.setStatus(Constant.Pending); + //替换argoInsName + String outputString = JsonUtils.mapToJson(output); + textClassificationIns.setNodeResult(outputString.replace("{{workflow.name}}", (String) metadata.get("name"))); - Map param_output = (Map) output.get("param_output"); - List output1 = (ArrayList) param_output.values().toArray()[0]; - Map output2 = (Map) output1.get(0); - String outputPath = minioEndpoint + "/" + output2.get("path").replace("{{workflow.name}}", (String) metadata.get("name")) + "/"; - textClassificationIns.setModelPath(outputPath + "/saved_dict/" + textClassification.getModel() + ".ckpt"); - textClassificationInsDao.insert(textClassificationIns); - textClassificationInsService.updateTextClassificationStatus(id); - } catch (Exception e) { - throw new RuntimeException(e); + Map param_output = (Map) output.get("param_output"); + List output1 = (ArrayList) param_output.values().toArray()[0]; + Map output2 = (Map) output1.get(0); + String outputPath = minioEndpoint + "/" + output2.get("path").replace("{{workflow.name}}", (String) metadata.get("name")) + "/"; + textClassificationIns.setModelPath(outputPath + "/saved_dict/" + textClassification.getModel() + ".ckpt"); + textClassificationInsDao.insert(textClassificationIns); + textClassificationInsService.updateTextClassificationStatus(id); + // 记录开始扣除积分 + resourceOccupyService.startDeduce(textClassification.getComputingResourceId(), 1, Constant.TaskType_TextClassification, id, textClassificationIns.getId(), null, textClassification.getName(), null, null); + } catch (Exception e) { + throw new RuntimeException(e); + } } return "执行成功"; } diff --git a/ruoyi-modules/management-platform/src/main/resources/mapper/managementPlatform/TextClassificationDaoMapper.xml b/ruoyi-modules/management-platform/src/main/resources/mapper/managementPlatform/TextClassificationDaoMapper.xml index b600dec9..5f0f3b0c 100644 --- a/ruoyi-modules/management-platform/src/main/resources/mapper/managementPlatform/TextClassificationDaoMapper.xml +++ b/ruoyi-modules/management-platform/src/main/resources/mapper/managementPlatform/TextClassificationDaoMapper.xml @@ -26,10 +26,10 @@ - insert into text_classification(name, description, model, dataset, epochs, batch_size, lr, create_by, update_by) + insert into text_classification(name, description, model, dataset, epochs, batch_size, lr, create_by, update_by, computing_resource_id) values (#{textClassification.name}, #{textClassification.description}, #{textClassification.model}, #{textClassification.dataset}, #{textClassification.epochs}, #{textClassification.batchSize}, - #{textClassification.lr}, #{textClassification.createBy}, #{textClassification.updateBy},) + #{textClassification.lr}, #{textClassification.createBy}, #{textClassification.updateBy}, #{textClassification.computingResourceId}) @@ -47,6 +47,9 @@ dataset = #{textClassification.dataset}, + + computing_resource_id = #{textClassification.computingResourceId}, + epochs = #{textClassification.epochs}, diff --git a/ruoyi-modules/management-platform/src/main/resources/mapper/managementPlatform/TextClassificationInsDaoMapper.xml b/ruoyi-modules/management-platform/src/main/resources/mapper/managementPlatform/TextClassificationInsDaoMapper.xml index fd481dae..79254f76 100644 --- a/ruoyi-modules/management-platform/src/main/resources/mapper/managementPlatform/TextClassificationInsDaoMapper.xml +++ b/ruoyi-modules/management-platform/src/main/resources/mapper/managementPlatform/TextClassificationInsDaoMapper.xml @@ -75,7 +75,7 @@ where id = #{textClassificationIns.id} - select * from text_classification_ins where (status NOT IN ('Terminated', 'Succeeded', 'Failed')