Skip to content

Commit 33dc8b6

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent a64a8f8 commit 33dc8b6

File tree

11 files changed

+134
-97
lines changed

11 files changed

+134
-97
lines changed

packages/jupyter-ai-test/jupyter_ai_test/debug_persona.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from jupyter_ai.personas.base_persona import BasePersona, PersonaDefaults
22
from jupyterlab_chat.models import Message
33

4+
45
class DebugPersona(BasePersona):
56
"""
67
The Jupyternaut persona, the main persona provided by Jupyter AI.
@@ -15,9 +16,9 @@ def defaults(self):
1516
name="DebugPersona",
1617
avatar_path="/api/ai/static/jupyternaut.svg",
1718
description="A mock persona used for debugging in local dev environments.",
18-
system_prompt="..."
19+
system_prompt="...",
1920
)
20-
21+
2122
async def process_message(self, message: Message):
2223
self.log.info("HI IM DEBUGPERSONA AND IDK WHAT TO DO")
2324
return

packages/jupyter-ai/jupyter_ai/extension.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
from asyncio import get_event_loop_policy
21
import os
32
import re
43
import time
54
import types
5+
from asyncio import get_event_loop_policy
66
from functools import partial
7-
from typing import Dict, TYPE_CHECKING
7+
from typing import TYPE_CHECKING, Dict
88

99
import traitlets
1010
from dask.distributed import Client as DaskClient
@@ -253,13 +253,12 @@ def initialize(self):
253253
)
254254

255255
@property
256-
def event_loop(self) -> 'AbstractEventLoop':
256+
def event_loop(self) -> "AbstractEventLoop":
257257
"""
258258
Returns a reference to the asyncio event loop.
259259
"""
260260
return get_event_loop_policy().get_event_loop()
261261

262-
263262
async def connect_chat(
264263
self, logger: EventLogger, schema_id: str, data: dict
265264
) -> None:
@@ -326,7 +325,9 @@ def on_change(self, room_id: str, events: ArrayEvent) -> None:
326325
# message triggers 2 events, one with `raw_time` set to `True` and
327326
# another with `raw_time` set to `False` milliseconds later.
328327
# we should explore fixing this quirk in Jupyter Chat.
329-
new_messages = [Message(**m) for m in change['insert'] if not m.get('raw_time', False)]
328+
new_messages = [
329+
Message(**m) for m in change["insert"] if not m.get("raw_time", False)
330+
]
330331
for new_message in new_messages:
331332
persona_manager.route_message(new_message)
332333

@@ -419,7 +420,9 @@ def initialize_settings(self):
419420
# requires the event loop to be running on init. So instead we schedule
420421
# this as a task that is run as soon as the loop starts, and pass
421422
# consumers a Future that resolves to the Dask client when awaited.
422-
self.settings["dask_client_future"] = self.event_loop.create_task(self._get_dask_client())
423+
self.settings["dask_client_future"] = self.event_loop.create_task(
424+
self._get_dask_client()
425+
)
423426

424427
# Create empty context providers dict to be filled later.
425428
# This is created early to use as kwargs for chat handlers.

packages/jupyter-ai/jupyter_ai/history.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def _convert_to_langchain_messages(self, jchat_messages: List[JChatMessage]):
4545
"""
4646
messages: List[BaseMessage] = []
4747
for jchat_message in jchat_messages:
48-
if jchat_message.sender.startswith('jupyter-ai-personas::'):
48+
if jchat_message.sender.startswith("jupyter-ai-personas::"):
4949
messages.append(AIMessage(content=jchat_message.body))
5050
else:
5151
messages.append(HumanMessage(content=jchat_message.body))
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
from .base_persona import BasePersona, PersonaDefaults
2-
from .persona_manager import PersonaManager
2+
from .persona_manager import PersonaManager
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,45 @@
1-
from pydantic import BaseModel
2-
from dataclasses import asdict
3-
from typing import Any, Dict, Optional, Set, TYPE_CHECKING
41
from abc import ABC, abstractmethod
5-
from time import time
2+
from dataclasses import asdict
63
from logging import Logger
4+
from time import time
5+
from typing import TYPE_CHECKING, Any, Dict, Optional, Set
6+
77
from jupyter_ai.config_manager import ConfigManager
8+
from jupyterlab_chat.models import Message, NewMessage, User
89
from jupyterlab_chat.ychat import YChat
9-
from jupyterlab_chat.models import User, Message, NewMessage
10+
from pydantic import BaseModel
11+
1012
from .persona_awareness import PersonaAwareness
1113

1214
# prevents a circular import
1315
# `PersonaManager` types have to be surrounded in single quotes
1416
if TYPE_CHECKING:
15-
from .persona_manager import PersonaManager
1617
from collections.abc import AsyncIterator
1718

19+
from .persona_manager import PersonaManager
20+
21+
1822
class PersonaDefaults(BaseModel):
1923
"""
2024
Data structure that represents the default settings of a persona. Each persona
2125
must define some basic default settings, like its name.
2226
2327
Each of these settings can be overwritten through the settings UI.
2428
"""
29+
2530
################################################
2631
# required fields
2732
################################################
28-
name: str # e.g. "Jupyternaut"
29-
description: str # e.g. "..."
30-
avatar_path: str # e.g. /avatars/jupyternaut.svg
31-
system_prompt: str # e.g. "You are a language model named..."
33+
name: str # e.g. "Jupyternaut"
34+
description: str # e.g. "..."
35+
avatar_path: str # e.g. /avatars/jupyternaut.svg
36+
system_prompt: str # e.g. "You are a language model named..."
3237

3338
################################################
3439
# optional fields
3540
################################################
36-
slash_commands: Set[str] = set("*") # change this to enable/disable slash commands
37-
model_uid: Optional[str] = None # e.g. "ollama:deepseek-coder-v2"
41+
slash_commands: Set[str] = set("*") # change this to enable/disable slash commands
42+
model_uid: Optional[str] = None # e.g. "ollama:deepseek-coder-v2"
3843
# ^^^ set this to automatically default to a model after a fresh start, no config file
3944

4045

@@ -44,27 +49,32 @@ class BasePersona(ABC):
4449
"""
4550

4651
ychat: YChat
47-
manager: 'PersonaManager'
52+
manager: "PersonaManager"
4853
config: ConfigManager
4954
log: Logger
5055
awareness: PersonaAwareness
5156

5257
################################################
5358
# constructor
5459
################################################
55-
def __init__(self, *, ychat: YChat, manager: 'PersonaManager', config: ConfigManager, log: Logger):
60+
def __init__(
61+
self,
62+
*,
63+
ychat: YChat,
64+
manager: "PersonaManager",
65+
config: ConfigManager,
66+
log: Logger,
67+
):
5668
self.ychat = ychat
5769
self.manager = manager
5870
self.config = config
5971
self.log = log
6072
self.awareness = PersonaAwareness(
61-
ychat=self.ychat,
62-
log=self.log,
63-
user=self.as_user()
73+
ychat=self.ychat, log=self.log, user=self.as_user()
6474
)
6575

6676
self.ychat.set_user(self.as_user())
67-
77+
6878
################################################
6979
# abstract methods, required by subclasses.
7080
################################################
@@ -86,7 +96,7 @@ async def process_message(self, message: Message) -> None:
8696
# support streaming
8797
# handle multiple processed messages concurrently (if model service allows it)
8898
pass
89-
99+
90100
################################################
91101
# base class methods, available to subclasses.
92102
################################################
@@ -95,7 +105,7 @@ def id(self) -> str:
95105
"""
96106
Return a unique ID for this persona, which sets its username in the
97107
`User` object shared with other collaborative extensions.
98-
108+
99109
- This ID is guaranteed to be identical throughout this object's
100110
lifecycle.
101111
@@ -106,40 +116,40 @@ def id(self) -> str:
106116
107117
- For example, 'Jupyternaut' always has the ID
108118
`jupyter-ai-personas::jupyter-ai::JupyternautPersona`.
109-
119+
110120
- The ID must be unique, so if a package provides multiple personas,
111121
their class names must be unique. Renaming the persona class changes the
112122
ID of that persona, so you should also avoid renaming it if possible.
113123
"""
114-
package_name = self.__module__.split('.')[0]
124+
package_name = self.__module__.split(".")[0]
115125
class_name = self.__class__.__name__
116-
return f'jupyter-ai-personas::{package_name}::{class_name}'
126+
return f"jupyter-ai-personas::{package_name}::{class_name}"
117127

118128
@property
119129
def name(self) -> str:
120130
return self.defaults.name
121-
131+
122132
@property
123133
def avatar_path(self) -> str:
124134
return self.defaults.avatar_path
125135

126136
@property
127137
def system_prompt(self) -> str:
128138
return self.defaults.system_prompt
129-
139+
130140
def as_user(self) -> User:
131141
return User(
132142
username=self.id,
133143
name=self.name,
134144
display_name=self.name,
135145
avatar_url=self.avatar_path,
136146
)
137-
147+
138148
def as_user_dict(self) -> Dict[str, Any]:
139149
user = self.as_user()
140150
return asdict(user)
141-
142-
async def forward_reply_stream(self, reply_stream: 'AsyncIterator'):
151+
152+
async def forward_reply_stream(self, reply_stream: "AsyncIterator"):
143153
"""
144154
Forwards an async iterator, dubbed the 'reply stream', to a new message
145155
by this persona in the YChat.
@@ -152,13 +162,13 @@ async def forward_reply_stream(self, reply_stream: 'AsyncIterator'):
152162
stream_id: Optional[str] = None
153163

154164
try:
155-
self.awareness.set_local_state_field('isWriting', True)
165+
self.awareness.set_local_state_field("isWriting", True)
156166
async for chunk in reply_stream:
157167
if not stream_id:
158168
stream_id = self.ychat.add_message(
159169
NewMessage(body="", sender=self.id)
160170
)
161-
171+
162172
assert stream_id
163173
self.ychat.update_message(
164174
Message(
@@ -171,8 +181,9 @@ async def forward_reply_stream(self, reply_stream: 'AsyncIterator'):
171181
append=True,
172182
)
173183
except Exception as e:
174-
self.log.error(f"Persona '{self.name}' encountered an exception printed below when attempting to stream output.")
184+
self.log.error(
185+
f"Persona '{self.name}' encountered an exception printed below when attempting to stream output."
186+
)
175187
self.log.exception(e)
176188
finally:
177-
self.awareness.set_local_state_field('isWriting', False)
178-
189+
self.awareness.set_local_state_field("isWriting", False)
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .jupyternaut import JupyternautPersona
1+
from .jupyternaut import JupyternautPersona

packages/jupyter-ai/jupyter_ai/personas/jupyternaut/jupyternaut.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
from typing import Any
22

3+
from jupyterlab_chat.models import Message
34
from langchain_core.output_parsers import StrOutputParser
45
from langchain_core.runnables.history import RunnableWithMessageHistory
56

7+
from ...history import YChatHistory
68
from ..base_persona import BasePersona, PersonaDefaults
79
from .prompt_template import JUPYTERNAUT_PROMPT_TEMPLATE, JupyternautVariables
8-
from ...history import YChatHistory
9-
from jupyterlab_chat.models import Message
10+
1011

1112
class JupyternautPersona(BasePersona):
1213
"""
@@ -22,19 +23,19 @@ def defaults(self):
2223
name="Jupyternaut",
2324
avatar_path="/api/ai/static/jupyternaut.svg",
2425
description="The standard agent provided by JupyterLab. Currently has no tools.",
25-
system_prompt="..."
26+
system_prompt="...",
2627
)
27-
28+
2829
async def process_message(self, message: Message):
2930
provider_name = self.config.lm_provider.name
30-
model_id = self.config.lm_provider_params['model_id']
31+
model_id = self.config.lm_provider_params["model_id"]
3132

3233
runnable = self.build_runnable()
3334
variables = JupyternautVariables(
3435
input=message.body,
3536
model_id=model_id,
3637
provider_name=provider_name,
37-
persona_name=self.name
38+
persona_name=self.name,
3839
)
3940
variables_dict = variables.model_dump()
4041
reply_stream = runnable.astream(variables_dict)

packages/jupyter-ai/jupyter_ai/personas/jupyternaut/prompt_template.py

+13-10
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
from typing import Optional
22

3-
from pydantic import BaseModel
43
from langchain.prompts import (
54
ChatPromptTemplate,
65
HumanMessagePromptTemplate,
76
MessagesPlaceholder,
87
SystemMessagePromptTemplate,
98
)
9+
from pydantic import BaseModel
1010

1111
_JUPYTERNAUT_SYSTEM_PROMPT_FORMAT = """
1212
<instructions>
@@ -48,23 +48,26 @@
4848
</context>
4949
""".strip()
5050

51-
JUPYTERNAUT_PROMPT_TEMPLATE = ChatPromptTemplate.from_messages([
52-
SystemMessagePromptTemplate.from_template(
53-
_JUPYTERNAUT_SYSTEM_PROMPT_FORMAT,
54-
template_format="jinja2"
55-
),
56-
MessagesPlaceholder(variable_name="history"),
57-
HumanMessagePromptTemplate.from_template("{input}")
58-
])
51+
JUPYTERNAUT_PROMPT_TEMPLATE = ChatPromptTemplate.from_messages(
52+
[
53+
SystemMessagePromptTemplate.from_template(
54+
_JUPYTERNAUT_SYSTEM_PROMPT_FORMAT, template_format="jinja2"
55+
),
56+
MessagesPlaceholder(variable_name="history"),
57+
HumanMessagePromptTemplate.from_template("{input}"),
58+
]
59+
)
60+
5961

6062
class JupyternautVariables(BaseModel):
6163
"""
6264
Variables expected by `JUPYTERNAUT_PROMPT_TEMPLATE`, defined as a Pydantic
63-
data model for developer convenience.
65+
data model for developer convenience.
6466
6567
Call the `.model_dump()` method on an instance to convert it to a Python
6668
dictionary.
6769
"""
70+
6871
input: str
6972
persona_name: str
7073
provider_name: str

0 commit comments

Comments
 (0)