import { changePropertyName, fittingString } from '@/utils'; import { EdgeConfig, GraphData, LayoutConfig, NodeConfig, TreeGraphData, Util } from '@antv/g6'; // @ts-ignore import Hierarchy from '@antv/hierarchy'; export const nodeWidth = 90; export const nodeHeight = 40; export const vGap = nodeHeight + 20; export const hGap = nodeWidth; export const ellipseWidth = nodeWidth; export const labelPadding = 30; export const nodeFontSize = 8; // 数据集节点 const datasetNodes: NodeConfig[] = []; export enum NodeType { current = 'current', parent = 'parent', children = 'children', project = 'project', trainDataset = 'trainDataset', testDataset = 'testDataset', } export type Rect = { x: number; // 矩形中心的 x 坐标 y: number; // 矩形中心的 y 坐标 width: number; height: number; }; export type TrainTask = { ins_id: number; name: string; task_id: string; }; export interface TrainDataset extends NodeConfig { dataset_id: number; dataset_name: string; dataset_version: string; model_type: NodeType.testDataset | NodeType.trainDataset; } export interface ProjectDependency extends NodeConfig { url: string; name: string; branch: string; model_type: NodeType.project; } export type ModalDetail = { name: string; available_range: number; file_name: string; file_size: string; description: string; model_type_name: string; model_tag_name: string; create_time: string; }; export interface ModelDepsAPIData { current_model_id: number; version: string; workflow_id: number; exp_ins_id: number; model_type: NodeType.children | NodeType.current | NodeType.parent; current_model_name: string; project_dependency: ProjectDependency; test_dataset: TrainDataset[]; train_dataset: TrainDataset[]; train_task: TrainTask; model_version_dependcy_vo: ModalDetail; children_models: ModelDepsAPIData[]; parent_models: ModelDepsAPIData[]; } export interface ModelDepsData extends Omit, TreeGraphData { children: ModelDepsData[]; } // 规范化子数据 export function normalizeChildren(data: ModelDepsData[]) { if (Array.isArray(data)) { data.forEach((item) => { item.model_type = NodeType.children; item.id = `$M_${item.current_model_id}_${item.version}`; item.label = getLabel(item); item.style = getStyle(NodeType.children); normalizeChildren(item.children); }); } } // 获取 label export function getLabel(node: ModelDepsData | ModelDepsAPIData) { return ( fittingString( `${node.model_version_dependcy_vo.name ?? ''}`, nodeWidth - labelPadding, nodeFontSize, ) + '\n' + fittingString(`${node.version}`, nodeWidth - labelPadding, nodeFontSize) ); } // 获取 style export function getStyle(model_type: NodeType) { let fill = ''; switch (model_type) { case NodeType.current: fill = 'l(0) 0:#72a1ff 1:#1664ff'; break; case NodeType.parent: fill = 'l(0) 0:#93dfd1 1:#43c9b1'; break; case NodeType.children: fill = 'l(0) 0:#72b4ff 1:#169aff'; break; case NodeType.project: fill = 'l(0) 0:#b3a9ff 1:#8981ff'; break; case NodeType.trainDataset: fill = '#a5d878'; break; case NodeType.testDataset: fill = '#d8b578'; break; default: break; } return { fill, }; } // 将后台返回的数据转换成树形数据 export function normalizeTreeData(apiData: ModelDepsAPIData): ModelDepsData { // 将 children_models 转换成 children let normalizedData = changePropertyName(apiData, { children_models: 'children', }) as ModelDepsData; // 设置当前模型的数据 normalizedData.model_type = NodeType.current; normalizedData.id = `$M_${normalizedData.current_model_id}_${normalizedData.version}`; normalizedData.label = getLabel(normalizedData); normalizedData.style = getStyle(NodeType.current); normalizeChildren(normalizedData.children as ModelDepsData[]); // 将 parent_models 转换成树形结构 let parent_models = normalizedData.parent_models || []; while (parent_models.length > 0) { const parent = parent_models[0]; normalizedData = { ...parent, model_type: NodeType.parent, id: `$M_${parent.current_model_id}_${parent.version}`, label: getLabel(parent), style: getStyle(NodeType.parent), children: [ { ...normalizedData, parent_models: [], }, ], }; parent_models = normalizedData.parent_models || []; } return normalizedData; } // 将树形数据,使用 Hierarchy 进行布局,计算出坐标,然后转换成 G6 的数据 export function getGraphData(data: ModelDepsData): GraphData { const config = { direction: 'LR', getHeight: () => nodeHeight, getWidth: () => nodeWidth, getVGap: () => vGap / 2, getHGap: () => hGap / 2, }; // 树形布局计算出坐标 const treeLayoutData: LayoutConfig = Hierarchy['compactBox'](data, config); const nodes: NodeConfig[] = []; const edges: EdgeConfig[] = []; Util.traverseTree(treeLayoutData, (node: NodeConfig, parent: NodeConfig) => { const data = node.data as ModelDepsData; // 当前模型显示数据集和项目 if (data.model_type === NodeType.current) { addDatasetDependency(data, node, nodes, edges); addProjectDependency(data, node, nodes, edges); } else if (data.model_type === NodeType.children) { adjustDatasetPosition(node); } nodes.push({ ...data, x: node.x, y: node.y, }); if (parent) { edges.push({ source: parent.id, target: node.id, }); } }); return { nodes, edges }; } // 将数据集转换成 G6 的数据 const addDatasetDependency = ( data: ModelDepsData, currentNode: NodeConfig, nodes: NodeConfig[], edges: EdgeConfig[], ) => { const { train_dataset, test_dataset } = data; train_dataset.forEach((item) => { item.id = `$DTrain_${item.dataset_id}_${item.dataset_version}`; item.model_type = NodeType.trainDataset; item.style = getStyle(NodeType.trainDataset); }); test_dataset.forEach((item) => { item.id = `$DTest_${item.dataset_id}_${item.dataset_version}`; item.model_type = NodeType.testDataset; item.style = getStyle(NodeType.testDataset); }); datasetNodes.length = 0; const len = train_dataset.length + test_dataset.length; [...train_dataset, ...test_dataset].forEach((item, index) => { const node = { ...item }; node.type = 'ellipse'; node.size = [ellipseWidth, nodeHeight]; node.label = fittingString(node.dataset_name, ellipseWidth - labelPadding, nodeFontSize) + '\n' + fittingString(node.dataset_version, ellipseWidth - labelPadding, nodeFontSize); const half = len / 2 - 0.5; node.x = currentNode.x! - (half - index) * (ellipseWidth + 20); node.y = currentNode.y! - nodeHeight - vGap; nodes.push(node); datasetNodes.push(node); edges.push({ source: currentNode.id, target: node.id, sourceAnchor: 2, targetAnchor: 3, type: 'cubic-vertical', }); }); }; // 将模型依赖数据转换成 G6 的数据 const addProjectDependency = ( data: ModelDepsData, currentNode: NodeConfig, nodes: NodeConfig[], edges: EdgeConfig[], ) => { const { project_dependency } = data; if (project_dependency?.url) { const node = { ...project_dependency }; node.id = `$P_${node.url}_${node.branch}`; node.model_type = NodeType.project; node.type = 'rect'; node.label = fittingString(node.name, nodeWidth - labelPadding, nodeFontSize); node.style = getStyle(NodeType.project); node.style.radius = nodeHeight / 2; node.x = currentNode.x; node.y = currentNode.y! + nodeHeight + vGap; nodes.push(node); edges.push({ source: currentNode.id, target: node.id, sourceAnchor: 3, targetAnchor: 2, type: 'cubic-vertical', }); } }; // 判断两个矩形是否相交 function isRectanglesOverlap(rect1: Rect, rect2: Rect) { const a2x = rect1.x + rect1.width / 2; const a2y = rect1.y + rect1.height / 2; const b1x = rect2.x - rect2.width / 2; const b1y = rect2.y - rect2.height / 2; return b1y <= a2y && b1x <= a2x; } // 判断子节点是否与数据集节点重叠 function isChildrenOverlapDataset(nodes: NodeConfig[], childrenRect: Rect) { for (const node of nodes) { const rect = { x: node.x!, y: node.y!, width: nodeWidth, height: nodeHeight }; if (isRectanglesOverlap(rect, childrenRect)) { return childrenRect; } } return null; } // 调整数据集位置 function adjustDatasetPosition(node: NodeConfig) { const nodeRect = { x: node.x!, y: node.y!, width: nodeWidth, height: nodeHeight, }; const overlapRect = isChildrenOverlapDataset(datasetNodes, nodeRect); if (overlapRect) { const adjustRect = { x: overlapRect.x - nodeWidth - hGap / 2, y: overlapRect.y, width: overlapRect.width, height: overlapRect.height, }; const lastNode = datasetNodes[datasetNodes.length - 1]; const distance = lastNode.x! - adjustRect.x; datasetNodes.forEach((item) => { item.x = item.x! - distance; }); } }