diff --git a/statistics/handle/update.go b/statistics/handle/update.go index 54319fecdb308..98083228d4cc6 100644 --- a/statistics/handle/update.go +++ b/statistics/handle/update.go @@ -934,7 +934,7 @@ func (h *Handle) RecalculateExpectCount(q *statistics.QueryFeedback) error { expected *= idx.GetIncreaseFactor(t.Count) } else { c := t.Columns[id] - expected, err = c.GetColumnRowCount(sc, ranges, t.ModifyCount) + expected, err = c.GetColumnRowCount(sc, ranges, t.ModifyCount, true) expected *= c.GetIncreaseFactor(t.Count) } q.Expected = int64(expected) diff --git a/statistics/histogram.go b/statistics/histogram.go index 1bff31a4f3945..a5d07928dc1aa 100644 --- a/statistics/histogram.go +++ b/statistics/histogram.go @@ -731,7 +731,7 @@ func (c *Column) equalRowCount(sc *stmtctx.StatementContext, val types.Datum, mo } // GetColumnRowCount estimates the row count by a slice of Range. -func (c *Column) GetColumnRowCount(sc *stmtctx.StatementContext, ranges []*ranger.Range, modifyCount int64) (float64, error) { +func (c *Column) GetColumnRowCount(sc *stmtctx.StatementContext, ranges []*ranger.Range, modifyCount int64, pkIsHandle bool) (float64, error) { var rowCount float64 for _, rg := range ranges { cmp, err := rg.LowVal[0].CompareDatum(sc, &rg.HighVal[0]) @@ -741,6 +741,11 @@ func (c *Column) GetColumnRowCount(sc *stmtctx.StatementContext, ranges []*range if cmp == 0 { // the point case. if !rg.LowExclude && !rg.HighExclude { + // In this case, the row count is at most 1. + if pkIsHandle { + rowCount += 1 + continue + } var cnt float64 cnt, err = c.equalRowCount(sc, rg.LowVal[0], modifyCount) if err != nil { @@ -855,6 +860,11 @@ func (idx *Index) GetRowCount(sc *stmtctx.StatementContext, indexRanges []*range continue } if fullLen { + // At most 1 in this case. + if idx.Info.Unique { + totalCount += 1 + continue + } count, err := idx.equalRowCount(sc, lb, modifyCount) if err != nil { return 0, err diff --git a/statistics/histogram_test.go b/statistics/histogram_test.go index 51b18480fc26d..cc22d2f0762f6 100644 --- a/statistics/histogram_test.go +++ b/statistics/histogram_test.go @@ -50,9 +50,9 @@ func (s *testStatisticsSuite) TestNewHistogramBySelectivity(c *C) { node.Ranges = append(node.Ranges, &ranger.Range{LowVal: types.MakeDatums(25), HighVal: []types.Datum{types.MaxValueDatum()}}) intColResult := `column:1 ndv:16 totColSize:0 num: 30 lower_bound: 0 upper_bound: 2 repeats: 10 -num: 20 lower_bound: 6 upper_bound: 8 repeats: 0 +num: 11 lower_bound: 6 upper_bound: 8 repeats: 0 num: 30 lower_bound: 9 upper_bound: 11 repeats: 0 -num: 10 lower_bound: 12 upper_bound: 14 repeats: 0 +num: 1 lower_bound: 12 upper_bound: 14 repeats: 0 num: 30 lower_bound: 27 upper_bound: 29 repeats: 0` stringCol := &Column{} diff --git a/statistics/selectivity_test.go b/statistics/selectivity_test.go index 48aa2aa067133..3fcfd3221d017 100644 --- a/statistics/selectivity_test.go +++ b/statistics/selectivity_test.go @@ -424,6 +424,38 @@ func (s *testStatsSuite) TestEstimationForUnknownValues(c *C) { c.Assert(count, Equals, 0.0) } +func (s *testStatsSuite) TestEstimationUniqueKeyEqualConds(c *C) { + defer cleanEnv(c, s.store, s.do) + testKit := testkit.NewTestKit(c, s.store) + testKit.MustExec("use test") + testKit.MustExec("drop table if exists t") + testKit.MustExec("create table t(a int, b int, c int, unique key(b))") + testKit.MustExec("insert into t values (1,1,1),(2,2,2),(3,3,3),(4,4,4),(5,5,5),(6,6,6),(7,7,7)") + testKit.MustExec("analyze table t with 4 cmsketch width, 1 cmsketch depth;") + table, err := s.do.InfoSchema().TableByName(model.NewCIStr("test"), model.NewCIStr("t")) + c.Assert(err, IsNil) + statsTbl := s.do.StatsHandle().GetTableStats(table.Meta()) + + sc := &stmtctx.StatementContext{} + idxID := table.Meta().Indices[0].ID + count, err := statsTbl.GetRowCountByIndexRanges(sc, idxID, getRange(7, 7)) + c.Assert(err, IsNil) + c.Assert(count, Equals, 1.0) + + count, err = statsTbl.GetRowCountByIndexRanges(sc, idxID, getRange(6, 6)) + c.Assert(err, IsNil) + c.Assert(count, Equals, 1.0) + + colID := table.Meta().Columns[0].ID + count, err = statsTbl.GetRowCountByIntColumnRanges(sc, colID, getRange(7, 7)) + c.Assert(err, IsNil) + c.Assert(count, Equals, 1.0) + + count, err = statsTbl.GetRowCountByIntColumnRanges(sc, colID, getRange(6, 6)) + c.Assert(err, IsNil) + c.Assert(count, Equals, 1.0) +} + func (s *testStatsSuite) TestPrimaryKeySelectivity(c *C) { defer cleanEnv(c, s.store, s.do) testKit := testkit.NewTestKit(c, s.store) diff --git a/statistics/table.go b/statistics/table.go index da563044dc2ad..c3d79ad621fa9 100644 --- a/statistics/table.go +++ b/statistics/table.go @@ -240,7 +240,7 @@ func (coll *HistColl) GetRowCountByIntColumnRanges(sc *stmtctx.StatementContext, } return getPseudoRowCountByUnsignedIntRanges(intRanges, float64(coll.Count)), nil } - result, err := c.GetColumnRowCount(sc, intRanges, coll.ModifyCount) + result, err := c.GetColumnRowCount(sc, intRanges, coll.ModifyCount, true) result *= c.GetIncreaseFactor(coll.Count) return result, errors.Trace(err) } @@ -251,7 +251,7 @@ func (coll *HistColl) GetRowCountByColumnRanges(sc *stmtctx.StatementContext, co if !ok || c.IsInvalid(sc, coll.Pseudo) { return GetPseudoRowCountByColumnRanges(sc, float64(coll.Count), colRanges, 0) } - result, err := c.GetColumnRowCount(sc, colRanges, coll.ModifyCount) + result, err := c.GetColumnRowCount(sc, colRanges, coll.ModifyCount, false) result *= c.GetIncreaseFactor(coll.Count) return result, errors.Trace(err) } @@ -387,7 +387,11 @@ func isSingleColIdxNullRange(idx *Index, ran *ranger.Range) bool { // getEqualCondSelectivity gets the selectivity of the equal conditions. `coverAll` means if the conditions // have covered all the index columns. -func (coll *HistColl) getEqualCondSelectivity(idx *Index, bytes []byte, coverAll bool) float64 { +func (coll *HistColl) getEqualCondSelectivity(idx *Index, bytes []byte, coverAll bool, unique bool) float64 { + // In this case, the row count is at most 1. + if unique && coverAll { + return 1.0 / float64(idx.TotalCount()) + } val := types.NewBytesDatum(bytes) if idx.outOfRange(val) { // When the value is out of range, we could not found this value in the CM Sketch, @@ -434,7 +438,7 @@ func (coll *HistColl) getIndexRowCount(sc *stmtctx.StatementContext, idxID int64 if err != nil { return 0, errors.Trace(err) } - selectivity = coll.getEqualCondSelectivity(idx, bytes, coverAll) + selectivity = coll.getEqualCondSelectivity(idx, bytes, coverAll, idx.Info.Unique) } else { bytes, err := codec.EncodeKey(sc, nil, ran.LowVal[:rangePosition-1]...) if err != nil { @@ -447,7 +451,7 @@ func (coll *HistColl) getIndexRowCount(sc *stmtctx.StatementContext, idxID int64 if err != nil { return 0, err } - selectivity += coll.getEqualCondSelectivity(idx, bytes, coverAll) + selectivity += coll.getEqualCondSelectivity(idx, bytes, coverAll, idx.Info.Unique) } } // use histogram to estimate the range condition