Skip to content

Commit e417531

Browse files
committed
refactor[py]: Add type hints to remote_connection.py
- Added explicit type annotations to selenium.webdriver.remote.remote_connection.py - Improved code clarity and static type checking - Ensured compatibility with modern type checkers like Pyright and Mypy
1 parent 5c2959b commit e417531

File tree

1 file changed

+22
-75
lines changed

1 file changed

+22
-75
lines changed

Diff for: py/selenium/webdriver/remote/remote_connection.py

+22-75
Original file line numberDiff line numberDiff line change
@@ -275,19 +275,11 @@ class RemoteConnection:
275275
import certifi
276276

277277
_timeout = (
278-
float(
279-
os.getenv(
280-
"GLOBAL_DEFAULT_TIMEOUT", str(socket.getdefaulttimeout())
281-
)
282-
)
278+
float(os.getenv("GLOBAL_DEFAULT_TIMEOUT", str(socket.getdefaulttimeout())))
283279
if os.getenv("GLOBAL_DEFAULT_TIMEOUT") is not None
284280
else socket.getdefaulttimeout()
285281
)
286-
_ca_certs = (
287-
os.getenv("REQUESTS_CA_BUNDLE")
288-
if "REQUESTS_CA_BUNDLE" in os.environ
289-
else certifi.where()
290-
)
282+
_ca_certs = os.getenv("REQUESTS_CA_BUNDLE") if "REQUESTS_CA_BUNDLE" in os.environ else certifi.where()
291283
_client_config: ClientConfig = None
292284

293285
system = platform.system().lower()
@@ -368,9 +360,7 @@ def set_certificate_bundle_path(cls, path: str):
368360
cls._client_config.ca_certs = path
369361

370362
@classmethod
371-
def get_remote_connection_headers(
372-
cls, parsed_url: str, keep_alive: bool = False
373-
) -> dict[str, Any]:
363+
def get_remote_connection_headers(cls, parsed_url: str, keep_alive: bool = False) -> dict[str, Any]:
374364
"""Get headers for remote request.
375365
376366
:Args:
@@ -389,12 +379,8 @@ def get_remote_connection_headers(
389379
"Embedding username and password in URL could be insecure, use ClientConfig instead",
390380
stacklevel=2,
391381
)
392-
base64string = b64encode(
393-
f"{parsed_url.username}:{parsed_url.password}".encode()
394-
)
395-
headers.update(
396-
{"Authorization": f"Basic {base64string.decode()}"}
397-
)
382+
base64string = b64encode(f"{parsed_url.username}:{parsed_url.password}".encode())
383+
headers.update({"Authorization": f"Basic {base64string.decode()}"})
398384

399385
if keep_alive:
400386
headers.update({"Connection": "keep-alive"})
@@ -411,25 +397,19 @@ def _identify_http_proxy_auth(self):
411397

412398
def _separate_http_proxy_auth(self):
413399
parsed_url = urlparse(self._proxy_url)
414-
proxy_without_auth = (
415-
f"{parsed_url.scheme}://{parsed_url.hostname}:{parsed_url.port}"
416-
)
400+
proxy_without_auth = f"{parsed_url.scheme}://{parsed_url.hostname}:{parsed_url.port}"
417401
auth = f"{parsed_url.username}:{parsed_url.password}"
418402
return proxy_without_auth, auth
419403

420404
def _get_connection_manager(self):
421405
pool_manager_init_args = {"timeout": self._client_config.timeout}
422406
pool_manager_init_args.update(
423-
self._client_config.init_args_for_pool_manager.get(
424-
"init_args_for_pool_manager", {}
425-
)
407+
self._client_config.init_args_for_pool_manager.get("init_args_for_pool_manager", {})
426408
)
427409

428410
if self._client_config.ignore_certificates:
429411
pool_manager_init_args["cert_reqs"] = "CERT_NONE"
430-
urllib3.disable_warnings(
431-
urllib3.exceptions.InsecureRequestWarning
432-
)
412+
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
433413
elif self._client_config.ca_certs:
434414
pool_manager_init_args["cert_reqs"] = "CERT_REQUIRED"
435415
pool_manager_init_args["ca_certs"] = self._client_config.ca_certs
@@ -438,21 +418,11 @@ def _get_connection_manager(self):
438418
if self._proxy_url.lower().startswith("sock"):
439419
from urllib3.contrib.socks import SOCKSProxyManager
440420

441-
return SOCKSProxyManager(
442-
self._proxy_url, **pool_manager_init_args
443-
)
421+
return SOCKSProxyManager(self._proxy_url, **pool_manager_init_args)
444422
if self._identify_http_proxy_auth():
445-
self._proxy_url, self._basic_proxy_auth = (
446-
self._separate_http_proxy_auth()
447-
)
448-
pool_manager_init_args["proxy_headers"] = (
449-
urllib3.make_headers(
450-
proxy_basic_auth=self._basic_proxy_auth
451-
)
452-
)
453-
return urllib3.ProxyManager(
454-
self._proxy_url, **pool_manager_init_args
455-
)
423+
self._proxy_url, self._basic_proxy_auth = self._separate_http_proxy_auth()
424+
pool_manager_init_args["proxy_headers"] = urllib3.make_headers(proxy_basic_auth=self._basic_proxy_auth)
425+
return urllib3.ProxyManager(self._proxy_url, **pool_manager_init_args)
456426

457427
return urllib3.PoolManager(**pool_manager_init_args)
458428

@@ -476,13 +446,8 @@ def __init__(
476446
RemoteConnection._timeout = self._client_config.timeout
477447
RemoteConnection._ca_certs = self._client_config.ca_certs
478448
RemoteConnection._client_config = self._client_config
479-
RemoteConnection.extra_headers = (
480-
self._client_config.extra_headers
481-
or RemoteConnection.extra_headers
482-
)
483-
RemoteConnection.user_agent = (
484-
self._client_config.user_agent or RemoteConnection.user_agent
485-
)
449+
RemoteConnection.extra_headers = self._client_config.extra_headers or RemoteConnection.extra_headers
450+
RemoteConnection.user_agent = self._client_config.user_agent or RemoteConnection.user_agent
486451

487452
if remote_server_addr:
488453
warnings.warn(
@@ -547,17 +512,11 @@ def execute(self, command: str, params: dict[Any, Any]) -> dict[str, Any]:
547512
- params - A dictionary of named parameters to send with the command as
548513
its JSON payload.
549514
"""
550-
command_info = self._commands.get(command) or self.extra_commands.get(
551-
command
552-
)
515+
command_info = self._commands.get(command) or self.extra_commands.get(command)
553516
assert command_info is not None, f"Unrecognised command {command}"
554517
path_string = command_info[1]
555518
path = string.Template(path_string).substitute(params)
556-
substitute_params = {
557-
word[1:]
558-
for word in path_string.split("/")
559-
if word.startswith("$")
560-
} # remove dollar sign
519+
substitute_params = {word[1:] for word in path_string.split("/") if word.startswith("$")} # remove dollar sign
561520
if isinstance(params, dict) and substitute_params:
562521
for word in substitute_params:
563522
del params[word]
@@ -567,9 +526,7 @@ def execute(self, command: str, params: dict[Any, Any]) -> dict[str, Any]:
567526
LOGGER.debug("%s %s %s", command_info[0], url, str(trimmed))
568527
return self._request(command_info[0], url, body=data)
569528

570-
def _request(
571-
self, method: str, url: str, body: str | None = None
572-
) -> dict[Any, Any]:
529+
def _request(self, method: str, url: str, body: str | None = None) -> dict[Any, Any]:
573530
"""Send an HTTP request to the remote server.
574531
575532
:Args:
@@ -581,9 +538,7 @@ def _request(
581538
A dictionary with the server's parsed JSON response.
582539
"""
583540
parsed_url = parse.urlparse(url)
584-
headers = self.get_remote_connection_headers(
585-
parsed_url, self._client_config.keep_alive
586-
)
541+
headers = self.get_remote_connection_headers(parsed_url, self._client_config.keep_alive)
587542
auth_header = self._client_config.get_auth_header()
588543

589544
if auth_header:
@@ -621,9 +576,7 @@ def _request(
621576
)
622577
try:
623578
if 300 <= statuscode < 304:
624-
return self._request(
625-
"GET", response.headers.get("location", None)
626-
)
579+
return self._request("GET", response.headers.get("location", None))
627580
if 399 < statuscode <= 500:
628581
if statuscode == 401:
629582
return {
@@ -636,9 +589,7 @@ def _request(
636589
}
637590
content_type = []
638591
if response.headers.get("Content-Type", None):
639-
content_type = response.headers.get(
640-
"Content-Type", None
641-
).split(";")
592+
content_type = response.headers.get("Content-Type", None).split(";")
642593
if not any([x.startswith("image/png") for x in content_type]):
643594
try:
644595
data = utils.load_json(data.strip())
@@ -665,9 +616,7 @@ def close(self):
665616
if hasattr(self, "_conn"):
666617
self._conn.clear()
667618

668-
def _trim_large_entries(
669-
self, input_dict: dict[Any, Any], max_length: int = 100
670-
) -> dict[str, str]:
619+
def _trim_large_entries(self, input_dict: dict[Any, Any], max_length: int = 100) -> dict[str, str]:
671620
"""Truncate string values in a dictionary if they exceed max_length.
672621
673622
:param dict: Dictionary with potentially large values
@@ -677,9 +626,7 @@ def _trim_large_entries(
677626
output_dictionary = {}
678627
for key, value in input_dict.items():
679628
if isinstance(value, dict):
680-
output_dictionary[key] = self._trim_large_entries(
681-
value, max_length
682-
)
629+
output_dictionary[key] = self._trim_large_entries(value, max_length)
683630
elif isinstance(value, str) and len(value) > max_length:
684631
output_dictionary[key] = value[:max_length] + "..."
685632
else:

0 commit comments

Comments
 (0)