You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

parser_factory.go 2.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. /*
  2. * Licensed to the Apache Software Foundation (ASF) under one or more
  3. * contributor license agreements. See the NOTICE file distributed with
  4. * this work for additional information regarding copyright ownership.
  5. * The ASF licenses this file to You under the Apache License, Version 2.0
  6. * (the "License"); you may not use this file except in compliance with
  7. * the License. You may obtain a copy of the License at
  8. *
  9. * http://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. */
  17. package parser
  18. import (
  19. aparser "github.com/arana-db/parser"
  20. "github.com/arana-db/parser/ast"
  21. "seata.apache.org/seata-go/pkg/datasource/sql/types"
  22. )
  23. func DoParser(query string) (*types.ParseContext, error) {
  24. p := aparser.New()
  25. stmtNodes, _, err := p.Parse(query, "", "")
  26. if err != nil {
  27. return nil, err
  28. }
  29. if len(stmtNodes) == 1 {
  30. return parseParseContext(stmtNodes[0]), err
  31. }
  32. parserCtx := types.ParseContext{
  33. SQLType: types.SQLTypeMulti,
  34. ExecutorType: types.MultiExecutor,
  35. MultiStmt: make([]*types.ParseContext, 0, len(stmtNodes)),
  36. }
  37. for _, node := range stmtNodes {
  38. parserCtx.MultiStmt = append(parserCtx.MultiStmt, parseParseContext(node))
  39. }
  40. return &parserCtx, nil
  41. }
  42. func parseParseContext(stmtNode ast.StmtNode) *types.ParseContext {
  43. parserCtx := new(types.ParseContext)
  44. switch stmt := stmtNode.(type) {
  45. case *ast.InsertStmt:
  46. parserCtx.SQLType = types.SQLTypeInsert
  47. parserCtx.InsertStmt = stmt
  48. parserCtx.ExecutorType = types.InsertExecutor
  49. if stmt.IsReplace {
  50. parserCtx.ExecutorType = types.ReplaceIntoExecutor
  51. }
  52. if len(stmt.OnDuplicate) != 0 {
  53. parserCtx.SQLType = types.SQLTypeInsertOnDuplicateUpdate
  54. parserCtx.ExecutorType = types.InsertOnDuplicateExecutor
  55. }
  56. case *ast.UpdateStmt:
  57. parserCtx.SQLType = types.SQLTypeUpdate
  58. parserCtx.UpdateStmt = stmt
  59. parserCtx.ExecutorType = types.UpdateExecutor
  60. case *ast.SelectStmt:
  61. if stmt.LockInfo != nil && stmt.LockInfo.LockType == ast.SelectLockForUpdate {
  62. parserCtx.SQLType = types.SQLTypeSelectForUpdate
  63. parserCtx.SelectStmt = stmt
  64. parserCtx.ExecutorType = types.SelectForUpdateExecutor
  65. } else {
  66. parserCtx.SQLType = types.SQLTypeSelect
  67. parserCtx.SelectStmt = stmt
  68. parserCtx.ExecutorType = types.SelectExecutor
  69. }
  70. case *ast.DeleteStmt:
  71. parserCtx.SQLType = types.SQLTypeDelete
  72. parserCtx.DeleteStmt = stmt
  73. parserCtx.ExecutorType = types.DeleteExecutor
  74. }
  75. return parserCtx
  76. }