Browse Source

add inter

fix-2419
lewis 3 years ago
parent
commit
561e37b239
4 changed files with 47 additions and 71 deletions
  1. +43
    -2
      modules/modelarts/modelarts.go
  2. +0
    -65
      routers/api/v1/repo/modelarts.go
  3. +2
    -2
      routers/repo/grampus.go
  4. +2
    -2
      routers/repo/modelarts.go

+ 43
- 2
modules/modelarts/modelarts.go View File

@@ -62,7 +62,7 @@ const (
PerPage = 10
IsLatestVersion = "1"
NotLatestVersion = "0"
VersionCount = 1
VersionCountOne = 1

SortByCreateTime = "create_time"
ConfigTypeCustom = "custom"
@@ -470,7 +470,7 @@ func GenerateTrainJobVersion(ctx *context.Context, req *GenerateTrainJobReq, job
Status: string(models.ModelArtsTrainJobWaiting),
UserID: ctx.User.ID,
RepoID: ctx.Repo.Repository.ID,
JobID: models.TempJobIdPrefix + req.JobName + strconv.Itoa(int(rand.New(rand.NewSource(time.Now().UnixNano())).Int31n(100000))),
JobID: models.TempJobIdPrefix + jobId,
JobName: req.JobName,
DisplayJobName: req.DisplayJobName,
JobType: string(models.JobTypeTrain),
@@ -767,3 +767,44 @@ func GetNotebookImageName(imageId string) (string, error) {

return imageName, nil
}

func ProcessTrainJobInfo(task *models.Cloudbrain) error {
if strings.HasPrefix(task.JobID, models.TempJobIdPrefix) {
if task.VersionCount > VersionCountOne {
//multi version
} else {
//inference or one version

}
} else {
//normal
}
result, err := GetTrainJob(task.JobID, strconv.FormatInt(task.VersionID, 10))
if err != nil {
log.Error("GetTrainJob(%s) failed:%v", task.JobName, err)
return err
}

if result != nil {
task.Status = TransTrainJobStatus(result.IntStatus)
task.Duration = result.Duration / 1000
task.TrainJobDuration = result.TrainJobDuration

if task.StartTime == 0 && result.StartTime > 0 {
task.StartTime = timeutil.TimeStamp(result.StartTime / 1000)
}
task.TrainJobDuration = models.ConvertDurationToStr(task.Duration)
if task.EndTime == 0 && models.IsTrainJobTerminal(task.Status) && task.StartTime > 0 {
task.EndTime = task.StartTime.Add(task.Duration)
}
task.CorrectCreateUnix()
err = models.UpdateJob(task)
if err != nil {
log.Error("UpdateJob(%s) failed:%v", task.JobName, err)
return err
}
}

//temp
return nil
}

+ 0
- 65
routers/api/v1/repo/modelarts.go View File

@@ -24,37 +24,6 @@ import (
routerRepo "code.gitea.io/gitea/routers/repo"
)

func GetModelArtsNotebook(ctx *context.APIContext) {
var (
err error
)

jobID := ctx.Params(":jobid")
repoID := ctx.Repo.Repository.ID
job, err := models.GetRepoCloudBrainByJobID(repoID, jobID)
if err != nil {
ctx.NotFound(err)
return
}
result, err := modelarts.GetJob(jobID)
if err != nil {
ctx.NotFound(err)
return
}

job.Status = result.Status
err = models.UpdateJob(job)
if err != nil {
log.Error("UpdateJob failed:", err)
}

ctx.JSON(http.StatusOK, map[string]interface{}{
"JobID": jobID,
"JobStatus": result.Status,
})

}

func GetModelArtsNotebook2(ctx *context.APIContext) {
var (
err error
@@ -93,40 +62,6 @@ func GetModelArtsNotebook2(ctx *context.APIContext) {

}

func GetModelArtsTrainJob(ctx *context.APIContext) {
var (
err error
)

jobID := ctx.Params(":jobid")
repoID := ctx.Repo.Repository.ID
job, err := models.GetRepoCloudBrainByJobID(repoID, jobID)
if err != nil {
ctx.NotFound(err)
return
}
result, err := modelarts.GetTrainJob(jobID, strconv.FormatInt(job.VersionID, 10))
if err != nil {
ctx.NotFound(err)
return
}

job.Status = modelarts.TransTrainJobStatus(result.IntStatus)
job.Duration = result.Duration
job.TrainJobDuration = result.TrainJobDuration
err = models.UpdateJob(job)
if err != nil {
log.Error("UpdateJob failed:", err)
}

ctx.JSON(http.StatusOK, map[string]interface{}{
"JobID": jobID,
"JobStatus": job.Status,
"JobDuration": job.Duration,
})

}

func GetModelArtsTrainJobVersion(ctx *context.APIContext) {
var (
err error


+ 2
- 2
routers/repo/grampus.go View File

@@ -332,7 +332,7 @@ func GrampusTrainJobGpuCreate(ctx *context.Context, form auth.CreateGrampusTrain
EngineName: image,
DatasetName: attachment.Name,
IsLatestVersion: modelarts.IsLatestVersion,
VersionCount: modelarts.VersionCount,
VersionCount: modelarts.VersionCountOne,
WorkServerNumber: 1,
}

@@ -382,7 +382,7 @@ func GrampusTrainJobNpuCreate(ctx *context.Context, form auth.CreateGrampusTrain
branchName := form.BranchName
isLatestVersion := modelarts.IsLatestVersion
flavorName := form.FlavorName
versionCount := modelarts.VersionCount
versionCount := modelarts.VersionCountOne
engineName := form.EngineName

if !jobNamePattern.MatchString(displayJobName) {


+ 2
- 2
routers/repo/modelarts.go View File

@@ -1020,7 +1020,7 @@ func TrainJobCreate(ctx *context.Context, form auth.CreateModelArtsTrainJobForm)
branch_name := form.BranchName
isLatestVersion := modelarts.IsLatestVersion
FlavorName := form.FlavorName
VersionCount := modelarts.VersionCount
VersionCount := modelarts.VersionCountOne
EngineName := form.EngineName

count, err := models.GetCloudbrainTrainJobCountByUserID(ctx.User.ID)
@@ -1902,7 +1902,7 @@ func InferenceJobCreate(ctx *context.Context, form auth.CreateModelArtsInference
EngineName := form.EngineName
LabelName := form.LabelName
isLatestVersion := modelarts.IsLatestVersion
VersionCount := modelarts.VersionCount
VersionCount := modelarts.VersionCountOne
trainUrl := form.TrainUrl
modelName := form.ModelName
modelVersion := form.ModelVersion


Loading…
Cancel
Save