Skip to content
Merged
30 changes: 23 additions & 7 deletions monty/aiohttp_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
import socket
import sys
from datetime import timedelta
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, TypedDict
from unittest.mock import Mock

import aiohttp
from multidict import CIMultiDict, CIMultiDictProxy

from monty import constants
from monty.log import get_logger
from monty.utils import helpers
from monty.utils.caching import RedisCache


Expand Down Expand Up @@ -45,14 +46,29 @@ async def _on_request_end(
)


class SessionArgs(TypedDict):
proxy: str | None
connector: aiohttp.BaseConnector


def session_args_for_proxy(proxy: str | None) -> SessionArgs:
"""Create a dict with `proxy` and `connector` items, to be passed to aiohttp.ClientSession."""
connector = aiohttp.TCPConnector(
resolver=aiohttp.AsyncResolver(),
family=socket.AF_INET,
ssl=(
helpers._SSL_CONTEXT_UNVERIFIED
if (proxy and proxy.startswith("http://"))
else helpers._SSL_CONTEXT_VERIFIED
),
)
return {"proxy": proxy or None, "connector": connector}


class CachingClientSession(aiohttp.ClientSession):
def __init__(self, *args: Any, **kwargs: Any) -> None:
if "connector" not in kwargs:
kwargs["connector"] = aiohttp.TCPConnector(
resolver=aiohttp.AsyncResolver(),
family=socket.AF_INET,
verify_ssl=not bool(constants.Client.proxy and constants.Client.proxy.startswith("http://")),
)
kwargs.update(session_args_for_proxy(kwargs.get("proxy")))

if "trace_configs" not in kwargs:
trace_config = aiohttp.TraceConfig()
trace_config.on_request_end.append(_on_request_end)
Expand Down
6 changes: 5 additions & 1 deletion monty/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from sqlalchemy.orm import selectinload

from monty import constants
from monty.aiohttp_session import CachingClientSession
from monty.aiohttp_session import CachingClientSession, session_args_for_proxy
from monty.database import Feature, Guild, GuildConfig
from monty.database.rollouts import Rollout
from monty.log import get_logger
Expand Down Expand Up @@ -60,6 +60,10 @@ def __init__(
if TEST_GUILDS:
kwargs["test_guilds"] = TEST_GUILDS
log.warning("registering as test_guilds")

# pass proxy and connector to disnake client
kwargs.update(session_args_for_proxy(proxy))

super().__init__(**kwargs)

self.redis_session = redis_session
Expand Down
7 changes: 2 additions & 5 deletions monty/exts/info/docs/_batch_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from monty import constants
from monty.bot import Monty
from monty.log import get_logger
from monty.utils import helpers, scheduling
from monty.utils import scheduling
from monty.utils.html_parsing import get_symbol_markdown

from . import _cog, doc_cache
Expand Down Expand Up @@ -120,11 +120,8 @@ async def get_markdown(self, doc_item: "_cog.DocItem") -> str | None:
if doc_item not in self._item_futures and doc_item not in self._queue:
self._item_futures[doc_item].user_requested = True

# providing a context is workaround for cloudflare issues
try:
async with self._bot.http_session.get(
doc_item.url, raise_for_status=True, ssl=helpers.ssl_create_default_context()
) as response:
async with self._bot.http_session.get(doc_item.url, raise_for_status=True) as response:
soup = await self._bot.loop.run_in_executor(
None,
BeautifulSoup,
Expand Down
9 changes: 8 additions & 1 deletion monty/exts/info/github_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import monty.utils.services
from monty import constants
from monty.aiohttp_session import session_args_for_proxy
from monty.bot import Monty
from monty.constants import Feature
from monty.errors import MontyCommandError
Expand Down Expand Up @@ -221,7 +222,13 @@ class GithubInfo(
def __init__(self, bot: Monty) -> None:
self.bot = bot

transport = AIOHTTPTransport(url="https://api.github.com/graphql", timeout=20, headers=GITHUB_REQUEST_HEADERS)
transport = AIOHTTPTransport(
url="https://api.github.com/graphql",
timeout=20,
headers=GITHUB_REQUEST_HEADERS,
# copy because invariance
client_session_args=dict(session_args_for_proxy(bot.http.proxy)),
)

self.gql_client = gql.Client(transport=transport, fetch_schema_from_transport=True)

Expand Down
4 changes: 2 additions & 2 deletions monty/utils/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from monty.bot import Monty
from monty.database import Feature, Rollout
from monty.log import get_logger
from monty.utils import helpers, inventory_parser
from monty.utils import inventory_parser
from monty.utils.extensions import EXTENSIONS, unqualify
from monty.utils.features import NAME_REGEX as FEATURE_NAME_REGEX

Expand Down Expand Up @@ -204,7 +204,7 @@ class ValidURL(commands.Converter):
async def convert(ctx: commands.Context, url: str) -> str:
"""This converter checks whether the given URL can be reached with a status code of 200."""
try:
async with ctx.bot.http_session.get(url, ssl=helpers.ssl_create_default_context()) as resp:
async with ctx.bot.http_session.get(url) as resp:
if resp.status != 200:
msg = f"HTTP GET on `{url}` returned status `{resp.status}`, expected 200"
raise commands.BadArgument(msg)
Expand Down
20 changes: 15 additions & 5 deletions monty/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,21 @@ def fromisoformat(timestamp: str) -> datetime.datetime:
return dt


def ssl_create_default_context() -> ssl.SSLContext:
"""Return an ssl context that CloudFlare shouldn't flag."""
ssl_context = ssl.create_default_context()
ssl_context.post_handshake_auth = True
return ssl_context
def _create_ssl_context(*, verify: bool) -> ssl.SSLContext:
if verify:
ctx = ssl.create_default_context()
else:
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
ctx.check_hostname = False
ctx.verify_mode = ssl.CERT_NONE
ctx.set_alpn_protocols(["http/1.1"])
# change tls fingerprint to avoid being flagged by cloudflare
ctx.post_handshake_auth = True
return ctx


_SSL_CONTEXT_VERIFIED = _create_ssl_context(verify=True)
_SSL_CONTEXT_UNVERIFIED = _create_ssl_context(verify=False)


@overload
Expand Down
5 changes: 1 addition & 4 deletions monty/utils/inventory_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import aiohttp

from monty.log import get_logger
from monty.utils import helpers
from monty.utils.caching import redis_cache


Expand Down Expand Up @@ -93,9 +92,7 @@ async def _load_v2(stream: aiohttp.StreamReader) -> InventoryDict:
async def _fetch_inventory(bot: Monty, url: str) -> InventoryDict:
"""Fetch, parse and return an intersphinx inventory file from an url."""
timeout = aiohttp.ClientTimeout(sock_connect=5, sock_read=5)
async with bot.http_session.get(
url, timeout=timeout, raise_for_status=True, use_cache=False, ssl=helpers.ssl_create_default_context()
) as response:
async with bot.http_session.get(url, timeout=timeout, raise_for_status=True, use_cache=False) as response:
stream = response.content

inventory_header = (await stream.readline()).decode().rstrip()
Expand Down