Browse Source

查询模型版本指标对比

dev-lhz
chenzhihang 1 year ago
parent
commit
952f469fbc
10 changed files with 84 additions and 7 deletions
  1. +2
    -0
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/RuoYiManagementPlatformApplication.java
  2. +7
    -0
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/model/NewModelFromGitController.java
  3. +1
    -1
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/mapper/ModelDependency1Dao.java
  4. +6
    -0
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/AimService.java
  5. +2
    -0
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/ModelsService.java
  6. +42
    -0
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/AimServiceImpl.java
  7. +1
    -1
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ExperimentServiceImpl.java
  8. +18
    -4
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ModelsServiceImpl.java
  9. +3
    -0
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/vo/ModelMetaVo.java
  10. +2
    -1
      ruoyi-modules/management-platform/src/main/resources/mapper/managementPlatform/ModelDependency1DaoMapper.xml

+ 2
- 0
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/RuoYiManagementPlatformApplication.java View File

@@ -5,6 +5,7 @@ import com.ruoyi.common.security.annotation.EnableRyFeignClients;
import com.ruoyi.common.swagger.annotation.EnableCustomSwagger2; import com.ruoyi.common.swagger.annotation.EnableCustomSwagger2;
import org.springframework.boot.SpringApplication; import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication; import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.scheduling.annotation.EnableAsync;
import org.springframework.scheduling.annotation.EnableScheduling; import org.springframework.scheduling.annotation.EnableScheduling;


/** /**
@@ -17,6 +18,7 @@ import org.springframework.scheduling.annotation.EnableScheduling;
@EnableRyFeignClients @EnableRyFeignClients
@SpringBootApplication @SpringBootApplication
@EnableScheduling @EnableScheduling
@EnableAsync
public class RuoYiManagementPlatformApplication { public class RuoYiManagementPlatformApplication {
public static void main(String[] args) { public static void main(String[] args) {
SpringApplication.run(RuoYiManagementPlatformApplication.class, args); SpringApplication.run(RuoYiManagementPlatformApplication.class, args);


+ 7
- 0
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/model/NewModelFromGitController.java View File

@@ -84,6 +84,13 @@ public class NewModelFromGitController {
return AjaxResult.success(this.modelsService.queryVersions(pageRequest, identifier, owner)); return AjaxResult.success(this.modelsService.queryVersions(pageRequest, identifier, owner));
} }


@GetMapping("/queryVersionsMetrics")
@ApiOperation("查询版本指标")
public AjaxResult queryVersionsMetrics(@RequestParam("runIds") List<String> runIds) throws Exception {
return AjaxResult.success(this.modelsService.queryVersionsMetrics(runIds));
}


@GetMapping("/getVersionList") @GetMapping("/getVersionList")
@ApiOperation(value = "获取模型分支列表") @ApiOperation(value = "获取模型分支列表")
public AjaxResult getVersionList(@RequestParam("identifier") String identifier, @RequestParam("owner") String owner) throws Exception { public AjaxResult getVersionList(@RequestParam("identifier") String identifier, @RequestParam("owner") String owner) throws Exception {


+ 1
- 1
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/mapper/ModelDependency1Dao.java View File

@@ -9,7 +9,7 @@ public interface ModelDependency1Dao {


int insert(ModelDependency1 modelDependency1); int insert(ModelDependency1 modelDependency1);


int updateState(@Param("repoId") Integer repoId, @Param("identifier") String identifier, @Param("version") String version, @Param("state") Integer state);
int updateState(@Param("repoId") Integer repoId, @Param("identifier") String identifier, @Param("version") String version, @Param("meta") String meta, @Param("state") Integer state);


List<ModelDependency1> queryModelDependency(@Param("modelName") String modelName, @Param("repoId") Integer repoId, @Param("owner") String owner); List<ModelDependency1> queryModelDependency(@Param("modelName") String modelName, @Param("repoId") Integer repoId, @Param("owner") String owner);




+ 6
- 0
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/AimService.java View File

@@ -2,7 +2,9 @@ package com.ruoyi.platform.service;


import com.ruoyi.platform.vo.InsMetricInfoVo; import com.ruoyi.platform.vo.InsMetricInfoVo;


import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map;


public interface AimService { public interface AimService {


@@ -13,4 +15,8 @@ public interface AimService {
List<InsMetricInfoVo> getExpEvaluateInfos(Integer experimentId) throws Exception; List<InsMetricInfoVo> getExpEvaluateInfos(Integer experimentId) throws Exception;


String getExpMetrics(List<String> runIds) throws Exception; String getExpMetrics(List<String> runIds) throws Exception;

HashMap<String, Object> queryMetricsParams(String runId) throws Exception;

List<Map<String, Object>> getBatchMetric(String runHash, String params);
} }

+ 2
- 0
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/ModelsService.java View File

@@ -101,6 +101,8 @@ public interface ModelsService {


Page<Map<String, Object>> queryVersions(PageRequest pageRequest, String identifier, String owner) throws Exception; Page<Map<String, Object>> queryVersions(PageRequest pageRequest, String identifier, String owner) throws Exception;


List<List<Map<String, Object>>> queryVersionsMetrics(List<String> runIds) throws Exception;

List<Map<String, Object>> getVersionList(String identifier, String owner) throws Exception; List<Map<String, Object>> getVersionList(String identifier, String owner) throws Exception;


ModelsVo getModelDetail(Integer id, String identifier, String owner, String version) throws Exception; ModelsVo getModelDetail(Integer id, String identifier, String owner, String version) throws Exception;


+ 42
- 0
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/AimServiceImpl.java View File

@@ -1,5 +1,6 @@
package com.ruoyi.platform.service.impl; package com.ruoyi.platform.service.impl;


import com.alibaba.fastjson2.JSON;
import com.ruoyi.platform.domain.ExperimentIns; import com.ruoyi.platform.domain.ExperimentIns;
import com.ruoyi.platform.service.AimService; import com.ruoyi.platform.service.AimService;
import com.ruoyi.platform.service.ExperimentInsService; import com.ruoyi.platform.service.ExperimentInsService;
@@ -13,6 +14,7 @@ import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;


import javax.annotation.Resource; import javax.annotation.Resource;
import java.io.UnsupportedEncodingException;
import java.net.URLEncoder; import java.net.URLEncoder;
import java.util.*; import java.util.*;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@@ -245,4 +247,44 @@ public class AimServiceImpl implements AimService {
} }
return datasetList; return datasetList;
} }

@Override
public HashMap<String, Object> queryMetricsParams(String runId) throws UnsupportedEncodingException {
String encodedUrlString = URLEncoder.encode("run.id==\"" + runId + "\"", "UTF-8");
String url = aimProxyUrl + "/api/runs/search/run?query=" + encodedUrlString;
String s = httpUtils.sendGet(url, null);
List<Map<String, Object>> response = JacksonUtil.parseJSONStr2MapList(s);
if (response == null || response.size() == 0) {
return new HashMap<>();
}

HashMap<String, Object> resultMap = new HashMap<>();
List<Map<String, Object>> paramList = new ArrayList<>();

Map<String, Object> map = response.get(0);
LinkedHashMap<String, ArrayList> traces = (LinkedHashMap<String, ArrayList>) map.get("traces");
if (traces != null) {
List<Map<String, Object>> metrics = traces.get("metric");
for (Map<String, Object> metric : metrics) {
Map<String, Object> metricParam = new HashMap<>();
metricParam.put("context", metric.get("context"));
metricParam.put("name", metric.get("name"));
paramList.add(metricParam);
}
resultMap.put("params", JSON.toJSONString(paramList));
}
resultMap.put("run_hash", map.get("run_hash"));
return resultMap;
}

@Override
public List<Map<String, Object>> getBatchMetric(String runHash, String params) {
String url = aimUrl + "/api/runs/" + runHash + "/metric/get-batch";
String response = httpUtils.sendPost(url, null, params);
if (StringUtils.isNotEmpty(response)) {
return JacksonUtil.parseJSONStr2MapList(response);
}
return null;
}

} }

+ 1
- 1
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ExperimentServiceImpl.java View File

@@ -560,7 +560,7 @@ public class ExperimentServiceImpl implements ExperimentService {


//处理指标 //处理指标
HashMap<String, Object> metricMap = JSON.parseObject(metricRecord, HashMap.class); HashMap<String, Object> metricMap = JSON.parseObject(metricRecord, HashMap.class);
modelMetaVo.setMetrics(metricMap);
modelMetaVo.setMetricsParams(metricMap);


//训练数据集 //训练数据集
List<Map<String, Object>> trainDatasetList = (List<Map<String, Object>>) modelTrainMap.get("datasets"); List<Map<String, Object>> trainDatasetList = (List<Map<String, Object>>) modelTrainMap.get("datasets");


+ 18
- 4
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ModelsServiceImpl.java View File

@@ -747,7 +747,7 @@ public class ModelsServiceImpl implements ModelsService {
if (buildingModel != null) { if (buildingModel != null) {
modelMetaVo = JSON.parseObject(buildingModel.getMeta(), ModelMetaVo.class); modelMetaVo = JSON.parseObject(buildingModel.getMeta(), ModelMetaVo.class);
//获取指标 //获取指标
getMetrics(modelMetaVo);
transMetrics(modelMetaVo);
} }


//拼接生产的元数据后写入yaml文件 //拼接生产的元数据后写入yaml文件
@@ -806,7 +806,7 @@ public class ModelsServiceImpl implements ModelsService {
modelDependency1Dao.insert(modelDependency); modelDependency1Dao.insert(modelDependency);
} else { } else {
//更新模型依赖 //更新模型依赖
modelDependency1Dao.updateState(modelsVo.getId(), modelsVo.getIdentifier(), modelsVo.getVersion(), Constant.State_valid);
modelDependency1Dao.updateState(modelsVo.getId(), modelsVo.getIdentifier(), modelsVo.getVersion(), meta, Constant.State_valid);
} }
} else { } else {
//保存模型依赖 //保存模型依赖
@@ -991,10 +991,24 @@ public class ModelsServiceImpl implements ModelsService {
if (modelMetaVo.getMetrics() != null) { if (modelMetaVo.getMetrics() != null) {
branch.putAll(modelMetaVo.getMetrics()); branch.putAll(modelMetaVo.getMetrics());
} }
if (modelMetaVo.getMetricsParams() != null) {
branch.putAll(modelMetaVo.getMetricsParams());
}
} }
return new PageImpl<>(result, pageRequest, collect.size()); return new PageImpl<>(result, pageRequest, collect.size());
} }


@Override
public List<List<Map<String, Object>>> queryVersionsMetrics(List<String> runIds) throws Exception {
List<List<Map<String, Object>>> batchMetrics = new ArrayList<>();
for (String runId : runIds) {
HashMap<String, Object> map = aimsService.queryMetricsParams(runId);
List<Map<String, Object>> batchMetric = aimsService.getBatchMetric((String) map.get("run_hash"), (String) map.get("params"));
batchMetrics.add(batchMetric);
}
return batchMetrics;
}

@Override @Override
public List<Map<String, Object>> getVersionList(String identifier, String owner) throws Exception { public List<Map<String, Object>> getVersionList(String identifier, String owner) throws Exception {
String token = gitService.checkoutToken(); String token = gitService.checkoutToken();
@@ -1201,12 +1215,12 @@ public class ModelsServiceImpl implements ModelsService {
return userInfo; return userInfo;
} }


void getMetrics(ModelMetaVo modelMetaVo) throws Exception {
void transMetrics(ModelMetaVo modelMetaVo) throws Exception {
HashMap<String, Object> result = new HashMap<>(); HashMap<String, Object> result = new HashMap<>();
HashMap<String, Object> train = new HashMap<>(); HashMap<String, Object> train = new HashMap<>();
HashMap<String, Object> evaluate = new HashMap<>(); HashMap<String, Object> evaluate = new HashMap<>();


HashMap<String, Object> metrics = modelMetaVo.getMetrics();
HashMap<String, Object> metrics = modelMetaVo.getMetricsParams();
JSONArray trainMetrics = (JSONArray) metrics.get("train"); JSONArray trainMetrics = (JSONArray) metrics.get("train");
if (trainMetrics != null) { if (trainMetrics != null) {
for (int i = 0; i < trainMetrics.size(); i++) { for (int i = 0; i < trainMetrics.size(); i++) {


+ 3
- 0
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/vo/ModelMetaVo.java View File

@@ -59,6 +59,9 @@ public class ModelMetaVo implements Serializable {
@ApiModelProperty(value = "指标") @ApiModelProperty(value = "指标")
private HashMap<String, Object> metrics; private HashMap<String, Object> metrics;


@ApiModelProperty(value = "指标查询参数")
private HashMap<String, Object> metricsParams;

@ApiModelProperty(value = "训练任务") @ApiModelProperty(value = "训练任务")
private TrainTaskDepency trainTask; private TrainTaskDepency trainTask;




+ 2
- 1
ruoyi-modules/management-platform/src/main/resources/mapper/managementPlatform/ModelDependency1DaoMapper.xml View File

@@ -86,7 +86,8 @@


<update id="updateState"> <update id="updateState">
update model_dependency1 update model_dependency1
set state = 1
set state = 1,
meta = #{meta}
where repo_id = #{repoId} where repo_id = #{repoId}
and identifier = #{identifier} and identifier = #{identifier}
and version = #{version} and version = #{version}


Loading…
Cancel
Save