Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions app_users/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,23 +82,28 @@ def _update_user_balance_in_txn(txn: Transaction):
# avoid updating twice for same invoice
return

with transaction.atomic():
self.balance += amount
self.save()
obj = self.add_balance_direct(amount)

# create invoice entry
txn.create(
invoice_ref,
{
"amount": amount,
"end_balance": self.balance,
"end_balance": obj.balance,
"timestamp": datetime.datetime.utcnow(),
**invoice_items,
},
)

_update_user_balance_in_txn(db.get_client().transaction())

@transaction.atomic
def add_balance_direct(self, amount):
obj: AppUser = self.__class__.objects.select_for_update().get(pk=self.pk)
obj.balance += amount
obj.save(update_fields=["balance"])
return obj

def copy_from_firebase_user(self, user: auth.UserRecord) -> "AppUser":
# copy data from firebase user
self.uid = user.uid
Expand Down
46 changes: 46 additions & 0 deletions bots/migrations/0023_alter_savedrun_workflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Generated by Django 4.2.1 on 2023-07-14 11:38

from django.db import migrations, models


class Migration(migrations.Migration):
dependencies = [
("bots", "0022_remove_botintegration_analysis_url_and_more"),
]

operations = [
migrations.AlterField(
model_name="savedrun",
name="workflow",
field=models.IntegerField(
choices=[
(1, "doc-search"),
(2, "doc-summary"),
(3, "google-gpt"),
(4, "video-bots"),
(5, "LipsyncTTS"),
(6, "TextToSpeech"),
(7, "asr"),
(8, "Lipsync"),
(9, "DeforumSD"),
(10, "CompareText2Img"),
(11, "text2audio"),
(12, "Img2Img"),
(13, "FaceInpainting"),
(14, "GoogleImageGen"),
(15, "compare-ai-upscalers"),
(16, "SEOSummary"),
(17, "EmailFaceInpainting"),
(18, "SocialLookupEmail"),
(19, "ObjectInpainting"),
(20, "ImageSegmentation"),
(21, "CompareLLM"),
(22, "ChyronPlant"),
(23, "LetterWriter"),
(24, "SmartGPT"),
(25, "QRCodeGenerator"),
],
default=4,
),
),
]
1 change: 1 addition & 0 deletions bots/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class Workflow(models.IntegerChoices):
CHYRONPLANT = (22, "ChyronPlant")
LETTERWRITER = (23, "LetterWriter")
SMARTGPT = (24, "SmartGPT")
QRCODE = (25, "QRCodeGenerator")

def get_app_url(self, example_id: str, run_id: str, uid: str):
"""return the url to the gooey app"""
Expand Down
156 changes: 66 additions & 90 deletions bots/tests.py
Original file line number Diff line number Diff line change
@@ -1,102 +1,78 @@
from django.test import TestCase
import random

from app_users.models import AppUser
from daras_ai_v2.functional import map_parallel
from .models import (
Message,
Conversation,
BotIntegration,
Platform,
Workflow,
ConvoState,
)
from django.db import transaction
from django.contrib import messages

CHATML_ROLE_USER = "user"
CHATML_ROLE_ASSISSTANT = "assistant"

# python manage.py test


class MessageModelTest(TestCase):

"""def test_create_and_save_message(self):

# Create a new conversation
conversation = Conversation.objects.create()

# Create and save a new message
message = Message(content="Hello, world!", conversation=conversation)
message.save()

# Retrieve all messages from the database
all_messages = Message.objects.all()
self.assertEqual(len(all_messages), 1)

# Check that the message's content is correct
only_message = all_messages[0]
self.assertEqual(only_message, message)

# Check the content
self.assertEqual(only_message.content, "Hello, world!")"""


class BotIntegrationTest(TestCase):
@classmethod
def setUpClass(cls):
super(BotIntegrationTest, cls).setUpClass()
cls.keepdb = True

@transaction.atomic
def test_create_bot_integration_conversation_message(self):
# Create a new BotIntegration with WhatsApp as the platform
bot_integration = BotIntegration.objects.create(
name="My Bot Integration",
saved_run=None,
billing_account_uid="fdnacsFSBQNKVW8z6tzhBLHKpAm1", # digital green's account id
user_language="en",
show_feedback_buttons=True,
platform=Platform.WHATSAPP,
wa_phone_number="my_whatsapp_number",
wa_phone_number_id="my_whatsapp_number_id",
)

# Create a Conversation that uses the BotIntegration
conversation = Conversation.objects.create(
bot_integration=bot_integration,
state=ConvoState.INITIAL,
wa_phone_number="user_whatsapp_number",
)

# Create a User Message within the Conversation
message_u = Message.objects.create(
conversation=conversation,
role=CHATML_ROLE_USER,
content="What types of chilies can be grown in Mumbai?",
display_content="What types of chilies can be grown in Mumbai?",
)

# Create a Bot Message within the Conversation
message_b = Message.objects.create(
conversation=conversation,
role=CHATML_ROLE_ASSISSTANT,
content="Red, green, and yellow grow the best.",
display_content="Red, green, and yellow grow the best.",
)

# Assert that the User Message was created successfully
self.assertEqual(Message.objects.count(), 2)
self.assertEqual(message_u.conversation, conversation)
self.assertEqual(message_u.role, CHATML_ROLE_USER)
self.assertEqual(
message_u.content, "What types of chilies can be grown in Mumbai?"
)
self.assertEqual(
message_u.display_content, "What types of chilies can be grown in Mumbai?"
)

# Assert that the Bot Message was created successfully
self.assertEqual(message_b.conversation, conversation)
self.assertEqual(message_b.role, CHATML_ROLE_ASSISSTANT)
self.assertEqual(message_b.content, "Red, green, and yellow grow the best.")
self.assertEqual(
message_b.display_content, "Red, green, and yellow grow the best."
)
def test_add_balance_direct():
pk = AppUser.objects.create(balance=0, is_anonymous=False).pk
amounts = [[random.randint(-100, 10_000) for _ in range(100)] for _ in range(5)]

def worker(amts):
user = AppUser.objects.get(pk=pk)
for amt in amts:
user.add_balance_direct(amt)

map_parallel(worker, amounts)

assert AppUser.objects.get(pk=pk).balance == sum(map(sum, amounts))


def test_create_bot_integration_conversation_message():
# Create a new BotIntegration with WhatsApp as the platform
bot_integration = BotIntegration.objects.create(
name="My Bot Integration",
saved_run=None,
billing_account_uid="fdnacsFSBQNKVW8z6tzhBLHKpAm1", # digital green's account id
user_language="en",
show_feedback_buttons=True,
platform=Platform.WHATSAPP,
wa_phone_number="my_whatsapp_number",
wa_phone_number_id="my_whatsapp_number_id",
)

# Create a Conversation that uses the BotIntegration
conversation = Conversation.objects.create(
bot_integration=bot_integration,
state=ConvoState.INITIAL,
wa_phone_number="user_whatsapp_number",
)

# Create a User Message within the Conversation
message_u = Message.objects.create(
conversation=conversation,
role=CHATML_ROLE_USER,
content="What types of chilies can be grown in Mumbai?",
display_content="What types of chilies can be grown in Mumbai?",
)

# Create a Bot Message within the Conversation
message_b = Message.objects.create(
conversation=conversation,
role=CHATML_ROLE_ASSISSTANT,
content="Red, green, and yellow grow the best.",
display_content="Red, green, and yellow grow the best.",
)

# Assert that the User Message was created successfully
assert Message.objects.count() == 2
assert message_u.conversation == conversation
assert message_u.role == CHATML_ROLE_USER
assert message_u.content == "What types of chilies can be grown in Mumbai?"
assert message_u.display_content == "What types of chilies can be grown in Mumbai?"

# Assert that the Bot Message was created successfully
assert message_b.conversation == conversation
assert message_b.role == CHATML_ROLE_ASSISSTANT
assert message_b.content == "Red, green, and yellow grow the best."
assert message_b.display_content == "Red, green, and yellow grow the best."
3 changes: 3 additions & 0 deletions celeryapp/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from celeryapp.celeryconfig import app
from daras_ai_v2.base import StateKeys, err_msg_for_exc, BasePage
from gooey_ui.pubsub import realtime_push
from gooey_ui.state import set_query_params


@app.task
Expand All @@ -21,6 +22,7 @@ def gui_runner(
uid: str,
state: dict,
channel: str,
query_params: dict = None,
):
self = page_cls(request=SimpleNamespace(user=AppUser.objects.get(id=user_id)))

Expand All @@ -29,6 +31,7 @@ def gui_runner(
yield_val = None
error_msg = None
url = self.app_url(run_id=run_id, uid=uid)
set_query_params(query_params or {})

def save(done=False):
if done:
Expand Down
1 change: 1 addition & 0 deletions daras_ai_v2/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,6 +628,7 @@ def _render_output_col(self, submitted: bool):
uid=uid,
state=st.session_state,
channel=f"gooey-outputs/{self.doc_name}/{uid}/{run_id}",
query_params=gooey_get_query_params(),
)

raise QueryParamsRedirectException(
Expand Down
4 changes: 4 additions & 0 deletions daras_ai_v2/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
else:
SECRET_KEY = config("SECRET_KEY")

# https://hashids.org/
HASHIDS_SALT = config("HASHIDS_SALT", default="")

ALLOWED_HOSTS = ["*"]
INTERNAL_IPS = ["127.0.0.1"]
SECURE_PROXY_SSL_HEADER = ("HTTP_X_FORWARDED_PROTO", "https")
Expand All @@ -48,6 +51,7 @@
"django.forms", # needed to override admin forms
"django.contrib.admin",
"app_users",
"url_shortener",
]

MIDDLEWARE = [
Expand Down
2 changes: 1 addition & 1 deletion pytest.ini
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
[pytest]
addopts = --tb=native -vv -n 8
addopts = --tb=native -vv -n 8 --disable-warnings
DJANGO_SETTINGS_MODULE = daras_ai_v2.settings
python_files = tests.py test_*.py *_tests.py
Loading