Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 86 additions & 83 deletions posthog/test/test_consumer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import json
import time
import unittest
from typing import Any

import mock
from parameterized import parameterized

try:
from queue import Queue
Expand All @@ -14,15 +16,19 @@
from posthog.test.test_utils import TEST_API_KEY


def _track_event(event_name: str = "python event") -> dict[str, str]:
return {"type": "track", "event": event_name, "distinct_id": "distinct_id"}


class TestConsumer(unittest.TestCase):
def test_next(self):
def test_next(self) -> None:
q = Queue()
consumer = Consumer(q, "")
q.put(1)
next = consumer.next()
self.assertEqual(next, [1])

def test_next_limit(self):
def test_next_limit(self) -> None:
q = Queue()
flush_at = 50
consumer = Consumer(q, "", flush_at)
Expand All @@ -31,7 +37,7 @@ def test_next_limit(self):
next = consumer.next()
self.assertEqual(next, list(range(flush_at)))

def test_dropping_oversize_msg(self):
def test_dropping_oversize_msg(self) -> None:
q = Queue()
consumer = Consumer(q, "")
oversize_msg = {"m": "x" * MAX_MSG_SIZE}
Expand All @@ -40,15 +46,14 @@ def test_dropping_oversize_msg(self):
self.assertEqual(next, [])
self.assertTrue(q.empty())

def test_upload(self):
def test_upload(self) -> None:
q = Queue()
consumer = Consumer(q, TEST_API_KEY)
track = {"type": "track", "event": "python event", "distinct_id": "distinct_id"}
q.put(track)
q.put(_track_event())
success = consumer.upload()
self.assertTrue(success)

def test_flush_interval(self):
def test_flush_interval(self) -> None:
# Put _n_ items in the queue, pausing a little bit more than
# _flush_interval_ after each one.
# The consumer should upload _n_ times.
Expand All @@ -57,17 +62,12 @@ def test_flush_interval(self):
consumer = Consumer(q, TEST_API_KEY, flush_at=10, flush_interval=flush_interval)
with mock.patch("posthog.consumer.batch_post") as mock_post:
consumer.start()
for i in range(0, 3):
track = {
"type": "track",
"event": "python event %d" % i,
"distinct_id": "distinct_id",
}
q.put(track)
for i in range(3):
q.put(_track_event("python event %d" % i))
time.sleep(flush_interval * 1.1)
self.assertEqual(mock_post.call_count, 3)

def test_multiple_uploads_per_interval(self):
def test_multiple_uploads_per_interval(self) -> None:
# Put _flush_at*2_ items in the queue at once, then pause for
# _flush_interval_. The consumer should upload 2 times.
q = Queue()
Expand All @@ -78,88 +78,60 @@ def test_multiple_uploads_per_interval(self):
)
with mock.patch("posthog.consumer.batch_post") as mock_post:
consumer.start()
for i in range(0, flush_at * 2):
track = {
"type": "track",
"event": "python event %d" % i,
"distinct_id": "distinct_id",
}
q.put(track)
for i in range(flush_at * 2):
q.put(_track_event("python event %d" % i))
time.sleep(flush_interval * 1.1)
self.assertEqual(mock_post.call_count, 2)

def test_request(self):
def test_request(self) -> None:
consumer = Consumer(None, TEST_API_KEY)
track = {"type": "track", "event": "python event", "distinct_id": "distinct_id"}
consumer.request([track])
consumer.request([_track_event()])

def _test_request_retry(self, consumer, expected_exception, exception_count):
def mock_post(*args, **kwargs):
mock_post.call_count += 1
if mock_post.call_count <= exception_count:
raise expected_exception
def _run_retry_test(
self, exception: Exception, exception_count: int, retries: int = 10
) -> None:
call_count = [0]

mock_post.call_count = 0
def mock_post(*args: Any, **kwargs: Any) -> None:
call_count[0] += 1
if call_count[0] <= exception_count:
raise exception

consumer = Consumer(None, TEST_API_KEY, retries=retries)
with mock.patch(
"posthog.consumer.batch_post", mock.Mock(side_effect=mock_post)
):
track = {
"type": "track",
"event": "python event",
"distinct_id": "distinct_id",
}
# request() should succeed if the number of exceptions raised is
# less than the retries paramater.
if exception_count <= consumer.retries:
consumer.request([track])
if exception_count <= retries:
consumer.request([_track_event()])
else:
# if exceptions are raised more times than the retries
# parameter, we expect the exception to be returned to
# the caller.
try:
consumer.request([track])
except type(expected_exception) as exc:
self.assertEqual(exc, expected_exception)
else:
self.fail(
"request() should raise an exception if still failing after %d retries"
% consumer.retries
)

def test_request_retry(self):
# we should retry on general errors
consumer = Consumer(None, TEST_API_KEY)
self._test_request_retry(consumer, Exception("generic exception"), 2)

# we should retry on server errors
consumer = Consumer(None, TEST_API_KEY)
self._test_request_retry(consumer, APIError(500, "Internal Server Error"), 2)

# we should retry on HTTP 429 errors
consumer = Consumer(None, TEST_API_KEY)
self._test_request_retry(consumer, APIError(429, "Too Many Requests"), 2)

# we should NOT retry on other client errors
consumer = Consumer(None, TEST_API_KEY)
api_error = APIError(400, "Client Errors")
try:
self._test_request_retry(consumer, api_error, 1)
except APIError:
pass
else:
self.fail("request() should not retry on client errors")

# test for number of exceptions raise > retries value
consumer = Consumer(None, TEST_API_KEY, retries=3)
self._test_request_retry(consumer, APIError(500, "Internal Server Error"), 3)

def test_pause(self):
with self.assertRaises(type(exception)):
consumer.request([_track_event()])

@parameterized.expand(
[
("general_errors", Exception("generic exception"), 2),
("server_errors", APIError(500, "Internal Server Error"), 2),
("rate_limit_errors", APIError(429, "Too Many Requests"), 2),
]
)
def test_request_retries_on_retriable_errors(
self, _name: str, exception: Exception, exception_count: int
) -> None:
self._run_retry_test(exception, exception_count)

def test_request_does_not_retry_client_errors(self) -> None:
with self.assertRaises(APIError):
self._run_retry_test(APIError(400, "Client Errors"), 1)

def test_request_fails_when_exceptions_exceed_retries(self) -> None:
self._run_retry_test(APIError(500, "Internal Server Error"), 4, retries=3)

def test_pause(self) -> None:
consumer = Consumer(None, TEST_API_KEY)
consumer.pause()
self.assertFalse(consumer.running)

def test_max_batch_size(self):
def test_max_batch_size(self) -> None:
q = Queue()
consumer = Consumer(q, TEST_API_KEY, flush_at=100000, flush_interval=3)
properties = {}
Expand All @@ -175,7 +147,7 @@ def test_max_batch_size(self):
# Let's capture 8MB of data to trigger two batches
n_msgs = int(8_000_000 / msg_size)

def mock_post_fn(_, data, **kwargs):
def mock_post_fn(_: str, data: str, **kwargs: Any) -> mock.Mock:
res = mock.Mock()
res.status_code = 200
request_size = len(data.encode())
Expand All @@ -194,3 +166,34 @@ def mock_post_fn(_, data, **kwargs):
q.put(track)
q.join()
self.assertEqual(mock_post.call_count, 2)

@parameterized.expand(
[
("on_error_succeeds", False),
("on_error_raises", True),
]
)
def test_upload_exception_calls_on_error_and_does_not_raise(
self, _name: str, on_error_raises: bool
) -> None:
on_error_called: list[tuple[Exception, list[dict[str, str]]]] = []

def on_error(e: Exception, batch: list[dict[str, str]]) -> None:
on_error_called.append((e, batch))
if on_error_raises:
raise Exception("on_error failed")

q = Queue()
consumer = Consumer(q, TEST_API_KEY, on_error=on_error)
track = _track_event()
q.put(track)

with mock.patch.object(
consumer, "request", side_effect=Exception("request failed")
):
result = consumer.upload()

self.assertFalse(result)
self.assertEqual(len(on_error_called), 1)
self.assertEqual(str(on_error_called[0][0]), "request failed")
self.assertEqual(on_error_called[0][1], [track])
Loading