Browse Source

Merge remote-tracking branch 'origin/dev-czh' into dev

dev-restore_mount
chenzhihang 1 year ago
parent
commit
a912b24ad2
6 changed files with 433 additions and 124 deletions
  1. +1
    -1
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/dataset/NewDatasetFromGitController.java
  2. +86
    -0
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/model/NewModelFromGitController.java
  3. +16
    -5
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/ModelsService.java
  4. +314
    -48
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ModelsServiceImpl.java
  5. +16
    -68
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/vo/ModelsVo.java
  6. +0
    -2
      ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/vo/NewDatasetVo.java

+ 1
- 1
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/dataset/NewDatasetFromGitController.java View File

@@ -15,7 +15,7 @@ import org.springframework.web.multipart.MultipartFile;
import javax.annotation.Resource;
@RestController
@RequestMapping("newdataset")
@Api(value = "新数据集管理")
//@Api(value = "新数据集管理")
public class NewDatasetFromGitController {
/**
* 服务对象


+ 86
- 0
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/model/NewModelFromGitController.java View File

@@ -0,0 +1,86 @@
package com.ruoyi.platform.controller.model;

import com.ruoyi.common.core.web.domain.AjaxResult;
import com.ruoyi.platform.service.ModelsService;
import com.ruoyi.platform.vo.ModelsVo;
import io.swagger.annotations.Api;
import io.swagger.annotations.ApiOperation;
import org.springframework.core.io.InputStreamResource;
import org.springframework.data.domain.PageRequest;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.*;

import javax.annotation.Resource;

@RestController
@RequestMapping("newmodel")
@Api(value = "新模型管理")
public class NewModelFromGitController {

@Resource
private ModelsService modelsService;

@PostMapping("/addModelAndVersion")
@ApiOperation("添加模型和版本")
public AjaxResult addModelAndVersion(@RequestBody ModelsVo modelsVo) throws Exception {
return AjaxResult.success(this.modelsService.newCreateModel(modelsVo));
}

@PostMapping("/addVersion")
@ApiOperation("添加版本")
public AjaxResult addVersion(@RequestBody ModelsVo modelsVo) throws Exception {
return AjaxResult.success(this.modelsService.newCreateVersion(modelsVo));
}

/**
* 模型打包下载
*
* @param version 模型版本
* @return 模型
*/
@GetMapping("/downloadAllFiles")
@ApiOperation(value = "下载同一版本下所有模型,并打包")
public ResponseEntity<InputStreamResource> downloadAllDatasetFiles(@RequestParam(value = "repository_name") String repositoryName,
@RequestParam(value = "version") String version,
@RequestParam(value = "git_link_username") String gitLinkUsername,
@RequestParam(value = "git_link_password") String gitLinkPassword) throws Exception {
return modelsService.downloadAllModelFilesNew(repositoryName, version, gitLinkUsername, gitLinkPassword);
}


/**
* 下载模型
*
* @param model_version_id ps:这里的id是model_version表的主键
* @return 模型
*/

@GetMapping("/download/{model_version_id}")
@ApiOperation(value = "下载单个模型文件", notes = "根据模型版本表id下载单个模型文件")
public ResponseEntity<InputStreamResource> downloadModel(@PathVariable("model_version_id") Integer model_version_id) throws Exception {
return modelsService.downloadModels(model_version_id);
}

@GetMapping("/queryDatasets")
@ApiOperation("模型广场公开模型分页查询,根据model_type,model_tag筛选,true公开false私有")
public AjaxResult queryDatasets(@RequestParam(value = "page") int page,
@RequestParam(value = "size") int size,
@RequestParam(value = "is_public") Boolean isPublic,
@RequestParam(value = "model_type", required = false) String modelType,
@RequestParam(value = "model_tag", required = false) String modelTag,
@RequestParam(value = "git_link_username") String gitLinkUsername,
@RequestParam(value = "git_link_password") String gitLinkPassword) throws Exception {
PageRequest pageRequest = PageRequest.of(page, size);
ModelsVo modelsVo = new ModelsVo();
modelsVo.setModelType(modelType);
modelsVo.setModelTag(modelTag);
modelsVo.setGitLinkUsername(gitLinkUsername);
modelsVo.setGitLinkPassword(gitLinkPassword);
if (isPublic) {
return AjaxResult.success(this.modelsService.newPubilcQueryByPage(modelsVo, pageRequest));
} else {
return AjaxResult.success(this.modelsService.newPersonalQueryByPage(modelsVo, pageRequest));
}
}

}

+ 16
- 5
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/ModelsService.java View File

@@ -1,8 +1,6 @@
package com.ruoyi.platform.service;



import com.ruoyi.platform.domain.Dataset;
import com.ruoyi.platform.domain.Models;
import com.ruoyi.platform.domain.ModelsVersion;
import com.ruoyi.platform.vo.ModelsVo;
@@ -12,8 +10,10 @@ import org.springframework.data.domain.PageRequest;
import org.springframework.http.ResponseEntity;
import org.springframework.web.multipart.MultipartFile;

import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;

/**
* (Models)表服务接口
@@ -34,8 +34,8 @@ public interface ModelsService {
/**
* 分页查询
*
* @param models 筛选条件
* @param pageRequest 分页对象
* @param models 筛选条件
* @param pageRequest 分页对象
* @return 查询结果
*/
Page<Models> queryByPage(Models models, PageRequest pageRequest);
@@ -70,12 +70,12 @@ public interface ModelsService {
ResponseEntity<InputStreamResource> downloadModels(Integer id) throws Exception;



List<Map<String, String>> uploadModels(MultipartFile[] files, String uuid) throws Exception;

Map uploadModelsPipeline(ModelsVersion modelsVersion) throws Exception;

ResponseEntity<InputStreamResource> downloadAllModelFiles(Integer modelsId, String version) throws Exception;

List<String> getModelVersions(Integer modelId) throws Exception;

String insertModelAndVersion(ModelsVo modelsVo) throws Exception;
@@ -85,4 +85,15 @@ public interface ModelsService {
public void checkDeclaredName(Models insert) throws Exception;

List<Map<String, String>> exportModels(String path, String uuid) throws Exception;


CompletableFuture<String> newCreateModel(ModelsVo modelsVo) throws Exception;

CompletableFuture<String> newCreateVersion(ModelsVo modelsVo);

ResponseEntity<InputStreamResource> downloadAllModelFilesNew(String repositoryName, String version, String gitLinkUsername, String gitLinkPassword) throws IOException, Exception;

Page<ModelsVo> newPubilcQueryByPage(ModelsVo modelsVo, PageRequest pageRequest) throws Exception;

Page<ModelsVo> newPersonalQueryByPage(ModelsVo modelsVo, PageRequest pageRequest) throws Exception;
}

+ 314
- 48
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ModelsServiceImpl.java View File

@@ -1,26 +1,23 @@
package com.ruoyi.platform.service.impl;

import com.ruoyi.common.core.utils.DateUtils;
import com.ruoyi.common.security.utils.SecurityUtils;
import com.ruoyi.platform.annotations.CheckDuplicate;
import com.ruoyi.platform.domain.AssetIcon;
import com.ruoyi.platform.domain.Dataset;
import com.ruoyi.platform.domain.Models;
import com.ruoyi.platform.domain.ModelsVersion;
import com.ruoyi.platform.mapper.ModelsDao;
import com.ruoyi.platform.mapper.ModelsVersionDao;
import com.ruoyi.platform.service.AssetIconService;
import com.ruoyi.platform.service.MinioService;
import com.ruoyi.platform.service.ModelsService;
import com.ruoyi.platform.service.ModelsVersionService;
import com.ruoyi.platform.utils.BeansUtils;
import com.ruoyi.platform.utils.FileUtil;
import com.ruoyi.platform.utils.MinioUtil;
import com.ruoyi.platform.service.*;
import com.ruoyi.platform.utils.*;
import com.ruoyi.platform.vo.GitProjectVo;
import com.ruoyi.platform.vo.ModelsVo;
import com.ruoyi.platform.vo.VersionVo;
import com.ruoyi.system.api.model.LoginUser;
import io.minio.messages.Item;
import io.netty.util.Version;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.core.io.InputStreamResource;
import org.springframework.data.domain.Page;
@@ -32,14 +29,17 @@ import org.springframework.http.ResponseEntity;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import org.springframework.web.multipart.MultipartFile;
import redis.clients.jedis.Jedis;

import javax.annotation.Resource;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.InputStream;
import java.io.*;
import java.lang.reflect.Field;
import java.net.URLEncoder;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.*;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Collectors;
import java.util.zip.ZipEntry;
import java.util.zip.ZipOutputStream;
@@ -52,6 +52,9 @@ import java.util.zip.ZipOutputStream;
*/
@Service("modelsService")
public class ModelsServiceImpl implements ModelsService {

private static final Logger logger = LoggerFactory.getLogger(ModelsServiceImpl.class);

@Resource
private ModelsDao modelsDao;
@Resource
@@ -63,6 +66,9 @@ public class ModelsServiceImpl implements ModelsService {
@Resource
private MinioService minioService;

@Resource
private GitService gitService;

@Resource
private AssetIconService assetIconService;

@@ -71,8 +77,20 @@ public class ModelsServiceImpl implements ModelsService {
@Value("${minio.dataReleaseBucketName}")
private String bucketName;
@Resource
private MinioUtil minioUtil;

private MinioUtil minioUtil;

@Value("${spring.redis.host}")
private String redisHost;
@Value("${git.endpoint}")
String gitendpoint;
@Value("${git.localPath}")
String localPath;
@Value("${minio.accessKey}")
String accessKeyId;
@Value("${minio.secretKey}")
String secretAccessKey;
@Value("${minio.endpoint}")
String endpoint;

/**
* 通过ID查询单条数据
@@ -86,15 +104,15 @@ public class ModelsServiceImpl implements ModelsService {
String modelType = models.getModelType();
String modelTag = models.getModelTag();
//去资产管理表中查询对应的图标名,注意判空逻辑,只有当dataType和dataTag不为空时,才进行查询
if(modelType != null && !modelType.isEmpty()){
if (modelType != null && !modelType.isEmpty()) {
AssetIcon modelTypeAssetIcon = assetIconService.queryById(Integer.valueOf(modelType));
if (modelTypeAssetIcon != null){
if (modelTypeAssetIcon != null) {
models.setModelTypeName(modelTypeAssetIcon.getName());
}
}
if(modelTag != null && !modelTag.isEmpty()){
if (modelTag != null && !modelTag.isEmpty()) {
AssetIcon modelTagAssetIcon = assetIconService.queryById(Integer.valueOf(modelTag));
if (modelTagAssetIcon != null){
if (modelTagAssetIcon != null) {
models.setModelTagName(modelTagAssetIcon.getName());
}
}
@@ -104,8 +122,8 @@ public class ModelsServiceImpl implements ModelsService {
/**
* 分页查询
*
* @param models 筛选条件
* @param pageRequest 分页对象
* @param models 筛选条件
* @param pageRequest 分页对象
* @return 查询结果
*/
@Override
@@ -142,7 +160,7 @@ public class ModelsServiceImpl implements ModelsService {
@Override
public Models update(Models models) {
int currentState = models.getState();
if(currentState == 0){
if (currentState == 0) {
throw new RuntimeException("模型已被删除,无法更新。");
}
LoginUser loginUser = SecurityUtils.getLoginUser();
@@ -166,7 +184,7 @@ public class ModelsServiceImpl implements ModelsService {
@Override
public String removeById(Integer id) throws Exception {
Models models = this.modelsDao.queryById(id);
if (models == null){
if (models == null) {
throw new Exception("模型不存在");
}

@@ -176,15 +194,15 @@ public class ModelsServiceImpl implements ModelsService {


String createdBy = models.getCreateBy();
if (!(StringUtils.equals(username,"admin") || StringUtils.equals(username,createdBy))){
if (!(StringUtils.equals(username, "admin") || StringUtils.equals(username, createdBy))) {
throw new Exception("无权限删除该模型");
}
//判断是否有版本文件
if (!modelsVersionService.queryByModelsId(id).isEmpty()){
if (!modelsVersionService.queryByModelsId(id).isEmpty()) {
throw new Exception("请先删除该模型下的版本文件");
}
models.setState(0);
return this.modelsDao.update(models)>0?"删除成功":"删除失败";
return this.modelsDao.update(models) > 0 ? "删除成功" : "删除失败";
}

/**
@@ -192,7 +210,6 @@ public class ModelsServiceImpl implements ModelsService {
*
* @param id models_version表的主键
* @return 文件内容
*
*/

@Override
@@ -204,19 +221,19 @@ public class ModelsServiceImpl implements ModelsService {
}
// 从数据库中获取存储路径(即MinIO中的对象名称)
String objectName = modelsVersion.getUrl();
if(objectName == null || objectName.isEmpty() ){
if (objectName == null || objectName.isEmpty()) {
throw new Exception("未找到该版本模型文件");
}

try {
// 使用ByteArrayOutputStream来捕获下载的数据
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
minioUtil.downloadObject(bucketName,objectName,outputStream);
minioUtil.downloadObject(bucketName, objectName, outputStream);
ByteArrayInputStream inputStream = new ByteArrayInputStream(outputStream.toByteArray());
InputStreamResource resource = new InputStreamResource(inputStream);

return ResponseEntity.ok()
.header(HttpHeaders.CONTENT_DISPOSITION, "attachment; filename=\"" + URLEncoder.encode(extractFileName(objectName),"UTF-8") + "\"")
.header(HttpHeaders.CONTENT_DISPOSITION, "attachment; filename=\"" + URLEncoder.encode(extractFileName(objectName), "UTF-8") + "\"")
.contentType(MediaType.APPLICATION_OCTET_STREAM)
.body(resource);

@@ -231,13 +248,13 @@ public class ModelsServiceImpl implements ModelsService {
* 上传模型文件
*
* @param files 文件
* @param uuid 唯一标识
* @param uuid 唯一标识
* @return 是否成功
*/
@Override
public List<Map<String, String>> uploadModels(MultipartFile[] files, String uuid) throws Exception {
List<Map<String, String>> results = new ArrayList<>();
for (MultipartFile file:files) {
for (MultipartFile file : files) {
// 构建objectName
String username = SecurityUtils.getLoginUser().getUsername();
String fileName = file.getOriginalFilename();
@@ -265,7 +282,7 @@ public class ModelsServiceImpl implements ModelsService {
//插表,因为这里是一次直接插表所以这里定掉date,然后用DAO插入
Date date = new Date();
//String timestamp = new SimpleDateFormat("yyyyMMdd-HHmmss").format(date);
url = "models/" + username + "/" + models.getName() + "-" + "/" + modelsVersion.getVersion() + "/" + modelsVersion.getFileName();
url = "models/" + username + "/" + models.getName() + "-" + "/" + modelsVersion.getVersion() + "/" + modelsVersion.getFileName();
modelsVersion.setUrl(url);
modelsVersion.setCreateBy(username);
modelsVersion.setUpdateBy(username);
@@ -273,7 +290,7 @@ public class ModelsServiceImpl implements ModelsService {
modelsVersion.setUpdateTime(date);
modelsVersion.setState(1);
modelsVersionDao.insert(modelsVersion);
}else {
} else {
//改表
BeansUtils.copyPropertiesIgnoreNull(modelsVersion, version);
Date createTime = version.getCreateTime();
@@ -283,21 +300,21 @@ public class ModelsServiceImpl implements ModelsService {
modelsVersionService.update(version);
}
Map<String, String> result = new HashMap<String, String>();
result.put("url",url);
result.put("url", url);
return result;
}

/**
* 下载所有模型文件,用压缩包的方式返回
*
* @param modelsId 模型ID
* @param version 模型版本号
* @param modelsId 模型ID
* @param version 模型版本号
* @return 文件内容
*/
@Override
public ResponseEntity<InputStreamResource> downloadAllModelFiles(Integer modelsId, String version) throws Exception {
// 根据模型id查模型名
Models model = this.modelsDao.queryById(modelsId);
Models model = this.modelsDao.queryById(modelsId);
String modelName = model.getName();
// 查询特定模型和版本对应的所有文件
List<ModelsVersion> modelsVersionList = this.modelsVersionDao.queryAllByModelsVersion(modelsId, version);
@@ -328,7 +345,7 @@ public class ModelsServiceImpl implements ModelsService {

// 设置响应
return ResponseEntity.ok()
.header(HttpHeaders.CONTENT_DISPOSITION, "attachment; filename=\"" + modelName + "_" + version + ".zip\"")
.header(HttpHeaders.CONTENT_DISPOSITION, "attachment; filename=\"" + modelName + "_" + version + ".zip\"")
.contentType(MediaType.APPLICATION_OCTET_STREAM)
.body(resource);
} catch (Exception e) {
@@ -361,7 +378,7 @@ public class ModelsServiceImpl implements ModelsService {
public String insertModelAndVersion(ModelsVo modelsVo) throws Exception {
List<VersionVo> modelsVersionVos = modelsVo.getModelsVersionVos();

if (modelsVersionVos==null || modelsVersionVos.isEmpty()){
if (modelsVersionVos == null || modelsVersionVos.isEmpty()) {
throw new Exception("模型版本信息错误");
}
Models models = new Models();
@@ -371,11 +388,11 @@ public class ModelsServiceImpl implements ModelsService {
models.setModelType(modelsVo.getModelType());
models.setModelTag(modelsVo.getModelTag());
Models modelsInsert = this.insert(models);
if (modelsInsert == null){
if (modelsInsert == null) {
throw new Exception("新增模型失败");
}
//遍历版本信息列表,把文件信息插入数据库
for(VersionVo modelsVersionVo : modelsVersionVos){
for (VersionVo modelsVersionVo : modelsVersionVos) {
ModelsVersion modelsVersion = new ModelsVersion();
modelsVersion.setModelsId(modelsInsert.getId());
modelsVersion.setVersion(modelsVo.getVersion());
@@ -394,16 +411,16 @@ public class ModelsServiceImpl implements ModelsService {


/**
* 根据模型id和版本读取文件内容
* 根据模型id和版本读取文件内容
*
* @param modelsId 模型ID
* @param version 模型版本号
* @param modelsId 模型ID
* @param version 模型版本号
* @return 文件内容
*/
@Override
public String readFileContent(Integer modelsId, String version) throws Exception {
// 根据模型id查模型名
Models model = this.modelsDao.queryById(modelsId);
Models model = this.modelsDao.queryById(modelsId);
if (model == null) {
throw new Exception("模型不存在");
}
@@ -415,9 +432,9 @@ public class ModelsServiceImpl implements ModelsService {
throw new Exception("对应模型版本不存在");
}
//遍历文件列表
for(ModelsVersion modelsVersion : modelsVersionList){
for (ModelsVersion modelsVersion : modelsVersionList) {
String fileName = modelsVersion.getFileName();
if("readme.md".equalsIgnoreCase(fileName)){
if ("readme.md".equalsIgnoreCase(fileName)) {
//如果存在readme文件,读取minio中的url
objectName = modelsVersion.getUrl();
// 读取MinIO中的对象为字符串
@@ -473,7 +490,7 @@ public class ModelsServiceImpl implements ModelsService {
minioUtil.copyDirectory(srcBucketName, srcDir, bucketName, targetDir);
List<Item> movedItems = minioUtil.getAllObjectsByPrefix(bucketName, targetDir, true);
for (Item movedItem : movedItems) {
if(!movedItem.isDir() && movedItem.size() > 0){ // 检查是否为非目录且文件大小大于0
if (!movedItem.isDir() && movedItem.size() > 0) { // 检查是否为非目录且文件大小大于0
Map<String, String> result = new HashMap<>();
String url = movedItem.objectName();
String fileName = extractFileName(url);
@@ -493,5 +510,254 @@ public class ModelsServiceImpl implements ModelsService {
return urlStr.substring(urlStr.lastIndexOf('/') + 1);
}

@Override
public CompletableFuture<String> newCreateModel(ModelsVo modelsVo) throws Exception {
return CompletableFuture.supplyAsync(() -> {
try {
String token = gitService.login(modelsVo.getGitLinkUsername(), modelsVo.getGitLinkPassword());
LoginUser loginUser = SecurityUtils.getLoginUser();
String ci4sUsername = loginUser.getUsername();
Jedis jedis = new Jedis(redisHost);
String userReq = jedis.get(ci4sUsername + "_gitUserInfo");
Map<String, Object> userInfo = JsonUtils.jsonToMap(userReq);
Integer userId = (Integer) userInfo.get("user_id");

// 拼接project
String repositoryName = ci4sUsername + "_model_" + DateUtils.dateTimeNow();
GitProjectVo gitProjectVo = new GitProjectVo();
gitProjectVo.setRepositoryName(repositoryName);
gitProjectVo.setName(modelsVo.getName());
gitProjectVo.setDescription(modelsVo.getDescription());
gitProjectVo.setPrivate(modelsVo.getAvailableRange() == 0);
gitProjectVo.setUserId(userId);

// 创建项目
Map project = gitService.createProject(token, gitProjectVo);

// 创建分支
String branchName = modelsVo.getVersion();
gitService.createBranch(token, (String) userInfo.get("login"), repositoryName, branchName, "master");
// 定义标签 标签1:ci4s_model 标签2:ModelTag 标签3:ModelType
gitService.createTopic(token, (Integer) project.get("id"), "ci4s_model");
gitService.createTopic(token, (Integer) project.get("id"), "ModelTag_" + modelsVo.getModelTag());
gitService.createTopic(token, (Integer) project.get("id"), "ModelType_" + modelsVo.getModelType());
// 得到项目地址
String projectUrl = gitendpoint + "/" + (String) userInfo.get("login") + "/" + repositoryName + ".git";
// 得到用户操作的路径
String url = modelsVo.getModelsVersionVos().get(0).getUrl();
String localPath1 = localPath + modelsVo.getName();
String sourcePath = url.substring(0, url.lastIndexOf("/"));
// 命令行操作 git clone 项目地址
DVCUtils.gitClone(localPath1, projectUrl, branchName, modelsVo.getGitLinkUsername(), modelsVo.getGitLinkPassword());
String s3Path = "management-platform-files/" + ci4sUsername + "/models/" + repositoryName + "/" + branchName;

// 拼接生产的元数据后写入yaml文件
YamlUtils.generateYamlFile(JsonUtils.objectToMap(modelsVo), sourcePath, "model");

DVCUtils.moveFiles(sourcePath, localPath1 + "/model");

// dvc init 初始化
DVCUtils.dvcInit(localPath1);
// 配置远程S3地址
DVCUtils.dvcRemoteAdd(localPath1, s3Path);
DVCUtils.dvcConfigS3Credentials(localPath1, endpoint);
DVCUtils.dvcConfigS3Credentials2(localPath1, accessKeyId);
DVCUtils.dvcConfigS3Credentials3(localPath1, secretAccessKey);
// dvc 跟踪
DVCUtils.dvcAdd(localPath1, "model");
// git commit
DVCUtils.gitAdd(localPath1, ".");
DVCUtils.gitCommit(localPath1, "commit from ci4s with " + ci4sUsername);
DVCUtils.gitPush(localPath1, modelsVo.getGitLinkUsername(), modelsVo.getGitLinkPassword());
// dvc push 到远程S3
DVCUtils.dvcPush(localPath1);
return "新增模型成功";
} catch (Exception e) {
logger.error(e.getMessage());
throw new RuntimeException(e);
}
});
}

@Override
public CompletableFuture<String> newCreateVersion(ModelsVo modelsVo) {
return CompletableFuture.supplyAsync(() -> {
try {
String token = gitService.login(modelsVo.getGitLinkUsername(), modelsVo.getGitLinkPassword());
LoginUser loginUser = SecurityUtils.getLoginUser();
String ci4sUsername = loginUser.getUsername();
Jedis jedis = new Jedis(redisHost);
String userReq = jedis.get(ci4sUsername + "_gitUserInfo");
Map<String, Object> userInfo = JsonUtils.jsonToMap(userReq);
// 创建分支
String branchName = modelsVo.getVersion();
String repositoryName = modelsVo.getRepositoryName();
gitService.createBranch(token, (String) userInfo.get("login"), repositoryName, branchName, "master");
// 得到项目地址
String projectUrl = gitendpoint + "/" + (String) userInfo.get("login") + "/" + repositoryName + ".git";
// 得到用户操作的路径
String url = modelsVo.getModelVersionVos().get(0).getUrl();
String localPath1 = localPath + ci4sUsername + "/model/" + modelsVo.getName();
String sourcePath = url.substring(0, url.lastIndexOf("/"));
// 命令行操作 git clone 项目地址
DVCUtils.gitClone(localPath1, projectUrl, branchName, modelsVo.getGitLinkUsername(), modelsVo.getGitLinkPassword());
String s3Path = "management-platform-files/" + ci4sUsername + "/model/" + repositoryName + "/" + branchName;
//拼接生产的元数据后写入yaml文件
YamlUtils.generateYamlFile(JsonUtils.objectToMap(modelsVo), sourcePath, "dataset");

DVCUtils.moveFiles(sourcePath, localPath1 + "/model");
// dvc init 初始化
DVCUtils.dvcInit(localPath1);
// 配置远程S3地址
DVCUtils.dvcRemoteAdd(localPath1, s3Path);
DVCUtils.dvcConfigS3Credentials(localPath1, endpoint);
DVCUtils.dvcConfigS3Credentials2(localPath1, accessKeyId);
DVCUtils.dvcConfigS3Credentials3(localPath1, secretAccessKey);
// dvc 跟踪
DVCUtils.dvcAdd(localPath1, "model");
// git commit
DVCUtils.gitAdd(localPath1, ".");
DVCUtils.gitCommit(localPath1, "commit from ci4s with " + ci4sUsername);
DVCUtils.gitPush(localPath1, modelsVo.getGitLinkUsername(), modelsVo.getGitLinkPassword());
// dvc push 到远程S3
DVCUtils.dvcPush(localPath1);
return "新增模型成功";
} catch (Exception e) {
throw new RuntimeException(e);
}
});
}

@Override
public ResponseEntity<InputStreamResource> downloadAllModelFilesNew(String repositoryName, String version, String gitLinkUsername, String gitLinkPassword) throws IOException, Exception {
// 命令行操作 git clone 项目地址
LoginUser loginUser = SecurityUtils.getLoginUser();
String ci4sUsername = loginUser.getUsername();
String token = gitService.login(gitLinkUsername, gitLinkPassword);
Jedis jedis = new Jedis(redisHost);
String userReq = jedis.get(ci4sUsername + "_gitUserInfo");
Map<String, Object> userInfo = JsonUtils.jsonToMap(userReq);
Integer userId = (Integer) userInfo.get("user_id");
String projectUrl = gitendpoint + "/" + (String) userInfo.get("login") + "/" + repositoryName + ".git";
String localPath1 = localPath + ci4sUsername + "/model/" + repositoryName;
File folder = new File(localPath1);
if (folder.exists() && folder.isDirectory()) {
//切换分支
DVCUtils.gitCheckoutBranch(localPath1, version);
//pull
DVCUtils.gitPull(localPath1, gitLinkUsername, gitLinkPassword);
//dvc pull
DVCUtils.dvcPull(localPath1);
} else {
DVCUtils.gitClone(localPath1, projectUrl, version, gitLinkUsername, gitLinkPassword);
}
// 打包 data 文件夹
String dataFolderPath = localPath1 + "/model";
String zipFilePath = localPath1 + "/model.zip";
try (FileOutputStream fos = new FileOutputStream(zipFilePath);
ZipOutputStream zos = new ZipOutputStream(fos)) {
Path sourcePath = Paths.get(dataFolderPath);
Files.walk(sourcePath).forEach(path -> {
if (!Files.isDirectory(path)) {
ZipEntry zipEntry = new ZipEntry(sourcePath.relativize(path).toString());
try {
zos.putNextEntry(zipEntry);
Files.copy(path, zos);
zos.closeEntry();
} catch (IOException e) {
throw new RuntimeException("Error while zipping: " + path, e);
}
}
});
}

// 返回压缩文件的输入流
File zipFile = new File(zipFilePath);
InputStreamResource resource = new InputStreamResource(new FileInputStream(zipFile));

return ResponseEntity.ok()
.header(HttpHeaders.CONTENT_DISPOSITION, "attachment;filename=data.zip")
.contentType(MediaType.APPLICATION_OCTET_STREAM)
.contentLength(zipFile.length())
.body(resource);
}

@Override
public Page<ModelsVo> newPubilcQueryByPage(ModelsVo modelsVo, PageRequest pageRequest) throws Exception {
LoginUser loginUser = SecurityUtils.getLoginUser();
String ci4sUsername = loginUser.getUsername();
String token = gitService.login(modelsVo.getGitLinkUsername(), modelsVo.getGitLinkPassword());
Jedis jedis = new Jedis(redisHost);
String userReq = jedis.get(ci4sUsername + "_gitUserInfo");
Map<String, Object> userInfo = JsonUtils.jsonToMap(userReq);
//拼接查询url
String modelTagName = modelsVo.getModelTag();
String modelTypeName = modelsVo.getModelType();
String topic_name = "ci4s_model";
topic_name = StringUtils.isEmpty(modelTagName) ? topic_name : topic_name + ",modeltag_" + modelTypeName;
topic_name = StringUtils.isEmpty(modelTagName) ? topic_name : topic_name + ",modeltype_" + modelTypeName;
String url = gitendpoint + "/api/users/" + (String) userInfo.get("login") + "/projects.json?page=" + pageRequest.getPageNumber() + "&limit=" + pageRequest.getPageSize() + "&category=manage&topic_name=" + topic_name;
String req = HttpUtils.sendGetWithToken(url, null, token);
Map<String, Object> stringObjectMap = JacksonUtil.parseJSONStr2Map(req);
Integer total = (Integer) stringObjectMap.get("count");
List<Map<String, Object>> projects = (List<Map<String, Object>>) stringObjectMap.get("projects");
return new PageImpl<>(convert(projects), pageRequest, total);
}

@Override
public Page<ModelsVo> newPersonalQueryByPage(ModelsVo modelsVo, PageRequest pageRequest) throws Exception {
LoginUser loginUser = SecurityUtils.getLoginUser();
String ci4sUsername = loginUser.getUsername();
String token = gitService.login(modelsVo.getGitLinkUsername(), modelsVo.getGitLinkPassword());
Jedis jedis = new Jedis(redisHost);
String userReq = jedis.get(ci4sUsername + "_gitUserInfo");
Map<String, Object> userInfo = JsonUtils.jsonToMap(userReq);
Integer userId = (Integer) userInfo.get("user_id");
//拼接查询url
String modelTagName = modelsVo.getModelTag();
String modelTypeName = modelsVo.getModelType();
String topic_name = "ci4s_model";
topic_name = StringUtils.isEmpty(modelTagName) ? topic_name : topic_name + ",modeltag_" + modelTagName;
topic_name = StringUtils.isEmpty(modelTypeName) ? topic_name : topic_name + ",modeltype_" + modelTypeName;

String url = gitendpoint + "/api/projects.json?user_id=" + userId + "&page=" + pageRequest.getPageNumber() + "&limit=" + pageRequest.getPageSize() + "&sort_by=praises_count&topic_name=" + topic_name;
String req = HttpUtils.sendGetWithToken(url, null, token);
Map<String, Object> stringObjectMap = JacksonUtil.parseJSONStr2Map(req);
Integer total = (Integer) stringObjectMap.get("total_count");
List<Map<String, Object>> projects = (List<Map<String, Object>>) stringObjectMap.get("projects");
return new PageImpl<>(convert(projects), pageRequest, total);
}

public List<ModelsVo> convert(List<Map<String, Object>> lst) {
if (lst != null && lst.size() > 0) {
List<ModelsVo> newModelVos = ConvertUtil.convertListMapToObjectList(lst, ModelsVo.class);

for (ModelsVo newModelVo : newModelVos) {
Map<String, Object> map = lst.stream()
.filter(m -> m.get("repo_id").equals(newModelVo.getRepoId()))
.findFirst()
.orElse(null);

if (map != null) {
List<Map<String, Object>> topics = (List<Map<String, Object>>) map.get("topics");
if (topics != null) {
topics.forEach(topic -> {
String name = (String) topic.get("name");
if (name != null) {
if (name.startsWith("modeltag_")) {
newModelVo.setModelTag(name.substring("modeltag_".length()));
} else if (name.startsWith("modeltype_")) {
newModelVo.setModelType(name.substring("modeltype_".length()));
}
}
});
}
}
}

return newModelVos;
}
return new ArrayList<>();
}

}

+ 16
- 68
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/vo/ModelsVo.java View File

@@ -2,13 +2,14 @@ package com.ruoyi.platform.vo;

import com.fasterxml.jackson.databind.PropertyNamingStrategy;
import com.fasterxml.jackson.databind.annotation.JsonNaming;
import com.ruoyi.platform.domain.ModelsVersion;
import io.swagger.annotations.ApiModelProperty;
import lombok.Data;

import java.io.Serializable;
import java.util.List;

@JsonNaming(PropertyNamingStrategy.SnakeCaseStrategy.class)
@Data
public class ModelsVo implements Serializable {


@@ -25,7 +26,6 @@ public class ModelsVo implements Serializable {
@ApiModelProperty(name = "available_range")
private int availableRange;


// private String url;
@ApiModelProperty(name = "model_type")
private String modelType;
@@ -39,80 +39,28 @@ public class ModelsVo implements Serializable {
@ApiModelProperty(name = "version")
private String version;

@ApiModelProperty(name = "model_version_vos")
private List<VersionVo> modelVersionVos;

/**
* 状态
*/
@ApiModelProperty(name = "status")
private Integer status;

@ApiModelProperty(name = "models_version_vos")
private List<VersionVo> modelsVersionVos;


public String getName() {
return name;
}

public void setName(String name) {
this.name = name;
}

public String getDescription() {
return description;
}

public void setDescription(String description) {
this.description = description;
}

public int getAvailableRange() {
return availableRange;
}

public void setAvailableRange(int availableRange) {
this.availableRange = availableRange;
}

private String gitLinkUsername;

public String getModelType() {
return modelType;
}
private String gitLinkPassword;

public void setModelType(String modelType) {
this.modelType = modelType;
}

public String getModelTag() {
return modelTag;
}

public void setModelTag(String modelTag) {
this.modelTag = modelTag;
}

public String getVersion() {
return version;
}

public void setVersion(String version) {
this.version = version;
}

public Integer getStatus() {
return status;
}

public void setStatus(Integer status) {
this.status = status;
}

public List<VersionVo> getModelsVersionVos() {
return modelsVersionVos;
}

public void setModelsVersionVos(List<VersionVo> modelsVersionVos) {
this.modelsVersionVos = modelsVersionVos;
}
@ApiModelProperty(name = "models_version_vos")
private List<VersionVo> modelsVersionVos;

/**
* 数据集仓库名称
*/
@ApiModelProperty(name = "repository_name")
private String repositoryName;

@ApiModelProperty(name = "repo_id")
private Integer repoId;
}

+ 0
- 2
ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/vo/NewDatasetVo.java View File

@@ -4,10 +4,8 @@ import com.fasterxml.jackson.databind.PropertyNamingStrategy;
import com.fasterxml.jackson.databind.annotation.JsonNaming;
import io.swagger.annotations.ApiModelProperty;
import lombok.Data;
import org.apache.ibatis.annotations.MapKey;

import java.io.Serializable;
import java.util.List;

@JsonNaming(PropertyNamingStrategy.SnakeCaseStrategy.class)
@Data


Loading…
Cancel
Save