Browse Source

从sqlx迁移到gorm

gitlink
JeshuaRen 1 year ago
parent
commit
22c953d03b
22 changed files with 1413 additions and 98 deletions
  1. +137
    -0
      common/pkgs/db2/bucket.go
  2. +114
    -0
      common/pkgs/db2/cache.go
  3. +4
    -4
      common/pkgs/db2/db2.go
  4. +35
    -0
      common/pkgs/db2/location.go
  5. +2
    -2
      common/pkgs/db2/node.go
  6. +37
    -0
      common/pkgs/db2/node_connectivity.go
  7. +346
    -0
      common/pkgs/db2/object.go
  8. +107
    -0
      common/pkgs/db2/object_access_stat.go
  9. +107
    -0
      common/pkgs/db2/object_block.go
  10. +186
    -0
      common/pkgs/db2/package.go
  11. +66
    -0
      common/pkgs/db2/package_access_stat.go
  12. +83
    -0
      common/pkgs/db2/storage_package.go
  13. +20
    -0
      common/pkgs/db2/user.go
  14. +22
    -0
      common/pkgs/db2/user_bucket.go
  15. +57
    -0
      common/pkgs/db2/utils.go
  16. +3
    -1
      common/pkgs/storage/temp/local.go
  17. +11
    -12
      coordinator/internal/mq/bucket.go
  18. +9
    -10
      coordinator/internal/mq/cache.go
  19. +6
    -7
      coordinator/internal/mq/node.go
  20. +30
    -30
      coordinator/internal/mq/object.go
  21. +24
    -24
      coordinator/internal/mq/package.go
  22. +7
    -8
      coordinator/internal/mq/storage.go

+ 137
- 0
common/pkgs/db2/bucket.go View File

@@ -0,0 +1,137 @@
package db2

import (
"errors"
"fmt"
"gorm.io/gorm"

cdssdk "gitlink.org.cn/cloudream/common/sdks/storage"
"gitlink.org.cn/cloudream/storage/common/pkgs/db/model"
)

type BucketDB struct {
*DB
}

func (db *DB) Bucket() *BucketDB {
return &BucketDB{DB: db}
}

func (db *BucketDB) GetByID(ctx SQLContext, bucketID cdssdk.BucketID) (cdssdk.Bucket, error) {
var ret cdssdk.Bucket
err := ctx.Table("Bucket").Where("BucketID = ?", bucketID).First(&ret).Error
return ret, err
}

// GetIDByName 根据BucketName查询BucketID
func (db *BucketDB) GetIDByName(ctx SQLContext, bucketName string) (int64, error) {
var result struct {
BucketID int64 `gorm:"column:BucketID"`
BucketName string `gorm:"column:BucketName"`
}

err := ctx.Table("Bucket").Select("BucketID, BucketName").Where("BucketName = ?", bucketName).Scan(&result).Error
if err != nil {
return 0, err
}

return result.BucketID, nil
}

// IsAvailable 判断用户是否有指定Bucekt的权限
func (db *BucketDB) IsAvailable(ctx SQLContext, bucketID cdssdk.BucketID, userID cdssdk.UserID) (bool, error) {
_, err := db.GetUserBucket(ctx, userID, bucketID)
if errors.Is(err, gorm.ErrRecordNotFound) {
return false, nil
}

if err != nil {
return false, fmt.Errorf("find bucket failed, err: %w", err)
}

return true, nil
}

func (*BucketDB) GetUserBucket(ctx SQLContext, userID cdssdk.UserID, bucketID cdssdk.BucketID) (model.Bucket, error) {
var ret model.Bucket
err := ctx.Table("UserBucket").
Select("Bucket.*").
Joins("JOIN Bucket ON UserBucket.BucketID = Bucket.BucketID").
Where("UserBucket.UserID = ? AND Bucket.BucketID = ?", userID, bucketID).
First(&ret).Error
return ret, err
}

func (*BucketDB) GetUserBucketByName(ctx SQLContext, userID cdssdk.UserID, bucketName string) (model.Bucket, error) {
var ret model.Bucket
err := ctx.Table("UserBucket").
Select("Bucket.*").
Joins("JOIN Bucket ON UserBucket.BucketID = Bucket.BucketID").
Where("UserBucket.UserID = ? AND Bucket.Name = ?", userID, bucketName).
First(&ret).Error
return ret, err
}

func (*BucketDB) GetUserBuckets(ctx SQLContext, userID cdssdk.UserID) ([]model.Bucket, error) {
var ret []model.Bucket
err := ctx.Table("UserBucket").
Select("Bucket.*").
Joins("JOIN Bucket ON UserBucket.BucketID = Bucket.BucketID").
Where("UserBucket.UserID = ?", userID).
Find(&ret).Error
return ret, err
}

func (db *BucketDB) Create(ctx SQLContext, userID cdssdk.UserID, bucketName string) (cdssdk.BucketID, error) {
var bucketID int64
err := ctx.Table("UserBucket").
Select("Bucket.BucketID").
Joins("JOIN Bucket ON UserBucket.BucketID = Bucket.BucketID").
Where("UserBucket.UserID = ? AND Bucket.Name = ?", userID, bucketName).
Scan(&bucketID).Error

if err == nil {
return 0, fmt.Errorf("bucket name exists")
}

if !errors.Is(err, gorm.ErrRecordNotFound) {
return 0, err
}

newBucket := cdssdk.Bucket{Name: bucketName, CreatorID: userID}
if err := ctx.Create(&newBucket).Error; err != nil {
return 0, fmt.Errorf("insert bucket failed, err: %w", err)
}

err = ctx.Exec("insert into UserBucket(UserID,BucketID) values(?,?)", userID, bucketID).Error
if err := ctx.Create(&newBucket).Error; err != nil {
return 0, fmt.Errorf("insert bucket failed, err: %w", err)
}

return newBucket.BucketID, nil
}

func (db *BucketDB) Delete(ctx SQLContext, bucketID cdssdk.BucketID) error {
if err := ctx.Exec("DELETE FROM UserBucket WHERE BucketID = ?", bucketID).Error; err != nil {
return fmt.Errorf("delete user bucket failed, err: %w", err)
}

if err := ctx.Exec("DELETE FROM Bucket WHERE BucketID = ?", bucketID).Error; err != nil {
return fmt.Errorf("delete bucket failed, err: %w", err)
}

var pkgIDs []cdssdk.PackageID
if err := ctx.Table("Package").Select("PackageID").Where("BucketID = ?", bucketID).Find(&pkgIDs).Error; err != nil {
return fmt.Errorf("query package failed, err: %w", err)
}

for _, pkgID := range pkgIDs {
if err := db.Package().SoftDelete(ctx, pkgID); err != nil {
return fmt.Errorf("set package selected failed, err: %w", err)
}

// 失败也没关系,会有定时任务再次尝试
db.Package().DeleteUnused(ctx, pkgID)
}
return nil
}

+ 114
- 0
common/pkgs/db2/cache.go View File

@@ -0,0 +1,114 @@
package db2

import (
"time"

cdssdk "gitlink.org.cn/cloudream/common/sdks/storage"
"gitlink.org.cn/cloudream/storage/common/pkgs/db/model"
)

type CacheDB struct {
*DB
}

func (db *DB) Cache() *CacheDB {
return &CacheDB{DB: db}
}

func (*CacheDB) Get(ctx SQLContext, fileHash string, nodeID cdssdk.NodeID) (model.Cache, error) {
var ret model.Cache
err := ctx.Table("Cache").Where("FileHash = ? AND NodeID = ?", fileHash, nodeID).First(&ret).Error
return ret, err
}

func (*CacheDB) BatchGetAllFileHashes(ctx SQLContext, start int, count int) ([]string, error) {
var ret []string
err := ctx.Table("Cache").Distinct("FileHash").Offset(start).Limit(count).Pluck("FileHash", &ret).Error
return ret, err
}

func (*CacheDB) GetByNodeID(ctx SQLContext, nodeID cdssdk.NodeID) ([]model.Cache, error) {
var ret []model.Cache
err := ctx.Table("Cache").Where("NodeID = ?", nodeID).Find(&ret).Error
return ret, err
}

// Create 创建一条缓存记录,如果已有则不进行操作
func (*CacheDB) Create(ctx SQLContext, fileHash string, nodeID cdssdk.NodeID, priority int) error {
cache := model.Cache{FileHash: fileHash, NodeID: nodeID, CreateTime: time.Now(), Priority: priority}
return ctx.Where(cache).Attrs(cache).FirstOrCreate(&cache).Error
}

// 批量创建缓存记录
func (*CacheDB) BatchCreate(ctx SQLContext, caches []model.Cache) error {
if len(caches) == 0 {
return nil
}
return BatchNamedExec(
ctx,
"insert into Cache(FileHash,NodeID,CreateTime,Priority) values(:FileHash,:NodeID,:CreateTime,:Priority)"+
" on duplicate key update CreateTime=values(CreateTime), Priority=values(Priority)",
4,
caches,
nil,
)
}

func (*CacheDB) BatchCreateOnSameNode(ctx SQLContext, fileHashes []string, nodeID cdssdk.NodeID, priority int) error {
if len(fileHashes) == 0 {
return nil
}

var caches []model.Cache
var nowTime = time.Now()
for _, hash := range fileHashes {
caches = append(caches, model.Cache{
FileHash: hash,
NodeID: nodeID,
CreateTime: nowTime,
Priority: priority,
})
}

return BatchNamedExec(ctx,
"insert into Cache(FileHash,NodeID,CreateTime,Priority) values(:FileHash,:NodeID,:CreateTime,:Priority)"+
" on duplicate key update CreateTime=values(CreateTime), Priority=values(Priority)",
4,
caches,
nil,
)
}

func (*CacheDB) NodeBatchDelete(ctx SQLContext, nodeID cdssdk.NodeID, fileHashes []string) error {
if len(fileHashes) == 0 {
return nil
}

return ctx.Table("Cache").Where("NodeID = ? AND FileHash IN (?)", nodeID, fileHashes).Delete(&model.Cache{}).Error
}

// GetCachingFileNodes 查找缓存了指定文件的节点
func (*CacheDB) GetCachingFileNodes(ctx SQLContext, fileHash string) ([]cdssdk.Node, error) {
var nodes []cdssdk.Node
err := ctx.Table("Cache").Select("Node.*").
Joins("JOIN Node ON Cache.NodeID = Node.NodeID").
Where("Cache.FileHash = ?", fileHash).
Find(&nodes).Error
return nodes, err
}

// DeleteNodeAll 删除一个节点所有的记录
func (*CacheDB) DeleteNodeAll(ctx SQLContext, nodeID cdssdk.NodeID) error {
return ctx.Where("NodeID = ?", nodeID).Delete(&model.Cache{}).Error
}

// FindCachingFileUserNodes 在缓存表中查询指定数据所在的节点
func (*CacheDB) FindCachingFileUserNodes(ctx SQLContext, userID cdssdk.NodeID, fileHash string) ([]cdssdk.Node, error) {
var nodes []cdssdk.Node
err := ctx.Table("Cache").Select("Node.*").
Joins("JOIN UserNode ON Cache.NodeID = UserNode.NodeID").
Joins("JOIN Node ON UserNode.NodeID = Node.NodeID").
Where("Cache.FileHash = ? AND UserNode.UserID = ?", fileHash, userID).
Find(&nodes).Error
return nodes, err
}

+ 4
- 4
common/pkgs/db2/db2.go View File

@@ -23,8 +23,8 @@ func NewDB(cfg *config.Config) (*DB, error) {
}, nil
}

func (s *DB) DoTx(do func(tx SQLContext) error) error {
return s.db.Transaction(func(tx *gorm.DB) error {
func (db *DB) DoTx(do func(tx SQLContext) error) error {
return db.db.Transaction(func(tx *gorm.DB) error {
return do(SQLContext{tx})
})
}
@@ -33,6 +33,6 @@ type SQLContext struct {
*gorm.DB
}

func (d *DB) DefCtx() SQLContext {
return SQLContext{d.db}
func (db *DB) DefCtx() SQLContext {
return SQLContext{db.db}
}

+ 35
- 0
common/pkgs/db2/location.go View File

@@ -0,0 +1,35 @@
package db2

import (
"fmt"
"gitlink.org.cn/cloudream/storage/common/pkgs/db/model"
)

type LocationDB struct {
*DB
}

func (db *DB) Location() *LocationDB {
return &LocationDB{DB: db}
}

func (*LocationDB) GetByID(ctx SQLContext, id int64) (model.Location, error) {
var ret model.Location
err := ctx.First(&ret, id).Error
return ret, err
}

func (db *LocationDB) FindLocationByExternalIP(ctx SQLContext, ip string) (model.Location, error) {
var locID int64
err := ctx.Table("Node").Select("LocationID").Where("ExternalIP = ?", ip).Scan(&locID).Error
if err != nil {
return model.Location{}, fmt.Errorf("finding node by external ip: %w", err)
}

loc, err := db.GetByID(ctx, locID)
if err != nil {
return model.Location{}, fmt.Errorf("getting location by id: %w", err)
}

return loc, nil
}

+ 2
- 2
common/pkgs/db2/node.go View File

@@ -10,8 +10,8 @@ type NodeDB struct {
*DB
}

func (nodeDB *DB) Node() *NodeDB {
return &NodeDB{DB: nodeDB}
func (db *DB) Node() *NodeDB {
return &NodeDB{DB: db}
}

func (*NodeDB) GetAllNodes(ctx SQLContext) ([]cdssdk.Node, error) {


+ 37
- 0
common/pkgs/db2/node_connectivity.go View File

@@ -0,0 +1,37 @@
package db2

import (
cdssdk "gitlink.org.cn/cloudream/common/sdks/storage"
"gitlink.org.cn/cloudream/storage/common/pkgs/db/model"
"gorm.io/gorm/clause"
)

type NodeConnectivityDB struct {
*DB
}

func (db *DB) NodeConnectivity() *NodeConnectivityDB {
return &NodeConnectivityDB{DB: db}
}

func (db *NodeConnectivityDB) BatchGetByFromNode(ctx SQLContext, fromNodeIDs []cdssdk.NodeID) ([]model.NodeConnectivity, error) {
if len(fromNodeIDs) == 0 {
return nil, nil
}

var ret []model.NodeConnectivity

err := ctx.Table("NodeConnectivity").Where("FromNodeID IN (?)", fromNodeIDs).Find(&ret).Error
return ret, err
}

func (db *NodeConnectivityDB) BatchUpdateOrCreate(ctx SQLContext, cons []model.NodeConnectivity) error {
if len(cons) == 0 {
return nil
}

// 使用 GORM 的批量插入或更新
return ctx.Table("NodeConnectivity").Clauses(clause.OnConflict{
UpdateAll: true,
}).Create(&cons).Error
}

+ 346
- 0
common/pkgs/db2/object.go View File

@@ -0,0 +1,346 @@
package db2

import (
"fmt"
"gitlink.org.cn/cloudream/common/utils/sort2"
"strings"
"time"

"github.com/samber/lo"
cdssdk "gitlink.org.cn/cloudream/common/sdks/storage"
stgmod "gitlink.org.cn/cloudream/storage/common/models"
"gitlink.org.cn/cloudream/storage/common/pkgs/db/model"
coormq "gitlink.org.cn/cloudream/storage/common/pkgs/mq/coordinator"
)

type ObjectDB struct {
*DB
}

func (db *DB) Object() *ObjectDB {
return &ObjectDB{DB: db}
}

func (db *ObjectDB) GetByID(ctx SQLContext, objectID cdssdk.ObjectID) (model.Object, error) {
var ret model.TempObject
err := ctx.Table("Object").Where("ObjectID = ?", objectID).First(&ret).Error
return ret.ToObject(), err
}

func (db *ObjectDB) BatchTestObjectID(ctx SQLContext, objectIDs []cdssdk.ObjectID) (map[cdssdk.ObjectID]bool, error) {
if len(objectIDs) == 0 {
return make(map[cdssdk.ObjectID]bool), nil
}

var avaiIDs []cdssdk.ObjectID
err := ctx.Table("Object").Where("ObjectID IN ?", objectIDs).Pluck("ObjectID", &avaiIDs).Error
if err != nil {
return nil, err
}

avaiIDMap := make(map[cdssdk.ObjectID]bool)
for _, pkgID := range avaiIDs {
avaiIDMap[pkgID] = true
}

return avaiIDMap, nil
}

func (db *ObjectDB) BatchGet(ctx SQLContext, objectIDs []cdssdk.ObjectID) ([]model.Object, error) {
if len(objectIDs) == 0 {
return nil, nil
}

var objs []model.TempObject
err := ctx.Table("Object").Where("ObjectID IN ?", objectIDs).Order("ObjectID ASC").Find(&objs).Error
if err != nil {
return nil, err
}

return lo.Map(objs, func(o model.TempObject, idx int) cdssdk.Object { return o.ToObject() }), nil
}

func (db *ObjectDB) BatchGetByPackagePath(ctx SQLContext, pkgID cdssdk.PackageID, pathes []string) ([]cdssdk.Object, error) {
if len(pathes) == 0 {
return nil, nil
}

var objs []model.TempObject
err := ctx.Table("Object").Where("PackageID = ? AND Path IN ?", pkgID, pathes).Find(&objs).Error
if err != nil {
return nil, err
}

return lo.Map(objs, func(o model.TempObject, idx int) cdssdk.Object { return o.ToObject() }), nil
}

func (db *ObjectDB) Create(ctx SQLContext, obj cdssdk.Object) (cdssdk.ObjectID, error) {
err := ctx.Table("Object").Create(&obj).Error
if err != nil {
return 0, fmt.Errorf("insert object failed, err: %w", err)
}
return obj.ObjectID, nil
}

func (db *ObjectDB) BatchUpsertByPackagePath(ctx SQLContext, objs []cdssdk.Object) error {
if len(objs) == 0 {
return nil
}

// 使用 GORM 的 Save 方法,插入或更新对象
return ctx.Table("Object").Save(&objs).Error
}

func (db *ObjectDB) BatchUpert(ctx SQLContext, objs []cdssdk.Object) error {
if len(objs) == 0 {
return nil
}

// 直接更新或插入
return ctx.Table("Object").Save(&objs).Error
}

func (db *ObjectDB) GetPackageObjects(ctx SQLContext, packageID cdssdk.PackageID) ([]model.Object, error) {
var ret []model.TempObject
err := ctx.Table("Object").Where("PackageID = ?", packageID).Order("ObjectID ASC").Find(&ret).Error
return lo.Map(ret, func(o model.TempObject, idx int) model.Object { return o.ToObject() }), err
}

func (db *ObjectDB) GetPackageObjectDetails(ctx SQLContext, packageID cdssdk.PackageID) ([]stgmod.ObjectDetail, error) {
var objs []model.TempObject
err := ctx.Table("Object").Where("PackageID = ?", packageID).Order("ObjectID ASC").Find(&objs).Error
if err != nil {
return nil, fmt.Errorf("getting objects: %w", err)
}

// 获取所有的 ObjectBlock
var allBlocks []stgmod.ObjectBlock
err = ctx.Table("ObjectBlock").
Select("ObjectBlock.*").
Joins("JOIN Object ON ObjectBlock.ObjectID = Object.ObjectID").
Where("Object.PackageID = ?", packageID).
Order("ObjectBlock.ObjectID, `Index` ASC").
Find(&allBlocks).Error
if err != nil {
return nil, fmt.Errorf("getting all object blocks: %w", err)
}

// 获取所有的 PinnedObject
var allPinnedObjs []cdssdk.PinnedObject
err = ctx.Table("PinnedObject").
Select("PinnedObject.*").
Joins("JOIN Object ON PinnedObject.ObjectID = Object.ObjectID").
Where("Object.PackageID = ?", packageID).
Order("PinnedObject.ObjectID").
Find(&allPinnedObjs).Error
if err != nil {
return nil, fmt.Errorf("getting all pinned objects: %w", err)
}

details := make([]stgmod.ObjectDetail, len(objs))
for i, obj := range objs {
details[i] = stgmod.ObjectDetail{
Object: obj.ToObject(),
}
}

stgmod.DetailsFillObjectBlocks(details, allBlocks)
stgmod.DetailsFillPinnedAt(details, allPinnedObjs)
return details, nil
}

func (db *ObjectDB) GetObjectsIfAnyBlockOnNode(ctx SQLContext, nodeID cdssdk.NodeID) ([]cdssdk.Object, error) {
var temps []model.TempObject
err := ctx.Table("Object").Where("ObjectID IN (SELECT ObjectID FROM ObjectBlock WHERE NodeID = ?)", nodeID).Order("ObjectID ASC").Find(&temps).Error
if err != nil {
return nil, fmt.Errorf("getting objects: %w", err)
}

objs := make([]cdssdk.Object, len(temps))
for i := range temps {
objs[i] = temps[i].ToObject()
}

return objs, nil
}

func (db *ObjectDB) BatchAdd(ctx SQLContext, packageID cdssdk.PackageID, adds []coormq.AddObjectEntry) ([]cdssdk.Object, error) {
if len(adds) == 0 {
return nil, nil
}

objs := make([]cdssdk.Object, 0, len(adds))
for _, add := range adds {
objs = append(objs, cdssdk.Object{
PackageID: packageID,
Path: add.Path,
Size: add.Size,
FileHash: add.FileHash,
Redundancy: cdssdk.NewNoneRedundancy(), // 首次上传默认使用不分块的none模式
CreateTime: add.UploadTime,
UpdateTime: add.UploadTime,
})
}

err := db.BatchUpsertByPackagePath(ctx, objs)
if err != nil {
return nil, fmt.Errorf("batch create or update objects: %w", err)
}

// 收集所有路径
pathes := make([]string, 0, len(adds))
for _, add := range adds {
pathes = append(pathes, add.Path)
}

// 批量获取对象
addedObjs := []cdssdk.Object{}
err = ctx.Table("Object").Where("PackageID = ? AND Path IN ?", packageID, pathes).Find(&addedObjs).Error
if err != nil {
return nil, fmt.Errorf("batch get object ids: %w", err)
}

// 对添加的对象和获取的对象进行排序
adds = sort2.Sort(adds, func(l, r coormq.AddObjectEntry) int { return strings.Compare(l.Path, r.Path) })
addedObjs = sort2.Sort(addedObjs, func(l, r cdssdk.Object) int { return strings.Compare(l.Path, r.Path) })

// 收集对象 ID
addedObjIDs := make([]cdssdk.ObjectID, len(addedObjs))
for i := range addedObjs {
addedObjIDs[i] = addedObjs[i].ObjectID
}

// 批量删除 ObjectBlock
if err := ctx.Table("ObjectBlock").Where("ObjectID IN ?", addedObjIDs).Delete(&stgmod.ObjectBlock{}).Error; err != nil {
return nil, fmt.Errorf("batch delete object blocks: %w", err)
}

// 批量删除 PinnedObject
if err := ctx.Table("PinnedObject").Where("ObjectID IN ?", addedObjIDs).Delete(&cdssdk.PinnedObject{}).Error; err != nil {
return nil, fmt.Errorf("batch delete pinned objects: %w", err)
}

// 创建 ObjectBlock
objBlocks := make([]stgmod.ObjectBlock, len(adds))
for i, add := range adds {
objBlocks[i] = stgmod.ObjectBlock{
ObjectID: addedObjIDs[i],
Index: 0,
NodeID: add.NodeID,
FileHash: add.FileHash,
}
}
if err := ctx.Table("ObjectBlock").Create(&objBlocks).Error; err != nil {
return nil, fmt.Errorf("batch create object blocks: %w", err)
}

// 创建 Cache
caches := make([]model.Cache, len(adds))
for _, add := range adds {
caches = append(caches, model.Cache{
FileHash: add.FileHash,
NodeID: add.NodeID,
CreateTime: time.Now(),
Priority: 0,
})
}
if err := ctx.Table("Cache").Create(&caches).Error; err != nil {
return nil, fmt.Errorf("batch create caches: %w", err)
}

return addedObjs, nil
}

func (db *ObjectDB) BatchUpdateRedundancy(ctx SQLContext, objs []coormq.UpdatingObjectRedundancy) error {
if len(objs) == 0 {
return nil
}

nowTime := time.Now()
objIDs := make([]cdssdk.ObjectID, 0, len(objs))
dummyObjs := make([]cdssdk.Object, 0, len(objs))
for _, obj := range objs {
objIDs = append(objIDs, obj.ObjectID)
dummyObjs = append(dummyObjs, cdssdk.Object{
ObjectID: obj.ObjectID,
Redundancy: obj.Redundancy,
CreateTime: nowTime,
UpdateTime: nowTime,
})
}

// 目前只能使用这种方式来同时更新大量数据
err := BatchNamedExec(ctx,
"insert into Object(ObjectID, PackageID, Path, Size, FileHash, Redundancy, CreateTime, UpdateTime)"+
" values(:ObjectID, :PackageID, :Path, :Size, :FileHash, :Redundancy, :CreateTime, :UpdateTime) as new"+
" on duplicate key update Redundancy=new.Redundancy", 8, dummyObjs, nil)
if err != nil {
return fmt.Errorf("batch update object redundancy: %w", err)
}

// 删除原本所有的编码块记录,重新添加
err = db.ObjectBlock().BatchDeleteByObjectID(ctx, objIDs)
if err != nil {
return fmt.Errorf("batch delete object blocks: %w", err)
}

// 删除原本Pin住的Object。暂不考虑FileHash没有变化的情况
err = db.PinnedObject().BatchDeleteByObjectID(ctx, objIDs)
if err != nil {
return fmt.Errorf("batch delete pinned object: %w", err)
}

blocks := make([]stgmod.ObjectBlock, 0, len(objs))
for _, obj := range objs {
blocks = append(blocks, obj.Blocks...)
}
err = db.ObjectBlock().BatchCreate(ctx, blocks)
if err != nil {
return fmt.Errorf("batch create object blocks: %w", err)
}

caches := make([]model.Cache, 0, len(objs))
for _, obj := range objs {
for _, blk := range obj.Blocks {
caches = append(caches, model.Cache{
FileHash: blk.FileHash,
NodeID: blk.NodeID,
CreateTime: time.Now(),
Priority: 0,
})
}
}
err = db.Cache().BatchCreate(ctx, caches)
if err != nil {
return fmt.Errorf("batch create object caches: %w", err)
}

pinneds := make([]cdssdk.PinnedObject, 0, len(objs))
for _, obj := range objs {
for _, p := range obj.PinnedAt {
pinneds = append(pinneds, cdssdk.PinnedObject{
ObjectID: obj.ObjectID,
StorageID: p,
CreateTime: time.Now(),
})
}
}
err = db.PinnedObject().BatchTryCreate(ctx, pinneds)
if err != nil {
return fmt.Errorf("batch create pinned objects: %w", err)
}

return nil
}

func (db *ObjectDB) BatchDelete(ctx SQLContext, ids []cdssdk.ObjectID) error {
if len(ids) == 0 {
return nil
}

return ctx.Table("Object").Where("ObjectID IN ?", ids).Delete(&model.TempObject{}).Error
}

func (db *ObjectDB) DeleteInPackage(ctx SQLContext, packageID cdssdk.PackageID) error {
return ctx.Table("Object").Where("PackageID = ?", packageID).Delete(&model.TempObject{}).Error
}

+ 107
- 0
common/pkgs/db2/object_access_stat.go View File

@@ -0,0 +1,107 @@
package db2

import (
cdssdk "gitlink.org.cn/cloudream/common/sdks/storage"
stgmod "gitlink.org.cn/cloudream/storage/common/models"
coormq "gitlink.org.cn/cloudream/storage/common/pkgs/mq/coordinator"
"gorm.io/gorm/clause"
)

type ObjectAccessStatDB struct {
*DB
}

func (db *DB) ObjectAccessStat() *ObjectAccessStatDB {
return &ObjectAccessStatDB{db}
}

func (*ObjectAccessStatDB) Get(ctx SQLContext, objID cdssdk.ObjectID, nodeID cdssdk.NodeID) (stgmod.ObjectAccessStat, error) {
var ret stgmod.ObjectAccessStat
err := ctx.Table("ObjectAccessStat").
Where("ObjectID = ? AND NodeID = ?", objID, nodeID).
First(&ret).Error
return ret, err
}

func (*ObjectAccessStatDB) GetByObjectID(ctx SQLContext, objID cdssdk.ObjectID) ([]stgmod.ObjectAccessStat, error) {
var ret []stgmod.ObjectAccessStat
err := ctx.Table("ObjectAccessStat").
Where("ObjectID = ?", objID).
Find(&ret).Error
return ret, err
}

func (*ObjectAccessStatDB) BatchGetByObjectID(ctx SQLContext, objIDs []cdssdk.ObjectID) ([]stgmod.ObjectAccessStat, error) {
if len(objIDs) == 0 {
return nil, nil
}

var ret []stgmod.ObjectAccessStat
err := ctx.Table("ObjectAccessStat").
Where("ObjectID IN ?", objIDs).
Find(&ret).Error
return ret, err
}

func (*ObjectAccessStatDB) BatchGetByObjectIDOnNode(ctx SQLContext, objIDs []cdssdk.ObjectID, nodeID cdssdk.NodeID) ([]stgmod.ObjectAccessStat, error) {
if len(objIDs) == 0 {
return nil, nil
}

var ret []stgmod.ObjectAccessStat
err := ctx.Table("ObjectAccessStat").
Where("ObjectID IN ? AND NodeID = ?", objIDs, nodeID).
Find(&ret).Error
return ret, err
}

func (*ObjectAccessStatDB) BatchAddCounter(ctx SQLContext, entries []coormq.AddAccessStatEntry) error {
if len(entries) == 0 {
return nil
}

for _, entry := range entries {
err := ctx.Table("ObjectAccessStat").
Clauses(clause.OnConflict{
UpdateAll: true,
}).
Create(&entry).Error
if err != nil {
return err
}
}
return nil
}

func (*ObjectAccessStatDB) BatchUpdateAmountInPackage(ctx SQLContext, pkgIDs []cdssdk.PackageID, historyWeight float64) error {
if len(pkgIDs) == 0 {
return nil
}

err := ctx.Exec("UPDATE ObjectAccessStat AS o INNER JOIN Object AS obj ON o.ObjectID = obj.ObjectID SET o.Amount = o.Amount * ? + o.Counter * (1 - ?), o.Counter = 0 WHERE obj.PackageID IN ?", historyWeight, historyWeight, pkgIDs).Error
return err
}

func (*ObjectAccessStatDB) UpdateAllAmount(ctx SQLContext, historyWeight float64) error {
err := ctx.Exec("UPDATE ObjectAccessStat SET Amount = Amount * ? + Counter * (1 - ?), Counter = 0", historyWeight, historyWeight).Error
return err
}

func (*ObjectAccessStatDB) DeleteByObjectID(ctx SQLContext, objID cdssdk.ObjectID) error {
err := ctx.Table("ObjectAccessStat").Where("ObjectID = ?", objID).Delete(nil).Error
return err
}

func (*ObjectAccessStatDB) BatchDeleteByObjectID(ctx SQLContext, objIDs []cdssdk.ObjectID) error {
if len(objIDs) == 0 {
return nil
}

err := ctx.Table("ObjectAccessStat").Where("ObjectID IN ?", objIDs).Delete(nil).Error
return err
}

func (*ObjectAccessStatDB) DeleteInPackage(ctx SQLContext, packageID cdssdk.PackageID) error {
err := ctx.Exec("DELETE o FROM ObjectAccessStat o INNER JOIN Object obj ON o.ObjectID = obj.ObjectID WHERE obj.PackageID = ?", packageID).Error
return err
}

+ 107
- 0
common/pkgs/db2/object_block.go View File

@@ -0,0 +1,107 @@
package db2

import (
"strconv"
"strings"

cdssdk "gitlink.org.cn/cloudream/common/sdks/storage"
stgmod "gitlink.org.cn/cloudream/storage/common/models"
)

type ObjectBlockDB struct {
*DB
}

func (db *DB) ObjectBlock() *ObjectBlockDB {
return &ObjectBlockDB{DB: db}
}

func (db *ObjectBlockDB) GetByNodeID(ctx SQLContext, nodeID cdssdk.NodeID) ([]stgmod.ObjectBlock, error) {
var rets []stgmod.ObjectBlock
err := ctx.Table("ObjectBlock").Where("NodeID = ?", nodeID).Find(&rets).Error
return rets, err
}

func (db *ObjectBlockDB) BatchGetByObjectID(ctx SQLContext, objectIDs []cdssdk.ObjectID) ([]stgmod.ObjectBlock, error) {
if len(objectIDs) == 0 {
return nil, nil
}

var blocks []stgmod.ObjectBlock
err := ctx.Table("ObjectBlock").Where("ObjectID IN (?)", objectIDs).Order("ObjectID, `Index` ASC").Find(&blocks).Error
return blocks, err
}

func (db *ObjectBlockDB) Create(ctx SQLContext, objectID cdssdk.ObjectID, index int, nodeID cdssdk.NodeID, fileHash string) error {
block := stgmod.ObjectBlock{ObjectID: objectID, Index: index, NodeID: nodeID, FileHash: fileHash}
return ctx.Table("ObjectBlock").Create(&block).Error
}

func (db *ObjectBlockDB) BatchCreate(ctx SQLContext, blocks []stgmod.ObjectBlock) error {
if len(blocks) == 0 {
return nil
}

return ctx.Table("ObjectBlock").Create(&blocks).Error
}

func (db *ObjectBlockDB) DeleteByObjectID(ctx SQLContext, objectID cdssdk.ObjectID) error {
return ctx.Table("ObjectBlock").Where("ObjectID = ?", objectID).Delete(&stgmod.ObjectBlock{}).Error
}

func (db *ObjectBlockDB) BatchDeleteByObjectID(ctx SQLContext, objectIDs []cdssdk.ObjectID) error {
if len(objectIDs) == 0 {
return nil
}

return ctx.Table("ObjectBlock").Where("ObjectID IN (?)", objectIDs).Delete(&stgmod.ObjectBlock{}).Error
}

func (db *ObjectBlockDB) DeleteInPackage(ctx SQLContext, packageID cdssdk.PackageID) error {
return ctx.Table("ObjectBlock").Where("ObjectID IN (SELECT ObjectID FROM Object WHERE PackageID = ?)", packageID).Delete(&stgmod.ObjectBlock{}).Error
}

func (db *ObjectBlockDB) NodeBatchDelete(ctx SQLContext, nodeID cdssdk.NodeID, fileHashes []string) error {
if len(fileHashes) == 0 {
return nil
}

return ctx.Table("ObjectBlock").Where("NodeID = ? AND FileHash IN (?)", nodeID, fileHashes).Delete(&stgmod.ObjectBlock{}).Error
}

func (db *ObjectBlockDB) CountBlockWithHash(ctx SQLContext, fileHash string) (int, error) {
var cnt int64
err := ctx.Table("ObjectBlock").
Select("COUNT(FileHash)").
Joins("INNER JOIN Object ON ObjectBlock.ObjectID = Object.ObjectID").
Joins("INNER JOIN Package ON Object.PackageID = Package.PackageID").
Where("FileHash = ? AND Package.State = ?", fileHash, cdssdk.PackageStateNormal).
Scan(&cnt).Error

if err != nil {
return 0, err
}

return int(cnt), nil
}

// 按逗号切割字符串,并将每一个部分解析为一个int64的ID。
// 注:需要外部保证分隔的每一个部分都是正确的10进制数字格式
func splitConcatedNodeID(idStr string) []cdssdk.NodeID {
idStrs := strings.Split(idStr, ",")
ids := make([]cdssdk.NodeID, 0, len(idStrs))

for _, str := range idStrs {
// 假设传入的ID是正确的数字格式
id, _ := strconv.ParseInt(str, 10, 64)
ids = append(ids, cdssdk.NodeID(id))
}

return ids
}

// 按逗号切割字符串
func splitConcatedFileHash(idStr string) []string {
idStrs := strings.Split(idStr, ",")
return idStrs
}

+ 186
- 0
common/pkgs/db2/package.go View File

@@ -0,0 +1,186 @@
package db2

import (
"errors"
"fmt"
"gorm.io/gorm"

cdssdk "gitlink.org.cn/cloudream/common/sdks/storage"
"gitlink.org.cn/cloudream/storage/common/pkgs/db/model"
)

type PackageDB struct {
*DB
}

func (db *DB) Package() *PackageDB {
return &PackageDB{DB: db}
}

func (db *PackageDB) GetByID(ctx SQLContext, packageID cdssdk.PackageID) (model.Package, error) {
var ret model.Package
err := ctx.Table("Package").Where("PackageID = ?", packageID).First(&ret).Error
return ret, err
}

func (db *PackageDB) GetByName(ctx SQLContext, bucketID cdssdk.BucketID, name string) (model.Package, error) {
var ret model.Package
err := ctx.Table("Package").Where("BucketID = ? AND Name = ?", bucketID, name).First(&ret).Error
return ret, err
}

func (db *PackageDB) BatchTestPackageID(ctx SQLContext, pkgIDs []cdssdk.PackageID) (map[cdssdk.PackageID]bool, error) {
if len(pkgIDs) == 0 {
return make(map[cdssdk.PackageID]bool), nil
}

var avaiIDs []cdssdk.PackageID
err := ctx.Table("Package").
Select("PackageID").
Where("PackageID IN ?", pkgIDs).
Find(&avaiIDs).Error
if err != nil {
return nil, err
}

avaiIDMap := make(map[cdssdk.PackageID]bool)
for _, pkgID := range avaiIDs {
avaiIDMap[pkgID] = true
}

return avaiIDMap, nil
}

func (*PackageDB) BatchGetAllPackageIDs(ctx SQLContext, start int, count int) ([]cdssdk.PackageID, error) {
var ret []cdssdk.PackageID
err := ctx.Table("Package").Select("PackageID").Limit(count).Offset(start).Find(&ret).Error
return ret, err
}

func (db *PackageDB) GetBucketPackages(ctx SQLContext, userID cdssdk.UserID, bucketID cdssdk.BucketID) ([]model.Package, error) {
var ret []model.Package
err := ctx.Table("UserBucket").
Select("Package.*").
Joins("JOIN Package ON UserBucket.BucketID = Package.BucketID").
Where("UserBucket.UserID = ? AND UserBucket.BucketID = ?", userID, bucketID).
Find(&ret).Error
return ret, err
}

// IsAvailable 判断一个用户是否拥有指定对象
func (db *PackageDB) IsAvailable(ctx SQLContext, userID cdssdk.UserID, packageID cdssdk.PackageID) (bool, error) {
var pkgID cdssdk.PackageID
err := ctx.Table("Package").
Select("Package.PackageID").
Joins("JOIN UserBucket ON Package.BucketID = UserBucket.BucketID").
Where("Package.PackageID = ? AND UserBucket.UserID = ?", packageID, userID).
Scan(&pkgID).Error

if err == gorm.ErrRecordNotFound {
return false, nil
}

if err != nil {
return false, fmt.Errorf("find package failed, err: %w", err)
}

return true, nil
}

// GetUserPackage 获得Package,如果用户没有权限访问,则不会获得结果
func (db *PackageDB) GetUserPackage(ctx SQLContext, userID cdssdk.UserID, packageID cdssdk.PackageID) (model.Package, error) {
var ret model.Package
err := ctx.Table("Package").
Select("Package.*").
Joins("JOIN UserBucket ON Package.BucketID = UserBucket.BucketID").
Where("Package.PackageID = ? AND UserBucket.UserID = ?", packageID, userID).
First(&ret).Error
return ret, err
}

// 在指定名称的Bucket中查找指定名称的Package
func (*PackageDB) GetUserPackageByName(ctx SQLContext, userID cdssdk.UserID, bucketName string, packageName string) (model.Package, error) {
var ret model.Package
err := ctx.Table("Package").
Select("Package.*").
Joins("JOIN Bucket ON Package.BucketID = Bucket.BucketID").
Joins("JOIN UserBucket ON Bucket.BucketID = UserBucket.BucketID").
Where("Package.Name = ? AND Bucket.Name = ? AND UserBucket.UserID = ?", packageName, bucketName, userID).
First(&ret).Error
return ret, err
}

func (db *PackageDB) Create(ctx SQLContext, bucketID cdssdk.BucketID, name string) (cdssdk.PackageID, error) {
var packageID int64
err := ctx.Table("Package").
Select("PackageID").
Where("Name = ? AND BucketID = ?", name, bucketID).
First(&packageID).Error

if err == nil {
return 0, fmt.Errorf("package with given Name and BucketID already exists")
}
if !errors.Is(err, gorm.ErrRecordNotFound) {
return 0, fmt.Errorf("query Package by PackageName and BucketID failed, err: %w", err)
}

newPackage := model.Package{Name: name, BucketID: bucketID, State: cdssdk.PackageStateNormal}
if err := ctx.Create(&newPackage).Error; err != nil {
return 0, fmt.Errorf("insert package failed, err: %w", err)
}

return cdssdk.PackageID(newPackage.PackageID), nil
}

// SoftDelete 设置一个对象被删除,并将相关数据删除
func (db *PackageDB) SoftDelete(ctx SQLContext, packageID cdssdk.PackageID) error {
obj, err := db.GetByID(ctx, packageID)
if err != nil {
return fmt.Errorf("get package failed, err: %w", err)
}

if obj.State != cdssdk.PackageStateNormal {
return nil
}

if err := db.ChangeState(ctx, packageID, cdssdk.PackageStateDeleted); err != nil {
return fmt.Errorf("change package state failed, err: %w", err)
}

if err := db.ObjectAccessStat().DeleteInPackage(ctx, packageID); err != nil {
return fmt.Errorf("delete from object access stat: %w", err)
}

if err := db.ObjectBlock().DeleteInPackage(ctx, packageID); err != nil {
return fmt.Errorf("delete from object block failed, err: %w", err)
}

if err := db.PinnedObject().DeleteInPackage(ctx, packageID); err != nil {
return fmt.Errorf("deleting pinned objects in package: %w", err)
}

if err := db.Object().DeleteInPackage(ctx, packageID); err != nil {
return fmt.Errorf("deleting objects in package: %w", err)
}

if _, err := db.StoragePackage().SetAllPackageDeleted(ctx, packageID); err != nil {
return fmt.Errorf("set storage package deleted failed, err: %w", err)
}

return nil
}

// DeleteUnused 删除一个已经是Deleted状态,且不再被使用的对象
func (PackageDB) DeleteUnused(ctx SQLContext, packageID cdssdk.PackageID) error {
err := ctx.Exec("DELETE FROM Package WHERE PackageID = ? AND State = ? AND NOT EXISTS (SELECT StorageID FROM StoragePackage WHERE PackageID = ?)",
packageID,
cdssdk.PackageStateDeleted,
packageID,
).Error
return err
}

func (*PackageDB) ChangeState(ctx SQLContext, packageID cdssdk.PackageID, state string) error {
err := ctx.Exec("UPDATE Package SET State = ? WHERE PackageID = ?", state, packageID).Error
return err
}

+ 66
- 0
common/pkgs/db2/package_access_stat.go View File

@@ -0,0 +1,66 @@
package db2

import (
cdssdk "gitlink.org.cn/cloudream/common/sdks/storage"
stgmod "gitlink.org.cn/cloudream/storage/common/models"
coormq "gitlink.org.cn/cloudream/storage/common/pkgs/mq/coordinator"
)

type PackageAccessStatDB struct {
*DB
}

func (db *DB) PackageAccessStat() *PackageAccessStatDB {
return &PackageAccessStatDB{db}
}

func (*PackageAccessStatDB) Get(ctx SQLContext, pkgID cdssdk.PackageID, nodeID cdssdk.NodeID) (stgmod.PackageAccessStat, error) {
var ret stgmod.PackageAccessStat
err := ctx.Table("PackageAccessStat").Where("PackageID = ? AND NodeID = ?", pkgID, nodeID).First(&ret).Error
return ret, err
}

func (*PackageAccessStatDB) GetByPackageID(ctx SQLContext, pkgID cdssdk.PackageID) ([]stgmod.PackageAccessStat, error) {
var ret []stgmod.PackageAccessStat
err := ctx.Table("PackageAccessStat").Where("PackageID = ?", pkgID).Find(&ret).Error
return ret, err
}

func (*PackageAccessStatDB) BatchGetByPackageID(ctx SQLContext, pkgIDs []cdssdk.PackageID) ([]stgmod.PackageAccessStat, error) {
if len(pkgIDs) == 0 {
return nil, nil
}

var ret []stgmod.PackageAccessStat
err := ctx.Table("PackageAccessStat").Where("PackageID IN (?)", pkgIDs).Find(&ret).Error
return ret, err
}

func (*PackageAccessStatDB) BatchAddCounter(ctx SQLContext, entries []coormq.AddAccessStatEntry) error {
if len(entries) == 0 {
return nil
}

sql := "INSERT INTO PackageAccessStat(PackageID, NodeID, Counter, Amount) " +
"VALUES(:PackageID, :NodeID, :Counter, 0) ON DUPLICATE KEY UPDATE Counter = Counter + VALUES(Counter)"

return ctx.Exec(sql, entries).Error
}

func (*PackageAccessStatDB) BatchUpdateAmount(ctx SQLContext, pkgIDs []cdssdk.PackageID, historyWeight float64) error {
if len(pkgIDs) == 0 {
return nil
}

sql := "UPDATE PackageAccessStat SET Amount = Amount * ? + Counter * (1 - ?), Counter = 0 WHERE PackageID IN (?)"
return ctx.Exec(sql, historyWeight, historyWeight, pkgIDs).Error
}

func (*PackageAccessStatDB) UpdateAllAmount(ctx SQLContext, historyWeight float64) error {
sql := "UPDATE PackageAccessStat SET Amount = Amount * ? + Counter * (1 - ?), Counter = 0"
return ctx.Exec(sql, historyWeight, historyWeight).Error
}

func (*PackageAccessStatDB) DeleteByPackageID(ctx SQLContext, pkgID cdssdk.PackageID) error {
return ctx.Table("PackageAccessStat").Where("PackageID = ?", pkgID).Delete(&stgmod.PackageAccessStat{}).Error
}

+ 83
- 0
common/pkgs/db2/storage_package.go View File

@@ -0,0 +1,83 @@
package db2

import (
cdssdk "gitlink.org.cn/cloudream/common/sdks/storage"
"gitlink.org.cn/cloudream/storage/common/pkgs/db/model"
)

type StoragePackageDB struct {
*DB
}

func (db *DB) StoragePackage() *StoragePackageDB {
return &StoragePackageDB{DB: db}
}

func (*StoragePackageDB) Get(ctx SQLContext, storageID cdssdk.StorageID, packageID cdssdk.PackageID, userID cdssdk.UserID) (model.StoragePackage, error) {
var ret model.StoragePackage
err := ctx.Table("StoragePackage").Where("StorageID = ? AND PackageID = ? AND UserID = ?", storageID, packageID, userID).First(&ret).Error
return ret, err
}

func (*StoragePackageDB) GetAllByStorageAndPackageID(ctx SQLContext, storageID cdssdk.StorageID, packageID cdssdk.PackageID) ([]model.StoragePackage, error) {
var ret []model.StoragePackage
err := ctx.Table("StoragePackage").Where("StorageID = ? AND PackageID = ?", storageID, packageID).Find(&ret).Error
return ret, err
}

func (*StoragePackageDB) GetAllByStorageID(ctx SQLContext, storageID cdssdk.StorageID) ([]model.StoragePackage, error) {
var ret []model.StoragePackage
err := ctx.Table("StoragePackage").Where("StorageID = ?", storageID).Find(&ret).Error
return ret, err
}

func (*StoragePackageDB) CreateOrUpdate(ctx SQLContext, storageID cdssdk.StorageID, packageID cdssdk.PackageID, userID cdssdk.UserID) error {
sql := "INSERT INTO StoragePackage (StorageID, PackageID, UserID, State) VALUES (?, ?, ?, ?) " +
"ON DUPLICATE KEY UPDATE State = VALUES(State)"
return ctx.Exec(sql, storageID, packageID, userID, model.StoragePackageStateNormal).Error
}

func (*StoragePackageDB) ChangeState(ctx SQLContext, storageID cdssdk.StorageID, packageID cdssdk.PackageID, userID cdssdk.UserID, state string) error {
return ctx.Table("StoragePackage").Where("StorageID = ? AND PackageID = ? AND UserID = ?", storageID, packageID, userID).Update("State", state).Error
}

// SetStateNormal 将状态设置为Normal,如果记录状态是Deleted,则不进行操作
func (*StoragePackageDB) SetStateNormal(ctx SQLContext, storageID cdssdk.StorageID, packageID cdssdk.PackageID, userID cdssdk.UserID) error {
return ctx.Table("StoragePackage").Where("StorageID = ? AND PackageID = ? AND UserID = ? AND State <> ?",
storageID, packageID, userID, model.StoragePackageStateDeleted).Update("State", model.StoragePackageStateNormal).Error
}

func (*StoragePackageDB) SetAllPackageState(ctx SQLContext, packageID cdssdk.PackageID, state string) (int64, error) {
ret := ctx.Table("StoragePackage").Where("PackageID = ?", packageID).Update("State", state)
if err := ret.Error; err != nil {
return 0, err
}
return ret.RowsAffected, nil
}

// SetAllPackageOutdated 将Storage中指定对象设置为已过期。只会设置Normal状态的对象
func (*StoragePackageDB) SetAllPackageOutdated(ctx SQLContext, packageID cdssdk.PackageID) (int64, error) {
ret := ctx.Table("StoragePackage").Where("State = ? AND PackageID = ?", model.StoragePackageStateNormal, packageID).Update("State", model.StoragePackageStateOutdated)
if err := ret.Error; err != nil {
return 0, err
}
return ret.RowsAffected, nil
}

func (db *StoragePackageDB) SetAllPackageDeleted(ctx SQLContext, packageID cdssdk.PackageID) (int64, error) {
return db.SetAllPackageState(ctx, packageID, model.StoragePackageStateDeleted)
}

func (*StoragePackageDB) Delete(ctx SQLContext, storageID cdssdk.StorageID, packageID cdssdk.PackageID, userID cdssdk.UserID) error {
return ctx.Table("StoragePackage").Where("StorageID = ? AND PackageID = ? AND UserID = ?", storageID, packageID, userID).Delete(&model.StoragePackage{}).Error
}

// FindPackageStorages 查询存储了指定对象的Storage
func (*StoragePackageDB) FindPackageStorages(ctx SQLContext, packageID cdssdk.PackageID) ([]model.Storage, error) {
var ret []model.Storage
err := ctx.Table("StoragePackage").Select("Storage.*").
Joins("JOIN Storage ON StoragePackage.StorageID = Storage.StorageID").
Where("PackageID = ?", packageID).
Scan(&ret).Error
return ret, err
}

+ 20
- 0
common/pkgs/db2/user.go View File

@@ -0,0 +1,20 @@
package db2

import (
cdssdk "gitlink.org.cn/cloudream/common/sdks/storage"
"gitlink.org.cn/cloudream/storage/common/pkgs/db/model"
)

type UserDB struct {
*DB
}

func (db *DB) User() *UserDB {
return &UserDB{DB: db}
}

func (db *UserDB) GetByID(ctx SQLContext, userID cdssdk.UserID) (model.User, error) {
var ret model.User
err := ctx.Table("User").Where("UserID = ?", userID).First(&ret).Error
return ret, err
}

+ 22
- 0
common/pkgs/db2/user_bucket.go View File

@@ -0,0 +1,22 @@
package db2

import (
cdssdk "gitlink.org.cn/cloudream/common/sdks/storage"
"gitlink.org.cn/cloudream/storage/common/pkgs/db/model"
)

type UserBucketDB struct {
*DB
}

func (db *DB) UserBucket() *UserBucketDB {
return &UserBucketDB{DB: db}
}

func (*UserBucketDB) Create(ctx SQLContext, userID int64, bucketID int64) error {
userBucket := model.UserBucket{
UserID: cdssdk.UserID(userID),
BucketID: cdssdk.BucketID(bucketID),
}
return ctx.Table("UserBucket").Create(&userBucket).Error
}

+ 57
- 0
common/pkgs/db2/utils.go View File

@@ -0,0 +1,57 @@
package db2

import (
"gorm.io/gorm"
)

const (
maxPlaceholderCount = 65535
)

func BatchNamedExec[T any](ctx SQLContext, sql string, argCnt int, arr []T, callback func(result *gorm.DB) bool) error {
if argCnt == 0 {
result := ctx.Exec(sql, toInterfaceSlice(arr)...)
if result.Error != nil {
return result.Error
}

if callback != nil {
callback(result)
}

return nil
}

batchSize := maxPlaceholderCount / argCnt
for len(arr) > 0 {
curBatchSize := min(batchSize, len(arr))

result := ctx.Exec(sql, toInterfaceSlice(arr[:curBatchSize])...)
if result.Error != nil {
return result.Error
}
if callback != nil && !callback(result) {
return nil
}

arr = arr[curBatchSize:]
}

return nil
}

// 将 []T 转换为 []interface{}
func toInterfaceSlice[T any](arr []T) []interface{} {
interfaceSlice := make([]interface{}, len(arr))
for i, v := range arr {
interfaceSlice[i] = v
}
return interfaceSlice
}

func min(a, b int) int {
if a < b {
return a
}
return b
}

+ 3
- 1
common/pkgs/storage/temp/local.go View File

@@ -4,10 +4,12 @@ import cdssdk "gitlink.org.cn/cloudream/common/sdks/storage"

type LocalTempStore struct {
cfg cdssdk.BypassUploadFeature
stg cdssdk.Storage
}

func NewLocalTempStore(cfg cdssdk.BypassUploadFeature) *LocalTempStore {
func NewLocalTempStore(stg cdssdk.Storage, cfg cdssdk.BypassUploadFeature) *LocalTempStore {
return &LocalTempStore{
cfg: cfg,
stg: stg,
}
}

+ 11
- 12
coordinator/internal/mq/bucket.go View File

@@ -1,10 +1,9 @@
package mq

import (
"database/sql"
"fmt"
"gitlink.org.cn/cloudream/storage/common/pkgs/db2"

"github.com/jmoiron/sqlx"
"gitlink.org.cn/cloudream/common/consts/errorcode"
"gitlink.org.cn/cloudream/common/pkgs/logger"
"gitlink.org.cn/cloudream/common/pkgs/mq"
@@ -19,7 +18,7 @@ func (svc *Service) GetBucket(userID cdssdk.UserID, bucketID cdssdk.BucketID) (m
}

func (svc *Service) GetBucketByName(msg *coormq.GetBucketByName) (*coormq.GetBucketByNameResp, *mq.CodeMessage) {
bucket, err := svc.db.Bucket().GetUserBucketByName(svc.db.SQLCtx(), msg.UserID, msg.Name)
bucket, err := svc.db2.Bucket().GetUserBucketByName(svc.db2.DefCtx(), msg.UserID, msg.Name)
if err != nil {
logger.WithField("UserID", msg.UserID).
WithField("Name", msg.Name).
@@ -31,7 +30,7 @@ func (svc *Service) GetBucketByName(msg *coormq.GetBucketByName) (*coormq.GetBuc
}

func (svc *Service) GetUserBuckets(msg *coormq.GetUserBuckets) (*coormq.GetUserBucketsResp, *mq.CodeMessage) {
buckets, err := svc.db.Bucket().GetUserBuckets(svc.db.SQLCtx(), msg.UserID)
buckets, err := svc.db2.Bucket().GetUserBuckets(svc.db2.DefCtx(), msg.UserID)

if err != nil {
logger.WithField("UserID", msg.UserID).
@@ -43,7 +42,7 @@ func (svc *Service) GetUserBuckets(msg *coormq.GetUserBuckets) (*coormq.GetUserB
}

func (svc *Service) GetBucketPackages(msg *coormq.GetBucketPackages) (*coormq.GetBucketPackagesResp, *mq.CodeMessage) {
packages, err := svc.db.Package().GetBucketPackages(svc.db.SQLCtx(), msg.UserID, msg.BucketID)
packages, err := svc.db2.Package().GetBucketPackages(svc.db2.DefCtx(), msg.UserID, msg.BucketID)

if err != nil {
logger.WithField("UserID", msg.UserID).
@@ -57,18 +56,18 @@ func (svc *Service) GetBucketPackages(msg *coormq.GetBucketPackages) (*coormq.Ge

func (svc *Service) CreateBucket(msg *coormq.CreateBucket) (*coormq.CreateBucketResp, *mq.CodeMessage) {
var bucket cdssdk.Bucket
err := svc.db.DoTx(sql.LevelSerializable, func(tx *sqlx.Tx) error {
_, err := svc.db.User().GetByID(tx, msg.UserID)
err := svc.db2.DoTx(func(tx db2.SQLContext) error {
_, err := svc.db2.User().GetByID(tx, msg.UserID)
if err != nil {
return fmt.Errorf("getting user by id: %w", err)
}

bucketID, err := svc.db.Bucket().Create(tx, msg.UserID, msg.BucketName)
bucketID, err := svc.db2.Bucket().Create(tx, msg.UserID, msg.BucketName)
if err != nil {
return fmt.Errorf("creating bucket: %w", err)
}

bucket, err = svc.db.Bucket().GetByID(tx, bucketID)
bucket, err = svc.db2.Bucket().GetByID(tx, bucketID)
if err != nil {
return fmt.Errorf("getting bucket by id: %w", err)
}
@@ -85,13 +84,13 @@ func (svc *Service) CreateBucket(msg *coormq.CreateBucket) (*coormq.CreateBucket
}

func (svc *Service) DeleteBucket(msg *coormq.DeleteBucket) (*coormq.DeleteBucketResp, *mq.CodeMessage) {
err := svc.db.DoTx(sql.LevelSerializable, func(tx *sqlx.Tx) error {
isAvai, _ := svc.db.Bucket().IsAvailable(tx, msg.BucketID, msg.UserID)
err := svc.db2.DoTx(func(tx db2.SQLContext) error {
isAvai, _ := svc.db2.Bucket().IsAvailable(tx, msg.BucketID, msg.UserID)
if !isAvai {
return fmt.Errorf("bucket is not avaiable to the user")
}

err := svc.db.Bucket().Delete(tx, msg.BucketID)
err := svc.db2.Bucket().Delete(tx, msg.BucketID)
if err != nil {
return fmt.Errorf("deleting bucket: %w", err)
}


+ 9
- 10
coordinator/internal/mq/cache.go View File

@@ -1,10 +1,9 @@
package mq

import (
"database/sql"
"fmt"
"gitlink.org.cn/cloudream/storage/common/pkgs/db2"

"github.com/jmoiron/sqlx"
"gitlink.org.cn/cloudream/common/consts/errorcode"
"gitlink.org.cn/cloudream/common/pkgs/logger"
"gitlink.org.cn/cloudream/common/pkgs/mq"
@@ -12,18 +11,18 @@ import (
)

func (svc *Service) CachePackageMoved(msg *coormq.CachePackageMoved) (*coormq.CachePackageMovedResp, *mq.CodeMessage) {
err := svc.db.DoTx(sql.LevelSerializable, func(tx *sqlx.Tx) error {
_, err := svc.db.Package().GetByID(tx, msg.PackageID)
err := svc.db2.DoTx(func(tx db2.SQLContext) error {
_, err := svc.db2.Package().GetByID(tx, msg.PackageID)
if err != nil {
return fmt.Errorf("getting package by id: %w", err)
}

_, err = svc.db.Node().GetByID(tx, msg.StorageID)
_, err = svc.db2.Node().GetByID(tx, msg.StorageID)
if err != nil {
return fmt.Errorf("getting node by id: %w", err)
}

err = svc.db.PinnedObject().CreateFromPackage(tx, msg.PackageID, msg.StorageID)
err = svc.db2.PinnedObject().CreateFromPackage(tx, msg.PackageID, msg.StorageID)
if err != nil {
return fmt.Errorf("creating pinned objects from package: %w", err)
}
@@ -39,18 +38,18 @@ func (svc *Service) CachePackageMoved(msg *coormq.CachePackageMoved) (*coormq.Ca
}

func (svc *Service) CacheRemovePackage(msg *coormq.CacheRemovePackage) (*coormq.CacheRemovePackageResp, *mq.CodeMessage) {
err := svc.db.DoTx(sql.LevelSerializable, func(tx *sqlx.Tx) error {
_, err := svc.db.Package().GetByID(tx, msg.PackageID)
err := svc.db2.DoTx(func(tx db2.SQLContext) error {
_, err := svc.db2.Package().GetByID(tx, msg.PackageID)
if err != nil {
return fmt.Errorf("getting package by id: %w", err)
}

_, err = svc.db.Node().GetByID(tx, msg.NodeID)
_, err = svc.db2.Node().GetByID(tx, msg.NodeID)
if err != nil {
return fmt.Errorf("getting node by id: %w", err)
}

err = svc.db.PinnedObject().DeleteInPackageAtNode(tx, msg.PackageID, msg.NodeID)
err = svc.db2.PinnedObject().DeleteInPackageAtNode(tx, msg.PackageID, msg.NodeID)
if err != nil {
return fmt.Errorf("delete pinned objects in package at node: %w", err)
}


+ 6
- 7
coordinator/internal/mq/node.go View File

@@ -1,10 +1,9 @@
package mq

import (
"database/sql"
"fmt"
"gitlink.org.cn/cloudream/storage/common/pkgs/db2"

"github.com/jmoiron/sqlx"
"gitlink.org.cn/cloudream/common/consts/errorcode"
"gitlink.org.cn/cloudream/common/pkgs/logger"
"gitlink.org.cn/cloudream/common/pkgs/mq"
@@ -13,7 +12,7 @@ import (
)

func (svc *Service) GetUserNodes(msg *coormq.GetUserNodes) (*coormq.GetUserNodesResp, *mq.CodeMessage) {
nodes, err := svc.db.Node().GetUserNodes(svc.db.SQLCtx(), msg.UserID)
nodes, err := svc.db2.Node().GetUserNodes(svc.db2.DefCtx(), msg.UserID)
if err != nil {
logger.WithField("UserID", msg.UserID).
Warnf("query user nodes failed, err: %s", err.Error())
@@ -52,7 +51,7 @@ func (svc *Service) GetNodes(msg *coormq.GetNodes) (*coormq.GetNodesResp, *mq.Co
}

func (svc *Service) GetNodeConnectivities(msg *coormq.GetNodeConnectivities) (*coormq.GetNodeConnectivitiesResp, *mq.CodeMessage) {
cons, err := svc.db.NodeConnectivity().BatchGetByFromNode(svc.db.SQLCtx(), msg.NodeIDs)
cons, err := svc.db2.NodeConnectivity().BatchGetByFromNode(svc.db2.DefCtx(), msg.NodeIDs)
if err != nil {
logger.Warnf("batch get node connectivities by from node: %s", err.Error())
return nil, mq.Failed(errorcode.OperationFailed, "batch get node connectivities by from node failed")
@@ -62,9 +61,9 @@ func (svc *Service) GetNodeConnectivities(msg *coormq.GetNodeConnectivities) (*c
}

func (svc *Service) UpdateNodeConnectivities(msg *coormq.UpdateNodeConnectivities) (*coormq.UpdateNodeConnectivitiesResp, *mq.CodeMessage) {
err := svc.db.DoTx(sql.LevelSerializable, func(tx *sqlx.Tx) error {
err := svc.db2.DoTx(func(tx db2.SQLContext) error {
// 只有发起节点和目的节点都存在,才能插入这条记录到数据库
allNodes, err := svc.db.Node().GetAllNodes(tx)
allNodes, err := svc.db2.Node().GetAllNodes(tx)
if err != nil {
return fmt.Errorf("getting all nodes: %w", err)
}
@@ -81,7 +80,7 @@ func (svc *Service) UpdateNodeConnectivities(msg *coormq.UpdateNodeConnectivitie
}
}

err = svc.db.NodeConnectivity().BatchUpdateOrCreate(tx, avaiCons)
err = svc.db2.NodeConnectivity().BatchUpdateOrCreate(tx, avaiCons)
if err != nil {
return fmt.Errorf("batch update or create node connectivities: %s", err)
}


+ 30
- 30
coordinator/internal/mq/object.go View File

@@ -3,8 +3,8 @@ package mq
import (
"database/sql"
"fmt"
"gitlink.org.cn/cloudream/storage/common/pkgs/db2"

"github.com/jmoiron/sqlx"
"github.com/samber/lo"
"gitlink.org.cn/cloudream/common/consts/errorcode"
"gitlink.org.cn/cloudream/common/pkgs/logger"
@@ -18,13 +18,13 @@ import (

func (svc *Service) GetPackageObjects(msg *coormq.GetPackageObjects) (*coormq.GetPackageObjectsResp, *mq.CodeMessage) {
var objs []cdssdk.Object
err := svc.db.DoTx(sql.LevelSerializable, func(tx *sqlx.Tx) error {
_, err := svc.db.Package().GetUserPackage(tx, msg.UserID, msg.PackageID)
err := svc.db2.DoTx(func(tx db2.SQLContext) error {
_, err := svc.db2.Package().GetUserPackage(tx, msg.UserID, msg.PackageID)
if err != nil {
return fmt.Errorf("getting package by id: %w", err)
}

objs, err = svc.db.Object().GetPackageObjects(svc.db.SQLCtx(), msg.PackageID)
objs, err = svc.db2.Object().GetPackageObjects(tx, msg.PackageID)
if err != nil {
return fmt.Errorf("getting package objects: %w", err)
}
@@ -44,14 +44,14 @@ func (svc *Service) GetPackageObjects(msg *coormq.GetPackageObjects) (*coormq.Ge
func (svc *Service) GetPackageObjectDetails(msg *coormq.GetPackageObjectDetails) (*coormq.GetPackageObjectDetailsResp, *mq.CodeMessage) {
var details []stgmod.ObjectDetail
// 必须放在事务里进行,因为GetPackageBlockDetails是由多次数据库操作组成,必须保证数据的一致性
err := svc.db.DoTx(sql.LevelSerializable, func(tx *sqlx.Tx) error {
err := svc.db2.DoTx(func(tx db2.SQLContext) error {
var err error
_, err = svc.db.Package().GetByID(tx, msg.PackageID)
_, err = svc.db2.Package().GetByID(tx, msg.PackageID)
if err != nil {
return fmt.Errorf("getting package by id: %w", err)
}

details, err = svc.db.Object().GetPackageObjectDetails(tx, msg.PackageID)
details, err = svc.db2.Object().GetPackageObjectDetails(tx, msg.PackageID)
if err != nil {
return fmt.Errorf("getting package block details: %w", err)
}
@@ -69,13 +69,13 @@ func (svc *Service) GetPackageObjectDetails(msg *coormq.GetPackageObjectDetails)

func (svc *Service) GetObjectDetails(msg *coormq.GetObjectDetails) (*coormq.GetObjectDetailsResp, *mq.CodeMessage) {
details := make([]*stgmod.ObjectDetail, len(msg.ObjectIDs))
err := svc.db.DoTx(sql.LevelSerializable, func(tx *sqlx.Tx) error {
err := svc.db2.DoTx(func(tx db2.SQLContext) error {
var err error

msg.ObjectIDs = sort2.SortAsc(msg.ObjectIDs)

// 根据ID依次查询Object,ObjectBlock,PinnedObject,并根据升序的特点进行合并
objs, err := svc.db.Object().BatchGet(tx, msg.ObjectIDs)
objs, err := svc.db2.Object().BatchGet(tx, msg.ObjectIDs)
if err != nil {
return fmt.Errorf("batch get objects: %w", err)
}
@@ -98,7 +98,7 @@ func (svc *Service) GetObjectDetails(msg *coormq.GetObjectDetails) (*coormq.GetO
}

// 查询合并
blocks, err := svc.db.ObjectBlock().BatchGetByObjectID(tx, msg.ObjectIDs)
blocks, err := svc.db2.ObjectBlock().BatchGetByObjectID(tx, msg.ObjectIDs)
if err != nil {
return fmt.Errorf("batch get object blocks: %w", err)
}
@@ -121,7 +121,7 @@ func (svc *Service) GetObjectDetails(msg *coormq.GetObjectDetails) (*coormq.GetO
}

// 查询合并
pinneds, err := svc.db.PinnedObject().BatchGetByObjectID(tx, msg.ObjectIDs)
pinneds, err := svc.db2.PinnedObject().BatchGetByObjectID(tx, msg.ObjectIDs)
if err != nil {
return fmt.Errorf("batch get pinned objects: %w", err)
}
@@ -154,8 +154,8 @@ func (svc *Service) GetObjectDetails(msg *coormq.GetObjectDetails) (*coormq.GetO
}

func (svc *Service) UpdateObjectRedundancy(msg *coormq.UpdateObjectRedundancy) (*coormq.UpdateObjectRedundancyResp, *mq.CodeMessage) {
err := svc.db.DoTx(sql.LevelSerializable, func(tx *sqlx.Tx) error {
return svc.db.Object().BatchUpdateRedundancy(tx, msg.Updatings)
err := svc.db2.DoTx(func(tx db2.SQLContext) error {
return svc.db2.Object().BatchUpdateRedundancy(tx, msg.Updatings)
})
if err != nil {
logger.Warnf("batch updating redundancy: %s", err.Error())
@@ -167,7 +167,7 @@ func (svc *Service) UpdateObjectRedundancy(msg *coormq.UpdateObjectRedundancy) (

func (svc *Service) UpdateObjectInfos(msg *coormq.UpdateObjectInfos) (*coormq.UpdateObjectInfosResp, *mq.CodeMessage) {
var sucs []cdssdk.ObjectID
err := svc.db.DoTx(sql.LevelSerializable, func(tx *sqlx.Tx) error {
err := svc.db2.DoTx(func(tx db2.SQLContext) error {
msg.Updatings = sort2.Sort(msg.Updatings, func(o1, o2 cdsapi.UpdatingObject) int {
return sort2.Cmp(o1.ObjectID, o2.ObjectID)
})
@@ -177,7 +177,7 @@ func (svc *Service) UpdateObjectInfos(msg *coormq.UpdateObjectInfos) (*coormq.Up
objIDs[i] = obj.ObjectID
}

oldObjs, err := svc.db.Object().BatchGet(tx, objIDs)
oldObjs, err := svc.db2.Object().BatchGet(tx, objIDs)
if err != nil {
return fmt.Errorf("batch getting objects: %w", err)
}
@@ -197,7 +197,7 @@ func (svc *Service) UpdateObjectInfos(msg *coormq.UpdateObjectInfos) (*coormq.Up
avaiUpdatings[i].ApplyTo(&newObjs[i])
}

err = svc.db.Object().BatchUpsertByPackagePath(tx, newObjs)
err = svc.db2.Object().BatchUpsertByPackagePath(tx, newObjs)
if err != nil {
return fmt.Errorf("batch create or update: %w", err)
}
@@ -237,7 +237,7 @@ func pickByObjectIDs[T any](objs []T, objIDs []cdssdk.ObjectID, getID func(T) cd

func (svc *Service) MoveObjects(msg *coormq.MoveObjects) (*coormq.MoveObjectsResp, *mq.CodeMessage) {
var sucs []cdssdk.ObjectID
err := svc.db.DoTx(sql.LevelSerializable, func(tx *sqlx.Tx) error {
err := svc.db2.DoTx(func(tx db2.SQLContext) error {
msg.Movings = sort2.Sort(msg.Movings, func(o1, o2 cdsapi.MovingObject) int {
return sort2.Cmp(o1.ObjectID, o2.ObjectID)
})
@@ -247,7 +247,7 @@ func (svc *Service) MoveObjects(msg *coormq.MoveObjects) (*coormq.MoveObjectsRes
objIDs[i] = obj.ObjectID
}

oldObjs, err := svc.db.Object().BatchGet(tx, objIDs)
oldObjs, err := svc.db2.Object().BatchGet(tx, objIDs)
if err != nil {
return fmt.Errorf("batch getting objects: %w", err)
}
@@ -291,7 +291,7 @@ func (svc *Service) MoveObjects(msg *coormq.MoveObjects) (*coormq.MoveObjectsRes
}
newObjs = append(newObjs, ensuredObjs...)

err = svc.db.Object().BatchUpert(tx, newObjs)
err = svc.db2.Object().BatchUpert(tx, newObjs)
if err != nil {
return fmt.Errorf("batch create or update: %w", err)
}
@@ -307,7 +307,7 @@ func (svc *Service) MoveObjects(msg *coormq.MoveObjects) (*coormq.MoveObjectsRes
return mq.ReplyOK(coormq.RespMoveObjects(sucs))
}

func (svc *Service) ensurePackageChangedObjects(tx *sqlx.Tx, userID cdssdk.UserID, objs []cdssdk.Object) ([]cdssdk.Object, error) {
func (svc *Service) ensurePackageChangedObjects(tx db2.SQLContext, userID cdssdk.UserID, objs []cdssdk.Object) ([]cdssdk.Object, error) {
if len(objs) == 0 {
return nil, nil
}
@@ -338,7 +338,7 @@ func (svc *Service) ensurePackageChangedObjects(tx *sqlx.Tx, userID cdssdk.UserI

var willUpdateObjs []cdssdk.Object
for _, pkg := range packages {
_, err := svc.db.Package().GetUserPackage(tx, userID, pkg.PackageID)
_, err := svc.db2.Package().GetUserPackage(tx, userID, pkg.PackageID)
if err == sql.ErrNoRows {
continue
}
@@ -346,7 +346,7 @@ func (svc *Service) ensurePackageChangedObjects(tx *sqlx.Tx, userID cdssdk.UserI
return nil, fmt.Errorf("getting user package by id: %w", err)
}

existsObjs, err := svc.db.Object().BatchGetByPackagePath(tx, pkg.PackageID, lo.Keys(pkg.ObjectByPath))
existsObjs, err := svc.db2.Object().BatchGetByPackagePath(tx, pkg.PackageID, lo.Keys(pkg.ObjectByPath))
if err != nil {
return nil, fmt.Errorf("batch getting objects by package path: %w", err)
}
@@ -368,7 +368,7 @@ func (svc *Service) ensurePackageChangedObjects(tx *sqlx.Tx, userID cdssdk.UserI
return willUpdateObjs, nil
}

func (svc *Service) ensurePathChangedObjects(tx *sqlx.Tx, userID cdssdk.UserID, objs []cdssdk.Object) ([]cdssdk.Object, error) {
func (svc *Service) ensurePathChangedObjects(tx db2.SQLContext, userID cdssdk.UserID, objs []cdssdk.Object) ([]cdssdk.Object, error) {
if len(objs) == 0 {
return nil, nil
}
@@ -384,7 +384,7 @@ func (svc *Service) ensurePathChangedObjects(tx *sqlx.Tx, userID cdssdk.UserID,

}

_, err := svc.db.Package().GetUserPackage(tx, userID, objs[0].PackageID)
_, err := svc.db2.Package().GetUserPackage(tx, userID, objs[0].PackageID)
if err == sql.ErrNoRows {
return nil, nil
}
@@ -392,7 +392,7 @@ func (svc *Service) ensurePathChangedObjects(tx *sqlx.Tx, userID cdssdk.UserID,
return nil, fmt.Errorf("getting user package by id: %w", err)
}

existsObjs, err := svc.db.Object().BatchGetByPackagePath(tx, objs[0].PackageID, lo.Map(objs, func(obj cdssdk.Object, idx int) string { return obj.Path }))
existsObjs, err := svc.db2.Object().BatchGetByPackagePath(tx, objs[0].PackageID, lo.Map(objs, func(obj cdssdk.Object, idx int) string { return obj.Path }))
if err != nil {
return nil, fmt.Errorf("batch getting objects by package path: %w", err)
}
@@ -414,23 +414,23 @@ func (svc *Service) ensurePathChangedObjects(tx *sqlx.Tx, userID cdssdk.UserID,
}

func (svc *Service) DeleteObjects(msg *coormq.DeleteObjects) (*coormq.DeleteObjectsResp, *mq.CodeMessage) {
err := svc.db.DoTx(sql.LevelSerializable, func(tx *sqlx.Tx) error {
err := svc.db.Object().BatchDelete(tx, msg.ObjectIDs)
err := svc.db2.DoTx(func(tx db2.SQLContext) error {
err := svc.db2.Object().BatchDelete(tx, msg.ObjectIDs)
if err != nil {
return fmt.Errorf("batch deleting objects: %w", err)
}

err = svc.db.ObjectBlock().BatchDeleteByObjectID(tx, msg.ObjectIDs)
err = svc.db2.ObjectBlock().BatchDeleteByObjectID(tx, msg.ObjectIDs)
if err != nil {
return fmt.Errorf("batch deleting object blocks: %w", err)
}

err = svc.db.PinnedObject().BatchDeleteByObjectID(tx, msg.ObjectIDs)
err = svc.db2.PinnedObject().BatchDeleteByObjectID(tx, msg.ObjectIDs)
if err != nil {
return fmt.Errorf("batch deleting pinned objects: %w", err)
}

err = svc.db.ObjectAccessStat().BatchDeleteByObjectID(tx, msg.ObjectIDs)
err = svc.db2.ObjectAccessStat().BatchDeleteByObjectID(tx, msg.ObjectIDs)
if err != nil {
return fmt.Errorf("batch deleting object access stats: %w", err)
}


+ 24
- 24
coordinator/internal/mq/package.go View File

@@ -3,9 +3,9 @@ package mq
import (
"database/sql"
"fmt"
"gitlink.org.cn/cloudream/storage/common/pkgs/db2"
"sort"

"github.com/jmoiron/sqlx"
"gitlink.org.cn/cloudream/common/consts/errorcode"
"gitlink.org.cn/cloudream/common/pkgs/logger"
"gitlink.org.cn/cloudream/common/pkgs/mq"
@@ -14,7 +14,7 @@ import (
)

func (svc *Service) GetPackage(msg *coormq.GetPackage) (*coormq.GetPackageResp, *mq.CodeMessage) {
pkg, err := svc.db.Package().GetByID(svc.db.SQLCtx(), msg.PackageID)
pkg, err := svc.db2.Package().GetByID(svc.db2.DefCtx(), msg.PackageID)
if err != nil {
logger.WithField("PackageID", msg.PackageID).
Warnf("get package: %s", err.Error())
@@ -26,7 +26,7 @@ func (svc *Service) GetPackage(msg *coormq.GetPackage) (*coormq.GetPackageResp,
}

func (svc *Service) GetPackageByName(msg *coormq.GetPackageByName) (*coormq.GetPackageByNameResp, *mq.CodeMessage) {
pkg, err := svc.db.Package().GetUserPackageByName(svc.db.SQLCtx(), msg.UserID, msg.BucketName, msg.PackageName)
pkg, err := svc.db2.Package().GetUserPackageByName(svc.db2.DefCtx(), msg.UserID, msg.BucketName, msg.PackageName)
if err != nil {
logger.WithField("UserID", msg.UserID).
WithField("BucketName", msg.BucketName).
@@ -45,20 +45,20 @@ func (svc *Service) GetPackageByName(msg *coormq.GetPackageByName) (*coormq.GetP

func (svc *Service) CreatePackage(msg *coormq.CreatePackage) (*coormq.CreatePackageResp, *mq.CodeMessage) {
var pkg cdssdk.Package
err := svc.db.DoTx(sql.LevelSerializable, func(tx *sqlx.Tx) error {
err := svc.db2.DoTx(func(tx db2.SQLContext) error {
var err error

isAvai, _ := svc.db.Bucket().IsAvailable(tx, msg.BucketID, msg.UserID)
isAvai, _ := svc.db2.Bucket().IsAvailable(tx, msg.BucketID, msg.UserID)
if !isAvai {
return fmt.Errorf("bucket is not avaiable to the user")
}

pkgID, err := svc.db.Package().Create(tx, msg.BucketID, msg.Name)
pkgID, err := svc.db2.Package().Create(tx, msg.BucketID, msg.Name)
if err != nil {
return fmt.Errorf("creating package: %w", err)
}

pkg, err = svc.db.Package().GetByID(tx, pkgID)
pkg, err = svc.db2.Package().GetByID(tx, pkgID)
if err != nil {
return fmt.Errorf("getting package by id: %w", err)
}
@@ -77,22 +77,22 @@ func (svc *Service) CreatePackage(msg *coormq.CreatePackage) (*coormq.CreatePack

func (svc *Service) UpdatePackage(msg *coormq.UpdatePackage) (*coormq.UpdatePackageResp, *mq.CodeMessage) {
var added []cdssdk.Object
err := svc.db.DoTx(sql.LevelSerializable, func(tx *sqlx.Tx) error {
_, err := svc.db.Package().GetByID(tx, msg.PackageID)
err := svc.db2.DoTx(func(tx db2.SQLContext) error {
_, err := svc.db2.Package().GetByID(tx, msg.PackageID)
if err != nil {
return fmt.Errorf("getting package by id: %w", err)
}

// 先执行删除操作
if len(msg.Deletes) > 0 {
if err := svc.db.Object().BatchDelete(tx, msg.Deletes); err != nil {
if err := svc.db2.Object().BatchDelete(tx, msg.Deletes); err != nil {
return fmt.Errorf("deleting objects: %w", err)
}
}

// 再执行添加操作
if len(msg.Adds) > 0 {
ad, err := svc.db.Object().BatchAdd(tx, msg.PackageID, msg.Adds)
ad, err := svc.db2.Object().BatchAdd(tx, msg.PackageID, msg.Adds)
if err != nil {
return fmt.Errorf("adding objects: %w", err)
}
@@ -110,25 +110,25 @@ func (svc *Service) UpdatePackage(msg *coormq.UpdatePackage) (*coormq.UpdatePack
}

func (svc *Service) DeletePackage(msg *coormq.DeletePackage) (*coormq.DeletePackageResp, *mq.CodeMessage) {
err := svc.db.DoTx(sql.LevelSerializable, func(tx *sqlx.Tx) error {
isAvai, _ := svc.db.Package().IsAvailable(tx, msg.UserID, msg.PackageID)
err := svc.db2.DoTx(func(tx db2.SQLContext) error {
isAvai, _ := svc.db2.Package().IsAvailable(tx, msg.UserID, msg.PackageID)
if !isAvai {
return fmt.Errorf("package is not available to the user")
}

err := svc.db.Package().SoftDelete(tx, msg.PackageID)
err := svc.db2.Package().SoftDelete(tx, msg.PackageID)
if err != nil {
return fmt.Errorf("soft delete package: %w", err)
}

err = svc.db.Package().DeleteUnused(tx, msg.PackageID)
err = svc.db2.Package().DeleteUnused(tx, msg.PackageID)
if err != nil {
logger.WithField("UserID", msg.UserID).
WithField("PackageID", msg.PackageID).
Warnf("deleting unused package: %w", err.Error())
}

err = svc.db.PackageAccessStat().DeleteByPackageID(tx, msg.PackageID)
err = svc.db2.PackageAccessStat().DeleteByPackageID(tx, msg.PackageID)
if err != nil {
logger.WithField("UserID", msg.UserID).
WithField("PackageID", msg.PackageID).
@@ -148,7 +148,7 @@ func (svc *Service) DeletePackage(msg *coormq.DeletePackage) (*coormq.DeletePack
}

func (svc *Service) GetPackageCachedNodes(msg *coormq.GetPackageCachedNodes) (*coormq.GetPackageCachedNodesResp, *mq.CodeMessage) {
isAva, err := svc.db.Package().IsAvailable(svc.db.SQLCtx(), msg.UserID, msg.PackageID)
isAva, err := svc.db2.Package().IsAvailable(svc.db2.DefCtx(), msg.UserID, msg.PackageID)
if err != nil {
logger.WithField("UserID", msg.UserID).
WithField("PackageID", msg.PackageID).
@@ -163,7 +163,7 @@ func (svc *Service) GetPackageCachedNodes(msg *coormq.GetPackageCachedNodes) (*c
}

// 这个函数只是统计哪些节点缓存了Package中的数据,不需要多么精确,所以可以不用事务
objDetails, err := svc.db.Object().GetPackageObjectDetails(svc.db.SQLCtx(), msg.PackageID)
objDetails, err := svc.db2.Object().GetPackageObjectDetails(svc.db2.DefCtx(), msg.PackageID)
if err != nil {
logger.WithField("PackageID", msg.PackageID).
Warnf("get package block details: %s", err.Error())
@@ -202,7 +202,7 @@ func (svc *Service) GetPackageCachedNodes(msg *coormq.GetPackageCachedNodes) (*c
}

func (svc *Service) GetPackageLoadedNodes(msg *coormq.GetPackageLoadedNodes) (*coormq.GetPackageLoadedNodesResp, *mq.CodeMessage) {
storages, err := svc.db.StoragePackage().FindPackageStorages(svc.db.SQLCtx(), msg.PackageID)
storages, err := svc.db2.StoragePackage().FindPackageStorages(svc.db2.DefCtx(), msg.PackageID)
if err != nil {
logger.WithField("PackageID", msg.PackageID).
Warnf("get storages by packageID failed, err: %s", err.Error())
@@ -229,13 +229,13 @@ func (svc *Service) AddAccessStat(msg *coormq.AddAccessStat) {
objIDs[i] = e.ObjectID
}

err := svc.db.DoTx(sql.LevelSerializable, func(tx *sqlx.Tx) error {
avaiPkgIDs, err := svc.db.Package().BatchTestPackageID(tx, pkgIDs)
err := svc.db2.DoTx(func(tx db2.SQLContext) error {
avaiPkgIDs, err := svc.db2.Package().BatchTestPackageID(tx, pkgIDs)
if err != nil {
return fmt.Errorf("batch test package id: %w", err)
}

avaiObjIDs, err := svc.db.Object().BatchTestObjectID(tx, objIDs)
avaiObjIDs, err := svc.db2.Object().BatchTestObjectID(tx, objIDs)
if err != nil {
return fmt.Errorf("batch test object id: %w", err)
}
@@ -248,12 +248,12 @@ func (svc *Service) AddAccessStat(msg *coormq.AddAccessStat) {
}

if len(willAdds) > 0 {
err := svc.db.PackageAccessStat().BatchAddCounter(tx, willAdds)
err := svc.db2.PackageAccessStat().BatchAddCounter(tx, willAdds)
if err != nil {
return fmt.Errorf("batch add package access stat counter: %w", err)
}

err = svc.db.ObjectAccessStat().BatchAddCounter(tx, willAdds)
err = svc.db2.ObjectAccessStat().BatchAddCounter(tx, willAdds)
if err != nil {
return fmt.Errorf("batch add object access stat counter: %w", err)
}


+ 7
- 8
coordinator/internal/mq/storage.go View File

@@ -4,7 +4,6 @@ import (
"database/sql"
"fmt"

"github.com/jmoiron/sqlx"
"gitlink.org.cn/cloudream/common/consts/errorcode"
"gitlink.org.cn/cloudream/common/pkgs/logger"
"gorm.io/gorm"
@@ -70,33 +69,33 @@ func (svc *Service) GetStorageByName(msg *coormq.GetStorageByName) (*coormq.GetS
}

func (svc *Service) StoragePackageLoaded(msg *coormq.StoragePackageLoaded) (*coormq.StoragePackageLoadedResp, *mq.CodeMessage) {
err := svc.db.DoTx(sql.LevelSerializable, func(tx *sqlx.Tx) error {
err := svc.db2.DoTx(func(tx db2.SQLContext) error {
// 可以不用检查用户是否存在
if ok, _ := svc.db.Package().IsAvailable(tx, msg.UserID, msg.PackageID); !ok {
if ok, _ := svc.db2.Package().IsAvailable(tx, msg.UserID, msg.PackageID); !ok {
return fmt.Errorf("package is not available to user")
}

if ok, _ := svc.db.Storage().IsAvailable(tx, msg.UserID, msg.StorageID); !ok {
if ok, _ := svc.db2.Storage().IsAvailable(tx, msg.UserID, msg.StorageID); !ok {
return fmt.Errorf("storage is not available to user")
}

err := svc.db.StoragePackage().CreateOrUpdate(tx, msg.StorageID, msg.PackageID, msg.UserID)
err := svc.db2.StoragePackage().CreateOrUpdate(tx, msg.StorageID, msg.PackageID, msg.UserID)
if err != nil {
return fmt.Errorf("creating storage package: %w", err)
}

stg, err := svc.db.Storage().GetByID(tx, msg.StorageID)
stg, err := svc.db2.Storage().GetByID(tx, msg.StorageID)
if err != nil {
return fmt.Errorf("getting storage: %w", err)
}

err = svc.db.PinnedObject().CreateFromPackage(tx, msg.PackageID, stg.NodeID)
err = svc.db2.PinnedObject().CreateFromPackage(tx, msg.PackageID, stg.NodeID)
if err != nil {
return fmt.Errorf("creating pinned object from package: %w", err)
}

if len(msg.PinnedBlocks) > 0 {
err = svc.db.ObjectBlock().BatchCreate(tx, msg.PinnedBlocks)
err = svc.db2.ObjectBlock().BatchCreate(tx, msg.PinnedBlocks)
if err != nil {
return fmt.Errorf("batch creating object block: %w", err)
}


Loading…
Cancel
Save