99from __future__ import unicode_literals
1010
1111from builtins import object
12+ from decimal import Decimal
13+
1214from pyhive import common
1315from pyhive .common import DBAPITypeObject
1416# Make all exceptions visible in this module per DB-API
3436
3537_logger = logging .getLogger (__name__ )
3638
39+ TYPES_CONVERTER = {
40+ "decimal" : Decimal ,
41+ # As of Presto 0.69, binary data is returned as the varbinary type in base64 format
42+ "varbinary" : base64 .b64decode
43+ }
3744
3845class PrestoParamEscaper (common .ParamEscaper ):
3946 def escape_datetime (self , item , format ):
@@ -307,14 +314,13 @@ def _fetch_more(self):
307314 """Fetch the next URI and update state"""
308315 self ._process_response (self ._requests_session .get (self ._nextUri , ** self ._requests_kwargs ))
309316
310- def _decode_binary (self , rows ):
311- # As of Presto 0.69, binary data is returned as the varbinary type in base64 format
312- # This function decodes base64 data in place
317+ def _process_data (self , rows ):
313318 for i , col in enumerate (self .description ):
314- if col [1 ] == 'varbinary' :
319+ col_type = col [1 ].split ("(" )[0 ].lower ()
320+ if col_type in TYPES_CONVERTER :
315321 for row in rows :
316322 if row [i ] is not None :
317- row [i ] = base64 . b64decode (row [i ])
323+ row [i ] = TYPES_CONVERTER [ col_type ] (row [i ])
318324
319325 def _process_response (self , response ):
320326 """Given the JSON response from Presto's REST API, update the internal state with the next
@@ -341,7 +347,7 @@ def _process_response(self, response):
341347 if 'data' in response_json :
342348 assert self ._columns
343349 new_data = response_json ['data' ]
344- self ._decode_binary (new_data )
350+ self ._process_data (new_data )
345351 self ._data += map (tuple , new_data )
346352 if 'nextUri' not in response_json :
347353 self ._state = self ._STATE_FINISHED
0 commit comments