You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

update_executor.go 8.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287
  1. /*
  2. * Licensed to the Apache Software Foundation (ASF) under one or more
  3. * contributor license agreements. See the NOTICE file distributed with
  4. * this work for additional information regarding copyright ownership.
  5. * The ASF licenses this file to You under the Apache License, Version 2.0
  6. * (the "License"); you may not use this file except in compliance with
  7. * the License. You may obtain a copy of the License at
  8. *
  9. * http://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. */
  17. package at
  18. import (
  19. "context"
  20. "database/sql/driver"
  21. "fmt"
  22. "github.com/arana-db/parser/model"
  23. "seata.apache.org/seata-go/pkg/datasource/sql/util"
  24. "strings"
  25. "github.com/arana-db/parser/ast"
  26. "github.com/arana-db/parser/format"
  27. "seata.apache.org/seata-go/pkg/datasource/sql/datasource"
  28. "seata.apache.org/seata-go/pkg/datasource/sql/exec"
  29. "seata.apache.org/seata-go/pkg/datasource/sql/types"
  30. "seata.apache.org/seata-go/pkg/datasource/sql/undo"
  31. "seata.apache.org/seata-go/pkg/util/bytes"
  32. "seata.apache.org/seata-go/pkg/util/log"
  33. )
  34. var (
  35. maxInSize = 1000
  36. )
  37. // updateExecutor execute update SQL
  38. type updateExecutor struct {
  39. baseExecutor
  40. parserCtx *types.ParseContext
  41. execContext *types.ExecContext
  42. }
  43. // NewUpdateExecutor get update executor
  44. func NewUpdateExecutor(parserCtx *types.ParseContext, execContent *types.ExecContext, hooks []exec.SQLHook) executor {
  45. // Because update join cannot be clearly identified when SQL cannot be parsed
  46. if parserCtx.UpdateStmt.TableRefs.TableRefs.Right != nil {
  47. return NewUpdateJoinExecutor(parserCtx, execContent, hooks)
  48. }
  49. return &updateExecutor{parserCtx: parserCtx, execContext: execContent, baseExecutor: baseExecutor{hooks: hooks}}
  50. }
  51. // ExecContext exec SQL, and generate before image and after image
  52. func (u *updateExecutor) ExecContext(ctx context.Context, f exec.CallbackWithNamedValue) (types.ExecResult, error) {
  53. u.beforeHooks(ctx, u.execContext)
  54. defer func() {
  55. u.afterHooks(ctx, u.execContext)
  56. }()
  57. beforeImage, err := u.beforeImage(ctx)
  58. if err != nil {
  59. return nil, err
  60. }
  61. res, err := f(ctx, u.execContext.Query, u.execContext.NamedValues)
  62. if err != nil {
  63. return nil, err
  64. }
  65. afterImage, err := u.afterImage(ctx, *beforeImage)
  66. if err != nil {
  67. return nil, err
  68. }
  69. if len(beforeImage.Rows) != len(afterImage.Rows) {
  70. return nil, fmt.Errorf("Before image size is not equaled to after image size, probably because you updated the primary keys.")
  71. }
  72. u.execContext.TxCtx.RoundImages.AppendBeofreImage(beforeImage)
  73. u.execContext.TxCtx.RoundImages.AppendAfterImage(afterImage)
  74. return res, nil
  75. }
  76. // beforeImage build before image
  77. func (u *updateExecutor) beforeImage(ctx context.Context) (*types.RecordImage, error) {
  78. if !u.isAstStmtValid() {
  79. return nil, nil
  80. }
  81. selectSQL, selectArgs, err := u.buildBeforeImageSQL(ctx, u.execContext.NamedValues)
  82. if err != nil {
  83. return nil, err
  84. }
  85. tableName, _ := u.parserCtx.GetTableName()
  86. metaData, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, u.execContext.DBName, tableName)
  87. if err != nil {
  88. return nil, err
  89. }
  90. var rowsi driver.Rows
  91. queryerCtx, ok := u.execContext.Conn.(driver.QueryerContext)
  92. var queryer driver.Queryer
  93. if !ok {
  94. queryer, ok = u.execContext.Conn.(driver.Queryer)
  95. }
  96. if ok {
  97. rowsi, err = util.CtxDriverQuery(ctx, queryerCtx, queryer, selectSQL, selectArgs)
  98. defer func() {
  99. if rowsi != nil {
  100. rowsi.Close()
  101. }
  102. }()
  103. if err != nil {
  104. log.Errorf("ctx driver query: %+v", err)
  105. return nil, err
  106. }
  107. } else {
  108. log.Errorf("target conn should been driver.QueryerContext or driver.Queryer")
  109. return nil, fmt.Errorf("invalid conn")
  110. }
  111. image, err := u.buildRecordImages(rowsi, metaData, types.SQLTypeUpdate)
  112. if err != nil {
  113. return nil, err
  114. }
  115. lockKey := u.buildLockKey(image, *metaData)
  116. u.execContext.TxCtx.LockKeys[lockKey] = struct{}{}
  117. image.SQLType = u.parserCtx.SQLType
  118. return image, nil
  119. }
  120. // afterImage build after image
  121. func (u *updateExecutor) afterImage(ctx context.Context, beforeImage types.RecordImage) (*types.RecordImage, error) {
  122. if !u.isAstStmtValid() {
  123. return nil, nil
  124. }
  125. if len(beforeImage.Rows) == 0 {
  126. return &types.RecordImage{}, nil
  127. }
  128. tableName, _ := u.parserCtx.GetTableName()
  129. metaData, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, u.execContext.DBName, tableName)
  130. if err != nil {
  131. return nil, err
  132. }
  133. selectSQL, selectArgs := u.buildAfterImageSQL(beforeImage, metaData)
  134. var rowsi driver.Rows
  135. queryerCtx, ok := u.execContext.Conn.(driver.QueryerContext)
  136. var queryer driver.Queryer
  137. if !ok {
  138. queryer, ok = u.execContext.Conn.(driver.Queryer)
  139. }
  140. if ok {
  141. rowsi, err = util.CtxDriverQuery(ctx, queryerCtx, queryer, selectSQL, selectArgs)
  142. defer func() {
  143. if rowsi != nil {
  144. rowsi.Close()
  145. }
  146. }()
  147. if err != nil {
  148. log.Errorf("ctx driver query: %+v", err)
  149. return nil, err
  150. }
  151. } else {
  152. log.Errorf("target conn should been driver.QueryerContext or driver.Queryer")
  153. return nil, fmt.Errorf("invalid conn")
  154. }
  155. afterImage, err := u.buildRecordImages(rowsi, metaData, types.SQLTypeUpdate)
  156. if err != nil {
  157. return nil, err
  158. }
  159. afterImage.SQLType = u.parserCtx.SQLType
  160. return afterImage, nil
  161. }
  162. func (u *updateExecutor) isAstStmtValid() bool {
  163. return u.parserCtx != nil && u.parserCtx.UpdateStmt != nil
  164. }
  165. // buildAfterImageSQL build the SQL to query after image data
  166. func (u *updateExecutor) buildAfterImageSQL(beforeImage types.RecordImage, meta *types.TableMeta) (string, []driver.NamedValue) {
  167. if len(beforeImage.Rows) == 0 {
  168. return "", nil
  169. }
  170. sb := strings.Builder{}
  171. // todo: OnlyCareUpdateColumns should load from config first
  172. var selectFields string
  173. var separator = ","
  174. if undo.UndoConfig.OnlyCareUpdateColumns {
  175. for _, row := range beforeImage.Rows {
  176. for _, column := range row.Columns {
  177. selectFields += column.ColumnName + separator
  178. }
  179. }
  180. selectFields = strings.TrimSuffix(selectFields, separator)
  181. } else {
  182. selectFields = "*"
  183. }
  184. sb.WriteString("SELECT " + selectFields + " FROM " + meta.TableName + " WHERE ")
  185. whereSQL := u.buildWhereConditionByPKs(meta.GetPrimaryKeyOnlyName(), len(beforeImage.Rows), "mysql", maxInSize)
  186. sb.WriteString(" " + whereSQL + " ")
  187. return sb.String(), u.buildPKParams(beforeImage.Rows, meta.GetPrimaryKeyOnlyName())
  188. }
  189. // buildAfterImageSQL build the SQL to query before image data
  190. func (u *updateExecutor) buildBeforeImageSQL(ctx context.Context, args []driver.NamedValue) (string, []driver.NamedValue, error) {
  191. if !u.isAstStmtValid() {
  192. log.Errorf("invalid update stmt")
  193. return "", nil, fmt.Errorf("invalid update stmt")
  194. }
  195. updateStmt := u.parserCtx.UpdateStmt
  196. fields := make([]*ast.SelectField, 0, len(updateStmt.List))
  197. if undo.UndoConfig.OnlyCareUpdateColumns {
  198. for _, column := range updateStmt.List {
  199. fields = append(fields, &ast.SelectField{
  200. Expr: &ast.ColumnNameExpr{
  201. Name: column.Column,
  202. },
  203. })
  204. }
  205. // select indexes columns
  206. tableName, _ := u.parserCtx.GetTableName()
  207. metaData, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, u.execContext.DBName, tableName)
  208. if err != nil {
  209. return "", nil, err
  210. }
  211. for _, columnName := range metaData.GetPrimaryKeyOnlyName() {
  212. fields = append(fields, &ast.SelectField{
  213. Expr: &ast.ColumnNameExpr{
  214. Name: &ast.ColumnName{
  215. Name: model.CIStr{
  216. O: columnName,
  217. L: columnName,
  218. },
  219. },
  220. },
  221. })
  222. }
  223. } else {
  224. fields = append(fields, &ast.SelectField{
  225. Expr: &ast.ColumnNameExpr{
  226. Name: &ast.ColumnName{
  227. Name: model.CIStr{
  228. O: "*",
  229. L: "*",
  230. },
  231. },
  232. },
  233. })
  234. }
  235. selStmt := ast.SelectStmt{
  236. SelectStmtOpts: &ast.SelectStmtOpts{},
  237. From: updateStmt.TableRefs,
  238. Where: updateStmt.Where,
  239. Fields: &ast.FieldList{Fields: fields},
  240. OrderBy: updateStmt.Order,
  241. Limit: updateStmt.Limit,
  242. TableHints: updateStmt.TableHints,
  243. LockInfo: &ast.SelectLockInfo{
  244. LockType: ast.SelectLockForUpdate,
  245. },
  246. }
  247. b := bytes.NewByteBuffer([]byte{})
  248. _ = selStmt.Restore(format.NewRestoreCtx(format.RestoreKeyWordUppercase, b))
  249. sql := string(b.Bytes())
  250. log.Infof("build select sql by update sourceQuery, sql {%s}", sql)
  251. return sql, u.buildSelectArgs(&selStmt, args), nil
  252. }