Browse Source

优化主动学习

dev-active_learn
chenzhihang 9 months ago
parent
commit
f93805383f
9 changed files with 49 additions and 33 deletions
  1. +4
    -4
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/domain/ActiveLearn.java
  2. +2
    -2
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/mapper/ResourceOccupyDao.java
  3. +1
    -1
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ActiveLearnServiceImpl.java
  4. +1
    -1
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/AutoMlServiceImpl.java
  5. +1
    -1
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/RayServiceImpl.java
  6. +17
    -4
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/utils/JsonUtils.java
  7. +4
    -4
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/vo/ActiveLearnParamVo.java
  8. +7
    -4
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/vo/ActiveLearnVo.java
  9. +12
    -12
      ruoyi-modules/management-platform/src/main/resources/mapper/managementPlatform/ActiveLearnDaoMapper.xml

+ 4
- 4
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/domain/ActiveLearn.java View File

@@ -69,13 +69,13 @@ public class ActiveLearn {
private Integer trainSize;

@ApiModelProperty(value = "初始训练数据量")
private Integer nInitial;
private Integer initialNum;

@ApiModelProperty(value = "查询次数")
private Integer nQueries;
private Integer queriesNum;

@ApiModelProperty(value = "每次查询数据量")
private Integer nInstances;
private Integer instancesNum;

@ApiModelProperty(value = "查询策略:uncertainty_sampling, uncertainty_batch_sampling, max_std_sampling, expected_improvement, upper_confidence_bound")
private String queryStrategy;
@@ -87,7 +87,7 @@ public class ActiveLearn {
private String lossClassName;

@ApiModelProperty(value = "多少轮查询保存一次模型参数")
private Integer nCheckpoint;
private Integer checkpointNum;

@ApiModelProperty(value = "batch_size")
private Integer batchSize;


+ 2
- 2
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/mapper/ResourceOccupyDao.java View File

@@ -14,7 +14,7 @@ public interface ResourceOccupyDao {

int edit(@Param("resourceOccupy") ResourceOccupy resourceOccupy);

List<ResourceOccupy> getResourceOccupyByTask(@Param("taskType") String taskType, @Param("taskId") Long taskId, @Param("taskInsId") Long taskInsId, @Param("nodeId") String nodeId);
List<ResourceOccupy> getResourceOccupyByTask(@Param("taskType") String taskType, @Param("taskId") Long taskId, @Param("taskInsId") Long taskInsId, @Param("nodeId") String nodeId);

int deduceCredit(@Param("credit") Double credit, @Param("userId") Long userId);

@@ -30,5 +30,5 @@ public interface ResourceOccupyDao {

Double getDeduceCredit(@Param("userId") Long userId);

int deleteTaskState(String taskType, Long taskId, Long taskInsId);
int deleteTaskState(@Param("taskType") String taskType, @Param("taskId") Long taskId, @Param("taskInsId") Long taskInsId);
}

+ 1
- 1
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ActiveLearnServiceImpl.java View File

@@ -146,8 +146,8 @@ public class ActiveLearnServiceImpl implements ActiveLearnService {
activeLearnParamVo.setDataset(JsonUtils.jsonToMap(activeLearn.getDataset()));
activeLearnParamVo.setModel(JsonUtils.jsonToMap(activeLearn.getModel()));
activeLearnParamVo.setImage(JsonUtils.jsonToMap(activeLearn.getImage()));
String param = JsonUtils.objectToJson(activeLearnParamVo);

String param = JsonUtils.getConvertParam(activeLearnParamVo);
// 调argo转换接口
try {
String convertRes = HttpUtils.sendPost(argoUrl + convertActiveLearn, param);


+ 1
- 1
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/AutoMlServiceImpl.java View File

@@ -157,7 +157,7 @@ public class AutoMlServiceImpl implements AutoMlService {
AutoMlParamVo autoMlParam = new AutoMlParamVo();
BeanUtils.copyProperties(autoMl, autoMlParam);
autoMlParam.setDataset(JsonUtils.jsonToMap(autoMl.getDataset()));
String param = JsonUtils.objectToJson(autoMlParam);
String param = JsonUtils.getConvertParam(autoMlParam);
// 调argo转换接口
try {
String convertRes = HttpUtils.sendPost(argoUrl + convertAutoML, param);


+ 1
- 1
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/RayServiceImpl.java View File

@@ -163,7 +163,7 @@ public class RayServiceImpl implements RayService {
rayParamVo.setDataset(JsonUtils.jsonToMap(ray.getDataset()));
rayParamVo.setModel(JsonUtils.jsonToMap(ray.getModel()));
rayParamVo.setImage(JsonUtils.jsonToMap(ray.getImage()));
String param = JsonUtils.objectToJson(rayParamVo);
String param = JsonUtils.getConvertParam(rayParamVo);

// 调argo转换接口
try {


+ 17
- 4
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/utils/JsonUtils.java View File

@@ -1,11 +1,12 @@
package com.ruoyi.platform.utils;

import com.alibaba.fastjson2.JSON;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.json.JSONObject;
import com.ruoyi.common.security.utils.SecurityUtils;

import java.io.IOException;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;

public class JsonUtils {
@@ -33,9 +34,8 @@ public class JsonUtils {
}



// 将JSON字符串转换为扁平化的Map
public static Map<String, Object> flattenJson(String prefix, Map<String, Object> map) {
public static Map<String, Object> flattenJson(String prefix, Map<String, Object> map) {
Map<String, Object> flatMap = new HashMap<>();

for (Map.Entry<String, Object> entry : map.entrySet()) {
@@ -56,4 +56,17 @@ public class JsonUtils {
public static Map<String, Object> objectToMap(Object object) throws IOException {
return objectMapper.convertValue(object, Map.class);
}

public static String getConvertParam(Object object) throws JsonProcessingException {
HashMap<Object, Object> paramMap = new HashMap<>();
paramMap.put("data", JSON.parseObject(objectToJson(object), Map.class));

HashMap<Object, Object> userInfoMap = new HashMap<>();
userInfoMap.put("name", SecurityUtils.getLoginUser().getUsername());
userInfoMap.put("token", SecurityUtils.getLoginUser().getSysUser().getOriginPassword());
HashMap<Object, Object> extraInfoMap = new HashMap<>();
extraInfoMap.put("user_info", userInfoMap);
paramMap.put("extra_info", extraInfoMap);
return objectToJson(paramMap);
}
}

+ 4
- 4
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/vo/ActiveLearnParamVo.java View File

@@ -43,11 +43,11 @@ public class ActiveLearnParamVo {

private Integer trainSize;

private Integer nInitial;
private Integer initialNum;

private Integer nQueries;
private Integer queriesNum;

private Integer nInstances;
private Integer instancesNum;

private String queryStrategy;

@@ -55,7 +55,7 @@ public class ActiveLearnParamVo {

private String lossClassName;

private Integer nCheckpoint;
private Integer checkpointNum;

private Integer batchSize;



+ 7
- 4
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/vo/ActiveLearnVo.java View File

@@ -24,6 +24,9 @@ public class ActiveLearnVo {
@ApiModelProperty(value = "任务类型:classification或regression")
private String taskType;

@ApiModelProperty(value = "框架类型: sklearn, keras, pytorch")
private String frameworkType;

@ApiModelProperty(value = "代码")
private Map<String, Object> codeConfig;

@@ -67,13 +70,13 @@ public class ActiveLearnVo {
private Integer trainSize;

@ApiModelProperty(value = "初始训练数据量")
private Integer nInitial;
private Integer initialNum;

@ApiModelProperty(value = "查询次数")
private Integer nQueries;
private Integer queriesNum;

@ApiModelProperty(value = "每次查询数据量")
private Integer nInstances;
private Integer instancesNum;

@ApiModelProperty(value = "查询策略:uncertainty_sampling, uncertainty_batch_sampling, max_std_sampling, expected_improvement, upper_confidence_bound")
private String queryStrategy;
@@ -85,7 +88,7 @@ public class ActiveLearnVo {
private String lossClassName;

@ApiModelProperty(value = "多少轮查询保存一次模型参数")
private Integer nCheckpoint;
private Integer checkpointNum;

@ApiModelProperty(value = "batch_size")
private Integer batchSize;


+ 12
- 12
ruoyi-modules/management-platform/src/main/resources/mapper/managementPlatform/ActiveLearnDaoMapper.xml View File

@@ -25,8 +25,8 @@
model, model_py, model_class_name,
classifier_alg, regressor_alg, dataset_py, dataset_class_name,
dataset, data_size, image, computing_resource_id, shuffle,
train_size, n_initial, n_queries, n_instances, query_strategy,
loss_py, loss_class_name, n_checkpoint, batch_size, epochs,
train_size, initial_num, queries_num, instances_num, query_strategy,
loss_py, loss_class_name, checkpoint_num, batch_size, epochs,
lr, create_by, update_by)
values (#{activeLearn.name}, #{activeLearn.description}, #{activeLearn.taskType}, #{activeLearn.frameworkType}, #{activeLearn.codeConfig},
#{activeLearn.model},
@@ -36,9 +36,9 @@
#{activeLearn.dataset}, #{activeLearn.dataSize},
#{activeLearn.image},
#{activeLearn.computingResourceId}, #{activeLearn.shuffle},
#{activeLearn.trainSize}, #{activeLearn.nInitial}, #{activeLearn.nQueries}, #{activeLearn.nInstances},
#{activeLearn.trainSize}, #{activeLearn.initialNum}, #{activeLearn.queriesNum}, #{activeLearn.instancesNum},
#{activeLearn.queryStrategy}, #{activeLearn.lossPy},
#{activeLearn.lossClassName}, #{activeLearn.nCheckpoint}, #{activeLearn.batchSize},
#{activeLearn.lossClassName}, #{activeLearn.checkpointNum}, #{activeLearn.batchSize},
#{activeLearn.epochs}, #{activeLearn.lr}, #{activeLearn.createBy}, #{activeLearn.updateBy})
</insert>

@@ -99,14 +99,14 @@
<if test="activeLearn.trainSize != null">
train_size = #{activeLearn.trainSize},
</if>
<if test="activeLearn.nInitial != null">
n_initial = #{activeLearn.nInitial},
<if test="activeLearn.initialNum != null">
initial_num = #{activeLearn.initialNum},
</if>
<if test="activeLearn.nQueries != null">
n_queries = #{activeLearn.nQueries},
<if test="activeLearn.queriesNum != null">
queries_num = #{activeLearn.queriesNum},
</if>
<if test="activeLearn.nInstances != null">
n_instances = #{activeLearn.nInstances},
<if test="activeLearn.instancesNum != null">
instances_num = #{activeLearn.instancesNum},
</if>
<if test="activeLearn.queryStrategy != null and activeLearn.queryStrategy !=''">
query_strategy = #{activeLearn.queryStrategy},
@@ -117,8 +117,8 @@
<if test="activeLearn.lossClassName != null and activeLearn.lossClassName !=''">
loss_class_name = #{activeLearn.lossClassName},
</if>
<if test="activeLearn.nCheckpoint != null">
n_checkpoint = #{activeLearn.nCheckpoint},
<if test="activeLearn.checkpointNum != null">
checkpoint_num = #{activeLearn.checkpointNum},
</if>
<if test="activeLearn.batchSize != null">
batch_size = #{activeLearn.batchSize},


Loading…
Cancel
Save