diff --git a/react-ui/config/routes.ts b/react-ui/config/routes.ts index 3a58d07b..e89d5d60 100644 --- a/react-ui/config/routes.ts +++ b/react-ui/config/routes.ts @@ -94,7 +94,7 @@ export default [ { name: '实验训练', path: ':workflowId/:id', - component: './Experiment/training/index', + component: './Experiment/Info/index', }, { name: '实验对比', @@ -112,18 +112,18 @@ export default [ { name: '开发环境', path: '', + component: './DevelopmentEnvironment/List', + }, + { + name: '创建编辑器', + path: 'create', + component: './DevelopmentEnvironment/Create', + }, + { + name: '编辑器', + path: 'editor', component: './DevelopmentEnvironment/Editor', }, - // { - // name: '创建编辑器', - // path: 'create', - // component: './DevelopmentEnvironment/Create', - // }, - // { - // name: '编辑器', - // path: 'editor', - // component: './DevelopmentEnvironment/Editor', - // }, ], }, { diff --git a/react-ui/src/components/ArrayTableCell/index.tsx b/react-ui/src/components/ArrayTableCell/index.tsx new file mode 100644 index 00000000..de109ff4 --- /dev/null +++ b/react-ui/src/components/ArrayTableCell/index.tsx @@ -0,0 +1,37 @@ +/* + * @Author: 赵伟 + * @Date: 2024-04-28 14:18:11 + * @Description: 自定义 Table 数组类单元格 + */ + +import { Tooltip } from 'antd'; + +function ArrayTableCell(ellipsis: boolean = false, property?: string) { + return (value?: any | null) => { + if ( + value === undefined || + value === null || + Array.isArray(value) === false || + value.length === 0 + ) { + return --; + } + + let list = value; + if (property && typeof value[0] === 'object') { + list = value.map((item) => item[property]); + } + const text = list.join(','); + if (ellipsis) { + return ( + + {text}; + + ); + } else { + return {text}; + } + }; +} + +export default ArrayTableCell; diff --git a/react-ui/src/components/CommonTableCell/index.tsx b/react-ui/src/components/CommonTableCell/index.tsx index afa5d8d3..c86ef9a9 100644 --- a/react-ui/src/components/CommonTableCell/index.tsx +++ b/react-ui/src/components/CommonTableCell/index.tsx @@ -6,13 +6,13 @@ import { Tooltip } from 'antd'; -function renderCell(text?: string | null) { +function renderCell(text?: any | null) { return {text ?? '--'}; } function CommonTableCell(ellipsis: boolean = false) { if (ellipsis) { - return (text?: string | null) => ( + return (text?: any | null) => ( {renderCell(text)} diff --git a/react-ui/src/enums/index.ts b/react-ui/src/enums/index.ts index b31aee3a..b75eeca4 100644 --- a/react-ui/src/enums/index.ts +++ b/react-ui/src/enums/index.ts @@ -1,9 +1,36 @@ +/* + * @Author: 赵伟 + * @Date: 2024-06-07 11:22:28 + * @Description: 接口返回的枚举值和共用的枚举值定义在这里 + */ + // 公开还是私有 TabKey export enum CommonTabKeys { Private = 'Private', // 私有 Public = 'Public', // 公开 } +// 实验状态 +export enum ExperimentStatus { + Running = 'Running', // 运行中 + Succeeded = 'Succeeded', // 成功 + Pending = 'Pending', // 启动中 + Failed = 'Failed', // 失败 + Error = 'Error', // 错误 + Terminated = 'Terminated', // 终止 + Skipped = 'Skipped', // 跳过 + Omitted = 'Omitted', // 忽略 +} + +// TensorBoard 状态 +export enum TensorBoardStatus { + Unknown = 'Unknown', // 未知 + Pending = 'Pending', // 启动中 + Running = 'Running', // 运行中 + Terminated = 'Terminated', // 未启动或者已终止 + Failed = 'Failed', // 失败 +} + // 镜像版本状态 export enum MirrorVersionStatus { Available = 'available', // 可用 diff --git a/react-ui/src/overrides.less b/react-ui/src/overrides.less index 61102f12..0f84889d 100644 --- a/react-ui/src/overrides.less +++ b/react-ui/src/overrides.less @@ -4,7 +4,7 @@ * @Description: 覆盖 antd 样式 */ -// 设置 Table 可以滑动 +// 设置 Table 可以滑动,带分页 .vertical-scroll-table { .ant-table-wrapper { height: 100%; @@ -30,6 +30,32 @@ } } +// 设置 Table 可以滑动,没有分页 +.vertical-scroll-table-no-page { + .ant-table-wrapper { + height: 100%; + .ant-spin-nested-loading { + height: 100%; + + .ant-spin-container { + height: 100%; + + .ant-table { + height: 100%; + + .ant-table-container { + height: 100%; + + .ant-table-body { + overflow-y: auto !important; + } + } + } + } + } + } +} + // Tabs 样式 // 删除底部白色横线 .ant-tabs { diff --git a/react-ui/src/pages/Dataset/components/ResourceIntro/index.tsx b/react-ui/src/pages/Dataset/components/ResourceIntro/index.tsx index e8bed08c..27908de7 100644 --- a/react-ui/src/pages/Dataset/components/ResourceIntro/index.tsx +++ b/react-ui/src/pages/Dataset/components/ResourceIntro/index.tsx @@ -8,10 +8,11 @@ import { ResourceData, ResourceType, resourceConfig } from '../../config'; import ResourceVersion from '../ResourceVersion'; import styles from './index.less'; +// 这里值小写是因为值会写在 url 中 export enum ResourceInfoTabKeys { - Introduction = 'introduction', - Version = 'version', - Evolution = 'evolution', + Introduction = 'introduction', // 简介 + Version = 'version', // 版本 + Evolution = 'evolution', // 演化 } type ResourceIntroProps = { diff --git a/react-ui/src/pages/Experiment/Comparison/index.less b/react-ui/src/pages/Experiment/Comparison/index.less index 288ce2ed..a491c621 100644 --- a/react-ui/src/pages/Experiment/Comparison/index.less +++ b/react-ui/src/pages/Experiment/Comparison/index.less @@ -17,5 +17,17 @@ padding: 20px 30px 0; background-color: white; border-radius: 10px; + + :global { + .ant-table-container { + border: none !important; + } + .ant-table-tbody { + .ant-table-cell { + border-right: none !important; + border-left: none !important; + } + } + } } } diff --git a/react-ui/src/pages/Experiment/Comparison/index.tsx b/react-ui/src/pages/Experiment/Comparison/index.tsx index c53a0a94..59cd3f8b 100644 --- a/react-ui/src/pages/Experiment/Comparison/index.tsx +++ b/react-ui/src/pages/Experiment/Comparison/index.tsx @@ -1,32 +1,50 @@ -import CommonTableCell from '@/components/CommonTableCell'; -import { useCacheState } from '@/hooks/pageCacheState'; -import { getExpEvaluateInfosReq, getExpTrainInfosReq } from '@/services/experiment'; +// import { useCacheState } from '@/hooks/pageCacheState'; +import { + getExpEvaluateInfosReq, + getExpMetricsReq, + getExpTrainInfosReq, +} from '@/services/experiment'; import { to } from '@/utils/promise'; +import tableCellRender, { arrayFormatter, dateFormatter } from '@/utils/table'; import { useSearchParams } from '@umijs/max'; -import { Button, Table, TablePaginationConfig, TableProps } from 'antd'; +import { App, Button, Table, /*TablePaginationConfig,*/ TableProps } from 'antd'; import classNames from 'classnames'; -import { useEffect, useState } from 'react'; +import { useEffect, useMemo, useState } from 'react'; +import ExperimentStatusCell from '../components/ExperimentStatusCell'; import styles from './index.less'; export enum ComparisonType { - Train = 'train', // 训练 - Evaluate = 'evaluate', // 评估 + Train = 'Train', // 训练 + Evaluate = 'Evaluate', // 评估 } +type TableData = { + experiment_ins_id: number; + run_id: string; + dataset: string[]; + start_time: string; + status: string; + metrics_names: string[]; + metrics: Record; + params_names: string[]; + params: Record; +}; + function ExperimentComparison() { const [searchParams] = useSearchParams(); const comparisonType = searchParams.get('type'); - const experimentId = searchParams.get('experimentId'); - const [tableData, setTableData] = useState([]); - const [cacheState, setCacheState] = useCacheState(); - const [total, setTotal] = useState(0); + const experimentId = searchParams.get('id'); + const [tableData, setTableData] = useState([]); + // const [cacheState, setCacheState] = useCacheState(); + // const [total, setTotal] = useState(0); const [selectedRowKeys, setSelectedRowKeys] = useState([]); - const [pagination, setPagination] = useState( - cacheState?.pagination ?? { - current: 1, - pageSize: 10, - }, - ); + const { message } = App.useApp(); + // const [pagination, setPagination] = useState( + // cacheState?.pagination ?? { + // current: 1, + // pageSize: 10, + // }, + // ); useEffect(() => { getComparisonData(); @@ -38,20 +56,39 @@ function ExperimentComparison() { comparisonType === ComparisonType.Train ? getExpTrainInfosReq : getExpEvaluateInfosReq; const [res] = await to(request(experimentId)); if (res && res.data) { - const { content = [], totalElements = 0 } = res.data; - setTableData(content); - setTotal(totalElements); + // const { content = [], totalElements = 0 } = res.data; + setTableData(res.data); + // setTotal(totalElements); } }; - // 分页切换 - const handleTableChange: TableProps['onChange'] = (pagination, filters, sorter, { action }) => { - if (action === 'paginate') { - setPagination(pagination); + // 获取对比 url + const getExpMetrics = async () => { + const [res] = await to(getExpMetricsReq(selectedRowKeys)); + if (res && res.data) { + const url = res.data; + window.open(url, '_blank'); } - // console.log(pagination, filters, sorter, action); }; + // 对比按钮 click + const hanldeComparisonClick = () => { + if (selectedRowKeys.length < 2) { + message.error('请至少选择两项进行对比'); + return; + } + getExpMetrics(); + }; + + // 分页切换 + // const handleTableChange: TableProps['onChange'] = (pagination, filters, sorter, { action }) => { + // if (action === 'paginate') { + // setPagination(pagination); + // } + // // console.log(pagination, filters, sorter, action); + // }; + + // 选择行 const rowSelection: TableProps['rowSelection'] = { type: 'checkbox', selectedRowKeys, @@ -61,148 +98,96 @@ function ExperimentComparison() { }, }; - const columns: TableProps['columns'] = [ - { - title: '基本信息', - children: [ - { - title: '实例ID', - dataIndex: 'name', - key: 'name', - width: '30%', - render: CommonTableCell(), - }, - { - title: '运行时间', - dataIndex: 'name', - key: 'name', - width: '30%', - render: CommonTableCell(), - }, - { - title: '运行状态', - dataIndex: 'name', - key: 'name', - width: '30%', - render: CommonTableCell(), - }, - { - title: '训练数据集', - dataIndex: 'name', - key: 'name', - width: '30%', - render: CommonTableCell(), - }, - { - title: '增量训练', - dataIndex: 'name', - key: 'name', - width: '30%', - render: CommonTableCell(), - }, - ], - }, - { - title: '训练参数', - children: [ - { - title: 'batchsize', - dataIndex: 'name', - key: 'name', - width: '30%', - render: CommonTableCell(), - }, - { - title: 'config', - dataIndex: 'name', - key: 'name', - width: '30%', - render: CommonTableCell(), - }, - { - title: 'epoch', - dataIndex: 'name', - key: 'name', - width: '30%', - render: CommonTableCell(), - }, - { - title: 'lr', - dataIndex: 'name', - key: 'name', - width: '30%', - render: CommonTableCell(), - }, - { - title: 'warmup_iters', - dataIndex: 'name', - key: 'name', - width: '30%', - render: CommonTableCell(), - }, - ], - }, - { - title: '训练指标', - children: [ - { - title: 'metrc_name', - dataIndex: 'name', - key: 'name', - width: '30%', - render: CommonTableCell(), - }, - { - title: 'test_1', - dataIndex: 'name', - key: 'name', - width: '30%', - render: CommonTableCell(), - }, - { - title: 'test_2', - dataIndex: 'name', - key: 'name', - width: '30%', - render: CommonTableCell(), - }, - { - title: 'test_3', - dataIndex: 'name', - key: 'name', - width: '30%', - render: CommonTableCell(), - }, - { - title: 'test_4', - dataIndex: 'name', - key: 'name', - width: '30%', - render: CommonTableCell(), - }, - ], - }, - ]; + const columns: TableProps['columns'] = useMemo(() => { + const first: TableData | undefined = tableData[0]; + return [ + { + title: '基本信息', + children: [ + { + title: '实例 ID', + dataIndex: 'experiment_ins_id', + key: 'experiment_ins_id', + width: '20%', + render: tableCellRender(), + }, + { + title: '运行时间', + dataIndex: 'start_time', + key: 'start_time', + width: 180, + render: tableCellRender(false, dateFormatter), + }, + { + title: '运行状态', + dataIndex: 'status', + key: 'status', + width: '20%', + render: ExperimentStatusCell, + }, + { + title: '训练数据集', + dataIndex: 'dataset', + key: 'dataset', + width: '20%', + render: tableCellRender(true, arrayFormatter()), + ellipsis: { showTitle: false }, + }, + ], + }, + { + title: '训练参数', + children: first?.params_names.map((name) => ({ + title: name, + dataIndex: ['params', name], + key: name, + width: '20%', + render: tableCellRender(true), + ellipsis: { showTitle: false }, + })), + }, + { + title: '训练指标', + children: first?.metrics_names.map((name) => ({ + title: name, + dataIndex: ['metrics', name], + key: name, + width: '20%', + render: tableCellRender(true), + ellipsis: { showTitle: false }, + })), + }, + ]; + }, [tableData]); return (
- +
-
+
diff --git a/react-ui/src/pages/Experiment/training/index.jsx b/react-ui/src/pages/Experiment/Info/index.jsx similarity index 100% rename from react-ui/src/pages/Experiment/training/index.jsx rename to react-ui/src/pages/Experiment/Info/index.jsx diff --git a/react-ui/src/pages/Experiment/training/index.less b/react-ui/src/pages/Experiment/Info/index.less similarity index 100% rename from react-ui/src/pages/Experiment/training/index.less rename to react-ui/src/pages/Experiment/Info/index.less diff --git a/react-ui/src/pages/Experiment/training/props.less b/react-ui/src/pages/Experiment/Info/props.less similarity index 100% rename from react-ui/src/pages/Experiment/training/props.less rename to react-ui/src/pages/Experiment/Info/props.less diff --git a/react-ui/src/pages/Experiment/training/props.tsx b/react-ui/src/pages/Experiment/Info/props.tsx similarity index 100% rename from react-ui/src/pages/Experiment/training/props.tsx rename to react-ui/src/pages/Experiment/Info/props.tsx diff --git a/react-ui/src/pages/Experiment/components/ExperimentStatusCell/index.less b/react-ui/src/pages/Experiment/components/ExperimentStatusCell/index.less new file mode 100644 index 00000000..a5df71aa --- /dev/null +++ b/react-ui/src/pages/Experiment/components/ExperimentStatusCell/index.less @@ -0,0 +1,12 @@ +.experiment-status-cell { + height: 100%; + &__label { + display: none; + } +} + +.experiment-status-cell:hover { + .experiment-status-cell__label { + display: inline; + } +} diff --git a/react-ui/src/pages/Experiment/components/ExperimentStatusCell/index.tsx b/react-ui/src/pages/Experiment/components/ExperimentStatusCell/index.tsx new file mode 100644 index 00000000..373b6ffa --- /dev/null +++ b/react-ui/src/pages/Experiment/components/ExperimentStatusCell/index.tsx @@ -0,0 +1,28 @@ +/* + * @Author: 赵伟 + * @Date: 2024-04-18 18:35:41 + * @Description: 实验状态 + */ + +import { ExperimentStatus } from '@/enums'; +import { experimentStatusInfo as statusInfo } from '@/pages/Experiment/status'; +import styles from './index.less'; + +function ExperimentStatusCell(status?: ExperimentStatus | null) { + if (status === null || status === undefined || !statusInfo[status]) { + return --; + } + return ( +
+ + + {statusInfo[status]?.label} + +
+ ); +} + +export default ExperimentStatusCell; diff --git a/react-ui/src/pages/Experiment/components/LogGroup/index.tsx b/react-ui/src/pages/Experiment/components/LogGroup/index.tsx index a3da3044..cd84406d 100644 --- a/react-ui/src/pages/Experiment/components/LogGroup/index.tsx +++ b/react-ui/src/pages/Experiment/components/LogGroup/index.tsx @@ -4,9 +4,9 @@ * @Description: 日志组件 */ +import { ExperimentStatus } from '@/enums'; import { useStateRef } from '@/hooks'; -import { ExperimentStatus } from '@/pages/Experiment/status'; -import { ExperimentLog } from '@/pages/Experiment/training/props'; +import { ExperimentLog } from '@/pages/Experiment/Info/props'; import { getExperimentPodsLog } from '@/services/experiment/index.js'; import { DoubleRightOutlined, DownOutlined, UpOutlined } from '@ant-design/icons'; import { Button } from 'antd'; @@ -47,7 +47,7 @@ function LogGroup({ const [logList, setLogList, logListRef] = useStateRef([]); const [completed, setCompleted] = useState(false); // eslint-disable-next-line @typescript-eslint/no-unused-vars - const [isMouseDown, setIsMouseDown, isMouseDownRef] = useStateRef(false); + const [_isMouseDown, setIsMouseDown, isMouseDownRef] = useStateRef(false); useEffect(() => { scrollToBottom(false); diff --git a/react-ui/src/pages/Experiment/components/LogList/index.tsx b/react-ui/src/pages/Experiment/components/LogList/index.tsx index 8a2ade14..0d833107 100644 --- a/react-ui/src/pages/Experiment/components/LogList/index.tsx +++ b/react-ui/src/pages/Experiment/components/LogList/index.tsx @@ -1,5 +1,5 @@ -import { ExperimentStatus } from '@/pages/Experiment/status'; -import { ExperimentLog } from '@/pages/Experiment/training/props'; +import { ExperimentStatus } from '@/enums'; +import { ExperimentLog } from '@/pages/Experiment/Info/props'; import LogGroup from '../LogGroup'; import styles from './index.less'; diff --git a/react-ui/src/pages/Experiment/components/TensorBoardStatus/index.tsx b/react-ui/src/pages/Experiment/components/TensorBoardStatus/index.tsx index d8dad901..8a6f5b7c 100644 --- a/react-ui/src/pages/Experiment/components/TensorBoardStatus/index.tsx +++ b/react-ui/src/pages/Experiment/components/TensorBoardStatus/index.tsx @@ -5,16 +5,15 @@ import classNames from 'classnames'; import styles from './index.less'; // import stopImg from '@/assets/img/tensor-board-stop.png'; import terminatedImg from '@/assets/img/tensor-board-terminated.png'; +import { TensorBoardStatus } from '@/enums'; -export enum TensorBoardStatusEnum { - Unknown = 'Unknown', // 未知 - Pending = 'Pending', // 启动中 - Running = 'Running', // 运行中 - Terminated = 'Terminated', // 未启动或者已终止 - Failed = 'Failed', // 失败 -} +type TensorBoardStatusInfo = { + label: string; + icon: string; + classname: string; +}; -const statusConfig = { +const statusConfig: Record = { Unknown: { label: '未知', icon: terminatedImg, @@ -43,12 +42,12 @@ const statusConfig = { }; type TensorBoardStatusProps = { - status: TensorBoardStatusEnum; + status: TensorBoardStatus; onClick: () => void; }; -function TensorBoardStatus({ - status = TensorBoardStatusEnum.Unknown, +function TensorBoardStatusCell({ + status = TensorBoardStatus.Unknown, onClick, }: TensorBoardStatusProps) { return ( @@ -64,7 +63,7 @@ function TensorBoardStatus({ {statusConfig[status].icon ? ( <>
|
- {status === TensorBoardStatusEnum.Pending ? ( + {status === TensorBoardStatus.Pending ? ( ) : ( { if ( - experimentIn.tensorBoardStatus === TensorBoardStatusEnum.Terminated || - experimentIn.tensorBoardStatus === TensorBoardStatusEnum.Failed + experimentIn.tensorBoardStatus === TensorBoardStatus.Terminated || + experimentIn.tensorBoardStatus === TensorBoardStatus.Failed ) { await runTensorBoard(experimentIn); } else if ( - experimentIn.tensorBoardStatus === TensorBoardStatusEnum.Running && + experimentIn.tensorBoardStatus === TensorBoardStatus.Running && experimentIn.tensorboardUrl ) { window.open(experimentIn.tensorboardUrl, '_blank'); @@ -457,12 +458,12 @@ function Experiment() {
{item.nodes_result?.tensorboard_log ? ( - handleTensorboard(item)} - > + > ) : ( - '-' + '--' )}
diff --git a/react-ui/src/pages/Experiment/status.ts b/react-ui/src/pages/Experiment/status.ts index b9c45af6..02b68f53 100644 --- a/react-ui/src/pages/Experiment/status.ts +++ b/react-ui/src/pages/Experiment/status.ts @@ -1,23 +1,13 @@ +import { ExperimentStatus } from '@/enums'; import themes from '@/styles/theme.less'; -export interface StatusInfo { +export interface ExperimentStatusInfo { label: string; color: string; icon: string; } -export enum ExperimentStatus { - Running = 'Running', - Succeeded = 'Succeeded', - Pending = 'Pending', - Failed = 'Failed', - Error = 'Error', - Terminated = 'Terminated', - Skipped = 'Skipped', - Omitted = 'Omitted', -} - -export const experimentStatusInfo: Record = { +export const experimentStatusInfo: Record = { Running: { label: '运行中', color: themes.primaryColor, diff --git a/react-ui/src/pages/ModelDeployment/Info/index.tsx b/react-ui/src/pages/ModelDeployment/Info/index.tsx index a548e93a..4f73f46f 100644 --- a/react-ui/src/pages/ModelDeployment/Info/index.tsx +++ b/react-ui/src/pages/ModelDeployment/Info/index.tsx @@ -17,9 +17,9 @@ import { ModelDeploymentData } from '../types'; import styles from './index.less'; export enum ModelDeploymentTabKey { - Predict = 'Predict', - Guide = 'Guide', - Log = 'Log', + Predict = 'Predict', // 预测 + Guide = 'Guide', // 调用指南 + Log = 'Log', // 服务日志 } function ModelDeploymentInfo() { diff --git a/react-ui/src/pages/ModelDeployment/List/index.tsx b/react-ui/src/pages/ModelDeployment/List/index.tsx index bc1f03dd..934b4cbd 100644 --- a/react-ui/src/pages/ModelDeployment/List/index.tsx +++ b/react-ui/src/pages/ModelDeployment/List/index.tsx @@ -166,7 +166,7 @@ function ModelDeployment() { }; // 分页切换 - const handleTableChange: TableProps['onChange'] = (pagination, filters, sorter, { action }) => { + const handleTableChange: TableProps['onChange'] = (pagination, _filters, _sorter, { action }) => { if (action === 'paginate') { setPagination(pagination); } @@ -179,7 +179,7 @@ function ModelDeployment() { dataIndex: 'index', key: 'index', width: '20%', - render(text, record, index) { + render(_text, _record, index) { return {(pagination.current! - 1) * pagination.pageSize! + index + 1}; }, }, diff --git a/react-ui/src/pages/ModelDeployment/types.ts b/react-ui/src/pages/ModelDeployment/types.ts index c8dfe808..4bdf28c8 100644 --- a/react-ui/src/pages/ModelDeployment/types.ts +++ b/react-ui/src/pages/ModelDeployment/types.ts @@ -26,7 +26,7 @@ export type ModelDeploymentData = { // 操作类型 export enum ModelDeploymentOperationType { - Create = 'Create', - Update = 'Update', - Restart = 'Restart', + Create = 'Create', // 创建 + Update = 'Update', // 更新 + Restart = 'Restart', // 重启 } diff --git a/react-ui/src/services/experiment/index.js b/react-ui/src/services/experiment/index.js index 89028b24..66d27eb6 100644 --- a/react-ui/src/services/experiment/index.js +++ b/react-ui/src/services/experiment/index.js @@ -126,7 +126,15 @@ export function getExpEvaluateInfosReq(experimentId) { // 获取当前实验的模型训练指标信息 export function getExpTrainInfosReq(experimentId) { - return request(`/api/mmp//aim/getExpTrainInfos/${experimentId}`, { + return request(`/api/mmp/aim/getExpTrainInfos/${experimentId}`, { method: 'GET', }); } + +// 获取当前实验的指标对比地址 +export function getExpMetricsReq(data) { + return request(`/api/mmp/aim/getExpMetrics`, { + method: 'POST', + data + }); +} diff --git a/react-ui/src/types.ts b/react-ui/src/types.ts index 57dc3856..aa9b0e5e 100644 --- a/react-ui/src/types.ts +++ b/react-ui/src/types.ts @@ -4,7 +4,7 @@ * @Description: 定义全局类型,比如无关联的页面都需要要的类型 */ -import { ExperimentStatus } from '@/pages/Experiment/status'; +import { ExperimentStatus } from '@/enums'; // 流水线全局参数 export type PipelineGlobalParam = { diff --git a/react-ui/src/utils/table.tsx b/react-ui/src/utils/table.tsx new file mode 100644 index 00000000..3058284a --- /dev/null +++ b/react-ui/src/utils/table.tsx @@ -0,0 +1,68 @@ +/* + * @Author: 赵伟 + * @Date: 2024-06-26 10:05:52 + * @Description: 列表自定义 render + */ + +import { formatDate } from '@/utils/date'; +import { Tooltip } from 'antd'; +import dayjs from 'dayjs'; + +type TableCellFormatter = (value?: any | null) => string | undefined | null; + +// 字符串转换函数 +export const stringFormatter: TableCellFormatter = (value?: any | null) => { + return value; +}; + +// 日期转换函数 +export const dateFormatter: TableCellFormatter = (value?: any | null) => { + if (value === undefined || value === null || value === '') { + return null; + } + if (!dayjs(value).isValid()) { + return null; + } + return formatDate(value); +}; + +// 数组转换函数 +export function arrayFormatter(property?: string) { + return (value?: any | null): ReturnType => { + if ( + value === undefined || + value === null || + Array.isArray(value) === false || + value.length === 0 + ) { + return null; + } + + let list = value; + if (property && typeof value[0] === 'object') { + list = value.map((item) => item[property]); + } + return list.join(','); + }; +} + +function tableCellRender(ellipsis: boolean = false, format: TableCellFormatter = stringFormatter) { + return (value?: any | null) => { + const text = format(value); + if (ellipsis && text) { + return ( + + {renderCell(text)} + + ); + } else { + return renderCell(text); + } + }; +} + +function renderCell(text?: any | null) { + return {text ?? '--'}; +} + +export default tableCellRender; diff --git a/ruoyi-modules/management-platform/pom.xml b/ruoyi-modules/management-platform/pom.xml index 4cb1ae7d..37234ad0 100644 --- a/ruoyi-modules/management-platform/pom.xml +++ b/ruoyi-modules/management-platform/pom.xml @@ -205,6 +205,17 @@ org.springframework.boot spring-boot-starter-websocket + + org.json + json + 20210307 + + + org.apache.dubbo + dubbo + 3.0.8 + compile + diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/aim/AimController.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/aim/AimController.java index f1750133..f6b0b863 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/aim/AimController.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/aim/AimController.java @@ -4,16 +4,16 @@ import com.ruoyi.common.core.web.controller.BaseController; import com.ruoyi.common.core.web.domain.GenericsAjaxResult; import com.ruoyi.platform.service.AimService; import com.ruoyi.platform.vo.FrameLogPathVo; +import com.ruoyi.platform.vo.InsMetricInfoVo; import com.ruoyi.platform.vo.PodStatusVo; import io.swagger.annotations.Api; import io.swagger.annotations.ApiOperation; import io.swagger.v3.oas.annotations.responses.ApiResponse; -import org.springframework.web.bind.annotation.PostMapping; -import org.springframework.web.bind.annotation.RequestBody; -import org.springframework.web.bind.annotation.RequestMapping; -import org.springframework.web.bind.annotation.RestController; +import org.springframework.web.bind.annotation.*; import javax.annotation.Resource; +import java.util.List; + @RestController @RequestMapping("aim") @Api("Aim管理") @@ -22,17 +22,25 @@ public class AimController extends BaseController { @Resource private AimService aimService; - /** - * 启动tensorBoard接口 - * - * @param frameLogPathVo 存储路径 - * @return url - */ - @PostMapping("/run") - @ApiOperation("启动aim`") + + @GetMapping("/getExpTrainInfos/{experiment_id}") + @ApiOperation("获取当前实验的模型训练指标信息") @ApiResponse - public GenericsAjaxResult runAim(@RequestBody FrameLogPathVo frameLogPathVo) throws Exception { - return genericsSuccess(aimService.runAim(frameLogPathVo)); + public GenericsAjaxResult> getExpTrainInfos(@PathVariable("experiment_id") Integer experimentId) throws Exception { + return genericsSuccess(aimService.getExpTrainInfos(experimentId)); } + @GetMapping("/getExpEvaluateInfos/{experiment_id}") + @ApiOperation("获取当前实验的模型推理指标信息") + @ApiResponse + public GenericsAjaxResult> getExpEvaluateInfos(@PathVariable("experiment_id") Integer experimentId) throws Exception { + return genericsSuccess(aimService.getExpEvaluateInfos(experimentId)); + } + + @PostMapping("/getExpMetrics") + @ApiOperation("获取当前实验的指标对比地址") + @ApiResponse + public GenericsAjaxResult getExpMetrics(@RequestBody List runIds) throws Exception { + return genericsSuccess(aimService.getExpMetrics(runIds)); + } } \ No newline at end of file diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/jupyter/JupyterController.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/jupyter/JupyterController.java index 0abcae75..3167e2ec 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/jupyter/JupyterController.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/jupyter/JupyterController.java @@ -3,6 +3,7 @@ package com.ruoyi.platform.controller.jupyter; import com.ruoyi.common.core.web.controller.BaseController; import com.ruoyi.common.core.web.domain.AjaxResult; import com.ruoyi.common.core.web.domain.GenericsAjaxResult; +import com.ruoyi.platform.domain.DevEnvironment; import com.ruoyi.platform.service.JupyterService; import com.ruoyi.platform.vo.FrameLogPathVo; import com.ruoyi.platform.vo.PodStatusVo; @@ -60,8 +61,8 @@ public class JupyterController extends BaseController { @PostMapping("/getStatus") @ApiOperation("查询jupyter pod状态") @ApiResponse - public GenericsAjaxResult getStatus(@RequestBody FrameLogPathVo frameLogPathVo) throws Exception { - return genericsSuccess(this.jupyterService.getJupyterStatus(frameLogPathVo)); + public GenericsAjaxResult getStatus(DevEnvironment devEnvironment) throws Exception { + return genericsSuccess(this.jupyterService.getJupyterStatus(devEnvironment)); } diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/mapper/ModelDependencyDao.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/mapper/ModelDependencyDao.java index 3a999886..ba1bc40b 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/mapper/ModelDependencyDao.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/mapper/ModelDependencyDao.java @@ -84,5 +84,9 @@ public interface ModelDependencyDao { List queryByModelDependency(@Param("modelDependency") ModelDependency modelDependency); List queryChildrenByVersionId(@Param("model_id")String modelId, @Param("version")String version); + + List queryByIns(@Param("expInsId")Integer expInsId); + + ModelDependency queryByInsAndTrainTaskId(@Param("expInsId")Integer expInsId,@Param("taskId") String taskId); } diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/scheduling/ExperimentInstanceStatusTask.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/scheduling/ExperimentInstanceStatusTask.java index 4680285e..131dca48 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/scheduling/ExperimentInstanceStatusTask.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/scheduling/ExperimentInstanceStatusTask.java @@ -7,20 +7,15 @@ import com.ruoyi.platform.mapper.ExperimentDao; import com.ruoyi.platform.mapper.ExperimentInsDao; import com.ruoyi.platform.mapper.ModelDependencyDao; import com.ruoyi.platform.service.ExperimentInsService; -import com.ruoyi.platform.service.ModelDependencyService; import com.ruoyi.platform.utils.JacksonUtil; import org.apache.commons.lang3.StringUtils; import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.data.domain.Page; import org.springframework.scheduling.annotation.Scheduled; import org.springframework.stereotype.Component; import javax.annotation.Resource; import java.io.IOException; -import java.util.ArrayList; -import java.util.Date; -import java.util.List; -import java.util.Map; +import java.util.*; @Component() public class ExperimentInstanceStatusTask { @@ -34,7 +29,7 @@ public class ExperimentInstanceStatusTask { private ModelDependencyDao modelDependencyDao; private List experimentIds = new ArrayList<>(); - @Scheduled(cron = "0/14 * * * * ?") // 每30S执行一次 + @Scheduled(cron = "0/30 * * * * ?") // 每30S执行一次 public void executeExperimentInsStatus() throws IOException { // 首先查到所有非终止态的实验实例 List experimentInsList = experimentInsService.queryByExperimentIsNotTerminated(); @@ -46,95 +41,94 @@ public class ExperimentInstanceStatusTask { String oldStatus = experimentIns.getStatus(); try { experimentIns = experimentInsService.queryStatusFromArgo(experimentIns); - }catch (Exception e){ + } catch (Exception e) { experimentIns.setStatus("Failed"); } -// if (!StringUtils.equals(oldStatus,experimentIns.getStatus())){ - experimentIns.setUpdateTime(new Date()); - // 线程安全的添加操作 - synchronized (experimentIds) { - experimentIds.add(experimentIns.getExperimentId()); - } - updateList.add(experimentIns); - -// } -// experimentInsDao.update(experimentIns); + experimentIns.setUpdateTime(new Date()); + // 线程安全的添加操作 + synchronized (experimentIds) { + experimentIds.add(experimentIns.getExperimentId()); + } + updateList.add(experimentIns); } - } - if (updateList.size() > 0){ + if (updateList.size() > 0) { experimentInsDao.insertOrUpdateBatch(updateList); //遍历模型关系表,找到 List modelDependencyList = new ArrayList(); - for (ExperimentIns experimentIns : updateList){ + for (ExperimentIns experimentIns : updateList) { ModelDependency modelDependencyquery = new ModelDependency(); modelDependencyquery.setExpInsId(experimentIns.getId()); modelDependencyquery.setState(2); List modelDependencyListquery = modelDependencyDao.queryByModelDependency(modelDependencyquery); - if (modelDependencyListquery==null||modelDependencyListquery.size()==0){ + if (modelDependencyListquery == null || modelDependencyListquery.size() == 0) { continue; } ModelDependency modelDependency = modelDependencyListquery.get(0); //查看状态, - if (StringUtils.equals("Failed",experimentIns.getStatus())){ + if (StringUtils.equals("Failed", experimentIns.getStatus())) { //取出节点状态 String trainTask = modelDependency.getTrainTask(); Map trainMap = JacksonUtil.parseJSONStr2Map(trainTask); String task_id = (String) trainMap.get("task_id"); - if (StringUtils.isEmpty(task_id)){ + if (StringUtils.isEmpty(task_id)) { continue; } String nodesStatus = experimentIns.getNodesStatus(); Map nodeMaps = JacksonUtil.parseJSONStr2Map(nodesStatus); Map nodeMap = JacksonUtil.parseJSONStr2Map(JacksonUtil.toJSONString(nodeMaps.get(task_id))); - if (nodeMap==null){ + if (nodeMap == null) { continue; } - if (!StringUtils.equals("Succeeded",(String)nodeMap.get("phase"))){ + if (!StringUtils.equals("Succeeded", (String) nodeMap.get("phase"))) { modelDependency.setState(0); modelDependencyList.add(modelDependency); } } } - if (modelDependencyList.size()>0) { + if (modelDependencyList.size() > 0) { modelDependencyDao.insertOrUpdateBatch(modelDependencyList); } } - } - @Scheduled(cron = "0/17 * * * * ?") // / 每30S执行一次 + + @Scheduled(cron = "0/30 * * * * ?") // / 每30S执行一次 public void executeExperimentStatus() throws IOException { - if (experimentIds.size()==0){ + if (experimentIds.size() == 0) { return; } // 存储需要更新的实验对象列表 List updateExperiments = new ArrayList<>(); - for (Integer experimentId : experimentIds){ + for (Integer experimentId : experimentIds) { // 获取当前实验的所有实例列表 List insList = experimentInsService.getByExperimentId(experimentId); List statusList = new ArrayList(); // 更新实验状态列表 - for (int i=0;i iterator = experimentIds.iterator(); + while (iterator.hasNext()) { + Integer experimentId = iterator.next(); + for (Experiment experiment : updateExperiments) { + if (experiment.getId().equals(experimentId)) { + iterator.remove(); + } } } } diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/AimService.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/AimService.java index 60f74b90..c83a42af 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/AimService.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/AimService.java @@ -1,7 +1,14 @@ package com.ruoyi.platform.service; -import com.ruoyi.platform.vo.FrameLogPathVo; +import com.ruoyi.platform.vo.InsMetricInfoVo; + +import java.util.List; public interface AimService { - String runAim(FrameLogPathVo frameLogPathVo); + + List getExpTrainInfos(Integer experimentId) throws Exception; + + List getExpEvaluateInfos(Integer experimentId) throws Exception; + + String getExpMetrics(List runIds) throws Exception; } diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/JupyterService.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/JupyterService.java index 09a7fc68..b2af9cca 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/JupyterService.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/JupyterService.java @@ -1,6 +1,6 @@ package com.ruoyi.platform.service; -import com.ruoyi.platform.vo.FrameLogPathVo; +import com.ruoyi.platform.domain.DevEnvironment; import com.ruoyi.platform.vo.PodStatusVo; import java.io.InputStream; @@ -16,5 +16,5 @@ public interface JupyterService { String stopJupyterService(Integer id) throws Exception; - PodStatusVo getJupyterStatus(FrameLogPathVo frameLogPathVo); + PodStatusVo getJupyterStatus(DevEnvironment devEnvironment) throws Exception; } diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/ModelDependencyService.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/ModelDependencyService.java index 5c8b9d1d..049d87d1 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/ModelDependencyService.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/ModelDependencyService.java @@ -62,4 +62,8 @@ public interface ModelDependencyService { List queryByModelDependency(ModelDependency modelDependency) throws IOException; ModelDependcyTreeVo getModelDependencyTree(ModelDependency modelDependency) throws Exception; + + List queryByIns(Integer expInsId); + + ModelDependency queryByInsAndTrainTaskId(Integer expInsId, String taskId); } diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/AimServiceImpl.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/AimServiceImpl.java index f66dc178..fdd7b5c9 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/AimServiceImpl.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/AimServiceImpl.java @@ -1,13 +1,157 @@ package com.ruoyi.platform.service.impl; +import com.ruoyi.platform.domain.ExperimentIns; +import com.ruoyi.platform.domain.ModelDependency; import com.ruoyi.platform.service.AimService; -import com.ruoyi.platform.vo.FrameLogPathVo; +import com.ruoyi.platform.service.ExperimentInsService; +import com.ruoyi.platform.service.ModelDependencyService; +import com.ruoyi.platform.utils.AIM64EncoderUtil; +import com.ruoyi.platform.utils.HttpUtils; +import com.ruoyi.platform.utils.JacksonUtil; +import com.ruoyi.platform.utils.JsonUtils; +import com.ruoyi.platform.vo.InsMetricInfoVo; +import org.apache.commons.lang3.StringUtils; +import org.springframework.beans.factory.annotation.Value; import org.springframework.stereotype.Service; +import javax.annotation.Resource; +import java.net.URLEncoder; +import java.util.*; +import java.util.stream.Collectors; + @Service public class AimServiceImpl implements AimService { + @Resource + private ExperimentInsService experimentInsService; + + @Value("${aim.url}") + private String aimUrl; + @Value("${aim.proxyUrl}") + private String aimProxyUrl; + + @Override + public List getExpTrainInfos(Integer experimentId) throws Exception { + return getAimRunInfos(true,experimentId); + } + @Override - public String runAim(FrameLogPathVo frameLogPathVo) { - return null; + public List getExpEvaluateInfos(Integer experimentId) throws Exception { + return getAimRunInfos(false,experimentId); + } + + @Override + public String getExpMetrics(List runIds) throws Exception { + String decode = AIM64EncoderUtil.decode(runIds); + return aimUrl+"/metrics?select="+decode; + } + + private List getAimRunInfos(boolean isTrain,Integer experimentId) throws Exception { + String experimentName = "experiment-"+experimentId+"-train"; + if (!isTrain){ + experimentName = "experiment-"+experimentId+"-evaluate"; + } + String encodedUrlString = URLEncoder.encode("run.experiment==\""+experimentName+"\"", "UTF-8"); + String url = aimProxyUrl+"/api/runs/search/run?query="+encodedUrlString; + String s = HttpUtils.sendGetRequest(url); + List> response = JacksonUtil.parseJSONStr2MapList(s); + System.out.println("response: "+JacksonUtil.toJSONString(response)); + if (response == null || response.size() == 0){ + return new ArrayList<>(); + } + //查询实例数据 + List byExperimentId = experimentInsService.getByExperimentId(experimentId); + + if (byExperimentId == null || byExperimentId.size() == 0){ + return new ArrayList<>(); + } + List aimRunInfoList = new ArrayList<>(); + for (Map run : response) { + InsMetricInfoVo aimRunInfo = new InsMetricInfoVo(); + String runHash = (String) run.get("run_hash"); + + aimRunInfo.setRunId(runHash); + + Map params= (Map) run.get("params"); + Map paramMap = JsonUtils.flattenJson("", params); + aimRunInfo.setParams(paramMap); + String aimrunId = (String) paramMap.get("id"); + Map tracesMap= (Map) run.get("traces"); + List> metricList = (List>) tracesMap.get("metric"); + //过滤name为__system__开头的对象 + aimRunInfo.setMetrics(new HashMap<>()); + if (metricList != null && metricList.size() > 0){ + List> metricRelList = metricList.stream() + .filter(map -> !StringUtils.startsWith((String) map.get("name"),"__system__" )) + .collect(Collectors.toList()); + if (metricRelList!= null && metricRelList.size() > 0){ + Map relMetricMap = new HashMap<>(); + for (Map metricMap : metricRelList) { + relMetricMap.put((String)metricMap.get("name"), metricMap.get("last_value")); + } + aimRunInfo.setMetrics(relMetricMap); + } + } + //找到ins + for (ExperimentIns ins : byExperimentId) { + String metricRecordString = ins.getMetricRecord(); + if (StringUtils.isEmpty(metricRecordString)){ + continue; + } + if (metricRecordString.contains(aimrunId)){ + aimRunInfo.setExperimentInsId(ins.getId()); + aimRunInfo.setStatus(ins.getStatus()); + aimRunInfo.setStartTime(ins.getCreateTime()); + Map metricRecordMap = JacksonUtil.parseJSONStr2Map(metricRecordString); + if (isTrain){ + List> records = (List>) metricRecordMap.get("train"); + List datasetList = getTrainDateSet(records, aimrunId); + aimRunInfo.setDataset(datasetList); + }else { + List> records = (List>) metricRecordMap.get("evaluate"); + List datasetList = getTrainDateSet(records, aimrunId); + aimRunInfo.setDataset(datasetList); + } + } + } + aimRunInfoList.add(aimRunInfo); + } + //判断哪个最长 + + // 获取所有 metrics 的 key 的并集 + Set metricsKeys = (Set) aimRunInfoList.stream() + .map(InsMetricInfoVo::getMetrics) + .flatMap(metrics -> metrics.keySet().stream()) + .collect(Collectors.toSet()); + // 将并集赋值给每个 InsMetricInfoVo 的 metricsNames 属性 + aimRunInfoList.forEach(vo -> vo.setMetricsNames(new ArrayList<>(metricsKeys))); + + // 获取所有 params 的 key 的并集 + Set paramKeys = (Set) aimRunInfoList.stream() + .map(InsMetricInfoVo::getParams) + .flatMap(params -> params.keySet().stream()) + .collect(Collectors.toSet()); + // 将并集赋值给每个 InsMetricInfoVo 的 paramsNames 属性 + aimRunInfoList.forEach(vo -> vo.setParamsNames(new ArrayList<>(paramKeys))); + + return aimRunInfoList; + } + + + private List getTrainDateSet(List> records, String aimrunId){ + List datasetList = new ArrayList<>(); + for (Map record : records) { + if (StringUtils.equals(aimrunId, (String)record.get("run_id"))) { + List> datasets = (List>) record.get("datasets"); + if (datasets == null || datasets.size() == 0){ + continue; + } + for (Map dataset : datasets){ + String datasetName = (String) dataset.get("dataset_name")+":"+(String) dataset.get("dataset_version"); + datasetList.add(datasetName); + } + break; + } + } + return datasetList; } } diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/DevEnvironmentServiceImpl.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/DevEnvironmentServiceImpl.java index 4b8973b9..7992f089 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/DevEnvironmentServiceImpl.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/DevEnvironmentServiceImpl.java @@ -2,14 +2,17 @@ package com.ruoyi.platform.service.impl; import com.ruoyi.common.security.utils.SecurityUtils; import com.ruoyi.platform.domain.DevEnvironment; +import com.ruoyi.platform.domain.PodStatus; import com.ruoyi.platform.mapper.DevEnvironmentDao; import com.ruoyi.platform.service.DevEnvironmentService; import com.ruoyi.platform.service.JupyterService; import com.ruoyi.platform.utils.JacksonUtil; import com.ruoyi.platform.vo.DevEnvironmentVo; +import com.ruoyi.platform.vo.PodStatusVo; import com.ruoyi.system.api.model.LoginUser; import io.kubernetes.client.openapi.models.V1PersistentVolumeClaim; import org.apache.commons.lang3.StringUtils; +import org.springframework.context.annotation.Lazy; import org.springframework.stereotype.Service; import org.springframework.data.domain.Page; import org.springframework.data.domain.PageImpl; @@ -17,6 +20,7 @@ import org.springframework.data.domain.PageRequest; import javax.annotation.Resource; import java.util.Date; +import java.util.List; import java.util.Map; /** @@ -30,6 +34,10 @@ public class DevEnvironmentServiceImpl implements DevEnvironmentService { @Resource private DevEnvironmentDao devEnvironmentDao; + @Resource + @Lazy + private JupyterService jupyterService; + /** * 通过ID查询单条数据 @@ -52,13 +60,29 @@ public class DevEnvironmentServiceImpl implements DevEnvironmentService { @Override public Page queryByPage(DevEnvironment devEnvironment, PageRequest pageRequest) { long total = this.devEnvironmentDao.count(devEnvironment); - return new PageImpl<>(this.devEnvironmentDao.queryAllByLimit(devEnvironment, pageRequest), pageRequest, total); + List devEnvironmentList = this.devEnvironmentDao.queryAllByLimit(devEnvironment, pageRequest); + + //查询每个开发环境的pod状态,注意:只有pod为非终止态时才去调状态接口 + devEnvironmentList.forEach(devEnv -> { + try{ + if (!devEnv.getStatus().equals(PodStatus.Terminated.getName()) && + !devEnv.getStatus().equals(PodStatus.Failed.getName())) { + PodStatusVo podStatusVo = this.jupyterService.getJupyterStatus(devEnv); + devEnv.setStatus(podStatusVo.getStatus()); + devEnv.setUrl(podStatusVo.getUrl()); + } + } catch (Exception e) { + devEnv.setStatus(PodStatus.Unknown.getName()); + } + }); + + return new PageImpl<>(devEnvironmentList, pageRequest, total); } /** * 新增数据 * - * @param devEnvironment 实例对象 + * @param devEnvironmentVo 实例对象 * @return 实例对象 */ @Override diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ExperimentInsServiceImpl.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ExperimentInsServiceImpl.java index 5a71ac66..ece171ac 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ExperimentInsServiceImpl.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ExperimentInsServiceImpl.java @@ -102,7 +102,6 @@ public class ExperimentInsServiceImpl implements ExperimentInsService { */ @Override public List getByExperimentId(Integer experimentId) throws IOException { - List experimentInsList = experimentInsDao.getByExperimentId(experimentId); //代码全部迁移至定时任务 //搞个标记,当状态改变才去改表 @@ -138,7 +137,7 @@ public class ExperimentInsServiceImpl implements ExperimentInsService { // experimentDao.update(experiment); // } - return experimentInsList; + return experimentInsDao.getByExperimentId(experimentId); } diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ExperimentServiceImpl.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ExperimentServiceImpl.java index 893d23e7..388155d6 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ExperimentServiceImpl.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ExperimentServiceImpl.java @@ -251,12 +251,14 @@ public class ExperimentServiceImpl implements ExperimentService { if (data == null || MapUtils.isEmpty(data)) { throw new RuntimeException("Failed to run workflow."); } - //获取训练参数 - Map metricRecord = (Map) runResMap.get("metric_record"); + + Map metadata = (Map) data.get("metadata"); // 插入记录到实验实例表 ExperimentIns experimentIns = new ExperimentIns(); + //获取训练参数 + experimentIns.setExperimentId(experiment.getId()); experimentIns.setArgoInsNs((String) metadata.get("namespace")); experimentIns.setArgoInsName((String) metadata.get("name")); @@ -267,16 +269,25 @@ public class ExperimentServiceImpl implements ExperimentService { //替换argoInsName String outputString = JsonUtils.mapToJson(output); experimentIns.setNodesResult(outputString.replace("{{workflow.name}}", (String) metadata.get("name"))); - //插入ExperimentIns表中 - ExperimentIns insert = experimentInsService.insert(experimentIns); - //插入到模型依赖关系表 + //得到dependendcy Map converMap2 = JsonUtils.jsonToMap(JacksonUtil.replaceInAarry(convertRes, params)); Map dependendcy = (Map)converMap2.get("model_dependency"); Map trainInfo = (Map)converMap2.get("component_info"); - insertModelDependency(dependendcy,trainInfo,insert.getId(),experiment.getName()); + Map metricRecord = (Map) runResMap.get("metric_record"); + if (metricRecord != null){ + //把训练用的数据集也放进去 + addDatesetToMetric(metricRecord, trainInfo); + experimentIns.setMetricRecord(JacksonUtil.toJSONString(metricRecord)); + } + //插入ExperimentIns表中 + ExperimentIns insert = experimentInsService.insert(experimentIns); + //插入到模型依赖关系表 + if (dependendcy != null && trainInfo != null){ + insertModelDependency(dependendcy,trainInfo,insert.getId(),experiment.getName()); + } }catch (Exception e){ throw new RuntimeException(e); } @@ -284,6 +295,37 @@ public class ExperimentServiceImpl implements ExperimentService { experiment.setExperimentInsList(updatedExperimentInsList); return experiment; } + private void addDatesetToMetric(Map metricRecord, Map trainInfo) { + processMetricPart(metricRecord, trainInfo, "train", "model_train"); + processMetricPart(metricRecord, trainInfo, "evaluate", "model_evaluate"); + } + + private void processMetricPart(Map metricRecord, Map trainInfo, String metricKey, String trainInfoKey) { + List> metricList = (List>) metricRecord.get(metricKey); + if (metricList != null) { + for (Map metricRecordItem : metricList) { + String taskId = (String) metricRecordItem.get("task_id"); + Map trainInfoPart = (Map) trainInfo.get(trainInfoKey); + if (trainInfoPart != null) { + Map trainInfoDetails = (Map) trainInfoPart.get(taskId); + if (trainInfoDetails != null) { + List> datasets = (List>) trainInfoDetails.get("datasets"); + if (datasets != null) { + //查询名字再回填 + for (int i = 0; i < datasets.size(); i++) { + Dataset dataset = datasetService.queryById((Integer) datasets.get(i).get("dataset_id")); + datasets.get(i).put("dataset_name", dataset.getName()); + } + metricRecordItem.put("datasets", datasets); + } + } + } + } + } + } + + + private void insertModelDependency(Map dependendcy,Map trainInfo, Integer experimentInsId, String experimentName) throws Exception { Iterator> dependendcyIterator = dependendcy.entrySet().iterator(); Map modelTrain = (Map) trainInfo.get("model_train"); diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/JupyterServiceImpl.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/JupyterServiceImpl.java index 46071e7e..038e6d91 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/JupyterServiceImpl.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/JupyterServiceImpl.java @@ -1,6 +1,5 @@ package com.ruoyi.platform.service.impl; -import com.ruoyi.common.core.utils.StringUtils; import com.ruoyi.common.redis.service.RedisService; import com.ruoyi.common.security.utils.SecurityUtils; import com.ruoyi.platform.domain.DevEnvironment; @@ -12,7 +11,6 @@ import com.ruoyi.platform.utils.JacksonUtil; import com.ruoyi.platform.utils.K8sClientUtil; import com.ruoyi.platform.utils.MinioUtil; import com.ruoyi.platform.utils.MlflowUtil; -import com.ruoyi.platform.vo.FrameLogPathVo; import com.ruoyi.platform.vo.PodStatusVo; import com.ruoyi.system.api.model.LoginUser; import io.kubernetes.client.openapi.models.V1PersistentVolumeClaim; @@ -101,17 +99,17 @@ public class JupyterServiceImpl implements JupyterService { // 调用修改后的 createPod 方法,传入额外的参数 Integer podPort = k8sClientUtil.createConfiguredPod(podName, namespace, port, mountPath, pvc, image, minioPvcName, datasetPath, modelPath); - // 简单的延迟,以便 Pod 有时间启动 - Thread.sleep(2500); - //查询pod状态,更新到数据库 - String podStatus = k8sClientUtil.getPodStatus(podName, namespace); +// // 简单的延迟,以便 Pod 有时间启动 +// Thread.sleep(2500); +// //查询pod状态,更新到数据库 +// String podStatus = k8sClientUtil.getPodStatus(podName, namespace); String url = masterIp + ":" + podPort; - devEnvironment.setStatus(podStatus); + redisService.setCacheObject(podName,masterIp + ":" + podPort); + devEnvironment.setStatus("Pending"); devEnvironment.setUrl(url); this.devEnvironmentService.update(devEnvironment); return url ; - } @Override @@ -132,25 +130,25 @@ public class JupyterServiceImpl implements JupyterService { String deleteResult = k8sClientUtil.deletePod(podName, namespace); - devEnvironment.setStatus("Terminating"); + devEnvironment.setStatus("Terminated"); this.devEnvironmentService.update(devEnvironment); return deleteResult + ",编辑器已停止"; } @Override - public PodStatusVo getJupyterStatus(FrameLogPathVo frameLogPathVo) { + public PodStatusVo getJupyterStatus(DevEnvironment devEnvironment) throws Exception { String status = PodStatus.Terminated.getName(); PodStatusVo JupyterStatusVo = new PodStatusVo(); JupyterStatusVo.setStatus(status); - if(StringUtils.isEmpty(frameLogPathVo.getPath())){ + if (devEnvironment==null){ return JupyterStatusVo; } LoginUser loginUser = SecurityUtils.getLoginUser(); - String podName = loginUser.getUsername().toLowerCase() + "-editor-pod"; + String podName = loginUser.getUsername().toLowerCase() +"-editor-pod" + "-" + devEnvironment.getId(); try { // 查询相应pod状态 - String podStatus = k8sClientUtil.getPodStatus(podName, StringUtils.isEmpty(frameLogPathVo.getNamespace()) ? "default" : frameLogPathVo.getNamespace()); + String podStatus = k8sClientUtil.getPodStatus(podName, namespace); for (PodStatus s : PodStatus.values()) { if (s.getName().equals(podStatus)) { status = s.getName(); @@ -160,8 +158,6 @@ public class JupyterServiceImpl implements JupyterService { } catch (Exception e) { return JupyterStatusVo; - - } String url = redisService.getCacheObject(podName); JupyterStatusVo.setStatus(status); diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ModelDependencyServiceImpl.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ModelDependencyServiceImpl.java index f3c48ebb..572a66a5 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ModelDependencyServiceImpl.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ModelDependencyServiceImpl.java @@ -17,10 +17,7 @@ import org.springframework.data.domain.PageRequest; import javax.annotation.Resource; import java.io.IOException; -import java.util.ArrayList; -import java.util.Date; -import java.util.List; -import java.util.Map; +import java.util.*; import java.util.stream.Collectors; /** @@ -97,6 +94,16 @@ public class ModelDependencyServiceImpl implements ModelDependencyService { return modelDependcyTreeVo; } + @Override + public List queryByIns(Integer expInsId) { + return modelDependencyDao.queryByIns(expInsId); + } + + @Override + public ModelDependency queryByInsAndTrainTaskId(Integer expInsId, String taskId) { + return modelDependencyDao.queryByInsAndTrainTaskId(expInsId,taskId); + } + /** * 递归父模型 * @param modelDependcyTreeVo diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/utils/AIM64EncoderUtil.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/utils/AIM64EncoderUtil.java new file mode 100644 index 00000000..0cffc705 --- /dev/null +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/utils/AIM64EncoderUtil.java @@ -0,0 +1,76 @@ +package com.ruoyi.platform.utils; + +import com.alibaba.fastjson.JSON; + +import java.util.Base64; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class AIM64EncoderUtil { + + private static final String AIM64_ENCODING_PREFIX = "O-"; + + private static final Map BS64_REPLACE_CHARACTERS_ENCODING = new HashMap<>(); + static { + BS64_REPLACE_CHARACTERS_ENCODING.put("=", ""); + BS64_REPLACE_CHARACTERS_ENCODING.put("+", "-"); + BS64_REPLACE_CHARACTERS_ENCODING.put("/", "_"); + } + + public static String aim64encode(Map value) { + String jsonEncoded = JSON.toJSONString(value); + String base64Encoded = Base64.getEncoder().encodeToString(jsonEncoded.getBytes()); + String aim64Encoded = base64Encoded; + for (Map.Entry entry : BS64_REPLACE_CHARACTERS_ENCODING.entrySet()) { + aim64Encoded = aim64Encoded.replace(entry.getKey(), entry.getValue()); + } + return AIM64_ENCODING_PREFIX + aim64Encoded; + } + + public static String encode(Map value, boolean oneWayHashing) { + if (oneWayHashing) { + return md5(JSON.toJSONString(value)); + } + return aim64encode(value); + } + + private static String md5(String input) { + try { + java.security.MessageDigest md = java.security.MessageDigest.getInstance("MD5"); + byte[] array = md.digest(input.getBytes()); + StringBuilder sb = new StringBuilder(); + for (byte b : array) { + sb.append(Integer.toHexString((b & 0xFF) | 0x100).substring(1, 3)); + } + return sb.toString(); + } catch (java.security.NoSuchAlgorithmException e) { + e.printStackTrace(); + } + return null; + } + + public static String decode(List runIds) { + // 确保 runIds 列表的大小为 3 + if (runIds == null || runIds.size() == 0) { + throw new IllegalArgumentException("runIds 不能为空"); + } + // 构建查询字符串 + StringBuilder queryBuilder = new StringBuilder("run.hash in ["); + for (int i = 0; i < runIds.size(); i++) { + if (i > 0) { + queryBuilder.append(","); + } + queryBuilder.append("\"").append(runIds.get(i)).append("\""); + } + queryBuilder.append("]"); + String query = queryBuilder.toString(); + Map map = new HashMap<>(); + map.put("query", query); + map.put("advancedMode", true); + map.put("advancedQuery", query); + + String searchQuery = encode(map, false); + return searchQuery; + } +} diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/utils/HttpUtils.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/utils/HttpUtils.java index b382eda9..910d9981 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/utils/HttpUtils.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/utils/HttpUtils.java @@ -25,6 +25,7 @@ import java.security.cert.X509Certificate; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.zip.GZIPInputStream; /** * HTTP请求工具类 @@ -447,4 +448,38 @@ public class HttpUtils { return true; } } + + public static String sendGetRequestgzip(String url) throws Exception { + String resultStr = null; + HttpGet httpGet = new HttpGet(url); + httpGet.setHeader("Content-Type", "application/json"); + httpGet.setHeader("Accept-Encoding", "gzip, deflate"); + try { + HttpResponse response = httpClient.execute(httpGet); + int responseCode = response.getStatusLine().getStatusCode(); + if (responseCode != 200) { + throw new IOException("HTTP request failed with response code: " + responseCode); + } + + // 获取响应内容 + InputStream responseStream = response.getEntity().getContent(); + // 检查响应是否被压缩 + if ("gzip".equalsIgnoreCase(response.getEntity().getContentEncoding().getValue())) { + responseStream = new GZIPInputStream(responseStream); + } + + // 读取解压缩后的内容 + byte[] buffer = new byte[1024]; + int len; + StringBuilder decompressedString = new StringBuilder(); + while ((len = responseStream.read(buffer)) > 0) { + decompressedString.append(new String(buffer, 0, len)); + } + + resultStr = decompressedString.toString(); + } catch (IOException e) { + e.printStackTrace(); + } + return resultStr; + } } diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/utils/JsonUtils.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/utils/JsonUtils.java index e1b41780..186173eb 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/utils/JsonUtils.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/utils/JsonUtils.java @@ -1,8 +1,11 @@ package com.ruoyi.platform.utils; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; +import org.json.JSONObject; import java.io.IOException; +import java.util.HashMap; +import java.util.Iterator; import java.util.Map; public class JsonUtils { @@ -28,4 +31,26 @@ public class JsonUtils { public static T jsonToObject(String json, Class clazz) throws IOException { return objectMapper.readValue(json, clazz); } + + + + // 将JSON字符串转换为扁平化的Map + public static Map flattenJson(String prefix, Map map) { + Map flatMap = new HashMap<>(); + Iterator> entries = map.entrySet().iterator(); + + while (entries.hasNext()) { + Map.Entry entry = entries.next(); + String key = entry.getKey(); + Object value = entry.getValue(); + + if (value instanceof Map) { + flatMap.putAll(flattenJson(prefix + key + ".", (Map) value)); + } else { + flatMap.put(prefix + key, value); + } + } + + return flatMap; + } } diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/vo/InsMetricInfoVo.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/vo/InsMetricInfoVo.java new file mode 100644 index 00000000..6fe8caa4 --- /dev/null +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/vo/InsMetricInfoVo.java @@ -0,0 +1,32 @@ +package com.ruoyi.platform.vo; + +import com.fasterxml.jackson.databind.PropertyNamingStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import io.swagger.annotations.ApiModelProperty; +import lombok.Data; + +import java.io.Serializable; +import java.util.Date; +import java.util.List; +import java.util.Map; + +@JsonNaming(PropertyNamingStrategy.SnakeCaseStrategy.class) +@Data +public class InsMetricInfoVo implements Serializable { + @ApiModelProperty(value = "开始时间") + private Date startTime; + @ApiModelProperty(value = "实例运行状态") + private String status; + @ApiModelProperty(value = "使用数据集") + private List dataset; + @ApiModelProperty(value = "实例ID") + private Integer experimentInsId; + @ApiModelProperty(value = "训练指标") + private Map metrics; + @ApiModelProperty(value = "训练参数") + private Map params; + @ApiModelProperty(value = "训练记录ID") + private String runId; + private List metricsNames; + private List paramsNames; +} diff --git a/ruoyi-modules/management-platform/src/main/resources/mapper/managementPlatform/ModelDependencyDaoMapper.xml b/ruoyi-modules/management-platform/src/main/resources/mapper/managementPlatform/ModelDependencyDaoMapper.xml index 2cd5dd7a..ea592ee2 100644 --- a/ruoyi-modules/management-platform/src/main/resources/mapper/managementPlatform/ModelDependencyDaoMapper.xml +++ b/ruoyi-modules/management-platform/src/main/resources/mapper/managementPlatform/ModelDependencyDaoMapper.xml @@ -22,6 +22,22 @@ + +