Browse Source

提交代码

tags/v1.22.9.2^2
ychao_1983 3 years ago
parent
commit
9df8fa9f24
2 changed files with 87 additions and 61 deletions
  1. +71
    -61
      modules/modelarts/modelarts.go
  2. +16
    -0
      routers/repo/modelarts.go

+ 71
- 61
modules/modelarts/modelarts.go View File

@@ -75,35 +75,40 @@ var (
)

type GenerateTrainJobReq struct {
JobName string
DisplayJobName string
Uuid string
Description string
CodeObsPath string
BootFile string
BootFileUrl string
DataUrl string
TrainUrl string
LogUrl string
PoolID string
WorkServerNumber int
EngineID int64
Parameters []models.Parameter
CommitID string
IsLatestVersion string
Params string
BranchName string
PreVersionId int64
PreVersionName string
FlavorCode string
FlavorName string
VersionCount int
EngineName string
TotalVersionCount int
UserImageUrl string
UserCommand string
DatasetName string
Spec *models.Specification
JobName string
DisplayJobName string
Uuid string
Description string
CodeObsPath string
BootFile string
BootFileUrl string
DataUrl string
TrainUrl string
LogUrl string
PoolID string
WorkServerNumber int
EngineID int64
Parameters []models.Parameter
CommitID string
IsLatestVersion string
Params string
BranchName string
PreVersionId int64
PreVersionName string
FlavorCode string
FlavorName string
VersionCount int
EngineName string
TotalVersionCount int
UserImageUrl string
UserCommand string
DatasetName string
Spec *models.Specification
ModelName string
LabelName string
CkptName string
ModelVersion string
PreTrainingModelUrl string
}

type GenerateInferenceJobReq struct {
@@ -407,38 +412,43 @@ func GenerateTrainJob(ctx *context.Context, req *GenerateTrainJobReq) (err error
}
jobId := strconv.FormatInt(jobResult.JobID, 10)
createErr = models.CreateCloudbrain(&models.Cloudbrain{
Status: TransTrainJobStatus(jobResult.Status),
UserID: ctx.User.ID,
RepoID: ctx.Repo.Repository.ID,
JobID: jobId,
JobName: req.JobName,
DisplayJobName: req.DisplayJobName,
JobType: string(models.JobTypeTrain),
Type: models.TypeCloudBrainTwo,
VersionID: jobResult.VersionID,
VersionName: jobResult.VersionName,
Uuid: req.Uuid,
DatasetName: req.DatasetName,
CommitID: req.CommitID,
IsLatestVersion: req.IsLatestVersion,
ComputeResource: models.NPUResource,
EngineID: req.EngineID,
TrainUrl: req.TrainUrl,
BranchName: req.BranchName,
Parameters: req.Params,
BootFile: req.BootFile,
DataUrl: req.DataUrl,
LogUrl: req.LogUrl,
FlavorCode: req.Spec.SourceSpecId,
Description: req.Description,
WorkServerNumber: req.WorkServerNumber,
FlavorName: req.FlavorName,
EngineName: req.EngineName,
VersionCount: req.VersionCount,
TotalVersionCount: req.TotalVersionCount,
CreatedUnix: createTime,
UpdatedUnix: createTime,
Spec: req.Spec,
Status: TransTrainJobStatus(jobResult.Status),
UserID: ctx.User.ID,
RepoID: ctx.Repo.Repository.ID,
JobID: jobId,
JobName: req.JobName,
DisplayJobName: req.DisplayJobName,
JobType: string(models.JobTypeTrain),
Type: models.TypeCloudBrainTwo,
VersionID: jobResult.VersionID,
VersionName: jobResult.VersionName,
Uuid: req.Uuid,
DatasetName: req.DatasetName,
CommitID: req.CommitID,
IsLatestVersion: req.IsLatestVersion,
ComputeResource: models.NPUResource,
EngineID: req.EngineID,
TrainUrl: req.TrainUrl,
BranchName: req.BranchName,
Parameters: req.Params,
BootFile: req.BootFile,
DataUrl: req.DataUrl,
LogUrl: req.LogUrl,
FlavorCode: req.Spec.SourceSpecId,
Description: req.Description,
WorkServerNumber: req.WorkServerNumber,
FlavorName: req.FlavorName,
EngineName: req.EngineName,
VersionCount: req.VersionCount,
TotalVersionCount: req.TotalVersionCount,
CreatedUnix: createTime,
UpdatedUnix: createTime,
Spec: req.Spec,
ModelName: req.ModelName,
ModelVersion: req.ModelVersion,
LabelName: req.LabelName,
PreTrainingModelUrl: req.PreTrainingModelUrl,
CkptName: req.CkptName,
})

if createErr != nil {


+ 16
- 0
routers/repo/modelarts.go View File

@@ -1290,6 +1290,13 @@ func TrainJobCreate(ctx *context.Context, form auth.CreateModelArtsTrainJobForm)
Value: string(jsondatas),
})
}
if form.ModelName != "" { //使用预训练模型训练
ckptUrl := "/" + form.PreTrainModelUrl + form.CkptName
param = append(param, models.Parameter{
Label: modelarts.CkptUrl,
Value: "s3:/" + ckptUrl,
})
}

//save param config
// if isSaveParam == "on" {
@@ -1358,6 +1365,15 @@ func TrainJobCreate(ctx *context.Context, form auth.CreateModelArtsTrainJobForm)
DatasetName: datasetNames,
Spec: spec,
}
if form.ModelName != "" { //使用预训练模型训练
req.ModelName = form.ModelName
req.LabelName = form.LabelName
req.CkptName = form.CkptName
req.ModelVersion = form.ModelVersion
req.PreTrainingModelUrl = form.PreTrainModelUrl

}

userCommand, userImageUrl := getUserCommand(engineID, req)
req.UserCommand = userCommand
req.UserImageUrl = userImageUrl


Loading…
Cancel
Save