Browse Source

feat: 验证超参数自动寻优的手动运行参数

pull/171/head
cp3hnu 1 year ago
parent
commit
4cecd72fc5
10 changed files with 195 additions and 100 deletions
  1. +19
    -4
      react-ui/src/pages/HyperParameter/Create/index.tsx
  2. +47
    -34
      react-ui/src/pages/HyperParameter/components/CreateForm/ExecuteConfig.tsx
  3. +7
    -1
      react-ui/src/pages/HyperParameter/components/CreateForm/index.less
  4. +52
    -0
      react-ui/src/pages/HyperParameter/components/CreateForm/utils.ts
  5. +28
    -6
      react-ui/src/pages/HyperParameter/components/HyperParameterBasic/index.tsx
  6. +14
    -16
      react-ui/src/pages/HyperParameter/components/ParameterInfo/index.tsx
  7. +1
    -1
      react-ui/src/pages/HyperParameter/types.ts
  8. +2
    -2
      react-ui/src/pages/ModelDeployment/components/VersionBasicInfo/index.tsx
  9. +23
    -35
      react-ui/src/utils/format.ts
  10. +2
    -1
      react-ui/src/utils/table.tsx

+ 19
- 4
react-ui/src/pages/HyperParameter/Create/index.tsx View File

@@ -61,13 +61,28 @@ function CreateHyperparameter() {
// 创建、更新、复制实验 // 创建、更新、复制实验
const createExperiment = async (formData: FormData) => { const createExperiment = async (formData: FormData) => {
// 按后台接口要求,修改参数表单数据结构,将 "value" 参数改为 "bounds"/"values"/"value" // 按后台接口要求,修改参数表单数据结构,将 "value" 参数改为 "bounds"/"values"/"value"
const parameters = formData['parameters'];
parameters.forEach((item) => {
const formParameters = formData['parameters'];

const parameters = formParameters.map((item) => {
const paramName = getReqParamName(item.type); const paramName = getReqParamName(item.type);
item[paramName] = item.range;
delete item.range;
const range = item.range;
return {
...item,
[paramName]: range,
range: undefined,
};
}); });


// const runParameters = formData['points_to_evaluate'];
// for (const item of parameters) {
// const name = item.name;
// const arr = runParameters.filter((item) => isEmpty(item[name]));
// if (arr.length > 0 && arr.length < runParameters.length) {
// message.error(`手动运行参数 ${name} 必须全部填写或者都不填写`);
// return;
// }
// }

// 根据后台要求,修改表单数据 // 根据后台要求,修改表单数据
const object = { const object = {
...formData, ...formData,


+ 47
- 34
react-ui/src/pages/HyperParameter/components/CreateForm/ExecuteConfig.tsx View File

@@ -7,24 +7,20 @@ import ResourceSelect, {
import SubAreaTitle from '@/components/SubAreaTitle'; import SubAreaTitle from '@/components/SubAreaTitle';
import { hyperParameterOptimizedModeOptions } from '@/enums'; import { hyperParameterOptimizedModeOptions } from '@/enums';
import { useComputingResource } from '@/hooks/resource'; import { useComputingResource } from '@/hooks/resource';
import { isEmpty } from '@/utils';
import { modalConfirm } from '@/utils/ui'; import { modalConfirm } from '@/utils/ui';
import { MinusCircleOutlined, PlusCircleOutlined, QuestionCircleOutlined } from '@ant-design/icons'; import { MinusCircleOutlined, PlusCircleOutlined, QuestionCircleOutlined } from '@ant-design/icons';
import { Button, Col, Flex, Form, Input, InputNumber, Radio, Row, Select, Tooltip } from 'antd'; import { Button, Col, Flex, Form, Input, InputNumber, Radio, Row, Select, Tooltip } from 'antd';
import { isEqual } from 'lodash'; import { isEqual } from 'lodash';
import PopParameterRange from './PopParameterRange'; import PopParameterRange from './PopParameterRange';
import styles from './index.less'; import styles from './index.less';
import { axParameterOptions, parameterOptions, type FormParameter } from './utils';

// 搜索算法
const searchAlgorithms = ['HyperOpt', 'HEBO', 'BayesOpt', 'Optuna', 'ZOOpt', 'Ax'].map((name) => ({
label: name,
value: name,
}));

// 调度算法
const schedulerAlgorithms = ['ASHA', 'HyperBand', 'MedianStopping', 'PopulationBased', 'PB2'].map(
(name) => ({ label: name, value: name }),
);
import {
axParameterOptions,
parameterOptions,
schedulerAlgorithms,
searchAlgorithms,
type FormParameter,
} from './utils';


const parameterTooltip = `uniform(-5, -1) const parameterTooltip = `uniform(-5, -1)
在 -5.0 和 -1.0 之间均匀采样浮点数 在 -5.0 和 -1.0 之间均匀采样浮点数
@@ -146,17 +142,13 @@ function ExecuteConfig() {


<Row gutter={8}> <Row gutter={8}>
<Col span={10}> <Col span={10}>
<Form.Item
label="数据集挂载路径"
name="dataset_path"
rules={[
{
required: true,
message: '请输入数据集挂载路径',
},
]}
>
<Input placeholder="请输入数据集挂载路径" maxLength={64} showCount allowClear />
<Form.Item label="模型" name="model">
<ResourceSelect
type={ResourceSelectorType.Model}
placeholder="请选择模型"
canInput={false}
size="large"
/>
</Form.Item> </Form.Item>
</Col> </Col>
</Row> </Row>
@@ -405,15 +397,36 @@ function ExecuteConfig() {
} }


return ( return (
<Form.List name="points_to_evaluate">
{(fields, { add, remove }) => (
<Form.List
name="points_to_evaluate"
rules={[
{
validator: (_, runParameters) => {
const parameters = form.getFieldValue('parameters');
for (const item of parameters) {
const name = item.name;
const arr = runParameters.filter((item?: Record<string, any>) =>
isEmpty(item?.[name]),
);
if (arr.length > 0 && arr.length < runParameters.length) {
return Promise.reject(
new Error(`手动运行参数 ${name} 必须全部填写或者都不填写`),
);
}
}

return Promise.resolve();
},
},
]}
>
{(fields, { add, remove }, { errors }) => (
<> <>
<Row gutter={8}> <Row gutter={8}>
<Col span={10}> <Col span={10}>
<Form.Item <Form.Item
label="手动运行参数" label="手动运行参数"
style={{ marginBottom: 0, marginTop: '-14px' }} style={{ marginBottom: 0, marginTop: '-14px' }}
required
></Form.Item> ></Form.Item>
</Col> </Col>
</Row> </Row>
@@ -429,15 +442,14 @@ function ExecuteConfig() {
labelCol={{ flex: '140px' }} labelCol={{ flex: '140px' }}
name={[name, item.name]} name={[name, item.name]}
preserve={false} preserve={false}
required
rules={[
{
required: true,
message: '请输入',
},
]}
> >
<Input placeholder="请输入" maxLength={64} showCount allowClear />
<Input
placeholder="请输入"
maxLength={64}
showCount
allowClear
onChange={() => form.validateFields(['points_to_evaluate'])}
/>
</Form.Item> </Form.Item>
))} ))}
</div> </div>
@@ -472,6 +484,7 @@ function ExecuteConfig() {
</div> </div>
</Flex> </Flex>
))} ))}
<Form.ErrorList errors={errors} className={styles['run-parameter__error']} />
</div> </div>
</> </>
)} )}


+ 7
- 1
react-ui/src/pages/HyperParameter/components/CreateForm/index.less View File

@@ -119,13 +119,19 @@
flex: 1; flex: 1;
margin-right: 10px; margin-right: 10px;
padding: 20px 20px 0; padding: 20px 20px 0;
border: 1px dashed #dddddd;
border: 1px dashed #e0e0e0;
border-radius: 8px; border-radius: 8px;
} }

&__operation { &__operation {
display: flex; display: flex;
flex: none; flex: none;
align-items: center; align-items: center;
width: 100px; width: 100px;
} }

&__error {
margin-top: -20px;
color: @error-color;
}
} }

+ 52
- 0
react-ui/src/pages/HyperParameter/components/CreateForm/utils.ts View File

@@ -153,3 +153,55 @@ export const getReqParamName = (type: ParameterType) => {
return 'bounds'; return 'bounds';
} }
}; };

// 搜索算法
export const searchAlgorithms = [
{
label: 'HyperOpt(分布式异步超参数优化)',
value: 'HyperOpt',
},
{
label: 'HEBO(异方差进化贝叶斯优化)',
value: 'HEBO',
},
{
label: 'BayesOpt(贝叶斯优化)',
value: 'BayesOpt',
},
{
label: 'Optuna',
value: 'Optuna',
},
{
label: 'ZOOpt',
value: 'ZOOpt',
},
{
label: 'Ax',
value: 'Ax',
},
];

// 调度算法
export const schedulerAlgorithms = [
{
label: 'ASHA(异步连续减半)',
value: 'ASHA',
},
{
label: 'HyperBand(HyperBand 早停算法)',
value: 'HyperBand',
},
{
label: 'MedianStopping(中值停止规则)',
value: 'MedianStopping',
},
{
label: 'PopulationBased(基于种群训练)',
value: 'PopulationBased',
},
{
label: 'PB2(Population Based Bandits)',
value: 'PB2',
},
];

+ 28
- 6
react-ui/src/pages/HyperParameter/components/HyperParameterBasic/index.tsx View File

@@ -2,10 +2,20 @@ import { hyperParameterOptimizedMode } from '@/enums';
import { useComputingResource } from '@/hooks/resource'; import { useComputingResource } from '@/hooks/resource';
import ConfigInfo, { type BasicInfoData } from '@/pages/AutoML/components/ConfigInfo'; import ConfigInfo, { type BasicInfoData } from '@/pages/AutoML/components/ConfigInfo';
import { experimentStatusInfo } from '@/pages/Experiment/status'; import { experimentStatusInfo } from '@/pages/Experiment/status';
import {
schedulerAlgorithms,
searchAlgorithms,
} from '@/pages/HyperParameter/components/CreateForm/utils';
import { HyperparameterData } from '@/pages/HyperParameter/types'; import { HyperparameterData } from '@/pages/HyperParameter/types';
import { type NodeStatus } from '@/types'; import { type NodeStatus } from '@/types';
import { elapsedTime } from '@/utils/date'; import { elapsedTime } from '@/utils/date';
import { formatDataset, formatDate, formatSelectCodeConfig } from '@/utils/format';
import {
formatCodeConfig,
formatDataset,
formatDate,
formatEnum,
formatModel,
} from '@/utils/format';
import { Flex } from 'antd'; import { Flex } from 'antd';
import classNames from 'classnames'; import classNames from 'classnames';
import { useMemo } from 'react'; import { useMemo } from 'react';
@@ -86,7 +96,7 @@ function HyperParameterBasic({
label: '代码', label: '代码',
value: info.code, value: info.code,
ellipsis: true, ellipsis: true,
format: formatSelectCodeConfig,
format: formatCodeConfig,
}, },
{ {
label: '主函数代码文件', label: '主函数代码文件',
@@ -99,11 +109,11 @@ function HyperParameterBasic({
ellipsis: true, ellipsis: true,
format: formatDataset, format: formatDataset,
}, },

{ {
label: '数据集挂载路径',
value: info.dataset_path,
label: '模型',
value: info.model,
ellipsis: true, ellipsis: true,
format: formatModel,
}, },
{ {
label: '总实验次数', label: '总实验次数',
@@ -113,11 +123,23 @@ function HyperParameterBasic({
{ {
label: '搜索算法', label: '搜索算法',
value: info.search_alg, value: info.search_alg,
format: formatEnum(searchAlgorithms),
ellipsis: true, ellipsis: true,
}, },
{ {
label: '调度算法', label: '调度算法',
value: info.scheduler, value: info.scheduler,
format: formatEnum(schedulerAlgorithms),
ellipsis: true,
},
{
label: '单次试验最大时间',
value: info.max_t,
ellipsis: true,
},
{
label: '最小试验数',
value: info.min_samples_required,
ellipsis: true, ellipsis: true,
}, },
{ {
@@ -203,7 +225,7 @@ function HyperParameterBasic({
<ConfigInfo <ConfigInfo
title="配置信息" title="配置信息"
data={configDatas} data={configDatas}
labelWidth={150}
labelWidth={120}
style={{ marginBottom: '20px' }} style={{ marginBottom: '20px' }}
> >
{info && <ParameterInfo info={info} />} {info && <ParameterInfo info={info} />}


+ 14
- 16
react-ui/src/pages/HyperParameter/components/ParameterInfo/index.tsx View File

@@ -70,22 +70,20 @@ function ParameterInfo({ info }: ParameterInfoProps) {


const runColumns: TableProps<Record<string, any>>['columns'] = const runColumns: TableProps<Record<string, any>>['columns'] =
runParameters.length > 0 runParameters.length > 0
? Object.keys(runParameters[0])
.filter((key) => key !== 'id')
.map((key) => {
return {
title: (
<Tooltip title={key}>
<span>{key}</span>
</Tooltip>
),
dataIndex: key,
key: key,
width: 150,
render: tableCellRender(true),
ellipsis: { showTitle: false },
};
})
? parameters.map(({ name }) => {
return {
title: (
<Tooltip title={name}>
<span>{name}</span>
</Tooltip>
),
dataIndex: name,
key: name,
width: 150,
render: tableCellRender(true),
ellipsis: { showTitle: false },
};
})
: []; : [];


return ( return (


+ 1
- 1
react-ui/src/pages/HyperParameter/types.ts View File

@@ -14,7 +14,7 @@ export type FormData = {
description: string; // 实验描述 description: string; // 实验描述
code: ParameterInputObject; // 代码 code: ParameterInputObject; // 代码
dataset: ParameterInputObject; // 数据集 dataset: ParameterInputObject; // 数据集
dataset_path: string; // 数据集路径
model: ParameterInputObject; // 模型
main_py: string; // 主函数代码文件 main_py: string; // 主函数代码文件
metric: string; // 指标 metric: string; // 指标
mode: string; // 优化方向 mode: string; // 优化方向


+ 2
- 2
react-ui/src/pages/ModelDeployment/components/VersionBasicInfo/index.tsx View File

@@ -3,7 +3,7 @@ import { ServiceRunStatus } from '@/enums';
import { useComputingResource } from '@/hooks/resource'; import { useComputingResource } from '@/hooks/resource';
import { ServiceVersionData } from '@/pages/ModelDeployment/types'; import { ServiceVersionData } from '@/pages/ModelDeployment/types';
import { formatDate } from '@/utils/date'; import { formatDate } from '@/utils/date';
import { formatModel, formatSelectCodeConfig } from '@/utils/format';
import { formatCodeConfig, formatModel } from '@/utils/format';
import { Flex } from 'antd'; import { Flex } from 'antd';
import ModelDeployStatusCell from '../ModelDeployStatusCell'; import ModelDeployStatusCell from '../ModelDeployStatusCell';


@@ -61,7 +61,7 @@ function VersionBasicInfo({ info }: BasicInfoProps) {
{ {
label: '代码配置', label: '代码配置',
value: info?.code_config, value: info?.code_config,
format: formatSelectCodeConfig,
format: formatCodeConfig,
ellipsis: true, ellipsis: true,
}, },
{ {


+ 23
- 35
react-ui/src/utils/format.ts View File

@@ -10,6 +10,13 @@ import { getGitUrl } from '@/utils';
// 格式化日期 // 格式化日期
export { formatDate } from '@/utils/date'; export { formatDate } from '@/utils/date';


type SelectedCodeConfig = {
code_path: string;
branch: string;
showValue?: string; // 前端使用的
show_value?: string; // 后端使用的
};

// 格式化数据集数组 // 格式化数据集数组
export const formatDatasets = (datasets?: DatasetData[]) => { export const formatDatasets = (datasets?: DatasetData[]) => {
if (!datasets || datasets.length === 0) { if (!datasets || datasets.length === 0) {
@@ -37,51 +44,32 @@ export const formatModel = (model: ModelData) => {
if (!model) { if (!model) {
return undefined; return undefined;
} }

return { return {
value: model.name, value: model.name,
link: `/dataset/model/info/${model.id}?tab=${ResourceInfoTabKeys.Introduction}&version=${model.version}&name=${model.name}&owner=${model.owner}&identifier=${model.identifier}`, link: `/dataset/model/info/${model.id}?tab=${ResourceInfoTabKeys.Introduction}&version=${model.version}&name=${model.name}&owner=${model.owner}&identifier=${model.identifier}`,
}; };
}; };


// 获取代码配置的仓库的 url
export const getRepoUrl = (project?: ProjectDependency) => {
if (!project) {
return undefined;
}
const { url, branch } = project;
return getGitUrl(url, branch);
};

// 格式化代码配置 // 格式化代码配置
export const formatCodeConfig = (project?: ProjectDependency) => {
export const formatCodeConfig = (project?: ProjectDependency | SelectedCodeConfig) => {
if (!project) { if (!project) {
return undefined; return undefined;
} }
return {
value: project.name,
url: getRepoUrl(project),
};
};

// 格式化选中的代码配置
export const formatSelectCodeConfig = (value?: {
code_path: string;
branch: string;
showValue?: string;
show_value?: string;
}) => {
if (!value) {
return undefined;
// 创建表单,CodeSelect 组件返回,目前有流水线、模型部署、超参数自动寻优创建时选择了代码配置
if ('code_path' in project) {
const { showValue, show_value, code_path, branch } = project;
return {
value: showValue || show_value,
url: getGitUrl(code_path, branch),
};
} else {
// 数据集和模型的代码配置
const { url, branch, name } = project;
return {
value: name,
url: getGitUrl(url, branch),
};
} }
const { showValue, show_value, code_path, branch } = value;
return {
value: showValue || show_value,
url: getRepoUrl({
url: code_path,
branch,
} as ProjectDependency),
};
}; };


// 格式化训练任务(实验实例) // 格式化训练任务(实验实例)
@@ -107,7 +95,7 @@ export const formatSource = (source?: string) => {
return source; return source;
}; };


// 格式化字符串数组
// 格式化字符串数组,以逗号分隔
export const formatList = (value: string[] | null | undefined): string => { export const formatList = (value: string[] | null | undefined): string => {
if ( if (
value === undefined || value === undefined ||


+ 2
- 1
react-ui/src/utils/table.tsx View File

@@ -4,6 +4,7 @@
* @Description: Table cell 自定义 render * @Description: Table cell 自定义 render
*/ */


import { isEmpty } from '@/utils';
import { formatDate } from '@/utils/date'; import { formatDate } from '@/utils/date';
import { Tooltip } from 'antd'; import { Tooltip } from 'antd';
import dayjs from 'dayjs'; import dayjs from 'dayjs';
@@ -113,7 +114,7 @@ function renderCell<T>(
} }


function renderText(text: any | undefined | null) { function renderText(text: any | undefined | null) {
return <span>{text ?? '--'}</span>;
return <span>{!isEmpty(text) ? text : '--'}</span>;
} }


function renderLink<T>( function renderLink<T>(


Loading…
Cancel
Save