| @@ -5,6 +5,7 @@ import com.ruoyi.common.security.annotation.EnableRyFeignClients; | |||||
| import com.ruoyi.common.swagger.annotation.EnableCustomSwagger2; | import com.ruoyi.common.swagger.annotation.EnableCustomSwagger2; | ||||
| import org.springframework.boot.SpringApplication; | import org.springframework.boot.SpringApplication; | ||||
| import org.springframework.boot.autoconfigure.SpringBootApplication; | import org.springframework.boot.autoconfigure.SpringBootApplication; | ||||
| import org.springframework.scheduling.annotation.EnableAsync; | |||||
| import org.springframework.scheduling.annotation.EnableScheduling; | import org.springframework.scheduling.annotation.EnableScheduling; | ||||
| /** | /** | ||||
| @@ -17,6 +18,7 @@ import org.springframework.scheduling.annotation.EnableScheduling; | |||||
| @EnableRyFeignClients | @EnableRyFeignClients | ||||
| @SpringBootApplication | @SpringBootApplication | ||||
| @EnableScheduling | @EnableScheduling | ||||
| @EnableAsync | |||||
| public class RuoYiManagementPlatformApplication { | public class RuoYiManagementPlatformApplication { | ||||
| public static void main(String[] args) { | public static void main(String[] args) { | ||||
| SpringApplication.run(RuoYiManagementPlatformApplication.class, args); | SpringApplication.run(RuoYiManagementPlatformApplication.class, args); | ||||
| @@ -84,6 +84,13 @@ public class NewModelFromGitController { | |||||
| return AjaxResult.success(this.modelsService.queryVersions(pageRequest, identifier, owner)); | return AjaxResult.success(this.modelsService.queryVersions(pageRequest, identifier, owner)); | ||||
| } | } | ||||
| @GetMapping("/queryVersionsMetrics") | |||||
| @ApiOperation("查询版本指标") | |||||
| public AjaxResult queryVersionsMetrics(@RequestParam("runIds") List<String> runIds) throws Exception { | |||||
| return AjaxResult.success(this.modelsService.queryVersionsMetrics(runIds)); | |||||
| } | |||||
| @GetMapping("/getVersionList") | @GetMapping("/getVersionList") | ||||
| @ApiOperation(value = "获取模型分支列表") | @ApiOperation(value = "获取模型分支列表") | ||||
| public AjaxResult getVersionList(@RequestParam("identifier") String identifier, @RequestParam("owner") String owner) throws Exception { | public AjaxResult getVersionList(@RequestParam("identifier") String identifier, @RequestParam("owner") String owner) throws Exception { | ||||
| @@ -9,7 +9,7 @@ public interface ModelDependency1Dao { | |||||
| int insert(ModelDependency1 modelDependency1); | int insert(ModelDependency1 modelDependency1); | ||||
| int updateState(@Param("repoId") Integer repoId, @Param("identifier") String identifier, @Param("version") String version, @Param("state") Integer state); | |||||
| int updateState(@Param("repoId") Integer repoId, @Param("identifier") String identifier, @Param("version") String version, @Param("meta") String meta, @Param("state") Integer state); | |||||
| List<ModelDependency1> queryModelDependency(@Param("modelName") String modelName, @Param("repoId") Integer repoId, @Param("owner") String owner); | List<ModelDependency1> queryModelDependency(@Param("modelName") String modelName, @Param("repoId") Integer repoId, @Param("owner") String owner); | ||||
| @@ -2,7 +2,9 @@ package com.ruoyi.platform.service; | |||||
| import com.ruoyi.platform.vo.InsMetricInfoVo; | import com.ruoyi.platform.vo.InsMetricInfoVo; | ||||
| import java.util.HashMap; | |||||
| import java.util.List; | import java.util.List; | ||||
| import java.util.Map; | |||||
| public interface AimService { | public interface AimService { | ||||
| @@ -13,4 +15,8 @@ public interface AimService { | |||||
| List<InsMetricInfoVo> getExpEvaluateInfos(Integer experimentId) throws Exception; | List<InsMetricInfoVo> getExpEvaluateInfos(Integer experimentId) throws Exception; | ||||
| String getExpMetrics(List<String> runIds) throws Exception; | String getExpMetrics(List<String> runIds) throws Exception; | ||||
| HashMap<String, Object> queryMetricsParams(String runId) throws Exception; | |||||
| List<Map<String, Object>> getBatchMetric(String runHash, String params); | |||||
| } | } | ||||
| @@ -101,6 +101,8 @@ public interface ModelsService { | |||||
| Page<Map<String, Object>> queryVersions(PageRequest pageRequest, String identifier, String owner) throws Exception; | Page<Map<String, Object>> queryVersions(PageRequest pageRequest, String identifier, String owner) throws Exception; | ||||
| List<List<Map<String, Object>>> queryVersionsMetrics(List<String> runIds) throws Exception; | |||||
| List<Map<String, Object>> getVersionList(String identifier, String owner) throws Exception; | List<Map<String, Object>> getVersionList(String identifier, String owner) throws Exception; | ||||
| ModelsVo getModelDetail(Integer id, String identifier, String owner, String version) throws Exception; | ModelsVo getModelDetail(Integer id, String identifier, String owner, String version) throws Exception; | ||||
| @@ -1,5 +1,6 @@ | |||||
| package com.ruoyi.platform.service.impl; | package com.ruoyi.platform.service.impl; | ||||
| import com.alibaba.fastjson2.JSON; | |||||
| import com.ruoyi.platform.domain.ExperimentIns; | import com.ruoyi.platform.domain.ExperimentIns; | ||||
| import com.ruoyi.platform.service.AimService; | import com.ruoyi.platform.service.AimService; | ||||
| import com.ruoyi.platform.service.ExperimentInsService; | import com.ruoyi.platform.service.ExperimentInsService; | ||||
| @@ -13,6 +14,7 @@ import org.springframework.beans.factory.annotation.Value; | |||||
| import org.springframework.stereotype.Service; | import org.springframework.stereotype.Service; | ||||
| import javax.annotation.Resource; | import javax.annotation.Resource; | ||||
| import java.io.UnsupportedEncodingException; | |||||
| import java.net.URLEncoder; | import java.net.URLEncoder; | ||||
| import java.util.*; | import java.util.*; | ||||
| import java.util.stream.Collectors; | import java.util.stream.Collectors; | ||||
| @@ -245,4 +247,44 @@ public class AimServiceImpl implements AimService { | |||||
| } | } | ||||
| return datasetList; | return datasetList; | ||||
| } | } | ||||
| @Override | |||||
| public HashMap<String, Object> queryMetricsParams(String runId) throws UnsupportedEncodingException { | |||||
| String encodedUrlString = URLEncoder.encode("run.id==\"" + runId + "\"", "UTF-8"); | |||||
| String url = aimProxyUrl + "/api/runs/search/run?query=" + encodedUrlString; | |||||
| String s = httpUtils.sendGet(url, null); | |||||
| List<Map<String, Object>> response = JacksonUtil.parseJSONStr2MapList(s); | |||||
| if (response == null || response.size() == 0) { | |||||
| return new HashMap<>(); | |||||
| } | |||||
| HashMap<String, Object> resultMap = new HashMap<>(); | |||||
| List<Map<String, Object>> paramList = new ArrayList<>(); | |||||
| Map<String, Object> map = response.get(0); | |||||
| LinkedHashMap<String, ArrayList> traces = (LinkedHashMap<String, ArrayList>) map.get("traces"); | |||||
| if (traces != null) { | |||||
| List<Map<String, Object>> metrics = traces.get("metric"); | |||||
| for (Map<String, Object> metric : metrics) { | |||||
| Map<String, Object> metricParam = new HashMap<>(); | |||||
| metricParam.put("context", metric.get("context")); | |||||
| metricParam.put("name", metric.get("name")); | |||||
| paramList.add(metricParam); | |||||
| } | |||||
| resultMap.put("params", JSON.toJSONString(paramList)); | |||||
| } | |||||
| resultMap.put("run_hash", map.get("run_hash")); | |||||
| return resultMap; | |||||
| } | |||||
| @Override | |||||
| public List<Map<String, Object>> getBatchMetric(String runHash, String params) { | |||||
| String url = aimUrl + "/api/runs/" + runHash + "/metric/get-batch"; | |||||
| String response = httpUtils.sendPost(url, null, params); | |||||
| if (StringUtils.isNotEmpty(response)) { | |||||
| return JacksonUtil.parseJSONStr2MapList(response); | |||||
| } | |||||
| return null; | |||||
| } | |||||
| } | } | ||||
| @@ -560,7 +560,7 @@ public class ExperimentServiceImpl implements ExperimentService { | |||||
| //处理指标 | //处理指标 | ||||
| HashMap<String, Object> metricMap = JSON.parseObject(metricRecord, HashMap.class); | HashMap<String, Object> metricMap = JSON.parseObject(metricRecord, HashMap.class); | ||||
| modelMetaVo.setMetrics(metricMap); | |||||
| modelMetaVo.setMetricsParams(metricMap); | |||||
| //训练数据集 | //训练数据集 | ||||
| List<Map<String, Object>> trainDatasetList = (List<Map<String, Object>>) modelTrainMap.get("datasets"); | List<Map<String, Object>> trainDatasetList = (List<Map<String, Object>>) modelTrainMap.get("datasets"); | ||||
| @@ -747,7 +747,7 @@ public class ModelsServiceImpl implements ModelsService { | |||||
| if (buildingModel != null) { | if (buildingModel != null) { | ||||
| modelMetaVo = JSON.parseObject(buildingModel.getMeta(), ModelMetaVo.class); | modelMetaVo = JSON.parseObject(buildingModel.getMeta(), ModelMetaVo.class); | ||||
| //获取指标 | //获取指标 | ||||
| getMetrics(modelMetaVo); | |||||
| transMetrics(modelMetaVo); | |||||
| } | } | ||||
| //拼接生产的元数据后写入yaml文件 | //拼接生产的元数据后写入yaml文件 | ||||
| @@ -806,7 +806,7 @@ public class ModelsServiceImpl implements ModelsService { | |||||
| modelDependency1Dao.insert(modelDependency); | modelDependency1Dao.insert(modelDependency); | ||||
| } else { | } else { | ||||
| //更新模型依赖 | //更新模型依赖 | ||||
| modelDependency1Dao.updateState(modelsVo.getId(), modelsVo.getIdentifier(), modelsVo.getVersion(), Constant.State_valid); | |||||
| modelDependency1Dao.updateState(modelsVo.getId(), modelsVo.getIdentifier(), modelsVo.getVersion(), meta, Constant.State_valid); | |||||
| } | } | ||||
| } else { | } else { | ||||
| //保存模型依赖 | //保存模型依赖 | ||||
| @@ -991,10 +991,24 @@ public class ModelsServiceImpl implements ModelsService { | |||||
| if (modelMetaVo.getMetrics() != null) { | if (modelMetaVo.getMetrics() != null) { | ||||
| branch.putAll(modelMetaVo.getMetrics()); | branch.putAll(modelMetaVo.getMetrics()); | ||||
| } | } | ||||
| if (modelMetaVo.getMetricsParams() != null) { | |||||
| branch.putAll(modelMetaVo.getMetricsParams()); | |||||
| } | |||||
| } | } | ||||
| return new PageImpl<>(result, pageRequest, collect.size()); | return new PageImpl<>(result, pageRequest, collect.size()); | ||||
| } | } | ||||
| @Override | |||||
| public List<List<Map<String, Object>>> queryVersionsMetrics(List<String> runIds) throws Exception { | |||||
| List<List<Map<String, Object>>> batchMetrics = new ArrayList<>(); | |||||
| for (String runId : runIds) { | |||||
| HashMap<String, Object> map = aimsService.queryMetricsParams(runId); | |||||
| List<Map<String, Object>> batchMetric = aimsService.getBatchMetric((String) map.get("run_hash"), (String) map.get("params")); | |||||
| batchMetrics.add(batchMetric); | |||||
| } | |||||
| return batchMetrics; | |||||
| } | |||||
| @Override | @Override | ||||
| public List<Map<String, Object>> getVersionList(String identifier, String owner) throws Exception { | public List<Map<String, Object>> getVersionList(String identifier, String owner) throws Exception { | ||||
| String token = gitService.checkoutToken(); | String token = gitService.checkoutToken(); | ||||
| @@ -1201,12 +1215,12 @@ public class ModelsServiceImpl implements ModelsService { | |||||
| return userInfo; | return userInfo; | ||||
| } | } | ||||
| void getMetrics(ModelMetaVo modelMetaVo) throws Exception { | |||||
| void transMetrics(ModelMetaVo modelMetaVo) throws Exception { | |||||
| HashMap<String, Object> result = new HashMap<>(); | HashMap<String, Object> result = new HashMap<>(); | ||||
| HashMap<String, Object> train = new HashMap<>(); | HashMap<String, Object> train = new HashMap<>(); | ||||
| HashMap<String, Object> evaluate = new HashMap<>(); | HashMap<String, Object> evaluate = new HashMap<>(); | ||||
| HashMap<String, Object> metrics = modelMetaVo.getMetrics(); | |||||
| HashMap<String, Object> metrics = modelMetaVo.getMetricsParams(); | |||||
| JSONArray trainMetrics = (JSONArray) metrics.get("train"); | JSONArray trainMetrics = (JSONArray) metrics.get("train"); | ||||
| if (trainMetrics != null) { | if (trainMetrics != null) { | ||||
| for (int i = 0; i < trainMetrics.size(); i++) { | for (int i = 0; i < trainMetrics.size(); i++) { | ||||
| @@ -59,6 +59,9 @@ public class ModelMetaVo implements Serializable { | |||||
| @ApiModelProperty(value = "指标") | @ApiModelProperty(value = "指标") | ||||
| private HashMap<String, Object> metrics; | private HashMap<String, Object> metrics; | ||||
| @ApiModelProperty(value = "指标查询参数") | |||||
| private HashMap<String, Object> metricsParams; | |||||
| @ApiModelProperty(value = "训练任务") | @ApiModelProperty(value = "训练任务") | ||||
| private TrainTaskDepency trainTask; | private TrainTaskDepency trainTask; | ||||
| @@ -86,7 +86,8 @@ | |||||
| <update id="updateState"> | <update id="updateState"> | ||||
| update model_dependency1 | update model_dependency1 | ||||
| set state = 1 | |||||
| set state = 1, | |||||
| meta = #{meta} | |||||
| where repo_id = #{repoId} | where repo_id = #{repoId} | ||||
| and identifier = #{identifier} | and identifier = #{identifier} | ||||
| and version = #{version} | and version = #{version} | ||||