Skip to content

Commit 3d209cb

Browse files
committed
[py] get auth header from client config
Signed-off-by: Viet Nguyen Duc <[email protected]>
1 parent 5600cc7 commit 3d209cb

File tree

3 files changed

+88
-30
lines changed

3 files changed

+88
-30
lines changed

py/selenium/webdriver/remote/client_config.py

+73-30
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
1616
# under the License.
17+
import base64
1718
import os
1819
from urllib import parse
1920

@@ -26,11 +27,19 @@ def __init__(
2627
self,
2728
remote_server_addr: str,
2829
keep_alive: bool = True,
29-
proxy=None,
30+
proxy: Proxy = Proxy(raw={"proxyType": ProxyType.SYSTEM}),
31+
username: str = None,
32+
password: str = None,
33+
auth_type: str = "Basic",
34+
token: str = None,
3035
) -> None:
3136
self.remote_server_addr = remote_server_addr
3237
self.keep_alive = keep_alive
3338
self.proxy = proxy
39+
self.username = username
40+
self.password = password
41+
self.auth_type = auth_type
42+
self.token = token
3443

3544
@property
3645
def remote_server_addr(self) -> str:
@@ -57,8 +66,6 @@ def keep_alive(self, value: bool) -> None:
5766
@property
5867
def proxy(self) -> Proxy:
5968
""":Returns: The proxy used for communicating to the driver/server."""
60-
61-
self._proxy = self._proxy or Proxy(raw={"proxyType": ProxyType.SYSTEM})
6269
return self._proxy
6370

6471
@proxy.setter
@@ -71,34 +78,70 @@ def proxy(self, proxy: Proxy) -> None:
7178
"""
7279
self._proxy = proxy
7380

74-
def get_proxy_url(self):
75-
if self.proxy.proxy_type == ProxyType.DIRECT:
81+
@property
82+
def username(self) -> str:
83+
return self._username
84+
85+
@username.setter
86+
def username(self, value: str) -> None:
87+
self._username = value
88+
89+
@property
90+
def password(self) -> str:
91+
return self._password
92+
93+
@password.setter
94+
def password(self, value: str) -> None:
95+
self._password = value
96+
97+
@property
98+
def auth_type(self) -> str:
99+
return self._auth_type
100+
101+
@auth_type.setter
102+
def auth_type(self, value: str) -> None:
103+
self._auth_type = value
104+
105+
@property
106+
def token(self) -> str:
107+
return self._token
108+
109+
@token.setter
110+
def token(self, value: str) -> None:
111+
self._token = value
112+
113+
def get_proxy_url(self) -> str:
114+
proxy_type = self.proxy.proxy_type
115+
remote_add = parse.urlparse(self.remote_server_addr)
116+
if proxy_type == ProxyType.DIRECT:
76117
return None
77-
elif self.proxy.proxy_type == ProxyType.SYSTEM:
118+
if proxy_type == ProxyType.SYSTEM:
78119
_no_proxy = os.environ.get("no_proxy", os.environ.get("NO_PROXY"))
79120
if _no_proxy:
80-
for npu in _no_proxy.split(","):
81-
npu = npu.strip()
82-
if npu == "*":
121+
for entry in map(str.strip, _no_proxy.split(",")):
122+
if entry == "*":
83123
return None
84-
n_url = parse.urlparse(npu)
85-
remote_add = parse.urlparse(self.remote_server_addr)
86-
if n_url.netloc:
87-
if remote_add.netloc == n_url.netloc:
88-
return None
89-
else:
90-
if n_url.path in remote_add.netloc:
91-
return None
92-
if self.remote_server_addr.startswith("https://"):
93-
return os.environ.get("https_proxy", os.environ.get("HTTPS_PROXY"))
94-
if self.remote_server_addr.startswith("http://"):
95-
return os.environ.get("http_proxy", os.environ.get("HTTP_PROXY"))
96-
elif self.proxy.proxy_type == ProxyType.MANUAL:
97-
if self.remote_server_addr.startswith("https://"):
98-
return self.proxy.sslProxy
99-
elif self.remote_server_addr.startswith("http://"):
100-
return self.proxy.http_proxy
101-
else:
102-
return None
103-
else:
104-
return None
124+
n_url = parse.urlparse(entry)
125+
if n_url.netloc and remote_add.netloc == n_url.netloc:
126+
return None
127+
if n_url.path in remote_add.netloc:
128+
return None
129+
return os.environ.get(
130+
"https_proxy" if self.remote_server_addr.startswith("https://") else "http_proxy",
131+
os.environ.get("HTTPS_PROXY" if self.remote_server_addr.startswith("https://") else "HTTP_PROXY"),
132+
)
133+
if proxy_type == ProxyType.MANUAL:
134+
return self.proxy.sslProxy if self.remote_server_addr.startswith("https://") else self.proxy.http_proxy
135+
return None
136+
137+
def get_auth_header(self):
138+
auth_type = self.auth_type.lower()
139+
if auth_type == "basic" and self.username and self.password:
140+
credentials = f"{self.username}:{self.password}"
141+
encoded_credentials = base64.b64encode(credentials.encode()).decode()
142+
return {"Authorization": f"Basic {encoded_credentials}"}
143+
elif auth_type == "bearer" and self.token:
144+
return {"Authorization": f"Bearer {self.token}"}
145+
elif auth_type == "oauth" and self.token:
146+
return {"Authorization": f"OAuth {self.token}"}
147+
return None

py/selenium/webdriver/remote/remote_connection.py

+5
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,11 @@ def _request(self, method, url, body=None):
323323
"""
324324
parsed_url = parse.urlparse(url)
325325
headers = self.get_remote_connection_headers(parsed_url, self._client_config.keep_alive)
326+
auth_header = self._client_config.get_auth_header()
327+
328+
if auth_header:
329+
headers.update(auth_header)
330+
326331
if body and method not in ("POST", "PUT"):
327332
body = None
328333

py/test/unit/selenium/webdriver/remote/remote_connection_tests.py

+10
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import urllib3
2222

2323
from selenium import __version__
24+
from selenium.webdriver.remote.remote_connection import ClientConfig
2425
from selenium.webdriver.remote.remote_connection import RemoteConnection
2526

2627

@@ -54,6 +55,15 @@ def test_get_proxy_url_http(mock_proxy_settings):
5455
assert proxy_url == proxy
5556

5657

58+
def test_get_auth_header_if_client_config_pass():
59+
custom_config = ClientConfig(
60+
remote_server_addr="http://remote", keep_alive=True, username="user", password="pass", auth_type="Basic"
61+
)
62+
remote_connection = RemoteConnection(custom_config.remote_server_addr, client_config=custom_config)
63+
headers = remote_connection._client_config.get_auth_header()
64+
assert headers.get("Authorization") == "Basic dXNlcjpwYXNz"
65+
66+
5767
def test_get_proxy_url_https(mock_proxy_settings):
5868
proxy = "http://https_proxy.com:8080"
5969
remote_connection = RemoteConnection("https://remote", keep_alive=False)

0 commit comments

Comments
 (0)