You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

index.tsx 6.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. import TableColTitle from '@/components/TableColTitle';
  2. import TrialStatusCell from '@/pages/HyperParameter/components/TrialStatusCell';
  3. import { HyperParameterTrial } from '@/pages/HyperParameter/types';
  4. import { getExpMetricsReq } from '@/services/hyperParameter';
  5. import { to } from '@/utils/promise';
  6. import SessionStorage from '@/utils/sessionStorage';
  7. import tableCellRender, { TableCellValueType } from '@/utils/table';
  8. import { useNavigate } from '@umijs/max';
  9. import { App, Button, Table, type TableProps } from 'antd';
  10. import classNames from 'classnames';
  11. import { useEffect, useState } from 'react';
  12. import TrialFileTree from '../TrialFileTree';
  13. import styles from './index.less';
  14. type ExperimentHistoryProps = {
  15. trialList?: HyperParameterTrial[];
  16. };
  17. function ExperimentHistory({ trialList = [] }: ExperimentHistoryProps) {
  18. const [expandedRowKeys, setExpandedRowKeys] = useState<string[]>([]);
  19. const [selectedRowKeys, setSelectedRowKeys] = useState<React.Key[]>([]);
  20. const { message } = App.useApp();
  21. const [tableData, setTableData] = useState<HyperParameterTrial[]>([]);
  22. const [loading, setLoading] = useState(false);
  23. // 防止 Tabs 卡顿
  24. useEffect(() => {
  25. setLoading(true);
  26. setTimeout(() => {
  27. setTableData(trialList);
  28. setLoading(false);
  29. }, 500);
  30. }, [trialList]);
  31. // 计算 column
  32. const first: HyperParameterTrial | undefined = trialList ? trialList[0] : undefined;
  33. const config: Record<string, any> = first?.config ?? {};
  34. const metricAnalysis: Record<string, any> = first?.metric_analysis ?? {};
  35. const paramsNames = Object.keys(config);
  36. const metricNames = Object.keys(metricAnalysis);
  37. const navigate = useNavigate();
  38. const trialColumns: TableProps<HyperParameterTrial>['columns'] = [
  39. {
  40. title: '序号',
  41. dataIndex: 'index',
  42. key: 'index',
  43. width: 100,
  44. fixed: 'left',
  45. render: (_text, record, index: number) => {
  46. return (
  47. <div className={styles['cell-index']}>
  48. <span className={styles['cell-index__text']}>{index + 1}</span>
  49. {record.is_best && <span className={styles['cell-index__best-tag']}>最佳</span>}
  50. </div>
  51. );
  52. },
  53. },
  54. {
  55. title: '基本信息',
  56. align: 'center',
  57. children: [
  58. {
  59. title: '运行次数',
  60. dataIndex: 'training_iteration',
  61. key: 'training_iteration',
  62. width: 120,
  63. fixed: 'left',
  64. render: tableCellRender(false),
  65. },
  66. {
  67. title: '平均时长(秒)',
  68. dataIndex: 'time_avg',
  69. key: 'time_avg',
  70. width: 150,
  71. fixed: 'left',
  72. render: tableCellRender(false, TableCellValueType.Custom, {
  73. format: (value = 0) => Number(value).toFixed(2),
  74. }),
  75. },
  76. {
  77. title: '状态',
  78. dataIndex: 'status',
  79. key: 'status',
  80. width: 120,
  81. fixed: 'left',
  82. render: TrialStatusCell,
  83. },
  84. ],
  85. },
  86. ];
  87. if (paramsNames.length) {
  88. trialColumns.push({
  89. title: '运行参数',
  90. dataIndex: 'config',
  91. key: 'config',
  92. align: 'center',
  93. children: paramsNames.map((name) => ({
  94. title: <TableColTitle title={name} />,
  95. dataIndex: ['config', name],
  96. key: name,
  97. width: 120,
  98. align: 'center',
  99. render: tableCellRender(true),
  100. })),
  101. });
  102. }
  103. if (metricNames.length) {
  104. trialColumns.push({
  105. title: `指标分析(${first?.metric ?? ''})`,
  106. dataIndex: 'metrics',
  107. key: 'metrics',
  108. align: 'center',
  109. children: metricNames.map((name) => ({
  110. title: <TableColTitle title={name} />,
  111. dataIndex: ['metric_analysis', name],
  112. key: name,
  113. width: 120,
  114. align: 'center',
  115. render: tableCellRender(true),
  116. })),
  117. });
  118. }
  119. // 自定义展开视图
  120. const expandedRowRender = (record: HyperParameterTrial) => {
  121. return <TrialFileTree title="寻优结果" file={record.file}></TrialFileTree>;
  122. };
  123. // 展开实例
  124. const handleExpandChange = (expanded: boolean, record: HyperParameterTrial) => {
  125. if (expanded) {
  126. setExpandedRowKeys([record.trial_id]);
  127. } else {
  128. setExpandedRowKeys([]);
  129. }
  130. };
  131. // 选择行
  132. const rowSelection: TableProps<HyperParameterTrial>['rowSelection'] = {
  133. type: 'checkbox',
  134. columnWidth: 48,
  135. fixed: 'left',
  136. selectedRowKeys,
  137. onChange: (selectedRowKeys: React.Key[]) => {
  138. setSelectedRowKeys(selectedRowKeys);
  139. },
  140. };
  141. // 对比
  142. const handleComparisonClick = () => {
  143. if (selectedRowKeys.length < 1) {
  144. message.error('请至少选择一项');
  145. return;
  146. }
  147. getExpMetrics();
  148. };
  149. // 获取对比 url
  150. const getExpMetrics = async () => {
  151. const [res] = await to(getExpMetricsReq(selectedRowKeys));
  152. if (res && res.data) {
  153. const url = res.data;
  154. SessionStorage.setItem(SessionStorage.aimUrlKey, url);
  155. navigate('compare-visual');
  156. }
  157. };
  158. return (
  159. <div className={styles['experiment-history']}>
  160. <div className={styles['experiment-history__content']}>
  161. <Button type="default" onClick={handleComparisonClick}>
  162. 可视化对比
  163. </Button>
  164. <div
  165. className={classNames(
  166. 'vertical-scroll-table-no-page',
  167. styles['experiment-history__content__table'],
  168. )}
  169. >
  170. <Table
  171. loading={loading}
  172. rowClassName={(record) => (record.is_best ? styles['table-best-row'] : '')}
  173. dataSource={tableData}
  174. columns={trialColumns}
  175. pagination={false}
  176. bordered={true}
  177. scroll={{ y: 'calc(100% - 110px)', x: '100%' }}
  178. rowKey="trial_id"
  179. expandable={{
  180. expandedRowRender: expandedRowRender,
  181. onExpand: handleExpandChange,
  182. expandedRowKeys: expandedRowKeys,
  183. rowExpandable: (record: HyperParameterTrial) => !!record.file,
  184. }}
  185. rowSelection={rowSelection}
  186. />
  187. </div>
  188. </div>
  189. </div>
  190. );
  191. }
  192. export default ExperimentHistory;