You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

driver.go 5.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. /*
  2. * Licensed to the Apache Software Foundation (ASF) under one or more
  3. * contributor license agreements. See the NOTICE file distributed with
  4. * this work for additional information regarding copyright ownership.
  5. * The ASF licenses this file to You under the Apache License, Version 2.0
  6. * (the "License"); you may not use this file except in compliance with
  7. * the License. You may obtain a copy of the License at
  8. *
  9. * http://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. */
  17. package sql
  18. import (
  19. "context"
  20. "database/sql"
  21. "database/sql/driver"
  22. "fmt"
  23. "strings"
  24. mysql2 "github.com/seata/seata-go/pkg/datasource/sql/datasource/mysql"
  25. "github.com/go-sql-driver/mysql"
  26. "github.com/seata/seata-go/pkg/datasource/sql/datasource"
  27. "github.com/seata/seata-go/pkg/datasource/sql/types"
  28. "github.com/seata/seata-go/pkg/protocol/branch"
  29. "github.com/seata/seata-go/pkg/util/log"
  30. )
  31. const (
  32. // SeataATMySQLDriver MySQL driver for AT mode
  33. SeataATMySQLDriver = "seata-at-mysql"
  34. // SeataXAMySQLDriver MySQL driver for XA mode
  35. SeataXAMySQLDriver = "seata-xa-mysql"
  36. )
  37. func init() {
  38. sql.Register(SeataATMySQLDriver, &seataATDriver{
  39. seataDriver: &seataDriver{
  40. transType: types.ATMode,
  41. target: mysql.MySQLDriver{},
  42. },
  43. })
  44. sql.Register(SeataXAMySQLDriver, &seataXADriver{
  45. seataDriver: &seataDriver{
  46. transType: types.XAMode,
  47. target: mysql.MySQLDriver{},
  48. },
  49. })
  50. }
  51. type seataATDriver struct {
  52. *seataDriver
  53. }
  54. func (d *seataATDriver) OpenConnector(name string) (c driver.Connector, err error) {
  55. connector, err := d.seataDriver.OpenConnector(name)
  56. if err != nil {
  57. return nil, err
  58. }
  59. _connector, _ := connector.(*seataConnector)
  60. _connector.transType = types.ATMode
  61. cfg, _ := mysql.ParseDSN(name)
  62. _connector.cfg = cfg
  63. return &seataATConnector{
  64. seataConnector: _connector,
  65. }, nil
  66. }
  67. type seataXADriver struct {
  68. *seataDriver
  69. }
  70. func (d *seataXADriver) OpenConnector(name string) (c driver.Connector, err error) {
  71. connector, err := d.seataDriver.OpenConnector(name)
  72. if err != nil {
  73. return nil, err
  74. }
  75. _connector, _ := connector.(*seataConnector)
  76. _connector.transType = types.XAMode
  77. cfg, _ := mysql.ParseDSN(name)
  78. _connector.cfg = cfg
  79. return &seataXAConnector{
  80. seataConnector: _connector,
  81. }, nil
  82. }
  83. type seataDriver struct {
  84. transType types.TransactionMode
  85. target driver.Driver
  86. }
  87. func (d *seataDriver) Open(name string) (driver.Conn, error) {
  88. conn, err := d.target.Open(name)
  89. if err != nil {
  90. log.Errorf("open db connection: %w", err)
  91. return nil, err
  92. }
  93. return conn, nil
  94. }
  95. func (d *seataDriver) OpenConnector(name string) (c driver.Connector, err error) {
  96. c = &dsnConnector{dsn: name, driver: d.target}
  97. if driverCtx, ok := d.target.(driver.DriverContext); ok {
  98. c, err = driverCtx.OpenConnector(name)
  99. if err != nil {
  100. log.Errorf("open connector: %w", err)
  101. return nil, err
  102. }
  103. }
  104. dbType := types.ParseDBType(d.getTargetDriverName())
  105. if dbType == types.DBTypeUnknown {
  106. return nil, fmt.Errorf("unsupport conn type %s", d.getTargetDriverName())
  107. }
  108. proxy, err := getOpenConnectorProxy(c, dbType, sql.OpenDB(c), name)
  109. if err != nil {
  110. log.Errorf("register resource: %w", err)
  111. return nil, err
  112. }
  113. return proxy, nil
  114. }
  115. func (d *seataDriver) getTargetDriverName() string {
  116. return "mysql"
  117. }
  118. type dsnConnector struct {
  119. dsn string
  120. driver driver.Driver
  121. }
  122. func (t *dsnConnector) Connect(_ context.Context) (driver.Conn, error) {
  123. return t.driver.Open(t.dsn)
  124. }
  125. func (t *dsnConnector) Driver() driver.Driver {
  126. return t.driver
  127. }
  128. func getOpenConnectorProxy(connector driver.Connector, dbType types.DBType, db *sql.DB,
  129. dataSourceName string, opts ...seataOption) (driver.Connector, error) {
  130. conf := loadConfig()
  131. for i := range opts {
  132. opts[i](conf)
  133. }
  134. if err := conf.validate(); err != nil {
  135. log.Errorf("invalid conf: %w", err)
  136. return connector, err
  137. }
  138. cfg, _ := mysql.ParseDSN(dataSourceName)
  139. options := []dbOption{
  140. withGroupID(conf.GroupID),
  141. withResourceID(parseResourceID(dataSourceName)),
  142. withConf(conf),
  143. withTarget(db),
  144. withDBType(dbType),
  145. withDBName(cfg.DBName),
  146. }
  147. res, err := newResource(options...)
  148. if err != nil {
  149. log.Errorf("create new resource: %w", err)
  150. return nil, err
  151. }
  152. datasource.RegisterTableCache(types.DBTypeMySQL, mysql2.NewTableMetaInstance(db))
  153. if err = datasource.GetDataSourceManager(conf.BranchType).RegisterResource(res); err != nil {
  154. log.Errorf("regisiter resource: %w", err)
  155. return nil, err
  156. }
  157. return &seataConnector{
  158. res: res,
  159. target: connector,
  160. conf: conf,
  161. cfg: cfg,
  162. }, nil
  163. }
  164. type (
  165. seataOption func(cfg *seataServerConfig)
  166. // seataServerConfig
  167. seataServerConfig struct {
  168. // GroupID
  169. GroupID string `yaml:"groupID"`
  170. // BranchType
  171. BranchType branch.BranchType
  172. // Endpoints
  173. Endpoints []string `yaml:"endpoints" json:"endpoints"`
  174. }
  175. )
  176. func (c *seataServerConfig) validate() error {
  177. return nil
  178. }
  179. // loadConfig
  180. // TODO wait finish
  181. func loadConfig() *seataServerConfig {
  182. // 先设置默认配置
  183. // 从默认文件获取
  184. return &seataServerConfig{
  185. GroupID: "DEFAULT_GROUP",
  186. BranchType: branch.BranchTypeAT,
  187. Endpoints: []string{"127.0.0.1:8888"},
  188. }
  189. }
  190. func parseResourceID(dsn string) string {
  191. i := strings.Index(dsn, "?")
  192. res := dsn
  193. if i > 0 {
  194. res = dsn[:i]
  195. }
  196. return strings.ReplaceAll(res, ",", "|")
  197. }