Skip to content

Commit 6dc916b

Browse files
authored
Tidy up code and use a provider model with base
Tidy up code and use a provider model with base
1 parent 16d05ff commit 6dc916b

File tree

6 files changed

+199
-85
lines changed

6 files changed

+199
-85
lines changed

src/mockllm/providers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# Empty init file to make this a package

src/mockllm/providers/anthropic.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
from typing import Any, AsyncGenerator, Dict, Union
2+
3+
from fastapi import HTTPException
4+
from fastapi.responses import StreamingResponse
5+
6+
from ..config import ResponseConfig
7+
from ..models import (
8+
AnthropicChatRequest,
9+
AnthropicChatResponse,
10+
AnthropicStreamDelta,
11+
AnthropicStreamResponse,
12+
)
13+
from ..utils import count_tokens
14+
from .base import LLMProvider
15+
16+
17+
class AnthropicProvider(LLMProvider):
18+
def __init__(self, response_config: ResponseConfig):
19+
self.response_config = response_config
20+
21+
async def generate_stream_response(
22+
self, content: str, model: str
23+
) -> AsyncGenerator[str, None]:
24+
async for chunk in self.response_config.get_streaming_response_with_lag(
25+
content
26+
):
27+
stream_response = AnthropicStreamResponse(
28+
delta=AnthropicStreamDelta(delta={"text": chunk})
29+
)
30+
yield f"data: {stream_response.model_dump_json()}\n\n"
31+
32+
yield "data: [DONE]\n\n"
33+
34+
async def handle_chat_completion(
35+
self, request: AnthropicChatRequest
36+
) -> Union[Dict[str, Any], StreamingResponse]:
37+
last_message = next(
38+
(msg for msg in reversed(request.messages) if msg.role == "user"), None
39+
)
40+
41+
if not last_message:
42+
raise HTTPException(
43+
status_code=400, detail="No user message found in request"
44+
)
45+
46+
if request.stream:
47+
return StreamingResponse(
48+
self.generate_stream_response(last_message.content, request.model),
49+
media_type="text/event-stream",
50+
)
51+
52+
response_content = await self.response_config.get_response_with_lag(
53+
last_message.content
54+
)
55+
56+
prompt_tokens = count_tokens(str(request.messages), request.model)
57+
completion_tokens = count_tokens(response_content, request.model)
58+
total_tokens = prompt_tokens + completion_tokens
59+
60+
return AnthropicChatResponse(
61+
model=request.model,
62+
content=[{"type": "text", "text": response_content}],
63+
usage={
64+
"input_tokens": prompt_tokens,
65+
"output_tokens": completion_tokens,
66+
"total_tokens": total_tokens,
67+
},
68+
).model_dump()

src/mockllm/providers/base.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Any, AsyncGenerator, Dict, Union
3+
4+
from fastapi.responses import StreamingResponse
5+
6+
7+
class LLMProvider(ABC):
8+
@abstractmethod
9+
async def handle_chat_completion(
10+
self, request: Any
11+
) -> Union[Dict[str, Any], StreamingResponse]:
12+
pass
13+
14+
@abstractmethod
15+
async def generate_stream_response(
16+
self, content: str, model: str
17+
) -> AsyncGenerator[str, None]:
18+
"""Generate streaming response"""
19+
yield "" # pragma: no cover

src/mockllm/providers/openai.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
from typing import Any, AsyncGenerator, Dict, Union
2+
3+
from fastapi import HTTPException
4+
from fastapi.responses import StreamingResponse
5+
6+
from ..config import ResponseConfig
7+
from ..models import (
8+
OpenAIChatRequest,
9+
OpenAIChatResponse,
10+
OpenAIDeltaMessage,
11+
OpenAIStreamChoice,
12+
OpenAIStreamResponse,
13+
)
14+
from ..utils import count_tokens
15+
from .base import LLMProvider
16+
17+
18+
class OpenAIProvider(LLMProvider):
19+
def __init__(self, response_config: ResponseConfig):
20+
self.response_config = response_config
21+
22+
async def generate_stream_response(
23+
self, content: str, model: str
24+
) -> AsyncGenerator[str, None]:
25+
first_chunk = OpenAIStreamResponse(
26+
model=model,
27+
choices=[OpenAIStreamChoice(delta=OpenAIDeltaMessage(role="assistant"))],
28+
)
29+
yield f"data: {first_chunk.model_dump_json()}\n\n"
30+
31+
async for chunk in self.response_config.get_streaming_response_with_lag(
32+
content
33+
):
34+
chunk_response = OpenAIStreamResponse(
35+
model=model,
36+
choices=[OpenAIStreamChoice(delta=OpenAIDeltaMessage(content=chunk))],
37+
)
38+
yield f"data: {chunk_response.model_dump_json()}\n\n"
39+
40+
final_chunk = OpenAIStreamResponse(
41+
model=model,
42+
choices=[
43+
OpenAIStreamChoice(delta=OpenAIDeltaMessage(), finish_reason="stop")
44+
],
45+
)
46+
yield f"data: {final_chunk.model_dump_json()}\n\n"
47+
yield "data: [DONE]\n\n"
48+
49+
async def handle_chat_completion(
50+
self, request: OpenAIChatRequest
51+
) -> Union[Dict[str, Any], StreamingResponse]:
52+
last_message = next(
53+
(msg for msg in reversed(request.messages) if msg.role == "user"), None
54+
)
55+
56+
if not last_message:
57+
raise HTTPException(
58+
status_code=400, detail="No user message found in request"
59+
)
60+
61+
if request.stream:
62+
return StreamingResponse(
63+
self.generate_stream_response(last_message.content, request.model),
64+
media_type="text/event-stream",
65+
)
66+
67+
response_content = await self.response_config.get_response_with_lag(
68+
last_message.content
69+
)
70+
71+
prompt_tokens = count_tokens(str(request.messages), request.model)
72+
completion_tokens = count_tokens(response_content, request.model)
73+
total_tokens = prompt_tokens + completion_tokens
74+
75+
return OpenAIChatResponse(
76+
model=request.model,
77+
choices=[
78+
{
79+
"index": 0,
80+
"message": {"role": "assistant", "content": response_content},
81+
"finish_reason": "stop",
82+
}
83+
],
84+
usage={
85+
"prompt_tokens": prompt_tokens,
86+
"completion_tokens": completion_tokens,
87+
"total_tokens": total_tokens,
88+
},
89+
).model_dump()

src/mockllm/server.py

Lines changed: 11 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import logging
2-
from typing import AsyncGenerator, Union
2+
from typing import Any, AsyncGenerator, Dict, Union
33

44
import tiktoken
55
from fastapi import FastAPI, HTTPException
@@ -9,15 +9,15 @@
99
from .config import ResponseConfig
1010
from .models import (
1111
AnthropicChatRequest,
12-
AnthropicChatResponse,
1312
AnthropicStreamDelta,
1413
AnthropicStreamResponse,
1514
OpenAIChatRequest,
16-
OpenAIChatResponse,
1715
OpenAIDeltaMessage,
1816
OpenAIStreamChoice,
1917
OpenAIStreamResponse,
2018
)
19+
from .providers.anthropic import AnthropicProvider
20+
from .providers.openai import OpenAIProvider
2121

2222
log_handler = logging.StreamHandler()
2323
log_handler.setFormatter(jsonlogger.JsonFormatter())
@@ -27,6 +27,8 @@
2727
app = FastAPI(title="Mock LLM Server")
2828

2929
response_config = ResponseConfig()
30+
openai_provider = OpenAIProvider(response_config)
31+
anthropic_provider = AnthropicProvider(response_config)
3032

3133

3234
def count_tokens(text: str, model: str) -> int:
@@ -80,9 +82,8 @@ async def anthropic_stream_response(
8082
@app.post("/v1/chat/completions", response_model=None)
8183
async def openai_chat_completion(
8284
request: OpenAIChatRequest,
83-
) -> Union[OpenAIChatResponse, StreamingResponse]:
84-
"""Handle chat completion requests, supporting
85-
both regular and streaming responses."""
85+
) -> Union[Dict[str, Any], StreamingResponse]:
86+
"""Handle OpenAI chat completion requests"""
8687
try:
8788
logger.info(
8889
"Received chat completion request",
@@ -92,47 +93,7 @@ async def openai_chat_completion(
9293
"stream": request.stream,
9394
},
9495
)
95-
96-
last_message = next(
97-
(msg for msg in reversed(request.messages) if msg.role == "user"), None
98-
)
99-
100-
if not last_message:
101-
raise HTTPException(
102-
status_code=400, detail="No user message found in request"
103-
)
104-
105-
if request.stream:
106-
return StreamingResponse(
107-
openai_stream_response(last_message.content, request.model),
108-
media_type="text/event-stream",
109-
)
110-
111-
response_content = await response_config.get_response_with_lag(
112-
last_message.content
113-
)
114-
115-
# Calculate mock token counts
116-
prompt_tokens = count_tokens(str(request.messages), request.model)
117-
completion_tokens = count_tokens(response_content, request.model)
118-
total_tokens = prompt_tokens + completion_tokens
119-
120-
return OpenAIChatResponse(
121-
model=request.model,
122-
choices=[
123-
{
124-
"index": 0,
125-
"message": {"role": "assistant", "content": response_content},
126-
"finish_reason": "stop",
127-
}
128-
],
129-
usage={
130-
"prompt_tokens": prompt_tokens,
131-
"completion_tokens": completion_tokens,
132-
"total_tokens": total_tokens,
133-
},
134-
)
135-
96+
return await openai_provider.handle_chat_completion(request)
13697
except Exception as e:
13798
logger.error(f"Error processing request: {str(e)}")
13899
raise HTTPException(
@@ -143,9 +104,8 @@ async def openai_chat_completion(
143104
@app.post("/v1/messages", response_model=None)
144105
async def anthropic_chat_completion(
145106
request: AnthropicChatRequest,
146-
) -> Union[AnthropicChatResponse, StreamingResponse]:
147-
"""Handle Anthropic chat completion requests,
148-
supporting both regular and streaming responses."""
107+
) -> Union[Dict[str, Any], StreamingResponse]:
108+
"""Handle Anthropic chat completion requests"""
149109
try:
150110
logger.info(
151111
"Received Anthropic chat completion request",
@@ -155,41 +115,7 @@ async def anthropic_chat_completion(
155115
"stream": request.stream,
156116
},
157117
)
158-
159-
last_message = next(
160-
(msg for msg in reversed(request.messages) if msg.role == "user"), None
161-
)
162-
163-
if not last_message:
164-
raise HTTPException(
165-
status_code=400, detail="No user message found in request"
166-
)
167-
168-
if request.stream:
169-
return StreamingResponse(
170-
anthropic_stream_response(last_message.content, request.model),
171-
media_type="text/event-stream",
172-
)
173-
174-
response_content = await response_config.get_response_with_lag(
175-
last_message.content
176-
)
177-
178-
# Calculate mock token counts
179-
prompt_tokens = count_tokens(str(request.messages), request.model)
180-
completion_tokens = count_tokens(response_content, request.model)
181-
total_tokens = prompt_tokens + completion_tokens
182-
183-
return AnthropicChatResponse(
184-
model=request.model,
185-
content=[{"type": "text", "text": response_content}],
186-
usage={
187-
"input_tokens": prompt_tokens,
188-
"output_tokens": completion_tokens,
189-
"total_tokens": total_tokens,
190-
},
191-
)
192-
118+
return await anthropic_provider.handle_chat_completion(request)
193119
except Exception as e:
194120
logger.error(f"Error processing request: {str(e)}")
195121
raise HTTPException(

src/mockllm/utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import tiktoken
2+
3+
4+
def count_tokens(text: str, model: str) -> int:
5+
"""Get realistic token count for text using tiktoken"""
6+
try:
7+
encoding = tiktoken.encoding_for_model(model)
8+
return len(encoding.encode(text))
9+
except Exception:
10+
# Fallback to rough estimation if model not supported
11+
return len(text.split())

0 commit comments

Comments
 (0)