| @@ -4,6 +4,7 @@ | |||||
| * @Description: 创建实验 | * @Description: 创建实验 | ||||
| */ | */ | ||||
| import PageTitle from '@/components/PageTitle'; | import PageTitle from '@/components/PageTitle'; | ||||
| import { AutoMLTaskType } from '@/enums'; | |||||
| import { | import { | ||||
| addActiveLearnReq, | addActiveLearnReq, | ||||
| getActiveLearnInfoReq, | getActiveLearnInfoReq, | ||||
| @@ -98,7 +99,7 @@ function CreateActiveLearn() { | |||||
| return ( | return ( | ||||
| <div className={styles['create-hyperparameter']}> | <div className={styles['create-hyperparameter']}> | ||||
| <PageTitle title={title} tooltip="仅支持二分类及多分类任务"></PageTitle> | |||||
| <PageTitle title={title}></PageTitle> | |||||
| <div className={styles['create-hyperparameter__content']}> | <div className={styles['create-hyperparameter__content']}> | ||||
| <div> | <div> | ||||
| <Form | <Form | ||||
| @@ -111,8 +112,8 @@ function CreateActiveLearn() { | |||||
| autoComplete="off" | autoComplete="off" | ||||
| scrollToFirstError | scrollToFirstError | ||||
| initialValues={{ | initialValues={{ | ||||
| test_ratio: 0.3, | |||||
| initial_label_rate: 0.1, | |||||
| task_type: AutoMLTaskType.Classification, | |||||
| shuffle: false, | |||||
| }} | }} | ||||
| > | > | ||||
| <BasicConfig /> | <BasicConfig /> | ||||
| @@ -1,16 +1,26 @@ | |||||
| import ConfigInfo, { type BasicInfoData } from '@/components/ConfigInfo'; | import ConfigInfo, { type BasicInfoData } from '@/components/ConfigInfo'; | ||||
| import { AutoMLTaskType, autoMLTaskTypeOptions } from '@/enums'; | |||||
| import { useComputingResource } from '@/hooks/useComputingResource'; | |||||
| import { | import { | ||||
| classifierAlgorithms, | classifierAlgorithms, | ||||
| performanceMetrics, | |||||
| FrameworkType, | |||||
| frameworkTypeOptions, | |||||
| queryStrategies, | queryStrategies, | ||||
| stoppingCriterions, | |||||
| StoppingCriterionsType, | |||||
| regressorAlgorithms, | |||||
| } from '@/pages/ActiveLearn/components/CreateForm/utils'; | } from '@/pages/ActiveLearn/components/CreateForm/utils'; | ||||
| import { ActiveLearnData } from '@/pages/ActiveLearn/types'; | import { ActiveLearnData } from '@/pages/ActiveLearn/types'; | ||||
| import { experimentStatusInfo } from '@/pages/Experiment/status'; | import { experimentStatusInfo } from '@/pages/Experiment/status'; | ||||
| import { type NodeStatus } from '@/types'; | import { type NodeStatus } from '@/types'; | ||||
| import { elapsedTime } from '@/utils/date'; | import { elapsedTime } from '@/utils/date'; | ||||
| import { formatDataset, formatDate, formatEnum } from '@/utils/format'; | |||||
| import { | |||||
| formatBoolean, | |||||
| formatCodeConfig, | |||||
| formatDataset, | |||||
| formatDate, | |||||
| formatEnum, | |||||
| formatMirror, | |||||
| formatModel, | |||||
| } from '@/utils/format'; | |||||
| import { Flex } from 'antd'; | import { Flex } from 'antd'; | ||||
| import classNames from 'classnames'; | import classNames from 'classnames'; | ||||
| import { useMemo } from 'react'; | import { useMemo } from 'react'; | ||||
| @@ -24,6 +34,7 @@ type BasicInfoProps = { | |||||
| }; | }; | ||||
| function BasicInfo({ info, className, runStatus, isInstance = false }: BasicInfoProps) { | function BasicInfo({ info, className, runStatus, isInstance = false }: BasicInfoProps) { | ||||
| const getResourceDescription = useComputingResource()[1]; | |||||
| const basicDatas: BasicInfoData[] = useMemo(() => { | const basicDatas: BasicInfoData[] = useMemo(() => { | ||||
| if (!info) { | if (!info) { | ||||
| return []; | return []; | ||||
| @@ -59,47 +70,68 @@ function BasicInfo({ info, className, runStatus, isInstance = false }: BasicInfo | |||||
| if (!info) { | if (!info) { | ||||
| return []; | return []; | ||||
| } | } | ||||
| let stopping_criterion_params_title = ''; | |||||
| let stopping_criterion_params_value: number | undefined = undefined; | |||||
| if (info.stopping_criterion === StoppingCriterionsType.NumOfQueries) { | |||||
| stopping_criterion_params_title = '查询次数'; | |||||
| stopping_criterion_params_value = info.num_of_queries; | |||||
| } else if (info.stopping_criterion === StoppingCriterionsType.PercentOfUnlabel) { | |||||
| stopping_criterion_params_title = '未标记比例'; | |||||
| stopping_criterion_params_value = info.percent_of_unlabel; | |||||
| } else { | |||||
| stopping_criterion_params_title = '时间限制'; | |||||
| stopping_criterion_params_value = info.time_limit; | |||||
| } | |||||
| return [ | |||||
| const modelInfo = [ | |||||
| { | { | ||||
| label: '分类算法', | |||||
| value: info.classifier_type, | |||||
| format: formatEnum(classifierAlgorithms), | |||||
| label: '模型', | |||||
| value: info.model, | |||||
| format: formatModel, | |||||
| }, | }, | ||||
| { | { | ||||
| label: '停止判则', | |||||
| value: info.stopping_criterion, | |||||
| format: formatEnum(stoppingCriterions), | |||||
| label: '模型文件路径', | |||||
| value: info.model_py, | |||||
| }, | }, | ||||
| { | { | ||||
| label: stopping_criterion_params_title, | |||||
| value: stopping_criterion_params_value, | |||||
| label: '模型类名称', | |||||
| value: info.model_class_name, | |||||
| }, | }, | ||||
| ]; | |||||
| const lossInfo = [ | |||||
| { | { | ||||
| label: '查询策略', | |||||
| value: info.query_strategy, | |||||
| format: formatEnum(queryStrategies), | |||||
| label: 'loss文件路径', | |||||
| value: info.loss_py, | |||||
| }, | }, | ||||
| { | { | ||||
| label: '试验次数', | |||||
| value: info.num_of_experiment, | |||||
| label: 'loss类名', | |||||
| value: info.loss_class_name, | |||||
| }, | }, | ||||
| ]; | |||||
| const algorithmInfo = [ | |||||
| { | { | ||||
| label: '指标', | |||||
| value: info.performance_metric, | |||||
| format: formatEnum(performanceMetrics), | |||||
| label: info.task_type === AutoMLTaskType.Regression ? '回归算法' : '分类算法', | |||||
| value: | |||||
| info.task_type === AutoMLTaskType.Regression ? info.regressor_alg : info.classifier_alg, | |||||
| format: formatEnum( | |||||
| info.task_type === AutoMLTaskType.Regression ? regressorAlgorithms : classifierAlgorithms, | |||||
| ), | |||||
| }, | |||||
| ]; | |||||
| const diffInfo = | |||||
| info.framework_type === FrameworkType.Pytorch | |||||
| ? [...modelInfo, ...lossInfo] | |||||
| : info.framework_type === FrameworkType.Keras | |||||
| ? modelInfo | |||||
| : algorithmInfo; | |||||
| return [ | |||||
| { | |||||
| label: '任务类型', | |||||
| value: info.task_type, | |||||
| format: formatEnum(autoMLTaskTypeOptions), | |||||
| }, | |||||
| { | |||||
| label: '框架类型', | |||||
| value: info.framework_type, | |||||
| format: formatEnum(frameworkTypeOptions), | |||||
| }, | |||||
| ...diffInfo, | |||||
| { | |||||
| label: '代码配置', | |||||
| value: info.code_config, | |||||
| format: formatCodeConfig, | |||||
| }, | }, | ||||
| { | { | ||||
| label: '数据集', | label: '数据集', | ||||
| @@ -107,19 +139,71 @@ function BasicInfo({ info, className, runStatus, isInstance = false }: BasicInfo | |||||
| format: formatDataset, | format: formatDataset, | ||||
| }, | }, | ||||
| { | { | ||||
| label: '预测目标列', | |||||
| value: info.target_columns, | |||||
| label: '数据集处理文件路径', | |||||
| value: info.dataset_py, | |||||
| }, | |||||
| { | |||||
| label: '数据集类名', | |||||
| value: info.dataset_class_name, | |||||
| }, | |||||
| { | |||||
| label: '镜像', | |||||
| value: info.image, | |||||
| format: formatMirror, | |||||
| }, | }, | ||||
| { | { | ||||
| label: '测试集比率', | |||||
| value: info.test_ratio, | |||||
| label: '资源规格', | |||||
| value: info.computing_resource_id, | |||||
| format: getResourceDescription, | |||||
| }, | }, | ||||
| { | { | ||||
| label: '初始标记数据比率', | |||||
| value: info.initial_label_rate, | |||||
| label: '是否打乱', | |||||
| value: info.shuffle, | |||||
| format: formatBoolean, | |||||
| }, | |||||
| { | |||||
| label: '数据量', | |||||
| value: info.data_size, | |||||
| }, | |||||
| { | |||||
| label: '训练集数据量', | |||||
| value: info.train_size, | |||||
| }, | |||||
| { | |||||
| label: '初始训练数据量', | |||||
| value: info.ninitial, | |||||
| }, | |||||
| { | |||||
| label: '查询次数', | |||||
| value: info.nqueries, | |||||
| }, | |||||
| { | |||||
| label: '每次查询数据量', | |||||
| value: info.ninstances, | |||||
| }, | |||||
| { | |||||
| label: '查询策略', | |||||
| value: info.query_strategy, | |||||
| format: formatEnum(queryStrategies), | |||||
| }, | |||||
| { | |||||
| label: '轮数', | |||||
| value: info.ncheckpoint, | |||||
| }, | |||||
| { | |||||
| label: 'batch_size', | |||||
| value: info.batch_size, | |||||
| }, | |||||
| { | |||||
| label: 'epochs', | |||||
| value: info.epochs, | |||||
| }, | |||||
| { | |||||
| label: '学习率', | |||||
| value: info.lr, | |||||
| }, | }, | ||||
| ]; | ]; | ||||
| }, [info]); | |||||
| }, [info, getResourceDescription]); | |||||
| const instanceDatas = useMemo(() => { | const instanceDatas = useMemo(() => { | ||||
| if (!runStatus) { | if (!runStatus) { | ||||
| @@ -1,18 +1,22 @@ | |||||
| import CodeSelect from '@/components/CodeSelect'; | |||||
| import ParameterSelect from '@/components/ParameterSelect'; | |||||
| import ResourceSelect, { | import ResourceSelect, { | ||||
| requiredValidator, | requiredValidator, | ||||
| ResourceSelectorType, | ResourceSelectorType, | ||||
| } from '@/components/ResourceSelect'; | } from '@/components/ResourceSelect'; | ||||
| import SubAreaTitle from '@/components/SubAreaTitle'; | import SubAreaTitle from '@/components/SubAreaTitle'; | ||||
| import { Col, Form, Input, InputNumber, Row, Select } from 'antd'; | |||||
| import { AutoMLTaskType, autoMLTaskTypeOptions } from '@/enums'; | |||||
| import { Col, Form, Input, InputNumber, Radio, Row, Select, Switch } from 'antd'; | |||||
| import { | import { | ||||
| classifierAlgorithms, | classifierAlgorithms, | ||||
| performanceMetrics, | |||||
| FrameworkType, | |||||
| frameworkTypeOptions, | |||||
| queryStrategies, | queryStrategies, | ||||
| stoppingCriterions, | |||||
| StoppingCriterionsType, | |||||
| regressorAlgorithms, | |||||
| } from './utils'; | } from './utils'; | ||||
| function ExecuteConfig() { | function ExecuteConfig() { | ||||
| const form = Form.useFormInstance(); | |||||
| return ( | return ( | ||||
| <> | <> | ||||
| <SubAreaTitle | <SubAreaTitle | ||||
| @@ -24,21 +28,14 @@ function ExecuteConfig() { | |||||
| <Row gutter={8}> | <Row gutter={8}> | ||||
| <Col span={10}> | <Col span={10}> | ||||
| <Form.Item | <Form.Item | ||||
| label="分类算法" | |||||
| name="classifier_type" | |||||
| rules={[ | |||||
| { | |||||
| required: true, | |||||
| message: '请选择分类算法', | |||||
| }, | |||||
| ]} | |||||
| label="任务类型" | |||||
| name="task_type" | |||||
| rules={[{ required: true, message: '请选择任务类型' }]} | |||||
| > | > | ||||
| <Select | |||||
| placeholder="请选择分类算法" | |||||
| options={classifierAlgorithms} | |||||
| showSearch | |||||
| allowClear | |||||
| /> | |||||
| <Radio.Group | |||||
| options={autoMLTaskTypeOptions} | |||||
| onChange={() => form.resetFields(['metrics'])} | |||||
| ></Radio.Group> | |||||
| </Form.Item> | </Form.Item> | ||||
| </Col> | </Col> | ||||
| </Row> | </Row> | ||||
| @@ -46,18 +43,18 @@ function ExecuteConfig() { | |||||
| <Row gutter={8}> | <Row gutter={8}> | ||||
| <Col span={10}> | <Col span={10}> | ||||
| <Form.Item | <Form.Item | ||||
| label="停止判则" | |||||
| name="stopping_criterion" | |||||
| label="框架类型" | |||||
| name="framework_type" | |||||
| rules={[ | rules={[ | ||||
| { | { | ||||
| required: true, | required: true, | ||||
| message: '请选择停止判则', | |||||
| message: '请选择框架类型', | |||||
| }, | }, | ||||
| ]} | ]} | ||||
| > | > | ||||
| <Select | <Select | ||||
| placeholder="请选择停止判则" | |||||
| options={stoppingCriterions} | |||||
| placeholder="请选择框架类型" | |||||
| options={frameworkTypeOptions} | |||||
| showSearch | showSearch | ||||
| allowClear | allowClear | ||||
| /> | /> | ||||
| @@ -65,66 +62,169 @@ function ExecuteConfig() { | |||||
| </Col> | </Col> | ||||
| </Row> | </Row> | ||||
| <Form.Item dependencies={['stopping_criterion']} noStyle> | |||||
| <Form.Item dependencies={['task_type', 'framework_type']} noStyle> | |||||
| {({ getFieldValue }) => { | {({ getFieldValue }) => { | ||||
| const stopping_criterion = getFieldValue('stopping_criterion'); | |||||
| if (stopping_criterion === StoppingCriterionsType.NumOfQueries) { | |||||
| return ( | |||||
| <Row gutter={8}> | |||||
| <Col span={10}> | |||||
| <Form.Item | |||||
| label="查询次数" | |||||
| name="num_of_queries" | |||||
| rules={[ | |||||
| { | |||||
| required: true, | |||||
| message: '请输入查询次数', | |||||
| }, | |||||
| ]} | |||||
| > | |||||
| <InputNumber placeholder="请输入查询次数" min={0} precision={0} /> | |||||
| </Form.Item> | |||||
| </Col> | |||||
| </Row> | |||||
| ); | |||||
| } else if (stopping_criterion === StoppingCriterionsType.PercentOfUnlabel) { | |||||
| const taskType = getFieldValue('task_type'); | |||||
| const frameworkType = getFieldValue('framework_type'); | |||||
| if (frameworkType === FrameworkType.Keras || frameworkType === FrameworkType.Pytorch) { | |||||
| return ( | return ( | ||||
| <Row gutter={8}> | |||||
| <Col span={10}> | |||||
| <Form.Item | |||||
| label="未标记比例" | |||||
| name="percent_of_unlabel" | |||||
| rules={[ | |||||
| { | |||||
| required: true, | |||||
| message: '请输入未标记比例', | |||||
| }, | |||||
| ]} | |||||
| > | |||||
| <InputNumber placeholder="请输入未标记比例" min={0} /> | |||||
| </Form.Item> | |||||
| </Col> | |||||
| </Row> | |||||
| ); | |||||
| } else if (stopping_criterion === StoppingCriterionsType.TimeLimit) { | |||||
| return ( | |||||
| <Row gutter={8}> | |||||
| <Col span={10}> | |||||
| <Form.Item | |||||
| label="时间限制" | |||||
| name="time_limit" | |||||
| rules={[ | |||||
| { | |||||
| required: true, | |||||
| message: '请输入时间限制', | |||||
| }, | |||||
| ]} | |||||
| > | |||||
| <InputNumber placeholder="请输入时间限制" min={0} /> | |||||
| </Form.Item> | |||||
| </Col> | |||||
| </Row> | |||||
| <> | |||||
| <Row gutter={8}> | |||||
| <Col span={10}> | |||||
| <Form.Item | |||||
| label="模型" | |||||
| name="model" | |||||
| rules={[ | |||||
| { | |||||
| validator: requiredValidator, | |||||
| message: '请选择模型', | |||||
| }, | |||||
| ]} | |||||
| required | |||||
| > | |||||
| <ResourceSelect | |||||
| type={ResourceSelectorType.Model} | |||||
| placeholder="请选择模型" | |||||
| canInput={false} | |||||
| size="large" | |||||
| /> | |||||
| </Form.Item> | |||||
| </Col> | |||||
| </Row> | |||||
| <Row gutter={8}> | |||||
| <Col span={10}> | |||||
| <Form.Item | |||||
| label="模型文件路径" | |||||
| name="model_py" | |||||
| rules={[ | |||||
| { | |||||
| required: true, | |||||
| message: '请输入模型文件路径', | |||||
| }, | |||||
| ]} | |||||
| > | |||||
| <Input placeholder="请输入模型文件路径" maxLength={64} showCount allowClear /> | |||||
| </Form.Item> | |||||
| </Col> | |||||
| </Row> | |||||
| <Row gutter={8}> | |||||
| <Col span={10}> | |||||
| <Form.Item | |||||
| label="模型类名称" | |||||
| name="model_class_name" | |||||
| rules={[ | |||||
| { | |||||
| required: true, | |||||
| message: '请输入模型类名称', | |||||
| }, | |||||
| ]} | |||||
| > | |||||
| <Input placeholder="请输入模型类名称" maxLength={64} showCount allowClear /> | |||||
| </Form.Item> | |||||
| </Col> | |||||
| </Row> | |||||
| {frameworkType === FrameworkType.Pytorch ? ( | |||||
| <> | |||||
| <Row gutter={8}> | |||||
| <Col span={10}> | |||||
| <Form.Item | |||||
| label="loss 文件路径" | |||||
| name="loss_py" | |||||
| rules={[ | |||||
| { | |||||
| required: true, | |||||
| message: '请输入 loss 文件路径', | |||||
| }, | |||||
| ]} | |||||
| > | |||||
| <Input | |||||
| placeholder="请输入 loss 文件路径" | |||||
| maxLength={64} | |||||
| showCount | |||||
| allowClear | |||||
| /> | |||||
| </Form.Item> | |||||
| </Col> | |||||
| </Row> | |||||
| <Row gutter={8}> | |||||
| <Col span={10}> | |||||
| <Form.Item | |||||
| label="loss类名" | |||||
| name="loss_class_name" | |||||
| rules={[ | |||||
| { | |||||
| required: true, | |||||
| message: '请输入 loss 类名', | |||||
| }, | |||||
| ]} | |||||
| > | |||||
| <Input | |||||
| placeholder="请输入 loss 类名" | |||||
| maxLength={64} | |||||
| showCount | |||||
| allowClear | |||||
| /> | |||||
| </Form.Item> | |||||
| </Col> | |||||
| </Row> | |||||
| </> | |||||
| ) : null} | |||||
| </> | |||||
| ); | ); | ||||
| } else if (frameworkType === FrameworkType.Sklearn) { | |||||
| if (taskType === AutoMLTaskType.Classification) { | |||||
| return ( | |||||
| <> | |||||
| <Row gutter={8}> | |||||
| <Col span={10}> | |||||
| <Form.Item | |||||
| label="分类算法" | |||||
| name="classifier_alg" | |||||
| rules={[ | |||||
| { | |||||
| required: true, | |||||
| message: '请选择分类算法', | |||||
| }, | |||||
| ]} | |||||
| > | |||||
| <Select | |||||
| placeholder="请选择分类算法" | |||||
| options={classifierAlgorithms} | |||||
| showSearch | |||||
| allowClear | |||||
| /> | |||||
| </Form.Item> | |||||
| </Col> | |||||
| </Row> | |||||
| </> | |||||
| ); | |||||
| } else { | |||||
| return ( | |||||
| <> | |||||
| <Row gutter={8}> | |||||
| <Col span={10}> | |||||
| <Form.Item | |||||
| label="回归算法" | |||||
| name="regressor_alg" | |||||
| rules={[ | |||||
| { | |||||
| required: true, | |||||
| message: '请选择回归算法', | |||||
| }, | |||||
| ]} | |||||
| > | |||||
| <Select | |||||
| placeholder="请选择回归算法" | |||||
| options={regressorAlgorithms} | |||||
| showSearch | |||||
| allowClear | |||||
| /> | |||||
| </Form.Item> | |||||
| </Col> | |||||
| </Row> | |||||
| </> | |||||
| ); | |||||
| } | |||||
| } else { | } else { | ||||
| return null; | return null; | ||||
| } | } | ||||
| @@ -134,16 +234,57 @@ function ExecuteConfig() { | |||||
| <Row gutter={8}> | <Row gutter={8}> | ||||
| <Col span={10}> | <Col span={10}> | ||||
| <Form.Item | <Form.Item | ||||
| label="查询策略" | |||||
| name="query_strategy" | |||||
| label="代码配置" | |||||
| name="code_config" | |||||
| rules={[ | |||||
| { | |||||
| validator: requiredValidator, | |||||
| message: '请选择代码配置', | |||||
| }, | |||||
| ]} | |||||
| required | |||||
| > | |||||
| <CodeSelect placeholder="请选择代码配置" canInput={false} size="large" /> | |||||
| </Form.Item> | |||||
| </Col> | |||||
| </Row> | |||||
| <Row gutter={8}> | |||||
| <Col span={10}> | |||||
| <Form.Item | |||||
| label="数据集" | |||||
| name="dataset" | |||||
| rules={[ | |||||
| { | |||||
| validator: requiredValidator, | |||||
| message: '请选择数据集', | |||||
| }, | |||||
| ]} | |||||
| required | |||||
| > | |||||
| <ResourceSelect | |||||
| type={ResourceSelectorType.Dataset} | |||||
| placeholder="请选择数据集" | |||||
| canInput={false} | |||||
| size="large" | |||||
| /> | |||||
| </Form.Item> | |||||
| </Col> | |||||
| </Row> | |||||
| <Row gutter={8}> | |||||
| <Col span={10}> | |||||
| <Form.Item | |||||
| label="数据集处理文件路径" | |||||
| name="dataset_py" | |||||
| rules={[ | rules={[ | ||||
| { | { | ||||
| required: true, | required: true, | ||||
| message: '请选择查询策略', | |||||
| message: '请输入数据集处理文件路径', | |||||
| }, | }, | ||||
| ]} | ]} | ||||
| > | > | ||||
| <Select placeholder="请选择查询策略" options={queryStrategies} showSearch allowClear /> | |||||
| <Input placeholder="请输入数据集处理文件路径" maxLength={64} showCount allowClear /> | |||||
| </Form.Item> | </Form.Item> | ||||
| </Col> | </Col> | ||||
| </Row> | </Row> | ||||
| @@ -151,16 +292,16 @@ function ExecuteConfig() { | |||||
| <Row gutter={8}> | <Row gutter={8}> | ||||
| <Col span={10}> | <Col span={10}> | ||||
| <Form.Item | <Form.Item | ||||
| label="试验次数" | |||||
| name="num_of_experiment" | |||||
| label="数据集类名" | |||||
| name="dataset_class_name" | |||||
| rules={[ | rules={[ | ||||
| { | { | ||||
| required: true, | required: true, | ||||
| message: '请输入试验次数', | |||||
| message: '请输入数据集类名', | |||||
| }, | }, | ||||
| ]} | ]} | ||||
| > | > | ||||
| <InputNumber placeholder="请输入试验次数" min={0} precision={0} /> | |||||
| <Input placeholder="请输入数据集类名" maxLength={64} showCount allowClear /> | |||||
| </Form.Item> | </Form.Item> | ||||
| </Col> | </Col> | ||||
| </Row> | </Row> | ||||
| @@ -168,16 +309,16 @@ function ExecuteConfig() { | |||||
| <Row gutter={8}> | <Row gutter={8}> | ||||
| <Col span={10}> | <Col span={10}> | ||||
| <Form.Item | <Form.Item | ||||
| label="指标" | |||||
| name="performance_metric" | |||||
| label="数据量" | |||||
| name="data_size" | |||||
| rules={[ | rules={[ | ||||
| { | { | ||||
| required: true, | required: true, | ||||
| message: '请选择指标', | |||||
| message: '请输入数据量', | |||||
| }, | }, | ||||
| ]} | ]} | ||||
| > | > | ||||
| <Select placeholder="请选择指标" options={performanceMetrics} showSearch allowClear /> | |||||
| <InputNumber placeholder="请输入数据量" min={0} precision={0} /> | |||||
| </Form.Item> | </Form.Item> | ||||
| </Col> | </Col> | ||||
| </Row> | </Row> | ||||
| @@ -185,21 +326,20 @@ function ExecuteConfig() { | |||||
| <Row gutter={8}> | <Row gutter={8}> | ||||
| <Col span={10}> | <Col span={10}> | ||||
| <Form.Item | <Form.Item | ||||
| label="数据集" | |||||
| name="dataset" | |||||
| label="镜像" | |||||
| name="image" | |||||
| rules={[ | rules={[ | ||||
| { | { | ||||
| validator: requiredValidator, | validator: requiredValidator, | ||||
| message: '请选择数据集', | |||||
| message: '请选择镜像', | |||||
| }, | }, | ||||
| ]} | ]} | ||||
| required | required | ||||
| > | > | ||||
| <ResourceSelect | <ResourceSelect | ||||
| type={ResourceSelectorType.Dataset} | |||||
| placeholder="请选择数据集" | |||||
| type={ResourceSelectorType.Mirror} | |||||
| placeholder="请选择镜像" | |||||
| canInput={false} | canInput={false} | ||||
| size="large" | |||||
| /> | /> | ||||
| </Form.Item> | </Form.Item> | ||||
| </Col> | </Col> | ||||
| @@ -208,16 +348,144 @@ function ExecuteConfig() { | |||||
| <Row gutter={8}> | <Row gutter={8}> | ||||
| <Col span={10}> | <Col span={10}> | ||||
| <Form.Item | <Form.Item | ||||
| label="预测目标列" | |||||
| name="target_columns" | |||||
| label="资源规格" | |||||
| name="computing_resource_id" | |||||
| rules={[ | |||||
| { | |||||
| required: true, | |||||
| message: '请选择资源规格', | |||||
| }, | |||||
| ]} | |||||
| > | |||||
| <ParameterSelect dataType="resource" placeholder="请选择资源规格" /> | |||||
| </Form.Item> | |||||
| </Col> | |||||
| </Row> | |||||
| <Row gutter={8}> | |||||
| <Col span={10}> | |||||
| <Form.Item label="是否随机打乱" name="shuffle" valuePropName="checked"> | |||||
| <Switch /> | |||||
| </Form.Item> | |||||
| </Col> | |||||
| </Row> | |||||
| <Row gutter={8}> | |||||
| <Col span={10}> | |||||
| <Form.Item | |||||
| label="训练集数据量" | |||||
| name="train_size" | |||||
| rules={[ | |||||
| { | |||||
| required: true, | |||||
| message: '请输入训练集数据量', | |||||
| }, | |||||
| ]} | |||||
| > | |||||
| <InputNumber placeholder="请输入训练集数据量" min={0} precision={0} /> | |||||
| </Form.Item> | |||||
| </Col> | |||||
| </Row> | |||||
| <Row gutter={8}> | |||||
| <Col span={10}> | |||||
| <Form.Item | |||||
| label="初始训练数据量" | |||||
| name="ninitial" | |||||
| rules={[ | |||||
| { | |||||
| required: true, | |||||
| message: '请输入初始训练数据量', | |||||
| }, | |||||
| ]} | |||||
| > | |||||
| <InputNumber placeholder="请输入初始训练数据量" min={0} precision={0} /> | |||||
| </Form.Item> | |||||
| </Col> | |||||
| </Row> | |||||
| <Row gutter={8}> | |||||
| <Col span={10}> | |||||
| <Form.Item | |||||
| label="查询次数" | |||||
| name="nqueries" | |||||
| rules={[ | |||||
| { | |||||
| required: true, | |||||
| message: '请输入查询次数量', | |||||
| }, | |||||
| ]} | |||||
| > | |||||
| <InputNumber placeholder="请输入查询次数" min={0} precision={0} /> | |||||
| </Form.Item> | |||||
| </Col> | |||||
| </Row> | |||||
| <Row gutter={8}> | |||||
| <Col span={10}> | |||||
| <Form.Item | |||||
| label="每次查询数据量" | |||||
| name="ninstances" | |||||
| rules={[ | |||||
| { | |||||
| required: true, | |||||
| message: '请输入每次查询数据量', | |||||
| }, | |||||
| ]} | |||||
| > | |||||
| <InputNumber placeholder="请输入每次查询数据量" min={0} precision={0} /> | |||||
| </Form.Item> | |||||
| </Col> | |||||
| </Row> | |||||
| <Row gutter={8}> | |||||
| <Col span={10}> | |||||
| <Form.Item | |||||
| label="查询策略" | |||||
| name="query_strategy" | |||||
| rules={[ | |||||
| { | |||||
| required: true, | |||||
| message: '请选择查询策略', | |||||
| }, | |||||
| ]} | |||||
| > | |||||
| <Select placeholder="请选择查询策略" options={queryStrategies} showSearch allowClear /> | |||||
| </Form.Item> | |||||
| </Col> | |||||
| </Row> | |||||
| <Row gutter={8}> | |||||
| <Col span={10}> | |||||
| <Form.Item | |||||
| label="轮数" | |||||
| name="ncheckpoint" | |||||
| rules={[ | |||||
| { | |||||
| required: true, | |||||
| message: '请输入轮数', | |||||
| }, | |||||
| ]} | |||||
| tooltip="多少轮查询保存一次模型参数" | |||||
| > | |||||
| <InputNumber placeholder="请输入轮数" min={0} precision={0} /> | |||||
| </Form.Item> | |||||
| </Col> | |||||
| </Row> | |||||
| <Row gutter={8}> | |||||
| <Col span={10}> | |||||
| <Form.Item | |||||
| label="batch_size" | |||||
| name="batch_size" | |||||
| rules={[ | rules={[ | ||||
| { | { | ||||
| required: true, | required: true, | ||||
| message: '请输入预测目标列', | |||||
| message: '请输入 batch_size', | |||||
| }, | }, | ||||
| ]} | ]} | ||||
| > | > | ||||
| <Input placeholder="请输入预测目标列" maxLength={256} showCount allowClear /> | |||||
| <InputNumber placeholder="请输入 batch_size" min={0} precision={0} /> | |||||
| </Form.Item> | </Form.Item> | ||||
| </Col> | </Col> | ||||
| </Row> | </Row> | ||||
| @@ -225,16 +493,16 @@ function ExecuteConfig() { | |||||
| <Row gutter={8}> | <Row gutter={8}> | ||||
| <Col span={10}> | <Col span={10}> | ||||
| <Form.Item | <Form.Item | ||||
| label="测试集比率" | |||||
| name="test_ratio" | |||||
| label="epochs" | |||||
| name="epochs" | |||||
| rules={[ | rules={[ | ||||
| { | { | ||||
| required: true, | required: true, | ||||
| message: '请输入测试集比率', | |||||
| message: '请输入epochs', | |||||
| }, | }, | ||||
| ]} | ]} | ||||
| > | > | ||||
| <InputNumber placeholder="请输入测试集比率" min={0} /> | |||||
| <InputNumber placeholder="请输入epochs" min={0} precision={0} /> | |||||
| </Form.Item> | </Form.Item> | ||||
| </Col> | </Col> | ||||
| </Row> | </Row> | ||||
| @@ -242,16 +510,16 @@ function ExecuteConfig() { | |||||
| <Row gutter={8}> | <Row gutter={8}> | ||||
| <Col span={10}> | <Col span={10}> | ||||
| <Form.Item | <Form.Item | ||||
| label="初始标记数据比率" | |||||
| name="initial_label_rate" | |||||
| label="学习率" | |||||
| name="lr" | |||||
| rules={[ | rules={[ | ||||
| { | { | ||||
| required: true, | required: true, | ||||
| message: '请输入初始标记数据比率', | |||||
| message: '请输入学习率', | |||||
| }, | }, | ||||
| ]} | ]} | ||||
| > | > | ||||
| <InputNumber placeholder="请输入初始标记数据比率" min={0} /> | |||||
| <InputNumber placeholder="请输入学习率" min={0} /> | |||||
| </Form.Item> | </Form.Item> | ||||
| </Col> | </Col> | ||||
| </Row> | </Row> | ||||
| @@ -26,72 +26,67 @@ export const classifierAlgorithms = [ | |||||
| }, | }, | ||||
| ]; | ]; | ||||
| export enum StoppingCriterionsType { | |||||
| NumOfQueries = 'num_of_queries', | |||||
| PercentOfUnlabel = 'percent_of_unlabel', | |||||
| TimeLimit = 'time_limit', | |||||
| } | |||||
| // 停止判则 | |||||
| export const stoppingCriterions = [ | |||||
| // 回归算法 | |||||
| export const regressorAlgorithms = [ | |||||
| { | { | ||||
| label: 'num_of_queries(查询次数)', | |||||
| value: 'num_of_queries', | |||||
| label: 'bayesian_ridge(岭回归)', | |||||
| value: 'bayesian_ridge', | |||||
| }, | }, | ||||
| { | { | ||||
| label: 'percent_of_unlabel(未标记样本比例)', | |||||
| value: 'percent_of_unlabel', | |||||
| label: 'ARD_regression(自动相关性确定回归)', | |||||
| value: 'ARD_regression', | |||||
| }, | }, | ||||
| { | { | ||||
| label: 'time_limit(时间限制)', | |||||
| value: 'time_limit', | |||||
| }, | |||||
| label: 'gaussian_process(高斯回归)', | |||||
| value: 'gaussian_process', | |||||
| } | |||||
| ]; | ]; | ||||
| // 查询策略 | |||||
| export const queryStrategies = [ | |||||
| { | |||||
| label: 'Uncertainty(不确定性)', | |||||
| value: 'Uncertainty', | |||||
| }, | |||||
| // 框架类型 | |||||
| export enum FrameworkType { | |||||
| Sklearn = 'sklearn', | |||||
| Keras = 'keras', | |||||
| Pytorch = 'pytorch', | |||||
| } | |||||
| // 框架类型选项 | |||||
| export const frameworkTypeOptions = [ | |||||
| { | { | ||||
| label: 'QBC(委员会查询)', | |||||
| value: 'QBC', | |||||
| label: FrameworkType.Sklearn, | |||||
| value: FrameworkType.Sklearn, | |||||
| }, | }, | ||||
| { | { | ||||
| label: 'Random(随机)', | |||||
| value: 'Random', | |||||
| label: FrameworkType.Keras, | |||||
| value: FrameworkType.Keras, | |||||
| }, | }, | ||||
| { | { | ||||
| label: 'GraphDensity(图密度)', | |||||
| value: 'GraphDensity', | |||||
| label: FrameworkType.Pytorch, | |||||
| value: FrameworkType.Pytorch, | |||||
| }, | }, | ||||
| ]; | ]; | ||||
| // 指标 | |||||
| export const performanceMetrics = [ | |||||
| { | |||||
| label: 'accuracy_score', | |||||
| value: 'accuracy_score', | |||||
| }, | |||||
| // 查询策略 | |||||
| export const queryStrategies = [ | |||||
| { | { | ||||
| label: 'roc_auc_score', | |||||
| value: 'roc_auc_score', | |||||
| label: 'uncertainty_sampling', | |||||
| value: 'uncertainty_sampling', | |||||
| }, | }, | ||||
| { | { | ||||
| label: 'get_fps_tps_thresholds', | |||||
| value: 'get_fps_tps_thresholds', | |||||
| label: 'uncertainty_batch_sampling', | |||||
| value: 'uncertainty_batch_sampling', | |||||
| }, | }, | ||||
| { | { | ||||
| label: 'hamming_loss', | |||||
| value: 'hamming_loss', | |||||
| label: 'max_std_sampling', | |||||
| value: 'max_std_sampling', | |||||
| }, | }, | ||||
| { | { | ||||
| label: 'one_error', | |||||
| value: 'one_error', | |||||
| label: 'expected_improvement', | |||||
| value: 'expected_improvement', | |||||
| }, | }, | ||||
| { | { | ||||
| label: 'coverage_error', | |||||
| value: 'coverage_error', | |||||
| }, | |||||
| label: 'upper_confidence_bound', | |||||
| value: 'upper_confidence_bound', | |||||
| } | |||||
| ]; | ]; | ||||
| @@ -1,5 +1,14 @@ | |||||
| /* | |||||
| * @Author: error: error: git config user.name & please set dead value or install git && error: git config user.email & please set dead value or install git & please set dead value or install git | |||||
| * @Date: 2025-04-18 08:40:03 | |||||
| * @LastEditors: error: error: git config user.name & please set dead value or install git && error: git config user.email & please set dead value or install git & please set dead value or install git | |||||
| * @LastEditTime: 2025-04-18 11:30:21 | |||||
| * @FilePath: \ci4s\react-ui\src\pages\ActiveLearn\types.ts | |||||
| * @Description: 这是默认设置,请设置`customMade`, 打开koroFileHeader查看配置 进行设置: https://github.com/OBKoro1/koro1FileHeader/wiki/%E9%85%8D%E7%BD%AE | |||||
| */ | |||||
| import { type ParameterInputObject } from '@/components/ResourceSelect'; | import { type ParameterInputObject } from '@/components/ResourceSelect'; | ||||
| import { type NodeStatus } from '@/types'; | import { type NodeStatus } from '@/types'; | ||||
| import { AutoMLTaskType } from '@/enums'; | |||||
| // 操作类型 | // 操作类型 | ||||
| export enum OperationType { | export enum OperationType { | ||||
| @@ -11,18 +20,33 @@ export enum OperationType { | |||||
| export type FormData = { | export type FormData = { | ||||
| name: string; // 实验名称 | name: string; // 实验名称 | ||||
| description: string; // 实验描述 | description: string; // 实验描述 | ||||
| task_type: AutoMLTaskType; // 任务类型 | |||||
| framework_type: string; // 框架类型 | |||||
| code_config: ParameterInputObject; // 代码配置 | |||||
| model?: ParameterInputObject; // 模型 | |||||
| model_py?: string; // 模型文件路径 | |||||
| model_class_name?: string; // 模型类名称 | |||||
| loss_py?: string; // loss文件路径 | |||||
| loss_class_name?: string; // loss类名 | |||||
| classifier_alg?: string; // 分类算法 | |||||
| regressor_alg?: string; // 回归算法 | |||||
| dataset: ParameterInputObject; // 数据集 | dataset: ParameterInputObject; // 数据集 | ||||
| classifier_type: string; // 分类算法 | |||||
| stopping_criterion: string; // 停止判则 | |||||
| dataset_py: string; // dataset处理文件路径 | |||||
| dataset_class_name: string; // dataset类名 | |||||
| data_size: number; // 数据量 | |||||
| train_size: number; // 训练集数据量 | |||||
| ninitial: number; // 初始训练数据量 | |||||
| nqueries: number; // 查询次数 | |||||
| ninstances: number; // 每次查询数据量 | |||||
| computing_resource_id: number; // 资源规格 | |||||
| image: ParameterInputObject; // 镜像 | |||||
| shuffle: boolean; // 是否随机打乱 | |||||
| query_strategy: string; // 查询策略 | query_strategy: string; // 查询策略 | ||||
| num_of_experiment: number; // 试验次数 | |||||
| performance_metric: string; // 指标 | |||||
| target_columns: string; // 预测目标列 | |||||
| test_ratio: number; // 测试集比率 | |||||
| initial_label_rate: number; // 初始标记数据比率 | |||||
| num_of_queries?: number; // 查询次数 | |||||
| percent_of_unlabel: number; // 未标记比例 | |||||
| time_limit: number; // 时间限制 | |||||
| ncheckpoint: number; // 多少轮查询保存一次模型参数 | |||||
| batch_size: number; // batch_size | |||||
| epochs: number; // epochs | |||||
| lr: number; // 学习率 | |||||
| }; | }; | ||||
| // 主动学习 | // 主动学习 | ||||
| @@ -63,7 +63,7 @@ function ExperimentList({ type }: ExperimentListProps) { | |||||
| const params: Record<string, any> = { | const params: Record<string, any> = { | ||||
| page: pagination.current! - 1, | page: pagination.current! - 1, | ||||
| size: pagination.pageSize, | size: pagination.pageSize, | ||||
| ml_name: searchText || undefined, | |||||
| [config.nameProperty]: searchText || undefined, | |||||
| }; | }; | ||||
| const request = config.getListReq; | const request = config.getListReq; | ||||
| const [res] = await to(request(params)); | const [res] = await to(request(params)); | ||||
| @@ -248,7 +248,7 @@ function ExperimentList({ type }: ExperimentListProps) { | |||||
| { | { | ||||
| title: '实验名称', | title: '实验名称', | ||||
| dataIndex: config.nameProperty, | dataIndex: config.nameProperty, | ||||
| key: 'ml_name', | |||||
| key: 'name', | |||||
| width: '16%', | width: '16%', | ||||
| render: tableCellRender(false, TableCellValueType.Link, { | render: tableCellRender(false, TableCellValueType.Link, { | ||||
| onClick: gotoDetail, | onClick: gotoDetail, | ||||
| @@ -257,7 +257,7 @@ function ExperimentList({ type }: ExperimentListProps) { | |||||
| { | { | ||||
| title: '实验描述', | title: '实验描述', | ||||
| dataIndex: config.descProperty, | dataIndex: config.descProperty, | ||||
| key: 'ml_description', | |||||
| key: 'description', | |||||
| render: tableCellRender(true), | render: tableCellRender(true), | ||||
| }, | }, | ||||
| { | { | ||||