diff --git a/migrations/versions/5c2f3eee5f90_introduce_workspaces.py b/migrations/versions/5c2f3eee5f90_introduce_workspaces.py new file mode 100644 index 00000000..d928b852 --- /dev/null +++ b/migrations/versions/5c2f3eee5f90_introduce_workspaces.py @@ -0,0 +1,41 @@ +"""introduce workspaces + +Revision ID: 5c2f3eee5f90 +Revises: 30d0144e1a50 +Create Date: 2025-01-15 19:27:08.230296 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '5c2f3eee5f90' +down_revision: Union[str, None] = '30d0144e1a50' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # Workspaces table + op.execute( + """ + CREATE TABLE workspaces ( + id TEXT PRIMARY KEY, -- UUID stored as TEXT + name TEXT NOT NULL, + is_active BOOLEAN NOT NULL DEFAULT 0 + ); + """ + ) + op.execute("INSERT INTO workspaces (id, name, is_active) VALUES ('1', 'default', 1);") + # Alter table prompts + op.execute("ALTER TABLE prompts ADD COLUMN workspace_id TEXT REFERENCES workspaces(id);") + op.execute("UPDATE prompts SET workspace_id = '1';") + # Create index for workspace_id + op.execute("CREATE INDEX idx_prompts_workspace_id ON prompts (workspace_id);") + + +def downgrade() -> None: + pass diff --git a/src/codegate/db/connection.py b/src/codegate/db/connection.py index 4894ad2a..88aff1f3 100644 --- a/src/codegate/db/connection.py +++ b/src/codegate/db/connection.py @@ -18,6 +18,7 @@ GetPromptWithOutputsRow, Output, Prompt, + Workspace, ) from codegate.pipeline.base import PipelineContext @@ -286,6 +287,18 @@ async def get_alerts_with_prompt_and_output(self) -> List[GetAlertsWithPromptAnd prompts = await self._execute_select_pydantic_model(GetAlertsWithPromptAndOutputRow, sql) return prompts + async def get_workspaces(self) -> List[Workspace]: + sql = text( + """ + SELECT + id, name, is_active + FROM workspaces + ORDER BY is_active DESC + """ + ) + workspaces = await self._execute_select_pydantic_model(Workspace, sql) + return workspaces + def init_db_sync(db_path: Optional[str] = None): """DB will be initialized in the constructor in case it doesn't exist.""" diff --git a/src/codegate/db/models.py b/src/codegate/db/models.py index 22859573..5ddc6148 100644 --- a/src/codegate/db/models.py +++ b/src/codegate/db/models.py @@ -26,6 +26,7 @@ class Prompt(pydantic.BaseModel): provider: Optional[Any] request: Any type: Any + workspace_id: Optional[Any] class Setting(pydantic.BaseModel): @@ -37,6 +38,11 @@ class Setting(pydantic.BaseModel): other_settings: Optional[Any] +class Workspace(pydantic.BaseModel): + id: Any + name: str + is_active: bool = False + # Models for select queries diff --git a/src/codegate/pipeline/base.py b/src/codegate/pipeline/base.py index e22b2915..13231224 100644 --- a/src/codegate/pipeline/base.py +++ b/src/codegate/pipeline/base.py @@ -135,6 +135,7 @@ def add_input_request( provider=provider, type="fim" if is_fim_request else "chat", request=request_str, + workspace_id="1", # TODO: This is a placeholder for now, using default workspace ) # Uncomment the below to debug the input # logger.debug(f"Added input request to context: {self.input_request}") diff --git a/src/codegate/pipeline/factory.py b/src/codegate/pipeline/factory.py index 7a713332..cd14df8e 100644 --- a/src/codegate/pipeline/factory.py +++ b/src/codegate/pipeline/factory.py @@ -14,6 +14,7 @@ ) from codegate.pipeline.system_prompt.codegate import SystemPrompt from codegate.pipeline.version.version import CodegateVersion +from codegate.pipeline.workspace.workspace import CodegateWorkspace class PipelineFactory: @@ -28,6 +29,7 @@ def create_input_pipeline(self) -> SequentialPipelineProcessor: # later steps CodegateSecrets(), CodegateVersion(), + CodegateWorkspace(), CodeSnippetExtractor(), CodegateContextRetriever(), SystemPrompt(Config.get_config().prompts.default_chat), diff --git a/src/codegate/pipeline/workspace/__init__.py b/src/codegate/pipeline/workspace/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/codegate/pipeline/workspace/workspace.py b/src/codegate/pipeline/workspace/workspace.py new file mode 100644 index 00000000..da950357 --- /dev/null +++ b/src/codegate/pipeline/workspace/workspace.py @@ -0,0 +1,107 @@ +import asyncio + +from litellm import ChatCompletionRequest + +from codegate.db.connection import DbReader +from codegate.pipeline.base import ( + PipelineContext, + PipelineResponse, + PipelineResult, + PipelineStep, +) + + +class WorkspaceCommands: + + def __init__(self): + self._db_recorder = DbReader() + self.commands = { + "list": self._list_workspaces, + } + + async def _list_workspaces(self, *args): + """ + List all workspaces + """ + workspaces = await self._db_recorder.get_workspaces() + print(workspaces) + respond_str = "" + for workspace in workspaces: + respond_str += f"{workspace.id} - {workspace.name}" + if workspace.is_active: + respond_str += " (active)" + respond_str += "\n" + return respond_str + + async def execute(self, command: str, *args) -> str: + """ + Execute the given command + + Args: + command (str): The command to execute + """ + command_to_execute = self.commands.get(command) + if command_to_execute is not None: + return await command_to_execute(*args) + else: + return "Command not found" + + async def parse_execute_cmd(self, last_user_message: str) -> str: + """ + Parse the last user message and execute the command + + Args: + last_user_message (str): The last user message + """ + command_and_args = last_user_message.split("codegate-workspace ")[1] + command, *args = command_and_args.split(" ") + return await self.execute(command, *args) + + +class CodegateWorkspace(PipelineStep): + """Pipeline step that handles workspace information requests.""" + + @property + def name(self) -> str: + """ + Returns the name of this pipeline step. + + Returns: + str: The identifier 'codegate-workspace' + """ + return "codegate-workspace" + + async def process( + self, request: ChatCompletionRequest, context: PipelineContext + ) -> PipelineResult: + """ + Checks if the last user message contains "codegate-workspace" and + responds with command specified. + This short-circuits the pipeline if the message is found. + + Args: + request (ChatCompletionRequest): The chat completion request to process + context (PipelineContext): The current pipeline context + + Returns: + PipelineResult: Contains workspace response if triggered, otherwise continues + pipeline + """ + last_user_message = self.get_last_user_message(request) + + if last_user_message is not None: + last_user_message_str, _ = last_user_message + if "codegate-workspace" in last_user_message_str.lower(): + context.shortcut_response = True + command_output = await WorkspaceCommands().parse_execute_cmd(last_user_message_str) + return PipelineResult( + response=PipelineResponse( + step_name=self.name, + content=command_output, + model=request["model"], + ), + context=context, + ) + + # Fall through + return PipelineResult(request=request, context=context)