|
- // @Author: Ciusyan 5/20/24
-
- package postgres
-
- import (
- "context"
- "encoding/json"
- "errors"
- "fmt"
- "strings"
- "time"
-
- "github.com/ccfos/nightingale/v6/dskit/pool"
- "github.com/ccfos/nightingale/v6/dskit/sqlbase"
- "github.com/ccfos/nightingale/v6/dskit/types"
-
- _ "github.com/lib/pq" // PostgreSQL driver
- "github.com/mitchellh/mapstructure"
- "gorm.io/driver/postgres"
- "gorm.io/gorm"
- )
-
- type PostgreSQL struct {
- Shard `json:",inline" mapstructure:",squash"`
- }
-
- type Shard struct {
- Addr string `json:"pgsql.addr" mapstructure:"pgsql.addr"`
- DB string `json:"pgsql.db" mapstructure:"pgsql.db"`
- User string `json:"pgsql.user" mapstructure:"pgsql.user"`
- Password string `json:"pgsql.password" mapstructure:"pgsql.password" `
- Timeout int `json:"pgsql.timeout" mapstructure:"pgsql.timeout"`
- MaxIdleConns int `json:"pgsql.max_idle_conns" mapstructure:"pgsql.max_idle_conns"`
- MaxOpenConns int `json:"pgsql.max_open_conns" mapstructure:"pgsql.max_open_conns"`
- ConnMaxLifetime int `json:"pgsql.conn_max_lifetime" mapstructure:"pgsql.conn_max_lifetime"`
- MaxQueryRows int `json:"pgsql.max_query_rows" mapstructure:"pgsql.max_query_rows"`
- }
-
- // NewPostgreSQLWithSettings initializes a new PostgreSQL instance with the given settings
- func NewPostgreSQLWithSettings(ctx context.Context, settings interface{}) (*PostgreSQL, error) {
- newest := new(PostgreSQL)
- settingsMap := map[string]interface{}{}
-
- switch s := settings.(type) {
- case string:
- if err := json.Unmarshal([]byte(s), &settingsMap); err != nil {
- return nil, err
- }
- case map[string]interface{}:
- settingsMap = s
- case *PostgreSQL:
- return s, nil
- case PostgreSQL:
- return &s, nil
- case Shard:
- newest.Shard = s
- return newest, nil
- case *Shard:
- newest.Shard = *s
- return newest, nil
- default:
- return nil, errors.New("unsupported settings type")
- }
-
- if err := mapstructure.Decode(settingsMap, newest); err != nil {
- return nil, err
- }
-
- return newest, nil
- }
-
- // NewConn establishes a new connection to PostgreSQL
- func (p *PostgreSQL) NewConn(ctx context.Context, database string) (*gorm.DB, error) {
- if len(p.DB) == 0 && len(database) == 0 {
- return nil, errors.New("empty pgsql database") // 兼容阿里实时数仓Holgres, 连接时必须指定db名字
- }
-
- if p.Shard.Timeout == 0 {
- p.Shard.Timeout = 60
- }
-
- if p.Shard.MaxIdleConns == 0 {
- p.Shard.MaxIdleConns = 10
- }
-
- if p.Shard.MaxOpenConns == 0 {
- p.Shard.MaxOpenConns = 100
- }
-
- if p.Shard.ConnMaxLifetime == 0 {
- p.Shard.ConnMaxLifetime = 14400
- }
-
- if len(p.Shard.Addr) == 0 {
- return nil, errors.New("empty fe-node addr")
- }
- var keys []string
- var err error
- keys = append(keys, p.Shard.Addr)
-
- keys = append(keys, p.Shard.Password, p.Shard.User)
- if len(database) > 0 {
- keys = append(keys, database)
- }
- cachedKey := strings.Join(keys, ":")
- // cache conn with database
- conn, ok := pool.PoolClient.Load(cachedKey)
- if ok {
- return conn.(*gorm.DB), nil
- }
-
- var db *gorm.DB
- defer func() {
- if db != nil && err == nil {
- pool.PoolClient.Store(cachedKey, db)
- }
- }()
-
- // Simplified connection logic for PostgreSQL
- dsn := fmt.Sprintf("postgres://%s:%s@%s/%s?sslmode=disable&TimeZone=Asia/Shanghai", p.Shard.User, p.Shard.Password, p.Shard.Addr, database)
- db, err = sqlbase.NewDB(
- ctx,
- postgres.Open(dsn),
- p.Shard.MaxIdleConns,
- p.Shard.MaxOpenConns,
- time.Duration(p.Shard.ConnMaxLifetime)*time.Second,
- )
-
- if err != nil {
- if db != nil {
- sqlDB, _ := db.DB()
- if sqlDB != nil {
- sqlDB.Close()
- }
- }
- return nil, err
- }
-
- return db, nil
- }
-
- // ShowDatabases lists all databases in PostgreSQL
- func (p *PostgreSQL) ShowDatabases(ctx context.Context, searchKeyword string) ([]string, error) {
- db, err := p.NewConn(ctx, "postgres")
- if err != nil {
- return nil, err
- }
- sql := fmt.Sprintf("SELECT datname FROM pg_database WHERE datistemplate = false AND datname LIKE %s",
- "'%"+searchKeyword+"%'")
- return sqlbase.ShowDatabases(ctx, db, sql)
- }
-
- // ShowTables lists all tables in a given database
- func (p *PostgreSQL) ShowTables(ctx context.Context, searchKeyword string) (map[string][]string, error) {
- db, err := p.NewConn(ctx, p.DB)
- if err != nil {
- return nil, err
- }
- sql := fmt.Sprintf("SELECT schemaname, tablename FROM pg_tables WHERE schemaname !='information_schema' and schemaname !='pg_catalog' and tablename LIKE %s",
- "'%"+searchKeyword+"%'")
- rets, err := sqlbase.ExecQuery(ctx, db, sql)
- if err != nil {
- return nil, err
- }
- tabs := make(map[string][]string, 3)
- for _, row := range rets {
- if val, ok := row["schemaname"].(string); ok {
- tabs[val] = append(tabs[val], row["tablename"].(string))
- }
- }
- return tabs, nil
- }
-
- // DescTable describes the schema of a specified table in PostgreSQL
- // scheme default: public if not specified
- func (p *PostgreSQL) DescTable(ctx context.Context, scheme, table string) ([]*types.ColumnProperty, error) {
- db, err := p.NewConn(ctx, p.DB)
- if err != nil {
- return nil, err
- }
- if scheme == "" {
- scheme = "public"
- }
-
- query := fmt.Sprintf("SELECT column_name, data_type, is_nullable, column_default FROM information_schema.columns WHERE table_name = '%s' AND table_schema = '%s'", table, scheme)
- return sqlbase.DescTable(ctx, db, query)
- }
-
- // SelectRows selects rows from a specified table in PostgreSQL based on a given query
- func (p *PostgreSQL) SelectRows(ctx context.Context, table, where string) ([]map[string]interface{}, error) {
- db, err := p.NewConn(ctx, p.DB)
- if err != nil {
- return nil, err
- }
-
- return sqlbase.SelectRows(ctx, db, table, where)
- }
-
- // ExecQuery executes a SQL query in PostgreSQL
- func (p *PostgreSQL) ExecQuery(ctx context.Context, sql string) ([]map[string]interface{}, error) {
- db, err := p.NewConn(ctx, p.DB)
- if err != nil {
- return nil, err
- }
-
- return sqlbase.ExecQuery(ctx, db, sql)
- }
|