diff --git a/python/sqlcommenter-python/google/cloud/sqlcommenter/django/middleware.py b/python/sqlcommenter-python/google/cloud/sqlcommenter/django/middleware.py index cf271a48..f92fed6f 100644 --- a/python/sqlcommenter-python/google/cloud/sqlcommenter/django/middleware.py +++ b/python/sqlcommenter-python/google/cloud/sqlcommenter/django/middleware.py @@ -15,9 +15,11 @@ # limitations under the License. import logging +from contextlib import ExitStack import django -from django.db import connection +from django.db import connections + from django.db.backends.utils import CursorDebugWrapper from google.cloud.sqlcommenter import add_sql_comment from google.cloud.sqlcommenter.opencensus import get_opencensus_values @@ -36,7 +38,9 @@ def __init__(self, get_response): self.get_response = get_response def __call__(self, request): - with connection.execute_wrapper(QueryWrapper(request)): + with ExitStack() as stack: + for db_alias in connections: + stack.enter_context(connections[db_alias].execute_wrapper(QueryWrapper(request))) return self.get_response(request) @@ -88,7 +92,7 @@ def __call__(self, execute, sql, params, many, context): # * https://github.com/basecamp/marginalia/pull/80 # Add the query to the query log if debugging. - if context['cursor'].__class__ is CursorDebugWrapper: + if isinstance(context['cursor'], CursorDebugWrapper): context['connection'].queries_log.append(sql) return execute(sql, params, many, context) diff --git a/python/sqlcommenter-python/tests/django/settings.py b/python/sqlcommenter-python/tests/django/settings.py index 04eedf66..d59a47ed 100644 --- a/python/sqlcommenter-python/tests/django/settings.py +++ b/python/sqlcommenter-python/tests/django/settings.py @@ -18,6 +18,10 @@ 'default': { 'ENGINE': 'django.db.backends.sqlite3', }, + + 'other': { + 'ENGINE': 'django.db.backends.sqlite3', + }, } INSTALLED_APPS = ['tests.django'] diff --git a/python/sqlcommenter-python/tests/django/tests.py b/python/sqlcommenter-python/tests/django/tests.py index b7c388c6..5f0f4b73 100644 --- a/python/sqlcommenter-python/tests/django/tests.py +++ b/python/sqlcommenter-python/tests/django/tests.py @@ -15,7 +15,7 @@ # limitations under the License. import django -from django.db import connection +from django.db import connection, connections from django.http import HttpRequest from django.test import TestCase, override_settings, modify_settings from django.urls import resolve, reverse @@ -41,7 +41,7 @@ def __call__(self, request): # Query log only active if DEBUG=True. @override_settings(DEBUG=True) class Tests(TestCase): - + databases = '__all__' @staticmethod def get_request(path): request = HttpRequest() @@ -55,6 +55,13 @@ def get_query(self, path='/'): self.assertEqual(len(connection.queries), 2) return connection.queries[0] + def get_query_other_db(self, path='/', connection_name='default'): + SqlCommenter(views.home_other_db)(self.get_request(path)) + # Query with comment added by QueryWrapper and unaltered query added + # by Django's CursorDebugWrapper. + self.assertEqual(len(connections[connection_name].queries), 2) + return connections[connection_name].queries[0] + def assertRoute(self, route, query): # route available in Django 2.2 and later. if django.VERSION < (2, 2): @@ -69,6 +76,13 @@ def test_basic(self): self.assertIn("framework='django%%3A" + django.get_version(), query) self.assertRoute('', query) + def test_basic_multiple_db_support(self): + query = self.get_query_other_db(path='/other/', connection_name='other') + self.assertIn("/*controller='some-other-db-path'", query) + # Expecting url_quoted("framework='django:'") + self.assertIn("framework='django%%3A" + django.get_version(), query) + self.assertRoute('other/', query) + def test_basic_disabled(self): with self.settings( SQLCOMMENTER_WITH_CONTROLLER=False, SQLCOMMENTER_WITH_ROUTE=False, diff --git a/python/sqlcommenter-python/tests/django/urls.py b/python/sqlcommenter-python/tests/django/urls.py index af364642..0a641444 100644 --- a/python/sqlcommenter-python/tests/django/urls.py +++ b/python/sqlcommenter-python/tests/django/urls.py @@ -21,5 +21,6 @@ urlpatterns = [ path('', views.home, name='home'), path('path/', views.home, name='some-path'), + path('other/', views.home_other_db, name='some-other-db-path'), path('app-urls/', include('tests.django.app_urls')), ] diff --git a/python/sqlcommenter-python/tests/django/views.py b/python/sqlcommenter-python/tests/django/views.py index 03e2392c..badc9685 100644 --- a/python/sqlcommenter-python/tests/django/views.py +++ b/python/sqlcommenter-python/tests/django/views.py @@ -20,5 +20,10 @@ def home(request): - list(Author.objects.all()) + list(Author.objects.all().using('default')) + return HttpResponse() + + +def home_other_db(request): + list(Author.objects.all().using('other')) return HttpResponse()