import { changePropertyName, fittingString } from '@/utils'; import { EdgeConfig, GraphData, LayoutConfig, NodeConfig, TreeGraphData, Util } from '@antv/g6'; // @ts-ignore import Hierarchy from '@antv/hierarchy'; export const nodeWidth = 90; export const nodeHeight = 40; export const vGap = nodeHeight + 20; export const hGap = nodeHeight + 20; export const ellipseWidth = nodeWidth; export const labelPadding = 30; export const nodeFontSize = 8; export const datasetHGap = 20; // 数据集节点 const datasetNodes: NodeConfig[] = []; export enum NodeType { Current = 'Current', // 当前模型 Parent = 'Parent', // 父模型 Children = 'Children', // 子模型 Project = 'Project', // 项目 TrainDataset = 'TrainDataset', // 训练数据集 TestDataset = 'TestDataset', // 测试数据集 } export type Rect = { x: number; // 矩形中心的 x 坐标 y: number; // 矩形中心的 y 坐标 width: number; height: number; }; export type TrainTask = { ins_id: number; name: string; task_id: string; }; export interface TrainDataset extends NodeConfig { dataset_id: number; dataset_name: string; dataset_version: string; model_type: NodeType.TestDataset | NodeType.TrainDataset; } export interface ProjectDependency extends NodeConfig { url: string; name: string; branch: string; model_type: NodeType.Project; } export type ModalDetail = { name: string; available_range: number; file_name: string; file_size: string; description: string; model_type_name: string; model_tag_name: string; create_time: string; }; export interface ModelDepsAPIData { current_model_id: number; version: string; workflow_id: number; exp_ins_id: number; model_type: NodeType.Children | NodeType.Current | NodeType.Parent; current_model_name: string; project_dependency?: ProjectDependency; test_dataset: TrainDataset[]; train_dataset: TrainDataset[]; train_task: TrainTask; model_version_dependcy_vo: ModalDetail; children_models: ModelDepsAPIData[]; parent_models: ModelDepsAPIData[]; } export interface ModelDepsData extends Omit, TreeGraphData { children: ModelDepsData[]; expanded: boolean; // 是否展开 level: number; // 层级,从 0 开始 datasetLen: number; // 数据集数量 } // 规范化子数据 export function normalizeChildren(data: ModelDepsData[]) { if (Array.isArray(data)) { data.forEach((item) => { item.model_type = NodeType.Children; item.expanded = false; item.level = 0; item.datasetLen = item.train_dataset.length + item.test_dataset.length; item.id = `$M_${item.current_model_id}_${item.version}`; item.label = getLabel(item); item.style = getStyle(NodeType.Children); normalizeChildren(item.children); }); } } // 获取 label export function getLabel(node: ModelDepsData | ModelDepsAPIData) { return ( fittingString( `${node.model_version_dependcy_vo.name ?? ''}`, nodeWidth - labelPadding, nodeFontSize, ) + '\n' + fittingString(`${node.version}`, nodeWidth - labelPadding, nodeFontSize) ); } // 获取 style export function getStyle(model_type: NodeType) { let fill = ''; switch (model_type) { case NodeType.Current: fill = 'l(0) 0:#72a1ff 1:#1664ff'; break; case NodeType.Parent: fill = 'l(0) 0:#93dfd1 1:#43c9b1'; break; case NodeType.Children: fill = 'l(0) 0:#72b4ff 1:#169aff'; break; case NodeType.Project: fill = 'l(0) 0:#b3a9ff 1:#8981ff'; break; case NodeType.TrainDataset: fill = '#a5d878'; break; case NodeType.TestDataset: fill = '#d8b578'; break; default: break; } return { fill, }; } // 将后台返回的数据转换成树形数据 export function normalizeTreeData(apiData: ModelDepsAPIData): ModelDepsData { // 将 children_models 转换成 children let normalizedData = changePropertyName(apiData, { children_models: 'children', }) as ModelDepsData; // 设置当前模型的数据 normalizedData.model_type = NodeType.Current; normalizedData.id = `$M_${normalizedData.current_model_id}_${normalizedData.version}`; normalizedData.label = getLabel(normalizedData); normalizedData.style = getStyle(NodeType.Current); normalizedData.expanded = true; normalizedData.datasetLen = normalizedData.train_dataset.length + normalizedData.test_dataset.length; normalizeChildren(normalizedData.children as ModelDepsData[]); normalizedData.level = 0; // 将 parent_models 转换成树形结构 let parent_models = normalizedData.parent_models || []; while (parent_models.length > 0) { const parent = parent_models[0]; normalizedData = { ...parent, expanded: false, level: 0, datasetLen: parent.train_dataset.length + parent.test_dataset.length, model_type: NodeType.Parent, id: `$M_${parent.current_model_id}_${parent.version}`, label: getLabel(parent), style: getStyle(NodeType.Parent), children: [ { ...normalizedData, parent_models: [], }, ], }; parent_models = normalizedData.parent_models || []; } return normalizedData; } // 将树形数据,使用 Hierarchy 进行布局,计算出坐标,然后转换成 G6 的数据 export function getGraphData(data: ModelDepsData, hierarchyNodes: ModelDepsData[]): GraphData { const config = { direction: 'LR', getHeight: () => nodeHeight, getWidth: () => nodeWidth, getVGap: (node: NodeConfig) => { const model = node as ModelDepsData; const { model_type, expanded, project_dependency } = model; if (model_type === NodeType.Current || model_type === NodeType.Parent) { return vGap / 2; } const selfGap = expanded && project_dependency?.url ? nodeHeight + vGap : 0; const nextNode = getSameHierarchyNextNode(model, hierarchyNodes); if (!nextNode) { return vGap / 2; } const nextGap = nextNode.expanded === true && nextNode.datasetLen > 0 ? nodeHeight + vGap : 0; return (selfGap + nextGap + vGap) / 2; }, getHGap: (node: NodeConfig) => { const model = node as ModelDepsData; return ( (getHierarchyWidth(model.level, hierarchyNodes) + getHierarchyWidth(model.level + 1, hierarchyNodes) + hGap) / 2 ); }, }; // 树形布局计算出坐标 const treeLayoutData: LayoutConfig = Hierarchy['compactBox'](data, config); const nodes: NodeConfig[] = []; const edges: EdgeConfig[] = []; Util.traverseTree(treeLayoutData, (node: NodeConfig, parent: NodeConfig) => { const data = node.data as ModelDepsData; // 当前模型显示数据集和项目 if (data.expanded === true) { addDatasetDependency(data, node, nodes, edges); addProjectDependency(data, node, nodes, edges); } else if (data.model_type === NodeType.Children) { // adjustDatasetPosition(node); } nodes.push({ ...data, x: node.x, y: node.y, }); if (parent) { edges.push({ source: parent.id, target: node.id, }); } }); return { nodes, edges }; } // 将数据集转换成 G6 的数据 const addDatasetDependency = ( data: ModelDepsData, currentNode: NodeConfig, nodes: NodeConfig[], edges: EdgeConfig[], ) => { const { train_dataset, test_dataset, id } = data; train_dataset.forEach((item) => { item.id = `$DTrain_${id}_${item.dataset_id}_${item.dataset_version}`; item.model_type = NodeType.TrainDataset; item.style = getStyle(NodeType.TrainDataset); }); test_dataset.forEach((item) => { item.id = `$DTest_${id}_${item.dataset_id}_${item.dataset_version}`; item.model_type = NodeType.TestDataset; item.style = getStyle(NodeType.TestDataset); }); datasetNodes.length = 0; const len = train_dataset.length + test_dataset.length; [...train_dataset, ...test_dataset].forEach((item, index) => { const node = { ...item }; node.type = 'ellipse'; node.size = [ellipseWidth, nodeHeight]; node.label = fittingString(node.dataset_name, ellipseWidth - labelPadding, nodeFontSize) + '\n' + fittingString(node.dataset_version, ellipseWidth - labelPadding, nodeFontSize); const half = len / 2 - 0.5; node.x = currentNode.x! - (half - index) * (ellipseWidth + datasetHGap); node.y = currentNode.y! - nodeHeight - vGap; nodes.push(node); datasetNodes.push(node); edges.push({ source: currentNode.id, target: node.id, sourceAnchor: 2, targetAnchor: 3, type: 'cubic-vertical', }); }); }; // 将模型依赖数据转换成 G6 的数据 const addProjectDependency = ( data: ModelDepsData, currentNode: NodeConfig, nodes: NodeConfig[], edges: EdgeConfig[], ) => { const { project_dependency, id } = data; if (project_dependency?.url) { const node = { ...project_dependency }; node.id = `$P_${id}_${node.url}_${node.branch}`; node.model_type = NodeType.Project; node.type = 'rect'; node.label = fittingString(node.name, nodeWidth - labelPadding, nodeFontSize); node.style = getStyle(NodeType.Project); node.style.radius = nodeHeight / 2; node.x = currentNode.x; node.y = currentNode.y! + nodeHeight + vGap; nodes.push(node); edges.push({ source: currentNode.id, target: node.id, sourceAnchor: 3, targetAnchor: 2, type: 'cubic-vertical', }); } }; // 判断两个矩形是否相交 function isRectanglesOverlap(rect1: Rect, rect2: Rect) { const a2x = rect1.x + rect1.width / 2; const a2y = rect1.y + rect1.height / 2; const b1x = rect2.x - rect2.width / 2; const b1y = rect2.y - rect2.height / 2; return b1y <= a2y && b1x <= a2x; } // 判断子节点是否与数据集节点重叠 function isChildrenOverlapDataset(nodes: NodeConfig[], childrenRect: Rect) { for (const node of nodes) { const rect = { x: node.x!, y: node.y!, width: nodeWidth, height: nodeHeight }; if (isRectanglesOverlap(rect, childrenRect)) { return childrenRect; } } return null; } // 调整数据集位置 function adjustDatasetPosition(node: NodeConfig) { const nodeRect = { x: node.x!, y: node.y!, width: nodeWidth, height: nodeHeight, }; const overlapRect = isChildrenOverlapDataset(datasetNodes, nodeRect); if (overlapRect) { const adjustRect = { x: overlapRect.x - nodeWidth - hGap / 2, y: overlapRect.y, width: overlapRect.width, height: overlapRect.height, }; const lastNode = datasetNodes[datasetNodes.length - 1]; const distance = lastNode.x! - adjustRect.x; datasetNodes.forEach((item) => { item.x = item.x! - distance; }); } } // 层级遍历树结构 export function traverseHierarchically(data: ModelDepsData | undefined): ModelDepsData[] { if (!data) return []; let level = 0; data.level = level; const result: ModelDepsData[] = [data]; let index = 0; while (index < result.length) { const item = result[index]; if (item.children) { item.children.forEach((child) => { child.level = item.level + 1; result.push(child); }); } index++; } return result; } // 找到同层次的下一个节点 export function getSameHierarchyNextNode(node: ModelDepsData, nodes: ModelDepsData[]) { const index = nodes.findIndex((item) => item.id === node.id); if (index >= 0 && index < nodes.length - 1) { const nextNode = nodes[index + 1]; if (nextNode.level === node.level) { return nextNode; } } return null; } // 得到层级的宽度 export function getHierarchyWidth(level: number, nodes: ModelDepsData[]) { const hierarchyNodes = nodes .filter((item) => item.level === level && item.expanded === true) .sort((a, b) => b.datasetLen - a.datasetLen); const first = hierarchyNodes[0]; if (first) { return Math.max(((first.datasetLen - 1) * (nodeWidth + datasetHGap)) / 2, 0); } return 0; }