|
18 | 18 | import logging
|
19 | 19 | import socket
|
20 | 20 | import time
|
| 21 | +from functools import cached_property |
21 | 22 | from types import TracebackType
|
22 | 23 | from typing import (
|
23 | 24 | TYPE_CHECKING,
|
@@ -143,40 +144,47 @@ class _HiveClient:
|
143 | 144 | """Helper class to nicely open and close the transport."""
|
144 | 145 |
|
145 | 146 | _transport: TTransport
|
146 |
| - _client: Client |
147 | 147 | _ugi: Optional[List[str]]
|
148 | 148 |
|
149 | 149 | def __init__(self, uri: str, ugi: Optional[str] = None, kerberos_auth: Optional[bool] = HIVE_KERBEROS_AUTH_DEFAULT):
|
150 | 150 | self._uri = uri
|
151 | 151 | self._kerberos_auth = kerberos_auth
|
152 | 152 | self._ugi = ugi.split(":") if ugi else None
|
| 153 | + self._transport = self._init_thrift_transport() |
153 | 154 |
|
154 |
| - self._init_thrift_client() |
155 |
| - |
156 |
| - def _init_thrift_client(self) -> None: |
| 155 | + def _init_thrift_transport(self) -> TTransport: |
157 | 156 | url_parts = urlparse(self._uri)
|
158 |
| - |
159 | 157 | socket = TSocket.TSocket(url_parts.hostname, url_parts.port)
|
160 |
| - |
161 | 158 | if not self._kerberos_auth:
|
162 |
| - self._transport = TTransport.TBufferedTransport(socket) |
| 159 | + return TTransport.TBufferedTransport(socket) |
163 | 160 | else:
|
164 |
| - self._transport = TTransport.TSaslClientTransport(socket, host=url_parts.hostname, service="hive") |
| 161 | + return TTransport.TSaslClientTransport(socket, host=url_parts.hostname, service="hive") |
165 | 162 |
|
| 163 | + @cached_property |
| 164 | + def _client(self) -> Client: |
166 | 165 | protocol = TBinaryProtocol.TBinaryProtocol(self._transport)
|
167 |
| - |
168 |
| - self._client = Client(protocol) |
| 166 | + client = Client(protocol) |
| 167 | + if self._ugi: |
| 168 | + client.set_ugi(*self._ugi) |
| 169 | + return client |
169 | 170 |
|
170 | 171 | def __enter__(self) -> Client:
|
171 |
| - self._transport.open() |
172 |
| - if self._ugi: |
173 |
| - self._client.set_ugi(*self._ugi) |
| 172 | + """Make sure the transport is initialized and open.""" |
| 173 | + if not self._transport.isOpen(): |
| 174 | + try: |
| 175 | + self._transport.open() |
| 176 | + except TTransport.TTransportException: |
| 177 | + # reinitialize _transport |
| 178 | + self._transport = self._init_thrift_transport() |
| 179 | + self._transport.open() |
174 | 180 | return self._client
|
175 | 181 |
|
176 | 182 | def __exit__(
|
177 | 183 | self, exctype: Optional[Type[BaseException]], excinst: Optional[BaseException], exctb: Optional[TracebackType]
|
178 | 184 | ) -> None:
|
179 |
| - self._transport.close() |
| 185 | + """Close transport if it was opened.""" |
| 186 | + if self._transport.isOpen(): |
| 187 | + self._transport.close() |
180 | 188 |
|
181 | 189 |
|
182 | 190 | def _construct_hive_storage_descriptor(
|
|
0 commit comments