| @@ -19,6 +19,7 @@ package org.dubhe.admin; | |||
| import org.mybatis.spring.annotation.MapperScan; | |||
| import org.springframework.boot.SpringApplication; | |||
| import org.springframework.boot.autoconfigure.SpringBootApplication; | |||
| import org.springframework.scheduling.annotation.EnableScheduling; | |||
| /** | |||
| @@ -27,6 +28,7 @@ import org.springframework.boot.autoconfigure.SpringBootApplication; | |||
| */ | |||
| @SpringBootApplication(scanBasePackages = "org.dubhe") | |||
| @MapperScan(basePackages = {"org.dubhe.**.dao"}) | |||
| @EnableScheduling | |||
| public class AdminApplication { | |||
| public static void main(String[] args) { | |||
| SpringApplication.run(AdminApplication.class, args); | |||
| @@ -42,10 +42,10 @@ public class ResourceSpecsCreateDTO implements Serializable { | |||
| @Pattern(regexp = StringConstant.REGEXP_SPECS, message = "规格名称支持字母、数字、汉字、英文横杠、下划线和空白字符") | |||
| private String specsName; | |||
| @ApiModelProperty(value = "所属业务场景(0:通用,1:dubhe-notebook,2:dubhe-train,3:dubhe-serving)", required = true) | |||
| @ApiModelProperty(value = "所属业务场景(0:通用,1:dubhe-notebook,2:dubhe-train,3:dubhe-serving, 4:dubhe-tadl)", required = true) | |||
| @NotNull(message = "所属业务场景不能为空") | |||
| @Min(value = MagicNumConstant.ZERO, message = "所属业务场景错误") | |||
| @Max(value = MagicNumConstant.THREE, message = "所属业务场景错误") | |||
| @Max(value = MagicNumConstant.FOUR, message = "所属业务场景错误") | |||
| private Integer module; | |||
| @ApiModelProperty(value = "CPU数量,单位:核", required = true) | |||
| @@ -37,6 +37,9 @@ public class ResourceSpecsQueryDTO extends PageQueryBase implements Serializable | |||
| private static final long serialVersionUID = 1L; | |||
| @ApiModelProperty(value = "多GPU,true:GPU数大于1核,false:GPU数等于1核") | |||
| private Boolean multiGpu; | |||
| @ApiModelProperty("规格名称") | |||
| @Length(max = MagicNumConstant.THIRTY_TWO, message = "规格名称错误") | |||
| private String specsName; | |||
| @@ -44,8 +47,8 @@ public class ResourceSpecsQueryDTO extends PageQueryBase implements Serializable | |||
| @ApiModelProperty("规格类型(0为CPU, 1为GPU)") | |||
| private Boolean resourcesPoolType; | |||
| @ApiModelProperty("所属业务场景(0:通用,1:dubhe-notebook,2:dubhe-train,3:dubhe-serving)") | |||
| @ApiModelProperty("所属业务场景(0:通用,1:dubhe-notebook,2:dubhe-train,3:dubhe-serving,4:dubhe-tadl)") | |||
| @Min(value = MagicNumConstant.ZERO, message = "所属业务场景错误") | |||
| @Max(value = MagicNumConstant.THREE, message = "所属业务场景错误") | |||
| @Max(value = MagicNumConstant.FOUR, message = "所属业务场景错误") | |||
| private Integer module; | |||
| } | |||
| @@ -49,10 +49,10 @@ public class ResourceSpecsUpdateDTO implements Serializable { | |||
| @Pattern(regexp = StringConstant.REGEXP_SPECS, message = "规格名称支持字母、数字、汉字、英文横杠、下划线和空白字符") | |||
| private String specsName; | |||
| @ApiModelProperty(value = "所属业务场景(0:通用,1:dubhe-notebook,2:dubhe-train,3:dubhe-serving)", required = true) | |||
| @ApiModelProperty(value = "所属业务场景(0:通用,1:dubhe-notebook,2:dubhe-train,3:dubhe-serving, 4:dubhe-tadl)", required = true) | |||
| @NotNull(message = "所属业务场景不能为空") | |||
| @Min(value = MagicNumConstant.ZERO, message = "所属业务场景错误") | |||
| @Max(value = MagicNumConstant.THREE, message = "所属业务场景错误") | |||
| @Max(value = MagicNumConstant.FOUR, message = "所属业务场景错误") | |||
| private Integer module; | |||
| @ApiModelProperty(value = "CPU数量,单位:核") | |||
| @@ -29,6 +29,7 @@ import org.dubhe.biz.base.vo.DataResponseBody; | |||
| import org.dubhe.biz.base.vo.QueryResourceSpecsVO; | |||
| import org.springframework.beans.factory.annotation.Autowired; | |||
| import org.springframework.security.access.prepost.PreAuthorize; | |||
| import org.springframework.validation.annotation.Validated; | |||
| import org.springframework.web.bind.annotation.*; | |||
| import javax.validation.Valid; | |||
| @@ -47,16 +48,23 @@ public class ResourceSpecsController { | |||
| @ApiOperation("查询资源规格") | |||
| @GetMapping | |||
| public DataResponseBody getResourceSpecs(ResourceSpecsQueryDTO resourceSpecsQueryDTO) { | |||
| public DataResponseBody getResourceSpecs(@Validated ResourceSpecsQueryDTO resourceSpecsQueryDTO) { | |||
| return new DataResponseBody(resourceSpecsService.getResourceSpecs(resourceSpecsQueryDTO)); | |||
| } | |||
| @ApiOperation("查询资源规格(远程调用)") | |||
| @ApiOperation("查询资源规格(训练远程调用)") | |||
| @GetMapping("/queryResourceSpecs") | |||
| public DataResponseBody<QueryResourceSpecsVO> queryResourceSpecs(QueryResourceSpecsDTO queryResourceSpecsDTO) { | |||
| public DataResponseBody<QueryResourceSpecsVO> queryResourceSpecs(@Validated QueryResourceSpecsDTO queryResourceSpecsDTO) { | |||
| return new DataResponseBody(resourceSpecsService.queryResourceSpecs(queryResourceSpecsDTO)); | |||
| } | |||
| @ApiOperation("查询资源规格(tadl远程调用)") | |||
| @GetMapping("/queryTadlResourceSpecs") | |||
| public DataResponseBody<QueryResourceSpecsVO> queryTadlResourceSpecs(Long id) { | |||
| return new DataResponseBody(resourceSpecsService.queryTadlResourceSpecs(id)); | |||
| } | |||
| @ApiOperation("新增资源规格") | |||
| @PostMapping | |||
| @PreAuthorize(Permissions.SPECS_CREATE) | |||
| @@ -62,4 +62,11 @@ public interface ResourceSpecsService { | |||
| * @return QueryResourceSpecsVO 资源规格返回结果实体类 | |||
| */ | |||
| QueryResourceSpecsVO queryResourceSpecs(QueryResourceSpecsDTO queryResourceSpecsDTO); | |||
| /** | |||
| * 查询资源规格 | |||
| * @param id 资源规格id | |||
| * @return QueryResourceSpecsVO 资源规格返回结果实体类 | |||
| */ | |||
| QueryResourceSpecsVO queryTadlResourceSpecs(Long id); | |||
| } | |||
| @@ -329,7 +329,7 @@ public class RecycleTaskServiceImpl implements RecycleTaskService { | |||
| } | |||
| String emptyDir = recycleFileTmpPath + randomPath + File.separator; | |||
| LogUtil.debug(LogEnum.GARBAGE_RECYCLE, "recycle task sourcePath:{},emptyDir:{}", sourcePath, emptyDir); | |||
| Process process = Runtime.getRuntime().exec(new String[]{"/bin/sh", "-c", String.format(ShellFileStoreApiImpl.DEL_COMMAND, userName, ip, emptyDir, emptyDir, sourcePath, emptyDir, sourcePath)}); | |||
| Process process = Runtime.getRuntime().exec(new String[]{"/bin/sh", "-c", String.format(ShellFileStoreApiImpl.DEL_COMMAND, emptyDir, emptyDir, sourcePath, emptyDir, sourcePath)}); | |||
| return processRecycle(process); | |||
| } else { | |||
| LogUtil.error(LogEnum.GARBAGE_RECYCLE, "file recycle is failed! sourcePath:{}", sourcePath); | |||
| @@ -460,7 +460,7 @@ public class RecycleTaskServiceImpl implements RecycleTaskService { | |||
| String delRealPath = fileStoreApi.formatPath(sourcePath + File.separator + fileName + File.separator + directoryName); | |||
| delRealPath = delRealPath.endsWith(File.separator) ? delRealPath : delRealPath + File.separator; | |||
| String emptyDir = invalidFileTmpPath + directoryName + File.separator; | |||
| Process process = Runtime.getRuntime().exec(new String[]{"/bin/sh", "-c", String.format(ShellFileStoreApiImpl.DEL_COMMAND, userName, ip, emptyDir, emptyDir, delRealPath, emptyDir, delRealPath)}); | |||
| Process process = Runtime.getRuntime().exec(new String[]{"/bin/sh", "-c", String.format(ShellFileStoreApiImpl.DEL_COMMAND, emptyDir, emptyDir, delRealPath, emptyDir, delRealPath)}); | |||
| Integer deleteStatus = process.waitFor(); | |||
| LogUtil.info(LogEnum.GARBAGE_RECYCLE, "recycle resources path:{},recycle status:{}", delRealPath, deleteStatus); | |||
| } catch (Exception e) { | |||
| @@ -27,6 +27,7 @@ import org.dubhe.admin.domain.dto.ResourceSpecsUpdateDTO; | |||
| import org.dubhe.admin.domain.entity.ResourceSpecs; | |||
| import org.dubhe.admin.domain.vo.ResourceSpecsQueryVO; | |||
| import org.dubhe.admin.service.ResourceSpecsService; | |||
| import org.dubhe.biz.base.constant.MagicNumConstant; | |||
| import org.dubhe.biz.base.constant.StringConstant; | |||
| import org.dubhe.biz.base.context.UserContext; | |||
| import org.dubhe.biz.base.dto.QueryResourceSpecsDTO; | |||
| @@ -72,6 +73,13 @@ public class ResourceSpecsServiceImpl implements ResourceSpecsService { | |||
| queryResourceSpecsWrapper.like(resourceSpecsQueryDTO.getSpecsName() != null, "specs_name", resourceSpecsQueryDTO.getSpecsName()) | |||
| .eq(resourceSpecsQueryDTO.getResourcesPoolType() != null, "resources_pool_type", resourceSpecsQueryDTO.getResourcesPoolType()) | |||
| .eq(resourceSpecsQueryDTO.getModule() != null, "module", resourceSpecsQueryDTO.getModule()); | |||
| if (resourceSpecsQueryDTO.getMultiGpu() != null) { | |||
| if (resourceSpecsQueryDTO.getMultiGpu()) { | |||
| queryResourceSpecsWrapper.gt("gpu_num", MagicNumConstant.ONE); | |||
| } else { | |||
| queryResourceSpecsWrapper.eq("gpu_num", MagicNumConstant.ONE); | |||
| } | |||
| } | |||
| if (StringConstant.SORT_ASC.equals(resourceSpecsQueryDTO.getOrder())) { | |||
| queryResourceSpecsWrapper.orderByAsc(StringUtils.humpToLine(sort)); | |||
| } else { | |||
| @@ -206,4 +214,23 @@ public class ResourceSpecsServiceImpl implements ResourceSpecsService { | |||
| BeanUtils.copyProperties(resourceSpecs, queryResourceSpecsVO); | |||
| return queryResourceSpecsVO; | |||
| } | |||
| /** | |||
| * 查询资源规格 | |||
| * @param id 资源规格id | |||
| * @return QueryResourceSpecsVO 资源规格返回结果实体类 | |||
| */ | |||
| @Override | |||
| public QueryResourceSpecsVO queryTadlResourceSpecs(Long id) { | |||
| LogUtil.info(LogEnum.BIZ_SYS,"Query resource specification information with resource id:{}",id); | |||
| ResourceSpecs resourceSpecs = resourceSpecsMapper.selectById(id); | |||
| LogUtil.info(LogEnum.BIZ_SYS,"Obtain resource specification information:{} ",resourceSpecs); | |||
| if (resourceSpecs == null) { | |||
| throw new BusinessException("资源规格不存在或已被删除"); | |||
| } | |||
| QueryResourceSpecsVO queryResourceSpecsVO = new QueryResourceSpecsVO(); | |||
| BeanUtils.copyProperties(resourceSpecs, queryResourceSpecsVO); | |||
| LogUtil.info(LogEnum.BIZ_SYS,"Return resource specification information :{} ",queryResourceSpecsVO); | |||
| return queryResourceSpecsVO; | |||
| } | |||
| } | |||
| @@ -81,6 +81,10 @@ public class ApplicationNameConst { | |||
| */ | |||
| public final static String SERVER_DATA_DCM = "dubhe-data-dcm"; | |||
| /** | |||
| * TADL | |||
| */ | |||
| public final static String SERVER_TADL = "dubhe-tadl"; | |||
| /** | |||
| * k8s | |||
| */ | |||
| @@ -32,6 +32,8 @@ public class NumberConstant { | |||
| public final static int NUMBER_6 = 6; | |||
| public final static int NUMBER_8 = 8; | |||
| public final static int NUMBER_10 = 10; | |||
| public final static int NUMBER_12 = 12; | |||
| public final static int NUMBER_24 = 24; | |||
| public final static int NUMBER_30 = 30; | |||
| public final static int NUMBER_32 = 32; | |||
| public final static int NUMBER_50 = 50; | |||
| @@ -53,6 +53,19 @@ public final class StringConstant { | |||
| * 整数匹配 | |||
| */ | |||
| public static final Pattern PATTERN_NUM = Pattern.compile("^[-\\+]?[\\d]*$"); | |||
| /** | |||
| * 数字匹配 | |||
| */ | |||
| public static final String NUMBER ="(\\d+)"; | |||
| /** | |||
| * 整数匹配 | |||
| */ | |||
| public static final Pattern PATTERN_NUMBER = Pattern.compile("(\\d+)"); | |||
| /** | |||
| * 小数匹配 | |||
| */ | |||
| public static final Pattern PATTERN_DECIMAL = Pattern.compile("(\\d+\\.\\d+)"); | |||
| /** | |||
| @@ -27,6 +27,7 @@ public class SymbolConstant { | |||
| public static final String COLON = ":"; | |||
| public static final String LINEBREAK = "\n"; | |||
| public static final String BLANK = ""; | |||
| public static final String SPACE = " "; | |||
| public static final String QUESTION = "?"; | |||
| public static final String ZERO = "0"; | |||
| public static final String DOT = "."; | |||
| @@ -45,10 +45,10 @@ public class QueryResourceSpecsDTO implements Serializable { | |||
| private String specsName; | |||
| /** | |||
| * 所属业务场景(0:通用,1:dubhe-notebook,2:dubhe-train,3:dubhe-serving) | |||
| * 所属业务场景(0:通用,1:dubhe-notebook,2:dubhe-train,3:dubhe-serving,4:dubhe-tadl) | |||
| */ | |||
| @NotNull(message = "所属业务场景不能为空") | |||
| @Min(value = MagicNumConstant.ZERO, message = "所属业务场景错误") | |||
| @Max(value = MagicNumConstant.THREE, message = "所属业务场景错误") | |||
| @Max(value = MagicNumConstant.FOUR, message = "所属业务场景错误") | |||
| private Integer module; | |||
| } | |||
| @@ -61,7 +61,10 @@ public enum BizEnum { | |||
| * 专业版终端 | |||
| */ | |||
| TERMINAL("专业版终端", "terminal", 7), | |||
| ; | |||
| /** | |||
| * TADL | |||
| */ | |||
| TADL("TADL服务", "tadl", 8); | |||
| /** | |||
| * 业务模块名称 | |||
| @@ -54,6 +54,8 @@ public class PtModelUtil { | |||
| public static final int MODEL_OPTIMIZATION = 2; | |||
| public static final int AUTOMATIC_MACHINE_LEARNING = 4; | |||
| public static final int RANDOM_LENGTH = 4; | |||
| } | |||
| @@ -48,7 +48,7 @@ public class QueryResourceSpecsVO implements Serializable { | |||
| private Boolean resourcesPoolType; | |||
| /** | |||
| *所属业务场景 | |||
| *所属业务场景(0:通用,1:dubhe-notebook,2:dubhe-train,3:dubhe-serving,4:dubhe-tadl) | |||
| */ | |||
| private Integer module; | |||
| @@ -0,0 +1,34 @@ | |||
| /** | |||
| * Copyright 2020 Zhejiang Lab. All Rights Reserved. | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| * ============================================================= | |||
| */ | |||
| package org.dubhe.biz.db.base; | |||
| import lombok.Data; | |||
| import lombok.experimental.Accessors; | |||
| /** | |||
| * @description 日志基类 | |||
| * @date 2021-03-11 | |||
| */ | |||
| @Data | |||
| @Accessors(chain = true) | |||
| public class BaseLogQuery { | |||
| private Integer startLine; | |||
| private Integer lines; | |||
| } | |||
| @@ -72,6 +72,16 @@ | |||
| <groupId>org.apache.poi</groupId> | |||
| <artifactId>poi-ooxml</artifactId> | |||
| </dependency> | |||
| <dependency> | |||
| <groupId>junit</groupId> | |||
| <artifactId>junit</artifactId> | |||
| <version>4.13.1</version> | |||
| </dependency> | |||
| <dependency> | |||
| <groupId>io.minio</groupId> | |||
| <artifactId>minio</artifactId> | |||
| <version>7.0.2</version> | |||
| </dependency> | |||
| </dependencies> | |||
| <build> | |||
| @@ -62,7 +62,7 @@ public class ShellFileStoreApiImpl implements FileStoreApi { | |||
| * 删除服务器无效文件(大文件) | |||
| * 示例:rsync --delete-before -d /空目录 /需要回收的源目录 | |||
| */ | |||
| public static final String DEL_COMMAND = "ssh %s@%s \"mkdir -p %s; rsync --delete-before -d %s %s; rmdir %s %s\""; | |||
| public static final String DEL_COMMAND = "mkdir -p %s; rsync --delete-before -d %s %s; rmdir %s %s"; | |||
| /** | |||
| * 拷贝文件并重命名 | |||
| @@ -18,6 +18,7 @@ | |||
| package org.dubhe.biz.file.utils; | |||
| import cn.hutool.core.io.IoUtil; | |||
| import com.alibaba.fastjson.JSONObject; | |||
| import io.minio.CopyConditions; | |||
| import io.minio.MinioClient; | |||
| import io.minio.PutObjectOptions; | |||
| @@ -40,6 +41,8 @@ import java.io.IOException; | |||
| import java.io.InputStream; | |||
| import java.nio.charset.Charset; | |||
| import java.util.*; | |||
| import java.util.stream.Collectors; | |||
| import java.util.stream.Stream; | |||
| /** | |||
| * @description Minio工具类 | |||
| @@ -280,4 +283,33 @@ public class MinioUtil { | |||
| } | |||
| } | |||
| /** | |||
| * 生成给HTTP PUT请求用的presigned URLs。浏览器/移动端的客户端可以用这个URL进行上传, | |||
| * 即使其所在的存储桶是私有的。这个presigned URL可以设置一个失效时间,默认值是7天 | |||
| * | |||
| * @param bucketName 存储桶名称 | |||
| * @param objectNames 存储桶里的对象名称 | |||
| * @param expires 失效时间(以秒为单位),默认是7天,不得大于七天 | |||
| * @return String | |||
| */ | |||
| public JSONObject getEncryptedPutUrls(String bucketName,String objectNames, Integer expires) { | |||
| List<String> filePaths = JSONObject.parseObject(objectNames, List.class); | |||
| List<String> urls = new ArrayList<>(); | |||
| filePaths.stream().forEach(filePath->{ | |||
| if (StringUtils.isEmpty(filePath)) { | |||
| throw new BusinessException("filePath cannot be empty"); | |||
| } | |||
| try { | |||
| urls.add(client.presignedPutObject(bucketName, filePath, expires)); | |||
| } catch (Exception e) { | |||
| LogUtil.error(LogEnum.BIZ_DATASET, e.getMessage()); | |||
| throw new BusinessException("MinIO an error occurred, please contact the administrator"); | |||
| } | |||
| }); | |||
| JSONObject jsonObject = new JSONObject(); | |||
| jsonObject.put("preUrls",urls); | |||
| jsonObject.put("bucketName", bucketName); | |||
| return jsonObject; | |||
| } | |||
| } | |||
| @@ -79,7 +79,9 @@ public enum LogEnum { | |||
| //云端Serving | |||
| SERVING, | |||
| //专业版终端 | |||
| TERMINAL; | |||
| TERMINAL, | |||
| //tadl | |||
| TADL; | |||
| /** | |||
| * 判断日志类型不能为空 | |||
| @@ -24,6 +24,11 @@ | |||
| <groupId>com.liferay</groupId> | |||
| <artifactId>com.fasterxml.jackson.databind</artifactId> | |||
| </dependency> | |||
| <dependency> | |||
| <groupId>com.amazonaws</groupId> | |||
| <artifactId>aws-java-sdk</artifactId> | |||
| <version>1.12.35</version> | |||
| </dependency> | |||
| </dependencies> | |||
| <build> | |||
| @@ -24,10 +24,7 @@ import org.dubhe.biz.log.utils.LogUtil; | |||
| import org.springframework.beans.factory.annotation.Value; | |||
| import org.springframework.data.redis.connection.RedisConnection; | |||
| import org.springframework.data.redis.connection.RedisConnectionFactory; | |||
| import org.springframework.data.redis.core.Cursor; | |||
| import org.springframework.data.redis.core.RedisConnectionUtils; | |||
| import org.springframework.data.redis.core.RedisTemplate; | |||
| import org.springframework.data.redis.core.ScanOptions; | |||
| import org.springframework.data.redis.core.*; | |||
| import org.springframework.data.redis.core.script.DefaultRedisScript; | |||
| import org.springframework.data.redis.core.script.RedisScript; | |||
| import org.springframework.data.redis.serializer.Jackson2JsonRedisSerializer; | |||
| @@ -670,6 +667,40 @@ public class RedisUtils { | |||
| return zRangeByScorePop( key,0, max,0,1); | |||
| } | |||
| /** | |||
| * 根据键获取score值为 min 到 max 之间的所有 member 和 score | |||
| * @param key 健 | |||
| * @param min score最小值 | |||
| * @param max score最大值 | |||
| * @return | |||
| */ | |||
| public Set<ZSetOperations.TypedTuple<Object>> zRangeByScoreWithScores(String key, Long min, Long max){ | |||
| try { | |||
| return redisTemplate.opsForZSet().rangeWithScores(key, min, max); | |||
| } catch (Exception e) { | |||
| LogUtil.error(LogEnum.BIZ_DATASET, "RedisUtils rangeWithScores key {} error:{}", key, e.getMessage(), e); | |||
| return null; | |||
| } | |||
| } | |||
| /** | |||
| * 根据 key 和 member 移除元素 | |||
| * @param key | |||
| * @param member | |||
| * @return | |||
| */ | |||
| public Boolean zRem(String key,Object member){ | |||
| try{ | |||
| if (StringUtils.isEmpty(key) || null == member){ | |||
| return false; | |||
| } | |||
| redisTemplate.opsForZSet().remove(key,member); | |||
| return true; | |||
| }catch (Exception e){ | |||
| LogUtil.error(LogEnum.REDIS, "RedisUtils zrem key {} member {} error:{}", key, member, e); | |||
| return false; | |||
| } | |||
| } | |||
| // ===============================list================================= | |||
| @@ -32,12 +32,12 @@ public class OAuth2UserContextServiceImpl implements UserContextService { | |||
| @Override | |||
| public UserContext getCurUser() { | |||
| JwtUserDTO jwtUserDTO = JwtUtils.getCurUser(); | |||
| return jwtUserDTO == null?null:jwtUserDTO.getUser(); | |||
| return jwtUserDTO == null ? null : jwtUserDTO.getUser(); | |||
| } | |||
| @Override | |||
| public Long getCurUserId() { | |||
| UserContext userContext = getCurUser(); | |||
| return userContext == null?null:userContext.getId(); | |||
| return userContext == null ? null : userContext.getId(); | |||
| } | |||
| } | |||
| @@ -6,7 +6,7 @@ spring: | |||
| context-path: /nacos | |||
| config: | |||
| namespace: dubhe-server-cloud-dev | |||
| server-addr: 10.105.1.132:8848 | |||
| server-addr: 127.0.0.1:8848 | |||
| discovery: | |||
| namespace: dubhe-server-cloud-dev | |||
| server-addr: 10.105.1.132:8848 | |||
| server-addr: 127.0.0.1:8848 | |||
| @@ -0,0 +1,12 @@ | |||
| spring: | |||
| cloud: | |||
| nacos: | |||
| username: nacos | |||
| password: Tianshu | |||
| context-path: /nacos | |||
| config: | |||
| namespace: dubhe-server-cloud-open-dev | |||
| server-addr: 10.105.1.132:8848 | |||
| discovery: | |||
| namespace: dubhe-server-cloud-open-dev | |||
| server-addr: 10.105.1.132:8848 | |||
| @@ -0,0 +1,12 @@ | |||
| spring: | |||
| cloud: | |||
| nacos: | |||
| username: nacos | |||
| password: Tianshu | |||
| context-path: /nacos | |||
| config: | |||
| namespace: dubhe-server-cloud-open-dev | |||
| server-addr: 10.105.1.132:8848 | |||
| discovery: | |||
| namespace: dubhe-server-cloud-open-dev | |||
| server-addr: 10.105.1.132:8848 | |||
| @@ -3,7 +3,7 @@ spring: | |||
| nacos: | |||
| config: | |||
| namespace: dubhe-server-cloud-pre | |||
| server-addr: 10.105.1.133:8848 | |||
| server-addr: 127.0.0.1:8848 | |||
| discovery: | |||
| namespace: dubhe-server-cloud-pre | |||
| server-addr: 10.105.1.133:8848 | |||
| server-addr: 127.0.0.1:8848 | |||
| @@ -6,7 +6,7 @@ spring: | |||
| context-path: /nacos | |||
| config: | |||
| namespace: dubhe-server-cloud-test | |||
| server-addr: 10.105.1.132:8848 | |||
| server-addr: 127.0.0.1:8848 | |||
| discovery: | |||
| namespace: dubhe-server-cloud-test | |||
| server-addr: 10.105.1.132:8848 | |||
| server-addr: 127.0.0.1:8848 | |||
| @@ -74,6 +74,7 @@ public class DockerCallbackTool { | |||
| LogUtil.info(LogEnum.TERMINAL, "{} sendPushCallback {} count {} status:{}", url, dockerPushCallbackDTO,count,httpResponse.getStatus()); | |||
| //重试 | |||
| if (HttpStatus.HTTP_OK != httpResponse.getStatus() && count > MagicNumConstant.ZERO){ | |||
| Thread.sleep(MagicNumConstant.ONE_THOUSAND); | |||
| sendPushCallback(dockerPushCallbackDTO,url,--count); | |||
| } | |||
| }catch (Exception e){ | |||
| @@ -75,4 +75,31 @@ public interface LogMonitoringApi { | |||
| */ | |||
| long searchLogCountByPodName(LogMonitoringBO logMonitoringBo); | |||
| /** | |||
| * 日志查询方法 | |||
| * | |||
| * @param logMonitoringBo 日志查询bo | |||
| * @return LogMonitoringVO 日志查询结果类 | |||
| */ | |||
| LogMonitoringVO searchLog(LogMonitoringBO logMonitoringBo); | |||
| /** | |||
| * 添加 TADL 服务日志到 Elasticsearch | |||
| * | |||
| * @param experimentId Experiment ID | |||
| * @param log 日志 | |||
| * @return boolean 日志添加是否成功 | |||
| */ | |||
| boolean addTadlLogsToEs(long experimentId, String log); | |||
| /** | |||
| * TADL 服务日志查询方法 | |||
| * | |||
| * @param from 日志查询起始值,初始值为1,表示从第一条日志记录开始查询 | |||
| * @param size 日志查询记录数 | |||
| * @param experimentId TADL Experiment ID | |||
| * @return LogMonitoringVO 日志查询结果类 | |||
| */ | |||
| LogMonitoringVO searchTadlLogById(int from, int size, long experimentId); | |||
| } | |||
| @@ -18,12 +18,14 @@ | |||
| package org.dubhe.k8s.api.impl; | |||
| import com.alibaba.fastjson.JSON; | |||
| import com.alibaba.fastjson.JSONObject; | |||
| import com.baomidou.mybatisplus.core.toolkit.CollectionUtils; | |||
| import io.fabric8.kubernetes.api.model.DoneablePod; | |||
| import io.fabric8.kubernetes.api.model.Pod; | |||
| import io.fabric8.kubernetes.client.KubernetesClient; | |||
| import io.fabric8.kubernetes.client.dsl.PodResource; | |||
| import org.dubhe.biz.base.constant.MagicNumConstant; | |||
| import org.dubhe.biz.base.enums.BizEnum; | |||
| import org.dubhe.biz.base.utils.StringUtils; | |||
| import org.dubhe.biz.base.utils.TimeTransferUtil; | |||
| import org.dubhe.biz.log.enums.LogEnum; | |||
| @@ -39,6 +41,10 @@ import org.elasticsearch.action.search.SearchRequest; | |||
| import org.elasticsearch.action.search.SearchResponse; | |||
| import org.elasticsearch.client.RequestOptions; | |||
| import org.elasticsearch.client.RestHighLevelClient; | |||
| import org.elasticsearch.client.indices.CreateIndexRequest; | |||
| import org.elasticsearch.client.indices.CreateIndexResponse; | |||
| import org.elasticsearch.client.indices.GetIndexRequest; | |||
| import org.elasticsearch.common.settings.Settings; | |||
| import org.elasticsearch.index.query.BoolQueryBuilder; | |||
| import org.elasticsearch.index.query.Operator; | |||
| import org.elasticsearch.index.query.QueryBuilders; | |||
| @@ -48,12 +54,9 @@ import org.elasticsearch.search.builder.SearchSourceBuilder; | |||
| import org.elasticsearch.search.sort.SortOrder; | |||
| import org.springframework.beans.factory.annotation.Autowired; | |||
| import org.springframework.beans.factory.annotation.Value; | |||
| import java.io.IOException; | |||
| import java.text.SimpleDateFormat; | |||
| import java.util.*; | |||
| import java.util.stream.Collectors; | |||
| import static org.dubhe.biz.base.constant.MagicNumConstant.ZERO; | |||
| import static org.dubhe.biz.base.constant.MagicNumConstant.*; | |||
| import static org.dubhe.biz.base.constant.SymbolConstant.*; | |||
| @@ -78,6 +81,14 @@ public class LogMonitoringApiImpl implements LogMonitoringApi { | |||
| private KubernetesClient kubernetesClient; | |||
| private static final String INDEX_NAME = "kubelogs"; | |||
| private static final String TADL_INDEX_NAME = "tadllogs"; | |||
| private static final String INDEX_SHARDS_NUMBER = "index.number_of_shards"; | |||
| private static final String INDEX_REPLICAS_NUMBER = "index.number_of_replicas"; | |||
| private static final String TYPE = "type"; | |||
| private static final String TEXT = "text"; | |||
| private static final String DATE = "date"; | |||
| private static final String PROPERTIES = "properties"; | |||
| private static final String EXPERIMENT_ID = "experimentId"; | |||
| private static final String POD_NAME_KEY = "kubernetes.pod_name.keyword"; | |||
| private static final String POD_NAME = "kubernetes.pod_name"; | |||
| private static final String NAMESPACE_KEY = "kubernetes.namespace_name.keyword"; | |||
| @@ -86,6 +97,12 @@ public class LogMonitoringApiImpl implements LogMonitoringApi { | |||
| private static final String MESSAGE = "log"; | |||
| private static final String LOG_PREFIX = "[Dubhe Service Log] "; | |||
| private static final String INDEX_FORMAT = "yyyy.MM.dd"; | |||
| private static final String TIMESTAMP_FORMAT = "yyyy-MM-dd HH:mm:ss.sss"; | |||
| private static final String LOG_FORMAT = "[Dubhe Service Log]-[%s]-%s"; | |||
| private static final String KUBERNETES_KEY = "kubernetes"; | |||
| private static final String KUBERNETES_POD_NAME_KEY = "pod_name"; | |||
| private static final String BUSINESS_KEY = "kubernetes.labels.platform/business.keyword"; | |||
| private static final String BUSINESS_GROUP_ID_KEY = "kubernetes.labels.platform/business-group-id.keyword"; | |||
| public LogMonitoringApiImpl(K8sUtils k8sUtils) { | |||
| this.kubernetesClient = k8sUtils.getClient(); | |||
| @@ -151,7 +168,7 @@ public class LogMonitoringApiImpl implements LogMonitoringApi { | |||
| /**通过restHighLevelClient发送http的请求批量创建文档**/ | |||
| restHighLevelClient.bulk(bulkRequest, RequestOptions.DEFAULT); | |||
| } catch (IOException e) { | |||
| } catch (Exception e) { | |||
| LogUtil.error(LogEnum.BIZ_K8S, "LogMonitoringApi.addLogsToEs error:{}", e); | |||
| return false; | |||
| } | |||
| @@ -169,7 +186,7 @@ public class LogMonitoringApiImpl implements LogMonitoringApi { | |||
| @Override | |||
| public LogMonitoringVO searchLogByResName(int from, int size, LogMonitoringBO logMonitoringBo) { | |||
| List<String> logList = new ArrayList<>(); | |||
| LogMonitoringVO logMonitoringResult = new LogMonitoringVO(ZERO_LONG, logList); | |||
| LogMonitoringVO logMonitoringResult = new LogMonitoringVO(ZERO, logList); | |||
| String namespace = logMonitoringBo.getNamespace(); | |||
| String resourceName = logMonitoringBo.getResourceName(); | |||
| if (StringUtils.isBlank(resourceName) || StringUtils.isBlank(namespace)) { | |||
| @@ -195,7 +212,7 @@ public class LogMonitoringApiImpl implements LogMonitoringApi { | |||
| } | |||
| logMonitoringResult.setLogs(logList); | |||
| logMonitoringResult.setTotalLogs(Long.valueOf(logList.size())); | |||
| logMonitoringResult.setTotalLogs(logList.size()); | |||
| return logMonitoringResult; | |||
| } | |||
| @@ -212,7 +229,23 @@ public class LogMonitoringApiImpl implements LogMonitoringApi { | |||
| LogMonitoringVO logMonitoringResult = new LogMonitoringVO(); | |||
| List<String> logs = searchLogInfoByEs(from, size, logMonitoringBo); | |||
| logMonitoringResult.setLogs(logs); | |||
| logMonitoringResult.setTotalLogs(Long.valueOf(logs.size())); | |||
| logMonitoringResult.setTotalLogs(logs.size()); | |||
| return logMonitoringResult; | |||
| } | |||
| /** | |||
| * 日志查询方法 | |||
| * | |||
| * @param logMonitoringBo 日志查询bo | |||
| * @return LogMonitoringVO 日志查询结果类 | |||
| */ | |||
| @Override | |||
| public LogMonitoringVO searchLog(LogMonitoringBO logMonitoringBo) { | |||
| LogMonitoringVO logMonitoringResult = new LogMonitoringVO(); | |||
| List<String> logs = searchLogInfoByEs(logMonitoringBo); | |||
| logMonitoringResult.setLogs(logs); | |||
| logMonitoringResult.setTotalLogs(logs.size()); | |||
| return logMonitoringResult; | |||
| } | |||
| @@ -235,6 +268,127 @@ public class LogMonitoringApiImpl implements LogMonitoringApi { | |||
| } | |||
| } | |||
| /** | |||
| * 添加 TADL 服务日志到 Elasticsearch | |||
| * | |||
| * @param experimentId 日志查询起始值,初始值为1,表示从第一条日志记录开始查询 | |||
| * @param log 日志 | |||
| * @return boolean 日志添加是否成功 | |||
| */ | |||
| @Override | |||
| public boolean addTadlLogsToEs(long experimentId, String log) { | |||
| Date date = new Date(); | |||
| String timestamp = TimeTransferUtil.dateTransferToUtc(date); | |||
| BulkRequest bulkRequest = new BulkRequest(); | |||
| try { | |||
| /**查询索引是否存在, 不存在则创建**/ | |||
| GetIndexRequest getIndexRequest = new GetIndexRequest(TADL_INDEX_NAME); | |||
| boolean exists = restHighLevelClient.indices().exists(getIndexRequest, RequestOptions.DEFAULT); | |||
| if (!exists){ | |||
| CreateIndexRequest createIndexRequest = new CreateIndexRequest(TADL_INDEX_NAME); | |||
| createIndexRequest.settings(Settings.builder() | |||
| .put(INDEX_SHARDS_NUMBER, 3) | |||
| .put(INDEX_REPLICAS_NUMBER, 2) | |||
| ); | |||
| Map<String, String> timestampMapping = new HashMap<>(); | |||
| timestampMapping.put(TYPE, DATE); | |||
| Map<String, String> logMapping = new HashMap<>(); | |||
| logMapping.put(TYPE, TEXT); | |||
| Map<String, String> experimentIdMapping = new HashMap<>(); | |||
| experimentIdMapping.put(TYPE, TEXT); | |||
| Map<String, Object> properties = new HashMap<>(); | |||
| properties.put(TIMESTAMP,timestampMapping); | |||
| properties.put(EXPERIMENT_ID,experimentIdMapping); | |||
| properties.put(MESSAGE,logMapping); | |||
| Map<String, Object> mapping = new HashMap<>(); | |||
| mapping.put(PROPERTIES, properties); | |||
| createIndexRequest.mapping(mapping); | |||
| CreateIndexResponse createIndexResponse = restHighLevelClient.indices().create(createIndexRequest, RequestOptions.DEFAULT); | |||
| } | |||
| LinkedHashMap<String, Object> jsonMap = new LinkedHashMap() {{ | |||
| put(EXPERIMENT_ID, experimentId); | |||
| put(MESSAGE, new SimpleDateFormat(TIMESTAMP_FORMAT).format(date) + SPACE + log); | |||
| put(TIMESTAMP, timestamp); | |||
| }}; | |||
| /**添加索引创建对象到bulkRequest**/ | |||
| bulkRequest.add(new IndexRequest(TADL_INDEX_NAME).source(jsonMap)); | |||
| /**通过restHighLevelClient发送http的请求创建文档**/ | |||
| restHighLevelClient.bulk(bulkRequest, RequestOptions.DEFAULT); | |||
| } catch (Exception e) { | |||
| LogUtil.error(LogEnum.BIZ_K8S, "LogMonitoringApi.addTadlLogsToEs error:{}", e); | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| /** | |||
| * TADL 服务日志查询方法 | |||
| * | |||
| * @param from 日志查询起始值,初始值为1,表示从第一条日志记录开始查询 | |||
| * @param size 日志查询记录数 | |||
| * @param experimentId TADL Experiment ID | |||
| * @return LogMonitoringVO 日志查询结果类 | |||
| */ | |||
| @Override | |||
| public LogMonitoringVO searchTadlLogById(int from, int size, long experimentId) { | |||
| List<String> logList = new ArrayList<>(); | |||
| LogMonitoringVO logMonitoringResult = new LogMonitoringVO(ZERO, logList); | |||
| /**处理查询范围参数起始值**/ | |||
| from = from <= MagicNumConstant.ZERO ? MagicNumConstant.ZERO : --from; | |||
| size = size <= MagicNumConstant.ZERO || size > TEN_THOUSAND ? TEN_THOUSAND : size; | |||
| /**创建搜索请求对象**/ | |||
| SearchRequest searchRequest = new SearchRequest(); | |||
| searchRequest.indices(TADL_INDEX_NAME); | |||
| SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); | |||
| searchSourceBuilder.trackTotalHits(true).from(from).size(size); | |||
| /**根据时间戳排序**/ | |||
| searchSourceBuilder.sort(TIMESTAMP, SortOrder.ASC); | |||
| /**创建布尔查询对象**/ | |||
| BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery(); | |||
| boolQueryBuilder.filter(QueryBuilders.matchQuery(EXPERIMENT_ID, experimentId)); | |||
| /**设置boolQueryBuilder到searchSourceBuilder**/ | |||
| searchSourceBuilder.query(boolQueryBuilder); | |||
| searchRequest = searchRequest.source(searchSourceBuilder); | |||
| /**执行搜索**/ | |||
| SearchResponse searchResponse; | |||
| try { | |||
| searchResponse = restHighLevelClient.search(searchRequest, RequestOptions.DEFAULT); | |||
| } catch (Exception e) { | |||
| LogUtil.error(LogEnum.BIZ_K8S, "LogMonitoringApiImpl.searchTadlLogById error,param:[experimentId]={}, error:{}", experimentId, e); | |||
| return logMonitoringResult; | |||
| } | |||
| /**获取响应结果**/ | |||
| SearchHits hits = searchResponse.getHits(); | |||
| SearchHit[] searchHits = hits.getHits(); | |||
| if (searchHits.length == MagicNumConstant.ZERO) { | |||
| return logMonitoringResult; | |||
| } | |||
| for (SearchHit hit : searchHits) { | |||
| /**源文档**/ | |||
| Map<String, Object> sourceAsMap = hit.getSourceAsMap(); | |||
| /**取出message**/ | |||
| String message = (String) sourceAsMap.get(MESSAGE); | |||
| message = message.replace(LINEBREAK, BLANK); | |||
| /**添加日志信息到集合**/ | |||
| logList.add(message); | |||
| } | |||
| logMonitoringResult.setLogs(logList); | |||
| logMonitoringResult.setTotalLogs(logList.size()); | |||
| return logMonitoringResult; | |||
| } | |||
| /** | |||
| * 得到日志信息String | |||
| * | |||
| @@ -256,6 +410,51 @@ public class LogMonitoringApiImpl implements LogMonitoringApi { | |||
| return null; | |||
| } | |||
| /** | |||
| * 从Elasticsearch查询日志 | |||
| * | |||
| * @param logMonitoringBo 日志查询bo | |||
| * @return List<String> 日志集合 | |||
| */ | |||
| private List<String> searchLogInfoByEs(LogMonitoringBO logMonitoringBo) { | |||
| List<String> logList = new ArrayList<>(); | |||
| SearchRequest searchRequest = buildSearchRequest(logMonitoringBo); | |||
| /**执行搜索**/ | |||
| SearchResponse searchResponse; | |||
| try { | |||
| searchResponse = restHighLevelClient.search(searchRequest, RequestOptions.DEFAULT); | |||
| } catch (Exception e) { | |||
| LogUtil.error(LogEnum.BIZ_K8S, "LogMonitoringApiImpl.searchLogInfoByEs error,param:[logMonitoringBo]={}, error:{}", JSON.toJSONString(logMonitoringBo), e); | |||
| return logList; | |||
| } | |||
| /**获取响应结果**/ | |||
| SearchHits hits = searchResponse.getHits(); | |||
| SearchHit[] searchHits = hits.getHits(); | |||
| if (searchHits.length == MagicNumConstant.ZERO) { | |||
| return logList; | |||
| } | |||
| for (SearchHit hit : searchHits) { | |||
| String esResult = hit.getSourceAsString(); | |||
| JSONObject jsonObject = JSON.parseObject(esResult); | |||
| String message = jsonObject.getString(MESSAGE); | |||
| message = message.replace(LINEBREAK, BLANK); | |||
| String podName = jsonObject.getJSONObject(KUBERNETES_KEY). | |||
| getString(KUBERNETES_POD_NAME_KEY); | |||
| /**拼接日志信息**/ | |||
| String logString = String.format(LOG_FORMAT, podName, message); | |||
| /**添加日志信息到集合**/ | |||
| logList.add(logString); | |||
| } | |||
| return logList; | |||
| } | |||
| /** | |||
| * 从Elasticsearch查询日志 | |||
| @@ -366,4 +565,74 @@ public class LogMonitoringApiImpl implements LogMonitoringApi { | |||
| return searchRequest.source(searchSourceBuilder); | |||
| } | |||
| /** | |||
| * 构建搜索请求对象 | |||
| * | |||
| * @param logMonitoringBo 日志查询bo | |||
| * @return SearchRequest ES搜索请求对象 | |||
| */ | |||
| private SearchRequest buildSearchRequest(LogMonitoringBO logMonitoringBo) { | |||
| /**创建搜索请求对象**/ | |||
| SearchRequest searchRequest = new SearchRequest(); | |||
| searchRequest.indices(INDEX_NAME); | |||
| SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); | |||
| searchSourceBuilder.trackTotalHits(true).from(logMonitoringBo.getFrom()).size(logMonitoringBo.getSize()); | |||
| /**根据时间戳排序**/ | |||
| searchSourceBuilder.sort(TIMESTAMP, SortOrder.ASC); | |||
| /**过虑源字段**/ | |||
| String[] sourceFieldArray = sourceField.split(COMMA); | |||
| searchSourceBuilder.fetchSource(sourceFieldArray, new String[]{}); | |||
| /**创建布尔查询对象**/ | |||
| BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery(); | |||
| /**添加podName查询条件**/ | |||
| Set<String> podNames = logMonitoringBo.getPodNames(); | |||
| if (CollectionUtils.isNotEmpty(podNames)) { | |||
| boolQueryBuilder.filter(QueryBuilders.termsQuery(POD_NAME_KEY, podNames.toArray(new String[podNames.size()]))); | |||
| } | |||
| /**添加namespace查询条件**/ | |||
| String namespace = logMonitoringBo.getNamespace(); | |||
| if (StringUtils.isNotEmpty(namespace)) { | |||
| boolQueryBuilder.filter(QueryBuilders.matchQuery(NAMESPACE_KEY, namespace)); | |||
| } | |||
| /**添加业务查询条件**/ | |||
| BizEnum business = logMonitoringBo.getBusiness(); | |||
| if (null != business) { | |||
| boolQueryBuilder.filter(QueryBuilders.termQuery(BUSINESS_KEY, business.getBizCode())); | |||
| } | |||
| /**添加实验Id查询条件**/ | |||
| String businessGroupId = logMonitoringBo.getBusinessGroupId(); | |||
| if (StringUtils.isNotEmpty(businessGroupId)) { | |||
| boolQueryBuilder.filter(QueryBuilders.termQuery(BUSINESS_GROUP_ID_KEY, businessGroupId)); | |||
| } | |||
| /**添加关键字查询条件**/ | |||
| String logKeyword = logMonitoringBo.getLogKeyword(); | |||
| if (StringUtils.isNotEmpty(logKeyword)) { | |||
| boolQueryBuilder.filter(QueryBuilders.matchQuery(MESSAGE, logKeyword).operator(Operator.AND)); | |||
| } | |||
| /**添加时间范围查询条件**/ | |||
| Long beginTimeMillis = logMonitoringBo.getBeginTimeMillis(); | |||
| Long endTimeMillis = logMonitoringBo.getEndTimeMillis(); | |||
| if (beginTimeMillis != null || endTimeMillis != null){ | |||
| beginTimeMillis = beginTimeMillis == null ? ZERO_LONG : beginTimeMillis; | |||
| endTimeMillis = endTimeMillis == null ? System.currentTimeMillis() : endTimeMillis; | |||
| /**将毫秒值转换为UTC时间**/ | |||
| String beginUtcTime = TimeTransferUtil.dateTransferToUtc(new Date(beginTimeMillis)); | |||
| String endUtcTime = TimeTransferUtil.dateTransferToUtc(new Date(endTimeMillis)); | |||
| boolQueryBuilder.filter(QueryBuilders.rangeQuery(TIMESTAMP).gte(beginUtcTime).lte(endUtcTime)); | |||
| } | |||
| /**设置boolQueryBuilder到searchSourceBuilder**/ | |||
| searchSourceBuilder.query(boolQueryBuilder); | |||
| return searchRequest.source(searchSourceBuilder); | |||
| } | |||
| } | |||
| @@ -47,6 +47,7 @@ import org.dubhe.k8s.domain.bo.TerminalBO; | |||
| import org.dubhe.k8s.domain.vo.PtJupyterDeployVO; | |||
| import org.dubhe.k8s.domain.vo.TerminalResourceVO; | |||
| import org.dubhe.k8s.domain.vo.VolumeVO; | |||
| import org.dubhe.k8s.enums.ImagePullPolicyEnum; | |||
| import org.dubhe.k8s.enums.K8sKindEnum; | |||
| import org.dubhe.k8s.enums.K8sResponseEnum; | |||
| import org.dubhe.k8s.enums.LackOfResourcesEnum; | |||
| @@ -132,6 +133,7 @@ public class TerminalApiImpl implements TerminalApi { | |||
| Map<String, String> podLabels = LabelUtils.getChildLabels(bo.getResourceName(), deploymentName, K8sKindEnum.DEPLOYMENT.getKind(), bo.getBusinessLabel(), bo.getTaskIdentifyLabel()); | |||
| //部署deployment | |||
| bo.setImagePullPolicy(ImagePullPolicyEnum.ALWAYS.getPolicy()); | |||
| Deployment deployment = ResourceBuildUtils.buildDeployment(bo, volumeVO, deploymentName); | |||
| LogUtil.info(LogEnum.BIZ_K8S, "Ready to deploy {}, yaml信息为{}", deploymentName, YamlUtils.dumpAsYaml(deployment)); | |||
| resourceIisolationApi.addIisolationInfo(deployment); | |||
| @@ -250,8 +250,8 @@ public class TrainJobApiImpl implements TrainJobApi { | |||
| this.fsMounts = bo.getFsMounts(); | |||
| businessLabel = bo.getBusinessLabel(); | |||
| this.baseLabels = LabelUtils.getBaseLabels(baseName,bo.getBusinessLabel(),bo.getExtraLabelMap()); | |||
| this.taskIdentifyLabel = bo.getTaskIdentifyLabel(); | |||
| this.baseLabels = LabelUtils.getBaseLabels(baseName,bo.getBusinessLabel()); | |||
| this.volumeMounts = new ArrayList<>(); | |||
| this.volumes = new ArrayList<>(); | |||
| @@ -458,7 +458,7 @@ public class TrainJobApiImpl implements TrainJobApi { | |||
| .withNewTemplate() | |||
| .withNewMetadata() | |||
| .withName(jobName) | |||
| .addToLabels(LabelUtils.getChildLabels(baseName, jobName, K8sKindEnum.JOB.getKind(),businessLabel, taskIdentifyLabel)) | |||
| .addToLabels(LabelUtils.getChildLabels(baseName, jobName, K8sKindEnum.JOB.getKind(),businessLabel,taskIdentifyLabel,baseLabels)) | |||
| .withNamespace(namespace) | |||
| .endMetadata() | |||
| .withNewSpec() | |||
| @@ -474,7 +474,7 @@ public class TrainJobApiImpl implements TrainJobApi { | |||
| if (delayCreate == null || delayCreate == MagicNumConstant.ZERO){ | |||
| resourceIisolationApi.addIisolationInfo(job); | |||
| LogUtil.info(LogEnum.BIZ_K8S, "Ready to deploy {}", jobName); | |||
| job = client.batch().jobs().create(job); | |||
| job = client.batch().jobs().inNamespace(namespace).create(job); | |||
| LogUtil.info(LogEnum.BIZ_K8S, "{} deployed successfully", jobName); | |||
| } | |||
| if (delayCreate > MagicNumConstant.ZERO || delayDelete > MagicNumConstant.ZERO){ | |||
| @@ -93,6 +93,14 @@ public class DeploymentBO { | |||
| */ | |||
| private Set<Integer> ports; | |||
| /** | |||
| * 镜像拉取策略 | |||
| * IfNotPresent 默认值 | |||
| * Always | |||
| * Never | |||
| */ | |||
| private String imagePullPolicy; | |||
| /** | |||
| * 获取nfs路径 | |||
| * @return | |||
| @@ -17,6 +17,7 @@ | |||
| package org.dubhe.k8s.domain.bo; | |||
| import lombok.Data; | |||
| import lombok.experimental.Accessors; | |||
| import org.dubhe.biz.base.enums.BizEnum; | |||
| import org.dubhe.k8s.domain.dto.PodLogQueryDTO; | |||
| import java.util.Set; | |||
| @@ -56,6 +57,26 @@ public class LogMonitoringBO { | |||
| **/ | |||
| private Long endTimeMillis; | |||
| /** | |||
| * 日志查询起始行 | |||
| **/ | |||
| private Integer from; | |||
| /** | |||
| * 日志查询行数 | |||
| **/ | |||
| private Integer size; | |||
| /** | |||
| * 业务标签,用于标识一个组的业务模块 比如:TRAIN模块的trainId, TADL模块的experimentId | |||
| */ | |||
| private String businessGroupId; | |||
| /** | |||
| * 业务标签,用于标识业务模块 | |||
| */ | |||
| private BizEnum business; | |||
| public LogMonitoringBO(String namespace,String podName){ | |||
| this.namespace = namespace; | |||
| this.podName = podName; | |||
| @@ -65,6 +65,8 @@ public class PtJupyterJobBO { | |||
| private GraphicsCardTypeEnum graphicsCardType; | |||
| /**业务标签,用于标识业务模块**/ | |||
| private String businessLabel; | |||
| /**额外扩展的标签**/ | |||
| private Map<String, String> extraLabelMap; | |||
| /**任务身份标签,用于标识任务身份**/ | |||
| private String taskIdentifyLabel; | |||
| /**延时创建时间,单位:分钟**/ | |||
| @@ -22,6 +22,7 @@ import io.swagger.annotations.ApiModelProperty; | |||
| import lombok.Data; | |||
| import javax.validation.constraints.NotEmpty; | |||
| import java.util.Map; | |||
| /** | |||
| * @descripton 统一通用参数实现与校验 | |||
| @@ -58,6 +59,9 @@ public class BaseK8sPodCallbackCreateDTO { | |||
| @ApiModelProperty(value = "k8s pod containerStatuses state") | |||
| private String messages; | |||
| @ApiModelProperty(value = "k8s pod lables") | |||
| private Map<String,String> lables; | |||
| public BaseK8sPodCallbackCreateDTO(){ | |||
| } | |||
| @@ -30,7 +30,7 @@ import java.util.List; | |||
| @Data | |||
| @AllArgsConstructor | |||
| public class LogMonitoringVO extends PtBaseResult { | |||
| private Long totalLogs; | |||
| private Integer totalLogs; | |||
| private List<String> logs; | |||
| public LogMonitoringVO() { | |||
| @@ -52,7 +52,10 @@ public enum BusinessLabelServiceNameEnum { | |||
| * 专业版终端 | |||
| */ | |||
| TERMINAL(BizEnum.TERMINAL.getBizCode(), ApplicationNameConst.TERMINAL), | |||
| ; | |||
| /** | |||
| * TADL | |||
| */ | |||
| TADL(BizEnum.TADL.getBizCode(), ApplicationNameConst.SERVER_TADL); | |||
| /** | |||
| * 业务标签 | |||
| */ | |||
| @@ -74,18 +77,19 @@ public enum BusinessLabelServiceNameEnum { | |||
| this.businessLabel = businessLabel; | |||
| this.serviceName = serviceName; | |||
| } | |||
| public static String getServiceNameByBusinessLabel(String businessLabel){ | |||
| public static String getServiceNameByBusinessLabel(String businessLabel) { | |||
| for (BusinessLabelServiceNameEnum businessLabelServiceNameEnum : BusinessLabelServiceNameEnum.values()) { | |||
| if (StringUtils.equals(businessLabel, businessLabelServiceNameEnum.getBusinessLabel() )){ | |||
| if (StringUtils.equals(businessLabel, businessLabelServiceNameEnum.getBusinessLabel())) { | |||
| return businessLabelServiceNameEnum.getServiceName(); | |||
| } | |||
| } | |||
| return BLANK; | |||
| } | |||
| public static String getBusinessLabelByServiceName(String serviceName){ | |||
| public static String getBusinessLabelByServiceName(String serviceName) { | |||
| for (BusinessLabelServiceNameEnum businessLabelServiceNameEnum : BusinessLabelServiceNameEnum.values()) { | |||
| if (StringUtils.equals(serviceName, businessLabelServiceNameEnum.getServiceName() )){ | |||
| if (StringUtils.equals(serviceName, businessLabelServiceNameEnum.getServiceName())) { | |||
| return businessLabelServiceNameEnum.getBusinessLabel(); | |||
| } | |||
| } | |||
| @@ -73,7 +73,7 @@ public class K8sCallBackTool { | |||
| * k8s 回调路径 | |||
| */ | |||
| private static final String K8S_CALLBACK_PATH_DEPLOYMENT = "/api/k8s/callback/deployment/"; | |||
| public static final String K8S_CALLBACK_PATH_POD = StringConstant.K8S_CALLBACK_URI+ SymbolConstant.SLASH; | |||
| public static final String K8S_CALLBACK_PATH_POD = StringConstant.K8S_CALLBACK_URI + SymbolConstant.SLASH; | |||
| static { | |||
| K8S_CALLBACK_PATH = new ArrayList<>(); | |||
| @@ -113,7 +113,7 @@ public class K8sCallBackTool { | |||
| */ | |||
| public boolean validateToken(String token) { | |||
| String expireTime = AesUtil.decrypt(token, secretKey); | |||
| if (StringUtils.isEmpty(expireTime)){ | |||
| if (StringUtils.isEmpty(expireTime)) { | |||
| return false; | |||
| } | |||
| String nowTime = DateUtil.format( | |||
| @@ -141,7 +141,7 @@ public class K8sCallBackTool { | |||
| * @return String | |||
| */ | |||
| public String getPodCallbackUrl(String podLabel) { | |||
| return "http://"+BusinessLabelServiceNameEnum.getServiceNameByBusinessLabel(podLabel) + K8S_CALLBACK_PATH_POD + podLabel; | |||
| return "http://" + BusinessLabelServiceNameEnum.getServiceNameByBusinessLabel(podLabel) + K8S_CALLBACK_PATH_POD + podLabel; | |||
| } | |||
| /** | |||
| @@ -151,7 +151,7 @@ public class K8sCallBackTool { | |||
| * @return String | |||
| */ | |||
| public String getDeploymentCallbackUrl(String businessLabel) { | |||
| return "http://"+BusinessLabelServiceNameEnum.getServiceNameByBusinessLabel(businessLabel) + K8S_CALLBACK_PATH_DEPLOYMENT + businessLabel; | |||
| return "http://" + BusinessLabelServiceNameEnum.getServiceNameByBusinessLabel(businessLabel) + K8S_CALLBACK_PATH_DEPLOYMENT + businessLabel; | |||
| } | |||
| @@ -63,6 +63,8 @@ import java.util.List; | |||
| import java.util.Map; | |||
| import java.util.Optional; | |||
| import static org.dubhe.biz.base.constant.MagicNumConstant.ZERO_LONG; | |||
| /** | |||
| * @description 构建 Kubernetes 资源对象 | |||
| * @date 2020-09-10 | |||
| @@ -256,6 +258,7 @@ public class ResourceBuildUtils { | |||
| .withNamespace(bo.getNamespace()) | |||
| .endMetadata() | |||
| .withNewSpec() | |||
| .withTerminationGracePeriodSeconds(ZERO_LONG) | |||
| .addToNodeSelector(K8sUtils.gpuSelector(bo.getGpuNum())) | |||
| .addToContainers(buildContainer(bo, volumeVO, deploymentName)) | |||
| .addToVolumes(volumeVO.getVolumes().toArray(new Volume[0])) | |||
| @@ -281,7 +284,7 @@ public class ResourceBuildUtils { | |||
| Container container = new ContainerBuilder() | |||
| .withNewName(name) | |||
| .withNewImage(bo.getImage()) | |||
| .withNewImagePullPolicy(ImagePullPolicyEnum.IFNOTPRESENT.getPolicy()) | |||
| .withNewImagePullPolicy(StringUtils.isEmpty(bo.getImagePullPolicy())?ImagePullPolicyEnum.IFNOTPRESENT.getPolicy():bo.getImagePullPolicy()) | |||
| .withVolumeMounts(volumeVO.getVolumeMounts()) | |||
| .withNewResources().addToLimits(resourcesLimitsMap).endResources() | |||
| .build(); | |||
| @@ -0,0 +1,19 @@ | |||
| apiVersion: v1 | |||
| clusters: | |||
| - cluster: | |||
| certificate-authority-data: | |||
| server: https://127.0.0.1:6443 | |||
| name: kubernetes | |||
| contexts: | |||
| - context: | |||
| cluster: kubernetes | |||
| user: kubernetes-admin | |||
| name: kubernetes-admin@kubernetes | |||
| current-context: kubernetes-admin@kubernetes | |||
| kind: Config | |||
| preferences: {} | |||
| users: | |||
| - name: kubernetes-admin | |||
| user: | |||
| client-certificate-data: | |||
| client-key-data: | |||
| @@ -0,0 +1,19 @@ | |||
| apiVersion: v1 | |||
| clusters: | |||
| - cluster: | |||
| certificate-authority-data: | |||
| server: https://127.0.0.1:6443 | |||
| name: kubernetes | |||
| contexts: | |||
| - context: | |||
| cluster: kubernetes | |||
| user: kubernetes-admin | |||
| name: kubernetes-admin@kubernetes | |||
| current-context: kubernetes-admin@kubernetes | |||
| kind: Config | |||
| preferences: {} | |||
| users: | |||
| - name: kubernetes-admin | |||
| user: | |||
| client-certificate-data: | |||
| client-key-data: | |||
| @@ -1,7 +1,7 @@ | |||
| apiVersion: v1 | |||
| clusters: | |||
| - cluster: | |||
| certificate-authority-data: | |||
| certificate-authority-data: | |||
| server: https://127.0.0.1:6443 | |||
| name: kubernetes | |||
| contexts: | |||
| @@ -15,5 +15,5 @@ preferences: {} | |||
| users: | |||
| - name: kubernetes-admin | |||
| user: | |||
| client-certificate-data: | |||
| client-key-data: | |||
| client-certificate-data: | |||
| client-key-data: | |||
| @@ -0,0 +1,19 @@ | |||
| apiVersion: v1 | |||
| clusters: | |||
| - cluster: | |||
| certificate-authority-data: | |||
| server: https://127.0.0.1:6443 | |||
| name: kubernetes | |||
| contexts: | |||
| - context: | |||
| cluster: kubernetes | |||
| user: kubernetes-admin | |||
| name: kubernetes-admin@kubernetes | |||
| current-context: kubernetes-admin@kubernetes | |||
| kind: Config | |||
| preferences: {} | |||
| users: | |||
| - name: kubernetes-admin | |||
| user: | |||
| client-certificate-data: | |||
| client-key-data: | |||
| @@ -68,5 +68,9 @@ public class RecycleConfig { | |||
| * 回收serving相关文件后,回收文件最大有效时长,以天为单位 | |||
| */ | |||
| private Integer servingValid; | |||
| /** | |||
| * 用户删除tadl算法版本文件后,文件最大有效时长,以天为单位 | |||
| */ | |||
| private Integer tadlValid; | |||
| } | |||
| @@ -39,7 +39,8 @@ public enum RecycleModuleEnum { | |||
| BIZ_MODEL(7, "模型管理",SERVER_MODEL), | |||
| BIZ_DATAMEDICINE(8, "医学影像",SERVER_DATA_DCM), | |||
| BIZ_MEASURE(9, "度量管理",SERVER_MEASURE), | |||
| BIZ_SERVING(10, "云端Serving", SERVER_SERVING); | |||
| BIZ_SERVING(10, "云端部署", SERVER_SERVING), | |||
| BIZ_TADL(11,"自动机器学习",SERVER_TADL); | |||
| private Integer value; | |||
| @@ -50,6 +50,15 @@ public enum RecycleResourceEnum { | |||
| */ | |||
| BATCH_SERVING_RECYCLE_FILE("batchServingRecycleFile", "云端Serving批量服务文件回收"), | |||
| /** | |||
| * tadl算法文件回收 | |||
| */ | |||
| TADL_ALGORITHM_RECYCLE_FILE("tadlAlgorithmRecycleFile", "tadl算法文件回收"), | |||
| /** | |||
| * tadl实验文件回收 | |||
| */ | |||
| TADL_EXPERIMENT_RECYCLE_FILE("tadlExperimentRecycleFile","tadl实验文件回收"), | |||
| /** | |||
| * 标签组文件回收 | |||
| */ | |||
| @@ -197,7 +197,7 @@ public class RecycleTool { | |||
| if (sourcePath.length() > nfsBucket.length()) { | |||
| String emptyDir = recycleFileTmpPath + randomPath + StrUtil.SLASH; | |||
| LogUtil.info(LogEnum.GARBAGE_RECYCLE, "recycle task sourcePath:{},emptyDir:{}", sourcePath, emptyDir); | |||
| process = Runtime.getRuntime().exec(new String[]{"/bin/sh", "-c", String.format(ShellFileStoreApiImpl.DEL_COMMAND, userName, ip, emptyDir, emptyDir, sourcePath, emptyDir, sourcePath)}); | |||
| process = Runtime.getRuntime().exec(new String[]{"/bin/sh", "-c", String.format(ShellFileStoreApiImpl.DEL_COMMAND, emptyDir, emptyDir, sourcePath, emptyDir, sourcePath)}); | |||
| } | |||
| return processRecycle(process); | |||
| @@ -559,7 +559,7 @@ public class DataTaskExecuteThread implements Runnable { | |||
| for (List<File> el : lists) { | |||
| List<Long> fileIds = csvImportSaveDb(el, dataset); | |||
| LogUtil.info(LogEnum.BIZ_DATASET, "table import transport to es datasetid:{}", datasetId); | |||
| fileService.transportTextToEs(dataset, fileIds); | |||
| fileService.transportTextToEs(dataset, fileIds,Boolean.FALSE); | |||
| } | |||
| } | |||
| //------- 导入完成后 更改数据集状态 --------- | |||
| @@ -112,6 +112,7 @@ public enum ErrorEnum implements ErrorCode { | |||
| DATASET_NOT_ANNOTATION(1718, "数据集暂不支持自动标注"), | |||
| DATASET_NOT_OPERATIONS_BASE_DATASET(1719, "禁止操作内置的数据集"), | |||
| DATASET_PUBLISH_REJECT(1720, "文本暂不支持多版本发布"), | |||
| DATASET_CHECK_VERSION_ERROR(1721,"目标版本不存在"), | |||
| /** | |||
| * 数据集版本校验 | |||
| @@ -19,6 +19,7 @@ package org.dubhe.data.dao; | |||
| import com.baomidou.mybatisplus.core.mapper.BaseMapper; | |||
| import org.apache.ibatis.annotations.*; | |||
| import org.dubhe.data.domain.bo.FileUploadBO; | |||
| import org.dubhe.data.domain.dto.DatasetVersionFileDTO; | |||
| import org.dubhe.data.domain.entity.DataFileAnnotation; | |||
| import org.dubhe.data.domain.entity.Dataset; | |||
| @@ -302,4 +303,11 @@ public interface DatasetVersionFileMapper extends BaseMapper<DatasetVersionFile> | |||
| * @return Long 版本文件id | |||
| */ | |||
| Long getVersionFileIdByFileName(@Param("datasetId")Long datasetId, @Param("fileName")String fileName, @Param("versionName")String versionName); | |||
| /** | |||
| * 获取导入文件所需信息 | |||
| * @param datasetId 数据集id | |||
| * @return List<FileUploadBO> | |||
| */ | |||
| List<FileUploadBO> getFileUploadContent(@Param("datasetId")Long datasetId, @Param("fileIds")List<Long> fileIds); | |||
| } | |||
| @@ -164,7 +164,8 @@ public interface FileMapper extends BaseMapper<File> { | |||
| * @param fileIdsNotToEs 需要同步的文件ID | |||
| * @return List<EsTransportDTO> ES数据同步DTO | |||
| */ | |||
| List<EsTransportDTO> selectTextDataNoTransport(@Param("datasetId") Long datasetId,@Param("fileIdsNotToEs")List<Long> fileIdsNotToEs); | |||
| List<EsTransportDTO> selectTextDataNoTransport(@Param("datasetId") Long datasetId,@Param("fileIdsNotToEs")List<Long> fileIdsNotToEs, | |||
| @Param("ifImport") Boolean ifImport); | |||
| /** | |||
| * 更新同步es标志 | |||
| @@ -0,0 +1,40 @@ | |||
| /** | |||
| * Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| * ============================================================= | |||
| */ | |||
| package org.dubhe.data.domain.bo; | |||
| import lombok.*; | |||
| import java.io.Serializable; | |||
| @Builder | |||
| @Data | |||
| @ToString | |||
| @AllArgsConstructor | |||
| @NoArgsConstructor | |||
| public class FileUploadBO implements Serializable { | |||
| String fileUrl; | |||
| String fileName; | |||
| Long fileId; | |||
| Long versionFileId; | |||
| String annPath; | |||
| } | |||
| @@ -43,4 +43,5 @@ public class BatchFileCreateDTO implements Serializable { | |||
| @NotNull(message = "文件不能为空") | |||
| private List<FileCreateDTO> files; | |||
| Boolean ifImport; | |||
| } | |||
| @@ -187,6 +187,13 @@ public class FileController { | |||
| return new DataResponseBody(minioUtil.getEncryptedPutUrl(bucketName, objectName, expiry)); | |||
| } | |||
| @ApiOperation("MinIO生成put请求的上传路径列表") | |||
| @PostMapping(value = "/minio/getUrls") | |||
| @PreAuthorize(Permissions.DATA) | |||
| public DataResponseBody getEncryptedPutUrls(@RequestBody String objectNames) { | |||
| return new DataResponseBody(minioUtil.getEncryptedPutUrls(bucketName, objectNames, expiry)); | |||
| } | |||
| @ApiOperation("获取MinIO相关信息") | |||
| @GetMapping(value = "/minio/info") | |||
| public DataResponseBody getMinIOInfo() { | |||
| @@ -53,8 +53,7 @@ public class LabelGroupController { | |||
| @PostMapping(value = "/labelGroup") | |||
| @PreAuthorize(Permissions.DATA) | |||
| public DataResponseBody create(@Validated @RequestBody LabelGroupCreateDTO labelGroupCreateDTO) { | |||
| labelGroupService.creatLabelGroup(labelGroupCreateDTO); | |||
| return new DataResponseBody(); | |||
| return new DataResponseBody(labelGroupService.creatLabelGroup(labelGroupCreateDTO)); | |||
| } | |||
| @ApiOperation(value = "标签组分页列表") | |||
| @@ -103,8 +102,7 @@ public class LabelGroupController { | |||
| public DataResponseBody importLabelGroup( | |||
| @RequestParam(value = "file", required = false) MultipartFile file, | |||
| LabelGroupImportDTO labelGroupImportDTO) { | |||
| labelGroupService.importLabelGroup(labelGroupImportDTO, file); | |||
| return new DataResponseBody(); | |||
| return new DataResponseBody(labelGroupService.importLabelGroup(labelGroupImportDTO, file)); | |||
| } | |||
| @@ -18,6 +18,7 @@ | |||
| package org.dubhe.data.service; | |||
| import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; | |||
| import org.dubhe.data.domain.bo.FileUploadBO; | |||
| import org.dubhe.data.domain.dto.DatasetVersionFileDTO; | |||
| import org.dubhe.data.domain.entity.Dataset; | |||
| import org.dubhe.data.domain.entity.DatasetVersion; | |||
| @@ -342,4 +343,11 @@ public interface DatasetVersionFileService { | |||
| * @param versionName 版本名称 | |||
| */ | |||
| Long getVersionFileIdByFileName(Long datasetId, String fileName, String versionName); | |||
| /** | |||
| * 获取导入文件所需信息 | |||
| * @param datasetId 数据集id | |||
| * @return List<FileUploadBO> | |||
| */ | |||
| List<FileUploadBO> getFileUploadContent(Long datasetId, List<Long> fileIds); | |||
| } | |||
| @@ -315,7 +315,7 @@ public interface FileService { | |||
| * @param dataset 数据集 | |||
| * @param fileIdsNotToEs 需要同步的文件ID | |||
| */ | |||
| void transportTextToEs(Dataset dataset,List<Long> fileIdsNotToEs); | |||
| void transportTextToEs(Dataset dataset,List<Long> fileIdsNotToEs,Boolean ifImport); | |||
| /** | |||
| * 还原es_transport状态 | |||
| @@ -38,7 +38,7 @@ public interface LabelGroupService { | |||
| * | |||
| * @param labelGroupCreateDTO 创建标签组DTO | |||
| */ | |||
| void creatLabelGroup(LabelGroupCreateDTO labelGroupCreateDTO); | |||
| Long creatLabelGroup(LabelGroupCreateDTO labelGroupCreateDTO); | |||
| /** | |||
| * 更新(编辑)标签组 | |||
| @@ -94,7 +94,7 @@ public interface LabelGroupService { | |||
| * @param labelGroupImportDTO 标签组导入DTO | |||
| * @param file 导入文件 | |||
| */ | |||
| void importLabelGroup(LabelGroupImportDTO labelGroupImportDTO, MultipartFile file); | |||
| Long importLabelGroup(LabelGroupImportDTO labelGroupImportDTO, MultipartFile file); | |||
| /** | |||
| * 标签组复制 | |||
| @@ -22,6 +22,7 @@ import cn.hutool.core.util.ObjectUtil; | |||
| import cn.hutool.core.util.StrUtil; | |||
| import com.alibaba.fastjson.JSON; | |||
| import com.alibaba.fastjson.JSONArray; | |||
| import com.alibaba.fastjson.JSONObject; | |||
| import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; | |||
| import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper; | |||
| import com.baomidou.mybatisplus.core.metadata.IPage; | |||
| @@ -56,6 +57,7 @@ import org.dubhe.data.constant.*; | |||
| import org.dubhe.data.dao.DatasetMapper; | |||
| import org.dubhe.data.dao.TaskMapper; | |||
| import org.dubhe.biz.base.vo.ProgressVO; | |||
| import org.dubhe.data.domain.bo.FileUploadBO; | |||
| import org.dubhe.data.domain.dto.*; | |||
| import org.dubhe.data.domain.entity.*; | |||
| import org.dubhe.data.domain.vo.*; | |||
| @@ -67,6 +69,7 @@ import org.dubhe.data.machine.utils.StateMachineUtil; | |||
| import org.dubhe.data.pool.BasePool; | |||
| import org.dubhe.data.service.*; | |||
| import org.dubhe.data.service.task.DatasetRecycleFile; | |||
| import org.dubhe.data.util.GeneratorKeyUtil; | |||
| import org.dubhe.data.util.ZipUtil; | |||
| import org.dubhe.recycle.domain.dto.RecycleCreateDTO; | |||
| import org.dubhe.recycle.domain.dto.RecycleDetailCreateDTO; | |||
| @@ -260,6 +263,9 @@ public class DatasetServiceImpl extends ServiceImpl<DatasetMapper, Dataset> impl | |||
| @Resource | |||
| private MinioUtil minioUtil; | |||
| @Autowired | |||
| private GeneratorKeyUtil generatorKeyUtil; | |||
| /** | |||
| * 线程池 | |||
| */ | |||
| @@ -603,25 +609,23 @@ public class DatasetServiceImpl extends ServiceImpl<DatasetMapper, Dataset> impl | |||
| } catch (DuplicateKeyException e) { | |||
| throw new BusinessException(ErrorEnum.DATASET_NAME_DUPLICATED_ERROR); | |||
| } | |||
| if (!dataset.isImport()) { | |||
| //新增数据标签关系 | |||
| List<Label> labels = labelService.listByGroupId(datasetCreateDTO.getLabelGroupId()); | |||
| if (!CollectionUtils.isEmpty(labels)) { | |||
| List<DatasetLabel> datasetLabels = labels.stream().map(a -> { | |||
| DatasetLabel datasetLabel = new DatasetLabel(); | |||
| datasetLabel.setDatasetId(dataset.getId()); | |||
| datasetLabel.setLabelId(a.getId()); | |||
| return datasetLabel; | |||
| }).collect(Collectors.toList()); | |||
| datasetLabelService.saveList(datasetLabels); | |||
| } | |||
| //预置标签处理 | |||
| if (datasetCreateDTO.getPresetLabelType() != null) { | |||
| presetLabel(datasetCreateDTO.getPresetLabelType(), dataset.getId()); | |||
| } | |||
| if (DatatypeEnum.VIDEO.getValue().equals(datasetCreateDTO.getDataType())) { | |||
| dataset.setStatus(DataStateCodeConstant.NOT_SAMPLED_STATE); | |||
| } | |||
| //新增数据标签关系 | |||
| List<Label> labels = labelService.listByGroupId(datasetCreateDTO.getLabelGroupId()); | |||
| if (!CollectionUtils.isEmpty(labels)) { | |||
| List<DatasetLabel> datasetLabels = labels.stream().map(a -> { | |||
| DatasetLabel datasetLabel = new DatasetLabel(); | |||
| datasetLabel.setDatasetId(dataset.getId()); | |||
| datasetLabel.setLabelId(a.getId()); | |||
| return datasetLabel; | |||
| }).collect(Collectors.toList()); | |||
| datasetLabelService.saveList(datasetLabels); | |||
| } | |||
| //预置标签处理 | |||
| if (datasetCreateDTO.getPresetLabelType() != null) { | |||
| presetLabel(datasetCreateDTO.getPresetLabelType(), dataset.getId()); | |||
| } | |||
| if (DatatypeEnum.VIDEO.getValue().equals(datasetCreateDTO.getDataType())) { | |||
| dataset.setStatus(DataStateCodeConstant.NOT_SAMPLED_STATE); | |||
| } | |||
| dataset.setUri(fileUtil.getDatasetAbsPath(dataset.getId())); | |||
| if (datasetCreateDTO.getDataType().equals(DatatypeEnum.AUTO_IMPORT.getValue())) { | |||
| @@ -631,6 +635,9 @@ public class DatasetServiceImpl extends ServiceImpl<DatasetMapper, Dataset> impl | |||
| dataset.setStatus(DataStateCodeConstant.ANNOTATION_COMPLETE_STATE); | |||
| dataset.setCurrentVersionName(DEFAULT_VERSION); | |||
| } | |||
| if(datasetCreateDTO.isImport()){ | |||
| dataset.setStatus(DataStateEnum.ANNOTATION_COMPLETE_STATE.getCode()); | |||
| } | |||
| updateById(dataset); | |||
| return dataset.getId(); | |||
| } | |||
| @@ -725,9 +732,10 @@ public class DatasetServiceImpl extends ServiceImpl<DatasetMapper, Dataset> impl | |||
| setEventMethodName(DataStateMachineConstant.DATA_DELETE_FILES_EVENT); | |||
| setStateMachineType(DataStateMachineConstant.DATA_STATE_MACHINE); | |||
| }}); | |||
| if(dataset.getDataType().equals(MagicNumConstant.TWO) || dataset.getDataType().equals(MagicNumConstant.THREE)){ | |||
| fileService.deleteEsData(fileDeleteDTO.getFileIds()); | |||
| } | |||
| } | |||
| fileService.deleteEsData(fileDeleteDTO.getFileIds()); | |||
| } | |||
| @@ -986,8 +994,46 @@ public class DatasetServiceImpl extends ServiceImpl<DatasetMapper, Dataset> impl | |||
| */ | |||
| @Override | |||
| public void uploadFiles(Long datasetId, BatchFileCreateDTO batchFileCreateDTO) { | |||
| List<Long> fileIds = saveDbForUploadFiles(datasetId, batchFileCreateDTO); | |||
| transportTextToEsForUploadFiles(datasetId, fileIds); | |||
| List<Long> fileIds = saveDbForUploadFiles(datasetId, batchFileCreateDTO, batchFileCreateDTO.getIfImport()); | |||
| if(batchFileCreateDTO.getIfImport()!=null && batchFileCreateDTO.getIfImport()){ | |||
| importFileAnnotation(datasetId, fileIds); | |||
| } | |||
| transportTextToEsForUploadFiles(datasetId, fileIds,batchFileCreateDTO.getIfImport()); | |||
| } | |||
| void importFileAnnotation(Long datasetId, List<Long> fileIds) { | |||
| List<Long> versionFileIds = datasetVersionFileService.getVersionFileIdsByFileIds(datasetId, fileIds); | |||
| List<FileUploadBO> fileUploadContent = datasetVersionFileService.getFileUploadContent(datasetId,fileIds); | |||
| List<DataFileAnnotation> dataFileAnnotations = new ArrayList<>(); | |||
| fileUploadContent.forEach(fileUploadBO -> { | |||
| String annPath = StringUtils.substringBeforeLast(fileUploadBO.getFileUrl(), "."); | |||
| annPath = annPath.replace("/origin/","/annotation/").replace(bucket+"/",""); | |||
| try { | |||
| JSONArray annJsonArray = JSONObject.parseArray((minioUtil.readString(bucket, annPath))); | |||
| for (Object object : annJsonArray) { | |||
| JSONObject jsonObject = (JSONObject) object; | |||
| Long categoryId = Long.parseLong(jsonObject.getString("category_id")); | |||
| Double score = jsonObject.getString("score")==null ? null : Double.parseDouble(jsonObject.getString("score")); | |||
| DataFileAnnotation dataFileAnnotation = DataFileAnnotation.builder().fileName(fileUploadBO.getFileName()) | |||
| .versionFileId(fileUploadBO.getVersionFileId()) | |||
| .datasetId(datasetId) | |||
| .labelId(categoryId) | |||
| .prediction(score).build(); | |||
| dataFileAnnotations.add(dataFileAnnotation); | |||
| } | |||
| } catch (Exception e) { | |||
| LogUtil.error(LogEnum.BIZ_DATASET, "导入数据集读取标注出错:{}",e); | |||
| } | |||
| }); | |||
| if(!CollectionUtils.isEmpty(dataFileAnnotations)){ | |||
| Queue<Long> dataFileAnnotionIds = generatorKeyUtil.getSequenceByBusinessCode(Constant.DATA_FILE_ANNOTATION, dataFileAnnotations.size()); | |||
| for (DataFileAnnotation dataFileAnnotation : dataFileAnnotations) { | |||
| dataFileAnnotation.setId(dataFileAnnotionIds.poll()); | |||
| dataFileAnnotation.setStatus(MagicNumConstant.ZERO); | |||
| dataFileAnnotation.setInvariable(MagicNumConstant.ZERO); | |||
| } | |||
| dataFileAnnotationService.insertDataFileBatch(dataFileAnnotations); | |||
| } | |||
| } | |||
| /** | |||
| @@ -997,7 +1043,7 @@ public class DatasetServiceImpl extends ServiceImpl<DatasetMapper, Dataset> impl | |||
| * @param batchFileCreateDTO | |||
| */ | |||
| @Transactional(rollbackFor = Exception.class) | |||
| public List<Long> saveDbForUploadFiles(Long datasetId, BatchFileCreateDTO batchFileCreateDTO) { | |||
| public List<Long> saveDbForUploadFiles(Long datasetId, BatchFileCreateDTO batchFileCreateDTO,Boolean ifImport) { | |||
| Dataset dataset = getBaseMapper().selectById(datasetId); | |||
| if (null == dataset) { | |||
| throw new BusinessException(ErrorEnum.DATA_ABSENT_OR_NO_AUTH, "id:" + datasetId, null); | |||
| @@ -1010,9 +1056,11 @@ public class DatasetServiceImpl extends ServiceImpl<DatasetMapper, Dataset> impl | |||
| if (!CollectionUtils.isEmpty(list)) { | |||
| List<DatasetVersionFile> datasetVersionFiles = new ArrayList<>(); | |||
| for (File file : list) { | |||
| datasetVersionFiles.add( | |||
| new DatasetVersionFile(datasetId, dataset.getCurrentVersionName(), file.getId(), file.getName()) | |||
| ); | |||
| DatasetVersionFile datasetVersionFile = new DatasetVersionFile(datasetId, dataset.getCurrentVersionName(), file.getId(), file.getName()); | |||
| if(ifImport != null && ifImport){ | |||
| datasetVersionFile.setAnnotationStatus(FileTypeEnum.FINISHED.getValue()); | |||
| } | |||
| datasetVersionFiles.add(datasetVersionFile); | |||
| } | |||
| datasetVersionFileService.insertList(datasetVersionFiles); | |||
| } | |||
| @@ -1021,11 +1069,13 @@ public class DatasetServiceImpl extends ServiceImpl<DatasetMapper, Dataset> impl | |||
| return fileIds; | |||
| } | |||
| //改变数据集的状态 | |||
| StateMachineUtil.stateChange(new StateChangeDTO() {{ | |||
| setObjectParam(new Object[]{dataset}); | |||
| setEventMethodName(DataStateMachineConstant.DATA_UPLOAD_FILES_EVENT); | |||
| setStateMachineType(DataStateMachineConstant.DATA_STATE_MACHINE); | |||
| }}); | |||
| if(!dataset.isImport()){ | |||
| StateMachineUtil.stateChange(new StateChangeDTO() {{ | |||
| setObjectParam(new Object[]{dataset}); | |||
| setEventMethodName(DataStateMachineConstant.DATA_UPLOAD_FILES_EVENT); | |||
| setStateMachineType(DataStateMachineConstant.DATA_STATE_MACHINE); | |||
| }}); | |||
| } | |||
| return fileIds; | |||
| } | |||
| @@ -1034,10 +1084,10 @@ public class DatasetServiceImpl extends ServiceImpl<DatasetMapper, Dataset> impl | |||
| * | |||
| * @param datasetId 数据集ID | |||
| */ | |||
| public void transportTextToEsForUploadFiles(Long datasetId, List<Long> fileIds) { | |||
| public void transportTextToEsForUploadFiles(Long datasetId, List<Long> fileIds,Boolean ifImport) { | |||
| Dataset dataset = getBaseMapper().selectById(datasetId); | |||
| if (dataset.getDataType().equals(MagicNumConstant.TWO) || dataset.getDataType().equals(MagicNumConstant.THREE)) { | |||
| fileService.transportTextToEs(dataset, fileIds); | |||
| fileService.transportTextToEs(dataset, fileIds,ifImport); | |||
| } | |||
| } | |||
| @@ -1663,5 +1713,4 @@ public class DatasetServiceImpl extends ServiceImpl<DatasetMapper, Dataset> impl | |||
| .ne("deleted", MagicNumConstant.ONE); | |||
| return baseMapper.selectList(queryWrapper); | |||
| } | |||
| } | |||
| @@ -36,6 +36,7 @@ import org.dubhe.biz.log.utils.LogUtil; | |||
| import org.dubhe.biz.statemachine.dto.StateChangeDTO; | |||
| import org.dubhe.data.constant.*; | |||
| import org.dubhe.data.dao.DatasetVersionFileMapper; | |||
| import org.dubhe.data.domain.bo.FileUploadBO; | |||
| import org.dubhe.data.domain.dto.DatasetVersionFileDTO; | |||
| import org.dubhe.data.domain.entity.*; | |||
| import org.dubhe.data.machine.constant.FileStateCodeConstant; | |||
| @@ -721,4 +722,14 @@ public class DatasetVersionFileServiceImpl extends ServiceImpl<DatasetVersionFil | |||
| public Long getVersionFileIdByFileName(Long datasetId, String fileName, String versionName) { | |||
| return baseMapper.getVersionFileIdByFileName(datasetId, fileName, versionName); | |||
| } | |||
| /** | |||
| * 获取导入文件所需信息 | |||
| * @param datasetId 数据集id | |||
| * @return List<FileUploadBO> | |||
| */ | |||
| @Override | |||
| public List<FileUploadBO> getFileUploadContent(Long datasetId, List<Long> fileIds) { | |||
| return baseMapper.getFileUploadContent(datasetId,fileIds); | |||
| } | |||
| } | |||
| @@ -87,6 +87,7 @@ import org.springframework.transaction.annotation.Transactional; | |||
| import org.springframework.transaction.support.TransactionCallbackWithoutResult; | |||
| import org.springframework.transaction.support.TransactionTemplate; | |||
| import org.springframework.util.CollectionUtils; | |||
| import springfox.documentation.spring.web.json.Json; | |||
| import javax.annotation.Resource; | |||
| import java.io.File; | |||
| @@ -609,6 +610,13 @@ public class DatasetVersionServiceImpl extends ServiceImpl<DatasetVersionMapper, | |||
| if (null == dataset) { | |||
| throw new BusinessException(ErrorEnum.DATASET_ABSENT, "id:" + datasetId, null); | |||
| } | |||
| //判断目标版本是否存在 | |||
| QueryWrapper<DatasetVersion> queryWrapper = new QueryWrapper<>(); | |||
| queryWrapper.lambda().eq(DatasetVersion::getDatasetId, datasetId).eq(DatasetVersion::getVersionName, versionName); | |||
| DatasetVersion datasetVersion = baseMapper.selectOne(queryWrapper); | |||
| if(datasetVersion == null) { | |||
| throw new BusinessException(ErrorEnum.DATASET_CHECK_VERSION_ERROR); | |||
| } | |||
| //判断数据集是否在发布中 | |||
| if (!StringUtils.isBlank(dataset.getCurrentVersionName())) { | |||
| if (getDatasetVersionSourceVersion(dataset).getDataConversion().equals(NumberConstant.NUMBER_4)) { | |||
| @@ -875,6 +883,7 @@ public class DatasetVersionServiceImpl extends ServiceImpl<DatasetVersionMapper, | |||
| } | |||
| try { | |||
| minioUtil.writeString(bucketName, targetDir + "/labels.text", Strings.join(labelStr, ',')); | |||
| minioUtil.writeString(bucketName, targetDir + "/labelsIds.text", JSONObject.toJSONString(labelMaps)); | |||
| } catch (Exception e) { | |||
| LogUtil.error(LogEnum.BIZ_DATASET, "MinIO file write exception, {}", e); | |||
| } | |||
| @@ -1528,8 +1528,8 @@ public class FileServiceImpl extends ServiceImpl<FileMapper, File> implements Fi | |||
| * @param dataset 数据集 | |||
| */ | |||
| @Override | |||
| public void transportTextToEs(Dataset dataset,List<Long> fileIdsNotToEs) { | |||
| List<EsTransportDTO> esTransportDTOList = fileMapper.selectTextDataNoTransport(dataset.getId(), fileIdsNotToEs); | |||
| public void transportTextToEs(Dataset dataset,List<Long> fileIdsNotToEs,Boolean ifImport) { | |||
| List<EsTransportDTO> esTransportDTOList = fileMapper.selectTextDataNoTransport(dataset.getId(), fileIdsNotToEs, ifImport); | |||
| esTransportDTOList.forEach(esTransportDTO -> { | |||
| FileInputStream fileInputStream = null; | |||
| InputStreamReader reader = null; | |||
| @@ -108,7 +108,7 @@ public class LabelGroupServiceImpl extends ServiceImpl<LabelGroupMapper, LabelGr | |||
| */ | |||
| @Transactional(rollbackFor = Exception.class) | |||
| @Override | |||
| public void creatLabelGroup(LabelGroupCreateDTO labelGroupCreateDTO) { | |||
| public Long creatLabelGroup(LabelGroupCreateDTO labelGroupCreateDTO) { | |||
| //1 标签组名称唯一校验 | |||
| labelGroupCreateDTO.setOriginUserId(JwtUtils.getCurUserId()); | |||
| @@ -132,7 +132,7 @@ public class LabelGroupServiceImpl extends ServiceImpl<LabelGroupMapper, LabelGr | |||
| if (!CollectionUtils.isEmpty(labelList)) { | |||
| buildLabelDataByCreate(labelGroup, labelList); | |||
| } | |||
| return labelGroup.getId(); | |||
| } | |||
| /** | |||
| @@ -488,7 +488,7 @@ public class LabelGroupServiceImpl extends ServiceImpl<LabelGroupMapper, LabelGr | |||
| */ | |||
| @Override | |||
| @Transactional(rollbackFor = Exception.class) | |||
| public void importLabelGroup(LabelGroupImportDTO labelGroupImportDTO, MultipartFile file) { | |||
| public Long importLabelGroup(LabelGroupImportDTO labelGroupImportDTO, MultipartFile file) { | |||
| //文件格式/大小/属性校验 | |||
| FileUtil.checkoutFile(file); | |||
| @@ -506,7 +506,7 @@ public class LabelGroupServiceImpl extends ServiceImpl<LabelGroupMapper, LabelGr | |||
| .remark(labelGroupImportDTO.getRemark()).build(); | |||
| //调用新增标签方法 | |||
| this.creatLabelGroup(createDTO); | |||
| return this.creatLabelGroup(createDTO); | |||
| } | |||
| /** | |||
| @@ -33,7 +33,10 @@ import org.springframework.stereotype.Component; | |||
| import java.io.File; | |||
| import java.math.BigDecimal; | |||
| import java.util.ArrayList; | |||
| import java.util.HashMap; | |||
| import java.util.List; | |||
| import java.util.Map; | |||
| import java.util.stream.Collectors; | |||
| /** | |||
| * @description oneflow文本格式转换 | |||
| @@ -71,6 +74,15 @@ public class ConversionUtil { | |||
| LogUtil.error(LogEnum.BIZ_DATASET, "getObjects is failed:{}", e); | |||
| return; | |||
| } | |||
| Map<String,Integer> labelMap = new HashMap<>(); | |||
| try { | |||
| String labelIdsPath = path.replace("/origin","/annotation/"); | |||
| String labelIdsString = minioUtil.readString(bucket, labelIdsPath + "labelsIds.text"); | |||
| Map<Integer,String> idLabelMap = JSONObject.parseObject(labelIdsString, Map.class); | |||
| labelMap = idLabelMap.entrySet().stream().collect(Collectors.toMap(Map.Entry::getValue, Map.Entry::getKey)); | |||
| } catch (Exception e) { | |||
| LogUtil.error(LogEnum.BIZ_DATASET, "ReadJson is failed:{}", e); | |||
| } | |||
| for (int n = 0; n < imagePaths.size(); n++) { | |||
| String imagePath = imagePaths.get(n); | |||
| if (imagePath.endsWith(TXT_FILE_FORMATS)) { | |||
| @@ -92,7 +104,8 @@ public class ConversionUtil { | |||
| StringBuffer content = new StringBuffer(); | |||
| for (Object object : objects) { | |||
| JSONObject jsonObject = (JSONObject) object; | |||
| Long categoryId = Long.valueOf(jsonObject.getString("category_id")); | |||
| String categoryName = jsonObject.getString("category_id"); | |||
| Integer categoryId = labelMap.get(categoryName); | |||
| JSONArray jsonArray = (JSONArray) jsonObject.get("bbox"); | |||
| BigDecimal[] bbox = new BigDecimal[ARRAY_LENGTH]; | |||
| for (int j = 0; j < ARRAY_LENGTH; j++) { | |||
| @@ -294,4 +294,29 @@ | |||
| select id from data_dataset_version_file where file_name = #{fileName} and version_name = #{versionName} | |||
| and dataset_id = #{datasetId} | |||
| </select> | |||
| <select id="getFileUploadContent" resultMap="fileUploadContent"> | |||
| SELECT | |||
| ddvf.id version_file_id, | |||
| df.id file_id, | |||
| df.url, | |||
| df.`name` | |||
| FROM | |||
| data_dataset_version_file ddvf | |||
| LEFT JOIN data_file df ON ddvf.file_id = df.id | |||
| WHERE | |||
| ddvf.dataset_id = #{datasetId} | |||
| AND df.dataset_id = #{datasetId} | |||
| AND df.id in | |||
| <foreach item="item" collection="fileIds" separator="," open="(" close=")"> | |||
| #{item} | |||
| </foreach> | |||
| </select> | |||
| <resultMap id="fileUploadContent" type="org.dubhe.data.domain.bo.FileUploadBO"> | |||
| <result column="version_file_id" property="versionFileId"/> | |||
| <result column="file_id" property="fileId"/> | |||
| <result column="url" property="fileUrl"/> | |||
| <result column="name" property="fileName"/> | |||
| </resultMap> | |||
| </mapper> | |||
| @@ -79,13 +79,24 @@ | |||
| df.enhance_type, | |||
| df.origin_user_id, | |||
| df.id | |||
| <if test="ifImport"> | |||
| , | |||
| dfa.label_id, | |||
| dfa.prediction | |||
| </if> | |||
| FROM | |||
| data_dataset_version_file ddvf | |||
| LEFT JOIN data_file df ON ddvf.file_id = df.id | |||
| <if test="ifImport"> | |||
| LEFT JOIN data_file_annotation dfa on ddvf.id = dfa.version_file_id | |||
| </if> | |||
| WHERE | |||
| df.es_transport = 0 | |||
| AND ddvf.dataset_id = #{datasetId} | |||
| AND df.dataset_id = #{datasetId} | |||
| <if test="ifImport"> | |||
| AND dfa.dataset_id = #{datasetId} | |||
| </if> | |||
| AND df.id in | |||
| <foreach collection="fileIdsNotToEs" item="fileId" open="(" close=")" separator=","> | |||
| #{fileId} | |||
| @@ -209,12 +209,15 @@ public class PtImageServiceImpl implements PtImageService { | |||
| .setRemark(ptImageUploadDTO.getRemark()) | |||
| .setImageTag(ptImageUploadDTO.getImageTag()) | |||
| .setCreateUserId(user.getId()); | |||
| if (ImageSourceEnum.PRE.getCode().equals(ptImageUploadDTO.getImageResource())) { | |||
| ptImage.setOriginUserId(MagicNumConstant.ZERO_LONG); | |||
| } else { | |||
| ptImage.setOriginUserId(user.getId()); | |||
| } | |||
| //设置notebook镜像为预置镜像 | |||
| if (ptImageUploadDTO.getProjectType().equals(ImageTypeEnum.NOTEBOOK.getCode()) && BaseService.isAdmin(user)) { | |||
| ptImage.setOriginUserId(0L); | |||
| } | |||
| int count = ptImageMapper.insert(ptImage); | |||
| if (count < 1) { | |||
| imagePushAsync.updateImageStatus(ptImage, ImageStateEnum.FAIL.getCode()); | |||
| @@ -96,6 +96,7 @@ public class PodCallback extends Observable { | |||
| if (StringUtils.isNotEmpty(businessLabel) && needCallback(watcherActionEnum,pod)){ | |||
| dealWithDeleted(watcherActionEnum,pod); | |||
| BaseK8sPodCallbackCreateDTO baseK8sPodCallbackCreateDTO = new BaseK8sPodCallbackCreateDTO(pod.getNamespace(), pod.getLabel(K8sLabelConstants.BASE_TAG_SOURCE),pod.getName(), pod.getLabel(K8sLabelConstants.BASE_TAG_P_KIND), pod.getLabel(K8sLabelConstants.BASE_TAG_P_NAME), pod.getPhase(), waitingReason); | |||
| baseK8sPodCallbackCreateDTO.setLables(pod.getLabels()); | |||
| String url = k8sCallBackTool.getPodCallbackUrl(businessLabel); | |||
| String token = k8sCallBackTool.generateToken(); | |||
| @@ -0,0 +1,70 @@ | |||
| /** | |||
| * Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| * ============================================================= | |||
| */ | |||
| package org.dubhe.dubhek8s.observer; | |||
| import org.dubhe.biz.base.constant.MagicNumConstant; | |||
| import org.dubhe.biz.base.enums.BizEnum; | |||
| import org.dubhe.biz.base.utils.SpringContextHolder; | |||
| import org.dubhe.biz.log.enums.LogEnum; | |||
| import org.dubhe.biz.log.utils.LogUtil; | |||
| import org.dubhe.dubhek8s.event.callback.PodCallback; | |||
| import org.dubhe.k8s.api.TrainJobApi; | |||
| import org.dubhe.k8s.constant.K8sLabelConstants; | |||
| import org.dubhe.k8s.domain.resource.BizPod; | |||
| import org.dubhe.k8s.enums.PodPhaseEnum; | |||
| import org.springframework.beans.factory.annotation.Autowired; | |||
| import org.springframework.stereotype.Component; | |||
| import java.util.Observable; | |||
| import java.util.Observer; | |||
| /** | |||
| * @description 观察者,处理Tadl模块trial pod变化 | |||
| * @date 2021-08-12 | |||
| */ | |||
| @Component | |||
| public class TadlTrialObserver implements Observer { | |||
| @Autowired | |||
| private TrainJobApi trainJobApi; | |||
| public TadlTrialObserver(PodCallback podCallback) { | |||
| podCallback.addObserver(this); | |||
| } | |||
| @Override | |||
| public void update(Observable o, Object arg) { | |||
| if (arg instanceof BizPod){ | |||
| BizPod pod = (BizPod)arg; | |||
| boolean trialSucceedOrFailed = (PodPhaseEnum.FAILED.getPhase().equals(pod.getPhase()) || PodPhaseEnum.SUCCEEDED.getPhase().equals(pod.getPhase())) && BizEnum.TADL.getBizCode().equals(pod.getBusinessLabel()) && SpringContextHolder.getActiveProfile().equals(pod.getLabel(K8sLabelConstants.PLATFORM_RUNTIME_ENV)); | |||
| if (trialSucceedOrFailed){ | |||
| new Thread(new Runnable(){ | |||
| @Override | |||
| public void run(){ | |||
| try { | |||
| Thread.sleep(MagicNumConstant.ONE_MINUTE); | |||
| } catch (InterruptedException e) { | |||
| LogUtil.error(LogEnum.BIZ_K8S,"TadlTrialObserver update error {}",e.getMessage(),e); | |||
| } | |||
| LogUtil.warn(LogEnum.BIZ_K8S,"delete succeed or failed trial resourceName {};phase {};podName {}",pod.getLabel(K8sLabelConstants.BASE_TAG_SOURCE),pod.getPhase(),pod.getName()); | |||
| trainJobApi.delete(pod.getNamespace(),pod.getLabel(K8sLabelConstants.BASE_TAG_SOURCE)); | |||
| } | |||
| }).start(); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -18,7 +18,6 @@ | |||
| package org.dubhe.dubhek8s.service; | |||
| import com.sun.org.apache.xpath.internal.operations.Bool; | |||
| import org.dubhe.dubhek8s.domain.dto.NodeDTO; | |||
| import org.dubhe.k8s.domain.dto.NodeIsolationDTO; | |||
| import org.dubhe.k8s.domain.resource.BizNode; | |||
| @@ -94,4 +94,14 @@ public class LogMonitoringApiTest { | |||
| System.out.println(count); | |||
| } | |||
| @Test | |||
| public void addTadlLog(){ | |||
| logMonitoringApi.addTadlLogsToEs(1L,"Once I was a wooden boy"); | |||
| } | |||
| @Test | |||
| public void searchTadlLogById(){ | |||
| LogMonitoringVO logMonitoringVO = logMonitoringApi.searchTadlLogById(1, 10, 1L); | |||
| } | |||
| } | |||
| @@ -45,9 +45,9 @@ public class PtModelBranchCreateDTO implements Serializable { | |||
| @Length(max = PtModelUtil.NUMBER_ONE_HUNDRED_TWENTY_EIGHT, message = "模型地址-输入长度不能超过128个字符") | |||
| private String modelAddress; | |||
| @ApiModelProperty("模型来源(0用户上传,1训练输出,2模型优化)") | |||
| @ApiModelProperty("模型来源(0用户上传,1训练输出,2模型优化,3模型转换,4自动机器学习)") | |||
| @Min(value = PtModelUtil.NUMBER_ZERO, message = "模型来源错误") | |||
| @Max(value = PtModelUtil.NUMBER_TWO, message = "模型来源错误") | |||
| @Max(value = PtModelUtil.NUMBER_FOUR, message = "模型来源错误") | |||
| @NotNull(message = "模型来源不能为空") | |||
| private Integer modelSource; | |||
| @@ -66,9 +66,9 @@ public class PtModelInfoCreateDTO implements Serializable { | |||
| @Length(max = PtModelUtil.NUMBER_ONE_HUNDRED_TWENTY_EIGHT, message = "模型地址-输入长度不能超过128个字符") | |||
| private String modelAddress; | |||
| @ApiModelProperty("模型来源(0用户上传,1训练输出,2模型优化)") | |||
| @ApiModelProperty("模型来源(0用户上传,1训练输出,2模型优化,3模型转换,4自动机器学习)") | |||
| @Min(value = PtModelUtil.NUMBER_ZERO, message = "模型来源错误") | |||
| @Max(value = PtModelUtil.NUMBER_TWO, message = "模型来源错误") | |||
| @Max(value = PtModelUtil.NUMBER_FOUR, message = "模型来源错误") | |||
| private Integer modelSource; | |||
| @ApiModelProperty("算法ID") | |||
| @@ -195,7 +195,8 @@ public class PtModelBranchServiceImpl implements PtModelBranchService { | |||
| LogUtil.error(LogEnum.BIZ_MODEL, "User {} failed to create new version", user.getUsername()); | |||
| throw new BusinessException("模型版本创建失败"); | |||
| } | |||
| } else if (ptModelBranchCreateDTO.getModelSource() == PtModelUtil.TRAINING_IMPORT || ptModelBranchCreateDTO.getModelSource() == PtModelUtil.MODEL_OPTIMIZATION) { | |||
| } else if (ptModelBranchCreateDTO.getModelSource() == PtModelUtil.TRAINING_IMPORT || ptModelBranchCreateDTO.getModelSource() == PtModelUtil.MODEL_OPTIMIZATION | |||
| ||ptModelBranchCreateDTO.getModelSource() == PtModelUtil.AUTOMATIC_MACHINE_LEARNING) { | |||
| //文件拷贝中 | |||
| ptModelBranch.setStatus(ModelCopyStatusEnum.COPING.getCode()); | |||
| //判断模型版本是否已存在 | |||
| @@ -778,6 +778,7 @@ public class NoteBookServiceImpl implements NoteBookService { | |||
| noteBook.setUrl(SymbolConstant.BLANK); | |||
| noteBook.setStatus(NoteBookStatusEnum.STOP.getCode()); | |||
| noteBook.setStatusDetail(SymbolConstant.BLANK); | |||
| jupyterResourceApi.delete(noteBook.getK8sNamespace(),noteBook.getK8sResourceName()); | |||
| processNotebookCommand.stop(noteBook); | |||
| updateById(noteBook); | |||
| return true; | |||
| @@ -28,7 +28,7 @@ import org.dubhe.serving.domain.entity.ServingInfo; | |||
| * @description 服务信息管理 | |||
| * @date 2020-08-25 | |||
| */ | |||
| @DataPermission(ignoresMethod = {"insert", "updateStatusById"}) | |||
| @DataPermission(ignoresMethod = {"insert", "rollbackById", "updateStatusDetail"}) | |||
| public interface ServingInfoMapper extends BaseMapper<ServingInfo> { | |||
| /** | |||
| @@ -39,5 +39,14 @@ public interface ServingInfoMapper extends BaseMapper<ServingInfo> { | |||
| * @return int 数量 | |||
| */ | |||
| @Update("update serving_info set deleted = #{deleteFlag} where id = #{id}") | |||
| int updateStatusById(@Param("id") Long id, @Param("deleteFlag") boolean deleteFlag); | |||
| int rollbackById(@Param("id") Long id, @Param("deleteFlag") boolean deleteFlag); | |||
| /** | |||
| * 修改状态详情 | |||
| * @param id serving id | |||
| * @param statusDetail 状态详情 | |||
| * @return int 数量 | |||
| */ | |||
| @Update("update serving_info set status_detail = #{statusDetail} where id = #{id}") | |||
| int updateStatusDetail(@Param("id") Long id, @Param("statusDetail") String statusDetail); | |||
| } | |||
| @@ -193,7 +193,7 @@ public class ServingServiceImpl implements ServingService { | |||
| private RedisUtils redisUtils; | |||
| @Autowired | |||
| private ResourceCache resourceCache; | |||
| @Value("Task:Serving:"+"${spring.profiles.active}_serving_id_") | |||
| @Value("Task:Serving:" + "${spring.profiles.active}_serving_id_") | |||
| private String servingIdPrefix; | |||
| /** | |||
| @@ -553,7 +553,7 @@ public class ServingServiceImpl implements ServingService { | |||
| throw new BusinessException(ServingErrorEnum.INTERNAL_SERVER_ERROR); | |||
| } | |||
| servingInfo.setStatusDetail(SymbolConstant.BRACKETS); | |||
| deployServingAsyncTask.deleteServing(user, servingInfo, oldModelConfigList); | |||
| deployServingAsyncTask.deleteServing(servingInfo, oldModelConfigList); | |||
| // 删除拷贝的文件 | |||
| for (ServingModelConfig oldModelConfig : oldModelConfigList) { | |||
| String recyclePath = k8sNameTool.getAbsolutePath(onlineRootPath + servingInfo.getCreateUserId() + File.separator + servingInfo.getId() + File.separator + oldModelConfig.getId()); | |||
| @@ -567,7 +567,7 @@ public class ServingServiceImpl implements ServingService { | |||
| } | |||
| } | |||
| List<ServingModelConfig> modelConfigList = updateServing(servingInfoUpdateDTO, user, servingInfo); | |||
| String taskIdentify = resourceCache.getTaskIdentify(servingInfo.getId(), servingInfo.getName(),servingIdPrefix); | |||
| String taskIdentify = resourceCache.getTaskIdentify(servingInfo.getId(), servingInfo.getName(), servingIdPrefix); | |||
| // 异步部署容器 | |||
| deployServingAsyncTask.deployServing(user, servingInfo, modelConfigList, taskIdentify); | |||
| return new ServingInfoUpdateVO(servingInfo.getId(), servingInfo.getStatus()); | |||
| @@ -641,10 +641,10 @@ public class ServingServiceImpl implements ServingService { | |||
| } | |||
| ServingInfo servingInfo = checkServingInfoExist(servingInfoDeleteDTO.getId(), user.getId()); | |||
| List<ServingModelConfig> modelConfigList = getModelConfigByServingId(servingInfo.getId()); | |||
| deployServingAsyncTask.deleteServing(user, servingInfo, modelConfigList); | |||
| deployServingAsyncTask.deleteServing(servingInfo, modelConfigList); | |||
| deleteServing(servingInfoDeleteDTO, user, servingInfo); | |||
| String taskIdentify = (String) redisUtils.get(servingIdPrefix + String.valueOf(servingInfo.getId())); | |||
| if (StringUtils.isNotEmpty(taskIdentify)){ | |||
| if (StringUtils.isNotEmpty(taskIdentify)) { | |||
| redisUtils.del(taskIdentify, servingIdPrefix + String.valueOf(servingInfo.getId())); | |||
| } | |||
| Map<String, Object> map = new HashMap<>(NumberConstant.NUMBER_2); | |||
| @@ -787,8 +787,8 @@ public class ServingServiceImpl implements ServingService { | |||
| throw new BusinessException(ServingErrorEnum.INTERNAL_SERVER_ERROR); | |||
| } | |||
| for (ServingModelConfig servingModelConfig : modelConfigList) { | |||
| servingModelConfig.setResourceInfo(null); | |||
| servingModelConfig.setUrl(null); | |||
| servingModelConfig.setResourceInfo(SymbolConstant.BLANK); | |||
| servingModelConfig.setUrl(SymbolConstant.BLANK); | |||
| if (servingModelConfigMapper.updateById(servingModelConfig) < NumberConstant.NUMBER_1) { | |||
| throw new BusinessException(ServingErrorEnum.DATABASE_ERROR); | |||
| } | |||
| @@ -871,7 +871,7 @@ public class ServingServiceImpl implements ServingService { | |||
| servingInfo.setStatus(ServingStatusEnum.STOP.getStatus()); | |||
| servingInfo.setRunningNode(NumberConstant.NUMBER_0); | |||
| List<ServingModelConfig> modelConfigList = getModelConfigByServingId(servingInfo.getId()); | |||
| deployServingAsyncTask.deleteServing(user, servingInfo, modelConfigList); | |||
| deployServingAsyncTask.deleteServing(servingInfo, modelConfigList); | |||
| updateServingStop(user, servingInfo, modelConfigList); | |||
| servingInfo.setStatusDetail(SymbolConstant.BRACKETS); | |||
| // 删除路由信息 | |||
| @@ -903,7 +903,7 @@ public class ServingServiceImpl implements ServingService { | |||
| servingModelConfigMapper.updateById(servingModelConfig); | |||
| }); | |||
| if (ServingTypeEnum.GRPC.getType().equals(servingInfo.getType())) { | |||
| GrpcClient.shutdownChannel(servingInfo.getId(), user); | |||
| GrpcClient.shutdownChannel(servingInfo.getId()); | |||
| } | |||
| } | |||
| @@ -1154,7 +1154,7 @@ public class ServingServiceImpl implements ServingService { | |||
| servingInfo.putStatusDetail(statusDetailKey, req.getMessages()); | |||
| } | |||
| LogUtil.info(LogEnum.SERVING, "The callback serving message:{} ,req message:{}", servingInfo, req); | |||
| return servingInfoMapper.updateById(servingInfo) < NumberConstant.NUMBER_1; | |||
| return servingInfoMapper.updateStatusDetail(servingInfo.getId(), servingInfo.getStatusDetail()) < NumberConstant.NUMBER_1; | |||
| } | |||
| @Transactional(rollbackFor = Exception.class) | |||
| @@ -1186,6 +1186,9 @@ public class ServingServiceImpl implements ServingService { | |||
| servingInfo.getRunningNode(), req); | |||
| servingInfo.putStatusDetail(statusDetailKey, "运行中的节点数为0"); | |||
| servingInfo.setStatus(ServingStatusEnum.EXCEPTION.getStatus()); | |||
| // 删除已创建的pod | |||
| List<ServingModelConfig> deleteList = getModelConfigByServingId(servingInfo.getId()); | |||
| deployServingAsyncTask.deleteServing(servingInfo, deleteList); | |||
| } | |||
| if (servingInfo.getRunningNode() > NumberConstant.NUMBER_0) { | |||
| LogUtil.info(LogEnum.SERVING, | |||
| @@ -1289,7 +1292,7 @@ public class ServingServiceImpl implements ServingService { | |||
| } | |||
| } | |||
| if (servingInfoId != null) { | |||
| servingInfoMapper.updateStatusById(servingInfoId, false); | |||
| servingInfoMapper.rollbackById(servingInfoId, false); | |||
| } | |||
| } | |||
| } | |||
| @@ -183,7 +183,7 @@ public class DeployServingAsyncTask { | |||
| // 删除已创建的pod | |||
| List<ServingModelConfig> deleteList = new ArrayList<>(); | |||
| deleteList.add(servingModelConfig); | |||
| deleteServing(user, servingInfo, deleteList); | |||
| deleteServing(servingInfo, deleteList); | |||
| } | |||
| } | |||
| } else { | |||
| @@ -309,52 +309,51 @@ public class DeployServingAsyncTask { | |||
| /** | |||
| * 删除pod | |||
| * | |||
| * @param user 用户信息 | |||
| * @param servingInfo 在线服务信息 | |||
| * @param modelConfigList 在线服务模型部署信息集合 | |||
| */ | |||
| @Async("servingExecutor") | |||
| @Transactional(rollbackFor = Exception.class) | |||
| public void deleteServing(UserContext user, ServingInfo servingInfo, List<ServingModelConfig> modelConfigList) { | |||
| public void deleteServing(ServingInfo servingInfo, List<ServingModelConfig> modelConfigList) { | |||
| boolean flag = true; | |||
| try { | |||
| for (ServingModelConfig servingModelConfig : modelConfigList) { | |||
| String namespace = k8sNameTool.getNamespace(servingInfo.getCreateUserId()); | |||
| String resourceName = k8sNameTool.generateResourceName(BizEnum.SERVING, servingModelConfig.getResourceInfo()); | |||
| LogUtil.info(LogEnum.SERVING, "User {} delete the service, namespace:{}, resourceName:{}", user.getUsername(), namespace, resourceName); | |||
| LogUtil.info(LogEnum.SERVING, "Delete the service, namespace:{}, resourceName:{}", namespace, resourceName); | |||
| String uniqueName = ServingStatusDetailDescUtil.getUniqueName(servingModelConfig.getModelName(), servingModelConfig.getModelVersion()); | |||
| String statusDetailKey = ServingStatusDetailDescUtil.getServingStatusDetailKey(ServingStatusDetailDescUtil.CONTAINER_DELETION_EXCEPTION, uniqueName); | |||
| PtBaseResult ptBaseResult = modelServingApi.delete(namespace, resourceName); | |||
| if (!ServingConstant.SUCCESS_CODE.equals(ptBaseResult.getCode())) { | |||
| servingInfo.putStatusDetail(statusDetailKey, ptBaseResult.getMessage()); | |||
| flag = false; | |||
| } else { | |||
| servingModelConfig.setResourceInfo(null); | |||
| if (ServingConstant.SUCCESS_CODE.equals(ptBaseResult.getCode())) { | |||
| servingModelConfig.setResourceInfo(SymbolConstant.BLANK); | |||
| if (servingModelConfigService.updateById(servingModelConfig)) { | |||
| LogUtil.info(LogEnum.SERVING, "User {} delete the service SUCCESS, namespace:{}, resourceName:{}", user.getUsername(), namespace, resourceName); | |||
| LogUtil.info(LogEnum.SERVING, "Delete the service SUCCESS, namespace:{}, resourceName:{}", namespace, resourceName); | |||
| } | |||
| } else { | |||
| servingInfo.putStatusDetail(statusDetailKey, ptBaseResult.getMessage()); | |||
| flag = false; | |||
| } | |||
| } | |||
| } catch (KubernetesClientException e) { | |||
| servingInfo.putStatusDetail(ServingStatusDetailDescUtil.getServingStatusDetailKey(ServingStatusDetailDescUtil.CONTAINER_DELETION_EXCEPTION, servingInfo.getName()), e.getMessage()); | |||
| LogUtil.error(LogEnum.SERVING, "An Exception occurred. Service id={}, service name:{},exception :{}", user.getUsername(), servingInfo.getId(), servingInfo.getName(), e); | |||
| LogUtil.error(LogEnum.SERVING, "An Exception occurred. Service id={}, service name:{},exception :{}", servingInfo.getId(), servingInfo.getName(), e); | |||
| } | |||
| if (!flag) { | |||
| servingInfo.setStatus(ServingStatusEnum.EXCEPTION.getStatus()); | |||
| LogUtil.error(LogEnum.SERVING, "An Exception occurred when user {} stopping the service, service name:{}", user.getUsername(), servingInfo.getName()); | |||
| LogUtil.error(LogEnum.SERVING, "An Exception occurred when stopping the service, service name:{}", servingInfo.getName()); | |||
| } | |||
| LogUtil.info(LogEnum.SERVING, "User {} stopped the service with SUCCESS, service name:{}", user.getUsername(), servingInfo.getName()); | |||
| LogUtil.info(LogEnum.SERVING, "Stopped the service with SUCCESS, service name:{}", servingInfo.getName()); | |||
| //grpc协议关闭对应通道 | |||
| if (ServingTypeEnum.GRPC.getType().equals(servingInfo.getType())) { | |||
| GrpcClient.shutdownChannel(servingInfo.getId(), user); | |||
| GrpcClient.shutdownChannel(servingInfo.getId()); | |||
| } | |||
| int result = servingInfoMapper.updateById(servingInfo); | |||
| if (result < NumberConstant.NUMBER_1) { | |||
| LogUtil.error(LogEnum.SERVING, "User {} FAILED stopping the online service. Database update FAILED. Service id={}, service name:{},service status:{}", user.getUsername(), servingInfo.getId(), servingInfo.getName(), servingInfo.getStatus()); | |||
| LogUtil.error(LogEnum.SERVING, "FAILED stopping the online service. Database update FAILED. Service id={}, service name:{},service status:{}", servingInfo.getId(), servingInfo.getName(), servingInfo.getStatus()); | |||
| throw new BusinessException(ServingErrorEnum.DATABASE_ERROR); | |||
| } | |||
| } | |||
| @@ -438,7 +437,7 @@ public class DeployServingAsyncTask { | |||
| } | |||
| } catch (KubernetesClientException e) { | |||
| batchServing.putStatusDetail(ServingStatusDetailDescUtil.getServingStatusDetailKey(ServingStatusDetailDescUtil.BULK_SERVICE_DELETE_EXCEPTION, batchServing.getName()), e.getMessage()); | |||
| LogUtil.error(LogEnum.SERVING, "An Exception occurred. BatchServing id={}, batchServing name:{},exception :{}", user.getUsername(), batchServing.getId(), batchServing.getName(), e); | |||
| LogUtil.error(LogEnum.SERVING, "An Exception occurred. BatchServing id={}, batchServing name:{},exception :{}", batchServing.getId(), batchServing.getName(), e); | |||
| } | |||
| batchServing.setStatus(ServingStatusEnum.EXCEPTION.getStatus()); | |||
| batchServingMapper.updateById(batchServing); | |||
| @@ -158,15 +158,14 @@ public class GrpcClient { | |||
| * 关闭grpc通道 | |||
| * | |||
| * @param servingId 在线服务id | |||
| * @param user 用户信息 | |||
| */ | |||
| public static void shutdownChannel(Long servingId, UserContext user) { | |||
| public static void shutdownChannel(Long servingId) { | |||
| if (channelMap.containsKey(servingId)) { | |||
| ManagedChannel channel = channelMap.get(servingId); | |||
| try { | |||
| GrpcClient.shutdown(channel); | |||
| } catch (InterruptedException e) { | |||
| LogUtil.error(LogEnum.SERVING, "An Exception occurred when user {} shutting down the grpc channel, service id:{}", user.getUsername(), servingId, e); | |||
| LogUtil.error(LogEnum.SERVING, "An Exception occurred when shutting down the grpc channel, service id:{}", servingId, e); | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,147 @@ | |||
| <?xml version="1.0" encoding="UTF-8"?> | |||
| <project xmlns="http://maven.apache.org/POM/4.0.0" | |||
| xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" | |||
| xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> | |||
| <parent> | |||
| <groupId>org.dubhe</groupId> | |||
| <artifactId>server</artifactId> | |||
| <version>0.0.1-SNAPSHOT</version> | |||
| </parent> | |||
| <modelVersion>4.0.0</modelVersion> | |||
| <artifactId>dubhe-tadl</artifactId> | |||
| <version>0.0.1-SNAPSHOT</version> | |||
| <name>Dubhe TADL</name> | |||
| <description>Dubhe TADL</description> | |||
| <dependencies> | |||
| <!-- 注册中心 --> | |||
| <dependency> | |||
| <groupId>org.dubhe.cloud</groupId> | |||
| <artifactId>registration</artifactId> | |||
| <version>${org.dubhe.cloud.registration.version}</version> | |||
| </dependency> | |||
| <!-- 配置中心 --> | |||
| <dependency> | |||
| <groupId>org.dubhe.cloud</groupId> | |||
| <artifactId>configuration</artifactId> | |||
| <version>${org.dubhe.cloud.configuration.version}</version> | |||
| </dependency> | |||
| <!-- 统一Rest返回工具结构 --> | |||
| <dependency> | |||
| <groupId>org.dubhe.biz</groupId> | |||
| <artifactId>data-response</artifactId> | |||
| <version>${org.dubhe.biz.data-response.version}</version> | |||
| </dependency> | |||
| <!-- 统一权限配置 --> | |||
| <dependency> | |||
| <groupId>org.dubhe.cloud</groupId> | |||
| <artifactId>auth-config</artifactId> | |||
| <version>${org.dubhe.cloud.auth-config.version}</version> | |||
| </dependency> | |||
| <!-- Cloud swagger --> | |||
| <dependency> | |||
| <groupId>org.dubhe.cloud</groupId> | |||
| <artifactId>swagger</artifactId> | |||
| <version>${org.dubhe.cloud.swagger.version}</version> | |||
| </dependency> | |||
| <!--持久层操作--> | |||
| <dependency> | |||
| <groupId>org.dubhe.biz</groupId> | |||
| <artifactId>db</artifactId> | |||
| <version>${org.dubhe.biz.db.version}</version> | |||
| </dependency> | |||
| <!-- k8s API依赖--> | |||
| <dependency> | |||
| <groupId>org.dubhe</groupId> | |||
| <artifactId>common-k8s</artifactId> | |||
| <version>${org.dubhe.common-k8s.version}</version> | |||
| </dependency> | |||
| <!-- log依赖--> | |||
| <dependency> | |||
| <groupId>org.dubhe.biz</groupId> | |||
| <artifactId>log</artifactId> | |||
| <version>${org.dubhe.biz.log.version}</version> | |||
| </dependency> | |||
| <!-- 远程调用 --> | |||
| <dependency> | |||
| <groupId>org.dubhe.cloud</groupId> | |||
| <artifactId>remote-call</artifactId> | |||
| <version>${org.dubhe.cloud.remote-call.version}</version> | |||
| </dependency> | |||
| <!--mapStruct依赖--> | |||
| <dependency> | |||
| <groupId>org.mapstruct</groupId> | |||
| <artifactId>mapstruct-jdk8</artifactId> | |||
| <version>${mapstruct.version}</version> | |||
| </dependency> | |||
| <dependency> | |||
| <groupId>org.mapstruct</groupId> | |||
| <artifactId>mapstruct-processor</artifactId> | |||
| <version>${mapstruct.version}</version> | |||
| <scope>provided</scope> | |||
| </dependency> | |||
| <dependency> | |||
| <groupId>commons-beanutils</groupId> | |||
| <artifactId>commons-beanutils</artifactId> | |||
| <version>1.9.3</version> | |||
| </dependency> | |||
| <dependency> | |||
| <groupId>org.dubhe.biz</groupId> | |||
| <artifactId>state-machine</artifactId> | |||
| <version>0.0.1-SNAPSHOT</version> | |||
| <scope>compile</scope> | |||
| </dependency> | |||
| <dependency> | |||
| <groupId>org.dubhe.biz</groupId> | |||
| <artifactId>state-machine</artifactId> | |||
| <version>0.0.1-SNAPSHOT</version> | |||
| <scope>compile</scope> | |||
| </dependency> | |||
| <dependency> | |||
| <groupId>junit</groupId> | |||
| <artifactId>junit</artifactId> | |||
| </dependency> | |||
| <dependency> | |||
| <groupId>org.springframework.boot</groupId> | |||
| <artifactId>spring-boot-test</artifactId> | |||
| </dependency> | |||
| <dependency> | |||
| <groupId>org.springframework</groupId> | |||
| <artifactId>spring-test</artifactId> | |||
| </dependency> | |||
| <dependency> | |||
| <groupId>org.redisson</groupId> | |||
| <artifactId>redisson-spring-boot-starter</artifactId> | |||
| <version>${org.redisson.version}</version> | |||
| </dependency> | |||
| <!-- common-recycle 垃圾回收--> | |||
| <dependency> | |||
| <groupId>org.dubhe</groupId> | |||
| <artifactId>common-recycle</artifactId> | |||
| <version>${org.dubhe.common-recycle.version}</version> | |||
| </dependency> | |||
| </dependencies> | |||
| <build> | |||
| <plugins> | |||
| <plugin> | |||
| <groupId>org.springframework.boot</groupId> | |||
| <artifactId>spring-boot-maven-plugin</artifactId> | |||
| <configuration> | |||
| <skip>false</skip> | |||
| <fork>true</fork> | |||
| <classifier>exec</classifier> | |||
| </configuration> | |||
| </plugin> | |||
| <!-- 跳过单元测试 --> | |||
| <plugin> | |||
| <groupId>org.apache.maven.plugins</groupId> | |||
| <artifactId>maven-surefire-plugin</artifactId> | |||
| <configuration> | |||
| <skipTests>true</skipTests> | |||
| </configuration> | |||
| </plugin> | |||
| </plugins> | |||
| </build> | |||
| </project> | |||
| @@ -0,0 +1,40 @@ | |||
| /** | |||
| * Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| * ============================================================= | |||
| */ | |||
| package org.dubhe.tadl; | |||
| import org.mybatis.spring.annotation.MapperScan; | |||
| import org.springframework.boot.SpringApplication; | |||
| import org.springframework.boot.autoconfigure.SpringBootApplication; | |||
| import org.springframework.scheduling.annotation.EnableAsync; | |||
| import org.springframework.scheduling.annotation.EnableScheduling; | |||
| /** | |||
| * @description TADL开发启动类 | |||
| * @date 2020-12-08 | |||
| */ | |||
| @SpringBootApplication(scanBasePackages = "org.dubhe") | |||
| @MapperScan(basePackages = {"org.dubhe.**.dao"}) | |||
| @EnableScheduling | |||
| //启动异步调用 | |||
| @EnableAsync | |||
| public class TadlApplication { | |||
| public static void main(String[] args) { | |||
| SpringApplication.run(TadlApplication.class, args); | |||
| } | |||
| } | |||
| @@ -0,0 +1,44 @@ | |||
| /** | |||
| * Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| * ============================================================= | |||
| */ | |||
| package org.dubhe.tadl.client; | |||
| import org.dubhe.biz.base.constant.ApplicationNameConst; | |||
| import org.dubhe.biz.base.vo.DataResponseBody; | |||
| import org.dubhe.biz.base.vo.QueryResourceSpecsVO; | |||
| import org.dubhe.tadl.client.fallback.AdminServiceFallback; | |||
| import org.springframework.cloud.openfeign.FeignClient; | |||
| import org.springframework.web.bind.annotation.GetMapping; | |||
| import org.springframework.web.bind.annotation.RequestParam; | |||
| /** | |||
| * @description feign调用demo | |||
| * @date 2020-11-04 | |||
| */ | |||
| @FeignClient(value = ApplicationNameConst.SERVER_ADMIN, fallback = AdminServiceFallback.class) | |||
| public interface AdminServiceClient { | |||
| /** | |||
| * 获取资源规格信息 | |||
| * | |||
| * @param id 资源ID | |||
| * @return 资源信息 | |||
| */ | |||
| @GetMapping(value = "/resourceSpecs/queryTadlResourceSpecs") | |||
| DataResponseBody<QueryResourceSpecsVO> queryTadlResourceSpecs(@RequestParam("id")Long id); | |||
| } | |||
| @@ -0,0 +1,45 @@ | |||
| /** | |||
| * Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| * ============================================================= | |||
| */ | |||
| package org.dubhe.tadl.client.fallback; | |||
| import org.dubhe.biz.base.vo.DataResponseBody; | |||
| import org.dubhe.biz.base.vo.QueryResourceSpecsVO; | |||
| import org.dubhe.biz.dataresponse.factory.DataResponseFactory; | |||
| import org.dubhe.tadl.client.AdminServiceClient; | |||
| import org.springframework.stereotype.Component; | |||
| /** | |||
| * @description Feign 熔断处理类 | |||
| * @date 2020-11-04 | |||
| */ | |||
| @Component | |||
| public class AdminServiceFallback implements AdminServiceClient { | |||
| /** | |||
| * 获取资源规格信息 | |||
| * | |||
| * @param id 资源ID | |||
| * @return 资源信息 | |||
| */ | |||
| @Override | |||
| public DataResponseBody<QueryResourceSpecsVO> queryTadlResourceSpecs(Long id) { | |||
| return DataResponseFactory.failed("call admin server queryTadlResourceSpecs error "); | |||
| } | |||
| } | |||
| @@ -0,0 +1,68 @@ | |||
| /** | |||
| * Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| * ============================================================= | |||
| */ | |||
| package org.dubhe.tadl.config; | |||
| import lombok.AllArgsConstructor; | |||
| import lombok.Data; | |||
| import lombok.NoArgsConstructor; | |||
| import org.springframework.boot.context.properties.ConfigurationProperties; | |||
| import org.springframework.context.annotation.Configuration; | |||
| import java.util.Arrays; | |||
| import java.util.HashMap; | |||
| import java.util.List; | |||
| import java.util.Map; | |||
| @Data | |||
| @AllArgsConstructor | |||
| @NoArgsConstructor | |||
| @Configuration | |||
| @ConfigurationProperties("tadl") | |||
| public class CmdConf { | |||
| private Map<String,Algorithm> cmd; | |||
| public Map<String,Algorithm> getCmd(){ | |||
| return cmd; | |||
| } | |||
| @Data | |||
| @AllArgsConstructor | |||
| @NoArgsConstructor | |||
| public static class Algorithm{ | |||
| private Key key; | |||
| private Val val; | |||
| @Data | |||
| @AllArgsConstructor | |||
| @NoArgsConstructor | |||
| public static class Key{ | |||
| private String key; | |||
| } | |||
| @Data | |||
| @AllArgsConstructor | |||
| @NoArgsConstructor | |||
| public static class Val{ | |||
| private String val; | |||
| public List<String> getVal() { | |||
| return Arrays.asList(val.split(",")); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,126 @@ | |||
| /** | |||
| * Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| * ============================================================= | |||
| */ | |||
| package org.dubhe.tadl.config; | |||
| import org.dubhe.biz.base.utils.StringUtils; | |||
| import org.dubhe.tadl.domain.dto.ExperimentAndTrailDTO; | |||
| import org.dubhe.tadl.listener.RedisStreamListener; | |||
| import org.dubhe.tadl.constant.RedisKeyConstant; | |||
| import org.springframework.beans.factory.DisposableBean; | |||
| import org.springframework.beans.factory.annotation.Autowired; | |||
| import org.springframework.boot.ApplicationArguments; | |||
| import org.springframework.boot.ApplicationRunner; | |||
| import org.springframework.context.annotation.Bean; | |||
| import org.springframework.context.annotation.Configuration; | |||
| import org.springframework.data.redis.connection.ReactiveRedisConnectionFactory; | |||
| import org.springframework.data.redis.connection.RedisConnectionFactory; | |||
| import org.springframework.data.redis.connection.stream.*; | |||
| import org.springframework.data.redis.core.ReactiveRedisTemplate; | |||
| import org.springframework.data.redis.core.ReactiveStreamOperations; | |||
| import org.springframework.data.redis.serializer.RedisSerializationContext; | |||
| import org.springframework.data.redis.stream.StreamMessageListenerContainer; | |||
| import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler; | |||
| import reactor.core.publisher.Mono; | |||
| import java.time.Duration; | |||
| /** | |||
| * @description redis stream listener config | |||
| * @date 2021-03-05 | |||
| */ | |||
| @Configuration | |||
| public class RedisStreamListenerContainerConfig implements ApplicationRunner, DisposableBean { | |||
| @Autowired | |||
| private RedisConnectionFactory redisConnectionFactory; | |||
| @Autowired | |||
| private RedisStreamListener redisStreamListener; | |||
| private StreamMessageListenerContainer<String, ObjectRecord<String, ExperimentAndTrailDTO>> streamMessageListenerContainer; | |||
| @Bean | |||
| public ThreadPoolTaskScheduler initTaskScheduler() { | |||
| ThreadPoolTaskScheduler threadPoolTaskScheduler = new ThreadPoolTaskScheduler(); | |||
| threadPoolTaskScheduler.setPoolSize(20); | |||
| return threadPoolTaskScheduler; | |||
| } | |||
| @Override | |||
| public void run(ApplicationArguments args) { | |||
| // 创建配置对象 | |||
| StreamMessageListenerContainer.StreamMessageListenerContainerOptions<String, ObjectRecord<String, ExperimentAndTrailDTO>> options = | |||
| StreamMessageListenerContainer.StreamMessageListenerContainerOptions.builder() | |||
| // 一次性最多拉取多少条消息 | |||
| .batchSize(10) | |||
| // 执行消息轮询的执行器 | |||
| .executor(initTaskScheduler()) | |||
| // 超时时间,设置为0,表示不超时(超时后会抛出异常) | |||
| .pollTimeout(Duration.ZERO) | |||
| .targetType(ExperimentAndTrailDTO.class) | |||
| .build(); | |||
| // 根据配置对象创建监听容器对象 | |||
| streamMessageListenerContainer = StreamMessageListenerContainer.create(redisConnectionFactory, options); | |||
| // 启动监听 | |||
| streamMessageListenerContainer.start(); | |||
| } | |||
| @Override | |||
| public void destroy() { | |||
| streamMessageListenerContainer.stop(); | |||
| } | |||
| @Bean | |||
| public ReactiveRedisTemplate<String, String> reactiveRedisTemplate(ReactiveRedisConnectionFactory factory) { | |||
| return new ReactiveRedisTemplate<>(factory, RedisSerializationContext.string()); | |||
| } | |||
| /** | |||
| * 查找Stream信息,如果不存在,则创建Stream | |||
| */ | |||
| public Mono<StreamInfo.XInfoStream> prepareStreamAndGroup(ReactiveStreamOperations<String, ?, ?> ops, String stream, String group) { | |||
| return ops.info(stream).onErrorResume(err -> ops.createGroup(stream, group).flatMap(s -> ops.info(stream))); | |||
| } | |||
| /** | |||
| * 绑定指定消费者 | |||
| * | |||
| * @param mono | |||
| * @param streamName | |||
| * @param group | |||
| */ | |||
| public void prepareDisposable(Mono<StreamInfo.XInfoStream> mono, String streamName, String group) { | |||
| mono.subscribe(stream -> streamMessageListenerContainer.receive(Consumer.from(group, RedisKeyConstant.CONSUMER), | |||
| StreamOffset.create(streamName, ReadOffset.lastConsumed()), | |||
| redisStreamListener)); | |||
| } | |||
| /** | |||
| * 创建Redis Stream 并绑定指定消费者 | |||
| * | |||
| * @param ops | |||
| * @param stream | |||
| * @param group | |||
| */ | |||
| public void buildRedisStream(ReactiveStreamOperations<String, ?, ?> ops, String stream, String group) { | |||
| if (!StringUtils.isEmpty(stream) && !StringUtils.isEmpty(group)) { | |||
| Mono<StreamInfo.XInfoStream> mono = prepareStreamAndGroup(ops, stream, group); | |||
| prepareDisposable(mono, stream, group); | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,44 @@ | |||
| /** | |||
| * Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| * ============================================================= | |||
| */ | |||
| package org.dubhe.tadl.config; | |||
| import lombok.Data; | |||
| import org.springframework.boot.context.properties.ConfigurationProperties; | |||
| import org.springframework.stereotype.Component; | |||
| /** | |||
| * @description | |||
| * @date 2021-03-23 | |||
| */ | |||
| @Component | |||
| @Data | |||
| @ConfigurationProperties(prefix = "tadl") | |||
| public class TadlJobConfig { | |||
| /** | |||
| * 镜像 | |||
| */ | |||
| private String image; | |||
| /** | |||
| * docker数据集 | |||
| */ | |||
| private String dockerDatasetPath; | |||
| /** | |||
| * docker 实验路径 | |||
| */ | |||
| private String dockerExperimentPath; | |||
| } | |||
| @@ -0,0 +1,66 @@ | |||
| /** | |||
| * Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| * ============================================================= | |||
| */ | |||
| package org.dubhe.tadl.config; | |||
| import org.springframework.beans.factory.annotation.Value; | |||
| import org.springframework.context.annotation.Bean; | |||
| import org.springframework.context.annotation.Configuration; | |||
| import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor; | |||
| import java.util.concurrent.Executor; | |||
| import java.util.concurrent.ThreadPoolExecutor; | |||
| /** | |||
| * @description TADL模块 trial线程池配置类 | |||
| * @date 2020-09-02 | |||
| */ | |||
| @Configuration | |||
| public class TrialPoolConfig { | |||
| @Value("${basepool.corePoolSize:40}") | |||
| private Integer corePoolSize; | |||
| @Value("${basepool.maximumPoolSize:60}") | |||
| private Integer maximumPoolSize; | |||
| @Value("${basepool.keepAliveTime:120}") | |||
| private Integer keepAliveTime; | |||
| @Value("${basepool.blockQueueSize:20}") | |||
| private Integer blockQueueSize; | |||
| /** | |||
| * TADL trial异步处理线程池 | |||
| * @return Executor 线程实例 | |||
| */ | |||
| @Bean("tadlExecutor") | |||
| public Executor servingAsyncExecutor() { | |||
| ThreadPoolTaskExecutor taskExecutor = new ThreadPoolTaskExecutor(); | |||
| //核心线程数 | |||
| taskExecutor.setCorePoolSize(corePoolSize); | |||
| taskExecutor.setAllowCoreThreadTimeOut(true); | |||
| //最大线程数 | |||
| taskExecutor.setMaxPoolSize(maximumPoolSize); | |||
| //超时时间 | |||
| taskExecutor.setKeepAliveSeconds(keepAliveTime); | |||
| //配置队列大小 | |||
| taskExecutor.setQueueCapacity(blockQueueSize); | |||
| //配置线程池前缀 | |||
| taskExecutor.setThreadNamePrefix("async-trial-"); | |||
| //拒绝策略 | |||
| taskExecutor.setRejectedExecutionHandler(new ThreadPoolExecutor.AbortPolicy()); | |||
| taskExecutor.initialize(); | |||
| return taskExecutor; | |||
| } | |||
| } | |||
| @@ -0,0 +1,121 @@ | |||
| /** | |||
| * Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this Trial except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| * ============================================================= | |||
| */ | |||
| package org.dubhe.tadl.constant; | |||
| /** | |||
| * @description redis中key定义 | |||
| * @date 2021-07-29 | |||
| */ | |||
| public class RedisKeyConstant { | |||
| /** | |||
| * 冒号 | |||
| */ | |||
| public static final String COLON = ":"; | |||
| /** | |||
| * 下划线 | |||
| */ | |||
| public static final String UNDERLINE = "_"; | |||
| /** | |||
| * 项目名称 | |||
| */ | |||
| public static final String TADL = "tadl:"; | |||
| /** | |||
| * 实验 | |||
| */ | |||
| public static final String EXPERIMENT = "experiment:"; | |||
| /** | |||
| * 实验阶段 | |||
| */ | |||
| public static final String EXPERIMENT_STAGE = "experiment_stage:"; | |||
| /** | |||
| * 实验trial | |||
| */ | |||
| public static final String TRIAL = "trial:"; | |||
| /** | |||
| * Stream Stage Key | |||
| * tadl:experiment_stage:trial:run_param: | |||
| */ | |||
| private final static String STREAM_STAGE_KEY = TADL + EXPERIMENT_STAGE + TRIAL + "run_param:"; | |||
| /** | |||
| * Stream Group | |||
| * tadl:experiment_stage:group | |||
| */ | |||
| private final static String STREAM_GROUP_KEY = TADL + EXPERIMENT_STAGE + "group:"; | |||
| /** | |||
| * consumer | |||
| */ | |||
| public final static String CONSUMER = "consumer"; | |||
| /** | |||
| * paused key | |||
| */ | |||
| private final static String EXPERIMENT_PAUSED_KEY = TADL + EXPERIMENT + "paused:"; | |||
| /** | |||
| * 实验阶段过期时间zset队列 | |||
| * tadl:experiment_stage:expired_time_set | |||
| */ | |||
| public static final String EXPERIMENT_STAGE_EXPIRED_TIME_SET = TADL + EXPERIMENT_STAGE + "expired_time_set"; | |||
| /** | |||
| * 生成 组合 阶段 Stream key | |||
| * 用于存储k8s相关的,trial运行参数 | |||
| * @param indexId 索引ID | |||
| * @param stageId 阶段ID | |||
| * @return String | |||
| */ | |||
| public static String buildStreamStageKey(long indexId, Long stageId) { | |||
| return STREAM_STAGE_KEY + indexId + UNDERLINE + stageId; | |||
| } | |||
| /** | |||
| * 生成 组合 阶段 Stream group key | |||
| * | |||
| * @param indexId 索引ID | |||
| * @param stageId 阶段ID | |||
| * @return String | |||
| */ | |||
| public static String buildStreamGroupStageKey(long indexId, Long stageId) { | |||
| return STREAM_GROUP_KEY + indexId + COLON + stageId; | |||
| } | |||
| /** | |||
| * 生成组合的 pausedKey | |||
| * @param experimentId 实验id | |||
| * @return String | |||
| */ | |||
| public static String buildPausedKey(Long experimentId) { | |||
| return EXPERIMENT_PAUSED_KEY + experimentId; | |||
| } | |||
| /** | |||
| * 生成组合的 deletedKey | |||
| * @param experimentId 实验id | |||
| * @return String | |||
| */ | |||
| public static String buildDeletedKey(Long experimentId, Long stageId, Long trialId) { | |||
| return EXPERIMENT_PAUSED_KEY + experimentId + UNDERLINE + stageId + UNDERLINE + trialId; | |||
| } | |||
| } | |||
| @@ -0,0 +1,134 @@ | |||
| /** | |||
| * Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| * ============================================================= | |||
| */ | |||
| package org.dubhe.tadl.constant; | |||
| import java.io.File; | |||
| /** | |||
| * @description 常量 | |||
| * @date 2020-12-16 | |||
| */ | |||
| public class TadlConstant { | |||
| /** | |||
| * 模块路径 | |||
| */ | |||
| public static final String MODULE_URL_PREFIX = "/"; | |||
| /** | |||
| * 算法工程名 | |||
| */ | |||
| public static final String ALGORITHM_PROJECT_NAME = "TADL"; | |||
| /** | |||
| * 算法cmd符号 | |||
| */ | |||
| public static final String PARAM_SYMBOL = "--"; | |||
| /** | |||
| * 算法配置文件后缀名 | |||
| */ | |||
| public static final String ALGORITHM_CONFIGURATION_FILE_SUFFIX = ".yaml"; | |||
| /** | |||
| * yaml | |||
| */ | |||
| public static final String ALGORITHM_YAML = "yaml/"; | |||
| /** | |||
| * 算法文件转化python脚本文件名 | |||
| */ | |||
| public static final String ALGORITHM_TRANSFORM_FILE_NAME = "transformParam.py"; | |||
| /** | |||
| * zip后缀名 | |||
| */ | |||
| public static final String ZIP_SUFFIX = ".zip"; | |||
| /** | |||
| * 默认的版本 | |||
| */ | |||
| public static final String DEFAULT_VERSION = "V0001"; | |||
| /** | |||
| * 版本名称的首字母 | |||
| */ | |||
| public static final String DATASET_VERSION_PREFIX = "V"; | |||
| /** | |||
| * 常量数据 | |||
| */ | |||
| public static final int NUMBER_ZERO = 0; | |||
| public static final int NUMBER_ONE = 1; | |||
| public static final String EXECUTE_SCRIPT_PATH = File.separator + "algorithm" + File.separator + "TADL" + File.separator + "pytorch"; | |||
| public static final String MODEL_SELECTED_SPACE_PATH = File.separator + "model_selected_space" + File.separator + "model_selected_space.json"; | |||
| public static final String RESULT_PATH = File.separator + "result" + File.separator + "result.json"; | |||
| public static final String LOG_PATH = File.separator + "log"; | |||
| public static final String SEARCH_SPACE_PATH = File.separator + "search_space.json"; | |||
| public static final String BEST_SELECTED_SPACE_PATH = File.separator + "best_selected_space.json"; | |||
| public static final String BEST_CHECKPOINT_DIR = File.separator + "best_checkpoint" + File.separator; | |||
| public static final String AND = "&&"; | |||
| public static final String RUN_PARAMETER = "run_parameter"; | |||
| public static final String MODEL_SELECTED_SPACE_PATH_STRING = "model_selected_space_path"; | |||
| public static final String RESULT_PATH_STRING = "result_path"; | |||
| public static final String LOG_PATH_STRING = "log_path"; | |||
| public static final String EXPERIMENT_DIR_STRING = "experiment_dir"; | |||
| public static final String SEARCH_SPACE_PATH_STRING = "search_space_path"; | |||
| public static final String BEST_SELECTED_SPACE_PATH_STRING = "best_selected_space_path"; | |||
| public static final String BEST_CHECKPOINT_DIR_STRING = "best_checkpoint_dir"; | |||
| public static final String DATA_DIR_STRING = "data_dir"; | |||
| public static final String TRIAL_ID_STRING = "trial_id"; | |||
| public static final String LOCK = "lock"; | |||
| public static final String SEARCH_SPACE_FILENAME = "search_space.json"; | |||
| public static final String BEST_SELECTED_SPACE_FILENAME = "best_selected_space.json"; | |||
| public static final String RESULT_JSON_TYPE = "accuracy"; | |||
| /** | |||
| * 实验步骤流程日志 | |||
| */ | |||
| public static final String EXPERIMENT_STAGE_FLOW_LOG = "(stage_id = {})"; | |||
| public static final String EXPERIMENT_TRIAL_FLOW_LOG = "(trial_id = {})"; | |||
| public static final String PROCESS_TRIAL_KEYWORD_LOG = "The experiment id:{},stage id:{},trial id:{}."; | |||
| public static final String PROCESS_STAGE_KEYWORD_LOG = "The experiment id:{},stage id:{}."; | |||
| public static final String PROCESS_EXPERIMENT_FLOW_LOG = "The experiment id:{}."; | |||
| /** | |||
| * 状态详情记录 | |||
| */ | |||
| public static final String TRIAL_TASK_DELETE_EXCEPTION = "TRIAL任务删除异常"; | |||
| public static final String ADMIN_SERVER_EXCEPTION = "admin服务异常"; | |||
| public static final String TRIAL_STARTUP_COMMAND_ASSEMBLY_EXCEPTION = "TRIAL启动命令组装异常"; | |||
| public static final String ABNORMAL_EXPERIMENTAL_PROCESS = "实验流程异常"; | |||
| public static final String EXPERIMENT_RUN_FAILED = "实验运行失败"; | |||
| public static final String ABNORMAL_OPERATION_OF_ALGORITHM = "算法运行异常"; | |||
| public static final String TRIAL_STARTUP_FAILED = "TRIAL启动失败"; | |||
| public static final String TRIAL_STARTUP_EXCEPTION = "TRIAL启动异常"; | |||
| public static final String REDIS_STREAM_DATA_CONVERSION_EXCEPTION = "REDIS STREAM 数据转换异常"; | |||
| public static final String DISTRIBUTED_LOCK_ACQUISITION_FAILED = "分布式锁获取失败"; | |||
| public static final String STAGE_OVERTIME = "实验阶段超时"; | |||
| public static final String UNKNOWN_EXCEPTION= "未知异常"; | |||
| public static final String REDIS_MESSAGE_QUEUE_EXCEPTION = "REDIS消息队列异常"; | |||
| } | |||
| @@ -0,0 +1,38 @@ | |||
| /** | |||
| * Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| * ============================================================= | |||
| */ | |||
| package org.dubhe.tadl.dao; | |||
| import com.baomidou.mybatisplus.core.mapper.BaseMapper; | |||
| import org.apache.ibatis.annotations.Param; | |||
| import org.apache.ibatis.annotations.Select; | |||
| import org.dubhe.tadl.domain.entity.Algorithm; | |||
| /** | |||
| * @description 算法Mapper | |||
| * @date 2021-03-10 | |||
| */ | |||
| public interface AlgorithmMapper extends BaseMapper<Algorithm> { | |||
| /** | |||
| * 通过id查询算法 | |||
| * @param id 算法id | |||
| * @return Algorithm | |||
| */ | |||
| @Select("select * from tadl_algorithm where id = #{id}") | |||
| Algorithm getOneById(@Param("id") Long id); | |||
| } | |||
| @@ -0,0 +1,46 @@ | |||
| /** | |||
| * Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| * ============================================================= | |||
| */ | |||
| package org.dubhe.tadl.dao; | |||
| import com.baomidou.mybatisplus.core.mapper.BaseMapper; | |||
| import org.apache.ibatis.annotations.Param; | |||
| import org.apache.ibatis.annotations.Select; | |||
| import org.apache.ibatis.annotations.Update; | |||
| import org.dubhe.tadl.domain.entity.AlgorithmStage; | |||
| /** | |||
| * @description 算法阶段管理服务Mapper | |||
| * @date 2021-03-22 | |||
| */ | |||
| public interface AlgorithmStageMapper extends BaseMapper<AlgorithmStage>{ | |||
| /** | |||
| * 变更算法阶段删除标识 | |||
| * @param versionId 算法版本id | |||
| * @param deleted 删除标识 | |||
| * @return | |||
| */ | |||
| @Update("update tadl_algorithm_stage set deleted=#{deleted} where algorithm_version_id = #{versionId}") | |||
| int updateStageStatusByVersionId(@Param("versionId") Long versionId, @Param("deleted") Boolean deleted); | |||
| /** | |||
| * 通过id查询算法阶段 | |||
| * @param id 算法阶段id | |||
| * @return AlgorithmStage | |||
| */ | |||
| @Select("select * from tadl_algorithm_stage where id = #{id}") | |||
| AlgorithmStage getOneById(@Param("id") Long id); | |||
| } | |||
| @@ -0,0 +1,57 @@ | |||
| /** | |||
| * Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| * ============================================================= | |||
| */ | |||
| package org.dubhe.tadl.dao; | |||
| import com.baomidou.mybatisplus.core.mapper.BaseMapper; | |||
| import org.apache.ibatis.annotations.Param; | |||
| import org.apache.ibatis.annotations.Select; | |||
| import org.apache.ibatis.annotations.Update; | |||
| import org.dubhe.tadl.domain.entity.AlgorithmVersion; | |||
| /** | |||
| * @description 算法版本管理服务Mapper | |||
| * @date 2021-03-22 | |||
| */ | |||
| public interface AlgorithmVersionMapper extends BaseMapper<AlgorithmVersion>{ | |||
| /** | |||
| * 获取指定算法当前使用最大版本号 | |||
| * | |||
| * @param algorithmId 数据集ID | |||
| * @return String 指定算法当前使用最大版本号 | |||
| */ | |||
| @Select("select max(version_name) from tadl_algorithm_version where algorithm_id=#{algorithmId} and version_name like 'V%'") | |||
| String getMaxVersionName(@Param("algorithmId") Long algorithmId); | |||
| /** | |||
| * 变更算法版本删除标识 | |||
| * @param id 算法版本id | |||
| * @param deleted 删除标识 | |||
| * @return int | |||
| */ | |||
| @Update("update tadl_algorithm_version set deleted = #{deleted} where id = #{id}") | |||
| int updateAlgorithmVersionStatus(@Param("id") Long id,@Param("deleted") Boolean deleted); | |||
| /** | |||
| * 根据id查询算法版本 | |||
| * @param id 算法版本id | |||
| * @return AlgorithmVersion | |||
| */ | |||
| @Select("select * from tadl_algorithm_version where id = #{id}") | |||
| AlgorithmVersion getOneById(@Param("id") Long id ); | |||
| } | |||
| @@ -0,0 +1,45 @@ | |||
| /** | |||
| * Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| * ============================================================= | |||
| */ | |||
| package org.dubhe.tadl.dao; | |||
| import com.baomidou.mybatisplus.core.mapper.BaseMapper; | |||
| import org.apache.ibatis.annotations.Param; | |||
| import org.dubhe.tadl.domain.entity.Experiment; | |||
| /** | |||
| * @description 实验管理服务Mapper | |||
| * @date 2021-03-22 | |||
| */ | |||
| public interface ExperimentMapper extends BaseMapper<Experiment>{ | |||
| /** | |||
| * 根据实验id删除实验,实验阶段及trial实验数据 | |||
| * @param id 实验id | |||
| * @param deleted 删除状态 | |||
| * @return | |||
| */ | |||
| int updateExperimentDeletedById(@Param("id") Long id,@Param("deleted") Boolean deleted); | |||
| /** | |||
| * 根据trial实验id变更实验为运行失败 | |||
| * @param trialId trial实验id | |||
| * @param trialStatus trial 状态 | |||
| * @param stageStatus 实验阶段状态 | |||
| * @param experimentStatus 实验状态 | |||
| * @param statusDetail 状态详情 | |||
| */ | |||
| void updateExperimentFailedByTrialId(@Param("trialId") Long trialId,@Param("trialStatus")Integer trialStatus,@Param("stageStatus") Integer stageStatus,@Param("experimentStatus") Integer experimentStatus,@Param("statusDetail")String statusDetail); | |||
| } | |||
| @@ -0,0 +1,60 @@ | |||
| /** | |||
| * Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| * ============================================================= | |||
| */ | |||
| package org.dubhe.tadl.dao; | |||
| import com.baomidou.mybatisplus.core.mapper.BaseMapper; | |||
| import org.apache.ibatis.annotations.Param; | |||
| import org.dubhe.tadl.domain.entity.ExperimentStage; | |||
| import java.util.List; | |||
| /** | |||
| * @description 实验阶段管理服务Mapper | |||
| * @date 2021-03-22 | |||
| */ | |||
| public interface ExperimentStageMapper extends BaseMapper<ExperimentStage>{ | |||
| /** | |||
| * 根据实验ID查询实验阶段状态列表 | |||
| * | |||
| * @param experimentId 实验id | |||
| * @return List<Integer> 实验阶段状态列表 | |||
| */ | |||
| List<Integer> getExperimentStateByStage(@Param("experimentId") Long experimentId); | |||
| /** | |||
| * 更改实验阶段状态 | |||
| * @param id 实验阶段 id | |||
| * @param status 实验阶段状态 | |||
| */ | |||
| void updateExperimentStageStatus(@Param("id") Long id,@Param("status") Integer status); | |||
| /** | |||
| * 根据实验Id和实验阶段id查找 | |||
| * @param experimentId 实验 id | |||
| * @param experimentStageId 实验阶段 id | |||
| * @return | |||
| */ | |||
| ExperimentStage getExperimentStateByExperimentIdAndStageId(@Param("experimentId") Long experimentId,@Param("experimentStageId") Long experimentStageId); | |||
| /** | |||
| * 批量插入实验阶段 | |||
| * @param experimentStageList 实验阶段集合 | |||
| * @return 批量插入数量 | |||
| */ | |||
| Integer insertExperimentStageList(@Param("experimentStageList") List<ExperimentStage> experimentStageList); | |||
| } | |||
| @@ -0,0 +1,46 @@ | |||
| /** | |||
| * Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| * ============================================================= | |||
| */ | |||
| package org.dubhe.tadl.dao; | |||
| import com.baomidou.mybatisplus.core.mapper.BaseMapper; | |||
| import org.apache.ibatis.annotations.Param; | |||
| import org.dubhe.tadl.domain.entity.Trial; | |||
| import org.dubhe.tadl.domain.entity.TrialData; | |||
| import java.util.List; | |||
| /** | |||
| * @description 试验运行结果服务Mapper | |||
| * @date 2021-03-22 | |||
| */ | |||
| public interface TrialDataMapper extends BaseMapper<TrialData>{ | |||
| /** | |||
| * 批量写入 trial data | |||
| * | |||
| * @param trialDataList trial 列表 | |||
| */ | |||
| void saveList(@Param("trialDataList") List<TrialData> trialDataList); | |||
| /** | |||
| * 更新最大值 | |||
| * @param trialId | |||
| * @param value | |||
| * @return | |||
| */ | |||
| int updateValue(@Param("trialId") Long trialId,@Param("value") Double value); | |||
| } | |||
| @@ -0,0 +1,97 @@ | |||
| /** | |||
| * Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| * ============================================================= | |||
| */ | |||
| package org.dubhe.tadl.dao; | |||
| import com.baomidou.mybatisplus.core.mapper.BaseMapper; | |||
| import org.apache.ibatis.annotations.Param; | |||
| import org.dubhe.tadl.domain.entity.Trial; | |||
| import org.dubhe.tadl.domain.entity.TrialData; | |||
| import java.util.List; | |||
| /** | |||
| * @description 试验管理服务Mapper | |||
| * @date 2021-03-22 | |||
| */ | |||
| public interface TrialMapper extends BaseMapper<Trial>{ | |||
| /** | |||
| * 根据实验阶段ID查询trial状态列表 | |||
| * | |||
| * @param experimentStageId 实验阶段id | |||
| * @return List<Integer> trial 状态set | |||
| */ | |||
| List<Integer> getExperimentStageStateByTrial(@Param("experimentStageId") Long experimentStageId); | |||
| /** | |||
| * 批量写入trial | |||
| * | |||
| * @param trials trial 列表 | |||
| */ | |||
| void saveList(@Param("trials")List<Trial> trials); | |||
| /** | |||
| * 获取当前阶段最佳的精度 | |||
| * | |||
| * @param experimentId 实验ID | |||
| * @param stageId 阶段ID | |||
| * @return 当前阶段最佳精度 | |||
| */ | |||
| double getBestData(@Param("experimentId") Long experimentId,@Param("stageId") Long stageId); | |||
| /** | |||
| * 更新trial实验状态 | |||
| * @param id trial ID | |||
| * @param status trial 状态 | |||
| */ | |||
| void updateTrialStatus(@Param("id") Long id ,@Param("status") Integer status); | |||
| /** | |||
| * 根据 实验id和阶段id 查询TrialData表 | |||
| * @param experimentId | |||
| * @param stageId | |||
| * @return TrialData | |||
| */ | |||
| List<TrialData> queryTrialDataById(@Param("experimentId") Long experimentId,@Param("stageId") Long stageId); | |||
| /** | |||
| * 根据 实验id和阶段id 查询Trial表 | |||
| * @param experimentId | |||
| * @param stageId | |||
| * @return | |||
| */ | |||
| List<Trial> queryTrialById(@Param("experimentId") Long experimentId,@Param("stageId") Long stageId,@Param("trialIds") List<Long> trialIds,@Param("statusList") List<Integer> statusList); | |||
| /** | |||
| * 根据id变更trial 为失败 | |||
| * @param trialId trial id | |||
| * @param trialStatus trial 实验状态 | |||
| * @param stageStatus 实验阶段状态 | |||
| * @param experimentStatus 实验状态 | |||
| * @param statusDetail 状态详情 | |||
| */ | |||
| void updateTrialFailed(@Param("id") Long trialId,@Param("trialStatus")Integer trialStatus,@Param("stageStatus") Integer stageStatus,@Param("experimentStatus") Integer experimentStatus,@Param("statusDetail")String statusDetail); | |||
| /** | |||
| * 获取成功的trial数量 | |||
| * @param experimentId | |||
| * @param stageId | |||
| * @return 成功的trial数量 | |||
| */ | |||
| Integer getTrialCountOfStatus(Long experimentId,Long stageId,Integer status); | |||
| } | |||
| @@ -0,0 +1,52 @@ | |||
| /** | |||
| * Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| * ============================================================= | |||
| */ | |||
| package org.dubhe.tadl.domain.bo; | |||
| import lombok.Data; | |||
| import java.util.Objects; | |||
| @Data | |||
| public class IntermediateAccuracy { | |||
| /** trial中间精度序号 */ | |||
| private Integer sequence; | |||
| /** 中间精度类目 */ | |||
| private String category; | |||
| /** trial中间精度 */ | |||
| private Double value; | |||
| @Override | |||
| public boolean equals(Object o) { | |||
| if (this == o) { | |||
| return true; | |||
| } | |||
| if (o != null && getClass() == o.getClass()) { | |||
| IntermediateAccuracy intermediateAccuracy = (IntermediateAccuracy) o; | |||
| return sequence.equals(intermediateAccuracy.sequence); | |||
| } else { | |||
| return false; | |||
| } | |||
| } | |||
| @Override | |||
| public int hashCode() { | |||
| return Objects.hash(sequence); | |||
| } | |||
| } | |||