| @@ -126,7 +126,7 @@ const Dataset = () => { | |||
| }; | |||
| const getDatasetVersions = (params) => { | |||
| getDatasetVersionIdList(params).then((res) => { | |||
| setWordList(res.data); | |||
| setWordList(res?.data?.content ?? []); | |||
| }); | |||
| }; | |||
| const handleChange = (value) => { | |||
| @@ -277,7 +277,9 @@ const Dataset = () => { | |||
| </div> | |||
| </div> | |||
| <div style={{ marginBottom: '10px', fontSize: '14px' }}> | |||
| {wordList.length > 0 ? wordList[0].description : null} | |||
| {wordList.length > 0 && wordList[0].description | |||
| ? '版本描述:' + wordList[0].description | |||
| : null} | |||
| </div> | |||
| <Table columns={columns} dataSource={wordList} pagination={false} rowKey="id" /> | |||
| </div> | |||
| @@ -48,4 +48,12 @@ | |||
| margin-right: 10px; | |||
| } | |||
| } | |||
| .global_param_item { | |||
| max-height: 230px; | |||
| padding: 20px 12px; | |||
| overflow-y: auto; | |||
| border: 1px solid #e6e6e6; | |||
| border-radius: 6px; | |||
| } | |||
| } | |||
| @@ -1,4 +1,5 @@ | |||
| import { Form, Input, Modal, Select } from 'antd'; | |||
| import { type PipelineGlobalParam } from '@/types'; | |||
| import { Form, Input, Modal, Radio, Select, type FormRule } from 'antd'; | |||
| import { useState } from 'react'; | |||
| import styles from './addExperimentModal.less'; | |||
| @@ -6,6 +7,7 @@ type FormData = { | |||
| name?: string; | |||
| description?: string; | |||
| workflow_id?: string | number; | |||
| global_param?: PipelineGlobalParam[]; | |||
| }; | |||
| type AddExperimentModalProps = { | |||
| @@ -17,17 +19,55 @@ type AddExperimentModalProps = { | |||
| initialValues: FormData; | |||
| }; | |||
| interface GlobalParam { | |||
| param_name: string; | |||
| param_value: string; | |||
| } | |||
| interface Workflow { | |||
| id: string | number; | |||
| name: string; | |||
| global_param?: GlobalParam[] | null; | |||
| global_param?: PipelineGlobalParam[] | null; | |||
| } | |||
| // 根据参数设置输入组件 | |||
| export const getParamComponent = (paramType: number, isSensitive?: number): JSX.Element => { | |||
| // 防止后台返回不是 number 类型 | |||
| if (Number(paramType) === 3) { | |||
| return ( | |||
| <Radio.Group> | |||
| <Radio value={1}>是</Radio> | |||
| <Radio value={0}>否</Radio> | |||
| </Radio.Group> | |||
| ); | |||
| } | |||
| if (isSensitive && Number(isSensitive) === 1) { | |||
| return <Input.Password visibilityToggle={false} allowClear />; | |||
| } | |||
| return <Input placeholder="请输入值" allowClear />; | |||
| }; | |||
| // 根据参数设置校验规则 | |||
| export const getParamRules = (paramType: number, required: boolean = false): FormRule[] => { | |||
| const rules = []; | |||
| // 防止后台返回不是 number 类型 | |||
| if (Number(paramType) === 2) { | |||
| rules.push({ | |||
| pattern: /^-?\d+(\.\d+)?$/, | |||
| message: '整型必须是数字', | |||
| }); | |||
| } | |||
| if (required) { | |||
| rules.push({ required: true, message: '请输入值' }); | |||
| } | |||
| return rules; | |||
| }; | |||
| // 根据参数设置 label | |||
| const getParamType = (param: PipelineGlobalParam): string => { | |||
| const paramTypes: Readonly<Record<number, string>> = { | |||
| 1: '字符串', | |||
| 2: '整型', | |||
| 3: '布尔类型', | |||
| }; | |||
| return param.param_name + `(${paramTypes[param.param_type]})`; | |||
| }; | |||
| function AddExperimentModal({ | |||
| isAdd, | |||
| open, | |||
| @@ -38,20 +78,30 @@ function AddExperimentModal({ | |||
| }: AddExperimentModalProps) { | |||
| const dialogTitle = isAdd ? '新建实验' : '编辑实验'; | |||
| const workflowDisabled = isAdd ? false : true; | |||
| const [globalParam, setGlobalParam] = useState<GlobalParam[]>([]); | |||
| const [globalParam, setGlobalParam] = useState<PipelineGlobalParam[]>( | |||
| initialValues.global_param || [], | |||
| ); | |||
| const [form] = Form.useForm(); | |||
| const layout = { | |||
| labelCol: { span: 24 }, | |||
| wrapperCol: { span: 24 }, | |||
| }; | |||
| const tailLayout = { | |||
| labelCol: { span: 8 }, | |||
| wrapperCol: { span: 16 }, | |||
| }; | |||
| // 除了流水线选择发生变化 | |||
| const handleWorkflowChange = (id: string) => { | |||
| const handleWorkflowChange = (id: string | number) => { | |||
| const pipeline: Workflow | undefined = workflowList.find((v) => v.id === id); | |||
| if (pipeline && pipeline.global_param) { | |||
| setGlobalParam(pipeline.global_param); | |||
| const fields = pipeline.global_param.reduce((acc, item) => { | |||
| acc[item.param_name] = item.param_value; | |||
| return acc; | |||
| }, {} as Record<string, string>); | |||
| form.setFieldsValue(fields); | |||
| form.setFieldValue('global_param', pipeline.global_param); | |||
| } else { | |||
| setGlobalParam([]); | |||
| form.setFieldValue('global_param', []); | |||
| } | |||
| }; | |||
| return ( | |||
| @@ -73,11 +123,12 @@ function AddExperimentModal({ | |||
| > | |||
| <Form | |||
| name="form" | |||
| layout="vertical" | |||
| layout="horizontal" | |||
| initialValues={initialValues} | |||
| onFinish={onFinish} | |||
| autoComplete="off" | |||
| form={form} | |||
| {...layout} | |||
| > | |||
| <Form.Item | |||
| label="实验名称" | |||
| @@ -114,11 +165,32 @@ function AddExperimentModal({ | |||
| : null} | |||
| </Select> | |||
| </Form.Item> | |||
| {globalParam.map((item) => ( | |||
| <Form.Item label={item.param_name} name={item.param_name} key={item.param_name}> | |||
| <Input /> | |||
| {globalParam.length > 0 && ( | |||
| <Form.Item label="运行参数" tooltip="展示关联的流水线的参数,脱敏的参数以xxxx展示"> | |||
| <div className={styles.global_param_item}> | |||
| <Form.List name="global_param"> | |||
| {(fields) => | |||
| fields.map(({ key, name, ...restField }) => ( | |||
| <Form.Item | |||
| {...tailLayout} | |||
| {...restField} | |||
| key={key} | |||
| label={getParamType(globalParam[name])} | |||
| name={[name, 'param_value']} | |||
| labelAlign="left" | |||
| rules={getParamRules(globalParam[name]['param_type'])} | |||
| > | |||
| {getParamComponent( | |||
| globalParam[name]['param_type'], | |||
| globalParam[name]['is_sensitive'], | |||
| )} | |||
| </Form.Item> | |||
| )) | |||
| } | |||
| </Form.List> | |||
| </div> | |||
| </Form.Item> | |||
| ))} | |||
| )} | |||
| </Form> | |||
| </Modal> | |||
| ); | |||
| @@ -119,22 +119,12 @@ function Experiment() { | |||
| }; | |||
| // 创建或者编辑实验接口请求 | |||
| const handleAddExperiment = async (values) => { | |||
| const workflow_id = values['workflow_id']; | |||
| let global_param = undefined; | |||
| const pipeline = workflowList.find((v) => v.id === workflow_id); | |||
| if (pipeline && pipeline.global_param) { | |||
| const globalParamList = [...pipeline.global_param]; | |||
| for (const item of globalParamList) { | |||
| item.param_value = values[item.param_name]; | |||
| values[item.param_name] = undefined; | |||
| } | |||
| global_param = JSON.stringify(globalParamList); | |||
| } | |||
| const params = { | |||
| ...values, | |||
| global_param, | |||
| }; | |||
| const global_param = JSON.stringify(values.global_param); | |||
| if (!experimentId) { | |||
| const params = { | |||
| ...values, | |||
| global_param, | |||
| }; | |||
| const [res, _] = await to(postExperiment(params)); | |||
| if (res) { | |||
| message.success('新建实验成功'); | |||
| @@ -142,7 +132,7 @@ function Experiment() { | |||
| getList(); | |||
| } | |||
| } else { | |||
| const params = { ...values, id: experimentId }; | |||
| const params = { ...values, global_param, id: experimentId }; | |||
| const [res, _] = await to(putExperiment(params)); | |||
| if (res) { | |||
| message.success('编辑实验成功'); | |||
| @@ -431,14 +421,16 @@ function Experiment() { | |||
| rowExpandable: (record) => true, | |||
| }} | |||
| /> | |||
| <AddExperimentModal | |||
| isAdd={isAdd} | |||
| open={isModalOpen} | |||
| initialValues={addFormData} | |||
| onCancel={handleCancel} | |||
| onFinish={handleAddExperiment} | |||
| workflowList={workflowList} | |||
| /> | |||
| {isModalOpen && ( | |||
| <AddExperimentModal | |||
| isAdd={isAdd} | |||
| open={isModalOpen} | |||
| initialValues={addFormData} | |||
| onCancel={handleCancel} | |||
| onFinish={handleAddExperiment} | |||
| workflowList={workflowList} | |||
| /> | |||
| )} | |||
| </div> | |||
| ); | |||
| } | |||
| @@ -121,8 +121,7 @@ const Dataset = () => { | |||
| }; | |||
| const getModelVersions = (params) => { | |||
| getModelVersionIdList(params).then((ret) => { | |||
| console.log(ret); | |||
| setWordList(ret.data); | |||
| setWordList(ret?.data?.content ?? []); | |||
| }); | |||
| }; | |||
| const handleExport = async () => { | |||
| @@ -278,7 +277,9 @@ const Dataset = () => { | |||
| </div> | |||
| </div> | |||
| <div style={{ marginBottom: '10px', fontSize: '14px' }}> | |||
| {wordList.length > 0 ? wordList[0].description : null} | |||
| {wordList.length > 0 && wordList[0].description | |||
| ? '版本描述:' + wordList[0].description | |||
| : null} | |||
| </div> | |||
| <Table columns={columns} dataSource={wordList} pagination={false} rowKey="id" /> | |||
| </div> | |||
| @@ -8,7 +8,13 @@ | |||
| top: 5px; | |||
| right: 0; | |||
| } | |||
| .add_button { | |||
| .add_button_form_item { | |||
| margin-top: 15px; | |||
| &:first-child { | |||
| margin-top: 0; | |||
| } | |||
| } | |||
| .add_button_form_item .add_button { | |||
| padding: 0; | |||
| } | |||
| @@ -1,7 +1,12 @@ | |||
| import { | |||
| getParamComponent, | |||
| getParamRules, | |||
| } from '@/pages/Experiment/experimentText/addExperimentModal'; | |||
| import { type PipelineGlobalParam } from '@/types'; | |||
| import { to } from '@/utils/promise'; | |||
| import { DeleteOutlined, PlusOutlined } from '@ant-design/icons'; | |||
| import { Button, Drawer, Form, Input, Radio } from 'antd'; | |||
| import { NamePath } from 'antd/es/form/interface'; | |||
| import { forwardRef, useImperativeHandle } from 'react'; | |||
| import styles from './globalParamsDrawer.less'; | |||
| @@ -26,6 +31,11 @@ const GlobalParamsDrawer = forwardRef( | |||
| } | |||
| }, | |||
| })); | |||
| const handleTypeChange = (name: NamePath) => { | |||
| form.setFieldValue(name, null); | |||
| }; | |||
| return ( | |||
| <Drawer | |||
| rootStyle={{ marginTop: '45px' }} | |||
| @@ -77,26 +87,41 @@ const GlobalParamsDrawer = forwardRef( | |||
| label="类 型" | |||
| rules={[{ required: true, message: '请选择类型' }]} | |||
| > | |||
| <Radio.Group> | |||
| <Radio.Group | |||
| onChange={() => handleTypeChange(['global_param', name, 'param_value'])} | |||
| > | |||
| <Radio value={1}>字符串</Radio> | |||
| <Radio value={2}>整型</Radio> | |||
| <Radio value={3}>布尔类型</Radio> | |||
| </Radio.Group> | |||
| </Form.Item> | |||
| <Form.Item | |||
| {...restField} | |||
| name={[name, 'param_value']} | |||
| label="值" | |||
| rules={[{ required: true, message: '请输入值' }]} | |||
| noStyle | |||
| shouldUpdate={(prev, cur) => | |||
| prev.global_param?.[name]?.param_type !== | |||
| cur.global_param?.[name]?.param_type | |||
| } | |||
| > | |||
| <Input placeholder="请输入值" allowClear /> | |||
| {({ getFieldValue }) => ( | |||
| <Form.Item | |||
| {...restField} | |||
| name={[name, 'param_value']} | |||
| label="值" | |||
| rules={getParamRules( | |||
| getFieldValue(['global_param', name, 'param_type']), | |||
| true, | |||
| )} | |||
| > | |||
| {getParamComponent(getFieldValue(['global_param', name, 'param_type']))} | |||
| </Form.Item> | |||
| )} | |||
| </Form.Item> | |||
| <Form.Item | |||
| {...restField} | |||
| name={[name, 'is_sensitive']} | |||
| label="脱敏显示" | |||
| rules={[{ required: true, message: '请选择' }]} | |||
| tooltip="脱敏后的参数以*****显示" | |||
| tooltip="展示关联的流水线的参数,脱敏的参数以xxxx展示" | |||
| > | |||
| <Radio.Group> | |||
| <Radio value={1}>是</Radio> | |||
| @@ -111,11 +136,11 @@ const GlobalParamsDrawer = forwardRef( | |||
| ></Button> | |||
| </div> | |||
| ))} | |||
| <Form.Item> | |||
| <Form.Item className={styles.add_button_form_item}> | |||
| <Button | |||
| className={styles.add_button} | |||
| type="link" | |||
| onClick={add} | |||
| onClick={() => add()} | |||
| icon={<PlusOutlined />} | |||
| > | |||
| 流水线参数 | |||
| @@ -101,13 +101,14 @@ const EditPipeline = () => { | |||
| } | |||
| const data = graph.save(); | |||
| console.log(data); | |||
| let params = { | |||
| const params = { | |||
| ...locationParams, | |||
| dag: JSON.stringify(data), | |||
| global_param: JSON.stringify(res.global_param), | |||
| }; | |||
| saveWorkflow(params).then((ret) => { | |||
| message.success('保存成功'); | |||
| closeParamsDrawer(); | |||
| setTimeout(() => { | |||
| if (val) { | |||
| navgite({ pathname: `/pipeline` }); | |||
| @@ -1,6 +1,5 @@ | |||
| // 流水线全局参数 | |||
| export type PipelineGlobalParam = { | |||
| workflow_id: number; | |||
| param_name: string; | |||
| description: string; | |||
| param_type: number; | |||
| @@ -9,3 +9,4 @@ export function getNameByCode(list, code) { | |||
| }); | |||
| return name; | |||
| } | |||
| @@ -66,7 +66,7 @@ public class DatasetVersionController extends BaseController { | |||
| */ | |||
| @GetMapping("/versions") | |||
| @ApiOperation("通过数据集id和version查询版本文件列表") | |||
| public GenericsAjaxResult<List<DatasetVersion>> queryByDatasetIdAndVersion(@RequestParam("dataset_id") Integer datasetId, | |||
| public GenericsAjaxResult<Map<String,Object>> queryByDatasetIdAndVersion(@RequestParam("dataset_id") Integer datasetId, | |||
| @RequestParam("version") String version) { | |||
| return genericsSuccess(this.datasetVersionService.queryByDatasetIdAndVersion(datasetId, version)); | |||
| } | |||
| @@ -64,7 +64,7 @@ public class ModelsVersionController extends BaseController { | |||
| * @return 匹配的模型版本记录列表 | |||
| */ | |||
| @GetMapping("/versions") | |||
| public GenericsAjaxResult<List<ModelsVersion>> queryByModelsIdAndVersion(@RequestParam("models_id") Integer modelsId, | |||
| public GenericsAjaxResult<Map<String,Object>> queryByModelsIdAndVersion(@RequestParam("models_id") Integer modelsId, | |||
| @RequestParam("version") String version) { | |||
| return genericsSuccess(this.modelsVersionService.queryByModelsIdAndVersion(modelsId, version)); | |||
| } | |||
| @@ -66,7 +66,7 @@ public interface DatasetVersionService { | |||
| DatasetVersion queryByDatasetVersion(DatasetVersion datasetVersion); | |||
| List<DatasetVersion> queryByDatasetIdAndVersion(Integer datasetId, String version); | |||
| Map<String,Object> queryByDatasetIdAndVersion(Integer datasetId, String version); | |||
| Map<Integer,String> deleteDatasetVersion(Integer datasetId, String version); | |||
| @@ -66,7 +66,7 @@ public interface ModelsVersionService { | |||
| ModelsVersion queryByModelsVersion(ModelsVersion modelsVersion); | |||
| List<ModelsVersion> queryByModelsIdAndVersion(Integer modelsId, String version); | |||
| Map<String,Object> queryByModelsIdAndVersion(Integer modelsId, String version); | |||
| Map<Integer, String> deleteModelsVersion(Integer modelsId, String version); | |||
| @@ -210,6 +210,9 @@ public class DatasetServiceImpl implements DatasetService { | |||
| @Override | |||
| public List<Map<String, String>> uploadDataset(MultipartFile[] files) throws Exception { | |||
| List<Map<String, String>> results = new ArrayList<>(); | |||
| //时间戳统一定在外面,一次上传就定好 | |||
| Date createTime = new Date(); | |||
| String timestamp = new SimpleDateFormat("yyyyMMdd-HHmmss").format(createTime); | |||
| for (MultipartFile file:files){ | |||
| if (file.isEmpty()) { | |||
| @@ -222,8 +225,6 @@ public class DatasetServiceImpl implements DatasetService { | |||
| // 其余操作基于 modelsVersionToUse | |||
| String username = SecurityUtils.getLoginUser().getUsername(); | |||
| String fileName = file.getOriginalFilename(); | |||
| Date createTime = new Date(); | |||
| String timestamp = new SimpleDateFormat("yyyyMMdd-HHmmss").format(createTime); | |||
| String objectName = "datasets/" + username + "/" + timestamp + "/" + fileName; | |||
| // 上传文件到MinIO并将记录新增到数据库中 | |||
| @@ -33,6 +33,9 @@ public class DatasetVersionServiceImpl implements DatasetVersionService { | |||
| @Resource | |||
| private DatasetVersionDao datasetVersionDao; | |||
| // 固定存储桶名 | |||
| private final String bucketName = "platform-data"; | |||
| /** | |||
| * 通过ID查询单条数据 | |||
| * | |||
| @@ -131,9 +134,21 @@ public class DatasetVersionServiceImpl implements DatasetVersionService { | |||
| } | |||
| @Override | |||
| public Map<String,Object> queryByDatasetIdAndVersion(Integer datasetId, String version) { | |||
| Map<String, Object> response = new HashMap<>(); | |||
| List<DatasetVersion> datasetVersionList = this.datasetVersionDao.queryAllByDatasetVersion(datasetId, version); | |||
| datasetVersionList.stream(). | |||
| findFirst(). | |||
| ifPresent(datasetVersion -> { | |||
| String url = datasetVersion.getUrl(); | |||
| String path = bucketName + '/' + url.substring(0, url.lastIndexOf('/')); | |||
| response.put("path", path); | |||
| }); | |||
| public List<DatasetVersion> queryByDatasetIdAndVersion(Integer datasetId, String version) { | |||
| return this.datasetVersionDao.queryAllByDatasetVersion(datasetId, version); | |||
| response.put("content", datasetVersionList); | |||
| return response; | |||
| } | |||
| @Override | |||
| @@ -205,6 +205,9 @@ public class ModelsServiceImpl implements ModelsService { | |||
| public List<Map<String, String>> uploadModels(MultipartFile[] files) throws Exception { | |||
| List<Map<String, String>> results = new ArrayList<>(); | |||
| //时间戳统一定在外面,一次上传就定好 | |||
| Date createTime = new Date(); | |||
| String timestamp = new SimpleDateFormat("yyyyMMdd-HHmmss").format(createTime); | |||
| for (MultipartFile file:files){ | |||
| if (file.isEmpty()) { | |||
| @@ -217,8 +220,6 @@ public class ModelsServiceImpl implements ModelsService { | |||
| // 其余操作基于 modelsVersionToUse | |||
| String username = SecurityUtils.getLoginUser().getUsername(); | |||
| String fileName = file.getOriginalFilename(); | |||
| Date createTime = new Date(); | |||
| String timestamp = new SimpleDateFormat("yyyyMMdd-HHmmss").format(createTime); | |||
| String objectName = "models/" + username + "/" + timestamp + "/" + fileName; | |||
| // 上传文件到MinIO并将记录新增到数据库中 | |||
| @@ -34,6 +34,9 @@ public class ModelsVersionServiceImpl implements ModelsVersionService { | |||
| @Resource | |||
| private ModelsDao modelsDao; | |||
| // 固定存储桶名 | |||
| private final String bucketName = "platform-data"; | |||
| /** | |||
| * 通过ID查询单条数据 | |||
| * | |||
| @@ -159,8 +162,21 @@ public class ModelsVersionServiceImpl implements ModelsVersionService { | |||
| * @return 新的模型版本记录列表 | |||
| */ | |||
| @Override | |||
| public List<ModelsVersion> queryByModelsIdAndVersion(Integer modelsId, String version) { | |||
| return this.modelsVersionDao.queryAllByModelsVersion(modelsId, version) ; | |||
| public Map<String,Object> queryByModelsIdAndVersion(Integer modelsId, String version) { | |||
| Map<String,Object> response = new HashMap<>(); | |||
| List<ModelsVersion> modelsVersionList = this.modelsVersionDao.queryAllByModelsVersion(modelsId, version); | |||
| modelsVersionList.stream(). | |||
| findFirst(). | |||
| ifPresent(modelsVersion -> { | |||
| String url = modelsVersion.getUrl(); | |||
| String path = bucketName + '/' + url.substring(0, url.lastIndexOf('/')); | |||
| response.put("path", path); | |||
| }); | |||
| response.put("content", modelsVersionList); | |||
| return response; | |||
| } | |||