Skip to content

Commit 30e56e3

Browse files
committed
Initialize pipelines per instance in the base providers, too
1 parent 3e4790d commit 30e56e3

File tree

8 files changed

+30
-94
lines changed

8 files changed

+30
-94
lines changed

src/codegate/providers/anthropic/provider.py

+3-11
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
import json
2-
from typing import Optional
32

43
import structlog
54
from fastapi import Header, HTTPException, Request
65

7-
from codegate.pipeline.base import SequentialPipelineProcessor
8-
from codegate.pipeline.output import OutputPipelineProcessor
6+
from codegate.pipeline.factory import PipelineFactory
97
from codegate.providers.anthropic.adapter import AnthropicInputNormalizer, AnthropicOutputNormalizer
108
from codegate.providers.anthropic.completion_handler import AnthropicCompletion
119
from codegate.providers.base import BaseProvider
@@ -15,20 +13,14 @@
1513
class AnthropicProvider(BaseProvider):
1614
def __init__(
1715
self,
18-
pipeline_processor: Optional[SequentialPipelineProcessor] = None,
19-
fim_pipeline_processor: Optional[SequentialPipelineProcessor] = None,
20-
output_pipeline_processor: Optional[OutputPipelineProcessor] = None,
21-
fim_output_pipeline_processor: Optional[OutputPipelineProcessor] = None,
16+
pipeline_factory: PipelineFactory,
2217
):
2318
completion_handler = AnthropicCompletion(stream_generator=anthropic_stream_generator)
2419
super().__init__(
2520
AnthropicInputNormalizer(),
2621
AnthropicOutputNormalizer(),
2722
completion_handler,
28-
pipeline_processor,
29-
fim_pipeline_processor,
30-
output_pipeline_processor,
31-
fim_output_pipeline_processor,
23+
pipeline_factory,
3224
)
3325

3426
@property

src/codegate/providers/base.py

+8-14
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010
from codegate.pipeline.base import (
1111
PipelineContext,
1212
PipelineResult,
13-
SequentialPipelineProcessor,
1413
)
15-
from codegate.pipeline.output import OutputPipelineInstance, OutputPipelineProcessor
14+
from codegate.pipeline.factory import PipelineFactory
15+
from codegate.pipeline.output import OutputPipelineInstance
1616
from codegate.providers.completion.base import BaseCompletionHandler
1717
from codegate.providers.formatting.input_pipeline import PipelineResponseFormatter
1818
from codegate.providers.normalizer.base import ModelInputNormalizer, ModelOutputNormalizer
@@ -34,19 +34,13 @@ def __init__(
3434
input_normalizer: ModelInputNormalizer,
3535
output_normalizer: ModelOutputNormalizer,
3636
completion_handler: BaseCompletionHandler,
37-
pipeline_processor: Optional[SequentialPipelineProcessor] = None,
38-
fim_pipeline_processor: Optional[SequentialPipelineProcessor] = None,
39-
output_pipeline_processor: Optional[OutputPipelineProcessor] = None,
40-
fim_output_pipeline_processor: Optional[OutputPipelineProcessor] = None,
37+
pipeline_factory: PipelineFactory,
4138
):
4239
self.router = APIRouter()
4340
self._completion_handler = completion_handler
4441
self._input_normalizer = input_normalizer
4542
self._output_normalizer = output_normalizer
46-
self._pipeline_processor = pipeline_processor
47-
self._fim_pipelin_processor = fim_pipeline_processor
48-
self._output_pipeline_processor = output_pipeline_processor
49-
self._fim_output_pipeline_processor = fim_output_pipeline_processor
43+
self._pipeline_factory = pipeline_factory
5044
self._db_recorder = DbRecorder()
5145
self._pipeline_response_formatter = PipelineResponseFormatter(
5246
output_normalizer, self._db_recorder
@@ -73,10 +67,10 @@ async def _run_output_stream_pipeline(
7367
# Decide which pipeline processor to use
7468
out_pipeline_processor = None
7569
if is_fim_request:
76-
out_pipeline_processor = self._fim_output_pipeline_processor
70+
out_pipeline_processor = self._pipeline_factory.create_fim_output_pipeline()
7771
logger.info("FIM pipeline selected for output.")
7872
else:
79-
out_pipeline_processor = self._output_pipeline_processor
73+
out_pipeline_processor = self._pipeline_factory.create_output_pipeline()
8074
logger.info("Chat completion pipeline selected for output.")
8175
if out_pipeline_processor is None:
8276
logger.info("No output pipeline processor found, passing through")
@@ -117,11 +111,11 @@ async def _run_input_pipeline(
117111
) -> PipelineResult:
118112
# Decide which pipeline processor to use
119113
if is_fim_request:
120-
pipeline_processor = self._fim_pipelin_processor
114+
pipeline_processor = self._pipeline_factory.create_fim_pipeline()
121115
logger.info("FIM pipeline selected for execution.")
122116
normalized_request = self._fim_normalizer.normalize(normalized_request)
123117
else:
124-
pipeline_processor = self._pipeline_processor
118+
pipeline_processor = self._pipeline_factory.create_input_pipeline()
125119
logger.info("Chat completion pipeline selected for execution.")
126120
if pipeline_processor is None:
127121
return PipelineResult(request=normalized_request)

src/codegate/providers/llamacpp/provider.py

+3-11
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
import json
2-
from typing import Optional
32

43
import structlog
54
from fastapi import HTTPException, Request
65

7-
from codegate.pipeline.base import SequentialPipelineProcessor
8-
from codegate.pipeline.output import OutputPipelineProcessor
6+
from codegate.pipeline.factory import PipelineFactory
97
from codegate.providers.base import BaseProvider
108
from codegate.providers.llamacpp.completion_handler import LlamaCppCompletionHandler
119
from codegate.providers.llamacpp.normalizer import LLamaCppInputNormalizer, LLamaCppOutputNormalizer
@@ -14,20 +12,14 @@
1412
class LlamaCppProvider(BaseProvider):
1513
def __init__(
1614
self,
17-
pipeline_processor: Optional[SequentialPipelineProcessor] = None,
18-
fim_pipeline_processor: Optional[SequentialPipelineProcessor] = None,
19-
output_pipeline_processor: Optional[OutputPipelineProcessor] = None,
20-
fim_output_pipeline_processor: Optional[OutputPipelineProcessor] = None,
15+
pipeline_factory: PipelineFactory,
2116
):
2217
completion_handler = LlamaCppCompletionHandler()
2318
super().__init__(
2419
LLamaCppInputNormalizer(),
2520
LLamaCppOutputNormalizer(),
2621
completion_handler,
27-
pipeline_processor,
28-
fim_pipeline_processor,
29-
output_pipeline_processor,
30-
fim_output_pipeline_processor,
22+
pipeline_factory,
3123
)
3224

3325
@property

src/codegate/providers/ollama/provider.py

+3-10
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
import json
2-
from typing import Optional
32

43
import httpx
54
import structlog
65
from fastapi import HTTPException, Request
76

87
from codegate.config import Config
9-
from codegate.pipeline.base import SequentialPipelineProcessor
10-
from codegate.pipeline.output import OutputPipelineProcessor
8+
from codegate.pipeline.factory import PipelineFactory
119
from codegate.providers.base import BaseProvider
1210
from codegate.providers.ollama.adapter import OllamaInputNormalizer, OllamaOutputNormalizer
1311
from codegate.providers.ollama.completion_handler import OllamaShim
@@ -16,10 +14,7 @@
1614
class OllamaProvider(BaseProvider):
1715
def __init__(
1816
self,
19-
pipeline_processor: Optional[SequentialPipelineProcessor] = None,
20-
fim_pipeline_processor: Optional[SequentialPipelineProcessor] = None,
21-
output_pipeline_processor: Optional[OutputPipelineProcessor] = None,
22-
fim_output_pipeline_processor: Optional[OutputPipelineProcessor] = None,
17+
pipeline_factory: PipelineFactory,
2318
):
2419
config = Config.get_config()
2520
if config is None:
@@ -32,9 +27,7 @@ def __init__(
3227
OllamaInputNormalizer(),
3328
OllamaOutputNormalizer(),
3429
completion_handler,
35-
pipeline_processor,
36-
fim_pipeline_processor,
37-
output_pipeline_processor,
30+
pipeline_factory,
3831
)
3932

4033
@property

src/codegate/providers/openai/provider.py

+3-11
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
import json
2-
from typing import Optional
32

43
import structlog
54
from fastapi import Header, HTTPException, Request
65

7-
from codegate.pipeline.base import SequentialPipelineProcessor
8-
from codegate.pipeline.output import OutputPipelineProcessor
6+
from codegate.pipeline.factory import PipelineFactory
97
from codegate.providers.base import BaseProvider
108
from codegate.providers.litellmshim import LiteLLmShim, sse_stream_generator
119
from codegate.providers.openai.adapter import OpenAIInputNormalizer, OpenAIOutputNormalizer
@@ -14,20 +12,14 @@
1412
class OpenAIProvider(BaseProvider):
1513
def __init__(
1614
self,
17-
pipeline_processor: Optional[SequentialPipelineProcessor] = None,
18-
fim_pipeline_processor: Optional[SequentialPipelineProcessor] = None,
19-
output_pipeline_processor: Optional[OutputPipelineProcessor] = None,
20-
fim_output_pipeline_processor: Optional[OutputPipelineProcessor] = None,
15+
pipeline_factory: PipelineFactory,
2116
):
2217
completion_handler = LiteLLmShim(stream_generator=sse_stream_generator)
2318
super().__init__(
2419
OpenAIInputNormalizer(),
2520
OpenAIOutputNormalizer(),
2621
completion_handler,
27-
pipeline_processor,
28-
fim_pipeline_processor,
29-
output_pipeline_processor,
30-
fim_output_pipeline_processor,
22+
pipeline_factory,
3123
)
3224

3325
@property

src/codegate/providers/vllm/provider.py

+3-11
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
11
import json
2-
from typing import Optional
32

43
import httpx
54
import structlog
65
from fastapi import Header, HTTPException, Request
76
from litellm import atext_completion
87

98
from codegate.config import Config
10-
from codegate.pipeline.base import SequentialPipelineProcessor
11-
from codegate.pipeline.output import OutputPipelineProcessor
9+
from codegate.pipeline.factory import PipelineFactory
1210
from codegate.providers.base import BaseProvider
1311
from codegate.providers.litellmshim import LiteLLmShim, sse_stream_generator
1412
from codegate.providers.vllm.adapter import VLLMInputNormalizer, VLLMOutputNormalizer
@@ -17,10 +15,7 @@
1715
class VLLMProvider(BaseProvider):
1816
def __init__(
1917
self,
20-
pipeline_processor: Optional[SequentialPipelineProcessor] = None,
21-
fim_pipeline_processor: Optional[SequentialPipelineProcessor] = None,
22-
output_pipeline_processor: Optional[OutputPipelineProcessor] = None,
23-
fim_output_pipeline_processor: Optional[OutputPipelineProcessor] = None,
18+
pipeline_factory: PipelineFactory,
2419
):
2520
completion_handler = LiteLLmShim(
2621
stream_generator=sse_stream_generator, fim_completion_func=atext_completion
@@ -29,10 +24,7 @@ def __init__(
2924
VLLMInputNormalizer(),
3025
VLLMOutputNormalizer(),
3126
completion_handler,
32-
pipeline_processor,
33-
fim_pipeline_processor,
34-
output_pipeline_processor,
35-
fim_output_pipeline_processor,
27+
pipeline_factory,
3628
)
3729

3830
@property

src/codegate/server.py

+5-22
Original file line numberDiff line numberDiff line change
@@ -51,47 +51,30 @@ def init_app(pipeline_factory: PipelineFactory) -> FastAPI:
5151
# Register all known providers
5252
registry.add_provider(
5353
"openai",
54-
OpenAIProvider(
55-
pipeline_processor=pipeline_factory.create_input_pipeline(),
56-
fim_pipeline_processor=pipeline_factory.create_fim_pipeline(),
57-
output_pipeline_processor=pipeline_factory.create_output_pipeline(),
58-
fim_output_pipeline_processor=pipeline_factory.create_fim_output_pipeline(),
59-
),
54+
OpenAIProvider(pipeline_factory),
6055
)
6156
registry.add_provider(
6257
"anthropic",
6358
AnthropicProvider(
64-
pipeline_processor=pipeline_factory.create_input_pipeline(),
65-
fim_pipeline_processor=pipeline_factory.create_fim_pipeline(),
66-
output_pipeline_processor=pipeline_factory.create_output_pipeline(),
67-
fim_output_pipeline_processor=pipeline_factory.create_fim_output_pipeline(),
59+
pipeline_factory,
6860
),
6961
)
7062
registry.add_provider(
7163
"llamacpp",
7264
LlamaCppProvider(
73-
pipeline_processor=pipeline_factory.create_input_pipeline(),
74-
fim_pipeline_processor=pipeline_factory.create_fim_pipeline(),
75-
output_pipeline_processor=pipeline_factory.create_output_pipeline(),
76-
fim_output_pipeline_processor=pipeline_factory.create_fim_output_pipeline(),
65+
pipeline_factory,
7766
),
7867
)
7968
registry.add_provider(
8069
"vllm",
8170
VLLMProvider(
82-
pipeline_processor=pipeline_factory.create_input_pipeline(),
83-
fim_pipeline_processor=pipeline_factory.create_fim_pipeline(),
84-
output_pipeline_processor=pipeline_factory.create_output_pipeline(),
85-
fim_output_pipeline_processor=pipeline_factory.create_fim_output_pipeline(),
71+
pipeline_factory,
8672
),
8773
)
8874
registry.add_provider(
8975
"ollama",
9076
OllamaProvider(
91-
pipeline_processor=pipeline_factory.create_input_pipeline(),
92-
fim_pipeline_processor=pipeline_factory.create_fim_pipeline(),
93-
output_pipeline_processor=pipeline_factory.create_output_pipeline(),
94-
fim_output_pipeline_processor=pipeline_factory.create_fim_output_pipeline(),
77+
pipeline_factory,
9578
),
9679
)
9780

tests/test_provider.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,12 @@ def __init__(self):
1111
mocked_input_normalizer = MagicMock()
1212
mocked_output_normalizer = MagicMock()
1313
mocked_completion_handler = MagicMock()
14-
mocked_pipepeline = MagicMock()
15-
mocked_fim_pipeline = MagicMock()
14+
mocked_factory = MagicMock()
1615
super().__init__(
1716
mocked_input_normalizer,
1817
mocked_output_normalizer,
1918
mocked_completion_handler,
20-
mocked_pipepeline,
21-
mocked_fim_pipeline,
19+
mocked_factory,
2220
)
2321

2422
def _setup_routes(self) -> None:

0 commit comments

Comments
 (0)