You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

update_join_executor.go 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348
  1. /*
  2. * Licensed to the Apache Software Foundation (ASF) under one or more
  3. * contributor license agreements. See the NOTICE file distributed with
  4. * this work for additional information regarding copyright ownership.
  5. * The ASF licenses this file to You under the Apache License, Version 2.0
  6. * (the "License"); you may not use this file except in compliance with
  7. * the License. You may obtain a copy of the License at
  8. *
  9. * http://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. */
  17. package at
  18. import (
  19. "context"
  20. "database/sql/driver"
  21. "errors"
  22. "io"
  23. "reflect"
  24. "strings"
  25. "github.com/arana-db/parser/ast"
  26. "github.com/arana-db/parser/format"
  27. "github.com/arana-db/parser/model"
  28. "seata.apache.org/seata-go/pkg/datasource/sql/datasource"
  29. "seata.apache.org/seata-go/pkg/datasource/sql/exec"
  30. "seata.apache.org/seata-go/pkg/datasource/sql/types"
  31. "seata.apache.org/seata-go/pkg/datasource/sql/util"
  32. "seata.apache.org/seata-go/pkg/util/bytes"
  33. "seata.apache.org/seata-go/pkg/util/log"
  34. )
  35. const (
  36. LowerSupportGroupByPksVersion = "5.7.5"
  37. )
  38. // updateJoinExecutor execute update SQL
  39. type updateJoinExecutor struct {
  40. baseExecutor
  41. parserCtx *types.ParseContext
  42. execContext *types.ExecContext
  43. isLowerSupportGroupByPksVersion bool
  44. sqlMode string
  45. tableAliasesMap map[string]string
  46. }
  47. // NewUpdateJoinExecutor get executor
  48. func NewUpdateJoinExecutor(parserCtx *types.ParseContext, execContent *types.ExecContext, hooks []exec.SQLHook) executor {
  49. minimumVersion, _ := util.ConvertDbVersion(LowerSupportGroupByPksVersion)
  50. currentVersion, _ := util.ConvertDbVersion(execContent.DbVersion)
  51. return &updateJoinExecutor{
  52. parserCtx: parserCtx,
  53. execContext: execContent,
  54. baseExecutor: baseExecutor{hooks: hooks},
  55. isLowerSupportGroupByPksVersion: currentVersion < minimumVersion,
  56. tableAliasesMap: make(map[string]string, 0),
  57. }
  58. }
  59. // ExecContext exec SQL, and generate before image and after image
  60. func (u *updateJoinExecutor) ExecContext(ctx context.Context, f exec.CallbackWithNamedValue) (types.ExecResult, error) {
  61. u.beforeHooks(ctx, u.execContext)
  62. defer func() {
  63. u.afterHooks(ctx, u.execContext)
  64. }()
  65. if u.isAstStmtValid() {
  66. u.tableAliasesMap = u.parseTableName(u.parserCtx.UpdateStmt.TableRefs.TableRefs)
  67. }
  68. beforeImages, err := u.beforeImage(ctx)
  69. if err != nil {
  70. return nil, err
  71. }
  72. res, err := f(ctx, u.execContext.Query, u.execContext.NamedValues)
  73. if err != nil {
  74. return nil, err
  75. }
  76. afterImages, err := u.afterImage(ctx, beforeImages)
  77. if err != nil {
  78. return nil, err
  79. }
  80. if len(afterImages) != len(beforeImages) {
  81. return nil, errors.New("Before image size is not equaled to after image size, probably because you updated the primary keys.")
  82. }
  83. u.execContext.TxCtx.RoundImages.AppendBeofreImages(beforeImages)
  84. u.execContext.TxCtx.RoundImages.AppendAfterImages(afterImages)
  85. return res, nil
  86. }
  87. func (u *updateJoinExecutor) isAstStmtValid() bool {
  88. return u.parserCtx != nil && u.parserCtx.UpdateStmt != nil && u.parserCtx.UpdateStmt.TableRefs.TableRefs.Right != nil
  89. }
  90. func (u *updateJoinExecutor) beforeImage(ctx context.Context) ([]*types.RecordImage, error) {
  91. if !u.isAstStmtValid() {
  92. return nil, nil
  93. }
  94. var recordImages []*types.RecordImage
  95. for tbName, tableAliases := range u.tableAliasesMap {
  96. metaData, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, u.execContext.DBName, tbName)
  97. if err != nil {
  98. return nil, err
  99. }
  100. selectSQL, selectArgs, err := u.buildBeforeImageSQL(ctx, metaData, tableAliases, u.execContext.NamedValues)
  101. if err != nil {
  102. return nil, err
  103. }
  104. if selectSQL == "" {
  105. log.Debugf("Skip unused table [{%s}] when build select sql by update sourceQuery", tbName)
  106. continue
  107. }
  108. var image *types.RecordImage
  109. rowsi, err := u.rowsPrepare(ctx, u.execContext.Conn, selectSQL, selectArgs)
  110. if err == nil {
  111. image, err = u.buildRecordImages(rowsi, metaData, types.SQLTypeUpdate)
  112. }
  113. if rowsi != nil {
  114. if rowerr := rows.Close(); rowerr != nil {
  115. log.Errorf("rows close fail, err:%v", rowerr)
  116. return nil, rowerr
  117. }
  118. }
  119. if err != nil {
  120. // If one fail, all fails
  121. return nil, err
  122. }
  123. lockKey := u.buildLockKey(image, *metaData)
  124. u.execContext.TxCtx.LockKeys[lockKey] = struct{}{}
  125. image.SQLType = u.parserCtx.SQLType
  126. recordImages = append(recordImages, image)
  127. }
  128. return recordImages, nil
  129. }
  130. func (u *updateJoinExecutor) afterImage(ctx context.Context, beforeImages []*types.RecordImage) ([]*types.RecordImage, error) {
  131. if !u.isAstStmtValid() {
  132. return nil, nil
  133. }
  134. if len(beforeImages) == 0 {
  135. return nil, errors.New("empty beforeImages")
  136. }
  137. var recordImages []*types.RecordImage
  138. for _, beforeImage := range beforeImages {
  139. metaData, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, u.execContext.DBName, beforeImage.TableName)
  140. if err != nil {
  141. return nil, err
  142. }
  143. selectSQL, selectArgs, err := u.buildAfterImageSQL(ctx, *beforeImage, metaData, u.tableAliasesMap[beforeImage.TableName])
  144. if err != nil {
  145. return nil, err
  146. }
  147. var image *types.RecordImage
  148. rowsi, err := u.rowsPrepare(ctx, u.execContext.Conn, selectSQL, selectArgs)
  149. if err == nil {
  150. image, err = u.buildRecordImages(rowsi, metaData, types.SQLTypeUpdate)
  151. }
  152. if rowsi != nil {
  153. if rowerr := rowsi.Close(); rowerr != nil {
  154. log.Errorf("rows close fail, err:%v", rowerr)
  155. return nil, rowerr
  156. }
  157. }
  158. if err != nil {
  159. // If one fail, all fails
  160. return nil, err
  161. }
  162. image.SQLType = u.parserCtx.SQLType
  163. recordImages = append(recordImages, image)
  164. }
  165. return recordImages, nil
  166. }
  167. // buildAfterImageSQL build the SQL to query before image data
  168. func (u *updateJoinExecutor) buildBeforeImageSQL(ctx context.Context, tableMeta *types.TableMeta, tableAliases string, args []driver.NamedValue) (string, []driver.NamedValue, error) {
  169. updateStmt := u.parserCtx.UpdateStmt
  170. fields, err := u.buildSelectFields(ctx, tableMeta, tableAliases, updateStmt.List)
  171. if err != nil {
  172. return "", nil, err
  173. }
  174. if len(fields) == 0 {
  175. return "", nil, err
  176. }
  177. selStmt := ast.SelectStmt{
  178. SelectStmtOpts: &ast.SelectStmtOpts{},
  179. From: updateStmt.TableRefs,
  180. Where: updateStmt.Where,
  181. Fields: &ast.FieldList{Fields: fields},
  182. OrderBy: updateStmt.Order,
  183. Limit: updateStmt.Limit,
  184. TableHints: updateStmt.TableHints,
  185. // maybe duplicate row for select join sql.remove duplicate row by 'group by' condition
  186. GroupBy: &ast.GroupByClause{
  187. Items: u.buildGroupByClause(ctx, tableMeta.TableName, tableAliases, tableMeta.GetPrimaryKeyOnlyName(), fields),
  188. },
  189. LockInfo: &ast.SelectLockInfo{
  190. LockType: ast.SelectLockForUpdate,
  191. },
  192. }
  193. b := bytes.NewByteBuffer([]byte{})
  194. _ = selStmt.Restore(format.NewRestoreCtx(format.RestoreKeyWordUppercase, b))
  195. sql := string(b.Bytes())
  196. log.Infof("build select sql by update sourceQuery, sql {%s}", sql)
  197. return sql, u.buildSelectArgs(&selStmt, args), nil
  198. }
  199. func (u *updateJoinExecutor) buildAfterImageSQL(ctx context.Context, beforeImage types.RecordImage, meta *types.TableMeta, tableAliases string) (string, []driver.NamedValue, error) {
  200. if len(beforeImage.Rows) == 0 {
  201. return "", nil, nil
  202. }
  203. fields, err := u.buildSelectFields(ctx, meta, tableAliases, u.parserCtx.UpdateStmt.List)
  204. if err != nil {
  205. return "", nil, err
  206. }
  207. if len(fields) == 0 {
  208. return "", nil, err
  209. }
  210. updateStmt := u.parserCtx.UpdateStmt
  211. selStmt := ast.SelectStmt{
  212. SelectStmtOpts: &ast.SelectStmtOpts{},
  213. From: updateStmt.TableRefs,
  214. Where: updateStmt.Where,
  215. Fields: &ast.FieldList{Fields: fields},
  216. OrderBy: updateStmt.Order,
  217. Limit: updateStmt.Limit,
  218. TableHints: updateStmt.TableHints,
  219. // maybe duplicate row for select join sql.remove duplicate row by 'group by' condition
  220. GroupBy: &ast.GroupByClause{
  221. Items: u.buildGroupByClause(ctx, meta.TableName, tableAliases, meta.GetPrimaryKeyOnlyName(), fields),
  222. },
  223. }
  224. b := bytes.NewByteBuffer([]byte{})
  225. _ = selStmt.Restore(format.NewRestoreCtx(format.RestoreKeyWordUppercase, b))
  226. sql := string(b.Bytes())
  227. log.Infof("build select sql by update sourceQuery, sql {%s}", sql)
  228. return sql, u.buildPKParams(beforeImage.Rows, meta.GetPrimaryKeyOnlyName()), nil
  229. }
  230. func (u *updateJoinExecutor) parseTableName(joinMate *ast.Join) map[string]string {
  231. tableNames := make(map[string]string, 0)
  232. if item, ok := joinMate.Left.(*ast.Join); ok {
  233. tableNames = u.parseTableName(item)
  234. } else {
  235. leftTableSource := joinMate.Left.(*ast.TableSource)
  236. leftName := leftTableSource.Source.(*ast.TableName)
  237. tableNames[leftName.Name.O] = leftTableSource.AsName.O
  238. }
  239. rightTableSource := joinMate.Right.(*ast.TableSource)
  240. rightName := rightTableSource.Source.(*ast.TableName)
  241. tableNames[rightName.Name.O] = rightTableSource.AsName.O
  242. return tableNames
  243. }
  244. // build group by condition which used for removing duplicate row in select join sql
  245. func (u *updateJoinExecutor) buildGroupByClause(ctx context.Context, tableName string, tableAliases string, pkColumns []string, allSelectColumns []*ast.SelectField) []*ast.ByItem {
  246. var groupByPks = true
  247. if tableAliases != "" {
  248. tableName = tableAliases
  249. }
  250. //only pks group by is valid when db version >= 5.7.5
  251. if u.isLowerSupportGroupByPksVersion {
  252. if u.sqlMode == "" {
  253. rowsi, err := u.rowsPrepare(ctx, u.execContext.Conn, "SELECT @@SQL_MODE", nil)
  254. defer func() {
  255. if rowsi != nil {
  256. if rowerr := rowsi.Close(); rowerr != nil {
  257. log.Errorf("rows close fail, err:%v", rowerr)
  258. }
  259. }
  260. }()
  261. if err != nil {
  262. groupByPks = false
  263. log.Warnf("determine group by pks or all columns error:%s", err)
  264. } else {
  265. // getString("@@SQL_MODE")
  266. mode := make([]driver.Value, 1)
  267. if err = rowsi.Next(mode); err != nil {
  268. if err != io.EOF && len(mode) == 1 {
  269. u.sqlMode = reflect.ValueOf(mode[0]).String()
  270. }
  271. }
  272. }
  273. }
  274. if strings.Contains(u.sqlMode, "ONLY_FULL_GROUP_BY") {
  275. groupByPks = false
  276. }
  277. }
  278. groupByColumns := make([]*ast.ByItem, 0)
  279. if groupByPks {
  280. for _, column := range pkColumns {
  281. groupByColumns = append(groupByColumns, &ast.ByItem{
  282. Expr: &ast.ColumnNameExpr{
  283. Name: &ast.ColumnName{
  284. Table: model.CIStr{
  285. O: tableName,
  286. L: strings.ToLower(tableName),
  287. },
  288. Name: model.CIStr{
  289. O: column,
  290. L: strings.ToLower(column),
  291. },
  292. },
  293. },
  294. })
  295. }
  296. } else {
  297. for _, column := range allSelectColumns {
  298. groupByColumns = append(groupByColumns, &ast.ByItem{
  299. Expr: column.Expr,
  300. })
  301. }
  302. }
  303. return groupByColumns
  304. }