|
- export enum ParameterType {
- Uniform = 'uniform',
- QUniform = 'quniform',
- LogUniform = 'loguniform',
- QLogUniform = 'qloguniform',
- Randn = 'randn',
- QRandn = 'qrandn',
- RandInt = 'randint',
- QRandInt = 'qrandint',
- LogRandInt = 'lograndint',
- QLogRandInt = 'qlograndint',
- Choice = 'choice',
- Grid = 'grid',
- Range = 'range',
- Fixed = 'fixed',
- }
-
- export const parameterOptions = [
- 'uniform',
- 'quniform',
- 'loguniform',
- 'qloguniform',
- 'randn',
- 'qrandn',
- 'randint',
- 'qrandint',
- 'lograndint',
- 'qlograndint',
- 'choice',
- 'grid',
- ].map((name) => ({
- label: name,
- value: name,
- }));
-
- export const axParameterOptions = ['fixed', 'range', 'choice'].map((name) => ({
- label: name,
- value: name,
- }));
-
- export type ParameterData = {
- label: string;
- name: string;
- value?: number;
- };
-
- // 参数表单数据
- export type FormParameter = {
- name: string; // 参数名称
- type: ParameterType; // 参数类型
- range: any; // 参数值
- [key: string]: any;
- };
-
- export const getFormOptions = (type?: ParameterType, value?: number[]): ParameterData[] => {
- const numbers =
- value?.map((item) => {
- const num = Number(item);
- if (isNaN(num)) {
- return undefined;
- }
- return num;
- }) ?? [];
- switch (type) {
- case ParameterType.Uniform:
- case ParameterType.LogUniform:
- case ParameterType.RandInt:
- case ParameterType.LogRandInt:
- case ParameterType.Range:
- return [
- {
- name: 'min',
- label: '最小值',
- value: numbers?.[0],
- },
- {
- name: 'max',
- label: '最大值',
- value: numbers?.[1],
- },
- ];
- case ParameterType.QUniform:
- case ParameterType.QLogUniform:
- case ParameterType.QRandInt:
- case ParameterType.QLogRandInt:
- return [
- {
- name: 'min',
- label: '最小值',
- value: numbers?.[0],
- },
- {
- name: 'max',
- label: '最大值',
- value: numbers?.[1],
- },
- {
- name: 'q',
- label: '间隔',
- value: numbers?.[2],
- },
- ];
- case ParameterType.Randn:
- return [
- {
- name: 'mean',
- label: '均值',
- value: numbers?.[0],
- },
- {
- name: 'std',
- label: '方差',
- value: numbers?.[1],
- },
- ];
- case ParameterType.QRandn:
- return [
- {
- name: 'mean',
- label: '均值',
- value: numbers?.[0],
- },
- {
- name: 'std',
- label: '方差',
- value: numbers?.[1],
- },
- {
- name: 'q',
- label: '间隔',
- value: numbers?.[2],
- },
- ];
- case ParameterType.Fixed:
- return [
- {
- name: 'value',
- label: '值',
- value: numbers?.[0],
- },
- ];
- default:
- return [];
- }
- };
-
- export const getReqParamName = (type: ParameterType) => {
- if (type === ParameterType.Fixed) {
- return 'value';
- } else if (type === ParameterType.Choice || type === ParameterType.Grid) {
- return 'values';
- } else {
- return 'bounds';
- }
- };
-
- // 搜索算法
- export const searchAlgorithms = [
- {
- label: 'HyperOpt(分布式异步超参数优化)',
- value: 'HyperOpt',
- },
- {
- label: 'HEBO(异方差进化贝叶斯优化)',
- value: 'HEBO',
- },
- {
- label: 'BayesOpt(贝叶斯优化)',
- value: 'BayesOpt',
- },
- {
- label: 'Optuna',
- value: 'Optuna',
- },
- {
- label: 'ZOOpt',
- value: 'ZOOpt',
- },
- {
- label: 'Ax',
- value: 'Ax',
- },
- ];
-
- // 调度算法
- export const schedulerAlgorithms = [
- {
- label: 'ASHA(异步连续减半)',
- value: 'ASHA',
- },
- {
- label: 'HyperBand(HyperBand 早停算法)',
- value: 'HyperBand',
- },
- {
- label: 'MedianStopping(中值停止规则)',
- value: 'MedianStopping',
- },
- {
- label: 'PopulationBased(基于种群训练)',
- value: 'PopulationBased',
- },
- {
- label: 'PB2(Population Based Bandits)',
- value: 'PB2',
- },
- ];
|