You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.tsx 9.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333
  1. import { changePropertyName, fittingString } from '@/utils';
  2. import { EdgeConfig, GraphData, LayoutConfig, NodeConfig, TreeGraphData, Util } from '@antv/g6';
  3. // @ts-ignore
  4. import Hierarchy from '@antv/hierarchy';
  5. export const nodeWidth = 90;
  6. export const nodeHeight = 40;
  7. export const vGap = nodeHeight + 20;
  8. export const hGap = nodeWidth;
  9. export const ellipseWidth = nodeWidth;
  10. export const labelPadding = 30;
  11. export const nodeFontSize = 8;
  12. // 数据集节点
  13. const datasetNodes: NodeConfig[] = [];
  14. export enum NodeType {
  15. current = 'current',
  16. parent = 'parent',
  17. children = 'children',
  18. project = 'project',
  19. trainDataset = 'trainDataset',
  20. testDataset = 'testDataset',
  21. }
  22. export type Rect = {
  23. x: number; // 矩形中心的 x 坐标
  24. y: number; // 矩形中心的 y 坐标
  25. width: number;
  26. height: number;
  27. };
  28. export type TrainTask = {
  29. ins_id: number;
  30. name: string;
  31. task_id: string;
  32. };
  33. export interface TrainDataset extends NodeConfig {
  34. dataset_id: number;
  35. dataset_name: string;
  36. dataset_version: string;
  37. model_type: NodeType.testDataset | NodeType.trainDataset;
  38. }
  39. export interface ProjectDependency extends NodeConfig {
  40. url: string;
  41. name: string;
  42. branch: string;
  43. model_type: NodeType.project;
  44. }
  45. export type ModalDetail = {
  46. name: string;
  47. available_range: number;
  48. file_name: string;
  49. file_size: string;
  50. description: string;
  51. model_type_name: string;
  52. model_tag_name: string;
  53. create_time: string;
  54. };
  55. export interface ModelDepsAPIData {
  56. current_model_id: number;
  57. version: string;
  58. workflow_id: number;
  59. exp_ins_id: number;
  60. model_type: NodeType.children | NodeType.current | NodeType.parent;
  61. current_model_name: string;
  62. project_dependency: ProjectDependency;
  63. test_dataset: TrainDataset[];
  64. train_dataset: TrainDataset[];
  65. train_task: TrainTask;
  66. model_version_dependcy_vo: ModalDetail;
  67. children_models: ModelDepsAPIData[];
  68. parent_models: ModelDepsAPIData[];
  69. }
  70. export interface ModelDepsData extends Omit<ModelDepsAPIData, 'children_models'>, TreeGraphData {
  71. children: ModelDepsData[];
  72. }
  73. // 规范化子数据
  74. export function normalizeChildren(data: ModelDepsData[]) {
  75. if (Array.isArray(data)) {
  76. data.forEach((item) => {
  77. item.model_type = NodeType.children;
  78. item.id = `$M_${item.current_model_id}_${item.version}`;
  79. item.label = getLabel(item);
  80. item.style = getStyle(NodeType.children);
  81. normalizeChildren(item.children);
  82. });
  83. }
  84. }
  85. // 获取 label
  86. export function getLabel(node: ModelDepsData | ModelDepsAPIData) {
  87. return (
  88. fittingString(
  89. `${node.model_version_dependcy_vo.name ?? ''}`,
  90. nodeWidth - labelPadding,
  91. nodeFontSize,
  92. ) +
  93. '\n' +
  94. fittingString(`${node.version}`, nodeWidth - labelPadding, nodeFontSize)
  95. );
  96. }
  97. // 获取 style
  98. export function getStyle(model_type: NodeType) {
  99. let fill = '';
  100. switch (model_type) {
  101. case NodeType.current:
  102. fill = 'l(0) 0:#72a1ff 1:#1664ff';
  103. break;
  104. case NodeType.parent:
  105. fill = 'l(0) 0:#93dfd1 1:#43c9b1';
  106. break;
  107. case NodeType.children:
  108. fill = 'l(0) 0:#72b4ff 1:#169aff';
  109. break;
  110. case NodeType.project:
  111. fill = 'l(0) 0:#b3a9ff 1:#8981ff';
  112. break;
  113. case NodeType.trainDataset:
  114. fill = '#a5d878';
  115. break;
  116. case NodeType.testDataset:
  117. fill = '#d8b578';
  118. break;
  119. default:
  120. break;
  121. }
  122. return {
  123. fill,
  124. };
  125. }
  126. // 将后台返回的数据转换成树形数据
  127. export function normalizeTreeData(apiData: ModelDepsAPIData): ModelDepsData {
  128. // 将 children_models 转换成 children
  129. let normalizedData = changePropertyName(apiData, {
  130. children_models: 'children',
  131. }) as ModelDepsData;
  132. // 设置当前模型的数据
  133. normalizedData.model_type = NodeType.current;
  134. normalizedData.id = `$M_${normalizedData.current_model_id}_${normalizedData.version}`;
  135. normalizedData.label = getLabel(normalizedData);
  136. normalizedData.style = getStyle(NodeType.current);
  137. normalizeChildren(normalizedData.children as ModelDepsData[]);
  138. // 将 parent_models 转换成树形结构
  139. let parent_models = normalizedData.parent_models || [];
  140. while (parent_models.length > 0) {
  141. const parent = parent_models[0];
  142. normalizedData = {
  143. ...parent,
  144. model_type: NodeType.parent,
  145. id: `$M_${parent.current_model_id}_${parent.version}`,
  146. label: getLabel(parent),
  147. style: getStyle(NodeType.parent),
  148. children: [
  149. {
  150. ...normalizedData,
  151. parent_models: [],
  152. },
  153. ],
  154. };
  155. parent_models = normalizedData.parent_models || [];
  156. }
  157. return normalizedData;
  158. }
  159. // 将树形数据,使用 Hierarchy 进行布局,计算出坐标,然后转换成 G6 的数据
  160. export function getGraphData(data: ModelDepsData): GraphData {
  161. const config = {
  162. direction: 'LR',
  163. getHeight: () => nodeHeight,
  164. getWidth: () => nodeWidth,
  165. getVGap: () => vGap / 2,
  166. getHGap: () => hGap / 2,
  167. };
  168. // 树形布局计算出坐标
  169. const treeLayoutData: LayoutConfig = Hierarchy['compactBox'](data, config);
  170. const nodes: NodeConfig[] = [];
  171. const edges: EdgeConfig[] = [];
  172. Util.traverseTree(treeLayoutData, (node: NodeConfig, parent: NodeConfig) => {
  173. const data = node.data as ModelDepsData;
  174. // 当前模型显示数据集和项目
  175. if (data.model_type === NodeType.current) {
  176. addDatasetDependency(data, node, nodes, edges);
  177. addProjectDependency(data, node, nodes, edges);
  178. } else if (data.model_type === NodeType.children) {
  179. adjustDatasetPosition(node);
  180. }
  181. nodes.push({
  182. ...data,
  183. x: node.x,
  184. y: node.y,
  185. });
  186. if (parent) {
  187. edges.push({
  188. source: parent.id,
  189. target: node.id,
  190. });
  191. }
  192. });
  193. return { nodes, edges };
  194. }
  195. // 将数据集转换成 G6 的数据
  196. const addDatasetDependency = (
  197. data: ModelDepsData,
  198. currentNode: NodeConfig,
  199. nodes: NodeConfig[],
  200. edges: EdgeConfig[],
  201. ) => {
  202. const { train_dataset, test_dataset } = data;
  203. train_dataset.forEach((item) => {
  204. item.id = `$DTrain_${item.dataset_id}_${item.dataset_version}`;
  205. item.model_type = NodeType.trainDataset;
  206. item.style = getStyle(NodeType.trainDataset);
  207. });
  208. test_dataset.forEach((item) => {
  209. item.id = `$DTest_${item.dataset_id}_${item.dataset_version}`;
  210. item.model_type = NodeType.testDataset;
  211. item.style = getStyle(NodeType.testDataset);
  212. });
  213. datasetNodes.length = 0;
  214. const len = train_dataset.length + test_dataset.length;
  215. [...train_dataset, ...test_dataset].forEach((item, index) => {
  216. const node = { ...item };
  217. node.type = 'ellipse';
  218. node.size = [ellipseWidth, nodeHeight];
  219. node.label =
  220. fittingString(node.dataset_name, ellipseWidth - labelPadding, nodeFontSize) +
  221. '\n' +
  222. fittingString(node.dataset_version, ellipseWidth - labelPadding, nodeFontSize);
  223. const half = len / 2 - 0.5;
  224. node.x = currentNode.x! - (half - index) * (ellipseWidth + 20);
  225. node.y = currentNode.y! - nodeHeight - vGap;
  226. nodes.push(node);
  227. datasetNodes.push(node);
  228. edges.push({
  229. source: currentNode.id,
  230. target: node.id,
  231. sourceAnchor: 2,
  232. targetAnchor: 3,
  233. type: 'cubic-vertical',
  234. });
  235. });
  236. };
  237. // 将模型依赖数据转换成 G6 的数据
  238. const addProjectDependency = (
  239. data: ModelDepsData,
  240. currentNode: NodeConfig,
  241. nodes: NodeConfig[],
  242. edges: EdgeConfig[],
  243. ) => {
  244. const { project_dependency } = data;
  245. if (project_dependency?.url) {
  246. const node = { ...project_dependency };
  247. node.id = `$P_${node.url}_${node.branch}`;
  248. node.model_type = NodeType.project;
  249. node.type = 'rect';
  250. node.label = fittingString(node.name, nodeWidth - labelPadding, nodeFontSize);
  251. node.style = getStyle(NodeType.project);
  252. node.style.radius = nodeHeight / 2;
  253. node.x = currentNode.x;
  254. node.y = currentNode.y! + nodeHeight + vGap;
  255. nodes.push(node);
  256. edges.push({
  257. source: currentNode.id,
  258. target: node.id,
  259. sourceAnchor: 3,
  260. targetAnchor: 2,
  261. type: 'cubic-vertical',
  262. });
  263. }
  264. };
  265. // 判断两个矩形是否相交
  266. function isRectanglesOverlap(rect1: Rect, rect2: Rect) {
  267. const a2x = rect1.x + rect1.width / 2;
  268. const a2y = rect1.y + rect1.height / 2;
  269. const b1x = rect2.x - rect2.width / 2;
  270. const b1y = rect2.y - rect2.height / 2;
  271. return b1y <= a2y && b1x <= a2x;
  272. }
  273. // 判断子节点是否与数据集节点重叠
  274. function isChildrenOverlapDataset(nodes: NodeConfig[], childrenRect: Rect) {
  275. for (const node of nodes) {
  276. const rect = { x: node.x!, y: node.y!, width: nodeWidth, height: nodeHeight };
  277. if (isRectanglesOverlap(rect, childrenRect)) {
  278. return childrenRect;
  279. }
  280. }
  281. return null;
  282. }
  283. // 调整数据集位置
  284. function adjustDatasetPosition(node: NodeConfig) {
  285. const nodeRect = {
  286. x: node.x!,
  287. y: node.y!,
  288. width: nodeWidth,
  289. height: nodeHeight,
  290. };
  291. const overlapRect = isChildrenOverlapDataset(datasetNodes, nodeRect);
  292. if (overlapRect) {
  293. const adjustRect = {
  294. x: overlapRect.x - nodeWidth - hGap / 2,
  295. y: overlapRect.y,
  296. width: overlapRect.width,
  297. height: overlapRect.height,
  298. };
  299. const lastNode = datasetNodes[datasetNodes.length - 1];
  300. const distance = lastNode.x! - adjustRect.x;
  301. datasetNodes.forEach((item) => {
  302. item.x = item.x! - distance;
  303. });
  304. }
  305. }