Skip to content

Commit

Permalink
unify chat.py file
Browse files Browse the repository at this point in the history
  • Loading branch information
leehuwuj committed Feb 25, 2025
1 parent 087a45e commit d38eb3c
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 104 deletions.
57 changes: 0 additions & 57 deletions templates/components/multiagent/python/app/api/routers/chat.py

This file was deleted.

54 changes: 9 additions & 45 deletions templates/types/streaming/fastapi/app/api/routers/chat.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,22 @@
import json
import logging

from fastapi import APIRouter, BackgroundTasks, HTTPException, Request, status
from llama_index.core.agent.workflow import AgentOutput
from llama_index.core.llms import MessageRole

from app.api.callbacks.llamacloud import LlamaCloudFileDownload
from app.api.callbacks.next_question import SuggestNextQuestions
from app.api.callbacks.source_nodes import AddNodeUrl
from app.api.callbacks.stream_handler import StreamHandler
from app.api.callbacks.source_nodes import AddNodeUrl
from app.api.routers.models import (
ChatData,
Message,
Result,
)
from app.engine.engine import get_engine
from app.engine.query_filter import generate_filters
from app.workflows import create_workflow

chat_router = r = APIRouter()

logger = logging.getLogger("uvicorn")


# streaming endpoint - delete if not needed
@r.post("")
async def chat(
request: Request,
Expand All @@ -31,16 +25,18 @@ async def chat(
):
try:
last_message_content = data.get_last_message_content()
messages = data.get_history_messages()
messages = data.get_history_messages(include_agent_messages=True)

doc_ids = data.get_chat_document_ids()
filters = generate_filters(doc_ids)
params = data.data or {}
logger.info(
f"Creating chat engine with filters: {str(filters)}",

workflow = create_workflow(
params=params,
filters=filters,
)
engine = get_engine(filters=filters, params=params)
handler = engine.run(

handler = workflow.run(
user_msg=last_message_content,
chat_history=messages,
stream=True,
Expand All @@ -59,35 +55,3 @@ async def chat(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error in chat engine: {e}",
) from e


# non-streaming endpoint - delete if not needed
@r.post("/request")
async def chat_request(
data: ChatData,
) -> Result:
last_message_content = data.get_last_message_content()
messages = data.get_history_messages()

doc_ids = data.get_chat_document_ids()
filters = generate_filters(doc_ids)
params = data.data or {}
logger.info(
f"Creating chat engine with filters: {str(filters)}",
)
engine = get_engine(filters=filters, params=params)

response = await engine.run(
user_msg=last_message_content,
chat_history=messages,
stream=False,
)
output = response
if isinstance(output, AgentOutput):
content = output.response.content
else:
content = json.dumps(output)

return Result(
result=Message(role=MessageRole.ASSISTANT, content=content),
)
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .agent import create_workflow
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import List

from llama_index.core.agent.workflow import AgentWorkflow

from llama_index.core.settings import Settings
from llama_index.core.tools import BaseTool

Expand All @@ -11,7 +10,7 @@
from app.engine.tools.query_engine import get_query_engine_tool


def get_engine(params=None, **kwargs):
def create_workflow(params=None, **kwargs):
if params is None:
params = {}
system_prompt = os.getenv("SYSTEM_PROMPT")
Expand Down

0 comments on commit d38eb3c

Please sign in to comment.