From 65f43fa820645c632db7a591ed1a71467fffdabb Mon Sep 17 00:00:00 2001 From: ychao_1983 Date: Mon, 5 Sep 2022 09:50:47 +0800 Subject: [PATCH] =?UTF-8?q?=E6=8F=90=E4=BA=A4=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- models/cloudbrain.go | 37 ++++++++++++++++---------------- modules/cloudbrain/cloudbrain.go | 37 +++++++++++++++++++++----------- modules/modelarts/modelarts.go | 1 + routers/repo/cloudbrain.go | 14 ++++++++++++ 4 files changed, 59 insertions(+), 30 deletions(-) diff --git a/models/cloudbrain.go b/models/cloudbrain.go index 5be5d566e..650fba242 100755 --- a/models/cloudbrain.go +++ b/models/cloudbrain.go @@ -168,24 +168,25 @@ type Cloudbrain struct { ImageID string //grampus image_id AiCenter string //grampus ai center: center_id+center_name - TrainUrl string //输出模型的obs路径 - BranchName string //分支名称 - Parameters string //传给modelarts的param参数 - BootFile string //启动文件 - DataUrl string //数据集的obs路径 - LogUrl string //日志输出的obs路径 - PreVersionId int64 //父版本的版本id - FlavorCode string //modelarts上的规格id - Description string `xorm:"varchar(256)"` //描述 - WorkServerNumber int //节点数 - FlavorName string //规格名称 - EngineName string //引擎名称 - TotalVersionCount int //任务的所有版本数量,包括删除的 - LabelName string //标签名称 - ModelName string //模型名称 - ModelVersion string //模型版本 - CkptName string //权重文件名称 - ResultUrl string //推理结果的obs路径 + TrainUrl string //输出模型的obs路径 + BranchName string //分支名称 + Parameters string //传给modelarts的param参数 + BootFile string //启动文件 + DataUrl string //数据集的obs路径 + LogUrl string //日志输出的obs路径 + PreVersionId int64 //父版本的版本id + FlavorCode string //modelarts上的规格id + Description string `xorm:"varchar(256)"` //描述 + WorkServerNumber int //节点数 + FlavorName string //规格名称 + EngineName string //引擎名称 + TotalVersionCount int //任务的所有版本数量,包括删除的 + LabelName string //标签名称 + ModelName string //模型名称 + ModelVersion string //模型版本 + CkptName string //权重文件名称 + PreTrainingModelUrl string //预训练模型地址 + ResultUrl string //推理结果的obs路径 User *User `xorm:"-"` Repo *Repository `xorm:"-"` diff --git a/modules/cloudbrain/cloudbrain.go b/modules/cloudbrain/cloudbrain.go index 85dcb1ac5..1ee0bd47b 100755 --- a/modules/cloudbrain/cloudbrain.go +++ b/modules/cloudbrain/cloudbrain.go @@ -20,18 +20,19 @@ import ( const ( //Command = `pip3 install jupyterlab==2.2.5 -i https://pypi.tuna.tsinghua.edu.cn/simple;service ssh stop;jupyter lab --no-browser --ip=0.0.0.0 --allow-root --notebook-dir="/code" --port=80 --LabApp.token="" --LabApp.allow_origin="self https://cloudbrain.pcl.ac.cn"` //CommandBenchmark = `echo "start benchmark";python /code/test.py;echo "end benchmark"` - CommandBenchmark = `echo "start benchmark";cd /benchmark && bash run_bk.sh;echo "end benchmark"` - CodeMountPath = "/code" - DataSetMountPath = "/dataset" - ModelMountPath = "/model" - LogFile = "log.txt" - BenchMarkMountPath = "/benchmark" - BenchMarkResourceID = 1 - Snn4imagenetMountPath = "/snn4imagenet" - BrainScoreMountPath = "/brainscore" - TaskInfoName = "/taskInfo" - Snn4imagenetCommand = `/opt/conda/bin/python /snn4imagenet/testSNN_script.py --modelname '%s' --modelpath '/dataset' --modeldescription '%s'` - BrainScoreCommand = `bash /brainscore/brainscore_test_par4shSrcipt.sh -b '%s' -n '%s' -p '/dataset' -d '%s'` + CommandBenchmark = `echo "start benchmark";cd /benchmark && bash run_bk.sh;echo "end benchmark"` + CodeMountPath = "/code" + DataSetMountPath = "/dataset" + ModelMountPath = "/model" + PretrainModelMountPath = "/pretrainmodel" + LogFile = "log.txt" + BenchMarkMountPath = "/benchmark" + BenchMarkResourceID = 1 + Snn4imagenetMountPath = "/snn4imagenet" + BrainScoreMountPath = "/brainscore" + TaskInfoName = "/taskInfo" + Snn4imagenetCommand = `/opt/conda/bin/python /snn4imagenet/testSNN_script.py --modelname '%s' --modelpath '/dataset' --modeldescription '%s'` + BrainScoreCommand = `bash /brainscore/brainscore_test_par4shSrcipt.sh -b '%s' -n '%s' -p '/dataset' -d '%s'` SubTaskName = "task1" @@ -77,6 +78,8 @@ type GenerateCloudBrainTaskReq struct { ModelVersion string CkptName string LabelName string + PreTrainModelPath string + PreTrainingModelUrl string Spec *models.Specification } @@ -276,6 +279,16 @@ func GenerateTask(req GenerateCloudBrainTaskReq) error { }, } + if req.PreTrainingModelUrl != "" { //预训练 + volumes = append(volumes, models.Volume{ + HostPath: models.StHostPath{ + Path: req.PreTrainModelPath, + MountPath: PretrainModelMountPath, + ReadOnly: true, + }, + }) + } + if len(req.DatasetInfos) == 1 { volumes = append(volumes, models.Volume{ HostPath: models.StHostPath{ diff --git a/modules/modelarts/modelarts.go b/modules/modelarts/modelarts.go index 2887b625d..81dfcc551 100755 --- a/modules/modelarts/modelarts.go +++ b/modules/modelarts/modelarts.go @@ -54,6 +54,7 @@ const ( MultiDataUrl = "multi_data_url" ResultUrl = "result_url" CkptUrl = "ckpt_url" + PreTrainModelUrl = "pretrain_model_url" DeviceTarget = "device_target" Ascend = "Ascend" PerPage = 10 diff --git a/routers/repo/cloudbrain.go b/routers/repo/cloudbrain.go index c2804d319..7264a50aa 100755 --- a/routers/repo/cloudbrain.go +++ b/routers/repo/cloudbrain.go @@ -327,6 +327,17 @@ func CloudBrainCreate(ctx *context.Context, form auth.CreateCloudBrainForm) { ResultPath: storage.GetMinioPath(jobName, cloudbrain.ResultPath+"/"), Spec: spec, } + + if form.ModelName != "" { //使用预训练模型训练 + req.ModelName = form.ModelName + req.LabelName = form.LabelName + req.CkptName = form.CkptName + req.ModelVersion = form.ModelVersion + req.PreTrainModelPath = setting.Attachment.Minio.RealPath + form.PreTrainingModelUrl + req.PreTrainingModelUrl = form.PreTrainingModelUrl + + } + err = cloudbrain.GenerateTask(req) if err != nil { @@ -2626,6 +2637,9 @@ func getTrainJobCommand(form auth.CreateCloudBrainForm) (string, error) { param += " --" + parameter.Label + "=" + parameter.Value } } + if form.CkptName != "" { + param += " --pretrainmodelname" + "=" + form.CkptName + } command += "python /code/" + bootFile + param + " | tee " + cloudbrain.ModelMountPath + "/" + form.DisplayJobName + "-" + cloudbrain.LogFile