Skip to content

Commit ede1956

Browse files
authored
chore: refactor init and extract Query API (#93)
* refactor: extract query API * test: add pytest.ini with custom marker definition * fix: polars module check in query API
1 parent a030485 commit ede1956

File tree

6 files changed

+126
-73
lines changed

6 files changed

+126
-73
lines changed

influxdb_client_3/__init__.py

+8-56
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,17 @@
1-
import json
21
import urllib.parse
32

43
import pyarrow as pa
5-
from pyarrow.flight import FlightClient, Ticket, FlightCallOptions
4+
import importlib.util
65

6+
from influxdb_client_3.query.query_api import QueryApi as _QueryApi
77
from influxdb_client_3.read_file import UploadFile
88
from influxdb_client_3.write_client import InfluxDBClient as _InfluxDBClient, WriteOptions, Point
99
from influxdb_client_3.write_client.client.exceptions import InfluxDBError
1010
from influxdb_client_3.write_client.client.write_api import WriteApi as _WriteApi, SYNCHRONOUS, ASYNCHRONOUS, \
1111
PointSettings
1212
from influxdb_client_3.write_client.domain.write_precision import WritePrecision
13-
from influxdb_client_3.version import USER_AGENT
1413

15-
try:
16-
import polars as pl
17-
18-
polars = True
19-
except ImportError:
20-
polars = False
14+
polars = importlib.util.find_spec("polars") is not None
2115

2216

2317
def write_client_options(**kwargs):
@@ -144,23 +138,15 @@ def __init__(
144138
**kwargs)
145139

146140
self._write_api = _WriteApi(influxdb_client=self._client, **self._write_client_options)
147-
self._flight_client_options = flight_client_options or {}
148141

149142
if query_port_overwrite is not None:
150143
port = query_port_overwrite
151-
152-
gen_opts = [
153-
("grpc.secondary_user_agent", USER_AGENT)
154-
]
155-
156-
self._flight_client_options["generic_options"] = gen_opts
157-
158144
if scheme == 'https':
159145
connection_string = f"grpc+tls://{hostname}:{port}"
160146
else:
161147
connection_string = f"grpc+tcp://{hostname}:{port}"
162-
163-
self._flight_client = FlightClient(connection_string, **self._flight_client_options)
148+
self._query_api = _QueryApi(connection_string=connection_string, token=token,
149+
flight_client_options=flight_client_options)
164150

165151
def write(self, record=None, database=None, **kwargs):
166152
"""
@@ -258,48 +244,14 @@ def query(self, query: str, language: str = "sql", mode: str = "all", database:
258244
database = self._database
259245

260246
try:
261-
# Create an authorization header
262-
optargs = {
263-
"headers": [(b"authorization", f"Bearer {self._token}".encode('utf-8'))],
264-
"timeout": 300
265-
}
266-
opts = _merge_options(optargs, exclude_keys=['query_parameters'], custom=kwargs)
267-
_options = FlightCallOptions(**opts)
268-
269-
#
270-
# Ticket data
271-
#
272-
ticket_data = {
273-
"database": database,
274-
"sql_query": query,
275-
"query_type": language
276-
}
277-
# add query parameters
278-
query_parameters = kwargs.get("query_parameters", None)
279-
if query_parameters:
280-
ticket_data["params"] = query_parameters
281-
282-
ticket = Ticket(json.dumps(ticket_data).encode('utf-8'))
283-
flight_reader = self._flight_client.do_get(ticket, _options)
284-
285-
mode_func = {
286-
"all": flight_reader.read_all,
287-
"pandas": flight_reader.read_pandas,
288-
"polars": lambda: pl.from_arrow(flight_reader.read_all()),
289-
"chunk": lambda: flight_reader,
290-
"reader": flight_reader.to_reader,
291-
"schema": lambda: flight_reader.schema
292-
293-
}.get(mode, flight_reader.read_all)
294-
295-
return mode_func() if callable(mode_func) else mode_func
296-
except Exception as e:
247+
return self._query_api.query(query=query, language=language, mode=mode, database=database, **kwargs)
248+
except InfluxDBError as e:
297249
raise e
298250

299251
def close(self):
300252
"""Close the client and clean up resources."""
301253
self._write_api.close()
302-
self._flight_client.close()
254+
self._query_api.close()
303255
self._client.close()
304256

305257
def __enter__(self):

influxdb_client_3/query/query_api.py

+104
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
"""Query data in InfluxDB 3."""
2+
3+
# coding: utf-8
4+
import json
5+
6+
from pyarrow.flight import FlightClient, Ticket, FlightCallOptions, FlightStreamReader
7+
from influxdb_client_3.version import USER_AGENT
8+
9+
10+
class QueryApi(object):
11+
"""
12+
Implementation for '/api/v2/query' endpoint.
13+
14+
Example:
15+
.. code-block:: python
16+
17+
from influxdb_client import InfluxDBClient
18+
19+
20+
# Initialize instance of QueryApi
21+
with InfluxDBClient(url="http://localhost:8086", token="my-token", org="my-org") as client:
22+
query_api = client.query_api()
23+
"""
24+
25+
def __init__(self,
26+
connection_string,
27+
token,
28+
flight_client_options) -> None:
29+
"""
30+
Initialize defaults.
31+
32+
:param connection_string: Flight/gRPC connection string
33+
:param token: access token
34+
:param flight_client_options: Flight client options
35+
"""
36+
self._token = token
37+
self._flight_client_options = flight_client_options or {}
38+
self._flight_client_options["generic_options"] = [
39+
("grpc.secondary_user_agent", USER_AGENT)
40+
]
41+
self._flight_client = FlightClient(connection_string, **self._flight_client_options)
42+
43+
def query(self, query: str, language: str, mode: str, database: str, **kwargs):
44+
"""Query data from InfluxDB.
45+
46+
:param query: The query to execute on the database.
47+
:param language: The query language.
48+
:param mode: The mode to use for the query.
49+
It should be one of "all", "pandas", "polars", "chunk", "reader" or "schema".
50+
:param database: The database to query from.
51+
:param kwargs: Additional arguments to pass to the ``FlightCallOptions headers``.
52+
For example, it can be used to set up per request headers.
53+
:keyword query_parameters: The query parameters to use in the query.
54+
It should be a ``dictionary`` of key-value pairs.
55+
:return: The query result in the specified mode.
56+
"""
57+
from influxdb_client_3 import polars as has_polars, _merge_options as merge_options
58+
try:
59+
# Create an authorization header
60+
optargs = {
61+
"headers": [(b"authorization", f"Bearer {self._token}".encode('utf-8'))],
62+
"timeout": 300
63+
}
64+
opts = merge_options(optargs, exclude_keys=['query_parameters'], custom=kwargs)
65+
_options = FlightCallOptions(**opts)
66+
67+
#
68+
# Ticket data
69+
#
70+
ticket_data = {
71+
"database": database,
72+
"sql_query": query,
73+
"query_type": language
74+
}
75+
# add query parameters
76+
query_parameters = kwargs.get("query_parameters", None)
77+
if query_parameters:
78+
ticket_data["params"] = query_parameters
79+
80+
ticket = Ticket(json.dumps(ticket_data).encode('utf-8'))
81+
flight_reader = self._do_get(ticket, _options)
82+
83+
mode_funcs = {
84+
"all": flight_reader.read_all,
85+
"pandas": flight_reader.read_pandas,
86+
"chunk": lambda: flight_reader,
87+
"reader": flight_reader.to_reader,
88+
"schema": lambda: flight_reader.schema
89+
}
90+
if has_polars:
91+
import polars as pl
92+
mode_funcs["polars"] = lambda: pl.from_arrow(flight_reader.read_all())
93+
mode_func = mode_funcs.get(mode, flight_reader.read_all)
94+
95+
return mode_func() if callable(mode_func) else mode_func
96+
except Exception as e:
97+
raise e
98+
99+
def _do_get(self, ticket: Ticket, options: FlightCallOptions = None) -> FlightStreamReader:
100+
return self._flight_client.do_get(ticket, options)
101+
102+
def close(self):
103+
"""Close the Flight client."""
104+
self._flight_client.close()

pytest.ini

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[pytest]
2+
markers =
3+
integration: marks integration tests (deselect with '-m "not integration"')

tests/test_influxdb_client_3.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@ class TestInfluxDBClient3(unittest.TestCase):
88

99
@patch('influxdb_client_3._InfluxDBClient')
1010
@patch('influxdb_client_3._WriteApi')
11-
@patch('influxdb_client_3.FlightClient')
12-
def setUp(self, mock_flight_client, mock_write_api, mock_influx_db_client):
11+
@patch('influxdb_client_3._QueryApi')
12+
def setUp(self, mock_query_api, mock_write_api, mock_influx_db_client):
1313
self.mock_influx_db_client = mock_influx_db_client
1414
self.mock_write_api = mock_write_api
15-
self.mock_flight_client = mock_flight_client
15+
self.mock_query_api = mock_query_api
1616
self.client = InfluxDBClient3(
1717
host="localhost",
1818
org="my_org",
@@ -25,7 +25,7 @@ def test_init(self):
2525
self.assertEqual(self.client._database, "my_db")
2626
self.assertEqual(self.client._client, self.mock_influx_db_client.return_value)
2727
self.assertEqual(self.client._write_api, self.mock_write_api.return_value)
28-
self.assertEqual(self.client._flight_client, self.mock_flight_client.return_value)
28+
self.assertEqual(self.client._query_api, self.mock_query_api.return_value)
2929

3030

3131
if __name__ == '__main__':

tests/test_influxdb_client_3_integration.py

+1
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def test_write_and_query(self):
3939

4040
df = self.client.query(sql, mode="pandas", query_parameters={'type': 'used', 'test_id': test_id})
4141

42+
self.assertIsNotNone(df)
4243
self.assertEqual(1, len(df))
4344
self.assertEqual(test_id, df['test_id'][0])
4445
self.assertEqual(123.0, df['value'][0])

tests/test_query.py

+6-13
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import unittest
22
import struct
3-
from unittest.mock import Mock, patch, ANY
3+
from unittest.mock import Mock, ANY
44

55
from pyarrow import (
66
array,
@@ -66,7 +66,8 @@ class HeaderCheckServerMiddleware(ServerMiddleware):
6666
Middleware needed to catch request headers via factory
6767
N.B. As found in pyarrow tests
6868
"""
69-
def __init__(self, token):
69+
def __init__(self, token, *args, **kwargs):
70+
super().__init__(*args, **kwargs)
7071
self.token = token
7172

7273
def sending_headers(self):
@@ -114,25 +115,17 @@ def test_influx_default_query_headers():
114115

115116
class TestQuery(unittest.TestCase):
116117

117-
@patch('influxdb_client_3._InfluxDBClient')
118-
@patch('influxdb_client_3._WriteApi')
119-
@patch('influxdb_client_3.FlightClient')
120-
def setUp(self, mock_flight_client, mock_write_api, mock_influx_db_client):
121-
self.mock_influx_db_client = mock_influx_db_client
122-
self.mock_write_api = mock_write_api
123-
self.mock_flight_client = mock_flight_client
118+
def setUp(self):
124119
self.client = InfluxDBClient3(
125120
host="localhost",
126121
org="my_org",
127122
database="my_db",
128123
token="my_token"
129124
)
130-
self.client._flight_client = mock_flight_client
131-
self.client._write_api = mock_write_api
132125

133126
def test_query_without_parameters(self):
134127
mock_do_get = Mock()
135-
self.client._flight_client.do_get = mock_do_get
128+
self.client._query_api._do_get = mock_do_get
136129

137130
self.client.query('SELECT * FROM measurement')
138131

@@ -146,7 +139,7 @@ def test_query_without_parameters(self):
146139

147140
def test_query_with_parameters(self):
148141
mock_do_get = Mock()
149-
self.client._flight_client.do_get = mock_do_get
142+
self.client._query_api._do_get = mock_do_get
150143

151144
self.client.query('SELECT * FROM measurement WHERE time > $time', query_parameters={"time": "2021-01-01"})
152145

0 commit comments

Comments
 (0)