Skip to content

Commit f8ee382

Browse files
authored
Merge pull request #28 Add basic JSON support from LuckySting/json-type
2 parents db10250 + 720c17a commit f8ee382

File tree

6 files changed

+110
-2
lines changed

6 files changed

+110
-2
lines changed

test/test_suite.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import ctypes
2+
13
import pytest
24
import sqlalchemy as sa
35
import sqlalchemy.testing.suite.test_types
@@ -65,6 +67,7 @@
6567
)
6668
from sqlalchemy.testing.suite.test_types import DateTimeTest as _DateTimeTest
6769
from sqlalchemy.testing.suite.test_types import IntegerTest as _IntegerTest
70+
from sqlalchemy.testing.suite.test_types import JSONTest as _JSONTest
6871
from sqlalchemy.testing.suite.test_types import NativeUUIDTest as _NativeUUIDTest
6972
from sqlalchemy.testing.suite.test_types import NumericTest as _NumericTest
7073
from sqlalchemy.testing.suite.test_types import StringTest as _StringTest
@@ -430,6 +433,25 @@ class TimeTest(_TimeTest):
430433
pass
431434

432435

436+
class JSONTest(_JSONTest):
437+
@classmethod
438+
def define_tables(cls, metadata):
439+
Table(
440+
"data_table",
441+
metadata,
442+
Column("id", Integer, primary_key=True, default=1),
443+
Column("name", String(30), primary_key=True, nullable=False),
444+
Column("data", cls.datatype, nullable=False),
445+
Column("nulldata", cls.datatype(none_as_null=True)),
446+
)
447+
448+
def _json_value_insert(self, connection, datatype, value, data_element):
449+
if datatype == "float" and value is not None:
450+
# As python's float is stored as C double, it needs to be shrank
451+
value = ctypes.c_float(value).value
452+
return super()._json_value_insert(connection, datatype, value, data_element)
453+
454+
433455
class StringTest(_StringTest):
434456
@requirements.unbounded_varchar
435457
def test_nolength_string(self):

ydb_sqlalchemy/dbapi/connection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def _get_table_client_settings(self) -> ydb.TableClientSettings:
149149
.with_native_datetime_in_result_sets(True)
150150
.with_native_timestamp_in_result_sets(True)
151151
.with_native_interval_in_result_sets(True)
152-
.with_native_json_in_result_sets(True)
152+
.with_native_json_in_result_sets(False)
153153
)
154154

155155
def _create_driver(self):

ydb_sqlalchemy/sqlalchemy/__init__.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Experimental
33
Work in progress, breaking changes are possible.
44
"""
5+
56
import collections
67
import collections.abc
78
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union
@@ -63,6 +64,9 @@ def __init__(self, dialect):
6364

6465

6566
class YqlTypeCompiler(StrSQLTypeCompiler):
67+
def visit_JSON(self, type_: Union[sa.JSON, types.YqlJSON], **kw):
68+
return "JSON"
69+
6670
def visit_CHAR(self, type_: sa.CHAR, **kw):
6771
return "UTF8"
6872

@@ -171,8 +175,21 @@ def get_ydb_type(
171175
ydb_type = ydb.PrimitiveType.Int64
172176
# Integers
173177

178+
# Json
174179
elif isinstance(type_, sa.JSON):
175180
ydb_type = ydb.PrimitiveType.Json
181+
elif isinstance(type_, sa.JSON.JSONStrIndexType):
182+
ydb_type = ydb.PrimitiveType.Utf8
183+
elif isinstance(type_, sa.JSON.JSONIntIndexType):
184+
ydb_type = ydb.PrimitiveType.Int64
185+
elif isinstance(type_, sa.JSON.JSONPathType):
186+
ydb_type = ydb.PrimitiveType.Utf8
187+
elif isinstance(type_, types.YqlJSON):
188+
ydb_type = ydb.PrimitiveType.Json
189+
elif isinstance(type_, types.YqlJSON.YqlJSONPathType):
190+
ydb_type = ydb.PrimitiveType.Utf8
191+
# Json
192+
176193
elif isinstance(type_, sa.DateTime):
177194
ydb_type = ydb.PrimitiveType.Timestamp
178195
elif isinstance(type_, sa.Date):
@@ -326,6 +343,24 @@ def visit_function(self, func, add_to_result_map=None, **kwargs):
326343
+ [name]
327344
) % {"expr": self.function_argspec(func, **kwargs)}
328345

346+
def _yson_convert_to(self, statement: str, target_type: sa.types.TypeEngine) -> str:
347+
type_name = target_type.compile(self.dialect)
348+
if isinstance(target_type, sa.Numeric) and not isinstance(target_type, (sa.Float, sa.Double)):
349+
# Since Decimal is stored in JSON either as String or as Float
350+
string_value = f"Yson::ConvertTo({statement}, Optional<String>, Yson::Options(true AS AutoConvert))"
351+
return f"CAST({string_value} AS Optional<{type_name}>)"
352+
return f"Yson::ConvertTo({statement}, Optional<{type_name}>)"
353+
354+
def visit_json_getitem_op_binary(self, binary: sa.BinaryExpression, operator, **kw) -> str:
355+
json_field = self.process(binary.left, **kw)
356+
index = self.process(binary.right, **kw)
357+
return self._yson_convert_to(f"{json_field}[{index}]", binary.type)
358+
359+
def visit_json_path_getitem_op_binary(self, binary: sa.BinaryExpression, operator, **kw) -> str:
360+
json_field = self.process(binary.left, **kw)
361+
path = self.process(binary.right, **kw)
362+
return self._yson_convert_to(f"Yson::YPath({json_field}, {path})", binary.type)
363+
329364
def visit_regexp_match_op_binary(self, binary, operator, **kw):
330365
return self._generate_generic_binary(binary, " REGEXP ", **kw)
331366

@@ -336,7 +371,7 @@ def _is_bound_to_nullable_column(self, bind_name: str) -> bool:
336371
if bind_name in self.column_keys and hasattr(self.compile_state, "dml_table"):
337372
if bind_name in self.compile_state.dml_table.c:
338373
column = self.compile_state.dml_table.c[bind_name]
339-
return not column.primary_key
374+
return column.nullable and not column.primary_key
340375
return False
341376

342377
def _guess_bound_variable_type_by_parameters(
@@ -503,6 +538,7 @@ class YqlDialect(StrCompileDialect):
503538
supports_smallserial = False
504539
supports_schemas = False
505540
supports_constraint_comments = False
541+
supports_json_type = True
506542

507543
insert_returning = False
508544
update_returning = False
@@ -524,6 +560,10 @@ class YqlDialect(StrCompileDialect):
524560
statement_compiler = YqlCompiler
525561
ddl_compiler = YqlDDLCompiler
526562
type_compiler = YqlTypeCompiler
563+
colspecs = {
564+
sa.types.JSON: types.YqlJSON,
565+
sa.types.JSON.JSONPathType: types.YqlJSON.YqlJSONPathType,
566+
}
527567

528568
construct_arguments = [
529569
(
@@ -544,6 +584,12 @@ class YqlDialect(StrCompileDialect):
544584
def import_dbapi(cls: Any):
545585
return dbapi.YdbDBApi()
546586

587+
def __init__(self, json_serializer=None, json_deserializer=None, **kwargs):
588+
super().__init__(**kwargs)
589+
590+
self._json_deserializer = json_deserializer
591+
self._json_serializer = json_serializer
592+
547593
def _describe_table(self, connection, table_name, schema=None):
548594
if schema is not None:
549595
raise dbapi.NotSupportedError("unsupported on non empty schema")

ydb_sqlalchemy/sqlalchemy/json.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from typing import Tuple, Union
2+
3+
from sqlalchemy import types as sqltypes
4+
5+
6+
class YqlJSON(sqltypes.JSON):
7+
class YqlJSONPathType(sqltypes.JSON.JSONPathType):
8+
def _format_value(self, value: Tuple[Union[str, int]]) -> str:
9+
path = "/"
10+
for elem in value:
11+
path += f"/{elem}"
12+
return path
13+
14+
def bind_processor(self, dialect):
15+
super_proc = self.string_bind_processor(dialect)
16+
17+
def process(value: Tuple[Union[str, int]]):
18+
value = self._format_value(value)
19+
if super_proc:
20+
value = super_proc(value)
21+
return value
22+
23+
return process
24+
25+
def literal_processor(self, dialect):
26+
super_proc = self.string_literal_processor(dialect)
27+
28+
def process(value: Tuple[Union[str, int]]):
29+
value = self._format_value(value)
30+
if super_proc:
31+
value = super_proc(value)
32+
return value
33+
34+
return process

ydb_sqlalchemy/sqlalchemy/requirements.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33

44

55
class Requirements(SuiteRequirements):
6+
@property
7+
def json_type(self):
8+
return exclusions.open()
9+
610
@property
711
def array_type(self):
812
return exclusions.closed()

ydb_sqlalchemy/sqlalchemy/types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from sqlalchemy import ARRAY, ColumnElement, exc, types
44
from sqlalchemy.sql import type_api
55

6+
from .json import YqlJSON # noqa: F401
7+
68

79
class UInt64(types.Integer):
810
__visit_name__ = "uint64"

0 commit comments

Comments
 (0)