From c35fb15f650cea7d09e6433c6fc2139e4bda011c Mon Sep 17 00:00:00 2001 From: Harvey Hartwell Date: Thu, 28 Dec 2023 10:23:38 -0800 Subject: [PATCH] subscriptions --- ckc/serializers.py | 36 +++++++++++++++++- ckc/views.py | 18 +++++++-- testproject/urls.py | 4 +- tests/integration/test_payment_processing.py | 36 +++++++++++++++++- tests/integration/utils.py | 40 ++++++++++++++++++++ 5 files changed, 126 insertions(+), 8 deletions(-) create mode 100644 tests/integration/utils.py diff --git a/ckc/serializers.py b/ckc/serializers.py index b88510a..3e61f36 100644 --- a/ckc/serializers.py +++ b/ckc/serializers.py @@ -1,5 +1,5 @@ import stripe -from djstripe.models import PaymentMethod, Customer +from djstripe.models import PaymentMethod, Customer, Price, Product from rest_framework import serializers @@ -74,3 +74,37 @@ def create(self, validated_data): raise serializers.ValidationError(e) return payment_method + + +class ProductSerializer(serializers.ModelSerializer): + class Meta: + model = Product + fields = ( + 'id', + 'name', + 'description', + 'type', + ) + read_only_fields = ( + 'id', + 'name', + 'description', + 'type', + ) + + +class PriceSerializer(serializers.ModelSerializer): + class Meta: + model = Price + fields = ( + 'id', + 'unit_amount', + 'currency', + 'recurring', + ) + read_only_fields = ( + 'id', + 'unit_amount', + 'currency', + 'recurring', + ) diff --git a/ckc/views.py b/ckc/views.py index a348c6b..d044308 100644 --- a/ckc/views.py +++ b/ckc/views.py @@ -1,8 +1,8 @@ -from djstripe.models import PaymentMethod -from rest_framework import viewsets -from rest_framework.permissions import IsAuthenticated +from djstripe.models import PaymentMethod, Price, Plan +from rest_framework import viewsets, mixins +from rest_framework.permissions import IsAuthenticated, AllowAny -from ckc.serializers import PaymentMethodSerializer +from ckc.serializers import PaymentMethodSerializer, PriceSerializer class PaymentMethodViewSet(viewsets.ModelViewSet): @@ -13,3 +13,13 @@ class PaymentMethodViewSet(viewsets.ModelViewSet): def get_queryset(self): qs = PaymentMethod.objects.filter(customer__subscriber=self.request.user) return qs + + +class PriceViewSet(viewsets.GenericViewSet, mixins.RetrieveModelMixin, mixins.ListModelMixin): + queryset = Price.objects.all() + serializer_class = PriceSerializer + permission_classes = [AllowAny] + + def get_queryset(self): + qs = Price.objects.all() + return qs diff --git a/testproject/urls.py b/testproject/urls.py index 644739e..d324a0a 100644 --- a/testproject/urls.py +++ b/testproject/urls.py @@ -1,7 +1,7 @@ from django.urls import path from rest_framework import routers -from ckc.views import PaymentMethodViewSet +from ckc.views import PaymentMethodViewSet, PriceViewSet from testapp.views import TestExceptionsViewSet from testapp.viewsets import TestModelWithACreatorViewSet, TestModelWithADifferentNamedCreatorViewSet, BModelViewSet @@ -11,6 +11,8 @@ router.register(r'creators-alternative', TestModelWithADifferentNamedCreatorViewSet) router.register(r'bmodel', BModelViewSet) router.register(r'payment-methods', PaymentMethodViewSet, basename='payment-methods') +# router.register(r'subscription-plans', SubscriptionPlanViewSet, basename='subscription-plans') +router.register(r'prices', PriceViewSet, basename='prices') urlpatterns = router.urls + [ path('test-exceptions/', TestExceptionsViewSet.as_view(), name='test-exceptions'), diff --git a/tests/integration/test_payment_processing.py b/tests/integration/test_payment_processing.py index 549cb6e..3fcabd8 100644 --- a/tests/integration/test_payment_processing.py +++ b/tests/integration/test_payment_processing.py @@ -1,17 +1,20 @@ import json +import stripe from django.urls import reverse -from djstripe.models import PaymentMethod, Customer +from djstripe.models import PaymentMethod, Customer, Price, Product +# from djstripe.core import Price from rest_framework.test import APITestCase from django.contrib.auth import get_user_model from ckc.utils.payments import create_checkout_session, create_payment_intent, confirm_payment_intent +from tests.integration.utils import create_subscription_plan User = get_user_model() -class TestExceptions(APITestCase): +class TestPaymentProcessing(APITestCase): def setUp(self): self.user = User.objects.create_user(username="test", password="test") self.client.force_authenticate(user=self.user) @@ -77,4 +80,33 @@ def test_payment_intents(self): assert intent is not None assert intent.status == "succeeded" + def test_subscriptions(self): + # create the subscription plan through dj stripe price object + price = create_subscription_plan(2000, "month", product_name="Sample Product Name: 0", currency="usd") + assert price is not None + assert price.id is not None + customer, created = Customer.get_or_create(subscriber=self.user) + customer.add_payment_method("pm_card_visa") + # subscribe the customer to the plan + subscription = customer.subscribe(price=price.id) + + stripe_sub = stripe.Subscription.retrieve(subscription.id) + assert stripe_sub is not None + assert stripe_sub.status == "active" + assert stripe_sub.customer == customer.id + + # cancel the subscription + subscription.cancel() + stripe_sub = stripe.Subscription.retrieve(subscription.id) + assert stripe_sub is not None + assert stripe_sub.status == "canceled" + + def test_subscription_plan_list(self): + for i in range(3): + create_subscription_plan(2000 + i, "month", product_name=f"Sample Product Name: {i}", currency="usd") + + url = reverse('prices-list') + resp = self.client.get(url) + assert resp.status_code == 200 + assert len(resp.data) == 3 diff --git a/tests/integration/utils.py b/tests/integration/utils.py new file mode 100644 index 0000000..a5b9f45 --- /dev/null +++ b/tests/integration/utils.py @@ -0,0 +1,40 @@ +import stripe +from djstripe.models import Product, Price, Plan + + +def create_subscription_plan(amount, interval, interval_count=1, currency="usd", product_name="Sample Product Name"): + # product, created = Product.get_or_create( + # name=product_name, + # description="Sample Description", + # type="service", + # ) + stripe_product = stripe.Product.create( + name=product_name, + description="Sample Description", + ) + product = Product.sync_from_stripe_data(stripe_product) + + price = Price.create( + unit_amount=amount, + currency=currency, + recurring={ + "interval": interval, + "interval_count": interval_count, + }, + product=product, + active=True, + ) + from pprint import pprint + pprint(price) + + # print(price) + # print(created) + # plan, created = Plan.objects.get_or_create( + # active=True, + # amount=amount, + # interval=interval, + # interval_count=interval_count, + # product=product, + # currency=currency, + # ) + return price