From 7514f04f0c1a33070dc7148cff0584148de616ac Mon Sep 17 00:00:00 2001 From: cp3hnu Date: Mon, 24 Jun 2024 09:01:43 +0800 Subject: [PATCH 1/6] =?UTF-8?q?feat:=20=E5=AE=9E=E9=AA=8C=E5=AF=B9?= =?UTF-8?q?=E6=AF=94?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- react-ui/config/routes.ts | 5 + .../components/ResourceIntro/index.tsx | 2 +- .../DevelopmentEnvironment/List/index.tsx | 10 +- .../pages/Experiment/Comparison/index.less | 21 ++ .../src/pages/Experiment/Comparison/index.tsx | 212 ++++++++++++++++++ react-ui/src/pages/Experiment/index.jsx | 33 ++- react-ui/src/pages/Experiment/index.less | 2 +- react-ui/src/requestConfig.ts | 2 +- react-ui/src/services/experiment/index.js | 14 ++ react-ui/src/services/session.ts | 11 +- 10 files changed, 298 insertions(+), 14 deletions(-) create mode 100644 react-ui/src/pages/Experiment/Comparison/index.less create mode 100644 react-ui/src/pages/Experiment/Comparison/index.tsx diff --git a/react-ui/config/routes.ts b/react-ui/config/routes.ts index 7edf6db2..74390abf 100644 --- a/react-ui/config/routes.ts +++ b/react-ui/config/routes.ts @@ -96,6 +96,11 @@ export default [ path: ':workflowId/:id', component: './Experiment/training/index', }, + { + name: '实验对比', + path: 'compare', + component: './Experiment/Comparison/index', + }, ], }, ], diff --git a/react-ui/src/pages/Dataset/components/ResourceIntro/index.tsx b/react-ui/src/pages/Dataset/components/ResourceIntro/index.tsx index bcbd123f..e8bed08c 100644 --- a/react-ui/src/pages/Dataset/components/ResourceIntro/index.tsx +++ b/react-ui/src/pages/Dataset/components/ResourceIntro/index.tsx @@ -58,7 +58,7 @@ const ResourceIntro = ({ resourceType }: ResourceIntroProps) => { }; }), ); - if (versionParam) { + if (versionParam && res.data.includes(versionParam)) { setVersion(versionParam); versionParam = null; } else { diff --git a/react-ui/src/pages/DevelopmentEnvironment/List/index.tsx b/react-ui/src/pages/DevelopmentEnvironment/List/index.tsx index 9c4badb6..76f0dcb9 100644 --- a/react-ui/src/pages/DevelopmentEnvironment/List/index.tsx +++ b/react-ui/src/pages/DevelopmentEnvironment/List/index.tsx @@ -38,6 +38,7 @@ export type EditorData = { computing_resource: string; update_by: string; create_time: string; + url: string; }; function EditorList() { @@ -140,7 +141,14 @@ function EditorList() { dataIndex: 'name', key: 'name', width: '30%', - render: CommonTableCell(), + render: (text, record) => + record.url ? ( + + {text} + + ) : ( + {text ?? '--'} + ), }, { title: '状态', diff --git a/react-ui/src/pages/Experiment/Comparison/index.less b/react-ui/src/pages/Experiment/Comparison/index.less new file mode 100644 index 00000000..288ce2ed --- /dev/null +++ b/react-ui/src/pages/Experiment/Comparison/index.less @@ -0,0 +1,21 @@ +.experiment-comparison { + height: 100%; + &__header { + display: flex; + align-items: center; + height: 50px; + margin-bottom: 10px; + padding: 0 30px; + background-image: url(@/assets/img/page-title-bg.png); + background-repeat: no-repeat; + background-position: top center; + background-size: 100% 100%; + } + + &__table { + height: calc(100% - 60px); + padding: 20px 30px 0; + background-color: white; + border-radius: 10px; + } +} diff --git a/react-ui/src/pages/Experiment/Comparison/index.tsx b/react-ui/src/pages/Experiment/Comparison/index.tsx new file mode 100644 index 00000000..5fca638d --- /dev/null +++ b/react-ui/src/pages/Experiment/Comparison/index.tsx @@ -0,0 +1,212 @@ +import CommonTableCell from '@/components/CommonTableCell'; +import { useCacheState } from '@/hooks/pageCacheState'; +import { getExpEvaluateInfosReq, getExpTrainInfosReq } from '@/services/experiment'; +import { to } from '@/utils/promise'; +import { useSearchParams } from '@umijs/max'; +import { Button, Table, TablePaginationConfig, TableProps } from 'antd'; +import classNames from 'classnames'; +import { useEffect, useState } from 'react'; +import styles from './index.less'; + +export enum ComparisonType { + Train = 'train', + Evaluate = 'evaluate', +} + +function ExperimentComparison() { + const [searchParams] = useSearchParams(); + const comparisonType = searchParams.get('type'); + const experimentId = searchParams.get('experimentId'); + const [tableData, setTableData] = useState([]); + const [cacheState, setCacheState] = useCacheState(); + const [total, setTotal] = useState(0); + const [selectedRowKeys, setSelectedRowKeys] = useState([]); + const [pagination, setPagination] = useState( + cacheState?.pagination ?? { + current: 1, + pageSize: 10, + }, + ); + + useEffect(() => { + getComparisonData(); + }, [experimentId]); + + // 获取对比数据列表 + const getComparisonData = async () => { + const request = + comparisonType === ComparisonType.Train ? getExpTrainInfosReq : getExpEvaluateInfosReq; + const [res] = await to(request(experimentId)); + if (res && res.data) { + const { content = [], totalElements = 0 } = res.data; + setTableData(content); + setTotal(totalElements); + } + }; + + // 分页切换 + const handleTableChange: TableProps['onChange'] = (pagination, filters, sorter, { action }) => { + if (action === 'paginate') { + setPagination(pagination); + } + // console.log(pagination, filters, sorter, action); + }; + + const rowSelection: TableProps['rowSelection'] = { + type: 'checkbox', + selectedRowKeys, + onChange: (selectedRowKeys: React.Key[], selectedRows: any[]) => { + console.log(`selectedRowKeys: ${selectedRowKeys}`, 'selectedRows: ', selectedRows); + setSelectedRowKeys(selectedRowKeys); + }, + }; + + const columns: TableProps['columns'] = [ + { + title: '基本信息', + children: [ + { + title: '实例ID', + dataIndex: 'name', + key: 'name', + width: '30%', + render: CommonTableCell(), + }, + { + title: '运行时间', + dataIndex: 'name', + key: 'name', + width: '30%', + render: CommonTableCell(), + }, + { + title: '运行状态', + dataIndex: 'name', + key: 'name', + width: '30%', + render: CommonTableCell(), + }, + { + title: '训练数据集', + dataIndex: 'name', + key: 'name', + width: '30%', + render: CommonTableCell(), + }, + { + title: '增量训练', + dataIndex: 'name', + key: 'name', + width: '30%', + render: CommonTableCell(), + }, + ], + }, + { + title: '训练参数', + children: [ + { + title: 'batchsize', + dataIndex: 'name', + key: 'name', + width: '30%', + render: CommonTableCell(), + }, + { + title: 'config', + dataIndex: 'name', + key: 'name', + width: '30%', + render: CommonTableCell(), + }, + { + title: 'epoch', + dataIndex: 'name', + key: 'name', + width: '30%', + render: CommonTableCell(), + }, + { + title: 'lr', + dataIndex: 'name', + key: 'name', + width: '30%', + render: CommonTableCell(), + }, + { + title: 'warmup_iters', + dataIndex: 'name', + key: 'name', + width: '30%', + render: CommonTableCell(), + }, + ], + }, + { + title: '训练指标', + children: [ + { + title: 'metrc_name', + dataIndex: 'name', + key: 'name', + width: '30%', + render: CommonTableCell(), + }, + { + title: 'test_1', + dataIndex: 'name', + key: 'name', + width: '30%', + render: CommonTableCell(), + }, + { + title: 'test_2', + dataIndex: 'name', + key: 'name', + width: '30%', + render: CommonTableCell(), + }, + { + title: 'test_3', + dataIndex: 'name', + key: 'name', + width: '30%', + render: CommonTableCell(), + }, + { + title: 'test_4', + dataIndex: 'name', + key: 'name', + width: '30%', + render: CommonTableCell(), + }, + ], + }, + ]; + + return ( +
+
+ +
+
+ + + + ); +} + +export default ExperimentComparison; diff --git a/react-ui/src/pages/Experiment/index.jsx b/react-ui/src/pages/Experiment/index.jsx index 69648b49..7c7e3eaa 100644 --- a/react-ui/src/pages/Experiment/index.jsx +++ b/react-ui/src/pages/Experiment/index.jsx @@ -18,11 +18,12 @@ import themes from '@/styles/theme.less'; import { elapsedTime, formatDate } from '@/utils/date'; import { to } from '@/utils/promise'; import { modalConfirm } from '@/utils/ui'; -import { App, Button, ConfigProvider, Space, Table, Tooltip } from 'antd'; +import { App, Button, ConfigProvider, Dropdown, Space, Table, Tooltip } from 'antd'; import classNames from 'classnames'; import { useEffect, useRef, useState } from 'react'; import { useNavigate } from 'react-router-dom'; import AddExperimentModal from './components/AddExperimentModal'; +import { ComparisonType } from './components/ComparisonModal/config'; import TensorBoardStatus, { TensorBoardStatusEnum } from './components/TensorBoardStatus'; import Styles from './index.less'; import { experimentStatusInfo } from './status'; @@ -270,6 +271,26 @@ function Experiment() { window.open(experimentIn.tensorboardUrl, '_blank'); } }; + + // 实验对比菜单 + const getComparisonMenu = (experimentId) => { + return { + items: [ + { + label: 训练对比, + key: ComparisonType.Train, + }, + { + label: 评估对比, + key: ComparisonType.Evaluate, + }, + ], + onClick: ({ key }) => { + navgite(`/pipeline/experiment/compare?type=${key}&id=${experimentId}`); + }, + }; + }; + const columns = [ { title: '实验名称', @@ -320,7 +341,7 @@ function Experiment() { { title: '操作', key: 'action', - width: 300, + width: 350, render: (_, record) => ( + + e.preventDefault()}> + + + 实验对比 + + + { const { status, data } = response || {}; - console.log(message, data); + // console.log(message, data); if (status >= 200 && status < 300) { if (data && (data instanceof Blob || data.code === 200)) { return response; diff --git a/react-ui/src/services/experiment/index.js b/react-ui/src/services/experiment/index.js index 9d388b71..89028b24 100644 --- a/react-ui/src/services/experiment/index.js +++ b/react-ui/src/services/experiment/index.js @@ -116,3 +116,17 @@ export function getTensorBoardStatusReq(data) { data, }); } + +// 获取当前实验的模型推理指标信息 +export function getExpEvaluateInfosReq(experimentId) { + return request(`/api/mmp/aim/getExpEvaluateInfos/${experimentId}`, { + method: 'GET', + }); +} + +// 获取当前实验的模型训练指标信息 +export function getExpTrainInfosReq(experimentId) { + return request(`/api/mmp//aim/getExpTrainInfos/${experimentId}`, { + method: 'GET', + }); +} diff --git a/react-ui/src/services/session.ts b/react-ui/src/services/session.ts index 8430b21b..e1f00d75 100644 --- a/react-ui/src/services/session.ts +++ b/react-ui/src/services/session.ts @@ -60,8 +60,8 @@ function patchRouteItems(route: any, menu: any, parentPath: string) { element: React.createElement(lazy(() => import('@/pages/' + path))), path: parentPath + menuItem.path, }; - console.log(newRoute); - + // console.log(newRoute); + route.children.push(newRoute); route.routes.push(newRoute); } @@ -74,10 +74,7 @@ export function patchRouteWithRemoteMenus(routes: any) { } let proLayout = null; for (const routeItem of routes) { - if (routeItem.id === 'ant-design-pro-layout') { - - proLayout = routeItem; break; } @@ -101,7 +98,6 @@ export async function refreshToken() { } export function convertCompatRouters(childrens: API.RoutersMenuItem[]): any[] { - return childrens.map((item: API.RoutersMenuItem) => { return { path: item.path, @@ -149,12 +145,11 @@ export function getMatchMenuItem( const subpath = path.substr(item.path.length + 1); const subItem: MenuDataItem[] = getMatchMenuItem(subpath, item.routes); items = items.concat(subItem); - } else { const paths = path.split('/'); if (paths.length >= 2 && paths[0] === item.path && paths[1] === 'index') { console.log(item); - + items.push(item); } } From 29eb05cf5e7b70471969c42333ecacb395576068 Mon Sep 17 00:00:00 2001 From: cp3hnu Date: Mon, 24 Jun 2024 09:16:45 +0800 Subject: [PATCH 2/6] =?UTF-8?q?feat:=20=E5=AE=9E=E9=AA=8C=E5=AF=B9?= =?UTF-8?q?=E6=AF=94?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- react-ui/src/pages/Experiment/Comparison/index.tsx | 4 ++-- react-ui/src/pages/Experiment/index.jsx | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/react-ui/src/pages/Experiment/Comparison/index.tsx b/react-ui/src/pages/Experiment/Comparison/index.tsx index 5fca638d..c53a0a94 100644 --- a/react-ui/src/pages/Experiment/Comparison/index.tsx +++ b/react-ui/src/pages/Experiment/Comparison/index.tsx @@ -9,8 +9,8 @@ import { useEffect, useState } from 'react'; import styles from './index.less'; export enum ComparisonType { - Train = 'train', - Evaluate = 'evaluate', + Train = 'train', // 训练 + Evaluate = 'evaluate', // 评估 } function ExperimentComparison() { diff --git a/react-ui/src/pages/Experiment/index.jsx b/react-ui/src/pages/Experiment/index.jsx index 7c7e3eaa..8793da19 100644 --- a/react-ui/src/pages/Experiment/index.jsx +++ b/react-ui/src/pages/Experiment/index.jsx @@ -22,8 +22,8 @@ import { App, Button, ConfigProvider, Dropdown, Space, Table, Tooltip } from 'an import classNames from 'classnames'; import { useEffect, useRef, useState } from 'react'; import { useNavigate } from 'react-router-dom'; +import { ComparisonType } from './Comparison'; import AddExperimentModal from './components/AddExperimentModal'; -import { ComparisonType } from './components/ComparisonModal/config'; import TensorBoardStatus, { TensorBoardStatusEnum } from './components/TensorBoardStatus'; import Styles from './index.less'; import { experimentStatusInfo } from './status'; From b06a502ebe51434b8d4494e2b836e49271fda1ce Mon Sep 17 00:00:00 2001 From: cp3hnu Date: Mon, 24 Jun 2024 16:56:36 +0800 Subject: [PATCH 3/6] =?UTF-8?q?feat:=20=E5=BC=80=E5=8F=91=E7=8E=AF?= =?UTF-8?q?=E5=A2=83=E7=BC=96=E8=BE=91=E5=99=A8=E4=BD=BF=E7=94=A8iframe?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- react-ui/config/routes.ts | 5 ++++ .../DevelopmentEnvironment/Editor/index.tsx | 25 +++++++++++++++++++ .../DevelopmentEnvironment/List/index.tsx | 15 ++++++++--- .../pages/DevelopmentEnvironment/index.tsx | 22 ---------------- .../src/pages/Pipeline/editPipeline/index.jsx | 10 ++++---- .../services/developmentEnvironment/index.ts | 3 +-- react-ui/src/utils/sessionStorage.ts | 2 ++ react-ui/tsconfig.json | 1 + 8 files changed, 51 insertions(+), 32 deletions(-) create mode 100644 react-ui/src/pages/DevelopmentEnvironment/Editor/index.tsx delete mode 100644 react-ui/src/pages/DevelopmentEnvironment/index.tsx diff --git a/react-ui/config/routes.ts b/react-ui/config/routes.ts index 74390abf..81076627 100644 --- a/react-ui/config/routes.ts +++ b/react-ui/config/routes.ts @@ -119,6 +119,11 @@ export default [ path: 'create', component: './DevelopmentEnvironment/Create', }, + { + name: '编辑器', + path: 'editor', + component: './DevelopmentEnvironment/Editor', + }, ], }, { diff --git a/react-ui/src/pages/DevelopmentEnvironment/Editor/index.tsx b/react-ui/src/pages/DevelopmentEnvironment/Editor/index.tsx new file mode 100644 index 00000000..b113d76b --- /dev/null +++ b/react-ui/src/pages/DevelopmentEnvironment/Editor/index.tsx @@ -0,0 +1,25 @@ +import { editorUrl, getSessionStorageItem, removeSessionStorageItem } from '@/utils/sessionStorage'; +import { useEffect, useState } from 'react'; + +function DevEditor() { + const [iframeUrl, setIframeUrl] = useState(''); + useEffect(() => { + const url = getSessionStorageItem(editorUrl) || ''; + setIframeUrl(url); + return () => { + removeSessionStorageItem(editorUrl); + }; + }, []); + + // const requestJupyterUrl = async () => { + // const [res, error] = await to(getJupyterUrl()); + // if (res) { + // setIframeUrl(res.data as string); + // } else { + // console.log(error); + // } + // }; + + return ; +} +export default DevEditor; diff --git a/react-ui/src/pages/DevelopmentEnvironment/List/index.tsx b/react-ui/src/pages/DevelopmentEnvironment/List/index.tsx index 76f0dcb9..c7b22c6a 100644 --- a/react-ui/src/pages/DevelopmentEnvironment/List/index.tsx +++ b/react-ui/src/pages/DevelopmentEnvironment/List/index.tsx @@ -16,6 +16,7 @@ import { } from '@/services/developmentEnvironment'; import themes from '@/styles/theme.less'; import { to } from '@/utils/promise'; +import { editorUrl, setSessionStorageItem } from '@/utils/sessionStorage'; import { modalConfirm } from '@/utils/ui'; import { useNavigate } from '@umijs/max'; import { @@ -128,6 +129,16 @@ function EditorList() { }); }; + // 跳转编辑器页面 + const gotoEditorPage = (e: React.MouseEvent, record: EditorData) => { + e.stopPropagation(); + setSessionStorageItem(editorUrl, record.url); + navigate(`/developmentEnvironment/editor`); + setCacheState({ + pagination, + }); + }; + // 分页切换 const handleTableChange: TableProps['onChange'] = (pagination, filters, sorter, { action }) => { if (action === 'paginate') { @@ -143,9 +154,7 @@ function EditorList() { width: '30%', render: (text, record) => record.url ? ( - - {text} - + gotoEditorPage(e, record)}>{text} ) : ( {text ?? '--'} ), diff --git a/react-ui/src/pages/DevelopmentEnvironment/index.tsx b/react-ui/src/pages/DevelopmentEnvironment/index.tsx deleted file mode 100644 index c2e09d35..00000000 --- a/react-ui/src/pages/DevelopmentEnvironment/index.tsx +++ /dev/null @@ -1,22 +0,0 @@ -import { getJupyterUrl } from '@/services/developmentEnvironment'; -import { to } from '@/utils/promise'; -import { useEffect, useState } from 'react'; - -const DevelopmentEnvironment = () => { - const [iframeUrl, setIframeUrl] = useState(''); - useEffect(() => { - requestJupyterUrl(); - }, []); - - const requestJupyterUrl = async () => { - const [res, error] = await to(getJupyterUrl()); - if (res) { - setIframeUrl(res.data as string); - } else { - console.log(error); - } - }; - - return ; -}; -export default DevelopmentEnvironment; diff --git a/react-ui/src/pages/Pipeline/editPipeline/index.jsx b/react-ui/src/pages/Pipeline/editPipeline/index.jsx index c51ce935..2f434cb0 100644 --- a/react-ui/src/pages/Pipeline/editPipeline/index.jsx +++ b/react-ui/src/pages/Pipeline/editPipeline/index.jsx @@ -93,11 +93,11 @@ const EditPipeline = () => { return; } - // const [propsRes, propsError] = await to(propsRef.current.getFieldsValue()); - // if (propsError) { - // message.error('基本信息必填项需配置'); - // return; - // } + const [propsRes, propsError] = await to(propsRef.current.getFieldsValue()); + if (propsError) { + message.error('节点必填项必须配置'); + return; + } propsRef.current.propClose(); setTimeout(() => { const data = graph.save(); diff --git a/react-ui/src/services/developmentEnvironment/index.ts b/react-ui/src/services/developmentEnvironment/index.ts index dbe2717a..330f76d3 100644 --- a/react-ui/src/services/developmentEnvironment/index.ts +++ b/react-ui/src/services/developmentEnvironment/index.ts @@ -1,9 +1,8 @@ import { request } from '@umijs/max'; // 查询开发环境url -export function getJupyterUrl(params: any) { +export function getJupyterUrl() { return request(`/api/mmp/jupyter/getURL`, { method: 'GET', - params, }); } diff --git a/react-ui/src/utils/sessionStorage.ts b/react-ui/src/utils/sessionStorage.ts index 6018dbe7..a8632070 100644 --- a/react-ui/src/utils/sessionStorage.ts +++ b/react-ui/src/utils/sessionStorage.ts @@ -2,6 +2,8 @@ export const mirrorNameKey = 'mirror-name'; // 模型部署 export const modelDeploymentInfoKey = 'model-deployment-info'; +// 编辑器 url +export const editorUrl = 'editor-url'; export const getSessionStorageItem = (key: string, isObject: boolean = false) => { const jsonStr = sessionStorage.getItem(key); diff --git a/react-ui/tsconfig.json b/react-ui/tsconfig.json index 9922437a..0afa8788 100644 --- a/react-ui/tsconfig.json +++ b/react-ui/tsconfig.json @@ -20,6 +20,7 @@ "noUnusedParameters": true, // 报告未使用的参数错误 "incremental": true, // 通过读写磁盘上的文件来启用增量编译 "noFallthroughCasesInSwitch": true, // 报告switch语句中的fallthrough案例错误 + "strictNullChecks": true, // 启用严格的null检查 "baseUrl": "./", "paths": { "@/*": ["src/*"], From 01e1f7e951d4afa9531a4c99da1aea9a01602d2c Mon Sep 17 00:00:00 2001 From: fanshuai <1141904845@qq.com> Date: Mon, 24 Jun 2024 19:24:55 +0800 Subject: [PATCH 4/6] =?UTF-8?q?=E6=96=B0=E5=A2=9E=E6=8C=87=E6=A0=87?= =?UTF-8?q?=E5=AF=B9=E6=AF=94?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ruoyi-modules/management-platform/pom.xml | 11 ++ .../controller/aim/AimController.java | 36 +++--- .../ruoyi/platform/service/AimService.java | 11 +- .../platform/service/impl/AimServiceImpl.java | 116 +++++++++++++++++- .../service/impl/ExperimentServiceImpl.java | 9 +- .../platform/utils/AIM64EncoderUtil.java | 76 ++++++++++++ .../com/ruoyi/platform/utils/HttpUtils.java | 35 ++++++ .../com/ruoyi/platform/utils/JsonUtils.java | 25 ++++ .../ruoyi/platform/vo/InsMetricInfoVo.java | 32 +++++ 9 files changed, 330 insertions(+), 21 deletions(-) create mode 100644 ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/utils/AIM64EncoderUtil.java create mode 100644 ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/vo/InsMetricInfoVo.java diff --git a/ruoyi-modules/management-platform/pom.xml b/ruoyi-modules/management-platform/pom.xml index 4cb1ae7d..37234ad0 100644 --- a/ruoyi-modules/management-platform/pom.xml +++ b/ruoyi-modules/management-platform/pom.xml @@ -205,6 +205,17 @@ org.springframework.boot spring-boot-starter-websocket + + org.json + json + 20210307 + + + org.apache.dubbo + dubbo + 3.0.8 + compile + diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/aim/AimController.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/aim/AimController.java index f1750133..f6b0b863 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/aim/AimController.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/controller/aim/AimController.java @@ -4,16 +4,16 @@ import com.ruoyi.common.core.web.controller.BaseController; import com.ruoyi.common.core.web.domain.GenericsAjaxResult; import com.ruoyi.platform.service.AimService; import com.ruoyi.platform.vo.FrameLogPathVo; +import com.ruoyi.platform.vo.InsMetricInfoVo; import com.ruoyi.platform.vo.PodStatusVo; import io.swagger.annotations.Api; import io.swagger.annotations.ApiOperation; import io.swagger.v3.oas.annotations.responses.ApiResponse; -import org.springframework.web.bind.annotation.PostMapping; -import org.springframework.web.bind.annotation.RequestBody; -import org.springframework.web.bind.annotation.RequestMapping; -import org.springframework.web.bind.annotation.RestController; +import org.springframework.web.bind.annotation.*; import javax.annotation.Resource; +import java.util.List; + @RestController @RequestMapping("aim") @Api("Aim管理") @@ -22,17 +22,25 @@ public class AimController extends BaseController { @Resource private AimService aimService; - /** - * 启动tensorBoard接口 - * - * @param frameLogPathVo 存储路径 - * @return url - */ - @PostMapping("/run") - @ApiOperation("启动aim`") + + @GetMapping("/getExpTrainInfos/{experiment_id}") + @ApiOperation("获取当前实验的模型训练指标信息") @ApiResponse - public GenericsAjaxResult runAim(@RequestBody FrameLogPathVo frameLogPathVo) throws Exception { - return genericsSuccess(aimService.runAim(frameLogPathVo)); + public GenericsAjaxResult> getExpTrainInfos(@PathVariable("experiment_id") Integer experimentId) throws Exception { + return genericsSuccess(aimService.getExpTrainInfos(experimentId)); } + @GetMapping("/getExpEvaluateInfos/{experiment_id}") + @ApiOperation("获取当前实验的模型推理指标信息") + @ApiResponse + public GenericsAjaxResult> getExpEvaluateInfos(@PathVariable("experiment_id") Integer experimentId) throws Exception { + return genericsSuccess(aimService.getExpEvaluateInfos(experimentId)); + } + + @PostMapping("/getExpMetrics") + @ApiOperation("获取当前实验的指标对比地址") + @ApiResponse + public GenericsAjaxResult getExpMetrics(@RequestBody List runIds) throws Exception { + return genericsSuccess(aimService.getExpMetrics(runIds)); + } } \ No newline at end of file diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/AimService.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/AimService.java index 60f74b90..c83a42af 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/AimService.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/AimService.java @@ -1,7 +1,14 @@ package com.ruoyi.platform.service; -import com.ruoyi.platform.vo.FrameLogPathVo; +import com.ruoyi.platform.vo.InsMetricInfoVo; + +import java.util.List; public interface AimService { - String runAim(FrameLogPathVo frameLogPathVo); + + List getExpTrainInfos(Integer experimentId) throws Exception; + + List getExpEvaluateInfos(Integer experimentId) throws Exception; + + String getExpMetrics(List runIds) throws Exception; } diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/AimServiceImpl.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/AimServiceImpl.java index f66dc178..c26e101e 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/AimServiceImpl.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/AimServiceImpl.java @@ -1,13 +1,123 @@ package com.ruoyi.platform.service.impl; +import com.alibaba.druid.util.StringUtils; +import com.ruoyi.platform.domain.ExperimentIns; import com.ruoyi.platform.service.AimService; -import com.ruoyi.platform.vo.FrameLogPathVo; +import com.ruoyi.platform.service.ExperimentInsService; +import com.ruoyi.platform.service.ExperimentService; +import com.ruoyi.platform.utils.AIM64EncoderUtil; +import com.ruoyi.platform.utils.HttpUtils; +import com.ruoyi.platform.utils.JacksonUtil; +import com.ruoyi.platform.utils.JsonUtils; +import com.ruoyi.platform.vo.InsMetricInfoVo; +import org.apache.dubbo.container.Main; +import org.json.JSONObject; +import org.json.JSONTokener; import org.springframework.stereotype.Service; +import javax.annotation.Resource; +import java.net.URLEncoder; +import java.util.*; +import java.util.stream.Collectors; + @Service public class AimServiceImpl implements AimService { + @Resource + private ExperimentInsService experimentInsService; + + @Override + public List getExpTrainInfos(Integer experimentId) throws Exception { + String experimentName = "experiment-train-0"+experimentId; + return getAimRunInfos("",experimentId); + } + @Override - public String runAim(FrameLogPathVo frameLogPathVo) { - return null; + public List getExpEvaluateInfos(Integer experimentId) throws Exception { + String experimentName = "experiment-evaluate-0"+experimentId; + return getAimRunInfos("",experimentId); + } + + @Override + public String getExpMetrics(List runIds) throws Exception { + String decode = AIM64EncoderUtil.decode(runIds); + return "http://172.20.32.21:7123/api/runs/search/run?query="+decode; + } + + private List getAimRunInfos(String experimentName,Integer experimentId) throws Exception { + String encodedUrlString = URLEncoder.encode("run.experiment==\"experiment-0000\"", "UTF-8"); + String url = "http://172.20.32.181:30123/api/runs/search/run?query="+encodedUrlString; + String s = HttpUtils.sendGetRequest(url); + System.out.println(s); + List> response = JacksonUtil.parseJSONStr2MapList(s); + // TODO: parse aim response to InsMetricInfoVo list + if (response == null || response.size() == 0){ + return new ArrayList<>(); + } + //查询实例数据 + List byExperimentId = experimentInsService.getByExperimentId(experimentId); + +// if (byExperimentId == null || byExperimentId.size() == 0){ +// return new ArrayList<>(); +// } + List aimRunInfoList = new ArrayList<>(); + for (Map run : response) { + InsMetricInfoVo aimRunInfo = new InsMetricInfoVo(); + String runHash = (String) run.get("run_hash"); + aimRunInfo.setRunId(runHash); + + Map params= (Map) run.get("params"); + Map paramMap = JsonUtils.flattenJson("", params); + aimRunInfo.setParams(paramMap); + + Map tracesMap= (Map) run.get("params"); + List> metricList = (List>) tracesMap.get("metric"); + //过滤name为__system__开头的对象 + aimRunInfo.setMetrics(new HashMap<>()); + if (metricList != null && metricList.size() > 0){ + List> metricRelList = metricList.stream() + .filter(map -> !StringUtils.equals("__system__", (String) map.get("name"))) + .collect(Collectors.toList()); + if (metricRelList!= null && metricRelList.size() > 0){ + Map relMetricMap = new HashMap<>(); + for (Map metricMap : metricRelList) { + relMetricMap.put((String)metricMap.get("name"), metricMap.get("last_value")); + } + aimRunInfo.setMetrics(relMetricMap); + } + } + + + + //找到ins + + for (ExperimentIns ins : byExperimentId) { + String metricRecord = ins.getMetricRecord(); + if (metricRecord.contains(runHash)){ + aimRunInfo.setExperimentInsId(ins.getId()); + aimRunInfo.setStatus(ins.getStatus()); + aimRunInfo.setStartTime(ins.getStartTime()); + } + } + aimRunInfoList.add(aimRunInfo); + } + //判断哪个最长 + + Optional maxMetricsVo = aimRunInfoList.stream() + .max((vo1, vo2) -> Integer.compare(vo1.getMetrics().size(), vo2.getMetrics().size())); + + // 如果找到了,设置 metricsFlag 为 true + if (maxMetricsVo.isPresent()) { + maxMetricsVo.get().setMetricsFlag(true); + } + Optional maxParamsVo = aimRunInfoList.stream() + .max((vo1, vo2) -> Integer.compare(vo1.getParams().size(), vo2.getParams().size())); + + // 如果找到了,设置 metricsFlag 为 true + if (maxParamsVo.isPresent()) { + maxParamsVo.get().setMetricsFlag(true); + } + + return aimRunInfoList; } + } diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ExperimentServiceImpl.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ExperimentServiceImpl.java index 893d23e7..d1ae81c4 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ExperimentServiceImpl.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ExperimentServiceImpl.java @@ -254,9 +254,13 @@ public class ExperimentServiceImpl implements ExperimentService { //获取训练参数 Map metricRecord = (Map) runResMap.get("metric_record"); + Map metadata = (Map) data.get("metadata"); // 插入记录到实验实例表 ExperimentIns experimentIns = new ExperimentIns(); + if (metricRecord != null){ + experimentIns.setMetricRecord(JacksonUtil.toJSONString(metricRecord)); + } experimentIns.setExperimentId(experiment.getId()); experimentIns.setArgoInsNs((String) metadata.get("namespace")); experimentIns.setArgoInsName((String) metadata.get("name")); @@ -275,8 +279,9 @@ public class ExperimentServiceImpl implements ExperimentService { Map converMap2 = JsonUtils.jsonToMap(JacksonUtil.replaceInAarry(convertRes, params)); Map dependendcy = (Map)converMap2.get("model_dependency"); Map trainInfo = (Map)converMap2.get("component_info"); - insertModelDependency(dependendcy,trainInfo,insert.getId(),experiment.getName()); - + if (dependendcy != null && trainInfo != null){ + insertModelDependency(dependendcy,trainInfo,insert.getId(),experiment.getName()); + } }catch (Exception e){ throw new RuntimeException(e); } diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/utils/AIM64EncoderUtil.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/utils/AIM64EncoderUtil.java new file mode 100644 index 00000000..0cffc705 --- /dev/null +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/utils/AIM64EncoderUtil.java @@ -0,0 +1,76 @@ +package com.ruoyi.platform.utils; + +import com.alibaba.fastjson.JSON; + +import java.util.Base64; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class AIM64EncoderUtil { + + private static final String AIM64_ENCODING_PREFIX = "O-"; + + private static final Map BS64_REPLACE_CHARACTERS_ENCODING = new HashMap<>(); + static { + BS64_REPLACE_CHARACTERS_ENCODING.put("=", ""); + BS64_REPLACE_CHARACTERS_ENCODING.put("+", "-"); + BS64_REPLACE_CHARACTERS_ENCODING.put("/", "_"); + } + + public static String aim64encode(Map value) { + String jsonEncoded = JSON.toJSONString(value); + String base64Encoded = Base64.getEncoder().encodeToString(jsonEncoded.getBytes()); + String aim64Encoded = base64Encoded; + for (Map.Entry entry : BS64_REPLACE_CHARACTERS_ENCODING.entrySet()) { + aim64Encoded = aim64Encoded.replace(entry.getKey(), entry.getValue()); + } + return AIM64_ENCODING_PREFIX + aim64Encoded; + } + + public static String encode(Map value, boolean oneWayHashing) { + if (oneWayHashing) { + return md5(JSON.toJSONString(value)); + } + return aim64encode(value); + } + + private static String md5(String input) { + try { + java.security.MessageDigest md = java.security.MessageDigest.getInstance("MD5"); + byte[] array = md.digest(input.getBytes()); + StringBuilder sb = new StringBuilder(); + for (byte b : array) { + sb.append(Integer.toHexString((b & 0xFF) | 0x100).substring(1, 3)); + } + return sb.toString(); + } catch (java.security.NoSuchAlgorithmException e) { + e.printStackTrace(); + } + return null; + } + + public static String decode(List runIds) { + // 确保 runIds 列表的大小为 3 + if (runIds == null || runIds.size() == 0) { + throw new IllegalArgumentException("runIds 不能为空"); + } + // 构建查询字符串 + StringBuilder queryBuilder = new StringBuilder("run.hash in ["); + for (int i = 0; i < runIds.size(); i++) { + if (i > 0) { + queryBuilder.append(","); + } + queryBuilder.append("\"").append(runIds.get(i)).append("\""); + } + queryBuilder.append("]"); + String query = queryBuilder.toString(); + Map map = new HashMap<>(); + map.put("query", query); + map.put("advancedMode", true); + map.put("advancedQuery", query); + + String searchQuery = encode(map, false); + return searchQuery; + } +} diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/utils/HttpUtils.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/utils/HttpUtils.java index b382eda9..910d9981 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/utils/HttpUtils.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/utils/HttpUtils.java @@ -25,6 +25,7 @@ import java.security.cert.X509Certificate; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.zip.GZIPInputStream; /** * HTTP请求工具类 @@ -447,4 +448,38 @@ public class HttpUtils { return true; } } + + public static String sendGetRequestgzip(String url) throws Exception { + String resultStr = null; + HttpGet httpGet = new HttpGet(url); + httpGet.setHeader("Content-Type", "application/json"); + httpGet.setHeader("Accept-Encoding", "gzip, deflate"); + try { + HttpResponse response = httpClient.execute(httpGet); + int responseCode = response.getStatusLine().getStatusCode(); + if (responseCode != 200) { + throw new IOException("HTTP request failed with response code: " + responseCode); + } + + // 获取响应内容 + InputStream responseStream = response.getEntity().getContent(); + // 检查响应是否被压缩 + if ("gzip".equalsIgnoreCase(response.getEntity().getContentEncoding().getValue())) { + responseStream = new GZIPInputStream(responseStream); + } + + // 读取解压缩后的内容 + byte[] buffer = new byte[1024]; + int len; + StringBuilder decompressedString = new StringBuilder(); + while ((len = responseStream.read(buffer)) > 0) { + decompressedString.append(new String(buffer, 0, len)); + } + + resultStr = decompressedString.toString(); + } catch (IOException e) { + e.printStackTrace(); + } + return resultStr; + } } diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/utils/JsonUtils.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/utils/JsonUtils.java index e1b41780..186173eb 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/utils/JsonUtils.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/utils/JsonUtils.java @@ -1,8 +1,11 @@ package com.ruoyi.platform.utils; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; +import org.json.JSONObject; import java.io.IOException; +import java.util.HashMap; +import java.util.Iterator; import java.util.Map; public class JsonUtils { @@ -28,4 +31,26 @@ public class JsonUtils { public static T jsonToObject(String json, Class clazz) throws IOException { return objectMapper.readValue(json, clazz); } + + + + // 将JSON字符串转换为扁平化的Map + public static Map flattenJson(String prefix, Map map) { + Map flatMap = new HashMap<>(); + Iterator> entries = map.entrySet().iterator(); + + while (entries.hasNext()) { + Map.Entry entry = entries.next(); + String key = entry.getKey(); + Object value = entry.getValue(); + + if (value instanceof Map) { + flatMap.putAll(flattenJson(prefix + key + ".", (Map) value)); + } else { + flatMap.put(prefix + key, value); + } + } + + return flatMap; + } } diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/vo/InsMetricInfoVo.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/vo/InsMetricInfoVo.java new file mode 100644 index 00000000..cd3943ed --- /dev/null +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/vo/InsMetricInfoVo.java @@ -0,0 +1,32 @@ +package com.ruoyi.platform.vo; + +import com.fasterxml.jackson.databind.PropertyNamingStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import io.swagger.annotations.ApiModelProperty; +import lombok.Data; + +import java.io.Serializable; +import java.util.Date; +import java.util.List; +import java.util.Map; + +@JsonNaming(PropertyNamingStrategy.SnakeCaseStrategy.class) +@Data +public class InsMetricInfoVo implements Serializable { + @ApiModelProperty(value = "开始时间") + private Date startTime; + @ApiModelProperty(value = "实例运行状态") + private String status; + @ApiModelProperty(value = "使用数据集") + private List> dataset; + @ApiModelProperty(value = "实例ID") + private Integer experimentInsId; + @ApiModelProperty(value = "训练指标") + private Map metrics; + @ApiModelProperty(value = "训练参数") + private Map params; + @ApiModelProperty(value = "训练记录ID") + private String runId; + private Boolean metricsFlag = false; + private Boolean paramsFlag = false; +} From 90541b57b931e5cbbb4830495c35e018f9ae555c Mon Sep 17 00:00:00 2001 From: cp3hnu Date: Tue, 25 Jun 2024 09:34:36 +0800 Subject: [PATCH 5/6] =?UTF-8?q?chore:=20=E5=BC=80=E5=8F=91=E7=8E=AF?= =?UTF-8?q?=E5=A2=83=E5=9B=9E=E9=80=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- react-ui/config/routes.ts | 20 ++++++------ .../DevelopmentEnvironment/Editor/index.tsx | 31 ++++++++++--------- 2 files changed, 27 insertions(+), 24 deletions(-) diff --git a/react-ui/config/routes.ts b/react-ui/config/routes.ts index 81076627..3a58d07b 100644 --- a/react-ui/config/routes.ts +++ b/react-ui/config/routes.ts @@ -112,18 +112,18 @@ export default [ { name: '开发环境', path: '', - component: './DevelopmentEnvironment/List', - }, - { - name: '创建编辑器', - path: 'create', - component: './DevelopmentEnvironment/Create', - }, - { - name: '编辑器', - path: 'editor', component: './DevelopmentEnvironment/Editor', }, + // { + // name: '创建编辑器', + // path: 'create', + // component: './DevelopmentEnvironment/Create', + // }, + // { + // name: '编辑器', + // path: 'editor', + // component: './DevelopmentEnvironment/Editor', + // }, ], }, { diff --git a/react-ui/src/pages/DevelopmentEnvironment/Editor/index.tsx b/react-ui/src/pages/DevelopmentEnvironment/Editor/index.tsx index b113d76b..0b2c63de 100644 --- a/react-ui/src/pages/DevelopmentEnvironment/Editor/index.tsx +++ b/react-ui/src/pages/DevelopmentEnvironment/Editor/index.tsx @@ -1,24 +1,27 @@ -import { editorUrl, getSessionStorageItem, removeSessionStorageItem } from '@/utils/sessionStorage'; +// import { editorUrl, getSessionStorageItem, removeSessionStorageItem } from '@/utils/sessionStorage'; +import { getJupyterUrl } from '@/services/developmentEnvironment'; +import { to } from '@/utils/promise'; import { useEffect, useState } from 'react'; function DevEditor() { const [iframeUrl, setIframeUrl] = useState(''); useEffect(() => { - const url = getSessionStorageItem(editorUrl) || ''; - setIframeUrl(url); - return () => { - removeSessionStorageItem(editorUrl); - }; + // const url = getSessionStorageItem(editorUrl) || ''; + // setIframeUrl(url); + // return () => { + // removeSessionStorageItem(editorUrl); + // }; + requestJupyterUrl(); }, []); - // const requestJupyterUrl = async () => { - // const [res, error] = await to(getJupyterUrl()); - // if (res) { - // setIframeUrl(res.data as string); - // } else { - // console.log(error); - // } - // }; + const requestJupyterUrl = async () => { + const [res, error] = await to(getJupyterUrl()); + if (res) { + setIframeUrl(res.data as string); + } else { + console.log(error); + } + }; return ; } From 0e6c7ce5e86bcc1ea4a090287efe9b511af982a3 Mon Sep 17 00:00:00 2001 From: fanshuai <1141904845@qq.com> Date: Wed, 26 Jun 2024 09:14:41 +0800 Subject: [PATCH 6/6] =?UTF-8?q?=E6=96=B0=E5=A2=9E=E6=8C=87=E6=A0=87?= =?UTF-8?q?=E5=AF=B9=E6=AF=94?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../platform/mapper/ModelDependencyDao.java | 4 + .../ExperimentInstanceStatusTask.java | 70 +++++----- .../service/ModelDependencyService.java | 4 + .../platform/service/impl/AimServiceImpl.java | 132 ++++++++++++------ .../impl/ModelDependencyServiceImpl.java | 15 +- .../ruoyi/platform/vo/InsMetricInfoVo.java | 6 +- .../ModelDependencyDaoMapper.xml | 16 +++ 7 files changed, 162 insertions(+), 85 deletions(-) diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/mapper/ModelDependencyDao.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/mapper/ModelDependencyDao.java index 3a999886..ba1bc40b 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/mapper/ModelDependencyDao.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/mapper/ModelDependencyDao.java @@ -84,5 +84,9 @@ public interface ModelDependencyDao { List queryByModelDependency(@Param("modelDependency") ModelDependency modelDependency); List queryChildrenByVersionId(@Param("model_id")String modelId, @Param("version")String version); + + List queryByIns(@Param("expInsId")Integer expInsId); + + ModelDependency queryByInsAndTrainTaskId(@Param("expInsId")Integer expInsId,@Param("taskId") String taskId); } diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/scheduling/ExperimentInstanceStatusTask.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/scheduling/ExperimentInstanceStatusTask.java index 4680285e..131dca48 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/scheduling/ExperimentInstanceStatusTask.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/scheduling/ExperimentInstanceStatusTask.java @@ -7,20 +7,15 @@ import com.ruoyi.platform.mapper.ExperimentDao; import com.ruoyi.platform.mapper.ExperimentInsDao; import com.ruoyi.platform.mapper.ModelDependencyDao; import com.ruoyi.platform.service.ExperimentInsService; -import com.ruoyi.platform.service.ModelDependencyService; import com.ruoyi.platform.utils.JacksonUtil; import org.apache.commons.lang3.StringUtils; import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.data.domain.Page; import org.springframework.scheduling.annotation.Scheduled; import org.springframework.stereotype.Component; import javax.annotation.Resource; import java.io.IOException; -import java.util.ArrayList; -import java.util.Date; -import java.util.List; -import java.util.Map; +import java.util.*; @Component() public class ExperimentInstanceStatusTask { @@ -34,7 +29,7 @@ public class ExperimentInstanceStatusTask { private ModelDependencyDao modelDependencyDao; private List experimentIds = new ArrayList<>(); - @Scheduled(cron = "0/14 * * * * ?") // 每30S执行一次 + @Scheduled(cron = "0/30 * * * * ?") // 每30S执行一次 public void executeExperimentInsStatus() throws IOException { // 首先查到所有非终止态的实验实例 List experimentInsList = experimentInsService.queryByExperimentIsNotTerminated(); @@ -46,95 +41,94 @@ public class ExperimentInstanceStatusTask { String oldStatus = experimentIns.getStatus(); try { experimentIns = experimentInsService.queryStatusFromArgo(experimentIns); - }catch (Exception e){ + } catch (Exception e) { experimentIns.setStatus("Failed"); } -// if (!StringUtils.equals(oldStatus,experimentIns.getStatus())){ - experimentIns.setUpdateTime(new Date()); - // 线程安全的添加操作 - synchronized (experimentIds) { - experimentIds.add(experimentIns.getExperimentId()); - } - updateList.add(experimentIns); - -// } -// experimentInsDao.update(experimentIns); + experimentIns.setUpdateTime(new Date()); + // 线程安全的添加操作 + synchronized (experimentIds) { + experimentIds.add(experimentIns.getExperimentId()); + } + updateList.add(experimentIns); } - } - if (updateList.size() > 0){ + if (updateList.size() > 0) { experimentInsDao.insertOrUpdateBatch(updateList); //遍历模型关系表,找到 List modelDependencyList = new ArrayList(); - for (ExperimentIns experimentIns : updateList){ + for (ExperimentIns experimentIns : updateList) { ModelDependency modelDependencyquery = new ModelDependency(); modelDependencyquery.setExpInsId(experimentIns.getId()); modelDependencyquery.setState(2); List modelDependencyListquery = modelDependencyDao.queryByModelDependency(modelDependencyquery); - if (modelDependencyListquery==null||modelDependencyListquery.size()==0){ + if (modelDependencyListquery == null || modelDependencyListquery.size() == 0) { continue; } ModelDependency modelDependency = modelDependencyListquery.get(0); //查看状态, - if (StringUtils.equals("Failed",experimentIns.getStatus())){ + if (StringUtils.equals("Failed", experimentIns.getStatus())) { //取出节点状态 String trainTask = modelDependency.getTrainTask(); Map trainMap = JacksonUtil.parseJSONStr2Map(trainTask); String task_id = (String) trainMap.get("task_id"); - if (StringUtils.isEmpty(task_id)){ + if (StringUtils.isEmpty(task_id)) { continue; } String nodesStatus = experimentIns.getNodesStatus(); Map nodeMaps = JacksonUtil.parseJSONStr2Map(nodesStatus); Map nodeMap = JacksonUtil.parseJSONStr2Map(JacksonUtil.toJSONString(nodeMaps.get(task_id))); - if (nodeMap==null){ + if (nodeMap == null) { continue; } - if (!StringUtils.equals("Succeeded",(String)nodeMap.get("phase"))){ + if (!StringUtils.equals("Succeeded", (String) nodeMap.get("phase"))) { modelDependency.setState(0); modelDependencyList.add(modelDependency); } } } - if (modelDependencyList.size()>0) { + if (modelDependencyList.size() > 0) { modelDependencyDao.insertOrUpdateBatch(modelDependencyList); } } - } - @Scheduled(cron = "0/17 * * * * ?") // / 每30S执行一次 + + @Scheduled(cron = "0/30 * * * * ?") // / 每30S执行一次 public void executeExperimentStatus() throws IOException { - if (experimentIds.size()==0){ + if (experimentIds.size() == 0) { return; } // 存储需要更新的实验对象列表 List updateExperiments = new ArrayList<>(); - for (Integer experimentId : experimentIds){ + for (Integer experimentId : experimentIds) { // 获取当前实验的所有实例列表 List insList = experimentInsService.getByExperimentId(experimentId); List statusList = new ArrayList(); // 更新实验状态列表 - for (int i=0;i iterator = experimentIds.iterator(); + while (iterator.hasNext()) { + Integer experimentId = iterator.next(); + for (Experiment experiment : updateExperiments) { + if (experiment.getId().equals(experimentId)) { + iterator.remove(); + } } } } diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/ModelDependencyService.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/ModelDependencyService.java index 5c8b9d1d..049d87d1 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/ModelDependencyService.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/ModelDependencyService.java @@ -62,4 +62,8 @@ public interface ModelDependencyService { List queryByModelDependency(ModelDependency modelDependency) throws IOException; ModelDependcyTreeVo getModelDependencyTree(ModelDependency modelDependency) throws Exception; + + List queryByIns(Integer expInsId); + + ModelDependency queryByInsAndTrainTaskId(Integer expInsId, String taskId); } diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/AimServiceImpl.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/AimServiceImpl.java index c26e101e..6ec8f43c 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/AimServiceImpl.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/AimServiceImpl.java @@ -1,18 +1,17 @@ package com.ruoyi.platform.service.impl; -import com.alibaba.druid.util.StringUtils; import com.ruoyi.platform.domain.ExperimentIns; +import com.ruoyi.platform.domain.ModelDependency; import com.ruoyi.platform.service.AimService; import com.ruoyi.platform.service.ExperimentInsService; -import com.ruoyi.platform.service.ExperimentService; +import com.ruoyi.platform.service.ModelDependencyService; import com.ruoyi.platform.utils.AIM64EncoderUtil; import com.ruoyi.platform.utils.HttpUtils; import com.ruoyi.platform.utils.JacksonUtil; import com.ruoyi.platform.utils.JsonUtils; import com.ruoyi.platform.vo.InsMetricInfoVo; -import org.apache.dubbo.container.Main; -import org.json.JSONObject; -import org.json.JSONTokener; +import org.apache.commons.lang3.StringUtils; +import org.springframework.beans.factory.annotation.Value; import org.springframework.stereotype.Service; import javax.annotation.Resource; @@ -24,58 +23,66 @@ import java.util.stream.Collectors; public class AimServiceImpl implements AimService { @Resource private ExperimentInsService experimentInsService; + @Resource + private ModelDependencyService modelDependencyService; + + @Value("${aim.url}") + private String aimUrl; + @Value("${aim.proxyUrl}") + private String aimProxyUrl; @Override public List getExpTrainInfos(Integer experimentId) throws Exception { - String experimentName = "experiment-train-0"+experimentId; - return getAimRunInfos("",experimentId); + return getAimRunInfos(true,experimentId); } @Override public List getExpEvaluateInfos(Integer experimentId) throws Exception { - String experimentName = "experiment-evaluate-0"+experimentId; - return getAimRunInfos("",experimentId); + return getAimRunInfos(false,experimentId); } @Override public String getExpMetrics(List runIds) throws Exception { String decode = AIM64EncoderUtil.decode(runIds); - return "http://172.20.32.21:7123/api/runs/search/run?query="+decode; + return aimUrl+"/api/runs/search/run?query="+decode; } - private List getAimRunInfos(String experimentName,Integer experimentId) throws Exception { - String encodedUrlString = URLEncoder.encode("run.experiment==\"experiment-0000\"", "UTF-8"); - String url = "http://172.20.32.181:30123/api/runs/search/run?query="+encodedUrlString; + private List getAimRunInfos(boolean isTrain,Integer experimentId) throws Exception { + String experimentName = "experiment-"+experimentId+"-train"; + if (!isTrain){ + experimentName = "experiment-"+experimentId+"-evaluate"; + } + String encodedUrlString = URLEncoder.encode("run.experiment==\""+experimentName+"\"", "UTF-8"); + String url = aimProxyUrl+"/api/runs/search/run?query="+encodedUrlString; String s = HttpUtils.sendGetRequest(url); - System.out.println(s); List> response = JacksonUtil.parseJSONStr2MapList(s); - // TODO: parse aim response to InsMetricInfoVo list if (response == null || response.size() == 0){ return new ArrayList<>(); } //查询实例数据 List byExperimentId = experimentInsService.getByExperimentId(experimentId); -// if (byExperimentId == null || byExperimentId.size() == 0){ -// return new ArrayList<>(); -// } + if (byExperimentId == null || byExperimentId.size() == 0){ + return new ArrayList<>(); + } List aimRunInfoList = new ArrayList<>(); for (Map run : response) { InsMetricInfoVo aimRunInfo = new InsMetricInfoVo(); String runHash = (String) run.get("run_hash"); + aimRunInfo.setRunId(runHash); Map params= (Map) run.get("params"); Map paramMap = JsonUtils.flattenJson("", params); aimRunInfo.setParams(paramMap); - - Map tracesMap= (Map) run.get("params"); + String aimrunId = (String) paramMap.get("id"); + Map tracesMap= (Map) run.get("traces"); List> metricList = (List>) tracesMap.get("metric"); //过滤name为__system__开头的对象 aimRunInfo.setMetrics(new HashMap<>()); if (metricList != null && metricList.size() > 0){ List> metricRelList = metricList.stream() - .filter(map -> !StringUtils.equals("__system__", (String) map.get("name"))) + .filter(map -> !StringUtils.startsWith((String) map.get("name"),"__system__" )) .collect(Collectors.toList()); if (metricRelList!= null && metricRelList.size() > 0){ Map relMetricMap = new HashMap<>(); @@ -85,39 +92,84 @@ public class AimServiceImpl implements AimService { aimRunInfo.setMetrics(relMetricMap); } } - - - //找到ins - for (ExperimentIns ins : byExperimentId) { - String metricRecord = ins.getMetricRecord(); - if (metricRecord.contains(runHash)){ + String metricRecordString = ins.getMetricRecord(); + if (StringUtils.isEmpty(metricRecordString)){ + continue; + } + if (metricRecordString.contains(aimrunId)){ aimRunInfo.setExperimentInsId(ins.getId()); aimRunInfo.setStatus(ins.getStatus()); - aimRunInfo.setStartTime(ins.getStartTime()); + aimRunInfo.setStartTime(ins.getCreateTime()); + Map metricRecordMap = JacksonUtil.parseJSONStr2Map(metricRecordString); + + //metricRecord 格式为{"train":[{"task_id":"model-train-35303690","run_id":"5560d78f54314672b60304c8d6ba03b8","experiment_name":"experiment-30-train"}],"evaluate":[{"task_id":"model-train-35303690","run_id":"5560d78f54314672b60304c8d6ba03b8","experiment_name":"experiment-30-train"}]} + //遍历metricRecord,找到当前task_id对应的ModelDependency + + if (isTrain){ + List> trainList = (List>) metricRecordMap.get("train"); + List trainDateSet = getTrainDateSet(trainList, ins.getId(), isTrain); + aimRunInfo.setDataset(trainDateSet); + }else { + List> trainList = (List>) metricRecordMap.get("evaluate"); + List trainDateSet = getTrainDateSet(trainList, ins.getId(), isTrain); + aimRunInfo.setDataset(trainDateSet); + } + } } aimRunInfoList.add(aimRunInfo); } //判断哪个最长 - Optional maxMetricsVo = aimRunInfoList.stream() - .max((vo1, vo2) -> Integer.compare(vo1.getMetrics().size(), vo2.getMetrics().size())); + // 获取所有 metrics 的 key 的并集 + Set metricsKeys = (Set) aimRunInfoList.stream() + .map(InsMetricInfoVo::getMetrics) + .flatMap(metrics -> metrics.keySet().stream()) + .collect(Collectors.toSet()); + // 将并集赋值给每个 InsMetricInfoVo 的 metricsNames 属性 + aimRunInfoList.forEach(vo -> vo.setMetricsNames(new ArrayList<>(metricsKeys))); + + // 获取所有 params 的 key 的并集 + Set paramKeys = (Set) aimRunInfoList.stream() + .map(InsMetricInfoVo::getParams) + .flatMap(params -> params.keySet().stream()) + .collect(Collectors.toSet()); + // 将并集赋值给每个 InsMetricInfoVo 的 paramsNames 属性 + aimRunInfoList.forEach(vo -> vo.setParamsNames(new ArrayList<>(paramKeys))); + + return aimRunInfoList; + } - // 如果找到了,设置 metricsFlag 为 true - if (maxMetricsVo.isPresent()) { - maxMetricsVo.get().setMetricsFlag(true); - } - Optional maxParamsVo = aimRunInfoList.stream() - .max((vo1, vo2) -> Integer.compare(vo1.getParams().size(), vo2.getParams().size())); - // 如果找到了,设置 metricsFlag 为 true - if (maxParamsVo.isPresent()) { - maxParamsVo.get().setMetricsFlag(true); + private List getTrainDateSet(List> trainList,Integer expInsId,boolean isTrain){ + if (trainList == null || trainList.size() == 0){ + return new ArrayList<>(); } + List datasetList = new ArrayList<>(); + for (Map trainMap : trainList) { + String task_id = (String) trainMap.get("task_id"); + //modelDependency取到数据集文件 + ModelDependency modelDependency = modelDependencyService.queryByInsAndTrainTaskId(expInsId, task_id); + //把数据集文件组装成String后放进List + String datasetString = ""; + if (isTrain){ + datasetString = modelDependency.getTrainDataset(); + }else { + datasetString = modelDependency.getTestDataset(); + } + List> datasetListMap = JacksonUtil.parseJSONStr2MapList(datasetString); - return aimRunInfoList; + if (datasetListMap != null && datasetListMap.size() > 0){ + for (Map datasetMap : datasetListMap) { + //[{"dataset_id":20,"dataset_version":"v0.1.0","dataset_name":"手写体识别模型依赖测试训练数据集"}] + String datasetName = (String) datasetMap.get("dataset_name")+":"+(String) datasetMap.get("dataset_version"); + datasetList.add(datasetName); + } + } + } + return datasetList; } } diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ModelDependencyServiceImpl.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ModelDependencyServiceImpl.java index f3c48ebb..572a66a5 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ModelDependencyServiceImpl.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/service/impl/ModelDependencyServiceImpl.java @@ -17,10 +17,7 @@ import org.springframework.data.domain.PageRequest; import javax.annotation.Resource; import java.io.IOException; -import java.util.ArrayList; -import java.util.Date; -import java.util.List; -import java.util.Map; +import java.util.*; import java.util.stream.Collectors; /** @@ -97,6 +94,16 @@ public class ModelDependencyServiceImpl implements ModelDependencyService { return modelDependcyTreeVo; } + @Override + public List queryByIns(Integer expInsId) { + return modelDependencyDao.queryByIns(expInsId); + } + + @Override + public ModelDependency queryByInsAndTrainTaskId(Integer expInsId, String taskId) { + return modelDependencyDao.queryByInsAndTrainTaskId(expInsId,taskId); + } + /** * 递归父模型 * @param modelDependcyTreeVo diff --git a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/vo/InsMetricInfoVo.java b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/vo/InsMetricInfoVo.java index cd3943ed..6fe8caa4 100644 --- a/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/vo/InsMetricInfoVo.java +++ b/ruoyi-modules/management-platform/src/main/java/com/ruoyi/platform/vo/InsMetricInfoVo.java @@ -18,7 +18,7 @@ public class InsMetricInfoVo implements Serializable { @ApiModelProperty(value = "实例运行状态") private String status; @ApiModelProperty(value = "使用数据集") - private List> dataset; + private List dataset; @ApiModelProperty(value = "实例ID") private Integer experimentInsId; @ApiModelProperty(value = "训练指标") @@ -27,6 +27,6 @@ public class InsMetricInfoVo implements Serializable { private Map params; @ApiModelProperty(value = "训练记录ID") private String runId; - private Boolean metricsFlag = false; - private Boolean paramsFlag = false; + private List metricsNames; + private List paramsNames; } diff --git a/ruoyi-modules/management-platform/src/main/resources/mapper/managementPlatform/ModelDependencyDaoMapper.xml b/ruoyi-modules/management-platform/src/main/resources/mapper/managementPlatform/ModelDependencyDaoMapper.xml index 2cd5dd7a..ea592ee2 100644 --- a/ruoyi-modules/management-platform/src/main/resources/mapper/managementPlatform/ModelDependencyDaoMapper.xml +++ b/ruoyi-modules/management-platform/src/main/resources/mapper/managementPlatform/ModelDependencyDaoMapper.xml @@ -22,6 +22,22 @@ + +