|
- // Copyright 2015 PingCAP, Inc.
- //
- // Licensed under the Apache License, Version 2.0 (the "License");
- // you may not use this file except in compliance with the License.
- // You may obtain a copy of the License at
- //
- // http://www.apache.org/licenses/LICENSE-2.0
- //
- // Unless required by applicable law or agreed to in writing, software
- // distributed under the License is distributed on an "AS IS" BASIS,
- // See the License for the specific language governing permissions and
- // limitations under the License.
-
- package optimizer
-
- import (
- "github.com/juju/errors"
- "github.com/pingcap/tidb/ast"
- "github.com/pingcap/tidb/mysql"
- "github.com/pingcap/tidb/parser"
- "github.com/pingcap/tidb/parser/opcode"
- )
-
- // Validate checkes whether the node is valid.
- func Validate(node ast.Node, inPrepare bool) error {
- v := validator{inPrepare: inPrepare}
- node.Accept(&v)
- return v.err
- }
-
- // validator is an ast.Visitor that validates
- // ast Nodes parsed from parser.
- type validator struct {
- err error
- wildCardCount int
- inPrepare bool
- inAggregate bool
- }
-
- func (v *validator) Enter(in ast.Node) (out ast.Node, skipChildren bool) {
- switch in.(type) {
- case *ast.AggregateFuncExpr:
- if v.inAggregate {
- // Aggregate function can not contain aggregate function.
- v.err = ErrInvalidGroupFuncUse
- return in, true
- }
- v.inAggregate = true
- }
- return in, false
- }
-
- func (v *validator) Leave(in ast.Node) (out ast.Node, ok bool) {
- switch x := in.(type) {
- case *ast.AggregateFuncExpr:
- v.inAggregate = false
- case *ast.BetweenExpr:
- v.checkAllOneColumn(x.Expr, x.Left, x.Right)
- case *ast.BinaryOperationExpr:
- v.checkBinaryOperation(x)
- case *ast.ByItem:
- v.checkAllOneColumn(x.Expr)
- case *ast.CreateTableStmt:
- v.checkAutoIncrement(x)
- case *ast.CompareSubqueryExpr:
- v.checkSameColumns(x.L, x.R)
- case *ast.FieldList:
- v.checkFieldList(x)
- case *ast.HavingClause:
- v.checkAllOneColumn(x.Expr)
- case *ast.IsNullExpr:
- v.checkAllOneColumn(x.Expr)
- case *ast.IsTruthExpr:
- v.checkAllOneColumn(x.Expr)
- case *ast.ParamMarkerExpr:
- if !v.inPrepare {
- v.err = parser.ErrSyntax.Gen("syntax error, unexpected '?'")
- }
- case *ast.PatternInExpr:
- v.checkSameColumns(append(x.List, x.Expr)...)
- }
-
- return in, v.err == nil
- }
-
- // checkAllOneColumn checks that all expressions have one column.
- // Expression may have more than one column when it is a rowExpr or
- // a Subquery with more than one result fields.
- func (v *validator) checkAllOneColumn(exprs ...ast.ExprNode) {
- for _, expr := range exprs {
- switch x := expr.(type) {
- case *ast.RowExpr:
- v.err = ErrOneColumn
- case *ast.SubqueryExpr:
- if len(x.Query.GetResultFields()) != 1 {
- v.err = ErrOneColumn
- }
- }
- }
- return
- }
-
- func checkAutoIncrementOp(colDef *ast.ColumnDef, num int) (bool, error) {
- var hasAutoIncrement bool
-
- if colDef.Options[num].Tp == ast.ColumnOptionAutoIncrement {
- hasAutoIncrement = true
- if len(colDef.Options) == num+1 {
- return hasAutoIncrement, nil
- }
- for _, op := range colDef.Options[num+1:] {
- if op.Tp == ast.ColumnOptionDefaultValue {
- return hasAutoIncrement, errors.Errorf("Invalid default value for '%s'", colDef.Name.Name.O)
- }
- }
- }
- if colDef.Options[num].Tp == ast.ColumnOptionDefaultValue && len(colDef.Options) != num+1 {
- for _, op := range colDef.Options[num+1:] {
- if op.Tp == ast.ColumnOptionAutoIncrement {
- return hasAutoIncrement, errors.Errorf("Invalid default value for '%s'", colDef.Name.Name.O)
- }
- }
- }
-
- return hasAutoIncrement, nil
- }
-
- func isConstraintKeyTp(constraints []*ast.Constraint, colDef *ast.ColumnDef) bool {
- for _, c := range constraints {
- if len(c.Keys) < 1 {
- }
- // If the constraint as follows: primary key(c1, c2)
- // we only support c1 column can be auto_increment.
- if colDef.Name.Name.L != c.Keys[0].Column.Name.L {
- continue
- }
- switch c.Tp {
- case ast.ConstraintPrimaryKey, ast.ConstraintKey, ast.ConstraintIndex,
- ast.ConstraintUniq, ast.ConstraintUniqIndex, ast.ConstraintUniqKey:
- return true
- }
- }
-
- return false
- }
-
- func (v *validator) checkAutoIncrement(stmt *ast.CreateTableStmt) {
- var (
- isKey bool
- count int
- autoIncrementCol *ast.ColumnDef
- )
-
- for _, colDef := range stmt.Cols {
- var hasAutoIncrement bool
- for i, op := range colDef.Options {
- ok, err := checkAutoIncrementOp(colDef, i)
- if err != nil {
- v.err = err
- return
- }
- if ok {
- hasAutoIncrement = true
- }
- switch op.Tp {
- case ast.ColumnOptionPrimaryKey, ast.ColumnOptionUniqKey, ast.ColumnOptionUniqIndex,
- ast.ColumnOptionUniq, ast.ColumnOptionKey, ast.ColumnOptionIndex:
- isKey = true
- }
- }
- if hasAutoIncrement {
- count++
- autoIncrementCol = colDef
- }
- }
-
- if count < 1 {
- return
- }
-
- if !isKey {
- isKey = isConstraintKeyTp(stmt.Constraints, autoIncrementCol)
- }
- if !isKey || count > 1 {
- v.err = errors.New("Incorrect table definition; there can be only one auto column and it must be defined as a key")
- }
-
- switch autoIncrementCol.Tp.Tp {
- case mysql.TypeTiny, mysql.TypeShort, mysql.TypeLong,
- mysql.TypeFloat, mysql.TypeDouble, mysql.TypeLonglong, mysql.TypeInt24:
- default:
- v.err = errors.Errorf("Incorrect column specifier for column '%s'", autoIncrementCol.Name.Name.O)
- }
- }
-
- func (v *validator) checkBinaryOperation(x *ast.BinaryOperationExpr) {
- // row constructor only supports comparison operation.
- switch x.Op {
- case opcode.LT, opcode.LE, opcode.GE, opcode.GT, opcode.EQ, opcode.NE, opcode.NullEQ:
- v.checkSameColumns(x.L, x.R)
- default:
- v.checkAllOneColumn(x.L, x.R)
- }
- }
-
- func columnCount(ex ast.ExprNode) int {
- switch x := ex.(type) {
- case *ast.RowExpr:
- return len(x.Values)
- case *ast.SubqueryExpr:
- return len(x.Query.GetResultFields())
- default:
- return 1
- }
- }
-
- func (v *validator) checkSameColumns(exprs ...ast.ExprNode) {
- if len(exprs) == 0 {
- return
- }
- count := columnCount(exprs[0])
- for i := 1; i < len(exprs); i++ {
- if columnCount(exprs[i]) != count {
- v.err = ErrSameColumns
- return
- }
- }
- }
-
- // checkFieldList checks if there is only one '*' and each field has only one column.
- func (v *validator) checkFieldList(x *ast.FieldList) {
- var hasWildCard bool
- for _, val := range x.Fields {
- if val.WildCard != nil && val.WildCard.Table.L == "" {
- if hasWildCard {
- v.err = ErrMultiWildCard
- return
- }
- hasWildCard = true
- }
- v.checkAllOneColumn(val.Expr)
- if v.err != nil {
- return
- }
- }
- }
|