Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feature] Upgrade to channels 4.x #332

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions openwisp_notifications/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import pytest
import asyncio

@pytest.fixture(scope='session')
def event_loop():
"""Create an instance of the default event loop for each test case."""
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()
34 changes: 18 additions & 16 deletions openwisp_notifications/tests/test_websockets.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import sys
import uuid
import asyncio
from datetime import timedelta
from unittest.mock import patch

import pytest
from channels.db import database_sync_to_async
from channels.testing import WebsocketCommunicator
from channels.layers import get_channel_layer
from django.conf import settings
from django.contrib.auth import get_user_model
from django.contrib.contenttypes.models import ContentType
Expand Down Expand Up @@ -73,6 +75,15 @@ def create_object_notification(admin_user):
@pytest.mark.django_db(transaction=True)
class TestNotificationSockets:
application = import_string(getattr(settings, 'ASGI_APPLICATION'))
channel_layer = get_channel_layer()

@pytest.fixture(autouse=True)
async def setup_test(self):
# Clear channel layer before each test
await self.channel_layer.flush()
yield
# Cleanup after each test
await self.channel_layer.flush()

async def _get_communicator(self, admin_client):
session_id = admin_client.cookies['sessionid'].value
Expand All @@ -86,22 +97,13 @@ async def _get_communicator(self, admin_client):
)
],
)
connected, subprotocol = await communicator.connect()
assert connected is True
return communicator

async def test_new_notification_created(self, admin_user, admin_client):
communicator = await self._get_communicator(admin_client)
n = await create_notification(admin_user)
response = await communicator.receive_json_from()
expected_response = {
'type': 'notification',
'notification_count': 1,
'reload_widget': True,
'notification': NotificationListSerializer(n).data,
}
assert response == expected_response
await communicator.disconnect()
try:
connected, _ = await communicator.connect()
assert connected is True
return communicator
except Exception as e:
await communicator.disconnect()
raise e

async def test_read_notification(self, admin_user, admin_client):
n = await create_notification(admin_user)
Expand Down
56 changes: 30 additions & 26 deletions openwisp_notifications/websockets/consumers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json

from asgiref.sync import async_to_sync
from channels.generic.websocket import WebsocketConsumer
from channels.generic.websocket import AsyncWebsocketConsumer
from asgiref.sync import sync_to_async
from django.contrib.contenttypes.models import ContentType
from django.utils.timezone import now, timedelta

Expand All @@ -14,35 +14,36 @@
IgnoreObjectNotification = load_model('IgnoreObjectNotification')


class NotificationConsumer(WebsocketConsumer):
class NotificationConsumer(AsyncWebsocketConsumer):
_initial_backoff = app_settings.NOTIFICATION_STORM_PREVENTION['initial_backoff']
_backoff_increment = app_settings.NOTIFICATION_STORM_PREVENTION['backoff_increment']
_max_allowed_backoff = app_settings.NOTIFICATION_STORM_PREVENTION[
'max_allowed_backoff'
]

def _is_user_authenticated(self):
async def _is_user_authenticated(self):
try:
assert self.scope['user'].is_authenticated is True
except (KeyError, AssertionError):
self.close()
await self.close()
return False
else:
return True

def connect(self):
if self._is_user_authenticated():
async_to_sync(self.channel_layer.group_add)(
'ow-notification-{0}'.format(self.scope['user'].pk), self.channel_name
async def connect(self):
if await self._is_user_authenticated():
await self.channel_layer.group_add(
f'ow-notification-{self.scope["user"].pk}', self.channel_name
)
self.accept()
await self.accept()
self.scope['last_update_datetime'] = now()
self.scope['backoff'] = self._initial_backoff

def disconnect(self, close_code):
async_to_sync(self.channel_layer.group_discard)(
'ow-notification-{0}'.format(self.scope['user'].pk), self.channel_name
)
async def disconnect(self, close_code):
if hasattr(self, 'scope') and 'user' in self.scope:
await self.channel_layer.group_discard(
f'ow-notification-{self.scope["user"].pk}', self.channel_name
)

def process_event_for_notification_storm(self, event):
if not event['in_notification_storm']:
Expand Down Expand Up @@ -75,9 +76,9 @@ def process_event_for_notification_storm(self, event):
self.scope['backoff'] = self._initial_backoff
return event

def send_updates(self, event):
async def send_updates(self, event):
event = self.process_event_for_notification_storm(event)
self.send(
await self.send(
json.dumps(
{
'type': 'notification',
Expand All @@ -88,27 +89,28 @@ def send_updates(self, event):
)
)

def receive(self, text_data):
if self._is_user_authenticated():
async def receive(self, text_data):
if await self._is_user_authenticated():
try:
json_data = json.loads(text_data)
except json.JSONDecodeError:
return

try:
if json_data['type'] == 'notification':
self._notification_handler(
await self._notification_handler(
notification_id=json_data['notification_id']
)
elif json_data['type'] == 'object_notification':
self._object_notification_handler(
await self._object_notification_handler(
object_id=json_data['object_id'],
app_label=json_data['app_label'],
model_name=json_data['model_name'],
)
except KeyError:
return

@sync_to_async
def _notification_handler(self, notification_id):
try:
notification = Notification.objects.get(
Expand All @@ -118,18 +120,20 @@ def _notification_handler(self, notification_id):
except Notification.DoesNotExist:
return

def _object_notification_handler(self, object_id, app_label, model_name):
async def _object_notification_handler(self, object_id, app_label, model_name):
try:
object_notification = IgnoreObjectNotification.objects.get(
object_notification = await sync_to_async(IgnoreObjectNotification.objects.get)(
user=self.scope['user'],
object_id=object_id,
object_content_type_id=ContentType.objects.get_by_natural_key(
app_label=app_label,
model=model_name,
object_content_type_id=(
await sync_to_async(ContentType.objects.get_by_natural_key)(
app_label=app_label,
model=model_name,
)
).pk,
)
serialized_data = IgnoreObjectNotificationSerializer(object_notification)
self.send(
await self.send(
json.dumps(
{
'type': 'object_notification',
Expand Down
9 changes: 5 additions & 4 deletions openwisp_notifications/websockets/handlers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from asgiref.sync import async_to_sync
from channels import layers
from channels.layers import get_channel_layer
from django.core.cache import cache
from django.utils.timezone import now, timedelta

Expand Down Expand Up @@ -52,14 +52,15 @@ def user_in_notification_storm(user):
return in_notification_storm


def notification_update_handler(reload_widget=False, notification=None, recipient=None):
channel_layer = layers.get_channel_layer()
async def notification_update_handler(reload_widget=False, notification=None, recipient=None):
channel_layer = get_channel_layer()
try:
assert notification is not None
notification = NotificationListSerializer(notification).data
except (NotFound, AssertionError):
pass
async_to_sync(channel_layer.group_send)(

await channel_layer.group_send(
f'ow-notification-{recipient.pk}',
{
'type': 'send.updates',
Expand Down
18 changes: 8 additions & 10 deletions openwisp_notifications/websockets/routing.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
# chat/routing.py
from channels.routing import ProtocolTypeRouter, URLRouter
from django.urls import path
from . import consumers

from . import consumers as ow_consumers


def get_routes(consumer=None):
if not consumer:
consumer = ow_consumers
return [
path('ws/notification/', consumer.NotificationConsumer.as_asgi()),
]
def get_routes():
return ProtocolTypeRouter({
"websocket": URLRouter([
path("ws/notifications/", consumers.NotificationConsumer.as_asgi()),
]),
})
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
django-notifications-hq~=1.8.3
channels~=3.0.2
channels>=4.0.0,<5.0.0
openwisp-users @ https://github.com/openwisp/openwisp-users/tarball/1.2
openwisp-utils[rest,celery] @ https://github.com/openwisp/openwisp-utils/tarball/1.2
markdown~=3.7.0
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def get_install_requires():
packages=find_packages(exclude=['tests*', 'docs*']),
include_package_data=True,
zip_safe=False,
install_requires=get_install_requires(),
install_requires=get_install_requires() + ['channels>=4.0.0,<5.0.0'],
classifiers=[
'Development Status :: 5 - Production/Stable',
'Environment :: Web Environment',
Expand Down
21 changes: 18 additions & 3 deletions tests/openwisp2/asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,37 @@

from channels.auth import AuthMiddlewareStack
from channels.routing import ProtocolTypeRouter, URLRouter
from channels.security.websocket import AllowedHostsOriginValidator
from django.core.asgi import get_asgi_application

from openwisp_notifications.websockets.routing import get_routes

os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'openwisp2.settings')
django_asgi_app = get_asgi_application()

if os.environ.get('SAMPLE_APP', False):
# Load custom routes:
# This should be done when you are extending the app and modifying
# the web socket consumer in your extended app.
from .sample_notifications import consumers

application = ProtocolTypeRouter(
{'websocket': AuthMiddlewareStack(URLRouter(get_routes(consumers)))}
{
'http': django_asgi_app,
'websocket': AllowedHostsOriginValidator(
AuthMiddlewareStack(URLRouter(get_routes(consumers)))
),
}
)
else:
# Load openwisp_notifications consumers:
# This can be used when you are extending the app but not making
# any changes in the web socket consumer.
application = ProtocolTypeRouter(
{'websocket': AuthMiddlewareStack(URLRouter(get_routes()))}
)
{
'http': django_asgi_app,
'websocket': AllowedHostsOriginValidator(
AuthMiddlewareStack(URLRouter(get_routes()))
),
}
)
Loading