diff --git a/agent/internal/http/hub_io.go b/agent/internal/http/hub_io.go index 8772bbf..bae2953 100644 --- a/agent/internal/http/hub_io.go +++ b/agent/internal/http/hub_io.go @@ -4,17 +4,18 @@ import ( "bytes" "context" "fmt" + "io" + "io/ioutil" + "net/http" + "time" + "github.com/gin-gonic/gin" "github.com/inhies/go-bytesize" "gitlink.org.cn/cloudream/common/consts/errorcode" "gitlink.org.cn/cloudream/common/pkgs/ioswitch/exec" "gitlink.org.cn/cloudream/common/pkgs/logger" - cdssdk "gitlink.org.cn/cloudream/common/sdks/storage" + "gitlink.org.cn/cloudream/common/sdks/storage/cdsapi" "gitlink.org.cn/cloudream/common/utils/serder" - "io" - "io/ioutil" - "net/http" - "time" ) type IOService struct { @@ -28,7 +29,7 @@ func (s *Server) IOSvc() *IOService { } func (s *IOService) GetStream(ctx *gin.Context) { - var req cdssdk.GetStreamReq + var req cdsapi.GetStreamReq if err := ctx.ShouldBindJSON(&req); err != nil { logger.Warnf("binding body: %s", err.Error()) ctx.JSON(http.StatusBadRequest, Failed(errorcode.BadArgument, "missing argument or invalid argument")) @@ -126,7 +127,7 @@ func (s *IOService) SendStream(ctx *gin.Context) { //planID := ctx.PostForm("plan_id") //varID := ctx.PostForm("var_id") - var req cdssdk.SendStreamReq + var req cdsapi.SendStreamReq if err := ctx.ShouldBindJSON(&req); err != nil { logger.Warnf("binding body: %s", err.Error()) ctx.JSON(http.StatusBadRequest, Failed(errorcode.BadArgument, "missing argument or invalid argument")) @@ -201,9 +202,9 @@ func (s *IOService) ExecuteIOPlan(ctx *gin.Context) { return } println("Received body: %s", string(bodyBytes)) - ctx.Request.Body = ioutil.NopCloser(bytes.NewBuffer(bodyBytes)) // Reset body for subsequent reads + ctx.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) // Reset body for subsequent reads - var req cdssdk.ExecuteIOPlanReq + var req cdsapi.ExecuteIOPlanReq if err := ctx.ShouldBindJSON(&req); err != nil { logger.Warnf("binding body: %s", err.Error()) ctx.JSON(http.StatusBadRequest, Failed(errorcode.BadArgument, "missing argument or invalid argument")) @@ -240,7 +241,7 @@ func (s *IOService) ExecuteIOPlan(ctx *gin.Context) { } func (s *IOService) SendVar(ctx *gin.Context) { - var req cdssdk.SendVarReq + var req cdsapi.SendVarReq if err := ctx.ShouldBindJSON(&req); err != nil { logger.Warnf("binding body: %s", err.Error()) ctx.JSON(http.StatusBadRequest, Failed(errorcode.BadArgument, "missing argument or invalid argument")) @@ -268,7 +269,7 @@ func (s *IOService) SendVar(ctx *gin.Context) { } func (s *IOService) GetVar(ctx *gin.Context) { - var req cdssdk.GetVarReq + var req cdsapi.GetVarReq if err := ctx.ShouldBindJSON(&req); err != nil { logger.Warnf("binding body: %s", err.Error()) ctx.JSON(http.StatusBadRequest, Failed(errorcode.BadArgument, "missing argument or invalid argument")) diff --git a/agent/internal/http/server.go b/agent/internal/http/server.go index 3224b88..70550ee 100644 --- a/agent/internal/http/server.go +++ b/agent/internal/http/server.go @@ -3,7 +3,7 @@ package http import ( "github.com/gin-gonic/gin" "gitlink.org.cn/cloudream/common/pkgs/logger" - cdssdk "gitlink.org.cn/cloudream/common/sdks/storage" + "gitlink.org.cn/cloudream/common/sdks/storage/cdsapi" ) type Server struct { @@ -38,9 +38,9 @@ func (s *Server) Serve() error { } func (s *Server) initRouters() { - s.engine.GET(cdssdk.GetStreamPath, s.IOSvc().GetStream) - s.engine.POST(cdssdk.SendStreamPath, s.IOSvc().SendStream) - s.engine.POST(cdssdk.ExecuteIOPlanPath, s.IOSvc().ExecuteIOPlan) - s.engine.POST(cdssdk.SendVarPath, s.IOSvc().SendVar) - s.engine.GET(cdssdk.GetVarPath, s.IOSvc().GetVar) + s.engine.GET(cdsapi.GetStreamPath, s.IOSvc().GetStream) + s.engine.POST(cdsapi.SendStreamPath, s.IOSvc().SendStream) + s.engine.POST(cdsapi.ExecuteIOPlanPath, s.IOSvc().ExecuteIOPlan) + s.engine.POST(cdsapi.SendVarPath, s.IOSvc().SendVar) + s.engine.GET(cdsapi.GetVarPath, s.IOSvc().GetVar) } diff --git a/common/pkgs/db2/db2.go b/common/pkgs/db2/db2.go index a0447b5..af825c8 100644 --- a/common/pkgs/db2/db2.go +++ b/common/pkgs/db2/db2.go @@ -22,3 +22,17 @@ func NewDB(cfg *config.Config) (*DB, error) { db: mydb, }, nil } + +func (s *DB) DoTx(do func(tx SQLContext) error) error { + return s.db.Transaction(func(tx *gorm.DB) error { + return do(SQLContext{tx}) + }) +} + +type SQLContext struct { + *gorm.DB +} + +func (d *DB) DefCtx() SQLContext { + return SQLContext{d.db} +} diff --git a/common/pkgs/db2/node.go b/common/pkgs/db2/node.go index baeb7d3..1af106c 100644 --- a/common/pkgs/db2/node.go +++ b/common/pkgs/db2/node.go @@ -1,8 +1,9 @@ package db2 import ( - cdssdk "gitlink.org.cn/cloudream/common/sdks/storage" "time" + + cdssdk "gitlink.org.cn/cloudream/common/sdks/storage" ) type NodeDB struct { @@ -13,24 +14,24 @@ func (nodeDB *DB) Node() *NodeDB { return &NodeDB{DB: nodeDB} } -func (nodeDB *NodeDB) GetAllNodes() ([]cdssdk.Node, error) { +func (*NodeDB) GetAllNodes(ctx SQLContext) ([]cdssdk.Node, error) { var ret []cdssdk.Node - err := nodeDB.DB.db.Table("node").Find(&ret).Error + err := ctx.Table("node").Find(&ret).Error return ret, err } -func (nodeDB *NodeDB) GetByID(nodeID cdssdk.NodeID) (cdssdk.Node, error) { +func (*NodeDB) GetByID(ctx SQLContext, nodeID cdssdk.NodeID) (cdssdk.Node, error) { var ret cdssdk.Node - err := nodeDB.DB.db.Table("node").Where("NodeID = ?", nodeID).Find(&ret).Error + err := ctx.Table("node").Where("NodeID = ?", nodeID).Find(&ret).Error return ret, err } // GetUserNodes 根据用户id查询可用node -func (nodeDB *NodeDB) GetUserNodes(userID cdssdk.UserID) ([]cdssdk.Node, error) { +func (*NodeDB) GetUserNodes(ctx SQLContext, userID cdssdk.UserID) ([]cdssdk.Node, error) { var nodes []cdssdk.Node - err := nodeDB.DB.db. + err := ctx. Table("Node"). Select("Node.*"). Joins("JOIN UserNode ON UserNode.NodeID = Node.NodeID"). @@ -40,8 +41,8 @@ func (nodeDB *NodeDB) GetUserNodes(userID cdssdk.UserID) ([]cdssdk.Node, error) } // UpdateState 更新状态,并且设置上次上报时间为现在 -func (nodeDB *NodeDB) UpdateState(nodeID cdssdk.NodeID, state string) error { - err := nodeDB.DB.db. +func (*NodeDB) UpdateState(ctx SQLContext, nodeID cdssdk.NodeID, state string) error { + err := ctx. Model(&cdssdk.Node{}). Where("NodeID = ?", nodeID). Updates(map[string]interface{}{ diff --git a/common/pkgs/db2/shard_storage.go b/common/pkgs/db2/shard_storage.go new file mode 100644 index 0000000..e52e59d --- /dev/null +++ b/common/pkgs/db2/shard_storage.go @@ -0,0 +1,19 @@ +package db2 + +import ( + cdssdk "gitlink.org.cn/cloudream/common/sdks/storage" +) + +type ShardStorageDB struct { + *DB +} + +func (db *DB) ShardStorage() *ShardStorageDB { + return &ShardStorageDB{DB: db} +} + +func (*ShardStorageDB) GetByStorageID(ctx SQLContext, stgID cdssdk.StorageID) (cdssdk.ShardStorage, error) { + var ret cdssdk.ShardStorage + err := ctx.Table("ShardStorage").First(&ret, stgID).Error + return ret, err +} diff --git a/common/pkgs/db2/shared_storage.go b/common/pkgs/db2/shared_storage.go new file mode 100644 index 0000000..d2a61e9 --- /dev/null +++ b/common/pkgs/db2/shared_storage.go @@ -0,0 +1,19 @@ +package db2 + +import ( + cdssdk "gitlink.org.cn/cloudream/common/sdks/storage" +) + +type SharedStorageDB struct { + *DB +} + +func (db *DB) SharedStorage() *SharedStorageDB { + return &SharedStorageDB{DB: db} +} + +func (*SharedStorageDB) GetByStorageID(ctx SQLContext, stgID cdssdk.StorageID) (cdssdk.SharedStorage, error) { + var ret cdssdk.SharedStorage + err := ctx.Table("SharedStorage").First(&ret, stgID).Error + return ret, err +} diff --git a/common/pkgs/db2/storage.go b/common/pkgs/db2/storage.go new file mode 100644 index 0000000..56f1786 --- /dev/null +++ b/common/pkgs/db2/storage.go @@ -0,0 +1,62 @@ +package db2 + +import ( + "fmt" + + cdssdk "gitlink.org.cn/cloudream/common/sdks/storage" + "gitlink.org.cn/cloudream/storage/common/pkgs/db/model" +) + +type StorageDB struct { + *DB +} + +func (db *DB) Storage() *StorageDB { + return &StorageDB{DB: db} +} + +func (db *StorageDB) GetByID(ctx SQLContext, stgID cdssdk.StorageID) (model.Storage, error) { + var stg model.Storage + err := ctx.Table("Storage").First(&stg, stgID).Error + return stg, err +} + +func (db *StorageDB) BatchGetAllStorageIDs(ctx SQLContext, start int, count int) ([]cdssdk.StorageID, error) { + var ret []cdssdk.StorageID + err := ctx.Table("Storage").Select("StorageID").Find(ret).Limit(count).Offset(start).Error + return ret, err +} + +func (db *StorageDB) IsAvailable(ctx SQLContext, userID cdssdk.UserID, storageID cdssdk.StorageID) (bool, error) { + rows, err := ctx.Table("Storage").Select("Storage.StorageID"). + Joins("inner join UserStorage on Storage.StorageID = UserStorage.StorageID"). + Where("UserID = ? and StorageID = ?", userID, storageID).Rows() + if err != nil { + return false, fmt.Errorf("execute sql: %w", err) + } + defer rows.Close() + + return rows.Next(), nil +} + +func (db *StorageDB) GetUserStorage(ctx SQLContext, userID cdssdk.UserID, storageID cdssdk.StorageID) (model.Storage, error) { + var stg model.Storage + err := ctx.Table("Storage").Select("Storage.*"). + Joins("inner join UserStorage on Storage.StorageID = UserStorage.StorageID"). + Where("UserID = ? and StorageID = ?", userID, storageID).First(&stg).Error + + return stg, err +} + +func (db *StorageDB) GetUserStorageByName(ctx SQLContext, userID cdssdk.UserID, name string) (model.Storage, error) { + var stg model.Storage + err := ctx.Table("Storage").Select("Storage.*"). + Joins("inner join UserStorage on Storage.StorageID = UserStorage.StorageID"). + Where("UserID = ? and Name = ?", userID, name).First(&stg).Error + + return stg, err +} + +// func (db *StorageDB) ChangeState(ctx SQLContext, storageID cdssdk.StorageID, state string) error { +// return ctx.Table("Storage").Where("StorageID = ?", storageID).Update("State", state).Error +// } diff --git a/coordinator/internal/cmd/migrate.go b/coordinator/internal/cmd/migrate.go index 53a38ca..b47708a 100644 --- a/coordinator/internal/cmd/migrate.go +++ b/coordinator/internal/cmd/migrate.go @@ -39,6 +39,12 @@ func migrate(configPath string) { os.Exit(1) } + err = db.AutoMigrate(&cdssdk.Node{}) + if err != nil { + fmt.Printf("migratting model Node: %v\n", err) + os.Exit(1) + } + err = db.AutoMigrate(&cdssdk.Storage{}) if err != nil { fmt.Printf("migratting model Storage: %v\n", err)