diff --git a/modules/modelarts/modelarts.go b/modules/modelarts/modelarts.go index 9e8447978..97791e25a 100755 --- a/modules/modelarts/modelarts.go +++ b/modules/modelarts/modelarts.go @@ -71,7 +71,8 @@ var ( FlavorInfos *models.FlavorInfos ImageInfos *models.ImageInfosModelArts TrainFlavorInfos *Flavor - SpecialPools *models.SpecialPools + SpecialPools *models.SpecialPools + MultiNodeConfig *MultiNodes ) type GenerateTrainJobReq struct { @@ -166,6 +167,14 @@ type ResourcePool struct { } `json:"resource_pool"` } +type MultiNodes struct{ + Info []OrgMultiNode `json:"multinode"` +} +type OrgMultiNode struct{ + Org string `json:"org"` + Node []int `json:"node"` +} + // type Parameter struct { // Label string `json:"label"` // Value string `json:"value"` @@ -773,6 +782,13 @@ func InitSpecialPool() { } } +func InitMultiNode(){ + if MultiNodeConfig ==nil && setting.ModelArtsMultiNode!=""{ + json.Unmarshal([]byte(setting.ModelArtsMultiNode), &MultiNodeConfig) + } + +} + func HandleTrainJobInfo(task *models.Cloudbrain) error { result, err := GetTrainJob(task.JobID, strconv.FormatInt(task.VersionID, 10)) diff --git a/modules/setting/setting.go b/modules/setting/setting.go index 1e96ff9da..3b8a1d8cf 100755 --- a/modules/setting/setting.go +++ b/modules/setting/setting.go @@ -547,6 +547,7 @@ var ( FlavorInfos string TrainJobFLAVORINFOS string ModelArtsSpecialPools string + ModelArtsMultiNode string //grampus config Grampus = struct { @@ -1432,6 +1433,7 @@ func NewContext() { FlavorInfos = sec.Key("FLAVOR_INFOS").MustString("") TrainJobFLAVORINFOS = sec.Key("TrainJob_FLAVOR_INFOS").MustString("") ModelArtsSpecialPools = sec.Key("SPECIAL_POOL").MustString("") + ModelArtsMultiNode=sec.Key("MULTI_NODE").MustString("") sec = Cfg.Section("elk") ElkUrl = sec.Key("ELKURL").MustString("") diff --git a/routers/repo/modelarts.go b/routers/repo/modelarts.go index 847e831f6..10843e683 100755 --- a/routers/repo/modelarts.go +++ b/routers/repo/modelarts.go @@ -763,9 +763,23 @@ func trainJobNewDataPrepare(ctx *context.Context) error { waitCount := cloudbrain.GetWaitingCloudbrainCount(models.TypeCloudBrainTwo, "") ctx.Data["WaitCount"] = waitCount + setMultiNodeIfConfigureMatch(ctx) + return nil } +func setMultiNodeIfConfigureMatch(ctx *context.Context) { + modelarts.InitMultiNode() + if modelarts.MultiNodeConfig != nil { + for _, info := range modelarts.MultiNodeConfig.Info { + if isInOrg, _ := models.IsOrganizationMemberByOrgName(info.Org, ctx.User.ID); isInOrg { + ctx.Data["WorkNode"] = info.Node + break + } + } + } +} + func setSpecBySpecialPoolConfig(ctx *context.Context, jobType string) { modelarts.InitSpecialPool() @@ -880,6 +894,7 @@ func trainJobErrorNewDataPrepare(ctx *context.Context, form auth.CreateModelArts ctx.Data["datasetType"] = models.TypeCloudBrainTwo waitCount := cloudbrain.GetWaitingCloudbrainCount(models.TypeCloudBrainTwo, "") ctx.Data["WaitCount"] = waitCount + setMultiNodeIfConfigureMatch(ctx) return nil } diff --git a/templates/repo/modelarts/trainjob/new.tmpl b/templates/repo/modelarts/trainjob/new.tmpl index 7818938d3..2b6ea923b 100755 --- a/templates/repo/modelarts/trainjob/new.tmpl +++ b/templates/repo/modelarts/trainjob/new.tmpl @@ -287,8 +287,15 @@ id="trainjob_work_server_num" tabindex="3" autofocus required maxlength="255" value="1" readonly>