|
- export enum ParameterType {
- Uniform = 'uniform',
- QUniform = 'quniform',
- LogUniform = 'loguniform',
- QLogUniform = 'qloguniform',
- Randn = 'randn',
- QRandn = 'qrandn',
- RandInt = 'randint',
- QRandInt = 'qrandint',
- LogRandInt = 'lograndint',
- QLogRandInt = 'qlograndint',
- Choice = 'choice',
- Grid = 'grid',
- Range = 'range',
- Fixed = 'fixed',
- }
-
- export const parameterOptions = [
- 'uniform',
- 'quniform',
- 'loguniform',
- 'qloguniform',
- 'randn',
- 'qrandn',
- 'randint',
- 'qrandint',
- 'lograndint',
- 'qlograndint',
- 'choice',
- 'grid',
- ].map((name) => ({
- label: name,
- value: name,
- }));
-
- export const axParameterOptions = ['fixed', 'range', 'choice'].map((name) => ({
- label: name,
- value: name,
- }));
-
- export const parameterTooltip: Record<ParameterType, string> = {
- [ParameterType.Uniform]: '在 low 和 high 之间均匀采样浮点数',
- [ParameterType.QUniform]: '在 low 和 high 之间均匀采样浮点数,四舍五入到 q 的倍数',
- [ParameterType.LogUniform]: '在 low 和 high 之间均匀采样浮点数,对数空间采样',
- [ParameterType.QLogUniform]:
- '在 low 和 high 之间均匀采样浮点数,对数空间采样并四舍五入到 q 的倍数',
- [ParameterType.Randn]: '在均值为 m,方差为 s 的正态分布中进行随机浮点数抽样',
- [ParameterType.QRandn]:
- '在均值为 m,方差为 s 的正态分布中进行随机浮点数抽样,四舍五入到 q 的倍数',
- [ParameterType.RandInt]: '在 low(包括)到 high(不包括)之间均匀采样整数',
- [ParameterType.QRandInt]:
- '在 low(包括)到 high(不包括)之间均匀采样整数,四舍五入到 q 的倍数(包括 high)',
- [ParameterType.LogRandInt]: '在 low(包括)到 high(不包括)之间对数空间上均匀采样整数',
- [ParameterType.QLogRandInt]:
- '在 low(包括)到 high(不包括)之间对数空间上均匀采样整数,并四舍五入到 q 的倍数',
- [ParameterType.Choice]: '从指定的选项中采样一个选项',
- [ParameterType.Grid]: '对选项进行网格搜索,每个值都将被采样',
- [ParameterType.Range]: '在 low 和 high 范围内采样取值',
- [ParameterType.Fixed]: '固定取值',
- };
-
- export type ParameterData = {
- label: string;
- name: string;
- value?: number;
- };
-
- // 参数表单数据
- export type FormParameter = {
- name: string; // 参数名称
- type: ParameterType; // 参数类型
- range: any; // 参数值
- [key: string]: any;
- };
-
- export const getFormOptions = (type?: ParameterType, value?: number[]): ParameterData[] => {
- const numbers =
- value?.map((item) => {
- const num = Number(item);
- if (isNaN(num)) {
- return undefined;
- }
- return num;
- }) ?? [];
- switch (type) {
- case ParameterType.Uniform:
- case ParameterType.LogUniform:
- case ParameterType.RandInt:
- case ParameterType.LogRandInt:
- case ParameterType.Range:
- return [
- {
- name: 'low',
- label: '最小值',
- value: numbers?.[0],
- },
- {
- name: 'high',
- label: '最大值',
- value: numbers?.[1],
- },
- ];
- case ParameterType.QUniform:
- case ParameterType.QLogUniform:
- case ParameterType.QRandInt:
- case ParameterType.QLogRandInt:
- return [
- {
- name: 'low',
- label: '最小值',
- value: numbers?.[0],
- },
- {
- name: 'high',
- label: '最大值',
- value: numbers?.[1],
- },
- {
- name: 'q',
- label: '间隔',
- value: numbers?.[2],
- },
- ];
- case ParameterType.Randn:
- return [
- {
- name: 'm',
- label: '均值',
- value: numbers?.[0],
- },
- {
- name: 's',
- label: '方差',
- value: numbers?.[1],
- },
- ];
- case ParameterType.QRandn:
- return [
- {
- name: 'm',
- label: '均值',
- value: numbers?.[0],
- },
- {
- name: 's',
- label: '方差',
- value: numbers?.[1],
- },
- {
- name: 'q',
- label: '间隔',
- value: numbers?.[2],
- },
- ];
- case ParameterType.Fixed:
- return [
- {
- name: 'value',
- label: '值',
- value: numbers?.[0],
- },
- ];
- default:
- return [];
- }
- };
-
- export const getReqParamName = (type: ParameterType) => {
- if (type === ParameterType.Fixed) {
- return 'value';
- } else if (type === ParameterType.Choice || type === ParameterType.Grid) {
- return 'values';
- } else {
- return 'bounds';
- }
- };
-
- // 搜索算法
- export const searchAlgorithms = [
- {
- label: 'HyperOpt(分布式异步超参数优化)',
- value: 'HyperOpt',
- },
- {
- label: 'HEBO(异方差进化贝叶斯优化)',
- value: 'HEBO',
- },
- {
- label: 'BayesOpt(贝叶斯优化)',
- value: 'BayesOpt',
- },
- {
- label: 'Optuna',
- value: 'Optuna',
- },
- {
- label: 'ZOOpt',
- value: 'ZOOpt',
- },
- {
- label: 'Ax',
- value: 'Ax',
- },
- ];
-
- // 调度算法
- export const schedulerAlgorithms = [
- {
- label: 'ASHA(异步连续减半)',
- value: 'ASHA',
- },
- {
- label: 'HyperBand(HyperBand 早停算法)',
- value: 'HyperBand',
- },
- {
- label: 'MedianStopping(中值停止规则)',
- value: 'MedianStopping',
- },
- {
- label: 'PopulationBased(基于种群训练)',
- value: 'PopulationBased',
- },
- {
- label: 'PB2(Population Based Bandits)',
- value: 'PB2',
- },
- ];
|