Browse Source

自动超参数寻优实验功能开发

dev-ray
chenzhihang 11 months ago
parent
commit
29124867f5
3 changed files with 28 additions and 27 deletions
  1. +27
    -23
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/RayInsServiceImpl.java
  2. +1
    -2
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/RayServiceImpl.java
  3. +0
    -2
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/vo/RayParamVo.java

+ 27
- 23
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/RayInsServiceImpl.java View File

@@ -19,7 +19,6 @@ import org.springframework.stereotype.Service;

import javax.annotation.Resource;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.*;
@@ -33,6 +32,8 @@ public class RayInsServiceImpl implements RayInsService {
private String argoWorkflowStatus;
@Value("${argo.workflowTermination}")
private String argoWorkflowTermination;
@Value("${minio.dataReleaseBucketName}")
private String bucketName;

@Resource
private RayInsDao rayInsDao;
@@ -266,37 +267,40 @@ public class RayInsServiceImpl implements RayInsService {

public ArrayList<Map<String, Object>> getTrialList(String directoryPath) throws Exception {
// 获取指定路径下的所有文件
String bucketName = directoryPath.substring(0, directoryPath.indexOf("/"));
String prefix = directoryPath.substring(directoryPath.indexOf("/") + 1, directoryPath.length()) + "/";

List<Map> maps = minioUtil.listFilesInDirectory(bucketName, prefix);

Path dirPath = Paths.get(directoryPath);
Path experimentState = Files.list(dirPath).filter(path -> Files.isRegularFile(path) && path.getFileName().toString().startsWith("experiment_state")).collect(Collectors.toList()).get(0);
String content = new String(Files.readAllBytes(experimentState));
Map<String, Object> result = JsonUtils.jsonToMap(content);
ArrayList<ArrayList> trial_data_list = (ArrayList<ArrayList>) result.get("trial_data");
if (!maps.isEmpty()) {
List<Map> collect = maps.stream().filter(map -> map.get("name").toString().startsWith("experiment_state")).collect(Collectors.toList());
if (!collect.isEmpty()) {
Path experimentState = Paths.get(collect.get(0).get("name").toString());
String content = minioUtil.readObjectAsString(bucketName, prefix + "/" + experimentState);

ArrayList<Map<String, Object>> trialList = new ArrayList<>();
Map<String, Object> result = JsonUtils.jsonToMap(content);
ArrayList<ArrayList> trial_data_list = (ArrayList<ArrayList>) result.get("trial_data");
ArrayList<Map<String, Object>> trialList = new ArrayList<>();

for (ArrayList trial_data : trial_data_list) {
Map<String, Object> trial_data_0 = JsonUtils.jsonToMap((String) trial_data.get(0));
Map<String, Object> trial_data_1 = JsonUtils.jsonToMap((String) trial_data.get(1));
for (ArrayList trial_data : trial_data_list) {
Map<String, Object> trial_data_0 = JsonUtils.jsonToMap((String) trial_data.get(0));
Map<String, Object> trial_data_1 = JsonUtils.jsonToMap((String) trial_data.get(1));

Map<String, Object> trial = new HashMap<>();
trial.put("trial_id", trial_data_0.get("trial_id"));
trial.put("config", trial_data_0.get("config"));
trial.put("status", trial_data_0.get("status"));
Map<String, Object> trial = new HashMap<>();
trial.put("trial_id", trial_data_0.get("trial_id"));
trial.put("config", trial_data_0.get("config"));
trial.put("status", trial_data_0.get("status"));

Map<String, Object> last_result = (Map<String, Object>) trial_data_1.get("last_result");
Map<String, Object> metric_analysis = (Map<String, Object>) trial_data_1.get("metric_analysis");
Map<String, Object> time_total_s = (Map<String, Object>) metric_analysis.get("time_total_s");
Map<String, Object> last_result = (Map<String, Object>) trial_data_1.get("last_result");
Map<String, Object> metric_analysis = (Map<String, Object>) trial_data_1.get("metric_analysis");
Map<String, Object> time_total_s = (Map<String, Object>) metric_analysis.get("time_total_s");

trial.put("training_iteration", last_result.get("training_iteration"));
trial.put("time", time_total_s.get("avg"));
trial.put("training_iteration", last_result.get("training_iteration"));
trial.put("time", time_total_s.get("avg"));

trialList.add(trial);
trialList.add(trial);
}
return trialList;
}
}
return trialList;
return null;
}
}

+ 1
- 2
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/RayServiceImpl.java View File

@@ -197,8 +197,7 @@ public class RayServiceImpl implements RayService {
Map<String, Object> param_output = (Map<String, Object>) output.get("param_output");
List output1 = (ArrayList) param_output.values().toArray()[0];
Map<String, String> output2 = (Map<String, String>) output1.get(0);
String outputPath = minioEndpoint + "/" + output2.get("path").replace("{{workflow.name}}", (String) metadata.get("name")) + "/" + ray.getName();

String outputPath = output2.get("path").replace("{{workflow.name}}", (String) metadata.get("name")) + "/hpo";
rayIns.setResultPath(outputPath);
rayInsDao.insert(rayIns);
rayInsService.updateRayStatus(id);


+ 0
- 2
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/vo/RayParamVo.java View File

@@ -24,8 +24,6 @@ public class RayParamVo {

private String mainPy;

private String name;

private Integer numSamples;

private String parameters;


Loading…
Cancel
Save