Browse Source

自动超参数寻优实验功能开发

dev-ray
chenzhihang 1 year ago
parent
commit
5b3daa7019
4 changed files with 47 additions and 6 deletions
  1. +1
    -1
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/ray/RayInsController.java
  2. +6
    -0
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/domain/RayIns.java
  3. +1
    -1
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/RayInsService.java
  4. +39
    -4
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/RayInsServiceImpl.java

+ 1
- 1
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/ray/RayInsController.java View File

@@ -54,7 +54,7 @@ public class RayInsController extends BaseController {

@GetMapping("{id}")
@ApiOperation("查看实验实例详情")
public GenericsAjaxResult<RayIns> getDetailById(@PathVariable("id") Long id) {
public GenericsAjaxResult<RayIns> getDetailById(@PathVariable("id") Long id) throws IOException {
return genericsSuccess(this.rayInsService.getDetailById(id));
}
}

+ 6
- 0
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/domain/RayIns.java View File

@@ -1,12 +1,15 @@
package com.ruoyi.platform.domain;

import com.baomidou.mybatisplus.annotation.TableField;
import com.fasterxml.jackson.databind.PropertyNamingStrategy;
import com.fasterxml.jackson.databind.annotation.JsonNaming;
import io.swagger.annotations.ApiModel;
import io.swagger.annotations.ApiModelProperty;
import lombok.Data;

import java.util.ArrayList;
import java.util.Date;
import java.util.Map;

@Data
@JsonNaming(PropertyNamingStrategy.SnakeCaseStrategy.class)
@@ -41,5 +44,8 @@ public class RayIns {
private Date updateTime;

private Date finishTime;

@TableField(exist = false)
private ArrayList<Map<String, Object>> trialList;
}


+ 1
- 1
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/RayInsService.java View File

@@ -15,7 +15,7 @@ public interface RayInsService {

boolean terminateRayIns(Long id) throws Exception;

RayIns getDetailById(Long id);
RayIns getDetailById(Long id) throws IOException;

void updateRayStatus(Long rayId);
}

+ 39
- 4
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/RayInsServiceImpl.java View File

@@ -6,6 +6,7 @@ import com.ruoyi.platform.domain.RayIns;
import com.ruoyi.platform.mapper.RayDao;
import com.ruoyi.platform.mapper.RayInsDao;
import com.ruoyi.platform.service.RayInsService;
import com.ruoyi.platform.utils.JsonUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.PageImpl;
@@ -14,9 +15,11 @@ import org.springframework.stereotype.Service;

import javax.annotation.Resource;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Date;
import java.util.List;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.*;
import java.util.stream.Collectors;

@Service("rayInsService")
public class RayInsServiceImpl implements RayInsService {
@@ -104,11 +107,12 @@ public class RayInsServiceImpl implements RayInsService {
}

@Override
public RayIns getDetailById(Long id) {
public RayIns getDetailById(Long id) throws IOException {
RayIns rayIns = rayInsDao.queryById(id);
if (Constant.Running.equals(rayIns.getStatus()) || Constant.Pending.equals(rayIns.getStatus())) {
//todo queryStatusFromArgo
}
rayIns.setTrialList(getTrialList(rayIns.getResultPath()));
return rayIns;
}

@@ -127,4 +131,35 @@ public class RayInsServiceImpl implements RayInsService {
rayDao.edit(ray);
}
}

public ArrayList<Map<String, Object>> getTrialList(String directoryPath) throws IOException {
// 获取指定路径下的所有文件
Path dirPath = Paths.get(directoryPath);
Path experimentState = Files.list(dirPath).filter(path -> Files.isRegularFile(path) && path.getFileName().toString().startsWith("experiment_state")).collect(Collectors.toList()).get(0);
String content = new String(Files.readAllBytes(experimentState));
Map<String, Object> result = JsonUtils.jsonToMap(content);
ArrayList<ArrayList> trial_data_list = (ArrayList<ArrayList>) result.get("trial_data");

ArrayList<Map<String, Object>> trialList = new ArrayList<>();

for (ArrayList trial_data : trial_data_list) {
Map<String, Object> trial_data_0 = JsonUtils.jsonToMap((String) trial_data.get(0));
Map<String, Object> trial_data_1 = JsonUtils.jsonToMap((String) trial_data.get(1));

Map<String, Object> trial = new HashMap<>();
trial.put("trial_id", trial_data_0.get("trial_id"));
trial.put("config", trial_data_0.get("config"));
trial.put("status", trial_data_0.get("status"));

Map<String, Object> last_result = (Map<String, Object>) trial_data_1.get("last_result");
Map<String, Object> metric_analysis = (Map<String, Object>) trial_data_1.get("metric_analysis");
Map<String, Object> time_total_s = (Map<String, Object>) metric_analysis.get("time_total_s");

trial.put("training_iteration", last_result.get("training_iteration"));
trial.put("time", time_total_s.get("avg"));

trialList.add(trial);
}
return trialList;
}
}

Loading…
Cancel
Save