diff --git a/models/cloudbrain_temp.go b/models/cloudbrain_temp.go index e7d424006..f5d04fa97 100755 --- a/models/cloudbrain_temp.go +++ b/models/cloudbrain_temp.go @@ -1,6 +1,7 @@ package models import ( + "code.gitea.io/gitea/modules/setting" "time" "code.gitea.io/gitea/modules/timeutil" @@ -15,13 +16,13 @@ type CloudbrainTemp struct { CloudbrainID int64 `xorm:"pk"` JobName string Type int - JobType string `xorm:"INDEX NOT NULL DEFAULT 'DEBUG'"` - Status string `xorm:"INDEX NOT NULL DEFAULT 'TEMP'"` - VersionCount int `xorm:"NOT NULL DEFAULT 0"` - QueryTimes int `xorm:"INDEX NOT NULL DEFAULT 0"` - CreatedUnix timeutil.TimeStamp `xorm:"INDEX"` - UpdatedUnix timeutil.TimeStamp `xorm:"INDEX updated"` - DeletedAt time.Time `xorm:"deleted"` + JobType string `xorm:"INDEX NOT NULL DEFAULT 'DEBUG'"` + Status string `xorm:"INDEX NOT NULL DEFAULT 'TEMP'"` + //VersionCount int `xorm:"NOT NULL DEFAULT 0"` + QueryTimes int `xorm:"INDEX NOT NULL DEFAULT 0"` + CreatedUnix timeutil.TimeStamp `xorm:"INDEX"` + UpdatedUnix timeutil.TimeStamp `xorm:"INDEX updated"` + DeletedAt time.Time `xorm:"deleted"` } func InsertCloudbrainTemp(temp *CloudbrainTemp) (err error) { @@ -47,6 +48,14 @@ func GetCloudbrainTempByCloudbrainID(id int64) (*CloudbrainTemp, error) { return getCloudBrainTemp(temp) } +func GetCloudBrainTempJobs() ([]*CloudbrainTemp, error) { + jobs := make([]*CloudbrainTemp, 0, 10) + return jobs, x. + Where("status = ? AND query_times < ?", JobStatusTemp, setting.MaxTempQueryTimes). + Limit(100). + Find(&jobs) +} + func DeleteCloudbrainTemp(temp *CloudbrainTemp) error { return deleteCloudbrainTemp(x, temp) } @@ -55,3 +64,14 @@ func deleteCloudbrainTemp(e Engine, temp *CloudbrainTemp) error { _, err := e.Where("cloudbrain_id = ?", temp.CloudbrainID).Delete(temp) return err } + +func IncreaseCloudbrainTempQueryTimes(temp *CloudbrainTemp) error { + times := temp.QueryTimes + 1 + if times >= setting.MaxTempQueryTimes { + temp.Status = string(ModelArtsTrainJobFailed) + } + + _, err := x.Exec("update cloudbrain_temp set query_times=?, status = ? where cloudbrain_id=?", temp.QueryTimes+1, temp.Status, temp.CloudbrainID) + + return err +} diff --git a/modules/cron/tasks_basic.go b/modules/cron/tasks_basic.go index b3a6c02a1..080f5bd81 100755 --- a/modules/cron/tasks_basic.go +++ b/modules/cron/tasks_basic.go @@ -5,6 +5,7 @@ package cron import ( + "code.gitea.io/gitea/modules/modelarts" "context" "time" @@ -207,6 +208,17 @@ func registerSyncCloudbrainStatus() { }) } +func registerSyncModelArtsTempJobs() { + RegisterTaskFatal("sync_model_arts_temp_jobs", &BaseConfig{ + Enabled: true, + RunAtStart: false, + Schedule: "@every 1m", + }, func(ctx context.Context, _ *models.User, _ Config) error { + modelarts.SyncTempStatusJob() + return nil + }) +} + func initBasicTasks() { registerUpdateMirrorTask() registerRepoHealthCheck() @@ -227,4 +239,5 @@ func initBasicTasks() { registerSyncCloudbrainStatus() registerHandleOrgStatistic() + registerSyncModelArtsTempJobs() } diff --git a/modules/modelarts/modelarts.go b/modules/modelarts/modelarts.go index 1591e1dd5..cadd83779 100755 --- a/modules/modelarts/modelarts.go +++ b/modules/modelarts/modelarts.go @@ -849,7 +849,7 @@ func HandleTrainJobInfo(task *models.Cloudbrain) error { if isTempJob(task.JobID, task.Status) { if task.VersionCount > VersionCountOne { //multi version - result, err := GetTrainJobVersionList(1000, 1, strings.TrimPrefix(task.JobID, models.TempJobIdPrefix)) + result, err := GetTrainJobVersionList(1000, 1, task.JobID) if err != nil { log.Error("GetTrainJobVersionList failed:%v", err) return err @@ -1043,3 +1043,149 @@ func isTempJob(jobID, status string) bool { } return false } + +func SyncTempStatusJob() { + jobs, err := models.GetCloudBrainTempJobs() + if err != nil { + log.Error("GetCloudBrainTempJobs failed:%v", err.Error()) + return + } + + for _, temp := range jobs { + task, err := models.GetCloudbrainByID(strconv.FormatInt(temp.CloudbrainID, 10)) + if err != nil { + log.Error("GetCloudbrainByID failed:%v", err) + continue + } + + if temp.Type == models.TypeCloudBrainTwo { + if temp.JobType == string(models.JobTypeDebug) { + result, err := GetNotebookList(1000, 0, "createTime", "DESC", temp.JobName) + if err != nil { + log.Error("GetNotebookList failed:%v", err) + continue + } + + err = models.IncreaseCloudbrainTempQueryTimes(temp) + if err != nil { + log.Error("IncreaseCloudbrainTempQueryTimes failed:%v", err) + } + + if result != nil { + count, err := models.GetCloudbrainCountByJobName(temp.JobName, temp.JobType) + if err != nil { + log.Error("GetCloudbrainCountByJobName failed:%v", err) + continue + } + + if len(result.NotebookList) == count { + if result.NotebookList[0].JobName == temp.JobName { + log.Info("find the record(%s)", temp.JobName) + task.Status = result.NotebookList[0].Status + task.JobID = result.NotebookList[0].JobID + + err = models.UpdateJob(task) + if err != nil { + log.Error("UpdateJob(%s) failed:%v", task.JobName, err) + continue + } + + err = models.DeleteCloudbrainTemp(temp) + if err != nil { + log.Error("DeleteCloudbrainTemp(%s) failed:%v", task.DisplayJobName, err) + continue + } + } else { + log.Error("can not find the record(%s) until now", temp.JobName) + } + } else { + log.Error("can not find the record(%s) until now", temp.JobName) + } + } else { + log.Error("can not find the record(%s) until now", temp.JobName) + } + } else if temp.JobType == string(models.JobTypeTrain) || temp.JobType == string(models.JobTypeInference) { + if task.VersionCount > VersionCountOne { + //multi version + result, err := GetTrainJobVersionList(1000, 1, task.JobID) + if err != nil { + log.Error("GetTrainJobVersionList failed:%v", err) + continue + } + + err = models.IncreaseCloudbrainTempQueryTimes(temp) + if err != nil { + log.Error("IncreaseCloudbrainTempQueryTimes failed:%v", err) + } + + if result != nil { + if strconv.FormatInt(result.JobID, 10) == task.JobID && result.JobName == task.JobName { + if result.VersionCount == int64(task.VersionCount) { + log.Info("find the record(%s)", task.DisplayJobName) + task.Status = TransTrainJobStatus(result.JobVersionList[0].IntStatus) + task.VersionName = result.JobVersionList[0].VersionName + task.VersionID = result.JobVersionList[0].VersionID + + err = models.UpdateJob(task) + if err != nil { + log.Error("UpdateJob(%s) failed:%v", task.JobName, err) + continue + } + + err = models.DeleteCloudbrainTemp(temp) + if err != nil { + log.Error("DeleteCloudbrainTemp(%s) failed:%v", task.DisplayJobName, err) + continue + } + } else { + log.Error("can not find the record(%s) until now", task.DisplayJobName) + } + } else { + log.Error("can not find the record(%s) until now", task.DisplayJobName) + } + } + } else { + //inference or one version + result, err := GetTrainJobList(1000, 1, "create_time", "desc", task.JobName) + if err != nil { + log.Error("GetTrainJobList failed:%v", err) + continue + } + + err = models.IncreaseCloudbrainTempQueryTimes(temp) + if err != nil { + log.Error("IncreaseCloudbrainTempQueryTimes failed:%v", err) + } + + if result != nil { + for _, job := range result.JobList { + if task.JobName == job.JobName { + log.Info("find the record(%s)", task.DisplayJobName) + task.Status = TransTrainJobStatus(job.IntStatus) + task.JobID = strconv.FormatInt(job.JobID, 10) + + err = models.UpdateJob(task) + if err != nil { + log.Error("UpdateJob(%s) failed:%v", task.DisplayJobName, err) + continue + } + + err = models.DeleteCloudbrainTemp(temp) + if err != nil { + log.Error("DeleteCloudbrainTemp(%s) failed:%v", task.DisplayJobName, err) + continue + } + } + } + } + + } + } + } else { + log.Error("invalid job_type(%d)", temp.Type) + continue + } + } + + return +} diff --git a/modules/setting/setting.go b/modules/setting/setting.go index f63088091..ff6a02e17 100755 --- a/modules/setting/setting.go +++ b/modules/setting/setting.go @@ -539,6 +539,7 @@ var ( DebugHost string ImageInfos string Capacity int + MaxTempQueryTimes int //train-job ResourcePools string Engines string @@ -1417,6 +1418,7 @@ func NewContext() { Flavor = sec.Key("FLAVOR").MustString("") ImageInfos = sec.Key("IMAGE_INFOS").MustString("") Capacity = sec.Key("IMAGE_INFOS").MustInt(100) + MaxTempQueryTimes = sec.Key("MAX_TEMP_QUERY_TIMES").MustInt(10) ResourcePools = sec.Key("Resource_Pools").MustString("") Engines = sec.Key("Engines").MustString("") EngineVersions = sec.Key("Engine_Versions").MustString("")