diff --git a/models/cloudbrain.go b/models/cloudbrain.go
index f053552db..d9db9ef23 100755
--- a/models/cloudbrain.go
+++ b/models/cloudbrain.go
@@ -165,6 +165,7 @@ type TaskInfo struct {
Username string `json:"username"`
TaskName string `json:"task_name"`
CodeName string `json:"code_name"`
+ BenchmarkCategory string `json:"selected_category"`
}
func ConvertToTaskPod(input map[string]interface{}) (TaskPod, error) {
diff --git a/modules/auth/cloudbrain.go b/modules/auth/cloudbrain.go
index 2470c2ad6..e7268052d 100755
--- a/modules/auth/cloudbrain.go
+++ b/modules/auth/cloudbrain.go
@@ -11,6 +11,7 @@ type CreateCloudBrainForm struct {
Command string `form:"command" binding:"Required"`
Attachment string `form:"attachment" binding:"Required"`
JobType string `form:"job_type" binding:"Required"`
+ BenchmarkCategory string `form:"benchmark_category"`
}
type CommitImageCloudBrainForm struct {
diff --git a/routers/repo/cloudbrain.go b/routers/repo/cloudbrain.go
index c710f8c5b..8b072e1e6 100755
--- a/routers/repo/cloudbrain.go
+++ b/routers/repo/cloudbrain.go
@@ -165,12 +165,12 @@ func CloudBrainCreate(ctx *context.Context, form auth.CreateCloudBrainForm) {
benchmarkPath := setting.JobPath + jobName + cloudbrain.BenchMarkMountPath
if setting.IsBenchmarkEnabled && jobType == string(models.JobTypeBenchmark) {
- downloadRateCode(repo, jobName, setting.BenchmarkCode, benchmarkPath)
+ downloadRateCode(repo, jobName, setting.BenchmarkCode, benchmarkPath, form.BenchmarkCategory)
}
snn4imagenetPath := setting.JobPath + jobName + cloudbrain.Snn4imagenetMountPath
if setting.IsSnn4imagenetEnabled && jobType == string(models.JobTypeSnn4imagenet) {
- downloadRateCode(repo, jobName, setting.Snn4imagenetCode, snn4imagenetPath)
+ downloadRateCode(repo, jobName, setting.Snn4imagenetCode, snn4imagenetPath, "")
}
err = cloudbrain.GenerateTask(ctx, jobName, image, command, uuid, codePath, modelPath, benchmarkPath, snn4imagenetPath, jobType)
@@ -340,7 +340,7 @@ func downloadCode(repo *models.Repository, codePath string) error {
return nil
}
-func downloadRateCode(repo *models.Repository, taskName, gitPath, codePath string) error {
+func downloadRateCode(repo *models.Repository, taskName, gitPath, codePath, benchmarkCategory string) error {
err := os.MkdirAll(codePath, os.ModePerm)
if err != nil {
log.Error("mkdir codePath failed", err.Error())
@@ -369,6 +369,7 @@ func downloadRateCode(repo *models.Repository, taskName, gitPath, codePath strin
Username: repo.Owner.Name,
TaskName: taskName,
CodeName: repo.Name,
+ BenchmarkCategory: benchmarkCategory,
})
if err != nil {
log.Error("json.Marshal failed", err.Error())
diff --git a/templates/repo/cloudbrain/new.tmpl b/templates/repo/cloudbrain/new.tmpl
index 28e33cec2..2c4832736 100755
--- a/templates/repo/cloudbrain/new.tmpl
+++ b/templates/repo/cloudbrain/new.tmpl
@@ -129,7 +129,7 @@