| @@ -96,6 +96,11 @@ export default [ | |||||
| path: ':workflowId/:id', | path: ':workflowId/:id', | ||||
| component: './Experiment/training/index', | component: './Experiment/training/index', | ||||
| }, | }, | ||||
| { | |||||
| name: '实验对比', | |||||
| path: 'compare', | |||||
| component: './Experiment/Comparison/index', | |||||
| }, | |||||
| ], | ], | ||||
| }, | }, | ||||
| ], | ], | ||||
| @@ -107,13 +112,18 @@ export default [ | |||||
| { | { | ||||
| name: '开发环境', | name: '开发环境', | ||||
| path: '', | path: '', | ||||
| component: './DevelopmentEnvironment/List', | |||||
| }, | |||||
| { | |||||
| name: '创建编辑器', | |||||
| path: 'create', | |||||
| component: './DevelopmentEnvironment/Create', | |||||
| component: './DevelopmentEnvironment/Editor', | |||||
| }, | }, | ||||
| // { | |||||
| // name: '创建编辑器', | |||||
| // path: 'create', | |||||
| // component: './DevelopmentEnvironment/Create', | |||||
| // }, | |||||
| // { | |||||
| // name: '编辑器', | |||||
| // path: 'editor', | |||||
| // component: './DevelopmentEnvironment/Editor', | |||||
| // }, | |||||
| ], | ], | ||||
| }, | }, | ||||
| { | { | ||||
| @@ -58,7 +58,7 @@ const ResourceIntro = ({ resourceType }: ResourceIntroProps) => { | |||||
| }; | }; | ||||
| }), | }), | ||||
| ); | ); | ||||
| if (versionParam) { | |||||
| if (versionParam && res.data.includes(versionParam)) { | |||||
| setVersion(versionParam); | setVersion(versionParam); | ||||
| versionParam = null; | versionParam = null; | ||||
| } else { | } else { | ||||
| @@ -1,10 +1,16 @@ | |||||
| // import { editorUrl, getSessionStorageItem, removeSessionStorageItem } from '@/utils/sessionStorage'; | |||||
| import { getJupyterUrl } from '@/services/developmentEnvironment'; | import { getJupyterUrl } from '@/services/developmentEnvironment'; | ||||
| import { to } from '@/utils/promise'; | import { to } from '@/utils/promise'; | ||||
| import { useEffect, useState } from 'react'; | import { useEffect, useState } from 'react'; | ||||
| const DevelopmentEnvironment = () => { | |||||
| function DevEditor() { | |||||
| const [iframeUrl, setIframeUrl] = useState(''); | const [iframeUrl, setIframeUrl] = useState(''); | ||||
| useEffect(() => { | useEffect(() => { | ||||
| // const url = getSessionStorageItem(editorUrl) || ''; | |||||
| // setIframeUrl(url); | |||||
| // return () => { | |||||
| // removeSessionStorageItem(editorUrl); | |||||
| // }; | |||||
| requestJupyterUrl(); | requestJupyterUrl(); | ||||
| }, []); | }, []); | ||||
| @@ -18,5 +24,5 @@ const DevelopmentEnvironment = () => { | |||||
| }; | }; | ||||
| return <iframe style={{ width: '100%', height: '100%', border: 0 }} src={iframeUrl}></iframe>; | return <iframe style={{ width: '100%', height: '100%', border: 0 }} src={iframeUrl}></iframe>; | ||||
| }; | |||||
| export default DevelopmentEnvironment; | |||||
| } | |||||
| export default DevEditor; | |||||
| @@ -16,6 +16,7 @@ import { | |||||
| } from '@/services/developmentEnvironment'; | } from '@/services/developmentEnvironment'; | ||||
| import themes from '@/styles/theme.less'; | import themes from '@/styles/theme.less'; | ||||
| import { to } from '@/utils/promise'; | import { to } from '@/utils/promise'; | ||||
| import { editorUrl, setSessionStorageItem } from '@/utils/sessionStorage'; | |||||
| import { modalConfirm } from '@/utils/ui'; | import { modalConfirm } from '@/utils/ui'; | ||||
| import { useNavigate } from '@umijs/max'; | import { useNavigate } from '@umijs/max'; | ||||
| import { | import { | ||||
| @@ -38,6 +39,7 @@ export type EditorData = { | |||||
| computing_resource: string; | computing_resource: string; | ||||
| update_by: string; | update_by: string; | ||||
| create_time: string; | create_time: string; | ||||
| url: string; | |||||
| }; | }; | ||||
| function EditorList() { | function EditorList() { | ||||
| @@ -127,6 +129,16 @@ function EditorList() { | |||||
| }); | }); | ||||
| }; | }; | ||||
| // 跳转编辑器页面 | |||||
| const gotoEditorPage = (e: React.MouseEvent, record: EditorData) => { | |||||
| e.stopPropagation(); | |||||
| setSessionStorageItem(editorUrl, record.url); | |||||
| navigate(`/developmentEnvironment/editor`); | |||||
| setCacheState({ | |||||
| pagination, | |||||
| }); | |||||
| }; | |||||
| // 分页切换 | // 分页切换 | ||||
| const handleTableChange: TableProps['onChange'] = (pagination, filters, sorter, { action }) => { | const handleTableChange: TableProps['onChange'] = (pagination, filters, sorter, { action }) => { | ||||
| if (action === 'paginate') { | if (action === 'paginate') { | ||||
| @@ -140,7 +152,12 @@ function EditorList() { | |||||
| dataIndex: 'name', | dataIndex: 'name', | ||||
| key: 'name', | key: 'name', | ||||
| width: '30%', | width: '30%', | ||||
| render: CommonTableCell(), | |||||
| render: (text, record) => | |||||
| record.url ? ( | |||||
| <a onClick={(e) => gotoEditorPage(e, record)}>{text}</a> | |||||
| ) : ( | |||||
| <span>{text ?? '--'}</span> | |||||
| ), | |||||
| }, | }, | ||||
| { | { | ||||
| title: '状态', | title: '状态', | ||||
| @@ -0,0 +1,21 @@ | |||||
| .experiment-comparison { | |||||
| height: 100%; | |||||
| &__header { | |||||
| display: flex; | |||||
| align-items: center; | |||||
| height: 50px; | |||||
| margin-bottom: 10px; | |||||
| padding: 0 30px; | |||||
| background-image: url(@/assets/img/page-title-bg.png); | |||||
| background-repeat: no-repeat; | |||||
| background-position: top center; | |||||
| background-size: 100% 100%; | |||||
| } | |||||
| &__table { | |||||
| height: calc(100% - 60px); | |||||
| padding: 20px 30px 0; | |||||
| background-color: white; | |||||
| border-radius: 10px; | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,212 @@ | |||||
| import CommonTableCell from '@/components/CommonTableCell'; | |||||
| import { useCacheState } from '@/hooks/pageCacheState'; | |||||
| import { getExpEvaluateInfosReq, getExpTrainInfosReq } from '@/services/experiment'; | |||||
| import { to } from '@/utils/promise'; | |||||
| import { useSearchParams } from '@umijs/max'; | |||||
| import { Button, Table, TablePaginationConfig, TableProps } from 'antd'; | |||||
| import classNames from 'classnames'; | |||||
| import { useEffect, useState } from 'react'; | |||||
| import styles from './index.less'; | |||||
| export enum ComparisonType { | |||||
| Train = 'train', // 训练 | |||||
| Evaluate = 'evaluate', // 评估 | |||||
| } | |||||
| function ExperimentComparison() { | |||||
| const [searchParams] = useSearchParams(); | |||||
| const comparisonType = searchParams.get('type'); | |||||
| const experimentId = searchParams.get('experimentId'); | |||||
| const [tableData, setTableData] = useState([]); | |||||
| const [cacheState, setCacheState] = useCacheState(); | |||||
| const [total, setTotal] = useState(0); | |||||
| const [selectedRowKeys, setSelectedRowKeys] = useState<React.Key[]>([]); | |||||
| const [pagination, setPagination] = useState<TablePaginationConfig>( | |||||
| cacheState?.pagination ?? { | |||||
| current: 1, | |||||
| pageSize: 10, | |||||
| }, | |||||
| ); | |||||
| useEffect(() => { | |||||
| getComparisonData(); | |||||
| }, [experimentId]); | |||||
| // 获取对比数据列表 | |||||
| const getComparisonData = async () => { | |||||
| const request = | |||||
| comparisonType === ComparisonType.Train ? getExpTrainInfosReq : getExpEvaluateInfosReq; | |||||
| const [res] = await to(request(experimentId)); | |||||
| if (res && res.data) { | |||||
| const { content = [], totalElements = 0 } = res.data; | |||||
| setTableData(content); | |||||
| setTotal(totalElements); | |||||
| } | |||||
| }; | |||||
| // 分页切换 | |||||
| const handleTableChange: TableProps['onChange'] = (pagination, filters, sorter, { action }) => { | |||||
| if (action === 'paginate') { | |||||
| setPagination(pagination); | |||||
| } | |||||
| // console.log(pagination, filters, sorter, action); | |||||
| }; | |||||
| const rowSelection: TableProps['rowSelection'] = { | |||||
| type: 'checkbox', | |||||
| selectedRowKeys, | |||||
| onChange: (selectedRowKeys: React.Key[], selectedRows: any[]) => { | |||||
| console.log(`selectedRowKeys: ${selectedRowKeys}`, 'selectedRows: ', selectedRows); | |||||
| setSelectedRowKeys(selectedRowKeys); | |||||
| }, | |||||
| }; | |||||
| const columns: TableProps['columns'] = [ | |||||
| { | |||||
| title: '基本信息', | |||||
| children: [ | |||||
| { | |||||
| title: '实例ID', | |||||
| dataIndex: 'name', | |||||
| key: 'name', | |||||
| width: '30%', | |||||
| render: CommonTableCell(), | |||||
| }, | |||||
| { | |||||
| title: '运行时间', | |||||
| dataIndex: 'name', | |||||
| key: 'name', | |||||
| width: '30%', | |||||
| render: CommonTableCell(), | |||||
| }, | |||||
| { | |||||
| title: '运行状态', | |||||
| dataIndex: 'name', | |||||
| key: 'name', | |||||
| width: '30%', | |||||
| render: CommonTableCell(), | |||||
| }, | |||||
| { | |||||
| title: '训练数据集', | |||||
| dataIndex: 'name', | |||||
| key: 'name', | |||||
| width: '30%', | |||||
| render: CommonTableCell(), | |||||
| }, | |||||
| { | |||||
| title: '增量训练', | |||||
| dataIndex: 'name', | |||||
| key: 'name', | |||||
| width: '30%', | |||||
| render: CommonTableCell(), | |||||
| }, | |||||
| ], | |||||
| }, | |||||
| { | |||||
| title: '训练参数', | |||||
| children: [ | |||||
| { | |||||
| title: 'batchsize', | |||||
| dataIndex: 'name', | |||||
| key: 'name', | |||||
| width: '30%', | |||||
| render: CommonTableCell(), | |||||
| }, | |||||
| { | |||||
| title: 'config', | |||||
| dataIndex: 'name', | |||||
| key: 'name', | |||||
| width: '30%', | |||||
| render: CommonTableCell(), | |||||
| }, | |||||
| { | |||||
| title: 'epoch', | |||||
| dataIndex: 'name', | |||||
| key: 'name', | |||||
| width: '30%', | |||||
| render: CommonTableCell(), | |||||
| }, | |||||
| { | |||||
| title: 'lr', | |||||
| dataIndex: 'name', | |||||
| key: 'name', | |||||
| width: '30%', | |||||
| render: CommonTableCell(), | |||||
| }, | |||||
| { | |||||
| title: 'warmup_iters', | |||||
| dataIndex: 'name', | |||||
| key: 'name', | |||||
| width: '30%', | |||||
| render: CommonTableCell(), | |||||
| }, | |||||
| ], | |||||
| }, | |||||
| { | |||||
| title: '训练指标', | |||||
| children: [ | |||||
| { | |||||
| title: 'metrc_name', | |||||
| dataIndex: 'name', | |||||
| key: 'name', | |||||
| width: '30%', | |||||
| render: CommonTableCell(), | |||||
| }, | |||||
| { | |||||
| title: 'test_1', | |||||
| dataIndex: 'name', | |||||
| key: 'name', | |||||
| width: '30%', | |||||
| render: CommonTableCell(), | |||||
| }, | |||||
| { | |||||
| title: 'test_2', | |||||
| dataIndex: 'name', | |||||
| key: 'name', | |||||
| width: '30%', | |||||
| render: CommonTableCell(), | |||||
| }, | |||||
| { | |||||
| title: 'test_3', | |||||
| dataIndex: 'name', | |||||
| key: 'name', | |||||
| width: '30%', | |||||
| render: CommonTableCell(), | |||||
| }, | |||||
| { | |||||
| title: 'test_4', | |||||
| dataIndex: 'name', | |||||
| key: 'name', | |||||
| width: '30%', | |||||
| render: CommonTableCell(), | |||||
| }, | |||||
| ], | |||||
| }, | |||||
| ]; | |||||
| return ( | |||||
| <div className={styles['experiment-comparison']}> | |||||
| <div className={styles['experiment-comparison__header']}> | |||||
| <Button type="default">可视化对比</Button> | |||||
| </div> | |||||
| <div className={classNames('vertical-scroll-table', styles['experiment-comparison__table'])}> | |||||
| <Table | |||||
| dataSource={tableData} | |||||
| columns={columns} | |||||
| rowSelection={rowSelection} | |||||
| scroll={{ y: 'calc(100% - 55px)' }} | |||||
| pagination={{ | |||||
| ...pagination, | |||||
| total: total, | |||||
| showSizeChanger: true, | |||||
| showQuickJumper: true, | |||||
| }} | |||||
| onChange={handleTableChange} | |||||
| rowKey="id" | |||||
| /> | |||||
| </div> | |||||
| </div> | |||||
| ); | |||||
| } | |||||
| export default ExperimentComparison; | |||||
| @@ -18,10 +18,11 @@ import themes from '@/styles/theme.less'; | |||||
| import { elapsedTime, formatDate } from '@/utils/date'; | import { elapsedTime, formatDate } from '@/utils/date'; | ||||
| import { to } from '@/utils/promise'; | import { to } from '@/utils/promise'; | ||||
| import { modalConfirm } from '@/utils/ui'; | import { modalConfirm } from '@/utils/ui'; | ||||
| import { App, Button, ConfigProvider, Space, Table, Tooltip } from 'antd'; | |||||
| import { App, Button, ConfigProvider, Dropdown, Space, Table, Tooltip } from 'antd'; | |||||
| import classNames from 'classnames'; | import classNames from 'classnames'; | ||||
| import { useEffect, useRef, useState } from 'react'; | import { useEffect, useRef, useState } from 'react'; | ||||
| import { useNavigate } from 'react-router-dom'; | import { useNavigate } from 'react-router-dom'; | ||||
| import { ComparisonType } from './Comparison'; | |||||
| import AddExperimentModal from './components/AddExperimentModal'; | import AddExperimentModal from './components/AddExperimentModal'; | ||||
| import TensorBoardStatus, { TensorBoardStatusEnum } from './components/TensorBoardStatus'; | import TensorBoardStatus, { TensorBoardStatusEnum } from './components/TensorBoardStatus'; | ||||
| import Styles from './index.less'; | import Styles from './index.less'; | ||||
| @@ -270,6 +271,26 @@ function Experiment() { | |||||
| window.open(experimentIn.tensorboardUrl, '_blank'); | window.open(experimentIn.tensorboardUrl, '_blank'); | ||||
| } | } | ||||
| }; | }; | ||||
| // 实验对比菜单 | |||||
| const getComparisonMenu = (experimentId) => { | |||||
| return { | |||||
| items: [ | |||||
| { | |||||
| label: <span>训练对比</span>, | |||||
| key: ComparisonType.Train, | |||||
| }, | |||||
| { | |||||
| label: <span>评估对比</span>, | |||||
| key: ComparisonType.Evaluate, | |||||
| }, | |||||
| ], | |||||
| onClick: ({ key }) => { | |||||
| navgite(`/pipeline/experiment/compare?type=${key}&id=${experimentId}`); | |||||
| }, | |||||
| }; | |||||
| }; | |||||
| const columns = [ | const columns = [ | ||||
| { | { | ||||
| title: '实验名称', | title: '实验名称', | ||||
| @@ -320,7 +341,7 @@ function Experiment() { | |||||
| { | { | ||||
| title: '操作', | title: '操作', | ||||
| key: 'action', | key: 'action', | ||||
| width: 300, | |||||
| width: 350, | |||||
| render: (_, record) => ( | render: (_, record) => ( | ||||
| <Space size="small"> | <Space size="small"> | ||||
| <Button | <Button | ||||
| @@ -345,6 +366,14 @@ function Experiment() { | |||||
| > | > | ||||
| 编辑 | 编辑 | ||||
| </Button> | </Button> | ||||
| <Dropdown key="comparison" menu={getComparisonMenu(record.id)}> | |||||
| <a onClick={(e) => e.preventDefault()}> | |||||
| <Space style={{ padding: '0 7px' }}> | |||||
| <KFIcon type="icon-shiyanduibi" /> | |||||
| 实验对比 | |||||
| </Space> | |||||
| </a> | |||||
| </Dropdown> | |||||
| <ConfigProvider | <ConfigProvider | ||||
| theme={{ | theme={{ | ||||
| token: { | token: { | ||||
| @@ -58,7 +58,7 @@ | |||||
| } | } | ||||
| .operation { | .operation { | ||||
| width: 284px; | |||||
| width: 334px; | |||||
| } | } | ||||
| } | } | ||||
| .tableExpandBoxContent { | .tableExpandBoxContent { | ||||
| @@ -93,11 +93,11 @@ const EditPipeline = () => { | |||||
| return; | return; | ||||
| } | } | ||||
| // const [propsRes, propsError] = await to(propsRef.current.getFieldsValue()); | |||||
| // if (propsError) { | |||||
| // message.error('基本信息必填项需配置'); | |||||
| // return; | |||||
| // } | |||||
| const [propsRes, propsError] = await to(propsRef.current.getFieldsValue()); | |||||
| if (propsError) { | |||||
| message.error('节点必填项必须配置'); | |||||
| return; | |||||
| } | |||||
| propsRef.current.propClose(); | propsRef.current.propClose(); | ||||
| setTimeout(() => { | setTimeout(() => { | ||||
| const data = graph.save(); | const data = graph.save(); | ||||
| @@ -40,7 +40,7 @@ export const requestConfig: RequestConfig = { | |||||
| [ | [ | ||||
| (response: AxiosResponse) => { | (response: AxiosResponse) => { | ||||
| const { status, data } = response || {}; | const { status, data } = response || {}; | ||||
| console.log(message, data); | |||||
| // console.log(message, data); | |||||
| if (status >= 200 && status < 300) { | if (status >= 200 && status < 300) { | ||||
| if (data && (data instanceof Blob || data.code === 200)) { | if (data && (data instanceof Blob || data.code === 200)) { | ||||
| return response; | return response; | ||||
| @@ -1,9 +1,8 @@ | |||||
| import { request } from '@umijs/max'; | import { request } from '@umijs/max'; | ||||
| // 查询开发环境url | // 查询开发环境url | ||||
| export function getJupyterUrl(params: any) { | |||||
| export function getJupyterUrl() { | |||||
| return request(`/api/mmp/jupyter/getURL`, { | return request(`/api/mmp/jupyter/getURL`, { | ||||
| method: 'GET', | method: 'GET', | ||||
| params, | |||||
| }); | }); | ||||
| } | } | ||||
| @@ -116,3 +116,17 @@ export function getTensorBoardStatusReq(data) { | |||||
| data, | data, | ||||
| }); | }); | ||||
| } | } | ||||
| // 获取当前实验的模型推理指标信息 | |||||
| export function getExpEvaluateInfosReq(experimentId) { | |||||
| return request(`/api/mmp/aim/getExpEvaluateInfos/${experimentId}`, { | |||||
| method: 'GET', | |||||
| }); | |||||
| } | |||||
| // 获取当前实验的模型训练指标信息 | |||||
| export function getExpTrainInfosReq(experimentId) { | |||||
| return request(`/api/mmp//aim/getExpTrainInfos/${experimentId}`, { | |||||
| method: 'GET', | |||||
| }); | |||||
| } | |||||
| @@ -60,8 +60,8 @@ function patchRouteItems(route: any, menu: any, parentPath: string) { | |||||
| element: React.createElement(lazy(() => import('@/pages/' + path))), | element: React.createElement(lazy(() => import('@/pages/' + path))), | ||||
| path: parentPath + menuItem.path, | path: parentPath + menuItem.path, | ||||
| }; | }; | ||||
| console.log(newRoute); | |||||
| // console.log(newRoute); | |||||
| route.children.push(newRoute); | route.children.push(newRoute); | ||||
| route.routes.push(newRoute); | route.routes.push(newRoute); | ||||
| } | } | ||||
| @@ -74,10 +74,7 @@ export function patchRouteWithRemoteMenus(routes: any) { | |||||
| } | } | ||||
| let proLayout = null; | let proLayout = null; | ||||
| for (const routeItem of routes) { | for (const routeItem of routes) { | ||||
| if (routeItem.id === 'ant-design-pro-layout') { | if (routeItem.id === 'ant-design-pro-layout') { | ||||
| proLayout = routeItem; | proLayout = routeItem; | ||||
| break; | break; | ||||
| } | } | ||||
| @@ -101,7 +98,6 @@ export async function refreshToken() { | |||||
| } | } | ||||
| export function convertCompatRouters(childrens: API.RoutersMenuItem[]): any[] { | export function convertCompatRouters(childrens: API.RoutersMenuItem[]): any[] { | ||||
| return childrens.map((item: API.RoutersMenuItem) => { | return childrens.map((item: API.RoutersMenuItem) => { | ||||
| return { | return { | ||||
| path: item.path, | path: item.path, | ||||
| @@ -149,12 +145,11 @@ export function getMatchMenuItem( | |||||
| const subpath = path.substr(item.path.length + 1); | const subpath = path.substr(item.path.length + 1); | ||||
| const subItem: MenuDataItem[] = getMatchMenuItem(subpath, item.routes); | const subItem: MenuDataItem[] = getMatchMenuItem(subpath, item.routes); | ||||
| items = items.concat(subItem); | items = items.concat(subItem); | ||||
| } else { | } else { | ||||
| const paths = path.split('/'); | const paths = path.split('/'); | ||||
| if (paths.length >= 2 && paths[0] === item.path && paths[1] === 'index') { | if (paths.length >= 2 && paths[0] === item.path && paths[1] === 'index') { | ||||
| console.log(item); | console.log(item); | ||||
| items.push(item); | items.push(item); | ||||
| } | } | ||||
| } | } | ||||
| @@ -2,6 +2,8 @@ | |||||
| export const mirrorNameKey = 'mirror-name'; | export const mirrorNameKey = 'mirror-name'; | ||||
| // 模型部署 | // 模型部署 | ||||
| export const modelDeploymentInfoKey = 'model-deployment-info'; | export const modelDeploymentInfoKey = 'model-deployment-info'; | ||||
| // 编辑器 url | |||||
| export const editorUrl = 'editor-url'; | |||||
| export const getSessionStorageItem = (key: string, isObject: boolean = false) => { | export const getSessionStorageItem = (key: string, isObject: boolean = false) => { | ||||
| const jsonStr = sessionStorage.getItem(key); | const jsonStr = sessionStorage.getItem(key); | ||||
| @@ -20,6 +20,7 @@ | |||||
| "noUnusedParameters": true, // 报告未使用的参数错误 | "noUnusedParameters": true, // 报告未使用的参数错误 | ||||
| "incremental": true, // 通过读写磁盘上的文件来启用增量编译 | "incremental": true, // 通过读写磁盘上的文件来启用增量编译 | ||||
| "noFallthroughCasesInSwitch": true, // 报告switch语句中的fallthrough案例错误 | "noFallthroughCasesInSwitch": true, // 报告switch语句中的fallthrough案例错误 | ||||
| "strictNullChecks": true, // 启用严格的null检查 | |||||
| "baseUrl": "./", | "baseUrl": "./", | ||||
| "paths": { | "paths": { | ||||
| "@/*": ["src/*"], | "@/*": ["src/*"], | ||||
| @@ -205,6 +205,17 @@ | |||||
| <groupId>org.springframework.boot</groupId> | <groupId>org.springframework.boot</groupId> | ||||
| <artifactId>spring-boot-starter-websocket</artifactId> | <artifactId>spring-boot-starter-websocket</artifactId> | ||||
| </dependency> | </dependency> | ||||
| <dependency> | |||||
| <groupId>org.json</groupId> | |||||
| <artifactId>json</artifactId> | |||||
| <version>20210307</version> | |||||
| </dependency> | |||||
| <dependency> | |||||
| <groupId>org.apache.dubbo</groupId> | |||||
| <artifactId>dubbo</artifactId> | |||||
| <version>3.0.8</version> | |||||
| <scope>compile</scope> | |||||
| </dependency> | |||||
| </dependencies> | </dependencies> | ||||
| @@ -4,16 +4,16 @@ import com.ruoyi.common.core.web.controller.BaseController; | |||||
| import com.ruoyi.common.core.web.domain.GenericsAjaxResult; | import com.ruoyi.common.core.web.domain.GenericsAjaxResult; | ||||
| import com.ruoyi.platform.service.AimService; | import com.ruoyi.platform.service.AimService; | ||||
| import com.ruoyi.platform.vo.FrameLogPathVo; | import com.ruoyi.platform.vo.FrameLogPathVo; | ||||
| import com.ruoyi.platform.vo.InsMetricInfoVo; | |||||
| import com.ruoyi.platform.vo.PodStatusVo; | import com.ruoyi.platform.vo.PodStatusVo; | ||||
| import io.swagger.annotations.Api; | import io.swagger.annotations.Api; | ||||
| import io.swagger.annotations.ApiOperation; | import io.swagger.annotations.ApiOperation; | ||||
| import io.swagger.v3.oas.annotations.responses.ApiResponse; | import io.swagger.v3.oas.annotations.responses.ApiResponse; | ||||
| import org.springframework.web.bind.annotation.PostMapping; | |||||
| import org.springframework.web.bind.annotation.RequestBody; | |||||
| import org.springframework.web.bind.annotation.RequestMapping; | |||||
| import org.springframework.web.bind.annotation.RestController; | |||||
| import org.springframework.web.bind.annotation.*; | |||||
| import javax.annotation.Resource; | import javax.annotation.Resource; | ||||
| import java.util.List; | |||||
| @RestController | @RestController | ||||
| @RequestMapping("aim") | @RequestMapping("aim") | ||||
| @Api("Aim管理") | @Api("Aim管理") | ||||
| @@ -22,17 +22,25 @@ public class AimController extends BaseController { | |||||
| @Resource | @Resource | ||||
| private AimService aimService; | private AimService aimService; | ||||
| /** | |||||
| * 启动tensorBoard接口 | |||||
| * | |||||
| * @param frameLogPathVo 存储路径 | |||||
| * @return url | |||||
| */ | |||||
| @PostMapping("/run") | |||||
| @ApiOperation("启动aim`") | |||||
| @GetMapping("/getExpTrainInfos/{experiment_id}") | |||||
| @ApiOperation("获取当前实验的模型训练指标信息") | |||||
| @ApiResponse | @ApiResponse | ||||
| public GenericsAjaxResult<String> runAim(@RequestBody FrameLogPathVo frameLogPathVo) throws Exception { | |||||
| return genericsSuccess(aimService.runAim(frameLogPathVo)); | |||||
| public GenericsAjaxResult<List<InsMetricInfoVo>> getExpTrainInfos(@PathVariable("experiment_id") Integer experimentId) throws Exception { | |||||
| return genericsSuccess(aimService.getExpTrainInfos(experimentId)); | |||||
| } | } | ||||
| @GetMapping("/getExpEvaluateInfos/{experiment_id}") | |||||
| @ApiOperation("获取当前实验的模型推理指标信息") | |||||
| @ApiResponse | |||||
| public GenericsAjaxResult<List<InsMetricInfoVo>> getExpEvaluateInfos(@PathVariable("experiment_id") Integer experimentId) throws Exception { | |||||
| return genericsSuccess(aimService.getExpEvaluateInfos(experimentId)); | |||||
| } | |||||
| @PostMapping("/getExpMetrics") | |||||
| @ApiOperation("获取当前实验的指标对比地址") | |||||
| @ApiResponse | |||||
| public GenericsAjaxResult<String> getExpMetrics(@RequestBody List<String> runIds) throws Exception { | |||||
| return genericsSuccess(aimService.getExpMetrics(runIds)); | |||||
| } | |||||
| } | } | ||||
| @@ -84,5 +84,9 @@ public interface ModelDependencyDao { | |||||
| List<ModelDependency> queryByModelDependency(@Param("modelDependency") ModelDependency modelDependency); | List<ModelDependency> queryByModelDependency(@Param("modelDependency") ModelDependency modelDependency); | ||||
| List<ModelDependency> queryChildrenByVersionId(@Param("model_id")String modelId, @Param("version")String version); | List<ModelDependency> queryChildrenByVersionId(@Param("model_id")String modelId, @Param("version")String version); | ||||
| List<ModelDependency> queryByIns(@Param("expInsId")Integer expInsId); | |||||
| ModelDependency queryByInsAndTrainTaskId(@Param("expInsId")Integer expInsId,@Param("taskId") String taskId); | |||||
| } | } | ||||
| @@ -7,20 +7,15 @@ import com.ruoyi.platform.mapper.ExperimentDao; | |||||
| import com.ruoyi.platform.mapper.ExperimentInsDao; | import com.ruoyi.platform.mapper.ExperimentInsDao; | ||||
| import com.ruoyi.platform.mapper.ModelDependencyDao; | import com.ruoyi.platform.mapper.ModelDependencyDao; | ||||
| import com.ruoyi.platform.service.ExperimentInsService; | import com.ruoyi.platform.service.ExperimentInsService; | ||||
| import com.ruoyi.platform.service.ModelDependencyService; | |||||
| import com.ruoyi.platform.utils.JacksonUtil; | import com.ruoyi.platform.utils.JacksonUtil; | ||||
| import org.apache.commons.lang3.StringUtils; | import org.apache.commons.lang3.StringUtils; | ||||
| import org.springframework.beans.factory.annotation.Autowired; | import org.springframework.beans.factory.annotation.Autowired; | ||||
| import org.springframework.data.domain.Page; | |||||
| import org.springframework.scheduling.annotation.Scheduled; | import org.springframework.scheduling.annotation.Scheduled; | ||||
| import org.springframework.stereotype.Component; | import org.springframework.stereotype.Component; | ||||
| import javax.annotation.Resource; | import javax.annotation.Resource; | ||||
| import java.io.IOException; | import java.io.IOException; | ||||
| import java.util.ArrayList; | |||||
| import java.util.Date; | |||||
| import java.util.List; | |||||
| import java.util.Map; | |||||
| import java.util.*; | |||||
| @Component() | @Component() | ||||
| public class ExperimentInstanceStatusTask { | public class ExperimentInstanceStatusTask { | ||||
| @@ -34,7 +29,7 @@ public class ExperimentInstanceStatusTask { | |||||
| private ModelDependencyDao modelDependencyDao; | private ModelDependencyDao modelDependencyDao; | ||||
| private List<Integer> experimentIds = new ArrayList<>(); | private List<Integer> experimentIds = new ArrayList<>(); | ||||
| @Scheduled(cron = "0/14 * * * * ?") // 每30S执行一次 | |||||
| @Scheduled(cron = "0/30 * * * * ?") // 每30S执行一次 | |||||
| public void executeExperimentInsStatus() throws IOException { | public void executeExperimentInsStatus() throws IOException { | ||||
| // 首先查到所有非终止态的实验实例 | // 首先查到所有非终止态的实验实例 | ||||
| List<ExperimentIns> experimentInsList = experimentInsService.queryByExperimentIsNotTerminated(); | List<ExperimentIns> experimentInsList = experimentInsService.queryByExperimentIsNotTerminated(); | ||||
| @@ -46,95 +41,94 @@ public class ExperimentInstanceStatusTask { | |||||
| String oldStatus = experimentIns.getStatus(); | String oldStatus = experimentIns.getStatus(); | ||||
| try { | try { | ||||
| experimentIns = experimentInsService.queryStatusFromArgo(experimentIns); | experimentIns = experimentInsService.queryStatusFromArgo(experimentIns); | ||||
| }catch (Exception e){ | |||||
| } catch (Exception e) { | |||||
| experimentIns.setStatus("Failed"); | experimentIns.setStatus("Failed"); | ||||
| } | } | ||||
| // if (!StringUtils.equals(oldStatus,experimentIns.getStatus())){ | |||||
| experimentIns.setUpdateTime(new Date()); | |||||
| // 线程安全的添加操作 | |||||
| synchronized (experimentIds) { | |||||
| experimentIds.add(experimentIns.getExperimentId()); | |||||
| } | |||||
| updateList.add(experimentIns); | |||||
| // } | |||||
| // experimentInsDao.update(experimentIns); | |||||
| experimentIns.setUpdateTime(new Date()); | |||||
| // 线程安全的添加操作 | |||||
| synchronized (experimentIds) { | |||||
| experimentIds.add(experimentIns.getExperimentId()); | |||||
| } | |||||
| updateList.add(experimentIns); | |||||
| } | } | ||||
| } | } | ||||
| if (updateList.size() > 0){ | |||||
| if (updateList.size() > 0) { | |||||
| experimentInsDao.insertOrUpdateBatch(updateList); | experimentInsDao.insertOrUpdateBatch(updateList); | ||||
| //遍历模型关系表,找到 | //遍历模型关系表,找到 | ||||
| List<ModelDependency> modelDependencyList = new ArrayList<ModelDependency>(); | List<ModelDependency> modelDependencyList = new ArrayList<ModelDependency>(); | ||||
| for (ExperimentIns experimentIns : updateList){ | |||||
| for (ExperimentIns experimentIns : updateList) { | |||||
| ModelDependency modelDependencyquery = new ModelDependency(); | ModelDependency modelDependencyquery = new ModelDependency(); | ||||
| modelDependencyquery.setExpInsId(experimentIns.getId()); | modelDependencyquery.setExpInsId(experimentIns.getId()); | ||||
| modelDependencyquery.setState(2); | modelDependencyquery.setState(2); | ||||
| List<ModelDependency> modelDependencyListquery = modelDependencyDao.queryByModelDependency(modelDependencyquery); | List<ModelDependency> modelDependencyListquery = modelDependencyDao.queryByModelDependency(modelDependencyquery); | ||||
| if (modelDependencyListquery==null||modelDependencyListquery.size()==0){ | |||||
| if (modelDependencyListquery == null || modelDependencyListquery.size() == 0) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| ModelDependency modelDependency = modelDependencyListquery.get(0); | ModelDependency modelDependency = modelDependencyListquery.get(0); | ||||
| //查看状态, | //查看状态, | ||||
| if (StringUtils.equals("Failed",experimentIns.getStatus())){ | |||||
| if (StringUtils.equals("Failed", experimentIns.getStatus())) { | |||||
| //取出节点状态 | //取出节点状态 | ||||
| String trainTask = modelDependency.getTrainTask(); | String trainTask = modelDependency.getTrainTask(); | ||||
| Map<String, Object> trainMap = JacksonUtil.parseJSONStr2Map(trainTask); | Map<String, Object> trainMap = JacksonUtil.parseJSONStr2Map(trainTask); | ||||
| String task_id = (String) trainMap.get("task_id"); | String task_id = (String) trainMap.get("task_id"); | ||||
| if (StringUtils.isEmpty(task_id)){ | |||||
| if (StringUtils.isEmpty(task_id)) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| String nodesStatus = experimentIns.getNodesStatus(); | String nodesStatus = experimentIns.getNodesStatus(); | ||||
| Map<String, Object> nodeMaps = JacksonUtil.parseJSONStr2Map(nodesStatus); | Map<String, Object> nodeMaps = JacksonUtil.parseJSONStr2Map(nodesStatus); | ||||
| Map<String, Object> nodeMap = JacksonUtil.parseJSONStr2Map(JacksonUtil.toJSONString(nodeMaps.get(task_id))); | Map<String, Object> nodeMap = JacksonUtil.parseJSONStr2Map(JacksonUtil.toJSONString(nodeMaps.get(task_id))); | ||||
| if (nodeMap==null){ | |||||
| if (nodeMap == null) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| if (!StringUtils.equals("Succeeded",(String)nodeMap.get("phase"))){ | |||||
| if (!StringUtils.equals("Succeeded", (String) nodeMap.get("phase"))) { | |||||
| modelDependency.setState(0); | modelDependency.setState(0); | ||||
| modelDependencyList.add(modelDependency); | modelDependencyList.add(modelDependency); | ||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| if (modelDependencyList.size()>0) { | |||||
| if (modelDependencyList.size() > 0) { | |||||
| modelDependencyDao.insertOrUpdateBatch(modelDependencyList); | modelDependencyDao.insertOrUpdateBatch(modelDependencyList); | ||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @Scheduled(cron = "0/17 * * * * ?") // / 每30S执行一次 | |||||
| @Scheduled(cron = "0/30 * * * * ?") // / 每30S执行一次 | |||||
| public void executeExperimentStatus() throws IOException { | public void executeExperimentStatus() throws IOException { | ||||
| if (experimentIds.size()==0){ | |||||
| if (experimentIds.size() == 0) { | |||||
| return; | return; | ||||
| } | } | ||||
| // 存储需要更新的实验对象列表 | // 存储需要更新的实验对象列表 | ||||
| List<Experiment> updateExperiments = new ArrayList<>(); | List<Experiment> updateExperiments = new ArrayList<>(); | ||||
| for (Integer experimentId : experimentIds){ | |||||
| for (Integer experimentId : experimentIds) { | |||||
| // 获取当前实验的所有实例列表 | // 获取当前实验的所有实例列表 | ||||
| List<ExperimentIns> insList = experimentInsService.getByExperimentId(experimentId); | List<ExperimentIns> insList = experimentInsService.getByExperimentId(experimentId); | ||||
| List<String> statusList = new ArrayList<String>(); | List<String> statusList = new ArrayList<String>(); | ||||
| // 更新实验状态列表 | // 更新实验状态列表 | ||||
| for (int i=0;i<insList.size();i++){ | |||||
| for (int i = 0; i < insList.size(); i++) { | |||||
| statusList.add(insList.get(i).getStatus()); | statusList.add(insList.get(i).getStatus()); | ||||
| } | } | ||||
| String subStatus = statusList.toString().substring(1, statusList.toString().length() - 1); | String subStatus = statusList.toString().substring(1, statusList.toString().length() - 1); | ||||
| Experiment experiment = experimentDao.queryById(experimentId); | Experiment experiment = experimentDao.queryById(experimentId); | ||||
| // 如果实验状态列表发生变化,则更新实验对象,并加入到需要更新的列表中 | // 如果实验状态列表发生变化,则更新实验对象,并加入到需要更新的列表中 | ||||
| if (!StringUtils.equals(subStatus,experiment.getStatusList())){ | |||||
| if (!StringUtils.equals(subStatus, experiment.getStatusList())) { | |||||
| experiment.setStatusList(subStatus); | experiment.setStatusList(subStatus); | ||||
| updateExperiments.add(experiment); | updateExperiments.add(experiment); | ||||
| } | } | ||||
| } | } | ||||
| if (!updateExperiments.isEmpty()) { | if (!updateExperiments.isEmpty()) { | ||||
| experimentDao.insertOrUpdateBatch(updateExperiments); | experimentDao.insertOrUpdateBatch(updateExperiments); | ||||
| for (int index = 0; index < updateExperiments.size(); index++) { | |||||
| // 线程安全的删除操作 | |||||
| synchronized (experimentIds) { | |||||
| experimentIds.remove(index); | |||||
| // 使用Iterator进行安全的删除操作 | |||||
| Iterator<Integer> iterator = experimentIds.iterator(); | |||||
| while (iterator.hasNext()) { | |||||
| Integer experimentId = iterator.next(); | |||||
| for (Experiment experiment : updateExperiments) { | |||||
| if (experiment.getId().equals(experimentId)) { | |||||
| iterator.remove(); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -1,7 +1,14 @@ | |||||
| package com.ruoyi.platform.service; | package com.ruoyi.platform.service; | ||||
| import com.ruoyi.platform.vo.FrameLogPathVo; | |||||
| import com.ruoyi.platform.vo.InsMetricInfoVo; | |||||
| import java.util.List; | |||||
| public interface AimService { | public interface AimService { | ||||
| String runAim(FrameLogPathVo frameLogPathVo); | |||||
| List<InsMetricInfoVo> getExpTrainInfos(Integer experimentId) throws Exception; | |||||
| List<InsMetricInfoVo> getExpEvaluateInfos(Integer experimentId) throws Exception; | |||||
| String getExpMetrics(List<String> runIds) throws Exception; | |||||
| } | } | ||||
| @@ -62,4 +62,8 @@ public interface ModelDependencyService { | |||||
| List<ModelDependency> queryByModelDependency(ModelDependency modelDependency) throws IOException; | List<ModelDependency> queryByModelDependency(ModelDependency modelDependency) throws IOException; | ||||
| ModelDependcyTreeVo getModelDependencyTree(ModelDependency modelDependency) throws Exception; | ModelDependcyTreeVo getModelDependencyTree(ModelDependency modelDependency) throws Exception; | ||||
| List<ModelDependency> queryByIns(Integer expInsId); | |||||
| ModelDependency queryByInsAndTrainTaskId(Integer expInsId, String taskId); | |||||
| } | } | ||||
| @@ -1,13 +1,175 @@ | |||||
| package com.ruoyi.platform.service.impl; | package com.ruoyi.platform.service.impl; | ||||
| import com.ruoyi.platform.domain.ExperimentIns; | |||||
| import com.ruoyi.platform.domain.ModelDependency; | |||||
| import com.ruoyi.platform.service.AimService; | import com.ruoyi.platform.service.AimService; | ||||
| import com.ruoyi.platform.vo.FrameLogPathVo; | |||||
| import com.ruoyi.platform.service.ExperimentInsService; | |||||
| import com.ruoyi.platform.service.ModelDependencyService; | |||||
| import com.ruoyi.platform.utils.AIM64EncoderUtil; | |||||
| import com.ruoyi.platform.utils.HttpUtils; | |||||
| import com.ruoyi.platform.utils.JacksonUtil; | |||||
| import com.ruoyi.platform.utils.JsonUtils; | |||||
| import com.ruoyi.platform.vo.InsMetricInfoVo; | |||||
| import org.apache.commons.lang3.StringUtils; | |||||
| import org.springframework.beans.factory.annotation.Value; | |||||
| import org.springframework.stereotype.Service; | import org.springframework.stereotype.Service; | ||||
| import javax.annotation.Resource; | |||||
| import java.net.URLEncoder; | |||||
| import java.util.*; | |||||
| import java.util.stream.Collectors; | |||||
| @Service | @Service | ||||
| public class AimServiceImpl implements AimService { | public class AimServiceImpl implements AimService { | ||||
| @Resource | |||||
| private ExperimentInsService experimentInsService; | |||||
| @Resource | |||||
| private ModelDependencyService modelDependencyService; | |||||
| @Value("${aim.url}") | |||||
| private String aimUrl; | |||||
| @Value("${aim.proxyUrl}") | |||||
| private String aimProxyUrl; | |||||
| @Override | @Override | ||||
| public String runAim(FrameLogPathVo frameLogPathVo) { | |||||
| return null; | |||||
| public List<InsMetricInfoVo> getExpTrainInfos(Integer experimentId) throws Exception { | |||||
| return getAimRunInfos(true,experimentId); | |||||
| } | } | ||||
| @Override | |||||
| public List<InsMetricInfoVo> getExpEvaluateInfos(Integer experimentId) throws Exception { | |||||
| return getAimRunInfos(false,experimentId); | |||||
| } | |||||
| @Override | |||||
| public String getExpMetrics(List<String> runIds) throws Exception { | |||||
| String decode = AIM64EncoderUtil.decode(runIds); | |||||
| return aimUrl+"/api/runs/search/run?query="+decode; | |||||
| } | |||||
| private List<InsMetricInfoVo> getAimRunInfos(boolean isTrain,Integer experimentId) throws Exception { | |||||
| String experimentName = "experiment-"+experimentId+"-train"; | |||||
| if (!isTrain){ | |||||
| experimentName = "experiment-"+experimentId+"-evaluate"; | |||||
| } | |||||
| String encodedUrlString = URLEncoder.encode("run.experiment==\""+experimentName+"\"", "UTF-8"); | |||||
| String url = aimProxyUrl+"/api/runs/search/run?query="+encodedUrlString; | |||||
| String s = HttpUtils.sendGetRequest(url); | |||||
| List<Map<String, Object>> response = JacksonUtil.parseJSONStr2MapList(s); | |||||
| if (response == null || response.size() == 0){ | |||||
| return new ArrayList<>(); | |||||
| } | |||||
| //查询实例数据 | |||||
| List<ExperimentIns> byExperimentId = experimentInsService.getByExperimentId(experimentId); | |||||
| if (byExperimentId == null || byExperimentId.size() == 0){ | |||||
| return new ArrayList<>(); | |||||
| } | |||||
| List<InsMetricInfoVo> aimRunInfoList = new ArrayList<>(); | |||||
| for (Map<String, Object> run : response) { | |||||
| InsMetricInfoVo aimRunInfo = new InsMetricInfoVo(); | |||||
| String runHash = (String) run.get("run_hash"); | |||||
| aimRunInfo.setRunId(runHash); | |||||
| Map params= (Map) run.get("params"); | |||||
| Map<String, Object> paramMap = JsonUtils.flattenJson("", params); | |||||
| aimRunInfo.setParams(paramMap); | |||||
| String aimrunId = (String) paramMap.get("id"); | |||||
| Map<String, Object> tracesMap= (Map<String, Object>) run.get("traces"); | |||||
| List<Map<String, Object>> metricList = (List<Map<String, Object>>) tracesMap.get("metric"); | |||||
| //过滤name为__system__开头的对象 | |||||
| aimRunInfo.setMetrics(new HashMap<>()); | |||||
| if (metricList != null && metricList.size() > 0){ | |||||
| List<Map<String, Object>> metricRelList = metricList.stream() | |||||
| .filter(map -> !StringUtils.startsWith((String) map.get("name"),"__system__" )) | |||||
| .collect(Collectors.toList()); | |||||
| if (metricRelList!= null && metricRelList.size() > 0){ | |||||
| Map<String, Object> relMetricMap = new HashMap<>(); | |||||
| for (Map<String, Object> metricMap : metricRelList) { | |||||
| relMetricMap.put((String)metricMap.get("name"), metricMap.get("last_value")); | |||||
| } | |||||
| aimRunInfo.setMetrics(relMetricMap); | |||||
| } | |||||
| } | |||||
| //找到ins | |||||
| for (ExperimentIns ins : byExperimentId) { | |||||
| String metricRecordString = ins.getMetricRecord(); | |||||
| if (StringUtils.isEmpty(metricRecordString)){ | |||||
| continue; | |||||
| } | |||||
| if (metricRecordString.contains(aimrunId)){ | |||||
| aimRunInfo.setExperimentInsId(ins.getId()); | |||||
| aimRunInfo.setStatus(ins.getStatus()); | |||||
| aimRunInfo.setStartTime(ins.getCreateTime()); | |||||
| Map<String, Object> metricRecordMap = JacksonUtil.parseJSONStr2Map(metricRecordString); | |||||
| //metricRecord 格式为{"train":[{"task_id":"model-train-35303690","run_id":"5560d78f54314672b60304c8d6ba03b8","experiment_name":"experiment-30-train"}],"evaluate":[{"task_id":"model-train-35303690","run_id":"5560d78f54314672b60304c8d6ba03b8","experiment_name":"experiment-30-train"}]} | |||||
| //遍历metricRecord,找到当前task_id对应的ModelDependency | |||||
| if (isTrain){ | |||||
| List<Map<String, Object>> trainList = (List<Map<String, Object>>) metricRecordMap.get("train"); | |||||
| List<String> trainDateSet = getTrainDateSet(trainList, ins.getId(), isTrain); | |||||
| aimRunInfo.setDataset(trainDateSet); | |||||
| }else { | |||||
| List<Map<String, Object>> trainList = (List<Map<String, Object>>) metricRecordMap.get("evaluate"); | |||||
| List<String> trainDateSet = getTrainDateSet(trainList, ins.getId(), isTrain); | |||||
| aimRunInfo.setDataset(trainDateSet); | |||||
| } | |||||
| } | |||||
| } | |||||
| aimRunInfoList.add(aimRunInfo); | |||||
| } | |||||
| //判断哪个最长 | |||||
| // 获取所有 metrics 的 key 的并集 | |||||
| Set<String> metricsKeys = (Set<String>) aimRunInfoList.stream() | |||||
| .map(InsMetricInfoVo::getMetrics) | |||||
| .flatMap(metrics -> metrics.keySet().stream()) | |||||
| .collect(Collectors.toSet()); | |||||
| // 将并集赋值给每个 InsMetricInfoVo 的 metricsNames 属性 | |||||
| aimRunInfoList.forEach(vo -> vo.setMetricsNames(new ArrayList<>(metricsKeys))); | |||||
| // 获取所有 params 的 key 的并集 | |||||
| Set<String> paramKeys = (Set<String>) aimRunInfoList.stream() | |||||
| .map(InsMetricInfoVo::getParams) | |||||
| .flatMap(params -> params.keySet().stream()) | |||||
| .collect(Collectors.toSet()); | |||||
| // 将并集赋值给每个 InsMetricInfoVo 的 paramsNames 属性 | |||||
| aimRunInfoList.forEach(vo -> vo.setParamsNames(new ArrayList<>(paramKeys))); | |||||
| return aimRunInfoList; | |||||
| } | |||||
| private List<String> getTrainDateSet(List<Map<String, Object>> trainList,Integer expInsId,boolean isTrain){ | |||||
| if (trainList == null || trainList.size() == 0){ | |||||
| return new ArrayList<>(); | |||||
| } | |||||
| List<String> datasetList = new ArrayList<>(); | |||||
| for (Map<String, Object> trainMap : trainList) { | |||||
| String task_id = (String) trainMap.get("task_id"); | |||||
| //modelDependency取到数据集文件 | |||||
| ModelDependency modelDependency = modelDependencyService.queryByInsAndTrainTaskId(expInsId, task_id); | |||||
| //把数据集文件组装成String后放进List | |||||
| String datasetString = ""; | |||||
| if (isTrain){ | |||||
| datasetString = modelDependency.getTrainDataset(); | |||||
| }else { | |||||
| datasetString = modelDependency.getTestDataset(); | |||||
| } | |||||
| List<Map<String, Object>> datasetListMap = JacksonUtil.parseJSONStr2MapList(datasetString); | |||||
| if (datasetListMap != null && datasetListMap.size() > 0){ | |||||
| for (Map<String, Object> datasetMap : datasetListMap) { | |||||
| //[{"dataset_id":20,"dataset_version":"v0.1.0","dataset_name":"手写体识别模型依赖测试训练数据集"}] | |||||
| String datasetName = (String) datasetMap.get("dataset_name")+":"+(String) datasetMap.get("dataset_version"); | |||||
| datasetList.add(datasetName); | |||||
| } | |||||
| } | |||||
| } | |||||
| return datasetList; | |||||
| } | |||||
| } | } | ||||
| @@ -254,9 +254,13 @@ public class ExperimentServiceImpl implements ExperimentService { | |||||
| //获取训练参数 | //获取训练参数 | ||||
| Map<String, Object> metricRecord = (Map<String, Object>) runResMap.get("metric_record"); | Map<String, Object> metricRecord = (Map<String, Object>) runResMap.get("metric_record"); | ||||
| Map<String, Object> metadata = (Map<String, Object>) data.get("metadata"); | Map<String, Object> metadata = (Map<String, Object>) data.get("metadata"); | ||||
| // 插入记录到实验实例表 | // 插入记录到实验实例表 | ||||
| ExperimentIns experimentIns = new ExperimentIns(); | ExperimentIns experimentIns = new ExperimentIns(); | ||||
| if (metricRecord != null){ | |||||
| experimentIns.setMetricRecord(JacksonUtil.toJSONString(metricRecord)); | |||||
| } | |||||
| experimentIns.setExperimentId(experiment.getId()); | experimentIns.setExperimentId(experiment.getId()); | ||||
| experimentIns.setArgoInsNs((String) metadata.get("namespace")); | experimentIns.setArgoInsNs((String) metadata.get("namespace")); | ||||
| experimentIns.setArgoInsName((String) metadata.get("name")); | experimentIns.setArgoInsName((String) metadata.get("name")); | ||||
| @@ -275,8 +279,9 @@ public class ExperimentServiceImpl implements ExperimentService { | |||||
| Map<String, Object> converMap2 = JsonUtils.jsonToMap(JacksonUtil.replaceInAarry(convertRes, params)); | Map<String, Object> converMap2 = JsonUtils.jsonToMap(JacksonUtil.replaceInAarry(convertRes, params)); | ||||
| Map<String ,Object> dependendcy = (Map<String, Object>)converMap2.get("model_dependency"); | Map<String ,Object> dependendcy = (Map<String, Object>)converMap2.get("model_dependency"); | ||||
| Map<String ,Object> trainInfo = (Map<String, Object>)converMap2.get("component_info"); | Map<String ,Object> trainInfo = (Map<String, Object>)converMap2.get("component_info"); | ||||
| insertModelDependency(dependendcy,trainInfo,insert.getId(),experiment.getName()); | |||||
| if (dependendcy != null && trainInfo != null){ | |||||
| insertModelDependency(dependendcy,trainInfo,insert.getId(),experiment.getName()); | |||||
| } | |||||
| }catch (Exception e){ | }catch (Exception e){ | ||||
| throw new RuntimeException(e); | throw new RuntimeException(e); | ||||
| } | } | ||||
| @@ -17,10 +17,7 @@ import org.springframework.data.domain.PageRequest; | |||||
| import javax.annotation.Resource; | import javax.annotation.Resource; | ||||
| import java.io.IOException; | import java.io.IOException; | ||||
| import java.util.ArrayList; | |||||
| import java.util.Date; | |||||
| import java.util.List; | |||||
| import java.util.Map; | |||||
| import java.util.*; | |||||
| import java.util.stream.Collectors; | import java.util.stream.Collectors; | ||||
| /** | /** | ||||
| @@ -97,6 +94,16 @@ public class ModelDependencyServiceImpl implements ModelDependencyService { | |||||
| return modelDependcyTreeVo; | return modelDependcyTreeVo; | ||||
| } | } | ||||
| @Override | |||||
| public List<ModelDependency> queryByIns(Integer expInsId) { | |||||
| return modelDependencyDao.queryByIns(expInsId); | |||||
| } | |||||
| @Override | |||||
| public ModelDependency queryByInsAndTrainTaskId(Integer expInsId, String taskId) { | |||||
| return modelDependencyDao.queryByInsAndTrainTaskId(expInsId,taskId); | |||||
| } | |||||
| /** | /** | ||||
| * 递归父模型 | * 递归父模型 | ||||
| * @param modelDependcyTreeVo | * @param modelDependcyTreeVo | ||||
| @@ -0,0 +1,76 @@ | |||||
| package com.ruoyi.platform.utils; | |||||
| import com.alibaba.fastjson.JSON; | |||||
| import java.util.Base64; | |||||
| import java.util.HashMap; | |||||
| import java.util.List; | |||||
| import java.util.Map; | |||||
| public class AIM64EncoderUtil { | |||||
| private static final String AIM64_ENCODING_PREFIX = "O-"; | |||||
| private static final Map<String, String> BS64_REPLACE_CHARACTERS_ENCODING = new HashMap<>(); | |||||
| static { | |||||
| BS64_REPLACE_CHARACTERS_ENCODING.put("=", ""); | |||||
| BS64_REPLACE_CHARACTERS_ENCODING.put("+", "-"); | |||||
| BS64_REPLACE_CHARACTERS_ENCODING.put("/", "_"); | |||||
| } | |||||
| public static String aim64encode(Map<String, Object> value) { | |||||
| String jsonEncoded = JSON.toJSONString(value); | |||||
| String base64Encoded = Base64.getEncoder().encodeToString(jsonEncoded.getBytes()); | |||||
| String aim64Encoded = base64Encoded; | |||||
| for (Map.Entry<String, String> entry : BS64_REPLACE_CHARACTERS_ENCODING.entrySet()) { | |||||
| aim64Encoded = aim64Encoded.replace(entry.getKey(), entry.getValue()); | |||||
| } | |||||
| return AIM64_ENCODING_PREFIX + aim64Encoded; | |||||
| } | |||||
| public static String encode(Map<String, Object> value, boolean oneWayHashing) { | |||||
| if (oneWayHashing) { | |||||
| return md5(JSON.toJSONString(value)); | |||||
| } | |||||
| return aim64encode(value); | |||||
| } | |||||
| private static String md5(String input) { | |||||
| try { | |||||
| java.security.MessageDigest md = java.security.MessageDigest.getInstance("MD5"); | |||||
| byte[] array = md.digest(input.getBytes()); | |||||
| StringBuilder sb = new StringBuilder(); | |||||
| for (byte b : array) { | |||||
| sb.append(Integer.toHexString((b & 0xFF) | 0x100).substring(1, 3)); | |||||
| } | |||||
| return sb.toString(); | |||||
| } catch (java.security.NoSuchAlgorithmException e) { | |||||
| e.printStackTrace(); | |||||
| } | |||||
| return null; | |||||
| } | |||||
| public static String decode(List<String> runIds) { | |||||
| // 确保 runIds 列表的大小为 3 | |||||
| if (runIds == null || runIds.size() == 0) { | |||||
| throw new IllegalArgumentException("runIds 不能为空"); | |||||
| } | |||||
| // 构建查询字符串 | |||||
| StringBuilder queryBuilder = new StringBuilder("run.hash in ["); | |||||
| for (int i = 0; i < runIds.size(); i++) { | |||||
| if (i > 0) { | |||||
| queryBuilder.append(","); | |||||
| } | |||||
| queryBuilder.append("\"").append(runIds.get(i)).append("\""); | |||||
| } | |||||
| queryBuilder.append("]"); | |||||
| String query = queryBuilder.toString(); | |||||
| Map<String, Object> map = new HashMap<>(); | |||||
| map.put("query", query); | |||||
| map.put("advancedMode", true); | |||||
| map.put("advancedQuery", query); | |||||
| String searchQuery = encode(map, false); | |||||
| return searchQuery; | |||||
| } | |||||
| } | |||||
| @@ -25,6 +25,7 @@ import java.security.cert.X509Certificate; | |||||
| import java.util.HashMap; | import java.util.HashMap; | ||||
| import java.util.List; | import java.util.List; | ||||
| import java.util.Map; | import java.util.Map; | ||||
| import java.util.zip.GZIPInputStream; | |||||
| /** | /** | ||||
| * HTTP请求工具类 | * HTTP请求工具类 | ||||
| @@ -447,4 +448,38 @@ public class HttpUtils { | |||||
| return true; | return true; | ||||
| } | } | ||||
| } | } | ||||
| public static String sendGetRequestgzip(String url) throws Exception { | |||||
| String resultStr = null; | |||||
| HttpGet httpGet = new HttpGet(url); | |||||
| httpGet.setHeader("Content-Type", "application/json"); | |||||
| httpGet.setHeader("Accept-Encoding", "gzip, deflate"); | |||||
| try { | |||||
| HttpResponse response = httpClient.execute(httpGet); | |||||
| int responseCode = response.getStatusLine().getStatusCode(); | |||||
| if (responseCode != 200) { | |||||
| throw new IOException("HTTP request failed with response code: " + responseCode); | |||||
| } | |||||
| // 获取响应内容 | |||||
| InputStream responseStream = response.getEntity().getContent(); | |||||
| // 检查响应是否被压缩 | |||||
| if ("gzip".equalsIgnoreCase(response.getEntity().getContentEncoding().getValue())) { | |||||
| responseStream = new GZIPInputStream(responseStream); | |||||
| } | |||||
| // 读取解压缩后的内容 | |||||
| byte[] buffer = new byte[1024]; | |||||
| int len; | |||||
| StringBuilder decompressedString = new StringBuilder(); | |||||
| while ((len = responseStream.read(buffer)) > 0) { | |||||
| decompressedString.append(new String(buffer, 0, len)); | |||||
| } | |||||
| resultStr = decompressedString.toString(); | |||||
| } catch (IOException e) { | |||||
| e.printStackTrace(); | |||||
| } | |||||
| return resultStr; | |||||
| } | |||||
| } | } | ||||
| @@ -1,8 +1,11 @@ | |||||
| package com.ruoyi.platform.utils; | package com.ruoyi.platform.utils; | ||||
| import com.fasterxml.jackson.core.JsonProcessingException; | import com.fasterxml.jackson.core.JsonProcessingException; | ||||
| import com.fasterxml.jackson.databind.ObjectMapper; | import com.fasterxml.jackson.databind.ObjectMapper; | ||||
| import org.json.JSONObject; | |||||
| import java.io.IOException; | import java.io.IOException; | ||||
| import java.util.HashMap; | |||||
| import java.util.Iterator; | |||||
| import java.util.Map; | import java.util.Map; | ||||
| public class JsonUtils { | public class JsonUtils { | ||||
| @@ -28,4 +31,26 @@ public class JsonUtils { | |||||
| public static <T> T jsonToObject(String json, Class<T> clazz) throws IOException { | public static <T> T jsonToObject(String json, Class<T> clazz) throws IOException { | ||||
| return objectMapper.readValue(json, clazz); | return objectMapper.readValue(json, clazz); | ||||
| } | } | ||||
| // 将JSON字符串转换为扁平化的Map | |||||
| public static Map<String, Object> flattenJson(String prefix, Map<String, Object> map) { | |||||
| Map<String, Object> flatMap = new HashMap<>(); | |||||
| Iterator<Map.Entry<String, Object>> entries = map.entrySet().iterator(); | |||||
| while (entries.hasNext()) { | |||||
| Map.Entry<String, Object> entry = entries.next(); | |||||
| String key = entry.getKey(); | |||||
| Object value = entry.getValue(); | |||||
| if (value instanceof Map) { | |||||
| flatMap.putAll(flattenJson(prefix + key + ".", (Map<String, Object>) value)); | |||||
| } else { | |||||
| flatMap.put(prefix + key, value); | |||||
| } | |||||
| } | |||||
| return flatMap; | |||||
| } | |||||
| } | } | ||||
| @@ -0,0 +1,32 @@ | |||||
| package com.ruoyi.platform.vo; | |||||
| import com.fasterxml.jackson.databind.PropertyNamingStrategy; | |||||
| import com.fasterxml.jackson.databind.annotation.JsonNaming; | |||||
| import io.swagger.annotations.ApiModelProperty; | |||||
| import lombok.Data; | |||||
| import java.io.Serializable; | |||||
| import java.util.Date; | |||||
| import java.util.List; | |||||
| import java.util.Map; | |||||
| @JsonNaming(PropertyNamingStrategy.SnakeCaseStrategy.class) | |||||
| @Data | |||||
| public class InsMetricInfoVo implements Serializable { | |||||
| @ApiModelProperty(value = "开始时间") | |||||
| private Date startTime; | |||||
| @ApiModelProperty(value = "实例运行状态") | |||||
| private String status; | |||||
| @ApiModelProperty(value = "使用数据集") | |||||
| private List<String> dataset; | |||||
| @ApiModelProperty(value = "实例ID") | |||||
| private Integer experimentInsId; | |||||
| @ApiModelProperty(value = "训练指标") | |||||
| private Map metrics; | |||||
| @ApiModelProperty(value = "训练参数") | |||||
| private Map params; | |||||
| @ApiModelProperty(value = "训练记录ID") | |||||
| private String runId; | |||||
| private List<String> metricsNames; | |||||
| private List<String> paramsNames; | |||||
| } | |||||
| @@ -22,6 +22,22 @@ | |||||
| <result property="state" column="state" jdbcType="INTEGER"/> | <result property="state" column="state" jdbcType="INTEGER"/> | ||||
| </resultMap> | </resultMap> | ||||
| <select id="queryByIns" resultMap="ModelDependencyMap"> | |||||
| select | |||||
| id,current_model_id,exp_ins_id,parent_models,ref_item,train_task,train_dataset,train_params,train_image,test_dataset,project_dependency,version,create_by,create_time,update_by,update_time,state | |||||
| from model_dependency | |||||
| <where> | |||||
| exp_ins_id = #{expInsId} and state = 1 | |||||
| </where> | |||||
| </select> | |||||
| <select id="queryByInsAndTrainTaskId" resultMap="ModelDependencyMap"> | |||||
| select | |||||
| id,current_model_id,exp_ins_id,parent_models,ref_item,train_task,train_dataset,train_params,train_image,test_dataset,project_dependency,version,create_by,create_time,update_by,update_time,state | |||||
| from model_dependency | |||||
| <where> | |||||
| exp_ins_id = #{expInsId} and train_task like concat('%', #{taskId}, '%') limit 1 | |||||
| </where> | |||||
| </select> | |||||
| <select id="queryChildrenByVersionId" resultMap="ModelDependencyMap"> | <select id="queryChildrenByVersionId" resultMap="ModelDependencyMap"> | ||||
| select | select | ||||
| id,current_model_id,exp_ins_id,parent_models,ref_item,train_task,train_dataset,train_params,train_image,test_dataset,project_dependency,version,create_by,create_time,update_by,update_time,state | id,current_model_id,exp_ins_id,parent_models,ref_item,train_task,train_dataset,train_params,train_image,test_dataset,project_dependency,version,create_by,create_time,update_by,update_time,state | ||||