diff --git a/routers/repo/aisafety.go b/routers/repo/aisafety.go index b4ecbadf4..d4c162544 100644 --- a/routers/repo/aisafety.go +++ b/routers/repo/aisafety.go @@ -640,10 +640,8 @@ func AiSafetyCreateForPost(ctx *context.Context) { return } if taskType == models.TypeCloudBrainTwo { - ctx.Data["datasetType"] = models.TypeCloudBrainTwo err = createForNPU(ctx, jobName) } else if taskType == models.TypeCloudBrainOne { - ctx.Data["datasetType"] = models.TypeCloudBrainOne err = createForGPU(ctx, jobName) } if err != nil { @@ -971,11 +969,18 @@ func modelSafetyNewDataPrepare(ctx *context.Context) error { ctx.Data["model_version"] = ctx.Query("model_version") if ctx.QueryInt("type") == models.TypeCloudBrainOne { + ctx.Data["type"] = models.TypeCloudBrainOne + ctx.Data["compute_resource"] = models.GPUResource + ctx.Data["datasetType"] = models.TypeCloudBrainOne + ctx.Data["BaseDataSetName"] = setting.ModelSafetyTest.GPUBaseDataSetName ctx.Data["BaseDataSetUUID"] = setting.ModelSafetyTest.GPUBaseDataSetUUID ctx.Data["CombatDataSetName"] = setting.ModelSafetyTest.GPUCombatDataSetName ctx.Data["CombatDataSetUUID"] = setting.ModelSafetyTest.GPUCombatDataSetUUID } else { + ctx.Data["type"] = models.TypeCloudBrainTwo + ctx.Data["compute_resource"] = models.NPUResource + ctx.Data["datasetType"] = models.TypeCloudBrainTwo ctx.Data["BaseDataSetName"] = setting.ModelSafetyTest.NPUBaseDataSetName ctx.Data["BaseDataSetUUID"] = setting.ModelSafetyTest.NPUBaseDataSetUUID ctx.Data["CombatDataSetName"] = setting.ModelSafetyTest.NPUCombatDataSetName