1
- from pydantic import BaseModel
2
- from dataclasses import asdict
3
- from typing import Any , Dict , Optional , Set , TYPE_CHECKING
4
1
from abc import ABC , abstractmethod
5
- from time import time
2
+ from dataclasses import asdict
6
3
from logging import Logger
4
+ from time import time
5
+ from typing import TYPE_CHECKING , Any , Dict , Optional , Set
6
+
7
7
from jupyter_ai .config_manager import ConfigManager
8
+ from jupyterlab_chat .models import Message , NewMessage , User
8
9
from jupyterlab_chat .ychat import YChat
9
- from jupyterlab_chat .models import User , Message , NewMessage
10
+ from pydantic import BaseModel
11
+
10
12
from .persona_awareness import PersonaAwareness
11
13
12
14
# prevents a circular import
13
15
# `PersonaManager` types have to be surrounded in single quotes
14
16
if TYPE_CHECKING :
15
- from .persona_manager import PersonaManager
16
17
from collections .abc import AsyncIterator
17
18
19
+ from .persona_manager import PersonaManager
20
+
21
+
18
22
class PersonaDefaults (BaseModel ):
19
23
"""
20
24
Data structure that represents the default settings of a persona. Each persona
21
25
must define some basic default settings, like its name.
22
26
23
27
Each of these settings can be overwritten through the settings UI.
24
28
"""
29
+
25
30
################################################
26
31
# required fields
27
32
################################################
28
- name : str # e.g. "Jupyternaut"
29
- description : str # e.g. "..."
30
- avatar_path : str # e.g. /avatars/jupyternaut.svg
31
- system_prompt : str # e.g. "You are a language model named..."
33
+ name : str # e.g. "Jupyternaut"
34
+ description : str # e.g. "..."
35
+ avatar_path : str # e.g. /avatars/jupyternaut.svg
36
+ system_prompt : str # e.g. "You are a language model named..."
32
37
33
38
################################################
34
39
# optional fields
35
40
################################################
36
- slash_commands : Set [str ] = set ("*" ) # change this to enable/disable slash commands
37
- model_uid : Optional [str ] = None # e.g. "ollama:deepseek-coder-v2"
41
+ slash_commands : Set [str ] = set ("*" ) # change this to enable/disable slash commands
42
+ model_uid : Optional [str ] = None # e.g. "ollama:deepseek-coder-v2"
38
43
# ^^^ set this to automatically default to a model after a fresh start, no config file
39
44
40
45
@@ -44,27 +49,32 @@ class BasePersona(ABC):
44
49
"""
45
50
46
51
ychat : YChat
47
- manager : ' PersonaManager'
52
+ manager : " PersonaManager"
48
53
config : ConfigManager
49
54
log : Logger
50
55
awareness : PersonaAwareness
51
56
52
57
################################################
53
58
# constructor
54
59
################################################
55
- def __init__ (self , * , ychat : YChat , manager : 'PersonaManager' , config : ConfigManager , log : Logger ):
60
+ def __init__ (
61
+ self ,
62
+ * ,
63
+ ychat : YChat ,
64
+ manager : "PersonaManager" ,
65
+ config : ConfigManager ,
66
+ log : Logger ,
67
+ ):
56
68
self .ychat = ychat
57
69
self .manager = manager
58
70
self .config = config
59
71
self .log = log
60
72
self .awareness = PersonaAwareness (
61
- ychat = self .ychat ,
62
- log = self .log ,
63
- user = self .as_user ()
73
+ ychat = self .ychat , log = self .log , user = self .as_user ()
64
74
)
65
75
66
76
self .ychat .set_user (self .as_user ())
67
-
77
+
68
78
################################################
69
79
# abstract methods, required by subclasses.
70
80
################################################
@@ -86,7 +96,7 @@ async def process_message(self, message: Message) -> None:
86
96
# support streaming
87
97
# handle multiple processed messages concurrently (if model service allows it)
88
98
pass
89
-
99
+
90
100
################################################
91
101
# base class methods, available to subclasses.
92
102
################################################
@@ -95,7 +105,7 @@ def id(self) -> str:
95
105
"""
96
106
Return a unique ID for this persona, which sets its username in the
97
107
`User` object shared with other collaborative extensions.
98
-
108
+
99
109
- This ID is guaranteed to be identical throughout this object's
100
110
lifecycle.
101
111
@@ -106,40 +116,40 @@ def id(self) -> str:
106
116
107
117
- For example, 'Jupyternaut' always has the ID
108
118
`jupyter-ai-personas::jupyter-ai::JupyternautPersona`.
109
-
119
+
110
120
- The ID must be unique, so if a package provides multiple personas,
111
121
their class names must be unique. Renaming the persona class changes the
112
122
ID of that persona, so you should also avoid renaming it if possible.
113
123
"""
114
- package_name = self .__module__ .split ('.' )[0 ]
124
+ package_name = self .__module__ .split ("." )[0 ]
115
125
class_name = self .__class__ .__name__
116
- return f' jupyter-ai-personas::{ package_name } ::{ class_name } '
126
+ return f" jupyter-ai-personas::{ package_name } ::{ class_name } "
117
127
118
128
@property
119
129
def name (self ) -> str :
120
130
return self .defaults .name
121
-
131
+
122
132
@property
123
133
def avatar_path (self ) -> str :
124
134
return self .defaults .avatar_path
125
135
126
136
@property
127
137
def system_prompt (self ) -> str :
128
138
return self .defaults .system_prompt
129
-
139
+
130
140
def as_user (self ) -> User :
131
141
return User (
132
142
username = self .id ,
133
143
name = self .name ,
134
144
display_name = self .name ,
135
145
avatar_url = self .avatar_path ,
136
146
)
137
-
147
+
138
148
def as_user_dict (self ) -> Dict [str , Any ]:
139
149
user = self .as_user ()
140
150
return asdict (user )
141
-
142
- async def forward_reply_stream (self , reply_stream : ' AsyncIterator' ):
151
+
152
+ async def forward_reply_stream (self , reply_stream : " AsyncIterator" ):
143
153
"""
144
154
Forwards an async iterator, dubbed the 'reply stream', to a new message
145
155
by this persona in the YChat.
@@ -152,13 +162,13 @@ async def forward_reply_stream(self, reply_stream: 'AsyncIterator'):
152
162
stream_id : Optional [str ] = None
153
163
154
164
try :
155
- self .awareness .set_local_state_field (' isWriting' , True )
165
+ self .awareness .set_local_state_field (" isWriting" , True )
156
166
async for chunk in reply_stream :
157
167
if not stream_id :
158
168
stream_id = self .ychat .add_message (
159
169
NewMessage (body = "" , sender = self .id )
160
170
)
161
-
171
+
162
172
assert stream_id
163
173
self .ychat .update_message (
164
174
Message (
@@ -171,8 +181,9 @@ async def forward_reply_stream(self, reply_stream: 'AsyncIterator'):
171
181
append = True ,
172
182
)
173
183
except Exception as e :
174
- self .log .error (f"Persona '{ self .name } ' encountered an exception printed below when attempting to stream output." )
184
+ self .log .error (
185
+ f"Persona '{ self .name } ' encountered an exception printed below when attempting to stream output."
186
+ )
175
187
self .log .exception (e )
176
188
finally :
177
- self .awareness .set_local_state_field ('isWriting' , False )
178
-
189
+ self .awareness .set_local_state_field ("isWriting" , False )
0 commit comments