Browse Source

自动超参数寻优实验功能开发

dev-ray
chenzhihang 11 months ago
parent
commit
3430691d08
6 changed files with 85 additions and 16 deletions
  1. +10
    -1
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/minio/MinioStorageController.java
  2. +1
    -4
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/domain/RayIns.java
  3. +3
    -0
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/MinioService.java
  4. +26
    -0
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/MinioServiceImpl.java
  5. +44
    -9
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/RayInsServiceImpl.java
  6. +1
    -2
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/utils/MinioUtil.java

+ 10
- 1
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/minio/MinioStorageController.java View File

@@ -23,11 +23,20 @@ public class MinioStorageController {
@Resource @Resource
private MinioService minioService; private MinioService minioService;



@GetMapping("/downloadFile")
@ApiOperation("下载单个文件")
public ResponseEntity<InputStreamResource> downloadFile(@RequestParam("path") String path) throws Exception {
String bucketName = path.substring(0, path.indexOf("/"));
String prefix = path.substring(path.indexOf("/")+1,path.length());
return minioService.downloadFile(bucketName, prefix);
}

@GetMapping("/download") @GetMapping("/download")
@ApiOperation(value = "minio存储下载", notes = "minio存储下载文件为zip包") @ApiOperation(value = "minio存储下载", notes = "minio存储下载文件为zip包")
public ResponseEntity<InputStreamResource> downloadDataset(@RequestParam("path") String path) { public ResponseEntity<InputStreamResource> downloadDataset(@RequestParam("path") String path) {
String bucketName = path.substring(0, path.indexOf("/")); String bucketName = path.substring(0, path.indexOf("/"));
String prefix = path.substring(path.indexOf("/")+1,path.length())+"/";
String prefix = path.substring(path.indexOf("/")+1,path.length());
return minioService.downloadZipFile(bucketName,prefix); return minioService.downloadZipFile(bucketName,prefix);
} }




+ 1
- 4
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/domain/RayIns.java View File

@@ -22,9 +22,6 @@ public class RayIns {


private String resultPath; private String resultPath;


@TableField(exist = false)
private String resultTxt;

private Integer state; private Integer state;


private String status; private String status;
@@ -50,7 +47,7 @@ public class RayIns {
private Date finishTime; private Date finishTime;


@TableField(exist = false) @TableField(exist = false)
private List<Map> fileList;
private String resultTxt;


@TableField(exist = false) @TableField(exist = false)
private ArrayList<Map<String, Object>> trialList; private ArrayList<Map<String, Object>> trialList;


+ 3
- 0
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/MinioService.java View File

@@ -9,6 +9,9 @@ import java.util.List;
import java.util.Map; import java.util.Map;


public interface MinioService { public interface MinioService {

ResponseEntity<InputStreamResource> downloadFile(String bucketName , String path) throws Exception;

ResponseEntity<InputStreamResource> downloadZipFile(String bucketName , String path); ResponseEntity<InputStreamResource> downloadZipFile(String bucketName , String path);


Map<String, String> uploadFile(String bucketName, String objectName, MultipartFile file ) throws Exception; Map<String, String> uploadFile(String bucketName, String objectName, MultipartFile file ) throws Exception;


+ 26
- 0
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/MinioServiceImpl.java View File

@@ -31,6 +31,32 @@ public class MinioServiceImpl implements MinioService {
public MinioServiceImpl(MinioUtil minioUtil) { public MinioServiceImpl(MinioUtil minioUtil) {
this.minioUtil = minioUtil; this.minioUtil = minioUtil;
} }

@Override
public ResponseEntity<InputStreamResource> downloadFile(String bucketName , String url) throws Exception {
try {
// 使用ByteArrayOutputStream来捕获下载的数据
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
minioUtil.downloadObject(bucketName, url, outputStream);

ByteArrayInputStream inputStream = new ByteArrayInputStream(outputStream.toByteArray());
InputStreamResource resource = new InputStreamResource(inputStream);

return ResponseEntity.ok()
.header(HttpHeaders.CONTENT_DISPOSITION, "attachment; filename=\"" + extractFileName(url) + "\"")
.contentType(MediaType.APPLICATION_OCTET_STREAM)
.body(resource);
} catch (Exception e) {
e.printStackTrace();
throw new Exception("下载文件错误");
}
}

private String extractFileName(String urlStr) {
return urlStr.substring(urlStr.lastIndexOf('/') + 1);
}


@Override @Override
public ResponseEntity<InputStreamResource> downloadZipFile(String bucketName,String path) { public ResponseEntity<InputStreamResource> downloadZipFile(String bucketName,String path) {
try { try {


+ 44
- 9
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/RayInsServiceImpl.java View File

@@ -11,6 +11,8 @@ import com.ruoyi.platform.utils.HttpUtils;
import com.ruoyi.platform.utils.JsonUtils; import com.ruoyi.platform.utils.JsonUtils;
import com.ruoyi.platform.utils.MinioUtil; import com.ruoyi.platform.utils.MinioUtil;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Value; import org.springframework.beans.factory.annotation.Value;
import org.springframework.data.domain.Page; import org.springframework.data.domain.Page;
import org.springframework.data.domain.PageImpl; import org.springframework.data.domain.PageImpl;
@@ -26,6 +28,8 @@ import java.util.stream.Collectors;


@Service("rayInsService") @Service("rayInsService")
public class RayInsServiceImpl implements RayInsService { public class RayInsServiceImpl implements RayInsService {
private static final Logger logger = LoggerFactory.getLogger(RayInsServiceImpl.class);

@Value("${argo.url}") @Value("${argo.url}")
private String argoUrl; private String argoUrl;
@Value("${argo.workflowStatus}") @Value("${argo.workflowStatus}")
@@ -34,8 +38,6 @@ public class RayInsServiceImpl implements RayInsService {
private String argoWorkflowTermination; private String argoWorkflowTermination;
@Value("${minio.endpointIp}") @Value("${minio.endpointIp}")
String endpoint; String endpoint;
@Value("${minio.dataReleaseBucketName}")
private String bucketName;


@Resource @Resource
private RayInsDao rayInsDao; private RayInsDao rayInsDao;
@@ -174,8 +176,6 @@ public class RayInsServiceImpl implements RayInsService {
rayIns = queryStatusFromArgo(rayIns); rayIns = queryStatusFromArgo(rayIns);
} }
getTrialList(rayIns); getTrialList(rayIns);

// rayIns.setTrialList(getTrialList(rayIns.getResultPath()));
return rayIns; return rayIns;
} }


@@ -276,12 +276,12 @@ public class RayInsServiceImpl implements RayInsService {
rayIns.setResultPath(endpoint + "/" + directoryPath); rayIns.setResultPath(endpoint + "/" + directoryPath);
rayIns.setResultTxt(endpoint + "/" + directoryPath + "/result.txt"); rayIns.setResultTxt(endpoint + "/" + directoryPath + "/result.txt");


String bucketName = directoryPath.substring(0, directoryPath.indexOf("/"));
String prefix = directoryPath.substring(directoryPath.indexOf("/") + 1, directoryPath.length()) + "/"; String prefix = directoryPath.substring(directoryPath.indexOf("/") + 1, directoryPath.length()) + "/";
List<Map> maps = minioUtil.listRayFilesInDirectory(bucketName, prefix);
List<Map> fileMaps = minioUtil.listRayFilesInDirectory(bucketName, prefix);


if (!maps.isEmpty()) {
rayIns.setFileList(maps);
List<Map> collect = maps.stream().filter(map -> map.get("name").toString().startsWith("experiment_state")).collect(Collectors.toList());
if (!fileMaps.isEmpty()) {
List<Map> collect = fileMaps.stream().filter(map -> map.get("name").toString().startsWith("experiment_state")).collect(Collectors.toList());
if (!collect.isEmpty()) { if (!collect.isEmpty()) {
Path experimentState = Paths.get(collect.get(0).get("name").toString()); Path experimentState = Paths.get(collect.get(0).get("name").toString());
String content = minioUtil.readObjectAsString(bucketName, prefix + "/" + experimentState); String content = minioUtil.readObjectAsString(bucketName, prefix + "/" + experimentState);
@@ -294,8 +294,9 @@ public class RayInsServiceImpl implements RayInsService {
Map<String, Object> trial_data_0 = JsonUtils.jsonToMap((String) trial_data.get(0)); Map<String, Object> trial_data_0 = JsonUtils.jsonToMap((String) trial_data.get(0));
Map<String, Object> trial_data_1 = JsonUtils.jsonToMap((String) trial_data.get(1)); Map<String, Object> trial_data_1 = JsonUtils.jsonToMap((String) trial_data.get(1));


String trialId = (String) trial_data_0.get("trial_id");
Map<String, Object> trial = new HashMap<>(); Map<String, Object> trial = new HashMap<>();
trial.put("trial_id", trial_data_0.get("trial_id"));
trial.put("trial_id", trialId);
trial.put("config", trial_data_0.get("config")); trial.put("config", trial_data_0.get("config"));
trial.put("status", trial_data_0.get("status")); trial.put("status", trial_data_0.get("status"));


@@ -310,10 +311,44 @@ public class RayInsServiceImpl implements RayInsService {
trial.put("metric_analysis", metric_analysis.get((String) param.get("metric"))); trial.put("metric_analysis", metric_analysis.get((String) param.get("metric")));
trial.put("metric", param.get("metric")); trial.put("metric", param.get("metric"));


for (Map fileMap : fileMaps) {
if (fileMap.get("name").toString().contains(trialId)) {
trial.put("file", fileMap);
}
}

try {
String resultTxt = minioUtil.readObjectAsString(bucketName, prefix + "result.txt");
String bestTrialId = getStringBetween(resultTxt, "'trial_id': '", "'");
if (bestTrialId.equals(trialId)) {
trial.put("is_best", true);
}
} catch (Exception e) {
logger.error("未找到结果文件:result.txt");
}
trialList.add(trial); trialList.add(trial);
} }
rayIns.setTrialList(trialList); rayIns.setTrialList(trialList);
} }
} }
} }

String getStringBetween(String input, String startMarker, String endMarker) {
int startIndex = input.indexOf(startMarker);
if (startIndex == -1) {
return ""; // 如果未找到起始标记,返回空字符串
}

// 跳过起始标记
startIndex += startMarker.length();

int endIndex = input.indexOf(endMarker, startIndex);
if (endIndex == -1) {
return ""; // 如果未找到结束标记,返回空字符串
}

return input.substring(startIndex, endIndex);
}
} }



+ 1
- 2
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/utils/MinioUtil.java View File

@@ -355,9 +355,8 @@ public class MinioUtil {
map.put("children", listRayFilesInDirectory(bucketName, fullPath)); map.put("children", listRayFilesInDirectory(bucketName, fullPath));
} else { } else {
map.put("isFile", true); map.put("isFile", true);
map.put("url", minioEndpoint + "/" + bucketName + "/" + fullPath);
} }
map.put("url", bucketName + "/" + fullPath);
fileInfoList.add(map); fileInfoList.add(map);
} }




Loading…
Cancel
Save