diff --git a/expression/builtin_json.go b/expression/builtin_json.go index aefbe93d7421e..0d7844ec0264f 100644 --- a/expression/builtin_json.go +++ b/expression/builtin_json.go @@ -15,6 +15,7 @@ package expression import ( json2 "encoding/json" + "strconv" "strings" "github.com/pingcap/errors" @@ -182,11 +183,21 @@ func (b *builtinJSONUnquoteSig) Clone() builtinFunc { return newSig } +func (c *jsonUnquoteFunctionClass) verifyArgs(args []Expression) error { + if err := c.baseFunctionClass.verifyArgs(args); err != nil { + return err + } + if evalType := args[0].GetType().EvalType(); evalType != types.ETString && evalType != types.ETJson { + return ErrIncorrectType.GenWithStackByArgs("1", "json_unquote") + } + return nil +} + func (c *jsonUnquoteFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error) { if err := c.verifyArgs(args); err != nil { return nil, err } - bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETString, types.ETJson) + bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETString, types.ETString) bf.tp.Flen = mysql.MaxFieldVarCharLength DisableParseJSONFlag4Expr(args[0]) sig := &builtinJSONUnquoteSig{bf} @@ -194,14 +205,16 @@ func (c *jsonUnquoteFunctionClass) getFunction(ctx sessionctx.Context, args []Ex return sig, nil } -func (b *builtinJSONUnquoteSig) evalString(row chunk.Row) (res string, isNull bool, err error) { - var j json.BinaryJSON - j, isNull, err = b.args[0].EvalJSON(b.ctx, row) +func (b *builtinJSONUnquoteSig) evalString(row chunk.Row) (string, bool, error) { + str, isNull, err := b.args[0].EvalString(b.ctx, row) if isNull || err != nil { return "", isNull, err } - res, err = j.Unquote() - return res, err != nil, err + str, err = json.UnquoteString(str) + if err != nil { + return "", false, err + } + return str, false, nil } type jsonSetFunctionClass struct { @@ -1022,24 +1035,33 @@ func (b *builtinJSONQuoteSig) Clone() builtinFunc { return newSig } +func (c *jsonQuoteFunctionClass) verifyArgs(args []Expression) error { + if err := c.baseFunctionClass.verifyArgs(args); err != nil { + return err + } + if evalType := args[0].GetType().EvalType(); evalType != types.ETString { + return ErrIncorrectType.GenWithStackByArgs("1", "json_quote") + } + return nil +} + func (c *jsonQuoteFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error) { if err := c.verifyArgs(args); err != nil { return nil, err } - bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETString, types.ETJson) + bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETString, types.ETString) DisableParseJSONFlag4Expr(args[0]) sig := &builtinJSONQuoteSig{bf} sig.setPbCode(tipb.ScalarFuncSig_JsonQuoteSig) return sig, nil } -func (b *builtinJSONQuoteSig) evalString(row chunk.Row) (res string, isNull bool, err error) { - var j json.BinaryJSON - j, isNull, err = b.args[0].EvalJSON(b.ctx, row) +func (b *builtinJSONQuoteSig) evalString(row chunk.Row) (string, bool, error) { + str, isNull, err := b.args[0].EvalString(b.ctx, row) if isNull || err != nil { return "", isNull, err } - return j.Quote(), false, nil + return strconv.Quote(str), false, nil } type jsonSearchFunctionClass struct { diff --git a/expression/builtin_json_vec.go b/expression/builtin_json_vec.go index 903a15e82cf6e..53db3b56ab779 100644 --- a/expression/builtin_json_vec.go +++ b/expression/builtin_json_vec.go @@ -14,6 +14,8 @@ package expression import ( + "strconv" + "github.com/pingcap/errors" "github.com/pingcap/parser/ast" "github.com/pingcap/tidb/sessionctx" @@ -299,12 +301,12 @@ func (b *builtinJSONQuoteSig) vectorized() bool { func (b *builtinJSONQuoteSig) vecEvalString(input *chunk.Chunk, result *chunk.Column) error { n := input.NumRows() - buf, err := b.bufAllocator.get(types.ETJson, n) + buf, err := b.bufAllocator.get(types.ETString, n) if err != nil { return err } defer b.bufAllocator.put(buf) - if err := b.args[0].VecEvalJSON(b.ctx, input, buf); err != nil { + if err := b.args[0].VecEvalString(b.ctx, input, buf); err != nil { return err } @@ -314,7 +316,7 @@ func (b *builtinJSONQuoteSig) vecEvalString(input *chunk.Chunk, result *chunk.Co result.AppendNull() continue } - result.AppendString(buf.GetJSON(i).Quote()) + result.AppendString(strconv.Quote(buf.GetString(i))) } return nil } @@ -811,12 +813,12 @@ func (b *builtinJSONUnquoteSig) vectorized() bool { func (b *builtinJSONUnquoteSig) vecEvalString(input *chunk.Chunk, result *chunk.Column) error { n := input.NumRows() - buf, err := b.bufAllocator.get(types.ETJson, n) + buf, err := b.bufAllocator.get(types.ETString, n) if err != nil { return err } defer b.bufAllocator.put(buf) - if err := b.args[0].VecEvalJSON(b.ctx, input, buf); err != nil { + if err := b.args[0].VecEvalString(b.ctx, input, buf); err != nil { return err } @@ -826,11 +828,11 @@ func (b *builtinJSONUnquoteSig) vecEvalString(input *chunk.Chunk, result *chunk. result.AppendNull() continue } - res, err := buf.GetJSON(i).Unquote() + str, err := json.UnquoteString(buf.GetString(i)) if err != nil { return err } - result.AppendString(res) + result.AppendString(str) } return nil } diff --git a/expression/builtin_json_vec_test.go b/expression/builtin_json_vec_test.go index 9640330abaa2f..5b0ffb4fb5fc1 100644 --- a/expression/builtin_json_vec_test.go +++ b/expression/builtin_json_vec_test.go @@ -102,7 +102,7 @@ var vecBuiltinJSONCases = map[string][]vecExprBenchCase{ {retEvalType: types.ETJson, childrenTypes: []types.EvalType{types.ETJson, types.ETString, types.ETJson, types.ETString, types.ETJson}, geners: []dataGenerator{nil, &constStrGener{"$.aaa"}, nil, &constStrGener{"$.bbb"}, nil}}, }, ast.JSONQuote: { - {retEvalType: types.ETString, childrenTypes: []types.EvalType{types.ETJson}}, + {retEvalType: types.ETString, childrenTypes: []types.EvalType{types.ETString}}, }, } diff --git a/expression/errors.go b/expression/errors.go index d0f7d128784fd..5db1acbd0b494 100644 --- a/expression/errors.go +++ b/expression/errors.go @@ -29,6 +29,7 @@ var ( ErrOperandColumns = terror.ClassExpression.New(mysql.ErrOperandColumns, mysql.MySQLErrName[mysql.ErrOperandColumns]) ErrCutValueGroupConcat = terror.ClassExpression.New(mysql.ErrCutValueGroupConcat, mysql.MySQLErrName[mysql.ErrCutValueGroupConcat]) ErrFunctionsNoopImpl = terror.ClassExpression.New(mysql.ErrNotSupportedYet, "function %s has only noop implementation in tidb now, use tidb_enable_noop_functions to enable these functions") + ErrIncorrectType = terror.ClassExpression.New(mysql.ErrIncorrectType, mysql.MySQLErrName[mysql.ErrIncorrectType]) // All the un-exported errors are defined here: errFunctionNotExists = terror.ClassExpression.New(mysql.ErrSpDoesNotExist, mysql.MySQLErrName[mysql.ErrSpDoesNotExist]) @@ -67,6 +68,7 @@ func init() { mysql.ErrUnknownLocale: mysql.ErrUnknownLocale, mysql.ErrBadField: mysql.ErrBadField, mysql.ErrNonUniq: mysql.ErrNonUniq, + mysql.ErrIncorrectType: mysql.ErrIncorrectType, } terror.ErrClassToMySQLCodes[terror.ClassExpression] = expressionMySQLErrCodes } diff --git a/expression/integration_test.go b/expression/integration_test.go index aa75a4d02ea5b..c29b3b982408a 100755 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -3905,24 +3905,73 @@ func (s *testIntegrationSuite) TestFuncJSON(c *C) { r := tk.MustQuery(`select json_type(a), json_type(b) from table_json`) r.Check(testkit.Rows("OBJECT OBJECT", "ARRAY ARRAY")) - r = tk.MustQuery(`select json_unquote('hello'), json_unquote('world')`) - r.Check(testkit.Rows("hello world")) - - r = tk.MustQuery(`select - json_quote(''), - json_quote('""'), - json_quote('a'), - json_quote('3'), - json_quote('{"a": "b"}'), - json_quote('{"a": "b"}'), - json_quote('hello,"quoted string",world'), - json_quote('hello,"宽字符",world'), - json_quote('Invalid Json string is OK'), - json_quote('1\u2232\u22322') - `) - r.Check(testkit.Rows( - `"" "\"\"" "a" "3" "{\"a\": \"b\"}" "{\"a\": \"b\"}" "hello,\"quoted string\",world" "hello,\"宽字符\",world" "Invalid Json string\tis OK" "1u2232u22322"`, - )) + tk.MustGetErrCode("select json_quote();", mysql.ErrWrongParamcountToNativeFct) + tk.MustGetErrCode("select json_quote('abc', 'def');", mysql.ErrWrongParamcountToNativeFct) + tk.MustGetErrCode("select json_quote(NULL, 'def');", mysql.ErrWrongParamcountToNativeFct) + tk.MustGetErrCode("select json_quote('abc', NULL);", mysql.ErrWrongParamcountToNativeFct) + + tk.MustGetErrCode("select json_unquote();", mysql.ErrWrongParamcountToNativeFct) + tk.MustGetErrCode("select json_unquote('abc', 'def');", mysql.ErrWrongParamcountToNativeFct) + tk.MustGetErrCode("select json_unquote(NULL, 'def');", mysql.ErrWrongParamcountToNativeFct) + tk.MustGetErrCode("select json_unquote('abc', NULL);", mysql.ErrWrongParamcountToNativeFct) + + tk.MustQuery("select json_quote(NULL);").Check(testkit.Rows("")) + tk.MustQuery("select json_unquote(NULL);").Check(testkit.Rows("")) + + tk.MustQuery("select json_quote('abc');").Check(testkit.Rows(`"abc"`)) + tk.MustQuery(`select json_quote(convert('"abc"' using ascii));`).Check(testkit.Rows(`"\"abc\""`)) + tk.MustQuery(`select json_quote(convert('"abc"' using latin1));`).Check(testkit.Rows(`"\"abc\""`)) + tk.MustQuery(`select json_quote(convert('"abc"' using utf8));`).Check(testkit.Rows(`"\"abc\""`)) + tk.MustQuery(`select json_quote(convert('"abc"' using utf8mb4));`).Check(testkit.Rows(`"\"abc\""`)) + + tk.MustQuery("select json_unquote('abc');").Check(testkit.Rows("abc")) + tk.MustQuery(`select json_unquote('"abc"');`).Check(testkit.Rows("abc")) + tk.MustQuery(`select json_unquote(convert('"abc"' using ascii));`).Check(testkit.Rows("abc")) + tk.MustQuery(`select json_unquote(convert('"abc"' using latin1));`).Check(testkit.Rows("abc")) + tk.MustQuery(`select json_unquote(convert('"abc"' using utf8));`).Check(testkit.Rows("abc")) + tk.MustQuery(`select json_unquote(convert('"abc"' using utf8mb4));`).Check(testkit.Rows("abc")) + + tk.MustQuery(`select json_quote('"');`).Check(testkit.Rows(`"\""`)) + tk.MustQuery(`select json_unquote('"');`).Check(testkit.Rows(`"`)) + + tk.MustQuery(`select json_unquote('""');`).Check(testkit.Rows(``)) + tk.MustQuery(`select char_length(json_unquote('""'));`).Check(testkit.Rows(`0`)) + tk.MustQuery(`select json_unquote('"" ');`).Check(testkit.Rows(`"" `)) + tk.MustQuery(`select json_unquote(cast(json_quote('abc') as json));`).Check(testkit.Rows("abc")) + + tk.MustQuery(`select json_unquote(cast('{"abc": "foo"}' as json));`).Check(testkit.Rows(`{"abc": "foo"}`)) + tk.MustQuery(`select json_unquote(json_extract(cast('{"abc": "foo"}' as json), '$.abc'));`).Check(testkit.Rows("foo")) + tk.MustQuery(`select json_unquote('["a", "b", "c"]');`).Check(testkit.Rows(`["a", "b", "c"]`)) + tk.MustQuery(`select json_unquote(cast('["a", "b", "c"]' as json));`).Check(testkit.Rows(`["a", "b", "c"]`)) + tk.MustQuery(`select json_quote(convert(X'e68891' using utf8));`).Check(testkit.Rows(`"我"`)) + tk.MustQuery(`select json_quote(convert(X'e68891' using utf8mb4));`).Check(testkit.Rows(`"我"`)) + tk.MustQuery(`select cast(json_quote(convert(X'e68891' using utf8)) as json);`).Check(testkit.Rows(`"我"`)) + tk.MustQuery(`select json_unquote(convert(X'e68891' using utf8));`).Check(testkit.Rows("我")) + + tk.MustQuery(`select json_quote(json_quote(json_quote('abc')));`).Check(testkit.Rows(`"\"\\\"abc\\\"\""`)) + tk.MustQuery(`select json_unquote(json_unquote(json_unquote(json_quote(json_quote(json_quote('abc'))))));`).Check(testkit.Rows("abc")) + + tk.MustGetErrCode("select json_quote(123)", mysql.ErrIncorrectType) + tk.MustGetErrCode("select json_quote(-100)", mysql.ErrIncorrectType) + tk.MustGetErrCode("select json_quote(123.123)", mysql.ErrIncorrectType) + tk.MustGetErrCode("select json_quote(-100.000)", mysql.ErrIncorrectType) + tk.MustGetErrCode(`select json_quote(true);`, mysql.ErrIncorrectType) + tk.MustGetErrCode(`select json_quote(false);`, mysql.ErrIncorrectType) + tk.MustGetErrCode(`select json_quote(cast("{}" as JSON));`, mysql.ErrIncorrectType) + tk.MustGetErrCode(`select json_quote(cast("[]" as JSON));`, mysql.ErrIncorrectType) + tk.MustGetErrCode(`select json_quote(cast("2015-07-29" as date));`, mysql.ErrIncorrectType) + tk.MustGetErrCode(`select json_quote(cast("12:18:29.000000" as time));`, mysql.ErrIncorrectType) + tk.MustGetErrCode(`select json_quote(cast("2015-07-29 12:18:29.000000" as datetime));`, mysql.ErrIncorrectType) + + tk.MustGetErrCode("select json_unquote(123)", mysql.ErrIncorrectType) + tk.MustGetErrCode("select json_unquote(-100)", mysql.ErrIncorrectType) + tk.MustGetErrCode("select json_unquote(123.123)", mysql.ErrIncorrectType) + tk.MustGetErrCode("select json_unquote(-100.000)", mysql.ErrIncorrectType) + tk.MustGetErrCode(`select json_unquote(true);`, mysql.ErrIncorrectType) + tk.MustGetErrCode(`select json_unquote(false);`, mysql.ErrIncorrectType) + tk.MustGetErrCode(`select json_unquote(cast("2015-07-29" as date));`, mysql.ErrIncorrectType) + tk.MustGetErrCode(`select json_unquote(cast("12:18:29.000000" as time));`, mysql.ErrIncorrectType) + tk.MustGetErrCode(`select json_unquote(cast("2015-07-29 12:18:29.000000" as datetime));`, mysql.ErrIncorrectType) r = tk.MustQuery(`select json_extract(a, '$.a[1]'), json_extract(b, '$.b') from table_json`) r.Check(testkit.Rows("\"2\" true", " ")) diff --git a/types/json/binary_functions.go b/types/json/binary_functions.go index 4a0cfd50d82ba..b936b0eafd765 100644 --- a/types/json/binary_functions.go +++ b/types/json/binary_functions.go @@ -19,7 +19,6 @@ import ( "encoding/hex" "fmt" "sort" - "strconv" "unicode/utf8" "unsafe" @@ -55,33 +54,33 @@ func (bj BinaryJSON) Type() string { } } -// Quote is for JSON_QUOTE -func (bj BinaryJSON) Quote() string { - str := hack.String(bj.GetString()) - return strconv.Quote(string(str)) -} - // Unquote is for JSON_UNQUOTE. func (bj BinaryJSON) Unquote() (string, error) { switch bj.TypeCode { case TypeCodeString: - tmp := string(hack.String(bj.GetString())) - tlen := len(tmp) - if tlen < 2 { - return tmp, nil - } - head, tail := tmp[0], tmp[tlen-1] - if head == '"' && tail == '"' { - // Remove prefix and suffix '"' before unquoting - return unquoteString(tmp[1 : tlen-1]) - } - // if value is not double quoted, do nothing - return tmp, nil + str := string(hack.String(bj.GetString())) + return UnquoteString(str) default: return bj.String(), nil } } +// UnquoteString remove quotes in a string, +// including the quotes at the head and tail of string. +func UnquoteString(str string) (string, error) { + strLen := len(str) + if strLen < 2 { + return str, nil + } + head, tail := str[0], str[strLen-1] + if head == '"' && tail == '"' { + // Remove prefix and suffix '"' before unquoting + return unquoteString(str[1 : strLen-1]) + } + // if value is not double quoted, do nothing + return str, nil +} + // unquoteString recognizes the escape sequences shown in: // https://dev.mysql.com/doc/refman/5.7/en/json-modification-functions.html#json-unquote-character-escape-sequences func unquoteString(s string) (string, error) {