Skip to content

Commit e2b6c54

Browse files
authored
Merge pull request #13 from fetchai/feat/migrate-to-pydantic-v2
Migrate to pydantic v2
2 parents 6bda180 + 8a05a93 commit e2b6c54

File tree

5 files changed

+148
-76
lines changed

5 files changed

+148
-76
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ share/python-wheels/
2626
*.egg
2727
MANIFEST
2828

29+
#VS Code extensions
30+
.history/*
31+
2932
# PyInstaller
3033
# Usually these files are written by a python script from a template
3134
# before PyInstaller builds the exe, so as to inject date/other infos into it.

ai_engine_sdk/api_models/parsing_utils.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,24 +3,21 @@
33

44
def get_options_from_raw_api_response(raw_api_response: dict) -> list[dict[str, str]]:
55
return [
6-
{
7-
'key': str(o['key']), 'title': o['value']
8-
}
9-
for o in raw_api_response['agent_json']['options']
6+
{"key": str(o["key"]), "title": o["value"]}
7+
for o in raw_api_response["agent_json"]["options"]
108
]
119

1210

1311
def get_task_options_from_options(options: list[dict[str, str]]) -> list[TaskOption]:
1412
return [
15-
TaskOption.parse_obj({
16-
"key": option['key'],
17-
"title": option['title']
18-
})
13+
TaskOption.model_validate({"key": option["key"], "title": option["title"]})
1914
for option in options
2015
]
2116

2217

23-
def get_indexed_task_options_from_raw_api_response(raw_api_response: dict) -> dict[TaskOption]:
18+
def get_indexed_task_options_from_raw_api_response(
19+
raw_api_response: dict,
20+
) -> dict[TaskOption]:
2421
options_list = get_options_from_raw_api_response(raw_api_response=raw_api_response)
2522
task_options_list = get_task_options_from_options(options=options_list)
2623

ai_engine_sdk/client.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -107,12 +107,12 @@ async def _submit_message(self, payload: ApiMessagePayload):
107107
api_key=self._api_key,
108108
method='POST',
109109
endpoint=f"/v1beta1/engine/chat/sessions/{self.session_id}/submit",
110-
payload={'payload': payload.dict()}
110+
payload={'payload': payload.model_dump()}
111111
)
112112

113113
async def start(self, objective: str, context: Optional[str] = None):
114114
await self._submit_message(
115-
payload=ApiStartMessage.parse_obj({
115+
payload=ApiStartMessage.model_validate({
116116
'session_id': self.session_id,
117117
'bucket_id': self.function_group,
118118
'message_id': str(uuid4()).lower(),
@@ -123,7 +123,7 @@ async def start(self, objective: str, context: Optional[str] = None):
123123

124124
async def submit_task_selection(self, selection: TaskSelectionMessage, options: list[TaskOption]):
125125
await self._submit_message(
126-
payload=ApiUserJsonMessage.parse_obj({
126+
payload=ApiUserJsonMessage.model_validate({
127127
'session_id': self.session_id,
128128
'message_id': str(uuid4()).lower(),
129129
'referral_id': selection.id,
@@ -136,7 +136,7 @@ async def submit_task_selection(self, selection: TaskSelectionMessage, options:
136136

137137
async def submit_response(self, query: AgentMessage, response: str):
138138
await self._submit_message(
139-
payload=ApiUserMessageMessage.parse_obj(
139+
payload=ApiUserMessageMessage.model_validate(
140140
{
141141
'session_id': self.session_id,
142142
'message_id': str(uuid4()).lower(),
@@ -148,7 +148,7 @@ async def submit_response(self, query: AgentMessage, response: str):
148148

149149
async def submit_confirmation(self, confirmation: ConfirmationMessage):
150150
await self._submit_message(
151-
payload=ApiUserMessageMessage.parse_obj({
151+
payload=ApiUserMessageMessage.model_validate({
152152
'session_id': self.session_id,
153153
'message_id': str(uuid4()).lower(),
154154
'referral_id': confirmation.id,
@@ -158,7 +158,7 @@ async def submit_confirmation(self, confirmation: ConfirmationMessage):
158158

159159
async def reject_confirmation(self, confirmation: ConfirmationMessage, reason: str):
160160
await self._submit_message(
161-
payload=ApiUserMessageMessage.parse_obj({
161+
payload=ApiUserMessageMessage.model_validate({
162162
'session_id': self.session_id,
163163
'message_id': str(uuid4()).lower(),
164164
'referral_id': confirmation.id,
@@ -188,7 +188,7 @@ async def get_messages(self) -> List[ApiBaseMessage]:
188188
if is_task_selection_message(message_type=agent_json_type):
189189
indexed_task_options: dict = get_indexed_task_options_from_raw_api_response(raw_api_response=message)
190190
newMessages.append(
191-
TaskSelectionMessage.parse_obj({
191+
TaskSelectionMessage.model_validate({
192192
'type': agent_json_type,
193193
'id': message['message_id'],
194194
'timestamp': message['timestamp'],
@@ -198,7 +198,7 @@ async def get_messages(self) -> List[ApiBaseMessage]:
198198
)
199199
elif is_api_context_json(message_type=agent_json_type, agent_json_text=agent_json['text']):
200200
newMessages.append(
201-
ConfirmationMessage.parse_obj({
201+
ConfirmationMessage.model_validate({
202202
'id': message['message_id'],
203203
'timestamp': message['timestamp'],
204204
'text': agent_json['text'],
@@ -208,7 +208,7 @@ async def get_messages(self) -> List[ApiBaseMessage]:
208208
)
209209
elif is_data_request_message(message_type=agent_json_type):
210210
newMessages.append(
211-
DataRequestMessage.parse_obj({
211+
DataRequestMessage.model_validate({
212212
"id": message['message_id'],
213213
"text": agent_json['text'],
214214
"type": agent_json_type,
@@ -220,7 +220,7 @@ async def get_messages(self) -> List[ApiBaseMessage]:
220220
print(f"UNKNOWN-JSON: {message}")
221221
elif is_api_agent_info_message(message):
222222
newMessages.append(
223-
AiEngineMessage.parse_obj({
223+
AiEngineMessage.model_validate({
224224
'id': message['message_id'],
225225
'type': 'ai-engine',
226226
'timestamp': message['timestamp'],
@@ -229,7 +229,7 @@ async def get_messages(self) -> List[ApiBaseMessage]:
229229
)
230230
elif is_api_agent_message_message(message):
231231
newMessages.append(
232-
AgentMessage.parse_obj({
232+
AgentMessage.model_validate({
233233
'id': message['message_id'],
234234
'type': 'agent',
235235
'timestamp': message['timestamp'],
@@ -239,7 +239,7 @@ async def get_messages(self) -> List[ApiBaseMessage]:
239239
elif is_api_stop_message(message):
240240
print(f"STOP: {message}")
241241
newMessages.append(
242-
StopMessage.parse_obj({
242+
StopMessage.model_validate({
243243
'id': message['message_id'],
244244
'timestamp': message['timestamp'],
245245
'type': 'stop',
@@ -289,7 +289,7 @@ async def get_public_function_groups(self) -> List[FunctionGroup]:
289289
)
290290
return list(
291291
map(
292-
lambda item: FunctionGroup.parse_obj(item),
292+
lambda item: FunctionGroup.model_validate(item),
293293
raw_response
294294
)
295295
)
@@ -303,7 +303,7 @@ async def get_private_function_groups(self) -> List[FunctionGroup]:
303303
)
304304
return list(
305305
map(
306-
lambda item: FunctionGroup.parse_obj(item),
306+
lambda item: FunctionGroup.model_validate(item),
307307
raw_response
308308
)
309309
)
@@ -393,7 +393,7 @@ async def create_session(self, function_group: str, opts: Optional[dict] = None)
393393
api_key=self._api_key,
394394
method='POST',
395395
endpoint="/v1beta1/engine/chat/sessions",
396-
payload=request_payload.dict()
396+
payload=request_payload.model_dump()
397397
)
398398

399399
return Session(self._api_base_url, self._api_key, response['session_id'], function_group)

0 commit comments

Comments
 (0)