diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/MachineLearnServiceImpl.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/MachineLearnServiceImpl.java index 32d7dd07..1ac75c1b 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/MachineLearnServiceImpl.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/MachineLearnServiceImpl.java @@ -12,6 +12,7 @@ import com.ruoyi.platform.utils.HttpUtils; import com.ruoyi.platform.utils.JsonUtils; import com.ruoyi.platform.vo.AutoMlParamVo; import com.ruoyi.platform.vo.TextClassificationParamVo; +import com.ruoyi.platform.vo.VideoClassificationParamVo; import com.ruoyi.system.api.constant.Constant; import org.apache.commons.collections4.MapUtils; import org.apache.commons.lang3.StringUtils; @@ -37,6 +38,8 @@ public class MachineLearnServiceImpl implements MachineLearnService { String convertAutoML; @Value("${argo.convertTextClassification}") String convertTextClassification; + @Value("${argo.convertVideoClassification}") + String convertVideoClassification; @Value("${argo.workflowRun}") private String argoWorkflowRun; @Value("${minio.endpointIp}") @@ -143,7 +146,12 @@ public class MachineLearnServiceImpl implements MachineLearnService { break; } case Constant.ML_VideoClassification: { - // todo + VideoClassificationParamVo paramVo = JsonUtils.jsonToObject(machineLearn.getParam(), VideoClassificationParamVo.class); + computingResourceId = paramVo.getComputingResourceId(); + if (resourceOccupyService.haveResource(computingResourceId, 1)) { + param = JsonUtils.getConvertParam(paramVo); + convertRes = HttpUtils.sendPost(argoUrl + convertVideoClassification, param); + } break; } } @@ -205,6 +213,7 @@ public class MachineLearnServiceImpl implements MachineLearnService { break; } case Constant.ML_VideoClassification: { + machineLearnIns.setResultPath(outputPath); break; } } diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/vo/VideoClassificationParamVo.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/vo/VideoClassificationParamVo.java new file mode 100644 index 00000000..38d5c21e --- /dev/null +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/vo/VideoClassificationParamVo.java @@ -0,0 +1,50 @@ +package com.ruoyi.platform.vo; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import io.swagger.annotations.ApiModel; +import io.swagger.annotations.ApiModelProperty; +import lombok.Data; + +import java.util.Map; + +@Data +@JsonNaming(PropertyNamingStrategy.SnakeCaseStrategy.class) +@JsonInclude(JsonInclude.Include.NON_NULL) +@ApiModel(description = "视频分类参数") +public class VideoClassificationParamVo { + + @ApiModelProperty(value = "类别数量") + private Integer numClasses; + + @ApiModelProperty(value = "数据集") + private Map dataset; + + @ApiModelProperty(value = "epochs") + private Integer epochs; + + @ApiModelProperty(value = "batch_size") + private Integer batchSize; + + @ApiModelProperty(value = "学习率") + private Float lr; + + @ApiModelProperty(value = "是否验证") + private Boolean isValidate; + + @ApiModelProperty(value = "训练集路径") + private String trainDataPrefix; + + @ApiModelProperty(value = "验证集路径") + private String validDataPrefix; + + @ApiModelProperty(value = "训练集标注文件") + private String trainFilePath; + + @ApiModelProperty(value = "验证集标注文件") + private String validFilePath; + + private Integer computingResourceId; + +}