You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

basic_undo_log_builder.go 6.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260
  1. /*
  2. * Licensed to the Apache Software Foundation (ASF) under one or more
  3. * contributor license agreements. See the NOTICE file distributed with
  4. * this work for additional information regarding copyright ownership.
  5. * The ASF licenses this file to You under the Apache License, Version 2.0
  6. * (the "License"); you may not use this file except in compliance with
  7. * the License. You may obtain a copy of the License at
  8. *
  9. * http://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. */
  17. package builder
  18. import (
  19. "database/sql"
  20. "database/sql/driver"
  21. "fmt"
  22. "io"
  23. "strings"
  24. "github.com/arana-db/parser/ast"
  25. "github.com/arana-db/parser/test_driver"
  26. gxsort "github.com/dubbogo/gost/sort"
  27. "github.com/seata/seata-go/pkg/datasource/sql/types"
  28. )
  29. type BasicUndoLogBuilder struct{}
  30. // getScanSlice get the column type for scann
  31. func (*BasicUndoLogBuilder) GetScanSlice(columnNames []string, tableMeta types.TableMeta) []driver.Value {
  32. scanSlice := make([]driver.Value, 0, len(columnNames))
  33. for _, columnNmae := range columnNames {
  34. var (
  35. scanVal interface{}
  36. // 从metData获取该列的元信息
  37. columnMeta = tableMeta.Columns[columnNmae]
  38. )
  39. switch columnMeta.Info.ScanType() {
  40. case types.ScanTypeFloat32:
  41. scanVal = float32(0)
  42. break
  43. case types.ScanTypeFloat64:
  44. scanVal = float64(0)
  45. break
  46. case types.ScanTypeInt8:
  47. scanVal = int8(0)
  48. break
  49. case types.ScanTypeInt16:
  50. scanVal = int16(0)
  51. break
  52. case types.ScanTypeInt32:
  53. scanVal = int32(0)
  54. break
  55. case types.ScanTypeInt64:
  56. scanVal = int64(0)
  57. break
  58. case types.ScanTypeNullFloat:
  59. scanVal = sql.NullFloat64{}
  60. break
  61. case types.ScanTypeNullInt:
  62. scanVal = sql.NullInt64{}
  63. break
  64. case types.ScanTypeNullTime:
  65. scanVal = sql.NullTime{}
  66. break
  67. case types.ScanTypeUint8:
  68. scanVal = uint8(0)
  69. break
  70. case types.ScanTypeUint16:
  71. scanVal = uint16(0)
  72. break
  73. case types.ScanTypeUint32:
  74. scanVal = uint32(0)
  75. break
  76. case types.ScanTypeUint64:
  77. scanVal = uint64(0)
  78. break
  79. case types.ScanTypeRawBytes:
  80. scanVal = sql.RawBytes{}
  81. break
  82. case types.ScanTypeUnknown:
  83. scanVal = new(interface{})
  84. break
  85. }
  86. scanSlice = append(scanSlice, &scanVal)
  87. }
  88. return scanSlice
  89. }
  90. func (b *BasicUndoLogBuilder) buildSelectArgs(stmt *ast.SelectStmt, args []driver.Value) []driver.Value {
  91. var (
  92. selectArgsIndexs = make([]int32, 0)
  93. selectArgs = make([]driver.Value, 0)
  94. )
  95. b.traversalArgs(stmt.Where, &selectArgsIndexs)
  96. if stmt.OrderBy != nil {
  97. for _, item := range stmt.OrderBy.Items {
  98. b.traversalArgs(item, &selectArgsIndexs)
  99. }
  100. }
  101. if stmt.Limit != nil {
  102. if stmt.Limit.Offset != nil {
  103. b.traversalArgs(stmt.Limit.Offset, &selectArgsIndexs)
  104. }
  105. if stmt.Limit.Count != nil {
  106. b.traversalArgs(stmt.Limit.Count, &selectArgsIndexs)
  107. }
  108. }
  109. // sort selectArgs index array
  110. gxsort.Int32(selectArgsIndexs)
  111. for _, index := range selectArgsIndexs {
  112. selectArgs = append(selectArgs, args[index])
  113. }
  114. return selectArgs
  115. }
  116. // todo perfect all sql operation
  117. func (b *BasicUndoLogBuilder) traversalArgs(node ast.Node, argsIndex *[]int32) {
  118. if node == nil {
  119. return
  120. }
  121. switch node.(type) {
  122. case *ast.BinaryOperationExpr:
  123. expr := node.(*ast.BinaryOperationExpr)
  124. b.traversalArgs(expr.L, argsIndex)
  125. b.traversalArgs(expr.R, argsIndex)
  126. break
  127. case *ast.BetweenExpr:
  128. expr := node.(*ast.BetweenExpr)
  129. b.traversalArgs(expr.Left, argsIndex)
  130. b.traversalArgs(expr.Right, argsIndex)
  131. break
  132. case *ast.PatternInExpr:
  133. exprs := node.(*ast.PatternInExpr).List
  134. for i := 0; i < len(exprs); i++ {
  135. b.traversalArgs(exprs[i], argsIndex)
  136. }
  137. break
  138. case *test_driver.ParamMarkerExpr:
  139. *argsIndex = append(*argsIndex, int32(node.(*test_driver.ParamMarkerExpr).Order))
  140. break
  141. }
  142. }
  143. func (b *BasicUndoLogBuilder) buildRecordImages(rowsi driver.Rows, tableMetaData types.TableMeta) (*types.RecordImage, error) {
  144. // select column names
  145. columnNames := rowsi.Columns()
  146. rowImages := make([]types.RowImage, 0)
  147. ss := b.GetScanSlice(columnNames, tableMetaData)
  148. for {
  149. err := rowsi.Next(ss)
  150. if err == io.EOF {
  151. break
  152. }
  153. columns := make([]types.ColumnImage, 0)
  154. // build record image
  155. for i, name := range columnNames {
  156. columnMeta := tableMetaData.Columns[name]
  157. keyType := types.IndexTypeNull
  158. if data, ok := tableMetaData.Indexs[name]; ok {
  159. keyType = data.IType
  160. }
  161. jdbcType := types.GetJDBCTypeByTypeName(columnMeta.Info.DatabaseTypeName())
  162. columns = append(columns, types.ColumnImage{
  163. KeyType: keyType,
  164. Name: name,
  165. Type: int16(jdbcType),
  166. Value: ss[i],
  167. })
  168. }
  169. rowImages = append(rowImages, types.RowImage{Columns: columns})
  170. }
  171. return &types.RecordImage{TableName: tableMetaData.Name, Rows: rowImages}, nil
  172. }
  173. // buildWhereConditionByPKs build where condition by primary keys
  174. // each pk is a condition.the result will like :" (id,userCode) in ((?,?),(?,?)) or (id,userCode) in ((?,?),(?,?) ) or (id,userCode) in ((?,?))"
  175. func (b *BasicUndoLogBuilder) buildWhereConditionByPKs(pkNameList []string, rowSize int, dbType string, maxInSize int) string {
  176. var (
  177. whereStr = &strings.Builder{}
  178. batchSize = rowSize/maxInSize + 1
  179. )
  180. if rowSize%maxInSize == 0 {
  181. batchSize = rowSize / maxInSize
  182. }
  183. for batch := 0; batch < batchSize; batch++ {
  184. if batch > 0 {
  185. whereStr.WriteString(" OR ")
  186. }
  187. whereStr.WriteString("(")
  188. for i := 0; i < len(pkNameList); i++ {
  189. if i > 0 {
  190. whereStr.WriteString(",")
  191. }
  192. // todo add escape
  193. whereStr.WriteString(fmt.Sprintf("`%s`", pkNameList[i]))
  194. }
  195. whereStr.WriteString(") IN (")
  196. var eachSize int
  197. if batch == batchSize-1 {
  198. if rowSize%maxInSize == 0 {
  199. eachSize = maxInSize
  200. } else {
  201. eachSize = rowSize % maxInSize
  202. }
  203. } else {
  204. eachSize = maxInSize
  205. }
  206. for i := 0; i < eachSize; i++ {
  207. if i > 0 {
  208. whereStr.WriteString(",")
  209. }
  210. whereStr.WriteString("(")
  211. for j := 0; j < len(pkNameList); j++ {
  212. if j > 0 {
  213. whereStr.WriteString(",")
  214. }
  215. whereStr.WriteString("?")
  216. }
  217. whereStr.WriteString(")")
  218. }
  219. whereStr.WriteString(")")
  220. }
  221. return whereStr.String()
  222. }
  223. func (b *BasicUndoLogBuilder) buildPKParams(rows []types.RowImage, pkNameList []string) []driver.Value {
  224. params := make([]driver.Value, 0)
  225. for _, row := range rows {
  226. coumnMap := row.GetColumnMap()
  227. for _, pk := range pkNameList {
  228. col := coumnMap[pk]
  229. if col != nil {
  230. params = append(params, col.Value)
  231. }
  232. }
  233. }
  234. return params
  235. }