Skip to content

Commit d61b1a5

Browse files
committed
Create the pipelines only once in the copilot provider
1 parent 483e6ae commit d61b1a5

File tree

2 files changed

+27
-11
lines changed

2 files changed

+27
-11
lines changed

src/codegate/providers/copilot/pipeline.py

+12-6
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class CopilotPipeline(ABC):
2424

2525
def __init__(self, pipeline_factory: PipelineFactory):
2626
self.pipeline_factory = pipeline_factory
27+
self.instance = self._create_pipeline()
2728
self.normalizer = self._create_normalizer()
2829
self.provider_name = "openai"
2930

@@ -33,7 +34,7 @@ def _create_normalizer(self):
3334
pass
3435

3536
@abstractmethod
36-
def create_pipeline(self) -> SequentialPipelineProcessor:
37+
def _create_pipeline(self) -> SequentialPipelineProcessor:
3738
"""Each strategy defines which pipeline to create"""
3839
pass
3940

@@ -84,7 +85,9 @@ def _create_shortcut_response(result: PipelineResult, model: str) -> bytes:
8485
body = response.model_dump_json(exclude_none=True, exclude_unset=True).encode()
8586
return body
8687

87-
async def process_body(self, headers: list[str], body: bytes) -> Tuple[bytes, PipelineContext]:
88+
async def process_body(
89+
self, headers: list[str], body: bytes,
90+
) -> Tuple[bytes, PipelineContext | None]:
8891
"""Common processing logic for all strategies"""
8992
try:
9093
normalized_body = self.normalizer.normalize(body)
@@ -97,8 +100,7 @@ async def process_body(self, headers: list[str], body: bytes) -> Tuple[bytes, Pi
97100
except ValueError:
98101
continue
99102

100-
pipeline = self.create_pipeline()
101-
result = await pipeline.process_request(
103+
result = await self.instance.process_request(
102104
request=normalized_body,
103105
provider=self.provider_name,
104106
model=normalized_body.get("model", "gpt-4o-mini"),
@@ -167,11 +169,13 @@ class CopilotFimPipeline(CopilotPipeline):
167169
A pipeline for the FIM format used by Copilot. Combines the normalizer for the FIM
168170
format and the FIM pipeline used by all providers.
169171
"""
172+
def __init__(self, pipeline_factory: PipelineFactory):
173+
super().__init__(pipeline_factory)
170174

171175
def _create_normalizer(self):
172176
return CopilotFimNormalizer()
173177

174-
def create_pipeline(self) -> SequentialPipelineProcessor:
178+
def _create_pipeline(self) -> SequentialPipelineProcessor:
175179
return self.pipeline_factory.create_fim_pipeline()
176180

177181

@@ -180,9 +184,11 @@ class CopilotChatPipeline(CopilotPipeline):
180184
A pipeline for the Chat format used by Copilot. Combines the normalizer for the FIM
181185
format and the FIM pipeline used by all providers.
182186
"""
187+
def __init__(self, pipeline_factory: PipelineFactory):
188+
super().__init__(pipeline_factory)
183189

184190
def _create_normalizer(self):
185191
return CopilotChatNormalizer()
186192

187-
def create_pipeline(self) -> SequentialPipelineProcessor:
193+
def _create_pipeline(self) -> SequentialPipelineProcessor:
188194
return self.pipeline_factory.create_input_pipeline()

src/codegate/providers/copilot/provider.py

+15-5
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,16 @@ def __init__(self, loop: asyncio.AbstractEventLoop):
150150
self.cert_manager = TLSCertDomainManager(self.ca)
151151
self._closing = False
152152
self.pipeline_factory = PipelineFactory(SecretsManager())
153+
self.input_pipeline: Optional[CopilotPipeline] = None
154+
self.fim_pipeline: Optional[CopilotPipeline] = None
155+
# the context as provided by the pipeline
153156
self.context_tracking: Optional[PipelineContext] = None
154157

158+
def _ensure_pipelines(self):
159+
if not self.input_pipeline or not self.fim_pipeline:
160+
self.input_pipeline = CopilotChatPipeline(self.pipeline_factory)
161+
self.fim_pipeline = CopilotFimPipeline(self.pipeline_factory)
162+
155163
def _select_pipeline(self, method: str, path: str) -> Optional[CopilotPipeline]:
156164
if method != "POST":
157165
logger.debug("Not a POST request, no pipeline selected")
@@ -161,10 +169,10 @@ def _select_pipeline(self, method: str, path: str) -> Optional[CopilotPipeline]:
161169
if path == route.path:
162170
if route.pipeline_type == PipelineType.FIM:
163171
logger.debug("Selected FIM pipeline")
164-
return CopilotFimPipeline(self.pipeline_factory)
172+
return self.fim_pipeline
165173
elif route.pipeline_type == PipelineType.CHAT:
166174
logger.debug("Selected CHAT pipeline")
167-
return CopilotChatPipeline(self.pipeline_factory)
175+
return self.input_pipeline
168176

169177
logger.debug("No pipeline selected")
170178
return None
@@ -181,7 +189,6 @@ async def _body_through_pipeline(
181189
# if we didn't select any strategy that would change the request
182190
# let's just pass through the body as-is
183191
return body, None
184-
logger.debug(f"Processing body through pipeline: {len(body)} bytes")
185192
return await strategy.process_body(headers, body)
186193

187194
async def _request_to_target(self, headers: list[str], body: bytes):
@@ -288,6 +295,8 @@ async def _forward_data_through_pipeline(self, data: bytes) -> Union[HttpRequest
288295
http_request.headers,
289296
http_request.body,
290297
)
298+
# TODO: it's weird that we're overwriting the context. Should we set the context once? Maybe when
299+
# creating the pipeline instance?
291300
self.context_tracking = context
292301

293302
if context and context.shortcut_response:
@@ -431,7 +440,6 @@ def data_received(self, data: bytes) -> None:
431440
Handle received data from client. Since we need to process the complete body
432441
through our pipeline before forwarding, we accumulate the entire request first.
433442
"""
434-
logger.info(f"Received data from {self.peername}: {data}")
435443
try:
436444
if not self._check_buffer_size(data):
437445
self.send_error_response(413, b"Request body too large")
@@ -442,6 +450,7 @@ def data_received(self, data: bytes) -> None:
442450
if not self.headers_parsed:
443451
self.headers_parsed = self.parse_headers()
444452
if self.headers_parsed:
453+
self._ensure_pipelines()
445454
if self.request.method == "CONNECT":
446455
self.handle_connect()
447456
self.buffer.clear()
@@ -452,7 +461,6 @@ def data_received(self, data: bytes) -> None:
452461
if self._has_complete_body():
453462
# Process the complete request through the pipeline
454463
complete_request = bytes(self.buffer)
455-
logger.debug(f"Complete request: {complete_request}")
456464
self.buffer.clear()
457465
asyncio.create_task(self._forward_data_to_target(complete_request))
458466

@@ -756,10 +764,12 @@ def connection_made(self, transport: asyncio.Transport) -> None:
756764

757765
def _ensure_output_processor(self) -> None:
758766
if self.proxy.context_tracking is None:
767+
logger.debug("No context tracking, no need to process pipeline")
759768
# No context tracking, no need to process pipeline
760769
return
761770

762771
if self.sse_processor is not None:
772+
logger.debug("Already initialized, no need to reinitialize")
763773
# Already initialized, no need to reinitialize
764774
return
765775

0 commit comments

Comments
 (0)