diff --git a/react-ui/src/pages/HyperParameter/Create/index.tsx b/react-ui/src/pages/HyperParameter/Create/index.tsx
index fa5fc8b8..fa94809b 100644
--- a/react-ui/src/pages/HyperParameter/Create/index.tsx
+++ b/react-ui/src/pages/HyperParameter/Create/index.tsx
@@ -61,13 +61,28 @@ function CreateHyperparameter() {
// 创建、更新、复制实验
const createExperiment = async (formData: FormData) => {
// 按后台接口要求,修改参数表单数据结构,将 "value" 参数改为 "bounds"/"values"/"value"
- const parameters = formData['parameters'];
- parameters.forEach((item) => {
+ const formParameters = formData['parameters'];
+
+ const parameters = formParameters.map((item) => {
const paramName = getReqParamName(item.type);
- item[paramName] = item.range;
- delete item.range;
+ const range = item.range;
+ return {
+ ...item,
+ [paramName]: range,
+ range: undefined,
+ };
});
+ // const runParameters = formData['points_to_evaluate'];
+ // for (const item of parameters) {
+ // const name = item.name;
+ // const arr = runParameters.filter((item) => isEmpty(item[name]));
+ // if (arr.length > 0 && arr.length < runParameters.length) {
+ // message.error(`手动运行参数 ${name} 必须全部填写或者都不填写`);
+ // return;
+ // }
+ // }
+
// 根据后台要求,修改表单数据
const object = {
...formData,
diff --git a/react-ui/src/pages/HyperParameter/components/CreateForm/ExecuteConfig.tsx b/react-ui/src/pages/HyperParameter/components/CreateForm/ExecuteConfig.tsx
index b900c2e5..056ae7e4 100644
--- a/react-ui/src/pages/HyperParameter/components/CreateForm/ExecuteConfig.tsx
+++ b/react-ui/src/pages/HyperParameter/components/CreateForm/ExecuteConfig.tsx
@@ -7,24 +7,20 @@ import ResourceSelect, {
import SubAreaTitle from '@/components/SubAreaTitle';
import { hyperParameterOptimizedModeOptions } from '@/enums';
import { useComputingResource } from '@/hooks/resource';
+import { isEmpty } from '@/utils';
import { modalConfirm } from '@/utils/ui';
import { MinusCircleOutlined, PlusCircleOutlined, QuestionCircleOutlined } from '@ant-design/icons';
import { Button, Col, Flex, Form, Input, InputNumber, Radio, Row, Select, Tooltip } from 'antd';
import { isEqual } from 'lodash';
import PopParameterRange from './PopParameterRange';
import styles from './index.less';
-import { axParameterOptions, parameterOptions, type FormParameter } from './utils';
-
-// 搜索算法
-const searchAlgorithms = ['HyperOpt', 'HEBO', 'BayesOpt', 'Optuna', 'ZOOpt', 'Ax'].map((name) => ({
- label: name,
- value: name,
-}));
-
-// 调度算法
-const schedulerAlgorithms = ['ASHA', 'HyperBand', 'MedianStopping', 'PopulationBased', 'PB2'].map(
- (name) => ({ label: name, value: name }),
-);
+import {
+ axParameterOptions,
+ parameterOptions,
+ schedulerAlgorithms,
+ searchAlgorithms,
+ type FormParameter,
+} from './utils';
const parameterTooltip = `uniform(-5, -1)
在 -5.0 和 -1.0 之间均匀采样浮点数
@@ -146,17 +142,13 @@ function ExecuteConfig() {
-
-
+
+
@@ -405,15 +397,36 @@ function ExecuteConfig() {
}
return (
-
- {(fields, { add, remove }) => (
+ {
+ const parameters = form.getFieldValue('parameters');
+ for (const item of parameters) {
+ const name = item.name;
+ const arr = runParameters.filter((item?: Record) =>
+ isEmpty(item?.[name]),
+ );
+ if (arr.length > 0 && arr.length < runParameters.length) {
+ return Promise.reject(
+ new Error(`手动运行参数 ${name} 必须全部填写或者都不填写`),
+ );
+ }
+ }
+
+ return Promise.resolve();
+ },
+ },
+ ]}
+ >
+ {(fields, { add, remove }, { errors }) => (
<>
@@ -429,15 +442,14 @@ function ExecuteConfig() {
labelCol={{ flex: '140px' }}
name={[name, item.name]}
preserve={false}
- required
- rules={[
- {
- required: true,
- message: '请输入',
- },
- ]}
>
-
+ form.validateFields(['points_to_evaluate'])}
+ />
))}
@@ -472,6 +484,7 @@ function ExecuteConfig() {
))}
+
>
)}
diff --git a/react-ui/src/pages/HyperParameter/components/CreateForm/index.less b/react-ui/src/pages/HyperParameter/components/CreateForm/index.less
index 7d945b21..7c218f63 100644
--- a/react-ui/src/pages/HyperParameter/components/CreateForm/index.less
+++ b/react-ui/src/pages/HyperParameter/components/CreateForm/index.less
@@ -119,13 +119,19 @@
flex: 1;
margin-right: 10px;
padding: 20px 20px 0;
- border: 1px dashed #dddddd;
+ border: 1px dashed #e0e0e0;
border-radius: 8px;
}
+
&__operation {
display: flex;
flex: none;
align-items: center;
width: 100px;
}
+
+ &__error {
+ margin-top: -20px;
+ color: @error-color;
+ }
}
diff --git a/react-ui/src/pages/HyperParameter/components/CreateForm/utils.ts b/react-ui/src/pages/HyperParameter/components/CreateForm/utils.ts
index 95aa3651..f3f92555 100644
--- a/react-ui/src/pages/HyperParameter/components/CreateForm/utils.ts
+++ b/react-ui/src/pages/HyperParameter/components/CreateForm/utils.ts
@@ -153,3 +153,55 @@ export const getReqParamName = (type: ParameterType) => {
return 'bounds';
}
};
+
+// 搜索算法
+export const searchAlgorithms = [
+ {
+ label: 'HyperOpt(分布式异步超参数优化)',
+ value: 'HyperOpt',
+ },
+ {
+ label: 'HEBO(异方差进化贝叶斯优化)',
+ value: 'HEBO',
+ },
+ {
+ label: 'BayesOpt(贝叶斯优化)',
+ value: 'BayesOpt',
+ },
+ {
+ label: 'Optuna',
+ value: 'Optuna',
+ },
+ {
+ label: 'ZOOpt',
+ value: 'ZOOpt',
+ },
+ {
+ label: 'Ax',
+ value: 'Ax',
+ },
+];
+
+// 调度算法
+export const schedulerAlgorithms = [
+ {
+ label: 'ASHA(异步连续减半)',
+ value: 'ASHA',
+ },
+ {
+ label: 'HyperBand(HyperBand 早停算法)',
+ value: 'HyperBand',
+ },
+ {
+ label: 'MedianStopping(中值停止规则)',
+ value: 'MedianStopping',
+ },
+ {
+ label: 'PopulationBased(基于种群训练)',
+ value: 'PopulationBased',
+ },
+ {
+ label: 'PB2(Population Based Bandits)',
+ value: 'PB2',
+ },
+];
diff --git a/react-ui/src/pages/HyperParameter/components/HyperParameterBasic/index.tsx b/react-ui/src/pages/HyperParameter/components/HyperParameterBasic/index.tsx
index 4daa2298..88a88800 100644
--- a/react-ui/src/pages/HyperParameter/components/HyperParameterBasic/index.tsx
+++ b/react-ui/src/pages/HyperParameter/components/HyperParameterBasic/index.tsx
@@ -2,10 +2,20 @@ import { hyperParameterOptimizedMode } from '@/enums';
import { useComputingResource } from '@/hooks/resource';
import ConfigInfo, { type BasicInfoData } from '@/pages/AutoML/components/ConfigInfo';
import { experimentStatusInfo } from '@/pages/Experiment/status';
+import {
+ schedulerAlgorithms,
+ searchAlgorithms,
+} from '@/pages/HyperParameter/components/CreateForm/utils';
import { HyperparameterData } from '@/pages/HyperParameter/types';
import { type NodeStatus } from '@/types';
import { elapsedTime } from '@/utils/date';
-import { formatDataset, formatDate, formatSelectCodeConfig } from '@/utils/format';
+import {
+ formatCodeConfig,
+ formatDataset,
+ formatDate,
+ formatEnum,
+ formatModel,
+} from '@/utils/format';
import { Flex } from 'antd';
import classNames from 'classnames';
import { useMemo } from 'react';
@@ -86,7 +96,7 @@ function HyperParameterBasic({
label: '代码',
value: info.code,
ellipsis: true,
- format: formatSelectCodeConfig,
+ format: formatCodeConfig,
},
{
label: '主函数代码文件',
@@ -99,11 +109,11 @@ function HyperParameterBasic({
ellipsis: true,
format: formatDataset,
},
-
{
- label: '数据集挂载路径',
- value: info.dataset_path,
+ label: '模型',
+ value: info.model,
ellipsis: true,
+ format: formatModel,
},
{
label: '总实验次数',
@@ -113,11 +123,23 @@ function HyperParameterBasic({
{
label: '搜索算法',
value: info.search_alg,
+ format: formatEnum(searchAlgorithms),
ellipsis: true,
},
{
label: '调度算法',
value: info.scheduler,
+ format: formatEnum(schedulerAlgorithms),
+ ellipsis: true,
+ },
+ {
+ label: '单次试验最大时间',
+ value: info.max_t,
+ ellipsis: true,
+ },
+ {
+ label: '最小试验数',
+ value: info.min_samples_required,
ellipsis: true,
},
{
@@ -203,7 +225,7 @@ function HyperParameterBasic({
{info && }
diff --git a/react-ui/src/pages/HyperParameter/components/ParameterInfo/index.tsx b/react-ui/src/pages/HyperParameter/components/ParameterInfo/index.tsx
index 9b415d85..95a24e1f 100644
--- a/react-ui/src/pages/HyperParameter/components/ParameterInfo/index.tsx
+++ b/react-ui/src/pages/HyperParameter/components/ParameterInfo/index.tsx
@@ -70,22 +70,20 @@ function ParameterInfo({ info }: ParameterInfoProps) {
const runColumns: TableProps>['columns'] =
runParameters.length > 0
- ? Object.keys(runParameters[0])
- .filter((key) => key !== 'id')
- .map((key) => {
- return {
- title: (
-
- {key}
-
- ),
- dataIndex: key,
- key: key,
- width: 150,
- render: tableCellRender(true),
- ellipsis: { showTitle: false },
- };
- })
+ ? parameters.map(({ name }) => {
+ return {
+ title: (
+
+ {name}
+
+ ),
+ dataIndex: name,
+ key: name,
+ width: 150,
+ render: tableCellRender(true),
+ ellipsis: { showTitle: false },
+ };
+ })
: [];
return (
diff --git a/react-ui/src/pages/HyperParameter/types.ts b/react-ui/src/pages/HyperParameter/types.ts
index 68f77fb2..568dbd4a 100644
--- a/react-ui/src/pages/HyperParameter/types.ts
+++ b/react-ui/src/pages/HyperParameter/types.ts
@@ -14,7 +14,7 @@ export type FormData = {
description: string; // 实验描述
code: ParameterInputObject; // 代码
dataset: ParameterInputObject; // 数据集
- dataset_path: string; // 数据集路径
+ model: ParameterInputObject; // 模型
main_py: string; // 主函数代码文件
metric: string; // 指标
mode: string; // 优化方向
diff --git a/react-ui/src/pages/ModelDeployment/components/VersionBasicInfo/index.tsx b/react-ui/src/pages/ModelDeployment/components/VersionBasicInfo/index.tsx
index 3ddfb0e2..28da0127 100644
--- a/react-ui/src/pages/ModelDeployment/components/VersionBasicInfo/index.tsx
+++ b/react-ui/src/pages/ModelDeployment/components/VersionBasicInfo/index.tsx
@@ -3,7 +3,7 @@ import { ServiceRunStatus } from '@/enums';
import { useComputingResource } from '@/hooks/resource';
import { ServiceVersionData } from '@/pages/ModelDeployment/types';
import { formatDate } from '@/utils/date';
-import { formatModel, formatSelectCodeConfig } from '@/utils/format';
+import { formatCodeConfig, formatModel } from '@/utils/format';
import { Flex } from 'antd';
import ModelDeployStatusCell from '../ModelDeployStatusCell';
@@ -61,7 +61,7 @@ function VersionBasicInfo({ info }: BasicInfoProps) {
{
label: '代码配置',
value: info?.code_config,
- format: formatSelectCodeConfig,
+ format: formatCodeConfig,
ellipsis: true,
},
{
diff --git a/react-ui/src/utils/format.ts b/react-ui/src/utils/format.ts
index c7d34e7d..8e7bbbf4 100644
--- a/react-ui/src/utils/format.ts
+++ b/react-ui/src/utils/format.ts
@@ -10,6 +10,13 @@ import { getGitUrl } from '@/utils';
// 格式化日期
export { formatDate } from '@/utils/date';
+type SelectedCodeConfig = {
+ code_path: string;
+ branch: string;
+ showValue?: string; // 前端使用的
+ show_value?: string; // 后端使用的
+};
+
// 格式化数据集数组
export const formatDatasets = (datasets?: DatasetData[]) => {
if (!datasets || datasets.length === 0) {
@@ -37,51 +44,32 @@ export const formatModel = (model: ModelData) => {
if (!model) {
return undefined;
}
-
return {
value: model.name,
link: `/dataset/model/info/${model.id}?tab=${ResourceInfoTabKeys.Introduction}&version=${model.version}&name=${model.name}&owner=${model.owner}&identifier=${model.identifier}`,
};
};
-// 获取代码配置的仓库的 url
-export const getRepoUrl = (project?: ProjectDependency) => {
- if (!project) {
- return undefined;
- }
- const { url, branch } = project;
- return getGitUrl(url, branch);
-};
-
// 格式化代码配置
-export const formatCodeConfig = (project?: ProjectDependency) => {
+export const formatCodeConfig = (project?: ProjectDependency | SelectedCodeConfig) => {
if (!project) {
return undefined;
}
- return {
- value: project.name,
- url: getRepoUrl(project),
- };
-};
-
-// 格式化选中的代码配置
-export const formatSelectCodeConfig = (value?: {
- code_path: string;
- branch: string;
- showValue?: string;
- show_value?: string;
-}) => {
- if (!value) {
- return undefined;
+ // 创建表单,CodeSelect 组件返回,目前有流水线、模型部署、超参数自动寻优创建时选择了代码配置
+ if ('code_path' in project) {
+ const { showValue, show_value, code_path, branch } = project;
+ return {
+ value: showValue || show_value,
+ url: getGitUrl(code_path, branch),
+ };
+ } else {
+ // 数据集和模型的代码配置
+ const { url, branch, name } = project;
+ return {
+ value: name,
+ url: getGitUrl(url, branch),
+ };
}
- const { showValue, show_value, code_path, branch } = value;
- return {
- value: showValue || show_value,
- url: getRepoUrl({
- url: code_path,
- branch,
- } as ProjectDependency),
- };
};
// 格式化训练任务(实验实例)
@@ -107,7 +95,7 @@ export const formatSource = (source?: string) => {
return source;
};
-// 格式化字符串数组
+// 格式化字符串数组,以逗号分隔
export const formatList = (value: string[] | null | undefined): string => {
if (
value === undefined ||
diff --git a/react-ui/src/utils/table.tsx b/react-ui/src/utils/table.tsx
index 0d4b1927..d3ec10d6 100644
--- a/react-ui/src/utils/table.tsx
+++ b/react-ui/src/utils/table.tsx
@@ -4,6 +4,7 @@
* @Description: Table cell 自定义 render
*/
+import { isEmpty } from '@/utils';
import { formatDate } from '@/utils/date';
import { Tooltip } from 'antd';
import dayjs from 'dayjs';
@@ -113,7 +114,7 @@ function renderCell(
}
function renderText(text: any | undefined | null) {
- return {text ?? '--'};
+ return {!isEmpty(text) ? text : '--'};
}
function renderLink(