diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/model/ModelsController.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/model/ModelsController.java index 2424b234..b2624c9c 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/model/ModelsController.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/model/ModelsController.java @@ -7,7 +7,9 @@ import com.ruoyi.platform.domain.ModelsVersion; import com.ruoyi.platform.service.ModelsService; import io.swagger.annotations.Api; import io.swagger.annotations.ApiOperation; +import org.springframework.core.io.InputStreamResource; import org.springframework.data.domain.PageRequest; +import org.springframework.http.ResponseEntity; import org.springframework.web.bind.annotation.*; import org.springframework.web.multipart.MultipartFile; @@ -127,20 +129,21 @@ public class ModelsController { */ @GetMapping("/download/{models_version_id}") @ApiOperation(value = "下载模型", notes = "根据模型版本表id下载模型文件。") - public AjaxResult downloadModels(@PathVariable("models_version_id") Integer models_version_id) { - return AjaxResult.success(modelsService.downloadModels(models_version_id)); + public ResponseEntity downloadModels(@PathVariable("models_version_id") Integer models_version_id) { + return modelsService.downloadModels(models_version_id); } /** * 模型打包下载 * - *@param models_version_id 模型版本表主键 + * @param modelsId 模型版本表主键 + * @param version 模型版本表主键 * @return 单条数据 */ - @GetMapping("/downloadAllFiles/{models_version_id}") - @ApiOperation(value = "下载模型", notes = "根据模型版本表id下载模型文件。") - public AjaxResult downloadAllModelFiles(@PathVariable("models_version_id") Integer models_version_id) { - return AjaxResult.success(modelsService.downloadAllModelFiles(models_version_id)); + @GetMapping("/downloadAllFiles") + @ApiOperation(value = "下载模型压缩包", notes = "根据模型ID和版本下载所有模型文件,并打包。") + public ResponseEntity downloadAllModelFiles(@RequestParam("models_id") Integer modelsId, @RequestParam("version") String version) { + return modelsService.downloadAllModelFiles(modelsId, version); } diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/mapper/ModelsVersionDao.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/mapper/ModelsVersionDao.java index d4a843aa..a51ae5ac 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/mapper/ModelsVersionDao.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/mapper/ModelsVersionDao.java @@ -82,6 +82,8 @@ public interface ModelsVersionDao { List queryByModelsId(Integer modelsId); - ModelsVersion queryByModelsVersion(ModelsVersion modelsVersion); + ModelsVersion queryByModelsVersion(@Param("modelsVersion") ModelsVersion modelsVersion); + + List queryAllByModelsVersion(@Param("modelsId") Integer modelsId, @Param("version") String version); } diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/ModelsService.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/ModelsService.java index 3461c596..5fc7ea9d 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/ModelsService.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/ModelsService.java @@ -72,5 +72,5 @@ public interface ModelsService { Map uploadModelsPipeline(ModelsVersion modelsVersion) throws Exception; - ResponseEntity downloadAllModelFiles(Integer modelsVersionId); + ResponseEntity downloadAllModelFiles(Integer modelsId, String version); } diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ModelsServiceImpl.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ModelsServiceImpl.java index 14087f71..1452cb10 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ModelsServiceImpl.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ModelsServiceImpl.java @@ -1,8 +1,6 @@ package com.ruoyi.platform.service.impl; import com.ruoyi.common.security.utils.SecurityUtils; -import com.ruoyi.platform.domain.Dataset; -import com.ruoyi.platform.domain.DatasetVersion; import com.ruoyi.platform.domain.Models; import com.ruoyi.platform.domain.ModelsVersion; import com.ruoyi.platform.mapper.ModelsDao; @@ -14,7 +12,6 @@ import com.ruoyi.platform.utils.MinioUtil; import com.ruoyi.system.api.model.LoginUser; import io.minio.MinioClient; import org.apache.commons.lang3.StringUtils; -import org.springframework.beans.BeanUtils; import org.springframework.beans.factory.annotation.Value; import org.springframework.core.io.InputStreamResource; import org.springframework.data.domain.Page; @@ -31,10 +28,12 @@ import javax.annotation.Resource; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.InputStream; -import java.text.SimpleDateFormat; import java.util.Date; import java.util.HashMap; +import java.util.List; import java.util.Map; +import java.util.zip.ZipEntry; +import java.util.zip.ZipOutputStream; /** * (Models)表服务实现类 @@ -308,12 +307,50 @@ public class ModelsServiceImpl implements ModelsService { /** * 下载所有模型文件,用压缩包的方式返回 * - * @param modelsVersionId models_version表的主键 + * @param modelsId 模型ID + * @param version 模型版本号 * @return 文件内容 */ @Override - public ResponseEntity downloadAllModelFiles(Integer modelsVersionId) { - return null; + public ResponseEntity downloadAllModelFiles(Integer modelsId, String version) { + // 查询特定模型和版本对应的所有文件 + List modelsVersionList = this.modelsVersionDao.queryAllByModelsVersion(modelsId, version); + if (modelsVersionList == null || modelsVersionList.isEmpty()) { + return ResponseEntity.status(HttpStatus.NOT_FOUND).body(null); + } + + // 创建ZIP文件,准备打包下载 + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + try (ZipOutputStream zos = new ZipOutputStream(baos)) { + //遍历每一个version数据项 + + for (ModelsVersion modelsVersion : modelsVersionList) { + String objectName = modelsVersion.getUrl(); + if (objectName != null && !objectName.isEmpty()) { + // 使用ByteArrayOutputStream来捕获MinIO中的文件 + ByteArrayOutputStream fileOutputStream = new ByteArrayOutputStream(); + minioUtil.downloadObject(bucketName, objectName, fileOutputStream); + // 添加到ZIP文件 + zos.putNextEntry(new ZipEntry(extractFileName(objectName))); + fileOutputStream.writeTo(zos); + zos.closeEntry(); + } + } + + // 转换为输入流 + ByteArrayInputStream inputStream = new ByteArrayInputStream(baos.toByteArray()); + InputStreamResource resource = new InputStreamResource(inputStream); + + // 设置响应 + return ResponseEntity.ok() + .header(HttpHeaders.CONTENT_DISPOSITION, "attachment; filename=\"models_" + modelsId + "_version_" + version + ".zip\"") + .contentType(MediaType.APPLICATION_OCTET_STREAM) + .body(resource); + } catch (Exception e) { + e.printStackTrace(); + return ResponseEntity.status(HttpStatus.INTERNAL_SERVER_ERROR).body(null); + } + } private String extractFileName(String urlStr) { diff --git a/ruoyi-modules/management-platform/src/main/resources/mapper/managementPlatform/ModelsVersionDaoMapper.xml b/ruoyi-modules/management-platform/src/main/resources/mapper/managementPlatform/ModelsVersionDaoMapper.xml index bf38bf71..b7048474 100644 --- a/ruoyi-modules/management-platform/src/main/resources/mapper/managementPlatform/ModelsVersionDaoMapper.xml +++ b/ruoyi-modules/management-platform/src/main/resources/mapper/managementPlatform/ModelsVersionDaoMapper.xml @@ -31,6 +31,23 @@ + + + + +