Browse Source

查询模型版本指标对比

dev-czh
chenzhihang 1 year ago
parent
commit
2caa5286e5
4 changed files with 65 additions and 22 deletions
  1. +3
    -0
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/constant/Constant.java
  2. +6
    -4
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/model/NewModelFromGitController.java
  3. +3
    -2
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/ModelsService.java
  4. +53
    -16
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ModelsServiceImpl.java

+ 3
- 0
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/constant/Constant.java View File

@@ -31,4 +31,7 @@ public class Constant {
public final static String Pending = "Pending";
public final static String Init = "Init";
public final static String Stopped = "Stopped";

public final static String Type_Train = "train";
public final static String Type_Evaluate = "evaluate";
}

+ 6
- 4
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/model/NewModelFromGitController.java View File

@@ -1,6 +1,7 @@
package com.ruoyi.platform.controller.model;

import com.ruoyi.common.core.web.domain.AjaxResult;
import com.ruoyi.platform.domain.ModelDependency1;
import com.ruoyi.platform.service.ModelsService;
import com.ruoyi.platform.vo.ModelsVo;
import io.swagger.annotations.Api;
@@ -79,15 +80,16 @@ public class NewModelFromGitController {
public AjaxResult queryVersions(@RequestParam(value = "page") int page,
@RequestParam(value = "size") int size,
@RequestParam("identifier") String identifier,
@RequestParam("owner") String owner) throws Exception {
@RequestParam("owner") String owner,
@RequestParam("type") String type) throws Exception {
PageRequest pageRequest = PageRequest.of(page, size);
return AjaxResult.success(this.modelsService.queryVersions(pageRequest, identifier, owner));
return AjaxResult.success(this.modelsService.queryVersions(pageRequest, identifier, owner, type));
}

@GetMapping("/queryVersionsMetrics")
@ApiOperation("查询版本指标")
public AjaxResult queryVersionsMetrics(@RequestParam("runIds") List<String> runIds) throws Exception {
return AjaxResult.success(this.modelsService.queryVersionsMetrics(runIds));
public AjaxResult queryVersionsMetrics(@RequestParam("params") List<ModelDependency1> params, @RequestParam("type") String type) throws Exception {
return AjaxResult.success(this.modelsService.queryVersionsMetrics(params, type));
}




+ 3
- 2
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/ModelsService.java View File

@@ -1,6 +1,7 @@
package com.ruoyi.platform.service;


import com.ruoyi.platform.domain.ModelDependency1;
import com.ruoyi.platform.domain.Models;
import com.ruoyi.platform.domain.ModelsVersion;
import com.ruoyi.platform.vo.ModelDependency1TreeVo;
@@ -99,9 +100,9 @@ public interface ModelsService {

Page<ModelsVo> newPersonalQueryByPage(ModelsVo modelsVo, PageRequest pageRequest) throws Exception;

Page<Map<String, Object>> queryVersions(PageRequest pageRequest, String identifier, String owner) throws Exception;
Page<Map<String, Object>> queryVersions(PageRequest pageRequest, String identifier, String owner, String type) throws Exception;

List<List<Map<String, Object>>> queryVersionsMetrics(List<String> runIds) throws Exception;
List<Object> queryVersionsMetrics(List<ModelDependency1> params, String type) throws Exception;

List<Map<String, Object>> getVersionList(String identifier, String owner) throws Exception;



+ 53
- 16
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ModelsServiceImpl.java View File

@@ -974,7 +974,7 @@ public class ModelsServiceImpl implements ModelsService {
}

@Override
public Page<Map<String, Object>> queryVersions(PageRequest pageRequest, String identifier, String owner) throws Exception {
public Page<Map<String, Object>> queryVersions(PageRequest pageRequest, String identifier, String owner, String type) throws Exception {
String token = gitService.checkoutToken();
List<Map<String, Object>> collect = gitService.getBrancheList(token, owner, identifier);
List<Map<String, Object>> result = collect.stream()
@@ -984,27 +984,45 @@ public class ModelsServiceImpl implements ModelsService {

for (Map<String, Object> branch : result) {
String meta = modelDependency1Dao.getMeta(identifier, owner, (String) branch.get("name"));
ModelMetaVo modelMetaVo = JSON.parseObject(meta, ModelMetaVo.class);
if (modelMetaVo.getParams() != null) {
branch.putAll(modelMetaVo.getParams());
}
if (modelMetaVo.getMetrics() != null) {
branch.putAll(modelMetaVo.getMetrics());
}
if (modelMetaVo.getMetricsParams() != null) {
branch.putAll(modelMetaVo.getMetricsParams());
if (StringUtils.isNotEmpty(meta)) {
ModelMetaVo modelMetaVo = JSON.parseObject(meta, ModelMetaVo.class);
if (modelMetaVo.getParams() != null) {
HashMap<String, Object> params = modelMetaVo.getParams();
branch.putAll(params);
ArrayList<String> params_names = new ArrayList<>();
for (String key : params.keySet()) {
params_names.add(key);
}
branch.put("params_names", params_names);
}
if (modelMetaVo.getMetrics() != null) {
HashMap<String, Object> metrics = modelMetaVo.getMetrics();
if (Constant.Type_Train.equals(type)) {
Map<String, Object> trainMetrics = (Map<String, Object>) metrics.get(Constant.Type_Train);
branch.putAll(trainMetrics);
} else {
Map<String, Object> evaluateMetrics = (Map<String, Object>) metrics.get(Constant.Type_Evaluate);
branch.putAll(evaluateMetrics);
}
}
}
}
return new PageImpl<>(result, pageRequest, collect.size());
}

@Override
public List<List<Map<String, Object>>> queryVersionsMetrics(List<String> runIds) throws Exception {
List<List<Map<String, Object>>> batchMetrics = new ArrayList<>();
for (String runId : runIds) {
HashMap<String, Object> map = aimsService.queryMetricsParams(runId);
List<Map<String, Object>> batchMetric = aimsService.getBatchMetric((String) map.get("run_hash"), (String) map.get("params"));
batchMetrics.add(batchMetric);
public List<Object> queryVersionsMetrics(List<ModelDependency1> params, String type) {
List<Object> batchMetrics = new ArrayList<>();
for (ModelDependency1 model : params) {

ModelDependency1 modelDependency1 = modelDependency1Dao.queryByRepoAndVersion(model.getRepoId(), model.getIdentifier(), model.getVersion());
ModelMetaVo modelMetaVo = JSON.parseObject(modelDependency1.getMeta(), ModelMetaVo.class);
HashMap<String, Object> metrics = modelMetaVo.getMetrics();
if (Constant.Type_Train.equals(type)) {
batchMetrics.add(metrics.get("tarinDetail"));
} else {
batchMetrics.add(metrics.get("evaluateDetail"));
}
}
return batchMetrics;
}
@@ -1220,6 +1238,9 @@ public class ModelsServiceImpl implements ModelsService {
HashMap<String, Object> train = new HashMap<>();
HashMap<String, Object> evaluate = new HashMap<>();

List<List<Map<String, Object>>> trainBatchMetrics = new ArrayList<>();
List<List<Map<String, Object>>> evaluateBatchMetrics = new ArrayList<>();

HashMap<String, Object> metrics = modelMetaVo.getMetricsParams();
JSONArray trainMetrics = (JSONArray) metrics.get("train");
if (trainMetrics != null) {
@@ -1231,8 +1252,16 @@ public class ModelsServiceImpl implements ModelsService {
Map metrics1 = expTrainInfo.getMetrics();
train.putAll(metrics1);
}

// 记录训练指标详情
HashMap<String, Object> map = aimsService.queryMetricsParams(runId);
List<Map<String, Object>> batchMetric = aimsService.getBatchMetric((String) map.get("run_hash"), (String) map.get("params"));
trainBatchMetrics.add(batchMetric);

}
result.put("train", train);
result.put("tarinDetail", trainBatchMetrics);

}

JSONArray testMetrics = (JSONArray) metrics.get("evaluate");
@@ -1245,8 +1274,16 @@ public class ModelsServiceImpl implements ModelsService {
Map metrics1 = expTestInfo.getMetrics();
evaluate.putAll(metrics1);
}

// 记录验证指标详情
HashMap<String, Object> map = aimsService.queryMetricsParams(runId);
List<Map<String, Object>> batchMetric = aimsService.getBatchMetric((String) map.get("run_hash"), (String) map.get("params"));
evaluateBatchMetrics.add(batchMetric);

}
result.put("evaluate", evaluate);
result.put("evaluateDetail", evaluateBatchMetrics);

}
modelMetaVo.setMetrics(result);
}


Loading…
Cancel
Save