Browse Source

Merge remote-tracking branch 'origin/dev' into test

dev-czh
chenzhihang 1 year ago
parent
commit
984c1bbd7f
1 changed files with 18 additions and 6 deletions
  1. +18
    -6
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ModelsServiceImpl.java

+ 18
- 6
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ModelsServiceImpl.java View File

@@ -988,7 +988,8 @@ public class ModelsServiceImpl implements ModelsService {
ModelMetaVo modelMetaVo = JSON.parseObject(meta, ModelMetaVo.class); ModelMetaVo modelMetaVo = JSON.parseObject(meta, ModelMetaVo.class);
if (modelMetaVo.getParams() != null) { if (modelMetaVo.getParams() != null) {
HashMap<String, Object> params = modelMetaVo.getParams(); HashMap<String, Object> params = modelMetaVo.getParams();
branch.putAll(params);
branch.put("params", params);

ArrayList<String> params_names = new ArrayList<>(); ArrayList<String> params_names = new ArrayList<>();
for (String key : params.keySet()) { for (String key : params.keySet()) {
params_names.add(key); params_names.add(key);
@@ -999,10 +1000,20 @@ public class ModelsServiceImpl implements ModelsService {
HashMap<String, Object> metrics = modelMetaVo.getMetrics(); HashMap<String, Object> metrics = modelMetaVo.getMetrics();
if (Constant.Type_Train.equals(type)) { if (Constant.Type_Train.equals(type)) {
Map<String, Object> trainMetrics = (Map<String, Object>) metrics.get(Constant.Type_Train); Map<String, Object> trainMetrics = (Map<String, Object>) metrics.get(Constant.Type_Train);
branch.putAll(trainMetrics);
ArrayList<String> metrics_names = new ArrayList<>();
for (String key : trainMetrics.keySet()) {
metrics_names.add(key);
}
branch.put("metrics_names", metrics_names);
branch.put("metrics",trainMetrics);
} else { } else {
Map<String, Object> evaluateMetrics = (Map<String, Object>) metrics.get(Constant.Type_Evaluate); Map<String, Object> evaluateMetrics = (Map<String, Object>) metrics.get(Constant.Type_Evaluate);
branch.putAll(evaluateMetrics);
ArrayList<String> metrics_names = new ArrayList<>();
for (String key : evaluateMetrics.keySet()) {
metrics_names.add(key);
}
branch.put("metrics_names", metrics_names);
branch.put("metrics",evaluateMetrics);
} }
} }
} }
@@ -1017,7 +1028,7 @@ public class ModelsServiceImpl implements ModelsService {


ModelDependency1 modelDependency1 = modelDependency1Dao.queryByRepoAndVersion(model.getRepoId(), model.getIdentifier(), model.getVersion()); ModelDependency1 modelDependency1 = modelDependency1Dao.queryByRepoAndVersion(model.getRepoId(), model.getIdentifier(), model.getVersion());
ModelMetaVo modelMetaVo = JSON.parseObject(modelDependency1.getMeta(), ModelMetaVo.class); ModelMetaVo modelMetaVo = JSON.parseObject(modelDependency1.getMeta(), ModelMetaVo.class);
HashMap<String, Object> metrics = modelMetaVo.getMetrics();
HashMap<String, Object> metrics = modelMetaVo.getMetricsParams();
if (Constant.Type_Train.equals(type)) { if (Constant.Type_Train.equals(type)) {
batchMetrics.add(metrics.get("tarinDetail")); batchMetrics.add(metrics.get("tarinDetail"));
} else { } else {
@@ -1260,7 +1271,7 @@ public class ModelsServiceImpl implements ModelsService {


} }
result.put("train", train); result.put("train", train);
result.put("tarinDetail", trainBatchMetrics);
metrics.put("tarinDetail", trainBatchMetrics);


} }


@@ -1282,10 +1293,11 @@ public class ModelsServiceImpl implements ModelsService {


} }
result.put("evaluate", evaluate); result.put("evaluate", evaluate);
result.put("evaluateDetail", evaluateBatchMetrics);
metrics.put("evaluateDetail", evaluateBatchMetrics);


} }
modelMetaVo.setMetrics(result); modelMetaVo.setMetrics(result);
modelMetaVo.setMetricsParams(metrics);
} }


} }

Loading…
Cancel
Save