diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/domain/ActiveLearn.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/domain/ActiveLearn.java index 4bee4f88..7c0f51c7 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/domain/ActiveLearn.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/domain/ActiveLearn.java @@ -20,47 +20,68 @@ public class ActiveLearn { @ApiModelProperty(value = "实验描述") private String description; - @ApiModelProperty(value = "数据集") - private String dataset; + @ApiModelProperty(value = "任务类型:classification, regression") + private String taskType; + + @ApiModelProperty(value = "模型文件路径") + private String modelPath; - @ApiModelProperty(value = "数据集csv文件中哪几列是预测目标列,逗号分隔") - private String targetColumns; + @ApiModelProperty(value = "模型类名称") + private String modelClassName; @ApiModelProperty(value = "分类算法") - private String classifierType; + private String classifierAlg; + + @ApiModelProperty(value = "回归算法") + private String regressorAlg; + + @ApiModelProperty(value = "dataset文件路径") + private String dataset; - @ApiModelProperty(value = "每次查询个数") - private Integer queryBatchSize; + @ApiModelProperty(value = "dataset类名") + private String datasetClassName; - @ApiModelProperty(value = "停止判则") - private String stoppingCriterion; + @ApiModelProperty(value = "数据集文件路径") + private String datasetPath; - @ApiModelProperty(value = "stopping_criterion为num_of_queries时传入,查询次数") - private Integer numOfQueries; + @ApiModelProperty(value = "数据量") + private Integer dataSize; - @ApiModelProperty(value = "stopping_criterion为cost_limit时传入,成本限制") - private Double costLimit; + @ApiModelProperty(value = "是否随机打乱") + private Boolean shuffle; - @ApiModelProperty(value = "stopping_criterion为percent_of_unlabel时传入,未标记比例") - private Double percentOfUnlabel; + @ApiModelProperty(value = "训练集数据量") + private Integer trainSize; - @ApiModelProperty(value = "stopping_criterion为time_limit时传入,时间限制") - private Double timeLimit; + @ApiModelProperty(value = "初始训练数据量") + private Integer nInitial; - @ApiModelProperty(value = "查询策略") + @ApiModelProperty(value = "查询次数") + private Integer nQueries; + + @ApiModelProperty(value = "每次查询数据量") + private Integer nInstances; + + @ApiModelProperty(value = "查询策略:uncertainty_sampling, uncertainty_batch_sampling, max_std_sampling, expected_improvement, upper_confidence_bound") private String queryStrategy; - @ApiModelProperty(value = "实验次数") - private Integer numOfExperiment; + @ApiModelProperty(value = "loss文件路径") + private String lossPath; + + @ApiModelProperty(value = "loss类名") + private String lossClassName; + + @ApiModelProperty(value = "多少轮查询保存一次模型参数") + private Integer nCheckpoint; - @ApiModelProperty(value = "测试集比率") - private Double testRatio; + @ApiModelProperty(value = "batch_size") + private Integer batchSize; - @ApiModelProperty(value = "初始使用标记数据比率") - private Double initialLabelRate; + @ApiModelProperty(value = "epochs") + private Integer epochs; - @ApiModelProperty(value = "指标") - private String performanceMetric; + @ApiModelProperty(value = "学习率") + private Float lr; private Integer state; diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ActiveLearnServiceImpl.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ActiveLearnServiceImpl.java index 3bcdee48..48ba0232 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ActiveLearnServiceImpl.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ActiveLearnServiceImpl.java @@ -42,8 +42,16 @@ public class ActiveLearnServiceImpl implements ActiveLearnService { String username = SecurityUtils.getLoginUser().getUsername(); activeLearn.setCreateBy(username); activeLearn.setUpdateBy(username); + String datasetJson = JacksonUtil.toJSONString(activeLearnVo.getDataset()); activeLearn.setDataset(datasetJson); + String modelJson = JacksonUtil.toJSONString(activeLearnVo.getModelPath()); + activeLearn.setModelPath(modelJson); + String datasetPathJson = JacksonUtil.toJSONString(activeLearnVo.getDatasetPath()); + activeLearn.setDatasetPath(datasetPathJson); + String lossPathJson = JacksonUtil.toJSONString(activeLearnVo.getLossPath()); + activeLearn.setLossPath(lossPathJson); + activeLearnDao.save(activeLearn); return activeLearn; } @@ -59,8 +67,16 @@ public class ActiveLearnServiceImpl implements ActiveLearnService { BeanUtils.copyProperties(activeLearnVo, activeLearn); String username = SecurityUtils.getLoginUser().getUsername(); activeLearn.setUpdateBy(username); + String datasetJson = JacksonUtil.toJSONString(activeLearnVo.getDataset()); activeLearn.setDataset(datasetJson); + String modelJson = JacksonUtil.toJSONString(activeLearnVo.getModelPath()); + activeLearn.setModelPath(modelJson); + String datasetPathJson = JacksonUtil.toJSONString(activeLearnVo.getDatasetPath()); + activeLearn.setDatasetPath(datasetPathJson); + String lossPathJson = JacksonUtil.toJSONString(activeLearnVo.getLossPath()); + activeLearn.setLossPath(lossPathJson); + activeLearnDao.edit(activeLearn); return "修改成功"; @@ -71,9 +87,18 @@ public class ActiveLearnServiceImpl implements ActiveLearnService { ActiveLearn activeLearn = activeLearnDao.getActiveLearnById(id); ActiveLearnVo activeLearnVo = new ActiveLearnVo(); BeanUtils.copyProperties(activeLearn, activeLearnVo); + if (StringUtils.isNotEmpty(activeLearn.getDataset())) { + activeLearnVo.setDatasetPath(JsonUtils.jsonToMap(activeLearn.getDatasetPath())); + } + if (StringUtils.isNotEmpty(activeLearn.getModelPath())) { + activeLearnVo.setModelPath(JsonUtils.jsonToMap(activeLearn.getModelPath())); + } if (StringUtils.isNotEmpty(activeLearn.getDataset())) { activeLearnVo.setDataset(JsonUtils.jsonToMap(activeLearn.getDataset())); } + if (StringUtils.isNotEmpty(activeLearn.getLossPath())) { + activeLearnVo.setLossPath(JsonUtils.jsonToMap(activeLearn.getLossPath())); + } return activeLearnVo; } @@ -101,7 +126,11 @@ public class ActiveLearnServiceImpl implements ActiveLearnService { ActiveLearnVo activeLearnParam = new ActiveLearnVo(); BeanUtils.copyProperties(activeLearn, activeLearnParam); + activeLearnParam.setDatasetPath(JsonUtils.jsonToMap(activeLearn.getDatasetPath())); activeLearnParam.setDataset(JsonUtils.jsonToMap(activeLearn.getDataset())); + activeLearnParam.setModelPath(JsonUtils.jsonToMap(activeLearn.getModelPath())); + activeLearnParam.setLossPath(JsonUtils.jsonToMap(activeLearn.getLossPath())); + String param = JsonUtils.objectToJson(activeLearnParam); // todo 调argo转换接口 diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/vo/ActiveLearnVo.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/vo/ActiveLearnVo.java index 538fdb33..b9e6adf1 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/vo/ActiveLearnVo.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/vo/ActiveLearnVo.java @@ -21,49 +21,68 @@ public class ActiveLearnVo { @ApiModelProperty(value = "实验描述") private String description; - /** - * 对应数据集 - */ + @ApiModelProperty(value = "任务类型:classification, regression") + private String taskType; + + @ApiModelProperty(value = "模型文件路径") + private Map modelPath; + + @ApiModelProperty(value = "模型类名称") + private String modelClassName; + + @ApiModelProperty(value = "分类算法") + private String classifierAlg; + + @ApiModelProperty(value = "分类算法") + private String regressorAlg; + + @ApiModelProperty(value = "dataset文件路径") private Map dataset; - @ApiModelProperty(value = "数据集csv文件中哪一列是预测目标列,逗号分隔") - private String targetColumns; + @ApiModelProperty(value = "dataset类名") + private String datasetClassName; - @ApiModelProperty(value = "分类算法:logistic_regression(逻辑回归),decision_tree(决策树),random_forest(随机森林),SVM(支持向量机),naive_bayes(朴素贝叶斯),GBM(梯度提升机)") - private String classifierType; + @ApiModelProperty(value = "数据集") + private Map datasetPath; - @ApiModelProperty(value = "每次查询个数") - private Integer queryBatchSize; + @ApiModelProperty(value = "数据量") + private Integer dataSize; - @ApiModelProperty(value = "停止判则:num_of_queries(查询次数),percent_of_unlabel(未标记样本比例),time_limit(时间限制)") - private String stoppingCriterion; + @ApiModelProperty(value = "是否随机打乱") + private Boolean shuffle; - @ApiModelProperty(value = "stopping_criterion为num_of_queries时传入,查询次数") - private Integer numOfQueries; + @ApiModelProperty(value = "训练集数据量") + private Integer trainSize; -// @ApiModelProperty(value = "stopping_criterion为cost_limit时传入,成本限制") -// private Double costLimit; + @ApiModelProperty(value = "初始训练数据量") + private Integer nInitial; - @ApiModelProperty(value = "stopping_criterion为percent_of_unlabel时传入,未标记比例") - private Double percentOfUnlabel; + @ApiModelProperty(value = "查询次数") + private Integer nQueries; - @ApiModelProperty(value = "stopping_criterion为time_limit时传入,时间限制") - private Double timeLimit; + @ApiModelProperty(value = "每次查询数据量") + private Integer nInstances; - @ApiModelProperty(value = "查询策略:Uncertainty(不确定性),QBC(委员会查询),Random(随机),GraphDensity(图密度)") + @ApiModelProperty(value = "查询策略:uncertainty_sampling, uncertainty_batch_sampling, max_std_sampling, expected_improvement, upper_confidence_bound") private String queryStrategy; - @ApiModelProperty(value = "实验次数") - private Integer numOfExperiment; + @ApiModelProperty(value = "loss文件路径") + private Map lossPath; + + @ApiModelProperty(value = "loss类名") + private String lossClassName; + + @ApiModelProperty(value = "多少轮查询保存一次模型参数") + private Integer nCheckpoint; - @ApiModelProperty(value = "测试集比率") - private Double testRatio; + @ApiModelProperty(value = "batch_size") + private Integer batchSize; - @ApiModelProperty(value = "初始使用标记数据比率") - private Double initialLabelRate; + @ApiModelProperty(value = "epochs") + private Integer epochs; - @ApiModelProperty(value = "指标:accuracy_score,roc_auc_score,get_fps_tps_thresholds,hamming_loss,one_error,coverage_error") - private String performanceMetric; + @ApiModelProperty(value = "学习率") + private Float lr; private Integer state; diff --git a/ruoyi-modules/management-platform/src/main/resources/mapper/managementPlatform/ActiveLearnDaoMapper.xml b/ruoyi-modules/management-platform/src/main/resources/mapper/managementPlatform/ActiveLearnDaoMapper.xml index 92b5cc23..70b3cc01 100644 --- a/ruoyi-modules/management-platform/src/main/resources/mapper/managementPlatform/ActiveLearnDaoMapper.xml +++ b/ruoyi-modules/management-platform/src/main/resources/mapper/managementPlatform/ActiveLearnDaoMapper.xml @@ -2,18 +2,18 @@ - insert into active_learn(name, description, dataset, target_columns, classifier_type, query_batch_size, - stopping_criterion, - num_of_queries, cost_limit, percent_of_unlabel, time_limit, query_strategy, - num_of_experiment, test_ratio, - initial_label_rate, performance_metric, create_by, update_by) - values (#{activeLearn.name}, #{activeLearn.description}, #{activeLearn.dataset}, #{activeLearn.targetColumns}, - #{activeLearn.classifierType}, #{activeLearn.queryBatchSize}, #{activeLearn.stoppingCriterion}, - #{activeLearn.numOfQueries}, - #{activeLearn.costLimit}, #{activeLearn.percentOfUnlabel}, #{activeLearn.timeLimit}, + insert into active_learn(name, description, task_type, model_path, model_class_name, classifier_alg, + regressor_alg, dataset, dataset_class_name, dataset_path, data_size, + shuffle, train_size, n_initial, n_queries, n_instances, query_strategy, + loss_path, loss_class_name, n_checkpoint, batch_size, epochs, lr, create_by, update_by) + values (#{activeLearn.name}, #{activeLearn.description}, #{activeLearn.taskType}, #{activeLearn.modelPath}, + #{activeLearn.modelClassName}, #{activeLearn.classifierAlg}, #{activeLearn.regressorAlg}, + #{activeLearn.dataset}, #{activeLearn.datasetClassName}, #{activeLearn.datasetPath}, + #{activeLearn.dataSize}, #{activeLearn.shuffle}, #{activeLearn.trainSize}, + #{activeLearn.nInitial}, #{activeLearn.nQueries}, #{activeLearn.nInstances}, #{activeLearn.queryStrategy}, - #{activeLearn.numOfExperiment}, #{activeLearn.testRatio}, #{activeLearn.initialLabelRate}, - #{activeLearn.performanceMetric}, + #{activeLearn.lossPath}, #{activeLearn.lossClassName}, #{activeLearn.nCheckpoint}, + #{activeLearn.batchSize}, #{activeLearn.epochs}, #{activeLearn.lr}, #{activeLearn.createBy}, #{activeLearn.updateBy}) @@ -26,47 +26,68 @@ description = #{activeLearn.description}, + + task_type = #{activeLearn.taskType}, + + + model_path = #{activeLearn.modelPath}, + + + model_class_name = #{activeLearn.modelClassName}, + + + classifier_alg = #{activeLearn.classifierAlg}, + + + regressor_alg = #{activeLearn.regressorAlg}, + dataset = #{activeLearn.dataset}, - - target_columns = #{activeLearn.targetColumns}, + + dataset_class_name = #{activeLearn.datasetClassName}, - - classifier_type = #{activeLearn.classifierType}, + + dataset_path = #{activeLearn.datasetPath}, - - query_batch_size = #{activeLearn.queryBatchSize}, + + data_size = #{activeLearn.dataSize}, - - stopping_criterion = #{activeLearn.stoppingCriterion}, + + shuffle = #{activeLearn.shuffle}, - - num_of_queries = #{activeLearn.numOfQueries}, + + train_size = #{activeLearn.trainSize}, - - cost_limit = #{activeLearn.costLimit}, + + n_initial = #{activeLearn.nInitial}, - - percent_of_unlabel = #{activeLearn.percentOfUnlabel}, + + n_queries = #{activeLearn.nQueries}, - - time_limit = #{activeLearn.timeLimit}, + + n_instances = #{activeLearn.nInstances}, query_strategy = #{activeLearn.queryStrategy}, - - num_of_experiment = #{activeLearn.numOfExperiment}, + + loss_path = #{activeLearn.lossPath}, + + + loss_class_name = #{activeLearn.lossClassName}, + + + n_checkpoint = #{activeLearn.nCheckpoint}, - - test_ratio = #{activeLearn.testRatio}, + + batch_size = #{activeLearn.batchSize}, - - initial_label_rate = #{activeLearn.initialLabelRate}, + + epochs = #{activeLearn.epochs}, - - performance_metric = #{activeLearn.performanceMetric}, + + lr = #{activeLearn.lr}, update_by = #{activeLearn.updateBy},