You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

driver.go 5.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. /*
  2. * Licensed to the Apache Software Foundation (ASF) under one or more
  3. * contributor license agreements. See the NOTICE file distributed with
  4. * this work for additional information regarding copyright ownership.
  5. * The ASF licenses this file to You under the Apache License, Version 2.0
  6. * (the "License"); you may not use this file except in compliance with
  7. * the License. You may obtain a copy of the License at
  8. *
  9. * http://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. */
  17. package sql
  18. import (
  19. "context"
  20. "database/sql"
  21. "database/sql/driver"
  22. "fmt"
  23. "strings"
  24. "github.com/go-sql-driver/mysql"
  25. "github.com/seata/seata-go/pkg/datasource/sql/datasource"
  26. "github.com/seata/seata-go/pkg/datasource/sql/types"
  27. "github.com/seata/seata-go/pkg/protocol/branch"
  28. "github.com/seata/seata-go/pkg/util/log"
  29. )
  30. const (
  31. // SeataATMySQLDriver MySQL driver for AT mode
  32. SeataATMySQLDriver = "seata-at-mysql"
  33. // SeataXAMySQLDriver MySQL driver for XA mode
  34. SeataXAMySQLDriver = "seata-xa-mysql"
  35. )
  36. func init() {
  37. sql.Register(SeataATMySQLDriver, &seataATDriver{
  38. seataDriver: &seataDriver{
  39. transType: types.ATMode,
  40. target: mysql.MySQLDriver{},
  41. },
  42. })
  43. sql.Register(SeataXAMySQLDriver, &seataXADriver{
  44. seataDriver: &seataDriver{
  45. transType: types.XAMode,
  46. target: mysql.MySQLDriver{},
  47. },
  48. })
  49. }
  50. type seataATDriver struct {
  51. *seataDriver
  52. }
  53. func (d *seataATDriver) OpenConnector(name string) (c driver.Connector, err error) {
  54. connector, err := d.seataDriver.OpenConnector(name)
  55. if err != nil {
  56. return nil, err
  57. }
  58. _connector, _ := connector.(*seataConnector)
  59. _connector.transType = types.ATMode
  60. return &seataATConnector{
  61. seataConnector: _connector,
  62. }, nil
  63. }
  64. type seataXADriver struct {
  65. *seataDriver
  66. }
  67. func (d *seataXADriver) OpenConnector(name string) (c driver.Connector, err error) {
  68. connector, err := d.seataDriver.OpenConnector(name)
  69. if err != nil {
  70. return nil, err
  71. }
  72. _connector, _ := connector.(*seataConnector)
  73. _connector.transType = types.XAMode
  74. return &seataXAConnector{
  75. seataConnector: _connector,
  76. }, nil
  77. }
  78. type seataDriver struct {
  79. transType types.TransactionType
  80. target driver.Driver
  81. }
  82. func (d *seataDriver) Open(name string) (driver.Conn, error) {
  83. conn, err := d.target.Open(name)
  84. if err != nil {
  85. log.Errorf("open target connection: %w", err)
  86. return nil, err
  87. }
  88. return conn, nil
  89. }
  90. func (d *seataDriver) OpenConnector(name string) (c driver.Connector, err error) {
  91. c = &dsnConnector{dsn: name, driver: d}
  92. if driverCtx, ok := d.target.(driver.DriverContext); ok {
  93. c, err = driverCtx.OpenConnector(name)
  94. if err != nil {
  95. log.Errorf("open connector: %w", err)
  96. return nil, err
  97. }
  98. }
  99. dbType := types.ParseDBType(d.getTargetDriverName())
  100. if dbType == types.DBTypeUnknown {
  101. return nil, fmt.Errorf("unsupport conn type %s", d.getTargetDriverName())
  102. }
  103. proxy, err := registerResource(c, d.transType, dbType, sql.OpenDB(c), name)
  104. if err != nil {
  105. log.Errorf("register resource: %w", err)
  106. return nil, err
  107. }
  108. return proxy, nil
  109. }
  110. func (d *seataDriver) getTargetDriverName() string {
  111. return "mysql"
  112. }
  113. type dsnConnector struct {
  114. dsn string
  115. driver driver.Driver
  116. }
  117. func (t *dsnConnector) Connect(_ context.Context) (driver.Conn, error) {
  118. return t.driver.Open(t.dsn)
  119. }
  120. func (t *dsnConnector) Driver() driver.Driver {
  121. return t.driver
  122. }
  123. func registerResource(connector driver.Connector, txType types.TransactionType, dbType types.DBType, db *sql.DB,
  124. dataSourceName string, opts ...seataOption) (driver.Connector, error) {
  125. conf := loadConfig()
  126. for i := range opts {
  127. opts[i](conf)
  128. }
  129. if err := conf.validate(); err != nil {
  130. log.Errorf("invalid conf: %w", err)
  131. return connector, err
  132. }
  133. options := []dbOption{
  134. withGroupID(conf.GroupID),
  135. withResourceID(parseResourceID(dataSourceName)),
  136. withConf(conf),
  137. withTarget(db),
  138. withDBType(dbType),
  139. }
  140. res, err := newResource(options...)
  141. if err != nil {
  142. log.Errorf("create new resource: %w", err)
  143. return nil, err
  144. }
  145. if err = datasource.GetDataSourceManager(conf.BranchType).RegisterResource(res); err != nil {
  146. log.Errorf("regisiter resource: %w", err)
  147. return nil, err
  148. }
  149. return &seataConnector{
  150. res: res,
  151. target: connector,
  152. conf: conf,
  153. }, nil
  154. }
  155. type (
  156. seataOption func(cfg *seataServerConfig)
  157. // seataServerConfig
  158. seataServerConfig struct {
  159. // GroupID
  160. GroupID string `yaml:"groupID"`
  161. // BranchType
  162. BranchType branch.BranchType
  163. // Endpoints
  164. Endpoints []string `yaml:"endpoints" json:"endpoints"`
  165. }
  166. )
  167. func (c *seataServerConfig) validate() error {
  168. return nil
  169. }
  170. // loadConfig
  171. // TODO wait finish
  172. func loadConfig() *seataServerConfig {
  173. // 先设置默认配置
  174. // 从默认文件获取
  175. return &seataServerConfig{
  176. GroupID: "DEFAULT_GROUP",
  177. BranchType: branch.BranchTypeAT,
  178. Endpoints: []string{"127.0.0.1:8888"},
  179. }
  180. }
  181. func parseResourceID(dsn string) string {
  182. i := strings.Index(dsn, "?")
  183. res := dsn
  184. if i > 0 {
  185. res = dsn[:i]
  186. }
  187. return strings.ReplaceAll(res, ",", "|")
  188. }