You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ExecuteConfig.tsx 19 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556
  1. import SubAreaTitle from '@/components/SubAreaTitle';
  2. import {
  3. AutoMLEnsembleClass,
  4. AutoMLResamplingStrategy,
  5. AutoMLTaskType,
  6. autoMLEnsembleClassOptions,
  7. autoMLResamplingStrategyOptions,
  8. autoMLTaskTypeOptions,
  9. } from '@/enums';
  10. import { Col, Form, InputNumber, Radio, Row, Select, Switch } from 'antd';
  11. // 分类算法
  12. const classificationAlgorithms = [
  13. 'adaboost',
  14. 'bernoulli_nb',
  15. 'decision_tree',
  16. 'extra_trees',
  17. 'gaussian_nb',
  18. 'gradient_boosting',
  19. 'k_nearest_neighbors',
  20. 'lda',
  21. 'liblinear_svc',
  22. 'libsvm_svc',
  23. 'mlp',
  24. 'multinomial_nb',
  25. 'passive_aggressive',
  26. 'qda',
  27. 'random_forest',
  28. 'sgd',
  29. ].map((name) => ({ label: name, value: name }));
  30. // 回归算法
  31. const regressorAlgorithms = [
  32. 'adaboost',
  33. 'ard_regression',
  34. 'decision_tree',
  35. 'extra_trees',
  36. 'gaussian_process',
  37. 'gradient_boosting',
  38. 'k_nearest_neighbors',
  39. 'liblinear_svr',
  40. 'libsvm_svr',
  41. 'mlp',
  42. 'random_forest',
  43. 'sgd',
  44. ].map((name) => ({ label: name, value: name }));
  45. // 特征预处理算法
  46. const featureAlgorithms = [
  47. 'densifier',
  48. 'extra_trees_preproc_for_classification',
  49. 'extra_trees_preproc_for_regression',
  50. 'fast_ica',
  51. 'feature_agglomeration',
  52. 'kernel_pca',
  53. 'kitchen_sinks',
  54. 'liblinear_svc_preprocessor',
  55. 'no_preprocessing',
  56. 'nystroem_sampler',
  57. 'pca',
  58. 'polynomial',
  59. 'random_trees_embedding',
  60. 'select_percentile_classification',
  61. 'select_percentile_regression',
  62. 'select_rates_classification',
  63. 'select_rates_regression',
  64. 'truncatedSVD',
  65. ].map((name) => ({ label: name, value: name }));
  66. // 分类指标
  67. export const classificationMetrics = [
  68. 'accuracy',
  69. 'balanced_accuracy',
  70. 'roc_auc',
  71. 'average_precision',
  72. 'log_loss',
  73. 'precision_macro',
  74. 'precision_micro',
  75. 'precision_samples',
  76. 'precision_weighted',
  77. 'recall_macro',
  78. 'recall_micro',
  79. 'recall_samples',
  80. 'recall_weighted',
  81. 'f1_macro',
  82. 'f1_micro',
  83. 'f1_samples',
  84. 'f1_weighted',
  85. ].map((name) => ({ label: name, value: name }));
  86. // 回归指标
  87. export const regressionMetrics = [
  88. 'mean_absolute_error',
  89. 'mean_squared_error',
  90. 'root_mean_squared_error',
  91. 'mean_squared_log_error',
  92. 'median_absolute_error',
  93. 'r2',
  94. ].map((name) => ({ label: name, value: name }));
  95. function ExecuteConfig() {
  96. const form = Form.useFormInstance();
  97. const task_type = Form.useWatch('task_type', form);
  98. const include_classifier = Form.useWatch('include_classifier', form);
  99. const exclude_classifier = Form.useWatch('exclude_classifier', form);
  100. const include_regressor = Form.useWatch('include_regressor', form);
  101. const exclude_regressor = Form.useWatch('exclude_regressor', form);
  102. const include_feature_preprocessor = Form.useWatch('include_feature_preprocessor', form);
  103. const exclude_feature_preprocessor = Form.useWatch('exclude_feature_preprocessor', form);
  104. return (
  105. <>
  106. <SubAreaTitle
  107. title="执行配置"
  108. image={require('@/assets/img/model-deployment.png')}
  109. style={{ marginTop: '20px', marginBottom: '24px' }}
  110. ></SubAreaTitle>
  111. <Row gutter={8}>
  112. <Col span={10}>
  113. <Form.Item
  114. label="任务类型"
  115. name="task_type"
  116. rules={[{ required: true, message: '请选择任务类型' }]}
  117. >
  118. <Radio.Group
  119. options={autoMLTaskTypeOptions}
  120. onChange={() => form.resetFields(['metrics'])}
  121. ></Radio.Group>
  122. </Form.Item>
  123. </Col>
  124. </Row>
  125. <Row gutter={8}>
  126. <Col span={10}>
  127. <Form.Item
  128. label="特征预处理算法"
  129. name="include_feature_preprocessor"
  130. tooltip="如果不选,则使用所有可能的特征预处理算法。否则,将只使用包含的特征预处理算法"
  131. >
  132. <Select
  133. allowClear
  134. placeholder="请选择特征预处理算法"
  135. options={featureAlgorithms}
  136. disabled={exclude_feature_preprocessor?.length > 0}
  137. mode="multiple"
  138. showSearch
  139. />
  140. </Form.Item>
  141. </Col>
  142. </Row>
  143. <Row gutter={8}>
  144. <Col span={10}>
  145. <Form.Item
  146. label="排除特征预处理算法"
  147. name="exclude_feature_preprocessor"
  148. tooltip="如果不选,则使用所有可能的特征预处理算法。否则,将排除包含的特征预处理算法"
  149. >
  150. <Select
  151. allowClear
  152. placeholder="排除特征预处理算法"
  153. options={featureAlgorithms}
  154. disabled={include_feature_preprocessor?.length > 0}
  155. mode="multiple"
  156. showSearch
  157. />
  158. </Form.Item>
  159. </Col>
  160. </Row>
  161. <Form.Item dependencies={['task_type']} noStyle>
  162. {({ getFieldValue }) => {
  163. return getFieldValue('task_type') === AutoMLTaskType.Classification ? (
  164. <>
  165. <Row gutter={8}>
  166. <Col span={10}>
  167. <Form.Item
  168. label="分类算法"
  169. name="include_classifier"
  170. tooltip="如果不选,则使用所有可能的分类算法。否则,将只使用包含的算法"
  171. >
  172. <Select
  173. allowClear
  174. placeholder="请选择分类算法"
  175. options={classificationAlgorithms}
  176. mode="multiple"
  177. disabled={exclude_classifier?.length > 0}
  178. showSearch
  179. />
  180. </Form.Item>
  181. </Col>
  182. </Row>
  183. <Row gutter={8}>
  184. <Col span={10}>
  185. <Form.Item
  186. label="排除分类算法"
  187. name="exclude_classifier"
  188. tooltip="如果不选,则使用所有可能的分类算法。否则,将排除包含的算法"
  189. >
  190. <Select
  191. allowClear
  192. placeholder="排除分类算法"
  193. options={classificationAlgorithms}
  194. mode="multiple"
  195. disabled={include_classifier?.length > 0}
  196. showSearch
  197. />
  198. </Form.Item>
  199. </Col>
  200. </Row>
  201. </>
  202. ) : (
  203. <>
  204. <Row gutter={8}>
  205. <Col span={10}>
  206. <Form.Item
  207. label="回归算法"
  208. name="include_regressor"
  209. tooltip="如果不选,则使用所有可能的回归算法。否则,将只使用包含的算法"
  210. >
  211. <Select
  212. allowClear
  213. placeholder="请选择回归算法"
  214. options={regressorAlgorithms}
  215. mode="multiple"
  216. disabled={exclude_regressor?.length > 0}
  217. showSearch
  218. />
  219. </Form.Item>
  220. </Col>
  221. </Row>
  222. <Row gutter={8}>
  223. <Col span={10}>
  224. <Form.Item
  225. label="排除的回归算法"
  226. name="exclude_regressor"
  227. tooltip="如果不选,则使用所有可能的回归算法。否则,将排除包含的算法"
  228. >
  229. <Select
  230. allowClear
  231. placeholder="排除回归算法"
  232. options={regressorAlgorithms}
  233. mode="multiple"
  234. disabled={include_regressor?.length > 0}
  235. showSearch
  236. />
  237. </Form.Item>
  238. </Col>
  239. </Row>
  240. </>
  241. );
  242. }}
  243. </Form.Item>
  244. <Row gutter={8}>
  245. <Col span={10}>
  246. <Form.Item
  247. label="集成方式"
  248. name="ensemble_class"
  249. tooltip="仅使用单个最佳模型还是集成模型"
  250. >
  251. <Radio.Group options={autoMLEnsembleClassOptions}></Radio.Group>
  252. </Form.Item>
  253. </Col>
  254. </Row>
  255. <Form.Item dependencies={['ensemble_class']} noStyle>
  256. {({ getFieldValue }) => {
  257. return getFieldValue('ensemble_class') === AutoMLEnsembleClass.Default ? (
  258. <>
  259. <Row gutter={8}>
  260. <Col span={10}>
  261. <Form.Item
  262. label="集成模型数量"
  263. name="ensemble_size"
  264. tooltip="集成模型数量,如果设置为0,则没有集成。默认50"
  265. >
  266. <InputNumber placeholder="请输入集成模型数量" min={0} precision={0} />
  267. </Form.Item>
  268. </Col>
  269. </Row>
  270. <Row gutter={8}>
  271. <Col span={10}>
  272. <Form.Item
  273. label="集成最佳模型数量"
  274. name="ensemble_nbest"
  275. tooltip="仅集成最佳的N个模型"
  276. >
  277. <InputNumber placeholder="请输入集成最佳模型数量" min={0} precision={0} />
  278. </Form.Item>
  279. </Col>
  280. </Row>
  281. </>
  282. ) : null;
  283. }}
  284. </Form.Item>
  285. <Row gutter={8}>
  286. <Col span={10}>
  287. <Form.Item
  288. label="最大数量"
  289. name="max_models_on_disc"
  290. tooltip="定义在磁盘中保存的模型的最大数量。额外的模型数量将被永久删除,它设置了一个集成可以使用多少个模型的上限。必须是大于等于1的整数,默认50"
  291. >
  292. <InputNumber placeholder="请输入最大数量" min={0} precision={0} />
  293. </Form.Item>
  294. </Col>
  295. </Row>
  296. <Row gutter={8}>
  297. <Col span={10}>
  298. <Form.Item
  299. label="内存限制(MB)"
  300. name="memory_limit"
  301. tooltip="机器学习算法的内存限制(MB)。如果自动机器学习试图分配超过memory_limit MB,它将停止拟合机器学习算法。默认3072"
  302. >
  303. <InputNumber placeholder="请输入内存限制" min={0} precision={0} />
  304. </Form.Item>
  305. </Col>
  306. </Row>
  307. <Row gutter={8}>
  308. <Col span={10}>
  309. <Form.Item
  310. label="单次时间限制(秒)"
  311. name="per_run_time_limit"
  312. tooltip="单次调用机器学习模型的时间限制(以秒为单位)。如果机器学习算法运行超过时间限制,将终止模型拟合,默认600"
  313. >
  314. <InputNumber placeholder="请输入时间限制" min={0} precision={0} />
  315. </Form.Item>
  316. </Col>
  317. </Row>
  318. <Row gutter={8}>
  319. <Col span={10}>
  320. <Form.Item
  321. label="搜索时间限制(秒)"
  322. name="time_left_for_this_task"
  323. tooltip="搜索合适模型的时间限制(以秒为单位)。通过增加这个值,自动机器学习有更高的机会找到更好的模型。默认3600。"
  324. >
  325. <InputNumber placeholder="请输入搜索时间限制" min={0} precision={0} />
  326. </Form.Item>
  327. </Col>
  328. </Row>
  329. <Row gutter={8}>
  330. <Col span={10}>
  331. <Form.Item
  332. label="测试集比率"
  333. name="test_size"
  334. tooltip="将数据划分为训练数据和测试数据,测试数据集所占比例,0到1之间"
  335. >
  336. <InputNumber placeholder="请输入测试集比率" min={0} max={1} />
  337. </Form.Item>
  338. </Col>
  339. </Row>
  340. <Row gutter={8}>
  341. <Col span={10}>
  342. <Form.Item label="计算指标" name="scoring_functions" tooltip="需要计算并打印的指标">
  343. <Select
  344. allowClear
  345. placeholder="请选择计算指标"
  346. options={
  347. task_type === AutoMLTaskType.Classification
  348. ? classificationMetrics
  349. : regressionMetrics
  350. }
  351. showSearch
  352. />
  353. </Form.Item>
  354. </Col>
  355. </Row>
  356. <Row gutter={8}>
  357. <Col span={10}>
  358. <Form.Item label="随机种子" name="seed" tooltip="随机种子,将决定输出文件名">
  359. <InputNumber placeholder="请输入随机种子" min={0} precision={0} />
  360. </Form.Item>
  361. </Col>
  362. </Row>
  363. <SubAreaTitle
  364. title="重采样策略"
  365. image={require('@/assets/img/resample-icon.png')}
  366. style={{ marginTop: '20px', marginBottom: '24px' }}
  367. ></SubAreaTitle>
  368. <Row gutter={8}>
  369. <Col span={10}>
  370. <Form.Item
  371. label="重采样策略"
  372. name="resampling_strategy"
  373. tooltip="重采样策略,分为holdout和crossValid。holdout指定训练数据划分为训练集和验证集的比例。crossValid为交叉验证。"
  374. >
  375. <Select
  376. allowClear
  377. placeholder="请选择重采样策略"
  378. options={autoMLResamplingStrategyOptions}
  379. showSearch
  380. />
  381. </Form.Item>
  382. </Col>
  383. </Row>
  384. <Form.Item dependencies={['resampling_strategy']} noStyle>
  385. {({ getFieldValue }) => {
  386. return getFieldValue('resampling_strategy') === AutoMLResamplingStrategy.CrossValid ? (
  387. <Row gutter={8}>
  388. <Col span={10}>
  389. <Form.Item
  390. label="交叉验证折数"
  391. name="folds"
  392. rules={[
  393. {
  394. required: true,
  395. message: '请输入交叉验证折数',
  396. },
  397. ]}
  398. >
  399. <InputNumber placeholder="请输入交叉验证折数" min={0} precision={0} />
  400. </Form.Item>
  401. </Col>
  402. </Row>
  403. ) : null;
  404. }}
  405. </Form.Item>
  406. <Row gutter={8}>
  407. <Col span={10}>
  408. <Form.Item label="是否打乱" name="shuffle" tooltip="拆分数据前是否打乱顺序">
  409. <Switch />
  410. </Form.Item>
  411. </Col>
  412. </Row>
  413. <Row gutter={8}>
  414. <Col span={10}>
  415. <Form.Item
  416. label="训练集比率"
  417. name="train_size"
  418. tooltip="重采样划分训练集和验证集,训练集的比率,0到1之间"
  419. >
  420. <InputNumber placeholder="请输入训练集比率" min={0} max={1} />
  421. </Form.Item>
  422. </Col>
  423. </Row>
  424. {/* <Row gutter={8}>
  425. <Col span={10}>
  426. <Form.Item
  427. label="文件夹路径"
  428. name="tmp_folder"
  429. tooltip="存放配置输出和日志文件的文件夹"
  430. rules={[
  431. {
  432. pattern: /^\/[a-zA-Z0-9._/-]+$/,
  433. message:
  434. '请输入正确的文件夹路径,以 / 开头,只支持字母、数字、点、下划线、中横线、斜杠',
  435. },
  436. ]}
  437. >
  438. <Input placeholder="请输入文件夹路径" maxLength={64} showCount allowClear />
  439. </Form.Item>
  440. </Col>
  441. </Row> */}
  442. {/* <Form.List name="hyper-parameter">
  443. {(fields, { add, remove }) => (
  444. <>
  445. <Row gutter={8}>
  446. <Col span={10}>
  447. <Form.Item
  448. label="超参数"
  449. style={{ marginBottom: 0, marginTop: '-14px' }}
  450. tooltip="超参数"
  451. ></Form.Item>
  452. </Col>
  453. </Row>
  454. <div className={styles['hyper-parameter']}>
  455. <Flex align="center" className={styles['hyper-parameter__header']}>
  456. <div className={styles['hyper-parameter__header__name']}>参数名称</div>
  457. <div className={styles['hyper-parameter__header__type']}>约束类型</div>
  458. <div className={styles['hyper-parameter__header__space']}>搜索空间</div>
  459. <div className={styles['hyper-parameter__header__operation']}>操作</div>
  460. </Flex>
  461. {fields.map(({ key, name, ...restField }, index) => (
  462. <Flex key={key} align="center" className={styles['hyper-parameter__body']}>
  463. <Form.Item
  464. className={styles['hyper-parameter__body__name']}
  465. {...restField}
  466. name={[name, 'name']}
  467. rules={[{ required: true, message: 'Missing first name' }]}
  468. >
  469. <Input placeholder="Key" />
  470. </Form.Item>
  471. <Form.Item
  472. className={styles['hyper-parameter__body__name']}
  473. {...restField}
  474. name={[name, 'type']}
  475. rules={[{ required: true, message: 'Missing last name' }]}
  476. >
  477. <Input placeholder="Value" />
  478. </Form.Item>
  479. <Form.Item
  480. className={styles['hyper-parameter__body__name']}
  481. {...restField}
  482. name={[name, 'space']}
  483. rules={[{ required: true, message: 'Missing last name' }]}
  484. >
  485. <Input placeholder="Value" />
  486. </Form.Item>
  487. <div className={styles['hyper-parameter__body__operation']}>
  488. <Button
  489. style={{
  490. marginRight: '3px',
  491. }}
  492. shape="circle"
  493. disabled={fields.length === 1}
  494. type="text"
  495. size="middle"
  496. onClick={() => remove(name)}
  497. icon={<MinusCircleOutlined />}
  498. ></Button>
  499. {index === fields.length - 1 && (
  500. <Button
  501. shape="circle"
  502. size="middle"
  503. type="text"
  504. onClick={() => add()}
  505. icon={<PlusCircleOutlined />}
  506. ></Button>
  507. )}
  508. </div>
  509. </Flex>
  510. ))}
  511. {fields.length === 0 && (
  512. <div className={styles['hyper-parameter__add']}>
  513. <Button type="link" onClick={() => add()} icon={<KFIcon type="icon-xinjian2" />}>
  514. 添加一行
  515. </Button>
  516. </div>
  517. )}
  518. </div>
  519. </>
  520. )}
  521. </Form.List> */}
  522. </>
  523. );
  524. }
  525. export default ExecuteConfig;