diff --git a/expression/builtin.go b/expression/builtin.go index 001f91e471597..9b192d2d2bb29 100644 --- a/expression/builtin.go +++ b/expression/builtin.go @@ -32,6 +32,7 @@ import ( "github.com/pingcap/parser/mysql" "github.com/pingcap/parser/opcode" "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/types/json" "github.com/pingcap/tidb/util/chunk" @@ -48,6 +49,9 @@ type baseBuiltinFunc struct { childrenVectorizedOnce *sync.Once childrenVectorized bool + + childrenReversedOnce *sync.Once + childrenReversed bool } func (b *baseBuiltinFunc) PbCode() tipb.ScalarFuncSig { @@ -74,6 +78,7 @@ func newBaseBuiltinFunc(ctx sessionctx.Context, args []Expression) baseBuiltinFu return baseBuiltinFunc{ bufAllocator: newLocalSliceBuffer(len(args)), childrenVectorizedOnce: new(sync.Once), + childrenReversedOnce: new(sync.Once), args: args, ctx: ctx, @@ -179,6 +184,7 @@ func newBaseBuiltinFuncWithTp(ctx sessionctx.Context, args []Expression, retType return baseBuiltinFunc{ bufAllocator: newLocalSliceBuffer(len(args)), childrenVectorizedOnce: new(sync.Once), + childrenReversedOnce: new(sync.Once), args: args, ctx: ctx, @@ -250,6 +256,27 @@ func (b *baseBuiltinFunc) vectorized() bool { return false } +func (b *baseBuiltinFunc) supportReverseEval() bool { + return false +} + +func (b *baseBuiltinFunc) isChildrenReversed() bool { + b.childrenReversedOnce.Do(func() { + b.childrenReversed = true + for _, arg := range b.args { + if !arg.SupportReverseEval() { + b.childrenReversed = false + break + } + } + }) + return b.childrenReversed +} + +func (b *baseBuiltinFunc) reverseEval(sc *stmtctx.StatementContext, res types.Datum, rType types.RoundingType) (types.Datum, error) { + return types.Datum{}, errors.Errorf("baseBuiltinFunc.reverseEvalInt() should never be called, please contact the TiDB team for help") +} + func (b *baseBuiltinFunc) isChildrenVectorized() bool { b.childrenVectorizedOnce.Do(func() { b.childrenVectorized = true @@ -305,6 +332,7 @@ func (b *baseBuiltinFunc) cloneFrom(from *baseBuiltinFunc) { b.pbCode = from.pbCode b.bufAllocator = newLocalSliceBuffer(len(b.args)) b.childrenVectorizedOnce = new(sync.Once) + b.childrenReversedOnce = new(sync.Once) } func (b *baseBuiltinFunc) Clone() builtinFunc { @@ -372,9 +400,22 @@ type vecBuiltinFunc interface { vecEvalJSON(input *chunk.Chunk, result *chunk.Column) error } +// reverseBuiltinFunc evaluates the exactly one column value in the function when given a result for expression. +// For example, the buitinFunc is builtinArithmeticPlusRealSig(2.3, builtinArithmeticMinusRealSig(Column, 3.4)) +// when given the result like 1.0, then the ReverseEval should evaluate the column value 1.0 - 2.3 + 3.4 = 2.1 +type reverseBuiltinFunc interface { + // supportReverseEval checks whether the builtinFunc support reverse evaluation. + supportReverseEval() bool + // isChildrenReversed checks whether the builtinFunc's children support reverse evaluation. + isChildrenReversed() bool + // reverseEval evaluates the only one column value with given function result. + reverseEval(sc *stmtctx.StatementContext, res types.Datum, rType types.RoundingType) (val types.Datum, err error) +} + // builtinFunc stands for a particular function signature. type builtinFunc interface { vecBuiltinFunc + reverseBuiltinFunc // evalInt evaluates int result of builtinFunc by given row. evalInt(row chunk.Row) (val int64, isNull bool, err error) diff --git a/expression/column.go b/expression/column.go index 8b0fa3f96d3b8..98601053c328d 100644 --- a/expression/column.go +++ b/expression/column.go @@ -581,3 +581,18 @@ idLoop: func (col *Column) EvalVirtualColumn(row chunk.Row) (types.Datum, error) { return col.VirtualExpr.Eval(row) } + +// SupportReverseEval checks whether the builtinFunc support reverse evaluation. +func (col *Column) SupportReverseEval() bool { + switch col.RetType.Tp { + case mysql.TypeShort, mysql.TypeLong, mysql.TypeLonglong, + mysql.TypeFloat, mysql.TypeDouble, mysql.TypeNewDecimal: + return true + } + return false +} + +// ReverseEval evaluates the only one column value with given function result. +func (col *Column) ReverseEval(sc *stmtctx.StatementContext, res types.Datum, rType types.RoundingType) (val types.Datum, err error) { + return types.ChangeReverseResultByUpperLowerBound(sc, col.RetType, res, rType) +} diff --git a/expression/constant.go b/expression/constant.go index be33d1f17ce41..5104b564072fc 100644 --- a/expression/constant.go +++ b/expression/constant.go @@ -448,3 +448,16 @@ func (c *Constant) Vectorized() bool { } return true } + +// SupportReverseEval checks whether the builtinFunc support reverse evaluation. +func (c *Constant) SupportReverseEval() bool { + if c.DeferredExpr != nil { + return c.DeferredExpr.SupportReverseEval() + } + return true +} + +// ReverseEval evaluates the only one column value with given function result. +func (c *Constant) ReverseEval(sc *stmtctx.StatementContext, res types.Datum, rType types.RoundingType) (val types.Datum, err error) { + return c.Value, nil +} diff --git a/expression/expression.go b/expression/expression.go index f8ab93986b245..2bb31b995c9f5 100644 --- a/expression/expression.go +++ b/expression/expression.go @@ -68,11 +68,21 @@ type VecExpr interface { VecEvalJSON(ctx sessionctx.Context, input *chunk.Chunk, result *chunk.Column) error } +// ReverseExpr contains all resersed evaluation methods. +type ReverseExpr interface { + // SupportReverseEval checks whether the builtinFunc support reverse evaluation. + SupportReverseEval() bool + + // ReverseEval evaluates the only one column value with given function result. + ReverseEval(sc *stmtctx.StatementContext, res types.Datum, rType types.RoundingType) (val types.Datum, err error) +} + // Expression represents all scalar expression in SQL. type Expression interface { fmt.Stringer goJSON.Marshaler VecExpr + ReverseExpr // Eval evaluates an expression through a row. Eval(row chunk.Row) (types.Datum, error) diff --git a/expression/scalar_function.go b/expression/scalar_function.go index 29a5933741b03..a274215f0c88b 100755 --- a/expression/scalar_function.go +++ b/expression/scalar_function.go @@ -86,6 +86,21 @@ func (sf *ScalarFunction) Vectorized() bool { return sf.Function.vectorized() && sf.Function.isChildrenVectorized() } +// SupportReverseEval returns if this expression supports reversed evaluation. +func (sf *ScalarFunction) SupportReverseEval() bool { + switch sf.RetType.Tp { + case mysql.TypeShort, mysql.TypeLong, mysql.TypeLonglong, + mysql.TypeFloat, mysql.TypeDouble, mysql.TypeNewDecimal: + return sf.Function.supportReverseEval() && sf.Function.isChildrenReversed() + } + return false +} + +// ReverseEval evaluates the only one column value with given function result. +func (sf *ScalarFunction) ReverseEval(sc *stmtctx.StatementContext, res types.Datum, rType types.RoundingType) (val types.Datum, err error) { + return sf.Function.reverseEval(sc, res, rType) +} + // GetCtx gets the context of function. func (sf *ScalarFunction) GetCtx() sessionctx.Context { return sf.Function.getCtx() diff --git a/expression/util_test.go b/expression/util_test.go index 13bea0dc9ec8c..4abda426a4414 100644 --- a/expression/util_test.go +++ b/expression/util_test.go @@ -501,6 +501,9 @@ func (m *MockExpr) EvalJSON(ctx sessionctx.Context, row chunk.Row) (val json.Bin } return json.BinaryJSON{}, m.i == nil, m.err } +func (m *MockExpr) ReverseEval(sc *stmtctx.StatementContext, res types.Datum, rType types.RoundingType) (val types.Datum, err error) { + return types.Datum{}, m.err +} func (m *MockExpr) GetType() *types.FieldType { return m.t } func (m *MockExpr) Clone() Expression { return nil } func (m *MockExpr) Equal(ctx sessionctx.Context, e Expression) bool { return false } @@ -513,3 +516,4 @@ func (m *MockExpr) ExplainInfo() string { return " func (m *MockExpr) ExplainNormalizedInfo() string { return "" } func (m *MockExpr) HashCode(sc *stmtctx.StatementContext) []byte { return nil } func (m *MockExpr) Vectorized() bool { return false } +func (m *MockExpr) SupportReverseEval() bool { return false } diff --git a/statistics/feedback.go b/statistics/feedback.go index 3db3ae09e43e4..cfc371947e23c 100644 --- a/statistics/feedback.go +++ b/statistics/feedback.go @@ -302,7 +302,7 @@ func buildBucketFeedback(h *Histogram, feedback *QueryFeedback) (map[int]*Bucket } total := 0 sc := &stmtctx.StatementContext{TimeZone: time.UTC} - min, max := GetMinValue(h.Tp), GetMaxValue(h.Tp) + min, max := types.GetMinValue(h.Tp), types.GetMaxValue(h.Tp) for _, fb := range feedback.Feedback { skip, err := fb.adjustFeedbackBoundaries(sc, &min, &max) if err != nil { @@ -927,73 +927,3 @@ func SupportColumnType(ft *types.FieldType) bool { } return false } - -// GetMaxValue returns the max value datum for each type. -func GetMaxValue(ft *types.FieldType) (max types.Datum) { - switch ft.Tp { - case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong: - if mysql.HasUnsignedFlag(ft.Flag) { - max.SetUint64(types.IntergerUnsignedUpperBound(ft.Tp)) - } else { - max.SetInt64(types.IntergerSignedUpperBound(ft.Tp)) - } - case mysql.TypeFloat: - max.SetFloat32(float32(types.GetMaxFloat(ft.Flen, ft.Decimal))) - case mysql.TypeDouble: - max.SetFloat64(types.GetMaxFloat(ft.Flen, ft.Decimal)) - case mysql.TypeString, mysql.TypeVarString, mysql.TypeVarchar, mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob: - val := types.MaxValueDatum() - bytes, err := codec.EncodeKey(nil, nil, val) - // should not happen - if err != nil { - logutil.BgLogger().Error("encode key fail", zap.Error(err)) - } - max.SetBytes(bytes) - case mysql.TypeNewDecimal: - max.SetMysqlDecimal(types.NewMaxOrMinDec(false, ft.Flen, ft.Decimal)) - case mysql.TypeDuration: - max.SetMysqlDuration(types.Duration{Duration: types.MaxTime}) - case mysql.TypeDate, mysql.TypeDatetime, mysql.TypeTimestamp: - if ft.Tp == mysql.TypeDate || ft.Tp == mysql.TypeDatetime { - max.SetMysqlTime(types.Time{Time: types.MaxDatetime, Type: ft.Tp}) - } else { - max.SetMysqlTime(types.MaxTimestamp) - } - } - return -} - -// GetMinValue returns the min value datum for each type. -func GetMinValue(ft *types.FieldType) (min types.Datum) { - switch ft.Tp { - case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong: - if mysql.HasUnsignedFlag(ft.Flag) { - min.SetUint64(0) - } else { - min.SetInt64(types.IntergerSignedLowerBound(ft.Tp)) - } - case mysql.TypeFloat: - min.SetFloat32(float32(-types.GetMaxFloat(ft.Flen, ft.Decimal))) - case mysql.TypeDouble: - min.SetFloat64(-types.GetMaxFloat(ft.Flen, ft.Decimal)) - case mysql.TypeString, mysql.TypeVarString, mysql.TypeVarchar, mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob: - val := types.MinNotNullDatum() - bytes, err := codec.EncodeKey(nil, nil, val) - // should not happen - if err != nil { - logutil.BgLogger().Error("encode key fail", zap.Error(err)) - } - min.SetBytes(bytes) - case mysql.TypeNewDecimal: - min.SetMysqlDecimal(types.NewMaxOrMinDec(true, ft.Flen, ft.Decimal)) - case mysql.TypeDuration: - min.SetMysqlDuration(types.Duration{Duration: types.MinTime}) - case mysql.TypeDate, mysql.TypeDatetime, mysql.TypeTimestamp: - if ft.Tp == mysql.TypeDate || ft.Tp == mysql.TypeDatetime { - min.SetMysqlTime(types.Time{Time: types.MinDatetime, Type: ft.Tp}) - } else { - min.SetMysqlTime(types.MinTimestamp) - } - } - return -} diff --git a/statistics/handle/update.go b/statistics/handle/update.go index 766fd739f5f76..73187a49dfe44 100644 --- a/statistics/handle/update.go +++ b/statistics/handle/update.go @@ -948,10 +948,10 @@ func (h *Handle) dumpRangeFeedback(sc *stmtctx.StatementContext, ran *ranger.Ran return nil } if ran.LowVal[0].Kind() == types.KindMinNotNull { - ran.LowVal[0] = statistics.GetMinValue(q.Hist.Tp) + ran.LowVal[0] = types.GetMinValue(q.Hist.Tp) } if ran.HighVal[0].Kind() == types.KindMaxValue { - ran.HighVal[0] = statistics.GetMaxValue(q.Hist.Tp) + ran.HighVal[0] = types.GetMaxValue(q.Hist.Tp) } } ranges, ok := q.Hist.SplitRange(sc, []*ranger.Range{ran}, q.Tp == statistics.IndexType) diff --git a/types/datum.go b/types/datum.go index 0b3e9148d1809..525ce3be7b2b4 100644 --- a/types/datum.go +++ b/types/datum.go @@ -1878,3 +1878,182 @@ func CloneRow(dr []Datum) []Datum { } return c } + +// GetMaxValue returns the max value datum for each type. +func GetMaxValue(ft *FieldType) (max Datum) { + switch ft.Tp { + case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong: + if mysql.HasUnsignedFlag(ft.Flag) { + max.SetUint64(IntergerUnsignedUpperBound(ft.Tp)) + } else { + max.SetInt64(IntergerSignedUpperBound(ft.Tp)) + } + case mysql.TypeFloat: + max.SetFloat32(float32(GetMaxFloat(ft.Flen, ft.Decimal))) + case mysql.TypeDouble: + max.SetFloat64(GetMaxFloat(ft.Flen, ft.Decimal)) + case mysql.TypeString, mysql.TypeVarString, mysql.TypeVarchar, mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob: + // codec.Encode KindMaxValue, to avoid import circle + bytes := []byte{250} + max.SetBytes(bytes) + case mysql.TypeNewDecimal: + max.SetMysqlDecimal(NewMaxOrMinDec(false, ft.Flen, ft.Decimal)) + case mysql.TypeDuration: + max.SetMysqlDuration(Duration{Duration: MaxTime}) + case mysql.TypeDate, mysql.TypeDatetime, mysql.TypeTimestamp: + if ft.Tp == mysql.TypeDate || ft.Tp == mysql.TypeDatetime { + max.SetMysqlTime(Time{Time: MaxDatetime, Type: ft.Tp}) + } else { + max.SetMysqlTime(MaxTimestamp) + } + } + return +} + +// GetMinValue returns the min value datum for each type. +func GetMinValue(ft *FieldType) (min Datum) { + switch ft.Tp { + case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong: + if mysql.HasUnsignedFlag(ft.Flag) { + min.SetUint64(0) + } else { + min.SetInt64(IntergerSignedLowerBound(ft.Tp)) + } + case mysql.TypeFloat: + min.SetFloat32(float32(-GetMaxFloat(ft.Flen, ft.Decimal))) + case mysql.TypeDouble: + min.SetFloat64(-GetMaxFloat(ft.Flen, ft.Decimal)) + case mysql.TypeString, mysql.TypeVarString, mysql.TypeVarchar, mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob: + // codec.Encode KindMinNotNull, to avoid import circle + bytes := []byte{1} + min.SetBytes(bytes) + case mysql.TypeNewDecimal: + min.SetMysqlDecimal(NewMaxOrMinDec(true, ft.Flen, ft.Decimal)) + case mysql.TypeDuration: + min.SetMysqlDuration(Duration{Duration: MinTime}) + case mysql.TypeDate, mysql.TypeDatetime, mysql.TypeTimestamp: + if ft.Tp == mysql.TypeDate || ft.Tp == mysql.TypeDatetime { + min.SetMysqlTime(Time{Time: MinDatetime, Type: ft.Tp}) + } else { + min.SetMysqlTime(MinTimestamp) + } + } + return +} + +// RoundingType is used to indicate the rounding type for reversing evaluation. +type RoundingType uint8 + +const ( + // Ceiling means rounding up. + Ceiling RoundingType = iota + // Floor means rounding down. + Floor +) + +func getDatumBound(retType *FieldType, rType RoundingType) Datum { + if rType == Ceiling { + return GetMaxValue(retType) + } + return GetMinValue(retType) +} + +// ChangeReverseResultByUpperLowerBound is for expression's reverse evaluation. +// Here is an example for what's effort for the function: CastRealAsInt(t.a), +// if the type of column `t.a` is mysql.TypeDouble, and there is a row that t.a == MaxFloat64 +// then the cast function will arrive a result MaxInt64. But when we do the reverse evaluation, +// if the result is MaxInt64, and the rounding type is ceiling. Then we should get the MaxFloat64 +// instead of float64(MaxInt64). +// Another example: cast(1.1 as signed) = 1, +// when we get the answer 1, we can only reversely evaluate 1.0 as the column value. So in this +// case, we should judge whether the rounding type are ceiling. If it is, then we should plus one for +// 1.0 and get the reverse result 2.0. +func ChangeReverseResultByUpperLowerBound( + sc *stmtctx.StatementContext, + retType *FieldType, + res Datum, + rType RoundingType) (Datum, error) { + d, err := res.ConvertTo(sc, retType) + if terror.ErrorEqual(err, ErrOverflow) { + return d, nil + } + if err != nil { + return d, err + } + resRetType := FieldType{} + switch res.Kind() { + case KindInt64: + resRetType.Tp = mysql.TypeLonglong + case KindUint64: + resRetType.Tp = mysql.TypeLonglong + resRetType.Flag |= mysql.UnsignedFlag + case KindFloat32: + resRetType.Tp = mysql.TypeFloat + case KindFloat64: + resRetType.Tp = mysql.TypeDouble + case KindMysqlDecimal: + resRetType.Tp = mysql.TypeNewDecimal + resRetType.Flen = int(res.GetMysqlDecimal().GetDigitsFrac() + res.GetMysqlDecimal().GetDigitsInt()) + resRetType.Decimal = int(res.GetMysqlDecimal().GetDigitsInt()) + } + bound := getDatumBound(&resRetType, rType) + cmp, err := d.CompareDatum(sc, &bound) + if err != nil { + return d, err + } + if cmp == 0 { + d = getDatumBound(retType, rType) + } else if rType == Ceiling { + switch retType.Tp { + case mysql.TypeShort: + if mysql.HasUnsignedFlag(retType.Flag) { + if d.GetUint64() != math.MaxUint16 { + d.SetUint64(d.GetUint64() + 1) + } + } else { + if d.GetInt64() != math.MaxInt16 { + d.SetInt64(d.GetInt64() + 1) + } + } + case mysql.TypeLong: + if mysql.HasUnsignedFlag(retType.Flag) { + if d.GetUint64() != math.MaxUint32 { + d.SetUint64(d.GetUint64() + 1) + } + } else { + if d.GetInt64() != math.MaxInt32 { + d.SetInt64(d.GetInt64() + 1) + } + } + case mysql.TypeLonglong: + if mysql.HasUnsignedFlag(retType.Flag) { + if d.GetUint64() != math.MaxUint64 { + d.SetUint64(d.GetUint64() + 1) + } + } else { + if d.GetInt64() != math.MaxInt64 { + d.SetInt64(d.GetInt64() + 1) + } + } + case mysql.TypeFloat: + if d.GetFloat32() != math.MaxFloat32 { + d.SetFloat32(d.GetFloat32() + 1.0) + } + case mysql.TypeDouble: + if d.GetFloat64() != math.MaxFloat64 { + d.SetFloat64(d.GetFloat64() + 1.0) + } + case mysql.TypeNewDecimal: + if d.GetMysqlDecimal().Compare(NewMaxOrMinDec(false, retType.Flen, retType.Decimal)) != 0 { + var decimalOne, newD MyDecimal + one := decimalOne.FromInt(1) + err = DecimalAdd(d.GetMysqlDecimal(), one, &newD) + if err != nil { + return d, err + } + d = NewDecimalDatum(&newD) + } + } + } + return d, nil +} diff --git a/types/datum_test.go b/types/datum_test.go index 40d6ae0dd36a0..09b612e47e681 100644 --- a/types/datum_test.go +++ b/types/datum_test.go @@ -15,7 +15,9 @@ package types import ( "fmt" + "math" "reflect" + "strconv" "testing" "time" @@ -373,6 +375,124 @@ func (ts *testDatumSuite) TestCloneDatum(c *C) { } } +func newTypeWithFlag(tp byte, flag uint) *FieldType { + t := NewFieldType(tp) + t.Flag |= flag + return t +} + +func newMyDecimal(val string, c *C) *MyDecimal { + t := MyDecimal{} + err := t.FromString([]byte(val)) + c.Assert(err, IsNil) + return &t +} + +func newRetTypeWithFlenDecimal(tp byte, flen int, decimal int) *FieldType { + return &FieldType{ + Tp: tp, + Flen: flen, + Decimal: decimal, + } +} + +func (ts *testDatumSuite) TestChangeReverseResultByUpperLowerBound(c *C) { + sc := new(stmtctx.StatementContext) + sc.IgnoreTruncate = true + sc.OverflowAsWarning = true + // TODO: add more reserve convert tests for each pair of convert type. + testData := []struct { + a Datum + res Datum + retType *FieldType + roundType RoundingType + }{ + // int64 reserve to uint64 + { + NewIntDatum(1), + NewUintDatum(2), + newTypeWithFlag(mysql.TypeLonglong, mysql.UnsignedFlag), + Ceiling, + }, + { + NewIntDatum(1), + NewUintDatum(1), + newTypeWithFlag(mysql.TypeLonglong, mysql.UnsignedFlag), + Floor, + }, + { + NewIntDatum(math.MaxInt64), + NewUintDatum(math.MaxUint64), + newTypeWithFlag(mysql.TypeLonglong, mysql.UnsignedFlag), + Ceiling, + }, + { + NewIntDatum(math.MaxInt64), + NewUintDatum(math.MaxInt64), + newTypeWithFlag(mysql.TypeLonglong, mysql.UnsignedFlag), + Floor, + }, + // int64 reserve to float64 + { + NewIntDatum(1), + NewFloat64Datum(2), + newRetTypeWithFlenDecimal(mysql.TypeDouble, mysql.MaxRealWidth, UnspecifiedLength), + Ceiling, + }, + { + NewIntDatum(1), + NewFloat64Datum(1), + newRetTypeWithFlenDecimal(mysql.TypeDouble, mysql.MaxRealWidth, UnspecifiedLength), + Floor, + }, + { + NewIntDatum(math.MaxInt64), + GetMaxValue(newRetTypeWithFlenDecimal(mysql.TypeDouble, mysql.MaxRealWidth, UnspecifiedLength)), + newRetTypeWithFlenDecimal(mysql.TypeDouble, mysql.MaxRealWidth, UnspecifiedLength), + Ceiling, + }, + { + NewIntDatum(math.MaxInt64), + NewFloat64Datum(float64(math.MaxInt64)), + newRetTypeWithFlenDecimal(mysql.TypeDouble, mysql.MaxRealWidth, UnspecifiedLength), + Floor, + }, + // int64 reserve to Decimal + { + NewIntDatum(1), + NewDecimalDatum(newMyDecimal("2", c)), + newRetTypeWithFlenDecimal(mysql.TypeNewDecimal, 30, 3), + Ceiling, + }, + { + NewIntDatum(1), + NewDecimalDatum(newMyDecimal("1", c)), + newRetTypeWithFlenDecimal(mysql.TypeNewDecimal, 30, 3), + Floor, + }, + { + NewIntDatum(math.MaxInt64), + GetMaxValue(newRetTypeWithFlenDecimal(mysql.TypeNewDecimal, 30, 3)), + newRetTypeWithFlenDecimal(mysql.TypeNewDecimal, 30, 3), + Ceiling, + }, + { + NewIntDatum(math.MaxInt64), + NewDecimalDatum(newMyDecimal(strconv.FormatInt(math.MaxInt64, 10), c)), + newRetTypeWithFlenDecimal(mysql.TypeNewDecimal, 30, 3), + Floor, + }, + } + for ith, test := range testData { + reverseRes, err := ChangeReverseResultByUpperLowerBound(sc, test.retType, test.a, test.roundType) + c.Assert(err, IsNil) + var cmp int + cmp, err = reverseRes.CompareDatum(sc, &test.res) + c.Assert(err, IsNil) + c.Assert(cmp, Equals, 0, Commentf("%dth got:%#v, expect:%#v", ith, reverseRes, test.res)) + } +} + func prepareCompareDatums() ([]Datum, []Datum) { vals := make([]Datum, 0, 5) vals = append(vals, NewIntDatum(1)) diff --git a/types/mydecimal.go b/types/mydecimal.go index 5085cbedd1505..4e4ffcd5e84aa 100644 --- a/types/mydecimal.go +++ b/types/mydecimal.go @@ -250,6 +250,11 @@ func (d *MyDecimal) GetDigitsFrac() int8 { return d.digitsFrac } +// GetDigitsInt returns the digitsInt. +func (d *MyDecimal) GetDigitsInt() int8 { + return d.digitsInt +} + // String returns the decimal string representation rounded to resultFrac. func (d *MyDecimal) String() string { tmp := *d