Skip to content

Commit 9b2c382

Browse files
committed
feat: chat function: params and result schema
1 parent b6a58d3 commit 9b2c382

File tree

3 files changed

+119
-0
lines changed

3 files changed

+119
-0
lines changed

lf_toolkit/chat/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .params import ChatParams
2+
from .result import ChatResult

lf_toolkit/chat/params.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from typing import TypedDict
2+
from typing import List
3+
4+
class ChatParams(TypedDict):
5+
include_test_data: bool | None
6+
conversation_history: List[str] | None
7+
summary: str | None
8+
conversational_style: str | None
9+
question_response_details: str | None
10+
conversation_id: str | None

lf_toolkit/chat/result.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
from typing import Any
2+
from typing import Dict
3+
from typing import List
4+
from typing import Tuple
5+
from typing import Union
6+
7+
ResponseItem = Tuple[str, str]
8+
9+
def update_response(
10+
response: Dict[str, List[str]], response_items: List[ResponseItem]
11+
) -> Dict[str, List[str]]:
12+
for item in response_items:
13+
if (isinstance(item, tuple) or isinstance(item, list)) and len(item) == 2:
14+
response.setdefault(item[0], []).append(item[1])
15+
else:
16+
raise TypeError("Response item must be a tuple of (tag, chatbot_response).")
17+
18+
return response
19+
20+
21+
class ChatResult:
22+
__slots__ = ("_response",
23+
"_metadata",
24+
"_processing_time")
25+
__fields__ = (
26+
"response",
27+
"tags",
28+
"metadata",
29+
"processing_time",
30+
)
31+
32+
_response: Dict[str, List[str]]
33+
34+
_metadata: Dict[str, Any]
35+
_processing_time: float
36+
37+
def __init__(
38+
self,
39+
response_items: List[ResponseItem] = [],
40+
metadata: Dict[str, Any] = {},
41+
processing_time: float = 0,
42+
):
43+
self._response = update_response({}, response_items)
44+
self._metadata = metadata
45+
self._processing_time = processing_time
46+
47+
@property
48+
def response(self) -> str:
49+
return "<br>".join(
50+
[
51+
response_str
52+
for lists in self._response.values()
53+
for response_str in lists
54+
]
55+
)
56+
57+
@property
58+
def tags(self) -> Union[List[str], None]:
59+
return list(self._response.keys())
60+
61+
@property
62+
def metadata(self) -> Dict[str, Any]:
63+
return self._metadata
64+
65+
66+
def get_response(self, tag: str) -> List[str]:
67+
return self._response.get(tag, [])
68+
69+
def get_processing_time(self) -> float:
70+
return self._processing_time
71+
72+
def add_response(self, tag: str, response: str) -> None:
73+
self._response.setdefault(tag, []).append(response)
74+
75+
def add_metadata(self, name: str, data: Any) -> None:
76+
self._metadata[name] = data
77+
78+
def add_processing_time(self, time: float) -> None:
79+
self._processing_time = time
80+
81+
def to_dict(self, include_test_data: bool = False) -> Dict[str, Any]:
82+
res = {
83+
"chatbot_response": self.response,
84+
}
85+
86+
if include_test_data:
87+
res["tags"] = self.tags
88+
if len(self.metadata) > 0:
89+
res["metadata"] = self.metadata
90+
if self._processing_time >= 0:
91+
res["processing_time"] = self._processing_time
92+
93+
return res
94+
95+
def __repr__(self):
96+
members = ", ".join(f"{k}={repr(getattr(self, k))}" for k in self.__fields__)
97+
return f"Result({members})"
98+
99+
def __eq__(self, other):
100+
if type(self) is not type(other):
101+
return False
102+
103+
for k in self.__slots__:
104+
if getattr(self, k) != getattr(other, k):
105+
return False
106+
107+
return True

0 commit comments

Comments
 (0)