From 1b497e1743aec4559c5df91635f3eadcd65170bb Mon Sep 17 00:00:00 2001 From: zhaowei Date: Fri, 30 May 2025 11:46:12 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E8=87=AA=E5=8A=A8=E6=9C=BA=E5=99=A8?= =?UTF-8?q?=E5=AD=A6=E4=B9=A0=E6=B7=BB=E5=8A=A0=E7=AE=97=E6=B3=95=E6=8F=8F?= =?UTF-8?q?=E8=BF=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../AutoML/components/AutoMLBasic/index.tsx | 40 +++++++++- .../components/CreateForm/ExecuteConfig.tsx | 65 +--------------- .../AutoML/components/CreateForm/utils.ts | 78 +++++++++++++++++++ 3 files changed, 118 insertions(+), 65 deletions(-) create mode 100644 react-ui/src/pages/AutoML/components/CreateForm/utils.ts diff --git a/react-ui/src/pages/AutoML/components/AutoMLBasic/index.tsx b/react-ui/src/pages/AutoML/components/AutoMLBasic/index.tsx index 56525f00..6f7870d8 100644 --- a/react-ui/src/pages/AutoML/components/AutoMLBasic/index.tsx +++ b/react-ui/src/pages/AutoML/components/AutoMLBasic/index.tsx @@ -7,10 +7,21 @@ import { autoMLTaskTypeOptions, } from '@/enums'; import { useComputingResource } from '@/hooks/useComputingResource'; +import { + classificationAlgorithms, + featureAlgorithms, + regressorAlgorithms, +} from '@/pages/AutoML/components/CreateForm/utils'; import { AutoMLData } from '@/pages/AutoML/types'; import { type NodeStatus } from '@/types'; import { parseJsonText } from '@/utils'; -import { formatBoolean, formatDataset, formatDate, formatEnum } from '@/utils/format'; +import { + formatBoolean, + formatDataset, + formatDate, + formatEnum, + type EnumOptions, +} from '@/utils/format'; import classNames from 'classnames'; import { useMemo } from 'react'; import ExperimentRunBasic from '../ExperimentRunBasic'; @@ -21,6 +32,7 @@ const formatOptimizeMode = (value: boolean) => { return value ? '越大越好' : '越小越好'; }; +// 格式化权重 const formatMetricsWeight = (value: string) => { if (!value) { return '--'; @@ -34,6 +46,20 @@ const formatMetricsWeight = (value: string) => { .join('\n'); }; +// 格式化算法 +const formatAlgorithm = (algorithms: EnumOptions[]) => { + return (value: string) => { + if (!value) { + return '--'; + } + const list = value + .split(',') + .filter((v) => v !== '') + .map((v) => v.trim()); + return list.map((v) => formatEnum(algorithms)(v)).join(','); + }; +}; + type AutoMLBasicProps = { info?: AutoMLData; className?: string; @@ -96,10 +122,12 @@ function AutoMLBasic({ { label: '特征预处理算法', value: info.include_feature_preprocessor, + format: formatAlgorithm(featureAlgorithms), }, { label: '排除的特征预处理算法', value: info.exclude_feature_preprocessor, + format: formatAlgorithm(featureAlgorithms), }, { label: info.task_type === AutoMLTaskType.Regression ? '回归算法' : '分类算法', @@ -107,6 +135,11 @@ function AutoMLBasic({ info.task_type === AutoMLTaskType.Regression ? info.include_regressor : info.include_classifier, + format: formatAlgorithm( + info.task_type === AutoMLTaskType.Regression + ? regressorAlgorithms + : classificationAlgorithms, + ), }, { label: info.task_type === AutoMLTaskType.Regression ? '排除的回归算法' : '排除的分类算法', @@ -114,6 +147,11 @@ function AutoMLBasic({ info.task_type === AutoMLTaskType.Regression ? info.exclude_regressor : info.exclude_classifier, + format: formatAlgorithm( + info.task_type === AutoMLTaskType.Regression + ? regressorAlgorithms + : classificationAlgorithms, + ), }, { label: '集成方式', diff --git a/react-ui/src/pages/AutoML/components/CreateForm/ExecuteConfig.tsx b/react-ui/src/pages/AutoML/components/CreateForm/ExecuteConfig.tsx index fded3102..b68fdc9a 100644 --- a/react-ui/src/pages/AutoML/components/CreateForm/ExecuteConfig.tsx +++ b/react-ui/src/pages/AutoML/components/CreateForm/ExecuteConfig.tsx @@ -8,70 +8,7 @@ import { autoMLTaskTypeOptions, } from '@/enums'; import { Col, Form, InputNumber, Radio, Row, Select, Switch } from 'antd'; - -// 分类算法 -const classificationAlgorithms = [ - 'adaboost', - 'bernoulli_nb', - 'decision_tree', - 'extra_trees', - 'gaussian_nb', - 'gradient_boosting', - 'k_nearest_neighbors', - 'lda', - 'liblinear_svc', - 'libsvm_svc', - 'tablenet', - 'mlp', - 'multinomial_nb', - 'passive_aggressive', - 'qda', - 'random_forest', - 'sgd', - 'LightGBMClassification', - 'XGBoostClassification', - 'StackingClassification', -].map((name) => ({ label: name, value: name })); - -// 回归算法 -const regressorAlgorithms = [ - 'adaboost', - 'ard_regression', - 'decision_tree', - 'extra_trees', - 'gaussian_process', - 'gradient_boosting', - 'k_nearest_neighbors', - 'liblinear_svr', - 'libsvm_svr', - 'mlp', - 'random_forest', - 'sgd', - 'LightGBMRegression', - 'XGBoostRegression', -].map((name) => ({ label: name, value: name })); - -// 特征预处理算法 -const featureAlgorithms = [ - 'densifier', - 'extra_trees_preproc_for_classification', - 'extra_trees_preproc_for_regression', - 'fast_ica', - 'feature_agglomeration', - 'kernel_pca', - 'kitchen_sinks', - 'liblinear_svc_preprocessor', - 'no_preprocessing', - 'nystroem_sampler', - 'pca', - 'polynomial', - 'random_trees_embedding', - 'select_percentile_classification', - 'select_percentile_regression', - 'select_rates_classification', - 'select_rates_regression', - 'truncatedSVD', -].map((name) => ({ label: name, value: name })); +import { classificationAlgorithms, featureAlgorithms, regressorAlgorithms } from './utils'; // 分类指标 export const classificationMetrics = [ diff --git a/react-ui/src/pages/AutoML/components/CreateForm/utils.ts b/react-ui/src/pages/AutoML/components/CreateForm/utils.ts new file mode 100644 index 00000000..16553178 --- /dev/null +++ b/react-ui/src/pages/AutoML/components/CreateForm/utils.ts @@ -0,0 +1,78 @@ +// 分类算法 +export const classificationAlgorithms = [ + { label: 'adaboost (自适应提升算法)', value: 'adaboost' }, + { label: 'bernoulli_nb (伯努利朴素贝叶斯)', value: 'bernoulli_nb' }, + { label: 'decision_tree (决策树)', value: 'decision_tree' }, + { label: 'extra_trees (极端随机树)', value: 'extra_trees' }, + { label: 'gaussian_nb (高斯朴素贝叶斯)', value: 'gaussian_nb' }, + { label: 'gradient_boosting (梯度提升)', value: 'gradient_boosting' }, + { label: 'k_nearest_neighbors (k近邻)', value: 'k_nearest_neighbors' }, + { label: 'lda (线性判别分析)', value: 'lda' }, + { label: 'liblinear_svc (liblinear支持向量分类)', value: 'liblinear_svc' }, + { label: 'libsvm_svc (libsvm支持向量分类)', value: 'libsvm_svc' }, + { label: 'mlp (多层感知器)', value: 'mlp' }, + { label: 'multinomial_nb (多项式朴素贝叶斯)', value: 'multinomial_nb' }, + { label: 'passive_aggressive (被动攻击算法)', value: 'passive_aggressive' }, + { label: 'qda (二次判别式分析)', value: 'qda' }, + { label: 'random_forest (随机森林)', value: 'random_forest' }, + { label: 'sgd (随机梯度下降)', value: 'sgd' }, + { label: 'tablenet (表格网络)', value: 'tablenet' }, + { label: 'LightGBMClassification (轻量梯度提升机分类)', value: 'LightGBMClassification' }, + { label: 'XGBoostClassification (极端梯度提升机分类)', value: 'XGBoostClassification' }, + { label: 'StackingClassification (堆叠泛化)', value: 'StackingClassification' }, +]; + +// 回归算法 +export const regressorAlgorithms = [ + { label: 'adaboost (自适应提升算法)', value: 'adaboost' }, + { label: 'ard_regression (自动相关性确定回归)', value: 'ard_regression' }, + { label: 'decision_tree (决策树)', value: 'decision_tree' }, + { label: 'extra_trees (极端随机树)', value: 'extra_trees' }, + { label: 'gaussian_process (高斯过程回归)', value: 'gaussian_process' }, + { label: 'gradient_boosting (梯度提升)', value: 'gradient_boosting' }, + { label: 'k_nearest_neighbors (梯度提升)', value: 'k_nearest_neighbors' }, + { label: 'liblinear_svr (liblinear支持向量回归)', value: 'liblinear_svr' }, + { label: 'libsvm_svr (libsvm支持向量回归)', value: 'libsvm_svr' }, + { label: 'mlp (多层感知器)', value: 'mlp' }, + { label: 'random_forest (随机森林)', value: 'random_forest' }, + { label: 'sgd (随机梯度下降)', value: 'sgd' }, + { label: 'LightGBMRegression (轻量梯度提升机回归)', value: 'LightGBMRegression' }, + { label: 'XGBoostRegression (极端梯度提升机回归)', value: 'XGBoostRegression' }, +]; + +// 特征预处理算法 +export const featureAlgorithms = [ + { label: 'densifier (数据增稠)', value: 'densifier' }, + { + label: 'extra_trees_preproc_for_classification (分类任务极端随机树)', + value: 'extra_trees_preproc_for_classification', + }, + { + label: 'extra_trees_preproc_for_regression (回归任务极端随机树)', + value: 'extra_trees_preproc_for_regression', + }, + { label: 'fast_ica (快速独立成分分析)', value: 'fast_ica' }, + { label: 'feature_agglomeration (特征聚合)', value: 'feature_agglomeration' }, + { label: 'kernel_pca (核主成分分析)', value: 'kernel_pca' }, + { label: 'kitchen_sinks (随机特征映射)', value: 'kitchen_sinks' }, + { label: 'liblinear_svc_preprocessor (线性svc预处理器)', value: 'liblinear_svc_preprocessor' }, + { label: 'no_preprocessing (无预处理)', value: 'no_preprocessing' }, + { label: 'nystroem_sampler (尼斯特罗姆采样器)', value: 'nystroem_sampler' }, + { label: 'pca (主成分分析)', value: 'pca' }, + { label: 'polynomial (多项式特征扩展)', value: 'polynomial' }, + { label: 'random_trees_embedding (随机森林特征嵌入)', value: 'random_trees_embedding' }, + { + label: 'select_percentile_classification (基于百分位的分类特征选择)', + value: 'select_percentile_classification', + }, + { + label: 'select_percentile_regression (基于百分位的回归特征选择)', + value: 'select_percentile_regression', + }, + { + label: 'select_rates_classification (基于比率的分类特征选择)', + value: 'select_rates_classification', + }, + { label: 'select_rates_regression (基于比率的回归特征选择)', value: 'select_rates_regression' }, + { label: 'truncatedSVD (截断奇异值分解)', value: 'truncatedSVD' }, +];