| @@ -94,7 +94,7 @@ export default [ | |||
| { | |||
| name: '实验训练', | |||
| path: ':workflowId/:id', | |||
| component: './Experiment/training/index', | |||
| component: './Experiment/Info/index', | |||
| }, | |||
| { | |||
| name: '实验对比', | |||
| @@ -112,18 +112,18 @@ export default [ | |||
| { | |||
| name: '开发环境', | |||
| path: '', | |||
| component: './DevelopmentEnvironment/List', | |||
| }, | |||
| { | |||
| name: '创建编辑器', | |||
| path: 'create', | |||
| component: './DevelopmentEnvironment/Create', | |||
| }, | |||
| { | |||
| name: '编辑器', | |||
| path: 'editor', | |||
| component: './DevelopmentEnvironment/Editor', | |||
| }, | |||
| // { | |||
| // name: '创建编辑器', | |||
| // path: 'create', | |||
| // component: './DevelopmentEnvironment/Create', | |||
| // }, | |||
| // { | |||
| // name: '编辑器', | |||
| // path: 'editor', | |||
| // component: './DevelopmentEnvironment/Editor', | |||
| // }, | |||
| ], | |||
| }, | |||
| { | |||
| @@ -0,0 +1,37 @@ | |||
| /* | |||
| * @Author: 赵伟 | |||
| * @Date: 2024-04-28 14:18:11 | |||
| * @Description: 自定义 Table 数组类单元格 | |||
| */ | |||
| import { Tooltip } from 'antd'; | |||
| function ArrayTableCell(ellipsis: boolean = false, property?: string) { | |||
| return (value?: any | null) => { | |||
| if ( | |||
| value === undefined || | |||
| value === null || | |||
| Array.isArray(value) === false || | |||
| value.length === 0 | |||
| ) { | |||
| return <span>--</span>; | |||
| } | |||
| let list = value; | |||
| if (property && typeof value[0] === 'object') { | |||
| list = value.map((item) => item[property]); | |||
| } | |||
| const text = list.join(','); | |||
| if (ellipsis) { | |||
| return ( | |||
| <Tooltip title={text} placement="topLeft" overlayStyle={{ maxWidth: '400px' }}> | |||
| <span>{text}</span>; | |||
| </Tooltip> | |||
| ); | |||
| } else { | |||
| return <span>{text}</span>; | |||
| } | |||
| }; | |||
| } | |||
| export default ArrayTableCell; | |||
| @@ -6,13 +6,13 @@ | |||
| import { Tooltip } from 'antd'; | |||
| function renderCell(text?: string | null) { | |||
| function renderCell(text?: any | null) { | |||
| return <span>{text ?? '--'}</span>; | |||
| } | |||
| function CommonTableCell(ellipsis: boolean = false) { | |||
| if (ellipsis) { | |||
| return (text?: string | null) => ( | |||
| return (text?: any | null) => ( | |||
| <Tooltip title={text} placement="topLeft" overlayStyle={{ maxWidth: '400px' }}> | |||
| {renderCell(text)} | |||
| </Tooltip> | |||
| @@ -1,9 +1,36 @@ | |||
| /* | |||
| * @Author: 赵伟 | |||
| * @Date: 2024-06-07 11:22:28 | |||
| * @Description: 接口返回的枚举值和共用的枚举值定义在这里 | |||
| */ | |||
| // 公开还是私有 TabKey | |||
| export enum CommonTabKeys { | |||
| Private = 'Private', // 私有 | |||
| Public = 'Public', // 公开 | |||
| } | |||
| // 实验状态 | |||
| export enum ExperimentStatus { | |||
| Running = 'Running', // 运行中 | |||
| Succeeded = 'Succeeded', // 成功 | |||
| Pending = 'Pending', // 启动中 | |||
| Failed = 'Failed', // 失败 | |||
| Error = 'Error', // 错误 | |||
| Terminated = 'Terminated', // 终止 | |||
| Skipped = 'Skipped', // 跳过 | |||
| Omitted = 'Omitted', // 忽略 | |||
| } | |||
| // TensorBoard 状态 | |||
| export enum TensorBoardStatus { | |||
| Unknown = 'Unknown', // 未知 | |||
| Pending = 'Pending', // 启动中 | |||
| Running = 'Running', // 运行中 | |||
| Terminated = 'Terminated', // 未启动或者已终止 | |||
| Failed = 'Failed', // 失败 | |||
| } | |||
| // 镜像版本状态 | |||
| export enum MirrorVersionStatus { | |||
| Available = 'available', // 可用 | |||
| @@ -4,7 +4,7 @@ | |||
| * @Description: 覆盖 antd 样式 | |||
| */ | |||
| // 设置 Table 可以滑动 | |||
| // 设置 Table 可以滑动,带分页 | |||
| .vertical-scroll-table { | |||
| .ant-table-wrapper { | |||
| height: 100%; | |||
| @@ -30,6 +30,32 @@ | |||
| } | |||
| } | |||
| // 设置 Table 可以滑动,没有分页 | |||
| .vertical-scroll-table-no-page { | |||
| .ant-table-wrapper { | |||
| height: 100%; | |||
| .ant-spin-nested-loading { | |||
| height: 100%; | |||
| .ant-spin-container { | |||
| height: 100%; | |||
| .ant-table { | |||
| height: 100%; | |||
| .ant-table-container { | |||
| height: 100%; | |||
| .ant-table-body { | |||
| overflow-y: auto !important; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| // Tabs 样式 | |||
| // 删除底部白色横线 | |||
| .ant-tabs { | |||
| @@ -8,10 +8,11 @@ import { ResourceData, ResourceType, resourceConfig } from '../../config'; | |||
| import ResourceVersion from '../ResourceVersion'; | |||
| import styles from './index.less'; | |||
| // 这里值小写是因为值会写在 url 中 | |||
| export enum ResourceInfoTabKeys { | |||
| Introduction = 'introduction', | |||
| Version = 'version', | |||
| Evolution = 'evolution', | |||
| Introduction = 'introduction', // 简介 | |||
| Version = 'version', // 版本 | |||
| Evolution = 'evolution', // 演化 | |||
| } | |||
| type ResourceIntroProps = { | |||
| @@ -17,5 +17,17 @@ | |||
| padding: 20px 30px 0; | |||
| background-color: white; | |||
| border-radius: 10px; | |||
| :global { | |||
| .ant-table-container { | |||
| border: none !important; | |||
| } | |||
| .ant-table-tbody { | |||
| .ant-table-cell { | |||
| border-right: none !important; | |||
| border-left: none !important; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -1,32 +1,50 @@ | |||
| import CommonTableCell from '@/components/CommonTableCell'; | |||
| import { useCacheState } from '@/hooks/pageCacheState'; | |||
| import { getExpEvaluateInfosReq, getExpTrainInfosReq } from '@/services/experiment'; | |||
| // import { useCacheState } from '@/hooks/pageCacheState'; | |||
| import { | |||
| getExpEvaluateInfosReq, | |||
| getExpMetricsReq, | |||
| getExpTrainInfosReq, | |||
| } from '@/services/experiment'; | |||
| import { to } from '@/utils/promise'; | |||
| import tableCellRender, { arrayFormatter, dateFormatter } from '@/utils/table'; | |||
| import { useSearchParams } from '@umijs/max'; | |||
| import { Button, Table, TablePaginationConfig, TableProps } from 'antd'; | |||
| import { App, Button, Table, /*TablePaginationConfig,*/ TableProps } from 'antd'; | |||
| import classNames from 'classnames'; | |||
| import { useEffect, useState } from 'react'; | |||
| import { useEffect, useMemo, useState } from 'react'; | |||
| import ExperimentStatusCell from '../components/ExperimentStatusCell'; | |||
| import styles from './index.less'; | |||
| export enum ComparisonType { | |||
| Train = 'train', // 训练 | |||
| Evaluate = 'evaluate', // 评估 | |||
| Train = 'Train', // 训练 | |||
| Evaluate = 'Evaluate', // 评估 | |||
| } | |||
| type TableData = { | |||
| experiment_ins_id: number; | |||
| run_id: string; | |||
| dataset: string[]; | |||
| start_time: string; | |||
| status: string; | |||
| metrics_names: string[]; | |||
| metrics: Record<string, number>; | |||
| params_names: string[]; | |||
| params: Record<string, string>; | |||
| }; | |||
| function ExperimentComparison() { | |||
| const [searchParams] = useSearchParams(); | |||
| const comparisonType = searchParams.get('type'); | |||
| const experimentId = searchParams.get('experimentId'); | |||
| const [tableData, setTableData] = useState([]); | |||
| const [cacheState, setCacheState] = useCacheState(); | |||
| const [total, setTotal] = useState(0); | |||
| const experimentId = searchParams.get('id'); | |||
| const [tableData, setTableData] = useState<TableData[]>([]); | |||
| // const [cacheState, setCacheState] = useCacheState(); | |||
| // const [total, setTotal] = useState(0); | |||
| const [selectedRowKeys, setSelectedRowKeys] = useState<React.Key[]>([]); | |||
| const [pagination, setPagination] = useState<TablePaginationConfig>( | |||
| cacheState?.pagination ?? { | |||
| current: 1, | |||
| pageSize: 10, | |||
| }, | |||
| ); | |||
| const { message } = App.useApp(); | |||
| // const [pagination, setPagination] = useState<TablePaginationConfig>( | |||
| // cacheState?.pagination ?? { | |||
| // current: 1, | |||
| // pageSize: 10, | |||
| // }, | |||
| // ); | |||
| useEffect(() => { | |||
| getComparisonData(); | |||
| @@ -38,20 +56,39 @@ function ExperimentComparison() { | |||
| comparisonType === ComparisonType.Train ? getExpTrainInfosReq : getExpEvaluateInfosReq; | |||
| const [res] = await to(request(experimentId)); | |||
| if (res && res.data) { | |||
| const { content = [], totalElements = 0 } = res.data; | |||
| setTableData(content); | |||
| setTotal(totalElements); | |||
| // const { content = [], totalElements = 0 } = res.data; | |||
| setTableData(res.data); | |||
| // setTotal(totalElements); | |||
| } | |||
| }; | |||
| // 分页切换 | |||
| const handleTableChange: TableProps['onChange'] = (pagination, filters, sorter, { action }) => { | |||
| if (action === 'paginate') { | |||
| setPagination(pagination); | |||
| // 获取对比 url | |||
| const getExpMetrics = async () => { | |||
| const [res] = await to(getExpMetricsReq(selectedRowKeys)); | |||
| if (res && res.data) { | |||
| const url = res.data; | |||
| window.open(url, '_blank'); | |||
| } | |||
| // console.log(pagination, filters, sorter, action); | |||
| }; | |||
| // 对比按钮 click | |||
| const hanldeComparisonClick = () => { | |||
| if (selectedRowKeys.length < 2) { | |||
| message.error('请至少选择两项进行对比'); | |||
| return; | |||
| } | |||
| getExpMetrics(); | |||
| }; | |||
| // 分页切换 | |||
| // const handleTableChange: TableProps['onChange'] = (pagination, filters, sorter, { action }) => { | |||
| // if (action === 'paginate') { | |||
| // setPagination(pagination); | |||
| // } | |||
| // // console.log(pagination, filters, sorter, action); | |||
| // }; | |||
| // 选择行 | |||
| const rowSelection: TableProps['rowSelection'] = { | |||
| type: 'checkbox', | |||
| selectedRowKeys, | |||
| @@ -61,148 +98,96 @@ function ExperimentComparison() { | |||
| }, | |||
| }; | |||
| const columns: TableProps['columns'] = [ | |||
| { | |||
| title: '基本信息', | |||
| children: [ | |||
| { | |||
| title: '实例ID', | |||
| dataIndex: 'name', | |||
| key: 'name', | |||
| width: '30%', | |||
| render: CommonTableCell(), | |||
| }, | |||
| { | |||
| title: '运行时间', | |||
| dataIndex: 'name', | |||
| key: 'name', | |||
| width: '30%', | |||
| render: CommonTableCell(), | |||
| }, | |||
| { | |||
| title: '运行状态', | |||
| dataIndex: 'name', | |||
| key: 'name', | |||
| width: '30%', | |||
| render: CommonTableCell(), | |||
| }, | |||
| { | |||
| title: '训练数据集', | |||
| dataIndex: 'name', | |||
| key: 'name', | |||
| width: '30%', | |||
| render: CommonTableCell(), | |||
| }, | |||
| { | |||
| title: '增量训练', | |||
| dataIndex: 'name', | |||
| key: 'name', | |||
| width: '30%', | |||
| render: CommonTableCell(), | |||
| }, | |||
| ], | |||
| }, | |||
| { | |||
| title: '训练参数', | |||
| children: [ | |||
| { | |||
| title: 'batchsize', | |||
| dataIndex: 'name', | |||
| key: 'name', | |||
| width: '30%', | |||
| render: CommonTableCell(), | |||
| }, | |||
| { | |||
| title: 'config', | |||
| dataIndex: 'name', | |||
| key: 'name', | |||
| width: '30%', | |||
| render: CommonTableCell(), | |||
| }, | |||
| { | |||
| title: 'epoch', | |||
| dataIndex: 'name', | |||
| key: 'name', | |||
| width: '30%', | |||
| render: CommonTableCell(), | |||
| }, | |||
| { | |||
| title: 'lr', | |||
| dataIndex: 'name', | |||
| key: 'name', | |||
| width: '30%', | |||
| render: CommonTableCell(), | |||
| }, | |||
| { | |||
| title: 'warmup_iters', | |||
| dataIndex: 'name', | |||
| key: 'name', | |||
| width: '30%', | |||
| render: CommonTableCell(), | |||
| }, | |||
| ], | |||
| }, | |||
| { | |||
| title: '训练指标', | |||
| children: [ | |||
| { | |||
| title: 'metrc_name', | |||
| dataIndex: 'name', | |||
| key: 'name', | |||
| width: '30%', | |||
| render: CommonTableCell(), | |||
| }, | |||
| { | |||
| title: 'test_1', | |||
| dataIndex: 'name', | |||
| key: 'name', | |||
| width: '30%', | |||
| render: CommonTableCell(), | |||
| }, | |||
| { | |||
| title: 'test_2', | |||
| dataIndex: 'name', | |||
| key: 'name', | |||
| width: '30%', | |||
| render: CommonTableCell(), | |||
| }, | |||
| { | |||
| title: 'test_3', | |||
| dataIndex: 'name', | |||
| key: 'name', | |||
| width: '30%', | |||
| render: CommonTableCell(), | |||
| }, | |||
| { | |||
| title: 'test_4', | |||
| dataIndex: 'name', | |||
| key: 'name', | |||
| width: '30%', | |||
| render: CommonTableCell(), | |||
| }, | |||
| ], | |||
| }, | |||
| ]; | |||
| const columns: TableProps['columns'] = useMemo(() => { | |||
| const first: TableData | undefined = tableData[0]; | |||
| return [ | |||
| { | |||
| title: '基本信息', | |||
| children: [ | |||
| { | |||
| title: '实例 ID', | |||
| dataIndex: 'experiment_ins_id', | |||
| key: 'experiment_ins_id', | |||
| width: '20%', | |||
| render: tableCellRender(), | |||
| }, | |||
| { | |||
| title: '运行时间', | |||
| dataIndex: 'start_time', | |||
| key: 'start_time', | |||
| width: 180, | |||
| render: tableCellRender(false, dateFormatter), | |||
| }, | |||
| { | |||
| title: '运行状态', | |||
| dataIndex: 'status', | |||
| key: 'status', | |||
| width: '20%', | |||
| render: ExperimentStatusCell, | |||
| }, | |||
| { | |||
| title: '训练数据集', | |||
| dataIndex: 'dataset', | |||
| key: 'dataset', | |||
| width: '20%', | |||
| render: tableCellRender(true, arrayFormatter()), | |||
| ellipsis: { showTitle: false }, | |||
| }, | |||
| ], | |||
| }, | |||
| { | |||
| title: '训练参数', | |||
| children: first?.params_names.map((name) => ({ | |||
| title: name, | |||
| dataIndex: ['params', name], | |||
| key: name, | |||
| width: '20%', | |||
| render: tableCellRender(true), | |||
| ellipsis: { showTitle: false }, | |||
| })), | |||
| }, | |||
| { | |||
| title: '训练指标', | |||
| children: first?.metrics_names.map((name) => ({ | |||
| title: name, | |||
| dataIndex: ['metrics', name], | |||
| key: name, | |||
| width: '20%', | |||
| render: tableCellRender(true), | |||
| ellipsis: { showTitle: false }, | |||
| })), | |||
| }, | |||
| ]; | |||
| }, [tableData]); | |||
| return ( | |||
| <div className={styles['experiment-comparison']}> | |||
| <div className={styles['experiment-comparison__header']}> | |||
| <Button type="default">可视化对比</Button> | |||
| <Button type="default" onClick={hanldeComparisonClick}> | |||
| 可视化对比 | |||
| </Button> | |||
| </div> | |||
| <div className={classNames('vertical-scroll-table', styles['experiment-comparison__table'])}> | |||
| <div | |||
| className={classNames( | |||
| 'vertical-scroll-table-no-page', | |||
| styles['experiment-comparison__table'], | |||
| )} | |||
| > | |||
| <Table | |||
| dataSource={tableData} | |||
| columns={columns} | |||
| rowSelection={rowSelection} | |||
| scroll={{ y: 'calc(100% - 55px)' }} | |||
| pagination={{ | |||
| ...pagination, | |||
| total: total, | |||
| showSizeChanger: true, | |||
| showQuickJumper: true, | |||
| }} | |||
| onChange={handleTableChange} | |||
| rowKey="id" | |||
| pagination={false} | |||
| bordered={true} | |||
| // pagination={{ | |||
| // ...pagination, | |||
| // total: total, | |||
| // showSizeChanger: true, | |||
| // showQuickJumper: true, | |||
| // }} | |||
| // onChange={handleTableChange} | |||
| rowKey="run_id" | |||
| /> | |||
| </div> | |||
| </div> | |||
| @@ -0,0 +1,12 @@ | |||
| .experiment-status-cell { | |||
| height: 100%; | |||
| &__label { | |||
| display: none; | |||
| } | |||
| } | |||
| .experiment-status-cell:hover { | |||
| .experiment-status-cell__label { | |||
| display: inline; | |||
| } | |||
| } | |||
| @@ -0,0 +1,28 @@ | |||
| /* | |||
| * @Author: 赵伟 | |||
| * @Date: 2024-04-18 18:35:41 | |||
| * @Description: 实验状态 | |||
| */ | |||
| import { ExperimentStatus } from '@/enums'; | |||
| import { experimentStatusInfo as statusInfo } from '@/pages/Experiment/status'; | |||
| import styles from './index.less'; | |||
| function ExperimentStatusCell(status?: ExperimentStatus | null) { | |||
| if (status === null || status === undefined || !statusInfo[status]) { | |||
| return <span>--</span>; | |||
| } | |||
| return ( | |||
| <div className={styles['experiment-status-cell']}> | |||
| <img style={{ width: '17px', marginRight: '7px' }} src={statusInfo[status]?.icon} /> | |||
| <span | |||
| style={{ color: statusInfo[status]?.color }} | |||
| className={styles['experiment-status-cell__label']} | |||
| > | |||
| {statusInfo[status]?.label} | |||
| </span> | |||
| </div> | |||
| ); | |||
| } | |||
| export default ExperimentStatusCell; | |||
| @@ -4,9 +4,9 @@ | |||
| * @Description: 日志组件 | |||
| */ | |||
| import { ExperimentStatus } from '@/enums'; | |||
| import { useStateRef } from '@/hooks'; | |||
| import { ExperimentStatus } from '@/pages/Experiment/status'; | |||
| import { ExperimentLog } from '@/pages/Experiment/training/props'; | |||
| import { ExperimentLog } from '@/pages/Experiment/Info/props'; | |||
| import { getExperimentPodsLog } from '@/services/experiment/index.js'; | |||
| import { DoubleRightOutlined, DownOutlined, UpOutlined } from '@ant-design/icons'; | |||
| import { Button } from 'antd'; | |||
| @@ -47,7 +47,7 @@ function LogGroup({ | |||
| const [logList, setLogList, logListRef] = useStateRef<Log[]>([]); | |||
| const [completed, setCompleted] = useState(false); | |||
| // eslint-disable-next-line @typescript-eslint/no-unused-vars | |||
| const [isMouseDown, setIsMouseDown, isMouseDownRef] = useStateRef(false); | |||
| const [_isMouseDown, setIsMouseDown, isMouseDownRef] = useStateRef(false); | |||
| useEffect(() => { | |||
| scrollToBottom(false); | |||
| @@ -1,5 +1,5 @@ | |||
| import { ExperimentStatus } from '@/pages/Experiment/status'; | |||
| import { ExperimentLog } from '@/pages/Experiment/training/props'; | |||
| import { ExperimentStatus } from '@/enums'; | |||
| import { ExperimentLog } from '@/pages/Experiment/Info/props'; | |||
| import LogGroup from '../LogGroup'; | |||
| import styles from './index.less'; | |||
| @@ -5,16 +5,15 @@ import classNames from 'classnames'; | |||
| import styles from './index.less'; | |||
| // import stopImg from '@/assets/img/tensor-board-stop.png'; | |||
| import terminatedImg from '@/assets/img/tensor-board-terminated.png'; | |||
| import { TensorBoardStatus } from '@/enums'; | |||
| export enum TensorBoardStatusEnum { | |||
| Unknown = 'Unknown', // 未知 | |||
| Pending = 'Pending', // 启动中 | |||
| Running = 'Running', // 运行中 | |||
| Terminated = 'Terminated', // 未启动或者已终止 | |||
| Failed = 'Failed', // 失败 | |||
| } | |||
| type TensorBoardStatusInfo = { | |||
| label: string; | |||
| icon: string; | |||
| classname: string; | |||
| }; | |||
| const statusConfig = { | |||
| const statusConfig: Record<TensorBoardStatus, TensorBoardStatusInfo> = { | |||
| Unknown: { | |||
| label: '未知', | |||
| icon: terminatedImg, | |||
| @@ -43,12 +42,12 @@ const statusConfig = { | |||
| }; | |||
| type TensorBoardStatusProps = { | |||
| status: TensorBoardStatusEnum; | |||
| status: TensorBoardStatus; | |||
| onClick: () => void; | |||
| }; | |||
| function TensorBoardStatus({ | |||
| status = TensorBoardStatusEnum.Unknown, | |||
| function TensorBoardStatusCell({ | |||
| status = TensorBoardStatus.Unknown, | |||
| onClick, | |||
| }: TensorBoardStatusProps) { | |||
| return ( | |||
| @@ -64,7 +63,7 @@ function TensorBoardStatus({ | |||
| {statusConfig[status].icon ? ( | |||
| <> | |||
| <div style={{ margin: '0 6px' }}>|</div> | |||
| {status === TensorBoardStatusEnum.Pending ? ( | |||
| {status === TensorBoardStatus.Pending ? ( | |||
| <LoadingOutlined className={styles['tensorBoard-status__icon']} /> | |||
| ) : ( | |||
| <img | |||
| @@ -79,4 +78,4 @@ function TensorBoardStatus({ | |||
| ); | |||
| } | |||
| export default TensorBoardStatus; | |||
| export default TensorBoardStatusCell; | |||
| @@ -1,5 +1,6 @@ | |||
| import CommonTableCell from '@/components/CommonTableCell'; | |||
| import KFIcon from '@/components/KFIcon'; | |||
| import { TensorBoardStatus } from '@/enums'; | |||
| import { | |||
| deleteExperimentById, | |||
| deleteQueryByExperimentInsId, | |||
| @@ -24,7 +25,7 @@ import { useEffect, useRef, useState } from 'react'; | |||
| import { useNavigate } from 'react-router-dom'; | |||
| import { ComparisonType } from './Comparison'; | |||
| import AddExperimentModal from './components/AddExperimentModal'; | |||
| import TensorBoardStatus, { TensorBoardStatusEnum } from './components/TensorBoardStatus'; | |||
| import TensorBoardStatusCell from './components/TensorBoardStatus'; | |||
| import Styles from './index.less'; | |||
| import { experimentStatusInfo } from './status'; | |||
| @@ -260,12 +261,12 @@ function Experiment() { | |||
| const handleTensorboard = async (experimentIn) => { | |||
| if ( | |||
| experimentIn.tensorBoardStatus === TensorBoardStatusEnum.Terminated || | |||
| experimentIn.tensorBoardStatus === TensorBoardStatusEnum.Failed | |||
| experimentIn.tensorBoardStatus === TensorBoardStatus.Terminated || | |||
| experimentIn.tensorBoardStatus === TensorBoardStatus.Failed | |||
| ) { | |||
| await runTensorBoard(experimentIn); | |||
| } else if ( | |||
| experimentIn.tensorBoardStatus === TensorBoardStatusEnum.Running && | |||
| experimentIn.tensorBoardStatus === TensorBoardStatus.Running && | |||
| experimentIn.tensorboardUrl | |||
| ) { | |||
| window.open(experimentIn.tensorboardUrl, '_blank'); | |||
| @@ -457,12 +458,12 @@ function Experiment() { | |||
| </a> | |||
| <div className={Styles.tensorBoard}> | |||
| {item.nodes_result?.tensorboard_log ? ( | |||
| <TensorBoardStatus | |||
| <TensorBoardStatusCell | |||
| status={item.tensorBoardStatus} | |||
| onClick={() => handleTensorboard(item)} | |||
| ></TensorBoardStatus> | |||
| ></TensorBoardStatusCell> | |||
| ) : ( | |||
| '-' | |||
| '--' | |||
| )} | |||
| </div> | |||
| <div className={Styles.description}> | |||
| @@ -1,23 +1,13 @@ | |||
| import { ExperimentStatus } from '@/enums'; | |||
| import themes from '@/styles/theme.less'; | |||
| export interface StatusInfo { | |||
| export interface ExperimentStatusInfo { | |||
| label: string; | |||
| color: string; | |||
| icon: string; | |||
| } | |||
| export enum ExperimentStatus { | |||
| Running = 'Running', | |||
| Succeeded = 'Succeeded', | |||
| Pending = 'Pending', | |||
| Failed = 'Failed', | |||
| Error = 'Error', | |||
| Terminated = 'Terminated', | |||
| Skipped = 'Skipped', | |||
| Omitted = 'Omitted', | |||
| } | |||
| export const experimentStatusInfo: Record<ExperimentStatus, StatusInfo | undefined> = { | |||
| export const experimentStatusInfo: Record<ExperimentStatus, ExperimentStatusInfo> = { | |||
| Running: { | |||
| label: '运行中', | |||
| color: themes.primaryColor, | |||
| @@ -17,9 +17,9 @@ import { ModelDeploymentData } from '../types'; | |||
| import styles from './index.less'; | |||
| export enum ModelDeploymentTabKey { | |||
| Predict = 'Predict', | |||
| Guide = 'Guide', | |||
| Log = 'Log', | |||
| Predict = 'Predict', // 预测 | |||
| Guide = 'Guide', // 调用指南 | |||
| Log = 'Log', // 服务日志 | |||
| } | |||
| function ModelDeploymentInfo() { | |||
| @@ -166,7 +166,7 @@ function ModelDeployment() { | |||
| }; | |||
| // 分页切换 | |||
| const handleTableChange: TableProps['onChange'] = (pagination, filters, sorter, { action }) => { | |||
| const handleTableChange: TableProps['onChange'] = (pagination, _filters, _sorter, { action }) => { | |||
| if (action === 'paginate') { | |||
| setPagination(pagination); | |||
| } | |||
| @@ -179,7 +179,7 @@ function ModelDeployment() { | |||
| dataIndex: 'index', | |||
| key: 'index', | |||
| width: '20%', | |||
| render(text, record, index) { | |||
| render(_text, _record, index) { | |||
| return <span>{(pagination.current! - 1) * pagination.pageSize! + index + 1}</span>; | |||
| }, | |||
| }, | |||
| @@ -26,7 +26,7 @@ export type ModelDeploymentData = { | |||
| // 操作类型 | |||
| export enum ModelDeploymentOperationType { | |||
| Create = 'Create', | |||
| Update = 'Update', | |||
| Restart = 'Restart', | |||
| Create = 'Create', // 创建 | |||
| Update = 'Update', // 更新 | |||
| Restart = 'Restart', // 重启 | |||
| } | |||
| @@ -126,7 +126,15 @@ export function getExpEvaluateInfosReq(experimentId) { | |||
| // 获取当前实验的模型训练指标信息 | |||
| export function getExpTrainInfosReq(experimentId) { | |||
| return request(`/api/mmp//aim/getExpTrainInfos/${experimentId}`, { | |||
| return request(`/api/mmp/aim/getExpTrainInfos/${experimentId}`, { | |||
| method: 'GET', | |||
| }); | |||
| } | |||
| // 获取当前实验的指标对比地址 | |||
| export function getExpMetricsReq(data) { | |||
| return request(`/api/mmp/aim/getExpMetrics`, { | |||
| method: 'POST', | |||
| data | |||
| }); | |||
| } | |||
| @@ -4,7 +4,7 @@ | |||
| * @Description: 定义全局类型,比如无关联的页面都需要要的类型 | |||
| */ | |||
| import { ExperimentStatus } from '@/pages/Experiment/status'; | |||
| import { ExperimentStatus } from '@/enums'; | |||
| // 流水线全局参数 | |||
| export type PipelineGlobalParam = { | |||
| @@ -0,0 +1,68 @@ | |||
| /* | |||
| * @Author: 赵伟 | |||
| * @Date: 2024-06-26 10:05:52 | |||
| * @Description: 列表自定义 render | |||
| */ | |||
| import { formatDate } from '@/utils/date'; | |||
| import { Tooltip } from 'antd'; | |||
| import dayjs from 'dayjs'; | |||
| type TableCellFormatter = (value?: any | null) => string | undefined | null; | |||
| // 字符串转换函数 | |||
| export const stringFormatter: TableCellFormatter = (value?: any | null) => { | |||
| return value; | |||
| }; | |||
| // 日期转换函数 | |||
| export const dateFormatter: TableCellFormatter = (value?: any | null) => { | |||
| if (value === undefined || value === null || value === '') { | |||
| return null; | |||
| } | |||
| if (!dayjs(value).isValid()) { | |||
| return null; | |||
| } | |||
| return formatDate(value); | |||
| }; | |||
| // 数组转换函数 | |||
| export function arrayFormatter(property?: string) { | |||
| return (value?: any | null): ReturnType<TableCellFormatter> => { | |||
| if ( | |||
| value === undefined || | |||
| value === null || | |||
| Array.isArray(value) === false || | |||
| value.length === 0 | |||
| ) { | |||
| return null; | |||
| } | |||
| let list = value; | |||
| if (property && typeof value[0] === 'object') { | |||
| list = value.map((item) => item[property]); | |||
| } | |||
| return list.join(','); | |||
| }; | |||
| } | |||
| function tableCellRender(ellipsis: boolean = false, format: TableCellFormatter = stringFormatter) { | |||
| return (value?: any | null) => { | |||
| const text = format(value); | |||
| if (ellipsis && text) { | |||
| return ( | |||
| <Tooltip title={text} placement="topLeft" overlayStyle={{ maxWidth: '400px' }}> | |||
| {renderCell(text)} | |||
| </Tooltip> | |||
| ); | |||
| } else { | |||
| return renderCell(text); | |||
| } | |||
| }; | |||
| } | |||
| function renderCell(text?: any | null) { | |||
| return <span>{text ?? '--'}</span>; | |||
| } | |||
| export default tableCellRender; | |||
| @@ -205,6 +205,17 @@ | |||
| <groupId>org.springframework.boot</groupId> | |||
| <artifactId>spring-boot-starter-websocket</artifactId> | |||
| </dependency> | |||
| <dependency> | |||
| <groupId>org.json</groupId> | |||
| <artifactId>json</artifactId> | |||
| <version>20210307</version> | |||
| </dependency> | |||
| <dependency> | |||
| <groupId>org.apache.dubbo</groupId> | |||
| <artifactId>dubbo</artifactId> | |||
| <version>3.0.8</version> | |||
| <scope>compile</scope> | |||
| </dependency> | |||
| </dependencies> | |||
| @@ -4,16 +4,16 @@ import com.ruoyi.common.core.web.controller.BaseController; | |||
| import com.ruoyi.common.core.web.domain.GenericsAjaxResult; | |||
| import com.ruoyi.platform.service.AimService; | |||
| import com.ruoyi.platform.vo.FrameLogPathVo; | |||
| import com.ruoyi.platform.vo.InsMetricInfoVo; | |||
| import com.ruoyi.platform.vo.PodStatusVo; | |||
| import io.swagger.annotations.Api; | |||
| import io.swagger.annotations.ApiOperation; | |||
| import io.swagger.v3.oas.annotations.responses.ApiResponse; | |||
| import org.springframework.web.bind.annotation.PostMapping; | |||
| import org.springframework.web.bind.annotation.RequestBody; | |||
| import org.springframework.web.bind.annotation.RequestMapping; | |||
| import org.springframework.web.bind.annotation.RestController; | |||
| import org.springframework.web.bind.annotation.*; | |||
| import javax.annotation.Resource; | |||
| import java.util.List; | |||
| @RestController | |||
| @RequestMapping("aim") | |||
| @Api("Aim管理") | |||
| @@ -22,17 +22,25 @@ public class AimController extends BaseController { | |||
| @Resource | |||
| private AimService aimService; | |||
| /** | |||
| * 启动tensorBoard接口 | |||
| * | |||
| * @param frameLogPathVo 存储路径 | |||
| * @return url | |||
| */ | |||
| @PostMapping("/run") | |||
| @ApiOperation("启动aim`") | |||
| @GetMapping("/getExpTrainInfos/{experiment_id}") | |||
| @ApiOperation("获取当前实验的模型训练指标信息") | |||
| @ApiResponse | |||
| public GenericsAjaxResult<String> runAim(@RequestBody FrameLogPathVo frameLogPathVo) throws Exception { | |||
| return genericsSuccess(aimService.runAim(frameLogPathVo)); | |||
| public GenericsAjaxResult<List<InsMetricInfoVo>> getExpTrainInfos(@PathVariable("experiment_id") Integer experimentId) throws Exception { | |||
| return genericsSuccess(aimService.getExpTrainInfos(experimentId)); | |||
| } | |||
| @GetMapping("/getExpEvaluateInfos/{experiment_id}") | |||
| @ApiOperation("获取当前实验的模型推理指标信息") | |||
| @ApiResponse | |||
| public GenericsAjaxResult<List<InsMetricInfoVo>> getExpEvaluateInfos(@PathVariable("experiment_id") Integer experimentId) throws Exception { | |||
| return genericsSuccess(aimService.getExpEvaluateInfos(experimentId)); | |||
| } | |||
| @PostMapping("/getExpMetrics") | |||
| @ApiOperation("获取当前实验的指标对比地址") | |||
| @ApiResponse | |||
| public GenericsAjaxResult<String> getExpMetrics(@RequestBody List<String> runIds) throws Exception { | |||
| return genericsSuccess(aimService.getExpMetrics(runIds)); | |||
| } | |||
| } | |||
| @@ -3,6 +3,7 @@ package com.ruoyi.platform.controller.jupyter; | |||
| import com.ruoyi.common.core.web.controller.BaseController; | |||
| import com.ruoyi.common.core.web.domain.AjaxResult; | |||
| import com.ruoyi.common.core.web.domain.GenericsAjaxResult; | |||
| import com.ruoyi.platform.domain.DevEnvironment; | |||
| import com.ruoyi.platform.service.JupyterService; | |||
| import com.ruoyi.platform.vo.FrameLogPathVo; | |||
| import com.ruoyi.platform.vo.PodStatusVo; | |||
| @@ -60,8 +61,8 @@ public class JupyterController extends BaseController { | |||
| @PostMapping("/getStatus") | |||
| @ApiOperation("查询jupyter pod状态") | |||
| @ApiResponse | |||
| public GenericsAjaxResult<PodStatusVo> getStatus(@RequestBody FrameLogPathVo frameLogPathVo) throws Exception { | |||
| return genericsSuccess(this.jupyterService.getJupyterStatus(frameLogPathVo)); | |||
| public GenericsAjaxResult<PodStatusVo> getStatus(DevEnvironment devEnvironment) throws Exception { | |||
| return genericsSuccess(this.jupyterService.getJupyterStatus(devEnvironment)); | |||
| } | |||
| @@ -84,5 +84,9 @@ public interface ModelDependencyDao { | |||
| List<ModelDependency> queryByModelDependency(@Param("modelDependency") ModelDependency modelDependency); | |||
| List<ModelDependency> queryChildrenByVersionId(@Param("model_id")String modelId, @Param("version")String version); | |||
| List<ModelDependency> queryByIns(@Param("expInsId")Integer expInsId); | |||
| ModelDependency queryByInsAndTrainTaskId(@Param("expInsId")Integer expInsId,@Param("taskId") String taskId); | |||
| } | |||
| @@ -7,20 +7,15 @@ import com.ruoyi.platform.mapper.ExperimentDao; | |||
| import com.ruoyi.platform.mapper.ExperimentInsDao; | |||
| import com.ruoyi.platform.mapper.ModelDependencyDao; | |||
| import com.ruoyi.platform.service.ExperimentInsService; | |||
| import com.ruoyi.platform.service.ModelDependencyService; | |||
| import com.ruoyi.platform.utils.JacksonUtil; | |||
| import org.apache.commons.lang3.StringUtils; | |||
| import org.springframework.beans.factory.annotation.Autowired; | |||
| import org.springframework.data.domain.Page; | |||
| import org.springframework.scheduling.annotation.Scheduled; | |||
| import org.springframework.stereotype.Component; | |||
| import javax.annotation.Resource; | |||
| import java.io.IOException; | |||
| import java.util.ArrayList; | |||
| import java.util.Date; | |||
| import java.util.List; | |||
| import java.util.Map; | |||
| import java.util.*; | |||
| @Component() | |||
| public class ExperimentInstanceStatusTask { | |||
| @@ -34,7 +29,7 @@ public class ExperimentInstanceStatusTask { | |||
| private ModelDependencyDao modelDependencyDao; | |||
| private List<Integer> experimentIds = new ArrayList<>(); | |||
| @Scheduled(cron = "0/14 * * * * ?") // 每30S执行一次 | |||
| @Scheduled(cron = "0/30 * * * * ?") // 每30S执行一次 | |||
| public void executeExperimentInsStatus() throws IOException { | |||
| // 首先查到所有非终止态的实验实例 | |||
| List<ExperimentIns> experimentInsList = experimentInsService.queryByExperimentIsNotTerminated(); | |||
| @@ -46,95 +41,94 @@ public class ExperimentInstanceStatusTask { | |||
| String oldStatus = experimentIns.getStatus(); | |||
| try { | |||
| experimentIns = experimentInsService.queryStatusFromArgo(experimentIns); | |||
| }catch (Exception e){ | |||
| } catch (Exception e) { | |||
| experimentIns.setStatus("Failed"); | |||
| } | |||
| // if (!StringUtils.equals(oldStatus,experimentIns.getStatus())){ | |||
| experimentIns.setUpdateTime(new Date()); | |||
| // 线程安全的添加操作 | |||
| synchronized (experimentIds) { | |||
| experimentIds.add(experimentIns.getExperimentId()); | |||
| } | |||
| updateList.add(experimentIns); | |||
| // } | |||
| // experimentInsDao.update(experimentIns); | |||
| experimentIns.setUpdateTime(new Date()); | |||
| // 线程安全的添加操作 | |||
| synchronized (experimentIds) { | |||
| experimentIds.add(experimentIns.getExperimentId()); | |||
| } | |||
| updateList.add(experimentIns); | |||
| } | |||
| } | |||
| if (updateList.size() > 0){ | |||
| if (updateList.size() > 0) { | |||
| experimentInsDao.insertOrUpdateBatch(updateList); | |||
| //遍历模型关系表,找到 | |||
| List<ModelDependency> modelDependencyList = new ArrayList<ModelDependency>(); | |||
| for (ExperimentIns experimentIns : updateList){ | |||
| for (ExperimentIns experimentIns : updateList) { | |||
| ModelDependency modelDependencyquery = new ModelDependency(); | |||
| modelDependencyquery.setExpInsId(experimentIns.getId()); | |||
| modelDependencyquery.setState(2); | |||
| List<ModelDependency> modelDependencyListquery = modelDependencyDao.queryByModelDependency(modelDependencyquery); | |||
| if (modelDependencyListquery==null||modelDependencyListquery.size()==0){ | |||
| if (modelDependencyListquery == null || modelDependencyListquery.size() == 0) { | |||
| continue; | |||
| } | |||
| ModelDependency modelDependency = modelDependencyListquery.get(0); | |||
| //查看状态, | |||
| if (StringUtils.equals("Failed",experimentIns.getStatus())){ | |||
| if (StringUtils.equals("Failed", experimentIns.getStatus())) { | |||
| //取出节点状态 | |||
| String trainTask = modelDependency.getTrainTask(); | |||
| Map<String, Object> trainMap = JacksonUtil.parseJSONStr2Map(trainTask); | |||
| String task_id = (String) trainMap.get("task_id"); | |||
| if (StringUtils.isEmpty(task_id)){ | |||
| if (StringUtils.isEmpty(task_id)) { | |||
| continue; | |||
| } | |||
| String nodesStatus = experimentIns.getNodesStatus(); | |||
| Map<String, Object> nodeMaps = JacksonUtil.parseJSONStr2Map(nodesStatus); | |||
| Map<String, Object> nodeMap = JacksonUtil.parseJSONStr2Map(JacksonUtil.toJSONString(nodeMaps.get(task_id))); | |||
| if (nodeMap==null){ | |||
| if (nodeMap == null) { | |||
| continue; | |||
| } | |||
| if (!StringUtils.equals("Succeeded",(String)nodeMap.get("phase"))){ | |||
| if (!StringUtils.equals("Succeeded", (String) nodeMap.get("phase"))) { | |||
| modelDependency.setState(0); | |||
| modelDependencyList.add(modelDependency); | |||
| } | |||
| } | |||
| } | |||
| if (modelDependencyList.size()>0) { | |||
| if (modelDependencyList.size() > 0) { | |||
| modelDependencyDao.insertOrUpdateBatch(modelDependencyList); | |||
| } | |||
| } | |||
| } | |||
| @Scheduled(cron = "0/17 * * * * ?") // / 每30S执行一次 | |||
| @Scheduled(cron = "0/30 * * * * ?") // / 每30S执行一次 | |||
| public void executeExperimentStatus() throws IOException { | |||
| if (experimentIds.size()==0){ | |||
| if (experimentIds.size() == 0) { | |||
| return; | |||
| } | |||
| // 存储需要更新的实验对象列表 | |||
| List<Experiment> updateExperiments = new ArrayList<>(); | |||
| for (Integer experimentId : experimentIds){ | |||
| for (Integer experimentId : experimentIds) { | |||
| // 获取当前实验的所有实例列表 | |||
| List<ExperimentIns> insList = experimentInsService.getByExperimentId(experimentId); | |||
| List<String> statusList = new ArrayList<String>(); | |||
| // 更新实验状态列表 | |||
| for (int i=0;i<insList.size();i++){ | |||
| for (int i = 0; i < insList.size(); i++) { | |||
| statusList.add(insList.get(i).getStatus()); | |||
| } | |||
| String subStatus = statusList.toString().substring(1, statusList.toString().length() - 1); | |||
| Experiment experiment = experimentDao.queryById(experimentId); | |||
| // 如果实验状态列表发生变化,则更新实验对象,并加入到需要更新的列表中 | |||
| if (!StringUtils.equals(subStatus,experiment.getStatusList())){ | |||
| if (!StringUtils.equals(subStatus, experiment.getStatusList())) { | |||
| experiment.setStatusList(subStatus); | |||
| updateExperiments.add(experiment); | |||
| } | |||
| } | |||
| if (!updateExperiments.isEmpty()) { | |||
| experimentDao.insertOrUpdateBatch(updateExperiments); | |||
| for (int index = 0; index < updateExperiments.size(); index++) { | |||
| // 线程安全的删除操作 | |||
| synchronized (experimentIds) { | |||
| experimentIds.remove(index); | |||
| // 使用Iterator进行安全的删除操作 | |||
| Iterator<Integer> iterator = experimentIds.iterator(); | |||
| while (iterator.hasNext()) { | |||
| Integer experimentId = iterator.next(); | |||
| for (Experiment experiment : updateExperiments) { | |||
| if (experiment.getId().equals(experimentId)) { | |||
| iterator.remove(); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -1,7 +1,14 @@ | |||
| package com.ruoyi.platform.service; | |||
| import com.ruoyi.platform.vo.FrameLogPathVo; | |||
| import com.ruoyi.platform.vo.InsMetricInfoVo; | |||
| import java.util.List; | |||
| public interface AimService { | |||
| String runAim(FrameLogPathVo frameLogPathVo); | |||
| List<InsMetricInfoVo> getExpTrainInfos(Integer experimentId) throws Exception; | |||
| List<InsMetricInfoVo> getExpEvaluateInfos(Integer experimentId) throws Exception; | |||
| String getExpMetrics(List<String> runIds) throws Exception; | |||
| } | |||
| @@ -1,6 +1,6 @@ | |||
| package com.ruoyi.platform.service; | |||
| import com.ruoyi.platform.vo.FrameLogPathVo; | |||
| import com.ruoyi.platform.domain.DevEnvironment; | |||
| import com.ruoyi.platform.vo.PodStatusVo; | |||
| import java.io.InputStream; | |||
| @@ -16,5 +16,5 @@ public interface JupyterService { | |||
| String stopJupyterService(Integer id) throws Exception; | |||
| PodStatusVo getJupyterStatus(FrameLogPathVo frameLogPathVo); | |||
| PodStatusVo getJupyterStatus(DevEnvironment devEnvironment) throws Exception; | |||
| } | |||
| @@ -62,4 +62,8 @@ public interface ModelDependencyService { | |||
| List<ModelDependency> queryByModelDependency(ModelDependency modelDependency) throws IOException; | |||
| ModelDependcyTreeVo getModelDependencyTree(ModelDependency modelDependency) throws Exception; | |||
| List<ModelDependency> queryByIns(Integer expInsId); | |||
| ModelDependency queryByInsAndTrainTaskId(Integer expInsId, String taskId); | |||
| } | |||
| @@ -1,13 +1,157 @@ | |||
| package com.ruoyi.platform.service.impl; | |||
| import com.ruoyi.platform.domain.ExperimentIns; | |||
| import com.ruoyi.platform.domain.ModelDependency; | |||
| import com.ruoyi.platform.service.AimService; | |||
| import com.ruoyi.platform.vo.FrameLogPathVo; | |||
| import com.ruoyi.platform.service.ExperimentInsService; | |||
| import com.ruoyi.platform.service.ModelDependencyService; | |||
| import com.ruoyi.platform.utils.AIM64EncoderUtil; | |||
| import com.ruoyi.platform.utils.HttpUtils; | |||
| import com.ruoyi.platform.utils.JacksonUtil; | |||
| import com.ruoyi.platform.utils.JsonUtils; | |||
| import com.ruoyi.platform.vo.InsMetricInfoVo; | |||
| import org.apache.commons.lang3.StringUtils; | |||
| import org.springframework.beans.factory.annotation.Value; | |||
| import org.springframework.stereotype.Service; | |||
| import javax.annotation.Resource; | |||
| import java.net.URLEncoder; | |||
| import java.util.*; | |||
| import java.util.stream.Collectors; | |||
| @Service | |||
| public class AimServiceImpl implements AimService { | |||
| @Resource | |||
| private ExperimentInsService experimentInsService; | |||
| @Value("${aim.url}") | |||
| private String aimUrl; | |||
| @Value("${aim.proxyUrl}") | |||
| private String aimProxyUrl; | |||
| @Override | |||
| public List<InsMetricInfoVo> getExpTrainInfos(Integer experimentId) throws Exception { | |||
| return getAimRunInfos(true,experimentId); | |||
| } | |||
| @Override | |||
| public String runAim(FrameLogPathVo frameLogPathVo) { | |||
| return null; | |||
| public List<InsMetricInfoVo> getExpEvaluateInfos(Integer experimentId) throws Exception { | |||
| return getAimRunInfos(false,experimentId); | |||
| } | |||
| @Override | |||
| public String getExpMetrics(List<String> runIds) throws Exception { | |||
| String decode = AIM64EncoderUtil.decode(runIds); | |||
| return aimUrl+"/metrics?select="+decode; | |||
| } | |||
| private List<InsMetricInfoVo> getAimRunInfos(boolean isTrain,Integer experimentId) throws Exception { | |||
| String experimentName = "experiment-"+experimentId+"-train"; | |||
| if (!isTrain){ | |||
| experimentName = "experiment-"+experimentId+"-evaluate"; | |||
| } | |||
| String encodedUrlString = URLEncoder.encode("run.experiment==\""+experimentName+"\"", "UTF-8"); | |||
| String url = aimProxyUrl+"/api/runs/search/run?query="+encodedUrlString; | |||
| String s = HttpUtils.sendGetRequest(url); | |||
| List<Map<String, Object>> response = JacksonUtil.parseJSONStr2MapList(s); | |||
| System.out.println("response: "+JacksonUtil.toJSONString(response)); | |||
| if (response == null || response.size() == 0){ | |||
| return new ArrayList<>(); | |||
| } | |||
| //查询实例数据 | |||
| List<ExperimentIns> byExperimentId = experimentInsService.getByExperimentId(experimentId); | |||
| if (byExperimentId == null || byExperimentId.size() == 0){ | |||
| return new ArrayList<>(); | |||
| } | |||
| List<InsMetricInfoVo> aimRunInfoList = new ArrayList<>(); | |||
| for (Map<String, Object> run : response) { | |||
| InsMetricInfoVo aimRunInfo = new InsMetricInfoVo(); | |||
| String runHash = (String) run.get("run_hash"); | |||
| aimRunInfo.setRunId(runHash); | |||
| Map params= (Map) run.get("params"); | |||
| Map<String, Object> paramMap = JsonUtils.flattenJson("", params); | |||
| aimRunInfo.setParams(paramMap); | |||
| String aimrunId = (String) paramMap.get("id"); | |||
| Map<String, Object> tracesMap= (Map<String, Object>) run.get("traces"); | |||
| List<Map<String, Object>> metricList = (List<Map<String, Object>>) tracesMap.get("metric"); | |||
| //过滤name为__system__开头的对象 | |||
| aimRunInfo.setMetrics(new HashMap<>()); | |||
| if (metricList != null && metricList.size() > 0){ | |||
| List<Map<String, Object>> metricRelList = metricList.stream() | |||
| .filter(map -> !StringUtils.startsWith((String) map.get("name"),"__system__" )) | |||
| .collect(Collectors.toList()); | |||
| if (metricRelList!= null && metricRelList.size() > 0){ | |||
| Map<String, Object> relMetricMap = new HashMap<>(); | |||
| for (Map<String, Object> metricMap : metricRelList) { | |||
| relMetricMap.put((String)metricMap.get("name"), metricMap.get("last_value")); | |||
| } | |||
| aimRunInfo.setMetrics(relMetricMap); | |||
| } | |||
| } | |||
| //找到ins | |||
| for (ExperimentIns ins : byExperimentId) { | |||
| String metricRecordString = ins.getMetricRecord(); | |||
| if (StringUtils.isEmpty(metricRecordString)){ | |||
| continue; | |||
| } | |||
| if (metricRecordString.contains(aimrunId)){ | |||
| aimRunInfo.setExperimentInsId(ins.getId()); | |||
| aimRunInfo.setStatus(ins.getStatus()); | |||
| aimRunInfo.setStartTime(ins.getCreateTime()); | |||
| Map<String, Object> metricRecordMap = JacksonUtil.parseJSONStr2Map(metricRecordString); | |||
| if (isTrain){ | |||
| List<Map<String, Object>> records = (List<Map<String, Object>>) metricRecordMap.get("train"); | |||
| List<String> datasetList = getTrainDateSet(records, aimrunId); | |||
| aimRunInfo.setDataset(datasetList); | |||
| }else { | |||
| List<Map<String, Object>> records = (List<Map<String, Object>>) metricRecordMap.get("evaluate"); | |||
| List<String> datasetList = getTrainDateSet(records, aimrunId); | |||
| aimRunInfo.setDataset(datasetList); | |||
| } | |||
| } | |||
| } | |||
| aimRunInfoList.add(aimRunInfo); | |||
| } | |||
| //判断哪个最长 | |||
| // 获取所有 metrics 的 key 的并集 | |||
| Set<String> metricsKeys = (Set<String>) aimRunInfoList.stream() | |||
| .map(InsMetricInfoVo::getMetrics) | |||
| .flatMap(metrics -> metrics.keySet().stream()) | |||
| .collect(Collectors.toSet()); | |||
| // 将并集赋值给每个 InsMetricInfoVo 的 metricsNames 属性 | |||
| aimRunInfoList.forEach(vo -> vo.setMetricsNames(new ArrayList<>(metricsKeys))); | |||
| // 获取所有 params 的 key 的并集 | |||
| Set<String> paramKeys = (Set<String>) aimRunInfoList.stream() | |||
| .map(InsMetricInfoVo::getParams) | |||
| .flatMap(params -> params.keySet().stream()) | |||
| .collect(Collectors.toSet()); | |||
| // 将并集赋值给每个 InsMetricInfoVo 的 paramsNames 属性 | |||
| aimRunInfoList.forEach(vo -> vo.setParamsNames(new ArrayList<>(paramKeys))); | |||
| return aimRunInfoList; | |||
| } | |||
| private List<String> getTrainDateSet(List<Map<String, Object>> records, String aimrunId){ | |||
| List<String> datasetList = new ArrayList<>(); | |||
| for (Map<String, Object> record : records) { | |||
| if (StringUtils.equals(aimrunId, (String)record.get("run_id"))) { | |||
| List<Map<String, Object>> datasets = (List<Map<String, Object>>) record.get("datasets"); | |||
| if (datasets == null || datasets.size() == 0){ | |||
| continue; | |||
| } | |||
| for (Map<String, Object> dataset : datasets){ | |||
| String datasetName = (String) dataset.get("dataset_name")+":"+(String) dataset.get("dataset_version"); | |||
| datasetList.add(datasetName); | |||
| } | |||
| break; | |||
| } | |||
| } | |||
| return datasetList; | |||
| } | |||
| } | |||
| @@ -2,14 +2,17 @@ package com.ruoyi.platform.service.impl; | |||
| import com.ruoyi.common.security.utils.SecurityUtils; | |||
| import com.ruoyi.platform.domain.DevEnvironment; | |||
| import com.ruoyi.platform.domain.PodStatus; | |||
| import com.ruoyi.platform.mapper.DevEnvironmentDao; | |||
| import com.ruoyi.platform.service.DevEnvironmentService; | |||
| import com.ruoyi.platform.service.JupyterService; | |||
| import com.ruoyi.platform.utils.JacksonUtil; | |||
| import com.ruoyi.platform.vo.DevEnvironmentVo; | |||
| import com.ruoyi.platform.vo.PodStatusVo; | |||
| import com.ruoyi.system.api.model.LoginUser; | |||
| import io.kubernetes.client.openapi.models.V1PersistentVolumeClaim; | |||
| import org.apache.commons.lang3.StringUtils; | |||
| import org.springframework.context.annotation.Lazy; | |||
| import org.springframework.stereotype.Service; | |||
| import org.springframework.data.domain.Page; | |||
| import org.springframework.data.domain.PageImpl; | |||
| @@ -17,6 +20,7 @@ import org.springframework.data.domain.PageRequest; | |||
| import javax.annotation.Resource; | |||
| import java.util.Date; | |||
| import java.util.List; | |||
| import java.util.Map; | |||
| /** | |||
| @@ -30,6 +34,10 @@ public class DevEnvironmentServiceImpl implements DevEnvironmentService { | |||
| @Resource | |||
| private DevEnvironmentDao devEnvironmentDao; | |||
| @Resource | |||
| @Lazy | |||
| private JupyterService jupyterService; | |||
| /** | |||
| * 通过ID查询单条数据 | |||
| @@ -52,13 +60,29 @@ public class DevEnvironmentServiceImpl implements DevEnvironmentService { | |||
| @Override | |||
| public Page<DevEnvironment> queryByPage(DevEnvironment devEnvironment, PageRequest pageRequest) { | |||
| long total = this.devEnvironmentDao.count(devEnvironment); | |||
| return new PageImpl<>(this.devEnvironmentDao.queryAllByLimit(devEnvironment, pageRequest), pageRequest, total); | |||
| List<DevEnvironment> devEnvironmentList = this.devEnvironmentDao.queryAllByLimit(devEnvironment, pageRequest); | |||
| //查询每个开发环境的pod状态,注意:只有pod为非终止态时才去调状态接口 | |||
| devEnvironmentList.forEach(devEnv -> { | |||
| try{ | |||
| if (!devEnv.getStatus().equals(PodStatus.Terminated.getName()) && | |||
| !devEnv.getStatus().equals(PodStatus.Failed.getName())) { | |||
| PodStatusVo podStatusVo = this.jupyterService.getJupyterStatus(devEnv); | |||
| devEnv.setStatus(podStatusVo.getStatus()); | |||
| devEnv.setUrl(podStatusVo.getUrl()); | |||
| } | |||
| } catch (Exception e) { | |||
| devEnv.setStatus(PodStatus.Unknown.getName()); | |||
| } | |||
| }); | |||
| return new PageImpl<>(devEnvironmentList, pageRequest, total); | |||
| } | |||
| /** | |||
| * 新增数据 | |||
| * | |||
| * @param devEnvironment 实例对象 | |||
| * @param devEnvironmentVo 实例对象 | |||
| * @return 实例对象 | |||
| */ | |||
| @Override | |||
| @@ -102,7 +102,6 @@ public class ExperimentInsServiceImpl implements ExperimentInsService { | |||
| */ | |||
| @Override | |||
| public List<ExperimentIns> getByExperimentId(Integer experimentId) throws IOException { | |||
| List<ExperimentIns> experimentInsList = experimentInsDao.getByExperimentId(experimentId); | |||
| //代码全部迁移至定时任务 | |||
| //搞个标记,当状态改变才去改表 | |||
| @@ -138,7 +137,7 @@ public class ExperimentInsServiceImpl implements ExperimentInsService { | |||
| // experimentDao.update(experiment); | |||
| // } | |||
| return experimentInsList; | |||
| return experimentInsDao.getByExperimentId(experimentId); | |||
| } | |||
| @@ -251,12 +251,14 @@ public class ExperimentServiceImpl implements ExperimentService { | |||
| if (data == null || MapUtils.isEmpty(data)) { | |||
| throw new RuntimeException("Failed to run workflow."); | |||
| } | |||
| //获取训练参数 | |||
| Map<String, Object> metricRecord = (Map<String, Object>) runResMap.get("metric_record"); | |||
| Map<String, Object> metadata = (Map<String, Object>) data.get("metadata"); | |||
| // 插入记录到实验实例表 | |||
| ExperimentIns experimentIns = new ExperimentIns(); | |||
| //获取训练参数 | |||
| experimentIns.setExperimentId(experiment.getId()); | |||
| experimentIns.setArgoInsNs((String) metadata.get("namespace")); | |||
| experimentIns.setArgoInsName((String) metadata.get("name")); | |||
| @@ -267,16 +269,25 @@ public class ExperimentServiceImpl implements ExperimentService { | |||
| //替换argoInsName | |||
| String outputString = JsonUtils.mapToJson(output); | |||
| experimentIns.setNodesResult(outputString.replace("{{workflow.name}}", (String) metadata.get("name"))); | |||
| //插入ExperimentIns表中 | |||
| ExperimentIns insert = experimentInsService.insert(experimentIns); | |||
| //插入到模型依赖关系表 | |||
| //得到dependendcy | |||
| Map<String, Object> converMap2 = JsonUtils.jsonToMap(JacksonUtil.replaceInAarry(convertRes, params)); | |||
| Map<String ,Object> dependendcy = (Map<String, Object>)converMap2.get("model_dependency"); | |||
| Map<String ,Object> trainInfo = (Map<String, Object>)converMap2.get("component_info"); | |||
| insertModelDependency(dependendcy,trainInfo,insert.getId(),experiment.getName()); | |||
| Map<String, Object> metricRecord = (Map<String, Object>) runResMap.get("metric_record"); | |||
| if (metricRecord != null){ | |||
| //把训练用的数据集也放进去 | |||
| addDatesetToMetric(metricRecord, trainInfo); | |||
| experimentIns.setMetricRecord(JacksonUtil.toJSONString(metricRecord)); | |||
| } | |||
| //插入ExperimentIns表中 | |||
| ExperimentIns insert = experimentInsService.insert(experimentIns); | |||
| //插入到模型依赖关系表 | |||
| if (dependendcy != null && trainInfo != null){ | |||
| insertModelDependency(dependendcy,trainInfo,insert.getId(),experiment.getName()); | |||
| } | |||
| }catch (Exception e){ | |||
| throw new RuntimeException(e); | |||
| } | |||
| @@ -284,6 +295,37 @@ public class ExperimentServiceImpl implements ExperimentService { | |||
| experiment.setExperimentInsList(updatedExperimentInsList); | |||
| return experiment; | |||
| } | |||
| private void addDatesetToMetric(Map<String, Object> metricRecord, Map<String, Object> trainInfo) { | |||
| processMetricPart(metricRecord, trainInfo, "train", "model_train"); | |||
| processMetricPart(metricRecord, trainInfo, "evaluate", "model_evaluate"); | |||
| } | |||
| private void processMetricPart(Map<String, Object> metricRecord, Map<String, Object> trainInfo, String metricKey, String trainInfoKey) { | |||
| List<Map<String, Object>> metricList = (List<Map<String, Object>>) metricRecord.get(metricKey); | |||
| if (metricList != null) { | |||
| for (Map<String, Object> metricRecordItem : metricList) { | |||
| String taskId = (String) metricRecordItem.get("task_id"); | |||
| Map<String, Object> trainInfoPart = (Map<String, Object>) trainInfo.get(trainInfoKey); | |||
| if (trainInfoPart != null) { | |||
| Map<String, Object> trainInfoDetails = (Map<String, Object>) trainInfoPart.get(taskId); | |||
| if (trainInfoDetails != null) { | |||
| List<Map<String, Object>> datasets = (List<Map<String, Object>>) trainInfoDetails.get("datasets"); | |||
| if (datasets != null) { | |||
| //查询名字再回填 | |||
| for (int i = 0; i < datasets.size(); i++) { | |||
| Dataset dataset = datasetService.queryById((Integer) datasets.get(i).get("dataset_id")); | |||
| datasets.get(i).put("dataset_name", dataset.getName()); | |||
| } | |||
| metricRecordItem.put("datasets", datasets); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| private void insertModelDependency(Map<String ,Object> dependendcy,Map<String ,Object> trainInfo, Integer experimentInsId, String experimentName) throws Exception { | |||
| Iterator<Map.Entry<String, Object>> dependendcyIterator = dependendcy.entrySet().iterator(); | |||
| Map<String, Object> modelTrain = (Map<String, Object>) trainInfo.get("model_train"); | |||
| @@ -1,6 +1,5 @@ | |||
| package com.ruoyi.platform.service.impl; | |||
| import com.ruoyi.common.core.utils.StringUtils; | |||
| import com.ruoyi.common.redis.service.RedisService; | |||
| import com.ruoyi.common.security.utils.SecurityUtils; | |||
| import com.ruoyi.platform.domain.DevEnvironment; | |||
| @@ -12,7 +11,6 @@ import com.ruoyi.platform.utils.JacksonUtil; | |||
| import com.ruoyi.platform.utils.K8sClientUtil; | |||
| import com.ruoyi.platform.utils.MinioUtil; | |||
| import com.ruoyi.platform.utils.MlflowUtil; | |||
| import com.ruoyi.platform.vo.FrameLogPathVo; | |||
| import com.ruoyi.platform.vo.PodStatusVo; | |||
| import com.ruoyi.system.api.model.LoginUser; | |||
| import io.kubernetes.client.openapi.models.V1PersistentVolumeClaim; | |||
| @@ -101,17 +99,17 @@ public class JupyterServiceImpl implements JupyterService { | |||
| // 调用修改后的 createPod 方法,传入额外的参数 | |||
| Integer podPort = k8sClientUtil.createConfiguredPod(podName, namespace, port, mountPath, pvc, image, minioPvcName, datasetPath, modelPath); | |||
| // 简单的延迟,以便 Pod 有时间启动 | |||
| Thread.sleep(2500); | |||
| //查询pod状态,更新到数据库 | |||
| String podStatus = k8sClientUtil.getPodStatus(podName, namespace); | |||
| // // 简单的延迟,以便 Pod 有时间启动 | |||
| // Thread.sleep(2500); | |||
| // //查询pod状态,更新到数据库 | |||
| // String podStatus = k8sClientUtil.getPodStatus(podName, namespace); | |||
| String url = masterIp + ":" + podPort; | |||
| devEnvironment.setStatus(podStatus); | |||
| redisService.setCacheObject(podName,masterIp + ":" + podPort); | |||
| devEnvironment.setStatus("Pending"); | |||
| devEnvironment.setUrl(url); | |||
| this.devEnvironmentService.update(devEnvironment); | |||
| return url ; | |||
| } | |||
| @Override | |||
| @@ -132,25 +130,25 @@ public class JupyterServiceImpl implements JupyterService { | |||
| String deleteResult = k8sClientUtil.deletePod(podName, namespace); | |||
| devEnvironment.setStatus("Terminating"); | |||
| devEnvironment.setStatus("Terminated"); | |||
| this.devEnvironmentService.update(devEnvironment); | |||
| return deleteResult + ",编辑器已停止"; | |||
| } | |||
| @Override | |||
| public PodStatusVo getJupyterStatus(FrameLogPathVo frameLogPathVo) { | |||
| public PodStatusVo getJupyterStatus(DevEnvironment devEnvironment) throws Exception { | |||
| String status = PodStatus.Terminated.getName(); | |||
| PodStatusVo JupyterStatusVo = new PodStatusVo(); | |||
| JupyterStatusVo.setStatus(status); | |||
| if(StringUtils.isEmpty(frameLogPathVo.getPath())){ | |||
| if (devEnvironment==null){ | |||
| return JupyterStatusVo; | |||
| } | |||
| LoginUser loginUser = SecurityUtils.getLoginUser(); | |||
| String podName = loginUser.getUsername().toLowerCase() + "-editor-pod"; | |||
| String podName = loginUser.getUsername().toLowerCase() +"-editor-pod" + "-" + devEnvironment.getId(); | |||
| try { | |||
| // 查询相应pod状态 | |||
| String podStatus = k8sClientUtil.getPodStatus(podName, StringUtils.isEmpty(frameLogPathVo.getNamespace()) ? "default" : frameLogPathVo.getNamespace()); | |||
| String podStatus = k8sClientUtil.getPodStatus(podName, namespace); | |||
| for (PodStatus s : PodStatus.values()) { | |||
| if (s.getName().equals(podStatus)) { | |||
| status = s.getName(); | |||
| @@ -160,8 +158,6 @@ public class JupyterServiceImpl implements JupyterService { | |||
| } catch (Exception e) { | |||
| return JupyterStatusVo; | |||
| } | |||
| String url = redisService.getCacheObject(podName); | |||
| JupyterStatusVo.setStatus(status); | |||
| @@ -17,10 +17,7 @@ import org.springframework.data.domain.PageRequest; | |||
| import javax.annotation.Resource; | |||
| import java.io.IOException; | |||
| import java.util.ArrayList; | |||
| import java.util.Date; | |||
| import java.util.List; | |||
| import java.util.Map; | |||
| import java.util.*; | |||
| import java.util.stream.Collectors; | |||
| /** | |||
| @@ -97,6 +94,16 @@ public class ModelDependencyServiceImpl implements ModelDependencyService { | |||
| return modelDependcyTreeVo; | |||
| } | |||
| @Override | |||
| public List<ModelDependency> queryByIns(Integer expInsId) { | |||
| return modelDependencyDao.queryByIns(expInsId); | |||
| } | |||
| @Override | |||
| public ModelDependency queryByInsAndTrainTaskId(Integer expInsId, String taskId) { | |||
| return modelDependencyDao.queryByInsAndTrainTaskId(expInsId,taskId); | |||
| } | |||
| /** | |||
| * 递归父模型 | |||
| * @param modelDependcyTreeVo | |||
| @@ -0,0 +1,76 @@ | |||
| package com.ruoyi.platform.utils; | |||
| import com.alibaba.fastjson.JSON; | |||
| import java.util.Base64; | |||
| import java.util.HashMap; | |||
| import java.util.List; | |||
| import java.util.Map; | |||
| public class AIM64EncoderUtil { | |||
| private static final String AIM64_ENCODING_PREFIX = "O-"; | |||
| private static final Map<String, String> BS64_REPLACE_CHARACTERS_ENCODING = new HashMap<>(); | |||
| static { | |||
| BS64_REPLACE_CHARACTERS_ENCODING.put("=", ""); | |||
| BS64_REPLACE_CHARACTERS_ENCODING.put("+", "-"); | |||
| BS64_REPLACE_CHARACTERS_ENCODING.put("/", "_"); | |||
| } | |||
| public static String aim64encode(Map<String, Object> value) { | |||
| String jsonEncoded = JSON.toJSONString(value); | |||
| String base64Encoded = Base64.getEncoder().encodeToString(jsonEncoded.getBytes()); | |||
| String aim64Encoded = base64Encoded; | |||
| for (Map.Entry<String, String> entry : BS64_REPLACE_CHARACTERS_ENCODING.entrySet()) { | |||
| aim64Encoded = aim64Encoded.replace(entry.getKey(), entry.getValue()); | |||
| } | |||
| return AIM64_ENCODING_PREFIX + aim64Encoded; | |||
| } | |||
| public static String encode(Map<String, Object> value, boolean oneWayHashing) { | |||
| if (oneWayHashing) { | |||
| return md5(JSON.toJSONString(value)); | |||
| } | |||
| return aim64encode(value); | |||
| } | |||
| private static String md5(String input) { | |||
| try { | |||
| java.security.MessageDigest md = java.security.MessageDigest.getInstance("MD5"); | |||
| byte[] array = md.digest(input.getBytes()); | |||
| StringBuilder sb = new StringBuilder(); | |||
| for (byte b : array) { | |||
| sb.append(Integer.toHexString((b & 0xFF) | 0x100).substring(1, 3)); | |||
| } | |||
| return sb.toString(); | |||
| } catch (java.security.NoSuchAlgorithmException e) { | |||
| e.printStackTrace(); | |||
| } | |||
| return null; | |||
| } | |||
| public static String decode(List<String> runIds) { | |||
| // 确保 runIds 列表的大小为 3 | |||
| if (runIds == null || runIds.size() == 0) { | |||
| throw new IllegalArgumentException("runIds 不能为空"); | |||
| } | |||
| // 构建查询字符串 | |||
| StringBuilder queryBuilder = new StringBuilder("run.hash in ["); | |||
| for (int i = 0; i < runIds.size(); i++) { | |||
| if (i > 0) { | |||
| queryBuilder.append(","); | |||
| } | |||
| queryBuilder.append("\"").append(runIds.get(i)).append("\""); | |||
| } | |||
| queryBuilder.append("]"); | |||
| String query = queryBuilder.toString(); | |||
| Map<String, Object> map = new HashMap<>(); | |||
| map.put("query", query); | |||
| map.put("advancedMode", true); | |||
| map.put("advancedQuery", query); | |||
| String searchQuery = encode(map, false); | |||
| return searchQuery; | |||
| } | |||
| } | |||
| @@ -25,6 +25,7 @@ import java.security.cert.X509Certificate; | |||
| import java.util.HashMap; | |||
| import java.util.List; | |||
| import java.util.Map; | |||
| import java.util.zip.GZIPInputStream; | |||
| /** | |||
| * HTTP请求工具类 | |||
| @@ -447,4 +448,38 @@ public class HttpUtils { | |||
| return true; | |||
| } | |||
| } | |||
| public static String sendGetRequestgzip(String url) throws Exception { | |||
| String resultStr = null; | |||
| HttpGet httpGet = new HttpGet(url); | |||
| httpGet.setHeader("Content-Type", "application/json"); | |||
| httpGet.setHeader("Accept-Encoding", "gzip, deflate"); | |||
| try { | |||
| HttpResponse response = httpClient.execute(httpGet); | |||
| int responseCode = response.getStatusLine().getStatusCode(); | |||
| if (responseCode != 200) { | |||
| throw new IOException("HTTP request failed with response code: " + responseCode); | |||
| } | |||
| // 获取响应内容 | |||
| InputStream responseStream = response.getEntity().getContent(); | |||
| // 检查响应是否被压缩 | |||
| if ("gzip".equalsIgnoreCase(response.getEntity().getContentEncoding().getValue())) { | |||
| responseStream = new GZIPInputStream(responseStream); | |||
| } | |||
| // 读取解压缩后的内容 | |||
| byte[] buffer = new byte[1024]; | |||
| int len; | |||
| StringBuilder decompressedString = new StringBuilder(); | |||
| while ((len = responseStream.read(buffer)) > 0) { | |||
| decompressedString.append(new String(buffer, 0, len)); | |||
| } | |||
| resultStr = decompressedString.toString(); | |||
| } catch (IOException e) { | |||
| e.printStackTrace(); | |||
| } | |||
| return resultStr; | |||
| } | |||
| } | |||
| @@ -1,8 +1,11 @@ | |||
| package com.ruoyi.platform.utils; | |||
| import com.fasterxml.jackson.core.JsonProcessingException; | |||
| import com.fasterxml.jackson.databind.ObjectMapper; | |||
| import org.json.JSONObject; | |||
| import java.io.IOException; | |||
| import java.util.HashMap; | |||
| import java.util.Iterator; | |||
| import java.util.Map; | |||
| public class JsonUtils { | |||
| @@ -28,4 +31,26 @@ public class JsonUtils { | |||
| public static <T> T jsonToObject(String json, Class<T> clazz) throws IOException { | |||
| return objectMapper.readValue(json, clazz); | |||
| } | |||
| // 将JSON字符串转换为扁平化的Map | |||
| public static Map<String, Object> flattenJson(String prefix, Map<String, Object> map) { | |||
| Map<String, Object> flatMap = new HashMap<>(); | |||
| Iterator<Map.Entry<String, Object>> entries = map.entrySet().iterator(); | |||
| while (entries.hasNext()) { | |||
| Map.Entry<String, Object> entry = entries.next(); | |||
| String key = entry.getKey(); | |||
| Object value = entry.getValue(); | |||
| if (value instanceof Map) { | |||
| flatMap.putAll(flattenJson(prefix + key + ".", (Map<String, Object>) value)); | |||
| } else { | |||
| flatMap.put(prefix + key, value); | |||
| } | |||
| } | |||
| return flatMap; | |||
| } | |||
| } | |||
| @@ -0,0 +1,32 @@ | |||
| package com.ruoyi.platform.vo; | |||
| import com.fasterxml.jackson.databind.PropertyNamingStrategy; | |||
| import com.fasterxml.jackson.databind.annotation.JsonNaming; | |||
| import io.swagger.annotations.ApiModelProperty; | |||
| import lombok.Data; | |||
| import java.io.Serializable; | |||
| import java.util.Date; | |||
| import java.util.List; | |||
| import java.util.Map; | |||
| @JsonNaming(PropertyNamingStrategy.SnakeCaseStrategy.class) | |||
| @Data | |||
| public class InsMetricInfoVo implements Serializable { | |||
| @ApiModelProperty(value = "开始时间") | |||
| private Date startTime; | |||
| @ApiModelProperty(value = "实例运行状态") | |||
| private String status; | |||
| @ApiModelProperty(value = "使用数据集") | |||
| private List<String> dataset; | |||
| @ApiModelProperty(value = "实例ID") | |||
| private Integer experimentInsId; | |||
| @ApiModelProperty(value = "训练指标") | |||
| private Map metrics; | |||
| @ApiModelProperty(value = "训练参数") | |||
| private Map params; | |||
| @ApiModelProperty(value = "训练记录ID") | |||
| private String runId; | |||
| private List<String> metricsNames; | |||
| private List<String> paramsNames; | |||
| } | |||
| @@ -22,6 +22,22 @@ | |||
| <result property="state" column="state" jdbcType="INTEGER"/> | |||
| </resultMap> | |||
| <select id="queryByIns" resultMap="ModelDependencyMap"> | |||
| select | |||
| id,current_model_id,exp_ins_id,parent_models,ref_item,train_task,train_dataset,train_params,train_image,test_dataset,project_dependency,version,create_by,create_time,update_by,update_time,state | |||
| from model_dependency | |||
| <where> | |||
| exp_ins_id = #{expInsId} and state = 1 | |||
| </where> | |||
| </select> | |||
| <select id="queryByInsAndTrainTaskId" resultMap="ModelDependencyMap"> | |||
| select | |||
| id,current_model_id,exp_ins_id,parent_models,ref_item,train_task,train_dataset,train_params,train_image,test_dataset,project_dependency,version,create_by,create_time,update_by,update_time,state | |||
| from model_dependency | |||
| <where> | |||
| exp_ins_id = #{expInsId} and train_task like concat('%', #{taskId}, '%') limit 1 | |||
| </where> | |||
| </select> | |||
| <select id="queryChildrenByVersionId" resultMap="ModelDependencyMap"> | |||
| select | |||
| id,current_model_id,exp_ins_id,parent_models,ref_item,train_task,train_dataset,train_params,train_image,test_dataset,project_dependency,version,create_by,create_time,update_by,update_time,state | |||