Skip to content

Commit

Permalink
expression: refactor JSON_QUOTE / JSON_UNQUOTE (pingcap#13688)
Browse files Browse the repository at this point in the history
  • Loading branch information
Deardrops authored Dec 5, 2019
1 parent 13400ee commit 6b6a698
Show file tree
Hide file tree
Showing 6 changed files with 130 additions and 56 deletions.
44 changes: 33 additions & 11 deletions expression/builtin_json.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ package expression

import (
json2 "encoding/json"
"strconv"
"strings"

"github.com/pingcap/errors"
Expand Down Expand Up @@ -182,26 +183,38 @@ func (b *builtinJSONUnquoteSig) Clone() builtinFunc {
return newSig
}

func (c *jsonUnquoteFunctionClass) verifyArgs(args []Expression) error {
if err := c.baseFunctionClass.verifyArgs(args); err != nil {
return err
}
if evalType := args[0].GetType().EvalType(); evalType != types.ETString && evalType != types.ETJson {
return ErrIncorrectType.GenWithStackByArgs("1", "json_unquote")
}
return nil
}

func (c *jsonUnquoteFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error) {
if err := c.verifyArgs(args); err != nil {
return nil, err
}
bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETString, types.ETJson)
bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETString, types.ETString)
bf.tp.Flen = mysql.MaxFieldVarCharLength
DisableParseJSONFlag4Expr(args[0])
sig := &builtinJSONUnquoteSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_JsonUnquoteSig)
return sig, nil
}

func (b *builtinJSONUnquoteSig) evalString(row chunk.Row) (res string, isNull bool, err error) {
var j json.BinaryJSON
j, isNull, err = b.args[0].EvalJSON(b.ctx, row)
func (b *builtinJSONUnquoteSig) evalString(row chunk.Row) (string, bool, error) {
str, isNull, err := b.args[0].EvalString(b.ctx, row)
if isNull || err != nil {
return "", isNull, err
}
res, err = j.Unquote()
return res, err != nil, err
str, err = json.UnquoteString(str)
if err != nil {
return "", false, err
}
return str, false, nil
}

type jsonSetFunctionClass struct {
Expand Down Expand Up @@ -1022,24 +1035,33 @@ func (b *builtinJSONQuoteSig) Clone() builtinFunc {
return newSig
}

func (c *jsonQuoteFunctionClass) verifyArgs(args []Expression) error {
if err := c.baseFunctionClass.verifyArgs(args); err != nil {
return err
}
if evalType := args[0].GetType().EvalType(); evalType != types.ETString {
return ErrIncorrectType.GenWithStackByArgs("1", "json_quote")
}
return nil
}

func (c *jsonQuoteFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error) {
if err := c.verifyArgs(args); err != nil {
return nil, err
}
bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETString, types.ETJson)
bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETString, types.ETString)
DisableParseJSONFlag4Expr(args[0])
sig := &builtinJSONQuoteSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_JsonQuoteSig)
return sig, nil
}

func (b *builtinJSONQuoteSig) evalString(row chunk.Row) (res string, isNull bool, err error) {
var j json.BinaryJSON
j, isNull, err = b.args[0].EvalJSON(b.ctx, row)
func (b *builtinJSONQuoteSig) evalString(row chunk.Row) (string, bool, error) {
str, isNull, err := b.args[0].EvalString(b.ctx, row)
if isNull || err != nil {
return "", isNull, err
}
return j.Quote(), false, nil
return strconv.Quote(str), false, nil
}

type jsonSearchFunctionClass struct {
Expand Down
16 changes: 9 additions & 7 deletions expression/builtin_json_vec.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
package expression

import (
"strconv"

"github.com/pingcap/errors"
"github.com/pingcap/parser/ast"
"github.com/pingcap/tidb/sessionctx"
Expand Down Expand Up @@ -299,12 +301,12 @@ func (b *builtinJSONQuoteSig) vectorized() bool {

func (b *builtinJSONQuoteSig) vecEvalString(input *chunk.Chunk, result *chunk.Column) error {
n := input.NumRows()
buf, err := b.bufAllocator.get(types.ETJson, n)
buf, err := b.bufAllocator.get(types.ETString, n)
if err != nil {
return err
}
defer b.bufAllocator.put(buf)
if err := b.args[0].VecEvalJSON(b.ctx, input, buf); err != nil {
if err := b.args[0].VecEvalString(b.ctx, input, buf); err != nil {
return err
}

Expand All @@ -314,7 +316,7 @@ func (b *builtinJSONQuoteSig) vecEvalString(input *chunk.Chunk, result *chunk.Co
result.AppendNull()
continue
}
result.AppendString(buf.GetJSON(i).Quote())
result.AppendString(strconv.Quote(buf.GetString(i)))
}
return nil
}
Expand Down Expand Up @@ -811,12 +813,12 @@ func (b *builtinJSONUnquoteSig) vectorized() bool {

func (b *builtinJSONUnquoteSig) vecEvalString(input *chunk.Chunk, result *chunk.Column) error {
n := input.NumRows()
buf, err := b.bufAllocator.get(types.ETJson, n)
buf, err := b.bufAllocator.get(types.ETString, n)
if err != nil {
return err
}
defer b.bufAllocator.put(buf)
if err := b.args[0].VecEvalJSON(b.ctx, input, buf); err != nil {
if err := b.args[0].VecEvalString(b.ctx, input, buf); err != nil {
return err
}

Expand All @@ -826,11 +828,11 @@ func (b *builtinJSONUnquoteSig) vecEvalString(input *chunk.Chunk, result *chunk.
result.AppendNull()
continue
}
res, err := buf.GetJSON(i).Unquote()
str, err := json.UnquoteString(buf.GetString(i))
if err != nil {
return err
}
result.AppendString(res)
result.AppendString(str)
}
return nil
}
2 changes: 1 addition & 1 deletion expression/builtin_json_vec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ var vecBuiltinJSONCases = map[string][]vecExprBenchCase{
{retEvalType: types.ETJson, childrenTypes: []types.EvalType{types.ETJson, types.ETString, types.ETJson, types.ETString, types.ETJson}, geners: []dataGenerator{nil, &constStrGener{"$.aaa"}, nil, &constStrGener{"$.bbb"}, nil}},
},
ast.JSONQuote: {
{retEvalType: types.ETString, childrenTypes: []types.EvalType{types.ETJson}},
{retEvalType: types.ETString, childrenTypes: []types.EvalType{types.ETString}},
},
}

Expand Down
2 changes: 2 additions & 0 deletions expression/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ var (
ErrOperandColumns = terror.ClassExpression.New(mysql.ErrOperandColumns, mysql.MySQLErrName[mysql.ErrOperandColumns])
ErrCutValueGroupConcat = terror.ClassExpression.New(mysql.ErrCutValueGroupConcat, mysql.MySQLErrName[mysql.ErrCutValueGroupConcat])
ErrFunctionsNoopImpl = terror.ClassExpression.New(mysql.ErrNotSupportedYet, "function %s has only noop implementation in tidb now, use tidb_enable_noop_functions to enable these functions")
ErrIncorrectType = terror.ClassExpression.New(mysql.ErrIncorrectType, mysql.MySQLErrName[mysql.ErrIncorrectType])

// All the un-exported errors are defined here:
errFunctionNotExists = terror.ClassExpression.New(mysql.ErrSpDoesNotExist, mysql.MySQLErrName[mysql.ErrSpDoesNotExist])
Expand Down Expand Up @@ -67,6 +68,7 @@ func init() {
mysql.ErrUnknownLocale: mysql.ErrUnknownLocale,
mysql.ErrBadField: mysql.ErrBadField,
mysql.ErrNonUniq: mysql.ErrNonUniq,
mysql.ErrIncorrectType: mysql.ErrIncorrectType,
}
terror.ErrClassToMySQLCodes[terror.ClassExpression] = expressionMySQLErrCodes
}
Expand Down
85 changes: 67 additions & 18 deletions expression/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3905,24 +3905,73 @@ func (s *testIntegrationSuite) TestFuncJSON(c *C) {
r := tk.MustQuery(`select json_type(a), json_type(b) from table_json`)
r.Check(testkit.Rows("OBJECT OBJECT", "ARRAY ARRAY"))

r = tk.MustQuery(`select json_unquote('hello'), json_unquote('world')`)
r.Check(testkit.Rows("hello world"))

r = tk.MustQuery(`select
json_quote(''),
json_quote('""'),
json_quote('a'),
json_quote('3'),
json_quote('{"a": "b"}'),
json_quote('{"a": "b"}'),
json_quote('hello,"quoted string",world'),
json_quote('hello,"宽字符",world'),
json_quote('Invalid Json string is OK'),
json_quote('1\u2232\u22322')
`)
r.Check(testkit.Rows(
`"" "\"\"" "a" "3" "{\"a\": \"b\"}" "{\"a\": \"b\"}" "hello,\"quoted string\",world" "hello,\"宽字符\",world" "Invalid Json string\tis OK" "1u2232u22322"`,
))
tk.MustGetErrCode("select json_quote();", mysql.ErrWrongParamcountToNativeFct)
tk.MustGetErrCode("select json_quote('abc', 'def');", mysql.ErrWrongParamcountToNativeFct)
tk.MustGetErrCode("select json_quote(NULL, 'def');", mysql.ErrWrongParamcountToNativeFct)
tk.MustGetErrCode("select json_quote('abc', NULL);", mysql.ErrWrongParamcountToNativeFct)

tk.MustGetErrCode("select json_unquote();", mysql.ErrWrongParamcountToNativeFct)
tk.MustGetErrCode("select json_unquote('abc', 'def');", mysql.ErrWrongParamcountToNativeFct)
tk.MustGetErrCode("select json_unquote(NULL, 'def');", mysql.ErrWrongParamcountToNativeFct)
tk.MustGetErrCode("select json_unquote('abc', NULL);", mysql.ErrWrongParamcountToNativeFct)

tk.MustQuery("select json_quote(NULL);").Check(testkit.Rows("<nil>"))
tk.MustQuery("select json_unquote(NULL);").Check(testkit.Rows("<nil>"))

tk.MustQuery("select json_quote('abc');").Check(testkit.Rows(`"abc"`))
tk.MustQuery(`select json_quote(convert('"abc"' using ascii));`).Check(testkit.Rows(`"\"abc\""`))
tk.MustQuery(`select json_quote(convert('"abc"' using latin1));`).Check(testkit.Rows(`"\"abc\""`))
tk.MustQuery(`select json_quote(convert('"abc"' using utf8));`).Check(testkit.Rows(`"\"abc\""`))
tk.MustQuery(`select json_quote(convert('"abc"' using utf8mb4));`).Check(testkit.Rows(`"\"abc\""`))

tk.MustQuery("select json_unquote('abc');").Check(testkit.Rows("abc"))
tk.MustQuery(`select json_unquote('"abc"');`).Check(testkit.Rows("abc"))
tk.MustQuery(`select json_unquote(convert('"abc"' using ascii));`).Check(testkit.Rows("abc"))
tk.MustQuery(`select json_unquote(convert('"abc"' using latin1));`).Check(testkit.Rows("abc"))
tk.MustQuery(`select json_unquote(convert('"abc"' using utf8));`).Check(testkit.Rows("abc"))
tk.MustQuery(`select json_unquote(convert('"abc"' using utf8mb4));`).Check(testkit.Rows("abc"))

tk.MustQuery(`select json_quote('"');`).Check(testkit.Rows(`"\""`))
tk.MustQuery(`select json_unquote('"');`).Check(testkit.Rows(`"`))

tk.MustQuery(`select json_unquote('""');`).Check(testkit.Rows(``))
tk.MustQuery(`select char_length(json_unquote('""'));`).Check(testkit.Rows(`0`))
tk.MustQuery(`select json_unquote('"" ');`).Check(testkit.Rows(`"" `))
tk.MustQuery(`select json_unquote(cast(json_quote('abc') as json));`).Check(testkit.Rows("abc"))

tk.MustQuery(`select json_unquote(cast('{"abc": "foo"}' as json));`).Check(testkit.Rows(`{"abc": "foo"}`))
tk.MustQuery(`select json_unquote(json_extract(cast('{"abc": "foo"}' as json), '$.abc'));`).Check(testkit.Rows("foo"))
tk.MustQuery(`select json_unquote('["a", "b", "c"]');`).Check(testkit.Rows(`["a", "b", "c"]`))
tk.MustQuery(`select json_unquote(cast('["a", "b", "c"]' as json));`).Check(testkit.Rows(`["a", "b", "c"]`))
tk.MustQuery(`select json_quote(convert(X'e68891' using utf8));`).Check(testkit.Rows(`"我"`))
tk.MustQuery(`select json_quote(convert(X'e68891' using utf8mb4));`).Check(testkit.Rows(`"我"`))
tk.MustQuery(`select cast(json_quote(convert(X'e68891' using utf8)) as json);`).Check(testkit.Rows(`"我"`))
tk.MustQuery(`select json_unquote(convert(X'e68891' using utf8));`).Check(testkit.Rows("我"))

tk.MustQuery(`select json_quote(json_quote(json_quote('abc')));`).Check(testkit.Rows(`"\"\\\"abc\\\"\""`))
tk.MustQuery(`select json_unquote(json_unquote(json_unquote(json_quote(json_quote(json_quote('abc'))))));`).Check(testkit.Rows("abc"))

tk.MustGetErrCode("select json_quote(123)", mysql.ErrIncorrectType)
tk.MustGetErrCode("select json_quote(-100)", mysql.ErrIncorrectType)
tk.MustGetErrCode("select json_quote(123.123)", mysql.ErrIncorrectType)
tk.MustGetErrCode("select json_quote(-100.000)", mysql.ErrIncorrectType)
tk.MustGetErrCode(`select json_quote(true);`, mysql.ErrIncorrectType)
tk.MustGetErrCode(`select json_quote(false);`, mysql.ErrIncorrectType)
tk.MustGetErrCode(`select json_quote(cast("{}" as JSON));`, mysql.ErrIncorrectType)
tk.MustGetErrCode(`select json_quote(cast("[]" as JSON));`, mysql.ErrIncorrectType)
tk.MustGetErrCode(`select json_quote(cast("2015-07-29" as date));`, mysql.ErrIncorrectType)
tk.MustGetErrCode(`select json_quote(cast("12:18:29.000000" as time));`, mysql.ErrIncorrectType)
tk.MustGetErrCode(`select json_quote(cast("2015-07-29 12:18:29.000000" as datetime));`, mysql.ErrIncorrectType)

tk.MustGetErrCode("select json_unquote(123)", mysql.ErrIncorrectType)
tk.MustGetErrCode("select json_unquote(-100)", mysql.ErrIncorrectType)
tk.MustGetErrCode("select json_unquote(123.123)", mysql.ErrIncorrectType)
tk.MustGetErrCode("select json_unquote(-100.000)", mysql.ErrIncorrectType)
tk.MustGetErrCode(`select json_unquote(true);`, mysql.ErrIncorrectType)
tk.MustGetErrCode(`select json_unquote(false);`, mysql.ErrIncorrectType)
tk.MustGetErrCode(`select json_unquote(cast("2015-07-29" as date));`, mysql.ErrIncorrectType)
tk.MustGetErrCode(`select json_unquote(cast("12:18:29.000000" as time));`, mysql.ErrIncorrectType)
tk.MustGetErrCode(`select json_unquote(cast("2015-07-29 12:18:29.000000" as datetime));`, mysql.ErrIncorrectType)

r = tk.MustQuery(`select json_extract(a, '$.a[1]'), json_extract(b, '$.b') from table_json`)
r.Check(testkit.Rows("\"2\" true", "<nil> <nil>"))
Expand Down
37 changes: 18 additions & 19 deletions types/json/binary_functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ import (
"encoding/hex"
"fmt"
"sort"
"strconv"
"unicode/utf8"
"unsafe"

Expand Down Expand Up @@ -55,33 +54,33 @@ func (bj BinaryJSON) Type() string {
}
}

// Quote is for JSON_QUOTE
func (bj BinaryJSON) Quote() string {
str := hack.String(bj.GetString())
return strconv.Quote(string(str))
}

// Unquote is for JSON_UNQUOTE.
func (bj BinaryJSON) Unquote() (string, error) {
switch bj.TypeCode {
case TypeCodeString:
tmp := string(hack.String(bj.GetString()))
tlen := len(tmp)
if tlen < 2 {
return tmp, nil
}
head, tail := tmp[0], tmp[tlen-1]
if head == '"' && tail == '"' {
// Remove prefix and suffix '"' before unquoting
return unquoteString(tmp[1 : tlen-1])
}
// if value is not double quoted, do nothing
return tmp, nil
str := string(hack.String(bj.GetString()))
return UnquoteString(str)
default:
return bj.String(), nil
}
}

// UnquoteString remove quotes in a string,
// including the quotes at the head and tail of string.
func UnquoteString(str string) (string, error) {
strLen := len(str)
if strLen < 2 {
return str, nil
}
head, tail := str[0], str[strLen-1]
if head == '"' && tail == '"' {
// Remove prefix and suffix '"' before unquoting
return unquoteString(str[1 : strLen-1])
}
// if value is not double quoted, do nothing
return str, nil
}

// unquoteString recognizes the escape sequences shown in:
// https://dev.mysql.com/doc/refman/5.7/en/json-modification-functions.html#json-unquote-character-escape-sequences
func unquoteString(s string) (string, error) {
Expand Down

0 comments on commit 6b6a698

Please sign in to comment.