From 2cc70b1f00633694ac7321330d2fe7a35abde668 Mon Sep 17 00:00:00 2001 From: cp3hnu Date: Mon, 1 Jul 2024 11:45:39 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=A8=A1=E5=9E=8B=E6=BC=94=E5=8C=96?= =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E7=82=B9=E5=87=BB=E8=8A=82=E7=82=B9=E5=B1=95?= =?UTF-8?q?=E5=BC=80=E6=95=B0=E6=8D=AE=E9=9B=86=E5=92=8C=E9=A1=B9=E7=9B=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../Model/components/ModelEvolution/index.tsx | 67 +++++-- .../Model/components/ModelEvolution/utils.tsx | 163 +++++++++++++----- .../Model/components/NodeTooltips/index.tsx | 14 +- .../src/pages/ModelDeployment/List/index.tsx | 2 +- 4 files changed, 180 insertions(+), 66 deletions(-) diff --git a/react-ui/src/pages/Model/components/ModelEvolution/index.tsx b/react-ui/src/pages/Model/components/ModelEvolution/index.tsx index f46d77ec..a0e9beaf 100644 --- a/react-ui/src/pages/Model/components/ModelEvolution/index.tsx +++ b/react-ui/src/pages/Model/components/ModelEvolution/index.tsx @@ -1,9 +1,15 @@ +/* + * @Author: 赵伟 + * @Date: 2024-06-07 11:24:10 + * @Description: 模型演化 + */ + import { useEffectWhen } from '@/hooks'; import { ResourceVersionData } from '@/pages/Dataset/config'; import { getModelAtlasReq } from '@/services/dataset/index.js'; import themes from '@/styles/theme.less'; import { to } from '@/utils/promise'; -import G6, { G6GraphEvent, Graph } from '@antv/g6'; +import G6, { G6GraphEvent, Graph, INode } from '@antv/g6'; // @ts-ignore import { Flex, Select } from 'antd'; import { useEffect, useRef, useState } from 'react'; @@ -11,7 +17,15 @@ import GraphLegend from '../GraphLegend'; import NodeTooltips from '../NodeTooltips'; import styles from './index.less'; import type { ModelDepsData, ProjectDependency, TrainDataset } from './utils'; -import { getGraphData, nodeFontSize, nodeHeight, nodeWidth, normalizeTreeData } from './utils'; +import { + NodeType, + getGraphData, + nodeFontSize, + nodeHeight, + nodeWidth, + normalizeTreeData, + traverseHierarchically, +} from './utils'; type modeModelEvolutionProps = { resourceId: number; @@ -37,6 +51,8 @@ function ModelEvolution({ const [hoverNodeData, setHoverNodeData] = useState< ModelDepsData | ProjectDependency | TrainDataset | undefined >(undefined); + const apiData = useRef(undefined); // 接口返回的树形结构 + const hierarchyNodes = useRef([]); // 层级迭代树形结构,得到的节点列表 useEffect(() => { initGraph(); @@ -111,18 +127,7 @@ function ModelEvolution({ }, }, modes: { - default: [ - 'drag-canvas', - 'zoom-canvas', - // { - // type: 'collapse-expand', - // onChange(item?: Item, collapsed?: boolean) { - // const data = item!.getModel(); - // data.collapsed = collapsed; - // return true; - // }, - // }, - ], + default: ['drag-canvas', 'zoom-canvas'], }, }); @@ -161,11 +166,26 @@ function ModelEvolution({ }); graph.on('node:click', (e: G6GraphEvent) => { - const nodeItem = e.item; + const nodeItem = e.item as INode; const model = nodeItem.getModel() as ModelDepsData | ProjectDependency | TrainDataset; const { model_type } = model; - switch (model_type) { + if ( + model_type === NodeType.Project || + model_type === NodeType.TrainDataset || + model_type === NodeType.TestDataset || + !apiData.current || + !hierarchyNodes.current + ) { + return; } + + setShowNodeTooltip(false); + setEnterTooltip(false); + toggleExpended(model.id); + const graphData = getGraphData(apiData.current, hierarchyNodes.current); + graph.data(graphData); + graph.render(); + graph.fitView(); }); // 鼠标滚轮缩放时,隐藏 tooltip @@ -175,6 +195,17 @@ function ModelEvolution({ }); }; + // toggle 展开 + const toggleExpended = (id: string) => { + const nodes = hierarchyNodes.current; + for (const node of nodes) { + if (node.id === id) { + node.expanded = !node.expanded; + break; + } + } + }; + const handleTooltipsMouseEnter = () => { setEnterTooltip(true); }; @@ -192,7 +223,9 @@ function ModelEvolution({ const [res] = await to(getModelAtlasReq(params)); if (res && res.data) { const data = normalizeTreeData(res.data); - const graphData = getGraphData(data); + apiData.current = data; + hierarchyNodes.current = traverseHierarchically(data); + const graphData = getGraphData(data, hierarchyNodes.current); graph.data(graphData); graph.render(); diff --git a/react-ui/src/pages/Model/components/ModelEvolution/utils.tsx b/react-ui/src/pages/Model/components/ModelEvolution/utils.tsx index a878321a..30a06817 100644 --- a/react-ui/src/pages/Model/components/ModelEvolution/utils.tsx +++ b/react-ui/src/pages/Model/components/ModelEvolution/utils.tsx @@ -6,21 +6,22 @@ import Hierarchy from '@antv/hierarchy'; export const nodeWidth = 90; export const nodeHeight = 40; export const vGap = nodeHeight + 20; -export const hGap = nodeWidth; +export const hGap = nodeHeight + 20; export const ellipseWidth = nodeWidth; export const labelPadding = 30; export const nodeFontSize = 8; +export const datasetHGap = 20; // 数据集节点 const datasetNodes: NodeConfig[] = []; export enum NodeType { - current = 'current', - parent = 'parent', - children = 'children', - project = 'project', - trainDataset = 'trainDataset', - testDataset = 'testDataset', + Current = 'Current', // 当前模型 + Parent = 'Parent', // 父模型 + Children = 'Children', // 子模型 + Project = 'Project', // 项目 + TrainDataset = 'TrainDataset', // 训练数据集 + TestDataset = 'TestDataset', // 测试数据集 } export type Rect = { @@ -40,14 +41,14 @@ export interface TrainDataset extends NodeConfig { dataset_id: number; dataset_name: string; dataset_version: string; - model_type: NodeType.testDataset | NodeType.trainDataset; + model_type: NodeType.TestDataset | NodeType.TrainDataset; } export interface ProjectDependency extends NodeConfig { url: string; name: string; branch: string; - model_type: NodeType.project; + model_type: NodeType.Project; } export type ModalDetail = { @@ -66,9 +67,9 @@ export interface ModelDepsAPIData { version: string; workflow_id: number; exp_ins_id: number; - model_type: NodeType.children | NodeType.current | NodeType.parent; + model_type: NodeType.Children | NodeType.Current | NodeType.Parent; current_model_name: string; - project_dependency: ProjectDependency; + project_dependency?: ProjectDependency; test_dataset: TrainDataset[]; train_dataset: TrainDataset[]; train_task: TrainTask; @@ -79,16 +80,22 @@ export interface ModelDepsAPIData { export interface ModelDepsData extends Omit, TreeGraphData { children: ModelDepsData[]; + expanded: boolean; // 是否展开 + level: number; // 层级,从 0 开始 + datasetLen: number; // 数据集数量 } // 规范化子数据 export function normalizeChildren(data: ModelDepsData[]) { if (Array.isArray(data)) { data.forEach((item) => { - item.model_type = NodeType.children; + item.model_type = NodeType.Children; + item.expanded = false; + item.level = 0; + item.datasetLen = item.train_dataset.length + item.test_dataset.length; item.id = `$M_${item.current_model_id}_${item.version}`; item.label = getLabel(item); - item.style = getStyle(NodeType.children); + item.style = getStyle(NodeType.Children); normalizeChildren(item.children); }); } @@ -111,22 +118,22 @@ export function getLabel(node: ModelDepsData | ModelDepsAPIData) { export function getStyle(model_type: NodeType) { let fill = ''; switch (model_type) { - case NodeType.current: + case NodeType.Current: fill = 'l(0) 0:#72a1ff 1:#1664ff'; break; - case NodeType.parent: + case NodeType.Parent: fill = 'l(0) 0:#93dfd1 1:#43c9b1'; break; - case NodeType.children: + case NodeType.Children: fill = 'l(0) 0:#72b4ff 1:#169aff'; break; - case NodeType.project: + case NodeType.Project: fill = 'l(0) 0:#b3a9ff 1:#8981ff'; break; - case NodeType.trainDataset: + case NodeType.TrainDataset: fill = '#a5d878'; break; - case NodeType.testDataset: + case NodeType.TestDataset: fill = '#d8b578'; break; default: @@ -145,11 +152,15 @@ export function normalizeTreeData(apiData: ModelDepsAPIData): ModelDepsData { }) as ModelDepsData; // 设置当前模型的数据 - normalizedData.model_type = NodeType.current; + normalizedData.model_type = NodeType.Current; normalizedData.id = `$M_${normalizedData.current_model_id}_${normalizedData.version}`; normalizedData.label = getLabel(normalizedData); - normalizedData.style = getStyle(NodeType.current); + normalizedData.style = getStyle(NodeType.Current); + normalizedData.expanded = true; + normalizedData.datasetLen = + normalizedData.train_dataset.length + normalizedData.test_dataset.length; normalizeChildren(normalizedData.children as ModelDepsData[]); + normalizedData.level = 0; // 将 parent_models 转换成树形结构 let parent_models = normalizedData.parent_models || []; @@ -157,10 +168,13 @@ export function normalizeTreeData(apiData: ModelDepsAPIData): ModelDepsData { const parent = parent_models[0]; normalizedData = { ...parent, - model_type: NodeType.parent, + expanded: false, + level: 0, + datasetLen: parent.train_dataset.length + parent.test_dataset.length, + model_type: NodeType.Parent, id: `$M_${parent.current_model_id}_${parent.version}`, label: getLabel(parent), - style: getStyle(NodeType.parent), + style: getStyle(NodeType.Parent), children: [ { ...normalizedData, @@ -174,13 +188,34 @@ export function normalizeTreeData(apiData: ModelDepsAPIData): ModelDepsData { } // 将树形数据,使用 Hierarchy 进行布局,计算出坐标,然后转换成 G6 的数据 -export function getGraphData(data: ModelDepsData): GraphData { +export function getGraphData(data: ModelDepsData, hierarchyNodes: ModelDepsData[]): GraphData { const config = { direction: 'LR', getHeight: () => nodeHeight, getWidth: () => nodeWidth, - getVGap: () => vGap / 2, - getHGap: () => hGap / 2, + getVGap: (node: NodeConfig) => { + const model = node as ModelDepsData; + const { model_type, expanded, project_dependency } = model; + if (model_type === NodeType.Current || model_type === NodeType.Parent) { + return vGap / 2; + } + const selfGap = expanded && project_dependency?.url ? nodeHeight + vGap : 0; + const nextNode = getSameHierarchyNextNode(model, hierarchyNodes); + if (!nextNode) { + return vGap / 2; + } + const nextGap = nextNode.expanded === true && nextNode.datasetLen > 0 ? nodeHeight + vGap : 0; + return (selfGap + nextGap + vGap) / 2; + }, + getHGap: (node: NodeConfig) => { + const model = node as ModelDepsData; + return ( + (getHierarchyWidth(model.level, hierarchyNodes) + + getHierarchyWidth(model.level + 1, hierarchyNodes) + + hGap) / + 2 + ); + }, }; // 树形布局计算出坐标 @@ -191,11 +226,11 @@ export function getGraphData(data: ModelDepsData): GraphData { Util.traverseTree(treeLayoutData, (node: NodeConfig, parent: NodeConfig) => { const data = node.data as ModelDepsData; // 当前模型显示数据集和项目 - if (data.model_type === NodeType.current) { + if (data.expanded === true) { addDatasetDependency(data, node, nodes, edges); addProjectDependency(data, node, nodes, edges); - } else if (data.model_type === NodeType.children) { - adjustDatasetPosition(node); + } else if (data.model_type === NodeType.Children) { + // adjustDatasetPosition(node); } nodes.push({ ...data, @@ -219,16 +254,16 @@ const addDatasetDependency = ( nodes: NodeConfig[], edges: EdgeConfig[], ) => { - const { train_dataset, test_dataset } = data; + const { train_dataset, test_dataset, id } = data; train_dataset.forEach((item) => { - item.id = `$DTrain_${item.dataset_id}_${item.dataset_version}`; - item.model_type = NodeType.trainDataset; - item.style = getStyle(NodeType.trainDataset); + item.id = `$DTrain_${id}_${item.dataset_id}_${item.dataset_version}`; + item.model_type = NodeType.TrainDataset; + item.style = getStyle(NodeType.TrainDataset); }); test_dataset.forEach((item) => { - item.id = `$DTest_${item.dataset_id}_${item.dataset_version}`; - item.model_type = NodeType.testDataset; - item.style = getStyle(NodeType.testDataset); + item.id = `$DTest_${id}_${item.dataset_id}_${item.dataset_version}`; + item.model_type = NodeType.TestDataset; + item.style = getStyle(NodeType.TestDataset); }); datasetNodes.length = 0; @@ -243,7 +278,7 @@ const addDatasetDependency = ( fittingString(node.dataset_version, ellipseWidth - labelPadding, nodeFontSize); const half = len / 2 - 0.5; - node.x = currentNode.x! - (half - index) * (ellipseWidth + 20); + node.x = currentNode.x! - (half - index) * (ellipseWidth + datasetHGap); node.y = currentNode.y! - nodeHeight - vGap; nodes.push(node); datasetNodes.push(node); @@ -264,14 +299,14 @@ const addProjectDependency = ( nodes: NodeConfig[], edges: EdgeConfig[], ) => { - const { project_dependency } = data; + const { project_dependency, id } = data; if (project_dependency?.url) { const node = { ...project_dependency }; - node.id = `$P_${node.url}_${node.branch}`; - node.model_type = NodeType.project; + node.id = `$P_${id}_${node.url}_${node.branch}`; + node.model_type = NodeType.Project; node.type = 'rect'; node.label = fittingString(node.name, nodeWidth - labelPadding, nodeFontSize); - node.style = getStyle(NodeType.project); + node.style = getStyle(NodeType.Project); node.style.radius = nodeHeight / 2; node.x = currentNode.x; node.y = currentNode.y! + nodeHeight + vGap; @@ -331,3 +366,49 @@ function adjustDatasetPosition(node: NodeConfig) { }); } } + +// 层级遍历树结构 +export function traverseHierarchically(data: ModelDepsData | undefined): ModelDepsData[] { + if (!data) return []; + let level = 0; + data.level = level; + const result: ModelDepsData[] = [data]; + let index = 0; + + while (index < result.length) { + const item = result[index]; + if (item.children) { + item.children.forEach((child) => { + child.level = item.level + 1; + result.push(child); + }); + } + index++; + } + + return result; +} + +// 找到同层次的下一个节点 +export function getSameHierarchyNextNode(node: ModelDepsData, nodes: ModelDepsData[]) { + const index = nodes.findIndex((item) => item.id === node.id); + if (index >= 0 && index < nodes.length - 1) { + const nextNode = nodes[index + 1]; + if (nextNode.level === node.level) { + return nextNode; + } + } + return null; +} + +// 得到层级的宽度 +export function getHierarchyWidth(level: number, nodes: ModelDepsData[]) { + const hierarchyNodes = nodes + .filter((item) => item.level === level && item.expanded === true) + .sort((a, b) => b.datasetLen - a.datasetLen); + const first = hierarchyNodes[0]; + if (first) { + return Math.max(((first.datasetLen - 1) * (nodeWidth + datasetHGap)) / 2, 0); + } + return 0; +} diff --git a/react-ui/src/pages/Model/components/NodeTooltips/index.tsx b/react-ui/src/pages/Model/components/NodeTooltips/index.tsx index 217222da..f5bb2c82 100644 --- a/react-ui/src/pages/Model/components/NodeTooltips/index.tsx +++ b/react-ui/src/pages/Model/components/NodeTooltips/index.tsx @@ -22,7 +22,7 @@ function ModelInfo({ resourceId, data, onVersionChange }: ModelInfoProps) { }; const gotoModelPage = () => { - if (data.model_type === NodeType.current) { + if (data.model_type === NodeType.Current) { return; } if (data.current_model_id === resourceId) { @@ -39,7 +39,7 @@ function ModelInfo({ resourceId, data, onVersionChange }: ModelInfoProps) {
模型名称: - {data.model_type === NodeType.current ? ( + {data.model_type === NodeType.Current ? ( {data.model_version_dependcy_vo?.name || '--'} @@ -199,14 +199,14 @@ function NodeTooltips({ if (!data) return null; let Component = null; const { model_type } = data; - if (model_type === NodeType.testDataset || model_type === NodeType.trainDataset) { + if (model_type === NodeType.TestDataset || model_type === NodeType.TrainDataset) { Component = ; - } else if (model_type === NodeType.project) { + } else if (model_type === NodeType.Project) { Component = ; } else if ( - model_type === NodeType.children || - model_type === NodeType.parent || - model_type === NodeType.current + model_type === NodeType.Children || + model_type === NodeType.Parent || + model_type === NodeType.Current ) { Component = ; } diff --git a/react-ui/src/pages/ModelDeployment/List/index.tsx b/react-ui/src/pages/ModelDeployment/List/index.tsx index 934b4cbd..af8fba44 100644 --- a/react-ui/src/pages/ModelDeployment/List/index.tsx +++ b/react-ui/src/pages/ModelDeployment/List/index.tsx @@ -223,7 +223,7 @@ function ModelDeployment() { { title: '操作', dataIndex: 'operation', - width: 350, + width: 250, key: 'operation', render: (_: any, record: ModelDeploymentData) => (