You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ExecuteConfig.tsx 15 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472
  1. import SubAreaTitle from '@/components/SubAreaTitle';
  2. import {
  3. AutoMLEnsembleClass,
  4. AutoMLResamplingStrategy,
  5. AutoMLTaskType,
  6. autoMLEnsembleClassOptions,
  7. autoMLResamplingStrategyOptions,
  8. autoMLTaskTypeOptions,
  9. } from '@/enums';
  10. import { Col, Form, InputNumber, Radio, Row, Select, Switch } from 'antd';
  11. // 分类算法
  12. const classificationAlgorithms = [
  13. 'adaboost',
  14. 'bernoulli_nb',
  15. 'decision_tree',
  16. 'extra_trees',
  17. 'gaussian_nb',
  18. 'gradient_boosting',
  19. 'k_nearest_neighbors',
  20. 'lda',
  21. 'liblinear_svc',
  22. 'libsvm_svc',
  23. 'mlp',
  24. 'multinomial_nb',
  25. 'passive_aggressive',
  26. 'qda',
  27. 'random_forest',
  28. 'sgd',
  29. 'LightGBMClassification',
  30. 'XGBoostClassification',
  31. 'StackingClassification',
  32. ].map((name) => {
  33. if (name === 'mlp') {
  34. return { label: 'tablenet', value: name };
  35. } else {
  36. return { label: name, value: name };
  37. }
  38. });
  39. // 回归算法
  40. const regressorAlgorithms = [
  41. 'adaboost',
  42. 'ard_regression',
  43. 'decision_tree',
  44. 'extra_trees',
  45. 'gaussian_process',
  46. 'gradient_boosting',
  47. 'k_nearest_neighbors',
  48. 'liblinear_svr',
  49. 'libsvm_svr',
  50. 'mlp',
  51. 'random_forest',
  52. 'sgd',
  53. 'LightGBMRegression',
  54. 'XGBoostRegression',
  55. ].map((name) => ({ label: name, value: name }));
  56. // 特征预处理算法
  57. const featureAlgorithms = [
  58. 'densifier',
  59. 'extra_trees_preproc_for_classification',
  60. 'extra_trees_preproc_for_regression',
  61. 'fast_ica',
  62. 'feature_agglomeration',
  63. 'kernel_pca',
  64. 'kitchen_sinks',
  65. 'liblinear_svc_preprocessor',
  66. 'no_preprocessing',
  67. 'nystroem_sampler',
  68. 'pca',
  69. 'polynomial',
  70. 'random_trees_embedding',
  71. 'select_percentile_classification',
  72. 'select_percentile_regression',
  73. 'select_rates_classification',
  74. 'select_rates_regression',
  75. 'truncatedSVD',
  76. ].map((name) => ({ label: name, value: name }));
  77. // 分类指标
  78. export const classificationMetrics = [
  79. 'accuracy',
  80. 'balanced_accuracy',
  81. 'roc_auc',
  82. 'average_precision',
  83. 'log_loss',
  84. 'precision_macro',
  85. 'precision_micro',
  86. 'precision_samples',
  87. 'precision_weighted',
  88. 'recall_macro',
  89. 'recall_micro',
  90. 'recall_samples',
  91. 'recall_weighted',
  92. 'f1_macro',
  93. 'f1_micro',
  94. 'f1_samples',
  95. 'f1_weighted',
  96. ].map((name) => ({ label: name, value: name }));
  97. // 回归指标
  98. export const regressionMetrics = [
  99. 'mean_absolute_error',
  100. 'mean_squared_error',
  101. 'root_mean_squared_error',
  102. 'mean_squared_log_error',
  103. 'median_absolute_error',
  104. 'r2',
  105. ].map((name) => ({ label: name, value: name }));
  106. function ExecuteConfig() {
  107. const form = Form.useFormInstance();
  108. const task_type = Form.useWatch('task_type', form);
  109. const include_classifier = Form.useWatch('include_classifier', form);
  110. const exclude_classifier = Form.useWatch('exclude_classifier', form);
  111. const include_regressor = Form.useWatch('include_regressor', form);
  112. const exclude_regressor = Form.useWatch('exclude_regressor', form);
  113. const include_feature_preprocessor = Form.useWatch('include_feature_preprocessor', form);
  114. const exclude_feature_preprocessor = Form.useWatch('exclude_feature_preprocessor', form);
  115. return (
  116. <>
  117. <SubAreaTitle
  118. title="执行配置"
  119. image={require('@/assets/img/model-deployment.png')}
  120. style={{ marginTop: '20px', marginBottom: '24px' }}
  121. ></SubAreaTitle>
  122. <Row gutter={8}>
  123. <Col span={10}>
  124. <Form.Item
  125. label="任务类型"
  126. name="task_type"
  127. rules={[{ required: true, message: '请选择任务类型' }]}
  128. >
  129. <Radio.Group
  130. options={autoMLTaskTypeOptions}
  131. onChange={() => form.resetFields(['metrics'])}
  132. ></Radio.Group>
  133. </Form.Item>
  134. </Col>
  135. </Row>
  136. <Row gutter={8}>
  137. <Col span={10}>
  138. <Form.Item
  139. label="特征预处理算法"
  140. name="include_feature_preprocessor"
  141. tooltip="如果不选,则使用所有可能的特征预处理算法。否则,将只使用包含的特征预处理算法"
  142. >
  143. <Select
  144. allowClear
  145. placeholder="请选择特征预处理算法"
  146. options={featureAlgorithms}
  147. disabled={exclude_feature_preprocessor?.length > 0}
  148. mode="multiple"
  149. showSearch
  150. />
  151. </Form.Item>
  152. </Col>
  153. </Row>
  154. <Row gutter={8}>
  155. <Col span={10}>
  156. <Form.Item
  157. label="排除特征预处理算法"
  158. name="exclude_feature_preprocessor"
  159. tooltip="如果不选,则使用所有可能的特征预处理算法。否则,将排除包含的特征预处理算法"
  160. >
  161. <Select
  162. allowClear
  163. placeholder="排除特征预处理算法"
  164. options={featureAlgorithms}
  165. disabled={include_feature_preprocessor?.length > 0}
  166. mode="multiple"
  167. showSearch
  168. />
  169. </Form.Item>
  170. </Col>
  171. </Row>
  172. <Form.Item dependencies={['task_type']} noStyle>
  173. {({ getFieldValue }) => {
  174. return getFieldValue('task_type') === AutoMLTaskType.Classification ? (
  175. <>
  176. <Row gutter={8}>
  177. <Col span={10}>
  178. <Form.Item
  179. label="分类算法"
  180. name="include_classifier"
  181. tooltip="如果不选,则使用所有可能的分类算法。否则,将只使用包含的算法"
  182. >
  183. <Select
  184. allowClear
  185. placeholder="请选择分类算法"
  186. options={classificationAlgorithms}
  187. mode="multiple"
  188. disabled={exclude_classifier?.length > 0}
  189. showSearch
  190. />
  191. </Form.Item>
  192. </Col>
  193. </Row>
  194. <Row gutter={8}>
  195. <Col span={10}>
  196. <Form.Item
  197. label="排除分类算法"
  198. name="exclude_classifier"
  199. tooltip="如果不选,则使用所有可能的分类算法。否则,将排除包含的算法"
  200. >
  201. <Select
  202. allowClear
  203. placeholder="排除分类算法"
  204. options={classificationAlgorithms}
  205. mode="multiple"
  206. disabled={include_classifier?.length > 0}
  207. showSearch
  208. />
  209. </Form.Item>
  210. </Col>
  211. </Row>
  212. </>
  213. ) : (
  214. <>
  215. <Row gutter={8}>
  216. <Col span={10}>
  217. <Form.Item
  218. label="回归算法"
  219. name="include_regressor"
  220. tooltip="如果不选,则使用所有可能的回归算法。否则,将只使用包含的算法"
  221. >
  222. <Select
  223. allowClear
  224. placeholder="请选择回归算法"
  225. options={regressorAlgorithms}
  226. mode="multiple"
  227. disabled={exclude_regressor?.length > 0}
  228. showSearch
  229. />
  230. </Form.Item>
  231. </Col>
  232. </Row>
  233. <Row gutter={8}>
  234. <Col span={10}>
  235. <Form.Item
  236. label="排除的回归算法"
  237. name="exclude_regressor"
  238. tooltip="如果不选,则使用所有可能的回归算法。否则,将排除包含的算法"
  239. >
  240. <Select
  241. allowClear
  242. placeholder="排除回归算法"
  243. options={regressorAlgorithms}
  244. mode="multiple"
  245. disabled={include_regressor?.length > 0}
  246. showSearch
  247. />
  248. </Form.Item>
  249. </Col>
  250. </Row>
  251. </>
  252. );
  253. }}
  254. </Form.Item>
  255. <Row gutter={8}>
  256. <Col span={10}>
  257. <Form.Item
  258. label="集成方式"
  259. name="ensemble_class"
  260. tooltip="仅使用单个最佳模型还是集成模型"
  261. >
  262. <Radio.Group options={autoMLEnsembleClassOptions}></Radio.Group>
  263. </Form.Item>
  264. </Col>
  265. </Row>
  266. <Form.Item dependencies={['ensemble_class']} noStyle>
  267. {({ getFieldValue }) => {
  268. return getFieldValue('ensemble_class') === AutoMLEnsembleClass.Default ? (
  269. <>
  270. <Row gutter={8}>
  271. <Col span={10}>
  272. <Form.Item
  273. label="集成模型数量"
  274. name="ensemble_size"
  275. tooltip="集成模型数量,必须是大于等于1的整数,默认50"
  276. >
  277. <InputNumber placeholder="请输入集成模型数量" min={1} precision={0} />
  278. </Form.Item>
  279. </Col>
  280. </Row>
  281. <Row gutter={8}>
  282. <Col span={10}>
  283. <Form.Item
  284. label="集成最佳模型数量"
  285. name="ensemble_nbest"
  286. tooltip="仅集成最佳的N个模型,必须是大于等于1的整数"
  287. >
  288. <InputNumber placeholder="请输入集成最佳模型数量" min={1} precision={0} />
  289. </Form.Item>
  290. </Col>
  291. </Row>
  292. </>
  293. ) : null;
  294. }}
  295. </Form.Item>
  296. <Row gutter={8}>
  297. <Col span={10}>
  298. <Form.Item
  299. label="最大数量"
  300. name="max_models_on_disc"
  301. tooltip="定义在磁盘中保存的模型的最大数量。额外的模型数量将被永久删除,它设置了一个集成可以使用多少个模型的上限。必须是大于等于1的整数,默认50"
  302. >
  303. <InputNumber placeholder="请输入最大数量" min={1} precision={0} />
  304. </Form.Item>
  305. </Col>
  306. </Row>
  307. <Row gutter={8}>
  308. <Col span={10}>
  309. <Form.Item
  310. label="内存限制(MB)"
  311. name="memory_limit"
  312. tooltip="机器学习算法的内存限制(MB)。如果自动机器学习试图分配超过memory_limit MB,它将停止拟合机器学习算法。默认3072"
  313. >
  314. <InputNumber placeholder="请输入内存限制" min={0} precision={0} />
  315. </Form.Item>
  316. </Col>
  317. </Row>
  318. <Row gutter={8}>
  319. <Col span={10}>
  320. <Form.Item
  321. label="单次时间限制(秒)"
  322. name="per_run_time_limit"
  323. tooltip="单次调用机器学习模型的时间限制(以秒为单位)。如果机器学习算法运行超过时间限制,将终止模型拟合,默认600"
  324. >
  325. <InputNumber placeholder="请输入时间限制" min={0} precision={0} />
  326. </Form.Item>
  327. </Col>
  328. </Row>
  329. <Row gutter={8}>
  330. <Col span={10}>
  331. <Form.Item
  332. label="搜索时间限制(秒)"
  333. name="time_left_for_this_task"
  334. tooltip="搜索合适模型的时间限制(以秒为单位)。通过增加这个值,自动机器学习有更高的机会找到更好的模型。默认3600。"
  335. >
  336. <InputNumber placeholder="请输入搜索时间限制" min={0} precision={0} />
  337. </Form.Item>
  338. </Col>
  339. </Row>
  340. <Row gutter={8}>
  341. <Col span={10}>
  342. <Form.Item
  343. label="测试集比率"
  344. name="test_size"
  345. tooltip="将数据划分为训练数据和测试数据,测试数据集所占比例,0到1之间"
  346. >
  347. <InputNumber placeholder="请输入测试集比率" min={0} max={1} />
  348. </Form.Item>
  349. </Col>
  350. </Row>
  351. <Row gutter={8}>
  352. <Col span={10}>
  353. <Form.Item label="计算指标" name="scoring_functions" tooltip="需要计算并打印的指标">
  354. <Select
  355. allowClear
  356. placeholder="请选择计算指标"
  357. options={
  358. task_type === AutoMLTaskType.Classification
  359. ? classificationMetrics
  360. : regressionMetrics
  361. }
  362. showSearch
  363. />
  364. </Form.Item>
  365. </Col>
  366. </Row>
  367. <Row gutter={8}>
  368. <Col span={10}>
  369. <Form.Item label="随机种子" name="seed" tooltip="随机种子,将决定输出文件名">
  370. <InputNumber placeholder="请输入随机种子" min={0} precision={0} />
  371. </Form.Item>
  372. </Col>
  373. </Row>
  374. <SubAreaTitle
  375. title="重采样策略"
  376. image={require('@/assets/img/resample-icon.png')}
  377. style={{ marginTop: '20px', marginBottom: '24px' }}
  378. ></SubAreaTitle>
  379. <Row gutter={8}>
  380. <Col span={10}>
  381. <Form.Item
  382. label="重采样策略"
  383. name="resampling_strategy"
  384. tooltip="重采样策略,分为holdout和crossValid。holdout指定训练数据划分为训练集和验证集的比例。crossValid为交叉验证。"
  385. >
  386. <Select
  387. allowClear
  388. placeholder="请选择重采样策略"
  389. options={autoMLResamplingStrategyOptions}
  390. showSearch
  391. />
  392. </Form.Item>
  393. </Col>
  394. </Row>
  395. <Form.Item dependencies={['resampling_strategy']} noStyle>
  396. {({ getFieldValue }) => {
  397. return getFieldValue('resampling_strategy') === AutoMLResamplingStrategy.CrossValid ? (
  398. <Row gutter={8}>
  399. <Col span={10}>
  400. <Form.Item
  401. label="交叉验证折数"
  402. name="folds"
  403. tooltip="交叉验证折数必须是大于等于2的整数"
  404. rules={[
  405. {
  406. required: true,
  407. message: '请输入交叉验证折数',
  408. },
  409. ]}
  410. >
  411. <InputNumber placeholder="请输入交叉验证折数" min={2} precision={0} />
  412. </Form.Item>
  413. </Col>
  414. </Row>
  415. ) : null;
  416. }}
  417. </Form.Item>
  418. <Row gutter={8}>
  419. <Col span={10}>
  420. <Form.Item
  421. label="是否打乱"
  422. name="shuffle"
  423. tooltip="拆分数据前是否打乱顺序"
  424. valuePropName="checked"
  425. >
  426. <Switch />
  427. </Form.Item>
  428. </Col>
  429. </Row>
  430. <Row gutter={8}>
  431. <Col span={10}>
  432. <Form.Item
  433. label="训练集比率"
  434. name="train_size"
  435. tooltip="重采样划分训练集和验证集,训练集的比率,0到1之间"
  436. >
  437. <InputNumber placeholder="请输入训练集比率" min={0} max={1} />
  438. </Form.Item>
  439. </Col>
  440. </Row>
  441. </>
  442. );
  443. }
  444. export default ExecuteConfig;