| @@ -4,6 +4,7 @@ | |||
| * @Description: 创建实验 | |||
| */ | |||
| import PageTitle from '@/components/PageTitle'; | |||
| import { AutoMLTaskType } from '@/enums'; | |||
| import { | |||
| addActiveLearnReq, | |||
| getActiveLearnInfoReq, | |||
| @@ -98,7 +99,7 @@ function CreateActiveLearn() { | |||
| return ( | |||
| <div className={styles['create-hyperparameter']}> | |||
| <PageTitle title={title} tooltip="仅支持二分类及多分类任务"></PageTitle> | |||
| <PageTitle title={title}></PageTitle> | |||
| <div className={styles['create-hyperparameter__content']}> | |||
| <div> | |||
| <Form | |||
| @@ -111,8 +112,8 @@ function CreateActiveLearn() { | |||
| autoComplete="off" | |||
| scrollToFirstError | |||
| initialValues={{ | |||
| test_ratio: 0.3, | |||
| initial_label_rate: 0.1, | |||
| task_type: AutoMLTaskType.Classification, | |||
| shuffle: false, | |||
| }} | |||
| > | |||
| <BasicConfig /> | |||
| @@ -1,16 +1,26 @@ | |||
| import ConfigInfo, { type BasicInfoData } from '@/components/ConfigInfo'; | |||
| import { AutoMLTaskType, autoMLTaskTypeOptions } from '@/enums'; | |||
| import { useComputingResource } from '@/hooks/useComputingResource'; | |||
| import { | |||
| classifierAlgorithms, | |||
| performanceMetrics, | |||
| FrameworkType, | |||
| frameworkTypeOptions, | |||
| queryStrategies, | |||
| stoppingCriterions, | |||
| StoppingCriterionsType, | |||
| regressorAlgorithms, | |||
| } from '@/pages/ActiveLearn/components/CreateForm/utils'; | |||
| import { ActiveLearnData } from '@/pages/ActiveLearn/types'; | |||
| import { experimentStatusInfo } from '@/pages/Experiment/status'; | |||
| import { type NodeStatus } from '@/types'; | |||
| import { elapsedTime } from '@/utils/date'; | |||
| import { formatDataset, formatDate, formatEnum } from '@/utils/format'; | |||
| import { | |||
| formatBoolean, | |||
| formatCodeConfig, | |||
| formatDataset, | |||
| formatDate, | |||
| formatEnum, | |||
| formatMirror, | |||
| formatModel, | |||
| } from '@/utils/format'; | |||
| import { Flex } from 'antd'; | |||
| import classNames from 'classnames'; | |||
| import { useMemo } from 'react'; | |||
| @@ -24,6 +34,7 @@ type BasicInfoProps = { | |||
| }; | |||
| function BasicInfo({ info, className, runStatus, isInstance = false }: BasicInfoProps) { | |||
| const getResourceDescription = useComputingResource()[1]; | |||
| const basicDatas: BasicInfoData[] = useMemo(() => { | |||
| if (!info) { | |||
| return []; | |||
| @@ -59,47 +70,68 @@ function BasicInfo({ info, className, runStatus, isInstance = false }: BasicInfo | |||
| if (!info) { | |||
| return []; | |||
| } | |||
| let stopping_criterion_params_title = ''; | |||
| let stopping_criterion_params_value: number | undefined = undefined; | |||
| if (info.stopping_criterion === StoppingCriterionsType.NumOfQueries) { | |||
| stopping_criterion_params_title = '查询次数'; | |||
| stopping_criterion_params_value = info.num_of_queries; | |||
| } else if (info.stopping_criterion === StoppingCriterionsType.PercentOfUnlabel) { | |||
| stopping_criterion_params_title = '未标记比例'; | |||
| stopping_criterion_params_value = info.percent_of_unlabel; | |||
| } else { | |||
| stopping_criterion_params_title = '时间限制'; | |||
| stopping_criterion_params_value = info.time_limit; | |||
| } | |||
| return [ | |||
| const modelInfo = [ | |||
| { | |||
| label: '分类算法', | |||
| value: info.classifier_type, | |||
| format: formatEnum(classifierAlgorithms), | |||
| label: '模型', | |||
| value: info.model, | |||
| format: formatModel, | |||
| }, | |||
| { | |||
| label: '停止判则', | |||
| value: info.stopping_criterion, | |||
| format: formatEnum(stoppingCriterions), | |||
| label: '模型文件路径', | |||
| value: info.model_py, | |||
| }, | |||
| { | |||
| label: stopping_criterion_params_title, | |||
| value: stopping_criterion_params_value, | |||
| label: '模型类名称', | |||
| value: info.model_class_name, | |||
| }, | |||
| ]; | |||
| const lossInfo = [ | |||
| { | |||
| label: '查询策略', | |||
| value: info.query_strategy, | |||
| format: formatEnum(queryStrategies), | |||
| label: 'loss文件路径', | |||
| value: info.loss_py, | |||
| }, | |||
| { | |||
| label: '试验次数', | |||
| value: info.num_of_experiment, | |||
| label: 'loss类名', | |||
| value: info.loss_class_name, | |||
| }, | |||
| ]; | |||
| const algorithmInfo = [ | |||
| { | |||
| label: '指标', | |||
| value: info.performance_metric, | |||
| format: formatEnum(performanceMetrics), | |||
| label: info.task_type === AutoMLTaskType.Regression ? '回归算法' : '分类算法', | |||
| value: | |||
| info.task_type === AutoMLTaskType.Regression ? info.regressor_alg : info.classifier_alg, | |||
| format: formatEnum( | |||
| info.task_type === AutoMLTaskType.Regression ? regressorAlgorithms : classifierAlgorithms, | |||
| ), | |||
| }, | |||
| ]; | |||
| const diffInfo = | |||
| info.framework_type === FrameworkType.Pytorch | |||
| ? [...modelInfo, ...lossInfo] | |||
| : info.framework_type === FrameworkType.Keras | |||
| ? modelInfo | |||
| : algorithmInfo; | |||
| return [ | |||
| { | |||
| label: '任务类型', | |||
| value: info.task_type, | |||
| format: formatEnum(autoMLTaskTypeOptions), | |||
| }, | |||
| { | |||
| label: '框架类型', | |||
| value: info.framework_type, | |||
| format: formatEnum(frameworkTypeOptions), | |||
| }, | |||
| ...diffInfo, | |||
| { | |||
| label: '代码配置', | |||
| value: info.code_config, | |||
| format: formatCodeConfig, | |||
| }, | |||
| { | |||
| label: '数据集', | |||
| @@ -107,19 +139,71 @@ function BasicInfo({ info, className, runStatus, isInstance = false }: BasicInfo | |||
| format: formatDataset, | |||
| }, | |||
| { | |||
| label: '预测目标列', | |||
| value: info.target_columns, | |||
| label: '数据集处理文件路径', | |||
| value: info.dataset_py, | |||
| }, | |||
| { | |||
| label: '数据集类名', | |||
| value: info.dataset_class_name, | |||
| }, | |||
| { | |||
| label: '镜像', | |||
| value: info.image, | |||
| format: formatMirror, | |||
| }, | |||
| { | |||
| label: '测试集比率', | |||
| value: info.test_ratio, | |||
| label: '资源规格', | |||
| value: info.computing_resource_id, | |||
| format: getResourceDescription, | |||
| }, | |||
| { | |||
| label: '初始标记数据比率', | |||
| value: info.initial_label_rate, | |||
| label: '是否打乱', | |||
| value: info.shuffle, | |||
| format: formatBoolean, | |||
| }, | |||
| { | |||
| label: '数据量', | |||
| value: info.data_size, | |||
| }, | |||
| { | |||
| label: '训练集数据量', | |||
| value: info.train_size, | |||
| }, | |||
| { | |||
| label: '初始训练数据量', | |||
| value: info.ninitial, | |||
| }, | |||
| { | |||
| label: '查询次数', | |||
| value: info.nqueries, | |||
| }, | |||
| { | |||
| label: '每次查询数据量', | |||
| value: info.ninstances, | |||
| }, | |||
| { | |||
| label: '查询策略', | |||
| value: info.query_strategy, | |||
| format: formatEnum(queryStrategies), | |||
| }, | |||
| { | |||
| label: '轮数', | |||
| value: info.ncheckpoint, | |||
| }, | |||
| { | |||
| label: 'batch_size', | |||
| value: info.batch_size, | |||
| }, | |||
| { | |||
| label: 'epochs', | |||
| value: info.epochs, | |||
| }, | |||
| { | |||
| label: '学习率', | |||
| value: info.lr, | |||
| }, | |||
| ]; | |||
| }, [info]); | |||
| }, [info, getResourceDescription]); | |||
| const instanceDatas = useMemo(() => { | |||
| if (!runStatus) { | |||
| @@ -1,18 +1,22 @@ | |||
| import CodeSelect from '@/components/CodeSelect'; | |||
| import ParameterSelect from '@/components/ParameterSelect'; | |||
| import ResourceSelect, { | |||
| requiredValidator, | |||
| ResourceSelectorType, | |||
| } from '@/components/ResourceSelect'; | |||
| import SubAreaTitle from '@/components/SubAreaTitle'; | |||
| import { Col, Form, Input, InputNumber, Row, Select } from 'antd'; | |||
| import { AutoMLTaskType, autoMLTaskTypeOptions } from '@/enums'; | |||
| import { Col, Form, Input, InputNumber, Radio, Row, Select, Switch } from 'antd'; | |||
| import { | |||
| classifierAlgorithms, | |||
| performanceMetrics, | |||
| FrameworkType, | |||
| frameworkTypeOptions, | |||
| queryStrategies, | |||
| stoppingCriterions, | |||
| StoppingCriterionsType, | |||
| regressorAlgorithms, | |||
| } from './utils'; | |||
| function ExecuteConfig() { | |||
| const form = Form.useFormInstance(); | |||
| return ( | |||
| <> | |||
| <SubAreaTitle | |||
| @@ -24,21 +28,14 @@ function ExecuteConfig() { | |||
| <Row gutter={8}> | |||
| <Col span={10}> | |||
| <Form.Item | |||
| label="分类算法" | |||
| name="classifier_type" | |||
| rules={[ | |||
| { | |||
| required: true, | |||
| message: '请选择分类算法', | |||
| }, | |||
| ]} | |||
| label="任务类型" | |||
| name="task_type" | |||
| rules={[{ required: true, message: '请选择任务类型' }]} | |||
| > | |||
| <Select | |||
| placeholder="请选择分类算法" | |||
| options={classifierAlgorithms} | |||
| showSearch | |||
| allowClear | |||
| /> | |||
| <Radio.Group | |||
| options={autoMLTaskTypeOptions} | |||
| onChange={() => form.resetFields(['metrics'])} | |||
| ></Radio.Group> | |||
| </Form.Item> | |||
| </Col> | |||
| </Row> | |||
| @@ -46,18 +43,18 @@ function ExecuteConfig() { | |||
| <Row gutter={8}> | |||
| <Col span={10}> | |||
| <Form.Item | |||
| label="停止判则" | |||
| name="stopping_criterion" | |||
| label="框架类型" | |||
| name="framework_type" | |||
| rules={[ | |||
| { | |||
| required: true, | |||
| message: '请选择停止判则', | |||
| message: '请选择框架类型', | |||
| }, | |||
| ]} | |||
| > | |||
| <Select | |||
| placeholder="请选择停止判则" | |||
| options={stoppingCriterions} | |||
| placeholder="请选择框架类型" | |||
| options={frameworkTypeOptions} | |||
| showSearch | |||
| allowClear | |||
| /> | |||
| @@ -65,66 +62,169 @@ function ExecuteConfig() { | |||
| </Col> | |||
| </Row> | |||
| <Form.Item dependencies={['stopping_criterion']} noStyle> | |||
| <Form.Item dependencies={['task_type', 'framework_type']} noStyle> | |||
| {({ getFieldValue }) => { | |||
| const stopping_criterion = getFieldValue('stopping_criterion'); | |||
| if (stopping_criterion === StoppingCriterionsType.NumOfQueries) { | |||
| return ( | |||
| <Row gutter={8}> | |||
| <Col span={10}> | |||
| <Form.Item | |||
| label="查询次数" | |||
| name="num_of_queries" | |||
| rules={[ | |||
| { | |||
| required: true, | |||
| message: '请输入查询次数', | |||
| }, | |||
| ]} | |||
| > | |||
| <InputNumber placeholder="请输入查询次数" min={0} precision={0} /> | |||
| </Form.Item> | |||
| </Col> | |||
| </Row> | |||
| ); | |||
| } else if (stopping_criterion === StoppingCriterionsType.PercentOfUnlabel) { | |||
| const taskType = getFieldValue('task_type'); | |||
| const frameworkType = getFieldValue('framework_type'); | |||
| if (frameworkType === FrameworkType.Keras || frameworkType === FrameworkType.Pytorch) { | |||
| return ( | |||
| <Row gutter={8}> | |||
| <Col span={10}> | |||
| <Form.Item | |||
| label="未标记比例" | |||
| name="percent_of_unlabel" | |||
| rules={[ | |||
| { | |||
| required: true, | |||
| message: '请输入未标记比例', | |||
| }, | |||
| ]} | |||
| > | |||
| <InputNumber placeholder="请输入未标记比例" min={0} /> | |||
| </Form.Item> | |||
| </Col> | |||
| </Row> | |||
| ); | |||
| } else if (stopping_criterion === StoppingCriterionsType.TimeLimit) { | |||
| return ( | |||
| <Row gutter={8}> | |||
| <Col span={10}> | |||
| <Form.Item | |||
| label="时间限制" | |||
| name="time_limit" | |||
| rules={[ | |||
| { | |||
| required: true, | |||
| message: '请输入时间限制', | |||
| }, | |||
| ]} | |||
| > | |||
| <InputNumber placeholder="请输入时间限制" min={0} /> | |||
| </Form.Item> | |||
| </Col> | |||
| </Row> | |||
| <> | |||
| <Row gutter={8}> | |||
| <Col span={10}> | |||
| <Form.Item | |||
| label="模型" | |||
| name="model" | |||
| rules={[ | |||
| { | |||
| validator: requiredValidator, | |||
| message: '请选择模型', | |||
| }, | |||
| ]} | |||
| required | |||
| > | |||
| <ResourceSelect | |||
| type={ResourceSelectorType.Model} | |||
| placeholder="请选择模型" | |||
| canInput={false} | |||
| size="large" | |||
| /> | |||
| </Form.Item> | |||
| </Col> | |||
| </Row> | |||
| <Row gutter={8}> | |||
| <Col span={10}> | |||
| <Form.Item | |||
| label="模型文件路径" | |||
| name="model_py" | |||
| rules={[ | |||
| { | |||
| required: true, | |||
| message: '请输入模型文件路径', | |||
| }, | |||
| ]} | |||
| > | |||
| <Input placeholder="请输入模型文件路径" maxLength={64} showCount allowClear /> | |||
| </Form.Item> | |||
| </Col> | |||
| </Row> | |||
| <Row gutter={8}> | |||
| <Col span={10}> | |||
| <Form.Item | |||
| label="模型类名称" | |||
| name="model_class_name" | |||
| rules={[ | |||
| { | |||
| required: true, | |||
| message: '请输入模型类名称', | |||
| }, | |||
| ]} | |||
| > | |||
| <Input placeholder="请输入模型类名称" maxLength={64} showCount allowClear /> | |||
| </Form.Item> | |||
| </Col> | |||
| </Row> | |||
| {frameworkType === FrameworkType.Pytorch ? ( | |||
| <> | |||
| <Row gutter={8}> | |||
| <Col span={10}> | |||
| <Form.Item | |||
| label="loss 文件路径" | |||
| name="loss_py" | |||
| rules={[ | |||
| { | |||
| required: true, | |||
| message: '请输入 loss 文件路径', | |||
| }, | |||
| ]} | |||
| > | |||
| <Input | |||
| placeholder="请输入 loss 文件路径" | |||
| maxLength={64} | |||
| showCount | |||
| allowClear | |||
| /> | |||
| </Form.Item> | |||
| </Col> | |||
| </Row> | |||
| <Row gutter={8}> | |||
| <Col span={10}> | |||
| <Form.Item | |||
| label="loss类名" | |||
| name="loss_class_name" | |||
| rules={[ | |||
| { | |||
| required: true, | |||
| message: '请输入 loss 类名', | |||
| }, | |||
| ]} | |||
| > | |||
| <Input | |||
| placeholder="请输入 loss 类名" | |||
| maxLength={64} | |||
| showCount | |||
| allowClear | |||
| /> | |||
| </Form.Item> | |||
| </Col> | |||
| </Row> | |||
| </> | |||
| ) : null} | |||
| </> | |||
| ); | |||
| } else if (frameworkType === FrameworkType.Sklearn) { | |||
| if (taskType === AutoMLTaskType.Classification) { | |||
| return ( | |||
| <> | |||
| <Row gutter={8}> | |||
| <Col span={10}> | |||
| <Form.Item | |||
| label="分类算法" | |||
| name="classifier_alg" | |||
| rules={[ | |||
| { | |||
| required: true, | |||
| message: '请选择分类算法', | |||
| }, | |||
| ]} | |||
| > | |||
| <Select | |||
| placeholder="请选择分类算法" | |||
| options={classifierAlgorithms} | |||
| showSearch | |||
| allowClear | |||
| /> | |||
| </Form.Item> | |||
| </Col> | |||
| </Row> | |||
| </> | |||
| ); | |||
| } else { | |||
| return ( | |||
| <> | |||
| <Row gutter={8}> | |||
| <Col span={10}> | |||
| <Form.Item | |||
| label="回归算法" | |||
| name="regressor_alg" | |||
| rules={[ | |||
| { | |||
| required: true, | |||
| message: '请选择回归算法', | |||
| }, | |||
| ]} | |||
| > | |||
| <Select | |||
| placeholder="请选择回归算法" | |||
| options={regressorAlgorithms} | |||
| showSearch | |||
| allowClear | |||
| /> | |||
| </Form.Item> | |||
| </Col> | |||
| </Row> | |||
| </> | |||
| ); | |||
| } | |||
| } else { | |||
| return null; | |||
| } | |||
| @@ -134,16 +234,57 @@ function ExecuteConfig() { | |||
| <Row gutter={8}> | |||
| <Col span={10}> | |||
| <Form.Item | |||
| label="查询策略" | |||
| name="query_strategy" | |||
| label="代码配置" | |||
| name="code_config" | |||
| rules={[ | |||
| { | |||
| validator: requiredValidator, | |||
| message: '请选择代码配置', | |||
| }, | |||
| ]} | |||
| required | |||
| > | |||
| <CodeSelect placeholder="请选择代码配置" canInput={false} size="large" /> | |||
| </Form.Item> | |||
| </Col> | |||
| </Row> | |||
| <Row gutter={8}> | |||
| <Col span={10}> | |||
| <Form.Item | |||
| label="数据集" | |||
| name="dataset" | |||
| rules={[ | |||
| { | |||
| validator: requiredValidator, | |||
| message: '请选择数据集', | |||
| }, | |||
| ]} | |||
| required | |||
| > | |||
| <ResourceSelect | |||
| type={ResourceSelectorType.Dataset} | |||
| placeholder="请选择数据集" | |||
| canInput={false} | |||
| size="large" | |||
| /> | |||
| </Form.Item> | |||
| </Col> | |||
| </Row> | |||
| <Row gutter={8}> | |||
| <Col span={10}> | |||
| <Form.Item | |||
| label="数据集处理文件路径" | |||
| name="dataset_py" | |||
| rules={[ | |||
| { | |||
| required: true, | |||
| message: '请选择查询策略', | |||
| message: '请输入数据集处理文件路径', | |||
| }, | |||
| ]} | |||
| > | |||
| <Select placeholder="请选择查询策略" options={queryStrategies} showSearch allowClear /> | |||
| <Input placeholder="请输入数据集处理文件路径" maxLength={64} showCount allowClear /> | |||
| </Form.Item> | |||
| </Col> | |||
| </Row> | |||
| @@ -151,16 +292,16 @@ function ExecuteConfig() { | |||
| <Row gutter={8}> | |||
| <Col span={10}> | |||
| <Form.Item | |||
| label="试验次数" | |||
| name="num_of_experiment" | |||
| label="数据集类名" | |||
| name="dataset_class_name" | |||
| rules={[ | |||
| { | |||
| required: true, | |||
| message: '请输入试验次数', | |||
| message: '请输入数据集类名', | |||
| }, | |||
| ]} | |||
| > | |||
| <InputNumber placeholder="请输入试验次数" min={0} precision={0} /> | |||
| <Input placeholder="请输入数据集类名" maxLength={64} showCount allowClear /> | |||
| </Form.Item> | |||
| </Col> | |||
| </Row> | |||
| @@ -168,16 +309,16 @@ function ExecuteConfig() { | |||
| <Row gutter={8}> | |||
| <Col span={10}> | |||
| <Form.Item | |||
| label="指标" | |||
| name="performance_metric" | |||
| label="数据量" | |||
| name="data_size" | |||
| rules={[ | |||
| { | |||
| required: true, | |||
| message: '请选择指标', | |||
| message: '请输入数据量', | |||
| }, | |||
| ]} | |||
| > | |||
| <Select placeholder="请选择指标" options={performanceMetrics} showSearch allowClear /> | |||
| <InputNumber placeholder="请输入数据量" min={0} precision={0} /> | |||
| </Form.Item> | |||
| </Col> | |||
| </Row> | |||
| @@ -185,21 +326,20 @@ function ExecuteConfig() { | |||
| <Row gutter={8}> | |||
| <Col span={10}> | |||
| <Form.Item | |||
| label="数据集" | |||
| name="dataset" | |||
| label="镜像" | |||
| name="image" | |||
| rules={[ | |||
| { | |||
| validator: requiredValidator, | |||
| message: '请选择数据集', | |||
| message: '请选择镜像', | |||
| }, | |||
| ]} | |||
| required | |||
| > | |||
| <ResourceSelect | |||
| type={ResourceSelectorType.Dataset} | |||
| placeholder="请选择数据集" | |||
| type={ResourceSelectorType.Mirror} | |||
| placeholder="请选择镜像" | |||
| canInput={false} | |||
| size="large" | |||
| /> | |||
| </Form.Item> | |||
| </Col> | |||
| @@ -208,16 +348,144 @@ function ExecuteConfig() { | |||
| <Row gutter={8}> | |||
| <Col span={10}> | |||
| <Form.Item | |||
| label="预测目标列" | |||
| name="target_columns" | |||
| label="资源规格" | |||
| name="computing_resource_id" | |||
| rules={[ | |||
| { | |||
| required: true, | |||
| message: '请选择资源规格', | |||
| }, | |||
| ]} | |||
| > | |||
| <ParameterSelect dataType="resource" placeholder="请选择资源规格" /> | |||
| </Form.Item> | |||
| </Col> | |||
| </Row> | |||
| <Row gutter={8}> | |||
| <Col span={10}> | |||
| <Form.Item label="是否随机打乱" name="shuffle" valuePropName="checked"> | |||
| <Switch /> | |||
| </Form.Item> | |||
| </Col> | |||
| </Row> | |||
| <Row gutter={8}> | |||
| <Col span={10}> | |||
| <Form.Item | |||
| label="训练集数据量" | |||
| name="train_size" | |||
| rules={[ | |||
| { | |||
| required: true, | |||
| message: '请输入训练集数据量', | |||
| }, | |||
| ]} | |||
| > | |||
| <InputNumber placeholder="请输入训练集数据量" min={0} precision={0} /> | |||
| </Form.Item> | |||
| </Col> | |||
| </Row> | |||
| <Row gutter={8}> | |||
| <Col span={10}> | |||
| <Form.Item | |||
| label="初始训练数据量" | |||
| name="ninitial" | |||
| rules={[ | |||
| { | |||
| required: true, | |||
| message: '请输入初始训练数据量', | |||
| }, | |||
| ]} | |||
| > | |||
| <InputNumber placeholder="请输入初始训练数据量" min={0} precision={0} /> | |||
| </Form.Item> | |||
| </Col> | |||
| </Row> | |||
| <Row gutter={8}> | |||
| <Col span={10}> | |||
| <Form.Item | |||
| label="查询次数" | |||
| name="nqueries" | |||
| rules={[ | |||
| { | |||
| required: true, | |||
| message: '请输入查询次数量', | |||
| }, | |||
| ]} | |||
| > | |||
| <InputNumber placeholder="请输入查询次数" min={0} precision={0} /> | |||
| </Form.Item> | |||
| </Col> | |||
| </Row> | |||
| <Row gutter={8}> | |||
| <Col span={10}> | |||
| <Form.Item | |||
| label="每次查询数据量" | |||
| name="ninstances" | |||
| rules={[ | |||
| { | |||
| required: true, | |||
| message: '请输入每次查询数据量', | |||
| }, | |||
| ]} | |||
| > | |||
| <InputNumber placeholder="请输入每次查询数据量" min={0} precision={0} /> | |||
| </Form.Item> | |||
| </Col> | |||
| </Row> | |||
| <Row gutter={8}> | |||
| <Col span={10}> | |||
| <Form.Item | |||
| label="查询策略" | |||
| name="query_strategy" | |||
| rules={[ | |||
| { | |||
| required: true, | |||
| message: '请选择查询策略', | |||
| }, | |||
| ]} | |||
| > | |||
| <Select placeholder="请选择查询策略" options={queryStrategies} showSearch allowClear /> | |||
| </Form.Item> | |||
| </Col> | |||
| </Row> | |||
| <Row gutter={8}> | |||
| <Col span={10}> | |||
| <Form.Item | |||
| label="轮数" | |||
| name="ncheckpoint" | |||
| rules={[ | |||
| { | |||
| required: true, | |||
| message: '请输入轮数', | |||
| }, | |||
| ]} | |||
| tooltip="多少轮查询保存一次模型参数" | |||
| > | |||
| <InputNumber placeholder="请输入轮数" min={0} precision={0} /> | |||
| </Form.Item> | |||
| </Col> | |||
| </Row> | |||
| <Row gutter={8}> | |||
| <Col span={10}> | |||
| <Form.Item | |||
| label="batch_size" | |||
| name="batch_size" | |||
| rules={[ | |||
| { | |||
| required: true, | |||
| message: '请输入预测目标列', | |||
| message: '请输入 batch_size', | |||
| }, | |||
| ]} | |||
| > | |||
| <Input placeholder="请输入预测目标列" maxLength={256} showCount allowClear /> | |||
| <InputNumber placeholder="请输入 batch_size" min={0} precision={0} /> | |||
| </Form.Item> | |||
| </Col> | |||
| </Row> | |||
| @@ -225,16 +493,16 @@ function ExecuteConfig() { | |||
| <Row gutter={8}> | |||
| <Col span={10}> | |||
| <Form.Item | |||
| label="测试集比率" | |||
| name="test_ratio" | |||
| label="epochs" | |||
| name="epochs" | |||
| rules={[ | |||
| { | |||
| required: true, | |||
| message: '请输入测试集比率', | |||
| message: '请输入epochs', | |||
| }, | |||
| ]} | |||
| > | |||
| <InputNumber placeholder="请输入测试集比率" min={0} /> | |||
| <InputNumber placeholder="请输入epochs" min={0} precision={0} /> | |||
| </Form.Item> | |||
| </Col> | |||
| </Row> | |||
| @@ -242,16 +510,16 @@ function ExecuteConfig() { | |||
| <Row gutter={8}> | |||
| <Col span={10}> | |||
| <Form.Item | |||
| label="初始标记数据比率" | |||
| name="initial_label_rate" | |||
| label="学习率" | |||
| name="lr" | |||
| rules={[ | |||
| { | |||
| required: true, | |||
| message: '请输入初始标记数据比率', | |||
| message: '请输入学习率', | |||
| }, | |||
| ]} | |||
| > | |||
| <InputNumber placeholder="请输入初始标记数据比率" min={0} /> | |||
| <InputNumber placeholder="请输入学习率" min={0} /> | |||
| </Form.Item> | |||
| </Col> | |||
| </Row> | |||
| @@ -26,72 +26,67 @@ export const classifierAlgorithms = [ | |||
| }, | |||
| ]; | |||
| export enum StoppingCriterionsType { | |||
| NumOfQueries = 'num_of_queries', | |||
| PercentOfUnlabel = 'percent_of_unlabel', | |||
| TimeLimit = 'time_limit', | |||
| } | |||
| // 停止判则 | |||
| export const stoppingCriterions = [ | |||
| // 回归算法 | |||
| export const regressorAlgorithms = [ | |||
| { | |||
| label: 'num_of_queries(查询次数)', | |||
| value: 'num_of_queries', | |||
| label: 'bayesian_ridge(岭回归)', | |||
| value: 'bayesian_ridge', | |||
| }, | |||
| { | |||
| label: 'percent_of_unlabel(未标记样本比例)', | |||
| value: 'percent_of_unlabel', | |||
| label: 'ARD_regression(自动相关性确定回归)', | |||
| value: 'ARD_regression', | |||
| }, | |||
| { | |||
| label: 'time_limit(时间限制)', | |||
| value: 'time_limit', | |||
| }, | |||
| label: 'gaussian_process(高斯回归)', | |||
| value: 'gaussian_process', | |||
| } | |||
| ]; | |||
| // 查询策略 | |||
| export const queryStrategies = [ | |||
| { | |||
| label: 'Uncertainty(不确定性)', | |||
| value: 'Uncertainty', | |||
| }, | |||
| // 框架类型 | |||
| export enum FrameworkType { | |||
| Sklearn = 'sklearn', | |||
| Keras = 'keras', | |||
| Pytorch = 'pytorch', | |||
| } | |||
| // 框架类型选项 | |||
| export const frameworkTypeOptions = [ | |||
| { | |||
| label: 'QBC(委员会查询)', | |||
| value: 'QBC', | |||
| label: FrameworkType.Sklearn, | |||
| value: FrameworkType.Sklearn, | |||
| }, | |||
| { | |||
| label: 'Random(随机)', | |||
| value: 'Random', | |||
| label: FrameworkType.Keras, | |||
| value: FrameworkType.Keras, | |||
| }, | |||
| { | |||
| label: 'GraphDensity(图密度)', | |||
| value: 'GraphDensity', | |||
| label: FrameworkType.Pytorch, | |||
| value: FrameworkType.Pytorch, | |||
| }, | |||
| ]; | |||
| // 指标 | |||
| export const performanceMetrics = [ | |||
| { | |||
| label: 'accuracy_score', | |||
| value: 'accuracy_score', | |||
| }, | |||
| // 查询策略 | |||
| export const queryStrategies = [ | |||
| { | |||
| label: 'roc_auc_score', | |||
| value: 'roc_auc_score', | |||
| label: 'uncertainty_sampling', | |||
| value: 'uncertainty_sampling', | |||
| }, | |||
| { | |||
| label: 'get_fps_tps_thresholds', | |||
| value: 'get_fps_tps_thresholds', | |||
| label: 'uncertainty_batch_sampling', | |||
| value: 'uncertainty_batch_sampling', | |||
| }, | |||
| { | |||
| label: 'hamming_loss', | |||
| value: 'hamming_loss', | |||
| label: 'max_std_sampling', | |||
| value: 'max_std_sampling', | |||
| }, | |||
| { | |||
| label: 'one_error', | |||
| value: 'one_error', | |||
| label: 'expected_improvement', | |||
| value: 'expected_improvement', | |||
| }, | |||
| { | |||
| label: 'coverage_error', | |||
| value: 'coverage_error', | |||
| }, | |||
| label: 'upper_confidence_bound', | |||
| value: 'upper_confidence_bound', | |||
| } | |||
| ]; | |||
| @@ -1,5 +1,14 @@ | |||
| /* | |||
| * @Author: error: error: git config user.name & please set dead value or install git && error: git config user.email & please set dead value or install git & please set dead value or install git | |||
| * @Date: 2025-04-18 08:40:03 | |||
| * @LastEditors: error: error: git config user.name & please set dead value or install git && error: git config user.email & please set dead value or install git & please set dead value or install git | |||
| * @LastEditTime: 2025-04-18 11:30:21 | |||
| * @FilePath: \ci4s\react-ui\src\pages\ActiveLearn\types.ts | |||
| * @Description: 这是默认设置,请设置`customMade`, 打开koroFileHeader查看配置 进行设置: https://github.com/OBKoro1/koro1FileHeader/wiki/%E9%85%8D%E7%BD%AE | |||
| */ | |||
| import { type ParameterInputObject } from '@/components/ResourceSelect'; | |||
| import { type NodeStatus } from '@/types'; | |||
| import { AutoMLTaskType } from '@/enums'; | |||
| // 操作类型 | |||
| export enum OperationType { | |||
| @@ -11,18 +20,33 @@ export enum OperationType { | |||
| export type FormData = { | |||
| name: string; // 实验名称 | |||
| description: string; // 实验描述 | |||
| task_type: AutoMLTaskType; // 任务类型 | |||
| framework_type: string; // 框架类型 | |||
| code_config: ParameterInputObject; // 代码配置 | |||
| model?: ParameterInputObject; // 模型 | |||
| model_py?: string; // 模型文件路径 | |||
| model_class_name?: string; // 模型类名称 | |||
| loss_py?: string; // loss文件路径 | |||
| loss_class_name?: string; // loss类名 | |||
| classifier_alg?: string; // 分类算法 | |||
| regressor_alg?: string; // 回归算法 | |||
| dataset: ParameterInputObject; // 数据集 | |||
| classifier_type: string; // 分类算法 | |||
| stopping_criterion: string; // 停止判则 | |||
| dataset_py: string; // dataset处理文件路径 | |||
| dataset_class_name: string; // dataset类名 | |||
| data_size: number; // 数据量 | |||
| train_size: number; // 训练集数据量 | |||
| ninitial: number; // 初始训练数据量 | |||
| nqueries: number; // 查询次数 | |||
| ninstances: number; // 每次查询数据量 | |||
| computing_resource_id: number; // 资源规格 | |||
| image: ParameterInputObject; // 镜像 | |||
| shuffle: boolean; // 是否随机打乱 | |||
| query_strategy: string; // 查询策略 | |||
| num_of_experiment: number; // 试验次数 | |||
| performance_metric: string; // 指标 | |||
| target_columns: string; // 预测目标列 | |||
| test_ratio: number; // 测试集比率 | |||
| initial_label_rate: number; // 初始标记数据比率 | |||
| num_of_queries?: number; // 查询次数 | |||
| percent_of_unlabel: number; // 未标记比例 | |||
| time_limit: number; // 时间限制 | |||
| ncheckpoint: number; // 多少轮查询保存一次模型参数 | |||
| batch_size: number; // batch_size | |||
| epochs: number; // epochs | |||
| lr: number; // 学习率 | |||
| }; | |||
| // 主动学习 | |||
| @@ -63,7 +63,7 @@ function ExperimentList({ type }: ExperimentListProps) { | |||
| const params: Record<string, any> = { | |||
| page: pagination.current! - 1, | |||
| size: pagination.pageSize, | |||
| ml_name: searchText || undefined, | |||
| [config.nameProperty]: searchText || undefined, | |||
| }; | |||
| const request = config.getListReq; | |||
| const [res] = await to(request(params)); | |||
| @@ -248,7 +248,7 @@ function ExperimentList({ type }: ExperimentListProps) { | |||
| { | |||
| title: '实验名称', | |||
| dataIndex: config.nameProperty, | |||
| key: 'ml_name', | |||
| key: 'name', | |||
| width: '16%', | |||
| render: tableCellRender(false, TableCellValueType.Link, { | |||
| onClick: gotoDetail, | |||
| @@ -257,7 +257,7 @@ function ExperimentList({ type }: ExperimentListProps) { | |||
| { | |||
| title: '实验描述', | |||
| dataIndex: config.descProperty, | |||
| key: 'ml_description', | |||
| key: 'description', | |||
| render: tableCellRender(true), | |||
| }, | |||
| { | |||