Skip to content

Commit

Permalink
expression: support expression reverse evaluation framework (pingcap#…
Browse files Browse the repository at this point in the history
  • Loading branch information
lzmhhh123 authored and sre-bot committed Dec 5, 2019
1 parent 6b6a698 commit 7de6200
Show file tree
Hide file tree
Showing 11 changed files with 405 additions and 73 deletions.
41 changes: 41 additions & 0 deletions expression/builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
"github.com/pingcap/parser/mysql"
"github.com/pingcap/parser/opcode"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/sessionctx/stmtctx"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/types/json"
"github.com/pingcap/tidb/util/chunk"
Expand All @@ -48,6 +49,9 @@ type baseBuiltinFunc struct {

childrenVectorizedOnce *sync.Once
childrenVectorized bool

childrenReversedOnce *sync.Once
childrenReversed bool
}

func (b *baseBuiltinFunc) PbCode() tipb.ScalarFuncSig {
Expand All @@ -74,6 +78,7 @@ func newBaseBuiltinFunc(ctx sessionctx.Context, args []Expression) baseBuiltinFu
return baseBuiltinFunc{
bufAllocator: newLocalSliceBuffer(len(args)),
childrenVectorizedOnce: new(sync.Once),
childrenReversedOnce: new(sync.Once),

args: args,
ctx: ctx,
Expand Down Expand Up @@ -179,6 +184,7 @@ func newBaseBuiltinFuncWithTp(ctx sessionctx.Context, args []Expression, retType
return baseBuiltinFunc{
bufAllocator: newLocalSliceBuffer(len(args)),
childrenVectorizedOnce: new(sync.Once),
childrenReversedOnce: new(sync.Once),

args: args,
ctx: ctx,
Expand Down Expand Up @@ -250,6 +256,27 @@ func (b *baseBuiltinFunc) vectorized() bool {
return false
}

func (b *baseBuiltinFunc) supportReverseEval() bool {
return false
}

func (b *baseBuiltinFunc) isChildrenReversed() bool {
b.childrenReversedOnce.Do(func() {
b.childrenReversed = true
for _, arg := range b.args {
if !arg.SupportReverseEval() {
b.childrenReversed = false
break
}
}
})
return b.childrenReversed
}

func (b *baseBuiltinFunc) reverseEval(sc *stmtctx.StatementContext, res types.Datum, rType types.RoundingType) (types.Datum, error) {
return types.Datum{}, errors.Errorf("baseBuiltinFunc.reverseEvalInt() should never be called, please contact the TiDB team for help")
}

func (b *baseBuiltinFunc) isChildrenVectorized() bool {
b.childrenVectorizedOnce.Do(func() {
b.childrenVectorized = true
Expand Down Expand Up @@ -305,6 +332,7 @@ func (b *baseBuiltinFunc) cloneFrom(from *baseBuiltinFunc) {
b.pbCode = from.pbCode
b.bufAllocator = newLocalSliceBuffer(len(b.args))
b.childrenVectorizedOnce = new(sync.Once)
b.childrenReversedOnce = new(sync.Once)
}

func (b *baseBuiltinFunc) Clone() builtinFunc {
Expand Down Expand Up @@ -372,9 +400,22 @@ type vecBuiltinFunc interface {
vecEvalJSON(input *chunk.Chunk, result *chunk.Column) error
}

// reverseBuiltinFunc evaluates the exactly one column value in the function when given a result for expression.
// For example, the buitinFunc is builtinArithmeticPlusRealSig(2.3, builtinArithmeticMinusRealSig(Column, 3.4))
// when given the result like 1.0, then the ReverseEval should evaluate the column value 1.0 - 2.3 + 3.4 = 2.1
type reverseBuiltinFunc interface {
// supportReverseEval checks whether the builtinFunc support reverse evaluation.
supportReverseEval() bool
// isChildrenReversed checks whether the builtinFunc's children support reverse evaluation.
isChildrenReversed() bool
// reverseEval evaluates the only one column value with given function result.
reverseEval(sc *stmtctx.StatementContext, res types.Datum, rType types.RoundingType) (val types.Datum, err error)
}

// builtinFunc stands for a particular function signature.
type builtinFunc interface {
vecBuiltinFunc
reverseBuiltinFunc

// evalInt evaluates int result of builtinFunc by given row.
evalInt(row chunk.Row) (val int64, isNull bool, err error)
Expand Down
15 changes: 15 additions & 0 deletions expression/column.go
Original file line number Diff line number Diff line change
Expand Up @@ -581,3 +581,18 @@ idLoop:
func (col *Column) EvalVirtualColumn(row chunk.Row) (types.Datum, error) {
return col.VirtualExpr.Eval(row)
}

// SupportReverseEval checks whether the builtinFunc support reverse evaluation.
func (col *Column) SupportReverseEval() bool {
switch col.RetType.Tp {
case mysql.TypeShort, mysql.TypeLong, mysql.TypeLonglong,
mysql.TypeFloat, mysql.TypeDouble, mysql.TypeNewDecimal:
return true
}
return false
}

// ReverseEval evaluates the only one column value with given function result.
func (col *Column) ReverseEval(sc *stmtctx.StatementContext, res types.Datum, rType types.RoundingType) (val types.Datum, err error) {
return types.ChangeReverseResultByUpperLowerBound(sc, col.RetType, res, rType)
}
13 changes: 13 additions & 0 deletions expression/constant.go
Original file line number Diff line number Diff line change
Expand Up @@ -448,3 +448,16 @@ func (c *Constant) Vectorized() bool {
}
return true
}

// SupportReverseEval checks whether the builtinFunc support reverse evaluation.
func (c *Constant) SupportReverseEval() bool {
if c.DeferredExpr != nil {
return c.DeferredExpr.SupportReverseEval()
}
return true
}

// ReverseEval evaluates the only one column value with given function result.
func (c *Constant) ReverseEval(sc *stmtctx.StatementContext, res types.Datum, rType types.RoundingType) (val types.Datum, err error) {
return c.Value, nil
}
10 changes: 10 additions & 0 deletions expression/expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,21 @@ type VecExpr interface {
VecEvalJSON(ctx sessionctx.Context, input *chunk.Chunk, result *chunk.Column) error
}

// ReverseExpr contains all resersed evaluation methods.
type ReverseExpr interface {
// SupportReverseEval checks whether the builtinFunc support reverse evaluation.
SupportReverseEval() bool

// ReverseEval evaluates the only one column value with given function result.
ReverseEval(sc *stmtctx.StatementContext, res types.Datum, rType types.RoundingType) (val types.Datum, err error)
}

// Expression represents all scalar expression in SQL.
type Expression interface {
fmt.Stringer
goJSON.Marshaler
VecExpr
ReverseExpr

// Eval evaluates an expression through a row.
Eval(row chunk.Row) (types.Datum, error)
Expand Down
15 changes: 15 additions & 0 deletions expression/scalar_function.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,21 @@ func (sf *ScalarFunction) Vectorized() bool {
return sf.Function.vectorized() && sf.Function.isChildrenVectorized()
}

// SupportReverseEval returns if this expression supports reversed evaluation.
func (sf *ScalarFunction) SupportReverseEval() bool {
switch sf.RetType.Tp {
case mysql.TypeShort, mysql.TypeLong, mysql.TypeLonglong,
mysql.TypeFloat, mysql.TypeDouble, mysql.TypeNewDecimal:
return sf.Function.supportReverseEval() && sf.Function.isChildrenReversed()
}
return false
}

// ReverseEval evaluates the only one column value with given function result.
func (sf *ScalarFunction) ReverseEval(sc *stmtctx.StatementContext, res types.Datum, rType types.RoundingType) (val types.Datum, err error) {
return sf.Function.reverseEval(sc, res, rType)
}

// GetCtx gets the context of function.
func (sf *ScalarFunction) GetCtx() sessionctx.Context {
return sf.Function.getCtx()
Expand Down
4 changes: 4 additions & 0 deletions expression/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,9 @@ func (m *MockExpr) EvalJSON(ctx sessionctx.Context, row chunk.Row) (val json.Bin
}
return json.BinaryJSON{}, m.i == nil, m.err
}
func (m *MockExpr) ReverseEval(sc *stmtctx.StatementContext, res types.Datum, rType types.RoundingType) (val types.Datum, err error) {
return types.Datum{}, m.err
}
func (m *MockExpr) GetType() *types.FieldType { return m.t }
func (m *MockExpr) Clone() Expression { return nil }
func (m *MockExpr) Equal(ctx sessionctx.Context, e Expression) bool { return false }
Expand All @@ -513,3 +516,4 @@ func (m *MockExpr) ExplainInfo() string { return "
func (m *MockExpr) ExplainNormalizedInfo() string { return "" }
func (m *MockExpr) HashCode(sc *stmtctx.StatementContext) []byte { return nil }
func (m *MockExpr) Vectorized() bool { return false }
func (m *MockExpr) SupportReverseEval() bool { return false }
72 changes: 1 addition & 71 deletions statistics/feedback.go
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ func buildBucketFeedback(h *Histogram, feedback *QueryFeedback) (map[int]*Bucket
}
total := 0
sc := &stmtctx.StatementContext{TimeZone: time.UTC}
min, max := GetMinValue(h.Tp), GetMaxValue(h.Tp)
min, max := types.GetMinValue(h.Tp), types.GetMaxValue(h.Tp)
for _, fb := range feedback.Feedback {
skip, err := fb.adjustFeedbackBoundaries(sc, &min, &max)
if err != nil {
Expand Down Expand Up @@ -927,73 +927,3 @@ func SupportColumnType(ft *types.FieldType) bool {
}
return false
}

// GetMaxValue returns the max value datum for each type.
func GetMaxValue(ft *types.FieldType) (max types.Datum) {
switch ft.Tp {
case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong:
if mysql.HasUnsignedFlag(ft.Flag) {
max.SetUint64(types.IntergerUnsignedUpperBound(ft.Tp))
} else {
max.SetInt64(types.IntergerSignedUpperBound(ft.Tp))
}
case mysql.TypeFloat:
max.SetFloat32(float32(types.GetMaxFloat(ft.Flen, ft.Decimal)))
case mysql.TypeDouble:
max.SetFloat64(types.GetMaxFloat(ft.Flen, ft.Decimal))
case mysql.TypeString, mysql.TypeVarString, mysql.TypeVarchar, mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob:
val := types.MaxValueDatum()
bytes, err := codec.EncodeKey(nil, nil, val)
// should not happen
if err != nil {
logutil.BgLogger().Error("encode key fail", zap.Error(err))
}
max.SetBytes(bytes)
case mysql.TypeNewDecimal:
max.SetMysqlDecimal(types.NewMaxOrMinDec(false, ft.Flen, ft.Decimal))
case mysql.TypeDuration:
max.SetMysqlDuration(types.Duration{Duration: types.MaxTime})
case mysql.TypeDate, mysql.TypeDatetime, mysql.TypeTimestamp:
if ft.Tp == mysql.TypeDate || ft.Tp == mysql.TypeDatetime {
max.SetMysqlTime(types.Time{Time: types.MaxDatetime, Type: ft.Tp})
} else {
max.SetMysqlTime(types.MaxTimestamp)
}
}
return
}

// GetMinValue returns the min value datum for each type.
func GetMinValue(ft *types.FieldType) (min types.Datum) {
switch ft.Tp {
case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong:
if mysql.HasUnsignedFlag(ft.Flag) {
min.SetUint64(0)
} else {
min.SetInt64(types.IntergerSignedLowerBound(ft.Tp))
}
case mysql.TypeFloat:
min.SetFloat32(float32(-types.GetMaxFloat(ft.Flen, ft.Decimal)))
case mysql.TypeDouble:
min.SetFloat64(-types.GetMaxFloat(ft.Flen, ft.Decimal))
case mysql.TypeString, mysql.TypeVarString, mysql.TypeVarchar, mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob:
val := types.MinNotNullDatum()
bytes, err := codec.EncodeKey(nil, nil, val)
// should not happen
if err != nil {
logutil.BgLogger().Error("encode key fail", zap.Error(err))
}
min.SetBytes(bytes)
case mysql.TypeNewDecimal:
min.SetMysqlDecimal(types.NewMaxOrMinDec(true, ft.Flen, ft.Decimal))
case mysql.TypeDuration:
min.SetMysqlDuration(types.Duration{Duration: types.MinTime})
case mysql.TypeDate, mysql.TypeDatetime, mysql.TypeTimestamp:
if ft.Tp == mysql.TypeDate || ft.Tp == mysql.TypeDatetime {
min.SetMysqlTime(types.Time{Time: types.MinDatetime, Type: ft.Tp})
} else {
min.SetMysqlTime(types.MinTimestamp)
}
}
return
}
4 changes: 2 additions & 2 deletions statistics/handle/update.go
Original file line number Diff line number Diff line change
Expand Up @@ -948,10 +948,10 @@ func (h *Handle) dumpRangeFeedback(sc *stmtctx.StatementContext, ran *ranger.Ran
return nil
}
if ran.LowVal[0].Kind() == types.KindMinNotNull {
ran.LowVal[0] = statistics.GetMinValue(q.Hist.Tp)
ran.LowVal[0] = types.GetMinValue(q.Hist.Tp)
}
if ran.HighVal[0].Kind() == types.KindMaxValue {
ran.HighVal[0] = statistics.GetMaxValue(q.Hist.Tp)
ran.HighVal[0] = types.GetMaxValue(q.Hist.Tp)
}
}
ranges, ok := q.Hist.SplitRange(sc, []*ranger.Range{ran}, q.Tp == statistics.IndexType)
Expand Down
Loading

0 comments on commit 7de6200

Please sign in to comment.