diff --git a/models/cloudbrain.go b/models/cloudbrain.go index f053552db..d9db9ef23 100755 --- a/models/cloudbrain.go +++ b/models/cloudbrain.go @@ -165,6 +165,7 @@ type TaskInfo struct { Username string `json:"username"` TaskName string `json:"task_name"` CodeName string `json:"code_name"` + BenchmarkCategory string `json:"selected_category"` } func ConvertToTaskPod(input map[string]interface{}) (TaskPod, error) { diff --git a/modules/auth/cloudbrain.go b/modules/auth/cloudbrain.go index 2470c2ad6..e7268052d 100755 --- a/modules/auth/cloudbrain.go +++ b/modules/auth/cloudbrain.go @@ -11,6 +11,7 @@ type CreateCloudBrainForm struct { Command string `form:"command" binding:"Required"` Attachment string `form:"attachment" binding:"Required"` JobType string `form:"job_type" binding:"Required"` + BenchmarkCategory string `form:"benchmark_category"` } type CommitImageCloudBrainForm struct { diff --git a/routers/repo/cloudbrain.go b/routers/repo/cloudbrain.go index c710f8c5b..8b072e1e6 100755 --- a/routers/repo/cloudbrain.go +++ b/routers/repo/cloudbrain.go @@ -165,12 +165,12 @@ func CloudBrainCreate(ctx *context.Context, form auth.CreateCloudBrainForm) { benchmarkPath := setting.JobPath + jobName + cloudbrain.BenchMarkMountPath if setting.IsBenchmarkEnabled && jobType == string(models.JobTypeBenchmark) { - downloadRateCode(repo, jobName, setting.BenchmarkCode, benchmarkPath) + downloadRateCode(repo, jobName, setting.BenchmarkCode, benchmarkPath, form.BenchmarkCategory) } snn4imagenetPath := setting.JobPath + jobName + cloudbrain.Snn4imagenetMountPath if setting.IsSnn4imagenetEnabled && jobType == string(models.JobTypeSnn4imagenet) { - downloadRateCode(repo, jobName, setting.Snn4imagenetCode, snn4imagenetPath) + downloadRateCode(repo, jobName, setting.Snn4imagenetCode, snn4imagenetPath, "") } err = cloudbrain.GenerateTask(ctx, jobName, image, command, uuid, codePath, modelPath, benchmarkPath, snn4imagenetPath, jobType) @@ -340,7 +340,7 @@ func downloadCode(repo *models.Repository, codePath string) error { return nil } -func downloadRateCode(repo *models.Repository, taskName, gitPath, codePath string) error { +func downloadRateCode(repo *models.Repository, taskName, gitPath, codePath, benchmarkCategory string) error { err := os.MkdirAll(codePath, os.ModePerm) if err != nil { log.Error("mkdir codePath failed", err.Error()) @@ -369,6 +369,7 @@ func downloadRateCode(repo *models.Repository, taskName, gitPath, codePath strin Username: repo.Owner.Name, TaskName: taskName, CodeName: repo.Name, + BenchmarkCategory: benchmarkCategory, }) if err != nil { log.Error("json.Marshal failed", err.Error()) diff --git a/templates/repo/cloudbrain/new.tmpl b/templates/repo/cloudbrain/new.tmpl index 28e33cec2..2c4832736 100755 --- a/templates/repo/cloudbrain/new.tmpl +++ b/templates/repo/cloudbrain/new.tmpl @@ -129,7 +129,7 @@