diff --git a/modules/cloudbrain/cloudbrain.go b/modules/cloudbrain/cloudbrain.go index 6111cf460..b70ca2ada 100755 --- a/modules/cloudbrain/cloudbrain.go +++ b/modules/cloudbrain/cloudbrain.go @@ -490,6 +490,22 @@ func RestartTask(ctx *context.Context, task *models.Cloudbrain, newID *string) e } } + if task.PreTrainModelUrl != "" { //预训练 + realPath := setting.Attachment.Minio.RealPath + task.PreTrainModelUrl + _, err := os.Stat(realPath) + if err != nil { + log.Warn("The model may be deleted", err) + } else { + volumes = append(volumes, models.Volume{ + HostPath: models.StHostPath{ + Path: realPath, + MountPath: ModelMountPath, + ReadOnly: true, + }, + }) + } + } + createTime := timeutil.TimeStampNow() jobResult, err := CreateJob(jobName, models.CreateJobParams{ JobName: jobName, diff --git a/modules/grampus/grampus.go b/modules/grampus/grampus.go index 1b69edd44..712afc337 100755 --- a/modules/grampus/grampus.go +++ b/modules/grampus/grampus.go @@ -1,7 +1,6 @@ package grampus import ( - "encoding/json" "fmt" "strings" @@ -40,7 +39,7 @@ var ( SpecialPools *models.SpecialPools - CommandPrepareScriptGpu = ";mkdir -p output;mkdir -p code;mkdir -p dataset;mkdir -p pretrainmodel;echo \"start loading script\";wget -q https://openi.pcl.ac.cn/OpenIOSSG/%s/archive/master.zip;" + + CommandPrepareScriptGpu = ";mkdir -p output;mkdir -p code;mkdir -p dataset;mkdir -p pretrainmodel;echo \"start loading script\";wget -q https://git.openi.org.cn/OpenIOSSG/%s/archive/master.zip;" + "echo \"finish loading script\";unzip -q master.zip;cd %s;chmod 777 downloader_for_obs uploader_for_npu downloader_for_minio uploader_for_gpu;" ) @@ -85,22 +84,28 @@ type GenerateTrainJobReq struct { } type GenerateNotebookJobReq struct { - JobName string - Command string - ImageUrl string - ImageId string - DisplayJobName string - Uuid string - Description string - CodeStoragePath string - CommitID string - BranchName string - ComputeResource string - ProcessType string - DatasetNames string - DatasetInfos map[string]models.DatasetInfo - Spec *models.Specification - CodeName string + JobName string + Command string + ImageUrl string + ImageId string + DisplayJobName string + Uuid string + Description string + CodeStoragePath string + CommitID string + BranchName string + ComputeResource string + ProcessType string + DatasetNames string + DatasetInfos map[string]models.DatasetInfo + ModelName string + LabelName string + CkptName string + ModelVersion string + PreTrainModelPath string + PreTrainModelUrl string + Spec *models.Specification + CodeName string } func getEndPoint() string { @@ -151,16 +156,37 @@ func GenerateNotebookJob(ctx *context.Context, req *GenerateNotebookJobReq) (job imageUrl := req.ImageUrl if ProcessorTypeNPU == req.ProcessType { datasetGrampus = getDatasetGrampus(req.DatasetInfos) + if len(req.ModelName) != 0 { + datasetGrampus = append(datasetGrampus, models.GrampusDataset{ + Name: req.ModelName, + Bucket: setting.Bucket, + EndPoint: getEndPoint(), + ReadOnly: true, + ObjectKey: req.PreTrainModelPath, + }) + } + codeGrampus = models.GrampusDataset{ Name: req.CodeName, Bucket: setting.Bucket, EndPoint: getEndPoint(), ObjectKey: req.CodeStoragePath + cloudbrain.DefaultBranchName + ".zip", + ReadOnly: false, } imageUrl = "" req.Command = "" } else { datasetGrampus, cpCommand = getDatasetGPUGrampus(req.DatasetInfos) + if len(req.ModelName) != 0 { + datasetGrampus = append(datasetGrampus, models.GrampusDataset{ + Name: req.ModelName, + Bucket: setting.Attachment.Minio.Bucket, + EndPoint: setting.Attachment.Minio.Endpoint, + ObjectKey: req.PreTrainModelPath, + ReadOnly: true, + ContainerPath: "/model", + }) + } codeGrampus = models.GrampusDataset{ Name: req.CodeName, Bucket: setting.Attachment.Minio.Bucket, @@ -218,6 +244,11 @@ func GenerateNotebookJob(ctx *context.Context, req *GenerateNotebookJobReq) (job CreatedUnix: createTime, UpdatedUnix: createTime, Spec: req.Spec, + ModelName: req.ModelName, + ModelVersion: req.ModelVersion, + LabelName: req.LabelName, + PreTrainModelUrl: req.PreTrainModelUrl, + CkptName: req.CkptName, }) if err != nil { @@ -406,11 +437,6 @@ func TransTrainJobStatus(status string) string { return strings.ToUpper(status) } -func InitSpecialPool() { - if SpecialPools == nil && setting.Grampus.SpecialPools != "" { - json.Unmarshal([]byte(setting.Grampus.SpecialPools), &SpecialPools) - } -} func GetNpuModelRemoteObsUrl(jobName string) string { return "s3:///" + BucketRemote + "/" + GetNpuModelObjectKey(jobName) diff --git a/routers/repo/grampus.go b/routers/repo/grampus.go index d5b525415..c89c0451b 100755 --- a/routers/repo/grampus.go +++ b/routers/repo/grampus.go @@ -252,6 +252,16 @@ func GrampusNotebookCreate(ctx *context.Context, form auth.CreateGrampusNotebook CodeName: strings.ToLower(repo.Name), } + if form.ModelName != "" { //使用预训练模型训练 + req.ModelName = form.ModelName + req.LabelName = form.LabelName + req.CkptName = form.CkptName + req.ModelVersion = form.ModelVersion + req.PreTrainModelUrl = form.PreTrainModelUrl + req.PreTrainModelPath = getPreTrainModelPath(form.PreTrainModelUrl, form.CkptName) + + } + _, err = grampus.GenerateNotebookJob(ctx, req) if err != nil { log.Error("GenerateNotebookJob failed:%v", err.Error(), ctx.Data["MsgID"])