Browse Source

查询模型版本指标对比

dev-czh
chenzhihang 1 year ago
parent
commit
2caa5286e5
4 changed files with 65 additions and 22 deletions
  1. +3
    -0
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/constant/Constant.java
  2. +6
    -4
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/model/NewModelFromGitController.java
  3. +3
    -2
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/ModelsService.java
  4. +53
    -16
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ModelsServiceImpl.java

+ 3
- 0
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/constant/Constant.java View File

@@ -31,4 +31,7 @@ public class Constant {
public final static String Pending = "Pending"; public final static String Pending = "Pending";
public final static String Init = "Init"; public final static String Init = "Init";
public final static String Stopped = "Stopped"; public final static String Stopped = "Stopped";

public final static String Type_Train = "train";
public final static String Type_Evaluate = "evaluate";
} }

+ 6
- 4
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/model/NewModelFromGitController.java View File

@@ -1,6 +1,7 @@
package com.ruoyi.platform.controller.model; package com.ruoyi.platform.controller.model;


import com.ruoyi.common.core.web.domain.AjaxResult; import com.ruoyi.common.core.web.domain.AjaxResult;
import com.ruoyi.platform.domain.ModelDependency1;
import com.ruoyi.platform.service.ModelsService; import com.ruoyi.platform.service.ModelsService;
import com.ruoyi.platform.vo.ModelsVo; import com.ruoyi.platform.vo.ModelsVo;
import io.swagger.annotations.Api; import io.swagger.annotations.Api;
@@ -79,15 +80,16 @@ public class NewModelFromGitController {
public AjaxResult queryVersions(@RequestParam(value = "page") int page, public AjaxResult queryVersions(@RequestParam(value = "page") int page,
@RequestParam(value = "size") int size, @RequestParam(value = "size") int size,
@RequestParam("identifier") String identifier, @RequestParam("identifier") String identifier,
@RequestParam("owner") String owner) throws Exception {
@RequestParam("owner") String owner,
@RequestParam("type") String type) throws Exception {
PageRequest pageRequest = PageRequest.of(page, size); PageRequest pageRequest = PageRequest.of(page, size);
return AjaxResult.success(this.modelsService.queryVersions(pageRequest, identifier, owner));
return AjaxResult.success(this.modelsService.queryVersions(pageRequest, identifier, owner, type));
} }


@GetMapping("/queryVersionsMetrics") @GetMapping("/queryVersionsMetrics")
@ApiOperation("查询版本指标") @ApiOperation("查询版本指标")
public AjaxResult queryVersionsMetrics(@RequestParam("runIds") List<String> runIds) throws Exception {
return AjaxResult.success(this.modelsService.queryVersionsMetrics(runIds));
public AjaxResult queryVersionsMetrics(@RequestParam("params") List<ModelDependency1> params, @RequestParam("type") String type) throws Exception {
return AjaxResult.success(this.modelsService.queryVersionsMetrics(params, type));
} }






+ 3
- 2
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/ModelsService.java View File

@@ -1,6 +1,7 @@
package com.ruoyi.platform.service; package com.ruoyi.platform.service;




import com.ruoyi.platform.domain.ModelDependency1;
import com.ruoyi.platform.domain.Models; import com.ruoyi.platform.domain.Models;
import com.ruoyi.platform.domain.ModelsVersion; import com.ruoyi.platform.domain.ModelsVersion;
import com.ruoyi.platform.vo.ModelDependency1TreeVo; import com.ruoyi.platform.vo.ModelDependency1TreeVo;
@@ -99,9 +100,9 @@ public interface ModelsService {


Page<ModelsVo> newPersonalQueryByPage(ModelsVo modelsVo, PageRequest pageRequest) throws Exception; Page<ModelsVo> newPersonalQueryByPage(ModelsVo modelsVo, PageRequest pageRequest) throws Exception;


Page<Map<String, Object>> queryVersions(PageRequest pageRequest, String identifier, String owner) throws Exception;
Page<Map<String, Object>> queryVersions(PageRequest pageRequest, String identifier, String owner, String type) throws Exception;


List<List<Map<String, Object>>> queryVersionsMetrics(List<String> runIds) throws Exception;
List<Object> queryVersionsMetrics(List<ModelDependency1> params, String type) throws Exception;


List<Map<String, Object>> getVersionList(String identifier, String owner) throws Exception; List<Map<String, Object>> getVersionList(String identifier, String owner) throws Exception;




+ 53
- 16
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ModelsServiceImpl.java View File

@@ -974,7 +974,7 @@ public class ModelsServiceImpl implements ModelsService {
} }


@Override @Override
public Page<Map<String, Object>> queryVersions(PageRequest pageRequest, String identifier, String owner) throws Exception {
public Page<Map<String, Object>> queryVersions(PageRequest pageRequest, String identifier, String owner, String type) throws Exception {
String token = gitService.checkoutToken(); String token = gitService.checkoutToken();
List<Map<String, Object>> collect = gitService.getBrancheList(token, owner, identifier); List<Map<String, Object>> collect = gitService.getBrancheList(token, owner, identifier);
List<Map<String, Object>> result = collect.stream() List<Map<String, Object>> result = collect.stream()
@@ -984,27 +984,45 @@ public class ModelsServiceImpl implements ModelsService {


for (Map<String, Object> branch : result) { for (Map<String, Object> branch : result) {
String meta = modelDependency1Dao.getMeta(identifier, owner, (String) branch.get("name")); String meta = modelDependency1Dao.getMeta(identifier, owner, (String) branch.get("name"));
ModelMetaVo modelMetaVo = JSON.parseObject(meta, ModelMetaVo.class);
if (modelMetaVo.getParams() != null) {
branch.putAll(modelMetaVo.getParams());
}
if (modelMetaVo.getMetrics() != null) {
branch.putAll(modelMetaVo.getMetrics());
}
if (modelMetaVo.getMetricsParams() != null) {
branch.putAll(modelMetaVo.getMetricsParams());
if (StringUtils.isNotEmpty(meta)) {
ModelMetaVo modelMetaVo = JSON.parseObject(meta, ModelMetaVo.class);
if (modelMetaVo.getParams() != null) {
HashMap<String, Object> params = modelMetaVo.getParams();
branch.putAll(params);
ArrayList<String> params_names = new ArrayList<>();
for (String key : params.keySet()) {
params_names.add(key);
}
branch.put("params_names", params_names);
}
if (modelMetaVo.getMetrics() != null) {
HashMap<String, Object> metrics = modelMetaVo.getMetrics();
if (Constant.Type_Train.equals(type)) {
Map<String, Object> trainMetrics = (Map<String, Object>) metrics.get(Constant.Type_Train);
branch.putAll(trainMetrics);
} else {
Map<String, Object> evaluateMetrics = (Map<String, Object>) metrics.get(Constant.Type_Evaluate);
branch.putAll(evaluateMetrics);
}
}
} }
} }
return new PageImpl<>(result, pageRequest, collect.size()); return new PageImpl<>(result, pageRequest, collect.size());
} }


@Override @Override
public List<List<Map<String, Object>>> queryVersionsMetrics(List<String> runIds) throws Exception {
List<List<Map<String, Object>>> batchMetrics = new ArrayList<>();
for (String runId : runIds) {
HashMap<String, Object> map = aimsService.queryMetricsParams(runId);
List<Map<String, Object>> batchMetric = aimsService.getBatchMetric((String) map.get("run_hash"), (String) map.get("params"));
batchMetrics.add(batchMetric);
public List<Object> queryVersionsMetrics(List<ModelDependency1> params, String type) {
List<Object> batchMetrics = new ArrayList<>();
for (ModelDependency1 model : params) {

ModelDependency1 modelDependency1 = modelDependency1Dao.queryByRepoAndVersion(model.getRepoId(), model.getIdentifier(), model.getVersion());
ModelMetaVo modelMetaVo = JSON.parseObject(modelDependency1.getMeta(), ModelMetaVo.class);
HashMap<String, Object> metrics = modelMetaVo.getMetrics();
if (Constant.Type_Train.equals(type)) {
batchMetrics.add(metrics.get("tarinDetail"));
} else {
batchMetrics.add(metrics.get("evaluateDetail"));
}
} }
return batchMetrics; return batchMetrics;
} }
@@ -1220,6 +1238,9 @@ public class ModelsServiceImpl implements ModelsService {
HashMap<String, Object> train = new HashMap<>(); HashMap<String, Object> train = new HashMap<>();
HashMap<String, Object> evaluate = new HashMap<>(); HashMap<String, Object> evaluate = new HashMap<>();


List<List<Map<String, Object>>> trainBatchMetrics = new ArrayList<>();
List<List<Map<String, Object>>> evaluateBatchMetrics = new ArrayList<>();

HashMap<String, Object> metrics = modelMetaVo.getMetricsParams(); HashMap<String, Object> metrics = modelMetaVo.getMetricsParams();
JSONArray trainMetrics = (JSONArray) metrics.get("train"); JSONArray trainMetrics = (JSONArray) metrics.get("train");
if (trainMetrics != null) { if (trainMetrics != null) {
@@ -1231,8 +1252,16 @@ public class ModelsServiceImpl implements ModelsService {
Map metrics1 = expTrainInfo.getMetrics(); Map metrics1 = expTrainInfo.getMetrics();
train.putAll(metrics1); train.putAll(metrics1);
} }

// 记录训练指标详情
HashMap<String, Object> map = aimsService.queryMetricsParams(runId);
List<Map<String, Object>> batchMetric = aimsService.getBatchMetric((String) map.get("run_hash"), (String) map.get("params"));
trainBatchMetrics.add(batchMetric);

} }
result.put("train", train); result.put("train", train);
result.put("tarinDetail", trainBatchMetrics);

} }


JSONArray testMetrics = (JSONArray) metrics.get("evaluate"); JSONArray testMetrics = (JSONArray) metrics.get("evaluate");
@@ -1245,8 +1274,16 @@ public class ModelsServiceImpl implements ModelsService {
Map metrics1 = expTestInfo.getMetrics(); Map metrics1 = expTestInfo.getMetrics();
evaluate.putAll(metrics1); evaluate.putAll(metrics1);
} }

// 记录验证指标详情
HashMap<String, Object> map = aimsService.queryMetricsParams(runId);
List<Map<String, Object>> batchMetric = aimsService.getBatchMetric((String) map.get("run_hash"), (String) map.get("params"));
evaluateBatchMetrics.add(batchMetric);

} }
result.put("evaluate", evaluate); result.put("evaluate", evaluate);
result.put("evaluateDetail", evaluateBatchMetrics);

} }
modelMetaVo.setMetrics(result); modelMetaVo.setMetrics(result);
} }


Loading…
Cancel
Save