| @@ -7,7 +7,9 @@ import com.ruoyi.platform.domain.ModelsVersion; | |||||
| import com.ruoyi.platform.service.ModelsService; | import com.ruoyi.platform.service.ModelsService; | ||||
| import io.swagger.annotations.Api; | import io.swagger.annotations.Api; | ||||
| import io.swagger.annotations.ApiOperation; | import io.swagger.annotations.ApiOperation; | ||||
| import org.springframework.core.io.InputStreamResource; | |||||
| import org.springframework.data.domain.PageRequest; | import org.springframework.data.domain.PageRequest; | ||||
| import org.springframework.http.ResponseEntity; | |||||
| import org.springframework.web.bind.annotation.*; | import org.springframework.web.bind.annotation.*; | ||||
| import org.springframework.web.multipart.MultipartFile; | import org.springframework.web.multipart.MultipartFile; | ||||
| @@ -127,20 +129,21 @@ public class ModelsController { | |||||
| */ | */ | ||||
| @GetMapping("/download/{models_version_id}") | @GetMapping("/download/{models_version_id}") | ||||
| @ApiOperation(value = "下载模型", notes = "根据模型版本表id下载模型文件。") | @ApiOperation(value = "下载模型", notes = "根据模型版本表id下载模型文件。") | ||||
| public AjaxResult downloadModels(@PathVariable("models_version_id") Integer models_version_id) { | |||||
| return AjaxResult.success(modelsService.downloadModels(models_version_id)); | |||||
| public ResponseEntity<InputStreamResource> downloadModels(@PathVariable("models_version_id") Integer models_version_id) { | |||||
| return modelsService.downloadModels(models_version_id); | |||||
| } | } | ||||
| /** | /** | ||||
| * 模型打包下载 | * 模型打包下载 | ||||
| * | * | ||||
| *@param models_version_id 模型版本表主键 | |||||
| * @param modelsId 模型版本表主键 | |||||
| * @param version 模型版本表主键 | |||||
| * @return 单条数据 | * @return 单条数据 | ||||
| */ | */ | ||||
| @GetMapping("/downloadAllFiles/{models_version_id}") | |||||
| @ApiOperation(value = "下载模型", notes = "根据模型版本表id下载模型文件。") | |||||
| public AjaxResult downloadAllModelFiles(@PathVariable("models_version_id") Integer models_version_id) { | |||||
| return AjaxResult.success(modelsService.downloadAllModelFiles(models_version_id)); | |||||
| @GetMapping("/downloadAllFiles") | |||||
| @ApiOperation(value = "下载模型压缩包", notes = "根据模型ID和版本下载所有模型文件,并打包。") | |||||
| public ResponseEntity<InputStreamResource> downloadAllModelFiles(@RequestParam("models_id") Integer modelsId, @RequestParam("version") String version) { | |||||
| return modelsService.downloadAllModelFiles(modelsId, version); | |||||
| } | } | ||||
| @@ -82,6 +82,8 @@ public interface ModelsVersionDao { | |||||
| List<ModelsVersion> queryByModelsId(Integer modelsId); | List<ModelsVersion> queryByModelsId(Integer modelsId); | ||||
| ModelsVersion queryByModelsVersion(ModelsVersion modelsVersion); | |||||
| ModelsVersion queryByModelsVersion(@Param("modelsVersion") ModelsVersion modelsVersion); | |||||
| List<ModelsVersion> queryAllByModelsVersion(@Param("modelsId") Integer modelsId, @Param("version") String version); | |||||
| } | } | ||||
| @@ -72,5 +72,5 @@ public interface ModelsService { | |||||
| Map uploadModelsPipeline(ModelsVersion modelsVersion) throws Exception; | Map uploadModelsPipeline(ModelsVersion modelsVersion) throws Exception; | ||||
| ResponseEntity<InputStreamResource> downloadAllModelFiles(Integer modelsVersionId); | |||||
| ResponseEntity<InputStreamResource> downloadAllModelFiles(Integer modelsId, String version); | |||||
| } | } | ||||
| @@ -1,8 +1,6 @@ | |||||
| package com.ruoyi.platform.service.impl; | package com.ruoyi.platform.service.impl; | ||||
| import com.ruoyi.common.security.utils.SecurityUtils; | import com.ruoyi.common.security.utils.SecurityUtils; | ||||
| import com.ruoyi.platform.domain.Dataset; | |||||
| import com.ruoyi.platform.domain.DatasetVersion; | |||||
| import com.ruoyi.platform.domain.Models; | import com.ruoyi.platform.domain.Models; | ||||
| import com.ruoyi.platform.domain.ModelsVersion; | import com.ruoyi.platform.domain.ModelsVersion; | ||||
| import com.ruoyi.platform.mapper.ModelsDao; | import com.ruoyi.platform.mapper.ModelsDao; | ||||
| @@ -14,7 +12,6 @@ import com.ruoyi.platform.utils.MinioUtil; | |||||
| import com.ruoyi.system.api.model.LoginUser; | import com.ruoyi.system.api.model.LoginUser; | ||||
| import io.minio.MinioClient; | import io.minio.MinioClient; | ||||
| import org.apache.commons.lang3.StringUtils; | import org.apache.commons.lang3.StringUtils; | ||||
| import org.springframework.beans.BeanUtils; | |||||
| import org.springframework.beans.factory.annotation.Value; | import org.springframework.beans.factory.annotation.Value; | ||||
| import org.springframework.core.io.InputStreamResource; | import org.springframework.core.io.InputStreamResource; | ||||
| import org.springframework.data.domain.Page; | import org.springframework.data.domain.Page; | ||||
| @@ -31,10 +28,12 @@ import javax.annotation.Resource; | |||||
| import java.io.ByteArrayInputStream; | import java.io.ByteArrayInputStream; | ||||
| import java.io.ByteArrayOutputStream; | import java.io.ByteArrayOutputStream; | ||||
| import java.io.InputStream; | import java.io.InputStream; | ||||
| import java.text.SimpleDateFormat; | |||||
| import java.util.Date; | import java.util.Date; | ||||
| import java.util.HashMap; | import java.util.HashMap; | ||||
| import java.util.List; | |||||
| import java.util.Map; | import java.util.Map; | ||||
| import java.util.zip.ZipEntry; | |||||
| import java.util.zip.ZipOutputStream; | |||||
| /** | /** | ||||
| * (Models)表服务实现类 | * (Models)表服务实现类 | ||||
| @@ -308,12 +307,50 @@ public class ModelsServiceImpl implements ModelsService { | |||||
| /** | /** | ||||
| * 下载所有模型文件,用压缩包的方式返回 | * 下载所有模型文件,用压缩包的方式返回 | ||||
| * | * | ||||
| * @param modelsVersionId models_version表的主键 | |||||
| * @param modelsId 模型ID | |||||
| * @param version 模型版本号 | |||||
| * @return 文件内容 | * @return 文件内容 | ||||
| */ | */ | ||||
| @Override | @Override | ||||
| public ResponseEntity<InputStreamResource> downloadAllModelFiles(Integer modelsVersionId) { | |||||
| return null; | |||||
| public ResponseEntity<InputStreamResource> downloadAllModelFiles(Integer modelsId, String version) { | |||||
| // 查询特定模型和版本对应的所有文件 | |||||
| List<ModelsVersion> modelsVersionList = this.modelsVersionDao.queryAllByModelsVersion(modelsId, version); | |||||
| if (modelsVersionList == null || modelsVersionList.isEmpty()) { | |||||
| return ResponseEntity.status(HttpStatus.NOT_FOUND).body(null); | |||||
| } | |||||
| // 创建ZIP文件,准备打包下载 | |||||
| ByteArrayOutputStream baos = new ByteArrayOutputStream(); | |||||
| try (ZipOutputStream zos = new ZipOutputStream(baos)) { | |||||
| //遍历每一个version数据项 | |||||
| for (ModelsVersion modelsVersion : modelsVersionList) { | |||||
| String objectName = modelsVersion.getUrl(); | |||||
| if (objectName != null && !objectName.isEmpty()) { | |||||
| // 使用ByteArrayOutputStream来捕获MinIO中的文件 | |||||
| ByteArrayOutputStream fileOutputStream = new ByteArrayOutputStream(); | |||||
| minioUtil.downloadObject(bucketName, objectName, fileOutputStream); | |||||
| // 添加到ZIP文件 | |||||
| zos.putNextEntry(new ZipEntry(extractFileName(objectName))); | |||||
| fileOutputStream.writeTo(zos); | |||||
| zos.closeEntry(); | |||||
| } | |||||
| } | |||||
| // 转换为输入流 | |||||
| ByteArrayInputStream inputStream = new ByteArrayInputStream(baos.toByteArray()); | |||||
| InputStreamResource resource = new InputStreamResource(inputStream); | |||||
| // 设置响应 | |||||
| return ResponseEntity.ok() | |||||
| .header(HttpHeaders.CONTENT_DISPOSITION, "attachment; filename=\"models_" + modelsId + "_version_" + version + ".zip\"") | |||||
| .contentType(MediaType.APPLICATION_OCTET_STREAM) | |||||
| .body(resource); | |||||
| } catch (Exception e) { | |||||
| e.printStackTrace(); | |||||
| return ResponseEntity.status(HttpStatus.INTERNAL_SERVER_ERROR).body(null); | |||||
| } | |||||
| } | } | ||||
| private String extractFileName(String urlStr) { | private String extractFileName(String urlStr) { | ||||
| @@ -31,6 +31,23 @@ | |||||
| </select> | </select> | ||||
| <select id="queryByModelsVersion" resultMap="ModelsVersionMap"> | <select id="queryByModelsVersion" resultMap="ModelsVersionMap"> | ||||
| select | |||||
| id,models_id,version,url,file_name,file_size,status,create_by,create_time,update_by,update_time,state | |||||
| from models_version | |||||
| <where> | |||||
| state = 1 | |||||
| <if test="modelsId != null"> | |||||
| and models_id = #{modelsVersion.modelsId} | |||||
| </if> | |||||
| <if test="version != null and version != ''"> | |||||
| and version = #{modelsVersion.version} | |||||
| </if> | |||||
| </where> | |||||
| limit 1 | |||||
| </select> | |||||
| <!-- 查询模型同一个版本下的所有文件 --> | |||||
| <select id="queryAllByModelsVersion" resultMap="ModelsVersionMap"> | |||||
| select | select | ||||
| id,models_id,version,url,file_name,file_size,status,create_by,create_time,update_by,update_time,state | id,models_id,version,url,file_name,file_size,status,create_by,create_time,update_by,update_time,state | ||||
| from models_version | from models_version | ||||
| @@ -43,9 +60,10 @@ | |||||
| and version = #{version} | and version = #{version} | ||||
| </if> | </if> | ||||
| </where> | </where> | ||||
| limit 1 | |||||
| </select> | </select> | ||||
| <!--查询指定行数据--> | <!--查询指定行数据--> | ||||
| <select id="queryAllByLimit" resultMap="ModelsVersionMap"> | <select id="queryAllByLimit" resultMap="ModelsVersionMap"> | ||||
| select | select | ||||