Skip to content

[py] Set user_agent and extra_headers via ClientConfig #14718

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Nov 9, 2024
55 changes: 44 additions & 11 deletions py/selenium/webdriver/remote/client_config.py
Original file line number Diff line number Diff line change
@@ -17,6 +17,7 @@
import base64
import os
import socket
from enum import Enum
from typing import Optional
from urllib import parse

@@ -26,6 +27,12 @@
from selenium.webdriver.common.proxy import ProxyType


class AuthType(Enum):
BASIC = "Basic"
BEARER = "Bearer"
X_API_KEY = "X-API-Key"


class ClientConfig:
def __init__(
self,
@@ -38,8 +45,10 @@ def __init__(
ca_certs: Optional[str] = None,
username: Optional[str] = None,
password: Optional[str] = None,
auth_type: Optional[str] = "Basic",
auth_type: Optional[AuthType] = AuthType.BASIC,
token: Optional[str] = None,
user_agent: Optional[str] = None,
extra_headers: Optional[dict] = None,
) -> None:
self.remote_server_addr = remote_server_addr
self.keep_alive = keep_alive
@@ -51,6 +60,8 @@ def __init__(
self.password = password
self.auth_type = auth_type
self.token = token
self.user_agent = user_agent
self.extra_headers = extra_headers

self.timeout = (
(
@@ -198,14 +209,17 @@ def password(self, value: str) -> None:
self._password = value

@property
def auth_type(self) -> str:
def auth_type(self) -> AuthType:
"""Returns the type of authentication to the remote server."""
return self._auth_type

@auth_type.setter
def auth_type(self, value: str) -> None:
def auth_type(self, value: AuthType) -> None:
"""Sets the type of authentication to the remote server if it is not
using basic with username and password."""
using basic with username and password.
:Args: value - AuthType enum value. For others, please use `extra_headers` instead
"""
self._auth_type = value

@property
@@ -219,6 +233,26 @@ def token(self, value: str) -> None:
auth_type is not basic."""
self._token = value

@property
def user_agent(self) -> str:
"""Returns user agent to be added to the request headers."""
return self._user_agent

@user_agent.setter
def user_agent(self, value: str) -> None:
"""Sets user agent to be added to the request headers."""
self._user_agent = value

@property
def extra_headers(self) -> dict:
"""Returns extra headers to be added to the request."""
return self._extra_headers

@extra_headers.setter
def extra_headers(self, value: dict) -> None:
"""Sets extra headers to be added to the request."""
self._extra_headers = value

def get_proxy_url(self) -> Optional[str]:
"""Returns the proxy URL to use for the connection."""
proxy_type = self.proxy.proxy_type
@@ -246,13 +280,12 @@ def get_proxy_url(self) -> Optional[str]:

def get_auth_header(self) -> Optional[dict]:
"""Returns the authorization to add to the request headers."""
auth_type = self.auth_type.lower()
if auth_type == "basic" and self.username and self.password:
if self.auth_type is AuthType.BASIC and self.username and self.password:
credentials = f"{self.username}:{self.password}"
encoded_credentials = base64.b64encode(credentials.encode("utf-8")).decode("utf-8")
return {"Authorization": f"Basic {encoded_credentials}"}
if auth_type == "bearer" and self.token:
return {"Authorization": f"Bearer {self.token}"}
if auth_type == "oauth" and self.token:
return {"Authorization": f"OAuth {self.token}"}
return {"Authorization": f"{AuthType.BASIC.value} {encoded_credentials}"}
if self.auth_type is AuthType.BEARER and self.token:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this looks so much better!

return {"Authorization": f"{AuthType.BEARER.value} {self.token}"}
if self.auth_type is AuthType.X_API_KEY and self.token:
return {f"{AuthType.X_API_KEY.value}": f"{self.token}"}
return None
20 changes: 12 additions & 8 deletions py/selenium/webdriver/remote/remote_connection.py
Original file line number Diff line number Diff line change
@@ -22,6 +22,7 @@
from base64 import b64encode
from typing import Optional
from urllib import parse
from urllib.parse import urlparse

import urllib3

@@ -243,6 +244,9 @@ def get_remote_connection_headers(cls, parsed_url, keep_alive=False):
}

if parsed_url.username:
warnings.warn(
"Embedding username and password in URL could be insecure, use ClientConfig instead", stacklevel=2
)
base64string = b64encode(f"{parsed_url.username}:{parsed_url.password}".encode())
headers.update({"Authorization": f"Basic {base64string.decode()}"})

@@ -255,16 +259,14 @@ def get_remote_connection_headers(cls, parsed_url, keep_alive=False):
return headers

def _identify_http_proxy_auth(self):
url = self._proxy_url
url = url[url.find(":") + 3 :]
return "@" in url and len(url[: url.find("@")]) > 0
parsed_url = urlparse(self._proxy_url)
if parsed_url.username and parsed_url.password:
return True

def _separate_http_proxy_auth(self):
url = self._proxy_url
protocol = url[: url.find(":") + 3]
no_protocol = url[len(protocol) :]
auth = no_protocol[: no_protocol.find("@")]
proxy_without_auth = protocol + no_protocol[len(auth) + 1 :]
parsed_url = urlparse(self._proxy_url)
proxy_without_auth = f"{parsed_url.scheme}://{parsed_url.hostname}:{parsed_url.port}"
auth = f"{parsed_url.username}:{parsed_url.password}"
return proxy_without_auth, auth

def _get_connection_manager(self):
@@ -312,6 +314,8 @@ def __init__(
RemoteConnection._timeout = self._client_config.timeout
RemoteConnection._ca_certs = self._client_config.ca_certs
RemoteConnection._client_config = self._client_config
RemoteConnection.extra_headers = self._client_config.extra_headers or RemoteConnection.extra_headers
RemoteConnection.user_agent = self._client_config.user_agent or RemoteConnection.user_agent

if remote_server_addr:
warnings.warn(
160 changes: 149 additions & 11 deletions py/test/unit/selenium/webdriver/remote/remote_connection_tests.py
Original file line number Diff line number Diff line change
@@ -15,13 +15,19 @@
# specific language governing permissions and limitations
# under the License.

import os
from unittest.mock import patch
from urllib import parse

import pytest
import urllib3
from urllib3.util import Retry
from urllib3.util import Timeout

from selenium import __version__
from selenium.webdriver import Proxy
from selenium.webdriver.common.proxy import ProxyType
from selenium.webdriver.remote.client_config import AuthType
from selenium.webdriver.remote.remote_connection import ClientConfig
from selenium.webdriver.remote.remote_connection import RemoteConnection

@@ -64,8 +70,13 @@ def test_get_remote_connection_headers_defaults():

def test_get_remote_connection_headers_adds_auth_header_if_pass():
url = "http://user:pass@remote"
headers = RemoteConnection.get_remote_connection_headers(parse.urlparse(url))
with pytest.warns(None) as record:
headers = RemoteConnection.get_remote_connection_headers(parse.urlparse(url))
assert headers.get("Authorization") == "Basic dXNlcjpwYXNz"
assert (
record[0].message.args[0]
== "Embedding username and password in URL could be insecure, use ClientConfig instead"
)


def test_get_remote_connection_headers_adds_keep_alive_if_requested():
@@ -81,22 +92,96 @@ def test_get_proxy_url_http(mock_proxy_settings):
assert proxy_url == proxy


def test_get_auth_header_if_client_config_pass():
def test_get_auth_header_if_client_config_pass_basic_auth():
custom_config = ClientConfig(
remote_server_addr="http://remote", keep_alive=True, username="user", password="pass", auth_type="Basic"
remote_server_addr="http://remote", keep_alive=True, username="user", password="pass", auth_type=AuthType.BASIC
)
remote_connection = RemoteConnection(custom_config.remote_server_addr, client_config=custom_config)
headers = remote_connection._client_config.get_auth_header()
assert headers.get("Authorization") == "Basic dXNlcjpwYXNz"


def test_get_auth_header_if_client_config_pass_bearer_token():
custom_config = ClientConfig(
remote_server_addr="http://remote", keep_alive=True, auth_type=AuthType.BEARER, token="dXNlcjpwYXNz"
)
remote_connection = RemoteConnection(custom_config.remote_server_addr, client_config=custom_config)
headers = remote_connection._client_config.get_auth_header()
assert headers.get("Authorization") == "Bearer dXNlcjpwYXNz"


def test_get_auth_header_if_client_config_pass_x_api_key():
custom_config = ClientConfig(
remote_server_addr="http://remote", keep_alive=True, auth_type=AuthType.X_API_KEY, token="abcdefgh123456789"
)
remote_connection = RemoteConnection(custom_config.remote_server_addr, client_config=custom_config)
headers = remote_connection._client_config.get_auth_header()
assert headers.get("X-API-Key") == "abcdefgh123456789"


def test_get_proxy_url_https(mock_proxy_settings):
proxy = "http://https_proxy.com:8080"
remote_connection = RemoteConnection("https://remote", keep_alive=False)
proxy_url = remote_connection._client_config.get_proxy_url()
assert proxy_url == proxy


def test_get_proxy_url_https_via_client_config():
client_config = ClientConfig(
remote_server_addr="https://localhost:4444",
proxy=Proxy({"proxyType": ProxyType.MANUAL, "sslProxy": "https://admin:admin@http_proxy.com:8080"}),
)
remote_connection = RemoteConnection(client_config=client_config)
conn = remote_connection._get_connection_manager()
assert isinstance(conn, urllib3.ProxyManager)
conn.proxy_url = "https://http_proxy.com:8080"
conn.connection_pool_kw["proxy_headers"] = urllib3.make_headers(proxy_basic_auth="admin:admin")


def test_get_proxy_url_http_via_client_config():
client_config = ClientConfig(
remote_server_addr="http://localhost:4444",
proxy=Proxy(
{
"proxyType": ProxyType.MANUAL,
"httpProxy": "http://admin:admin@http_proxy.com:8080",
"sslProxy": "https://admin:admin@http_proxy.com:8080",
}
),
)
remote_connection = RemoteConnection(client_config=client_config)
conn = remote_connection._get_connection_manager()
assert isinstance(conn, urllib3.ProxyManager)
conn.proxy_url = "http://http_proxy.com:8080"
conn.connection_pool_kw["proxy_headers"] = urllib3.make_headers(proxy_basic_auth="admin:admin")


def test_get_proxy_direct_via_client_config():
client_config = ClientConfig(
remote_server_addr="http://localhost:4444", proxy=Proxy({"proxyType": ProxyType.DIRECT})
)
remote_connection = RemoteConnection(client_config=client_config)
conn = remote_connection._get_connection_manager()
assert isinstance(conn, urllib3.PoolManager)
proxy_url = remote_connection._client_config.get_proxy_url()
assert proxy_url is None


def test_get_proxy_system_matches_no_proxy_via_client_config():
os.environ["HTTP_PROXY"] = "http://admin:admin@system_proxy.com:8080"
os.environ["NO_PROXY"] = "localhost,127.0.0.1"
client_config = ClientConfig(
remote_server_addr="http://localhost:4444", proxy=Proxy({"proxyType": ProxyType.SYSTEM})
)
remote_connection = RemoteConnection(client_config=client_config)
conn = remote_connection._get_connection_manager()
assert isinstance(conn, urllib3.PoolManager)
proxy_url = remote_connection._client_config.get_proxy_url()
assert proxy_url is None
os.environ.pop("HTTP_PROXY")
os.environ.pop("NO_PROXY")


def test_get_proxy_url_none(mock_proxy_settings_missing):
remote_connection = RemoteConnection("https://remote", keep_alive=False)
proxy_url = remote_connection._client_config.get_proxy_url()
@@ -295,6 +380,28 @@ def test_override_user_agent_in_headers(mock_get_remote_connection_headers, remo
assert headers.get("Content-Type") == "application/json;charset=UTF-8"


@patch("selenium.webdriver.remote.remote_connection.RemoteConnection.get_remote_connection_headers")
def test_override_user_agent_via_client_config(mock_get_remote_connection_headers):
client_config = ClientConfig(
remote_server_addr="http://localhost:4444",
user_agent="custom-agent/1.0 (python 3.8)",
extra_headers={"Content-Type": "application/xml;charset=UTF-8"},
)
remote_connection = RemoteConnection(client_config=client_config)

mock_get_remote_connection_headers.return_value = {
"Accept": "application/json",
"Content-Type": "application/xml;charset=UTF-8",
"User-Agent": "custom-agent/1.0 (python 3.8)",
}

headers = remote_connection.get_remote_connection_headers(parse.urlparse("http://localhost:4444"))

assert headers.get("User-Agent") == "custom-agent/1.0 (python 3.8)"
assert headers.get("Accept") == "application/json"
assert headers.get("Content-Type") == "application/xml;charset=UTF-8"


@patch("selenium.webdriver.remote.remote_connection.RemoteConnection._request")
def test_register_extra_headers(mock_request, remote_connection):
RemoteConnection.extra_headers = {"Foo": "bar"}
@@ -307,6 +414,26 @@ def test_register_extra_headers(mock_request, remote_connection):
assert headers["Foo"] == "bar"


@patch("selenium.webdriver.remote.remote_connection.RemoteConnection._request")
def test_register_extra_headers_via_client_config(mock_request):
client_config = ClientConfig(
remote_server_addr="http://localhost:4444",
extra_headers={
"Authorization": "AWS4-HMAC-SHA256",
"Credential": "abc/20200618/us-east-1/execute-api/aws4_request",
},
)
remote_connection = RemoteConnection(client_config=client_config)

mock_request.return_value = {"status": 200, "value": "OK"}
remote_connection.execute("newSession", {})

mock_request.assert_called_once_with("POST", "http://localhost:4444/session", body="{}")
headers = remote_connection.get_remote_connection_headers(parse.urlparse("http://localhost:4444"), False)
assert headers["Authorization"] == "AWS4-HMAC-SHA256"
assert headers["Credential"] == "abc/20200618/us-east-1/execute-api/aws4_request"


def test_backwards_compatibility_with_appium_connection():
# Keep backward compatibility for AppiumConnection - https://github.com/SeleniumHQ/selenium/issues/14694
client_config = ClientConfig(remote_server_addr="http://remote", ca_certs="/path/to/cacert.pem", timeout=300)
@@ -328,14 +455,16 @@ def test_get_connection_manager_with_timeout_from_client_config():
assert conn.connection_pool_kw["timeout"] == 10
assert isinstance(conn, urllib3.PoolManager)


def test_connection_manager_with_timeout_via_client_config():
client_config = ClientConfig("http://remote", timeout=300)
remote_connection = RemoteConnection(client_config=client_config)
conn = remote_connection._get_connection_manager()
assert conn.connection_pool_kw["timeout"] == 300
assert isinstance(conn, urllib3.PoolManager)


def test_get_connection_manager_with_ca_certs_from_client_config():
def test_get_connection_manager_with_ca_certs():
remote_connection = RemoteConnection(remote_server_addr="http://remote")
remote_connection.set_certificate_bundle_path("/path/to/cacert.pem")
conn = remote_connection._get_connection_manager()
@@ -344,6 +473,8 @@ def test_get_connection_manager_with_ca_certs_from_client_config():
assert conn.connection_pool_kw["ca_certs"] == "/path/to/cacert.pem"
assert isinstance(conn, urllib3.PoolManager)


def test_connection_manager_with_ca_certs_via_client_config():
client_config = ClientConfig(remote_server_addr="http://remote", ca_certs="/path/to/cacert.pem")
remote_connection = RemoteConnection(client_config=client_config)
conn = remote_connection._get_connection_manager()
@@ -361,15 +492,17 @@ def test_get_connection_manager_ignores_certificates():
assert conn.connection_pool_kw["cert_reqs"] == "CERT_NONE"
assert isinstance(conn, urllib3.PoolManager)

remote_connection.reset_timeout()
assert remote_connection.get_timeout() is None


def test_connection_manager_ignores_certificates_via_client_config():
client_config = ClientConfig(remote_server_addr="http://remote", ignore_certificates=True, timeout=10)
remote_connection = RemoteConnection(client_config=client_config)
conn = remote_connection._get_connection_manager()
assert isinstance(conn, urllib3.PoolManager)
assert conn.connection_pool_kw["timeout"] == 10
assert conn.connection_pool_kw["cert_reqs"] == "CERT_NONE"
assert isinstance(conn, urllib3.PoolManager)

remote_connection.reset_timeout()
assert remote_connection.get_timeout() is None


def test_get_connection_manager_with_custom_args():
@@ -383,11 +516,16 @@ def test_get_connection_manager_with_custom_args():
assert conn.connection_pool_kw["retries"] == 3
assert conn.connection_pool_kw["block"] is True


def test_connection_manager_with_custom_args_via_client_config():
retries = Retry(connect=2, read=2, redirect=2)
timeout = Timeout(connect=300, read=3600)
client_config = ClientConfig(
remote_server_addr="http://remote", keep_alive=False, init_args_for_pool_manager=custom_args
remote_server_addr="http://localhost:4444",
init_args_for_pool_manager={"init_args_for_pool_manager": {"retries": retries, "timeout": timeout}},
)
remote_connection = RemoteConnection(client_config=client_config)
conn = remote_connection._get_connection_manager()
assert isinstance(conn, urllib3.PoolManager)
assert conn.connection_pool_kw["retries"] == 3
assert conn.connection_pool_kw["block"] is True
assert conn.connection_pool_kw["retries"] == retries
assert conn.connection_pool_kw["timeout"] == timeout