Browse Source

自动超参数寻优实验功能开发

dev-ray
chenzhihang 11 months ago
parent
commit
756967a6fd
1 changed files with 53 additions and 52 deletions
  1. +53
    -52
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/RayInsServiceImpl.java

+ 53
- 52
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/RayInsServiceImpl.java View File

@@ -280,8 +280,8 @@ public class RayInsServiceImpl implements RayInsService {
List<Map<String, Object>> responses = JacksonUtil.parseJSONStr2MapList(s);

List<String> runIds = new ArrayList<>();
for (Map response:responses) {
runIds.add((String)response.get("run_hash"));
for (Map response : responses) {
runIds.add((String) response.get("run_hash"));
}

String decode = AIM64EncoderUtil.decode(runIds);
@@ -294,64 +294,65 @@ public class RayInsServiceImpl implements RayInsService {

rayIns.setResultPath(endpoint + "/" + directoryPath);
rayIns.setResultTxt(endpoint + "/" + directoryPath + "/result.txt");

String bucketName = directoryPath.substring(0, directoryPath.indexOf("/"));
String prefix = directoryPath.substring(directoryPath.indexOf("/") + 1, directoryPath.length()) + "/";
List<Map> fileMaps = minioUtil.listRayFilesInDirectory(bucketName, prefix);

if (!fileMaps.isEmpty()) {
List<Map> collect = fileMaps.stream().filter(map -> map.get("name").toString().startsWith("experiment_state")).collect(Collectors.toList());
if (!collect.isEmpty()) {
Path experimentState = Paths.get(collect.get(0).get("name").toString());
String content = minioUtil.readObjectAsString(bucketName, prefix + "/" + experimentState);

String resultTxt = minioUtil.readObjectAsString(bucketName, prefix + "result.txt");
String bestMetrics = getStringBetween(resultTxt, "Best metrics:", "Best result_df");
Map<String, Object> bestMetricsMap = JsonUtils.jsonToMap(bestMetrics);
String bestTrialId = (String)bestMetricsMap.get("trial_id");

Map<String, Object> result = JsonUtils.jsonToMap(content);
ArrayList<ArrayList> trial_data_list = (ArrayList<ArrayList>) result.get("trial_data");
ArrayList<Map<String, Object>> trialList = new ArrayList<>();

for (ArrayList trial_data : trial_data_list) {
Map<String, Object> trial_data_0 = JsonUtils.jsonToMap((String) trial_data.get(0));
Map<String, Object> trial_data_1 = JsonUtils.jsonToMap((String) trial_data.get(1));

String trialId = (String) trial_data_0.get("trial_id");
Map<String, Object> trial = new HashMap<>();
trial.put("trial_id", trialId);
trial.put("config", trial_data_0.get("config"));
trial.put("status", trial_data_0.get("status"));

Map<String, Object> last_result = (Map<String, Object>) trial_data_1.get("last_result");
Map<String, Object> metric_analysis = (Map<String, Object>) trial_data_1.get("metric_analysis");
Map<String, Object> time_total_s = (Map<String, Object>) metric_analysis.get("time_total_s");

trial.put("training_iteration", last_result.get("training_iteration"));
trial.put("time_avg", time_total_s.get("avg"));

Map<String, Object> param = JsonUtils.jsonToMap(rayIns.getParam());
trial.put("metric_analysis", metric_analysis.get((String) param.get("metric")));
trial.put("metric", param.get("metric"));

for (Map fileMap : fileMaps) {
if (fileMap.get("name").toString().contains(trialId)) {
trial.put("file", fileMap);
try {
String bucketName = directoryPath.substring(0, directoryPath.indexOf("/"));
String prefix = directoryPath.substring(directoryPath.indexOf("/") + 1, directoryPath.length()) + "/";
List<Map> fileMaps = minioUtil.listRayFilesInDirectory(bucketName, prefix);

if (!fileMaps.isEmpty()) {
List<Map> collect = fileMaps.stream().filter(map -> map.get("name").toString().startsWith("experiment_state")).collect(Collectors.toList());
if (!collect.isEmpty()) {
Path experimentState = Paths.get(collect.get(0).get("name").toString());
String content = minioUtil.readObjectAsString(bucketName, prefix + "/" + experimentState);

String resultTxt = minioUtil.readObjectAsString(bucketName, prefix + "result.txt");
String bestMetrics = getStringBetween(resultTxt, "Best metrics:", "Best result_df");
Map<String, Object> bestMetricsMap = JsonUtils.jsonToMap(bestMetrics);
String bestTrialId = (String) bestMetricsMap.get("trial_id");

Map<String, Object> result = JsonUtils.jsonToMap(content);
ArrayList<ArrayList> trial_data_list = (ArrayList<ArrayList>) result.get("trial_data");
ArrayList<Map<String, Object>> trialList = new ArrayList<>();

for (ArrayList trial_data : trial_data_list) {
Map<String, Object> trial_data_0 = JsonUtils.jsonToMap((String) trial_data.get(0));
Map<String, Object> trial_data_1 = JsonUtils.jsonToMap((String) trial_data.get(1));

String trialId = (String) trial_data_0.get("trial_id");
Map<String, Object> trial = new HashMap<>();
trial.put("trial_id", trialId);
trial.put("config", trial_data_0.get("config"));
trial.put("status", trial_data_0.get("status"));

Map<String, Object> last_result = (Map<String, Object>) trial_data_1.get("last_result");
Map<String, Object> metric_analysis = (Map<String, Object>) trial_data_1.get("metric_analysis");
Map<String, Object> time_total_s = (Map<String, Object>) metric_analysis.get("time_total_s");

trial.put("training_iteration", last_result.get("training_iteration"));
trial.put("time_avg", time_total_s.get("avg"));

Map<String, Object> param = JsonUtils.jsonToMap(rayIns.getParam());
trial.put("metric_analysis", metric_analysis.get((String) param.get("metric")));
trial.put("metric", param.get("metric"));

for (Map fileMap : fileMaps) {
if (fileMap.get("name").toString().contains(trialId)) {
trial.put("file", fileMap);
}
}
}

try {
if (bestTrialId.equals(trialId)) {
trial.put("is_best", true);
trialList.add(0, trial);
} else {
trialList.add(trial);
}
} catch (Exception e) {
logger.error("未找到结果文件:result.txt");
}
trialList.add(trial);
rayIns.setTrialList(trialList);
}
rayIns.setTrialList(trialList);
}
} catch (Exception e) {
logger.error("未找到结果文件:result.txt");
}
}



Loading…
Cancel
Save