You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

sql.go 4.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. /*
  2. * Licensed to the Apache Software Foundation (ASF) under one or more
  3. * contributor license agreements. See the NOTICE file distributed with
  4. * this work for additional information regarding copyright ownership.
  5. * The ASF licenses this file to You under the Apache License, Version 2.0
  6. * (the "License"); you may not use this file except in compliance with
  7. * the License. You may obtain a copy of the License at
  8. *
  9. * http://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. */
  17. package executor
  18. import (
  19. "strings"
  20. "github.com/seata/seata-go/pkg/datasource/sql/types"
  21. )
  22. const (
  23. Dot = "."
  24. EscapeStandard = "\""
  25. EscapeMysql = "`"
  26. )
  27. // DelEscape del escape by db type
  28. func DelEscape(colName string, dbType types.DBType) string {
  29. newColName := delEscape(colName, EscapeStandard)
  30. if dbType == types.DBTypeMySQL {
  31. newColName = delEscape(newColName, EscapeMysql)
  32. }
  33. return newColName
  34. }
  35. // delEscape
  36. func delEscape(colName string, escape string) string {
  37. if colName == "" {
  38. return ""
  39. }
  40. if string(colName[0]) == escape && string(colName[len(colName)-1]) == escape {
  41. // like "scheme"."id" `scheme`.`id`
  42. str := escape + Dot + escape
  43. index := strings.Index(colName, str)
  44. if index > -1 {
  45. return colName[1:index] + Dot + colName[index+len(str):len(colName)-1]
  46. }
  47. return colName[1 : len(colName)-1]
  48. } else {
  49. // like "scheme".id `scheme`.id
  50. str := escape + Dot
  51. index := strings.Index(colName, str)
  52. if index > -1 && string(colName[0]) == escape {
  53. return colName[1:index] + Dot + colName[index+len(str):]
  54. }
  55. // like scheme."id" scheme.`id`
  56. str = Dot + escape
  57. index = strings.Index(colName, str)
  58. if index > -1 && string(colName[len(colName)-1]) == escape {
  59. return colName[0:index] + Dot + colName[index+len(str):len(colName)-1]
  60. }
  61. }
  62. return colName
  63. }
  64. // AddEscape if necessary, add escape by db type
  65. func AddEscape(colName string, dbType types.DBType) string {
  66. if dbType == types.DBTypeMySQL {
  67. return addEscape(colName, dbType, EscapeMysql)
  68. }
  69. return addEscape(colName, dbType, EscapeStandard)
  70. }
  71. func addEscape(colName string, dbType types.DBType, escape string) string {
  72. if colName == "" {
  73. return colName
  74. }
  75. if string(colName[0]) == escape && string(colName[len(colName)-1]) == escape {
  76. return colName
  77. }
  78. if !checkEscape(colName, dbType) {
  79. return colName
  80. }
  81. if strings.Contains(colName, Dot) {
  82. // like "scheme".id `scheme`.id
  83. str := escape + Dot
  84. dotIndex := strings.Index(colName, str)
  85. if dotIndex > -1 {
  86. tempStr := strings.Builder{}
  87. tempStr.WriteString(colName[0 : dotIndex+len(str)])
  88. tempStr.WriteString(escape)
  89. tempStr.WriteString(colName[dotIndex+len(str):])
  90. tempStr.WriteString(escape)
  91. return tempStr.String()
  92. }
  93. // like scheme."id" scheme.`id`
  94. str = Dot + escape
  95. dotIndex = strings.Index(colName, str)
  96. if dotIndex > -1 {
  97. tempStr := strings.Builder{}
  98. tempStr.WriteString(escape)
  99. tempStr.WriteString(colName[0:dotIndex])
  100. tempStr.WriteString(escape)
  101. tempStr.WriteString(colName[dotIndex:])
  102. return tempStr.String()
  103. }
  104. str = Dot
  105. dotIndex = strings.Index(colName, str)
  106. if dotIndex > -1 {
  107. tempStr := strings.Builder{}
  108. tempStr.WriteString(escape)
  109. tempStr.WriteString(colName[0:dotIndex])
  110. tempStr.WriteString(escape)
  111. tempStr.WriteString(Dot)
  112. tempStr.WriteString(escape)
  113. tempStr.WriteString(colName[dotIndex+len(str):])
  114. tempStr.WriteString(escape)
  115. return tempStr.String()
  116. }
  117. }
  118. buf := make([]byte, len(colName)+2)
  119. buf[0], buf[len(buf)-1] = escape[0], escape[0]
  120. for key, _ := range colName {
  121. buf[key+1] = colName[key]
  122. }
  123. return string(buf)
  124. }
  125. // checkEscape check whether given field or table name use keywords. the method has database special logic.
  126. func checkEscape(colName string, dbType types.DBType) bool {
  127. switch dbType {
  128. case types.DBTypeMySQL:
  129. if _, ok := types.GetMysqlKeyWord()[strings.ToUpper(colName)]; ok {
  130. return true
  131. }
  132. return false
  133. // TODO impl Oracle PG SQLServer ...
  134. default:
  135. return true
  136. }
  137. }