From 69451e57d0fef40d23a0d1b172e93454f0dad83d Mon Sep 17 00:00:00 2001 From: zhaowei Date: Fri, 18 Apr 2025 14:09:46 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=88=9B=E5=BB=BA=E4=B8=BB=E5=8A=A8?= =?UTF-8?q?=E5=AD=A6=E4=B9=A0=E5=AE=9E=E9=AA=8C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../src/pages/ActiveLearn/Create/index.tsx | 7 +- .../components/BasicInfo/index.tsx | 164 ++++-- .../components/CreateForm/ExecuteConfig.tsx | 488 ++++++++++++++---- .../components/CreateForm/utils.ts | 83 ++- react-ui/src/pages/ActiveLearn/types.ts | 44 +- .../components/ExperimentList/index.tsx | 6 +- 6 files changed, 582 insertions(+), 210 deletions(-) diff --git a/react-ui/src/pages/ActiveLearn/Create/index.tsx b/react-ui/src/pages/ActiveLearn/Create/index.tsx index 002ad50d..8ad3f3b1 100644 --- a/react-ui/src/pages/ActiveLearn/Create/index.tsx +++ b/react-ui/src/pages/ActiveLearn/Create/index.tsx @@ -4,6 +4,7 @@ * @Description: 创建实验 */ import PageTitle from '@/components/PageTitle'; +import { AutoMLTaskType } from '@/enums'; import { addActiveLearnReq, getActiveLearnInfoReq, @@ -98,7 +99,7 @@ function CreateActiveLearn() { return (
- +
diff --git a/react-ui/src/pages/ActiveLearn/components/BasicInfo/index.tsx b/react-ui/src/pages/ActiveLearn/components/BasicInfo/index.tsx index 5388d5f7..6536b772 100644 --- a/react-ui/src/pages/ActiveLearn/components/BasicInfo/index.tsx +++ b/react-ui/src/pages/ActiveLearn/components/BasicInfo/index.tsx @@ -1,16 +1,26 @@ import ConfigInfo, { type BasicInfoData } from '@/components/ConfigInfo'; +import { AutoMLTaskType, autoMLTaskTypeOptions } from '@/enums'; +import { useComputingResource } from '@/hooks/useComputingResource'; import { classifierAlgorithms, - performanceMetrics, + FrameworkType, + frameworkTypeOptions, queryStrategies, - stoppingCriterions, - StoppingCriterionsType, + regressorAlgorithms, } from '@/pages/ActiveLearn/components/CreateForm/utils'; import { ActiveLearnData } from '@/pages/ActiveLearn/types'; import { experimentStatusInfo } from '@/pages/Experiment/status'; import { type NodeStatus } from '@/types'; import { elapsedTime } from '@/utils/date'; -import { formatDataset, formatDate, formatEnum } from '@/utils/format'; +import { + formatBoolean, + formatCodeConfig, + formatDataset, + formatDate, + formatEnum, + formatMirror, + formatModel, +} from '@/utils/format'; import { Flex } from 'antd'; import classNames from 'classnames'; import { useMemo } from 'react'; @@ -24,6 +34,7 @@ type BasicInfoProps = { }; function BasicInfo({ info, className, runStatus, isInstance = false }: BasicInfoProps) { + const getResourceDescription = useComputingResource()[1]; const basicDatas: BasicInfoData[] = useMemo(() => { if (!info) { return []; @@ -59,47 +70,68 @@ function BasicInfo({ info, className, runStatus, isInstance = false }: BasicInfo if (!info) { return []; } - let stopping_criterion_params_title = ''; - let stopping_criterion_params_value: number | undefined = undefined; - if (info.stopping_criterion === StoppingCriterionsType.NumOfQueries) { - stopping_criterion_params_title = '查询次数'; - stopping_criterion_params_value = info.num_of_queries; - } else if (info.stopping_criterion === StoppingCriterionsType.PercentOfUnlabel) { - stopping_criterion_params_title = '未标记比例'; - stopping_criterion_params_value = info.percent_of_unlabel; - } else { - stopping_criterion_params_title = '时间限制'; - stopping_criterion_params_value = info.time_limit; - } - return [ + const modelInfo = [ { - label: '分类算法', - value: info.classifier_type, - format: formatEnum(classifierAlgorithms), + label: '模型', + value: info.model, + format: formatModel, }, { - label: '停止判则', - value: info.stopping_criterion, - format: formatEnum(stoppingCriterions), + label: '模型文件路径', + value: info.model_py, }, { - label: stopping_criterion_params_title, - value: stopping_criterion_params_value, + label: '模型类名称', + value: info.model_class_name, }, + ]; + + const lossInfo = [ { - label: '查询策略', - value: info.query_strategy, - format: formatEnum(queryStrategies), + label: 'loss文件路径', + value: info.loss_py, }, { - label: '试验次数', - value: info.num_of_experiment, + label: 'loss类名', + value: info.loss_class_name, }, + ]; + + const algorithmInfo = [ { - label: '指标', - value: info.performance_metric, - format: formatEnum(performanceMetrics), + label: info.task_type === AutoMLTaskType.Regression ? '回归算法' : '分类算法', + value: + info.task_type === AutoMLTaskType.Regression ? info.regressor_alg : info.classifier_alg, + format: formatEnum( + info.task_type === AutoMLTaskType.Regression ? regressorAlgorithms : classifierAlgorithms, + ), + }, + ]; + + const diffInfo = + info.framework_type === FrameworkType.Pytorch + ? [...modelInfo, ...lossInfo] + : info.framework_type === FrameworkType.Keras + ? modelInfo + : algorithmInfo; + + return [ + { + label: '任务类型', + value: info.task_type, + format: formatEnum(autoMLTaskTypeOptions), + }, + { + label: '框架类型', + value: info.framework_type, + format: formatEnum(frameworkTypeOptions), + }, + ...diffInfo, + { + label: '代码配置', + value: info.code_config, + format: formatCodeConfig, }, { label: '数据集', @@ -107,19 +139,71 @@ function BasicInfo({ info, className, runStatus, isInstance = false }: BasicInfo format: formatDataset, }, { - label: '预测目标列', - value: info.target_columns, + label: '数据集处理文件路径', + value: info.dataset_py, + }, + { + label: '数据集类名', + value: info.dataset_class_name, + }, + { + label: '镜像', + value: info.image, + format: formatMirror, }, { - label: '测试集比率', - value: info.test_ratio, + label: '资源规格', + value: info.computing_resource_id, + format: getResourceDescription, }, { - label: '初始标记数据比率', - value: info.initial_label_rate, + label: '是否打乱', + value: info.shuffle, + format: formatBoolean, + }, + { + label: '数据量', + value: info.data_size, + }, + { + label: '训练集数据量', + value: info.train_size, + }, + { + label: '初始训练数据量', + value: info.ninitial, + }, + { + label: '查询次数', + value: info.nqueries, + }, + { + label: '每次查询数据量', + value: info.ninstances, + }, + { + label: '查询策略', + value: info.query_strategy, + format: formatEnum(queryStrategies), + }, + { + label: '轮数', + value: info.ncheckpoint, + }, + { + label: 'batch_size', + value: info.batch_size, + }, + { + label: 'epochs', + value: info.epochs, + }, + { + label: '学习率', + value: info.lr, }, ]; - }, [info]); + }, [info, getResourceDescription]); const instanceDatas = useMemo(() => { if (!runStatus) { diff --git a/react-ui/src/pages/ActiveLearn/components/CreateForm/ExecuteConfig.tsx b/react-ui/src/pages/ActiveLearn/components/CreateForm/ExecuteConfig.tsx index 8e7bda92..37d04ee9 100644 --- a/react-ui/src/pages/ActiveLearn/components/CreateForm/ExecuteConfig.tsx +++ b/react-ui/src/pages/ActiveLearn/components/CreateForm/ExecuteConfig.tsx @@ -1,18 +1,22 @@ +import CodeSelect from '@/components/CodeSelect'; +import ParameterSelect from '@/components/ParameterSelect'; import ResourceSelect, { requiredValidator, ResourceSelectorType, } from '@/components/ResourceSelect'; import SubAreaTitle from '@/components/SubAreaTitle'; -import { Col, Form, Input, InputNumber, Row, Select } from 'antd'; +import { AutoMLTaskType, autoMLTaskTypeOptions } from '@/enums'; +import { Col, Form, Input, InputNumber, Radio, Row, Select, Switch } from 'antd'; import { classifierAlgorithms, - performanceMetrics, + FrameworkType, + frameworkTypeOptions, queryStrategies, - stoppingCriterions, - StoppingCriterionsType, + regressorAlgorithms, } from './utils'; function ExecuteConfig() { + const form = Form.useFormInstance(); return ( <> - @@ -65,66 +62,169 @@ function ExecuteConfig() { - + {({ getFieldValue }) => { - const stopping_criterion = getFieldValue('stopping_criterion'); - if (stopping_criterion === StoppingCriterionsType.NumOfQueries) { - return ( - - - - - - - - ); - } else if (stopping_criterion === StoppingCriterionsType.PercentOfUnlabel) { + const taskType = getFieldValue('task_type'); + const frameworkType = getFieldValue('framework_type'); + if (frameworkType === FrameworkType.Keras || frameworkType === FrameworkType.Pytorch) { return ( - - - - - - - - ); - } else if (stopping_criterion === StoppingCriterionsType.TimeLimit) { - return ( - - - - - - - + <> + + + + + + + + + + + + + + + + + + + + + + {frameworkType === FrameworkType.Pytorch ? ( + <> + + + + + + + + + + + + + + + + ) : null} + ); + } else if (frameworkType === FrameworkType.Sklearn) { + if (taskType === AutoMLTaskType.Classification) { + return ( + <> + + + + + + + + + ); + } } else { return null; } @@ -134,16 +234,57 @@ function ExecuteConfig() { + + + + + + + + + + + + + + + + - @@ -151,16 +292,16 @@ function ExecuteConfig() { - + @@ -168,16 +309,16 @@ function ExecuteConfig() { - + + + + + + + + + + + + + + + - + @@ -225,16 +493,16 @@ function ExecuteConfig() { - + @@ -242,16 +510,16 @@ function ExecuteConfig() { - + diff --git a/react-ui/src/pages/ActiveLearn/components/CreateForm/utils.ts b/react-ui/src/pages/ActiveLearn/components/CreateForm/utils.ts index c7fd8567..0090f5df 100644 --- a/react-ui/src/pages/ActiveLearn/components/CreateForm/utils.ts +++ b/react-ui/src/pages/ActiveLearn/components/CreateForm/utils.ts @@ -26,72 +26,67 @@ export const classifierAlgorithms = [ }, ]; -export enum StoppingCriterionsType { - NumOfQueries = 'num_of_queries', - PercentOfUnlabel = 'percent_of_unlabel', - TimeLimit = 'time_limit', -} - -// 停止判则 -export const stoppingCriterions = [ +// 回归算法 +export const regressorAlgorithms = [ { - label: 'num_of_queries(查询次数)', - value: 'num_of_queries', + label: 'bayesian_ridge(岭回归)', + value: 'bayesian_ridge', }, { - label: 'percent_of_unlabel(未标记样本比例)', - value: 'percent_of_unlabel', + label: 'ARD_regression(自动相关性确定回归)', + value: 'ARD_regression', }, { - label: 'time_limit(时间限制)', - value: 'time_limit', - }, + label: 'gaussian_process(高斯回归)', + value: 'gaussian_process', + } ]; -// 查询策略 -export const queryStrategies = [ - { - label: 'Uncertainty(不确定性)', - value: 'Uncertainty', - }, +// 框架类型 +export enum FrameworkType { + Sklearn = 'sklearn', + Keras = 'keras', + Pytorch = 'pytorch', +} + +// 框架类型选项 +export const frameworkTypeOptions = [ { - label: 'QBC(委员会查询)', - value: 'QBC', + label: FrameworkType.Sklearn, + value: FrameworkType.Sklearn, }, { - label: 'Random(随机)', - value: 'Random', + label: FrameworkType.Keras, + value: FrameworkType.Keras, }, { - label: 'GraphDensity(图密度)', - value: 'GraphDensity', + label: FrameworkType.Pytorch, + value: FrameworkType.Pytorch, }, ]; -// 指标 -export const performanceMetrics = [ - { - label: 'accuracy_score', - value: 'accuracy_score', - }, + +// 查询策略 +export const queryStrategies = [ { - label: 'roc_auc_score', - value: 'roc_auc_score', + label: 'uncertainty_sampling', + value: 'uncertainty_sampling', }, { - label: 'get_fps_tps_thresholds', - value: 'get_fps_tps_thresholds', + label: 'uncertainty_batch_sampling', + value: 'uncertainty_batch_sampling', }, { - label: 'hamming_loss', - value: 'hamming_loss', + label: 'max_std_sampling', + value: 'max_std_sampling', }, { - label: 'one_error', - value: 'one_error', + label: 'expected_improvement', + value: 'expected_improvement', }, { - label: 'coverage_error', - value: 'coverage_error', - }, + label: 'upper_confidence_bound', + value: 'upper_confidence_bound', + } ]; + diff --git a/react-ui/src/pages/ActiveLearn/types.ts b/react-ui/src/pages/ActiveLearn/types.ts index b7421d92..acccdd94 100644 --- a/react-ui/src/pages/ActiveLearn/types.ts +++ b/react-ui/src/pages/ActiveLearn/types.ts @@ -1,5 +1,14 @@ +/* + * @Author: error: error: git config user.name & please set dead value or install git && error: git config user.email & please set dead value or install git & please set dead value or install git + * @Date: 2025-04-18 08:40:03 + * @LastEditors: error: error: git config user.name & please set dead value or install git && error: git config user.email & please set dead value or install git & please set dead value or install git + * @LastEditTime: 2025-04-18 11:30:21 + * @FilePath: \ci4s\react-ui\src\pages\ActiveLearn\types.ts + * @Description: 这是默认设置,请设置`customMade`, 打开koroFileHeader查看配置 进行设置: https://github.com/OBKoro1/koro1FileHeader/wiki/%E9%85%8D%E7%BD%AE + */ import { type ParameterInputObject } from '@/components/ResourceSelect'; import { type NodeStatus } from '@/types'; +import { AutoMLTaskType } from '@/enums'; // 操作类型 export enum OperationType { @@ -11,18 +20,33 @@ export enum OperationType { export type FormData = { name: string; // 实验名称 description: string; // 实验描述 + task_type: AutoMLTaskType; // 任务类型 + framework_type: string; // 框架类型 + code_config: ParameterInputObject; // 代码配置 + model?: ParameterInputObject; // 模型 + model_py?: string; // 模型文件路径 + model_class_name?: string; // 模型类名称 + loss_py?: string; // loss文件路径 + loss_class_name?: string; // loss类名 + classifier_alg?: string; // 分类算法 + regressor_alg?: string; // 回归算法 dataset: ParameterInputObject; // 数据集 - classifier_type: string; // 分类算法 - stopping_criterion: string; // 停止判则 + dataset_py: string; // dataset处理文件路径 + dataset_class_name: string; // dataset类名 + data_size: number; // 数据量 + train_size: number; // 训练集数据量 + ninitial: number; // 初始训练数据量 + nqueries: number; // 查询次数 + ninstances: number; // 每次查询数据量 + computing_resource_id: number; // 资源规格 + image: ParameterInputObject; // 镜像 + shuffle: boolean; // 是否随机打乱 query_strategy: string; // 查询策略 - num_of_experiment: number; // 试验次数 - performance_metric: string; // 指标 - target_columns: string; // 预测目标列 - test_ratio: number; // 测试集比率 - initial_label_rate: number; // 初始标记数据比率 - num_of_queries?: number; // 查询次数 - percent_of_unlabel: number; // 未标记比例 - time_limit: number; // 时间限制 + ncheckpoint: number; // 多少轮查询保存一次模型参数 + batch_size: number; // batch_size + epochs: number; // epochs + lr: number; // 学习率 + }; // 主动学习 diff --git a/react-ui/src/pages/AutoML/components/ExperimentList/index.tsx b/react-ui/src/pages/AutoML/components/ExperimentList/index.tsx index d4628c24..a4b53b37 100644 --- a/react-ui/src/pages/AutoML/components/ExperimentList/index.tsx +++ b/react-ui/src/pages/AutoML/components/ExperimentList/index.tsx @@ -63,7 +63,7 @@ function ExperimentList({ type }: ExperimentListProps) { const params: Record = { page: pagination.current! - 1, size: pagination.pageSize, - ml_name: searchText || undefined, + [config.nameProperty]: searchText || undefined, }; const request = config.getListReq; const [res] = await to(request(params)); @@ -248,7 +248,7 @@ function ExperimentList({ type }: ExperimentListProps) { { title: '实验名称', dataIndex: config.nameProperty, - key: 'ml_name', + key: 'name', width: '16%', render: tableCellRender(false, TableCellValueType.Link, { onClick: gotoDetail, @@ -257,7 +257,7 @@ function ExperimentList({ type }: ExperimentListProps) { { title: '实验描述', dataIndex: config.descProperty, - key: 'ml_description', + key: 'description', render: tableCellRender(true), }, {