From 4cecd72fc51e47229056aa2c4b74d48ffa6dceb1 Mon Sep 17 00:00:00 2001 From: cp3hnu Date: Thu, 16 Jan 2025 15:23:57 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E9=AA=8C=E8=AF=81=E8=B6=85=E5=8F=82?= =?UTF-8?q?=E6=95=B0=E8=87=AA=E5=8A=A8=E5=AF=BB=E4=BC=98=E7=9A=84=E6=89=8B?= =?UTF-8?q?=E5=8A=A8=E8=BF=90=E8=A1=8C=E5=8F=82=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../src/pages/HyperParameter/Create/index.tsx | 23 +++++- .../components/CreateForm/ExecuteConfig.tsx | 81 +++++++++++-------- .../components/CreateForm/index.less | 8 +- .../components/CreateForm/utils.ts | 52 ++++++++++++ .../components/HyperParameterBasic/index.tsx | 34 ++++++-- .../components/ParameterInfo/index.tsx | 30 ++++--- react-ui/src/pages/HyperParameter/types.ts | 2 +- .../components/VersionBasicInfo/index.tsx | 4 +- react-ui/src/utils/format.ts | 58 ++++++------- react-ui/src/utils/table.tsx | 3 +- 10 files changed, 195 insertions(+), 100 deletions(-) diff --git a/react-ui/src/pages/HyperParameter/Create/index.tsx b/react-ui/src/pages/HyperParameter/Create/index.tsx index fa5fc8b8..fa94809b 100644 --- a/react-ui/src/pages/HyperParameter/Create/index.tsx +++ b/react-ui/src/pages/HyperParameter/Create/index.tsx @@ -61,13 +61,28 @@ function CreateHyperparameter() { // 创建、更新、复制实验 const createExperiment = async (formData: FormData) => { // 按后台接口要求,修改参数表单数据结构,将 "value" 参数改为 "bounds"/"values"/"value" - const parameters = formData['parameters']; - parameters.forEach((item) => { + const formParameters = formData['parameters']; + + const parameters = formParameters.map((item) => { const paramName = getReqParamName(item.type); - item[paramName] = item.range; - delete item.range; + const range = item.range; + return { + ...item, + [paramName]: range, + range: undefined, + }; }); + // const runParameters = formData['points_to_evaluate']; + // for (const item of parameters) { + // const name = item.name; + // const arr = runParameters.filter((item) => isEmpty(item[name])); + // if (arr.length > 0 && arr.length < runParameters.length) { + // message.error(`手动运行参数 ${name} 必须全部填写或者都不填写`); + // return; + // } + // } + // 根据后台要求,修改表单数据 const object = { ...formData, diff --git a/react-ui/src/pages/HyperParameter/components/CreateForm/ExecuteConfig.tsx b/react-ui/src/pages/HyperParameter/components/CreateForm/ExecuteConfig.tsx index b900c2e5..056ae7e4 100644 --- a/react-ui/src/pages/HyperParameter/components/CreateForm/ExecuteConfig.tsx +++ b/react-ui/src/pages/HyperParameter/components/CreateForm/ExecuteConfig.tsx @@ -7,24 +7,20 @@ import ResourceSelect, { import SubAreaTitle from '@/components/SubAreaTitle'; import { hyperParameterOptimizedModeOptions } from '@/enums'; import { useComputingResource } from '@/hooks/resource'; +import { isEmpty } from '@/utils'; import { modalConfirm } from '@/utils/ui'; import { MinusCircleOutlined, PlusCircleOutlined, QuestionCircleOutlined } from '@ant-design/icons'; import { Button, Col, Flex, Form, Input, InputNumber, Radio, Row, Select, Tooltip } from 'antd'; import { isEqual } from 'lodash'; import PopParameterRange from './PopParameterRange'; import styles from './index.less'; -import { axParameterOptions, parameterOptions, type FormParameter } from './utils'; - -// 搜索算法 -const searchAlgorithms = ['HyperOpt', 'HEBO', 'BayesOpt', 'Optuna', 'ZOOpt', 'Ax'].map((name) => ({ - label: name, - value: name, -})); - -// 调度算法 -const schedulerAlgorithms = ['ASHA', 'HyperBand', 'MedianStopping', 'PopulationBased', 'PB2'].map( - (name) => ({ label: name, value: name }), -); +import { + axParameterOptions, + parameterOptions, + schedulerAlgorithms, + searchAlgorithms, + type FormParameter, +} from './utils'; const parameterTooltip = `uniform(-5, -1) 在 -5.0 和 -1.0 之间均匀采样浮点数 @@ -146,17 +142,13 @@ function ExecuteConfig() { - - + + @@ -405,15 +397,36 @@ function ExecuteConfig() { } return ( - - {(fields, { add, remove }) => ( + { + const parameters = form.getFieldValue('parameters'); + for (const item of parameters) { + const name = item.name; + const arr = runParameters.filter((item?: Record) => + isEmpty(item?.[name]), + ); + if (arr.length > 0 && arr.length < runParameters.length) { + return Promise.reject( + new Error(`手动运行参数 ${name} 必须全部填写或者都不填写`), + ); + } + } + + return Promise.resolve(); + }, + }, + ]} + > + {(fields, { add, remove }, { errors }) => ( <> @@ -429,15 +442,14 @@ function ExecuteConfig() { labelCol={{ flex: '140px' }} name={[name, item.name]} preserve={false} - required - rules={[ - { - required: true, - message: '请输入', - }, - ]} > - + form.validateFields(['points_to_evaluate'])} + /> ))} @@ -472,6 +484,7 @@ function ExecuteConfig() { ))} + )} diff --git a/react-ui/src/pages/HyperParameter/components/CreateForm/index.less b/react-ui/src/pages/HyperParameter/components/CreateForm/index.less index 7d945b21..7c218f63 100644 --- a/react-ui/src/pages/HyperParameter/components/CreateForm/index.less +++ b/react-ui/src/pages/HyperParameter/components/CreateForm/index.less @@ -119,13 +119,19 @@ flex: 1; margin-right: 10px; padding: 20px 20px 0; - border: 1px dashed #dddddd; + border: 1px dashed #e0e0e0; border-radius: 8px; } + &__operation { display: flex; flex: none; align-items: center; width: 100px; } + + &__error { + margin-top: -20px; + color: @error-color; + } } diff --git a/react-ui/src/pages/HyperParameter/components/CreateForm/utils.ts b/react-ui/src/pages/HyperParameter/components/CreateForm/utils.ts index 95aa3651..f3f92555 100644 --- a/react-ui/src/pages/HyperParameter/components/CreateForm/utils.ts +++ b/react-ui/src/pages/HyperParameter/components/CreateForm/utils.ts @@ -153,3 +153,55 @@ export const getReqParamName = (type: ParameterType) => { return 'bounds'; } }; + +// 搜索算法 +export const searchAlgorithms = [ + { + label: 'HyperOpt(分布式异步超参数优化)', + value: 'HyperOpt', + }, + { + label: 'HEBO(异方差进化贝叶斯优化)', + value: 'HEBO', + }, + { + label: 'BayesOpt(贝叶斯优化)', + value: 'BayesOpt', + }, + { + label: 'Optuna', + value: 'Optuna', + }, + { + label: 'ZOOpt', + value: 'ZOOpt', + }, + { + label: 'Ax', + value: 'Ax', + }, +]; + +// 调度算法 +export const schedulerAlgorithms = [ + { + label: 'ASHA(异步连续减半)', + value: 'ASHA', + }, + { + label: 'HyperBand(HyperBand 早停算法)', + value: 'HyperBand', + }, + { + label: 'MedianStopping(中值停止规则)', + value: 'MedianStopping', + }, + { + label: 'PopulationBased(基于种群训练)', + value: 'PopulationBased', + }, + { + label: 'PB2(Population Based Bandits)', + value: 'PB2', + }, +]; diff --git a/react-ui/src/pages/HyperParameter/components/HyperParameterBasic/index.tsx b/react-ui/src/pages/HyperParameter/components/HyperParameterBasic/index.tsx index 4daa2298..88a88800 100644 --- a/react-ui/src/pages/HyperParameter/components/HyperParameterBasic/index.tsx +++ b/react-ui/src/pages/HyperParameter/components/HyperParameterBasic/index.tsx @@ -2,10 +2,20 @@ import { hyperParameterOptimizedMode } from '@/enums'; import { useComputingResource } from '@/hooks/resource'; import ConfigInfo, { type BasicInfoData } from '@/pages/AutoML/components/ConfigInfo'; import { experimentStatusInfo } from '@/pages/Experiment/status'; +import { + schedulerAlgorithms, + searchAlgorithms, +} from '@/pages/HyperParameter/components/CreateForm/utils'; import { HyperparameterData } from '@/pages/HyperParameter/types'; import { type NodeStatus } from '@/types'; import { elapsedTime } from '@/utils/date'; -import { formatDataset, formatDate, formatSelectCodeConfig } from '@/utils/format'; +import { + formatCodeConfig, + formatDataset, + formatDate, + formatEnum, + formatModel, +} from '@/utils/format'; import { Flex } from 'antd'; import classNames from 'classnames'; import { useMemo } from 'react'; @@ -86,7 +96,7 @@ function HyperParameterBasic({ label: '代码', value: info.code, ellipsis: true, - format: formatSelectCodeConfig, + format: formatCodeConfig, }, { label: '主函数代码文件', @@ -99,11 +109,11 @@ function HyperParameterBasic({ ellipsis: true, format: formatDataset, }, - { - label: '数据集挂载路径', - value: info.dataset_path, + label: '模型', + value: info.model, ellipsis: true, + format: formatModel, }, { label: '总实验次数', @@ -113,11 +123,23 @@ function HyperParameterBasic({ { label: '搜索算法', value: info.search_alg, + format: formatEnum(searchAlgorithms), ellipsis: true, }, { label: '调度算法', value: info.scheduler, + format: formatEnum(schedulerAlgorithms), + ellipsis: true, + }, + { + label: '单次试验最大时间', + value: info.max_t, + ellipsis: true, + }, + { + label: '最小试验数', + value: info.min_samples_required, ellipsis: true, }, { @@ -203,7 +225,7 @@ function HyperParameterBasic({ {info && } diff --git a/react-ui/src/pages/HyperParameter/components/ParameterInfo/index.tsx b/react-ui/src/pages/HyperParameter/components/ParameterInfo/index.tsx index 9b415d85..95a24e1f 100644 --- a/react-ui/src/pages/HyperParameter/components/ParameterInfo/index.tsx +++ b/react-ui/src/pages/HyperParameter/components/ParameterInfo/index.tsx @@ -70,22 +70,20 @@ function ParameterInfo({ info }: ParameterInfoProps) { const runColumns: TableProps>['columns'] = runParameters.length > 0 - ? Object.keys(runParameters[0]) - .filter((key) => key !== 'id') - .map((key) => { - return { - title: ( - - {key} - - ), - dataIndex: key, - key: key, - width: 150, - render: tableCellRender(true), - ellipsis: { showTitle: false }, - }; - }) + ? parameters.map(({ name }) => { + return { + title: ( + + {name} + + ), + dataIndex: name, + key: name, + width: 150, + render: tableCellRender(true), + ellipsis: { showTitle: false }, + }; + }) : []; return ( diff --git a/react-ui/src/pages/HyperParameter/types.ts b/react-ui/src/pages/HyperParameter/types.ts index 68f77fb2..568dbd4a 100644 --- a/react-ui/src/pages/HyperParameter/types.ts +++ b/react-ui/src/pages/HyperParameter/types.ts @@ -14,7 +14,7 @@ export type FormData = { description: string; // 实验描述 code: ParameterInputObject; // 代码 dataset: ParameterInputObject; // 数据集 - dataset_path: string; // 数据集路径 + model: ParameterInputObject; // 模型 main_py: string; // 主函数代码文件 metric: string; // 指标 mode: string; // 优化方向 diff --git a/react-ui/src/pages/ModelDeployment/components/VersionBasicInfo/index.tsx b/react-ui/src/pages/ModelDeployment/components/VersionBasicInfo/index.tsx index 3ddfb0e2..28da0127 100644 --- a/react-ui/src/pages/ModelDeployment/components/VersionBasicInfo/index.tsx +++ b/react-ui/src/pages/ModelDeployment/components/VersionBasicInfo/index.tsx @@ -3,7 +3,7 @@ import { ServiceRunStatus } from '@/enums'; import { useComputingResource } from '@/hooks/resource'; import { ServiceVersionData } from '@/pages/ModelDeployment/types'; import { formatDate } from '@/utils/date'; -import { formatModel, formatSelectCodeConfig } from '@/utils/format'; +import { formatCodeConfig, formatModel } from '@/utils/format'; import { Flex } from 'antd'; import ModelDeployStatusCell from '../ModelDeployStatusCell'; @@ -61,7 +61,7 @@ function VersionBasicInfo({ info }: BasicInfoProps) { { label: '代码配置', value: info?.code_config, - format: formatSelectCodeConfig, + format: formatCodeConfig, ellipsis: true, }, { diff --git a/react-ui/src/utils/format.ts b/react-ui/src/utils/format.ts index c7d34e7d..8e7bbbf4 100644 --- a/react-ui/src/utils/format.ts +++ b/react-ui/src/utils/format.ts @@ -10,6 +10,13 @@ import { getGitUrl } from '@/utils'; // 格式化日期 export { formatDate } from '@/utils/date'; +type SelectedCodeConfig = { + code_path: string; + branch: string; + showValue?: string; // 前端使用的 + show_value?: string; // 后端使用的 +}; + // 格式化数据集数组 export const formatDatasets = (datasets?: DatasetData[]) => { if (!datasets || datasets.length === 0) { @@ -37,51 +44,32 @@ export const formatModel = (model: ModelData) => { if (!model) { return undefined; } - return { value: model.name, link: `/dataset/model/info/${model.id}?tab=${ResourceInfoTabKeys.Introduction}&version=${model.version}&name=${model.name}&owner=${model.owner}&identifier=${model.identifier}`, }; }; -// 获取代码配置的仓库的 url -export const getRepoUrl = (project?: ProjectDependency) => { - if (!project) { - return undefined; - } - const { url, branch } = project; - return getGitUrl(url, branch); -}; - // 格式化代码配置 -export const formatCodeConfig = (project?: ProjectDependency) => { +export const formatCodeConfig = (project?: ProjectDependency | SelectedCodeConfig) => { if (!project) { return undefined; } - return { - value: project.name, - url: getRepoUrl(project), - }; -}; - -// 格式化选中的代码配置 -export const formatSelectCodeConfig = (value?: { - code_path: string; - branch: string; - showValue?: string; - show_value?: string; -}) => { - if (!value) { - return undefined; + // 创建表单,CodeSelect 组件返回,目前有流水线、模型部署、超参数自动寻优创建时选择了代码配置 + if ('code_path' in project) { + const { showValue, show_value, code_path, branch } = project; + return { + value: showValue || show_value, + url: getGitUrl(code_path, branch), + }; + } else { + // 数据集和模型的代码配置 + const { url, branch, name } = project; + return { + value: name, + url: getGitUrl(url, branch), + }; } - const { showValue, show_value, code_path, branch } = value; - return { - value: showValue || show_value, - url: getRepoUrl({ - url: code_path, - branch, - } as ProjectDependency), - }; }; // 格式化训练任务(实验实例) @@ -107,7 +95,7 @@ export const formatSource = (source?: string) => { return source; }; -// 格式化字符串数组 +// 格式化字符串数组,以逗号分隔 export const formatList = (value: string[] | null | undefined): string => { if ( value === undefined || diff --git a/react-ui/src/utils/table.tsx b/react-ui/src/utils/table.tsx index 0d4b1927..d3ec10d6 100644 --- a/react-ui/src/utils/table.tsx +++ b/react-ui/src/utils/table.tsx @@ -4,6 +4,7 @@ * @Description: Table cell 自定义 render */ +import { isEmpty } from '@/utils'; import { formatDate } from '@/utils/date'; import { Tooltip } from 'antd'; import dayjs from 'dayjs'; @@ -113,7 +114,7 @@ function renderCell( } function renderText(text: any | undefined | null) { - return {text ?? '--'}; + return {!isEmpty(text) ? text : '--'}; } function renderLink(