Browse Source

自动超参数寻优实验功能开发

dev-ray
chenzhihang 11 months ago
parent
commit
3430691d08
6 changed files with 85 additions and 16 deletions
  1. +10
    -1
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/minio/MinioStorageController.java
  2. +1
    -4
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/domain/RayIns.java
  3. +3
    -0
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/MinioService.java
  4. +26
    -0
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/MinioServiceImpl.java
  5. +44
    -9
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/RayInsServiceImpl.java
  6. +1
    -2
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/utils/MinioUtil.java

+ 10
- 1
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/minio/MinioStorageController.java View File

@@ -23,11 +23,20 @@ public class MinioStorageController {
@Resource
private MinioService minioService;


@GetMapping("/downloadFile")
@ApiOperation("下载单个文件")
public ResponseEntity<InputStreamResource> downloadFile(@RequestParam("path") String path) throws Exception {
String bucketName = path.substring(0, path.indexOf("/"));
String prefix = path.substring(path.indexOf("/")+1,path.length());
return minioService.downloadFile(bucketName, prefix);
}

@GetMapping("/download")
@ApiOperation(value = "minio存储下载", notes = "minio存储下载文件为zip包")
public ResponseEntity<InputStreamResource> downloadDataset(@RequestParam("path") String path) {
String bucketName = path.substring(0, path.indexOf("/"));
String prefix = path.substring(path.indexOf("/")+1,path.length())+"/";
String prefix = path.substring(path.indexOf("/")+1,path.length());
return minioService.downloadZipFile(bucketName,prefix);
}



+ 1
- 4
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/domain/RayIns.java View File

@@ -22,9 +22,6 @@ public class RayIns {

private String resultPath;

@TableField(exist = false)
private String resultTxt;

private Integer state;

private String status;
@@ -50,7 +47,7 @@ public class RayIns {
private Date finishTime;

@TableField(exist = false)
private List<Map> fileList;
private String resultTxt;

@TableField(exist = false)
private ArrayList<Map<String, Object>> trialList;


+ 3
- 0
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/MinioService.java View File

@@ -9,6 +9,9 @@ import java.util.List;
import java.util.Map;

public interface MinioService {

ResponseEntity<InputStreamResource> downloadFile(String bucketName , String path) throws Exception;

ResponseEntity<InputStreamResource> downloadZipFile(String bucketName , String path);

Map<String, String> uploadFile(String bucketName, String objectName, MultipartFile file ) throws Exception;


+ 26
- 0
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/MinioServiceImpl.java View File

@@ -31,6 +31,32 @@ public class MinioServiceImpl implements MinioService {
public MinioServiceImpl(MinioUtil minioUtil) {
this.minioUtil = minioUtil;
}

@Override
public ResponseEntity<InputStreamResource> downloadFile(String bucketName , String url) throws Exception {
try {
// 使用ByteArrayOutputStream来捕获下载的数据
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
minioUtil.downloadObject(bucketName, url, outputStream);

ByteArrayInputStream inputStream = new ByteArrayInputStream(outputStream.toByteArray());
InputStreamResource resource = new InputStreamResource(inputStream);

return ResponseEntity.ok()
.header(HttpHeaders.CONTENT_DISPOSITION, "attachment; filename=\"" + extractFileName(url) + "\"")
.contentType(MediaType.APPLICATION_OCTET_STREAM)
.body(resource);
} catch (Exception e) {
e.printStackTrace();
throw new Exception("下载文件错误");
}
}

private String extractFileName(String urlStr) {
return urlStr.substring(urlStr.lastIndexOf('/') + 1);
}


@Override
public ResponseEntity<InputStreamResource> downloadZipFile(String bucketName,String path) {
try {


+ 44
- 9
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/RayInsServiceImpl.java View File

@@ -11,6 +11,8 @@ import com.ruoyi.platform.utils.HttpUtils;
import com.ruoyi.platform.utils.JsonUtils;
import com.ruoyi.platform.utils.MinioUtil;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.PageImpl;
@@ -26,6 +28,8 @@ import java.util.stream.Collectors;

@Service("rayInsService")
public class RayInsServiceImpl implements RayInsService {
private static final Logger logger = LoggerFactory.getLogger(RayInsServiceImpl.class);

@Value("${argo.url}")
private String argoUrl;
@Value("${argo.workflowStatus}")
@@ -34,8 +38,6 @@ public class RayInsServiceImpl implements RayInsService {
private String argoWorkflowTermination;
@Value("${minio.endpointIp}")
String endpoint;
@Value("${minio.dataReleaseBucketName}")
private String bucketName;

@Resource
private RayInsDao rayInsDao;
@@ -174,8 +176,6 @@ public class RayInsServiceImpl implements RayInsService {
rayIns = queryStatusFromArgo(rayIns);
}
getTrialList(rayIns);

// rayIns.setTrialList(getTrialList(rayIns.getResultPath()));
return rayIns;
}

@@ -276,12 +276,12 @@ public class RayInsServiceImpl implements RayInsService {
rayIns.setResultPath(endpoint + "/" + directoryPath);
rayIns.setResultTxt(endpoint + "/" + directoryPath + "/result.txt");

String bucketName = directoryPath.substring(0, directoryPath.indexOf("/"));
String prefix = directoryPath.substring(directoryPath.indexOf("/") + 1, directoryPath.length()) + "/";
List<Map> maps = minioUtil.listRayFilesInDirectory(bucketName, prefix);
List<Map> fileMaps = minioUtil.listRayFilesInDirectory(bucketName, prefix);

if (!maps.isEmpty()) {
rayIns.setFileList(maps);
List<Map> collect = maps.stream().filter(map -> map.get("name").toString().startsWith("experiment_state")).collect(Collectors.toList());
if (!fileMaps.isEmpty()) {
List<Map> collect = fileMaps.stream().filter(map -> map.get("name").toString().startsWith("experiment_state")).collect(Collectors.toList());
if (!collect.isEmpty()) {
Path experimentState = Paths.get(collect.get(0).get("name").toString());
String content = minioUtil.readObjectAsString(bucketName, prefix + "/" + experimentState);
@@ -294,8 +294,9 @@ public class RayInsServiceImpl implements RayInsService {
Map<String, Object> trial_data_0 = JsonUtils.jsonToMap((String) trial_data.get(0));
Map<String, Object> trial_data_1 = JsonUtils.jsonToMap((String) trial_data.get(1));

String trialId = (String) trial_data_0.get("trial_id");
Map<String, Object> trial = new HashMap<>();
trial.put("trial_id", trial_data_0.get("trial_id"));
trial.put("trial_id", trialId);
trial.put("config", trial_data_0.get("config"));
trial.put("status", trial_data_0.get("status"));

@@ -310,10 +311,44 @@ public class RayInsServiceImpl implements RayInsService {
trial.put("metric_analysis", metric_analysis.get((String) param.get("metric")));
trial.put("metric", param.get("metric"));

for (Map fileMap : fileMaps) {
if (fileMap.get("name").toString().contains(trialId)) {
trial.put("file", fileMap);
}
}

try {
String resultTxt = minioUtil.readObjectAsString(bucketName, prefix + "result.txt");
String bestTrialId = getStringBetween(resultTxt, "'trial_id': '", "'");
if (bestTrialId.equals(trialId)) {
trial.put("is_best", true);
}
} catch (Exception e) {
logger.error("未找到结果文件:result.txt");
}
trialList.add(trial);
}
rayIns.setTrialList(trialList);
}
}
}

String getStringBetween(String input, String startMarker, String endMarker) {
int startIndex = input.indexOf(startMarker);
if (startIndex == -1) {
return ""; // 如果未找到起始标记,返回空字符串
}

// 跳过起始标记
startIndex += startMarker.length();

int endIndex = input.indexOf(endMarker, startIndex);
if (endIndex == -1) {
return ""; // 如果未找到结束标记,返回空字符串
}

return input.substring(startIndex, endIndex);
}
}



+ 1
- 2
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/utils/MinioUtil.java View File

@@ -355,9 +355,8 @@ public class MinioUtil {
map.put("children", listRayFilesInDirectory(bucketName, fullPath));
} else {
map.put("isFile", true);
map.put("url", minioEndpoint + "/" + bucketName + "/" + fullPath);
}
map.put("url", bucketName + "/" + fullPath);
fileInfoList.add(map);
}



Loading…
Cancel
Save